diff --git a/s3_storage_provider.py b/s3_storage_provider.py index ef4eaab..30b229c 100644 --- a/s3_storage_provider.py +++ b/s3_storage_provider.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2018 New Vector Ltd +# Copyright 2021 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -22,7 +23,7 @@ from six import string_types import boto3 import botocore -from twisted.internet import defer, reactor +from twisted.internet import defer, reactor, threads from twisted.python.failure import Failure from twisted.python.threadpool import ThreadPool @@ -76,42 +77,68 @@ class S3StorageProviderBackend(StorageProvider): if "secret_access_key" in config: self.api_kwargs["aws_secret_access_key"] = config["secret_access_key"] + self._s3_client = None + self._s3_client_lock = threading.Lock() + threadpool_size = config.get("threadpool_size", 40) - self._download_pool = ThreadPool( - name="s3-download-pool", maxthreads=threadpool_size - ) - self._download_pool.start() + self._s3_pool = ThreadPool(name="s3-pool", maxthreads=threadpool_size) + self._s3_pool.start() # Manually stop the thread pool on shutdown. If we don't do this then # stopping Synapse takes an extra ~30s as Python waits for the threads # to exit. reactor.addSystemEventTrigger( - "during", "shutdown", self._download_pool.stop, + "during", "shutdown", self._s3_pool.stop, ) + def _get_s3_client(self): + # this method is designed to be thread-safe, so that we can share a + # single boto3 client across multiple threads. + # + # (XXX: is creating a client actually a blocking operation, or could we do + # this on the main thread, to simplify all this?) + + # first of all, do a fast lock-free check + s3 = self._s3_client + if s3: + return s3 + + # no joy, grab the lock and repeat the check + with self._s3_client_lock: + s3 = self._s3_client + if not s3: + b3_session = boto3.session.Session() + self._s3_client = s3 = b3_session.client("s3", **self.api_kwargs) + return s3 + def store_file(self, path, file_info): """See StorageProvider.store_file""" - def _store_file(): - session = boto3.session.Session() - session.resource("s3", **self.api_kwargs).Bucket(self.bucket).upload_file( - Filename=os.path.join(self.cache_directory, path), - Key=path, - ExtraArgs={"StorageClass": self.storage_class}, - ) + parent_logcontext = current_context() - # XXX: reactor.callInThread doesn't return anything, so I don't think this does - # what the author intended. - return make_deferred_yieldable(reactor.callInThread(_store_file)) + def _store_file(): + with LoggingContext(parent_context=parent_logcontext): + self._get_s3_client().upload_file( + Filename=os.path.join(self.cache_directory, path), + Bucket=self.bucket, + Key=path, + ExtraArgs={"StorageClass": self.storage_class}, + ) + + return make_deferred_yieldable( + threads.deferToThreadPool(reactor, self._s3_pool, _store_file) + ) def fetch(self, path, file_info): """See StorageProvider.fetch""" logcontext = current_context() d = defer.Deferred() - self._download_pool.callInThread( - s3_download_task, self.bucket, self.api_kwargs, path, d, logcontext - ) + + def _get_file(): + s3_download_task(self._get_s3_client(), self.bucket, path, d, logcontext) + + self._s3_pool.callInThread(_get_file) return make_deferred_yieldable(d) @staticmethod @@ -149,13 +176,12 @@ class S3StorageProviderBackend(StorageProvider): return result -def s3_download_task(bucket, api_kwargs, key, deferred, parent_logcontext): +def s3_download_task(s3_client, bucket, key, deferred, parent_logcontext): """Attempts to download a file from S3. Args: + s3_client: boto3 s3 client bucket (str): The S3 bucket which may have the file - api_kwargs (dict): Keyword arguments to pass when invoking the API. - Generally `endpoint_url`. key (str): The key of the file deferred (Deferred[_S3Responder|None]): If file exists resolved with an _S3Responder instance, if it doesn't @@ -166,16 +192,8 @@ def s3_download_task(bucket, api_kwargs, key, deferred, parent_logcontext): with LoggingContext(parent_context=parent_logcontext): logger.info("Fetching %s from S3", key) - local_data = threading.local() - try: - s3 = local_data.b3_client - except AttributeError: - b3_session = boto3.session.Session() - local_data.b3_client = s3 = b3_session.client("s3", **api_kwargs) - - try: - resp = s3.get_object(Bucket=bucket, Key=key) + resp = s3_client.get_object(Bucket=bucket, Key=key) except botocore.exceptions.ClientError as e: if e.response["Error"]["Code"] in ("404", "NoSuchKey",): logger.info("Media %s not found in S3", key) diff --git a/setup.py b/setup.py index a880b1e..654544b 100644 --- a/setup.py +++ b/setup.py @@ -16,5 +16,6 @@ setup( "psycopg2>=2.7.5<3.0", "PyYAML>=3.13<4.0", "tqdm>=4.26.0<5.0", + "Twisted", ], )