synapse-s3-storage-provider/s3_storage_provider.py

392 lines
13 KiB
Python
Raw Normal View History

2018-02-07 11:34:53 +00:00
# -*- coding: utf-8 -*-
# Copyright 2018 New Vector Ltd
# Copyright 2021 The Matrix.org Foundation C.I.C.
2018-02-07 11:34:53 +00:00
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
import threading
from six import string_types
import boto3
import botocore
2018-02-07 11:34:53 +00:00
from twisted.internet import defer, reactor, threads
2018-02-07 11:34:53 +00:00
from twisted.python.failure import Failure
from twisted.python.threadpool import ThreadPool
2018-02-07 11:34:53 +00:00
from synapse.logging.context import LoggingContext, make_deferred_yieldable
2018-02-07 11:34:53 +00:00
from synapse.rest.media.v1._base import Responder
from synapse.rest.media.v1.storage_provider import StorageProvider
# Synapse 1.13.0 moved current_context to a module-level function.
try:
from synapse.logging.context import current_context
except ImportError:
current_context = LoggingContext.current_context
2018-02-07 11:34:53 +00:00
logger = logging.getLogger("synapse.s3")
2018-02-13 16:59:28 +00:00
# The list of valid AWS storage class names
2020-01-23 11:48:59 +00:00
_VALID_STORAGE_CLASSES = (
"STANDARD",
"REDUCED_REDUNDANCY",
"STANDARD_IA",
"INTELLIGENT_TIERING",
)
2018-02-13 16:59:28 +00:00
# Chunk size to use when reading from s3 connection in bytes
READ_CHUNK_SIZE = 16 * 1024
2018-02-13 16:59:28 +00:00
2018-02-07 11:34:53 +00:00
class S3StorageProviderBackend(StorageProvider):
"""
Args:
hs (HomeServer)
config: The config returned by `parse_config`
"""
def __init__(self, hs, config):
2021-10-12 10:14:27 +00:00
self.cache_directory = hs.config.media.media_store_path
2018-02-13 16:59:28 +00:00
self.bucket = config["bucket"]
self.prefix = config["prefix"]
# A dictionary of extra arguments for uploading files.
# See https://boto3.amazonaws.com/v1/documentation/api/latest/reference/customizations/s3.html#boto3.s3.transfer.S3Transfer.ALLOWED_UPLOAD_ARGS
# for a list of possible keys.
2022-11-04 17:17:00 +00:00
self.extra_args = config["extra_args"]
self.api_kwargs = {}
if "region_name" in config:
self.api_kwargs["region_name"] = config["region_name"]
if "endpoint_url" in config:
self.api_kwargs["endpoint_url"] = config["endpoint_url"]
2018-02-07 11:34:53 +00:00
if "access_key_id" in config:
self.api_kwargs["aws_access_key_id"] = config["access_key_id"]
if "secret_access_key" in config:
self.api_kwargs["aws_secret_access_key"] = config["secret_access_key"]
2024-10-15 18:55:04 +00:00
if "session_token" in config:
self.api_kwargs["aws_session_token"] = config["session_token"]
self._s3_client = None
self._s3_client_lock = threading.Lock()
threadpool_size = config.get("threadpool_size", 40)
self._s3_pool = ThreadPool(name="s3-pool", maxthreads=threadpool_size)
self._s3_pool.start()
# Manually stop the thread pool on shutdown. If we don't do this then
# stopping Synapse takes an extra ~30s as Python waits for the threads
# to exit.
reactor.addSystemEventTrigger(
"during", "shutdown", self._s3_pool.stop,
)
def _get_s3_client(self):
# this method is designed to be thread-safe, so that we can share a
# single boto3 client across multiple threads.
#
# (XXX: is creating a client actually a blocking operation, or could we do
# this on the main thread, to simplify all this?)
# first of all, do a fast lock-free check
s3 = self._s3_client
if s3:
return s3
# no joy, grab the lock and repeat the check
with self._s3_client_lock:
s3 = self._s3_client
if not s3:
b3_session = boto3.session.Session()
self._s3_client = s3 = b3_session.client("s3", **self.api_kwargs)
return s3
2018-02-07 11:34:53 +00:00
def store_file(self, path, file_info):
"""See StorageProvider.store_file"""
parent_logcontext = current_context()
2018-02-07 13:19:31 +00:00
def _store_file():
with LoggingContext(parent_context=parent_logcontext):
self._get_s3_client().upload_file(
Filename=os.path.join(self.cache_directory, path),
Bucket=self.bucket,
Key=self.prefix + path,
2022-11-04 17:17:00 +00:00
ExtraArgs=self.extra_args,
)
return make_deferred_yieldable(
threads.deferToThreadPool(reactor, self._s3_pool, _store_file)
)
2018-02-07 11:34:53 +00:00
def fetch(self, path, file_info):
"""See StorageProvider.fetch"""
logcontext = current_context()
2018-02-07 11:34:53 +00:00
d = defer.Deferred()
def _get_file():
s3_download_task(
self._get_s3_client(), self.bucket, self.prefix + path, self.extra_args, d, logcontext
)
self._s3_pool.callInThread(_get_file)
2018-02-12 13:49:43 +00:00
return make_deferred_yieldable(d)
2018-02-07 11:34:53 +00:00
@staticmethod
def parse_config(config):
"""Called on startup to parse config supplied. This should parse
the config and raise if there is a problem.
The returned value is passed into the constructor.
In this case we return a dict with fields, `bucket`, `prefix` and `storage_class`
2018-02-07 11:34:53 +00:00
"""
2018-02-13 16:59:28 +00:00
bucket = config["bucket"]
prefix = config.get("prefix", "")
2018-02-13 16:59:28 +00:00
storage_class = config.get("storage_class", "STANDARD")
assert isinstance(bucket, string_types)
2018-02-13 16:59:28 +00:00
assert storage_class in _VALID_STORAGE_CLASSES
result = {
2018-02-13 16:59:28 +00:00
"bucket": bucket,
"prefix": prefix,
2022-11-04 17:17:00 +00:00
"extra_args": {"StorageClass": storage_class},
2018-02-13 16:59:28 +00:00
}
2018-02-07 11:34:53 +00:00
if "region_name" in config:
result["region_name"] = config["region_name"]
if "endpoint_url" in config:
result["endpoint_url"] = config["endpoint_url"]
if "access_key_id" in config:
result["access_key_id"] = config["access_key_id"]
if "secret_access_key" in config:
result["secret_access_key"] = config["secret_access_key"]
2024-10-15 18:55:04 +00:00
if "session_token" in config:
result["session_token"] = config["session_token"]
if "sse_customer_key" in config:
2022-11-04 17:17:00 +00:00
result["extra_args"]["SSECustomerKey"] = config["sse_customer_key"]
result["extra_args"]["SSECustomerAlgorithm"] = config.get(
"sse_customer_algo", "AES256"
)
return result
2018-02-07 11:34:53 +00:00
2022-11-04 17:17:00 +00:00
def s3_download_task(s3_client, bucket, key, extra_args, deferred, parent_logcontext):
2018-02-12 13:50:21 +00:00
"""Attempts to download a file from S3.
Args:
s3_client: boto3 s3 client
2018-02-12 13:50:21 +00:00
bucket (str): The S3 bucket which may have the file
key (str): The key of the file
2018-02-12 13:50:47 +00:00
deferred (Deferred[_S3Responder|None]): If file exists
2018-02-12 13:50:21 +00:00
resolved with an _S3Responder instance, if it doesn't
exist then resolves with None.
parent_logcontext (LoggingContext): the logcontext to report logs and metrics
against.
2018-02-12 13:50:21 +00:00
"""
with LoggingContext(parent_context=parent_logcontext):
logger.info("Fetching %s from S3", key)
2018-02-12 13:50:21 +00:00
try:
if "SSECustomerKey" in extra_args and "SSECustomerAlgorithm" in extra_args:
resp = s3_client.get_object(
Bucket=bucket,
Key=key,
2022-11-04 17:17:00 +00:00
SSECustomerKey=extra_args["SSECustomerKey"],
SSECustomerAlgorithm=extra_args["SSECustomerAlgorithm"],
)
else:
resp = s3_client.get_object(Bucket=bucket, Key=key)
2018-02-12 13:50:21 +00:00
except botocore.exceptions.ClientError as e:
2022-10-30 20:58:08 +00:00
if e.response["Error"]["Code"] in ("404", "NoSuchKey",):
logger.info("Media %s not found in S3", key)
reactor.callFromThread(deferred.callback, None)
2018-02-12 13:50:21 +00:00
return
reactor.callFromThread(deferred.errback, Failure())
2018-02-12 13:50:21 +00:00
return
producer = _S3Responder()
reactor.callFromThread(deferred.callback, producer)
2020-01-23 11:48:59 +00:00
_stream_to_producer(reactor, producer, resp["Body"], timeout=90.0)
2018-02-12 13:50:21 +00:00
def _stream_to_producer(reactor, producer, body, status=None, timeout=None):
"""Streams a file like object to the producer.
2018-02-12 13:50:21 +00:00
Correctly handles producer being paused/resumed/stopped.
2018-02-12 13:50:21 +00:00
Args:
reactor
producer (_S3Responder): Producer object to stream results to
body (file like): The object to read from
status (_ProducerStatus|None): Used to track whether we're currently
paused or not. Used for testing
timeout (float|None): Timeout in seconds to wait for consume to resume
after being paused
"""
2018-02-12 13:50:21 +00:00
# Set when we should be producing, cleared when we are paused
wakeup_event = producer.wakeup_event
2018-02-12 13:50:21 +00:00
# Set if we should stop producing forever
stop_event = producer.stop_event
if not status:
status = _ProducerStatus()
try:
while not stop_event.is_set():
# We wait for the producer to signal that the consumer wants
# more data (or we should abort)
if not wakeup_event.is_set():
status.set_paused(True)
ret = wakeup_event.wait(timeout)
if not ret:
raise Exception("Timed out waiting to resume")
status.set_paused(False)
# Check if we were woken up so that we abort the download
if stop_event.is_set():
return
chunk = body.read(READ_CHUNK_SIZE)
if not chunk:
return
reactor.callFromThread(producer._write, chunk)
except Exception:
reactor.callFromThread(producer._error, Failure())
finally:
reactor.callFromThread(producer._finish)
if body:
body.close()
2018-02-12 13:50:21 +00:00
2018-02-07 11:34:53 +00:00
class _S3Responder(Responder):
"""A Responder for S3. Created by _S3DownloadThread
"""
2020-01-23 11:48:59 +00:00
def __init__(self):
# Triggered by responder when more data has been requested (or
# stop_event has been triggered)
self.wakeup_event = threading.Event()
# Trigered by responder when we should abort the download.
self.stop_event = threading.Event()
2018-02-07 11:34:53 +00:00
# The consumer we're registered to
self.consumer = None
# The deferred returned by write_to_consumer, which should resolve when
# all the data has been written (or there has been a fatal error).
self.deferred = defer.Deferred()
def write_to_consumer(self, consumer):
"""See Responder.write_to_consumer
"""
self.consumer = consumer
2018-08-23 10:01:36 +00:00
# We are a IPushProducer, so we start producing immediately until we
# get a pauseProducing or stopProducing
consumer.registerProducer(self, True)
self.wakeup_event.set()
return make_deferred_yieldable(self.deferred)
2018-02-07 11:34:53 +00:00
def __exit__(self, exc_type, exc_val, exc_tb):
self.stop_event.set()
self.wakeup_event.set()
def resumeProducing(self):
2018-08-23 10:01:36 +00:00
"""See IPushProducer.resumeProducing
2018-02-07 11:34:53 +00:00
"""
# The consumer is asking for more data, signal _S3DownloadThread
self.wakeup_event.set()
2018-08-23 10:01:36 +00:00
def pauseProducing(self):
"""See IPushProducer.stopProducing
"""
self.wakeup_event.clear()
2018-02-07 11:34:53 +00:00
def stopProducing(self):
2018-08-23 10:01:36 +00:00
"""See IPushProducer.stopProducing
2018-02-07 11:34:53 +00:00
"""
# The consumer wants no more data ever, signal _S3DownloadThread
self.stop_event.set()
self.wakeup_event.set()
if not self.deferred.called:
self.deferred.errback(Exception("Consumer ask to stop producing"))
2018-02-07 11:34:53 +00:00
def _write(self, chunk):
"""Writes the chunk of data to consumer. Called by _S3DownloadThread.
"""
if self.consumer and not self.stop_event.is_set():
self.consumer.write(chunk)
def _error(self, failure):
"""Called when a fatal error occured while getting data. Called by
_S3DownloadThread.
"""
if self.consumer:
self.consumer.unregisterProducer()
self.consumer = None
if not self.deferred.called:
self.deferred.errback(failure)
def _finish(self):
"""Called when there is no more data to write. Called by _S3DownloadThread.
"""
if self.consumer:
self.consumer.unregisterProducer()
self.consumer = None
if not self.deferred.called:
self.deferred.callback(None)
class _ProducerStatus(object):
"""Used to track whether the s3 download thread is currently paused
waiting for consumer to resume. Used for testing.
"""
def __init__(self):
self.is_paused = threading.Event()
self.is_paused.clear()
def wait_until_paused(self, timeout=None):
is_paused = self.is_paused.wait(timeout)
if not is_paused:
raise Exception("Timed out waiting")
def set_paused(self, paused):
if paused:
self.is_paused.set()
else:
self.is_paused.clear()