@ -8,13 +8,16 @@
# published by the Free Software Foundation.
import re
import sys
import base64
import os . path
import tempfile
import unittest
import threading
import http . server
import unittest
from unittest . mock import Mock , MagicMock , patch
import gallery_dl . downloader as downloader
import gallery_dl . extractor as extractor
import gallery_dl . config as config
@ -23,6 +26,73 @@ from gallery_dl.output import NullOutput
from gallery_dl . util import PathFormat
class MockDownloaderModule ( Mock ) :
__downloader__ = " mock "
class TestDownloaderModule ( unittest . TestCase ) :
@classmethod
def setUpClass ( cls ) :
# allow import of ytdl downloader module without youtube_dl installed
sys . modules [ " youtube_dl " ] = MagicMock ( )
@classmethod
def tearDownClass ( cls ) :
del sys . modules [ " youtube_dl " ]
def tearDown ( self ) :
downloader . _cache . clear ( )
def test_find ( self ) :
cls = downloader . find ( " http " )
self . assertEqual ( cls . __name__ , " HttpDownloader " )
self . assertEqual ( cls . scheme , " http " )
cls = downloader . find ( " https " )
self . assertEqual ( cls . __name__ , " HttpDownloader " )
self . assertEqual ( cls . scheme , " http " )
cls = downloader . find ( " text " )
self . assertEqual ( cls . __name__ , " TextDownloader " )
self . assertEqual ( cls . scheme , " text " )
cls = downloader . find ( " ytdl " )
self . assertEqual ( cls . __name__ , " YoutubeDLDownloader " )
self . assertEqual ( cls . scheme , " ytdl " )
self . assertEqual ( downloader . find ( " ftp " ) , None )
self . assertEqual ( downloader . find ( " foo " ) , None )
self . assertEqual ( downloader . find ( 1234 ) , None )
self . assertEqual ( downloader . find ( None ) , None )
@patch ( " importlib.import_module " )
def test_cache ( self , import_module ) :
import_module . return_value = MockDownloaderModule ( )
downloader . find ( " http " )
downloader . find ( " text " )
downloader . find ( " ytdl " )
self . assertEqual ( import_module . call_count , 3 )
downloader . find ( " http " )
downloader . find ( " text " )
downloader . find ( " ytdl " )
self . assertEqual ( import_module . call_count , 3 )
@patch ( " importlib.import_module " )
def test_cache_http ( self , import_module ) :
import_module . return_value = MockDownloaderModule ( )
downloader . find ( " http " )
downloader . find ( " https " )
self . assertEqual ( import_module . call_count , 1 )
@patch ( " importlib.import_module " )
def test_cache_https ( self , import_module ) :
import_module . return_value = MockDownloaderModule ( )
downloader . find ( " https " )
downloader . find ( " http " )
self . assertEqual ( import_module . call_count , 1 )
class TestDownloaderBase ( unittest . TestCase ) :
@classmethod