From 425ece17ace5bed5775176413d0a7927ced55f06 Mon Sep 17 00:00:00 2001 From: Aaron Cohen Date: Sun, 22 Jul 2012 22:18:44 -0700 Subject: [PATCH 01/84] Synoindex tweaks, now will always log when sending notification --- headphones/notifiers.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/headphones/notifiers.py b/headphones/notifiers.py index ae04fe99..3c16b647 100644 --- a/headphones/notifiers.py +++ b/headphones/notifiers.py @@ -197,6 +197,8 @@ class Synoindex: return os.path.exists(self.util_loc) def notify(self, path): + path = os.path.abspath(path) + if not self.util_exists(): logger.warn("Error sending notification: synoindex utility not found at %s" % self.util_loc) return @@ -209,12 +211,12 @@ class Synoindex: logger.warn("Error sending notification: Path passed to synoindex was not a file or folder.") return - cmd = [self.util_loc, cmd_arg, '\"%s\"' % os.path.abspath(path)] - logger.debug("Calling synoindex command: %s" % str(cmd)) + cmd = [self.util_loc, cmd_arg, path] + logger.info("Calling synoindex command: %s" % str(cmd)) try: p = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, cwd=headphones.PROG_DIR) out, error = p.communicate() - logger.debug("Synoindex result: %s" % str(out)) + #synoindex never returns any codes other than '0', highly irritating except OSError, e: logger.warn("Error sending notification: %s" % str(e)) From 2b51447a9e7885c5d3835154566e0d6e919efd14 Mon Sep 17 00:00:00 2001 From: rembo10 Date: Wed, 25 Jul 2012 17:54:41 +0530 Subject: [PATCH 02/84] Initial changes to allow multiple newznab providers --- data/interfaces/default/config.html | 31 ++++++++++++++++------------- headphones/__init__.py | 5 ++++- headphones/webserve.py | 4 +++- 3 files changed, 24 insertions(+), 16 deletions(-) diff --git a/data/interfaces/default/config.html b/data/interfaces/default/config.html index ef0f534c..b5d62ecc 100644 --- a/data/interfaces/default/config.html +++ b/data/interfaces/default/config.html @@ -184,7 +184,7 @@ m<%inherit file="base.html"/>
- +
@@ -200,17 +200,20 @@ m<%inherit file="base.html"/>
-
- - - e.g. http://nzb.su -
-
- - -
-
- +
+ + + e.g. http://nzb.su +
+
+ + +
+
+ +
+
+ @@ -233,11 +236,11 @@ m<%inherit file="base.html"/>
- +
- +
diff --git a/headphones/__init__.py b/headphones/__init__.py index a6ab5255..91b396f7 100644 --- a/headphones/__init__.py +++ b/headphones/__init__.py @@ -121,6 +121,7 @@ NZBMATRIX_APIKEY = None NEWZNAB = False NEWZNAB_HOST = None NEWZNAB_APIKEY = None +NEWZNAB_ENABLED = False NZBSORG = False NZBSORG_UID = None @@ -240,7 +241,7 @@ def initialize(): ADD_ALBUM_ART, EMBED_ALBUM_ART, EMBED_LYRICS, DOWNLOAD_DIR, BLACKHOLE, BLACKHOLE_DIR, USENET_RETENTION, SEARCH_INTERVAL, \ TORRENTBLACKHOLE_DIR, NUMBEROFSEEDERS, ISOHUNT, KAT, MININOVA, WAFFLES, WAFFLES_UID, WAFFLES_PASSKEY, DOWNLOAD_TORRENT_DIR, \ LIBRARYSCAN_INTERVAL, DOWNLOAD_SCAN_INTERVAL, SAB_HOST, SAB_USERNAME, SAB_PASSWORD, SAB_APIKEY, SAB_CATEGORY, \ - NZBMATRIX, NZBMATRIX_USERNAME, NZBMATRIX_APIKEY, NEWZNAB, NEWZNAB_HOST, NEWZNAB_APIKEY, \ + NZBMATRIX, NZBMATRIX_USERNAME, NZBMATRIX_APIKEY, NEWZNAB, NEWZNAB_HOST, NEWZNAB_APIKEY, NEWZNAB_ENABLED, \ NZBSORG, NZBSORG_UID, NZBSORG_HASH, NEWZBIN, NEWZBIN_UID, NEWZBIN_PASSWORD, LASTFM_USERNAME, INTERFACE, FOLDER_PERMISSIONS, \ ENCODERFOLDER, ENCODER, BITRATE, SAMPLINGFREQUENCY, MUSIC_ENCODER, ADVANCEDENCODER, ENCODEROUTPUTFORMAT, ENCODERQUALITY, ENCODERVBRCBR, \ ENCODERLOSSLESS, PROWL_ENABLED, PROWL_PRIORITY, PROWL_KEYS, PROWL_ONSNATCH, MIRRORLIST, MIRROR, CUSTOMHOST, CUSTOMPORT, \ @@ -340,6 +341,7 @@ def initialize(): NEWZNAB = bool(check_setting_int(CFG, 'Newznab', 'newznab', 0)) NEWZNAB_HOST = check_setting_str(CFG, 'Newznab', 'newznab_host', '') NEWZNAB_APIKEY = check_setting_str(CFG, 'Newznab', 'newznab_apikey', '') + NEWZNAB_ENABLED = bool(check_setting_int(CFG, 'Newznab', 'newznab_enabled', 1)) NZBSORG = bool(check_setting_int(CFG, 'NZBsorg', 'nzbsorg', 0)) NZBSORG_UID = check_setting_str(CFG, 'NZBsorg', 'nzbsorg_uid', '') @@ -613,6 +615,7 @@ def config_write(): new_config['Newznab']['newznab'] = int(NEWZNAB) new_config['Newznab']['newznab_host'] = NEWZNAB_HOST new_config['Newznab']['newznab_apikey'] = NEWZNAB_APIKEY + new_config['Newznab']['newznab_enabled'] = int(NEWZNAB_ENABLED) new_config['NZBsorg'] = {} new_config['NZBsorg']['nzbsorg'] = int(NZBSORG) diff --git a/headphones/webserve.py b/headphones/webserve.py index fc0332e7..93865a6a 100644 --- a/headphones/webserve.py +++ b/headphones/webserve.py @@ -384,6 +384,7 @@ class WebInterface(object): "use_newznab" : checked(headphones.NEWZNAB), "newznab_host" : headphones.NEWZNAB_HOST, "newznab_api" : headphones.NEWZNAB_APIKEY, + "newznab_enabled" : checked(headphones.NEWZNAB_ENABLED), "use_nzbsorg" : checked(headphones.NZBSORG), "nzbsorg_uid" : headphones.NZBSORG_UID, "nzbsorg_hash" : headphones.NZBSORG_HASH, @@ -458,7 +459,7 @@ class WebInterface(object): def configUpdate(self, http_host='0.0.0.0', http_username=None, http_port=8181, http_password=None, launch_browser=0, api_enabled=0, api_key=None, download_scan_interval=None, nzb_search_interval=None, libraryscan_interval=None, sab_host=None, sab_username=None, sab_apikey=None, sab_password=None, sab_category=None, download_dir=None, blackhole=0, blackhole_dir=None, - usenet_retention=None, nzbmatrix=0, nzbmatrix_username=None, nzbmatrix_apikey=None, newznab=0, newznab_host=None, newznab_apikey=None, + usenet_retention=None, nzbmatrix=0, nzbmatrix_username=None, nzbmatrix_apikey=None, newznab=0, newznab_host=None, newznab_apikey=None, newznab_enabled=0, nzbsorg=0, nzbsorg_uid=None, nzbsorg_hash=None, newzbin=0, newzbin_uid=None, newzbin_password=None, preferred_quality=0, preferred_bitrate=None, detect_bitrate=0, move_files=0, torrentblackhole_dir=None, download_torrent_dir=None, numberofseeders=10, use_isohunt=0, use_kat=0, use_mininova=0, waffles=0, waffles_uid=None, waffles_passkey=None, rename_files=0, correct_metadata=0, cleanup_files=0, add_album_art=0, embed_album_art=0, embed_lyrics=0, destination_dir=None, folder_format=None, file_format=None, include_extras=0, autowant_upcoming=False, autowant_all=False, interface=None, log_dir=None, @@ -491,6 +492,7 @@ class WebInterface(object): headphones.NEWZNAB = newznab headphones.NEWZNAB_HOST = newznab_host headphones.NEWZNAB_APIKEY = newznab_apikey + headphones.NEWZNAB_ENABLED = newznab_enabled headphones.NZBSORG = nzbsorg headphones.NZBSORG_UID = nzbsorg_uid headphones.NZBSORG_HASH = nzbsorg_hash From bf43f65995a5093efc393da81bd975b720d641ee Mon Sep 17 00:00:00 2001 From: rembo10 Date: Wed, 25 Jul 2012 18:12:30 +0530 Subject: [PATCH 03/84] Added a timeout to last.fm album art calls from the post processor --- headphones/albumart.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/headphones/albumart.py b/headphones/albumart.py index 3b80e5f5..5fd6f8a8 100644 --- a/headphones/albumart.py +++ b/headphones/albumart.py @@ -13,6 +13,7 @@ # You should have received a copy of the GNU General Public License # along with Headphones. If not, see . +import urllib2 from headphones import db def getAlbumArt(albumid): @@ -36,7 +37,7 @@ def getCachedArt(albumid): return None if artwork_path.startswith('http://'): - artwork = urllib.urlopen(artwork_path).read() + artwork = urllib2.urlopen(artwork_path, timeout=20).read() return artwork else: artwork = open(artwork_path, "r").read() From 4f8e5877a7e4642984f556338c5cbfd3b11ed5cb Mon Sep 17 00:00:00 2001 From: rembo10 Date: Wed, 25 Jul 2012 20:18:43 +0530 Subject: [PATCH 04/84] Added jquery to allow adding/removing newznab providers --- data/interfaces/default/config.html | 41 ++++++++++++++++++++--------- 1 file changed, 28 insertions(+), 13 deletions(-) diff --git a/data/interfaces/default/config.html b/data/interfaces/default/config.html index b5d62ecc..3a0ac8a8 100644 --- a/data/interfaces/default/config.html +++ b/data/interfaces/default/config.html @@ -199,19 +199,22 @@ m<%inherit file="base.html"/>
-
-
- - - e.g. http://nzb.su -
-
- - -
-
- +
+
+
+ + + e.g. http://nzb.su +
+
+ + +
+
+ +
+
@@ -792,7 +795,19 @@ m<%inherit file="base.html"/> initConfigCheckbox("#useapi"); } $(document).ready(function() { - initThisPage(); + initThisPage(); + $("#add_newznab").click(function() { + var intIdPrev = $("#newznab_providers > div").size() + var intId = intIdPrev + 1; + var formfields = $("
"); + var removeButton = $("
"); + removeButton.click(function() { + $(this).parent().remove(); + }); + formfields.append(removeButton); + formfields.append("
"); + $("#newznab" + intIdPrev).append(formfields); + }); }); From 0cf59d94ffea31f277e1cdde45fec671a14e0875 Mon Sep 17 00:00:00 2001 From: rembo10 Date: Wed, 25 Jul 2012 20:33:11 +0530 Subject: [PATCH 05/84] Allow for extra newznabs to be inserted into the form on load --- data/interfaces/default/config.html | 24 +++++++++++++++++++++++- 1 file changed, 23 insertions(+), 1 deletion(-) diff --git a/data/interfaces/default/config.html b/data/interfaces/default/config.html index 3a0ac8a8..9bceb8db 100644 --- a/data/interfaces/default/config.html +++ b/data/interfaces/default/config.html @@ -214,6 +214,28 @@ m<%inherit file="base.html"/>
+ <% + newznab_number = 2 + %> + %for newznab in config['extra_newznabs']: +
+
+ + + e.g. http://nzb.su +
+
+ + +
+
+ +
+
+ <% + newznab_number += 1 + %> + %endfor @@ -799,7 +821,7 @@ m<%inherit file="base.html"/> $("#add_newznab").click(function() { var intIdPrev = $("#newznab_providers > div").size() var intId = intIdPrev + 1; - var formfields = $("
"); + var formfields = $("
"); var removeButton = $("
"); removeButton.click(function() { $(this).parent().remove(); From cc2adb40ebda1381cb0421e720e9a316dc826ac6 Mon Sep 17 00:00:00 2001 From: rembo10 Date: Thu, 26 Jul 2012 01:29:16 +0530 Subject: [PATCH 06/84] Extra newznabs can be saved to and pulled from the config --- data/interfaces/default/config.html | 27 ++++++++++++++++++++------- headphones/__init__.py | 5 ++++- headphones/webserve.py | 19 ++++++++++++++++++- 3 files changed, 42 insertions(+), 9 deletions(-) diff --git a/data/interfaces/default/config.html b/data/interfaces/default/config.html index 9bceb8db..5f29847b 100644 --- a/data/interfaces/default/config.html +++ b/data/interfaces/default/config.html @@ -1,6 +1,7 @@ m<%inherit file="base.html"/> <%! import headphones + from operator import itemgetter %> <%def name="headerIncludes()"> @@ -217,19 +218,27 @@ m<%inherit file="base.html"/> <% newznab_number = 2 %> - %for newznab in config['extra_newznabs']: + %for newznab in sorted(config['extra_newznabs'], key=itemgetter(0)): + <% + if newznab[2]: + newznab_enabled = "checked" + else: + newznab_enabled = "" + %>
- e.g. http://nzb.su
- +
- + +
+
+
<% @@ -804,6 +813,10 @@ m<%inherit file="base.html"/> $("#mirror").change(handleNewSelection); handleNewSelection.apply($("#mirror")); + + $(".remove").click(function() { + $(this).parent().parent().remove(); + }); $(function() { $( "#tabs" ).tabs(); }); @@ -819,16 +832,16 @@ m<%inherit file="base.html"/> $(document).ready(function() { initThisPage(); $("#add_newznab").click(function() { - var intIdPrev = $("#newznab_providers > div").size() + var intIdPrev = $("#newznab_providers > div").size(); var intId = intIdPrev + 1; - var formfields = $("
"); + var formfields = $("
"); var removeButton = $("
"); removeButton.click(function() { $(this).parent().remove(); }); formfields.append(removeButton); formfields.append("
"); - $("#newznab" + intIdPrev).append(formfields); + $("#newznab" + intIdPrev).after(formfields); }); }); diff --git a/headphones/__init__.py b/headphones/__init__.py index 91b396f7..f8eacc1a 100644 --- a/headphones/__init__.py +++ b/headphones/__init__.py @@ -122,6 +122,7 @@ NEWZNAB = False NEWZNAB_HOST = None NEWZNAB_APIKEY = None NEWZNAB_ENABLED = False +EXTRA_NEWZNABS = [] NZBSORG = False NZBSORG_UID = None @@ -241,7 +242,7 @@ def initialize(): ADD_ALBUM_ART, EMBED_ALBUM_ART, EMBED_LYRICS, DOWNLOAD_DIR, BLACKHOLE, BLACKHOLE_DIR, USENET_RETENTION, SEARCH_INTERVAL, \ TORRENTBLACKHOLE_DIR, NUMBEROFSEEDERS, ISOHUNT, KAT, MININOVA, WAFFLES, WAFFLES_UID, WAFFLES_PASSKEY, DOWNLOAD_TORRENT_DIR, \ LIBRARYSCAN_INTERVAL, DOWNLOAD_SCAN_INTERVAL, SAB_HOST, SAB_USERNAME, SAB_PASSWORD, SAB_APIKEY, SAB_CATEGORY, \ - NZBMATRIX, NZBMATRIX_USERNAME, NZBMATRIX_APIKEY, NEWZNAB, NEWZNAB_HOST, NEWZNAB_APIKEY, NEWZNAB_ENABLED, \ + NZBMATRIX, NZBMATRIX_USERNAME, NZBMATRIX_APIKEY, NEWZNAB, NEWZNAB_HOST, NEWZNAB_APIKEY, NEWZNAB_ENABLED, EXTRA_NEWZNABS,\ NZBSORG, NZBSORG_UID, NZBSORG_HASH, NEWZBIN, NEWZBIN_UID, NEWZBIN_PASSWORD, LASTFM_USERNAME, INTERFACE, FOLDER_PERMISSIONS, \ ENCODERFOLDER, ENCODER, BITRATE, SAMPLINGFREQUENCY, MUSIC_ENCODER, ADVANCEDENCODER, ENCODEROUTPUTFORMAT, ENCODERQUALITY, ENCODERVBRCBR, \ ENCODERLOSSLESS, PROWL_ENABLED, PROWL_PRIORITY, PROWL_KEYS, PROWL_ONSNATCH, MIRRORLIST, MIRROR, CUSTOMHOST, CUSTOMPORT, \ @@ -342,6 +343,7 @@ def initialize(): NEWZNAB_HOST = check_setting_str(CFG, 'Newznab', 'newznab_host', '') NEWZNAB_APIKEY = check_setting_str(CFG, 'Newznab', 'newznab_apikey', '') NEWZNAB_ENABLED = bool(check_setting_int(CFG, 'Newznab', 'newznab_enabled', 1)) + EXTRA_NEWZNABS = check_setting_str(CFG, 'Newznab', 'extra_newznabs', [], log=False) NZBSORG = bool(check_setting_int(CFG, 'NZBsorg', 'nzbsorg', 0)) NZBSORG_UID = check_setting_str(CFG, 'NZBsorg', 'nzbsorg_uid', '') @@ -616,6 +618,7 @@ def config_write(): new_config['Newznab']['newznab_host'] = NEWZNAB_HOST new_config['Newznab']['newznab_apikey'] = NEWZNAB_APIKEY new_config['Newznab']['newznab_enabled'] = int(NEWZNAB_ENABLED) + new_config['Newznab']['extra_newznabs'] = EXTRA_NEWZNABS new_config['NZBsorg'] = {} new_config['NZBsorg']['nzbsorg'] = int(NZBSORG) diff --git a/headphones/webserve.py b/headphones/webserve.py index 93865a6a..2d190bf8 100644 --- a/headphones/webserve.py +++ b/headphones/webserve.py @@ -385,6 +385,7 @@ class WebInterface(object): "newznab_host" : headphones.NEWZNAB_HOST, "newznab_api" : headphones.NEWZNAB_APIKEY, "newznab_enabled" : checked(headphones.NEWZNAB_ENABLED), + "extra_newznabs" : headphones.EXTRA_NEWZNABS, "use_nzbsorg" : checked(headphones.NZBSORG), "nzbsorg_uid" : headphones.NZBSORG_UID, "nzbsorg_hash" : headphones.NZBSORG_HASH, @@ -465,7 +466,7 @@ class WebInterface(object): rename_files=0, correct_metadata=0, cleanup_files=0, add_album_art=0, embed_album_art=0, embed_lyrics=0, destination_dir=None, folder_format=None, file_format=None, include_extras=0, autowant_upcoming=False, autowant_all=False, interface=None, log_dir=None, music_encoder=0, encoder=None, bitrate=None, samplingfrequency=None, encoderfolder=None, advancedencoder=None, encoderoutputformat=None, encodervbrcbr=None, encoderquality=None, encoderlossless=0, prowl_enabled=0, prowl_onsnatch=0, prowl_keys=None, prowl_priority=0, xbmc_enabled=0, xbmc_host=None, xbmc_username=None, xbmc_password=None, xbmc_update=0, xbmc_notify=0, - nma_enabled=False, nma_apikey=None, nma_priority=0, synoindex_enabled=False, mirror=None, customhost=None, customport=None, customsleep=None, hpuser=None, hppass=None): + nma_enabled=False, nma_apikey=None, nma_priority=0, synoindex_enabled=False, mirror=None, customhost=None, customport=None, customsleep=None, hpuser=None, hppass=None, **kwargs): headphones.HTTP_HOST = http_host headphones.HTTP_PORT = http_port @@ -556,6 +557,22 @@ class WebInterface(object): headphones.CUSTOMSLEEP = customsleep headphones.HPUSER = hpuser headphones.HPPASS = hppass + + # Handle the variable config options. Note - keys with False values aren't getting passed + + headphones.EXTRA_NEWZNABS = [] + + for kwarg in kwargs: + if kwarg.startswith('newznab_host'): + newznab_number = kwarg[12:] + newznab_host = kwargs['newznab_host' + newznab_number] + newznab_api = kwargs['newznab_api' + newznab_number] + try: + newznab_enabled = int(kwargs['newznab_enabled' + newznab_number]) + except KeyError: + newznab_enabled = 0 + + headphones.EXTRA_NEWZNABS.append([newznab_host, newznab_api, newznab_enabled]) headphones.config_write() From 362338926c9aad9e961a73dfa83b0cc35934a632 Mon Sep 17 00:00:00 2001 From: rembo10 Date: Thu, 26 Jul 2012 01:41:52 +0530 Subject: [PATCH 07/84] Make sure we can still add new newznab providers after a config save without refreshing the page. Moved the add function from the document ready function to the main functions --- data/interfaces/default/config.html | 26 ++++++++++++++------------ 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/data/interfaces/default/config.html b/data/interfaces/default/config.html index 5f29847b..120ec49a 100644 --- a/data/interfaces/default/config.html +++ b/data/interfaces/default/config.html @@ -817,6 +817,20 @@ m<%inherit file="base.html"/> $(".remove").click(function() { $(this).parent().parent().remove(); }); + + $("#add_newznab").click(function() { + var intIdPrev = $("#newznab_providers > div").size(); + var intId = intIdPrev + 1; + var formfields = $("
"); + var removeButton = $("
"); + removeButton.click(function() { + $(this).parent().remove(); + }); + formfields.append(removeButton); + formfields.append("
"); + $("#newznab" + intIdPrev).after(formfields); + }); + $(function() { $( "#tabs" ).tabs(); }); @@ -831,18 +845,6 @@ m<%inherit file="base.html"/> } $(document).ready(function() { initThisPage(); - $("#add_newznab").click(function() { - var intIdPrev = $("#newznab_providers > div").size(); - var intId = intIdPrev + 1; - var formfields = $("
"); - var removeButton = $("
"); - removeButton.click(function() { - $(this).parent().remove(); - }); - formfields.append(removeButton); - formfields.append("
"); - $("#newznab" + intIdPrev).after(formfields); - }); }); From 351c5de23ebc5c9c2ce1c94cce77207f7cffa5c3 Mon Sep 17 00:00:00 2001 From: rembo10 Date: Thu, 26 Jul 2012 02:05:54 +0530 Subject: [PATCH 08/84] Modified searcher.py to use the new multiple newznab format --- headphones/searcher.py | 81 +++++++++++++++++++++++++----------------- 1 file changed, 48 insertions(+), 33 deletions(-) diff --git a/headphones/searcher.py b/headphones/searcher.py index 95e14a5e..a7808220 100644 --- a/headphones/searcher.py +++ b/headphones/searcher.py @@ -216,6 +216,16 @@ def searchNZB(albumid=None, new=False, losslessOnly=False): logger.info(u"No results found from NZBMatrix for %s" % term) if headphones.NEWZNAB: + + newznab_hosts = [[headphones.NEWZNAB_HOST, headphones.NEWZNAB_APIKEY, headphones.NEWZNAB_ENABLED]] + + # This is just to make sure we don't have any empty string for EXTRA_NEWZNABS + if not headphones.EXTRA_NEWZNABS: + headphones.EXTRA_NEWZNABS = [] + + for newznab_host in headphones.EXTRA_NEWZNABS: + newznab_hosts.append(newznab_host) + provider = "newznab" if headphones.PREFERRED_QUALITY == 3 or losslessOnly: categories = "3040" @@ -227,44 +237,49 @@ def searchNZB(albumid=None, new=False, losslessOnly=False): if albums['Type'] == 'Other': categories = "3030" logger.info("Album type is audiobook/spokenword. Using audiobook category") + + for newznab_host in newznab_hosts: + + if newznab_host[2] == 0 or newznab_host[2] == '0': + continue - params = { "t": "search", - "apikey": headphones.NEWZNAB_APIKEY, - "cat": categories, - "maxage": headphones.USENET_RETENTION, - "q": term - } - - searchURL = headphones.NEWZNAB_HOST + '/api?' + urllib.urlencode(params) - - logger.info(u'Parsing results from %s' % (searchURL, headphones.NEWZNAB_HOST)) + params = { "t": "search", + "apikey": newznab_host[1], + "cat": categories, + "maxage": headphones.USENET_RETENTION, + "q": term + } - try: - data = urllib2.urlopen(searchURL, timeout=20).read() - except urllib2.URLError, e: - logger.warn('Error fetching data from %s: %s' % (headphones.NEWZNAB_HOST, e)) - data = False + searchURL = newznab_host[0] + '/api?' + urllib.urlencode(params) + + logger.info(u'Parsing results from %s' % (searchURL, newznab_host[0])) - if data: - - d = feedparser.parse(data) + try: + data = urllib2.urlopen(searchURL, timeout=20).read() + except urllib2.URLError, e: + logger.warn('Error fetching data from %s: %s' % (newznab_host[0], e)) + data = False + + if data: - if not len(d.entries): - logger.info(u"No results found from %s for %s" % (headphones.NEWZNAB_HOST, term)) - pass - - else: - for item in d.entries: - try: - url = item.link - title = item.title - size = int(item.links[1]['length']) + d = feedparser.parse(data) + + if not len(d.entries): + logger.info(u"No results found from %s for %s" % (newznab_host[0], term)) + pass + + else: + for item in d.entries: + try: + url = item.link + title = item.title + size = int(item.links[1]['length']) + + resultlist.append((title, size, url, provider)) + logger.info('Found %s. Size: %s' % (title, helpers.bytes_to_mb(size))) - resultlist.append((title, size, url, provider)) - logger.info('Found %s. Size: %s' % (title, helpers.bytes_to_mb(size))) - - except Exception, e: - logger.error(u"An unknown error occurred trying to parse the feed: %s" % e) + except Exception, e: + logger.error(u"An unknown error occurred trying to parse the feed: %s" % e) if headphones.NZBSORG: provider = "nzbsorg" From f8ef52b8eefa4fd6828c138fdd08131f7ef6f0f0 Mon Sep 17 00:00:00 2001 From: rembo10 Date: Thu, 26 Jul 2012 17:22:54 +0530 Subject: [PATCH 09/84] Some fixes to get mult_newznabs working: unpack & repack settings when saving to/pulling from config, modified searcher.py to work with tuples, fixed config.html to create new intIds no matter what, place new newznabs before add button, instead of after last div --- data/interfaces/default/config.html | 15 +++++++++------ headphones/__init__.py | 15 ++++++++++++--- headphones/searcher.py | 12 +++--------- headphones/webserve.py | 2 +- 4 files changed, 25 insertions(+), 19 deletions(-) diff --git a/data/interfaces/default/config.html b/data/interfaces/default/config.html index 120ec49a..216e4eb6 100644 --- a/data/interfaces/default/config.html +++ b/data/interfaces/default/config.html @@ -1,7 +1,6 @@ m<%inherit file="base.html"/> <%! import headphones - from operator import itemgetter %> <%def name="headerIncludes()"> @@ -218,9 +217,9 @@ m<%inherit file="base.html"/> <% newznab_number = 2 %> - %for newznab in sorted(config['extra_newznabs'], key=itemgetter(0)): + %for newznab in config['extra_newznabs']: <% - if newznab[2]: + if newznab[2] == '1' or newznab[2] == 1: newznab_enabled = "checked" else: newznab_enabled = "" @@ -814,21 +813,25 @@ m<%inherit file="base.html"/> $("#mirror").change(handleNewSelection); handleNewSelection.apply($("#mirror")); + var deletedNewznabs = 0; + $(".remove").click(function() { $(this).parent().parent().remove(); + deletedNewznabs = deletedNewznabs + 1; }); $("#add_newznab").click(function() { - var intIdPrev = $("#newznab_providers > div").size(); - var intId = intIdPrev + 1; + var intId = $("#newznab_providers > div").size() + deletedNewznabs + 1; var formfields = $("
"); var removeButton = $("
"); removeButton.click(function() { $(this).parent().remove(); + deletedNewznabs = deletedNewznabs + 1; + }); formfields.append(removeButton); formfields.append("
"); - $("#newznab" + intIdPrev).after(formfields); + $("#add_newznab").before(formfields); }); $(function() { diff --git a/headphones/__init__.py b/headphones/__init__.py index f8eacc1a..265a016e 100644 --- a/headphones/__init__.py +++ b/headphones/__init__.py @@ -20,6 +20,7 @@ import os, sys, subprocess import threading import webbrowser import sqlite3 +import itertools from lib.apscheduler.scheduler import Scheduler from lib.configobj import ConfigObj @@ -229,7 +230,6 @@ def check_setting_str(config, cfg_name, item_name, def_val, log=True): else: logger.debug(item_name + " -> ******") return my_val - def initialize(): @@ -343,7 +343,10 @@ def initialize(): NEWZNAB_HOST = check_setting_str(CFG, 'Newznab', 'newznab_host', '') NEWZNAB_APIKEY = check_setting_str(CFG, 'Newznab', 'newznab_apikey', '') NEWZNAB_ENABLED = bool(check_setting_int(CFG, 'Newznab', 'newznab_enabled', 1)) - EXTRA_NEWZNABS = check_setting_str(CFG, 'Newznab', 'extra_newznabs', [], log=False) + + # Need to pack the extra newznabs back into a list of tuples + flattened_newznabs = check_setting_str(CFG, 'Newznab', 'extra_newznabs', [], log=False) + EXTRA_NEWZNABS = list(itertools.izip(*[itertools.islice(flattened_newznabs, i, None, 3) for i in range(3)])) NZBSORG = bool(check_setting_int(CFG, 'NZBsorg', 'nzbsorg', 0)) NZBSORG_UID = check_setting_str(CFG, 'NZBsorg', 'nzbsorg_uid', '') @@ -618,7 +621,13 @@ def config_write(): new_config['Newznab']['newznab_host'] = NEWZNAB_HOST new_config['Newznab']['newznab_apikey'] = NEWZNAB_APIKEY new_config['Newznab']['newznab_enabled'] = int(NEWZNAB_ENABLED) - new_config['Newznab']['extra_newznabs'] = EXTRA_NEWZNABS + # Need to unpack the extra newznabs for saving in config.ini + flattened_newznabs = [] + for newznab in EXTRA_NEWZNABS: + for item in newznab: + flattened_newznabs.append(item) + + new_config['Newznab']['extra_newznabs'] = flattened_newznabs new_config['NZBsorg'] = {} new_config['NZBsorg']['nzbsorg'] = int(NZBSORG) diff --git a/headphones/searcher.py b/headphones/searcher.py index a7808220..85b699e5 100644 --- a/headphones/searcher.py +++ b/headphones/searcher.py @@ -217,14 +217,11 @@ def searchNZB(albumid=None, new=False, losslessOnly=False): if headphones.NEWZNAB: - newznab_hosts = [[headphones.NEWZNAB_HOST, headphones.NEWZNAB_APIKEY, headphones.NEWZNAB_ENABLED]] - - # This is just to make sure we don't have any empty string for EXTRA_NEWZNABS - if not headphones.EXTRA_NEWZNABS: - headphones.EXTRA_NEWZNABS = [] + newznab_hosts = [(headphones.NEWZNAB_HOST, headphones.NEWZNAB_APIKEY, headphones.NEWZNAB_ENABLED)] for newznab_host in headphones.EXTRA_NEWZNABS: - newznab_hosts.append(newznab_host) + if newznab_host[2] == '1' or newznab_host[2] == 1: + newznab_hosts.append(newznab_host) provider = "newznab" if headphones.PREFERRED_QUALITY == 3 or losslessOnly: @@ -239,9 +236,6 @@ def searchNZB(albumid=None, new=False, losslessOnly=False): logger.info("Album type is audiobook/spokenword. Using audiobook category") for newznab_host in newznab_hosts: - - if newznab_host[2] == 0 or newznab_host[2] == '0': - continue params = { "t": "search", "apikey": newznab_host[1], diff --git a/headphones/webserve.py b/headphones/webserve.py index 2d190bf8..2b2e89fd 100644 --- a/headphones/webserve.py +++ b/headphones/webserve.py @@ -572,7 +572,7 @@ class WebInterface(object): except KeyError: newznab_enabled = 0 - headphones.EXTRA_NEWZNABS.append([newznab_host, newznab_api, newznab_enabled]) + headphones.EXTRA_NEWZNABS.append((newznab_host, newznab_api, newznab_enabled)) headphones.config_write() From 780facc67582cd338893da374d742858f9fab507 Mon Sep 17 00:00:00 2001 From: rodikal Date: Fri, 27 Jul 2012 19:14:28 +0300 Subject: [PATCH 10/84] fixed typo in footer --- data/interfaces/brink/base.html | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/data/interfaces/brink/base.html b/data/interfaces/brink/base.html index e608ace8..6c38cb14 100644 --- a/data/interfaces/brink/base.html +++ b/data/interfaces/brink/base.html @@ -160,7 +160,7 @@ -
IWanted Albums UActive Artists JPost-Process From d245428ca2fa1bcc0f7562d96b1a4c71b45ec139 Mon Sep 17 00:00:00 2001 From: rembo10 Date: Sat, 28 Jul 2012 23:45:08 +0530 Subject: [PATCH 11/84] InRough update of the beets lib to 1.0b15 --- lib/beets/__init__.py | 4 +- lib/beets/autotag/__init__.py | 59 ++- lib/beets/autotag/hooks.py | 78 ++- lib/beets/autotag/match.py | 189 ++++--- lib/beets/autotag/mb.py | 141 +++++- lib/beets/importer.py | 328 +++++++------ lib/beets/library.py | 865 ++++++++++++++++++++++----------- lib/beets/mediafile.py | 546 +++++++++++---------- lib/beets/plugins.py | 30 +- lib/beets/ui/__init__.py | 145 ++++-- lib/beets/ui/commands.py | 534 ++++++++++---------- lib/beets/util/__init__.py | 234 ++++++--- lib/beets/util/bluelet.py | 562 +++++++++++++++++++++ lib/beets/util/enumeration.py | 56 +-- lib/beets/util/functemplate.py | 215 +++++++- lib/beets/util/pipeline.py | 61 +-- lib/beets/vfs.py | 2 +- 17 files changed, 2786 insertions(+), 1263 deletions(-) create mode 100644 lib/beets/util/bluelet.py diff --git a/lib/beets/__init__.py b/lib/beets/__init__.py index c7ef23b6..3ef490cf 100644 --- a/lib/beets/__init__.py +++ b/lib/beets/__init__.py @@ -8,7 +8,7 @@ # distribute, sublicense, and/or sell copies of the Software, and to # permit persons to whom the Software is furnished to do so, subject to # the following conditions: -# +# # The above copyright notice and this permission notice shall be # included in all copies or substantial portions of the Software. @@ -16,7 +16,7 @@ # MODIFIED TO WORK WITH HEADPHONES!! # -__version__ = '1.0b14' +__version__ = '1.0b15' __author__ = 'Adrian Sampson ' from lib.beets.library import Library diff --git a/lib/beets/autotag/__init__.py b/lib/beets/autotag/__init__.py index 2ea52e03..e4e4d1a0 100644 --- a/lib/beets/autotag/__init__.py +++ b/lib/beets/autotag/__init__.py @@ -1,5 +1,5 @@ # This file is part of beets. -# Copyright 2011, Adrian Sampson. +# Copyright 2012, Adrian Sampson. # # Permission is hereby granted, free of charge, to any person obtaining # a copy of this software and associated documentation files (the @@ -8,7 +8,7 @@ # distribute, sublicense, and/or sell copies of the Software, and to # permit persons to whom the Software is furnished to do so, subject to # the following conditions: -# +# # The above copyright notice and this permission notice shall be # included in all copies or substantial portions of the Software. @@ -22,7 +22,7 @@ from lib.beets import library, mediafile from lib.beets.util import sorted_walk, ancestry # Parts of external interface. -from .hooks import AlbumInfo, TrackInfo +from .hooks import AlbumInfo, TrackInfo, AlbumMatch, TrackMatch from .match import AutotagError from .match import tag_item, tag_album from .match import RECOMMEND_STRONG, RECOMMEND_MEDIUM, RECOMMEND_NONE @@ -93,7 +93,7 @@ def albums_in_dir(path, ignore=()): collapse_root = root collapse_items = [] continue - + # If it's nonempty, yield it. if items: yield root, items @@ -106,6 +106,8 @@ def apply_item_metadata(item, track_info): """Set an item's metadata from its matched TrackInfo object. """ item.artist = track_info.artist + item.artist_sort = track_info.artist_sort + item.artist_credit = track_info.artist_credit item.title = track_info.title item.mb_trackid = track_info.track_id if track_info.artist_id: @@ -113,11 +115,12 @@ def apply_item_metadata(item, track_info): # At the moment, the other metadata is left intact (including album # and track number). Perhaps these should be emptied? -def apply_metadata(items, album_info): - """Set the items' metadata to match an AlbumInfo object. The list - of items must be ordered. +def apply_metadata(album_info, mapping, per_disc_numbering=False): + """Set the items' metadata to match an AlbumInfo object using a + mapping from Items to TrackInfo objects. If `per_disc_numbering`, + then the track numbers are per-disc instead of per-release. """ - for index, (item, track_info) in enumerate(zip(items, album_info.tracks)): + for item, track_info in mapping.iteritems(): # Album, artist, track count. if not item: continue @@ -127,8 +130,15 @@ def apply_metadata(items, album_info): item.artist = album_info.artist item.albumartist = album_info.artist item.album = album_info.album - item.tracktotal = len(items) - + item.tracktotal = len(album_info.tracks) + + # Artist sort and credit names. + item.artist_sort = track_info.artist_sort or album_info.artist_sort + item.artist_credit = track_info.artist_credit or \ + album_info.artist_credit + item.albumartist_sort = album_info.artist_sort + item.albumartist_credit = album_info.artist_credit + # Release date. if album_info.year: item.year = album_info.year @@ -136,15 +146,19 @@ def apply_metadata(items, album_info): item.month = album_info.month if album_info.day: item.day = album_info.day - - # Title and track index. + + # Title. item.title = track_info.title - item.track = index + 1 + + if per_disc_numbering: + item.track = track_info.medium_index + else: + item.track = track_info.index # Disc and disc count. item.disc = track_info.medium item.disctotal = album_info.mediums - + # MusicBrainz IDs. item.mb_trackid = track_info.track_id item.mb_albumid = album_info.album_id @@ -153,12 +167,25 @@ def apply_metadata(items, album_info): else: item.mb_artistid = album_info.artist_id item.mb_albumartistid = album_info.artist_id + item.mb_releasegroupid = album_info.releasegroup_id + + # Compilation flag. + item.comp = album_info.va + + # Miscellaneous metadata. item.albumtype = album_info.albumtype if album_info.label: item.label = album_info.label - # Compilation flag. - item.comp = album_info.va + item.asin = album_info.asin + item.catalognum = album_info.catalognum + item.script = album_info.script + item.language = album_info.language + item.country = album_info.country + item.albumstatus = album_info.albumstatus + item.media = album_info.media + item.albumdisambig = album_info.albumdisambig + item.disctitle = track_info.disctitle # Headphones seal of approval item.comments = 'tagged by headphones/beets' diff --git a/lib/beets/autotag/hooks.py b/lib/beets/autotag/hooks.py index b4fa9826..d0042ce3 100644 --- a/lib/beets/autotag/hooks.py +++ b/lib/beets/autotag/hooks.py @@ -1,5 +1,5 @@ # This file is part of beets. -# Copyright 2011, Adrian Sampson. +# Copyright 2012, Adrian Sampson. # # Permission is hereby granted, free of charge, to any person obtaining # a copy of this software and associated documentation files (the @@ -8,15 +8,20 @@ # distribute, sublicense, and/or sell copies of the Software, and to # permit persons to whom the Software is furnished to do so, subject to # the following conditions: -# +# # The above copyright notice and this permission notice shall be # included in all copies or substantial portions of the Software. """Glue between metadata sources and the matching logic.""" +import logging +from collections import namedtuple from lib.beets import plugins from lib.beets.autotag import mb +log = logging.getLogger('beets') + + # Classes used to represent candidate options. class AlbumInfo(object): @@ -36,13 +41,26 @@ class AlbumInfo(object): - ``day``: release day - ``label``: music label responsible for the release - ``mediums``: the number of discs in this release + - ``artist_sort``: name of the release's artist for sorting + - ``releasegroup_id``: MBID for the album's release group + - ``catalognum``: the label's catalog number for the release + - ``script``: character set used for metadata + - ``language``: human language of the metadata + - ``country``: the release country + - ``albumstatus``: MusicBrainz release status (Official, etc.) + - ``media``: delivery mechanism (Vinyl, etc.) + - ``albumdisambig``: MusicBrainz release disambiguation comment + - ``artist_credit``: Release-specific artist name The fields up through ``tracks`` are required. The others are optional and may be None. """ def __init__(self, album, album_id, artist, artist_id, tracks, asin=None, albumtype=None, va=False, year=None, month=None, day=None, - label=None, mediums=None): + label=None, mediums=None, artist_sort=None, + releasegroup_id=None, catalognum=None, script=None, + language=None, country=None, albumstatus=None, media=None, + albumdisambig=None, artist_credit=None): self.album = album self.album_id = album_id self.artist = artist @@ -56,6 +74,16 @@ class AlbumInfo(object): self.day = day self.label = label self.mediums = mediums + self.artist_sort = artist_sort + self.releasegroup_id = releasegroup_id + self.catalognum = catalognum + self.script = script + self.language = language + self.country = country + self.albumstatus = albumstatus + self.media = media + self.albumdisambig = albumdisambig + self.artist_credit = artist_credit class TrackInfo(object): """Describes a canonical track present on a release. Appears as part @@ -66,32 +94,53 @@ class TrackInfo(object): - ``artist``: individual track artist name - ``artist_id`` - ``length``: float: duration of the track in seconds + - ``index``: position on the entire release - ``medium``: the disc number this track appears on in the album - ``medium_index``: the track's position on the disc + - ``artist_sort``: name of the track artist for sorting + - ``disctitle``: name of the individual medium (subtitle) + - ``artist_credit``: Recording-specific artist name Only ``title`` and ``track_id`` are required. The rest of the fields - may be None. + may be None. The indices ``index``, ``medium``, and ``medium_index`` + are all 1-based. """ def __init__(self, title, track_id, artist=None, artist_id=None, - length=None, medium=None, medium_index=None): + length=None, index=None, medium=None, medium_index=None, + artist_sort=None, disctitle=None, artist_credit=None): self.title = title self.track_id = track_id self.artist = artist self.artist_id = artist_id self.length = length + self.index = index self.medium = medium self.medium_index = medium_index + self.artist_sort = artist_sort + self.disctitle = disctitle + self.artist_credit = artist_credit + +AlbumMatch = namedtuple('AlbumMatch', ['distance', 'info', 'mapping', + 'extra_items', 'extra_tracks']) + +TrackMatch = namedtuple('TrackMatch', ['distance', 'info']) # Aggregation of sources. def _album_for_id(album_id): """Get an album corresponding to a MusicBrainz release ID.""" - return mb.album_for_id(album_id) + try: + return mb.album_for_id(album_id) + except mb.MusicBrainzAPIError as exc: + exc.log(log) def _track_for_id(track_id): """Get an item for a recording MBID.""" - return mb.track_for_id(track_id) + try: + return mb.track_for_id(track_id) + except mb.MusicBrainzAPIError as exc: + exc.log(log) def _album_candidates(items, artist, album, va_likely): """Search for album matches. ``items`` is a list of Item objects @@ -104,11 +153,17 @@ def _album_candidates(items, artist, album, va_likely): # Base candidates if we have album and artist to match. if artist and album: - out.extend(mb.match_album(artist, album, len(items))) + try: + out.extend(mb.match_album(artist, album, len(items))) + except mb.MusicBrainzAPIError as exc: + exc.log(log) # Also add VA matches from MusicBrainz where appropriate. if va_likely and album: - out.extend(mb.match_album(None, album, len(items))) + try: + out.extend(mb.match_album(None, album, len(items))) + except mb.MusicBrainzAPIError as exc: + exc.log(log) # Candidates from plugins. out.extend(plugins.candidates(items)) @@ -124,7 +179,10 @@ def _item_candidates(item, artist, title): # MusicBrainz candidates. if artist and title: - out.extend(mb.match_track(artist, title)) + try: + out.extend(mb.match_track(artist, title)) + except mb.MusicBrainzAPIError as exc: + exc.log(log) # Plugin candidates. out.extend(plugins.item_candidates(item)) diff --git a/lib/beets/autotag/match.py b/lib/beets/autotag/match.py index ac4d6cd0..1b42da49 100644 --- a/lib/beets/autotag/match.py +++ b/lib/beets/autotag/match.py @@ -1,5 +1,5 @@ # This file is part of beets. -# Copyright 2011, Adrian Sampson. +# Copyright 2012, Adrian Sampson. # # Permission is hereby granted, free of charge, to any person obtaining # a copy of this software and associated documentation files (the @@ -8,13 +8,15 @@ # distribute, sublicense, and/or sell copies of the Software, and to # permit persons to whom the Software is furnished to do so, subject to # the following conditions: -# +# # The above copyright notice and this permission notice shall be # included in all copies or substantial portions of the Software. """Matches existing metadata with canonical information to identify releases and tracks. """ +from __future__ import division + import logging import re from lib.munkres import Munkres @@ -33,6 +35,8 @@ ALBUM_WEIGHT = 3.0 TRACK_WEIGHT = 1.0 # The weight of a missing track. MISSING_WEIGHT = 0.9 +# The weight of an extra (umatched) track. +UNMATCHED_WEIGHT = 0.6 # These distances are components of the track distance (that is, they # compete against each other but not ARTIST_WEIGHT and ALBUM_WEIGHT; # the overall TRACK_WEIGHT does that). @@ -112,7 +116,7 @@ def string_dist(str1, str2): """ str1 = str1.lower() str2 = str2.lower() - + # Don't penalize strings that move certain words to the end. For # example, "the something" should be considered equal to # "something, the". @@ -126,7 +130,7 @@ def string_dist(str1, str2): for pat, repl in SD_REPLACE: str1 = re.sub(pat, repl, str1) str2 = re.sub(pat, repl, str2) - + # Change the weight for certain string portions matched by a set # of regular expressions. We gradually change the strings and build # up penalties associated with parts of the string that were @@ -137,7 +141,7 @@ def string_dist(str1, str2): # Get strings that drop the pattern. case_str1 = re.sub(pat, '', str1) case_str2 = re.sub(pat, '', str2) - + if case_str1 != str1 or case_str2 != str2: # If the pattern was present (i.e., it is deleted in the # the current case), recalculate the distances for the @@ -146,7 +150,7 @@ def string_dist(str1, str2): case_delta = max(0.0, base_dist - case_dist) if case_delta == 0.0: continue - + # Shift our baseline strings down (to avoid rematching the # same part of the string) and add a scaled distance # amount to the penalties. @@ -155,7 +159,7 @@ def string_dist(str1, str2): base_dist = case_dist penalty += weight * case_delta dist = base_dist + penalty - + return dist def current_metadata(items): @@ -171,42 +175,33 @@ def current_metadata(items): consensus[key] = (freq == len(values)) return likelies['artist'], likelies['album'], consensus['artist'] -def order_items(items, trackinfo): - """Orders the items based on how they match some canonical track - information. Returns a list of Items whose length is equal to the - length of ``trackinfo``. This always produces a result if the - numbers of items is at most the number of TrackInfo objects - (otherwise, returns None). In the case of a partial match, the - returned list may contain None in some positions. +def assign_items(items, tracks): + """Given a list of Items and a list of TrackInfo objects, find the + best mapping between them. Returns a mapping from Items to TrackInfo + objects, a set of extra Items, and a set of extra TrackInfo + objects. These "extra" objects occur when there is an unequal number + of objects of the two types. """ - # Make sure lengths match: If there is less items, it might just be that - # there is some tracks missing. - if len(items) > len(trackinfo): - return None - # Construct the cost matrix. costs = [] - for cur_item in items: + for item in items: row = [] - for i, canon_item in enumerate(trackinfo): - row.append(track_distance(cur_item, canon_item, i+1)) + for i, track in enumerate(tracks): + row.append(track_distance(item, track)) costs.append(row) - + # Find a minimum-cost bipartite matching. matching = Munkres().compute(costs) - # Order items based on the matching. - ordered_items = [None]*len(trackinfo) - for cur_idx, canon_idx in matching: - ordered_items[canon_idx] = items[cur_idx] - return ordered_items + # Produce the output matching. + mapping = dict((items[i], tracks[j]) for (i, j) in matching) + extra_items = set(items) - set(mapping.keys()) + extra_tracks = set(tracks) - set(mapping.values()) + return mapping, extra_items, extra_tracks -def track_distance(item, track_info, track_index=None, incl_artist=False): - """Determines the significance of a track metadata change. Returns - a float in [0.0,1.0]. `track_index` is the track number of the - `track_info` metadata set. If `track_index` is provided and - item.track is set, then these indices are used as a component of - the distance calculation. `incl_artist` indicates that a distance +def track_distance(item, track_info, incl_artist=False): + """Determines the significance of a track metadata change. Returns a + float in [0.0,1.0]. `incl_artist` indicates that a distance component should be included for the track artist (i.e., for various-artist releases). """ @@ -221,7 +216,7 @@ def track_distance(item, track_info, track_index=None, incl_artist=False): diff = min(diff, TRACK_LENGTH_MAX) dist += (diff / TRACK_LENGTH_MAX) * TRACK_LENGTH_WEIGHT dist_max += TRACK_LENGTH_WEIGHT - + # Track title. dist += string_dist(item.title, track_info.title) * TRACK_TITLE_WEIGHT dist_max += TRACK_TITLE_WEIGHT @@ -237,11 +232,11 @@ def track_distance(item, track_info, track_index=None, incl_artist=False): dist_max += TRACK_ARTIST_WEIGHT # Track index. - if track_index and item.track: - if item.track not in (track_index, track_info.medium_index): + if track_info.index and item.track: + if item.track not in (track_info.index, track_info.medium_index): dist += TRACK_INDEX_WEIGHT dist_max += TRACK_INDEX_WEIGHT - + # MusicBrainz track ID. if item.mb_trackid: if item.mb_trackid != track_info.track_id: @@ -255,35 +250,43 @@ def track_distance(item, track_info, track_index=None, incl_artist=False): return dist / dist_max -def distance(items, album_info): +def distance(items, album_info, mapping): """Determines how "significant" an album metadata change would be. - Returns a float in [0.0,1.0]. The list of items must be ordered. + Returns a float in [0.0,1.0]. `album_info` is an AlbumInfo object + reflecting the album to be compared. `items` is a sequence of all + Item objects that will be matched (order is not important). + `mapping` is a dictionary mapping Items to TrackInfo objects; the + keys are a subset of `items` and the values are a subset of + `album_info.tracks`. """ cur_artist, cur_album, _ = current_metadata(items) cur_artist = cur_artist or '' cur_album = cur_album or '' - + # These accumulate the possible distance components. The final # distance will be dist/dist_max. dist = 0.0 dist_max = 0.0 - + # Artist/album metadata. if not album_info.va: dist += string_dist(cur_artist, album_info.artist) * ARTIST_WEIGHT dist_max += ARTIST_WEIGHT dist += string_dist(cur_album, album_info.album) * ALBUM_WEIGHT dist_max += ALBUM_WEIGHT - - # Track distances. - for i, (item, track_info) in enumerate(zip(items, album_info.tracks)): - if item: - dist += track_distance(item, track_info, i+1, album_info.va) * \ - TRACK_WEIGHT - dist_max += TRACK_WEIGHT - else: - dist += MISSING_WEIGHT - dist_max += MISSING_WEIGHT + + # Matched track distances. + for item, track in mapping.iteritems(): + dist += track_distance(item, track, album_info.va) * TRACK_WEIGHT + dist_max += TRACK_WEIGHT + + # Extra and unmatched tracks. + for track in set(album_info.tracks) - set(mapping.values()): + dist += MISSING_WEIGHT + dist_max += MISSING_WEIGHT + for item in set(items) - set(mapping.keys()): + dist += UNMATCHED_WEIGHT + dist_max += UNMATCHED_WEIGHT # Plugin distances. plugin_d, plugin_dm = plugins.album_distance(items, album_info) @@ -294,18 +297,19 @@ def distance(items, album_info): if dist_max == 0.0: return 0.0 else: - return dist/dist_max + return dist / dist_max def match_by_id(items): """If the items are tagged with a MusicBrainz album ID, returns an - info dict for the corresponding album. Otherwise, returns None. + AlbumInfo object for the corresponding album. Otherwise, returns + None. """ # Is there a consensus on the MB album ID? albumids = [item.mb_albumid for item in items if item.mb_albumid] if not albumids: log.debug('No album IDs found.') return None - + # If all album IDs are equal, look up the album. if bool(reduce(lambda x,y: x if x==y else (), albumids)): albumid = albumids[0] @@ -314,21 +318,21 @@ def match_by_id(items): else: log.debug('No album ID consensus.') return None - + #fixme In the future, at the expense of performance, we could use # other IDs (i.e., track and artist) in case the album tag isn't # present, but that event seems very unlikely. def recommendation(results): - """Given a sorted list of result tuples, returns a recommendation - flag (RECOMMEND_STRONG, RECOMMEND_MEDIUM, RECOMMEND_NONE) based - on the results' distances. + """Given a sorted list of AlbumMatch or TrackMatch objects, return a + recommendation flag (RECOMMEND_STRONG, RECOMMEND_MEDIUM, + RECOMMEND_NONE) based on the results' distances. """ if not results: # No candidates: no recommendation. rec = RECOMMEND_NONE else: - min_dist = results[0][0] + min_dist = results[0].distance if min_dist < STRONG_REC_THRESH: # Strong recommendation level. rec = RECOMMEND_STRONG @@ -338,7 +342,7 @@ def recommendation(results): elif min_dist <= MEDIUM_REC_THRESH: # Medium recommendation level. rec = RECOMMEND_MEDIUM - elif results[1][0] - min_dist >= REC_GAP_THRESH: + elif results[1].distance - min_dist >= REC_GAP_THRESH: # Gap between first two candidates is large. rec = RECOMMEND_MEDIUM else: @@ -346,36 +350,28 @@ def recommendation(results): rec = RECOMMEND_NONE return rec -def validate_candidate(items, tuple_dict, info): +def _add_candidate(items, results, info): """Given a candidate AlbumInfo object, attempt to add the candidate - to the output dictionary of result tuples. This involves checking - the track count, ordering the items, checking for duplicates, and - calculating the distance. + to the output dictionary of AlbumMatch objects. This involves + checking the track count, ordering the items, checking for + duplicates, and calculating the distance. """ log.debug('Candidate: %s - %s' % (info.artist, info.album)) # Don't duplicate. - if info.album_id in tuple_dict: + if info.album_id in results: log.debug('Duplicate.') return - # Make sure the album has the correct number of tracks. - if len(items) > len(info.tracks): - log.debug('Too many items to match: %i > %i.' % - (len(items), len(info.tracks))) - return - - # Put items in order. - ordered = order_items(items, info.tracks) - if not ordered: - log.debug('Not orderable.') - return + # Find mapping between the items and the track info. + mapping, extra_items, extra_tracks = assign_items(items, info.tracks) # Get the change distance. - dist = distance(ordered, info) + dist = distance(items, info, mapping) log.debug('Success. Distance: %f' % dist) - tuple_dict[info.album_id] = dist, ordered, info + results[info.album_id] = hooks.AlbumMatch(dist, info, mapping, + extra_items, extra_tracks) def tag_album(items, timid=False, search_artist=None, search_album=None, search_id=None): @@ -383,10 +379,8 @@ def tag_album(items, timid=False, search_artist=None, search_album=None, set of items comprised by an album. Returns everything relevant: - The current artist. - The current album. - - A list of (distance, items, info) tuples where info is a - dictionary containing the inferred tags and items is a - reordered version of the input items list. The candidates are - sorted by distance (i.e., best match first). + - A list of AlbumMatch objects. The candidates are sorted by + distance (i.e., best match first). - A recommendation, one of RECOMMEND_STRONG, RECOMMEND_MEDIUM, or RECOMMEND_NONE; indicating that the first candidate is very likely, it is somewhat likely, or no conclusion could @@ -398,11 +392,11 @@ def tag_album(items, timid=False, search_artist=None, search_album=None, # Get current metadata. cur_artist, cur_album, artist_consensus = current_metadata(items) log.debug('Tagging %s - %s' % (cur_artist, cur_album)) - + # The output result (distance, AlbumInfo) tuples (keyed by MB album # ID). candidates = {} - + # Try to find album indicated by MusicBrainz IDs. if search_id: log.debug('Searching for album ID: ' + search_id) @@ -410,7 +404,7 @@ def tag_album(items, timid=False, search_artist=None, search_album=None, else: id_info = match_by_id(items) if id_info: - validate_candidate(items, candidates, id_info) + _add_candidate(items, candidates, id_info) rec = recommendation(candidates.values()) log.debug('Album ID match recommendation is ' + str(rec)) if candidates and not timid: @@ -427,13 +421,13 @@ def tag_album(items, timid=False, search_artist=None, search_album=None, return cur_artist, cur_album, candidates.values(), rec else: return cur_artist, cur_album, [], RECOMMEND_NONE - + # Search terms. if not (search_artist and search_album): # No explicit search terms -- use current metadata. search_artist, search_album = cur_artist, cur_album log.debug(u'Search terms: %s - %s' % (search_artist, search_album)) - + # Is this album likely to be a "various artist" release? va_likely = ((not artist_consensus) or (search_artist.lower() in VA_ARTISTS) or @@ -445,8 +439,8 @@ def tag_album(items, timid=False, search_artist=None, search_album=None, va_likely) log.debug(u'Evaluating %i candidates.' % len(search_cands)) for info in search_cands: - validate_candidate(items, candidates, info) - + _add_candidate(items, candidates, info) + # Sort and get the recommendation. candidates = sorted(candidates.itervalues()) rec = recommendation(candidates) @@ -455,10 +449,10 @@ def tag_album(items, timid=False, search_artist=None, search_album=None, def tag_item(item, timid=False, search_artist=None, search_title=None, search_id=None): """Attempts to find metadata for a single track. Returns a - `(candidates, recommendation)` pair where `candidates` is a list - of `(distance, track_info)` pairs. `search_artist` and - `search_title` may be used to override the current metadata for - the purposes of the MusicBrainz title; likewise `search_id`. + `(candidates, recommendation)` pair where `candidates` is a list of + TrackMatch objects. `search_artist` and `search_title` may be used + to override the current metadata for the purposes of the MusicBrainz + title; likewise `search_id`. """ # Holds candidates found so far: keys are MBIDs; values are # (distance, TrackInfo) pairs. @@ -471,7 +465,8 @@ def tag_item(item, timid=False, search_artist=None, search_title=None, track_info = hooks._track_for_id(trackid) if track_info: dist = track_distance(item, track_info, incl_artist=True) - candidates[track_info.track_id] = (dist, track_info) + candidates[track_info.track_id] = \ + hooks.TrackMatch(dist, track_info) # If this is a good match, then don't keep searching. rec = recommendation(candidates.values()) if rec == RECOMMEND_STRONG and not timid: @@ -484,7 +479,7 @@ def tag_item(item, timid=False, search_artist=None, search_title=None, return candidates.values(), rec else: return [], RECOMMEND_NONE - + # Search terms. if not (search_artist and search_title): search_artist, search_title = item.artist, item.title @@ -493,7 +488,7 @@ def tag_item(item, timid=False, search_artist=None, search_title=None, # Get and evaluate candidate metadata. for track_info in hooks._item_candidates(item, search_artist, search_title): dist = track_distance(item, track_info, incl_artist=True) - candidates[track_info.track_id] = (dist, track_info) + candidates[track_info.track_id] = hooks.TrackMatch(dist, track_info) # Sort by distance and return with recommendation. log.debug('Found %i candidates.' % len(candidates)) diff --git a/lib/beets/autotag/mb.py b/lib/beets/autotag/mb.py index 6d286f57..b5ad589f 100644 --- a/lib/beets/autotag/mb.py +++ b/lib/beets/autotag/mb.py @@ -1,5 +1,5 @@ # This file is part of beets. -# Copyright 2011, Adrian Sampson. +# Copyright 2012, Adrian Sampson. # # Permission is hereby granted, free of charge, to any person obtaining # a copy of this software and associated documentation files (the @@ -8,7 +8,7 @@ # distribute, sublicense, and/or sell copies of the Software, and to # permit persons to whom the Software is furnished to do so, subject to # the following conditions: -# +# # The above copyright notice and this permission notice shall be # included in all copies or substantial portions of the Software. @@ -16,9 +16,11 @@ """ import logging import lib.musicbrainzngs as musicbrainzngs +import traceback import lib.beets.autotag.hooks import lib.beets +from lib.beets import util SEARCH_LIMIT = 5 VARIOUS_ARTISTS_ID = '89ad4ac3-39f7-470e-963a-56509c546377' @@ -26,8 +28,18 @@ VARIOUS_ARTISTS_ID = '89ad4ac3-39f7-470e-963a-56509c546377' musicbrainzngs.set_useragent('beets', lib.beets.__version__, 'http://beets.radbox.org/') -class ServerBusyError(Exception): pass -class BadResponseError(Exception): pass +class MusicBrainzAPIError(util.HumanReadableException): + """An error while talking to MusicBrainz. The `query` field is the + parameter to the action and may have any type. + """ + def __init__(self, reason, verb, query, tb=None): + self.query = query + super(MusicBrainzAPIError, self).__init__(reason, verb, tb) + + def get_message(self): + return u'"{0}" in {1} with query {2}'.format( + self._reasonstr(), self.verb, repr(self.query) + ) log = logging.getLogger('beets') @@ -45,22 +57,64 @@ else: _mb_release_search = musicbrainzngs.search_releases _mb_recording_search = musicbrainzngs.search_recordings -def track_info(recording, medium=None, medium_index=None): +def _flatten_artist_credit(credit): + """Given a list representing an ``artist-credit`` block, flatten the + data into a triple of joined artist name strings: canonical, sort, and + credit. + """ + artist_parts = [] + artist_sort_parts = [] + artist_credit_parts = [] + for el in credit: + if isinstance(el, basestring): + # Join phrase. + artist_parts.append(el) + artist_credit_parts.append(el) + artist_sort_parts.append(el) + + else: + # An artist. + cur_artist_name = el['artist']['name'] + artist_parts.append(cur_artist_name) + + # Artist sort name. + if 'sort-name' in el['artist']: + artist_sort_parts.append(el['artist']['sort-name']) + else: + artist_sort_parts.append(cur_artist_name) + + # Artist credit. + if 'name' in el: + artist_credit_parts.append(el['name']) + else: + artist_credit_parts.append(cur_artist_name) + + return ( + ''.join(artist_parts), + ''.join(artist_sort_parts), + ''.join(artist_credit_parts), + ) + +def track_info(recording, index=None, medium=None, medium_index=None): """Translates a MusicBrainz recording result dictionary into a beets - ``TrackInfo`` object. ``medium_index``, if provided, is the track's - index (1-based) on its medium. + ``TrackInfo`` object. Three parameters are optional and are used + only for tracks that appear on releases (non-singletons): ``index``, + the overall track number; ``medium``, the disc number; + ``medium_index``, the track's index on its medium. Each number is a + 1-based index. """ info = lib.beets.autotag.hooks.TrackInfo(recording['title'], recording['id'], + index=index, medium=medium, medium_index=medium_index) - # Get the name of the track artist. - if recording.get('artist-credit-phrase'): - info.artist = recording['artist-credit-phrase'] + if recording.get('artist-credit'): + # Get the artist names. + info.artist, info.artist_sort, info.artist_credit = \ + _flatten_artist_credit(recording['artist-credit']) - # Get the ID of the first artist. - if 'artist-credit' in recording: + # Get the ID and sort name of the first artist. artist = recording['artist-credit'][0]['artist'] info.artist_id = artist['id'] @@ -84,25 +138,25 @@ def album_info(release): AlbumInfo object containing the interesting data about that release. """ # Get artist name using join phrases. - artist_parts = [] - for el in release['artist-credit']: - if isinstance(el, basestring): - artist_parts.append(el) - else: - artist_parts.append(el['artist']['name']) - artist_name = ''.join(artist_parts) + artist_name, artist_sort_name, artist_credit_name = \ + _flatten_artist_credit(release['artist-credit']) # Basic info. track_infos = [] + index = 0 for medium in release['medium-list']: + disctitle = medium.get('title') for track in medium['track-list']: + index += 1 ti = track_info(track['recording'], + index, int(medium['position']), int(track['position'])) if track.get('title'): # Track title may be distinct from underling recording # title. ti.title = track['title'] + ti.disctitle = disctitle track_infos.append(ti) info = lib.beets.autotag.hooks.AlbumInfo( release['title'], @@ -111,10 +165,15 @@ def album_info(release): release['artist-credit'][0]['artist']['id'], track_infos, mediums=len(release['medium-list']), + artist_sort=artist_sort_name, + artist_credit=artist_credit_name, ) info.va = info.artist_id == VARIOUS_ARTISTS_ID - if 'asin' in release: - info.asin = release['asin'] + info.asin = release.get('asin') + info.releasegroup_id = release['release-group']['id'] + info.albumdisambig = release['release-group'].get('disambiguation') + info.country = release.get('country') + info.albumstatus = release.get('status') # Release type not always populated. if 'type' in release['release-group']: @@ -137,12 +196,25 @@ def album_info(release): label = label_info['label']['name'] if label != '[no label]': info.label = label + info.catalognum = label_info.get('catalog-number') + + # Text representation data. + if release.get('text-representation'): + rep = release['text-representation'] + info.script = rep.get('script') + info.language = rep.get('language') + + # Media (format). + if release['medium-list']: + first_medium = release['medium-list'][0] + info.media = first_medium.get('format') return info def match_album(artist, album, tracks=None, limit=SEARCH_LIMIT): """Searches for a single album ("release" in MusicBrainz parlance) - and returns an iterator over AlbumInfo objects. + and returns an iterator over AlbumInfo objects. May raise a + MusicBrainzAPIError. The query consists of an artist name, an album name, and, optionally, a number of tracks on the album. @@ -161,7 +233,11 @@ def match_album(artist, album, tracks=None, limit=SEARCH_LIMIT): if not any(criteria.itervalues()): return - res = _mb_release_search(limit=limit, **criteria) + try: + res = _mb_release_search(limit=limit, **criteria) + except musicbrainzngs.MusicBrainzError as exc: + raise MusicBrainzAPIError(exc, 'release search', criteria, + traceback.format_exc()) for release in res['release-list']: # The search result is missing some data (namely, the tracks), # so we just use the ID and fetch the rest of the information. @@ -171,7 +247,7 @@ def match_album(artist, album, tracks=None, limit=SEARCH_LIMIT): def match_track(artist, title, limit=SEARCH_LIMIT): """Searches for a single track and returns an iterable of TrackInfo - objects. + objects. May raise a MusicBrainzAPIError. """ criteria = { 'artist': artist.lower(), @@ -181,28 +257,39 @@ def match_track(artist, title, limit=SEARCH_LIMIT): if not any(criteria.itervalues()): return - res = _mb_recording_search(limit=limit, **criteria) + try: + res = _mb_recording_search(limit=limit, **criteria) + except musicbrainzngs.MusicBrainzError as exc: + raise MusicBrainzAPIError(exc, 'recording search', criteria, + traceback.format_exc()) for recording in res['recording-list']: yield track_info(recording) def album_for_id(albumid): """Fetches an album by its MusicBrainz ID and returns an AlbumInfo - object or None if the album is not found. + object or None if the album is not found. May raise a + MusicBrainzAPIError. """ try: res = musicbrainzngs.get_release_by_id(albumid, RELEASE_INCLUDES) except musicbrainzngs.ResponseError: log.debug('Album ID match failed.') return None + except musicbrainzngs.MusicBrainzError as exc: + raise MusicBrainzAPIError(exc, 'get release by ID', albumid, + traceback.format_exc()) return album_info(res['release']) def track_for_id(trackid): """Fetches a track by its MusicBrainz ID. Returns a TrackInfo object - or None if no track is found. + or None if no track is found. May raise a MusicBrainzAPIError. """ try: res = musicbrainzngs.get_recording_by_id(trackid, TRACK_INCLUDES) except musicbrainzngs.ResponseError: log.debug('Track ID match failed.') return None + except musicbrainzngs.MusicBrainzError as exc: + raise MusicBrainzAPIError(exc, 'get recording by ID', trackid, + traceback.format_exc()) return track_info(res['recording']) diff --git a/lib/beets/importer.py b/lib/beets/importer.py index 1e7affd8..048077b8 100644 --- a/lib/beets/importer.py +++ b/lib/beets/importer.py @@ -8,14 +8,15 @@ # distribute, sublicense, and/or sell copies of the Software, and to # permit persons to whom the Software is furnished to do so, subject to # the following conditions: -# +# # The above copyright notice and this permission notice shall be # included in all copies or substantial portions of the Software. """Provides the basic, interface-agnostic workflow for importing and autotagging music files. """ -from __future__ import with_statement # Python 2.5 +from __future__ import print_function + import os import logging import pickle @@ -23,7 +24,6 @@ from collections import defaultdict from lib.beets import autotag from lib.beets import library -import lib.beets.autotag.art from lib.beets import plugins from lib.beets import util from lib.beets.util import pipeline @@ -56,7 +56,7 @@ def tag_log(logfile, status, path): reflect the reason the album couldn't be tagged. """ if logfile: - print >>logfile, '%s %s' % (status, path) + print('{0} {1}'.format(status, path), file=logfile) logfile.flush() def log_choice(config, task, duplicate=False): @@ -80,23 +80,6 @@ def log_choice(config, task, duplicate=False): elif task.choice_flag is action.SKIP: tag_log(config.logfile, 'skip', path) -def _reopen_lib(lib): - """Because of limitations in SQLite, a given Library is bound to - the thread in which it was created. This function reopens Library - objects so that they can be used from separate threads. - """ - if isinstance(lib, library.Library): - return library.Library( - lib.path, - lib.directory, - lib.path_formats, - lib.art_filename, - lib.timeout, - lib.replacements, - ) - else: - return lib - def _duplicate_check(lib, task): """Check whether an album already exists in the library. Returns a list of Album objects (empty if no duplicates are found). @@ -193,7 +176,7 @@ def _save_state(state): try: with open(STATE_FILE, 'w') as f: pickle.dump(state, f) - except IOError, exc: + except IOError as exc: log.error(u'state file could not be written: %s' % unicode(exc)) @@ -259,11 +242,11 @@ class ImportConfig(object): then never touched again. """ _fields = ['lib', 'paths', 'resume', 'logfile', 'color', 'quiet', - 'quiet_fallback', 'copy', 'write', 'art', 'delete', + 'quiet_fallback', 'copy', 'move', 'write', 'delete', 'choose_match_func', 'should_resume_func', 'threaded', 'autot', 'singletons', 'timid', 'choose_item_func', 'query', 'incremental', 'ignore', - 'resolve_duplicate_func'] + 'resolve_duplicate_func', 'per_disc_numbering'] def __init__(self, **kwargs): for slot in self._fields: setattr(self, slot, kwargs[slot]) @@ -283,6 +266,14 @@ class ImportConfig(object): self.resume = False self.incremental = False + # Copy and move are mutually exclusive. + if self.move: + self.copy = False + + # Only delete when copying. + if not self.copy: + self.delete = False + # The importer task class. @@ -296,6 +287,7 @@ class ImportTask(object): self.items = items self.sentinel = False self.remove_duplicates = False + self.is_album = True @classmethod def done_sentinel(cls, toppath): @@ -324,56 +316,50 @@ class ImportTask(object): obj.is_album = False return obj - def set_match(self, cur_artist, cur_album, candidates, rec): + def set_candidates(self, cur_artist, cur_album, candidates, rec): """Sets the candidates for this album matched by the `autotag.tag_album` method. """ + assert self.is_album assert not self.sentinel self.cur_artist = cur_artist self.cur_album = cur_album self.candidates = candidates self.rec = rec - self.is_album = True - def set_null_match(self): + def set_null_candidates(self): """Set the candidates to indicate no album match was found. """ - self.set_match(None, None, None, None) + self.cur_artist = None + self.cur_album = None + self.candidates = None + self.rec = None - def set_item_match(self, candidates, rec): + def set_item_candidates(self, candidates, rec): """Set the match for a single-item task.""" assert not self.is_album assert self.item is not None - self.item_match = (candidates, rec) - - def set_null_item_match(self): - """For single-item tasks, mark the item as having no matches. - """ - assert not self.is_album - assert self.item is not None - self.item_match = None + self.candidates = candidates + self.rec = rec def set_choice(self, choice): - """Given either an (info, items) tuple or an action constant, - indicates that an action has been selected by the user (or - automatically). + """Given an AlbumMatch or TrackMatch object or an action constant, + indicates that an action has been selected for this task. """ assert not self.sentinel # Not part of the task structure: assert choice not in (action.MANUAL, action.MANUAL_ID) - assert choice != action.APPLY # Only used internally. + assert choice != action.APPLY # Only used internally. if choice in (action.SKIP, action.ASIS, action.TRACKS): self.choice_flag = choice - self.info = None + self.match = None else: - assert not isinstance(choice, action) if self.is_album: - info, items = choice - self.items = items # Reordered items list. + assert isinstance(choice, autotag.AlbumMatch) else: - info = choice - self.info = info - self.choice_flag = action.APPLY # Implicit choice. + assert isinstance(choice, autotag.TrackMatch) + self.choice_flag = action.APPLY # Implicit choice. + self.match = choice def save_progress(self): """Updates the progress state to indicate that this album has @@ -393,7 +379,9 @@ class ImportTask(object): if self.sentinel or self.is_album: history_add(self.path) + # Logical decisions. + def should_write_tags(self): """Should new info be written to the files' metadata?""" if self.choice_flag == action.APPLY: @@ -402,16 +390,16 @@ class ImportTask(object): return False else: assert False - def should_fetch_art(self): - """Should album art be downloaded for this album?""" - return self.should_write_tags() and self.is_album + def should_skip(self): """After a choice has been made, returns True if this is a sentinel or it has been marked for skipping. """ return self.sentinel or self.choice_flag == action.SKIP - # Useful data. + + # Convenient data. + def chosen_ident(self): """Returns identifying metadata about the current choice. For albums, this is an (artist, album) pair. For items, this is @@ -424,12 +412,41 @@ class ImportTask(object): if self.choice_flag is action.ASIS: return (self.cur_artist, self.cur_album) elif self.choice_flag is action.APPLY: - return (self.info.artist, self.info.album) + return (self.match.info.artist, self.match.info.album) else: if self.choice_flag is action.ASIS: return (self.item.artist, self.item.title) elif self.choice_flag is action.APPLY: - return (self.info.artist, self.info.title) + return (self.match.info.artist, self.match.info.title) + + def imported_items(self): + """Return a list of Items that should be added to the library. + If this is an album task, return the list of items in the + selected match or everything if the choice is ASIS. If this is a + singleton task, return a list containing the item. + """ + if self.is_album: + if self.choice_flag == action.ASIS: + return list(self.items) + elif self.choice_flag == action.APPLY: + return self.match.mapping.keys() + else: + assert False + else: + return [self.item] + + + # Utilities. + + def prune(self, filename): + """Prune any empty directories above the given file. If this + task has no `toppath` or the file path provided is not within + the `toppath`, then this function has no effect. Similarly, if + the file still exists, no pruning is performed, so it's safe to + call when the file in question may not have been removed. + """ + if self.toppath and not os.path.exists(filename): + util.prune_dirs(os.path.dirname(filename), self.toppath) # Full-album pipeline stages. @@ -464,14 +481,14 @@ def read_tasks(config): if config.incremental: incremental_skipped = 0 history_dirs = history_get() - + for toppath in config.paths: # Check whether the path is to a file. if config.singletons and not os.path.isdir(syspath(toppath)): item = library.Item.from_path(toppath) yield ImportTask.item_task(item) continue - + # Produce paths under this directory. if progress: resume_dir = resume_dirs.get(toppath) @@ -513,16 +530,14 @@ def query_tasks(config): Instead of finding files from the filesystem, a query is used to match items from the library. """ - lib = _reopen_lib(config.lib) - if config.singletons: # Search for items. - for item in lib.items(config.query): + for item in config.lib.items(config.query): yield ImportTask.item_task(item) else: # Search for albums. - for album in lib.albums(config.query): + for album in config.lib.albums(config.query): log.debug('yielding album %i: %s - %s' % (album.id, album.albumartist, album.album)) items = list(album.items()) @@ -540,11 +555,13 @@ def initial_lookup(config): if task.sentinel: continue + plugins.send('import_task_start', task=task, config=config) + log.debug('Looking up: %s' % task.path) try: - task.set_match(*autotag.tag_album(task.items, config.timid)) + task.set_candidates(*autotag.tag_album(task.items, config.timid)) except autotag.AutotagError: - task.set_null_match() + task.set_null_candidates() def user_query(config): """A coroutine for interfacing with the user about the tagging @@ -552,18 +569,18 @@ def user_query(config): a file-like object for logging the import process. The coroutine accepts and yields ImportTask objects. """ - lib = _reopen_lib(config.lib) recent = set() task = None while True: task = yield task if task.sentinel: continue - + # Ask the user for a choice. choice = config.choose_match_func(task, config) task.set_choice(choice) log_choice(config, task) + plugins.send('import_task_choice', task=task, config=config) # As-tracks: transition to singleton workflow. if choice is action.TRACKS: @@ -577,7 +594,7 @@ def user_query(config): while True: item_task = yield item_tasks.append(item_task) - ipl = pipeline.Pipeline((emitter(), item_lookup(config), + ipl = pipeline.Pipeline((emitter(), item_lookup(config), item_query(config), collector())) ipl.run_sequential() task = pipeline.multiple(item_tasks) @@ -589,7 +606,7 @@ def user_query(config): # The "recent" set keeps track of identifiers for recently # imported albums -- those that haven't reached the database # yet. - if ident in recent or _duplicate_check(lib, task): + if ident in recent or _duplicate_check(config.lib, task): config.resolve_duplicate_func(task, config) log_choice(config, task, True) recent.add(ident) @@ -608,21 +625,20 @@ def show_progress(config): log.info(task.path) # Behave as if ASIS were selected. - task.set_null_match() + task.set_null_candidates() task.set_choice(action.ASIS) - + def apply_choices(config): - """A coroutine for applying changes to albums during the autotag - process. + """A coroutine for applying changes to albums and singletons during + the autotag process. """ - lib = _reopen_lib(config.lib) task = None - while True: + while True: task = yield task if task.should_skip(): continue - items = [i for i in task.items if i] if task.is_album else [task.item] + items = task.imported_items() # Clear IDs in case the items are being re-tagged. for item in items: item.id = None @@ -631,9 +647,13 @@ def apply_choices(config): # Change metadata. if task.should_write_tags(): if task.is_album: - autotag.apply_metadata(task.items, task.info) + autotag.apply_metadata( + task.match.info, task.match.mapping, + per_disc_numbering=config.per_disc_numbering + ) else: - autotag.apply_item_metadata(task.item, task.info) + autotag.apply_item_metadata(task.item, task.match.info) + plugins.send('import_task_apply', config=config, task=task) # Infer album-level fields. if task.is_album: @@ -642,14 +662,14 @@ def apply_choices(config): # Find existing item entries that these are replacing (for # re-imports). Old album structures are automatically cleaned up # when the last item is removed. - replaced_items = defaultdict(list) + task.replaced_items = defaultdict(list) for item in items: - dup_items = lib.items(library.MatchQuery('path', item.path)) + dup_items = config.lib.items(library.MatchQuery('path', item.path)) for dup_item in dup_items: - replaced_items[item].append(dup_item) + task.replaced_items[item].append(dup_item) log.debug('replacing item %i: %s' % (dup_item.id, displayable_path(item.path))) - log.debug('%i of %i items replaced' % (len(replaced_items), + log.debug('%i of %i items replaced' % (len(task.replaced_items), len(items))) # Find old items that should be replaced as part of a duplicate @@ -657,93 +677,111 @@ def apply_choices(config): duplicate_items = [] if task.remove_duplicates: if task.is_album: - for album in _duplicate_check(lib, task): + for album in _duplicate_check(config.lib, task): duplicate_items += album.items() else: - duplicate_items = _item_duplicate_check(lib, task) + duplicate_items = _item_duplicate_check(config.lib, task) log.debug('removing %i old duplicated items' % len(duplicate_items)) # Delete duplicate files that are located inside the library # directory. for duplicate_path in [i.path for i in duplicate_items]: - if lib.directory in util.ancestry(duplicate_path): + if config.lib.directory in util.ancestry(duplicate_path): log.debug(u'deleting replaced duplicate %s' % util.displayable_path(duplicate_path)) - util.soft_remove(duplicate_path) + util.remove(duplicate_path) util.prune_dirs(os.path.dirname(duplicate_path), - lib.directory) + config.lib.directory) - # Move/copy files. - task.old_paths = [item.path for item in items] - for item in items: - if config.copy: - # If we're replacing an item, then move rather than - # copying. - old_path = item.path - do_copy = not bool(replaced_items[item]) - lib.move(item, do_copy, task.is_album) - if not do_copy: - # If we moved the item, remove the now-nonexistent - # file from old_paths. - task.old_paths.remove(old_path) - if config.write and task.should_write_tags(): - item.write() - - # Add items to library. We consolidate this at the end to avoid - # locking while we do the copying and tag updates. - try: + # Add items -- before path changes -- to the library. We add the + # items now (rather than at the end) so that album structures + # are in place before calls to destination(). + with config.lib.transaction(): # Remove old items. - for replaced in replaced_items.itervalues(): + for replaced in task.replaced_items.itervalues(): for item in replaced: - lib.remove(item) + config.lib.remove(item) for item in duplicate_items: - lib.remove(item) + config.lib.remove(item) # Add new ones. if task.is_album: # Add an album. - album = lib.add_album(items) + album = config.lib.add_album(items) task.album_id = album.id else: # Add tracks. for item in items: - lib.add(item) - finally: - lib.save() + config.lib.add(item) -def fetch_art(config): - """A coroutine that fetches and applies album art for albums where - appropriate. +def plugin_stage(config, func): + """A coroutine (pipeline stage) that calls the given function with + each non-skipped import task. These stages occur between applying + metadata changes and moving/copying/writing files. + """ + task = None + while True: + task = yield task + if task.should_skip(): + continue + func(config, task) + +def manipulate_files(config): + """A coroutine (pipeline stage) that performs necessary file + manipulations *after* items have been added to the library. """ - lib = _reopen_lib(config.lib) task = None while True: task = yield task if task.should_skip(): continue - if task.should_fetch_art(): - artpath = lib.beets.autotag.art.art_for_album(task.info, task.path) + # Move/copy files. + items = task.imported_items() + task.old_paths = [item.path for item in items] # For deletion. + for item in items: + if config.move: + # Just move the file. + old_path = item.path + config.lib.move(item, False) + task.prune(old_path) + elif config.copy: + # If it's a reimport, move in-library files and copy + # out-of-library files. Otherwise, copy and keep track + # of the old path. + old_path = item.path + if task.replaced_items[item]: + # This is a reimport. Move in-library files and copy + # out-of-library files. + if config.lib.directory in util.ancestry(old_path): + config.lib.move(item, False) + # We moved the item, so remove the + # now-nonexistent file from old_paths. + task.old_paths.remove(old_path) + else: + config.lib.move(item, True) + else: + # A normal import. Just copy files and keep track of + # old paths. + config.lib.move(item, True) - # Save the art if any was found. - if artpath: - try: - album = lib.get_album(task.album_id) - album.set_art(artpath) - if config.delete and not util.samefile(artpath, - album.artpath): - # Delete the original file after it's imported. - os.remove(artpath) - finally: - lib.save(False) + if config.write and task.should_write_tags(): + item.write() + + # Save new paths. + with config.lib.transaction(): + for item in items: + config.lib.store(item) + + # Plugin event. + plugins.send('import_task_files', config=config, task=task) def finalize(config): """A coroutine that finishes up importer tasks. In particular, the coroutine sends plugin events, deletes old files, and saves progress. This is a "terminal" coroutine (it yields None). """ - lib = _reopen_lib(config.lib) while True: task = yield if task.should_skip(): @@ -753,15 +791,17 @@ def finalize(config): task.save_history() continue - items = [i for i in task.items if i] if task.is_album else [task.item] + items = task.imported_items() # Announce that we've added an album. if task.is_album: - album = lib.get_album(task.album_id) - plugins.send('album_imported', lib=lib, album=album, config=config) + album = config.lib.get_album(task.album_id) + plugins.send('album_imported', + lib=config.lib, album=album, config=config) else: for item in items: - plugins.send('item_imported', lib=lib, item=item, config=config) + plugins.send('item_imported', + lib=config.lib, item=item, config=config) # Finally, delete old files. if config.copy and config.delete: @@ -769,11 +809,8 @@ def finalize(config): for old_path in task.old_paths: # Only delete files that were actually copied. if old_path not in new_paths: - os.remove(syspath(old_path)) - # Clean up directory if it is emptied. - if task.toppath: - util.prune_dirs(os.path.dirname(old_path), - task.toppath) + util.remove(syspath(old_path), False) + task.prune(old_path) # Update progress. if config.resume is not False: @@ -794,13 +831,14 @@ def item_lookup(config): if task.sentinel: continue - task.set_item_match(*autotag.tag_item(task.item, config.timid)) + plugins.send('import_task_start', task=task, config=config) + + task.set_item_candidates(*autotag.tag_item(task.item, config.timid)) def item_query(config): """A coroutine that queries the user for input on single-item lookups. """ - lib = _reopen_lib(config.lib) task = None recent = set() while True: @@ -811,11 +849,12 @@ def item_query(config): choice = config.choose_item_func(task, config) task.set_choice(choice) log_choice(config, task) + plugins.send('import_task_choice', task=task, config=config) # Duplicate check. if task.choice_flag in (action.ASIS, action.APPLY): ident = task.chosen_ident() - if ident in recent or _item_duplicate_check(lib, task): + if ident in recent or _item_duplicate_check(config.lib, task): config.resolve_duplicate_func(task, config) log_choice(config, task, True) recent.add(ident) @@ -832,7 +871,7 @@ def item_progress(config): continue log.info(displayable_path(task.item.path)) - task.set_null_item_match() + task.set_null_candidates() task.set_choice(action.ASIS) @@ -843,7 +882,7 @@ def run_import(**kwargs): ImportConfig. """ config = ImportConfig(**kwargs) - + # Set up the pipeline. if config.query is None: stages = [read_tasks(config)] @@ -864,8 +903,9 @@ def run_import(**kwargs): # When not autotagging, just display progress. stages += [show_progress(config)] stages += [apply_choices(config)] - if config.art: - stages += [fetch_art(config)] + for stage_func in plugins.import_stages(): + stages.append(plugin_stage(config, stage_func)) + stages += [manipulate_files(config)] stages += [finalize(config)] pl = pipeline.Pipeline(stages) diff --git a/lib/beets/library.py b/lib/beets/library.py index 97e7c865..31be2460 100644 --- a/lib/beets/library.py +++ b/lib/beets/library.py @@ -1,5 +1,5 @@ # This file is part of beets. -# Copyright 2011, Adrian Sampson. +# Copyright 2012, Adrian Sampson. # # Permission is hereby granted, free of charge, to any person obtaining # a copy of this software and associated documentation files (the @@ -8,17 +8,23 @@ # distribute, sublicense, and/or sell copies of the Software, and to # permit persons to whom the Software is furnished to do so, subject to # the following conditions: -# +# # The above copyright notice and this permission notice shall be # included in all copies or substantial portions of the Software. +"""The core data store and collection logic for beets. +""" import sqlite3 import os import re import sys import logging import shlex -#from unidecode import unidecode +import unicodedata +import threading +import contextlib +from collections import defaultdict +# from unidecode import unidecode from lib.beets.mediafile import MediaFile from lib.beets import plugins from lib.beets import util @@ -40,30 +46,47 @@ ITEM_FIELDS = [ ('path', 'blob', False, False), ('album_id', 'int', False, False), - ('title', 'text', True, True), - ('artist', 'text', True, True), - ('album', 'text', True, True), - ('albumartist', 'text', True, True), - ('genre', 'text', True, True), - ('composer', 'text', True, True), - ('grouping', 'text', True, True), - ('year', 'int', True, True), - ('month', 'int', True, True), - ('day', 'int', True, True), - ('track', 'int', True, True), - ('tracktotal', 'int', True, True), - ('disc', 'int', True, True), - ('disctotal', 'int', True, True), - ('lyrics', 'text', True, True), - ('comments', 'text', True, True), - ('bpm', 'int', True, True), - ('comp', 'bool', True, True), - ('mb_trackid', 'text', True, True), - ('mb_albumid', 'text', True, True), - ('mb_artistid', 'text', True, True), - ('mb_albumartistid', 'text', True, True), - ('albumtype', 'text', True, True), - ('label', 'text', True, True), + ('title', 'text', True, True), + ('artist', 'text', True, True), + ('artist_sort', 'text', True, True), + ('artist_credit', 'text', True, True), + ('album', 'text', True, True), + ('albumartist', 'text', True, True), + ('albumartist_sort', 'text', True, True), + ('albumartist_credit', 'text', True, True), + ('genre', 'text', True, True), + ('composer', 'text', True, True), + ('grouping', 'text', True, True), + ('year', 'int', True, True), + ('month', 'int', True, True), + ('day', 'int', True, True), + ('track', 'int', True, True), + ('tracktotal', 'int', True, True), + ('disc', 'int', True, True), + ('disctotal', 'int', True, True), + ('lyrics', 'text', True, True), + ('comments', 'text', True, True), + ('bpm', 'int', True, True), + ('comp', 'bool', True, True), + ('mb_trackid', 'text', True, True), + ('mb_albumid', 'text', True, True), + ('mb_artistid', 'text', True, True), + ('mb_albumartistid', 'text', True, True), + ('albumtype', 'text', True, True), + ('label', 'text', True, True), + ('acoustid_fingerprint', 'text', True, True), + ('acoustid_id', 'text', True, True), + ('mb_releasegroupid', 'text', True, True), + ('asin', 'text', True, True), + ('catalognum', 'text', True, True), + ('script', 'text', True, True), + ('language', 'text', True, True), + ('country', 'text', True, True), + ('albumstatus', 'text', True, True), + ('media', 'text', True, True), + ('albumdisambig', 'text', True, True), + ('disctitle', 'text', True, True), + ('encoder', 'text', True, True), ('length', 'real', False, True), ('bitrate', 'int', False, True), @@ -84,19 +107,30 @@ ALBUM_FIELDS = [ ('id', 'integer primary key', False), ('artpath', 'blob', False), - ('albumartist', 'text', True), - ('album', 'text', True), - ('genre', 'text', True), - ('year', 'int', True), - ('month', 'int', True), - ('day', 'int', True), - ('tracktotal', 'int', True), - ('disctotal', 'int', True), - ('comp', 'bool', True), - ('mb_albumid', 'text', True), - ('mb_albumartistid', 'text', True), - ('albumtype', 'text', True), - ('label', 'text', True), + ('albumartist', 'text', True), + ('albumartist_sort', 'text', True), + ('albumartist_credit', 'text', True, True), + ('album', 'text', True), + ('genre', 'text', True), + ('year', 'int', True), + ('month', 'int', True), + ('day', 'int', True), + ('tracktotal', 'int', True), + ('disctotal', 'int', True), + ('comp', 'bool', True), + ('mb_albumid', 'text', True), + ('mb_albumartistid', 'text', True), + ('albumtype', 'text', True), + ('label', 'text', True), + ('mb_releasegroupid', 'text', True), + ('asin', 'text', True), + ('catalognum', 'text', True), + ('script', 'text', True), + ('language', 'text', True), + ('country', 'text', True), + ('albumstatus', 'text', True), + ('media', 'text', True), + ('albumdisambig', 'text', True), ] ALBUM_KEYS = [f[0] for f in ALBUM_FIELDS] ALBUM_KEYS_ITEM = [f[0] for f in ALBUM_FIELDS if f[2]] @@ -110,11 +144,37 @@ ITEM_DEFAULT_FIELDS = ARTIST_DEFAULT_FIELDS + ALBUM_DEFAULT_FIELDS + \ # Special path format key. PF_KEY_DEFAULT = 'default' + # Logger. log = logging.getLogger('beets') if not log.handlers: log.addHandler(logging.StreamHandler()) +# A little SQL utility. +def _orelse(exp1, exp2): + """Generates an SQLite expression that evaluates to exp1 if exp1 is + non-null and non-empty or exp2 otherwise. + """ + return ('(CASE {0} WHEN NULL THEN {1} ' + 'WHEN "" THEN {1} ' + 'ELSE {0} END)').format(exp1, exp2) + +# An SQLite function for regular expression matching. +def _regexp(expr, val): + """Return a boolean indicating whether the regular expression `expr` + matches `val`. + """ + if val is None or expr is None: + return False + if not isinstance(val, basestring): + val = unicode(val) + try: + res = re.search(expr, val) + except re.error: + # Invalid regular expression. + return False + return res is not None + # Exceptions. @@ -129,7 +189,7 @@ class Item(object): self.dirty = {} self._fill_record(values) self._clear_dirty() - + @classmethod def from_path(cls, path): """Creates a new item from the media file at the specified path. @@ -139,7 +199,7 @@ class Item(object): 'album_id': None, }) i.read(path) - i.mtime = i.current_mtime() # Initial mtime. + i.mtime = i.current_mtime() # Initial mtime. return i def _fill_record(self, values): @@ -175,7 +235,7 @@ class Item(object): sets the record entry for that key to value. Note that to change the attribute in the database or in the file's tags, one must call store() or write(). - + Otherwise, performs an ordinary setattr. """ # Encode unicode paths and read buffers. @@ -187,17 +247,17 @@ class Item(object): if key in ITEM_KEYS: # If the value changed, mark the field as dirty. - if (not (key in self.record)) or (self.record[key] != value): + if (key not in self.record) or (self.record[key] != value): self.record[key] = value self.dirty[key] = True if key in ITEM_KEYS_WRITABLE: self.mtime = 0 # Reset mtime on dirty. else: super(Item, self).__setattr__(key, value) - - + + # Interaction with file metadata. - + def read(self, read_path=None): """Read the metadata from the associated file. If read_path is specified, read metadata from that file instead. @@ -215,7 +275,7 @@ class Item(object): # Database's mtime should now reflect the on-disk value. if read_path == self.path: self.mtime = self.current_mtime() - + def write(self): """Writes the item's metadata to the associated file. """ @@ -242,7 +302,7 @@ class Item(object): util.copy(self.path, dest) else: util.move(self.path, dest) - + # Either copying or moving succeeded, so update the stored path. self.path = dest @@ -253,6 +313,57 @@ class Item(object): return int(os.path.getmtime(syspath(self.path))) + # Templating. + + def evaluate_template(self, template, lib=None, sanitize=False, + pathmod=None): + """Evaluates a Template object using the item's fields. If `lib` + is provided, it is used to map some fields to the item's album + (if available) and is made available to template functions. If + `sanitize`, then each value will be sanitized for inclusion in a + file path. + """ + pathmod = pathmod or os.path + + # Get the item's Album if it has one. + album = lib.get_album(self) + + # Build the mapping for substitution in the template, + # beginning with the values from the database. + mapping = {} + for key in ITEM_KEYS_META: + # Get the values from either the item or its album. + if key in ALBUM_KEYS_ITEM and album is not None: + # From album. + value = getattr(album, key) + else: + # From Item. + value = getattr(self, key) + if sanitize: + value = util.sanitize_for_path(value, pathmod, key) + mapping[key] = value + + # Use the album artist if the track artist is not set and + # vice-versa. + if not mapping['artist']: + mapping['artist'] = mapping['albumartist'] + if not mapping['albumartist']: + mapping['albumartist'] = mapping['artist'] + + # Get values from plugins. + for key, value in plugins.template_values(self).iteritems(): + if sanitize: + value = util.sanitize_for_path(value, pathmod, key) + mapping[key] = value + + # Get template functions. + funcs = DefaultTemplateFunctions(self, lib, pathmod).functions() + funcs.update(plugins.template_funcs()) + + # Perform substitution. + return template.substitute(mapping, funcs) + + # Library queries. class Query(object): @@ -279,16 +390,14 @@ class Query(object): clause, subvals = self.clause() return ('SELECT ' + columns + ' FROM items WHERE ' + clause, subvals) - def count(self, library): + def count(self, tx): """Returns `(num, length)` where `num` is the number of items in the library matching this query and `length` is their total length in seconds. """ clause, subvals = self.clause() statement = 'SELECT COUNT(id), SUM(length) FROM items WHERE ' + clause - c = library.conn.execute(statement, subvals) - result = c.fetchone() - c.close() + result = tx.query(statement, subvals)[0] return (result[0], result[1] or 0.0) class FieldQuery(Query): @@ -300,7 +409,7 @@ class FieldQuery(Query): raise InvalidFieldError(field + ' is not an item key') self.field = field self.pattern = pattern - + class MatchQuery(FieldQuery): """A query that looks for exact matches in an item field.""" def clause(self): @@ -325,6 +434,21 @@ class SubstringQuery(FieldQuery): value = getattr(item, self.field) or '' return self.pattern.lower() in value.lower() +class RegexpQuery(FieldQuery): + """A query that matches a regular expression in a specific item field.""" + def __init__(self, field, pattern): + super(RegexpQuery, self).__init__(field, pattern) + self.regexp = re.compile(pattern) + + def clause(self): + clause = self.field + " REGEXP ?" + subvals = [self.pattern] + return clause, subvals + + def match(self, item): + value = getattr(item, self.field) or '' + return self.regexp.search(value) is not None + class BooleanQuery(MatchQuery): """Matches a boolean field. Pattern should either be a boolean or a string reflecting a boolean. @@ -355,7 +479,7 @@ class CollectionQuery(Query): """ def __init__(self, subqueries=()): self.subqueries = subqueries - + # is there a better way to do this? def __len__(self): return len(self.subqueries) def __getitem__(self, key): return self.subqueries[key] @@ -374,24 +498,34 @@ class CollectionQuery(Query): subvals += subq_subvals clause = (' ' + joiner + ' ').join(clause_parts) return clause, subvals - - # regular expression for _parse_query_part, below - _pq_regex = re.compile(# non-grouping optional segment for the keyword - r'(?:' - r'(\S+?)' # the keyword - r'(? %s' % (field, oldval, newval)) +# fields: Shows a list of available fields for queries and format strings. +fields_cmd = ui.Subcommand('fields', + help='show fields available for queries and format strings') +def fields_func(lib, config, opts, args): + print("Available item fields:") + print(" " + "\n ".join([key for key in library.ITEM_KEYS])) + print("\nAvailable album fields:") + print(" " + "\n ".join([key for key in library.ALBUM_KEYS])) + +fields_cmd.func = fields_func +default_commands.append(fields_cmd) + + # import: Autotagger and importer. DEFAULT_IMPORT_COPY = True +DEFAULT_IMPORT_MOVE = False DEFAULT_IMPORT_WRITE = True DEFAULT_IMPORT_DELETE = False DEFAULT_IMPORT_AUTOT = True DEFAULT_IMPORT_TIMID = False -DEFAULT_IMPORT_ART = True DEFAULT_IMPORT_QUIET = False DEFAULT_IMPORT_QUIET_FALLBACK = 'skip' DEFAULT_IMPORT_RESUME = None # "ask" @@ -101,6 +114,7 @@ DEFAULT_COLOR = True DEFAULT_IGNORE = [ '.*', '*~', ] +DEFAULT_PER_DISC_NUMBERING = False VARIOUS_ARTISTS = u'Various Artists' @@ -122,10 +136,11 @@ def dist_string(dist, color): out = ui.colorize('red', out) return out -def show_change(cur_artist, cur_album, items, info, dist, color=True): - """Print out a representation of the changes that will be made if - tags are changed from (cur_artist, cur_album, items) to info with - distance dist. +def show_change(cur_artist, cur_album, match, color=True, + per_disc_numbering=False): + """Print out a representation of the changes that will be made if an + album's tags are changed according to `match`, which must be an AlbumMatch + object. """ def show_album(artist, album, partial=False): if artist: @@ -148,14 +163,25 @@ def show_change(cur_artist, cur_album, items, info, dist, color=True): out += u' ' + warning print_(out) - # Record if the match is partial or not. - partial_match = None in items + def format_index(track_info): + """Return a string representing the track index of the given + TrackInfo object. + """ + if per_disc_numbering: + if match.info.mediums > 1: + return u'{0}-{1}'.format(track_info.medium, + track_info.medium_index) + else: + return unicode(track_info.medium_index) + else: + return unicode(track_info.index) # Identify the album in question. - if cur_artist != info.artist or \ - (cur_album != info.album and info.album != VARIOUS_ARTISTS): - artist_l, artist_r = cur_artist or '', info.artist - album_l, album_r = cur_album or '', info.album + if cur_artist != match.info.artist or \ + (cur_album != match.info.album and + match.info.album != VARIOUS_ARTISTS): + artist_l, artist_r = cur_artist or '', match.info.artist + album_l, album_r = cur_album or '', match.info.album if artist_r == VARIOUS_ARTISTS: # Hide artists for VA releases. artist_l, artist_r = u'', u'' @@ -169,8 +195,8 @@ def show_change(cur_artist, cur_album, items, info, dist, color=True): print_("To:") show_album(artist_r, album_r) else: - message = u"Tagging: %s - %s" % (info.artist, info.album) - if partial_match: + message = u"Tagging: %s - %s" % (match.info.artist, match.info.album) + if match.extra_items or match.extra_tracks: warning = PARTIAL_MATCH_MESSAGE if color: warning = ui.colorize('yellow', PARTIAL_MATCH_MESSAGE) @@ -178,18 +204,17 @@ def show_change(cur_artist, cur_album, items, info, dist, color=True): print_(message) # Distance/similarity. - print_('(Similarity: %s)' % dist_string(dist, color)) + print_('(Similarity: %s)' % dist_string(match.distance, color)) # Tracks. - missing_tracks = [] - for i, (item, track_info) in enumerate(zip(items, info.tracks)): - if not item: - missing_tracks.append((i, track_info)) - continue - + pairs = match.mapping.items() + pairs.sort(key=lambda (_, track_info): track_info.index) + for item, track_info in pairs: # Get displayable LHS and RHS values. cur_track = unicode(item.track) - new_track = unicode(i+1) + new_track = format_index(track_info) + tracks_differ = item.track not in (track_info.index, + track_info.medium_index) cur_title = item.title new_title = track_info.title if item.length and track_info.length: @@ -198,48 +223,55 @@ def show_change(cur_artist, cur_album, items, info, dist, color=True): if color: cur_length = ui.colorize('red', cur_length) new_length = ui.colorize('red', new_length) - + # Possibly colorize changes. if color: cur_title, new_title = ui.colordiff(cur_title, new_title) - if cur_track != new_track: - cur_track = ui.colorize('red', cur_track) - new_track = ui.colorize('red', new_track) + cur_track = ui.colorize('red', cur_track) + new_track = ui.colorize('red', new_track) # Show filename (non-colorized) when title is not set. if not item.title.strip(): cur_title = displayable_path(os.path.basename(item.path)) - + if cur_title != new_title: lhs, rhs = cur_title, new_title - if cur_track != new_track: + if tracks_differ: lhs += u' (%s)' % cur_track rhs += u' (%s)' % new_track print_(u" * %s -> %s" % (lhs, rhs)) else: line = u' * %s' % item.title display = False - if cur_track != new_track: + if tracks_differ: display = True line += u' (%s -> %s)' % (cur_track, new_track) if item.length and track_info.length and \ abs(item.length - track_info.length) > 2.0: display = True - line += u' (%s -> %s)' % (cur_length, new_length) + line += u' (%s vs. %s)' % (cur_length, new_length) if display: print_(line) - for i, track_info in missing_tracks: - line = u' * Missing track: %s (%d)' % (track_info.title, i+1) + + # Missing and unmatched tracks. + for track_info in match.extra_tracks: + line = u' * Missing track: {0} ({1})'.format(track_info.title, + format_index(track_info)) + if color: + line = ui.colorize('yellow', line) + print_(line) + for item in match.extra_items: + line = u' * Unmatched track: {0} ({1})'.format(item.title, item.track) if color: line = ui.colorize('yellow', line) print_(line) -def show_item_change(item, info, dist, color): +def show_item_change(item, match, color): """Print out the change that would occur by tagging `item` with the - metadata from `info`. + metadata from `match`, a TrackMatch object. """ - cur_artist, new_artist = item.artist, info.artist - cur_title, new_title = item.title, info.title + cur_artist, new_artist = item.artist, match.info.artist + cur_title, new_title = item.title, match.info.title if cur_artist != new_artist or cur_title != new_title: if color: @@ -254,7 +286,7 @@ def show_item_change(item, info, dist, color): else: print_("Tagging track: %s - %s" % (cur_artist, cur_title)) - print_('(Similarity: %s)' % dist_string(dist, color)) + print_('(Similarity: %s)' % dist_string(match.distance, color)) def should_resume(config, path): return ui.input_yn("Import of the directory:\n%s" @@ -273,17 +305,17 @@ def _quiet_fall_back(config): return config.quiet_fallback def choose_candidate(candidates, singleton, rec, color, timid, - cur_artist=None, cur_album=None, item=None): + cur_artist=None, cur_album=None, item=None, + itemcount=None, per_disc_numbering=False): """Given a sorted list of candidates, ask the user for a selection - of which candidate to use. Applies to both full albums and - singletons (tracks). For albums, the candidates are `(dist, items, - info)` triples and `cur_artist` and `cur_album` must be provided. - For singletons, the candidates are `(dist, info)` pairs and `item` - must be provided. + of which candidate to use. Applies to both full albums and + singletons (tracks). Candidates are either AlbumMatch or TrackMatch + objects depending on `singleton`. for albums, `cur_artist`, + `cur_album`, and `itemcount` must be provided. For singletons, + `item` must be provided. Returns the result of the choice, which may SKIP, ASIS, TRACKS, or - MANUAL or a candidate. For albums, a candidate is a `(info, items)` - pair; for items, it is just a TrackInfo object. + MANUAL or a candidate (an AlbumMatch/TrackMatch object). """ # Sanity check. if singleton: @@ -294,11 +326,15 @@ def choose_candidate(candidates, singleton, rec, color, timid, # Zero candidates. if not candidates: - print_("No match found.") if singleton: + print_("No matching recordings found.") opts = ('Use as-is', 'Skip', 'Enter search', 'enter Id', 'aBort') else: + print_("No matching release found for {0} tracks." + .format(itemcount)) + print_('For help, see: ' + 'https://github.com/sampsyo/beets/wiki/FAQ#wiki-nomatch') opts = ('Use as-is', 'as Tracks', 'Skip', 'Enter search', 'enter Id', 'aBort') sel = ui.input_options(opts, color=color) @@ -321,12 +357,9 @@ def choose_candidate(candidates, singleton, rec, color, timid, # Is the change good enough? bypass_candidates = False if rec != autotag.RECOMMEND_NONE: - if singleton: - dist, info = candidates[0] - else: - dist, items, info = candidates[0] + match = candidates[0] bypass_candidates = True - + while True: # Display and choose from candidates. if not bypass_candidates: @@ -335,22 +368,24 @@ def choose_candidate(candidates, singleton, rec, color, timid, print_('Finding tags for track "%s - %s".' % (item.artist, item.title)) print_('Candidates:') - for i, (dist, info) in enumerate(candidates): - print_('%i. %s - %s (%s)' % (i+1, info.artist, - info.title, dist_string(dist, color))) + for i, match in enumerate(candidates): + print_('%i. %s - %s (%s)' % + (i + 1, match.info.artist, match.info.title, + dist_string(match.distance, color))) else: print_('Finding tags for album "%s - %s".' % (cur_artist, cur_album)) print_('Candidates:') - for i, (dist, items, info) in enumerate(candidates): - line = '%i. %s - %s' % (i+1, info.artist, info.album) + for i, match in enumerate(candidates): + line = '%i. %s - %s' % (i + 1, match.info.artist, + match.info.album) # Label and year disambiguation, if available. label, year = None, None - if info.label: - label = info.label - if info.year: - year = unicode(info.year) + if match.info.label: + label = match.info.label + if match.info.year: + year = unicode(match.info.year) if label and year: line += u' [%s, %s]' % (label, year) elif label: @@ -358,17 +393,17 @@ def choose_candidate(candidates, singleton, rec, color, timid, elif year: line += u' [%s]' % year - line += ' (%s)' % dist_string(dist, color) + line += ' (%s)' % dist_string(match.distance, color) # Point out the partial matches. - if None in items: + if match.extra_items or match.extra_tracks: warning = PARTIAL_MATCH_MESSAGE if color: warning = ui.colorize('yellow', warning) line += u' %s' % warning print_(line) - + # Ask the user for a choice. if singleton: opts = ('Skip', 'Use as-is', 'Enter search', 'enter Id', @@ -391,26 +426,24 @@ def choose_candidate(candidates, singleton, rec, color, timid, raise importer.ImportAbort() elif sel == 'i': return importer.action.MANUAL_ID - else: # Numerical selection. + else: # Numerical selection. if singleton: - dist, info = candidates[sel-1] + match = candidates[sel - 1] else: - dist, items, info = candidates[sel-1] + match = candidates[sel - 1] bypass_candidates = False - + # Show what we're about to do. if singleton: - show_item_change(item, info, dist, color) + show_item_change(item, match, color) else: - show_change(cur_artist, cur_album, items, info, dist, color) - + show_change(cur_artist, cur_album, match, color, + per_disc_numbering) + # Exact match => tag automatically if we're not in timid mode. if rec == autotag.RECOMMEND_STRONG and not timid: - if singleton: - return info - else: - return info, items - + return match + # Ask for confirmation. if singleton: opts = ('Apply', 'More candidates', 'Skip', 'Use as-is', @@ -420,10 +453,7 @@ def choose_candidate(candidates, singleton, rec, color, timid, 'as Tracks', 'Enter search', 'enter Id', 'aBort') sel = ui.input_options(opts, color=color) if sel == 'a': - if singleton: - return info - else: - return info, items + return match elif sel == 'm': pass elif sel == 's': @@ -444,18 +474,17 @@ def manual_search(singleton): """Input either an artist and album (for full albums) or artist and track name (for singletons) for manual search. """ - artist = raw_input('Artist: ').decode(sys.stdin.encoding) - name = raw_input('Track: ' if singleton else 'Album: ') \ - .decode(sys.stdin.encoding) + artist = input_('Artist:') + name = input_('Track:' if singleton else 'Album:') return artist.strip(), name.strip() def manual_id(singleton): """Input a MusicBrainz ID, either for an album ("release") or a track ("recording"). If no valid ID is entered, returns None. """ - prompt = 'Enter MusicBrainz %s ID: ' % \ + prompt = 'Enter MusicBrainz %s ID:' % \ ('recording' if singleton else 'release') - entry = raw_input(prompt).decode(sys.stdin.encoding).strip() + entry = input_(prompt).strip() # Find the first thing that looks like a UUID/MBID. match = re.search('[a-f0-9]{8}(-[a-f0-9]{4}){3}-[a-f0-9]{12}', entry) @@ -468,7 +497,7 @@ def manual_id(singleton): def choose_match(task, config): """Given an initial autotagging of items, go through an interactive dance with the user to ask for a choice of metadata. Returns an - (info, items) pair, ASIS, or SKIP. + AlbumMatch object, ASIS, or SKIP. """ # Show what we're tagging. print_() @@ -477,10 +506,9 @@ def choose_match(task, config): if config.quiet: # No input; just make a decision. if task.rec == autotag.RECOMMEND_STRONG: - dist, items, info = task.candidates[0] - show_change(task.cur_artist, task.cur_album, items, info, dist, - config.color) - return info, items + match = task.candidates[0] + show_change(task.cur_artist, task.cur_album, match, config.color) + return match else: return _quiet_fall_back(config) @@ -488,10 +516,11 @@ def choose_match(task, config): candidates, rec = task.candidates, task.rec while True: # Ask for a choice from the user. - choice = choose_candidate(candidates, False, rec, config.color, + choice = choose_candidate(candidates, False, rec, config.color, config.timid, task.cur_artist, - task.cur_album) - + task.cur_album, itemcount=len(task.items), + per_disc_numbering=config.per_disc_numbering) + # Choose which tags to use. if choice in (importer.action.SKIP, importer.action.ASIS, importer.action.TRACKS): @@ -517,25 +546,25 @@ def choose_match(task, config): except autotag.AutotagError: candidates, rec = None, None else: - # We have a candidate! Finish tagging. Here, choice is - # an (info, items) pair as desired. - assert not isinstance(choice, importer.action) + # We have a candidate! Finish tagging. Here, choice is an + # AlbumMatch object. + assert isinstance(choice, autotag.AlbumMatch) return choice def choose_item(task, config): """Ask the user for a choice about tagging a single item. Returns - either an action constant or a TrackInfo object. + either an action constant or a TrackMatch object. """ print_() print_(task.item.path) - candidates, rec = task.item_match + candidates, rec = task.candidates, task.rec if config.quiet: # Quiet mode; make a decision. if rec == autotag.RECOMMEND_STRONG: - dist, track_info = candidates[0] - show_item_change(task.item, track_info, dist, config.color) - return track_info + match = candidates[0] + show_item_change(task.item, match, config.color) + return match else: return _quiet_fall_back(config) @@ -558,10 +587,10 @@ def choose_item(task, config): search_id = manual_id(True) if search_id: candidates, rec = autotag.tag_item(task.item, config.timid, - search_id=search_id) + search_id=search_id) else: # Chose a candidate. - assert not isinstance(choice, importer.action) + assert isinstance(choice, autotag.TrackMatch) return choice def resolve_duplicate(task, config): @@ -595,30 +624,30 @@ def resolve_duplicate(task, config): # The import command. -def import_files(lib, paths, copy, write, autot, logpath, art, threaded, +def import_files(lib, paths, copy, move, write, autot, logpath, threaded, color, delete, quiet, resume, quiet_fallback, singletons, - timid, query, incremental, ignore): + timid, query, incremental, ignore, per_disc_numbering): """Import the files in the given list of paths, tagging each leaf - directory as an album. If copy, then the files are copied into - the library folder. If write, then new metadata is written to the - files themselves. If not autot, then just import the files - without attempting to tag. If logpath is provided, then untaggable - albums will be logged there. If art, then attempt to download - cover art for each album. If threaded, then accelerate autotagging + directory as an album. If copy, then the files are copied into the + library folder. If write, then new metadata is written to the files + themselves. If not autot, then just import the files without + attempting to tag. If logpath is provided, then untaggable albums + will be logged there. If threaded, then accelerate autotagging imports by running them in multiple threads. If color, then ANSI-colorize some terminal output. If delete, then old files are - deleted when they are copied. If quiet, then the user is - never prompted for input; instead, the tagger just skips anything - it is not confident about. resume indicates whether interrupted - imports can be resumed and is either a boolean or None. - quiet_fallback should be either ASIS or SKIP and indicates what - should happen in quiet mode when the recommendation is not strong. + deleted when they are copied. If quiet, then the user is never + prompted for input; instead, the tagger just skips anything it is + not confident about. resume indicates whether interrupted imports + can be resumed and is either a boolean or None. quiet_fallback + should be either ASIS or SKIP and indicates what should happen in + quiet mode when the recommendation is not strong. """ # Check the user-specified directories. for path in paths: - if not singletons and not os.path.isdir(syspath(path)): + fullpath = syspath(normpath(path)) + if not singletons and not os.path.isdir(fullpath): raise ui.UserError('not a directory: ' + path) - elif singletons and not os.path.exists(syspath(path)): + elif singletons and not os.path.exists(fullpath): raise ui.UserError('no such file: ' + path) # Check parameter consistency. @@ -633,7 +662,7 @@ def import_files(lib, paths, copy, write, autot, logpath, art, threaded, except IOError: raise ui.UserError(u"could not open log file for writing: %s" % displayable_path(logpath)) - print >>logfile, 'import started', time.asctime() + print('import started', time.asctime(), file=logfile) else: logfile = None @@ -652,8 +681,8 @@ def import_files(lib, paths, copy, write, autot, logpath, art, threaded, quiet = quiet, quiet_fallback = quiet_fallback, copy = copy, + move = move, write = write, - art = art, delete = delete, threaded = threaded, autot = autot, @@ -666,12 +695,13 @@ def import_files(lib, paths, copy, write, autot, logpath, art, threaded, incremental = incremental, ignore = ignore, resolve_duplicate_func = resolve_duplicate, + per_disc_numbering = per_disc_numbering, ) - + finally: # If we were logging, close the file. if logfile: - print >>logfile, '' + print('', file=logfile) logfile.close() # Emit event. @@ -696,10 +726,6 @@ import_cmd.parser.add_option('-p', '--resume', action='store_true', default=None, help="resume importing if interrupted") import_cmd.parser.add_option('-P', '--noresume', action='store_false', dest='resume', help="do not try to resume importing") -import_cmd.parser.add_option('-r', '--art', action='store_true', - default=None, help="try to download album art") -import_cmd.parser.add_option('-R', '--noart', action='store_false', - dest='art', help="don't album art (opposite of -r)") import_cmd.parser.add_option('-q', '--quiet', action='store_true', dest='quiet', help="never prompt for input: skip albums instead") import_cmd.parser.add_option('-l', '--log', dest='logpath', @@ -712,19 +738,20 @@ import_cmd.parser.add_option('-L', '--library', dest='library', action='store_true', help='retag items matching a query') import_cmd.parser.add_option('-i', '--incremental', dest='incremental', action='store_true', help='skip already-imported directories') +import_cmd.parser.add_option('-I', '--noincremental', dest='incremental', + action='store_false', help='do not skip already-imported directories') def import_func(lib, config, opts, args): copy = opts.copy if opts.copy is not None else \ ui.config_val(config, 'beets', 'import_copy', DEFAULT_IMPORT_COPY, bool) + move = ui.config_val(config, 'beets', 'import_move', + DEFAULT_IMPORT_MOVE, bool) write = opts.write if opts.write is not None else \ ui.config_val(config, 'beets', 'import_write', DEFAULT_IMPORT_WRITE, bool) delete = ui.config_val(config, 'beets', 'import_delete', DEFAULT_IMPORT_DELETE, bool) autot = opts.autotag if opts.autotag is not None else DEFAULT_IMPORT_AUTOT - art = opts.art if opts.art is not None else \ - ui.config_val(config, 'beets', 'import_art', - DEFAULT_IMPORT_ART, bool) threaded = ui.config_val(config, 'beets', 'threaded', DEFAULT_THREADED, bool) color = ui.config_val(config, 'beets', 'color', DEFAULT_COLOR, bool) @@ -741,6 +768,8 @@ def import_func(lib, config, opts, args): ui.config_val(config, 'beets', 'import_incremental', DEFAULT_IMPORT_INCREMENTAL, bool) ignore = ui.config_val(config, 'beets', 'ignore', DEFAULT_IGNORE, list) + per_disc_numbering = ui.config_val(config, 'beets', 'per_disc_numbering', + DEFAULT_PER_DISC_NUMBERING, bool) # Resume has three options: yes, no, and "ask" (None). resume = opts.resume if opts.resume is not None else \ @@ -753,6 +782,11 @@ def import_func(lib, config, opts, args): else: resume = None + # Special case: --copy flag suppresses import_move (which would + # otherwise take precedence). + if opts.copy: + move = False + if quiet_fallback_str == 'asis': quiet_fallback = importer.action.ASIS else: @@ -765,26 +799,23 @@ def import_func(lib, config, opts, args): query = None paths = args - import_files(lib, paths, copy, write, autot, logpath, art, threaded, + import_files(lib, paths, copy, move, write, autot, logpath, threaded, color, delete, quiet, resume, quiet_fallback, singletons, - timid, query, incremental, ignore) + timid, query, incremental, ignore, per_disc_numbering) import_cmd.func = import_func default_commands.append(import_cmd) # list: Query and show library contents. +DEFAULT_LIST_FORMAT_ITEM = '$artist - $album - $title' +DEFAULT_LIST_FORMAT_ALBUM = '$albumartist - $album' + def list_items(lib, query, album, path, fmt): """Print out items in lib matching query. If album, then search for albums instead of single items. If path, print the matched objects' paths instead of human-readable information about them. """ - if fmt is None: - # If no specific template is supplied, use a default. - if album: - fmt = u'$albumartist - $album' - else: - fmt = u'$artist - $album - $title' template = Template(fmt) if album: @@ -792,13 +823,13 @@ def list_items(lib, query, album, path, fmt): if path: print_(album.item_dir()) elif fmt is not None: - print_(template.substitute(album._record)) + print_(album.evaluate_template(template)) else: for item in lib.items(query): if path: print_(item.path) elif fmt is not None: - print_(template.substitute(item.record)) + print_(item.evaluate_template(template, lib)) list_cmd = ui.Subcommand('list', help='query the library', aliases=('ls',)) list_cmd.parser.add_option('-a', '--album', action='store_true', @@ -808,7 +839,16 @@ list_cmd.parser.add_option('-p', '--path', action='store_true', list_cmd.parser.add_option('-f', '--format', action='store', help='print with custom format', default=None) def list_func(lib, config, opts, args): - list_items(lib, decargs(args), opts.album, opts.path, opts.format) + fmt = opts.format + if not fmt: + # If no format is specified, fall back to a default. + if opts.album: + fmt = ui.config_val(config, 'beets', 'list_format_album', + DEFAULT_LIST_FORMAT_ALBUM) + else: + fmt = ui.config_val(config, 'beets', 'list_format_item', + DEFAULT_LIST_FORMAT_ITEM) + list_items(lib, decargs(args), opts.album, opts.path, fmt) list_cmd.func = list_func default_commands.append(list_cmd) @@ -819,89 +859,89 @@ def update_items(lib, query, album, move, color, pretend): """For all the items matched by the query, update the library to reflect the item's embedded tags. """ - items, _ = _do_query(lib, query, album) + with lib.transaction(): + items, _ = _do_query(lib, query, album) - # Walk through the items and pick up their changes. - affected_albums = set() - for item in items: - # Item deleted? - if not os.path.exists(syspath(item.path)): - print_(u'X %s - %s' % (item.artist, item.title)) - if not pretend: - lib.remove(item, True) - affected_albums.add(item.album_id) - continue - - # Did the item change since last checked? - if item.current_mtime() <= item.mtime: - log.debug(u'skipping %s because mtime is up to date (%i)' % - (displayable_path(item.path), item.mtime)) - continue - - # Read new data. - old_data = dict(item.record) - item.read() - - # Special-case album artist when it matches track artist. (Hacky - # but necessary for preserving album-level metadata for non- - # autotagged imports.) - if not item.albumartist and \ - old_data['albumartist'] == old_data['artist'] == item.artist: - item.albumartist = old_data['albumartist'] - item.dirty['albumartist'] = False - - # Get and save metadata changes. - changes = {} - for key in library.ITEM_KEYS_META: - if item.dirty[key]: - changes[key] = old_data[key], getattr(item, key) - if changes: - # Something changed. - print_(u'* %s - %s' % (item.artist, item.title)) - for key, (oldval, newval) in changes.iteritems(): - _showdiff(key, oldval, newval, color) - - # If we're just pretending, then don't move or save. - if pretend: + # Walk through the items and pick up their changes. + affected_albums = set() + for item in items: + # Item deleted? + if not os.path.exists(syspath(item.path)): + print_(u'X %s - %s' % (item.artist, item.title)) + if not pretend: + lib.remove(item, True) + affected_albums.add(item.album_id) continue - # Move the item if it's in the library. - if move and lib.directory in ancestry(item.path): - lib.move(item) + # Did the item change since last checked? + if item.current_mtime() <= item.mtime: + log.debug(u'skipping %s because mtime is up to date (%i)' % + (displayable_path(item.path), item.mtime)) + continue - lib.store(item) - affected_albums.add(item.album_id) - elif not pretend: - # The file's mtime was different, but there were no changes - # to the metadata. Store the new mtime, which is set in the - # call to read(), so we don't check this again in the - # future. - lib.store(item) + # Read new data. + old_data = dict(item.record) + item.read() - # Skip album changes while pretending. - if pretend: - return + # Special-case album artist when it matches track artist. (Hacky + # but necessary for preserving album-level metadata for non- + # autotagged imports.) + if not item.albumartist and \ + old_data['albumartist'] == old_data['artist'] == \ + item.artist: + item.albumartist = old_data['albumartist'] + item.dirty['albumartist'] = False - # Modify affected albums to reflect changes in their items. - for album_id in affected_albums: - if album_id is None: # Singletons. - continue - album = lib.get_album(album_id) - if not album: # Empty albums have already been removed. - log.debug('emptied album %i' % album_id) - continue - al_items = list(album.items()) + # Get and save metadata changes. + changes = {} + for key in library.ITEM_KEYS_META: + if item.dirty[key]: + changes[key] = old_data[key], getattr(item, key) + if changes: + # Something changed. + print_(u'* %s - %s' % (item.artist, item.title)) + for key, (oldval, newval) in changes.iteritems(): + _showdiff(key, oldval, newval, color) - # Update album structure to reflect an item in it. - for key in library.ALBUM_KEYS_ITEM: - setattr(album, key, getattr(al_items[0], key)) + # If we're just pretending, then don't move or save. + if pretend: + continue - # Move album art (and any inconsistent items). - if move and lib.directory in ancestry(al_items[0].path): - log.debug('moving album %i' % album_id) - album.move() + # Move the item if it's in the library. + if move and lib.directory in ancestry(item.path): + lib.move(item) - lib.save() + lib.store(item) + affected_albums.add(item.album_id) + elif not pretend: + # The file's mtime was different, but there were no changes + # to the metadata. Store the new mtime, which is set in the + # call to read(), so we don't check this again in the + # future. + lib.store(item) + + # Skip album changes while pretending. + if pretend: + return + + # Modify affected albums to reflect changes in their items. + for album_id in affected_albums: + if album_id is None: # Singletons. + continue + album = lib.get_album(album_id) + if not album: # Empty albums have already been removed. + log.debug('emptied album %i' % album_id) + continue + al_items = list(album.items()) + + # Update album structure to reflect an item in it. + for key in library.ALBUM_KEYS_ITEM: + setattr(album, key, getattr(al_items[0], key)) + + # Move album art (and any inconsistent items). + if move and lib.directory in ancestry(al_items[0].path): + log.debug('moving album %i' % album_id) + album.move() update_cmd = ui.Subcommand('update', help='update the library', aliases=('upd','up',)) @@ -942,14 +982,13 @@ def remove_items(lib, query, album, delete=False): return # Remove (and possibly delete) items. - if album: - for al in albums: - al.remove(delete) - else: - for item in items: - lib.remove(item, delete) - - lib.save() + with lib.transaction(): + if album: + for al in albums: + al.remove(delete) + else: + for item in items: + lib.remove(item, delete) remove_cmd = ui.Subcommand('remove', help='remove matching items from the library', aliases=('rm',)) @@ -1007,16 +1046,18 @@ default_commands.append(stats_cmd) # version: Show current beets version. def show_version(lib, config, opts, args): - print 'beets version %s' % lib.beets.__version__ + + print_('beets version %s' % lib.beets.__version__) + # Show plugins. names = [] for plugin in plugins.find_plugins(): modname = plugin.__module__ names.append(modname.split('.')[-1]) if names: - print 'plugins:', ', '.join(names) + print_('plugins:', ', '.join(names)) else: - print 'no plugins loaded' + print_('no plugins loaded') version_cmd = ui.Subcommand('version', help='output version information') version_cmd.func = show_version @@ -1061,23 +1102,23 @@ def modify_items(lib, mods, query, write, move, album, color, confirm): return # Apply changes to database. - for obj in objs: - for field, value in fsets.iteritems(): - setattr(obj, field, value) + with lib.transaction(): + for obj in objs: + for field, value in fsets.iteritems(): + setattr(obj, field, value) - if move: - cur_path = obj.item_dir() if album else obj.path - if lib.directory in ancestry(cur_path): # In library? - log.debug('moving object %s' % cur_path) - if album: - obj.move() - else: - lib.move(obj) + if move: + cur_path = obj.item_dir() if album else obj.path + if lib.directory in ancestry(cur_path): # In library? + log.debug('moving object %s' % cur_path) + if album: + obj.move() + else: + lib.move(obj) - # When modifying items, we have to store them to the database. - if not album: - lib.store(obj) - lib.save() + # When modifying items, we have to store them to the database. + if not album: + lib.store(obj) # Apply tags if requested. if write: @@ -1136,7 +1177,6 @@ def move_items(lib, dest, query, copy, album): else: lib.move(obj, copy, basedir=dest) lib.store(obj) - lib.save() move_cmd = ui.Subcommand('move', help='move or copy items', aliases=('mv',)) diff --git a/lib/beets/util/__init__.py b/lib/beets/util/__init__.py index b0ec38ba..380bfaf8 100644 --- a/lib/beets/util/__init__.py +++ b/lib/beets/util/__init__.py @@ -1,5 +1,5 @@ # This file is part of beets. -# Copyright 2011, Adrian Sampson. +# Copyright 2012, Adrian Sampson. # # Permission is hereby granted, free of charge, to any person obtaining # a copy of this software and associated documentation files (the @@ -8,20 +8,102 @@ # distribute, sublicense, and/or sell copies of the Software, and to # permit persons to whom the Software is furnished to do so, subject to # the following conditions: -# +# # The above copyright notice and this permission notice shall be # included in all copies or substantial portions of the Software. """Miscellaneous utility functions.""" +from __future__ import division + import os import sys import re import shutil import fnmatch from collections import defaultdict +import traceback MAX_FILENAME_LENGTH = 200 +class HumanReadableException(Exception): + """An Exception that can include a human-readable error message to + be logged without a traceback. Can preserve a traceback for + debugging purposes as well. + + Has at least two fields: `reason`, the underlying exception or a + string describing the problem; and `verb`, the action being + performed during the error. + + If `tb` is provided, it is a string containing a traceback for the + associated exception. (Note that this is not necessary in Python 3.x + and should be removed when we make the transition.) + """ + error_kind = 'Error' # Human-readable description of error type. + + def __init__(self, reason, verb, tb=None): + self.reason = reason + self.verb = verb + self.tb = tb + super(HumanReadableException, self).__init__(self.get_message()) + + def _gerund(self): + """Generate a (likely) gerund form of the English verb. + """ + if ' ' in self.verb: + return self.verb + gerund = self.verb[:-1] if self.verb.endswith('e') else self.verb + gerund += 'ing' + return gerund + + def _reasonstr(self): + """Get the reason as a string.""" + if isinstance(self.reason, basestring): + return self.reason + elif hasattr(self.reason, 'strerror'): # i.e., EnvironmentError + return self.reason.strerror + else: + return u'"{0}"'.format(self.reason) + + def get_message(self): + """Create the human-readable description of the error, sans + introduction. + """ + raise NotImplementedError + + def log(self, logger): + """Log to the provided `logger` a human-readable message as an + error and a verbose traceback as a debug message. + """ + if self.tb: + logger.debug(self.tb) + logger.error(u'{0}: {1}'.format(self.error_kind, self.args[0])) + +class FilesystemError(HumanReadableException): + """An error that occurred while performing a filesystem manipulation + via a function in this module. The `paths` field is a sequence of + pathnames involved in the operation. + """ + def __init__(self, reason, verb, paths, tb=None): + self.paths = paths + super(FilesystemError, self).__init__(reason, verb, tb) + + def get_message(self): + # Use a nicer English phrasing for some specific verbs. + if self.verb in ('move', 'copy', 'rename'): + clause = 'while {0} {1} to {2}'.format( + self._gerund(), repr(self.paths[0]), repr(self.paths[1]) + ) + elif self.verb in ('delete',): + clause = 'while {0} {1}'.format( + self._gerund(), repr(self.paths[0]) + ) + else: + clause = 'during {0} of paths {1}'.format( + self.verb, u', '.join(repr(p) for p in self.paths) + ) + + return u'{0} {1}'.format(self._reasonstr(), clause) + def normpath(path): """Provide the canonical form of the path suitable for storing in the database. @@ -39,11 +121,11 @@ def ancestry(path, pathmod=None): last_path = None while path: path = pathmod.dirname(path) - + if path == last_path: break last_path = path - + if path: # don't yield '' out.insert(0, path) return out @@ -59,7 +141,9 @@ def sorted_walk(path, ignore=()): # Get all the directories and files at this level. dirs = [] files = [] - for base in os.listdir(path): + for base in os.listdir(syspath(path)): + base = bytestring_path(base) + # Skip ignored filenames. skip = False for pat in ignore: @@ -84,7 +168,7 @@ def sorted_walk(path, ignore=()): # Recurse into directories. for base in dirs: cur = os.path.join(path, base) - # yield from _sorted_walk(cur) + # yield from sorted_walk(...) for res in sorted_walk(cur, ignore): yield res @@ -149,13 +233,13 @@ def components(path, pathmod=None): comp = pathmod.basename(anc) if comp: comps.append(comp) - else: # root + else: # root comps.append(anc) - + last = pathmod.basename(path) if last: comps.append(last) - + return comps def bytestring_path(path): @@ -168,6 +252,13 @@ def bytestring_path(path): # Try to encode with default encodings, but fall back to UTF8. encoding = sys.getfilesystemencoding() or sys.getdefaultencoding() + if encoding == 'mbcs': + # On Windows, a broken encoding known to Python as "MBCS" is + # used for the filesystem. However, we only use the Unicode API + # for Windows paths, so the encoding is actually immaterial so + # we can avoid dealing with this nastiness. We arbitrarily + # choose UTF-8. + encoding = 'utf8' try: return path.encode(encoding) except (UnicodeError, LookupError): @@ -202,12 +293,16 @@ def syspath(path, pathmod=None): return path if not isinstance(path, unicode): - # Try to decode with default encodings, but fall back to UTF8. - encoding = sys.getfilesystemencoding() or sys.getdefaultencoding() + # Beets currently represents Windows paths internally with UTF-8 + # arbitrarily. But earlier versions used MBCS because it is + # reported as the FS encoding by Windows. Try both. try: - path = path.decode(encoding, 'replace') + path = path.decode('utf8') except UnicodeError: - path = path.decode('utf8', 'replace') + # The encoding should always be MBCS, Windows' broken + # Unicode representation. + encoding = sys.getfilesystemencoding() or sys.getdefaultencoding() + path = path.decode(encoding, 'replace') # Add the magic prefix if it isn't already there if not path.startswith(u'\\\\?\\'): @@ -219,42 +314,63 @@ def samefile(p1, p2): """Safer equality for paths.""" return shutil._samefile(syspath(p1), syspath(p2)) -def soft_remove(path): - """Remove the file if it exists.""" +def remove(path, soft=True): + """Remove the file. If `soft`, then no error will be raised if the + file does not exist. + """ path = syspath(path) - if os.path.exists(path): + if soft and not os.path.exists(path): + return + try: os.remove(path) + except (OSError, IOError) as exc: + raise FilesystemError(exc, 'delete', (path,), traceback.format_exc()) -def _assert_not_exists(path, pathmod=None): - """Raises an OSError if the path exists.""" - pathmod = pathmod or os.path - if pathmod.exists(path): - raise OSError('file exists: %s' % path) - -def copy(path, dest, replace=False, pathmod=None): - """Copy a plain file. Permissions are not copied. If dest already - exists, raises an OSError unless replace is True. Has no effect if - path is the same as dest. Paths are translated to system paths - before the syscall. +def copy(path, dest, replace=False, pathmod=os.path): + """Copy a plain file. Permissions are not copied. If `dest` already + exists, raises a FilesystemError unless `replace` is True. Has no + effect if `path` is the same as `dest`. Paths are translated to + system paths before the syscall. """ if samefile(path, dest): return path = syspath(path) dest = syspath(dest) - _assert_not_exists(dest, pathmod) - return shutil.copyfile(path, dest) + if not replace and pathmod.exists(dest): + raise FilesystemError('file exists', 'copy', (path, dest)) + try: + shutil.copyfile(path, dest) + except (OSError, IOError) as exc: + raise FilesystemError(exc, 'copy', (path, dest), + traceback.format_exc()) -def move(path, dest, replace=False, pathmod=None): - """Rename a file. dest may not be a directory. If dest already - exists, raises an OSError unless replace is True. Hos no effect if - path is the same as dest. Paths are translated to system paths. +def move(path, dest, replace=False, pathmod=os.path): + """Rename a file. `dest` may not be a directory. If `dest` already + exists, raises an OSError unless `replace` is True. Has no effect if + `path` is the same as `dest`. If the paths are on different + filesystems (or the rename otherwise fails), a copy is attempted + instead, in which case metadata will *not* be preserved. Paths are + translated to system paths. """ if samefile(path, dest): return path = syspath(path) dest = syspath(dest) - _assert_not_exists(dest, pathmod) - return shutil.move(path, dest) + if pathmod.exists(dest): + raise FilesystemError('file exists', 'rename', (path, dest), + traceback.format_exc()) + + # First, try renaming the file. + try: + os.rename(path, dest) + except OSError: + # Otherwise, copy and delete the original. + try: + shutil.copyfile(path, dest) + os.remove(path) + except (OSError, IOError) as exc: + raise FilesystemError(exc, 'move', (path, dest), + traceback.format_exc()) def unique_path(path): """Returns a version of ``path`` that does not exist on the @@ -277,33 +393,33 @@ def unique_path(path): if not os.path.exists(new_path): return new_path -# Note: POSIX actually supports \ and : -- I just think they're -# a pain. And ? has caused problems for some. +# Note: The Windows "reserved characters" are, of course, allowed on +# Unix. They are forbidden here because they cause problems on Samba +# shares, which are sufficiently common as to cause frequent problems. +# http://msdn.microsoft.com/en-us/library/windows/desktop/aa365247.aspx CHAR_REPLACE = [ - (re.compile(r'[\\/\?"]|^\.'), '_'), - (re.compile(r':'), '-'), -] -CHAR_REPLACE_WINDOWS = [ - (re.compile(r'["\*<>\|]|^\.|\.$|\s+$'), '_'), + (re.compile(ur'[\\/]'), u'_'), # / and \ -- forbidden everywhere. + (re.compile(ur'^\.'), u'_'), # Leading dot (hidden files on Unix). + (re.compile(ur'[\x00-\x1f]'), u''), # Control characters. + (re.compile(ur'[<>:"\?\*\|]'), u'_'), # Windows "reserved characters". + (re.compile(ur'\.$'), u'_'), # Trailing dots. + (re.compile(ur'\s+$'), u''), # Trailing whitespace. ] def sanitize_path(path, pathmod=None, replacements=None): - """Takes a path and makes sure that it is legal. Returns a new path. - Only works with fragments; won't work reliably on Windows when a - path begins with a drive letter. Path separators (including altsep!) - should already be cleaned from the path components. If replacements - is specified, it is used *instead* of the default set of - replacements for the platform; it must be a list of (compiled regex, - replacement string) pairs. + """Takes a path (as a Unicode string) and makes sure that it is + legal. Returns a new path. Only works with fragments; won't work + reliably on Windows when a path begins with a drive letter. Path + separators (including altsep!) should already be cleaned from the + path components. If replacements is specified, it is used *instead* + of the default set of replacements for the platform; it must be a + list of (compiled regex, replacement string) pairs. """ pathmod = pathmod or os.path - windows = pathmod.__name__ == 'ntpath' # Choose the appropriate replacements. if not replacements: replacements = list(CHAR_REPLACE) - if windows: - replacements += CHAR_REPLACE_WINDOWS - + comps = components(path, pathmod) if not comps: return '' @@ -311,10 +427,10 @@ def sanitize_path(path, pathmod=None, replacements=None): # Replace special characters. for regex, repl in replacements: comp = regex.sub(repl, comp) - + # Truncate each component. comp = comp[:MAX_FILENAME_LENGTH] - + comps[i] = comp return pathmod.join(*comps) @@ -336,10 +452,10 @@ def sanitize_for_path(value, pathmod, key=None): value = u'%02i' % (value or 0) elif key == 'bitrate': # Bitrate gets formatted as kbps. - value = u'%ikbps' % ((value or 0) / 1000) + value = u'%ikbps' % ((value or 0) // 1000) elif key == 'samplerate': # Sample rate formatted as kHz. - value = u'%ikHz' % ((value or 0) / 1000) + value = u'%ikHz' % ((value or 0) // 1000) else: value = unicode(value) return value @@ -360,7 +476,7 @@ def levenshtein(s1, s2): return levenshtein(s2, s1) if not s1: return len(s2) - + previous_row = xrange(len(s2) + 1) for i, c1 in enumerate(s1): current_row = [i + 1] @@ -370,7 +486,7 @@ def levenshtein(s1, s2): substitutions = previous_row[j] + (c1 != c2) current_row.append(min(insertions, deletions, substitutions)) previous_row = current_row - + return previous_row[-1] def plurality(objs): diff --git a/lib/beets/util/bluelet.py b/lib/beets/util/bluelet.py new file mode 100644 index 00000000..aee63116 --- /dev/null +++ b/lib/beets/util/bluelet.py @@ -0,0 +1,562 @@ +"""Extremely simple pure-Python implementation of coroutine-style +asynchronous socket I/O. Inspired by, but inferior to, Eventlet. +Bluelet can also be thought of as a less-terrible replacement for +asyncore. + +Bluelet: easy concurrency without all the messy parallelism. +""" +import socket +import select +import sys +import types +import errno +import traceback +import time +import collections + + +# A little bit of "six" (Python 2/3 compatibility): cope with PEP 3109 syntax +# changes. + +PY3 = sys.version_info[0] == 3 +if PY3: + def _reraise(typ, exc, tb): + raise exc.with_traceback(tb) +else: + exec(""" +def _reraise(typ, exc, tb): + raise typ, exc, tb""") + + +# Basic events used for thread scheduling. + +class Event(object): + """Just a base class identifying Bluelet events. An event is an + object yielded from a Bluelet thread coroutine to suspend operation + and communicate with the scheduler. + """ + pass + +class WaitableEvent(Event): + """A waitable event is one encapsulating an action that can be + waited for using a select() call. That is, it's an event with an + associated file descriptor. + """ + def waitables(self): + """Return "waitable" objects to pass to select(). Should return + three iterables for input readiness, output readiness, and + exceptional conditions (i.e., the three lists passed to + select()). + """ + return (), (), () + + def fire(self): + """Called when an assoicated file descriptor becomes ready + (i.e., is returned from a select() call). + """ + pass + +class ValueEvent(Event): + """An event that does nothing but return a fixed value.""" + def __init__(self, value): + self.value = value + +class ExceptionEvent(Event): + """Raise an exception at the yield point. Used internally.""" + def __init__(self, exc_info): + self.exc_info = exc_info + +class SpawnEvent(Event): + """Add a new coroutine thread to the scheduler.""" + def __init__(self, coro): + self.spawned = coro + +class JoinEvent(Event): + """Suspend the thread until the specified child thread has + completed. + """ + def __init__(self, child): + self.child = child + +class DelegationEvent(Event): + """Suspend execution of the current thread, start a new thread and, + once the child thread finished, return control to the parent + thread. + """ + def __init__(self, coro): + self.spawned = coro + +class ReturnEvent(Event): + """Return a value the current thread's delegator at the point of + delegation. Ends the current (delegate) thread. + """ + def __init__(self, value): + self.value = value + +class SleepEvent(WaitableEvent): + """Suspend the thread for a given duration. + """ + def __init__(self, duration): + self.wakeup_time = time.time() + duration + + def time_left(self): + return max(self.wakeup_time - time.time(), 0.0) + +class ReadEvent(WaitableEvent): + """Reads from a file-like object.""" + def __init__(self, fd, bufsize): + self.fd = fd + self.bufsize = bufsize + + def waitables(self): + return (self.fd,), (), () + + def fire(self): + return self.fd.read(self.bufsize) + +class WriteEvent(WaitableEvent): + """Writes to a file-like object.""" + def __init__(self, fd, data): + self.fd = fd + self.data = data + + def waitable(self): + return (), (self.fd,), () + + def fire(self): + self.fd.write(self.data) + + +# Core logic for executing and scheduling threads. + +def _event_select(events): + """Perform a select() over all the Events provided, returning the + ones ready to be fired. Only WaitableEvents (including SleepEvents) + matter here; all other events are ignored (and thus postponed). + """ + # Gather waitables and wakeup times. + waitable_to_event = {} + rlist, wlist, xlist = [], [], [] + earliest_wakeup = None + for event in events: + if isinstance(event, SleepEvent): + if not earliest_wakeup: + earliest_wakeup = event.wakeup_time + else: + earliest_wakeup = min(earliest_wakeup, event.wakeup_time) + elif isinstance(event, WaitableEvent): + r, w, x = event.waitables() + rlist += r + wlist += w + xlist += x + for waitable in r: + waitable_to_event[('r', waitable)] = event + for waitable in w: + waitable_to_event[('w', waitable)] = event + for waitable in x: + waitable_to_event[('x', waitable)] = event + + # If we have a any sleeping threads, determine how long to sleep. + if earliest_wakeup: + timeout = max(earliest_wakeup - time.time(), 0.0) + else: + timeout = None + + # Perform select() if we have any waitables. + if rlist or wlist or xlist: + rready, wready, xready = select.select(rlist, wlist, xlist, timeout) + else: + rready, wready, xready = (), (), () + if timeout: + time.sleep(timeout) + + # Gather ready events corresponding to the ready waitables. + ready_events = set() + for ready in rready: + ready_events.add(waitable_to_event[('r', ready)]) + for ready in wready: + ready_events.add(waitable_to_event[('w', ready)]) + for ready in xready: + ready_events.add(waitable_to_event[('x', ready)]) + + # Gather any finished sleeps. + for event in events: + if isinstance(event, SleepEvent) and event.time_left() == 0.0: + ready_events.add(event) + + return ready_events + +class ThreadException(Exception): + def __init__(self, coro, exc_info): + self.coro = coro + self.exc_info = exc_info + def reraise(self): + _reraise(self.exc_info[0], self.exc_info[1], self.exc_info[2]) + +SUSPENDED = Event() # Special sentinel placeholder for suspended threads. + +def run(root_coro): + """Schedules a coroutine, running it to completion. This + encapsulates the Bluelet scheduler, which the root coroutine can + add to by spawning new coroutines. + """ + # The "threads" dictionary keeps track of all the currently- + # executing and suspended coroutines. It maps coroutines to their + # currently "blocking" event. The event value may be SUSPENDED if + # the coroutine is waiting on some other condition: namely, a + # delegated coroutine or a joined coroutine. In this case, the + # coroutine should *also* appear as a value in one of the below + # dictionaries `delegators` or `joiners`. + threads = {root_coro: ValueEvent(None)} + + # Maps child coroutines to delegating parents. + delegators = {} + + # Maps child coroutines to joining (exit-waiting) parents. + joiners = collections.defaultdict(list) + + def complete_thread(coro, return_value): + """Remove a coroutine from the scheduling pool, awaking + delegators and joiners as necessary and returning the specified + value to any delegating parent. + """ + del threads[coro] + + # Resume delegator. + if coro in delegators: + threads[delegators[coro]] = ValueEvent(return_value) + del delegators[coro] + + # Resume joiners. + if coro in joiners: + for parent in joiners[coro]: + threads[parent] = ValueEvent(None) + del joiners[coro] + + def advance_thread(coro, value, is_exc=False): + """After an event is fired, run a given coroutine associated with + it in the threads dict until it yields again. If the coroutine + exits, then the thread is removed from the pool. If the coroutine + raises an exception, it is reraised in a ThreadException. If + is_exc is True, then the value must be an exc_info tuple and the + exception is thrown into the coroutine. + """ + try: + if is_exc: + next_event = coro.throw(*value) + else: + next_event = coro.send(value) + except StopIteration: + # Thread is done. + complete_thread(coro, None) + except: + # Thread raised some other exception. + del threads[coro] + raise ThreadException(coro, sys.exc_info()) + else: + if isinstance(next_event, types.GeneratorType): + # Automatically invoke sub-coroutines. (Shorthand for + # explicit bluelet.call().) + next_event = DelegationEvent(next_event) + threads[coro] = next_event + + # Continue advancing threads until root thread exits. + exit_te = None + while threads: + try: + # Look for events that can be run immediately. Continue + # running immediate events until nothing is ready. + while True: + have_ready = False + for coro, event in list(threads.items()): + if isinstance(event, SpawnEvent): + threads[event.spawned] = ValueEvent(None) # Spawn. + advance_thread(coro, None) + have_ready = True + elif isinstance(event, ValueEvent): + advance_thread(coro, event.value) + have_ready = True + elif isinstance(event, ExceptionEvent): + advance_thread(coro, event.exc_info, True) + have_ready = True + elif isinstance(event, DelegationEvent): + threads[coro] = SUSPENDED # Suspend. + threads[event.spawned] = ValueEvent(None) # Spawn. + delegators[event.spawned] = coro + have_ready = True + elif isinstance(event, ReturnEvent): + # Thread is done. + complete_thread(coro, event.value) + have_ready = True + elif isinstance(event, JoinEvent): + threads[coro] = SUSPENDED # Suspend. + joiners[event.child].append(coro) + have_ready = True + + # Only start the select when nothing else is ready. + if not have_ready: + break + + # Wait and fire. + event2coro = dict((v,k) for k,v in threads.items()) + for event in _event_select(threads.values()): + # Run the IO operation, but catch socket errors. + try: + value = event.fire() + except socket.error as exc: + if isinstance(exc.args, tuple) and \ + exc.args[0] == errno.EPIPE: + # Broken pipe. Remote host disconnected. + pass + else: + traceback.print_exc() + # Abort the coroutine. + threads[event2coro[event]] = ReturnEvent(None) + else: + advance_thread(event2coro[event], value) + + except ThreadException as te: + # Exception raised from inside a thread. + event = ExceptionEvent(te.exc_info) + if te.coro in delegators: + # The thread is a delegate. Raise exception in its + # delegator. + threads[delegators[te.coro]] = event + del delegators[te.coro] + else: + # The thread is root-level. Raise in client code. + exit_te = te + break + + except: + # For instance, KeyboardInterrupt during select(). Raise + # into root thread and terminate others. + threads = {root_coro: ExceptionEvent(sys.exc_info())} + + # If any threads still remain, kill them. + for coro in threads: + coro.close() + + # If we're exiting with an exception, raise it in the client. + if exit_te: + exit_te.reraise() + + +# Sockets and their associated events. + +class Listener(object): + """A socket wrapper object for listening sockets. + """ + def __init__(self, host, port): + """Create a listening socket on the given hostname and port. + """ + self.host = host + self.port = port + self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + self.sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + self.sock.bind((host, port)) + self.sock.listen(5) + + def accept(self): + """An event that waits for a connection on the listening socket. + When a connection is made, the event returns a Connection + object. + """ + return AcceptEvent(self) + + def close(self): + """Immediately close the listening socket. (Not an event.) + """ + self.sock.close() + +class Connection(object): + """A socket wrapper object for connected sockets. + """ + def __init__(self, sock, addr): + self.sock = sock + self.addr = addr + self._buf = b'' + + def close(self): + """Close the connection.""" + self.sock.close() + + def recv(self, size): + """Read at most size bytes of data from the socket.""" + if self._buf: + # We already have data read previously. + out = self._buf[:size] + self._buf = self._buf[size:] + return ValueEvent(out) + else: + return ReceiveEvent(self, size) + + def send(self, data): + """Sends data on the socket, returning the number of bytes + successfully sent. + """ + return SendEvent(self, data) + + def sendall(self, data): + """Send all of data on the socket.""" + return SendEvent(self, data, True) + + def readline(self, terminator=b"\n", bufsize=1024): + """Reads a line (delimited by terminator) from the socket.""" + while True: + if terminator in self._buf: + line, self._buf = self._buf.split(terminator, 1) + line += terminator + yield ReturnEvent(line) + break + data = yield ReceiveEvent(self, bufsize) + if data: + self._buf += data + else: + line = self._buf + self._buf = b'' + yield ReturnEvent(line) + break + +class AcceptEvent(WaitableEvent): + """An event for Listener objects (listening sockets) that suspends + execution until the socket gets a connection. + """ + def __init__(self, listener): + self.listener = listener + + def waitables(self): + return (self.listener.sock,), (), () + + def fire(self): + sock, addr = self.listener.sock.accept() + return Connection(sock, addr) + +class ReceiveEvent(WaitableEvent): + """An event for Connection objects (connected sockets) for + asynchronously reading data. + """ + def __init__(self, conn, bufsize): + self.conn = conn + self.bufsize = bufsize + + def waitables(self): + return (self.conn.sock,), (), () + + def fire(self): + return self.conn.sock.recv(self.bufsize) + +class SendEvent(WaitableEvent): + """An event for Connection objects (connected sockets) for + asynchronously writing data. + """ + def __init__(self, conn, data, sendall=False): + self.conn = conn + self.data = data + self.sendall = sendall + + def waitables(self): + return (), (self.conn.sock,), () + + def fire(self): + if self.sendall: + return self.conn.sock.sendall(self.data) + else: + return self.conn.sock.send(self.data) + + +# Public interface for threads; each returns an event object that +# can immediately be "yield"ed. + +def null(): + """Event: yield to the scheduler without doing anything special. + """ + return ValueEvent(None) + +def spawn(coro): + """Event: add another coroutine to the scheduler. Both the parent + and child coroutines run concurrently. + """ + if not isinstance(coro, types.GeneratorType): + raise ValueError('%s is not a coroutine' % str(coro)) + return SpawnEvent(coro) + +def call(coro): + """Event: delegate to another coroutine. The current coroutine + is resumed once the sub-coroutine finishes. If the sub-coroutine + returns a value using end(), then this event returns that value. + """ + if not isinstance(coro, types.GeneratorType): + raise ValueError('%s is not a coroutine' % str(coro)) + return DelegationEvent(coro) + +def end(value = None): + """Event: ends the coroutine and returns a value to its + delegator. + """ + return ReturnEvent(value) + +def read(fd, bufsize = None): + """Event: read from a file descriptor asynchronously.""" + if bufsize is None: + # Read all. + def reader(): + buf = [] + while True: + data = yield read(fd, 1024) + if not data: + break + buf.append(data) + yield ReturnEvent(''.join(buf)) + return DelegationEvent(reader()) + + else: + return ReadEvent(fd, bufsize) + +def write(fd, data): + """Event: write to a file descriptor asynchronously.""" + return WriteEvent(fd, data) + +def connect(host, port): + """Event: connect to a network address and return a Connection + object for communicating on the socket. + """ + addr = (host, port) + sock = socket.create_connection(addr) + return ValueEvent(Connection(sock, addr)) + +def sleep(duration): + """Event: suspend the thread for ``duration`` seconds. + """ + return SleepEvent(duration) + +def join(coro): + """Suspend the thread until another, previously `spawn`ed thread + completes. + """ + return JoinEvent(coro) + + +# Convenience function for running socket servers. + +def server(host, port, func): + """A coroutine that runs a network server. Host and port specify the + listening address. func should be a coroutine that takes a single + parameter, a Connection object. The coroutine is invoked for every + incoming connection on the listening socket. + """ + def handler(conn): + try: + yield func(conn) + finally: + conn.close() + + listener = Listener(host, port) + try: + while True: + conn = yield listener.accept() + yield spawn(handler(conn)) + except KeyboardInterrupt: + pass + finally: + listener.close() diff --git a/lib/beets/util/enumeration.py b/lib/beets/util/enumeration.py index 794a0624..f4968025 100644 --- a/lib/beets/util/enumeration.py +++ b/lib/beets/util/enumeration.py @@ -8,7 +8,7 @@ # distribute, sublicense, and/or sell copies of the Software, and to # permit persons to whom the Software is furnished to do so, subject to # the following conditions: -# +# # The above copyright notice and this permission notice shall be # included in all copies or substantial portions of the Software. @@ -35,7 +35,7 @@ how you would expect them to. 'west' >>> Direction.north < Direction.west True - + Enumerations are classes; their instances represent the possible values of the enumeration. Because Python classes must have names, you may provide a `name` parameter to `enum`; if you don't, a meaningless one @@ -45,31 +45,31 @@ import random class Enumeration(type): """A metaclass whose classes are enumerations. - + The `values` attribute of the class is used to populate the enumeration. Values may either be a list of enumerated names or a string containing a space-separated list of names. When the class is created, it is instantiated for each name value in `values`. Each such instance is the name of the enumerated item as the sole argument. - + The `Enumerated` class is a good choice for a superclass. """ - + def __init__(cls, name, bases, dic): super(Enumeration, cls).__init__(name, bases, dic) - + if 'values' not in dic: # Do nothing if no values are provided (i.e., with # Enumerated itself). return - + # May be called with a single string, in which case we split on # whitespace for convenience. values = dic['values'] if isinstance(values, basestring): values = values.split() - + # Create the Enumerated instances for each value. We have to use # super's __setattr__ here because we disallow setattr below. super(Enumeration, cls).__setattr__('_items_dict', {}) @@ -78,56 +78,56 @@ class Enumeration(type): item = cls(value, len(cls._items_list)) cls._items_dict[value] = item cls._items_list.append(item) - + def __getattr__(cls, key): try: return cls._items_dict[key] except KeyError: raise AttributeError("enumeration '" + cls.__name__ + "' has no item '" + key + "'") - + def __setattr__(cls, key, val): raise TypeError("enumerations do not support attribute assignment") - + def __getitem__(cls, key): if isinstance(key, int): return cls._items_list[key] else: return getattr(cls, key) - + def __len__(cls): return len(cls._items_list) - + def __iter__(cls): return iter(cls._items_list) - + def __nonzero__(cls): # Ensures that __len__ doesn't get called before __init__ by # pydoc. return True - + class Enumerated(object): """An item in an enumeration. - + Contains instance methods inherited by enumerated objects. The metaclass is preset to `Enumeration` for your convenience. - - Instance attributes: + + Instance attributes: name -- The name of the item. index -- The index of the item in its enumeration. - + >>> from enumeration import Enumerated >>> class Garment(Enumerated): ... values = 'hat glove belt poncho lederhosen suspenders' ... def wear(self): - ... print 'now wearing a ' + self.name + ... print('now wearing a ' + self.name) ... >>> Garment.poncho.wear() now wearing a poncho """ - + __metaclass__ = Enumeration - + def __init__(self, name, index): self.name = name self.index = index @@ -149,18 +149,18 @@ class Enumerated(object): def enum(*values, **kwargs): """Shorthand for creating a new Enumeration class. - + Call with enumeration values as a list, a space-delimited string, or just an argument list. To give the class a name, pass it as the `name` keyword argument. Otherwise, a name will be chosen for you. - + The following are all equivalent: - + enum('pinkie ring middle index thumb') enum('pinkie', 'ring', 'middle', 'index', 'thumb') enum(['pinkie', 'ring', 'middle', 'index', 'thumb']) """ - + if ('name' not in kwargs) or kwargs['name'] is None: # Create a probably-unique name. It doesn't really have to be # unique, but getting distinct names each time helps with @@ -168,11 +168,11 @@ def enum(*values, **kwargs): name = 'Enumeration' + hex(random.randint(0,0xfffffff))[2:].upper() else: name = kwargs['name'] - + if len(values) == 1: # If there's only one value, we have a couple of alternate calling # styles. if isinstance(values[0], basestring) or hasattr(values[0], '__iter__'): values = values[0] - + return type(name, (Enumerated,), {'values': values}) diff --git a/lib/beets/util/functemplate.py b/lib/beets/util/functemplate.py index 5d692179..94cdf6c1 100644 --- a/lib/beets/util/functemplate.py +++ b/lib/beets/util/functemplate.py @@ -8,7 +8,7 @@ # distribute, sublicense, and/or sell copies of the Software, and to # permit persons to whom the Software is furnished to do so, subject to # the following conditions: -# +# # The above copyright notice and this permission notice shall be # included in all copies or substantial portions of the Software. @@ -25,7 +25,12 @@ library: unknown symbols are left intact. This is sort of like a tiny, horrible degeneration of a real templating engine like Jinja2 or Mustache. """ +from __future__ import print_function + import re +import ast +import dis +import types SYMBOL_DELIM = u'$' FUNC_DELIM = u'%' @@ -34,6 +39,9 @@ GROUP_CLOSE = u'}' ARG_SEP = u',' ESCAPE_CHAR = u'$' +VARIABLE_PREFIX = '__var_' +FUNCTION_PREFIX = '__func_' + class Environment(object): """Contains the values and functions to be substituted into a template. @@ -42,6 +50,88 @@ class Environment(object): self.values = values self.functions = functions + +# Code generation helpers. + +def ex_lvalue(name): + """A variable load expression.""" + return ast.Name(name, ast.Store()) + +def ex_rvalue(name): + """A variable store expression.""" + return ast.Name(name, ast.Load()) + +def ex_literal(val): + """An int, float, long, bool, string, or None literal with the given + value. + """ + if val is None: + return ast.Name('None', ast.Load()) + elif isinstance(val, (int, float, long)): + return ast.Num(val) + elif isinstance(val, bool): + return ast.Name(str(val), ast.Load()) + elif isinstance(val, basestring): + return ast.Str(val) + raise TypeError('no literal for {0}'.format(type(val))) + +def ex_varassign(name, expr): + """Assign an expression into a single variable. The expression may + either be an `ast.expr` object or a value to be used as a literal. + """ + if not isinstance(expr, ast.expr): + expr = ex_literal(expr) + return ast.Assign([ex_lvalue(name)], expr) + +def ex_call(func, args): + """A function-call expression with only positional parameters. The + function may be an expression or the name of a function. Each + argument may be an expression or a value to be used as a literal. + """ + if isinstance(func, basestring): + func = ex_rvalue(func) + + args = list(args) + for i in range(len(args)): + if not isinstance(args[i], ast.expr): + args[i] = ex_literal(args[i]) + + return ast.Call(func, args, [], None, None) + +def compile_func(arg_names, statements, name='_the_func', debug=False): + """Compile a list of statements as the body of a function and return + the resulting Python function. If `debug`, then print out the + bytecode of the compiled function. + """ + func_def = ast.FunctionDef( + name, + ast.arguments( + [ast.Name(n, ast.Param()) for n in arg_names], + None, None, + [ex_literal(None) for _ in arg_names], + ), + statements, + [], + ) + mod = ast.Module([func_def]) + ast.fix_missing_locations(mod) + + prog = compile(mod, '', 'exec') + + # Debug: show bytecode. + if debug: + dis.dis(prog) + for const in prog.co_consts: + if isinstance(const, types.CodeType): + dis.dis(const) + + the_locals = {} + exec prog in {}, the_locals + return the_locals[name] + + +# AST nodes for the template language. + class Symbol(object): """A variable-substitution symbol in a template.""" def __init__(self, ident, original): @@ -62,6 +152,11 @@ class Symbol(object): # Keep original text. return self.original + def translate(self): + """Compile the variable lookup.""" + expr = ex_rvalue(VARIABLE_PREFIX + self.ident.encode('utf8')) + return [expr], set([self.ident.encode('utf8')]), set() + class Call(object): """A function call in a template.""" def __init__(self, ident, args, original): @@ -81,7 +176,7 @@ class Call(object): arg_vals = [expr.evaluate(env) for expr in self.args] try: out = env.functions[self.ident](*arg_vals) - except Exception, exc: + except Exception as exc: # Function raised exception! Maybe inlining the name of # the exception will help debug. return u'<%s>' % unicode(exc) @@ -89,6 +184,36 @@ class Call(object): else: return self.original + def translate(self): + """Compile the function call.""" + varnames = set() + funcnames = set([self.ident.encode('utf8')]) + + arg_exprs = [] + for arg in self.args: + subexprs, subvars, subfuncs = arg.translate() + varnames.update(subvars) + funcnames.update(subfuncs) + + # Create a subexpression that joins the result components of + # the arguments. + arg_exprs.append(ex_call( + ast.Attribute(ex_literal(u''), 'join', ast.Load()), + [ex_call( + 'map', + [ + ex_rvalue('unicode'), + ast.List(subexprs, ast.Load()), + ] + )], + )) + + subexpr_call = ex_call( + FUNCTION_PREFIX + self.ident.encode('utf8'), + arg_exprs + ) + return [subexpr_call], varnames, funcnames + class Expression(object): """Top-level template construct: contains a list of text blobs, Symbols, and Calls. @@ -111,6 +236,26 @@ class Expression(object): out.append(part.evaluate(env)) return u''.join(map(unicode, out)) + def translate(self): + """Compile the expression to a list of Python AST expressions, a + set of variable names used, and a set of function names. + """ + expressions = [] + varnames = set() + funcnames = set() + for part in self.parts: + if isinstance(part, basestring): + expressions.append(ex_literal(part)) + else: + e, v, f = part.translate() + expressions.extend(e) + varnames.update(v) + funcnames.update(f) + return expressions, varnames, funcnames + + +# Parser. + class ParseError(Exception): pass @@ -266,7 +411,7 @@ class Parser(object): # No function name. self.parts.append(FUNC_DELIM) return - + if self.pos >= len(self.string): # Identifier terminates string. self.parts.append(self.string[start_pos:self.pos]) @@ -304,7 +449,7 @@ class Parser(object): # Extract and advance past the parsed expression. expressions.append(Expression(subparser.parts)) - self.pos += subparser.pos + self.pos += subparser.pos if self.pos >= len(self.string) or \ self.string[self.pos] == GROUP_CLOSE: @@ -340,14 +485,74 @@ def _parse(template): parts.append(remainder) return Expression(parts) + +# External interface. + class Template(object): """A string template, including text, Symbols, and Calls. """ def __init__(self, template): self.expr = _parse(template) self.original = template + self.compiled = self.translate() + + def interpret(self, values={}, functions={}): + """Like `substitute`, but forces the interpreter (rather than + the compiled version) to be used. The interpreter includes + exception-handling code for missing variables and buggy template + functions but is much slower. + """ + return self.expr.evaluate(Environment(values, functions)) def substitute(self, values={}, functions={}): """Evaluate the template given the values and functions. """ - return self.expr.evaluate(Environment(values, functions)) + try: + res = self.compiled(values, functions) + except: # Handle any exceptions thrown by compiled version. + res = self.interpret(values, functions) + return res + + def translate(self): + """Compile the template to a Python function.""" + expressions, varnames, funcnames = self.expr.translate() + + argnames = [] + for varname in varnames: + argnames.append(VARIABLE_PREFIX.encode('utf8') + varname) + for funcname in funcnames: + argnames.append(FUNCTION_PREFIX.encode('utf8') + funcname) + + func = compile_func( + argnames, + [ast.Return(ast.List(expressions, ast.Load()))], + ) + + def wrapper_func(values={}, functions={}): + args = {} + for varname in varnames: + args[VARIABLE_PREFIX + varname] = values[varname] + for funcname in funcnames: + args[FUNCTION_PREFIX + funcname] = functions[funcname] + parts = func(**args) + return u''.join(parts) + + return wrapper_func + + +# Performance tests. + +if __name__ == '__main__': + import timeit + _tmpl = Template(u'foo $bar %baz{foozle $bar barzle} $bar') + _vars = {'bar': 'qux'} + _funcs = {'baz': unicode.upper} + interp_time = timeit.timeit('_tmpl.interpret(_vars, _funcs)', + 'from __main__ import _tmpl, _vars, _funcs', + number=10000) + print(interp_time) + comp_time = timeit.timeit('_tmpl.substitute(_vars, _funcs)', + 'from __main__ import _tmpl, _vars, _funcs', + number=10000) + print(comp_time) + print('Speedup:', interp_time / comp_time) diff --git a/lib/beets/util/pipeline.py b/lib/beets/util/pipeline.py index 6adbf160..b81db3c7 100644 --- a/lib/beets/util/pipeline.py +++ b/lib/beets/util/pipeline.py @@ -8,7 +8,7 @@ # distribute, sublicense, and/or sell copies of the Software, and to # permit persons to whom the Software is furnished to do so, subject to # the following conditions: -# +# # The above copyright notice and this permission notice shall be # included in all copies or substantial portions of the Software. @@ -30,7 +30,8 @@ up a bottleneck stage by dividing its work among multiple threads. To do so, pass an iterable of coroutines to the Pipeline constructor in place of any single coroutine. """ -from __future__ import with_statement # for Python 2.5 +from __future__ import print_function + import Queue from threading import Thread, Lock import sys @@ -177,23 +178,23 @@ class FirstPipelineThread(PipelineThread): self.coro = coro self.out_queue = out_queue self.out_queue.acquire() - + self.abort_lock = Lock() self.abort_flag = False - + def run(self): try: while True: with self.abort_lock: if self.abort_flag: return - + # Get the value from the generator. try: msg = self.coro.next() except StopIteration: break - + # Send messages to the next stage. for msg in _allmsgs(msg): with self.abort_lock: @@ -207,7 +208,7 @@ class FirstPipelineThread(PipelineThread): # Generator finished; shut down the pipeline. self.out_queue.release() - + class MiddlePipelineThread(PipelineThread): """A thread running any stage in the pipeline except the first or last. @@ -223,7 +224,7 @@ class MiddlePipelineThread(PipelineThread): try: # Prime the coroutine. self.coro.next() - + while True: with self.abort_lock: if self.abort_flag: @@ -233,14 +234,14 @@ class MiddlePipelineThread(PipelineThread): msg = self.in_queue.get() if msg is POISON: break - + with self.abort_lock: if self.abort_flag: return # Invoke the current stage. out = self.coro.send(msg) - + # Send messages to next stage. for msg in _allmsgs(out): with self.abort_lock: @@ -251,7 +252,7 @@ class MiddlePipelineThread(PipelineThread): except: self.abort_all(sys.exc_info()) return - + # Pipeline is shutting down normally. self.out_queue.release() @@ -273,12 +274,12 @@ class LastPipelineThread(PipelineThread): with self.abort_lock: if self.abort_flag: return - + # Get the message from the previous stage. msg = self.in_queue.get() if msg is POISON: break - + with self.abort_lock: if self.abort_flag: return @@ -308,7 +309,7 @@ class Pipeline(object): self.stages.append((stage,)) else: self.stages.append(stage) - + def run_sequential(self): """Run the pipeline sequentially in the current thread. The stages are run one after the other. Only the first coroutine @@ -319,7 +320,7 @@ class Pipeline(object): # "Prime" the coroutines. for coro in coros[1:]: coro.next() - + # Begin the pipeline. for out in coros[0]: msgs = _allmsgs(out) @@ -329,7 +330,7 @@ class Pipeline(object): out = coro.send(msg) next_msgs.extend(_allmsgs(out)) msgs = next_msgs - + def run_parallel(self, queue_size=DEFAULT_QUEUE_SIZE): """Run the pipeline in parallel using one thread per stage. The messages between the stages are stored in queues of the given @@ -354,11 +355,11 @@ class Pipeline(object): threads.append( LastPipelineThread(coro, queues[-1], threads) ) - + # Start threads. for thread in threads: thread.start() - + # Wait for termination. The final thread lasts the longest. try: # Using a timeout allows us to receive KeyboardInterrupt @@ -371,7 +372,7 @@ class Pipeline(object): for thread in threads: thread.abort() raise - + finally: # Make completely sure that all the threads have finished # before we return. They should already be either finished, @@ -388,25 +389,25 @@ class Pipeline(object): # Smoke test. if __name__ == '__main__': import time - + # Test a normally-terminating pipeline both in sequence and # in parallel. def produce(): for i in range(5): - print 'generating %i' % i + print('generating %i' % i) time.sleep(1) yield i def work(): num = yield while True: - print 'processing %i' % num + print('processing %i' % num) time.sleep(2) num = yield num*2 def consume(): while True: num = yield time.sleep(1) - print 'received %i' % num + print('received %i' % num) ts_start = time.time() Pipeline([produce(), work(), consume()]).run_sequential() ts_seq = time.time() @@ -414,21 +415,21 @@ if __name__ == '__main__': ts_par = time.time() Pipeline([produce(), (work(), work()), consume()]).run_parallel() ts_end = time.time() - print 'Sequential time:', ts_seq - ts_start - print 'Parallel time:', ts_par - ts_seq - print 'Multiply-parallel time:', ts_end - ts_par - print + print('Sequential time:', ts_seq - ts_start) + print('Parallel time:', ts_par - ts_seq) + print('Multiply-parallel time:', ts_end - ts_par) + print() # Test a pipeline that raises an exception. def exc_produce(): for i in range(10): - print 'generating %i' % i + print('generating %i' % i) time.sleep(1) yield i def exc_work(): num = yield while True: - print 'processing %i' % num + print('processing %i' % num) time.sleep(3) if num == 3: raise Exception() @@ -438,5 +439,5 @@ if __name__ == '__main__': num = yield #if num == 4: # raise Exception() - print 'received %i' % num + print('received %i' % num) Pipeline([exc_produce(), exc_work(), exc_consume()]).run_parallel(1) diff --git a/lib/beets/vfs.py b/lib/beets/vfs.py index 614bc8f5..815f8db3 100644 --- a/lib/beets/vfs.py +++ b/lib/beets/vfs.py @@ -8,7 +8,7 @@ # distribute, sublicense, and/or sell copies of the Software, and to # permit persons to whom the Software is furnished to do so, subject to # the following conditions: -# +# # The above copyright notice and this permission notice shall be # included in all copies or substantial portions of the Software. From 69f51241cc56ea59b41cfc498d214a7119900230 Mon Sep 17 00:00:00 2001 From: rembo10 Date: Sat, 28 Jul 2012 23:48:19 +0530 Subject: [PATCH 12/84] Rough update of the musicbrainzngs lib --- lib/musicbrainzngs/__init__.py | 2 +- lib/musicbrainzngs/mbxml.py | 185 +++--- lib/musicbrainzngs/musicbrainz.py | 977 +++++++++++++++--------------- 3 files changed, 590 insertions(+), 574 deletions(-) diff --git a/lib/musicbrainzngs/__init__.py b/lib/musicbrainzngs/__init__.py index 40a89036..36962ef5 100644 --- a/lib/musicbrainzngs/__init__.py +++ b/lib/musicbrainzngs/__init__.py @@ -1 +1 @@ -from lib.musicbrainzngs.musicbrainz import * +from musicbrainzngs.musicbrainz import * diff --git a/lib/musicbrainzngs/mbxml.py b/lib/musicbrainzngs/mbxml.py index dd4ca961..7f6bd9f2 100644 --- a/lib/musicbrainzngs/mbxml.py +++ b/lib/musicbrainzngs/mbxml.py @@ -6,8 +6,7 @@ import xml.etree.ElementTree as ET import logging -from lib.musicbrainzngs import compat -from lib.musicbrainzngs import util +from musicbrainzngs import util try: from ET import fixtag @@ -40,7 +39,11 @@ def make_artist_credit(artists): names = [] for artist in artists: if isinstance(artist, dict): - names.append(artist.get("artist", {}).get("name", "")) + if "name" in artist: + names.append(artist.get("name", "")) + else: + names.append(artist.get("artist", {}).get("name", "")) + else: names.append(artist) return "".join(names) @@ -60,7 +63,7 @@ def parse_elements(valid_els, element): if ":" in t: t = t.split(":")[1] if t in valid_els: - result[t] = sub.text + result[t] = sub.text or "" else: _log.debug("in <%s>, uncaught <%s>", fixtag(element.tag, NS_MAP)[0], t) return result @@ -175,52 +178,48 @@ def parse_artist_list(al): return [parse_artist(a) for a in al] def parse_artist(artist): - result = {} - attribs = ["id", "type", "ext:score"] - elements = ["name", "sort-name", "country", "user-rating", - "disambiguation", "gender", "ipi"] - inner_els = {"life-span": parse_artist_lifespan, - "recording-list": parse_recording_list, - "release-list": parse_release_list, - "release-group-list": parse_release_group_list, - "work-list": parse_work_list, - "tag-list": parse_tag_list, - "user-tag-list": parse_tag_list, - "rating": parse_rating, - "alias-list": parse_alias_list} + result = {} + attribs = ["id", "type", "ext:score"] + elements = ["name", "sort-name", "country", "user-rating", + "disambiguation", "gender", "ipi"] + inner_els = {"life-span": parse_artist_lifespan, + "recording-list": parse_recording_list, + "release-list": parse_release_list, + "release-group-list": parse_release_group_list, + "work-list": parse_work_list, + "tag-list": parse_tag_list, + "user-tag-list": parse_tag_list, + "rating": parse_rating, + "ipi-list": parse_element_list, + "alias-list": parse_element_list} - result.update(parse_attributes(attribs, artist)) - result.update(parse_elements(elements, artist)) - result.update(parse_inner(inner_els, artist)) + result.update(parse_attributes(attribs, artist)) + result.update(parse_elements(elements, artist)) + result.update(parse_inner(inner_els, artist)) - return result + return result def parse_label_list(ll): - return [parse_label(l) for l in ll] + return [parse_label(l) for l in ll] def parse_label(label): - result = {} - attribs = ["id", "type", "ext:score"] - elements = ["name", "sort-name", "country", "label-code", "user-rating", - "ipi", "disambiguation"] - inner_els = {"life-span": parse_artist_lifespan, - "release-list": parse_release_list, - "tag-list": parse_tag_list, - "user-tag-list": parse_tag_list, - "rating": parse_rating, - "alias-list": parse_alias_list} + result = {} + attribs = ["id", "type", "ext:score"] + elements = ["name", "sort-name", "country", "label-code", "user-rating", + "ipi", "disambiguation"] + inner_els = {"life-span": parse_artist_lifespan, + "release-list": parse_release_list, + "tag-list": parse_tag_list, + "user-tag-list": parse_tag_list, + "rating": parse_rating, + "ipi-list": parse_element_list, + "alias-list": parse_element_list} - result.update(parse_attributes(attribs, label)) - result.update(parse_elements(elements, label)) - result.update(parse_inner(inner_els, label)) + result.update(parse_attributes(attribs, label)) + result.update(parse_elements(elements, label)) + result.update(parse_inner(inner_els, label)) - return result - -def parse_attribute_list(al): - return [parse_attribute_tag(a) for a in al] - -def parse_attribute_tag(attribute): - return attribute.text + return result def parse_relation_list(rl): attribs = ["target-type"] @@ -237,7 +236,7 @@ def parse_relation(relation): "recording": parse_recording, "release": parse_release, "release-group": parse_release_group, - "attribute-list": parse_attribute_list, + "attribute-list": parse_element_list, "work": parse_work } result.update(parse_attributes(attribs, relation)) @@ -285,22 +284,23 @@ def parse_text_representation(textr): return parse_elements(["language", "script"], textr) def parse_release_group(rg): - result = {} - attribs = ["id", "type", "ext:score"] - elements = ["title", "user-rating", "first-release-date"] - inner_els = {"artist-credit": parse_artist_credit, - "release-list": parse_release_list, - "tag-list": parse_tag_list, - "user-tag-list": parse_tag_list, - "rating": parse_rating} + result = {} + attribs = ["id", "type", "ext:score"] + elements = ["title", "user-rating", "first-release-date", "primary-type"] + inner_els = {"artist-credit": parse_artist_credit, + "release-list": parse_release_list, + "tag-list": parse_tag_list, + "user-tag-list": parse_tag_list, + "secondary-type-list": parse_element_list, + "rating": parse_rating} - result.update(parse_attributes(attribs, rg)) - result.update(parse_elements(elements, rg)) - result.update(parse_inner(inner_els, rg)) - if "artist-credit" in result: - result["artist-credit-phrase"] = make_artist_credit(result["artist-credit"]) + result.update(parse_attributes(attribs, rg)) + result.update(parse_elements(elements, rg)) + result.update(parse_inner(inner_els, rg)) + if "artist-credit" in result: + result["artist-credit-phrase"] = make_artist_credit(result["artist-credit"]) - return result + return result def parse_recording(recording): result = {} @@ -313,7 +313,8 @@ def parse_recording(recording): "rating": parse_rating, "puid-list": parse_external_id_list, "isrc-list": parse_external_id_list, - "echoprint-list": parse_external_id_list} + "echoprint-list": parse_external_id_list, + "relation-list": parse_relation_list} result.update(parse_attributes(attribs, recording)) result.update(parse_elements(elements, recording)) @@ -326,26 +327,28 @@ def parse_recording(recording): def parse_external_id_list(pl): return [parse_attributes(["id"], p)["id"] for p in pl] +def parse_element_list(el): + return [e.text for e in el] + def parse_work_list(wl): - result = [] - for w in wl: - result.append(parse_work(w)) - return result + return [parse_work(w) for w in wl] def parse_work(work): - result = {} - attribs = ["id", "ext:score"] - elements = ["title", "user-rating"] - inner_els = {"tag-list": parse_tag_list, - "user-tag-list": parse_tag_list, - "rating": parse_rating, - "alias-list": parse_alias_list} + result = {} + attribs = ["id", "ext:score"] + elements = ["title", "user-rating", "language", "iswc"] + inner_els = {"tag-list": parse_tag_list, + "user-tag-list": parse_tag_list, + "rating": parse_rating, + "alias-list": parse_element_list, + "iswc-list": parse_element_list, + "relation-list": parse_relation_list} - result.update(parse_attributes(attribs, work)) - result.update(parse_elements(elements, work)) - result.update(parse_inner(inner_els, work)) + result.update(parse_attributes(attribs, work)) + result.update(parse_elements(elements, work)) + result.update(parse_inner(inner_els, work)) - return result + return result def parse_disc(disc): result = {} @@ -429,19 +432,29 @@ def parse_track_list(tl): return result def parse_track(track): - result = {} - elements = ["position", "title","length"] #CHANGED!!! - inner_els = {"recording": parse_recording} + result = {} + elements = ["number", "position", "title", "length"] + inner_els = {"recording": parse_recording, + "artist-credit": parse_artist_credit} - result.update(parse_elements(elements, track)) - result.update(parse_inner(inner_els, track)) - return result + result.update(parse_elements(elements, track)) + result.update(parse_inner(inner_els, track)) + if "artist-credit" in result.get("recording", {}) and "artist-credit" not in result: + result["artist-credit"] = result["recording"]["artist-credit"] + if "artist-credit" in result: + result["artist-credit-phrase"] = make_artist_credit(result["artist-credit"]) + # Make a length field that contains track length or recording length + track_or_recording = None + if "length" in result: + track_or_recording = result["length"] + elif result.get("recording", {}).get("length"): + track_or_recording = result.get("recording", {}).get("length") + if track_or_recording: + result["track_or_recording_length"] = track_or_recording + return result def parse_tag_list(tl): - result = [] - for t in tl: - result.append(parse_tag(t)) - return result + return [parse_tag(t) for t in tl] def parse_tag(tag): result = {} @@ -462,12 +475,6 @@ def parse_rating(rating): return result -def parse_alias_list(al): - result = [] - for a in al: - result.append(a.text) - return result - ### def make_barcode_request(barcodes): NS = "http://musicbrainz.org/ns/mmd-2.0#" diff --git a/lib/musicbrainzngs/musicbrainz.py b/lib/musicbrainzngs/musicbrainz.py index c5a3d65e..b0e94fed 100644 --- a/lib/musicbrainzngs/musicbrainz.py +++ b/lib/musicbrainzngs/musicbrainz.py @@ -23,153 +23,157 @@ _log = logging.getLogger("musicbrainzngs") # Constants for validation. VALID_INCLUDES = { - 'artist': [ - "recordings", "releases", "release-groups", "works", # Subqueries - "various-artists", "discids", "media", - "aliases", "tags", "user-tags", "ratings", "user-ratings", # misc - "artist-rels", "label-rels", "recording-rels", "release-rels", - "release-group-rels", "url-rels", "work-rels" - ], - 'label': [ - "releases", # Subqueries - "discids", "media", - "aliases", "tags", "user-tags", "ratings", "user-ratings", # misc - "artist-rels", "label-rels", "recording-rels", "release-rels", - "release-group-rels", "url-rels", "work-rels" - ], - 'recording': [ - "artists", "releases", # Subqueries - "discids", "media", "artist-credits", - "tags", "user-tags", "ratings", "user-ratings", # misc - "artist-rels", "label-rels", "recording-rels", "release-rels", - "release-group-rels", "url-rels", "work-rels" - ], - 'release': [ - "artists", "labels", "recordings", "release-groups", "media", - "artist-credits", "discids", "puids", "echoprints", "isrcs", - "artist-rels", "label-rels", "recording-rels", "release-rels", - "release-group-rels", "url-rels", "work-rels", "recording-level-rels", - "work-level-rels" - ], - 'release-group': [ - "artists", "releases", "discids", "media", - "artist-credits", "tags", "user-tags", "ratings", "user-ratings", # misc - "artist-rels", "label-rels", "recording-rels", "release-rels", - "release-group-rels", "url-rels", "work-rels" - ], - 'work': [ - "artists", # Subqueries - "aliases", "tags", "user-tags", "ratings", "user-ratings", # misc - "artist-rels", "label-rels", "recording-rels", "release-rels", - "release-group-rels", "url-rels", "work-rels" - ], - 'discid': [ - "artists", "labels", "recordings", "release-groups", "media", - "artist-credits", "discids", "puids", "echoprints", "isrcs", - "artist-rels", "label-rels", "recording-rels", "release-rels", - "release-group-rels", "url-rels", "work-rels", "recording-level-rels", - "work-level-rels" - ], - 'echoprint': ["artists", "releases"], - 'puid': ["artists", "releases", "puids", "echoprints", "isrcs"], - 'isrc': ["artists", "releases", "puids", "echoprints", "isrcs"], - 'iswc': ["artists"], - 'collection': ['releases'], + 'artist': [ + "recordings", "releases", "release-groups", "works", # Subqueries + "various-artists", "discids", "media", + "aliases", "tags", "user-tags", "ratings", "user-ratings", # misc + "artist-rels", "label-rels", "recording-rels", "release-rels", + "release-group-rels", "url-rels", "work-rels" + ], + 'label': [ + "releases", # Subqueries + "discids", "media", + "aliases", "tags", "user-tags", "ratings", "user-ratings", # misc + "artist-rels", "label-rels", "recording-rels", "release-rels", + "release-group-rels", "url-rels", "work-rels" + ], + 'recording': [ + "artists", "releases", # Subqueries + "discids", "media", "artist-credits", + "tags", "user-tags", "ratings", "user-ratings", # misc + "artist-rels", "label-rels", "recording-rels", "release-rels", + "release-group-rels", "url-rels", "work-rels" + ], + 'release': [ + "artists", "labels", "recordings", "release-groups", "media", + "artist-credits", "discids", "puids", "echoprints", "isrcs", + "artist-rels", "label-rels", "recording-rels", "release-rels", + "release-group-rels", "url-rels", "work-rels", "recording-level-rels", + "work-level-rels" + ], + 'release-group': [ + "artists", "releases", "discids", "media", + "artist-credits", "tags", "user-tags", "ratings", "user-ratings", # misc + "artist-rels", "label-rels", "recording-rels", "release-rels", + "release-group-rels", "url-rels", "work-rels" + ], + 'work': [ + "artists", # Subqueries + "aliases", "tags", "user-tags", "ratings", "user-ratings", # misc + "artist-rels", "label-rels", "recording-rels", "release-rels", + "release-group-rels", "url-rels", "work-rels" + ], + 'discid': [ + "artists", "labels", "recordings", "release-groups", "media", + "artist-credits", "discids", "puids", "echoprints", "isrcs", + "artist-rels", "label-rels", "recording-rels", "release-rels", + "release-group-rels", "url-rels", "work-rels", "recording-level-rels", + "work-level-rels" + ], + 'echoprint': ["artists", "releases"], + 'puid': ["artists", "releases", "puids", "echoprints", "isrcs"], + 'isrc': ["artists", "releases", "puids", "echoprints", "isrcs"], + 'iswc': ["artists"], + 'collection': ['releases'], } VALID_RELEASE_TYPES = [ - "nat", "album", "single", "ep", "compilation", "soundtrack", "spokenword", - "interview", "audiobook", "live", "remix", "other" + "nat", "album", "single", "ep", "compilation", "soundtrack", "spokenword", + "interview", "audiobook", "live", "remix", "other" ] VALID_RELEASE_STATUSES = ["official", "promotion", "bootleg", "pseudo-release"] VALID_SEARCH_FIELDS = { - 'artist': [ - 'arid', 'artist', 'sortname', 'type', 'begin', 'end', 'comment', - 'alias', 'country', 'gender', 'tag', 'ipi', 'artistaccent' - ], - 'release-group': [ - 'rgid', 'releasegroup', 'reid', 'release', 'arid', 'artist', - 'artistname', 'creditname', 'type', 'tag', 'releasegroupaccent', - 'releases', 'comment' - ], - 'release': [ - 'reid', 'release', 'arid', 'artist', 'artistname', 'creditname', - 'type', 'status', 'tracks', 'tracksmedium', 'discids', - 'discidsmedium', 'mediums', 'date', 'asin', 'lang', 'script', - 'country', 'date', 'label', 'catno', 'barcode', 'puid', 'comment', - 'format', 'releaseaccent', 'rgid' - ], - 'recording': [ - 'rid', 'recording', 'isrc', 'arid', 'artist', 'artistname', - 'creditname', 'reid', 'release', 'type', 'status', 'tracks', - 'tracksrelease', 'dur', 'qdur', 'tnum', 'position', 'tag', 'comment', - 'country', 'date' 'format', 'recordingaccent' - ], - 'label': [ - 'laid', 'label', 'sortname', 'type', 'code', 'country', 'begin', - 'end', 'comment', 'alias', 'tag', 'ipi', 'labelaccent' - ], - 'work': [ - 'wid', 'work', 'iswc', 'type', 'arid', 'artist', 'alias', 'tag', - 'comment', 'workaccent' - ], + 'artist': [ + 'arid', 'artist', 'sortname', 'type', 'begin', 'end', 'comment', + 'alias', 'country', 'gender', 'tag', 'ipi', 'artistaccent' + ], + 'release-group': [ + 'rgid', 'releasegroup', 'reid', 'release', 'arid', 'artist', + 'artistname', 'creditname', 'type', 'tag', 'releasegroupaccent', + 'releases', 'comment' + ], + 'release': [ + 'reid', 'release', 'arid', 'artist', 'artistname', 'creditname', + 'type', 'status', 'tracks', 'tracksmedium', 'discids', + 'discidsmedium', 'mediums', 'date', 'asin', 'lang', 'script', + 'country', 'date', 'label', 'catno', 'barcode', 'puid', 'comment', + 'format', 'releaseaccent', 'rgid' + ], + 'recording': [ + 'rid', 'recording', 'isrc', 'arid', 'artist', 'artistname', + 'creditname', 'reid', 'release', 'type', 'status', 'tracks', + 'tracksrelease', 'dur', 'qdur', 'tnum', 'position', 'tag', 'comment', + 'country', 'date' 'format', 'recordingaccent' + ], + 'label': [ + 'laid', 'label', 'sortname', 'type', 'code', 'country', 'begin', + 'end', 'comment', 'alias', 'tag', 'ipi', 'labelaccent' + ], + 'work': [ + 'wid', 'work', 'iswc', 'type', 'arid', 'artist', 'alias', 'tag', + 'comment', 'workaccent' + ], } # Exceptions. class MusicBrainzError(Exception): - """Base class for all exceptions related to MusicBrainz.""" - pass + """Base class for all exceptions related to MusicBrainz.""" + pass class UsageError(MusicBrainzError): - """Error related to misuse of the module API.""" - pass + """Error related to misuse of the module API.""" + pass class InvalidSearchFieldError(UsageError): - pass + pass class InvalidIncludeError(UsageError): - def __init__(self, msg='Invalid Includes', reason=None): - super(InvalidIncludeError, self).__init__(self) - self.msg = msg - self.reason = reason + def __init__(self, msg='Invalid Includes', reason=None): + super(InvalidIncludeError, self).__init__(self) + self.msg = msg + self.reason = reason - def __str__(self): - return self.msg + def __str__(self): + return self.msg class InvalidFilterError(UsageError): - def __init__(self, msg='Invalid Includes', reason=None): - super(InvalidFilterError, self).__init__(self) - self.msg = msg - self.reason = reason + def __init__(self, msg='Invalid Includes', reason=None): + super(InvalidFilterError, self).__init__(self) + self.msg = msg + self.reason = reason - def __str__(self): - return self.msg + def __str__(self): + return self.msg class WebServiceError(MusicBrainzError): - """Error related to MusicBrainz API requests.""" - def __init__(self, message=None, cause=None): - """Pass ``cause`` if this exception was caused by another - exception. - """ - self.message = message - self.cause = cause + """Error related to MusicBrainz API requests.""" + def __init__(self, message=None, cause=None): + """Pass ``cause`` if this exception was caused by another + exception. + """ + self.message = message + self.cause = cause - def __str__(self): - if self.message: - msg = "%s, " % self.message - else: - msg = "" - msg += "caused by: %s" % str(self.cause) - return msg + def __str__(self): + if self.message: + msg = "%s, " % self.message + else: + msg = "" + msg += "caused by: %s" % str(self.cause) + return msg class NetworkError(WebServiceError): - """Problem communicating with the MB server.""" - pass + """Problem communicating with the MB server.""" + pass class ResponseError(WebServiceError): - """Bad response sent by the MB server.""" - pass + """Bad response sent by the MB server.""" + pass + +class AuthenticationError(WebServiceError): + """Received a HTTP 401 response while accessing a protected resource.""" + pass # Helpers for validating and formatting allowed sets. @@ -182,37 +186,37 @@ def _check_includes(entity, inc): _check_includes_impl(inc, VALID_INCLUDES[entity]) def _check_filter(values, valid): - for v in values: - if v not in valid: - raise InvalidFilterError(v) + for v in values: + if v not in valid: + raise InvalidFilterError(v) def _check_filter_and_make_params(entity, includes, release_status=[], release_type=[]): - """Check that the status or type values are valid. Then, check that - the filters can be used with the given includes. Return a params - dict that can be passed to _do_mb_query. - """ - if isinstance(release_status, compat.basestring): - release_status = [release_status] - if isinstance(release_type, compat.basestring): - release_type = [release_type] - _check_filter(release_status, VALID_RELEASE_STATUSES) - _check_filter(release_type, VALID_RELEASE_TYPES) + """Check that the status or type values are valid. Then, check that + the filters can be used with the given includes. Return a params + dict that can be passed to _do_mb_query. + """ + if isinstance(release_status, compat.basestring): + release_status = [release_status] + if isinstance(release_type, compat.basestring): + release_type = [release_type] + _check_filter(release_status, VALID_RELEASE_STATUSES) + _check_filter(release_type, VALID_RELEASE_TYPES) - if release_status and "releases" not in includes: - raise InvalidFilterError("Can't have a status with no release include") - if release_type and ("release-groups" not in includes and - "releases" not in includes and - entity != "release-group"): - raise InvalidFilterError("Can't have a release type with no " - "release-group include") + if release_status and "releases" not in includes: + raise InvalidFilterError("Can't have a status with no release include") + if release_type and ("release-groups" not in includes and + "releases" not in includes and + entity != "release-group"): + raise InvalidFilterError("Can't have a release type with no " + "release-group include") - # Build parameters. - params = {} - if len(release_status): - params["status"] = "|".join(release_status) - if len(release_type): - params["type"] = "|".join(release_type) - return params + # Build parameters. + params = {} + if len(release_status): + params["status"] = "|".join(release_status) + if len(release_type): + params["type"] = "|".join(release_type) + return params # Global authentication and endpoint details. @@ -223,20 +227,20 @@ _client = "" _useragent = "" def auth(u, p): - """Set the username and password to be used in subsequent queries to - the MusicBrainz XML API that require authentication. - """ - global user, password - user = u - password = p - + """Set the username and password to be used in subsequent queries to + the MusicBrainz XML API that require authentication. + """ + global user, password + user = u + password = p + def hpauth(u, p): - """Set the username and password to be used in subsequent queries to - the MusicBrainz XML API that require authentication. - """ - global hpuser, hppassword - hpuser = u - hppassword = p + """Set the username and password to be used in subsequent queries to + the MusicBrainz XML API that require authentication. + """ + global hpuser, hppassword + hpuser = u + hppassword = p def set_useragent(app, version, contact=None): """Set the User-Agent to be used for requests to the MusicBrainz webservice. @@ -261,24 +265,27 @@ limit_interval = 1.0 limit_requests = 1 do_rate_limit = True -def set_rate_limit(rate_limit=True, new_interval=1.0, new_requests=1): +def set_rate_limit(limit_or_interval=1.0, new_requests=1): """Sets the rate limiting behavior of the module. Must be invoked before the first Web service call. - If the `rate_limit` parameter is set to True, then only a set number - of requests (`new_requests`) will be made per given interval - (`new_interval`). If `rate_limit` is False, then no rate limiting - will occur. + If the `limit_or_interval` parameter is set to False then + rate limiting will be disabled. If it is a number then only + a set number of requests (`new_requests`) will be made per + given interval (`limit_or_interval`). """ global limit_interval global limit_requests global do_rate_limit - if new_interval <= 0.0: - raise ValueError("new_interval can't be less than 0") - if new_requests <= 0: - raise ValueError("new_requests can't be less than 0") - limit_interval = new_interval - limit_requests = new_requests - do_rate_limit = rate_limit + if isinstance(limit_or_interval, bool): + do_rate_limit = limit_or_interval + else: + if limit_or_interval <= 0.0: + raise ValueError("limit_or_interval can't be less than 0") + if new_requests <= 0: + raise ValueError("new_requests can't be less than 0") + do_rate_limit = True + limit_interval = limit_or_interval + limit_requests = new_requests class _rate_limit(object): """A decorator that limits the rate at which the function may be @@ -329,405 +336,407 @@ class _rate_limit(object): # From pymb2 class _RedirectPasswordMgr(compat.HTTPPasswordMgr): - def __init__(self): - self._realms = { } + def __init__(self): + self._realms = { } - def find_user_password(self, realm, uri): - # ignoring the uri parameter intentionally - try: - return self._realms[realm] - except KeyError: - return (None, None) + def find_user_password(self, realm, uri): + # ignoring the uri parameter intentionally + try: + return self._realms[realm] + except KeyError: + return (None, None) - def add_password(self, realm, uri, username, password): - # ignoring the uri parameter intentionally - self._realms[realm] = (username, password) + def add_password(self, realm, uri, username, password): + # ignoring the uri parameter intentionally + self._realms[realm] = (username, password) class _DigestAuthHandler(compat.HTTPDigestAuthHandler): - def get_authorization (self, req, chal): - qop = chal.get ('qop', None) - if qop and ',' in qop and 'auth' in qop.split (','): - chal['qop'] = 'auth' + def get_authorization (self, req, chal): + qop = chal.get ('qop', None) + if qop and ',' in qop and 'auth' in qop.split (','): + chal['qop'] = 'auth' - return compat.HTTPDigestAuthHandler.get_authorization (self, req, chal) + return compat.HTTPDigestAuthHandler.get_authorization (self, req, chal) class _MusicbrainzHttpRequest(compat.Request): - """ A custom request handler that allows DELETE and PUT""" - def __init__(self, method, url, data=None): - compat.Request.__init__(self, url, data) - allowed_m = ["GET", "POST", "DELETE", "PUT"] - if method not in allowed_m: - raise ValueError("invalid method: %s" % method) - self.method = method + """ A custom request handler that allows DELETE and PUT""" + def __init__(self, method, url, data=None): + compat.Request.__init__(self, url, data) + allowed_m = ["GET", "POST", "DELETE", "PUT"] + if method not in allowed_m: + raise ValueError("invalid method: %s" % method) + self.method = method - def get_method(self): - return self.method + def get_method(self): + return self.method # Core (internal) functions for calling the MB API. def _safe_open(opener, req, body=None, max_retries=3, retry_delay_delta=2.0): - """Open an HTTP request with a given URL opener and (optionally) a - request body. Transient errors lead to retries. Permanent errors - and repeated errors are translated into a small set of handleable - exceptions. Returns a file-like object. - """ - last_exc = None - for retry_num in range(max_retries): - if retry_num: # Not the first try: delay an increasing amount. - _log.debug("retrying after delay (#%i)" % retry_num) - time.sleep(retry_num * retry_delay_delta) + """Open an HTTP request with a given URL opener and (optionally) a + request body. Transient errors lead to retries. Permanent errors + and repeated errors are translated into a small set of handleable + exceptions. Returns a file-like object. + """ + last_exc = None + for retry_num in range(max_retries): + if retry_num: # Not the first try: delay an increasing amount. + _log.debug("retrying after delay (#%i)" % retry_num) + time.sleep(retry_num * retry_delay_delta) - try: - if body: - f = opener.open(req, body) - else: - f = opener.open(req) + try: + if body: + f = opener.open(req, body) + else: + f = opener.open(req) - except compat.HTTPError as exc: - if exc.code in (400, 404, 411): - # Bad request, not found, etc. - raise ResponseError(cause=exc) - elif exc.code in (503, 502, 500): - # Rate limiting, internal overloading... - _log.debug("HTTP error %i" % exc.code) - else: - # Other, unknown error. Should handle more cases, but - # retrying for now. - _log.debug("unknown HTTP error %i" % exc.code) - last_exc = exc - except compat.BadStatusLine as exc: - _log.debug("bad status line") - last_exc = exc - except compat.HTTPException as exc: - _log.debug("miscellaneous HTTP exception: %s" % str(exc)) - last_exc = exc - except compat.URLError as exc: - if isinstance(exc.reason, socket.error): - code = exc.reason.errno - if code == 104: # "Connection reset by peer." - continue - raise NetworkError(cause=exc) - except socket.error as exc: - if exc.errno == 104: - continue - raise NetworkError(cause=exc) - except IOError as exc: - raise NetworkError(cause=exc) - else: - # No exception! Yay! - return f + except compat.HTTPError as exc: + if exc.code in (400, 404, 411): + # Bad request, not found, etc. + raise ResponseError(cause=exc) + elif exc.code in (503, 502, 500): + # Rate limiting, internal overloading... + _log.debug("HTTP error %i" % exc.code) + elif exc.code in (401, ): + raise AuthenticationError(cause=exc) + else: + # Other, unknown error. Should handle more cases, but + # retrying for now. + _log.debug("unknown HTTP error %i" % exc.code) + last_exc = exc + except compat.BadStatusLine as exc: + _log.debug("bad status line") + last_exc = exc + except compat.HTTPException as exc: + _log.debug("miscellaneous HTTP exception: %s" % str(exc)) + last_exc = exc + except compat.URLError as exc: + if isinstance(exc.reason, socket.error): + code = exc.reason.errno + if code == 104: # "Connection reset by peer." + continue + raise NetworkError(cause=exc) + except socket.error as exc: + if exc.errno == 104: + continue + raise NetworkError(cause=exc) + except IOError as exc: + raise NetworkError(cause=exc) + else: + # No exception! Yay! + return f - # Out of retries! - raise NetworkError("retried %i times" % max_retries, last_exc) + # Out of retries! + raise NetworkError("retried %i times" % max_retries, last_exc) # Get the XML parsing exceptions to catch. The behavior chnaged with Python 2.7 # and ElementTree 1.3. if hasattr(etree, 'ParseError'): - ETREE_EXCEPTIONS = (etree.ParseError, expat.ExpatError) + ETREE_EXCEPTIONS = (etree.ParseError, expat.ExpatError) else: - ETREE_EXCEPTIONS = (expat.ExpatError) + ETREE_EXCEPTIONS = (expat.ExpatError) @_rate_limit def _mb_request(path, method='GET', auth_required=False, client_required=False, - args=None, data=None, body=None): - """Makes a request for the specified `path` (endpoint) on /ws/2 on - the globally-specified hostname. Parses the responses and returns - the resulting object. `auth_required` and `client_required` control - whether exceptions should be raised if the client and - username/password are left unspecified, respectively. - """ - if args is None: - args = {} - else: - args = dict(args) or {} + args=None, data=None, body=None): + """Makes a request for the specified `path` (endpoint) on /ws/2 on + the globally-specified hostname. Parses the responses and returns + the resulting object. `auth_required` and `client_required` control + whether exceptions should be raised if the client and + username/password are left unspecified, respectively. + """ + if args is None: + args = {} + else: + args = dict(args) or {} - if _useragent == "": - raise UsageError("set a proper user-agent with " - "set_useragent(\"application name\", \"application version\", \"contact info (preferably URL or email for your application)\")") + if _useragent == "": + raise UsageError("set a proper user-agent with " + "set_useragent(\"application name\", \"application version\", \"contact info (preferably URL or email for your application)\")") - if client_required: - args["client"] = _client + if client_required: + args["client"] = _client - # Encode Unicode arguments using UTF-8. - for key, value in args.items(): - if isinstance(value, compat.unicode): - args[key] = value.encode('utf8') + # Encode Unicode arguments using UTF-8. + for key, value in args.items(): + if isinstance(value, compat.unicode): + args[key] = value.encode('utf8') - # Construct the full URL for the request, including hostname and - # query string. - url = compat.urlunparse(( - 'http', - hostname, - '/ws/2/%s' % path, - '', - compat.urlencode(args), - '' - )) - _log.debug("%s request for %s" % (method, url)) + # Construct the full URL for the request, including hostname and + # query string. + url = compat.urlunparse(( + 'http', + hostname, + '/ws/2/%s' % path, + '', + compat.urlencode(args), + '' + )) + _log.debug("%s request for %s" % (method, url)) - # Set up HTTP request handler and URL opener. - httpHandler = compat.HTTPHandler(debuglevel=0) - handlers = [httpHandler] + # Set up HTTP request handler and URL opener. + httpHandler = compat.HTTPHandler(debuglevel=0) + handlers = [httpHandler] - # Add credentials if required. - if auth_required: - _log.debug("Auth required for %s" % url) - if not user: - raise UsageError("authorization required; " - "use auth(user, pass) first") - passwordMgr = _RedirectPasswordMgr() - authHandler = _DigestAuthHandler(passwordMgr) - authHandler.add_password("musicbrainz.org", (), user, password) - handlers.append(authHandler) + # Add credentials if required. + if auth_required: + _log.debug("Auth required for %s" % url) + if not user: + raise UsageError("authorization required; " + "use auth(user, pass) first") + passwordMgr = _RedirectPasswordMgr() + authHandler = _DigestAuthHandler(passwordMgr) + authHandler.add_password("musicbrainz.org", (), user, password) + handlers.append(authHandler) - opener = compat.build_opener(*handlers) + opener = compat.build_opener(*handlers) - # Make request. - req = _MusicbrainzHttpRequest(method, url, data) - req.add_header('User-Agent', _useragent) - + # Make request. + req = _MusicbrainzHttpRequest(method, url, data) + req.add_header('User-Agent', _useragent) + # Add headphones credentials - if hostname == '178.63.142.150:8181': - base64string = base64.encodestring('%s:%s' % (hpuser, hppassword)).replace('\n', '') - req.add_header("Authorization", "Basic %s" % base64string) - - _log.debug("requesting with UA %s" % _useragent) - if body: - req.add_header('Content-Type', 'application/xml; charset=UTF-8') - elif not data and not req.has_header('Content-Length'): - # Explicitly indicate zero content length if no request data - # will be sent (avoids HTTP 411 error). - req.add_header('Content-Length', '0') - f = _safe_open(opener, req, body) + if hostname == '178.63.142.150:8181': + base64string = base64.encodestring('%s:%s' % (hpuser, hppassword)).replace('\n', '') + req.add_header("Authorization", "Basic %s" % base64string) + + _log.debug("requesting with UA %s" % _useragent) + if body: + req.add_header('Content-Type', 'application/xml; charset=UTF-8') + elif not data and not req.has_header('Content-Length'): + # Explicitly indicate zero content length if no request data + # will be sent (avoids HTTP 411 error). + req.add_header('Content-Length', '0') + f = _safe_open(opener, req, body) - # Parse the response. - try: - return mbxml.parse_message(f) - except UnicodeError as exc: - raise ResponseError(cause=exc) - except Exception as exc: - if isinstance(exc, ETREE_EXCEPTIONS): - raise ResponseError(cause=exc) - else: - raise + # Parse the response. + try: + return mbxml.parse_message(f) + except UnicodeError as exc: + raise ResponseError(cause=exc) + except Exception as exc: + if isinstance(exc, ETREE_EXCEPTIONS): + raise ResponseError(cause=exc) + else: + raise def _is_auth_required(entity, includes): - """ Some calls require authentication. This returns - True if a call does, False otherwise - """ - if "user-tags" in includes or "user-ratings" in includes: - return True - elif entity.startswith("collection"): - return True - else: - return False + """ Some calls require authentication. This returns + True if a call does, False otherwise + """ + if "user-tags" in includes or "user-ratings" in includes: + return True + elif entity.startswith("collection"): + return True + else: + return False def _do_mb_query(entity, id, includes=[], params={}): - """Make a single GET call to the MusicBrainz XML API. `entity` is a - string indicated the type of object to be retrieved. The id may be - empty, in which case the query is a search. `includes` is a list - of strings that must be valid includes for the entity type. `params` - is a dictionary of additional parameters for the API call. The - response is parsed and returned. - """ - # Build arguments. - if not isinstance(includes, list): - includes = [includes] - _check_includes(entity, includes) - auth_required = _is_auth_required(entity, includes) - args = dict(params) - if len(includes) > 0: - inc = " ".join(includes) - args["inc"] = inc + """Make a single GET call to the MusicBrainz XML API. `entity` is a + string indicated the type of object to be retrieved. The id may be + empty, in which case the query is a search. `includes` is a list + of strings that must be valid includes for the entity type. `params` + is a dictionary of additional parameters for the API call. The + response is parsed and returned. + """ + # Build arguments. + if not isinstance(includes, list): + includes = [includes] + _check_includes(entity, includes) + auth_required = _is_auth_required(entity, includes) + args = dict(params) + if len(includes) > 0: + inc = " ".join(includes) + args["inc"] = inc - # Build the endpoint components. - path = '%s/%s' % (entity, id) - return _mb_request(path, 'GET', auth_required, args=args) + # Build the endpoint components. + path = '%s/%s' % (entity, id) + return _mb_request(path, 'GET', auth_required, args=args) def _do_mb_search(entity, query='', fields={}, - limit=None, offset=None, strict=False): - """Perform a full-text search on the MusicBrainz search server. - `query` is a lucene query string when no fields are set, - but is escaped when any fields are given. `fields` is a dictionary - of key/value query parameters. They keys in `fields` must be valid - for the given entity type. - """ - # Encode the query terms as a Lucene query string. - query_parts = [] - if query: - clean_query = util._unicode(query) - if fields: - clean_query = re.sub(r'([+\-&|!(){}\[\]\^"~*?:\\])', - r'\\\1', clean_query) - if strict: - query_parts.append('"%s"' % clean_query) - else: - query_parts.append(clean_query.lower()) - else: - query_parts.append(clean_query) - for key, value in fields.items(): - # Ensure this is a valid search field. - if key not in VALID_SEARCH_FIELDS[entity]: - raise InvalidSearchFieldError( - '%s is not a valid search field for %s' % (key, entity) - ) + limit=None, offset=None, strict=False): + """Perform a full-text search on the MusicBrainz search server. + `query` is a lucene query string when no fields are set, + but is escaped when any fields are given. `fields` is a dictionary + of key/value query parameters. They keys in `fields` must be valid + for the given entity type. + """ + # Encode the query terms as a Lucene query string. + query_parts = [] + if query: + clean_query = util._unicode(query) + if fields: + clean_query = re.sub(r'([+\-&|!(){}\[\]\^"~*?:\\])', + r'\\\1', clean_query) + if strict: + query_parts.append('"%s"' % clean_query) + else: + query_parts.append(clean_query.lower()) + else: + query_parts.append(clean_query) + for key, value in fields.items(): + # Ensure this is a valid search field. + if key not in VALID_SEARCH_FIELDS[entity]: + raise InvalidSearchFieldError( + '%s is not a valid search field for %s' % (key, entity) + ) - # Escape Lucene's special characters. - value = util._unicode(value) - value = re.sub(r'([+\-&|!(){}\[\]\^"~*?:\\])', r'\\\1', value) - if value: - if strict: - query_parts.append('%s:"%s"' % (key, value)) - else: - value = value.lower() # avoid AND / OR - query_parts.append('%s:(%s)' % (key, value)) - if strict: - full_query = ' AND '.join(query_parts).strip() - else: - full_query = ' '.join(query_parts).strip() + # Escape Lucene's special characters. + value = util._unicode(value) + value = re.sub(r'([+\-&|!(){}\[\]\^"~*?:\\])', r'\\\1', value) + if value: + if strict: + query_parts.append('%s:"%s"' % (key, value)) + else: + value = value.lower() # avoid AND / OR + query_parts.append('%s:(%s)' % (key, value)) + if strict: + full_query = ' AND '.join(query_parts).strip() + else: + full_query = ' '.join(query_parts).strip() - if not full_query: - raise ValueError('at least one query term is required') + if not full_query: + raise ValueError('at least one query term is required') - # Additional parameters to the search. - params = {'query': full_query} - if limit: - params['limit'] = str(limit) - if offset: - params['offset'] = str(offset) + # Additional parameters to the search. + params = {'query': full_query} + if limit: + params['limit'] = str(limit) + if offset: + params['offset'] = str(offset) - return _do_mb_query(entity, '', [], params) + return _do_mb_query(entity, '', [], params) def _do_mb_delete(path): - """Send a DELETE request for the specified object. - """ - return _mb_request(path, 'DELETE', True, True) + """Send a DELETE request for the specified object. + """ + return _mb_request(path, 'DELETE', True, True) def _do_mb_put(path): - """Send a PUT request for the specified object. - """ - return _mb_request(path, 'PUT', True, True) + """Send a PUT request for the specified object. + """ + return _mb_request(path, 'PUT', True, True) def _do_mb_post(path, body): - """Perform a single POST call for an endpoint with a specified - request body. - """ - return _mb_request(path, 'POST', True, True, body=body) + """Perform a single POST call for an endpoint with a specified + request body. + """ + return _mb_request(path, 'POST', True, True, body=body) # The main interface! # Single entity by ID def get_artist_by_id(id, includes=[], release_status=[], release_type=[]): - params = _check_filter_and_make_params("artist", includes, release_status, release_type) - return _do_mb_query("artist", id, includes, params) + params = _check_filter_and_make_params("artist", includes, release_status, release_type) + return _do_mb_query("artist", id, includes, params) def get_label_by_id(id, includes=[], release_status=[], release_type=[]): - params = _check_filter_and_make_params("label", includes, release_status, release_type) - return _do_mb_query("label", id, includes, params) + params = _check_filter_and_make_params("label", includes, release_status, release_type) + return _do_mb_query("label", id, includes, params) def get_recording_by_id(id, includes=[], release_status=[], release_type=[]): - params = _check_filter_and_make_params("recording", includes, release_status, release_type) - return _do_mb_query("recording", id, includes, params) + params = _check_filter_and_make_params("recording", includes, release_status, release_type) + return _do_mb_query("recording", id, includes, params) def get_release_by_id(id, includes=[], release_status=[], release_type=[]): - params = _check_filter_and_make_params("release", includes, release_status, release_type) - return _do_mb_query("release", id, includes, params) + params = _check_filter_and_make_params("release", includes, release_status, release_type) + return _do_mb_query("release", id, includes, params) def get_release_group_by_id(id, includes=[], release_status=[], release_type=[]): - params = _check_filter_and_make_params("release-group", includes, release_status, release_type) - return _do_mb_query("release-group", id, includes, params) + params = _check_filter_and_make_params("release-group", includes, release_status, release_type) + return _do_mb_query("release-group", id, includes, params) def get_work_by_id(id, includes=[]): - return _do_mb_query("work", id, includes) + return _do_mb_query("work", id, includes) # Searching def search_artists(query='', limit=None, offset=None, strict=False, **fields): - """Search for artists by a free-form `query` string or any of - the following keyword arguments specifying field queries: - arid, artist, sortname, type, begin, end, comment, alias, country, - gender, tag - When `fields` are set, special lucene characters are escaped - in the `query`. - """ - return _do_mb_search('artist', query, fields, limit, offset, strict) + """Search for artists by a free-form `query` string or any of + the following keyword arguments specifying field queries: + arid, artist, sortname, type, begin, end, comment, alias, country, + gender, tag + When `fields` are set, special lucene characters are escaped + in the `query`. + """ + return _do_mb_search('artist', query, fields, limit, offset, strict) def search_labels(query='', limit=None, offset=None, strict=False, **fields): - """Search for labels by a free-form `query` string or any of - the following keyword arguments specifying field queries: - laid, label, sortname, type, code, country, begin, end, comment, - alias, tag - When `fields` are set, special lucene characters are escaped - in the `query`. - """ - return _do_mb_search('label', query, fields, limit, offset, strict) + """Search for labels by a free-form `query` string or any of + the following keyword arguments specifying field queries: + laid, label, sortname, type, code, country, begin, end, comment, + alias, tag + When `fields` are set, special lucene characters are escaped + in the `query`. + """ + return _do_mb_search('label', query, fields, limit, offset, strict) def search_recordings(query='', limit=None, offset=None, strict=False, **fields): - """Search for recordings by a free-form `query` string or any of - the following keyword arguments specifying field queries: - rid, recording, isrc, arid, artist, artistname, creditname, reid, - release, type, status, tracks, tracksrelease, dur, qdur, tnum, - position, tag - When `fields` are set, special lucene characters are escaped - in the `query`. - """ - return _do_mb_search('recording', query, fields, limit, offset, strict) + """Search for recordings by a free-form `query` string or any of + the following keyword arguments specifying field queries: + rid, recording, isrc, arid, artist, artistname, creditname, reid, + release, type, status, tracks, tracksrelease, dur, qdur, tnum, + position, tag + When `fields` are set, special lucene characters are escaped + in the `query`. + """ + return _do_mb_search('recording', query, fields, limit, offset, strict) def search_releases(query='', limit=None, offset=None, strict=False, **fields): - """Search for releases by a free-form `query` string or any of - the following keyword arguments specifying field queries: - reid, release, arid, artist, artistname, creditname, type, status, - tracks, tracksmedium, discids, discidsmedium, mediums, date, asin, - lang, script, country, date, label, catno, barcode, puid - When `fields` are set, special lucene characters are escaped - in the `query`. - """ - return _do_mb_search('release', query, fields, limit, offset, strict) + """Search for releases by a free-form `query` string or any of + the following keyword arguments specifying field queries: + reid, release, arid, artist, artistname, creditname, type, status, + tracks, tracksmedium, discids, discidsmedium, mediums, date, asin, + lang, script, country, date, label, catno, barcode, puid + When `fields` are set, special lucene characters are escaped + in the `query`. + """ + return _do_mb_search('release', query, fields, limit, offset, strict) def search_release_groups(query='', limit=None, offset=None, - strict=False, **fields): - """Search for release groups by a free-form `query` string or - any of the following keyword arguments specifying field queries: - rgid, releasegroup, reid, release, arid, artist, artistname, - creditname, type, tag - When `fields` are set, special lucene characters are escaped - in the `query`. - """ - return _do_mb_search('release-group', query, fields, - limit, offset, strict) + strict=False, **fields): + """Search for release groups by a free-form `query` string or + any of the following keyword arguments specifying field queries: + rgid, releasegroup, reid, release, arid, artist, artistname, + creditname, type, tag + When `fields` are set, special lucene characters are escaped + in the `query`. + """ + return _do_mb_search('release-group', query, fields, + limit, offset, strict) def search_works(query='', limit=None, offset=None, strict=False, **fields): - """Search for works by a free-form `query` string or any of - the following keyword arguments specifying field queries: - wid, work, iswc, type, arid, artist, alias, tag - When `fields` are set, special lucene characters are escaped - in the `query`. - """ - return _do_mb_search('work', query, fields, limit, offset, strict) + """Search for works by a free-form `query` string or any of + the following keyword arguments specifying field queries: + wid, work, iswc, type, arid, artist, alias, tag + When `fields` are set, special lucene characters are escaped + in the `query`. + """ + return _do_mb_search('work', query, fields, limit, offset, strict) # Lists of entities def get_releases_by_discid(id, includes=[], release_status=[], release_type=[]): - params = _check_filter_and_make_params(includes, release_status, release_type=release_type) - return _do_mb_query("discid", id, includes, params) + params = _check_filter_and_make_params(includes, release_status, release_type=release_type) + return _do_mb_query("discid", id, includes, params) def get_recordings_by_echoprint(echoprint, includes=[], release_status=[], release_type=[]): - params = _check_filter_and_make_params(includes, release_status, release_type) - return _do_mb_query("echoprint", echoprint, includes, params) + params = _check_filter_and_make_params(includes, release_status, release_type) + return _do_mb_query("echoprint", echoprint, includes, params) def get_recordings_by_puid(puid, includes=[], release_status=[], release_type=[]): - params = _check_filter_and_make_params(includes, release_status, release_type) - return _do_mb_query("puid", puid, includes, params) + params = _check_filter_and_make_params(includes, release_status, release_type) + return _do_mb_query("puid", puid, includes, params) def get_recordings_by_isrc(isrc, includes=[], release_status=[], release_type=[]): - params = _check_filter_and_make_params(includes, release_status, release_type) - return _do_mb_query("isrc", isrc, includes, params) + params = _check_filter_and_make_params(includes, release_status, release_type) + return _do_mb_query("isrc", isrc, includes, params) def get_works_by_iswc(iswc, includes=[]): - return _do_mb_query("iswc", iswc, includes) + return _do_mb_query("iswc", iswc, includes) def _browse_impl(entity, includes, valid_includes, limit, offset, params, release_status=[], release_type=[]): _check_includes_impl(includes, valid_includes) @@ -784,20 +793,20 @@ def browse_release_groups(artist=None, release=None, release_type=[], includes=[ # Collections def get_collections(): - # Missing the count in the reply - return _do_mb_query("collection", '') + # Missing the count in the reply + return _do_mb_query("collection", '') def get_releases_in_collection(collection): - return _do_mb_query("collection", "%s/releases" % collection) + return _do_mb_query("collection", "%s/releases" % collection) # Submission methods def submit_barcodes(barcodes): - """Submits a set of {release1: barcode1, release2:barcode2} + """Submits a set of {release1: barcode1, release2:barcode2} - Must call auth(user, pass) first""" - query = mbxml.make_barcode_request(barcodes) - return _do_mb_post("release", query) + Must call auth(user, pass) first""" + query = mbxml.make_barcode_request(barcodes) + return _do_mb_post("release", query) def submit_puids(puids): """Submit PUIDs. @@ -815,7 +824,7 @@ def submit_echoprints(echoprints): def submit_isrcs(recordings_isrcs): """Submit ISRCs. - Submits a set of {recording-id: [isrc1, isrc1, ...]} + Submits a set of {recording-id: [isrc1, isrc2, ...]} Must call auth(user, pass) first""" query = mbxml.make_isrc_request(recordings_isrcs=recordings_isrcs) From 2141ab45b71eb7fbe1c53d42bdc1f7342b683599 Mon Sep 17 00:00:00 2001 From: rembo10 Date: Sat, 28 Jul 2012 23:52:24 +0530 Subject: [PATCH 13/84] Updated mako libs to 0.7.2 --- mako/__init__.py | 4 +- mako/_ast_util.py | 2 +- mako/ast.py | 74 ++-- mako/cache.py | 336 ++++++++++++------ mako/codegen.py | 724 +++++++++++++++++++++++--------------- mako/exceptions.py | 137 +++++--- mako/ext/autohandler.py | 2 +- mako/ext/babelplugin.py | 5 +- mako/ext/beaker_cache.py | 70 ++++ mako/ext/preprocessors.py | 2 +- mako/ext/pygmentplugin.py | 51 ++- mako/ext/turbogears.py | 5 +- mako/filters.py | 23 +- mako/lexer.py | 253 +++++++------ mako/lookup.py | 261 +++++++------- mako/parsetree.py | 294 +++++++++------- mako/pygen.py | 124 ++++--- mako/pyparser.py | 64 ++-- mako/runtime.py | 440 ++++++++++++++--------- mako/template.py | 558 +++++++++++++++++------------ mako/util.py | 146 ++++++-- 21 files changed, 2188 insertions(+), 1387 deletions(-) create mode 100644 mako/ext/beaker_cache.py diff --git a/mako/__init__.py b/mako/__init__.py index c0f78adc..a16564be 100644 --- a/mako/__init__.py +++ b/mako/__init__.py @@ -1,9 +1,9 @@ # mako/__init__.py -# Copyright (C) 2006-2011 the Mako authors and contributors +# Copyright (C) 2006-2012 the Mako authors and contributors # # This module is part of Mako and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php -__version__ = '0.4.1' +__version__ = '0.7.2' diff --git a/mako/_ast_util.py b/mako/_ast_util.py index 9521ccbb..a1bd54c4 100644 --- a/mako/_ast_util.py +++ b/mako/_ast_util.py @@ -1,5 +1,5 @@ # mako/_ast_util.py -# Copyright (C) 2006-2011 the Mako authors and contributors +# Copyright (C) 2006-2012 the Mako authors and contributors # # This module is part of Mako and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php diff --git a/mako/ast.py b/mako/ast.py index 4365b0b1..76311e9d 100644 --- a/mako/ast.py +++ b/mako/ast.py @@ -1,10 +1,10 @@ # mako/ast.py -# Copyright (C) 2006-2011 the Mako authors and contributors +# Copyright (C) 2006-2012 the Mako authors and contributors # # This module is part of Mako and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php -"""utilities for analyzing expressions and blocks of Python +"""utilities for analyzing expressions and blocks of Python code, as well as generating Python from AST nodes""" from mako import exceptions, pyparser, util @@ -14,20 +14,23 @@ class PythonCode(object): """represents information about a string containing Python code""" def __init__(self, code, **exception_kwargs): self.code = code - - # represents all identifiers which are assigned to at some point in the code - self.declared_identifiers = set() - - # represents all identifiers which are referenced before their assignment, if any - self.undeclared_identifiers = set() - - # note that an identifier can be in both the undeclared and declared lists. - # using AST to parse instead of using code.co_varnames, + # represents all identifiers which are assigned to at some point in + # the code + self.declared_identifiers = set() + + # represents all identifiers which are referenced before their + # assignment, if any + self.undeclared_identifiers = set() + + # note that an identifier can be in both the undeclared and declared + # lists. + + # using AST to parse instead of using code.co_varnames, # code.co_names has several advantages: - # - we can locate an identifier as "undeclared" even if + # - we can locate an identifier as "undeclared" even if # its declared later in the same block of code - # - AST is less likely to break with version changes + # - AST is less likely to break with version changes # (for example, the behavior of co_names changed a little bit # in python version 2.5) if isinstance(code, basestring): @@ -56,11 +59,12 @@ class ArgumentList(object): f = pyparser.FindTuple(self, PythonCode, **exception_kwargs) f.visit(expr) - + class PythonFragment(PythonCode): - """extends PythonCode to provide identifier lookups in partial control statements - - e.g. + """extends PythonCode to provide identifier lookups in partial control + statements + + e.g. for x in 5: elif y==9: except (MyException, e): @@ -70,8 +74,8 @@ class PythonFragment(PythonCode): m = re.match(r'^(\w+)(?:\s+(.*?))?:\s*(#|$)', code.strip(), re.S) if not m: raise exceptions.CompileException( - "Fragment '%s' is not a partial control statement" % - code, **exception_kwargs) + "Fragment '%s' is not a partial control statement" % + code, **exception_kwargs) if m.group(3): code = code[:m.start(3)] (keyword, expr) = m.group(1,2) @@ -83,33 +87,36 @@ class PythonFragment(PythonCode): code = "if False:pass\n" + code + "pass" elif keyword == 'except': code = "try:pass\n" + code + "pass" + elif keyword == 'with': + code = code + "pass" else: raise exceptions.CompileException( - "Unsupported control keyword: '%s'" % + "Unsupported control keyword: '%s'" % keyword, **exception_kwargs) super(PythonFragment, self).__init__(code, **exception_kwargs) - - + + class FunctionDecl(object): """function declaration""" def __init__(self, code, allow_kwargs=True, **exception_kwargs): self.code = code expr = pyparser.parse(code, "exec", **exception_kwargs) - + f = pyparser.ParseFunc(self, **exception_kwargs) f.visit(expr) if not hasattr(self, 'funcname'): raise exceptions.CompileException( - "Code '%s' is not a function declaration" % code, - **exception_kwargs) + "Code '%s' is not a function declaration" % code, + **exception_kwargs) if not allow_kwargs and self.kwargs: raise exceptions.CompileException( - "'**%s' keyword argument not allowed here" % + "'**%s' keyword argument not allowed here" % self.argnames[-1], **exception_kwargs) - + def get_argument_expressions(self, include_defaults=True): - """return the argument declarations of this FunctionDecl as a printable list.""" - + """return the argument declarations of this FunctionDecl as a printable + list.""" + namedecls = [] defaults = [d for d in self.defaults] kwargs = self.kwargs @@ -127,8 +134,8 @@ class FunctionDecl(object): else: default = len(defaults) and defaults.pop() or None if include_defaults and default: - namedecls.insert(0, "%s=%s" % - (arg, + namedecls.insert(0, "%s=%s" % + (arg, pyparser.ExpressionGenerator(default).value() ) ) @@ -138,6 +145,7 @@ class FunctionDecl(object): class FunctionArgs(FunctionDecl): """the argument portion of a function declaration""" - + def __init__(self, code, **kwargs): - super(FunctionArgs, self).__init__("def ANON(%s):pass" % code, **kwargs) + super(FunctionArgs, self).__init__("def ANON(%s):pass" % code, + **kwargs) diff --git a/mako/cache.py b/mako/cache.py index ce73ae5c..f50ce58a 100644 --- a/mako/cache.py +++ b/mako/cache.py @@ -1,124 +1,236 @@ # mako/cache.py -# Copyright (C) 2006-2011 the Mako authors and contributors +# Copyright (C) 2006-2012 the Mako authors and contributors # # This module is part of Mako and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php -from mako import exceptions +from mako import exceptions, util -cache = None +_cache_plugins = util.PluginLoader("mako.cache") + +register_plugin = _cache_plugins.register +register_plugin("beaker", "mako.ext.beaker_cache", "BeakerCacheImpl") -class BeakerMissing(object): - def get_cache(self, name, **kwargs): - raise exceptions.RuntimeException("the Beaker package is required to use cache functionality.") class Cache(object): """Represents a data content cache made available to the module - space of a :class:`.Template` object. - - :class:`.Cache` is a wrapper on top of a Beaker CacheManager object. - This object in turn references any number of "containers", each of - which defines its own backend (i.e. file, memory, memcached, etc.) - independently of the rest. - - """ - - def __init__(self, id, starttime): - self.id = id - self.starttime = starttime - self.def_regions = {} - - def put(self, key, value, **kwargs): - """Place a value in the cache. - - :param key: the value's key. - :param value: the value - :param \**kwargs: cache configuration arguments. The - backend is configured using these arguments upon first request. - Subsequent requests that use the same series of configuration - values will use that same backend. - - """ - - defname = kwargs.pop('defname', None) - expiretime = kwargs.pop('expiretime', None) - createfunc = kwargs.pop('createfunc', None) - - self._get_cache(defname, **kwargs).put_value(key, starttime=self.starttime, expiretime=expiretime) - - def get(self, key, **kwargs): - """Retrieve a value from the cache. - - :param key: the value's key. - :param \**kwargs: cache configuration arguments. The - backend is configured using these arguments upon first request. - Subsequent requests that use the same series of configuration - values will use that same backend. - - """ - - defname = kwargs.pop('defname', None) - expiretime = kwargs.pop('expiretime', None) - createfunc = kwargs.pop('createfunc', None) - - return self._get_cache(defname, **kwargs).get_value(key, starttime=self.starttime, expiretime=expiretime, createfunc=createfunc) - - def invalidate(self, key, **kwargs): - """Invalidate a value in the cache. - - :param key: the value's key. - :param \**kwargs: cache configuration arguments. The - backend is configured using these arguments upon first request. - Subsequent requests that use the same series of configuration - values will use that same backend. - - """ - defname = kwargs.pop('defname', None) - expiretime = kwargs.pop('expiretime', None) - createfunc = kwargs.pop('createfunc', None) - - self._get_cache(defname, **kwargs).remove_value(key, starttime=self.starttime, expiretime=expiretime) - - def invalidate_body(self): - """Invalidate the cached content of the "body" method for this template. - - """ - self.invalidate('render_body', defname='render_body') - - def invalidate_def(self, name): - """Invalidate the cached content of a particular <%def> within this template.""" - - self.invalidate('render_%s' % name, defname='render_%s' % name) - - def invalidate_closure(self, name): - """Invalidate a nested <%def> within this template. - - Caching of nested defs is a blunt tool as there is no - management of scope - nested defs that use cache tags - need to have names unique of all other nested defs in the - template, else their content will be overwritten by - each other. - - """ - - self.invalidate(name, defname=name) - - def _get_cache(self, defname, type=None, **kw): - global cache - if not cache: - try: - from beaker import cache as beaker_cache - cache = beaker_cache.CacheManager() - except ImportError: - # keep a fake cache around so subsequent - # calls don't attempt to re-import - cache = BeakerMissing() + space of a specific :class:`.Template` object. - if type == 'memcached': - type = 'ext:memcached' - if not type: - (type, kw) = self.def_regions.get(defname, ('memory', {})) + .. versionadded:: 0.6 + :class:`.Cache` by itself is mostly a + container for a :class:`.CacheImpl` object, which implements + a fixed API to provide caching services; specific subclasses exist to + implement different + caching strategies. Mako includes a backend that works with + the Beaker caching system. Beaker itself then supports + a number of backends (i.e. file, memory, memcached, etc.) + + The construction of a :class:`.Cache` is part of the mechanics + of a :class:`.Template`, and programmatic access to this + cache is typically via the :attr:`.Template.cache` attribute. + + """ + + impl = None + """Provide the :class:`.CacheImpl` in use by this :class:`.Cache`. + + This accessor allows a :class:`.CacheImpl` with additional + methods beyond that of :class:`.Cache` to be used programmatically. + + """ + + id = None + """Return the 'id' that identifies this cache. + + This is a value that should be globally unique to the + :class:`.Template` associated with this cache, and can + be used by a caching system to name a local container + for data specific to this template. + + """ + + starttime = None + """Epochal time value for when the owning :class:`.Template` was + first compiled. + + A cache implementation may wish to invalidate data earlier than + this timestamp; this has the effect of the cache for a specific + :class:`.Template` starting clean any time the :class:`.Template` + is recompiled, such as when the original template file changed on + the filesystem. + + """ + + def __init__(self, template, *args): + # check for a stale template calling the + # constructor + if isinstance(template, basestring) and args: + return + self.template = template + self.id = template.module.__name__ + self.starttime = template.module._modified_time + self._def_regions = {} + self.impl = self._load_impl(self.template.cache_impl) + + def _load_impl(self, name): + return _cache_plugins.load(name)(self) + + def get_or_create(self, key, creation_function, **kw): + """Retrieve a value from the cache, using the given creation function + to generate a new value.""" + + return self._ctx_get_or_create(key, creation_function, None, **kw) + + def _ctx_get_or_create(self, key, creation_function, context, **kw): + """Retrieve a value from the cache, using the given creation function + to generate a new value.""" + + if not self.template.cache_enabled: + return creation_function() + + return self.impl.get_or_create(key, + creation_function, + **self._get_cache_kw(kw, context)) + + def set(self, key, value, **kw): + """Place a value in the cache. + + :param key: the value's key. + :param value: the value. + :param \**kw: cache configuration arguments. + + """ + + self.impl.set(key, value, **self._get_cache_kw(kw, None)) + + put = set + """A synonym for :meth:`.Cache.set`. + + This is here for backwards compatibility. + + """ + + def get(self, key, **kw): + """Retrieve a value from the cache. + + :param key: the value's key. + :param \**kw: cache configuration arguments. The + backend is configured using these arguments upon first request. + Subsequent requests that use the same series of configuration + values will use that same backend. + + """ + return self.impl.get(key, **self._get_cache_kw(kw, None)) + + def invalidate(self, key, **kw): + """Invalidate a value in the cache. + + :param key: the value's key. + :param \**kw: cache configuration arguments. The + backend is configured using these arguments upon first request. + Subsequent requests that use the same series of configuration + values will use that same backend. + + """ + self.impl.invalidate(key, **self._get_cache_kw(kw, None)) + + def invalidate_body(self): + """Invalidate the cached content of the "body" method for this + template. + + """ + self.invalidate('render_body', __M_defname='render_body') + + def invalidate_def(self, name): + """Invalidate the cached content of a particular ``<%def>`` within this + template. + + """ + + self.invalidate('render_%s' % name, __M_defname='render_%s' % name) + + def invalidate_closure(self, name): + """Invalidate a nested ``<%def>`` within this template. + + Caching of nested defs is a blunt tool as there is no + management of scope -- nested defs that use cache tags + need to have names unique of all other nested defs in the + template, else their content will be overwritten by + each other. + + """ + + self.invalidate(name, __M_defname=name) + + def _get_cache_kw(self, kw, context): + defname = kw.pop('__M_defname', None) + if not defname: + tmpl_kw = self.template.cache_args.copy() + tmpl_kw.update(kw) + elif defname in self._def_regions: + tmpl_kw = self._def_regions[defname] else: - self.def_regions[defname] = (type, kw) - return cache.get_cache(self.id, type=type, **kw) - \ No newline at end of file + tmpl_kw = self.template.cache_args.copy() + tmpl_kw.update(kw) + self._def_regions[defname] = tmpl_kw + if context and self.impl.pass_context: + tmpl_kw = tmpl_kw.copy() + tmpl_kw.setdefault('context', context) + return tmpl_kw + +class CacheImpl(object): + """Provide a cache implementation for use by :class:`.Cache`.""" + + def __init__(self, cache): + self.cache = cache + + pass_context = False + """If ``True``, the :class:`.Context` will be passed to + :meth:`get_or_create <.CacheImpl.get_or_create>` as the name ``'context'``. + """ + + def get_or_create(self, key, creation_function, **kw): + """Retrieve a value from the cache, using the given creation function + to generate a new value. + + This function *must* return a value, either from + the cache, or via the given creation function. + If the creation function is called, the newly + created value should be populated into the cache + under the given key before being returned. + + :param key: the value's key. + :param creation_function: function that when called generates + a new value. + :param \**kw: cache configuration arguments. + + """ + raise NotImplementedError() + + def set(self, key, value, **kw): + """Place a value in the cache. + + :param key: the value's key. + :param value: the value. + :param \**kw: cache configuration arguments. + + """ + raise NotImplementedError() + + def get(self, key, **kw): + """Retrieve a value from the cache. + + :param key: the value's key. + :param \**kw: cache configuration arguments. + + """ + raise NotImplementedError() + + def invalidate(self, key, **kw): + """Invalidate a value in the cache. + + :param key: the value's key. + :param \**kw: cache configuration arguments. + + """ + raise NotImplementedError() diff --git a/mako/codegen.py b/mako/codegen.py index 53691807..3cec0eec 100644 --- a/mako/codegen.py +++ b/mako/codegen.py @@ -1,67 +1,79 @@ # mako/codegen.py -# Copyright (C) 2006-2011 the Mako authors and contributors +# Copyright (C) 2006-2012 the Mako authors and contributors # # This module is part of Mako and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php -"""provides functionality for rendering a parsetree constructing into module source code.""" +"""provides functionality for rendering a parsetree constructing into module +source code.""" import time import re from mako.pygen import PythonPrinter from mako import util, ast, parsetree, filters, exceptions -MAGIC_NUMBER = 6 +MAGIC_NUMBER = 8 -def compile(node, - uri, - filename=None, - default_filters=None, - buffer_filters=None, - imports=None, - source_encoding=None, +# names which are hardwired into the +# template and are not accessed via the +# context itself +RESERVED_NAMES = set(['context', 'loop', 'UNDEFINED']) + +def compile(node, + uri, + filename=None, + default_filters=None, + buffer_filters=None, + imports=None, + source_encoding=None, generate_magic_comment=True, disable_unicode=False, - strict_undefined=False): - - """Generate module source code given a parsetree node, + strict_undefined=False, + enable_loop=True, + reserved_names=()): + + """Generate module source code given a parsetree node, uri, and optional source filename""" # if on Py2K, push the "source_encoding" string to be - # a bytestring itself, as we will be embedding it into - # the generated source and we don't want to coerce the + # a bytestring itself, as we will be embedding it into + # the generated source and we don't want to coerce the # result into a unicode object, in "disable_unicode" mode if not util.py3k and isinstance(source_encoding, unicode): source_encoding = source_encoding.encode(source_encoding) - - + + buf = util.FastEncodingBuffer() printer = PythonPrinter(buf) - _GenerateRenderMethod(printer, - _CompileContext(uri, - filename, - default_filters, + _GenerateRenderMethod(printer, + _CompileContext(uri, + filename, + default_filters, buffer_filters, - imports, + imports, source_encoding, generate_magic_comment, disable_unicode, - strict_undefined), + strict_undefined, + enable_loop, + reserved_names), node) return buf.getvalue() class _CompileContext(object): - def __init__(self, - uri, - filename, - default_filters, - buffer_filters, - imports, - source_encoding, + def __init__(self, + uri, + filename, + default_filters, + buffer_filters, + imports, + source_encoding, generate_magic_comment, disable_unicode, - strict_undefined): + strict_undefined, + enable_loop, + reserved_names): self.uri = uri self.filename = filename self.default_filters = default_filters @@ -71,11 +83,13 @@ class _CompileContext(object): self.generate_magic_comment = generate_magic_comment self.disable_unicode = disable_unicode self.strict_undefined = strict_undefined - + self.enable_loop = enable_loop + self.reserved_names = reserved_names + class _GenerateRenderMethod(object): - """A template visitor object which generates the + """A template visitor object which generates the full module source for a template. - + """ def __init__(self, printer, compiler, node): self.printer = printer @@ -83,13 +97,13 @@ class _GenerateRenderMethod(object): self.compiler = compiler self.node = node self.identifier_stack = [None] - + self.in_def = isinstance(node, (parsetree.DefTag, parsetree.BlockTag)) if self.in_def: name = "render_%s" % node.funcname args = node.get_argument_expressions() - filtered = len(node.filter_args.args) > 0 + filtered = len(node.filter_args.args) > 0 buffered = eval(node.attributes.get('buffered', 'False')) cached = eval(node.attributes.get('cached', 'False')) defs = None @@ -105,6 +119,10 @@ class _GenerateRenderMethod(object): if not pagetag.body_decl.kwargs: args += ['**pageargs'] cached = eval(pagetag.attributes.get('cached', 'False')) + self.compiler.enable_loop = self.compiler.enable_loop or eval( + pagetag.attributes.get( + 'enable_loop', 'False') + ) else: args = ['**pageargs'] cached = False @@ -113,24 +131,24 @@ class _GenerateRenderMethod(object): args = ['context'] else: args = [a for a in ['context'] + args] - + self.write_render_callable( - pagetag or node, - name, args, + pagetag or node, + name, args, buffered, filtered, cached) - + if defs is not None: for node in defs: _GenerateRenderMethod(printer, compiler, node) - + @property def identifiers(self): return self.identifier_stack[-1] - + def write_toplevel(self): """Traverse a template structure for module-level directives and generate the start of module-level code. - + """ inherit = [] namespaces = {} @@ -138,7 +156,7 @@ class _GenerateRenderMethod(object): encoding =[None] self.compiler.pagetag = None - + class FindTopLevel(object): def visitInheritTag(s, node): inherit.append(node) @@ -149,7 +167,7 @@ class _GenerateRenderMethod(object): def visitCode(s, node): if node.ismodule: module_code.append(node) - + f = FindTopLevel() for n in self.node.nodes: n.accept_visitor(f) @@ -160,41 +178,40 @@ class _GenerateRenderMethod(object): for n in module_code: module_ident = module_ident.union(n.declared_identifiers()) - module_identifiers = _Identifiers() + module_identifiers = _Identifiers(self.compiler) module_identifiers.declared = module_ident - + # module-level names, python code if self.compiler.generate_magic_comment and \ self.compiler.source_encoding: self.printer.writeline("# -*- encoding:%s -*-" % self.compiler.source_encoding) - + self.printer.writeline("from mako import runtime, filters, cache") self.printer.writeline("UNDEFINED = runtime.UNDEFINED") self.printer.writeline("__M_dict_builtin = dict") self.printer.writeline("__M_locals_builtin = locals") self.printer.writeline("_magic_number = %r" % MAGIC_NUMBER) self.printer.writeline("_modified_time = %r" % time.time()) + self.printer.writeline("_enable_loop = %r" % self.compiler.enable_loop) self.printer.writeline( - "_template_filename=%r" % self.compiler.filename) - self.printer.writeline("_template_uri=%r" % self.compiler.uri) + "_template_filename = %r" % self.compiler.filename) + self.printer.writeline("_template_uri = %r" % self.compiler.uri) self.printer.writeline( - "_template_cache=cache.Cache(__name__, _modified_time)") - self.printer.writeline( - "_source_encoding=%r" % self.compiler.source_encoding) + "_source_encoding = %r" % self.compiler.source_encoding) if self.compiler.imports: buf = '' for imp in self.compiler.imports: buf += imp + "\n" self.printer.writeline(imp) impcode = ast.PythonCode( - buf, - source='', lineno=0, - pos=0, + buf, + source='', lineno=0, + pos=0, filename='template defined imports') else: impcode = None - + main_identifiers = module_identifiers.branch(self.node) module_identifiers.topleveldefs = \ module_identifiers.topleveldefs.\ @@ -202,9 +219,9 @@ class _GenerateRenderMethod(object): module_identifiers.declared.add("UNDEFINED") if impcode: module_identifiers.declared.update(impcode.declared_identifiers) - + self.compiler.identifiers = module_identifiers - self.printer.writeline("_exports = %r" % + self.printer.writeline("_exports = %r" % [n.name for n in main_identifiers.topleveldefs.values()] ) @@ -221,25 +238,29 @@ class _GenerateRenderMethod(object): return main_identifiers.topleveldefs.values() - def write_render_callable(self, node, name, args, buffered, filtered, cached): + def write_render_callable(self, node, name, args, buffered, filtered, + cached): """write a top-level render callable. - + this could be the main render() method or that of a top-level def.""" - + if self.in_def: decorator = node.decorator if decorator: - self.printer.writeline("@runtime._decorate_toplevel(%s)" % decorator) - + self.printer.writeline( + "@runtime._decorate_toplevel(%s)" % decorator) + self.printer.writelines( "def %s(%s):" % (name, ','.join(args)), - "context.caller_stack._push_frame()", + # push new frame, assign current frame to __M_caller + "__M_caller = context.caller_stack._push_frame()", "try:" ) if buffered or filtered or cached: self.printer.writeline("context._push_buffer()") - - self.identifier_stack.append(self.compiler.identifiers.branch(self.node)) + + self.identifier_stack.append( + self.compiler.identifiers.branch(self.node)) if (not self.in_def or self.node.is_block) and '**pageargs' in args: self.identifier_stack[-1].argument_declared.add('pageargs') @@ -247,7 +268,7 @@ class _GenerateRenderMethod(object): len(self.identifiers.locally_assigned) > 0 or len(self.identifiers.argument_declared) > 0 ): - self.printer.writeline("__M_locals = __M_dict_builtin(%s)" % + self.printer.writeline("__M_locals = __M_dict_builtin(%s)" % ','.join([ "%s=%s" % (x, x) for x in self.identifiers.argument_declared @@ -263,12 +284,12 @@ class _GenerateRenderMethod(object): self.printer.write("\n\n") if cached: self.write_cache_decorator( - node, name, - args, buffered, + node, name, + args, buffered, self.identifiers, toplevel=True) - + def write_module_code(self, module_code): - """write module-level template code, i.e. that which + """write module-level template code, i.e. that which is enclosed in <%! %> tags in the template.""" for n in module_code: self.write_source_comment(n) @@ -276,7 +297,7 @@ class _GenerateRenderMethod(object): def write_inherit(self, node): """write the module-level inheritance-determination callable.""" - + self.printer.writelines( "def _mako_inherit(template, context):", "_mako_generate_namespaces(context)", @@ -298,7 +319,7 @@ class _GenerateRenderMethod(object): ) self.printer.writeline("def _mako_generate_namespaces(context):") - + for node in namespaces.values(): if node.attributes.has_key('import'): self.compiler.has_ns_imports = True @@ -318,7 +339,8 @@ class _GenerateRenderMethod(object): def visitDefOrBase(s, node): if node.is_anonymous: raise exceptions.CompileException( - "Can't put anonymous blocks inside <%namespace>", + "Can't put anonymous blocks inside " + "<%namespace>", **node.exception_kwargs ) self.write_inline_def(node, identifiers, nested=False) @@ -335,45 +357,51 @@ class _GenerateRenderMethod(object): if 'file' in node.parsed_attributes: self.printer.writeline( - "ns = runtime.TemplateNamespace(%r, context._clean_inheritance_tokens()," - " templateuri=%s, callables=%s, calling_uri=_template_uri)" % + "ns = runtime.TemplateNamespace(%r," + " context._clean_inheritance_tokens()," + " templateuri=%s, callables=%s, " + " calling_uri=_template_uri)" % ( - node.name, - node.parsed_attributes.get('file', 'None'), - callable_name, + node.name, + node.parsed_attributes.get('file', 'None'), + callable_name, ) ) elif 'module' in node.parsed_attributes: self.printer.writeline( - "ns = runtime.ModuleNamespace(%r, context._clean_inheritance_tokens()," - " callables=%s, calling_uri=_template_uri, module=%s)" % + "ns = runtime.ModuleNamespace(%r," + " context._clean_inheritance_tokens()," + " callables=%s, calling_uri=_template_uri," + " module=%s)" % ( - node.name, - callable_name, - node.parsed_attributes.get('module', 'None') + node.name, + callable_name, + node.parsed_attributes.get('module', 'None') ) ) else: self.printer.writeline( - "ns = runtime.Namespace(%r, context._clean_inheritance_tokens()," + "ns = runtime.Namespace(%r," + " context._clean_inheritance_tokens()," " callables=%s, calling_uri=_template_uri)" % ( node.name, - callable_name, + callable_name, ) ) if eval(node.attributes.get('inheritable', "False")): self.printer.writeline("context['self'].%s = ns" % (node.name)) - - self.printer.writeline("context.namespaces[(__name__, %s)] = ns" % repr(node.name)) + + self.printer.writeline( + "context.namespaces[(__name__, %s)] = ns" % repr(node.name)) self.printer.write("\n") if not len(namespaces): self.printer.writeline("pass") self.printer.writeline(None) - + def write_variable_declares(self, identifiers, toplevel=False, limit=None): """write variable declarations at the top of a function. - + the variable declarations are in the form of callable definitions for defs and/or name lookup within the function's context argument. the names declared are based @@ -382,53 +410,67 @@ class _GenerateRenderMethod(object): operation. names that are assigned within the body are assumed to be locally-scoped variables and are not separately declared. - + for def callable definitions, if the def is a top-level callable then a 'stub' callable is generated which wraps the current Context into a closure. if the def is not top-level, it is fully rendered as a local closure. - + """ + # collection of all defs available to us in this scope comp_idents = dict([(c.funcname, c) for c in identifiers.defs]) to_write = set() - - # write "context.get()" for all variables we are going to + + # write "context.get()" for all variables we are going to # need that arent in the namespace yet to_write = to_write.union(identifiers.undeclared) - - # write closure functions for closures that we define - # right here - to_write = to_write.union([c.funcname for c in identifiers.closuredefs.values()]) - # remove identifiers that are declared in the argument + # write closure functions for closures that we define + # right here + to_write = to_write.union( + [c.funcname for c in identifiers.closuredefs.values()]) + + # remove identifiers that are declared in the argument # signature of the callable to_write = to_write.difference(identifiers.argument_declared) - # remove identifiers that we are going to assign to. + # remove identifiers that we are going to assign to. # in this way we mimic Python's behavior, - # i.e. assignment to a variable within a block + # i.e. assignment to a variable within a block # means that variable is now a "locally declared" var, - # which cannot be referenced beforehand. + # which cannot be referenced beforehand. to_write = to_write.difference(identifiers.locally_declared) - + + if self.compiler.enable_loop: + has_loop = "loop" in to_write + to_write.discard("loop") + else: + has_loop = False + # if a limiting set was sent, constraint to those items in that list # (this is used for the caching decorator) if limit is not None: to_write = to_write.intersection(limit) - + if toplevel and getattr(self.compiler, 'has_ns_imports', False): self.printer.writeline("_import_ns = {}") self.compiler.has_imports = True for ident, ns in self.compiler.namespaces.iteritems(): if ns.attributes.has_key('import'): self.printer.writeline( - "_mako_get_namespace(context, %r)._populate(_import_ns, %r)" % + "_mako_get_namespace(context, %r)."\ + "_populate(_import_ns, %r)" % ( ident, re.split(r'\s*,\s*', ns.attributes['import']) )) - + + if has_loop: + self.printer.writeline( + 'loop = __M_loop = runtime.LoopStack()' + ) + for ident in to_write: if ident in comp_idents: comp = comp_idents[ident] @@ -445,26 +487,26 @@ class _GenerateRenderMethod(object): elif ident in self.compiler.namespaces: self.printer.writeline( - "%s = _mako_get_namespace(context, %r)" % + "%s = _mako_get_namespace(context, %r)" % (ident, ident) ) else: if getattr(self.compiler, 'has_ns_imports', False): if self.compiler.strict_undefined: self.printer.writelines( - "%s = _import_ns.get(%r, UNDEFINED)" % + "%s = _import_ns.get(%r, UNDEFINED)" % (ident, ident), "if %s is UNDEFINED:" % ident, "try:", "%s = context[%r]" % (ident, ident), "except KeyError:", - "raise NameError(\"'%s' is not defined\")" % + "raise NameError(\"'%s' is not defined\")" % ident, None, None ) else: self.printer.writeline( - "%s = _import_ns.get(%r, context.get(%r, UNDEFINED))" % + "%s = _import_ns.get(%r, context.get(%r, UNDEFINED))" % (ident, ident, ident)) else: if self.compiler.strict_undefined: @@ -472,7 +514,7 @@ class _GenerateRenderMethod(object): "try:", "%s = context[%r]" % (ident, ident), "except KeyError:", - "raise NameError(\"'%s' is not defined\")" % + "raise NameError(\"'%s' is not defined\")" % ident, None ) @@ -480,11 +522,12 @@ class _GenerateRenderMethod(object): self.printer.writeline( "%s = context.get(%r, UNDEFINED)" % (ident, ident) ) - + self.printer.writeline("__M_writer = context.writer()") - + def write_source_comment(self, node): - """write a source comment containing the line number of the corresponding template line.""" + """write a source comment containing the line number of the + corresponding template line.""" if self.last_source_line != node.lineno: self.printer.writeline("# SOURCE LINE %d" % node.lineno) self.last_source_line = node.lineno @@ -494,7 +537,7 @@ class _GenerateRenderMethod(object): funcname = node.funcname namedecls = node.get_argument_expressions() nameargs = node.get_argument_expressions(include_defaults=False) - + if not self.in_def and ( len(self.identifiers.locally_assigned) > 0 or len(self.identifiers.argument_declared) > 0): @@ -502,23 +545,27 @@ class _GenerateRenderMethod(object): else: nameargs.insert(0, 'context') self.printer.writeline("def %s(%s):" % (funcname, ",".join(namedecls))) - self.printer.writeline("return render_%s(%s)" % (funcname, ",".join(nameargs))) + self.printer.writeline( + "return render_%s(%s)" % (funcname, ",".join(nameargs))) self.printer.writeline(None) - + def write_inline_def(self, node, identifiers, nested): """write a locally-available def callable inside an enclosing def.""" namedecls = node.get_argument_expressions() - + decorator = node.decorator if decorator: - self.printer.writeline("@runtime._decorate_inline(context, %s)" % decorator) - self.printer.writeline("def %s(%s):" % (node.funcname, ",".join(namedecls))) - filtered = len(node.filter_args.args) > 0 + self.printer.writeline( + "@runtime._decorate_inline(context, %s)" % decorator) + self.printer.writeline( + "def %s(%s):" % (node.funcname, ",".join(namedecls))) + filtered = len(node.filter_args.args) > 0 buffered = eval(node.attributes.get('buffered', 'False')) cached = eval(node.attributes.get('cached', 'False')) self.printer.writelines( - "context.caller_stack._push_frame()", + # push new frame, assign current frame to __M_caller + "__M_caller = context.caller_stack._push_frame()", "try:" ) if buffered or filtered or cached: @@ -529,26 +576,29 @@ class _GenerateRenderMethod(object): identifiers = identifiers.branch(node, nested=nested) self.write_variable_declares(identifiers) - + self.identifier_stack.append(identifiers) for n in node.nodes: n.accept_visitor(self) self.identifier_stack.pop() - + self.write_def_finish(node, buffered, filtered, cached) self.printer.writeline(None) if cached: - self.write_cache_decorator(node, node.funcname, - namedecls, False, identifiers, + self.write_cache_decorator(node, node.funcname, + namedecls, False, identifiers, inline=True, toplevel=False) - - def write_def_finish(self, node, buffered, filtered, cached, callstack=True): - """write the end section of a rendering function, either outermost or inline. - - this takes into account if the rendering function was filtered, buffered, etc. - and closes the corresponding try: block if any, and writes code to retrieve - captured content, apply filters, send proper return value.""" - + + def write_def_finish(self, node, buffered, filtered, cached, + callstack=True): + """write the end section of a rendering function, either outermost or + inline. + + this takes into account if the rendering function was filtered, + buffered, etc. and closes the corresponding try: block if any, and + writes code to retrieve captured content, apply filters, send proper + return value.""" + if not buffered and not cached and not filtered: self.printer.writeline("return ''") if callstack: @@ -557,7 +607,7 @@ class _GenerateRenderMethod(object): "context.caller_stack._pop_frame()", None ) - + if buffered or filtered or cached: if buffered or cached: # in a caching scenario, don't try to get a writer @@ -570,19 +620,21 @@ class _GenerateRenderMethod(object): ) else: self.printer.writelines( - "finally:", - "__M_buf, __M_writer = context._pop_buffer_and_writer()" + "finally:", + "__M_buf, __M_writer = context._pop_buffer_and_writer()" ) - + if callstack: self.printer.writeline("context.caller_stack._pop_frame()") - + s = "__M_buf.getvalue()" if filtered: - s = self.create_filter_callable(node.filter_args.args, s, False) + s = self.create_filter_callable(node.filter_args.args, s, + False) self.printer.writeline(None) if buffered and not cached: - s = self.create_filter_callable(self.compiler.buffer_filters, s, False) + s = self.create_filter_callable(self.compiler.buffer_filters, + s, False) if buffered or cached: self.printer.writeline("return %s" % s) else: @@ -591,71 +643,81 @@ class _GenerateRenderMethod(object): "return ''" ) - def write_cache_decorator(self, node_or_pagetag, name, - args, buffered, identifiers, + def write_cache_decorator(self, node_or_pagetag, name, + args, buffered, identifiers, inline=False, toplevel=False): - """write a post-function decorator to replace a rendering + """write a post-function decorator to replace a rendering callable with a cached version of itself.""" - + self.printer.writeline("__M_%s = %s" % (name, name)) - cachekey = node_or_pagetag.parsed_attributes.get('cache_key', repr(name)) - cacheargs = {} - for arg in ( - ('cache_type', 'type'), ('cache_dir', 'data_dir'), - ('cache_timeout', 'expiretime'), ('cache_url', 'url')): - val = node_or_pagetag.parsed_attributes.get(arg[0], None) - if val is not None: - if arg[1] == 'expiretime': - cacheargs[arg[1]] = int(eval(val)) - else: - cacheargs[arg[1]] = val - else: - if self.compiler.pagetag is not None: - val = self.compiler.pagetag.parsed_attributes.get(arg[0], None) - if val is not None: - if arg[1] == 'expiretime': - cacheargs[arg[1]] == int(eval(val)) - else: - cacheargs[arg[1]] = val - + cachekey = node_or_pagetag.parsed_attributes.get('cache_key', + repr(name)) + + cache_args = {} + if self.compiler.pagetag is not None: + cache_args.update( + ( + pa[6:], + self.compiler.pagetag.parsed_attributes[pa] + ) + for pa in self.compiler.pagetag.parsed_attributes + if pa.startswith('cache_') and pa != 'cache_key' + ) + cache_args.update( + ( + pa[6:], + node_or_pagetag.parsed_attributes[pa] + ) for pa in node_or_pagetag.parsed_attributes + if pa.startswith('cache_') and pa != 'cache_key' + ) + if 'timeout' in cache_args: + cache_args['timeout'] = int(eval(cache_args['timeout'])) + self.printer.writeline("def %s(%s):" % (name, ','.join(args))) - + # form "arg1, arg2, arg3=arg3, arg4=arg4", etc. pass_args = [ - '=' in a and "%s=%s" % ((a.split('=')[0],)*2) or a + '=' in a and "%s=%s" % ((a.split('=')[0],)*2) or a for a in args ] self.write_variable_declares( - identifiers, - toplevel=toplevel, + identifiers, + toplevel=toplevel, limit=node_or_pagetag.undeclared_identifiers() ) if buffered: s = "context.get('local')."\ - "get_cached(%s, defname=%r, %screatefunc=lambda:__M_%s(%s))" % \ - (cachekey, name, - ''.join(["%s=%s, " % (k,v) for k, v in cacheargs.iteritems()]), - name, ','.join(pass_args)) + "cache._ctx_get_or_create("\ + "%s, lambda:__M_%s(%s), context, %s__M_defname=%r)" % \ + (cachekey, name, ','.join(pass_args), + ''.join(["%s=%s, " % (k,v) + for k, v in cache_args.items()]), + name + ) # apply buffer_filters - s = self.create_filter_callable(self.compiler.buffer_filters, s, False) + s = self.create_filter_callable(self.compiler.buffer_filters, s, + False) self.printer.writelines("return " + s,None) else: self.printer.writelines( "__M_writer(context.get('local')." - "get_cached(%s, defname=%r, %screatefunc=lambda:__M_%s(%s)))" % - (cachekey, name, - ''.join(["%s=%s, " % (k,v) for k, v in cacheargs.iteritems()]), - name, ','.join(pass_args)), + "cache._ctx_get_or_create("\ + "%s, lambda:__M_%s(%s), context, %s__M_defname=%r))" % + (cachekey, name, ','.join(pass_args), + ''.join(["%s=%s, " % (k,v) + for k, v in cache_args.items()]), + name, + ), "return ''", None ) def create_filter_callable(self, args, target, is_expression): - """write a filter-applying expression based on the filters - present in the given filter names, adjusting for the global + """write a filter-applying expression based on the filters + present in the given filter names, adjusting for the global 'default' filter aliases as needed.""" - + def locate_encode(name): if re.match(r'decode\..+', name): return "filters." + name @@ -663,7 +725,7 @@ class _GenerateRenderMethod(object): return filters.NON_UNICODE_ESCAPES.get(name, name) else: return filters.DEFAULT_ESCAPES.get(name, name) - + if 'n' not in args: if is_expression: if self.compiler.pagetag: @@ -685,7 +747,7 @@ class _GenerateRenderMethod(object): assert e is not None target = "%s(%s)" % (e, target) return target - + def visitExpression(self, node): self.write_source_comment(node) if len(node.escapes) or \ @@ -694,25 +756,46 @@ class _GenerateRenderMethod(object): len(self.compiler.pagetag.filter_args.args) ) or \ len(self.compiler.default_filters): - - s = self.create_filter_callable(node.escapes_code.args, "%s" % node.text, True) + + s = self.create_filter_callable(node.escapes_code.args, + "%s" % node.text, True) self.printer.writeline("__M_writer(%s)" % s) else: self.printer.writeline("__M_writer(%s)" % node.text) - + def visitControlLine(self, node): if node.isend: - if not node.get_children(): - self.printer.writeline("pass") self.printer.writeline(None) + if node.has_loop_context: + self.printer.writeline('finally:') + self.printer.writeline("loop = __M_loop._exit()") + self.printer.writeline(None) else: self.write_source_comment(node) - self.printer.writeline(node.text) - + if self.compiler.enable_loop and node.keyword == 'for': + text = mangle_mako_loop(node, self.printer) + else: + text = node.text + self.printer.writeline(text) + children = node.get_children() + # this covers the three situations where we want to insert a pass: + # 1) a ternary control line with no children, + # 2) a primary control line with nothing but its own ternary + # and end control lines, and + # 3) any control line with no content other than comments + if not children or ( + util.all(isinstance(c, (parsetree.Comment, + parsetree.ControlLine)) + for c in children) and + util.all((node.is_ternary(c.keyword) or c.isend) + for c in children + if isinstance(c, parsetree.ControlLine))): + self.printer.writeline("pass") + def visitText(self, node): self.write_source_comment(node) self.printer.writeline("__M_writer(%s)" % repr(node.content)) - + def visitTextTag(self, node): filtered = len(node.filter_args.args) > 0 if filtered: @@ -726,46 +809,47 @@ class _GenerateRenderMethod(object): self.printer.writelines( "finally:", "__M_buf, __M_writer = context._pop_buffer_and_writer()", - "__M_writer(%s)" % + "__M_writer(%s)" % self.create_filter_callable( - node.filter_args.args, - "__M_buf.getvalue()", + node.filter_args.args, + "__M_buf.getvalue()", False), None ) - + def visitCode(self, node): if not node.ismodule: self.write_source_comment(node) self.printer.write_indented_block(node.text) if not self.in_def and len(self.identifiers.locally_assigned) > 0: - # if we are the "template" def, fudge locally + # if we are the "template" def, fudge locally # declared/modified variables into the "__M_locals" dictionary, - # which is used for def calls within the same template, + # which is used for def calls within the same template, # to simulate "enclosing scope" - self.printer.writeline('__M_locals_builtin_stored = __M_locals_builtin()') self.printer.writeline( - '__M_locals.update(__M_dict_builtin([(__M_key,' - ' __M_locals_builtin_stored[__M_key]) for ' - '__M_key in [%s] if __M_key in __M_locals_builtin_stored]))' % - ','.join([repr(x) for x in node.declared_identifiers()])) + '__M_locals_builtin_stored = __M_locals_builtin()') + self.printer.writeline( + '__M_locals.update(__M_dict_builtin([(__M_key,' + ' __M_locals_builtin_stored[__M_key]) for __M_key in' + ' [%s] if __M_key in __M_locals_builtin_stored]))' % + ','.join([repr(x) for x in node.declared_identifiers()])) def visitIncludeTag(self, node): self.write_source_comment(node) args = node.attributes.get('args') if args: self.printer.writeline( - "runtime._include_file(context, %s, _template_uri, %s)" % - (node.parsed_attributes['file'], args)) + "runtime._include_file(context, %s, _template_uri, %s)" % + (node.parsed_attributes['file'], args)) else: self.printer.writeline( "runtime._include_file(context, %s, _template_uri)" % (node.parsed_attributes['file'])) - + def visitNamespaceTag(self, node): pass - + def visitDefTag(self, node): pass @@ -776,9 +860,10 @@ class _GenerateRenderMethod(object): nameargs = node.get_argument_expressions(include_defaults=False) nameargs += ['**pageargs'] self.printer.writeline("if 'parent' not in context._data or " - "not hasattr(context._data['parent'], '%s'):" - % node.funcname) - self.printer.writeline("context['self'].%s(%s)" % (node.funcname, ",".join(nameargs))) + "not hasattr(context._data['parent'], '%s'):" + % node.funcname) + self.printer.writeline( + "context['self'].%s(%s)" % (node.funcname, ",".join(nameargs))) self.printer.writeline("\n") def visitCallNamespaceTag(self, node): @@ -786,18 +871,18 @@ class _GenerateRenderMethod(object): # as ensure the given namespace will be imported, # pre-import the namespace, etc. self.visitCallTag(node) - + def visitCallTag(self, node): self.printer.writeline("def ccall(caller):") export = ['body'] callable_identifiers = self.identifiers.branch(node, nested=True) body_identifiers = callable_identifiers.branch(node, nested=False) - # we want the 'caller' passed to ccall to be used - # for the body() function, but for other non-body() - # <%def>s within <%call> we want the current caller + # we want the 'caller' passed to ccall to be used + # for the body() function, but for other non-body() + # <%def>s within <%call> we want the current caller # off the call stack (if any) body_identifiers.add_declared('caller') - + self.identifier_stack.append(body_identifiers) class DefVisitor(object): def visitDefTag(s, node): @@ -810,8 +895,8 @@ class _GenerateRenderMethod(object): self.write_inline_def(node, callable_identifiers, nested=False) if not node.is_anonymous: export.append(node.funcname) - # remove defs that are within the <%call> from the "closuredefs" defined - # in the body, so they dont render twice + # remove defs that are within the <%call> from the + # "closuredefs" defined in the body, so they dont render twice if node.funcname in body_identifiers.closuredefs: del body_identifiers.closuredefs[node.funcname] @@ -819,11 +904,11 @@ class _GenerateRenderMethod(object): for n in node.nodes: n.accept_visitor(vis) self.identifier_stack.pop() - - bodyargs = node.body_decl.get_argument_expressions() + + bodyargs = node.body_decl.get_argument_expressions() self.printer.writeline("def body(%s):" % ','.join(bodyargs)) - - # TODO: figure out best way to specify + + # TODO: figure out best way to specify # buffering/nonbuffering (at call time would be better) buffered = False if buffered: @@ -833,11 +918,11 @@ class _GenerateRenderMethod(object): ) self.write_variable_declares(body_identifiers) self.identifier_stack.append(body_identifiers) - + for n in node.nodes: n.accept_visitor(self) self.identifier_stack.pop() - + self.write_def_finish(node, buffered, False, False, callstack=False) self.printer.writelines( None, @@ -846,15 +931,15 @@ class _GenerateRenderMethod(object): ) self.printer.writelines( - # get local reference to current caller, if any - "caller = context.caller_stack._get_caller()", # push on caller for nested call "context.caller_stack.nextcaller = " - "runtime.Namespace('caller', context, callables=ccall(caller))", + "runtime.Namespace('caller', context, " + "callables=ccall(__M_caller))", "try:") self.write_source_comment(node) self.printer.writelines( - "__M_writer(%s)" % self.create_filter_callable([], node.expression, True), + "__M_writer(%s)" % self.create_filter_callable( + [], node.expression, True), "finally:", "context.caller_stack.nextcaller = None", None @@ -862,9 +947,8 @@ class _GenerateRenderMethod(object): class _Identifiers(object): """tracks the status of identifier names as template code is rendered.""" - - def __init__(self, node=None, parent=None, nested=False): - + + def __init__(self, compiler, node=None, parent=None, nested=False): if parent is not None: # if we are the branch created in write_namespaces(), # we don't share any context from the main body(). @@ -872,65 +956,76 @@ class _Identifiers(object): self.declared = set() self.topleveldefs = util.SetLikeDict() else: - # things that have already been declared + # things that have already been declared # in an enclosing namespace (i.e. names we can just use) self.declared = set(parent.declared).\ - union([c.name for c in parent.closuredefs.values()]).\ - union(parent.locally_declared).\ - union(parent.argument_declared) - - # if these identifiers correspond to a "nested" - # scope, it means whatever the parent identifiers - # had as undeclared will have been declared by that parent, + union([c.name for c in parent.closuredefs.values()]).\ + union(parent.locally_declared).\ + union(parent.argument_declared) + + # if these identifiers correspond to a "nested" + # scope, it means whatever the parent identifiers + # had as undeclared will have been declared by that parent, # and therefore we have them in our scope. if nested: self.declared = self.declared.union(parent.undeclared) - + # top level defs that are available self.topleveldefs = util.SetLikeDict(**parent.topleveldefs) else: self.declared = set() self.topleveldefs = util.SetLikeDict() - - # things within this level that are referenced before they + + self.compiler = compiler + + # things within this level that are referenced before they # are declared (e.g. assigned to) self.undeclared = set() - - # things that are declared locally. some of these things - # could be in the "undeclared" list as well if they are + + # things that are declared locally. some of these things + # could be in the "undeclared" list as well if they are # referenced before declared self.locally_declared = set() - - # assignments made in explicit python blocks. - # these will be propagated to + + # assignments made in explicit python blocks. + # these will be propagated to # the context of local def calls. self.locally_assigned = set() - - # things that are declared in the argument + + # things that are declared in the argument # signature of the def callable self.argument_declared = set() - + # closure defs that are defined in this level self.closuredefs = util.SetLikeDict() - + self.node = node - + if node is not None: node.accept_visitor(self) - + + illegal_names = self.compiler.reserved_names.intersection( + self.locally_declared) + if illegal_names: + raise exceptions.NameConflictError( + "Reserved words declared in template: %s" % + ", ".join(illegal_names)) + + def branch(self, node, **kwargs): - """create a new Identifiers for a new Node, with + """create a new Identifiers for a new Node, with this Identifiers as the parent.""" - - return _Identifiers(node, self, **kwargs) - + + return _Identifiers(self.compiler, node, self, **kwargs) + @property def defs(self): return set(self.topleveldefs.union(self.closuredefs).values()) - + def __repr__(self): return "Identifiers(declared=%r, locally_declared=%r, "\ - "undeclared=%r, topleveldefs=%r, closuredefs=%r, argumentdeclared=%r)" %\ + "undeclared=%r, topleveldefs=%r, closuredefs=%r, "\ + "argumentdeclared=%r)" %\ ( list(self.declared), list(self.locally_declared), @@ -938,36 +1033,38 @@ class _Identifiers(object): [c.name for c in self.topleveldefs.values()], [c.name for c in self.closuredefs.values()], self.argument_declared) - + def check_declared(self, node): - """update the state of this Identifiers with the undeclared + """update the state of this Identifiers with the undeclared and declared identifiers of the given node.""" - + for ident in node.undeclared_identifiers(): - if ident != 'context' and ident not in self.declared.union(self.locally_declared): + if ident != 'context' and\ + ident not in self.declared.union(self.locally_declared): self.undeclared.add(ident) for ident in node.declared_identifiers(): self.locally_declared.add(ident) - + def add_declared(self, ident): self.declared.add(ident) if ident in self.undeclared: self.undeclared.remove(ident) - + def visitExpression(self, node): self.check_declared(node) - + def visitControlLine(self, node): self.check_declared(node) - + def visitCode(self, node): if not node.ismodule: self.check_declared(node) - self.locally_assigned = self.locally_assigned.union(node.declared_identifiers()) - + self.locally_assigned = self.locally_assigned.union( + node.declared_identifiers()) + def visitNamespaceTag(self, node): - # only traverse into the sub-elements of a - # <%namespace> tag if we are the branch created in + # only traverse into the sub-elements of a + # <%namespace> tag if we are the branch created in # write_namespaces() if self.node is node: for n in node.nodes: @@ -981,7 +1078,7 @@ class _Identifiers(object): (node.is_block or existing.is_block): raise exceptions.CompileException( "%%def or %%block named '%s' already " - "exists in this template." % + "exists in this template." % node.funcname, **node.exception_kwargs) def visitDefTag(self, node): @@ -991,13 +1088,15 @@ class _Identifiers(object): self._check_name_exists(self.closuredefs, node) for ident in node.undeclared_identifiers(): - if ident != 'context' and ident not in self.declared.union(self.locally_declared): + if ident != 'context' and\ + ident not in self.declared.union(self.locally_declared): self.undeclared.add(ident) - + # visit defs only one level deep if node is self.node: for ident in node.declared_identifiers(): self.argument_declared.add(ident) + for n in node.nodes: n.accept_visitor(self) @@ -1007,13 +1106,19 @@ class _Identifiers(object): if isinstance(self.node, parsetree.DefTag): raise exceptions.CompileException( - "Named block '%s' not allowed inside of def '%s'" + "Named block '%s' not allowed inside of def '%s'" % (node.name, self.node.name), **node.exception_kwargs) - elif isinstance(self.node, (parsetree.CallTag, parsetree.CallNamespaceTag)): + elif isinstance(self.node, + (parsetree.CallTag, parsetree.CallNamespaceTag)): raise exceptions.CompileException( - "Named block '%s' not allowed inside of <%%call> tag" + "Named block '%s' not allowed inside of <%%call> tag" % (node.name, ), **node.exception_kwargs) + for ident in node.undeclared_identifiers(): + if ident != 'context' and\ + ident not in self.declared.union(self.locally_declared): + self.undeclared.add(ident) + if not node.is_anonymous: self._check_name_exists(self.topleveldefs, node) self.undeclared.add(node.funcname) @@ -1026,19 +1131,20 @@ class _Identifiers(object): def visitIncludeTag(self, node): self.check_declared(node) - + def visitPageTag(self, node): for ident in node.declared_identifiers(): self.argument_declared.add(ident) self.check_declared(node) - + def visitCallNamespaceTag(self, node): self.visitCallTag(node) - + def visitCallTag(self, node): if node is self.node: for ident in node.undeclared_identifiers(): - if ident != 'context' and ident not in self.declared.union(self.locally_declared): + if ident != 'context' and\ + ident not in self.declared.union(self.locally_declared): self.undeclared.add(ident) for ident in node.declared_identifiers(): self.argument_declared.add(ident) @@ -1046,6 +1152,58 @@ class _Identifiers(object): n.accept_visitor(self) else: for ident in node.undeclared_identifiers(): - if ident != 'context' and ident not in self.declared.union(self.locally_declared): + if ident != 'context' and\ + ident not in self.declared.union(self.locally_declared): self.undeclared.add(ident) - + + +_FOR_LOOP = re.compile( + r'^for\s+((?:\(?)\s*[A-Za-z_][A-Za-z_0-9]*' + r'(?:\s*,\s*(?:[A-Za-z_][A-Za-z0-9_]*),??)*\s*(?:\)?))\s+in\s+(.*):' + ) + +def mangle_mako_loop(node, printer): + """converts a for loop into a context manager wrapped around a for loop + when access to the `loop` variable has been detected in the for loop body + """ + loop_variable = LoopVariable() + node.accept_visitor(loop_variable) + if loop_variable.detected: + node.nodes[-1].has_loop_context = True + match = _FOR_LOOP.match(node.text) + if match: + printer.writelines( + 'loop = __M_loop._enter(%s)' % match.group(2), + 'try:' + #'with __M_loop(%s) as loop:' % match.group(2) + ) + text = 'for %s in loop:' % match.group(1) + else: + raise SyntaxError("Couldn't apply loop context: %s" % node.text) + else: + text = node.text + return text + + +class LoopVariable(object): + """A node visitor which looks for the name 'loop' within undeclared + identifiers.""" + + def __init__(self): + self.detected = False + + def _loop_reference_detected(self, node): + if 'loop' in node.undeclared_identifiers(): + self.detected = True + else: + for n in node.get_children(): + n.accept_visitor(self) + + def visitControlLine(self, node): + self._loop_reference_detected(node) + + def visitCode(self, node): + self._loop_reference_detected(node) + + def visitExpression(self, node): + self._loop_reference_detected(node) diff --git a/mako/exceptions.py b/mako/exceptions.py index 491d2af5..b8d5ef33 100644 --- a/mako/exceptions.py +++ b/mako/exceptions.py @@ -1,5 +1,5 @@ # mako/exceptions.py -# Copyright (C) 2006-2011 the Mako authors and contributors +# Copyright (C) 2006-2012 the Mako authors and contributors # # This module is part of Mako and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php @@ -20,19 +20,21 @@ def _format_filepos(lineno, pos, filename): return " at line: %d char: %d" % (lineno, pos) else: return " in file '%s' at line: %d char: %d" % (filename, lineno, pos) - - + + class CompileException(MakoException): def __init__(self, message, source, lineno, pos, filename): - MakoException.__init__(self, message + _format_filepos(lineno, pos, filename)) + MakoException.__init__(self, + message + _format_filepos(lineno, pos, filename)) self.lineno =lineno self.pos = pos self.filename = filename self.source = source - + class SyntaxException(MakoException): def __init__(self, message, source, lineno, pos, filename): - MakoException.__init__(self, message + _format_filepos(lineno, pos, filename)) + MakoException.__init__(self, + message + _format_filepos(lineno, pos, filename)) self.lineno =lineno self.pos = pos self.filename = filename @@ -40,47 +42,50 @@ class SyntaxException(MakoException): class UnsupportedError(MakoException): """raised when a retired feature is used.""" - + +class NameConflictError(MakoException): + """raised when a reserved word is used inappropriately""" + class TemplateLookupException(MakoException): pass class TopLevelLookupException(TemplateLookupException): pass - + class RichTraceback(object): - """Pulls the current exception from the sys traceback and extracts + """Pull the current exception from the ``sys`` traceback and extracts Mako-specific template information. - + See the usage examples in :ref:`handling_exceptions`. - + """ def __init__(self, error=None, traceback=None): self.source, self.lineno = "", 0 if error is None or traceback is None: t, value, tback = sys.exc_info() - + if error is None: error = value or t - + if traceback is None: traceback = tback - + self.error = error self.records = self._init(traceback) - + if isinstance(self.error, (CompileException, SyntaxException)): import mako.template self.source = self.error.source self.lineno = self.error.lineno self._has_source = True - + self._init_message() - + @property def errorname(self): return util.exception_name(self.error) - + def _init_message(self): """Find a unicode representation of self.error""" try: @@ -101,25 +106,25 @@ class RichTraceback(object): yield (rec[4], rec[5], rec[2], rec[6]) else: yield tuple(rec[0:4]) - + @property def traceback(self): - """return a list of 4-tuple traceback records (i.e. normal python + """Return a list of 4-tuple traceback records (i.e. normal python format) with template-corresponding lines remapped to the originating template. - + """ return list(self._get_reformatted_records(self.records)) - + @property def reverse_records(self): return reversed(self.records) - + @property def reverse_traceback(self): - """return the same data as traceback, except in reverse order. + """Return the same data as traceback, except in reverse order. """ - + return list(self._get_reformatted_records(self.reverse_records)) def _init(self, trcback): @@ -156,7 +161,7 @@ class RichTraceback(object): line = line.decode(encoding) else: line = line.decode('ascii', 'replace') - new_trcback.append((filename, lineno, function, line, + new_trcback.append((filename, lineno, function, line, None, None, None, None)) continue @@ -177,8 +182,8 @@ class RichTraceback(object): template_line = template_lines[template_ln - 1] else: template_line = None - new_trcback.append((filename, lineno, function, - line, template_filename, template_ln, + new_trcback.append((filename, lineno, function, + line, template_filename, template_ln, template_line, template_source)) if not self.source: for l in range(len(new_trcback)-1, 0, -1): @@ -202,13 +207,13 @@ class RichTraceback(object): self.lineno = new_trcback[-1][1] return new_trcback - + def text_error_template(lookup=None): """Provides a template that renders a stack trace in a similar format to the Python interpreter, substituting source template filenames, line numbers and code for that of the originating source template, as applicable. - + """ import mako.template return mako.template.Template(r""" @@ -227,22 +232,33 @@ Traceback (most recent call last): ${tback.errorname}: ${tback.message} """) + +try: + from mako.ext.pygmentplugin import syntax_highlight,\ + pygments_html_formatter +except ImportError: + from mako.filters import html_escape + pygments_html_formatter = None + def syntax_highlight(filename='', language=None): + return html_escape + def html_error_template(): """Provides a template that renders a stack trace in an HTML format, providing an excerpt of code as well as substituting source template filenames, line numbers and code for that of the originating source template, as applicable. - The template's default encoding_errors value is 'htmlentityreplace'. the - template has two options. With the full option disabled, only a section of - an HTML document is returned. with the css option disabled, the default + The template's default ``encoding_errors`` value is ``'htmlentityreplace'``. The + template has two options. With the ``full`` option disabled, only a section of + an HTML document is returned. With the ``css`` option disabled, the default stylesheet won't be included. - + """ import mako.template return mako.template.Template(r""" <%! - from mako.exceptions import RichTraceback + from mako.exceptions import RichTraceback, syntax_highlight,\ + pygments_html_formatter %> <%page args="full=True, css=True, error=None, traceback=None"/> % if full: @@ -256,10 +272,29 @@ def html_error_template(): .stacktrace { margin:5px 5px 5px 5px; } .highlight { padding:0px 10px 0px 10px; background-color:#9F9FDF; } .nonhighlight { padding:0px; background-color:#DFDFDF; } - .sample { padding:10px; margin:10px 10px 10px 10px; font-family:monospace; } + .sample { padding:10px; margin:10px 10px 10px 10px; + font-family:monospace; } .sampleline { padding:0px 10px 0px 10px; } .sourceline { margin:5px 5px 10px 5px; font-family:monospace;} .location { font-size:80%; } + .highlight { white-space:pre; } + .sampleline { white-space:pre; } + + % if pygments_html_formatter: + ${pygments_html_formatter.get_style_defs()} + .linenos { min-width: 2.5em; text-align: right; } + pre { margin: 0; } + .syntax-highlighted { padding: 0 10px; } + .syntax-highlightedtable { border-spacing: 1px; } + .nonhighlight { border-top: 1px solid #DFDFDF; + border-bottom: 1px solid #DFDFDF; } + .stacktrace .nonhighlight { margin: 5px 15px 10px; } + .sourceline { margin: 0 0; font-family:monospace; } + .code { background-color: #F8F8F8; width: 100%; } + .error .code { background-color: #FFBDBD; } + .error .syntax-highlighted { background-color: #FFBDBD; } + % endif + % endif % if full: @@ -277,16 +312,29 @@ def html_error_template(): else: lines = None %> -

${tback.errorname}: ${tback.message}

+

${tback.errorname}: ${tback.message|h}

% if lines:
% for index in range(max(0, line-4),min(len(lines), line+5)): + <% + if pygments_html_formatter: + pygments_html_formatter.linenostart = index + 1 + %> % if index + 1 == line: -
${index + 1} ${lines[index] | h}
+ <% + if pygments_html_formatter: + old_cssclass = pygments_html_formatter.cssclass + pygments_html_formatter.cssclass = 'error ' + old_cssclass + %> + ${lines[index] | syntax_highlight(language='mako')} + <% + if pygments_html_formatter: + pygments_html_formatter.cssclass = old_cssclass + %> % else: -
${index + 1} ${lines[index] | h}
+ ${lines[index] | syntax_highlight(language='mako')} % endif % endfor
@@ -296,7 +344,13 @@ def html_error_template():
% for (filename, lineno, function, line) in tback.reverse_traceback:
${filename}, line ${lineno}:
-
${line | h}
+
+ <% + if pygments_html_formatter: + pygments_html_formatter.linenostart = lineno + %> +
${line | syntax_highlight(filename)}
+
% endfor
@@ -304,4 +358,5 @@ def html_error_template(): % endif -""", output_encoding=sys.getdefaultencoding(), encoding_errors='htmlentityreplace') +""", output_encoding=sys.getdefaultencoding(), + encoding_errors='htmlentityreplace') diff --git a/mako/ext/autohandler.py b/mako/ext/autohandler.py index 5d89ac59..93c60866 100644 --- a/mako/ext/autohandler.py +++ b/mako/ext/autohandler.py @@ -1,5 +1,5 @@ # ext/autohandler.py -# Copyright (C) 2006-2011 the Mako authors and contributors +# Copyright (C) 2006-2012 the Mako authors and contributors # # This module is part of Mako and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php diff --git a/mako/ext/babelplugin.py b/mako/ext/babelplugin.py index 6b7c1d35..65f7e02f 100644 --- a/mako/ext/babelplugin.py +++ b/mako/ext/babelplugin.py @@ -1,5 +1,5 @@ # ext/babelplugin.py -# Copyright (C) 2006-2011 the Mako authors and contributors +# Copyright (C) 2006-2012 the Mako authors and contributors # # This module is part of Mako and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php @@ -68,6 +68,9 @@ def extract_nodes(nodes, keywords, comment_tags, options): if isinstance(node, parsetree.DefTag): code = node.function_decl.code child_nodes = node.nodes + elif isinstance(node, parsetree.BlockTag): + code = node.body_decl.code + child_nodes = node.nodes elif isinstance(node, parsetree.CallTag): code = node.code.code child_nodes = node.nodes diff --git a/mako/ext/beaker_cache.py b/mako/ext/beaker_cache.py new file mode 100644 index 00000000..f0b50fac --- /dev/null +++ b/mako/ext/beaker_cache.py @@ -0,0 +1,70 @@ +"""Provide a :class:`.CacheImpl` for the Beaker caching system.""" + +from mako import exceptions + +from mako.cache import CacheImpl + +_beaker_cache = None +class BeakerCacheImpl(CacheImpl): + """A :class:`.CacheImpl` provided for the Beaker caching system. + + This plugin is used by default, based on the default + value of ``'beaker'`` for the ``cache_impl`` parameter of the + :class:`.Template` or :class:`.TemplateLookup` classes. + + """ + + def __init__(self, cache): + global _beaker_cache + if _beaker_cache is None: + try: + from beaker import cache as beaker_cache + except ImportError, e: + raise exceptions.RuntimeException( + "the Beaker package is required to use cache " + "functionality.") + + if 'manager' in cache.template.cache_args: + _beaker_cache = cache.template.cache_args['manager'] + else: + _beaker_cache = beaker_cache.CacheManager() + super(BeakerCacheImpl, self).__init__(cache) + + def _get_cache(self, **kw): + expiretime = kw.pop('timeout', None) + if 'dir' in kw: + kw['data_dir'] = kw.pop('dir') + elif self.cache.template.module_directory: + kw['data_dir'] = self.cache.template.module_directory + + if 'manager' in kw: + kw.pop('manager') + + if kw.get('type') == 'memcached': + kw['type'] = 'ext:memcached' + + if 'region' in kw: + region = kw.pop('region') + cache = _beaker_cache.get_cache_region(self.cache.id, region, **kw) + else: + cache = _beaker_cache.get_cache(self.cache.id, **kw) + cache_args = {'starttime':self.cache.starttime} + if expiretime: + cache_args['expiretime'] = expiretime + return cache, cache_args + + def get_or_create(self, key, creation_function, **kw): + cache, kw = self._get_cache(**kw) + return cache.get(key, createfunc=creation_function, **kw) + + def put(self, key, value, **kw): + cache, kw = self._get_cache(**kw) + cache.put(key, value, **kw) + + def get(self, key, **kw): + cache, kw = self._get_cache(**kw) + return cache.get(key, **kw) + + def invalidate(self, key, **kw): + cache, kw = self._get_cache(**kw) + cache.remove_value(key, **kw) diff --git a/mako/ext/preprocessors.py b/mako/ext/preprocessors.py index 2c0d9935..fcc55007 100644 --- a/mako/ext/preprocessors.py +++ b/mako/ext/preprocessors.py @@ -1,5 +1,5 @@ # ext/preprocessors.py -# Copyright (C) 2006-2011 the Mako authors and contributors +# Copyright (C) 2006-2012 the Mako authors and contributors # # This module is part of Mako and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php diff --git a/mako/ext/pygmentplugin.py b/mako/ext/pygmentplugin.py index 0ce57c47..773f47a7 100644 --- a/mako/ext/pygmentplugin.py +++ b/mako/ext/pygmentplugin.py @@ -1,23 +1,19 @@ # ext/pygmentplugin.py -# Copyright (C) 2006-2011 the Mako authors and contributors +# Copyright (C) 2006-2012 the Mako authors and contributors # # This module is part of Mako and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php -import re -try: - set -except NameError: - from sets import Set as set - from pygments.lexers.web import \ HtmlLexer, XmlLexer, JavascriptLexer, CssLexer -from pygments.lexers.agile import PythonLexer -from pygments.lexer import Lexer, DelegatingLexer, RegexLexer, bygroups, \ - include, using, this -from pygments.token import Error, Punctuation, \ - Text, Comment, Operator, Keyword, Name, String, Number, Other, Literal -from pygments.util import html_doctype_matches, looks_like_xml +from pygments.lexers.agile import PythonLexer, Python3Lexer +from pygments.lexer import DelegatingLexer, RegexLexer, bygroups, \ + include, using +from pygments.token import \ + Text, Comment, Operator, Keyword, Name, String, Other +from pygments.formatters.html import HtmlFormatter +from pygments import highlight +from mako import util class MakoLexer(RegexLexer): name = 'Mako' @@ -30,13 +26,16 @@ class MakoLexer(RegexLexer): bygroups(Text, Comment.Preproc, Keyword, Other)), (r'(\s*)(\%(?!%))([^\n]*)(\n|\Z)', bygroups(Text, Comment.Preproc, using(PythonLexer), Other)), - (r'(\s*)(##[^\n]*)(\n|\Z)', + (r'(\s*)(##[^\n]*)(\n|\Z)', bygroups(Text, Comment.Preproc, Other)), - (r'''(?s)<%doc>.*?''', Comment.Preproc), - (r'(<%)([\w\.\:]+)', bygroups(Comment.Preproc, Name.Builtin), 'tag'), - (r'()', bygroups(Comment.Preproc, Name.Builtin, Comment.Preproc)), + (r'''(?s)<%doc>.*?''', Comment.Preproc), + (r'(<%)([\w\.\:]+)', + bygroups(Comment.Preproc, Name.Builtin), 'tag'), + (r'()', + bygroups(Comment.Preproc, Name.Builtin, Comment.Preproc)), (r'<%(?=([\w\.\:]+))', Comment.Preproc, 'ondeftags'), - (r'(<%(?:!?))(.*?)(%>)(?s)', bygroups(Comment.Preproc, using(PythonLexer), Comment.Preproc)), + (r'(<%(?:!?))(.*?)(%>)(?s)', + bygroups(Comment.Preproc, using(PythonLexer), Comment.Preproc)), (r'(\$\{)(.*?)(\})', bygroups(Comment.Preproc, using(PythonLexer), Comment.Preproc)), (r'''(?sx) @@ -105,3 +104,19 @@ class MakoCssLexer(DelegatingLexer): def __init__(self, **options): super(MakoCssLexer, self).__init__(CssLexer, MakoLexer, **options) + + +pygments_html_formatter = HtmlFormatter(cssclass='syntax-highlighted', + linenos=True) +def syntax_highlight(filename='', language=None): + mako_lexer = MakoLexer() + if util.py3k: + python_lexer = Python3Lexer() + else: + python_lexer = PythonLexer() + if filename.startswith('memory:') or language == 'mako': + return lambda string: highlight(string, mako_lexer, + pygments_html_formatter) + return lambda string: highlight(string, python_lexer, + pygments_html_formatter) + diff --git a/mako/ext/turbogears.py b/mako/ext/turbogears.py index f7822eea..e453ada1 100644 --- a/mako/ext/turbogears.py +++ b/mako/ext/turbogears.py @@ -1,5 +1,5 @@ # ext/turbogears.py -# Copyright (C) 2006-2011 the Mako authors and contributors +# Copyright (C) 2006-2012 the Mako authors and contributors # # This module is part of Mako and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php @@ -39,7 +39,8 @@ class TGPlugin(object): return Template(template_string, **self.tmpl_options) # Translate TG dot notation to normal / template path if '/' not in templatename: - templatename = '/' + templatename.replace('.', '/') + '.' + self.extension + templatename = '/' + templatename.replace('.', '/') + '.' +\ + self.extension # Lookup template return self.lookup.get_template(templatename) diff --git a/mako/filters.py b/mako/filters.py index 30c792f2..37c8fe4c 100644 --- a/mako/filters.py +++ b/mako/filters.py @@ -1,5 +1,5 @@ # mako/filters.py -# Copyright (C) 2006-2011 the Mako authors and contributors +# Copyright (C) 2006-2012 the Mako authors and contributors # # This module is part of Mako and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php @@ -11,10 +11,10 @@ from mako import util xml_escapes = { '&' : '&', - '>' : '>', - '<' : '<', + '>' : '>', + '<' : '<', '"' : '"', # also " in html-only - "'" : ''' # also ' in html-only + "'" : ''' # also ' in html-only } # XXX: " is valid in HTML and XML @@ -31,7 +31,7 @@ try: except ImportError: html_escape = legacy_html_escape - + def xml_escape(string): return re.sub(r'([&<"\'>])', lambda m: xml_escapes[m.group()], string) @@ -61,14 +61,14 @@ class Decode(object): return unicode(x, encoding=key) return decode decode = Decode() - - + + _ASCII_re = re.compile(r'\A[\x00-\x7f]*\Z') def is_ascii_str(text): return isinstance(text, str) and _ASCII_re.match(text) -################################################################ +################################################################ class XMLEntityEscaper(object): def __init__(self, codepoint2name, name2codepoint): @@ -115,7 +115,7 @@ class XMLEntityEscaper(object): | ( (?!\d) [:\w] [-.:\w]+ ) ) ;''', re.X | re.UNICODE) - + def __unescape(self, m): dval, hval, name = m.groups() if dval: @@ -128,7 +128,7 @@ class XMLEntityEscaper(object): if codepoint < 128: return chr(codepoint) return unichr(codepoint) - + def unescape(self, text): """Unescape character references. @@ -165,7 +165,8 @@ def htmlentityreplace_errors(ex): codecs.register_error('htmlentityreplace', htmlentityreplace_errors) -# TODO: options to make this dynamic per-compilation will be added in a later release +# TODO: options to make this dynamic per-compilation will be added in a later +# release DEFAULT_ESCAPES = { 'x':'filters.xml_escape', 'h':'filters.html_escape', diff --git a/mako/lexer.py b/mako/lexer.py index cf06bb52..267c0d13 100644 --- a/mako/lexer.py +++ b/mako/lexer.py @@ -1,5 +1,5 @@ # mako/lexer.py -# Copyright (C) 2006-2011 the Mako authors and contributors +# Copyright (C) 2006-2012 the Mako authors and contributors # # This module is part of Mako and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php @@ -13,8 +13,8 @@ from mako.pygen import adjust_whitespace _regexp_cache = {} class Lexer(object): - def __init__(self, text, filename=None, - disable_unicode=False, + def __init__(self, text, filename=None, + disable_unicode=False, input_encoding=None, preprocessor=None): self.text = text self.filename = filename @@ -25,31 +25,32 @@ class Lexer(object): self.match_position = 0 self.tag = [] self.control_line = [] + self.ternary_stack = [] self.disable_unicode = disable_unicode self.encoding = input_encoding - + if util.py3k and disable_unicode: raise exceptions.UnsupportedError( "Mako for Python 3 does not " "support disabling Unicode") - + if preprocessor is None: self.preprocessor = [] elif not hasattr(preprocessor, '__iter__'): self.preprocessor = [preprocessor] else: self.preprocessor = preprocessor - + @property def exception_kwargs(self): - return {'source':self.text, - 'lineno':self.matched_lineno, - 'pos':self.matched_charpos, + return {'source':self.text, + 'lineno':self.matched_lineno, + 'pos':self.matched_charpos, 'filename':self.filename} - + def match(self, regexp, flags=None): """compile the given regexp, cache the reg, and call match_reg().""" - + try: reg = _regexp_cache[(regexp, flags)] except KeyError: @@ -58,14 +59,15 @@ class Lexer(object): else: reg = re.compile(regexp) _regexp_cache[(regexp, flags)] = reg - + return self.match_reg(reg) - + def match_reg(self, reg): - """match the given regular expression object to the current text position. - + """match the given regular expression object to the current text + position. + if a match occurs, update the current text and line position. - + """ mp = self.match_position @@ -84,39 +86,43 @@ class Lexer(object): cp -=1 self.matched_charpos = mp - cp self.lineno += len(lines) - #print "MATCHED:", match.group(0), "LINE START:", + #print "MATCHED:", match.group(0), "LINE START:", # self.matched_lineno, "LINE END:", self.lineno - #print "MATCH:", regexp, "\n", self.text[mp : mp + 15], (match and "TRUE" or "FALSE") + #print "MATCH:", regexp, "\n", self.text[mp : mp + 15], \ + # (match and "TRUE" or "FALSE") return match - + def parse_until_text(self, *text): startpos = self.match_position + text_re = r'|'.join(text) + brace_level = 0 while True: match = self.match(r'#.*\n') if match: continue - match = self.match(r'(\"\"\"|\'\'\'|\"|\')') + match = self.match(r'(\"\"\"|\'\'\'|\"|\')((? 0: + brace_level -= 1 + continue + return \ + self.text[startpos:\ + self.match_position-len(match.group(1))],\ + match.group(1) + match = self.match(r"(.*?)(?=\"|\'|#|%s)" % text_re, re.S) + if match: + brace_level += match.group(1).count('{') + brace_level -= match.group(1).count('}') + continue + raise exceptions.SyntaxException( + "Expected: %s" % + ','.join(text), + **self.exception_kwargs) + def append_node(self, nodecls, *args, **kwargs): kwargs.setdefault('source', self.text) kwargs.setdefault('lineno', self.matched_lineno) @@ -127,6 +133,17 @@ class Lexer(object): self.tag[-1].nodes.append(node) else: self.template.nodes.append(node) + # build a set of child nodes for the control line + # (used for loop variable detection) + # also build a set of child nodes on ternary control lines + # (used for determining if a pass needs to be auto-inserted + if self.control_line: + control_frame = self.control_line[-1] + control_frame.nodes.append(node) + if not (isinstance(node, parsetree.ControlLine) and + control_frame.is_ternary(node.keyword)): + if self.ternary_stack and self.ternary_stack[-1]: + self.ternary_stack[-1][-1].nodes.append(node) if isinstance(node, parsetree.Tag): if len(self.tag): node.parent = self.tag[-1] @@ -134,14 +151,19 @@ class Lexer(object): elif isinstance(node, parsetree.ControlLine): if node.isend: self.control_line.pop() + self.ternary_stack.pop() elif node.is_primary: self.control_line.append(node) - elif len(self.control_line) and \ + self.ternary_stack.append([]) + elif self.control_line and \ + self.control_line[-1].is_ternary(node.keyword): + self.ternary_stack[-1].append(node) + elif self.control_line and \ not self.control_line[-1].is_ternary(node.keyword): raise exceptions.SyntaxException( - "Keyword '%s' not a legal ternary for keyword '%s'" % - (node.keyword, self.control_line[-1].keyword), - **self.exception_kwargs) + "Keyword '%s' not a legal ternary for keyword '%s'" % + (node.keyword, self.control_line[-1].keyword), + **self.exception_kwargs) _coding_re = re.compile(r'#.*coding[:=]\s*([-\w.]+).*\r?\n') @@ -163,8 +185,8 @@ class Lexer(object): if m is not None and m.group(1) != 'utf-8': raise exceptions.CompileException( "Found utf-8 BOM in file, with conflicting " - "magic encoding comment of '%s'" % m.group(1), - text.decode('utf-8', 'ignore'), + "magic encoding comment of '%s'" % m.group(1), + text.decode('utf-8', 'ignore'), 0, 0, filename) else: m = self._coding_re.match(text.decode('utf-8', 'ignore')) @@ -178,32 +200,32 @@ class Lexer(object): text = text.decode(parsed_encoding) except UnicodeDecodeError, e: raise exceptions.CompileException( - "Unicode decode operation of encoding '%s' failed" % - parsed_encoding, - text.decode('utf-8', 'ignore'), - 0, 0, filename) + "Unicode decode operation of encoding '%s' failed" % + parsed_encoding, + text.decode('utf-8', 'ignore'), + 0, 0, filename) return parsed_encoding, text def parse(self): - self.encoding, self.text = self.decode_raw_stream(self.text, - not self.disable_unicode, + self.encoding, self.text = self.decode_raw_stream(self.text, + not self.disable_unicode, self.encoding, self.filename,) for preproc in self.preprocessor: self.text = preproc(self.text) - - # push the match marker past the + + # push the match marker past the # encoding comment. self.match_reg(self._coding_re) - + self.textlength = len(self.text) - + while (True): - if self.match_position > self.textlength: + if self.match_position > self.textlength: break - + if self.match_end(): break if self.match_expression(): @@ -212,53 +234,56 @@ class Lexer(object): continue if self.match_comment(): continue - if self.match_tag_start(): + if self.match_tag_start(): continue if self.match_tag_end(): continue if self.match_python_block(): continue - if self.match_text(): + if self.match_text(): continue - - if self.match_position > self.textlength: + + if self.match_position > self.textlength: break raise exceptions.CompileException("assertion failed") - + if len(self.tag): - raise exceptions.SyntaxException("Unclosed tag: <%%%s>" % - self.tag[-1].keyword, + raise exceptions.SyntaxException("Unclosed tag: <%%%s>" % + self.tag[-1].keyword, **self.exception_kwargs) if len(self.control_line): - raise exceptions.SyntaxException("Unterminated control keyword: '%s'" % - self.control_line[-1].keyword, - self.text, - self.control_line[-1].lineno, - self.control_line[-1].pos, self.filename) + raise exceptions.SyntaxException( + "Unterminated control keyword: '%s'" % + self.control_line[-1].keyword, + self.text, + self.control_line[-1].lineno, + self.control_line[-1].pos, self.filename) return self.template def match_tag_start(self): match = self.match(r''' \<% # opening tag - + ([\w\.\:]+) # keyword - - ((?:\s+\w+|\s*=\s*|".*?"|'.*?')*) # attrname, = sign, string expression - + + ((?:\s+\w+|\s*=\s*|".*?"|'.*?')*) # attrname, = \ + # sign, string expression + \s* # more whitespace - + (/)?> # closing - - ''', - + + ''', + re.I | re.S | re.X) - + if match: - keyword, attr, isend = match.group(1), match.group(2), match.group(3) + keyword, attr, isend = match.groups() self.keyword = keyword attributes = {} if attr: - for att in re.findall(r"\s*(\w+)\s*=\s*(?:'([^']*)'|\"([^\"]*)\")", attr): + for att in re.findall( + r"\s*(\w+)\s*=\s*(?:'([^']*)'|\"([^\"]*)\")", attr): key, val1, val2 = att text = val1 or val2 text = text.replace('\r\n', '\n') @@ -271,33 +296,33 @@ class Lexer(object): match = self.match(r'(.*?)(?=\)', re.S) if not match: raise exceptions.SyntaxException( - "Unclosed tag: <%%%s>" % - self.tag[-1].keyword, + "Unclosed tag: <%%%s>" % + self.tag[-1].keyword, **self.exception_kwargs) self.append_node(parsetree.Text, match.group(1)) return self.match_tag_end() return True - else: + else: return False - + def match_tag_end(self): match = self.match(r'\') if match: if not len(self.tag): raise exceptions.SyntaxException( - "Closing tag without opening tag: " % - match.group(1), - **self.exception_kwargs) + "Closing tag without opening tag: " % + match.group(1), + **self.exception_kwargs) elif self.tag[-1].keyword != match.group(1): raise exceptions.SyntaxException( - "Closing tag does not match tag: <%%%s>" % - (match.group(1), self.tag[-1].keyword), - **self.exception_kwargs) + "Closing tag does not match tag: <%%%s>" % + (match.group(1), self.tag[-1].keyword), + **self.exception_kwargs) self.tag.pop() return True else: return False - + def match_end(self): match = self.match(r'\Z', re.S) if match: @@ -308,13 +333,13 @@ class Lexer(object): return True else: return False - + def match_text(self): match = self.match(r""" (.*?) # anything, followed by: ( - (?<=\n)(?=[ \t]*(?=%|\#\#)) # an eval or line-based - # comment preceded by a + (?<=\n)(?=[ \t]*(?=%|\#\#)) # an eval or line-based + # comment preceded by a # consumed newline and whitespace | (?=\${) # an expression @@ -328,7 +353,7 @@ class Lexer(object): | \Z # end of string )""", re.X | re.S) - + if match: text = match.group(1) if text: @@ -336,23 +361,23 @@ class Lexer(object): return True else: return False - + def match_python_block(self): match = self.match(r"<%(!)?") if match: line, pos = self.matched_lineno, self.matched_charpos text, end = self.parse_until_text(r'%>') - # the trailing newline helps + # the trailing newline helps # compiler.parse() not complain about indentation - text = adjust_whitespace(text) + "\n" + text = adjust_whitespace(text) + "\n" self.append_node( - parsetree.Code, - text, + parsetree.Code, + text, match.group(1)=='!', lineno=line, pos=pos) return True else: return False - + def match_expression(self): match = self.match(r"\${") if match: @@ -364,15 +389,17 @@ class Lexer(object): escapes = "" text = text.replace('\r\n', '\n') self.append_node( - parsetree.Expression, - text, escapes.strip(), + parsetree.Expression, + text, escapes.strip(), lineno=line, pos=pos) return True else: return False def match_control_line(self): - match = self.match(r"(?<=^)[\t ]*(%(?!%)|##)[\t ]*((?:(?:\\r?\n)|[^\r\n])*)(?:\r?\n|\Z)", re.M) + match = self.match( + r"(?<=^)[\t ]*(%(?!%)|##)[\t ]*((?:(?:\\r?\n)|[^\r\n])*)" + r"(?:\r?\n|\Z)", re.M) if match: operator = match.group(1) text = match.group(2) @@ -380,22 +407,22 @@ class Lexer(object): m2 = re.match(r'(end)?(\w+)\s*(.*)', text) if not m2: raise exceptions.SyntaxException( - "Invalid control line: '%s'" % - text, + "Invalid control line: '%s'" % + text, **self.exception_kwargs) isend, keyword = m2.group(1, 2) isend = (isend is not None) - + if isend: if not len(self.control_line): raise exceptions.SyntaxException( - "No starting keyword '%s' for '%s'" % - (keyword, text), + "No starting keyword '%s' for '%s'" % + (keyword, text), **self.exception_kwargs) elif self.control_line[-1].keyword != keyword: raise exceptions.SyntaxException( - "Keyword '%s' doesn't match keyword '%s'" % - (text, self.control_line[-1].keyword), + "Keyword '%s' doesn't match keyword '%s'" % + (text, self.control_line[-1].keyword), **self.exception_kwargs) self.append_node(parsetree.ControlLine, keyword, isend, text) else: @@ -412,4 +439,4 @@ class Lexer(object): return True else: return False - + diff --git a/mako/lookup.py b/mako/lookup.py index b397d21f..4d86696a 100644 --- a/mako/lookup.py +++ b/mako/lookup.py @@ -1,5 +1,5 @@ # mako/lookup.py -# Copyright (C) 2006-2011 the Mako authors and contributors +# Copyright (C) 2006-2012 the Mako authors and contributors # # This module is part of Mako and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php @@ -12,30 +12,30 @@ try: import threading except: import dummy_threading as threading - + class TemplateCollection(object): - """Represent a collection of :class:`.Template` objects, - identifiable via uri. - + """Represent a collection of :class:`.Template` objects, + identifiable via URI. + A :class:`.TemplateCollection` is linked to the usage of - all template tags that address other templates, such + all template tags that address other templates, such as ``<%include>``, ``<%namespace>``, and ``<%inherit>``. The ``file`` attribute of each of those tags refers to a string URI that is passed to that :class:`.Template` object's :class:`.TemplateCollection` for resolution. - + :class:`.TemplateCollection` is an abstract class, with the usual default implementation being :class:`.TemplateLookup`. - + """ def has_template(self, uri): """Return ``True`` if this :class:`.TemplateLookup` is capable of returning a :class:`.Template` object for the - given URL. + given ``uri``. + + :param uri: String URI of the template to be resolved. - :param uri: String uri of the template to be resolved. - """ try: self.get_template(uri) @@ -44,124 +44,135 @@ class TemplateCollection(object): return False def get_template(self, uri, relativeto=None): - """Return a :class:`.Template` object corresponding to the given - URL. - + """Return a :class:`.Template` object corresponding to the given + ``uri``. + The default implementation raises :class:`.NotImplementedError`. Implementations should - raise :class:`.TemplateLookupException` if the given uri + raise :class:`.TemplateLookupException` if the given ``uri`` cannot be resolved. - - :param uri: String uri of the template to be resolved. - :param relativeto: if present, the given URI is assumed to - be relative to this uri. - + + :param uri: String URI of the template to be resolved. + :param relativeto: if present, the given ``uri`` is assumed to + be relative to this URI. + """ raise NotImplementedError() def filename_to_uri(self, uri, filename): - """Convert the given filename to a uri relative to - this TemplateCollection.""" - + """Convert the given ``filename`` to a URI relative to + this :class:`.TemplateCollection`.""" + return uri - + def adjust_uri(self, uri, filename): - """Adjust the given uri based on the calling filename. - + """Adjust the given ``uri`` based on the calling ``filename``. + When this method is called from the runtime, the - 'filename' parameter is taken directly to the 'filename' + ``filename`` parameter is taken directly to the ``filename`` attribute of the calling template. Therefore a custom - TemplateCollection subclass can place any string - identifier desired in the "filename" parameter of the - Template objects it constructs and have them come back + :class:`.TemplateCollection` subclass can place any string + identifier desired in the ``filename`` parameter of the + :class:`.Template` objects it constructs and have them come back here. - + """ return uri - + class TemplateLookup(TemplateCollection): """Represent a collection of templates that locates template source files from the local filesystem. - + The primary argument is the ``directories`` argument, the list of - directories to search:: - + directories to search: + + .. sourcecode:: python + lookup = TemplateLookup(["/path/to/templates"]) some_template = lookup.get_template("/index.html") - + The :class:`.TemplateLookup` can also be given :class:`.Template` objects - programatically using :meth:`.put_string` or :meth:`.put_template`:: - + programatically using :meth:`.put_string` or :meth:`.put_template`: + + .. sourcecode:: python + lookup = TemplateLookup() lookup.put_string("base.html", ''' ${self.next()} ''') lookup.put_string("hello.html", ''' <%include file='base.html'/> - + Hello, world ! ''') - - - :param directories: A list of directory names which will be + + + :param directories: A list of directory names which will be searched for a particular template URI. The URI is appended to each directory and the filesystem checked. - - :param collection_size: Approximate size of the collection used - to store templates. If left at its default of -1, the size + + :param collection_size: Approximate size of the collection used + to store templates. If left at its default of ``-1``, the size is unbounded, and a plain Python dictionary is used to relate URI strings to :class:`.Template` instances. Otherwise, a least-recently-used cache object is used which will maintain the size of the collection approximately to the number given. - - :param filesystem_checks: When at its default value of ``True``, - each call to :meth:`TemplateLookup.get_template()` will + + :param filesystem_checks: When at its default value of ``True``, + each call to :meth:`.TemplateLookup.get_template()` will compare the filesystem last modified time to the time in which an existing :class:`.Template` object was created. This allows the :class:`.TemplateLookup` to regenerate a new :class:`.Template` whenever the original source has been updated. Set this to ``False`` for a very minor performance increase. - - :param modulename_callable: A callable which, when present, + + :param modulename_callable: A callable which, when present, is passed the path of the source file as well as the requested URI, and then returns the full path of the generated Python module file. This is used to inject - alternate schemes for Pyhton module location. If left at + alternate schemes for Python module location. If left at its default of ``None``, the built in system of generation based on ``module_directory`` plus ``uri`` is used. - + All other keyword parameters available for :class:`.Template` are mirrored here. When new :class:`.Template` objects are created, the keywords established with this :class:`.TemplateLookup` are passed on to each new :class:`.Template`. - + """ - - def __init__(self, - directories=None, - module_directory=None, - filesystem_checks=True, - collection_size=-1, - format_exceptions=False, - error_handler=None, - disable_unicode=False, + + def __init__(self, + directories=None, + module_directory=None, + filesystem_checks=True, + collection_size=-1, + format_exceptions=False, + error_handler=None, + disable_unicode=False, bytestring_passthrough=False, - output_encoding=None, - encoding_errors='strict', - cache_type=None, - cache_dir=None, cache_url=None, - cache_enabled=True, - modulename_callable=None, - default_filters=None, - buffer_filters=(), + output_encoding=None, + encoding_errors='strict', + + cache_args=None, + cache_impl='beaker', + cache_enabled=True, + cache_type=None, + cache_dir=None, + cache_url=None, + + modulename_callable=None, + module_writer=None, + default_filters=None, + buffer_filters=(), strict_undefined=False, - imports=None, - input_encoding=None, + imports=None, + enable_loop=True, + input_encoding=None, preprocessor=None): - + self.directories = [posixpath.normpath(d) for d in util.to_list(directories, ()) ] @@ -170,23 +181,34 @@ class TemplateLookup(TemplateCollection): self.filesystem_checks = filesystem_checks self.collection_size = collection_size + if cache_args is None: + cache_args = {} + # transfer deprecated cache_* args + if cache_dir: + cache_args.setdefault('dir', cache_dir) + if cache_url: + cache_args.setdefault('url', cache_url) + if cache_type: + cache_args.setdefault('type', cache_type) + self.template_args = { - 'format_exceptions':format_exceptions, - 'error_handler':error_handler, - 'disable_unicode':disable_unicode, + 'format_exceptions':format_exceptions, + 'error_handler':error_handler, + 'disable_unicode':disable_unicode, 'bytestring_passthrough':bytestring_passthrough, - 'output_encoding':output_encoding, - 'encoding_errors':encoding_errors, - 'input_encoding':input_encoding, - 'module_directory':module_directory, - 'cache_type':cache_type, - 'cache_dir':cache_dir or module_directory, - 'cache_url':cache_url, - 'cache_enabled':cache_enabled, - 'default_filters':default_filters, - 'buffer_filters':buffer_filters, + 'output_encoding':output_encoding, + 'cache_impl':cache_impl, + 'encoding_errors':encoding_errors, + 'input_encoding':input_encoding, + 'module_directory':module_directory, + 'module_writer':module_writer, + 'cache_args':cache_args, + 'cache_enabled':cache_enabled, + 'default_filters':default_filters, + 'buffer_filters':buffer_filters, 'strict_undefined':strict_undefined, - 'imports':imports, + 'imports':imports, + 'enable_loop':enable_loop, 'preprocessor':preprocessor} if collection_size == -1: @@ -196,15 +218,15 @@ class TemplateLookup(TemplateCollection): self._collection = util.LRUCache(collection_size) self._uri_cache = util.LRUCache(collection_size) self._mutex = threading.Lock() - + def get_template(self, uri): - """Return a :class:`.Template` object corresponding to the given - URL. - - Note the "relativeto" argument is not supported here at the moment. - + """Return a :class:`.Template` object corresponding to the given + ``uri``. + + .. note:: The ``relativeto`` argument is not supported here at the moment. + """ - + try: if self.filesystem_checks: return self._check(uri, self._collection[uri]) @@ -221,25 +243,26 @@ class TemplateLookup(TemplateCollection): "Cant locate template for uri %r" % uri) def adjust_uri(self, uri, relativeto): - """adjust the given uri based on the given relative uri.""" - + """Adjust the given ``uri`` based on the given relative URI.""" + key = (uri, relativeto) if key in self._uri_cache: return self._uri_cache[key] if uri[0] != '/': if relativeto is not None: - v = self._uri_cache[key] = posixpath.join(posixpath.dirname(relativeto), uri) + v = self._uri_cache[key] = posixpath.join( + posixpath.dirname(relativeto), uri) else: v = self._uri_cache[key] = '/' + uri else: v = self._uri_cache[key] = uri return v - - + + def filename_to_uri(self, filename): - """Convert the given filename to a uri relative to - this TemplateCollection.""" + """Convert the given ``filename`` to a URI relative to + this :class:`.TemplateCollection`.""" try: return self._uri_cache[filename] @@ -247,25 +270,25 @@ class TemplateLookup(TemplateCollection): value = self._relativeize(filename) self._uri_cache[filename] = value return value - + def _relativeize(self, filename): - """Return the portion of a filename that is 'relative' + """Return the portion of a filename that is 'relative' to the directories in this lookup. - + """ - + filename = posixpath.normpath(filename) for dir in self.directories: if filename[0:len(dir)] == dir: return filename[len(dir):] else: return None - + def _load(self, filename, uri): self._mutex.acquire() try: try: - # try returning from collection one + # try returning from collection one # more time in case concurrent thread already loaded return self._collection[uri] except KeyError: @@ -278,19 +301,19 @@ class TemplateLookup(TemplateCollection): self._collection[uri] = template = Template( uri=uri, filename=posixpath.normpath(filename), - lookup=self, + lookup=self, module_filename=module_filename, **self.template_args) return template except: - # if compilation fails etc, ensure + # if compilation fails etc, ensure # template is removed from collection, # re-raise self._collection.pop(uri, None) raise finally: self._mutex.release() - + def _check(self, uri, template): if template.filename is None: return template @@ -308,24 +331,24 @@ class TemplateLookup(TemplateCollection): raise exceptions.TemplateLookupException( "Cant locate template for uri %r" % uri) - + def put_string(self, uri, text): """Place a new :class:`.Template` object into this :class:`.TemplateLookup`, based on the given string of - text. - + ``text``. + """ self._collection[uri] = Template( - text, - lookup=self, - uri=uri, + text, + lookup=self, + uri=uri, **self.template_args) - + def put_template(self, uri, template): """Place a new :class:`.Template` object into this :class:`.TemplateLookup`, based on the given :class:`.Template` object. - + """ self._collection[uri] = template - + diff --git a/mako/parsetree.py b/mako/parsetree.py index 31b9b4f0..ecd82425 100644 --- a/mako/parsetree.py +++ b/mako/parsetree.py @@ -1,5 +1,5 @@ # mako/parsetree.py -# Copyright (C) 2006-2011 the Mako authors and contributors +# Copyright (C) 2006-2012 the Mako authors and contributors # # This module is part of Mako and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php @@ -11,117 +11,125 @@ import re class Node(object): """base class for a Node in the parse tree.""" + def __init__(self, source, lineno, pos, filename): self.source = source self.lineno = lineno self.pos = pos self.filename = filename - + @property def exception_kwargs(self): - return {'source':self.source, 'lineno':self.lineno, + return {'source':self.source, 'lineno':self.lineno, 'pos':self.pos, 'filename':self.filename} - + def get_children(self): return [] - + def accept_visitor(self, visitor): def traverse(node): for n in node.get_children(): n.accept_visitor(visitor) + method = getattr(visitor, "visit" + self.__class__.__name__, traverse) method(self) class TemplateNode(Node): """a 'container' node that stores the overall collection of nodes.""" - + def __init__(self, filename): super(TemplateNode, self).__init__('', 0, 0, filename) self.nodes = [] self.page_attributes = {} - + def get_children(self): return self.nodes - + def __repr__(self): return "TemplateNode(%s, %r)" % ( - util.sorted_dict_repr(self.page_attributes), + util.sorted_dict_repr(self.page_attributes), self.nodes) - + class ControlLine(Node): """defines a control line, a line-oriented python line or end tag. - + e.g.:: % if foo: (markup) % endif - + """ + has_loop_context = False + def __init__(self, keyword, isend, text, **kwargs): super(ControlLine, self).__init__(**kwargs) self.text = text self.keyword = keyword self.isend = isend - self.is_primary = keyword in ['for','if', 'while', 'try'] + self.is_primary = keyword in ['for', 'if', 'while', 'try', 'with'] + self.nodes = [] if self.isend: self._declared_identifiers = [] self._undeclared_identifiers = [] else: code = ast.PythonFragment(text, **self.exception_kwargs) - self._declared_identifiers = code.declared_identifiers + self._declared_identifiers = code.declared_identifiers self._undeclared_identifiers = code.undeclared_identifiers + def get_children(self): + return self.nodes + def declared_identifiers(self): return self._declared_identifiers def undeclared_identifiers(self): return self._undeclared_identifiers - + def is_ternary(self, keyword): """return true if the given keyword is a ternary keyword for this ControlLine""" - + return keyword in { 'if':set(['else', 'elif']), 'try':set(['except', 'finally']), 'for':set(['else']) }.get(self.keyword, []) - + def __repr__(self): return "ControlLine(%r, %r, %r, %r)" % ( - self.keyword, - self.text, - self.isend, + self.keyword, + self.text, + self.isend, (self.lineno, self.pos) ) class Text(Node): """defines plain text in the template.""" - + def __init__(self, content, **kwargs): super(Text, self).__init__(**kwargs) self.content = content - + def __repr__(self): return "Text(%r, %r)" % (self.content, (self.lineno, self.pos)) - + class Code(Node): """defines a Python code block, either inline or module level. - + e.g.:: inline: <% x = 12 %> - + module level: <%! import logger %> - + """ def __init__(self, text, ismodule, **kwargs): @@ -138,32 +146,32 @@ class Code(Node): def __repr__(self): return "Code(%r, %r, %r)" % ( - self.text, - self.ismodule, + self.text, + self.ismodule, (self.lineno, self.pos) ) - + class Comment(Node): """defines a comment line. - + # this is a comment - + """ - + def __init__(self, text, **kwargs): super(Comment, self).__init__(**kwargs) self.text = text def __repr__(self): return "Comment(%r, %r)" % (self.text, (self.lineno, self.pos)) - + class Expression(Node): """defines an inline expression. - + ${x+y} - + """ - + def __init__(self, text, escapes, **kwargs): super(Expression, self).__init__(**kwargs) self.text = text @@ -184,74 +192,74 @@ class Expression(Node): def __repr__(self): return "Expression(%r, %r, %r)" % ( - self.text, - self.escapes_code.args, + self.text, + self.escapes_code.args, (self.lineno, self.pos) ) - + class _TagMeta(type): """metaclass to allow Tag to produce a subclass according to its keyword""" - + _classmap = {} - + def __init__(cls, clsname, bases, dict): if cls.__keyword__ is not None: cls._classmap[cls.__keyword__] = cls super(_TagMeta, cls).__init__(clsname, bases, dict) - + def __call__(cls, keyword, attributes, **kwargs): if ":" in keyword: ns, defname = keyword.split(':') - return type.__call__(CallNamespaceTag, ns, defname, + return type.__call__(CallNamespaceTag, ns, defname, attributes, **kwargs) try: cls = _TagMeta._classmap[keyword] except KeyError: raise exceptions.CompileException( - "No such tag: '%s'" % keyword, - source=kwargs['source'], - lineno=kwargs['lineno'], - pos=kwargs['pos'], + "No such tag: '%s'" % keyword, + source=kwargs['source'], + lineno=kwargs['lineno'], + pos=kwargs['pos'], filename=kwargs['filename'] ) return type.__call__(cls, keyword, attributes, **kwargs) - + class Tag(Node): """abstract base class for tags. - + <%sometag/> - + <%someothertag> stuff - + """ - + __metaclass__ = _TagMeta __keyword__ = None - - def __init__(self, keyword, attributes, expressions, + + def __init__(self, keyword, attributes, expressions, nonexpressions, required, **kwargs): """construct a new Tag instance. - + this constructor not called directly, and is only called by subclasses. - + :param keyword: the tag keyword - + :param attributes: raw dictionary of attribute key/value pairs - - :param expressions: a set of identifiers that are legal attributes, + + :param expressions: a set of identifiers that are legal attributes, which can also contain embedded expressions - - :param nonexpressions: a set of identifiers that are legal + + :param nonexpressions: a set of identifiers that are legal attributes, which cannot contain embedded expressions - + :param \**kwargs: other arguments passed to the Node superclass (lineno, pos) - + """ super(Tag, self).__init__(**kwargs) self.keyword = keyword @@ -260,18 +268,18 @@ class Tag(Node): missing = [r for r in required if r not in self.parsed_attributes] if len(missing): raise exceptions.CompileException( - "Missing attribute(s): %s" % - ",".join([repr(m) for m in missing]), + "Missing attribute(s): %s" % + ",".join([repr(m) for m in missing]), **self.exception_kwargs) self.parent = None self.nodes = [] - + def is_root(self): return self.parent is None - + def get_children(self): return self.nodes - + def _parse_attributes(self, expressions, nonexpressions): undeclared_identifiers = set() self.parsed_attributes = {} @@ -285,8 +293,8 @@ class Tag(Node): code = ast.PythonCode(m.group(1).rstrip(), **self.exception_kwargs) # we aren't discarding "declared_identifiers" here, - # which we do so that list comprehension-declared - # variables aren't counted. As yet can't find a + # which we do so that list comprehension-declared + # variables aren't counted. As yet can't find a # condition that requires it here. undeclared_identifiers = \ undeclared_identifiers.union( @@ -299,14 +307,14 @@ class Tag(Node): elif key in nonexpressions: if re.search(r'\${.+?}', self.attributes[key]): raise exceptions.CompileException( - "Attibute '%s' in tag '%s' does not allow embedded " - "expressions" % (key, self.keyword), - **self.exception_kwargs) + "Attibute '%s' in tag '%s' does not allow embedded " + "expressions" % (key, self.keyword), + **self.exception_kwargs) self.parsed_attributes[key] = repr(self.attributes[key]) else: raise exceptions.CompileException( "Invalid attribute for tag '%s': '%s'" % - (self.keyword, key), + (self.keyword, key), **self.exception_kwargs) self.expression_undeclared_identifiers = undeclared_identifiers @@ -317,21 +325,21 @@ class Tag(Node): return self.expression_undeclared_identifiers def __repr__(self): - return "%s(%r, %s, %r, %r)" % (self.__class__.__name__, - self.keyword, + return "%s(%r, %s, %r, %r)" % (self.__class__.__name__, + self.keyword, util.sorted_dict_repr(self.attributes), - (self.lineno, self.pos), + (self.lineno, self.pos), self.nodes ) - + class IncludeTag(Tag): __keyword__ = 'include' def __init__(self, keyword, attributes, **kwargs): super(IncludeTag, self).__init__( - keyword, - attributes, - ('file', 'import', 'args'), + keyword, + attributes, + ('file', 'import', 'args'), (), ('file',), **kwargs) self.page_args = ast.PythonCode( "__DUMMY(%s)" % attributes.get('args', ''), @@ -346,18 +354,18 @@ class IncludeTag(Tag): difference(self.page_args.declared_identifiers) return identifiers.union(super(IncludeTag, self). undeclared_identifiers()) - + class NamespaceTag(Tag): __keyword__ = 'namespace' def __init__(self, keyword, attributes, **kwargs): super(NamespaceTag, self).__init__( - keyword, attributes, - ('file',), + keyword, attributes, + ('file',), ('name','inheritable', - 'import','module'), + 'import','module'), (), **kwargs) - + self.name = attributes.get('name', '__anon_%s' % hex(abs(id(self)))) if not 'name' in attributes and not 'import' in attributes: raise exceptions.CompileException( @@ -378,36 +386,39 @@ class TextTag(Tag): def __init__(self, keyword, attributes, **kwargs): super(TextTag, self).__init__( - keyword, - attributes, (), + keyword, + attributes, (), ('filter'), (), **kwargs) self.filter_args = ast.ArgumentList( - attributes.get('filter', ''), + attributes.get('filter', ''), **self.exception_kwargs) - + class DefTag(Tag): __keyword__ = 'def' def __init__(self, keyword, attributes, **kwargs): + expressions = ['buffered', 'cached'] + [ + c for c in attributes if c.startswith('cache_')] + + super(DefTag, self).__init__( - keyword, - attributes, - ('buffered', 'cached', 'cache_key', 'cache_timeout', - 'cache_type', 'cache_dir', 'cache_url'), - ('name','filter', 'decorator'), - ('name',), + keyword, + attributes, + expressions, + ('name','filter', 'decorator'), + ('name',), **kwargs) name = attributes['name'] if re.match(r'^[\w_]+$',name): raise exceptions.CompileException( - "Missing parenthesis in %def", + "Missing parenthesis in %def", **self.exception_kwargs) - self.function_decl = ast.FunctionDecl("def " + name + ":pass", + self.function_decl = ast.FunctionDecl("def " + name + ":pass", **self.exception_kwargs) self.name = self.function_decl.funcname self.decorator = attributes.get('decorator', '') self.filter_args = ast.ArgumentList( - attributes.get('filter', ''), + attributes.get('filter', ''), **self.exception_kwargs) is_anonymous = False @@ -428,40 +439,47 @@ class DefTag(Tag): for c in self.function_decl.defaults: res += list(ast.PythonCode(c, **self.exception_kwargs). undeclared_identifiers) - return res + list(self.filter_args.\ + return set(res).union( + self.filter_args.\ undeclared_identifiers.\ difference(filters.DEFAULT_ESCAPES.keys()) - ) + ).union( + self.expression_undeclared_identifiers + ).difference( + self.function_decl.argnames + ) class BlockTag(Tag): __keyword__ = 'block' def __init__(self, keyword, attributes, **kwargs): + expressions = ['buffered', 'cached', 'args'] + [ + c for c in attributes if c.startswith('cache_')] + super(BlockTag, self).__init__( - keyword, - attributes, - ('buffered', 'cached', 'cache_key', 'cache_timeout', - 'cache_type', 'cache_dir', 'cache_url', 'args'), - ('name','filter', 'decorator'), - (), + keyword, + attributes, + expressions, + ('name','filter', 'decorator'), + (), **kwargs) name = attributes.get('name') if name and not re.match(r'^[\w_]+$',name): raise exceptions.CompileException( - "%block may not specify an argument signature", - **self.exception_kwargs) + "%block may not specify an argument signature", + **self.exception_kwargs) if not name and attributes.get('args', None): raise exceptions.CompileException( "Only named %blocks may specify args", **self.exception_kwargs ) - self.body_decl = ast.FunctionArgs(attributes.get('args', ''), + self.body_decl = ast.FunctionArgs(attributes.get('args', ''), **self.exception_kwargs) self.name = name self.decorator = attributes.get('decorator', '') self.filter_args = ast.ArgumentList( - attributes.get('filter', ''), + attributes.get('filter', ''), **self.exception_kwargs) @@ -482,17 +500,22 @@ class BlockTag(Tag): return self.body_decl.argnames def undeclared_identifiers(self): - return [] + return (self.filter_args.\ + undeclared_identifiers.\ + difference(filters.DEFAULT_ESCAPES.keys()) + ).union(self.expression_undeclared_identifiers) + + class CallTag(Tag): __keyword__ = 'call' def __init__(self, keyword, attributes, **kwargs): - super(CallTag, self).__init__(keyword, attributes, + super(CallTag, self).__init__(keyword, attributes, ('args'), ('expr',), ('expr',), **kwargs) self.expression = attributes['expr'] self.code = ast.PythonCode(self.expression, **self.exception_kwargs) - self.body_decl = ast.FunctionArgs(attributes.get('args', ''), + self.body_decl = ast.FunctionArgs(attributes.get('args', ''), **self.exception_kwargs) def declared_identifiers(self): @@ -506,23 +529,23 @@ class CallNamespaceTag(Tag): def __init__(self, namespace, defname, attributes, **kwargs): super(CallNamespaceTag, self).__init__( - namespace + ":" + defname, - attributes, - tuple(attributes.keys()) + ('args', ), - (), - (), + namespace + ":" + defname, + attributes, + tuple(attributes.keys()) + ('args', ), + (), + (), **kwargs) - + self.expression = "%s.%s(%s)" % ( - namespace, - defname, + namespace, + defname, ",".join(["%s=%s" % (k, v) for k, v in - self.parsed_attributes.iteritems() + self.parsed_attributes.iteritems() if k != 'args']) ) self.code = ast.PythonCode(self.expression, **self.exception_kwargs) self.body_decl = ast.FunctionArgs( - attributes.get('args', ''), + attributes.get('args', ''), **self.exception_kwargs) def declared_identifiers(self): @@ -537,23 +560,24 @@ class InheritTag(Tag): def __init__(self, keyword, attributes, **kwargs): super(InheritTag, self).__init__( - keyword, attributes, + keyword, attributes, ('file',), (), ('file',), **kwargs) class PageTag(Tag): __keyword__ = 'page' def __init__(self, keyword, attributes, **kwargs): + expressions = ['cached', 'args', 'expression_filter', 'enable_loop'] + [ + c for c in attributes if c.startswith('cache_')] + super(PageTag, self).__init__( - keyword, - attributes, - ('cached', 'cache_key', 'cache_timeout', - 'cache_type', 'cache_dir', 'cache_url', - 'args', 'expression_filter'), - (), - (), + keyword, + attributes, + expressions, + (), + (), **kwargs) - self.body_decl = ast.FunctionArgs(attributes.get('args', ''), + self.body_decl = ast.FunctionArgs(attributes.get('args', ''), **self.exception_kwargs) self.filter_args = ast.ArgumentList( attributes.get('expression_filter', ''), @@ -561,5 +585,5 @@ class PageTag(Tag): def declared_identifiers(self): return self.body_decl.argnames - - + + diff --git a/mako/pygen.py b/mako/pygen.py index 07f26670..e946de50 100644 --- a/mako/pygen.py +++ b/mako/pygen.py @@ -1,5 +1,5 @@ # mako/pygen.py -# Copyright (C) 2006-2011 the Mako authors and contributors +# Copyright (C) 2006-2012 the Mako authors and contributors # # This module is part of Mako and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php @@ -14,48 +14,48 @@ class PythonPrinter(object): def __init__(self, stream): # indentation counter self.indent = 0 - - # a stack storing information about why we incremented + + # a stack storing information about why we incremented # the indentation counter, to help us determine if we # should decrement it self.indent_detail = [] - + # the string of whitespace multiplied by the indent # counter to produce a line self.indentstring = " " - + # the stream we are writing to self.stream = stream - + # a list of lines that represents a buffered "block" of code, - # which can be later printed relative to an indent level + # which can be later printed relative to an indent level self.line_buffer = [] - + self.in_indent_lines = False - + self._reset_multi_line_flags() def write(self, text): self.stream.write(text) - + def write_indented_block(self, block): """print a line or lines of python which already contain indentation. - + The indentation of the total block of lines will be adjusted to that of - the current indent level.""" + the current indent level.""" self.in_indent_lines = False for l in re.split(r'\r?\n', block): self.line_buffer.append(l) - + def writelines(self, *lines): """print a series of lines of python.""" for line in lines: self.writeline(line) - + def writeline(self, line): """print a line of python, indenting it according to the current indent level. - + this also adjusts the indentation counter according to the content of the line. @@ -65,9 +65,7 @@ class PythonPrinter(object): self._flush_adjusted_lines() self.in_indent_lines = True - decreased_indent = False - - if (line is None or + if (line is None or re.match(r"^\s*#",line) or re.match(r"^\s*$", line) ): @@ -76,31 +74,30 @@ class PythonPrinter(object): hastext = True is_comment = line and len(line) and line[0] == '#' - + # see if this line should decrease the indentation level - if (not decreased_indent and - not is_comment and + if (not is_comment and (not hastext or self._is_unindentor(line)) ): - - if self.indent > 0: + + if self.indent > 0: self.indent -=1 # if the indent_detail stack is empty, the user # probably put extra closures - the resulting - # module wont compile. - if len(self.indent_detail) == 0: + # module wont compile. + if len(self.indent_detail) == 0: raise exceptions.SyntaxException( "Too many whitespace closures") self.indent_detail.pop() - + if line is None: return - + # write the line self.stream.write(self._indent_line(line) + "\n") - + # see if this line should increase the indentation level. - # note that a line can both decrase (before printing) and + # note that a line can both decrase (before printing) and # then increase (after printing) the indentation level. if re.search(r":[ \t]*(?:#.*)?$", line): @@ -108,7 +105,7 @@ class PythonPrinter(object): # keep track of what the keyword was that indented us, # if it is a python compound statement keyword # where we might have to look for an "unindent" keyword - match = re.match(r"^\s*(if|try|elif|while|for)", line) + match = re.match(r"^\s*(if|try|elif|while|for|with)", line) if match: # its a "compound" keyword, so we will check for "unindentors" indentor = match.group(1) @@ -119,7 +116,8 @@ class PythonPrinter(object): # its not a "compound" keyword. but lets also # test for valid Python keywords that might be indenting us, # else assume its a non-indenting line - m2 = re.match(r"^\s*(def|class|else|elif|except|finally)", line) + m2 = re.match(r"^\s*(def|class|else|elif|except|finally)", + line) if m2: self.indent += 1 self.indent_detail.append(indentor) @@ -127,53 +125,53 @@ class PythonPrinter(object): def close(self): """close this printer, flushing any remaining lines.""" self._flush_adjusted_lines() - + def _is_unindentor(self, line): - """return true if the given line is an 'unindentor', + """return true if the given line is an 'unindentor', relative to the last 'indent' event received. - + """ - + # no indentation detail has been pushed on; return False - if len(self.indent_detail) == 0: + if len(self.indent_detail) == 0: return False indentor = self.indent_detail[-1] - - # the last indent keyword we grabbed is not a + + # the last indent keyword we grabbed is not a # compound statement keyword; return False - if indentor is None: + if indentor is None: return False - + # if the current line doesnt have one of the "unindentor" keywords, # return False match = re.match(r"^\s*(else|elif|except|finally).*\:", line) - if not match: + if not match: return False - + # whitespace matches up, we have a compound indentor, # and this line has an unindentor, this # is probably good enough return True - + # should we decide that its not good enough, heres # more stuff to check. #keyword = match.group(1) - - # match the original indent keyword + + # match the original indent keyword #for crit in [ # (r'if|elif', r'else|elif'), # (r'try', r'except|finally|else'), # (r'while|for', r'else'), #]: - # if re.match(crit[0], indentor) and re.match(crit[1], keyword): + # if re.match(crit[0], indentor) and re.match(crit[1], keyword): # return True - + #return False - + def _indent_line(self, line, stripspace=''): """indent the given line according to the current indent level. - + stripspace is a string of space that will be truncated from the start of the line before indenting.""" @@ -185,7 +183,7 @@ class PythonPrinter(object): or triple-quoted section.""" self.backslashed, self.triplequoted = False, False - + def _in_multi_line(self, line): """return true if the given line is part of a multi-line block, via backslash or triple-quote.""" @@ -195,24 +193,24 @@ class PythonPrinter(object): # guard against the possibility of modifying the space inside of # a literal multiline string with unfortunately placed # whitespace - - current_state = (self.backslashed or self.triplequoted) - + + current_state = (self.backslashed or self.triplequoted) + if re.search(r"\\$", line): self.backslashed = True else: self.backslashed = False - + triples = len(re.findall(r"\"\"\"|\'\'\'", line)) if triples == 1 or triples % 2 != 0: self.triplequoted = not self.triplequoted - + return current_state def _flush_adjusted_lines(self): stripspace = None self._reset_multi_line_flags() - + for entry in self.line_buffer: if self._in_multi_line(entry): self.stream.write(entry + "\n") @@ -221,32 +219,32 @@ class PythonPrinter(object): if stripspace is None and re.search(r"^[ \t]*[^# \t]", entry): stripspace = re.match(r"^([ \t]*)", entry).group(1) self.stream.write(self._indent_line(entry, stripspace) + "\n") - + self.line_buffer = [] self._reset_multi_line_flags() def adjust_whitespace(text): """remove the left-whitespace margin of a block of Python code.""" - + state = [False, False] (backslashed, triplequoted) = (0, 1) def in_multi_line(line): start_state = (state[backslashed] or state[triplequoted]) - + if re.search(r"\\$", line): state[backslashed] = True else: state[backslashed] = False - + def match(reg, t): m = re.match(reg, t) if m: return m, t[len(m.group(0)):] else: return None, t - + while line: if state[triplequoted]: m, line = match(r"%s" % state[triplequoted], line) @@ -258,14 +256,14 @@ def adjust_whitespace(text): m, line = match(r'#', line) if m: return start_state - + m, line = match(r"\"\"\"|\'\'\'", line) if m: state[triplequoted] = m.group(0) continue m, line = match(r".*?(?=\"\"\"|\'\'\'|#|$)", line) - + return start_state def _indent_line(line, stripspace = ''): diff --git a/mako/pyparser.py b/mako/pyparser.py index 953596af..1f39756e 100644 --- a/mako/pyparser.py +++ b/mako/pyparser.py @@ -1,5 +1,5 @@ # mako/pyparser.py -# Copyright (C) 2006-2011 the Mako authors and contributors +# Copyright (C) 2006-2012 the Mako authors and contributors # # This module is part of Mako and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php @@ -15,17 +15,17 @@ from mako import exceptions, util import operator if util.py3k: - # words that cannot be assigned to (notably + # words that cannot be assigned to (notably # smaller than the total keys in __builtins__) reserved = set(['True', 'False', 'None', 'print']) # the "id" attribute on a function node arg_id = operator.attrgetter('arg') else: - # words that cannot be assigned to (notably + # words that cannot be assigned to (notably # smaller than the total keys in __builtins__) reserved = set(['True', 'False', 'None']) - + # the "id" attribute on a function node arg_id = operator.attrgetter('id') @@ -42,7 +42,7 @@ except ImportError: def parse(code, mode='exec', **exception_kwargs): """Parse an expression into AST""" - + try: if _ast: @@ -54,8 +54,8 @@ def parse(code, mode='exec', **exception_kwargs): except Exception, e: raise exceptions.SyntaxException( "(%s) %s (%r)" % ( - e.__class__.__name__, - e, + e.__class__.__name__, + e, code[0:50] ), **exception_kwargs) @@ -66,13 +66,15 @@ if _ast: def __init__(self, listener, **exception_kwargs): self.in_function = False self.in_assign_targets = False - self.local_ident_stack = {} + self.local_ident_stack = set() self.listener = listener self.exception_kwargs = exception_kwargs def _add_declared(self, name): if not self.in_function: self.listener.declared_identifiers.add(name) + else: + self.local_ident_stack.add(name) def visit_ClassDef(self, node): self._add_declared(node.name) @@ -118,23 +120,20 @@ if _ast: # argument names in each function header so they arent # counted as "undeclared" - saved = {} inf = self.in_function self.in_function = True - for arg in node.args.args: - if arg_id(arg) in self.local_ident_stack: - saved[arg_id(arg)] = True - else: - self.local_ident_stack[arg_id(arg)] = True + + local_ident_stack = self.local_ident_stack + self.local_ident_stack = local_ident_stack.union([ + arg_id(arg) for arg in node.args.args + ]) if islambda: self.visit(node.body) else: for n in node.body: self.visit(n) self.in_function = inf - for arg in node.args.args: - if arg_id(arg) not in saved: - del self.local_ident_stack[arg_id(arg)] + self.local_ident_stack = local_ident_stack def visit_For(self, node): @@ -149,8 +148,10 @@ if _ast: def visit_Name(self, node): if isinstance(node.ctx, _ast.Store): + # this is eqiuvalent to visit_AssName in + # compiler self._add_declared(node.id) - if node.id not in reserved and node.id \ + elif node.id not in reserved and node.id \ not in self.listener.declared_identifiers and node.id \ not in self.local_ident_stack: self.listener.undeclared_identifiers.add(node.id) @@ -228,13 +229,15 @@ else: def __init__(self, listener, **exception_kwargs): self.in_function = False - self.local_ident_stack = {} + self.local_ident_stack = set() self.listener = listener self.exception_kwargs = exception_kwargs def _add_declared(self, name): if not self.in_function: self.listener.declared_identifiers.add(name) + else: + self.local_ident_stack.add(name) def visitClass(self, node, *args): self._add_declared(node.name) @@ -247,7 +250,6 @@ else: # flip around the visiting of Assign so the expression gets # evaluated first, in the case of a clause like "x=x+5" (x # is undeclared) - self.visit(node.expr, *args) for n in node.nodes: self.visit(n, *args) @@ -267,20 +269,18 @@ else: # argument names in each function header so they arent # counted as "undeclared" - saved = {} inf = self.in_function self.in_function = True - for arg in node.argnames: - if arg in self.local_ident_stack: - saved[arg] = True - else: - self.local_ident_stack[arg] = True + + local_ident_stack = self.local_ident_stack + self.local_ident_stack = local_ident_stack.union([ + arg for arg in node.argnames + ]) + for n in node.getChildNodes(): self.visit(n, *args) self.in_function = inf - for arg in node.argnames: - if arg not in saved: - del self.local_ident_stack[arg] + self.local_ident_stack = local_ident_stack def visitFor(self, node, *args): @@ -333,9 +333,11 @@ else: self.listener.codeargs.append(p) self.listener.args.append(ExpressionGenerator(n).value()) self.listener.declared_identifiers = \ - self.listener.declared_identifiers.union(p.declared_identifiers) + self.listener.declared_identifiers.union( + p.declared_identifiers) self.listener.undeclared_identifiers = \ - self.listener.undeclared_identifiers.union(p.undeclared_identifiers) + self.listener.undeclared_identifiers.union( + p.undeclared_identifiers) def visit(self, expr): visitor.walk(expr, self) # , walker=walker()) diff --git a/mako/runtime.py b/mako/runtime.py index dfd701aa..f890c809 100644 --- a/mako/runtime.py +++ b/mako/runtime.py @@ -1,5 +1,5 @@ # mako/runtime.py -# Copyright (C) 2006-2011 the Mako authors and contributors +# Copyright (C) 2006-2012 the Mako authors and contributors # # This module is part of Mako and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php @@ -10,65 +10,75 @@ Namespace, and various helper functions.""" from mako import exceptions, util import __builtin__, inspect, sys + class Context(object): """Provides runtime namespace, output buffer, and various callstacks for templates. - - See :ref:`runtime_toplevel` for detail on the usage of + + See :ref:`runtime_toplevel` for detail on the usage of :class:`.Context`. - + """ - + def __init__(self, buffer, **data): self._buffer_stack = [buffer] - + self._data = data + self._kwargs = data.copy() self._with_template = None self._outputting_as_unicode = None self.namespaces = {} - - # "capture" function which proxies to the + + # "capture" function which proxies to the # generic "capture" function self._data['capture'] = util.partial(capture, self) - + # "caller" stack used by def calls with content self.caller_stack = self._data['caller'] = CallerStack() - + + def _set_with_template(self, t): + self._with_template = t + illegal_names = t.reserved_names.intersection(self._data) + if illegal_names: + raise exceptions.NameConflictError( + "Reserved words passed to render(): %s" % + ", ".join(illegal_names)) + @property def lookup(self): - """Return the :class:`.TemplateLookup` associated + """Return the :class:`.TemplateLookup` associated with this :class:`.Context`. - + """ return self._with_template.lookup - + @property def kwargs(self): - """Return the dictionary of keyword argments associated with this + """Return the dictionary of keyword arguments associated with this :class:`.Context`. - + """ return self._kwargs.copy() - + def push_caller(self, caller): - """Pushes a 'caller' callable onto the callstack for + """Push a ``caller`` callable onto the callstack for this :class:`.Context`.""" - - + + self.caller_stack.append(caller) - + def pop_caller(self): - """Pops a 'caller' callable onto the callstack for this + """Pop a ``caller`` callable onto the callstack for this :class:`.Context`.""" del self.caller_stack[-1] - + def keys(self): """Return a list of all names established in this :class:`.Context`.""" return self._data.keys() - + def __getitem__(self, key): if key in self._data: return self._data[key] @@ -78,45 +88,45 @@ class Context(object): def _push_writer(self): """push a capturing buffer onto this Context and return the new writer function.""" - + buf = util.FastEncodingBuffer() self._buffer_stack.append(buf) return buf.write def _pop_buffer_and_writer(self): - """pop the most recent capturing buffer from this Context + """pop the most recent capturing buffer from this Context and return the current writer after the pop. - + """ buf = self._buffer_stack.pop() return buf, self._buffer_stack[-1].write - + def _push_buffer(self): """push a capturing buffer onto this Context.""" - + self._push_writer() - + def _pop_buffer(self): """pop the most recent capturing buffer from this Context.""" - + return self._buffer_stack.pop() - + def get(self, key, default=None): """Return a value from this :class:`.Context`.""" - - return self._data.get(key, + + return self._data.get(key, __builtin__.__dict__.get(key, default) ) - + def write(self, string): """Write a string to this :class:`.Context` object's underlying output buffer.""" - + self._buffer_stack[-1].write(string) - + def writer(self): - """Return the current writer function""" + """Return the current writer function.""" return self._buffer_stack[-1].write @@ -130,17 +140,17 @@ class Context(object): c.namespaces = self.namespaces c.caller_stack = self.caller_stack return c - + def locals_(self, d): - """create a new :class:`.Context` with a copy of this - :class:`Context`'s current state, updated with the given dictionary.""" - + """Create a new :class:`.Context` with a copy of this + :class:`.Context`'s current state, updated with the given dictionary.""" + if len(d) == 0: return self c = self._copy() c._data.update(d) return c - + def _clean_inheritance_tokens(self): """create a new copy of this :class:`.Context`. with tokens related to inheritance state removed.""" @@ -158,23 +168,27 @@ class CallerStack(list): def __nonzero__(self): return self._get_caller() and True or False def _get_caller(self): + # this method can be removed once + # codegen MAGIC_NUMBER moves past 7 return self[-1] def __getattr__(self, key): return getattr(self._get_caller(), key) def _push_frame(self): - self.append(self.nextcaller or None) + frame = self.nextcaller or None + self.append(frame) self.nextcaller = None + return frame def _pop_frame(self): self.nextcaller = self.pop() - - + + class Undefined(object): """Represents an undefined value in a template. - - All template modules have a constant value + + All template modules have a constant value ``UNDEFINED`` present which is an instance of this object. - + """ def __str__(self): raise NameError("Undefined") @@ -183,6 +197,110 @@ class Undefined(object): UNDEFINED = Undefined() +class LoopStack(object): + """a stack for LoopContexts that implements the context manager protocol + to automatically pop off the top of the stack on context exit + """ + + def __init__(self): + self.stack = [] + + def _enter(self, iterable): + self._push(iterable) + return self._top + + def _exit(self): + self._pop() + return self._top + + @property + def _top(self): + if self.stack: + return self.stack[-1] + else: + return self + + def _pop(self): + return self.stack.pop() + + def _push(self, iterable): + new = LoopContext(iterable) + if self.stack: + new.parent = self.stack[-1] + return self.stack.append(new) + + def __getattr__(self, key): + raise exceptions.RuntimeException("No loop context is established") + + def __iter__(self): + return iter(self._top) + + +class LoopContext(object): + """A magic loop variable. + Automatically accessible in any ``% for`` block. + + See the section :ref:`loop_context` for usage + notes. + + :attr:`parent` -> :class:`.LoopContext` or ``None`` + The parent loop, if one exists. + :attr:`index` -> `int` + The 0-based iteration count. + :attr:`reverse_index` -> `int` + The number of iterations remaining. + :attr:`first` -> `bool` + ``True`` on the first iteration, ``False`` otherwise. + :attr:`last` -> `bool` + ``True`` on the last iteration, ``False`` otherwise. + :attr:`even` -> `bool` + ``True`` when ``index`` is even. + :attr:`odd` -> `bool` + ``True`` when ``index`` is odd. + """ + + def __init__(self, iterable): + self._iterable = iterable + self.index = 0 + self.parent = None + + def __iter__(self): + for i in self._iterable: + yield i + self.index += 1 + + @util.memoized_instancemethod + def __len__(self): + return len(self._iterable) + + @property + def reverse_index(self): + return len(self) - self.index - 1 + + @property + def first(self): + return self.index == 0 + + @property + def last(self): + return self.index == len(self) - 1 + + @property + def even(self): + return not self.odd + + @property + def odd(self): + return bool(self.index % 2) + + def cycle(self, *values): + """Cycle through values as the loop progresses. + """ + if not values: + raise ValueError("You must provide values to cycle through") + return values[self.index % len(values)] + + class _NSAttr(object): def __init__(self, parent): self.__parent = parent @@ -193,24 +311,26 @@ class _NSAttr(object): return getattr(ns.module, key) else: ns = ns.inherits - raise AttributeError(key) - + raise AttributeError(key) + class Namespace(object): - """Provides access to collections of rendering methods, which + """Provides access to collections of rendering methods, which can be local, from other templates, or from imported modules. - - To access a particular rendering method referenced by a - :class:`.Namespace`, use plain attribute access:: - + + To access a particular rendering method referenced by a + :class:`.Namespace`, use plain attribute access: + + .. sourcecode:: mako + ${some_namespace.foo(x, y, z)} - - :class:`.Namespace` also contains several built-in attributes + + :class:`.Namespace` also contains several built-in attributes described here. - + """ - - def __init__(self, name, context, - callables=None, inherits=None, + + def __init__(self, name, context, + callables=None, inherits=None, populate_self=True, calling_uri=None): self.name = name self.context = context @@ -221,7 +341,7 @@ class Namespace(object): callables = () module = None - """The Python module referenced by this Namespace. + """The Python module referenced by this :class:`.Namespace`. If the namespace references a :class:`.Template`, then this module is the equivalent of ``template.module``, @@ -236,8 +356,8 @@ class Namespace(object): """ context = None - """The :class:`.Context` object for this namespace. - + """The :class:`.Context` object for this :class:`.Namespace`. + Namespaces are often created with copies of contexts that contain slightly different data, particularly in inheritance scenarios. Using the :class:`.Context` off of a :class:`.Namespace` one @@ -245,24 +365,24 @@ class Namespace(object): one-another. """ - + filename = None """The path of the filesystem file used for this - Namespace's module or template. + :class:`.Namespace`'s module or template. If this is a pure module-based - Namespace, this evaluates to ``module.__file__``. If a + :class:`.Namespace`, this evaluates to ``module.__file__``. If a template-based namespace, it evaluates to the original template file location. """ - + uri = None - """The uri for this Namespace's template. + """The URI for this :class:`.Namespace`'s template. I.e. whatever was sent to :meth:`.TemplateLookup.get_template()`. - This is the equivalent of :attr:`Template.uri`. + This is the equivalent of :attr:`.Template.uri`. """ @@ -270,8 +390,8 @@ class Namespace(object): @util.memoized_property def attr(self): - """Access module level attributes by name. - + """Access module level attributes by name. + This accessor allows templates to supply "scalar" attributes which are particularly handy in inheritance relationships. See the example in @@ -281,86 +401,72 @@ class Namespace(object): return _NSAttr(self) def get_namespace(self, uri): - """Return a :class:`.Namespace` corresponding to the given uri. - - If the given uri is a relative uri (i.e. it does not - contain ia leading slash ``/``), the uri is adjusted to - be relative to the uri of the namespace itself. This + """Return a :class:`.Namespace` corresponding to the given ``uri``. + + If the given ``uri`` is a relative URI (i.e. it does not + contain a leading slash ``/``), the ``uri`` is adjusted to + be relative to the ``uri`` of the namespace itself. This method is therefore mostly useful off of the built-in - ``local`` namespace, described in :ref:`namespace_local` + ``local`` namespace, described in :ref:`namespace_local`. In most cases, a template wouldn't need this function, and should instead use the ``<%namespace>`` tag to load namespaces. However, since all ``<%namespace>`` tags are - evaulated before the body of a template ever runs, + evaluated before the body of a template ever runs, this method can be used to locate namespaces using expressions that were generated within the body code of the template, or to conditionally use a particular namespace. - + """ key = (self, uri) if key in self.context.namespaces: return self.context.namespaces[key] else: - ns = TemplateNamespace(uri, self.context._copy(), - templateuri=uri, - calling_uri=self._templateuri) + ns = TemplateNamespace(uri, self.context._copy(), + templateuri=uri, + calling_uri=self._templateuri) self.context.namespaces[key] = ns return ns - + def get_template(self, uri): - """Return a :class:`.Template` from the given uri. - - The uri resolution is relative to the uri of this :class:`.Namespace` + """Return a :class:`.Template` from the given ``uri``. + + The ``uri`` resolution is relative to the ``uri`` of this :class:`.Namespace` object's :class:`.Template`. - + """ return _lookup_template(self.context, uri, self._templateuri) - + def get_cached(self, key, **kwargs): - """Return a value from the :class:`.Cache` referenced by this + """Return a value from the :class:`.Cache` referenced by this :class:`.Namespace` object's :class:`.Template`. - - The advantage to this method versus direct access to the + + The advantage to this method versus direct access to the :class:`.Cache` is that the configuration parameters declared in ``<%page>`` take effect here, thereby calling up the same configured backend as that configured by ``<%page>``. - + """ - - if self.template: - if not self.template.cache_enabled: - createfunc = kwargs.get('createfunc', None) - if createfunc: - return createfunc() - else: - return None - - if self.template.cache_dir: - kwargs.setdefault('data_dir', self.template.cache_dir) - if self.template.cache_type: - kwargs.setdefault('type', self.template.cache_type) - if self.template.cache_url: - kwargs.setdefault('url', self.template.cache_url) + return self.cache.get(key, **kwargs) - + @property def cache(self): - """Return the :class:`.Cache` object referenced - by this :class:`.Namespace` object's + """Return the :class:`.Cache` object referenced + by this :class:`.Namespace` object's :class:`.Template`. - + """ return self.template.cache - + def include_file(self, uri, **kwargs): - """Include a file at the given uri""" - + """Include a file at the given ``uri``.""" + _include_file(self.context, uri, self._templateuri, **kwargs) - + def _populate(self, d, l): for ident in l: if ident == '*': @@ -368,7 +474,7 @@ class Namespace(object): d[k] = v else: d[ident] = getattr(self, ident) - + def _get_star(self): if self.callables: for key in self.callables: @@ -381,7 +487,7 @@ class Namespace(object): val = getattr(self.inherits, key) else: raise AttributeError( - "Namespace '%s' has no member '%s'" % + "Namespace '%s' has no member '%s'" % (self.name, key)) setattr(self, key, val) return val @@ -389,8 +495,8 @@ class Namespace(object): class TemplateNamespace(Namespace): """A :class:`.Namespace` specific to a :class:`.Template` instance.""" - def __init__(self, name, context, template=None, templateuri=None, - callables=None, inherits=None, + def __init__(self, name, context, template=None, templateuri=None, + callables=None, inherits=None, populate_self=True, calling_uri=None): self.name = name self.context = context @@ -399,7 +505,7 @@ class TemplateNamespace(Namespace): self.callables = dict([(c.func_name, c) for c in callables]) if templateuri is not None: - self.template = _lookup_template(context, templateuri, + self.template = _lookup_template(context, templateuri, calling_uri) self._templateuri = self.template.module._template_uri elif template is not None: @@ -410,13 +516,13 @@ class TemplateNamespace(Namespace): if populate_self: lclcallable, lclcontext = \ - _populate_self_namespace(context, self.template, + _populate_self_namespace(context, self.template, self_ns=self) @property def module(self): - """The Python module referenced by this Namespace. - + """The Python module referenced by this :class:`.Namespace`. + If the namespace references a :class:`.Template`, then this module is the equivalent of ``template.module``, i.e. the generated module for the template. @@ -427,17 +533,17 @@ class TemplateNamespace(Namespace): @property def filename(self): """The path of the filesystem file used for this - Namespace's module or template. + :class:`.Namespace`'s module or template. """ return self.template.filename @property def uri(self): - """The uri for this Namespace's template. - + """The URI for this :class:`.Namespace`'s template. + I.e. whatever was sent to :meth:`.TemplateLookup.get_template()`. - - This is the equivalent of :attr:`Template.uri`. + + This is the equivalent of :attr:`.Template.uri`. """ return self.template.uri @@ -463,7 +569,7 @@ class TemplateNamespace(Namespace): else: raise AttributeError( - "Namespace '%s' has no member '%s'" % + "Namespace '%s' has no member '%s'" % (self.name, key)) setattr(self, key, val) return val @@ -471,8 +577,8 @@ class TemplateNamespace(Namespace): class ModuleNamespace(Namespace): """A :class:`.Namespace` specific to a Python module instance.""" - def __init__(self, name, context, module, - callables=None, inherits=None, + def __init__(self, name, context, module, + callables=None, inherits=None, populate_self=True, calling_uri=None): self.name = name self.context = context @@ -488,7 +594,7 @@ class ModuleNamespace(Namespace): @property def filename(self): """The path of the filesystem file used for this - Namespace's module or template. + :class:`.Namespace`'s module or template. """ return self.module.__file__ @@ -513,7 +619,7 @@ class ModuleNamespace(Namespace): val = getattr(self.inherits, key) else: raise AttributeError( - "Namespace '%s' has no member '%s'" % + "Namespace '%s' has no member '%s'" % (self.name, key)) setattr(self, key, val) return val @@ -521,11 +627,11 @@ class ModuleNamespace(Namespace): def supports_caller(func): """Apply a caller_stack compatibility decorator to a plain Python function. - + See the example in :ref:`namespaces_python_modules`. - + """ - + def wrap_stackframe(context, *args, **kwargs): context.caller_stack._push_frame() try: @@ -533,19 +639,19 @@ def supports_caller(func): finally: context.caller_stack._pop_frame() return wrap_stackframe - + def capture(context, callable_, *args, **kwargs): """Execute the given template def, capturing the output into a buffer. - + See the example in :ref:`namespaces_python_modules`. - + """ - + if not callable(callable_): raise exceptions.RuntimeException( - "capture() function expects a callable as " - "its argument (i.e. capture(func, *args, **kwargs))" + "capture() function expects a callable as " + "its argument (i.e. capture(func, *args, **kwargs))" ) context._push_buffer() try: @@ -567,7 +673,7 @@ def _decorate_toplevel(fn): return fn(y)(context, *args, **kw) return go return decorate_render - + def _decorate_inline(context, fn): def decorate_render(render_fn): dec = fn(render_fn) @@ -575,17 +681,17 @@ def _decorate_inline(context, fn): return dec(context, *args, **kw) return go return decorate_render - + def _include_file(context, uri, calling_uri, **kwargs): """locate the template from the given uri and include it in the current output.""" - + template = _lookup_template(context, uri, calling_uri) (callable_, ctx) = _populate_self_namespace( - context._clean_inheritance_tokens(), + context._clean_inheritance_tokens(), template) callable_(ctx, **_kwargs_for_include(callable_, context._data, **kwargs)) - + def _inherit_from(context, uri, calling_uri): """called by the _inherit method in template modules to set up the inheritance chain at the start of a template's @@ -599,9 +705,9 @@ def _inherit_from(context, uri, calling_uri): while ih.inherits is not None: ih = ih.inherits lclcontext = context.locals_({'next':ih}) - ih.inherits = TemplateNamespace("self:%s" % template.uri, - lclcontext, - template = template, + ih.inherits = TemplateNamespace("self:%s" % template.uri, + lclcontext, + template = template, populate_self=False) context._data['parent'] = lclcontext._data['local'] = ih.inherits callable_ = getattr(template.module, '_mako_inherit', None) @@ -619,7 +725,7 @@ def _lookup_template(context, uri, relativeto): lookup = context._with_template.lookup if lookup is None: raise exceptions.TemplateLookupException( - "Template '%s' has no TemplateLookup associated" % + "Template '%s' has no TemplateLookup associated" % context._with_template.uri) uri = lookup.adjust_uri(uri, relativeto) try: @@ -629,8 +735,8 @@ def _lookup_template(context, uri, relativeto): def _populate_self_namespace(context, template, self_ns=None): if self_ns is None: - self_ns = TemplateNamespace('self:%s' % template.uri, - context, template=template, + self_ns = TemplateNamespace('self:%s' % template.uri, + context, template=template, populate_self=False) context._data['self'] = context._data['local'] = self_ns if hasattr(template.module, '_mako_inherit'): @@ -640,7 +746,7 @@ def _populate_self_namespace(context, template, self_ns=None): return (template.callable_, context) def _render(template, callable_, args, data, as_unicode=False): - """create a Context and return the string + """create a Context and return the string output of the given template and template callable.""" if as_unicode: @@ -649,14 +755,14 @@ def _render(template, callable_, args, data, as_unicode=False): buf = util.StringIO() else: buf = util.FastEncodingBuffer( - unicode=as_unicode, - encoding=template.output_encoding, + unicode=as_unicode, + encoding=template.output_encoding, errors=template.encoding_errors) context = Context(buf, **data) context._outputting_as_unicode = as_unicode - context._with_template = template - - _render_context(template, callable_, context, *args, + context._set_with_template(template) + + _render_context(template, callable_, context, *args, **_kwargs_for_callable(callable_, data)) return context._pop_buffer().getvalue() @@ -665,7 +771,7 @@ def _kwargs_for_callable(callable_, data): # for normal pages, **pageargs is usually present if argspec[2]: return data - + # for rendering defs from the top level, figure out the args namedargs = argspec[0] + [v for v in argspec[1:3] if v is not None] kwargs = {} @@ -681,10 +787,10 @@ def _kwargs_for_include(callable_, data, **kwargs): if arg != 'context' and arg in data and arg not in kwargs: kwargs[arg] = data[arg] return kwargs - + def _render_context(tmpl, callable_, context, *args, **kwargs): import mako.template as template - # create polymorphic 'self' namespace for this + # create polymorphic 'self' namespace for this # template with possibly updated context if not isinstance(tmpl, template.DefTemplate): # if main render method, call from the base of the inheritance stack @@ -694,7 +800,7 @@ def _render_context(tmpl, callable_, context, *args, **kwargs): # otherwise, call the actual rendering method specified (inherit, lclcontext) = _populate_self_namespace(context, tmpl.parent) _exec_template(callable_, context, args=args, kwargs=kwargs) - + def _exec_template(callable_, context, args=None, kwargs=None): """execute a rendering callable given the callable, a Context, and optional explicit arguments @@ -711,7 +817,7 @@ def _exec_template(callable_, context, args=None, kwargs=None): callable_(context, *args, **kwargs) except Exception, e: _render_error(template, context, e) - except: + except: e = sys.exc_info()[0] _render_error(template, context, e) else: @@ -730,6 +836,6 @@ def _render_error(template, context, error): context._buffer_stack[:] = [util.FastEncodingBuffer( error_template.output_encoding, error_template.encoding_errors)] - - context._with_template = error_template + + context._set_with_template(error_template) error_template.render_context(context, error=error) diff --git a/mako/template.py b/mako/template.py index 903dc425..b0691391 100644 --- a/mako/template.py +++ b/mako/template.py @@ -1,5 +1,5 @@ # mako/template.py -# Copyright (C) 2006-2011 the Mako authors and contributors +# Copyright (C) 2006-2012 the Mako authors and contributors # # This module is part of Mako and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php @@ -8,57 +8,74 @@ template strings, as well as template runtime operations.""" from mako.lexer import Lexer -from mako import runtime, util, exceptions, codegen -import imp, os, re, shutil, stat, sys, tempfile, time, types, weakref +from mako import runtime, util, exceptions, codegen, cache +import os, re, shutil, stat, sys, tempfile, types, weakref + - class Template(object): """Represents a compiled template. - + :class:`.Template` includes a reference to the original - template source (via the ``.source`` attribute) + template source (via the :attr:`.source` attribute) as well as the source code of the - generated Python module (i.e. the ``.code`` attribute), + generated Python module (i.e. the :attr:`.code` attribute), as well as a reference to an actual Python module. :class:`.Template` is constructed using either a literal string representing the template text, or a filename representing a filesystem path to a source file. - - :param text: textual template source. This argument is mutually - exclusive versus the "filename" parameter. - :param filename: filename of the source template. This argument is - mutually exclusive versus the "text" parameter. + :param text: textual template source. This argument is mutually + exclusive versus the ``filename`` parameter. + + :param filename: filename of the source template. This argument is + mutually exclusive versus the ``text`` parameter. :param buffer_filters: string list of filters to be applied - to the output of %defs which are buffered, cached, or otherwise + to the output of ``%def``\ s which are buffered, cached, or otherwise filtered, after all filters - defined with the %def itself have been applied. Allows the + defined with the ``%def`` itself have been applied. Allows the creation of default expression filters that let the output - of return-valued %defs "opt out" of that filtering via + of return-valued ``%def``\ s "opt out" of that filtering via passing special attributes or objects. - :param bytestring_passthrough: When True, and output_encoding is - set to None, and :meth:`.Template.render` is used to render, - the StringIO or cStringIO buffer will be used instead of the + :param bytestring_passthrough: When ``True``, and ``output_encoding`` is + set to ``None``, and :meth:`.Template.render` is used to render, + the `StringIO` or `cStringIO` buffer will be used instead of the default "fast" buffer. This allows raw bytestrings in the output stream, such as in expressions, to pass straight - through to the buffer. New in 0.4 to provide the same - behavior as that of the previous series. This flag is forced - to True if disable_unicode is also configured. + through to the buffer. This flag is forced + to ``True`` if ``disable_unicode`` is also configured. - :param cache_dir: Filesystem directory where cache files will be - placed. See :ref:`caching_toplevel`. + .. versionadded:: 0.4 + Added to provide the same behavior as that of the previous series. + + :param cache_args: Dictionary of cache configuration arguments that + will be passed to the :class:`.CacheImpl`. See :ref:`caching_toplevel`. + + :param cache_dir: + + .. deprecated:: 0.6 + Use the ``'dir'`` argument in the ``cache_args`` dictionary. + See :ref:`caching_toplevel`. :param cache_enabled: Boolean flag which enables caching of this template. See :ref:`caching_toplevel`. - :param cache_type: Type of Beaker caching to be applied to the - template. See :ref:`caching_toplevel`. - - :param cache_url: URL of a memcached server with which to use - for caching. See :ref:`caching_toplevel`. + :param cache_impl: String name of a :class:`.CacheImpl` caching + implementation to use. Defaults to ``'beaker'``. + + :param cache_type: + + .. deprecated:: 0.6 + Use the ``'type'`` argument in the ``cache_args`` dictionary. + See :ref:`caching_toplevel`. + + :param cache_url: + + .. deprecated:: 0.6 + Use the ``'url'`` argument in the ``cache_args`` dictionary. + See :ref:`caching_toplevel`. :param default_filters: List of string filter names that will be applied to all expressions. See :ref:`filtering_default_filters`. @@ -66,9 +83,16 @@ class Template(object): :param disable_unicode: Disables all awareness of Python Unicode objects. See :ref:`unicode_disabled`. + :param enable_loop: When ``True``, enable the ``loop`` context variable. + This can be set to ``False`` to support templates that may + be making usage of the name "``loop``". Individual templates can + re-enable the "loop" context by placing the directive + ``enable_loop="True"`` inside the ``<%page>`` tag -- see + :ref:`migrating_loop`. + :param encoding_errors: Error parameter passed to ``encode()`` when string encoding is performed. See :ref:`usage_unicode`. - + :param error_handler: Python callable which is called whenever compile or runtime exceptions occur. The callable is passed the current context as well as the exception. If the @@ -76,13 +100,13 @@ class Template(object): be handled, else it is re-raised after the function completes. Is used to provide custom error-rendering functions. - + :param format_exceptions: if ``True``, exceptions which occur during the render phase of this template will be caught and formatted into an HTML error page, which then becomes the - rendered result of the :meth:`render` call. Otherwise, + rendered result of the :meth:`.render` call. Otherwise, runtime exceptions are propagated outwards. - + :param imports: String list of Python statements, typically individual "import" lines, which will be placed into the module level preamble of all generated Python modules. See the example @@ -92,66 +116,107 @@ class Template(object): be used in lieu of the coding comment. See :ref:`usage_unicode` as well as :ref:`unicode_toplevel` for details on source encoding. - + :param lookup: a :class:`.TemplateLookup` instance that will be used for all file lookups via the ``<%namespace>``, ``<%include>``, and ``<%inherit>`` tags. See :ref:`usage_templatelookup`. - - :param module_directory: Filesystem location where generated + + :param module_directory: Filesystem location where generated Python module files will be placed. - :param module_filename: Overrides the filename of the generated + :param module_filename: Overrides the filename of the generated Python module file. For advanced usage only. - - :param output_encoding: The encoding to use when :meth:`.render` - is called. + + :param module_writer: A callable which overrides how the Python + module is written entirely. The callable is passed the + encoded source content of the module and the destination + path to be written to. The default behavior of module writing + uses a tempfile in conjunction with a file move in order + to make the operation atomic. So a user-defined module + writing function that mimics the default behavior would be: + + .. sourcecode:: python + + import tempfile + import os + import shutil + + def module_writer(source, outputpath): + (dest, name) = \\ + tempfile.mkstemp( + dir=os.path.dirname(outputpath) + ) + + os.write(dest, source) + os.close(dest) + shutil.move(name, outputpath) + + from mako.template import Template + mytemplate = Template( + file="index.html", + module_directory="/path/to/modules", + module_writer=module_writer + ) + + The function is provided for unusual configurations where + certain platform-specific permissions or other special + steps are needed. + + :param output_encoding: The encoding to use when :meth:`.render` + is called. See :ref:`usage_unicode` as well as :ref:`unicode_toplevel`. - - :param preprocessor: Python callable which will be passed + + :param preprocessor: Python callable which will be passed the full template source before it is parsed. The return result of the callable will be used as the template source code. - - :param strict_undefined: Replaces the automatic usage of + + :param strict_undefined: Replaces the automatic usage of ``UNDEFINED`` for any undeclared variables not located in the :class:`.Context` with an immediate raise of ``NameError``. The advantage is immediate reporting of - missing variables which include the name. New in 0.3.6. - - :param uri: string uri or other identifier for this template. - If not provided, the uri is generated from the filesystem + missing variables which include the name. + + .. versionadded:: 0.3.6 + + :param uri: string URI or other identifier for this template. + If not provided, the ``uri`` is generated from the filesystem path, or from the in-memory identity of a non-file-based - template. The primary usage of the uri is to provide a key + template. The primary usage of the ``uri`` is to provide a key within :class:`.TemplateLookup`, as well as to generate the file path of the generated Python module file, if ``module_directory`` is specified. - + """ - - def __init__(self, - text=None, - filename=None, - uri=None, - format_exceptions=False, - error_handler=None, - lookup=None, - output_encoding=None, - encoding_errors='strict', - module_directory=None, - cache_type=None, - cache_dir=None, - cache_url=None, - module_filename=None, - input_encoding=None, + + def __init__(self, + text=None, + filename=None, + uri=None, + format_exceptions=False, + error_handler=None, + lookup=None, + output_encoding=None, + encoding_errors='strict', + module_directory=None, + cache_args=None, + cache_impl='beaker', + cache_enabled=True, + cache_type=None, + cache_dir=None, + cache_url=None, + module_filename=None, + input_encoding=None, disable_unicode=False, - bytestring_passthrough=False, - default_filters=None, - buffer_filters=(), + module_writer=None, + bytestring_passthrough=False, + default_filters=None, + buffer_filters=(), strict_undefined=False, - imports=None, - preprocessor=None, - cache_enabled=True): + imports=None, + enable_loop=True, + preprocessor=None): if uri: self.module_id = re.sub(r'\W', "_", uri) self.uri = uri @@ -163,13 +228,25 @@ class Template(object): else: self.module_id = "memory:" + hex(id(self)) self.uri = self.module_id - + + u_norm = self.uri + if u_norm.startswith("/"): + u_norm = u_norm[1:] + u_norm = os.path.normpath(u_norm) + if u_norm.startswith(".."): + raise exceptions.TemplateLookupException( + "Template uri \"%s\" is invalid - " + "it cannot be relative outside " + "of the root path." % self.uri) + self.input_encoding = input_encoding self.output_encoding = output_encoding self.encoding_errors = encoding_errors self.disable_unicode = disable_unicode self.bytestring_passthrough = bytestring_passthrough or disable_unicode + self.enable_loop = enable_loop self.strict_undefined = strict_undefined + self.module_writer = module_writer if util.py3k and disable_unicode: raise exceptions.UnsupportedError( @@ -187,10 +264,10 @@ class Template(object): else: self.default_filters = default_filters self.buffer_filters = buffer_filters - + self.imports = imports self.preprocessor = preprocessor - + # if plain text, compile code in memory only if text is not None: (code, module) = _compile_text(self, text, filename) @@ -203,18 +280,14 @@ class Template(object): if module_filename is not None: path = module_filename elif module_directory is not None: - u = self.uri - if u[0] == '/': - u = u[1:] path = os.path.abspath( os.path.join( - os.path.normpath(module_directory), - os.path.normpath(u) + ".py" + os.path.normpath(module_directory), + u_norm + ".py" ) ) else: path = None - module = self._compile_from_file(path, filename) else: raise exceptions.RuntimeException( @@ -226,147 +299,192 @@ class Template(object): self.format_exceptions = format_exceptions self.error_handler = error_handler self.lookup = lookup - self.cache_type = cache_type - self.cache_dir = cache_dir - self.cache_url = cache_url + + self.module_directory = module_directory + + self._setup_cache_args( + cache_impl, cache_enabled, cache_args, + cache_type, cache_dir, cache_url + ) + + @util.memoized_property + def reserved_names(self): + if self.enable_loop: + return codegen.RESERVED_NAMES + else: + return codegen.RESERVED_NAMES.difference(['loop']) + + def _setup_cache_args(self, + cache_impl, cache_enabled, cache_args, + cache_type, cache_dir, cache_url): + self.cache_impl = cache_impl self.cache_enabled = cache_enabled - + if cache_args: + self.cache_args = cache_args + else: + self.cache_args = {} + + # transfer deprecated cache_* args + if cache_type: + self.cache_args['type'] = cache_type + if cache_dir: + self.cache_args['dir'] = cache_dir + if cache_url: + self.cache_args['url'] = cache_url + def _compile_from_file(self, path, filename): if path is not None: util.verify_directory(os.path.dirname(path)) filemtime = os.stat(filename)[stat.ST_MTIME] if not os.path.exists(path) or \ os.stat(path)[stat.ST_MTIME] < filemtime: + data = util.read_file(filename) _compile_module_file( - self, - open(filename, 'rb').read(), - filename, - path) - module = imp.load_source(self.module_id, path, open(path, 'rb')) + self, + data, + filename, + path, + self.module_writer) + module = util.load_module(self.module_id, path) del sys.modules[self.module_id] if module._magic_number != codegen.MAGIC_NUMBER: + data = util.read_file(filename) _compile_module_file( - self, - open(filename, 'rb').read(), - filename, - path) - module = imp.load_source(self.module_id, path, open(path, 'rb')) + self, + data, + filename, + path, + self.module_writer) + module = util.load_module(self.module_id, path) del sys.modules[self.module_id] ModuleInfo(module, path, self, filename, None, None) else: # template filename and no module directory, compile code # in memory + data = util.read_file(filename) code, module = _compile_text( - self, - open(filename, 'rb').read(), + self, + data, filename) self._source = None self._code = code ModuleInfo(module, None, self, filename, code, None) return module - + @property def source(self): - """return the template source code for this Template.""" - + """Return the template source code for this :class:`.Template`.""" + return _get_module_info_from_callable(self.callable_).source @property def code(self): - """return the module source code for this Template""" - + """Return the module source code for this :class:`.Template`.""" + return _get_module_info_from_callable(self.callable_).code - - @property + + @util.memoized_property def cache(self): - return self.module._template_cache - + return cache.Cache(self) + + @property + def cache_dir(self): + return self.cache_args['dir'] + @property + def cache_url(self): + return self.cache_args['url'] + @property + def cache_type(self): + return self.cache_args['type'] + def render(self, *args, **data): """Render the output of this template as a string. - - if the template specifies an output encoding, the string + + If the template specifies an output encoding, the string will be encoded accordingly, else the output is raw (raw - output uses cStringIO and can't handle multibyte - characters). a Context object is created corresponding - to the given data. Arguments that are explictly declared + output uses `cStringIO` and can't handle multibyte + characters). A :class:`.Context` object is created corresponding + to the given data. Arguments that are explicitly declared by this template's internal rendering method are also - pulled from the given \*args, \**data members. - + pulled from the given ``*args``, ``**data`` members. + """ return runtime._render(self, self.callable_, args, data) - + def render_unicode(self, *args, **data): - """render the output of this template as a unicode object.""" - - return runtime._render(self, - self.callable_, - args, - data, + """Render the output of this template as a unicode object.""" + + return runtime._render(self, + self.callable_, + args, + data, as_unicode=True) - + def render_context(self, context, *args, **kwargs): - """Render this Template with the given context. - - the data is written to the context's buffer. - + """Render this :class:`.Template` with the given context. + + The data is written to the context's buffer. + """ if getattr(context, '_with_template', None) is None: - context._with_template = self - runtime._render_context(self, - self.callable_, - context, - *args, + context._set_with_template(self) + runtime._render_context(self, + self.callable_, + context, + *args, **kwargs) - + def has_def(self, name): return hasattr(self.module, "render_%s" % name) - + def get_def(self, name): """Return a def of this template as a :class:`.DefTemplate`.""" - + return DefTemplate(self, getattr(self.module, "render_%s" % name)) def _get_def_callable(self, name): return getattr(self.module, "render_%s" % name) - + @property - def last_modified(self): - return self.module._modified_time - + def last_modified(self): + return self.module._modified_time + class ModuleTemplate(Template): """A Template which is constructed given an existing Python module. - + e.g.:: - + t = Template("this is a template") f = file("mymodule.py", "w") f.write(t.code) f.close() - + import mymodule - + t = ModuleTemplate(mymodule) print t.render() - + """ - - def __init__(self, module, - module_filename=None, - template=None, - template_filename=None, - module_source=None, + + def __init__(self, module, + module_filename=None, + template=None, + template_filename=None, + module_source=None, template_source=None, - output_encoding=None, + output_encoding=None, encoding_errors='strict', - disable_unicode=False, + disable_unicode=False, bytestring_passthrough=False, format_exceptions=False, - error_handler=None, - lookup=None, + error_handler=None, + lookup=None, + cache_args=None, + cache_impl='beaker', + cache_enabled=True, cache_type=None, - cache_dir=None, - cache_url=None, - cache_enabled=True + cache_dir=None, + cache_url=None, ): self.module_id = re.sub(r'\W', "_", module._template_uri) self.uri = module._template_uri @@ -375,6 +493,7 @@ class ModuleTemplate(Template): self.encoding_errors = encoding_errors self.disable_unicode = disable_unicode self.bytestring_passthrough = bytestring_passthrough or disable_unicode + self.enable_loop = module._enable_loop if util.py3k and disable_unicode: raise exceptions.UnsupportedError( @@ -387,26 +506,26 @@ class ModuleTemplate(Template): self.module = module self.filename = template_filename - ModuleInfo(module, - module_filename, - self, - template_filename, - module_source, + ModuleInfo(module, + module_filename, + self, + template_filename, + module_source, template_source) - + self.callable_ = self.module.render_body self.format_exceptions = format_exceptions self.error_handler = error_handler self.lookup = lookup - self.cache_type = cache_type - self.cache_dir = cache_dir - self.cache_url = cache_url - self.cache_enabled = cache_enabled - + self._setup_cache_args( + cache_impl, cache_enabled, cache_args, + cache_type, cache_dir, cache_url + ) + class DefTemplate(Template): - """a Template which represents a callable def in a parent + """A :class:`.Template` which represents a callable def in a parent template.""" - + def __init__(self, parent, callable_): self.parent = parent self.callable_ = callable_ @@ -415,6 +534,7 @@ class DefTemplate(Template): self.encoding_errors = parent.encoding_errors self.format_exceptions = parent.format_exceptions self.error_handler = parent.error_handler + self.enable_loop = parent.enable_loop self.lookup = parent.lookup self.bytestring_passthrough = parent.bytestring_passthrough @@ -425,16 +545,16 @@ class ModuleInfo(object): """Stores information about a module currently loaded into memory, provides reverse lookups of template source, module source code based on a module's identifier. - + """ _modules = weakref.WeakValueDictionary() - def __init__(self, - module, - module_filename, - template, - template_filename, - module_source, + def __init__(self, + module, + module_filename, + template, + template_filename, + module_source, template_source): self.module = module self.module_filename = module_filename @@ -444,14 +564,14 @@ class ModuleInfo(object): self._modules[module.__name__] = template._mmarker = self if module_filename: self._modules[module_filename] = self - + @property def code(self): if self.module_source is not None: return self.module_source else: - return open(self.module_filename).read() - + return util.read_file(self.module_filename) + @property def source(self): if self.template_source is not None: @@ -462,31 +582,37 @@ class ModuleInfo(object): else: return self.template_source else: + data = util.read_file(self.template_filename) if self.module._source_encoding: - return open(self.template_filename, 'rb').read().\ - decode(self.module._source_encoding) + return data.decode(self.module._source_encoding) else: - return open(self.template_filename).read() - -def _compile_text(template, text, filename): - identifier = template.module_id - lexer = Lexer(text, - filename, + return data + +def _compile(template, text, filename, generate_magic_comment): + lexer = Lexer(text, + filename, disable_unicode=template.disable_unicode, input_encoding=template.input_encoding, preprocessor=template.preprocessor) node = lexer.parse() - - source = codegen.compile(node, - template.uri, + source = codegen.compile(node, + template.uri, filename, default_filters=template.default_filters, - buffer_filters=template.buffer_filters, - imports=template.imports, + buffer_filters=template.buffer_filters, + imports=template.imports, source_encoding=lexer.encoding, - generate_magic_comment=template.disable_unicode, + generate_magic_comment=generate_magic_comment, disable_unicode=template.disable_unicode, - strict_undefined=template.strict_undefined) + strict_undefined=template.strict_undefined, + enable_loop=template.enable_loop, + reserved_names=template.reserved_names) + return source, lexer + +def _compile_text(template, text, filename): + identifier = template.module_id + source, lexer = _compile(template, text, filename, + generate_magic_comment=template.disable_unicode) cid = identifier if not util.py3k and isinstance(cid, unicode): @@ -496,41 +622,29 @@ def _compile_text(template, text, filename): exec code in module.__dict__, module.__dict__ return (source, module) -def _compile_module_file(template, text, filename, outputpath): +def _compile_module_file(template, text, filename, outputpath, module_writer): identifier = template.module_id - lexer = Lexer(text, - filename, - disable_unicode=template.disable_unicode, - input_encoding=template.input_encoding, - preprocessor=template.preprocessor) - - node = lexer.parse() - source = codegen.compile(node, - template.uri, - filename, - default_filters=template.default_filters, - buffer_filters=template.buffer_filters, - imports=template.imports, - source_encoding=lexer.encoding, - generate_magic_comment=True, - disable_unicode=template.disable_unicode, - strict_undefined=template.strict_undefined) - - # make tempfiles in the same location as the ultimate - # location. this ensures they're on the same filesystem, - # avoiding synchronization issues. - (dest, name) = tempfile.mkstemp(dir=os.path.dirname(outputpath)) - + source, lexer = _compile(template, text, filename, + generate_magic_comment=True) + if isinstance(source, unicode): source = source.encode(lexer.encoding or 'ascii') - - os.write(dest, source) - os.close(dest) - shutil.move(name, outputpath) + + if module_writer: + module_writer(source, outputpath) + else: + # make tempfiles in the same location as the ultimate + # location. this ensures they're on the same filesystem, + # avoiding synchronization issues. + (dest, name) = tempfile.mkstemp(dir=os.path.dirname(outputpath)) + + os.write(dest, source) + os.close(dest) + shutil.move(name, outputpath) def _get_module_info_from_callable(callable_): return _get_module_info(callable_.func_globals['__name__']) - + def _get_module_info(filename): return ModuleInfo._modules[filename] - + diff --git a/mako/util.py b/mako/util.py index 5518b4dd..df4bf4b7 100644 --- a/mako/util.py +++ b/mako/util.py @@ -1,13 +1,15 @@ # mako/util.py -# Copyright (C) 2006-2011 the Mako authors and contributors +# Copyright (C) 2006-2012 the Mako authors and contributors # # This module is part of Mako and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php +import imp import sys py3k = getattr(sys, 'py3kwarning', False) or sys.version_info >= (3, 0) +py26 = sys.version_info >= (2, 6) py24 = sys.version_info >= (2, 4) and sys.version_info < (2, 5) jython = sys.platform.startswith('java') win32 = sys.platform.startswith('win') @@ -33,8 +35,8 @@ except ImportError: if win32 or jython: time_func = time.clock else: - time_func = time.time - + time_func = time.time + def function_named(fn, name): """Return a function with a given __name__. @@ -56,20 +58,57 @@ except: return newfunc if py24: + def all(iterable): + for i in iterable: + if not i: + return False + return True + def exception_name(exc): try: return exc.__class__.__name__ except AttributeError: return exc.__name__ else: + all = all + def exception_name(exc): return exc.__class__.__name__ - + + +class PluginLoader(object): + def __init__(self, group): + self.group = group + self.impls = {} + + def load(self, name): + if name in self.impls: + return self.impls[name]() + else: + import pkg_resources + for impl in pkg_resources.iter_entry_points( + self.group, + name): + self.impls[name] = impl.load + return impl.load() + else: + raise exceptions.RuntimeException( + "Can't load plugin %s %s" % + (self.group, name)) + + def register(self, name, modulepath, objname): + def load(): + mod = __import__(modulepath) + for token in modulepath.split(".")[1:]: + mod = getattr(mod, token) + return getattr(mod, objname) + self.impls[name] = load + def verify_directory(dir): """create and/or verify a filesystem directory.""" - + tries = 0 - + while not os.path.exists(dir): try: tries += 1 @@ -100,20 +139,47 @@ class memoized_property(object): obj.__dict__[self.__name__] = result = self.fget(obj) return result +class memoized_instancemethod(object): + """Decorate a method memoize its return value. + + Best applied to no-arg methods: memoization is not sensitive to + argument values, and will always return the same value even when + called with different arguments. + + """ + def __init__(self, fget, doc=None): + self.fget = fget + self.__doc__ = doc or fget.__doc__ + self.__name__ = fget.__name__ + + def __get__(self, obj, cls): + if obj is None: + return self + def oneshot(*args, **kw): + result = self.fget(obj, *args, **kw) + memo = lambda *a, **kw: result + memo.__name__ = self.__name__ + memo.__doc__ = self.__doc__ + obj.__dict__[self.__name__] = memo + return result + oneshot.__name__ = self.__name__ + oneshot.__doc__ = self.__doc__ + return oneshot + class SetLikeDict(dict): """a dictionary that has some setlike methods on it""" def union(self, other): """produce a 'union' of this dict and another (at the key level). - + values in the second dict take precedence over that of the first""" x = SetLikeDict(**self) x.update(other) return x class FastEncodingBuffer(object): - """a very rudimentary buffer that is faster than StringIO, + """a very rudimentary buffer that is faster than StringIO, but doesn't crash on unicode data like cStringIO.""" - + def __init__(self, encoding=None, errors='strict', unicode=False): self.data = collections.deque() self.encoding = encoding @@ -124,26 +190,27 @@ class FastEncodingBuffer(object): self.unicode = unicode self.errors = errors self.write = self.data.append - + def truncate(self): self.data = collections.deque() self.write = self.data.append - + def getvalue(self): if self.encoding: - return self.delim.join(self.data).encode(self.encoding, self.errors) + return self.delim.join(self.data).encode(self.encoding, + self.errors) else: return self.delim.join(self.data) class LRUCache(dict): - """A dictionary-like object that stores a limited number of items, discarding - lesser used items periodically. - + """A dictionary-like object that stores a limited number of items, + discarding lesser used items periodically. + this is a rewrite of LRUCache from Myghty to use a periodic timestamp-based - paradigm so that synchronization is not really needed. the size management + paradigm so that synchronization is not really needed. the size management is inexact. """ - + class _Item(object): def __init__(self, key, value): self.key = key @@ -151,26 +218,26 @@ class LRUCache(dict): self.timestamp = time_func() def __repr__(self): return repr(self.value) - + def __init__(self, capacity, threshold=.5): self.capacity = capacity self.threshold = threshold - + def __getitem__(self, key): item = dict.__getitem__(self, key) item.timestamp = time_func() return item.value - + def values(self): return [i.value for i in dict.values(self)] - + def setdefault(self, key, value): if key in self: return self[key] else: self[key] = value return value - + def __setitem__(self, key, value): item = dict.get(self, key) if item is None: @@ -179,17 +246,17 @@ class LRUCache(dict): else: item.value = value self._manage_size() - + def _manage_size(self): while len(self) > self.capacity + self.capacity * self.threshold: - bytime = sorted(dict.values(self), + bytime = sorted(dict.values(self), key=operator.attrgetter('timestamp'), reverse=True) for item in bytime[self.capacity:]: try: del self[item.key] except KeyError: - # if we couldnt find a key, most likely some other thread broke in - # on us. loop around and try again + # if we couldn't find a key, most likely some other thread + # broke in on us. loop around and try again break # Regexp to match python magic encoding line @@ -198,7 +265,8 @@ _PYTHON_MAGIC_COMMENT_re = re.compile( re.VERBOSE) def parse_encoding(fp): - """Deduce the encoding of a Python source file (binary mode) from magic comment. + """Deduce the encoding of a Python source file (binary mode) from magic + comment. It does this in the same way as the `Python interpreter`__ @@ -227,7 +295,8 @@ def parse_encoding(fp): pass else: line2 = fp.readline() - m = _PYTHON_MAGIC_COMMENT_re.match(line2.decode('ascii', 'ignore')) + m = _PYTHON_MAGIC_COMMENT_re.match( + line2.decode('ascii', 'ignore')) if has_bom: if m: @@ -244,14 +313,14 @@ def parse_encoding(fp): def sorted_dict_repr(d): """repr() a dictionary with the keys in order. - + Used by the lexer unit test to compare parse trees based on strings. - + """ keys = d.keys() keys.sort() return "{" + ", ".join(["%r: %r" % (k, d[k]) for k in keys]) + "}" - + def restore__ast(_ast): """Attempt to restore the required classes to the _ast module if it appears to be missing them @@ -350,3 +419,18 @@ except ImportError: import inspect def inspect_func_args(fn): return inspect.getargspec(fn) + +def read_file(path, mode='rb'): + fp = open(path, mode) + try: + data = fp.read() + return data + finally: + fp.close() + +def load_module(module_id, path): + fp = open(path, 'rb') + try: + return imp.load_source(module_id, path, fp) + finally: + fp.close() From d2ae29d2d6d8c9c6328488e6e5943aafa8e561b5 Mon Sep 17 00:00:00 2001 From: rembo10 Date: Sat, 28 Jul 2012 23:56:28 +0530 Subject: [PATCH 14/84] Updated cherrypy to 3.2.2 --- cherrypy/__init__.py | 16 +- cherrypy/_cpcompat.py | 45 +- cherrypy/_cpdispatch.py | 28 +- cherrypy/_cperror.py | 53 +- cherrypy/_cplogging.py | 85 +- cherrypy/_cpmodpy.py | 15 +- cherrypy/_cpreqbody.py | 50 +- cherrypy/_cprequest.py | 48 +- cherrypy/_cpserver.py | 24 +- cherrypy/_cptools.py | 4 +- cherrypy/_cptree.py | 35 +- cherrypy/_cpwsgi.py | 95 +- cherrypy/_cpwsgi_server.py | 13 +- cherrypy/lib/__init__.py | 2 +- cherrypy/lib/cpstats.py | 7 +- cherrypy/lib/cptools.py | 14 +- cherrypy/lib/gctools.py | 214 ++ cherrypy/lib/httputil.py | 51 +- cherrypy/lib/jsontools.py | 2 +- cherrypy/lib/reprconf.py | 168 +- cherrypy/lib/sessions.py | 73 +- cherrypy/lib/static.py | 13 +- cherrypy/lib/{xmlrpc.py => xmlrpcutil.py} | 26 +- cherrypy/process/plugins.py | 12 +- cherrypy/process/servers.py | 21 +- cherrypy/process/wspbus.py | 57 +- cherrypy/test/__init__.py | 25 - cherrypy/test/_test_decorators.py | 41 - cherrypy/test/_test_states_demo.py | 66 - cherrypy/test/benchmark.py | 409 ---- cherrypy/test/checkerdemo.py | 47 - cherrypy/test/fastcgi.conf | 18 - cherrypy/test/fcgi.conf | 14 - cherrypy/test/helper.py | 476 ---- cherrypy/test/logtest.py | 181 -- cherrypy/test/modfastcgi.py | 135 -- cherrypy/test/modfcgid.py | 125 -- cherrypy/test/modpy.py | 163 -- cherrypy/test/modwsgi.py | 148 -- cherrypy/test/native-server.ini | 9 - cherrypy/test/sessiondemo.py | 153 -- cherrypy/test/static/dirback.jpg | Bin 18238 -> 0 bytes cherrypy/test/static/index.html | 1 - cherrypy/test/style.css | 1 - cherrypy/test/test.pem | 38 - cherrypy/test/test_auth_basic.py | 79 - cherrypy/test/test_auth_digest.py | 115 - cherrypy/test/test_bus.py | 263 --- cherrypy/test/test_caching.py | 329 --- cherrypy/test/test_config.py | 249 --- cherrypy/test/test_config_server.py | 121 - cherrypy/test/test_conn.py | 734 ------- cherrypy/test/test_core.py | 617 ------ cherrypy/test/test_dynamicobjectmapping.py | 403 ---- cherrypy/test/test_encoding.py | 363 --- cherrypy/test/test_etags.py | 81 - cherrypy/test/test_http.py | 168 -- cherrypy/test/test_httpauth.py | 151 -- cherrypy/test/test_httplib.py | 29 - cherrypy/test/test_json.py | 79 - cherrypy/test/test_logging.py | 149 -- cherrypy/test/test_mime.py | 128 -- cherrypy/test/test_misc_tools.py | 202 -- cherrypy/test/test_objectmapping.py | 403 ---- cherrypy/test/test_proxy.py | 129 -- cherrypy/test/test_refleaks.py | 119 - cherrypy/test/test_request_obj.py | 722 ------ cherrypy/test/test_routes.py | 69 - cherrypy/test/test_session.py | 464 ---- cherrypy/test/test_sessionauthenticate.py | 62 - cherrypy/test/test_states.py | 436 ---- cherrypy/test/test_static.py | 300 --- cherrypy/test/test_tools.py | 393 ---- cherrypy/test/test_tutorials.py | 201 -- cherrypy/test/test_virtualhost.py | 107 - cherrypy/test/test_wsgi_ns.py | 80 - cherrypy/test/test_wsgi_vhost.py | 36 - cherrypy/test/test_wsgiapps.py | 111 - cherrypy/test/test_xmlrpc.py | 172 -- cherrypy/test/webtest.py | 535 ----- cherrypy/wsgiserver/__init__.py | 2227 +------------------ cherrypy/wsgiserver/ssl_builtin.py | 25 +- cherrypy/wsgiserver/wsgiserver2.py | 2322 ++++++++++++++++++++ cherrypy/wsgiserver/wsgiserver3.py | 2040 +++++++++++++++++ 84 files changed, 5359 insertions(+), 13075 deletions(-) create mode 100644 cherrypy/lib/gctools.py rename cherrypy/lib/{xmlrpc.py => xmlrpcutil.py} (61%) delete mode 100644 cherrypy/test/__init__.py delete mode 100644 cherrypy/test/_test_decorators.py delete mode 100644 cherrypy/test/_test_states_demo.py delete mode 100644 cherrypy/test/benchmark.py delete mode 100644 cherrypy/test/checkerdemo.py delete mode 100644 cherrypy/test/fastcgi.conf delete mode 100644 cherrypy/test/fcgi.conf delete mode 100644 cherrypy/test/helper.py delete mode 100644 cherrypy/test/logtest.py delete mode 100644 cherrypy/test/modfastcgi.py delete mode 100644 cherrypy/test/modfcgid.py delete mode 100644 cherrypy/test/modpy.py delete mode 100644 cherrypy/test/modwsgi.py delete mode 100644 cherrypy/test/native-server.ini delete mode 100755 cherrypy/test/sessiondemo.py delete mode 100644 cherrypy/test/static/dirback.jpg delete mode 100644 cherrypy/test/static/index.html delete mode 100644 cherrypy/test/style.css delete mode 100644 cherrypy/test/test.pem delete mode 100644 cherrypy/test/test_auth_basic.py delete mode 100644 cherrypy/test/test_auth_digest.py delete mode 100644 cherrypy/test/test_bus.py delete mode 100644 cherrypy/test/test_caching.py delete mode 100644 cherrypy/test/test_config.py delete mode 100644 cherrypy/test/test_config_server.py delete mode 100644 cherrypy/test/test_conn.py delete mode 100644 cherrypy/test/test_core.py delete mode 100644 cherrypy/test/test_dynamicobjectmapping.py delete mode 100644 cherrypy/test/test_encoding.py delete mode 100644 cherrypy/test/test_etags.py delete mode 100644 cherrypy/test/test_http.py delete mode 100644 cherrypy/test/test_httpauth.py delete mode 100644 cherrypy/test/test_httplib.py delete mode 100644 cherrypy/test/test_json.py delete mode 100644 cherrypy/test/test_logging.py delete mode 100644 cherrypy/test/test_mime.py delete mode 100644 cherrypy/test/test_misc_tools.py delete mode 100644 cherrypy/test/test_objectmapping.py delete mode 100644 cherrypy/test/test_proxy.py delete mode 100644 cherrypy/test/test_refleaks.py delete mode 100644 cherrypy/test/test_request_obj.py delete mode 100644 cherrypy/test/test_routes.py delete mode 100755 cherrypy/test/test_session.py delete mode 100644 cherrypy/test/test_sessionauthenticate.py delete mode 100644 cherrypy/test/test_states.py delete mode 100644 cherrypy/test/test_static.py delete mode 100644 cherrypy/test/test_tools.py delete mode 100644 cherrypy/test/test_tutorials.py delete mode 100644 cherrypy/test/test_virtualhost.py delete mode 100644 cherrypy/test/test_wsgi_ns.py delete mode 100644 cherrypy/test/test_wsgi_vhost.py delete mode 100644 cherrypy/test/test_wsgiapps.py delete mode 100644 cherrypy/test/test_xmlrpc.py delete mode 100644 cherrypy/test/webtest.py create mode 100644 cherrypy/wsgiserver/wsgiserver2.py create mode 100644 cherrypy/wsgiserver/wsgiserver3.py diff --git a/cherrypy/__init__.py b/cherrypy/__init__.py index eb7cabf6..41e3898b 100644 --- a/cherrypy/__init__.py +++ b/cherrypy/__init__.py @@ -57,10 +57,10 @@ These API's are described in the CherryPy specification: http://www.cherrypy.org/wiki/CherryPySpec """ -__version__ = "3.2.0" +__version__ = "3.2.2" from cherrypy._cpcompat import urljoin as _urljoin, urlencode as _urlencode -from cherrypy._cpcompat import basestring, unicodestr +from cherrypy._cpcompat import basestring, unicodestr, set from cherrypy._cperror import HTTPError, HTTPRedirect, InternalRedirect from cherrypy._cperror import NotFound, CherryPyException, TimeoutError @@ -89,17 +89,21 @@ except ImportError: engine = process.bus -# Timeout monitor +# Timeout monitor. We add two channels to the engine +# to which cherrypy.Application will publish. +engine.listeners['before_request'] = set() +engine.listeners['after_request'] = set() + class _TimeoutMonitor(process.plugins.Monitor): def __init__(self, bus): self.servings = [] process.plugins.Monitor.__init__(self, bus, self.run) - def acquire(self): + def before_request(self): self.servings.append((serving.request, serving.response)) - def release(self): + def after_request(self): try: self.servings.remove((serving.request, serving.response)) except ValueError: @@ -585,7 +589,7 @@ def url(path="", qs="", script_name=None, base=None, relative=None): elif relative: # "A relative reference that does not begin with a scheme name # or a slash character is termed a relative-path reference." - old = url().split('/')[:-1] + old = url(relative=False).split('/')[:-1] new = newurl.split('/') while old and new: a, b = old[0], new[0] diff --git a/cherrypy/_cpcompat.py b/cherrypy/_cpcompat.py index 216ddddc..ed24c1ab 100644 --- a/cherrypy/_cpcompat.py +++ b/cherrypy/_cpcompat.py @@ -16,9 +16,11 @@ It also provides a 'base64_decode' function with native strings as input and output. """ import os +import re import sys if sys.version_info >= (3, 0): + py3k = True bytestr = bytes unicodestr = str nativestr = unicodestr @@ -31,12 +33,19 @@ if sys.version_info >= (3, 0): """Return the given native string as a unicode string with the given encoding.""" # In Python 3, the native string type is unicode return n + def tonative(n, encoding='ISO-8859-1'): + """Return the given string as a native string in the given encoding.""" + # In Python 3, the native string type is unicode + if isinstance(n, bytes): + return n.decode(encoding) + return n # type("") from io import StringIO # bytes: from io import BytesIO as BytesIO else: # Python 2 + py3k = False bytestr = str unicodestr = unicode nativestr = bytestr @@ -49,10 +58,25 @@ else: return n def ntou(n, encoding='ISO-8859-1'): """Return the given native string as a unicode string with the given encoding.""" - # In Python 2, the native string type is bytes. Assume it's already - # in the given encoding, which for ISO-8859-1 is almost always what - # was intended. + # In Python 2, the native string type is bytes. + # First, check for the special encoding 'escape'. The test suite uses this + # to signal that it wants to pass a string with embedded \uXXXX escapes, + # but without having to prefix it with u'' for Python 2, but no prefix + # for Python 3. + if encoding == 'escape': + return unicode( + re.sub(r'\\u([0-9a-zA-Z]{4})', + lambda m: unichr(int(m.group(1), 16)), + n.decode('ISO-8859-1'))) + # Assume it's already in the given encoding, which for ISO-8859-1 is almost + # always what was intended. return n.decode(encoding) + def tonative(n, encoding='ISO-8859-1'): + """Return the given string as a native string in the given encoding.""" + # In Python 2, the native string type is bytes. + if isinstance(n, unicode): + return n.encode(encoding) + return n try: # type("") from cStringIO import StringIO @@ -185,6 +209,18 @@ except ImportError: from http.client import BadStatusLine, HTTPConnection, HTTPSConnection, IncompleteRead, NotConnected from http.server import BaseHTTPRequestHandler +try: + # Python 2. We have to do it in this order so Python 2 builds + # don't try to import the 'http' module from cherrypy.lib + from httplib import HTTPSConnection +except ImportError: + try: + # Python 3 + from http.client import HTTPSConnection + except ImportError: + # Some platforms which don't have SSL don't expose HTTPSConnection + HTTPSConnection = None + try: # Python 2 xrange = xrange @@ -229,7 +265,7 @@ try: json_decode = json.JSONDecoder().decode json_encode = json.JSONEncoder().iterencode except ImportError: - if sys.version_info >= (3, 0): + if py3k: # Python 3.0: json is part of the standard library, # but outputs unicode. We need bytes. import json @@ -280,4 +316,3 @@ except NameError: # Python 2 def next(i): return i.next() - diff --git a/cherrypy/_cpdispatch.py b/cherrypy/_cpdispatch.py index 7250ac92..d614e086 100644 --- a/cherrypy/_cpdispatch.py +++ b/cherrypy/_cpdispatch.py @@ -12,8 +12,13 @@ to a hierarchical arrangement of objects, starting at request.app.root. import string import sys import types +try: + classtype = (type, types.ClassType) +except AttributeError: + classtype = type import cherrypy +from cherrypy._cpcompat import set class PageHandler(object): @@ -197,8 +202,18 @@ class LateParamPageHandler(PageHandler): 'cherrypy.request.params copied in)') -punctuation_to_underscores = string.maketrans( - string.punctuation, '_' * len(string.punctuation)) +if sys.version_info < (3, 0): + punctuation_to_underscores = string.maketrans( + string.punctuation, '_' * len(string.punctuation)) + def validate_translator(t): + if not isinstance(t, str) or len(t) != 256: + raise ValueError("The translate argument must be a str of len 256.") +else: + punctuation_to_underscores = str.maketrans( + string.punctuation, '_' * len(string.punctuation)) + def validate_translator(t): + if not isinstance(t, dict): + raise ValueError("The translate argument must be a dict.") class Dispatcher(object): """CherryPy Dispatcher which walks a tree of objects to find a handler. @@ -222,8 +237,7 @@ class Dispatcher(object): def __init__(self, dispatch_method_name=None, translate=punctuation_to_underscores): - if not isinstance(translate, str) or len(translate) != 256: - raise ValueError("The translate argument must be a str of len 256.") + validate_translator(translate) self.translate = translate if dispatch_method_name: self.dispatch_method_name = dispatch_method_name @@ -524,7 +538,7 @@ class RoutesDispatcher(object): controller = result.get('controller') controller = self.controllers.get(controller, controller) if controller: - if isinstance(controller, (type, types.ClassType)): + if isinstance(controller, classtype): controller = controller() # Get config from the controller. if hasattr(controller, "_cp_config"): @@ -550,9 +564,9 @@ class RoutesDispatcher(object): def XMLRPCDispatcher(next_dispatcher=Dispatcher()): - from cherrypy.lib import xmlrpc + from cherrypy.lib import xmlrpcutil def xmlrpc_dispatch(path_info): - path_info = xmlrpc.patched_path(path_info) + path_info = xmlrpcutil.patched_path(path_info) return next_dispatcher(path_info) return xmlrpc_dispatch diff --git a/cherrypy/_cperror.py b/cherrypy/_cperror.py index 00e5b532..76a409ff 100644 --- a/cherrypy/_cperror.py +++ b/cherrypy/_cperror.py @@ -107,7 +107,7 @@ and not simply return an error message as a result. from cgi import escape as _escape from sys import exc_info as _exc_info from traceback import format_exception as _format_exception -from cherrypy._cpcompat import basestring, iteritems, urljoin as _urljoin +from cherrypy._cpcompat import basestring, bytestr, iteritems, ntob, tonative, urljoin as _urljoin from cherrypy.lib import httputil as _httputil @@ -183,7 +183,7 @@ class HTTPRedirect(CherryPyException): """The list of URL's to emit.""" encoding = 'utf-8' - """The encoding when passed urls are unicode objects""" + """The encoding when passed urls are not native strings""" def __init__(self, urls, status=None, encoding=None): import cherrypy @@ -194,8 +194,7 @@ class HTTPRedirect(CherryPyException): abs_urls = [] for url in urls: - if isinstance(url, unicode): - url = url.encode(encoding or self.encoding) + url = tonative(url, encoding or self.encoding) # Note that urljoin will "do the right thing" whether url is: # 1. a complete URL with host (e.g. "http://www.example.com/test") @@ -248,7 +247,7 @@ class HTTPRedirect(CherryPyException): 307: "This resource has moved temporarily to %s.", }[status] msgs = [msg % (u, u) for u in self.urls] - response.body = "
\n".join(msgs) + response.body = ntob("
\n".join(msgs), 'utf-8') # Previous code may have set C-L, so we have to reset it # (allow finalize to set it). response.headers.pop('Content-Length', None) @@ -341,8 +340,8 @@ class HTTPError(CherryPyException): self.status = status try: self.code, self.reason, defaultmsg = _httputil.valid_status(status) - except ValueError, x: - raise self.__class__(500, x.args[0]) + except ValueError: + raise self.__class__(500, _exc_info()[1].args[0]) if self.code < 400 or self.code > 599: raise ValueError("status must be between 400 and 599.") @@ -373,8 +372,8 @@ class HTTPError(CherryPyException): response.headers['Content-Type'] = "text/html;charset=utf-8" response.headers.pop('Content-Length', None) - content = self.get_error_page(self.status, traceback=tb, - message=self._message) + content = ntob(self.get_error_page(self.status, traceback=tb, + message=self._message), 'utf-8') response.body = content _be_ie_unfriendly(self.code) @@ -442,8 +441,8 @@ def get_error_page(status, **kwargs): try: code, reason, message = _httputil.valid_status(status) - except ValueError, x: - raise cherrypy.HTTPError(500, x.args[0]) + except ValueError: + raise cherrypy.HTTPError(500, _exc_info()[1].args[0]) # We can't use setdefault here, because some # callers send None for kwarg values. @@ -470,7 +469,8 @@ def get_error_page(status, **kwargs): if hasattr(error_page, '__call__'): return error_page(**kwargs) else: - return open(error_page, 'rb').read() % kwargs + data = open(error_page, 'rb').read() + return tonative(data) % kwargs except: e = _format_exception(*_exc_info())[-1] m = kwargs['message'] @@ -508,19 +508,22 @@ def _be_ie_unfriendly(status): if l and l < s: # IN ADDITION: the response must be written to IE # in one chunk or it will still get replaced! Bah. - content = content + (" " * (s - l)) + content = content + (ntob(" ") * (s - l)) response.body = content response.headers['Content-Length'] = str(len(content)) def format_exc(exc=None): """Return exc (or sys.exc_info if None), formatted.""" - if exc is None: - exc = _exc_info() - if exc == (None, None, None): - return "" - import traceback - return "".join(traceback.format_exception(*exc)) + try: + if exc is None: + exc = _exc_info() + if exc == (None, None, None): + return "" + import traceback + return "".join(traceback.format_exception(*exc)) + finally: + del exc def bare_error(extrabody=None): """Produce status, headers, body for a critical error. @@ -539,15 +542,15 @@ def bare_error(extrabody=None): # it cannot be allowed to fail. Therefore, don't add to it! # In particular, don't call any other CP functions. - body = "Unrecoverable error in the server." + body = ntob("Unrecoverable error in the server.") if extrabody is not None: - if not isinstance(extrabody, str): + if not isinstance(extrabody, bytestr): extrabody = extrabody.encode('utf-8') - body += "\n" + extrabody + body += ntob("\n") + extrabody - return ("500 Internal Server Error", - [('Content-Type', 'text/plain'), - ('Content-Length', str(len(body)))], + return (ntob("500 Internal Server Error"), + [(ntob('Content-Type'), ntob('text/plain')), + (ntob('Content-Length'), ntob(str(len(body)),'ISO-8859-1'))], [body]) diff --git a/cherrypy/_cplogging.py b/cherrypy/_cplogging.py index d6ca979e..e10c9420 100644 --- a/cherrypy/_cplogging.py +++ b/cherrypy/_cplogging.py @@ -109,6 +109,20 @@ import sys import cherrypy from cherrypy import _cperror +from cherrypy._cpcompat import ntob, py3k + + +class NullHandler(logging.Handler): + """A no-op logging handler to silence the logging.lastResort handler.""" + + def handle(self, record): + pass + + def emit(self, record): + pass + + def createLock(self): + self.lock = None class LogManager(object): @@ -127,8 +141,12 @@ class LogManager(object): access_log = None """The actual :class:`logging.Logger` instance for access messages.""" - access_log_format = \ - '%(h)s %(l)s %(u)s %(t)s "%(r)s" %(s)s %(b)s "%(f)s" "%(a)s"' + if py3k: + access_log_format = \ + '{h} {l} {u} {t} "{r}" {s} {b} "{f}" "{a}"' + else: + access_log_format = \ + '%(h)s %(l)s %(u)s %(t)s "%(r)s" %(s)s %(b)s "%(f)s" "%(a)s"' logger_root = None """The "top-level" logger name. @@ -152,8 +170,13 @@ class LogManager(object): self.access_log = logging.getLogger("%s.access.%s" % (logger_root, appid)) self.error_log.setLevel(logging.INFO) self.access_log.setLevel(logging.INFO) + + # Silence the no-handlers "warning" (stderr write!) in stdlib logging + self.error_log.addHandler(NullHandler()) + self.access_log.addHandler(NullHandler()) + cherrypy.engine.subscribe('graceful', self.reopen_files) - + def reopen_files(self): """Close and reopen all file handlers.""" for log in (self.error_log, self.access_log): @@ -206,7 +229,9 @@ class LogManager(object): if response.output_status is None: status = "-" else: - status = response.output_status.split(" ", 1)[0] + status = response.output_status.split(ntob(" "), 1)[0] + if py3k: + status = status.decode('ISO-8859-1') atoms = {'h': remote.name or remote.ip, 'l': '-', @@ -218,21 +243,43 @@ class LogManager(object): 'f': dict.get(inheaders, 'Referer', ''), 'a': dict.get(inheaders, 'User-Agent', ''), } - for k, v in atoms.items(): - if isinstance(v, unicode): - v = v.encode('utf8') - elif not isinstance(v, str): - v = str(v) - # Fortunately, repr(str) escapes unprintable chars, \n, \t, etc - # and backslash for us. All we have to do is strip the quotes. - v = repr(v)[1:-1] - # Escape double-quote. - atoms[k] = v.replace('"', '\\"') - - try: - self.access_log.log(logging.INFO, self.access_log_format % atoms) - except: - self(traceback=True) + if py3k: + for k, v in atoms.items(): + if not isinstance(v, str): + v = str(v) + v = v.replace('"', '\\"').encode('utf8') + # Fortunately, repr(str) escapes unprintable chars, \n, \t, etc + # and backslash for us. All we have to do is strip the quotes. + v = repr(v)[2:-1] + + # in python 3.0 the repr of bytes (as returned by encode) + # uses double \'s. But then the logger escapes them yet, again + # resulting in quadruple slashes. Remove the extra one here. + v = v.replace('\\\\', '\\') + + # Escape double-quote. + atoms[k] = v + + try: + self.access_log.log(logging.INFO, self.access_log_format.format(**atoms)) + except: + self(traceback=True) + else: + for k, v in atoms.items(): + if isinstance(v, unicode): + v = v.encode('utf8') + elif not isinstance(v, str): + v = str(v) + # Fortunately, repr(str) escapes unprintable chars, \n, \t, etc + # and backslash for us. All we have to do is strip the quotes. + v = repr(v)[1:-1] + # Escape double-quote. + atoms[k] = v.replace('"', '\\"') + + try: + self.access_log.log(logging.INFO, self.access_log_format % atoms) + except: + self(traceback=True) def time(self): """Return now() in Apache Common Log Format (no timezone).""" diff --git a/cherrypy/_cpmodpy.py b/cherrypy/_cpmodpy.py index ba2ab22f..76ef6ead 100644 --- a/cherrypy/_cpmodpy.py +++ b/cherrypy/_cpmodpy.py @@ -224,7 +224,7 @@ def handler(req): qs = ir.query_string rfile = BytesIO() - send_response(req, response.status, response.header_list, + send_response(req, response.output_status, response.header_list, response.body, response.stream) finally: app.release_serving() @@ -266,11 +266,22 @@ def send_response(req, status, headers, body, stream=False): import os import re +try: + import subprocess + def popen(fullcmd): + p = subprocess.Popen(fullcmd, shell=True, + stdout=subprocess.PIPE, stderr=subprocess.STDOUT, + close_fds=True) + return p.stdout +except ImportError: + def popen(fullcmd): + pipein, pipeout = os.popen4(fullcmd) + return pipeout def read_process(cmd, args=""): fullcmd = "%s %s" % (cmd, args) - pipein, pipeout = os.popen4(fullcmd) + pipeout = popen(fullcmd) try: firstline = pipeout.readline() if (re.search(ntob("(not recognized|No such file|not found)"), firstline, diff --git a/cherrypy/_cpreqbody.py b/cherrypy/_cpreqbody.py index 1b0496e3..5d72c854 100644 --- a/cherrypy/_cpreqbody.py +++ b/cherrypy/_cpreqbody.py @@ -101,10 +101,28 @@ If we were defining a custom processor, we can do so without making a ``Tool``. Note that you can only replace the ``processors`` dict wholesale this way, not update the existing one. """ +try: + from io import DEFAULT_BUFFER_SIZE +except ImportError: + DEFAULT_BUFFER_SIZE = 8192 import re import sys import tempfile -from urllib import unquote_plus +try: + from urllib import unquote_plus +except ImportError: + def unquote_plus(bs): + """Bytes version of urllib.parse.unquote_plus.""" + bs = bs.replace(ntob('+'), ntob(' ')) + atoms = bs.split(ntob('%')) + for i in range(1, len(atoms)): + item = atoms[i] + try: + pct = int(item[:2], 16) + atoms[i] = bytes([pct]) + item[2:] + except ValueError: + pass + return ntob('').join(atoms) import cherrypy from cherrypy._cpcompat import basestring, ntob, ntou @@ -399,7 +417,6 @@ class Entity(object): # Copy the class 'attempt_charsets', prepending any Content-Type charset dec = self.content_type.params.get("charset", None) if dec: - #dec = dec.decode('ISO-8859-1') self.attempt_charsets = [dec] + [c for c in self.attempt_charsets if c != dec] else: @@ -446,11 +463,14 @@ class Entity(object): def __iter__(self): return self - def next(self): + def __next__(self): line = self.readline() if not line: raise StopIteration return line + + def next(self): + return self.__next__() def read_into_file(self, fp_out=None): """Read the request body into fp_out (or make_file() if None). Return fp_out.""" @@ -671,13 +691,16 @@ class Part(Entity): Entity.part_class = Part - -class Infinity(object): - def __cmp__(self, other): - return 1 - def __sub__(self, other): - return self -inf = Infinity() +try: + inf = float('inf') +except ValueError: + # Python 2.4 and lower + class Infinity(object): + def __cmp__(self, other): + return 1 + def __sub__(self, other): + return self + inf = Infinity() comma_separated_headers = ['Accept', 'Accept-Charset', 'Accept-Encoding', @@ -689,7 +712,7 @@ comma_separated_headers = ['Accept', 'Accept-Charset', 'Accept-Encoding', class SizedReader: - def __init__(self, fp, length, maxbytes, bufsize=8192, has_trailers=False): + def __init__(self, fp, length, maxbytes, bufsize=DEFAULT_BUFFER_SIZE, has_trailers=False): # Wrap our fp in a buffer so peek() works self.fp = fp self.length = length @@ -930,8 +953,9 @@ class RequestBody(Entity): request_params = self.request_params for key, value in self.params.items(): # Python 2 only: keyword arguments must be byte strings (type 'str'). - if isinstance(key, unicode): - key = key.encode('ISO-8859-1') + if sys.version_info < (3, 0): + if isinstance(key, unicode): + key = key.encode('ISO-8859-1') if key in request_params: if not isinstance(request_params[key], list): diff --git a/cherrypy/_cprequest.py b/cherrypy/_cprequest.py index ae5e8971..5890c728 100644 --- a/cherrypy/_cprequest.py +++ b/cherrypy/_cprequest.py @@ -6,7 +6,7 @@ import warnings import cherrypy from cherrypy._cpcompat import basestring, copykeys, ntob, unicodestr -from cherrypy._cpcompat import SimpleCookie, CookieError +from cherrypy._cpcompat import SimpleCookie, CookieError, py3k from cherrypy import _cpreqbody, _cpconfig from cherrypy._cperror import format_exc, bare_error from cherrypy.lib import httputil, file_generator @@ -49,7 +49,12 @@ class Hook(object): self.kwargs = kwargs + def __lt__(self, other): + # Python 3 + return self.priority < other.priority + def __cmp__(self, other): + # Python 2 return cmp(self.priority, other.priority) def __call__(self): @@ -104,7 +109,7 @@ class HookMap(dict): exc = sys.exc_info()[1] cherrypy.log(traceback=True, severity=40) if exc: - raise + raise exc def __copy__(self): newmap = self.__class__() @@ -488,14 +493,20 @@ class Request(object): self.stage = 'close' def run(self, method, path, query_string, req_protocol, headers, rfile): - """Process the Request. (Core) + r"""Process the Request. (Core) method, path, query_string, and req_protocol should be pulled directly from the Request-Line (e.g. "GET /path?key=val HTTP/1.0"). path This should be %XX-unquoted, but query_string should not be. - They both MUST be byte strings, not unicode strings. + + When using Python 2, they both MUST be byte strings, + not unicode strings. + + When using Python 3, they both MUST be unicode strings, + not byte strings, and preferably not bytes \x00-\xFF + disguised as unicode. headers A list of (name, value) tuples. @@ -676,10 +687,11 @@ class Request(object): self.query_string_encoding) # Python 2 only: keyword arguments must be byte strings (type 'str'). - for key, value in p.items(): - if isinstance(key, unicode): - del p[key] - p[key.encode(self.query_string_encoding)] = value + if not py3k: + for key, value in p.items(): + if isinstance(key, unicode): + del p[key] + p[key.encode(self.query_string_encoding)] = value self.params.update(p) def process_headers(self): @@ -770,6 +782,10 @@ class Request(object): class ResponseBody(object): """The body of the HTTP response (the response entity).""" + if py3k: + unicode_err = ("Page handlers MUST return bytes. Use tools.encode " + "if you wish to return unicode.") + def __get__(self, obj, objclass=None): if obj is None: # When calling on the class instead of an instance... @@ -779,6 +795,9 @@ class ResponseBody(object): def __set__(self, obj, value): # Convert the given value to an iterable object. + if py3k and isinstance(value, str): + raise ValueError(self.unicode_err) + if isinstance(value, basestring): # strings get wrapped in a list because iterating over a single # item list is much faster than iterating over every character @@ -788,6 +807,11 @@ class ResponseBody(object): else: # [''] doesn't evaluate to False, so replace it with []. value = [] + elif py3k and isinstance(value, list): + # every item in a list must be bytes... + for i, item in enumerate(value): + if isinstance(item, str): + raise ValueError(self.unicode_err) # Don't use isinstance here; io.IOBase which has an ABC takes # 1000 times as long as, say, isinstance(value, str) elif hasattr(value, 'read'): @@ -862,7 +886,12 @@ class Response(object): if isinstance(self.body, basestring): return self.body - newbody = ''.join([chunk for chunk in self.body]) + newbody = [] + for chunk in self.body: + if py3k and not isinstance(chunk, bytes): + raise TypeError("Chunk %s is not of type 'bytes'." % repr(chunk)) + newbody.append(chunk) + newbody = ntob('').join(newbody) self.body = newbody return newbody @@ -876,6 +905,7 @@ class Response(object): headers = self.headers + self.status = "%s %s" % (code, reason) self.output_status = ntob(str(code), 'ascii') + ntob(" ") + headers.encode(reason) if self.stream: diff --git a/cherrypy/_cpserver.py b/cherrypy/_cpserver.py index c1695a66..2eecd6ec 100644 --- a/cherrypy/_cpserver.py +++ b/cherrypy/_cpserver.py @@ -4,7 +4,7 @@ import warnings import cherrypy from cherrypy.lib import attributes -from cherrypy._cpcompat import basestring +from cherrypy._cpcompat import basestring, py3k # We import * because we want to export check_port # et al as attributes of this module. @@ -98,12 +98,22 @@ class Server(ServerAdapter): ssl_private_key = None """The filename of the private key to use with SSL.""" - ssl_module = 'pyopenssl' - """The name of a registered SSL adaptation module to use with the builtin - WSGI server. Builtin options are 'builtin' (to use the SSL library built - into recent versions of Python) and 'pyopenssl' (to use the PyOpenSSL - project, which you must install separately). You may also register your - own classes in the wsgiserver.ssl_adapters dict.""" + if py3k: + ssl_module = 'builtin' + """The name of a registered SSL adaptation module to use with the builtin + WSGI server. Builtin options are: 'builtin' (to use the SSL library built + into recent versions of Python). You may also register your + own classes in the wsgiserver.ssl_adapters dict.""" + else: + ssl_module = 'pyopenssl' + """The name of a registered SSL adaptation module to use with the builtin + WSGI server. Builtin options are 'builtin' (to use the SSL library built + into recent versions of Python) and 'pyopenssl' (to use the PyOpenSSL + project, which you must install separately). You may also register your + own classes in the wsgiserver.ssl_adapters dict.""" + + statistics = False + """Turns statistics-gathering on or off for aware HTTP servers.""" nodelay = True """If True (the default since 3.1), sets the TCP_NODELAY socket option.""" diff --git a/cherrypy/_cptools.py b/cherrypy/_cptools.py index d3eab059..22316b31 100644 --- a/cherrypy/_cptools.py +++ b/cherrypy/_cptools.py @@ -243,7 +243,7 @@ class ErrorTool(Tool): # Builtin tools # from cherrypy.lib import cptools, encoding, auth, static, jsontools -from cherrypy.lib import sessions as _sessions, xmlrpc as _xmlrpc +from cherrypy.lib import sessions as _sessions, xmlrpcutil as _xmlrpc from cherrypy.lib import caching as _caching from cherrypy.lib import auth_basic, auth_digest @@ -367,7 +367,7 @@ class XMLRPCController(object): # http://www.cherrypy.org/ticket/533 # if a method is not found, an xmlrpclib.Fault should be returned # raising an exception here will do that; see - # cherrypy.lib.xmlrpc.on_error + # cherrypy.lib.xmlrpcutil.on_error raise Exception('method "%s" is not supported' % attr) conf = cherrypy.serving.request.toolmaps['tools'].get("xmlrpc", {}) diff --git a/cherrypy/_cptree.py b/cherrypy/_cptree.py index 67ce5465..3aa4b9e1 100644 --- a/cherrypy/_cptree.py +++ b/cherrypy/_cptree.py @@ -1,8 +1,10 @@ """CherryPy Application and Tree objects.""" import os +import sys + import cherrypy -from cherrypy._cpcompat import ntou +from cherrypy._cpcompat import ntou, py3k from cherrypy import _cpconfig, _cplogging, _cprequest, _cpwsgi, tools from cherrypy.lib import httputil @@ -123,8 +125,8 @@ class Application(object): resp = self.response_class() cherrypy.serving.load(req, resp) - cherrypy.engine.timeout_monitor.acquire() cherrypy.engine.publish('acquire_thread') + cherrypy.engine.publish('before_request') return req, resp @@ -132,7 +134,7 @@ class Application(object): """Release the current serving (request and response).""" req = cherrypy.serving.request - cherrypy.engine.timeout_monitor.release() + cherrypy.engine.publish('after_request') try: req.close() @@ -266,14 +268,23 @@ class Tree(object): # Correct the SCRIPT_NAME and PATH_INFO environ entries. environ = environ.copy() - if environ.get(u'wsgi.version') == (u'u', 0): - # Python 2/WSGI u.0: all strings MUST be of type unicode - enc = environ[u'wsgi.url_encoding'] - environ[u'SCRIPT_NAME'] = sn.decode(enc) - environ[u'PATH_INFO'] = path[len(sn.rstrip("/")):].decode(enc) + if not py3k: + if environ.get(ntou('wsgi.version')) == (ntou('u'), 0): + # Python 2/WSGI u.0: all strings MUST be of type unicode + enc = environ[ntou('wsgi.url_encoding')] + environ[ntou('SCRIPT_NAME')] = sn.decode(enc) + environ[ntou('PATH_INFO')] = path[len(sn.rstrip("/")):].decode(enc) + else: + # Python 2/WSGI 1.x: all strings MUST be of type str + environ['SCRIPT_NAME'] = sn + environ['PATH_INFO'] = path[len(sn.rstrip("/")):] else: - # Python 2/WSGI 1.x: all strings MUST be of type str - environ['SCRIPT_NAME'] = sn - environ['PATH_INFO'] = path[len(sn.rstrip("/")):] + if environ.get(ntou('wsgi.version')) == (ntou('u'), 0): + # Python 3/WSGI u.0: all strings MUST be full unicode + environ['SCRIPT_NAME'] = sn + environ['PATH_INFO'] = path[len(sn.rstrip("/")):] + else: + # Python 3/WSGI 1.x: all strings MUST be ISO-8859-1 str + environ['SCRIPT_NAME'] = sn.encode('utf-8').decode('ISO-8859-1') + environ['PATH_INFO'] = path[len(sn.rstrip("/")):].encode('utf-8').decode('ISO-8859-1') return app(environ, start_response) - diff --git a/cherrypy/_cpwsgi.py b/cherrypy/_cpwsgi.py index aa4b7631..91cd044e 100644 --- a/cherrypy/_cpwsgi.py +++ b/cherrypy/_cpwsgi.py @@ -10,7 +10,7 @@ still be translatable to bytes via the Latin-1 encoding!" import sys as _sys import cherrypy as _cherrypy -from cherrypy._cpcompat import BytesIO +from cherrypy._cpcompat import BytesIO, bytestr, ntob, ntou, py3k, unicodestr from cherrypy import _cperror from cherrypy.lib import httputil @@ -19,11 +19,11 @@ def downgrade_wsgi_ux_to_1x(environ): """Return a new environ dict for WSGI 1.x from the given WSGI u.x environ.""" env1x = {} - url_encoding = environ[u'wsgi.url_encoding'] - for k, v in environ.items(): - if k in [u'PATH_INFO', u'SCRIPT_NAME', u'QUERY_STRING']: + url_encoding = environ[ntou('wsgi.url_encoding')] + for k, v in list(environ.items()): + if k in [ntou('PATH_INFO'), ntou('SCRIPT_NAME'), ntou('QUERY_STRING')]: v = v.encode(url_encoding) - elif isinstance(v, unicode): + elif isinstance(v, unicodestr): v = v.encode('ISO-8859-1') env1x[k.encode('ISO-8859-1')] = v @@ -94,7 +94,8 @@ class InternalRedirector(object): environ = environ.copy() try: return self.nextapp(environ, start_response) - except _cherrypy.InternalRedirect, ir: + except _cherrypy.InternalRedirect: + ir = _sys.exc_info()[1] sn = environ.get('SCRIPT_NAME', '') path = environ.get('PATH_INFO', '') qs = environ.get('QUERY_STRING', '') @@ -152,8 +153,12 @@ class _TrappedResponse(object): self.started_response = True return self - def next(self): - return self.trap(self.iter_response.next) + if py3k: + def __next__(self): + return self.trap(next, self.iter_response) + else: + def next(self): + return self.trap(self.iter_response.next) def close(self): if hasattr(self.response, 'close'): @@ -173,6 +178,11 @@ class _TrappedResponse(object): if not _cherrypy.request.show_tracebacks: tb = "" s, h, b = _cperror.bare_error(tb) + if py3k: + # What fun. + s = s.decode('ISO-8859-1') + h = [(k.decode('ISO-8859-1'), v.decode('ISO-8859-1')) + for k, v in h] if self.started_response: # Empty our iterable (so future calls raise StopIteration) self.iter_response = iter([]) @@ -191,7 +201,7 @@ class _TrappedResponse(object): raise if self.started_response: - return "".join(b) + return ntob("").join(b) else: return b @@ -203,24 +213,52 @@ class AppResponse(object): """WSGI response iterable for CherryPy applications.""" def __init__(self, environ, start_response, cpapp): - if environ.get(u'wsgi.version') == (u'u', 0): - environ = downgrade_wsgi_ux_to_1x(environ) - self.environ = environ self.cpapp = cpapp try: + if not py3k: + if environ.get(ntou('wsgi.version')) == (ntou('u'), 0): + environ = downgrade_wsgi_ux_to_1x(environ) + self.environ = environ self.run() + + r = _cherrypy.serving.response + + outstatus = r.output_status + if not isinstance(outstatus, bytestr): + raise TypeError("response.output_status is not a byte string.") + + outheaders = [] + for k, v in r.header_list: + if not isinstance(k, bytestr): + raise TypeError("response.header_list key %r is not a byte string." % k) + if not isinstance(v, bytestr): + raise TypeError("response.header_list value %r is not a byte string." % v) + outheaders.append((k, v)) + + if py3k: + # According to PEP 3333, when using Python 3, the response status + # and headers must be bytes masquerading as unicode; that is, they + # must be of type "str" but are restricted to code points in the + # "latin-1" set. + outstatus = outstatus.decode('ISO-8859-1') + outheaders = [(k.decode('ISO-8859-1'), v.decode('ISO-8859-1')) + for k, v in outheaders] + + self.iter_response = iter(r.body) + self.write = start_response(outstatus, outheaders) except: self.close() raise - r = _cherrypy.serving.response - self.iter_response = iter(r.body) - self.write = start_response(r.output_status, r.header_list) def __iter__(self): return self - def next(self): - return self.iter_response.next() + if py3k: + def __next__(self): + return next(self.iter_response) + else: + def next(self): + return self.iter_response.next() def close(self): """Close and de-reference the current request and response. (Core)""" @@ -253,6 +291,29 @@ class AppResponse(object): path = httputil.urljoin(self.environ.get('SCRIPT_NAME', ''), self.environ.get('PATH_INFO', '')) qs = self.environ.get('QUERY_STRING', '') + + if py3k: + # This isn't perfect; if the given PATH_INFO is in the wrong encoding, + # it may fail to match the appropriate config section URI. But meh. + old_enc = self.environ.get('wsgi.url_encoding', 'ISO-8859-1') + new_enc = self.cpapp.find_config(self.environ.get('PATH_INFO', ''), + "request.uri_encoding", 'utf-8') + if new_enc.lower() != old_enc.lower(): + # Even though the path and qs are unicode, the WSGI server is + # required by PEP 3333 to coerce them to ISO-8859-1 masquerading + # as unicode. So we have to encode back to bytes and then decode + # again using the "correct" encoding. + try: + u_path = path.encode(old_enc).decode(new_enc) + u_qs = qs.encode(old_enc).decode(new_enc) + except (UnicodeEncodeError, UnicodeDecodeError): + # Just pass them through without transcoding and hope. + pass + else: + # Only set transcoded values if they both succeed. + path = u_path + qs = u_qs + rproto = self.environ.get('SERVER_PROTOCOL') headers = self.translate_headers(self.environ) rfile = self.environ['wsgi.input'] diff --git a/cherrypy/_cpwsgi_server.py b/cherrypy/_cpwsgi_server.py index 49fd5a19..21af5134 100644 --- a/cherrypy/_cpwsgi_server.py +++ b/cherrypy/_cpwsgi_server.py @@ -37,8 +37,11 @@ class CPWSGIServer(wsgiserver.CherryPyWSGIServer): ) self.protocol = self.server_adapter.protocol_version self.nodelay = self.server_adapter.nodelay - - ssl_module = self.server_adapter.ssl_module or 'pyopenssl' + + if sys.version_info >= (3, 0): + ssl_module = self.server_adapter.ssl_module or 'builtin' + else: + ssl_module = self.server_adapter.ssl_module or 'pyopenssl' if self.server_adapter.ssl_context: adapter_class = wsgiserver.get_ssl_adapter_class(ssl_module) self.ssl_adapter = adapter_class( @@ -52,3 +55,9 @@ class CPWSGIServer(wsgiserver.CherryPyWSGIServer): self.server_adapter.ssl_certificate, self.server_adapter.ssl_private_key, self.server_adapter.ssl_certificate_chain) + + self.stats['Enabled'] = getattr(self.server_adapter, 'statistics', False) + + def error_log(self, msg="", level=20, traceback=False): + cherrypy.engine.log(msg, level, traceback) + diff --git a/cherrypy/lib/__init__.py b/cherrypy/lib/__init__.py index 611350c9..3fc0ec58 100644 --- a/cherrypy/lib/__init__.py +++ b/cherrypy/lib/__init__.py @@ -1,7 +1,7 @@ """CherryPy Library""" # Deprecated in CherryPy 3.2 -- remove in CherryPy 3.3 -from cherrypy.lib.reprconf import _Builder, unrepr, modules, attributes +from cherrypy.lib.reprconf import unrepr, modules, attributes class file_generator(object): """Yield the given input (a file object) in chunks (default 64k). (Core)""" diff --git a/cherrypy/lib/cpstats.py b/cherrypy/lib/cpstats.py index 79d5c3a9..9be947f2 100644 --- a/cherrypy/lib/cpstats.py +++ b/cherrypy/lib/cpstats.py @@ -320,20 +320,21 @@ class StatsTool(cherrypy.Tool): def record_stop(self, uriset=None, slow_queries=1.0, slow_queries_count=100, debug=False, **kwargs): """Record the end of a request.""" + resp = cherrypy.serving.response w = appstats['Requests'][threading._get_ident()] r = cherrypy.request.rfile.bytes_read w['Bytes Read'] = r appstats['Total Bytes Read'] += r - if cherrypy.response.stream: + if resp.stream: w['Bytes Written'] = 'chunked' else: - cl = int(cherrypy.response.headers.get('Content-Length', 0)) + cl = int(resp.headers.get('Content-Length', 0)) w['Bytes Written'] = cl appstats['Total Bytes Written'] += cl - w['Response Status'] = cherrypy.response.status + w['Response Status'] = getattr(resp, 'output_status', None) or resp.status w['End Time'] = time.time() p = w['End Time'] - w['Start Time'] diff --git a/cherrypy/lib/cptools.py b/cherrypy/lib/cptools.py index 3eedf97a..b426a3e7 100644 --- a/cherrypy/lib/cptools.py +++ b/cherrypy/lib/cptools.py @@ -116,7 +116,7 @@ def validate_since(): # Tool code # def allow(methods=None, debug=False): - """Raise 405 if request.method not in methods (default GET/HEAD). + """Raise 405 if request.method not in methods (default ['GET', 'HEAD']). The given methods are case-insensitive, and may be in any order. If only one method is allowed, you may supply a single string; @@ -151,6 +151,10 @@ def proxy(base=None, local='X-Forwarded-Host', remote='X-Forwarded-For', For running a CP server behind Apache, lighttpd, or other HTTP server. + For Apache and lighttpd, you should leave the 'local' argument at the + default value of 'X-Forwarded-Host'. For Squid, you probably want to set + tools.proxy.local = 'Origin'. + If you want the new request.base to include path info (not just the host), you must explicitly set base to the full base path, and ALSO set 'local' to '', so that the X-Forwarded-Host request header (which never includes @@ -581,9 +585,11 @@ class MonitoredHeaderMap(_httputil.HeaderMap): self.accessed_headers.add(key) return _httputil.HeaderMap.get(self, key, default=default) - def has_key(self, key): - self.accessed_headers.add(key) - return _httputil.HeaderMap.has_key(self, key) + if hasattr({}, 'has_key'): + # Python 2 + def has_key(self, key): + self.accessed_headers.add(key) + return _httputil.HeaderMap.has_key(self, key) def autovary(ignore=None, debug=False): diff --git a/cherrypy/lib/gctools.py b/cherrypy/lib/gctools.py new file mode 100644 index 00000000..183148b2 --- /dev/null +++ b/cherrypy/lib/gctools.py @@ -0,0 +1,214 @@ +import gc +import inspect +import os +import sys +import time + +try: + import objgraph +except ImportError: + objgraph = None + +import cherrypy +from cherrypy import _cprequest, _cpwsgi +from cherrypy.process.plugins import SimplePlugin + + +class ReferrerTree(object): + """An object which gathers all referrers of an object to a given depth.""" + + peek_length = 40 + + def __init__(self, ignore=None, maxdepth=2, maxparents=10): + self.ignore = ignore or [] + self.ignore.append(inspect.currentframe().f_back) + self.maxdepth = maxdepth + self.maxparents = maxparents + + def ascend(self, obj, depth=1): + """Return a nested list containing referrers of the given object.""" + depth += 1 + parents = [] + + # Gather all referrers in one step to minimize + # cascading references due to repr() logic. + refs = gc.get_referrers(obj) + self.ignore.append(refs) + if len(refs) > self.maxparents: + return [("[%s referrers]" % len(refs), [])] + + try: + ascendcode = self.ascend.__code__ + except AttributeError: + ascendcode = self.ascend.im_func.func_code + for parent in refs: + if inspect.isframe(parent) and parent.f_code is ascendcode: + continue + if parent in self.ignore: + continue + if depth <= self.maxdepth: + parents.append((parent, self.ascend(parent, depth))) + else: + parents.append((parent, [])) + + return parents + + def peek(self, s): + """Return s, restricted to a sane length.""" + if len(s) > (self.peek_length + 3): + half = self.peek_length // 2 + return s[:half] + '...' + s[-half:] + else: + return s + + def _format(self, obj, descend=True): + """Return a string representation of a single object.""" + if inspect.isframe(obj): + filename, lineno, func, context, index = inspect.getframeinfo(obj) + return "" % func + + if not descend: + return self.peek(repr(obj)) + + if isinstance(obj, dict): + return "{" + ", ".join(["%s: %s" % (self._format(k, descend=False), + self._format(v, descend=False)) + for k, v in obj.items()]) + "}" + elif isinstance(obj, list): + return "[" + ", ".join([self._format(item, descend=False) + for item in obj]) + "]" + elif isinstance(obj, tuple): + return "(" + ", ".join([self._format(item, descend=False) + for item in obj]) + ")" + + r = self.peek(repr(obj)) + if isinstance(obj, (str, int, float)): + return r + return "%s: %s" % (type(obj), r) + + def format(self, tree): + """Return a list of string reprs from a nested list of referrers.""" + output = [] + def ascend(branch, depth=1): + for parent, grandparents in branch: + output.append((" " * depth) + self._format(parent)) + if grandparents: + ascend(grandparents, depth + 1) + ascend(tree) + return output + + +def get_instances(cls): + return [x for x in gc.get_objects() if isinstance(x, cls)] + + +class RequestCounter(SimplePlugin): + + def start(self): + self.count = 0 + + def before_request(self): + self.count += 1 + + def after_request(self): + self.count -=1 +request_counter = RequestCounter(cherrypy.engine) +request_counter.subscribe() + + +def get_context(obj): + if isinstance(obj, _cprequest.Request): + return "path=%s;stage=%s" % (obj.path_info, obj.stage) + elif isinstance(obj, _cprequest.Response): + return "status=%s" % obj.status + elif isinstance(obj, _cpwsgi.AppResponse): + return "PATH_INFO=%s" % obj.environ.get('PATH_INFO', '') + elif hasattr(obj, "tb_lineno"): + return "tb_lineno=%s" % obj.tb_lineno + return "" + + +class GCRoot(object): + """A CherryPy page handler for testing reference leaks.""" + + classes = [(_cprequest.Request, 2, 2, + "Should be 1 in this request thread and 1 in the main thread."), + (_cprequest.Response, 2, 2, + "Should be 1 in this request thread and 1 in the main thread."), + (_cpwsgi.AppResponse, 1, 1, + "Should be 1 in this request thread only."), + ] + + def index(self): + return "Hello, world!" + index.exposed = True + + def stats(self): + output = ["Statistics:"] + + for trial in range(10): + if request_counter.count > 0: + break + time.sleep(0.5) + else: + output.append("\nNot all requests closed properly.") + + # gc_collect isn't perfectly synchronous, because it may + # break reference cycles that then take time to fully + # finalize. Call it thrice and hope for the best. + gc.collect() + gc.collect() + unreachable = gc.collect() + if unreachable: + if objgraph is not None: + final = objgraph.by_type('Nondestructible') + if final: + objgraph.show_backrefs(final, filename='finalizers.png') + + trash = {} + for x in gc.garbage: + trash[type(x)] = trash.get(type(x), 0) + 1 + if trash: + output.insert(0, "\n%s unreachable objects:" % unreachable) + trash = [(v, k) for k, v in trash.items()] + trash.sort() + for pair in trash: + output.append(" " + repr(pair)) + + # Check declared classes to verify uncollected instances. + # These don't have to be part of a cycle; they can be + # any objects that have unanticipated referrers that keep + # them from being collected. + allobjs = {} + for cls, minobj, maxobj, msg in self.classes: + allobjs[cls] = get_instances(cls) + + for cls, minobj, maxobj, msg in self.classes: + objs = allobjs[cls] + lenobj = len(objs) + if lenobj < minobj or lenobj > maxobj: + if minobj == maxobj: + output.append( + "\nExpected %s %r references, got %s." % + (minobj, cls, lenobj)) + else: + output.append( + "\nExpected %s to %s %r references, got %s." % + (minobj, maxobj, cls, lenobj)) + + for obj in objs: + if objgraph is not None: + ig = [id(objs), id(inspect.currentframe())] + fname = "graph_%s_%s.png" % (cls.__name__, id(obj)) + objgraph.show_backrefs( + obj, extra_ignore=ig, max_depth=4, too_many=20, + filename=fname, extra_info=get_context) + output.append("\nReferrers for %s (refcount=%s):" % + (repr(obj), sys.getrefcount(obj))) + t = ReferrerTree(ignore=[objs], maxdepth=3) + tree = t.ascend(obj) + output.extend(t.format(tree)) + + return "\n".join(output) + stats.exposed = True + diff --git a/cherrypy/lib/httputil.py b/cherrypy/lib/httputil.py index e0058751..5f77d547 100644 --- a/cherrypy/lib/httputil.py +++ b/cherrypy/lib/httputil.py @@ -9,7 +9,7 @@ to a public caning. from binascii import b2a_base64 from cherrypy._cpcompat import BaseHTTPRequestHandler, HTTPDate, ntob, ntou, reversed, sorted -from cherrypy._cpcompat import basestring, iteritems, unicodestr, unquote_qs +from cherrypy._cpcompat import basestring, bytestr, iteritems, nativestr, unicodestr, unquote_qs response_codes = BaseHTTPRequestHandler.responses.copy() # From http://www.cherrypy.org/ticket/361 @@ -38,6 +38,18 @@ def urljoin(*atoms): # Special-case the final url of "", and return "/" instead. return url or "/" +def urljoin_bytes(*atoms): + """Return the given path *atoms, joined into a single URL. + + This will correctly join a SCRIPT_NAME and PATH_INFO into the + original URL, even if either atom is blank. + """ + url = ntob("/").join([x for x in atoms if x]) + while ntob("//") in url: + url = url.replace(ntob("//"), ntob("/")) + # Special-case the final url of "", and return "/" instead. + return url or ntob("/") + def protocol_from_http(protocol_str): """Return a protocol tuple from the given 'HTTP/x.y' string.""" return int(protocol_str[5]), int(protocol_str[7]) @@ -105,9 +117,15 @@ class HeaderElement(object): def __cmp__(self, other): return cmp(self.value, other.value) + def __lt__(self, other): + return self.value < other.value + def __str__(self): p = [";%s=%s" % (k, v) for k, v in iteritems(self.params)] return "%s%s" % (self.value, "".join(p)) + + def __bytes__(self): + return ntob(self.__str__()) def __unicode__(self): return ntou(self.__str__()) @@ -181,6 +199,12 @@ class AcceptElement(HeaderElement): if diff == 0: diff = cmp(str(self), str(other)) return diff + + def __lt__(self, other): + if self.qvalue == other.qvalue: + return str(self) < str(other) + else: + return self.qvalue < other.qvalue def header_elements(fieldname, fieldvalue): @@ -199,8 +223,12 @@ def header_elements(fieldname, fieldvalue): return list(reversed(sorted(result))) def decode_TEXT(value): - r"""Decode :rfc:`2047` TEXT (e.g. "=?utf-8?q?f=C3=BCr?=" -> u"f\xfcr").""" - from email.Header import decode_header + r"""Decode :rfc:`2047` TEXT (e.g. "=?utf-8?q?f=C3=BCr?=" -> "f\xfcr").""" + try: + # Python 3 + from email.header import decode_header + except ImportError: + from email.Header import decode_header atoms = decode_header(value) decodedvalue = "" for atom, charset in atoms: @@ -253,6 +281,10 @@ def valid_status(status): return code, reason, message +# NOTE: the parse_qs functions that follow are modified version of those +# in the python3.0 source - we need to pass through an encoding to the unquote +# method, but the default parse_qs function doesn't allow us to. These do. + def _parse_qs(qs, keep_blank_values=0, strict_parsing=0, encoding='utf-8'): """Parse a query given as a string argument. @@ -338,8 +370,9 @@ class CaseInsensitiveDict(dict): def get(self, key, default=None): return dict.get(self, str(key).title(), default) - def has_key(self, key): - return dict.has_key(self, str(key).title()) + if hasattr({}, 'has_key'): + def has_key(self, key): + return dict.has_key(self, str(key).title()) def update(self, E): for k in E.keys(): @@ -369,8 +402,12 @@ class CaseInsensitiveDict(dict): # A CRLF is allowed in the definition of TEXT only as part of a header # field continuation. It is expected that the folding LWS will be # replaced with a single SP before interpretation of the TEXT value." -header_translate_table = ''.join([chr(i) for i in xrange(256)]) -header_translate_deletechars = ''.join([chr(i) for i in xrange(32)]) + chr(127) +if nativestr == bytestr: + header_translate_table = ''.join([chr(i) for i in xrange(256)]) + header_translate_deletechars = ''.join([chr(i) for i in xrange(32)]) + chr(127) +else: + header_translate_table = None + header_translate_deletechars = bytes(range(32)) + bytes([127]) class HeaderMap(CaseInsensitiveDict): diff --git a/cherrypy/lib/jsontools.py b/cherrypy/lib/jsontools.py index 09042e45..20925791 100644 --- a/cherrypy/lib/jsontools.py +++ b/cherrypy/lib/jsontools.py @@ -82,6 +82,6 @@ def json_out(content_type='application/json', debug=False, handler=json_handler) request.handler = handler if content_type is not None: if debug: - cherrypy.log('Setting Content-Type to %s' % ct, 'TOOLS.JSON_OUT') + cherrypy.log('Setting Content-Type to %s' % content_type, 'TOOLS.JSON_OUT') cherrypy.serving.response.headers['Content-Type'] = content_type diff --git a/cherrypy/lib/reprconf.py b/cherrypy/lib/reprconf.py index e18949ee..ba8ff51e 100644 --- a/cherrypy/lib/reprconf.py +++ b/cherrypy/lib/reprconf.py @@ -28,6 +28,20 @@ try: set except NameError: from sets import Set as set + +try: + basestring +except NameError: + basestring = str + +try: + # Python 3 + import builtins +except ImportError: + # Python 2 + import __builtin__ as builtins + +import operator as _operator import sys def as_dict(config): @@ -195,10 +209,11 @@ class Parser(ConfigParser): if section not in result: result[section] = {} for option in self.options(section): - value = self.get(section, option, raw, vars) + value = self.get(section, option, raw=raw, vars=vars) try: value = unrepr(value) - except Exception, x: + except Exception: + x = sys.exc_info()[1] msg = ("Config error in section: %r, option: %r, " "value: %r. Config values must be valid Python." % (section, option, value)) @@ -216,7 +231,8 @@ class Parser(ConfigParser): # public domain "unrepr" implementation, found on the web and then improved. -class _Builder: + +class _Builder2: def build(self, o): m = getattr(self, 'build_' + o.__class__.__name__, None) @@ -225,6 +241,18 @@ class _Builder: repr(o.__class__.__name__)) return m(o) + def astnode(self, s): + """Return a Python2 ast Node compiled from a string.""" + try: + import compiler + except ImportError: + # Fallback to eval when compiler package is not available, + # e.g. IronPython 1.0. + return eval(s) + + p = compiler.parse("__tempvalue__ = " + s) + return p.getChildren()[1].getChildren()[0].getChildren()[1] + def build_Subscript(self, o): expr, flags, subs = o.getChildren() expr = self.build(expr) @@ -272,8 +300,7 @@ class _Builder: # See if the Name is in builtins. try: - import __builtin__ - return getattr(__builtin__, name) + return getattr(builtins, name) except AttributeError: pass @@ -282,6 +309,10 @@ class _Builder: def build_Add(self, o): left, right = map(self.build, o.getChildren()) return left + right + + def build_Mul(self, o): + left, right = map(self.build, o.getChildren()) + return left * right def build_Getattr(self, o): parent = self.build(o.expr) @@ -297,25 +328,128 @@ class _Builder: return self.build(o.getChildren()[0]) -def _astnode(s): - """Return a Python ast Node compiled from a string.""" - try: - import compiler - except ImportError: - # Fallback to eval when compiler package is not available, - # e.g. IronPython 1.0. - return eval(s) +class _Builder3: - p = compiler.parse("__tempvalue__ = " + s) - return p.getChildren()[1].getChildren()[0].getChildren()[1] + def build(self, o): + m = getattr(self, 'build_' + o.__class__.__name__, None) + if m is None: + raise TypeError("unrepr does not recognize %s" % + repr(o.__class__.__name__)) + return m(o) + def astnode(self, s): + """Return a Python3 ast Node compiled from a string.""" + try: + import ast + except ImportError: + # Fallback to eval when ast package is not available, + # e.g. IronPython 1.0. + return eval(s) + + p = ast.parse("__tempvalue__ = " + s) + return p.body[0].value + + def build_Subscript(self, o): + return self.build(o.value)[self.build(o.slice)] + + def build_Index(self, o): + return self.build(o.value) + + def build_Call(self, o): + callee = self.build(o.func) + + if o.args is None: + args = () + else: + args = tuple([self.build(a) for a in o.args]) + + if o.starargs is None: + starargs = () + else: + starargs = self.build(o.starargs) + + if o.kwargs is None: + kwargs = {} + else: + kwargs = self.build(o.kwargs) + + return callee(*(args + starargs), **kwargs) + + def build_List(self, o): + return list(map(self.build, o.elts)) + + def build_Str(self, o): + return o.s + + def build_Num(self, o): + return o.n + + def build_Dict(self, o): + return dict([(self.build(k), self.build(v)) + for k, v in zip(o.keys, o.values)]) + + def build_Tuple(self, o): + return tuple(self.build_List(o)) + + def build_Name(self, o): + name = o.id + if name == 'None': + return None + if name == 'True': + return True + if name == 'False': + return False + + # See if the Name is a package or module. If it is, import it. + try: + return modules(name) + except ImportError: + pass + + # See if the Name is in builtins. + try: + import builtins + return getattr(builtins, name) + except AttributeError: + pass + + raise TypeError("unrepr could not resolve the name %s" % repr(name)) + + def build_UnaryOp(self, o): + op, operand = map(self.build, [o.op, o.operand]) + return op(operand) + + def build_BinOp(self, o): + left, op, right = map(self.build, [o.left, o.op, o.right]) + return op(left, right) + + def build_Add(self, o): + return _operator.add + + def build_Mult(self, o): + return _operator.mul + + def build_USub(self, o): + return _operator.neg + + def build_Attribute(self, o): + parent = self.build(o.value) + return getattr(parent, o.attr) + + def build_NoneType(self, o): + return None + def unrepr(s): """Return a Python object compiled from a string.""" if not s: return s - obj = _astnode(s) - return _Builder().build(obj) + if sys.version_info < (3, 0): + b = _Builder2() + else: + b = _Builder3() + obj = b.astnode(s) + return b.build(obj) def modules(modulePath): diff --git a/cherrypy/lib/sessions.py b/cherrypy/lib/sessions.py index 42c28009..9763f120 100644 --- a/cherrypy/lib/sessions.py +++ b/cherrypy/lib/sessions.py @@ -93,7 +93,7 @@ import types from warnings import warn import cherrypy -from cherrypy._cpcompat import copyitems, pickle, random20 +from cherrypy._cpcompat import copyitems, pickle, random20, unicodestr from cherrypy.lib import httputil @@ -171,7 +171,15 @@ class Session(object): self.id = None self.missing = True self._regenerate() - + + def now(self): + """Generate the session specific concept of 'now'. + + Other session providers can override this to use alternative, + possibly timezone aware, versions of 'now'. + """ + return datetime.datetime.now() + def regenerate(self): """Replace the current session (with a new id).""" self.regenerated = True @@ -210,7 +218,7 @@ class Session(object): # accessed: no need to save it if self.loaded: t = datetime.timedelta(seconds = self.timeout * 60) - expiration_time = datetime.datetime.now() + t + expiration_time = self.now() + t if self.debug: cherrypy.log('Saving with expiry %s' % expiration_time, 'TOOLS.SESSIONS') @@ -225,7 +233,7 @@ class Session(object): """Copy stored session data into this session instance.""" data = self._load() # data is either None or a tuple (session_data, expiration_time) - if data is None or data[1] < datetime.datetime.now(): + if data is None or data[1] < self.now(): if self.debug: cherrypy.log('Expired session, flushing data', 'TOOLS.SESSIONS') self._data = {} @@ -277,10 +285,11 @@ class Session(object): if not self.loaded: self.load() return key in self._data - def has_key(self, key): - """D.has_key(k) -> True if D has a key k, else False.""" - if not self.loaded: self.load() - return key in self._data + if hasattr({}, 'has_key'): + def has_key(self, key): + """D.has_key(k) -> True if D has a key k, else False.""" + if not self.loaded: self.load() + return key in self._data def get(self, key, default=None): """D.get(k[,d]) -> D[k] if k in D, else d. d defaults to None.""" @@ -326,7 +335,7 @@ class RamSession(Session): def clean_up(self): """Clean up expired sessions.""" - now = datetime.datetime.now() + now = self.now() for id, (data, expiration_time) in copyitems(self.cache): if expiration_time <= now: try: @@ -337,6 +346,11 @@ class RamSession(Session): del self.locks[id] except KeyError: pass + + # added to remove obsolete lock objects + for id in list(self.locks): + if id not in self.cache: + self.locks.pop(id, None) def _exists(self): return self.id in self.cache @@ -467,7 +481,7 @@ class FileSession(Session): def clean_up(self): """Clean up expired sessions.""" - now = datetime.datetime.now() + now = self.now() # Iterate over all session files in self.storage_path for fname in os.listdir(self.storage_path): if (fname.startswith(self.SESSION_PREFIX) @@ -575,7 +589,7 @@ class PostgresqlSession(Session): def clean_up(self): """Clean up expired sessions.""" self.cursor.execute('delete from session where expiration_time < %s', - (datetime.datetime.now(),)) + (self.now(),)) class MemcachedSession(Session): @@ -602,6 +616,19 @@ class MemcachedSession(Session): cls.cache = memcache.Client(cls.servers) setup = classmethod(setup) + def _get_id(self): + return self._id + def _set_id(self, value): + # This encode() call is where we differ from the superclass. + # Memcache keys MUST be byte strings, not unicode. + if isinstance(value, unicodestr): + value = value.encode('utf-8') + + self._id = value + for o in self.id_observers: + o(value) + id = property(_get_id, _set_id, doc="The current session ID.") + def _exists(self): self.mc_lock.acquire() try: @@ -683,12 +710,12 @@ close.priority = 90 def init(storage_type='ram', path=None, path_header=None, name='session_id', timeout=60, domain=None, secure=False, clean_freq=5, - persistent=True, debug=False, **kwargs): + persistent=True, httponly=False, debug=False, **kwargs): """Initialize session object (using cookies). storage_type - One of 'ram', 'file', 'postgresql'. This will be used - to look up the corresponding class in cherrypy.lib.sessions + One of 'ram', 'file', 'postgresql', 'memcached'. This will be + used to look up the corresponding class in cherrypy.lib.sessions globals. For example, 'file' will use the FileSession class. path @@ -722,6 +749,10 @@ def init(storage_type='ram', path=None, path_header=None, name='session_id', and the cookie will be a "session cookie" which expires when the browser is closed. + httponly + If False (the default) the cookie 'httponly' value will not be set. + If True, the cookie 'httponly' value will be set (to 1). + Any additional kwargs will be bound to the new Session instance, and may be specific to the storage type. See the subclass of Session you're using for more information. @@ -772,11 +803,12 @@ def init(storage_type='ram', path=None, path_header=None, name='session_id', # and http://support.mozilla.com/en-US/kb/Cookies cookie_timeout = None set_response_cookie(path=path, path_header=path_header, name=name, - timeout=cookie_timeout, domain=domain, secure=secure) + timeout=cookie_timeout, domain=domain, secure=secure, + httponly=httponly) def set_response_cookie(path=None, path_header=None, name='session_id', - timeout=60, domain=None, secure=False): + timeout=60, domain=None, secure=False, httponly=False): """Set a response cookie for the client. path @@ -801,6 +833,10 @@ def set_response_cookie(path=None, path_header=None, name='session_id', if False (the default) the cookie 'secure' value will not be set. If True, the cookie 'secure' value will be set (to 1). + httponly + If False (the default) the cookie 'httponly' value will not be set. + If True, the cookie 'httponly' value will be set (to 1). + """ # Set response cookie cookie = cherrypy.serving.response.cookie @@ -820,7 +856,10 @@ def set_response_cookie(path=None, path_header=None, name='session_id', cookie[name]['domain'] = domain if secure: cookie[name]['secure'] = 1 - + if httponly: + if not cookie[name].isReservedKey('httponly'): + raise ValueError("The httponly cookie token is not supported.") + cookie[name]['httponly'] = 1 def expire(): """Expire the current session cookie.""" diff --git a/cherrypy/lib/static.py b/cherrypy/lib/static.py index cb9a68cb..2d142307 100644 --- a/cherrypy/lib/static.py +++ b/cherrypy/lib/static.py @@ -1,3 +1,7 @@ +try: + from io import UnsupportedOperation +except ImportError: + UnsupportedOperation = object() import logging import mimetypes mimetypes.init() @@ -115,6 +119,8 @@ def serve_fileobj(fileobj, content_type=None, disposition=None, name=None, if debug: cherrypy.log('os has no fstat attribute', 'TOOLS.STATIC') content_length = None + except UnsupportedOperation: + content_length = None else: # Set the Last-Modified response header, so that # modified-since validation code can work. @@ -174,7 +180,12 @@ def _serve_fileobj(fileobj, content_type, content_length, debug=False): else: # Return a multipart/byteranges response. response.status = "206 Partial Content" - from mimetools import choose_boundary + try: + # Python 3 + from email.generator import _make_boundary as choose_boundary + except ImportError: + # Python 2 + from mimetools import choose_boundary boundary = choose_boundary() ct = "multipart/byteranges; boundary=%s" % boundary response.headers['Content-Type'] = ct diff --git a/cherrypy/lib/xmlrpc.py b/cherrypy/lib/xmlrpcutil.py similarity index 61% rename from cherrypy/lib/xmlrpc.py rename to cherrypy/lib/xmlrpcutil.py index 8a5ef546..9a44464b 100644 --- a/cherrypy/lib/xmlrpc.py +++ b/cherrypy/lib/xmlrpcutil.py @@ -1,13 +1,19 @@ import sys import cherrypy +from cherrypy._cpcompat import ntob +def get_xmlrpclib(): + try: + import xmlrpc.client as x + except ImportError: + import xmlrpclib as x + return x def process_body(): """Return (params, method) from request body.""" try: - import xmlrpclib - return xmlrpclib.loads(cherrypy.request.body.read()) + return get_xmlrpclib().loads(cherrypy.request.body.read()) except Exception: return ('ERROR PARAMS', ), 'ERRORMETHOD' @@ -29,21 +35,21 @@ def _set_response(body): # as a "Protocol Error", we'll just return 200 every time. response = cherrypy.response response.status = '200 OK' - response.body = body + response.body = ntob(body, 'utf-8') response.headers['Content-Type'] = 'text/xml' response.headers['Content-Length'] = len(body) def respond(body, encoding='utf-8', allow_none=0): - from xmlrpclib import Fault, dumps - if not isinstance(body, Fault): + xmlrpclib = get_xmlrpclib() + if not isinstance(body, xmlrpclib.Fault): body = (body,) - _set_response(dumps(body, methodresponse=1, - encoding=encoding, - allow_none=allow_none)) + _set_response(xmlrpclib.dumps(body, methodresponse=1, + encoding=encoding, + allow_none=allow_none)) def on_error(*args, **kwargs): body = str(sys.exc_info()[1]) - from xmlrpclib import Fault, dumps - _set_response(dumps(Fault(1, body))) + xmlrpclib = get_xmlrpclib() + _set_response(xmlrpclib.dumps(xmlrpclib.Fault(1, body))) diff --git a/cherrypy/process/plugins.py b/cherrypy/process/plugins.py index 488958eb..ba618a0b 100644 --- a/cherrypy/process/plugins.py +++ b/cherrypy/process/plugins.py @@ -453,13 +453,14 @@ class BackgroundTask(threading.Thread): it won't delay stopping the whole process. """ - def __init__(self, interval, function, args=[], kwargs={}): + def __init__(self, interval, function, args=[], kwargs={}, bus=None): threading.Thread.__init__(self) self.interval = interval self.function = function self.args = args self.kwargs = kwargs self.running = False + self.bus = bus def cancel(self): self.running = False @@ -473,8 +474,9 @@ class BackgroundTask(threading.Thread): try: self.function(*self.args, **self.kwargs) except Exception: - self.bus.log("Error in background task thread function %r." % - self.function, level=40, traceback=True) + if self.bus: + self.bus.log("Error in background task thread function %r." + % self.function, level=40, traceback=True) # Quit on first error to avoid massive logs. raise @@ -506,8 +508,8 @@ class Monitor(SimplePlugin): if self.frequency > 0: threadname = self.name or self.__class__.__name__ if self.thread is None: - self.thread = BackgroundTask(self.frequency, self.callback) - self.thread.bus = self.bus + self.thread = BackgroundTask(self.frequency, self.callback, + bus = self.bus) self.thread.setName(threadname) self.thread.start() self.bus.log("Started monitor thread %r." % threadname) diff --git a/cherrypy/process/servers.py b/cherrypy/process/servers.py index 272e8436..fa714d65 100644 --- a/cherrypy/process/servers.py +++ b/cherrypy/process/servers.py @@ -385,34 +385,43 @@ def check_port(host, port, timeout=1.0): if s: s.close() -def wait_for_free_port(host, port): + +# Feel free to increase these defaults on slow systems: +free_port_timeout = 0.1 +occupied_port_timeout = 1.0 + +def wait_for_free_port(host, port, timeout=None): """Wait for the specified port to become free (drop requests).""" if not host: raise ValueError("Host values of '' or None are not allowed.") + if timeout is None: + timeout = free_port_timeout for trial in range(50): try: # we are expecting a free port, so reduce the timeout - check_port(host, port, timeout=0.1) + check_port(host, port, timeout=timeout) except IOError: # Give the old server thread time to free the port. - time.sleep(0.1) + time.sleep(timeout) else: return raise IOError("Port %r not free on %r" % (port, host)) -def wait_for_occupied_port(host, port): +def wait_for_occupied_port(host, port, timeout=None): """Wait for the specified port to become active (receive requests).""" if not host: raise ValueError("Host values of '' or None are not allowed.") + if timeout is None: + timeout = occupied_port_timeout for trial in range(50): try: - check_port(host, port) + check_port(host, port, timeout=timeout) except IOError: return else: - time.sleep(.1) + time.sleep(timeout) raise IOError("Port %r not bound on %r" % (port, host)) diff --git a/cherrypy/process/wspbus.py b/cherrypy/process/wspbus.py index 46cd75a2..6ef768dc 100644 --- a/cherrypy/process/wspbus.py +++ b/cherrypy/process/wspbus.py @@ -90,11 +90,11 @@ class ChannelFailures(Exception): def handle_exception(self): """Append the current exception to self.""" - self._exceptions.append(sys.exc_info()) + self._exceptions.append(sys.exc_info()[1]) def get_instances(self): """Return a list of seen exception instances.""" - return [instance for cls, instance, traceback in self._exceptions] + return self._exceptions[:] def __str__(self): exception_strings = map(repr, self.get_instances()) @@ -102,8 +102,9 @@ class ChannelFailures(Exception): __repr__ = __str__ - def __nonzero__(self): + def __bool__(self): return bool(self._exceptions) + __nonzero__ = __bool__ # Use a flag to indicate the state of the bus. class _StateEnum(object): @@ -124,6 +125,17 @@ states.STOPPING = states.State() states.EXITING = states.State() +try: + import fcntl +except ImportError: + max_files = 0 +else: + try: + max_files = os.sysconf('SC_OPEN_MAX') + except AttributeError: + max_files = 1024 + + class Bus(object): """Process state-machine and messenger for HTTP site deployment. @@ -137,6 +149,7 @@ class Bus(object): states = states state = states.STOPPED execv = False + max_cloexec_files = max_files def __init__(self): self.execv = False @@ -173,13 +186,19 @@ class Bus(object): items = [(self._priorities[(channel, listener)], listener) for listener in self.listeners[channel]] - items.sort() + try: + items.sort(key=lambda item: item[0]) + except TypeError: + # Python 2.3 had no 'key' arg, but that doesn't matter + # since it could sort dissimilar types just fine. + items.sort() for priority, listener in items: try: output.append(listener(*args, **kwargs)) except KeyboardInterrupt: raise - except SystemExit, e: + except SystemExit: + e = sys.exc_info()[1] # If we have previous errors ensure the exit code is non-zero if exc and e.code == 0: e.code = 1 @@ -221,13 +240,14 @@ class Bus(object): except: self.log("Shutting down due to error in start listener:", level=40, traceback=True) - e_info = sys.exc_info() + e_info = sys.exc_info()[1] try: self.exit() except: # Any stop/exit errors will be logged inside publish(). pass - raise e_info[0], e_info[1], e_info[2] + # Re-raise the original error + raise e_info def exit(self): """Stop all services and prepare to exit the process.""" @@ -354,8 +374,28 @@ class Bus(object): args = ['"%s"' % arg for arg in args] os.chdir(_startup_cwd) + if self.max_cloexec_files: + self._set_cloexec() os.execv(sys.executable, args) + def _set_cloexec(self): + """Set the CLOEXEC flag on all open files (except stdin/out/err). + + If self.max_cloexec_files is an integer (the default), then on + platforms which support it, it represents the max open files setting + for the operating system. This function will be called just before + the process is restarted via os.execv() to prevent open files + from persisting into the new process. + + Set self.max_cloexec_files to 0 to disable this behavior. + """ + for fd in range(3, self.max_cloexec_files): # skip stdin/out/err + try: + flags = fcntl.fcntl(fd, fcntl.F_GETFD) + except IOError: + continue + fcntl.fcntl(fd, fcntl.F_SETFD, flags | fcntl.FD_CLOEXEC) + def stop(self): """Stop all services.""" self.state = states.STOPPING @@ -386,8 +426,7 @@ class Bus(object): def log(self, msg="", level=20, traceback=False): """Log the given message. Append the last traceback if requested.""" if traceback: - exc = sys.exc_info() - msg += "\n" + "".join(_traceback.format_exception(*exc)) + msg += "\n" + "".join(_traceback.format_exception(*sys.exc_info())) self.publish('log', msg, level) bus = Bus() diff --git a/cherrypy/test/__init__.py b/cherrypy/test/__init__.py deleted file mode 100644 index e4c400d6..00000000 --- a/cherrypy/test/__init__.py +++ /dev/null @@ -1,25 +0,0 @@ -"""Regression test suite for CherryPy. - -Run 'nosetests -s test/' to exercise all tests. - -The '-s' flag instructs nose to output stdout messages, wihch is crucial to -the 'interactive' mode of webtest.py. If you run these tests without the '-s' -flag, don't be surprised if the test seems to hang: it's waiting for your -interactive input. -""" - -import sys -def newexit(): - raise SystemExit('Exit called') - -def setup(): - # We want to monkey patch sys.exit so that we can get some - # information about where exit is being called. - newexit._old = sys.exit - sys.exit = newexit - -def teardown(): - try: - sys.exit = sys.exit._old - except AttributeError: - sys.exit = sys._exit diff --git a/cherrypy/test/_test_decorators.py b/cherrypy/test/_test_decorators.py deleted file mode 100644 index 5bcbc1e6..00000000 --- a/cherrypy/test/_test_decorators.py +++ /dev/null @@ -1,41 +0,0 @@ -"""Test module for the @-decorator syntax, which is version-specific""" - -from cherrypy import expose, tools -from cherrypy._cpcompat import ntob - - -class ExposeExamples(object): - - @expose - def no_call(self): - return "Mr E. R. Bradshaw" - - @expose() - def call_empty(self): - return "Mrs. B.J. Smegma" - - @expose("call_alias") - def nesbitt(self): - return "Mr Nesbitt" - - @expose(["alias1", "alias2"]) - def andrews(self): - return "Mr Ken Andrews" - - @expose(alias="alias3") - def watson(self): - return "Mr. and Mrs. Watson" - - -class ToolExamples(object): - - @expose - @tools.response_headers(headers=[('Content-Type', 'application/data')]) - def blah(self): - yield ntob("blah") - # This is here to demonstrate that _cp_config = {...} overwrites - # the _cp_config attribute added by the Tool decorator. You have - # to write _cp_config[k] = v or _cp_config.update(...) instead. - blah._cp_config['response.stream'] = True - - diff --git a/cherrypy/test/_test_states_demo.py b/cherrypy/test/_test_states_demo.py deleted file mode 100644 index 3f8f196c..00000000 --- a/cherrypy/test/_test_states_demo.py +++ /dev/null @@ -1,66 +0,0 @@ -import os -import sys -import time -starttime = time.time() - -import cherrypy - - -class Root: - - def index(self): - return "Hello World" - index.exposed = True - - def mtimes(self): - return repr(cherrypy.engine.publish("Autoreloader", "mtimes")) - mtimes.exposed = True - - def pid(self): - return str(os.getpid()) - pid.exposed = True - - def start(self): - return repr(starttime) - start.exposed = True - - def exit(self): - # This handler might be called before the engine is STARTED if an - # HTTP worker thread handles it before the HTTP server returns - # control to engine.start. We avoid that race condition here - # by waiting for the Bus to be STARTED. - cherrypy.engine.wait(state=cherrypy.engine.states.STARTED) - cherrypy.engine.exit() - exit.exposed = True - - -def unsub_sig(): - cherrypy.log("unsubsig: %s" % cherrypy.config.get('unsubsig', False)) - if cherrypy.config.get('unsubsig', False): - cherrypy.log("Unsubscribing the default cherrypy signal handler") - cherrypy.engine.signal_handler.unsubscribe() - try: - from signal import signal, SIGTERM - except ImportError: - pass - else: - def old_term_handler(signum=None, frame=None): - cherrypy.log("I am an old SIGTERM handler.") - sys.exit(0) - cherrypy.log("Subscribing the new one.") - signal(SIGTERM, old_term_handler) -cherrypy.engine.subscribe('start', unsub_sig, priority=100) - - -def starterror(): - if cherrypy.config.get('starterror', False): - zerodiv = 1 / 0 -cherrypy.engine.subscribe('start', starterror, priority=6) - -def log_test_case_name(): - if cherrypy.config.get('test_case_name', False): - cherrypy.log("STARTED FROM: %s" % cherrypy.config.get('test_case_name')) -cherrypy.engine.subscribe('start', log_test_case_name, priority=6) - - -cherrypy.tree.mount(Root(), '/', {'/': {}}) diff --git a/cherrypy/test/benchmark.py b/cherrypy/test/benchmark.py deleted file mode 100644 index 90536a56..00000000 --- a/cherrypy/test/benchmark.py +++ /dev/null @@ -1,409 +0,0 @@ -"""CherryPy Benchmark Tool - - Usage: - benchmark.py --null --notests --help --cpmodpy --modpython --ab=path --apache=path - - --null: use a null Request object (to bench the HTTP server only) - --notests: start the server but do not run the tests; this allows - you to check the tested pages with a browser - --help: show this help message - --cpmodpy: run tests via apache on 8080 (with the builtin _cpmodpy) - --modpython: run tests via apache on 8080 (with modpython_gateway) - --ab=path: Use the ab script/executable at 'path' (see below) - --apache=path: Use the apache script/exe at 'path' (see below) - - To run the benchmarks, the Apache Benchmark tool "ab" must either be on - your system path, or specified via the --ab=path option. - - To run the modpython tests, the "apache" executable or script must be - on your system path, or provided via the --apache=path option. On some - platforms, "apache" may be called "apachectl" or "apache2ctl"--create - a symlink to them if needed. -""" - -import getopt -import os -curdir = os.path.join(os.getcwd(), os.path.dirname(__file__)) - -import re -import sys -import time -import traceback - -import cherrypy -from cherrypy._cpcompat import ntob -from cherrypy import _cperror, _cpmodpy -from cherrypy.lib import httputil - - -AB_PATH = "" -APACHE_PATH = "apache" -SCRIPT_NAME = "/cpbench/users/rdelon/apps/blog" - -__all__ = ['ABSession', 'Root', 'print_report', - 'run_standard_benchmarks', 'safe_threads', - 'size_report', 'startup', 'thread_report', - ] - -size_cache = {} - -class Root: - - def index(self): - return """ - - CherryPy Benchmark - - - - -""" - index.exposed = True - - def hello(self): - return "Hello, world\r\n" - hello.exposed = True - - def sizer(self, size): - resp = size_cache.get(size, None) - if resp is None: - size_cache[size] = resp = "X" * int(size) - return resp - sizer.exposed = True - - -cherrypy.config.update({ - 'log.error.file': '', - 'environment': 'production', - 'server.socket_host': '127.0.0.1', - 'server.socket_port': 8080, - 'server.max_request_header_size': 0, - 'server.max_request_body_size': 0, - 'engine.deadlock_poll_freq': 0, - }) - -# Cheat mode on ;) -del cherrypy.config['tools.log_tracebacks.on'] -del cherrypy.config['tools.log_headers.on'] -del cherrypy.config['tools.trailing_slash.on'] - -appconf = { - '/static': { - 'tools.staticdir.on': True, - 'tools.staticdir.dir': 'static', - 'tools.staticdir.root': curdir, - }, - } -app = cherrypy.tree.mount(Root(), SCRIPT_NAME, appconf) - - -class NullRequest: - """A null HTTP request class, returning 200 and an empty body.""" - - def __init__(self, local, remote, scheme="http"): - pass - - def close(self): - pass - - def run(self, method, path, query_string, protocol, headers, rfile): - cherrypy.response.status = "200 OK" - cherrypy.response.header_list = [("Content-Type", 'text/html'), - ("Server", "Null CherryPy"), - ("Date", httputil.HTTPDate()), - ("Content-Length", "0"), - ] - cherrypy.response.body = [""] - return cherrypy.response - - -class NullResponse: - pass - - -class ABSession: - """A session of 'ab', the Apache HTTP server benchmarking tool. - -Example output from ab: - -This is ApacheBench, Version 2.0.40-dev <$Revision: 1.121.2.1 $> apache-2.0 -Copyright (c) 1996 Adam Twiss, Zeus Technology Ltd, http://www.zeustech.net/ -Copyright (c) 1998-2002 The Apache Software Foundation, http://www.apache.org/ - -Benchmarking 127.0.0.1 (be patient) -Completed 100 requests -Completed 200 requests -Completed 300 requests -Completed 400 requests -Completed 500 requests -Completed 600 requests -Completed 700 requests -Completed 800 requests -Completed 900 requests - - -Server Software: CherryPy/3.1beta -Server Hostname: 127.0.0.1 -Server Port: 8080 - -Document Path: /static/index.html -Document Length: 14 bytes - -Concurrency Level: 10 -Time taken for tests: 9.643867 seconds -Complete requests: 1000 -Failed requests: 0 -Write errors: 0 -Total transferred: 189000 bytes -HTML transferred: 14000 bytes -Requests per second: 103.69 [#/sec] (mean) -Time per request: 96.439 [ms] (mean) -Time per request: 9.644 [ms] (mean, across all concurrent requests) -Transfer rate: 19.08 [Kbytes/sec] received - -Connection Times (ms) - min mean[+/-sd] median max -Connect: 0 0 2.9 0 10 -Processing: 20 94 7.3 90 130 -Waiting: 0 43 28.1 40 100 -Total: 20 95 7.3 100 130 - -Percentage of the requests served within a certain time (ms) - 50% 100 - 66% 100 - 75% 100 - 80% 100 - 90% 100 - 95% 100 - 98% 100 - 99% 110 - 100% 130 (longest request) -Finished 1000 requests -""" - - parse_patterns = [('complete_requests', 'Completed', - ntob(r'^Complete requests:\s*(\d+)')), - ('failed_requests', 'Failed', - ntob(r'^Failed requests:\s*(\d+)')), - ('requests_per_second', 'req/sec', - ntob(r'^Requests per second:\s*([0-9.]+)')), - ('time_per_request_concurrent', 'msec/req', - ntob(r'^Time per request:\s*([0-9.]+).*concurrent requests\)$')), - ('transfer_rate', 'KB/sec', - ntob(r'^Transfer rate:\s*([0-9.]+)')), - ] - - def __init__(self, path=SCRIPT_NAME + "/hello", requests=1000, concurrency=10): - self.path = path - self.requests = requests - self.concurrency = concurrency - - def args(self): - port = cherrypy.server.socket_port - assert self.concurrency > 0 - assert self.requests > 0 - # Don't use "localhost". - # Cf http://mail.python.org/pipermail/python-win32/2008-March/007050.html - return ("-k -n %s -c %s http://127.0.0.1:%s%s" % - (self.requests, self.concurrency, port, self.path)) - - def run(self): - # Parse output of ab, setting attributes on self - try: - self.output = _cpmodpy.read_process(AB_PATH or "ab", self.args()) - except: - print(_cperror.format_exc()) - raise - - for attr, name, pattern in self.parse_patterns: - val = re.search(pattern, self.output, re.MULTILINE) - if val: - val = val.group(1) - setattr(self, attr, val) - else: - setattr(self, attr, None) - - -safe_threads = (25, 50, 100, 200, 400) -if sys.platform in ("win32",): - # For some reason, ab crashes with > 50 threads on my Win2k laptop. - safe_threads = (10, 20, 30, 40, 50) - - -def thread_report(path=SCRIPT_NAME + "/hello", concurrency=safe_threads): - sess = ABSession(path) - attrs, names, patterns = list(zip(*sess.parse_patterns)) - avg = dict.fromkeys(attrs, 0.0) - - yield ('threads',) + names - for c in concurrency: - sess.concurrency = c - sess.run() - row = [c] - for attr in attrs: - val = getattr(sess, attr) - if val is None: - print(sess.output) - row = None - break - val = float(val) - avg[attr] += float(val) - row.append(val) - if row: - yield row - - # Add a row of averages. - yield ["Average"] + [str(avg[attr] / len(concurrency)) for attr in attrs] - -def size_report(sizes=(10, 100, 1000, 10000, 100000, 100000000), - concurrency=50): - sess = ABSession(concurrency=concurrency) - attrs, names, patterns = list(zip(*sess.parse_patterns)) - yield ('bytes',) + names - for sz in sizes: - sess.path = "%s/sizer?size=%s" % (SCRIPT_NAME, sz) - sess.run() - yield [sz] + [getattr(sess, attr) for attr in attrs] - -def print_report(rows): - for row in rows: - print("") - for i, val in enumerate(row): - sys.stdout.write(str(val).rjust(10) + " | ") - print("") - - -def run_standard_benchmarks(): - print("") - print("Client Thread Report (1000 requests, 14 byte response body, " - "%s server threads):" % cherrypy.server.thread_pool) - print_report(thread_report()) - - print("") - print("Client Thread Report (1000 requests, 14 bytes via staticdir, " - "%s server threads):" % cherrypy.server.thread_pool) - print_report(thread_report("%s/static/index.html" % SCRIPT_NAME)) - - print("") - print("Size Report (1000 requests, 50 client threads, " - "%s server threads):" % cherrypy.server.thread_pool) - print_report(size_report()) - - -# modpython and other WSGI # - -def startup_modpython(req=None): - """Start the CherryPy app server in 'serverless' mode (for modpython/WSGI).""" - if cherrypy.engine.state == cherrypy._cpengine.STOPPED: - if req: - if "nullreq" in req.get_options(): - cherrypy.engine.request_class = NullRequest - cherrypy.engine.response_class = NullResponse - ab_opt = req.get_options().get("ab", "") - if ab_opt: - global AB_PATH - AB_PATH = ab_opt - cherrypy.engine.start() - if cherrypy.engine.state == cherrypy._cpengine.STARTING: - cherrypy.engine.wait() - return 0 # apache.OK - - -def run_modpython(use_wsgi=False): - print("Starting mod_python...") - pyopts = [] - - # Pass the null and ab=path options through Apache - if "--null" in opts: - pyopts.append(("nullreq", "")) - - if "--ab" in opts: - pyopts.append(("ab", opts["--ab"])) - - s = _cpmodpy.ModPythonServer - if use_wsgi: - pyopts.append(("wsgi.application", "cherrypy::tree")) - pyopts.append(("wsgi.startup", "cherrypy.test.benchmark::startup_modpython")) - handler = "modpython_gateway::handler" - s = s(port=8080, opts=pyopts, apache_path=APACHE_PATH, handler=handler) - else: - pyopts.append(("cherrypy.setup", "cherrypy.test.benchmark::startup_modpython")) - s = s(port=8080, opts=pyopts, apache_path=APACHE_PATH) - - try: - s.start() - run() - finally: - s.stop() - - - -if __name__ == '__main__': - longopts = ['cpmodpy', 'modpython', 'null', 'notests', - 'help', 'ab=', 'apache='] - try: - switches, args = getopt.getopt(sys.argv[1:], "", longopts) - opts = dict(switches) - except getopt.GetoptError: - print(__doc__) - sys.exit(2) - - if "--help" in opts: - print(__doc__) - sys.exit(0) - - if "--ab" in opts: - AB_PATH = opts['--ab'] - - if "--notests" in opts: - # Return without stopping the server, so that the pages - # can be tested from a standard web browser. - def run(): - port = cherrypy.server.socket_port - print("You may now open http://127.0.0.1:%s%s/" % - (port, SCRIPT_NAME)) - - if "--null" in opts: - print("Using null Request object") - else: - def run(): - end = time.time() - start - print("Started in %s seconds" % end) - if "--null" in opts: - print("\nUsing null Request object") - try: - try: - run_standard_benchmarks() - except: - print(_cperror.format_exc()) - raise - finally: - cherrypy.engine.exit() - - print("Starting CherryPy app server...") - - class NullWriter(object): - """Suppresses the printing of socket errors.""" - def write(self, data): - pass - sys.stderr = NullWriter() - - start = time.time() - - if "--cpmodpy" in opts: - run_modpython() - elif "--modpython" in opts: - run_modpython(use_wsgi=True) - else: - if "--null" in opts: - cherrypy.server.request_class = NullRequest - cherrypy.server.response_class = NullResponse - - cherrypy.engine.start_with_callback(run) - cherrypy.engine.block() diff --git a/cherrypy/test/checkerdemo.py b/cherrypy/test/checkerdemo.py deleted file mode 100644 index 32a7dee2..00000000 --- a/cherrypy/test/checkerdemo.py +++ /dev/null @@ -1,47 +0,0 @@ -"""Demonstration app for cherrypy.checker. - -This application is intentionally broken and badly designed. -To demonstrate the output of the CherryPy Checker, simply execute -this module. -""" - -import os -import cherrypy -thisdir = os.path.dirname(os.path.abspath(__file__)) - -class Root: - pass - -if __name__ == '__main__': - conf = {'/base': {'tools.staticdir.root': thisdir, - # Obsolete key. - 'throw_errors': True, - }, - # This entry should be OK. - '/base/static': {'tools.staticdir.on': True, - 'tools.staticdir.dir': 'static'}, - # Warn on missing folder. - '/base/js': {'tools.staticdir.on': True, - 'tools.staticdir.dir': 'js'}, - # Warn on dir with an abs path even though we provide root. - '/base/static2': {'tools.staticdir.on': True, - 'tools.staticdir.dir': '/static'}, - # Warn on dir with a relative path with no root. - '/static3': {'tools.staticdir.on': True, - 'tools.staticdir.dir': 'static'}, - # Warn on unknown namespace - '/unknown': {'toobles.gzip.on': True}, - # Warn special on cherrypy..* - '/cpknown': {'cherrypy.tools.encode.on': True}, - # Warn on mismatched types - '/conftype': {'request.show_tracebacks': 14}, - # Warn on unknown tool. - '/web': {'tools.unknown.on': True}, - # Warn on server.* in app config. - '/app1': {'server.socket_host': '0.0.0.0'}, - # Warn on 'localhost' - 'global': {'server.socket_host': 'localhost'}, - # Warn on '[name]' - '[/extra_brackets]': {}, - } - cherrypy.quickstart(Root(), config=conf) diff --git a/cherrypy/test/fastcgi.conf b/cherrypy/test/fastcgi.conf deleted file mode 100644 index e5c5163c..00000000 --- a/cherrypy/test/fastcgi.conf +++ /dev/null @@ -1,18 +0,0 @@ - -# Apache2 server conf file for testing CherryPy with mod_fastcgi. -# fumanchu: I had to hard-code paths due to crazy Debian layouts :( -ServerRoot /usr/lib/apache2 -User #1000 -ErrorLog /usr/lib/python2.5/site-packages/cproot/trunk/cherrypy/test/mod_fastcgi.error.log - -DocumentRoot "/usr/lib/python2.5/site-packages/cproot/trunk/cherrypy/test" -ServerName 127.0.0.1 -Listen 8080 -LoadModule fastcgi_module modules/mod_fastcgi.so -LoadModule rewrite_module modules/mod_rewrite.so - -Options +ExecCGI -SetHandler fastcgi-script -RewriteEngine On -RewriteRule ^(.*)$ /fastcgi.pyc [L] -FastCgiExternalServer "/usr/lib/python2.5/site-packages/cproot/trunk/cherrypy/test/fastcgi.pyc" -host 127.0.0.1:4000 diff --git a/cherrypy/test/fcgi.conf b/cherrypy/test/fcgi.conf deleted file mode 100644 index 8cf24b64..00000000 --- a/cherrypy/test/fcgi.conf +++ /dev/null @@ -1,14 +0,0 @@ - -# Apache2 server conf file for testing CherryPy with mod_fcgid. - -DocumentRoot "C:\Python25\Lib\site-packages\cherrypy\test" -ServerName 127.0.0.1 -Listen 8080 -LoadModule fastcgi_module modules/mod_fastcgi.dll -LoadModule rewrite_module modules/mod_rewrite.so - -Options ExecCGI -SetHandler fastcgi-script -RewriteEngine On -RewriteRule ^(.*)$ /fastcgi.pyc [L] -FastCgiExternalServer "C:\\Python25\\Lib\\site-packages\\cherrypy\\test\\fastcgi.pyc" -host 127.0.0.1:4000 diff --git a/cherrypy/test/helper.py b/cherrypy/test/helper.py deleted file mode 100644 index ff9e06cf..00000000 --- a/cherrypy/test/helper.py +++ /dev/null @@ -1,476 +0,0 @@ -"""A library of helper functions for the CherryPy test suite.""" - -import datetime -import logging -log = logging.getLogger(__name__) -import os -thisdir = os.path.abspath(os.path.dirname(__file__)) -serverpem = os.path.join(os.getcwd(), thisdir, 'test.pem') - -import re -import sys -import time -import warnings - -import cherrypy -from cherrypy._cpcompat import basestring, copyitems, HTTPSConnection, ntob -from cherrypy.lib import httputil -from cherrypy.lib.reprconf import unrepr -from cherrypy.test import webtest - -import nose - -_testconfig = None - -def get_tst_config(overconf = {}): - global _testconfig - if _testconfig is None: - conf = { - 'scheme': 'http', - 'protocol': "HTTP/1.1", - 'port': 8080, - 'host': '127.0.0.1', - 'validate': False, - 'conquer': False, - 'server': 'wsgi', - } - try: - import testconfig - _conf = testconfig.config.get('supervisor', None) - if _conf is not None: - for k, v in _conf.items(): - if isinstance(v, basestring): - _conf[k] = unrepr(v) - conf.update(_conf) - except ImportError: - pass - _testconfig = conf - conf = _testconfig.copy() - conf.update(overconf) - - return conf - -class Supervisor(object): - """Base class for modeling and controlling servers during testing.""" - - def __init__(self, **kwargs): - for k, v in kwargs.items(): - if k == 'port': - setattr(self, k, int(v)) - setattr(self, k, v) - - -log_to_stderr = lambda msg, level: sys.stderr.write(msg + os.linesep) - -class LocalSupervisor(Supervisor): - """Base class for modeling/controlling servers which run in the same process. - - When the server side runs in a different process, start/stop can dump all - state between each test module easily. When the server side runs in the - same process as the client, however, we have to do a bit more work to ensure - config and mounted apps are reset between tests. - """ - - using_apache = False - using_wsgi = False - - def __init__(self, **kwargs): - for k, v in kwargs.items(): - setattr(self, k, v) - - cherrypy.server.httpserver = self.httpserver_class - - engine = cherrypy.engine - if hasattr(engine, "signal_handler"): - engine.signal_handler.subscribe() - if hasattr(engine, "console_control_handler"): - engine.console_control_handler.subscribe() - #engine.subscribe('log', log_to_stderr) - - def start(self, modulename=None): - """Load and start the HTTP server.""" - if modulename: - # Unhook httpserver so cherrypy.server.start() creates a new - # one (with config from setup_server, if declared). - cherrypy.server.httpserver = None - - cherrypy.engine.start() - - self.sync_apps() - - def sync_apps(self): - """Tell the server about any apps which the setup functions mounted.""" - pass - - def stop(self): - td = getattr(self, 'teardown', None) - if td: - td() - - cherrypy.engine.exit() - - for name, server in copyitems(getattr(cherrypy, 'servers', {})): - server.unsubscribe() - del cherrypy.servers[name] - - -class NativeServerSupervisor(LocalSupervisor): - """Server supervisor for the builtin HTTP server.""" - - httpserver_class = "cherrypy._cpnative_server.CPHTTPServer" - using_apache = False - using_wsgi = False - - def __str__(self): - return "Builtin HTTP Server on %s:%s" % (self.host, self.port) - - -class LocalWSGISupervisor(LocalSupervisor): - """Server supervisor for the builtin WSGI server.""" - - httpserver_class = "cherrypy._cpwsgi_server.CPWSGIServer" - using_apache = False - using_wsgi = True - - def __str__(self): - return "Builtin WSGI Server on %s:%s" % (self.host, self.port) - - def sync_apps(self): - """Hook a new WSGI app into the origin server.""" - cherrypy.server.httpserver.wsgi_app = self.get_app() - - def get_app(self, app=None): - """Obtain a new (decorated) WSGI app to hook into the origin server.""" - if app is None: - app = cherrypy.tree - - if self.conquer: - try: - import wsgiconq - except ImportError: - warnings.warn("Error importing wsgiconq. pyconquer will not run.") - else: - app = wsgiconq.WSGILogger(app, c_calls=True) - - if self.validate: - try: - from wsgiref import validate - except ImportError: - warnings.warn("Error importing wsgiref. The validator will not run.") - else: - #wraps the app in the validator - app = validate.validator(app) - - return app - - -def get_cpmodpy_supervisor(**options): - from cherrypy.test import modpy - sup = modpy.ModPythonSupervisor(**options) - sup.template = modpy.conf_cpmodpy - return sup - -def get_modpygw_supervisor(**options): - from cherrypy.test import modpy - sup = modpy.ModPythonSupervisor(**options) - sup.template = modpy.conf_modpython_gateway - sup.using_wsgi = True - return sup - -def get_modwsgi_supervisor(**options): - from cherrypy.test import modwsgi - return modwsgi.ModWSGISupervisor(**options) - -def get_modfcgid_supervisor(**options): - from cherrypy.test import modfcgid - return modfcgid.ModFCGISupervisor(**options) - -def get_modfastcgi_supervisor(**options): - from cherrypy.test import modfastcgi - return modfastcgi.ModFCGISupervisor(**options) - -def get_wsgi_u_supervisor(**options): - cherrypy.server.wsgi_version = ('u', 0) - return LocalWSGISupervisor(**options) - - -class CPWebCase(webtest.WebCase): - - script_name = "" - scheme = "http" - - available_servers = {'wsgi': LocalWSGISupervisor, - 'wsgi_u': get_wsgi_u_supervisor, - 'native': NativeServerSupervisor, - 'cpmodpy': get_cpmodpy_supervisor, - 'modpygw': get_modpygw_supervisor, - 'modwsgi': get_modwsgi_supervisor, - 'modfcgid': get_modfcgid_supervisor, - 'modfastcgi': get_modfastcgi_supervisor, - } - default_server = "wsgi" - - def _setup_server(cls, supervisor, conf): - v = sys.version.split()[0] - log.info("Python version used to run this test script: %s" % v) - log.info("CherryPy version: %s" % cherrypy.__version__) - if supervisor.scheme == "https": - ssl = " (ssl)" - else: - ssl = "" - log.info("HTTP server version: %s%s" % (supervisor.protocol, ssl)) - log.info("PID: %s" % os.getpid()) - - cherrypy.server.using_apache = supervisor.using_apache - cherrypy.server.using_wsgi = supervisor.using_wsgi - - if sys.platform[:4] == 'java': - cherrypy.config.update({'server.nodelay': False}) - - if isinstance(conf, basestring): - parser = cherrypy.lib.reprconf.Parser() - conf = parser.dict_from_file(conf).get('global', {}) - else: - conf = conf or {} - baseconf = conf.copy() - baseconf.update({'server.socket_host': supervisor.host, - 'server.socket_port': supervisor.port, - 'server.protocol_version': supervisor.protocol, - 'environment': "test_suite", - }) - if supervisor.scheme == "https": - #baseconf['server.ssl_module'] = 'builtin' - baseconf['server.ssl_certificate'] = serverpem - baseconf['server.ssl_private_key'] = serverpem - - # helper must be imported lazily so the coverage tool - # can run against module-level statements within cherrypy. - # Also, we have to do "from cherrypy.test import helper", - # exactly like each test module does, because a relative import - # would stick a second instance of webtest in sys.modules, - # and we wouldn't be able to globally override the port anymore. - if supervisor.scheme == "https": - webtest.WebCase.HTTP_CONN = HTTPSConnection - return baseconf - _setup_server = classmethod(_setup_server) - - def setup_class(cls): - '' - #Creates a server - conf = get_tst_config() - supervisor_factory = cls.available_servers.get(conf.get('server', 'wsgi')) - if supervisor_factory is None: - raise RuntimeError('Unknown server in config: %s' % conf['server']) - supervisor = supervisor_factory(**conf) - - #Copied from "run_test_suite" - cherrypy.config.reset() - baseconf = cls._setup_server(supervisor, conf) - cherrypy.config.update(baseconf) - setup_client() - - if hasattr(cls, 'setup_server'): - # Clear the cherrypy tree and clear the wsgi server so that - # it can be updated with the new root - cherrypy.tree = cherrypy._cptree.Tree() - cherrypy.server.httpserver = None - cls.setup_server() - supervisor.start(cls.__module__) - - cls.supervisor = supervisor - setup_class = classmethod(setup_class) - - def teardown_class(cls): - '' - if hasattr(cls, 'setup_server'): - cls.supervisor.stop() - teardown_class = classmethod(teardown_class) - - def prefix(self): - return self.script_name.rstrip("/") - - def base(self): - if ((self.scheme == "http" and self.PORT == 80) or - (self.scheme == "https" and self.PORT == 443)): - port = "" - else: - port = ":%s" % self.PORT - - return "%s://%s%s%s" % (self.scheme, self.HOST, port, - self.script_name.rstrip("/")) - - def exit(self): - sys.exit() - - def getPage(self, url, headers=None, method="GET", body=None, protocol=None): - """Open the url. Return status, headers, body.""" - if self.script_name: - url = httputil.urljoin(self.script_name, url) - return webtest.WebCase.getPage(self, url, headers, method, body, protocol) - - def skip(self, msg='skipped '): - raise nose.SkipTest(msg) - - def assertErrorPage(self, status, message=None, pattern=''): - """Compare the response body with a built in error page. - - The function will optionally look for the regexp pattern, - within the exception embedded in the error page.""" - - # This will never contain a traceback - page = cherrypy._cperror.get_error_page(status, message=message) - - # First, test the response body without checking the traceback. - # Stick a match-all group (.*) in to grab the traceback. - esc = re.escape - epage = esc(page) - epage = epage.replace(esc('
'),
-                              esc('
') + '(.*)' + esc('
')) - m = re.match(ntob(epage, self.encoding), self.body, re.DOTALL) - if not m: - self._handlewebError('Error page does not match; expected:\n' + page) - return - - # Now test the pattern against the traceback - if pattern is None: - # Special-case None to mean that there should be *no* traceback. - if m and m.group(1): - self._handlewebError('Error page contains traceback') - else: - if (m is None) or ( - not re.search(ntob(re.escape(pattern), self.encoding), - m.group(1))): - msg = 'Error page does not contain %s in traceback' - self._handlewebError(msg % repr(pattern)) - - date_tolerance = 2 - - def assertEqualDates(self, dt1, dt2, seconds=None): - """Assert abs(dt1 - dt2) is within Y seconds.""" - if seconds is None: - seconds = self.date_tolerance - - if dt1 > dt2: - diff = dt1 - dt2 - else: - diff = dt2 - dt1 - if not diff < datetime.timedelta(seconds=seconds): - raise AssertionError('%r and %r are not within %r seconds.' % - (dt1, dt2, seconds)) - - -def setup_client(): - """Set up the WebCase classes to match the server's socket settings.""" - webtest.WebCase.PORT = cherrypy.server.socket_port - webtest.WebCase.HOST = cherrypy.server.socket_host - if cherrypy.server.ssl_certificate: - CPWebCase.scheme = 'https' - -# --------------------------- Spawning helpers --------------------------- # - - -class CPProcess(object): - - pid_file = os.path.join(thisdir, 'test.pid') - config_file = os.path.join(thisdir, 'test.conf') - config_template = """[global] -server.socket_host: '%(host)s' -server.socket_port: %(port)s -checker.on: False -log.screen: False -log.error_file: r'%(error_log)s' -log.access_file: r'%(access_log)s' -%(ssl)s -%(extra)s -""" - error_log = os.path.join(thisdir, 'test.error.log') - access_log = os.path.join(thisdir, 'test.access.log') - - def __init__(self, wait=False, daemonize=False, ssl=False, socket_host=None, socket_port=None): - self.wait = wait - self.daemonize = daemonize - self.ssl = ssl - self.host = socket_host or cherrypy.server.socket_host - self.port = socket_port or cherrypy.server.socket_port - - def write_conf(self, extra=""): - if self.ssl: - serverpem = os.path.join(thisdir, 'test.pem') - ssl = """ -server.ssl_certificate: r'%s' -server.ssl_private_key: r'%s' -""" % (serverpem, serverpem) - else: - ssl = "" - - conf = self.config_template % { - 'host': self.host, - 'port': self.port, - 'error_log': self.error_log, - 'access_log': self.access_log, - 'ssl': ssl, - 'extra': extra, - } - f = open(self.config_file, 'wb') - f.write(ntob(conf, 'utf-8')) - f.close() - - def start(self, imports=None): - """Start cherryd in a subprocess.""" - cherrypy._cpserver.wait_for_free_port(self.host, self.port) - - args = [sys.executable, os.path.join(thisdir, '..', 'cherryd'), - '-c', self.config_file, '-p', self.pid_file] - - if not isinstance(imports, (list, tuple)): - imports = [imports] - for i in imports: - if i: - args.append('-i') - args.append(i) - - if self.daemonize: - args.append('-d') - - env = os.environ.copy() - # Make sure we import the cherrypy package in which this module is defined. - grandparentdir = os.path.abspath(os.path.join(thisdir, '..', '..')) - if env.get('PYTHONPATH', ''): - env['PYTHONPATH'] = os.pathsep.join((grandparentdir, env['PYTHONPATH'])) - else: - env['PYTHONPATH'] = grandparentdir - if self.wait: - self.exit_code = os.spawnve(os.P_WAIT, sys.executable, args, env) - else: - os.spawnve(os.P_NOWAIT, sys.executable, args, env) - cherrypy._cpserver.wait_for_occupied_port(self.host, self.port) - - # Give the engine a wee bit more time to finish STARTING - if self.daemonize: - time.sleep(2) - else: - time.sleep(1) - - def get_pid(self): - return int(open(self.pid_file, 'rb').read()) - - def join(self): - """Wait for the process to exit.""" - try: - try: - # Mac, UNIX - os.wait() - except AttributeError: - # Windows - try: - pid = self.get_pid() - except IOError: - # Assume the subprocess deleted the pidfile on shutdown. - pass - else: - os.waitpid(pid, 0) - except OSError: - x = sys.exc_info()[1] - if x.args != (10, 'No child processes'): - raise - diff --git a/cherrypy/test/logtest.py b/cherrypy/test/logtest.py deleted file mode 100644 index c093da2c..00000000 --- a/cherrypy/test/logtest.py +++ /dev/null @@ -1,181 +0,0 @@ -"""logtest, a unittest.TestCase helper for testing log output.""" - -import sys -import time - -import cherrypy - - -try: - # On Windows, msvcrt.getch reads a single char without output. - import msvcrt - def getchar(): - return msvcrt.getch() -except ImportError: - # Unix getchr - import tty, termios - def getchar(): - fd = sys.stdin.fileno() - old_settings = termios.tcgetattr(fd) - try: - tty.setraw(sys.stdin.fileno()) - ch = sys.stdin.read(1) - finally: - termios.tcsetattr(fd, termios.TCSADRAIN, old_settings) - return ch - - -class LogCase(object): - """unittest.TestCase mixin for testing log messages. - - logfile: a filename for the desired log. Yes, I know modes are evil, - but it makes the test functions so much cleaner to set this once. - - lastmarker: the last marker in the log. This can be used to search for - messages since the last marker. - - markerPrefix: a string with which to prefix log markers. This should be - unique enough from normal log output to use for marker identification. - """ - - logfile = None - lastmarker = None - markerPrefix = "test suite marker: " - - def _handleLogError(self, msg, data, marker, pattern): - print("") - print(" ERROR: %s" % msg) - - if not self.interactive: - raise self.failureException(msg) - - p = " Show: [L]og [M]arker [P]attern; [I]gnore, [R]aise, or sys.e[X]it >> " - print p, - # ARGH - sys.stdout.flush() - while True: - i = getchar().upper() - if i not in "MPLIRX": - continue - print(i.upper()) # Also prints new line - if i == "L": - for x, line in enumerate(data): - if (x + 1) % self.console_height == 0: - # The \r and comma should make the next line overwrite - print "<-- More -->\r", - m = getchar().lower() - # Erase our "More" prompt - print " \r", - if m == "q": - break - print(line.rstrip()) - elif i == "M": - print(repr(marker or self.lastmarker)) - elif i == "P": - print(repr(pattern)) - elif i == "I": - # return without raising the normal exception - return - elif i == "R": - raise self.failureException(msg) - elif i == "X": - self.exit() - print p, - - def exit(self): - sys.exit() - - def emptyLog(self): - """Overwrite self.logfile with 0 bytes.""" - open(self.logfile, 'wb').write("") - - def markLog(self, key=None): - """Insert a marker line into the log and set self.lastmarker.""" - if key is None: - key = str(time.time()) - self.lastmarker = key - - open(self.logfile, 'ab+').write("%s%s\n" % (self.markerPrefix, key)) - - def _read_marked_region(self, marker=None): - """Return lines from self.logfile in the marked region. - - If marker is None, self.lastmarker is used. If the log hasn't - been marked (using self.markLog), the entire log will be returned. - """ -## # Give the logger time to finish writing? -## time.sleep(0.5) - - logfile = self.logfile - marker = marker or self.lastmarker - if marker is None: - return open(logfile, 'rb').readlines() - - data = [] - in_region = False - for line in open(logfile, 'rb'): - if in_region: - if (line.startswith(self.markerPrefix) and not marker in line): - break - else: - data.append(line) - elif marker in line: - in_region = True - return data - - def assertInLog(self, line, marker=None): - """Fail if the given (partial) line is not in the log. - - The log will be searched from the given marker to the next marker. - If marker is None, self.lastmarker is used. If the log hasn't - been marked (using self.markLog), the entire log will be searched. - """ - data = self._read_marked_region(marker) - for logline in data: - if line in logline: - return - msg = "%r not found in log" % line - self._handleLogError(msg, data, marker, line) - - def assertNotInLog(self, line, marker=None): - """Fail if the given (partial) line is in the log. - - The log will be searched from the given marker to the next marker. - If marker is None, self.lastmarker is used. If the log hasn't - been marked (using self.markLog), the entire log will be searched. - """ - data = self._read_marked_region(marker) - for logline in data: - if line in logline: - msg = "%r found in log" % line - self._handleLogError(msg, data, marker, line) - - def assertLog(self, sliceargs, lines, marker=None): - """Fail if log.readlines()[sliceargs] is not contained in 'lines'. - - The log will be searched from the given marker to the next marker. - If marker is None, self.lastmarker is used. If the log hasn't - been marked (using self.markLog), the entire log will be searched. - """ - data = self._read_marked_region(marker) - if isinstance(sliceargs, int): - # Single arg. Use __getitem__ and allow lines to be str or list. - if isinstance(lines, (tuple, list)): - lines = lines[0] - if lines not in data[sliceargs]: - msg = "%r not found on log line %r" % (lines, sliceargs) - self._handleLogError(msg, [data[sliceargs]], marker, lines) - else: - # Multiple args. Use __getslice__ and require lines to be list. - if isinstance(lines, tuple): - lines = list(lines) - elif isinstance(lines, basestring): - raise TypeError("The 'lines' arg must be a list when " - "'sliceargs' is a tuple.") - - start, stop = sliceargs - for line, logline in zip(lines, data[start:stop]): - if line not in logline: - msg = "%r not found in log" % line - self._handleLogError(msg, data[start:stop], marker, line) - diff --git a/cherrypy/test/modfastcgi.py b/cherrypy/test/modfastcgi.py deleted file mode 100644 index 95acf141..00000000 --- a/cherrypy/test/modfastcgi.py +++ /dev/null @@ -1,135 +0,0 @@ -"""Wrapper for mod_fastcgi, for use as a CherryPy HTTP server when testing. - -To autostart fastcgi, the "apache" executable or script must be -on your system path, or you must override the global APACHE_PATH. -On some platforms, "apache" may be called "apachectl", "apache2ctl", -or "httpd"--create a symlink to them if needed. - -You'll also need the WSGIServer from flup.servers. -See http://projects.amor.org/misc/wiki/ModPythonGateway - - -KNOWN BUGS -========== - -1. Apache processes Range headers automatically; CherryPy's truncated - output is then truncated again by Apache. See test_core.testRanges. - This was worked around in http://www.cherrypy.org/changeset/1319. -2. Apache does not allow custom HTTP methods like CONNECT as per the spec. - See test_core.testHTTPMethods. -3. Max request header and body settings do not work with Apache. -4. Apache replaces status "reason phrases" automatically. For example, - CherryPy may set "304 Not modified" but Apache will write out - "304 Not Modified" (capital "M"). -5. Apache does not allow custom error codes as per the spec. -6. Apache (or perhaps modpython, or modpython_gateway) unquotes %xx in the - Request-URI too early. -7. mod_python will not read request bodies which use the "chunked" - transfer-coding (it passes REQUEST_CHUNKED_ERROR to ap_setup_client_block - instead of REQUEST_CHUNKED_DECHUNK, see Apache2's http_protocol.c and - mod_python's requestobject.c). -8. Apache will output a "Content-Length: 0" response header even if there's - no response entity body. This isn't really a bug; it just differs from - the CherryPy default. -""" - -import os -curdir = os.path.join(os.getcwd(), os.path.dirname(__file__)) -import re -import sys -import time - -import cherrypy -from cherrypy.process import plugins, servers -from cherrypy.test import helper - - -def read_process(cmd, args=""): - pipein, pipeout = os.popen4("%s %s" % (cmd, args)) - try: - firstline = pipeout.readline() - if (re.search(r"(not recognized|No such file|not found)", firstline, - re.IGNORECASE)): - raise IOError('%s must be on your system path.' % cmd) - output = firstline + pipeout.read() - finally: - pipeout.close() - return output - - -APACHE_PATH = "apache2ctl" -CONF_PATH = "fastcgi.conf" - -conf_fastcgi = """ -# Apache2 server conf file for testing CherryPy with mod_fastcgi. -# fumanchu: I had to hard-code paths due to crazy Debian layouts :( -ServerRoot /usr/lib/apache2 -User #1000 -ErrorLog %(root)s/mod_fastcgi.error.log - -DocumentRoot "%(root)s" -ServerName 127.0.0.1 -Listen %(port)s -LoadModule fastcgi_module modules/mod_fastcgi.so -LoadModule rewrite_module modules/mod_rewrite.so - -Options +ExecCGI -SetHandler fastcgi-script -RewriteEngine On -RewriteRule ^(.*)$ /fastcgi.pyc [L] -FastCgiExternalServer "%(server)s" -host 127.0.0.1:4000 -""" - -def erase_script_name(environ, start_response): - environ['SCRIPT_NAME'] = '' - return cherrypy.tree(environ, start_response) - -class ModFCGISupervisor(helper.LocalWSGISupervisor): - - httpserver_class = "cherrypy.process.servers.FlupFCGIServer" - using_apache = True - using_wsgi = True - template = conf_fastcgi - - def __str__(self): - return "FCGI Server on %s:%s" % (self.host, self.port) - - def start(self, modulename): - cherrypy.server.httpserver = servers.FlupFCGIServer( - application=erase_script_name, bindAddress=('127.0.0.1', 4000)) - cherrypy.server.httpserver.bind_addr = ('127.0.0.1', 4000) - cherrypy.server.socket_port = 4000 - # For FCGI, we both start apache... - self.start_apache() - # ...and our local server - cherrypy.engine.start() - self.sync_apps() - - def start_apache(self): - fcgiconf = CONF_PATH - if not os.path.isabs(fcgiconf): - fcgiconf = os.path.join(curdir, fcgiconf) - - # Write the Apache conf file. - f = open(fcgiconf, 'wb') - try: - server = repr(os.path.join(curdir, 'fastcgi.pyc'))[1:-1] - output = self.template % {'port': self.port, 'root': curdir, - 'server': server} - output = output.replace('\r\n', '\n') - f.write(output) - finally: - f.close() - - result = read_process(APACHE_PATH, "-k start -f %s" % fcgiconf) - if result: - print(result) - - def stop(self): - """Gracefully shutdown a server that is serving forever.""" - read_process(APACHE_PATH, "-k stop") - helper.LocalWSGISupervisor.stop(self) - - def sync_apps(self): - cherrypy.server.httpserver.fcgiserver.application = self.get_app(erase_script_name) - diff --git a/cherrypy/test/modfcgid.py b/cherrypy/test/modfcgid.py deleted file mode 100644 index 736aa4c8..00000000 --- a/cherrypy/test/modfcgid.py +++ /dev/null @@ -1,125 +0,0 @@ -"""Wrapper for mod_fcgid, for use as a CherryPy HTTP server when testing. - -To autostart fcgid, the "apache" executable or script must be -on your system path, or you must override the global APACHE_PATH. -On some platforms, "apache" may be called "apachectl", "apache2ctl", -or "httpd"--create a symlink to them if needed. - -You'll also need the WSGIServer from flup.servers. -See http://projects.amor.org/misc/wiki/ModPythonGateway - - -KNOWN BUGS -========== - -1. Apache processes Range headers automatically; CherryPy's truncated - output is then truncated again by Apache. See test_core.testRanges. - This was worked around in http://www.cherrypy.org/changeset/1319. -2. Apache does not allow custom HTTP methods like CONNECT as per the spec. - See test_core.testHTTPMethods. -3. Max request header and body settings do not work with Apache. -4. Apache replaces status "reason phrases" automatically. For example, - CherryPy may set "304 Not modified" but Apache will write out - "304 Not Modified" (capital "M"). -5. Apache does not allow custom error codes as per the spec. -6. Apache (or perhaps modpython, or modpython_gateway) unquotes %xx in the - Request-URI too early. -7. mod_python will not read request bodies which use the "chunked" - transfer-coding (it passes REQUEST_CHUNKED_ERROR to ap_setup_client_block - instead of REQUEST_CHUNKED_DECHUNK, see Apache2's http_protocol.c and - mod_python's requestobject.c). -8. Apache will output a "Content-Length: 0" response header even if there's - no response entity body. This isn't really a bug; it just differs from - the CherryPy default. -""" - -import os -curdir = os.path.join(os.getcwd(), os.path.dirname(__file__)) -import re -import sys -import time - -import cherrypy -from cherrypy._cpcompat import ntob -from cherrypy.process import plugins, servers -from cherrypy.test import helper - - -def read_process(cmd, args=""): - pipein, pipeout = os.popen4("%s %s" % (cmd, args)) - try: - firstline = pipeout.readline() - if (re.search(r"(not recognized|No such file|not found)", firstline, - re.IGNORECASE)): - raise IOError('%s must be on your system path.' % cmd) - output = firstline + pipeout.read() - finally: - pipeout.close() - return output - - -APACHE_PATH = "httpd" -CONF_PATH = "fcgi.conf" - -conf_fcgid = """ -# Apache2 server conf file for testing CherryPy with mod_fcgid. - -DocumentRoot "%(root)s" -ServerName 127.0.0.1 -Listen %(port)s -LoadModule fastcgi_module modules/mod_fastcgi.dll -LoadModule rewrite_module modules/mod_rewrite.so - -Options ExecCGI -SetHandler fastcgi-script -RewriteEngine On -RewriteRule ^(.*)$ /fastcgi.pyc [L] -FastCgiExternalServer "%(server)s" -host 127.0.0.1:4000 -""" - -class ModFCGISupervisor(helper.LocalSupervisor): - - using_apache = True - using_wsgi = True - template = conf_fcgid - - def __str__(self): - return "FCGI Server on %s:%s" % (self.host, self.port) - - def start(self, modulename): - cherrypy.server.httpserver = servers.FlupFCGIServer( - application=cherrypy.tree, bindAddress=('127.0.0.1', 4000)) - cherrypy.server.httpserver.bind_addr = ('127.0.0.1', 4000) - # For FCGI, we both start apache... - self.start_apache() - # ...and our local server - helper.LocalServer.start(self, modulename) - - def start_apache(self): - fcgiconf = CONF_PATH - if not os.path.isabs(fcgiconf): - fcgiconf = os.path.join(curdir, fcgiconf) - - # Write the Apache conf file. - f = open(fcgiconf, 'wb') - try: - server = repr(os.path.join(curdir, 'fastcgi.pyc'))[1:-1] - output = self.template % {'port': self.port, 'root': curdir, - 'server': server} - output = ntob(output.replace('\r\n', '\n')) - f.write(output) - finally: - f.close() - - result = read_process(APACHE_PATH, "-k start -f %s" % fcgiconf) - if result: - print(result) - - def stop(self): - """Gracefully shutdown a server that is serving forever.""" - read_process(APACHE_PATH, "-k stop") - helper.LocalServer.stop(self) - - def sync_apps(self): - cherrypy.server.httpserver.fcgiserver.application = self.get_app() - diff --git a/cherrypy/test/modpy.py b/cherrypy/test/modpy.py deleted file mode 100644 index 519571fc..00000000 --- a/cherrypy/test/modpy.py +++ /dev/null @@ -1,163 +0,0 @@ -"""Wrapper for mod_python, for use as a CherryPy HTTP server when testing. - -To autostart modpython, the "apache" executable or script must be -on your system path, or you must override the global APACHE_PATH. -On some platforms, "apache" may be called "apachectl" or "apache2ctl"-- -create a symlink to them if needed. - -If you wish to test the WSGI interface instead of our _cpmodpy interface, -you also need the 'modpython_gateway' module at: -http://projects.amor.org/misc/wiki/ModPythonGateway - - -KNOWN BUGS -========== - -1. Apache processes Range headers automatically; CherryPy's truncated - output is then truncated again by Apache. See test_core.testRanges. - This was worked around in http://www.cherrypy.org/changeset/1319. -2. Apache does not allow custom HTTP methods like CONNECT as per the spec. - See test_core.testHTTPMethods. -3. Max request header and body settings do not work with Apache. -4. Apache replaces status "reason phrases" automatically. For example, - CherryPy may set "304 Not modified" but Apache will write out - "304 Not Modified" (capital "M"). -5. Apache does not allow custom error codes as per the spec. -6. Apache (or perhaps modpython, or modpython_gateway) unquotes %xx in the - Request-URI too early. -7. mod_python will not read request bodies which use the "chunked" - transfer-coding (it passes REQUEST_CHUNKED_ERROR to ap_setup_client_block - instead of REQUEST_CHUNKED_DECHUNK, see Apache2's http_protocol.c and - mod_python's requestobject.c). -8. Apache will output a "Content-Length: 0" response header even if there's - no response entity body. This isn't really a bug; it just differs from - the CherryPy default. -""" - -import os -curdir = os.path.join(os.getcwd(), os.path.dirname(__file__)) -import re -import time - -from cherrypy.test import helper - - -def read_process(cmd, args=""): - pipein, pipeout = os.popen4("%s %s" % (cmd, args)) - try: - firstline = pipeout.readline() - if (re.search(r"(not recognized|No such file|not found)", firstline, - re.IGNORECASE)): - raise IOError('%s must be on your system path.' % cmd) - output = firstline + pipeout.read() - finally: - pipeout.close() - return output - - -APACHE_PATH = "httpd" -CONF_PATH = "test_mp.conf" - -conf_modpython_gateway = """ -# Apache2 server conf file for testing CherryPy with modpython_gateway. - -ServerName 127.0.0.1 -DocumentRoot "/" -Listen %(port)s -LoadModule python_module modules/mod_python.so - -SetHandler python-program -PythonFixupHandler cherrypy.test.modpy::wsgisetup -PythonOption testmod %(modulename)s -PythonHandler modpython_gateway::handler -PythonOption wsgi.application cherrypy::tree -PythonOption socket_host %(host)s -PythonDebug On -""" - -conf_cpmodpy = """ -# Apache2 server conf file for testing CherryPy with _cpmodpy. - -ServerName 127.0.0.1 -DocumentRoot "/" -Listen %(port)s -LoadModule python_module modules/mod_python.so - -SetHandler python-program -PythonFixupHandler cherrypy.test.modpy::cpmodpysetup -PythonHandler cherrypy._cpmodpy::handler -PythonOption cherrypy.setup cherrypy.test.%(modulename)s::setup_server -PythonOption socket_host %(host)s -PythonDebug On -""" - -class ModPythonSupervisor(helper.Supervisor): - - using_apache = True - using_wsgi = False - template = None - - def __str__(self): - return "ModPython Server on %s:%s" % (self.host, self.port) - - def start(self, modulename): - mpconf = CONF_PATH - if not os.path.isabs(mpconf): - mpconf = os.path.join(curdir, mpconf) - - f = open(mpconf, 'wb') - try: - f.write(self.template % - {'port': self.port, 'modulename': modulename, - 'host': self.host}) - finally: - f.close() - - result = read_process(APACHE_PATH, "-k start -f %s" % mpconf) - if result: - print(result) - - def stop(self): - """Gracefully shutdown a server that is serving forever.""" - read_process(APACHE_PATH, "-k stop") - - -loaded = False -def wsgisetup(req): - global loaded - if not loaded: - loaded = True - options = req.get_options() - - import cherrypy - cherrypy.config.update({ - "log.error_file": os.path.join(curdir, "test.log"), - "environment": "test_suite", - "server.socket_host": options['socket_host'], - }) - - modname = options['testmod'] - mod = __import__(modname, globals(), locals(), ['']) - mod.setup_server() - - cherrypy.server.unsubscribe() - cherrypy.engine.start() - from mod_python import apache - return apache.OK - - -def cpmodpysetup(req): - global loaded - if not loaded: - loaded = True - options = req.get_options() - - import cherrypy - cherrypy.config.update({ - "log.error_file": os.path.join(curdir, "test.log"), - "environment": "test_suite", - "server.socket_host": options['socket_host'], - }) - from mod_python import apache - return apache.OK - diff --git a/cherrypy/test/modwsgi.py b/cherrypy/test/modwsgi.py deleted file mode 100644 index 309a541c..00000000 --- a/cherrypy/test/modwsgi.py +++ /dev/null @@ -1,148 +0,0 @@ -"""Wrapper for mod_wsgi, for use as a CherryPy HTTP server. - -To autostart modwsgi, the "apache" executable or script must be -on your system path, or you must override the global APACHE_PATH. -On some platforms, "apache" may be called "apachectl" or "apache2ctl"-- -create a symlink to them if needed. - - -KNOWN BUGS -========== - -##1. Apache processes Range headers automatically; CherryPy's truncated -## output is then truncated again by Apache. See test_core.testRanges. -## This was worked around in http://www.cherrypy.org/changeset/1319. -2. Apache does not allow custom HTTP methods like CONNECT as per the spec. - See test_core.testHTTPMethods. -3. Max request header and body settings do not work with Apache. -##4. Apache replaces status "reason phrases" automatically. For example, -## CherryPy may set "304 Not modified" but Apache will write out -## "304 Not Modified" (capital "M"). -##5. Apache does not allow custom error codes as per the spec. -##6. Apache (or perhaps modpython, or modpython_gateway) unquotes %xx in the -## Request-URI too early. -7. mod_wsgi will not read request bodies which use the "chunked" - transfer-coding (it passes REQUEST_CHUNKED_ERROR to ap_setup_client_block - instead of REQUEST_CHUNKED_DECHUNK, see Apache2's http_protocol.c and - mod_python's requestobject.c). -8. When responding with 204 No Content, mod_wsgi adds a Content-Length - header for you. -9. When an error is raised, mod_wsgi has no facility for printing a - traceback as the response content (it's sent to the Apache log instead). -10. Startup and shutdown of Apache when running mod_wsgi seems slow. -""" - -import os -curdir = os.path.abspath(os.path.dirname(__file__)) -import re -import sys -import time - -import cherrypy -from cherrypy.test import helper, webtest - - -def read_process(cmd, args=""): - pipein, pipeout = os.popen4("%s %s" % (cmd, args)) - try: - firstline = pipeout.readline() - if (re.search(r"(not recognized|No such file|not found)", firstline, - re.IGNORECASE)): - raise IOError('%s must be on your system path.' % cmd) - output = firstline + pipeout.read() - finally: - pipeout.close() - return output - - -if sys.platform == 'win32': - APACHE_PATH = "httpd" -else: - APACHE_PATH = "apache" - -CONF_PATH = "test_mw.conf" - -conf_modwsgi = r""" -# Apache2 server conf file for testing CherryPy with modpython_gateway. - -ServerName 127.0.0.1 -DocumentRoot "/" -Listen %(port)s - -AllowEncodedSlashes On -LoadModule rewrite_module modules/mod_rewrite.so -RewriteEngine on -RewriteMap escaping int:escape - -LoadModule log_config_module modules/mod_log_config.so -LogFormat "%%h %%l %%u %%t \"%%r\" %%>s %%b \"%%{Referer}i\" \"%%{User-agent}i\"" combined -CustomLog "%(curdir)s/apache.access.log" combined -ErrorLog "%(curdir)s/apache.error.log" -LogLevel debug - -LoadModule wsgi_module modules/mod_wsgi.so -LoadModule env_module modules/mod_env.so - -WSGIScriptAlias / "%(curdir)s/modwsgi.py" -SetEnv testmod %(testmod)s -""" - - -class ModWSGISupervisor(helper.Supervisor): - """Server Controller for ModWSGI and CherryPy.""" - - using_apache = True - using_wsgi = True - template=conf_modwsgi - - def __str__(self): - return "ModWSGI Server on %s:%s" % (self.host, self.port) - - def start(self, modulename): - mpconf = CONF_PATH - if not os.path.isabs(mpconf): - mpconf = os.path.join(curdir, mpconf) - - f = open(mpconf, 'wb') - try: - output = (self.template % - {'port': self.port, 'testmod': modulename, - 'curdir': curdir}) - f.write(output) - finally: - f.close() - - result = read_process(APACHE_PATH, "-k start -f %s" % mpconf) - if result: - print(result) - - # Make a request so mod_wsgi starts up our app. - # If we don't, concurrent initial requests will 404. - cherrypy._cpserver.wait_for_occupied_port("127.0.0.1", self.port) - webtest.openURL('/ihopetheresnodefault', port=self.port) - time.sleep(1) - - def stop(self): - """Gracefully shutdown a server that is serving forever.""" - read_process(APACHE_PATH, "-k stop") - - -loaded = False -def application(environ, start_response): - import cherrypy - global loaded - if not loaded: - loaded = True - modname = "cherrypy.test." + environ['testmod'] - mod = __import__(modname, globals(), locals(), ['']) - mod.setup_server() - - cherrypy.config.update({ - "log.error_file": os.path.join(curdir, "test.error.log"), - "log.access_file": os.path.join(curdir, "test.access.log"), - "environment": "test_suite", - "engine.SIGHUP": None, - "engine.SIGTERM": None, - }) - return cherrypy.tree(environ, start_response) - diff --git a/cherrypy/test/native-server.ini b/cherrypy/test/native-server.ini deleted file mode 100644 index b32d98dd..00000000 --- a/cherrypy/test/native-server.ini +++ /dev/null @@ -1,9 +0,0 @@ -[supervisor] -scheme="http" -protocol="HTTP/1.1" -port= 8080 -host= "127.0.0.1" -profile= False -validate= False -conquer= False -server="wsgi" diff --git a/cherrypy/test/sessiondemo.py b/cherrypy/test/sessiondemo.py deleted file mode 100755 index 342e5b59..00000000 --- a/cherrypy/test/sessiondemo.py +++ /dev/null @@ -1,153 +0,0 @@ -#!/usr/bin/python -"""A session demonstration app.""" - -import calendar -from datetime import datetime -import sys -import cherrypy -from cherrypy.lib import sessions -from cherrypy._cpcompat import copyitems - - -page = """ - - - - - - - -

Session Demo

-

Reload this page. The session ID should not change from one reload to the next

-

Index | Expire | Regenerate

- - - - - - - - - -
Session ID:%(sessionid)s

%(changemsg)s

Request Cookie%(reqcookie)s
Response Cookie%(respcookie)s

Session Data%(sessiondata)s
Server Time%(servertime)s (Unix time: %(serverunixtime)s)
Browser Time 
Cherrypy Version:%(cpversion)s
Python Version:%(pyversion)s
- -""" - -class Root(object): - - def page(self): - changemsg = [] - if cherrypy.session.id != cherrypy.session.originalid: - if cherrypy.session.originalid is None: - changemsg.append('Created new session because no session id was given.') - if cherrypy.session.missing: - changemsg.append('Created new session due to missing (expired or malicious) session.') - if cherrypy.session.regenerated: - changemsg.append('Application generated a new session.') - - try: - expires = cherrypy.response.cookie['session_id']['expires'] - except KeyError: - expires = '' - - return page % { - 'sessionid': cherrypy.session.id, - 'changemsg': '
'.join(changemsg), - 'respcookie': cherrypy.response.cookie.output(), - 'reqcookie': cherrypy.request.cookie.output(), - 'sessiondata': copyitems(cherrypy.session), - 'servertime': datetime.utcnow().strftime("%Y/%m/%d %H:%M") + " UTC", - 'serverunixtime': calendar.timegm(datetime.utcnow().timetuple()), - 'cpversion': cherrypy.__version__, - 'pyversion': sys.version, - 'expires': expires, - } - - def index(self): - # Must modify data or the session will not be saved. - cherrypy.session['color'] = 'green' - return self.page() - index.exposed = True - - def expire(self): - sessions.expire() - return self.page() - expire.exposed = True - - def regen(self): - cherrypy.session.regenerate() - # Must modify data or the session will not be saved. - cherrypy.session['color'] = 'yellow' - return self.page() - regen.exposed = True - -if __name__ == '__main__': - cherrypy.config.update({ - #'environment': 'production', - 'log.screen': True, - 'tools.sessions.on': True, - }) - cherrypy.quickstart(Root()) - diff --git a/cherrypy/test/static/dirback.jpg b/cherrypy/test/static/dirback.jpg deleted file mode 100644 index 530e6d6a386fc097f3a1dbabbde2d80fec1175ac..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 18238 zcmb5VRajfk7cQI-+=Dy8ixwxi1qsEiMSnO1cPq5GyESOh;1mg3+T!jWq{RyqC|W9% z9{%UzeAnOF&)yfapS5P)tasMT`_8|$f7<|ZEp@m$00;yEG#?+pzYTyY00)GPjSa$i z{NUi=;NlVE<2@P~5fK3~n2dq~Oa=y1(lF6dQZZ12!E|hN49v`|tgMu@?40Z@oJ=gN zEdL7ve00Ub#UsVXCuN}mQ?dO2wtsy9Fg}n7%K-#r2VjALATaRX5P$&y06iuP1pGe( zVu7#$IQWnM6vzQU5EcmcF>P!tAT|gH06sc`*eFDl*+mU(eKN5rBg;6%R1EF1TKac< zqvkJEji@-q%P;Mt2NoXv>HlwF(Ep1J_@6%r8|Q!1g8w(?W5oZ@fM7NeWvqYe0OH5t z$7#R-MZj-6JTiY9Z~0Z+2Qq&pR9Op8Meoo~Tq;rKtSEkl_E3tu zCM=B{oR-z5Uw`_JYl%ol9qWg4hL^5pGP6&Xaz}Gvz%!&qN_iet7@9s1g4FTs1rcyp^&2O1O0JGrrK!FN<9$ zFaN_y?{ufh+PM3#bOK!}lwQ%FTrC@6@S*hL$gzbk@vk&_6TzV$=9zSvcPx`)g%Z`E zw9nPZ=kj`vMTrfHVN!y#1WYKUF zj4(lda${I45pXf3jR)GbdcD}^(3KblfRCV(EVb~COmu5+WNPHI7FXXir>o9V{x%&h(ax_sg_y%QzJ%}NlS&ks%Lf4jKsQ;u@iEV@l;JHO;12~U$u zS8INEphU1F>2A?L_7C%!<@oa5kGCwly%>COoL_!2jz`37iim*OwT)OV>Sz4}G;A2-3kO>J>M0ck4hhyxomHl^Z!F`Ql|O5-8UaRe zh~M(bA+I*cI3s87UzM2vUV604T~t5ye3GYzh96Rv2)q+C9$^-?Ngt>+NEQ3h;-~g( z7fU=U+6^~rHJ>A^#y=6X<*J!rZnL3d8o1z3jyR@o^x_u1GhhjsC7e!TJ5cq|v*tsa zn=^TX2flrjqj%F;=ItpfveDqu%>OTOM$%6!)LpFiwB5>Bc3M^_6;g`99t(UXJ*kCT z6#duMU-tVaSgK#-R)}&F=il(`T5ay2(`aR^a~uUcnEemS71L!hnq@~DLF{SyElQY}xo`%f_Ce-h{r1KsFmBDo`OeTCuq)R}Kp$Xyk$p%c zj6Zh+moU?da$7i|)BbMeOAYq7gJUlPtKDMwB-=S4A}iSYSfK%Qt~&$Dg_PeOsFL?s z%q@Z>mAfeWb&FpNotm)5SZYYEDN~Aa?z;imra+p#hk@>;Pe>=>LSLAP6ec9? znF4Xe3{5glCE?9@wwet@^s>6O@7%R#TIv755ey5OJXj+?3b=9%baafY771!*R z4A>7Xw{5y~cPm=3DaIZsME0rEw51{LUlcT8++weVQ2Ra>Ttm$V5zuzuPwbPw2NqpYX%ogL^;4(E1nQAGG5_@~6MmV_F zRp2+(gViWH57=?pVG=V-EKB1iD%y<8FJBD4_LGdOT?m#O!aU^a{~90h+Zq#8j$#Hd zIeh*BqriVtf~in?Gd`qwNUbE7SGmLUVpxhkEy+&|z2j4D5`^v|XJ*=Yn)Ce+{y2@R z@x7G3 zwHUZvl7P~uOMVv}p-rz3kJC*oOi0mz{N|>c(j*vmF+@Zs0K+MFfV!**Q1QHrzX#D3 zFCIYGRl!1(^v(q9cS9esMUh7$)!;AfD3QWz9o-TT+fYqQ^I*#>Y}7>s34CB1A2w2S z*IT%tkjQ6;>)mSssoe!U-SX*(6_kHN@+R3e8H^+lOI+ElnFk4nR+32%zURZ!EXt1R z8r`wfL92(vDDp+Zh3%+BvTvh32}h&__BVnVnZY;ny@N4l1nS9~K2|6w=XlIy=zte7 zC^-gw1M974G|CA@*#7a8>Y2{1Sag^kTNFR&NYOQ_V4IJ_ zqcE8@sTnt41M}2BxHJ7Lryj*oWm#tSM6L z3(>Db0IKlagp|@CB_eIf<7hqH(!9zLPYBS%*DOt!QcsAPrXV6ow@s~$58Qf&a?J5l zQzL(&Hz@jgqB9!&T=Jm0Jc9!mt)0lKchR_mAE2V_e&Jdkl%xn=4MtnsTJpQk;3=h& zUqX9#chdSCmt_%hx>(3%?)US~`cnPPKN_#mmK7IsgymP;W zHV!bFmDjFQ%$Ikc16$1pNnf`jOHU20Pru&&1Mm$m$};J)Tj@tEawb32-is9M?HaTD za@6@K_pv58aoiYJ!Q$g--`@QLptqNL)wq+|I-KBX$7+&sNRMX;t)0nEGL4=~Ok+y= z&R>jrglk6ax=Cr5`1GU#4X^S6OzMF46wlD!*RPn@DzB7Q) zo~<*h)q+K)qSp3OMW58!lar(>evLs691*Yg)e?H7WF-W|-@^?>szDv}&{yY0%W}dj z!6bvW!<22POOJw7;RWej84qp(s{YJOlS|lcqC4!2)2*r$x~P+F78kGX|2;o9usw<> zTD5gr6YNa{d2ywS%R5AWk7B4n4SLJZUu~S`b}4~rw+*nt2zkOenr3kTK}7ViomfIN z>-UqlJ+WpZftDWCb$oFFwDy_$L;HrU*~3cOP?DDr%dkm?Old^9xrC+lmn2POl-M{E zO}s;!AtTEp9K6jmLmZr#oTQ{z>&b-WQPDMD+B^je`LSg2!W zUAVnN#^&fx9VPP1X0;5LWgqHF;cTYogz3`F$rWaix1Jm7zVX@G;TVo5QpjH^TtUch zUV)kjf{d>*Gbmz@E_2FUcc#~;Lf6=%r`?T%x0K7`X>4ZtvVTlNG6K;uf;cNv0%rjg z^kYZ$#6pl4@L+|+8n{qA$juC}t^6ioVZ6qxFd>bIvbEUMct}LLx(-_zmtNk;hKmkJdctlDyvCA!{0dt~6?;CmK zTNj9u@=$Ux zyU*i2&Jv_1E*cx>h5Ks;wN7lla;I6EdidtV@5M67?|HUf(&JQ+s10pjU=hR(r~GO9 zX~wPW0X0rHANz&m&!@etcd>3^J5-j2ihx5?XyHXa;yyZ~og@f*kzIC*p)*@mg@s}(bvE3XM|4c?(#A!q> zNMDH^*Q3hwySpw}TfM6Y;5*ByJN6SPp>sTCL2fy(fgSPLF(@h8kh?@>W*RrFEi~NL zZH1aWXPl%vSIJCTGPD+;>r}J>yM&PQ&Yq%mg*78&Q6wbfjG(#<02r$z%mj8KF^A#_ zmDXx=D1O5j)M7fB9>6yM$8cI$YyLh~0blmVr@n)sQ&v3ZrC+2ajw$#LUhigQe-y*z z@#iQXt*ZScuykkoX1{Im*Q@H==Veifx1@^rmYZ)6T&s1qvAmO&0Rn?xv%*@?UQ*_& z9S*&g>l2Zt!qL!({Tc?FFvCg$lBmCaB-*3qB0XQB7DyA9Xz1bi2bByN``3;J%TL-n zNoLDaGFG3xMRu7s*v7m3=B0;hVUq9Ebc<8(0&Yyode6#3f7epBVD~sH(CF=qXL080 zu4FZq*o+$R7W~-Zb`rBP*=hRFU6$g2|CLSi%+EpXK(RbcGX6jORW7BkP$o_@!-`Pq zb%&P~uZUV$2y_+dp>DB*&~MAuA=D5zyQq2m78P7_=8N5R4z)@X8t;?hclEgDJW8oN zOG4`|AhjYkBZc+sO`(S&E5 z;#pQJl=`A)S9?q}&W_6rIx6&WgZf>D-eBUAgttXI?y=cSR|G@)>!}GO+BiB&R-h?KaKHwmM5~I@)>=iXaP|-Djy$12wGnv;PfWuUcKkCALBF2rC(0^E2LL(4hedCIyQ}K)-HE$0HwP zkSKbzh-VVrvglS&;Uq+St0}V1oAtWZAmR(9-V@#k+z^&;9609_cB3Z;%WeR&$~&{d zc+x%GarD)~NgYT8H=pVkyx63d`)4z9abBm?MtB9p(d1e}v5Zg0E`9?$ooiZ^X&H>@ zqff3lSi0(?1C|if{VH)6!!<_@8olA_u+mA(vS!fDXrP;O@1u&SA-R|c4tpVCKqu1D z1?~iGc~9*AeeNtOxXe2TWjQf1Cy?b_bFDvmS+j!Oy+I*Uajeo6W_S9Ob{6)L?hI{T z8_q;2{fTOliTIG#pxr6ztH0R|TG7AZZ0nX^^-gPMdhV?%<*SfiHgCW~hwY--7kz>> zAG#azjeT%4*hYRgiWhwaT>?8&58JTL5P$lPip0MesKx9QIs6w@VU_HlaM|zvmf$VY zXewM8;%RtgE8i7UZjyq%ds^)>jY&(&eP!0`Z-n%`Zo{u>YqR@qowOum<@@@NA!pq8 zE_#{&xXc-6;%Ng6{YsZ2Lf9gvi)6X#Y|l5?F)D)VH2(SBv7xC^niHFuHh#exoFF~w z{&z(i7%MKGEbL+|(|D4jn1Jyo&RG$4yGJ{3pE5OoB{d+-kF!8#-;eRc;IhY-0iV-Mt<--dW4M#(o2pKxrpodD*#4_^XCk@Bk+5lTaj=auGBh(g4$6UYJS z6;S&pZr#UhTRCF+Z3W?4N_@arr{Idv3trI*OB>jYo~NK{eL>H?IeOf!&vDX%)1_2* z<`+F8H@BYhAcBceXo}a4y@(;^OFXB6+Qc)8n*V>mohRM&%Pj?$LjMtO^%Hw@85YSWi~OweN1o`__yq%@=gHmQPv*oI)yH zOPV;!r&b1o>=**NT2~l7_Q=b5=UMd=c10uVsnk-T)^i)0tNF$ksqB?I(M+K{RL2Ax zl%k&XI~(vOwlwV9TwCW$?SfANpPtsy8dYlpz6ztG@3?Dq-d&OnVF36a7A13q0VKjj z42P{nO1Na=XVuy+8isGiQD5PRlDfD;d;dD69RL6W?kksJ`4$byVmOwYkM6rOPZ`zV zzcp*M#n@mZx&4XYW5{z;+b$bX;o=Zd7Zj9&!F|rbAtd)doNH2*)pRuoJ}L-j7sl|47t|`%S8M4ENhMdXg2#Ji3EE`t@ADyObdKy zyqmF+W|4vI>vDkNM)n80Y)w|nP(F#iaM<{EhcP`#sR1q~rfcIN#(W!=D$Rvss7g$C zSh))+(w%asH$G*qIYRk{Cqfb;UucG8yefuy`G|2ge{Ba$>_!^8yxeS+y z_ZYC-qy^o?Czjx6;b_1!=~V*zkuL4I35^hOyTsi>N?rXl#`74P@UkaBFV1lJRi!`r zlJlB+YBb*+Hhj3v&g&Ks*CRiw|7hWXUArq5`u5)BTgQLUq74VCCb5@IQ2uXdTb-e?;M3aJ zMPD{0TDxB(Ut+EEX-T2!bjG@D_ZArZ;HeJE(^M>MQuvTYYP3k!l%}9;y+w0>Q4YL} z21vs$7R248CqCUr-fhpILBOFq7DtaHR6a=U_Ip?GTAYafG&ABy_Zfgm*vox^;b{W( zQ?VyMGem;n9sifp=NcSh#c^=8yv%EdS2itr4 zu?Cr}%W_BS0i|m(?wc}-nKdbHVmlu=v9k=}an#7?Y*nn~wbjYOCAih5>u1EpSm_pS zHBh2fuaF9riG#BU3ve%rXj^^}Fb5Hj=&l@D0?9}Vi}}beB^oZV?;5Q|aD!sop! z>DX&88+8Y`stjn#1e>lH76R3S{_4A(*tZ$wrL`XJbM9088$rlGY#2(P{{6fik zdQ!LMW)Zd=5U1Ky+xHPJc=&NpZ_Ds$))E%5C_D0J_=4J%?qf+W6#xBrG`f>+@@4&H zQ~JQmTA9Vn$C_p(KfPFXNz>?=HK?AF1u1q@13fpY;;sCq0{}Sc5vyr6-EVjZxPP(K z1p82TT1yw@g$t5y@k1Gb|10m2-DP5>BCqvI$@B%ZAzh(>H+!(asjRM;o?_>1>j&C< z=n#f`2%m`Jis8j#MGSY}t6K&WZh%kzACqX;@P|p?V=)*}kHKji?eN>v0%8?4E+Z0d zJ1FnW_%Y1JYC69CPx8Hoheqte(CtmC#%6d9gtzBJdT3bU}t<1N$E-_-Ynh`Q@#%s&)U$p8k7%T{HqIo2 zZD@AUwf*g1d-hJU3C5bQdu&C#w~g_x1@?4p^IG> zQriEL1e=9d9o`Q{xsJQVTV^t_lG~xYCSj@mogu$YsJ~nXj8jhi;dF&2ax)o_ig5in zs&BIO_hZgMM_JkYg>nNTvC&QUHG%5n-Wb_3C-xERwKO6)BagqpB8jZM5Pt16mQ*3e zS3q#^VkKLw!fVfpARgTFzpTjJ;Lul=?U!mDSS@|{BW>1Ayqly>c^>d-x zh_mZ=0_o?b?tOVZp*0;9@3`8$>VFyAa`q=Awq-Stli>yFydpjCBGo;AKX+#V_&aK$ zn%_FLlEFem|GxGRZL%5gF*|?1f>DSLkqq>kcB&zb8e5w+8hZqx?8KosI0L;w25%(Q zc!^x@JtxJkux^)$tdNhJt1LN|o%A76K6VbD^IGc%tx1ZY=uSxpwX#lW0t(yAh0Zz4 zXL0KZ?be<4|LQ)5#QRUv!{U}VQ+mMTl5WWCE5BsT_0`#0p8DmTp})N({pC?ondMgN zGak8g<0qngpxEvI#FhiWsT)j>ENTRM4Q^q4L#qDsz+-*Jr1LF4khi}MacY2~aK9(H z&FSpvJ6$4mTB1TG;WqKjBKH=l!Fxd}$LWa6#=cl2tvx=LwU`dx!?@;(!WRl?T{AhP zi=va7{-Z79CDFSfp<65`uT=Z>ti0wKEZu`QRi_{}$8q`rS; z=u=?l%5)7h8q92OS4fC?=hi~U{!3H2BGEI9R_?9v8EFjrVB(1P`xu^iEGrLGv_CwI z8Fk?BQ(3l*2@=jXAQ%O0iV;xA{jzQx|0&=}McT;qsEO)aELj5`{iTUIdyW!^(356E z1o?TCabae22K39V9&B(ITQ^O;+H_QEkP=};R|<@n?&fJyEq3P?!@*@Pb5fFyv0fq{ zRT}CqNlF=xcTwIZq`;(aNtK=Fvj-ijDAxMmzN-3?ZzA*b_6hh>Xz)X=IWYS63C+FE zVo8+~GFxq?pJulpw+@p=ng2vwI|YwDl{-w|JCmWdiofXTF!qL0j$B5drA+K+j+ZBs zu$+ygpCuegjR&ms`F}GdaDx;Q-uV5{k=ww)6-7QBX%i>5icPV}t(O5u9roKKu1M!j z(eHzVThW=aK}T)R-rc5HV^2H z#)DK^0h@nEe1R{!)N6L(*V3@DmF|HJ`iQ?;RB_&!eWJeeBEB{c6PHQke4{KK^RCzB+FX3{Nrts`a4vVKGdT%XQ8AmWCMm{|e+Z*h80oJA5qhUBy%-OK@G>CZs1rJE zdOb{Z;EpUy%jF!4^kql3jf>r5B3ACSHTov+<>n*ifpOAQ48+7bVq>WEf%6m1m1vzG z91e%>ZOwGv-Y(6nZ%xXZ3v^q?4!=26t{@>c*gpQ@TIDqCIol`#85S#(dyc>&+tLfX ze0{Nl{y|FQYp&N=kdM>M)Ur=jC?9vY_(hlM2}W*-n9AAd$-0y2g70f6hoLvG`IZ<~ zq*nkXI7aAF)ToT*0DR^?Te6S=ZjaAXK+GY{vus4^)c4?sdcx>({;el^%bx|WN7$ZY z3xVUd_-b~+k6+e7JZi%Je+i<0d19s5JGu#A_e9>5=Nol_)~7bi&CUnA{JD)l4P*n)3|%qc~^T&o7g&9MhJ;xO!t zZV7X|33!6iZ1b3g5~t2@$yG!tE);(9M=CRRU+Wybtd-SyStA#w+A7s`V@{@0dU;7Y zCbe{CP zKg97l&{C#W|9Ktj1SWw#oW>AmwM+$Qy2T#iQagDI{fNy673vewWc3$%mzXa8l&$yZer-kr_?2dZM*kYzE9hbQ|7>B;J z^2Q7J&?oA^n-I{Hs7g4=7w>QLhKX+&v&yypZ)tPcLQ^;LRK=Hh=;;b%VyCg`**fnu z%8B4912)@Y>>*UqDIK(W+!h7?^$EV`5gdPc0h!q^Q6%(nn-_*6G)Va+g@yviNF-eN5AugZl`0TE5O%ikribhz;W}zc__utQ}?( zKy(35v#RWV6m~ZU$3%3kn>9V-gv9TX*zWjGhpRy1QXU$gTe4PO=(Rxl}Bw>Tdij=A6%GhqV z4~yPyMs{zO$;?y==DU|1*nLmPPUPR%aTsB=^QuI5N!dA~{)1$&R}AUahAI2_xnJAo znbj)eWUNPFKwkO*$>iq6*Zpi~{0tYCS#f?yyMB8DuoVvHu_Tn)nQcVP+Ct*c%%8#clj&cv*58Uh# zm}!UM{#>LH!Vm5tySC+8#g*qWSV@g?*@!%k{mF9l4`2(U`>p5QL0Qk89f9CXNi?$O z!_v#z9MO1d5*kOR37~7?Q>juR!Eo@GAC`xhfRwHZw2yYhnU?PvEfGp2gedHjv6>?{=q3t@7o`05=~Y(E&p3vDnK#@vh>RZ{`}lYgTCthZ zM#Vx6pi*~ne_rzm-6#h9?5uZX=SVHk6$v>W)rJbTMcp2{DxQA)K)udlzQE@qKJA240Eem${R3##U~6X> z*FMu->`&9`+F%H@h}p2r{KPdl@9%5G6<&y>h2MM~ywTHlKg3xeA}ZE6ssN5KNLEOg zt2^o(ZfyE7Q_hVI@ft3P=x=>21HcY3*aM+yo}!Y@X7E;mi1CwHkW=3dIkbGfNTMXx z2c+IE##Vg{%G9M*p2UaQ^(O4#C_wyGX}NfU&P(Z3@36I}8nI_$Va9<^FtbT@+D6_P zFGOROFqfRIki+*=HU0?n8VYr{7Y~ruf-3Ja4IJfdBp4#;MTyt2Z_T%lxiWMrx$FtL z`kOT0qYN?}GCq$d7luJ)o36IFZK5BH2zBU%wS1$78|ED4)Z=^Dnj!ei&owxld9WvS_* zM;Q|-{dg!2h>jUMe76hIr2Jiji&h?dv#WNh3}CGqm@jzE_15oCr8i?$d`~^xO$m$5 zklfyUl-7d+#x2A|E3uy~D5!2Rh&`X!-I4w>jrf-rqf658 zI;S$@F07)%PkSU{*ESEZq5Aqji;wPE#U!{gSG}W@oDQ&XqV2b7Ta_jU%GD^udLex# zs3b&&HT}u>bc=$x=se?BCx(9gH+$coU3Y(egEywp@IEP#+SVDRzEf?di~JzV^LNH! zJnm6~y+;QAG+o1Tq6+ojrm1Dg27uF5Vt2kwD{pb^>MX#KAFJSW>Vt)`|?Gx5&m6fKT3{vTo|i3qOqL0*ncB0`Vx0* z{HN82|5Ut8moWS;ausL+mac%G2dFU3vHS9=o?vM4PNof03R*45i3N^KQO9PZ6RZpG zjQnr(y(XC{UkW_16c)|>LSSoRJ{_S6y|9MJU?p^Rm5a_}Qzs2LohqDIbdOR%GT2VX z7zCSSo;PBjCfDB$NDk=5nIB`vgOqB$?~V)*fqG|4BPPqMlv_n<>1+SWP8T9_bL>;d zA<=%-7qzfxVP!*3CAMk$<9WlPoj|%`fdsfw_^NZASFALtbm_$@tVgEFMrI&$gy8G4 zzwo;JQXkuP#Ks|9hnk&;A>62DTioGko*R1XPLtse69-4rv-HBGKHHN@zonW8pM9L4 zD4`l2r;rJ%b+Iszv{oI5J4r;#YL22}oe0Qlp89fQawdpu?liWK)(c=VQMWB#TojVsDfQ)B@|k<&&SZN; zrSDXf=k9cLz4y;gNYjNe-6sb@&e`u9&b@|;K=M}ePBltGf$H6|ztp2&J%KVvPvBI8 zp^WdAmji&8461-s)hd$uT`wBEOg`mQauSi5-+VvIa$7b|#Y(;A>BqSm-eRKT55LIRoN*A&62#Apj zTeMti8<9q2Bzz7aRt*c~m+lS=(eScxMK0DH`>BcID~X&73A!$7+4YvU0!wpJ2sIQ~LC5VUh-0E-Bc`ovy;|Fi+{MO1iKUvJ- z_OJ{^@gEBm%pA!&$|KBxH&l1bIczo!i#7`fKcEB~_G{xV1lq%2JG zt741&y_K;|NOlY;`TVc};&WbZP6F@ap`|iZ;-DErN>eQMetMH|7s@DoG;Cn}sS+Pg36B#w{W|duI(#$5ns#MN;W|GobrKc8`mjB_1Q zma=?3AR%bnOQ{Ewu+gE6jbtmxfYWs6u8APZ4r3t!(?5yPv{9&Q87tisbSKWl z_0=J@mDlI~r%-`DqBR44A-yi<6RHU@hUd}*>ef1=JS>UlPNbze;t;!MqVzVA#p4g#l zmiAuGk6%0k$e4ECE%>m|&ERd(ItNN*rq$`3b@1_%&)~UanZOS&yA-WCE2ijG#^U!( zNK5Q}rdr2})JB|Sol%}OoGa~LX!wyvA}u!04<*1nLgd>FCZ(Cit0ekgm$xaAry zrI0M#2W(bt9|yyDyECE3nK^L?j8)UccU&A%Ay38FV)RxVXw8!3hsSSjH%o@&aNfex zD)Gf|aO)JH7^ljJjOI^64x$H$sAN)0@VaPUwC`xi@N&_n^8OOEB}u^e@n{DRCXD^V zu!+40H)X_WQo(;_r)%oUysj1L^x05MQ}J@rZoVl@aX>)xqpCRof$Jgn$OEuvCjvX? zW+FLKiL@_!|0fN2UBGr;lxyx*&DM|sd>!WwOJj|?U}h8%E>QCWZLb*)Rr9AjO7{s z8dv2a$mOr)$|t<5)UH=aH%oT>x-=wZzTX~dFPAV7$^C@k4HLf9pYc1bU!OalH^QUA zxJUq_atU8bcX};{!z;(~LD82rN$v^2`DG|-N+f1riKi2*u*cfs*}Lf{U-FemKFrpq z%Pn_&qDa7-+LWH26-~+>fTRH2H(Y}s?|59Nj4j^t?f~NuYOTYMauZ+^1Rv~nV|LSM z4F-0JqoBHjq)GwRk3kk^O!Sw}UCB7Caq>@4}=Kh>sX0kCf16AJ|>sizGzw(BR2% z2=Q07WZMOo?~ZddtJmrhVX9OA0X}PLxFLY-Pwm4NBL*A@ zg+l|78mhB;3X1(`+-K~on|b+UkhY!b!W*H@NZ0BoB(%jS8M zy&ERz?hjki#^d_|>beo13y8w1vWlPhFmmQdYfaXd* zeJp^-KLEi$0RFEWrC%NictYV6dA4ptXe@B#k9L@8rMUa+-t6}2noVsN$TwaGIT&Gx zI05anlH4z|MQm@6)}uzVnC3W%HV(P05EC_MlCN&`8#N=~PzKl3*cU;QUY4g^i;vAK zeCdRy)3!3@YWLmFU?D7C9K@GYfiwFktMC=sC1oR_5q7%IV4DtRn*u{C*^+7`!=;Fd zu8fJ=W|Pr+#(PtL{QrE)m3b8s-ReLejS+u7YD45f;fYS$e4*^8Dbo-82M7v)UP9CX z-@YcAB7JP0wsm3EjPpkz$YGl=kjqX3a^>D+(?3mn(Ca^T0u9CZUg03n66!}^`tmYgwrly>1nrS>g@js{q*Gid4b?=m84 z?us|UIW44CI@@4$tnkpXBw73Yc3*yx)-B(v%!X)aop_O`YCIPWv64GeW7smN2pd-h z7oJo5D|7Kglj&&FSM)eV$b88?eW?`p^Ukc%%E?~sX`9AS3f{|fV`H2OP-J!W*1S}# z3z=}Aq~9=KusUs+?1pvXJ?EZ}Al9aH)NEDg`&egeLP$~cYsb*{X-93@{{X0+)&lyY z4^<<(Izyd5pI7>*7Mbse+1?A>SkLWRA9j~I#{G77tTZkcBnCfRXiT)I7rfA9d$U^` zTj_}(E7_CXI-phRr$_qDhq;2F8Ou-oo&p^u_#1Wezn;VrFi6Sw%Kl>9@AX^>DsR=W z@@zBz?bf(D(}($9E_oh1+YxgZFMR~atVHQrM~O_ZFD|ly3c=qTf4cLP%NqcWk?MI) z{;U9``L>%$a2O)#%WWKwKMZ{VUkOmhTf@xCKbI(lswzhg(?qkIl>`|YQMVDb_-L+t zCXo%$Ao3;QRk7-a8P}$f(v|Q%-qXM*1XfSKxA;(0UP0JSeEx#jQ9>sDDPTKI-Pj@% zW82AkAi*!+y7G0Di%SkR)|`@(%Wf@hb36Tfh!R39(ojn)%g(ED%l9J(f$IZiP&UTN ze%mv142WPvK5l>}Tg1uxq^Y&?xZVc!K-JpL-+%(_yi~U|@tKxNp6qeWi@mHH<=Aqn z(<`tibwLl>ma|Hh)U==-Y!eqwlsxE&SDcp_lxVRSHgEW4ydU^}!az#{iHwXXOwqZ$ z`bCm%JxBxM!Vr(>l7)A!th*F;NRAYghvZOqTxi*)suINTeZ)2*2eg*v#Y-uOT3Z@g~GvY`KJSz&J<%6`4!5@fwipT2K;Fpj(R>JrXIOu9=$gouXlor$QmJZH%% zE2VW(8#(iui-pqjetw|QeWS;3ip>EuGQI?ZWp-28{l1p2{4^E<2H4 z`4IgRN@YbNNSUcuD`ixWrl1^fEi_GI1WoT_AGKEV@fw1*I+;% zL>+cftBe$89dIvrJg_NlHf)F!o8i=0u86B&rtwx6zGqWRq-XUts?I-cv)NE2!92>9 z$~CC*aTZC11LO^EA{)H1e>n%G!#;zX_XGy)d2>?SyqV*H?ti7ZFFHcjygQudofl2( z0=1^T1^d0OmMCx9G28MY)gGB}lb|uIQ|Hbj-x0O&MT?ei&@48#sEQU6nP#ZJ4>OHzWz?o=O}m$_PsQei=x<*$&4W>d9n~`Qh`tV zr;6WP>AfGGHM}MvI1ayL%=}#fEn;9ng95c;1b7pt$6h#;$l+LBYXBLtt!7=mub->+ z;1I==MY$o`7Z`whPYB&{muCHk%Q@EE99zVK`6QC#{-k(48IfV>0fZ4D z)q&3GYBqSFE$7aAv`C;yuS5q*vR^yPI8lD|_nK)?Z z^Xe~o4R^&4SDE`z2!q_Q_=%TO<#r!&hzT>t{gru`l@f0pu~MHH{p;pd+^-jPM4$hL z2$rz7z1ZYrYK!HzH8pb{G3zXFG8r@_g3B&9Ry!P`7zM&Tv{`n#%X) zgSL_r$dBlmh5mOaxCqdeRGMroP3eRnys7&YWo`$p!uQY>oQw*L2L?t3@4p8Z%a%Xu z4$@yRvEN*?H^+z67Dzl}SPcpJO7EDP{oF4aCQHDhW32O>(9V@Ui~Alcack#{(O0}(DICX8GSDsmf z(C4z5)>bEjoP0~Px@`F*m9VnMm?U}@x8oQ%Z3I86!ItX)>=~lsC-0PpZ>k0fw(|?x zvHA})$>a{lqs=Wnj8cE$NWLdA+@5IaY2%}m<9*{T03va+vKWur z?#r5FToA27dY;QLZ^Xrc+5Z5QqmBKRVK~Oav&{PPMNR@5ejA{)=6NQf&?nFtfS+lLt1nL0(Oy! zs@qcnbSpa2zArJd9c@e~hBh{s(QL6Rc%TcLveVSUh7d?4jmrC9%BYHh@&*=yBHKul zQr{2VAaphw2TV!ZP@I6D>Qtq~zBezzQE7BZX$Pv|h?6b1RNjv6w1D(I)?hZs!qPge zK926Q`EJKPSnjh@P=3~TT$=-hunzwK%H^I;m98bSNgBi>Iqe?Hu{v{H9O%tjv6by> zsuPBx+Mm%DhMK|HDg!($Zn)m7fy*EN0MJ5>>lx7h0RE6%!H^WHkN2Y0<7zZB1i7Zi z5FhUal9`MrQmNyf_S7M+AYmHyR6xp@T4uALh!!8^2Ry)cm~W;XL@73#oWj$@^=UW3@&JE0@m+#o>BNs+`Bwp`gF+d1dSYVf9f zF}qF9LK9S{PTLL3f_MRj1ST_jV!}hiaj+h3zonB?ksEZgy`o2MYIN#6y6&8_=(gZ& z?u@u26FVDF?1M;-pTc%x!~ybPn}DrNjDxZ=8Yc3o$v6gTkZ780)N47W0C7FnVH_7V zqn~oL0qtRE#7=XXiKXGoV-=_FYY58Db5sUUxGpNya5s27E^3ma% z)>4L<3nfN0wX!GuD@TM8Kt~l$+GQkTKRQda)oEog+4zS1PD>RRxz+CCZgpB-OG}`U zm+Gj>!Loi~?yHOO1(u87_Z+`)os}PnyGKQ-Q${bu5ykj|%jP~t(bx4>@PxUfxXF~2 zCi}r>7tzMYKUJk-1j|gEj=Y?Tgf+p3Ve@Mjb5G!;6SEfLp+sT^Z(;i@crEvpnRzRC zkCg6Wo9_fSKv9xY) zb!u3fX0ou&NWkp)^;(u7Nit5q^Ev(TMGt}juB)UKP59kr_-#x$Y z6`xAJDsixew2^z5ONyne-^`!HwXQPZCkmNkU1XpThF9<}R~L zO_r))H(#pgQu(bIS)OxsLe}|)%5A2>QKU6wt>Oa+0L1lB^4I`4T1os$^*|&@l`_`a z%BmDKx7c9UJySJ-q&P8>KN10qld=< -# -*- coding: utf-8 -*- -# vim:ts=4:sw=4:expandtab:fileencoding=utf-8 - -import cherrypy -from cherrypy._cpcompat import md5, ntob -from cherrypy.lib import auth_basic -from cherrypy.test import helper - - -class BasicAuthTest(helper.CPWebCase): - - def setup_server(): - class Root: - def index(self): - return "This is public." - index.exposed = True - - class BasicProtected: - def index(self): - return "Hello %s, you've been authorized." % cherrypy.request.login - index.exposed = True - - class BasicProtected2: - def index(self): - return "Hello %s, you've been authorized." % cherrypy.request.login - index.exposed = True - - userpassdict = {'xuser' : 'xpassword'} - userhashdict = {'xuser' : md5(ntob('xpassword')).hexdigest()} - - def checkpasshash(realm, user, password): - p = userhashdict.get(user) - return p and p == md5(ntob(password)).hexdigest() or False - - conf = {'/basic': {'tools.auth_basic.on': True, - 'tools.auth_basic.realm': 'wonderland', - 'tools.auth_basic.checkpassword': auth_basic.checkpassword_dict(userpassdict)}, - '/basic2': {'tools.auth_basic.on': True, - 'tools.auth_basic.realm': 'wonderland', - 'tools.auth_basic.checkpassword': checkpasshash}, - } - - root = Root() - root.basic = BasicProtected() - root.basic2 = BasicProtected2() - cherrypy.tree.mount(root, config=conf) - setup_server = staticmethod(setup_server) - - def testPublic(self): - self.getPage("/") - self.assertStatus('200 OK') - self.assertHeader('Content-Type', 'text/html;charset=utf-8') - self.assertBody('This is public.') - - def testBasic(self): - self.getPage("/basic/") - self.assertStatus(401) - self.assertHeader('WWW-Authenticate', 'Basic realm="wonderland"') - - self.getPage('/basic/', [('Authorization', 'Basic eHVzZXI6eHBhc3N3b3JX')]) - self.assertStatus(401) - - self.getPage('/basic/', [('Authorization', 'Basic eHVzZXI6eHBhc3N3b3Jk')]) - self.assertStatus('200 OK') - self.assertBody("Hello xuser, you've been authorized.") - - def testBasic2(self): - self.getPage("/basic2/") - self.assertStatus(401) - self.assertHeader('WWW-Authenticate', 'Basic realm="wonderland"') - - self.getPage('/basic2/', [('Authorization', 'Basic eHVzZXI6eHBhc3N3b3JX')]) - self.assertStatus(401) - - self.getPage('/basic2/', [('Authorization', 'Basic eHVzZXI6eHBhc3N3b3Jk')]) - self.assertStatus('200 OK') - self.assertBody("Hello xuser, you've been authorized.") - diff --git a/cherrypy/test/test_auth_digest.py b/cherrypy/test/test_auth_digest.py deleted file mode 100644 index 1960fa81..00000000 --- a/cherrypy/test/test_auth_digest.py +++ /dev/null @@ -1,115 +0,0 @@ -# This file is part of CherryPy -# -*- coding: utf-8 -*- -# vim:ts=4:sw=4:expandtab:fileencoding=utf-8 - - -import cherrypy -from cherrypy.lib import auth_digest - -from cherrypy.test import helper - -class DigestAuthTest(helper.CPWebCase): - - def setup_server(): - class Root: - def index(self): - return "This is public." - index.exposed = True - - class DigestProtected: - def index(self): - return "Hello %s, you've been authorized." % cherrypy.request.login - index.exposed = True - - def fetch_users(): - return {'test': 'test'} - - - get_ha1 = cherrypy.lib.auth_digest.get_ha1_dict_plain(fetch_users()) - conf = {'/digest': {'tools.auth_digest.on': True, - 'tools.auth_digest.realm': 'localhost', - 'tools.auth_digest.get_ha1': get_ha1, - 'tools.auth_digest.key': 'a565c27146791cfb', - 'tools.auth_digest.debug': 'True'}} - - root = Root() - root.digest = DigestProtected() - cherrypy.tree.mount(root, config=conf) - setup_server = staticmethod(setup_server) - - def testPublic(self): - self.getPage("/") - self.assertStatus('200 OK') - self.assertHeader('Content-Type', 'text/html;charset=utf-8') - self.assertBody('This is public.') - - def testDigest(self): - self.getPage("/digest/") - self.assertStatus(401) - - value = None - for k, v in self.headers: - if k.lower() == "www-authenticate": - if v.startswith("Digest"): - value = v - break - - if value is None: - self._handlewebError("Digest authentification scheme was not found") - - value = value[7:] - items = value.split(', ') - tokens = {} - for item in items: - key, value = item.split('=') - tokens[key.lower()] = value - - missing_msg = "%s is missing" - bad_value_msg = "'%s' was expecting '%s' but found '%s'" - nonce = None - if 'realm' not in tokens: - self._handlewebError(missing_msg % 'realm') - elif tokens['realm'] != '"localhost"': - self._handlewebError(bad_value_msg % ('realm', '"localhost"', tokens['realm'])) - if 'nonce' not in tokens: - self._handlewebError(missing_msg % 'nonce') - else: - nonce = tokens['nonce'].strip('"') - if 'algorithm' not in tokens: - self._handlewebError(missing_msg % 'algorithm') - elif tokens['algorithm'] != '"MD5"': - self._handlewebError(bad_value_msg % ('algorithm', '"MD5"', tokens['algorithm'])) - if 'qop' not in tokens: - self._handlewebError(missing_msg % 'qop') - elif tokens['qop'] != '"auth"': - self._handlewebError(bad_value_msg % ('qop', '"auth"', tokens['qop'])) - - get_ha1 = auth_digest.get_ha1_dict_plain({'test' : 'test'}) - - # Test user agent response with a wrong value for 'realm' - base_auth = 'Digest username="test", realm="wrong realm", nonce="%s", uri="/digest/", algorithm=MD5, response="%s", qop=auth, nc=%s, cnonce="1522e61005789929"' - - auth_header = base_auth % (nonce, '11111111111111111111111111111111', '00000001') - auth = auth_digest.HttpDigestAuthorization(auth_header, 'GET') - # calculate the response digest - ha1 = get_ha1(auth.realm, 'test') - response = auth.request_digest(ha1) - # send response with correct response digest, but wrong realm - auth_header = base_auth % (nonce, response, '00000001') - self.getPage('/digest/', [('Authorization', auth_header)]) - self.assertStatus(401) - - # Test that must pass - base_auth = 'Digest username="test", realm="localhost", nonce="%s", uri="/digest/", algorithm=MD5, response="%s", qop=auth, nc=%s, cnonce="1522e61005789929"' - - auth_header = base_auth % (nonce, '11111111111111111111111111111111', '00000001') - auth = auth_digest.HttpDigestAuthorization(auth_header, 'GET') - # calculate the response digest - ha1 = get_ha1('localhost', 'test') - response = auth.request_digest(ha1) - # send response with correct response digest - auth_header = base_auth % (nonce, response, '00000001') - self.getPage('/digest/', [('Authorization', auth_header)]) - self.assertStatus('200 OK') - self.assertBody("Hello test, you've been authorized.") - diff --git a/cherrypy/test/test_bus.py b/cherrypy/test/test_bus.py deleted file mode 100644 index 51c10220..00000000 --- a/cherrypy/test/test_bus.py +++ /dev/null @@ -1,263 +0,0 @@ -import threading -import time -import unittest - -import cherrypy -from cherrypy._cpcompat import get_daemon, set -from cherrypy.process import wspbus - - -msg = "Listener %d on channel %s: %s." - - -class PublishSubscribeTests(unittest.TestCase): - - def get_listener(self, channel, index): - def listener(arg=None): - self.responses.append(msg % (index, channel, arg)) - return listener - - def test_builtin_channels(self): - b = wspbus.Bus() - - self.responses, expected = [], [] - - for channel in b.listeners: - for index, priority in enumerate([100, 50, 0, 51]): - b.subscribe(channel, self.get_listener(channel, index), priority) - - for channel in b.listeners: - b.publish(channel) - expected.extend([msg % (i, channel, None) for i in (2, 1, 3, 0)]) - b.publish(channel, arg=79347) - expected.extend([msg % (i, channel, 79347) for i in (2, 1, 3, 0)]) - - self.assertEqual(self.responses, expected) - - def test_custom_channels(self): - b = wspbus.Bus() - - self.responses, expected = [], [] - - custom_listeners = ('hugh', 'louis', 'dewey') - for channel in custom_listeners: - for index, priority in enumerate([None, 10, 60, 40]): - b.subscribe(channel, self.get_listener(channel, index), priority) - - for channel in custom_listeners: - b.publish(channel, 'ah so') - expected.extend([msg % (i, channel, 'ah so') for i in (1, 3, 0, 2)]) - b.publish(channel) - expected.extend([msg % (i, channel, None) for i in (1, 3, 0, 2)]) - - self.assertEqual(self.responses, expected) - - def test_listener_errors(self): - b = wspbus.Bus() - - self.responses, expected = [], [] - channels = [c for c in b.listeners if c != 'log'] - - for channel in channels: - b.subscribe(channel, self.get_listener(channel, 1)) - # This will break since the lambda takes no args. - b.subscribe(channel, lambda: None, priority=20) - - for channel in channels: - self.assertRaises(wspbus.ChannelFailures, b.publish, channel, 123) - expected.append(msg % (1, channel, 123)) - - self.assertEqual(self.responses, expected) - - -class BusMethodTests(unittest.TestCase): - - def log(self, bus): - self._log_entries = [] - def logit(msg, level): - self._log_entries.append(msg) - bus.subscribe('log', logit) - - def assertLog(self, entries): - self.assertEqual(self._log_entries, entries) - - def get_listener(self, channel, index): - def listener(arg=None): - self.responses.append(msg % (index, channel, arg)) - return listener - - def test_start(self): - b = wspbus.Bus() - self.log(b) - - self.responses = [] - num = 3 - for index in range(num): - b.subscribe('start', self.get_listener('start', index)) - - b.start() - try: - # The start method MUST call all 'start' listeners. - self.assertEqual(set(self.responses), - set([msg % (i, 'start', None) for i in range(num)])) - # The start method MUST move the state to STARTED - # (or EXITING, if errors occur) - self.assertEqual(b.state, b.states.STARTED) - # The start method MUST log its states. - self.assertLog(['Bus STARTING', 'Bus STARTED']) - finally: - # Exit so the atexit handler doesn't complain. - b.exit() - - def test_stop(self): - b = wspbus.Bus() - self.log(b) - - self.responses = [] - num = 3 - for index in range(num): - b.subscribe('stop', self.get_listener('stop', index)) - - b.stop() - - # The stop method MUST call all 'stop' listeners. - self.assertEqual(set(self.responses), - set([msg % (i, 'stop', None) for i in range(num)])) - # The stop method MUST move the state to STOPPED - self.assertEqual(b.state, b.states.STOPPED) - # The stop method MUST log its states. - self.assertLog(['Bus STOPPING', 'Bus STOPPED']) - - def test_graceful(self): - b = wspbus.Bus() - self.log(b) - - self.responses = [] - num = 3 - for index in range(num): - b.subscribe('graceful', self.get_listener('graceful', index)) - - b.graceful() - - # The graceful method MUST call all 'graceful' listeners. - self.assertEqual(set(self.responses), - set([msg % (i, 'graceful', None) for i in range(num)])) - # The graceful method MUST log its states. - self.assertLog(['Bus graceful']) - - def test_exit(self): - b = wspbus.Bus() - self.log(b) - - self.responses = [] - num = 3 - for index in range(num): - b.subscribe('stop', self.get_listener('stop', index)) - b.subscribe('exit', self.get_listener('exit', index)) - - b.exit() - - # The exit method MUST call all 'stop' listeners, - # and then all 'exit' listeners. - self.assertEqual(set(self.responses), - set([msg % (i, 'stop', None) for i in range(num)] + - [msg % (i, 'exit', None) for i in range(num)])) - # The exit method MUST move the state to EXITING - self.assertEqual(b.state, b.states.EXITING) - # The exit method MUST log its states. - self.assertLog(['Bus STOPPING', 'Bus STOPPED', 'Bus EXITING', 'Bus EXITED']) - - def test_wait(self): - b = wspbus.Bus() - - def f(method): - time.sleep(0.2) - getattr(b, method)() - - for method, states in [('start', [b.states.STARTED]), - ('stop', [b.states.STOPPED]), - ('start', [b.states.STARTING, b.states.STARTED]), - ('exit', [b.states.EXITING]), - ]: - threading.Thread(target=f, args=(method,)).start() - b.wait(states) - - # The wait method MUST wait for the given state(s). - if b.state not in states: - self.fail("State %r not in %r" % (b.state, states)) - - def test_block(self): - b = wspbus.Bus() - self.log(b) - - def f(): - time.sleep(0.2) - b.exit() - def g(): - time.sleep(0.4) - threading.Thread(target=f).start() - threading.Thread(target=g).start() - threads = [t for t in threading.enumerate() if not get_daemon(t)] - self.assertEqual(len(threads), 3) - - b.block() - - # The block method MUST wait for the EXITING state. - self.assertEqual(b.state, b.states.EXITING) - # The block method MUST wait for ALL non-main, non-daemon threads to finish. - threads = [t for t in threading.enumerate() if not get_daemon(t)] - self.assertEqual(len(threads), 1) - # The last message will mention an indeterminable thread name; ignore it - self.assertEqual(self._log_entries[:-1], - ['Bus STOPPING', 'Bus STOPPED', - 'Bus EXITING', 'Bus EXITED', - 'Waiting for child threads to terminate...']) - - def test_start_with_callback(self): - b = wspbus.Bus() - self.log(b) - try: - events = [] - def f(*args, **kwargs): - events.append(("f", args, kwargs)) - def g(): - events.append("g") - b.subscribe("start", g) - b.start_with_callback(f, (1, 3, 5), {"foo": "bar"}) - # Give wait() time to run f() - time.sleep(0.2) - - # The callback method MUST wait for the STARTED state. - self.assertEqual(b.state, b.states.STARTED) - # The callback method MUST run after all start methods. - self.assertEqual(events, ["g", ("f", (1, 3, 5), {"foo": "bar"})]) - finally: - b.exit() - - def test_log(self): - b = wspbus.Bus() - self.log(b) - self.assertLog([]) - - # Try a normal message. - expected = [] - for msg in ["O mah darlin'"] * 3 + ["Clementiiiiiiiine"]: - b.log(msg) - expected.append(msg) - self.assertLog(expected) - - # Try an error message - try: - foo - except NameError: - b.log("You are lost and gone forever", traceback=True) - lastmsg = self._log_entries[-1] - if "Traceback" not in lastmsg or "NameError" not in lastmsg: - self.fail("Last log message %r did not contain " - "the expected traceback." % lastmsg) - else: - self.fail("NameError was not raised as expected.") - - -if __name__ == "__main__": - unittest.main() diff --git a/cherrypy/test/test_caching.py b/cherrypy/test/test_caching.py deleted file mode 100644 index 720a933a..00000000 --- a/cherrypy/test/test_caching.py +++ /dev/null @@ -1,329 +0,0 @@ -import datetime -import gzip -from itertools import count -import os -curdir = os.path.join(os.getcwd(), os.path.dirname(__file__)) -import sys -import threading -import time -import urllib - -import cherrypy -from cherrypy._cpcompat import next, ntob, quote, xrange -from cherrypy.lib import httputil - -gif_bytes = ntob('GIF89a\x01\x00\x01\x00\x82\x00\x01\x99"\x1e\x00\x00\x00\x00\x00' - '\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00' - '\x00,\x00\x00\x00\x00\x01\x00\x01\x00\x02\x03\x02\x08\t\x00;') - - - -from cherrypy.test import helper - -class CacheTest(helper.CPWebCase): - - def setup_server(): - - class Root: - - _cp_config = {'tools.caching.on': True} - - def __init__(self): - self.counter = 0 - self.control_counter = 0 - self.longlock = threading.Lock() - - def index(self): - self.counter += 1 - msg = "visit #%s" % self.counter - return msg - index.exposed = True - - def control(self): - self.control_counter += 1 - return "visit #%s" % self.control_counter - control.exposed = True - - def a_gif(self): - cherrypy.response.headers['Last-Modified'] = httputil.HTTPDate() - return gif_bytes - a_gif.exposed = True - - def long_process(self, seconds='1'): - try: - self.longlock.acquire() - time.sleep(float(seconds)) - finally: - self.longlock.release() - return 'success!' - long_process.exposed = True - - def clear_cache(self, path): - cherrypy._cache.store[cherrypy.request.base + path].clear() - clear_cache.exposed = True - - class VaryHeaderCachingServer(object): - - _cp_config = {'tools.caching.on': True, - 'tools.response_headers.on': True, - 'tools.response_headers.headers': [('Vary', 'Our-Varying-Header')], - } - - def __init__(self): - self.counter = count(1) - - def index(self): - return "visit #%s" % next(self.counter) - index.exposed = True - - class UnCached(object): - _cp_config = {'tools.expires.on': True, - 'tools.expires.secs': 60, - 'tools.staticdir.on': True, - 'tools.staticdir.dir': 'static', - 'tools.staticdir.root': curdir, - } - - def force(self): - cherrypy.response.headers['Etag'] = 'bibbitybobbityboo' - self._cp_config['tools.expires.force'] = True - self._cp_config['tools.expires.secs'] = 0 - return "being forceful" - force.exposed = True - force._cp_config = {'tools.expires.secs': 0} - - def dynamic(self): - cherrypy.response.headers['Etag'] = 'bibbitybobbityboo' - cherrypy.response.headers['Cache-Control'] = 'private' - return "D-d-d-dynamic!" - dynamic.exposed = True - - def cacheable(self): - cherrypy.response.headers['Etag'] = 'bibbitybobbityboo' - return "Hi, I'm cacheable." - cacheable.exposed = True - - def specific(self): - cherrypy.response.headers['Etag'] = 'need_this_to_make_me_cacheable' - return "I am being specific" - specific.exposed = True - specific._cp_config = {'tools.expires.secs': 86400} - - class Foo(object):pass - - def wrongtype(self): - cherrypy.response.headers['Etag'] = 'need_this_to_make_me_cacheable' - return "Woops" - wrongtype.exposed = True - wrongtype._cp_config = {'tools.expires.secs': Foo()} - - cherrypy.tree.mount(Root()) - cherrypy.tree.mount(UnCached(), "/expires") - cherrypy.tree.mount(VaryHeaderCachingServer(), "/varying_headers") - cherrypy.config.update({'tools.gzip.on': True}) - setup_server = staticmethod(setup_server) - - def testCaching(self): - elapsed = 0.0 - for trial in range(10): - self.getPage("/") - # The response should be the same every time, - # except for the Age response header. - self.assertBody('visit #1') - if trial != 0: - age = int(self.assertHeader("Age")) - self.assert_(age >= elapsed) - elapsed = age - - # POST, PUT, DELETE should not be cached. - self.getPage("/", method="POST") - self.assertBody('visit #2') - # Because gzip is turned on, the Vary header should always Vary for content-encoding - self.assertHeader('Vary', 'Accept-Encoding') - # The previous request should have invalidated the cache, - # so this request will recalc the response. - self.getPage("/", method="GET") - self.assertBody('visit #3') - # ...but this request should get the cached copy. - self.getPage("/", method="GET") - self.assertBody('visit #3') - self.getPage("/", method="DELETE") - self.assertBody('visit #4') - - # The previous request should have invalidated the cache, - # so this request will recalc the response. - self.getPage("/", method="GET", headers=[('Accept-Encoding', 'gzip')]) - self.assertHeader('Content-Encoding', 'gzip') - self.assertHeader('Vary') - self.assertEqual(cherrypy.lib.encoding.decompress(self.body), ntob("visit #5")) - - # Now check that a second request gets the gzip header and gzipped body - # This also tests a bug in 3.0 to 3.0.2 whereby the cached, gzipped - # response body was being gzipped a second time. - self.getPage("/", method="GET", headers=[('Accept-Encoding', 'gzip')]) - self.assertHeader('Content-Encoding', 'gzip') - self.assertEqual(cherrypy.lib.encoding.decompress(self.body), ntob("visit #5")) - - # Now check that a third request that doesn't accept gzip - # skips the cache (because the 'Vary' header denies it). - self.getPage("/", method="GET") - self.assertNoHeader('Content-Encoding') - self.assertBody('visit #6') - - def testVaryHeader(self): - self.getPage("/varying_headers/") - self.assertStatus("200 OK") - self.assertHeaderItemValue('Vary', 'Our-Varying-Header') - self.assertBody('visit #1') - - # Now check that different 'Vary'-fields don't evict each other. - # This test creates 2 requests with different 'Our-Varying-Header' - # and then tests if the first one still exists. - self.getPage("/varying_headers/", headers=[('Our-Varying-Header', 'request 2')]) - self.assertStatus("200 OK") - self.assertBody('visit #2') - - self.getPage("/varying_headers/", headers=[('Our-Varying-Header', 'request 2')]) - self.assertStatus("200 OK") - self.assertBody('visit #2') - - self.getPage("/varying_headers/") - self.assertStatus("200 OK") - self.assertBody('visit #1') - - def testExpiresTool(self): - # test setting an expires header - self.getPage("/expires/specific") - self.assertStatus("200 OK") - self.assertHeader("Expires") - - # test exceptions for bad time values - self.getPage("/expires/wrongtype") - self.assertStatus(500) - self.assertInBody("TypeError") - - # static content should not have "cache prevention" headers - self.getPage("/expires/index.html") - self.assertStatus("200 OK") - self.assertNoHeader("Pragma") - self.assertNoHeader("Cache-Control") - self.assertHeader("Expires") - - # dynamic content that sets indicators should not have - # "cache prevention" headers - self.getPage("/expires/cacheable") - self.assertStatus("200 OK") - self.assertNoHeader("Pragma") - self.assertNoHeader("Cache-Control") - self.assertHeader("Expires") - - self.getPage('/expires/dynamic') - self.assertBody("D-d-d-dynamic!") - # the Cache-Control header should be untouched - self.assertHeader("Cache-Control", "private") - self.assertHeader("Expires") - - # configure the tool to ignore indicators and replace existing headers - self.getPage("/expires/force") - self.assertStatus("200 OK") - # This also gives us a chance to test 0 expiry with no other headers - self.assertHeader("Pragma", "no-cache") - if cherrypy.server.protocol_version == "HTTP/1.1": - self.assertHeader("Cache-Control", "no-cache, must-revalidate") - self.assertHeader("Expires", "Sun, 28 Jan 2007 00:00:00 GMT") - - # static content should now have "cache prevention" headers - self.getPage("/expires/index.html") - self.assertStatus("200 OK") - self.assertHeader("Pragma", "no-cache") - if cherrypy.server.protocol_version == "HTTP/1.1": - self.assertHeader("Cache-Control", "no-cache, must-revalidate") - self.assertHeader("Expires", "Sun, 28 Jan 2007 00:00:00 GMT") - - # the cacheable handler should now have "cache prevention" headers - self.getPage("/expires/cacheable") - self.assertStatus("200 OK") - self.assertHeader("Pragma", "no-cache") - if cherrypy.server.protocol_version == "HTTP/1.1": - self.assertHeader("Cache-Control", "no-cache, must-revalidate") - self.assertHeader("Expires", "Sun, 28 Jan 2007 00:00:00 GMT") - - self.getPage('/expires/dynamic') - self.assertBody("D-d-d-dynamic!") - # dynamic sets Cache-Control to private but it should be - # overwritten here ... - self.assertHeader("Pragma", "no-cache") - if cherrypy.server.protocol_version == "HTTP/1.1": - self.assertHeader("Cache-Control", "no-cache, must-revalidate") - self.assertHeader("Expires", "Sun, 28 Jan 2007 00:00:00 GMT") - - def testLastModified(self): - self.getPage("/a.gif") - self.assertStatus(200) - self.assertBody(gif_bytes) - lm1 = self.assertHeader("Last-Modified") - - # this request should get the cached copy. - self.getPage("/a.gif") - self.assertStatus(200) - self.assertBody(gif_bytes) - self.assertHeader("Age") - lm2 = self.assertHeader("Last-Modified") - self.assertEqual(lm1, lm2) - - # this request should match the cached copy, but raise 304. - self.getPage("/a.gif", [('If-Modified-Since', lm1)]) - self.assertStatus(304) - self.assertNoHeader("Last-Modified") - if not getattr(cherrypy.server, "using_apache", False): - self.assertHeader("Age") - - def test_antistampede(self): - SECONDS = 4 - # We MUST make an initial synchronous request in order to create the - # AntiStampedeCache object, and populate its selecting_headers, - # before the actual stampede. - self.getPage("/long_process?seconds=%d" % SECONDS) - self.assertBody('success!') - self.getPage("/clear_cache?path=" + - quote('/long_process?seconds=%d' % SECONDS, safe='')) - self.assertStatus(200) - sys.stdout.write("prepped... ") - sys.stdout.flush() - - start = datetime.datetime.now() - def run(): - self.getPage("/long_process?seconds=%d" % SECONDS) - # The response should be the same every time - self.assertBody('success!') - ts = [threading.Thread(target=run) for i in xrange(100)] - for t in ts: - t.start() - for t in ts: - t.join() - self.assertEqualDates(start, datetime.datetime.now(), - # Allow a second for our thread/TCP overhead etc. - seconds=SECONDS + 1.1) - - def test_cache_control(self): - self.getPage("/control") - self.assertBody('visit #1') - self.getPage("/control") - self.assertBody('visit #1') - - self.getPage("/control", headers=[('Cache-Control', 'no-cache')]) - self.assertBody('visit #2') - self.getPage("/control") - self.assertBody('visit #2') - - self.getPage("/control", headers=[('Pragma', 'no-cache')]) - self.assertBody('visit #3') - self.getPage("/control") - self.assertBody('visit #3') - - time.sleep(1) - self.getPage("/control", headers=[('Cache-Control', 'max-age=0')]) - self.assertBody('visit #4') - self.getPage("/control") - self.assertBody('visit #4') - diff --git a/cherrypy/test/test_config.py b/cherrypy/test/test_config.py deleted file mode 100644 index a0bd8ab9..00000000 --- a/cherrypy/test/test_config.py +++ /dev/null @@ -1,249 +0,0 @@ -"""Tests for the CherryPy configuration system.""" - -import os, sys -localDir = os.path.join(os.getcwd(), os.path.dirname(__file__)) - -from cherrypy._cpcompat import ntob, StringIO -import unittest - -import cherrypy - -def setup_server(): - - class Root: - - _cp_config = {'foo': 'this', - 'bar': 'that'} - - def __init__(self): - cherrypy.config.namespaces['db'] = self.db_namespace - - def db_namespace(self, k, v): - if k == "scheme": - self.db = v - - # @cherrypy.expose(alias=('global_', 'xyz')) - def index(self, key): - return cherrypy.request.config.get(key, "None") - index = cherrypy.expose(index, alias=('global_', 'xyz')) - - def repr(self, key): - return repr(cherrypy.request.config.get(key, None)) - repr.exposed = True - - def dbscheme(self): - return self.db - dbscheme.exposed = True - - def plain(self, x): - return x - plain.exposed = True - plain._cp_config = {'request.body.attempt_charsets': ['utf-16']} - - favicon_ico = cherrypy.tools.staticfile.handler( - filename=os.path.join(localDir, '../favicon.ico')) - - class Foo: - - _cp_config = {'foo': 'this2', - 'baz': 'that2'} - - def index(self, key): - return cherrypy.request.config.get(key, "None") - index.exposed = True - nex = index - - def silly(self): - return 'Hello world' - silly.exposed = True - silly._cp_config = {'response.headers.X-silly': 'sillyval'} - - def bar(self, key): - return repr(cherrypy.request.config.get(key, None)) - bar.exposed = True - bar._cp_config = {'foo': 'this3', 'bax': 'this4'} - - class Another: - - def index(self, key): - return str(cherrypy.request.config.get(key, "None")) - index.exposed = True - - - def raw_namespace(key, value): - if key == 'input.map': - handler = cherrypy.request.handler - def wrapper(): - params = cherrypy.request.params - for name, coercer in list(value.items()): - try: - params[name] = coercer(params[name]) - except KeyError: - pass - return handler() - cherrypy.request.handler = wrapper - elif key == 'output': - handler = cherrypy.request.handler - def wrapper(): - # 'value' is a type (like int or str). - return value(handler()) - cherrypy.request.handler = wrapper - - class Raw: - - _cp_config = {'raw.output': repr} - - def incr(self, num): - return num + 1 - incr.exposed = True - incr._cp_config = {'raw.input.map': {'num': int}} - - ioconf = StringIO(""" -[/] -neg: -1234 -filename: os.path.join(sys.prefix, "hello.py") -thing1: cherrypy.lib.httputil.response_codes[404] -thing2: __import__('cherrypy.tutorial', globals(), locals(), ['']).thing2 -complex: 3+2j -ones: "11" -twos: "22" -stradd: %%(ones)s + %%(twos)s + "33" - -[/favicon.ico] -tools.staticfile.filename = %r -""" % os.path.join(localDir, 'static/dirback.jpg')) - - root = Root() - root.foo = Foo() - root.raw = Raw() - app = cherrypy.tree.mount(root, config=ioconf) - app.request_class.namespaces['raw'] = raw_namespace - - cherrypy.tree.mount(Another(), "/another") - cherrypy.config.update({'luxuryyacht': 'throatwobblermangrove', - 'db.scheme': r"sqlite///memory", - }) - - -# Client-side code # - -from cherrypy.test import helper - -class ConfigTests(helper.CPWebCase): - setup_server = staticmethod(setup_server) - - def testConfig(self): - tests = [ - ('/', 'nex', 'None'), - ('/', 'foo', 'this'), - ('/', 'bar', 'that'), - ('/xyz', 'foo', 'this'), - ('/foo/', 'foo', 'this2'), - ('/foo/', 'bar', 'that'), - ('/foo/', 'bax', 'None'), - ('/foo/bar', 'baz', "'that2'"), - ('/foo/nex', 'baz', 'that2'), - # If 'foo' == 'this', then the mount point '/another' leaks into '/'. - ('/another/','foo', 'None'), - ] - for path, key, expected in tests: - self.getPage(path + "?key=" + key) - self.assertBody(expected) - - expectedconf = { - # From CP defaults - 'tools.log_headers.on': False, - 'tools.log_tracebacks.on': True, - 'request.show_tracebacks': True, - 'log.screen': False, - 'environment': 'test_suite', - 'engine.autoreload_on': False, - # From global config - 'luxuryyacht': 'throatwobblermangrove', - # From Root._cp_config - 'bar': 'that', - # From Foo._cp_config - 'baz': 'that2', - # From Foo.bar._cp_config - 'foo': 'this3', - 'bax': 'this4', - } - for key, expected in expectedconf.items(): - self.getPage("/foo/bar?key=" + key) - self.assertBody(repr(expected)) - - def testUnrepr(self): - self.getPage("/repr?key=neg") - self.assertBody("-1234") - - self.getPage("/repr?key=filename") - self.assertBody(repr(os.path.join(sys.prefix, "hello.py"))) - - self.getPage("/repr?key=thing1") - self.assertBody(repr(cherrypy.lib.httputil.response_codes[404])) - - if not getattr(cherrypy.server, "using_apache", False): - # The object ID's won't match up when using Apache, since the - # server and client are running in different processes. - self.getPage("/repr?key=thing2") - from cherrypy.tutorial import thing2 - self.assertBody(repr(thing2)) - - self.getPage("/repr?key=complex") - self.assertBody("(3+2j)") - - self.getPage("/repr?key=stradd") - self.assertBody(repr("112233")) - - def testRespNamespaces(self): - self.getPage("/foo/silly") - self.assertHeader('X-silly', 'sillyval') - self.assertBody('Hello world') - - def testCustomNamespaces(self): - self.getPage("/raw/incr?num=12") - self.assertBody("13") - - self.getPage("/dbscheme") - self.assertBody(r"sqlite///memory") - - def testHandlerToolConfigOverride(self): - # Assert that config overrides tool constructor args. Above, we set - # the favicon in the page handler to be '../favicon.ico', - # but then overrode it in config to be './static/dirback.jpg'. - self.getPage("/favicon.ico") - self.assertBody(open(os.path.join(localDir, "static/dirback.jpg"), - "rb").read()) - - def test_request_body_namespace(self): - self.getPage("/plain", method='POST', headers=[ - ('Content-Type', 'application/x-www-form-urlencoded'), - ('Content-Length', '13')], - body=ntob('\xff\xfex\x00=\xff\xfea\x00b\x00c\x00')) - self.assertBody("abc") - - -class VariableSubstitutionTests(unittest.TestCase): - setup_server = staticmethod(setup_server) - - def test_config(self): - from textwrap import dedent - - # variable substitution with [DEFAULT] - conf = dedent(""" - [DEFAULT] - dir = "/some/dir" - my.dir = %(dir)s + "/sub" - - [my] - my.dir = %(dir)s + "/my/dir" - my.dir2 = %(my.dir)s + '/dir2' - - """) - - fp = StringIO(conf) - - cherrypy.config.update(fp) - self.assertEqual(cherrypy.config["my"]["my.dir"], "/some/dir/my/dir") - self.assertEqual(cherrypy.config["my"]["my.dir2"], "/some/dir/my/dir/dir2") - diff --git a/cherrypy/test/test_config_server.py b/cherrypy/test/test_config_server.py deleted file mode 100644 index 0b9718da..00000000 --- a/cherrypy/test/test_config_server.py +++ /dev/null @@ -1,121 +0,0 @@ -"""Tests for the CherryPy configuration system.""" - -import os, sys -localDir = os.path.join(os.getcwd(), os.path.dirname(__file__)) -import socket -import time - -import cherrypy - - -# Client-side code # - -from cherrypy.test import helper - -class ServerConfigTests(helper.CPWebCase): - - def setup_server(): - - class Root: - def index(self): - return cherrypy.request.wsgi_environ['SERVER_PORT'] - index.exposed = True - - def upload(self, file): - return "Size: %s" % len(file.file.read()) - upload.exposed = True - - def tinyupload(self): - return cherrypy.request.body.read() - tinyupload.exposed = True - tinyupload._cp_config = {'request.body.maxbytes': 100} - - cherrypy.tree.mount(Root()) - - cherrypy.config.update({ - 'server.socket_host': '0.0.0.0', - 'server.socket_port': 9876, - 'server.max_request_body_size': 200, - 'server.max_request_header_size': 500, - 'server.socket_timeout': 0.5, - - # Test explicit server.instance - 'server.2.instance': 'cherrypy._cpwsgi_server.CPWSGIServer', - 'server.2.socket_port': 9877, - - # Test non-numeric - # Also test default server.instance = builtin server - 'server.yetanother.socket_port': 9878, - }) - setup_server = staticmethod(setup_server) - - PORT = 9876 - - def testBasicConfig(self): - self.getPage("/") - self.assertBody(str(self.PORT)) - - def testAdditionalServers(self): - if self.scheme == 'https': - return self.skip("not available under ssl") - self.PORT = 9877 - self.getPage("/") - self.assertBody(str(self.PORT)) - self.PORT = 9878 - self.getPage("/") - self.assertBody(str(self.PORT)) - - def testMaxRequestSizePerHandler(self): - if getattr(cherrypy.server, "using_apache", False): - return self.skip("skipped due to known Apache differences... ") - - self.getPage('/tinyupload', method="POST", - headers=[('Content-Type', 'text/plain'), - ('Content-Length', '100')], - body="x" * 100) - self.assertStatus(200) - self.assertBody("x" * 100) - - self.getPage('/tinyupload', method="POST", - headers=[('Content-Type', 'text/plain'), - ('Content-Length', '101')], - body="x" * 101) - self.assertStatus(413) - - def testMaxRequestSize(self): - if getattr(cherrypy.server, "using_apache", False): - return self.skip("skipped due to known Apache differences... ") - - for size in (500, 5000, 50000): - self.getPage("/", headers=[('From', "x" * 500)]) - self.assertStatus(413) - - # Test for http://www.cherrypy.org/ticket/421 - # (Incorrect border condition in readline of SizeCheckWrapper). - # This hangs in rev 891 and earlier. - lines256 = "x" * 248 - self.getPage("/", - headers=[('Host', '%s:%s' % (self.HOST, self.PORT)), - ('From', lines256)]) - - # Test upload - body = '\r\n'.join([ - '--x', - 'Content-Disposition: form-data; name="file"; filename="hello.txt"', - 'Content-Type: text/plain', - '', - '%s', - '--x--']) - partlen = 200 - len(body) - b = body % ("x" * partlen) - h = [("Content-type", "multipart/form-data; boundary=x"), - ("Content-Length", "%s" % len(b))] - self.getPage('/upload', h, "POST", b) - self.assertBody('Size: %d' % partlen) - - b = body % ("x" * 200) - h = [("Content-type", "multipart/form-data; boundary=x"), - ("Content-Length", "%s" % len(b))] - self.getPage('/upload', h, "POST", b) - self.assertStatus(413) - diff --git a/cherrypy/test/test_conn.py b/cherrypy/test/test_conn.py deleted file mode 100644 index 1346f593..00000000 --- a/cherrypy/test/test_conn.py +++ /dev/null @@ -1,734 +0,0 @@ -"""Tests for TCP connection handling, including proper and timely close.""" - -import socket -import sys -import time -timeout = 1 - - -import cherrypy -from cherrypy._cpcompat import HTTPConnection, HTTPSConnection, NotConnected, BadStatusLine -from cherrypy._cpcompat import ntob, urlopen, unicodestr -from cherrypy.test import webtest -from cherrypy import _cperror - - -pov = 'pPeErRsSiIsStTeEnNcCeE oOfF vViIsSiIoOnN' - -def setup_server(): - - def raise500(): - raise cherrypy.HTTPError(500) - - class Root: - - def index(self): - return pov - index.exposed = True - page1 = index - page2 = index - page3 = index - - def hello(self): - return "Hello, world!" - hello.exposed = True - - def timeout(self, t): - return str(cherrypy.server.httpserver.timeout) - timeout.exposed = True - - def stream(self, set_cl=False): - if set_cl: - cherrypy.response.headers['Content-Length'] = 10 - - def content(): - for x in range(10): - yield str(x) - - return content() - stream.exposed = True - stream._cp_config = {'response.stream': True} - - def error(self, code=500): - raise cherrypy.HTTPError(code) - error.exposed = True - - def upload(self): - if not cherrypy.request.method == 'POST': - raise AssertionError("'POST' != request.method %r" % - cherrypy.request.method) - return "thanks for '%s'" % cherrypy.request.body.read() - upload.exposed = True - - def custom(self, response_code): - cherrypy.response.status = response_code - return "Code = %s" % response_code - custom.exposed = True - - def err_before_read(self): - return "ok" - err_before_read.exposed = True - err_before_read._cp_config = {'hooks.on_start_resource': raise500} - - def one_megabyte_of_a(self): - return ["a" * 1024] * 1024 - one_megabyte_of_a.exposed = True - - def custom_cl(self, body, cl): - cherrypy.response.headers['Content-Length'] = cl - if not isinstance(body, list): - body = [body] - newbody = [] - for chunk in body: - if isinstance(chunk, unicodestr): - chunk = chunk.encode('ISO-8859-1') - newbody.append(chunk) - return newbody - custom_cl.exposed = True - # Turn off the encoding tool so it doens't collapse - # our response body and reclaculate the Content-Length. - custom_cl._cp_config = {'tools.encode.on': False} - - cherrypy.tree.mount(Root()) - cherrypy.config.update({ - 'server.max_request_body_size': 1001, - 'server.socket_timeout': timeout, - }) - - -from cherrypy.test import helper - -class ConnectionCloseTests(helper.CPWebCase): - setup_server = staticmethod(setup_server) - - def test_HTTP11(self): - if cherrypy.server.protocol_version != "HTTP/1.1": - return self.skip() - - self.PROTOCOL = "HTTP/1.1" - - self.persistent = True - - # Make the first request and assert there's no "Connection: close". - self.getPage("/") - self.assertStatus('200 OK') - self.assertBody(pov) - self.assertNoHeader("Connection") - - # Make another request on the same connection. - self.getPage("/page1") - self.assertStatus('200 OK') - self.assertBody(pov) - self.assertNoHeader("Connection") - - # Test client-side close. - self.getPage("/page2", headers=[("Connection", "close")]) - self.assertStatus('200 OK') - self.assertBody(pov) - self.assertHeader("Connection", "close") - - # Make another request on the same connection, which should error. - self.assertRaises(NotConnected, self.getPage, "/") - - def test_Streaming_no_len(self): - self._streaming(set_cl=False) - - def test_Streaming_with_len(self): - self._streaming(set_cl=True) - - def _streaming(self, set_cl): - if cherrypy.server.protocol_version == "HTTP/1.1": - self.PROTOCOL = "HTTP/1.1" - - self.persistent = True - - # Make the first request and assert there's no "Connection: close". - self.getPage("/") - self.assertStatus('200 OK') - self.assertBody(pov) - self.assertNoHeader("Connection") - - # Make another, streamed request on the same connection. - if set_cl: - # When a Content-Length is provided, the content should stream - # without closing the connection. - self.getPage("/stream?set_cl=Yes") - self.assertHeader("Content-Length") - self.assertNoHeader("Connection", "close") - self.assertNoHeader("Transfer-Encoding") - - self.assertStatus('200 OK') - self.assertBody('0123456789') - else: - # When no Content-Length response header is provided, - # streamed output will either close the connection, or use - # chunked encoding, to determine transfer-length. - self.getPage("/stream") - self.assertNoHeader("Content-Length") - self.assertStatus('200 OK') - self.assertBody('0123456789') - - chunked_response = False - for k, v in self.headers: - if k.lower() == "transfer-encoding": - if str(v) == "chunked": - chunked_response = True - - if chunked_response: - self.assertNoHeader("Connection", "close") - else: - self.assertHeader("Connection", "close") - - # Make another request on the same connection, which should error. - self.assertRaises(NotConnected, self.getPage, "/") - - # Try HEAD. See http://www.cherrypy.org/ticket/864. - self.getPage("/stream", method='HEAD') - self.assertStatus('200 OK') - self.assertBody('') - self.assertNoHeader("Transfer-Encoding") - else: - self.PROTOCOL = "HTTP/1.0" - - self.persistent = True - - # Make the first request and assert Keep-Alive. - self.getPage("/", headers=[("Connection", "Keep-Alive")]) - self.assertStatus('200 OK') - self.assertBody(pov) - self.assertHeader("Connection", "Keep-Alive") - - # Make another, streamed request on the same connection. - if set_cl: - # When a Content-Length is provided, the content should - # stream without closing the connection. - self.getPage("/stream?set_cl=Yes", - headers=[("Connection", "Keep-Alive")]) - self.assertHeader("Content-Length") - self.assertHeader("Connection", "Keep-Alive") - self.assertNoHeader("Transfer-Encoding") - self.assertStatus('200 OK') - self.assertBody('0123456789') - else: - # When a Content-Length is not provided, - # the server should close the connection. - self.getPage("/stream", headers=[("Connection", "Keep-Alive")]) - self.assertStatus('200 OK') - self.assertBody('0123456789') - - self.assertNoHeader("Content-Length") - self.assertNoHeader("Connection", "Keep-Alive") - self.assertNoHeader("Transfer-Encoding") - - # Make another request on the same connection, which should error. - self.assertRaises(NotConnected, self.getPage, "/") - - def test_HTTP10_KeepAlive(self): - self.PROTOCOL = "HTTP/1.0" - if self.scheme == "https": - self.HTTP_CONN = HTTPSConnection - else: - self.HTTP_CONN = HTTPConnection - - # Test a normal HTTP/1.0 request. - self.getPage("/page2") - self.assertStatus('200 OK') - self.assertBody(pov) - # Apache, for example, may emit a Connection header even for HTTP/1.0 -## self.assertNoHeader("Connection") - - # Test a keep-alive HTTP/1.0 request. - self.persistent = True - - self.getPage("/page3", headers=[("Connection", "Keep-Alive")]) - self.assertStatus('200 OK') - self.assertBody(pov) - self.assertHeader("Connection", "Keep-Alive") - - # Remove the keep-alive header again. - self.getPage("/page3") - self.assertStatus('200 OK') - self.assertBody(pov) - # Apache, for example, may emit a Connection header even for HTTP/1.0 -## self.assertNoHeader("Connection") - - -class PipelineTests(helper.CPWebCase): - setup_server = staticmethod(setup_server) - - def test_HTTP11_Timeout(self): - # If we timeout without sending any data, - # the server will close the conn with a 408. - if cherrypy.server.protocol_version != "HTTP/1.1": - return self.skip() - - self.PROTOCOL = "HTTP/1.1" - - # Connect but send nothing. - self.persistent = True - conn = self.HTTP_CONN - conn.auto_open = False - conn.connect() - - # Wait for our socket timeout - time.sleep(timeout * 2) - - # The request should have returned 408 already. - response = conn.response_class(conn.sock, method="GET") - response.begin() - self.assertEqual(response.status, 408) - conn.close() - - # Connect but send half the headers only. - self.persistent = True - conn = self.HTTP_CONN - conn.auto_open = False - conn.connect() - conn.send(ntob('GET /hello HTTP/1.1')) - conn.send(("Host: %s" % self.HOST).encode('ascii')) - - # Wait for our socket timeout - time.sleep(timeout * 2) - - # The conn should have already sent 408. - response = conn.response_class(conn.sock, method="GET") - response.begin() - self.assertEqual(response.status, 408) - conn.close() - - def test_HTTP11_Timeout_after_request(self): - # If we timeout after at least one request has succeeded, - # the server will close the conn without 408. - if cherrypy.server.protocol_version != "HTTP/1.1": - return self.skip() - - self.PROTOCOL = "HTTP/1.1" - - # Make an initial request - self.persistent = True - conn = self.HTTP_CONN - conn.putrequest("GET", "/timeout?t=%s" % timeout, skip_host=True) - conn.putheader("Host", self.HOST) - conn.endheaders() - response = conn.response_class(conn.sock, method="GET") - response.begin() - self.assertEqual(response.status, 200) - self.body = response.read() - self.assertBody(str(timeout)) - - # Make a second request on the same socket - conn._output(ntob('GET /hello HTTP/1.1')) - conn._output(ntob("Host: %s" % self.HOST, 'ascii')) - conn._send_output() - response = conn.response_class(conn.sock, method="GET") - response.begin() - self.assertEqual(response.status, 200) - self.body = response.read() - self.assertBody("Hello, world!") - - # Wait for our socket timeout - time.sleep(timeout * 2) - - # Make another request on the same socket, which should error - conn._output(ntob('GET /hello HTTP/1.1')) - conn._output(ntob("Host: %s" % self.HOST, 'ascii')) - conn._send_output() - response = conn.response_class(conn.sock, method="GET") - try: - response.begin() - except: - if not isinstance(sys.exc_info()[1], - (socket.error, BadStatusLine)): - self.fail("Writing to timed out socket didn't fail" - " as it should have: %s" % sys.exc_info()[1]) - else: - if response.status != 408: - self.fail("Writing to timed out socket didn't fail" - " as it should have: %s" % - response.read()) - - conn.close() - - # Make another request on a new socket, which should work - self.persistent = True - conn = self.HTTP_CONN - conn.putrequest("GET", "/", skip_host=True) - conn.putheader("Host", self.HOST) - conn.endheaders() - response = conn.response_class(conn.sock, method="GET") - response.begin() - self.assertEqual(response.status, 200) - self.body = response.read() - self.assertBody(pov) - - - # Make another request on the same socket, - # but timeout on the headers - conn.send(ntob('GET /hello HTTP/1.1')) - # Wait for our socket timeout - time.sleep(timeout * 2) - response = conn.response_class(conn.sock, method="GET") - try: - response.begin() - except: - if not isinstance(sys.exc_info()[1], - (socket.error, BadStatusLine)): - self.fail("Writing to timed out socket didn't fail" - " as it should have: %s" % sys.exc_info()[1]) - else: - self.fail("Writing to timed out socket didn't fail" - " as it should have: %s" % - response.read()) - - conn.close() - - # Retry the request on a new connection, which should work - self.persistent = True - conn = self.HTTP_CONN - conn.putrequest("GET", "/", skip_host=True) - conn.putheader("Host", self.HOST) - conn.endheaders() - response = conn.response_class(conn.sock, method="GET") - response.begin() - self.assertEqual(response.status, 200) - self.body = response.read() - self.assertBody(pov) - conn.close() - - def test_HTTP11_pipelining(self): - if cherrypy.server.protocol_version != "HTTP/1.1": - return self.skip() - - self.PROTOCOL = "HTTP/1.1" - - # Test pipelining. httplib doesn't support this directly. - self.persistent = True - conn = self.HTTP_CONN - - # Put request 1 - conn.putrequest("GET", "/hello", skip_host=True) - conn.putheader("Host", self.HOST) - conn.endheaders() - - for trial in range(5): - # Put next request - conn._output(ntob('GET /hello HTTP/1.1')) - conn._output(ntob("Host: %s" % self.HOST, 'ascii')) - conn._send_output() - - # Retrieve previous response - response = conn.response_class(conn.sock, method="GET") - response.begin() - body = response.read(13) - self.assertEqual(response.status, 200) - self.assertEqual(body, ntob("Hello, world!")) - - # Retrieve final response - response = conn.response_class(conn.sock, method="GET") - response.begin() - body = response.read() - self.assertEqual(response.status, 200) - self.assertEqual(body, ntob("Hello, world!")) - - conn.close() - - def test_100_Continue(self): - if cherrypy.server.protocol_version != "HTTP/1.1": - return self.skip() - - self.PROTOCOL = "HTTP/1.1" - - self.persistent = True - conn = self.HTTP_CONN - - # Try a page without an Expect request header first. - # Note that httplib's response.begin automatically ignores - # 100 Continue responses, so we must manually check for it. - conn.putrequest("POST", "/upload", skip_host=True) - conn.putheader("Host", self.HOST) - conn.putheader("Content-Type", "text/plain") - conn.putheader("Content-Length", "4") - conn.endheaders() - conn.send(ntob("d'oh")) - response = conn.response_class(conn.sock, method="POST") - version, status, reason = response._read_status() - self.assertNotEqual(status, 100) - conn.close() - - # Now try a page with an Expect header... - conn.connect() - conn.putrequest("POST", "/upload", skip_host=True) - conn.putheader("Host", self.HOST) - conn.putheader("Content-Type", "text/plain") - conn.putheader("Content-Length", "17") - conn.putheader("Expect", "100-continue") - conn.endheaders() - response = conn.response_class(conn.sock, method="POST") - - # ...assert and then skip the 100 response - version, status, reason = response._read_status() - self.assertEqual(status, 100) - while True: - line = response.fp.readline().strip() - if line: - self.fail("100 Continue should not output any headers. Got %r" % line) - else: - break - - # ...send the body - body = ntob("I am a small file") - conn.send(body) - - # ...get the final response - response.begin() - self.status, self.headers, self.body = webtest.shb(response) - self.assertStatus(200) - self.assertBody("thanks for '%s'" % body) - conn.close() - - -class ConnectionTests(helper.CPWebCase): - setup_server = staticmethod(setup_server) - - def test_readall_or_close(self): - if cherrypy.server.protocol_version != "HTTP/1.1": - return self.skip() - - self.PROTOCOL = "HTTP/1.1" - - if self.scheme == "https": - self.HTTP_CONN = HTTPSConnection - else: - self.HTTP_CONN = HTTPConnection - - # Test a max of 0 (the default) and then reset to what it was above. - old_max = cherrypy.server.max_request_body_size - for new_max in (0, old_max): - cherrypy.server.max_request_body_size = new_max - - self.persistent = True - conn = self.HTTP_CONN - - # Get a POST page with an error - conn.putrequest("POST", "/err_before_read", skip_host=True) - conn.putheader("Host", self.HOST) - conn.putheader("Content-Type", "text/plain") - conn.putheader("Content-Length", "1000") - conn.putheader("Expect", "100-continue") - conn.endheaders() - response = conn.response_class(conn.sock, method="POST") - - # ...assert and then skip the 100 response - version, status, reason = response._read_status() - self.assertEqual(status, 100) - while True: - skip = response.fp.readline().strip() - if not skip: - break - - # ...send the body - conn.send(ntob("x" * 1000)) - - # ...get the final response - response.begin() - self.status, self.headers, self.body = webtest.shb(response) - self.assertStatus(500) - - # Now try a working page with an Expect header... - conn._output(ntob('POST /upload HTTP/1.1')) - conn._output(ntob("Host: %s" % self.HOST, 'ascii')) - conn._output(ntob("Content-Type: text/plain")) - conn._output(ntob("Content-Length: 17")) - conn._output(ntob("Expect: 100-continue")) - conn._send_output() - response = conn.response_class(conn.sock, method="POST") - - # ...assert and then skip the 100 response - version, status, reason = response._read_status() - self.assertEqual(status, 100) - while True: - skip = response.fp.readline().strip() - if not skip: - break - - # ...send the body - body = ntob("I am a small file") - conn.send(body) - - # ...get the final response - response.begin() - self.status, self.headers, self.body = webtest.shb(response) - self.assertStatus(200) - self.assertBody("thanks for '%s'" % body) - conn.close() - - def test_No_Message_Body(self): - if cherrypy.server.protocol_version != "HTTP/1.1": - return self.skip() - - self.PROTOCOL = "HTTP/1.1" - - # Set our HTTP_CONN to an instance so it persists between requests. - self.persistent = True - - # Make the first request and assert there's no "Connection: close". - self.getPage("/") - self.assertStatus('200 OK') - self.assertBody(pov) - self.assertNoHeader("Connection") - - # Make a 204 request on the same connection. - self.getPage("/custom/204") - self.assertStatus(204) - self.assertNoHeader("Content-Length") - self.assertBody("") - self.assertNoHeader("Connection") - - # Make a 304 request on the same connection. - self.getPage("/custom/304") - self.assertStatus(304) - self.assertNoHeader("Content-Length") - self.assertBody("") - self.assertNoHeader("Connection") - - def test_Chunked_Encoding(self): - if cherrypy.server.protocol_version != "HTTP/1.1": - return self.skip() - - if (hasattr(self, 'harness') and - "modpython" in self.harness.__class__.__name__.lower()): - # mod_python forbids chunked encoding - return self.skip() - - self.PROTOCOL = "HTTP/1.1" - - # Set our HTTP_CONN to an instance so it persists between requests. - self.persistent = True - conn = self.HTTP_CONN - - # Try a normal chunked request (with extensions) - body = ntob("8;key=value\r\nxx\r\nxxxx\r\n5\r\nyyyyy\r\n0\r\n" - "Content-Type: application/json\r\n" - "\r\n") - conn.putrequest("POST", "/upload", skip_host=True) - conn.putheader("Host", self.HOST) - conn.putheader("Transfer-Encoding", "chunked") - conn.putheader("Trailer", "Content-Type") - # Note that this is somewhat malformed: - # we shouldn't be sending Content-Length. - # RFC 2616 says the server should ignore it. - conn.putheader("Content-Length", "3") - conn.endheaders() - conn.send(body) - response = conn.getresponse() - self.status, self.headers, self.body = webtest.shb(response) - self.assertStatus('200 OK') - self.assertBody("thanks for '%s'" % ntob('xx\r\nxxxxyyyyy')) - - # Try a chunked request that exceeds server.max_request_body_size. - # Note that the delimiters and trailer are included. - body = ntob("3e3\r\n" + ("x" * 995) + "\r\n0\r\n\r\n") - conn.putrequest("POST", "/upload", skip_host=True) - conn.putheader("Host", self.HOST) - conn.putheader("Transfer-Encoding", "chunked") - conn.putheader("Content-Type", "text/plain") - # Chunked requests don't need a content-length -## conn.putheader("Content-Length", len(body)) - conn.endheaders() - conn.send(body) - response = conn.getresponse() - self.status, self.headers, self.body = webtest.shb(response) - self.assertStatus(413) - conn.close() - - def test_Content_Length_in(self): - # Try a non-chunked request where Content-Length exceeds - # server.max_request_body_size. Assert error before body send. - self.persistent = True - conn = self.HTTP_CONN - conn.putrequest("POST", "/upload", skip_host=True) - conn.putheader("Host", self.HOST) - conn.putheader("Content-Type", "text/plain") - conn.putheader("Content-Length", "9999") - conn.endheaders() - response = conn.getresponse() - self.status, self.headers, self.body = webtest.shb(response) - self.assertStatus(413) - self.assertBody("The entity sent with the request exceeds " - "the maximum allowed bytes.") - conn.close() - - def test_Content_Length_out_preheaders(self): - # Try a non-chunked response where Content-Length is less than - # the actual bytes in the response body. - self.persistent = True - conn = self.HTTP_CONN - conn.putrequest("GET", "/custom_cl?body=I+have+too+many+bytes&cl=5", - skip_host=True) - conn.putheader("Host", self.HOST) - conn.endheaders() - response = conn.getresponse() - self.status, self.headers, self.body = webtest.shb(response) - self.assertStatus(500) - self.assertBody( - "The requested resource returned more bytes than the " - "declared Content-Length.") - conn.close() - - def test_Content_Length_out_postheaders(self): - # Try a non-chunked response where Content-Length is less than - # the actual bytes in the response body. - self.persistent = True - conn = self.HTTP_CONN - conn.putrequest("GET", "/custom_cl?body=I+too&body=+have+too+many&cl=5", - skip_host=True) - conn.putheader("Host", self.HOST) - conn.endheaders() - response = conn.getresponse() - self.status, self.headers, self.body = webtest.shb(response) - self.assertStatus(200) - self.assertBody("I too") - conn.close() - - def test_598(self): - remote_data_conn = urlopen('%s://%s:%s/one_megabyte_of_a/' % - (self.scheme, self.HOST, self.PORT,)) - buf = remote_data_conn.read(512) - time.sleep(timeout * 0.6) - remaining = (1024 * 1024) - 512 - while remaining: - data = remote_data_conn.read(remaining) - if not data: - break - else: - buf += data - remaining -= len(data) - - self.assertEqual(len(buf), 1024 * 1024) - self.assertEqual(buf, ntob("a" * 1024 * 1024)) - self.assertEqual(remaining, 0) - remote_data_conn.close() - - -class BadRequestTests(helper.CPWebCase): - setup_server = staticmethod(setup_server) - - def test_No_CRLF(self): - self.persistent = True - - conn = self.HTTP_CONN - conn.send(ntob('GET /hello HTTP/1.1\n\n')) - response = conn.response_class(conn.sock, method="GET") - response.begin() - self.body = response.read() - self.assertBody("HTTP requires CRLF terminators") - conn.close() - - conn.connect() - conn.send(ntob('GET /hello HTTP/1.1\r\n\n')) - response = conn.response_class(conn.sock, method="GET") - response.begin() - self.body = response.read() - self.assertBody("HTTP requires CRLF terminators") - conn.close() - diff --git a/cherrypy/test/test_core.py b/cherrypy/test/test_core.py deleted file mode 100644 index 09544e34..00000000 --- a/cherrypy/test/test_core.py +++ /dev/null @@ -1,617 +0,0 @@ -"""Basic tests for the CherryPy core: request handling.""" - -import os -localDir = os.path.dirname(__file__) -import sys -import types - -import cherrypy -from cherrypy._cpcompat import IncompleteRead, itervalues, ntob -from cherrypy import _cptools, tools -from cherrypy.lib import httputil, static - - -favicon_path = os.path.join(os.getcwd(), localDir, "../favicon.ico") - -# Client-side code # - -from cherrypy.test import helper - -class CoreRequestHandlingTest(helper.CPWebCase): - - def setup_server(): - class Root: - - def index(self): - return "hello" - index.exposed = True - - favicon_ico = tools.staticfile.handler(filename=favicon_path) - - def defct(self, newct): - newct = "text/%s" % newct - cherrypy.config.update({'tools.response_headers.on': True, - 'tools.response_headers.headers': - [('Content-Type', newct)]}) - defct.exposed = True - - def baseurl(self, path_info, relative=None): - return cherrypy.url(path_info, relative=bool(relative)) - baseurl.exposed = True - - root = Root() - - if sys.version_info >= (2, 5): - from cherrypy.test._test_decorators import ExposeExamples - root.expose_dec = ExposeExamples() - - - class TestType(type): - """Metaclass which automatically exposes all functions in each subclass, - and adds an instance of the subclass as an attribute of root. - """ - def __init__(cls, name, bases, dct): - type.__init__(cls, name, bases, dct) - for value in itervalues(dct): - if isinstance(value, types.FunctionType): - value.exposed = True - setattr(root, name.lower(), cls()) - class Test(object): - __metaclass__ = TestType - - - class URL(Test): - - _cp_config = {'tools.trailing_slash.on': False} - - def index(self, path_info, relative=None): - if relative != 'server': - relative = bool(relative) - return cherrypy.url(path_info, relative=relative) - - def leaf(self, path_info, relative=None): - if relative != 'server': - relative = bool(relative) - return cherrypy.url(path_info, relative=relative) - - - class Status(Test): - - def index(self): - return "normal" - - def blank(self): - cherrypy.response.status = "" - - # According to RFC 2616, new status codes are OK as long as they - # are between 100 and 599. - - # Here is an illegal code... - def illegal(self): - cherrypy.response.status = 781 - return "oops" - - # ...and here is an unknown but legal code. - def unknown(self): - cherrypy.response.status = "431 My custom error" - return "funky" - - # Non-numeric code - def bad(self): - cherrypy.response.status = "error" - return "bad news" - - - class Redirect(Test): - - class Error: - _cp_config = {"tools.err_redirect.on": True, - "tools.err_redirect.url": "/errpage", - "tools.err_redirect.internal": False, - } - - def index(self): - raise NameError("redirect_test") - index.exposed = True - error = Error() - - def index(self): - return "child" - - def custom(self, url, code): - raise cherrypy.HTTPRedirect(url, code) - - def by_code(self, code): - raise cherrypy.HTTPRedirect("somewhere%20else", code) - by_code._cp_config = {'tools.trailing_slash.extra': True} - - def nomodify(self): - raise cherrypy.HTTPRedirect("", 304) - - def proxy(self): - raise cherrypy.HTTPRedirect("proxy", 305) - - def stringify(self): - return str(cherrypy.HTTPRedirect("/")) - - def fragment(self, frag): - raise cherrypy.HTTPRedirect("/some/url#%s" % frag) - - def login_redir(): - if not getattr(cherrypy.request, "login", None): - raise cherrypy.InternalRedirect("/internalredirect/login") - tools.login_redir = _cptools.Tool('before_handler', login_redir) - - def redir_custom(): - raise cherrypy.InternalRedirect("/internalredirect/custom_err") - - class InternalRedirect(Test): - - def index(self): - raise cherrypy.InternalRedirect("/") - - def choke(self): - return 3 / 0 - choke.exposed = True - choke._cp_config = {'hooks.before_error_response': redir_custom} - - def relative(self, a, b): - raise cherrypy.InternalRedirect("cousin?t=6") - - def cousin(self, t): - assert cherrypy.request.prev.closed - return cherrypy.request.prev.query_string - - def petshop(self, user_id): - if user_id == "parrot": - # Trade it for a slug when redirecting - raise cherrypy.InternalRedirect('/image/getImagesByUser?user_id=slug') - elif user_id == "terrier": - # Trade it for a fish when redirecting - raise cherrypy.InternalRedirect('/image/getImagesByUser?user_id=fish') - else: - # This should pass the user_id through to getImagesByUser - raise cherrypy.InternalRedirect( - '/image/getImagesByUser?user_id=%s' % str(user_id)) - - # We support Python 2.3, but the @-deco syntax would look like this: - # @tools.login_redir() - def secure(self): - return "Welcome!" - secure = tools.login_redir()(secure) - # Since calling the tool returns the same function you pass in, - # you could skip binding the return value, and just write: - # tools.login_redir()(secure) - - def login(self): - return "Please log in" - - def custom_err(self): - return "Something went horribly wrong." - - def early_ir(self, arg): - return "whatever" - early_ir._cp_config = {'hooks.before_request_body': redir_custom} - - - class Image(Test): - - def getImagesByUser(self, user_id): - return "0 images for %s" % user_id - - - class Flatten(Test): - - def as_string(self): - return "content" - - def as_list(self): - return ["con", "tent"] - - def as_yield(self): - yield ntob("content") - - def as_dblyield(self): - yield self.as_yield() - as_dblyield._cp_config = {'tools.flatten.on': True} - - def as_refyield(self): - for chunk in self.as_yield(): - yield chunk - - - class Ranges(Test): - - def get_ranges(self, bytes): - return repr(httputil.get_ranges('bytes=%s' % bytes, 8)) - - def slice_file(self): - path = os.path.join(os.getcwd(), os.path.dirname(__file__)) - return static.serve_file(os.path.join(path, "static/index.html")) - - - class Cookies(Test): - - def single(self, name): - cookie = cherrypy.request.cookie[name] - # Python2's SimpleCookie.__setitem__ won't take unicode keys. - cherrypy.response.cookie[str(name)] = cookie.value - - def multiple(self, names): - for name in names: - cookie = cherrypy.request.cookie[name] - # Python2's SimpleCookie.__setitem__ won't take unicode keys. - cherrypy.response.cookie[str(name)] = cookie.value - - - cherrypy.tree.mount(root) - setup_server = staticmethod(setup_server) - - - def testStatus(self): - self.getPage("/status/") - self.assertBody('normal') - self.assertStatus(200) - - self.getPage("/status/blank") - self.assertBody('') - self.assertStatus(200) - - self.getPage("/status/illegal") - self.assertStatus(500) - msg = "Illegal response status from server (781 is out of range)." - self.assertErrorPage(500, msg) - - if not getattr(cherrypy.server, 'using_apache', False): - self.getPage("/status/unknown") - self.assertBody('funky') - self.assertStatus(431) - - self.getPage("/status/bad") - self.assertStatus(500) - msg = "Illegal response status from server ('error' is non-numeric)." - self.assertErrorPage(500, msg) - - def testSlashes(self): - # Test that requests for index methods without a trailing slash - # get redirected to the same URI path with a trailing slash. - # Make sure GET params are preserved. - self.getPage("/redirect?id=3") - self.assertStatus(301) - self.assertInBody("" - "%s/redirect/?id=3" % (self.base(), self.base())) - - if self.prefix(): - # Corner case: the "trailing slash" redirect could be tricky if - # we're using a virtual root and the URI is "/vroot" (no slash). - self.getPage("") - self.assertStatus(301) - self.assertInBody("%s/" % - (self.base(), self.base())) - - # Test that requests for NON-index methods WITH a trailing slash - # get redirected to the same URI path WITHOUT a trailing slash. - # Make sure GET params are preserved. - self.getPage("/redirect/by_code/?code=307") - self.assertStatus(301) - self.assertInBody("" - "%s/redirect/by_code?code=307" - % (self.base(), self.base())) - - # If the trailing_slash tool is off, CP should just continue - # as if the slashes were correct. But it needs some help - # inside cherrypy.url to form correct output. - self.getPage('/url?path_info=page1') - self.assertBody('%s/url/page1' % self.base()) - self.getPage('/url/leaf/?path_info=page1') - self.assertBody('%s/url/page1' % self.base()) - - def testRedirect(self): - self.getPage("/redirect/") - self.assertBody('child') - self.assertStatus(200) - - self.getPage("/redirect/by_code?code=300") - self.assertMatchesBody(r"\1somewhere%20else") - self.assertStatus(300) - - self.getPage("/redirect/by_code?code=301") - self.assertMatchesBody(r"\1somewhere%20else") - self.assertStatus(301) - - self.getPage("/redirect/by_code?code=302") - self.assertMatchesBody(r"\1somewhere%20else") - self.assertStatus(302) - - self.getPage("/redirect/by_code?code=303") - self.assertMatchesBody(r"\1somewhere%20else") - self.assertStatus(303) - - self.getPage("/redirect/by_code?code=307") - self.assertMatchesBody(r"\1somewhere%20else") - self.assertStatus(307) - - self.getPage("/redirect/nomodify") - self.assertBody('') - self.assertStatus(304) - - self.getPage("/redirect/proxy") - self.assertBody('') - self.assertStatus(305) - - # HTTPRedirect on error - self.getPage("/redirect/error/") - self.assertStatus(('302 Found', '303 See Other')) - self.assertInBody('/errpage') - - # Make sure str(HTTPRedirect()) works. - self.getPage("/redirect/stringify", protocol="HTTP/1.0") - self.assertStatus(200) - self.assertBody("(['%s/'], 302)" % self.base()) - if cherrypy.server.protocol_version == "HTTP/1.1": - self.getPage("/redirect/stringify", protocol="HTTP/1.1") - self.assertStatus(200) - self.assertBody("(['%s/'], 303)" % self.base()) - - # check that #fragments are handled properly - # http://skrb.org/ietf/http_errata.html#location-fragments - frag = "foo" - self.getPage("/redirect/fragment/%s" % frag) - self.assertMatchesBody(r"\1\/some\/url\#%s" % (frag, frag)) - loc = self.assertHeader('Location') - assert loc.endswith("#%s" % frag) - self.assertStatus(('302 Found', '303 See Other')) - - # check injection protection - # See http://www.cherrypy.org/ticket/1003 - self.getPage("/redirect/custom?code=303&url=/foobar/%0d%0aSet-Cookie:%20somecookie=someval") - self.assertStatus(303) - loc = self.assertHeader('Location') - assert 'Set-Cookie' in loc - self.assertNoHeader('Set-Cookie') - - def test_InternalRedirect(self): - # InternalRedirect - self.getPage("/internalredirect/") - self.assertBody('hello') - self.assertStatus(200) - - # Test passthrough - self.getPage("/internalredirect/petshop?user_id=Sir-not-appearing-in-this-film") - self.assertBody('0 images for Sir-not-appearing-in-this-film') - self.assertStatus(200) - - # Test args - self.getPage("/internalredirect/petshop?user_id=parrot") - self.assertBody('0 images for slug') - self.assertStatus(200) - - # Test POST - self.getPage("/internalredirect/petshop", method="POST", - body="user_id=terrier") - self.assertBody('0 images for fish') - self.assertStatus(200) - - # Test ir before body read - self.getPage("/internalredirect/early_ir", method="POST", - body="arg=aha!") - self.assertBody("Something went horribly wrong.") - self.assertStatus(200) - - self.getPage("/internalredirect/secure") - self.assertBody('Please log in') - self.assertStatus(200) - - # Relative path in InternalRedirect. - # Also tests request.prev. - self.getPage("/internalredirect/relative?a=3&b=5") - self.assertBody("a=3&b=5") - self.assertStatus(200) - - # InternalRedirect on error - self.getPage("/internalredirect/choke") - self.assertStatus(200) - self.assertBody("Something went horribly wrong.") - - def testFlatten(self): - for url in ["/flatten/as_string", "/flatten/as_list", - "/flatten/as_yield", "/flatten/as_dblyield", - "/flatten/as_refyield"]: - self.getPage(url) - self.assertBody('content') - - def testRanges(self): - self.getPage("/ranges/get_ranges?bytes=3-6") - self.assertBody("[(3, 7)]") - - # Test multiple ranges and a suffix-byte-range-spec, for good measure. - self.getPage("/ranges/get_ranges?bytes=2-4,-1") - self.assertBody("[(2, 5), (7, 8)]") - - # Get a partial file. - if cherrypy.server.protocol_version == "HTTP/1.1": - self.getPage("/ranges/slice_file", [('Range', 'bytes=2-5')]) - self.assertStatus(206) - self.assertHeader("Content-Type", "text/html;charset=utf-8") - self.assertHeader("Content-Range", "bytes 2-5/14") - self.assertBody("llo,") - - # What happens with overlapping ranges (and out of order, too)? - self.getPage("/ranges/slice_file", [('Range', 'bytes=4-6,2-5')]) - self.assertStatus(206) - ct = self.assertHeader("Content-Type") - expected_type = "multipart/byteranges; boundary=" - self.assert_(ct.startswith(expected_type)) - boundary = ct[len(expected_type):] - expected_body = ("\r\n--%s\r\n" - "Content-type: text/html\r\n" - "Content-range: bytes 4-6/14\r\n" - "\r\n" - "o, \r\n" - "--%s\r\n" - "Content-type: text/html\r\n" - "Content-range: bytes 2-5/14\r\n" - "\r\n" - "llo,\r\n" - "--%s--\r\n" % (boundary, boundary, boundary)) - self.assertBody(expected_body) - self.assertHeader("Content-Length") - - # Test "416 Requested Range Not Satisfiable" - self.getPage("/ranges/slice_file", [('Range', 'bytes=2300-2900')]) - self.assertStatus(416) - # "When this status code is returned for a byte-range request, - # the response SHOULD include a Content-Range entity-header - # field specifying the current length of the selected resource" - self.assertHeader("Content-Range", "bytes */14") - elif cherrypy.server.protocol_version == "HTTP/1.0": - # Test Range behavior with HTTP/1.0 request - self.getPage("/ranges/slice_file", [('Range', 'bytes=2-5')]) - self.assertStatus(200) - self.assertBody("Hello, world\r\n") - - def testFavicon(self): - # favicon.ico is served by staticfile. - icofilename = os.path.join(localDir, "../favicon.ico") - icofile = open(icofilename, "rb") - data = icofile.read() - icofile.close() - - self.getPage("/favicon.ico") - self.assertBody(data) - - def testCookies(self): - if sys.version_info >= (2, 5): - header_value = lambda x: x - else: - header_value = lambda x: x+';' - - self.getPage("/cookies/single?name=First", - [('Cookie', 'First=Dinsdale;')]) - self.assertHeader('Set-Cookie', header_value('First=Dinsdale')) - - self.getPage("/cookies/multiple?names=First&names=Last", - [('Cookie', 'First=Dinsdale; Last=Piranha;'), - ]) - self.assertHeader('Set-Cookie', header_value('First=Dinsdale')) - self.assertHeader('Set-Cookie', header_value('Last=Piranha')) - - self.getPage("/cookies/single?name=Something-With:Colon", - [('Cookie', 'Something-With:Colon=some-value')]) - self.assertStatus(400) - - def testDefaultContentType(self): - self.getPage('/') - self.assertHeader('Content-Type', 'text/html;charset=utf-8') - self.getPage('/defct/plain') - self.getPage('/') - self.assertHeader('Content-Type', 'text/plain;charset=utf-8') - self.getPage('/defct/html') - - def test_cherrypy_url(self): - # Input relative to current - self.getPage('/url/leaf?path_info=page1') - self.assertBody('%s/url/page1' % self.base()) - self.getPage('/url/?path_info=page1') - self.assertBody('%s/url/page1' % self.base()) - # Other host header - host = 'www.mydomain.example' - self.getPage('/url/leaf?path_info=page1', - headers=[('Host', host)]) - self.assertBody('%s://%s/url/page1' % (self.scheme, host)) - - # Input is 'absolute'; that is, relative to script_name - self.getPage('/url/leaf?path_info=/page1') - self.assertBody('%s/page1' % self.base()) - self.getPage('/url/?path_info=/page1') - self.assertBody('%s/page1' % self.base()) - - # Single dots - self.getPage('/url/leaf?path_info=./page1') - self.assertBody('%s/url/page1' % self.base()) - self.getPage('/url/leaf?path_info=other/./page1') - self.assertBody('%s/url/other/page1' % self.base()) - self.getPage('/url/?path_info=/other/./page1') - self.assertBody('%s/other/page1' % self.base()) - - # Double dots - self.getPage('/url/leaf?path_info=../page1') - self.assertBody('%s/page1' % self.base()) - self.getPage('/url/leaf?path_info=other/../page1') - self.assertBody('%s/url/page1' % self.base()) - self.getPage('/url/leaf?path_info=/other/../page1') - self.assertBody('%s/page1' % self.base()) - - # Output relative to current path or script_name - self.getPage('/url/?path_info=page1&relative=True') - self.assertBody('page1') - self.getPage('/url/leaf?path_info=/page1&relative=True') - self.assertBody('../page1') - self.getPage('/url/leaf?path_info=page1&relative=True') - self.assertBody('page1') - self.getPage('/url/leaf?path_info=leaf/page1&relative=True') - self.assertBody('leaf/page1') - self.getPage('/url/leaf?path_info=../page1&relative=True') - self.assertBody('../page1') - self.getPage('/url/?path_info=other/../page1&relative=True') - self.assertBody('page1') - - # Output relative to / - self.getPage('/baseurl?path_info=ab&relative=True') - self.assertBody('ab') - # Output relative to / - self.getPage('/baseurl?path_info=/ab&relative=True') - self.assertBody('ab') - - # absolute-path references ("server-relative") - # Input relative to current - self.getPage('/url/leaf?path_info=page1&relative=server') - self.assertBody('/url/page1') - self.getPage('/url/?path_info=page1&relative=server') - self.assertBody('/url/page1') - # Input is 'absolute'; that is, relative to script_name - self.getPage('/url/leaf?path_info=/page1&relative=server') - self.assertBody('/page1') - self.getPage('/url/?path_info=/page1&relative=server') - self.assertBody('/page1') - - def test_expose_decorator(self): - if not sys.version_info >= (2, 5): - return self.skip("skipped (Python 2.5+ only) ") - - # Test @expose - self.getPage("/expose_dec/no_call") - self.assertStatus(200) - self.assertBody("Mr E. R. Bradshaw") - - # Test @expose() - self.getPage("/expose_dec/call_empty") - self.assertStatus(200) - self.assertBody("Mrs. B.J. Smegma") - - # Test @expose("alias") - self.getPage("/expose_dec/call_alias") - self.assertStatus(200) - self.assertBody("Mr Nesbitt") - # Does the original name work? - self.getPage("/expose_dec/nesbitt") - self.assertStatus(200) - self.assertBody("Mr Nesbitt") - - # Test @expose(["alias1", "alias2"]) - self.getPage("/expose_dec/alias1") - self.assertStatus(200) - self.assertBody("Mr Ken Andrews") - self.getPage("/expose_dec/alias2") - self.assertStatus(200) - self.assertBody("Mr Ken Andrews") - # Does the original name work? - self.getPage("/expose_dec/andrews") - self.assertStatus(200) - self.assertBody("Mr Ken Andrews") - - # Test @expose(alias="alias") - self.getPage("/expose_dec/alias3") - self.assertStatus(200) - self.assertBody("Mr. and Mrs. Watson") - diff --git a/cherrypy/test/test_dynamicobjectmapping.py b/cherrypy/test/test_dynamicobjectmapping.py deleted file mode 100644 index 1e04d089..00000000 --- a/cherrypy/test/test_dynamicobjectmapping.py +++ /dev/null @@ -1,403 +0,0 @@ -import cherrypy -from cherrypy._cptree import Application -from cherrypy.test import helper - -script_names = ["", "/foo", "/users/fred/blog", "/corp/blog"] - - - -def setup_server(): - class SubSubRoot: - def index(self): - return "SubSubRoot index" - index.exposed = True - - def default(self, *args): - return "SubSubRoot default" - default.exposed = True - - def handler(self): - return "SubSubRoot handler" - handler.exposed = True - - def dispatch(self): - return "SubSubRoot dispatch" - dispatch.exposed = True - - subsubnodes = { - '1': SubSubRoot(), - '2': SubSubRoot(), - } - - class SubRoot: - def index(self): - return "SubRoot index" - index.exposed = True - - def default(self, *args): - return "SubRoot %s" % (args,) - default.exposed = True - - def handler(self): - return "SubRoot handler" - handler.exposed = True - - def _cp_dispatch(self, vpath): - return subsubnodes.get(vpath[0], None) - - subnodes = { - '1': SubRoot(), - '2': SubRoot(), - } - class Root: - def index(self): - return "index" - index.exposed = True - - def default(self, *args): - return "default %s" % (args,) - default.exposed = True - - def handler(self): - return "handler" - handler.exposed = True - - def _cp_dispatch(self, vpath): - return subnodes.get(vpath[0]) - - #-------------------------------------------------------------------------- - # DynamicNodeAndMethodDispatcher example. - # This example exposes a fairly naive HTTP api - class User(object): - def __init__(self, id, name): - self.id = id - self.name = name - - def __unicode__(self): - return unicode(self.name) - - user_lookup = { - 1: User(1, 'foo'), - 2: User(2, 'bar'), - } - - def make_user(name, id=None): - if not id: - id = max(*user_lookup.keys()) + 1 - user_lookup[id] = User(id, name) - return id - - class UserContainerNode(object): - exposed = True - - def POST(self, name): - """ - Allow the creation of a new Object - """ - return "POST %d" % make_user(name) - - def GET(self): - keys = user_lookup.keys() - keys.sort() - return unicode(keys) - - def dynamic_dispatch(self, vpath): - try: - id = int(vpath[0]) - except (ValueError, IndexError): - return None - return UserInstanceNode(id) - - class UserInstanceNode(object): - exposed = True - def __init__(self, id): - self.id = id - self.user = user_lookup.get(id, None) - - # For all but PUT methods there MUST be a valid user identified - # by self.id - if not self.user and cherrypy.request.method != 'PUT': - raise cherrypy.HTTPError(404) - - def GET(self, *args, **kwargs): - """ - Return the appropriate representation of the instance. - """ - return unicode(self.user) - - def POST(self, name): - """ - Update the fields of the user instance. - """ - self.user.name = name - return "POST %d" % self.user.id - - def PUT(self, name): - """ - Create a new user with the specified id, or edit it if it already exists - """ - if self.user: - # Edit the current user - self.user.name = name - return "PUT %d" % self.user.id - else: - # Make a new user with said attributes. - return "PUT %d" % make_user(name, self.id) - - def DELETE(self): - """ - Delete the user specified at the id. - """ - id = self.user.id - del user_lookup[self.user.id] - del self.user - return "DELETE %d" % id - - - class ABHandler: - class CustomDispatch: - def index(self, a, b): - return "custom" - index.exposed = True - - def _cp_dispatch(self, vpath): - """Make sure that if we don't pop anything from vpath, - processing still works. - """ - return self.CustomDispatch() - - def index(self, a, b=None): - body = [ 'a:' + str(a) ] - if b is not None: - body.append(',b:' + str(b)) - return ''.join(body) - index.exposed = True - - def delete(self, a, b): - return 'deleting ' + str(a) + ' and ' + str(b) - delete.exposed = True - - class IndexOnly: - def _cp_dispatch(self, vpath): - """Make sure that popping ALL of vpath still shows the index - handler. - """ - while vpath: - vpath.pop() - return self - - def index(self): - return "IndexOnly index" - index.exposed = True - - class DecoratedPopArgs: - """Test _cp_dispatch with @cherrypy.popargs.""" - def index(self): - return "no params" - index.exposed = True - - def hi(self): - return "hi was not interpreted as 'a' param" - hi.exposed = True - DecoratedPopArgs = cherrypy.popargs('a', 'b', handler=ABHandler())(DecoratedPopArgs) - - class NonDecoratedPopArgs: - """Test _cp_dispatch = cherrypy.popargs()""" - - _cp_dispatch = cherrypy.popargs('a') - - def index(self, a): - return "index: " + str(a) - index.exposed = True - - class ParameterizedHandler: - """Special handler created for each request""" - - def __init__(self, a): - self.a = a - - def index(self): - if 'a' in cherrypy.request.params: - raise Exception("Parameterized handler argument ended up in request.params") - return self.a - index.exposed = True - - class ParameterizedPopArgs: - """Test cherrypy.popargs() with a function call handler""" - ParameterizedPopArgs = cherrypy.popargs('a', handler=ParameterizedHandler)(ParameterizedPopArgs) - - Root.decorated = DecoratedPopArgs() - Root.undecorated = NonDecoratedPopArgs() - Root.index_only = IndexOnly() - Root.parameter_test = ParameterizedPopArgs() - - Root.users = UserContainerNode() - - md = cherrypy.dispatch.MethodDispatcher('dynamic_dispatch') - for url in script_names: - conf = {'/': { - 'user': (url or "/").split("/")[-2], - }, - '/users': { - 'request.dispatch': md - }, - } - cherrypy.tree.mount(Root(), url, conf) - -class DynamicObjectMappingTest(helper.CPWebCase): - setup_server = staticmethod(setup_server) - - def testObjectMapping(self): - for url in script_names: - prefix = self.script_name = url - - self.getPage('/') - self.assertBody('index') - - self.getPage('/handler') - self.assertBody('handler') - - # Dynamic dispatch will succeed here for the subnodes - # so the subroot gets called - self.getPage('/1/') - self.assertBody('SubRoot index') - - self.getPage('/2/') - self.assertBody('SubRoot index') - - self.getPage('/1/handler') - self.assertBody('SubRoot handler') - - self.getPage('/2/handler') - self.assertBody('SubRoot handler') - - # Dynamic dispatch will fail here for the subnodes - # so the default gets called - self.getPage('/asdf/') - self.assertBody("default ('asdf',)") - - self.getPage('/asdf/asdf') - self.assertBody("default ('asdf', 'asdf')") - - self.getPage('/asdf/handler') - self.assertBody("default ('asdf', 'handler')") - - # Dynamic dispatch will succeed here for the subsubnodes - # so the subsubroot gets called - self.getPage('/1/1/') - self.assertBody('SubSubRoot index') - - self.getPage('/2/2/') - self.assertBody('SubSubRoot index') - - self.getPage('/1/1/handler') - self.assertBody('SubSubRoot handler') - - self.getPage('/2/2/handler') - self.assertBody('SubSubRoot handler') - - self.getPage('/2/2/dispatch') - self.assertBody('SubSubRoot dispatch') - - # The exposed dispatch will not be called as a dispatch - # method. - self.getPage('/2/2/foo/foo') - self.assertBody("SubSubRoot default") - - # Dynamic dispatch will fail here for the subsubnodes - # so the SubRoot gets called - self.getPage('/1/asdf/') - self.assertBody("SubRoot ('asdf',)") - - self.getPage('/1/asdf/asdf') - self.assertBody("SubRoot ('asdf', 'asdf')") - - self.getPage('/1/asdf/handler') - self.assertBody("SubRoot ('asdf', 'handler')") - - def testMethodDispatch(self): - # GET acts like a container - self.getPage("/users") - self.assertBody("[1, 2]") - self.assertHeader('Allow', 'GET, HEAD, POST') - - # POST to the container URI allows creation - self.getPage("/users", method="POST", body="name=baz") - self.assertBody("POST 3") - self.assertHeader('Allow', 'GET, HEAD, POST') - - # POST to a specific instanct URI results in a 404 - # as the resource does not exit. - self.getPage("/users/5", method="POST", body="name=baz") - self.assertStatus(404) - - # PUT to a specific instanct URI results in creation - self.getPage("/users/5", method="PUT", body="name=boris") - self.assertBody("PUT 5") - self.assertHeader('Allow', 'DELETE, GET, HEAD, POST, PUT') - - # GET acts like a container - self.getPage("/users") - self.assertBody("[1, 2, 3, 5]") - self.assertHeader('Allow', 'GET, HEAD, POST') - - test_cases = ( - (1, 'foo', 'fooupdated', 'DELETE, GET, HEAD, POST, PUT'), - (2, 'bar', 'barupdated', 'DELETE, GET, HEAD, POST, PUT'), - (3, 'baz', 'bazupdated', 'DELETE, GET, HEAD, POST, PUT'), - (5, 'boris', 'borisupdated', 'DELETE, GET, HEAD, POST, PUT'), - ) - for id, name, updatedname, headers in test_cases: - self.getPage("/users/%d" % id) - self.assertBody(name) - self.assertHeader('Allow', headers) - - # Make sure POSTs update already existings resources - self.getPage("/users/%d" % id, method='POST', body="name=%s" % updatedname) - self.assertBody("POST %d" % id) - self.assertHeader('Allow', headers) - - # Make sure PUTs Update already existing resources. - self.getPage("/users/%d" % id, method='PUT', body="name=%s" % updatedname) - self.assertBody("PUT %d" % id) - self.assertHeader('Allow', headers) - - # Make sure DELETES Remove already existing resources. - self.getPage("/users/%d" % id, method='DELETE') - self.assertBody("DELETE %d" % id) - self.assertHeader('Allow', headers) - - - # GET acts like a container - self.getPage("/users") - self.assertBody("[]") - self.assertHeader('Allow', 'GET, HEAD, POST') - - def testVpathDispatch(self): - self.getPage("/decorated/") - self.assertBody("no params") - - self.getPage("/decorated/hi") - self.assertBody("hi was not interpreted as 'a' param") - - self.getPage("/decorated/yo/") - self.assertBody("a:yo") - - self.getPage("/decorated/yo/there/") - self.assertBody("a:yo,b:there") - - self.getPage("/decorated/yo/there/delete") - self.assertBody("deleting yo and there") - - self.getPage("/decorated/yo/there/handled_by_dispatch/") - self.assertBody("custom") - - self.getPage("/undecorated/blah/") - self.assertBody("index: blah") - - self.getPage("/index_only/a/b/c/d/e/f/g/") - self.assertBody("IndexOnly index") - - self.getPage("/parameter_test/argument2/") - self.assertBody("argument2") - diff --git a/cherrypy/test/test_encoding.py b/cherrypy/test/test_encoding.py deleted file mode 100644 index 67b28ede..00000000 --- a/cherrypy/test/test_encoding.py +++ /dev/null @@ -1,363 +0,0 @@ - -import gzip -import sys - -import cherrypy -from cherrypy._cpcompat import BytesIO, IncompleteRead, ntob, ntou - -europoundUnicode = ntou('\x80\xa3') -sing = u"\u6bdb\u6cfd\u4e1c: Sing, Little Birdie?" -sing8 = sing.encode('utf-8') -sing16 = sing.encode('utf-16') - - -from cherrypy.test import helper - - -class EncodingTests(helper.CPWebCase): - - def setup_server(): - class Root: - def index(self, param): - assert param == europoundUnicode, "%r != %r" % (param, europoundUnicode) - yield europoundUnicode - index.exposed = True - - def mao_zedong(self): - return sing - mao_zedong.exposed = True - - def utf8(self): - return sing8 - utf8.exposed = True - utf8._cp_config = {'tools.encode.encoding': 'utf-8'} - - def cookies_and_headers(self): - # if the headers have non-ascii characters and a cookie has - # any part which is unicode (even ascii), the response - # should not fail. - cherrypy.response.cookie['candy'] = 'bar' - cherrypy.response.cookie['candy']['domain'] = 'cherrypy.org' - cherrypy.response.headers['Some-Header'] = 'My d\xc3\xb6g has fleas' - return 'Any content' - cookies_and_headers.exposed = True - - def reqparams(self, *args, **kwargs): - return ntob(', ').join([": ".join((k, v)).encode('utf8') - for k, v in cherrypy.request.params.items()]) - reqparams.exposed = True - - def nontext(self, *args, **kwargs): - cherrypy.response.headers['Content-Type'] = 'application/binary' - return '\x00\x01\x02\x03' - nontext.exposed = True - nontext._cp_config = {'tools.encode.text_only': False, - 'tools.encode.add_charset': True, - } - - class GZIP: - def index(self): - yield "Hello, world" - index.exposed = True - - def noshow(self): - # Test for ticket #147, where yield showed no exceptions (content- - # encoding was still gzip even though traceback wasn't zipped). - raise IndexError() - yield "Here be dragons" - noshow.exposed = True - # Turn encoding off so the gzip tool is the one doing the collapse. - noshow._cp_config = {'tools.encode.on': False} - - def noshow_stream(self): - # Test for ticket #147, where yield showed no exceptions (content- - # encoding was still gzip even though traceback wasn't zipped). - raise IndexError() - yield "Here be dragons" - noshow_stream.exposed = True - noshow_stream._cp_config = {'response.stream': True} - - class Decode: - def extra_charset(self, *args, **kwargs): - return ', '.join([": ".join((k, v)) - for k, v in cherrypy.request.params.items()]) - extra_charset.exposed = True - extra_charset._cp_config = { - 'tools.decode.on': True, - 'tools.decode.default_encoding': ['utf-16'], - } - - def force_charset(self, *args, **kwargs): - return ', '.join([": ".join((k, v)) - for k, v in cherrypy.request.params.items()]) - force_charset.exposed = True - force_charset._cp_config = { - 'tools.decode.on': True, - 'tools.decode.encoding': 'utf-16', - } - - root = Root() - root.gzip = GZIP() - root.decode = Decode() - cherrypy.tree.mount(root, config={'/gzip': {'tools.gzip.on': True}}) - setup_server = staticmethod(setup_server) - - def test_query_string_decoding(self): - europoundUtf8 = europoundUnicode.encode('utf-8') - self.getPage(ntob('/?param=') + europoundUtf8) - self.assertBody(europoundUtf8) - - # Encoded utf8 query strings MUST be parsed correctly. - # Here, q is the POUND SIGN U+00A3 encoded in utf8 and then %HEX - self.getPage("/reqparams?q=%C2%A3") - # The return value will be encoded as utf8. - self.assertBody(ntob("q: \xc2\xa3")) - - # Query strings that are incorrectly encoded MUST raise 404. - # Here, q is the POUND SIGN U+00A3 encoded in latin1 and then %HEX - self.getPage("/reqparams?q=%A3") - self.assertStatus(404) - self.assertErrorPage(404, - "The given query string could not be processed. Query " - "strings for this resource must be encoded with 'utf8'.") - - def test_urlencoded_decoding(self): - # Test the decoding of an application/x-www-form-urlencoded entity. - europoundUtf8 = europoundUnicode.encode('utf-8') - body=ntob("param=") + europoundUtf8 - self.getPage('/', method='POST', - headers=[("Content-Type", "application/x-www-form-urlencoded"), - ("Content-Length", str(len(body))), - ], - body=body), - self.assertBody(europoundUtf8) - - # Encoded utf8 entities MUST be parsed and decoded correctly. - # Here, q is the POUND SIGN U+00A3 encoded in utf8 - body = ntob("q=\xc2\xa3") - self.getPage('/reqparams', method='POST', - headers=[("Content-Type", "application/x-www-form-urlencoded"), - ("Content-Length", str(len(body))), - ], - body=body), - self.assertBody(ntob("q: \xc2\xa3")) - - # ...and in utf16, which is not in the default attempt_charsets list: - body = ntob("\xff\xfeq\x00=\xff\xfe\xa3\x00") - self.getPage('/reqparams', method='POST', - headers=[("Content-Type", "application/x-www-form-urlencoded;charset=utf-16"), - ("Content-Length", str(len(body))), - ], - body=body), - self.assertBody(ntob("q: \xc2\xa3")) - - # Entities that are incorrectly encoded MUST raise 400. - # Here, q is the POUND SIGN U+00A3 encoded in utf16, but - # the Content-Type incorrectly labels it utf-8. - body = ntob("\xff\xfeq\x00=\xff\xfe\xa3\x00") - self.getPage('/reqparams', method='POST', - headers=[("Content-Type", "application/x-www-form-urlencoded;charset=utf-8"), - ("Content-Length", str(len(body))), - ], - body=body), - self.assertStatus(400) - self.assertErrorPage(400, - "The request entity could not be decoded. The following charsets " - "were attempted: ['utf-8']") - - def test_decode_tool(self): - # An extra charset should be tried first, and succeed if it matches. - # Here, we add utf-16 as a charset and pass a utf-16 body. - body = ntob("\xff\xfeq\x00=\xff\xfe\xa3\x00") - self.getPage('/decode/extra_charset', method='POST', - headers=[("Content-Type", "application/x-www-form-urlencoded"), - ("Content-Length", str(len(body))), - ], - body=body), - self.assertBody(ntob("q: \xc2\xa3")) - - # An extra charset should be tried first, and continue to other default - # charsets if it doesn't match. - # Here, we add utf-16 as a charset but still pass a utf-8 body. - body = ntob("q=\xc2\xa3") - self.getPage('/decode/extra_charset', method='POST', - headers=[("Content-Type", "application/x-www-form-urlencoded"), - ("Content-Length", str(len(body))), - ], - body=body), - self.assertBody(ntob("q: \xc2\xa3")) - - # An extra charset should error if force is True and it doesn't match. - # Here, we force utf-16 as a charset but still pass a utf-8 body. - body = ntob("q=\xc2\xa3") - self.getPage('/decode/force_charset', method='POST', - headers=[("Content-Type", "application/x-www-form-urlencoded"), - ("Content-Length", str(len(body))), - ], - body=body), - self.assertErrorPage(400, - "The request entity could not be decoded. The following charsets " - "were attempted: ['utf-16']") - - def test_multipart_decoding(self): - # Test the decoding of a multipart entity when the charset (utf16) is - # explicitly given. - body=ntob('\r\n'.join(['--X', - 'Content-Type: text/plain;charset=utf-16', - 'Content-Disposition: form-data; name="text"', - '', - '\xff\xfea\x00b\x00\x1c c\x00', - '--X', - 'Content-Type: text/plain;charset=utf-16', - 'Content-Disposition: form-data; name="submit"', - '', - '\xff\xfeC\x00r\x00e\x00a\x00t\x00e\x00', - '--X--'])) - self.getPage('/reqparams', method='POST', - headers=[("Content-Type", "multipart/form-data;boundary=X"), - ("Content-Length", str(len(body))), - ], - body=body), - self.assertBody(ntob("text: ab\xe2\x80\x9cc, submit: Create")) - - def test_multipart_decoding_no_charset(self): - # Test the decoding of a multipart entity when the charset (utf8) is - # NOT explicitly given, but is in the list of charsets to attempt. - body=ntob('\r\n'.join(['--X', - 'Content-Disposition: form-data; name="text"', - '', - '\xe2\x80\x9c', - '--X', - 'Content-Disposition: form-data; name="submit"', - '', - 'Create', - '--X--'])) - self.getPage('/reqparams', method='POST', - headers=[("Content-Type", "multipart/form-data;boundary=X"), - ("Content-Length", str(len(body))), - ], - body=body), - self.assertBody(ntob("text: \xe2\x80\x9c, submit: Create")) - - def test_multipart_decoding_no_successful_charset(self): - # Test the decoding of a multipart entity when the charset (utf16) is - # NOT explicitly given, and is NOT in the list of charsets to attempt. - body=ntob('\r\n'.join(['--X', - 'Content-Disposition: form-data; name="text"', - '', - '\xff\xfea\x00b\x00\x1c c\x00', - '--X', - 'Content-Disposition: form-data; name="submit"', - '', - '\xff\xfeC\x00r\x00e\x00a\x00t\x00e\x00', - '--X--'])) - self.getPage('/reqparams', method='POST', - headers=[("Content-Type", "multipart/form-data;boundary=X"), - ("Content-Length", str(len(body))), - ], - body=body), - self.assertStatus(400) - self.assertErrorPage(400, - "The request entity could not be decoded. The following charsets " - "were attempted: ['us-ascii', 'utf-8']") - - def test_nontext(self): - self.getPage('/nontext') - self.assertHeader('Content-Type', 'application/binary;charset=utf-8') - self.assertBody('\x00\x01\x02\x03') - - def testEncoding(self): - # Default encoding should be utf-8 - self.getPage('/mao_zedong') - self.assertBody(sing8) - - # Ask for utf-16. - self.getPage('/mao_zedong', [('Accept-Charset', 'utf-16')]) - self.assertHeader('Content-Type', 'text/html;charset=utf-16') - self.assertBody(sing16) - - # Ask for multiple encodings. ISO-8859-1 should fail, and utf-16 - # should be produced. - self.getPage('/mao_zedong', [('Accept-Charset', - 'iso-8859-1;q=1, utf-16;q=0.5')]) - self.assertBody(sing16) - - # The "*" value should default to our default_encoding, utf-8 - self.getPage('/mao_zedong', [('Accept-Charset', '*;q=1, utf-7;q=.2')]) - self.assertBody(sing8) - - # Only allow iso-8859-1, which should fail and raise 406. - self.getPage('/mao_zedong', [('Accept-Charset', 'iso-8859-1, *;q=0')]) - self.assertStatus("406 Not Acceptable") - self.assertInBody("Your client sent this Accept-Charset header: " - "iso-8859-1, *;q=0. We tried these charsets: " - "iso-8859-1.") - - # Ask for x-mac-ce, which should be unknown. See ticket #569. - self.getPage('/mao_zedong', [('Accept-Charset', - 'us-ascii, ISO-8859-1, x-mac-ce')]) - self.assertStatus("406 Not Acceptable") - self.assertInBody("Your client sent this Accept-Charset header: " - "us-ascii, ISO-8859-1, x-mac-ce. We tried these " - "charsets: ISO-8859-1, us-ascii, x-mac-ce.") - - # Test the 'encoding' arg to encode. - self.getPage('/utf8') - self.assertBody(sing8) - self.getPage('/utf8', [('Accept-Charset', 'us-ascii, ISO-8859-1')]) - self.assertStatus("406 Not Acceptable") - - def testGzip(self): - zbuf = BytesIO() - zfile = gzip.GzipFile(mode='wb', fileobj=zbuf, compresslevel=9) - zfile.write(ntob("Hello, world")) - zfile.close() - - self.getPage('/gzip/', headers=[("Accept-Encoding", "gzip")]) - self.assertInBody(zbuf.getvalue()[:3]) - self.assertHeader("Vary", "Accept-Encoding") - self.assertHeader("Content-Encoding", "gzip") - - # Test when gzip is denied. - self.getPage('/gzip/', headers=[("Accept-Encoding", "identity")]) - self.assertHeader("Vary", "Accept-Encoding") - self.assertNoHeader("Content-Encoding") - self.assertBody("Hello, world") - - self.getPage('/gzip/', headers=[("Accept-Encoding", "gzip;q=0")]) - self.assertHeader("Vary", "Accept-Encoding") - self.assertNoHeader("Content-Encoding") - self.assertBody("Hello, world") - - self.getPage('/gzip/', headers=[("Accept-Encoding", "*;q=0")]) - self.assertStatus(406) - self.assertNoHeader("Content-Encoding") - self.assertErrorPage(406, "identity, gzip") - - # Test for ticket #147 - self.getPage('/gzip/noshow', headers=[("Accept-Encoding", "gzip")]) - self.assertNoHeader('Content-Encoding') - self.assertStatus(500) - self.assertErrorPage(500, pattern="IndexError\n") - - # In this case, there's nothing we can do to deliver a - # readable page, since 1) the gzip header is already set, - # and 2) we may have already written some of the body. - # The fix is to never stream yields when using gzip. - if (cherrypy.server.protocol_version == "HTTP/1.0" or - getattr(cherrypy.server, "using_apache", False)): - self.getPage('/gzip/noshow_stream', - headers=[("Accept-Encoding", "gzip")]) - self.assertHeader('Content-Encoding', 'gzip') - self.assertInBody('\x1f\x8b\x08\x00') - else: - # The wsgiserver will simply stop sending data, and the HTTP client - # will error due to an incomplete chunk-encoded stream. - self.assertRaises((ValueError, IncompleteRead), self.getPage, - '/gzip/noshow_stream', - headers=[("Accept-Encoding", "gzip")]) - - def test_UnicodeHeaders(self): - self.getPage('/cookies_and_headers') - self.assertBody('Any content') - diff --git a/cherrypy/test/test_etags.py b/cherrypy/test/test_etags.py deleted file mode 100644 index 026f9d65..00000000 --- a/cherrypy/test/test_etags.py +++ /dev/null @@ -1,81 +0,0 @@ -import cherrypy -from cherrypy.test import helper - - -class ETagTest(helper.CPWebCase): - - def setup_server(): - class Root: - def resource(self): - return "Oh wah ta goo Siam." - resource.exposed = True - - def fail(self, code): - code = int(code) - if 300 <= code <= 399: - raise cherrypy.HTTPRedirect([], code) - else: - raise cherrypy.HTTPError(code) - fail.exposed = True - - def unicoded(self): - return u'I am a \u1ee4nicode string.' - unicoded.exposed = True - unicoded._cp_config = {'tools.encode.on': True} - - conf = {'/': {'tools.etags.on': True, - 'tools.etags.autotags': True, - }} - cherrypy.tree.mount(Root(), config=conf) - setup_server = staticmethod(setup_server) - - def test_etags(self): - self.getPage("/resource") - self.assertStatus('200 OK') - self.assertHeader('Content-Type', 'text/html;charset=utf-8') - self.assertBody('Oh wah ta goo Siam.') - etag = self.assertHeader('ETag') - - # Test If-Match (both valid and invalid) - self.getPage("/resource", headers=[('If-Match', etag)]) - self.assertStatus("200 OK") - self.getPage("/resource", headers=[('If-Match', "*")]) - self.assertStatus("200 OK") - self.getPage("/resource", headers=[('If-Match', "*")], method="POST") - self.assertStatus("200 OK") - self.getPage("/resource", headers=[('If-Match', "a bogus tag")]) - self.assertStatus("412 Precondition Failed") - - # Test If-None-Match (both valid and invalid) - self.getPage("/resource", headers=[('If-None-Match', etag)]) - self.assertStatus(304) - self.getPage("/resource", method='POST', headers=[('If-None-Match', etag)]) - self.assertStatus("412 Precondition Failed") - self.getPage("/resource", headers=[('If-None-Match', "*")]) - self.assertStatus(304) - self.getPage("/resource", headers=[('If-None-Match', "a bogus tag")]) - self.assertStatus("200 OK") - - def test_errors(self): - self.getPage("/resource") - self.assertStatus(200) - etag = self.assertHeader('ETag') - - # Test raising errors in page handler - self.getPage("/fail/412", headers=[('If-Match', etag)]) - self.assertStatus(412) - self.getPage("/fail/304", headers=[('If-Match', etag)]) - self.assertStatus(304) - self.getPage("/fail/412", headers=[('If-None-Match', "*")]) - self.assertStatus(412) - self.getPage("/fail/304", headers=[('If-None-Match', "*")]) - self.assertStatus(304) - - def test_unicode_body(self): - self.getPage("/unicoded") - self.assertStatus(200) - etag1 = self.assertHeader('ETag') - self.getPage("/unicoded", headers=[('If-Match', etag1)]) - self.assertStatus(200) - self.assertHeader('ETag', etag1) - diff --git a/cherrypy/test/test_http.py b/cherrypy/test/test_http.py deleted file mode 100644 index eb72b5bf..00000000 --- a/cherrypy/test/test_http.py +++ /dev/null @@ -1,168 +0,0 @@ -"""Tests for managing HTTP issues (malformed requests, etc).""" - -import mimetypes - -import cherrypy -from cherrypy._cpcompat import HTTPConnection, HTTPSConnection, ntob - - -def encode_multipart_formdata(files): - """Return (content_type, body) ready for httplib.HTTP instance. - - files: a sequence of (name, filename, value) tuples for multipart uploads. - """ - BOUNDARY = '________ThIs_Is_tHe_bouNdaRY_$' - L = [] - for key, filename, value in files: - L.append('--' + BOUNDARY) - L.append('Content-Disposition: form-data; name="%s"; filename="%s"' % - (key, filename)) - ct = mimetypes.guess_type(filename)[0] or 'application/octet-stream' - L.append('Content-Type: %s' % ct) - L.append('') - L.append(value) - L.append('--' + BOUNDARY + '--') - L.append('') - body = '\r\n'.join(L) - content_type = 'multipart/form-data; boundary=%s' % BOUNDARY - return content_type, body - - - - -from cherrypy.test import helper - -class HTTPTests(helper.CPWebCase): - - def setup_server(): - class Root: - def index(self, *args, **kwargs): - return "Hello world!" - index.exposed = True - - def no_body(self, *args, **kwargs): - return "Hello world!" - no_body.exposed = True - no_body._cp_config = {'request.process_request_body': False} - - def post_multipart(self, file): - """Return a summary ("a * 65536\nb * 65536") of the uploaded file.""" - contents = file.file.read() - summary = [] - curchar = "" - count = 0 - for c in contents: - if c == curchar: - count += 1 - else: - if count: - summary.append("%s * %d" % (curchar, count)) - count = 1 - curchar = c - if count: - summary.append("%s * %d" % (curchar, count)) - return ", ".join(summary) - post_multipart.exposed = True - - cherrypy.tree.mount(Root()) - cherrypy.config.update({'server.max_request_body_size': 30000000}) - setup_server = staticmethod(setup_server) - - def test_no_content_length(self): - # "The presence of a message-body in a request is signaled by the - # inclusion of a Content-Length or Transfer-Encoding header field in - # the request's message-headers." - # - # Send a message with neither header and no body. Even though - # the request is of method POST, this should be OK because we set - # request.process_request_body to False for our handler. - if self.scheme == "https": - c = HTTPSConnection('%s:%s' % (self.interface(), self.PORT)) - else: - c = HTTPConnection('%s:%s' % (self.interface(), self.PORT)) - c.request("POST", "/no_body") - response = c.getresponse() - self.body = response.fp.read() - self.status = str(response.status) - self.assertStatus(200) - self.assertBody(ntob('Hello world!')) - - # Now send a message that has no Content-Length, but does send a body. - # Verify that CP times out the socket and responds - # with 411 Length Required. - if self.scheme == "https": - c = HTTPSConnection('%s:%s' % (self.interface(), self.PORT)) - else: - c = HTTPConnection('%s:%s' % (self.interface(), self.PORT)) - c.request("POST", "/") - response = c.getresponse() - self.body = response.fp.read() - self.status = str(response.status) - self.assertStatus(411) - - def test_post_multipart(self): - alphabet = "abcdefghijklmnopqrstuvwxyz" - # generate file contents for a large post - contents = "".join([c * 65536 for c in alphabet]) - - # encode as multipart form data - files=[('file', 'file.txt', contents)] - content_type, body = encode_multipart_formdata(files) - body = body.encode('Latin-1') - - # post file - if self.scheme == 'https': - c = HTTPSConnection('%s:%s' % (self.interface(), self.PORT)) - else: - c = HTTPConnection('%s:%s' % (self.interface(), self.PORT)) - c.putrequest('POST', '/post_multipart') - c.putheader('Content-Type', content_type) - c.putheader('Content-Length', str(len(body))) - c.endheaders() - c.send(body) - - response = c.getresponse() - self.body = response.fp.read() - self.status = str(response.status) - self.assertStatus(200) - self.assertBody(", ".join(["%s * 65536" % c for c in alphabet])) - - def test_malformed_request_line(self): - if getattr(cherrypy.server, "using_apache", False): - return self.skip("skipped due to known Apache differences...") - - # Test missing version in Request-Line - if self.scheme == 'https': - c = HTTPSConnection('%s:%s' % (self.interface(), self.PORT)) - else: - c = HTTPConnection('%s:%s' % (self.interface(), self.PORT)) - c._output(ntob('GET /')) - c._send_output() - if hasattr(c, 'strict'): - response = c.response_class(c.sock, strict=c.strict, method='GET') - else: - # Python 3.2 removed the 'strict' feature, saying: - # "http.client now always assumes HTTP/1.x compliant servers." - response = c.response_class(c.sock, method='GET') - response.begin() - self.assertEqual(response.status, 400) - self.assertEqual(response.fp.read(22), ntob("Malformed Request-Line")) - c.close() - - def test_malformed_header(self): - if self.scheme == 'https': - c = HTTPSConnection('%s:%s' % (self.interface(), self.PORT)) - else: - c = HTTPConnection('%s:%s' % (self.interface(), self.PORT)) - c.putrequest('GET', '/') - c.putheader('Content-Type', 'text/plain') - # See http://www.cherrypy.org/ticket/941 - c._output(ntob('Re, 1.2.3.4#015#012')) - c.endheaders() - - response = c.getresponse() - self.status = str(response.status) - self.assertStatus(400) - self.body = response.fp.read(20) - self.assertBody("Illegal header line.") - diff --git a/cherrypy/test/test_httpauth.py b/cherrypy/test/test_httpauth.py deleted file mode 100644 index 9d0eecb2..00000000 --- a/cherrypy/test/test_httpauth.py +++ /dev/null @@ -1,151 +0,0 @@ -import cherrypy -from cherrypy._cpcompat import md5, sha, ntob -from cherrypy.lib import httpauth - -from cherrypy.test import helper - -class HTTPAuthTest(helper.CPWebCase): - - def setup_server(): - class Root: - def index(self): - return "This is public." - index.exposed = True - - class DigestProtected: - def index(self): - return "Hello %s, you've been authorized." % cherrypy.request.login - index.exposed = True - - class BasicProtected: - def index(self): - return "Hello %s, you've been authorized." % cherrypy.request.login - index.exposed = True - - class BasicProtected2: - def index(self): - return "Hello %s, you've been authorized." % cherrypy.request.login - index.exposed = True - - def fetch_users(): - return {'test': 'test'} - - def sha_password_encrypter(password): - return sha(ntob(password)).hexdigest() - - def fetch_password(username): - return sha(ntob('test')).hexdigest() - - conf = {'/digest': {'tools.digest_auth.on': True, - 'tools.digest_auth.realm': 'localhost', - 'tools.digest_auth.users': fetch_users}, - '/basic': {'tools.basic_auth.on': True, - 'tools.basic_auth.realm': 'localhost', - 'tools.basic_auth.users': {'test': md5(ntob('test')).hexdigest()}}, - '/basic2': {'tools.basic_auth.on': True, - 'tools.basic_auth.realm': 'localhost', - 'tools.basic_auth.users': fetch_password, - 'tools.basic_auth.encrypt': sha_password_encrypter}} - - root = Root() - root.digest = DigestProtected() - root.basic = BasicProtected() - root.basic2 = BasicProtected2() - cherrypy.tree.mount(root, config=conf) - setup_server = staticmethod(setup_server) - - - def testPublic(self): - self.getPage("/") - self.assertStatus('200 OK') - self.assertHeader('Content-Type', 'text/html;charset=utf-8') - self.assertBody('This is public.') - - def testBasic(self): - self.getPage("/basic/") - self.assertStatus(401) - self.assertHeader('WWW-Authenticate', 'Basic realm="localhost"') - - self.getPage('/basic/', [('Authorization', 'Basic dGVzdDp0ZX60')]) - self.assertStatus(401) - - self.getPage('/basic/', [('Authorization', 'Basic dGVzdDp0ZXN0')]) - self.assertStatus('200 OK') - self.assertBody("Hello test, you've been authorized.") - - def testBasic2(self): - self.getPage("/basic2/") - self.assertStatus(401) - self.assertHeader('WWW-Authenticate', 'Basic realm="localhost"') - - self.getPage('/basic2/', [('Authorization', 'Basic dGVzdDp0ZX60')]) - self.assertStatus(401) - - self.getPage('/basic2/', [('Authorization', 'Basic dGVzdDp0ZXN0')]) - self.assertStatus('200 OK') - self.assertBody("Hello test, you've been authorized.") - - def testDigest(self): - self.getPage("/digest/") - self.assertStatus(401) - - value = None - for k, v in self.headers: - if k.lower() == "www-authenticate": - if v.startswith("Digest"): - value = v - break - - if value is None: - self._handlewebError("Digest authentification scheme was not found") - - value = value[7:] - items = value.split(', ') - tokens = {} - for item in items: - key, value = item.split('=') - tokens[key.lower()] = value - - missing_msg = "%s is missing" - bad_value_msg = "'%s' was expecting '%s' but found '%s'" - nonce = None - if 'realm' not in tokens: - self._handlewebError(missing_msg % 'realm') - elif tokens['realm'] != '"localhost"': - self._handlewebError(bad_value_msg % ('realm', '"localhost"', tokens['realm'])) - if 'nonce' not in tokens: - self._handlewebError(missing_msg % 'nonce') - else: - nonce = tokens['nonce'].strip('"') - if 'algorithm' not in tokens: - self._handlewebError(missing_msg % 'algorithm') - elif tokens['algorithm'] != '"MD5"': - self._handlewebError(bad_value_msg % ('algorithm', '"MD5"', tokens['algorithm'])) - if 'qop' not in tokens: - self._handlewebError(missing_msg % 'qop') - elif tokens['qop'] != '"auth"': - self._handlewebError(bad_value_msg % ('qop', '"auth"', tokens['qop'])) - - # Test a wrong 'realm' value - base_auth = 'Digest username="test", realm="wrong realm", nonce="%s", uri="/digest/", algorithm=MD5, response="%s", qop=auth, nc=%s, cnonce="1522e61005789929"' - - auth = base_auth % (nonce, '', '00000001') - params = httpauth.parseAuthorization(auth) - response = httpauth._computeDigestResponse(params, 'test') - - auth = base_auth % (nonce, response, '00000001') - self.getPage('/digest/', [('Authorization', auth)]) - self.assertStatus(401) - - # Test that must pass - base_auth = 'Digest username="test", realm="localhost", nonce="%s", uri="/digest/", algorithm=MD5, response="%s", qop=auth, nc=%s, cnonce="1522e61005789929"' - - auth = base_auth % (nonce, '', '00000001') - params = httpauth.parseAuthorization(auth) - response = httpauth._computeDigestResponse(params, 'test') - - auth = base_auth % (nonce, response, '00000001') - self.getPage('/digest/', [('Authorization', auth)]) - self.assertStatus('200 OK') - self.assertBody("Hello test, you've been authorized.") - diff --git a/cherrypy/test/test_httplib.py b/cherrypy/test/test_httplib.py deleted file mode 100644 index 5dc40fd2..00000000 --- a/cherrypy/test/test_httplib.py +++ /dev/null @@ -1,29 +0,0 @@ -"""Tests for cherrypy/lib/httputil.py.""" - -import unittest -from cherrypy.lib import httputil - - -class UtilityTests(unittest.TestCase): - - def test_urljoin(self): - # Test all slash+atom combinations for SCRIPT_NAME and PATH_INFO - self.assertEqual(httputil.urljoin("/sn/", "/pi/"), "/sn/pi/") - self.assertEqual(httputil.urljoin("/sn/", "/pi"), "/sn/pi") - self.assertEqual(httputil.urljoin("/sn/", "/"), "/sn/") - self.assertEqual(httputil.urljoin("/sn/", ""), "/sn/") - self.assertEqual(httputil.urljoin("/sn", "/pi/"), "/sn/pi/") - self.assertEqual(httputil.urljoin("/sn", "/pi"), "/sn/pi") - self.assertEqual(httputil.urljoin("/sn", "/"), "/sn/") - self.assertEqual(httputil.urljoin("/sn", ""), "/sn") - self.assertEqual(httputil.urljoin("/", "/pi/"), "/pi/") - self.assertEqual(httputil.urljoin("/", "/pi"), "/pi") - self.assertEqual(httputil.urljoin("/", "/"), "/") - self.assertEqual(httputil.urljoin("/", ""), "/") - self.assertEqual(httputil.urljoin("", "/pi/"), "/pi/") - self.assertEqual(httputil.urljoin("", "/pi"), "/pi") - self.assertEqual(httputil.urljoin("", "/"), "/") - self.assertEqual(httputil.urljoin("", ""), "/") - -if __name__ == '__main__': - unittest.main() diff --git a/cherrypy/test/test_json.py b/cherrypy/test/test_json.py deleted file mode 100644 index a02c0767..00000000 --- a/cherrypy/test/test_json.py +++ /dev/null @@ -1,79 +0,0 @@ -import cherrypy -from cherrypy.test import helper - -from cherrypy._cpcompat import json - -class JsonTest(helper.CPWebCase): - def setup_server(): - class Root(object): - def plain(self): - return 'hello' - plain.exposed = True - - def json_string(self): - return 'hello' - json_string.exposed = True - json_string._cp_config = {'tools.json_out.on': True} - - def json_list(self): - return ['a', 'b', 42] - json_list.exposed = True - json_list._cp_config = {'tools.json_out.on': True} - - def json_dict(self): - return {'answer': 42} - json_dict.exposed = True - json_dict._cp_config = {'tools.json_out.on': True} - - def json_post(self): - if cherrypy.request.json == [13, 'c']: - return 'ok' - else: - return 'nok' - json_post.exposed = True - json_post._cp_config = {'tools.json_in.on': True} - - root = Root() - cherrypy.tree.mount(root) - setup_server = staticmethod(setup_server) - - def test_json_output(self): - if json is None: - self.skip("json not found ") - return - - self.getPage("/plain") - self.assertBody("hello") - - self.getPage("/json_string") - self.assertBody('"hello"') - - self.getPage("/json_list") - self.assertBody('["a", "b", 42]') - - self.getPage("/json_dict") - self.assertBody('{"answer": 42}') - - def test_json_input(self): - if json is None: - self.skip("json not found ") - return - - body = '[13, "c"]' - headers = [('Content-Type', 'application/json'), - ('Content-Length', str(len(body)))] - self.getPage("/json_post", method="POST", headers=headers, body=body) - self.assertBody('ok') - - body = '[13, "c"]' - headers = [('Content-Type', 'text/plain'), - ('Content-Length', str(len(body)))] - self.getPage("/json_post", method="POST", headers=headers, body=body) - self.assertStatus(415, 'Expected an application/json content type') - - body = '[13, -]' - headers = [('Content-Type', 'application/json'), - ('Content-Length', str(len(body)))] - self.getPage("/json_post", method="POST", headers=headers, body=body) - self.assertStatus(400, 'Invalid JSON document') - diff --git a/cherrypy/test/test_logging.py b/cherrypy/test/test_logging.py deleted file mode 100644 index 5a13cd4a..00000000 --- a/cherrypy/test/test_logging.py +++ /dev/null @@ -1,149 +0,0 @@ -"""Basic tests for the CherryPy core: request handling.""" - -import os -localDir = os.path.dirname(__file__) - -import cherrypy - -access_log = os.path.join(localDir, "access.log") -error_log = os.path.join(localDir, "error.log") - -# Some unicode strings. -tartaros = u'\u03a4\u1f71\u03c1\u03c4\u03b1\u03c1\u03bf\u03c2' -erebos = u'\u0388\u03c1\u03b5\u03b2\u03bf\u03c2.com' - - -def setup_server(): - class Root: - - def index(self): - return "hello" - index.exposed = True - - def uni_code(self): - cherrypy.request.login = tartaros - cherrypy.request.remote.name = erebos - uni_code.exposed = True - - def slashes(self): - cherrypy.request.request_line = r'GET /slashed\path HTTP/1.1' - slashes.exposed = True - - def whitespace(self): - # User-Agent = "User-Agent" ":" 1*( product | comment ) - # comment = "(" *( ctext | quoted-pair | comment ) ")" - # ctext = - # TEXT = - # LWS = [CRLF] 1*( SP | HT ) - cherrypy.request.headers['User-Agent'] = 'Browzuh (1.0\r\n\t\t.3)' - whitespace.exposed = True - - def as_string(self): - return "content" - as_string.exposed = True - - def as_yield(self): - yield "content" - as_yield.exposed = True - - def error(self): - raise ValueError() - error.exposed = True - error._cp_config = {'tools.log_tracebacks.on': True} - - root = Root() - - - cherrypy.config.update({'log.error_file': error_log, - 'log.access_file': access_log, - }) - cherrypy.tree.mount(root) - - - -from cherrypy.test import helper, logtest - -class AccessLogTests(helper.CPWebCase, logtest.LogCase): - setup_server = staticmethod(setup_server) - - logfile = access_log - - def testNormalReturn(self): - self.markLog() - self.getPage("/as_string", - headers=[('Referer', 'http://www.cherrypy.org/'), - ('User-Agent', 'Mozilla/5.0')]) - self.assertBody('content') - self.assertStatus(200) - - intro = '%s - - [' % self.interface() - - self.assertLog(-1, intro) - - if [k for k, v in self.headers if k.lower() == 'content-length']: - self.assertLog(-1, '] "GET %s/as_string HTTP/1.1" 200 7 ' - '"http://www.cherrypy.org/" "Mozilla/5.0"' - % self.prefix()) - else: - self.assertLog(-1, '] "GET %s/as_string HTTP/1.1" 200 - ' - '"http://www.cherrypy.org/" "Mozilla/5.0"' - % self.prefix()) - - def testNormalYield(self): - self.markLog() - self.getPage("/as_yield") - self.assertBody('content') - self.assertStatus(200) - - intro = '%s - - [' % self.interface() - - self.assertLog(-1, intro) - if [k for k, v in self.headers if k.lower() == 'content-length']: - self.assertLog(-1, '] "GET %s/as_yield HTTP/1.1" 200 7 "" ""' % - self.prefix()) - else: - self.assertLog(-1, '] "GET %s/as_yield HTTP/1.1" 200 - "" ""' - % self.prefix()) - - def testEscapedOutput(self): - # Test unicode in access log pieces. - self.markLog() - self.getPage("/uni_code") - self.assertStatus(200) - self.assertLog(-1, repr(tartaros.encode('utf8'))[1:-1]) - # Test the erebos value. Included inline for your enlightenment. - # Note the 'r' prefix--those backslashes are literals. - self.assertLog(-1, r'\xce\x88\xcf\x81\xce\xb5\xce\xb2\xce\xbf\xcf\x82') - - # Test backslashes in output. - self.markLog() - self.getPage("/slashes") - self.assertStatus(200) - self.assertLog(-1, r'"GET /slashed\\path HTTP/1.1"') - - # Test whitespace in output. - self.markLog() - self.getPage("/whitespace") - self.assertStatus(200) - # Again, note the 'r' prefix. - self.assertLog(-1, r'"Browzuh (1.0\r\n\t\t.3)"') - - -class ErrorLogTests(helper.CPWebCase, logtest.LogCase): - setup_server = staticmethod(setup_server) - - logfile = error_log - - def testTracebacks(self): - # Test that tracebacks get written to the error log. - self.markLog() - ignore = helper.webtest.ignored_exceptions - ignore.append(ValueError) - try: - self.getPage("/error") - self.assertInBody("raise ValueError()") - self.assertLog(0, 'HTTP Traceback (most recent call last):') - self.assertLog(-3, 'raise ValueError()') - finally: - ignore.pop() - diff --git a/cherrypy/test/test_mime.py b/cherrypy/test/test_mime.py deleted file mode 100644 index 605071b8..00000000 --- a/cherrypy/test/test_mime.py +++ /dev/null @@ -1,128 +0,0 @@ -"""Tests for various MIME issues, including the safe_multipart Tool.""" - -import cherrypy -from cherrypy._cpcompat import ntob, ntou, sorted - -def setup_server(): - - class Root: - - def multipart(self, parts): - return repr(parts) - multipart.exposed = True - - def multipart_form_data(self, **kwargs): - return repr(list(sorted(kwargs.items()))) - multipart_form_data.exposed = True - - def flashupload(self, Filedata, Upload, Filename): - return ("Upload: %r, Filename: %r, Filedata: %r" % - (Upload, Filename, Filedata.file.read())) - flashupload.exposed = True - - cherrypy.config.update({'server.max_request_body_size': 0}) - cherrypy.tree.mount(Root()) - - -# Client-side code # - -from cherrypy.test import helper - -class MultipartTest(helper.CPWebCase): - setup_server = staticmethod(setup_server) - - def test_multipart(self): - text_part = ntou("This is the text version") - html_part = ntou(""" - - - - - - -This is the HTML version - - -""") - body = '\r\n'.join([ - "--123456789", - "Content-Type: text/plain; charset='ISO-8859-1'", - "Content-Transfer-Encoding: 7bit", - "", - text_part, - "--123456789", - "Content-Type: text/html; charset='ISO-8859-1'", - "", - html_part, - "--123456789--"]) - headers = [ - ('Content-Type', 'multipart/mixed; boundary=123456789'), - ('Content-Length', str(len(body))), - ] - self.getPage('/multipart', headers, "POST", body) - self.assertBody(repr([text_part, html_part])) - - def test_multipart_form_data(self): - body='\r\n'.join(['--X', - 'Content-Disposition: form-data; name="foo"', - '', - 'bar', - '--X', - # Test a param with more than one value. - # See http://www.cherrypy.org/ticket/1028 - 'Content-Disposition: form-data; name="baz"', - '', - '111', - '--X', - 'Content-Disposition: form-data; name="baz"', - '', - '333', - '--X--']) - self.getPage('/multipart_form_data', method='POST', - headers=[("Content-Type", "multipart/form-data;boundary=X"), - ("Content-Length", str(len(body))), - ], - body=body), - self.assertBody(repr([('baz', [u'111', u'333']), ('foo', u'bar')])) - - -class SafeMultipartHandlingTest(helper.CPWebCase): - setup_server = staticmethod(setup_server) - - def test_Flash_Upload(self): - headers = [ - ('Accept', 'text/*'), - ('Content-Type', 'multipart/form-data; ' - 'boundary=----------KM7Ij5cH2KM7Ef1gL6ae0ae0cH2gL6'), - ('User-Agent', 'Shockwave Flash'), - ('Host', 'www.example.com:8080'), - ('Content-Length', '499'), - ('Connection', 'Keep-Alive'), - ('Cache-Control', 'no-cache'), - ] - filedata = ntob('\r\n' - '\r\n' - '\r\n') - body = (ntob( - '------------KM7Ij5cH2KM7Ef1gL6ae0ae0cH2gL6\r\n' - 'Content-Disposition: form-data; name="Filename"\r\n' - '\r\n' - '.project\r\n' - '------------KM7Ij5cH2KM7Ef1gL6ae0ae0cH2gL6\r\n' - 'Content-Disposition: form-data; ' - 'name="Filedata"; filename=".project"\r\n' - 'Content-Type: application/octet-stream\r\n' - '\r\n') - + filedata + - ntob('\r\n' - '------------KM7Ij5cH2KM7Ef1gL6ae0ae0cH2gL6\r\n' - 'Content-Disposition: form-data; name="Upload"\r\n' - '\r\n' - 'Submit Query\r\n' - # Flash apps omit the trailing \r\n on the last line: - '------------KM7Ij5cH2KM7Ef1gL6ae0ae0cH2gL6--' - )) - self.getPage('/flashupload', headers, "POST", body) - self.assertBody("Upload: u'Submit Query', Filename: u'.project', " - "Filedata: %r" % filedata) - diff --git a/cherrypy/test/test_misc_tools.py b/cherrypy/test/test_misc_tools.py deleted file mode 100644 index fb94e860..00000000 --- a/cherrypy/test/test_misc_tools.py +++ /dev/null @@ -1,202 +0,0 @@ -import os -localDir = os.path.dirname(__file__) -logfile = os.path.join(localDir, "test_misc_tools.log") - -import cherrypy -from cherrypy import tools - - -def setup_server(): - class Root: - def index(self): - yield "Hello, world" - index.exposed = True - h = [("Content-Language", "en-GB"), ('Content-Type', 'text/plain')] - tools.response_headers(headers=h)(index) - - def other(self): - return "salut" - other.exposed = True - other._cp_config = { - 'tools.response_headers.on': True, - 'tools.response_headers.headers': [("Content-Language", "fr"), - ('Content-Type', 'text/plain')], - 'tools.log_hooks.on': True, - } - - - class Accept: - _cp_config = {'tools.accept.on': True} - - def index(self): - return 'Atom feed' - index.exposed = True - - # In Python 2.4+, we could use a decorator instead: - # @tools.accept('application/atom+xml') - def feed(self): - return """ - - Unknown Blog -""" - feed.exposed = True - feed._cp_config = {'tools.accept.media': 'application/atom+xml'} - - def select(self): - # We could also write this: mtype = cherrypy.lib.accept.accept(...) - mtype = tools.accept.callable(['text/html', 'text/plain']) - if mtype == 'text/html': - return "

Page Title

" - else: - return "PAGE TITLE" - select.exposed = True - - class Referer: - def accept(self): - return "Accepted!" - accept.exposed = True - reject = accept - - class AutoVary: - def index(self): - # Read a header directly with 'get' - ae = cherrypy.request.headers.get('Accept-Encoding') - # Read a header directly with '__getitem__' - cl = cherrypy.request.headers['Host'] - # Read a header directly with '__contains__' - hasif = 'If-Modified-Since' in cherrypy.request.headers - # Read a header directly with 'has_key' - has = cherrypy.request.headers.has_key('Range') - # Call a lib function - mtype = tools.accept.callable(['text/html', 'text/plain']) - return "Hello, world!" - index.exposed = True - - conf = {'/referer': {'tools.referer.on': True, - 'tools.referer.pattern': r'http://[^/]*example\.com', - }, - '/referer/reject': {'tools.referer.accept': False, - 'tools.referer.accept_missing': True, - }, - '/autovary': {'tools.autovary.on': True}, - } - - root = Root() - root.referer = Referer() - root.accept = Accept() - root.autovary = AutoVary() - cherrypy.tree.mount(root, config=conf) - cherrypy.config.update({'log.error_file': logfile}) - - -from cherrypy.test import helper - -class ResponseHeadersTest(helper.CPWebCase): - setup_server = staticmethod(setup_server) - - def testResponseHeadersDecorator(self): - self.getPage('/') - self.assertHeader("Content-Language", "en-GB") - self.assertHeader('Content-Type', 'text/plain;charset=utf-8') - - def testResponseHeaders(self): - self.getPage('/other') - self.assertHeader("Content-Language", "fr") - self.assertHeader('Content-Type', 'text/plain;charset=utf-8') - - -class RefererTest(helper.CPWebCase): - setup_server = staticmethod(setup_server) - - def testReferer(self): - self.getPage('/referer/accept') - self.assertErrorPage(403, 'Forbidden Referer header.') - - self.getPage('/referer/accept', - headers=[('Referer', 'http://www.example.com/')]) - self.assertStatus(200) - self.assertBody('Accepted!') - - # Reject - self.getPage('/referer/reject') - self.assertStatus(200) - self.assertBody('Accepted!') - - self.getPage('/referer/reject', - headers=[('Referer', 'http://www.example.com/')]) - self.assertErrorPage(403, 'Forbidden Referer header.') - - -class AcceptTest(helper.CPWebCase): - setup_server = staticmethod(setup_server) - - def test_Accept_Tool(self): - # Test with no header provided - self.getPage('/accept/feed') - self.assertStatus(200) - self.assertInBody('Unknown Blog') - - # Specify exact media type - self.getPage('/accept/feed', headers=[('Accept', 'application/atom+xml')]) - self.assertStatus(200) - self.assertInBody('Unknown Blog') - - # Specify matching media range - self.getPage('/accept/feed', headers=[('Accept', 'application/*')]) - self.assertStatus(200) - self.assertInBody('Unknown Blog') - - # Specify all media ranges - self.getPage('/accept/feed', headers=[('Accept', '*/*')]) - self.assertStatus(200) - self.assertInBody('Unknown Blog') - - # Specify unacceptable media types - self.getPage('/accept/feed', headers=[('Accept', 'text/html')]) - self.assertErrorPage(406, - "Your client sent this Accept header: text/html. " - "But this resource only emits these media types: " - "application/atom+xml.") - - # Test resource where tool is 'on' but media is None (not set). - self.getPage('/accept/') - self.assertStatus(200) - self.assertBody('Atom feed') - - def test_accept_selection(self): - # Try both our expected media types - self.getPage('/accept/select', [('Accept', 'text/html')]) - self.assertStatus(200) - self.assertBody('

Page Title

') - self.getPage('/accept/select', [('Accept', 'text/plain')]) - self.assertStatus(200) - self.assertBody('PAGE TITLE') - self.getPage('/accept/select', [('Accept', 'text/plain, text/*;q=0.5')]) - self.assertStatus(200) - self.assertBody('PAGE TITLE') - - # text/* and */* should prefer text/html since it comes first - # in our 'media' argument to tools.accept - self.getPage('/accept/select', [('Accept', 'text/*')]) - self.assertStatus(200) - self.assertBody('

Page Title

') - self.getPage('/accept/select', [('Accept', '*/*')]) - self.assertStatus(200) - self.assertBody('

Page Title

') - - # Try unacceptable media types - self.getPage('/accept/select', [('Accept', 'application/xml')]) - self.assertErrorPage(406, - "Your client sent this Accept header: application/xml. " - "But this resource only emits these media types: " - "text/html, text/plain.") - - -class AutoVaryTest(helper.CPWebCase): - setup_server = staticmethod(setup_server) - - def testAutoVary(self): - self.getPage('/autovary/') - self.assertHeader( - "Vary", 'Accept, Accept-Charset, Accept-Encoding, Host, If-Modified-Since, Range') - diff --git a/cherrypy/test/test_objectmapping.py b/cherrypy/test/test_objectmapping.py deleted file mode 100644 index 46816fcb..00000000 --- a/cherrypy/test/test_objectmapping.py +++ /dev/null @@ -1,403 +0,0 @@ -import cherrypy -from cherrypy._cptree import Application -from cherrypy.test import helper - -script_names = ["", "/foo", "/users/fred/blog", "/corp/blog"] - - -class ObjectMappingTest(helper.CPWebCase): - - def setup_server(): - class Root: - def index(self, name="world"): - return name - index.exposed = True - - def foobar(self): - return "bar" - foobar.exposed = True - - def default(self, *params, **kwargs): - return "default:" + repr(params) - default.exposed = True - - def other(self): - return "other" - other.exposed = True - - def extra(self, *p): - return repr(p) - extra.exposed = True - - def redirect(self): - raise cherrypy.HTTPRedirect('dir1/', 302) - redirect.exposed = True - - def notExposed(self): - return "not exposed" - - def confvalue(self): - return cherrypy.request.config.get("user") - confvalue.exposed = True - - def redirect_via_url(self, path): - raise cherrypy.HTTPRedirect(cherrypy.url(path)) - redirect_via_url.exposed = True - - def translate_html(self): - return "OK" - translate_html.exposed = True - - def mapped_func(self, ID=None): - return "ID is %s" % ID - mapped_func.exposed = True - setattr(Root, "Von B\xfclow", mapped_func) - - - class Exposing: - def base(self): - return "expose works!" - cherrypy.expose(base) - cherrypy.expose(base, "1") - cherrypy.expose(base, "2") - - class ExposingNewStyle(object): - def base(self): - return "expose works!" - cherrypy.expose(base) - cherrypy.expose(base, "1") - cherrypy.expose(base, "2") - - - class Dir1: - def index(self): - return "index for dir1" - index.exposed = True - - def myMethod(self): - return "myMethod from dir1, path_info is:" + repr(cherrypy.request.path_info) - myMethod.exposed = True - myMethod._cp_config = {'tools.trailing_slash.extra': True} - - def default(self, *params): - return "default for dir1, param is:" + repr(params) - default.exposed = True - - - class Dir2: - def index(self): - return "index for dir2, path is:" + cherrypy.request.path_info - index.exposed = True - - def script_name(self): - return cherrypy.tree.script_name() - script_name.exposed = True - - def cherrypy_url(self): - return cherrypy.url("/extra") - cherrypy_url.exposed = True - - def posparam(self, *vpath): - return "/".join(vpath) - posparam.exposed = True - - - class Dir3: - def default(self): - return "default for dir3, not exposed" - - class Dir4: - def index(self): - return "index for dir4, not exposed" - - class DefNoIndex: - def default(self, *args): - raise cherrypy.HTTPRedirect("contact") - default.exposed = True - - # MethodDispatcher code - class ByMethod: - exposed = True - - def __init__(self, *things): - self.things = list(things) - - def GET(self): - return repr(self.things) - - def POST(self, thing): - self.things.append(thing) - - class Collection: - default = ByMethod('a', 'bit') - - Root.exposing = Exposing() - Root.exposingnew = ExposingNewStyle() - Root.dir1 = Dir1() - Root.dir1.dir2 = Dir2() - Root.dir1.dir2.dir3 = Dir3() - Root.dir1.dir2.dir3.dir4 = Dir4() - Root.defnoindex = DefNoIndex() - Root.bymethod = ByMethod('another') - Root.collection = Collection() - - d = cherrypy.dispatch.MethodDispatcher() - for url in script_names: - conf = {'/': {'user': (url or "/").split("/")[-2]}, - '/bymethod': {'request.dispatch': d}, - '/collection': {'request.dispatch': d}, - } - cherrypy.tree.mount(Root(), url, conf) - - - class Isolated: - def index(self): - return "made it!" - index.exposed = True - - cherrypy.tree.mount(Isolated(), "/isolated") - - class AnotherApp: - - exposed = True - - def GET(self): - return "milk" - - cherrypy.tree.mount(AnotherApp(), "/app", {'/': {'request.dispatch': d}}) - setup_server = staticmethod(setup_server) - - - def testObjectMapping(self): - for url in script_names: - prefix = self.script_name = url - - self.getPage('/') - self.assertBody('world') - - self.getPage("/dir1/myMethod") - self.assertBody("myMethod from dir1, path_info is:'/dir1/myMethod'") - - self.getPage("/this/method/does/not/exist") - self.assertBody("default:('this', 'method', 'does', 'not', 'exist')") - - self.getPage("/extra/too/much") - self.assertBody("('too', 'much')") - - self.getPage("/other") - self.assertBody('other') - - self.getPage("/notExposed") - self.assertBody("default:('notExposed',)") - - self.getPage("/dir1/dir2/") - self.assertBody('index for dir2, path is:/dir1/dir2/') - - # Test omitted trailing slash (should be redirected by default). - self.getPage("/dir1/dir2") - self.assertStatus(301) - self.assertHeader('Location', '%s/dir1/dir2/' % self.base()) - - # Test extra trailing slash (should be redirected if configured). - self.getPage("/dir1/myMethod/") - self.assertStatus(301) - self.assertHeader('Location', '%s/dir1/myMethod' % self.base()) - - # Test that default method must be exposed in order to match. - self.getPage("/dir1/dir2/dir3/dir4/index") - self.assertBody("default for dir1, param is:('dir2', 'dir3', 'dir4', 'index')") - - # Test *vpath when default() is defined but not index() - # This also tests HTTPRedirect with default. - self.getPage("/defnoindex") - self.assertStatus((302, 303)) - self.assertHeader('Location', '%s/contact' % self.base()) - self.getPage("/defnoindex/") - self.assertStatus((302, 303)) - self.assertHeader('Location', '%s/defnoindex/contact' % self.base()) - self.getPage("/defnoindex/page") - self.assertStatus((302, 303)) - self.assertHeader('Location', '%s/defnoindex/contact' % self.base()) - - self.getPage("/redirect") - self.assertStatus('302 Found') - self.assertHeader('Location', '%s/dir1/' % self.base()) - - if not getattr(cherrypy.server, "using_apache", False): - # Test that we can use URL's which aren't all valid Python identifiers - # This should also test the %XX-unquoting of URL's. - self.getPage("/Von%20B%fclow?ID=14") - self.assertBody("ID is 14") - - # Test that %2F in the path doesn't get unquoted too early; - # that is, it should not be used to separate path components. - # See ticket #393. - self.getPage("/page%2Fname") - self.assertBody("default:('page/name',)") - - self.getPage("/dir1/dir2/script_name") - self.assertBody(url) - self.getPage("/dir1/dir2/cherrypy_url") - self.assertBody("%s/extra" % self.base()) - - # Test that configs don't overwrite each other from diferent apps - self.getPage("/confvalue") - self.assertBody((url or "/").split("/")[-2]) - - self.script_name = "" - - # Test absoluteURI's in the Request-Line - self.getPage('http://%s:%s/' % (self.interface(), self.PORT)) - self.assertBody('world') - - self.getPage('http://%s:%s/abs/?service=http://192.168.0.1/x/y/z' % - (self.interface(), self.PORT)) - self.assertBody("default:('abs',)") - - self.getPage('/rel/?service=http://192.168.120.121:8000/x/y/z') - self.assertBody("default:('rel',)") - - # Test that the "isolated" app doesn't leak url's into the root app. - # If it did leak, Root.default() would answer with - # "default:('isolated', 'doesnt', 'exist')". - self.getPage("/isolated/") - self.assertStatus("200 OK") - self.assertBody("made it!") - self.getPage("/isolated/doesnt/exist") - self.assertStatus("404 Not Found") - - # Make sure /foobar maps to Root.foobar and not to the app - # mounted at /foo. See http://www.cherrypy.org/ticket/573 - self.getPage("/foobar") - self.assertBody("bar") - - def test_translate(self): - self.getPage("/translate_html") - self.assertStatus("200 OK") - self.assertBody("OK") - - self.getPage("/translate.html") - self.assertStatus("200 OK") - self.assertBody("OK") - - self.getPage("/translate-html") - self.assertStatus("200 OK") - self.assertBody("OK") - - def test_redir_using_url(self): - for url in script_names: - prefix = self.script_name = url - - # Test the absolute path to the parent (leading slash) - self.getPage('/redirect_via_url?path=./') - self.assertStatus(('302 Found', '303 See Other')) - self.assertHeader('Location', '%s/' % self.base()) - - # Test the relative path to the parent (no leading slash) - self.getPage('/redirect_via_url?path=./') - self.assertStatus(('302 Found', '303 See Other')) - self.assertHeader('Location', '%s/' % self.base()) - - # Test the absolute path to the parent (leading slash) - self.getPage('/redirect_via_url/?path=./') - self.assertStatus(('302 Found', '303 See Other')) - self.assertHeader('Location', '%s/' % self.base()) - - # Test the relative path to the parent (no leading slash) - self.getPage('/redirect_via_url/?path=./') - self.assertStatus(('302 Found', '303 See Other')) - self.assertHeader('Location', '%s/' % self.base()) - - def testPositionalParams(self): - self.getPage("/dir1/dir2/posparam/18/24/hut/hike") - self.assertBody("18/24/hut/hike") - - # intermediate index methods should not receive posparams; - # only the "final" index method should do so. - self.getPage("/dir1/dir2/5/3/sir") - self.assertBody("default for dir1, param is:('dir2', '5', '3', 'sir')") - - # test that extra positional args raises an 404 Not Found - # See http://www.cherrypy.org/ticket/733. - self.getPage("/dir1/dir2/script_name/extra/stuff") - self.assertStatus(404) - - def testExpose(self): - # Test the cherrypy.expose function/decorator - self.getPage("/exposing/base") - self.assertBody("expose works!") - - self.getPage("/exposing/1") - self.assertBody("expose works!") - - self.getPage("/exposing/2") - self.assertBody("expose works!") - - self.getPage("/exposingnew/base") - self.assertBody("expose works!") - - self.getPage("/exposingnew/1") - self.assertBody("expose works!") - - self.getPage("/exposingnew/2") - self.assertBody("expose works!") - - def testMethodDispatch(self): - self.getPage("/bymethod") - self.assertBody("['another']") - self.assertHeader('Allow', 'GET, HEAD, POST') - - self.getPage("/bymethod", method="HEAD") - self.assertBody("") - self.assertHeader('Allow', 'GET, HEAD, POST') - - self.getPage("/bymethod", method="POST", body="thing=one") - self.assertBody("") - self.assertHeader('Allow', 'GET, HEAD, POST') - - self.getPage("/bymethod") - self.assertBody("['another', u'one']") - self.assertHeader('Allow', 'GET, HEAD, POST') - - self.getPage("/bymethod", method="PUT") - self.assertErrorPage(405) - self.assertHeader('Allow', 'GET, HEAD, POST') - - # Test default with posparams - self.getPage("/collection/silly", method="POST") - self.getPage("/collection", method="GET") - self.assertBody("['a', 'bit', 'silly']") - - # Test custom dispatcher set on app root (see #737). - self.getPage("/app") - self.assertBody("milk") - - def testTreeMounting(self): - class Root(object): - def hello(self): - return "Hello world!" - hello.exposed = True - - # When mounting an application instance, - # we can't specify a different script name in the call to mount. - a = Application(Root(), '/somewhere') - self.assertRaises(ValueError, cherrypy.tree.mount, a, '/somewhereelse') - - # When mounting an application instance... - a = Application(Root(), '/somewhere') - # ...we MUST allow in identical script name in the call to mount... - cherrypy.tree.mount(a, '/somewhere') - self.getPage('/somewhere/hello') - self.assertStatus(200) - # ...and MUST allow a missing script_name. - del cherrypy.tree.apps['/somewhere'] - cherrypy.tree.mount(a) - self.getPage('/somewhere/hello') - self.assertStatus(200) - - # In addition, we MUST be able to create an Application using - # script_name == None for access to the wsgi_environ. - a = Application(Root(), script_name=None) - # However, this does not apply to tree.mount - self.assertRaises(TypeError, cherrypy.tree.mount, a, None) - diff --git a/cherrypy/test/test_proxy.py b/cherrypy/test/test_proxy.py deleted file mode 100644 index 2fbb619a..00000000 --- a/cherrypy/test/test_proxy.py +++ /dev/null @@ -1,129 +0,0 @@ -import cherrypy -from cherrypy.test import helper - -script_names = ["", "/path/to/myapp"] - - -class ProxyTest(helper.CPWebCase): - - def setup_server(): - - # Set up site - cherrypy.config.update({ - 'tools.proxy.on': True, - 'tools.proxy.base': 'www.mydomain.test', - }) - - # Set up application - - class Root: - - def __init__(self, sn): - # Calculate a URL outside of any requests. - self.thisnewpage = cherrypy.url("/this/new/page", script_name=sn) - - def pageurl(self): - return self.thisnewpage - pageurl.exposed = True - - def index(self): - raise cherrypy.HTTPRedirect('dummy') - index.exposed = True - - def remoteip(self): - return cherrypy.request.remote.ip - remoteip.exposed = True - - def xhost(self): - raise cherrypy.HTTPRedirect('blah') - xhost.exposed = True - xhost._cp_config = {'tools.proxy.local': 'X-Host', - 'tools.trailing_slash.extra': True, - } - - def base(self): - return cherrypy.request.base - base.exposed = True - - def ssl(self): - return cherrypy.request.base - ssl.exposed = True - ssl._cp_config = {'tools.proxy.scheme': 'X-Forwarded-Ssl'} - - def newurl(self): - return ("Browse to this page." - % cherrypy.url("/this/new/page")) - newurl.exposed = True - - for sn in script_names: - cherrypy.tree.mount(Root(sn), sn) - setup_server = staticmethod(setup_server) - - def testProxy(self): - self.getPage("/") - self.assertHeader('Location', - "%s://www.mydomain.test%s/dummy" % - (self.scheme, self.prefix())) - - # Test X-Forwarded-Host (Apache 1.3.33+ and Apache 2) - self.getPage("/", headers=[('X-Forwarded-Host', 'http://www.example.test')]) - self.assertHeader('Location', "http://www.example.test/dummy") - self.getPage("/", headers=[('X-Forwarded-Host', 'www.example.test')]) - self.assertHeader('Location', "%s://www.example.test/dummy" % self.scheme) - # Test multiple X-Forwarded-Host headers - self.getPage("/", headers=[ - ('X-Forwarded-Host', 'http://www.example.test, www.cherrypy.test'), - ]) - self.assertHeader('Location', "http://www.example.test/dummy") - - # Test X-Forwarded-For (Apache2) - self.getPage("/remoteip", - headers=[('X-Forwarded-For', '192.168.0.20')]) - self.assertBody("192.168.0.20") - self.getPage("/remoteip", - headers=[('X-Forwarded-For', '67.15.36.43, 192.168.0.20')]) - self.assertBody("192.168.0.20") - - # Test X-Host (lighttpd; see https://trac.lighttpd.net/trac/ticket/418) - self.getPage("/xhost", headers=[('X-Host', 'www.example.test')]) - self.assertHeader('Location', "%s://www.example.test/blah" % self.scheme) - - # Test X-Forwarded-Proto (lighttpd) - self.getPage("/base", headers=[('X-Forwarded-Proto', 'https')]) - self.assertBody("https://www.mydomain.test") - - # Test X-Forwarded-Ssl (webfaction?) - self.getPage("/ssl", headers=[('X-Forwarded-Ssl', 'on')]) - self.assertBody("https://www.mydomain.test") - - # Test cherrypy.url() - for sn in script_names: - # Test the value inside requests - self.getPage(sn + "/newurl") - self.assertBody("Browse to this page.") - self.getPage(sn + "/newurl", headers=[('X-Forwarded-Host', - 'http://www.example.test')]) - self.assertBody("Browse to this page.") - - # Test the value outside requests - port = "" - if self.scheme == "http" and self.PORT != 80: - port = ":%s" % self.PORT - elif self.scheme == "https" and self.PORT != 443: - port = ":%s" % self.PORT - host = self.HOST - if host in ('0.0.0.0', '::'): - import socket - host = socket.gethostname() - expected = ("%s://%s%s%s/this/new/page" - % (self.scheme, host, port, sn)) - self.getPage(sn + "/pageurl") - self.assertBody(expected) - - # Test trailing slash (see http://www.cherrypy.org/ticket/562). - self.getPage("/xhost/", headers=[('X-Host', 'www.example.test')]) - self.assertHeader('Location', "%s://www.example.test/xhost" - % self.scheme) - diff --git a/cherrypy/test/test_refleaks.py b/cherrypy/test/test_refleaks.py deleted file mode 100644 index 4df1f082..00000000 --- a/cherrypy/test/test_refleaks.py +++ /dev/null @@ -1,119 +0,0 @@ -"""Tests for refleaks.""" - -import gc -from cherrypy._cpcompat import HTTPConnection, HTTPSConnection, ntob -import threading - -import cherrypy -from cherrypy import _cprequest - - -data = object() - -def get_instances(cls): - return [x for x in gc.get_objects() if isinstance(x, cls)] - - -from cherrypy.test import helper - - -class ReferenceTests(helper.CPWebCase): - - def setup_server(): - - class Root: - def index(self, *args, **kwargs): - cherrypy.request.thing = data - return "Hello world!" - index.exposed = True - - def gc_stats(self): - output = ["Statistics:"] - - # Uncollectable garbage - - # gc_collect isn't perfectly synchronous, because it may - # break reference cycles that then take time to fully - # finalize. Call it twice and hope for the best. - gc.collect() - unreachable = gc.collect() - if unreachable: - output.append("\n%s unreachable objects:" % unreachable) - trash = {} - for x in gc.garbage: - trash[type(x)] = trash.get(type(x), 0) + 1 - trash = [(v, k) for k, v in trash.items()] - trash.sort() - for pair in trash: - output.append(" " + repr(pair)) - - # Request references - reqs = get_instances(_cprequest.Request) - lenreqs = len(reqs) - if lenreqs < 2: - output.append("\nMissing Request reference. Should be 1 in " - "this request thread and 1 in the main thread.") - elif lenreqs > 2: - output.append("\nToo many Request references (%r)." % lenreqs) - for req in reqs: - output.append("Referrers for %s:" % repr(req)) - for ref in gc.get_referrers(req): - if ref is not reqs: - output.append(" %s" % repr(ref)) - - # Response references - resps = get_instances(_cprequest.Response) - lenresps = len(resps) - if lenresps < 2: - output.append("\nMissing Response reference. Should be 1 in " - "this request thread and 1 in the main thread.") - elif lenresps > 2: - output.append("\nToo many Response references (%r)." % lenresps) - for resp in resps: - output.append("Referrers for %s:" % repr(resp)) - for ref in gc.get_referrers(resp): - if ref is not resps: - output.append(" %s" % repr(ref)) - - return "\n".join(output) - gc_stats.exposed = True - - cherrypy.tree.mount(Root()) - setup_server = staticmethod(setup_server) - - - def test_threadlocal_garbage(self): - success = [] - - def getpage(): - host = '%s:%s' % (self.interface(), self.PORT) - if self.scheme == 'https': - c = HTTPSConnection(host) - else: - c = HTTPConnection(host) - try: - c.putrequest('GET', '/') - c.endheaders() - response = c.getresponse() - body = response.read() - self.assertEqual(response.status, 200) - self.assertEqual(body, ntob("Hello world!")) - finally: - c.close() - success.append(True) - - ITERATIONS = 25 - ts = [] - for _ in range(ITERATIONS): - t = threading.Thread(target=getpage) - ts.append(t) - t.start() - - for t in ts: - t.join() - - self.assertEqual(len(success), ITERATIONS) - - self.getPage("/gc_stats") - self.assertBody("Statistics:") - diff --git a/cherrypy/test/test_request_obj.py b/cherrypy/test/test_request_obj.py deleted file mode 100644 index 91ee4fd0..00000000 --- a/cherrypy/test/test_request_obj.py +++ /dev/null @@ -1,722 +0,0 @@ -"""Basic tests for the cherrypy.Request object.""" - -import os -localDir = os.path.dirname(__file__) -import sys -import types -from cherrypy._cpcompat import IncompleteRead, ntob, unicodestr - -import cherrypy -from cherrypy import _cptools, tools -from cherrypy.lib import httputil - -defined_http_methods = ("OPTIONS", "GET", "HEAD", "POST", "PUT", "DELETE", - "TRACE", "PROPFIND") - - -# Client-side code # - -from cherrypy.test import helper - -class RequestObjectTests(helper.CPWebCase): - - def setup_server(): - class Root: - - def index(self): - return "hello" - index.exposed = True - - def scheme(self): - return cherrypy.request.scheme - scheme.exposed = True - - root = Root() - - - class TestType(type): - """Metaclass which automatically exposes all functions in each subclass, - and adds an instance of the subclass as an attribute of root. - """ - def __init__(cls, name, bases, dct): - type.__init__(cls, name, bases, dct) - for value in dct.values(): - if isinstance(value, types.FunctionType): - value.exposed = True - setattr(root, name.lower(), cls()) - class Test(object): - __metaclass__ = TestType - - - class Params(Test): - - def index(self, thing): - return repr(thing) - - def ismap(self, x, y): - return "Coordinates: %s, %s" % (x, y) - - def default(self, *args, **kwargs): - return "args: %s kwargs: %s" % (args, kwargs) - default._cp_config = {'request.query_string_encoding': 'latin1'} - - - class ParamErrorsCallable(object): - exposed = True - def __call__(self): - return "data" - - class ParamErrors(Test): - - def one_positional(self, param1): - return "data" - one_positional.exposed = True - - def one_positional_args(self, param1, *args): - return "data" - one_positional_args.exposed = True - - def one_positional_args_kwargs(self, param1, *args, **kwargs): - return "data" - one_positional_args_kwargs.exposed = True - - def one_positional_kwargs(self, param1, **kwargs): - return "data" - one_positional_kwargs.exposed = True - - def no_positional(self): - return "data" - no_positional.exposed = True - - def no_positional_args(self, *args): - return "data" - no_positional_args.exposed = True - - def no_positional_args_kwargs(self, *args, **kwargs): - return "data" - no_positional_args_kwargs.exposed = True - - def no_positional_kwargs(self, **kwargs): - return "data" - no_positional_kwargs.exposed = True - - callable_object = ParamErrorsCallable() - - def raise_type_error(self, **kwargs): - raise TypeError("Client Error") - raise_type_error.exposed = True - - def raise_type_error_with_default_param(self, x, y=None): - return '%d' % 'a' # throw an exception - raise_type_error_with_default_param.exposed = True - - def callable_error_page(status, **kwargs): - return "Error %s - Well, I'm very sorry but you haven't paid!" % status - - - class Error(Test): - - _cp_config = {'tools.log_tracebacks.on': True, - } - - def reason_phrase(self): - raise cherrypy.HTTPError("410 Gone fishin'") - - def custom(self, err='404'): - raise cherrypy.HTTPError(int(err), "No, really, not found!") - custom._cp_config = {'error_page.404': os.path.join(localDir, "static/index.html"), - 'error_page.401': callable_error_page, - } - - def custom_default(self): - return 1 + 'a' # raise an unexpected error - custom_default._cp_config = {'error_page.default': callable_error_page} - - def noexist(self): - raise cherrypy.HTTPError(404, "No, really, not found!") - noexist._cp_config = {'error_page.404': "nonexistent.html"} - - def page_method(self): - raise ValueError() - - def page_yield(self): - yield "howdy" - raise ValueError() - - def page_streamed(self): - yield "word up" - raise ValueError() - yield "very oops" - page_streamed._cp_config = {"response.stream": True} - - def cause_err_in_finalize(self): - # Since status must start with an int, this should error. - cherrypy.response.status = "ZOO OK" - cause_err_in_finalize._cp_config = {'request.show_tracebacks': False} - - def rethrow(self): - """Test that an error raised here will be thrown out to the server.""" - raise ValueError() - rethrow._cp_config = {'request.throw_errors': True} - - - class Expect(Test): - - def expectation_failed(self): - expect = cherrypy.request.headers.elements("Expect") - if expect and expect[0].value != '100-continue': - raise cherrypy.HTTPError(400) - raise cherrypy.HTTPError(417, 'Expectation Failed') - - class Headers(Test): - - def default(self, headername): - """Spit back out the value for the requested header.""" - return cherrypy.request.headers[headername] - - def doubledheaders(self): - # From http://www.cherrypy.org/ticket/165: - # "header field names should not be case sensitive sayes the rfc. - # if i set a headerfield in complete lowercase i end up with two - # header fields, one in lowercase, the other in mixed-case." - - # Set the most common headers - hMap = cherrypy.response.headers - hMap['content-type'] = "text/html" - hMap['content-length'] = 18 - hMap['server'] = 'CherryPy headertest' - hMap['location'] = ('%s://%s:%s/headers/' - % (cherrypy.request.local.ip, - cherrypy.request.local.port, - cherrypy.request.scheme)) - - # Set a rare header for fun - hMap['Expires'] = 'Thu, 01 Dec 2194 16:00:00 GMT' - - return "double header test" - - def ifmatch(self): - val = cherrypy.request.headers['If-Match'] - assert isinstance(val, unicodestr) - cherrypy.response.headers['ETag'] = val - return val - - - class HeaderElements(Test): - - def get_elements(self, headername): - e = cherrypy.request.headers.elements(headername) - return "\n".join([unicodestr(x) for x in e]) - - - class Method(Test): - - def index(self): - m = cherrypy.request.method - if m in defined_http_methods or m == "CONNECT": - return m - - if m == "LINK": - raise cherrypy.HTTPError(405) - else: - raise cherrypy.HTTPError(501) - - def parameterized(self, data): - return data - - def request_body(self): - # This should be a file object (temp file), - # which CP will just pipe back out if we tell it to. - return cherrypy.request.body - - def reachable(self): - return "success" - - class Divorce: - """HTTP Method handlers shouldn't collide with normal method names. - For example, a GET-handler shouldn't collide with a method named 'get'. - - If you build HTTP method dispatching into CherryPy, rewrite this class - to use your new dispatch mechanism and make sure that: - "GET /divorce HTTP/1.1" maps to divorce.index() and - "GET /divorce/get?ID=13 HTTP/1.1" maps to divorce.get() - """ - - documents = {} - - def index(self): - yield "

Choose your document

\n" - yield "
    \n" - for id, contents in self.documents.items(): - yield ("
  • %s: %s
  • \n" - % (id, id, contents)) - yield "
" - index.exposed = True - - def get(self, ID): - return ("Divorce document %s: %s" % - (ID, self.documents.get(ID, "empty"))) - get.exposed = True - - root.divorce = Divorce() - - - class ThreadLocal(Test): - - def index(self): - existing = repr(getattr(cherrypy.request, "asdf", None)) - cherrypy.request.asdf = "rassfrassin" - return existing - - appconf = { - '/method': {'request.methods_with_bodies': ("POST", "PUT", "PROPFIND")}, - } - cherrypy.tree.mount(root, config=appconf) - setup_server = staticmethod(setup_server) - - def test_scheme(self): - self.getPage("/scheme") - self.assertBody(self.scheme) - - def testParams(self): - self.getPage("/params/?thing=a") - self.assertBody("u'a'") - - self.getPage("/params/?thing=a&thing=b&thing=c") - self.assertBody("[u'a', u'b', u'c']") - - # Test friendly error message when given params are not accepted. - cherrypy.config.update({"request.show_mismatched_params": True}) - self.getPage("/params/?notathing=meeting") - self.assertInBody("Missing parameters: thing") - self.getPage("/params/?thing=meeting¬athing=meeting") - self.assertInBody("Unexpected query string parameters: notathing") - - # Test ability to turn off friendly error messages - cherrypy.config.update({"request.show_mismatched_params": False}) - self.getPage("/params/?notathing=meeting") - self.assertInBody("Not Found") - self.getPage("/params/?thing=meeting¬athing=meeting") - self.assertInBody("Not Found") - - # Test "% HEX HEX"-encoded URL, param keys, and values - self.getPage("/params/%d4%20%e3/cheese?Gruy%E8re=Bulgn%e9ville") - self.assertBody(r"args: ('\xd4 \xe3', 'cheese') " - r"kwargs: {'Gruy\xe8re': u'Bulgn\xe9ville'}") - - # Make sure that encoded = and & get parsed correctly - self.getPage("/params/code?url=http%3A//cherrypy.org/index%3Fa%3D1%26b%3D2") - self.assertBody(r"args: ('code',) " - r"kwargs: {'url': u'http://cherrypy.org/index?a=1&b=2'}") - - # Test coordinates sent by - self.getPage("/params/ismap?223,114") - self.assertBody("Coordinates: 223, 114") - - # Test "name[key]" dict-like params - self.getPage("/params/dictlike?a[1]=1&a[2]=2&b=foo&b[bar]=baz") - self.assertBody( - "args: ('dictlike',) " - "kwargs: {'a[1]': u'1', 'b[bar]': u'baz', 'b': u'foo', 'a[2]': u'2'}") - - def testParamErrors(self): - - # test that all of the handlers work when given - # the correct parameters in order to ensure that the - # errors below aren't coming from some other source. - for uri in ( - '/paramerrors/one_positional?param1=foo', - '/paramerrors/one_positional_args?param1=foo', - '/paramerrors/one_positional_args/foo', - '/paramerrors/one_positional_args/foo/bar/baz', - '/paramerrors/one_positional_args_kwargs?param1=foo¶m2=bar', - '/paramerrors/one_positional_args_kwargs/foo?param2=bar¶m3=baz', - '/paramerrors/one_positional_args_kwargs/foo/bar/baz?param2=bar¶m3=baz', - '/paramerrors/one_positional_kwargs?param1=foo¶m2=bar¶m3=baz', - '/paramerrors/one_positional_kwargs/foo?param4=foo¶m2=bar¶m3=baz', - '/paramerrors/no_positional', - '/paramerrors/no_positional_args/foo', - '/paramerrors/no_positional_args/foo/bar/baz', - '/paramerrors/no_positional_args_kwargs?param1=foo¶m2=bar', - '/paramerrors/no_positional_args_kwargs/foo?param2=bar', - '/paramerrors/no_positional_args_kwargs/foo/bar/baz?param2=bar¶m3=baz', - '/paramerrors/no_positional_kwargs?param1=foo¶m2=bar', - '/paramerrors/callable_object', - ): - self.getPage(uri) - self.assertStatus(200) - - # query string parameters are part of the URI, so if they are wrong - # for a particular handler, the status MUST be a 404. - error_msgs = [ - 'Missing parameters', - 'Nothing matches the given URI', - 'Multiple values for parameters', - 'Unexpected query string parameters', - 'Unexpected body parameters', - ] - for uri, msg in ( - ('/paramerrors/one_positional', error_msgs[0]), - ('/paramerrors/one_positional?foo=foo', error_msgs[0]), - ('/paramerrors/one_positional/foo/bar/baz', error_msgs[1]), - ('/paramerrors/one_positional/foo?param1=foo', error_msgs[2]), - ('/paramerrors/one_positional/foo?param1=foo¶m2=foo', error_msgs[2]), - ('/paramerrors/one_positional_args/foo?param1=foo¶m2=foo', error_msgs[2]), - ('/paramerrors/one_positional_args/foo/bar/baz?param2=foo', error_msgs[3]), - ('/paramerrors/one_positional_args_kwargs/foo/bar/baz?param1=bar¶m3=baz', error_msgs[2]), - ('/paramerrors/one_positional_kwargs/foo?param1=foo¶m2=bar¶m3=baz', error_msgs[2]), - ('/paramerrors/no_positional/boo', error_msgs[1]), - ('/paramerrors/no_positional?param1=foo', error_msgs[3]), - ('/paramerrors/no_positional_args/boo?param1=foo', error_msgs[3]), - ('/paramerrors/no_positional_kwargs/boo?param1=foo', error_msgs[1]), - ('/paramerrors/callable_object?param1=foo', error_msgs[3]), - ('/paramerrors/callable_object/boo', error_msgs[1]), - ): - for show_mismatched_params in (True, False): - cherrypy.config.update({'request.show_mismatched_params': show_mismatched_params}) - self.getPage(uri) - self.assertStatus(404) - if show_mismatched_params: - self.assertInBody(msg) - else: - self.assertInBody("Not Found") - - # if body parameters are wrong, a 400 must be returned. - for uri, body, msg in ( - ('/paramerrors/one_positional/foo', 'param1=foo', error_msgs[2]), - ('/paramerrors/one_positional/foo', 'param1=foo¶m2=foo', error_msgs[2]), - ('/paramerrors/one_positional_args/foo', 'param1=foo¶m2=foo', error_msgs[2]), - ('/paramerrors/one_positional_args/foo/bar/baz', 'param2=foo', error_msgs[4]), - ('/paramerrors/one_positional_args_kwargs/foo/bar/baz', 'param1=bar¶m3=baz', error_msgs[2]), - ('/paramerrors/one_positional_kwargs/foo', 'param1=foo¶m2=bar¶m3=baz', error_msgs[2]), - ('/paramerrors/no_positional', 'param1=foo', error_msgs[4]), - ('/paramerrors/no_positional_args/boo', 'param1=foo', error_msgs[4]), - ('/paramerrors/callable_object', 'param1=foo', error_msgs[4]), - ): - for show_mismatched_params in (True, False): - cherrypy.config.update({'request.show_mismatched_params': show_mismatched_params}) - self.getPage(uri, method='POST', body=body) - self.assertStatus(400) - if show_mismatched_params: - self.assertInBody(msg) - else: - self.assertInBody("Bad Request") - - - # even if body parameters are wrong, if we get the uri wrong, then - # it's a 404 - for uri, body, msg in ( - ('/paramerrors/one_positional?param2=foo', 'param1=foo', error_msgs[3]), - ('/paramerrors/one_positional/foo/bar', 'param2=foo', error_msgs[1]), - ('/paramerrors/one_positional_args/foo/bar?param2=foo', 'param3=foo', error_msgs[3]), - ('/paramerrors/one_positional_kwargs/foo/bar', 'param2=bar¶m3=baz', error_msgs[1]), - ('/paramerrors/no_positional?param1=foo', 'param2=foo', error_msgs[3]), - ('/paramerrors/no_positional_args/boo?param2=foo', 'param1=foo', error_msgs[3]), - ('/paramerrors/callable_object?param2=bar', 'param1=foo', error_msgs[3]), - ): - for show_mismatched_params in (True, False): - cherrypy.config.update({'request.show_mismatched_params': show_mismatched_params}) - self.getPage(uri, method='POST', body=body) - self.assertStatus(404) - if show_mismatched_params: - self.assertInBody(msg) - else: - self.assertInBody("Not Found") - - # In the case that a handler raises a TypeError we should - # let that type error through. - for uri in ( - '/paramerrors/raise_type_error', - '/paramerrors/raise_type_error_with_default_param?x=0', - '/paramerrors/raise_type_error_with_default_param?x=0&y=0', - ): - self.getPage(uri, method='GET') - self.assertStatus(500) - self.assertTrue('Client Error', self.body) - - def testErrorHandling(self): - self.getPage("/error/missing") - self.assertStatus(404) - self.assertErrorPage(404, "The path '/error/missing' was not found.") - - ignore = helper.webtest.ignored_exceptions - ignore.append(ValueError) - try: - valerr = '\n raise ValueError()\nValueError' - self.getPage("/error/page_method") - self.assertErrorPage(500, pattern=valerr) - - self.getPage("/error/page_yield") - self.assertErrorPage(500, pattern=valerr) - - if (cherrypy.server.protocol_version == "HTTP/1.0" or - getattr(cherrypy.server, "using_apache", False)): - self.getPage("/error/page_streamed") - # Because this error is raised after the response body has - # started, the status should not change to an error status. - self.assertStatus(200) - self.assertBody("word up") - else: - # Under HTTP/1.1, the chunked transfer-coding is used. - # The HTTP client will choke when the output is incomplete. - self.assertRaises((ValueError, IncompleteRead), self.getPage, - "/error/page_streamed") - - # No traceback should be present - self.getPage("/error/cause_err_in_finalize") - msg = "Illegal response status from server ('ZOO' is non-numeric)." - self.assertErrorPage(500, msg, None) - finally: - ignore.pop() - - # Test HTTPError with a reason-phrase in the status arg. - self.getPage('/error/reason_phrase') - self.assertStatus("410 Gone fishin'") - - # Test custom error page for a specific error. - self.getPage("/error/custom") - self.assertStatus(404) - self.assertBody("Hello, world\r\n" + (" " * 499)) - - # Test custom error page for a specific error. - self.getPage("/error/custom?err=401") - self.assertStatus(401) - self.assertBody("Error 401 Unauthorized - Well, I'm very sorry but you haven't paid!") - - # Test default custom error page. - self.getPage("/error/custom_default") - self.assertStatus(500) - self.assertBody("Error 500 Internal Server Error - Well, I'm very sorry but you haven't paid!".ljust(513)) - - # Test error in custom error page (ticket #305). - # Note that the message is escaped for HTML (ticket #310). - self.getPage("/error/noexist") - self.assertStatus(404) - msg = ("No, <b>really</b>, not found!
" - "In addition, the custom error page failed:\n
" - "IOError: [Errno 2] No such file or directory: 'nonexistent.html'") - self.assertInBody(msg) - - if getattr(cherrypy.server, "using_apache", False): - pass - else: - # Test throw_errors (ticket #186). - self.getPage("/error/rethrow") - self.assertInBody("raise ValueError()") - - def testExpect(self): - e = ('Expect', '100-continue') - self.getPage("/headerelements/get_elements?headername=Expect", [e]) - self.assertBody('100-continue') - - self.getPage("/expect/expectation_failed", [e]) - self.assertStatus(417) - - def testHeaderElements(self): - # Accept-* header elements should be sorted, with most preferred first. - h = [('Accept', 'audio/*; q=0.2, audio/basic')] - self.getPage("/headerelements/get_elements?headername=Accept", h) - self.assertStatus(200) - self.assertBody("audio/basic\n" - "audio/*;q=0.2") - - h = [('Accept', 'text/plain; q=0.5, text/html, text/x-dvi; q=0.8, text/x-c')] - self.getPage("/headerelements/get_elements?headername=Accept", h) - self.assertStatus(200) - self.assertBody("text/x-c\n" - "text/html\n" - "text/x-dvi;q=0.8\n" - "text/plain;q=0.5") - - # Test that more specific media ranges get priority. - h = [('Accept', 'text/*, text/html, text/html;level=1, */*')] - self.getPage("/headerelements/get_elements?headername=Accept", h) - self.assertStatus(200) - self.assertBody("text/html;level=1\n" - "text/html\n" - "text/*\n" - "*/*") - - # Test Accept-Charset - h = [('Accept-Charset', 'iso-8859-5, unicode-1-1;q=0.8')] - self.getPage("/headerelements/get_elements?headername=Accept-Charset", h) - self.assertStatus("200 OK") - self.assertBody("iso-8859-5\n" - "unicode-1-1;q=0.8") - - # Test Accept-Encoding - h = [('Accept-Encoding', 'gzip;q=1.0, identity; q=0.5, *;q=0')] - self.getPage("/headerelements/get_elements?headername=Accept-Encoding", h) - self.assertStatus("200 OK") - self.assertBody("gzip;q=1.0\n" - "identity;q=0.5\n" - "*;q=0") - - # Test Accept-Language - h = [('Accept-Language', 'da, en-gb;q=0.8, en;q=0.7')] - self.getPage("/headerelements/get_elements?headername=Accept-Language", h) - self.assertStatus("200 OK") - self.assertBody("da\n" - "en-gb;q=0.8\n" - "en;q=0.7") - - # Test malformed header parsing. See http://www.cherrypy.org/ticket/763. - self.getPage("/headerelements/get_elements?headername=Content-Type", - # Note the illegal trailing ";" - headers=[('Content-Type', 'text/html; charset=utf-8;')]) - self.assertStatus(200) - self.assertBody("text/html;charset=utf-8") - - def test_repeated_headers(self): - # Test that two request headers are collapsed into one. - # See http://www.cherrypy.org/ticket/542. - self.getPage("/headers/Accept-Charset", - headers=[("Accept-Charset", "iso-8859-5"), - ("Accept-Charset", "unicode-1-1;q=0.8")]) - self.assertBody("iso-8859-5, unicode-1-1;q=0.8") - - # Tests that each header only appears once, regardless of case. - self.getPage("/headers/doubledheaders") - self.assertBody("double header test") - hnames = [name.title() for name, val in self.headers] - for key in ['Content-Length', 'Content-Type', 'Date', - 'Expires', 'Location', 'Server']: - self.assertEqual(hnames.count(key), 1, self.headers) - - def test_encoded_headers(self): - # First, make sure the innards work like expected. - self.assertEqual(httputil.decode_TEXT(u"=?utf-8?q?f=C3=BCr?="), u"f\xfcr") - - if cherrypy.server.protocol_version == "HTTP/1.1": - # Test RFC-2047-encoded request and response header values - u = u'\u212bngstr\xf6m' - c = u"=E2=84=ABngstr=C3=B6m" - self.getPage("/headers/ifmatch", [('If-Match', u'=?utf-8?q?%s?=' % c)]) - # The body should be utf-8 encoded. - self.assertBody("\xe2\x84\xabngstr\xc3\xb6m") - # But the Etag header should be RFC-2047 encoded (binary) - self.assertHeader("ETag", u'=?utf-8?b?4oSrbmdzdHLDtm0=?=') - - # Test a *LONG* RFC-2047-encoded request and response header value - self.getPage("/headers/ifmatch", - [('If-Match', u'=?utf-8?q?%s?=' % (c * 10))]) - self.assertBody("\xe2\x84\xabngstr\xc3\xb6m" * 10) - # Note: this is different output for Python3, but it decodes fine. - etag = self.assertHeader("ETag", - '=?utf-8?b?4oSrbmdzdHLDtm3ihKtuZ3N0csO2beKEq25nc3Ryw7Zt' - '4oSrbmdzdHLDtm3ihKtuZ3N0csO2beKEq25nc3Ryw7Zt' - '4oSrbmdzdHLDtm3ihKtuZ3N0csO2beKEq25nc3Ryw7Zt' - '4oSrbmdzdHLDtm0=?=') - self.assertEqual(httputil.decode_TEXT(etag), u * 10) - - def test_header_presence(self): - # If we don't pass a Content-Type header, it should not be present - # in cherrypy.request.headers - self.getPage("/headers/Content-Type", - headers=[]) - self.assertStatus(500) - - # If Content-Type is present in the request, it should be present in - # cherrypy.request.headers - self.getPage("/headers/Content-Type", - headers=[("Content-type", "application/json")]) - self.assertBody("application/json") - - def test_basic_HTTPMethods(self): - helper.webtest.methods_with_bodies = ("POST", "PUT", "PROPFIND") - - # Test that all defined HTTP methods work. - for m in defined_http_methods: - self.getPage("/method/", method=m) - - # HEAD requests should not return any body. - if m == "HEAD": - self.assertBody("") - elif m == "TRACE": - # Some HTTP servers (like modpy) have their own TRACE support - self.assertEqual(self.body[:5], ntob("TRACE")) - else: - self.assertBody(m) - - # Request a PUT method with a form-urlencoded body - self.getPage("/method/parameterized", method="PUT", - body="data=on+top+of+other+things") - self.assertBody("on top of other things") - - # Request a PUT method with a file body - b = "one thing on top of another" - h = [("Content-Type", "text/plain"), - ("Content-Length", str(len(b)))] - self.getPage("/method/request_body", headers=h, method="PUT", body=b) - self.assertStatus(200) - self.assertBody(b) - - # Request a PUT method with a file body but no Content-Type. - # See http://www.cherrypy.org/ticket/790. - b = ntob("one thing on top of another") - self.persistent = True - try: - conn = self.HTTP_CONN - conn.putrequest("PUT", "/method/request_body", skip_host=True) - conn.putheader("Host", self.HOST) - conn.putheader('Content-Length', str(len(b))) - conn.endheaders() - conn.send(b) - response = conn.response_class(conn.sock, method="PUT") - response.begin() - self.assertEqual(response.status, 200) - self.body = response.read() - self.assertBody(b) - finally: - self.persistent = False - - # Request a PUT method with no body whatsoever (not an empty one). - # See http://www.cherrypy.org/ticket/650. - # Provide a C-T or webtest will provide one (and a C-L) for us. - h = [("Content-Type", "text/plain")] - self.getPage("/method/reachable", headers=h, method="PUT") - self.assertStatus(411) - - # Request a custom method with a request body - b = ('\n\n' - '' - '') - h = [('Content-Type', 'text/xml'), - ('Content-Length', str(len(b)))] - self.getPage("/method/request_body", headers=h, method="PROPFIND", body=b) - self.assertStatus(200) - self.assertBody(b) - - # Request a disallowed method - self.getPage("/method/", method="LINK") - self.assertStatus(405) - - # Request an unknown method - self.getPage("/method/", method="SEARCH") - self.assertStatus(501) - - # For method dispatchers: make sure that an HTTP method doesn't - # collide with a virtual path atom. If you build HTTP-method - # dispatching into the core, rewrite these handlers to use - # your dispatch idioms. - self.getPage("/divorce/get?ID=13") - self.assertBody('Divorce document 13: empty') - self.assertStatus(200) - self.getPage("/divorce/", method="GET") - self.assertBody('

Choose your document

\n
    \n
') - self.assertStatus(200) - - def test_CONNECT_method(self): - if getattr(cherrypy.server, "using_apache", False): - return self.skip("skipped due to known Apache differences... ") - - self.getPage("/method/", method="CONNECT") - self.assertBody("CONNECT") - - def testEmptyThreadlocals(self): - results = [] - for x in range(20): - self.getPage("/threadlocal/") - results.append(self.body) - self.assertEqual(results, [ntob("None")] * 20) - diff --git a/cherrypy/test/test_routes.py b/cherrypy/test/test_routes.py deleted file mode 100644 index a8062f8f..00000000 --- a/cherrypy/test/test_routes.py +++ /dev/null @@ -1,69 +0,0 @@ -import os -curdir = os.path.join(os.getcwd(), os.path.dirname(__file__)) - -import cherrypy - -from cherrypy.test import helper -import nose - -class RoutesDispatchTest(helper.CPWebCase): - - def setup_server(): - - try: - import routes - except ImportError: - raise nose.SkipTest('Install routes to test RoutesDispatcher code') - - class Dummy: - def index(self): - return "I said good day!" - - class City: - - def __init__(self, name): - self.name = name - self.population = 10000 - - def index(self, **kwargs): - return "Welcome to %s, pop. %s" % (self.name, self.population) - index._cp_config = {'tools.response_headers.on': True, - 'tools.response_headers.headers': [('Content-Language', 'en-GB')]} - - def update(self, **kwargs): - self.population = kwargs['pop'] - return "OK" - - d = cherrypy.dispatch.RoutesDispatcher() - d.connect(action='index', name='hounslow', route='/hounslow', - controller=City('Hounslow')) - d.connect(name='surbiton', route='/surbiton', controller=City('Surbiton'), - action='index', conditions=dict(method=['GET'])) - d.mapper.connect('/surbiton', controller='surbiton', - action='update', conditions=dict(method=['POST'])) - d.connect('main', ':action', controller=Dummy()) - - conf = {'/': {'request.dispatch': d}} - cherrypy.tree.mount(root=None, config=conf) - setup_server = staticmethod(setup_server) - - def test_Routes_Dispatch(self): - self.getPage("/hounslow") - self.assertStatus("200 OK") - self.assertBody("Welcome to Hounslow, pop. 10000") - - self.getPage("/foo") - self.assertStatus("404 Not Found") - - self.getPage("/surbiton") - self.assertStatus("200 OK") - self.assertBody("Welcome to Surbiton, pop. 10000") - - self.getPage("/surbiton", method="POST", body="pop=1327") - self.assertStatus("200 OK") - self.assertBody("OK") - self.getPage("/surbiton") - self.assertStatus("200 OK") - self.assertHeader("Content-Language", "en-GB") - self.assertBody("Welcome to Surbiton, pop. 1327") - diff --git a/cherrypy/test/test_session.py b/cherrypy/test/test_session.py deleted file mode 100755 index 874023e2..00000000 --- a/cherrypy/test/test_session.py +++ /dev/null @@ -1,464 +0,0 @@ -import os -localDir = os.path.dirname(__file__) -import sys -import threading -import time - -import cherrypy -from cherrypy._cpcompat import copykeys, HTTPConnection, HTTPSConnection -from cherrypy.lib import sessions -from cherrypy.lib.httputil import response_codes - -def http_methods_allowed(methods=['GET', 'HEAD']): - method = cherrypy.request.method.upper() - if method not in methods: - cherrypy.response.headers['Allow'] = ", ".join(methods) - raise cherrypy.HTTPError(405) - -cherrypy.tools.allow = cherrypy.Tool('on_start_resource', http_methods_allowed) - - -def setup_server(): - - class Root: - - _cp_config = {'tools.sessions.on': True, - 'tools.sessions.storage_type' : 'ram', - 'tools.sessions.storage_path' : localDir, - 'tools.sessions.timeout': (1.0 / 60), - 'tools.sessions.clean_freq': (1.0 / 60), - } - - def clear(self): - cherrypy.session.cache.clear() - clear.exposed = True - - def data(self): - cherrypy.session['aha'] = 'foo' - return repr(cherrypy.session._data) - data.exposed = True - - def testGen(self): - counter = cherrypy.session.get('counter', 0) + 1 - cherrypy.session['counter'] = counter - yield str(counter) - testGen.exposed = True - - def testStr(self): - counter = cherrypy.session.get('counter', 0) + 1 - cherrypy.session['counter'] = counter - return str(counter) - testStr.exposed = True - - def setsessiontype(self, newtype): - self.__class__._cp_config.update({'tools.sessions.storage_type': newtype}) - if hasattr(cherrypy, "session"): - del cherrypy.session - cls = getattr(sessions, newtype.title() + 'Session') - if cls.clean_thread: - cls.clean_thread.stop() - cls.clean_thread.unsubscribe() - del cls.clean_thread - setsessiontype.exposed = True - setsessiontype._cp_config = {'tools.sessions.on': False} - - def index(self): - sess = cherrypy.session - c = sess.get('counter', 0) + 1 - time.sleep(0.01) - sess['counter'] = c - return str(c) - index.exposed = True - - def keyin(self, key): - return str(key in cherrypy.session) - keyin.exposed = True - - def delete(self): - cherrypy.session.delete() - sessions.expire() - return "done" - delete.exposed = True - - def delkey(self, key): - del cherrypy.session[key] - return "OK" - delkey.exposed = True - - def blah(self): - return self._cp_config['tools.sessions.storage_type'] - blah.exposed = True - - def iredir(self): - raise cherrypy.InternalRedirect('/blah') - iredir.exposed = True - - def restricted(self): - return cherrypy.request.method - restricted.exposed = True - restricted._cp_config = {'tools.allow.on': True, - 'tools.allow.methods': ['GET']} - - def regen(self): - cherrypy.tools.sessions.regenerate() - return "logged in" - regen.exposed = True - - def length(self): - return str(len(cherrypy.session)) - length.exposed = True - - def session_cookie(self): - # Must load() to start the clean thread. - cherrypy.session.load() - return cherrypy.session.id - session_cookie.exposed = True - session_cookie._cp_config = { - 'tools.sessions.path': '/session_cookie', - 'tools.sessions.name': 'temp', - 'tools.sessions.persistent': False} - - cherrypy.tree.mount(Root()) - - -from cherrypy.test import helper - -class SessionTest(helper.CPWebCase): - setup_server = staticmethod(setup_server) - - def tearDown(self): - # Clean up sessions. - for fname in os.listdir(localDir): - if fname.startswith(sessions.FileSession.SESSION_PREFIX): - os.unlink(os.path.join(localDir, fname)) - - def test_0_Session(self): - self.getPage('/setsessiontype/ram') - self.getPage('/clear') - - # Test that a normal request gets the same id in the cookies. - # Note: this wouldn't work if /data didn't load the session. - self.getPage('/data') - self.assertBody("{'aha': 'foo'}") - c = self.cookies[0] - self.getPage('/data', self.cookies) - self.assertEqual(self.cookies[0], c) - - self.getPage('/testStr') - self.assertBody('1') - cookie_parts = dict([p.strip().split('=') - for p in self.cookies[0][1].split(";")]) - # Assert there is an 'expires' param - self.assertEqual(set(cookie_parts.keys()), - set(['session_id', 'expires', 'Path'])) - self.getPage('/testGen', self.cookies) - self.assertBody('2') - self.getPage('/testStr', self.cookies) - self.assertBody('3') - self.getPage('/data', self.cookies) - self.assertBody("{'aha': 'foo', 'counter': 3}") - self.getPage('/length', self.cookies) - self.assertBody('2') - self.getPage('/delkey?key=counter', self.cookies) - self.assertStatus(200) - - self.getPage('/setsessiontype/file') - self.getPage('/testStr') - self.assertBody('1') - self.getPage('/testGen', self.cookies) - self.assertBody('2') - self.getPage('/testStr', self.cookies) - self.assertBody('3') - self.getPage('/delkey?key=counter', self.cookies) - self.assertStatus(200) - - # Wait for the session.timeout (1 second) - time.sleep(2) - self.getPage('/') - self.assertBody('1') - self.getPage('/length', self.cookies) - self.assertBody('1') - - # Test session __contains__ - self.getPage('/keyin?key=counter', self.cookies) - self.assertBody("True") - cookieset1 = self.cookies - - # Make a new session and test __len__ again - self.getPage('/') - self.getPage('/length', self.cookies) - self.assertBody('2') - - # Test session delete - self.getPage('/delete', self.cookies) - self.assertBody("done") - self.getPage('/delete', cookieset1) - self.assertBody("done") - f = lambda: [x for x in os.listdir(localDir) if x.startswith('session-')] - self.assertEqual(f(), []) - - # Wait for the cleanup thread to delete remaining session files - self.getPage('/') - f = lambda: [x for x in os.listdir(localDir) if x.startswith('session-')] - self.assertNotEqual(f(), []) - time.sleep(2) - self.assertEqual(f(), []) - - def test_1_Ram_Concurrency(self): - self.getPage('/setsessiontype/ram') - self._test_Concurrency() - - def test_2_File_Concurrency(self): - self.getPage('/setsessiontype/file') - self._test_Concurrency() - - def _test_Concurrency(self): - client_thread_count = 5 - request_count = 30 - - # Get initial cookie - self.getPage("/") - self.assertBody("1") - cookies = self.cookies - - data_dict = {} - errors = [] - - def request(index): - if self.scheme == 'https': - c = HTTPSConnection('%s:%s' % (self.interface(), self.PORT)) - else: - c = HTTPConnection('%s:%s' % (self.interface(), self.PORT)) - for i in range(request_count): - c.putrequest('GET', '/') - for k, v in cookies: - c.putheader(k, v) - c.endheaders() - response = c.getresponse() - body = response.read() - if response.status != 200 or not body.isdigit(): - errors.append((response.status, body)) - else: - data_dict[index] = max(data_dict[index], int(body)) - # Uncomment the following line to prove threads overlap. -## sys.stdout.write("%d " % index) - - # Start requests from each of - # concurrent clients - ts = [] - for c in range(client_thread_count): - data_dict[c] = 0 - t = threading.Thread(target=request, args=(c,)) - ts.append(t) - t.start() - - for t in ts: - t.join() - - hitcount = max(data_dict.values()) - expected = 1 + (client_thread_count * request_count) - - for e in errors: - print(e) - self.assertEqual(hitcount, expected) - - def test_3_Redirect(self): - # Start a new session - self.getPage('/testStr') - self.getPage('/iredir', self.cookies) - self.assertBody("file") - - def test_4_File_deletion(self): - # Start a new session - self.getPage('/testStr') - # Delete the session file manually and retry. - id = self.cookies[0][1].split(";", 1)[0].split("=", 1)[1] - path = os.path.join(localDir, "session-" + id) - os.unlink(path) - self.getPage('/testStr', self.cookies) - - def test_5_Error_paths(self): - self.getPage('/unknown/page') - self.assertErrorPage(404, "The path '/unknown/page' was not found.") - - # Note: this path is *not* the same as above. The above - # takes a normal route through the session code; this one - # skips the session code's before_handler and only calls - # before_finalize (save) and on_end (close). So the session - # code has to survive calling save/close without init. - self.getPage('/restricted', self.cookies, method='POST') - self.assertErrorPage(405, response_codes[405]) - - def test_6_regenerate(self): - self.getPage('/testStr') - # grab the cookie ID - id1 = self.cookies[0][1].split(";", 1)[0].split("=", 1)[1] - self.getPage('/regen') - self.assertBody('logged in') - id2 = self.cookies[0][1].split(";", 1)[0].split("=", 1)[1] - self.assertNotEqual(id1, id2) - - self.getPage('/testStr') - # grab the cookie ID - id1 = self.cookies[0][1].split(";", 1)[0].split("=", 1)[1] - self.getPage('/testStr', - headers=[('Cookie', - 'session_id=maliciousid; ' - 'expires=Sat, 27 Oct 2017 04:18:28 GMT; Path=/;')]) - id2 = self.cookies[0][1].split(";", 1)[0].split("=", 1)[1] - self.assertNotEqual(id1, id2) - self.assertNotEqual(id2, 'maliciousid') - - def test_7_session_cookies(self): - self.getPage('/setsessiontype/ram') - self.getPage('/clear') - self.getPage('/session_cookie') - # grab the cookie ID - cookie_parts = dict([p.strip().split('=') for p in self.cookies[0][1].split(";")]) - # Assert there is no 'expires' param - self.assertEqual(set(cookie_parts.keys()), set(['temp', 'Path'])) - id1 = cookie_parts['temp'] - self.assertEqual(copykeys(sessions.RamSession.cache), [id1]) - - # Send another request in the same "browser session". - self.getPage('/session_cookie', self.cookies) - cookie_parts = dict([p.strip().split('=') for p in self.cookies[0][1].split(";")]) - # Assert there is no 'expires' param - self.assertEqual(set(cookie_parts.keys()), set(['temp', 'Path'])) - self.assertBody(id1) - self.assertEqual(copykeys(sessions.RamSession.cache), [id1]) - - # Simulate a browser close by just not sending the cookies - self.getPage('/session_cookie') - # grab the cookie ID - cookie_parts = dict([p.strip().split('=') for p in self.cookies[0][1].split(";")]) - # Assert there is no 'expires' param - self.assertEqual(set(cookie_parts.keys()), set(['temp', 'Path'])) - # Assert a new id has been generated... - id2 = cookie_parts['temp'] - self.assertNotEqual(id1, id2) - self.assertEqual(set(sessions.RamSession.cache.keys()), set([id1, id2])) - - # Wait for the session.timeout on both sessions - time.sleep(2.5) - cache = copykeys(sessions.RamSession.cache) - if cache: - if cache == [id2]: - self.fail("The second session did not time out.") - else: - self.fail("Unknown session id in cache: %r", cache) - - -import socket -try: - import memcache - - host, port = '127.0.0.1', 11211 - for res in socket.getaddrinfo(host, port, socket.AF_UNSPEC, - socket.SOCK_STREAM): - af, socktype, proto, canonname, sa = res - s = None - try: - s = socket.socket(af, socktype, proto) - # See http://groups.google.com/group/cherrypy-users/ - # browse_frm/thread/bbfe5eb39c904fe0 - s.settimeout(1.0) - s.connect((host, port)) - s.close() - except socket.error: - if s: - s.close() - raise - break -except (ImportError, socket.error): - class MemcachedSessionTest(helper.CPWebCase): - setup_server = staticmethod(setup_server) - - def test(self): - return self.skip("memcached not reachable ") -else: - class MemcachedSessionTest(helper.CPWebCase): - setup_server = staticmethod(setup_server) - - def test_0_Session(self): - self.getPage('/setsessiontype/memcached') - - self.getPage('/testStr') - self.assertBody('1') - self.getPage('/testGen', self.cookies) - self.assertBody('2') - self.getPage('/testStr', self.cookies) - self.assertBody('3') - self.getPage('/length', self.cookies) - self.assertErrorPage(500) - self.assertInBody("NotImplementedError") - self.getPage('/delkey?key=counter', self.cookies) - self.assertStatus(200) - - # Wait for the session.timeout (1 second) - time.sleep(1.25) - self.getPage('/') - self.assertBody('1') - - # Test session __contains__ - self.getPage('/keyin?key=counter', self.cookies) - self.assertBody("True") - - # Test session delete - self.getPage('/delete', self.cookies) - self.assertBody("done") - - def test_1_Concurrency(self): - client_thread_count = 5 - request_count = 30 - - # Get initial cookie - self.getPage("/") - self.assertBody("1") - cookies = self.cookies - - data_dict = {} - - def request(index): - for i in range(request_count): - self.getPage("/", cookies) - # Uncomment the following line to prove threads overlap. -## sys.stdout.write("%d " % index) - if not self.body.isdigit(): - self.fail(self.body) - data_dict[index] = v = int(self.body) - - # Start concurrent requests from - # each of clients - ts = [] - for c in range(client_thread_count): - data_dict[c] = 0 - t = threading.Thread(target=request, args=(c,)) - ts.append(t) - t.start() - - for t in ts: - t.join() - - hitcount = max(data_dict.values()) - expected = 1 + (client_thread_count * request_count) - self.assertEqual(hitcount, expected) - - def test_3_Redirect(self): - # Start a new session - self.getPage('/testStr') - self.getPage('/iredir', self.cookies) - self.assertBody("memcached") - - def test_5_Error_paths(self): - self.getPage('/unknown/page') - self.assertErrorPage(404, "The path '/unknown/page' was not found.") - - # Note: this path is *not* the same as above. The above - # takes a normal route through the session code; this one - # skips the session code's before_handler and only calls - # before_finalize (save) and on_end (close). So the session - # code has to survive calling save/close without init. - self.getPage('/restricted', self.cookies, method='POST') - self.assertErrorPage(405, response_codes[405]) - diff --git a/cherrypy/test/test_sessionauthenticate.py b/cherrypy/test/test_sessionauthenticate.py deleted file mode 100644 index ab1fe51e..00000000 --- a/cherrypy/test/test_sessionauthenticate.py +++ /dev/null @@ -1,62 +0,0 @@ -import cherrypy -from cherrypy.test import helper - - -class SessionAuthenticateTest(helper.CPWebCase): - - def setup_server(): - - def check(username, password): - # Dummy check_username_and_password function - if username != 'test' or password != 'password': - return 'Wrong login/password' - - def augment_params(): - # A simple tool to add some things to request.params - # This is to check to make sure that session_auth can handle request - # params (ticket #780) - cherrypy.request.params["test"] = "test" - - cherrypy.tools.augment_params = cherrypy.Tool('before_handler', - augment_params, None, priority=30) - - class Test: - - _cp_config = {'tools.sessions.on': True, - 'tools.session_auth.on': True, - 'tools.session_auth.check_username_and_password': check, - 'tools.augment_params.on': True, - } - - def index(self, **kwargs): - return "Hi %s, you are logged in" % cherrypy.request.login - index.exposed = True - - cherrypy.tree.mount(Test()) - setup_server = staticmethod(setup_server) - - - def testSessionAuthenticate(self): - # request a page and check for login form - self.getPage('/') - self.assertInBody('
') - - # setup credentials - login_body = 'username=test&password=password&from_page=/' - - # attempt a login - self.getPage('/do_login', method='POST', body=login_body) - self.assertStatus((302, 303)) - - # get the page now that we are logged in - self.getPage('/', self.cookies) - self.assertBody('Hi test, you are logged in') - - # do a logout - self.getPage('/do_logout', self.cookies, method='POST') - self.assertStatus((302, 303)) - - # verify we are logged out - self.getPage('/', self.cookies) - self.assertInBody('') - diff --git a/cherrypy/test/test_states.py b/cherrypy/test/test_states.py deleted file mode 100644 index 0f973374..00000000 --- a/cherrypy/test/test_states.py +++ /dev/null @@ -1,436 +0,0 @@ -from cherrypy._cpcompat import BadStatusLine, ntob -import os -import sys -import threading -import time - -import cherrypy -engine = cherrypy.engine -thisdir = os.path.join(os.getcwd(), os.path.dirname(__file__)) - - -class Dependency: - - def __init__(self, bus): - self.bus = bus - self.running = False - self.startcount = 0 - self.gracecount = 0 - self.threads = {} - - def subscribe(self): - self.bus.subscribe('start', self.start) - self.bus.subscribe('stop', self.stop) - self.bus.subscribe('graceful', self.graceful) - self.bus.subscribe('start_thread', self.startthread) - self.bus.subscribe('stop_thread', self.stopthread) - - def start(self): - self.running = True - self.startcount += 1 - - def stop(self): - self.running = False - - def graceful(self): - self.gracecount += 1 - - def startthread(self, thread_id): - self.threads[thread_id] = None - - def stopthread(self, thread_id): - del self.threads[thread_id] - -db_connection = Dependency(engine) - -def setup_server(): - class Root: - def index(self): - return "Hello World" - index.exposed = True - - def ctrlc(self): - raise KeyboardInterrupt() - ctrlc.exposed = True - - def graceful(self): - engine.graceful() - return "app was (gracefully) restarted succesfully" - graceful.exposed = True - - def block_explicit(self): - while True: - if cherrypy.response.timed_out: - cherrypy.response.timed_out = False - return "broken!" - time.sleep(0.01) - block_explicit.exposed = True - - def block_implicit(self): - time.sleep(0.5) - return "response.timeout = %s" % cherrypy.response.timeout - block_implicit.exposed = True - - cherrypy.tree.mount(Root()) - cherrypy.config.update({ - 'environment': 'test_suite', - 'engine.deadlock_poll_freq': 0.1, - }) - - db_connection.subscribe() - - - -# ------------ Enough helpers. Time for real live test cases. ------------ # - - -from cherrypy.test import helper - -class ServerStateTests(helper.CPWebCase): - setup_server = staticmethod(setup_server) - - def setUp(self): - cherrypy.server.socket_timeout = 0.1 - - def test_0_NormalStateFlow(self): - engine.stop() - # Our db_connection should not be running - self.assertEqual(db_connection.running, False) - self.assertEqual(db_connection.startcount, 1) - self.assertEqual(len(db_connection.threads), 0) - - # Test server start - engine.start() - self.assertEqual(engine.state, engine.states.STARTED) - - host = cherrypy.server.socket_host - port = cherrypy.server.socket_port - self.assertRaises(IOError, cherrypy._cpserver.check_port, host, port) - - # The db_connection should be running now - self.assertEqual(db_connection.running, True) - self.assertEqual(db_connection.startcount, 2) - self.assertEqual(len(db_connection.threads), 0) - - self.getPage("/") - self.assertBody("Hello World") - self.assertEqual(len(db_connection.threads), 1) - - # Test engine stop. This will also stop the HTTP server. - engine.stop() - self.assertEqual(engine.state, engine.states.STOPPED) - - # Verify that our custom stop function was called - self.assertEqual(db_connection.running, False) - self.assertEqual(len(db_connection.threads), 0) - - # Block the main thread now and verify that exit() works. - def exittest(): - self.getPage("/") - self.assertBody("Hello World") - engine.exit() - cherrypy.server.start() - engine.start_with_callback(exittest) - engine.block() - self.assertEqual(engine.state, engine.states.EXITING) - - def test_1_Restart(self): - cherrypy.server.start() - engine.start() - - # The db_connection should be running now - self.assertEqual(db_connection.running, True) - grace = db_connection.gracecount - - self.getPage("/") - self.assertBody("Hello World") - self.assertEqual(len(db_connection.threads), 1) - - # Test server restart from this thread - engine.graceful() - self.assertEqual(engine.state, engine.states.STARTED) - self.getPage("/") - self.assertBody("Hello World") - self.assertEqual(db_connection.running, True) - self.assertEqual(db_connection.gracecount, grace + 1) - self.assertEqual(len(db_connection.threads), 1) - - # Test server restart from inside a page handler - self.getPage("/graceful") - self.assertEqual(engine.state, engine.states.STARTED) - self.assertBody("app was (gracefully) restarted succesfully") - self.assertEqual(db_connection.running, True) - self.assertEqual(db_connection.gracecount, grace + 2) - # Since we are requesting synchronously, is only one thread used? - # Note that the "/graceful" request has been flushed. - self.assertEqual(len(db_connection.threads), 0) - - engine.stop() - self.assertEqual(engine.state, engine.states.STOPPED) - self.assertEqual(db_connection.running, False) - self.assertEqual(len(db_connection.threads), 0) - - def test_2_KeyboardInterrupt(self): - # Raise a keyboard interrupt in the HTTP server's main thread. - # We must start the server in this, the main thread - engine.start() - cherrypy.server.start() - - self.persistent = True - try: - # Make the first request and assert there's no "Connection: close". - self.getPage("/") - self.assertStatus('200 OK') - self.assertBody("Hello World") - self.assertNoHeader("Connection") - - cherrypy.server.httpserver.interrupt = KeyboardInterrupt - engine.block() - - self.assertEqual(db_connection.running, False) - self.assertEqual(len(db_connection.threads), 0) - self.assertEqual(engine.state, engine.states.EXITING) - finally: - self.persistent = False - - # Raise a keyboard interrupt in a page handler; on multithreaded - # servers, this should occur in one of the worker threads. - # This should raise a BadStatusLine error, since the worker - # thread will just die without writing a response. - engine.start() - cherrypy.server.start() - - try: - self.getPage("/ctrlc") - except BadStatusLine: - pass - else: - print(self.body) - self.fail("AssertionError: BadStatusLine not raised") - - engine.block() - self.assertEqual(db_connection.running, False) - self.assertEqual(len(db_connection.threads), 0) - - def test_3_Deadlocks(self): - cherrypy.config.update({'response.timeout': 0.2}) - - engine.start() - cherrypy.server.start() - try: - self.assertNotEqual(engine.timeout_monitor.thread, None) - - # Request a "normal" page. - self.assertEqual(engine.timeout_monitor.servings, []) - self.getPage("/") - self.assertBody("Hello World") - # request.close is called async. - while engine.timeout_monitor.servings: - sys.stdout.write(".") - time.sleep(0.01) - - # Request a page that explicitly checks itself for deadlock. - # The deadlock_timeout should be 2 secs. - self.getPage("/block_explicit") - self.assertBody("broken!") - - # Request a page that implicitly breaks deadlock. - # If we deadlock, we want to touch as little code as possible, - # so we won't even call handle_error, just bail ASAP. - self.getPage("/block_implicit") - self.assertStatus(500) - self.assertInBody("raise cherrypy.TimeoutError()") - finally: - engine.exit() - - def test_4_Autoreload(self): - # Start the demo script in a new process - p = helper.CPProcess(ssl=(self.scheme.lower()=='https')) - p.write_conf( - extra='test_case_name: "test_4_Autoreload"') - p.start(imports='cherrypy.test._test_states_demo') - try: - self.getPage("/start") - start = float(self.body) - - # Give the autoreloader time to cache the file time. - time.sleep(2) - - # Touch the file - os.utime(os.path.join(thisdir, "_test_states_demo.py"), None) - - # Give the autoreloader time to re-exec the process - time.sleep(2) - host = cherrypy.server.socket_host - port = cherrypy.server.socket_port - cherrypy._cpserver.wait_for_occupied_port(host, port) - - self.getPage("/start") - self.assert_(float(self.body) > start) - finally: - # Shut down the spawned process - self.getPage("/exit") - p.join() - - def test_5_Start_Error(self): - # If a process errors during start, it should stop the engine - # and exit with a non-zero exit code. - p = helper.CPProcess(ssl=(self.scheme.lower()=='https'), - wait=True) - p.write_conf( - extra="""starterror: True -test_case_name: "test_5_Start_Error" -""" - ) - p.start(imports='cherrypy.test._test_states_demo') - if p.exit_code == 0: - self.fail("Process failed to return nonzero exit code.") - - -class PluginTests(helper.CPWebCase): - def test_daemonize(self): - if os.name not in ['posix']: - return self.skip("skipped (not on posix) ") - self.HOST = '127.0.0.1' - self.PORT = 8081 - # Spawn the process and wait, when this returns, the original process - # is finished. If it daemonized properly, we should still be able - # to access pages. - p = helper.CPProcess(ssl=(self.scheme.lower()=='https'), - wait=True, daemonize=True, - socket_host='127.0.0.1', - socket_port=8081) - p.write_conf( - extra='test_case_name: "test_daemonize"') - p.start(imports='cherrypy.test._test_states_demo') - try: - # Just get the pid of the daemonization process. - self.getPage("/pid") - self.assertStatus(200) - page_pid = int(self.body) - self.assertEqual(page_pid, p.get_pid()) - finally: - # Shut down the spawned process - self.getPage("/exit") - p.join() - - # Wait until here to test the exit code because we want to ensure - # that we wait for the daemon to finish running before we fail. - if p.exit_code != 0: - self.fail("Daemonized parent process failed to exit cleanly.") - - -class SignalHandlingTests(helper.CPWebCase): - def test_SIGHUP_tty(self): - # When not daemonized, SIGHUP should shut down the server. - try: - from signal import SIGHUP - except ImportError: - return self.skip("skipped (no SIGHUP) ") - - # Spawn the process. - p = helper.CPProcess(ssl=(self.scheme.lower()=='https')) - p.write_conf( - extra='test_case_name: "test_SIGHUP_tty"') - p.start(imports='cherrypy.test._test_states_demo') - # Send a SIGHUP - os.kill(p.get_pid(), SIGHUP) - # This might hang if things aren't working right, but meh. - p.join() - - def test_SIGHUP_daemonized(self): - # When daemonized, SIGHUP should restart the server. - try: - from signal import SIGHUP - except ImportError: - return self.skip("skipped (no SIGHUP) ") - - if os.name not in ['posix']: - return self.skip("skipped (not on posix) ") - - # Spawn the process and wait, when this returns, the original process - # is finished. If it daemonized properly, we should still be able - # to access pages. - p = helper.CPProcess(ssl=(self.scheme.lower()=='https'), - wait=True, daemonize=True) - p.write_conf( - extra='test_case_name: "test_SIGHUP_daemonized"') - p.start(imports='cherrypy.test._test_states_demo') - - pid = p.get_pid() - try: - # Send a SIGHUP - os.kill(pid, SIGHUP) - # Give the server some time to restart - time.sleep(2) - self.getPage("/pid") - self.assertStatus(200) - new_pid = int(self.body) - self.assertNotEqual(new_pid, pid) - finally: - # Shut down the spawned process - self.getPage("/exit") - p.join() - - def test_SIGTERM(self): - # SIGTERM should shut down the server whether daemonized or not. - try: - from signal import SIGTERM - except ImportError: - return self.skip("skipped (no SIGTERM) ") - - try: - from os import kill - except ImportError: - return self.skip("skipped (no os.kill) ") - - # Spawn a normal, undaemonized process. - p = helper.CPProcess(ssl=(self.scheme.lower()=='https')) - p.write_conf( - extra='test_case_name: "test_SIGTERM"') - p.start(imports='cherrypy.test._test_states_demo') - # Send a SIGTERM - os.kill(p.get_pid(), SIGTERM) - # This might hang if things aren't working right, but meh. - p.join() - - if os.name in ['posix']: - # Spawn a daemonized process and test again. - p = helper.CPProcess(ssl=(self.scheme.lower()=='https'), - wait=True, daemonize=True) - p.write_conf( - extra='test_case_name: "test_SIGTERM_2"') - p.start(imports='cherrypy.test._test_states_demo') - # Send a SIGTERM - os.kill(p.get_pid(), SIGTERM) - # This might hang if things aren't working right, but meh. - p.join() - - def test_signal_handler_unsubscribe(self): - try: - from signal import SIGTERM - except ImportError: - return self.skip("skipped (no SIGTERM) ") - - try: - from os import kill - except ImportError: - return self.skip("skipped (no os.kill) ") - - # Spawn a normal, undaemonized process. - p = helper.CPProcess(ssl=(self.scheme.lower()=='https')) - p.write_conf( - extra="""unsubsig: True -test_case_name: "test_signal_handler_unsubscribe" -""") - p.start(imports='cherrypy.test._test_states_demo') - # Send a SIGTERM - os.kill(p.get_pid(), SIGTERM) - # This might hang if things aren't working right, but meh. - p.join() - - # Assert the old handler ran. - target_line = open(p.error_log, 'rb').readlines()[-10] - if not ntob("I am an old SIGTERM handler.") in target_line: - self.fail("Old SIGTERM handler did not run.\n%r" % target_line) - diff --git a/cherrypy/test/test_static.py b/cherrypy/test/test_static.py deleted file mode 100644 index 871420bd..00000000 --- a/cherrypy/test/test_static.py +++ /dev/null @@ -1,300 +0,0 @@ -from cherrypy._cpcompat import HTTPConnection, HTTPSConnection, ntob -from cherrypy._cpcompat import BytesIO - -import os -curdir = os.path.join(os.getcwd(), os.path.dirname(__file__)) -has_space_filepath = os.path.join(curdir, 'static', 'has space.html') -bigfile_filepath = os.path.join(curdir, "static", "bigfile.log") -BIGFILE_SIZE = 1024 * 1024 -import threading - -import cherrypy -from cherrypy.lib import static -from cherrypy.test import helper - - -class StaticTest(helper.CPWebCase): - - def setup_server(): - if not os.path.exists(has_space_filepath): - open(has_space_filepath, 'wb').write(ntob('Hello, world\r\n')) - if not os.path.exists(bigfile_filepath): - open(bigfile_filepath, 'wb').write(ntob("x" * BIGFILE_SIZE)) - - class Root: - - def bigfile(self): - from cherrypy.lib import static - self.f = static.serve_file(bigfile_filepath) - return self.f - bigfile.exposed = True - bigfile._cp_config = {'response.stream': True} - - def tell(self): - if self.f.input.closed: - return '' - return repr(self.f.input.tell()).rstrip('L') - tell.exposed = True - - def fileobj(self): - f = open(os.path.join(curdir, 'style.css'), 'rb') - return static.serve_fileobj(f, content_type='text/css') - fileobj.exposed = True - - def bytesio(self): - f = BytesIO(ntob('Fee\nfie\nfo\nfum')) - return static.serve_fileobj(f, content_type='text/plain') - bytesio.exposed = True - - class Static: - - def index(self): - return 'You want the Baron? You can have the Baron!' - index.exposed = True - - def dynamic(self): - return "This is a DYNAMIC page" - dynamic.exposed = True - - - root = Root() - root.static = Static() - - rootconf = { - '/static': { - 'tools.staticdir.on': True, - 'tools.staticdir.dir': 'static', - 'tools.staticdir.root': curdir, - }, - '/style.css': { - 'tools.staticfile.on': True, - 'tools.staticfile.filename': os.path.join(curdir, 'style.css'), - }, - '/docroot': { - 'tools.staticdir.on': True, - 'tools.staticdir.root': curdir, - 'tools.staticdir.dir': 'static', - 'tools.staticdir.index': 'index.html', - }, - '/error': { - 'tools.staticdir.on': True, - 'request.show_tracebacks': True, - }, - } - rootApp = cherrypy.Application(root) - rootApp.merge(rootconf) - - test_app_conf = { - '/test': { - 'tools.staticdir.index': 'index.html', - 'tools.staticdir.on': True, - 'tools.staticdir.root': curdir, - 'tools.staticdir.dir': 'static', - }, - } - testApp = cherrypy.Application(Static()) - testApp.merge(test_app_conf) - - vhost = cherrypy._cpwsgi.VirtualHost(rootApp, {'virt.net': testApp}) - cherrypy.tree.graft(vhost) - setup_server = staticmethod(setup_server) - - - def teardown_server(): - for f in (has_space_filepath, bigfile_filepath): - if os.path.exists(f): - try: - os.unlink(f) - except: - pass - teardown_server = staticmethod(teardown_server) - - - def testStatic(self): - self.getPage("/static/index.html") - self.assertStatus('200 OK') - self.assertHeader('Content-Type', 'text/html') - self.assertBody('Hello, world\r\n') - - # Using a staticdir.root value in a subdir... - self.getPage("/docroot/index.html") - self.assertStatus('200 OK') - self.assertHeader('Content-Type', 'text/html') - self.assertBody('Hello, world\r\n') - - # Check a filename with spaces in it - self.getPage("/static/has%20space.html") - self.assertStatus('200 OK') - self.assertHeader('Content-Type', 'text/html') - self.assertBody('Hello, world\r\n') - - self.getPage("/style.css") - self.assertStatus('200 OK') - self.assertHeader('Content-Type', 'text/css') - # Note: The body should be exactly 'Dummy stylesheet\n', but - # unfortunately some tools such as WinZip sometimes turn \n - # into \r\n on Windows when extracting the CherryPy tarball so - # we just check the content - self.assertMatchesBody('^Dummy stylesheet') - - def test_fallthrough(self): - # Test that NotFound will then try dynamic handlers (see [878]). - self.getPage("/static/dynamic") - self.assertBody("This is a DYNAMIC page") - - # Check a directory via fall-through to dynamic handler. - self.getPage("/static/") - self.assertStatus('200 OK') - self.assertHeader('Content-Type', 'text/html;charset=utf-8') - self.assertBody('You want the Baron? You can have the Baron!') - - def test_index(self): - # Check a directory via "staticdir.index". - self.getPage("/docroot/") - self.assertStatus('200 OK') - self.assertHeader('Content-Type', 'text/html') - self.assertBody('Hello, world\r\n') - # The same page should be returned even if redirected. - self.getPage("/docroot") - self.assertStatus(301) - self.assertHeader('Location', '%s/docroot/' % self.base()) - self.assertMatchesBody("This resource .* " - "%s/docroot/." % (self.base(), self.base())) - - def test_config_errors(self): - # Check that we get an error if no .file or .dir - self.getPage("/error/thing.html") - self.assertErrorPage(500) - self.assertMatchesBody(ntob("TypeError: staticdir\(\) takes at least 2 " - "(positional )?arguments \(0 given\)")) - - def test_security(self): - # Test up-level security - self.getPage("/static/../../test/style.css") - self.assertStatus((400, 403)) - - def test_modif(self): - # Test modified-since on a reasonably-large file - self.getPage("/static/dirback.jpg") - self.assertStatus("200 OK") - lastmod = "" - for k, v in self.headers: - if k == 'Last-Modified': - lastmod = v - ims = ("If-Modified-Since", lastmod) - self.getPage("/static/dirback.jpg", headers=[ims]) - self.assertStatus(304) - self.assertNoHeader("Content-Type") - self.assertNoHeader("Content-Length") - self.assertNoHeader("Content-Disposition") - self.assertBody("") - - def test_755_vhost(self): - self.getPage("/test/", [('Host', 'virt.net')]) - self.assertStatus(200) - self.getPage("/test", [('Host', 'virt.net')]) - self.assertStatus(301) - self.assertHeader('Location', self.scheme + '://virt.net/test/') - - def test_serve_fileobj(self): - self.getPage("/fileobj") - self.assertStatus('200 OK') - self.assertHeader('Content-Type', 'text/css;charset=utf-8') - self.assertMatchesBody('^Dummy stylesheet') - - def test_serve_bytesio(self): - self.getPage("/bytesio") - self.assertStatus('200 OK') - self.assertHeader('Content-Type', 'text/plain;charset=utf-8') - self.assertHeader('Content-Length', 14) - self.assertMatchesBody('Fee\nfie\nfo\nfum') - - def test_file_stream(self): - if cherrypy.server.protocol_version != "HTTP/1.1": - return self.skip() - - self.PROTOCOL = "HTTP/1.1" - - # Make an initial request - self.persistent = True - conn = self.HTTP_CONN - conn.putrequest("GET", "/bigfile", skip_host=True) - conn.putheader("Host", self.HOST) - conn.endheaders() - response = conn.response_class(conn.sock, method="GET") - response.begin() - self.assertEqual(response.status, 200) - - body = ntob('') - remaining = BIGFILE_SIZE - while remaining > 0: - data = response.fp.read(65536) - if not data: - break - body += data - remaining -= len(data) - - if self.scheme == "https": - newconn = HTTPSConnection - else: - newconn = HTTPConnection - s, h, b = helper.webtest.openURL( - ntob("/tell"), headers=[], host=self.HOST, port=self.PORT, - http_conn=newconn) - if not b: - # The file was closed on the server. - tell_position = BIGFILE_SIZE - else: - tell_position = int(b) - - expected = len(body) - if tell_position >= BIGFILE_SIZE: - # We can't exactly control how much content the server asks for. - # Fudge it by only checking the first half of the reads. - if expected < (BIGFILE_SIZE / 2): - self.fail( - "The file should have advanced to position %r, but has " - "already advanced to the end of the file. It may not be " - "streamed as intended, or at the wrong chunk size (64k)" % - expected) - elif tell_position < expected: - self.fail( - "The file should have advanced to position %r, but has " - "only advanced to position %r. It may not be streamed " - "as intended, or at the wrong chunk size (65536)" % - (expected, tell_position)) - - if body != ntob("x" * BIGFILE_SIZE): - self.fail("Body != 'x' * %d. Got %r instead (%d bytes)." % - (BIGFILE_SIZE, body[:50], len(body))) - conn.close() - - def test_file_stream_deadlock(self): - if cherrypy.server.protocol_version != "HTTP/1.1": - return self.skip() - - self.PROTOCOL = "HTTP/1.1" - - # Make an initial request but abort early. - self.persistent = True - conn = self.HTTP_CONN - conn.putrequest("GET", "/bigfile", skip_host=True) - conn.putheader("Host", self.HOST) - conn.endheaders() - response = conn.response_class(conn.sock, method="GET") - response.begin() - self.assertEqual(response.status, 200) - body = response.fp.read(65536) - if body != ntob("x" * len(body)): - self.fail("Body != 'x' * %d. Got %r instead (%d bytes)." % - (65536, body[:50], len(body))) - response.close() - conn.close() - - # Make a second request, which should fetch the whole file. - self.persistent = False - self.getPage("/bigfile") - if self.body != ntob("x" * BIGFILE_SIZE): - self.fail("Body != 'x' * %d. Got %r instead (%d bytes)." % - (BIGFILE_SIZE, self.body[:50], len(body))) - diff --git a/cherrypy/test/test_tools.py b/cherrypy/test/test_tools.py deleted file mode 100644 index bc8579f0..00000000 --- a/cherrypy/test/test_tools.py +++ /dev/null @@ -1,393 +0,0 @@ -"""Test the various means of instantiating and invoking tools.""" - -import gzip -import sys -from cherrypy._cpcompat import BytesIO, copyitems, itervalues, IncompleteRead, ntob, ntou, xrange -import time -timeout = 0.2 -import types - -import cherrypy -from cherrypy import tools - - -europoundUnicode = ntou('\x80\xa3') - - -# Client-side code # - -from cherrypy.test import helper - - -class ToolTests(helper.CPWebCase): - def setup_server(): - - # Put check_access in a custom toolbox with its own namespace - myauthtools = cherrypy._cptools.Toolbox("myauth") - - def check_access(default=False): - if not getattr(cherrypy.request, "userid", default): - raise cherrypy.HTTPError(401) - myauthtools.check_access = cherrypy.Tool('before_request_body', check_access) - - def numerify(): - def number_it(body): - for chunk in body: - for k, v in cherrypy.request.numerify_map: - chunk = chunk.replace(k, v) - yield chunk - cherrypy.response.body = number_it(cherrypy.response.body) - - class NumTool(cherrypy.Tool): - def _setup(self): - def makemap(): - m = self._merged_args().get("map", {}) - cherrypy.request.numerify_map = copyitems(m) - cherrypy.request.hooks.attach('on_start_resource', makemap) - - def critical(): - cherrypy.request.error_response = cherrypy.HTTPError(502).set_response - critical.failsafe = True - - cherrypy.request.hooks.attach('on_start_resource', critical) - cherrypy.request.hooks.attach(self._point, self.callable) - - tools.numerify = NumTool('before_finalize', numerify) - - # It's not mandatory to inherit from cherrypy.Tool. - class NadsatTool: - - def __init__(self): - self.ended = {} - self._name = "nadsat" - - def nadsat(self): - def nadsat_it_up(body): - for chunk in body: - chunk = chunk.replace(ntob("good"), ntob("horrorshow")) - chunk = chunk.replace(ntob("piece"), ntob("lomtick")) - yield chunk - cherrypy.response.body = nadsat_it_up(cherrypy.response.body) - nadsat.priority = 0 - - def cleanup(self): - # This runs after the request has been completely written out. - cherrypy.response.body = [ntob("razdrez")] - id = cherrypy.request.params.get("id") - if id: - self.ended[id] = True - cleanup.failsafe = True - - def _setup(self): - cherrypy.request.hooks.attach('before_finalize', self.nadsat) - cherrypy.request.hooks.attach('on_end_request', self.cleanup) - tools.nadsat = NadsatTool() - - def pipe_body(): - cherrypy.request.process_request_body = False - clen = int(cherrypy.request.headers['Content-Length']) - cherrypy.request.body = cherrypy.request.rfile.read(clen) - - # Assert that we can use a callable object instead of a function. - class Rotator(object): - def __call__(self, scale): - r = cherrypy.response - r.collapse_body() - r.body = [chr((ord(x) + scale) % 256) for x in r.body[0]] - cherrypy.tools.rotator = cherrypy.Tool('before_finalize', Rotator()) - - def stream_handler(next_handler, *args, **kwargs): - cherrypy.response.output = o = BytesIO() - try: - response = next_handler(*args, **kwargs) - # Ignore the response and return our accumulated output instead. - return o.getvalue() - finally: - o.close() - cherrypy.tools.streamer = cherrypy._cptools.HandlerWrapperTool(stream_handler) - - class Root: - def index(self): - return "Howdy earth!" - index.exposed = True - - def tarfile(self): - cherrypy.response.output.write(ntob('I am ')) - cherrypy.response.output.write(ntob('a tarfile')) - tarfile.exposed = True - tarfile._cp_config = {'tools.streamer.on': True} - - def euro(self): - hooks = list(cherrypy.request.hooks['before_finalize']) - hooks.sort() - cbnames = [x.callback.__name__ for x in hooks] - assert cbnames == ['gzip'], cbnames - priorities = [x.priority for x in hooks] - assert priorities == [80], priorities - yield ntou("Hello,") - yield ntou("world") - yield europoundUnicode - euro.exposed = True - - # Bare hooks - def pipe(self): - return cherrypy.request.body - pipe.exposed = True - pipe._cp_config = {'hooks.before_request_body': pipe_body} - - # Multiple decorators; include kwargs just for fun. - # Note that rotator must run before gzip. - def decorated_euro(self, *vpath): - yield ntou("Hello,") - yield ntou("world") - yield europoundUnicode - decorated_euro.exposed = True - decorated_euro = tools.gzip(compress_level=6)(decorated_euro) - decorated_euro = tools.rotator(scale=3)(decorated_euro) - - root = Root() - - - class TestType(type): - """Metaclass which automatically exposes all functions in each subclass, - and adds an instance of the subclass as an attribute of root. - """ - def __init__(cls, name, bases, dct): - type.__init__(cls, name, bases, dct) - for value in itervalues(dct): - if isinstance(value, types.FunctionType): - value.exposed = True - setattr(root, name.lower(), cls()) - class Test(object): - __metaclass__ = TestType - - - # METHOD ONE: - # Declare Tools in _cp_config - class Demo(Test): - - _cp_config = {"tools.nadsat.on": True} - - def index(self, id=None): - return "A good piece of cherry pie" - - def ended(self, id): - return repr(tools.nadsat.ended[id]) - - def err(self, id=None): - raise ValueError() - - def errinstream(self, id=None): - yield "nonconfidential" - raise ValueError() - yield "confidential" - - # METHOD TWO: decorator using Tool() - # We support Python 2.3, but the @-deco syntax would look like this: - # @tools.check_access() - def restricted(self): - return "Welcome!" - restricted = myauthtools.check_access()(restricted) - userid = restricted - - def err_in_onstart(self): - return "success!" - - def stream(self, id=None): - for x in xrange(100000000): - yield str(x) - stream._cp_config = {'response.stream': True} - - - conf = { - # METHOD THREE: - # Declare Tools in detached config - '/demo': { - 'tools.numerify.on': True, - 'tools.numerify.map': {ntob("pie"): ntob("3.14159")}, - }, - '/demo/restricted': { - 'request.show_tracebacks': False, - }, - '/demo/userid': { - 'request.show_tracebacks': False, - 'myauth.check_access.default': True, - }, - '/demo/errinstream': { - 'response.stream': True, - }, - '/demo/err_in_onstart': { - # Because this isn't a dict, on_start_resource will error. - 'tools.numerify.map': "pie->3.14159" - }, - # Combined tools - '/euro': { - 'tools.gzip.on': True, - 'tools.encode.on': True, - }, - # Priority specified in config - '/decorated_euro/subpath': { - 'tools.gzip.priority': 10, - }, - # Handler wrappers - '/tarfile': {'tools.streamer.on': True} - } - app = cherrypy.tree.mount(root, config=conf) - app.request_class.namespaces['myauth'] = myauthtools - - if sys.version_info >= (2, 5): - from cherrypy.test import _test_decorators - root.tooldecs = _test_decorators.ToolExamples() - setup_server = staticmethod(setup_server) - - def testHookErrors(self): - self.getPage("/demo/?id=1") - # If body is "razdrez", then on_end_request is being called too early. - self.assertBody("A horrorshow lomtick of cherry 3.14159") - # If this fails, then on_end_request isn't being called at all. - time.sleep(0.1) - self.getPage("/demo/ended/1") - self.assertBody("True") - - valerr = '\n raise ValueError()\nValueError' - self.getPage("/demo/err?id=3") - # If body is "razdrez", then on_end_request is being called too early. - self.assertErrorPage(502, pattern=valerr) - # If this fails, then on_end_request isn't being called at all. - time.sleep(0.1) - self.getPage("/demo/ended/3") - self.assertBody("True") - - # If body is "razdrez", then on_end_request is being called too early. - if (cherrypy.server.protocol_version == "HTTP/1.0" or - getattr(cherrypy.server, "using_apache", False)): - self.getPage("/demo/errinstream?id=5") - # Because this error is raised after the response body has - # started, the status should not change to an error status. - self.assertStatus("200 OK") - self.assertBody("nonconfidential") - else: - # Because this error is raised after the response body has - # started, and because it's chunked output, an error is raised by - # the HTTP client when it encounters incomplete output. - self.assertRaises((ValueError, IncompleteRead), self.getPage, - "/demo/errinstream?id=5") - # If this fails, then on_end_request isn't being called at all. - time.sleep(0.1) - self.getPage("/demo/ended/5") - self.assertBody("True") - - # Test the "__call__" technique (compile-time decorator). - self.getPage("/demo/restricted") - self.assertErrorPage(401) - - # Test compile-time decorator with kwargs from config. - self.getPage("/demo/userid") - self.assertBody("Welcome!") - - def testEndRequestOnDrop(self): - old_timeout = None - try: - httpserver = cherrypy.server.httpserver - old_timeout = httpserver.timeout - except (AttributeError, IndexError): - return self.skip() - - try: - httpserver.timeout = timeout - - # Test that on_end_request is called even if the client drops. - self.persistent = True - try: - conn = self.HTTP_CONN - conn.putrequest("GET", "/demo/stream?id=9", skip_host=True) - conn.putheader("Host", self.HOST) - conn.endheaders() - # Skip the rest of the request and close the conn. This will - # cause the server's active socket to error, which *should* - # result in the request being aborted, and request.close being - # called all the way up the stack (including WSGI middleware), - # eventually calling our on_end_request hook. - finally: - self.persistent = False - time.sleep(timeout * 2) - # Test that the on_end_request hook was called. - self.getPage("/demo/ended/9") - self.assertBody("True") - finally: - if old_timeout is not None: - httpserver.timeout = old_timeout - - def testGuaranteedHooks(self): - # The 'critical' on_start_resource hook is 'failsafe' (guaranteed - # to run even if there are failures in other on_start methods). - # This is NOT true of the other hooks. - # Here, we have set up a failure in NumerifyTool.numerify_map, - # but our 'critical' hook should run and set the error to 502. - self.getPage("/demo/err_in_onstart") - self.assertErrorPage(502) - self.assertInBody("AttributeError: 'str' object has no attribute 'items'") - - def testCombinedTools(self): - expectedResult = (ntou("Hello,world") + europoundUnicode).encode('utf-8') - zbuf = BytesIO() - zfile = gzip.GzipFile(mode='wb', fileobj=zbuf, compresslevel=9) - zfile.write(expectedResult) - zfile.close() - - self.getPage("/euro", headers=[("Accept-Encoding", "gzip"), - ("Accept-Charset", "ISO-8859-1,utf-8;q=0.7,*;q=0.7")]) - self.assertInBody(zbuf.getvalue()[:3]) - - zbuf = BytesIO() - zfile = gzip.GzipFile(mode='wb', fileobj=zbuf, compresslevel=6) - zfile.write(expectedResult) - zfile.close() - - self.getPage("/decorated_euro", headers=[("Accept-Encoding", "gzip")]) - self.assertInBody(zbuf.getvalue()[:3]) - - # This returns a different value because gzip's priority was - # lowered in conf, allowing the rotator to run after gzip. - # Of course, we don't want breakage in production apps, - # but it proves the priority was changed. - self.getPage("/decorated_euro/subpath", - headers=[("Accept-Encoding", "gzip")]) - self.assertInBody(''.join([chr((ord(x) + 3) % 256) for x in zbuf.getvalue()])) - - def testBareHooks(self): - content = "bit of a pain in me gulliver" - self.getPage("/pipe", - headers=[("Content-Length", str(len(content))), - ("Content-Type", "text/plain")], - method="POST", body=content) - self.assertBody(content) - - def testHandlerWrapperTool(self): - self.getPage("/tarfile") - self.assertBody("I am a tarfile") - - def testToolWithConfig(self): - if not sys.version_info >= (2, 5): - return self.skip("skipped (Python 2.5+ only)") - - self.getPage('/tooldecs/blah') - self.assertHeader('Content-Type', 'application/data') - - def testWarnToolOn(self): - # get - try: - numon = cherrypy.tools.numerify.on - except AttributeError: - pass - else: - raise AssertionError("Tool.on did not error as it should have.") - - # set - try: - cherrypy.tools.numerify.on = True - except AttributeError: - pass - else: - raise AssertionError("Tool.on did not error as it should have.") - diff --git a/cherrypy/test/test_tutorials.py b/cherrypy/test/test_tutorials.py deleted file mode 100644 index aab27861..00000000 --- a/cherrypy/test/test_tutorials.py +++ /dev/null @@ -1,201 +0,0 @@ -import sys - -import cherrypy -from cherrypy.test import helper - - -class TutorialTest(helper.CPWebCase): - - def setup_server(cls): - - conf = cherrypy.config.copy() - - def load_tut_module(name): - """Import or reload tutorial module as needed.""" - cherrypy.config.reset() - cherrypy.config.update(conf) - - target = "cherrypy.tutorial." + name - if target in sys.modules: - module = reload(sys.modules[target]) - else: - module = __import__(target, globals(), locals(), ['']) - # The above import will probably mount a new app at "". - app = cherrypy.tree.apps[""] - - app.root.load_tut_module = load_tut_module - app.root.sessions = sessions - app.root.traceback_setting = traceback_setting - - cls.supervisor.sync_apps() - load_tut_module.exposed = True - - def sessions(): - cherrypy.config.update({"tools.sessions.on": True}) - sessions.exposed = True - - def traceback_setting(): - return repr(cherrypy.request.show_tracebacks) - traceback_setting.exposed = True - - class Dummy: - pass - root = Dummy() - root.load_tut_module = load_tut_module - cherrypy.tree.mount(root) - setup_server = classmethod(setup_server) - - - def test01HelloWorld(self): - self.getPage("/load_tut_module/tut01_helloworld") - self.getPage("/") - self.assertBody('Hello world!') - - def test02ExposeMethods(self): - self.getPage("/load_tut_module/tut02_expose_methods") - self.getPage("/showMessage") - self.assertBody('Hello world!') - - def test03GetAndPost(self): - self.getPage("/load_tut_module/tut03_get_and_post") - - # Try different GET queries - self.getPage("/greetUser?name=Bob") - self.assertBody("Hey Bob, what's up?") - - self.getPage("/greetUser") - self.assertBody('Please enter your name here.') - - self.getPage("/greetUser?name=") - self.assertBody('No, really, enter your name here.') - - # Try the same with POST - self.getPage("/greetUser", method="POST", body="name=Bob") - self.assertBody("Hey Bob, what's up?") - - self.getPage("/greetUser", method="POST", body="name=") - self.assertBody('No, really, enter your name here.') - - def test04ComplexSite(self): - self.getPage("/load_tut_module/tut04_complex_site") - msg = ''' -

Here are some extra useful links:

- - - -

[Return to links page]

''' - self.getPage("/links/extra/") - self.assertBody(msg) - - def test05DerivedObjects(self): - self.getPage("/load_tut_module/tut05_derived_objects") - msg = ''' - - - Another Page - - -

Another Page

- -

- And this is the amazing second page! -

- - - - ''' - self.getPage("/another/") - self.assertBody(msg) - - def test06DefaultMethod(self): - self.getPage("/load_tut_module/tut06_default_method") - self.getPage('/hendrik') - self.assertBody('Hendrik Mans, CherryPy co-developer & crazy German ' - '(back)') - - def test07Sessions(self): - self.getPage("/load_tut_module/tut07_sessions") - self.getPage("/sessions") - - self.getPage('/') - self.assertBody("\n During your current session, you've viewed this" - "\n page 1 times! Your life is a patio of fun!" - "\n ") - - self.getPage('/', self.cookies) - self.assertBody("\n During your current session, you've viewed this" - "\n page 2 times! Your life is a patio of fun!" - "\n ") - - def test08GeneratorsAndYield(self): - self.getPage("/load_tut_module/tut08_generators_and_yield") - self.getPage('/') - self.assertBody('

Generators rule!

' - '

List of users:

' - 'Remi
Carlos
Hendrik
Lorenzo Lamas
' - '') - - def test09Files(self): - self.getPage("/load_tut_module/tut09_files") - - # Test upload - filesize = 5 - h = [("Content-type", "multipart/form-data; boundary=x"), - ("Content-Length", str(105 + filesize))] - b = '--x\n' + \ - 'Content-Disposition: form-data; name="myFile"; filename="hello.txt"\r\n' + \ - 'Content-Type: text/plain\r\n' + \ - '\r\n' + \ - 'a' * filesize + '\n' + \ - '--x--\n' - self.getPage('/upload', h, "POST", b) - self.assertBody(''' - - myFile length: %d
- myFile filename: hello.txt
- myFile mime-type: text/plain - - ''' % filesize) - - # Test download - self.getPage('/download') - self.assertStatus("200 OK") - self.assertHeader("Content-Type", "application/x-download") - self.assertHeader("Content-Disposition", - # Make sure the filename is quoted. - 'attachment; filename="pdf_file.pdf"') - self.assertEqual(len(self.body), 85698) - - def test10HTTPErrors(self): - self.getPage("/load_tut_module/tut10_http_errors") - - self.getPage("/") - self.assertInBody("""""") - self.assertInBody("""""") - self.assertInBody("""""") - self.assertInBody("""""") - self.assertInBody("""""") - - self.getPage("/traceback_setting") - setting = self.body - self.getPage("/toggleTracebacks") - self.assertStatus((302, 303)) - self.getPage("/traceback_setting") - self.assertBody(str(not eval(setting))) - - self.getPage("/error?code=500") - self.assertStatus(500) - self.assertInBody("The server encountered an unexpected condition " - "which prevented it from fulfilling the request.") - - self.getPage("/error?code=403") - self.assertStatus(403) - self.assertInBody("

You can't do that!

") - - self.getPage("/messageArg") - self.assertStatus(500) - self.assertInBody("If you construct an HTTPError with a 'message'") - diff --git a/cherrypy/test/test_virtualhost.py b/cherrypy/test/test_virtualhost.py deleted file mode 100644 index d6eed0ea..00000000 --- a/cherrypy/test/test_virtualhost.py +++ /dev/null @@ -1,107 +0,0 @@ -import os -curdir = os.path.join(os.getcwd(), os.path.dirname(__file__)) - -import cherrypy -from cherrypy.test import helper - - -class VirtualHostTest(helper.CPWebCase): - - def setup_server(): - class Root: - def index(self): - return "Hello, world" - index.exposed = True - - def dom4(self): - return "Under construction" - dom4.exposed = True - - def method(self, value): - return "You sent %s" % repr(value) - method.exposed = True - - class VHost: - def __init__(self, sitename): - self.sitename = sitename - - def index(self): - return "Welcome to %s" % self.sitename - index.exposed = True - - def vmethod(self, value): - return "You sent %s" % repr(value) - vmethod.exposed = True - - def url(self): - return cherrypy.url("nextpage") - url.exposed = True - - # Test static as a handler (section must NOT include vhost prefix) - static = cherrypy.tools.staticdir.handler(section='/static', dir=curdir) - - root = Root() - root.mydom2 = VHost("Domain 2") - root.mydom3 = VHost("Domain 3") - hostmap = {'www.mydom2.com': '/mydom2', - 'www.mydom3.com': '/mydom3', - 'www.mydom4.com': '/dom4', - } - cherrypy.tree.mount(root, config={ - '/': {'request.dispatch': cherrypy.dispatch.VirtualHost(**hostmap)}, - # Test static in config (section must include vhost prefix) - '/mydom2/static2': {'tools.staticdir.on': True, - 'tools.staticdir.root': curdir, - 'tools.staticdir.dir': 'static', - 'tools.staticdir.index': 'index.html', - }, - }) - setup_server = staticmethod(setup_server) - - def testVirtualHost(self): - self.getPage("/", [('Host', 'www.mydom1.com')]) - self.assertBody('Hello, world') - self.getPage("/mydom2/", [('Host', 'www.mydom1.com')]) - self.assertBody('Welcome to Domain 2') - - self.getPage("/", [('Host', 'www.mydom2.com')]) - self.assertBody('Welcome to Domain 2') - self.getPage("/", [('Host', 'www.mydom3.com')]) - self.assertBody('Welcome to Domain 3') - self.getPage("/", [('Host', 'www.mydom4.com')]) - self.assertBody('Under construction') - - # Test GET, POST, and positional params - self.getPage("/method?value=root") - self.assertBody("You sent u'root'") - self.getPage("/vmethod?value=dom2+GET", [('Host', 'www.mydom2.com')]) - self.assertBody("You sent u'dom2 GET'") - self.getPage("/vmethod", [('Host', 'www.mydom3.com')], method="POST", - body="value=dom3+POST") - self.assertBody("You sent u'dom3 POST'") - self.getPage("/vmethod/pos", [('Host', 'www.mydom3.com')]) - self.assertBody("You sent 'pos'") - - # Test that cherrypy.url uses the browser url, not the virtual url - self.getPage("/url", [('Host', 'www.mydom2.com')]) - self.assertBody("%s://www.mydom2.com/nextpage" % self.scheme) - - def test_VHost_plus_Static(self): - # Test static as a handler - self.getPage("/static/style.css", [('Host', 'www.mydom2.com')]) - self.assertStatus('200 OK') - self.assertHeader('Content-Type', 'text/css;charset=utf-8') - - # Test static in config - self.getPage("/static2/dirback.jpg", [('Host', 'www.mydom2.com')]) - self.assertStatus('200 OK') - self.assertHeader('Content-Type', 'image/jpeg') - - # Test static config with "index" arg - self.getPage("/static2/", [('Host', 'www.mydom2.com')]) - self.assertStatus('200 OK') - self.assertBody('Hello, world\r\n') - # Since tools.trailing_slash is on by default, this should redirect - self.getPage("/static2", [('Host', 'www.mydom2.com')]) - self.assertStatus(301) - diff --git a/cherrypy/test/test_wsgi_ns.py b/cherrypy/test/test_wsgi_ns.py deleted file mode 100644 index d57013c3..00000000 --- a/cherrypy/test/test_wsgi_ns.py +++ /dev/null @@ -1,80 +0,0 @@ -import cherrypy -from cherrypy.test import helper - - -class WSGI_Namespace_Test(helper.CPWebCase): - - def setup_server(): - - class WSGIResponse(object): - - def __init__(self, appresults): - self.appresults = appresults - self.iter = iter(appresults) - - def __iter__(self): - return self - - def next(self): - return self.iter.next() - - def close(self): - if hasattr(self.appresults, "close"): - self.appresults.close() - - - class ChangeCase(object): - - def __init__(self, app, to=None): - self.app = app - self.to = to - - def __call__(self, environ, start_response): - res = self.app(environ, start_response) - class CaseResults(WSGIResponse): - def next(this): - return getattr(this.iter.next(), self.to)() - return CaseResults(res) - - class Replacer(object): - - def __init__(self, app, map={}): - self.app = app - self.map = map - - def __call__(self, environ, start_response): - res = self.app(environ, start_response) - class ReplaceResults(WSGIResponse): - def next(this): - line = this.iter.next() - for k, v in self.map.iteritems(): - line = line.replace(k, v) - return line - return ReplaceResults(res) - - class Root(object): - - def index(self): - return "HellO WoRlD!" - index.exposed = True - - - root_conf = {'wsgi.pipeline': [('replace', Replacer)], - 'wsgi.replace.map': {'L': 'X', 'l': 'r'}, - } - - app = cherrypy.Application(Root()) - app.wsgiapp.pipeline.append(('changecase', ChangeCase)) - app.wsgiapp.config['changecase'] = {'to': 'upper'} - cherrypy.tree.mount(app, config={'/': root_conf}) - setup_server = staticmethod(setup_server) - - - def test_pipeline(self): - if not cherrypy.server.httpserver: - return self.skip() - - self.getPage("/") - # If body is "HEXXO WORXD!", the middleware was applied out of order. - self.assertBody("HERRO WORRD!") - diff --git a/cherrypy/test/test_wsgi_vhost.py b/cherrypy/test/test_wsgi_vhost.py deleted file mode 100644 index abb1a917..00000000 --- a/cherrypy/test/test_wsgi_vhost.py +++ /dev/null @@ -1,36 +0,0 @@ -import cherrypy -from cherrypy.test import helper - - -class WSGI_VirtualHost_Test(helper.CPWebCase): - - def setup_server(): - - class ClassOfRoot(object): - - def __init__(self, name): - self.name = name - - def index(self): - return "Welcome to the %s website!" % self.name - index.exposed = True - - - default = cherrypy.Application(None) - - domains = {} - for year in range(1997, 2008): - app = cherrypy.Application(ClassOfRoot('Class of %s' % year)) - domains['www.classof%s.example' % year] = app - - cherrypy.tree.graft(cherrypy._cpwsgi.VirtualHost(default, domains)) - setup_server = staticmethod(setup_server) - - def test_welcome(self): - if not cherrypy.server.using_wsgi: - return self.skip("skipped (not using WSGI)... ") - - for year in range(1997, 2008): - self.getPage("/", headers=[('Host', 'www.classof%s.example' % year)]) - self.assertBody("Welcome to the Class of %s website!" % year) - diff --git a/cherrypy/test/test_wsgiapps.py b/cherrypy/test/test_wsgiapps.py deleted file mode 100644 index fa5420c5..00000000 --- a/cherrypy/test/test_wsgiapps.py +++ /dev/null @@ -1,111 +0,0 @@ -from cherrypy.test import helper - - -class WSGIGraftTests(helper.CPWebCase): - - def setup_server(): - import os - curdir = os.path.join(os.getcwd(), os.path.dirname(__file__)) - - import cherrypy - - def test_app(environ, start_response): - status = '200 OK' - response_headers = [('Content-type', 'text/plain')] - start_response(status, response_headers) - output = ['Hello, world!\n', - 'This is a wsgi app running within CherryPy!\n\n'] - keys = list(environ.keys()) - keys.sort() - for k in keys: - output.append('%s: %s\n' % (k,environ[k])) - return output - - def test_empty_string_app(environ, start_response): - status = '200 OK' - response_headers = [('Content-type', 'text/plain')] - start_response(status, response_headers) - return ['Hello', '', ' ', '', 'world'] - - - class WSGIResponse(object): - - def __init__(self, appresults): - self.appresults = appresults - self.iter = iter(appresults) - - def __iter__(self): - return self - - def next(self): - return self.iter.next() - - def close(self): - if hasattr(self.appresults, "close"): - self.appresults.close() - - - class ReversingMiddleware(object): - - def __init__(self, app): - self.app = app - - def __call__(self, environ, start_response): - results = app(environ, start_response) - class Reverser(WSGIResponse): - def next(this): - line = list(this.iter.next()) - line.reverse() - return "".join(line) - return Reverser(results) - - class Root: - def index(self): - return "I'm a regular CherryPy page handler!" - index.exposed = True - - - cherrypy.tree.mount(Root()) - - cherrypy.tree.graft(test_app, '/hosted/app1') - cherrypy.tree.graft(test_empty_string_app, '/hosted/app3') - - # Set script_name explicitly to None to signal CP that it should - # be pulled from the WSGI environ each time. - app = cherrypy.Application(Root(), script_name=None) - cherrypy.tree.graft(ReversingMiddleware(app), '/hosted/app2') - setup_server = staticmethod(setup_server) - - wsgi_output = '''Hello, world! -This is a wsgi app running within CherryPy!''' - - def test_01_standard_app(self): - self.getPage("/") - self.assertBody("I'm a regular CherryPy page handler!") - - def test_04_pure_wsgi(self): - import cherrypy - if not cherrypy.server.using_wsgi: - return self.skip("skipped (not using WSGI)... ") - self.getPage("/hosted/app1") - self.assertHeader("Content-Type", "text/plain") - self.assertInBody(self.wsgi_output) - - def test_05_wrapped_cp_app(self): - import cherrypy - if not cherrypy.server.using_wsgi: - return self.skip("skipped (not using WSGI)... ") - self.getPage("/hosted/app2/") - body = list("I'm a regular CherryPy page handler!") - body.reverse() - body = "".join(body) - self.assertInBody(body) - - def test_06_empty_string_app(self): - import cherrypy - if not cherrypy.server.using_wsgi: - return self.skip("skipped (not using WSGI)... ") - self.getPage("/hosted/app3") - self.assertHeader("Content-Type", "text/plain") - self.assertInBody('Hello world') - diff --git a/cherrypy/test/test_xmlrpc.py b/cherrypy/test/test_xmlrpc.py deleted file mode 100644 index c4bf61e0..00000000 --- a/cherrypy/test/test_xmlrpc.py +++ /dev/null @@ -1,172 +0,0 @@ -import sys -from xmlrpclib import DateTime, Fault, ServerProxy, SafeTransport - -class HTTPSTransport(SafeTransport): - """Subclass of SafeTransport to fix sock.recv errors (by using file).""" - - def request(self, host, handler, request_body, verbose=0): - # issue XML-RPC request - h = self.make_connection(host) - if verbose: - h.set_debuglevel(1) - - self.send_request(h, handler, request_body) - self.send_host(h, host) - self.send_user_agent(h) - self.send_content(h, request_body) - - errcode, errmsg, headers = h.getreply() - if errcode != 200: - raise xmlrpclib.ProtocolError(host + handler, errcode, errmsg, - headers) - - self.verbose = verbose - - # Here's where we differ from the superclass. It says: - # try: - # sock = h._conn.sock - # except AttributeError: - # sock = None - # return self._parse_response(h.getfile(), sock) - - return self.parse_response(h.getfile()) - -import cherrypy - - -def setup_server(): - from cherrypy import _cptools - - class Root: - def index(self): - return "I'm a standard index!" - index.exposed = True - - - class XmlRpc(_cptools.XMLRPCController): - - def foo(self): - return "Hello world!" - foo.exposed = True - - def return_single_item_list(self): - return [42] - return_single_item_list.exposed = True - - def return_string(self): - return "here is a string" - return_string.exposed = True - - def return_tuple(self): - return ('here', 'is', 1, 'tuple') - return_tuple.exposed = True - - def return_dict(self): - return dict(a=1, b=2, c=3) - return_dict.exposed = True - - def return_composite(self): - return dict(a=1,z=26), 'hi', ['welcome', 'friend'] - return_composite.exposed = True - - def return_int(self): - return 42 - return_int.exposed = True - - def return_float(self): - return 3.14 - return_float.exposed = True - - def return_datetime(self): - return DateTime((2003, 10, 7, 8, 1, 0, 1, 280, -1)) - return_datetime.exposed = True - - def return_boolean(self): - return True - return_boolean.exposed = True - - def test_argument_passing(self, num): - return num * 2 - test_argument_passing.exposed = True - - def test_returning_Fault(self): - return Fault(1, "custom Fault response") - test_returning_Fault.exposed = True - - root = Root() - root.xmlrpc = XmlRpc() - cherrypy.tree.mount(root, config={'/': { - 'request.dispatch': cherrypy.dispatch.XMLRPCDispatcher(), - 'tools.xmlrpc.allow_none': 0, - }}) - - -from cherrypy.test import helper - -class XmlRpcTest(helper.CPWebCase): - setup_server = staticmethod(setup_server) - def testXmlRpc(self): - - scheme = "http" - try: - scheme = self.harness.scheme - except AttributeError: - pass - - if scheme == "https": - url = 'https://%s:%s/xmlrpc/' % (self.interface(), self.PORT) - proxy = ServerProxy(url, transport=HTTPSTransport()) - else: - url = 'http://%s:%s/xmlrpc/' % (self.interface(), self.PORT) - proxy = ServerProxy(url) - - # begin the tests ... - self.getPage("/xmlrpc/foo") - self.assertBody("Hello world!") - - self.assertEqual(proxy.return_single_item_list(), [42]) - self.assertNotEqual(proxy.return_single_item_list(), 'one bazillion') - self.assertEqual(proxy.return_string(), "here is a string") - self.assertEqual(proxy.return_tuple(), list(('here', 'is', 1, 'tuple'))) - self.assertEqual(proxy.return_dict(), {'a': 1, 'c': 3, 'b': 2}) - self.assertEqual(proxy.return_composite(), - [{'a': 1, 'z': 26}, 'hi', ['welcome', 'friend']]) - self.assertEqual(proxy.return_int(), 42) - self.assertEqual(proxy.return_float(), 3.14) - self.assertEqual(proxy.return_datetime(), - DateTime((2003, 10, 7, 8, 1, 0, 1, 280, -1))) - self.assertEqual(proxy.return_boolean(), True) - self.assertEqual(proxy.test_argument_passing(22), 22 * 2) - - # Test an error in the page handler (should raise an xmlrpclib.Fault) - try: - proxy.test_argument_passing({}) - except Exception: - x = sys.exc_info()[1] - self.assertEqual(x.__class__, Fault) - self.assertEqual(x.faultString, ("unsupported operand type(s) " - "for *: 'dict' and 'int'")) - else: - self.fail("Expected xmlrpclib.Fault") - - # http://www.cherrypy.org/ticket/533 - # if a method is not found, an xmlrpclib.Fault should be raised - try: - proxy.non_method() - except Exception: - x = sys.exc_info()[1] - self.assertEqual(x.__class__, Fault) - self.assertEqual(x.faultString, 'method "non_method" is not supported') - else: - self.fail("Expected xmlrpclib.Fault") - - # Test returning a Fault from the page handler. - try: - proxy.test_returning_Fault() - except Exception: - x = sys.exc_info()[1] - self.assertEqual(x.__class__, Fault) - self.assertEqual(x.faultString, ("custom Fault response")) - else: - self.fail("Expected xmlrpclib.Fault") - diff --git a/cherrypy/test/webtest.py b/cherrypy/test/webtest.py deleted file mode 100644 index 969eab0e..00000000 --- a/cherrypy/test/webtest.py +++ /dev/null @@ -1,535 +0,0 @@ -"""Extensions to unittest for web frameworks. - -Use the WebCase.getPage method to request a page from your HTTP server. - -Framework Integration -===================== - -If you have control over your server process, you can handle errors -in the server-side of the HTTP conversation a bit better. You must run -both the client (your WebCase tests) and the server in the same process -(but in separate threads, obviously). - -When an error occurs in the framework, call server_error. It will print -the traceback to stdout, and keep any assertions you have from running -(the assumption is that, if the server errors, the page output will not -be of further significance to your tests). -""" - -import os -import pprint -import re -import socket -import sys -import time -import traceback -import types - -from unittest import * -from unittest import _TextTestResult - -from cherrypy._cpcompat import basestring, HTTPConnection, HTTPSConnection, unicodestr - - - -def interface(host): - """Return an IP address for a client connection given the server host. - - If the server is listening on '0.0.0.0' (INADDR_ANY) - or '::' (IN6ADDR_ANY), this will return the proper localhost.""" - if host == '0.0.0.0': - # INADDR_ANY, which should respond on localhost. - return "127.0.0.1" - if host == '::': - # IN6ADDR_ANY, which should respond on localhost. - return "::1" - return host - - -class TerseTestResult(_TextTestResult): - - def printErrors(self): - # Overridden to avoid unnecessary empty line - if self.errors or self.failures: - if self.dots or self.showAll: - self.stream.writeln() - self.printErrorList('ERROR', self.errors) - self.printErrorList('FAIL', self.failures) - - -class TerseTestRunner(TextTestRunner): - """A test runner class that displays results in textual form.""" - - def _makeResult(self): - return TerseTestResult(self.stream, self.descriptions, self.verbosity) - - def run(self, test): - "Run the given test case or test suite." - # Overridden to remove unnecessary empty lines and separators - result = self._makeResult() - test(result) - result.printErrors() - if not result.wasSuccessful(): - self.stream.write("FAILED (") - failed, errored = list(map(len, (result.failures, result.errors))) - if failed: - self.stream.write("failures=%d" % failed) - if errored: - if failed: self.stream.write(", ") - self.stream.write("errors=%d" % errored) - self.stream.writeln(")") - return result - - -class ReloadingTestLoader(TestLoader): - - def loadTestsFromName(self, name, module=None): - """Return a suite of all tests cases given a string specifier. - - The name may resolve either to a module, a test case class, a - test method within a test case class, or a callable object which - returns a TestCase or TestSuite instance. - - The method optionally resolves the names relative to a given module. - """ - parts = name.split('.') - unused_parts = [] - if module is None: - if not parts: - raise ValueError("incomplete test name: %s" % name) - else: - parts_copy = parts[:] - while parts_copy: - target = ".".join(parts_copy) - if target in sys.modules: - module = reload(sys.modules[target]) - parts = unused_parts - break - else: - try: - module = __import__(target) - parts = unused_parts - break - except ImportError: - unused_parts.insert(0,parts_copy[-1]) - del parts_copy[-1] - if not parts_copy: - raise - parts = parts[1:] - obj = module - for part in parts: - obj = getattr(obj, part) - - if type(obj) == types.ModuleType: - return self.loadTestsFromModule(obj) - elif (isinstance(obj, (type, types.ClassType)) and - issubclass(obj, TestCase)): - return self.loadTestsFromTestCase(obj) - elif type(obj) == types.UnboundMethodType: - return obj.im_class(obj.__name__) - elif hasattr(obj, '__call__'): - test = obj() - if not isinstance(test, TestCase) and \ - not isinstance(test, TestSuite): - raise ValueError("calling %s returned %s, " - "not a test" % (obj,test)) - return test - else: - raise ValueError("do not know how to make test from: %s" % obj) - - -try: - # Jython support - if sys.platform[:4] == 'java': - def getchar(): - # Hopefully this is enough - return sys.stdin.read(1) - else: - # On Windows, msvcrt.getch reads a single char without output. - import msvcrt - def getchar(): - return msvcrt.getch() -except ImportError: - # Unix getchr - import tty, termios - def getchar(): - fd = sys.stdin.fileno() - old_settings = termios.tcgetattr(fd) - try: - tty.setraw(sys.stdin.fileno()) - ch = sys.stdin.read(1) - finally: - termios.tcsetattr(fd, termios.TCSADRAIN, old_settings) - return ch - - -class WebCase(TestCase): - HOST = "127.0.0.1" - PORT = 8000 - HTTP_CONN = HTTPConnection - PROTOCOL = "HTTP/1.1" - - scheme = "http" - url = None - - status = None - headers = None - body = None - - encoding = 'utf-8' - - time = None - - def get_conn(self, auto_open=False): - """Return a connection to our HTTP server.""" - if self.scheme == "https": - cls = HTTPSConnection - else: - cls = HTTPConnection - conn = cls(self.interface(), self.PORT) - # Automatically re-connect? - conn.auto_open = auto_open - conn.connect() - return conn - - def set_persistent(self, on=True, auto_open=False): - """Make our HTTP_CONN persistent (or not). - - If the 'on' argument is True (the default), then self.HTTP_CONN - will be set to an instance of HTTPConnection (or HTTPS - if self.scheme is "https"). This will then persist across requests. - - We only allow for a single open connection, so if you call this - and we currently have an open connection, it will be closed. - """ - try: - self.HTTP_CONN.close() - except (TypeError, AttributeError): - pass - - if on: - self.HTTP_CONN = self.get_conn(auto_open=auto_open) - else: - if self.scheme == "https": - self.HTTP_CONN = HTTPSConnection - else: - self.HTTP_CONN = HTTPConnection - - def _get_persistent(self): - return hasattr(self.HTTP_CONN, "__class__") - def _set_persistent(self, on): - self.set_persistent(on) - persistent = property(_get_persistent, _set_persistent) - - def interface(self): - """Return an IP address for a client connection. - - If the server is listening on '0.0.0.0' (INADDR_ANY) - or '::' (IN6ADDR_ANY), this will return the proper localhost.""" - return interface(self.HOST) - - def getPage(self, url, headers=None, method="GET", body=None, protocol=None): - """Open the url with debugging support. Return status, headers, body.""" - ServerError.on = False - - if isinstance(url, unicodestr): - url = url.encode('utf-8') - if isinstance(body, unicodestr): - body = body.encode('utf-8') - - self.url = url - self.time = None - start = time.time() - result = openURL(url, headers, method, body, self.HOST, self.PORT, - self.HTTP_CONN, protocol or self.PROTOCOL) - self.time = time.time() - start - self.status, self.headers, self.body = result - - # Build a list of request cookies from the previous response cookies. - self.cookies = [('Cookie', v) for k, v in self.headers - if k.lower() == 'set-cookie'] - - if ServerError.on: - raise ServerError() - return result - - interactive = True - console_height = 30 - - def _handlewebError(self, msg): - print("") - print(" ERROR: %s" % msg) - - if not self.interactive: - raise self.failureException(msg) - - p = " Show: [B]ody [H]eaders [S]tatus [U]RL; [I]gnore, [R]aise, or sys.e[X]it >> " - sys.stdout.write(p) - sys.stdout.flush() - while True: - i = getchar().upper() - if i not in "BHSUIRX": - continue - print(i.upper()) # Also prints new line - if i == "B": - for x, line in enumerate(self.body.splitlines()): - if (x + 1) % self.console_height == 0: - # The \r and comma should make the next line overwrite - sys.stdout.write("<-- More -->\r") - m = getchar().lower() - # Erase our "More" prompt - sys.stdout.write(" \r") - if m == "q": - break - print(line) - elif i == "H": - pprint.pprint(self.headers) - elif i == "S": - print(self.status) - elif i == "U": - print(self.url) - elif i == "I": - # return without raising the normal exception - return - elif i == "R": - raise self.failureException(msg) - elif i == "X": - self.exit() - sys.stdout.write(p) - sys.stdout.flush() - - def exit(self): - sys.exit() - - def assertStatus(self, status, msg=None): - """Fail if self.status != status.""" - if isinstance(status, basestring): - if not self.status == status: - if msg is None: - msg = 'Status (%r) != %r' % (self.status, status) - self._handlewebError(msg) - elif isinstance(status, int): - code = int(self.status[:3]) - if code != status: - if msg is None: - msg = 'Status (%r) != %r' % (self.status, status) - self._handlewebError(msg) - else: - # status is a tuple or list. - match = False - for s in status: - if isinstance(s, basestring): - if self.status == s: - match = True - break - elif int(self.status[:3]) == s: - match = True - break - if not match: - if msg is None: - msg = 'Status (%r) not in %r' % (self.status, status) - self._handlewebError(msg) - - def assertHeader(self, key, value=None, msg=None): - """Fail if (key, [value]) not in self.headers.""" - lowkey = key.lower() - for k, v in self.headers: - if k.lower() == lowkey: - if value is None or str(value) == v: - return v - - if msg is None: - if value is None: - msg = '%r not in headers' % key - else: - msg = '%r:%r not in headers' % (key, value) - self._handlewebError(msg) - - def assertHeaderItemValue(self, key, value, msg=None): - """Fail if the header does not contain the specified value""" - actual_value = self.assertHeader(key, msg=msg) - header_values = map(str.strip, actual_value.split(',')) - if value in header_values: - return value - - if msg is None: - msg = "%r not in %r" % (value, header_values) - self._handlewebError(msg) - - def assertNoHeader(self, key, msg=None): - """Fail if key in self.headers.""" - lowkey = key.lower() - matches = [k for k, v in self.headers if k.lower() == lowkey] - if matches: - if msg is None: - msg = '%r in headers' % key - self._handlewebError(msg) - - def assertBody(self, value, msg=None): - """Fail if value != self.body.""" - if value != self.body: - if msg is None: - msg = 'expected body:\n%r\n\nactual body:\n%r' % (value, self.body) - self._handlewebError(msg) - - def assertInBody(self, value, msg=None): - """Fail if value not in self.body.""" - if value not in self.body: - if msg is None: - msg = '%r not in body: %s' % (value, self.body) - self._handlewebError(msg) - - def assertNotInBody(self, value, msg=None): - """Fail if value in self.body.""" - if value in self.body: - if msg is None: - msg = '%r found in body' % value - self._handlewebError(msg) - - def assertMatchesBody(self, pattern, msg=None, flags=0): - """Fail if value (a regex pattern) is not in self.body.""" - if re.search(pattern, self.body, flags) is None: - if msg is None: - msg = 'No match for %r in body' % pattern - self._handlewebError(msg) - - -methods_with_bodies = ("POST", "PUT") - -def cleanHeaders(headers, method, body, host, port): - """Return request headers, with required headers added (if missing).""" - if headers is None: - headers = [] - - # Add the required Host request header if not present. - # [This specifies the host:port of the server, not the client.] - found = False - for k, v in headers: - if k.lower() == 'host': - found = True - break - if not found: - if port == 80: - headers.append(("Host", host)) - else: - headers.append(("Host", "%s:%s" % (host, port))) - - if method in methods_with_bodies: - # Stick in default type and length headers if not present - found = False - for k, v in headers: - if k.lower() == 'content-type': - found = True - break - if not found: - headers.append(("Content-Type", "application/x-www-form-urlencoded")) - headers.append(("Content-Length", str(len(body or "")))) - - return headers - - -def shb(response): - """Return status, headers, body the way we like from a response.""" - h = [] - key, value = None, None - for line in response.msg.headers: - if line: - if line[0] in " \t": - value += line.strip() - else: - if key and value: - h.append((key, value)) - key, value = line.split(":", 1) - key = key.strip() - value = value.strip() - if key and value: - h.append((key, value)) - - return "%s %s" % (response.status, response.reason), h, response.read() - - -def openURL(url, headers=None, method="GET", body=None, - host="127.0.0.1", port=8000, http_conn=HTTPConnection, - protocol="HTTP/1.1"): - """Open the given HTTP resource and return status, headers, and body.""" - - headers = cleanHeaders(headers, method, body, host, port) - - # Trying 10 times is simply in case of socket errors. - # Normal case--it should run once. - for trial in range(10): - try: - # Allow http_conn to be a class or an instance - if hasattr(http_conn, "host"): - conn = http_conn - else: - conn = http_conn(interface(host), port) - - conn._http_vsn_str = protocol - conn._http_vsn = int("".join([x for x in protocol if x.isdigit()])) - - # skip_accept_encoding argument added in python version 2.4 - if sys.version_info < (2, 4): - def putheader(self, header, value): - if header == 'Accept-Encoding' and value == 'identity': - return - self.__class__.putheader(self, header, value) - import new - conn.putheader = new.instancemethod(putheader, conn, conn.__class__) - conn.putrequest(method.upper(), url, skip_host=True) - else: - conn.putrequest(method.upper(), url, skip_host=True, - skip_accept_encoding=True) - - for key, value in headers: - conn.putheader(key, value) - conn.endheaders() - - if body is not None: - conn.send(body) - - # Handle response - response = conn.getresponse() - - s, h, b = shb(response) - - if not hasattr(http_conn, "host"): - # We made our own conn instance. Close it. - conn.close() - - return s, h, b - except socket.error: - time.sleep(0.5) - raise - - -# Add any exceptions which your web framework handles -# normally (that you don't want server_error to trap). -ignored_exceptions = [] - -# You'll want set this to True when you can't guarantee -# that each response will immediately follow each request; -# for example, when handling requests via multiple threads. -ignore_all = False - -class ServerError(Exception): - on = False - - -def server_error(exc=None): - """Server debug hook. Return True if exception handled, False if ignored. - - You probably want to wrap this, so you can still handle an error using - your framework when it's ignored. - """ - if exc is None: - exc = sys.exc_info() - - if ignore_all or exc[0] in ignored_exceptions: - return False - else: - ServerError.on = True - print("") - print("".join(traceback.format_exception(*exc))) - return True - diff --git a/cherrypy/wsgiserver/__init__.py b/cherrypy/wsgiserver/__init__.py index 55d1dd90..ee6190fe 100644 --- a/cherrypy/wsgiserver/__init__.py +++ b/cherrypy/wsgiserver/__init__.py @@ -1,2219 +1,14 @@ -"""A high-speed, production ready, thread pooled, generic HTTP server. +__all__ = ['HTTPRequest', 'HTTPConnection', 'HTTPServer', + 'SizeCheckWrapper', 'KnownLengthRFile', 'ChunkedRFile', + 'MaxSizeExceeded', 'NoSSLError', 'FatalSSLAlert', + 'WorkerThread', 'ThreadPool', 'SSLAdapter', + 'CherryPyWSGIServer', + 'Gateway', 'WSGIGateway', 'WSGIGateway_10', 'WSGIGateway_u0', + 'WSGIPathInfoDispatcher', 'get_ssl_adapter_class'] -Simplest example on how to use this module directly -(without using CherryPy's application machinery):: - - from cherrypy import wsgiserver - - def my_crazy_app(environ, start_response): - status = '200 OK' - response_headers = [('Content-type','text/plain')] - start_response(status, response_headers) - return ['Hello world!'] - - server = wsgiserver.CherryPyWSGIServer( - ('0.0.0.0', 8070), my_crazy_app, - server_name='www.cherrypy.example') - server.start() - -The CherryPy WSGI server can serve as many WSGI applications -as you want in one instance by using a WSGIPathInfoDispatcher:: - - d = WSGIPathInfoDispatcher({'/': my_crazy_app, '/blog': my_blog_app}) - server = wsgiserver.CherryPyWSGIServer(('0.0.0.0', 80), d) - -Want SSL support? Just set server.ssl_adapter to an SSLAdapter instance. - -This won't call the CherryPy engine (application side) at all, only the -HTTP server, which is independent from the rest of CherryPy. Don't -let the name "CherryPyWSGIServer" throw you; the name merely reflects -its origin, not its coupling. - -For those of you wanting to understand internals of this module, here's the -basic call flow. The server's listening thread runs a very tight loop, -sticking incoming connections onto a Queue:: - - server = CherryPyWSGIServer(...) - server.start() - while True: - tick() - # This blocks until a request comes in: - child = socket.accept() - conn = HTTPConnection(child, ...) - server.requests.put(conn) - -Worker threads are kept in a pool and poll the Queue, popping off and then -handling each connection in turn. Each connection can consist of an arbitrary -number of requests and their responses, so we run a nested loop:: - - while True: - conn = server.requests.get() - conn.communicate() - -> while True: - req = HTTPRequest(...) - req.parse_request() - -> # Read the Request-Line, e.g. "GET /page HTTP/1.1" - req.rfile.readline() - read_headers(req.rfile, req.inheaders) - req.respond() - -> response = app(...) - try: - for chunk in response: - if chunk: - req.write(chunk) - finally: - if hasattr(response, "close"): - response.close() - if req.close_connection: - return -""" - -CRLF = '\r\n' -import os -import Queue -import re -quoted_slash = re.compile("(?i)%2F") -import rfc822 -import socket import sys -if 'win' in sys.platform and not hasattr(socket, 'IPPROTO_IPV6'): - socket.IPPROTO_IPV6 = 41 -try: - import cStringIO as StringIO -except ImportError: - import StringIO -DEFAULT_BUFFER_SIZE = -1 - -_fileobject_uses_str_type = isinstance(socket._fileobject(None)._rbuf, basestring) - -import threading -import time -import traceback -def format_exc(limit=None): - """Like print_exc() but return a string. Backport for Python 2.3.""" - try: - etype, value, tb = sys.exc_info() - return ''.join(traceback.format_exception(etype, value, tb, limit)) - finally: - etype = value = tb = None - - -from urllib import unquote -from urlparse import urlparse -import warnings - -import errno - -def plat_specific_errors(*errnames): - """Return error numbers for all errors in errnames on this platform. - - The 'errno' module contains different global constants depending on - the specific platform (OS). This function will return the list of - numeric values for a given list of potential names. - """ - errno_names = dir(errno) - nums = [getattr(errno, k) for k in errnames if k in errno_names] - # de-dupe the list - return dict.fromkeys(nums).keys() - -socket_error_eintr = plat_specific_errors("EINTR", "WSAEINTR") - -socket_errors_to_ignore = plat_specific_errors( - "EPIPE", - "EBADF", "WSAEBADF", - "ENOTSOCK", "WSAENOTSOCK", - "ETIMEDOUT", "WSAETIMEDOUT", - "ECONNREFUSED", "WSAECONNREFUSED", - "ECONNRESET", "WSAECONNRESET", - "ECONNABORTED", "WSAECONNABORTED", - "ENETRESET", "WSAENETRESET", - "EHOSTDOWN", "EHOSTUNREACH", - ) -socket_errors_to_ignore.append("timed out") -socket_errors_to_ignore.append("The read operation timed out") - -socket_errors_nonblocking = plat_specific_errors( - 'EAGAIN', 'EWOULDBLOCK', 'WSAEWOULDBLOCK') - -comma_separated_headers = ['Accept', 'Accept-Charset', 'Accept-Encoding', - 'Accept-Language', 'Accept-Ranges', 'Allow', 'Cache-Control', - 'Connection', 'Content-Encoding', 'Content-Language', 'Expect', - 'If-Match', 'If-None-Match', 'Pragma', 'Proxy-Authenticate', 'TE', - 'Trailer', 'Transfer-Encoding', 'Upgrade', 'Vary', 'Via', 'Warning', - 'WWW-Authenticate'] - - -import logging -if not hasattr(logging, 'statistics'): logging.statistics = {} - - -def read_headers(rfile, hdict=None): - """Read headers from the given stream into the given header dict. - - If hdict is None, a new header dict is created. Returns the populated - header dict. - - Headers which are repeated are folded together using a comma if their - specification so dictates. - - This function raises ValueError when the read bytes violate the HTTP spec. - You should probably return "400 Bad Request" if this happens. - """ - if hdict is None: - hdict = {} - - while True: - line = rfile.readline() - if not line: - # No more data--illegal end of headers - raise ValueError("Illegal end of headers.") - - if line == CRLF: - # Normal end of headers - break - if not line.endswith(CRLF): - raise ValueError("HTTP requires CRLF terminators") - - if line[0] in ' \t': - # It's a continuation line. - v = line.strip() - else: - try: - k, v = line.split(":", 1) - except ValueError: - raise ValueError("Illegal header line.") - # TODO: what about TE and WWW-Authenticate? - k = k.strip().title() - v = v.strip() - hname = k - - if k in comma_separated_headers: - existing = hdict.get(hname) - if existing: - v = ", ".join((existing, v)) - hdict[hname] = v - - return hdict - - -class MaxSizeExceeded(Exception): - pass - -class SizeCheckWrapper(object): - """Wraps a file-like object, raising MaxSizeExceeded if too large.""" - - def __init__(self, rfile, maxlen): - self.rfile = rfile - self.maxlen = maxlen - self.bytes_read = 0 - - def _check_length(self): - if self.maxlen and self.bytes_read > self.maxlen: - raise MaxSizeExceeded() - - def read(self, size=None): - data = self.rfile.read(size) - self.bytes_read += len(data) - self._check_length() - return data - - def readline(self, size=None): - if size is not None: - data = self.rfile.readline(size) - self.bytes_read += len(data) - self._check_length() - return data - - # User didn't specify a size ... - # We read the line in chunks to make sure it's not a 100MB line ! - res = [] - while True: - data = self.rfile.readline(256) - self.bytes_read += len(data) - self._check_length() - res.append(data) - # See http://www.cherrypy.org/ticket/421 - if len(data) < 256 or data[-1:] == "\n": - return ''.join(res) - - def readlines(self, sizehint=0): - # Shamelessly stolen from StringIO - total = 0 - lines = [] - line = self.readline() - while line: - lines.append(line) - total += len(line) - if 0 < sizehint <= total: - break - line = self.readline() - return lines - - def close(self): - self.rfile.close() - - def __iter__(self): - return self - - def next(self): - data = self.rfile.next() - self.bytes_read += len(data) - self._check_length() - return data - - -class KnownLengthRFile(object): - """Wraps a file-like object, returning an empty string when exhausted.""" - - def __init__(self, rfile, content_length): - self.rfile = rfile - self.remaining = content_length - - def read(self, size=None): - if self.remaining == 0: - return '' - if size is None: - size = self.remaining - else: - size = min(size, self.remaining) - - data = self.rfile.read(size) - self.remaining -= len(data) - return data - - def readline(self, size=None): - if self.remaining == 0: - return '' - if size is None: - size = self.remaining - else: - size = min(size, self.remaining) - - data = self.rfile.readline(size) - self.remaining -= len(data) - return data - - def readlines(self, sizehint=0): - # Shamelessly stolen from StringIO - total = 0 - lines = [] - line = self.readline(sizehint) - while line: - lines.append(line) - total += len(line) - if 0 < sizehint <= total: - break - line = self.readline(sizehint) - return lines - - def close(self): - self.rfile.close() - - def __iter__(self): - return self - - def __next__(self): - data = next(self.rfile) - self.remaining -= len(data) - return data - - -class ChunkedRFile(object): - """Wraps a file-like object, returning an empty string when exhausted. - - This class is intended to provide a conforming wsgi.input value for - request entities that have been encoded with the 'chunked' transfer - encoding. - """ - - def __init__(self, rfile, maxlen, bufsize=8192): - self.rfile = rfile - self.maxlen = maxlen - self.bytes_read = 0 - self.buffer = '' - self.bufsize = bufsize - self.closed = False - - def _fetch(self): - if self.closed: - return - - line = self.rfile.readline() - self.bytes_read += len(line) - - if self.maxlen and self.bytes_read > self.maxlen: - raise MaxSizeExceeded("Request Entity Too Large", self.maxlen) - - line = line.strip().split(";", 1) - - try: - chunk_size = line.pop(0) - chunk_size = int(chunk_size, 16) - except ValueError: - raise ValueError("Bad chunked transfer size: " + repr(chunk_size)) - - if chunk_size <= 0: - self.closed = True - return - -## if line: chunk_extension = line[0] - - if self.maxlen and self.bytes_read + chunk_size > self.maxlen: - raise IOError("Request Entity Too Large") - - chunk = self.rfile.read(chunk_size) - self.bytes_read += len(chunk) - self.buffer += chunk - - crlf = self.rfile.read(2) - if crlf != CRLF: - raise ValueError( - "Bad chunked transfer coding (expected '\\r\\n', " - "got " + repr(crlf) + ")") - - def read(self, size=None): - data = '' - while True: - if size and len(data) >= size: - return data - - if not self.buffer: - self._fetch() - if not self.buffer: - # EOF - return data - - if size: - remaining = size - len(data) - data += self.buffer[:remaining] - self.buffer = self.buffer[remaining:] - else: - data += self.buffer - - def readline(self, size=None): - data = '' - while True: - if size and len(data) >= size: - return data - - if not self.buffer: - self._fetch() - if not self.buffer: - # EOF - return data - - newline_pos = self.buffer.find('\n') - if size: - if newline_pos == -1: - remaining = size - len(data) - data += self.buffer[:remaining] - self.buffer = self.buffer[remaining:] - else: - remaining = min(size - len(data), newline_pos) - data += self.buffer[:remaining] - self.buffer = self.buffer[remaining:] - else: - if newline_pos == -1: - data += self.buffer - else: - data += self.buffer[:newline_pos] - self.buffer = self.buffer[newline_pos:] - - def readlines(self, sizehint=0): - # Shamelessly stolen from StringIO - total = 0 - lines = [] - line = self.readline(sizehint) - while line: - lines.append(line) - total += len(line) - if 0 < sizehint <= total: - break - line = self.readline(sizehint) - return lines - - def read_trailer_lines(self): - if not self.closed: - raise ValueError( - "Cannot read trailers until the request body has been read.") - - while True: - line = self.rfile.readline() - if not line: - # No more data--illegal end of headers - raise ValueError("Illegal end of headers.") - - self.bytes_read += len(line) - if self.maxlen and self.bytes_read > self.maxlen: - raise IOError("Request Entity Too Large") - - if line == CRLF: - # Normal end of headers - break - if not line.endswith(CRLF): - raise ValueError("HTTP requires CRLF terminators") - - yield line - - def close(self): - self.rfile.close() - - def __iter__(self): - # Shamelessly stolen from StringIO - total = 0 - line = self.readline(sizehint) - while line: - yield line - total += len(line) - if 0 < sizehint <= total: - break - line = self.readline(sizehint) - - -class HTTPRequest(object): - """An HTTP Request (and response). - - A single HTTP connection may consist of multiple request/response pairs. - """ - - server = None - """The HTTPServer object which is receiving this request.""" - - conn = None - """The HTTPConnection object on which this request connected.""" - - inheaders = {} - """A dict of request headers.""" - - outheaders = [] - """A list of header tuples to write in the response.""" - - ready = False - """When True, the request has been parsed and is ready to begin generating - the response. When False, signals the calling Connection that the response - should not be generated and the connection should close.""" - - close_connection = False - """Signals the calling Connection that the request should close. This does - not imply an error! The client and/or server may each request that the - connection be closed.""" - - chunked_write = False - """If True, output will be encoded with the "chunked" transfer-coding. - - This value is set automatically inside send_headers.""" - - def __init__(self, server, conn): - self.server= server - self.conn = conn - - self.ready = False - self.started_request = False - self.scheme = "http" - if self.server.ssl_adapter is not None: - self.scheme = "https" - # Use the lowest-common protocol in case read_request_line errors. - self.response_protocol = 'HTTP/1.0' - self.inheaders = {} - - self.status = "" - self.outheaders = [] - self.sent_headers = False - self.close_connection = self.__class__.close_connection - self.chunked_read = False - self.chunked_write = self.__class__.chunked_write - - def parse_request(self): - """Parse the next HTTP request start-line and message-headers.""" - self.rfile = SizeCheckWrapper(self.conn.rfile, - self.server.max_request_header_size) - try: - self.read_request_line() - except MaxSizeExceeded: - self.simple_response("414 Request-URI Too Long", - "The Request-URI sent with the request exceeds the maximum " - "allowed bytes.") - return - - try: - success = self.read_request_headers() - except MaxSizeExceeded: - self.simple_response("413 Request Entity Too Large", - "The headers sent with the request exceed the maximum " - "allowed bytes.") - return - else: - if not success: - return - - self.ready = True - - def read_request_line(self): - # HTTP/1.1 connections are persistent by default. If a client - # requests a page, then idles (leaves the connection open), - # then rfile.readline() will raise socket.error("timed out"). - # Note that it does this based on the value given to settimeout(), - # and doesn't need the client to request or acknowledge the close - # (although your TCP stack might suffer for it: cf Apache's history - # with FIN_WAIT_2). - request_line = self.rfile.readline() - - # Set started_request to True so communicate() knows to send 408 - # from here on out. - self.started_request = True - if not request_line: - # Force self.ready = False so the connection will close. - self.ready = False - return - - if request_line == CRLF: - # RFC 2616 sec 4.1: "...if the server is reading the protocol - # stream at the beginning of a message and receives a CRLF - # first, it should ignore the CRLF." - # But only ignore one leading line! else we enable a DoS. - request_line = self.rfile.readline() - if not request_line: - self.ready = False - return - - if not request_line.endswith(CRLF): - self.simple_response("400 Bad Request", "HTTP requires CRLF terminators") - return - - try: - method, uri, req_protocol = request_line.strip().split(" ", 2) - rp = int(req_protocol[5]), int(req_protocol[7]) - except (ValueError, IndexError): - self.simple_response("400 Bad Request", "Malformed Request-Line") - return - - self.uri = uri - self.method = method - - # uri may be an abs_path (including "http://host.domain.tld"); - scheme, authority, path = self.parse_request_uri(uri) - if '#' in path: - self.simple_response("400 Bad Request", - "Illegal #fragment in Request-URI.") - return - - if scheme: - self.scheme = scheme - - qs = '' - if '?' in path: - path, qs = path.split('?', 1) - - # Unquote the path+params (e.g. "/this%20path" -> "/this path"). - # http://www.w3.org/Protocols/rfc2616/rfc2616-sec5.html#sec5.1.2 - # - # But note that "...a URI must be separated into its components - # before the escaped characters within those components can be - # safely decoded." http://www.ietf.org/rfc/rfc2396.txt, sec 2.4.2 - # Therefore, "/this%2Fpath" becomes "/this%2Fpath", not "/this/path". - try: - atoms = [unquote(x) for x in quoted_slash.split(path)] - except ValueError, ex: - self.simple_response("400 Bad Request", ex.args[0]) - return - path = "%2F".join(atoms) - self.path = path - - # Note that, like wsgiref and most other HTTP servers, - # we "% HEX HEX"-unquote the path but not the query string. - self.qs = qs - - # Compare request and server HTTP protocol versions, in case our - # server does not support the requested protocol. Limit our output - # to min(req, server). We want the following output: - # request server actual written supported response - # protocol protocol response protocol feature set - # a 1.0 1.0 1.0 1.0 - # b 1.0 1.1 1.1 1.0 - # c 1.1 1.0 1.0 1.0 - # d 1.1 1.1 1.1 1.1 - # Notice that, in (b), the response will be "HTTP/1.1" even though - # the client only understands 1.0. RFC 2616 10.5.6 says we should - # only return 505 if the _major_ version is different. - sp = int(self.server.protocol[5]), int(self.server.protocol[7]) - - if sp[0] != rp[0]: - self.simple_response("505 HTTP Version Not Supported") - return - self.request_protocol = req_protocol - self.response_protocol = "HTTP/%s.%s" % min(rp, sp) - - def read_request_headers(self): - """Read self.rfile into self.inheaders. Return success.""" - - # then all the http headers - try: - read_headers(self.rfile, self.inheaders) - except ValueError, ex: - self.simple_response("400 Bad Request", ex.args[0]) - return False - - mrbs = self.server.max_request_body_size - if mrbs and int(self.inheaders.get("Content-Length", 0)) > mrbs: - self.simple_response("413 Request Entity Too Large", - "The entity sent with the request exceeds the maximum " - "allowed bytes.") - return False - - # Persistent connection support - if self.response_protocol == "HTTP/1.1": - # Both server and client are HTTP/1.1 - if self.inheaders.get("Connection", "") == "close": - self.close_connection = True - else: - # Either the server or client (or both) are HTTP/1.0 - if self.inheaders.get("Connection", "") != "Keep-Alive": - self.close_connection = True - - # Transfer-Encoding support - te = None - if self.response_protocol == "HTTP/1.1": - te = self.inheaders.get("Transfer-Encoding") - if te: - te = [x.strip().lower() for x in te.split(",") if x.strip()] - - self.chunked_read = False - - if te: - for enc in te: - if enc == "chunked": - self.chunked_read = True - else: - # Note that, even if we see "chunked", we must reject - # if there is an extension we don't recognize. - self.simple_response("501 Unimplemented") - self.close_connection = True - return False - - # From PEP 333: - # "Servers and gateways that implement HTTP 1.1 must provide - # transparent support for HTTP 1.1's "expect/continue" mechanism. - # This may be done in any of several ways: - # 1. Respond to requests containing an Expect: 100-continue request - # with an immediate "100 Continue" response, and proceed normally. - # 2. Proceed with the request normally, but provide the application - # with a wsgi.input stream that will send the "100 Continue" - # response if/when the application first attempts to read from - # the input stream. The read request must then remain blocked - # until the client responds. - # 3. Wait until the client decides that the server does not support - # expect/continue, and sends the request body on its own. - # (This is suboptimal, and is not recommended.) - # - # We used to do 3, but are now doing 1. Maybe we'll do 2 someday, - # but it seems like it would be a big slowdown for such a rare case. - if self.inheaders.get("Expect", "") == "100-continue": - # Don't use simple_response here, because it emits headers - # we don't want. See http://www.cherrypy.org/ticket/951 - msg = self.server.protocol + " 100 Continue\r\n\r\n" - try: - self.conn.wfile.sendall(msg) - except socket.error, x: - if x.args[0] not in socket_errors_to_ignore: - raise - return True - - def parse_request_uri(self, uri): - """Parse a Request-URI into (scheme, authority, path). - - Note that Request-URI's must be one of:: - - Request-URI = "*" | absoluteURI | abs_path | authority - - Therefore, a Request-URI which starts with a double forward-slash - cannot be a "net_path":: - - net_path = "//" authority [ abs_path ] - - Instead, it must be interpreted as an "abs_path" with an empty first - path segment:: - - abs_path = "/" path_segments - path_segments = segment *( "/" segment ) - segment = *pchar *( ";" param ) - param = *pchar - """ - if uri == "*": - return None, None, uri - - i = uri.find('://') - if i > 0 and '?' not in uri[:i]: - # An absoluteURI. - # If there's a scheme (and it must be http or https), then: - # http_URL = "http:" "//" host [ ":" port ] [ abs_path [ "?" query ]] - scheme, remainder = uri[:i].lower(), uri[i + 3:] - authority, path = remainder.split("/", 1) - return scheme, authority, path - - if uri.startswith('/'): - # An abs_path. - return None, None, uri - else: - # An authority. - return None, uri, None - - def respond(self): - """Call the gateway and write its iterable output.""" - mrbs = self.server.max_request_body_size - if self.chunked_read: - self.rfile = ChunkedRFile(self.conn.rfile, mrbs) - else: - cl = int(self.inheaders.get("Content-Length", 0)) - if mrbs and mrbs < cl: - if not self.sent_headers: - self.simple_response("413 Request Entity Too Large", - "The entity sent with the request exceeds the maximum " - "allowed bytes.") - return - self.rfile = KnownLengthRFile(self.conn.rfile, cl) - - self.server.gateway(self).respond() - - if (self.ready and not self.sent_headers): - self.sent_headers = True - self.send_headers() - if self.chunked_write: - self.conn.wfile.sendall("0\r\n\r\n") - - def simple_response(self, status, msg=""): - """Write a simple response back to the client.""" - status = str(status) - buf = [self.server.protocol + " " + - status + CRLF, - "Content-Length: %s\r\n" % len(msg), - "Content-Type: text/plain\r\n"] - - if status[:3] in ("413", "414"): - # Request Entity Too Large / Request-URI Too Long - self.close_connection = True - if self.response_protocol == 'HTTP/1.1': - # This will not be true for 414, since read_request_line - # usually raises 414 before reading the whole line, and we - # therefore cannot know the proper response_protocol. - buf.append("Connection: close\r\n") - else: - # HTTP/1.0 had no 413/414 status nor Connection header. - # Emit 400 instead and trust the message body is enough. - status = "400 Bad Request" - - buf.append(CRLF) - if msg: - if isinstance(msg, unicode): - msg = msg.encode("ISO-8859-1") - buf.append(msg) - - try: - self.conn.wfile.sendall("".join(buf)) - except socket.error, x: - if x.args[0] not in socket_errors_to_ignore: - raise - - def write(self, chunk): - """Write unbuffered data to the client.""" - if self.chunked_write and chunk: - buf = [hex(len(chunk))[2:], CRLF, chunk, CRLF] - self.conn.wfile.sendall("".join(buf)) - else: - self.conn.wfile.sendall(chunk) - - def send_headers(self): - """Assert, process, and send the HTTP response message-headers. - - You must set self.status, and self.outheaders before calling this. - """ - hkeys = [key.lower() for key, value in self.outheaders] - status = int(self.status[:3]) - - if status == 413: - # Request Entity Too Large. Close conn to avoid garbage. - self.close_connection = True - elif "content-length" not in hkeys: - # "All 1xx (informational), 204 (no content), - # and 304 (not modified) responses MUST NOT - # include a message-body." So no point chunking. - if status < 200 or status in (204, 205, 304): - pass - else: - if (self.response_protocol == 'HTTP/1.1' - and self.method != 'HEAD'): - # Use the chunked transfer-coding - self.chunked_write = True - self.outheaders.append(("Transfer-Encoding", "chunked")) - else: - # Closing the conn is the only way to determine len. - self.close_connection = True - - if "connection" not in hkeys: - if self.response_protocol == 'HTTP/1.1': - # Both server and client are HTTP/1.1 or better - if self.close_connection: - self.outheaders.append(("Connection", "close")) - else: - # Server and/or client are HTTP/1.0 - if not self.close_connection: - self.outheaders.append(("Connection", "Keep-Alive")) - - if (not self.close_connection) and (not self.chunked_read): - # Read any remaining request body data on the socket. - # "If an origin server receives a request that does not include an - # Expect request-header field with the "100-continue" expectation, - # the request includes a request body, and the server responds - # with a final status code before reading the entire request body - # from the transport connection, then the server SHOULD NOT close - # the transport connection until it has read the entire request, - # or until the client closes the connection. Otherwise, the client - # might not reliably receive the response message. However, this - # requirement is not be construed as preventing a server from - # defending itself against denial-of-service attacks, or from - # badly broken client implementations." - remaining = getattr(self.rfile, 'remaining', 0) - if remaining > 0: - self.rfile.read(remaining) - - if "date" not in hkeys: - self.outheaders.append(("Date", rfc822.formatdate())) - - if "server" not in hkeys: - self.outheaders.append(("Server", self.server.server_name)) - - buf = [self.server.protocol + " " + self.status + CRLF] - for k, v in self.outheaders: - buf.append(k + ": " + v + CRLF) - buf.append(CRLF) - self.conn.wfile.sendall("".join(buf)) - - -class NoSSLError(Exception): - """Exception raised when a client speaks HTTP to an HTTPS socket.""" - pass - - -class FatalSSLAlert(Exception): - """Exception raised when the SSL implementation signals a fatal alert.""" - pass - - -class CP_fileobject(socket._fileobject): - """Faux file object attached to a socket object.""" - - def __init__(self, *args, **kwargs): - self.bytes_read = 0 - self.bytes_written = 0 - socket._fileobject.__init__(self, *args, **kwargs) - - def sendall(self, data): - """Sendall for non-blocking sockets.""" - while data: - try: - bytes_sent = self.send(data) - data = data[bytes_sent:] - except socket.error, e: - if e.args[0] not in socket_errors_nonblocking: - raise - - def send(self, data): - bytes_sent = self._sock.send(data) - self.bytes_written += bytes_sent - return bytes_sent - - def flush(self): - if self._wbuf: - buffer = "".join(self._wbuf) - self._wbuf = [] - self.sendall(buffer) - - def recv(self, size): - while True: - try: - data = self._sock.recv(size) - self.bytes_read += len(data) - return data - except socket.error, e: - if (e.args[0] not in socket_errors_nonblocking - and e.args[0] not in socket_error_eintr): - raise - - if not _fileobject_uses_str_type: - def read(self, size=-1): - # Use max, disallow tiny reads in a loop as they are very inefficient. - # We never leave read() with any leftover data from a new recv() call - # in our internal buffer. - rbufsize = max(self._rbufsize, self.default_bufsize) - # Our use of StringIO rather than lists of string objects returned by - # recv() minimizes memory usage and fragmentation that occurs when - # rbufsize is large compared to the typical return value of recv(). - buf = self._rbuf - buf.seek(0, 2) # seek end - if size < 0: - # Read until EOF - self._rbuf = StringIO.StringIO() # reset _rbuf. we consume it via buf. - while True: - data = self.recv(rbufsize) - if not data: - break - buf.write(data) - return buf.getvalue() - else: - # Read until size bytes or EOF seen, whichever comes first - buf_len = buf.tell() - if buf_len >= size: - # Already have size bytes in our buffer? Extract and return. - buf.seek(0) - rv = buf.read(size) - self._rbuf = StringIO.StringIO() - self._rbuf.write(buf.read()) - return rv - - self._rbuf = StringIO.StringIO() # reset _rbuf. we consume it via buf. - while True: - left = size - buf_len - # recv() will malloc the amount of memory given as its - # parameter even though it often returns much less data - # than that. The returned data string is short lived - # as we copy it into a StringIO and free it. This avoids - # fragmentation issues on many platforms. - data = self.recv(left) - if not data: - break - n = len(data) - if n == size and not buf_len: - # Shortcut. Avoid buffer data copies when: - # - We have no data in our buffer. - # AND - # - Our call to recv returned exactly the - # number of bytes we were asked to read. - return data - if n == left: - buf.write(data) - del data # explicit free - break - assert n <= left, "recv(%d) returned %d bytes" % (left, n) - buf.write(data) - buf_len += n - del data # explicit free - #assert buf_len == buf.tell() - return buf.getvalue() - - def readline(self, size=-1): - buf = self._rbuf - buf.seek(0, 2) # seek end - if buf.tell() > 0: - # check if we already have it in our buffer - buf.seek(0) - bline = buf.readline(size) - if bline.endswith('\n') or len(bline) == size: - self._rbuf = StringIO.StringIO() - self._rbuf.write(buf.read()) - return bline - del bline - if size < 0: - # Read until \n or EOF, whichever comes first - if self._rbufsize <= 1: - # Speed up unbuffered case - buf.seek(0) - buffers = [buf.read()] - self._rbuf = StringIO.StringIO() # reset _rbuf. we consume it via buf. - data = None - recv = self.recv - while data != "\n": - data = recv(1) - if not data: - break - buffers.append(data) - return "".join(buffers) - - buf.seek(0, 2) # seek end - self._rbuf = StringIO.StringIO() # reset _rbuf. we consume it via buf. - while True: - data = self.recv(self._rbufsize) - if not data: - break - nl = data.find('\n') - if nl >= 0: - nl += 1 - buf.write(data[:nl]) - self._rbuf.write(data[nl:]) - del data - break - buf.write(data) - return buf.getvalue() - else: - # Read until size bytes or \n or EOF seen, whichever comes first - buf.seek(0, 2) # seek end - buf_len = buf.tell() - if buf_len >= size: - buf.seek(0) - rv = buf.read(size) - self._rbuf = StringIO.StringIO() - self._rbuf.write(buf.read()) - return rv - self._rbuf = StringIO.StringIO() # reset _rbuf. we consume it via buf. - while True: - data = self.recv(self._rbufsize) - if not data: - break - left = size - buf_len - # did we just receive a newline? - nl = data.find('\n', 0, left) - if nl >= 0: - nl += 1 - # save the excess data to _rbuf - self._rbuf.write(data[nl:]) - if buf_len: - buf.write(data[:nl]) - break - else: - # Shortcut. Avoid data copy through buf when returning - # a substring of our first recv(). - return data[:nl] - n = len(data) - if n == size and not buf_len: - # Shortcut. Avoid data copy through buf when - # returning exactly all of our first recv(). - return data - if n >= left: - buf.write(data[:left]) - self._rbuf.write(data[left:]) - break - buf.write(data) - buf_len += n - #assert buf_len == buf.tell() - return buf.getvalue() - else: - def read(self, size=-1): - if size < 0: - # Read until EOF - buffers = [self._rbuf] - self._rbuf = "" - if self._rbufsize <= 1: - recv_size = self.default_bufsize - else: - recv_size = self._rbufsize - - while True: - data = self.recv(recv_size) - if not data: - break - buffers.append(data) - return "".join(buffers) - else: - # Read until size bytes or EOF seen, whichever comes first - data = self._rbuf - buf_len = len(data) - if buf_len >= size: - self._rbuf = data[size:] - return data[:size] - buffers = [] - if data: - buffers.append(data) - self._rbuf = "" - while True: - left = size - buf_len - recv_size = max(self._rbufsize, left) - data = self.recv(recv_size) - if not data: - break - buffers.append(data) - n = len(data) - if n >= left: - self._rbuf = data[left:] - buffers[-1] = data[:left] - break - buf_len += n - return "".join(buffers) - - def readline(self, size=-1): - data = self._rbuf - if size < 0: - # Read until \n or EOF, whichever comes first - if self._rbufsize <= 1: - # Speed up unbuffered case - assert data == "" - buffers = [] - while data != "\n": - data = self.recv(1) - if not data: - break - buffers.append(data) - return "".join(buffers) - nl = data.find('\n') - if nl >= 0: - nl += 1 - self._rbuf = data[nl:] - return data[:nl] - buffers = [] - if data: - buffers.append(data) - self._rbuf = "" - while True: - data = self.recv(self._rbufsize) - if not data: - break - buffers.append(data) - nl = data.find('\n') - if nl >= 0: - nl += 1 - self._rbuf = data[nl:] - buffers[-1] = data[:nl] - break - return "".join(buffers) - else: - # Read until size bytes or \n or EOF seen, whichever comes first - nl = data.find('\n', 0, size) - if nl >= 0: - nl += 1 - self._rbuf = data[nl:] - return data[:nl] - buf_len = len(data) - if buf_len >= size: - self._rbuf = data[size:] - return data[:size] - buffers = [] - if data: - buffers.append(data) - self._rbuf = "" - while True: - data = self.recv(self._rbufsize) - if not data: - break - buffers.append(data) - left = size - buf_len - nl = data.find('\n', 0, left) - if nl >= 0: - nl += 1 - self._rbuf = data[nl:] - buffers[-1] = data[:nl] - break - n = len(data) - if n >= left: - self._rbuf = data[left:] - buffers[-1] = data[:left] - break - buf_len += n - return "".join(buffers) - - -class HTTPConnection(object): - """An HTTP connection (active socket). - - server: the Server object which received this connection. - socket: the raw socket object (usually TCP) for this connection. - makefile: a fileobject class for reading from the socket. - """ - - remote_addr = None - remote_port = None - ssl_env = None - rbufsize = DEFAULT_BUFFER_SIZE - wbufsize = DEFAULT_BUFFER_SIZE - RequestHandlerClass = HTTPRequest - - def __init__(self, server, sock, makefile=CP_fileobject): - self.server = server - self.socket = sock - self.rfile = makefile(sock, "rb", self.rbufsize) - self.wfile = makefile(sock, "wb", self.wbufsize) - self.requests_seen = 0 - - def communicate(self): - """Read each request and respond appropriately.""" - request_seen = False - try: - while True: - # (re)set req to None so that if something goes wrong in - # the RequestHandlerClass constructor, the error doesn't - # get written to the previous request. - req = None - req = self.RequestHandlerClass(self.server, self) - - # This order of operations should guarantee correct pipelining. - req.parse_request() - if self.server.stats['Enabled']: - self.requests_seen += 1 - if not req.ready: - # Something went wrong in the parsing (and the server has - # probably already made a simple_response). Return and - # let the conn close. - return - - request_seen = True - req.respond() - if req.close_connection: - return - except socket.error, e: - errnum = e.args[0] - # sadly SSL sockets return a different (longer) time out string - if errnum == 'timed out' or errnum == 'The read operation timed out': - # Don't error if we're between requests; only error - # if 1) no request has been started at all, or 2) we're - # in the middle of a request. - # See http://www.cherrypy.org/ticket/853 - if (not request_seen) or (req and req.started_request): - # Don't bother writing the 408 if the response - # has already started being written. - if req and not req.sent_headers: - try: - req.simple_response("408 Request Timeout") - except FatalSSLAlert: - # Close the connection. - return - elif errnum not in socket_errors_to_ignore: - if req and not req.sent_headers: - try: - req.simple_response("500 Internal Server Error", - format_exc()) - except FatalSSLAlert: - # Close the connection. - return - return - except (KeyboardInterrupt, SystemExit): - raise - except FatalSSLAlert: - # Close the connection. - return - except NoSSLError: - if req and not req.sent_headers: - # Unwrap our wfile - self.wfile = CP_fileobject(self.socket._sock, "wb", self.wbufsize) - req.simple_response("400 Bad Request", - "The client sent a plain HTTP request, but " - "this server only speaks HTTPS on this port.") - self.linger = True - except Exception: - if req and not req.sent_headers: - try: - req.simple_response("500 Internal Server Error", format_exc()) - except FatalSSLAlert: - # Close the connection. - return - - linger = False - - def close(self): - """Close the socket underlying this connection.""" - self.rfile.close() - - if not self.linger: - # Python's socket module does NOT call close on the kernel socket - # when you call socket.close(). We do so manually here because we - # want this server to send a FIN TCP segment immediately. Note this - # must be called *before* calling socket.close(), because the latter - # drops its reference to the kernel socket. - if hasattr(self.socket, '_sock'): - self.socket._sock.close() - self.socket.close() - else: - # On the other hand, sometimes we want to hang around for a bit - # to make sure the client has a chance to read our entire - # response. Skipping the close() calls here delays the FIN - # packet until the socket object is garbage-collected later. - # Someday, perhaps, we'll do the full lingering_close that - # Apache does, but not today. - pass - - -_SHUTDOWNREQUEST = None - -class WorkerThread(threading.Thread): - """Thread which continuously polls a Queue for Connection objects. - - Due to the timing issues of polling a Queue, a WorkerThread does not - check its own 'ready' flag after it has started. To stop the thread, - it is necessary to stick a _SHUTDOWNREQUEST object onto the Queue - (one for each running WorkerThread). - """ - - conn = None - """The current connection pulled off the Queue, or None.""" - - server = None - """The HTTP Server which spawned this thread, and which owns the - Queue and is placing active connections into it.""" - - ready = False - """A simple flag for the calling server to know when this thread - has begun polling the Queue.""" - - - def __init__(self, server): - self.ready = False - self.server = server - - self.requests_seen = 0 - self.bytes_read = 0 - self.bytes_written = 0 - self.start_time = None - self.work_time = 0 - self.stats = { - 'Requests': lambda s: self.requests_seen + ((self.start_time is None) and 0 or self.conn.requests_seen), - 'Bytes Read': lambda s: self.bytes_read + ((self.start_time is None) and 0 or self.conn.rfile.bytes_read), - 'Bytes Written': lambda s: self.bytes_written + ((self.start_time is None) and 0 or self.conn.wfile.bytes_written), - 'Work Time': lambda s: self.work_time + ((self.start_time is None) and 0 or time.time() - self.start_time), - 'Read Throughput': lambda s: s['Bytes Read'](s) / (s['Work Time'](s) or 1e-6), - 'Write Throughput': lambda s: s['Bytes Written'](s) / (s['Work Time'](s) or 1e-6), - } - threading.Thread.__init__(self) - - def run(self): - self.server.stats['Worker Threads'][self.getName()] = self.stats - try: - self.ready = True - while True: - conn = self.server.requests.get() - if conn is _SHUTDOWNREQUEST: - return - - self.conn = conn - if self.server.stats['Enabled']: - self.start_time = time.time() - try: - conn.communicate() - finally: - conn.close() - if self.server.stats['Enabled']: - self.requests_seen += self.conn.requests_seen - self.bytes_read += self.conn.rfile.bytes_read - self.bytes_written += self.conn.wfile.bytes_written - self.work_time += time.time() - self.start_time - self.start_time = None - self.conn = None - except (KeyboardInterrupt, SystemExit), exc: - self.server.interrupt = exc - - -class ThreadPool(object): - """A Request Queue for the CherryPyWSGIServer which pools threads. - - ThreadPool objects must provide min, get(), put(obj), start() - and stop(timeout) attributes. - """ - - def __init__(self, server, min=10, max=-1): - self.server = server - self.min = min - self.max = max - self._threads = [] - self._queue = Queue.Queue() - self.get = self._queue.get - - def start(self): - """Start the pool of threads.""" - for i in range(self.min): - self._threads.append(WorkerThread(self.server)) - for worker in self._threads: - worker.setName("CP Server " + worker.getName()) - worker.start() - for worker in self._threads: - while not worker.ready: - time.sleep(.1) - - def _get_idle(self): - """Number of worker threads which are idle. Read-only.""" - return len([t for t in self._threads if t.conn is None]) - idle = property(_get_idle, doc=_get_idle.__doc__) - - def put(self, obj): - self._queue.put(obj) - if obj is _SHUTDOWNREQUEST: - return - - def grow(self, amount): - """Spawn new worker threads (not above self.max).""" - for i in range(amount): - if self.max > 0 and len(self._threads) >= self.max: - break - worker = WorkerThread(self.server) - worker.setName("CP Server " + worker.getName()) - self._threads.append(worker) - worker.start() - - def shrink(self, amount): - """Kill off worker threads (not below self.min).""" - # Grow/shrink the pool if necessary. - # Remove any dead threads from our list - for t in self._threads: - if not t.isAlive(): - self._threads.remove(t) - amount -= 1 - - if amount > 0: - for i in range(min(amount, len(self._threads) - self.min)): - # Put a number of shutdown requests on the queue equal - # to 'amount'. Once each of those is processed by a worker, - # that worker will terminate and be culled from our list - # in self.put. - self._queue.put(_SHUTDOWNREQUEST) - - def stop(self, timeout=5): - # Must shut down threads here so the code that calls - # this method can know when all threads are stopped. - for worker in self._threads: - self._queue.put(_SHUTDOWNREQUEST) - - # Don't join currentThread (when stop is called inside a request). - current = threading.currentThread() - if timeout and timeout >= 0: - endtime = time.time() + timeout - while self._threads: - worker = self._threads.pop() - if worker is not current and worker.isAlive(): - try: - if timeout is None or timeout < 0: - worker.join() - else: - remaining_time = endtime - time.time() - if remaining_time > 0: - worker.join(remaining_time) - if worker.isAlive(): - # We exhausted the timeout. - # Forcibly shut down the socket. - c = worker.conn - if c and not c.rfile.closed: - try: - c.socket.shutdown(socket.SHUT_RD) - except TypeError: - # pyOpenSSL sockets don't take an arg - c.socket.shutdown() - worker.join() - except (AssertionError, - # Ignore repeated Ctrl-C. - # See http://www.cherrypy.org/ticket/691. - KeyboardInterrupt), exc1: - pass - - def _get_qsize(self): - return self._queue.qsize() - qsize = property(_get_qsize) - - - -try: - import fcntl -except ImportError: - try: - from ctypes import windll, WinError - except ImportError: - def prevent_socket_inheritance(sock): - """Dummy function, since neither fcntl nor ctypes are available.""" - pass - else: - def prevent_socket_inheritance(sock): - """Mark the given socket fd as non-inheritable (Windows).""" - if not windll.kernel32.SetHandleInformation(sock.fileno(), 1, 0): - raise WinError() +if sys.version_info < (3, 0): + from wsgiserver2 import * else: - def prevent_socket_inheritance(sock): - """Mark the given socket fd as non-inheritable (POSIX).""" - fd = sock.fileno() - old_flags = fcntl.fcntl(fd, fcntl.F_GETFD) - fcntl.fcntl(fd, fcntl.F_SETFD, old_flags | fcntl.FD_CLOEXEC) - - -class SSLAdapter(object): - """Base class for SSL driver library adapters. - - Required methods: - - * ``wrap(sock) -> (wrapped socket, ssl environ dict)`` - * ``makefile(sock, mode='r', bufsize=DEFAULT_BUFFER_SIZE) -> socket file object`` - """ - - def __init__(self, certificate, private_key, certificate_chain=None): - self.certificate = certificate - self.private_key = private_key - self.certificate_chain = certificate_chain - - def wrap(self, sock): - raise NotImplemented - - def makefile(self, sock, mode='r', bufsize=DEFAULT_BUFFER_SIZE): - raise NotImplemented - - -class HTTPServer(object): - """An HTTP server.""" - - _bind_addr = "127.0.0.1" - _interrupt = None - - gateway = None - """A Gateway instance.""" - - minthreads = None - """The minimum number of worker threads to create (default 10).""" - - maxthreads = None - """The maximum number of worker threads to create (default -1 = no limit).""" - - server_name = None - """The name of the server; defaults to socket.gethostname().""" - - protocol = "HTTP/1.1" - """The version string to write in the Status-Line of all HTTP responses. - - For example, "HTTP/1.1" is the default. This also limits the supported - features used in the response.""" - - request_queue_size = 5 - """The 'backlog' arg to socket.listen(); max queued connections (default 5).""" - - shutdown_timeout = 5 - """The total time, in seconds, to wait for worker threads to cleanly exit.""" - - timeout = 10 - """The timeout in seconds for accepted connections (default 10).""" - - version = "CherryPy/3.2.0" - """A version string for the HTTPServer.""" - - software = None - """The value to set for the SERVER_SOFTWARE entry in the WSGI environ. - - If None, this defaults to ``'%s Server' % self.version``.""" - - ready = False - """An internal flag which marks whether the socket is accepting connections.""" - - max_request_header_size = 0 - """The maximum size, in bytes, for request headers, or 0 for no limit.""" - - max_request_body_size = 0 - """The maximum size, in bytes, for request bodies, or 0 for no limit.""" - - nodelay = True - """If True (the default since 3.1), sets the TCP_NODELAY socket option.""" - - ConnectionClass = HTTPConnection - """The class to use for handling HTTP connections.""" - - ssl_adapter = None - """An instance of SSLAdapter (or a subclass). - - You must have the corresponding SSL driver library installed.""" - - def __init__(self, bind_addr, gateway, minthreads=10, maxthreads=-1, - server_name=None): - self.bind_addr = bind_addr - self.gateway = gateway - - self.requests = ThreadPool(self, min=minthreads or 1, max=maxthreads) - - if not server_name: - server_name = socket.gethostname() - self.server_name = server_name - self.clear_stats() - - def clear_stats(self): - self._start_time = None - self._run_time = 0 - self.stats = { - 'Enabled': False, - 'Bind Address': lambda s: repr(self.bind_addr), - 'Run time': lambda s: (not s['Enabled']) and 0 or self.runtime(), - 'Accepts': 0, - 'Accepts/sec': lambda s: s['Accepts'] / self.runtime(), - 'Queue': lambda s: getattr(self.requests, "qsize", None), - 'Threads': lambda s: len(getattr(self.requests, "_threads", [])), - 'Threads Idle': lambda s: getattr(self.requests, "idle", None), - 'Socket Errors': 0, - 'Requests': lambda s: (not s['Enabled']) and 0 or sum([w['Requests'](w) for w - in s['Worker Threads'].values()], 0), - 'Bytes Read': lambda s: (not s['Enabled']) and 0 or sum([w['Bytes Read'](w) for w - in s['Worker Threads'].values()], 0), - 'Bytes Written': lambda s: (not s['Enabled']) and 0 or sum([w['Bytes Written'](w) for w - in s['Worker Threads'].values()], 0), - 'Work Time': lambda s: (not s['Enabled']) and 0 or sum([w['Work Time'](w) for w - in s['Worker Threads'].values()], 0), - 'Read Throughput': lambda s: (not s['Enabled']) and 0 or sum( - [w['Bytes Read'](w) / (w['Work Time'](w) or 1e-6) - for w in s['Worker Threads'].values()], 0), - 'Write Throughput': lambda s: (not s['Enabled']) and 0 or sum( - [w['Bytes Written'](w) / (w['Work Time'](w) or 1e-6) - for w in s['Worker Threads'].values()], 0), - 'Worker Threads': {}, - } - logging.statistics["CherryPy HTTPServer %d" % id(self)] = self.stats - - def runtime(self): - if self._start_time is None: - return self._run_time - else: - return self._run_time + (time.time() - self._start_time) - - def __str__(self): - return "%s.%s(%r)" % (self.__module__, self.__class__.__name__, - self.bind_addr) - - def _get_bind_addr(self): - return self._bind_addr - def _set_bind_addr(self, value): - if isinstance(value, tuple) and value[0] in ('', None): - # Despite the socket module docs, using '' does not - # allow AI_PASSIVE to work. Passing None instead - # returns '0.0.0.0' like we want. In other words: - # host AI_PASSIVE result - # '' Y 192.168.x.y - # '' N 192.168.x.y - # None Y 0.0.0.0 - # None N 127.0.0.1 - # But since you can get the same effect with an explicit - # '0.0.0.0', we deny both the empty string and None as values. - raise ValueError("Host values of '' or None are not allowed. " - "Use '0.0.0.0' (IPv4) or '::' (IPv6) instead " - "to listen on all active interfaces.") - self._bind_addr = value - bind_addr = property(_get_bind_addr, _set_bind_addr, - doc="""The interface on which to listen for connections. - - For TCP sockets, a (host, port) tuple. Host values may be any IPv4 - or IPv6 address, or any valid hostname. The string 'localhost' is a - synonym for '127.0.0.1' (or '::1', if your hosts file prefers IPv6). - The string '0.0.0.0' is a special IPv4 entry meaning "any active - interface" (INADDR_ANY), and '::' is the similar IN6ADDR_ANY for - IPv6. The empty string or None are not allowed. - - For UNIX sockets, supply the filename as a string.""") - - def start(self): - """Run the server forever.""" - # We don't have to trap KeyboardInterrupt or SystemExit here, - # because cherrpy.server already does so, calling self.stop() for us. - # If you're using this server with another framework, you should - # trap those exceptions in whatever code block calls start(). - self._interrupt = None - - if self.software is None: - self.software = "%s Server" % self.version - - # SSL backward compatibility - if (self.ssl_adapter is None and - getattr(self, 'ssl_certificate', None) and - getattr(self, 'ssl_private_key', None)): - warnings.warn( - "SSL attributes are deprecated in CherryPy 3.2, and will " - "be removed in CherryPy 3.3. Use an ssl_adapter attribute " - "instead.", - DeprecationWarning - ) - try: - from cherrypy.wsgiserver.ssl_pyopenssl import pyOpenSSLAdapter - except ImportError: - pass - else: - self.ssl_adapter = pyOpenSSLAdapter( - self.ssl_certificate, self.ssl_private_key, - getattr(self, 'ssl_certificate_chain', None)) - - # Select the appropriate socket - if isinstance(self.bind_addr, basestring): - # AF_UNIX socket - - # So we can reuse the socket... - try: os.unlink(self.bind_addr) - except: pass - - # So everyone can access the socket... - try: os.chmod(self.bind_addr, 0777) - except: pass - - info = [(socket.AF_UNIX, socket.SOCK_STREAM, 0, "", self.bind_addr)] - else: - # AF_INET or AF_INET6 socket - # Get the correct address family for our host (allows IPv6 addresses) - host, port = self.bind_addr - try: - info = socket.getaddrinfo(host, port, socket.AF_UNSPEC, - socket.SOCK_STREAM, 0, socket.AI_PASSIVE) - except socket.gaierror: - if ':' in self.bind_addr[0]: - info = [(socket.AF_INET6, socket.SOCK_STREAM, - 0, "", self.bind_addr + (0, 0))] - else: - info = [(socket.AF_INET, socket.SOCK_STREAM, - 0, "", self.bind_addr)] - - self.socket = None - msg = "No socket could be created" - for res in info: - af, socktype, proto, canonname, sa = res - try: - self.bind(af, socktype, proto) - except socket.error: - if self.socket: - self.socket.close() - self.socket = None - continue - break - if not self.socket: - raise socket.error(msg) - - # Timeout so KeyboardInterrupt can be caught on Win32 - self.socket.settimeout(1) - self.socket.listen(self.request_queue_size) - - # Create worker threads - self.requests.start() - - self.ready = True - self._start_time = time.time() - while self.ready: - self.tick() - if self.interrupt: - while self.interrupt is True: - # Wait for self.stop() to complete. See _set_interrupt. - time.sleep(0.1) - if self.interrupt: - raise self.interrupt - - def bind(self, family, type, proto=0): - """Create (or recreate) the actual socket object.""" - self.socket = socket.socket(family, type, proto) - prevent_socket_inheritance(self.socket) - self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - if self.nodelay and not isinstance(self.bind_addr, str): - self.socket.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) - - if self.ssl_adapter is not None: - self.socket = self.ssl_adapter.bind(self.socket) - - # If listening on the IPV6 any address ('::' = IN6ADDR_ANY), - # activate dual-stack. See http://www.cherrypy.org/ticket/871. - if (hasattr(socket, 'AF_INET6') and family == socket.AF_INET6 - and self.bind_addr[0] in ('::', '::0', '::0.0.0.0')): - try: - self.socket.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_V6ONLY, 0) - except (AttributeError, socket.error): - # Apparently, the socket option is not available in - # this machine's TCP stack - pass - - self.socket.bind(self.bind_addr) - - def tick(self): - """Accept a new connection and put it on the Queue.""" - try: - s, addr = self.socket.accept() - if self.stats['Enabled']: - self.stats['Accepts'] += 1 - if not self.ready: - return - - prevent_socket_inheritance(s) - if hasattr(s, 'settimeout'): - s.settimeout(self.timeout) - - makefile = CP_fileobject - ssl_env = {} - # if ssl cert and key are set, we try to be a secure HTTP server - if self.ssl_adapter is not None: - try: - s, ssl_env = self.ssl_adapter.wrap(s) - except NoSSLError: - msg = ("The client sent a plain HTTP request, but " - "this server only speaks HTTPS on this port.") - buf = ["%s 400 Bad Request\r\n" % self.protocol, - "Content-Length: %s\r\n" % len(msg), - "Content-Type: text/plain\r\n\r\n", - msg] - - wfile = CP_fileobject(s, "wb", DEFAULT_BUFFER_SIZE) - try: - wfile.sendall("".join(buf)) - except socket.error, x: - if x.args[0] not in socket_errors_to_ignore: - raise - return - if not s: - return - makefile = self.ssl_adapter.makefile - # Re-apply our timeout since we may have a new socket object - if hasattr(s, 'settimeout'): - s.settimeout(self.timeout) - - conn = self.ConnectionClass(self, s, makefile) - - if not isinstance(self.bind_addr, basestring): - # optional values - # Until we do DNS lookups, omit REMOTE_HOST - if addr is None: # sometimes this can happen - # figure out if AF_INET or AF_INET6. - if len(s.getsockname()) == 2: - # AF_INET - addr = ('0.0.0.0', 0) - else: - # AF_INET6 - addr = ('::', 0) - conn.remote_addr = addr[0] - conn.remote_port = addr[1] - - conn.ssl_env = ssl_env - - self.requests.put(conn) - except socket.timeout: - # The only reason for the timeout in start() is so we can - # notice keyboard interrupts on Win32, which don't interrupt - # accept() by default - return - except socket.error, x: - if self.stats['Enabled']: - self.stats['Socket Errors'] += 1 - if x.args[0] in socket_error_eintr: - # I *think* this is right. EINTR should occur when a signal - # is received during the accept() call; all docs say retry - # the call, and I *think* I'm reading it right that Python - # will then go ahead and poll for and handle the signal - # elsewhere. See http://www.cherrypy.org/ticket/707. - return - if x.args[0] in socket_errors_nonblocking: - # Just try again. See http://www.cherrypy.org/ticket/479. - return - if x.args[0] in socket_errors_to_ignore: - # Our socket was closed. - # See http://www.cherrypy.org/ticket/686. - return - raise - - def _get_interrupt(self): - return self._interrupt - def _set_interrupt(self, interrupt): - self._interrupt = True - self.stop() - self._interrupt = interrupt - interrupt = property(_get_interrupt, _set_interrupt, - doc="Set this to an Exception instance to " - "interrupt the server.") - - def stop(self): - """Gracefully shutdown a server that is serving forever.""" - self.ready = False - if self._start_time is not None: - self._run_time += (time.time() - self._start_time) - self._start_time = None - - sock = getattr(self, "socket", None) - if sock: - if not isinstance(self.bind_addr, basestring): - # Touch our own socket to make accept() return immediately. - try: - host, port = sock.getsockname()[:2] - except socket.error, x: - if x.args[0] not in socket_errors_to_ignore: - # Changed to use error code and not message - # See http://www.cherrypy.org/ticket/860. - raise - else: - # Note that we're explicitly NOT using AI_PASSIVE, - # here, because we want an actual IP to touch. - # localhost won't work if we've bound to a public IP, - # but it will if we bound to '0.0.0.0' (INADDR_ANY). - for res in socket.getaddrinfo(host, port, socket.AF_UNSPEC, - socket.SOCK_STREAM): - af, socktype, proto, canonname, sa = res - s = None - try: - s = socket.socket(af, socktype, proto) - # See http://groups.google.com/group/cherrypy-users/ - # browse_frm/thread/bbfe5eb39c904fe0 - s.settimeout(1.0) - s.connect((host, port)) - s.close() - except socket.error: - if s: - s.close() - if hasattr(sock, "close"): - sock.close() - self.socket = None - - self.requests.stop(self.shutdown_timeout) - - -class Gateway(object): - - def __init__(self, req): - self.req = req - - def respond(self): - raise NotImplemented - - -# These may either be wsgiserver.SSLAdapter subclasses or the string names -# of such classes (in which case they will be lazily loaded). -ssl_adapters = { - 'builtin': 'cherrypy.wsgiserver.ssl_builtin.BuiltinSSLAdapter', - 'pyopenssl': 'cherrypy.wsgiserver.ssl_pyopenssl.pyOpenSSLAdapter', - } - -def get_ssl_adapter_class(name='pyopenssl'): - adapter = ssl_adapters[name.lower()] - if isinstance(adapter, basestring): - last_dot = adapter.rfind(".") - attr_name = adapter[last_dot + 1:] - mod_path = adapter[:last_dot] - - try: - mod = sys.modules[mod_path] - if mod is None: - raise KeyError() - except KeyError: - # The last [''] is important. - mod = __import__(mod_path, globals(), locals(), ['']) - - # Let an AttributeError propagate outward. - try: - adapter = getattr(mod, attr_name) - except AttributeError: - raise AttributeError("'%s' object has no attribute '%s'" - % (mod_path, attr_name)) - - return adapter - -# -------------------------------- WSGI Stuff -------------------------------- # - - -class CherryPyWSGIServer(HTTPServer): - - wsgi_version = (1, 0) - - def __init__(self, bind_addr, wsgi_app, numthreads=10, server_name=None, - max=-1, request_queue_size=5, timeout=10, shutdown_timeout=5): - self.requests = ThreadPool(self, min=numthreads or 1, max=max) - self.wsgi_app = wsgi_app - self.gateway = wsgi_gateways[self.wsgi_version] - - self.bind_addr = bind_addr - if not server_name: - server_name = socket.gethostname() - self.server_name = server_name - self.request_queue_size = request_queue_size - - self.timeout = timeout - self.shutdown_timeout = shutdown_timeout - self.clear_stats() - - def _get_numthreads(self): - return self.requests.min - def _set_numthreads(self, value): - self.requests.min = value - numthreads = property(_get_numthreads, _set_numthreads) - - -class WSGIGateway(Gateway): - - def __init__(self, req): - self.req = req - self.started_response = False - self.env = self.get_environ() - self.remaining_bytes_out = None - - def get_environ(self): - """Return a new environ dict targeting the given wsgi.version""" - raise NotImplemented - - def respond(self): - response = self.req.server.wsgi_app(self.env, self.start_response) - try: - for chunk in response: - # "The start_response callable must not actually transmit - # the response headers. Instead, it must store them for the - # server or gateway to transmit only after the first - # iteration of the application return value that yields - # a NON-EMPTY string, or upon the application's first - # invocation of the write() callable." (PEP 333) - if chunk: - if isinstance(chunk, unicode): - chunk = chunk.encode('ISO-8859-1') - self.write(chunk) - finally: - if hasattr(response, "close"): - response.close() - - def start_response(self, status, headers, exc_info = None): - """WSGI callable to begin the HTTP response.""" - # "The application may call start_response more than once, - # if and only if the exc_info argument is provided." - if self.started_response and not exc_info: - raise AssertionError("WSGI start_response called a second " - "time with no exc_info.") - self.started_response = True - - # "if exc_info is provided, and the HTTP headers have already been - # sent, start_response must raise an error, and should raise the - # exc_info tuple." - if self.req.sent_headers: - try: - raise exc_info[0], exc_info[1], exc_info[2] - finally: - exc_info = None - - self.req.status = status - for k, v in headers: - if not isinstance(k, str): - raise TypeError("WSGI response header key %r is not a byte string." % k) - if not isinstance(v, str): - raise TypeError("WSGI response header value %r is not a byte string." % v) - if k.lower() == 'content-length': - self.remaining_bytes_out = int(v) - self.req.outheaders.extend(headers) - - return self.write - - def write(self, chunk): - """WSGI callable to write unbuffered data to the client. - - This method is also used internally by start_response (to write - data from the iterable returned by the WSGI application). - """ - if not self.started_response: - raise AssertionError("WSGI write called before start_response.") - - chunklen = len(chunk) - rbo = self.remaining_bytes_out - if rbo is not None and chunklen > rbo: - if not self.req.sent_headers: - # Whew. We can send a 500 to the client. - self.req.simple_response("500 Internal Server Error", - "The requested resource returned more bytes than the " - "declared Content-Length.") - else: - # Dang. We have probably already sent data. Truncate the chunk - # to fit (so the client doesn't hang) and raise an error later. - chunk = chunk[:rbo] - - if not self.req.sent_headers: - self.req.sent_headers = True - self.req.send_headers() - - self.req.write(chunk) - - if rbo is not None: - rbo -= chunklen - if rbo < 0: - raise ValueError( - "Response body exceeds the declared Content-Length.") - - -class WSGIGateway_10(WSGIGateway): - - def get_environ(self): - """Return a new environ dict targeting the given wsgi.version""" - req = self.req - env = { - # set a non-standard environ entry so the WSGI app can know what - # the *real* server protocol is (and what features to support). - # See http://www.faqs.org/rfcs/rfc2145.html. - 'ACTUAL_SERVER_PROTOCOL': req.server.protocol, - 'PATH_INFO': req.path, - 'QUERY_STRING': req.qs, - 'REMOTE_ADDR': req.conn.remote_addr or '', - 'REMOTE_PORT': str(req.conn.remote_port or ''), - 'REQUEST_METHOD': req.method, - 'REQUEST_URI': req.uri, - 'SCRIPT_NAME': '', - 'SERVER_NAME': req.server.server_name, - # Bah. "SERVER_PROTOCOL" is actually the REQUEST protocol. - 'SERVER_PROTOCOL': req.request_protocol, - 'SERVER_SOFTWARE': req.server.software, - 'wsgi.errors': sys.stderr, - 'wsgi.input': req.rfile, - 'wsgi.multiprocess': False, - 'wsgi.multithread': True, - 'wsgi.run_once': False, - 'wsgi.url_scheme': req.scheme, - 'wsgi.version': (1, 0), - } - - if isinstance(req.server.bind_addr, basestring): - # AF_UNIX. This isn't really allowed by WSGI, which doesn't - # address unix domain sockets. But it's better than nothing. - env["SERVER_PORT"] = "" - else: - env["SERVER_PORT"] = str(req.server.bind_addr[1]) - - # Request headers - for k, v in req.inheaders.iteritems(): - env["HTTP_" + k.upper().replace("-", "_")] = v - - # CONTENT_TYPE/CONTENT_LENGTH - ct = env.pop("HTTP_CONTENT_TYPE", None) - if ct is not None: - env["CONTENT_TYPE"] = ct - cl = env.pop("HTTP_CONTENT_LENGTH", None) - if cl is not None: - env["CONTENT_LENGTH"] = cl - - if req.conn.ssl_env: - env.update(req.conn.ssl_env) - - return env - - -class WSGIGateway_u0(WSGIGateway_10): - - def get_environ(self): - """Return a new environ dict targeting the given wsgi.version""" - req = self.req - env_10 = WSGIGateway_10.get_environ(self) - env = dict([(k.decode('ISO-8859-1'), v) for k, v in env_10.iteritems()]) - env[u'wsgi.version'] = ('u', 0) - - # Request-URI - env.setdefault(u'wsgi.url_encoding', u'utf-8') - try: - for key in [u"PATH_INFO", u"SCRIPT_NAME", u"QUERY_STRING"]: - env[key] = env_10[str(key)].decode(env[u'wsgi.url_encoding']) - except UnicodeDecodeError: - # Fall back to latin 1 so apps can transcode if needed. - env[u'wsgi.url_encoding'] = u'ISO-8859-1' - for key in [u"PATH_INFO", u"SCRIPT_NAME", u"QUERY_STRING"]: - env[key] = env_10[str(key)].decode(env[u'wsgi.url_encoding']) - - for k, v in sorted(env.items()): - if isinstance(v, str) and k not in ('REQUEST_URI', 'wsgi.input'): - env[k] = v.decode('ISO-8859-1') - - return env - -wsgi_gateways = { - (1, 0): WSGIGateway_10, - ('u', 0): WSGIGateway_u0, -} - -class WSGIPathInfoDispatcher(object): - """A WSGI dispatcher for dispatch based on the PATH_INFO. - - apps: a dict or list of (path_prefix, app) pairs. - """ - - def __init__(self, apps): - try: - apps = apps.items() - except AttributeError: - pass - - # Sort the apps by len(path), descending - apps.sort(cmp=lambda x,y: cmp(len(x[0]), len(y[0]))) - apps.reverse() - - # The path_prefix strings must start, but not end, with a slash. - # Use "" instead of "/". - self.apps = [(p.rstrip("/"), a) for p, a in apps] - - def __call__(self, environ, start_response): - path = environ["PATH_INFO"] or "/" - for p, app in self.apps: - # The apps list should be sorted by length, descending. - if path.startswith(p + "/") or path == p: - environ = environ.copy() - environ["SCRIPT_NAME"] = environ["SCRIPT_NAME"] + p - environ["PATH_INFO"] = path[len(p):] - return app(environ, start_response) - - start_response('404 Not Found', [('Content-Type', 'text/plain'), - ('Content-Length', '0')]) - return [''] - + # Le sigh. Boo for backward-incompatible syntax. + exec('from .wsgiserver3 import *') diff --git a/cherrypy/wsgiserver/ssl_builtin.py b/cherrypy/wsgiserver/ssl_builtin.py index 64c0eeb0..03bf05de 100644 --- a/cherrypy/wsgiserver/ssl_builtin.py +++ b/cherrypy/wsgiserver/ssl_builtin.py @@ -11,6 +11,16 @@ try: except ImportError: ssl = None +try: + from _pyio import DEFAULT_BUFFER_SIZE +except ImportError: + try: + from io import DEFAULT_BUFFER_SIZE + except ImportError: + DEFAULT_BUFFER_SIZE = -1 + +import sys + from cherrypy import wsgiserver @@ -40,7 +50,8 @@ class BuiltinSSLAdapter(wsgiserver.SSLAdapter): s = ssl.wrap_socket(sock, do_handshake_on_connect=True, server_side=True, certfile=self.certificate, keyfile=self.private_key, ssl_version=ssl.PROTOCOL_SSLv23) - except ssl.SSLError, e: + except ssl.SSLError: + e = sys.exc_info()[1] if e.errno == ssl.SSL_ERROR_EOF: # This is almost certainly due to the cherrypy engine # 'pinging' the socket to assert it's connectable; @@ -50,6 +61,10 @@ class BuiltinSSLAdapter(wsgiserver.SSLAdapter): if e.args[1].endswith('http request'): # The client is speaking HTTP to an HTTPS server. raise wsgiserver.NoSSLError + elif e.args[1].endswith('unknown protocol'): + # The client is speaking some non-HTTP protocol. + # Drop the conn. + return None, {} raise return s, self.get_environ(s) @@ -67,6 +82,10 @@ class BuiltinSSLAdapter(wsgiserver.SSLAdapter): } return ssl_environ - def makefile(self, sock, mode='r', bufsize=-1): - return wsgiserver.CP_fileobject(sock, mode, bufsize) + if sys.version_info >= (3, 0): + def makefile(self, sock, mode='r', bufsize=DEFAULT_BUFFER_SIZE): + return wsgiserver.CP_makefile(sock, mode, bufsize) + else: + def makefile(self, sock, mode='r', bufsize=DEFAULT_BUFFER_SIZE): + return wsgiserver.CP_fileobject(sock, mode, bufsize) diff --git a/cherrypy/wsgiserver/wsgiserver2.py b/cherrypy/wsgiserver/wsgiserver2.py new file mode 100644 index 00000000..b6bd4997 --- /dev/null +++ b/cherrypy/wsgiserver/wsgiserver2.py @@ -0,0 +1,2322 @@ +"""A high-speed, production ready, thread pooled, generic HTTP server. + +Simplest example on how to use this module directly +(without using CherryPy's application machinery):: + + from cherrypy import wsgiserver + + def my_crazy_app(environ, start_response): + status = '200 OK' + response_headers = [('Content-type','text/plain')] + start_response(status, response_headers) + return ['Hello world!'] + + server = wsgiserver.CherryPyWSGIServer( + ('0.0.0.0', 8070), my_crazy_app, + server_name='www.cherrypy.example') + server.start() + +The CherryPy WSGI server can serve as many WSGI applications +as you want in one instance by using a WSGIPathInfoDispatcher:: + + d = WSGIPathInfoDispatcher({'/': my_crazy_app, '/blog': my_blog_app}) + server = wsgiserver.CherryPyWSGIServer(('0.0.0.0', 80), d) + +Want SSL support? Just set server.ssl_adapter to an SSLAdapter instance. + +This won't call the CherryPy engine (application side) at all, only the +HTTP server, which is independent from the rest of CherryPy. Don't +let the name "CherryPyWSGIServer" throw you; the name merely reflects +its origin, not its coupling. + +For those of you wanting to understand internals of this module, here's the +basic call flow. The server's listening thread runs a very tight loop, +sticking incoming connections onto a Queue:: + + server = CherryPyWSGIServer(...) + server.start() + while True: + tick() + # This blocks until a request comes in: + child = socket.accept() + conn = HTTPConnection(child, ...) + server.requests.put(conn) + +Worker threads are kept in a pool and poll the Queue, popping off and then +handling each connection in turn. Each connection can consist of an arbitrary +number of requests and their responses, so we run a nested loop:: + + while True: + conn = server.requests.get() + conn.communicate() + -> while True: + req = HTTPRequest(...) + req.parse_request() + -> # Read the Request-Line, e.g. "GET /page HTTP/1.1" + req.rfile.readline() + read_headers(req.rfile, req.inheaders) + req.respond() + -> response = app(...) + try: + for chunk in response: + if chunk: + req.write(chunk) + finally: + if hasattr(response, "close"): + response.close() + if req.close_connection: + return +""" + +__all__ = ['HTTPRequest', 'HTTPConnection', 'HTTPServer', + 'SizeCheckWrapper', 'KnownLengthRFile', 'ChunkedRFile', + 'CP_fileobject', + 'MaxSizeExceeded', 'NoSSLError', 'FatalSSLAlert', + 'WorkerThread', 'ThreadPool', 'SSLAdapter', + 'CherryPyWSGIServer', + 'Gateway', 'WSGIGateway', 'WSGIGateway_10', 'WSGIGateway_u0', + 'WSGIPathInfoDispatcher', 'get_ssl_adapter_class'] + +import os +try: + import queue +except: + import Queue as queue +import re +import rfc822 +import socket +import sys +if 'win' in sys.platform and not hasattr(socket, 'IPPROTO_IPV6'): + socket.IPPROTO_IPV6 = 41 +try: + import cStringIO as StringIO +except ImportError: + import StringIO +DEFAULT_BUFFER_SIZE = -1 + +_fileobject_uses_str_type = isinstance(socket._fileobject(None)._rbuf, basestring) + +import threading +import time +import traceback +def format_exc(limit=None): + """Like print_exc() but return a string. Backport for Python 2.3.""" + try: + etype, value, tb = sys.exc_info() + return ''.join(traceback.format_exception(etype, value, tb, limit)) + finally: + etype = value = tb = None + + +from urllib import unquote +from urlparse import urlparse +import warnings + +if sys.version_info >= (3, 0): + bytestr = bytes + unicodestr = str + basestring = (bytes, str) + def ntob(n, encoding='ISO-8859-1'): + """Return the given native string as a byte string in the given encoding.""" + # In Python 3, the native string type is unicode + return n.encode(encoding) +else: + bytestr = str + unicodestr = unicode + basestring = basestring + def ntob(n, encoding='ISO-8859-1'): + """Return the given native string as a byte string in the given encoding.""" + # In Python 2, the native string type is bytes. Assume it's already + # in the given encoding, which for ISO-8859-1 is almost always what + # was intended. + return n + +LF = ntob('\n') +CRLF = ntob('\r\n') +TAB = ntob('\t') +SPACE = ntob(' ') +COLON = ntob(':') +SEMICOLON = ntob(';') +EMPTY = ntob('') +NUMBER_SIGN = ntob('#') +QUESTION_MARK = ntob('?') +ASTERISK = ntob('*') +FORWARD_SLASH = ntob('/') +quoted_slash = re.compile(ntob("(?i)%2F")) + +import errno + +def plat_specific_errors(*errnames): + """Return error numbers for all errors in errnames on this platform. + + The 'errno' module contains different global constants depending on + the specific platform (OS). This function will return the list of + numeric values for a given list of potential names. + """ + errno_names = dir(errno) + nums = [getattr(errno, k) for k in errnames if k in errno_names] + # de-dupe the list + return list(dict.fromkeys(nums).keys()) + +socket_error_eintr = plat_specific_errors("EINTR", "WSAEINTR") + +socket_errors_to_ignore = plat_specific_errors( + "EPIPE", + "EBADF", "WSAEBADF", + "ENOTSOCK", "WSAENOTSOCK", + "ETIMEDOUT", "WSAETIMEDOUT", + "ECONNREFUSED", "WSAECONNREFUSED", + "ECONNRESET", "WSAECONNRESET", + "ECONNABORTED", "WSAECONNABORTED", + "ENETRESET", "WSAENETRESET", + "EHOSTDOWN", "EHOSTUNREACH", + ) +socket_errors_to_ignore.append("timed out") +socket_errors_to_ignore.append("The read operation timed out") + +socket_errors_nonblocking = plat_specific_errors( + 'EAGAIN', 'EWOULDBLOCK', 'WSAEWOULDBLOCK') + +comma_separated_headers = [ntob(h) for h in + ['Accept', 'Accept-Charset', 'Accept-Encoding', + 'Accept-Language', 'Accept-Ranges', 'Allow', 'Cache-Control', + 'Connection', 'Content-Encoding', 'Content-Language', 'Expect', + 'If-Match', 'If-None-Match', 'Pragma', 'Proxy-Authenticate', 'TE', + 'Trailer', 'Transfer-Encoding', 'Upgrade', 'Vary', 'Via', 'Warning', + 'WWW-Authenticate']] + + +import logging +if not hasattr(logging, 'statistics'): logging.statistics = {} + + +def read_headers(rfile, hdict=None): + """Read headers from the given stream into the given header dict. + + If hdict is None, a new header dict is created. Returns the populated + header dict. + + Headers which are repeated are folded together using a comma if their + specification so dictates. + + This function raises ValueError when the read bytes violate the HTTP spec. + You should probably return "400 Bad Request" if this happens. + """ + if hdict is None: + hdict = {} + + while True: + line = rfile.readline() + if not line: + # No more data--illegal end of headers + raise ValueError("Illegal end of headers.") + + if line == CRLF: + # Normal end of headers + break + if not line.endswith(CRLF): + raise ValueError("HTTP requires CRLF terminators") + + if line[0] in (SPACE, TAB): + # It's a continuation line. + v = line.strip() + else: + try: + k, v = line.split(COLON, 1) + except ValueError: + raise ValueError("Illegal header line.") + # TODO: what about TE and WWW-Authenticate? + k = k.strip().title() + v = v.strip() + hname = k + + if k in comma_separated_headers: + existing = hdict.get(hname) + if existing: + v = ", ".join((existing, v)) + hdict[hname] = v + + return hdict + + +class MaxSizeExceeded(Exception): + pass + +class SizeCheckWrapper(object): + """Wraps a file-like object, raising MaxSizeExceeded if too large.""" + + def __init__(self, rfile, maxlen): + self.rfile = rfile + self.maxlen = maxlen + self.bytes_read = 0 + + def _check_length(self): + if self.maxlen and self.bytes_read > self.maxlen: + raise MaxSizeExceeded() + + def read(self, size=None): + data = self.rfile.read(size) + self.bytes_read += len(data) + self._check_length() + return data + + def readline(self, size=None): + if size is not None: + data = self.rfile.readline(size) + self.bytes_read += len(data) + self._check_length() + return data + + # User didn't specify a size ... + # We read the line in chunks to make sure it's not a 100MB line ! + res = [] + while True: + data = self.rfile.readline(256) + self.bytes_read += len(data) + self._check_length() + res.append(data) + # See http://www.cherrypy.org/ticket/421 + if len(data) < 256 or data[-1:] == "\n": + return EMPTY.join(res) + + def readlines(self, sizehint=0): + # Shamelessly stolen from StringIO + total = 0 + lines = [] + line = self.readline() + while line: + lines.append(line) + total += len(line) + if 0 < sizehint <= total: + break + line = self.readline() + return lines + + def close(self): + self.rfile.close() + + def __iter__(self): + return self + + def __next__(self): + data = next(self.rfile) + self.bytes_read += len(data) + self._check_length() + return data + + def next(self): + data = self.rfile.next() + self.bytes_read += len(data) + self._check_length() + return data + + +class KnownLengthRFile(object): + """Wraps a file-like object, returning an empty string when exhausted.""" + + def __init__(self, rfile, content_length): + self.rfile = rfile + self.remaining = content_length + + def read(self, size=None): + if self.remaining == 0: + return '' + if size is None: + size = self.remaining + else: + size = min(size, self.remaining) + + data = self.rfile.read(size) + self.remaining -= len(data) + return data + + def readline(self, size=None): + if self.remaining == 0: + return '' + if size is None: + size = self.remaining + else: + size = min(size, self.remaining) + + data = self.rfile.readline(size) + self.remaining -= len(data) + return data + + def readlines(self, sizehint=0): + # Shamelessly stolen from StringIO + total = 0 + lines = [] + line = self.readline(sizehint) + while line: + lines.append(line) + total += len(line) + if 0 < sizehint <= total: + break + line = self.readline(sizehint) + return lines + + def close(self): + self.rfile.close() + + def __iter__(self): + return self + + def __next__(self): + data = next(self.rfile) + self.remaining -= len(data) + return data + + +class ChunkedRFile(object): + """Wraps a file-like object, returning an empty string when exhausted. + + This class is intended to provide a conforming wsgi.input value for + request entities that have been encoded with the 'chunked' transfer + encoding. + """ + + def __init__(self, rfile, maxlen, bufsize=8192): + self.rfile = rfile + self.maxlen = maxlen + self.bytes_read = 0 + self.buffer = EMPTY + self.bufsize = bufsize + self.closed = False + + def _fetch(self): + if self.closed: + return + + line = self.rfile.readline() + self.bytes_read += len(line) + + if self.maxlen and self.bytes_read > self.maxlen: + raise MaxSizeExceeded("Request Entity Too Large", self.maxlen) + + line = line.strip().split(SEMICOLON, 1) + + try: + chunk_size = line.pop(0) + chunk_size = int(chunk_size, 16) + except ValueError: + raise ValueError("Bad chunked transfer size: " + repr(chunk_size)) + + if chunk_size <= 0: + self.closed = True + return + +## if line: chunk_extension = line[0] + + if self.maxlen and self.bytes_read + chunk_size > self.maxlen: + raise IOError("Request Entity Too Large") + + chunk = self.rfile.read(chunk_size) + self.bytes_read += len(chunk) + self.buffer += chunk + + crlf = self.rfile.read(2) + if crlf != CRLF: + raise ValueError( + "Bad chunked transfer coding (expected '\\r\\n', " + "got " + repr(crlf) + ")") + + def read(self, size=None): + data = EMPTY + while True: + if size and len(data) >= size: + return data + + if not self.buffer: + self._fetch() + if not self.buffer: + # EOF + return data + + if size: + remaining = size - len(data) + data += self.buffer[:remaining] + self.buffer = self.buffer[remaining:] + else: + data += self.buffer + + def readline(self, size=None): + data = EMPTY + while True: + if size and len(data) >= size: + return data + + if not self.buffer: + self._fetch() + if not self.buffer: + # EOF + return data + + newline_pos = self.buffer.find(LF) + if size: + if newline_pos == -1: + remaining = size - len(data) + data += self.buffer[:remaining] + self.buffer = self.buffer[remaining:] + else: + remaining = min(size - len(data), newline_pos) + data += self.buffer[:remaining] + self.buffer = self.buffer[remaining:] + else: + if newline_pos == -1: + data += self.buffer + else: + data += self.buffer[:newline_pos] + self.buffer = self.buffer[newline_pos:] + + def readlines(self, sizehint=0): + # Shamelessly stolen from StringIO + total = 0 + lines = [] + line = self.readline(sizehint) + while line: + lines.append(line) + total += len(line) + if 0 < sizehint <= total: + break + line = self.readline(sizehint) + return lines + + def read_trailer_lines(self): + if not self.closed: + raise ValueError( + "Cannot read trailers until the request body has been read.") + + while True: + line = self.rfile.readline() + if not line: + # No more data--illegal end of headers + raise ValueError("Illegal end of headers.") + + self.bytes_read += len(line) + if self.maxlen and self.bytes_read > self.maxlen: + raise IOError("Request Entity Too Large") + + if line == CRLF: + # Normal end of headers + break + if not line.endswith(CRLF): + raise ValueError("HTTP requires CRLF terminators") + + yield line + + def close(self): + self.rfile.close() + + def __iter__(self): + # Shamelessly stolen from StringIO + total = 0 + line = self.readline(sizehint) + while line: + yield line + total += len(line) + if 0 < sizehint <= total: + break + line = self.readline(sizehint) + + +class HTTPRequest(object): + """An HTTP Request (and response). + + A single HTTP connection may consist of multiple request/response pairs. + """ + + server = None + """The HTTPServer object which is receiving this request.""" + + conn = None + """The HTTPConnection object on which this request connected.""" + + inheaders = {} + """A dict of request headers.""" + + outheaders = [] + """A list of header tuples to write in the response.""" + + ready = False + """When True, the request has been parsed and is ready to begin generating + the response. When False, signals the calling Connection that the response + should not be generated and the connection should close.""" + + close_connection = False + """Signals the calling Connection that the request should close. This does + not imply an error! The client and/or server may each request that the + connection be closed.""" + + chunked_write = False + """If True, output will be encoded with the "chunked" transfer-coding. + + This value is set automatically inside send_headers.""" + + def __init__(self, server, conn): + self.server= server + self.conn = conn + + self.ready = False + self.started_request = False + self.scheme = ntob("http") + if self.server.ssl_adapter is not None: + self.scheme = ntob("https") + # Use the lowest-common protocol in case read_request_line errors. + self.response_protocol = 'HTTP/1.0' + self.inheaders = {} + + self.status = "" + self.outheaders = [] + self.sent_headers = False + self.close_connection = self.__class__.close_connection + self.chunked_read = False + self.chunked_write = self.__class__.chunked_write + + def parse_request(self): + """Parse the next HTTP request start-line and message-headers.""" + self.rfile = SizeCheckWrapper(self.conn.rfile, + self.server.max_request_header_size) + try: + success = self.read_request_line() + except MaxSizeExceeded: + self.simple_response("414 Request-URI Too Long", + "The Request-URI sent with the request exceeds the maximum " + "allowed bytes.") + return + else: + if not success: + return + + try: + success = self.read_request_headers() + except MaxSizeExceeded: + self.simple_response("413 Request Entity Too Large", + "The headers sent with the request exceed the maximum " + "allowed bytes.") + return + else: + if not success: + return + + self.ready = True + + def read_request_line(self): + # HTTP/1.1 connections are persistent by default. If a client + # requests a page, then idles (leaves the connection open), + # then rfile.readline() will raise socket.error("timed out"). + # Note that it does this based on the value given to settimeout(), + # and doesn't need the client to request or acknowledge the close + # (although your TCP stack might suffer for it: cf Apache's history + # with FIN_WAIT_2). + request_line = self.rfile.readline() + + # Set started_request to True so communicate() knows to send 408 + # from here on out. + self.started_request = True + if not request_line: + return False + + if request_line == CRLF: + # RFC 2616 sec 4.1: "...if the server is reading the protocol + # stream at the beginning of a message and receives a CRLF + # first, it should ignore the CRLF." + # But only ignore one leading line! else we enable a DoS. + request_line = self.rfile.readline() + if not request_line: + return False + + if not request_line.endswith(CRLF): + self.simple_response("400 Bad Request", "HTTP requires CRLF terminators") + return False + + try: + method, uri, req_protocol = request_line.strip().split(SPACE, 2) + rp = int(req_protocol[5]), int(req_protocol[7]) + except (ValueError, IndexError): + self.simple_response("400 Bad Request", "Malformed Request-Line") + return False + + self.uri = uri + self.method = method + + # uri may be an abs_path (including "http://host.domain.tld"); + scheme, authority, path = self.parse_request_uri(uri) + if NUMBER_SIGN in path: + self.simple_response("400 Bad Request", + "Illegal #fragment in Request-URI.") + return False + + if scheme: + self.scheme = scheme + + qs = EMPTY + if QUESTION_MARK in path: + path, qs = path.split(QUESTION_MARK, 1) + + # Unquote the path+params (e.g. "/this%20path" -> "/this path"). + # http://www.w3.org/Protocols/rfc2616/rfc2616-sec5.html#sec5.1.2 + # + # But note that "...a URI must be separated into its components + # before the escaped characters within those components can be + # safely decoded." http://www.ietf.org/rfc/rfc2396.txt, sec 2.4.2 + # Therefore, "/this%2Fpath" becomes "/this%2Fpath", not "/this/path". + try: + atoms = [unquote(x) for x in quoted_slash.split(path)] + except ValueError: + ex = sys.exc_info()[1] + self.simple_response("400 Bad Request", ex.args[0]) + return False + path = "%2F".join(atoms) + self.path = path + + # Note that, like wsgiref and most other HTTP servers, + # we "% HEX HEX"-unquote the path but not the query string. + self.qs = qs + + # Compare request and server HTTP protocol versions, in case our + # server does not support the requested protocol. Limit our output + # to min(req, server). We want the following output: + # request server actual written supported response + # protocol protocol response protocol feature set + # a 1.0 1.0 1.0 1.0 + # b 1.0 1.1 1.1 1.0 + # c 1.1 1.0 1.0 1.0 + # d 1.1 1.1 1.1 1.1 + # Notice that, in (b), the response will be "HTTP/1.1" even though + # the client only understands 1.0. RFC 2616 10.5.6 says we should + # only return 505 if the _major_ version is different. + sp = int(self.server.protocol[5]), int(self.server.protocol[7]) + + if sp[0] != rp[0]: + self.simple_response("505 HTTP Version Not Supported") + return False + + self.request_protocol = req_protocol + self.response_protocol = "HTTP/%s.%s" % min(rp, sp) + + return True + + def read_request_headers(self): + """Read self.rfile into self.inheaders. Return success.""" + + # then all the http headers + try: + read_headers(self.rfile, self.inheaders) + except ValueError: + ex = sys.exc_info()[1] + self.simple_response("400 Bad Request", ex.args[0]) + return False + + mrbs = self.server.max_request_body_size + if mrbs and int(self.inheaders.get("Content-Length", 0)) > mrbs: + self.simple_response("413 Request Entity Too Large", + "The entity sent with the request exceeds the maximum " + "allowed bytes.") + return False + + # Persistent connection support + if self.response_protocol == "HTTP/1.1": + # Both server and client are HTTP/1.1 + if self.inheaders.get("Connection", "") == "close": + self.close_connection = True + else: + # Either the server or client (or both) are HTTP/1.0 + if self.inheaders.get("Connection", "") != "Keep-Alive": + self.close_connection = True + + # Transfer-Encoding support + te = None + if self.response_protocol == "HTTP/1.1": + te = self.inheaders.get("Transfer-Encoding") + if te: + te = [x.strip().lower() for x in te.split(",") if x.strip()] + + self.chunked_read = False + + if te: + for enc in te: + if enc == "chunked": + self.chunked_read = True + else: + # Note that, even if we see "chunked", we must reject + # if there is an extension we don't recognize. + self.simple_response("501 Unimplemented") + self.close_connection = True + return False + + # From PEP 333: + # "Servers and gateways that implement HTTP 1.1 must provide + # transparent support for HTTP 1.1's "expect/continue" mechanism. + # This may be done in any of several ways: + # 1. Respond to requests containing an Expect: 100-continue request + # with an immediate "100 Continue" response, and proceed normally. + # 2. Proceed with the request normally, but provide the application + # with a wsgi.input stream that will send the "100 Continue" + # response if/when the application first attempts to read from + # the input stream. The read request must then remain blocked + # until the client responds. + # 3. Wait until the client decides that the server does not support + # expect/continue, and sends the request body on its own. + # (This is suboptimal, and is not recommended.) + # + # We used to do 3, but are now doing 1. Maybe we'll do 2 someday, + # but it seems like it would be a big slowdown for such a rare case. + if self.inheaders.get("Expect", "") == "100-continue": + # Don't use simple_response here, because it emits headers + # we don't want. See http://www.cherrypy.org/ticket/951 + msg = self.server.protocol + " 100 Continue\r\n\r\n" + try: + self.conn.wfile.sendall(msg) + except socket.error: + x = sys.exc_info()[1] + if x.args[0] not in socket_errors_to_ignore: + raise + return True + + def parse_request_uri(self, uri): + """Parse a Request-URI into (scheme, authority, path). + + Note that Request-URI's must be one of:: + + Request-URI = "*" | absoluteURI | abs_path | authority + + Therefore, a Request-URI which starts with a double forward-slash + cannot be a "net_path":: + + net_path = "//" authority [ abs_path ] + + Instead, it must be interpreted as an "abs_path" with an empty first + path segment:: + + abs_path = "/" path_segments + path_segments = segment *( "/" segment ) + segment = *pchar *( ";" param ) + param = *pchar + """ + if uri == ASTERISK: + return None, None, uri + + i = uri.find('://') + if i > 0 and QUESTION_MARK not in uri[:i]: + # An absoluteURI. + # If there's a scheme (and it must be http or https), then: + # http_URL = "http:" "//" host [ ":" port ] [ abs_path [ "?" query ]] + scheme, remainder = uri[:i].lower(), uri[i + 3:] + authority, path = remainder.split(FORWARD_SLASH, 1) + path = FORWARD_SLASH + path + return scheme, authority, path + + if uri.startswith(FORWARD_SLASH): + # An abs_path. + return None, None, uri + else: + # An authority. + return None, uri, None + + def respond(self): + """Call the gateway and write its iterable output.""" + mrbs = self.server.max_request_body_size + if self.chunked_read: + self.rfile = ChunkedRFile(self.conn.rfile, mrbs) + else: + cl = int(self.inheaders.get("Content-Length", 0)) + if mrbs and mrbs < cl: + if not self.sent_headers: + self.simple_response("413 Request Entity Too Large", + "The entity sent with the request exceeds the maximum " + "allowed bytes.") + return + self.rfile = KnownLengthRFile(self.conn.rfile, cl) + + self.server.gateway(self).respond() + + if (self.ready and not self.sent_headers): + self.sent_headers = True + self.send_headers() + if self.chunked_write: + self.conn.wfile.sendall("0\r\n\r\n") + + def simple_response(self, status, msg=""): + """Write a simple response back to the client.""" + status = str(status) + buf = [self.server.protocol + SPACE + + status + CRLF, + "Content-Length: %s\r\n" % len(msg), + "Content-Type: text/plain\r\n"] + + if status[:3] in ("413", "414"): + # Request Entity Too Large / Request-URI Too Long + self.close_connection = True + if self.response_protocol == 'HTTP/1.1': + # This will not be true for 414, since read_request_line + # usually raises 414 before reading the whole line, and we + # therefore cannot know the proper response_protocol. + buf.append("Connection: close\r\n") + else: + # HTTP/1.0 had no 413/414 status nor Connection header. + # Emit 400 instead and trust the message body is enough. + status = "400 Bad Request" + + buf.append(CRLF) + if msg: + if isinstance(msg, unicodestr): + msg = msg.encode("ISO-8859-1") + buf.append(msg) + + try: + self.conn.wfile.sendall("".join(buf)) + except socket.error: + x = sys.exc_info()[1] + if x.args[0] not in socket_errors_to_ignore: + raise + + def write(self, chunk): + """Write unbuffered data to the client.""" + if self.chunked_write and chunk: + buf = [hex(len(chunk))[2:], CRLF, chunk, CRLF] + self.conn.wfile.sendall(EMPTY.join(buf)) + else: + self.conn.wfile.sendall(chunk) + + def send_headers(self): + """Assert, process, and send the HTTP response message-headers. + + You must set self.status, and self.outheaders before calling this. + """ + hkeys = [key.lower() for key, value in self.outheaders] + status = int(self.status[:3]) + + if status == 413: + # Request Entity Too Large. Close conn to avoid garbage. + self.close_connection = True + elif "content-length" not in hkeys: + # "All 1xx (informational), 204 (no content), + # and 304 (not modified) responses MUST NOT + # include a message-body." So no point chunking. + if status < 200 or status in (204, 205, 304): + pass + else: + if (self.response_protocol == 'HTTP/1.1' + and self.method != 'HEAD'): + # Use the chunked transfer-coding + self.chunked_write = True + self.outheaders.append(("Transfer-Encoding", "chunked")) + else: + # Closing the conn is the only way to determine len. + self.close_connection = True + + if "connection" not in hkeys: + if self.response_protocol == 'HTTP/1.1': + # Both server and client are HTTP/1.1 or better + if self.close_connection: + self.outheaders.append(("Connection", "close")) + else: + # Server and/or client are HTTP/1.0 + if not self.close_connection: + self.outheaders.append(("Connection", "Keep-Alive")) + + if (not self.close_connection) and (not self.chunked_read): + # Read any remaining request body data on the socket. + # "If an origin server receives a request that does not include an + # Expect request-header field with the "100-continue" expectation, + # the request includes a request body, and the server responds + # with a final status code before reading the entire request body + # from the transport connection, then the server SHOULD NOT close + # the transport connection until it has read the entire request, + # or until the client closes the connection. Otherwise, the client + # might not reliably receive the response message. However, this + # requirement is not be construed as preventing a server from + # defending itself against denial-of-service attacks, or from + # badly broken client implementations." + remaining = getattr(self.rfile, 'remaining', 0) + if remaining > 0: + self.rfile.read(remaining) + + if "date" not in hkeys: + self.outheaders.append(("Date", rfc822.formatdate())) + + if "server" not in hkeys: + self.outheaders.append(("Server", self.server.server_name)) + + buf = [self.server.protocol + SPACE + self.status + CRLF] + for k, v in self.outheaders: + buf.append(k + COLON + SPACE + v + CRLF) + buf.append(CRLF) + self.conn.wfile.sendall(EMPTY.join(buf)) + + +class NoSSLError(Exception): + """Exception raised when a client speaks HTTP to an HTTPS socket.""" + pass + + +class FatalSSLAlert(Exception): + """Exception raised when the SSL implementation signals a fatal alert.""" + pass + + +class CP_fileobject(socket._fileobject): + """Faux file object attached to a socket object.""" + + def __init__(self, *args, **kwargs): + self.bytes_read = 0 + self.bytes_written = 0 + socket._fileobject.__init__(self, *args, **kwargs) + + def sendall(self, data): + """Sendall for non-blocking sockets.""" + while data: + try: + bytes_sent = self.send(data) + data = data[bytes_sent:] + except socket.error, e: + if e.args[0] not in socket_errors_nonblocking: + raise + + def send(self, data): + bytes_sent = self._sock.send(data) + self.bytes_written += bytes_sent + return bytes_sent + + def flush(self): + if self._wbuf: + buffer = "".join(self._wbuf) + self._wbuf = [] + self.sendall(buffer) + + def recv(self, size): + while True: + try: + data = self._sock.recv(size) + self.bytes_read += len(data) + return data + except socket.error, e: + if (e.args[0] not in socket_errors_nonblocking + and e.args[0] not in socket_error_eintr): + raise + + if not _fileobject_uses_str_type: + def read(self, size=-1): + # Use max, disallow tiny reads in a loop as they are very inefficient. + # We never leave read() with any leftover data from a new recv() call + # in our internal buffer. + rbufsize = max(self._rbufsize, self.default_bufsize) + # Our use of StringIO rather than lists of string objects returned by + # recv() minimizes memory usage and fragmentation that occurs when + # rbufsize is large compared to the typical return value of recv(). + buf = self._rbuf + buf.seek(0, 2) # seek end + if size < 0: + # Read until EOF + self._rbuf = StringIO.StringIO() # reset _rbuf. we consume it via buf. + while True: + data = self.recv(rbufsize) + if not data: + break + buf.write(data) + return buf.getvalue() + else: + # Read until size bytes or EOF seen, whichever comes first + buf_len = buf.tell() + if buf_len >= size: + # Already have size bytes in our buffer? Extract and return. + buf.seek(0) + rv = buf.read(size) + self._rbuf = StringIO.StringIO() + self._rbuf.write(buf.read()) + return rv + + self._rbuf = StringIO.StringIO() # reset _rbuf. we consume it via buf. + while True: + left = size - buf_len + # recv() will malloc the amount of memory given as its + # parameter even though it often returns much less data + # than that. The returned data string is short lived + # as we copy it into a StringIO and free it. This avoids + # fragmentation issues on many platforms. + data = self.recv(left) + if not data: + break + n = len(data) + if n == size and not buf_len: + # Shortcut. Avoid buffer data copies when: + # - We have no data in our buffer. + # AND + # - Our call to recv returned exactly the + # number of bytes we were asked to read. + return data + if n == left: + buf.write(data) + del data # explicit free + break + assert n <= left, "recv(%d) returned %d bytes" % (left, n) + buf.write(data) + buf_len += n + del data # explicit free + #assert buf_len == buf.tell() + return buf.getvalue() + + def readline(self, size=-1): + buf = self._rbuf + buf.seek(0, 2) # seek end + if buf.tell() > 0: + # check if we already have it in our buffer + buf.seek(0) + bline = buf.readline(size) + if bline.endswith('\n') or len(bline) == size: + self._rbuf = StringIO.StringIO() + self._rbuf.write(buf.read()) + return bline + del bline + if size < 0: + # Read until \n or EOF, whichever comes first + if self._rbufsize <= 1: + # Speed up unbuffered case + buf.seek(0) + buffers = [buf.read()] + self._rbuf = StringIO.StringIO() # reset _rbuf. we consume it via buf. + data = None + recv = self.recv + while data != "\n": + data = recv(1) + if not data: + break + buffers.append(data) + return "".join(buffers) + + buf.seek(0, 2) # seek end + self._rbuf = StringIO.StringIO() # reset _rbuf. we consume it via buf. + while True: + data = self.recv(self._rbufsize) + if not data: + break + nl = data.find('\n') + if nl >= 0: + nl += 1 + buf.write(data[:nl]) + self._rbuf.write(data[nl:]) + del data + break + buf.write(data) + return buf.getvalue() + else: + # Read until size bytes or \n or EOF seen, whichever comes first + buf.seek(0, 2) # seek end + buf_len = buf.tell() + if buf_len >= size: + buf.seek(0) + rv = buf.read(size) + self._rbuf = StringIO.StringIO() + self._rbuf.write(buf.read()) + return rv + self._rbuf = StringIO.StringIO() # reset _rbuf. we consume it via buf. + while True: + data = self.recv(self._rbufsize) + if not data: + break + left = size - buf_len + # did we just receive a newline? + nl = data.find('\n', 0, left) + if nl >= 0: + nl += 1 + # save the excess data to _rbuf + self._rbuf.write(data[nl:]) + if buf_len: + buf.write(data[:nl]) + break + else: + # Shortcut. Avoid data copy through buf when returning + # a substring of our first recv(). + return data[:nl] + n = len(data) + if n == size and not buf_len: + # Shortcut. Avoid data copy through buf when + # returning exactly all of our first recv(). + return data + if n >= left: + buf.write(data[:left]) + self._rbuf.write(data[left:]) + break + buf.write(data) + buf_len += n + #assert buf_len == buf.tell() + return buf.getvalue() + else: + def read(self, size=-1): + if size < 0: + # Read until EOF + buffers = [self._rbuf] + self._rbuf = "" + if self._rbufsize <= 1: + recv_size = self.default_bufsize + else: + recv_size = self._rbufsize + + while True: + data = self.recv(recv_size) + if not data: + break + buffers.append(data) + return "".join(buffers) + else: + # Read until size bytes or EOF seen, whichever comes first + data = self._rbuf + buf_len = len(data) + if buf_len >= size: + self._rbuf = data[size:] + return data[:size] + buffers = [] + if data: + buffers.append(data) + self._rbuf = "" + while True: + left = size - buf_len + recv_size = max(self._rbufsize, left) + data = self.recv(recv_size) + if not data: + break + buffers.append(data) + n = len(data) + if n >= left: + self._rbuf = data[left:] + buffers[-1] = data[:left] + break + buf_len += n + return "".join(buffers) + + def readline(self, size=-1): + data = self._rbuf + if size < 0: + # Read until \n or EOF, whichever comes first + if self._rbufsize <= 1: + # Speed up unbuffered case + assert data == "" + buffers = [] + while data != "\n": + data = self.recv(1) + if not data: + break + buffers.append(data) + return "".join(buffers) + nl = data.find('\n') + if nl >= 0: + nl += 1 + self._rbuf = data[nl:] + return data[:nl] + buffers = [] + if data: + buffers.append(data) + self._rbuf = "" + while True: + data = self.recv(self._rbufsize) + if not data: + break + buffers.append(data) + nl = data.find('\n') + if nl >= 0: + nl += 1 + self._rbuf = data[nl:] + buffers[-1] = data[:nl] + break + return "".join(buffers) + else: + # Read until size bytes or \n or EOF seen, whichever comes first + nl = data.find('\n', 0, size) + if nl >= 0: + nl += 1 + self._rbuf = data[nl:] + return data[:nl] + buf_len = len(data) + if buf_len >= size: + self._rbuf = data[size:] + return data[:size] + buffers = [] + if data: + buffers.append(data) + self._rbuf = "" + while True: + data = self.recv(self._rbufsize) + if not data: + break + buffers.append(data) + left = size - buf_len + nl = data.find('\n', 0, left) + if nl >= 0: + nl += 1 + self._rbuf = data[nl:] + buffers[-1] = data[:nl] + break + n = len(data) + if n >= left: + self._rbuf = data[left:] + buffers[-1] = data[:left] + break + buf_len += n + return "".join(buffers) + + +class HTTPConnection(object): + """An HTTP connection (active socket). + + server: the Server object which received this connection. + socket: the raw socket object (usually TCP) for this connection. + makefile: a fileobject class for reading from the socket. + """ + + remote_addr = None + remote_port = None + ssl_env = None + rbufsize = DEFAULT_BUFFER_SIZE + wbufsize = DEFAULT_BUFFER_SIZE + RequestHandlerClass = HTTPRequest + + def __init__(self, server, sock, makefile=CP_fileobject): + self.server = server + self.socket = sock + self.rfile = makefile(sock, "rb", self.rbufsize) + self.wfile = makefile(sock, "wb", self.wbufsize) + self.requests_seen = 0 + + def communicate(self): + """Read each request and respond appropriately.""" + request_seen = False + try: + while True: + # (re)set req to None so that if something goes wrong in + # the RequestHandlerClass constructor, the error doesn't + # get written to the previous request. + req = None + req = self.RequestHandlerClass(self.server, self) + + # This order of operations should guarantee correct pipelining. + req.parse_request() + if self.server.stats['Enabled']: + self.requests_seen += 1 + if not req.ready: + # Something went wrong in the parsing (and the server has + # probably already made a simple_response). Return and + # let the conn close. + return + + request_seen = True + req.respond() + if req.close_connection: + return + except socket.error: + e = sys.exc_info()[1] + errnum = e.args[0] + # sadly SSL sockets return a different (longer) time out string + if errnum == 'timed out' or errnum == 'The read operation timed out': + # Don't error if we're between requests; only error + # if 1) no request has been started at all, or 2) we're + # in the middle of a request. + # See http://www.cherrypy.org/ticket/853 + if (not request_seen) or (req and req.started_request): + # Don't bother writing the 408 if the response + # has already started being written. + if req and not req.sent_headers: + try: + req.simple_response("408 Request Timeout") + except FatalSSLAlert: + # Close the connection. + return + elif errnum not in socket_errors_to_ignore: + self.server.error_log("socket.error %s" % repr(errnum), + level=logging.WARNING, traceback=True) + if req and not req.sent_headers: + try: + req.simple_response("500 Internal Server Error") + except FatalSSLAlert: + # Close the connection. + return + return + except (KeyboardInterrupt, SystemExit): + raise + except FatalSSLAlert: + # Close the connection. + return + except NoSSLError: + if req and not req.sent_headers: + # Unwrap our wfile + self.wfile = CP_fileobject(self.socket._sock, "wb", self.wbufsize) + req.simple_response("400 Bad Request", + "The client sent a plain HTTP request, but " + "this server only speaks HTTPS on this port.") + self.linger = True + except Exception: + e = sys.exc_info()[1] + self.server.error_log(repr(e), level=logging.ERROR, traceback=True) + if req and not req.sent_headers: + try: + req.simple_response("500 Internal Server Error") + except FatalSSLAlert: + # Close the connection. + return + + linger = False + + def close(self): + """Close the socket underlying this connection.""" + self.rfile.close() + + if not self.linger: + # Python's socket module does NOT call close on the kernel socket + # when you call socket.close(). We do so manually here because we + # want this server to send a FIN TCP segment immediately. Note this + # must be called *before* calling socket.close(), because the latter + # drops its reference to the kernel socket. + if hasattr(self.socket, '_sock'): + self.socket._sock.close() + self.socket.close() + else: + # On the other hand, sometimes we want to hang around for a bit + # to make sure the client has a chance to read our entire + # response. Skipping the close() calls here delays the FIN + # packet until the socket object is garbage-collected later. + # Someday, perhaps, we'll do the full lingering_close that + # Apache does, but not today. + pass + + +class TrueyZero(object): + """An object which equals and does math like the integer '0' but evals True.""" + def __add__(self, other): + return other + def __radd__(self, other): + return other +trueyzero = TrueyZero() + + +_SHUTDOWNREQUEST = None + +class WorkerThread(threading.Thread): + """Thread which continuously polls a Queue for Connection objects. + + Due to the timing issues of polling a Queue, a WorkerThread does not + check its own 'ready' flag after it has started. To stop the thread, + it is necessary to stick a _SHUTDOWNREQUEST object onto the Queue + (one for each running WorkerThread). + """ + + conn = None + """The current connection pulled off the Queue, or None.""" + + server = None + """The HTTP Server which spawned this thread, and which owns the + Queue and is placing active connections into it.""" + + ready = False + """A simple flag for the calling server to know when this thread + has begun polling the Queue.""" + + + def __init__(self, server): + self.ready = False + self.server = server + + self.requests_seen = 0 + self.bytes_read = 0 + self.bytes_written = 0 + self.start_time = None + self.work_time = 0 + self.stats = { + 'Requests': lambda s: self.requests_seen + ((self.start_time is None) and trueyzero or self.conn.requests_seen), + 'Bytes Read': lambda s: self.bytes_read + ((self.start_time is None) and trueyzero or self.conn.rfile.bytes_read), + 'Bytes Written': lambda s: self.bytes_written + ((self.start_time is None) and trueyzero or self.conn.wfile.bytes_written), + 'Work Time': lambda s: self.work_time + ((self.start_time is None) and trueyzero or time.time() - self.start_time), + 'Read Throughput': lambda s: s['Bytes Read'](s) / (s['Work Time'](s) or 1e-6), + 'Write Throughput': lambda s: s['Bytes Written'](s) / (s['Work Time'](s) or 1e-6), + } + threading.Thread.__init__(self) + + def run(self): + self.server.stats['Worker Threads'][self.getName()] = self.stats + try: + self.ready = True + while True: + conn = self.server.requests.get() + if conn is _SHUTDOWNREQUEST: + return + + self.conn = conn + if self.server.stats['Enabled']: + self.start_time = time.time() + try: + conn.communicate() + finally: + conn.close() + if self.server.stats['Enabled']: + self.requests_seen += self.conn.requests_seen + self.bytes_read += self.conn.rfile.bytes_read + self.bytes_written += self.conn.wfile.bytes_written + self.work_time += time.time() - self.start_time + self.start_time = None + self.conn = None + except (KeyboardInterrupt, SystemExit): + exc = sys.exc_info()[1] + self.server.interrupt = exc + + +class ThreadPool(object): + """A Request Queue for an HTTPServer which pools threads. + + ThreadPool objects must provide min, get(), put(obj), start() + and stop(timeout) attributes. + """ + + def __init__(self, server, min=10, max=-1): + self.server = server + self.min = min + self.max = max + self._threads = [] + self._queue = queue.Queue() + self.get = self._queue.get + + def start(self): + """Start the pool of threads.""" + for i in range(self.min): + self._threads.append(WorkerThread(self.server)) + for worker in self._threads: + worker.setName("CP Server " + worker.getName()) + worker.start() + for worker in self._threads: + while not worker.ready: + time.sleep(.1) + + def _get_idle(self): + """Number of worker threads which are idle. Read-only.""" + return len([t for t in self._threads if t.conn is None]) + idle = property(_get_idle, doc=_get_idle.__doc__) + + def put(self, obj): + self._queue.put(obj) + if obj is _SHUTDOWNREQUEST: + return + + def grow(self, amount): + """Spawn new worker threads (not above self.max).""" + for i in range(amount): + if self.max > 0 and len(self._threads) >= self.max: + break + worker = WorkerThread(self.server) + worker.setName("CP Server " + worker.getName()) + self._threads.append(worker) + worker.start() + + def shrink(self, amount): + """Kill off worker threads (not below self.min).""" + # Grow/shrink the pool if necessary. + # Remove any dead threads from our list + for t in self._threads: + if not t.isAlive(): + self._threads.remove(t) + amount -= 1 + + if amount > 0: + for i in range(min(amount, len(self._threads) - self.min)): + # Put a number of shutdown requests on the queue equal + # to 'amount'. Once each of those is processed by a worker, + # that worker will terminate and be culled from our list + # in self.put. + self._queue.put(_SHUTDOWNREQUEST) + + def stop(self, timeout=5): + # Must shut down threads here so the code that calls + # this method can know when all threads are stopped. + for worker in self._threads: + self._queue.put(_SHUTDOWNREQUEST) + + # Don't join currentThread (when stop is called inside a request). + current = threading.currentThread() + if timeout and timeout >= 0: + endtime = time.time() + timeout + while self._threads: + worker = self._threads.pop() + if worker is not current and worker.isAlive(): + try: + if timeout is None or timeout < 0: + worker.join() + else: + remaining_time = endtime - time.time() + if remaining_time > 0: + worker.join(remaining_time) + if worker.isAlive(): + # We exhausted the timeout. + # Forcibly shut down the socket. + c = worker.conn + if c and not c.rfile.closed: + try: + c.socket.shutdown(socket.SHUT_RD) + except TypeError: + # pyOpenSSL sockets don't take an arg + c.socket.shutdown() + worker.join() + except (AssertionError, + # Ignore repeated Ctrl-C. + # See http://www.cherrypy.org/ticket/691. + KeyboardInterrupt): + pass + + def _get_qsize(self): + return self._queue.qsize() + qsize = property(_get_qsize) + + + +try: + import fcntl +except ImportError: + try: + from ctypes import windll, WinError + except ImportError: + def prevent_socket_inheritance(sock): + """Dummy function, since neither fcntl nor ctypes are available.""" + pass + else: + def prevent_socket_inheritance(sock): + """Mark the given socket fd as non-inheritable (Windows).""" + if not windll.kernel32.SetHandleInformation(sock.fileno(), 1, 0): + raise WinError() +else: + def prevent_socket_inheritance(sock): + """Mark the given socket fd as non-inheritable (POSIX).""" + fd = sock.fileno() + old_flags = fcntl.fcntl(fd, fcntl.F_GETFD) + fcntl.fcntl(fd, fcntl.F_SETFD, old_flags | fcntl.FD_CLOEXEC) + + +class SSLAdapter(object): + """Base class for SSL driver library adapters. + + Required methods: + + * ``wrap(sock) -> (wrapped socket, ssl environ dict)`` + * ``makefile(sock, mode='r', bufsize=DEFAULT_BUFFER_SIZE) -> socket file object`` + """ + + def __init__(self, certificate, private_key, certificate_chain=None): + self.certificate = certificate + self.private_key = private_key + self.certificate_chain = certificate_chain + + def wrap(self, sock): + raise NotImplemented + + def makefile(self, sock, mode='r', bufsize=DEFAULT_BUFFER_SIZE): + raise NotImplemented + + +class HTTPServer(object): + """An HTTP server.""" + + _bind_addr = "127.0.0.1" + _interrupt = None + + gateway = None + """A Gateway instance.""" + + minthreads = None + """The minimum number of worker threads to create (default 10).""" + + maxthreads = None + """The maximum number of worker threads to create (default -1 = no limit).""" + + server_name = None + """The name of the server; defaults to socket.gethostname().""" + + protocol = "HTTP/1.1" + """The version string to write in the Status-Line of all HTTP responses. + + For example, "HTTP/1.1" is the default. This also limits the supported + features used in the response.""" + + request_queue_size = 5 + """The 'backlog' arg to socket.listen(); max queued connections (default 5).""" + + shutdown_timeout = 5 + """The total time, in seconds, to wait for worker threads to cleanly exit.""" + + timeout = 10 + """The timeout in seconds for accepted connections (default 10).""" + + version = "CherryPy/3.2.2" + """A version string for the HTTPServer.""" + + software = None + """The value to set for the SERVER_SOFTWARE entry in the WSGI environ. + + If None, this defaults to ``'%s Server' % self.version``.""" + + ready = False + """An internal flag which marks whether the socket is accepting connections.""" + + max_request_header_size = 0 + """The maximum size, in bytes, for request headers, or 0 for no limit.""" + + max_request_body_size = 0 + """The maximum size, in bytes, for request bodies, or 0 for no limit.""" + + nodelay = True + """If True (the default since 3.1), sets the TCP_NODELAY socket option.""" + + ConnectionClass = HTTPConnection + """The class to use for handling HTTP connections.""" + + ssl_adapter = None + """An instance of SSLAdapter (or a subclass). + + You must have the corresponding SSL driver library installed.""" + + def __init__(self, bind_addr, gateway, minthreads=10, maxthreads=-1, + server_name=None): + self.bind_addr = bind_addr + self.gateway = gateway + + self.requests = ThreadPool(self, min=minthreads or 1, max=maxthreads) + + if not server_name: + server_name = socket.gethostname() + self.server_name = server_name + self.clear_stats() + + def clear_stats(self): + self._start_time = None + self._run_time = 0 + self.stats = { + 'Enabled': False, + 'Bind Address': lambda s: repr(self.bind_addr), + 'Run time': lambda s: (not s['Enabled']) and -1 or self.runtime(), + 'Accepts': 0, + 'Accepts/sec': lambda s: s['Accepts'] / self.runtime(), + 'Queue': lambda s: getattr(self.requests, "qsize", None), + 'Threads': lambda s: len(getattr(self.requests, "_threads", [])), + 'Threads Idle': lambda s: getattr(self.requests, "idle", None), + 'Socket Errors': 0, + 'Requests': lambda s: (not s['Enabled']) and -1 or sum([w['Requests'](w) for w + in s['Worker Threads'].values()], 0), + 'Bytes Read': lambda s: (not s['Enabled']) and -1 or sum([w['Bytes Read'](w) for w + in s['Worker Threads'].values()], 0), + 'Bytes Written': lambda s: (not s['Enabled']) and -1 or sum([w['Bytes Written'](w) for w + in s['Worker Threads'].values()], 0), + 'Work Time': lambda s: (not s['Enabled']) and -1 or sum([w['Work Time'](w) for w + in s['Worker Threads'].values()], 0), + 'Read Throughput': lambda s: (not s['Enabled']) and -1 or sum( + [w['Bytes Read'](w) / (w['Work Time'](w) or 1e-6) + for w in s['Worker Threads'].values()], 0), + 'Write Throughput': lambda s: (not s['Enabled']) and -1 or sum( + [w['Bytes Written'](w) / (w['Work Time'](w) or 1e-6) + for w in s['Worker Threads'].values()], 0), + 'Worker Threads': {}, + } + logging.statistics["CherryPy HTTPServer %d" % id(self)] = self.stats + + def runtime(self): + if self._start_time is None: + return self._run_time + else: + return self._run_time + (time.time() - self._start_time) + + def __str__(self): + return "%s.%s(%r)" % (self.__module__, self.__class__.__name__, + self.bind_addr) + + def _get_bind_addr(self): + return self._bind_addr + def _set_bind_addr(self, value): + if isinstance(value, tuple) and value[0] in ('', None): + # Despite the socket module docs, using '' does not + # allow AI_PASSIVE to work. Passing None instead + # returns '0.0.0.0' like we want. In other words: + # host AI_PASSIVE result + # '' Y 192.168.x.y + # '' N 192.168.x.y + # None Y 0.0.0.0 + # None N 127.0.0.1 + # But since you can get the same effect with an explicit + # '0.0.0.0', we deny both the empty string and None as values. + raise ValueError("Host values of '' or None are not allowed. " + "Use '0.0.0.0' (IPv4) or '::' (IPv6) instead " + "to listen on all active interfaces.") + self._bind_addr = value + bind_addr = property(_get_bind_addr, _set_bind_addr, + doc="""The interface on which to listen for connections. + + For TCP sockets, a (host, port) tuple. Host values may be any IPv4 + or IPv6 address, or any valid hostname. The string 'localhost' is a + synonym for '127.0.0.1' (or '::1', if your hosts file prefers IPv6). + The string '0.0.0.0' is a special IPv4 entry meaning "any active + interface" (INADDR_ANY), and '::' is the similar IN6ADDR_ANY for + IPv6. The empty string or None are not allowed. + + For UNIX sockets, supply the filename as a string.""") + + def start(self): + """Run the server forever.""" + # We don't have to trap KeyboardInterrupt or SystemExit here, + # because cherrpy.server already does so, calling self.stop() for us. + # If you're using this server with another framework, you should + # trap those exceptions in whatever code block calls start(). + self._interrupt = None + + if self.software is None: + self.software = "%s Server" % self.version + + # SSL backward compatibility + if (self.ssl_adapter is None and + getattr(self, 'ssl_certificate', None) and + getattr(self, 'ssl_private_key', None)): + warnings.warn( + "SSL attributes are deprecated in CherryPy 3.2, and will " + "be removed in CherryPy 3.3. Use an ssl_adapter attribute " + "instead.", + DeprecationWarning + ) + try: + from cherrypy.wsgiserver.ssl_pyopenssl import pyOpenSSLAdapter + except ImportError: + pass + else: + self.ssl_adapter = pyOpenSSLAdapter( + self.ssl_certificate, self.ssl_private_key, + getattr(self, 'ssl_certificate_chain', None)) + + # Select the appropriate socket + if isinstance(self.bind_addr, basestring): + # AF_UNIX socket + + # So we can reuse the socket... + try: os.unlink(self.bind_addr) + except: pass + + # So everyone can access the socket... + try: os.chmod(self.bind_addr, 511) # 0777 + except: pass + + info = [(socket.AF_UNIX, socket.SOCK_STREAM, 0, "", self.bind_addr)] + else: + # AF_INET or AF_INET6 socket + # Get the correct address family for our host (allows IPv6 addresses) + host, port = self.bind_addr + try: + info = socket.getaddrinfo(host, port, socket.AF_UNSPEC, + socket.SOCK_STREAM, 0, socket.AI_PASSIVE) + except socket.gaierror: + if ':' in self.bind_addr[0]: + info = [(socket.AF_INET6, socket.SOCK_STREAM, + 0, "", self.bind_addr + (0, 0))] + else: + info = [(socket.AF_INET, socket.SOCK_STREAM, + 0, "", self.bind_addr)] + + self.socket = None + msg = "No socket could be created" + for res in info: + af, socktype, proto, canonname, sa = res + try: + self.bind(af, socktype, proto) + except socket.error: + if self.socket: + self.socket.close() + self.socket = None + continue + break + if not self.socket: + raise socket.error(msg) + + # Timeout so KeyboardInterrupt can be caught on Win32 + self.socket.settimeout(1) + self.socket.listen(self.request_queue_size) + + # Create worker threads + self.requests.start() + + self.ready = True + self._start_time = time.time() + while self.ready: + try: + self.tick() + except (KeyboardInterrupt, SystemExit): + raise + except: + self.error_log("Error in HTTPServer.tick", level=logging.ERROR, + traceback=True) + + if self.interrupt: + while self.interrupt is True: + # Wait for self.stop() to complete. See _set_interrupt. + time.sleep(0.1) + if self.interrupt: + raise self.interrupt + + def error_log(self, msg="", level=20, traceback=False): + # Override this in subclasses as desired + sys.stderr.write(msg + '\n') + sys.stderr.flush() + if traceback: + tblines = format_exc() + sys.stderr.write(tblines) + sys.stderr.flush() + + def bind(self, family, type, proto=0): + """Create (or recreate) the actual socket object.""" + self.socket = socket.socket(family, type, proto) + prevent_socket_inheritance(self.socket) + self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + if self.nodelay and not isinstance(self.bind_addr, str): + self.socket.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) + + if self.ssl_adapter is not None: + self.socket = self.ssl_adapter.bind(self.socket) + + # If listening on the IPV6 any address ('::' = IN6ADDR_ANY), + # activate dual-stack. See http://www.cherrypy.org/ticket/871. + if (hasattr(socket, 'AF_INET6') and family == socket.AF_INET6 + and self.bind_addr[0] in ('::', '::0', '::0.0.0.0')): + try: + self.socket.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_V6ONLY, 0) + except (AttributeError, socket.error): + # Apparently, the socket option is not available in + # this machine's TCP stack + pass + + self.socket.bind(self.bind_addr) + + def tick(self): + """Accept a new connection and put it on the Queue.""" + try: + s, addr = self.socket.accept() + if self.stats['Enabled']: + self.stats['Accepts'] += 1 + if not self.ready: + return + + prevent_socket_inheritance(s) + if hasattr(s, 'settimeout'): + s.settimeout(self.timeout) + + makefile = CP_fileobject + ssl_env = {} + # if ssl cert and key are set, we try to be a secure HTTP server + if self.ssl_adapter is not None: + try: + s, ssl_env = self.ssl_adapter.wrap(s) + except NoSSLError: + msg = ("The client sent a plain HTTP request, but " + "this server only speaks HTTPS on this port.") + buf = ["%s 400 Bad Request\r\n" % self.protocol, + "Content-Length: %s\r\n" % len(msg), + "Content-Type: text/plain\r\n\r\n", + msg] + + wfile = makefile(s, "wb", DEFAULT_BUFFER_SIZE) + try: + wfile.sendall("".join(buf)) + except socket.error: + x = sys.exc_info()[1] + if x.args[0] not in socket_errors_to_ignore: + raise + return + if not s: + return + makefile = self.ssl_adapter.makefile + # Re-apply our timeout since we may have a new socket object + if hasattr(s, 'settimeout'): + s.settimeout(self.timeout) + + conn = self.ConnectionClass(self, s, makefile) + + if not isinstance(self.bind_addr, basestring): + # optional values + # Until we do DNS lookups, omit REMOTE_HOST + if addr is None: # sometimes this can happen + # figure out if AF_INET or AF_INET6. + if len(s.getsockname()) == 2: + # AF_INET + addr = ('0.0.0.0', 0) + else: + # AF_INET6 + addr = ('::', 0) + conn.remote_addr = addr[0] + conn.remote_port = addr[1] + + conn.ssl_env = ssl_env + + self.requests.put(conn) + except socket.timeout: + # The only reason for the timeout in start() is so we can + # notice keyboard interrupts on Win32, which don't interrupt + # accept() by default + return + except socket.error: + x = sys.exc_info()[1] + if self.stats['Enabled']: + self.stats['Socket Errors'] += 1 + if x.args[0] in socket_error_eintr: + # I *think* this is right. EINTR should occur when a signal + # is received during the accept() call; all docs say retry + # the call, and I *think* I'm reading it right that Python + # will then go ahead and poll for and handle the signal + # elsewhere. See http://www.cherrypy.org/ticket/707. + return + if x.args[0] in socket_errors_nonblocking: + # Just try again. See http://www.cherrypy.org/ticket/479. + return + if x.args[0] in socket_errors_to_ignore: + # Our socket was closed. + # See http://www.cherrypy.org/ticket/686. + return + raise + + def _get_interrupt(self): + return self._interrupt + def _set_interrupt(self, interrupt): + self._interrupt = True + self.stop() + self._interrupt = interrupt + interrupt = property(_get_interrupt, _set_interrupt, + doc="Set this to an Exception instance to " + "interrupt the server.") + + def stop(self): + """Gracefully shutdown a server that is serving forever.""" + self.ready = False + if self._start_time is not None: + self._run_time += (time.time() - self._start_time) + self._start_time = None + + sock = getattr(self, "socket", None) + if sock: + if not isinstance(self.bind_addr, basestring): + # Touch our own socket to make accept() return immediately. + try: + host, port = sock.getsockname()[:2] + except socket.error: + x = sys.exc_info()[1] + if x.args[0] not in socket_errors_to_ignore: + # Changed to use error code and not message + # See http://www.cherrypy.org/ticket/860. + raise + else: + # Note that we're explicitly NOT using AI_PASSIVE, + # here, because we want an actual IP to touch. + # localhost won't work if we've bound to a public IP, + # but it will if we bound to '0.0.0.0' (INADDR_ANY). + for res in socket.getaddrinfo(host, port, socket.AF_UNSPEC, + socket.SOCK_STREAM): + af, socktype, proto, canonname, sa = res + s = None + try: + s = socket.socket(af, socktype, proto) + # See http://groups.google.com/group/cherrypy-users/ + # browse_frm/thread/bbfe5eb39c904fe0 + s.settimeout(1.0) + s.connect((host, port)) + s.close() + except socket.error: + if s: + s.close() + if hasattr(sock, "close"): + sock.close() + self.socket = None + + self.requests.stop(self.shutdown_timeout) + + +class Gateway(object): + """A base class to interface HTTPServer with other systems, such as WSGI.""" + + def __init__(self, req): + self.req = req + + def respond(self): + """Process the current request. Must be overridden in a subclass.""" + raise NotImplemented + + +# These may either be wsgiserver.SSLAdapter subclasses or the string names +# of such classes (in which case they will be lazily loaded). +ssl_adapters = { + 'builtin': 'cherrypy.wsgiserver.ssl_builtin.BuiltinSSLAdapter', + 'pyopenssl': 'cherrypy.wsgiserver.ssl_pyopenssl.pyOpenSSLAdapter', + } + +def get_ssl_adapter_class(name='pyopenssl'): + """Return an SSL adapter class for the given name.""" + adapter = ssl_adapters[name.lower()] + if isinstance(adapter, basestring): + last_dot = adapter.rfind(".") + attr_name = adapter[last_dot + 1:] + mod_path = adapter[:last_dot] + + try: + mod = sys.modules[mod_path] + if mod is None: + raise KeyError() + except KeyError: + # The last [''] is important. + mod = __import__(mod_path, globals(), locals(), ['']) + + # Let an AttributeError propagate outward. + try: + adapter = getattr(mod, attr_name) + except AttributeError: + raise AttributeError("'%s' object has no attribute '%s'" + % (mod_path, attr_name)) + + return adapter + +# -------------------------------- WSGI Stuff -------------------------------- # + + +class CherryPyWSGIServer(HTTPServer): + """A subclass of HTTPServer which calls a WSGI application.""" + + wsgi_version = (1, 0) + """The version of WSGI to produce.""" + + def __init__(self, bind_addr, wsgi_app, numthreads=10, server_name=None, + max=-1, request_queue_size=5, timeout=10, shutdown_timeout=5): + self.requests = ThreadPool(self, min=numthreads or 1, max=max) + self.wsgi_app = wsgi_app + self.gateway = wsgi_gateways[self.wsgi_version] + + self.bind_addr = bind_addr + if not server_name: + server_name = socket.gethostname() + self.server_name = server_name + self.request_queue_size = request_queue_size + + self.timeout = timeout + self.shutdown_timeout = shutdown_timeout + self.clear_stats() + + def _get_numthreads(self): + return self.requests.min + def _set_numthreads(self, value): + self.requests.min = value + numthreads = property(_get_numthreads, _set_numthreads) + + +class WSGIGateway(Gateway): + """A base class to interface HTTPServer with WSGI.""" + + def __init__(self, req): + self.req = req + self.started_response = False + self.env = self.get_environ() + self.remaining_bytes_out = None + + def get_environ(self): + """Return a new environ dict targeting the given wsgi.version""" + raise NotImplemented + + def respond(self): + """Process the current request.""" + response = self.req.server.wsgi_app(self.env, self.start_response) + try: + for chunk in response: + # "The start_response callable must not actually transmit + # the response headers. Instead, it must store them for the + # server or gateway to transmit only after the first + # iteration of the application return value that yields + # a NON-EMPTY string, or upon the application's first + # invocation of the write() callable." (PEP 333) + if chunk: + if isinstance(chunk, unicodestr): + chunk = chunk.encode('ISO-8859-1') + self.write(chunk) + finally: + if hasattr(response, "close"): + response.close() + + def start_response(self, status, headers, exc_info = None): + """WSGI callable to begin the HTTP response.""" + # "The application may call start_response more than once, + # if and only if the exc_info argument is provided." + if self.started_response and not exc_info: + raise AssertionError("WSGI start_response called a second " + "time with no exc_info.") + self.started_response = True + + # "if exc_info is provided, and the HTTP headers have already been + # sent, start_response must raise an error, and should raise the + # exc_info tuple." + if self.req.sent_headers: + try: + raise exc_info[0], exc_info[1], exc_info[2] + finally: + exc_info = None + + self.req.status = status + for k, v in headers: + if not isinstance(k, str): + raise TypeError("WSGI response header key %r is not of type str." % k) + if not isinstance(v, str): + raise TypeError("WSGI response header value %r is not of type str." % v) + if k.lower() == 'content-length': + self.remaining_bytes_out = int(v) + self.req.outheaders.extend(headers) + + return self.write + + def write(self, chunk): + """WSGI callable to write unbuffered data to the client. + + This method is also used internally by start_response (to write + data from the iterable returned by the WSGI application). + """ + if not self.started_response: + raise AssertionError("WSGI write called before start_response.") + + chunklen = len(chunk) + rbo = self.remaining_bytes_out + if rbo is not None and chunklen > rbo: + if not self.req.sent_headers: + # Whew. We can send a 500 to the client. + self.req.simple_response("500 Internal Server Error", + "The requested resource returned more bytes than the " + "declared Content-Length.") + else: + # Dang. We have probably already sent data. Truncate the chunk + # to fit (so the client doesn't hang) and raise an error later. + chunk = chunk[:rbo] + + if not self.req.sent_headers: + self.req.sent_headers = True + self.req.send_headers() + + self.req.write(chunk) + + if rbo is not None: + rbo -= chunklen + if rbo < 0: + raise ValueError( + "Response body exceeds the declared Content-Length.") + + +class WSGIGateway_10(WSGIGateway): + """A Gateway class to interface HTTPServer with WSGI 1.0.x.""" + + def get_environ(self): + """Return a new environ dict targeting the given wsgi.version""" + req = self.req + env = { + # set a non-standard environ entry so the WSGI app can know what + # the *real* server protocol is (and what features to support). + # See http://www.faqs.org/rfcs/rfc2145.html. + 'ACTUAL_SERVER_PROTOCOL': req.server.protocol, + 'PATH_INFO': req.path, + 'QUERY_STRING': req.qs, + 'REMOTE_ADDR': req.conn.remote_addr or '', + 'REMOTE_PORT': str(req.conn.remote_port or ''), + 'REQUEST_METHOD': req.method, + 'REQUEST_URI': req.uri, + 'SCRIPT_NAME': '', + 'SERVER_NAME': req.server.server_name, + # Bah. "SERVER_PROTOCOL" is actually the REQUEST protocol. + 'SERVER_PROTOCOL': req.request_protocol, + 'SERVER_SOFTWARE': req.server.software, + 'wsgi.errors': sys.stderr, + 'wsgi.input': req.rfile, + 'wsgi.multiprocess': False, + 'wsgi.multithread': True, + 'wsgi.run_once': False, + 'wsgi.url_scheme': req.scheme, + 'wsgi.version': (1, 0), + } + + if isinstance(req.server.bind_addr, basestring): + # AF_UNIX. This isn't really allowed by WSGI, which doesn't + # address unix domain sockets. But it's better than nothing. + env["SERVER_PORT"] = "" + else: + env["SERVER_PORT"] = str(req.server.bind_addr[1]) + + # Request headers + for k, v in req.inheaders.iteritems(): + env["HTTP_" + k.upper().replace("-", "_")] = v + + # CONTENT_TYPE/CONTENT_LENGTH + ct = env.pop("HTTP_CONTENT_TYPE", None) + if ct is not None: + env["CONTENT_TYPE"] = ct + cl = env.pop("HTTP_CONTENT_LENGTH", None) + if cl is not None: + env["CONTENT_LENGTH"] = cl + + if req.conn.ssl_env: + env.update(req.conn.ssl_env) + + return env + + +class WSGIGateway_u0(WSGIGateway_10): + """A Gateway class to interface HTTPServer with WSGI u.0. + + WSGI u.0 is an experimental protocol, which uses unicode for keys and values + in both Python 2 and Python 3. + """ + + def get_environ(self): + """Return a new environ dict targeting the given wsgi.version""" + req = self.req + env_10 = WSGIGateway_10.get_environ(self) + env = dict([(k.decode('ISO-8859-1'), v) for k, v in env_10.iteritems()]) + env[u'wsgi.version'] = ('u', 0) + + # Request-URI + env.setdefault(u'wsgi.url_encoding', u'utf-8') + try: + for key in [u"PATH_INFO", u"SCRIPT_NAME", u"QUERY_STRING"]: + env[key] = env_10[str(key)].decode(env[u'wsgi.url_encoding']) + except UnicodeDecodeError: + # Fall back to latin 1 so apps can transcode if needed. + env[u'wsgi.url_encoding'] = u'ISO-8859-1' + for key in [u"PATH_INFO", u"SCRIPT_NAME", u"QUERY_STRING"]: + env[key] = env_10[str(key)].decode(env[u'wsgi.url_encoding']) + + for k, v in sorted(env.items()): + if isinstance(v, str) and k not in ('REQUEST_URI', 'wsgi.input'): + env[k] = v.decode('ISO-8859-1') + + return env + +wsgi_gateways = { + (1, 0): WSGIGateway_10, + ('u', 0): WSGIGateway_u0, +} + +class WSGIPathInfoDispatcher(object): + """A WSGI dispatcher for dispatch based on the PATH_INFO. + + apps: a dict or list of (path_prefix, app) pairs. + """ + + def __init__(self, apps): + try: + apps = list(apps.items()) + except AttributeError: + pass + + # Sort the apps by len(path), descending + apps.sort(cmp=lambda x,y: cmp(len(x[0]), len(y[0]))) + apps.reverse() + + # The path_prefix strings must start, but not end, with a slash. + # Use "" instead of "/". + self.apps = [(p.rstrip("/"), a) for p, a in apps] + + def __call__(self, environ, start_response): + path = environ["PATH_INFO"] or "/" + for p, app in self.apps: + # The apps list should be sorted by length, descending. + if path.startswith(p + "/") or path == p: + environ = environ.copy() + environ["SCRIPT_NAME"] = environ["SCRIPT_NAME"] + p + environ["PATH_INFO"] = path[len(p):] + return app(environ, start_response) + + start_response('404 Not Found', [('Content-Type', 'text/plain'), + ('Content-Length', '0')]) + return [''] + diff --git a/cherrypy/wsgiserver/wsgiserver3.py b/cherrypy/wsgiserver/wsgiserver3.py new file mode 100644 index 00000000..62db5ffd --- /dev/null +++ b/cherrypy/wsgiserver/wsgiserver3.py @@ -0,0 +1,2040 @@ +"""A high-speed, production ready, thread pooled, generic HTTP server. + +Simplest example on how to use this module directly +(without using CherryPy's application machinery):: + + from cherrypy import wsgiserver + + def my_crazy_app(environ, start_response): + status = '200 OK' + response_headers = [('Content-type','text/plain')] + start_response(status, response_headers) + return ['Hello world!'] + + server = wsgiserver.CherryPyWSGIServer( + ('0.0.0.0', 8070), my_crazy_app, + server_name='www.cherrypy.example') + server.start() + +The CherryPy WSGI server can serve as many WSGI applications +as you want in one instance by using a WSGIPathInfoDispatcher:: + + d = WSGIPathInfoDispatcher({'/': my_crazy_app, '/blog': my_blog_app}) + server = wsgiserver.CherryPyWSGIServer(('0.0.0.0', 80), d) + +Want SSL support? Just set server.ssl_adapter to an SSLAdapter instance. + +This won't call the CherryPy engine (application side) at all, only the +HTTP server, which is independent from the rest of CherryPy. Don't +let the name "CherryPyWSGIServer" throw you; the name merely reflects +its origin, not its coupling. + +For those of you wanting to understand internals of this module, here's the +basic call flow. The server's listening thread runs a very tight loop, +sticking incoming connections onto a Queue:: + + server = CherryPyWSGIServer(...) + server.start() + while True: + tick() + # This blocks until a request comes in: + child = socket.accept() + conn = HTTPConnection(child, ...) + server.requests.put(conn) + +Worker threads are kept in a pool and poll the Queue, popping off and then +handling each connection in turn. Each connection can consist of an arbitrary +number of requests and their responses, so we run a nested loop:: + + while True: + conn = server.requests.get() + conn.communicate() + -> while True: + req = HTTPRequest(...) + req.parse_request() + -> # Read the Request-Line, e.g. "GET /page HTTP/1.1" + req.rfile.readline() + read_headers(req.rfile, req.inheaders) + req.respond() + -> response = app(...) + try: + for chunk in response: + if chunk: + req.write(chunk) + finally: + if hasattr(response, "close"): + response.close() + if req.close_connection: + return +""" + +__all__ = ['HTTPRequest', 'HTTPConnection', 'HTTPServer', + 'SizeCheckWrapper', 'KnownLengthRFile', 'ChunkedRFile', + 'CP_makefile', + 'MaxSizeExceeded', 'NoSSLError', 'FatalSSLAlert', + 'WorkerThread', 'ThreadPool', 'SSLAdapter', + 'CherryPyWSGIServer', + 'Gateway', 'WSGIGateway', 'WSGIGateway_10', 'WSGIGateway_u0', + 'WSGIPathInfoDispatcher', 'get_ssl_adapter_class'] + +import os +try: + import queue +except: + import Queue as queue +import re +import email.utils +import socket +import sys +if 'win' in sys.platform and not hasattr(socket, 'IPPROTO_IPV6'): + socket.IPPROTO_IPV6 = 41 +if sys.version_info < (3,1): + import io +else: + import _pyio as io +DEFAULT_BUFFER_SIZE = io.DEFAULT_BUFFER_SIZE + +import threading +import time +from traceback import format_exc +from urllib.parse import unquote +from urllib.parse import urlparse +from urllib.parse import scheme_chars +import warnings + +if sys.version_info >= (3, 0): + bytestr = bytes + unicodestr = str + basestring = (bytes, str) + def ntob(n, encoding='ISO-8859-1'): + """Return the given native string as a byte string in the given encoding.""" + # In Python 3, the native string type is unicode + return n.encode(encoding) +else: + bytestr = str + unicodestr = unicode + basestring = basestring + def ntob(n, encoding='ISO-8859-1'): + """Return the given native string as a byte string in the given encoding.""" + # In Python 2, the native string type is bytes. Assume it's already + # in the given encoding, which for ISO-8859-1 is almost always what + # was intended. + return n + +LF = ntob('\n') +CRLF = ntob('\r\n') +TAB = ntob('\t') +SPACE = ntob(' ') +COLON = ntob(':') +SEMICOLON = ntob(';') +EMPTY = ntob('') +NUMBER_SIGN = ntob('#') +QUESTION_MARK = ntob('?') +ASTERISK = ntob('*') +FORWARD_SLASH = ntob('/') +quoted_slash = re.compile(ntob("(?i)%2F")) + +import errno + +def plat_specific_errors(*errnames): + """Return error numbers for all errors in errnames on this platform. + + The 'errno' module contains different global constants depending on + the specific platform (OS). This function will return the list of + numeric values for a given list of potential names. + """ + errno_names = dir(errno) + nums = [getattr(errno, k) for k in errnames if k in errno_names] + # de-dupe the list + return list(dict.fromkeys(nums).keys()) + +socket_error_eintr = plat_specific_errors("EINTR", "WSAEINTR") + +socket_errors_to_ignore = plat_specific_errors( + "EPIPE", + "EBADF", "WSAEBADF", + "ENOTSOCK", "WSAENOTSOCK", + "ETIMEDOUT", "WSAETIMEDOUT", + "ECONNREFUSED", "WSAECONNREFUSED", + "ECONNRESET", "WSAECONNRESET", + "ECONNABORTED", "WSAECONNABORTED", + "ENETRESET", "WSAENETRESET", + "EHOSTDOWN", "EHOSTUNREACH", + ) +socket_errors_to_ignore.append("timed out") +socket_errors_to_ignore.append("The read operation timed out") + +socket_errors_nonblocking = plat_specific_errors( + 'EAGAIN', 'EWOULDBLOCK', 'WSAEWOULDBLOCK') + +comma_separated_headers = [ntob(h) for h in + ['Accept', 'Accept-Charset', 'Accept-Encoding', + 'Accept-Language', 'Accept-Ranges', 'Allow', 'Cache-Control', + 'Connection', 'Content-Encoding', 'Content-Language', 'Expect', + 'If-Match', 'If-None-Match', 'Pragma', 'Proxy-Authenticate', 'TE', + 'Trailer', 'Transfer-Encoding', 'Upgrade', 'Vary', 'Via', 'Warning', + 'WWW-Authenticate']] + + +import logging +if not hasattr(logging, 'statistics'): logging.statistics = {} + + +def read_headers(rfile, hdict=None): + """Read headers from the given stream into the given header dict. + + If hdict is None, a new header dict is created. Returns the populated + header dict. + + Headers which are repeated are folded together using a comma if their + specification so dictates. + + This function raises ValueError when the read bytes violate the HTTP spec. + You should probably return "400 Bad Request" if this happens. + """ + if hdict is None: + hdict = {} + + while True: + line = rfile.readline() + if not line: + # No more data--illegal end of headers + raise ValueError("Illegal end of headers.") + + if line == CRLF: + # Normal end of headers + break + if not line.endswith(CRLF): + raise ValueError("HTTP requires CRLF terminators") + + if line[0] in (SPACE, TAB): + # It's a continuation line. + v = line.strip() + else: + try: + k, v = line.split(COLON, 1) + except ValueError: + raise ValueError("Illegal header line.") + # TODO: what about TE and WWW-Authenticate? + k = k.strip().title() + v = v.strip() + hname = k + + if k in comma_separated_headers: + existing = hdict.get(hname) + if existing: + v = b", ".join((existing, v)) + hdict[hname] = v + + return hdict + + +class MaxSizeExceeded(Exception): + pass + +class SizeCheckWrapper(object): + """Wraps a file-like object, raising MaxSizeExceeded if too large.""" + + def __init__(self, rfile, maxlen): + self.rfile = rfile + self.maxlen = maxlen + self.bytes_read = 0 + + def _check_length(self): + if self.maxlen and self.bytes_read > self.maxlen: + raise MaxSizeExceeded() + + def read(self, size=None): + data = self.rfile.read(size) + self.bytes_read += len(data) + self._check_length() + return data + + def readline(self, size=None): + if size is not None: + data = self.rfile.readline(size) + self.bytes_read += len(data) + self._check_length() + return data + + # User didn't specify a size ... + # We read the line in chunks to make sure it's not a 100MB line ! + res = [] + while True: + data = self.rfile.readline(256) + self.bytes_read += len(data) + self._check_length() + res.append(data) + # See http://www.cherrypy.org/ticket/421 + if len(data) < 256 or data[-1:] == "\n": + return EMPTY.join(res) + + def readlines(self, sizehint=0): + # Shamelessly stolen from StringIO + total = 0 + lines = [] + line = self.readline() + while line: + lines.append(line) + total += len(line) + if 0 < sizehint <= total: + break + line = self.readline() + return lines + + def close(self): + self.rfile.close() + + def __iter__(self): + return self + + def __next__(self): + data = next(self.rfile) + self.bytes_read += len(data) + self._check_length() + return data + + def next(self): + data = self.rfile.next() + self.bytes_read += len(data) + self._check_length() + return data + + +class KnownLengthRFile(object): + """Wraps a file-like object, returning an empty string when exhausted.""" + + def __init__(self, rfile, content_length): + self.rfile = rfile + self.remaining = content_length + + def read(self, size=None): + if self.remaining == 0: + return b'' + if size is None: + size = self.remaining + else: + size = min(size, self.remaining) + + data = self.rfile.read(size) + self.remaining -= len(data) + return data + + def readline(self, size=None): + if self.remaining == 0: + return b'' + if size is None: + size = self.remaining + else: + size = min(size, self.remaining) + + data = self.rfile.readline(size) + self.remaining -= len(data) + return data + + def readlines(self, sizehint=0): + # Shamelessly stolen from StringIO + total = 0 + lines = [] + line = self.readline(sizehint) + while line: + lines.append(line) + total += len(line) + if 0 < sizehint <= total: + break + line = self.readline(sizehint) + return lines + + def close(self): + self.rfile.close() + + def __iter__(self): + return self + + def __next__(self): + data = next(self.rfile) + self.remaining -= len(data) + return data + + +class ChunkedRFile(object): + """Wraps a file-like object, returning an empty string when exhausted. + + This class is intended to provide a conforming wsgi.input value for + request entities that have been encoded with the 'chunked' transfer + encoding. + """ + + def __init__(self, rfile, maxlen, bufsize=8192): + self.rfile = rfile + self.maxlen = maxlen + self.bytes_read = 0 + self.buffer = EMPTY + self.bufsize = bufsize + self.closed = False + + def _fetch(self): + if self.closed: + return + + line = self.rfile.readline() + self.bytes_read += len(line) + + if self.maxlen and self.bytes_read > self.maxlen: + raise MaxSizeExceeded("Request Entity Too Large", self.maxlen) + + line = line.strip().split(SEMICOLON, 1) + + try: + chunk_size = line.pop(0) + chunk_size = int(chunk_size, 16) + except ValueError: + raise ValueError("Bad chunked transfer size: " + repr(chunk_size)) + + if chunk_size <= 0: + self.closed = True + return + +## if line: chunk_extension = line[0] + + if self.maxlen and self.bytes_read + chunk_size > self.maxlen: + raise IOError("Request Entity Too Large") + + chunk = self.rfile.read(chunk_size) + self.bytes_read += len(chunk) + self.buffer += chunk + + crlf = self.rfile.read(2) + if crlf != CRLF: + raise ValueError( + "Bad chunked transfer coding (expected '\\r\\n', " + "got " + repr(crlf) + ")") + + def read(self, size=None): + data = EMPTY + while True: + if size and len(data) >= size: + return data + + if not self.buffer: + self._fetch() + if not self.buffer: + # EOF + return data + + if size: + remaining = size - len(data) + data += self.buffer[:remaining] + self.buffer = self.buffer[remaining:] + else: + data += self.buffer + + def readline(self, size=None): + data = EMPTY + while True: + if size and len(data) >= size: + return data + + if not self.buffer: + self._fetch() + if not self.buffer: + # EOF + return data + + newline_pos = self.buffer.find(LF) + if size: + if newline_pos == -1: + remaining = size - len(data) + data += self.buffer[:remaining] + self.buffer = self.buffer[remaining:] + else: + remaining = min(size - len(data), newline_pos) + data += self.buffer[:remaining] + self.buffer = self.buffer[remaining:] + else: + if newline_pos == -1: + data += self.buffer + else: + data += self.buffer[:newline_pos] + self.buffer = self.buffer[newline_pos:] + + def readlines(self, sizehint=0): + # Shamelessly stolen from StringIO + total = 0 + lines = [] + line = self.readline(sizehint) + while line: + lines.append(line) + total += len(line) + if 0 < sizehint <= total: + break + line = self.readline(sizehint) + return lines + + def read_trailer_lines(self): + if not self.closed: + raise ValueError( + "Cannot read trailers until the request body has been read.") + + while True: + line = self.rfile.readline() + if not line: + # No more data--illegal end of headers + raise ValueError("Illegal end of headers.") + + self.bytes_read += len(line) + if self.maxlen and self.bytes_read > self.maxlen: + raise IOError("Request Entity Too Large") + + if line == CRLF: + # Normal end of headers + break + if not line.endswith(CRLF): + raise ValueError("HTTP requires CRLF terminators") + + yield line + + def close(self): + self.rfile.close() + + def __iter__(self): + # Shamelessly stolen from StringIO + total = 0 + line = self.readline(sizehint) + while line: + yield line + total += len(line) + if 0 < sizehint <= total: + break + line = self.readline(sizehint) + + +class HTTPRequest(object): + """An HTTP Request (and response). + + A single HTTP connection may consist of multiple request/response pairs. + """ + + server = None + """The HTTPServer object which is receiving this request.""" + + conn = None + """The HTTPConnection object on which this request connected.""" + + inheaders = {} + """A dict of request headers.""" + + outheaders = [] + """A list of header tuples to write in the response.""" + + ready = False + """When True, the request has been parsed and is ready to begin generating + the response. When False, signals the calling Connection that the response + should not be generated and the connection should close.""" + + close_connection = False + """Signals the calling Connection that the request should close. This does + not imply an error! The client and/or server may each request that the + connection be closed.""" + + chunked_write = False + """If True, output will be encoded with the "chunked" transfer-coding. + + This value is set automatically inside send_headers.""" + + def __init__(self, server, conn): + self.server= server + self.conn = conn + + self.ready = False + self.started_request = False + self.scheme = ntob("http") + if self.server.ssl_adapter is not None: + self.scheme = ntob("https") + # Use the lowest-common protocol in case read_request_line errors. + self.response_protocol = 'HTTP/1.0' + self.inheaders = {} + + self.status = "" + self.outheaders = [] + self.sent_headers = False + self.close_connection = self.__class__.close_connection + self.chunked_read = False + self.chunked_write = self.__class__.chunked_write + + def parse_request(self): + """Parse the next HTTP request start-line and message-headers.""" + self.rfile = SizeCheckWrapper(self.conn.rfile, + self.server.max_request_header_size) + try: + success = self.read_request_line() + except MaxSizeExceeded: + self.simple_response("414 Request-URI Too Long", + "The Request-URI sent with the request exceeds the maximum " + "allowed bytes.") + return + else: + if not success: + return + + try: + success = self.read_request_headers() + except MaxSizeExceeded: + self.simple_response("413 Request Entity Too Large", + "The headers sent with the request exceed the maximum " + "allowed bytes.") + return + else: + if not success: + return + + self.ready = True + + def read_request_line(self): + # HTTP/1.1 connections are persistent by default. If a client + # requests a page, then idles (leaves the connection open), + # then rfile.readline() will raise socket.error("timed out"). + # Note that it does this based on the value given to settimeout(), + # and doesn't need the client to request or acknowledge the close + # (although your TCP stack might suffer for it: cf Apache's history + # with FIN_WAIT_2). + request_line = self.rfile.readline() + + # Set started_request to True so communicate() knows to send 408 + # from here on out. + self.started_request = True + if not request_line: + return False + + if request_line == CRLF: + # RFC 2616 sec 4.1: "...if the server is reading the protocol + # stream at the beginning of a message and receives a CRLF + # first, it should ignore the CRLF." + # But only ignore one leading line! else we enable a DoS. + request_line = self.rfile.readline() + if not request_line: + return False + + if not request_line.endswith(CRLF): + self.simple_response("400 Bad Request", "HTTP requires CRLF terminators") + return False + + try: + method, uri, req_protocol = request_line.strip().split(SPACE, 2) + # The [x:y] slicing is necessary for byte strings to avoid getting ord's + rp = int(req_protocol[5:6]), int(req_protocol[7:8]) + except ValueError: + self.simple_response("400 Bad Request", "Malformed Request-Line") + return False + + self.uri = uri + self.method = method + + # uri may be an abs_path (including "http://host.domain.tld"); + scheme, authority, path = self.parse_request_uri(uri) + if NUMBER_SIGN in path: + self.simple_response("400 Bad Request", + "Illegal #fragment in Request-URI.") + return False + + if scheme: + self.scheme = scheme + + qs = EMPTY + if QUESTION_MARK in path: + path, qs = path.split(QUESTION_MARK, 1) + + # Unquote the path+params (e.g. "/this%20path" -> "/this path"). + # http://www.w3.org/Protocols/rfc2616/rfc2616-sec5.html#sec5.1.2 + # + # But note that "...a URI must be separated into its components + # before the escaped characters within those components can be + # safely decoded." http://www.ietf.org/rfc/rfc2396.txt, sec 2.4.2 + # Therefore, "/this%2Fpath" becomes "/this%2Fpath", not "/this/path". + try: + atoms = [self.unquote_bytes(x) for x in quoted_slash.split(path)] + except ValueError: + ex = sys.exc_info()[1] + self.simple_response("400 Bad Request", ex.args[0]) + return False + path = b"%2F".join(atoms) + self.path = path + + # Note that, like wsgiref and most other HTTP servers, + # we "% HEX HEX"-unquote the path but not the query string. + self.qs = qs + + # Compare request and server HTTP protocol versions, in case our + # server does not support the requested protocol. Limit our output + # to min(req, server). We want the following output: + # request server actual written supported response + # protocol protocol response protocol feature set + # a 1.0 1.0 1.0 1.0 + # b 1.0 1.1 1.1 1.0 + # c 1.1 1.0 1.0 1.0 + # d 1.1 1.1 1.1 1.1 + # Notice that, in (b), the response will be "HTTP/1.1" even though + # the client only understands 1.0. RFC 2616 10.5.6 says we should + # only return 505 if the _major_ version is different. + # The [x:y] slicing is necessary for byte strings to avoid getting ord's + sp = int(self.server.protocol[5:6]), int(self.server.protocol[7:8]) + + if sp[0] != rp[0]: + self.simple_response("505 HTTP Version Not Supported") + return False + + self.request_protocol = req_protocol + self.response_protocol = "HTTP/%s.%s" % min(rp, sp) + return True + + def read_request_headers(self): + """Read self.rfile into self.inheaders. Return success.""" + + # then all the http headers + try: + read_headers(self.rfile, self.inheaders) + except ValueError: + ex = sys.exc_info()[1] + self.simple_response("400 Bad Request", ex.args[0]) + return False + + mrbs = self.server.max_request_body_size + if mrbs and int(self.inheaders.get(b"Content-Length", 0)) > mrbs: + self.simple_response("413 Request Entity Too Large", + "The entity sent with the request exceeds the maximum " + "allowed bytes.") + return False + + # Persistent connection support + if self.response_protocol == "HTTP/1.1": + # Both server and client are HTTP/1.1 + if self.inheaders.get(b"Connection", b"") == b"close": + self.close_connection = True + else: + # Either the server or client (or both) are HTTP/1.0 + if self.inheaders.get(b"Connection", b"") != b"Keep-Alive": + self.close_connection = True + + # Transfer-Encoding support + te = None + if self.response_protocol == "HTTP/1.1": + te = self.inheaders.get(b"Transfer-Encoding") + if te: + te = [x.strip().lower() for x in te.split(b",") if x.strip()] + + self.chunked_read = False + + if te: + for enc in te: + if enc == b"chunked": + self.chunked_read = True + else: + # Note that, even if we see "chunked", we must reject + # if there is an extension we don't recognize. + self.simple_response("501 Unimplemented") + self.close_connection = True + return False + + # From PEP 333: + # "Servers and gateways that implement HTTP 1.1 must provide + # transparent support for HTTP 1.1's "expect/continue" mechanism. + # This may be done in any of several ways: + # 1. Respond to requests containing an Expect: 100-continue request + # with an immediate "100 Continue" response, and proceed normally. + # 2. Proceed with the request normally, but provide the application + # with a wsgi.input stream that will send the "100 Continue" + # response if/when the application first attempts to read from + # the input stream. The read request must then remain blocked + # until the client responds. + # 3. Wait until the client decides that the server does not support + # expect/continue, and sends the request body on its own. + # (This is suboptimal, and is not recommended.) + # + # We used to do 3, but are now doing 1. Maybe we'll do 2 someday, + # but it seems like it would be a big slowdown for such a rare case. + if self.inheaders.get(b"Expect", b"") == b"100-continue": + # Don't use simple_response here, because it emits headers + # we don't want. See http://www.cherrypy.org/ticket/951 + msg = self.server.protocol.encode('ascii') + b" 100 Continue\r\n\r\n" + try: + self.conn.wfile.write(msg) + except socket.error: + x = sys.exc_info()[1] + if x.args[0] not in socket_errors_to_ignore: + raise + return True + + def parse_request_uri(self, uri): + """Parse a Request-URI into (scheme, authority, path). + + Note that Request-URI's must be one of:: + + Request-URI = "*" | absoluteURI | abs_path | authority + + Therefore, a Request-URI which starts with a double forward-slash + cannot be a "net_path":: + + net_path = "//" authority [ abs_path ] + + Instead, it must be interpreted as an "abs_path" with an empty first + path segment:: + + abs_path = "/" path_segments + path_segments = segment *( "/" segment ) + segment = *pchar *( ";" param ) + param = *pchar + """ + if uri == ASTERISK: + return None, None, uri + + scheme, sep, remainder = uri.partition(b'://') + if sep and QUESTION_MARK not in scheme: + # An absoluteURI. + # If there's a scheme (and it must be http or https), then: + # http_URL = "http:" "//" host [ ":" port ] [ abs_path [ "?" query ]] + authority, path_a, path_b = remainder.partition(FORWARD_SLASH) + return scheme.lower(), authority, path_a+path_b + + if uri.startswith(FORWARD_SLASH): + # An abs_path. + return None, None, uri + else: + # An authority. + return None, uri, None + + def unquote_bytes(self, path): + """takes quoted string and unquotes % encoded values""" + res = path.split(b'%') + + for i in range(1, len(res)): + item = res[i] + try: + res[i] = bytes([int(item[:2], 16)]) + item[2:] + except ValueError: + raise + return b''.join(res) + + def respond(self): + """Call the gateway and write its iterable output.""" + mrbs = self.server.max_request_body_size + if self.chunked_read: + self.rfile = ChunkedRFile(self.conn.rfile, mrbs) + else: + cl = int(self.inheaders.get(b"Content-Length", 0)) + if mrbs and mrbs < cl: + if not self.sent_headers: + self.simple_response("413 Request Entity Too Large", + "The entity sent with the request exceeds the maximum " + "allowed bytes.") + return + self.rfile = KnownLengthRFile(self.conn.rfile, cl) + + self.server.gateway(self).respond() + + if (self.ready and not self.sent_headers): + self.sent_headers = True + self.send_headers() + if self.chunked_write: + self.conn.wfile.write(b"0\r\n\r\n") + + def simple_response(self, status, msg=""): + """Write a simple response back to the client.""" + status = str(status) + buf = [bytes(self.server.protocol, "ascii") + SPACE + + bytes(status, "ISO-8859-1") + CRLF, + bytes("Content-Length: %s\r\n" % len(msg), "ISO-8859-1"), + b"Content-Type: text/plain\r\n"] + + if status[:3] in ("413", "414"): + # Request Entity Too Large / Request-URI Too Long + self.close_connection = True + if self.response_protocol == 'HTTP/1.1': + # This will not be true for 414, since read_request_line + # usually raises 414 before reading the whole line, and we + # therefore cannot know the proper response_protocol. + buf.append(b"Connection: close\r\n") + else: + # HTTP/1.0 had no 413/414 status nor Connection header. + # Emit 400 instead and trust the message body is enough. + status = "400 Bad Request" + + buf.append(CRLF) + if msg: + if isinstance(msg, unicodestr): + msg = msg.encode("ISO-8859-1") + buf.append(msg) + + try: + self.conn.wfile.write(b"".join(buf)) + except socket.error: + x = sys.exc_info()[1] + if x.args[0] not in socket_errors_to_ignore: + raise + + def write(self, chunk): + """Write unbuffered data to the client.""" + if self.chunked_write and chunk: + buf = [bytes(hex(len(chunk)), 'ASCII')[2:], CRLF, chunk, CRLF] + self.conn.wfile.write(EMPTY.join(buf)) + else: + self.conn.wfile.write(chunk) + + def send_headers(self): + """Assert, process, and send the HTTP response message-headers. + + You must set self.status, and self.outheaders before calling this. + """ + hkeys = [key.lower() for key, value in self.outheaders] + status = int(self.status[:3]) + + if status == 413: + # Request Entity Too Large. Close conn to avoid garbage. + self.close_connection = True + elif b"content-length" not in hkeys: + # "All 1xx (informational), 204 (no content), + # and 304 (not modified) responses MUST NOT + # include a message-body." So no point chunking. + if status < 200 or status in (204, 205, 304): + pass + else: + if (self.response_protocol == 'HTTP/1.1' + and self.method != b'HEAD'): + # Use the chunked transfer-coding + self.chunked_write = True + self.outheaders.append((b"Transfer-Encoding", b"chunked")) + else: + # Closing the conn is the only way to determine len. + self.close_connection = True + + if b"connection" not in hkeys: + if self.response_protocol == 'HTTP/1.1': + # Both server and client are HTTP/1.1 or better + if self.close_connection: + self.outheaders.append((b"Connection", b"close")) + else: + # Server and/or client are HTTP/1.0 + if not self.close_connection: + self.outheaders.append((b"Connection", b"Keep-Alive")) + + if (not self.close_connection) and (not self.chunked_read): + # Read any remaining request body data on the socket. + # "If an origin server receives a request that does not include an + # Expect request-header field with the "100-continue" expectation, + # the request includes a request body, and the server responds + # with a final status code before reading the entire request body + # from the transport connection, then the server SHOULD NOT close + # the transport connection until it has read the entire request, + # or until the client closes the connection. Otherwise, the client + # might not reliably receive the response message. However, this + # requirement is not be construed as preventing a server from + # defending itself against denial-of-service attacks, or from + # badly broken client implementations." + remaining = getattr(self.rfile, 'remaining', 0) + if remaining > 0: + self.rfile.read(remaining) + + if b"date" not in hkeys: + self.outheaders.append( + (b"Date", email.utils.formatdate(usegmt=True).encode('ISO-8859-1'))) + + if b"server" not in hkeys: + self.outheaders.append( + (b"Server", self.server.server_name.encode('ISO-8859-1'))) + + buf = [self.server.protocol.encode('ascii') + SPACE + self.status + CRLF] + for k, v in self.outheaders: + buf.append(k + COLON + SPACE + v + CRLF) + buf.append(CRLF) + self.conn.wfile.write(EMPTY.join(buf)) + + +class NoSSLError(Exception): + """Exception raised when a client speaks HTTP to an HTTPS socket.""" + pass + + +class FatalSSLAlert(Exception): + """Exception raised when the SSL implementation signals a fatal alert.""" + pass + + +class CP_BufferedWriter(io.BufferedWriter): + """Faux file object attached to a socket object.""" + + def write(self, b): + self._checkClosed() + if isinstance(b, str): + raise TypeError("can't write str to binary stream") + + with self._write_lock: + self._write_buf.extend(b) + self._flush_unlocked() + return len(b) + + def _flush_unlocked(self): + self._checkClosed("flush of closed file") + while self._write_buf: + try: + # ssl sockets only except 'bytes', not bytearrays + # so perhaps we should conditionally wrap this for perf? + n = self.raw.write(bytes(self._write_buf)) + except io.BlockingIOError as e: + n = e.characters_written + del self._write_buf[:n] + + +def CP_makefile(sock, mode='r', bufsize=DEFAULT_BUFFER_SIZE): + if 'r' in mode: + return io.BufferedReader(socket.SocketIO(sock, mode), bufsize) + else: + return CP_BufferedWriter(socket.SocketIO(sock, mode), bufsize) + +class HTTPConnection(object): + """An HTTP connection (active socket). + + server: the Server object which received this connection. + socket: the raw socket object (usually TCP) for this connection. + makefile: a fileobject class for reading from the socket. + """ + + remote_addr = None + remote_port = None + ssl_env = None + rbufsize = DEFAULT_BUFFER_SIZE + wbufsize = DEFAULT_BUFFER_SIZE + RequestHandlerClass = HTTPRequest + + def __init__(self, server, sock, makefile=CP_makefile): + self.server = server + self.socket = sock + self.rfile = makefile(sock, "rb", self.rbufsize) + self.wfile = makefile(sock, "wb", self.wbufsize) + self.requests_seen = 0 + + def communicate(self): + """Read each request and respond appropriately.""" + request_seen = False + try: + while True: + # (re)set req to None so that if something goes wrong in + # the RequestHandlerClass constructor, the error doesn't + # get written to the previous request. + req = None + req = self.RequestHandlerClass(self.server, self) + + # This order of operations should guarantee correct pipelining. + req.parse_request() + if self.server.stats['Enabled']: + self.requests_seen += 1 + if not req.ready: + # Something went wrong in the parsing (and the server has + # probably already made a simple_response). Return and + # let the conn close. + return + + request_seen = True + req.respond() + if req.close_connection: + return + except socket.error: + e = sys.exc_info()[1] + errnum = e.args[0] + # sadly SSL sockets return a different (longer) time out string + if errnum == 'timed out' or errnum == 'The read operation timed out': + # Don't error if we're between requests; only error + # if 1) no request has been started at all, or 2) we're + # in the middle of a request. + # See http://www.cherrypy.org/ticket/853 + if (not request_seen) or (req and req.started_request): + # Don't bother writing the 408 if the response + # has already started being written. + if req and not req.sent_headers: + try: + req.simple_response("408 Request Timeout") + except FatalSSLAlert: + # Close the connection. + return + elif errnum not in socket_errors_to_ignore: + self.server.error_log("socket.error %s" % repr(errnum), + level=logging.WARNING, traceback=True) + if req and not req.sent_headers: + try: + req.simple_response("500 Internal Server Error") + except FatalSSLAlert: + # Close the connection. + return + return + except (KeyboardInterrupt, SystemExit): + raise + except FatalSSLAlert: + # Close the connection. + return + except NoSSLError: + if req and not req.sent_headers: + # Unwrap our wfile + self.wfile = CP_makefile(self.socket._sock, "wb", self.wbufsize) + req.simple_response("400 Bad Request", + "The client sent a plain HTTP request, but " + "this server only speaks HTTPS on this port.") + self.linger = True + except Exception: + e = sys.exc_info()[1] + self.server.error_log(repr(e), level=logging.ERROR, traceback=True) + if req and not req.sent_headers: + try: + req.simple_response("500 Internal Server Error") + except FatalSSLAlert: + # Close the connection. + return + + linger = False + + def close(self): + """Close the socket underlying this connection.""" + self.rfile.close() + + if not self.linger: + # Python's socket module does NOT call close on the kernel socket + # when you call socket.close(). We do so manually here because we + # want this server to send a FIN TCP segment immediately. Note this + # must be called *before* calling socket.close(), because the latter + # drops its reference to the kernel socket. + # Python 3 *probably* fixed this with socket._real_close; hard to tell. +## self.socket._sock.close() + self.socket.close() + else: + # On the other hand, sometimes we want to hang around for a bit + # to make sure the client has a chance to read our entire + # response. Skipping the close() calls here delays the FIN + # packet until the socket object is garbage-collected later. + # Someday, perhaps, we'll do the full lingering_close that + # Apache does, but not today. + pass + + +class TrueyZero(object): + """An object which equals and does math like the integer '0' but evals True.""" + def __add__(self, other): + return other + def __radd__(self, other): + return other +trueyzero = TrueyZero() + + +_SHUTDOWNREQUEST = None + +class WorkerThread(threading.Thread): + """Thread which continuously polls a Queue for Connection objects. + + Due to the timing issues of polling a Queue, a WorkerThread does not + check its own 'ready' flag after it has started. To stop the thread, + it is necessary to stick a _SHUTDOWNREQUEST object onto the Queue + (one for each running WorkerThread). + """ + + conn = None + """The current connection pulled off the Queue, or None.""" + + server = None + """The HTTP Server which spawned this thread, and which owns the + Queue and is placing active connections into it.""" + + ready = False + """A simple flag for the calling server to know when this thread + has begun polling the Queue.""" + + + def __init__(self, server): + self.ready = False + self.server = server + + self.requests_seen = 0 + self.bytes_read = 0 + self.bytes_written = 0 + self.start_time = None + self.work_time = 0 + self.stats = { + 'Requests': lambda s: self.requests_seen + ((self.start_time is None) and trueyzero or self.conn.requests_seen), + 'Bytes Read': lambda s: self.bytes_read + ((self.start_time is None) and trueyzero or self.conn.rfile.bytes_read), + 'Bytes Written': lambda s: self.bytes_written + ((self.start_time is None) and trueyzero or self.conn.wfile.bytes_written), + 'Work Time': lambda s: self.work_time + ((self.start_time is None) and trueyzero or time.time() - self.start_time), + 'Read Throughput': lambda s: s['Bytes Read'](s) / (s['Work Time'](s) or 1e-6), + 'Write Throughput': lambda s: s['Bytes Written'](s) / (s['Work Time'](s) or 1e-6), + } + threading.Thread.__init__(self) + + def run(self): + self.server.stats['Worker Threads'][self.getName()] = self.stats + try: + self.ready = True + while True: + conn = self.server.requests.get() + if conn is _SHUTDOWNREQUEST: + return + + self.conn = conn + if self.server.stats['Enabled']: + self.start_time = time.time() + try: + conn.communicate() + finally: + conn.close() + if self.server.stats['Enabled']: + self.requests_seen += self.conn.requests_seen + self.bytes_read += self.conn.rfile.bytes_read + self.bytes_written += self.conn.wfile.bytes_written + self.work_time += time.time() - self.start_time + self.start_time = None + self.conn = None + except (KeyboardInterrupt, SystemExit): + exc = sys.exc_info()[1] + self.server.interrupt = exc + + +class ThreadPool(object): + """A Request Queue for an HTTPServer which pools threads. + + ThreadPool objects must provide min, get(), put(obj), start() + and stop(timeout) attributes. + """ + + def __init__(self, server, min=10, max=-1): + self.server = server + self.min = min + self.max = max + self._threads = [] + self._queue = queue.Queue() + self.get = self._queue.get + + def start(self): + """Start the pool of threads.""" + for i in range(self.min): + self._threads.append(WorkerThread(self.server)) + for worker in self._threads: + worker.setName("CP Server " + worker.getName()) + worker.start() + for worker in self._threads: + while not worker.ready: + time.sleep(.1) + + def _get_idle(self): + """Number of worker threads which are idle. Read-only.""" + return len([t for t in self._threads if t.conn is None]) + idle = property(_get_idle, doc=_get_idle.__doc__) + + def put(self, obj): + self._queue.put(obj) + if obj is _SHUTDOWNREQUEST: + return + + def grow(self, amount): + """Spawn new worker threads (not above self.max).""" + for i in range(amount): + if self.max > 0 and len(self._threads) >= self.max: + break + worker = WorkerThread(self.server) + worker.setName("CP Server " + worker.getName()) + self._threads.append(worker) + worker.start() + + def shrink(self, amount): + """Kill off worker threads (not below self.min).""" + # Grow/shrink the pool if necessary. + # Remove any dead threads from our list + for t in self._threads: + if not t.isAlive(): + self._threads.remove(t) + amount -= 1 + + if amount > 0: + for i in range(min(amount, len(self._threads) - self.min)): + # Put a number of shutdown requests on the queue equal + # to 'amount'. Once each of those is processed by a worker, + # that worker will terminate and be culled from our list + # in self.put. + self._queue.put(_SHUTDOWNREQUEST) + + def stop(self, timeout=5): + # Must shut down threads here so the code that calls + # this method can know when all threads are stopped. + for worker in self._threads: + self._queue.put(_SHUTDOWNREQUEST) + + # Don't join currentThread (when stop is called inside a request). + current = threading.currentThread() + if timeout and timeout >= 0: + endtime = time.time() + timeout + while self._threads: + worker = self._threads.pop() + if worker is not current and worker.isAlive(): + try: + if timeout is None or timeout < 0: + worker.join() + else: + remaining_time = endtime - time.time() + if remaining_time > 0: + worker.join(remaining_time) + if worker.isAlive(): + # We exhausted the timeout. + # Forcibly shut down the socket. + c = worker.conn + if c and not c.rfile.closed: + try: + c.socket.shutdown(socket.SHUT_RD) + except TypeError: + # pyOpenSSL sockets don't take an arg + c.socket.shutdown() + worker.join() + except (AssertionError, + # Ignore repeated Ctrl-C. + # See http://www.cherrypy.org/ticket/691. + KeyboardInterrupt): + pass + + def _get_qsize(self): + return self._queue.qsize() + qsize = property(_get_qsize) + + + +try: + import fcntl +except ImportError: + try: + from ctypes import windll, WinError + except ImportError: + def prevent_socket_inheritance(sock): + """Dummy function, since neither fcntl nor ctypes are available.""" + pass + else: + def prevent_socket_inheritance(sock): + """Mark the given socket fd as non-inheritable (Windows).""" + if not windll.kernel32.SetHandleInformation(sock.fileno(), 1, 0): + raise WinError() +else: + def prevent_socket_inheritance(sock): + """Mark the given socket fd as non-inheritable (POSIX).""" + fd = sock.fileno() + old_flags = fcntl.fcntl(fd, fcntl.F_GETFD) + fcntl.fcntl(fd, fcntl.F_SETFD, old_flags | fcntl.FD_CLOEXEC) + + +class SSLAdapter(object): + """Base class for SSL driver library adapters. + + Required methods: + + * ``wrap(sock) -> (wrapped socket, ssl environ dict)`` + * ``makefile(sock, mode='r', bufsize=DEFAULT_BUFFER_SIZE) -> socket file object`` + """ + + def __init__(self, certificate, private_key, certificate_chain=None): + self.certificate = certificate + self.private_key = private_key + self.certificate_chain = certificate_chain + + def wrap(self, sock): + raise NotImplemented + + def makefile(self, sock, mode='r', bufsize=DEFAULT_BUFFER_SIZE): + raise NotImplemented + + +class HTTPServer(object): + """An HTTP server.""" + + _bind_addr = "127.0.0.1" + _interrupt = None + + gateway = None + """A Gateway instance.""" + + minthreads = None + """The minimum number of worker threads to create (default 10).""" + + maxthreads = None + """The maximum number of worker threads to create (default -1 = no limit).""" + + server_name = None + """The name of the server; defaults to socket.gethostname().""" + + protocol = "HTTP/1.1" + """The version string to write in the Status-Line of all HTTP responses. + + For example, "HTTP/1.1" is the default. This also limits the supported + features used in the response.""" + + request_queue_size = 5 + """The 'backlog' arg to socket.listen(); max queued connections (default 5).""" + + shutdown_timeout = 5 + """The total time, in seconds, to wait for worker threads to cleanly exit.""" + + timeout = 10 + """The timeout in seconds for accepted connections (default 10).""" + + version = "CherryPy/3.2.2" + """A version string for the HTTPServer.""" + + software = None + """The value to set for the SERVER_SOFTWARE entry in the WSGI environ. + + If None, this defaults to ``'%s Server' % self.version``.""" + + ready = False + """An internal flag which marks whether the socket is accepting connections.""" + + max_request_header_size = 0 + """The maximum size, in bytes, for request headers, or 0 for no limit.""" + + max_request_body_size = 0 + """The maximum size, in bytes, for request bodies, or 0 for no limit.""" + + nodelay = True + """If True (the default since 3.1), sets the TCP_NODELAY socket option.""" + + ConnectionClass = HTTPConnection + """The class to use for handling HTTP connections.""" + + ssl_adapter = None + """An instance of SSLAdapter (or a subclass). + + You must have the corresponding SSL driver library installed.""" + + def __init__(self, bind_addr, gateway, minthreads=10, maxthreads=-1, + server_name=None): + self.bind_addr = bind_addr + self.gateway = gateway + + self.requests = ThreadPool(self, min=minthreads or 1, max=maxthreads) + + if not server_name: + server_name = socket.gethostname() + self.server_name = server_name + self.clear_stats() + + def clear_stats(self): + self._start_time = None + self._run_time = 0 + self.stats = { + 'Enabled': False, + 'Bind Address': lambda s: repr(self.bind_addr), + 'Run time': lambda s: (not s['Enabled']) and -1 or self.runtime(), + 'Accepts': 0, + 'Accepts/sec': lambda s: s['Accepts'] / self.runtime(), + 'Queue': lambda s: getattr(self.requests, "qsize", None), + 'Threads': lambda s: len(getattr(self.requests, "_threads", [])), + 'Threads Idle': lambda s: getattr(self.requests, "idle", None), + 'Socket Errors': 0, + 'Requests': lambda s: (not s['Enabled']) and -1 or sum([w['Requests'](w) for w + in s['Worker Threads'].values()], 0), + 'Bytes Read': lambda s: (not s['Enabled']) and -1 or sum([w['Bytes Read'](w) for w + in s['Worker Threads'].values()], 0), + 'Bytes Written': lambda s: (not s['Enabled']) and -1 or sum([w['Bytes Written'](w) for w + in s['Worker Threads'].values()], 0), + 'Work Time': lambda s: (not s['Enabled']) and -1 or sum([w['Work Time'](w) for w + in s['Worker Threads'].values()], 0), + 'Read Throughput': lambda s: (not s['Enabled']) and -1 or sum( + [w['Bytes Read'](w) / (w['Work Time'](w) or 1e-6) + for w in s['Worker Threads'].values()], 0), + 'Write Throughput': lambda s: (not s['Enabled']) and -1 or sum( + [w['Bytes Written'](w) / (w['Work Time'](w) or 1e-6) + for w in s['Worker Threads'].values()], 0), + 'Worker Threads': {}, + } + logging.statistics["CherryPy HTTPServer %d" % id(self)] = self.stats + + def runtime(self): + if self._start_time is None: + return self._run_time + else: + return self._run_time + (time.time() - self._start_time) + + def __str__(self): + return "%s.%s(%r)" % (self.__module__, self.__class__.__name__, + self.bind_addr) + + def _get_bind_addr(self): + return self._bind_addr + def _set_bind_addr(self, value): + if isinstance(value, tuple) and value[0] in ('', None): + # Despite the socket module docs, using '' does not + # allow AI_PASSIVE to work. Passing None instead + # returns '0.0.0.0' like we want. In other words: + # host AI_PASSIVE result + # '' Y 192.168.x.y + # '' N 192.168.x.y + # None Y 0.0.0.0 + # None N 127.0.0.1 + # But since you can get the same effect with an explicit + # '0.0.0.0', we deny both the empty string and None as values. + raise ValueError("Host values of '' or None are not allowed. " + "Use '0.0.0.0' (IPv4) or '::' (IPv6) instead " + "to listen on all active interfaces.") + self._bind_addr = value + bind_addr = property(_get_bind_addr, _set_bind_addr, + doc="""The interface on which to listen for connections. + + For TCP sockets, a (host, port) tuple. Host values may be any IPv4 + or IPv6 address, or any valid hostname. The string 'localhost' is a + synonym for '127.0.0.1' (or '::1', if your hosts file prefers IPv6). + The string '0.0.0.0' is a special IPv4 entry meaning "any active + interface" (INADDR_ANY), and '::' is the similar IN6ADDR_ANY for + IPv6. The empty string or None are not allowed. + + For UNIX sockets, supply the filename as a string.""") + + def start(self): + """Run the server forever.""" + # We don't have to trap KeyboardInterrupt or SystemExit here, + # because cherrpy.server already does so, calling self.stop() for us. + # If you're using this server with another framework, you should + # trap those exceptions in whatever code block calls start(). + self._interrupt = None + + if self.software is None: + self.software = "%s Server" % self.version + + # Select the appropriate socket + if isinstance(self.bind_addr, basestring): + # AF_UNIX socket + + # So we can reuse the socket... + try: os.unlink(self.bind_addr) + except: pass + + # So everyone can access the socket... + try: os.chmod(self.bind_addr, 511) # 0777 + except: pass + + info = [(socket.AF_UNIX, socket.SOCK_STREAM, 0, "", self.bind_addr)] + else: + # AF_INET or AF_INET6 socket + # Get the correct address family for our host (allows IPv6 addresses) + host, port = self.bind_addr + try: + info = socket.getaddrinfo(host, port, socket.AF_UNSPEC, + socket.SOCK_STREAM, 0, socket.AI_PASSIVE) + except socket.gaierror: + if ':' in self.bind_addr[0]: + info = [(socket.AF_INET6, socket.SOCK_STREAM, + 0, "", self.bind_addr + (0, 0))] + else: + info = [(socket.AF_INET, socket.SOCK_STREAM, + 0, "", self.bind_addr)] + + self.socket = None + msg = "No socket could be created" + for res in info: + af, socktype, proto, canonname, sa = res + try: + self.bind(af, socktype, proto) + except socket.error: + if self.socket: + self.socket.close() + self.socket = None + continue + break + if not self.socket: + raise socket.error(msg) + + # Timeout so KeyboardInterrupt can be caught on Win32 + self.socket.settimeout(1) + self.socket.listen(self.request_queue_size) + + # Create worker threads + self.requests.start() + + self.ready = True + self._start_time = time.time() + while self.ready: + try: + self.tick() + except (KeyboardInterrupt, SystemExit): + raise + except: + self.error_log("Error in HTTPServer.tick", level=logging.ERROR, + traceback=True) + if self.interrupt: + while self.interrupt is True: + # Wait for self.stop() to complete. See _set_interrupt. + time.sleep(0.1) + if self.interrupt: + raise self.interrupt + + def error_log(self, msg="", level=20, traceback=False): + # Override this in subclasses as desired + sys.stderr.write(msg + '\n') + sys.stderr.flush() + if traceback: + tblines = format_exc() + sys.stderr.write(tblines) + sys.stderr.flush() + + def bind(self, family, type, proto=0): + """Create (or recreate) the actual socket object.""" + self.socket = socket.socket(family, type, proto) + prevent_socket_inheritance(self.socket) + self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + if self.nodelay and not isinstance(self.bind_addr, str): + self.socket.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) + + if self.ssl_adapter is not None: + self.socket = self.ssl_adapter.bind(self.socket) + + # If listening on the IPV6 any address ('::' = IN6ADDR_ANY), + # activate dual-stack. See http://www.cherrypy.org/ticket/871. + if (hasattr(socket, 'AF_INET6') and family == socket.AF_INET6 + and self.bind_addr[0] in ('::', '::0', '::0.0.0.0')): + try: + self.socket.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_V6ONLY, 0) + except (AttributeError, socket.error): + # Apparently, the socket option is not available in + # this machine's TCP stack + pass + + self.socket.bind(self.bind_addr) + + def tick(self): + """Accept a new connection and put it on the Queue.""" + try: + s, addr = self.socket.accept() + if self.stats['Enabled']: + self.stats['Accepts'] += 1 + if not self.ready: + return + + prevent_socket_inheritance(s) + if hasattr(s, 'settimeout'): + s.settimeout(self.timeout) + + makefile = CP_makefile + ssl_env = {} + # if ssl cert and key are set, we try to be a secure HTTP server + if self.ssl_adapter is not None: + try: + s, ssl_env = self.ssl_adapter.wrap(s) + except NoSSLError: + msg = ("The client sent a plain HTTP request, but " + "this server only speaks HTTPS on this port.") + buf = ["%s 400 Bad Request\r\n" % self.protocol, + "Content-Length: %s\r\n" % len(msg), + "Content-Type: text/plain\r\n\r\n", + msg] + + wfile = makefile(s, "wb", DEFAULT_BUFFER_SIZE) + try: + wfile.write("".join(buf).encode('ISO-8859-1')) + except socket.error: + x = sys.exc_info()[1] + if x.args[0] not in socket_errors_to_ignore: + raise + return + if not s: + return + makefile = self.ssl_adapter.makefile + # Re-apply our timeout since we may have a new socket object + if hasattr(s, 'settimeout'): + s.settimeout(self.timeout) + + conn = self.ConnectionClass(self, s, makefile) + + if not isinstance(self.bind_addr, basestring): + # optional values + # Until we do DNS lookups, omit REMOTE_HOST + if addr is None: # sometimes this can happen + # figure out if AF_INET or AF_INET6. + if len(s.getsockname()) == 2: + # AF_INET + addr = ('0.0.0.0', 0) + else: + # AF_INET6 + addr = ('::', 0) + conn.remote_addr = addr[0] + conn.remote_port = addr[1] + + conn.ssl_env = ssl_env + + self.requests.put(conn) + except socket.timeout: + # The only reason for the timeout in start() is so we can + # notice keyboard interrupts on Win32, which don't interrupt + # accept() by default + return + except socket.error: + x = sys.exc_info()[1] + if self.stats['Enabled']: + self.stats['Socket Errors'] += 1 + if x.args[0] in socket_error_eintr: + # I *think* this is right. EINTR should occur when a signal + # is received during the accept() call; all docs say retry + # the call, and I *think* I'm reading it right that Python + # will then go ahead and poll for and handle the signal + # elsewhere. See http://www.cherrypy.org/ticket/707. + return + if x.args[0] in socket_errors_nonblocking: + # Just try again. See http://www.cherrypy.org/ticket/479. + return + if x.args[0] in socket_errors_to_ignore: + # Our socket was closed. + # See http://www.cherrypy.org/ticket/686. + return + raise + + def _get_interrupt(self): + return self._interrupt + def _set_interrupt(self, interrupt): + self._interrupt = True + self.stop() + self._interrupt = interrupt + interrupt = property(_get_interrupt, _set_interrupt, + doc="Set this to an Exception instance to " + "interrupt the server.") + + def stop(self): + """Gracefully shutdown a server that is serving forever.""" + self.ready = False + if self._start_time is not None: + self._run_time += (time.time() - self._start_time) + self._start_time = None + + sock = getattr(self, "socket", None) + if sock: + if not isinstance(self.bind_addr, basestring): + # Touch our own socket to make accept() return immediately. + try: + host, port = sock.getsockname()[:2] + except socket.error: + x = sys.exc_info()[1] + if x.args[0] not in socket_errors_to_ignore: + # Changed to use error code and not message + # See http://www.cherrypy.org/ticket/860. + raise + else: + # Note that we're explicitly NOT using AI_PASSIVE, + # here, because we want an actual IP to touch. + # localhost won't work if we've bound to a public IP, + # but it will if we bound to '0.0.0.0' (INADDR_ANY). + for res in socket.getaddrinfo(host, port, socket.AF_UNSPEC, + socket.SOCK_STREAM): + af, socktype, proto, canonname, sa = res + s = None + try: + s = socket.socket(af, socktype, proto) + # See http://groups.google.com/group/cherrypy-users/ + # browse_frm/thread/bbfe5eb39c904fe0 + s.settimeout(1.0) + s.connect((host, port)) + s.close() + except socket.error: + if s: + s.close() + if hasattr(sock, "close"): + sock.close() + self.socket = None + + self.requests.stop(self.shutdown_timeout) + + +class Gateway(object): + """A base class to interface HTTPServer with other systems, such as WSGI.""" + + def __init__(self, req): + self.req = req + + def respond(self): + """Process the current request. Must be overridden in a subclass.""" + raise NotImplemented + + +# These may either be wsgiserver.SSLAdapter subclasses or the string names +# of such classes (in which case they will be lazily loaded). +ssl_adapters = { + 'builtin': 'cherrypy.wsgiserver.ssl_builtin.BuiltinSSLAdapter', + } + +def get_ssl_adapter_class(name='builtin'): + """Return an SSL adapter class for the given name.""" + adapter = ssl_adapters[name.lower()] + if isinstance(adapter, basestring): + last_dot = adapter.rfind(".") + attr_name = adapter[last_dot + 1:] + mod_path = adapter[:last_dot] + + try: + mod = sys.modules[mod_path] + if mod is None: + raise KeyError() + except KeyError: + # The last [''] is important. + mod = __import__(mod_path, globals(), locals(), ['']) + + # Let an AttributeError propagate outward. + try: + adapter = getattr(mod, attr_name) + except AttributeError: + raise AttributeError("'%s' object has no attribute '%s'" + % (mod_path, attr_name)) + + return adapter + +# -------------------------------- WSGI Stuff -------------------------------- # + + +class CherryPyWSGIServer(HTTPServer): + """A subclass of HTTPServer which calls a WSGI application.""" + + wsgi_version = (1, 0) + """The version of WSGI to produce.""" + + def __init__(self, bind_addr, wsgi_app, numthreads=10, server_name=None, + max=-1, request_queue_size=5, timeout=10, shutdown_timeout=5): + self.requests = ThreadPool(self, min=numthreads or 1, max=max) + self.wsgi_app = wsgi_app + self.gateway = wsgi_gateways[self.wsgi_version] + + self.bind_addr = bind_addr + if not server_name: + server_name = socket.gethostname() + self.server_name = server_name + self.request_queue_size = request_queue_size + + self.timeout = timeout + self.shutdown_timeout = shutdown_timeout + self.clear_stats() + + def _get_numthreads(self): + return self.requests.min + def _set_numthreads(self, value): + self.requests.min = value + numthreads = property(_get_numthreads, _set_numthreads) + + +class WSGIGateway(Gateway): + """A base class to interface HTTPServer with WSGI.""" + + def __init__(self, req): + self.req = req + self.started_response = False + self.env = self.get_environ() + self.remaining_bytes_out = None + + def get_environ(self): + """Return a new environ dict targeting the given wsgi.version""" + raise NotImplemented + + def respond(self): + """Process the current request.""" + response = self.req.server.wsgi_app(self.env, self.start_response) + try: + for chunk in response: + # "The start_response callable must not actually transmit + # the response headers. Instead, it must store them for the + # server or gateway to transmit only after the first + # iteration of the application return value that yields + # a NON-EMPTY string, or upon the application's first + # invocation of the write() callable." (PEP 333) + if chunk: + if isinstance(chunk, unicodestr): + chunk = chunk.encode('ISO-8859-1') + self.write(chunk) + finally: + if hasattr(response, "close"): + response.close() + + def start_response(self, status, headers, exc_info = None): + """WSGI callable to begin the HTTP response.""" + # "The application may call start_response more than once, + # if and only if the exc_info argument is provided." + if self.started_response and not exc_info: + raise AssertionError("WSGI start_response called a second " + "time with no exc_info.") + self.started_response = True + + # "if exc_info is provided, and the HTTP headers have already been + # sent, start_response must raise an error, and should raise the + # exc_info tuple." + if self.req.sent_headers: + try: + raise exc_info[0](exc_info[1]).with_traceback(exc_info[2]) + finally: + exc_info = None + + # According to PEP 3333, when using Python 3, the response status + # and headers must be bytes masquerading as unicode; that is, they + # must be of type "str" but are restricted to code points in the + # "latin-1" set. + if not isinstance(status, str): + raise TypeError("WSGI response status is not of type str.") + self.req.status = status.encode('ISO-8859-1') + + for k, v in headers: + if not isinstance(k, str): + raise TypeError("WSGI response header key %r is not of type str." % k) + if not isinstance(v, str): + raise TypeError("WSGI response header value %r is not of type str." % v) + if k.lower() == 'content-length': + self.remaining_bytes_out = int(v) + self.req.outheaders.append((k.encode('ISO-8859-1'), v.encode('ISO-8859-1'))) + + return self.write + + def write(self, chunk): + """WSGI callable to write unbuffered data to the client. + + This method is also used internally by start_response (to write + data from the iterable returned by the WSGI application). + """ + if not self.started_response: + raise AssertionError("WSGI write called before start_response.") + + chunklen = len(chunk) + rbo = self.remaining_bytes_out + if rbo is not None and chunklen > rbo: + if not self.req.sent_headers: + # Whew. We can send a 500 to the client. + self.req.simple_response("500 Internal Server Error", + "The requested resource returned more bytes than the " + "declared Content-Length.") + else: + # Dang. We have probably already sent data. Truncate the chunk + # to fit (so the client doesn't hang) and raise an error later. + chunk = chunk[:rbo] + + if not self.req.sent_headers: + self.req.sent_headers = True + self.req.send_headers() + + self.req.write(chunk) + + if rbo is not None: + rbo -= chunklen + if rbo < 0: + raise ValueError( + "Response body exceeds the declared Content-Length.") + + +class WSGIGateway_10(WSGIGateway): + """A Gateway class to interface HTTPServer with WSGI 1.0.x.""" + + def get_environ(self): + """Return a new environ dict targeting the given wsgi.version""" + req = self.req + env = { + # set a non-standard environ entry so the WSGI app can know what + # the *real* server protocol is (and what features to support). + # See http://www.faqs.org/rfcs/rfc2145.html. + 'ACTUAL_SERVER_PROTOCOL': req.server.protocol, + 'PATH_INFO': req.path.decode('ISO-8859-1'), + 'QUERY_STRING': req.qs.decode('ISO-8859-1'), + 'REMOTE_ADDR': req.conn.remote_addr or '', + 'REMOTE_PORT': str(req.conn.remote_port or ''), + 'REQUEST_METHOD': req.method.decode('ISO-8859-1'), + 'REQUEST_URI': req.uri, + 'SCRIPT_NAME': '', + 'SERVER_NAME': req.server.server_name, + # Bah. "SERVER_PROTOCOL" is actually the REQUEST protocol. + 'SERVER_PROTOCOL': req.request_protocol.decode('ISO-8859-1'), + 'SERVER_SOFTWARE': req.server.software, + 'wsgi.errors': sys.stderr, + 'wsgi.input': req.rfile, + 'wsgi.multiprocess': False, + 'wsgi.multithread': True, + 'wsgi.run_once': False, + 'wsgi.url_scheme': req.scheme.decode('ISO-8859-1'), + 'wsgi.version': (1, 0), + } + + if isinstance(req.server.bind_addr, basestring): + # AF_UNIX. This isn't really allowed by WSGI, which doesn't + # address unix domain sockets. But it's better than nothing. + env["SERVER_PORT"] = "" + else: + env["SERVER_PORT"] = str(req.server.bind_addr[1]) + + # Request headers + for k, v in req.inheaders.items(): + k = k.decode('ISO-8859-1').upper().replace("-", "_") + env["HTTP_" + k] = v.decode('ISO-8859-1') + + # CONTENT_TYPE/CONTENT_LENGTH + ct = env.pop("HTTP_CONTENT_TYPE", None) + if ct is not None: + env["CONTENT_TYPE"] = ct + cl = env.pop("HTTP_CONTENT_LENGTH", None) + if cl is not None: + env["CONTENT_LENGTH"] = cl + + if req.conn.ssl_env: + env.update(req.conn.ssl_env) + + return env + + +class WSGIGateway_u0(WSGIGateway_10): + """A Gateway class to interface HTTPServer with WSGI u.0. + + WSGI u.0 is an experimental protocol, which uses unicode for keys and values + in both Python 2 and Python 3. + """ + + def get_environ(self): + """Return a new environ dict targeting the given wsgi.version""" + req = self.req + env_10 = WSGIGateway_10.get_environ(self) + env = env_10.copy() + env['wsgi.version'] = ('u', 0) + + # Request-URI + env.setdefault('wsgi.url_encoding', 'utf-8') + try: + # SCRIPT_NAME is the empty string, who cares what encoding it is? + env["PATH_INFO"] = req.path.decode(env['wsgi.url_encoding']) + env["QUERY_STRING"] = req.qs.decode(env['wsgi.url_encoding']) + except UnicodeDecodeError: + # Fall back to latin 1 so apps can transcode if needed. + env['wsgi.url_encoding'] = 'ISO-8859-1' + env["PATH_INFO"] = env_10["PATH_INFO"] + env["QUERY_STRING"] = env_10["QUERY_STRING"] + + return env + +wsgi_gateways = { + (1, 0): WSGIGateway_10, + ('u', 0): WSGIGateway_u0, +} + +class WSGIPathInfoDispatcher(object): + """A WSGI dispatcher for dispatch based on the PATH_INFO. + + apps: a dict or list of (path_prefix, app) pairs. + """ + + def __init__(self, apps): + try: + apps = list(apps.items()) + except AttributeError: + pass + + # Sort the apps by len(path), descending + apps.sort() + apps.reverse() + + # The path_prefix strings must start, but not end, with a slash. + # Use "" instead of "/". + self.apps = [(p.rstrip("/"), a) for p, a in apps] + + def __call__(self, environ, start_response): + path = environ["PATH_INFO"] or "/" + for p, app in self.apps: + # The apps list should be sorted by length, descending. + if path.startswith(p + "/") or path == p: + environ = environ.copy() + environ["SCRIPT_NAME"] = environ["SCRIPT_NAME"] + p + environ["PATH_INFO"] = path[len(p):] + return app(environ, start_response) + + start_response('404 Not Found', [('Content-Type', 'text/plain'), + ('Content-Length', '0')]) + return [''] + From d57a2a606ffd6df95fa8f7d63c5e922b3c4ffbf5 Mon Sep 17 00:00:00 2001 From: rembo10 Date: Tue, 31 Jul 2012 15:23:41 +0530 Subject: [PATCH 15/84] Fix for librarysync unicode errors causing thread to hang --- headphones/librarysync.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/headphones/librarysync.py b/headphones/librarysync.py index 651ff7c8..f17db1ee 100644 --- a/headphones/librarysync.py +++ b/headphones/librarysync.py @@ -55,12 +55,15 @@ def libraryScan(dir=None): song = os.path.join(r, files) + # We need the unicode path to use for logging, inserting into database + unicode_song_path = song.decode(headphones.SYS_ENCODING, errors='replace') + # Try to read the metadata try: f = MediaFile(song) except: - logger.error('Cannot read file: ' + song.decode(headphones.SYS_ENCODING)) + logger.error('Cannot read file: ' + unicode_song_path) continue # Grab the bitrates for the auto detect bit rate option @@ -83,7 +86,7 @@ def libraryScan(dir=None): track = myDB.action('SELECT TrackID from tracks WHERE ArtistName LIKE ? AND AlbumTitle LIKE ? AND TrackTitle LIKE ?', [f_artist, f.album, f.title]).fetchone() if track: - myDB.action('UPDATE tracks SET Location=?, BitRate=?, Format=? WHERE TrackID=?', [song.decode(headphones.SYS_ENCODING), f.bitrate, f.format, track['TrackID']]) + myDB.action('UPDATE tracks SET Location=?, BitRate=?, Format=? WHERE TrackID=?', [unicode_song_path, f.bitrate, f.format, track['TrackID']]) continue # Try to match on mbid if available and we couldn't find a match based on metadata @@ -94,14 +97,14 @@ def libraryScan(dir=None): track = myDB.action('SELECT TrackID from tracks WHERE TrackID=?', [f.mb_trackid]).fetchone() if track: - myDB.action('UPDATE tracks SET Location=?, BitRate=?, Format=? WHERE TrackID=?', [song.decode(headphones.SYS_ENCODING), f.bitrate, f.format, track['TrackID']]) + myDB.action('UPDATE tracks SET Location=?, BitRate=?, Format=? WHERE TrackID=?', [unicode_song_path, f.bitrate, f.format, track['TrackID']]) continue # if we can't find a match in the database on a track level, it might be a new artist or it might be on a non-mb release new_artists.append(f_artist) # The have table will become the new database for unmatched tracks (i.e. tracks with no associated links in the database - myDB.action('INSERT INTO have (ArtistName, AlbumTitle, TrackNumber, TrackTitle, TrackLength, BitRate, Genre, Date, TrackID, Location, CleanName, Format) VALUES( ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)', [f_artist, f.album, f.track, f.title, f.length, f.bitrate, f.genre, f.date, f.mb_trackid, song.decode(headphones.SYS_ENCODING), helpers.cleanName(f_artist+' '+f.album+' '+f.title), f.format]) + myDB.action('INSERT INTO have (ArtistName, AlbumTitle, TrackNumber, TrackTitle, TrackLength, BitRate, Genre, Date, TrackID, Location, CleanName, Format) VALUES( ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)', [f_artist, f.album, f.track, f.title, f.length, f.bitrate, f.genre, f.date, f.mb_trackid, unicode_song_path, helpers.cleanName(f_artist+' '+f.album+' '+f.title), f.format]) logger.info('Completed scanning directory: %s' % dir) From 40079835b36b122c4eec9cd355d7d1a4041a05e9 Mon Sep 17 00:00:00 2001 From: rembo10 Date: Tue, 31 Jul 2012 20:50:23 +0530 Subject: [PATCH 16/84] Fixed some imports for the musicbrainz lib --- lib/musicbrainzngs/__init__.py | 2 +- lib/musicbrainzngs/mbxml.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/lib/musicbrainzngs/__init__.py b/lib/musicbrainzngs/__init__.py index 36962ef5..40a89036 100644 --- a/lib/musicbrainzngs/__init__.py +++ b/lib/musicbrainzngs/__init__.py @@ -1 +1 @@ -from musicbrainzngs.musicbrainz import * +from lib.musicbrainzngs.musicbrainz import * diff --git a/lib/musicbrainzngs/mbxml.py b/lib/musicbrainzngs/mbxml.py index 7f6bd9f2..137a3909 100644 --- a/lib/musicbrainzngs/mbxml.py +++ b/lib/musicbrainzngs/mbxml.py @@ -6,7 +6,7 @@ import xml.etree.ElementTree as ET import logging -from musicbrainzngs import util +from lib.musicbrainzngs import util try: from ET import fixtag From 65ae7d992a42db125f83e899ac5284b4ec3c2241 Mon Sep 17 00:00:00 2001 From: Ben Graham Date: Thu, 2 Aug 2012 14:58:17 +1000 Subject: [PATCH 17/84] initial hacking --- headphones/webfilters.py | 50 ++++++++++++++++++++++++++++++++++++++++ headphones/webserve.py | 6 +++-- 2 files changed, 54 insertions(+), 2 deletions(-) create mode 100644 headphones/webfilters.py diff --git a/headphones/webfilters.py b/headphones/webfilters.py new file mode 100644 index 00000000..50d369c3 --- /dev/null +++ b/headphones/webfilters.py @@ -0,0 +1,50 @@ +# This file is part of Headphones. +# +# Headphones is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Headphones is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Headphones. If not, see . + +from cherrypy.filters.basefilter import BaseFilter +import cherrypy + + + +class HTTPSFilter(BaseFilter): + + """This filter is based on a guide at http://www.turbogears.org/1.0/docs/Install/RedirectHttpsRequests.html + + It's purpose is to allow Headphones to issue redirects with the + correct protocol (HTTP/HTTPS) when being served behind a + HTTPS-handling proxy. + """ + + def before_request_body(self): + forwarded_ssl_triggers = { + 'X-Forwarded-Protocol': 'SSL', + 'X-Forwarded-Ssl': 'On', + } + request = cherrypy.request + headers = request.headers + forwarded_ssl = reduce( + lambda x, y: x | headers.get(y).lower() == forwarded_ssl_triggers[y].lower(), + forwarded_ssl_triggers.keys(), + False + ) + if forwarded_ssl: + # base = config.get('https_filter.secure_base_url') + # if base is None: + # if config.get('base_url_filter.use_x_forwarded_host', False): + # base = headers.get('X-Forwarded-Host', 'localhost') + # else: + # base = 'localhost' + # request.base = 'https://' + base + request.headers['X-ForwardedSslDetected'] = Yes diff --git a/headphones/webserve.py b/headphones/webserve.py index 2b2e89fd..ec2045e4 100644 --- a/headphones/webserve.py +++ b/headphones/webserve.py @@ -25,7 +25,7 @@ import threading import headphones -from headphones import logger, searcher, db, importer, mb, lastfm, librarysync +from headphones import logger, searcher, db, importer, mb, lastfm, librarysync, webfilters from headphones.helpers import checked, radio import lib.simplejson as simplejson @@ -46,7 +46,9 @@ def serve_template(templatename, **kwargs): return exceptions.html_error_template().render() class WebInterface(object): - + + _cp_filters = [webfilters.HTTPSFilter()] + def index(self): raise cherrypy.HTTPRedirect("home") index.exposed=True From 8ce64e551eeffff4b2c78856a6d6585a06c72c2e Mon Sep 17 00:00:00 2001 From: Ben Graham Date: Thu, 2 Aug 2012 17:16:28 +1000 Subject: [PATCH 18/84] OK, so that turned out to be easier than I expected --- headphones/webstart.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/headphones/webstart.py b/headphones/webstart.py index b3659f55..13465bff 100644 --- a/headphones/webstart.py +++ b/headphones/webstart.py @@ -35,7 +35,8 @@ def initialize(options={}): conf = { '/': { - 'tools.staticdir.root': os.path.join(headphones.PROG_DIR, 'data') + 'tools.staticdir.root': os.path.join(headphones.PROG_DIR, 'data'), + 'tools.proxy.on': True, # pay attention to X-Forwarded-Proto header }, '/interfaces':{ 'tools.staticdir.on': True, From 6c157a8c034efcde01150243e83b026aa0fdb187 Mon Sep 17 00:00:00 2001 From: Ben Graham Date: Thu, 2 Aug 2012 17:21:49 +1000 Subject: [PATCH 19/84] This file is no longer needed --- headphones/webfilters.py | 50 ---------------------------------------- 1 file changed, 50 deletions(-) delete mode 100644 headphones/webfilters.py diff --git a/headphones/webfilters.py b/headphones/webfilters.py deleted file mode 100644 index 50d369c3..00000000 --- a/headphones/webfilters.py +++ /dev/null @@ -1,50 +0,0 @@ -# This file is part of Headphones. -# -# Headphones is free software: you can redistribute it and/or modify -# it under the terms of the GNU General Public License as published by -# the Free Software Foundation, either version 3 of the License, or -# (at your option) any later version. -# -# Headphones is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU General Public License for more details. -# -# You should have received a copy of the GNU General Public License -# along with Headphones. If not, see . - -from cherrypy.filters.basefilter import BaseFilter -import cherrypy - - - -class HTTPSFilter(BaseFilter): - - """This filter is based on a guide at http://www.turbogears.org/1.0/docs/Install/RedirectHttpsRequests.html - - It's purpose is to allow Headphones to issue redirects with the - correct protocol (HTTP/HTTPS) when being served behind a - HTTPS-handling proxy. - """ - - def before_request_body(self): - forwarded_ssl_triggers = { - 'X-Forwarded-Protocol': 'SSL', - 'X-Forwarded-Ssl': 'On', - } - request = cherrypy.request - headers = request.headers - forwarded_ssl = reduce( - lambda x, y: x | headers.get(y).lower() == forwarded_ssl_triggers[y].lower(), - forwarded_ssl_triggers.keys(), - False - ) - if forwarded_ssl: - # base = config.get('https_filter.secure_base_url') - # if base is None: - # if config.get('base_url_filter.use_x_forwarded_host', False): - # base = headers.get('X-Forwarded-Host', 'localhost') - # else: - # base = 'localhost' - # request.base = 'https://' + base - request.headers['X-ForwardedSslDetected'] = Yes From 18fcd4d15d5bc505599705de2ceab2b3c1ffe9b3 Mon Sep 17 00:00:00 2001 From: Ben Graham Date: Thu, 2 Aug 2012 17:23:56 +1000 Subject: [PATCH 20/84] reverse-merge earlier changes to this file, they are no longer needed --- headphones/webserve.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/headphones/webserve.py b/headphones/webserve.py index ec2045e4..2b2e89fd 100644 --- a/headphones/webserve.py +++ b/headphones/webserve.py @@ -25,7 +25,7 @@ import threading import headphones -from headphones import logger, searcher, db, importer, mb, lastfm, librarysync, webfilters +from headphones import logger, searcher, db, importer, mb, lastfm, librarysync from headphones.helpers import checked, radio import lib.simplejson as simplejson @@ -46,9 +46,7 @@ def serve_template(templatename, **kwargs): return exceptions.html_error_template().render() class WebInterface(object): - - _cp_filters = [webfilters.HTTPSFilter()] - + def index(self): raise cherrypy.HTTPRedirect("home") index.exposed=True From 04c5cc5d522fae71fc08ed331b62595ba09428b8 Mon Sep 17 00:00:00 2001 From: rembo10 Date: Thu, 2 Aug 2012 15:59:47 +0530 Subject: [PATCH 21/84] Modified the correctMetadata function in postProcessor.py to work with the updated beets lib: candidates/out_tuples now returns extra_items & extra_tracks, and input to autotag.apply_metadata was reversed --- headphones/postprocessor.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/headphones/postprocessor.py b/headphones/postprocessor.py index 6f275716..3e4125bf 100644 --- a/headphones/postprocessor.py +++ b/headphones/postprocessor.py @@ -470,7 +470,7 @@ def correctMetadata(albumid, release, downloaded_track_list): logger.error("Beets couldn't create an Item from: " + downloaded_track + " - not a media file?" + str(e)) try: - cur_artist, cur_album, out_tuples, rec = autotag.tag_album(items, search_artist=helpers.latinToAscii(release['ArtistName']), search_album=helpers.latinToAscii(release['AlbumTitle'])) + cur_artist, cur_album, candidates, rec = autotag.tag_album(items, search_artist=helpers.latinToAscii(release['ArtistName']), search_album=helpers.latinToAscii(release['AlbumTitle'])) except Exception, e: logger.error('Error getting recommendation: %s. Not writing metadata' % e) return @@ -478,9 +478,9 @@ def correctMetadata(albumid, release, downloaded_track_list): logger.warn('No accurate album match found for %s, %s - not writing metadata' % (release['ArtistName'], release['AlbumTitle'])) return - distance, items, info = out_tuples[0] + dist, info, mapping, extra_items, extra_tracks = candidates[0] logger.debug('Beets recommendation: %s' % rec) - autotag.apply_metadata(items, info) + autotag.apply_metadata(info, mapping) if len(items) != len(downloaded_track_list): logger.warn("Mismatch between number of tracks downloaded and the metadata items, but I'll try to write it anyway") From ed5d8b7459e5d7c64181e70727d7f06ee7cd5e52 Mon Sep 17 00:00:00 2001 From: rembo10 Date: Mon, 6 Aug 2012 16:12:30 +0530 Subject: [PATCH 22/84] Added timeouts to last.fm functions (urllib->urllib2), added some error catching when parsing data --- headphones/lastfm.py | 22 ++++++++++++++++------ 1 file changed, 16 insertions(+), 6 deletions(-) diff --git a/headphones/lastfm.py b/headphones/lastfm.py index c770d187..a55b6c8f 100644 --- a/headphones/lastfm.py +++ b/headphones/lastfm.py @@ -13,7 +13,7 @@ # You should have received a copy of the GNU General Public License # along with Headphones. If not, see . -import urllib +import urllib, urllib2 from xml.dom import minidom from collections import defaultdict import random @@ -37,7 +37,7 @@ def getSimilar(): url = 'http://ws.audioscrobbler.com/2.0/?method=artist.getsimilar&mbid=%s&api_key=%s' % (result['ArtistID'], api_key) try: - data = urllib.urlopen(url).read() + data = urllib2.urlopen(url, timeout=20).read() except: time.sleep(1) continue @@ -45,7 +45,11 @@ def getSimilar(): if len(data) < 200: continue - d = minidom.parseString(data) + try: + d = minidom.parseString(data) + except: + logger.debug("Could not parse similar artist data from last.fm") + node = d.documentElement artists = d.getElementsByTagName("artist") @@ -93,8 +97,14 @@ def getArtists(): username = headphones.LASTFM_USERNAME url = 'http://ws.audioscrobbler.com/2.0/?method=library.getartists&limit=10000&api_key=%s&user=%s' % (api_key, username) - data = urllib.urlopen(url).read() - d = minidom.parseString(data) + data = urllib2.urlopen(url, timeout=20).read() + + try: + d = minidom.parseString(data) + except: + logger.error("Could not parse artist list from last.fm data") + return + artists = d.getElementsByTagName("artist") artistlist = [] @@ -131,7 +141,7 @@ def getAlbumDescription(rgid, artist, album): } searchURL = 'http://ws.audioscrobbler.com/2.0/?' + urllib.urlencode(params) - data = urllib.urlopen(searchURL).read() + data = urllib2.urlopen(searchURL, timeout=20).read() if data == 'Album not found': return From 7bbdeeac619c33dd94cb2d871843941dc88915e5 Mon Sep 17 00:00:00 2001 From: rembo10 Date: Tue, 7 Aug 2012 15:55:44 +0530 Subject: [PATCH 23/84] Removed autocommiting from db transactions in librarysync - only do one commit every 100 songs --- headphones/db.py | 30 ++++++++++++++++++++++++++++-- headphones/librarysync.py | 26 +++++++++++++++++++------- 2 files changed, 47 insertions(+), 9 deletions(-) diff --git a/headphones/db.py b/headphones/db.py index 601b6286..ac1a8d96 100644 --- a/headphones/db.py +++ b/headphones/db.py @@ -42,7 +42,30 @@ class DBConnection: self.connection = sqlite3.connect(dbFilename(filename), timeout=20) self.connection.row_factory = sqlite3.Row - def action(self, query, args=None): + def commit(self): + + with db_lock: + + attempt = 0 + + while attempt < 5: + try: + self.connection.commit() + break + + except sqlite3.OperationalError, e: + if "unable to open database file" in e.message or "database is locked" in e.message: + logger.warn('Database Error: %s' % e) + attempt += 1 + time.sleep(1) + else: + logger.error('Database error: %s' % e) + raise + except sqlite3.DatabaseError, e: + logger.error('Fatal Error executing %s :: %s' % (query, e)) + raise + + def action(self, query, args=None, commit=True): with db_lock: @@ -60,7 +83,10 @@ class DBConnection: else: #logger.debug(self.filename+": "+query+" with args "+str(args)) sqlResult = self.connection.execute(query, args) - self.connection.commit() + + if commit: + self.connection.commit() + break except sqlite3.OperationalError, e: if "unable to open database file" in e.message or "database is locked" in e.message: diff --git a/headphones/librarysync.py b/headphones/librarysync.py index f17db1ee..2923b919 100644 --- a/headphones/librarysync.py +++ b/headphones/librarysync.py @@ -39,7 +39,9 @@ def libraryScan(dir=None): for track in tracks: if not os.path.isfile(track['Location'].encode(headphones.SYS_ENCODING)): - myDB.action('UPDATE tracks SET Location=?, BitRate=?, Format=? WHERE TrackID=?', [None, None, None, track['TrackID']]) + myDB.action('UPDATE tracks SET Location=?, BitRate=?, Format=? WHERE TrackID=?', [None, None, None, track['TrackID']], commit=False) + + myDB.commit() logger.info('Scanning music directory: %s' % dir) @@ -50,6 +52,8 @@ def libraryScan(dir=None): for r,d,f in os.walk(dir): for files in f: + # Taking out the auto-commit for every database transaction, instead we'll commit every 100 songs. + i = 0 # MEDIA_FORMATS = music file extensions, e.g. mp3, flac, etc if any(files.lower().endswith('.' + x.lower()) for x in headphones.MEDIA_FORMATS): @@ -80,13 +84,13 @@ def libraryScan(dir=None): if f_artist and f.album and f.title: - track = myDB.action('SELECT TrackID from tracks WHERE CleanName LIKE ?', [helpers.cleanName(f_artist +' '+f.album+' '+f.title)]).fetchone() + track = myDB.action('SELECT TrackID from tracks WHERE CleanName LIKE ?', [helpers.cleanName(f_artist +' '+f.album+' '+f.title)], commit=False).fetchone() if not track: - track = myDB.action('SELECT TrackID from tracks WHERE ArtistName LIKE ? AND AlbumTitle LIKE ? AND TrackTitle LIKE ?', [f_artist, f.album, f.title]).fetchone() + track = myDB.action('SELECT TrackID from tracks WHERE ArtistName LIKE ? AND AlbumTitle LIKE ? AND TrackTitle LIKE ?', [f_artist, f.album, f.title], commit=False).fetchone() if track: - myDB.action('UPDATE tracks SET Location=?, BitRate=?, Format=? WHERE TrackID=?', [unicode_song_path, f.bitrate, f.format, track['TrackID']]) + myDB.action('UPDATE tracks SET Location=?, BitRate=?, Format=? WHERE TrackID=?', [unicode_song_path, f.bitrate, f.format, track['TrackID']], commit=False) continue # Try to match on mbid if available and we couldn't find a match based on metadata @@ -94,18 +98,26 @@ def libraryScan(dir=None): # Wondering if theres a better way to do this -> do one thing if the row exists, # do something else if it doesn't - track = myDB.action('SELECT TrackID from tracks WHERE TrackID=?', [f.mb_trackid]).fetchone() + track = myDB.action('SELECT TrackID from tracks WHERE TrackID=?', [f.mb_trackid], commit=False).fetchone() if track: - myDB.action('UPDATE tracks SET Location=?, BitRate=?, Format=? WHERE TrackID=?', [unicode_song_path, f.bitrate, f.format, track['TrackID']]) + myDB.action('UPDATE tracks SET Location=?, BitRate=?, Format=? WHERE TrackID=?', [unicode_song_path, f.bitrate, f.format, track['TrackID']], commit=False) continue # if we can't find a match in the database on a track level, it might be a new artist or it might be on a non-mb release new_artists.append(f_artist) # The have table will become the new database for unmatched tracks (i.e. tracks with no associated links in the database - myDB.action('INSERT INTO have (ArtistName, AlbumTitle, TrackNumber, TrackTitle, TrackLength, BitRate, Genre, Date, TrackID, Location, CleanName, Format) VALUES( ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)', [f_artist, f.album, f.track, f.title, f.length, f.bitrate, f.genre, f.date, f.mb_trackid, unicode_song_path, helpers.cleanName(f_artist+' '+f.album+' '+f.title), f.format]) + myDB.action('INSERT INTO have (ArtistName, AlbumTitle, TrackNumber, TrackTitle, TrackLength, BitRate, Genre, Date, TrackID, Location, CleanName, Format) VALUES( ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)', [f_artist, f.album, f.track, f.title, f.length, f.bitrate, f.genre, f.date, f.mb_trackid, unicode_song_path, helpers.cleanName(f_artist+' '+f.album+' '+f.title), f.format], commit=False) + ## Increment the song counter and commit every 100th song + i += 1 + if i%100 == 0: + myDB.commit() + + # Do one last commit of the changes + myDB.commit() + logger.info('Completed scanning directory: %s' % dir) # Clean up the new artist list From 967f4c610c42511bc02601036537270d9b3539e8 Mon Sep 17 00:00:00 2001 From: rembo10 Date: Sat, 11 Aug 2012 14:02:53 +0530 Subject: [PATCH 24/84] Fixed unicodedecodeerror in music_encoder: paths were being converted to unicode which wasn't necessary --- headphones/music_encoder.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/headphones/music_encoder.py b/headphones/music_encoder.py index 20e995e1..51864aa4 100644 --- a/headphones/music_encoder.py +++ b/headphones/music_encoder.py @@ -49,13 +49,13 @@ def encode(albumPath): if (headphones.ENCODERLOSSLESS): if (music.lower().endswith('.flac')): musicFiles.append(os.path.join(r, music)) - musicTemp = os.path.normpath(os.path.splitext(music)[0]+'.'+headphones.ENCODEROUTPUTFORMAT).encode(headphones.SYS_ENCODING) + musicTemp = os.path.normpath(os.path.splitext(music)[0]+'.'+headphones.ENCODEROUTPUTFORMAT) musicTempFiles.append(os.path.join(tempDirEncode, musicTemp)) else: logger.debug('Music "%s" is already encoded' % (music)) else: musicFiles.append(os.path.join(r, music)) - musicTemp = os.path.normpath(os.path.splitext(music)[0]+'.'+headphones.ENCODEROUTPUTFORMAT).encode(headphones.SYS_ENCODING) + musicTemp = os.path.normpath(os.path.splitext(music)[0]+'.'+headphones.ENCODEROUTPUTFORMAT) musicTempFiles.append(os.path.join(tempDirEncode, musicTemp)) if headphones.ENCODER=='lame': From 67d33fdc0f37f70d420b65998c7be14f7f5ffa3a Mon Sep 17 00:00:00 2001 From: rembo10 Date: Sun, 12 Aug 2012 12:39:58 +0530 Subject: [PATCH 25/84] Revert "Removed autocommiting from db transactions in librarysync - only do one commit every 100 songs" This reverts commit 7bbdeeac619c33dd94cb2d871843941dc88915e5. --- headphones/db.py | 30 ++---------------------------- headphones/librarysync.py | 26 +++++++------------------- 2 files changed, 9 insertions(+), 47 deletions(-) diff --git a/headphones/db.py b/headphones/db.py index ac1a8d96..601b6286 100644 --- a/headphones/db.py +++ b/headphones/db.py @@ -42,30 +42,7 @@ class DBConnection: self.connection = sqlite3.connect(dbFilename(filename), timeout=20) self.connection.row_factory = sqlite3.Row - def commit(self): - - with db_lock: - - attempt = 0 - - while attempt < 5: - try: - self.connection.commit() - break - - except sqlite3.OperationalError, e: - if "unable to open database file" in e.message or "database is locked" in e.message: - logger.warn('Database Error: %s' % e) - attempt += 1 - time.sleep(1) - else: - logger.error('Database error: %s' % e) - raise - except sqlite3.DatabaseError, e: - logger.error('Fatal Error executing %s :: %s' % (query, e)) - raise - - def action(self, query, args=None, commit=True): + def action(self, query, args=None): with db_lock: @@ -83,10 +60,7 @@ class DBConnection: else: #logger.debug(self.filename+": "+query+" with args "+str(args)) sqlResult = self.connection.execute(query, args) - - if commit: - self.connection.commit() - + self.connection.commit() break except sqlite3.OperationalError, e: if "unable to open database file" in e.message or "database is locked" in e.message: diff --git a/headphones/librarysync.py b/headphones/librarysync.py index 2923b919..f17db1ee 100644 --- a/headphones/librarysync.py +++ b/headphones/librarysync.py @@ -39,9 +39,7 @@ def libraryScan(dir=None): for track in tracks: if not os.path.isfile(track['Location'].encode(headphones.SYS_ENCODING)): - myDB.action('UPDATE tracks SET Location=?, BitRate=?, Format=? WHERE TrackID=?', [None, None, None, track['TrackID']], commit=False) - - myDB.commit() + myDB.action('UPDATE tracks SET Location=?, BitRate=?, Format=? WHERE TrackID=?', [None, None, None, track['TrackID']]) logger.info('Scanning music directory: %s' % dir) @@ -52,8 +50,6 @@ def libraryScan(dir=None): for r,d,f in os.walk(dir): for files in f: - # Taking out the auto-commit for every database transaction, instead we'll commit every 100 songs. - i = 0 # MEDIA_FORMATS = music file extensions, e.g. mp3, flac, etc if any(files.lower().endswith('.' + x.lower()) for x in headphones.MEDIA_FORMATS): @@ -84,13 +80,13 @@ def libraryScan(dir=None): if f_artist and f.album and f.title: - track = myDB.action('SELECT TrackID from tracks WHERE CleanName LIKE ?', [helpers.cleanName(f_artist +' '+f.album+' '+f.title)], commit=False).fetchone() + track = myDB.action('SELECT TrackID from tracks WHERE CleanName LIKE ?', [helpers.cleanName(f_artist +' '+f.album+' '+f.title)]).fetchone() if not track: - track = myDB.action('SELECT TrackID from tracks WHERE ArtistName LIKE ? AND AlbumTitle LIKE ? AND TrackTitle LIKE ?', [f_artist, f.album, f.title], commit=False).fetchone() + track = myDB.action('SELECT TrackID from tracks WHERE ArtistName LIKE ? AND AlbumTitle LIKE ? AND TrackTitle LIKE ?', [f_artist, f.album, f.title]).fetchone() if track: - myDB.action('UPDATE tracks SET Location=?, BitRate=?, Format=? WHERE TrackID=?', [unicode_song_path, f.bitrate, f.format, track['TrackID']], commit=False) + myDB.action('UPDATE tracks SET Location=?, BitRate=?, Format=? WHERE TrackID=?', [unicode_song_path, f.bitrate, f.format, track['TrackID']]) continue # Try to match on mbid if available and we couldn't find a match based on metadata @@ -98,26 +94,18 @@ def libraryScan(dir=None): # Wondering if theres a better way to do this -> do one thing if the row exists, # do something else if it doesn't - track = myDB.action('SELECT TrackID from tracks WHERE TrackID=?', [f.mb_trackid], commit=False).fetchone() + track = myDB.action('SELECT TrackID from tracks WHERE TrackID=?', [f.mb_trackid]).fetchone() if track: - myDB.action('UPDATE tracks SET Location=?, BitRate=?, Format=? WHERE TrackID=?', [unicode_song_path, f.bitrate, f.format, track['TrackID']], commit=False) + myDB.action('UPDATE tracks SET Location=?, BitRate=?, Format=? WHERE TrackID=?', [unicode_song_path, f.bitrate, f.format, track['TrackID']]) continue # if we can't find a match in the database on a track level, it might be a new artist or it might be on a non-mb release new_artists.append(f_artist) # The have table will become the new database for unmatched tracks (i.e. tracks with no associated links in the database - myDB.action('INSERT INTO have (ArtistName, AlbumTitle, TrackNumber, TrackTitle, TrackLength, BitRate, Genre, Date, TrackID, Location, CleanName, Format) VALUES( ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)', [f_artist, f.album, f.track, f.title, f.length, f.bitrate, f.genre, f.date, f.mb_trackid, unicode_song_path, helpers.cleanName(f_artist+' '+f.album+' '+f.title), f.format], commit=False) + myDB.action('INSERT INTO have (ArtistName, AlbumTitle, TrackNumber, TrackTitle, TrackLength, BitRate, Genre, Date, TrackID, Location, CleanName, Format) VALUES( ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)', [f_artist, f.album, f.track, f.title, f.length, f.bitrate, f.genre, f.date, f.mb_trackid, unicode_song_path, helpers.cleanName(f_artist+' '+f.album+' '+f.title), f.format]) - ## Increment the song counter and commit every 100th song - i += 1 - if i%100 == 0: - myDB.commit() - - # Do one last commit of the changes - myDB.commit() - logger.info('Completed scanning directory: %s' % dir) # Clean up the new artist list From fe574354d0eb7e302479d10e8551ab1b150c5d56 Mon Sep 17 00:00:00 2001 From: rembo10 Date: Sun, 12 Aug 2012 15:59:43 +0530 Subject: [PATCH 26/84] Fixed decode() function for python < 2.7 --- headphones/librarysync.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/headphones/librarysync.py b/headphones/librarysync.py index f17db1ee..8dfd4dd5 100644 --- a/headphones/librarysync.py +++ b/headphones/librarysync.py @@ -56,7 +56,7 @@ def libraryScan(dir=None): song = os.path.join(r, files) # We need the unicode path to use for logging, inserting into database - unicode_song_path = song.decode(headphones.SYS_ENCODING, errors='replace') + unicode_song_path = song.decode(headphones.SYS_ENCODING, 'replace') # Try to read the metadata try: From 26a5be9e6176c5c04aa0c8037b45b2b4a31eebad Mon Sep 17 00:00:00 2001 From: rembo10 Date: Sun, 12 Aug 2012 19:46:01 +0530 Subject: [PATCH 27/84] Make proxy tools optional (disabled by default --- Headphones.py | 1 + headphones/__init__.py | 5 ++++- headphones/webstart.py | 2 +- 3 files changed, 6 insertions(+), 2 deletions(-) diff --git a/Headphones.py b/Headphones.py index b6591d75..999c8d1f 100644 --- a/Headphones.py +++ b/Headphones.py @@ -122,6 +122,7 @@ def main(): 'http_port': http_port, 'http_host': headphones.HTTP_HOST, 'http_root': headphones.HTTP_ROOT, + 'http_proxy': headphones.HTTP_PROXY, 'http_username': headphones.HTTP_USERNAME, 'http_password': headphones.HTTP_PASSWORD, }) diff --git a/headphones/__init__.py b/headphones/__init__.py index 265a016e..f0750696 100644 --- a/headphones/__init__.py +++ b/headphones/__init__.py @@ -66,6 +66,7 @@ HTTP_HOST = None HTTP_USERNAME = None HTTP_PASSWORD = None HTTP_ROOT = None +HTTP_PROXY = False LAUNCH_BROWSER = False API_ENABLED = False @@ -236,7 +237,7 @@ def initialize(): with INIT_LOCK: global __INITIALIZED__, FULL_PATH, PROG_DIR, VERBOSE, DAEMON, DATA_DIR, CONFIG_FILE, CFG, CONFIG_VERSION, LOG_DIR, CACHE_DIR, \ - HTTP_PORT, HTTP_HOST, HTTP_USERNAME, HTTP_PASSWORD, HTTP_ROOT, LAUNCH_BROWSER, API_ENABLED, API_KEY, GIT_PATH, \ + HTTP_PORT, HTTP_HOST, HTTP_USERNAME, HTTP_PASSWORD, HTTP_ROOT, HTTP_PROXY, LAUNCH_BROWSER, API_ENABLED, API_KEY, GIT_PATH, \ CURRENT_VERSION, LATEST_VERSION, CHECK_GITHUB, CHECK_GITHUB_ON_STARTUP, CHECK_GITHUB_INTERVAL, MUSIC_DIR, DESTINATION_DIR, PREFERRED_QUALITY, PREFERRED_BITRATE, DETECT_BITRATE, \ ADD_ARTISTS, CORRECT_METADATA, MOVE_FILES, RENAME_FILES, FOLDER_FORMAT, FILE_FORMAT, CLEANUP_FILES, INCLUDE_EXTRAS, AUTOWANT_UPCOMING, AUTOWANT_ALL, \ ADD_ALBUM_ART, EMBED_ALBUM_ART, EMBED_LYRICS, DOWNLOAD_DIR, BLACKHOLE, BLACKHOLE_DIR, USENET_RETENTION, SEARCH_INTERVAL, \ @@ -281,6 +282,7 @@ def initialize(): HTTP_USERNAME = check_setting_str(CFG, 'General', 'http_username', '') HTTP_PASSWORD = check_setting_str(CFG, 'General', 'http_password', '') HTTP_ROOT = check_setting_str(CFG, 'General', 'http_root', '/') + HTTP_PROXY = bool(check_setting_int(CFG, 'General', 'http_proxy', 0)) LAUNCH_BROWSER = bool(check_setting_int(CFG, 'General', 'launch_browser', 1)) API_ENABLED = bool(check_setting_int(CFG, 'General', 'api_enabled', 0)) API_KEY = check_setting_str(CFG, 'General', 'api_key', '') @@ -555,6 +557,7 @@ def config_write(): new_config['General']['http_username'] = HTTP_USERNAME new_config['General']['http_password'] = HTTP_PASSWORD new_config['General']['http_root'] = HTTP_ROOT + new_config['General']['http_proxy'] = int(HTTP_PROXY) new_config['General']['launch_browser'] = int(LAUNCH_BROWSER) new_config['General']['api_enabled'] = int(API_ENABLED) new_config['General']['api_key'] = API_KEY diff --git a/headphones/webstart.py b/headphones/webstart.py index 13465bff..b16d6765 100644 --- a/headphones/webstart.py +++ b/headphones/webstart.py @@ -36,7 +36,7 @@ def initialize(options={}): conf = { '/': { 'tools.staticdir.root': os.path.join(headphones.PROG_DIR, 'data'), - 'tools.proxy.on': True, # pay attention to X-Forwarded-Proto header + 'tools.proxy.on': options['http_proxy'] # pay attention to X-Forwarded-Proto header }, '/interfaces':{ 'tools.staticdir.on': True, From 865e4451c10e98be5017f9f8b6ea5b3af9df90ee Mon Sep 17 00:00:00 2001 From: rembo10 Date: Tue, 14 Aug 2012 12:20:37 +0530 Subject: [PATCH 28/84] Added Status filter to webserve --- headphones/webserve.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/headphones/webserve.py b/headphones/webserve.py index 2b2e89fd..d49f338d 100644 --- a/headphones/webserve.py +++ b/headphones/webserve.py @@ -229,9 +229,12 @@ class WebInterface(object): return serve_template(templatename="manageartists.html", title="Manage Artists", artists=artists) manageArtists.exposed = True - def manageAlbums(self): + def manageAlbums(self, Status=None): myDB = db.DBConnection() - albums = myDB.select('SELECT * from albums') + if Status: + albums = myDB.select('SELECT * from albums WHERE Status=?', [Status]) + else: + albums = myDB.select('SELECT * from albums') return serve_template(templatename="managealbums.html", title="Manage Albums", albums=albums) manageAlbums.exposed = True From 0f2f48543b745bb579b8c1ae4fc24fb324290e16 Mon Sep 17 00:00:00 2001 From: rembo10 Date: Tue, 14 Aug 2012 12:39:24 +0530 Subject: [PATCH 29/84] Added dialog popup to default interface for manage albums filter --- data/interfaces/default/js/script.js | 1 + data/interfaces/default/manage.html | 21 ++++++++++++++++++--- 2 files changed, 19 insertions(+), 3 deletions(-) diff --git a/data/interfaces/default/js/script.js b/data/interfaces/default/js/script.js index b1ac046d..3ac9c342 100644 --- a/data/interfaces/default/js/script.js +++ b/data/interfaces/default/js/script.js @@ -144,6 +144,7 @@ function initConfigCheckbox(elem) { function initActions() { $("#subhead_menu #menu_link_refresh").button({ icons: { primary: "ui-icon-refresh" } }); $("#subhead_menu #menu_link_edit").button({ icons: { primary: "ui-icon-pencil" } }); + $("#subhead_menu .menu_link_edit").button({ icons: { primary: "ui-icon-pencil" } }); $("#subhead_menu #menu_link_delete" ).button({ icons: { primary: "ui-icon-trash" } }); $("#subhead_menu #menu_link_pauze").button({ icons: { primary: "ui-icon-pause"} }); $("#subhead_menu #menu_link_resume").button({ icons: { primary: "ui-icon-play"} }); diff --git a/data/interfaces/default/manage.html b/data/interfaces/default/manage.html index ed62e445..c4786f91 100644 --- a/data/interfaces/default/manage.html +++ b/data/interfaces/default/manage.html @@ -6,10 +6,21 @@ <%def name="headerIncludes()">
@@ -95,6 +106,10 @@ <%def name="javascriptIncludes()"> From 7ea9bbf57d0a0ae0ed76b2650758d54854336463 Mon Sep 17 00:00:00 2001 From: rembo10 Date: Wed, 15 Aug 2012 16:32:26 +0530 Subject: [PATCH 42/84] Fixed issue retrieving release if no releasedate --- headphones/mb.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/headphones/mb.py b/headphones/mb.py index c98e5b07..5f9d168d 100644 --- a/headphones/mb.py +++ b/headphones/mb.py @@ -313,8 +313,8 @@ def getRelease(releaseid, include_artist_info=True): release['title'] = unicode(results['title']) release['id'] = unicode(results['id']) - release['asin'] = unicode(results['asin']) if 'asin' in results else None - release['date'] = unicode(results['date']) + release['asin'] = unicode(results['asin']) if 'asin' in results else u'None' + release['date'] = unicode(results['date']) if 'date' in results else u'None' try: release['format'] = unicode(results['medium-list'][0]['format']) except: From 56e5c7ae1da2b94c2699708fa7f80677cfe33668 Mon Sep 17 00:00:00 2001 From: rembo10 Date: Wed, 15 Aug 2012 17:56:15 +0530 Subject: [PATCH 43/84] Added a timeout to lyrics fetching --- headphones/lyrics.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/headphones/lyrics.py b/headphones/lyrics.py index c8eda8a6..6899359a 100644 --- a/headphones/lyrics.py +++ b/headphones/lyrics.py @@ -14,7 +14,7 @@ # along with Headphones. If not, see . import re -import urllib +import urllib, urllib2 from xml.dom import minidom import htmlentitydefs @@ -30,7 +30,7 @@ def getLyrics(artist, song): searchURL = 'http://lyrics.wikia.com/api.php?' + urllib.urlencode(params) try: - data = urllib.urlopen(searchURL).read() + data = urllib2.urlopen(searchURL, timeout=20).read() except Exception, e: logger.warn('Error opening: %s. Error: %s' % (searchURL, e)) return From 87e8e35985557f55de080039cdd05a7caa3793e3 Mon Sep 17 00:00:00 2001 From: rembo10 Date: Thu, 16 Aug 2012 18:19:32 +0530 Subject: [PATCH 44/84] Fixed bug where fetching album art from Last.FM would hang if no album art was found --- headphones/postprocessor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/headphones/postprocessor.py b/headphones/postprocessor.py index 3e4125bf..6a90c706 100644 --- a/headphones/postprocessor.py +++ b/headphones/postprocessor.py @@ -240,7 +240,7 @@ def doPostProcessing(albumid, albumpath, release, tracks, downloaded_track_list) if len(artwork) < 100: logger.info("No suitable album art found from Amazon. Checking Last.FM....") artwork = albumart.getCachedArt(albumid) - if len(artwork) < 100: + if not artwork or len(artwork) < 100: artwork = False logger.info("No suitable album art found from Last.FM. Not adding album art") From f2e0afac2550e9dd59ea81d96acce84d759f41b9 Mon Sep 17 00:00:00 2001 From: rembo10 Date: Thu, 16 Aug 2012 21:40:15 +0530 Subject: [PATCH 45/84] Fire off a searcher thread when adding albums with Mark All Albums as Wanted option is set --- headphones/importer.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/headphones/importer.py b/headphones/importer.py index 473ef59c..cf58fa63 100644 --- a/headphones/importer.py +++ b/headphones/importer.py @@ -338,6 +338,11 @@ def addArtisttoDB(artistid, extrasonly=False): if headphones.AUTOWANT_ALL: newValueDict['Status'] = "Wanted" + + #start a search for the album + import searcher + searcher.searchforalbum(albumid=rg['id']) + elif album['ReleaseDate'] > helpers.today() and headphones.AUTOWANT_UPCOMING: newValueDict['Status'] = "Wanted" else: @@ -522,7 +527,7 @@ def addReleaseById(rid): #start a search for the album import searcher - searcher.searchNZB(rgid, False) + searcher.searchforalbum(rgid, False) elif not rg_exists and not release_dict: logger.error("ReleaseGroup does not exist in the database and did not get a valid response from MB. Skipping release.") return From 625b0efd13024ee373abfc68a9a8b4e7bbdf5f1e Mon Sep 17 00:00:00 2001 From: rembo10 Date: Thu, 16 Aug 2012 22:54:33 +0530 Subject: [PATCH 46/84] A couple of changes to library sync --- data/interfaces/default/album.html | 4 +- headphones/importer.py | 4 +- headphones/librarysync.py | 244 +++++++++++++++++++++++++---- 3 files changed, 218 insertions(+), 34 deletions(-) diff --git a/data/interfaces/default/album.html b/data/interfaces/default/album.html index 20a8f659..fce97271 100644 --- a/data/interfaces/default/album.html +++ b/data/interfaces/default/album.html @@ -176,8 +176,6 @@ $('#refresh_artist').click(function() { $('#dialog').dialog("close"); }); - getAlbumInfo(); - getAlbumArt(); initActions(); setTimeout(function(){ initFancybox(); @@ -193,6 +191,8 @@ }; $(document).ready(function() { + getAlbumInfo(); + getAlbumArt(); initThisPage(); }); diff --git a/headphones/importer.py b/headphones/importer.py index cf58fa63..c92a0d49 100644 --- a/headphones/importer.py +++ b/headphones/importer.py @@ -236,7 +236,7 @@ def addArtisttoDB(artistid, extrasonly=False): newValueDict['Location'] = match['Location'] newValueDict['BitRate'] = match['BitRate'] newValueDict['Format'] = match['Format'] - myDB.action('UPDATE tracks SET Matched="True" WHERE Location=?', match['Location']) + myDB.action('UPDATE have SET Matched="True" WHERE Location=?', [match['Location']]) myDB.upsert("alltracks", newValueDict, controlValueDict) @@ -287,7 +287,7 @@ def addArtisttoDB(artistid, extrasonly=False): newValueDict['Location'] = match['Location'] newValueDict['BitRate'] = match['BitRate'] newValueDict['Format'] = match['Format'] - myDB.action('UPDATE tracks SET Matched="True" WHERE Location=?', match['Location']) + myDB.action('UPDATE have SET Matched="True" WHERE Location=?', [match['Location']]) myDB.upsert("alltracks", newValueDict, controlValueDict) diff --git a/headphones/librarysync.py b/headphones/librarysync.py index 8dfd4dd5..b7484b1c 100644 --- a/headphones/librarysync.py +++ b/headphones/librarysync.py @@ -45,6 +45,8 @@ def libraryScan(dir=None): new_artists = [] bitrates = [] + + song_list = [] myDB.action('DELETE from have') @@ -69,44 +71,224 @@ def libraryScan(dir=None): # Grab the bitrates for the auto detect bit rate option if f.bitrate: bitrates.append(f.bitrate) - - # Try to find a match based on artist/album/tracktitle + + # Use the album artist over the artist if available if f.albumartist: f_artist = f.albumartist elif f.artist: f_artist = f.artist else: - continue - - if f_artist and f.album and f.title: - - track = myDB.action('SELECT TrackID from tracks WHERE CleanName LIKE ?', [helpers.cleanName(f_artist +' '+f.album+' '+f.title)]).fetchone() - - if not track: - track = myDB.action('SELECT TrackID from tracks WHERE ArtistName LIKE ? AND AlbumTitle LIKE ? AND TrackTitle LIKE ?', [f_artist, f.album, f.title]).fetchone() + f_artist = None - if track: - myDB.action('UPDATE tracks SET Location=?, BitRate=?, Format=? WHERE TrackID=?', [unicode_song_path, f.bitrate, f.format, track['TrackID']]) - continue - - # Try to match on mbid if available and we couldn't find a match based on metadata - if f.mb_trackid: + # Add the song to our song list - + # TODO: skip adding songs without the minimum requisite information (just a matter of putting together the right if statements) - # Wondering if theres a better way to do this -> do one thing if the row exists, - # do something else if it doesn't - track = myDB.action('SELECT TrackID from tracks WHERE TrackID=?', [f.mb_trackid]).fetchone() + song_dict = { 'TrackID' : f.mb_trackid, + 'ReleaseID' : f.mb_albumid, + 'ArtistName' : f_artist, + 'AlbumTitle' : f.album, + 'TrackNumber': f.track, + 'TrackLength': f.length, + 'Genre' : f.genre, + 'Date' : f.date, + 'TrackTitle' : f.title, + 'BitRate' : f.bitrate, + 'Format' : f.format, + 'Location' : unicode_song_path } + + song_list.append(song_dict) + + # Now we start track matching + total_number_of_songs = len(song_list) + logger.info("Found " + str(total_number_of_songs) + " tracks in: '" + dir + "'. Matching tracks to the appropriate releases....") + + # Sort the song_list by most vague (e.g. no trackid or releaseid) to most specific (both trackid & releaseid) + # When we insert into the database, the tracks with the most specific information will overwrite the more general matches + + song_list = helpers.multikeysort(song_list, ['ReleaseID', 'TrackID']) + + # We'll use this to give a % completion, just because the track matching might take a while + song_count = 0 + + for song in song_list: - if track: - myDB.action('UPDATE tracks SET Location=?, BitRate=?, Format=? WHERE TrackID=?', [unicode_song_path, f.bitrate, f.format, track['TrackID']]) - continue - - # if we can't find a match in the database on a track level, it might be a new artist or it might be on a non-mb release - new_artists.append(f_artist) - - # The have table will become the new database for unmatched tracks (i.e. tracks with no associated links in the database - myDB.action('INSERT INTO have (ArtistName, AlbumTitle, TrackNumber, TrackTitle, TrackLength, BitRate, Genre, Date, TrackID, Location, CleanName, Format) VALUES( ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)', [f_artist, f.album, f.track, f.title, f.length, f.bitrate, f.genre, f.date, f.mb_trackid, unicode_song_path, helpers.cleanName(f_artist+' '+f.album+' '+f.title), f.format]) + song_count += 1 + completion_percentage = float(song_count)/total_number_of_songs * 100 + + if completion_percentage%10 == 0: + logger.info("Track matching is " + str(completion_percentage) + "% complete") + + # If the track has a trackid & releaseid (beets: albumid) that the most surefire way + # of identifying a track to a specific release so we'll use that first + if song['TrackID'] and song['ReleaseID']: - logger.info('Completed scanning directory: %s' % dir) + # Check both the tracks table & alltracks table in case they haven't populated the alltracks table yet + track = myDB.action('SELECT TrackID, ReleaseID, AlbumID from alltracks WHERE TrackID=? AND ReleaseID=?', [song['TrackID'], song['ReleaseID']]).fetchone() + + # It might be the case that the alltracks table isn't populated yet, so maybe we can only find a match in the tracks table + if not track: + track = myDB.action('SELECT TrackID, ReleaseID, AlbumID from tracks WHERE TrackID=? AND ReleaseID=?', [song['TrackID'], song['ReleaseID']]).fetchone() + + if track: + # Use TrackID & ReleaseID here since there can only be one possible match with a TrackID & ReleaseID query combo + controlValueDict = { 'TrackID' : track['TrackID'], + 'ReleaseID' : track['ReleaseID'] } + + # Insert it into the Headphones hybrid release (ReleaseID == AlbumID) + hybridControlValueDict = { 'TrackID' : track['TrackID'], + 'ReleaseID' : track['AlbumID'] } + + newValueDict = { 'Location' : song['Location'], + 'BitRate' : song['BitRate'], + 'Format' : song['Format'] } + + # Update both the tracks table and the alltracks table using the controlValueDict and hybridControlValueDict + myDB.upsert("alltracks", newValueDict, controlValueDict) + myDB.upsert("tracks", newValueDict, controlValueDict) + + myDB.upsert("alltracks", newValueDict, hybridControlValueDict) + myDB.upsert("tracks", newValueDict, hybridControlValueDict) + + # Matched. Move on to the next one: + continue + + # If we can't find it with TrackID & ReleaseID, next most specific will be + # releaseid + tracktitle, although perhaps less reliable due to a higher + # likelihood of variations in the song title (e.g. feat. artists) + if song['ReleaseID'] and song['TrackTitle']: + + track = myDB.action('SELECT TrackID, ReleaseID, AlbumID from alltracks WHERE ReleaseID=? AND TrackTitle=?', [song['ReleaseID'], song['TrackTitle']]).fetchone() + + if not track: + track = myDB.action('SELECT TrackID, ReleaseID, AlbumID from tracks WHERE ReleaseID=? AND TrackTitle=?', [song['ReleaseID'], song['TrackTitle']]).fetchone() + + if track: + # There can also only be one match for this query as well (although it might be on both the tracks and alltracks table) + # So use both TrackID & ReleaseID as the control values + controlValueDict = { 'TrackID' : track['TrackID'], + 'ReleaseID' : track['ReleaseID'] } + + hybridControlValueDict = { 'TrackID' : track['TrackID'], + 'ReleaseID' : track['AlbumID'] } + + newValueDict = { 'Location' : song['Location'], + 'BitRate' : song['BitRate'], + 'Format' : song['Format'] } + + # Update both tables here as well + myDB.upsert("alltracks", newValueDict, controlValueDict) + myDB.upsert("tracks", newValueDict, controlValueDict) + + myDB.upsert("alltracks", newValueDict, hybridControlValueDict) + myDB.upsert("tracks", newValueDict, hybridControlValueDict) + + # Done + continue + + # Next most specific will be the opposite: a TrackID and an AlbumTitle + # TrackIDs span multiple releases so if something is on an official album + # and a compilation, for example, this will match it to the right one + # However - there may be multiple matches here + if song['TrackID'] and song['AlbumTitle']: + + # Even though there might be multiple matches, we just need to grab one to confirm a match + track = myDB.action('SELECT TrackID, AlbumTitle from alltracks WHERE TrackID=? AND AlbumTitle LIKE ?', [song['TrackID'], song['AlbumTitle']]).fetchone() + + if not track: + track = myDB.action('SELECT TrackID, AlbumTitle from tracks WHERE TrackID=? AND AlbumTitle LIKE ?', [song['TrackID'], song['AlbumTitle']]).fetchone() + + if track: + # Don't need the hybridControlValueDict here since ReleaseID is not unique + controlValueDict = { 'TrackID' : track['TrackID'], + 'AlbumTitle' : track['AlbumTitle'] } + + newValueDict = { 'Location' : song['Location'], + 'BitRate' : song['BitRate'], + 'Format' : song['Format'] } + + myDB.upsert("alltracks", newValueDict, controlValueDict) + myDB.upsert("tracks", newValueDict, controlValueDict) + + continue + + # Next most specific is the ArtistName + AlbumTitle + TrackTitle combo (but probably + # even more unreliable than the previous queries, and might span multiple releases) + if song['ArtistName'] and song['AlbumTitle'] and song['TrackTitle']: + + track = myDB.action('SELECT ArtistName, AlbumTitle, TrackTitle from alltracks WHERE ArtistName LIKE ? AND AlbumTitle LIKE ? AND TrackTitle LIKE ?', [song['ArtistName'], song['AlbumTitle'], song['TrackTitle']]).fetchone() + + if not track: + track = myDB.action('SELECT ArtistName, AlbumTitle, TrackTitle from tracks WHERE ArtistName LIKE ? AND AlbumTitle LIKE ? AND TrackTitle LIKE ?', [song['ArtistName'], song['AlbumTitle'], song['TrackTitle']]).fetchone() + + if track: + controlValueDict = { 'ArtistName' : track['ArtistName'], + 'AlbumTitle' : track['AlbumTitle'], + 'TrackTitle' : track['TrackTitle'] } + + newValueDict = { 'Location' : song['Location'], + 'BitRate' : song['BitRate'], + 'Format' : song['Format'] } + + myDB.upsert("alltracks", newValueDict, controlValueDict) + myDB.upsert("tracks", newValueDict, controlValueDict) + + continue + + # Use the "CleanName" (ArtistName + AlbumTitle + TrackTitle stripped of punctuation, capitalization, etc) + # This is more reliable than the former but requires some string manipulation so we'll do it only + # if we can't find a match with the original data + if song['ArtistName'] and song['AlbumTitle'] and song['TrackTitle']: + + CleanName = helpers.cleanName(song['ArtistName'] +' '+ song['AlbumTitle'] +' '+song['TrackTitle']) + + track = myDB.action('SELECT CleanName from alltracks WHERE CleanName LIKE ?', [CleanName]).fetchone() + + if not track: + track = myDB.action('SELECT CleanName from tracks WHERE CleanName LIKE ?', [CleanName]).fetchone() + + if track: + controlValueDict = { 'CleanName' : track['CleanName'] } + + newValueDict = { 'Location' : song['Location'], + 'BitRate' : song['BitRate'], + 'Format' : song['Format'] } + + myDB.upsert("alltracks", newValueDict, controlValueDict) + myDB.upsert("tracks", newValueDict, controlValueDict) + + continue + + # Match on TrackID alone if we can't find it using any of the above methods. This method is reliable + # but spans multiple releases - but that's why we're putting at the beginning as a last resort. If a track + # with more specific information exists in the library, it'll overwrite these values + if song['TrackID']: + + track = myDB.action('SELECT TrackID from alltracks WHERE TrackID=?', [song['TrackID']]).fetchone() + + if not track: + track = myDB.action('SELECT TrackID from tracks WHERE TrackID=?', [song['TrackID']]).fetchone() + + if track: + controlValueDict = { 'TrackID' : track['TrackID'] } + + newValueDict = { 'Location' : song['Location'], + 'BitRate' : song['BitRate'], + 'Format' : song['Format'] } + + myDB.upsert("alltracks", newValueDict, controlValueDict) + myDB.upsert("tracks", newValueDict, controlValueDict) + + continue + + # if we can't find a match in the database on a track level, it might be a new artist or it might be on a non-mb release + new_artists.append(song['ArtistName']) + + # The have table will become the new database for unmatched tracks (i.e. tracks with no associated links in the database + CleanName = helpers.cleanName(song['ArtistName'] +' '+ song['AlbumTitle'] +' '+song['TrackTitle']) + + myDB.action('INSERT INTO have (ArtistName, AlbumTitle, TrackNumber, TrackTitle, TrackLength, BitRate, Genre, Date, TrackID, Location, CleanName, Format) VALUES( ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)', [song['ArtistName'], song['AlbumTitle'], song['TrackNumber'], song['TrackTitle'], song['TrackLength'], song['BitRate'], song['Genre'], song['Date'], song['TrackID'], song['Location'], CleanName, song['Format']]) + + logger.info('Completed matching tracks from directory: %s' % dir) # Clean up the new artist list unique_artists = {}.fromkeys(new_artists).keys() @@ -115,9 +297,11 @@ def libraryScan(dir=None): artist_list = [f for f in unique_artists if f.lower() not in [x[0].lower() for x in current_artists]] # Update track counts - logger.info('Updating track counts') + logger.info('Updating current artist track counts') for artist in current_artists: + # Have tracks are selected from tracks table and not all tracks because of duplicates + # We update the track count upon an album switch to compliment this havetracks = len(myDB.select('SELECT TrackTitle from tracks WHERE ArtistID like ? AND Location IS NOT NULL', [artist['ArtistID']])) + len(myDB.select('SELECT TrackTitle from have WHERE ArtistName like ?', [artist['ArtistName']])) myDB.action('UPDATE artists SET HaveTracks=? WHERE ArtistID=?', [havetracks, artist['ArtistID']]) From 85c1bf2a3ea71d18dc2fc0c42411e3a12582d439 Mon Sep 17 00:00:00 2001 From: Ade Date: Wed, 8 Aug 2012 20:47:16 +1200 Subject: [PATCH 47/84] Search for Torrent if NZB found but out of retention --- headphones/searcher.py | 28 +++++++++++++++++++--------- 1 file changed, 19 insertions(+), 9 deletions(-) diff --git a/headphones/searcher.py b/headphones/searcher.py index 85b699e5..50068271 100644 --- a/headphones/searcher.py +++ b/headphones/searcher.py @@ -496,6 +496,8 @@ def searchNZB(albumid=None, new=False, losslessOnly=False): myDB.action('UPDATE albums SET status = "Snatched" WHERE AlbumID=?', [albums[2]]) myDB.action('INSERT INTO snatched VALUES( ?, ?, ?, ?, DATETIME("NOW", "localtime"), ?, ?)', [albums[2], bestqual[0], bestqual[1], bestqual[2], "Snatched", nzb_folder_name]) return "found" + else: + return "none" else: return "none" @@ -673,7 +675,9 @@ def searchTorrent(albumid=None, new=False, losslessOnly=False): data = False if data: - + + logger.info(u'Parsing results from KAT' % searchURL) + d = feedparser.parse(data) if not len(d.entries): logger.info(u"No results found from %s for %s" % (provider, term)) @@ -707,7 +711,7 @@ def searchTorrent(albumid=None, new=False, losslessOnly=False): resultlist.append((title, size, url, provider)) logger.info('Found %s. Size: %s' % (title, helpers.bytes_to_mb(size))) else: - logger.info('%s is larger than the maxsize, the wrong format or has to little seeders for this category, skipping. (Size: %i bytes, Seeders: %i, Format: %s)' % (title, size, int(seeders), rightformat)) + logger.info('%s is larger than the maxsize, the wrong format or has too little seeders for this category, skipping. (Size: %i bytes, Seeders: %i, Format: %s)' % (title, size, int(seeders), rightformat)) except Exception, e: logger.error(u"An unknown error occurred in the KAT parser: %s" % e) @@ -754,7 +758,9 @@ def searchTorrent(albumid=None, new=False, losslessOnly=False): data = False if data: - + + logger.info(u'Parsing results from Waffles.fm' % searchURL) + d = feedparser.parse(data) if not len(d.entries): logger.info(u"No results found from %s for %s" % (provider, term)) @@ -780,7 +786,7 @@ def searchTorrent(albumid=None, new=False, losslessOnly=False): if headphones.ISOHUNT: - provider = "ISOhunt" + provider = "isoHunt" providerurl = url_fix("http://isohunt.com/js/rss/" + term) if headphones.PREFERRED_QUALITY == 3 or losslessOnly: categories = "7" #music @@ -809,6 +815,8 @@ def searchTorrent(albumid=None, new=False, losslessOnly=False): if data: + logger.info(u'Parsing results from isoHunt' % searchURL) + d = feedparser.parse(data) if not len(d.entries): logger.info(u"No results found from %s for %s" % (provider, term)) @@ -846,10 +854,10 @@ def searchTorrent(albumid=None, new=False, losslessOnly=False): resultlist.append((title, size, url, provider)) logger.info('Found %s. Size: %s' % (title, helpers.bytes_to_mb(size))) else: - logger.info('%s is larger than the maxsize, the wrong format or has to little seeders for this category, skipping. (Size: %i bytes, Seeders: %i, Format: %s)' % (title, size, int(seeds), rightformat)) + logger.info('%s is larger than the maxsize, the wrong format or has too little seeders for this category, skipping. (Size: %i bytes, Seeders: %i, Format: %s)' % (title, size, int(seeds), rightformat)) except Exception, e: - logger.error(u"An unknown error occurred in the ISOhunt parser: %s" % e) + logger.error(u"An unknown error occurred in the isoHunt parser: %s" % e) if headphones.MININOVA: provider = "Mininova" @@ -877,6 +885,8 @@ def searchTorrent(albumid=None, new=False, losslessOnly=False): if data: + logger.info(u'Parsing results from Mininova' % searchURL) + d = feedparser.parse(data) if not len(d.entries): logger.info(u"No results found from %s for %s" % (provider, term)) @@ -913,10 +923,10 @@ def searchTorrent(albumid=None, new=False, losslessOnly=False): resultlist.append((title, size, url, provider)) logger.info('Found %s. Size: %s' % (title, helpers.bytes_to_mb(size))) else: - logger.info('%s is larger than the maxsize, the wrong format or has to little seeders for this category, skipping. (Size: %i bytes, Seeders: %i, Format: %s)' % (title, size, int(seeds), rightformat)) + logger.info('%s is larger than the maxsize, the wrong format or has too little seeders for this category, skipping. (Size: %i bytes, Seeders: %i, Format: %s)' % (title, size, int(seeds), rightformat)) except Exception, e: - logger.error(u"An unknown error occurred in the MiniNova Parser: %s" % e) + logger.error(u"An unknown error occurred in the Mininova Parser: %s" % e) @@ -983,7 +993,7 @@ def searchTorrent(albumid=None, new=False, losslessOnly=False): (data, bestqual) = preprocesstorrent(torrentlist) if data and bestqual: - logger.info(u'Found best result: %s - %s' % (bestqual[2], bestqual[0], helpers.bytes_to_mb(bestqual[1]))) + logger.info(u'Found best result from %s: %s - %s' % (bestqual[3], bestqual[2], bestqual[0], helpers.bytes_to_mb(bestqual[1]))) torrent_folder_name = '%s - %s [%s]' % (helpers.latinToAscii(albums[0]).encode('UTF-8').replace('/', '_'), helpers.latinToAscii(albums[1]).encode('UTF-8').replace('/', '_'), year) if headphones.TORRENTBLACKHOLE_DIR == "sendtracker": From ffb4798d04f9dc7d5fd81896ab40021ebcbbb5a8 Mon Sep 17 00:00:00 2001 From: rembo10 Date: Thu, 16 Aug 2012 23:09:27 +0530 Subject: [PATCH 48/84] Fix for last.fm query hanging when updating artist info --- headphones/lastfm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/headphones/lastfm.py b/headphones/lastfm.py index a55b6c8f..ce77948b 100644 --- a/headphones/lastfm.py +++ b/headphones/lastfm.py @@ -42,7 +42,7 @@ def getSimilar(): time.sleep(1) continue - if len(data) < 200: + if not data or len(data) < 200: continue try: From ffeb4383372851317262132d7d69eab0d2a96390 Mon Sep 17 00:00:00 2001 From: rembo10 Date: Thu, 16 Aug 2012 23:50:32 +0530 Subject: [PATCH 49/84] Modified library scan to allow single directories to be appended to the library; postprocessor now used that function to update track counts after post processing is complete --- headphones/librarysync.py | 87 ++++++++++++++++++++++--------------- headphones/postprocessor.py | 49 ++------------------- 2 files changed, 54 insertions(+), 82 deletions(-) diff --git a/headphones/librarysync.py b/headphones/librarysync.py index b7484b1c..3f31d4c0 100644 --- a/headphones/librarysync.py +++ b/headphones/librarysync.py @@ -21,12 +21,16 @@ from lib.beets.mediafile import MediaFile import headphones from headphones import db, logger, helpers, importer -def libraryScan(dir=None): +# You can scan a single directory and append it to the current library by specifying append=True, ArtistID & ArtistName +def libraryScan(dir=None, append=False, ArtistID=None, ArtistName=None): if not dir: dir = headphones.MUSIC_DIR - - dir = dir.encode(headphones.SYS_ENCODING) + + # If we're appending a dir, it's coming from the post processor which is + # already bytestring + if not append: + dir = dir.encode(headphones.SYS_ENCODING) if not os.path.isdir(dir): logger.warn('Cannot find directory: %s. Not scanning' % dir.decode(headphones.SYS_ENCODING)) @@ -34,12 +38,15 @@ def libraryScan(dir=None): myDB = db.DBConnection() - # Clean up bad filepaths - tracks = myDB.select('SELECT Location, TrackID from tracks WHERE Location IS NOT NULL') + if not append: + # Clean up bad filepaths + tracks = myDB.select('SELECT Location, TrackID from tracks WHERE Location IS NOT NULL') - for track in tracks: - if not os.path.isfile(track['Location'].encode(headphones.SYS_ENCODING)): - myDB.action('UPDATE tracks SET Location=?, BitRate=?, Format=? WHERE TrackID=?', [None, None, None, track['TrackID']]) + for track in tracks: + if not os.path.isfile(track['Location'].encode(headphones.SYS_ENCODING)): + myDB.action('UPDATE tracks SET Location=?, BitRate=?, Format=? WHERE TrackID=?', [None, None, None, track['TrackID']]) + + myDB.action('DELETE from have') logger.info('Scanning music directory: %s' % dir) @@ -47,8 +54,6 @@ def libraryScan(dir=None): bitrates = [] song_list = [] - - myDB.action('DELETE from have') for r,d,f in os.walk(dir): for files in f: @@ -290,32 +295,42 @@ def libraryScan(dir=None): logger.info('Completed matching tracks from directory: %s' % dir) - # Clean up the new artist list - unique_artists = {}.fromkeys(new_artists).keys() - current_artists = myDB.select('SELECT ArtistName, ArtistID from artists') - artist_list = [f for f in unique_artists if f.lower() not in [x[0].lower() for x in current_artists]] - - # Update track counts - logger.info('Updating current artist track counts') - - for artist in current_artists: - # Have tracks are selected from tracks table and not all tracks because of duplicates - # We update the track count upon an album switch to compliment this - havetracks = len(myDB.select('SELECT TrackTitle from tracks WHERE ArtistID like ? AND Location IS NOT NULL', [artist['ArtistID']])) + len(myDB.select('SELECT TrackTitle from have WHERE ArtistName like ?', [artist['ArtistName']])) - myDB.action('UPDATE artists SET HaveTracks=? WHERE ArtistID=?', [havetracks, artist['ArtistID']]) + if not append: + # Clean up the new artist list + unique_artists = {}.fromkeys(new_artists).keys() + current_artists = myDB.select('SELECT ArtistName, ArtistID from artists') - logger.info('Found %i new artists' % len(artist_list)) - - if len(artist_list): - if headphones.ADD_ARTISTS: - logger.info('Importing %i new artists' % len(artist_list)) - importer.artistlist_to_mbids(artist_list) - else: - logger.info('To add these artists, go to Manage->Manage New Artists') - myDB.action('DELETE from newartists') - for artist in artist_list: - myDB.action('INSERT into newartists VALUES (?)', [artist]) + artist_list = [f for f in unique_artists if f.lower() not in [x[0].lower() for x in current_artists]] + + # Update track counts + logger.info('Updating current artist track counts') + + for artist in current_artists: + # Have tracks are selected from tracks table and not all tracks because of duplicates + # We update the track count upon an album switch to compliment this + havetracks = len(myDB.select('SELECT TrackTitle from tracks WHERE ArtistID=? AND Location IS NOT NULL', [artist['ArtistID']])) + len(myDB.select('SELECT TrackTitle from have WHERE ArtistName like ?', [artist['ArtistName']])) + myDB.action('UPDATE artists SET HaveTracks=? WHERE ArtistID=?', [havetracks, artist['ArtistID']]) + + logger.info('Found %i new artists' % len(artist_list)) + + if len(artist_list): + if headphones.ADD_ARTISTS: + logger.info('Importing %i new artists' % len(artist_list)) + importer.artistlist_to_mbids(artist_list) + else: + logger.info('To add these artists, go to Manage->Manage New Artists') + myDB.action('DELETE from newartists') + for artist in artist_list: + myDB.action('INSERT into newartists VALUES (?)', [artist]) + + if headphones.DETECT_BITRATE: + headphones.PREFERRED_BITRATE = sum(bitrates)/len(bitrates)/1000 + + else: + # If we're appending a new album to the database, update the artists total track counts + logger.info('Updating artist track counts') + + havetracks = len(myDB.select('SELECT TrackTitle from tracks WHERE ArtistID=? AND Location IS NOT NULL', [ArtistID])) + len(myDB.select('SELECT TrackTitle from have WHERE ArtistName like ?', [ArtistName])) + myDB.action('UPDATE artists SET HaveTracks=? WHERE ArtistID=?', [havetracks, ArtistID]) - if headphones.DETECT_BITRATE: - headphones.PREFERRED_BITRATE = sum(bitrates)/len(bitrates)/1000 diff --git a/headphones/postprocessor.py b/headphones/postprocessor.py index 6a90c706..3c855494 100644 --- a/headphones/postprocessor.py +++ b/headphones/postprocessor.py @@ -26,7 +26,7 @@ from lib.beets import autotag from lib.beets.mediafile import MediaFile import headphones -from headphones import db, albumart, lyrics, logger, helpers +from headphones import db, albumart, librarysync, lyrics, logger, helpers postprocessor_lock = threading.Lock() @@ -269,21 +269,8 @@ def doPostProcessing(albumid, albumpath, release, tracks, downloaded_track_list) logger.error('No DESTINATION_DIR has been set. Set "Destination Directory" to the parent directory you want to move the files to') pass - myDB = db.DBConnection() - # There's gotta be a better way to update the have tracks - sqlite - - trackcount = myDB.select('SELECT HaveTracks from artists WHERE ArtistID=?', [release['ArtistID']]) - - if not trackcount[0][0]: - cur_track_count = 0 - else: - cur_track_count = trackcount[0][0] - - new_track_count = cur_track_count + len(downloaded_track_list) - myDB.action('UPDATE artists SET HaveTracks=? WHERE ArtistID=?', [new_track_count, release['ArtistID']]) - myDB.action('UPDATE albums SET status = "Downloaded" WHERE AlbumID=?', [albumid]) - myDB.action('UPDATE snatched SET status = "Processed" WHERE AlbumID=?', [albumid]) - updateHave(albumpath) + # Update the have tracks + librarysync.libraryScan(dir=albumpath, append=True, ArtistID=release['ArtistID'], ArtistName=release['ArtistName']) logger.info('Post-processing for %s - %s complete' % (release['ArtistName'], release['AlbumTitle'])) @@ -577,36 +564,6 @@ def renameFiles(albumpath, downloaded_track_list, release): except Exception, e: logger.error('Error renaming file: %s. Error: %s' % (downloaded_track, e)) continue - -def updateHave(albumpath): - - results = [] - - for r,d,f in os.walk(albumpath): - for files in f: - if any(files.lower().endswith('.' + x.lower()) for x in headphones.MEDIA_FORMATS): - results.append(os.path.join(r, files)) - - if results: - - myDB = db.DBConnection() - - for song in results: - try: - f = MediaFile(song) - #logger.debug('Reading: %s' % song.decode('UTF-8')) - except: - logger.warn('Could not read file: %s' % song) - continue - else: - if f.albumartist: - artist = f.albumartist - elif f.artist: - artist = f.artist - else: - continue - - myDB.action('UPDATE tracks SET Location=?, BitRate=?, Format=? WHERE ArtistName LIKE ? AND AlbumTitle LIKE ? AND TrackTitle LIKE ?', [unicode(song, headphones.SYS_ENCODING, errors="replace"), f.bitrate, f.format, artist, f.album, f.title]) def renameUnprocessedFolder(albumpath): From 0039a93736c3f01cd15c373317d723431bea5e0e Mon Sep 17 00:00:00 2001 From: rembo10 Date: Fri, 17 Aug 2012 00:28:56 +0530 Subject: [PATCH 50/84] Accidentally removed marking albums as Downloaded after being post processed --- headphones/postprocessor.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/headphones/postprocessor.py b/headphones/postprocessor.py index 3c855494..cd7fcadc 100644 --- a/headphones/postprocessor.py +++ b/headphones/postprocessor.py @@ -269,6 +269,10 @@ def doPostProcessing(albumid, albumpath, release, tracks, downloaded_track_list) logger.error('No DESTINATION_DIR has been set. Set "Destination Directory" to the parent directory you want to move the files to') pass + myDB = db.DBConnection() + myDB.action('UPDATE albums SET status = "Downloaded" WHERE AlbumID=?', [albumid]) + myDB.action('UPDATE snatched SET status = "Processed" WHERE AlbumID=?', [albumid]) + # Update the have tracks librarysync.libraryScan(dir=albumpath, append=True, ArtistID=release['ArtistID'], ArtistName=release['ArtistName']) From e2911e4f2b8ad34219a7deb95694de429672c205 Mon Sep 17 00:00:00 2001 From: rembo10 Date: Fri, 17 Aug 2012 01:00:28 +0530 Subject: [PATCH 51/84] Added config options for LOSSLESS_DESTINATION_DIR and DELETE_LOSSLESS_FILES --- headphones/__init__.py | 13 ++++++++++--- headphones/webserve.py | 10 +++++++--- 2 files changed, 17 insertions(+), 6 deletions(-) diff --git a/headphones/__init__.py b/headphones/__init__.py index e61cb103..2b938e02 100644 --- a/headphones/__init__.py +++ b/headphones/__init__.py @@ -84,6 +84,7 @@ CHECK_GITHUB_INTERVAL = None MUSIC_DIR = None DESTINATION_DIR = None +LOSSLESS_DESTINATION_DIR = None FOLDER_FORMAT = None FILE_FORMAT = None PATH_TO_XML = None @@ -165,6 +166,7 @@ ENCODEROUTPUTFORMAT = None ENCODERQUALITY = None ENCODERVBRCBR = None ENCODERLOSSLESS = False +DELETE_LOSSLESS_FILES = False PROWL_ENABLED = True PROWL_PRIORITY = 1 PROWL_KEYS = None @@ -238,7 +240,8 @@ def initialize(): global __INITIALIZED__, FULL_PATH, PROG_DIR, VERBOSE, DAEMON, DATA_DIR, CONFIG_FILE, CFG, CONFIG_VERSION, LOG_DIR, CACHE_DIR, \ HTTP_PORT, HTTP_HOST, HTTP_USERNAME, HTTP_PASSWORD, HTTP_ROOT, HTTP_PROXY, LAUNCH_BROWSER, API_ENABLED, API_KEY, GIT_PATH, \ - CURRENT_VERSION, LATEST_VERSION, CHECK_GITHUB, CHECK_GITHUB_ON_STARTUP, CHECK_GITHUB_INTERVAL, MUSIC_DIR, DESTINATION_DIR, PREFERRED_QUALITY, PREFERRED_BITRATE, DETECT_BITRATE, \ + CURRENT_VERSION, LATEST_VERSION, CHECK_GITHUB, CHECK_GITHUB_ON_STARTUP, CHECK_GITHUB_INTERVAL, MUSIC_DIR, DESTINATION_DIR, \ + LOSSLESS_DESTINATION_DIR, PREFERRED_QUALITY, PREFERRED_BITRATE, DETECT_BITRATE, \ ADD_ARTISTS, CORRECT_METADATA, MOVE_FILES, RENAME_FILES, FOLDER_FORMAT, FILE_FORMAT, CLEANUP_FILES, INCLUDE_EXTRAS, AUTOWANT_UPCOMING, AUTOWANT_ALL, \ ADD_ALBUM_ART, EMBED_ALBUM_ART, EMBED_LYRICS, DOWNLOAD_DIR, BLACKHOLE, BLACKHOLE_DIR, USENET_RETENTION, SEARCH_INTERVAL, \ TORRENTBLACKHOLE_DIR, NUMBEROFSEEDERS, ISOHUNT, KAT, MININOVA, WAFFLES, WAFFLES_UID, WAFFLES_PASSKEY, DOWNLOAD_TORRENT_DIR, \ @@ -246,7 +249,7 @@ def initialize(): NZBMATRIX, NZBMATRIX_USERNAME, NZBMATRIX_APIKEY, NEWZNAB, NEWZNAB_HOST, NEWZNAB_APIKEY, NEWZNAB_ENABLED, EXTRA_NEWZNABS,\ NZBSORG, NZBSORG_UID, NZBSORG_HASH, NEWZBIN, NEWZBIN_UID, NEWZBIN_PASSWORD, LASTFM_USERNAME, INTERFACE, FOLDER_PERMISSIONS, \ ENCODERFOLDER, ENCODER, BITRATE, SAMPLINGFREQUENCY, MUSIC_ENCODER, ADVANCEDENCODER, ENCODEROUTPUTFORMAT, ENCODERQUALITY, ENCODERVBRCBR, \ - ENCODERLOSSLESS, PROWL_ENABLED, PROWL_PRIORITY, PROWL_KEYS, PROWL_ONSNATCH, MIRRORLIST, MIRROR, CUSTOMHOST, CUSTOMPORT, \ + ENCODERLOSSLESS, DELETE_LOSSLESS_FILES, PROWL_ENABLED, PROWL_PRIORITY, PROWL_KEYS, PROWL_ONSNATCH, MIRRORLIST, MIRROR, CUSTOMHOST, CUSTOMPORT, \ CUSTOMSLEEP, HPUSER, HPPASS, XBMC_ENABLED, XBMC_HOST, XBMC_USERNAME, XBMC_PASSWORD, XBMC_UPDATE, XBMC_NOTIFY, NMA_ENABLED, NMA_APIKEY, NMA_PRIORITY, SYNOINDEX_ENABLED, \ ALBUM_COMPLETION_PCT @@ -295,6 +298,7 @@ def initialize(): MUSIC_DIR = check_setting_str(CFG, 'General', 'music_dir', '') DESTINATION_DIR = check_setting_str(CFG, 'General', 'destination_dir', '') + LOSSLESS_DESTINATION_DIR = check_setting_str(CFG, 'General', 'lossless_destination_dir', '') PREFERRED_QUALITY = check_setting_int(CFG, 'General', 'preferred_quality', 0) PREFERRED_BITRATE = check_setting_int(CFG, 'General', 'preferred_bitrate', '') DETECT_BITRATE = bool(check_setting_int(CFG, 'General', 'detect_bitrate', 0)) @@ -373,6 +377,7 @@ def initialize(): ENCODERQUALITY = check_setting_int(CFG, 'General', 'encoderquality', 2) ENCODERVBRCBR = check_setting_str(CFG, 'General', 'encodervbrcbr', 'cbr') ENCODERLOSSLESS = bool(check_setting_int(CFG, 'General', 'encoderlossless', 1)) + DELETE_LOSSLESS_FILES = bool(check_setting_int(CFG, 'General', 'delete_lossless_files', 1)) PROWL_ENABLED = bool(check_setting_int(CFG, 'Prowl', 'prowl_enabled', 0)) PROWL_KEYS = check_setting_str(CFG, 'Prowl', 'prowl_keys', '') @@ -570,6 +575,7 @@ def config_write(): new_config['General']['music_dir'] = MUSIC_DIR new_config['General']['destination_dir'] = DESTINATION_DIR + new_config['General']['lossless_destination_dir'] = LOSSLESS_DESTINATION_DIR new_config['General']['preferred_quality'] = PREFERRED_QUALITY new_config['General']['preferred_bitrate'] = PREFERRED_BITRATE new_config['General']['detect_bitrate'] = int(DETECT_BITRATE) @@ -677,7 +683,8 @@ def config_write(): new_config['General']['encoderoutputformat'] = ENCODEROUTPUTFORMAT new_config['General']['encoderquality'] = ENCODERQUALITY new_config['General']['encodervbrcbr'] = ENCODERVBRCBR - new_config['General']['encoderlossless'] = ENCODERLOSSLESS + new_config['General']['encoderlossless'] = int(ENCODERLOSSLESS) + new_config['General']['delete_lossless_files'] = int(DELETE_LOSSLESS_FILES) new_config['General']['mirror'] = MIRROR new_config['General']['customhost'] = CUSTOMHOST diff --git a/headphones/webserve.py b/headphones/webserve.py index 2f88c041..d00c1f86 100644 --- a/headphones/webserve.py +++ b/headphones/webserve.py @@ -429,6 +429,7 @@ class WebInterface(object): "embed_album_art" : checked(headphones.EMBED_ALBUM_ART), "embed_lyrics" : checked(headphones.EMBED_LYRICS), "dest_dir" : headphones.DESTINATION_DIR, + "lossless_dest_dir" : headphones.LOSSLESS_DESTINATION_DIR, "folder_format" : headphones.FOLDER_FORMAT, "file_format" : headphones.FILE_FORMAT, "include_extras" : checked(headphones.INCLUDE_EXTRAS), @@ -446,6 +447,7 @@ class WebInterface(object): "encodervbrcbr": headphones.ENCODERVBRCBR, "encoderquality": headphones.ENCODERQUALITY, "encoderlossless": checked(headphones.ENCODERLOSSLESS), + "delete_lossless_files": checked(headphones.DELETE_LOSSLESS_FILES), "prowl_enabled": checked(headphones.PROWL_ENABLED), "prowl_onsnatch": checked(headphones.PROWL_ONSNATCH), "prowl_keys": headphones.PROWL_KEYS, @@ -477,8 +479,8 @@ class WebInterface(object): usenet_retention=None, nzbmatrix=0, nzbmatrix_username=None, nzbmatrix_apikey=None, newznab=0, newznab_host=None, newznab_apikey=None, newznab_enabled=0, nzbsorg=0, nzbsorg_uid=None, nzbsorg_hash=None, newzbin=0, newzbin_uid=None, newzbin_password=None, preferred_quality=0, preferred_bitrate=None, detect_bitrate=0, move_files=0, torrentblackhole_dir=None, download_torrent_dir=None, numberofseeders=10, use_isohunt=0, use_kat=0, use_mininova=0, waffles=0, waffles_uid=None, waffles_passkey=None, - rename_files=0, correct_metadata=0, cleanup_files=0, add_album_art=0, embed_album_art=0, embed_lyrics=0, destination_dir=None, folder_format=None, file_format=None, include_extras=0, autowant_upcoming=False, autowant_all=False, interface=None, log_dir=None, - music_encoder=0, encoder=None, bitrate=None, samplingfrequency=None, encoderfolder=None, advancedencoder=None, encoderoutputformat=None, encodervbrcbr=None, encoderquality=None, encoderlossless=0, + rename_files=0, correct_metadata=0, cleanup_files=0, add_album_art=0, embed_album_art=0, embed_lyrics=0, destination_dir=None, lossless_destination_dir=None, folder_format=None, file_format=None, include_extras=0, autowant_upcoming=False, autowant_all=False, interface=None, log_dir=None, + music_encoder=0, encoder=None, bitrate=None, samplingfrequency=None, encoderfolder=None, advancedencoder=None, encoderoutputformat=None, encodervbrcbr=None, encoderquality=None, encoderlossless=0, delete_lossless_files=0, prowl_enabled=0, prowl_onsnatch=0, prowl_keys=None, prowl_priority=0, xbmc_enabled=0, xbmc_host=None, xbmc_username=None, xbmc_password=None, xbmc_update=0, xbmc_notify=0, nma_enabled=False, nma_apikey=None, nma_priority=0, synoindex_enabled=False, mirror=None, customhost=None, customport=None, customsleep=None, hpuser=None, hppass=None, **kwargs): @@ -534,6 +536,7 @@ class WebInterface(object): headphones.EMBED_ALBUM_ART = embed_album_art headphones.EMBED_LYRICS = embed_lyrics headphones.DESTINATION_DIR = destination_dir + headphones.LOSSLESS_DESTINATION_DIR = lossless_destination_dir headphones.FOLDER_FORMAT = folder_format headphones.FILE_FORMAT = file_format headphones.INCLUDE_EXTRAS = include_extras @@ -550,7 +553,8 @@ class WebInterface(object): headphones.ENCODEROUTPUTFORMAT = encoderoutputformat headphones.ENCODERVBRCBR = encodervbrcbr headphones.ENCODERQUALITY = int(encoderquality) - headphones.ENCODERLOSSLESS = encoderlossless + headphones.ENCODERLOSSLESS = int(encoderlossless) + headphones.DELETE_LOSSLESS_FILES = int(delete_lossless_files) headphones.PROWL_ENABLED = prowl_enabled headphones.PROWL_ONSNATCH = prowl_onsnatch headphones.PROWL_KEYS = prowl_keys From 1ec58533f01338f34fc64098312cebae43d0dd26 Mon Sep 17 00:00:00 2001 From: rembo10 Date: Fri, 17 Aug 2012 01:11:11 +0530 Subject: [PATCH 52/84] Added lossless_destination_dir and delete_lossless_files to the config page (default template only) --- data/interfaces/default/config.html | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/data/interfaces/default/config.html b/data/interfaces/default/config.html index 216e4eb6..3dfa2420 100644 --- a/data/interfaces/default/config.html +++ b/data/interfaces/default/config.html @@ -342,10 +342,15 @@ m<%inherit file="base.html"/>
- + e.g. /Users/name/Music/iTunes or /Volumes/share/music
+
+ + + Set this if you have a separate directory for lossless music +
@@ -379,8 +384,14 @@ m<%inherit file="base.html"/>
+
-
+
+
+
+ +
+
<% if config['encoder'] == 'lame': lameselect = 'selected="selected"' From f2674b55a1651735cba006c0948c3ba683023afa Mon Sep 17 00:00:00 2001 From: rembo10 Date: Fri, 17 Aug 2012 16:07:22 +0530 Subject: [PATCH 53/84] Backend stuff done to get separate lossless & lossy directories working --- headphones/helpers.py | 36 ++++- headphones/music_encoder.py | 3 +- headphones/postprocessor.py | 271 ++++++++++++++++++++++++------------ 3 files changed, 219 insertions(+), 91 deletions(-) diff --git a/headphones/helpers.py b/headphones/helpers.py index 1dbfbcda..7f5cbab9 100644 --- a/headphones/helpers.py +++ b/headphones/helpers.py @@ -13,10 +13,10 @@ # You should have received a copy of the GNU General Public License # along with Headphones. If not, see . -import time +import os, time from operator import itemgetter import datetime -import re +import re, shutil import headphones @@ -218,3 +218,35 @@ def extract_song_data(s): else: logger.info("Couldn't parse " + s + " into a valid Newbin format") return (name, album, year) + +def smartMove(src, dest, delete=True): + + source_dir = os.path.dirname(src) + filename = os.path.basename(src) + + if os.path.isfile(os.path.join(dest, filename)): + logger.info('Destination file exists: %s' % os.path.join(dest, filename).decode(headphones.SYS_ENCODING)) + title = os.path.splitext(filename)[0] + ext = os.path.splitext(filename)[1] + i = 1 + while True: + newfile = title + '(' + str(i) + ')' + ext + if os.path.isfile(os.path.join(dest, newfile)): + i += 1 + else: + logger.info('Renaming to %s' % newfile) + try: + os.rename(src, os.path.join(source_dir, newfile)) + filename = newfile + except Exception, e: + logger.warn('Error renaming %s: %s' % (src.decode(headphones.SYS_ENCODING), e)) + break + + try: + if delete: + shutil.move(os.path.join(source_dir, filename), os.path.join(dest, filename)) + else: + shutil.copy(os.path.join(source_dir, filename), os.path.join(dest, filename)) + return True + except Exception, e: + logger.warn('Error moving file %s: %s' % (filename.decode(headphones.SYS_ENCODING), e)) diff --git a/headphones/music_encoder.py b/headphones/music_encoder.py index 51864aa4..71da9ba5 100644 --- a/headphones/music_encoder.py +++ b/headphones/music_encoder.py @@ -138,7 +138,8 @@ def command(encoder,musicSource,musicDest,albumPath): time.sleep(10) return_code = call(cmd, shell=True) if (return_code==0) and (os.path.exists(musicDest)): - os.remove(musicSource) + if headphones.DELETE_LOSSLESS_FILES: + os.remove(musicSource) shutil.move(musicDest,albumPath) logger.info('Music "%s" encoded in %s' % (musicSource,getTimeEncode(startMusicTime))) diff --git a/headphones/postprocessor.py b/headphones/postprocessor.py index cd7fcadc..7959dd64 100644 --- a/headphones/postprocessor.py +++ b/headphones/postprocessor.py @@ -263,7 +263,7 @@ def doPostProcessing(albumid, albumpath, release, tracks, downloaded_track_list) renameFiles(albumpath, downloaded_track_list, release) if headphones.MOVE_FILES and headphones.DESTINATION_DIR: - albumpath = moveFiles(albumpath, release, tracks) + albumpaths = moveFiles(albumpath, release, tracks) if headphones.MOVE_FILES and not headphones.DESTINATION_DIR: logger.error('No DESTINATION_DIR has been set. Set "Destination Directory" to the parent directory you want to move the files to') @@ -273,8 +273,9 @@ def doPostProcessing(albumid, albumpath, release, tracks, downloaded_track_list) myDB.action('UPDATE albums SET status = "Downloaded" WHERE AlbumID=?', [albumid]) myDB.action('UPDATE snatched SET status = "Processed" WHERE AlbumID=?', [albumid]) - # Update the have tracks - librarysync.libraryScan(dir=albumpath, append=True, ArtistID=release['ArtistID'], ArtistName=release['ArtistName']) + # Update the have tracks for all created dirs: + for albumpath in albumpaths: + librarysync.libraryScan(dir=albumpath, append=True, ArtistID=release['ArtistID'], ArtistName=release['ArtistName']) logger.info('Post-processing for %s - %s complete' % (release['ArtistName'], release['AlbumTitle'])) @@ -297,7 +298,8 @@ def doPostProcessing(albumid, albumpath, release, tracks, downloaded_track_list) if headphones.SYNOINDEX_ENABLED: syno = notifiers.Synoindex() - syno.notify(albumpath) + for albumpath in albumpaths: + syno.notify(albumpath) def embedAlbumArt(artwork, downloaded_track_list): logger.info('Embedding album art') @@ -375,71 +377,138 @@ def moveFiles(albumpath, release, tracks): if folder.startswith('.'): folder = folder.replace(0, '_') + + # Grab our list of files early on so we can determine if we need to create + # the lossy_dest_dir, lossless_dest_dir, or both + files_to_move = [] + lossy_media = False + lossless_media = False - destination_path = os.path.normpath(os.path.join(headphones.DESTINATION_DIR, folder)).encode(headphones.SYS_ENCODING) - - last_folder = headphones.FOLDER_FORMAT.split('/')[-1] - - # Only rename the folder if they use the album name, otherwise merge into existing folder - if os.path.exists(destination_path) and 'album' in last_folder.lower(): - i = 1 - while True: - newfolder = folder + '[%i]' % i - destination_path = os.path.normpath(os.path.join(headphones.DESTINATION_DIR, newfolder)).encode(headphones.SYS_ENCODING) - if os.path.exists(destination_path): - i += 1 - else: - folder = newfolder - break - - logger.info('Moving files from %s to %s' % (unicode(albumpath, headphones.SYS_ENCODING, errors="replace"), unicode(destination_path, headphones.SYS_ENCODING, errors="replace"))) - - # Basically check if generic/non-album folders already exist, since we're going to merge - if not os.path.exists(destination_path): - try: - os.makedirs(destination_path) - except Exception, e: - logger.error('Could not create folder for %s. Not moving: %s' % (release['AlbumTitle'], e)) - return albumpath - - # Move files to the destination folder, renaming them if they already exist for r,d,f in os.walk(albumpath): for files in f: - if os.path.isfile(os.path.join(destination_path, files)): - logger.info('Destination file exists: %s' % os.path.join(destination_path, files)) - title = os.path.splitext(files)[0] - ext = os.path.splitext(files)[1] - i = 1 - while True: - newfile = title + '(' + str(i) + ')' + ext - if os.path.isfile(os.path.join(destination_path, newfile)): - i += 1 - else: - logger.info('Renaming to %s' % newfile) - try: - os.rename(os.path.join(r, files), os.path.join(r, newfile)) - files = newfile - except Exception, e: - logger.warn('Error renaming %s: %s' % (files, e)) - break + files_to_move.append(os.path.join(r, files)) + if any(files.lower.endswith(x) for x in headphones.LOSSY_MEDIA_FORMATS): + lossy_media = True + if any(files.lower.endswith(x) for x in headphones.LOSSLESS_MEDIA_FORMATS): + lossless_media = True + # Do some sanity checking to see what directories we need to create: + make_lossy_folder = False + make_lossless_folder = False + + lossy_destination_path = os.path.normpath(os.path.join(headphones.DESTINATION_DIR, folder)).encode(headphones.SYS_ENCODING) + lossless_destination_path = os.path.normpath(os.path.join(headphones.LOSSLESS_DESTINATION_DIR, folder)).encode(headphones.SYS_ENCODING) + + # If they set a destination dir for lossless media, only create the lossy folder if there is lossy media + if headphones.LOSSLESS_DESTINATION_DIR: + if lossy_media: + make_lossy_folder = True + if lossless_media: + make_lossless_folder = True + # If they haven't set a lossless dest_dir, just create the "lossy" folder + else: + make_lossy_folder = True + + last_folder = headphones.FOLDER_FORMAT.split('/')[-1] + + if make_lossless_folder: + # Only rename the folder if they use the album name, otherwise merge into existing folder + if os.path.exists(lossless_destination_path) and 'album' in last_folder.lower(): + i = 1 + while True: + newfolder = folder + '[%i]' % i + lossless_destination_path = os.path.normpath(os.path.join(headphones.LOSSLESS_DESTINATION_DIR, newfolder)).encode(headphones.SYS_ENCODING) + if os.path.exists(destination_path): + i += 1 + else: + folder = newfolder + break + + if not os.path.exists(lossless_destination_path): try: - shutil.move(os.path.join(r, files), os.path.join(destination_path, files)) + os.makedirs(lossless_destination_path) except Exception, e: - logger.warn('Error moving file %s: %s' % (files, e)) + logger.error('Could not create lossless folder for %s. (Error: %s)' % (release['AlbumTitle'], e)) + if not make_lossy_folder: + return albumpath + if make_lossy_folder: + if os.path.exists(lossy_destination_path) and 'album' in last_folder.lower(): + i = 1 + while True: + newfolder = folder + '[%i]' % i + lossy_destination_path = os.path.normpath(os.path.join(headphones.DESTINATION_DIR, newfolder)).encode(headphones.SYS_ENCODING) + if os.path.exists(lossy_destination_path): + i += 1 + else: + folder = newfolder + break + + if not os.path.exists(lossy_destination_path): + try: + os.makedirs(lossy_destination_path) + except Exception, e: + logger.error('Could not create folder for %s. Not moving: %s' % (release['AlbumTitle'], e)) + return albumpath + + logger.info('Checking which files we need to move.....') + + # Move files to the destination folder, renaming them if they already exist + # If we have two desination_dirs, move non-music files to both + if make_lossy_folder and make_lossless_folder: + + for file_to_move in files_to_move: + + if any(file_to_move.lower.endswith(x) for x in headphones.LOSSY_MEDIA_FORMATS): + helpers.smartMove(file_to_move, lossy_destination_path) + + elif any(files.lower.endswith(x) for x in headphones.LOSSLESS_MEDIA_FORMATS): + helpers.smartMove(file_to_move, lossless_destination_path) + + # If it's a non-music file, move it to both dirs + else: + + moved_to_lossy_folder = helpers.smartMove(file_to_move, lossy_destination_path, delete=False) + moved_to_lossless_folder = helpers.smartMove(file_to_move, lossless_destination_path, delete=False) + + if moved_to_lossy_folder or moved_to_lossless_folder: + try: + os.remove(file_to_move) + except Exception, e: + logger.error("Error deleting file '" + file_to_move.decode(headphones.SYS_ENCODING) + "' from source directory") + else: + logger.error("Error copying '" + file_to_move.decode(headphones.SYS_ENCODING) + "'. Not deleting from download directory") + + elif make_lossless_folder and not make_lossy_folder: + + for file_to_move in files_to_move: + helpers.smartMove(file_to_move, lossless_destination_path) + + else: + + for file_to_move in files_to_move: + helpers.smartMove(file_to_move, lossy_destination_path) + # Chmod the directories using the folder_format (script courtesy of premiso!) folder_list = folder.split('/') - temp_f = headphones.DESTINATION_DIR - - for f in folder_list: - - temp_f = os.path.join(temp_f, f) - - try: - os.chmod(os.path.normpath(temp_f).encode(headphones.SYS_ENCODING), int(headphones.FOLDER_PERMISSIONS, 8)) - except Exception, e: - logger.error("Error trying to change permissions on folder: %s" % temp_f) + temp_fs = [] + + if make_lossless_folder: + temp_fs.append(headphones.LOSSLESS_DESTINATION_DIR) + + if make_lossy_folder: + temp_fs.append(headphones.DESTINATION_DIR) + + for temp_f in temp_fs: + + for f in folder_list: + + temp_f = os.path.join(temp_f, f) + + try: + os.chmod(os.path.normpath(temp_f).encode(headphones.SYS_ENCODING), int(headphones.FOLDER_PERMISSIONS, 8)) + except Exception, e: + logger.error("Error trying to change permissions on folder: %s" % temp_f) # If we failed to move all the files out of the directory, this will fail too try: @@ -447,42 +516,68 @@ def moveFiles(albumpath, release, tracks): except Exception, e: logger.error('Could not remove directory: %s. %s' % (albumpath, e)) - return destination_path + destination_paths = [] + + if make_lossy_folder: + destination_paths.append(lossy_destination_path) + if make_lossless_folder: + destination_paths.append(lossless_destination_path) + + return destination_paths def correctMetadata(albumid, release, downloaded_track_list): - logger.info('Writing metadata') - items = [] + logger.info('Preparing to write metadata to tracks....') + lossy_items = [] + lossless_items = [] + + # Process lossless & lossy media formats separately for downloaded_track in downloaded_track_list: + + try: - try: - items.append(beets.library.Item.from_path(downloaded_track)) + if any(downloaded_track.lower().endswith('.' + x.lower()) for x in headphones.LOSSLESS_MEDIA_FORMATS): + lossless_items.append(beets.library.Item.from_path(downloaded_track)) + elif any(downloaded_track.lower().endswith('.' + x.lower()) for x in headphones.LOSSY_MEDIA_FORMATS): + lossy_items.append(beets.library.Item.from_path(downloaded_track)) + else: + logger.warn("Skipping: " + downloaded_track.decode(headphones.SYS_ENCODING) + " because it is not a mutagen friendly file format") except Exception, e: - logger.error("Beets couldn't create an Item from: " + downloaded_track + " - not a media file?" + str(e)) - - try: - cur_artist, cur_album, candidates, rec = autotag.tag_album(items, search_artist=helpers.latinToAscii(release['ArtistName']), search_album=helpers.latinToAscii(release['AlbumTitle'])) - except Exception, e: - logger.error('Error getting recommendation: %s. Not writing metadata' % e) - return - if rec == 'RECOMMEND_NONE': - logger.warn('No accurate album match found for %s, %s - not writing metadata' % (release['ArtistName'], release['AlbumTitle'])) - return - - dist, info, mapping, extra_items, extra_tracks = candidates[0] - logger.debug('Beets recommendation: %s' % rec) - autotag.apply_metadata(info, mapping) - - if len(items) != len(downloaded_track_list): - logger.warn("Mismatch between number of tracks downloaded and the metadata items, but I'll try to write it anyway") - - i = 1 - for item in items: + + logger.error("Beets couldn't create an Item from: " + downloaded_track.decode(headphones.SYS_ENCODING) + " - not a media file?" + str(e)) + + for items in [lossy_items, lossless_items]: + + if not items: + continue + try: - item.write() + cur_artist, cur_album, candidates, rec = autotag.tag_album(items, search_artist=helpers.latinToAscii(release['ArtistName']), search_album=helpers.latinToAscii(release['AlbumTitle'])) except Exception, e: - logger.warn('Error writing metadata to track %i: %s' % (i,e)) - i += 1 + logger.error('Error getting recommendation: %s. Not writing metadata' % e) + return + if rec == 'RECOMMEND_NONE': + logger.warn('No accurate album match found for %s, %s - not writing metadata' % (release['ArtistName'], release['AlbumTitle'])) + return + + if candidates: + dist, info, mapping, extra_items, extra_tracks = candidates[0] + else: + logger.warn('No accurate album match found for %s, %s - not writing metadata' % (release['ArtistName'], release['AlbumTitle'])) + return + + logger.info('Beets recommendation for tagging items: %s' % rec) + + # TODO: Handle extra_items & extra_tracks + + autotag.apply_metadata(info, mapping) + + for item in items: + try: + item.write() + logger.info("Successfully applied metadata to: " + item.path.decode(headphones.SYS_ENCODING)) + except Exception, e: + logger.warn("Error writing metadata to " + item.path.decode(headphones.SYS_ENCODING) + ": " + str(e)) def embedLyrics(downloaded_track_list): logger.info('Adding lyrics') From 4f1052d49e3751e6089cdaf6c65f44f6e5df64a9 Mon Sep 17 00:00:00 2001 From: rembo10 Date: Fri, 17 Aug 2012 16:28:28 +0530 Subject: [PATCH 54/84] Move getArtistBio out of init actions so it doesn't append when marking an album as wanted --- data/interfaces/default/artist.html | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/data/interfaces/default/artist.html b/data/interfaces/default/artist.html index 1e00b281..b0d8af05 100644 --- a/data/interfaces/default/artist.html +++ b/data/interfaces/default/artist.html @@ -182,7 +182,6 @@ showMsg("Getting artist information",true); %endif getArtistArt(); - getArtistBio(); getAlbumArt(); $('#album_table').dataTable({ "bDestroy": true, @@ -220,6 +219,7 @@ $(document).ready(function() { initActions(); initThisPage(); + getArtistBio(); }); From d5d8906addfe85c482b265f115838c984ef11e7e Mon Sep 17 00:00:00 2001 From: rembo10 Date: Fri, 17 Aug 2012 17:29:20 +0530 Subject: [PATCH 55/84] Fixed files.lower() function in moveFiles --- headphones/postprocessor.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/headphones/postprocessor.py b/headphones/postprocessor.py index 7959dd64..289d72c5 100644 --- a/headphones/postprocessor.py +++ b/headphones/postprocessor.py @@ -387,9 +387,9 @@ def moveFiles(albumpath, release, tracks): for r,d,f in os.walk(albumpath): for files in f: files_to_move.append(os.path.join(r, files)) - if any(files.lower.endswith(x) for x in headphones.LOSSY_MEDIA_FORMATS): + if any(files.lower().endswith('.' + x.lower()) for x in headphones.LOSSY_MEDIA_FORMATS): lossy_media = True - if any(files.lower.endswith(x) for x in headphones.LOSSLESS_MEDIA_FORMATS): + if any(files.lower().endswith('.' + x.lower()) for x in headphones.LOSSLESS_MEDIA_FORMATS): lossless_media = True # Do some sanity checking to see what directories we need to create: @@ -466,6 +466,7 @@ def moveFiles(albumpath, release, tracks): helpers.smartMove(file_to_move, lossless_destination_path) # If it's a non-music file, move it to both dirs + # TODO: Move specific-to-lossless files to the lossless dir only else: moved_to_lossy_folder = helpers.smartMove(file_to_move, lossy_destination_path, delete=False) @@ -582,6 +583,8 @@ def correctMetadata(albumid, release, downloaded_track_list): def embedLyrics(downloaded_track_list): logger.info('Adding lyrics') + # TODO: If adding lyrics for flac & lossy, only fetch the lyrics once + # and apply it to both files for downloaded_track in downloaded_track_list: try: From 95e560413976548929300aaa9fa0caf824948d4d Mon Sep 17 00:00:00 2001 From: rembo10 Date: Fri, 17 Aug 2012 17:37:24 +0530 Subject: [PATCH 56/84] Fixed 'destination_path' -> 'lossless_destination_path' variable --- headphones/postprocessor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/headphones/postprocessor.py b/headphones/postprocessor.py index 289d72c5..f76fbac0 100644 --- a/headphones/postprocessor.py +++ b/headphones/postprocessor.py @@ -418,7 +418,7 @@ def moveFiles(albumpath, release, tracks): while True: newfolder = folder + '[%i]' % i lossless_destination_path = os.path.normpath(os.path.join(headphones.LOSSLESS_DESTINATION_DIR, newfolder)).encode(headphones.SYS_ENCODING) - if os.path.exists(destination_path): + if os.path.exists(lossless_destination_path): i += 1 else: folder = newfolder From ec8c3157d9bc0b9b5515a1d7c0d616e32e87df80 Mon Sep 17 00:00:00 2001 From: rembo10 Date: Fri, 17 Aug 2012 17:42:30 +0530 Subject: [PATCH 57/84] More bug fixes. Fixed lower() function later on in moveFiles --- headphones/postprocessor.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/headphones/postprocessor.py b/headphones/postprocessor.py index f76fbac0..f6fde523 100644 --- a/headphones/postprocessor.py +++ b/headphones/postprocessor.py @@ -459,10 +459,10 @@ def moveFiles(albumpath, release, tracks): for file_to_move in files_to_move: - if any(file_to_move.lower.endswith(x) for x in headphones.LOSSY_MEDIA_FORMATS): + if any(file_to_move.lower().endswith('.' + x.lower()) for x in headphones.LOSSY_MEDIA_FORMATS): helpers.smartMove(file_to_move, lossy_destination_path) - elif any(files.lower.endswith(x) for x in headphones.LOSSLESS_MEDIA_FORMATS): + elif any(file_to_move.lower().endswith('.' + x.lower()) for x in headphones.LOSSLESS_MEDIA_FORMATS): helpers.smartMove(file_to_move, lossless_destination_path) # If it's a non-music file, move it to both dirs From daa9f9703f77798278e5c4f7a4d33ea06145a833 Mon Sep 17 00:00:00 2001 From: rembo10 Date: Fri, 17 Aug 2012 18:00:55 +0530 Subject: [PATCH 58/84] Fix for renaming duplicate lossy folders if the lossless folder exists --- headphones/postprocessor.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/headphones/postprocessor.py b/headphones/postprocessor.py index f6fde523..27581b93 100644 --- a/headphones/postprocessor.py +++ b/headphones/postprocessor.py @@ -414,14 +414,17 @@ def moveFiles(albumpath, release, tracks): if make_lossless_folder: # Only rename the folder if they use the album name, otherwise merge into existing folder if os.path.exists(lossless_destination_path) and 'album' in last_folder.lower(): + + temp_folder = folder + i = 1 while True: - newfolder = folder + '[%i]' % i + newfolder = temp_folder + '[%i]' % i lossless_destination_path = os.path.normpath(os.path.join(headphones.LOSSLESS_DESTINATION_DIR, newfolder)).encode(headphones.SYS_ENCODING) if os.path.exists(lossless_destination_path): i += 1 else: - folder = newfolder + temp_folder = newfolder break if not os.path.exists(lossless_destination_path): @@ -434,14 +437,17 @@ def moveFiles(albumpath, release, tracks): if make_lossy_folder: if os.path.exists(lossy_destination_path) and 'album' in last_folder.lower(): + + temp_folder = folder + i = 1 while True: - newfolder = folder + '[%i]' % i + newfolder = temp_folder + '[%i]' % i lossy_destination_path = os.path.normpath(os.path.join(headphones.DESTINATION_DIR, newfolder)).encode(headphones.SYS_ENCODING) if os.path.exists(lossy_destination_path): i += 1 else: - folder = newfolder + temp_folder = newfolder break if not os.path.exists(lossy_destination_path): From 44e67256403f0816a310ed88af38de56f2aa7015 Mon Sep 17 00:00:00 2001 From: rembo10 Date: Fri, 17 Aug 2012 18:27:05 +0530 Subject: [PATCH 59/84] Don't rename files that are already named correctly --- headphones/postprocessor.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/headphones/postprocessor.py b/headphones/postprocessor.py index 27581b93..28be4d90 100644 --- a/headphones/postprocessor.py +++ b/headphones/postprocessor.py @@ -623,7 +623,7 @@ def renameFiles(albumpath, downloaded_track_list, release): try: f = MediaFile(downloaded_track) except: - logger.info("MediaFile couldn't parse: " + downloaded_track) + logger.info("MediaFile couldn't parse: " + downloaded_track.decode(headphones.SYS_ENCODING)) continue if not f.track: @@ -665,12 +665,16 @@ def renameFiles(albumpath, downloaded_track_list, release): new_file_name = new_file_name.replace(0, '_') new_file = os.path.join(albumpath, new_file_name) + + if downloaded_track == new_file_name: + logger.info("Renaming for: " + downloaded_track.decode(headphones.SYS_ENCODING) + " is not neccessary") + continue - logger.debug('Renaming %s ---> %s' % (downloaded_track, new_file_name)) + logger.info('Renaming %s ---> %s' % (downloaded_track.decode(headphones.SYS_ENCODING), new_file_name.decode(headphones.SYS_ENCODING))) try: os.rename(downloaded_track, new_file) except Exception, e: - logger.error('Error renaming file: %s. Error: %s' % (downloaded_track, e)) + logger.error('Error renaming file: %s. Error: %s' % (downloaded_track.decode(headphones.SYS_ENCODING), e)) continue def renameUnprocessedFolder(albumpath): From c57d04f97bb83f8349a396bd92ce346c13400930 Mon Sep 17 00:00:00 2001 From: rembo10 Date: Fri, 17 Aug 2012 18:42:55 +0530 Subject: [PATCH 60/84] Changed some logging levels back to debug --- headphones/postprocessor.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/headphones/postprocessor.py b/headphones/postprocessor.py index 28be4d90..5f0c9842 100644 --- a/headphones/postprocessor.py +++ b/headphones/postprocessor.py @@ -667,10 +667,10 @@ def renameFiles(albumpath, downloaded_track_list, release): new_file = os.path.join(albumpath, new_file_name) if downloaded_track == new_file_name: - logger.info("Renaming for: " + downloaded_track.decode(headphones.SYS_ENCODING) + " is not neccessary") + logger.debug("Renaming for: " + downloaded_track.decode(headphones.SYS_ENCODING) + " is not neccessary") continue - logger.info('Renaming %s ---> %s' % (downloaded_track.decode(headphones.SYS_ENCODING), new_file_name.decode(headphones.SYS_ENCODING))) + logger.debug('Renaming %s ---> %s' % (downloaded_track.decode(headphones.SYS_ENCODING), new_file_name.decode(headphones.SYS_ENCODING))) try: os.rename(downloaded_track, new_file) except Exception, e: From 0bba75554fd917cd6849455f01327cf71feb261b Mon Sep 17 00:00:00 2001 From: rembo10 Date: Fri, 17 Aug 2012 19:04:25 +0530 Subject: [PATCH 61/84] Fixed some wording on the album page to differentiate between trying new versions to download and switching releases TODO: might still need to be a bit clearer what does what.... tooltips? --- data/interfaces/default/album.html | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/data/interfaces/default/album.html b/data/interfaces/default/album.html index fce97271..6af0459d 100644 --- a/data/interfaces/default/album.html +++ b/data/interfaces/default/album.html @@ -17,14 +17,14 @@ Retry Download Try New Version %endif - Choose Alternate Version + Choose Alternate Release @@ -848,6 +850,25 @@ m<%inherit file="base.html"/> { $("#nmaoptions").slideUp(); } + }); + if ($("#preferred_bitrate").is(":checked")) + { + $("#preferred_bitrate_options").show(); + } + else + { + $("#preferred_bitrate_options").hide(); + } + + $('input[type=radio]').change(function(){ + if ($("#preferred_bitrate").is(":checked")) + { + $("#preferred_bitrate_options").slideDown("fast"); + } + else + { + $("#preferred_bitrate_options").slideUp("fast"); + } }); $("#mirror").change(handleNewSelection); diff --git a/data/interfaces/default/css/style.css b/data/interfaces/default/css/style.css index 771d4db3..e171b4db 100644 --- a/data/interfaces/default/css/style.css +++ b/data/interfaces/default/css/style.css @@ -376,6 +376,10 @@ form .radio label { padding-top: 1px; width: auto; } +.override-float { + float: none !important; + margin-bottom: 0px !important; +} form .radio input { float: left; margin-bottom: 10px; From cecda0bab993da2e3115639df3b375360253ee38 Mon Sep 17 00:00:00 2001 From: rembo10 Date: Sun, 19 Aug 2012 17:27:56 +0530 Subject: [PATCH 74/84] Added backend stuff to searcher.py to limit sizes to a specified range with preferred quality --- headphones/searcher.py | 41 ++++++++++++++++++++++++++++++++--------- 1 file changed, 32 insertions(+), 9 deletions(-) diff --git a/headphones/searcher.py b/headphones/searcher.py index 50068271..7470c790 100644 --- a/headphones/searcher.py +++ b/headphones/searcher.py @@ -426,15 +426,38 @@ def searchNZB(albumid=None, new=False, losslessOnly=False): albumlength = sum([pair[0] for pair in tracks]) targetsize = albumlength/1000 * int(headphones.PREFERRED_BITRATE) * 128 - logger.info('Target size: %s' % helpers.bytes_to_mb(targetsize)) - - newlist = [] - - for result in resultlist: - delta = abs(targetsize - result[1]) - newlist.append((result[0], result[1], result[2], result[3], delta)) - - nzblist = sorted(newlist, key=lambda title: title[4]) + + if not targetsize: + logger.info('No track information for %s - %s. Defaulting to highest quality' % (albums[0], albums[1])) + nzblist = sorted(resultlist, key=lambda title: title[1], reverse=True) + + else: + logger.info('Target size: %s' % helpers.bytes_to_mb(targetsize)) + newlist = [] + + if headphones.PREFERRED_BITRATE_HIGH_BUFFER: + high_size_limit = targetsize * int(headphones.PREFERRED_BITRATE_HIGH_BUFFER)/100 + else: + high_size_limit = None + if headphones.PREFERRED_BITRATE_LOW_BUFFER: + low_size_limit = targetsize * int(headphones.PREFERRED_BITRATE_LOW_BUFFER)/100 + else: + low_size_limit = None + + for result in resultlist: + + if high_size_limit and (result[1] > high_size_limit): + logger.info(result[0] + "is too large for this album - not considering it. (Size: " + helpers.bytes_to_mb(result[1]) + ", Maxsize: " + helpers.bytes_to_mb(high_size_limit)) + continue + + if low_size_limit and (result[1] < low_size_limit): + logger.info(result[0] + "is too small for this album - not considering it. (Size: " + helpers.bytes_to_mb(result[1]) + ", Minsize: " + helpers.bytes_to_mb(low_size_limit)) + continue + + delta = abs(targetsize - result[1]) + newlist.append((result[0], result[1], result[2], result[3], delta)) + + nzblist = sorted(newlist, key=lambda title: title[4]) except Exception, e: From 6339dd8a87fbe211f1b9e5ae5fed5ab916640d75 Mon Sep 17 00:00:00 2001 From: rembo10 Date: Sun, 19 Aug 2012 18:39:17 +0530 Subject: [PATCH 75/84] Fix for post-processor not working if the album wasn't in the database --- headphones/importer.py | 1 + headphones/postprocessor.py | 28 ++++++++++++++++------------ 2 files changed, 17 insertions(+), 12 deletions(-) diff --git a/headphones/importer.py b/headphones/importer.py index a186f304..571360e9 100644 --- a/headphones/importer.py +++ b/headphones/importer.py @@ -470,6 +470,7 @@ def addReleaseById(rid): if headphones.INCLUDE_EXTRAS: newValueDict['IncludeExtras'] = 1 + newValueDict['Extras'] = headphones.EXTRAS myDB.upsert("artists", newValueDict, controlValueDict) diff --git a/headphones/postprocessor.py b/headphones/postprocessor.py index 5f0c9842..f9bfc833 100644 --- a/headphones/postprocessor.py +++ b/headphones/postprocessor.py @@ -65,18 +65,23 @@ def verify(albumid, albumpath): #TODO: odd things can happen when there are diacritic characters in the folder name, need to translate them? import mb - release_dict = None + release_list = None try: - release_dict = mb.getReleaseGroup(albumid) + release_list = mb.getReleaseGroup(albumid) except Exception, e: logger.info('Unable to get release information for manual album with rgid: %s. Error: %s' % (albumid, e)) return - if not release_dict: + if not release_list: logger.info('Unable to get release information for manual album with rgid: %s' % albumid) return + # Since we're just using this to create the bare minimum information to insert an artist/album combo, use the first release + releaseid = release_list[0]['id'] + + release_dict = mb.getRelease(releaseid) + logger.info(u"Now adding/updating artist: " + release_dict['artist_name']) if release_dict['artist_name'].startswith('The '): @@ -90,10 +95,12 @@ def verify(albumid, albumpath): "ArtistSortName": sortname, "DateAdded": helpers.today(), "Status": "Paused"} - logger.info("ArtistID:ArtistName: " + release_dict['artist_id'] + " : " + release_dict['artist_name']) + + logger.info("ArtistID: " + release_dict['artist_id'] + " , ArtistName: " + release_dict['artist_name']) if headphones.INCLUDE_EXTRAS: newValueDict['IncludeExtras'] = 1 + newValueDict['Extras'] = headphones.EXTRAS myDB.upsert("artists", newValueDict, controlValueDict) @@ -104,24 +111,21 @@ def verify(albumid, albumpath): "ArtistName": release_dict['artist_name'], "AlbumTitle": release_dict['title'], "AlbumASIN": release_dict['asin'], - "ReleaseDate": release_dict['releasedate'], + "ReleaseDate": release_dict['date'], "DateAdded": helpers.today(), - "Type": release_dict['type'], + "Type": release_dict['rg_type'], "Status": "Snatched" } myDB.upsert("albums", newValueDict, controlValueDict) - - # I changed the albumid from releaseid -> rgid, so might need to delete albums that have a releaseid - for rel in release_dict['releaselist']: - myDB.action('DELETE from albums WHERE AlbumID=?', [rel['releaseid']]) - myDB.action('DELETE from tracks WHERE AlbumID=?', [rel['releaseid']]) - + + # Delete existing tracks associated with this AlbumID since we're going to replace them and don't want any extras myDB.action('DELETE from tracks WHERE AlbumID=?', [albumid]) for track in release_dict['tracks']: controlValueDict = {"TrackID": track['id'], "AlbumID": albumid} + newValueDict = {"ArtistID": release_dict['artist_id'], "ArtistName": release_dict['artist_name'], "AlbumTitle": release_dict['title'], From cb930d5e23f68ac5f2127c51b2e36fad9c44c07a Mon Sep 17 00:00:00 2001 From: rembo10 Date: Sun, 19 Aug 2012 18:50:18 +0530 Subject: [PATCH 76/84] Fixed sqlite query when moving over Extras values from the old format to the new format --- headphones/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/headphones/__init__.py b/headphones/__init__.py index 2ed09fa8..0dfde064 100644 --- a/headphones/__init__.py +++ b/headphones/__init__.py @@ -902,7 +902,7 @@ def dbcheck(): artists = c.execute('SELECT ArtistID, IncludeExtras from artists').fetchall() for artist in artists: if artist['IncludeExtras']: - c.execute('INSERT into artists Extras="1,2,3,4,5,6,7,8" WHERE ArtistID=' + artist['ArtistID']) + c.execute('INSERT into artists (Extras) VALUES ("1,2,3,4,5,6,7,8") WHERE ArtistID=' + artist['ArtistID']) conn.commit() c.close() From 2a3352870c2a73d0090780748bee472ad490bda2 Mon Sep 17 00:00:00 2001 From: rembo10 Date: Sun, 19 Aug 2012 19:21:46 +0530 Subject: [PATCH 77/84] Actually fix the sqlite query :-) --- headphones/__init__.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/headphones/__init__.py b/headphones/__init__.py index 0dfde064..3126e793 100644 --- a/headphones/__init__.py +++ b/headphones/__init__.py @@ -901,8 +901,8 @@ def dbcheck(): logger.info("Copying over current artist IncludeExtras information") artists = c.execute('SELECT ArtistID, IncludeExtras from artists').fetchall() for artist in artists: - if artist['IncludeExtras']: - c.execute('INSERT into artists (Extras) VALUES ("1,2,3,4,5,6,7,8") WHERE ArtistID=' + artist['ArtistID']) + if artist[1]: + c.execute('UPDATE artists SET Extras=? WHERE ArtistID=?', ("1,2,3,4,5,6,7,8", artist[0])) conn.commit() c.close() From 3f1ddd6489427fc24820250563899fe6227d065b Mon Sep 17 00:00:00 2001 From: rembo10 Date: Mon, 20 Aug 2012 16:53:39 +0530 Subject: [PATCH 78/84] Don't try to load the artist page until it has the bare minimum info needed to render it --- headphones/webserve.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/headphones/webserve.py b/headphones/webserve.py index 7e905a24..c1d6674b 100644 --- a/headphones/webserve.py +++ b/headphones/webserve.py @@ -62,6 +62,20 @@ class WebInterface(object): artist = myDB.action('SELECT * FROM artists WHERE ArtistID=?', [ArtistID]).fetchone() albums = myDB.select('SELECT * from albums WHERE ArtistID=? order by ReleaseDate DESC', [ArtistID]) + # Don't redirect to the artist page until it has the bare minimum info inserted + # Redirect to the home page if we still can't get it after 5 seconds + retry = 0 + + while retry < 5: + if not artist: + time.sleep(1) + retry += 1 + else: + break + + if not artist: + raise cherrypy.HTTPRedirect("home") + # Serve the extras up as a dict to make things easier for new templates extras_list = ["single", "ep", "compilation", "soundtrack", "live", "remix", "spokenword", "audiobook"] extras_dict = {} @@ -79,8 +93,6 @@ class WebInterface(object): extras_dict[extra] = "" i+=1 - if artist is None: - raise cherrypy.HTTPRedirect("home") return serve_template(templatename="artist.html", title=artist['ArtistName'], artist=artist, albums=albums, extras=extras_dict) artistPage.exposed = True From 452b83e197f3937690224a99bf8218f86ac23f70 Mon Sep 17 00:00:00 2001 From: rembo10 Date: Mon, 20 Aug 2012 16:58:02 +0530 Subject: [PATCH 79/84] Fetch the artist info again during retries when loading artist page --- headphones/webserve.py | 1 + 1 file changed, 1 insertion(+) diff --git a/headphones/webserve.py b/headphones/webserve.py index c1d6674b..14f29a80 100644 --- a/headphones/webserve.py +++ b/headphones/webserve.py @@ -69,6 +69,7 @@ class WebInterface(object): while retry < 5: if not artist: time.sleep(1) + artist = myDB.action('SELECT * FROM artists WHERE ArtistID=?', [ArtistID]).fetchone() retry += 1 else: break From f9a048b8f8226abccc6d03190d600c89d8a31d27 Mon Sep 17 00:00:00 2001 From: rembo10 Date: Mon, 20 Aug 2012 17:06:50 +0530 Subject: [PATCH 80/84] First patch to stop albums with no release date from being marked as wanted - might need some frontend work in the UI to correct the sorting --- headphones/mb.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/headphones/mb.py b/headphones/mb.py index 29b5bb8a..36eeeae8 100644 --- a/headphones/mb.py +++ b/headphones/mb.py @@ -326,8 +326,8 @@ def getRelease(releaseid, include_artist_info=True): release['title'] = unicode(results['title']) release['id'] = unicode(results['id']) - release['asin'] = unicode(results['asin']) if 'asin' in results else u'None' - release['date'] = unicode(results['date']) if 'date' in results else u'None' + release['asin'] = unicode(results['asin']) if 'asin' in results else None + release['date'] = unicode(results['date']) if 'date' in results else None try: release['format'] = unicode(results['medium-list'][0]['format']) except: From 6b3e6afd6f27e9c5461aba33721ea3b100591c44 Mon Sep 17 00:00:00 2001 From: rembo10 Date: Mon, 20 Aug 2012 17:44:05 +0530 Subject: [PATCH 81/84] Datatables sType=Date on artist page for release dates --- data/interfaces/default/artist.html | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/data/interfaces/default/artist.html b/data/interfaces/default/artist.html index 49acbef7..31299d6e 100644 --- a/data/interfaces/default/artist.html +++ b/data/interfaces/default/artist.html @@ -210,7 +210,7 @@ null, null, null, - null, + { "sType": "date" }, null, null, { "sType": "title-numeric"}, From 8b26240e5723364c4fc185364c7926149d81004d Mon Sep 17 00:00:00 2001 From: rembo10 Date: Mon, 20 Aug 2012 20:32:55 +0530 Subject: [PATCH 82/84] Fixed prowl notifications not being sent after post processing if PROWL_ONSNATCH was not enabled --- headphones/postprocessor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/headphones/postprocessor.py b/headphones/postprocessor.py index f9bfc833..9050f8b6 100644 --- a/headphones/postprocessor.py +++ b/headphones/postprocessor.py @@ -283,7 +283,7 @@ def doPostProcessing(albumid, albumpath, release, tracks, downloaded_track_list) logger.info('Post-processing for %s - %s complete' % (release['ArtistName'], release['AlbumTitle'])) - if headphones.PROWL_ONSNATCH: + if headphones.PROWL_ENABLED: pushmessage = release['ArtistName'] + ' - ' + release['AlbumTitle'] logger.info(u"Prowl request") prowl = notifiers.PROWL() From bbf74a544de8dfee4eb4ffef4b427099e251b02c Mon Sep 17 00:00:00 2001 From: rembo10 Date: Mon, 20 Aug 2012 20:59:13 +0530 Subject: [PATCH 83/84] Added NMA on snatch option --- data/interfaces/default/config.html | 5 ++++- headphones/__init__.py | 5 ++++- headphones/notifiers.py | 13 ++++++++----- headphones/sab.py | 8 ++++++-- headphones/webserve.py | 4 +++- 5 files changed, 25 insertions(+), 10 deletions(-) diff --git a/data/interfaces/default/config.html b/data/interfaces/default/config.html index 75219f33..643af20e 100644 --- a/data/interfaces/default/config.html +++ b/data/interfaces/default/config.html @@ -618,7 +618,10 @@ m<%inherit file="base.html"/>
- +
+ +
+ Separate multiple api keys with commas
diff --git a/headphones/__init__.py b/headphones/__init__.py index 3126e793..0e06ba4c 100644 --- a/headphones/__init__.py +++ b/headphones/__init__.py @@ -183,6 +183,7 @@ XBMC_NOTIFY = False NMA_ENABLED = False NMA_APIKEY = None NMA_PRIORITY = None +NMA_ONSNATCH = None SYNOINDEX_ENABLED = False MIRRORLIST = ["musicbrainz.org","headphones","custom"] MIRROR = None @@ -254,7 +255,7 @@ def initialize(): ENCODERFOLDER, ENCODER, BITRATE, SAMPLINGFREQUENCY, MUSIC_ENCODER, ADVANCEDENCODER, ENCODEROUTPUTFORMAT, ENCODERQUALITY, \ ENCODERVBRCBR, ENCODERLOSSLESS, DELETE_LOSSLESS_FILES, PROWL_ENABLED, PROWL_PRIORITY, PROWL_KEYS, PROWL_ONSNATCH, MIRRORLIST, \ MIRROR, CUSTOMHOST, CUSTOMPORT, CUSTOMSLEEP, HPUSER, HPPASS, XBMC_ENABLED, XBMC_HOST, XBMC_USERNAME, XBMC_PASSWORD, XBMC_UPDATE, \ - XBMC_NOTIFY, NMA_ENABLED, NMA_APIKEY, NMA_PRIORITY, SYNOINDEX_ENABLED, ALBUM_COMPLETION_PCT, PREFERRED_BITRATE_HIGH_BUFFER, \ + XBMC_NOTIFY, NMA_ENABLED, NMA_APIKEY, NMA_PRIORITY, NMA_ONSNATCH, SYNOINDEX_ENABLED, ALBUM_COMPLETION_PCT, PREFERRED_BITRATE_HIGH_BUFFER, \ PREFERRED_BITRATE_LOW_BUFFER if __INITIALIZED__: @@ -401,6 +402,7 @@ def initialize(): NMA_ENABLED = bool(check_setting_int(CFG, 'NMA', 'nma_enabled', 0)) NMA_APIKEY = check_setting_str(CFG, 'NMA', 'nma_apikey', '') NMA_PRIORITY = check_setting_int(CFG, 'NMA', 'nma_priority', 0) + NMA_ONSNATCH = bool(check_setting_int(CFG, 'NMA', 'nma_onsnatch', 0)) SYNOINDEX_ENABLED = bool(check_setting_int(CFG, 'Synoindex', 'synoindex_enabled', 0)) @@ -676,6 +678,7 @@ def config_write(): new_config['NMA']['nma_enabled'] = int(NMA_ENABLED) new_config['NMA']['nma_apikey'] = NMA_APIKEY new_config['NMA']['nma_priority'] = NMA_PRIORITY + new_config['NMA']['nma_onsnatch'] = int(PROWL_ONSNATCH) new_config['Synoindex'] = {} new_config['Synoindex']['synoindex_enabled'] = int(SYNOINDEX_ENABLED) diff --git a/headphones/notifiers.py b/headphones/notifiers.py index 3c16b647..3df88485 100644 --- a/headphones/notifiers.py +++ b/headphones/notifiers.py @@ -171,14 +171,17 @@ class NMA: return response - def notify(self, artist, album): + def notify(self, artist=None, album=None, snatched_nzb=None): apikey = self.apikey priority = self.priority - event = artist + ' - ' + album + ' complete!' - - description = "Headphones has downloaded and postprocessed: " + artist + ' [' + album + ']' + if snatched_nzb: + event = snatched_nzb + " snatched!" + description = "Headphones has snatched: " + snatched_nzb + " and has sent it to SABnzbd+" + else: + event = artist + ' - ' + album + ' complete!' + description = "Headphones has downloaded and postprocessed: " + artist + ' [' + album + ']' data = { 'apikey': apikey, 'application':'Headphones', 'event': event, 'description': description, 'priority': priority} @@ -223,4 +226,4 @@ class Synoindex: def notify_multiple(self, path_list): if isinstance(path_list, list): for path in path_list: - self.notify(path) \ No newline at end of file + self.notify(path) diff --git a/headphones/sab.py b/headphones/sab.py index 6d57214c..282dd4c3 100644 --- a/headphones/sab.py +++ b/headphones/sab.py @@ -118,10 +118,14 @@ def sendNZB(nzb): if sabText == "ok": logger.info(u"NZB sent to SAB successfully") - if headphones.PROWL_ONSNATCH: - logger.info(u"Prowl request") + if headphones.PROWL_ENABLED and headphones.PROWL_ONSNATCH: + logger.info(u"Sending Prowl notification") prowl = notifiers.PROWL() prowl.notify(nzb.name,"Download started") + if headphones.NMA_ENABLED and headphones.NMA_ONSNATCH: + logger.debug(u"Sending NMA notification") + nma = notifiers.NMA() + nma.notify(snatched_nzb=nzb.name) return True elif sabText == "Missing authentication": diff --git a/headphones/webserve.py b/headphones/webserve.py index 14f29a80..799d84f3 100644 --- a/headphones/webserve.py +++ b/headphones/webserve.py @@ -511,6 +511,7 @@ class WebInterface(object): "nma_enabled": checked(headphones.NMA_ENABLED), "nma_apikey": headphones.NMA_APIKEY, "nma_priority": int(headphones.NMA_PRIORITY), + "nma_onsnatch": checked(headphones.NMA_ONSNATCH), "synoindex_enabled": checked(headphones.SYNOINDEX_ENABLED), "mirror_list": headphones.MIRRORLIST, "mirror": headphones.MIRROR, @@ -550,7 +551,7 @@ class WebInterface(object): interface=None, log_dir=None, music_encoder=0, encoder=None, bitrate=None, samplingfrequency=None, encoderfolder=None, advancedencoder=None, encoderoutputformat=None, encodervbrcbr=None, encoderquality=None, encoderlossless=0, delete_lossless_files=0, prowl_enabled=0, prowl_onsnatch=0, prowl_keys=None, prowl_priority=0, xbmc_enabled=0, xbmc_host=None, xbmc_username=None, xbmc_password=None, xbmc_update=0, xbmc_notify=0, nma_enabled=False, - nma_apikey=None, nma_priority=0, synoindex_enabled=False, mirror=None, customhost=None, customport=None, customsleep=None, hpuser=None, hppass=None, + nma_apikey=None, nma_priority=0, nma_onsnatch=0, synoindex_enabled=False, mirror=None, customhost=None, customport=None, customsleep=None, hpuser=None, hppass=None, preferred_bitrate_high_buffer=None, preferred_bitrate_low_buffer=None, **kwargs): headphones.HTTP_HOST = http_host @@ -639,6 +640,7 @@ class WebInterface(object): headphones.NMA_ENABLED = nma_enabled headphones.NMA_APIKEY = nma_apikey headphones.NMA_PRIORITY = nma_priority + headphones.NMA_ONSNATCH = nma_onsnatch headphones.SYNOINDEX_ENABLED = synoindex_enabled headphones.MIRROR = mirror headphones.CUSTOMHOST = customhost From 2f8abf7dca4949b2efde3464c3a1e96d42f0aa49 Mon Sep 17 00:00:00 2001 From: rembo10 Date: Mon, 20 Aug 2012 22:37:37 +0530 Subject: [PATCH 84/84] Add error checking to importer, so we don't mark an artist as being updated if not all albums were added/refreshed --- headphones/importer.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/headphones/importer.py b/headphones/importer.py index 571360e9..50c7531e 100644 --- a/headphones/importer.py +++ b/headphones/importer.py @@ -103,6 +103,9 @@ def addArtisttoDB(artistid, extrasonly=False): logger.warn('Cannot import Various Artists.') return + # We'll use this to see if we should update the 'LastUpdated' time stamp + errors = False + myDB = db.DBConnection() # Delete from blacklist if it's on there @@ -171,6 +174,7 @@ def addArtisttoDB(artistid, extrasonly=False): continue if not releaselist: + errors = True continue # This will be used later to build a hybrid release @@ -186,10 +190,12 @@ def addArtisttoDB(artistid, extrasonly=False): try: releasedict = mb.getRelease(releaseid, include_artist_info=False) except Exception, e: + errors = True logger.info('Unable to get release information for %s: %s' % (release['id'], e)) continue if not releasedict: + errors = True continue controlValueDict = {"ReleaseID": release['id']} @@ -412,14 +418,18 @@ def addArtisttoDB(artistid, extrasonly=False): "TotalTracks": totaltracks, "HaveTracks": havetracks} - newValueDict['LastUpdated'] = helpers.now() + if not errors: + newValueDict['LastUpdated'] = helpers.now() myDB.upsert("artists", newValueDict, controlValueDict) logger.info(u"Seeing if we need album art for: " + artist['artist_name']) cache.getThumb(ArtistID=artistid) - logger.info(u"Updating complete for: " + artist['artist_name']) + if errors: + logger.info("Finished updating artist: " + artist['artist_name'] + " but with errors, so not marking it as updated in the database") + else: + logger.info(u"Updating complete for: " + artist['artist_name']) def addReleaseById(rid):