Merge pull request #16 from Sharparam/feature/improvements

Generalize and improve threading
This commit is contained in:
Erik Johnston 2019-05-22 14:32:50 +01:00 committed by GitHub
commit 197709f1f0
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 49 additions and 6 deletions

View file

@ -21,6 +21,12 @@ media_storage_providers:
store_synchronous: True store_synchronous: True
config: config:
bucket: <S3_BUCKET_NAME> bucket: <S3_BUCKET_NAME>
# All of the below options are optional, for use with non-AWS S3-like
# services, or to specify access tokens here instead of some external method.
region_name: <S3_REGION_NAME>
endpoint_url: <S3_LIKE_SERVICE_ENDPOINT_URL>
access_key_id: <S3_ACCESS_KEY_ID>
secret_access_key: <S3_SECRET_ACCESS_KEY>
``` ```
This module uses `boto3`, and so the credentials should be specified as This module uses `boto3`, and so the credentials should be specified as

View file

@ -50,12 +50,26 @@ class S3StorageProviderBackend(StorageProvider):
self.cache_directory = hs.config.media_store_path self.cache_directory = hs.config.media_store_path
self.bucket = config["bucket"] self.bucket = config["bucket"]
self.storage_class = config["storage_class"] self.storage_class = config["storage_class"]
self.api_kwargs = {}
if "region_name" in config:
self.api_kwargs["region_name"] = config["region_name"]
if "endpoint_url" in config:
self.api_kwargs["endpoint_url"] = config["endpoint_url"]
if "access_key_id" in config:
self.api_kwargs["aws_access_key_id"] = config["access_key_id"]
if "secret_access_key" in config:
self.api_kwargs["aws_secret_access_key"] = config["secret_access_key"]
def store_file(self, path, file_info): def store_file(self, path, file_info):
"""See StorageProvider.store_file""" """See StorageProvider.store_file"""
def _store_file(): def _store_file():
boto3.resource('s3').Bucket(self.bucket).upload_file( session = boto3.session.Session()
session.resource('s3', **self.api_kwargs).Bucket(self.bucket).upload_file(
Filename=os.path.join(self.cache_directory, path), Filename=os.path.join(self.cache_directory, path),
Key=path, Key=path,
ExtraArgs={"StorageClass": self.storage_class}, ExtraArgs={"StorageClass": self.storage_class},
@ -68,7 +82,7 @@ class S3StorageProviderBackend(StorageProvider):
def fetch(self, path, file_info): def fetch(self, path, file_info):
"""See StorageProvider.fetch""" """See StorageProvider.fetch"""
d = defer.Deferred() d = defer.Deferred()
_S3DownloadThread(self.bucket, path, d).start() _S3DownloadThread(self.bucket, self.api_kwargs, path, d).start()
return make_deferred_yieldable(d) return make_deferred_yieldable(d)
@staticmethod @staticmethod
@ -86,32 +100,49 @@ class S3StorageProviderBackend(StorageProvider):
assert isinstance(bucket, string_types) assert isinstance(bucket, string_types)
assert storage_class in _VALID_STORAGE_CLASSES assert storage_class in _VALID_STORAGE_CLASSES
return { result = {
"bucket": bucket, "bucket": bucket,
"storage_class": storage_class, "storage_class": storage_class,
} }
if "region_name" in config:
result["region_name"] = config["region_name"]
if "endpoint_url" in config:
result["endpoint_url"] = config["endpoint_url"]
if "access_key_id" in config:
result["access_key_id"] = config["access_key_id"]
if "secret_access_key" in config:
result["secret_access_key"] = config["secret_access_key"]
return result
class _S3DownloadThread(threading.Thread): class _S3DownloadThread(threading.Thread):
"""Attempts to download a file from S3. """Attempts to download a file from S3.
Args: Args:
bucket (str): The S3 bucket which may have the file bucket (str): The S3 bucket which may have the file
api_kwargs (dict): Keyword arguments to pass when invoking the API.
Generally `endpoint_url`.
key (str): The key of the file key (str): The key of the file
deferred (Deferred[_S3Responder|None]): If file exists deferred (Deferred[_S3Responder|None]): If file exists
resolved with an _S3Responder instance, if it doesn't resolved with an _S3Responder instance, if it doesn't
exist then resolves with None. exist then resolves with None.
""" """
def __init__(self, bucket, key, deferred): def __init__(self, bucket, api_kwargs, key, deferred):
super(_S3DownloadThread, self).__init__(name="s3-download") super(_S3DownloadThread, self).__init__(name="s3-download")
self.bucket = bucket self.bucket = bucket
self.api_kwargs = api_kwargs
self.key = key self.key = key
self.deferred = deferred self.deferred = deferred
def run(self): def run(self):
session = boto3.session.Session() session = boto3.session.Session()
s3 = session.client('s3') s3 = session.client('s3', **self.api_kwargs)
try: try:
resp = s3.get_object(Bucket=self.bucket, Key=self.key) resp = s3.get_object(Bucket=self.bucket, Key=self.key)

View file

@ -18,7 +18,13 @@ from twisted.python.failure import Failure
from twisted.test.proto_helpers import MemoryReactorClock from twisted.test.proto_helpers import MemoryReactorClock
from twisted.trial import unittest from twisted.trial import unittest
import sys
is_py2 = sys.version[0] == '2'
if is_py2:
from Queue import Queue
else:
from queue import Queue from queue import Queue
from threading import Event, Thread from threading import Event, Thread
from mock import Mock from mock import Mock