diff --git a/scripts/s3_media_upload b/scripts/s3_media_upload index 40a5772..42ba3c3 100755 --- a/scripts/s3_media_upload +++ b/scripts/s3_media_upload @@ -394,13 +394,46 @@ def get_sqlite_conn(parser): return sqlite_conn +def get_homeserver_db_conn(parser): + """Attempt to get a connection based on homeserver.yaml to Synapse's + database, or exit. + """ -def get_synapse_db_conn(parser): + try: + with open("homeserver.yaml") as f: + homeserver_yaml = yaml.safe_load(f) + except FileNotFoundError: + parser.error("Could not find homeserver.yaml") + except yaml.YAMLError as e: + parser.error("homeserver.yaml is not valid yaml: %s" % (e,)) + + try: + database_name = homeserver_yaml["database"]["name"] + database_args = homeserver_yaml["database"]["args"] + if database_name == "sqlite3": + synapse_db_conn = sqlite3.connect(database=database_args["database"]) + else: + synapse_db_conn = psycopg2.connect( + user=database_args["user"], + password=database_args["password"], + database=database_args["database"], + host=database_args["host"], + ) + except sqlite3.OperationalError as e: + parser.error("Could not connect to sqlite3 database: %s" % (e,)) + except psycopg2.Error as e: + parser.error("Could not connect to postgres database: %s" % (e,)) + + return synapse_db_conn + +def get_database_db_conn(parser): """Attempt to get a connection based on database.yaml to Synapse's database, or exit. """ + try: - database_yaml = yaml.safe_load(open("database.yaml")) + with open("database.yaml") as f: + database_yaml = yaml.safe_load(f) except FileNotFoundError: parser.error("Could not find database.yaml") except yaml.YAMLError as e: @@ -420,6 +453,17 @@ def get_synapse_db_conn(parser): return synapse_db_conn +def get_synapse_db_conn(parser): + """Attempt to get a connection based on database.yaml or homeserver.yaml + to Synapse's database, or exit. + """ + + if os.path.isfile("database.yaml"): + conn = get_database_db_conn(parser) + else: + conn = get_homeserver_db_conn(parser) + + return conn def main(): parser = argparse.ArgumentParser(prog="s3_media_upload")