diff --git a/gallery_dl/formatter.py b/gallery_dl/formatter.py index 107c8ed6..d9c307e5 100644 --- a/gallery_dl/formatter.py +++ b/gallery_dl/formatter.py @@ -17,17 +17,9 @@ import operator import functools from . import text, util -_CACHE = {} -_CONVERSIONS = None -_GLOBALS = { - "_env": lambda: os.environ, - "_lit": lambda: _literal, - "_now": datetime.datetime.now, -} - -def parse(format_string, default=None): - key = format_string, default +def parse(format_string, default=None, fmt=format): + key = format_string, default, fmt try: return _CACHE[key] @@ -48,7 +40,7 @@ def parse(format_string, default=None): elif kind == "F": cls = FStringFormatter - formatter = _CACHE[key] = cls(format_string, default) + formatter = _CACHE[key] = cls(format_string, default, fmt) return formatter @@ -95,8 +87,9 @@ class StringFormatter(): Example: {f:R /_/} -> "f_o_o_b_a_r" (if "f" is "f o o b a r") """ - def __init__(self, format_string, default=None): + def __init__(self, format_string, default=None, fmt=format): self.default = default + self.format = fmt self.result = [] self.fields = [] @@ -126,7 +119,7 @@ class StringFormatter(): return "".join(result) def _field_access(self, field_name, format_spec, conversion): - fmt = parse_format_spec(format_spec, conversion) + fmt = self._parse_format_spec(format_spec, conversion) if "|" in field_name: return self._apply_list([ @@ -184,27 +177,38 @@ class StringFormatter(): return fmt(obj) return wrap + def _parse_format_spec(self, format_spec, conversion): + fmt = _build_format_func(format_spec, self.format) + if not conversion: + return fmt + + conversion = _CONVERSIONS[conversion] + if fmt is self.format: + return conversion + else: + return lambda obj: fmt(conversion(obj)) + class TemplateFormatter(StringFormatter): """Read format_string from file""" - def __init__(self, path, default=None): + def __init__(self, path, default=None, fmt=format): with open(util.expand_path(path)) as fp: format_string = fp.read() - StringFormatter.__init__(self, format_string, default) + StringFormatter.__init__(self, format_string, default, fmt) class ExpressionFormatter(): """Generate text by evaluating a Python expression""" - def __init__(self, expression, default=None): + def __init__(self, expression, default=None, fmt=None): self.format_map = util.compile_expression(expression) class ModuleFormatter(): """Generate text by calling an external function""" - def __init__(self, function_spec, default=None): + def __init__(self, function_spec, default=None, fmt=None): module_name, _, function_name = function_spec.partition(":") module = __import__(module_name) self.format_map = getattr(module, function_name) @@ -213,7 +217,7 @@ class ModuleFormatter(): class FStringFormatter(): """Generate text by evaluaring an f-string literal""" - def __init__(self, fstring, default=None): + def __init__(self, fstring, default=None, fmt=None): self.format_map = util.compile_expression("f'''" + fstring + "'''") @@ -251,81 +255,37 @@ def _slice(indices): ) -def parse_format_spec(format_spec, conversion): - fmt = build_format_func(format_spec) - if not conversion: - return fmt - - global _CONVERSIONS - if _CONVERSIONS is None: - _CONVERSIONS = { - "l": str.lower, - "u": str.upper, - "c": str.capitalize, - "C": string.capwords, - "j": functools.partial(json.dumps, default=str), - "t": str.strip, - "T": util.datetime_to_timestamp_string, - "d": text.parse_timestamp, - "U": text.unescape, - "S": util.to_string, - "s": str, - "r": repr, - "a": ascii, - } - - conversion = _CONVERSIONS[conversion] - if fmt is format: - return conversion - else: - def chain(obj): - return fmt(conversion(obj)) - return chain - - -def build_format_func(format_spec): +def _build_format_func(format_spec, default): if format_spec: - fmt = format_spec[0] - if fmt == "?": - return _parse_optional(format_spec) - if fmt == "[": - return _parse_slice(format_spec) - if fmt == "L": - return _parse_maxlen(format_spec) - if fmt == "J": - return _parse_join(format_spec) - if fmt == "R": - return _parse_replace(format_spec) - if fmt == "D": - return _parse_datetime(format_spec) - return _default_format(format_spec) - return format - - -def _parse_optional(format_spec): + return _FORMAT_SPECIFIERS.get( + format_spec[0], _default_format)(format_spec, default) + return default + + +def _parse_optional(format_spec, default): before, after, format_spec = format_spec.split("/", 2) before = before[1:] - fmt = build_format_func(format_spec) + fmt = _build_format_func(format_spec, default) def optional(obj): return before + fmt(obj) + after if obj else "" return optional -def _parse_slice(format_spec): +def _parse_slice(format_spec, default): indices, _, format_spec = format_spec.partition("]") slice = _slice(indices[1:]) - fmt = build_format_func(format_spec) + fmt = _build_format_func(format_spec, default) def apply_slice(obj): return fmt(obj[slice]) return apply_slice -def _parse_maxlen(format_spec): +def _parse_maxlen(format_spec, default): maxlen, replacement, format_spec = format_spec.split("/", 2) maxlen = text.parse_int(maxlen[1:]) - fmt = build_format_func(format_spec) + fmt = _build_format_func(format_spec, default) def mlen(obj): obj = fmt(obj) @@ -333,37 +293,37 @@ def _parse_maxlen(format_spec): return mlen -def _parse_join(format_spec): +def _parse_join(format_spec, default): separator, _, format_spec = format_spec.partition("/") separator = separator[1:] - fmt = build_format_func(format_spec) + fmt = _build_format_func(format_spec, default) def join(obj): return fmt(separator.join(obj)) return join -def _parse_replace(format_spec): +def _parse_replace(format_spec, default): old, new, format_spec = format_spec.split("/", 2) old = old[1:] - fmt = build_format_func(format_spec) + fmt = _build_format_func(format_spec, default) def replace(obj): return fmt(obj.replace(old, new)) return replace -def _parse_datetime(format_spec): +def _parse_datetime(format_spec, default): dt_format, _, format_spec = format_spec.partition("/") dt_format = dt_format[1:] - fmt = build_format_func(format_spec) + fmt = _build_format_func(format_spec, default) def dt(obj): return fmt(text.parse_datetime(obj, dt_format)) return dt -def _default_format(format_spec): +def _default_format(format_spec, default): def wrap(obj): return format(obj, format_spec) return wrap @@ -379,3 +339,33 @@ class Literal(): _literal = Literal() + +_CACHE = {} +_GLOBALS = { + "_env": lambda: os.environ, + "_lit": lambda: _literal, + "_now": datetime.datetime.now, +} +_CONVERSIONS = { + "l": str.lower, + "u": str.upper, + "c": str.capitalize, + "C": string.capwords, + "j": functools.partial(json.dumps, default=str), + "t": str.strip, + "T": util.datetime_to_timestamp_string, + "d": text.parse_timestamp, + "U": text.unescape, + "S": util.to_string, + "s": str, + "r": repr, + "a": ascii, +} +_FORMAT_SPECIFIERS = { + "?": _parse_optional, + "[": _parse_slice, + "D": _parse_datetime, + "L": _parse_maxlen, + "J": _parse_join, + "R": _parse_replace, +} diff --git a/test/test_formatter.py b/test/test_formatter.py index 5b8ca0ab..f2ce310c 100644 --- a/test/test_formatter.py +++ b/test/test_formatter.py @@ -14,7 +14,7 @@ import datetime import tempfile sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) -from gallery_dl import formatter # noqa E402 +from gallery_dl import formatter, text, util # noqa E402 class TestFormatter(unittest.TestCase): @@ -98,6 +98,14 @@ class TestFormatter(unittest.TestCase): self._run_test("{missing[key]}", replacement, default) self._run_test("{missing:?a//}", "a" + default, default) + def test_fmt_func(self): + self._run_test("{t}" , self.kwdict["t"] , None, int) + self._run_test("{t}" , self.kwdict["t"] , None, util.identity) + self._run_test("{dt}", self.kwdict["dt"], None, util.identity) + self._run_test("{ds}", self.kwdict["dt"], None, text.parse_datetime) + self._run_test("{ds:D%Y-%m-%dT%H:%M:%S%z}", self.kwdict["dt"], + None, util.identity) + def test_alternative(self): self._run_test("{a|z}" , "hElLo wOrLd") self._run_test("{z|a}" , "hElLo wOrLd") @@ -316,8 +324,8 @@ def noarg(): with self.assertRaises(TypeError): self.assertEqual(fmt3.format_map(self.kwdict), "") - def _run_test(self, format_string, result, default=None): - fmt = formatter.parse(format_string, default) + def _run_test(self, format_string, result, default=None, fmt=format): + fmt = formatter.parse(format_string, default, fmt) output = fmt.format_map(self.kwdict) self.assertEqual(output, result, format_string)