Generalize for usage in other S3-like services.

This commit is contained in:
Saad Rhoulam 2018-08-21 01:02:43 -04:00
parent 54c58922de
commit 503fd7ba38

View file

@ -46,12 +46,15 @@ class S3StorageProviderBackend(StorageProvider):
self.cache_directory = hs.config.media_store_path
self.bucket = config["bucket"]
self.storage_class = config["storage_class"]
self.api_kwargs = {}
if "endpoint_url" in config:
self.api_kwargs["endpoint_url"] = config["endpoint_url"]
def store_file(self, path, file_info):
"""See StorageProvider.store_file"""
def _store_file():
boto3.resource('s3').Bucket(self.bucket).upload_file(
boto3.resource('s3', **self.api_kwargs).Bucket(self.bucket).upload_file(
Filename=os.path.join(self.cache_directory, path),
Key=path,
ExtraArgs={"StorageClass": self.storage_class},
@ -64,7 +67,7 @@ class S3StorageProviderBackend(StorageProvider):
def fetch(self, path, file_info):
"""See StorageProvider.fetch"""
d = defer.Deferred()
_S3DownloadThread(self.bucket, path, d).start()
_S3DownloadThread(self.bucket, self.api_kwargs, path, d).start()
return make_deferred_yieldable(d)
@staticmethod
@ -82,17 +85,23 @@ class S3StorageProviderBackend(StorageProvider):
assert isinstance(bucket, basestring)
assert storage_class in _VALID_STORAGE_CLASSES
return {
result = {
"bucket": bucket,
"storage_class": storage_class,
}
if "endpoint_url" in config:
result["endpoint_url"] = config["endpoint_url"]
return result
class _S3DownloadThread(threading.Thread):
"""Attempts to download a file from S3.
Args:
bucket (str): The S3 bucket which may have the file
api_kwargs (dict): Keyword arguments to pass when invoking the API. Generally `endpoint_url`.
key (str): The key of the file
deferred (Deferred[_S3Responder|None]): If file exists
resolved with an _S3Responder instance, if it doesn't
@ -105,15 +114,16 @@ class _S3DownloadThread(threading.Thread):
READ_CHUNK_SIZE = 16 * 1024
def __init__(self, bucket, key, deferred):
def __init__(self, bucket, api_kwargs, key, deferred):
super(_S3DownloadThread, self).__init__(name="s3-download")
self.bucket = bucket
self.api_kwargs = api_kwargs
self.key = key
self.deferred = deferred
def run(self):
session = boto3.session.Session()
s3 = session.client('s3')
s3 = session.client('s3', **self.api_kwargs)
try:
resp = s3.get_object(Bucket=self.bucket, Key=self.key)