Cody Wyatt Neiman 2022-10-24 14:20:20 -04:00
parent b0fa9bb140
commit 93b41de6b4
No known key found for this signature in database
GPG key ID: 94475C8B94E4698D
3 changed files with 65 additions and 12 deletions

View file

@ -27,6 +27,9 @@ media_storage_providers:
endpoint_url: <S3_LIKE_SERVICE_ENDPOINT_URL> endpoint_url: <S3_LIKE_SERVICE_ENDPOINT_URL>
access_key_id: <S3_ACCESS_KEY_ID> access_key_id: <S3_ACCESS_KEY_ID>
secret_access_key: <S3_SECRET_ACCESS_KEY> secret_access_key: <S3_SECRET_ACCESS_KEY>
sse_customer_key: <S3_SSEC_KEY>
# Your SSE-C algorithm is very likely AES256
sse_customer_algo: <S3_SSEC_ALGO>
# The object storage class used when uploading files to the bucket. # The object storage class used when uploading files to the bucket.
# Default is STANDARD. # Default is STANDARD.

View file

@ -62,7 +62,7 @@ class S3StorageProviderBackend(StorageProvider):
def __init__(self, hs, config): def __init__(self, hs, config):
self.cache_directory = hs.config.media.media_store_path self.cache_directory = hs.config.media.media_store_path
self.bucket = config["bucket"] self.bucket = config["bucket"]
self.storage_class = config["storage_class"] self.eargs = config["eargs"]
self.api_kwargs = {} self.api_kwargs = {}
if "region_name" in config: if "region_name" in config:
@ -118,11 +118,12 @@ class S3StorageProviderBackend(StorageProvider):
def _store_file(): def _store_file():
with LoggingContext(parent_context=parent_logcontext): with LoggingContext(parent_context=parent_logcontext):
self._get_s3_client().upload_file( self._get_s3_client().upload_file(
Filename=os.path.join(self.cache_directory, path), Filename=os.path.join(self.cache_directory, path),
Bucket=self.bucket, Bucket=self.bucket,
Key=path, Key=path,
ExtraArgs={"StorageClass": self.storage_class}, ExtraArgs=self.eargs,
) )
return make_deferred_yieldable( return make_deferred_yieldable(
@ -136,7 +137,9 @@ class S3StorageProviderBackend(StorageProvider):
d = defer.Deferred() d = defer.Deferred()
def _get_file(): def _get_file():
s3_download_task(self._get_s3_client(), self.bucket, path, d, logcontext) s3_download_task(
self._get_s3_client(), self.bucket, self.eargs, path, d, logcontext
)
self._s3_pool.callInThread(_get_file) self._s3_pool.callInThread(_get_file)
return make_deferred_yieldable(d) return make_deferred_yieldable(d)
@ -158,7 +161,7 @@ class S3StorageProviderBackend(StorageProvider):
result = { result = {
"bucket": bucket, "bucket": bucket,
"storage_class": storage_class, "eargs": {"StorageClass": storage_class},
} }
if "region_name" in config: if "region_name" in config:
@ -173,10 +176,16 @@ class S3StorageProviderBackend(StorageProvider):
if "secret_access_key" in config: if "secret_access_key" in config:
result["secret_access_key"] = config["secret_access_key"] result["secret_access_key"] = config["secret_access_key"]
if "sse_customer_key" in config:
result["eargs"]["SSECustomerKey"] = config["sse_customer_key"]
result["eargs"]["SSECustomerAlgorithm"] = config.get(
"sse_customer_algo", "AES256"
)
return result return result
def s3_download_task(s3_client, bucket, key, deferred, parent_logcontext): def s3_download_task(s3_client, bucket, key, eargs, deferred, parent_logcontext):
"""Attempts to download a file from S3. """Attempts to download a file from S3.
Args: Args:
@ -193,9 +202,21 @@ def s3_download_task(s3_client, bucket, key, deferred, parent_logcontext):
logger.info("Fetching %s from S3", key) logger.info("Fetching %s from S3", key)
try: try:
resp = s3_client.get_object(Bucket=bucket, Key=key) if eargs["SSECustomerKey"] and eargs["SSECustomerAlgorithm"]:
resp = s3_client.get_object(
Bucket=bucket,
Key=key,
SSECustomerKey=eargs["SSECustomerKey"],
SSECustomerAlgorithm=eargs["SSECustomerAlgorithm"],
)
else:
resp = s3_client.get_object(Bucket=bucket, Key=key)
except botocore.exceptions.ClientError as e: except botocore.exceptions.ClientError as e:
if e.response["Error"]["Code"] in ("404", "NoSuchKey",): if e.response["Error"]["Code"] in (
"404",
"NoSuchKey",
):
logger.info("Media %s not found in S3", key) logger.info("Media %s not found in S3", key)
reactor.callFromThread(deferred.callback, None) reactor.callFromThread(deferred.callback, None)
return return

View file

@ -167,11 +167,19 @@ def get_local_files(base_path, origin, filesystem_id, m_type):
return local_files return local_files
def check_file_in_s3(s3, bucket, key): def check_file_in_s3(s3, bucket, key, eargs):
"""Check the file exists in S3 (though it could be different) """Check the file exists in S3 (though it could be different)
""" """
try: try:
s3.head_object(Bucket=bucket, Key=key) if eargs["SSECustomerKey"] and eargs["SSECustomerAlgorithm"]:
s3.head_object(
Bucket=bucket,
Key=key,
SSECustomerKey=eargs["SSECustomerKey"],
SSECustomerAlgorithm=eargs["SSECustomerAlgorithm"],
)
else:
s3.head_object(Bucket=bucket, Key=key)
except botocore.exceptions.ClientError as e: except botocore.exceptions.ClientError as e:
if int(e.response["Error"]["Code"]) == 404: if int(e.response["Error"]["Code"]) == 404:
return False return False
@ -327,13 +335,13 @@ def run_upload(s3, bucket, sqlite_conn, base_path, should_delete, storage_class)
for rel_file_path in local_files: for rel_file_path in local_files:
local_path = os.path.join(base_path, rel_file_path) local_path = os.path.join(base_path, rel_file_path)
if not check_file_in_s3(s3, bucket, rel_file_path): if not check_file_in_s3(s3, bucket, rel_file_path, eargs):
try: try:
s3.upload_file( s3.upload_file(
local_path, local_path,
bucket, bucket,
rel_file_path, rel_file_path,
ExtraArgs={"StorageClass": storage_class}, ExtraArgs=self.eargs,
) )
except Exception as e: except Exception as e:
print("Failed to upload file %s: %s", local_path, e) print("Failed to upload file %s: %s", local_path, e)
@ -481,6 +489,7 @@ def main():
"base_path", help="Base path of the media store directory" "base_path", help="Base path of the media store directory"
) )
upload_parser.add_argument("bucket", help="S3 bucket to upload to") upload_parser.add_argument("bucket", help="S3 bucket to upload to")
upload_parser.add_argument( upload_parser.add_argument(
"--storage-class", "--storage-class",
help="S3 storage class to use", help="S3 storage class to use",
@ -495,6 +504,17 @@ def main():
default="STANDARD", default="STANDARD",
) )
upload_parser.add_argument(
"--sse-customer-key", help="SSE-C key to use",
)
upload_parser.add_argument(
"--sse-customer-algo",
help="Algorithm for SSE-C, only used if sse-customer-key is also specified",
nargs="?",
default="AES256",
)
upload_parser.add_argument( upload_parser.add_argument(
"--delete", "--delete",
action="store_const", action="store_const",
@ -537,13 +557,22 @@ def main():
if args.cmd == "upload": if args.cmd == "upload":
sqlite_conn = get_sqlite_conn(parser) sqlite_conn = get_sqlite_conn(parser)
s3 = boto3.client("s3", endpoint_url=args.endpoint_url) s3 = boto3.client("s3", endpoint_url=args.endpoint_url)
eargs = {"StorageClass": args.storage_class}
if args.sse_customer_key:
eargs["SSECustomerKey"] = args.sse_customer_key
if args.sse_customer_algo:
eargs["SSECustomerAlgorithm"] = args.sse_customer_algo
else:
eargs["SSECustomerAlgorithm"] = "AES256"
run_upload( run_upload(
s3, s3,
args.bucket, args.bucket,
sqlite_conn, sqlite_conn,
args.base_path, args.base_path,
should_delete=args.delete, should_delete=args.delete,
storage_class=args.storage_class, eargs=eargs
) )
return return