use __import__() to dynamically load modules

pull/1352/head
Mike Fährmann 4 years ago
parent 69ea781d32
commit 8821dceb79
No known key found for this signature in database
GPG Key ID: 5680CA389D365A88

@ -1,6 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2015-2019 Mike Fährmann
# Copyright 2015-2021 Mike Fährmann
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License version 2 as
@ -8,8 +8,6 @@
"""Downloader modules"""
import importlib
modules = [
"http",
"text",
@ -24,22 +22,22 @@ def find(scheme):
except KeyError:
pass
klass = None
cls = None
if scheme == "https":
scheme = "http"
if scheme in modules: # prevent unwanted imports
try:
module = importlib.import_module("." + scheme, __package__)
module = __import__(scheme, globals(), None, (), 1)
except ImportError:
pass
else:
klass = module.__downloader__
cls = module.__downloader__
if scheme == "http":
_cache["http"] = _cache["https"] = klass
_cache["http"] = _cache["https"] = cls
else:
_cache[scheme] = klass
return klass
_cache[scheme] = cls
return cls
# --------------------------------------------------------------------

@ -7,7 +7,6 @@
# published by the Free Software Foundation.
import re
import importlib
modules = [
"2chan",
@ -185,11 +184,12 @@ def _list_classes():
"""Yield all available extractor classes"""
yield from _cache
globals_ = globals()
for module_name in _module_iter:
module = importlib.import_module("."+module_name, __package__)
module = __import__(module_name, globals_, None, (), 1)
yield from add_module(module)
globals()["_list_classes"] = lambda : _cache
globals_["_list_classes"] = lambda : _cache
def _get_classes(module):

@ -1,6 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2018-2020 Mike Fährmann
# Copyright 2018-2021 Mike Fährmann
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License version 2 as
@ -8,8 +8,6 @@
"""Post-processing modules"""
import importlib
modules = [
"classify",
"compare",
@ -28,16 +26,16 @@ def find(name):
except KeyError:
pass
klass = None
cls = None
if name in modules: # prevent unwanted imports
try:
module = importlib.import_module("." + name, __package__)
module = __import__(name, globals(), None, (), 1)
except ImportError:
pass
else:
klass = module.__postprocessor__
_cache[name] = klass
return klass
cls = module.__postprocessor__
_cache[name] = cls
return cls
# --------------------------------------------------------------------

@ -1,7 +1,7 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Copyright 2018-2020 Mike Fährmann
# Copyright 2018-2021 Mike Fährmann
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License version 2 as
@ -74,7 +74,7 @@ class TestDownloaderModule(unittest.TestCase):
self.assertEqual(downloader.find(1234) , None)
self.assertEqual(downloader.find(None) , None)
@patch("importlib.import_module")
@patch("builtins.__import__")
def test_cache(self, import_module):
import_module.return_value = MockDownloaderModule()
downloader.find("http")
@ -86,14 +86,14 @@ class TestDownloaderModule(unittest.TestCase):
downloader.find("ytdl")
self.assertEqual(import_module.call_count, 3)
@patch("importlib.import_module")
@patch("builtins.__import__")
def test_cache_http(self, import_module):
import_module.return_value = MockDownloaderModule()
downloader.find("http")
downloader.find("https")
self.assertEqual(import_module.call_count, 1)
@patch("importlib.import_module")
@patch("builtins.__import__")
def test_cache_https(self, import_module):
import_module.return_value = MockDownloaderModule()
downloader.find("https")

@ -1,7 +1,7 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Copyright 2019-2020 Mike Fährmann
# Copyright 2019-2021 Mike Fährmann
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License version 2 as
@ -53,7 +53,7 @@ class TestPostprocessorModule(unittest.TestCase):
self.assertEqual(postprocessor.find(1234) , None)
self.assertEqual(postprocessor.find(None) , None)
@patch("importlib.import_module")
@patch("builtins.__import__")
def test_cache(self, import_module):
import_module.return_value = MockPostprocessorModule()

Loading…
Cancel
Save