re-implement OAuth1.0 code

OAuth support for SmugMug needs some additional features
(auth-rebuild on redirect, query parameters in URL, ...)
and fixing this in the old code wouldn't work all that well.
pull/86/head
Mike Fährmann 6 years ago
parent 0e3883303f
commit 6a31ada9e3
No known key found for this signature in database
GPG Key ID: 5680CA389D365A88

@ -9,7 +9,7 @@
"""Extract images from https://www.flickr.com/"""
from .common import Extractor, Message
from .. import text, util, exception
from .. import text, oauth, util, exception
class FlickrExtractor(Extractor):
@ -264,17 +264,20 @@ class FlickrAPI():
]
def __init__(self, extractor):
self.api_key = extractor.config("api-key", self.API_KEY)
self.api_secret = extractor.config("api-secret", self.API_SECRET)
api_key = extractor.config("api-key", self.API_KEY)
api_secret = extractor.config("api-secret", self.API_SECRET)
token = extractor.config("access-token")
token_secret = extractor.config("access-token-secret")
if token and token_secret:
self.session = util.OAuthSession(
extractor.session,
self.api_key, self.api_secret, token, token_secret)
if api_key and api_secret and token and token_secret:
self.session = oauth.OAuth1Session(
api_key, api_secret,
token, token_secret,
)
self.api_key = None
else:
self.session = extractor.session
self.api_key = api_key
self.maxsize = extractor.config("size-max")
if isinstance(self.maxsize, str):

@ -10,7 +10,7 @@
from .common import Extractor, Message
from . import deviantart, flickr, reddit, tumblr
from .. import text, util, config
from .. import text, oauth, config
import os
import urllib.parse
@ -70,21 +70,19 @@ class OAuthBase(Extractor):
def _oauth1_authorization_flow(
self, request_token_url, authorize_url, access_token_url):
"""Perform the OAuth 1.0a authorization flow"""
del self.session.params["oauth_token"]
# get a request token
params = {"oauth_callback": self.redirect_uri}
data = self.session.get(request_token_url, params=params).text
data = text.parse_query(data)
self.session.params["oauth_token"] = token = data["oauth_token"]
self.session.token_secret = data["oauth_token_secret"]
self.session.auth.token_secret = data["oauth_token_secret"]
# get the user's authorization
params = {"oauth_token": token, "perms": "read"}
params = {"oauth_token": data["oauth_token"], "perms": "read"}
data = self.open(authorize_url, params)
# exchange the request token for an access token
# self.session.token = data["oauth_token"]
data = self.session.get(access_token_url, params=data).text
data = text.parse_query(data)
@ -101,7 +99,7 @@ class OAuthBase(Extractor):
state = "gallery-dl_{}_{}".format(
self.subcategory,
util.OAuthSession.nonce(8)
oauth.nonce(8),
)
auth_params = {
@ -182,8 +180,7 @@ class OAuthFlickr(OAuthBase):
def __init__(self, match):
OAuthBase.__init__(self, match)
self.session = util.OAuthSession(
self.session,
self.session = oauth.OAuth1Session(
self.oauth_config("api-key", flickr.FlickrAPI.API_KEY),
self.oauth_config("api-secret", flickr.FlickrAPI.API_SECRET),
)
@ -221,8 +218,7 @@ class OAuthTumblr(OAuthBase):
def __init__(self, match):
OAuthBase.__init__(self, match)
self.session = util.OAuthSession(
self.session,
self.session = oauth.OAuth1Session(
self.oauth_config("api-key", tumblr.TumblrAPI.API_KEY),
self.oauth_config("api-secret", tumblr.TumblrAPI.API_SECRET),
)

@ -9,7 +9,7 @@
"""Extract images from https://www.smugmug.com/"""
from .common import Extractor, Message
from .. import text, util, exception
from .. import text, oauth, exception
BASE_PATTERN = (
r"(?:smugmug:(?!album:)(?:https?://)?([^/]+)|"
@ -186,8 +186,7 @@ class SmugmugAPI():
token_secret = extractor.config("access-token-secret")
if api_key and api_secret and token and token_secret:
self.session = util.OAuthSession(
extractor.session,
self.session = oauth.OAuth1Session(
api_key, api_secret,
token, token_secret,
)

@ -9,7 +9,7 @@
"""Extract images from https://www.tumblr.com/"""
from .common import Extractor, Message
from .. import text, util, exception
from .. import text, oauth, exception
from datetime import datetime, timedelta
import re
import time
@ -261,8 +261,7 @@ class TumblrAPI():
token_secret = extractor.config("access-token-secret")
if api_key and api_secret and token and token_secret:
self.session = util.OAuthSession(
extractor.session,
self.session = oauth.OAuth1Session(
api_key, api_secret,
token, token_secret,
)

@ -0,0 +1,101 @@
# -*- coding: utf-8 -*-
# Copyright 2018 Mike Fährmann
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License version 2 as
# published by the Free Software Foundation.
"""OAuth helper functions and classes"""
import hmac
import time
import base64
import random
import string
import hashlib
import urllib.parse
import requests
import requests.auth
from . import text
class OAuth1Session(requests.Session):
"""Extension to requests.Session objects to support OAuth 1.0"""
def __init__(self, consumer_key, consumer_secret,
token=None, token_secret=None):
requests.Session.__init__(self)
self.auth = OAuth1Client(
consumer_key, consumer_secret,
token, token_secret,
)
def rebuild_auth(self, prepared_request, response):
if "Authorization" in prepared_request.headers:
del prepared_request.headers["Authorization"]
prepared_request.prepare_auth(self.auth)
class OAuth1Client(requests.auth.AuthBase):
"""OAuth1.0a authentication"""
def __init__(self, consumer_key, consumer_secret,
token=None, token_secret=None):
self.consumer_key = consumer_key
self.consumer_secret = consumer_secret
self.token = token
self.token_secret = token_secret
def __call__(self, request):
oauth_params = [
("oauth_consumer_key", self.consumer_key),
("oauth_nonce", nonce(16)),
("oauth_signature_method", "HMAC-SHA1"),
("oauth_timestamp", str(int(time.time()))),
("oauth_version", "1.0"),
]
if self.token:
oauth_params.append(("oauth_token", self.token))
signature = self.generate_signature(request, oauth_params)
oauth_params.append(("oauth_signature", signature))
request.headers["Authorization"] = "OAuth " + ",".join(
key + '="' + value + '"' for key, value in oauth_params)
return request
def generate_signature(self, request, params):
"""Generate 'oauth_signature' value"""
url, _, query = request.url.partition("?")
params = params.copy()
for key, value in text.parse_query(query).items():
params.append((quote(key), quote(value)))
params.sort()
query = "&".join("=".join(item) for item in params)
message = concat(request.method, url, query).encode()
key = concat(self.consumer_secret, self.token_secret or "").encode()
signature = hmac.new(key, message, hashlib.sha1).digest()
return quote(base64.b64encode(signature).decode())
def concat(*args):
"""Concatenate 'args'"""
return "&".join(quote(item) for item in args)
def nonce(size, alphabet=string.ascii_letters):
"""Generate a nonce value with 'size' characters"""
return "".join(random.choice(alphabet) for _ in range(size))
def quote(value, quote=urllib.parse.quote):
"""Quote 'value' according to the OAuth1.0 standard"""
return quote(value, "~")

@ -11,14 +11,9 @@
import re
import os
import sys
import hmac
import time
import base64
import random
import shutil
import string
import _string
import hashlib
import sqlite3
import datetime
import itertools
@ -497,54 +492,6 @@ class PathFormat():
return "\\\\?\\" + os.path.abspath(path) if os.name == "nt" else path
class OAuthSession():
"""Minimal wrapper for requests.session objects to support OAuth 1.0"""
def __init__(self, session, consumer_key, consumer_secret,
token=None, token_secret=None):
self.session = session
self.consumer_secret = consumer_secret
self.token_secret = token_secret or ""
self.params = {}
self.params["oauth_consumer_key"] = consumer_key
self.params["oauth_token"] = token
self.params["oauth_signature_method"] = "HMAC-SHA1"
self.params["oauth_version"] = "1.0"
def get(self, url, params, **kwargs):
params.update(self.params)
params["oauth_nonce"] = self.nonce(16)
params["oauth_timestamp"] = int(time.time())
return self.session.get(url + self.sign(url, params), **kwargs)
def sign(self, url, params):
"""Generate 'oauth_signature' value and return query string"""
query = self.urlencode(params)
message = self.concat("GET", url, query).encode()
key = self.concat(self.consumer_secret, self.token_secret).encode()
signature = hmac.new(key, message, hashlib.sha1).digest()
return "?{}&oauth_signature={}".format(
query, self.quote(base64.b64encode(signature).decode()))
@staticmethod
def concat(*args):
return "&".join(OAuthSession.quote(item) for item in args)
@staticmethod
def nonce(N, alphabet=string.ascii_letters):
return "".join(random.choice(alphabet) for _ in range(N))
@staticmethod
def quote(value, quote=urllib.parse.quote):
return quote(value, "~")
@staticmethod
def urlencode(params):
return "&".join(
OAuthSession.quote(str(key)) + "=" + OAuthSession.quote(str(value))
for key, value in sorted(params.items()) if value
)
class DownloadArchive():
def __init__(self, path, extractor):

@ -8,10 +8,8 @@
# published by the Free Software Foundation.
import unittest
import requests
from gallery_dl import text
from gallery_dl.util import OAuthSession
from gallery_dl import oauth, text
TESTSERVER = "http://oauthbin.com"
CONSUMER_KEY = "key"
@ -25,7 +23,7 @@ ACCESS_TOKEN_SECRET = "accesssecret"
class TestOAuthSession(unittest.TestCase):
def test_concat(self):
concat = OAuthSession.concat
concat = oauth.concat
self.assertEqual(concat(), "")
self.assertEqual(concat("str"), "str")
@ -37,18 +35,18 @@ class TestOAuthSession(unittest.TestCase):
"GET&http%3A%2F%2Fexample.org%2F&foo%3Dbar%26baz%3Da"
)
def test_nonce(self, N=16):
nonce_values = set(OAuthSession.nonce(N) for _ in range(N))
def test_nonce(self, size=16):
nonce_values = set(oauth.nonce(size) for _ in range(size))
# uniqueness
self.assertEqual(len(nonce_values), N)
self.assertEqual(len(nonce_values), size)
# length
for nonce in nonce_values:
self.assertEqual(len(nonce), N)
self.assertEqual(len(nonce), size)
def test_quote(self):
quote = OAuthSession.quote
quote = oauth.quote
reserved = ",;:!\"§$%&/(){}[]=?`´+*'äöü"
unreserved = ("ABCDEFGHIJKLMNOPQRSTUVWXYZ"
@ -65,33 +63,6 @@ class TestOAuthSession(unittest.TestCase):
self.assertTrue(len(quoted) >= 3)
self.assertEqual(quoted_hex.upper(), quoted_hex)
def test_urlencode(self):
urlencode = OAuthSession.urlencode
self.assertEqual(urlencode({}), "")
self.assertEqual(urlencode({"foo": "bar"}), "foo=bar")
self.assertEqual(
urlencode({"foo": "bar", "baz": "a", "a": "baz"}),
"a=baz&baz=a&foo=bar"
)
self.assertEqual(
urlencode({
"oauth_consumer_key": "0685bd9184jfhq22",
"oauth_token": "ad180jjd733klru7",
"oauth_signature_method": "HMAC-SHA1",
"oauth_timestamp": 137131200,
"oauth_nonce": "4572616e48616d6d65724c61686176",
"oauth_version": "1.0"
}),
"oauth_consumer_key=0685bd9184jfhq22&"
"oauth_nonce=4572616e48616d6d65724c61686176&"
"oauth_signature_method=HMAC-SHA1&"
"oauth_timestamp=137131200&"
"oauth_token=ad180jjd733klru7&"
"oauth_version=1.0"
)
def test_request_token(self):
response = self._oauth_request(
"/v1/request-token", {})
@ -113,23 +84,20 @@ class TestOAuthSession(unittest.TestCase):
self.assertTrue(data["oauth_token_secret"], ACCESS_TOKEN_SECRET)
def test_authenticated_call(self):
params = {"method": "foo", "bar": "baz", "a": "äöüß/?&#"}
params = {"method": "foo", "a": "äöüß/?&#", "äöüß/?&#": "a"}
response = self._oauth_request(
"/v1/echo", params, ACCESS_TOKEN, ACCESS_TOKEN_SECRET)
expected = OAuthSession.urlencode(params)
self.assertEqual(response, expected, msg=response)
self.assertEqual(text.parse_query(response), params)
def _oauth_request(self, endpoint, params=None,
oauth_token=None, oauth_token_secret=None):
session = OAuthSession(
requests.session(),
session = oauth.OAuth1Session(
CONSUMER_KEY, CONSUMER_SECRET,
oauth_token, oauth_token_secret,
)
url = TESTSERVER + endpoint
return session.get(url, params.copy()).text
return session.get(url, params=params).text
if __name__ == "__main__":

Loading…
Cancel
Save