[formatter] allow for custom "format" functions (#2721)

pull/2739/head
Mike Fährmann 2 years ago
parent 311e9383af
commit 04bed1eba3
No known key found for this signature in database
GPG Key ID: 5680CA389D365A88

@ -17,17 +17,9 @@ import operator
import functools
from . import text, util
_CACHE = {}
_CONVERSIONS = None
_GLOBALS = {
"_env": lambda: os.environ,
"_lit": lambda: _literal,
"_now": datetime.datetime.now,
}
def parse(format_string, default=None):
key = format_string, default
def parse(format_string, default=None, fmt=format):
key = format_string, default, fmt
try:
return _CACHE[key]
@ -48,7 +40,7 @@ def parse(format_string, default=None):
elif kind == "F":
cls = FStringFormatter
formatter = _CACHE[key] = cls(format_string, default)
formatter = _CACHE[key] = cls(format_string, default, fmt)
return formatter
@ -95,8 +87,9 @@ class StringFormatter():
Example: {f:R /_/} -> "f_o_o_b_a_r" (if "f" is "f o o b a r")
"""
def __init__(self, format_string, default=None):
def __init__(self, format_string, default=None, fmt=format):
self.default = default
self.format = fmt
self.result = []
self.fields = []
@ -126,7 +119,7 @@ class StringFormatter():
return "".join(result)
def _field_access(self, field_name, format_spec, conversion):
fmt = parse_format_spec(format_spec, conversion)
fmt = self._parse_format_spec(format_spec, conversion)
if "|" in field_name:
return self._apply_list([
@ -184,27 +177,38 @@ class StringFormatter():
return fmt(obj)
return wrap
def _parse_format_spec(self, format_spec, conversion):
fmt = _build_format_func(format_spec, self.format)
if not conversion:
return fmt
conversion = _CONVERSIONS[conversion]
if fmt is self.format:
return conversion
else:
return lambda obj: fmt(conversion(obj))
class TemplateFormatter(StringFormatter):
"""Read format_string from file"""
def __init__(self, path, default=None):
def __init__(self, path, default=None, fmt=format):
with open(util.expand_path(path)) as fp:
format_string = fp.read()
StringFormatter.__init__(self, format_string, default)
StringFormatter.__init__(self, format_string, default, fmt)
class ExpressionFormatter():
"""Generate text by evaluating a Python expression"""
def __init__(self, expression, default=None):
def __init__(self, expression, default=None, fmt=None):
self.format_map = util.compile_expression(expression)
class ModuleFormatter():
"""Generate text by calling an external function"""
def __init__(self, function_spec, default=None):
def __init__(self, function_spec, default=None, fmt=None):
module_name, _, function_name = function_spec.partition(":")
module = __import__(module_name)
self.format_map = getattr(module, function_name)
@ -213,7 +217,7 @@ class ModuleFormatter():
class FStringFormatter():
"""Generate text by evaluaring an f-string literal"""
def __init__(self, fstring, default=None):
def __init__(self, fstring, default=None, fmt=None):
self.format_map = util.compile_expression("f'''" + fstring + "'''")
@ -251,81 +255,37 @@ def _slice(indices):
)
def parse_format_spec(format_spec, conversion):
fmt = build_format_func(format_spec)
if not conversion:
return fmt
global _CONVERSIONS
if _CONVERSIONS is None:
_CONVERSIONS = {
"l": str.lower,
"u": str.upper,
"c": str.capitalize,
"C": string.capwords,
"j": functools.partial(json.dumps, default=str),
"t": str.strip,
"T": util.datetime_to_timestamp_string,
"d": text.parse_timestamp,
"U": text.unescape,
"S": util.to_string,
"s": str,
"r": repr,
"a": ascii,
}
conversion = _CONVERSIONS[conversion]
if fmt is format:
return conversion
else:
def chain(obj):
return fmt(conversion(obj))
return chain
def build_format_func(format_spec):
def _build_format_func(format_spec, default):
if format_spec:
fmt = format_spec[0]
if fmt == "?":
return _parse_optional(format_spec)
if fmt == "[":
return _parse_slice(format_spec)
if fmt == "L":
return _parse_maxlen(format_spec)
if fmt == "J":
return _parse_join(format_spec)
if fmt == "R":
return _parse_replace(format_spec)
if fmt == "D":
return _parse_datetime(format_spec)
return _default_format(format_spec)
return format
def _parse_optional(format_spec):
return _FORMAT_SPECIFIERS.get(
format_spec[0], _default_format)(format_spec, default)
return default
def _parse_optional(format_spec, default):
before, after, format_spec = format_spec.split("/", 2)
before = before[1:]
fmt = build_format_func(format_spec)
fmt = _build_format_func(format_spec, default)
def optional(obj):
return before + fmt(obj) + after if obj else ""
return optional
def _parse_slice(format_spec):
def _parse_slice(format_spec, default):
indices, _, format_spec = format_spec.partition("]")
slice = _slice(indices[1:])
fmt = build_format_func(format_spec)
fmt = _build_format_func(format_spec, default)
def apply_slice(obj):
return fmt(obj[slice])
return apply_slice
def _parse_maxlen(format_spec):
def _parse_maxlen(format_spec, default):
maxlen, replacement, format_spec = format_spec.split("/", 2)
maxlen = text.parse_int(maxlen[1:])
fmt = build_format_func(format_spec)
fmt = _build_format_func(format_spec, default)
def mlen(obj):
obj = fmt(obj)
@ -333,37 +293,37 @@ def _parse_maxlen(format_spec):
return mlen
def _parse_join(format_spec):
def _parse_join(format_spec, default):
separator, _, format_spec = format_spec.partition("/")
separator = separator[1:]
fmt = build_format_func(format_spec)
fmt = _build_format_func(format_spec, default)
def join(obj):
return fmt(separator.join(obj))
return join
def _parse_replace(format_spec):
def _parse_replace(format_spec, default):
old, new, format_spec = format_spec.split("/", 2)
old = old[1:]
fmt = build_format_func(format_spec)
fmt = _build_format_func(format_spec, default)
def replace(obj):
return fmt(obj.replace(old, new))
return replace
def _parse_datetime(format_spec):
def _parse_datetime(format_spec, default):
dt_format, _, format_spec = format_spec.partition("/")
dt_format = dt_format[1:]
fmt = build_format_func(format_spec)
fmt = _build_format_func(format_spec, default)
def dt(obj):
return fmt(text.parse_datetime(obj, dt_format))
return dt
def _default_format(format_spec):
def _default_format(format_spec, default):
def wrap(obj):
return format(obj, format_spec)
return wrap
@ -379,3 +339,33 @@ class Literal():
_literal = Literal()
_CACHE = {}
_GLOBALS = {
"_env": lambda: os.environ,
"_lit": lambda: _literal,
"_now": datetime.datetime.now,
}
_CONVERSIONS = {
"l": str.lower,
"u": str.upper,
"c": str.capitalize,
"C": string.capwords,
"j": functools.partial(json.dumps, default=str),
"t": str.strip,
"T": util.datetime_to_timestamp_string,
"d": text.parse_timestamp,
"U": text.unescape,
"S": util.to_string,
"s": str,
"r": repr,
"a": ascii,
}
_FORMAT_SPECIFIERS = {
"?": _parse_optional,
"[": _parse_slice,
"D": _parse_datetime,
"L": _parse_maxlen,
"J": _parse_join,
"R": _parse_replace,
}

@ -14,7 +14,7 @@ import datetime
import tempfile
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from gallery_dl import formatter # noqa E402
from gallery_dl import formatter, text, util # noqa E402
class TestFormatter(unittest.TestCase):
@ -98,6 +98,14 @@ class TestFormatter(unittest.TestCase):
self._run_test("{missing[key]}", replacement, default)
self._run_test("{missing:?a//}", "a" + default, default)
def test_fmt_func(self):
self._run_test("{t}" , self.kwdict["t"] , None, int)
self._run_test("{t}" , self.kwdict["t"] , None, util.identity)
self._run_test("{dt}", self.kwdict["dt"], None, util.identity)
self._run_test("{ds}", self.kwdict["dt"], None, text.parse_datetime)
self._run_test("{ds:D%Y-%m-%dT%H:%M:%S%z}", self.kwdict["dt"],
None, util.identity)
def test_alternative(self):
self._run_test("{a|z}" , "hElLo wOrLd")
self._run_test("{z|a}" , "hElLo wOrLd")
@ -316,8 +324,8 @@ def noarg():
with self.assertRaises(TypeError):
self.assertEqual(fmt3.format_map(self.kwdict), "")
def _run_test(self, format_string, result, default=None):
fmt = formatter.parse(format_string, default)
def _run_test(self, format_string, result, default=None, fmt=format):
fmt = formatter.parse(format_string, default, fmt)
output = fmt.format_map(self.kwdict)
self.assertEqual(output, result, format_string)

Loading…
Cancel
Save