mirror of
https://github.com/matrix-org/synapse-s3-storage-provider.git
synced 2024-10-23 07:29:40 +00:00
Refactor eargs to extra_args
This commit is contained in:
parent
ff28d0a02f
commit
5957773cf7
2 changed files with 22 additions and 22 deletions
|
@ -62,7 +62,7 @@ class S3StorageProviderBackend(StorageProvider):
|
|||
def __init__(self, hs, config):
|
||||
self.cache_directory = hs.config.media.media_store_path
|
||||
self.bucket = config["bucket"]
|
||||
self.eargs = config["eargs"]
|
||||
self.extra_args = config["extra_args"]
|
||||
self.api_kwargs = {}
|
||||
|
||||
if "region_name" in config:
|
||||
|
@ -123,7 +123,7 @@ class S3StorageProviderBackend(StorageProvider):
|
|||
Filename=os.path.join(self.cache_directory, path),
|
||||
Bucket=self.bucket,
|
||||
Key=path,
|
||||
ExtraArgs=self.eargs,
|
||||
ExtraArgs=self.extra_args,
|
||||
)
|
||||
|
||||
return make_deferred_yieldable(
|
||||
|
@ -138,7 +138,7 @@ class S3StorageProviderBackend(StorageProvider):
|
|||
|
||||
def _get_file():
|
||||
s3_download_task(
|
||||
self._get_s3_client(), self.bucket, path, self.eargs, d, logcontext
|
||||
self._get_s3_client(), self.bucket, path, self.extra_args, d, logcontext
|
||||
)
|
||||
|
||||
self._s3_pool.callInThread(_get_file)
|
||||
|
@ -161,7 +161,7 @@ class S3StorageProviderBackend(StorageProvider):
|
|||
|
||||
result = {
|
||||
"bucket": bucket,
|
||||
"eargs": {"StorageClass": storage_class},
|
||||
"extra_args": {"StorageClass": storage_class},
|
||||
}
|
||||
|
||||
if "region_name" in config:
|
||||
|
@ -177,15 +177,15 @@ class S3StorageProviderBackend(StorageProvider):
|
|||
result["secret_access_key"] = config["secret_access_key"]
|
||||
|
||||
if "sse_customer_key" in config:
|
||||
result["eargs"]["SSECustomerKey"] = config["sse_customer_key"]
|
||||
result["eargs"]["SSECustomerAlgorithm"] = config.get(
|
||||
result["extra_args"]["SSECustomerKey"] = config["sse_customer_key"]
|
||||
result["extra_args"]["SSECustomerAlgorithm"] = config.get(
|
||||
"sse_customer_algo", "AES256"
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def s3_download_task(s3_client, bucket, key, eargs, deferred, parent_logcontext):
|
||||
def s3_download_task(s3_client, bucket, key, extra_args, deferred, parent_logcontext):
|
||||
"""Attempts to download a file from S3.
|
||||
|
||||
Args:
|
||||
|
@ -202,12 +202,12 @@ def s3_download_task(s3_client, bucket, key, eargs, deferred, parent_logcontext)
|
|||
logger.info("Fetching %s from S3", key)
|
||||
|
||||
try:
|
||||
if eargs["SSECustomerKey"] and eargs["SSECustomerAlgorithm"]:
|
||||
if extra_args["SSECustomerKey"] and extra_args["SSECustomerAlgorithm"]:
|
||||
resp = s3_client.get_object(
|
||||
Bucket=bucket,
|
||||
Key=key,
|
||||
SSECustomerKey=eargs["SSECustomerKey"],
|
||||
SSECustomerAlgorithm=eargs["SSECustomerAlgorithm"],
|
||||
SSECustomerKey=extra_args["SSECustomerKey"],
|
||||
SSECustomerAlgorithm=extra_args["SSECustomerAlgorithm"],
|
||||
)
|
||||
else:
|
||||
resp = s3_client.get_object(Bucket=bucket, Key=key)
|
||||
|
|
|
@ -167,16 +167,16 @@ def get_local_files(base_path, origin, filesystem_id, m_type):
|
|||
return local_files
|
||||
|
||||
|
||||
def check_file_in_s3(s3, bucket, key, eargs):
|
||||
def check_file_in_s3(s3, bucket, key, extra_args):
|
||||
"""Check the file exists in S3 (though it could be different)
|
||||
"""
|
||||
try:
|
||||
if eargs["SSECustomerKey"] and eargs["SSECustomerAlgorithm"]:
|
||||
if extra_args["SSECustomerKey"] and extra_args["SSECustomerAlgorithm"]:
|
||||
s3.head_object(
|
||||
Bucket=bucket,
|
||||
Key=key,
|
||||
SSECustomerKey=eargs["SSECustomerKey"],
|
||||
SSECustomerAlgorithm=eargs["SSECustomerAlgorithm"],
|
||||
SSECustomerKey=extra_args["SSECustomerKey"],
|
||||
SSECustomerAlgorithm=extra_args["SSECustomerAlgorithm"],
|
||||
)
|
||||
else:
|
||||
s3.head_object(Bucket=bucket, Key=key)
|
||||
|
@ -302,7 +302,7 @@ def run_check_delete(sqlite_conn, base_path):
|
|||
print("Updated", len(deleted), "as deleted")
|
||||
|
||||
|
||||
def run_upload(s3, bucket, sqlite_conn, base_path, eargs, should_delete):
|
||||
def run_upload(s3, bucket, sqlite_conn, base_path, extra_args, should_delete):
|
||||
"""Entry point for upload command
|
||||
"""
|
||||
total = get_not_deleted_count(sqlite_conn)
|
||||
|
@ -335,10 +335,10 @@ def run_upload(s3, bucket, sqlite_conn, base_path, eargs, should_delete):
|
|||
for rel_file_path in local_files:
|
||||
local_path = os.path.join(base_path, rel_file_path)
|
||||
|
||||
if not check_file_in_s3(s3, bucket, rel_file_path, eargs):
|
||||
if not check_file_in_s3(s3, bucket, rel_file_path, extra_args):
|
||||
try:
|
||||
s3.upload_file(
|
||||
local_path, bucket, rel_file_path, ExtraArgs=eargs,
|
||||
local_path, bucket, rel_file_path, ExtraArgs=extra_args,
|
||||
)
|
||||
except Exception as e:
|
||||
print("Failed to upload file %s: %s", local_path, e)
|
||||
|
@ -554,20 +554,20 @@ def main():
|
|||
sqlite_conn = get_sqlite_conn(parser)
|
||||
s3 = boto3.client("s3", endpoint_url=args.endpoint_url)
|
||||
|
||||
eargs = {"StorageClass": args.storage_class}
|
||||
extra_args = {"StorageClass": args.storage_class}
|
||||
if args.sse_customer_key:
|
||||
eargs["SSECustomerKey"] = args.sse_customer_key
|
||||
extra_args["SSECustomerKey"] = args.sse_customer_key
|
||||
if args.sse_customer_algo:
|
||||
eargs["SSECustomerAlgorithm"] = args.sse_customer_algo
|
||||
extra_args["SSECustomerAlgorithm"] = args.sse_customer_algo
|
||||
else:
|
||||
eargs["SSECustomerAlgorithm"] = "AES256"
|
||||
extra_args["SSECustomerAlgorithm"] = "AES256"
|
||||
|
||||
run_upload(
|
||||
s3,
|
||||
args.bucket,
|
||||
sqlite_conn,
|
||||
args.base_path,
|
||||
eargs,
|
||||
extra_args,
|
||||
should_delete=args.delete,
|
||||
)
|
||||
return
|
||||
|
|
Loading…
Reference in a new issue