diff --git a/gallery_dl/extractor/__init__.py b/gallery_dl/extractor/__init__.py index b42e5c47..ec39961e 100644 --- a/gallery_dl/extractor/__init__.py +++ b/gallery_dl/extractor/__init__.py @@ -97,7 +97,7 @@ def extractors(): class blacklist(): """Context Manager to blacklist extractor modules""" - def __init__(self, *categories): + def __init__(self, categories): self.categories = categories def __enter__(self): diff --git a/gallery_dl/extractor/flickr.py b/gallery_dl/extractor/flickr.py index fdc46c2c..38a3ce36 100644 --- a/gallery_dl/extractor/flickr.py +++ b/gallery_dl/extractor/flickr.py @@ -9,7 +9,7 @@ """Extract images from https://www.flickr.com/""" from .common import Extractor, Message -from .. import text, util, oauth, exception +from .. import text, util, exception import urllib.parse @@ -249,7 +249,7 @@ class FlickrAPI(): token = extractor.config("access-token") token_secret = extractor.config("access-token-secret") if token and token_secret: - self.session = oauth.OAuthSession( + self.session = util.OAuthSession( extractor.session, self.API_KEY, self.API_SECRET, token, token_secret) self.API_KEY = None diff --git a/gallery_dl/extractor/oauthhelper.py b/gallery_dl/extractor/oauthhelper.py index 29eeb53a..64deb198 100644 --- a/gallery_dl/extractor/oauthhelper.py +++ b/gallery_dl/extractor/oauthhelper.py @@ -10,7 +10,7 @@ from .common import Extractor, Message from . import reddit, flickr -from .. import oauth +from .. import util import os import urllib.parse @@ -80,7 +80,7 @@ class OAuthReddit(OAuthBase): self.session.headers["User-Agent"] = reddit.RedditAPI.USER_AGENT self.client_id = reddit.RedditAPI.CLIENT_ID self.state = "gallery-dl:{}:{}".format( - self.subcategory, oauth.OAuthSession.nonce(8)) + self.subcategory, util.OAuthSession.nonce(8)) def items(self): yield Message.Version, 1 @@ -128,7 +128,7 @@ class OAuthFlickr(OAuthBase): def __init__(self, match): OAuthBase.__init__(self) - self.session = oauth.OAuthSession( + self.session = util.OAuthSession( self.session, flickr.FlickrAPI.API_KEY, flickr.FlickrAPI.API_SECRET ) diff --git a/gallery_dl/extractor/recursive.py b/gallery_dl/extractor/recursive.py index 0592bf42..d7b33419 100644 --- a/gallery_dl/extractor/recursive.py +++ b/gallery_dl/extractor/recursive.py @@ -10,7 +10,7 @@ import re from .common import Extractor, Message -from .. import extractor, adapter +from .. import extractor, adapter, util class RecursiveExtractor(Extractor): @@ -27,8 +27,10 @@ class RecursiveExtractor(Extractor): self.url = match.group(1) def items(self): + blist = self.config( + "blacklist", ("directlink",) + util.SPECIAL_EXTRACTORS) page = self.request(self.url).text yield Message.Version, 1 - with extractor.blacklist("directlink"): + with extractor.blacklist(blist): for match in re.finditer(r"https?://[^\s\"']+", page): yield Message.Queue, match.group(0) diff --git a/gallery_dl/extractor/reddit.py b/gallery_dl/extractor/reddit.py index bd3dcacd..589c4a46 100644 --- a/gallery_dl/extractor/reddit.py +++ b/gallery_dl/extractor/reddit.py @@ -9,7 +9,7 @@ """Extract images subreddits at https://reddit.com/""" from .common import Extractor, Message -from .. import text, extractor, exception +from .. import text, util, extractor, exception from ..cache import cache import time import re @@ -31,7 +31,7 @@ class RedditExtractor(Extractor): depth = 0 yield Message.Version, 1 - with extractor.blacklist("reddit"): + with extractor.blacklist(("reddit",) + util.SPECIAL_EXTRACTORS): while True: extra = [] for url in self._urls(submissions): diff --git a/gallery_dl/oauth.py b/gallery_dl/oauth.py deleted file mode 100644 index a27da285..00000000 --- a/gallery_dl/oauth.py +++ /dev/null @@ -1,54 +0,0 @@ -# -*- coding: utf-8 -*- - -# Copyright 2017 Mike Fährmann -# -# This program is free software; you can redistribute it and/or modify -# it under the terms of the GNU General Public License version 2 as -# published by the Free Software Foundation. - -"""Utility classes for OAuth 1.0""" - -import hmac -import time -import base64 -import random -import string -import hashlib -import urllib.parse - - -class OAuthSession(): - """Minimal wrapper for requests.session objects to support OAuth 1.0""" - def __init__(self, session, consumer_key, consumer_secret, - token=None, token_secret=None): - self.session = session - self.consumer_secret = consumer_secret - self.token_secret = token_secret or "" - self.params = session.params - self.params["oauth_consumer_key"] = consumer_key - self.params["oauth_token"] = token - self.params["oauth_signature_method"] = "HMAC-SHA1" - self.params["oauth_version"] = "1.0" - - def get(self, url, params): - params.update(self.params) - params["oauth_nonce"] = self.nonce(16) - params["oauth_timestamp"] = int(time.time()) - params["oauth_signature"] = self.signature(url, params) - return self.session.get(url, params=params) - - def signature(self, url, params): - """Generate 'oauth_signature' value""" - query = urllib.parse.urlencode(sorted(params.items())) - message = self.concat("GET", url, query).encode() - key = self.concat(self.consumer_secret, self.token_secret).encode() - signature = hmac.new(key, message, hashlib.sha1).digest() - return base64.b64encode(signature).decode() - - @staticmethod - def concat(*args): - return "&".join(urllib.parse.quote(item, "") for item in args) - - @staticmethod - def nonce(N, alphabet=string.ascii_letters): - return "".join(random.choice(alphabet) for _ in range(N)) diff --git a/gallery_dl/util.py b/gallery_dl/util.py index 6ae5975d..c51c2349 100644 --- a/gallery_dl/util.py +++ b/gallery_dl/util.py @@ -10,6 +10,13 @@ import os import sys +import hmac +import time +import base64 +import random +import string +import hashlib +import urllib.parse from . import config, text, exception @@ -75,19 +82,19 @@ def bdecode(data, alphabet="0123456789"): def code_to_language(code, default="English"): """Map an ISO 639-1 language code to its actual name""" - return codes.get(code.lower(), default) + return CODES.get(code.lower(), default) def language_to_code(lang, default="en"): """Map a language name to its ISO 639-1 code""" lang = lang.capitalize() - for code, language in codes.items(): + for code, language in CODES.items(): if language == lang: return code return default -codes = { +CODES = { "ar": "Arabic", "cs": "Czech", "da": "Danish", @@ -117,6 +124,8 @@ codes = { "zh": "Chinese", } +SPECIAL_EXTRACTORS = ("oauth", "recursive", "test") + class RangePredicate(): """Predicate; is True if the current index is in the given range""" @@ -224,3 +233,40 @@ class PathFormat(): def adjust_path(path): """Enable longer-than-260-character paths on windows""" return "\\\\?\\" + os.path.abspath(path) if os.name == "nt" else path + + +class OAuthSession(): + """Minimal wrapper for requests.session objects to support OAuth 1.0""" + def __init__(self, session, consumer_key, consumer_secret, + token=None, token_secret=None): + self.session = session + self.consumer_secret = consumer_secret + self.token_secret = token_secret or "" + self.params = session.params + self.params["oauth_consumer_key"] = consumer_key + self.params["oauth_token"] = token + self.params["oauth_signature_method"] = "HMAC-SHA1" + self.params["oauth_version"] = "1.0" + + def get(self, url, params): + params.update(self.params) + params["oauth_nonce"] = self.nonce(16) + params["oauth_timestamp"] = int(time.time()) + params["oauth_signature"] = self.signature(url, params) + return self.session.get(url, params=params) + + def signature(self, url, params): + """Generate 'oauth_signature' value""" + query = urllib.parse.urlencode(sorted(params.items())) + message = self.concat("GET", url, query).encode() + key = self.concat(self.consumer_secret, self.token_secret).encode() + signature = hmac.new(key, message, hashlib.sha1).digest() + return base64.b64encode(signature).decode() + + @staticmethod + def concat(*args): + return "&".join(urllib.parse.quote(item, "") for item in args) + + @staticmethod + def nonce(N, alphabet=string.ascii_letters): + return "".join(random.choice(alphabet) for _ in range(N)) diff --git a/scripts/build_testresult_db.py b/scripts/build_testresult_db.py index 0570f2e7..ecaf2747 100755 --- a/scripts/build_testresult_db.py +++ b/scripts/build_testresult_db.py @@ -9,7 +9,7 @@ sys.path.insert(0, os.path.realpath(ROOTDIR)) from gallery_dl import extractor, job, config tests = [ - ([url[0] for url in extr.test], extr) + ([url[0] for url in extr.test if url[1]], extr) for extr in extractor.extractors() if hasattr(extr, "test") ]