@ -9,7 +9,7 @@
""" Extract images subreddits at https://reddit.com/ """
from . common import Extractor , Message
from . . import text , ex ception
from . . import text , ex tractor, ex ception
from . . cache import cache
import re
@ -21,36 +21,54 @@ class RedditExtractor(Extractor):
def __init__ ( self ) :
Extractor . __init__ ( self )
self . api = RedditAPI ( self . session , self . log )
self . max_depth = int ( self . config ( " recursion " , 0 ) )
self . _visited = set ( )
def items ( self ) :
regex = re . compile ( r " https?://(?:[^.]+ \ .)?reddit.com/ " )
subre = re . compile ( RedditSubmissionExtractor . pattern [ 0 ] )
submissions = self . submissions ( )
depth = 0
yield Message . Version , 1
for submission , comments in self . submissions ( ) :
urls = [ submission [ " url " ] ]
urls . extend (
text . extract_iter (
" " . join ( self . _collect ( submission , comments ) ) ,
' href= " ' , ' " '
with extractor . blacklist ( " reddit " ) :
while True :
extra = [ ]
for url in self . _urls ( submissions ) :
if url [ 0 ] == " # " :
continue
if url [ 0 ] == " / " :
url = " https://www.reddit.com " + url
match = subre . match ( url )
if match :
extra . append ( match . group ( 1 ) )
else :
yield Message . Queue , url
if not extra or depth == self . max_depth :
return
depth + = 1
submissions = (
self . api . submission ( sid ) for sid in extra
if sid not in self . _visited
)
)
for url in urls :
if url [ 0 ] == " # " :
continue
elif url [ 0 ] == " / " :
url = " nofollow:https://www.reddit.com " + url
elif regex . match ( url ) :
url = " nofollow: " + url
yield Message . Queue , url
def _collect ( self , submission , comments ) :
yield submission [ " selftext_html " ] or " "
for comment in comments :
yield comment [ " body_html " ] or " "
def submissions ( self ) :
""" Return an iterable containing all (submission, comments) tuples """
def _urls ( self , submissions ) :
for submission , comments in submissions :
self . _visited . add ( submission [ " id " ] )
if not submission [ " is_self " ] :
yield submission [ " url " ]
strings = [ submission [ " selftext_html " ] or " " ]
strings + = [ c [ " body_html " ] or " " for c in comments ]
yield from text . extract_iter ( " " . join ( strings ) , ' href= " ' , ' " ' )
class RedditSubredditExtractor ( RedditExtractor ) :
""" Extractor for images from subreddits on reddit.com """
subcategory = " sub mission "
subcategory = " sub reddit "
pattern = [ r " (?:https?://)?(?:m \ .|www \ .)?reddit \ .com/r/([^/]+)/?$ " ]
def __init__ ( self , match ) :
@ -63,10 +81,11 @@ class RedditSubredditExtractor(RedditExtractor):
class RedditSubmissionExtractor ( RedditExtractor ) :
""" Extractor for images from a submission on reddit.com """
subcategory = " subreddit "
pattern = [ ( r " (?:https?://)?(?:m \ .|www \ .)?reddit \ .com/r/[^/]+ "
r " /comments/([a-z0-9]+) " ) ,
( r " (?:https?://)?redd \ .it/([a-z0-9]+) " ) ]
subcategory = " submission "
pattern = [ ( r " (?:https?://)?(?: "
r " (?:m \ .|www \ .)?reddit \ .com/r/[^/]+/comments| "
r " redd \ .it "
r " )/([a-z0-9]+) " ) ]
def __init__ ( self , match ) :
RedditExtractor . __init__ ( self )
@ -119,10 +138,15 @@ class RedditAPI():
def _call ( self , endpoint , params ) :
url = " https://oauth.reddit.com " + endpoint
# TODO: handle errors / rate limits
self . authenticate ( )
response = self . session . get ( url , params = params )
return response . json ( )
data = self . session . get ( url , params = params ) . json ( )
if " error " in data :
if data [ " error " ] == 403 :
raise exception . AuthorizationError ( )
if data [ " error " ] == 404 :
raise exception . NotFoundError ( )
raise Exception ( data [ " message " ] )
return data
def _pagination ( self , endpoint , params , _empty = ( ) ) :
while True :
@ -139,7 +163,8 @@ class RedditAPI():
return
params [ " after " ] = data [ " after " ]
def _unfold ( self , comments ) :
@staticmethod
def _unfold ( comments ) :
# TODO: order?
queue = comments [ " data " ] [ " children " ]
while queue :