diff --git a/.travis.yml b/.travis.yml new file mode 100644 index 0000000..7d99673 --- /dev/null +++ b/.travis.yml @@ -0,0 +1,13 @@ +language: python +python: + - "2.7" + - "3.5" + - "3.6" +# command to install dependencies +install: + - pip install twisted + - pip install boto3 + - pip install https://github.com/matrix-org/synapse/tarball/master +# command to run tests +script: + - PYTHONPATH=. trial test_s3 diff --git a/s3_storage_provider.py b/s3_storage_provider.py index 1ec4951..ac75514 100644 --- a/s3_storage_provider.py +++ b/s3_storage_provider.py @@ -34,6 +34,9 @@ logger = logging.getLogger("synapse.s3") # The list of valid AWS storage class names _VALID_STORAGE_CLASSES = ('STANDARD', 'REDUCED_REDUNDANCY', 'STANDARD_IA') +# Chunk size to use when reading from s3 connection in bytes +READ_CHUNK_SIZE = 16 * 1024 + class S3StorageProviderBackend(StorageProvider): """ @@ -97,14 +100,8 @@ class _S3DownloadThread(threading.Thread): deferred (Deferred[_S3Responder|None]): If file exists resolved with an _S3Responder instance, if it doesn't exist then resolves with None. - - Attributes: - READ_CHUNK_SIZE (int): The chunk size in bytes used when downloading - file. """ - READ_CHUNK_SIZE = 16 * 1024 - def __init__(self, bucket, key, deferred): super(_S3DownloadThread, self).__init__(name="s3-download") self.bucket = bucket @@ -125,59 +122,73 @@ class _S3DownloadThread(threading.Thread): reactor.callFromThread(self.deferred.errback, Failure()) return - # Triggered by responder when more data has been requested (or - # stop_event has been triggered) - wakeup_event = threading.Event() - # Trigered by responder when we should abort the download. - stop_event = threading.Event() - - producer = _S3Responder(wakeup_event, stop_event) + producer = _S3Responder() reactor.callFromThread(self.deferred.callback, producer) + _stream_to_producer(reactor, producer, resp["Body"], timeout=90.) - try: - body = resp["Body"] - while not stop_event.is_set(): - # We wait for the producer to signal that the consumer wants - # more data (or we should abort) - wakeup_event.wait() +def _stream_to_producer(reactor, producer, body, status=None, timeout=None): + """Streams a file like object to the producer. - # Check if we were woken up so that we abort the download - if stop_event.is_set(): - return + Correctly handles producer being paused/resumed/stopped. - chunk = body.read(self.READ_CHUNK_SIZE) - if not chunk: - return + Args: + reactor + producer (_S3Responder): Producer object to stream results to + body (file like): The object to read from + status (_ProducerStatus|None): Used to track whether we're currently + paused or not. Used for testing + timeout (float|None): Timeout in seconds to wait for consume to resume + after being paused + """ - # We clear the wakeup_event flag just before we write the data - # to producer. - wakeup_event.clear() - reactor.callFromThread(producer._write, chunk) + # Set when we should be producing, cleared when we are paused + wakeup_event = producer.wakeup_event - except Exception: - reactor.callFromThread(producer._error, Failure()) - return - finally: - reactor.callFromThread(producer._finish) - if body: - body.close() + # Set if we should stop producing forever + stop_event = producer.stop_event + + if not status: + status = _ProducerStatus() + + try: + while not stop_event.is_set(): + # We wait for the producer to signal that the consumer wants + # more data (or we should abort) + if not wakeup_event.is_set(): + status.set_paused(True) + ret = wakeup_event.wait(timeout) + if not ret: + raise Exception("Timed out waiting to resume") + status.set_paused(False) + + # Check if we were woken up so that we abort the download + if stop_event.is_set(): + return + + chunk = body.read(READ_CHUNK_SIZE) + if not chunk: + return + + reactor.callFromThread(producer._write, chunk) + + except Exception: + reactor.callFromThread(producer._error, Failure()) + finally: + reactor.callFromThread(producer._finish) + if body: + body.close() class _S3Responder(Responder): """A Responder for S3. Created by _S3DownloadThread - - Args: - wakeup_event (threading.Event): Used to signal to _S3DownloadThread - that consumer is ready for more data (or that we've triggered - stop_event). - stop_event (threading.Event): Used to signal to _S3DownloadThread that - it should stop producing. `wakeup_event` must also be set if - `stop_event` is used. """ - def __init__(self, wakeup_event, stop_event): - self.wakeup_event = wakeup_event - self.stop_event = stop_event + def __init__(self): + # Triggered by responder when more data has been requested (or + # stop_event has been triggered) + self.wakeup_event = threading.Event() + # Trigered by responder when we should abort the download. + self.stop_event = threading.Event() # The consumer we're registered to self.consumer = None @@ -190,9 +201,10 @@ class _S3Responder(Responder): """See Responder.write_to_consumer """ self.consumer = consumer - # We are a IPullProducer, so we expect consumer to call resumeProducing - # each time they want a new chunk of data. - consumer.registerProducer(self, False) + # We are a IPushProducer, so we start producing immediately until we + # get a pauseProducing or stopProducing + consumer.registerProducer(self, True) + self.wakeup_event.set() return self.deferred def __exit__(self, exc_type, exc_val, exc_tb): @@ -200,18 +212,24 @@ class _S3Responder(Responder): self.wakeup_event.set() def resumeProducing(self): - """See IPullProducer.resumeProducing + """See IPushProducer.resumeProducing """ # The consumer is asking for more data, signal _S3DownloadThread self.wakeup_event.set() + def pauseProducing(self): + """See IPushProducer.stopProducing + """ + self.wakeup_event.clear() + def stopProducing(self): - """See IPullProducer.stopProducing + """See IPushProducer.stopProducing """ # The consumer wants no more data ever, signal _S3DownloadThread self.stop_event.set() self.wakeup_event.set() - self.deferred.errback(Exception("Consumer ask to stop producing")) + if not self.deferred.called: + self.deferred.errback(Exception("Consumer ask to stop producing")) def _write(self, chunk): """Writes the chunk of data to consumer. Called by _S3DownloadThread. @@ -239,3 +257,24 @@ class _S3Responder(Responder): if not self.deferred.called: self.deferred.callback(None) + + +class _ProducerStatus(object): + """Used to track whether the s3 download thread is currently paused + waiting for consumer to resume. Used for testing. + """ + + def __init__(self): + self.is_paused = threading.Event() + self.is_paused.clear() + + def wait_until_paused(self, timeout=None): + is_paused = self.is_paused.wait(timeout) + if not is_paused: + raise Exception("Timed out waiting") + + def set_paused(self, paused): + if paused: + self.is_paused.set() + else: + self.is_paused.clear() diff --git a/test_s3.py b/test_s3.py new file mode 100644 index 0000000..379d381 --- /dev/null +++ b/test_s3.py @@ -0,0 +1,181 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from twisted.internet import defer +from twisted.python.failure import Failure +from twisted.test.proto_helpers import MemoryReactorClock +from twisted.trial import unittest + +from queue import Queue +from threading import Event, Thread + +from mock import Mock + +from s3_storage_provider import ( + _stream_to_producer, _S3Responder, _ProducerStatus, +) + + +class StreamingProducerTestCase(unittest.TestCase): + def setUp(self): + self.reactor = ThreadedMemoryReactorClock() + + self.body = Channel() + self.consumer = Mock() + self.written = "" + + def write(data): + self.written += data + + self.consumer.write.side_effect = write + + self.producer_status = _ProducerStatus() + self.producer = _S3Responder() + self.thread = Thread( + target=_stream_to_producer, + args=(self.reactor, self.producer, self.body), + kwargs={ + "status": self.producer_status, + "timeout": 1.0, + }, + ) + self.thread.daemon = True + self.thread.start() + + def tearDown(self): + # Really ensure that we've stopped the thread + self.producer.stopProducing() + + def test_simple_produce(self): + deferred = self.producer.write_to_consumer(self.consumer) + + self.body.write("test") + self.wait_for_thread() + self.assertEqual("test", self.written) + + self.body.write(" string") + self.wait_for_thread() + self.assertEqual("test string", self.written) + + self.body.finish() + self.wait_for_thread() + + self.assertTrue(deferred.called) + self.assertEqual(deferred.result, None) + + def test_pause_produce(self): + deferred = self.producer.write_to_consumer(self.consumer) + + self.body.write("test") + self.wait_for_thread() + self.assertEqual("test", self.written) + + # We pause producing, but the thread will currently be blocked waiting + # to read data, so we wake it up by writing before asserting that + # it actually pauses. + self.producer.pauseProducing() + self.body.write(" string") + self.wait_for_thread() + self.producer_status.wait_until_paused(10.) + self.assertEqual("test string", self.written) + + # If we write again we remain paused and nothing gets written + self.body.write(" second") + self.producer_status.wait_until_paused(10.) + self.assertEqual("test string", self.written) + + # If we call resumeProducing the buffered data gets read and written. + self.producer.resumeProducing() + self.wait_for_thread() + self.assertEqual("test string second", self.written) + + # We can continue writing as normal now + self.body.write(" third") + self.wait_for_thread() + self.assertEqual("test string second third", self.written) + + self.body.finish() + self.wait_for_thread() + + self.assertTrue(deferred.called) + self.assertEqual(deferred.result, None) + + def test_error(self): + deferred = self.producer.write_to_consumer(self.consumer) + + self.body.write("test") + self.wait_for_thread() + self.assertEqual("test", self.written) + + excp = Exception("Test Exception") + self.body.error(excp) + self.wait_for_thread() + + self.assertTrue(deferred.called) + self.assertIsInstance(deferred.result, Failure) + + def wait_for_thread(self): + """Wait for something to call `callFromThread` and advance reactor + """ + self.reactor.thread_event.wait(1) + self.reactor.thread_event.clear() + self.reactor.advance(0) + + +class ThreadedMemoryReactorClock(MemoryReactorClock): + """ + A MemoryReactorClock that supports callFromThread. + """ + + def __init__(self): + super(ThreadedMemoryReactorClock, self).__init__() + self.thread_event = Event() + + def callFromThread(self, callback, *args, **kwargs): + """ + Make the callback fire in the next reactor iteration. + """ + d = defer.Deferred() + d.addCallback(lambda x: callback(*args, **kwargs)) + self.callLater(0, d.callback, True) + + self.thread_event.set() + + return d + + +class Channel(object): + """Simple channel to mimic a thread safe file like object + """ + def __init__(self): + self._queue = Queue() + + def read(self, _): + val = self._queue.get() + if isinstance(val, Exception): + raise val + return val + + def write(self, val): + self._queue.put(val) + + def error(self, err): + self._queue.put(err) + + def finish(self): + self._queue.put(None) + + def close(self): + pass