smaller fixes and "security" measures

- move the OAuthSession class into util.py
- block special extractors for reddit and recursive
- ignore 'only matching' tests for testresults script
pull/21/head
Mike Fährmann 7 years ago
parent fb1904dd59
commit 2993206c4b
No known key found for this signature in database
GPG Key ID: 5680CA389D365A88

@ -97,7 +97,7 @@ def extractors():
class blacklist(): class blacklist():
"""Context Manager to blacklist extractor modules""" """Context Manager to blacklist extractor modules"""
def __init__(self, *categories): def __init__(self, categories):
self.categories = categories self.categories = categories
def __enter__(self): def __enter__(self):

@ -9,7 +9,7 @@
"""Extract images from https://www.flickr.com/""" """Extract images from https://www.flickr.com/"""
from .common import Extractor, Message from .common import Extractor, Message
from .. import text, util, oauth, exception from .. import text, util, exception
import urllib.parse import urllib.parse
@ -249,7 +249,7 @@ class FlickrAPI():
token = extractor.config("access-token") token = extractor.config("access-token")
token_secret = extractor.config("access-token-secret") token_secret = extractor.config("access-token-secret")
if token and token_secret: if token and token_secret:
self.session = oauth.OAuthSession( self.session = util.OAuthSession(
extractor.session, extractor.session,
self.API_KEY, self.API_SECRET, token, token_secret) self.API_KEY, self.API_SECRET, token, token_secret)
self.API_KEY = None self.API_KEY = None

@ -10,7 +10,7 @@
from .common import Extractor, Message from .common import Extractor, Message
from . import reddit, flickr from . import reddit, flickr
from .. import oauth from .. import util
import os import os
import urllib.parse import urllib.parse
@ -80,7 +80,7 @@ class OAuthReddit(OAuthBase):
self.session.headers["User-Agent"] = reddit.RedditAPI.USER_AGENT self.session.headers["User-Agent"] = reddit.RedditAPI.USER_AGENT
self.client_id = reddit.RedditAPI.CLIENT_ID self.client_id = reddit.RedditAPI.CLIENT_ID
self.state = "gallery-dl:{}:{}".format( self.state = "gallery-dl:{}:{}".format(
self.subcategory, oauth.OAuthSession.nonce(8)) self.subcategory, util.OAuthSession.nonce(8))
def items(self): def items(self):
yield Message.Version, 1 yield Message.Version, 1
@ -128,7 +128,7 @@ class OAuthFlickr(OAuthBase):
def __init__(self, match): def __init__(self, match):
OAuthBase.__init__(self) OAuthBase.__init__(self)
self.session = oauth.OAuthSession( self.session = util.OAuthSession(
self.session, self.session,
flickr.FlickrAPI.API_KEY, flickr.FlickrAPI.API_SECRET flickr.FlickrAPI.API_KEY, flickr.FlickrAPI.API_SECRET
) )

@ -10,7 +10,7 @@
import re import re
from .common import Extractor, Message from .common import Extractor, Message
from .. import extractor, adapter from .. import extractor, adapter, util
class RecursiveExtractor(Extractor): class RecursiveExtractor(Extractor):
@ -27,8 +27,10 @@ class RecursiveExtractor(Extractor):
self.url = match.group(1) self.url = match.group(1)
def items(self): def items(self):
blist = self.config(
"blacklist", ("directlink",) + util.SPECIAL_EXTRACTORS)
page = self.request(self.url).text page = self.request(self.url).text
yield Message.Version, 1 yield Message.Version, 1
with extractor.blacklist("directlink"): with extractor.blacklist(blist):
for match in re.finditer(r"https?://[^\s\"']+", page): for match in re.finditer(r"https?://[^\s\"']+", page):
yield Message.Queue, match.group(0) yield Message.Queue, match.group(0)

@ -9,7 +9,7 @@
"""Extract images subreddits at https://reddit.com/""" """Extract images subreddits at https://reddit.com/"""
from .common import Extractor, Message from .common import Extractor, Message
from .. import text, extractor, exception from .. import text, util, extractor, exception
from ..cache import cache from ..cache import cache
import time import time
import re import re
@ -31,7 +31,7 @@ class RedditExtractor(Extractor):
depth = 0 depth = 0
yield Message.Version, 1 yield Message.Version, 1
with extractor.blacklist("reddit"): with extractor.blacklist(("reddit",) + util.SPECIAL_EXTRACTORS):
while True: while True:
extra = [] extra = []
for url in self._urls(submissions): for url in self._urls(submissions):

@ -1,54 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2017 Mike Fährmann
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License version 2 as
# published by the Free Software Foundation.
"""Utility classes for OAuth 1.0"""
import hmac
import time
import base64
import random
import string
import hashlib
import urllib.parse
class OAuthSession():
"""Minimal wrapper for requests.session objects to support OAuth 1.0"""
def __init__(self, session, consumer_key, consumer_secret,
token=None, token_secret=None):
self.session = session
self.consumer_secret = consumer_secret
self.token_secret = token_secret or ""
self.params = session.params
self.params["oauth_consumer_key"] = consumer_key
self.params["oauth_token"] = token
self.params["oauth_signature_method"] = "HMAC-SHA1"
self.params["oauth_version"] = "1.0"
def get(self, url, params):
params.update(self.params)
params["oauth_nonce"] = self.nonce(16)
params["oauth_timestamp"] = int(time.time())
params["oauth_signature"] = self.signature(url, params)
return self.session.get(url, params=params)
def signature(self, url, params):
"""Generate 'oauth_signature' value"""
query = urllib.parse.urlencode(sorted(params.items()))
message = self.concat("GET", url, query).encode()
key = self.concat(self.consumer_secret, self.token_secret).encode()
signature = hmac.new(key, message, hashlib.sha1).digest()
return base64.b64encode(signature).decode()
@staticmethod
def concat(*args):
return "&".join(urllib.parse.quote(item, "") for item in args)
@staticmethod
def nonce(N, alphabet=string.ascii_letters):
return "".join(random.choice(alphabet) for _ in range(N))

@ -10,6 +10,13 @@
import os import os
import sys import sys
import hmac
import time
import base64
import random
import string
import hashlib
import urllib.parse
from . import config, text, exception from . import config, text, exception
@ -75,19 +82,19 @@ def bdecode(data, alphabet="0123456789"):
def code_to_language(code, default="English"): def code_to_language(code, default="English"):
"""Map an ISO 639-1 language code to its actual name""" """Map an ISO 639-1 language code to its actual name"""
return codes.get(code.lower(), default) return CODES.get(code.lower(), default)
def language_to_code(lang, default="en"): def language_to_code(lang, default="en"):
"""Map a language name to its ISO 639-1 code""" """Map a language name to its ISO 639-1 code"""
lang = lang.capitalize() lang = lang.capitalize()
for code, language in codes.items(): for code, language in CODES.items():
if language == lang: if language == lang:
return code return code
return default return default
codes = { CODES = {
"ar": "Arabic", "ar": "Arabic",
"cs": "Czech", "cs": "Czech",
"da": "Danish", "da": "Danish",
@ -117,6 +124,8 @@ codes = {
"zh": "Chinese", "zh": "Chinese",
} }
SPECIAL_EXTRACTORS = ("oauth", "recursive", "test")
class RangePredicate(): class RangePredicate():
"""Predicate; is True if the current index is in the given range""" """Predicate; is True if the current index is in the given range"""
@ -224,3 +233,40 @@ class PathFormat():
def adjust_path(path): def adjust_path(path):
"""Enable longer-than-260-character paths on windows""" """Enable longer-than-260-character paths on windows"""
return "\\\\?\\" + os.path.abspath(path) if os.name == "nt" else path return "\\\\?\\" + os.path.abspath(path) if os.name == "nt" else path
class OAuthSession():
"""Minimal wrapper for requests.session objects to support OAuth 1.0"""
def __init__(self, session, consumer_key, consumer_secret,
token=None, token_secret=None):
self.session = session
self.consumer_secret = consumer_secret
self.token_secret = token_secret or ""
self.params = session.params
self.params["oauth_consumer_key"] = consumer_key
self.params["oauth_token"] = token
self.params["oauth_signature_method"] = "HMAC-SHA1"
self.params["oauth_version"] = "1.0"
def get(self, url, params):
params.update(self.params)
params["oauth_nonce"] = self.nonce(16)
params["oauth_timestamp"] = int(time.time())
params["oauth_signature"] = self.signature(url, params)
return self.session.get(url, params=params)
def signature(self, url, params):
"""Generate 'oauth_signature' value"""
query = urllib.parse.urlencode(sorted(params.items()))
message = self.concat("GET", url, query).encode()
key = self.concat(self.consumer_secret, self.token_secret).encode()
signature = hmac.new(key, message, hashlib.sha1).digest()
return base64.b64encode(signature).decode()
@staticmethod
def concat(*args):
return "&".join(urllib.parse.quote(item, "") for item in args)
@staticmethod
def nonce(N, alphabet=string.ascii_letters):
return "".join(random.choice(alphabet) for _ in range(N))

@ -9,7 +9,7 @@ sys.path.insert(0, os.path.realpath(ROOTDIR))
from gallery_dl import extractor, job, config from gallery_dl import extractor, job, config
tests = [ tests = [
([url[0] for url in extr.test], extr) ([url[0] for url in extr.test if url[1]], extr)
for extr in extractor.extractors() for extr in extractor.extractors()
if hasattr(extr, "test") if hasattr(extr, "test")
] ]

Loading…
Cancel
Save