update downloader.find() and related code

Instead of replacing 'https' with 'http' for every URL in
'get_downloader()', this now only happens once during downloader
initialization. Also unit tests.
pull/359/head
Mike Fährmann 5 years ago
parent f4ba98771d
commit ee4d7c3d89
No known key found for this signature in database
GPG Key ID: 5680CA389D365A88

@ -22,15 +22,23 @@ def find(scheme):
try:
return _cache[scheme]
except KeyError:
klass = None
pass
klass = None
if scheme == "https":
scheme = "http"
if scheme in modules: # prevent unwanted imports
try:
if scheme in modules: # prevent unwanted imports
module = importlib.import_module("." + scheme, __package__)
klass = module.__downloader__
except (ImportError, AttributeError, TypeError):
module = importlib.import_module("." + scheme, __package__)
klass = module.__downloader__
except ImportError:
pass
if scheme == "http":
_cache["http"] = _cache["https"] = klass
else:
_cache[scheme] = klass
return klass
return klass
# --------------------------------------------------------------------

@ -281,20 +281,22 @@ class DownloadJob(Job):
def get_downloader(self, scheme):
"""Return a downloader suitable for 'scheme'"""
if scheme == "https":
scheme = "http"
try:
return self.downloaders[scheme]
except KeyError:
pass
klass = downloader.find(scheme)
if klass and config.get(("downloader", scheme, "enabled"), True):
if klass and config.get(("downloader", klass.scheme, "enabled"), True):
instance = klass(self.extractor, self.out)
else:
instance = None
self.log.error("'%s:' URLs are not supported/enabled", scheme)
self.downloaders[scheme] = instance
if klass.scheme == "http":
self.downloaders["http"] = self.downloaders["https"] = instance
else:
self.downloaders[scheme] = instance
return instance
def initialize(self, keywords=None):

@ -8,13 +8,16 @@
# published by the Free Software Foundation.
import re
import sys
import base64
import os.path
import tempfile
import unittest
import threading
import http.server
import unittest
from unittest.mock import Mock, MagicMock, patch
import gallery_dl.downloader as downloader
import gallery_dl.extractor as extractor
import gallery_dl.config as config
@ -23,6 +26,73 @@ from gallery_dl.output import NullOutput
from gallery_dl.util import PathFormat
class MockDownloaderModule(Mock):
__downloader__ = "mock"
class TestDownloaderModule(unittest.TestCase):
@classmethod
def setUpClass(cls):
# allow import of ytdl downloader module without youtube_dl installed
sys.modules["youtube_dl"] = MagicMock()
@classmethod
def tearDownClass(cls):
del sys.modules["youtube_dl"]
def tearDown(self):
downloader._cache.clear()
def test_find(self):
cls = downloader.find("http")
self.assertEqual(cls.__name__, "HttpDownloader")
self.assertEqual(cls.scheme , "http")
cls = downloader.find("https")
self.assertEqual(cls.__name__, "HttpDownloader")
self.assertEqual(cls.scheme , "http")
cls = downloader.find("text")
self.assertEqual(cls.__name__, "TextDownloader")
self.assertEqual(cls.scheme , "text")
cls = downloader.find("ytdl")
self.assertEqual(cls.__name__, "YoutubeDLDownloader")
self.assertEqual(cls.scheme , "ytdl")
self.assertEqual(downloader.find("ftp"), None)
self.assertEqual(downloader.find("foo"), None)
self.assertEqual(downloader.find(1234) , None)
self.assertEqual(downloader.find(None) , None)
@patch("importlib.import_module")
def test_cache(self, import_module):
import_module.return_value = MockDownloaderModule()
downloader.find("http")
downloader.find("text")
downloader.find("ytdl")
self.assertEqual(import_module.call_count, 3)
downloader.find("http")
downloader.find("text")
downloader.find("ytdl")
self.assertEqual(import_module.call_count, 3)
@patch("importlib.import_module")
def test_cache_http(self, import_module):
import_module.return_value = MockDownloaderModule()
downloader.find("http")
downloader.find("https")
self.assertEqual(import_module.call_count, 1)
@patch("importlib.import_module")
def test_cache_https(self, import_module):
import_module.return_value = MockDownloaderModule()
downloader.find("https")
downloader.find("http")
self.assertEqual(import_module.call_count, 1)
class TestDownloaderBase(unittest.TestCase):
@classmethod

Loading…
Cancel
Save