diff --git a/gallery_dl/util.py b/gallery_dl/util.py index 7fcb6dac..a52daae2 100644 --- a/gallery_dl/util.py +++ b/gallery_dl/util.py @@ -526,13 +526,12 @@ class OAuthSession(): return "".join(random.choice(alphabet) for _ in range(N)) @staticmethod - def quote(value, _=None, encoding=None, errors=None): - return urllib.parse.quote(value, "~", encoding, errors) + def quote(value, quote=urllib.parse.quote): + return quote(value, "~") @staticmethod def urlencode(params): - quote = OAuthSession.quote - return "&".join([ - quote(str(key)) + "=" + quote(str(value)) - for key, value in sorted(params.items()) - ]) + return "&".join( + OAuthSession.quote(str(key)) + "=" + OAuthSession.quote(str(value)) + for key, value in sorted(params.items()) if value + ) diff --git a/test/test_oauth.py b/test/test_oauth.py new file mode 100644 index 00000000..36f1bd96 --- /dev/null +++ b/test/test_oauth.py @@ -0,0 +1,136 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +# Copyright 2018 Mike Fährmann +# +# This program is free software; you can redistribute it and/or modify +# it under the terms of the GNU General Public License version 2 as +# published by the Free Software Foundation. + +import unittest +import requests + +from gallery_dl import text +from gallery_dl.util import OAuthSession + +TESTSERVER = "http://oauthbin.com" +CONSUMER_KEY = "key" +CONSUMER_SECRET = "secret" +REQUEST_TOKEN = "requestkey" +REQUEST_TOKEN_SECRET = "requestsecret" +ACCESS_TOKEN = "accesskey" +ACCESS_TOKEN_SECRET = "accesssecret" + + +class TestOAuthSession(unittest.TestCase): + + def test_concat(self): + concat = OAuthSession.concat + + self.assertEqual(concat(), "") + self.assertEqual(concat("str"), "str") + self.assertEqual(concat("str1", "str2"), "str1&str2") + + self.assertEqual(concat("&", "?/"), "%26&%3F%2F") + self.assertEqual( + concat("GET", "http://example.org/", "foo=bar&baz=a"), + "GET&http%3A%2F%2Fexample.org%2F&foo%3Dbar%26baz%3Da" + ) + + def test_nonce(self, N=16): + nonce_values = set(OAuthSession.nonce(N) for _ in range(N)) + + # uniqueness + self.assertEqual(len(nonce_values), N) + + # length + for nonce in nonce_values: + self.assertEqual(len(nonce), N) + + def test_quote(self): + quote = OAuthSession.quote + + reserved = ",;:!\"§$%&/(){}[]=?`´+*'äöü" + unreserved = ("ABCDEFGHIJKLMNOPQRSTUVWXYZ" + "abcdefghijklmnopqrstuvwxyz" + "0123456789-._~") + + for char in unreserved: + self.assertEqual(quote(char), char) + + for char in reserved: + quoted = quote(char) + quoted_hex = quoted.replace("%", "") + self.assertTrue(quoted.startswith("%")) + self.assertTrue(len(quoted) >= 3) + self.assertEqual(quoted_hex.upper(), quoted_hex) + + def test_urlencode(self): + urlencode = OAuthSession.urlencode + + self.assertEqual(urlencode({}), "") + self.assertEqual(urlencode({"foo": "bar"}), "foo=bar") + + self.assertEqual( + urlencode({"foo": "bar", "baz": "a", "a": "baz"}), + "a=baz&baz=a&foo=bar" + ) + self.assertEqual( + urlencode({ + "oauth_consumer_key": "0685bd9184jfhq22", + "oauth_token": "ad180jjd733klru7", + "oauth_signature_method": "HMAC-SHA1", + "oauth_timestamp": 137131200, + "oauth_nonce": "4572616e48616d6d65724c61686176", + "oauth_version": "1.0" + }), + "oauth_consumer_key=0685bd9184jfhq22&" + "oauth_nonce=4572616e48616d6d65724c61686176&" + "oauth_signature_method=HMAC-SHA1&" + "oauth_timestamp=137131200&" + "oauth_token=ad180jjd733klru7&" + "oauth_version=1.0" + ) + + def test_request_token(self): + response = self._oauth_request( + "/v1/request-token", {}) + expected = "oauth_token=requestkey&oauth_token_secret=requestsecret" + self.assertEqual(response, expected, msg=response) + + data = text.parse_query(response) + self.assertTrue(data["oauth_token"], REQUEST_TOKEN) + self.assertTrue(data["oauth_token_secret"], REQUEST_TOKEN_SECRET) + + def test_access_token(self): + response = self._oauth_request( + "/v1/access-token", {}, REQUEST_TOKEN, REQUEST_TOKEN_SECRET) + expected = "oauth_token=accesskey&oauth_token_secret=accesssecret" + self.assertEqual(response, expected, msg=response) + + data = text.parse_query(response) + self.assertTrue(data["oauth_token"], ACCESS_TOKEN) + self.assertTrue(data["oauth_token_secret"], ACCESS_TOKEN_SECRET) + + def test_authenticated_call(self): + params = {"method": "foo", "bar": "baz", "a": "äöüß/?&#"} + response = self._oauth_request( + "/v1/echo", params, ACCESS_TOKEN, ACCESS_TOKEN_SECRET) + expected = OAuthSession.urlencode(params) + + self.assertEqual(response, expected, msg=response) + self.assertEqual(text.parse_query(response), params) + + def _oauth_request(self, endpoint, params=None, + oauth_token=None, oauth_token_secret=None): + session = OAuthSession( + requests.session(), + CONSUMER_KEY, CONSUMER_SECRET, + oauth_token, oauth_token_secret, + ) + url = TESTSERVER + endpoint + return session.get(url, params.copy()).text + + +if __name__ == "__main__": + unittest.main(warnings="ignore")