From ec85bf90deb6a7fc46b3e8a2ba479e9ac939e122 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mike=20F=C3=A4hrmann?= Date: Tue, 25 Feb 2020 23:08:47 +0100 Subject: [PATCH] use context managers in cache.py & add tests --- gallery_dl/cache.py | 33 +++---- scripts/run_tests.sh | 2 +- test/test_cache.py | 202 +++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 217 insertions(+), 20 deletions(-) create mode 100644 test/test_cache.py diff --git a/gallery_dl/cache.py b/gallery_dl/cache.py index 89e2f5d3..6cde65dc 100644 --- a/gallery_dl/cache.py +++ b/gallery_dl/cache.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- -# Copyright 2016-2019 Mike Fährmann +# Copyright 2016-2020 Mike Fährmann # # This program is free software; you can redistribute it and/or modify # it under the terms of the GNU General Public License version 2 as @@ -96,12 +96,12 @@ class DatabaseCacheDecorator(): # database lookup fullkey = "%s-%s" % (self.key, key) - cursor = self.cursor() - try: - cursor.execute("BEGIN EXCLUSIVE") - except sqlite3.OperationalError: - pass # Silently swallow exception - workaround for Python 3.6 - try: + with self.database() as db: + cursor = db.cursor() + try: + cursor.execute("BEGIN EXCLUSIVE") + except sqlite3.OperationalError: + pass # Silently swallow exception - workaround for Python 3.6 cursor.execute( "SELECT value, expires FROM data WHERE key=? LIMIT 1", (fullkey,), @@ -118,43 +118,38 @@ class DatabaseCacheDecorator(): "INSERT OR REPLACE INTO data VALUES (?,?,?)", (fullkey, pickle.dumps(value), expires), ) - finally: - self.db.commit() + self.cache[key] = value, expires return value def update(self, key, value): expires = int(time.time()) + self.maxage self.cache[key] = value, expires - try: - self.cursor().execute( + with self.database() as db: + db.execute( "INSERT OR REPLACE INTO data VALUES (?,?,?)", ("%s-%s" % (self.key, key), pickle.dumps(value), expires), ) - finally: - self.db.commit() def invalidate(self, key): try: del self.cache[key] except KeyError: pass - try: - self.cursor().execute( + with self.database() as db: + db.execute( "DELETE FROM data WHERE key=?", ("%s-%s" % (self.key, key),), ) - finally: - self.db.commit() - def cursor(self): + def database(self): if self._init: self.db.execute( "CREATE TABLE IF NOT EXISTS data " "(key TEXT PRIMARY KEY, value TEXT, expires INTEGER)" ) DatabaseCacheDecorator._init = False - return self.db.cursor() + return self.db def memcache(maxage=None, keyarg=None): diff --git a/scripts/run_tests.sh b/scripts/run_tests.sh index d8c8a03b..de9fe4b0 100755 --- a/scripts/run_tests.sh +++ b/scripts/run_tests.sh @@ -2,7 +2,7 @@ DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" -TESTS_CORE=(config cookies downloader extractor oauth postprocessor text util) +TESTS_CORE=(cache config cookies downloader extractor oauth postprocessor text util) TESTS_RESULTS=(results) diff --git a/test/test_cache.py b/test/test_cache.py new file mode 100644 index 00000000..31ece7e9 --- /dev/null +++ b/test/test_cache.py @@ -0,0 +1,202 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +# Copyright 2020 Mike Fährmann +# +# This program is free software; you can redistribute it and/or modify +# it under the terms of the GNU General Public License version 2 as +# published by the Free Software Foundation. + +import unittest +import tempfile +import time + +from gallery_dl import config, util +dbpath = tempfile.mkstemp()[1] +config.set(("cache",), "file", dbpath) +from gallery_dl import cache # noqa + + +def tearDownModule(): + util.remove_file(dbpath) + + +class TestCache(unittest.TestCase): + + def test_decorator(self): + + @cache.memcache() + def mc1(): + pass + + @cache.memcache(maxage=10) + def mc2(): + pass + + @cache.cache() + def dbc(): + pass + + self.assertIsInstance(mc1, cache.CacheDecorator) + self.assertIsInstance(mc2, cache.MemoryCacheDecorator) + self.assertIsInstance(dbc, cache.DatabaseCacheDecorator) + + def test_keyarg_mem_simple(self): + @cache.memcache(keyarg=2) + def ka(a, b, c): + return a+b+c + + self.assertEqual(ka(1, 1, 1), 3) + self.assertEqual(ka(2, 2, 2), 6) + + self.assertEqual(ka(0, 0, 1), 3) + self.assertEqual(ka(9, 9, 1), 3) + self.assertEqual(ka(0, 0, 2), 6) + self.assertEqual(ka(9, 9, 2), 6) + + def test_keyarg_mem(self): + @cache.memcache(keyarg=2, maxage=10) + def ka(a, b, c): + return a+b+c + + self.assertEqual(ka(1, 1, 1), 3) + self.assertEqual(ka(2, 2, 2), 6) + + self.assertEqual(ka(0, 0, 1), 3) + self.assertEqual(ka(9, 9, 1), 3) + self.assertEqual(ka(0, 0, 2), 6) + self.assertEqual(ka(9, 9, 2), 6) + + def test_keyarg_db(self): + @cache.cache(keyarg=2, maxage=10) + def ka(a, b, c): + return a+b+c + + self.assertEqual(ka(1, 1, 1), 3) + self.assertEqual(ka(2, 2, 2), 6) + + self.assertEqual(ka(0, 0, 1), 3) + self.assertEqual(ka(9, 9, 1), 3) + self.assertEqual(ka(0, 0, 2), 6) + self.assertEqual(ka(9, 9, 2), 6) + + def test_expires_mem(self): + @cache.memcache(maxage=1) + def ex(a, b, c): + return a+b+c + + self.assertEqual(ex(1, 1, 1), 3) + self.assertEqual(ex(2, 2, 2), 3) + self.assertEqual(ex(3, 3, 3), 3) + + time.sleep(2) + self.assertEqual(ex(3, 3, 3), 9) + self.assertEqual(ex(2, 2, 2), 9) + self.assertEqual(ex(1, 1, 1), 9) + + def test_expires_db(self): + @cache.cache(maxage=1) + def ex(a, b, c): + return a+b+c + + self.assertEqual(ex(1, 1, 1), 3) + self.assertEqual(ex(2, 2, 2), 3) + self.assertEqual(ex(3, 3, 3), 3) + + time.sleep(2) + self.assertEqual(ex(3, 3, 3), 9) + self.assertEqual(ex(2, 2, 2), 9) + self.assertEqual(ex(1, 1, 1), 9) + + def test_update_mem_simple(self): + @cache.memcache(keyarg=0) + def up(a, b, c): + return a+b+c + + self.assertEqual(up(1, 1, 1), 3) + up.update(1, 0) + up.update(2, 9) + self.assertEqual(up(1, 0, 0), 0) + self.assertEqual(up(2, 0, 0), 9) + + def test_update_mem(self): + @cache.memcache(keyarg=0, maxage=10) + def up(a, b, c): + return a+b+c + + self.assertEqual(up(1, 1, 1), 3) + up.update(1, 0) + up.update(2, 9) + self.assertEqual(up(1, 0, 0), 0) + self.assertEqual(up(2, 0, 0), 9) + + def test_update_db(self): + @cache.cache(keyarg=0, maxage=10) + def up(a, b, c): + return a+b+c + + self.assertEqual(up(1, 1, 1), 3) + up.update(1, 0) + up.update(2, 9) + self.assertEqual(up(1, 0, 0), 0) + self.assertEqual(up(2, 0, 0), 9) + + def test_invalidate_mem_simple(self): + @cache.memcache(keyarg=0) + def inv(a, b, c): + return a+b+c + + self.assertEqual(inv(1, 1, 1), 3) + inv.invalidate(1) + inv.invalidate(2) + self.assertEqual(inv(1, 0, 0), 1) + self.assertEqual(inv(2, 0, 0), 2) + + def test_invalidate_mem(self): + @cache.memcache(keyarg=0, maxage=10) + def inv(a, b, c): + return a+b+c + + self.assertEqual(inv(1, 1, 1), 3) + inv.invalidate(1) + inv.invalidate(2) + self.assertEqual(inv(1, 0, 0), 1) + self.assertEqual(inv(2, 0, 0), 2) + + def test_invalidate_db(self): + @cache.cache(keyarg=0, maxage=10) + def inv(a, b, c): + return a+b+c + + self.assertEqual(inv(1, 1, 1), 3) + inv.invalidate(1) + inv.invalidate(2) + self.assertEqual(inv(1, 0, 0), 1) + self.assertEqual(inv(2, 0, 0), 2) + + def test_database_read(self): + @cache.cache(keyarg=0, maxage=10) + def db(a, b, c): + return a+b+c + + # initialize cache + self.assertEqual(db(1, 1, 1), 3) + db.update(2, 6) + + # check and clear the in-memory portion of said cache + self.assertEqual(db.cache[1][0], 3) + self.assertEqual(db.cache[2][0], 6) + db.cache.clear() + self.assertEqual(db.cache, {}) + + # fetch results from database + self.assertEqual(db(1, 0, 0), 3) + self.assertEqual(db(2, 0, 0), 6) + + # check in-memory cache updates + self.assertEqual(db.cache[1][0], 3) + self.assertEqual(db.cache[2][0], 6) + + +if __name__ == '__main__': + unittest.main()