Add support to read database configuration from homeserver.yaml (#91)

This commit is contained in:
Victor Freire 2023-03-14 12:38:51 -03:00 committed by GitHub
parent fa27fa1a92
commit 6f8b3821aa
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -394,13 +394,46 @@ def get_sqlite_conn(parser):
return sqlite_conn
def get_homeserver_db_conn(parser):
"""Attempt to get a connection based on homeserver.yaml to Synapse's
database, or exit.
"""
def get_synapse_db_conn(parser):
try:
with open("homeserver.yaml") as f:
homeserver_yaml = yaml.safe_load(f)
except FileNotFoundError:
parser.error("Could not find homeserver.yaml")
except yaml.YAMLError as e:
parser.error("homeserver.yaml is not valid yaml: %s" % (e,))
try:
database_name = homeserver_yaml["database"]["name"]
database_args = homeserver_yaml["database"]["args"]
if database_name == "sqlite3":
synapse_db_conn = sqlite3.connect(database=database_args["database"])
else:
synapse_db_conn = psycopg2.connect(
user=database_args["user"],
password=database_args["password"],
database=database_args["database"],
host=database_args["host"],
)
except sqlite3.OperationalError as e:
parser.error("Could not connect to sqlite3 database: %s" % (e,))
except psycopg2.Error as e:
parser.error("Could not connect to postgres database: %s" % (e,))
return synapse_db_conn
def get_database_db_conn(parser):
"""Attempt to get a connection based on database.yaml to Synapse's
database, or exit.
"""
try:
database_yaml = yaml.safe_load(open("database.yaml"))
with open("database.yaml") as f:
database_yaml = yaml.safe_load(f)
except FileNotFoundError:
parser.error("Could not find database.yaml")
except yaml.YAMLError as e:
@ -420,6 +453,17 @@ def get_synapse_db_conn(parser):
return synapse_db_conn
def get_synapse_db_conn(parser):
"""Attempt to get a connection based on database.yaml or homeserver.yaml
to Synapse's database, or exit.
"""
if os.path.isfile("database.yaml"):
conn = get_database_db_conn(parser)
else:
conn = get_homeserver_db_conn(parser)
return conn
def main():
parser = argparse.ArgumentParser(prog="s3_media_upload")