Improve the efficiency of the S3 storage provider (#50)

there are a few separate things here, which I hope will mean that (a) we use less CPU, and (b) what CPU we do use gets traced to the requests that cause it rather than getting lost down the sofa.
This commit is contained in:
Richard van der Hoff 2021-01-21 12:22:43 +00:00 committed by GitHub
parent 887ee24d76
commit 236e0cddb8
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 50 additions and 31 deletions

View file

@ -1,5 +1,6 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2018 New Vector Ltd # Copyright 2018 New Vector Ltd
# Copyright 2021 The Matrix.org Foundation C.I.C.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -22,7 +23,7 @@ from six import string_types
import boto3 import boto3
import botocore import botocore
from twisted.internet import defer, reactor from twisted.internet import defer, reactor, threads
from twisted.python.failure import Failure from twisted.python.failure import Failure
from twisted.python.threadpool import ThreadPool from twisted.python.threadpool import ThreadPool
@ -76,42 +77,68 @@ class S3StorageProviderBackend(StorageProvider):
if "secret_access_key" in config: if "secret_access_key" in config:
self.api_kwargs["aws_secret_access_key"] = config["secret_access_key"] self.api_kwargs["aws_secret_access_key"] = config["secret_access_key"]
self._s3_client = None
self._s3_client_lock = threading.Lock()
threadpool_size = config.get("threadpool_size", 40) threadpool_size = config.get("threadpool_size", 40)
self._download_pool = ThreadPool( self._s3_pool = ThreadPool(name="s3-pool", maxthreads=threadpool_size)
name="s3-download-pool", maxthreads=threadpool_size self._s3_pool.start()
)
self._download_pool.start()
# Manually stop the thread pool on shutdown. If we don't do this then # Manually stop the thread pool on shutdown. If we don't do this then
# stopping Synapse takes an extra ~30s as Python waits for the threads # stopping Synapse takes an extra ~30s as Python waits for the threads
# to exit. # to exit.
reactor.addSystemEventTrigger( reactor.addSystemEventTrigger(
"during", "shutdown", self._download_pool.stop, "during", "shutdown", self._s3_pool.stop,
) )
def _get_s3_client(self):
# this method is designed to be thread-safe, so that we can share a
# single boto3 client across multiple threads.
#
# (XXX: is creating a client actually a blocking operation, or could we do
# this on the main thread, to simplify all this?)
# first of all, do a fast lock-free check
s3 = self._s3_client
if s3:
return s3
# no joy, grab the lock and repeat the check
with self._s3_client_lock:
s3 = self._s3_client
if not s3:
b3_session = boto3.session.Session()
self._s3_client = s3 = b3_session.client("s3", **self.api_kwargs)
return s3
def store_file(self, path, file_info): def store_file(self, path, file_info):
"""See StorageProvider.store_file""" """See StorageProvider.store_file"""
parent_logcontext = current_context()
def _store_file(): def _store_file():
session = boto3.session.Session() with LoggingContext(parent_context=parent_logcontext):
session.resource("s3", **self.api_kwargs).Bucket(self.bucket).upload_file( self._get_s3_client().upload_file(
Filename=os.path.join(self.cache_directory, path), Filename=os.path.join(self.cache_directory, path),
Bucket=self.bucket,
Key=path, Key=path,
ExtraArgs={"StorageClass": self.storage_class}, ExtraArgs={"StorageClass": self.storage_class},
) )
# XXX: reactor.callInThread doesn't return anything, so I don't think this does return make_deferred_yieldable(
# what the author intended. threads.deferToThreadPool(reactor, self._s3_pool, _store_file)
return make_deferred_yieldable(reactor.callInThread(_store_file)) )
def fetch(self, path, file_info): def fetch(self, path, file_info):
"""See StorageProvider.fetch""" """See StorageProvider.fetch"""
logcontext = current_context() logcontext = current_context()
d = defer.Deferred() d = defer.Deferred()
self._download_pool.callInThread(
s3_download_task, self.bucket, self.api_kwargs, path, d, logcontext def _get_file():
) s3_download_task(self._get_s3_client(), self.bucket, path, d, logcontext)
self._s3_pool.callInThread(_get_file)
return make_deferred_yieldable(d) return make_deferred_yieldable(d)
@staticmethod @staticmethod
@ -149,13 +176,12 @@ class S3StorageProviderBackend(StorageProvider):
return result return result
def s3_download_task(bucket, api_kwargs, key, deferred, parent_logcontext): def s3_download_task(s3_client, bucket, key, deferred, parent_logcontext):
"""Attempts to download a file from S3. """Attempts to download a file from S3.
Args: Args:
s3_client: boto3 s3 client
bucket (str): The S3 bucket which may have the file bucket (str): The S3 bucket which may have the file
api_kwargs (dict): Keyword arguments to pass when invoking the API.
Generally `endpoint_url`.
key (str): The key of the file key (str): The key of the file
deferred (Deferred[_S3Responder|None]): If file exists deferred (Deferred[_S3Responder|None]): If file exists
resolved with an _S3Responder instance, if it doesn't resolved with an _S3Responder instance, if it doesn't
@ -166,16 +192,8 @@ def s3_download_task(bucket, api_kwargs, key, deferred, parent_logcontext):
with LoggingContext(parent_context=parent_logcontext): with LoggingContext(parent_context=parent_logcontext):
logger.info("Fetching %s from S3", key) logger.info("Fetching %s from S3", key)
local_data = threading.local()
try: try:
s3 = local_data.b3_client resp = s3_client.get_object(Bucket=bucket, Key=key)
except AttributeError:
b3_session = boto3.session.Session()
local_data.b3_client = s3 = b3_session.client("s3", **api_kwargs)
try:
resp = s3.get_object(Bucket=bucket, Key=key)
except botocore.exceptions.ClientError as e: except botocore.exceptions.ClientError as e:
if e.response["Error"]["Code"] in ("404", "NoSuchKey",): if e.response["Error"]["Code"] in ("404", "NoSuchKey",):
logger.info("Media %s not found in S3", key) logger.info("Media %s not found in S3", key)

View file

@ -16,5 +16,6 @@ setup(
"psycopg2>=2.7.5<3.0", "psycopg2>=2.7.5<3.0",
"PyYAML>=3.13<4.0", "PyYAML>=3.13<4.0",
"tqdm>=4.26.0<5.0", "tqdm>=4.26.0<5.0",
"Twisted",
], ],
) )