diff --git a/test_s3.py b/test_s3.py new file mode 100644 index 0000000..379d381 --- /dev/null +++ b/test_s3.py @@ -0,0 +1,181 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from twisted.internet import defer +from twisted.python.failure import Failure +from twisted.test.proto_helpers import MemoryReactorClock +from twisted.trial import unittest + +from queue import Queue +from threading import Event, Thread + +from mock import Mock + +from s3_storage_provider import ( + _stream_to_producer, _S3Responder, _ProducerStatus, +) + + +class StreamingProducerTestCase(unittest.TestCase): + def setUp(self): + self.reactor = ThreadedMemoryReactorClock() + + self.body = Channel() + self.consumer = Mock() + self.written = "" + + def write(data): + self.written += data + + self.consumer.write.side_effect = write + + self.producer_status = _ProducerStatus() + self.producer = _S3Responder() + self.thread = Thread( + target=_stream_to_producer, + args=(self.reactor, self.producer, self.body), + kwargs={ + "status": self.producer_status, + "timeout": 1.0, + }, + ) + self.thread.daemon = True + self.thread.start() + + def tearDown(self): + # Really ensure that we've stopped the thread + self.producer.stopProducing() + + def test_simple_produce(self): + deferred = self.producer.write_to_consumer(self.consumer) + + self.body.write("test") + self.wait_for_thread() + self.assertEqual("test", self.written) + + self.body.write(" string") + self.wait_for_thread() + self.assertEqual("test string", self.written) + + self.body.finish() + self.wait_for_thread() + + self.assertTrue(deferred.called) + self.assertEqual(deferred.result, None) + + def test_pause_produce(self): + deferred = self.producer.write_to_consumer(self.consumer) + + self.body.write("test") + self.wait_for_thread() + self.assertEqual("test", self.written) + + # We pause producing, but the thread will currently be blocked waiting + # to read data, so we wake it up by writing before asserting that + # it actually pauses. + self.producer.pauseProducing() + self.body.write(" string") + self.wait_for_thread() + self.producer_status.wait_until_paused(10.) + self.assertEqual("test string", self.written) + + # If we write again we remain paused and nothing gets written + self.body.write(" second") + self.producer_status.wait_until_paused(10.) + self.assertEqual("test string", self.written) + + # If we call resumeProducing the buffered data gets read and written. + self.producer.resumeProducing() + self.wait_for_thread() + self.assertEqual("test string second", self.written) + + # We can continue writing as normal now + self.body.write(" third") + self.wait_for_thread() + self.assertEqual("test string second third", self.written) + + self.body.finish() + self.wait_for_thread() + + self.assertTrue(deferred.called) + self.assertEqual(deferred.result, None) + + def test_error(self): + deferred = self.producer.write_to_consumer(self.consumer) + + self.body.write("test") + self.wait_for_thread() + self.assertEqual("test", self.written) + + excp = Exception("Test Exception") + self.body.error(excp) + self.wait_for_thread() + + self.assertTrue(deferred.called) + self.assertIsInstance(deferred.result, Failure) + + def wait_for_thread(self): + """Wait for something to call `callFromThread` and advance reactor + """ + self.reactor.thread_event.wait(1) + self.reactor.thread_event.clear() + self.reactor.advance(0) + + +class ThreadedMemoryReactorClock(MemoryReactorClock): + """ + A MemoryReactorClock that supports callFromThread. + """ + + def __init__(self): + super(ThreadedMemoryReactorClock, self).__init__() + self.thread_event = Event() + + def callFromThread(self, callback, *args, **kwargs): + """ + Make the callback fire in the next reactor iteration. + """ + d = defer.Deferred() + d.addCallback(lambda x: callback(*args, **kwargs)) + self.callLater(0, d.callback, True) + + self.thread_event.set() + + return d + + +class Channel(object): + """Simple channel to mimic a thread safe file like object + """ + def __init__(self): + self._queue = Queue() + + def read(self, _): + val = self._queue.get() + if isinstance(val, Exception): + raise val + return val + + def write(self, val): + self._queue.put(val) + + def error(self, err): + self._queue.put(err) + + def finish(self): + self._queue.put(None) + + def close(self): + pass