diff --git a/README.md b/README.md index fc6444f..23146a2 100644 --- a/README.md +++ b/README.md @@ -27,6 +27,9 @@ media_storage_providers: endpoint_url: access_key_id: secret_access_key: + sse_customer_key: + # Your SSE-C algorithm is very likely AES256 + sse_customer_algo: # The object storage class used when uploading files to the bucket. # Default is STANDARD. diff --git a/s3_storage_provider.py b/s3_storage_provider.py index 68137ac..c74fa5c 100644 --- a/s3_storage_provider.py +++ b/s3_storage_provider.py @@ -62,7 +62,7 @@ class S3StorageProviderBackend(StorageProvider): def __init__(self, hs, config): self.cache_directory = hs.config.media.media_store_path self.bucket = config["bucket"] - self.storage_class = config["storage_class"] + self.eargs = config["eargs"] self.api_kwargs = {} if "region_name" in config: @@ -118,11 +118,12 @@ class S3StorageProviderBackend(StorageProvider): def _store_file(): with LoggingContext(parent_context=parent_logcontext): + self._get_s3_client().upload_file( Filename=os.path.join(self.cache_directory, path), Bucket=self.bucket, Key=path, - ExtraArgs={"StorageClass": self.storage_class}, + ExtraArgs=self.eargs, ) return make_deferred_yieldable( @@ -136,7 +137,9 @@ class S3StorageProviderBackend(StorageProvider): d = defer.Deferred() def _get_file(): - s3_download_task(self._get_s3_client(), self.bucket, path, d, logcontext) + s3_download_task( + self._get_s3_client(), self.bucket, self.eargs, path, d, logcontext + ) self._s3_pool.callInThread(_get_file) return make_deferred_yieldable(d) @@ -158,7 +161,7 @@ class S3StorageProviderBackend(StorageProvider): result = { "bucket": bucket, - "storage_class": storage_class, + "eargs": {"StorageClass": storage_class}, } if "region_name" in config: @@ -173,10 +176,16 @@ class S3StorageProviderBackend(StorageProvider): if "secret_access_key" in config: result["secret_access_key"] = config["secret_access_key"] + if "sse_customer_key" in config: + result["eargs"]["SSECustomerKey"] = config["sse_customer_key"] + result["eargs"]["SSECustomerAlgorithm"] = config.get( + "sse_customer_algo", "AES256" + ) + return result -def s3_download_task(s3_client, bucket, key, deferred, parent_logcontext): +def s3_download_task(s3_client, bucket, key, eargs, deferred, parent_logcontext): """Attempts to download a file from S3. Args: @@ -193,9 +202,21 @@ def s3_download_task(s3_client, bucket, key, deferred, parent_logcontext): logger.info("Fetching %s from S3", key) try: - resp = s3_client.get_object(Bucket=bucket, Key=key) + if eargs["SSECustomerKey"] and eargs["SSECustomerAlgorithm"]: + resp = s3_client.get_object( + Bucket=bucket, + Key=key, + SSECustomerKey=eargs["SSECustomerKey"], + SSECustomerAlgorithm=eargs["SSECustomerAlgorithm"], + ) + else: + resp = s3_client.get_object(Bucket=bucket, Key=key) + except botocore.exceptions.ClientError as e: - if e.response["Error"]["Code"] in ("404", "NoSuchKey",): + if e.response["Error"]["Code"] in ( + "404", + "NoSuchKey", + ): logger.info("Media %s not found in S3", key) reactor.callFromThread(deferred.callback, None) return diff --git a/scripts/s3_media_upload b/scripts/s3_media_upload index 6be77e7..108d39e 100755 --- a/scripts/s3_media_upload +++ b/scripts/s3_media_upload @@ -167,11 +167,19 @@ def get_local_files(base_path, origin, filesystem_id, m_type): return local_files -def check_file_in_s3(s3, bucket, key): +def check_file_in_s3(s3, bucket, key, eargs): """Check the file exists in S3 (though it could be different) """ try: - s3.head_object(Bucket=bucket, Key=key) + if eargs["SSECustomerKey"] and eargs["SSECustomerAlgorithm"]: + s3.head_object( + Bucket=bucket, + Key=key, + SSECustomerKey=eargs["SSECustomerKey"], + SSECustomerAlgorithm=eargs["SSECustomerAlgorithm"], + ) + else: + s3.head_object(Bucket=bucket, Key=key) except botocore.exceptions.ClientError as e: if int(e.response["Error"]["Code"]) == 404: return False @@ -327,13 +335,13 @@ def run_upload(s3, bucket, sqlite_conn, base_path, should_delete, storage_class) for rel_file_path in local_files: local_path = os.path.join(base_path, rel_file_path) - if not check_file_in_s3(s3, bucket, rel_file_path): + if not check_file_in_s3(s3, bucket, rel_file_path, eargs): try: s3.upload_file( local_path, bucket, rel_file_path, - ExtraArgs={"StorageClass": storage_class}, + ExtraArgs=self.eargs, ) except Exception as e: print("Failed to upload file %s: %s", local_path, e) @@ -481,6 +489,7 @@ def main(): "base_path", help="Base path of the media store directory" ) upload_parser.add_argument("bucket", help="S3 bucket to upload to") + upload_parser.add_argument( "--storage-class", help="S3 storage class to use", @@ -495,6 +504,17 @@ def main(): default="STANDARD", ) + upload_parser.add_argument( + "--sse-customer-key", help="SSE-C key to use", + ) + + upload_parser.add_argument( + "--sse-customer-algo", + help="Algorithm for SSE-C, only used if sse-customer-key is also specified", + nargs="?", + default="AES256", + ) + upload_parser.add_argument( "--delete", action="store_const", @@ -537,13 +557,22 @@ def main(): if args.cmd == "upload": sqlite_conn = get_sqlite_conn(parser) s3 = boto3.client("s3", endpoint_url=args.endpoint_url) + + eargs = {"StorageClass": args.storage_class} + if args.sse_customer_key: + eargs["SSECustomerKey"] = args.sse_customer_key + if args.sse_customer_algo: + eargs["SSECustomerAlgorithm"] = args.sse_customer_algo + else: + eargs["SSECustomerAlgorithm"] = "AES256" + run_upload( s3, args.bucket, sqlite_conn, args.base_path, should_delete=args.delete, - storage_class=args.storage_class, + eargs=eargs ) return