#!/usr/bin/env python import argparse import datetime import os import sqlite3 import boto3 import botocore import humanize import psycopg2 import tqdm import yaml # Schema for our sqlite database cache SCHEMA = """ CREATE TABLE IF NOT EXISTS media ( origin TEXT NOT NULL, -- empty string if local media media_id TEXT NOT NULL, filesystem_id TEXT NOT NULL, -- Type is "local" or "remote" type TEXT NOT NULL, -- indicates whether the media and all its thumbnails have been deleted from the -- local cache known_deleted BOOLEAN NOT NULL ); CREATE UNIQUE INDEX IF NOT EXISTS media_id_idx ON media(origin, media_id); CREATE INDEX IF NOT EXISTS deleted_idx ON media(known_deleted); """ progress = True def parse_duration(string): """Parse a string into a duration supports suffix of d, m or y. """ suffix = string[-1] number = string[:-1] try: number = int(number) except ValueError: raise argparse.ArgumentTypeError( "duration must be an integer followed by a 'd', 'm' or 'y' suffix" ) now = datetime.datetime.now() if suffix == "d": then = now - datetime.timedelta(days=number) elif suffix == "m": then = now - datetime.timedelta(days=30 * number) elif suffix == "y": then = now - datetime.timedelta(days=365 * number) else: raise argparse.ArgumentTypeError("duration must end in 'd', 'm' or 'y'") return then def mark_as_deleted(sqlite_conn, origin, media_id): with sqlite_conn: sqlite_conn.execute( """ UPDATE media SET known_deleted = ? WHERE origin = ? AND media_id = ? """, (True, origin, media_id), ) def get_not_deleted_count(sqlite_conn): """Get count of all rows in our cache that we don't think have been deleted """ cur = sqlite_conn.cursor() cur.execute( """ SELECT COALESCE(count(*), 0) FROM media WHERE NOT known_deleted """ ) (count,) = cur.fetchone() return count def get_not_deleted(sqlite_conn): """Get all rows in our cache that we don't think have been deleted """ cur = sqlite_conn.cursor() cur.execute( """ SELECT origin, media_id, filesystem_id, type FROM media WHERE NOT known_deleted """ ) return cur def to_path(origin, filesystem_id, m_type): """Get a relative path to the given media """ if m_type == "local": file_path = os.path.join( "local_content", filesystem_id[:2], filesystem_id[2:4], filesystem_id[4:], ) elif m_type == "remote": file_path = os.path.join( "remote_content", origin, filesystem_id[:2], filesystem_id[2:4], filesystem_id[4:], ) else: raise Exception("Unexpected media type %r", m_type) return file_path def to_thumbnail_dir(origin, filesystem_id, m_type): """Get a relative path to the given media's thumbnail directory """ if m_type == "local": thumbnail_path = os.path.join( "local_thumbnails", filesystem_id[:2], filesystem_id[2:4], filesystem_id[4:], ) elif m_type == "remote": thumbnail_path = os.path.join( "remote_thumbnail", origin, filesystem_id[:2], filesystem_id[2:4], filesystem_id[4:], ) else: raise Exception("Unexpected media type %r", m_type) return thumbnail_path def get_local_files(base_path, origin, filesystem_id, m_type): """Get a list of relative paths to undeleted files for the given media """ local_files = [] original_path = to_path(origin, filesystem_id, m_type) if os.path.exists(os.path.join(base_path, original_path)): local_files.append(original_path) thumbnail_path = to_thumbnail_dir(origin, filesystem_id, m_type) try: with os.scandir(os.path.join(base_path, thumbnail_path)) as dir_entries: for dir_entry in dir_entries: if dir_entry.is_file(): local_files.append(os.path.join(thumbnail_path, dir_entry.name)) except FileNotFoundError: # The thumbnail directory does not exist pass except NotADirectoryError: # The thumbnail directory is not a directory for some reason pass return local_files def check_file_in_s3(s3, bucket, key, extra_args): """Check the file exists in S3 (though it could be different) """ try: if "SSECustomerKey" in extra_args and "SSECustomerAlgorithm" in extra_args: s3.head_object( Bucket=bucket, Key=key, SSECustomerKey=extra_args["SSECustomerKey"], SSECustomerAlgorithm=extra_args["SSECustomerAlgorithm"], ) else: s3.head_object(Bucket=bucket, Key=key) except botocore.exceptions.ClientError as e: if int(e.response["Error"]["Code"]) == 404: return False raise return True def run_write(sqlite_conn, output_file): """Entry point for write command """ for origin, _, filesystem_id, m_type in get_not_deleted(sqlite_conn): file_path = to_path(origin, filesystem_id, m_type) print(file_path, file=output_file) # Print thumbnail directories with a trailing '/' thumbnail_path = to_thumbnail_dir(origin, filesystem_id, m_type) thumbnail_path = os.path.join(thumbnail_path, "") print(thumbnail_path, file=output_file) def run_update_db(synapse_db_conn, sqlite_conn, before_date): """Entry point for update-db command """ local_sql = """ SELECT '', media_id, media_id FROM local_media_repository WHERE COALESCE(last_access_ts, created_ts) < %s AND url_cache IS NULL """ remote_sql = """ SELECT media_origin, media_id, filesystem_id FROM remote_media_cache WHERE COALESCE(last_access_ts, created_ts) < %s """ last_access_ts = int(before_date.timestamp() * 1000) print( "Syncing files that haven't been accessed since:", before_date.isoformat(" "), ) update_count = 0 with sqlite_conn: sqlite_cur = sqlite_conn.cursor() if isinstance(synapse_db_conn, sqlite3.Connection): synapse_db_curs = synapse_db_conn.cursor() for sql, mtype in ((local_sql, "local"), (remote_sql, "remote")): synapse_db_curs.execute(sql.replace("%s", "?"), (last_access_ts,)) update_count += update_db_process_rows( mtype, sqlite_cur, synapse_db_curs ) else: with synapse_db_conn.cursor() as synapse_db_curs: for sql, mtype in ((local_sql, "local"), (remote_sql, "remote")): synapse_db_curs.execute(sql, (last_access_ts,)) update_count += update_db_process_rows( mtype, sqlite_cur, synapse_db_curs ) print("Synced", update_count, "new rows") synapse_db_conn.close() def update_db_process_rows(mtype, sqlite_cur, synapse_db_curs): """Process rows extracted from Synapse's database and insert them in cache """ update_count = 0 for (origin, media_id, filesystem_id) in synapse_db_curs: sqlite_cur.execute( """ INSERT OR IGNORE INTO media (origin, media_id, filesystem_id, type, known_deleted) VALUES (?, ?, ?, ?, ?) """, (origin, media_id, filesystem_id, mtype, False), ) update_count += sqlite_cur.rowcount return update_count def run_check_delete(sqlite_conn, base_path): """Entry point for check-deleted command """ deleted = [] if progress: it = tqdm.tqdm( get_not_deleted(sqlite_conn), unit="files", total=get_not_deleted_count(sqlite_conn), ) else: it = get_not_deleted(sqlite_conn) print("Checking on ", get_not_deleted_count(sqlite_conn), " undeleted files") for origin, media_id, filesystem_id, m_type in it: local_files = get_local_files(base_path, origin, filesystem_id, m_type) if not local_files: deleted.append((origin, media_id)) with sqlite_conn: sqlite_conn.executemany( """ UPDATE media SET known_deleted = ? WHERE origin = ? AND media_id = ? """, ((True, o, m) for o, m in deleted), ) print("Updated", len(deleted), "as deleted") def run_upload(s3, bucket, sqlite_conn, base_path, s3_prefix, extra_args, should_delete): """Entry point for upload command """ total = get_not_deleted_count(sqlite_conn) uploaded_media = 0 uploaded_files = 0 uploaded_bytes = 0 deleted_media = 0 deleted_files = 0 deleted_bytes = 0 # This is a progress bar if progress: it = tqdm.tqdm(get_not_deleted(sqlite_conn), unit="files", total=total) else: print("Uploading ", total, " files") it = get_not_deleted(sqlite_conn) for origin, media_id, filesystem_id, m_type in it: local_files = get_local_files(base_path, origin, filesystem_id, m_type) if not local_files: mark_as_deleted(sqlite_conn, origin, media_id) continue # Counters of uploaded and deleted files for this media only media_uploaded_files = 0 media_deleted_files = 0 for rel_file_path in local_files: local_path = os.path.join(base_path, rel_file_path) key = s3_prefix + rel_file_path if not check_file_in_s3(s3, bucket, key, extra_args): try: s3.upload_file( local_path, bucket, key, ExtraArgs=extra_args, ) except Exception as e: print("Failed to upload file %s: %s", local_path, e) continue media_uploaded_files += 1 uploaded_files += 1 uploaded_bytes += os.path.getsize(local_path) if should_delete: size = os.path.getsize(local_path) os.remove(local_path) try: # This may have lead to an empty directory, so lets remove all # that are empty os.removedirs(os.path.dirname(local_path)) except Exception: # The directory might not be empty, or maybe we don't have # permission. Either way doesn't really matter. pass media_deleted_files += 1 deleted_files += 1 deleted_bytes += size if media_uploaded_files: uploaded_media += 1 if media_deleted_files: deleted_media += 1 if media_deleted_files == len(local_files): # Mark as deleted only if *all* the local files have been deleted mark_as_deleted(sqlite_conn, origin, media_id) print("Uploaded", uploaded_media, "media out of", total) print("Uploaded", uploaded_files, "files") print("Uploaded", humanize.naturalsize(uploaded_bytes, gnu=True)) print("Deleted", deleted_media, "media") print("Deleted", deleted_files, "files") print("Deleted", humanize.naturalsize(deleted_bytes, gnu=True)) def get_sqlite_conn(parser): """Attempt to get a sqlite connection to cache.db, or exit. """ try: sqlite_conn = sqlite3.connect("cache.db") sqlite_conn.executescript(SCHEMA) except sqlite3.Error as e: parser.error("Could not open 'cache.db' as sqlite DB: %s" % (e,)) return sqlite_conn def get_homeserver_db_conn(parser): """Attempt to get a connection based on homeserver.yaml to Synapse's database, or exit. """ try: with open("homeserver.yaml") as f: homeserver_yaml = yaml.safe_load(f) except FileNotFoundError: parser.error("Could not find homeserver.yaml") except yaml.YAMLError as e: parser.error("homeserver.yaml is not valid yaml: %s" % (e,)) try: database_name = homeserver_yaml["database"]["name"] database_args = homeserver_yaml["database"]["args"] if database_name == "sqlite3": synapse_db_conn = sqlite3.connect(database=database_args["database"]) else: synapse_db_conn = psycopg2.connect( user=database_args["user"], password=database_args["password"], database=database_args["database"], host=database_args["host"], ) except sqlite3.OperationalError as e: parser.error("Could not connect to sqlite3 database: %s" % (e,)) except psycopg2.Error as e: parser.error("Could not connect to postgres database: %s" % (e,)) return synapse_db_conn def get_database_db_conn(parser): """Attempt to get a connection based on database.yaml to Synapse's database, or exit. """ try: with open("database.yaml") as f: database_yaml = yaml.safe_load(f) except FileNotFoundError: parser.error("Could not find database.yaml") except yaml.YAMLError as e: parser.error("database.yaml is not valid yaml: %s" % (e,)) try: if "sqlite" in database_yaml: synapse_db_conn = sqlite3.connect(**database_yaml["sqlite"]) elif "postgres" in database_yaml: synapse_db_conn = psycopg2.connect(**database_yaml["postgres"]) else: synapse_db_conn = psycopg2.connect(**database_yaml) except sqlite3.OperationalError as e: parser.error("Could not connect to sqlite3 database: %s" % (e,)) except psycopg2.Error as e: parser.error("Could not connect to postgres database: %s" % (e,)) return synapse_db_conn def get_synapse_db_conn(parser): """Attempt to get a connection based on database.yaml or homeserver.yaml to Synapse's database, or exit. """ if os.path.isfile("database.yaml"): conn = get_database_db_conn(parser) else: conn = get_homeserver_db_conn(parser) return conn def main(): parser = argparse.ArgumentParser(prog="s3_media_upload") parser.add_argument( "--no-progress", help="do not show progress bars", action="store_true", dest="no_progress", ) subparsers = parser.add_subparsers(help="command to run", dest="cmd") update_db_parser = subparsers.add_parser( "update-db", help="Syncs rows from database to local cache" ) update_db_parser.add_argument( "duration", type=parse_duration, help="Fetch rows that haven't been accessed in the duration given," " accepts duration of the form of e.g. 1m (one month). Valid suffixes" " are d, m or y. NOTE: Currently does not remove entries from the cache", ) deleted_parser = subparsers.add_parser( "check-deleted", help="Check whether files in the local cache still exist under given path", ) deleted_parser.add_argument( "base_path", help="Base path of the media store directory" ) update_parser = subparsers.add_parser( "update", help="Updates local cache. Equivalent to running update-db and" " check-deleted", ) update_parser.add_argument( "base_path", help="Base path of the media store directory" ) update_parser.add_argument( "duration", type=parse_duration, help="Fetch rows that haven't been accessed in the duration given," " accepts duration of the form of e.g. 1m (one month). Valid suffixes" " are d, m or y. NOTE: Currently does not remove entries from the cache", ) write_parser = subparsers.add_parser( "write", help="Outputs all file and directory paths in the local cache that we may not" " have deleted. check-deleted should be run first to update cache.", ) write_parser.add_argument( "out", type=argparse.FileType("w", encoding="UTF-8"), default="-", nargs="?", help="File to output list to, or '-' for stdout", ) upload_parser = subparsers.add_parser( "upload", help="Uploads media to s3 based on local cache" ) upload_parser.add_argument( "base_path", help="Base path of the media store directory" ) upload_parser.add_argument("bucket", help="S3 bucket to upload to") upload_parser.add_argument("--prefix", help="Prefix of files in the bucket", default="") upload_parser.add_argument( "--storage-class", help="S3 storage class to use", nargs="?", choices=[ "STANDARD", "REDUCED_REDUNDANCY", "STANDARD_IA", "ONEZONE_IA", "INTELLIGENT_TIERING", ], default="STANDARD", ) upload_parser.add_argument( "--sse-customer-key", help="SSE-C key to use", ) upload_parser.add_argument( "--sse-customer-algo", help="Algorithm for SSE-C, only used if sse-customer-key is also specified", default="AES256", ) upload_parser.add_argument( "--delete", action="store_const", const=True, help="Deletes local copy from media store on succesful upload", ) upload_parser.add_argument( "--endpoint-url", help="S3 endpoint url to use", default=None ) args = parser.parse_args() if args.no_progress: global progress progress = False if args.cmd == "write": sqlite_conn = get_sqlite_conn(parser) run_write(sqlite_conn, args.out) return if args.cmd == "update-db": sqlite_conn = get_sqlite_conn(parser) synapse_db_conn = get_synapse_db_conn(parser) run_update_db(synapse_db_conn, sqlite_conn, args.duration) return if args.cmd == "check-deleted": sqlite_conn = get_sqlite_conn(parser) run_check_delete(sqlite_conn, args.base_path) return if args.cmd == "update": sqlite_conn = get_sqlite_conn(parser) synapse_db_conn = get_synapse_db_conn(parser) run_update_db(synapse_db_conn, sqlite_conn, args.duration) run_check_delete(sqlite_conn, args.base_path) return if args.cmd == "upload": sqlite_conn = get_sqlite_conn(parser) s3 = boto3.client("s3", endpoint_url=args.endpoint_url) extra_args = {"StorageClass": args.storage_class} if args.sse_customer_key: extra_args["SSECustomerKey"] = args.sse_customer_key if args.sse_customer_algo: extra_args["SSECustomerAlgorithm"] = args.sse_customer_algo else: extra_args["SSECustomerAlgorithm"] = "AES256" run_upload( s3, args.bucket, sqlite_conn, args.base_path, args.prefix, extra_args, should_delete=args.delete, ) return parser.error("Valid subcommand must be specified") if __name__ == "__main__": main()