diff --git a/gallery_dl/text.py b/gallery_dl/text.py
index 7e5cb29f..e439c2b8 100644
--- a/gallery_dl/text.py
+++ b/gallery_dl/text.py
@@ -15,14 +15,22 @@ import html
import urllib.parse
-INVALID_XML_CHARS = (1, 2, 3, 4, 5, 6, 7, 8, 11, 12, 14, 15, 16, 17, 18,
- 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31)
+INVALID_XML_CHARS = (
+ "\x00", "\x01", "\x02", "\x03", "\x04", "\x05", "\x06", "\x07",
+ "\x08", "\x0b", "\x0c", "\x0e", "\x0f", "\x10", "\x11", "\x12",
+ "\x13", "\x14", "\x15", "\x16", "\x17", "\x18", "\x19", "\x1a",
+ "\x1b", "\x1c", "\x1d", "\x1e", "\x1f",
+)
def clean_xml(xmldata, repl=""):
- """Replace/Remove invalid control characters in XML data"""
+ """Replace/Remove invalid control characters in 'xmldata'"""
+ if not isinstance(xmldata, str):
+ try:
+ xmldata = "".join(xmldata)
+ except TypeError:
+ return ""
for char in INVALID_XML_CHARS:
- char = chr(char)
if char in xmldata:
xmldata = xmldata.replace(char, repl)
return xmldata
diff --git a/test/test_text.py b/test/test_text.py
index 767952fd..c4b02969 100644
--- a/test/test_text.py
+++ b/test/test_text.py
@@ -1,7 +1,7 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
-# Copyright 2015 Mike Fährmann
+# Copyright 2015-2018 Mike Fährmann
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License version 2 as
@@ -9,11 +9,35 @@
import unittest
import sys
-import gallery_dl.text as text
+
+from gallery_dl import text
class TestText(unittest.TestCase):
+ def test_clean_xml(self, f=text.clean_xml):
+ # standard usage
+ self.assertEqual(f(""), "")
+ self.assertEqual(f("foo"), "foo")
+ self.assertEqual(f("\tfoo\nbar\r"), "\tfoo\nbar\r")
+ self.assertEqual(f("\ab\ba\fr\v"), "bar")
+
+ # 'repl' argument
+ repl = "#"
+ self.assertEqual(f("", repl), "")
+ self.assertEqual(f("foo", repl), "foo")
+ self.assertEqual(f("\tfoo\nbar\r", repl), "\tfoo\nbar\r")
+ self.assertEqual(
+ f("\ab\ba\fr\v", repl), "#b#a#r#")
+
+ # removal of all illegal control characters
+ value = "".join(chr(x) for x in range(32))
+ self.assertEqual(f(value), "\t\n\r")
+
+ # 'invalid' arguments
+ for value in ((), [], {}, None, 1, 2.3):
+ self.assertEqual(f(value), "")
+
def test_remove_html(self):
cases = (
"Hello World.",