From 6a31ada9e3733c48a58ea03baece22aefd2eea0b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mike=20F=C3=A4hrmann?= Date: Thu, 10 May 2018 18:26:10 +0200 Subject: [PATCH] re-implement OAuth1.0 code OAuth support for SmugMug needs some additional features (auth-rebuild on redirect, query parameters in URL, ...) and fixing this in the old code wouldn't work all that well. --- gallery_dl/extractor/flickr.py | 17 +++--- gallery_dl/extractor/oauth.py | 18 +++--- gallery_dl/extractor/smugmug.py | 5 +- gallery_dl/extractor/tumblr.py | 5 +- gallery_dl/oauth.py | 101 ++++++++++++++++++++++++++++++++ gallery_dl/util.py | 53 ----------------- test/test_oauth.py | 52 ++++------------ 7 files changed, 132 insertions(+), 119 deletions(-) create mode 100644 gallery_dl/oauth.py diff --git a/gallery_dl/extractor/flickr.py b/gallery_dl/extractor/flickr.py index 150b2a28..967ec114 100644 --- a/gallery_dl/extractor/flickr.py +++ b/gallery_dl/extractor/flickr.py @@ -9,7 +9,7 @@ """Extract images from https://www.flickr.com/""" from .common import Extractor, Message -from .. import text, util, exception +from .. import text, oauth, util, exception class FlickrExtractor(Extractor): @@ -264,17 +264,20 @@ class FlickrAPI(): ] def __init__(self, extractor): - self.api_key = extractor.config("api-key", self.API_KEY) - self.api_secret = extractor.config("api-secret", self.API_SECRET) + api_key = extractor.config("api-key", self.API_KEY) + api_secret = extractor.config("api-secret", self.API_SECRET) token = extractor.config("access-token") token_secret = extractor.config("access-token-secret") - if token and token_secret: - self.session = util.OAuthSession( - extractor.session, - self.api_key, self.api_secret, token, token_secret) + + if api_key and api_secret and token and token_secret: + self.session = oauth.OAuth1Session( + api_key, api_secret, + token, token_secret, + ) self.api_key = None else: self.session = extractor.session + self.api_key = api_key self.maxsize = extractor.config("size-max") if isinstance(self.maxsize, str): diff --git a/gallery_dl/extractor/oauth.py b/gallery_dl/extractor/oauth.py index f161126e..6d8a1c7a 100644 --- a/gallery_dl/extractor/oauth.py +++ b/gallery_dl/extractor/oauth.py @@ -10,7 +10,7 @@ from .common import Extractor, Message from . import deviantart, flickr, reddit, tumblr -from .. import text, util, config +from .. import text, oauth, config import os import urllib.parse @@ -70,21 +70,19 @@ class OAuthBase(Extractor): def _oauth1_authorization_flow( self, request_token_url, authorize_url, access_token_url): """Perform the OAuth 1.0a authorization flow""" - del self.session.params["oauth_token"] - # get a request token params = {"oauth_callback": self.redirect_uri} data = self.session.get(request_token_url, params=params).text data = text.parse_query(data) - self.session.params["oauth_token"] = token = data["oauth_token"] - self.session.token_secret = data["oauth_token_secret"] + self.session.auth.token_secret = data["oauth_token_secret"] # get the user's authorization - params = {"oauth_token": token, "perms": "read"} + params = {"oauth_token": data["oauth_token"], "perms": "read"} data = self.open(authorize_url, params) # exchange the request token for an access token + # self.session.token = data["oauth_token"] data = self.session.get(access_token_url, params=data).text data = text.parse_query(data) @@ -101,7 +99,7 @@ class OAuthBase(Extractor): state = "gallery-dl_{}_{}".format( self.subcategory, - util.OAuthSession.nonce(8) + oauth.nonce(8), ) auth_params = { @@ -182,8 +180,7 @@ class OAuthFlickr(OAuthBase): def __init__(self, match): OAuthBase.__init__(self, match) - self.session = util.OAuthSession( - self.session, + self.session = oauth.OAuth1Session( self.oauth_config("api-key", flickr.FlickrAPI.API_KEY), self.oauth_config("api-secret", flickr.FlickrAPI.API_SECRET), ) @@ -221,8 +218,7 @@ class OAuthTumblr(OAuthBase): def __init__(self, match): OAuthBase.__init__(self, match) - self.session = util.OAuthSession( - self.session, + self.session = oauth.OAuth1Session( self.oauth_config("api-key", tumblr.TumblrAPI.API_KEY), self.oauth_config("api-secret", tumblr.TumblrAPI.API_SECRET), ) diff --git a/gallery_dl/extractor/smugmug.py b/gallery_dl/extractor/smugmug.py index 3306bc58..fa5fa751 100644 --- a/gallery_dl/extractor/smugmug.py +++ b/gallery_dl/extractor/smugmug.py @@ -9,7 +9,7 @@ """Extract images from https://www.smugmug.com/""" from .common import Extractor, Message -from .. import text, util, exception +from .. import text, oauth, exception BASE_PATTERN = ( r"(?:smugmug:(?!album:)(?:https?://)?([^/]+)|" @@ -186,8 +186,7 @@ class SmugmugAPI(): token_secret = extractor.config("access-token-secret") if api_key and api_secret and token and token_secret: - self.session = util.OAuthSession( - extractor.session, + self.session = oauth.OAuth1Session( api_key, api_secret, token, token_secret, ) diff --git a/gallery_dl/extractor/tumblr.py b/gallery_dl/extractor/tumblr.py index 770ca03f..8a87d8a5 100644 --- a/gallery_dl/extractor/tumblr.py +++ b/gallery_dl/extractor/tumblr.py @@ -9,7 +9,7 @@ """Extract images from https://www.tumblr.com/""" from .common import Extractor, Message -from .. import text, util, exception +from .. import text, oauth, exception from datetime import datetime, timedelta import re import time @@ -261,8 +261,7 @@ class TumblrAPI(): token_secret = extractor.config("access-token-secret") if api_key and api_secret and token and token_secret: - self.session = util.OAuthSession( - extractor.session, + self.session = oauth.OAuth1Session( api_key, api_secret, token, token_secret, ) diff --git a/gallery_dl/oauth.py b/gallery_dl/oauth.py new file mode 100644 index 00000000..459b0d43 --- /dev/null +++ b/gallery_dl/oauth.py @@ -0,0 +1,101 @@ +# -*- coding: utf-8 -*- + +# Copyright 2018 Mike Fährmann +# +# This program is free software; you can redistribute it and/or modify +# it under the terms of the GNU General Public License version 2 as +# published by the Free Software Foundation. + +"""OAuth helper functions and classes""" + +import hmac +import time +import base64 +import random +import string +import hashlib +import urllib.parse + +import requests +import requests.auth + +from . import text + + +class OAuth1Session(requests.Session): + """Extension to requests.Session objects to support OAuth 1.0""" + + def __init__(self, consumer_key, consumer_secret, + token=None, token_secret=None): + + requests.Session.__init__(self) + self.auth = OAuth1Client( + consumer_key, consumer_secret, + token, token_secret, + ) + + def rebuild_auth(self, prepared_request, response): + if "Authorization" in prepared_request.headers: + del prepared_request.headers["Authorization"] + prepared_request.prepare_auth(self.auth) + + +class OAuth1Client(requests.auth.AuthBase): + """OAuth1.0a authentication""" + def __init__(self, consumer_key, consumer_secret, + token=None, token_secret=None): + + self.consumer_key = consumer_key + self.consumer_secret = consumer_secret + self.token = token + self.token_secret = token_secret + + def __call__(self, request): + oauth_params = [ + ("oauth_consumer_key", self.consumer_key), + ("oauth_nonce", nonce(16)), + ("oauth_signature_method", "HMAC-SHA1"), + ("oauth_timestamp", str(int(time.time()))), + ("oauth_version", "1.0"), + ] + if self.token: + oauth_params.append(("oauth_token", self.token)) + + signature = self.generate_signature(request, oauth_params) + oauth_params.append(("oauth_signature", signature)) + + request.headers["Authorization"] = "OAuth " + ",".join( + key + '="' + value + '"' for key, value in oauth_params) + + return request + + def generate_signature(self, request, params): + """Generate 'oauth_signature' value""" + url, _, query = request.url.partition("?") + + params = params.copy() + for key, value in text.parse_query(query).items(): + params.append((quote(key), quote(value))) + params.sort() + query = "&".join("=".join(item) for item in params) + + message = concat(request.method, url, query).encode() + key = concat(self.consumer_secret, self.token_secret or "").encode() + signature = hmac.new(key, message, hashlib.sha1).digest() + + return quote(base64.b64encode(signature).decode()) + + +def concat(*args): + """Concatenate 'args'""" + return "&".join(quote(item) for item in args) + + +def nonce(size, alphabet=string.ascii_letters): + """Generate a nonce value with 'size' characters""" + return "".join(random.choice(alphabet) for _ in range(size)) + + +def quote(value, quote=urllib.parse.quote): + """Quote 'value' according to the OAuth1.0 standard""" + return quote(value, "~") diff --git a/gallery_dl/util.py b/gallery_dl/util.py index 642162e4..d488d039 100644 --- a/gallery_dl/util.py +++ b/gallery_dl/util.py @@ -11,14 +11,9 @@ import re import os import sys -import hmac -import time -import base64 -import random import shutil import string import _string -import hashlib import sqlite3 import datetime import itertools @@ -497,54 +492,6 @@ class PathFormat(): return "\\\\?\\" + os.path.abspath(path) if os.name == "nt" else path -class OAuthSession(): - """Minimal wrapper for requests.session objects to support OAuth 1.0""" - def __init__(self, session, consumer_key, consumer_secret, - token=None, token_secret=None): - self.session = session - self.consumer_secret = consumer_secret - self.token_secret = token_secret or "" - self.params = {} - self.params["oauth_consumer_key"] = consumer_key - self.params["oauth_token"] = token - self.params["oauth_signature_method"] = "HMAC-SHA1" - self.params["oauth_version"] = "1.0" - - def get(self, url, params, **kwargs): - params.update(self.params) - params["oauth_nonce"] = self.nonce(16) - params["oauth_timestamp"] = int(time.time()) - return self.session.get(url + self.sign(url, params), **kwargs) - - def sign(self, url, params): - """Generate 'oauth_signature' value and return query string""" - query = self.urlencode(params) - message = self.concat("GET", url, query).encode() - key = self.concat(self.consumer_secret, self.token_secret).encode() - signature = hmac.new(key, message, hashlib.sha1).digest() - return "?{}&oauth_signature={}".format( - query, self.quote(base64.b64encode(signature).decode())) - - @staticmethod - def concat(*args): - return "&".join(OAuthSession.quote(item) for item in args) - - @staticmethod - def nonce(N, alphabet=string.ascii_letters): - return "".join(random.choice(alphabet) for _ in range(N)) - - @staticmethod - def quote(value, quote=urllib.parse.quote): - return quote(value, "~") - - @staticmethod - def urlencode(params): - return "&".join( - OAuthSession.quote(str(key)) + "=" + OAuthSession.quote(str(value)) - for key, value in sorted(params.items()) if value - ) - - class DownloadArchive(): def __init__(self, path, extractor): diff --git a/test/test_oauth.py b/test/test_oauth.py index 36f1bd96..5176cc35 100644 --- a/test/test_oauth.py +++ b/test/test_oauth.py @@ -8,10 +8,8 @@ # published by the Free Software Foundation. import unittest -import requests -from gallery_dl import text -from gallery_dl.util import OAuthSession +from gallery_dl import oauth, text TESTSERVER = "http://oauthbin.com" CONSUMER_KEY = "key" @@ -25,7 +23,7 @@ ACCESS_TOKEN_SECRET = "accesssecret" class TestOAuthSession(unittest.TestCase): def test_concat(self): - concat = OAuthSession.concat + concat = oauth.concat self.assertEqual(concat(), "") self.assertEqual(concat("str"), "str") @@ -37,18 +35,18 @@ class TestOAuthSession(unittest.TestCase): "GET&http%3A%2F%2Fexample.org%2F&foo%3Dbar%26baz%3Da" ) - def test_nonce(self, N=16): - nonce_values = set(OAuthSession.nonce(N) for _ in range(N)) + def test_nonce(self, size=16): + nonce_values = set(oauth.nonce(size) for _ in range(size)) # uniqueness - self.assertEqual(len(nonce_values), N) + self.assertEqual(len(nonce_values), size) # length for nonce in nonce_values: - self.assertEqual(len(nonce), N) + self.assertEqual(len(nonce), size) def test_quote(self): - quote = OAuthSession.quote + quote = oauth.quote reserved = ",;:!\"§$%&/(){}[]=?`´+*'äöü" unreserved = ("ABCDEFGHIJKLMNOPQRSTUVWXYZ" @@ -65,33 +63,6 @@ class TestOAuthSession(unittest.TestCase): self.assertTrue(len(quoted) >= 3) self.assertEqual(quoted_hex.upper(), quoted_hex) - def test_urlencode(self): - urlencode = OAuthSession.urlencode - - self.assertEqual(urlencode({}), "") - self.assertEqual(urlencode({"foo": "bar"}), "foo=bar") - - self.assertEqual( - urlencode({"foo": "bar", "baz": "a", "a": "baz"}), - "a=baz&baz=a&foo=bar" - ) - self.assertEqual( - urlencode({ - "oauth_consumer_key": "0685bd9184jfhq22", - "oauth_token": "ad180jjd733klru7", - "oauth_signature_method": "HMAC-SHA1", - "oauth_timestamp": 137131200, - "oauth_nonce": "4572616e48616d6d65724c61686176", - "oauth_version": "1.0" - }), - "oauth_consumer_key=0685bd9184jfhq22&" - "oauth_nonce=4572616e48616d6d65724c61686176&" - "oauth_signature_method=HMAC-SHA1&" - "oauth_timestamp=137131200&" - "oauth_token=ad180jjd733klru7&" - "oauth_version=1.0" - ) - def test_request_token(self): response = self._oauth_request( "/v1/request-token", {}) @@ -113,23 +84,20 @@ class TestOAuthSession(unittest.TestCase): self.assertTrue(data["oauth_token_secret"], ACCESS_TOKEN_SECRET) def test_authenticated_call(self): - params = {"method": "foo", "bar": "baz", "a": "äöüß/?&#"} + params = {"method": "foo", "a": "äöüß/?&#", "äöüß/?&#": "a"} response = self._oauth_request( "/v1/echo", params, ACCESS_TOKEN, ACCESS_TOKEN_SECRET) - expected = OAuthSession.urlencode(params) - self.assertEqual(response, expected, msg=response) self.assertEqual(text.parse_query(response), params) def _oauth_request(self, endpoint, params=None, oauth_token=None, oauth_token_secret=None): - session = OAuthSession( - requests.session(), + session = oauth.OAuth1Session( CONSUMER_KEY, CONSUMER_SECRET, oauth_token, oauth_token_secret, ) url = TESTSERVER + endpoint - return session.get(url, params.copy()).text + return session.get(url, params=params).text if __name__ == "__main__":