Refactor eargs to extra_args

This commit is contained in:
Cody Wyatt Neiman 2022-11-04 13:17:00 -04:00
parent ff28d0a02f
commit 5957773cf7
No known key found for this signature in database
GPG key ID: 94475C8B94E4698D
2 changed files with 22 additions and 22 deletions

View file

@ -62,7 +62,7 @@ class S3StorageProviderBackend(StorageProvider):
def __init__(self, hs, config): def __init__(self, hs, config):
self.cache_directory = hs.config.media.media_store_path self.cache_directory = hs.config.media.media_store_path
self.bucket = config["bucket"] self.bucket = config["bucket"]
self.eargs = config["eargs"] self.extra_args = config["extra_args"]
self.api_kwargs = {} self.api_kwargs = {}
if "region_name" in config: if "region_name" in config:
@ -123,7 +123,7 @@ class S3StorageProviderBackend(StorageProvider):
Filename=os.path.join(self.cache_directory, path), Filename=os.path.join(self.cache_directory, path),
Bucket=self.bucket, Bucket=self.bucket,
Key=path, Key=path,
ExtraArgs=self.eargs, ExtraArgs=self.extra_args,
) )
return make_deferred_yieldable( return make_deferred_yieldable(
@ -138,7 +138,7 @@ class S3StorageProviderBackend(StorageProvider):
def _get_file(): def _get_file():
s3_download_task( s3_download_task(
self._get_s3_client(), self.bucket, path, self.eargs, d, logcontext self._get_s3_client(), self.bucket, path, self.extra_args, d, logcontext
) )
self._s3_pool.callInThread(_get_file) self._s3_pool.callInThread(_get_file)
@ -161,7 +161,7 @@ class S3StorageProviderBackend(StorageProvider):
result = { result = {
"bucket": bucket, "bucket": bucket,
"eargs": {"StorageClass": storage_class}, "extra_args": {"StorageClass": storage_class},
} }
if "region_name" in config: if "region_name" in config:
@ -177,15 +177,15 @@ class S3StorageProviderBackend(StorageProvider):
result["secret_access_key"] = config["secret_access_key"] result["secret_access_key"] = config["secret_access_key"]
if "sse_customer_key" in config: if "sse_customer_key" in config:
result["eargs"]["SSECustomerKey"] = config["sse_customer_key"] result["extra_args"]["SSECustomerKey"] = config["sse_customer_key"]
result["eargs"]["SSECustomerAlgorithm"] = config.get( result["extra_args"]["SSECustomerAlgorithm"] = config.get(
"sse_customer_algo", "AES256" "sse_customer_algo", "AES256"
) )
return result return result
def s3_download_task(s3_client, bucket, key, eargs, deferred, parent_logcontext): def s3_download_task(s3_client, bucket, key, extra_args, deferred, parent_logcontext):
"""Attempts to download a file from S3. """Attempts to download a file from S3.
Args: Args:
@ -202,12 +202,12 @@ def s3_download_task(s3_client, bucket, key, eargs, deferred, parent_logcontext)
logger.info("Fetching %s from S3", key) logger.info("Fetching %s from S3", key)
try: try:
if eargs["SSECustomerKey"] and eargs["SSECustomerAlgorithm"]: if extra_args["SSECustomerKey"] and extra_args["SSECustomerAlgorithm"]:
resp = s3_client.get_object( resp = s3_client.get_object(
Bucket=bucket, Bucket=bucket,
Key=key, Key=key,
SSECustomerKey=eargs["SSECustomerKey"], SSECustomerKey=extra_args["SSECustomerKey"],
SSECustomerAlgorithm=eargs["SSECustomerAlgorithm"], SSECustomerAlgorithm=extra_args["SSECustomerAlgorithm"],
) )
else: else:
resp = s3_client.get_object(Bucket=bucket, Key=key) resp = s3_client.get_object(Bucket=bucket, Key=key)

View file

@ -167,16 +167,16 @@ def get_local_files(base_path, origin, filesystem_id, m_type):
return local_files return local_files
def check_file_in_s3(s3, bucket, key, eargs): def check_file_in_s3(s3, bucket, key, extra_args):
"""Check the file exists in S3 (though it could be different) """Check the file exists in S3 (though it could be different)
""" """
try: try:
if eargs["SSECustomerKey"] and eargs["SSECustomerAlgorithm"]: if extra_args["SSECustomerKey"] and extra_args["SSECustomerAlgorithm"]:
s3.head_object( s3.head_object(
Bucket=bucket, Bucket=bucket,
Key=key, Key=key,
SSECustomerKey=eargs["SSECustomerKey"], SSECustomerKey=extra_args["SSECustomerKey"],
SSECustomerAlgorithm=eargs["SSECustomerAlgorithm"], SSECustomerAlgorithm=extra_args["SSECustomerAlgorithm"],
) )
else: else:
s3.head_object(Bucket=bucket, Key=key) s3.head_object(Bucket=bucket, Key=key)
@ -302,7 +302,7 @@ def run_check_delete(sqlite_conn, base_path):
print("Updated", len(deleted), "as deleted") print("Updated", len(deleted), "as deleted")
def run_upload(s3, bucket, sqlite_conn, base_path, eargs, should_delete): def run_upload(s3, bucket, sqlite_conn, base_path, extra_args, should_delete):
"""Entry point for upload command """Entry point for upload command
""" """
total = get_not_deleted_count(sqlite_conn) total = get_not_deleted_count(sqlite_conn)
@ -335,10 +335,10 @@ def run_upload(s3, bucket, sqlite_conn, base_path, eargs, should_delete):
for rel_file_path in local_files: for rel_file_path in local_files:
local_path = os.path.join(base_path, rel_file_path) local_path = os.path.join(base_path, rel_file_path)
if not check_file_in_s3(s3, bucket, rel_file_path, eargs): if not check_file_in_s3(s3, bucket, rel_file_path, extra_args):
try: try:
s3.upload_file( s3.upload_file(
local_path, bucket, rel_file_path, ExtraArgs=eargs, local_path, bucket, rel_file_path, ExtraArgs=extra_args,
) )
except Exception as e: except Exception as e:
print("Failed to upload file %s: %s", local_path, e) print("Failed to upload file %s: %s", local_path, e)
@ -554,20 +554,20 @@ def main():
sqlite_conn = get_sqlite_conn(parser) sqlite_conn = get_sqlite_conn(parser)
s3 = boto3.client("s3", endpoint_url=args.endpoint_url) s3 = boto3.client("s3", endpoint_url=args.endpoint_url)
eargs = {"StorageClass": args.storage_class} extra_args = {"StorageClass": args.storage_class}
if args.sse_customer_key: if args.sse_customer_key:
eargs["SSECustomerKey"] = args.sse_customer_key extra_args["SSECustomerKey"] = args.sse_customer_key
if args.sse_customer_algo: if args.sse_customer_algo:
eargs["SSECustomerAlgorithm"] = args.sse_customer_algo extra_args["SSECustomerAlgorithm"] = args.sse_customer_algo
else: else:
eargs["SSECustomerAlgorithm"] = "AES256" extra_args["SSECustomerAlgorithm"] = "AES256"
run_upload( run_upload(
s3, s3,
args.bucket, args.bucket,
sqlite_conn, sqlite_conn,
args.base_path, args.base_path,
eargs, extra_args,
should_delete=args.delete, should_delete=args.delete,
) )
return return