mirror of
https://github.com/matrix-org/synapse-s3-storage-provider.git
synced 2024-10-23 07:29:40 +00:00
Merge pull request #9 from matrix-org/erikj/push_producer
Convert to being PushProducer and add tests
This commit is contained in:
commit
2af74f2e28
3 changed files with 286 additions and 53 deletions
13
.travis.yml
Normal file
13
.travis.yml
Normal file
|
@ -0,0 +1,13 @@
|
|||
language: python
|
||||
python:
|
||||
- "2.7"
|
||||
- "3.5"
|
||||
- "3.6"
|
||||
# command to install dependencies
|
||||
install:
|
||||
- pip install twisted
|
||||
- pip install boto3
|
||||
- pip install https://github.com/matrix-org/synapse/tarball/master
|
||||
# command to run tests
|
||||
script:
|
||||
- PYTHONPATH=. trial test_s3
|
|
@ -34,6 +34,9 @@ logger = logging.getLogger("synapse.s3")
|
|||
# The list of valid AWS storage class names
|
||||
_VALID_STORAGE_CLASSES = ('STANDARD', 'REDUCED_REDUNDANCY', 'STANDARD_IA')
|
||||
|
||||
# Chunk size to use when reading from s3 connection in bytes
|
||||
READ_CHUNK_SIZE = 16 * 1024
|
||||
|
||||
|
||||
class S3StorageProviderBackend(StorageProvider):
|
||||
"""
|
||||
|
@ -97,14 +100,8 @@ class _S3DownloadThread(threading.Thread):
|
|||
deferred (Deferred[_S3Responder|None]): If file exists
|
||||
resolved with an _S3Responder instance, if it doesn't
|
||||
exist then resolves with None.
|
||||
|
||||
Attributes:
|
||||
READ_CHUNK_SIZE (int): The chunk size in bytes used when downloading
|
||||
file.
|
||||
"""
|
||||
|
||||
READ_CHUNK_SIZE = 16 * 1024
|
||||
|
||||
def __init__(self, bucket, key, deferred):
|
||||
super(_S3DownloadThread, self).__init__(name="s3-download")
|
||||
self.bucket = bucket
|
||||
|
@ -125,39 +122,58 @@ class _S3DownloadThread(threading.Thread):
|
|||
reactor.callFromThread(self.deferred.errback, Failure())
|
||||
return
|
||||
|
||||
# Triggered by responder when more data has been requested (or
|
||||
# stop_event has been triggered)
|
||||
wakeup_event = threading.Event()
|
||||
# Trigered by responder when we should abort the download.
|
||||
stop_event = threading.Event()
|
||||
|
||||
producer = _S3Responder(wakeup_event, stop_event)
|
||||
producer = _S3Responder()
|
||||
reactor.callFromThread(self.deferred.callback, producer)
|
||||
_stream_to_producer(reactor, producer, resp["Body"], timeout=90.)
|
||||
|
||||
|
||||
def _stream_to_producer(reactor, producer, body, status=None, timeout=None):
|
||||
"""Streams a file like object to the producer.
|
||||
|
||||
Correctly handles producer being paused/resumed/stopped.
|
||||
|
||||
Args:
|
||||
reactor
|
||||
producer (_S3Responder): Producer object to stream results to
|
||||
body (file like): The object to read from
|
||||
status (_ProducerStatus|None): Used to track whether we're currently
|
||||
paused or not. Used for testing
|
||||
timeout (float|None): Timeout in seconds to wait for consume to resume
|
||||
after being paused
|
||||
"""
|
||||
|
||||
# Set when we should be producing, cleared when we are paused
|
||||
wakeup_event = producer.wakeup_event
|
||||
|
||||
# Set if we should stop producing forever
|
||||
stop_event = producer.stop_event
|
||||
|
||||
if not status:
|
||||
status = _ProducerStatus()
|
||||
|
||||
try:
|
||||
body = resp["Body"]
|
||||
|
||||
while not stop_event.is_set():
|
||||
# We wait for the producer to signal that the consumer wants
|
||||
# more data (or we should abort)
|
||||
wakeup_event.wait()
|
||||
if not wakeup_event.is_set():
|
||||
status.set_paused(True)
|
||||
ret = wakeup_event.wait(timeout)
|
||||
if not ret:
|
||||
raise Exception("Timed out waiting to resume")
|
||||
status.set_paused(False)
|
||||
|
||||
# Check if we were woken up so that we abort the download
|
||||
if stop_event.is_set():
|
||||
return
|
||||
|
||||
chunk = body.read(self.READ_CHUNK_SIZE)
|
||||
chunk = body.read(READ_CHUNK_SIZE)
|
||||
if not chunk:
|
||||
return
|
||||
|
||||
# We clear the wakeup_event flag just before we write the data
|
||||
# to producer.
|
||||
wakeup_event.clear()
|
||||
reactor.callFromThread(producer._write, chunk)
|
||||
|
||||
except Exception:
|
||||
reactor.callFromThread(producer._error, Failure())
|
||||
return
|
||||
finally:
|
||||
reactor.callFromThread(producer._finish)
|
||||
if body:
|
||||
|
@ -166,18 +182,13 @@ class _S3DownloadThread(threading.Thread):
|
|||
|
||||
class _S3Responder(Responder):
|
||||
"""A Responder for S3. Created by _S3DownloadThread
|
||||
|
||||
Args:
|
||||
wakeup_event (threading.Event): Used to signal to _S3DownloadThread
|
||||
that consumer is ready for more data (or that we've triggered
|
||||
stop_event).
|
||||
stop_event (threading.Event): Used to signal to _S3DownloadThread that
|
||||
it should stop producing. `wakeup_event` must also be set if
|
||||
`stop_event` is used.
|
||||
"""
|
||||
def __init__(self, wakeup_event, stop_event):
|
||||
self.wakeup_event = wakeup_event
|
||||
self.stop_event = stop_event
|
||||
def __init__(self):
|
||||
# Triggered by responder when more data has been requested (or
|
||||
# stop_event has been triggered)
|
||||
self.wakeup_event = threading.Event()
|
||||
# Trigered by responder when we should abort the download.
|
||||
self.stop_event = threading.Event()
|
||||
|
||||
# The consumer we're registered to
|
||||
self.consumer = None
|
||||
|
@ -190,9 +201,10 @@ class _S3Responder(Responder):
|
|||
"""See Responder.write_to_consumer
|
||||
"""
|
||||
self.consumer = consumer
|
||||
# We are a IPullProducer, so we expect consumer to call resumeProducing
|
||||
# each time they want a new chunk of data.
|
||||
consumer.registerProducer(self, False)
|
||||
# We are a IPushProducer, so we start producing immediately until we
|
||||
# get a pauseProducing or stopProducing
|
||||
consumer.registerProducer(self, True)
|
||||
self.wakeup_event.set()
|
||||
return self.deferred
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
|
@ -200,17 +212,23 @@ class _S3Responder(Responder):
|
|||
self.wakeup_event.set()
|
||||
|
||||
def resumeProducing(self):
|
||||
"""See IPullProducer.resumeProducing
|
||||
"""See IPushProducer.resumeProducing
|
||||
"""
|
||||
# The consumer is asking for more data, signal _S3DownloadThread
|
||||
self.wakeup_event.set()
|
||||
|
||||
def pauseProducing(self):
|
||||
"""See IPushProducer.stopProducing
|
||||
"""
|
||||
self.wakeup_event.clear()
|
||||
|
||||
def stopProducing(self):
|
||||
"""See IPullProducer.stopProducing
|
||||
"""See IPushProducer.stopProducing
|
||||
"""
|
||||
# The consumer wants no more data ever, signal _S3DownloadThread
|
||||
self.stop_event.set()
|
||||
self.wakeup_event.set()
|
||||
if not self.deferred.called:
|
||||
self.deferred.errback(Exception("Consumer ask to stop producing"))
|
||||
|
||||
def _write(self, chunk):
|
||||
|
@ -239,3 +257,24 @@ class _S3Responder(Responder):
|
|||
|
||||
if not self.deferred.called:
|
||||
self.deferred.callback(None)
|
||||
|
||||
|
||||
class _ProducerStatus(object):
|
||||
"""Used to track whether the s3 download thread is currently paused
|
||||
waiting for consumer to resume. Used for testing.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.is_paused = threading.Event()
|
||||
self.is_paused.clear()
|
||||
|
||||
def wait_until_paused(self, timeout=None):
|
||||
is_paused = self.is_paused.wait(timeout)
|
||||
if not is_paused:
|
||||
raise Exception("Timed out waiting")
|
||||
|
||||
def set_paused(self, paused):
|
||||
if paused:
|
||||
self.is_paused.set()
|
||||
else:
|
||||
self.is_paused.clear()
|
||||
|
|
181
test_s3.py
Normal file
181
test_s3.py
Normal file
|
@ -0,0 +1,181 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2018 New Vector Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from twisted.internet import defer
|
||||
from twisted.python.failure import Failure
|
||||
from twisted.test.proto_helpers import MemoryReactorClock
|
||||
from twisted.trial import unittest
|
||||
|
||||
from queue import Queue
|
||||
from threading import Event, Thread
|
||||
|
||||
from mock import Mock
|
||||
|
||||
from s3_storage_provider import (
|
||||
_stream_to_producer, _S3Responder, _ProducerStatus,
|
||||
)
|
||||
|
||||
|
||||
class StreamingProducerTestCase(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.reactor = ThreadedMemoryReactorClock()
|
||||
|
||||
self.body = Channel()
|
||||
self.consumer = Mock()
|
||||
self.written = ""
|
||||
|
||||
def write(data):
|
||||
self.written += data
|
||||
|
||||
self.consumer.write.side_effect = write
|
||||
|
||||
self.producer_status = _ProducerStatus()
|
||||
self.producer = _S3Responder()
|
||||
self.thread = Thread(
|
||||
target=_stream_to_producer,
|
||||
args=(self.reactor, self.producer, self.body),
|
||||
kwargs={
|
||||
"status": self.producer_status,
|
||||
"timeout": 1.0,
|
||||
},
|
||||
)
|
||||
self.thread.daemon = True
|
||||
self.thread.start()
|
||||
|
||||
def tearDown(self):
|
||||
# Really ensure that we've stopped the thread
|
||||
self.producer.stopProducing()
|
||||
|
||||
def test_simple_produce(self):
|
||||
deferred = self.producer.write_to_consumer(self.consumer)
|
||||
|
||||
self.body.write("test")
|
||||
self.wait_for_thread()
|
||||
self.assertEqual("test", self.written)
|
||||
|
||||
self.body.write(" string")
|
||||
self.wait_for_thread()
|
||||
self.assertEqual("test string", self.written)
|
||||
|
||||
self.body.finish()
|
||||
self.wait_for_thread()
|
||||
|
||||
self.assertTrue(deferred.called)
|
||||
self.assertEqual(deferred.result, None)
|
||||
|
||||
def test_pause_produce(self):
|
||||
deferred = self.producer.write_to_consumer(self.consumer)
|
||||
|
||||
self.body.write("test")
|
||||
self.wait_for_thread()
|
||||
self.assertEqual("test", self.written)
|
||||
|
||||
# We pause producing, but the thread will currently be blocked waiting
|
||||
# to read data, so we wake it up by writing before asserting that
|
||||
# it actually pauses.
|
||||
self.producer.pauseProducing()
|
||||
self.body.write(" string")
|
||||
self.wait_for_thread()
|
||||
self.producer_status.wait_until_paused(10.)
|
||||
self.assertEqual("test string", self.written)
|
||||
|
||||
# If we write again we remain paused and nothing gets written
|
||||
self.body.write(" second")
|
||||
self.producer_status.wait_until_paused(10.)
|
||||
self.assertEqual("test string", self.written)
|
||||
|
||||
# If we call resumeProducing the buffered data gets read and written.
|
||||
self.producer.resumeProducing()
|
||||
self.wait_for_thread()
|
||||
self.assertEqual("test string second", self.written)
|
||||
|
||||
# We can continue writing as normal now
|
||||
self.body.write(" third")
|
||||
self.wait_for_thread()
|
||||
self.assertEqual("test string second third", self.written)
|
||||
|
||||
self.body.finish()
|
||||
self.wait_for_thread()
|
||||
|
||||
self.assertTrue(deferred.called)
|
||||
self.assertEqual(deferred.result, None)
|
||||
|
||||
def test_error(self):
|
||||
deferred = self.producer.write_to_consumer(self.consumer)
|
||||
|
||||
self.body.write("test")
|
||||
self.wait_for_thread()
|
||||
self.assertEqual("test", self.written)
|
||||
|
||||
excp = Exception("Test Exception")
|
||||
self.body.error(excp)
|
||||
self.wait_for_thread()
|
||||
|
||||
self.assertTrue(deferred.called)
|
||||
self.assertIsInstance(deferred.result, Failure)
|
||||
|
||||
def wait_for_thread(self):
|
||||
"""Wait for something to call `callFromThread` and advance reactor
|
||||
"""
|
||||
self.reactor.thread_event.wait(1)
|
||||
self.reactor.thread_event.clear()
|
||||
self.reactor.advance(0)
|
||||
|
||||
|
||||
class ThreadedMemoryReactorClock(MemoryReactorClock):
|
||||
"""
|
||||
A MemoryReactorClock that supports callFromThread.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super(ThreadedMemoryReactorClock, self).__init__()
|
||||
self.thread_event = Event()
|
||||
|
||||
def callFromThread(self, callback, *args, **kwargs):
|
||||
"""
|
||||
Make the callback fire in the next reactor iteration.
|
||||
"""
|
||||
d = defer.Deferred()
|
||||
d.addCallback(lambda x: callback(*args, **kwargs))
|
||||
self.callLater(0, d.callback, True)
|
||||
|
||||
self.thread_event.set()
|
||||
|
||||
return d
|
||||
|
||||
|
||||
class Channel(object):
|
||||
"""Simple channel to mimic a thread safe file like object
|
||||
"""
|
||||
def __init__(self):
|
||||
self._queue = Queue()
|
||||
|
||||
def read(self, _):
|
||||
val = self._queue.get()
|
||||
if isinstance(val, Exception):
|
||||
raise val
|
||||
return val
|
||||
|
||||
def write(self, val):
|
||||
self._queue.put(val)
|
||||
|
||||
def error(self, err):
|
||||
self._queue.put(err)
|
||||
|
||||
def finish(self):
|
||||
self._queue.put(None)
|
||||
|
||||
def close(self):
|
||||
pass
|
Loading…
Reference in a new issue