diff --git a/gallery_dl/job.py b/gallery_dl/job.py index 4e185d04..97a8d3fa 100644 --- a/gallery_dl/job.py +++ b/gallery_dl/job.py @@ -11,7 +11,6 @@ import json import time import errno import logging -import operator import functools import collections from . import extractor, downloader, postprocessor @@ -201,7 +200,6 @@ class DownloadJob(Job): def __init__(self, url, parent=None): Job.__init__(self, url, parent) self.log = self.get_logger("download") - self.blacklist = None self.fallback = None self.archive = None self.sleep = None @@ -209,6 +207,7 @@ class DownloadJob(Job): self.downloaders = {} self.out = output.select() self.visited = parent.visited if parent else set() + self._extractor_filter = None self._skipcnt = 0 def handle_url(self, url, kwdict): @@ -297,9 +296,9 @@ class DownloadJob(Job): else: extr = extractor.find(url) if extr: - if self.blacklist is None: - self.blacklist = self._build_blacklist() - if extr.category in self.blacklist: + if self._extractor_filter is None: + self._extractor_filter = self._build_extractor_filter() + if not self._extractor_filter(extr): extr = None if extr: @@ -444,22 +443,20 @@ class DownloadJob(Job): self.hooks = collections.defaultdict(list) pp_log = self.get_logger("postprocessor") pp_list = [] - category = self.extractor.category - basecategory = self.extractor.basecategory pp_conf = config.get((), "postprocessor") or {} for pp_dict in postprocessors: if isinstance(pp_dict, str): pp_dict = pp_conf.get(pp_dict) or {"name": pp_dict} - whitelist = pp_dict.get("whitelist") - if whitelist and category not in whitelist and \ - basecategory not in whitelist: - continue - - blacklist = pp_dict.get("blacklist") - if blacklist and ( - category in blacklist or basecategory in blacklist): + clist = pp_dict.get("whitelist") + if clist is not None: + negate = False + else: + clist = pp_dict.get("blacklist") + negate = True + if clist and not util.build_extractor_filter( + clist, negate)(self.extractor): continue name = pp_dict.get("name") @@ -500,38 +497,18 @@ class DownloadJob(Job): if condition(pathfmt.kwdict): callback(pathfmt) - def _build_blacklist(self): - wlist = self.extractor.config("whitelist") - if wlist is not None: - if isinstance(wlist, str): - wlist = wlist.split(",") - - # build a set of all categories - blist = set() - add = blist.add - update = blist.update - get = operator.itemgetter(0) - - for extr in extractor._list_classes(): - category = extr.category - if category: - add(category) - else: - update(map(get, extr.instances)) - - # remove whitelisted categories - blist.difference_update(wlist) - return blist - - blist = self.extractor.config("blacklist") - if blist is not None: - if isinstance(blist, str): - blist = blist.split(",") - blist = set(blist) + def _build_extractor_filter(self): + clist = self.extractor.config("whitelist") + if clist is not None: + negate = False else: - blist = {self.extractor.category} - blist |= util.SPECIAL_EXTRACTORS - return blist + clist = self.extractor.config("blacklist") + negate = True + if clist is None: + clist = (self.extractor.category,) + + return util.build_extractor_filter( + clist, negate, util.SPECIAL_EXTRACTORS) class SimulationJob(DownloadJob): diff --git a/gallery_dl/util.py b/gallery_dl/util.py index 4a7fdbf4..d25194e3 100644 --- a/gallery_dl/util.py +++ b/gallery_dl/util.py @@ -81,6 +81,16 @@ def identity(x): return x +def true(_): + """Always returns True""" + return True + + +def false(_): + """Always returns False""" + return False + + def noop(): """Does nothing""" @@ -432,6 +442,66 @@ def build_duration_func(duration, min=0.0): return functools.partial(identity, duration if duration > min else min) +def build_extractor_filter(categories, negate=True, special=None): + """Build a function that takes an Extractor class as argument + and returns True if that class is allowed by 'categories' + """ + if isinstance(categories, str): + categories = categories.split(",") + + catset = set() # set of categories / basecategories + subset = set() # set of subcategories + catsub = [] # list of category-subcategory pairs + + for item in categories: + category, _, subcategory = item.partition(":") + if category and category != "*": + if subcategory and subcategory != "*": + catsub.append((category, subcategory)) + else: + catset.add(category) + elif subcategory and subcategory != "*": + subset.add(subcategory) + + if special: + catset |= special + elif not catset and not subset and not catsub: + return true if negate else false + + tests = [] + + if negate: + if catset: + tests.append(lambda extr: + extr.category not in catset and + extr.basecategory not in catset) + if subset: + tests.append(lambda extr: extr.subcategory not in subset) + else: + if catset: + tests.append(lambda extr: + extr.category in catset or + extr.basecategory in catset) + if subset: + tests.append(lambda extr: extr.subcategory in subset) + + if catsub: + def test(extr): + for category, subcategory in catsub: + if category in (extr.category, extr.basecategory) and \ + subcategory == extr.subcategory: + return not negate + return negate + tests.append(test) + + if len(tests) == 1: + return tests[0] + if negate: + return lambda extr: all(t(extr) for t in tests) + else: + return lambda extr: any(t(extr) for t in tests) + + def build_predicate(predicates): if not predicates: return lambda url, kwdict: True diff --git a/test/test_job.py b/test/test_job.py index 1aeec1c0..02765555 100644 --- a/test/test_job.py +++ b/test/test_job.py @@ -37,6 +37,31 @@ class TestJob(unittest.TestCase): return buffer.getvalue() +class TestDownloadJob(TestJob): + jobclass = job.DownloadJob + + def test_extractor_filter(self): + extr = TestExtractor.from_url("test:") + tjob = self.jobclass(extr) + + func = tjob._build_extractor_filter() + self.assertEqual(func(TestExtractor) , False) + self.assertEqual(func(TestExtractorParent), False) + self.assertEqual(func(TestExtractorAlt) , True) + + config.set((), "blacklist", ":test_subcategory") + func = tjob._build_extractor_filter() + self.assertEqual(func(TestExtractor) , False) + self.assertEqual(func(TestExtractorParent), True) + self.assertEqual(func(TestExtractorAlt) , False) + + config.set((), "whitelist", "test_category:test_subcategory") + func = tjob._build_extractor_filter() + self.assertEqual(func(TestExtractor) , True) + self.assertEqual(func(TestExtractorParent), False) + self.assertEqual(func(TestExtractorAlt) , False) + + class TestKeywordJob(TestJob): jobclass = job.KeywordJob @@ -334,5 +359,10 @@ class TestExtractorException(Extractor): return 1/0 +class TestExtractorAlt(Extractor): + category = "test_category_alt" + subcategory = "test_subcategory" + + if __name__ == '__main__': unittest.main() diff --git a/test/test_util.py b/test/test_util.py index 0fbbbcea..32e97849 100644 --- a/test/test_util.py +++ b/test/test_util.py @@ -357,6 +357,58 @@ class TestOther(unittest.TestCase): with self.assertRaises(exception.StopExtraction): expr() + def test_extractor_filter(self): + # empty + func = util.build_extractor_filter("") + self.assertEqual(func(TestExtractor) , True) + self.assertEqual(func(TestExtractorParent), True) + self.assertEqual(func(TestExtractorAlt) , True) + + # category + func = util.build_extractor_filter("test_category") + self.assertEqual(func(TestExtractor) , False) + self.assertEqual(func(TestExtractorParent), False) + self.assertEqual(func(TestExtractorAlt) , True) + + # subcategory + func = util.build_extractor_filter("*:test_subcategory") + self.assertEqual(func(TestExtractor) , False) + self.assertEqual(func(TestExtractorParent), True) + self.assertEqual(func(TestExtractorAlt) , False) + + # basecategory + func = util.build_extractor_filter("test_basecategory") + self.assertEqual(func(TestExtractor) , False) + self.assertEqual(func(TestExtractorParent), False) + self.assertEqual(func(TestExtractorAlt) , False) + + # category-subcategory pair + func = util.build_extractor_filter("test_category:test_subcategory") + self.assertEqual(func(TestExtractor) , False) + self.assertEqual(func(TestExtractorParent), True) + self.assertEqual(func(TestExtractorAlt) , True) + + # combination + func = util.build_extractor_filter( + ["test_category", "*:test_subcategory"]) + self.assertEqual(func(TestExtractor) , False) + self.assertEqual(func(TestExtractorParent), False) + self.assertEqual(func(TestExtractorAlt) , False) + + # whitelist + func = util.build_extractor_filter( + "test_category:test_subcategory", negate=False) + self.assertEqual(func(TestExtractor) , True) + self.assertEqual(func(TestExtractorParent), False) + self.assertEqual(func(TestExtractorAlt) , False) + + func = util.build_extractor_filter( + ["test_category:test_subcategory", "*:test_subcategory_parent"], + negate=False) + self.assertEqual(func(TestExtractor) , True) + self.assertEqual(func(TestExtractorParent), True) + self.assertEqual(func(TestExtractorAlt) , False) + def test_generate_token(self): tokens = set() for _ in range(100): @@ -469,5 +521,21 @@ class TestOther(unittest.TestCase): self.assertIs(obj["key"], obj) +class TestExtractor(): + category = "test_category" + subcategory = "test_subcategory" + basecategory = "test_basecategory" + + +class TestExtractorParent(TestExtractor): + category = "test_category" + subcategory = "test_subcategory_parent" + + +class TestExtractorAlt(TestExtractor): + category = "test_category_alt" + subcategory = "test_subcategory" + + if __name__ == '__main__': unittest.main()