diff --git a/s3_storage_provider.py b/s3_storage_provider.py new file mode 100644 index 0000000..50fe8ed --- /dev/null +++ b/s3_storage_provider.py @@ -0,0 +1,220 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from twisted.internet import defer, reactor +from twisted.python.failure import Failure + +from synapse.rest.media.v1.storage_provider import StorageProvider +from synapse.rest.media.v1._base import Responder + +import boto3 +import botocore +import logging +import threading + + +logger = logging.getLogger("synapse.s3") + + +class S3StorageProviderBackend(StorageProvider): + """ + Args: + hs (HomeServer) + config: The config returned by `parse_config` + """ + + def __init__(self, hs, config): + self.cache_directory = hs.config.media_store_path + self.bucket = config + + self.s3 = boto3.client('s3') + + def store_file(self, path, file_info): + """See StorageProvider.store_file""" + + pass + + def fetch(self, path, file_info): + """See StorageProvider.fetch""" + d = defer.Deferred() + _S3DownloadThread(self.bucket, path, d).start() + return d + + @staticmethod + def parse_config(config): + """Called on startup to parse config supplied. This should parse + the config and raise if there is a problem. + + The returned value is passed into the constructor. + + In this case we only care about a single param, the bucket, so lets + just pull that out. + """ + assert isinstance(config["bucket"], basestring) + return config["bucket"] + + +class _S3Responder(Responder): + """A Responder for S3. Created by _S3DownloadThread + + Args: + wakeup_event (threading.Event): Used to signal to _S3DownloadThread + that consumer is ready for more data (or that we've triggered + stop_event). + stop_event (threading.Event): Used to signal to _S3DownloadThread that + it should stop producing. `wakeup_event` must also be set if + `stop_event` is used. + """ + def __init__(self, wakeup_event, stop_event): + self.wakeup_event = wakeup_event + self.stop_event = stop_event + + # The consumer we're registered to + self.consumer = None + + # The deferred returned by write_to_consumer, which should resolve when + # all the data has been written (or there has been a fatal error). + self.deferred = defer.Deferred() + + def write_to_consumer(self, consumer): + """See Responder.write_to_consumer + """ + self.consumer = consumer + # We are a IPullProducer, so we expect consumer to call resumeProducing + # each time they want a new chunk of data. + consumer.registerProducer(self, False) + return self.deferred + + def __exit__(self, exc_type, exc_val, exc_tb): + self.stop_event.set() + self.wakeup_event.set() + + def resumeProducing(self): + """See IPullProducer.resumeProducing + """ + # The consumer is asking for more data, signal _S3DownloadThread + self.wakeup_event.set() + + def stopProducing(self): + """See IPullProducer.stopProducing + """ + # The consumer wants no more data ever, signal _S3DownloadThread + self.stop_event.set() + self.wakeup_event.set() + self.deferred.errback(Exception("Consumer ask to stop producing")) + + def _write(self, chunk): + """Writes the chunk of data to consumer. Called by _S3DownloadThread. + """ + if self.consumer and not self.stop_event.is_set(): + self.consumer.write(chunk) + + def _error(self, failure): + """Called when a fatal error occured while getting data. Called by + _S3DownloadThread. + """ + if self.consumer: + self.consumer.unregisterProducer() + self.consumer = None + + if not self.deferred.called: + self.deferred.errback(failure) + + def _finish(self): + """Called when there is no more data to write. Called by _S3DownloadThread. + """ + if self.consumer: + self.consumer.unregisterProducer() + self.consumer = None + + if not self.deferred.called: + self.deferred.callback(None) + + +class _S3DownloadThread(threading.Thread): + """Attempts to download a file from S3. + + Args: + bucket (str): The S3 bucket which may have the file + key (str): The key of the file + deferred (Deferred[_S3Responder|None]): If files exists + resolved with an _S3Responder instance, if it doesn't + exist then resolves with None. + + Attributes: + READ_CHUNK_SIZE (int): The chunk size in bytes used when downloading + file. + """ + + READ_CHUNK_SIZE = 16 * 1024 + + def __init__(self, bucket, key, deferred): + super(_S3DownloadThread, self).__init__(name="s3-download") + self.bucket = bucket + self.key = key + self.deferred = deferred + + def run(self): + session = boto3.session.Session() + s3 = session.client('s3') + + try: + resp = s3.get_object(Bucket=self.bucket, Key=self.key) + except botocore.exceptions.ClientError as e: + if e.response['Error']['Code'] == "404": + reactor.callFromThread(self.deferred.callback, None) + return + + reactor.callFromThread(self.deferred.errback, Failure()) + return + + # Triggered by responder when more data has been requested (or + # stop_event has been triggered) + wakeup_event = threading.Event() + # Trigered by responder when we should abort the download. + stop_event = threading.Event() + + producer = _S3Responder(wakeup_event, stop_event) + reactor.callFromThread(self.deferred.callback, producer) + + try: + body = resp["Body"] + + while not stop_event.is_set(): + # We wait for the producer to signal that the consumer wants + # more data (or we should abort) + wakeup_event.wait() + + # Check if we were woken up so that we abort the download + if stop_event.is_set(): + return + + chunk = body.read(self.READ_CHUNK_SIZE) + if not chunk: + return + + # We clear the wakeup_event flag just before we write the data + # to producer. + wakeup_event.clear() + reactor.callFromThread(producer._write, chunk) + + except Exception: + reactor.callFromThread(producer._error, Failure()) + return + finally: + reactor.callFromThread(producer._finish) + if body: + body.close()