diff --git a/gallery_dl/extractor/reddit.py b/gallery_dl/extractor/reddit.py index 4c830195..44627d5a 100644 --- a/gallery_dl/extractor/reddit.py +++ b/gallery_dl/extractor/reddit.py @@ -9,74 +9,88 @@ """Extractors for https://www.reddit.com/""" from .common import Extractor, Message -from .. import text, util, extractor, exception +from .. import text, util, exception from ..cache import cache class RedditExtractor(Extractor): """Base class for reddit extractors""" category = "reddit" + directory_fmt = ("{category}", "{subreddit}") + filename_fmt = "{id} {title[:242]}.{extension}" + archive_fmt = "{url}" cookiedomain = None def __init__(self, match): Extractor.__init__(self, match) self.api = RedditAPI(self) - self.max_depth = int(self.config("recursion", 0)) - self._visited = set() + self.max_depth = self.config("recursion", 0) def items(self): - subre = RedditSubmissionExtractor.pattern + match_submission = RedditSubmissionExtractor.pattern.match + match_subreddit = RedditSubredditExtractor.pattern.match + match_user = RedditUserExtractor.pattern.match + submissions = self.submissions() + visited = set() depth = 0 yield Message.Version, 1 - with extractor.blacklist( - util.SPECIAL_EXTRACTORS, - [RedditSubredditExtractor, RedditUserExtractor]): - while True: - extra = [] - for url, data in self._urls(submissions): - if url[0] == "#": + + while True: + extra = [] + + for submission, comments in submissions: + urls = [] + + if submission: + yield Message.Directory, submission + visited.add(submission["id"]) + url = submission["url"] + + if url.startswith("https://i.redd.it/"): + text.nameext_from_url(url, submission) + yield Message.Url, url, submission + elif submission["is_video"]: + submission["extension"] = None + yield Message.Url, "ytdl:" + url, submission + elif not submission["is_self"]: + urls.append((url, submission)) + + if self.api.comments: + if submission: + for url in text.extract_iter( + submission["selftext_html"] or "", + ' href="', '"'): + urls.append((url, submission)) + for comment in comments: + for url in text.extract_iter( + comment["body_html"] or "", ' href="', '"'): + urls.append((url, comment)) + + for url, data in urls: + if not url or url[0] == "#": continue if url[0] == "/": url = "https://www.reddit.com" + url - match = subre.match(url) + match = match_submission(url) if match: extra.append(match.group(1)) - else: + elif not match_user(url) and not match_subreddit(url): yield Message.Queue, text.unescape(url), data - if not extra or depth == self.max_depth: - return - depth += 1 - submissions = ( - self.api.submission(sid) for sid in extra - if sid not in self._visited - ) + if not extra or depth == self.max_depth: + return + depth += 1 + submissions = ( + self.api.submission(sid) for sid in extra + if sid not in self._visited + ) def submissions(self): """Return an iterable containing all (submission, comments) tuples""" - def _urls(self, submissions): - for submission, comments in submissions: - - if submission: - self._visited.add(submission["id"]) - - if not submission["is_self"]: - yield submission["url"], submission - - for url in text.extract_iter( - submission["selftext_html"] or "", ' href="', '"'): - yield url, submission - - if comments: - for comment in comments: - for url in text.extract_iter( - comment["body_html"] or "", ' href="', '"'): - yield url, comment - class RedditSubredditExtractor(RedditExtractor): """Extractor for URLs from subreddits on reddit.com""" @@ -84,7 +98,10 @@ class RedditSubredditExtractor(RedditExtractor): pattern = (r"(?:https?://)?(?:\w+\.)?reddit\.com/r/" r"([^/?&#]+(?:/[a-z]+)?)/?(?:\?([^#]*))?(?:$|#)") test = ( - ("https://www.reddit.com/r/lavaporn/"), + ("https://www.reddit.com/r/lavaporn/", { + "range": "1-20", + "count": ">= 20", + }), ("https://www.reddit.com/r/lavaporn/top/?sort=top&t=month"), ("https://old.reddit.com/r/lavaporn/"), ("https://np.reddit.com/r/lavaporn/"), @@ -210,7 +227,7 @@ class RedditAPI(): link_id = "t3_" + submission_id if self.morecomments else None submission, comments = self._call(endpoint, {"limit": self.comments}) return (submission["data"]["children"][0]["data"], - self._flatten(comments, link_id) if self.comments else None) + self._flatten(comments, link_id) if self.comments else ()) def submissions_subreddit(self, subreddit, params): """Collect all (submission, comments)-tuples of a subreddit""" @@ -290,7 +307,8 @@ class RedditAPI(): raise exception.AuthorizationError() if data["error"] == 404: raise exception.NotFoundError() - raise Exception(data["message"]) + self.log.debug(data) + raise exception.StopExtraction(data.get("message")) return data def _pagination(self, endpoint, params): @@ -315,7 +333,7 @@ class RedditAPI(): except exception.AuthorizationError: pass else: - yield post, None + yield post, () elif kind == "t1" and self.comments: yield None, (post,)