mirror of
https://github.com/rembo10/headphones.git
synced 2026-07-03 07:44:00 +01:00
+15
-124
@@ -1,67 +1,13 @@
|
||||
# Python #
|
||||
##########
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
env/
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
|
||||
# PyInstaller
|
||||
# Usually these files are written by a python script from a template
|
||||
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||
*.manifest
|
||||
*.spec
|
||||
|
||||
# Installer logs
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
htmlcov/
|
||||
.tox/
|
||||
.coverage
|
||||
.coverage.*
|
||||
.cache
|
||||
nosetests.xml
|
||||
coverage.xml
|
||||
*,cover
|
||||
.hypothesis/
|
||||
|
||||
# Translations
|
||||
*.mo
|
||||
*.pot
|
||||
|
||||
# Django stuff:
|
||||
*.log
|
||||
|
||||
# Sphinx documentation
|
||||
docs/_build/
|
||||
|
||||
# PyBuilder
|
||||
target/
|
||||
# Compiled source #
|
||||
###################
|
||||
*.pyc
|
||||
*.py~
|
||||
*.pyproj
|
||||
*.sln
|
||||
|
||||
# Headphones files #
|
||||
####################
|
||||
######################
|
||||
*.log
|
||||
*.db*
|
||||
*.db-journal
|
||||
@@ -76,68 +22,16 @@ cache/*
|
||||
*.key
|
||||
*.csr
|
||||
|
||||
# Windows #
|
||||
##########
|
||||
# Windows image file caches
|
||||
Thumbs.db
|
||||
ehthumbs.db
|
||||
|
||||
# Folder config file
|
||||
Desktop.ini
|
||||
|
||||
# Recycle Bin used on file shares
|
||||
$RECYCLE.BIN/
|
||||
|
||||
# Windows Installer files
|
||||
*.cab
|
||||
*.msi
|
||||
*.msm
|
||||
*.msp
|
||||
|
||||
# Windows shortcuts
|
||||
*.lnk
|
||||
|
||||
# OS X #
|
||||
########
|
||||
# OS generated files #
|
||||
######################
|
||||
.DS_Store?
|
||||
.DS_Store
|
||||
.AppleDouble
|
||||
.LSOverride
|
||||
|
||||
# Icon must end with two \r
|
||||
Icon
|
||||
|
||||
|
||||
# Thumbnails
|
||||
._*
|
||||
|
||||
# Files that might appear in the root of a volume
|
||||
.DocumentRevisions-V100
|
||||
.fseventsd
|
||||
.Spotlight-V100
|
||||
.TemporaryItems
|
||||
.Trashes
|
||||
.VolumeIcon.icns
|
||||
|
||||
# Directories potentially created on remote AFP share
|
||||
.AppleDB
|
||||
.AppleDesktop
|
||||
Network Trash Folder
|
||||
Temporary Items
|
||||
.apdisk
|
||||
|
||||
# Linux #
|
||||
#########
|
||||
*~
|
||||
|
||||
# KDE directory preferences
|
||||
.directory
|
||||
|
||||
# Linux trash folder which might appear on any partition or disk
|
||||
.Trash-*
|
||||
ehthumbs.db
|
||||
Icon?
|
||||
Thumbs.db
|
||||
|
||||
#Ignore files generated by PyCharm
|
||||
.idea/
|
||||
*.iml
|
||||
.idea/*
|
||||
|
||||
#Ignore files generated by vi
|
||||
*.swp
|
||||
@@ -159,6 +53,7 @@ Temporary Items
|
||||
*.bak
|
||||
*.cache
|
||||
*.ilk
|
||||
*.log
|
||||
[Bb]in
|
||||
[Dd]ebug*/
|
||||
*.lib
|
||||
@@ -171,7 +66,3 @@ _ReSharper*/
|
||||
/logs
|
||||
.project
|
||||
.pydevproject
|
||||
.vs/
|
||||
|
||||
#Python virtual environment
|
||||
env/
|
||||
+14
-18
@@ -1,24 +1,20 @@
|
||||
# Travis CI configuration file
|
||||
# http://about.travis-ci.org/docs/
|
||||
|
||||
language: python
|
||||
|
||||
# Available Python versions:
|
||||
# http://about.travis-ci.org/docs/user/ci-environment/#Python-VM-images
|
||||
python:
|
||||
- '2.6'
|
||||
- '2.7'
|
||||
- "2.6"
|
||||
- "2.7"
|
||||
# pylint 1.4 does not run under python 2.6
|
||||
install:
|
||||
- pip install -r requirements.txt
|
||||
- pip install pylint
|
||||
- pip install pyflakes
|
||||
- pip install pep8
|
||||
- pip install pyOpenSSL
|
||||
- pip install pylint==1.3.1
|
||||
- pip install pyflakes
|
||||
- pip install pep8
|
||||
script:
|
||||
- pep8 headphones
|
||||
- pyflakes headphones
|
||||
before_deploy:
|
||||
- pip install -r requirements.txt -t lib/
|
||||
- find . -name "*.pyc" -delete
|
||||
- find ./lib/ -path "*/*\-info/*" -delete
|
||||
- zip -r -9 headphones.zip data/ headphones/ init-scripts/ lib/ Headphones.py
|
||||
deploy:
|
||||
provider: releases
|
||||
api_key:
|
||||
secure: <GITHUB OATH TOKEN HERE>
|
||||
file: headphones.zip
|
||||
on:
|
||||
tags: true
|
||||
- nosetests headphones
|
||||
|
||||
@@ -1,5 +1,25 @@
|
||||
# Changelog
|
||||
|
||||
## v0.5.10
|
||||
Released 29 January 2016
|
||||
|
||||
Highlights:
|
||||
* Added: API option to post-process single folders
|
||||
* Added: Ability to specify extension when re-encoding
|
||||
* Added: Option to stop renaming folders
|
||||
* Fixed: Utorrent torrents not being removed (#2385)
|
||||
* Fixed: Torznab to transmission
|
||||
* Fixed: Magnet folder names in history
|
||||
* Fixed: Multiple torcache fixes
|
||||
* Fixed: Updated requests & urllib3 to latest versions to fix errors with pyOpenSSL
|
||||
* Improved: Use a temporary folder during post-processing
|
||||
* Improved: Added verify_ssl_cert option
|
||||
* Improved: Fixed track matching progress
|
||||
* Improved: pylint, pep8 & pylint fixes
|
||||
* Improved: Stop JS links from scrolling to the top of the page
|
||||
|
||||
The full list of commits can be found [here](https://github.com/rembo10/headphones/compare/v0.5.9...v0.5.10).
|
||||
|
||||
## v0.5.9
|
||||
Released 05 September 2015
|
||||
|
||||
|
||||
@@ -1,4 +1,7 @@
|
||||
# Headphones
|
||||
## Headphones
|
||||
|
||||
**Master Branch:** [](https://travis-ci.org/rembo10/headphones)
|
||||
**Develop Branch:** [](https://travis-ci.org/rembo10/headphones)
|
||||
|
||||
Headphones is an automated music downloader for NZB and Torrent, written in Python. It supports SABnzbd, NZBget, Transmission, µTorrent and Blackhole.
|
||||
|
||||
@@ -18,7 +21,7 @@ You are free to join the Headphones support community on IRC where you can ask q
|
||||
|
||||
1. Analyze your log, you just might find the solution yourself!
|
||||
2. You read the wiki and searched existing issues, but this is not solving your problem.
|
||||
3. Post the issue with a clear title, description and the HP log and use [proper markdown syntax](https://help.github.com/articles/github-flavored-markdown) to structure your text (code/log in code blocks).
|
||||
3. Post the issue with a clear title, description and the HP log and use [proper markdown syntax](https://help.github.com/articles/github-flavored-markdown) to structure your text (code/log in code blocks).
|
||||
4. Close your issue when it's solved! If you found the solution yourself, please comment so that others benefit from it.
|
||||
|
||||
**Feature requests** can be reported on the GitHub issue tracker too:
|
||||
|
||||
@@ -164,7 +164,7 @@ start_headphones () {
|
||||
handle_updates
|
||||
if ! is_running; then
|
||||
log_daemon_msg "Starting $DESC"
|
||||
start-stop-daemon -o -d "$APP_PATH" -c "$RUN_AS" --start "$EXTRA_SSD"_OPTS --pidfile "$PID_FILE" --exec "$DAEMON" -- "$DAEMON_OPTS"
|
||||
start-stop-daemon -o -d $APP_PATH -c $RUN_AS --start $EXTRA_SSD_OPTS --pidfile $PID_FILE --exec $DAEMON -- $DAEMON_OPTS
|
||||
check_retval
|
||||
else
|
||||
log_success_msg "$DESC: already running (pid $PID)"
|
||||
|
||||
@@ -0,0 +1,88 @@
|
||||
#!/usr/bin/python
|
||||
|
||||
####
|
||||
# 06/2010 Nic Wolfe <nic@wolfeden.ca>
|
||||
# 02/2006 Will Holcomb <wholcomb@gmail.com>
|
||||
#
|
||||
# This library is free software; you can redistribute it and/or
|
||||
# modify it under the terms of the GNU Lesser General Public
|
||||
# License as published by the Free Software Foundation; either
|
||||
# version 2.1 of the License, or (at your option) any later version.
|
||||
#
|
||||
# This library is distributed in the hope that it will be useful,
|
||||
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
|
||||
# Lesser General Public License for more details.
|
||||
#
|
||||
|
||||
import urllib
|
||||
import urllib2
|
||||
import mimetools, mimetypes
|
||||
import os, sys
|
||||
|
||||
# Controls how sequences are uncoded. If true, elements may be given multiple values by
|
||||
# assigning a sequence.
|
||||
doseq = 1
|
||||
|
||||
class MultipartPostHandler(urllib2.BaseHandler):
|
||||
handler_order = urllib2.HTTPHandler.handler_order - 10 # needs to run first
|
||||
|
||||
def http_request(self, request):
|
||||
data = request.get_data()
|
||||
if data is not None and type(data) != str:
|
||||
v_files = []
|
||||
v_vars = []
|
||||
try:
|
||||
for(key, value) in data.items():
|
||||
if type(value) in (file, list, tuple):
|
||||
v_files.append((key, value))
|
||||
else:
|
||||
v_vars.append((key, value))
|
||||
except TypeError:
|
||||
systype, value, traceback = sys.exc_info()
|
||||
raise TypeError, "not a valid non-string sequence or mapping object", traceback
|
||||
|
||||
if len(v_files) == 0:
|
||||
data = urllib.urlencode(v_vars, doseq)
|
||||
else:
|
||||
boundary, data = MultipartPostHandler.multipart_encode(v_vars, v_files)
|
||||
contenttype = 'multipart/form-data; boundary=%s' % boundary
|
||||
if(request.has_header('Content-Type')
|
||||
and request.get_header('Content-Type').find('multipart/form-data') != 0):
|
||||
print "Replacing %s with %s" % (request.get_header('content-type'), 'multipart/form-data')
|
||||
request.add_unredirected_header('Content-Type', contenttype)
|
||||
|
||||
request.add_data(data)
|
||||
return request
|
||||
|
||||
@staticmethod
|
||||
def multipart_encode(vars, files, boundary = None, buffer = None):
|
||||
if boundary is None:
|
||||
boundary = mimetools.choose_boundary()
|
||||
if buffer is None:
|
||||
buffer = ''
|
||||
for(key, value) in vars:
|
||||
buffer += '--%s\r\n' % boundary
|
||||
buffer += 'Content-Disposition: form-data; name="%s"' % key
|
||||
buffer += '\r\n\r\n' + value + '\r\n'
|
||||
for(key, fd) in files:
|
||||
|
||||
# allow them to pass in a file or a tuple with name & data
|
||||
if type(fd) == file:
|
||||
name_in = fd.name
|
||||
fd.seek(0)
|
||||
data_in = fd.read()
|
||||
elif type(fd) in (tuple, list):
|
||||
name_in, data_in = fd
|
||||
|
||||
filename = os.path.basename(name_in)
|
||||
contenttype = mimetypes.guess_type(filename)[0] or 'application/octet-stream'
|
||||
buffer += '--%s\r\n' % boundary
|
||||
buffer += 'Content-Disposition: form-data; name="%s"; filename="%s"\r\n' % (key, filename)
|
||||
buffer += 'Content-Type: %s\r\n' % contenttype
|
||||
# buffer += 'Content-Length: %s\r\n' % file_size
|
||||
buffer += '\r\n' + data_in + '\r\n'
|
||||
buffer += '--%s--\r\n\r\n' % boundary
|
||||
return boundary, buffer
|
||||
|
||||
https_request = http_request
|
||||
@@ -0,0 +1,5 @@
|
||||
version_info = (3, 0, 1)
|
||||
version = '3.0.1'
|
||||
release = '3.0.1'
|
||||
|
||||
__version__ = release # PEP 396
|
||||
@@ -0,0 +1,73 @@
|
||||
__all__ = ('EVENT_SCHEDULER_START', 'EVENT_SCHEDULER_SHUTDOWN', 'EVENT_EXECUTOR_ADDED', 'EVENT_EXECUTOR_REMOVED',
|
||||
'EVENT_JOBSTORE_ADDED', 'EVENT_JOBSTORE_REMOVED', 'EVENT_ALL_JOBS_REMOVED', 'EVENT_JOB_ADDED',
|
||||
'EVENT_JOB_REMOVED', 'EVENT_JOB_MODIFIED', 'EVENT_JOB_EXECUTED', 'EVENT_JOB_ERROR', 'EVENT_JOB_MISSED',
|
||||
'SchedulerEvent', 'JobEvent', 'JobExecutionEvent')
|
||||
|
||||
|
||||
EVENT_SCHEDULER_START = 1
|
||||
EVENT_SCHEDULER_SHUTDOWN = 2
|
||||
EVENT_EXECUTOR_ADDED = 4
|
||||
EVENT_EXECUTOR_REMOVED = 8
|
||||
EVENT_JOBSTORE_ADDED = 16
|
||||
EVENT_JOBSTORE_REMOVED = 32
|
||||
EVENT_ALL_JOBS_REMOVED = 64
|
||||
EVENT_JOB_ADDED = 128
|
||||
EVENT_JOB_REMOVED = 256
|
||||
EVENT_JOB_MODIFIED = 512
|
||||
EVENT_JOB_EXECUTED = 1024
|
||||
EVENT_JOB_ERROR = 2048
|
||||
EVENT_JOB_MISSED = 4096
|
||||
EVENT_ALL = (EVENT_SCHEDULER_START | EVENT_SCHEDULER_SHUTDOWN | EVENT_JOBSTORE_ADDED | EVENT_JOBSTORE_REMOVED |
|
||||
EVENT_JOB_ADDED | EVENT_JOB_REMOVED | EVENT_JOB_MODIFIED | EVENT_JOB_EXECUTED |
|
||||
EVENT_JOB_ERROR | EVENT_JOB_MISSED)
|
||||
|
||||
|
||||
class SchedulerEvent(object):
|
||||
"""
|
||||
An event that concerns the scheduler itself.
|
||||
|
||||
:ivar code: the type code of this event
|
||||
:ivar alias: alias of the job store or executor that was added or removed (if applicable)
|
||||
"""
|
||||
|
||||
def __init__(self, code, alias=None):
|
||||
super(SchedulerEvent, self).__init__()
|
||||
self.code = code
|
||||
self.alias = alias
|
||||
|
||||
def __repr__(self):
|
||||
return '<%s (code=%d)>' % (self.__class__.__name__, self.code)
|
||||
|
||||
|
||||
class JobEvent(SchedulerEvent):
|
||||
"""
|
||||
An event that concerns a job.
|
||||
|
||||
:ivar code: the type code of this event
|
||||
:ivar job_id: identifier of the job in question
|
||||
:ivar jobstore: alias of the job store containing the job in question
|
||||
"""
|
||||
|
||||
def __init__(self, code, job_id, jobstore):
|
||||
super(JobEvent, self).__init__(code)
|
||||
self.code = code
|
||||
self.job_id = job_id
|
||||
self.jobstore = jobstore
|
||||
|
||||
|
||||
class JobExecutionEvent(JobEvent):
|
||||
"""
|
||||
An event that concerns the execution of individual jobs.
|
||||
|
||||
:ivar scheduled_run_time: the time when the job was scheduled to be run
|
||||
:ivar retval: the return value of the successfully executed job
|
||||
:ivar exception: the exception raised by the job
|
||||
:ivar traceback: a formatted traceback for the exception
|
||||
"""
|
||||
|
||||
def __init__(self, code, job_id, jobstore, scheduled_run_time, retval=None, exception=None, traceback=None):
|
||||
super(JobExecutionEvent, self).__init__(code, job_id, jobstore)
|
||||
self.scheduled_run_time = scheduled_run_time
|
||||
self.retval = retval
|
||||
self.exception = exception
|
||||
self.traceback = traceback
|
||||
@@ -0,0 +1,28 @@
|
||||
from __future__ import absolute_import
|
||||
import sys
|
||||
|
||||
from apscheduler.executors.base import BaseExecutor, run_job
|
||||
|
||||
|
||||
class AsyncIOExecutor(BaseExecutor):
|
||||
"""
|
||||
Runs jobs in the default executor of the event loop.
|
||||
|
||||
Plugin alias: ``asyncio``
|
||||
"""
|
||||
|
||||
def start(self, scheduler, alias):
|
||||
super(AsyncIOExecutor, self).start(scheduler, alias)
|
||||
self._eventloop = scheduler._eventloop
|
||||
|
||||
def _do_submit_job(self, job, run_times):
|
||||
def callback(f):
|
||||
try:
|
||||
events = f.result()
|
||||
except:
|
||||
self._run_job_error(job.id, *sys.exc_info()[1:])
|
||||
else:
|
||||
self._run_job_success(job.id, events)
|
||||
|
||||
f = self._eventloop.run_in_executor(None, run_job, job, job._jobstore_alias, run_times, self._logger.name)
|
||||
f.add_done_callback(callback)
|
||||
@@ -0,0 +1,119 @@
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from collections import defaultdict
|
||||
from datetime import datetime, timedelta
|
||||
from traceback import format_tb
|
||||
import logging
|
||||
import sys
|
||||
|
||||
from pytz import utc
|
||||
import six
|
||||
|
||||
from apscheduler.events import JobExecutionEvent, EVENT_JOB_MISSED, EVENT_JOB_ERROR, EVENT_JOB_EXECUTED
|
||||
|
||||
|
||||
class MaxInstancesReachedError(Exception):
|
||||
def __init__(self, job):
|
||||
super(MaxInstancesReachedError, self).__init__(
|
||||
'Job "%s" has already reached its maximum number of instances (%d)' % (job.id, job.max_instances))
|
||||
|
||||
|
||||
class BaseExecutor(six.with_metaclass(ABCMeta, object)):
|
||||
"""Abstract base class that defines the interface that every executor must implement."""
|
||||
|
||||
_scheduler = None
|
||||
_lock = None
|
||||
_logger = logging.getLogger('apscheduler.executors')
|
||||
|
||||
def __init__(self):
|
||||
super(BaseExecutor, self).__init__()
|
||||
self._instances = defaultdict(lambda: 0)
|
||||
|
||||
def start(self, scheduler, alias):
|
||||
"""
|
||||
Called by the scheduler when the scheduler is being started or when the executor is being added to an already
|
||||
running scheduler.
|
||||
|
||||
:param apscheduler.schedulers.base.BaseScheduler scheduler: the scheduler that is starting this executor
|
||||
:param str|unicode alias: alias of this executor as it was assigned to the scheduler
|
||||
"""
|
||||
|
||||
self._scheduler = scheduler
|
||||
self._lock = scheduler._create_lock()
|
||||
self._logger = logging.getLogger('apscheduler.executors.%s' % alias)
|
||||
|
||||
def shutdown(self, wait=True):
|
||||
"""
|
||||
Shuts down this executor.
|
||||
|
||||
:param bool wait: ``True`` to wait until all submitted jobs have been executed
|
||||
"""
|
||||
|
||||
def submit_job(self, job, run_times):
|
||||
"""
|
||||
Submits job for execution.
|
||||
|
||||
:param Job job: job to execute
|
||||
:param list[datetime] run_times: list of datetimes specifying when the job should have been run
|
||||
:raises MaxInstancesReachedError: if the maximum number of allowed instances for this job has been reached
|
||||
"""
|
||||
|
||||
assert self._lock is not None, 'This executor has not been started yet'
|
||||
with self._lock:
|
||||
if self._instances[job.id] >= job.max_instances:
|
||||
raise MaxInstancesReachedError(job)
|
||||
|
||||
self._do_submit_job(job, run_times)
|
||||
self._instances[job.id] += 1
|
||||
|
||||
@abstractmethod
|
||||
def _do_submit_job(self, job, run_times):
|
||||
"""Performs the actual task of scheduling `run_job` to be called."""
|
||||
|
||||
def _run_job_success(self, job_id, events):
|
||||
"""Called by the executor with the list of generated events when `run_job` has been successfully called."""
|
||||
|
||||
with self._lock:
|
||||
self._instances[job_id] -= 1
|
||||
|
||||
for event in events:
|
||||
self._scheduler._dispatch_event(event)
|
||||
|
||||
def _run_job_error(self, job_id, exc, traceback=None):
|
||||
"""Called by the executor with the exception if there is an error calling `run_job`."""
|
||||
|
||||
with self._lock:
|
||||
self._instances[job_id] -= 1
|
||||
|
||||
exc_info = (exc.__class__, exc, traceback)
|
||||
self._logger.error('Error running job %s', job_id, exc_info=exc_info)
|
||||
|
||||
|
||||
def run_job(job, jobstore_alias, run_times, logger_name):
|
||||
"""Called by executors to run the job. Returns a list of scheduler events to be dispatched by the scheduler."""
|
||||
|
||||
events = []
|
||||
logger = logging.getLogger(logger_name)
|
||||
for run_time in run_times:
|
||||
# See if the job missed its run time window, and handle possible misfires accordingly
|
||||
if job.misfire_grace_time is not None:
|
||||
difference = datetime.now(utc) - run_time
|
||||
grace_time = timedelta(seconds=job.misfire_grace_time)
|
||||
if difference > grace_time:
|
||||
events.append(JobExecutionEvent(EVENT_JOB_MISSED, job.id, jobstore_alias, run_time))
|
||||
logger.warning('Run time of job "%s" was missed by %s', job, difference)
|
||||
continue
|
||||
|
||||
logger.info('Running job "%s" (scheduled at %s)', job, run_time)
|
||||
try:
|
||||
retval = job.func(*job.args, **job.kwargs)
|
||||
except:
|
||||
exc, tb = sys.exc_info()[1:]
|
||||
formatted_tb = ''.join(format_tb(tb))
|
||||
events.append(JobExecutionEvent(EVENT_JOB_ERROR, job.id, jobstore_alias, run_time, exception=exc,
|
||||
traceback=formatted_tb))
|
||||
logger.exception('Job "%s" raised an exception', job)
|
||||
else:
|
||||
events.append(JobExecutionEvent(EVENT_JOB_EXECUTED, job.id, jobstore_alias, run_time, retval=retval))
|
||||
logger.info('Job "%s" executed successfully', job)
|
||||
|
||||
return events
|
||||
@@ -0,0 +1,19 @@
|
||||
import sys
|
||||
|
||||
from apscheduler.executors.base import BaseExecutor, run_job
|
||||
|
||||
|
||||
class DebugExecutor(BaseExecutor):
|
||||
"""
|
||||
A special executor that executes the target callable directly instead of deferring it to a thread or process.
|
||||
|
||||
Plugin alias: ``debug``
|
||||
"""
|
||||
|
||||
def _do_submit_job(self, job, run_times):
|
||||
try:
|
||||
events = run_job(job, job._jobstore_alias, run_times, self._logger.name)
|
||||
except:
|
||||
self._run_job_error(job.id, *sys.exc_info()[1:])
|
||||
else:
|
||||
self._run_job_success(job.id, events)
|
||||
@@ -0,0 +1,29 @@
|
||||
from __future__ import absolute_import
|
||||
import sys
|
||||
|
||||
from apscheduler.executors.base import BaseExecutor, run_job
|
||||
|
||||
|
||||
try:
|
||||
import gevent
|
||||
except ImportError: # pragma: nocover
|
||||
raise ImportError('GeventExecutor requires gevent installed')
|
||||
|
||||
|
||||
class GeventExecutor(BaseExecutor):
|
||||
"""
|
||||
Runs jobs as greenlets.
|
||||
|
||||
Plugin alias: ``gevent``
|
||||
"""
|
||||
|
||||
def _do_submit_job(self, job, run_times):
|
||||
def callback(greenlet):
|
||||
try:
|
||||
events = greenlet.get()
|
||||
except:
|
||||
self._run_job_error(job.id, *sys.exc_info()[1:])
|
||||
else:
|
||||
self._run_job_success(job.id, events)
|
||||
|
||||
gevent.spawn(run_job, job, job._jobstore_alias, run_times, self._logger.name).link(callback)
|
||||
@@ -0,0 +1,54 @@
|
||||
from abc import abstractmethod
|
||||
import concurrent.futures
|
||||
|
||||
from apscheduler.executors.base import BaseExecutor, run_job
|
||||
|
||||
|
||||
class BasePoolExecutor(BaseExecutor):
|
||||
@abstractmethod
|
||||
def __init__(self, pool):
|
||||
super(BasePoolExecutor, self).__init__()
|
||||
self._pool = pool
|
||||
|
||||
def _do_submit_job(self, job, run_times):
|
||||
def callback(f):
|
||||
exc, tb = (f.exception_info() if hasattr(f, 'exception_info') else
|
||||
(f.exception(), getattr(f.exception(), '__traceback__', None)))
|
||||
if exc:
|
||||
self._run_job_error(job.id, exc, tb)
|
||||
else:
|
||||
self._run_job_success(job.id, f.result())
|
||||
|
||||
f = self._pool.submit(run_job, job, job._jobstore_alias, run_times, self._logger.name)
|
||||
f.add_done_callback(callback)
|
||||
|
||||
def shutdown(self, wait=True):
|
||||
self._pool.shutdown(wait)
|
||||
|
||||
|
||||
class ThreadPoolExecutor(BasePoolExecutor):
|
||||
"""
|
||||
An executor that runs jobs in a concurrent.futures thread pool.
|
||||
|
||||
Plugin alias: ``threadpool``
|
||||
|
||||
:param max_workers: the maximum number of spawned threads.
|
||||
"""
|
||||
|
||||
def __init__(self, max_workers=10):
|
||||
pool = concurrent.futures.ThreadPoolExecutor(int(max_workers))
|
||||
super(ThreadPoolExecutor, self).__init__(pool)
|
||||
|
||||
|
||||
class ProcessPoolExecutor(BasePoolExecutor):
|
||||
"""
|
||||
An executor that runs jobs in a concurrent.futures process pool.
|
||||
|
||||
Plugin alias: ``processpool``
|
||||
|
||||
:param max_workers: the maximum number of spawned processes.
|
||||
"""
|
||||
|
||||
def __init__(self, max_workers=10):
|
||||
pool = concurrent.futures.ProcessPoolExecutor(int(max_workers))
|
||||
super(ProcessPoolExecutor, self).__init__(pool)
|
||||
@@ -0,0 +1,25 @@
|
||||
from __future__ import absolute_import
|
||||
|
||||
from apscheduler.executors.base import BaseExecutor, run_job
|
||||
|
||||
|
||||
class TwistedExecutor(BaseExecutor):
|
||||
"""
|
||||
Runs jobs in the reactor's thread pool.
|
||||
|
||||
Plugin alias: ``twisted``
|
||||
"""
|
||||
|
||||
def start(self, scheduler, alias):
|
||||
super(TwistedExecutor, self).start(scheduler, alias)
|
||||
self._reactor = scheduler._reactor
|
||||
|
||||
def _do_submit_job(self, job, run_times):
|
||||
def callback(success, result):
|
||||
if success:
|
||||
self._run_job_success(job.id, result)
|
||||
else:
|
||||
self._run_job_error(job.id, result.value, result.tb)
|
||||
|
||||
self._reactor.getThreadPool().callInThreadWithCallback(callback, run_job, job, job._jobstore_alias, run_times,
|
||||
self._logger.name)
|
||||
@@ -0,0 +1,252 @@
|
||||
from collections import Iterable, Mapping
|
||||
from uuid import uuid4
|
||||
|
||||
import six
|
||||
|
||||
from apscheduler.triggers.base import BaseTrigger
|
||||
from apscheduler.util import ref_to_obj, obj_to_ref, datetime_repr, repr_escape, get_callable_name, check_callable_args, \
|
||||
convert_to_datetime
|
||||
|
||||
|
||||
class Job(object):
|
||||
"""
|
||||
Contains the options given when scheduling callables and its current schedule and other state.
|
||||
This class should never be instantiated by the user.
|
||||
|
||||
:var str id: the unique identifier of this job
|
||||
:var str name: the description of this job
|
||||
:var func: the callable to execute
|
||||
:var tuple|list args: positional arguments to the callable
|
||||
:var dict kwargs: keyword arguments to the callable
|
||||
:var bool coalesce: whether to only run the job once when several run times are due
|
||||
:var trigger: the trigger object that controls the schedule of this job
|
||||
:var str executor: the name of the executor that will run this job
|
||||
:var int misfire_grace_time: the time (in seconds) how much this job's execution is allowed to be late
|
||||
:var int max_instances: the maximum number of concurrently executing instances allowed for this job
|
||||
:var datetime.datetime next_run_time: the next scheduled run time of this job
|
||||
"""
|
||||
|
||||
__slots__ = ('_scheduler', '_jobstore_alias', 'id', 'trigger', 'executor', 'func', 'func_ref', 'args', 'kwargs',
|
||||
'name', 'misfire_grace_time', 'coalesce', 'max_instances', 'next_run_time')
|
||||
|
||||
def __init__(self, scheduler, id=None, **kwargs):
|
||||
super(Job, self).__init__()
|
||||
self._scheduler = scheduler
|
||||
self._jobstore_alias = None
|
||||
self._modify(id=id or uuid4().hex, **kwargs)
|
||||
|
||||
def modify(self, **changes):
|
||||
"""
|
||||
Makes the given changes to this job and saves it in the associated job store.
|
||||
Accepted keyword arguments are the same as the variables on this class.
|
||||
|
||||
.. seealso:: :meth:`~apscheduler.schedulers.base.BaseScheduler.modify_job`
|
||||
"""
|
||||
|
||||
self._scheduler.modify_job(self.id, self._jobstore_alias, **changes)
|
||||
|
||||
def reschedule(self, trigger, **trigger_args):
|
||||
"""
|
||||
Shortcut for switching the trigger on this job.
|
||||
|
||||
.. seealso:: :meth:`~apscheduler.schedulers.base.BaseScheduler.reschedule_job`
|
||||
"""
|
||||
|
||||
self._scheduler.reschedule_job(self.id, self._jobstore_alias, trigger, **trigger_args)
|
||||
|
||||
def pause(self):
|
||||
"""
|
||||
Temporarily suspend the execution of this job.
|
||||
|
||||
.. seealso:: :meth:`~apscheduler.schedulers.base.BaseScheduler.pause_job`
|
||||
"""
|
||||
|
||||
self._scheduler.pause_job(self.id, self._jobstore_alias)
|
||||
|
||||
def resume(self):
|
||||
"""
|
||||
Resume the schedule of this job if previously paused.
|
||||
|
||||
.. seealso:: :meth:`~apscheduler.schedulers.base.BaseScheduler.resume_job`
|
||||
"""
|
||||
|
||||
self._scheduler.resume_job(self.id, self._jobstore_alias)
|
||||
|
||||
def remove(self):
|
||||
"""
|
||||
Unschedules this job and removes it from its associated job store.
|
||||
|
||||
.. seealso:: :meth:`~apscheduler.schedulers.base.BaseScheduler.remove_job`
|
||||
"""
|
||||
|
||||
self._scheduler.remove_job(self.id, self._jobstore_alias)
|
||||
|
||||
@property
|
||||
def pending(self):
|
||||
"""Returns ``True`` if the referenced job is still waiting to be added to its designated job store."""
|
||||
|
||||
return self._jobstore_alias is None
|
||||
|
||||
#
|
||||
# Private API
|
||||
#
|
||||
|
||||
def _get_run_times(self, now):
|
||||
"""
|
||||
Computes the scheduled run times between ``next_run_time`` and ``now`` (inclusive).
|
||||
|
||||
:type now: datetime.datetime
|
||||
:rtype: list[datetime.datetime]
|
||||
"""
|
||||
|
||||
run_times = []
|
||||
next_run_time = self.next_run_time
|
||||
while next_run_time and next_run_time <= now:
|
||||
run_times.append(next_run_time)
|
||||
next_run_time = self.trigger.get_next_fire_time(next_run_time, now)
|
||||
|
||||
return run_times
|
||||
|
||||
def _modify(self, **changes):
|
||||
"""Validates the changes to the Job and makes the modifications if and only if all of them validate."""
|
||||
|
||||
approved = {}
|
||||
|
||||
if 'id' in changes:
|
||||
value = changes.pop('id')
|
||||
if not isinstance(value, six.string_types):
|
||||
raise TypeError("id must be a nonempty string")
|
||||
if hasattr(self, 'id'):
|
||||
raise ValueError('The job ID may not be changed')
|
||||
approved['id'] = value
|
||||
|
||||
if 'func' in changes or 'args' in changes or 'kwargs' in changes:
|
||||
func = changes.pop('func') if 'func' in changes else self.func
|
||||
args = changes.pop('args') if 'args' in changes else self.args
|
||||
kwargs = changes.pop('kwargs') if 'kwargs' in changes else self.kwargs
|
||||
|
||||
if isinstance(func, str):
|
||||
func_ref = func
|
||||
func = ref_to_obj(func)
|
||||
elif callable(func):
|
||||
try:
|
||||
func_ref = obj_to_ref(func)
|
||||
except ValueError:
|
||||
# If this happens, this Job won't be serializable
|
||||
func_ref = None
|
||||
else:
|
||||
raise TypeError('func must be a callable or a textual reference to one')
|
||||
|
||||
if not hasattr(self, 'name') and changes.get('name', None) is None:
|
||||
changes['name'] = get_callable_name(func)
|
||||
|
||||
if isinstance(args, six.string_types) or not isinstance(args, Iterable):
|
||||
raise TypeError('args must be a non-string iterable')
|
||||
if isinstance(kwargs, six.string_types) or not isinstance(kwargs, Mapping):
|
||||
raise TypeError('kwargs must be a dict-like object')
|
||||
|
||||
check_callable_args(func, args, kwargs)
|
||||
|
||||
approved['func'] = func
|
||||
approved['func_ref'] = func_ref
|
||||
approved['args'] = args
|
||||
approved['kwargs'] = kwargs
|
||||
|
||||
if 'name' in changes:
|
||||
value = changes.pop('name')
|
||||
if not value or not isinstance(value, six.string_types):
|
||||
raise TypeError("name must be a nonempty string")
|
||||
approved['name'] = value
|
||||
|
||||
if 'misfire_grace_time' in changes:
|
||||
value = changes.pop('misfire_grace_time')
|
||||
if value is not None and (not isinstance(value, six.integer_types) or value <= 0):
|
||||
raise TypeError('misfire_grace_time must be either None or a positive integer')
|
||||
approved['misfire_grace_time'] = value
|
||||
|
||||
if 'coalesce' in changes:
|
||||
value = bool(changes.pop('coalesce'))
|
||||
approved['coalesce'] = value
|
||||
|
||||
if 'max_instances' in changes:
|
||||
value = changes.pop('max_instances')
|
||||
if not isinstance(value, six.integer_types) or value <= 0:
|
||||
raise TypeError('max_instances must be a positive integer')
|
||||
approved['max_instances'] = value
|
||||
|
||||
if 'trigger' in changes:
|
||||
trigger = changes.pop('trigger')
|
||||
if not isinstance(trigger, BaseTrigger):
|
||||
raise TypeError('Expected a trigger instance, got %s instead' % trigger.__class__.__name__)
|
||||
|
||||
approved['trigger'] = trigger
|
||||
|
||||
if 'executor' in changes:
|
||||
value = changes.pop('executor')
|
||||
if not isinstance(value, six.string_types):
|
||||
raise TypeError('executor must be a string')
|
||||
approved['executor'] = value
|
||||
|
||||
if 'next_run_time' in changes:
|
||||
value = changes.pop('next_run_time')
|
||||
approved['next_run_time'] = convert_to_datetime(value, self._scheduler.timezone, 'next_run_time')
|
||||
|
||||
if changes:
|
||||
raise AttributeError('The following are not modifiable attributes of Job: %s' % ', '.join(changes))
|
||||
|
||||
for key, value in six.iteritems(approved):
|
||||
setattr(self, key, value)
|
||||
|
||||
def __getstate__(self):
|
||||
# Don't allow this Job to be serialized if the function reference could not be determined
|
||||
if not self.func_ref:
|
||||
raise ValueError('This Job cannot be serialized since the reference to its callable (%r) could not be '
|
||||
'determined. Consider giving a textual reference (module:function name) instead.' %
|
||||
(self.func,))
|
||||
|
||||
return {
|
||||
'version': 1,
|
||||
'id': self.id,
|
||||
'func': self.func_ref,
|
||||
'trigger': self.trigger,
|
||||
'executor': self.executor,
|
||||
'args': self.args,
|
||||
'kwargs': self.kwargs,
|
||||
'name': self.name,
|
||||
'misfire_grace_time': self.misfire_grace_time,
|
||||
'coalesce': self.coalesce,
|
||||
'max_instances': self.max_instances,
|
||||
'next_run_time': self.next_run_time
|
||||
}
|
||||
|
||||
def __setstate__(self, state):
|
||||
if state.get('version', 1) > 1:
|
||||
raise ValueError('Job has version %s, but only version 1 can be handled' % state['version'])
|
||||
|
||||
self.id = state['id']
|
||||
self.func_ref = state['func']
|
||||
self.func = ref_to_obj(self.func_ref)
|
||||
self.trigger = state['trigger']
|
||||
self.executor = state['executor']
|
||||
self.args = state['args']
|
||||
self.kwargs = state['kwargs']
|
||||
self.name = state['name']
|
||||
self.misfire_grace_time = state['misfire_grace_time']
|
||||
self.coalesce = state['coalesce']
|
||||
self.max_instances = state['max_instances']
|
||||
self.next_run_time = state['next_run_time']
|
||||
|
||||
def __eq__(self, other):
|
||||
if isinstance(other, Job):
|
||||
return self.id == other.id
|
||||
return NotImplemented
|
||||
|
||||
def __repr__(self):
|
||||
return '<Job (id=%s name=%s)>' % (repr_escape(self.id), repr_escape(self.name))
|
||||
|
||||
def __str__(self):
|
||||
return '%s (trigger: %s, next run at: %s)' % (repr_escape(self.name), repr_escape(str(self.trigger)),
|
||||
datetime_repr(self.next_run_time))
|
||||
|
||||
def __unicode__(self):
|
||||
return six.u('%s (trigger: %s, next run at: %s)') % (self.name, self.trigger, datetime_repr(self.next_run_time))
|
||||
@@ -0,0 +1,127 @@
|
||||
from abc import ABCMeta, abstractmethod
|
||||
import logging
|
||||
|
||||
import six
|
||||
|
||||
|
||||
class JobLookupError(KeyError):
|
||||
"""Raised when the job store cannot find a job for update or removal."""
|
||||
|
||||
def __init__(self, job_id):
|
||||
super(JobLookupError, self).__init__(six.u('No job by the id of %s was found') % job_id)
|
||||
|
||||
|
||||
class ConflictingIdError(KeyError):
|
||||
"""Raised when the uniqueness of job IDs is being violated."""
|
||||
|
||||
def __init__(self, job_id):
|
||||
super(ConflictingIdError, self).__init__(six.u('Job identifier (%s) conflicts with an existing job') % job_id)
|
||||
|
||||
|
||||
class TransientJobError(ValueError):
|
||||
"""Raised when an attempt to add transient (with no func_ref) job to a persistent job store is detected."""
|
||||
|
||||
def __init__(self, job_id):
|
||||
super(TransientJobError, self).__init__(
|
||||
six.u('Job (%s) cannot be added to this job store because a reference to the callable could not be '
|
||||
'determined.') % job_id)
|
||||
|
||||
|
||||
class BaseJobStore(six.with_metaclass(ABCMeta)):
|
||||
"""Abstract base class that defines the interface that every job store must implement."""
|
||||
|
||||
_scheduler = None
|
||||
_alias = None
|
||||
_logger = logging.getLogger('apscheduler.jobstores')
|
||||
|
||||
def start(self, scheduler, alias):
|
||||
"""
|
||||
Called by the scheduler when the scheduler is being started or when the job store is being added to an already
|
||||
running scheduler.
|
||||
|
||||
:param apscheduler.schedulers.base.BaseScheduler scheduler: the scheduler that is starting this job store
|
||||
:param str|unicode alias: alias of this job store as it was assigned to the scheduler
|
||||
"""
|
||||
|
||||
self._scheduler = scheduler
|
||||
self._alias = alias
|
||||
self._logger = logging.getLogger('apscheduler.jobstores.%s' % alias)
|
||||
|
||||
def shutdown(self):
|
||||
"""Frees any resources still bound to this job store."""
|
||||
|
||||
@abstractmethod
|
||||
def lookup_job(self, job_id):
|
||||
"""
|
||||
Returns a specific job, or ``None`` if it isn't found..
|
||||
|
||||
The job store is responsible for setting the ``scheduler`` and ``jobstore`` attributes of the returned job to
|
||||
point to the scheduler and itself, respectively.
|
||||
|
||||
:param str|unicode job_id: identifier of the job
|
||||
:rtype: Job
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_due_jobs(self, now):
|
||||
"""
|
||||
Returns the list of jobs that have ``next_run_time`` earlier or equal to ``now``.
|
||||
The returned jobs must be sorted by next run time (ascending).
|
||||
|
||||
:param datetime.datetime now: the current (timezone aware) datetime
|
||||
:rtype: list[Job]
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_next_run_time(self):
|
||||
"""
|
||||
Returns the earliest run time of all the jobs stored in this job store, or ``None`` if there are no active jobs.
|
||||
|
||||
:rtype: datetime.datetime
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_all_jobs(self):
|
||||
"""
|
||||
Returns a list of all jobs in this job store. The returned jobs should be sorted by next run time (ascending).
|
||||
Paused jobs (next_run_time is None) should be sorted last.
|
||||
|
||||
The job store is responsible for setting the ``scheduler`` and ``jobstore`` attributes of the returned jobs to
|
||||
point to the scheduler and itself, respectively.
|
||||
|
||||
:rtype: list[Job]
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def add_job(self, job):
|
||||
"""
|
||||
Adds the given job to this store.
|
||||
|
||||
:param Job job: the job to add
|
||||
:raises ConflictingIdError: if there is another job in this store with the same ID
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def update_job(self, job):
|
||||
"""
|
||||
Replaces the job in the store with the given newer version.
|
||||
|
||||
:param Job job: the job to update
|
||||
:raises JobLookupError: if the job does not exist
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def remove_job(self, job_id):
|
||||
"""
|
||||
Removes the given job from this store.
|
||||
|
||||
:param str|unicode job_id: identifier of the job
|
||||
:raises JobLookupError: if the job does not exist
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def remove_all_jobs(self):
|
||||
"""Removes all jobs from this store."""
|
||||
|
||||
def __repr__(self):
|
||||
return '<%s>' % self.__class__.__name__
|
||||
@@ -0,0 +1,107 @@
|
||||
from __future__ import absolute_import
|
||||
|
||||
from apscheduler.jobstores.base import BaseJobStore, JobLookupError, ConflictingIdError
|
||||
from apscheduler.util import datetime_to_utc_timestamp
|
||||
|
||||
|
||||
class MemoryJobStore(BaseJobStore):
|
||||
"""
|
||||
Stores jobs in an array in RAM. Provides no persistence support.
|
||||
|
||||
Plugin alias: ``memory``
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super(MemoryJobStore, self).__init__()
|
||||
self._jobs = [] # list of (job, timestamp), sorted by next_run_time and job id (ascending)
|
||||
self._jobs_index = {} # id -> (job, timestamp) lookup table
|
||||
|
||||
def lookup_job(self, job_id):
|
||||
return self._jobs_index.get(job_id, (None, None))[0]
|
||||
|
||||
def get_due_jobs(self, now):
|
||||
now_timestamp = datetime_to_utc_timestamp(now)
|
||||
pending = []
|
||||
for job, timestamp in self._jobs:
|
||||
if timestamp is None or timestamp > now_timestamp:
|
||||
break
|
||||
pending.append(job)
|
||||
|
||||
return pending
|
||||
|
||||
def get_next_run_time(self):
|
||||
return self._jobs[0][0].next_run_time if self._jobs else None
|
||||
|
||||
def get_all_jobs(self):
|
||||
return [j[0] for j in self._jobs]
|
||||
|
||||
def add_job(self, job):
|
||||
if job.id in self._jobs_index:
|
||||
raise ConflictingIdError(job.id)
|
||||
|
||||
timestamp = datetime_to_utc_timestamp(job.next_run_time)
|
||||
index = self._get_job_index(timestamp, job.id)
|
||||
self._jobs.insert(index, (job, timestamp))
|
||||
self._jobs_index[job.id] = (job, timestamp)
|
||||
|
||||
def update_job(self, job):
|
||||
old_job, old_timestamp = self._jobs_index.get(job.id, (None, None))
|
||||
if old_job is None:
|
||||
raise JobLookupError(job.id)
|
||||
|
||||
# If the next run time has not changed, simply replace the job in its present index.
|
||||
# Otherwise, reinsert the job to the list to preserve the ordering.
|
||||
old_index = self._get_job_index(old_timestamp, old_job.id)
|
||||
new_timestamp = datetime_to_utc_timestamp(job.next_run_time)
|
||||
if old_timestamp == new_timestamp:
|
||||
self._jobs[old_index] = (job, new_timestamp)
|
||||
else:
|
||||
del self._jobs[old_index]
|
||||
new_index = self._get_job_index(new_timestamp, job.id)
|
||||
self._jobs.insert(new_index, (job, new_timestamp))
|
||||
|
||||
self._jobs_index[old_job.id] = (job, new_timestamp)
|
||||
|
||||
def remove_job(self, job_id):
|
||||
job, timestamp = self._jobs_index.get(job_id, (None, None))
|
||||
if job is None:
|
||||
raise JobLookupError(job_id)
|
||||
|
||||
index = self._get_job_index(timestamp, job_id)
|
||||
del self._jobs[index]
|
||||
del self._jobs_index[job.id]
|
||||
|
||||
def remove_all_jobs(self):
|
||||
self._jobs = []
|
||||
self._jobs_index = {}
|
||||
|
||||
def shutdown(self):
|
||||
self.remove_all_jobs()
|
||||
|
||||
def _get_job_index(self, timestamp, job_id):
|
||||
"""
|
||||
Returns the index of the given job, or if it's not found, the index where the job should be inserted based on
|
||||
the given timestamp.
|
||||
|
||||
:type timestamp: int
|
||||
:type job_id: str
|
||||
"""
|
||||
|
||||
lo, hi = 0, len(self._jobs)
|
||||
timestamp = float('inf') if timestamp is None else timestamp
|
||||
while lo < hi:
|
||||
mid = (lo + hi) // 2
|
||||
mid_job, mid_timestamp = self._jobs[mid]
|
||||
mid_timestamp = float('inf') if mid_timestamp is None else mid_timestamp
|
||||
if mid_timestamp > timestamp:
|
||||
hi = mid
|
||||
elif mid_timestamp < timestamp:
|
||||
lo = mid + 1
|
||||
elif mid_job.id > job_id:
|
||||
hi = mid
|
||||
elif mid_job.id < job_id:
|
||||
lo = mid + 1
|
||||
else:
|
||||
return mid
|
||||
|
||||
return lo
|
||||
@@ -0,0 +1,124 @@
|
||||
from __future__ import absolute_import
|
||||
|
||||
from apscheduler.jobstores.base import BaseJobStore, JobLookupError, ConflictingIdError
|
||||
from apscheduler.util import maybe_ref, datetime_to_utc_timestamp, utc_timestamp_to_datetime
|
||||
from apscheduler.job import Job
|
||||
|
||||
try:
|
||||
import cPickle as pickle
|
||||
except ImportError: # pragma: nocover
|
||||
import pickle
|
||||
|
||||
try:
|
||||
from bson.binary import Binary
|
||||
from pymongo.errors import DuplicateKeyError
|
||||
from pymongo import MongoClient, ASCENDING
|
||||
except ImportError: # pragma: nocover
|
||||
raise ImportError('MongoDBJobStore requires PyMongo installed')
|
||||
|
||||
|
||||
class MongoDBJobStore(BaseJobStore):
|
||||
"""
|
||||
Stores jobs in a MongoDB database. Any leftover keyword arguments are directly passed to pymongo's `MongoClient
|
||||
<http://api.mongodb.org/python/current/api/pymongo/mongo_client.html#pymongo.mongo_client.MongoClient>`_.
|
||||
|
||||
Plugin alias: ``mongodb``
|
||||
|
||||
:param str database: database to store jobs in
|
||||
:param str collection: collection to store jobs in
|
||||
:param client: a :class:`~pymongo.mongo_client.MongoClient` instance to use instead of providing connection
|
||||
arguments
|
||||
:param int pickle_protocol: pickle protocol level to use (for serialization), defaults to the highest available
|
||||
"""
|
||||
|
||||
def __init__(self, database='apscheduler', collection='jobs', client=None,
|
||||
pickle_protocol=pickle.HIGHEST_PROTOCOL, **connect_args):
|
||||
super(MongoDBJobStore, self).__init__()
|
||||
self.pickle_protocol = pickle_protocol
|
||||
|
||||
if not database:
|
||||
raise ValueError('The "database" parameter must not be empty')
|
||||
if not collection:
|
||||
raise ValueError('The "collection" parameter must not be empty')
|
||||
|
||||
if client:
|
||||
self.connection = maybe_ref(client)
|
||||
else:
|
||||
connect_args.setdefault('w', 1)
|
||||
self.connection = MongoClient(**connect_args)
|
||||
|
||||
self.collection = self.connection[database][collection]
|
||||
self.collection.ensure_index('next_run_time', sparse=True)
|
||||
|
||||
def lookup_job(self, job_id):
|
||||
document = self.collection.find_one(job_id, ['job_state'])
|
||||
return self._reconstitute_job(document['job_state']) if document else None
|
||||
|
||||
def get_due_jobs(self, now):
|
||||
timestamp = datetime_to_utc_timestamp(now)
|
||||
return self._get_jobs({'next_run_time': {'$lte': timestamp}})
|
||||
|
||||
def get_next_run_time(self):
|
||||
document = self.collection.find_one({'next_run_time': {'$ne': None}}, fields=['next_run_time'],
|
||||
sort=[('next_run_time', ASCENDING)])
|
||||
return utc_timestamp_to_datetime(document['next_run_time']) if document else None
|
||||
|
||||
def get_all_jobs(self):
|
||||
return self._get_jobs({})
|
||||
|
||||
def add_job(self, job):
|
||||
try:
|
||||
self.collection.insert({
|
||||
'_id': job.id,
|
||||
'next_run_time': datetime_to_utc_timestamp(job.next_run_time),
|
||||
'job_state': Binary(pickle.dumps(job.__getstate__(), self.pickle_protocol))
|
||||
})
|
||||
except DuplicateKeyError:
|
||||
raise ConflictingIdError(job.id)
|
||||
|
||||
def update_job(self, job):
|
||||
changes = {
|
||||
'next_run_time': datetime_to_utc_timestamp(job.next_run_time),
|
||||
'job_state': Binary(pickle.dumps(job.__getstate__(), self.pickle_protocol))
|
||||
}
|
||||
result = self.collection.update({'_id': job.id}, {'$set': changes})
|
||||
if result and result['n'] == 0:
|
||||
raise JobLookupError(id)
|
||||
|
||||
def remove_job(self, job_id):
|
||||
result = self.collection.remove(job_id)
|
||||
if result and result['n'] == 0:
|
||||
raise JobLookupError(job_id)
|
||||
|
||||
def remove_all_jobs(self):
|
||||
self.collection.remove()
|
||||
|
||||
def shutdown(self):
|
||||
self.connection.disconnect()
|
||||
|
||||
def _reconstitute_job(self, job_state):
|
||||
job_state = pickle.loads(job_state)
|
||||
job = Job.__new__(Job)
|
||||
job.__setstate__(job_state)
|
||||
job._scheduler = self._scheduler
|
||||
job._jobstore_alias = self._alias
|
||||
return job
|
||||
|
||||
def _get_jobs(self, conditions):
|
||||
jobs = []
|
||||
failed_job_ids = []
|
||||
for document in self.collection.find(conditions, ['_id', 'job_state'], sort=[('next_run_time', ASCENDING)]):
|
||||
try:
|
||||
jobs.append(self._reconstitute_job(document['job_state']))
|
||||
except:
|
||||
self._logger.exception('Unable to restore job "%s" -- removing it', document['_id'])
|
||||
failed_job_ids.append(document['_id'])
|
||||
|
||||
# Remove all the jobs we failed to restore
|
||||
if failed_job_ids:
|
||||
self.collection.remove({'_id': {'$in': failed_job_ids}})
|
||||
|
||||
return jobs
|
||||
|
||||
def __repr__(self):
|
||||
return '<%s (client=%s)>' % (self.__class__.__name__, self.connection)
|
||||
@@ -0,0 +1,138 @@
|
||||
from __future__ import absolute_import
|
||||
|
||||
import six
|
||||
|
||||
from apscheduler.jobstores.base import BaseJobStore, JobLookupError, ConflictingIdError
|
||||
from apscheduler.util import datetime_to_utc_timestamp, utc_timestamp_to_datetime
|
||||
from apscheduler.job import Job
|
||||
|
||||
try:
|
||||
import cPickle as pickle
|
||||
except ImportError: # pragma: nocover
|
||||
import pickle
|
||||
|
||||
try:
|
||||
from redis import StrictRedis
|
||||
except ImportError: # pragma: nocover
|
||||
raise ImportError('RedisJobStore requires redis installed')
|
||||
|
||||
|
||||
class RedisJobStore(BaseJobStore):
|
||||
"""
|
||||
Stores jobs in a Redis database. Any leftover keyword arguments are directly passed to redis's StrictRedis.
|
||||
|
||||
Plugin alias: ``redis``
|
||||
|
||||
:param int db: the database number to store jobs in
|
||||
:param str jobs_key: key to store jobs in
|
||||
:param str run_times_key: key to store the jobs' run times in
|
||||
:param int pickle_protocol: pickle protocol level to use (for serialization), defaults to the highest available
|
||||
"""
|
||||
|
||||
def __init__(self, db=0, jobs_key='apscheduler.jobs', run_times_key='apscheduler.run_times',
|
||||
pickle_protocol=pickle.HIGHEST_PROTOCOL, **connect_args):
|
||||
super(RedisJobStore, self).__init__()
|
||||
|
||||
if db is None:
|
||||
raise ValueError('The "db" parameter must not be empty')
|
||||
if not jobs_key:
|
||||
raise ValueError('The "jobs_key" parameter must not be empty')
|
||||
if not run_times_key:
|
||||
raise ValueError('The "run_times_key" parameter must not be empty')
|
||||
|
||||
self.pickle_protocol = pickle_protocol
|
||||
self.jobs_key = jobs_key
|
||||
self.run_times_key = run_times_key
|
||||
self.redis = StrictRedis(db=int(db), **connect_args)
|
||||
|
||||
def lookup_job(self, job_id):
|
||||
job_state = self.redis.hget(self.jobs_key, job_id)
|
||||
return self._reconstitute_job(job_state) if job_state else None
|
||||
|
||||
def get_due_jobs(self, now):
|
||||
timestamp = datetime_to_utc_timestamp(now)
|
||||
job_ids = self.redis.zrangebyscore(self.run_times_key, 0, timestamp)
|
||||
if job_ids:
|
||||
job_states = self.redis.hmget(self.jobs_key, *job_ids)
|
||||
return self._reconstitute_jobs(six.moves.zip(job_ids, job_states))
|
||||
return []
|
||||
|
||||
def get_next_run_time(self):
|
||||
next_run_time = self.redis.zrange(self.run_times_key, 0, 0, withscores=True)
|
||||
if next_run_time:
|
||||
return utc_timestamp_to_datetime(next_run_time[0][1])
|
||||
|
||||
def get_all_jobs(self):
|
||||
job_states = self.redis.hgetall(self.jobs_key)
|
||||
jobs = self._reconstitute_jobs(six.iteritems(job_states))
|
||||
return sorted(jobs, key=lambda job: job.next_run_time)
|
||||
|
||||
def add_job(self, job):
|
||||
if self.redis.hexists(self.jobs_key, job.id):
|
||||
raise ConflictingIdError(job.id)
|
||||
|
||||
with self.redis.pipeline() as pipe:
|
||||
pipe.multi()
|
||||
pipe.hset(self.jobs_key, job.id, pickle.dumps(job.__getstate__(), self.pickle_protocol))
|
||||
pipe.zadd(self.run_times_key, datetime_to_utc_timestamp(job.next_run_time), job.id)
|
||||
pipe.execute()
|
||||
|
||||
def update_job(self, job):
|
||||
if not self.redis.hexists(self.jobs_key, job.id):
|
||||
raise JobLookupError(job.id)
|
||||
|
||||
with self.redis.pipeline() as pipe:
|
||||
pipe.hset(self.jobs_key, job.id, pickle.dumps(job.__getstate__(), self.pickle_protocol))
|
||||
if job.next_run_time:
|
||||
pipe.zadd(self.run_times_key, datetime_to_utc_timestamp(job.next_run_time), job.id)
|
||||
else:
|
||||
pipe.zrem(self.run_times_key, job.id)
|
||||
pipe.execute()
|
||||
|
||||
def remove_job(self, job_id):
|
||||
if not self.redis.hexists(self.jobs_key, job_id):
|
||||
raise JobLookupError(job_id)
|
||||
|
||||
with self.redis.pipeline() as pipe:
|
||||
pipe.hdel(self.jobs_key, job_id)
|
||||
pipe.zrem(self.run_times_key, job_id)
|
||||
pipe.execute()
|
||||
|
||||
def remove_all_jobs(self):
|
||||
with self.redis.pipeline() as pipe:
|
||||
pipe.delete(self.jobs_key)
|
||||
pipe.delete(self.run_times_key)
|
||||
pipe.execute()
|
||||
|
||||
def shutdown(self):
|
||||
self.redis.connection_pool.disconnect()
|
||||
|
||||
def _reconstitute_job(self, job_state):
|
||||
job_state = pickle.loads(job_state)
|
||||
job = Job.__new__(Job)
|
||||
job.__setstate__(job_state)
|
||||
job._scheduler = self._scheduler
|
||||
job._jobstore_alias = self._alias
|
||||
return job
|
||||
|
||||
def _reconstitute_jobs(self, job_states):
|
||||
jobs = []
|
||||
failed_job_ids = []
|
||||
for job_id, job_state in job_states:
|
||||
try:
|
||||
jobs.append(self._reconstitute_job(job_state))
|
||||
except:
|
||||
self._logger.exception('Unable to restore job "%s" -- removing it', job_id)
|
||||
failed_job_ids.append(job_id)
|
||||
|
||||
# Remove all the jobs we failed to restore
|
||||
if failed_job_ids:
|
||||
with self.redis.pipeline() as pipe:
|
||||
pipe.hdel(self.jobs_key, *failed_job_ids)
|
||||
pipe.zrem(self.run_times_key, *failed_job_ids)
|
||||
pipe.execute()
|
||||
|
||||
return jobs
|
||||
|
||||
def __repr__(self):
|
||||
return '<%s>' % self.__class__.__name__
|
||||
@@ -0,0 +1,137 @@
|
||||
from __future__ import absolute_import
|
||||
|
||||
from apscheduler.jobstores.base import BaseJobStore, JobLookupError, ConflictingIdError
|
||||
from apscheduler.util import maybe_ref, datetime_to_utc_timestamp, utc_timestamp_to_datetime
|
||||
from apscheduler.job import Job
|
||||
|
||||
try:
|
||||
import cPickle as pickle
|
||||
except ImportError: # pragma: nocover
|
||||
import pickle
|
||||
|
||||
try:
|
||||
from sqlalchemy import create_engine, Table, Column, MetaData, Unicode, Float, LargeBinary, select
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
except ImportError: # pragma: nocover
|
||||
raise ImportError('SQLAlchemyJobStore requires SQLAlchemy installed')
|
||||
|
||||
|
||||
class SQLAlchemyJobStore(BaseJobStore):
|
||||
"""
|
||||
Stores jobs in a database table using SQLAlchemy. The table will be created if it doesn't exist in the database.
|
||||
|
||||
Plugin alias: ``sqlalchemy``
|
||||
|
||||
:param str url: connection string (see `SQLAlchemy documentation
|
||||
<http://docs.sqlalchemy.org/en/latest/core/engines.html?highlight=create_engine#database-urls>`_
|
||||
on this)
|
||||
:param engine: an SQLAlchemy Engine to use instead of creating a new one based on ``url``
|
||||
:param str tablename: name of the table to store jobs in
|
||||
:param metadata: a :class:`~sqlalchemy.MetaData` instance to use instead of creating a new one
|
||||
:param int pickle_protocol: pickle protocol level to use (for serialization), defaults to the highest available
|
||||
"""
|
||||
|
||||
def __init__(self, url=None, engine=None, tablename='apscheduler_jobs', metadata=None,
|
||||
pickle_protocol=pickle.HIGHEST_PROTOCOL):
|
||||
super(SQLAlchemyJobStore, self).__init__()
|
||||
self.pickle_protocol = pickle_protocol
|
||||
metadata = maybe_ref(metadata) or MetaData()
|
||||
|
||||
if engine:
|
||||
self.engine = maybe_ref(engine)
|
||||
elif url:
|
||||
self.engine = create_engine(url)
|
||||
else:
|
||||
raise ValueError('Need either "engine" or "url" defined')
|
||||
|
||||
# 191 = max key length in MySQL for InnoDB/utf8mb4 tables, 25 = precision that translates to an 8-byte float
|
||||
self.jobs_t = Table(
|
||||
tablename, metadata,
|
||||
Column('id', Unicode(191, _warn_on_bytestring=False), primary_key=True),
|
||||
Column('next_run_time', Float(25), index=True),
|
||||
Column('job_state', LargeBinary, nullable=False)
|
||||
)
|
||||
|
||||
self.jobs_t.create(self.engine, True)
|
||||
|
||||
def lookup_job(self, job_id):
|
||||
selectable = select([self.jobs_t.c.job_state]).where(self.jobs_t.c.id == job_id)
|
||||
job_state = self.engine.execute(selectable).scalar()
|
||||
return self._reconstitute_job(job_state) if job_state else None
|
||||
|
||||
def get_due_jobs(self, now):
|
||||
timestamp = datetime_to_utc_timestamp(now)
|
||||
return self._get_jobs(self.jobs_t.c.next_run_time <= timestamp)
|
||||
|
||||
def get_next_run_time(self):
|
||||
selectable = select([self.jobs_t.c.next_run_time]).where(self.jobs_t.c.next_run_time != None).\
|
||||
order_by(self.jobs_t.c.next_run_time).limit(1)
|
||||
next_run_time = self.engine.execute(selectable).scalar()
|
||||
return utc_timestamp_to_datetime(next_run_time)
|
||||
|
||||
def get_all_jobs(self):
|
||||
return self._get_jobs()
|
||||
|
||||
def add_job(self, job):
|
||||
insert = self.jobs_t.insert().values(**{
|
||||
'id': job.id,
|
||||
'next_run_time': datetime_to_utc_timestamp(job.next_run_time),
|
||||
'job_state': pickle.dumps(job.__getstate__(), self.pickle_protocol)
|
||||
})
|
||||
try:
|
||||
self.engine.execute(insert)
|
||||
except IntegrityError:
|
||||
raise ConflictingIdError(job.id)
|
||||
|
||||
def update_job(self, job):
|
||||
update = self.jobs_t.update().values(**{
|
||||
'next_run_time': datetime_to_utc_timestamp(job.next_run_time),
|
||||
'job_state': pickle.dumps(job.__getstate__(), self.pickle_protocol)
|
||||
}).where(self.jobs_t.c.id == job.id)
|
||||
result = self.engine.execute(update)
|
||||
if result.rowcount == 0:
|
||||
raise JobLookupError(id)
|
||||
|
||||
def remove_job(self, job_id):
|
||||
delete = self.jobs_t.delete().where(self.jobs_t.c.id == job_id)
|
||||
result = self.engine.execute(delete)
|
||||
if result.rowcount == 0:
|
||||
raise JobLookupError(job_id)
|
||||
|
||||
def remove_all_jobs(self):
|
||||
delete = self.jobs_t.delete()
|
||||
self.engine.execute(delete)
|
||||
|
||||
def shutdown(self):
|
||||
self.engine.dispose()
|
||||
|
||||
def _reconstitute_job(self, job_state):
|
||||
job_state = pickle.loads(job_state)
|
||||
job_state['jobstore'] = self
|
||||
job = Job.__new__(Job)
|
||||
job.__setstate__(job_state)
|
||||
job._scheduler = self._scheduler
|
||||
job._jobstore_alias = self._alias
|
||||
return job
|
||||
|
||||
def _get_jobs(self, *conditions):
|
||||
jobs = []
|
||||
selectable = select([self.jobs_t.c.id, self.jobs_t.c.job_state]).order_by(self.jobs_t.c.next_run_time)
|
||||
selectable = selectable.where(*conditions) if conditions else selectable
|
||||
failed_job_ids = set()
|
||||
for row in self.engine.execute(selectable):
|
||||
try:
|
||||
jobs.append(self._reconstitute_job(row.job_state))
|
||||
except:
|
||||
self._logger.exception('Unable to restore job "%s" -- removing it', row.id)
|
||||
failed_job_ids.add(row.id)
|
||||
|
||||
# Remove all the jobs we failed to restore
|
||||
if failed_job_ids:
|
||||
delete = self.jobs_t.delete().where(self.jobs_t.c.id.in_(failed_job_ids))
|
||||
self.engine.execute(delete)
|
||||
|
||||
return jobs
|
||||
|
||||
def __repr__(self):
|
||||
return '<%s (url=%s)>' % (self.__class__.__name__, self.engine.url)
|
||||
@@ -0,0 +1,12 @@
|
||||
class SchedulerAlreadyRunningError(Exception):
|
||||
"""Raised when attempting to start or configure the scheduler when it's already running."""
|
||||
|
||||
def __str__(self):
|
||||
return 'Scheduler is already running'
|
||||
|
||||
|
||||
class SchedulerNotRunningError(Exception):
|
||||
"""Raised when attempting to shutdown the scheduler when it's not running."""
|
||||
|
||||
def __str__(self):
|
||||
return 'Scheduler is not running'
|
||||
@@ -0,0 +1,68 @@
|
||||
from __future__ import absolute_import
|
||||
from functools import wraps
|
||||
|
||||
from apscheduler.schedulers.base import BaseScheduler
|
||||
from apscheduler.util import maybe_ref
|
||||
|
||||
try:
|
||||
import asyncio
|
||||
except ImportError: # pragma: nocover
|
||||
try:
|
||||
import trollius as asyncio
|
||||
except ImportError:
|
||||
raise ImportError('AsyncIOScheduler requires either Python 3.4 or the asyncio package installed')
|
||||
|
||||
|
||||
def run_in_event_loop(func):
|
||||
@wraps(func)
|
||||
def wrapper(self, *args, **kwargs):
|
||||
self._eventloop.call_soon_threadsafe(func, self, *args, **kwargs)
|
||||
return wrapper
|
||||
|
||||
|
||||
class AsyncIOScheduler(BaseScheduler):
|
||||
"""
|
||||
A scheduler that runs on an asyncio (:pep:`3156`) event loop.
|
||||
|
||||
Extra options:
|
||||
|
||||
============== =============================================================
|
||||
``event_loop`` AsyncIO event loop to use (defaults to the global event loop)
|
||||
============== =============================================================
|
||||
"""
|
||||
|
||||
_eventloop = None
|
||||
_timeout = None
|
||||
|
||||
def start(self):
|
||||
super(AsyncIOScheduler, self).start()
|
||||
self.wakeup()
|
||||
|
||||
@run_in_event_loop
|
||||
def shutdown(self, wait=True):
|
||||
super(AsyncIOScheduler, self).shutdown(wait)
|
||||
self._stop_timer()
|
||||
|
||||
def _configure(self, config):
|
||||
self._eventloop = maybe_ref(config.pop('event_loop', None)) or asyncio.get_event_loop()
|
||||
super(AsyncIOScheduler, self)._configure(config)
|
||||
|
||||
def _start_timer(self, wait_seconds):
|
||||
self._stop_timer()
|
||||
if wait_seconds is not None:
|
||||
self._timeout = self._eventloop.call_later(wait_seconds, self.wakeup)
|
||||
|
||||
def _stop_timer(self):
|
||||
if self._timeout:
|
||||
self._timeout.cancel()
|
||||
del self._timeout
|
||||
|
||||
@run_in_event_loop
|
||||
def wakeup(self):
|
||||
self._stop_timer()
|
||||
wait_seconds = self._process_jobs()
|
||||
self._start_timer(wait_seconds)
|
||||
|
||||
def _create_default_executor(self):
|
||||
from apscheduler.executors.asyncio import AsyncIOExecutor
|
||||
return AsyncIOExecutor()
|
||||
@@ -0,0 +1,39 @@
|
||||
from __future__ import absolute_import
|
||||
from threading import Thread, Event
|
||||
|
||||
from apscheduler.schedulers.base import BaseScheduler
|
||||
from apscheduler.schedulers.blocking import BlockingScheduler
|
||||
from apscheduler.util import asbool
|
||||
|
||||
|
||||
class BackgroundScheduler(BlockingScheduler):
|
||||
"""
|
||||
A scheduler that runs in the background using a separate thread
|
||||
(:meth:`~apscheduler.schedulers.base.BaseScheduler.start` will return immediately).
|
||||
|
||||
Extra options:
|
||||
|
||||
========== ============================================================================================
|
||||
``daemon`` Set the ``daemon`` option in the background thread (defaults to ``True``,
|
||||
see `the documentation <https://docs.python.org/3.4/library/threading.html#thread-objects>`_
|
||||
for further details)
|
||||
========== ============================================================================================
|
||||
"""
|
||||
|
||||
_thread = None
|
||||
|
||||
def _configure(self, config):
|
||||
self._daemon = asbool(config.pop('daemon', True))
|
||||
super(BackgroundScheduler, self)._configure(config)
|
||||
|
||||
def start(self):
|
||||
BaseScheduler.start(self)
|
||||
self._event = Event()
|
||||
self._thread = Thread(target=self._main_loop, name='APScheduler')
|
||||
self._thread.daemon = self._daemon
|
||||
self._thread.start()
|
||||
|
||||
def shutdown(self, wait=True):
|
||||
super(BackgroundScheduler, self).shutdown(wait)
|
||||
self._thread.join()
|
||||
del self._thread
|
||||
@@ -0,0 +1,845 @@
|
||||
from __future__ import print_function
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from collections import MutableMapping
|
||||
from threading import RLock
|
||||
from datetime import datetime
|
||||
from logging import getLogger
|
||||
import sys
|
||||
|
||||
from pkg_resources import iter_entry_points
|
||||
from tzlocal import get_localzone
|
||||
import six
|
||||
|
||||
from apscheduler.schedulers import SchedulerAlreadyRunningError, SchedulerNotRunningError
|
||||
from apscheduler.executors.base import MaxInstancesReachedError, BaseExecutor
|
||||
from apscheduler.executors.pool import ThreadPoolExecutor
|
||||
from apscheduler.jobstores.base import ConflictingIdError, JobLookupError, BaseJobStore
|
||||
from apscheduler.jobstores.memory import MemoryJobStore
|
||||
from apscheduler.job import Job
|
||||
from apscheduler.triggers.base import BaseTrigger
|
||||
from apscheduler.util import asbool, asint, astimezone, maybe_ref, timedelta_seconds, undefined
|
||||
from apscheduler.events import (
|
||||
SchedulerEvent, JobEvent, EVENT_SCHEDULER_START, EVENT_SCHEDULER_SHUTDOWN, EVENT_JOBSTORE_ADDED,
|
||||
EVENT_JOBSTORE_REMOVED, EVENT_ALL, EVENT_JOB_MODIFIED, EVENT_JOB_REMOVED, EVENT_JOB_ADDED, EVENT_EXECUTOR_ADDED,
|
||||
EVENT_EXECUTOR_REMOVED, EVENT_ALL_JOBS_REMOVED)
|
||||
|
||||
|
||||
class BaseScheduler(six.with_metaclass(ABCMeta)):
|
||||
"""
|
||||
Abstract base class for all schedulers. Takes the following keyword arguments:
|
||||
|
||||
:param str|logging.Logger logger: logger to use for the scheduler's logging (defaults to apscheduler.scheduler)
|
||||
:param str|datetime.tzinfo timezone: the default time zone (defaults to the local timezone)
|
||||
:param dict job_defaults: default values for newly added jobs
|
||||
:param dict jobstores: a dictionary of job store alias -> job store instance or configuration dict
|
||||
:param dict executors: a dictionary of executor alias -> executor instance or configuration dict
|
||||
|
||||
.. seealso:: :ref:`scheduler-config`
|
||||
"""
|
||||
|
||||
_trigger_plugins = dict((ep.name, ep) for ep in iter_entry_points('apscheduler.triggers'))
|
||||
_trigger_classes = {}
|
||||
_executor_plugins = dict((ep.name, ep) for ep in iter_entry_points('apscheduler.executors'))
|
||||
_executor_classes = {}
|
||||
_jobstore_plugins = dict((ep.name, ep) for ep in iter_entry_points('apscheduler.jobstores'))
|
||||
_jobstore_classes = {}
|
||||
_stopped = True
|
||||
|
||||
#
|
||||
# Public API
|
||||
#
|
||||
|
||||
def __init__(self, gconfig={}, **options):
|
||||
super(BaseScheduler, self).__init__()
|
||||
self._executors = {}
|
||||
self._executors_lock = self._create_lock()
|
||||
self._jobstores = {}
|
||||
self._jobstores_lock = self._create_lock()
|
||||
self._listeners = []
|
||||
self._listeners_lock = self._create_lock()
|
||||
self._pending_jobs = []
|
||||
self.configure(gconfig, **options)
|
||||
|
||||
def configure(self, gconfig={}, prefix='apscheduler.', **options):
|
||||
"""
|
||||
Reconfigures the scheduler with the given options. Can only be done when the scheduler isn't running.
|
||||
|
||||
:param dict gconfig: a "global" configuration dictionary whose values can be overridden by keyword arguments to
|
||||
this method
|
||||
:param str|unicode prefix: pick only those keys from ``gconfig`` that are prefixed with this string
|
||||
(pass an empty string or ``None`` to use all keys)
|
||||
:raises SchedulerAlreadyRunningError: if the scheduler is already running
|
||||
"""
|
||||
|
||||
if self.running:
|
||||
raise SchedulerAlreadyRunningError
|
||||
|
||||
# If a non-empty prefix was given, strip it from the keys in the global configuration dict
|
||||
if prefix:
|
||||
prefixlen = len(prefix)
|
||||
gconfig = dict((key[prefixlen:], value) for key, value in six.iteritems(gconfig) if key.startswith(prefix))
|
||||
|
||||
# Create a structure from the dotted options (e.g. "a.b.c = d" -> {'a': {'b': {'c': 'd'}}})
|
||||
config = {}
|
||||
for key, value in six.iteritems(gconfig):
|
||||
parts = key.split('.')
|
||||
parent = config
|
||||
key = parts.pop(0)
|
||||
while parts:
|
||||
parent = parent.setdefault(key, {})
|
||||
key = parts.pop(0)
|
||||
parent[key] = value
|
||||
|
||||
# Override any options with explicit keyword arguments
|
||||
config.update(options)
|
||||
self._configure(config)
|
||||
|
||||
@abstractmethod
|
||||
def start(self):
|
||||
"""
|
||||
Starts the scheduler. The details of this process depend on the implementation.
|
||||
|
||||
:raises SchedulerAlreadyRunningError: if the scheduler is already running
|
||||
"""
|
||||
|
||||
if self.running:
|
||||
raise SchedulerAlreadyRunningError
|
||||
|
||||
with self._executors_lock:
|
||||
# Create a default executor if nothing else is configured
|
||||
if 'default' not in self._executors:
|
||||
self.add_executor(self._create_default_executor(), 'default')
|
||||
|
||||
# Start all the executors
|
||||
for alias, executor in six.iteritems(self._executors):
|
||||
executor.start(self, alias)
|
||||
|
||||
with self._jobstores_lock:
|
||||
# Create a default job store if nothing else is configured
|
||||
if 'default' not in self._jobstores:
|
||||
self.add_jobstore(self._create_default_jobstore(), 'default')
|
||||
|
||||
# Start all the job stores
|
||||
for alias, store in six.iteritems(self._jobstores):
|
||||
store.start(self, alias)
|
||||
|
||||
# Schedule all pending jobs
|
||||
for job, jobstore_alias, replace_existing in self._pending_jobs:
|
||||
self._real_add_job(job, jobstore_alias, replace_existing, False)
|
||||
del self._pending_jobs[:]
|
||||
|
||||
self._stopped = False
|
||||
self._logger.info('Scheduler started')
|
||||
|
||||
# Notify listeners that the scheduler has been started
|
||||
self._dispatch_event(SchedulerEvent(EVENT_SCHEDULER_START))
|
||||
|
||||
@abstractmethod
|
||||
def shutdown(self, wait=True):
|
||||
"""
|
||||
Shuts down the scheduler. Does not interrupt any currently running jobs.
|
||||
|
||||
:param bool wait: ``True`` to wait until all currently executing jobs have finished
|
||||
:raises SchedulerNotRunningError: if the scheduler has not been started yet
|
||||
"""
|
||||
|
||||
if not self.running:
|
||||
raise SchedulerNotRunningError
|
||||
|
||||
self._stopped = True
|
||||
|
||||
# Shut down all executors
|
||||
for executor in six.itervalues(self._executors):
|
||||
executor.shutdown(wait)
|
||||
|
||||
# Shut down all job stores
|
||||
for jobstore in six.itervalues(self._jobstores):
|
||||
jobstore.shutdown()
|
||||
|
||||
self._logger.info('Scheduler has been shut down')
|
||||
self._dispatch_event(SchedulerEvent(EVENT_SCHEDULER_SHUTDOWN))
|
||||
|
||||
@property
|
||||
def running(self):
|
||||
return not self._stopped
|
||||
|
||||
def add_executor(self, executor, alias='default', **executor_opts):
|
||||
"""
|
||||
Adds an executor to this scheduler. Any extra keyword arguments will be passed to the executor plugin's
|
||||
constructor, assuming that the first argument is the name of an executor plugin.
|
||||
|
||||
:param str|unicode|apscheduler.executors.base.BaseExecutor executor: either an executor instance or the name of
|
||||
an executor plugin
|
||||
:param str|unicode alias: alias for the scheduler
|
||||
:raises ValueError: if there is already an executor by the given alias
|
||||
"""
|
||||
|
||||
with self._executors_lock:
|
||||
if alias in self._executors:
|
||||
raise ValueError('This scheduler already has an executor by the alias of "%s"' % alias)
|
||||
|
||||
if isinstance(executor, BaseExecutor):
|
||||
self._executors[alias] = executor
|
||||
elif isinstance(executor, six.string_types):
|
||||
self._executors[alias] = executor = self._create_plugin_instance('executor', executor, executor_opts)
|
||||
else:
|
||||
raise TypeError('Expected an executor instance or a string, got %s instead' %
|
||||
executor.__class__.__name__)
|
||||
|
||||
# Start the executor right away if the scheduler is running
|
||||
if self.running:
|
||||
executor.start(self)
|
||||
|
||||
self._dispatch_event(SchedulerEvent(EVENT_EXECUTOR_ADDED, alias))
|
||||
|
||||
def remove_executor(self, alias, shutdown=True):
|
||||
"""
|
||||
Removes the executor by the given alias from this scheduler.
|
||||
|
||||
:param str|unicode alias: alias of the executor
|
||||
:param bool shutdown: ``True`` to shut down the executor after removing it
|
||||
"""
|
||||
|
||||
with self._jobstores_lock:
|
||||
executor = self._lookup_executor(alias)
|
||||
del self._executors[alias]
|
||||
|
||||
if shutdown:
|
||||
executor.shutdown()
|
||||
|
||||
self._dispatch_event(SchedulerEvent(EVENT_EXECUTOR_REMOVED, alias))
|
||||
|
||||
def add_jobstore(self, jobstore, alias='default', **jobstore_opts):
|
||||
"""
|
||||
Adds a job store to this scheduler. Any extra keyword arguments will be passed to the job store plugin's
|
||||
constructor, assuming that the first argument is the name of a job store plugin.
|
||||
|
||||
:param str|unicode|apscheduler.jobstores.base.BaseJobStore jobstore: job store to be added
|
||||
:param str|unicode alias: alias for the job store
|
||||
:raises ValueError: if there is already a job store by the given alias
|
||||
"""
|
||||
|
||||
with self._jobstores_lock:
|
||||
if alias in self._jobstores:
|
||||
raise ValueError('This scheduler already has a job store by the alias of "%s"' % alias)
|
||||
|
||||
if isinstance(jobstore, BaseJobStore):
|
||||
self._jobstores[alias] = jobstore
|
||||
elif isinstance(jobstore, six.string_types):
|
||||
self._jobstores[alias] = jobstore = self._create_plugin_instance('jobstore', jobstore, jobstore_opts)
|
||||
else:
|
||||
raise TypeError('Expected a job store instance or a string, got %s instead' %
|
||||
jobstore.__class__.__name__)
|
||||
|
||||
# Start the job store right away if the scheduler is running
|
||||
if self.running:
|
||||
jobstore.start(self, alias)
|
||||
|
||||
# Notify listeners that a new job store has been added
|
||||
self._dispatch_event(SchedulerEvent(EVENT_JOBSTORE_ADDED, alias))
|
||||
|
||||
# Notify the scheduler so it can scan the new job store for jobs
|
||||
if self.running:
|
||||
self.wakeup()
|
||||
|
||||
def remove_jobstore(self, alias, shutdown=True):
|
||||
"""
|
||||
Removes the job store by the given alias from this scheduler.
|
||||
|
||||
:param str|unicode alias: alias of the job store
|
||||
:param bool shutdown: ``True`` to shut down the job store after removing it
|
||||
"""
|
||||
|
||||
with self._jobstores_lock:
|
||||
jobstore = self._lookup_jobstore(alias)
|
||||
del self._jobstores[alias]
|
||||
|
||||
if shutdown:
|
||||
jobstore.shutdown()
|
||||
|
||||
self._dispatch_event(SchedulerEvent(EVENT_JOBSTORE_REMOVED, alias))
|
||||
|
||||
def add_listener(self, callback, mask=EVENT_ALL):
|
||||
"""
|
||||
add_listener(callback, mask=EVENT_ALL)
|
||||
|
||||
Adds a listener for scheduler events. When a matching event occurs, ``callback`` is executed with the event
|
||||
object as its sole argument. If the ``mask`` parameter is not provided, the callback will receive events of all
|
||||
types.
|
||||
|
||||
:param callback: any callable that takes one argument
|
||||
:param int mask: bitmask that indicates which events should be listened to
|
||||
|
||||
.. seealso:: :mod:`apscheduler.events`
|
||||
.. seealso:: :ref:`scheduler-events`
|
||||
"""
|
||||
|
||||
with self._listeners_lock:
|
||||
self._listeners.append((callback, mask))
|
||||
|
||||
def remove_listener(self, callback):
|
||||
"""Removes a previously added event listener."""
|
||||
|
||||
with self._listeners_lock:
|
||||
for i, (cb, _) in enumerate(self._listeners):
|
||||
if callback == cb:
|
||||
del self._listeners[i]
|
||||
|
||||
def add_job(self, func, trigger=None, args=None, kwargs=None, id=None, name=None, misfire_grace_time=undefined,
|
||||
coalesce=undefined, max_instances=undefined, next_run_time=undefined, jobstore='default',
|
||||
executor='default', replace_existing=False, **trigger_args):
|
||||
"""
|
||||
add_job(func, trigger=None, args=None, kwargs=None, id=None, name=None, misfire_grace_time=undefined, \
|
||||
coalesce=undefined, max_instances=undefined, next_run_time=undefined, jobstore='default', \
|
||||
executor='default', replace_existing=False, **trigger_args)
|
||||
|
||||
Adds the given job to the job list and wakes up the scheduler if it's already running.
|
||||
|
||||
Any option that defaults to ``undefined`` will be replaced with the corresponding default value when the job is
|
||||
scheduled (which happens when the scheduler is started, or immediately if the scheduler is already running).
|
||||
|
||||
The ``func`` argument can be given either as a callable object or a textual reference in the
|
||||
``package.module:some.object`` format, where the first half (separated by ``:``) is an importable module and the
|
||||
second half is a reference to the callable object, relative to the module.
|
||||
|
||||
The ``trigger`` argument can either be:
|
||||
#. the alias name of the trigger (e.g. ``date``, ``interval`` or ``cron``), in which case any extra keyword
|
||||
arguments to this method are passed on to the trigger's constructor
|
||||
#. an instance of a trigger class
|
||||
|
||||
:param func: callable (or a textual reference to one) to run at the given time
|
||||
:param str|apscheduler.triggers.base.BaseTrigger trigger: trigger that determines when ``func`` is called
|
||||
:param list|tuple args: list of positional arguments to call func with
|
||||
:param dict kwargs: dict of keyword arguments to call func with
|
||||
:param str|unicode id: explicit identifier for the job (for modifying it later)
|
||||
:param str|unicode name: textual description of the job
|
||||
:param int misfire_grace_time: seconds after the designated run time that the job is still allowed to be run
|
||||
:param bool coalesce: run once instead of many times if the scheduler determines that the job should be run more
|
||||
than once in succession
|
||||
:param int max_instances: maximum number of concurrently running instances allowed for this job
|
||||
:param datetime next_run_time: when to first run the job, regardless of the trigger (pass ``None`` to add the
|
||||
job as paused)
|
||||
:param str|unicode jobstore: alias of the job store to store the job in
|
||||
:param str|unicode executor: alias of the executor to run the job with
|
||||
:param bool replace_existing: ``True`` to replace an existing job with the same ``id`` (but retain the
|
||||
number of runs from the existing one)
|
||||
:rtype: Job
|
||||
"""
|
||||
|
||||
job_kwargs = {
|
||||
'trigger': self._create_trigger(trigger, trigger_args),
|
||||
'executor': executor,
|
||||
'func': func,
|
||||
'args': tuple(args) if args is not None else (),
|
||||
'kwargs': dict(kwargs) if kwargs is not None else {},
|
||||
'id': id,
|
||||
'name': name,
|
||||
'misfire_grace_time': misfire_grace_time,
|
||||
'coalesce': coalesce,
|
||||
'max_instances': max_instances,
|
||||
'next_run_time': next_run_time
|
||||
}
|
||||
job_kwargs = dict((key, value) for key, value in six.iteritems(job_kwargs) if value is not undefined)
|
||||
job = Job(self, **job_kwargs)
|
||||
|
||||
# Don't really add jobs to job stores before the scheduler is up and running
|
||||
with self._jobstores_lock:
|
||||
if not self.running:
|
||||
self._pending_jobs.append((job, jobstore, replace_existing))
|
||||
self._logger.info('Adding job tentatively -- it will be properly scheduled when the scheduler starts')
|
||||
else:
|
||||
self._real_add_job(job, jobstore, replace_existing, True)
|
||||
|
||||
return job
|
||||
|
||||
def scheduled_job(self, trigger, args=None, kwargs=None, id=None, name=None, misfire_grace_time=undefined,
|
||||
coalesce=undefined, max_instances=undefined, next_run_time=undefined, jobstore='default',
|
||||
executor='default', **trigger_args):
|
||||
"""
|
||||
scheduled_job(trigger, args=None, kwargs=None, id=None, name=None, misfire_grace_time=undefined, \
|
||||
coalesce=undefined, max_instances=undefined, next_run_time=undefined, jobstore='default', \
|
||||
executor='default',**trigger_args)
|
||||
|
||||
A decorator version of :meth:`add_job`, except that ``replace_existing`` is always ``True``.
|
||||
|
||||
.. important:: The ``id`` argument must be given if scheduling a job in a persistent job store. The scheduler
|
||||
cannot, however, enforce this requirement.
|
||||
"""
|
||||
|
||||
def inner(func):
|
||||
self.add_job(func, trigger, args, kwargs, id, name, misfire_grace_time, coalesce, max_instances,
|
||||
next_run_time, jobstore, executor, True, **trigger_args)
|
||||
return func
|
||||
return inner
|
||||
|
||||
def modify_job(self, job_id, jobstore=None, **changes):
|
||||
"""
|
||||
Modifies the properties of a single job. Modifications are passed to this method as extra keyword arguments.
|
||||
|
||||
:param str|unicode job_id: the identifier of the job
|
||||
:param str|unicode jobstore: alias of the job store that contains the job
|
||||
"""
|
||||
with self._jobstores_lock:
|
||||
job, jobstore = self._lookup_job(job_id, jobstore)
|
||||
job._modify(**changes)
|
||||
if jobstore:
|
||||
self._lookup_jobstore(jobstore).update_job(job)
|
||||
|
||||
self._dispatch_event(JobEvent(EVENT_JOB_MODIFIED, job_id, jobstore))
|
||||
|
||||
# Wake up the scheduler since the job's next run time may have been changed
|
||||
self.wakeup()
|
||||
|
||||
def reschedule_job(self, job_id, jobstore=None, trigger=None, **trigger_args):
|
||||
"""
|
||||
Constructs a new trigger for a job and updates its next run time.
|
||||
Extra keyword arguments are passed directly to the trigger's constructor.
|
||||
|
||||
:param str|unicode job_id: the identifier of the job
|
||||
:param str|unicode jobstore: alias of the job store that contains the job
|
||||
:param trigger: alias of the trigger type or a trigger instance
|
||||
"""
|
||||
|
||||
trigger = self._create_trigger(trigger, trigger_args)
|
||||
now = datetime.now(self.timezone)
|
||||
next_run_time = trigger.get_next_fire_time(None, now)
|
||||
self.modify_job(job_id, jobstore, trigger=trigger, next_run_time=next_run_time)
|
||||
|
||||
def pause_job(self, job_id, jobstore=None):
|
||||
"""
|
||||
Causes the given job not to be executed until it is explicitly resumed.
|
||||
|
||||
:param str|unicode job_id: the identifier of the job
|
||||
:param str|unicode jobstore: alias of the job store that contains the job
|
||||
"""
|
||||
|
||||
self.modify_job(job_id, jobstore, next_run_time=None)
|
||||
|
||||
def resume_job(self, job_id, jobstore=None):
|
||||
"""
|
||||
Resumes the schedule of the given job, or removes the job if its schedule is finished.
|
||||
|
||||
:param str|unicode job_id: the identifier of the job
|
||||
:param str|unicode jobstore: alias of the job store that contains the job
|
||||
"""
|
||||
|
||||
with self._jobstores_lock:
|
||||
job, jobstore = self._lookup_job(job_id, jobstore)
|
||||
now = datetime.now(self.timezone)
|
||||
next_run_time = job.trigger.get_next_fire_time(None, now)
|
||||
if next_run_time:
|
||||
self.modify_job(job_id, jobstore, next_run_time=next_run_time)
|
||||
else:
|
||||
self.remove_job(job.id, jobstore)
|
||||
|
||||
def get_jobs(self, jobstore=None, pending=None):
|
||||
"""
|
||||
Returns a list of pending jobs (if the scheduler hasn't been started yet) and scheduled jobs, either from a
|
||||
specific job store or from all of them.
|
||||
|
||||
:param str|unicode jobstore: alias of the job store
|
||||
:param bool pending: ``False`` to leave out pending jobs (jobs that are waiting for the scheduler start to be
|
||||
added to their respective job stores), ``True`` to only include pending jobs, anything else
|
||||
to return both
|
||||
:rtype: list[Job]
|
||||
"""
|
||||
|
||||
with self._jobstores_lock:
|
||||
jobs = []
|
||||
|
||||
if pending is not False:
|
||||
for job, alias, replace_existing in self._pending_jobs:
|
||||
if jobstore is None or alias == jobstore:
|
||||
jobs.append(job)
|
||||
|
||||
if pending is not True:
|
||||
for alias, store in six.iteritems(self._jobstores):
|
||||
if jobstore is None or alias == jobstore:
|
||||
jobs.extend(store.get_all_jobs())
|
||||
|
||||
return jobs
|
||||
|
||||
def get_job(self, job_id, jobstore=None):
|
||||
"""
|
||||
Returns the Job that matches the given ``job_id``.
|
||||
|
||||
:param str|unicode job_id: the identifier of the job
|
||||
:param str|unicode jobstore: alias of the job store that most likely contains the job
|
||||
:return: the Job by the given ID, or ``None`` if it wasn't found
|
||||
:rtype: Job
|
||||
"""
|
||||
|
||||
with self._jobstores_lock:
|
||||
try:
|
||||
return self._lookup_job(job_id, jobstore)[0]
|
||||
except JobLookupError:
|
||||
return
|
||||
|
||||
def remove_job(self, job_id, jobstore=None):
|
||||
"""
|
||||
Removes a job, preventing it from being run any more.
|
||||
|
||||
:param str|unicode job_id: the identifier of the job
|
||||
:param str|unicode jobstore: alias of the job store that contains the job
|
||||
:raises JobLookupError: if the job was not found
|
||||
"""
|
||||
|
||||
with self._jobstores_lock:
|
||||
# Check if the job is among the pending jobs
|
||||
for i, (job, jobstore_alias, replace_existing) in enumerate(self._pending_jobs):
|
||||
if job.id == job_id:
|
||||
del self._pending_jobs[i]
|
||||
jobstore = jobstore_alias
|
||||
break
|
||||
else:
|
||||
# Otherwise, try to remove it from each store until it succeeds or we run out of stores to check
|
||||
for alias, store in six.iteritems(self._jobstores):
|
||||
if jobstore in (None, alias):
|
||||
try:
|
||||
store.remove_job(job_id)
|
||||
except JobLookupError:
|
||||
continue
|
||||
|
||||
jobstore = alias
|
||||
break
|
||||
|
||||
if jobstore is None:
|
||||
raise JobLookupError(job_id)
|
||||
|
||||
# Notify listeners that a job has been removed
|
||||
event = JobEvent(EVENT_JOB_REMOVED, job_id, jobstore)
|
||||
self._dispatch_event(event)
|
||||
|
||||
self._logger.info('Removed job %s', job_id)
|
||||
|
||||
def remove_all_jobs(self, jobstore=None):
|
||||
"""
|
||||
Removes all jobs from the specified job store, or all job stores if none is given.
|
||||
|
||||
:param str|unicode jobstore: alias of the job store
|
||||
"""
|
||||
|
||||
with self._jobstores_lock:
|
||||
if jobstore:
|
||||
self._pending_jobs = [pending for pending in self._pending_jobs if pending[1] != jobstore]
|
||||
else:
|
||||
self._pending_jobs = []
|
||||
|
||||
for alias, store in six.iteritems(self._jobstores):
|
||||
if jobstore in (None, alias):
|
||||
store.remove_all_jobs()
|
||||
|
||||
self._dispatch_event(SchedulerEvent(EVENT_ALL_JOBS_REMOVED, jobstore))
|
||||
|
||||
def print_jobs(self, jobstore=None, out=None):
|
||||
"""
|
||||
print_jobs(jobstore=None, out=sys.stdout)
|
||||
|
||||
Prints out a textual listing of all jobs currently scheduled on either all job stores or just a specific one.
|
||||
|
||||
:param str|unicode jobstore: alias of the job store, ``None`` to list jobs from all stores
|
||||
:param file out: a file-like object to print to (defaults to **sys.stdout** if nothing is given)
|
||||
"""
|
||||
|
||||
out = out or sys.stdout
|
||||
with self._jobstores_lock:
|
||||
if self._pending_jobs:
|
||||
print(six.u('Pending jobs:'), file=out)
|
||||
for job, jobstore_alias, replace_existing in self._pending_jobs:
|
||||
if jobstore in (None, jobstore_alias):
|
||||
print(six.u(' %s') % job, file=out)
|
||||
|
||||
for alias, store in six.iteritems(self._jobstores):
|
||||
if jobstore in (None, alias):
|
||||
print(six.u('Jobstore %s:') % alias, file=out)
|
||||
jobs = store.get_all_jobs()
|
||||
if jobs:
|
||||
for job in jobs:
|
||||
print(six.u(' %s') % job, file=out)
|
||||
else:
|
||||
print(six.u(' No scheduled jobs'), file=out)
|
||||
|
||||
@abstractmethod
|
||||
def wakeup(self):
|
||||
"""
|
||||
Notifies the scheduler that there may be jobs due for execution.
|
||||
Triggers :meth:`_process_jobs` to be run in an implementation specific manner.
|
||||
"""
|
||||
|
||||
#
|
||||
# Private API
|
||||
#
|
||||
|
||||
def _configure(self, config):
|
||||
# Set general options
|
||||
self._logger = maybe_ref(config.pop('logger', None)) or getLogger('apscheduler.scheduler')
|
||||
self.timezone = astimezone(config.pop('timezone', None)) or get_localzone()
|
||||
|
||||
# Set the job defaults
|
||||
job_defaults = config.get('job_defaults', {})
|
||||
self._job_defaults = {
|
||||
'misfire_grace_time': asint(job_defaults.get('misfire_grace_time', 1)),
|
||||
'coalesce': asbool(job_defaults.get('coalesce', True)),
|
||||
'max_instances': asint(job_defaults.get('max_instances', 1))
|
||||
}
|
||||
|
||||
# Configure executors
|
||||
self._executors.clear()
|
||||
for alias, value in six.iteritems(config.get('executors', {})):
|
||||
if isinstance(value, BaseExecutor):
|
||||
self.add_executor(value, alias)
|
||||
elif isinstance(value, MutableMapping):
|
||||
executor_class = value.pop('class', None)
|
||||
plugin = value.pop('type', None)
|
||||
if plugin:
|
||||
executor = self._create_plugin_instance('executor', plugin, value)
|
||||
elif executor_class:
|
||||
cls = maybe_ref(executor_class)
|
||||
executor = cls(**value)
|
||||
else:
|
||||
raise ValueError('Cannot create executor "%s" -- either "type" or "class" must be defined' % alias)
|
||||
|
||||
self.add_executor(executor, alias)
|
||||
else:
|
||||
raise TypeError("Expected executor instance or dict for executors['%s'], got %s instead" % (
|
||||
alias, value.__class__.__name__))
|
||||
|
||||
# Configure job stores
|
||||
self._jobstores.clear()
|
||||
for alias, value in six.iteritems(config.get('jobstores', {})):
|
||||
if isinstance(value, BaseJobStore):
|
||||
self.add_jobstore(value, alias)
|
||||
elif isinstance(value, MutableMapping):
|
||||
jobstore_class = value.pop('class', None)
|
||||
plugin = value.pop('type', None)
|
||||
if plugin:
|
||||
jobstore = self._create_plugin_instance('jobstore', plugin, value)
|
||||
elif jobstore_class:
|
||||
cls = maybe_ref(jobstore_class)
|
||||
jobstore = cls(**value)
|
||||
else:
|
||||
raise ValueError('Cannot create job store "%s" -- either "type" or "class" must be defined' % alias)
|
||||
|
||||
self.add_jobstore(jobstore, alias)
|
||||
else:
|
||||
raise TypeError("Expected job store instance or dict for jobstores['%s'], got %s instead" % (
|
||||
alias, value.__class__.__name__))
|
||||
|
||||
def _create_default_executor(self):
|
||||
"""Creates a default executor store, specific to the particular scheduler type."""
|
||||
|
||||
return ThreadPoolExecutor()
|
||||
|
||||
def _create_default_jobstore(self):
|
||||
"""Creates a default job store, specific to the particular scheduler type."""
|
||||
|
||||
return MemoryJobStore()
|
||||
|
||||
def _lookup_executor(self, alias):
|
||||
"""
|
||||
Returns the executor instance by the given name from the list of executors that were added to this scheduler.
|
||||
|
||||
:type alias: str
|
||||
:raises KeyError: if no executor by the given alias is not found
|
||||
"""
|
||||
|
||||
try:
|
||||
return self._executors[alias]
|
||||
except KeyError:
|
||||
raise KeyError('No such executor: %s' % alias)
|
||||
|
||||
def _lookup_jobstore(self, alias):
|
||||
"""
|
||||
Returns the job store instance by the given name from the list of job stores that were added to this scheduler.
|
||||
|
||||
:type alias: str
|
||||
:raises KeyError: if no job store by the given alias is not found
|
||||
"""
|
||||
|
||||
try:
|
||||
return self._jobstores[alias]
|
||||
except KeyError:
|
||||
raise KeyError('No such job store: %s' % alias)
|
||||
|
||||
def _lookup_job(self, job_id, jobstore_alias):
|
||||
"""
|
||||
Finds a job by its ID.
|
||||
|
||||
:type job_id: str
|
||||
:param str jobstore_alias: alias of a job store to look in
|
||||
:return tuple[Job, str]: a tuple of job, jobstore alias (jobstore alias is None in case of a pending job)
|
||||
:raises JobLookupError: if no job by the given ID is found.
|
||||
"""
|
||||
|
||||
# Check if the job is among the pending jobs
|
||||
for job, alias, replace_existing in self._pending_jobs:
|
||||
if job.id == job_id:
|
||||
return job, None
|
||||
|
||||
# Look in all job stores
|
||||
for alias, store in six.iteritems(self._jobstores):
|
||||
if jobstore_alias in (None, alias):
|
||||
job = store.lookup_job(job_id)
|
||||
if job is not None:
|
||||
return job, alias
|
||||
|
||||
raise JobLookupError(job_id)
|
||||
|
||||
def _dispatch_event(self, event):
|
||||
"""
|
||||
Dispatches the given event to interested listeners.
|
||||
|
||||
:param SchedulerEvent event: the event to send
|
||||
"""
|
||||
|
||||
with self._listeners_lock:
|
||||
listeners = tuple(self._listeners)
|
||||
|
||||
for cb, mask in listeners:
|
||||
if event.code & mask:
|
||||
try:
|
||||
cb(event)
|
||||
except:
|
||||
self._logger.exception('Error notifying listener')
|
||||
|
||||
def _real_add_job(self, job, jobstore_alias, replace_existing, wakeup):
|
||||
"""
|
||||
:param Job job: the job to add
|
||||
:param bool replace_existing: ``True`` to use update_job() in case the job already exists in the store
|
||||
:param bool wakeup: ``True`` to wake up the scheduler after adding the job
|
||||
"""
|
||||
|
||||
# Fill in undefined values with defaults
|
||||
replacements = {}
|
||||
for key, value in six.iteritems(self._job_defaults):
|
||||
if not hasattr(job, key):
|
||||
replacements[key] = value
|
||||
|
||||
# Calculate the next run time if there is none defined
|
||||
if not hasattr(job, 'next_run_time'):
|
||||
now = datetime.now(self.timezone)
|
||||
replacements['next_run_time'] = job.trigger.get_next_fire_time(None, now)
|
||||
|
||||
# Apply any replacements
|
||||
job._modify(**replacements)
|
||||
|
||||
# Add the job to the given job store
|
||||
store = self._lookup_jobstore(jobstore_alias)
|
||||
try:
|
||||
store.add_job(job)
|
||||
except ConflictingIdError:
|
||||
if replace_existing:
|
||||
store.update_job(job)
|
||||
else:
|
||||
raise
|
||||
|
||||
# Mark the job as no longer pending
|
||||
job._jobstore_alias = jobstore_alias
|
||||
|
||||
# Notify listeners that a new job has been added
|
||||
event = JobEvent(EVENT_JOB_ADDED, job.id, jobstore_alias)
|
||||
self._dispatch_event(event)
|
||||
|
||||
self._logger.info('Added job "%s" to job store "%s"', job.name, jobstore_alias)
|
||||
|
||||
# Notify the scheduler about the new job
|
||||
if wakeup:
|
||||
self.wakeup()
|
||||
|
||||
def _create_plugin_instance(self, type_, alias, constructor_kwargs):
|
||||
"""Creates an instance of the given plugin type, loading the plugin first if necessary."""
|
||||
|
||||
plugin_container, class_container, base_class = {
|
||||
'trigger': (self._trigger_plugins, self._trigger_classes, BaseTrigger),
|
||||
'jobstore': (self._jobstore_plugins, self._jobstore_classes, BaseJobStore),
|
||||
'executor': (self._executor_plugins, self._executor_classes, BaseExecutor)
|
||||
}[type_]
|
||||
|
||||
try:
|
||||
plugin_cls = class_container[alias]
|
||||
except KeyError:
|
||||
if alias in plugin_container:
|
||||
plugin_cls = class_container[alias] = plugin_container[alias].load()
|
||||
if not issubclass(plugin_cls, base_class):
|
||||
raise TypeError('The {0} entry point does not point to a {0} class'.format(type_))
|
||||
else:
|
||||
raise LookupError('No {0} by the name "{1}" was found'.format(type_, alias))
|
||||
|
||||
return plugin_cls(**constructor_kwargs)
|
||||
|
||||
def _create_trigger(self, trigger, trigger_args):
|
||||
if isinstance(trigger, BaseTrigger):
|
||||
return trigger
|
||||
elif trigger is None:
|
||||
trigger = 'date'
|
||||
elif not isinstance(trigger, six.string_types):
|
||||
raise TypeError('Expected a trigger instance or string, got %s instead' % trigger.__class__.__name__)
|
||||
|
||||
# Use the scheduler's time zone if nothing else is specified
|
||||
trigger_args.setdefault('timezone', self.timezone)
|
||||
|
||||
# Instantiate the trigger class
|
||||
return self._create_plugin_instance('trigger', trigger, trigger_args)
|
||||
|
||||
def _create_lock(self):
|
||||
"""Creates a reentrant lock object."""
|
||||
|
||||
return RLock()
|
||||
|
||||
def _process_jobs(self):
|
||||
"""
|
||||
Iterates through jobs in every jobstore, starts jobs that are due and figures out how long to wait for the next
|
||||
round.
|
||||
"""
|
||||
|
||||
self._logger.debug('Looking for jobs to run')
|
||||
now = datetime.now(self.timezone)
|
||||
next_wakeup_time = None
|
||||
|
||||
with self._jobstores_lock:
|
||||
for jobstore_alias, jobstore in six.iteritems(self._jobstores):
|
||||
for job in jobstore.get_due_jobs(now):
|
||||
# Look up the job's executor
|
||||
try:
|
||||
executor = self._lookup_executor(job.executor)
|
||||
except:
|
||||
self._logger.error(
|
||||
'Executor lookup ("%s") failed for job "%s" -- removing it from the job store',
|
||||
job.executor, job)
|
||||
self.remove_job(job.id, jobstore_alias)
|
||||
continue
|
||||
|
||||
run_times = job._get_run_times(now)
|
||||
run_times = run_times[-1:] if run_times and job.coalesce else run_times
|
||||
if run_times:
|
||||
try:
|
||||
executor.submit_job(job, run_times)
|
||||
except MaxInstancesReachedError:
|
||||
self._logger.warning(
|
||||
'Execution of job "%s" skipped: maximum number of running instances reached (%d)',
|
||||
job, job.max_instances)
|
||||
except:
|
||||
self._logger.exception('Error submitting job "%s" to executor "%s"', job, job.executor)
|
||||
|
||||
# Update the job if it has a next execution time. Otherwise remove it from the job store.
|
||||
job_next_run = job.trigger.get_next_fire_time(run_times[-1], now)
|
||||
if job_next_run:
|
||||
job._modify(next_run_time=job_next_run)
|
||||
jobstore.update_job(job)
|
||||
else:
|
||||
self.remove_job(job.id, jobstore_alias)
|
||||
|
||||
# Set a new next wakeup time if there isn't one yet or the jobstore has an even earlier one
|
||||
jobstore_next_run_time = jobstore.get_next_run_time()
|
||||
if jobstore_next_run_time and (next_wakeup_time is None or jobstore_next_run_time < next_wakeup_time):
|
||||
next_wakeup_time = jobstore_next_run_time
|
||||
|
||||
# Determine the delay until this method should be called again
|
||||
if next_wakeup_time is not None:
|
||||
wait_seconds = max(timedelta_seconds(next_wakeup_time - now), 0)
|
||||
self._logger.debug('Next wakeup is due at %s (in %f seconds)', next_wakeup_time, wait_seconds)
|
||||
else:
|
||||
wait_seconds = None
|
||||
self._logger.debug('No jobs; waiting until a job is added')
|
||||
|
||||
return wait_seconds
|
||||
@@ -0,0 +1,32 @@
|
||||
from __future__ import absolute_import
|
||||
from threading import Event
|
||||
|
||||
from apscheduler.schedulers.base import BaseScheduler
|
||||
|
||||
|
||||
class BlockingScheduler(BaseScheduler):
|
||||
"""
|
||||
A scheduler that runs in the foreground (:meth:`~apscheduler.schedulers.base.BaseScheduler.start` will block).
|
||||
"""
|
||||
|
||||
MAX_WAIT_TIME = 4294967 # Maximum value accepted by Event.wait() on Windows
|
||||
|
||||
_event = None
|
||||
|
||||
def start(self):
|
||||
super(BlockingScheduler, self).start()
|
||||
self._event = Event()
|
||||
self._main_loop()
|
||||
|
||||
def shutdown(self, wait=True):
|
||||
super(BlockingScheduler, self).shutdown(wait)
|
||||
self._event.set()
|
||||
|
||||
def _main_loop(self):
|
||||
while self.running:
|
||||
wait_seconds = self._process_jobs()
|
||||
self._event.wait(wait_seconds if wait_seconds is not None else self.MAX_WAIT_TIME)
|
||||
self._event.clear()
|
||||
|
||||
def wakeup(self):
|
||||
self._event.set()
|
||||
@@ -0,0 +1,35 @@
|
||||
from __future__ import absolute_import
|
||||
|
||||
from apscheduler.schedulers.blocking import BlockingScheduler
|
||||
from apscheduler.schedulers.base import BaseScheduler
|
||||
|
||||
try:
|
||||
from gevent.event import Event
|
||||
from gevent.lock import RLock
|
||||
import gevent
|
||||
except ImportError: # pragma: nocover
|
||||
raise ImportError('GeventScheduler requires gevent installed')
|
||||
|
||||
|
||||
class GeventScheduler(BlockingScheduler):
|
||||
"""A scheduler that runs as a Gevent greenlet."""
|
||||
|
||||
_greenlet = None
|
||||
|
||||
def start(self):
|
||||
BaseScheduler.start(self)
|
||||
self._event = Event()
|
||||
self._greenlet = gevent.spawn(self._main_loop)
|
||||
return self._greenlet
|
||||
|
||||
def shutdown(self, wait=True):
|
||||
super(GeventScheduler, self).shutdown(wait)
|
||||
self._greenlet.join()
|
||||
del self._greenlet
|
||||
|
||||
def _create_lock(self):
|
||||
return RLock()
|
||||
|
||||
def _create_default_executor(self):
|
||||
from apscheduler.executors.gevent import GeventExecutor
|
||||
return GeventExecutor()
|
||||
@@ -0,0 +1,46 @@
|
||||
from __future__ import absolute_import
|
||||
|
||||
from apscheduler.schedulers.base import BaseScheduler
|
||||
|
||||
try:
|
||||
from PyQt5.QtCore import QObject, QTimer
|
||||
except ImportError: # pragma: nocover
|
||||
try:
|
||||
from PyQt4.QtCore import QObject, QTimer
|
||||
except ImportError:
|
||||
try:
|
||||
from PySide.QtCore import QObject, QTimer # flake8: noqa
|
||||
except ImportError:
|
||||
raise ImportError('QtScheduler requires either PyQt5, PyQt4 or PySide installed')
|
||||
|
||||
|
||||
class QtScheduler(BaseScheduler):
|
||||
"""A scheduler that runs in a Qt event loop."""
|
||||
|
||||
_timer = None
|
||||
|
||||
def start(self):
|
||||
super(QtScheduler, self).start()
|
||||
self.wakeup()
|
||||
|
||||
def shutdown(self, wait=True):
|
||||
super(QtScheduler, self).shutdown(wait)
|
||||
self._stop_timer()
|
||||
|
||||
def _start_timer(self, wait_seconds):
|
||||
self._stop_timer()
|
||||
if wait_seconds is not None:
|
||||
self._timer = QTimer.singleShot(wait_seconds * 1000, self._process_jobs)
|
||||
|
||||
def _stop_timer(self):
|
||||
if self._timer:
|
||||
if self._timer.isActive():
|
||||
self._timer.stop()
|
||||
del self._timer
|
||||
|
||||
def wakeup(self):
|
||||
self._start_timer(0)
|
||||
|
||||
def _process_jobs(self):
|
||||
wait_seconds = super(QtScheduler, self)._process_jobs()
|
||||
self._start_timer(wait_seconds)
|
||||
@@ -0,0 +1,60 @@
|
||||
from __future__ import absolute_import
|
||||
from datetime import timedelta
|
||||
from functools import wraps
|
||||
|
||||
from apscheduler.schedulers.base import BaseScheduler
|
||||
from apscheduler.util import maybe_ref
|
||||
|
||||
try:
|
||||
from tornado.ioloop import IOLoop
|
||||
except ImportError: # pragma: nocover
|
||||
raise ImportError('TornadoScheduler requires tornado installed')
|
||||
|
||||
|
||||
def run_in_ioloop(func):
|
||||
@wraps(func)
|
||||
def wrapper(self, *args, **kwargs):
|
||||
self._ioloop.add_callback(func, self, *args, **kwargs)
|
||||
return wrapper
|
||||
|
||||
|
||||
class TornadoScheduler(BaseScheduler):
|
||||
"""
|
||||
A scheduler that runs on a Tornado IOLoop.
|
||||
|
||||
=========== ===============================================================
|
||||
``io_loop`` Tornado IOLoop instance to use (defaults to the global IO loop)
|
||||
=========== ===============================================================
|
||||
"""
|
||||
|
||||
_ioloop = None
|
||||
_timeout = None
|
||||
|
||||
def start(self):
|
||||
super(TornadoScheduler, self).start()
|
||||
self.wakeup()
|
||||
|
||||
@run_in_ioloop
|
||||
def shutdown(self, wait=True):
|
||||
super(TornadoScheduler, self).shutdown(wait)
|
||||
self._stop_timer()
|
||||
|
||||
def _configure(self, config):
|
||||
self._ioloop = maybe_ref(config.pop('io_loop', None)) or IOLoop.current()
|
||||
super(TornadoScheduler, self)._configure(config)
|
||||
|
||||
def _start_timer(self, wait_seconds):
|
||||
self._stop_timer()
|
||||
if wait_seconds is not None:
|
||||
self._timeout = self._ioloop.add_timeout(timedelta(seconds=wait_seconds), self.wakeup)
|
||||
|
||||
def _stop_timer(self):
|
||||
if self._timeout:
|
||||
self._ioloop.remove_timeout(self._timeout)
|
||||
del self._timeout
|
||||
|
||||
@run_in_ioloop
|
||||
def wakeup(self):
|
||||
self._stop_timer()
|
||||
wait_seconds = self._process_jobs()
|
||||
self._start_timer(wait_seconds)
|
||||
@@ -0,0 +1,65 @@
|
||||
from __future__ import absolute_import
|
||||
from functools import wraps
|
||||
|
||||
from apscheduler.schedulers.base import BaseScheduler
|
||||
from apscheduler.util import maybe_ref
|
||||
|
||||
try:
|
||||
from twisted.internet import reactor as default_reactor
|
||||
except ImportError: # pragma: nocover
|
||||
raise ImportError('TwistedScheduler requires Twisted installed')
|
||||
|
||||
|
||||
def run_in_reactor(func):
|
||||
@wraps(func)
|
||||
def wrapper(self, *args, **kwargs):
|
||||
self._reactor.callFromThread(func, self, *args, **kwargs)
|
||||
return wrapper
|
||||
|
||||
|
||||
class TwistedScheduler(BaseScheduler):
|
||||
"""
|
||||
A scheduler that runs on a Twisted reactor.
|
||||
|
||||
Extra options:
|
||||
|
||||
=========== ========================================================
|
||||
``reactor`` Reactor instance to use (defaults to the global reactor)
|
||||
=========== ========================================================
|
||||
"""
|
||||
|
||||
_reactor = None
|
||||
_delayedcall = None
|
||||
|
||||
def _configure(self, config):
|
||||
self._reactor = maybe_ref(config.pop('reactor', default_reactor))
|
||||
super(TwistedScheduler, self)._configure(config)
|
||||
|
||||
def start(self):
|
||||
super(TwistedScheduler, self).start()
|
||||
self.wakeup()
|
||||
|
||||
@run_in_reactor
|
||||
def shutdown(self, wait=True):
|
||||
super(TwistedScheduler, self).shutdown(wait)
|
||||
self._stop_timer()
|
||||
|
||||
def _start_timer(self, wait_seconds):
|
||||
self._stop_timer()
|
||||
if wait_seconds is not None:
|
||||
self._delayedcall = self._reactor.callLater(wait_seconds, self.wakeup)
|
||||
|
||||
def _stop_timer(self):
|
||||
if self._delayedcall and self._delayedcall.active():
|
||||
self._delayedcall.cancel()
|
||||
del self._delayedcall
|
||||
|
||||
@run_in_reactor
|
||||
def wakeup(self):
|
||||
self._stop_timer()
|
||||
wait_seconds = self._process_jobs()
|
||||
self._start_timer(wait_seconds)
|
||||
|
||||
def _create_default_executor(self):
|
||||
from apscheduler.executors.twisted import TwistedExecutor
|
||||
return TwistedExecutor()
|
||||
@@ -0,0 +1,16 @@
|
||||
from abc import ABCMeta, abstractmethod
|
||||
|
||||
import six
|
||||
|
||||
|
||||
class BaseTrigger(six.with_metaclass(ABCMeta)):
|
||||
"""Abstract base class that defines the interface that every trigger must implement."""
|
||||
|
||||
@abstractmethod
|
||||
def get_next_fire_time(self, previous_fire_time, now):
|
||||
"""
|
||||
Returns the next datetime to fire on, If no such datetime can be calculated, returns ``None``.
|
||||
|
||||
:param datetime.datetime previous_fire_time: the previous time the trigger was fired
|
||||
:param datetime.datetime now: current datetime
|
||||
"""
|
||||
@@ -0,0 +1,176 @@
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from tzlocal import get_localzone
|
||||
import six
|
||||
|
||||
from apscheduler.triggers.base import BaseTrigger
|
||||
from apscheduler.triggers.cron.fields import BaseField, WeekField, DayOfMonthField, DayOfWeekField, DEFAULT_VALUES
|
||||
from apscheduler.util import datetime_ceil, convert_to_datetime, datetime_repr, astimezone
|
||||
|
||||
|
||||
class CronTrigger(BaseTrigger):
|
||||
"""
|
||||
Triggers when current time matches all specified time constraints, similarly to how the UNIX cron scheduler works.
|
||||
|
||||
:param int|str year: 4-digit year
|
||||
:param int|str month: month (1-12)
|
||||
:param int|str day: day of the (1-31)
|
||||
:param int|str week: ISO week (1-53)
|
||||
:param int|str day_of_week: number or name of weekday (0-6 or mon,tue,wed,thu,fri,sat,sun)
|
||||
:param int|str hour: hour (0-23)
|
||||
:param int|str minute: minute (0-59)
|
||||
:param int|str second: second (0-59)
|
||||
:param datetime|str start_date: earliest possible date/time to trigger on (inclusive)
|
||||
:param datetime|str end_date: latest possible date/time to trigger on (inclusive)
|
||||
:param datetime.tzinfo|str timezone: time zone to use for the date/time calculations
|
||||
(defaults to scheduler timezone)
|
||||
|
||||
.. note:: The first weekday is always **monday**.
|
||||
"""
|
||||
|
||||
FIELD_NAMES = ('year', 'month', 'day', 'week', 'day_of_week', 'hour', 'minute', 'second')
|
||||
FIELDS_MAP = {
|
||||
'year': BaseField,
|
||||
'month': BaseField,
|
||||
'week': WeekField,
|
||||
'day': DayOfMonthField,
|
||||
'day_of_week': DayOfWeekField,
|
||||
'hour': BaseField,
|
||||
'minute': BaseField,
|
||||
'second': BaseField
|
||||
}
|
||||
|
||||
__slots__ = 'timezone', 'start_date', 'end_date', 'fields'
|
||||
|
||||
def __init__(self, year=None, month=None, day=None, week=None, day_of_week=None, hour=None, minute=None,
|
||||
second=None, start_date=None, end_date=None, timezone=None):
|
||||
if timezone:
|
||||
self.timezone = astimezone(timezone)
|
||||
elif start_date and start_date.tzinfo:
|
||||
self.timezone = start_date.tzinfo
|
||||
elif end_date and end_date.tzinfo:
|
||||
self.timezone = end_date.tzinfo
|
||||
else:
|
||||
self.timezone = get_localzone()
|
||||
|
||||
self.start_date = convert_to_datetime(start_date, self.timezone, 'start_date')
|
||||
self.end_date = convert_to_datetime(end_date, self.timezone, 'end_date')
|
||||
|
||||
values = dict((key, value) for (key, value) in six.iteritems(locals())
|
||||
if key in self.FIELD_NAMES and value is not None)
|
||||
self.fields = []
|
||||
assign_defaults = False
|
||||
for field_name in self.FIELD_NAMES:
|
||||
if field_name in values:
|
||||
exprs = values.pop(field_name)
|
||||
is_default = False
|
||||
assign_defaults = not values
|
||||
elif assign_defaults:
|
||||
exprs = DEFAULT_VALUES[field_name]
|
||||
is_default = True
|
||||
else:
|
||||
exprs = '*'
|
||||
is_default = True
|
||||
|
||||
field_class = self.FIELDS_MAP[field_name]
|
||||
field = field_class(field_name, exprs, is_default)
|
||||
self.fields.append(field)
|
||||
|
||||
def _increment_field_value(self, dateval, fieldnum):
|
||||
"""
|
||||
Increments the designated field and resets all less significant fields to their minimum values.
|
||||
|
||||
:type dateval: datetime
|
||||
:type fieldnum: int
|
||||
:return: a tuple containing the new date, and the number of the field that was actually incremented
|
||||
:rtype: tuple
|
||||
"""
|
||||
|
||||
values = {}
|
||||
i = 0
|
||||
while i < len(self.fields):
|
||||
field = self.fields[i]
|
||||
if not field.REAL:
|
||||
if i == fieldnum:
|
||||
fieldnum -= 1
|
||||
i -= 1
|
||||
else:
|
||||
i += 1
|
||||
continue
|
||||
|
||||
if i < fieldnum:
|
||||
values[field.name] = field.get_value(dateval)
|
||||
i += 1
|
||||
elif i > fieldnum:
|
||||
values[field.name] = field.get_min(dateval)
|
||||
i += 1
|
||||
else:
|
||||
value = field.get_value(dateval)
|
||||
maxval = field.get_max(dateval)
|
||||
if value == maxval:
|
||||
fieldnum -= 1
|
||||
i -= 1
|
||||
else:
|
||||
values[field.name] = value + 1
|
||||
i += 1
|
||||
|
||||
difference = datetime(**values) - dateval.replace(tzinfo=None)
|
||||
return self.timezone.normalize(dateval + difference), fieldnum
|
||||
|
||||
def _set_field_value(self, dateval, fieldnum, new_value):
|
||||
values = {}
|
||||
for i, field in enumerate(self.fields):
|
||||
if field.REAL:
|
||||
if i < fieldnum:
|
||||
values[field.name] = field.get_value(dateval)
|
||||
elif i > fieldnum:
|
||||
values[field.name] = field.get_min(dateval)
|
||||
else:
|
||||
values[field.name] = new_value
|
||||
|
||||
difference = datetime(**values) - dateval.replace(tzinfo=None)
|
||||
return self.timezone.normalize(dateval + difference)
|
||||
|
||||
def get_next_fire_time(self, previous_fire_time, now):
|
||||
if previous_fire_time:
|
||||
start_date = max(now, previous_fire_time + timedelta(microseconds=1))
|
||||
else:
|
||||
start_date = max(now, self.start_date) if self.start_date else now
|
||||
|
||||
fieldnum = 0
|
||||
next_date = datetime_ceil(start_date).astimezone(self.timezone)
|
||||
while 0 <= fieldnum < len(self.fields):
|
||||
field = self.fields[fieldnum]
|
||||
curr_value = field.get_value(next_date)
|
||||
next_value = field.get_next_value(next_date)
|
||||
|
||||
if next_value is None:
|
||||
# No valid value was found
|
||||
next_date, fieldnum = self._increment_field_value(next_date, fieldnum - 1)
|
||||
elif next_value > curr_value:
|
||||
# A valid, but higher than the starting value, was found
|
||||
if field.REAL:
|
||||
next_date = self._set_field_value(next_date, fieldnum, next_value)
|
||||
fieldnum += 1
|
||||
else:
|
||||
next_date, fieldnum = self._increment_field_value(next_date, fieldnum)
|
||||
else:
|
||||
# A valid value was found, no changes necessary
|
||||
fieldnum += 1
|
||||
|
||||
# Return if the date has rolled past the end date
|
||||
if self.end_date and next_date > self.end_date:
|
||||
return None
|
||||
|
||||
if fieldnum >= 0:
|
||||
return next_date
|
||||
|
||||
def __str__(self):
|
||||
options = ["%s='%s'" % (f.name, f) for f in self.fields if not f.is_default]
|
||||
return 'cron[%s]' % (', '.join(options))
|
||||
|
||||
def __repr__(self):
|
||||
options = ["%s='%s'" % (f.name, f) for f in self.fields if not f.is_default]
|
||||
if self.start_date:
|
||||
options.append("start_date='%s'" % datetime_repr(self.start_date))
|
||||
return '<%s (%s)>' % (self.__class__.__name__, ', '.join(options))
|
||||
@@ -0,0 +1,188 @@
|
||||
"""
|
||||
This module contains the expressions applicable for CronTrigger's fields.
|
||||
"""
|
||||
|
||||
from calendar import monthrange
|
||||
import re
|
||||
|
||||
from apscheduler.util import asint
|
||||
|
||||
__all__ = ('AllExpression', 'RangeExpression', 'WeekdayRangeExpression', 'WeekdayPositionExpression',
|
||||
'LastDayOfMonthExpression')
|
||||
|
||||
|
||||
WEEKDAYS = ['mon', 'tue', 'wed', 'thu', 'fri', 'sat', 'sun']
|
||||
|
||||
|
||||
class AllExpression(object):
|
||||
value_re = re.compile(r'\*(?:/(?P<step>\d+))?$')
|
||||
|
||||
def __init__(self, step=None):
|
||||
self.step = asint(step)
|
||||
if self.step == 0:
|
||||
raise ValueError('Increment must be higher than 0')
|
||||
|
||||
def get_next_value(self, date, field):
|
||||
start = field.get_value(date)
|
||||
minval = field.get_min(date)
|
||||
maxval = field.get_max(date)
|
||||
start = max(start, minval)
|
||||
|
||||
if not self.step:
|
||||
next = start
|
||||
else:
|
||||
distance_to_next = (self.step - (start - minval)) % self.step
|
||||
next = start + distance_to_next
|
||||
|
||||
if next <= maxval:
|
||||
return next
|
||||
|
||||
def __str__(self):
|
||||
if self.step:
|
||||
return '*/%d' % self.step
|
||||
return '*'
|
||||
|
||||
def __repr__(self):
|
||||
return "%s(%s)" % (self.__class__.__name__, self.step)
|
||||
|
||||
|
||||
class RangeExpression(AllExpression):
|
||||
value_re = re.compile(
|
||||
r'(?P<first>\d+)(?:-(?P<last>\d+))?(?:/(?P<step>\d+))?$')
|
||||
|
||||
def __init__(self, first, last=None, step=None):
|
||||
AllExpression.__init__(self, step)
|
||||
first = asint(first)
|
||||
last = asint(last)
|
||||
if last is None and step is None:
|
||||
last = first
|
||||
if last is not None and first > last:
|
||||
raise ValueError('The minimum value in a range must not be higher than the maximum')
|
||||
self.first = first
|
||||
self.last = last
|
||||
|
||||
def get_next_value(self, date, field):
|
||||
start = field.get_value(date)
|
||||
minval = field.get_min(date)
|
||||
maxval = field.get_max(date)
|
||||
|
||||
# Apply range limits
|
||||
minval = max(minval, self.first)
|
||||
if self.last is not None:
|
||||
maxval = min(maxval, self.last)
|
||||
start = max(start, minval)
|
||||
|
||||
if not self.step:
|
||||
next = start
|
||||
else:
|
||||
distance_to_next = (self.step - (start - minval)) % self.step
|
||||
next = start + distance_to_next
|
||||
|
||||
if next <= maxval:
|
||||
return next
|
||||
|
||||
def __str__(self):
|
||||
if self.last != self.first and self.last is not None:
|
||||
range = '%d-%d' % (self.first, self.last)
|
||||
else:
|
||||
range = str(self.first)
|
||||
|
||||
if self.step:
|
||||
return '%s/%d' % (range, self.step)
|
||||
return range
|
||||
|
||||
def __repr__(self):
|
||||
args = [str(self.first)]
|
||||
if self.last != self.first and self.last is not None or self.step:
|
||||
args.append(str(self.last))
|
||||
if self.step:
|
||||
args.append(str(self.step))
|
||||
return "%s(%s)" % (self.__class__.__name__, ', '.join(args))
|
||||
|
||||
|
||||
class WeekdayRangeExpression(RangeExpression):
|
||||
value_re = re.compile(r'(?P<first>[a-z]+)(?:-(?P<last>[a-z]+))?', re.IGNORECASE)
|
||||
|
||||
def __init__(self, first, last=None):
|
||||
try:
|
||||
first_num = WEEKDAYS.index(first.lower())
|
||||
except ValueError:
|
||||
raise ValueError('Invalid weekday name "%s"' % first)
|
||||
|
||||
if last:
|
||||
try:
|
||||
last_num = WEEKDAYS.index(last.lower())
|
||||
except ValueError:
|
||||
raise ValueError('Invalid weekday name "%s"' % last)
|
||||
else:
|
||||
last_num = None
|
||||
|
||||
RangeExpression.__init__(self, first_num, last_num)
|
||||
|
||||
def __str__(self):
|
||||
if self.last != self.first and self.last is not None:
|
||||
return '%s-%s' % (WEEKDAYS[self.first], WEEKDAYS[self.last])
|
||||
return WEEKDAYS[self.first]
|
||||
|
||||
def __repr__(self):
|
||||
args = ["'%s'" % WEEKDAYS[self.first]]
|
||||
if self.last != self.first and self.last is not None:
|
||||
args.append("'%s'" % WEEKDAYS[self.last])
|
||||
return "%s(%s)" % (self.__class__.__name__, ', '.join(args))
|
||||
|
||||
|
||||
class WeekdayPositionExpression(AllExpression):
|
||||
options = ['1st', '2nd', '3rd', '4th', '5th', 'last']
|
||||
value_re = re.compile(r'(?P<option_name>%s) +(?P<weekday_name>(?:\d+|\w+))' % '|'.join(options), re.IGNORECASE)
|
||||
|
||||
def __init__(self, option_name, weekday_name):
|
||||
try:
|
||||
self.option_num = self.options.index(option_name.lower())
|
||||
except ValueError:
|
||||
raise ValueError('Invalid weekday position "%s"' % option_name)
|
||||
|
||||
try:
|
||||
self.weekday = WEEKDAYS.index(weekday_name.lower())
|
||||
except ValueError:
|
||||
raise ValueError('Invalid weekday name "%s"' % weekday_name)
|
||||
|
||||
def get_next_value(self, date, field):
|
||||
# Figure out the weekday of the month's first day and the number
|
||||
# of days in that month
|
||||
first_day_wday, last_day = monthrange(date.year, date.month)
|
||||
|
||||
# Calculate which day of the month is the first of the target weekdays
|
||||
first_hit_day = self.weekday - first_day_wday + 1
|
||||
if first_hit_day <= 0:
|
||||
first_hit_day += 7
|
||||
|
||||
# Calculate what day of the month the target weekday would be
|
||||
if self.option_num < 5:
|
||||
target_day = first_hit_day + self.option_num * 7
|
||||
else:
|
||||
target_day = first_hit_day + ((last_day - first_hit_day) / 7) * 7
|
||||
|
||||
if target_day <= last_day and target_day >= date.day:
|
||||
return target_day
|
||||
|
||||
def __str__(self):
|
||||
return '%s %s' % (self.options[self.option_num], WEEKDAYS[self.weekday])
|
||||
|
||||
def __repr__(self):
|
||||
return "%s('%s', '%s')" % (self.__class__.__name__, self.options[self.option_num], WEEKDAYS[self.weekday])
|
||||
|
||||
|
||||
class LastDayOfMonthExpression(AllExpression):
|
||||
value_re = re.compile(r'last', re.IGNORECASE)
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def get_next_value(self, date, field):
|
||||
return monthrange(date.year, date.month)[1]
|
||||
|
||||
def __str__(self):
|
||||
return 'last'
|
||||
|
||||
def __repr__(self):
|
||||
return "%s()" % self.__class__.__name__
|
||||
@@ -0,0 +1,97 @@
|
||||
"""
|
||||
Fields represent CronTrigger options which map to :class:`~datetime.datetime`
|
||||
fields.
|
||||
"""
|
||||
|
||||
from calendar import monthrange
|
||||
|
||||
from apscheduler.triggers.cron.expressions import (
|
||||
AllExpression, RangeExpression, WeekdayPositionExpression, LastDayOfMonthExpression, WeekdayRangeExpression)
|
||||
|
||||
|
||||
__all__ = ('MIN_VALUES', 'MAX_VALUES', 'DEFAULT_VALUES', 'BaseField', 'WeekField', 'DayOfMonthField', 'DayOfWeekField')
|
||||
|
||||
|
||||
MIN_VALUES = {'year': 1970, 'month': 1, 'day': 1, 'week': 1, 'day_of_week': 0, 'hour': 0, 'minute': 0, 'second': 0}
|
||||
MAX_VALUES = {'year': 2 ** 63, 'month': 12, 'day:': 31, 'week': 53, 'day_of_week': 6, 'hour': 23, 'minute': 59,
|
||||
'second': 59}
|
||||
DEFAULT_VALUES = {'year': '*', 'month': 1, 'day': 1, 'week': '*', 'day_of_week': '*', 'hour': 0, 'minute': 0,
|
||||
'second': 0}
|
||||
|
||||
|
||||
class BaseField(object):
|
||||
REAL = True
|
||||
COMPILERS = [AllExpression, RangeExpression]
|
||||
|
||||
def __init__(self, name, exprs, is_default=False):
|
||||
self.name = name
|
||||
self.is_default = is_default
|
||||
self.compile_expressions(exprs)
|
||||
|
||||
def get_min(self, dateval):
|
||||
return MIN_VALUES[self.name]
|
||||
|
||||
def get_max(self, dateval):
|
||||
return MAX_VALUES[self.name]
|
||||
|
||||
def get_value(self, dateval):
|
||||
return getattr(dateval, self.name)
|
||||
|
||||
def get_next_value(self, dateval):
|
||||
smallest = None
|
||||
for expr in self.expressions:
|
||||
value = expr.get_next_value(dateval, self)
|
||||
if smallest is None or (value is not None and value < smallest):
|
||||
smallest = value
|
||||
|
||||
return smallest
|
||||
|
||||
def compile_expressions(self, exprs):
|
||||
self.expressions = []
|
||||
|
||||
# Split a comma-separated expression list, if any
|
||||
exprs = str(exprs).strip()
|
||||
if ',' in exprs:
|
||||
for expr in exprs.split(','):
|
||||
self.compile_expression(expr)
|
||||
else:
|
||||
self.compile_expression(exprs)
|
||||
|
||||
def compile_expression(self, expr):
|
||||
for compiler in self.COMPILERS:
|
||||
match = compiler.value_re.match(expr)
|
||||
if match:
|
||||
compiled_expr = compiler(**match.groupdict())
|
||||
self.expressions.append(compiled_expr)
|
||||
return
|
||||
|
||||
raise ValueError('Unrecognized expression "%s" for field "%s"' % (expr, self.name))
|
||||
|
||||
def __str__(self):
|
||||
expr_strings = (str(e) for e in self.expressions)
|
||||
return ','.join(expr_strings)
|
||||
|
||||
def __repr__(self):
|
||||
return "%s('%s', '%s')" % (self.__class__.__name__, self.name, self)
|
||||
|
||||
|
||||
class WeekField(BaseField):
|
||||
REAL = False
|
||||
|
||||
def get_value(self, dateval):
|
||||
return dateval.isocalendar()[1]
|
||||
|
||||
|
||||
class DayOfMonthField(BaseField):
|
||||
COMPILERS = BaseField.COMPILERS + [WeekdayPositionExpression, LastDayOfMonthExpression]
|
||||
|
||||
def get_max(self, dateval):
|
||||
return monthrange(dateval.year, dateval.month)[1]
|
||||
|
||||
|
||||
class DayOfWeekField(BaseField):
|
||||
REAL = False
|
||||
COMPILERS = BaseField.COMPILERS + [WeekdayRangeExpression]
|
||||
|
||||
def get_value(self, dateval):
|
||||
return dateval.weekday()
|
||||
@@ -0,0 +1,30 @@
|
||||
from datetime import datetime
|
||||
|
||||
from tzlocal import get_localzone
|
||||
|
||||
from apscheduler.triggers.base import BaseTrigger
|
||||
from apscheduler.util import convert_to_datetime, datetime_repr, astimezone
|
||||
|
||||
|
||||
class DateTrigger(BaseTrigger):
|
||||
"""
|
||||
Triggers once on the given datetime. If ``run_date`` is left empty, current time is used.
|
||||
|
||||
:param datetime|str run_date: the date/time to run the job at
|
||||
:param datetime.tzinfo|str timezone: time zone for ``run_date`` if it doesn't have one already
|
||||
"""
|
||||
|
||||
__slots__ = 'timezone', 'run_date'
|
||||
|
||||
def __init__(self, run_date=None, timezone=None):
|
||||
timezone = astimezone(timezone) or get_localzone()
|
||||
self.run_date = convert_to_datetime(run_date or datetime.now(), timezone, 'run_date')
|
||||
|
||||
def get_next_fire_time(self, previous_fire_time, now):
|
||||
return self.run_date if previous_fire_time is None else None
|
||||
|
||||
def __str__(self):
|
||||
return 'date[%s]' % datetime_repr(self.run_date)
|
||||
|
||||
def __repr__(self):
|
||||
return "<%s (run_date='%s')>" % (self.__class__.__name__, datetime_repr(self.run_date))
|
||||
@@ -0,0 +1,65 @@
|
||||
from datetime import timedelta, datetime
|
||||
from math import ceil
|
||||
|
||||
from tzlocal import get_localzone
|
||||
|
||||
from apscheduler.triggers.base import BaseTrigger
|
||||
from apscheduler.util import convert_to_datetime, timedelta_seconds, datetime_repr, astimezone
|
||||
|
||||
|
||||
class IntervalTrigger(BaseTrigger):
|
||||
"""
|
||||
Triggers on specified intervals, starting on ``start_date`` if specified, ``datetime.now()`` + interval
|
||||
otherwise.
|
||||
|
||||
:param int weeks: number of weeks to wait
|
||||
:param int days: number of days to wait
|
||||
:param int hours: number of hours to wait
|
||||
:param int minutes: number of minutes to wait
|
||||
:param int seconds: number of seconds to wait
|
||||
:param datetime|str start_date: starting point for the interval calculation
|
||||
:param datetime|str end_date: latest possible date/time to trigger on
|
||||
:param datetime.tzinfo|str timezone: time zone to use for the date/time calculations
|
||||
"""
|
||||
|
||||
__slots__ = 'timezone', 'start_date', 'end_date', 'interval'
|
||||
|
||||
def __init__(self, weeks=0, days=0, hours=0, minutes=0, seconds=0, start_date=None, end_date=None, timezone=None):
|
||||
self.interval = timedelta(weeks=weeks, days=days, hours=hours, minutes=minutes, seconds=seconds)
|
||||
self.interval_length = timedelta_seconds(self.interval)
|
||||
if self.interval_length == 0:
|
||||
self.interval = timedelta(seconds=1)
|
||||
self.interval_length = 1
|
||||
|
||||
if timezone:
|
||||
self.timezone = astimezone(timezone)
|
||||
elif start_date and start_date.tzinfo:
|
||||
self.timezone = start_date.tzinfo
|
||||
elif end_date and end_date.tzinfo:
|
||||
self.timezone = end_date.tzinfo
|
||||
else:
|
||||
self.timezone = get_localzone()
|
||||
|
||||
start_date = start_date or (datetime.now(self.timezone) + self.interval)
|
||||
self.start_date = convert_to_datetime(start_date, self.timezone, 'start_date')
|
||||
self.end_date = convert_to_datetime(end_date, self.timezone, 'end_date')
|
||||
|
||||
def get_next_fire_time(self, previous_fire_time, now):
|
||||
if previous_fire_time:
|
||||
next_fire_time = previous_fire_time + self.interval
|
||||
elif self.start_date > now:
|
||||
next_fire_time = self.start_date
|
||||
else:
|
||||
timediff_seconds = timedelta_seconds(now - self.start_date)
|
||||
next_interval_num = int(ceil(timediff_seconds / self.interval_length))
|
||||
next_fire_time = self.start_date + self.interval * next_interval_num
|
||||
|
||||
if not self.end_date or next_fire_time <= self.end_date:
|
||||
return self.timezone.normalize(next_fire_time)
|
||||
|
||||
def __str__(self):
|
||||
return 'interval[%s]' % str(self.interval)
|
||||
|
||||
def __repr__(self):
|
||||
return "<%s (interval=%r, start_date='%s')>" % (self.__class__.__name__, self.interval,
|
||||
datetime_repr(self.start_date))
|
||||
@@ -0,0 +1,385 @@
|
||||
"""This module contains several handy functions primarily meant for internal use."""
|
||||
|
||||
from __future__ import division
|
||||
from datetime import date, datetime, time, timedelta, tzinfo
|
||||
from inspect import isfunction, ismethod, getargspec
|
||||
from calendar import timegm
|
||||
import re
|
||||
|
||||
from pytz import timezone, utc
|
||||
import six
|
||||
|
||||
try:
|
||||
from inspect import signature
|
||||
except ImportError: # pragma: nocover
|
||||
try:
|
||||
from funcsigs import signature
|
||||
except ImportError:
|
||||
signature = None
|
||||
|
||||
__all__ = ('asint', 'asbool', 'astimezone', 'convert_to_datetime', 'datetime_to_utc_timestamp',
|
||||
'utc_timestamp_to_datetime', 'timedelta_seconds', 'datetime_ceil', 'get_callable_name', 'obj_to_ref',
|
||||
'ref_to_obj', 'maybe_ref', 'repr_escape', 'check_callable_args')
|
||||
|
||||
|
||||
class _Undefined(object):
|
||||
def __nonzero__(self):
|
||||
return False
|
||||
|
||||
def __bool__(self):
|
||||
return False
|
||||
|
||||
def __repr__(self):
|
||||
return '<undefined>'
|
||||
|
||||
undefined = _Undefined() #: a unique object that only signifies that no value is defined
|
||||
|
||||
|
||||
def asint(text):
|
||||
"""
|
||||
Safely converts a string to an integer, returning None if the string is None.
|
||||
|
||||
:type text: str
|
||||
:rtype: int
|
||||
"""
|
||||
|
||||
if text is not None:
|
||||
return int(text)
|
||||
|
||||
|
||||
def asbool(obj):
|
||||
"""
|
||||
Interprets an object as a boolean value.
|
||||
|
||||
:rtype: bool
|
||||
"""
|
||||
|
||||
if isinstance(obj, str):
|
||||
obj = obj.strip().lower()
|
||||
if obj in ('true', 'yes', 'on', 'y', 't', '1'):
|
||||
return True
|
||||
if obj in ('false', 'no', 'off', 'n', 'f', '0'):
|
||||
return False
|
||||
raise ValueError('Unable to interpret value "%s" as boolean' % obj)
|
||||
return bool(obj)
|
||||
|
||||
|
||||
def astimezone(obj):
|
||||
"""
|
||||
Interprets an object as a timezone.
|
||||
|
||||
:rtype: tzinfo
|
||||
"""
|
||||
|
||||
if isinstance(obj, six.string_types):
|
||||
return timezone(obj)
|
||||
if isinstance(obj, tzinfo):
|
||||
if not hasattr(obj, 'localize') or not hasattr(obj, 'normalize'):
|
||||
raise TypeError('Only timezones from the pytz library are supported')
|
||||
if obj.zone == 'local':
|
||||
raise ValueError('Unable to determine the name of the local timezone -- use an explicit timezone instead')
|
||||
return obj
|
||||
if obj is not None:
|
||||
raise TypeError('Expected tzinfo, got %s instead' % obj.__class__.__name__)
|
||||
|
||||
|
||||
_DATE_REGEX = re.compile(
|
||||
r'(?P<year>\d{4})-(?P<month>\d{1,2})-(?P<day>\d{1,2})'
|
||||
r'(?: (?P<hour>\d{1,2}):(?P<minute>\d{1,2}):(?P<second>\d{1,2})'
|
||||
r'(?:\.(?P<microsecond>\d{1,6}))?)?')
|
||||
|
||||
|
||||
def convert_to_datetime(input, tz, arg_name):
|
||||
"""
|
||||
Converts the given object to a timezone aware datetime object.
|
||||
If a timezone aware datetime object is passed, it is returned unmodified.
|
||||
If a native datetime object is passed, it is given the specified timezone.
|
||||
If the input is a string, it is parsed as a datetime with the given timezone.
|
||||
|
||||
Date strings are accepted in three different forms: date only (Y-m-d),
|
||||
date with time (Y-m-d H:M:S) or with date+time with microseconds
|
||||
(Y-m-d H:M:S.micro).
|
||||
|
||||
:param str|datetime input: the datetime or string to convert to a timezone aware datetime
|
||||
:param datetime.tzinfo tz: timezone to interpret ``input`` in
|
||||
:param str arg_name: the name of the argument (used in an error message)
|
||||
:rtype: datetime
|
||||
"""
|
||||
|
||||
if input is None:
|
||||
return
|
||||
elif isinstance(input, datetime):
|
||||
datetime_ = input
|
||||
elif isinstance(input, date):
|
||||
datetime_ = datetime.combine(input, time())
|
||||
elif isinstance(input, six.string_types):
|
||||
m = _DATE_REGEX.match(input)
|
||||
if not m:
|
||||
raise ValueError('Invalid date string')
|
||||
values = [(k, int(v or 0)) for k, v in m.groupdict().items()]
|
||||
values = dict(values)
|
||||
datetime_ = datetime(**values)
|
||||
else:
|
||||
raise TypeError('Unsupported type for %s: %s' % (arg_name, input.__class__.__name__))
|
||||
|
||||
if datetime_.tzinfo is not None:
|
||||
return datetime_
|
||||
if tz is None:
|
||||
raise ValueError('The "tz" argument must be specified if %s has no timezone information' % arg_name)
|
||||
if isinstance(tz, six.string_types):
|
||||
tz = timezone(tz)
|
||||
|
||||
try:
|
||||
return tz.localize(datetime_, is_dst=None)
|
||||
except AttributeError:
|
||||
raise TypeError('Only pytz timezones are supported (need the localize() and normalize() methods)')
|
||||
|
||||
|
||||
def datetime_to_utc_timestamp(timeval):
|
||||
"""
|
||||
Converts a datetime instance to a timestamp.
|
||||
|
||||
:type timeval: datetime
|
||||
:rtype: float
|
||||
"""
|
||||
|
||||
if timeval is not None:
|
||||
return timegm(timeval.utctimetuple()) + timeval.microsecond / 1000000
|
||||
|
||||
|
||||
def utc_timestamp_to_datetime(timestamp):
|
||||
"""
|
||||
Converts the given timestamp to a datetime instance.
|
||||
|
||||
:type timestamp: float
|
||||
:rtype: datetime
|
||||
"""
|
||||
|
||||
if timestamp is not None:
|
||||
return datetime.fromtimestamp(timestamp, utc)
|
||||
|
||||
|
||||
def timedelta_seconds(delta):
|
||||
"""
|
||||
Converts the given timedelta to seconds.
|
||||
|
||||
:type delta: timedelta
|
||||
:rtype: float
|
||||
"""
|
||||
|
||||
return delta.days * 24 * 60 * 60 + delta.seconds + \
|
||||
delta.microseconds / 1000000.0
|
||||
|
||||
|
||||
def datetime_ceil(dateval):
|
||||
"""
|
||||
Rounds the given datetime object upwards.
|
||||
|
||||
:type dateval: datetime
|
||||
"""
|
||||
|
||||
if dateval.microsecond > 0:
|
||||
return dateval + timedelta(seconds=1, microseconds=-dateval.microsecond)
|
||||
return dateval
|
||||
|
||||
|
||||
def datetime_repr(dateval):
|
||||
return dateval.strftime('%Y-%m-%d %H:%M:%S %Z') if dateval else 'None'
|
||||
|
||||
|
||||
def get_callable_name(func):
|
||||
"""
|
||||
Returns the best available display name for the given function/callable.
|
||||
|
||||
:rtype: str
|
||||
"""
|
||||
|
||||
# the easy case (on Python 3.3+)
|
||||
if hasattr(func, '__qualname__'):
|
||||
return func.__qualname__
|
||||
|
||||
# class methods, bound and unbound methods
|
||||
f_self = getattr(func, '__self__', None) or getattr(func, 'im_self', None)
|
||||
if f_self and hasattr(func, '__name__'):
|
||||
f_class = f_self if isinstance(f_self, type) else f_self.__class__
|
||||
else:
|
||||
f_class = getattr(func, 'im_class', None)
|
||||
|
||||
if f_class and hasattr(func, '__name__'):
|
||||
return '%s.%s' % (f_class.__name__, func.__name__)
|
||||
|
||||
# class or class instance
|
||||
if hasattr(func, '__call__'):
|
||||
# class
|
||||
if hasattr(func, '__name__'):
|
||||
return func.__name__
|
||||
|
||||
# instance of a class with a __call__ method
|
||||
return func.__class__.__name__
|
||||
|
||||
raise TypeError('Unable to determine a name for %r -- maybe it is not a callable?' % func)
|
||||
|
||||
|
||||
def obj_to_ref(obj):
|
||||
"""
|
||||
Returns the path to the given object.
|
||||
|
||||
:rtype: str
|
||||
"""
|
||||
|
||||
try:
|
||||
ref = '%s:%s' % (obj.__module__, get_callable_name(obj))
|
||||
obj2 = ref_to_obj(ref)
|
||||
if obj != obj2:
|
||||
raise ValueError
|
||||
except Exception:
|
||||
raise ValueError('Cannot determine the reference to %r' % obj)
|
||||
|
||||
return ref
|
||||
|
||||
|
||||
def ref_to_obj(ref):
|
||||
"""
|
||||
Returns the object pointed to by ``ref``.
|
||||
|
||||
:type ref: str
|
||||
"""
|
||||
|
||||
if not isinstance(ref, six.string_types):
|
||||
raise TypeError('References must be strings')
|
||||
if ':' not in ref:
|
||||
raise ValueError('Invalid reference')
|
||||
|
||||
modulename, rest = ref.split(':', 1)
|
||||
try:
|
||||
obj = __import__(modulename)
|
||||
except ImportError:
|
||||
raise LookupError('Error resolving reference %s: could not import module' % ref)
|
||||
|
||||
try:
|
||||
for name in modulename.split('.')[1:] + rest.split('.'):
|
||||
obj = getattr(obj, name)
|
||||
return obj
|
||||
except Exception:
|
||||
raise LookupError('Error resolving reference %s: error looking up object' % ref)
|
||||
|
||||
|
||||
def maybe_ref(ref):
|
||||
"""
|
||||
Returns the object that the given reference points to, if it is indeed a reference.
|
||||
If it is not a reference, the object is returned as-is.
|
||||
"""
|
||||
|
||||
if not isinstance(ref, str):
|
||||
return ref
|
||||
return ref_to_obj(ref)
|
||||
|
||||
|
||||
if six.PY2:
|
||||
def repr_escape(string):
|
||||
if isinstance(string, six.text_type):
|
||||
return string.encode('ascii', 'backslashreplace')
|
||||
return string
|
||||
else:
|
||||
repr_escape = lambda string: string
|
||||
|
||||
|
||||
def check_callable_args(func, args, kwargs):
|
||||
"""
|
||||
Ensures that the given callable can be called with the given arguments.
|
||||
|
||||
:type args: tuple
|
||||
:type kwargs: dict
|
||||
"""
|
||||
|
||||
pos_kwargs_conflicts = [] # parameters that have a match in both args and kwargs
|
||||
positional_only_kwargs = [] # positional-only parameters that have a match in kwargs
|
||||
unsatisfied_args = [] # parameters in signature that don't have a match in args or kwargs
|
||||
unsatisfied_kwargs = [] # keyword-only arguments that don't have a match in kwargs
|
||||
unmatched_args = list(args) # args that didn't match any of the parameters in the signature
|
||||
unmatched_kwargs = list(kwargs) # kwargs that didn't match any of the parameters in the signature
|
||||
has_varargs = has_var_kwargs = False # indicates if the signature defines *args and **kwargs respectively
|
||||
|
||||
if signature:
|
||||
try:
|
||||
sig = signature(func)
|
||||
except ValueError:
|
||||
return # signature() doesn't work against every kind of callable
|
||||
|
||||
for param in six.itervalues(sig.parameters):
|
||||
if param.kind == param.POSITIONAL_OR_KEYWORD:
|
||||
if param.name in unmatched_kwargs and unmatched_args:
|
||||
pos_kwargs_conflicts.append(param.name)
|
||||
elif unmatched_args:
|
||||
del unmatched_args[0]
|
||||
elif param.name in unmatched_kwargs:
|
||||
unmatched_kwargs.remove(param.name)
|
||||
elif param.default is param.empty:
|
||||
unsatisfied_args.append(param.name)
|
||||
elif param.kind == param.POSITIONAL_ONLY:
|
||||
if unmatched_args:
|
||||
del unmatched_args[0]
|
||||
elif param.name in unmatched_kwargs:
|
||||
unmatched_kwargs.remove(param.name)
|
||||
positional_only_kwargs.append(param.name)
|
||||
elif param.default is param.empty:
|
||||
unsatisfied_args.append(param.name)
|
||||
elif param.kind == param.KEYWORD_ONLY:
|
||||
if param.name in unmatched_kwargs:
|
||||
unmatched_kwargs.remove(param.name)
|
||||
elif param.default is param.empty:
|
||||
unsatisfied_kwargs.append(param.name)
|
||||
elif param.kind == param.VAR_POSITIONAL:
|
||||
has_varargs = True
|
||||
elif param.kind == param.VAR_KEYWORD:
|
||||
has_var_kwargs = True
|
||||
else:
|
||||
if not isfunction(func) and not ismethod(func) and hasattr(func, '__call__'):
|
||||
func = func.__call__
|
||||
|
||||
try:
|
||||
argspec = getargspec(func)
|
||||
except TypeError:
|
||||
return # getargspec() doesn't work certain callables
|
||||
|
||||
argspec_args = argspec.args if not ismethod(func) else argspec.args[1:]
|
||||
has_varargs = bool(argspec.varargs)
|
||||
has_var_kwargs = bool(argspec.keywords)
|
||||
for arg, default in six.moves.zip_longest(argspec_args, argspec.defaults or (), fillvalue=undefined):
|
||||
if arg in unmatched_kwargs and unmatched_args:
|
||||
pos_kwargs_conflicts.append(arg)
|
||||
elif unmatched_args:
|
||||
del unmatched_args[0]
|
||||
elif arg in unmatched_kwargs:
|
||||
unmatched_kwargs.remove(arg)
|
||||
elif default is undefined:
|
||||
unsatisfied_args.append(arg)
|
||||
|
||||
# Make sure there are no conflicts between args and kwargs
|
||||
if pos_kwargs_conflicts:
|
||||
raise ValueError('The following arguments are supplied in both args and kwargs: %s' %
|
||||
', '.join(pos_kwargs_conflicts))
|
||||
|
||||
# Check if keyword arguments are being fed to positional-only parameters
|
||||
if positional_only_kwargs:
|
||||
raise ValueError('The following arguments cannot be given as keyword arguments: %s' %
|
||||
', '.join(positional_only_kwargs))
|
||||
|
||||
# Check that the number of positional arguments minus the number of matched kwargs matches the argspec
|
||||
if unsatisfied_args:
|
||||
raise ValueError('The following arguments have not been supplied: %s' % ', '.join(unsatisfied_args))
|
||||
|
||||
# Check that all keyword-only arguments have been supplied
|
||||
if unsatisfied_kwargs:
|
||||
raise ValueError('The following keyword-only arguments have not been supplied in kwargs: %s' %
|
||||
', '.join(unsatisfied_kwargs))
|
||||
|
||||
# Check that the callable can accept the given number of positional arguments
|
||||
if not has_varargs and unmatched_args:
|
||||
raise ValueError('The list of positional arguments is longer than the target callable can handle '
|
||||
'(allowed: %d, given in args: %d)' % (len(args) - len(unmatched_args), len(args)))
|
||||
|
||||
# Check that the callable can accept the given keyword arguments
|
||||
if not has_var_kwargs and unmatched_kwargs:
|
||||
raise ValueError('The target callable does not accept the following keyword arguments: %s' %
|
||||
', '.join(unmatched_kwargs))
|
||||
+2386
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,21 @@
|
||||
The MIT License
|
||||
|
||||
Copyright (c) 2010-2014 Adrian Sampson
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in
|
||||
all copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
THE SOFTWARE.
|
||||
@@ -0,0 +1,94 @@
|
||||
.. image:: https://travis-ci.org/sampsyo/beets.svg?branch=master
|
||||
:target: https://travis-ci.org/sampsyo/beets
|
||||
|
||||
.. image:: http://img.shields.io/coveralls/sampsyo/beets.svg
|
||||
:target: https://coveralls.io/r/sampsyo/beets
|
||||
|
||||
.. image:: http://img.shields.io/pypi/v/beets.svg
|
||||
:target: https://pypi.python.org/pypi/beets
|
||||
|
||||
Beets is the media library management system for obsessive-compulsive music
|
||||
geeks.
|
||||
|
||||
The purpose of beets is to get your music collection right once and for all.
|
||||
It catalogs your collection, automatically improving its metadata as it goes.
|
||||
It then provides a bouquet of tools for manipulating and accessing your music.
|
||||
|
||||
Here's an example of beets' brainy tag corrector doing its thing::
|
||||
|
||||
$ beet import ~/music/ladytron
|
||||
Tagging:
|
||||
Ladytron - Witching Hour
|
||||
(Similarity: 98.4%)
|
||||
* Last One Standing -> The Last One Standing
|
||||
* Beauty -> Beauty*2
|
||||
* White Light Generation -> Whitelightgenerator
|
||||
* All the Way -> All the Way...
|
||||
|
||||
Because beets is designed as a library, it can do almost anything you can
|
||||
imagine for your music collection. Via `plugins`_, beets becomes a panacea:
|
||||
|
||||
- Fetch or calculate all the metadata you could possibly need: `album art`_,
|
||||
`lyrics`_, `genres`_, `tempos`_, `ReplayGain`_ levels, or `acoustic
|
||||
fingerprints`_.
|
||||
- Get metadata from `MusicBrainz`_, `Discogs`_, or `Beatport`_. Or guess
|
||||
metadata using songs' filenames or their acoustic fingerprints.
|
||||
- `Transcode audio`_ to any format you like.
|
||||
- Check your library for `duplicate tracks and albums`_ or for `albums that
|
||||
are missing tracks`_.
|
||||
- Clean up crufty tags left behind by other, less-awesome tools.
|
||||
- Embed and extract album art from files' metadata.
|
||||
- Browse your music library graphically through a Web browser and play it in any
|
||||
browser that supports `HTML5 Audio`_.
|
||||
- Analyze music files' metadata from the command line.
|
||||
- Listen to your library with a music player that speaks the `MPD`_ protocol
|
||||
and works with a staggering variety of interfaces.
|
||||
|
||||
If beets doesn't do what you want yet, `writing your own plugin`_ is
|
||||
shockingly simple if you know a little Python.
|
||||
|
||||
.. _plugins: http://beets.readthedocs.org/page/plugins/
|
||||
.. _MPD: http://www.musicpd.org/
|
||||
.. _MusicBrainz music collection: http://musicbrainz.org/doc/Collections/
|
||||
.. _writing your own plugin:
|
||||
http://beets.readthedocs.org/page/dev/plugins.html
|
||||
.. _HTML5 Audio:
|
||||
http://www.w3.org/TR/html-markup/audio.html
|
||||
.. _albums that are missing tracks:
|
||||
http://beets.readthedocs.org/page/plugins/missing.html
|
||||
.. _duplicate tracks and albums:
|
||||
http://beets.readthedocs.org/page/plugins/duplicates.html
|
||||
.. _Transcode audio:
|
||||
http://beets.readthedocs.org/page/plugins/convert.html
|
||||
.. _Beatport: http://www.beatport.com/
|
||||
.. _Discogs: http://www.discogs.com/
|
||||
.. _acoustic fingerprints:
|
||||
http://beets.readthedocs.org/page/plugins/chroma.html
|
||||
.. _ReplayGain: http://beets.readthedocs.org/page/plugins/replaygain.html
|
||||
.. _tempos: http://beets.readthedocs.org/page/plugins/echonest.html
|
||||
.. _genres: http://beets.readthedocs.org/page/plugins/lastgenre.html
|
||||
.. _album art: http://beets.readthedocs.org/page/plugins/fetchart.html
|
||||
.. _lyrics: http://beets.readthedocs.org/page/plugins/lyrics.html
|
||||
.. _MusicBrainz: http://musicbrainz.org/
|
||||
|
||||
Read More
|
||||
---------
|
||||
|
||||
Learn more about beets at `its Web site`_. Follow `@b33ts`_ on Twitter for
|
||||
news and updates.
|
||||
|
||||
You can install beets by typing ``pip install beets``. Then check out the
|
||||
`Getting Started`_ guide.
|
||||
|
||||
.. _its Web site: http://beets.radbox.org/
|
||||
.. _Getting Started: http://beets.readthedocs.org/page/guides/main.html
|
||||
.. _@b33ts: http://twitter.com/b33ts/
|
||||
|
||||
Authors
|
||||
-------
|
||||
|
||||
Beets is by `Adrian Sampson`_ with a supporting cast of thousands. For help,
|
||||
please contact the `mailing list`_.
|
||||
|
||||
.. _mailing list: https://groups.google.com/forum/#!forum/beets-users
|
||||
.. _Adrian Sampson: http://homes.cs.washington.edu/~asampson/
|
||||
@@ -0,0 +1,28 @@
|
||||
# This file is part of beets.
|
||||
# Copyright 2014, Adrian Sampson.
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining
|
||||
# a copy of this software and associated documentation files (the
|
||||
# "Software"), to deal in the Software without restriction, including
|
||||
# without limitation the rights to use, copy, modify, merge, publish,
|
||||
# distribute, sublicense, and/or sell copies of the Software, and to
|
||||
# permit persons to whom the Software is furnished to do so, subject to
|
||||
# the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be
|
||||
# included in all copies or substantial portions of the Software.
|
||||
|
||||
# This particular version has been slightly modified to work with Headphones
|
||||
# https://github.com/rembo10/headphones
|
||||
|
||||
__version__ = '1.3.10-headphones'
|
||||
__author__ = 'Adrian Sampson <adrian@radbox.org>'
|
||||
|
||||
import os
|
||||
|
||||
import beets.library
|
||||
from beets.util import confit
|
||||
|
||||
Library = beets.library.Library
|
||||
|
||||
config = confit.LazyConfig(os.path.dirname(__file__), __name__)
|
||||
@@ -0,0 +1,136 @@
|
||||
# This file is part of beets.
|
||||
# Copyright 2013, Adrian Sampson.
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining
|
||||
# a copy of this software and associated documentation files (the
|
||||
# "Software"), to deal in the Software without restriction, including
|
||||
# without limitation the rights to use, copy, modify, merge, publish,
|
||||
# distribute, sublicense, and/or sell copies of the Software, and to
|
||||
# permit persons to whom the Software is furnished to do so, subject to
|
||||
# the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be
|
||||
# included in all copies or substantial portions of the Software.
|
||||
|
||||
"""Facilities for automatically determining files' correct metadata.
|
||||
"""
|
||||
import logging
|
||||
|
||||
from beets import config
|
||||
|
||||
# Parts of external interface.
|
||||
from .hooks import AlbumInfo, TrackInfo, AlbumMatch, TrackMatch # noqa
|
||||
from .match import tag_item, tag_album # noqa
|
||||
from .match import Recommendation # noqa
|
||||
|
||||
# Global logger.
|
||||
log = logging.getLogger('beets')
|
||||
|
||||
|
||||
# Additional utilities for the main interface.
|
||||
|
||||
def apply_item_metadata(item, track_info):
|
||||
"""Set an item's metadata from its matched TrackInfo object.
|
||||
"""
|
||||
item.artist = track_info.artist
|
||||
item.artist_sort = track_info.artist_sort
|
||||
item.artist_credit = track_info.artist_credit
|
||||
item.title = track_info.title
|
||||
item.mb_trackid = track_info.track_id
|
||||
if track_info.artist_id:
|
||||
item.mb_artistid = track_info.artist_id
|
||||
# At the moment, the other metadata is left intact (including album
|
||||
# and track number). Perhaps these should be emptied?
|
||||
|
||||
|
||||
def apply_metadata(album_info, mapping):
|
||||
"""Set the items' metadata to match an AlbumInfo object using a
|
||||
mapping from Items to TrackInfo objects.
|
||||
"""
|
||||
for item, track_info in mapping.iteritems():
|
||||
# Album, artist, track count.
|
||||
if track_info.artist:
|
||||
item.artist = track_info.artist
|
||||
else:
|
||||
item.artist = album_info.artist
|
||||
item.albumartist = album_info.artist
|
||||
item.album = album_info.album
|
||||
|
||||
# Artist sort and credit names.
|
||||
item.artist_sort = track_info.artist_sort or album_info.artist_sort
|
||||
item.artist_credit = (track_info.artist_credit or
|
||||
album_info.artist_credit)
|
||||
item.albumartist_sort = album_info.artist_sort
|
||||
item.albumartist_credit = album_info.artist_credit
|
||||
|
||||
# Release date.
|
||||
for prefix in '', 'original_':
|
||||
if config['original_date'] and not prefix:
|
||||
# Ignore specific release date.
|
||||
continue
|
||||
|
||||
for suffix in 'year', 'month', 'day':
|
||||
key = prefix + suffix
|
||||
value = getattr(album_info, key) or 0
|
||||
|
||||
# If we don't even have a year, apply nothing.
|
||||
if suffix == 'year' and not value:
|
||||
break
|
||||
|
||||
# Otherwise, set the fetched value (or 0 for the month
|
||||
# and day if not available).
|
||||
item[key] = value
|
||||
|
||||
# If we're using original release date for both fields,
|
||||
# also set item.year = info.original_year, etc.
|
||||
if config['original_date']:
|
||||
item[suffix] = value
|
||||
|
||||
# Title.
|
||||
item.title = track_info.title
|
||||
|
||||
if config['per_disc_numbering']:
|
||||
item.track = track_info.medium_index or track_info.index
|
||||
item.tracktotal = track_info.medium_total or len(album_info.tracks)
|
||||
else:
|
||||
item.track = track_info.index
|
||||
item.tracktotal = len(album_info.tracks)
|
||||
|
||||
# Disc and disc count.
|
||||
item.disc = track_info.medium
|
||||
item.disctotal = album_info.mediums
|
||||
|
||||
# MusicBrainz IDs.
|
||||
item.mb_trackid = track_info.track_id
|
||||
item.mb_albumid = album_info.album_id
|
||||
if track_info.artist_id:
|
||||
item.mb_artistid = track_info.artist_id
|
||||
else:
|
||||
item.mb_artistid = album_info.artist_id
|
||||
item.mb_albumartistid = album_info.artist_id
|
||||
item.mb_releasegroupid = album_info.releasegroup_id
|
||||
|
||||
# Compilation flag.
|
||||
item.comp = album_info.va
|
||||
|
||||
# Miscellaneous metadata.
|
||||
for field in ('albumtype',
|
||||
'label',
|
||||
'asin',
|
||||
'catalognum',
|
||||
'script',
|
||||
'language',
|
||||
'country',
|
||||
'albumstatus',
|
||||
'albumdisambig'):
|
||||
value = getattr(album_info, field)
|
||||
if value is not None:
|
||||
item[field] = value
|
||||
if track_info.disctitle is not None:
|
||||
item.disctitle = track_info.disctitle
|
||||
|
||||
if track_info.media is not None:
|
||||
item.media = track_info.media
|
||||
|
||||
# Headphones seal of approval
|
||||
item.comments = 'tagged by headphones/beets'
|
||||
@@ -0,0 +1,579 @@
|
||||
# This file is part of beets.
|
||||
# Copyright 2013, Adrian Sampson.
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining
|
||||
# a copy of this software and associated documentation files (the
|
||||
# "Software"), to deal in the Software without restriction, including
|
||||
# without limitation the rights to use, copy, modify, merge, publish,
|
||||
# distribute, sublicense, and/or sell copies of the Software, and to
|
||||
# permit persons to whom the Software is furnished to do so, subject to
|
||||
# the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be
|
||||
# included in all copies or substantial portions of the Software.
|
||||
|
||||
"""Glue between metadata sources and the matching logic."""
|
||||
import logging
|
||||
from collections import namedtuple
|
||||
import re
|
||||
|
||||
from beets import plugins
|
||||
from beets import config
|
||||
from beets.autotag import mb
|
||||
from beets.util import levenshtein
|
||||
from unidecode import unidecode
|
||||
|
||||
log = logging.getLogger('beets')
|
||||
|
||||
|
||||
# Classes used to represent candidate options.
|
||||
|
||||
class AlbumInfo(object):
|
||||
"""Describes a canonical release that may be used to match a release
|
||||
in the library. Consists of these data members:
|
||||
|
||||
- ``album``: the release title
|
||||
- ``album_id``: MusicBrainz ID; UUID fragment only
|
||||
- ``artist``: name of the release's primary artist
|
||||
- ``artist_id``
|
||||
- ``tracks``: list of TrackInfo objects making up the release
|
||||
- ``asin``: Amazon ASIN
|
||||
- ``albumtype``: string describing the kind of release
|
||||
- ``va``: boolean: whether the release has "various artists"
|
||||
- ``year``: release year
|
||||
- ``month``: release month
|
||||
- ``day``: release day
|
||||
- ``label``: music label responsible for the release
|
||||
- ``mediums``: the number of discs in this release
|
||||
- ``artist_sort``: name of the release's artist for sorting
|
||||
- ``releasegroup_id``: MBID for the album's release group
|
||||
- ``catalognum``: the label's catalog number for the release
|
||||
- ``script``: character set used for metadata
|
||||
- ``language``: human language of the metadata
|
||||
- ``country``: the release country
|
||||
- ``albumstatus``: MusicBrainz release status (Official, etc.)
|
||||
- ``media``: delivery mechanism (Vinyl, etc.)
|
||||
- ``albumdisambig``: MusicBrainz release disambiguation comment
|
||||
- ``artist_credit``: Release-specific artist name
|
||||
- ``data_source``: The original data source (MusicBrainz, Discogs, etc.)
|
||||
- ``data_url``: The data source release URL.
|
||||
|
||||
The fields up through ``tracks`` are required. The others are
|
||||
optional and may be None.
|
||||
"""
|
||||
def __init__(self, album, album_id, artist, artist_id, tracks, asin=None,
|
||||
albumtype=None, va=False, year=None, month=None, day=None,
|
||||
label=None, mediums=None, artist_sort=None,
|
||||
releasegroup_id=None, catalognum=None, script=None,
|
||||
language=None, country=None, albumstatus=None, media=None,
|
||||
albumdisambig=None, artist_credit=None, original_year=None,
|
||||
original_month=None, original_day=None, data_source=None,
|
||||
data_url=None):
|
||||
self.album = album
|
||||
self.album_id = album_id
|
||||
self.artist = artist
|
||||
self.artist_id = artist_id
|
||||
self.tracks = tracks
|
||||
self.asin = asin
|
||||
self.albumtype = albumtype
|
||||
self.va = va
|
||||
self.year = year
|
||||
self.month = month
|
||||
self.day = day
|
||||
self.label = label
|
||||
self.mediums = mediums
|
||||
self.artist_sort = artist_sort
|
||||
self.releasegroup_id = releasegroup_id
|
||||
self.catalognum = catalognum
|
||||
self.script = script
|
||||
self.language = language
|
||||
self.country = country
|
||||
self.albumstatus = albumstatus
|
||||
self.media = media
|
||||
self.albumdisambig = albumdisambig
|
||||
self.artist_credit = artist_credit
|
||||
self.original_year = original_year
|
||||
self.original_month = original_month
|
||||
self.original_day = original_day
|
||||
self.data_source = data_source
|
||||
self.data_url = data_url
|
||||
|
||||
# Work around a bug in python-musicbrainz-ngs that causes some
|
||||
# strings to be bytes rather than Unicode.
|
||||
# https://github.com/alastair/python-musicbrainz-ngs/issues/85
|
||||
def decode(self, codec='utf8'):
|
||||
"""Ensure that all string attributes on this object, and the
|
||||
constituent `TrackInfo` objects, are decoded to Unicode.
|
||||
"""
|
||||
for fld in ['album', 'artist', 'albumtype', 'label', 'artist_sort',
|
||||
'catalognum', 'script', 'language', 'country',
|
||||
'albumstatus', 'albumdisambig', 'artist_credit', 'media']:
|
||||
value = getattr(self, fld)
|
||||
if isinstance(value, str):
|
||||
setattr(self, fld, value.decode(codec, 'ignore'))
|
||||
|
||||
if self.tracks:
|
||||
for track in self.tracks:
|
||||
track.decode(codec)
|
||||
|
||||
|
||||
class TrackInfo(object):
|
||||
"""Describes a canonical track present on a release. Appears as part
|
||||
of an AlbumInfo's ``tracks`` list. Consists of these data members:
|
||||
|
||||
- ``title``: name of the track
|
||||
- ``track_id``: MusicBrainz ID; UUID fragment only
|
||||
- ``artist``: individual track artist name
|
||||
- ``artist_id``
|
||||
- ``length``: float: duration of the track in seconds
|
||||
- ``index``: position on the entire release
|
||||
- ``media``: delivery mechanism (Vinyl, etc.)
|
||||
- ``medium``: the disc number this track appears on in the album
|
||||
- ``medium_index``: the track's position on the disc
|
||||
- ``medium_total``: the number of tracks on the item's disc
|
||||
- ``artist_sort``: name of the track artist for sorting
|
||||
- ``disctitle``: name of the individual medium (subtitle)
|
||||
- ``artist_credit``: Recording-specific artist name
|
||||
|
||||
Only ``title`` and ``track_id`` are required. The rest of the fields
|
||||
may be None. The indices ``index``, ``medium``, and ``medium_index``
|
||||
are all 1-based.
|
||||
"""
|
||||
def __init__(self, title, track_id, artist=None, artist_id=None,
|
||||
length=None, index=None, medium=None, medium_index=None,
|
||||
medium_total=None, artist_sort=None, disctitle=None,
|
||||
artist_credit=None, data_source=None, data_url=None,
|
||||
media=None):
|
||||
self.title = title
|
||||
self.track_id = track_id
|
||||
self.artist = artist
|
||||
self.artist_id = artist_id
|
||||
self.length = length
|
||||
self.index = index
|
||||
self.media = media
|
||||
self.medium = medium
|
||||
self.medium_index = medium_index
|
||||
self.medium_total = medium_total
|
||||
self.artist_sort = artist_sort
|
||||
self.disctitle = disctitle
|
||||
self.artist_credit = artist_credit
|
||||
self.data_source = data_source
|
||||
self.data_url = data_url
|
||||
|
||||
# As above, work around a bug in python-musicbrainz-ngs.
|
||||
def decode(self, codec='utf8'):
|
||||
"""Ensure that all string attributes on this object are decoded
|
||||
to Unicode.
|
||||
"""
|
||||
for fld in ['title', 'artist', 'medium', 'artist_sort', 'disctitle',
|
||||
'artist_credit', 'media']:
|
||||
value = getattr(self, fld)
|
||||
if isinstance(value, str):
|
||||
setattr(self, fld, value.decode(codec, 'ignore'))
|
||||
|
||||
|
||||
# Candidate distance scoring.
|
||||
|
||||
# Parameters for string distance function.
|
||||
# Words that can be moved to the end of a string using a comma.
|
||||
SD_END_WORDS = ['the', 'a', 'an']
|
||||
# Reduced weights for certain portions of the string.
|
||||
SD_PATTERNS = [
|
||||
(r'^the ', 0.1),
|
||||
(r'[\[\(]?(ep|single)[\]\)]?', 0.0),
|
||||
(r'[\[\(]?(featuring|feat|ft)[\. :].+', 0.1),
|
||||
(r'\(.*?\)', 0.3),
|
||||
(r'\[.*?\]', 0.3),
|
||||
(r'(, )?(pt\.|part) .+', 0.2),
|
||||
]
|
||||
# Replacements to use before testing distance.
|
||||
SD_REPLACE = [
|
||||
(r'&', 'and'),
|
||||
]
|
||||
|
||||
|
||||
def _string_dist_basic(str1, str2):
|
||||
"""Basic edit distance between two strings, ignoring
|
||||
non-alphanumeric characters and case. Comparisons are based on a
|
||||
transliteration/lowering to ASCII characters. Normalized by string
|
||||
length.
|
||||
"""
|
||||
str1 = unidecode(str1)
|
||||
str2 = unidecode(str2)
|
||||
str1 = re.sub(r'[^a-z0-9]', '', str1.lower())
|
||||
str2 = re.sub(r'[^a-z0-9]', '', str2.lower())
|
||||
if not str1 and not str2:
|
||||
return 0.0
|
||||
return levenshtein(str1, str2) / float(max(len(str1), len(str2)))
|
||||
|
||||
|
||||
def string_dist(str1, str2):
|
||||
"""Gives an "intuitive" edit distance between two strings. This is
|
||||
an edit distance, normalized by the string length, with a number of
|
||||
tweaks that reflect intuition about text.
|
||||
"""
|
||||
if str1 is None and str2 is None:
|
||||
return 0.0
|
||||
if str1 is None or str2 is None:
|
||||
return 1.0
|
||||
|
||||
str1 = str1.lower()
|
||||
str2 = str2.lower()
|
||||
|
||||
# Don't penalize strings that move certain words to the end. For
|
||||
# example, "the something" should be considered equal to
|
||||
# "something, the".
|
||||
for word in SD_END_WORDS:
|
||||
if str1.endswith(', %s' % word):
|
||||
str1 = '%s %s' % (word, str1[:-len(word) - 2])
|
||||
if str2.endswith(', %s' % word):
|
||||
str2 = '%s %s' % (word, str2[:-len(word) - 2])
|
||||
|
||||
# Perform a couple of basic normalizing substitutions.
|
||||
for pat, repl in SD_REPLACE:
|
||||
str1 = re.sub(pat, repl, str1)
|
||||
str2 = re.sub(pat, repl, str2)
|
||||
|
||||
# Change the weight for certain string portions matched by a set
|
||||
# of regular expressions. We gradually change the strings and build
|
||||
# up penalties associated with parts of the string that were
|
||||
# deleted.
|
||||
base_dist = _string_dist_basic(str1, str2)
|
||||
penalty = 0.0
|
||||
for pat, weight in SD_PATTERNS:
|
||||
# Get strings that drop the pattern.
|
||||
case_str1 = re.sub(pat, '', str1)
|
||||
case_str2 = re.sub(pat, '', str2)
|
||||
|
||||
if case_str1 != str1 or case_str2 != str2:
|
||||
# If the pattern was present (i.e., it is deleted in the
|
||||
# the current case), recalculate the distances for the
|
||||
# modified strings.
|
||||
case_dist = _string_dist_basic(case_str1, case_str2)
|
||||
case_delta = max(0.0, base_dist - case_dist)
|
||||
if case_delta == 0.0:
|
||||
continue
|
||||
|
||||
# Shift our baseline strings down (to avoid rematching the
|
||||
# same part of the string) and add a scaled distance
|
||||
# amount to the penalties.
|
||||
str1 = case_str1
|
||||
str2 = case_str2
|
||||
base_dist = case_dist
|
||||
penalty += weight * case_delta
|
||||
|
||||
return base_dist + penalty
|
||||
|
||||
|
||||
class LazyClassProperty(object):
|
||||
"""A decorator implementing a read-only property that is *lazy* in
|
||||
the sense that the getter is only invoked once. Subsequent accesses
|
||||
through *any* instance use the cached result.
|
||||
"""
|
||||
def __init__(self, getter):
|
||||
self.getter = getter
|
||||
self.computed = False
|
||||
|
||||
def __get__(self, obj, owner):
|
||||
if not self.computed:
|
||||
self.value = self.getter(owner)
|
||||
self.computed = True
|
||||
return self.value
|
||||
|
||||
|
||||
class Distance(object):
|
||||
"""Keeps track of multiple distance penalties. Provides a single
|
||||
weighted distance for all penalties as well as a weighted distance
|
||||
for each individual penalty.
|
||||
"""
|
||||
def __init__(self):
|
||||
self._penalties = {}
|
||||
|
||||
@LazyClassProperty
|
||||
def _weights(cls):
|
||||
"""A dictionary from keys to floating-point weights.
|
||||
"""
|
||||
weights_view = config['match']['distance_weights']
|
||||
weights = {}
|
||||
for key in weights_view.keys():
|
||||
weights[key] = weights_view[key].as_number()
|
||||
return weights
|
||||
|
||||
# Access the components and their aggregates.
|
||||
|
||||
@property
|
||||
def distance(self):
|
||||
"""Return a weighted and normalized distance across all
|
||||
penalties.
|
||||
"""
|
||||
dist_max = self.max_distance
|
||||
if dist_max:
|
||||
return self.raw_distance / self.max_distance
|
||||
return 0.0
|
||||
|
||||
@property
|
||||
def max_distance(self):
|
||||
"""Return the maximum distance penalty (normalization factor).
|
||||
"""
|
||||
dist_max = 0.0
|
||||
for key, penalty in self._penalties.iteritems():
|
||||
dist_max += len(penalty) * self._weights[key]
|
||||
return dist_max
|
||||
|
||||
@property
|
||||
def raw_distance(self):
|
||||
"""Return the raw (denormalized) distance.
|
||||
"""
|
||||
dist_raw = 0.0
|
||||
for key, penalty in self._penalties.iteritems():
|
||||
dist_raw += sum(penalty) * self._weights[key]
|
||||
return dist_raw
|
||||
|
||||
def items(self):
|
||||
"""Return a list of (key, dist) pairs, with `dist` being the
|
||||
weighted distance, sorted from highest to lowest. Does not
|
||||
include penalties with a zero value.
|
||||
"""
|
||||
list_ = []
|
||||
for key in self._penalties:
|
||||
dist = self[key]
|
||||
if dist:
|
||||
list_.append((key, dist))
|
||||
# Convert distance into a negative float we can sort items in
|
||||
# ascending order (for keys, when the penalty is equal) and
|
||||
# still get the items with the biggest distance first.
|
||||
return sorted(list_, key=lambda (key, dist): (0 - dist, key))
|
||||
|
||||
# Behave like a float.
|
||||
|
||||
def __cmp__(self, other):
|
||||
return cmp(self.distance, other)
|
||||
|
||||
def __float__(self):
|
||||
return self.distance
|
||||
|
||||
def __sub__(self, other):
|
||||
return self.distance - other
|
||||
|
||||
def __rsub__(self, other):
|
||||
return other - self.distance
|
||||
|
||||
# Behave like a dict.
|
||||
|
||||
def __getitem__(self, key):
|
||||
"""Returns the weighted distance for a named penalty.
|
||||
"""
|
||||
dist = sum(self._penalties[key]) * self._weights[key]
|
||||
dist_max = self.max_distance
|
||||
if dist_max:
|
||||
return dist / dist_max
|
||||
return 0.0
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self.items())
|
||||
|
||||
def __len__(self):
|
||||
return len(self.items())
|
||||
|
||||
def keys(self):
|
||||
return [key for key, _ in self.items()]
|
||||
|
||||
def update(self, dist):
|
||||
"""Adds all the distance penalties from `dist`.
|
||||
"""
|
||||
if not isinstance(dist, Distance):
|
||||
raise ValueError(
|
||||
'`dist` must be a Distance object, not {0}'.format(type(dist))
|
||||
)
|
||||
for key, penalties in dist._penalties.iteritems():
|
||||
self._penalties.setdefault(key, []).extend(penalties)
|
||||
|
||||
# Adding components.
|
||||
|
||||
def _eq(self, value1, value2):
|
||||
"""Returns True if `value1` is equal to `value2`. `value1` may
|
||||
be a compiled regular expression, in which case it will be
|
||||
matched against `value2`.
|
||||
"""
|
||||
if isinstance(value1, re._pattern_type):
|
||||
return bool(value1.match(value2))
|
||||
return value1 == value2
|
||||
|
||||
def add(self, key, dist):
|
||||
"""Adds a distance penalty. `key` must correspond with a
|
||||
configured weight setting. `dist` must be a float between 0.0
|
||||
and 1.0, and will be added to any existing distance penalties
|
||||
for the same key.
|
||||
"""
|
||||
if not 0.0 <= dist <= 1.0:
|
||||
raise ValueError(
|
||||
'`dist` must be between 0.0 and 1.0, not {0}'.format(dist)
|
||||
)
|
||||
self._penalties.setdefault(key, []).append(dist)
|
||||
|
||||
def add_equality(self, key, value, options):
|
||||
"""Adds a distance penalty of 1.0 if `value` doesn't match any
|
||||
of the values in `options`. If an option is a compiled regular
|
||||
expression, it will be considered equal if it matches against
|
||||
`value`.
|
||||
"""
|
||||
if not isinstance(options, (list, tuple)):
|
||||
options = [options]
|
||||
for opt in options:
|
||||
if self._eq(opt, value):
|
||||
dist = 0.0
|
||||
break
|
||||
else:
|
||||
dist = 1.0
|
||||
self.add(key, dist)
|
||||
|
||||
def add_expr(self, key, expr):
|
||||
"""Adds a distance penalty of 1.0 if `expr` evaluates to True,
|
||||
or 0.0.
|
||||
"""
|
||||
if expr:
|
||||
self.add(key, 1.0)
|
||||
else:
|
||||
self.add(key, 0.0)
|
||||
|
||||
def add_number(self, key, number1, number2):
|
||||
"""Adds a distance penalty of 1.0 for each number of difference
|
||||
between `number1` and `number2`, or 0.0 when there is no
|
||||
difference. Use this when there is no upper limit on the
|
||||
difference between the two numbers.
|
||||
"""
|
||||
diff = abs(number1 - number2)
|
||||
if diff:
|
||||
for i in range(diff):
|
||||
self.add(key, 1.0)
|
||||
else:
|
||||
self.add(key, 0.0)
|
||||
|
||||
def add_priority(self, key, value, options):
|
||||
"""Adds a distance penalty that corresponds to the position at
|
||||
which `value` appears in `options`. A distance penalty of 0.0
|
||||
for the first option, or 1.0 if there is no matching option. If
|
||||
an option is a compiled regular expression, it will be
|
||||
considered equal if it matches against `value`.
|
||||
"""
|
||||
if not isinstance(options, (list, tuple)):
|
||||
options = [options]
|
||||
unit = 1.0 / (len(options) or 1)
|
||||
for i, opt in enumerate(options):
|
||||
if self._eq(opt, value):
|
||||
dist = i * unit
|
||||
break
|
||||
else:
|
||||
dist = 1.0
|
||||
self.add(key, dist)
|
||||
|
||||
def add_ratio(self, key, number1, number2):
|
||||
"""Adds a distance penalty for `number1` as a ratio of `number2`.
|
||||
`number1` is bound at 0 and `number2`.
|
||||
"""
|
||||
number = float(max(min(number1, number2), 0))
|
||||
if number2:
|
||||
dist = number / number2
|
||||
else:
|
||||
dist = 0.0
|
||||
self.add(key, dist)
|
||||
|
||||
def add_string(self, key, str1, str2):
|
||||
"""Adds a distance penalty based on the edit distance between
|
||||
`str1` and `str2`.
|
||||
"""
|
||||
dist = string_dist(str1, str2)
|
||||
self.add(key, dist)
|
||||
|
||||
|
||||
# Structures that compose all the information for a candidate match.
|
||||
|
||||
AlbumMatch = namedtuple('AlbumMatch', ['distance', 'info', 'mapping',
|
||||
'extra_items', 'extra_tracks'])
|
||||
|
||||
TrackMatch = namedtuple('TrackMatch', ['distance', 'info'])
|
||||
|
||||
|
||||
# Aggregation of sources.
|
||||
|
||||
def album_for_mbid(release_id):
|
||||
"""Get an AlbumInfo object for a MusicBrainz release ID. Return None
|
||||
if the ID is not found.
|
||||
"""
|
||||
try:
|
||||
return mb.album_for_id(release_id)
|
||||
except mb.MusicBrainzAPIError as exc:
|
||||
exc.log(log)
|
||||
|
||||
|
||||
def track_for_mbid(recording_id):
|
||||
"""Get a TrackInfo object for a MusicBrainz recording ID. Return None
|
||||
if the ID is not found.
|
||||
"""
|
||||
try:
|
||||
return mb.track_for_id(recording_id)
|
||||
except mb.MusicBrainzAPIError as exc:
|
||||
exc.log(log)
|
||||
|
||||
|
||||
def albums_for_id(album_id):
|
||||
"""Get a list of albums for an ID."""
|
||||
candidates = [album_for_mbid(album_id)]
|
||||
candidates.extend(plugins.album_for_id(album_id))
|
||||
return filter(None, candidates)
|
||||
|
||||
|
||||
def tracks_for_id(track_id):
|
||||
"""Get a list of tracks for an ID."""
|
||||
candidates = [track_for_mbid(track_id)]
|
||||
candidates.extend(plugins.track_for_id(track_id))
|
||||
return filter(None, candidates)
|
||||
|
||||
|
||||
def album_candidates(items, artist, album, va_likely):
|
||||
"""Search for album matches. ``items`` is a list of Item objects
|
||||
that make up the album. ``artist`` and ``album`` are the respective
|
||||
names (strings), which may be derived from the item list or may be
|
||||
entered by the user. ``va_likely`` is a boolean indicating whether
|
||||
the album is likely to be a "various artists" release.
|
||||
"""
|
||||
out = []
|
||||
|
||||
# Base candidates if we have album and artist to match.
|
||||
if artist and album:
|
||||
try:
|
||||
out.extend(mb.match_album(artist, album, len(items)))
|
||||
except mb.MusicBrainzAPIError as exc:
|
||||
exc.log(log)
|
||||
|
||||
# Also add VA matches from MusicBrainz where appropriate.
|
||||
if va_likely and album:
|
||||
try:
|
||||
out.extend(mb.match_album(None, album, len(items)))
|
||||
except mb.MusicBrainzAPIError as exc:
|
||||
exc.log(log)
|
||||
|
||||
# Candidates from plugins.
|
||||
out.extend(plugins.candidates(items, artist, album, va_likely))
|
||||
|
||||
return out
|
||||
|
||||
|
||||
def item_candidates(item, artist, title):
|
||||
"""Search for item matches. ``item`` is the Item to be matched.
|
||||
``artist`` and ``title`` are strings and either reflect the item or
|
||||
are specified by the user.
|
||||
"""
|
||||
out = []
|
||||
|
||||
# MusicBrainz candidates.
|
||||
if artist and title:
|
||||
try:
|
||||
out.extend(mb.match_track(artist, title))
|
||||
except mb.MusicBrainzAPIError as exc:
|
||||
exc.log(log)
|
||||
|
||||
# Plugin candidates.
|
||||
out.extend(plugins.item_candidates(item, artist, title))
|
||||
|
||||
return out
|
||||
@@ -0,0 +1,492 @@
|
||||
# This file is part of beets.
|
||||
# Copyright 2013, Adrian Sampson.
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining
|
||||
# a copy of this software and associated documentation files (the
|
||||
# "Software"), to deal in the Software without restriction, including
|
||||
# without limitation the rights to use, copy, modify, merge, publish,
|
||||
# distribute, sublicense, and/or sell copies of the Software, and to
|
||||
# permit persons to whom the Software is furnished to do so, subject to
|
||||
# the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be
|
||||
# included in all copies or substantial portions of the Software.
|
||||
|
||||
"""Matches existing metadata with canonical information to identify
|
||||
releases and tracks.
|
||||
"""
|
||||
from __future__ import division
|
||||
|
||||
import datetime
|
||||
import logging
|
||||
import re
|
||||
from munkres import Munkres
|
||||
|
||||
from beets import plugins
|
||||
from beets import config
|
||||
from beets.util import plurality
|
||||
from beets.autotag import hooks
|
||||
from beets.util.enumeration import OrderedEnum
|
||||
|
||||
# Artist signals that indicate "various artists". These are used at the
|
||||
# album level to determine whether a given release is likely a VA
|
||||
# release and also on the track level to to remove the penalty for
|
||||
# differing artists.
|
||||
VA_ARTISTS = (u'', u'various artists', u'various', u'va', u'unknown')
|
||||
|
||||
# Global logger.
|
||||
log = logging.getLogger('beets')
|
||||
|
||||
|
||||
# Recommendation enumeration.
|
||||
|
||||
class Recommendation(OrderedEnum):
|
||||
"""Indicates a qualitative suggestion to the user about what should
|
||||
be done with a given match.
|
||||
"""
|
||||
none = 0
|
||||
low = 1
|
||||
medium = 2
|
||||
strong = 3
|
||||
|
||||
|
||||
# Primary matching functionality.
|
||||
|
||||
def current_metadata(items):
|
||||
"""Extract the likely current metadata for an album given a list of its
|
||||
items. Return two dictionaries:
|
||||
- The most common value for each field.
|
||||
- Whether each field's value was unanimous (values are booleans).
|
||||
"""
|
||||
assert items # Must be nonempty.
|
||||
|
||||
likelies = {}
|
||||
consensus = {}
|
||||
fields = ['artist', 'album', 'albumartist', 'year', 'disctotal',
|
||||
'mb_albumid', 'label', 'catalognum', 'country', 'media',
|
||||
'albumdisambig']
|
||||
for field in fields:
|
||||
values = [item[field] for item in items if item]
|
||||
likelies[field], freq = plurality(values)
|
||||
consensus[field] = (freq == len(values))
|
||||
|
||||
# If there's an album artist consensus, use this for the artist.
|
||||
if consensus['albumartist'] and likelies['albumartist']:
|
||||
likelies['artist'] = likelies['albumartist']
|
||||
|
||||
return likelies, consensus
|
||||
|
||||
|
||||
def assign_items(items, tracks):
|
||||
"""Given a list of Items and a list of TrackInfo objects, find the
|
||||
best mapping between them. Returns a mapping from Items to TrackInfo
|
||||
objects, a set of extra Items, and a set of extra TrackInfo
|
||||
objects. These "extra" objects occur when there is an unequal number
|
||||
of objects of the two types.
|
||||
"""
|
||||
# Construct the cost matrix.
|
||||
costs = []
|
||||
for item in items:
|
||||
row = []
|
||||
for i, track in enumerate(tracks):
|
||||
row.append(track_distance(item, track))
|
||||
costs.append(row)
|
||||
|
||||
# Find a minimum-cost bipartite matching.
|
||||
matching = Munkres().compute(costs)
|
||||
|
||||
# Produce the output matching.
|
||||
mapping = dict((items[i], tracks[j]) for (i, j) in matching)
|
||||
extra_items = list(set(items) - set(mapping.keys()))
|
||||
extra_items.sort(key=lambda i: (i.disc, i.track, i.title))
|
||||
extra_tracks = list(set(tracks) - set(mapping.values()))
|
||||
extra_tracks.sort(key=lambda t: (t.index, t.title))
|
||||
return mapping, extra_items, extra_tracks
|
||||
|
||||
|
||||
def track_index_changed(item, track_info):
|
||||
"""Returns True if the item and track info index is different. Tolerates
|
||||
per disc and per release numbering.
|
||||
"""
|
||||
return item.track not in (track_info.medium_index, track_info.index)
|
||||
|
||||
|
||||
def track_distance(item, track_info, incl_artist=False):
|
||||
"""Determines the significance of a track metadata change. Returns a
|
||||
Distance object. `incl_artist` indicates that a distance component should
|
||||
be included for the track artist (i.e., for various-artist releases).
|
||||
"""
|
||||
dist = hooks.Distance()
|
||||
|
||||
# Length.
|
||||
if track_info.length:
|
||||
diff = abs(item.length - track_info.length) - \
|
||||
config['match']['track_length_grace'].as_number()
|
||||
dist.add_ratio('track_length', diff,
|
||||
config['match']['track_length_max'].as_number())
|
||||
|
||||
# Title.
|
||||
dist.add_string('track_title', item.title, track_info.title)
|
||||
|
||||
# Artist. Only check if there is actually an artist in the track data.
|
||||
if incl_artist and track_info.artist and \
|
||||
item.artist.lower() not in VA_ARTISTS:
|
||||
dist.add_string('track_artist', item.artist, track_info.artist)
|
||||
|
||||
# Track index.
|
||||
if track_info.index and item.track:
|
||||
dist.add_expr('track_index', track_index_changed(item, track_info))
|
||||
|
||||
# Track ID.
|
||||
if item.mb_trackid:
|
||||
dist.add_expr('track_id', item.mb_trackid != track_info.track_id)
|
||||
|
||||
# Plugins.
|
||||
dist.update(plugins.track_distance(item, track_info))
|
||||
|
||||
return dist
|
||||
|
||||
|
||||
def distance(items, album_info, mapping):
|
||||
"""Determines how "significant" an album metadata change would be.
|
||||
Returns a Distance object. `album_info` is an AlbumInfo object
|
||||
reflecting the album to be compared. `items` is a sequence of all
|
||||
Item objects that will be matched (order is not important).
|
||||
`mapping` is a dictionary mapping Items to TrackInfo objects; the
|
||||
keys are a subset of `items` and the values are a subset of
|
||||
`album_info.tracks`.
|
||||
"""
|
||||
likelies, _ = current_metadata(items)
|
||||
|
||||
dist = hooks.Distance()
|
||||
|
||||
# Artist, if not various.
|
||||
if not album_info.va:
|
||||
dist.add_string('artist', likelies['artist'], album_info.artist)
|
||||
|
||||
# Album.
|
||||
dist.add_string('album', likelies['album'], album_info.album)
|
||||
|
||||
# Current or preferred media.
|
||||
if album_info.media:
|
||||
# Preferred media options.
|
||||
patterns = config['match']['preferred']['media'].as_str_seq()
|
||||
options = [re.compile(r'(\d+x)?(%s)' % pat, re.I) for pat in patterns]
|
||||
if options:
|
||||
dist.add_priority('media', album_info.media, options)
|
||||
# Current media.
|
||||
elif likelies['media']:
|
||||
dist.add_equality('media', album_info.media, likelies['media'])
|
||||
|
||||
# Mediums.
|
||||
if likelies['disctotal'] and album_info.mediums:
|
||||
dist.add_number('mediums', likelies['disctotal'], album_info.mediums)
|
||||
|
||||
# Prefer earliest release.
|
||||
if album_info.year and config['match']['preferred']['original_year']:
|
||||
# Assume 1889 (earliest first gramophone discs) if we don't know the
|
||||
# original year.
|
||||
original = album_info.original_year or 1889
|
||||
diff = abs(album_info.year - original)
|
||||
diff_max = abs(datetime.date.today().year - original)
|
||||
dist.add_ratio('year', diff, diff_max)
|
||||
# Year.
|
||||
elif likelies['year'] and album_info.year:
|
||||
if likelies['year'] in (album_info.year, album_info.original_year):
|
||||
# No penalty for matching release or original year.
|
||||
dist.add('year', 0.0)
|
||||
elif album_info.original_year:
|
||||
# Prefer matchest closest to the release year.
|
||||
diff = abs(likelies['year'] - album_info.year)
|
||||
diff_max = abs(datetime.date.today().year -
|
||||
album_info.original_year)
|
||||
dist.add_ratio('year', diff, diff_max)
|
||||
else:
|
||||
# Full penalty when there is no original year.
|
||||
dist.add('year', 1.0)
|
||||
|
||||
# Preferred countries.
|
||||
patterns = config['match']['preferred']['countries'].as_str_seq()
|
||||
options = [re.compile(pat, re.I) for pat in patterns]
|
||||
if album_info.country and options:
|
||||
dist.add_priority('country', album_info.country, options)
|
||||
# Country.
|
||||
elif likelies['country'] and album_info.country:
|
||||
dist.add_string('country', likelies['country'], album_info.country)
|
||||
|
||||
# Label.
|
||||
if likelies['label'] and album_info.label:
|
||||
dist.add_string('label', likelies['label'], album_info.label)
|
||||
|
||||
# Catalog number.
|
||||
if likelies['catalognum'] and album_info.catalognum:
|
||||
dist.add_string('catalognum', likelies['catalognum'],
|
||||
album_info.catalognum)
|
||||
|
||||
# Disambiguation.
|
||||
if likelies['albumdisambig'] and album_info.albumdisambig:
|
||||
dist.add_string('albumdisambig', likelies['albumdisambig'],
|
||||
album_info.albumdisambig)
|
||||
|
||||
# Album ID.
|
||||
if likelies['mb_albumid']:
|
||||
dist.add_equality('album_id', likelies['mb_albumid'],
|
||||
album_info.album_id)
|
||||
|
||||
# Tracks.
|
||||
dist.tracks = {}
|
||||
for item, track in mapping.iteritems():
|
||||
dist.tracks[track] = track_distance(item, track, album_info.va)
|
||||
dist.add('tracks', dist.tracks[track].distance)
|
||||
|
||||
# Missing tracks.
|
||||
for i in range(len(album_info.tracks) - len(mapping)):
|
||||
dist.add('missing_tracks', 1.0)
|
||||
|
||||
# Unmatched tracks.
|
||||
for i in range(len(items) - len(mapping)):
|
||||
dist.add('unmatched_tracks', 1.0)
|
||||
|
||||
# Plugins.
|
||||
dist.update(plugins.album_distance(items, album_info, mapping))
|
||||
|
||||
return dist
|
||||
|
||||
|
||||
def match_by_id(items):
|
||||
"""If the items are tagged with a MusicBrainz album ID, returns an
|
||||
AlbumInfo object for the corresponding album. Otherwise, returns
|
||||
None.
|
||||
"""
|
||||
# Is there a consensus on the MB album ID?
|
||||
albumids = [item.mb_albumid for item in items if item.mb_albumid]
|
||||
if not albumids:
|
||||
log.debug(u'No album IDs found.')
|
||||
return None
|
||||
|
||||
# If all album IDs are equal, look up the album.
|
||||
if bool(reduce(lambda x, y: x if x == y else (), albumids)):
|
||||
albumid = albumids[0]
|
||||
log.debug(u'Searching for discovered album ID: {0}'.format(albumid))
|
||||
return hooks.album_for_mbid(albumid)
|
||||
else:
|
||||
log.debug(u'No album ID consensus.')
|
||||
|
||||
|
||||
def _recommendation(results):
|
||||
"""Given a sorted list of AlbumMatch or TrackMatch objects, return a
|
||||
recommendation based on the results' distances.
|
||||
|
||||
If the recommendation is higher than the configured maximum for
|
||||
an applied penalty, the recommendation will be downgraded to the
|
||||
configured maximum for that penalty.
|
||||
"""
|
||||
if not results:
|
||||
# No candidates: no recommendation.
|
||||
return Recommendation.none
|
||||
|
||||
# Basic distance thresholding.
|
||||
min_dist = results[0].distance
|
||||
if min_dist < config['match']['strong_rec_thresh'].as_number():
|
||||
# Strong recommendation level.
|
||||
rec = Recommendation.strong
|
||||
elif min_dist <= config['match']['medium_rec_thresh'].as_number():
|
||||
# Medium recommendation level.
|
||||
rec = Recommendation.medium
|
||||
elif len(results) == 1:
|
||||
# Only a single candidate.
|
||||
rec = Recommendation.low
|
||||
elif results[1].distance - min_dist >= \
|
||||
config['match']['rec_gap_thresh'].as_number():
|
||||
# Gap between first two candidates is large.
|
||||
rec = Recommendation.low
|
||||
else:
|
||||
# No conclusion. Return immediately. Can't be downgraded any further.
|
||||
return Recommendation.none
|
||||
|
||||
# Downgrade to the max rec if it is lower than the current rec for an
|
||||
# applied penalty.
|
||||
keys = set(min_dist.keys())
|
||||
if isinstance(results[0], hooks.AlbumMatch):
|
||||
for track_dist in min_dist.tracks.values():
|
||||
keys.update(track_dist.keys())
|
||||
max_rec_view = config['match']['max_rec']
|
||||
for key in keys:
|
||||
if key in max_rec_view.keys():
|
||||
max_rec = max_rec_view[key].as_choice({
|
||||
'strong': Recommendation.strong,
|
||||
'medium': Recommendation.medium,
|
||||
'low': Recommendation.low,
|
||||
'none': Recommendation.none,
|
||||
})
|
||||
rec = min(rec, max_rec)
|
||||
|
||||
return rec
|
||||
|
||||
|
||||
def _add_candidate(items, results, info):
|
||||
"""Given a candidate AlbumInfo object, attempt to add the candidate
|
||||
to the output dictionary of AlbumMatch objects. This involves
|
||||
checking the track count, ordering the items, checking for
|
||||
duplicates, and calculating the distance.
|
||||
"""
|
||||
log.debug(u'Candidate: {0} - {1}'.format(info.artist, info.album))
|
||||
|
||||
# Discard albums with zero tracks.
|
||||
if not info.tracks:
|
||||
log.debug('No tracks.')
|
||||
return
|
||||
|
||||
# Don't duplicate.
|
||||
if info.album_id in results:
|
||||
log.debug(u'Duplicate.')
|
||||
return
|
||||
|
||||
# Discard matches without required tags.
|
||||
for req_tag in config['match']['required'].as_str_seq():
|
||||
if getattr(info, req_tag) is None:
|
||||
log.debug(u'Ignored. Missing required tag: {0}'.format(req_tag))
|
||||
return
|
||||
|
||||
# Find mapping between the items and the track info.
|
||||
mapping, extra_items, extra_tracks = assign_items(items, info.tracks)
|
||||
|
||||
# Get the change distance.
|
||||
dist = distance(items, info, mapping)
|
||||
|
||||
# Skip matches with ignored penalties.
|
||||
penalties = [key for _, key in dist]
|
||||
for penalty in config['match']['ignored'].as_str_seq():
|
||||
if penalty in penalties:
|
||||
log.debug(u'Ignored. Penalty: {0}'.format(penalty))
|
||||
return
|
||||
|
||||
log.debug(u'Success. Distance: {0}'.format(dist))
|
||||
results[info.album_id] = hooks.AlbumMatch(dist, info, mapping,
|
||||
extra_items, extra_tracks)
|
||||
|
||||
|
||||
def tag_album(items, search_artist=None, search_album=None,
|
||||
search_id=None):
|
||||
"""Return a tuple of a artist name, an album name, a list of
|
||||
`AlbumMatch` candidates from the metadata backend, and a
|
||||
`Recommendation`.
|
||||
|
||||
The artist and album are the most common values of these fields
|
||||
among `items`.
|
||||
|
||||
The `AlbumMatch` objects are generated by searching the metadata
|
||||
backends. By default, the metadata of the items is used for the
|
||||
search. This can be customized by setting the parameters. The
|
||||
`mapping` field of the album has the matched `items` as keys.
|
||||
|
||||
The recommendation is calculated from the match qualitiy of the
|
||||
candidates.
|
||||
"""
|
||||
# Get current metadata.
|
||||
likelies, consensus = current_metadata(items)
|
||||
cur_artist = likelies['artist']
|
||||
cur_album = likelies['album']
|
||||
log.debug(u'Tagging {0} - {1}'.format(cur_artist, cur_album))
|
||||
|
||||
# The output result (distance, AlbumInfo) tuples (keyed by MB album
|
||||
# ID).
|
||||
candidates = {}
|
||||
|
||||
# Search by explicit ID.
|
||||
if search_id is not None:
|
||||
log.debug(u'Searching for album ID: {0}'.format(search_id))
|
||||
search_cands = hooks.albums_for_id(search_id)
|
||||
|
||||
# Use existing metadata or text search.
|
||||
else:
|
||||
# Try search based on current ID.
|
||||
id_info = match_by_id(items)
|
||||
if id_info:
|
||||
_add_candidate(items, candidates, id_info)
|
||||
rec = _recommendation(candidates.values())
|
||||
log.debug(u'Album ID match recommendation is {0}'.format(str(rec)))
|
||||
if candidates and not config['import']['timid']:
|
||||
# If we have a very good MBID match, return immediately.
|
||||
# Otherwise, this match will compete against metadata-based
|
||||
# matches.
|
||||
if rec == Recommendation.strong:
|
||||
log.debug(u'ID match.')
|
||||
return cur_artist, cur_album, candidates.values(), rec
|
||||
|
||||
# Search terms.
|
||||
if not (search_artist and search_album):
|
||||
# No explicit search terms -- use current metadata.
|
||||
search_artist, search_album = cur_artist, cur_album
|
||||
log.debug(u'Search terms: {0} - {1}'.format(search_artist,
|
||||
search_album))
|
||||
|
||||
# Is this album likely to be a "various artist" release?
|
||||
va_likely = ((not consensus['artist']) or
|
||||
(search_artist.lower() in VA_ARTISTS) or
|
||||
any(item.comp for item in items))
|
||||
log.debug(u'Album might be VA: {0}'.format(str(va_likely)))
|
||||
|
||||
# Get the results from the data sources.
|
||||
search_cands = hooks.album_candidates(items, search_artist,
|
||||
search_album, va_likely)
|
||||
|
||||
log.debug(u'Evaluating {0} candidates.'.format(len(search_cands)))
|
||||
for info in search_cands:
|
||||
_add_candidate(items, candidates, info)
|
||||
|
||||
# Sort and get the recommendation.
|
||||
candidates = sorted(candidates.itervalues())
|
||||
rec = _recommendation(candidates)
|
||||
return cur_artist, cur_album, candidates, rec
|
||||
|
||||
|
||||
def tag_item(item, search_artist=None, search_title=None,
|
||||
search_id=None):
|
||||
"""Attempts to find metadata for a single track. Returns a
|
||||
`(candidates, recommendation)` pair where `candidates` is a list of
|
||||
TrackMatch objects. `search_artist` and `search_title` may be used
|
||||
to override the current metadata for the purposes of the MusicBrainz
|
||||
title; likewise `search_id`.
|
||||
"""
|
||||
# Holds candidates found so far: keys are MBIDs; values are
|
||||
# (distance, TrackInfo) pairs.
|
||||
candidates = {}
|
||||
|
||||
# First, try matching by MusicBrainz ID.
|
||||
trackid = search_id or item.mb_trackid
|
||||
if trackid:
|
||||
log.debug(u'Searching for track ID: {0}'.format(trackid))
|
||||
for track_info in hooks.tracks_for_id(trackid):
|
||||
dist = track_distance(item, track_info, incl_artist=True)
|
||||
candidates[track_info.track_id] = \
|
||||
hooks.TrackMatch(dist, track_info)
|
||||
# If this is a good match, then don't keep searching.
|
||||
rec = _recommendation(candidates.values())
|
||||
if rec == Recommendation.strong and not config['import']['timid']:
|
||||
log.debug(u'Track ID match.')
|
||||
return candidates.values(), rec
|
||||
|
||||
# If we're searching by ID, don't proceed.
|
||||
if search_id is not None:
|
||||
if candidates:
|
||||
return candidates.values(), rec
|
||||
else:
|
||||
return [], Recommendation.none
|
||||
|
||||
# Search terms.
|
||||
if not (search_artist and search_title):
|
||||
search_artist, search_title = item.artist, item.title
|
||||
log.debug(u'Item search terms: {0} - {1}'.format(search_artist,
|
||||
search_title))
|
||||
|
||||
# Get and evaluate candidate metadata.
|
||||
for track_info in hooks.item_candidates(item, search_artist, search_title):
|
||||
dist = track_distance(item, track_info, incl_artist=True)
|
||||
candidates[track_info.track_id] = hooks.TrackMatch(dist, track_info)
|
||||
|
||||
# Sort by distance and return with recommendation.
|
||||
log.debug(u'Found {0} candidates.'.format(len(candidates)))
|
||||
candidates = sorted(candidates.itervalues())
|
||||
rec = _recommendation(candidates)
|
||||
return candidates, rec
|
||||
@@ -0,0 +1,407 @@
|
||||
# This file is part of beets.
|
||||
# Copyright 2013, Adrian Sampson.
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining
|
||||
# a copy of this software and associated documentation files (the
|
||||
# "Software"), to deal in the Software without restriction, including
|
||||
# without limitation the rights to use, copy, modify, merge, publish,
|
||||
# distribute, sublicense, and/or sell copies of the Software, and to
|
||||
# permit persons to whom the Software is furnished to do so, subject to
|
||||
# the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be
|
||||
# included in all copies or substantial portions of the Software.
|
||||
|
||||
"""Searches for albums in the MusicBrainz database.
|
||||
"""
|
||||
import logging
|
||||
import musicbrainzngs
|
||||
import re
|
||||
import traceback
|
||||
from urlparse import urljoin
|
||||
|
||||
import beets.autotag.hooks
|
||||
import beets
|
||||
from beets import util
|
||||
from beets import config
|
||||
|
||||
SEARCH_LIMIT = 5
|
||||
VARIOUS_ARTISTS_ID = '89ad4ac3-39f7-470e-963a-56509c546377'
|
||||
BASE_URL = 'http://musicbrainz.org/'
|
||||
|
||||
musicbrainzngs.set_useragent('beets', beets.__version__,
|
||||
'http://beets.radbox.org/')
|
||||
|
||||
|
||||
class MusicBrainzAPIError(util.HumanReadableException):
|
||||
"""An error while talking to MusicBrainz. The `query` field is the
|
||||
parameter to the action and may have any type.
|
||||
"""
|
||||
def __init__(self, reason, verb, query, tb=None):
|
||||
self.query = query
|
||||
super(MusicBrainzAPIError, self).__init__(reason, verb, tb)
|
||||
|
||||
def get_message(self):
|
||||
return u'{0} in {1} with query {2}'.format(
|
||||
self._reasonstr(), self.verb, repr(self.query)
|
||||
)
|
||||
|
||||
log = logging.getLogger('beets')
|
||||
|
||||
RELEASE_INCLUDES = ['artists', 'media', 'recordings', 'release-groups',
|
||||
'labels', 'artist-credits', 'aliases']
|
||||
TRACK_INCLUDES = ['artists', 'aliases']
|
||||
|
||||
|
||||
def track_url(trackid):
|
||||
return urljoin(BASE_URL, 'recording/' + trackid)
|
||||
|
||||
|
||||
def album_url(albumid):
|
||||
return urljoin(BASE_URL, 'release/' + albumid)
|
||||
|
||||
|
||||
def configure():
|
||||
"""Set up the python-musicbrainz-ngs module according to settings
|
||||
from the beets configuration. This should be called at startup.
|
||||
"""
|
||||
musicbrainzngs.set_hostname(config['musicbrainz']['host'].get(unicode))
|
||||
musicbrainzngs.set_rate_limit(
|
||||
config['musicbrainz']['ratelimit_interval'].as_number(),
|
||||
config['musicbrainz']['ratelimit'].get(int),
|
||||
)
|
||||
|
||||
|
||||
def _preferred_alias(aliases):
|
||||
"""Given an list of alias structures for an artist credit, select
|
||||
and return the user's preferred alias alias or None if no matching
|
||||
alias is found.
|
||||
"""
|
||||
if not aliases:
|
||||
return
|
||||
|
||||
# Only consider aliases that have locales set.
|
||||
aliases = [a for a in aliases if 'locale' in a]
|
||||
|
||||
# Search configured locales in order.
|
||||
for locale in config['import']['languages'].as_str_seq():
|
||||
# Find matching primary aliases for this locale.
|
||||
matches = [a for a in aliases
|
||||
if a['locale'] == locale and 'primary' in a]
|
||||
# Skip to the next locale if we have no matches
|
||||
if not matches:
|
||||
continue
|
||||
|
||||
return matches[0]
|
||||
|
||||
|
||||
def _flatten_artist_credit(credit):
|
||||
"""Given a list representing an ``artist-credit`` block, flatten the
|
||||
data into a triple of joined artist name strings: canonical, sort, and
|
||||
credit.
|
||||
"""
|
||||
artist_parts = []
|
||||
artist_sort_parts = []
|
||||
artist_credit_parts = []
|
||||
for el in credit:
|
||||
if isinstance(el, basestring):
|
||||
# Join phrase.
|
||||
artist_parts.append(el)
|
||||
artist_credit_parts.append(el)
|
||||
artist_sort_parts.append(el)
|
||||
|
||||
else:
|
||||
alias = _preferred_alias(el['artist'].get('alias-list', ()))
|
||||
|
||||
# An artist.
|
||||
if alias:
|
||||
cur_artist_name = alias['alias']
|
||||
else:
|
||||
cur_artist_name = el['artist']['name']
|
||||
artist_parts.append(cur_artist_name)
|
||||
|
||||
# Artist sort name.
|
||||
if alias:
|
||||
artist_sort_parts.append(alias['sort-name'])
|
||||
elif 'sort-name' in el['artist']:
|
||||
artist_sort_parts.append(el['artist']['sort-name'])
|
||||
else:
|
||||
artist_sort_parts.append(cur_artist_name)
|
||||
|
||||
# Artist credit.
|
||||
if 'name' in el:
|
||||
artist_credit_parts.append(el['name'])
|
||||
else:
|
||||
artist_credit_parts.append(cur_artist_name)
|
||||
|
||||
return (
|
||||
''.join(artist_parts),
|
||||
''.join(artist_sort_parts),
|
||||
''.join(artist_credit_parts),
|
||||
)
|
||||
|
||||
|
||||
def track_info(recording, index=None, medium=None, medium_index=None,
|
||||
medium_total=None):
|
||||
"""Translates a MusicBrainz recording result dictionary into a beets
|
||||
``TrackInfo`` object. Three parameters are optional and are used
|
||||
only for tracks that appear on releases (non-singletons): ``index``,
|
||||
the overall track number; ``medium``, the disc number;
|
||||
``medium_index``, the track's index on its medium; ``medium_total``,
|
||||
the number of tracks on the medium. Each number is a 1-based index.
|
||||
"""
|
||||
info = beets.autotag.hooks.TrackInfo(
|
||||
recording['title'],
|
||||
recording['id'],
|
||||
index=index,
|
||||
medium=medium,
|
||||
medium_index=medium_index,
|
||||
medium_total=medium_total,
|
||||
data_url=track_url(recording['id']),
|
||||
)
|
||||
|
||||
if recording.get('artist-credit'):
|
||||
# Get the artist names.
|
||||
info.artist, info.artist_sort, info.artist_credit = \
|
||||
_flatten_artist_credit(recording['artist-credit'])
|
||||
|
||||
# Get the ID and sort name of the first artist.
|
||||
artist = recording['artist-credit'][0]['artist']
|
||||
info.artist_id = artist['id']
|
||||
|
||||
if recording.get('length'):
|
||||
info.length = int(recording['length']) / (1000.0)
|
||||
|
||||
info.decode()
|
||||
return info
|
||||
|
||||
|
||||
def _set_date_str(info, date_str, original=False):
|
||||
"""Given a (possibly partial) YYYY-MM-DD string and an AlbumInfo
|
||||
object, set the object's release date fields appropriately. If
|
||||
`original`, then set the original_year, etc., fields.
|
||||
"""
|
||||
if date_str:
|
||||
date_parts = date_str.split('-')
|
||||
for key in ('year', 'month', 'day'):
|
||||
if date_parts:
|
||||
date_part = date_parts.pop(0)
|
||||
try:
|
||||
date_num = int(date_part)
|
||||
except ValueError:
|
||||
continue
|
||||
|
||||
if original:
|
||||
key = 'original_' + key
|
||||
setattr(info, key, date_num)
|
||||
|
||||
|
||||
def album_info(release):
|
||||
"""Takes a MusicBrainz release result dictionary and returns a beets
|
||||
AlbumInfo object containing the interesting data about that release.
|
||||
"""
|
||||
# Get artist name using join phrases.
|
||||
artist_name, artist_sort_name, artist_credit_name = \
|
||||
_flatten_artist_credit(release['artist-credit'])
|
||||
|
||||
# Basic info.
|
||||
track_infos = []
|
||||
index = 0
|
||||
for medium in release['medium-list']:
|
||||
disctitle = medium.get('title')
|
||||
format = medium.get('format')
|
||||
for track in medium['track-list']:
|
||||
# Basic information from the recording.
|
||||
index += 1
|
||||
ti = track_info(
|
||||
track['recording'],
|
||||
index,
|
||||
int(medium['position']),
|
||||
int(track['position']),
|
||||
len(medium['track-list']),
|
||||
)
|
||||
ti.disctitle = disctitle
|
||||
ti.media = format
|
||||
|
||||
# Prefer track data, where present, over recording data.
|
||||
if track.get('title'):
|
||||
ti.title = track['title']
|
||||
if track.get('artist-credit'):
|
||||
# Get the artist names.
|
||||
ti.artist, ti.artist_sort, ti.artist_credit = \
|
||||
_flatten_artist_credit(track['artist-credit'])
|
||||
ti.artist_id = track['artist-credit'][0]['artist']['id']
|
||||
if track.get('length'):
|
||||
ti.length = int(track['length']) / (1000.0)
|
||||
|
||||
track_infos.append(ti)
|
||||
|
||||
info = beets.autotag.hooks.AlbumInfo(
|
||||
release['title'],
|
||||
release['id'],
|
||||
artist_name,
|
||||
release['artist-credit'][0]['artist']['id'],
|
||||
track_infos,
|
||||
mediums=len(release['medium-list']),
|
||||
artist_sort=artist_sort_name,
|
||||
artist_credit=artist_credit_name,
|
||||
data_source='MusicBrainz',
|
||||
data_url=album_url(release['id']),
|
||||
)
|
||||
info.va = info.artist_id == VARIOUS_ARTISTS_ID
|
||||
info.asin = release.get('asin')
|
||||
info.releasegroup_id = release['release-group']['id']
|
||||
info.country = release.get('country')
|
||||
info.albumstatus = release.get('status')
|
||||
|
||||
# Build up the disambiguation string from the release group and release.
|
||||
disambig = []
|
||||
if release['release-group'].get('disambiguation'):
|
||||
disambig.append(release['release-group'].get('disambiguation'))
|
||||
if release.get('disambiguation'):
|
||||
disambig.append(release.get('disambiguation'))
|
||||
info.albumdisambig = u', '.join(disambig)
|
||||
|
||||
# Release type not always populated.
|
||||
if 'type' in release['release-group']:
|
||||
reltype = release['release-group']['type']
|
||||
if reltype:
|
||||
info.albumtype = reltype.lower()
|
||||
|
||||
# Release dates.
|
||||
release_date = release.get('date')
|
||||
release_group_date = release['release-group'].get('first-release-date')
|
||||
if not release_date:
|
||||
# Fall back if release-specific date is not available.
|
||||
release_date = release_group_date
|
||||
_set_date_str(info, release_date, False)
|
||||
_set_date_str(info, release_group_date, True)
|
||||
|
||||
# Label name.
|
||||
if release.get('label-info-list'):
|
||||
label_info = release['label-info-list'][0]
|
||||
if label_info.get('label'):
|
||||
label = label_info['label']['name']
|
||||
if label != '[no label]':
|
||||
info.label = label
|
||||
info.catalognum = label_info.get('catalog-number')
|
||||
|
||||
# Text representation data.
|
||||
if release.get('text-representation'):
|
||||
rep = release['text-representation']
|
||||
info.script = rep.get('script')
|
||||
info.language = rep.get('language')
|
||||
|
||||
# Media (format).
|
||||
if release['medium-list']:
|
||||
first_medium = release['medium-list'][0]
|
||||
info.media = first_medium.get('format')
|
||||
|
||||
info.decode()
|
||||
return info
|
||||
|
||||
|
||||
def match_album(artist, album, tracks=None, limit=SEARCH_LIMIT):
|
||||
"""Searches for a single album ("release" in MusicBrainz parlance)
|
||||
and returns an iterator over AlbumInfo objects. May raise a
|
||||
MusicBrainzAPIError.
|
||||
|
||||
The query consists of an artist name, an album name, and,
|
||||
optionally, a number of tracks on the album.
|
||||
"""
|
||||
# Build search criteria.
|
||||
criteria = {'release': album.lower().strip()}
|
||||
if artist is not None:
|
||||
criteria['artist'] = artist.lower().strip()
|
||||
else:
|
||||
# Various Artists search.
|
||||
criteria['arid'] = VARIOUS_ARTISTS_ID
|
||||
if tracks is not None:
|
||||
criteria['tracks'] = str(tracks)
|
||||
|
||||
# Abort if we have no search terms.
|
||||
if not any(criteria.itervalues()):
|
||||
return
|
||||
|
||||
try:
|
||||
res = musicbrainzngs.search_releases(limit=limit, **criteria)
|
||||
except musicbrainzngs.MusicBrainzError as exc:
|
||||
raise MusicBrainzAPIError(exc, 'release search', criteria,
|
||||
traceback.format_exc())
|
||||
for release in res['release-list']:
|
||||
# The search result is missing some data (namely, the tracks),
|
||||
# so we just use the ID and fetch the rest of the information.
|
||||
albuminfo = album_for_id(release['id'])
|
||||
if albuminfo is not None:
|
||||
yield albuminfo
|
||||
|
||||
|
||||
def match_track(artist, title, limit=SEARCH_LIMIT):
|
||||
"""Searches for a single track and returns an iterable of TrackInfo
|
||||
objects. May raise a MusicBrainzAPIError.
|
||||
"""
|
||||
criteria = {
|
||||
'artist': artist.lower().strip(),
|
||||
'recording': title.lower().strip(),
|
||||
}
|
||||
|
||||
if not any(criteria.itervalues()):
|
||||
return
|
||||
|
||||
try:
|
||||
res = musicbrainzngs.search_recordings(limit=limit, **criteria)
|
||||
except musicbrainzngs.MusicBrainzError as exc:
|
||||
raise MusicBrainzAPIError(exc, 'recording search', criteria,
|
||||
traceback.format_exc())
|
||||
for recording in res['recording-list']:
|
||||
yield track_info(recording)
|
||||
|
||||
|
||||
def _parse_id(s):
|
||||
"""Search for a MusicBrainz ID in the given string and return it. If
|
||||
no ID can be found, return None.
|
||||
"""
|
||||
# Find the first thing that looks like a UUID/MBID.
|
||||
match = re.search('[a-f0-9]{8}(-[a-f0-9]{4}){3}-[a-f0-9]{12}', s)
|
||||
if match:
|
||||
return match.group()
|
||||
|
||||
|
||||
def album_for_id(releaseid):
|
||||
"""Fetches an album by its MusicBrainz ID and returns an AlbumInfo
|
||||
object or None if the album is not found. May raise a
|
||||
MusicBrainzAPIError.
|
||||
"""
|
||||
albumid = _parse_id(releaseid)
|
||||
if not albumid:
|
||||
log.debug(u'Invalid MBID ({0}).'.format(releaseid))
|
||||
return
|
||||
try:
|
||||
res = musicbrainzngs.get_release_by_id(albumid,
|
||||
RELEASE_INCLUDES)
|
||||
except musicbrainzngs.ResponseError:
|
||||
log.debug(u'Album ID match failed.')
|
||||
return None
|
||||
except musicbrainzngs.MusicBrainzError as exc:
|
||||
raise MusicBrainzAPIError(exc, 'get release by ID', albumid,
|
||||
traceback.format_exc())
|
||||
return album_info(res['release'])
|
||||
|
||||
|
||||
def track_for_id(releaseid):
|
||||
"""Fetches a track by its MusicBrainz ID. Returns a TrackInfo object
|
||||
or None if no track is found. May raise a MusicBrainzAPIError.
|
||||
"""
|
||||
trackid = _parse_id(releaseid)
|
||||
if not trackid:
|
||||
log.debug(u'Invalid MBID ({0}).'.format(releaseid))
|
||||
return
|
||||
try:
|
||||
res = musicbrainzngs.get_recording_by_id(trackid, TRACK_INCLUDES)
|
||||
except musicbrainzngs.ResponseError:
|
||||
log.debug(u'Track ID match failed.')
|
||||
return None
|
||||
except musicbrainzngs.MusicBrainzError as exc:
|
||||
raise MusicBrainzAPIError(exc, 'get recording by ID', trackid,
|
||||
traceback.format_exc())
|
||||
return track_info(res['recording'])
|
||||
@@ -0,0 +1,109 @@
|
||||
library: library.db
|
||||
directory: ~/Music
|
||||
|
||||
import:
|
||||
write: yes
|
||||
copy: yes
|
||||
move: no
|
||||
link: no
|
||||
delete: no
|
||||
resume: ask
|
||||
incremental: no
|
||||
quiet_fallback: skip
|
||||
none_rec_action: ask
|
||||
timid: no
|
||||
log:
|
||||
autotag: yes
|
||||
quiet: no
|
||||
singletons: no
|
||||
default_action: apply
|
||||
languages: []
|
||||
detail: no
|
||||
flat: no
|
||||
group_albums: no
|
||||
pretend: false
|
||||
|
||||
clutter: ["Thumbs.DB", ".DS_Store"]
|
||||
ignore: [".*", "*~", "System Volume Information"]
|
||||
replace:
|
||||
'[\\/]': _
|
||||
'^\.': _
|
||||
'[\x00-\x1f]': _
|
||||
'[<>:"\?\*\|]': _
|
||||
'\.$': _
|
||||
'\s+$': ''
|
||||
'^\s+': ''
|
||||
path_sep_replace: _
|
||||
asciify_paths: false
|
||||
art_filename: cover
|
||||
max_filename_length: 0
|
||||
|
||||
plugins: []
|
||||
pluginpath: []
|
||||
threaded: yes
|
||||
color: yes
|
||||
timeout: 5.0
|
||||
per_disc_numbering: no
|
||||
verbose: no
|
||||
terminal_encoding: utf8
|
||||
original_date: no
|
||||
id3v23: no
|
||||
|
||||
ui:
|
||||
terminal_width: 80
|
||||
length_diff_thresh: 10.0
|
||||
|
||||
list_format_item: $artist - $album - $title
|
||||
list_format_album: $albumartist - $album
|
||||
time_format: '%Y-%m-%d %H:%M:%S'
|
||||
|
||||
sort_album: albumartist+ album+
|
||||
sort_item: artist+ album+ disc+ track+
|
||||
|
||||
paths:
|
||||
default: $albumartist/$album%aunique{}/$track $title
|
||||
singleton: Non-Album/$artist/$title
|
||||
comp: Compilations/$album%aunique{}/$track $title
|
||||
|
||||
statefile: state.pickle
|
||||
|
||||
musicbrainz:
|
||||
host: musicbrainz.org
|
||||
ratelimit: 1
|
||||
ratelimit_interval: 1.0
|
||||
|
||||
match:
|
||||
strong_rec_thresh: 0.04
|
||||
medium_rec_thresh: 0.25
|
||||
rec_gap_thresh: 0.25
|
||||
max_rec:
|
||||
missing_tracks: medium
|
||||
unmatched_tracks: medium
|
||||
distance_weights:
|
||||
source: 2.0
|
||||
artist: 3.0
|
||||
album: 3.0
|
||||
media: 1.0
|
||||
mediums: 1.0
|
||||
year: 1.0
|
||||
country: 0.5
|
||||
label: 0.5
|
||||
catalognum: 0.5
|
||||
albumdisambig: 0.5
|
||||
album_id: 5.0
|
||||
tracks: 2.0
|
||||
missing_tracks: 0.9
|
||||
unmatched_tracks: 0.6
|
||||
track_title: 3.0
|
||||
track_artist: 2.0
|
||||
track_index: 1.0
|
||||
track_length: 2.0
|
||||
track_id: 5.0
|
||||
preferred:
|
||||
countries: []
|
||||
media: []
|
||||
original_year: no
|
||||
ignored: []
|
||||
required: []
|
||||
track_length_grace: 10
|
||||
track_length_max: 30
|
||||
@@ -0,0 +1,25 @@
|
||||
# This file is part of beets.
|
||||
# Copyright 2014, Adrian Sampson.
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining
|
||||
# a copy of this software and associated documentation files (the
|
||||
# "Software"), to deal in the Software without restriction, including
|
||||
# without limitation the rights to use, copy, modify, merge, publish,
|
||||
# distribute, sublicense, and/or sell copies of the Software, and to
|
||||
# permit persons to whom the Software is furnished to do so, subject to
|
||||
# the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be
|
||||
# included in all copies or substantial portions of the Software.
|
||||
|
||||
"""DBCore is an abstract database package that forms the basis for beets'
|
||||
Library.
|
||||
"""
|
||||
from .db import Model, Database
|
||||
from .query import Query, FieldQuery, MatchQuery, AndQuery, OrQuery
|
||||
from .types import Type
|
||||
from .queryparse import query_from_strings
|
||||
from .queryparse import sort_from_strings
|
||||
from .queryparse import parse_sorted_query
|
||||
|
||||
# flake8: noqa
|
||||
@@ -0,0 +1,820 @@
|
||||
# This file is part of beets.
|
||||
# Copyright 2014, Adrian Sampson.
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining
|
||||
# a copy of this software and associated documentation files (the
|
||||
# "Software"), to deal in the Software without restriction, including
|
||||
# without limitation the rights to use, copy, modify, merge, publish,
|
||||
# distribute, sublicense, and/or sell copies of the Software, and to
|
||||
# permit persons to whom the Software is furnished to do so, subject to
|
||||
# the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be
|
||||
# included in all copies or substantial portions of the Software.
|
||||
|
||||
"""The central Model and Database constructs for DBCore.
|
||||
"""
|
||||
import time
|
||||
import os
|
||||
from collections import defaultdict
|
||||
import threading
|
||||
import sqlite3
|
||||
import contextlib
|
||||
import collections
|
||||
|
||||
import beets
|
||||
from beets.util.functemplate import Template
|
||||
from beets.dbcore import types
|
||||
from .query import MatchQuery, NullSort, TrueQuery
|
||||
|
||||
|
||||
class FormattedMapping(collections.Mapping):
|
||||
"""A `dict`-like formatted view of a model.
|
||||
|
||||
The accessor `mapping[key]` returns the formated version of
|
||||
`model[key]` as a unicode string.
|
||||
|
||||
If `for_path` is true, all path separators in the formatted values
|
||||
are replaced.
|
||||
"""
|
||||
|
||||
def __init__(self, model, for_path=False):
|
||||
self.for_path = for_path
|
||||
self.model = model
|
||||
self.model_keys = model.keys(True)
|
||||
|
||||
def __getitem__(self, key):
|
||||
if key in self.model_keys:
|
||||
return self._get_formatted(self.model, key)
|
||||
else:
|
||||
raise KeyError(key)
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self.model_keys)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.model_keys)
|
||||
|
||||
def get(self, key, default=None):
|
||||
if default is None:
|
||||
default = self.model._type(key).format(None)
|
||||
return super(FormattedMapping, self).get(key, default)
|
||||
|
||||
def _get_formatted(self, model, key):
|
||||
value = model._type(key).format(model.get(key))
|
||||
if isinstance(value, bytes):
|
||||
value = value.decode('utf8', 'ignore')
|
||||
|
||||
if self.for_path:
|
||||
sep_repl = beets.config['path_sep_replace'].get(unicode)
|
||||
for sep in (os.path.sep, os.path.altsep):
|
||||
if sep:
|
||||
value = value.replace(sep, sep_repl)
|
||||
|
||||
return value
|
||||
|
||||
|
||||
# Abstract base for model classes.
|
||||
|
||||
class Model(object):
|
||||
"""An abstract object representing an object in the database. Model
|
||||
objects act like dictionaries (i.e., the allow subscript access like
|
||||
``obj['field']``). The same field set is available via attribute
|
||||
access as a shortcut (i.e., ``obj.field``). Three kinds of attributes are
|
||||
available:
|
||||
|
||||
* **Fixed attributes** come from a predetermined list of field
|
||||
names. These fields correspond to SQLite table columns and are
|
||||
thus fast to read, write, and query.
|
||||
* **Flexible attributes** are free-form and do not need to be listed
|
||||
ahead of time.
|
||||
* **Computed attributes** are read-only fields computed by a getter
|
||||
function provided by a plugin.
|
||||
|
||||
Access to all three field types is uniform: ``obj.field`` works the
|
||||
same regardless of whether ``field`` is fixed, flexible, or
|
||||
computed.
|
||||
|
||||
Model objects can optionally be associated with a `Library` object,
|
||||
in which case they can be loaded and stored from the database. Dirty
|
||||
flags are used to track which fields need to be stored.
|
||||
"""
|
||||
|
||||
# Abstract components (to be provided by subclasses).
|
||||
|
||||
_table = None
|
||||
"""The main SQLite table name.
|
||||
"""
|
||||
|
||||
_flex_table = None
|
||||
"""The flex field SQLite table name.
|
||||
"""
|
||||
|
||||
_fields = {}
|
||||
"""A mapping indicating available "fixed" fields on this type. The
|
||||
keys are field names and the values are `Type` objects.
|
||||
"""
|
||||
|
||||
_search_fields = ()
|
||||
"""The fields that should be queried by default by unqualified query
|
||||
terms.
|
||||
"""
|
||||
|
||||
_types = {}
|
||||
"""Optional Types for non-fixed (i.e., flexible and computed) fields.
|
||||
"""
|
||||
|
||||
_sorts = {}
|
||||
"""Optional named sort criteria. The keys are strings and the values
|
||||
are subclasses of `Sort`.
|
||||
"""
|
||||
|
||||
_always_dirty = False
|
||||
"""By default, fields only become "dirty" when their value actually
|
||||
changes. Enabling this flag marks fields as dirty even when the new
|
||||
value is the same as the old value (e.g., `o.f = o.f`).
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def _getters(cls):
|
||||
"""Return a mapping from field names to getter functions.
|
||||
"""
|
||||
# We could cache this if it becomes a performance problem to
|
||||
# gather the getter mapping every time.
|
||||
raise NotImplementedError()
|
||||
|
||||
def _template_funcs(self):
|
||||
"""Return a mapping from function names to text-transformer
|
||||
functions.
|
||||
"""
|
||||
# As above: we could consider caching this result.
|
||||
raise NotImplementedError()
|
||||
|
||||
# Basic operation.
|
||||
|
||||
def __init__(self, db=None, **values):
|
||||
"""Create a new object with an optional Database association and
|
||||
initial field values.
|
||||
"""
|
||||
self._db = db
|
||||
self._dirty = set()
|
||||
self._values_fixed = {}
|
||||
self._values_flex = {}
|
||||
|
||||
# Initial contents.
|
||||
self.update(values)
|
||||
self.clear_dirty()
|
||||
|
||||
@classmethod
|
||||
def _awaken(cls, db=None, fixed_values={}, flex_values={}):
|
||||
"""Create an object with values drawn from the database.
|
||||
|
||||
This is a performance optimization: the checks involved with
|
||||
ordinary construction are bypassed.
|
||||
"""
|
||||
obj = cls(db)
|
||||
for key, value in fixed_values.iteritems():
|
||||
obj._values_fixed[key] = cls._type(key).from_sql(value)
|
||||
for key, value in flex_values.iteritems():
|
||||
obj._values_flex[key] = cls._type(key).from_sql(value)
|
||||
return obj
|
||||
|
||||
def __repr__(self):
|
||||
return '{0}({1})'.format(
|
||||
type(self).__name__,
|
||||
', '.join('{0}={1!r}'.format(k, v) for k, v in dict(self).items()),
|
||||
)
|
||||
|
||||
def clear_dirty(self):
|
||||
"""Mark all fields as *clean* (i.e., not needing to be stored to
|
||||
the database).
|
||||
"""
|
||||
self._dirty = set()
|
||||
|
||||
def _check_db(self, need_id=True):
|
||||
"""Ensure that this object is associated with a database row: it
|
||||
has a reference to a database (`_db`) and an id. A ValueError
|
||||
exception is raised otherwise.
|
||||
"""
|
||||
if not self._db:
|
||||
raise ValueError('{0} has no database'.format(type(self).__name__))
|
||||
if need_id and not self.id:
|
||||
raise ValueError('{0} has no id'.format(type(self).__name__))
|
||||
|
||||
# Essential field accessors.
|
||||
|
||||
@classmethod
|
||||
def _type(self, key):
|
||||
"""Get the type of a field, a `Type` instance.
|
||||
|
||||
If the field has no explicit type, it is given the base `Type`,
|
||||
which does no conversion.
|
||||
"""
|
||||
return self._fields.get(key) or self._types.get(key) or types.DEFAULT
|
||||
|
||||
def __getitem__(self, key):
|
||||
"""Get the value for a field. Raise a KeyError if the field is
|
||||
not available.
|
||||
"""
|
||||
getters = self._getters()
|
||||
if key in getters: # Computed.
|
||||
return getters[key](self)
|
||||
elif key in self._fields: # Fixed.
|
||||
return self._values_fixed.get(key)
|
||||
elif key in self._values_flex: # Flexible.
|
||||
return self._values_flex[key]
|
||||
else:
|
||||
raise KeyError(key)
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
"""Assign the value for a field.
|
||||
"""
|
||||
# Choose where to place the value.
|
||||
if key in self._fields:
|
||||
source = self._values_fixed
|
||||
else:
|
||||
source = self._values_flex
|
||||
|
||||
# If the field has a type, filter the value.
|
||||
value = self._type(key).normalize(value)
|
||||
|
||||
# Assign value and possibly mark as dirty.
|
||||
old_value = source.get(key)
|
||||
source[key] = value
|
||||
if self._always_dirty or old_value != value:
|
||||
self._dirty.add(key)
|
||||
|
||||
def __delitem__(self, key):
|
||||
"""Remove a flexible attribute from the model.
|
||||
"""
|
||||
if key in self._values_flex: # Flexible.
|
||||
del self._values_flex[key]
|
||||
self._dirty.add(key) # Mark for dropping on store.
|
||||
elif key in self._getters(): # Computed.
|
||||
raise KeyError('computed field {0} cannot be deleted'.format(key))
|
||||
elif key in self._fields: # Fixed.
|
||||
raise KeyError('fixed field {0} cannot be deleted'.format(key))
|
||||
else:
|
||||
raise KeyError('no such field {0}'.format(key))
|
||||
|
||||
def keys(self, computed=False):
|
||||
"""Get a list of available field names for this object. The
|
||||
`computed` parameter controls whether computed (plugin-provided)
|
||||
fields are included in the key list.
|
||||
"""
|
||||
base_keys = list(self._fields) + self._values_flex.keys()
|
||||
if computed:
|
||||
return base_keys + self._getters().keys()
|
||||
else:
|
||||
return base_keys
|
||||
|
||||
# Act like a dictionary.
|
||||
|
||||
def update(self, values):
|
||||
"""Assign all values in the given dict.
|
||||
"""
|
||||
for key, value in values.items():
|
||||
self[key] = value
|
||||
|
||||
def items(self):
|
||||
"""Iterate over (key, value) pairs that this object contains.
|
||||
Computed fields are not included.
|
||||
"""
|
||||
for key in self:
|
||||
yield key, self[key]
|
||||
|
||||
def get(self, key, default=None):
|
||||
"""Get the value for a given key or `default` if it does not
|
||||
exist.
|
||||
"""
|
||||
if key in self:
|
||||
return self[key]
|
||||
else:
|
||||
return default
|
||||
|
||||
def __contains__(self, key):
|
||||
"""Determine whether `key` is an attribute on this object.
|
||||
"""
|
||||
return key in self.keys(True)
|
||||
|
||||
def __iter__(self):
|
||||
"""Iterate over the available field names (excluding computed
|
||||
fields).
|
||||
"""
|
||||
return iter(self.keys())
|
||||
|
||||
# Convenient attribute access.
|
||||
|
||||
def __getattr__(self, key):
|
||||
if key.startswith('_'):
|
||||
raise AttributeError('model has no attribute {0!r}'.format(key))
|
||||
else:
|
||||
try:
|
||||
return self[key]
|
||||
except KeyError:
|
||||
raise AttributeError('no such field {0!r}'.format(key))
|
||||
|
||||
def __setattr__(self, key, value):
|
||||
if key.startswith('_'):
|
||||
super(Model, self).__setattr__(key, value)
|
||||
else:
|
||||
self[key] = value
|
||||
|
||||
def __delattr__(self, key):
|
||||
if key.startswith('_'):
|
||||
super(Model, self).__delattr__(key)
|
||||
else:
|
||||
del self[key]
|
||||
|
||||
# Database interaction (CRUD methods).
|
||||
|
||||
def store(self):
|
||||
"""Save the object's metadata into the library database.
|
||||
"""
|
||||
self._check_db()
|
||||
|
||||
# Build assignments for query.
|
||||
assignments = []
|
||||
subvars = []
|
||||
for key in self._fields:
|
||||
if key != 'id' and key in self._dirty:
|
||||
self._dirty.remove(key)
|
||||
assignments.append(key + '=?')
|
||||
value = self._type(key).to_sql(self[key])
|
||||
subvars.append(value)
|
||||
assignments = ','.join(assignments)
|
||||
|
||||
with self._db.transaction() as tx:
|
||||
# Main table update.
|
||||
if assignments:
|
||||
query = 'UPDATE {0} SET {1} WHERE id=?'.format(
|
||||
self._table, assignments
|
||||
)
|
||||
subvars.append(self.id)
|
||||
tx.mutate(query, subvars)
|
||||
|
||||
# Modified/added flexible attributes.
|
||||
for key, value in self._values_flex.items():
|
||||
if key in self._dirty:
|
||||
self._dirty.remove(key)
|
||||
tx.mutate(
|
||||
'INSERT INTO {0} '
|
||||
'(entity_id, key, value) '
|
||||
'VALUES (?, ?, ?);'.format(self._flex_table),
|
||||
(self.id, key, value),
|
||||
)
|
||||
|
||||
# Deleted flexible attributes.
|
||||
for key in self._dirty:
|
||||
tx.mutate(
|
||||
'DELETE FROM {0} '
|
||||
'WHERE entity_id=? AND key=?'.format(self._flex_table),
|
||||
(self.id, key)
|
||||
)
|
||||
|
||||
self.clear_dirty()
|
||||
|
||||
def load(self):
|
||||
"""Refresh the object's metadata from the library database.
|
||||
"""
|
||||
self._check_db()
|
||||
stored_obj = self._db._get(type(self), self.id)
|
||||
assert stored_obj is not None, "object {0} not in DB".format(self.id)
|
||||
self._values_fixed = {}
|
||||
self._values_flex = {}
|
||||
self.update(dict(stored_obj))
|
||||
self.clear_dirty()
|
||||
|
||||
def remove(self):
|
||||
"""Remove the object's associated rows from the database.
|
||||
"""
|
||||
self._check_db()
|
||||
with self._db.transaction() as tx:
|
||||
tx.mutate(
|
||||
'DELETE FROM {0} WHERE id=?'.format(self._table),
|
||||
(self.id,)
|
||||
)
|
||||
tx.mutate(
|
||||
'DELETE FROM {0} WHERE entity_id=?'.format(self._flex_table),
|
||||
(self.id,)
|
||||
)
|
||||
|
||||
def add(self, db=None):
|
||||
"""Add the object to the library database. This object must be
|
||||
associated with a database; you can provide one via the `db`
|
||||
parameter or use the currently associated database.
|
||||
|
||||
The object's `id` and `added` fields are set along with any
|
||||
current field values.
|
||||
"""
|
||||
if db:
|
||||
self._db = db
|
||||
self._check_db(False)
|
||||
|
||||
with self._db.transaction() as tx:
|
||||
new_id = tx.mutate(
|
||||
'INSERT INTO {0} DEFAULT VALUES'.format(self._table)
|
||||
)
|
||||
self.id = new_id
|
||||
self.added = time.time()
|
||||
|
||||
# Mark every non-null field as dirty and store.
|
||||
for key in self:
|
||||
if self[key] is not None:
|
||||
self._dirty.add(key)
|
||||
self.store()
|
||||
|
||||
# Formatting and templating.
|
||||
|
||||
_formatter = FormattedMapping
|
||||
|
||||
def formatted(self, for_path=False):
|
||||
"""Get a mapping containing all values on this object formatted
|
||||
as human-readable unicode strings.
|
||||
"""
|
||||
return self._formatter(self, for_path)
|
||||
|
||||
def evaluate_template(self, template, for_path=False):
|
||||
"""Evaluate a template (a string or a `Template` object) using
|
||||
the object's fields. If `for_path` is true, then no new path
|
||||
separators will be added to the template.
|
||||
"""
|
||||
# Perform substitution.
|
||||
if isinstance(template, basestring):
|
||||
template = Template(template)
|
||||
return template.substitute(self.formatted(for_path),
|
||||
self._template_funcs())
|
||||
|
||||
# Parsing.
|
||||
|
||||
@classmethod
|
||||
def _parse(cls, key, string):
|
||||
"""Parse a string as a value for the given key.
|
||||
"""
|
||||
if not isinstance(string, basestring):
|
||||
raise TypeError("_parse() argument must be a string")
|
||||
|
||||
return cls._type(key).parse(string)
|
||||
|
||||
|
||||
# Database controller and supporting interfaces.
|
||||
|
||||
class Results(object):
|
||||
"""An item query result set. Iterating over the collection lazily
|
||||
constructs LibModel objects that reflect database rows.
|
||||
"""
|
||||
def __init__(self, model_class, rows, db, query=None, sort=None):
|
||||
"""Create a result set that will construct objects of type
|
||||
`model_class`.
|
||||
|
||||
`model_class` is a subclass of `LibModel` that will be
|
||||
constructed. `rows` is a query result: a list of mappings. The
|
||||
new objects will be associated with the database `db`.
|
||||
|
||||
If `query` is provided, it is used as a predicate to filter the
|
||||
results for a "slow query" that cannot be evaluated by the
|
||||
database directly. If `sort` is provided, it is used to sort the
|
||||
full list of results before returning. This means it is a "slow
|
||||
sort" and all objects must be built before returning the first
|
||||
one.
|
||||
"""
|
||||
self.model_class = model_class
|
||||
self.rows = rows
|
||||
self.db = db
|
||||
self.query = query
|
||||
self.sort = sort
|
||||
|
||||
# We keep a queue of rows we haven't yet consumed for
|
||||
# materialization. We preserve the original total number of
|
||||
# rows.
|
||||
self._rows = rows
|
||||
self._row_count = len(rows)
|
||||
|
||||
# The materialized objects corresponding to rows that have been
|
||||
# consumed.
|
||||
self._objects = []
|
||||
|
||||
def _get_objects(self):
|
||||
"""Construct and generate Model objects for they query. The
|
||||
objects are returned in the order emitted from the database; no
|
||||
slow sort is applied.
|
||||
|
||||
For performance, this generator caches materialized objects to
|
||||
avoid constructing them more than once. This way, iterating over
|
||||
a `Results` object a second time should be much faster than the
|
||||
first.
|
||||
"""
|
||||
index = 0 # Position in the materialized objects.
|
||||
while index < len(self._objects) or self._rows:
|
||||
# Are there previously-materialized objects to produce?
|
||||
if index < len(self._objects):
|
||||
yield self._objects[index]
|
||||
index += 1
|
||||
|
||||
# Otherwise, we consume another row, materialize its object
|
||||
# and produce it.
|
||||
else:
|
||||
while self._rows:
|
||||
row = self._rows.pop(0)
|
||||
obj = self._make_model(row)
|
||||
# If there is a slow-query predicate, ensurer that the
|
||||
# object passes it.
|
||||
if not self.query or self.query.match(obj):
|
||||
self._objects.append(obj)
|
||||
index += 1
|
||||
yield obj
|
||||
break
|
||||
|
||||
def __iter__(self):
|
||||
"""Construct and generate Model objects for all matching
|
||||
objects, in sorted order.
|
||||
"""
|
||||
if self.sort:
|
||||
# Slow sort. Must build the full list first.
|
||||
objects = self.sort.sort(list(self._get_objects()))
|
||||
return iter(objects)
|
||||
|
||||
else:
|
||||
# Objects are pre-sorted (i.e., by the database).
|
||||
return self._get_objects()
|
||||
|
||||
def _make_model(self, row):
|
||||
# Get the flexible attributes for the object.
|
||||
with self.db.transaction() as tx:
|
||||
flex_rows = tx.query(
|
||||
'SELECT * FROM {0} WHERE entity_id=?'.format(
|
||||
self.model_class._flex_table
|
||||
),
|
||||
(row['id'],)
|
||||
)
|
||||
|
||||
cols = dict(row)
|
||||
values = dict((k, v) for (k, v) in cols.items()
|
||||
if not k[:4] == 'flex')
|
||||
flex_values = dict((row['key'], row['value']) for row in flex_rows)
|
||||
|
||||
# Construct the Python object
|
||||
obj = self.model_class._awaken(self.db, values, flex_values)
|
||||
return obj
|
||||
|
||||
def __len__(self):
|
||||
"""Get the number of matching objects.
|
||||
"""
|
||||
if not self._rows:
|
||||
# Fully materialized. Just count the objects.
|
||||
return len(self._objects)
|
||||
|
||||
elif self.query:
|
||||
# A slow query. Fall back to testing every object.
|
||||
count = 0
|
||||
for obj in self:
|
||||
count += 1
|
||||
return count
|
||||
|
||||
else:
|
||||
# A fast query. Just count the rows.
|
||||
return self._row_count
|
||||
|
||||
def __nonzero__(self):
|
||||
"""Does this result contain any objects?
|
||||
"""
|
||||
return bool(len(self))
|
||||
|
||||
def __getitem__(self, n):
|
||||
"""Get the nth item in this result set. This is inefficient: all
|
||||
items up to n are materialized and thrown away.
|
||||
"""
|
||||
if not self._rows and not self.sort:
|
||||
# Fully materialized and already in order. Just look up the
|
||||
# object.
|
||||
return self._objects[n]
|
||||
|
||||
it = iter(self)
|
||||
try:
|
||||
for i in range(n):
|
||||
it.next()
|
||||
return it.next()
|
||||
except StopIteration:
|
||||
raise IndexError('result index {0} out of range'.format(n))
|
||||
|
||||
def get(self):
|
||||
"""Return the first matching object, or None if no objects
|
||||
match.
|
||||
"""
|
||||
it = iter(self)
|
||||
try:
|
||||
return it.next()
|
||||
except StopIteration:
|
||||
return None
|
||||
|
||||
|
||||
class Transaction(object):
|
||||
"""A context manager for safe, concurrent access to the database.
|
||||
All SQL commands should be executed through a transaction.
|
||||
"""
|
||||
def __init__(self, db):
|
||||
self.db = db
|
||||
|
||||
def __enter__(self):
|
||||
"""Begin a transaction. This transaction may be created while
|
||||
another is active in a different thread.
|
||||
"""
|
||||
with self.db._tx_stack() as stack:
|
||||
first = not stack
|
||||
stack.append(self)
|
||||
if first:
|
||||
# Beginning a "root" transaction, which corresponds to an
|
||||
# SQLite transaction.
|
||||
self.db._db_lock.acquire()
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
"""Complete a transaction. This must be the most recently
|
||||
entered but not yet exited transaction. If it is the last active
|
||||
transaction, the database updates are committed.
|
||||
"""
|
||||
with self.db._tx_stack() as stack:
|
||||
assert stack.pop() is self
|
||||
empty = not stack
|
||||
if empty:
|
||||
# Ending a "root" transaction. End the SQLite transaction.
|
||||
self.db._connection().commit()
|
||||
self.db._db_lock.release()
|
||||
|
||||
def query(self, statement, subvals=()):
|
||||
"""Execute an SQL statement with substitution values and return
|
||||
a list of rows from the database.
|
||||
"""
|
||||
cursor = self.db._connection().execute(statement, subvals)
|
||||
return cursor.fetchall()
|
||||
|
||||
def mutate(self, statement, subvals=()):
|
||||
"""Execute an SQL statement with substitution values and return
|
||||
the row ID of the last affected row.
|
||||
"""
|
||||
cursor = self.db._connection().execute(statement, subvals)
|
||||
return cursor.lastrowid
|
||||
|
||||
def script(self, statements):
|
||||
"""Execute a string containing multiple SQL statements."""
|
||||
self.db._connection().executescript(statements)
|
||||
|
||||
|
||||
class Database(object):
|
||||
"""A container for Model objects that wraps an SQLite database as
|
||||
the backend.
|
||||
"""
|
||||
_models = ()
|
||||
"""The Model subclasses representing tables in this database.
|
||||
"""
|
||||
|
||||
def __init__(self, path):
|
||||
self.path = path
|
||||
|
||||
self._connections = {}
|
||||
self._tx_stacks = defaultdict(list)
|
||||
|
||||
# A lock to protect the _connections and _tx_stacks maps, which
|
||||
# both map thread IDs to private resources.
|
||||
self._shared_map_lock = threading.Lock()
|
||||
|
||||
# A lock to protect access to the database itself. SQLite does
|
||||
# allow multiple threads to access the database at the same
|
||||
# time, but many users were experiencing crashes related to this
|
||||
# capability: where SQLite was compiled without HAVE_USLEEP, its
|
||||
# backoff algorithm in the case of contention was causing
|
||||
# whole-second sleeps (!) that would trigger its internal
|
||||
# timeout. Using this lock ensures only one SQLite transaction
|
||||
# is active at a time.
|
||||
self._db_lock = threading.Lock()
|
||||
|
||||
# Set up database schema.
|
||||
for model_cls in self._models:
|
||||
self._make_table(model_cls._table, model_cls._fields)
|
||||
self._make_attribute_table(model_cls._flex_table)
|
||||
|
||||
# Primitive access control: connections and transactions.
|
||||
|
||||
def _connection(self):
|
||||
"""Get a SQLite connection object to the underlying database.
|
||||
One connection object is created per thread.
|
||||
"""
|
||||
thread_id = threading.current_thread().ident
|
||||
with self._shared_map_lock:
|
||||
if thread_id in self._connections:
|
||||
return self._connections[thread_id]
|
||||
else:
|
||||
# Make a new connection.
|
||||
conn = sqlite3.connect(
|
||||
self.path,
|
||||
timeout=beets.config['timeout'].as_number(),
|
||||
)
|
||||
|
||||
# Access SELECT results like dictionaries.
|
||||
conn.row_factory = sqlite3.Row
|
||||
|
||||
self._connections[thread_id] = conn
|
||||
return conn
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _tx_stack(self):
|
||||
"""A context manager providing access to the current thread's
|
||||
transaction stack. The context manager synchronizes access to
|
||||
the stack map. Transactions should never migrate across threads.
|
||||
"""
|
||||
thread_id = threading.current_thread().ident
|
||||
with self._shared_map_lock:
|
||||
yield self._tx_stacks[thread_id]
|
||||
|
||||
def transaction(self):
|
||||
"""Get a :class:`Transaction` object for interacting directly
|
||||
with the underlying SQLite database.
|
||||
"""
|
||||
return Transaction(self)
|
||||
|
||||
# Schema setup and migration.
|
||||
|
||||
def _make_table(self, table, fields):
|
||||
"""Set up the schema of the database. `fields` is a mapping
|
||||
from field names to `Type`s. Columns are added if necessary.
|
||||
"""
|
||||
# Get current schema.
|
||||
with self.transaction() as tx:
|
||||
rows = tx.query('PRAGMA table_info(%s)' % table)
|
||||
current_fields = set([row[1] for row in rows])
|
||||
|
||||
field_names = set(fields.keys())
|
||||
if current_fields.issuperset(field_names):
|
||||
# Table exists and has all the required columns.
|
||||
return
|
||||
|
||||
if not current_fields:
|
||||
# No table exists.
|
||||
columns = []
|
||||
for name, typ in fields.items():
|
||||
columns.append('{0} {1}'.format(name, typ.sql))
|
||||
setup_sql = 'CREATE TABLE {0} ({1});\n'.format(table,
|
||||
', '.join(columns))
|
||||
|
||||
else:
|
||||
# Table exists does not match the field set.
|
||||
setup_sql = ''
|
||||
for name, typ in fields.items():
|
||||
if name in current_fields:
|
||||
continue
|
||||
setup_sql += 'ALTER TABLE {0} ADD COLUMN {1} {2};\n'.format(
|
||||
table, name, typ.sql
|
||||
)
|
||||
|
||||
with self.transaction() as tx:
|
||||
tx.script(setup_sql)
|
||||
|
||||
def _make_attribute_table(self, flex_table):
|
||||
"""Create a table and associated index for flexible attributes
|
||||
for the given entity (if they don't exist).
|
||||
"""
|
||||
with self.transaction() as tx:
|
||||
tx.script("""
|
||||
CREATE TABLE IF NOT EXISTS {0} (
|
||||
id INTEGER PRIMARY KEY,
|
||||
entity_id INTEGER,
|
||||
key TEXT,
|
||||
value TEXT,
|
||||
UNIQUE(entity_id, key) ON CONFLICT REPLACE);
|
||||
CREATE INDEX IF NOT EXISTS {0}_by_entity
|
||||
ON {0} (entity_id);
|
||||
""".format(flex_table))
|
||||
|
||||
# Querying.
|
||||
|
||||
def _fetch(self, model_cls, query=None, sort=None):
|
||||
"""Fetch the objects of type `model_cls` matching the given
|
||||
query. The query may be given as a string, string sequence, a
|
||||
Query object, or None (to fetch everything). `sort` is an
|
||||
`Sort` object.
|
||||
"""
|
||||
query = query or TrueQuery() # A null query.
|
||||
sort = sort or NullSort() # Unsorted.
|
||||
where, subvals = query.clause()
|
||||
order_by = sort.order_clause()
|
||||
|
||||
sql = ("SELECT * FROM {0} WHERE {1} {2}").format(
|
||||
model_cls._table,
|
||||
where or '1',
|
||||
"ORDER BY {0}".format(order_by) if order_by else '',
|
||||
)
|
||||
|
||||
with self.transaction() as tx:
|
||||
rows = tx.query(sql, subvals)
|
||||
|
||||
return Results(
|
||||
model_cls, rows, self,
|
||||
None if where else query, # Slow query component.
|
||||
sort if sort.is_slow() else None, # Slow sort component.
|
||||
)
|
||||
|
||||
def _get(self, model_cls, id):
|
||||
"""Get a Model object by its id or None if the id does not
|
||||
exist.
|
||||
"""
|
||||
return self._fetch(model_cls, MatchQuery('id', id)).get()
|
||||
@@ -0,0 +1,654 @@
|
||||
# This file is part of beets.
|
||||
# Copyright 2014, Adrian Sampson.
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining
|
||||
# a copy of this software and associated documentation files (the
|
||||
# "Software"), to deal in the Software without restriction, including
|
||||
# without limitation the rights to use, copy, modify, merge, publish,
|
||||
# distribute, sublicense, and/or sell copies of the Software, and to
|
||||
# permit persons to whom the Software is furnished to do so, subject to
|
||||
# the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be
|
||||
# included in all copies or substantial portions of the Software.
|
||||
|
||||
"""The Query type hierarchy for DBCore.
|
||||
"""
|
||||
import re
|
||||
from operator import attrgetter
|
||||
from beets import util
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
|
||||
class Query(object):
|
||||
"""An abstract class representing a query into the item database.
|
||||
"""
|
||||
def clause(self):
|
||||
"""Generate an SQLite expression implementing the query.
|
||||
Return a clause string, a sequence of substitution values for
|
||||
the clause, and a Query object representing the "remainder"
|
||||
Returns (clause, subvals) where clause is a valid sqlite
|
||||
WHERE clause implementing the query and subvals is a list of
|
||||
items to be substituted for ?s in the clause.
|
||||
"""
|
||||
return None, ()
|
||||
|
||||
def match(self, item):
|
||||
"""Check whether this query matches a given Item. Can be used to
|
||||
perform queries on arbitrary sets of Items.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class FieldQuery(Query):
|
||||
"""An abstract query that searches in a specific field for a
|
||||
pattern. Subclasses must provide a `value_match` class method, which
|
||||
determines whether a certain pattern string matches a certain value
|
||||
string. Subclasses may also provide `col_clause` to implement the
|
||||
same matching functionality in SQLite.
|
||||
"""
|
||||
def __init__(self, field, pattern, fast=True):
|
||||
self.field = field
|
||||
self.pattern = pattern
|
||||
self.fast = fast
|
||||
|
||||
def col_clause(self):
|
||||
return None, ()
|
||||
|
||||
def clause(self):
|
||||
if self.fast:
|
||||
return self.col_clause()
|
||||
else:
|
||||
# Matching a flexattr. This is a slow query.
|
||||
return None, ()
|
||||
|
||||
@classmethod
|
||||
def value_match(cls, pattern, value):
|
||||
"""Determine whether the value matches the pattern. Both
|
||||
arguments are strings.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def match(self, item):
|
||||
return self.value_match(self.pattern, item.get(self.field))
|
||||
|
||||
|
||||
class MatchQuery(FieldQuery):
|
||||
"""A query that looks for exact matches in an item field."""
|
||||
def col_clause(self):
|
||||
return self.field + " = ?", [self.pattern]
|
||||
|
||||
@classmethod
|
||||
def value_match(cls, pattern, value):
|
||||
return pattern == value
|
||||
|
||||
|
||||
class NoneQuery(FieldQuery):
|
||||
|
||||
def __init__(self, field, fast=True):
|
||||
self.field = field
|
||||
self.fast = fast
|
||||
|
||||
def col_clause(self):
|
||||
return self.field + " IS NULL", ()
|
||||
|
||||
@classmethod
|
||||
def match(self, item):
|
||||
try:
|
||||
return item[self.field] is None
|
||||
except KeyError:
|
||||
return True
|
||||
|
||||
|
||||
class StringFieldQuery(FieldQuery):
|
||||
"""A FieldQuery that converts values to strings before matching
|
||||
them.
|
||||
"""
|
||||
@classmethod
|
||||
def value_match(cls, pattern, value):
|
||||
"""Determine whether the value matches the pattern. The value
|
||||
may have any type.
|
||||
"""
|
||||
return cls.string_match(pattern, util.as_string(value))
|
||||
|
||||
@classmethod
|
||||
def string_match(cls, pattern, value):
|
||||
"""Determine whether the value matches the pattern. Both
|
||||
arguments are strings. Subclasses implement this method.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class SubstringQuery(StringFieldQuery):
|
||||
"""A query that matches a substring in a specific item field."""
|
||||
def col_clause(self):
|
||||
pattern = (self.pattern
|
||||
.replace('\\', '\\\\')
|
||||
.replace('%', '\\%')
|
||||
.replace('_', '\\_'))
|
||||
search = '%' + pattern + '%'
|
||||
clause = self.field + " like ? escape '\\'"
|
||||
subvals = [search]
|
||||
return clause, subvals
|
||||
|
||||
@classmethod
|
||||
def string_match(cls, pattern, value):
|
||||
return pattern.lower() in value.lower()
|
||||
|
||||
|
||||
class RegexpQuery(StringFieldQuery):
|
||||
"""A query that matches a regular expression in a specific item
|
||||
field.
|
||||
"""
|
||||
@classmethod
|
||||
def string_match(cls, pattern, value):
|
||||
try:
|
||||
res = re.search(pattern, value)
|
||||
except re.error:
|
||||
# Invalid regular expression.
|
||||
return False
|
||||
return res is not None
|
||||
|
||||
|
||||
class BooleanQuery(MatchQuery):
|
||||
"""Matches a boolean field. Pattern should either be a boolean or a
|
||||
string reflecting a boolean.
|
||||
"""
|
||||
def __init__(self, field, pattern, fast=True):
|
||||
super(BooleanQuery, self).__init__(field, pattern, fast)
|
||||
if isinstance(pattern, basestring):
|
||||
self.pattern = util.str2bool(pattern)
|
||||
self.pattern = int(self.pattern)
|
||||
|
||||
|
||||
class BytesQuery(MatchQuery):
|
||||
"""Match a raw bytes field (i.e., a path). This is a necessary hack
|
||||
to work around the `sqlite3` module's desire to treat `str` and
|
||||
`unicode` equivalently in Python 2. Always use this query instead of
|
||||
`MatchQuery` when matching on BLOB values.
|
||||
"""
|
||||
def __init__(self, field, pattern):
|
||||
super(BytesQuery, self).__init__(field, pattern)
|
||||
|
||||
# Use a buffer representation of the pattern for SQLite
|
||||
# matching. This instructs SQLite to treat the blob as binary
|
||||
# rather than encoded Unicode.
|
||||
if isinstance(self.pattern, basestring):
|
||||
# Implicitly coerce Unicode strings to their bytes
|
||||
# equivalents.
|
||||
if isinstance(self.pattern, unicode):
|
||||
self.pattern = self.pattern.encode('utf8')
|
||||
self.buf_pattern = buffer(self.pattern)
|
||||
elif isinstance(self.pattern, buffer):
|
||||
self.buf_pattern = self.pattern
|
||||
self.pattern = bytes(self.pattern)
|
||||
|
||||
def col_clause(self):
|
||||
return self.field + " = ?", [self.buf_pattern]
|
||||
|
||||
|
||||
class NumericQuery(FieldQuery):
|
||||
"""Matches numeric fields. A syntax using Ruby-style range ellipses
|
||||
(``..``) lets users specify one- or two-sided ranges. For example,
|
||||
``year:2001..`` finds music released since the turn of the century.
|
||||
"""
|
||||
def _convert(self, s):
|
||||
"""Convert a string to a numeric type (float or int). If the
|
||||
string cannot be converted, return None.
|
||||
"""
|
||||
# This is really just a bit of fun premature optimization.
|
||||
try:
|
||||
return int(s)
|
||||
except ValueError:
|
||||
try:
|
||||
return float(s)
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
def __init__(self, field, pattern, fast=True):
|
||||
super(NumericQuery, self).__init__(field, pattern, fast)
|
||||
|
||||
parts = pattern.split('..', 1)
|
||||
if len(parts) == 1:
|
||||
# No range.
|
||||
self.point = self._convert(parts[0])
|
||||
self.rangemin = None
|
||||
self.rangemax = None
|
||||
else:
|
||||
# One- or two-sided range.
|
||||
self.point = None
|
||||
self.rangemin = self._convert(parts[0])
|
||||
self.rangemax = self._convert(parts[1])
|
||||
|
||||
def match(self, item):
|
||||
if self.field not in item:
|
||||
return False
|
||||
value = item[self.field]
|
||||
if isinstance(value, basestring):
|
||||
value = self._convert(value)
|
||||
|
||||
if self.point is not None:
|
||||
return value == self.point
|
||||
else:
|
||||
if self.rangemin is not None and value < self.rangemin:
|
||||
return False
|
||||
if self.rangemax is not None and value > self.rangemax:
|
||||
return False
|
||||
return True
|
||||
|
||||
def col_clause(self):
|
||||
if self.point is not None:
|
||||
return self.field + '=?', (self.point,)
|
||||
else:
|
||||
if self.rangemin is not None and self.rangemax is not None:
|
||||
return (u'{0} >= ? AND {0} <= ?'.format(self.field),
|
||||
(self.rangemin, self.rangemax))
|
||||
elif self.rangemin is not None:
|
||||
return u'{0} >= ?'.format(self.field), (self.rangemin,)
|
||||
elif self.rangemax is not None:
|
||||
return u'{0} <= ?'.format(self.field), (self.rangemax,)
|
||||
else:
|
||||
return '1', ()
|
||||
|
||||
|
||||
class CollectionQuery(Query):
|
||||
"""An abstract query class that aggregates other queries. Can be
|
||||
indexed like a list to access the sub-queries.
|
||||
"""
|
||||
def __init__(self, subqueries=()):
|
||||
self.subqueries = subqueries
|
||||
|
||||
# Act like a sequence.
|
||||
|
||||
def __len__(self):
|
||||
return len(self.subqueries)
|
||||
|
||||
def __getitem__(self, key):
|
||||
return self.subqueries[key]
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self.subqueries)
|
||||
|
||||
def __contains__(self, item):
|
||||
return item in self.subqueries
|
||||
|
||||
def clause_with_joiner(self, joiner):
|
||||
"""Returns a clause created by joining together the clauses of
|
||||
all subqueries with the string joiner (padded by spaces).
|
||||
"""
|
||||
clause_parts = []
|
||||
subvals = []
|
||||
for subq in self.subqueries:
|
||||
subq_clause, subq_subvals = subq.clause()
|
||||
if not subq_clause:
|
||||
# Fall back to slow query.
|
||||
return None, ()
|
||||
clause_parts.append('(' + subq_clause + ')')
|
||||
subvals += subq_subvals
|
||||
clause = (' ' + joiner + ' ').join(clause_parts)
|
||||
return clause, subvals
|
||||
|
||||
|
||||
class AnyFieldQuery(CollectionQuery):
|
||||
"""A query that matches if a given FieldQuery subclass matches in
|
||||
any field. The individual field query class is provided to the
|
||||
constructor.
|
||||
"""
|
||||
def __init__(self, pattern, fields, cls):
|
||||
self.pattern = pattern
|
||||
self.fields = fields
|
||||
self.query_class = cls
|
||||
|
||||
subqueries = []
|
||||
for field in self.fields:
|
||||
subqueries.append(cls(field, pattern, True))
|
||||
super(AnyFieldQuery, self).__init__(subqueries)
|
||||
|
||||
def clause(self):
|
||||
return self.clause_with_joiner('or')
|
||||
|
||||
def match(self, item):
|
||||
for subq in self.subqueries:
|
||||
if subq.match(item):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
class MutableCollectionQuery(CollectionQuery):
|
||||
"""A collection query whose subqueries may be modified after the
|
||||
query is initialized.
|
||||
"""
|
||||
def __setitem__(self, key, value):
|
||||
self.subqueries[key] = value
|
||||
|
||||
def __delitem__(self, key):
|
||||
del self.subqueries[key]
|
||||
|
||||
|
||||
class AndQuery(MutableCollectionQuery):
|
||||
"""A conjunction of a list of other queries."""
|
||||
def clause(self):
|
||||
return self.clause_with_joiner('and')
|
||||
|
||||
def match(self, item):
|
||||
return all([q.match(item) for q in self.subqueries])
|
||||
|
||||
|
||||
class OrQuery(MutableCollectionQuery):
|
||||
"""A conjunction of a list of other queries."""
|
||||
def clause(self):
|
||||
return self.clause_with_joiner('or')
|
||||
|
||||
def match(self, item):
|
||||
return any([q.match(item) for q in self.subqueries])
|
||||
|
||||
|
||||
class TrueQuery(Query):
|
||||
"""A query that always matches."""
|
||||
def clause(self):
|
||||
return '1', ()
|
||||
|
||||
def match(self, item):
|
||||
return True
|
||||
|
||||
|
||||
class FalseQuery(Query):
|
||||
"""A query that never matches."""
|
||||
def clause(self):
|
||||
return '0', ()
|
||||
|
||||
def match(self, item):
|
||||
return False
|
||||
|
||||
|
||||
# Time/date queries.
|
||||
|
||||
def _to_epoch_time(date):
|
||||
"""Convert a `datetime` object to an integer number of seconds since
|
||||
the (local) Unix epoch.
|
||||
"""
|
||||
epoch = datetime.fromtimestamp(0)
|
||||
delta = date - epoch
|
||||
try:
|
||||
return int(delta.total_seconds())
|
||||
except AttributeError:
|
||||
# datetime.timedelta.total_seconds() is not available on Python 2.6
|
||||
return delta.seconds + delta.days * 24 * 3600
|
||||
|
||||
|
||||
def _parse_periods(pattern):
|
||||
"""Parse a string containing two dates separated by two dots (..).
|
||||
Return a pair of `Period` objects.
|
||||
"""
|
||||
parts = pattern.split('..', 1)
|
||||
if len(parts) == 1:
|
||||
instant = Period.parse(parts[0])
|
||||
return (instant, instant)
|
||||
else:
|
||||
start = Period.parse(parts[0])
|
||||
end = Period.parse(parts[1])
|
||||
return (start, end)
|
||||
|
||||
|
||||
class Period(object):
|
||||
"""A period of time given by a date, time and precision.
|
||||
|
||||
Example: 2014-01-01 10:50:30 with precision 'month' represents all
|
||||
instants of time during January 2014.
|
||||
"""
|
||||
|
||||
precisions = ('year', 'month', 'day')
|
||||
date_formats = ('%Y', '%Y-%m', '%Y-%m-%d')
|
||||
|
||||
def __init__(self, date, precision):
|
||||
"""Create a period with the given date (a `datetime` object) and
|
||||
precision (a string, one of "year", "month", or "day").
|
||||
"""
|
||||
if precision not in Period.precisions:
|
||||
raise ValueError('Invalid precision ' + str(precision))
|
||||
self.date = date
|
||||
self.precision = precision
|
||||
|
||||
@classmethod
|
||||
def parse(cls, string):
|
||||
"""Parse a date and return a `Period` object or `None` if the
|
||||
string is empty.
|
||||
"""
|
||||
if not string:
|
||||
return None
|
||||
ordinal = string.count('-')
|
||||
if ordinal >= len(cls.date_formats):
|
||||
# Too many components.
|
||||
return None
|
||||
date_format = cls.date_formats[ordinal]
|
||||
try:
|
||||
date = datetime.strptime(string, date_format)
|
||||
except ValueError:
|
||||
# Parsing failed.
|
||||
return None
|
||||
precision = cls.precisions[ordinal]
|
||||
return cls(date, precision)
|
||||
|
||||
def open_right_endpoint(self):
|
||||
"""Based on the precision, convert the period to a precise
|
||||
`datetime` for use as a right endpoint in a right-open interval.
|
||||
"""
|
||||
precision = self.precision
|
||||
date = self.date
|
||||
if 'year' == self.precision:
|
||||
return date.replace(year=date.year + 1, month=1)
|
||||
elif 'month' == precision:
|
||||
if (date.month < 12):
|
||||
return date.replace(month=date.month + 1)
|
||||
else:
|
||||
return date.replace(year=date.year + 1, month=1)
|
||||
elif 'day' == precision:
|
||||
return date + timedelta(days=1)
|
||||
else:
|
||||
raise ValueError('unhandled precision ' + str(precision))
|
||||
|
||||
|
||||
class DateInterval(object):
|
||||
"""A closed-open interval of dates.
|
||||
|
||||
A left endpoint of None means since the beginning of time.
|
||||
A right endpoint of None means towards infinity.
|
||||
"""
|
||||
|
||||
def __init__(self, start, end):
|
||||
if start is not None and end is not None and not start < end:
|
||||
raise ValueError("start date {0} is not before end date {1}"
|
||||
.format(start, end))
|
||||
self.start = start
|
||||
self.end = end
|
||||
|
||||
@classmethod
|
||||
def from_periods(cls, start, end):
|
||||
"""Create an interval with two Periods as the endpoints.
|
||||
"""
|
||||
end_date = end.open_right_endpoint() if end is not None else None
|
||||
start_date = start.date if start is not None else None
|
||||
return cls(start_date, end_date)
|
||||
|
||||
def contains(self, date):
|
||||
if self.start is not None and date < self.start:
|
||||
return False
|
||||
if self.end is not None and date >= self.end:
|
||||
return False
|
||||
return True
|
||||
|
||||
def __str__(self):
|
||||
return'[{0}, {1})'.format(self.start, self.end)
|
||||
|
||||
|
||||
class DateQuery(FieldQuery):
|
||||
"""Matches date fields stored as seconds since Unix epoch time.
|
||||
|
||||
Dates can be specified as ``year-month-day`` strings where only year
|
||||
is mandatory.
|
||||
|
||||
The value of a date field can be matched against a date interval by
|
||||
using an ellipsis interval syntax similar to that of NumericQuery.
|
||||
"""
|
||||
def __init__(self, field, pattern, fast=True):
|
||||
super(DateQuery, self).__init__(field, pattern, fast)
|
||||
start, end = _parse_periods(pattern)
|
||||
self.interval = DateInterval.from_periods(start, end)
|
||||
|
||||
def match(self, item):
|
||||
timestamp = float(item[self.field])
|
||||
date = datetime.utcfromtimestamp(timestamp)
|
||||
return self.interval.contains(date)
|
||||
|
||||
_clause_tmpl = "{0} {1} ?"
|
||||
|
||||
def col_clause(self):
|
||||
clause_parts = []
|
||||
subvals = []
|
||||
|
||||
if self.interval.start:
|
||||
clause_parts.append(self._clause_tmpl.format(self.field, ">="))
|
||||
subvals.append(_to_epoch_time(self.interval.start))
|
||||
|
||||
if self.interval.end:
|
||||
clause_parts.append(self._clause_tmpl.format(self.field, "<"))
|
||||
subvals.append(_to_epoch_time(self.interval.end))
|
||||
|
||||
if clause_parts:
|
||||
# One- or two-sided interval.
|
||||
clause = ' AND '.join(clause_parts)
|
||||
else:
|
||||
# Match any date.
|
||||
clause = '1'
|
||||
return clause, subvals
|
||||
|
||||
|
||||
# Sorting.
|
||||
|
||||
class Sort(object):
|
||||
"""An abstract class representing a sort operation for a query into
|
||||
the item database.
|
||||
"""
|
||||
|
||||
def order_clause(self):
|
||||
"""Generates a SQL fragment to be used in a ORDER BY clause, or
|
||||
None if no fragment is used (i.e., this is a slow sort).
|
||||
"""
|
||||
return None
|
||||
|
||||
def sort(self, items):
|
||||
"""Sort the list of objects and return a list.
|
||||
"""
|
||||
return sorted(items)
|
||||
|
||||
def is_slow(self):
|
||||
"""Indicate whether this query is *slow*, meaning that it cannot
|
||||
be executed in SQL and must be executed in Python.
|
||||
"""
|
||||
return False
|
||||
|
||||
|
||||
class MultipleSort(Sort):
|
||||
"""Sort that encapsulates multiple sub-sorts.
|
||||
"""
|
||||
|
||||
def __init__(self, sorts=None):
|
||||
self.sorts = sorts or []
|
||||
|
||||
def add_sort(self, sort):
|
||||
self.sorts.append(sort)
|
||||
|
||||
def _sql_sorts(self):
|
||||
"""Return the list of sub-sorts for which we can be (at least
|
||||
partially) fast.
|
||||
|
||||
A contiguous suffix of fast (SQL-capable) sub-sorts are
|
||||
executable in SQL. The remaining, even if they are fast
|
||||
independently, must be executed slowly.
|
||||
"""
|
||||
sql_sorts = []
|
||||
for sort in reversed(self.sorts):
|
||||
if not sort.order_clause() is None:
|
||||
sql_sorts.append(sort)
|
||||
else:
|
||||
break
|
||||
sql_sorts.reverse()
|
||||
return sql_sorts
|
||||
|
||||
def order_clause(self):
|
||||
order_strings = []
|
||||
for sort in self._sql_sorts():
|
||||
order = sort.order_clause()
|
||||
order_strings.append(order)
|
||||
|
||||
return ", ".join(order_strings)
|
||||
|
||||
def is_slow(self):
|
||||
for sort in self.sorts:
|
||||
if sort.is_slow():
|
||||
return True
|
||||
return False
|
||||
|
||||
def sort(self, items):
|
||||
slow_sorts = []
|
||||
switch_slow = False
|
||||
for sort in reversed(self.sorts):
|
||||
if switch_slow:
|
||||
slow_sorts.append(sort)
|
||||
elif sort.order_clause() is None:
|
||||
switch_slow = True
|
||||
slow_sorts.append(sort)
|
||||
else:
|
||||
pass
|
||||
|
||||
for sort in slow_sorts:
|
||||
items = sort.sort(items)
|
||||
return items
|
||||
|
||||
def __repr__(self):
|
||||
return u'MultipleSort({0})'.format(repr(self.sorts))
|
||||
|
||||
|
||||
class FieldSort(Sort):
|
||||
"""An abstract sort criterion that orders by a specific field (of
|
||||
any kind).
|
||||
"""
|
||||
def __init__(self, field, ascending=True):
|
||||
self.field = field
|
||||
self.ascending = ascending
|
||||
|
||||
def sort(self, objs):
|
||||
# TODO: Conversion and null-detection here. In Python 3,
|
||||
# comparisons with None fail. We should also support flexible
|
||||
# attributes with different types without falling over.
|
||||
return sorted(objs, key=attrgetter(self.field),
|
||||
reverse=not self.ascending)
|
||||
|
||||
def __repr__(self):
|
||||
return u'<{0}: {1}{2}>'.format(
|
||||
type(self).__name__,
|
||||
self.field,
|
||||
'+' if self.ascending else '-',
|
||||
)
|
||||
|
||||
|
||||
class FixedFieldSort(FieldSort):
|
||||
"""Sort object to sort on a fixed field.
|
||||
"""
|
||||
def order_clause(self):
|
||||
order = "ASC" if self.ascending else "DESC"
|
||||
return "{0} {1}".format(self.field, order)
|
||||
|
||||
|
||||
class SlowFieldSort(FieldSort):
|
||||
"""A sort criterion by some model field other than a fixed field:
|
||||
i.e., a computed or flexible field.
|
||||
"""
|
||||
def is_slow(self):
|
||||
return True
|
||||
|
||||
|
||||
class NullSort(Sort):
|
||||
"""No sorting. Leave results unsorted."""
|
||||
def sort(items):
|
||||
return items
|
||||
@@ -0,0 +1,180 @@
|
||||
# This file is part of beets.
|
||||
# Copyright 2014, Adrian Sampson.
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining
|
||||
# a copy of this software and associated documentation files (the
|
||||
# "Software"), to deal in the Software without restriction, including
|
||||
# without limitation the rights to use, copy, modify, merge, publish,
|
||||
# distribute, sublicense, and/or sell copies of the Software, and to
|
||||
# permit persons to whom the Software is furnished to do so, subject to
|
||||
# the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be
|
||||
# included in all copies or substantial portions of the Software.
|
||||
|
||||
"""Parsing of strings into DBCore queries.
|
||||
"""
|
||||
import re
|
||||
import itertools
|
||||
from . import query
|
||||
|
||||
|
||||
PARSE_QUERY_PART_REGEX = re.compile(
|
||||
# Non-capturing optional segment for the keyword.
|
||||
r'(?:'
|
||||
r'(\S+?)' # The field key.
|
||||
r'(?<!\\):' # Unescaped :
|
||||
r')?'
|
||||
|
||||
r'(.*)', # The term itself.
|
||||
|
||||
re.I # Case-insensitive.
|
||||
)
|
||||
|
||||
|
||||
def parse_query_part(part, query_classes={}, prefixes={},
|
||||
default_class=query.SubstringQuery):
|
||||
"""Take a query in the form of a key/value pair separated by a
|
||||
colon and return a tuple of `(key, value, cls)`. `key` may be None,
|
||||
indicating that any field may be matched. `cls` is a subclass of
|
||||
`FieldQuery`.
|
||||
|
||||
The optional `query_classes` parameter maps field names to default
|
||||
query types; `default_class` is the fallback. `prefixes` is a map
|
||||
from query prefix markers and query types. Prefix-indicated queries
|
||||
take precedence over type-based queries.
|
||||
|
||||
To determine the query class, two factors are used: prefixes and
|
||||
field types. For example, the colon prefix denotes a regular
|
||||
expression query and a type map might provide a special kind of
|
||||
query for numeric values. If neither a prefix nor a specific query
|
||||
class is available, `default_class` is used.
|
||||
|
||||
For instance,
|
||||
'stapler' -> (None, 'stapler', SubstringQuery)
|
||||
'color:red' -> ('color', 'red', SubstringQuery)
|
||||
':^Quiet' -> (None, '^Quiet', RegexpQuery)
|
||||
'color::b..e' -> ('color', 'b..e', RegexpQuery)
|
||||
|
||||
Prefixes may be "escaped" with a backslash to disable the keying
|
||||
behavior.
|
||||
"""
|
||||
part = part.strip()
|
||||
match = PARSE_QUERY_PART_REGEX.match(part)
|
||||
|
||||
assert match # Regex should always match.
|
||||
key = match.group(1)
|
||||
term = match.group(2).replace('\:', ':')
|
||||
|
||||
# Match the search term against the list of prefixes.
|
||||
for pre, query_class in prefixes.items():
|
||||
if term.startswith(pre):
|
||||
return key, term[len(pre):], query_class
|
||||
|
||||
# No matching prefix: use type-based or fallback/default query.
|
||||
query_class = query_classes.get(key, default_class)
|
||||
return key, term, query_class
|
||||
|
||||
|
||||
def construct_query_part(model_cls, prefixes, query_part):
|
||||
"""Create a query from a single query component, `query_part`, for
|
||||
querying instances of `model_cls`. Return a `Query` instance.
|
||||
"""
|
||||
# Shortcut for empty query parts.
|
||||
if not query_part:
|
||||
return query.TrueQuery()
|
||||
|
||||
# Get the query classes for each possible field.
|
||||
query_classes = {}
|
||||
for k, t in itertools.chain(model_cls._fields.items(),
|
||||
model_cls._types.items()):
|
||||
query_classes[k] = t.query
|
||||
|
||||
# Parse the string.
|
||||
key, pattern, query_class = \
|
||||
parse_query_part(query_part, query_classes, prefixes)
|
||||
|
||||
# No key specified.
|
||||
if key is None:
|
||||
if issubclass(query_class, query.FieldQuery):
|
||||
# The query type matches a specific field, but none was
|
||||
# specified. So we use a version of the query that matches
|
||||
# any field.
|
||||
return query.AnyFieldQuery(pattern, model_cls._search_fields,
|
||||
query_class)
|
||||
else:
|
||||
# Other query type.
|
||||
return query_class(pattern)
|
||||
|
||||
key = key.lower()
|
||||
return query_class(key.lower(), pattern, key in model_cls._fields)
|
||||
|
||||
|
||||
def query_from_strings(query_cls, model_cls, prefixes, query_parts):
|
||||
"""Creates a collection query of type `query_cls` from a list of
|
||||
strings in the format used by parse_query_part. `model_cls`
|
||||
determines how queries are constructed from strings.
|
||||
"""
|
||||
subqueries = []
|
||||
for part in query_parts:
|
||||
subqueries.append(construct_query_part(model_cls, prefixes, part))
|
||||
if not subqueries: # No terms in query.
|
||||
subqueries = [query.TrueQuery()]
|
||||
return query_cls(subqueries)
|
||||
|
||||
|
||||
def construct_sort_part(model_cls, part):
|
||||
"""Create a `Sort` from a single string criterion.
|
||||
|
||||
`model_cls` is the `Model` being queried. `part` is a single string
|
||||
ending in ``+`` or ``-`` indicating the sort.
|
||||
"""
|
||||
assert part, "part must be a field name and + or -"
|
||||
field = part[:-1]
|
||||
assert field, "field is missing"
|
||||
direction = part[-1]
|
||||
assert direction in ('+', '-'), "part must end with + or -"
|
||||
is_ascending = direction == '+'
|
||||
|
||||
if field in model_cls._sorts:
|
||||
sort = model_cls._sorts[field](model_cls, is_ascending)
|
||||
elif field in model_cls._fields:
|
||||
sort = query.FixedFieldSort(field, is_ascending)
|
||||
else:
|
||||
# Flexible or computed.
|
||||
sort = query.SlowFieldSort(field, is_ascending)
|
||||
return sort
|
||||
|
||||
|
||||
def sort_from_strings(model_cls, sort_parts):
|
||||
"""Create a `Sort` from a list of sort criteria (strings).
|
||||
"""
|
||||
if not sort_parts:
|
||||
return query.NullSort()
|
||||
else:
|
||||
sort = query.MultipleSort()
|
||||
for part in sort_parts:
|
||||
sort.add_sort(construct_sort_part(model_cls, part))
|
||||
return sort
|
||||
|
||||
|
||||
def parse_sorted_query(model_cls, parts, prefixes={},
|
||||
query_cls=query.AndQuery):
|
||||
"""Given a list of strings, create the `Query` and `Sort` that they
|
||||
represent.
|
||||
"""
|
||||
# Separate query token and sort token.
|
||||
query_parts = []
|
||||
sort_parts = []
|
||||
for part in parts:
|
||||
if part.endswith((u'+', u'-')) and u':' not in part:
|
||||
sort_parts.append(part)
|
||||
else:
|
||||
query_parts.append(part)
|
||||
|
||||
# Parse each.
|
||||
q = query_from_strings(
|
||||
query_cls, model_cls, prefixes, query_parts
|
||||
)
|
||||
s = sort_from_strings(model_cls, sort_parts)
|
||||
return q, s
|
||||
@@ -0,0 +1,208 @@
|
||||
# This file is part of beets.
|
||||
# Copyright 2014, Adrian Sampson.
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining
|
||||
# a copy of this software and associated documentation files (the
|
||||
# "Software"), to deal in the Software without restriction, including
|
||||
# without limitation the rights to use, copy, modify, merge, publish,
|
||||
# distribute, sublicense, and/or sell copies of the Software, and to
|
||||
# permit persons to whom the Software is furnished to do so, subject to
|
||||
# the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be
|
||||
# included in all copies or substantial portions of the Software.
|
||||
|
||||
"""Representation of type information for DBCore model fields.
|
||||
"""
|
||||
from . import query
|
||||
from beets.util import str2bool
|
||||
|
||||
|
||||
# Abstract base.
|
||||
|
||||
class Type(object):
|
||||
"""An object encapsulating the type of a model field. Includes
|
||||
information about how to store, query, format, and parse a given
|
||||
field.
|
||||
"""
|
||||
|
||||
sql = u'TEXT'
|
||||
"""The SQLite column type for the value.
|
||||
"""
|
||||
|
||||
query = query.SubstringQuery
|
||||
"""The `Query` subclass to be used when querying the field.
|
||||
"""
|
||||
|
||||
model_type = unicode
|
||||
"""The Python type that is used to represent the value in the model.
|
||||
|
||||
The model is guaranteed to return a value of this type if the field
|
||||
is accessed. To this end, the constructor is used by the `normalize`
|
||||
and `from_sql` methods and the `default` property.
|
||||
"""
|
||||
|
||||
@property
|
||||
def null(self):
|
||||
"""The value to be exposed when the underlying value is None.
|
||||
"""
|
||||
return self.model_type()
|
||||
|
||||
def format(self, value):
|
||||
"""Given a value of this type, produce a Unicode string
|
||||
representing the value. This is used in template evaluation.
|
||||
"""
|
||||
if value is None:
|
||||
value = self.null
|
||||
# `self.null` might be `None`
|
||||
if value is None:
|
||||
value = u''
|
||||
if isinstance(value, bytes):
|
||||
value = value.decode('utf8', 'ignore')
|
||||
|
||||
return unicode(value)
|
||||
|
||||
def parse(self, string):
|
||||
"""Parse a (possibly human-written) string and return the
|
||||
indicated value of this type.
|
||||
"""
|
||||
try:
|
||||
return self.model_type(string)
|
||||
except ValueError:
|
||||
return self.null
|
||||
|
||||
def normalize(self, value):
|
||||
"""Given a value that will be assigned into a field of this
|
||||
type, normalize the value to have the appropriate type. This
|
||||
base implementation only reinterprets `None`.
|
||||
"""
|
||||
if value is None:
|
||||
return self.null
|
||||
else:
|
||||
# TODO This should eventually be replaced by
|
||||
# `self.model_type(value)`
|
||||
return value
|
||||
|
||||
def from_sql(self, sql_value):
|
||||
"""Receives the value stored in the SQL backend and return the
|
||||
value to be stored in the model.
|
||||
|
||||
For fixed fields the type of `value` is determined by the column
|
||||
type affinity given in the `sql` property and the SQL to Python
|
||||
mapping of the database adapter. For more information see:
|
||||
http://www.sqlite.org/datatype3.html
|
||||
https://docs.python.org/2/library/sqlite3.html#sqlite-and-python-types
|
||||
|
||||
Flexible fields have the type afinity `TEXT`. This means the
|
||||
`sql_value` is either a `buffer` or a `unicode` object` and the
|
||||
method must handle these in addition.
|
||||
"""
|
||||
if isinstance(sql_value, buffer):
|
||||
sql_value = bytes(sql_value).decode('utf8', 'ignore')
|
||||
if isinstance(sql_value, unicode):
|
||||
return self.parse(sql_value)
|
||||
else:
|
||||
return self.normalize(sql_value)
|
||||
|
||||
def to_sql(self, model_value):
|
||||
"""Convert a value as stored in the model object to a value used
|
||||
by the database adapter.
|
||||
"""
|
||||
return model_value
|
||||
|
||||
|
||||
# Reusable types.
|
||||
|
||||
class Default(Type):
|
||||
null = None
|
||||
|
||||
|
||||
class Integer(Type):
|
||||
"""A basic integer type.
|
||||
"""
|
||||
sql = u'INTEGER'
|
||||
query = query.NumericQuery
|
||||
model_type = int
|
||||
|
||||
|
||||
class PaddedInt(Integer):
|
||||
"""An integer field that is formatted with a given number of digits,
|
||||
padded with zeroes.
|
||||
"""
|
||||
def __init__(self, digits):
|
||||
self.digits = digits
|
||||
|
||||
def format(self, value):
|
||||
return u'{0:0{1}d}'.format(value or 0, self.digits)
|
||||
|
||||
|
||||
class ScaledInt(Integer):
|
||||
"""An integer whose formatting operation scales the number by a
|
||||
constant and adds a suffix. Good for units with large magnitudes.
|
||||
"""
|
||||
def __init__(self, unit, suffix=u''):
|
||||
self.unit = unit
|
||||
self.suffix = suffix
|
||||
|
||||
def format(self, value):
|
||||
return u'{0}{1}'.format((value or 0) // self.unit, self.suffix)
|
||||
|
||||
|
||||
class Id(Integer):
|
||||
"""An integer used as the row id or a foreign key in a SQLite table.
|
||||
This type is nullable: None values are not translated to zero.
|
||||
"""
|
||||
null = None
|
||||
|
||||
def __init__(self, primary=True):
|
||||
if primary:
|
||||
self.sql = u'INTEGER PRIMARY KEY'
|
||||
|
||||
|
||||
class Float(Type):
|
||||
"""A basic floating-point type.
|
||||
"""
|
||||
sql = u'REAL'
|
||||
query = query.NumericQuery
|
||||
model_type = float
|
||||
|
||||
def format(self, value):
|
||||
return u'{0:.1f}'.format(value or 0.0)
|
||||
|
||||
|
||||
class NullFloat(Float):
|
||||
"""Same as `Float`, but does not normalize `None` to `0.0`.
|
||||
"""
|
||||
null = None
|
||||
|
||||
|
||||
class String(Type):
|
||||
"""A Unicode string type.
|
||||
"""
|
||||
sql = u'TEXT'
|
||||
query = query.SubstringQuery
|
||||
|
||||
|
||||
class Boolean(Type):
|
||||
"""A boolean type.
|
||||
"""
|
||||
sql = u'INTEGER'
|
||||
query = query.BooleanQuery
|
||||
model_type = bool
|
||||
|
||||
def format(self, value):
|
||||
return unicode(bool(value))
|
||||
|
||||
def parse(self, string):
|
||||
return str2bool(string)
|
||||
|
||||
|
||||
# Shared instances of common types.
|
||||
DEFAULT = Default()
|
||||
INTEGER = Integer()
|
||||
PRIMARY_ID = Id(True)
|
||||
FOREIGN_ID = Id(False)
|
||||
FLOAT = Float()
|
||||
NULL_FLOAT = NullFloat()
|
||||
STRING = String()
|
||||
BOOLEAN = Boolean()
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
Executable
+435
@@ -0,0 +1,435 @@
|
||||
# This file is part of beets.
|
||||
# Copyright 2013, Adrian Sampson.
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining
|
||||
# a copy of this software and associated documentation files (the
|
||||
# "Software"), to deal in the Software without restriction, including
|
||||
# without limitation the rights to use, copy, modify, merge, publish,
|
||||
# distribute, sublicense, and/or sell copies of the Software, and to
|
||||
# permit persons to whom the Software is furnished to do so, subject to
|
||||
# the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be
|
||||
# included in all copies or substantial portions of the Software.
|
||||
|
||||
"""Support for beets plugins."""
|
||||
|
||||
import logging
|
||||
import traceback
|
||||
import inspect
|
||||
import re
|
||||
from collections import defaultdict
|
||||
|
||||
|
||||
import beets
|
||||
from beets import mediafile
|
||||
|
||||
PLUGIN_NAMESPACE = 'beetsplug'
|
||||
|
||||
# Plugins using the Last.fm API can share the same API key.
|
||||
LASTFM_KEY = '2dc3914abf35f0d9c92d97d8f8e42b43'
|
||||
|
||||
# Global logger.
|
||||
log = logging.getLogger('beets')
|
||||
|
||||
|
||||
class PluginConflictException(Exception):
|
||||
"""Indicates that the services provided by one plugin conflict with
|
||||
those of another.
|
||||
|
||||
For example two plugins may define different types for flexible fields.
|
||||
"""
|
||||
|
||||
|
||||
# Managing the plugins themselves.
|
||||
|
||||
class BeetsPlugin(object):
|
||||
"""The base class for all beets plugins. Plugins provide
|
||||
functionality by defining a subclass of BeetsPlugin and overriding
|
||||
the abstract methods defined here.
|
||||
"""
|
||||
def __init__(self, name=None):
|
||||
"""Perform one-time plugin setup.
|
||||
"""
|
||||
self.import_stages = []
|
||||
self.name = name or self.__module__.split('.')[-1]
|
||||
self.config = beets.config[self.name]
|
||||
if not self.template_funcs:
|
||||
self.template_funcs = {}
|
||||
if not self.template_fields:
|
||||
self.template_fields = {}
|
||||
if not self.album_template_fields:
|
||||
self.album_template_fields = {}
|
||||
|
||||
def commands(self):
|
||||
"""Should return a list of beets.ui.Subcommand objects for
|
||||
commands that should be added to beets' CLI.
|
||||
"""
|
||||
return ()
|
||||
|
||||
def queries(self):
|
||||
"""Should return a dict mapping prefixes to Query subclasses.
|
||||
"""
|
||||
return {}
|
||||
|
||||
def track_distance(self, item, info):
|
||||
"""Should return a Distance object to be added to the
|
||||
distance for every track comparison.
|
||||
"""
|
||||
return beets.autotag.hooks.Distance()
|
||||
|
||||
def album_distance(self, items, album_info, mapping):
|
||||
"""Should return a Distance object to be added to the
|
||||
distance for every album-level comparison.
|
||||
"""
|
||||
return beets.autotag.hooks.Distance()
|
||||
|
||||
def candidates(self, items, artist, album, va_likely):
|
||||
"""Should return a sequence of AlbumInfo objects that match the
|
||||
album whose items are provided.
|
||||
"""
|
||||
return ()
|
||||
|
||||
def item_candidates(self, item, artist, title):
|
||||
"""Should return a sequence of TrackInfo objects that match the
|
||||
item provided.
|
||||
"""
|
||||
return ()
|
||||
|
||||
def album_for_id(self, album_id):
|
||||
"""Return an AlbumInfo object or None if no matching release was
|
||||
found.
|
||||
"""
|
||||
return None
|
||||
|
||||
def track_for_id(self, track_id):
|
||||
"""Return a TrackInfo object or None if no matching release was
|
||||
found.
|
||||
"""
|
||||
return None
|
||||
|
||||
def add_media_field(self, name, descriptor):
|
||||
"""Add a field that is synchronized between media files and items.
|
||||
|
||||
When a media field is added ``item.write()`` will set the name
|
||||
property of the item's MediaFile to ``item[name]`` and save the
|
||||
changes. Similarly ``item.read()`` will set ``item[name]`` to
|
||||
the value of the name property of the media file.
|
||||
|
||||
``descriptor`` must be an instance of ``mediafile.MediaField``.
|
||||
"""
|
||||
# Defer impor to prevent circular dependency
|
||||
from beets import library
|
||||
mediafile.MediaFile.add_field(name, descriptor)
|
||||
library.Item._media_fields.add(name)
|
||||
|
||||
listeners = None
|
||||
|
||||
@classmethod
|
||||
def register_listener(cls, event, func):
|
||||
"""Add a function as a listener for the specified event. (An
|
||||
imperative alternative to the @listen decorator.)
|
||||
"""
|
||||
if cls.listeners is None:
|
||||
cls.listeners = defaultdict(list)
|
||||
cls.listeners[event].append(func)
|
||||
|
||||
@classmethod
|
||||
def listen(cls, event):
|
||||
"""Decorator that adds a function as an event handler for the
|
||||
specified event (as a string). The parameters passed to function
|
||||
will vary depending on what event occurred.
|
||||
|
||||
The function should respond to named parameters.
|
||||
function(**kwargs) will trap all arguments in a dictionary.
|
||||
Example:
|
||||
|
||||
>>> @MyPlugin.listen("imported")
|
||||
>>> def importListener(**kwargs):
|
||||
... pass
|
||||
"""
|
||||
def helper(func):
|
||||
if cls.listeners is None:
|
||||
cls.listeners = defaultdict(list)
|
||||
cls.listeners[event].append(func)
|
||||
return func
|
||||
return helper
|
||||
|
||||
template_funcs = None
|
||||
template_fields = None
|
||||
album_template_fields = None
|
||||
|
||||
@classmethod
|
||||
def template_func(cls, name):
|
||||
"""Decorator that registers a path template function. The
|
||||
function will be invoked as ``%name{}`` from path format
|
||||
strings.
|
||||
"""
|
||||
def helper(func):
|
||||
if cls.template_funcs is None:
|
||||
cls.template_funcs = {}
|
||||
cls.template_funcs[name] = func
|
||||
return func
|
||||
return helper
|
||||
|
||||
@classmethod
|
||||
def template_field(cls, name):
|
||||
"""Decorator that registers a path template field computation.
|
||||
The value will be referenced as ``$name`` from path format
|
||||
strings. The function must accept a single parameter, the Item
|
||||
being formatted.
|
||||
"""
|
||||
def helper(func):
|
||||
if cls.template_fields is None:
|
||||
cls.template_fields = {}
|
||||
cls.template_fields[name] = func
|
||||
return func
|
||||
return helper
|
||||
|
||||
|
||||
_classes = set()
|
||||
|
||||
|
||||
def load_plugins(names=()):
|
||||
"""Imports the modules for a sequence of plugin names. Each name
|
||||
must be the name of a Python module under the "beetsplug" namespace
|
||||
package in sys.path; the module indicated should contain the
|
||||
BeetsPlugin subclasses desired.
|
||||
"""
|
||||
for name in names:
|
||||
modname = '%s.%s' % (PLUGIN_NAMESPACE, name)
|
||||
try:
|
||||
try:
|
||||
namespace = __import__(modname, None, None)
|
||||
except ImportError as exc:
|
||||
# Again, this is hacky:
|
||||
if exc.args[0].endswith(' ' + name):
|
||||
log.warn(u'** plugin {0} not found'.format(name))
|
||||
else:
|
||||
raise
|
||||
else:
|
||||
for obj in getattr(namespace, name).__dict__.values():
|
||||
if isinstance(obj, type) and issubclass(obj, BeetsPlugin) \
|
||||
and obj != BeetsPlugin and obj not in _classes:
|
||||
_classes.add(obj)
|
||||
|
||||
except:
|
||||
log.warn(u'** error loading plugin {0}'.format(name))
|
||||
log.warn(traceback.format_exc())
|
||||
|
||||
|
||||
_instances = {}
|
||||
|
||||
|
||||
def find_plugins():
|
||||
"""Returns a list of BeetsPlugin subclass instances from all
|
||||
currently loaded beets plugins. Loads the default plugin set
|
||||
first.
|
||||
"""
|
||||
load_plugins()
|
||||
plugins = []
|
||||
for cls in _classes:
|
||||
# Only instantiate each plugin class once.
|
||||
if cls not in _instances:
|
||||
_instances[cls] = cls()
|
||||
plugins.append(_instances[cls])
|
||||
return plugins
|
||||
|
||||
|
||||
# Communication with plugins.
|
||||
|
||||
def commands():
|
||||
"""Returns a list of Subcommand objects from all loaded plugins.
|
||||
"""
|
||||
out = []
|
||||
for plugin in find_plugins():
|
||||
out += plugin.commands()
|
||||
return out
|
||||
|
||||
|
||||
def queries():
|
||||
"""Returns a dict mapping prefix strings to Query subclasses all loaded
|
||||
plugins.
|
||||
"""
|
||||
out = {}
|
||||
for plugin in find_plugins():
|
||||
out.update(plugin.queries())
|
||||
return out
|
||||
|
||||
|
||||
def types(model_cls):
|
||||
# Gives us `item_types` and `album_types`
|
||||
attr_name = '{0}_types'.format(model_cls.__name__.lower())
|
||||
types = {}
|
||||
for plugin in find_plugins():
|
||||
plugin_types = getattr(plugin, attr_name, {})
|
||||
for field in plugin_types:
|
||||
if field in types and plugin_types[field] != types[field]:
|
||||
raise PluginConflictException(
|
||||
u'Plugin {0} defines flexible field {1} '
|
||||
'which has already been defined with '
|
||||
'another type.'.format(plugin.name, field)
|
||||
)
|
||||
types.update(plugin_types)
|
||||
return types
|
||||
|
||||
|
||||
def track_distance(item, info):
|
||||
"""Gets the track distance calculated by all loaded plugins.
|
||||
Returns a Distance object.
|
||||
"""
|
||||
from beets.autotag.hooks import Distance
|
||||
dist = Distance()
|
||||
for plugin in find_plugins():
|
||||
dist.update(plugin.track_distance(item, info))
|
||||
return dist
|
||||
|
||||
|
||||
def album_distance(items, album_info, mapping):
|
||||
"""Returns the album distance calculated by plugins."""
|
||||
from beets.autotag.hooks import Distance
|
||||
dist = Distance()
|
||||
for plugin in find_plugins():
|
||||
dist.update(plugin.album_distance(items, album_info, mapping))
|
||||
return dist
|
||||
|
||||
|
||||
def candidates(items, artist, album, va_likely):
|
||||
"""Gets MusicBrainz candidates for an album from each plugin.
|
||||
"""
|
||||
out = []
|
||||
for plugin in find_plugins():
|
||||
out.extend(plugin.candidates(items, artist, album, va_likely))
|
||||
return out
|
||||
|
||||
|
||||
def item_candidates(item, artist, title):
|
||||
"""Gets MusicBrainz candidates for an item from the plugins.
|
||||
"""
|
||||
out = []
|
||||
for plugin in find_plugins():
|
||||
out.extend(plugin.item_candidates(item, artist, title))
|
||||
return out
|
||||
|
||||
|
||||
def album_for_id(album_id):
|
||||
"""Get AlbumInfo objects for a given ID string.
|
||||
"""
|
||||
out = []
|
||||
for plugin in find_plugins():
|
||||
res = plugin.album_for_id(album_id)
|
||||
if res:
|
||||
out.append(res)
|
||||
return out
|
||||
|
||||
|
||||
def track_for_id(track_id):
|
||||
"""Get TrackInfo objects for a given ID string.
|
||||
"""
|
||||
out = []
|
||||
for plugin in find_plugins():
|
||||
res = plugin.track_for_id(track_id)
|
||||
if res:
|
||||
out.append(res)
|
||||
return out
|
||||
|
||||
|
||||
def template_funcs():
|
||||
"""Get all the template functions declared by plugins as a
|
||||
dictionary.
|
||||
"""
|
||||
funcs = {}
|
||||
for plugin in find_plugins():
|
||||
if plugin.template_funcs:
|
||||
funcs.update(plugin.template_funcs)
|
||||
return funcs
|
||||
|
||||
|
||||
def import_stages():
|
||||
"""Get a list of import stage functions defined by plugins."""
|
||||
stages = []
|
||||
for plugin in find_plugins():
|
||||
if hasattr(plugin, 'import_stages'):
|
||||
stages += plugin.import_stages
|
||||
return stages
|
||||
|
||||
|
||||
# New-style (lazy) plugin-provided fields.
|
||||
|
||||
def item_field_getters():
|
||||
"""Get a dictionary mapping field names to unary functions that
|
||||
compute the field's value.
|
||||
"""
|
||||
funcs = {}
|
||||
for plugin in find_plugins():
|
||||
if plugin.template_fields:
|
||||
funcs.update(plugin.template_fields)
|
||||
return funcs
|
||||
|
||||
|
||||
def album_field_getters():
|
||||
"""As above, for album fields.
|
||||
"""
|
||||
funcs = {}
|
||||
for plugin in find_plugins():
|
||||
if plugin.album_template_fields:
|
||||
funcs.update(plugin.album_template_fields)
|
||||
return funcs
|
||||
|
||||
|
||||
# Event dispatch.
|
||||
|
||||
def event_handlers():
|
||||
"""Find all event handlers from plugins as a dictionary mapping
|
||||
event names to sequences of callables.
|
||||
"""
|
||||
all_handlers = defaultdict(list)
|
||||
for plugin in find_plugins():
|
||||
if plugin.listeners:
|
||||
for event, handlers in plugin.listeners.items():
|
||||
all_handlers[event] += handlers
|
||||
return all_handlers
|
||||
|
||||
|
||||
def send(event, **arguments):
|
||||
"""Sends an event to all assigned event listeners. Event is the
|
||||
name of the event to send, all other named arguments go to the
|
||||
event handler(s).
|
||||
|
||||
Returns a list of return values from the handlers.
|
||||
"""
|
||||
log.debug(u'Sending event: {0}'.format(event))
|
||||
for handler in event_handlers()[event]:
|
||||
# Don't break legacy plugins if we want to pass more arguments
|
||||
argspec = inspect.getargspec(handler).args
|
||||
args = dict((k, v) for k, v in arguments.items() if k in argspec)
|
||||
handler(**args)
|
||||
|
||||
|
||||
def feat_tokens(for_artist=True):
|
||||
"""Return a regular expression that matches phrases like "featuring"
|
||||
that separate a main artist or a song title from secondary artists.
|
||||
The `for_artist` option determines whether the regex should be
|
||||
suitable for matching artist fields (the default) or title fields.
|
||||
"""
|
||||
feat_words = ['ft', 'featuring', 'feat', 'feat.', 'ft.']
|
||||
if for_artist:
|
||||
feat_words += ['with', 'vs', 'and', 'con', '&']
|
||||
return '(?<=\s)(?:{0})(?=\s)'.format(
|
||||
'|'.join(re.escape(x) for x in feat_words)
|
||||
)
|
||||
|
||||
|
||||
def sanitize_choices(choices, choices_all):
|
||||
"""Clean up a stringlist configuration attribute: keep only choices
|
||||
elements present in choices_all, remove duplicate elements, expand '*'
|
||||
wildcard while keeping original stringlist order.
|
||||
"""
|
||||
seen = set()
|
||||
others = [x for x in choices_all if x not in choices]
|
||||
res = []
|
||||
for s in choices:
|
||||
if s in list(choices_all) + ['*']:
|
||||
if not (s in seen or seen.add(s)):
|
||||
res.extend(list(others) if s == '*' else [s])
|
||||
return res
|
||||
@@ -0,0 +1,970 @@
|
||||
# This file is part of beets.
|
||||
# Copyright 2014, Adrian Sampson.
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining
|
||||
# a copy of this software and associated documentation files (the
|
||||
# "Software"), to deal in the Software without restriction, including
|
||||
# without limitation the rights to use, copy, modify, merge, publish,
|
||||
# distribute, sublicense, and/or sell copies of the Software, and to
|
||||
# permit persons to whom the Software is furnished to do so, subject to
|
||||
# the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be
|
||||
# included in all copies or substantial portions of the Software.
|
||||
|
||||
"""This module contains all of the core logic for beets' command-line
|
||||
interface. To invoke the CLI, just call beets.ui.main(). The actual
|
||||
CLI commands are implemented in the ui.commands module.
|
||||
"""
|
||||
from __future__ import print_function
|
||||
|
||||
import locale
|
||||
import optparse
|
||||
import textwrap
|
||||
import sys
|
||||
from difflib import SequenceMatcher
|
||||
import logging
|
||||
import sqlite3
|
||||
import errno
|
||||
import re
|
||||
import struct
|
||||
import traceback
|
||||
import os.path
|
||||
|
||||
from beets import library
|
||||
from beets import plugins
|
||||
from beets import util
|
||||
from beets.util.functemplate import Template
|
||||
from beets import config
|
||||
from beets.util import confit
|
||||
from beets.autotag import mb
|
||||
|
||||
# On Windows platforms, use colorama to support "ANSI" terminal colors.
|
||||
if sys.platform == 'win32':
|
||||
try:
|
||||
import colorama
|
||||
except ImportError:
|
||||
pass
|
||||
else:
|
||||
colorama.init()
|
||||
|
||||
|
||||
log = logging.getLogger('beets')
|
||||
if not log.handlers:
|
||||
log.addHandler(logging.StreamHandler())
|
||||
log.propagate = False # Don't propagate to root handler.
|
||||
|
||||
|
||||
PF_KEY_QUERIES = {
|
||||
'comp': 'comp:true',
|
||||
'singleton': 'singleton:true',
|
||||
}
|
||||
|
||||
|
||||
class UserError(Exception):
|
||||
"""UI exception. Commands should throw this in order to display
|
||||
nonrecoverable errors to the user.
|
||||
"""
|
||||
|
||||
|
||||
# Utilities.
|
||||
|
||||
def _encoding():
|
||||
"""Tries to guess the encoding used by the terminal."""
|
||||
# Configured override?
|
||||
encoding = config['terminal_encoding'].get()
|
||||
if encoding:
|
||||
return encoding
|
||||
|
||||
# Determine from locale settings.
|
||||
try:
|
||||
return locale.getdefaultlocale()[1] or 'utf8'
|
||||
except ValueError:
|
||||
# Invalid locale environment variable setting. To avoid
|
||||
# failing entirely for no good reason, assume UTF-8.
|
||||
return 'utf8'
|
||||
|
||||
|
||||
def decargs(arglist):
|
||||
"""Given a list of command-line argument bytestrings, attempts to
|
||||
decode them to Unicode strings.
|
||||
"""
|
||||
return [s.decode(_encoding()) for s in arglist]
|
||||
|
||||
|
||||
def print_(*strings):
|
||||
"""Like print, but rather than raising an error when a character
|
||||
is not in the terminal's encoding's character set, just silently
|
||||
replaces it.
|
||||
"""
|
||||
if strings:
|
||||
if isinstance(strings[0], unicode):
|
||||
txt = u' '.join(strings)
|
||||
else:
|
||||
txt = ' '.join(strings)
|
||||
else:
|
||||
txt = u''
|
||||
if isinstance(txt, unicode):
|
||||
txt = txt.encode(_encoding(), 'replace')
|
||||
print(txt)
|
||||
|
||||
|
||||
def input_(prompt=None):
|
||||
"""Like `raw_input`, but decodes the result to a Unicode string.
|
||||
Raises a UserError if stdin is not available. The prompt is sent to
|
||||
stdout rather than stderr. A printed between the prompt and the
|
||||
input cursor.
|
||||
"""
|
||||
# raw_input incorrectly sends prompts to stderr, not stdout, so we
|
||||
# use print() explicitly to display prompts.
|
||||
# http://bugs.python.org/issue1927
|
||||
if prompt:
|
||||
if isinstance(prompt, unicode):
|
||||
prompt = prompt.encode(_encoding(), 'replace')
|
||||
print(prompt, end=' ')
|
||||
|
||||
try:
|
||||
resp = raw_input()
|
||||
except EOFError:
|
||||
raise UserError('stdin stream ended while input required')
|
||||
|
||||
return resp.decode(sys.stdin.encoding or 'utf8', 'ignore')
|
||||
|
||||
|
||||
def input_options(options, require=False, prompt=None, fallback_prompt=None,
|
||||
numrange=None, default=None, max_width=72):
|
||||
"""Prompts a user for input. The sequence of `options` defines the
|
||||
choices the user has. A single-letter shortcut is inferred for each
|
||||
option; the user's choice is returned as that single, lower-case
|
||||
letter. The options should be provided as lower-case strings unless
|
||||
a particular shortcut is desired; in that case, only that letter
|
||||
should be capitalized.
|
||||
|
||||
By default, the first option is the default. `default` can be provided to
|
||||
override this. If `require` is provided, then there is no default. The
|
||||
prompt and fallback prompt are also inferred but can be overridden.
|
||||
|
||||
If numrange is provided, it is a pair of `(high, low)` (both ints)
|
||||
indicating that, in addition to `options`, the user may enter an
|
||||
integer in that inclusive range.
|
||||
|
||||
`max_width` specifies the maximum number of columns in the
|
||||
automatically generated prompt string.
|
||||
"""
|
||||
# Assign single letters to each option. Also capitalize the options
|
||||
# to indicate the letter.
|
||||
letters = {}
|
||||
display_letters = []
|
||||
capitalized = []
|
||||
first = True
|
||||
for option in options:
|
||||
# Is a letter already capitalized?
|
||||
for letter in option:
|
||||
if letter.isalpha() and letter.upper() == letter:
|
||||
found_letter = letter
|
||||
break
|
||||
else:
|
||||
# Infer a letter.
|
||||
for letter in option:
|
||||
if not letter.isalpha():
|
||||
continue # Don't use punctuation.
|
||||
if letter not in letters:
|
||||
found_letter = letter
|
||||
break
|
||||
else:
|
||||
raise ValueError('no unambiguous lettering found')
|
||||
|
||||
letters[found_letter.lower()] = option
|
||||
index = option.index(found_letter)
|
||||
|
||||
# Mark the option's shortcut letter for display.
|
||||
if not require and (
|
||||
(default is None and not numrange and first) or
|
||||
(isinstance(default, basestring) and
|
||||
found_letter.lower() == default.lower())):
|
||||
# The first option is the default; mark it.
|
||||
show_letter = '[%s]' % found_letter.upper()
|
||||
is_default = True
|
||||
else:
|
||||
show_letter = found_letter.upper()
|
||||
is_default = False
|
||||
|
||||
# Colorize the letter shortcut.
|
||||
show_letter = colorize('turquoise' if is_default else 'blue',
|
||||
show_letter)
|
||||
|
||||
# Insert the highlighted letter back into the word.
|
||||
capitalized.append(
|
||||
option[:index] + show_letter + option[index + 1:]
|
||||
)
|
||||
display_letters.append(found_letter.upper())
|
||||
|
||||
first = False
|
||||
|
||||
# The default is just the first option if unspecified.
|
||||
if require:
|
||||
default = None
|
||||
elif default is None:
|
||||
if numrange:
|
||||
default = numrange[0]
|
||||
else:
|
||||
default = display_letters[0].lower()
|
||||
|
||||
# Make a prompt if one is not provided.
|
||||
if not prompt:
|
||||
prompt_parts = []
|
||||
prompt_part_lengths = []
|
||||
if numrange:
|
||||
if isinstance(default, int):
|
||||
default_name = str(default)
|
||||
default_name = colorize('turquoise', default_name)
|
||||
tmpl = '# selection (default %s)'
|
||||
prompt_parts.append(tmpl % default_name)
|
||||
prompt_part_lengths.append(len(tmpl % str(default)))
|
||||
else:
|
||||
prompt_parts.append('# selection')
|
||||
prompt_part_lengths.append(len(prompt_parts[-1]))
|
||||
prompt_parts += capitalized
|
||||
prompt_part_lengths += [len(s) for s in options]
|
||||
|
||||
# Wrap the query text.
|
||||
prompt = ''
|
||||
line_length = 0
|
||||
for i, (part, length) in enumerate(zip(prompt_parts,
|
||||
prompt_part_lengths)):
|
||||
# Add punctuation.
|
||||
if i == len(prompt_parts) - 1:
|
||||
part += '?'
|
||||
else:
|
||||
part += ','
|
||||
length += 1
|
||||
|
||||
# Choose either the current line or the beginning of the next.
|
||||
if line_length + length + 1 > max_width:
|
||||
prompt += '\n'
|
||||
line_length = 0
|
||||
|
||||
if line_length != 0:
|
||||
# Not the beginning of the line; need a space.
|
||||
part = ' ' + part
|
||||
length += 1
|
||||
|
||||
prompt += part
|
||||
line_length += length
|
||||
|
||||
# Make a fallback prompt too. This is displayed if the user enters
|
||||
# something that is not recognized.
|
||||
if not fallback_prompt:
|
||||
fallback_prompt = 'Enter one of '
|
||||
if numrange:
|
||||
fallback_prompt += '%i-%i, ' % numrange
|
||||
fallback_prompt += ', '.join(display_letters) + ':'
|
||||
|
||||
resp = input_(prompt)
|
||||
while True:
|
||||
resp = resp.strip().lower()
|
||||
|
||||
# Try default option.
|
||||
if default is not None and not resp:
|
||||
resp = default
|
||||
|
||||
# Try an integer input if available.
|
||||
if numrange:
|
||||
try:
|
||||
resp = int(resp)
|
||||
except ValueError:
|
||||
pass
|
||||
else:
|
||||
low, high = numrange
|
||||
if low <= resp <= high:
|
||||
return resp
|
||||
else:
|
||||
resp = None
|
||||
|
||||
# Try a normal letter input.
|
||||
if resp:
|
||||
resp = resp[0]
|
||||
if resp in letters:
|
||||
return resp
|
||||
|
||||
# Prompt for new input.
|
||||
resp = input_(fallback_prompt)
|
||||
|
||||
|
||||
def input_yn(prompt, require=False):
|
||||
"""Prompts the user for a "yes" or "no" response. The default is
|
||||
"yes" unless `require` is `True`, in which case there is no default.
|
||||
"""
|
||||
sel = input_options(
|
||||
('y', 'n'), require, prompt, 'Enter Y or N:'
|
||||
)
|
||||
return sel == 'y'
|
||||
|
||||
|
||||
def human_bytes(size):
|
||||
"""Formats size, a number of bytes, in a human-readable way."""
|
||||
suffices = ['B', 'KB', 'MB', 'GB', 'TB', 'PB', 'EB', 'ZB', 'YB', 'HB']
|
||||
for suffix in suffices:
|
||||
if size < 1024:
|
||||
return "%3.1f %s" % (size, suffix)
|
||||
size /= 1024.0
|
||||
return "big"
|
||||
|
||||
|
||||
def human_seconds(interval):
|
||||
"""Formats interval, a number of seconds, as a human-readable time
|
||||
interval using English words.
|
||||
"""
|
||||
units = [
|
||||
(1, 'second'),
|
||||
(60, 'minute'),
|
||||
(60, 'hour'),
|
||||
(24, 'day'),
|
||||
(7, 'week'),
|
||||
(52, 'year'),
|
||||
(10, 'decade'),
|
||||
]
|
||||
for i in range(len(units) - 1):
|
||||
increment, suffix = units[i]
|
||||
next_increment, _ = units[i + 1]
|
||||
interval /= float(increment)
|
||||
if interval < next_increment:
|
||||
break
|
||||
else:
|
||||
# Last unit.
|
||||
increment, suffix = units[-1]
|
||||
interval /= float(increment)
|
||||
|
||||
return "%3.1f %ss" % (interval, suffix)
|
||||
|
||||
|
||||
def human_seconds_short(interval):
|
||||
"""Formats a number of seconds as a short human-readable M:SS
|
||||
string.
|
||||
"""
|
||||
interval = int(interval)
|
||||
return u'%i:%02i' % (interval // 60, interval % 60)
|
||||
|
||||
|
||||
# ANSI terminal colorization code heavily inspired by pygments:
|
||||
# http://dev.pocoo.org/hg/pygments-main/file/b2deea5b5030/pygments/console.py
|
||||
# (pygments is by Tim Hatch, Armin Ronacher, et al.)
|
||||
COLOR_ESCAPE = "\x1b["
|
||||
DARK_COLORS = ["black", "darkred", "darkgreen", "brown", "darkblue",
|
||||
"purple", "teal", "lightgray"]
|
||||
LIGHT_COLORS = ["darkgray", "red", "green", "yellow", "blue",
|
||||
"fuchsia", "turquoise", "white"]
|
||||
RESET_COLOR = COLOR_ESCAPE + "39;49;00m"
|
||||
|
||||
|
||||
def _colorize(color, text):
|
||||
"""Returns a string that prints the given text in the given color
|
||||
in a terminal that is ANSI color-aware. The color must be something
|
||||
in DARK_COLORS or LIGHT_COLORS.
|
||||
"""
|
||||
if color in DARK_COLORS:
|
||||
escape = COLOR_ESCAPE + "%im" % (DARK_COLORS.index(color) + 30)
|
||||
elif color in LIGHT_COLORS:
|
||||
escape = COLOR_ESCAPE + "%i;01m" % (LIGHT_COLORS.index(color) + 30)
|
||||
else:
|
||||
raise ValueError('no such color %s', color)
|
||||
return escape + text + RESET_COLOR
|
||||
|
||||
|
||||
def colorize(color, text):
|
||||
"""Colorize text if colored output is enabled. (Like _colorize but
|
||||
conditional.)
|
||||
"""
|
||||
if config['color']:
|
||||
return _colorize(color, text)
|
||||
else:
|
||||
return text
|
||||
|
||||
|
||||
def _colordiff(a, b, highlight='red', minor_highlight='lightgray'):
|
||||
"""Given two values, return the same pair of strings except with
|
||||
their differences highlighted in the specified color. Strings are
|
||||
highlighted intelligently to show differences; other values are
|
||||
stringified and highlighted in their entirety.
|
||||
"""
|
||||
if not isinstance(a, basestring) or not isinstance(b, basestring):
|
||||
# Non-strings: use ordinary equality.
|
||||
a = unicode(a)
|
||||
b = unicode(b)
|
||||
if a == b:
|
||||
return a, b
|
||||
else:
|
||||
return colorize(highlight, a), colorize(highlight, b)
|
||||
|
||||
if isinstance(a, bytes) or isinstance(b, bytes):
|
||||
# A path field.
|
||||
a = util.displayable_path(a)
|
||||
b = util.displayable_path(b)
|
||||
|
||||
a_out = []
|
||||
b_out = []
|
||||
|
||||
matcher = SequenceMatcher(lambda x: False, a, b)
|
||||
for op, a_start, a_end, b_start, b_end in matcher.get_opcodes():
|
||||
if op == 'equal':
|
||||
# In both strings.
|
||||
a_out.append(a[a_start:a_end])
|
||||
b_out.append(b[b_start:b_end])
|
||||
elif op == 'insert':
|
||||
# Right only.
|
||||
b_out.append(colorize(highlight, b[b_start:b_end]))
|
||||
elif op == 'delete':
|
||||
# Left only.
|
||||
a_out.append(colorize(highlight, a[a_start:a_end]))
|
||||
elif op == 'replace':
|
||||
# Right and left differ. Colorise with second highlight if
|
||||
# it's just a case change.
|
||||
if a[a_start:a_end].lower() != b[b_start:b_end].lower():
|
||||
color = highlight
|
||||
else:
|
||||
color = minor_highlight
|
||||
a_out.append(colorize(color, a[a_start:a_end]))
|
||||
b_out.append(colorize(color, b[b_start:b_end]))
|
||||
else:
|
||||
assert(False)
|
||||
|
||||
return u''.join(a_out), u''.join(b_out)
|
||||
|
||||
|
||||
def colordiff(a, b, highlight='red'):
|
||||
"""Colorize differences between two values if color is enabled.
|
||||
(Like _colordiff but conditional.)
|
||||
"""
|
||||
if config['color']:
|
||||
return _colordiff(a, b, highlight)
|
||||
else:
|
||||
return unicode(a), unicode(b)
|
||||
|
||||
|
||||
def get_path_formats(subview=None):
|
||||
"""Get the configuration's path formats as a list of query/template
|
||||
pairs.
|
||||
"""
|
||||
path_formats = []
|
||||
subview = subview or config['paths']
|
||||
for query, view in subview.items():
|
||||
query = PF_KEY_QUERIES.get(query, query) # Expand common queries.
|
||||
path_formats.append((query, Template(view.get(unicode))))
|
||||
return path_formats
|
||||
|
||||
|
||||
def get_replacements():
|
||||
"""Confit validation function that reads regex/string pairs.
|
||||
"""
|
||||
replacements = []
|
||||
for pattern, repl in config['replace'].get(dict).items():
|
||||
repl = repl or ''
|
||||
try:
|
||||
replacements.append((re.compile(pattern), repl))
|
||||
except re.error:
|
||||
raise UserError(
|
||||
u'malformed regular expression in replace: {0}'.format(
|
||||
pattern
|
||||
)
|
||||
)
|
||||
return replacements
|
||||
|
||||
|
||||
def _pick_format(album, fmt=None):
|
||||
"""Pick a format string for printing Album or Item objects,
|
||||
falling back to config options and defaults.
|
||||
"""
|
||||
if fmt:
|
||||
return fmt
|
||||
if album:
|
||||
return config['list_format_album'].get(unicode)
|
||||
else:
|
||||
return config['list_format_item'].get(unicode)
|
||||
|
||||
|
||||
def print_obj(obj, lib, fmt=None):
|
||||
"""Print an Album or Item object. If `fmt` is specified, use that
|
||||
format string. Otherwise, use the configured template.
|
||||
"""
|
||||
album = isinstance(obj, library.Album)
|
||||
fmt = _pick_format(album, fmt)
|
||||
if isinstance(fmt, Template):
|
||||
template = fmt
|
||||
else:
|
||||
template = Template(fmt)
|
||||
print_(obj.evaluate_template(template))
|
||||
|
||||
|
||||
def term_width():
|
||||
"""Get the width (columns) of the terminal."""
|
||||
fallback = config['ui']['terminal_width'].get(int)
|
||||
|
||||
# The fcntl and termios modules are not available on non-Unix
|
||||
# platforms, so we fall back to a constant.
|
||||
try:
|
||||
import fcntl
|
||||
import termios
|
||||
except ImportError:
|
||||
return fallback
|
||||
|
||||
try:
|
||||
buf = fcntl.ioctl(0, termios.TIOCGWINSZ, ' ' * 4)
|
||||
except IOError:
|
||||
return fallback
|
||||
try:
|
||||
height, width = struct.unpack('hh', buf)
|
||||
except struct.error:
|
||||
return fallback
|
||||
return width
|
||||
|
||||
|
||||
FLOAT_EPSILON = 0.01
|
||||
|
||||
|
||||
def _field_diff(field, old, new):
|
||||
"""Given two Model objects, format their values for `field` and
|
||||
highlight changes among them. Return a human-readable string. If the
|
||||
value has not changed, return None instead.
|
||||
"""
|
||||
oldval = old.get(field)
|
||||
newval = new.get(field)
|
||||
|
||||
# If no change, abort.
|
||||
if isinstance(oldval, float) and isinstance(newval, float) and \
|
||||
abs(oldval - newval) < FLOAT_EPSILON:
|
||||
return None
|
||||
elif oldval == newval:
|
||||
return None
|
||||
|
||||
# Get formatted values for output.
|
||||
oldstr = old.formatted().get(field, u'')
|
||||
newstr = new.formatted().get(field, u'')
|
||||
|
||||
# For strings, highlight changes. For others, colorize the whole
|
||||
# thing.
|
||||
if isinstance(oldval, basestring):
|
||||
oldstr, newstr = colordiff(oldval, newstr)
|
||||
else:
|
||||
oldstr, newstr = colorize('red', oldstr), colorize('red', newstr)
|
||||
|
||||
return u'{0} -> {1}'.format(oldstr, newstr)
|
||||
|
||||
|
||||
def show_model_changes(new, old=None, fields=None, always=False):
|
||||
"""Given a Model object, print a list of changes from its pristine
|
||||
version stored in the database. Return a boolean indicating whether
|
||||
any changes were found.
|
||||
|
||||
`old` may be the "original" object to avoid using the pristine
|
||||
version from the database. `fields` may be a list of fields to
|
||||
restrict the detection to. `always` indicates whether the object is
|
||||
always identified, regardless of whether any changes are present.
|
||||
"""
|
||||
old = old or new._db._get(type(new), new.id)
|
||||
|
||||
# Build up lines showing changed fields.
|
||||
changes = []
|
||||
for field in old:
|
||||
# Subset of the fields. Never show mtime.
|
||||
if field == 'mtime' or (fields and field not in fields):
|
||||
continue
|
||||
|
||||
# Detect and show difference for this field.
|
||||
line = _field_diff(field, old, new)
|
||||
if line:
|
||||
changes.append(u' {0}: {1}'.format(field, line))
|
||||
|
||||
# New fields.
|
||||
for field in set(new) - set(old):
|
||||
if fields and field not in fields:
|
||||
continue
|
||||
|
||||
changes.append(u' {0}: {1}'.format(
|
||||
field,
|
||||
colorize('red', new.formatted()[field])
|
||||
))
|
||||
|
||||
# Print changes.
|
||||
if changes or always:
|
||||
print_obj(old, old._db)
|
||||
if changes:
|
||||
print_(u'\n'.join(changes))
|
||||
|
||||
return bool(changes)
|
||||
|
||||
|
||||
# Subcommand parsing infrastructure.
|
||||
#
|
||||
# This is a fairly generic subcommand parser for optparse. It is
|
||||
# maintained externally here:
|
||||
# http://gist.github.com/462717
|
||||
# There you will also find a better description of the code and a more
|
||||
# succinct example program.
|
||||
|
||||
class Subcommand(object):
|
||||
"""A subcommand of a root command-line application that may be
|
||||
invoked by a SubcommandOptionParser.
|
||||
"""
|
||||
def __init__(self, name, parser=None, help='', aliases=(), hide=False):
|
||||
"""Creates a new subcommand. name is the primary way to invoke
|
||||
the subcommand; aliases are alternate names. parser is an
|
||||
OptionParser responsible for parsing the subcommand's options.
|
||||
help is a short description of the command. If no parser is
|
||||
given, it defaults to a new, empty OptionParser.
|
||||
"""
|
||||
self.name = name
|
||||
self.parser = parser or optparse.OptionParser()
|
||||
self.aliases = aliases
|
||||
self.help = help
|
||||
self.hide = hide
|
||||
self._root_parser = None
|
||||
|
||||
def print_help(self):
|
||||
self.parser.print_help()
|
||||
|
||||
def parse_args(self, args):
|
||||
return self.parser.parse_args(args)
|
||||
|
||||
@property
|
||||
def root_parser(self):
|
||||
return self._root_parser
|
||||
|
||||
@root_parser.setter
|
||||
def root_parser(self, root_parser):
|
||||
self._root_parser = root_parser
|
||||
self.parser.prog = '{0} {1}'.format(root_parser.get_prog_name(),
|
||||
self.name)
|
||||
|
||||
|
||||
class SubcommandsOptionParser(optparse.OptionParser):
|
||||
"""A variant of OptionParser that parses subcommands and their
|
||||
arguments.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
"""Create a new subcommand-aware option parser. All of the
|
||||
options to OptionParser.__init__ are supported in addition
|
||||
to subcommands, a sequence of Subcommand objects.
|
||||
"""
|
||||
# A more helpful default usage.
|
||||
if 'usage' not in kwargs:
|
||||
kwargs['usage'] = """
|
||||
%prog COMMAND [ARGS...]
|
||||
%prog help COMMAND"""
|
||||
kwargs['add_help_option'] = False
|
||||
|
||||
# Super constructor.
|
||||
optparse.OptionParser.__init__(self, *args, **kwargs)
|
||||
|
||||
# Our root parser needs to stop on the first unrecognized argument.
|
||||
self.disable_interspersed_args()
|
||||
|
||||
self.subcommands = []
|
||||
|
||||
def add_subcommand(self, *cmds):
|
||||
"""Adds a Subcommand object to the parser's list of commands.
|
||||
"""
|
||||
for cmd in cmds:
|
||||
cmd.root_parser = self
|
||||
self.subcommands.append(cmd)
|
||||
|
||||
# Add the list of subcommands to the help message.
|
||||
def format_help(self, formatter=None):
|
||||
# Get the original help message, to which we will append.
|
||||
out = optparse.OptionParser.format_help(self, formatter)
|
||||
if formatter is None:
|
||||
formatter = self.formatter
|
||||
|
||||
# Subcommands header.
|
||||
result = ["\n"]
|
||||
result.append(formatter.format_heading('Commands'))
|
||||
formatter.indent()
|
||||
|
||||
# Generate the display names (including aliases).
|
||||
# Also determine the help position.
|
||||
disp_names = []
|
||||
help_position = 0
|
||||
subcommands = [c for c in self.subcommands if not c.hide]
|
||||
subcommands.sort(key=lambda c: c.name)
|
||||
for subcommand in subcommands:
|
||||
name = subcommand.name
|
||||
if subcommand.aliases:
|
||||
name += ' (%s)' % ', '.join(subcommand.aliases)
|
||||
disp_names.append(name)
|
||||
|
||||
# Set the help position based on the max width.
|
||||
proposed_help_position = len(name) + formatter.current_indent + 2
|
||||
if proposed_help_position <= formatter.max_help_position:
|
||||
help_position = max(help_position, proposed_help_position)
|
||||
|
||||
# Add each subcommand to the output.
|
||||
for subcommand, name in zip(subcommands, disp_names):
|
||||
# Lifted directly from optparse.py.
|
||||
name_width = help_position - formatter.current_indent - 2
|
||||
if len(name) > name_width:
|
||||
name = "%*s%s\n" % (formatter.current_indent, "", name)
|
||||
indent_first = help_position
|
||||
else:
|
||||
name = "%*s%-*s " % (formatter.current_indent, "",
|
||||
name_width, name)
|
||||
indent_first = 0
|
||||
result.append(name)
|
||||
help_width = formatter.width - help_position
|
||||
help_lines = textwrap.wrap(subcommand.help, help_width)
|
||||
result.append("%*s%s\n" % (indent_first, "", help_lines[0]))
|
||||
result.extend(["%*s%s\n" % (help_position, "", line)
|
||||
for line in help_lines[1:]])
|
||||
formatter.dedent()
|
||||
|
||||
# Concatenate the original help message with the subcommand
|
||||
# list.
|
||||
return out + "".join(result)
|
||||
|
||||
def _subcommand_for_name(self, name):
|
||||
"""Return the subcommand in self.subcommands matching the
|
||||
given name. The name may either be the name of a subcommand or
|
||||
an alias. If no subcommand matches, returns None.
|
||||
"""
|
||||
for subcommand in self.subcommands:
|
||||
if name == subcommand.name or \
|
||||
name in subcommand.aliases:
|
||||
return subcommand
|
||||
return None
|
||||
|
||||
def parse_global_options(self, args):
|
||||
"""Parse options up to the subcommand argument. Returns a tuple
|
||||
of the options object and the remaining arguments.
|
||||
"""
|
||||
options, subargs = self.parse_args(args)
|
||||
|
||||
# Force the help command
|
||||
if options.help:
|
||||
subargs = ['help']
|
||||
elif options.version:
|
||||
subargs = ['version']
|
||||
return options, subargs
|
||||
|
||||
def parse_subcommand(self, args):
|
||||
"""Given the `args` left unused by a `parse_global_options`,
|
||||
return the invoked subcommand, the subcommand options, and the
|
||||
subcommand arguments.
|
||||
"""
|
||||
# Help is default command
|
||||
if not args:
|
||||
args = ['help']
|
||||
|
||||
cmdname = args.pop(0)
|
||||
subcommand = self._subcommand_for_name(cmdname)
|
||||
if not subcommand:
|
||||
raise UserError("unknown command '{0}'".format(cmdname))
|
||||
|
||||
suboptions, subargs = subcommand.parse_args(args)
|
||||
return subcommand, suboptions, subargs
|
||||
|
||||
|
||||
optparse.Option.ALWAYS_TYPED_ACTIONS += ('callback',)
|
||||
|
||||
|
||||
def vararg_callback(option, opt_str, value, parser):
|
||||
"""Callback for an option with variable arguments.
|
||||
Manually collect arguments right of a callback-action
|
||||
option (ie. with action="callback"), and add the resulting
|
||||
list to the destination var.
|
||||
|
||||
Usage:
|
||||
parser.add_option("-c", "--callback", dest="vararg_attr",
|
||||
action="callback", callback=vararg_callback)
|
||||
|
||||
Details:
|
||||
http://docs.python.org/2/library/optparse.html#callback-example-6-variable
|
||||
-arguments
|
||||
"""
|
||||
value = [value]
|
||||
|
||||
def floatable(str):
|
||||
try:
|
||||
float(str)
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
for arg in parser.rargs:
|
||||
# stop on --foo like options
|
||||
if arg[:2] == "--" and len(arg) > 2:
|
||||
break
|
||||
# stop on -a, but not on -3 or -3.0
|
||||
if arg[:1] == "-" and len(arg) > 1 and not floatable(arg):
|
||||
break
|
||||
value.append(arg)
|
||||
|
||||
del parser.rargs[:len(value) - 1]
|
||||
setattr(parser.values, option.dest, value)
|
||||
|
||||
|
||||
# The main entry point and bootstrapping.
|
||||
|
||||
def _load_plugins(config):
|
||||
"""Load the plugins specified in the configuration.
|
||||
"""
|
||||
paths = config['pluginpath'].get(confit.StrSeq(split=False))
|
||||
paths = map(util.normpath, paths)
|
||||
|
||||
import beetsplug
|
||||
beetsplug.__path__ = paths + beetsplug.__path__
|
||||
# For backwards compatibility.
|
||||
sys.path += paths
|
||||
|
||||
plugins.load_plugins(config['plugins'].as_str_seq())
|
||||
plugins.send("pluginload")
|
||||
return plugins
|
||||
|
||||
|
||||
def _setup(options, lib=None):
|
||||
"""Prepare and global state and updates it with command line options.
|
||||
|
||||
Returns a list of subcommands, a list of plugins, and a library instance.
|
||||
"""
|
||||
# Configure the MusicBrainz API.
|
||||
mb.configure()
|
||||
|
||||
config = _configure(options)
|
||||
|
||||
plugins = _load_plugins(config)
|
||||
|
||||
# Get the default subcommands.
|
||||
from beets.ui.commands import default_commands
|
||||
|
||||
subcommands = list(default_commands)
|
||||
subcommands.extend(plugins.commands())
|
||||
|
||||
if lib is None:
|
||||
lib = _open_library(config)
|
||||
plugins.send("library_opened", lib=lib)
|
||||
library.Item._types = plugins.types(library.Item)
|
||||
library.Album._types = plugins.types(library.Album)
|
||||
|
||||
return subcommands, plugins, lib
|
||||
|
||||
|
||||
def _configure(options):
|
||||
"""Amend the global configuration object with command line options.
|
||||
"""
|
||||
# Add any additional config files specified with --config. This
|
||||
# special handling lets specified plugins get loaded before we
|
||||
# finish parsing the command line.
|
||||
if getattr(options, 'config', None) is not None:
|
||||
config_path = options.config
|
||||
del options.config
|
||||
config.set_file(config_path)
|
||||
config.set_args(options)
|
||||
|
||||
# Configure the logger.
|
||||
if config['verbose'].get(bool):
|
||||
log.setLevel(logging.DEBUG)
|
||||
else:
|
||||
log.setLevel(logging.INFO)
|
||||
|
||||
config_path = config.user_config_path()
|
||||
if os.path.isfile(config_path):
|
||||
log.debug(u'user configuration: {0}'.format(
|
||||
util.displayable_path(config_path)))
|
||||
else:
|
||||
log.debug(u'no user configuration found at {0}'.format(
|
||||
util.displayable_path(config_path)))
|
||||
|
||||
log.debug(u'data directory: {0}'
|
||||
.format(util.displayable_path(config.config_dir())))
|
||||
return config
|
||||
|
||||
|
||||
def _open_library(config):
|
||||
"""Create a new library instance from the configuration.
|
||||
"""
|
||||
dbpath = config['library'].as_filename()
|
||||
try:
|
||||
lib = library.Library(
|
||||
dbpath,
|
||||
config['directory'].as_filename(),
|
||||
get_path_formats(),
|
||||
get_replacements(),
|
||||
)
|
||||
lib.get_item(0) # Test database connection.
|
||||
except (sqlite3.OperationalError, sqlite3.DatabaseError):
|
||||
log.debug(traceback.format_exc())
|
||||
raise UserError(u"database file {0} could not be opened".format(
|
||||
util.displayable_path(dbpath)
|
||||
))
|
||||
log.debug(u'library database: {0}\n'
|
||||
u'library directory: {1}'
|
||||
.format(util.displayable_path(lib.path),
|
||||
util.displayable_path(lib.directory)))
|
||||
return lib
|
||||
|
||||
|
||||
def _raw_main(args, lib=None):
|
||||
"""A helper function for `main` without top-level exception
|
||||
handling.
|
||||
"""
|
||||
parser = SubcommandsOptionParser()
|
||||
parser.add_option('-l', '--library', dest='library',
|
||||
help='library database file to use')
|
||||
parser.add_option('-d', '--directory', dest='directory',
|
||||
help="destination music directory")
|
||||
parser.add_option('-v', '--verbose', dest='verbose', action='store_true',
|
||||
help='print debugging information')
|
||||
parser.add_option('-c', '--config', dest='config',
|
||||
help='path to configuration file')
|
||||
parser.add_option('-h', '--help', dest='help', action='store_true',
|
||||
help='how this help message and exit')
|
||||
parser.add_option('--version', dest='version', action='store_true',
|
||||
help=optparse.SUPPRESS_HELP)
|
||||
|
||||
options, subargs = parser.parse_global_options(args)
|
||||
|
||||
# Special case for the `config --edit` command: bypass _setup so
|
||||
# that an invalid configuration does not prevent the editor from
|
||||
# starting.
|
||||
if subargs[0] == 'config' and ('-e' in subargs or '--edit' in subargs):
|
||||
from beets.ui.commands import config_edit
|
||||
return config_edit()
|
||||
|
||||
subcommands, plugins, lib = _setup(options, lib)
|
||||
parser.add_subcommand(*subcommands)
|
||||
|
||||
subcommand, suboptions, subargs = parser.parse_subcommand(subargs)
|
||||
subcommand.func(lib, suboptions, subargs)
|
||||
|
||||
plugins.send('cli_exit', lib=lib)
|
||||
|
||||
|
||||
def main(args=None):
|
||||
"""Run the main command-line interface for beets. Includes top-level
|
||||
exception handlers that print friendly error messages.
|
||||
"""
|
||||
try:
|
||||
_raw_main(args)
|
||||
except UserError as exc:
|
||||
message = exc.args[0] if exc.args else None
|
||||
log.error(u'error: {0}'.format(message))
|
||||
sys.exit(1)
|
||||
except util.HumanReadableException as exc:
|
||||
exc.log(log)
|
||||
sys.exit(1)
|
||||
except library.FileOperationError as exc:
|
||||
# These errors have reasonable human-readable descriptions, but
|
||||
# we still want to log their tracebacks for debugging.
|
||||
log.debug(traceback.format_exc())
|
||||
log.error(exc)
|
||||
sys.exit(1)
|
||||
except confit.ConfigError as exc:
|
||||
log.error(u'configuration error: {0}'.format(exc))
|
||||
sys.exit(1)
|
||||
except IOError as exc:
|
||||
if exc.errno == errno.EPIPE:
|
||||
# "Broken pipe". End silently.
|
||||
pass
|
||||
else:
|
||||
raise
|
||||
except KeyboardInterrupt:
|
||||
# Silently ignore ^C except in verbose mode.
|
||||
log.debug(traceback.format_exc())
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,162 @@
|
||||
# This file is part of beets.
|
||||
# Copyright (c) 2014, Thomas Scholtes.
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining
|
||||
# a copy of this software and associated documentation files (the
|
||||
# "Software"), to deal in the Software without restriction, including
|
||||
# without limitation the rights to use, copy, modify, merge, publish,
|
||||
# distribute, sublicense, and/or sell copies of the Software, and to
|
||||
# permit persons to whom the Software is furnished to do so, subject to
|
||||
# the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be
|
||||
# included in all copies or substantial portions of the Software.
|
||||
|
||||
|
||||
|
||||
# Completion for the `beet` command
|
||||
# =================================
|
||||
#
|
||||
# Load this script to complete beets subcommands, options, and
|
||||
# queries.
|
||||
#
|
||||
# If a beets command is found on the command line it completes filenames and
|
||||
# the subcommand's options. Otherwise it will complete global options and
|
||||
# subcommands. If the previous option on the command line expects an argument,
|
||||
# it also completes filenames or directories. Options are only
|
||||
# completed if '-' has already been typed on the command line.
|
||||
#
|
||||
# Note that completion of plugin commands only works for those plugins
|
||||
# that were enabled when running `beet completion`. It does not check
|
||||
# plugins dynamically
|
||||
#
|
||||
# Currently, only Bash 3.2 and newer is supported and the
|
||||
# `bash-completion` package is requied.
|
||||
#
|
||||
# TODO
|
||||
# ----
|
||||
#
|
||||
# * There are some issues with arguments that are quoted on the command line.
|
||||
#
|
||||
# * Complete arguments for the `--format` option by expanding field variables.
|
||||
#
|
||||
# beet ls -f "$tit[TAB]
|
||||
# beet ls -f "$title
|
||||
#
|
||||
# * Support long options with `=`, e.g. `--config=file`. Debian's bash
|
||||
# completion package can handle this.
|
||||
#
|
||||
|
||||
|
||||
# Determines the beets subcommand and dispatches the completion
|
||||
# accordingly.
|
||||
_beet_dispatch() {
|
||||
local cur prev cmd=
|
||||
|
||||
COMPREPLY=()
|
||||
_get_comp_words_by_ref -n : cur prev
|
||||
|
||||
# Look for the beets subcommand
|
||||
local arg
|
||||
for (( i=1; i < COMP_CWORD; i++ )); do
|
||||
arg="${COMP_WORDS[i]}"
|
||||
if _list_include_item "${opts___global}" $arg; then
|
||||
((i++))
|
||||
elif [[ "$arg" != -* ]]; then
|
||||
cmd="$arg"
|
||||
break
|
||||
fi
|
||||
done
|
||||
|
||||
# Replace command shortcuts
|
||||
if [[ -n $cmd ]] && _list_include_item "$aliases" "$cmd"; then
|
||||
eval "cmd=\$alias__$cmd"
|
||||
fi
|
||||
|
||||
case $cmd in
|
||||
help)
|
||||
COMPREPLY+=( $(compgen -W "$commands" -- $cur) )
|
||||
;;
|
||||
list|remove|move|update|write|stats)
|
||||
_beet_complete_query
|
||||
;;
|
||||
"")
|
||||
_beet_complete_global
|
||||
;;
|
||||
*)
|
||||
_beet_complete
|
||||
;;
|
||||
esac
|
||||
}
|
||||
|
||||
|
||||
# Adds option and file completion to COMPREPLY for the subcommand $cmd
|
||||
_beet_complete() {
|
||||
if [[ $cur == -* ]]; then
|
||||
local opts flags completions
|
||||
eval "opts=\$opts__$cmd"
|
||||
eval "flags=\$flags__$cmd"
|
||||
completions="${flags___common} ${opts} ${flags}"
|
||||
COMPREPLY+=( $(compgen -W "$completions" -- $cur) )
|
||||
else
|
||||
_filedir
|
||||
fi
|
||||
}
|
||||
|
||||
|
||||
# Add global options and subcommands to the completion
|
||||
_beet_complete_global() {
|
||||
case $prev in
|
||||
-h|--help)
|
||||
# Complete commands
|
||||
COMPREPLY+=( $(compgen -W "$commands" -- $cur) )
|
||||
return
|
||||
;;
|
||||
-l|--library|-c|--config)
|
||||
# Filename completion
|
||||
_filedir
|
||||
return
|
||||
;;
|
||||
-d|--directory)
|
||||
# Directory completion
|
||||
_filedir -d
|
||||
return
|
||||
;;
|
||||
esac
|
||||
|
||||
if [[ $cur == -* ]]; then
|
||||
local completions="$opts___global $flags___global"
|
||||
COMPREPLY+=( $(compgen -W "$completions" -- $cur) )
|
||||
elif [[ -n $cur ]] && _list_include_item "$aliases" "$cur"; then
|
||||
local cmd
|
||||
eval "cmd=\$alias__$cur"
|
||||
COMPREPLY+=( "$cmd" )
|
||||
else
|
||||
COMPREPLY+=( $(compgen -W "$commands" -- $cur) )
|
||||
fi
|
||||
}
|
||||
|
||||
_beet_complete_query() {
|
||||
local opts
|
||||
eval "opts=\$opts__$cmd"
|
||||
|
||||
if [[ $cur == -* ]] || _list_include_item "$opts" "$prev"; then
|
||||
_beet_complete
|
||||
elif [[ $cur != \'* && $cur != \"* &&
|
||||
$cur != *:* ]]; then
|
||||
# Do not complete quoted queries or those who already have a field
|
||||
# set.
|
||||
compopt -o nospace
|
||||
COMPREPLY+=( $(compgen -S : -W "$fields" -- $cur) )
|
||||
return 0
|
||||
fi
|
||||
}
|
||||
|
||||
# Returns true if the space separated list $1 includes $2
|
||||
_list_include_item() {
|
||||
[[ " $1 " == *[[:space:]]$2[[:space:]]* ]]
|
||||
}
|
||||
|
||||
# This is where beets dynamically adds the _beet function. This
|
||||
# function sets the variables $flags, $opts, $commands, and $aliases.
|
||||
complete -o filenames -F _beet beet
|
||||
@@ -0,0 +1,686 @@
|
||||
# This file is part of beets.
|
||||
# Copyright 2013, Adrian Sampson.
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining
|
||||
# a copy of this software and associated documentation files (the
|
||||
# "Software"), to deal in the Software without restriction, including
|
||||
# without limitation the rights to use, copy, modify, merge, publish,
|
||||
# distribute, sublicense, and/or sell copies of the Software, and to
|
||||
# permit persons to whom the Software is furnished to do so, subject to
|
||||
# the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be
|
||||
# included in all copies or substantial portions of the Software.
|
||||
|
||||
"""Miscellaneous utility functions."""
|
||||
from __future__ import division
|
||||
|
||||
import os
|
||||
import sys
|
||||
import re
|
||||
import shutil
|
||||
import fnmatch
|
||||
from collections import defaultdict
|
||||
import traceback
|
||||
import subprocess
|
||||
import platform
|
||||
|
||||
|
||||
MAX_FILENAME_LENGTH = 200
|
||||
WINDOWS_MAGIC_PREFIX = u'\\\\?\\'
|
||||
|
||||
|
||||
class HumanReadableException(Exception):
|
||||
"""An Exception that can include a human-readable error message to
|
||||
be logged without a traceback. Can preserve a traceback for
|
||||
debugging purposes as well.
|
||||
|
||||
Has at least two fields: `reason`, the underlying exception or a
|
||||
string describing the problem; and `verb`, the action being
|
||||
performed during the error.
|
||||
|
||||
If `tb` is provided, it is a string containing a traceback for the
|
||||
associated exception. (Note that this is not necessary in Python 3.x
|
||||
and should be removed when we make the transition.)
|
||||
"""
|
||||
error_kind = 'Error' # Human-readable description of error type.
|
||||
|
||||
def __init__(self, reason, verb, tb=None):
|
||||
self.reason = reason
|
||||
self.verb = verb
|
||||
self.tb = tb
|
||||
super(HumanReadableException, self).__init__(self.get_message())
|
||||
|
||||
def _gerund(self):
|
||||
"""Generate a (likely) gerund form of the English verb.
|
||||
"""
|
||||
if ' ' in self.verb:
|
||||
return self.verb
|
||||
gerund = self.verb[:-1] if self.verb.endswith('e') else self.verb
|
||||
gerund += 'ing'
|
||||
return gerund
|
||||
|
||||
def _reasonstr(self):
|
||||
"""Get the reason as a string."""
|
||||
if isinstance(self.reason, unicode):
|
||||
return self.reason
|
||||
elif isinstance(self.reason, basestring): # Byte string.
|
||||
return self.reason.decode('utf8', 'ignore')
|
||||
elif hasattr(self.reason, 'strerror'): # i.e., EnvironmentError
|
||||
return self.reason.strerror
|
||||
else:
|
||||
return u'"{0}"'.format(unicode(self.reason))
|
||||
|
||||
def get_message(self):
|
||||
"""Create the human-readable description of the error, sans
|
||||
introduction.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def log(self, logger):
|
||||
"""Log to the provided `logger` a human-readable message as an
|
||||
error and a verbose traceback as a debug message.
|
||||
"""
|
||||
if self.tb:
|
||||
logger.debug(self.tb)
|
||||
logger.error(u'{0}: {1}'.format(self.error_kind, self.args[0]))
|
||||
|
||||
|
||||
class FilesystemError(HumanReadableException):
|
||||
"""An error that occurred while performing a filesystem manipulation
|
||||
via a function in this module. The `paths` field is a sequence of
|
||||
pathnames involved in the operation.
|
||||
"""
|
||||
def __init__(self, reason, verb, paths, tb=None):
|
||||
self.paths = paths
|
||||
super(FilesystemError, self).__init__(reason, verb, tb)
|
||||
|
||||
def get_message(self):
|
||||
# Use a nicer English phrasing for some specific verbs.
|
||||
if self.verb in ('move', 'copy', 'rename'):
|
||||
clause = u'while {0} {1} to {2}'.format(
|
||||
self._gerund(),
|
||||
displayable_path(self.paths[0]),
|
||||
displayable_path(self.paths[1])
|
||||
)
|
||||
elif self.verb in ('delete', 'write', 'create', 'read'):
|
||||
clause = u'while {0} {1}'.format(
|
||||
self._gerund(),
|
||||
displayable_path(self.paths[0])
|
||||
)
|
||||
else:
|
||||
clause = u'during {0} of paths {1}'.format(
|
||||
self.verb, u', '.join(displayable_path(p) for p in self.paths)
|
||||
)
|
||||
|
||||
return u'{0} {1}'.format(self._reasonstr(), clause)
|
||||
|
||||
|
||||
def normpath(path):
|
||||
"""Provide the canonical form of the path suitable for storing in
|
||||
the database.
|
||||
"""
|
||||
path = syspath(path, prefix=False)
|
||||
path = os.path.normpath(os.path.abspath(os.path.expanduser(path)))
|
||||
return bytestring_path(path)
|
||||
|
||||
|
||||
def ancestry(path):
|
||||
"""Return a list consisting of path's parent directory, its
|
||||
grandparent, and so on. For instance:
|
||||
|
||||
>>> ancestry('/a/b/c')
|
||||
['/', '/a', '/a/b']
|
||||
|
||||
The argument should *not* be the result of a call to `syspath`.
|
||||
"""
|
||||
out = []
|
||||
last_path = None
|
||||
while path:
|
||||
path = os.path.dirname(path)
|
||||
|
||||
if path == last_path:
|
||||
break
|
||||
last_path = path
|
||||
|
||||
if path:
|
||||
# don't yield ''
|
||||
out.insert(0, path)
|
||||
return out
|
||||
|
||||
|
||||
def sorted_walk(path, ignore=(), logger=None):
|
||||
"""Like `os.walk`, but yields things in case-insensitive sorted,
|
||||
breadth-first order. Directory and file names matching any glob
|
||||
pattern in `ignore` are skipped. If `logger` is provided, then
|
||||
warning messages are logged there when a directory cannot be listed.
|
||||
"""
|
||||
# Make sure the path isn't a Unicode string.
|
||||
path = bytestring_path(path)
|
||||
|
||||
# Get all the directories and files at this level.
|
||||
try:
|
||||
contents = os.listdir(syspath(path))
|
||||
except OSError as exc:
|
||||
if logger:
|
||||
logger.warn(u'could not list directory {0}: {1}'.format(
|
||||
displayable_path(path), exc.strerror
|
||||
))
|
||||
return
|
||||
dirs = []
|
||||
files = []
|
||||
for base in contents:
|
||||
base = bytestring_path(base)
|
||||
|
||||
# Skip ignored filenames.
|
||||
skip = False
|
||||
for pat in ignore:
|
||||
if fnmatch.fnmatch(base, pat):
|
||||
skip = True
|
||||
break
|
||||
if skip:
|
||||
continue
|
||||
|
||||
# Add to output as either a file or a directory.
|
||||
cur = os.path.join(path, base)
|
||||
if os.path.isdir(syspath(cur)):
|
||||
dirs.append(base)
|
||||
else:
|
||||
files.append(base)
|
||||
|
||||
# Sort lists (case-insensitive) and yield the current level.
|
||||
dirs.sort(key=bytes.lower)
|
||||
files.sort(key=bytes.lower)
|
||||
yield (path, dirs, files)
|
||||
|
||||
# Recurse into directories.
|
||||
for base in dirs:
|
||||
cur = os.path.join(path, base)
|
||||
# yield from sorted_walk(...)
|
||||
for res in sorted_walk(cur, ignore, logger):
|
||||
yield res
|
||||
|
||||
|
||||
def mkdirall(path):
|
||||
"""Make all the enclosing directories of path (like mkdir -p on the
|
||||
parent).
|
||||
"""
|
||||
for ancestor in ancestry(path):
|
||||
if not os.path.isdir(syspath(ancestor)):
|
||||
try:
|
||||
os.mkdir(syspath(ancestor))
|
||||
except (OSError, IOError) as exc:
|
||||
raise FilesystemError(exc, 'create', (ancestor,),
|
||||
traceback.format_exc())
|
||||
|
||||
|
||||
def fnmatch_all(names, patterns):
|
||||
"""Determine whether all strings in `names` match at least one of
|
||||
the `patterns`, which should be shell glob expressions.
|
||||
"""
|
||||
for name in names:
|
||||
matches = False
|
||||
for pattern in patterns:
|
||||
matches = fnmatch.fnmatch(name, pattern)
|
||||
if matches:
|
||||
break
|
||||
if not matches:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def prune_dirs(path, root=None, clutter=('.DS_Store', 'Thumbs.db')):
|
||||
"""If path is an empty directory, then remove it. Recursively remove
|
||||
path's ancestry up to root (which is never removed) where there are
|
||||
empty directories. If path is not contained in root, then nothing is
|
||||
removed. Glob patterns in clutter are ignored when determining
|
||||
emptiness. If root is not provided, then only path may be removed
|
||||
(i.e., no recursive removal).
|
||||
"""
|
||||
path = normpath(path)
|
||||
if root is not None:
|
||||
root = normpath(root)
|
||||
|
||||
ancestors = ancestry(path)
|
||||
if root is None:
|
||||
# Only remove the top directory.
|
||||
ancestors = []
|
||||
elif root in ancestors:
|
||||
# Only remove directories below the root.
|
||||
ancestors = ancestors[ancestors.index(root) + 1:]
|
||||
else:
|
||||
# Remove nothing.
|
||||
return
|
||||
|
||||
# Traverse upward from path.
|
||||
ancestors.append(path)
|
||||
ancestors.reverse()
|
||||
for directory in ancestors:
|
||||
directory = syspath(directory)
|
||||
if not os.path.exists(directory):
|
||||
# Directory gone already.
|
||||
continue
|
||||
if fnmatch_all(os.listdir(directory), clutter):
|
||||
# Directory contains only clutter (or nothing).
|
||||
try:
|
||||
shutil.rmtree(directory)
|
||||
except OSError:
|
||||
break
|
||||
else:
|
||||
break
|
||||
|
||||
|
||||
def components(path):
|
||||
"""Return a list of the path components in path. For instance:
|
||||
|
||||
>>> components('/a/b/c')
|
||||
['a', 'b', 'c']
|
||||
|
||||
The argument should *not* be the result of a call to `syspath`.
|
||||
"""
|
||||
comps = []
|
||||
ances = ancestry(path)
|
||||
for anc in ances:
|
||||
comp = os.path.basename(anc)
|
||||
if comp:
|
||||
comps.append(comp)
|
||||
else: # root
|
||||
comps.append(anc)
|
||||
|
||||
last = os.path.basename(path)
|
||||
if last:
|
||||
comps.append(last)
|
||||
|
||||
return comps
|
||||
|
||||
|
||||
def _fsencoding():
|
||||
"""Get the system's filesystem encoding. On Windows, this is always
|
||||
UTF-8 (not MBCS).
|
||||
"""
|
||||
encoding = sys.getfilesystemencoding() or sys.getdefaultencoding()
|
||||
if encoding == 'mbcs':
|
||||
# On Windows, a broken encoding known to Python as "MBCS" is
|
||||
# used for the filesystem. However, we only use the Unicode API
|
||||
# for Windows paths, so the encoding is actually immaterial so
|
||||
# we can avoid dealing with this nastiness. We arbitrarily
|
||||
# choose UTF-8.
|
||||
encoding = 'utf8'
|
||||
return encoding
|
||||
|
||||
|
||||
def bytestring_path(path):
|
||||
"""Given a path, which is either a str or a unicode, returns a str
|
||||
path (ensuring that we never deal with Unicode pathnames).
|
||||
"""
|
||||
# Pass through bytestrings.
|
||||
if isinstance(path, str):
|
||||
return path
|
||||
|
||||
# On Windows, remove the magic prefix added by `syspath`. This makes
|
||||
# ``bytestring_path(syspath(X)) == X``, i.e., we can safely
|
||||
# round-trip through `syspath`.
|
||||
if os.path.__name__ == 'ntpath' and path.startswith(WINDOWS_MAGIC_PREFIX):
|
||||
path = path[len(WINDOWS_MAGIC_PREFIX):]
|
||||
|
||||
# Try to encode with default encodings, but fall back to UTF8.
|
||||
try:
|
||||
return path.encode(_fsencoding())
|
||||
except (UnicodeError, LookupError):
|
||||
return path.encode('utf8')
|
||||
|
||||
|
||||
def displayable_path(path, separator=u'; '):
|
||||
"""Attempts to decode a bytestring path to a unicode object for the
|
||||
purpose of displaying it to the user. If the `path` argument is a
|
||||
list or a tuple, the elements are joined with `separator`.
|
||||
"""
|
||||
if isinstance(path, (list, tuple)):
|
||||
return separator.join(displayable_path(p) for p in path)
|
||||
elif isinstance(path, unicode):
|
||||
return path
|
||||
elif not isinstance(path, str):
|
||||
# A non-string object: just get its unicode representation.
|
||||
return unicode(path)
|
||||
|
||||
try:
|
||||
return path.decode(_fsencoding(), 'ignore')
|
||||
except (UnicodeError, LookupError):
|
||||
return path.decode('utf8', 'ignore')
|
||||
|
||||
|
||||
def syspath(path, prefix=True):
|
||||
"""Convert a path for use by the operating system. In particular,
|
||||
paths on Windows must receive a magic prefix and must be converted
|
||||
to Unicode before they are sent to the OS. To disable the magic
|
||||
prefix on Windows, set `prefix` to False---but only do this if you
|
||||
*really* know what you're doing.
|
||||
"""
|
||||
# Don't do anything if we're not on windows
|
||||
if os.path.__name__ != 'ntpath':
|
||||
return path
|
||||
|
||||
if not isinstance(path, unicode):
|
||||
# Beets currently represents Windows paths internally with UTF-8
|
||||
# arbitrarily. But earlier versions used MBCS because it is
|
||||
# reported as the FS encoding by Windows. Try both.
|
||||
try:
|
||||
path = path.decode('utf8')
|
||||
except UnicodeError:
|
||||
# The encoding should always be MBCS, Windows' broken
|
||||
# Unicode representation.
|
||||
encoding = sys.getfilesystemencoding() or sys.getdefaultencoding()
|
||||
path = path.decode(encoding, 'replace')
|
||||
|
||||
# Add the magic prefix if it isn't already there.
|
||||
# http://msdn.microsoft.com/en-us/library/windows/desktop/aa365247.aspx
|
||||
if prefix and not path.startswith(WINDOWS_MAGIC_PREFIX):
|
||||
if path.startswith(u'\\\\'):
|
||||
# UNC path. Final path should look like \\?\UNC\...
|
||||
path = u'UNC' + path[1:]
|
||||
path = WINDOWS_MAGIC_PREFIX + path
|
||||
|
||||
return path
|
||||
|
||||
|
||||
def samefile(p1, p2):
|
||||
"""Safer equality for paths."""
|
||||
return shutil._samefile(syspath(p1), syspath(p2))
|
||||
|
||||
|
||||
def remove(path, soft=True):
|
||||
"""Remove the file. If `soft`, then no error will be raised if the
|
||||
file does not exist.
|
||||
"""
|
||||
path = syspath(path)
|
||||
if soft and not os.path.exists(path):
|
||||
return
|
||||
try:
|
||||
os.remove(path)
|
||||
except (OSError, IOError) as exc:
|
||||
raise FilesystemError(exc, 'delete', (path,), traceback.format_exc())
|
||||
|
||||
|
||||
def copy(path, dest, replace=False):
|
||||
"""Copy a plain file. Permissions are not copied. If `dest` already
|
||||
exists, raises a FilesystemError unless `replace` is True. Has no
|
||||
effect if `path` is the same as `dest`. Paths are translated to
|
||||
system paths before the syscall.
|
||||
"""
|
||||
if samefile(path, dest):
|
||||
return
|
||||
path = syspath(path)
|
||||
dest = syspath(dest)
|
||||
if not replace and os.path.exists(dest):
|
||||
raise FilesystemError('file exists', 'copy', (path, dest))
|
||||
try:
|
||||
shutil.copyfile(path, dest)
|
||||
except (OSError, IOError) as exc:
|
||||
raise FilesystemError(exc, 'copy', (path, dest),
|
||||
traceback.format_exc())
|
||||
|
||||
|
||||
def move(path, dest, replace=False):
|
||||
"""Rename a file. `dest` may not be a directory. If `dest` already
|
||||
exists, raises an OSError unless `replace` is True. Has no effect if
|
||||
`path` is the same as `dest`. If the paths are on different
|
||||
filesystems (or the rename otherwise fails), a copy is attempted
|
||||
instead, in which case metadata will *not* be preserved. Paths are
|
||||
translated to system paths.
|
||||
"""
|
||||
if samefile(path, dest):
|
||||
return
|
||||
path = syspath(path)
|
||||
dest = syspath(dest)
|
||||
if os.path.exists(dest) and not replace:
|
||||
raise FilesystemError('file exists', 'rename', (path, dest),
|
||||
traceback.format_exc())
|
||||
|
||||
# First, try renaming the file.
|
||||
try:
|
||||
os.rename(path, dest)
|
||||
except OSError:
|
||||
# Otherwise, copy and delete the original.
|
||||
try:
|
||||
shutil.copyfile(path, dest)
|
||||
os.remove(path)
|
||||
except (OSError, IOError) as exc:
|
||||
raise FilesystemError(exc, 'move', (path, dest),
|
||||
traceback.format_exc())
|
||||
|
||||
|
||||
def link(path, dest, replace=False):
|
||||
"""Create a symbolic link from path to `dest`. Raises an OSError if
|
||||
`dest` already exists, unless `replace` is True. Does nothing if
|
||||
`path` == `dest`."""
|
||||
if (samefile(path, dest)):
|
||||
return
|
||||
|
||||
path = syspath(path)
|
||||
dest = syspath(dest)
|
||||
if os.path.exists(dest) and not replace:
|
||||
raise FilesystemError('file exists', 'rename', (path, dest),
|
||||
traceback.format_exc())
|
||||
try:
|
||||
os.symlink(path, dest)
|
||||
except OSError:
|
||||
raise FilesystemError('Operating system does not support symbolic '
|
||||
'links.', 'link', (path, dest),
|
||||
traceback.format_exc())
|
||||
|
||||
|
||||
def unique_path(path):
|
||||
"""Returns a version of ``path`` that does not exist on the
|
||||
filesystem. Specifically, if ``path` itself already exists, then
|
||||
something unique is appended to the path.
|
||||
"""
|
||||
if not os.path.exists(syspath(path)):
|
||||
return path
|
||||
|
||||
base, ext = os.path.splitext(path)
|
||||
match = re.search(r'\.(\d)+$', base)
|
||||
if match:
|
||||
num = int(match.group(1))
|
||||
base = base[:match.start()]
|
||||
else:
|
||||
num = 0
|
||||
while True:
|
||||
num += 1
|
||||
new_path = '%s.%i%s' % (base, num, ext)
|
||||
if not os.path.exists(new_path):
|
||||
return new_path
|
||||
|
||||
# Note: The Windows "reserved characters" are, of course, allowed on
|
||||
# Unix. They are forbidden here because they cause problems on Samba
|
||||
# shares, which are sufficiently common as to cause frequent problems.
|
||||
# http://msdn.microsoft.com/en-us/library/windows/desktop/aa365247.aspx
|
||||
CHAR_REPLACE = [
|
||||
(re.compile(ur'[\\/]'), u'_'), # / and \ -- forbidden everywhere.
|
||||
(re.compile(ur'^\.'), u'_'), # Leading dot (hidden files on Unix).
|
||||
(re.compile(ur'[\x00-\x1f]'), u''), # Control characters.
|
||||
(re.compile(ur'[<>:"\?\*\|]'), u'_'), # Windows "reserved characters".
|
||||
(re.compile(ur'\.$'), u'_'), # Trailing dots.
|
||||
(re.compile(ur'\s+$'), u''), # Trailing whitespace.
|
||||
]
|
||||
|
||||
|
||||
def sanitize_path(path, replacements=None):
|
||||
"""Takes a path (as a Unicode string) and makes sure that it is
|
||||
legal. Returns a new path. Only works with fragments; won't work
|
||||
reliably on Windows when a path begins with a drive letter. Path
|
||||
separators (including altsep!) should already be cleaned from the
|
||||
path components. If replacements is specified, it is used *instead*
|
||||
of the default set of replacements; it must be a list of (compiled
|
||||
regex, replacement string) pairs.
|
||||
"""
|
||||
replacements = replacements or CHAR_REPLACE
|
||||
|
||||
comps = components(path)
|
||||
if not comps:
|
||||
return ''
|
||||
for i, comp in enumerate(comps):
|
||||
for regex, repl in replacements:
|
||||
comp = regex.sub(repl, comp)
|
||||
comps[i] = comp
|
||||
return os.path.join(*comps)
|
||||
|
||||
|
||||
def truncate_path(path, length=MAX_FILENAME_LENGTH):
|
||||
"""Given a bytestring path or a Unicode path fragment, truncate the
|
||||
components to a legal length. In the last component, the extension
|
||||
is preserved.
|
||||
"""
|
||||
comps = components(path)
|
||||
|
||||
out = [c[:length] for c in comps]
|
||||
base, ext = os.path.splitext(comps[-1])
|
||||
if ext:
|
||||
# Last component has an extension.
|
||||
base = base[:length - len(ext)]
|
||||
out[-1] = base + ext
|
||||
|
||||
return os.path.join(*out)
|
||||
|
||||
|
||||
def str2bool(value):
|
||||
"""Returns a boolean reflecting a human-entered string."""
|
||||
if value.lower() in ('yes', '1', 'true', 't', 'y'):
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
def as_string(value):
|
||||
"""Convert a value to a Unicode object for matching with a query.
|
||||
None becomes the empty string. Bytestrings are silently decoded.
|
||||
"""
|
||||
if value is None:
|
||||
return u''
|
||||
elif isinstance(value, buffer):
|
||||
return str(value).decode('utf8', 'ignore')
|
||||
elif isinstance(value, str):
|
||||
return value.decode('utf8', 'ignore')
|
||||
else:
|
||||
return unicode(value)
|
||||
|
||||
|
||||
def levenshtein(s1, s2):
|
||||
"""A nice DP edit distance implementation from Wikibooks:
|
||||
http://en.wikibooks.org/wiki/Algorithm_implementation/Strings/
|
||||
Levenshtein_distance#Python
|
||||
"""
|
||||
if len(s1) < len(s2):
|
||||
return levenshtein(s2, s1)
|
||||
if not s1:
|
||||
return len(s2)
|
||||
|
||||
previous_row = xrange(len(s2) + 1)
|
||||
for i, c1 in enumerate(s1):
|
||||
current_row = [i + 1]
|
||||
for j, c2 in enumerate(s2):
|
||||
insertions = previous_row[j + 1] + 1
|
||||
deletions = current_row[j] + 1
|
||||
substitutions = previous_row[j] + (c1 != c2)
|
||||
current_row.append(min(insertions, deletions, substitutions))
|
||||
previous_row = current_row
|
||||
|
||||
return previous_row[-1]
|
||||
|
||||
|
||||
def plurality(objs):
|
||||
"""Given a sequence of comparable objects, returns the object that
|
||||
is most common in the set and the frequency of that object. The
|
||||
sequence must contain at least one object.
|
||||
"""
|
||||
# Calculate frequencies.
|
||||
freqs = defaultdict(int)
|
||||
for obj in objs:
|
||||
freqs[obj] += 1
|
||||
|
||||
if not freqs:
|
||||
raise ValueError('sequence must be non-empty')
|
||||
|
||||
# Find object with maximum frequency.
|
||||
max_freq = 0
|
||||
res = None
|
||||
for obj, freq in freqs.items():
|
||||
if freq > max_freq:
|
||||
max_freq = freq
|
||||
res = obj
|
||||
|
||||
return res, max_freq
|
||||
|
||||
|
||||
def cpu_count():
|
||||
"""Return the number of hardware thread contexts (cores or SMT
|
||||
threads) in the system.
|
||||
"""
|
||||
# Adapted from the soundconverter project:
|
||||
# https://github.com/kassoulet/soundconverter
|
||||
if sys.platform == 'win32':
|
||||
try:
|
||||
num = int(os.environ['NUMBER_OF_PROCESSORS'])
|
||||
except (ValueError, KeyError):
|
||||
num = 0
|
||||
elif sys.platform == 'darwin':
|
||||
try:
|
||||
num = int(command_output(['sysctl', '-n', 'hw.ncpu']))
|
||||
except ValueError:
|
||||
num = 0
|
||||
else:
|
||||
try:
|
||||
num = os.sysconf('SC_NPROCESSORS_ONLN')
|
||||
except (ValueError, OSError, AttributeError):
|
||||
num = 0
|
||||
if num >= 1:
|
||||
return num
|
||||
else:
|
||||
return 1
|
||||
|
||||
|
||||
def command_output(cmd, shell=False):
|
||||
"""Runs the command and returns its output after it has exited.
|
||||
|
||||
``cmd`` is a list of arguments starting with the command names. If
|
||||
``shell`` is true, ``cmd`` is assumed to be a string and passed to a
|
||||
shell to execute.
|
||||
|
||||
If the process exits with a non-zero return code
|
||||
``subprocess.CalledProcessError`` is raised. May also raise
|
||||
``OSError``.
|
||||
|
||||
This replaces `subprocess.check_output`, which isn't available in
|
||||
Python 2.6 and which can have problems if lots of output is sent to
|
||||
stderr.
|
||||
"""
|
||||
proc = subprocess.Popen(
|
||||
cmd,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
close_fds=platform.system() != 'Windows',
|
||||
shell=shell
|
||||
)
|
||||
stdout, stderr = proc.communicate()
|
||||
if proc.returncode:
|
||||
raise subprocess.CalledProcessError(
|
||||
returncode=proc.returncode,
|
||||
cmd=' '.join(cmd),
|
||||
)
|
||||
return stdout
|
||||
|
||||
|
||||
def max_filename_length(path, limit=MAX_FILENAME_LENGTH):
|
||||
"""Attempt to determine the maximum filename length for the
|
||||
filesystem containing `path`. If the value is greater than `limit`,
|
||||
then `limit` is used instead (to prevent errors when a filesystem
|
||||
misreports its capacity). If it cannot be determined (e.g., on
|
||||
Windows), return `limit`.
|
||||
"""
|
||||
if hasattr(os, 'statvfs'):
|
||||
try:
|
||||
res = os.statvfs(path)
|
||||
except OSError:
|
||||
return limit
|
||||
return min(res[9], limit)
|
||||
else:
|
||||
return limit
|
||||
@@ -0,0 +1,211 @@
|
||||
# This file is part of beets.
|
||||
# Copyright 2014, Fabrice Laporte
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining
|
||||
# a copy of this software and associated documentation files (the
|
||||
# "Software"), to deal in the Software without restriction, including
|
||||
# without limitation the rights to use, copy, modify, merge, publish,
|
||||
# distribute, sublicense, and/or sell copies of the Software, and to
|
||||
# permit persons to whom the Software is furnished to do so, subject to
|
||||
# the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be
|
||||
# included in all copies or substantial portions of the Software.
|
||||
|
||||
"""Abstraction layer to resize images using PIL, ImageMagick, or a
|
||||
public resizing proxy if neither is available.
|
||||
"""
|
||||
import urllib
|
||||
import subprocess
|
||||
import os
|
||||
import re
|
||||
from tempfile import NamedTemporaryFile
|
||||
import logging
|
||||
from beets import util
|
||||
|
||||
# Resizing methods
|
||||
PIL = 1
|
||||
IMAGEMAGICK = 2
|
||||
WEBPROXY = 3
|
||||
|
||||
PROXY_URL = 'http://images.weserv.nl/'
|
||||
|
||||
log = logging.getLogger('beets')
|
||||
|
||||
|
||||
def resize_url(url, maxwidth):
|
||||
"""Return a proxied image URL that resizes the original image to
|
||||
maxwidth (preserving aspect ratio).
|
||||
"""
|
||||
return '{0}?{1}'.format(PROXY_URL, urllib.urlencode({
|
||||
'url': url.replace('http://', ''),
|
||||
'w': str(maxwidth),
|
||||
}))
|
||||
|
||||
|
||||
def temp_file_for(path):
|
||||
"""Return an unused filename with the same extension as the
|
||||
specified path.
|
||||
"""
|
||||
ext = os.path.splitext(path)[1]
|
||||
with NamedTemporaryFile(suffix=ext, delete=False) as f:
|
||||
return f.name
|
||||
|
||||
|
||||
def pil_resize(maxwidth, path_in, path_out=None):
|
||||
"""Resize using Python Imaging Library (PIL). Return the output path
|
||||
of resized image.
|
||||
"""
|
||||
path_out = path_out or temp_file_for(path_in)
|
||||
from PIL import Image
|
||||
log.debug(u'artresizer: PIL resizing {0} to {1}'.format(
|
||||
util.displayable_path(path_in), util.displayable_path(path_out)
|
||||
))
|
||||
|
||||
try:
|
||||
im = Image.open(util.syspath(path_in))
|
||||
size = maxwidth, maxwidth
|
||||
im.thumbnail(size, Image.ANTIALIAS)
|
||||
im.save(path_out)
|
||||
return path_out
|
||||
except IOError:
|
||||
log.error(u"PIL cannot create thumbnail for '{0}'".format(
|
||||
util.displayable_path(path_in)
|
||||
))
|
||||
return path_in
|
||||
|
||||
|
||||
def im_resize(maxwidth, path_in, path_out=None):
|
||||
"""Resize using ImageMagick's ``convert`` tool.
|
||||
Return the output path of resized image.
|
||||
"""
|
||||
path_out = path_out or temp_file_for(path_in)
|
||||
log.debug(u'artresizer: ImageMagick resizing {0} to {1}'.format(
|
||||
util.displayable_path(path_in), util.displayable_path(path_out)
|
||||
))
|
||||
|
||||
# "-resize widthxheight>" shrinks images with dimension(s) larger
|
||||
# than the corresponding width and/or height dimension(s). The >
|
||||
# "only shrink" flag is prefixed by ^ escape char for Windows
|
||||
# compatibility.
|
||||
try:
|
||||
util.command_output([
|
||||
'convert', util.syspath(path_in),
|
||||
'-resize', '{0}x^>'.format(maxwidth), path_out
|
||||
])
|
||||
except subprocess.CalledProcessError:
|
||||
log.warn(u'artresizer: IM convert failed for {0}'.format(
|
||||
util.displayable_path(path_in)
|
||||
))
|
||||
return path_in
|
||||
return path_out
|
||||
|
||||
|
||||
BACKEND_FUNCS = {
|
||||
PIL: pil_resize,
|
||||
IMAGEMAGICK: im_resize,
|
||||
}
|
||||
|
||||
|
||||
class Shareable(type):
|
||||
"""A pseudo-singleton metaclass that allows both shared and
|
||||
non-shared instances. The ``MyClass.shared`` property holds a
|
||||
lazily-created shared instance of ``MyClass`` while calling
|
||||
``MyClass()`` to construct a new object works as usual.
|
||||
"""
|
||||
def __init__(cls, name, bases, dict):
|
||||
super(Shareable, cls).__init__(name, bases, dict)
|
||||
cls._instance = None
|
||||
|
||||
@property
|
||||
def shared(cls):
|
||||
if cls._instance is None:
|
||||
cls._instance = cls()
|
||||
return cls._instance
|
||||
|
||||
|
||||
class ArtResizer(object):
|
||||
"""A singleton class that performs image resizes.
|
||||
"""
|
||||
__metaclass__ = Shareable
|
||||
|
||||
def __init__(self, method=None):
|
||||
"""Create a resizer object for the given method or, if none is
|
||||
specified, with an inferred method.
|
||||
"""
|
||||
self.method = self._check_method(method)
|
||||
log.debug(u"artresizer: method is {0}".format(self.method))
|
||||
self.can_compare = self._can_compare()
|
||||
|
||||
def resize(self, maxwidth, path_in, path_out=None):
|
||||
"""Manipulate an image file according to the method, returning a
|
||||
new path. For PIL or IMAGEMAGIC methods, resizes the image to a
|
||||
temporary file. For WEBPROXY, returns `path_in` unmodified.
|
||||
"""
|
||||
if self.local:
|
||||
func = BACKEND_FUNCS[self.method[0]]
|
||||
return func(maxwidth, path_in, path_out)
|
||||
else:
|
||||
return path_in
|
||||
|
||||
def proxy_url(self, maxwidth, url):
|
||||
"""Modifies an image URL according the method, returning a new
|
||||
URL. For WEBPROXY, a URL on the proxy server is returned.
|
||||
Otherwise, the URL is returned unmodified.
|
||||
"""
|
||||
if self.local:
|
||||
return url
|
||||
else:
|
||||
return resize_url(url, maxwidth)
|
||||
|
||||
@property
|
||||
def local(self):
|
||||
"""A boolean indicating whether the resizing method is performed
|
||||
locally (i.e., PIL or ImageMagick).
|
||||
"""
|
||||
return self.method[0] in BACKEND_FUNCS
|
||||
|
||||
def _can_compare(self):
|
||||
"""A boolean indicating whether image comparison is available"""
|
||||
|
||||
return self.method[0] == IMAGEMAGICK and self.method[1] > (6, 8, 7)
|
||||
|
||||
@staticmethod
|
||||
def _check_method(method=None):
|
||||
"""A tuple indicating whether current method is available and its
|
||||
version. If no method is given, it returns a supported one.
|
||||
"""
|
||||
# Guess available method
|
||||
if not method:
|
||||
for m in [IMAGEMAGICK, PIL]:
|
||||
_, version = ArtResizer._check_method(m)
|
||||
if version:
|
||||
return (m, version)
|
||||
return (WEBPROXY, (0))
|
||||
|
||||
if method == IMAGEMAGICK:
|
||||
|
||||
# Try invoking ImageMagick's "convert".
|
||||
try:
|
||||
out = util.command_output(['identify', '--version'])
|
||||
|
||||
if 'imagemagick' in out.lower():
|
||||
pattern = r".+ (\d+)\.(\d+)\.(\d+).*"
|
||||
match = re.search(pattern, out)
|
||||
if match:
|
||||
return (IMAGEMAGICK,
|
||||
(int(match.group(1)),
|
||||
int(match.group(2)),
|
||||
int(match.group(3))))
|
||||
return (IMAGEMAGICK, (0))
|
||||
|
||||
except (subprocess.CalledProcessError, OSError):
|
||||
return (IMAGEMAGICK, None)
|
||||
|
||||
if method == PIL:
|
||||
# Try importing PIL.
|
||||
try:
|
||||
__import__('PIL', fromlist=['Image'])
|
||||
return (PIL, (0))
|
||||
except ImportError:
|
||||
return (PIL, None)
|
||||
@@ -0,0 +1,647 @@
|
||||
"""Extremely simple pure-Python implementation of coroutine-style
|
||||
asynchronous socket I/O. Inspired by, but inferior to, Eventlet.
|
||||
Bluelet can also be thought of as a less-terrible replacement for
|
||||
asyncore.
|
||||
|
||||
Bluelet: easy concurrency without all the messy parallelism.
|
||||
"""
|
||||
import socket
|
||||
import select
|
||||
import sys
|
||||
import types
|
||||
import errno
|
||||
import traceback
|
||||
import time
|
||||
import collections
|
||||
|
||||
|
||||
# A little bit of "six" (Python 2/3 compatibility): cope with PEP 3109 syntax
|
||||
# changes.
|
||||
|
||||
PY3 = sys.version_info[0] == 3
|
||||
if PY3:
|
||||
def _reraise(typ, exc, tb):
|
||||
raise exc.with_traceback(tb)
|
||||
else:
|
||||
exec("""
|
||||
def _reraise(typ, exc, tb):
|
||||
raise typ, exc, tb
|
||||
""")
|
||||
|
||||
|
||||
# Basic events used for thread scheduling.
|
||||
|
||||
class Event(object):
|
||||
"""Just a base class identifying Bluelet events. An event is an
|
||||
object yielded from a Bluelet thread coroutine to suspend operation
|
||||
and communicate with the scheduler.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class WaitableEvent(Event):
|
||||
"""A waitable event is one encapsulating an action that can be
|
||||
waited for using a select() call. That is, it's an event with an
|
||||
associated file descriptor.
|
||||
"""
|
||||
def waitables(self):
|
||||
"""Return "waitable" objects to pass to select(). Should return
|
||||
three iterables for input readiness, output readiness, and
|
||||
exceptional conditions (i.e., the three lists passed to
|
||||
select()).
|
||||
"""
|
||||
return (), (), ()
|
||||
|
||||
def fire(self):
|
||||
"""Called when an associated file descriptor becomes ready
|
||||
(i.e., is returned from a select() call).
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class ValueEvent(Event):
|
||||
"""An event that does nothing but return a fixed value."""
|
||||
def __init__(self, value):
|
||||
self.value = value
|
||||
|
||||
|
||||
class ExceptionEvent(Event):
|
||||
"""Raise an exception at the yield point. Used internally."""
|
||||
def __init__(self, exc_info):
|
||||
self.exc_info = exc_info
|
||||
|
||||
|
||||
class SpawnEvent(Event):
|
||||
"""Add a new coroutine thread to the scheduler."""
|
||||
def __init__(self, coro):
|
||||
self.spawned = coro
|
||||
|
||||
|
||||
class JoinEvent(Event):
|
||||
"""Suspend the thread until the specified child thread has
|
||||
completed.
|
||||
"""
|
||||
def __init__(self, child):
|
||||
self.child = child
|
||||
|
||||
|
||||
class KillEvent(Event):
|
||||
"""Unschedule a child thread."""
|
||||
def __init__(self, child):
|
||||
self.child = child
|
||||
|
||||
|
||||
class DelegationEvent(Event):
|
||||
"""Suspend execution of the current thread, start a new thread and,
|
||||
once the child thread finished, return control to the parent
|
||||
thread.
|
||||
"""
|
||||
def __init__(self, coro):
|
||||
self.spawned = coro
|
||||
|
||||
|
||||
class ReturnEvent(Event):
|
||||
"""Return a value the current thread's delegator at the point of
|
||||
delegation. Ends the current (delegate) thread.
|
||||
"""
|
||||
def __init__(self, value):
|
||||
self.value = value
|
||||
|
||||
|
||||
class SleepEvent(WaitableEvent):
|
||||
"""Suspend the thread for a given duration.
|
||||
"""
|
||||
def __init__(self, duration):
|
||||
self.wakeup_time = time.time() + duration
|
||||
|
||||
def time_left(self):
|
||||
return max(self.wakeup_time - time.time(), 0.0)
|
||||
|
||||
|
||||
class ReadEvent(WaitableEvent):
|
||||
"""Reads from a file-like object."""
|
||||
def __init__(self, fd, bufsize):
|
||||
self.fd = fd
|
||||
self.bufsize = bufsize
|
||||
|
||||
def waitables(self):
|
||||
return (self.fd,), (), ()
|
||||
|
||||
def fire(self):
|
||||
return self.fd.read(self.bufsize)
|
||||
|
||||
|
||||
class WriteEvent(WaitableEvent):
|
||||
"""Writes to a file-like object."""
|
||||
def __init__(self, fd, data):
|
||||
self.fd = fd
|
||||
self.data = data
|
||||
|
||||
def waitable(self):
|
||||
return (), (self.fd,), ()
|
||||
|
||||
def fire(self):
|
||||
self.fd.write(self.data)
|
||||
|
||||
|
||||
# Core logic for executing and scheduling threads.
|
||||
|
||||
def _event_select(events):
|
||||
"""Perform a select() over all the Events provided, returning the
|
||||
ones ready to be fired. Only WaitableEvents (including SleepEvents)
|
||||
matter here; all other events are ignored (and thus postponed).
|
||||
"""
|
||||
# Gather waitables and wakeup times.
|
||||
waitable_to_event = {}
|
||||
rlist, wlist, xlist = [], [], []
|
||||
earliest_wakeup = None
|
||||
for event in events:
|
||||
if isinstance(event, SleepEvent):
|
||||
if not earliest_wakeup:
|
||||
earliest_wakeup = event.wakeup_time
|
||||
else:
|
||||
earliest_wakeup = min(earliest_wakeup, event.wakeup_time)
|
||||
elif isinstance(event, WaitableEvent):
|
||||
r, w, x = event.waitables()
|
||||
rlist += r
|
||||
wlist += w
|
||||
xlist += x
|
||||
for waitable in r:
|
||||
waitable_to_event[('r', waitable)] = event
|
||||
for waitable in w:
|
||||
waitable_to_event[('w', waitable)] = event
|
||||
for waitable in x:
|
||||
waitable_to_event[('x', waitable)] = event
|
||||
|
||||
# If we have a any sleeping threads, determine how long to sleep.
|
||||
if earliest_wakeup:
|
||||
timeout = max(earliest_wakeup - time.time(), 0.0)
|
||||
else:
|
||||
timeout = None
|
||||
|
||||
# Perform select() if we have any waitables.
|
||||
if rlist or wlist or xlist:
|
||||
rready, wready, xready = select.select(rlist, wlist, xlist, timeout)
|
||||
else:
|
||||
rready, wready, xready = (), (), ()
|
||||
if timeout:
|
||||
time.sleep(timeout)
|
||||
|
||||
# Gather ready events corresponding to the ready waitables.
|
||||
ready_events = set()
|
||||
for ready in rready:
|
||||
ready_events.add(waitable_to_event[('r', ready)])
|
||||
for ready in wready:
|
||||
ready_events.add(waitable_to_event[('w', ready)])
|
||||
for ready in xready:
|
||||
ready_events.add(waitable_to_event[('x', ready)])
|
||||
|
||||
# Gather any finished sleeps.
|
||||
for event in events:
|
||||
if isinstance(event, SleepEvent) and event.time_left() == 0.0:
|
||||
ready_events.add(event)
|
||||
|
||||
return ready_events
|
||||
|
||||
|
||||
class ThreadException(Exception):
|
||||
def __init__(self, coro, exc_info):
|
||||
self.coro = coro
|
||||
self.exc_info = exc_info
|
||||
|
||||
def reraise(self):
|
||||
_reraise(self.exc_info[0], self.exc_info[1], self.exc_info[2])
|
||||
|
||||
|
||||
SUSPENDED = Event() # Special sentinel placeholder for suspended threads.
|
||||
|
||||
|
||||
class Delegated(Event):
|
||||
"""Placeholder indicating that a thread has delegated execution to a
|
||||
different thread.
|
||||
"""
|
||||
def __init__(self, child):
|
||||
self.child = child
|
||||
|
||||
|
||||
def run(root_coro):
|
||||
"""Schedules a coroutine, running it to completion. This
|
||||
encapsulates the Bluelet scheduler, which the root coroutine can
|
||||
add to by spawning new coroutines.
|
||||
"""
|
||||
# The "threads" dictionary keeps track of all the currently-
|
||||
# executing and suspended coroutines. It maps coroutines to their
|
||||
# currently "blocking" event. The event value may be SUSPENDED if
|
||||
# the coroutine is waiting on some other condition: namely, a
|
||||
# delegated coroutine or a joined coroutine. In this case, the
|
||||
# coroutine should *also* appear as a value in one of the below
|
||||
# dictionaries `delegators` or `joiners`.
|
||||
threads = {root_coro: ValueEvent(None)}
|
||||
|
||||
# Maps child coroutines to delegating parents.
|
||||
delegators = {}
|
||||
|
||||
# Maps child coroutines to joining (exit-waiting) parents.
|
||||
joiners = collections.defaultdict(list)
|
||||
|
||||
def complete_thread(coro, return_value):
|
||||
"""Remove a coroutine from the scheduling pool, awaking
|
||||
delegators and joiners as necessary and returning the specified
|
||||
value to any delegating parent.
|
||||
"""
|
||||
del threads[coro]
|
||||
|
||||
# Resume delegator.
|
||||
if coro in delegators:
|
||||
threads[delegators[coro]] = ValueEvent(return_value)
|
||||
del delegators[coro]
|
||||
|
||||
# Resume joiners.
|
||||
if coro in joiners:
|
||||
for parent in joiners[coro]:
|
||||
threads[parent] = ValueEvent(None)
|
||||
del joiners[coro]
|
||||
|
||||
def advance_thread(coro, value, is_exc=False):
|
||||
"""After an event is fired, run a given coroutine associated with
|
||||
it in the threads dict until it yields again. If the coroutine
|
||||
exits, then the thread is removed from the pool. If the coroutine
|
||||
raises an exception, it is reraised in a ThreadException. If
|
||||
is_exc is True, then the value must be an exc_info tuple and the
|
||||
exception is thrown into the coroutine.
|
||||
"""
|
||||
try:
|
||||
if is_exc:
|
||||
next_event = coro.throw(*value)
|
||||
else:
|
||||
next_event = coro.send(value)
|
||||
except StopIteration:
|
||||
# Thread is done.
|
||||
complete_thread(coro, None)
|
||||
except:
|
||||
# Thread raised some other exception.
|
||||
del threads[coro]
|
||||
raise ThreadException(coro, sys.exc_info())
|
||||
else:
|
||||
if isinstance(next_event, types.GeneratorType):
|
||||
# Automatically invoke sub-coroutines. (Shorthand for
|
||||
# explicit bluelet.call().)
|
||||
next_event = DelegationEvent(next_event)
|
||||
threads[coro] = next_event
|
||||
|
||||
def kill_thread(coro):
|
||||
"""Unschedule this thread and its (recursive) delegates.
|
||||
"""
|
||||
# Collect all coroutines in the delegation stack.
|
||||
coros = [coro]
|
||||
while isinstance(threads[coro], Delegated):
|
||||
coro = threads[coro].child
|
||||
coros.append(coro)
|
||||
|
||||
# Complete each coroutine from the top to the bottom of the
|
||||
# stack.
|
||||
for coro in reversed(coros):
|
||||
complete_thread(coro, None)
|
||||
|
||||
# Continue advancing threads until root thread exits.
|
||||
exit_te = None
|
||||
while threads:
|
||||
try:
|
||||
# Look for events that can be run immediately. Continue
|
||||
# running immediate events until nothing is ready.
|
||||
while True:
|
||||
have_ready = False
|
||||
for coro, event in list(threads.items()):
|
||||
if isinstance(event, SpawnEvent):
|
||||
threads[event.spawned] = ValueEvent(None) # Spawn.
|
||||
advance_thread(coro, None)
|
||||
have_ready = True
|
||||
elif isinstance(event, ValueEvent):
|
||||
advance_thread(coro, event.value)
|
||||
have_ready = True
|
||||
elif isinstance(event, ExceptionEvent):
|
||||
advance_thread(coro, event.exc_info, True)
|
||||
have_ready = True
|
||||
elif isinstance(event, DelegationEvent):
|
||||
threads[coro] = Delegated(event.spawned) # Suspend.
|
||||
threads[event.spawned] = ValueEvent(None) # Spawn.
|
||||
delegators[event.spawned] = coro
|
||||
have_ready = True
|
||||
elif isinstance(event, ReturnEvent):
|
||||
# Thread is done.
|
||||
complete_thread(coro, event.value)
|
||||
have_ready = True
|
||||
elif isinstance(event, JoinEvent):
|
||||
threads[coro] = SUSPENDED # Suspend.
|
||||
joiners[event.child].append(coro)
|
||||
have_ready = True
|
||||
elif isinstance(event, KillEvent):
|
||||
threads[coro] = ValueEvent(None)
|
||||
kill_thread(event.child)
|
||||
have_ready = True
|
||||
|
||||
# Only start the select when nothing else is ready.
|
||||
if not have_ready:
|
||||
break
|
||||
|
||||
# Wait and fire.
|
||||
event2coro = dict((v, k) for k, v in threads.items())
|
||||
for event in _event_select(threads.values()):
|
||||
# Run the IO operation, but catch socket errors.
|
||||
try:
|
||||
value = event.fire()
|
||||
except socket.error as exc:
|
||||
if isinstance(exc.args, tuple) and \
|
||||
exc.args[0] == errno.EPIPE:
|
||||
# Broken pipe. Remote host disconnected.
|
||||
pass
|
||||
else:
|
||||
traceback.print_exc()
|
||||
# Abort the coroutine.
|
||||
threads[event2coro[event]] = ReturnEvent(None)
|
||||
else:
|
||||
advance_thread(event2coro[event], value)
|
||||
|
||||
except ThreadException as te:
|
||||
# Exception raised from inside a thread.
|
||||
event = ExceptionEvent(te.exc_info)
|
||||
if te.coro in delegators:
|
||||
# The thread is a delegate. Raise exception in its
|
||||
# delegator.
|
||||
threads[delegators[te.coro]] = event
|
||||
del delegators[te.coro]
|
||||
else:
|
||||
# The thread is root-level. Raise in client code.
|
||||
exit_te = te
|
||||
break
|
||||
|
||||
except:
|
||||
# For instance, KeyboardInterrupt during select(). Raise
|
||||
# into root thread and terminate others.
|
||||
threads = {root_coro: ExceptionEvent(sys.exc_info())}
|
||||
|
||||
# If any threads still remain, kill them.
|
||||
for coro in threads:
|
||||
coro.close()
|
||||
|
||||
# If we're exiting with an exception, raise it in the client.
|
||||
if exit_te:
|
||||
exit_te.reraise()
|
||||
|
||||
|
||||
# Sockets and their associated events.
|
||||
|
||||
class SocketClosedError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class Listener(object):
|
||||
"""A socket wrapper object for listening sockets.
|
||||
"""
|
||||
def __init__(self, host, port):
|
||||
"""Create a listening socket on the given hostname and port.
|
||||
"""
|
||||
self._closed = False
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
self.sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||
self.sock.bind((host, port))
|
||||
self.sock.listen(5)
|
||||
|
||||
def accept(self):
|
||||
"""An event that waits for a connection on the listening socket.
|
||||
When a connection is made, the event returns a Connection
|
||||
object.
|
||||
"""
|
||||
if self._closed:
|
||||
raise SocketClosedError()
|
||||
return AcceptEvent(self)
|
||||
|
||||
def close(self):
|
||||
"""Immediately close the listening socket. (Not an event.)
|
||||
"""
|
||||
self._closed = True
|
||||
self.sock.close()
|
||||
|
||||
|
||||
class Connection(object):
|
||||
"""A socket wrapper object for connected sockets.
|
||||
"""
|
||||
def __init__(self, sock, addr):
|
||||
self.sock = sock
|
||||
self.addr = addr
|
||||
self._buf = b''
|
||||
self._closed = False
|
||||
|
||||
def close(self):
|
||||
"""Close the connection."""
|
||||
self._closed = True
|
||||
self.sock.close()
|
||||
|
||||
def recv(self, size):
|
||||
"""Read at most size bytes of data from the socket."""
|
||||
if self._closed:
|
||||
raise SocketClosedError()
|
||||
|
||||
if self._buf:
|
||||
# We already have data read previously.
|
||||
out = self._buf[:size]
|
||||
self._buf = self._buf[size:]
|
||||
return ValueEvent(out)
|
||||
else:
|
||||
return ReceiveEvent(self, size)
|
||||
|
||||
def send(self, data):
|
||||
"""Sends data on the socket, returning the number of bytes
|
||||
successfully sent.
|
||||
"""
|
||||
if self._closed:
|
||||
raise SocketClosedError()
|
||||
return SendEvent(self, data)
|
||||
|
||||
def sendall(self, data):
|
||||
"""Send all of data on the socket."""
|
||||
if self._closed:
|
||||
raise SocketClosedError()
|
||||
return SendEvent(self, data, True)
|
||||
|
||||
def readline(self, terminator=b"\n", bufsize=1024):
|
||||
"""Reads a line (delimited by terminator) from the socket."""
|
||||
if self._closed:
|
||||
raise SocketClosedError()
|
||||
|
||||
while True:
|
||||
if terminator in self._buf:
|
||||
line, self._buf = self._buf.split(terminator, 1)
|
||||
line += terminator
|
||||
yield ReturnEvent(line)
|
||||
break
|
||||
data = yield ReceiveEvent(self, bufsize)
|
||||
if data:
|
||||
self._buf += data
|
||||
else:
|
||||
line = self._buf
|
||||
self._buf = b''
|
||||
yield ReturnEvent(line)
|
||||
break
|
||||
|
||||
|
||||
class AcceptEvent(WaitableEvent):
|
||||
"""An event for Listener objects (listening sockets) that suspends
|
||||
execution until the socket gets a connection.
|
||||
"""
|
||||
def __init__(self, listener):
|
||||
self.listener = listener
|
||||
|
||||
def waitables(self):
|
||||
return (self.listener.sock,), (), ()
|
||||
|
||||
def fire(self):
|
||||
sock, addr = self.listener.sock.accept()
|
||||
return Connection(sock, addr)
|
||||
|
||||
|
||||
class ReceiveEvent(WaitableEvent):
|
||||
"""An event for Connection objects (connected sockets) for
|
||||
asynchronously reading data.
|
||||
"""
|
||||
def __init__(self, conn, bufsize):
|
||||
self.conn = conn
|
||||
self.bufsize = bufsize
|
||||
|
||||
def waitables(self):
|
||||
return (self.conn.sock,), (), ()
|
||||
|
||||
def fire(self):
|
||||
return self.conn.sock.recv(self.bufsize)
|
||||
|
||||
|
||||
class SendEvent(WaitableEvent):
|
||||
"""An event for Connection objects (connected sockets) for
|
||||
asynchronously writing data.
|
||||
"""
|
||||
def __init__(self, conn, data, sendall=False):
|
||||
self.conn = conn
|
||||
self.data = data
|
||||
self.sendall = sendall
|
||||
|
||||
def waitables(self):
|
||||
return (), (self.conn.sock,), ()
|
||||
|
||||
def fire(self):
|
||||
if self.sendall:
|
||||
return self.conn.sock.sendall(self.data)
|
||||
else:
|
||||
return self.conn.sock.send(self.data)
|
||||
|
||||
|
||||
# Public interface for threads; each returns an event object that
|
||||
# can immediately be "yield"ed.
|
||||
|
||||
def null():
|
||||
"""Event: yield to the scheduler without doing anything special.
|
||||
"""
|
||||
return ValueEvent(None)
|
||||
|
||||
|
||||
def spawn(coro):
|
||||
"""Event: add another coroutine to the scheduler. Both the parent
|
||||
and child coroutines run concurrently.
|
||||
"""
|
||||
if not isinstance(coro, types.GeneratorType):
|
||||
raise ValueError('%s is not a coroutine' % str(coro))
|
||||
return SpawnEvent(coro)
|
||||
|
||||
|
||||
def call(coro):
|
||||
"""Event: delegate to another coroutine. The current coroutine
|
||||
is resumed once the sub-coroutine finishes. If the sub-coroutine
|
||||
returns a value using end(), then this event returns that value.
|
||||
"""
|
||||
if not isinstance(coro, types.GeneratorType):
|
||||
raise ValueError('%s is not a coroutine' % str(coro))
|
||||
return DelegationEvent(coro)
|
||||
|
||||
|
||||
def end(value=None):
|
||||
"""Event: ends the coroutine and returns a value to its
|
||||
delegator.
|
||||
"""
|
||||
return ReturnEvent(value)
|
||||
|
||||
|
||||
def read(fd, bufsize=None):
|
||||
"""Event: read from a file descriptor asynchronously."""
|
||||
if bufsize is None:
|
||||
# Read all.
|
||||
def reader():
|
||||
buf = []
|
||||
while True:
|
||||
data = yield read(fd, 1024)
|
||||
if not data:
|
||||
break
|
||||
buf.append(data)
|
||||
yield ReturnEvent(''.join(buf))
|
||||
return DelegationEvent(reader())
|
||||
|
||||
else:
|
||||
return ReadEvent(fd, bufsize)
|
||||
|
||||
|
||||
def write(fd, data):
|
||||
"""Event: write to a file descriptor asynchronously."""
|
||||
return WriteEvent(fd, data)
|
||||
|
||||
|
||||
def connect(host, port):
|
||||
"""Event: connect to a network address and return a Connection
|
||||
object for communicating on the socket.
|
||||
"""
|
||||
addr = (host, port)
|
||||
sock = socket.create_connection(addr)
|
||||
return ValueEvent(Connection(sock, addr))
|
||||
|
||||
|
||||
def sleep(duration):
|
||||
"""Event: suspend the thread for ``duration`` seconds.
|
||||
"""
|
||||
return SleepEvent(duration)
|
||||
|
||||
|
||||
def join(coro):
|
||||
"""Suspend the thread until another, previously `spawn`ed thread
|
||||
completes.
|
||||
"""
|
||||
return JoinEvent(coro)
|
||||
|
||||
|
||||
def kill(coro):
|
||||
"""Halt the execution of a different `spawn`ed thread.
|
||||
"""
|
||||
return KillEvent(coro)
|
||||
|
||||
|
||||
# Convenience function for running socket servers.
|
||||
|
||||
def server(host, port, func):
|
||||
"""A coroutine that runs a network server. Host and port specify the
|
||||
listening address. func should be a coroutine that takes a single
|
||||
parameter, a Connection object. The coroutine is invoked for every
|
||||
incoming connection on the listening socket.
|
||||
"""
|
||||
def handler(conn):
|
||||
try:
|
||||
yield func(conn)
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
listener = Listener(host, port)
|
||||
try:
|
||||
while True:
|
||||
conn = yield listener.accept()
|
||||
yield spawn(handler(conn))
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
finally:
|
||||
listener.close()
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,40 @@
|
||||
# This file is part of beets.
|
||||
# Copyright 2013, Adrian Sampson.
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining
|
||||
# a copy of this software and associated documentation files (the
|
||||
# "Software"), to deal in the Software without restriction, including
|
||||
# without limitation the rights to use, copy, modify, merge, publish,
|
||||
# distribute, sublicense, and/or sell copies of the Software, and to
|
||||
# permit persons to whom the Software is furnished to do so, subject to
|
||||
# the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be
|
||||
# included in all copies or substantial portions of the Software.
|
||||
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class OrderedEnum(Enum):
|
||||
"""
|
||||
An Enum subclass that allows comparison of members.
|
||||
"""
|
||||
def __ge__(self, other):
|
||||
if self.__class__ is other.__class__:
|
||||
return self.value >= other.value
|
||||
return NotImplemented
|
||||
|
||||
def __gt__(self, other):
|
||||
if self.__class__ is other.__class__:
|
||||
return self.value > other.value
|
||||
return NotImplemented
|
||||
|
||||
def __le__(self, other):
|
||||
if self.__class__ is other.__class__:
|
||||
return self.value <= other.value
|
||||
return NotImplemented
|
||||
|
||||
def __lt__(self, other):
|
||||
if self.__class__ is other.__class__:
|
||||
return self.value < other.value
|
||||
return NotImplemented
|
||||
@@ -0,0 +1,571 @@
|
||||
# This file is part of beets.
|
||||
# Copyright 2013, Adrian Sampson.
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining
|
||||
# a copy of this software and associated documentation files (the
|
||||
# "Software"), to deal in the Software without restriction, including
|
||||
# without limitation the rights to use, copy, modify, merge, publish,
|
||||
# distribute, sublicense, and/or sell copies of the Software, and to
|
||||
# permit persons to whom the Software is furnished to do so, subject to
|
||||
# the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be
|
||||
# included in all copies or substantial portions of the Software.
|
||||
|
||||
"""This module implements a string formatter based on the standard PEP
|
||||
292 string.Template class extended with function calls. Variables, as
|
||||
with string.Template, are indicated with $ and functions are delimited
|
||||
with %.
|
||||
|
||||
This module assumes that everything is Unicode: the template and the
|
||||
substitution values. Bytestrings are not supported. Also, the templates
|
||||
always behave like the ``safe_substitute`` method in the standard
|
||||
library: unknown symbols are left intact.
|
||||
|
||||
This is sort of like a tiny, horrible degeneration of a real templating
|
||||
engine like Jinja2 or Mustache.
|
||||
"""
|
||||
from __future__ import print_function
|
||||
|
||||
import re
|
||||
import ast
|
||||
import dis
|
||||
import types
|
||||
|
||||
SYMBOL_DELIM = u'$'
|
||||
FUNC_DELIM = u'%'
|
||||
GROUP_OPEN = u'{'
|
||||
GROUP_CLOSE = u'}'
|
||||
ARG_SEP = u','
|
||||
ESCAPE_CHAR = u'$'
|
||||
|
||||
VARIABLE_PREFIX = '__var_'
|
||||
FUNCTION_PREFIX = '__func_'
|
||||
|
||||
|
||||
class Environment(object):
|
||||
"""Contains the values and functions to be substituted into a
|
||||
template.
|
||||
"""
|
||||
def __init__(self, values, functions):
|
||||
self.values = values
|
||||
self.functions = functions
|
||||
|
||||
|
||||
# Code generation helpers.
|
||||
|
||||
def ex_lvalue(name):
|
||||
"""A variable load expression."""
|
||||
return ast.Name(name, ast.Store())
|
||||
|
||||
|
||||
def ex_rvalue(name):
|
||||
"""A variable store expression."""
|
||||
return ast.Name(name, ast.Load())
|
||||
|
||||
|
||||
def ex_literal(val):
|
||||
"""An int, float, long, bool, string, or None literal with the given
|
||||
value.
|
||||
"""
|
||||
if val is None:
|
||||
return ast.Name('None', ast.Load())
|
||||
elif isinstance(val, (int, float, long)):
|
||||
return ast.Num(val)
|
||||
elif isinstance(val, bool):
|
||||
return ast.Name(str(val), ast.Load())
|
||||
elif isinstance(val, basestring):
|
||||
return ast.Str(val)
|
||||
raise TypeError('no literal for {0}'.format(type(val)))
|
||||
|
||||
|
||||
def ex_varassign(name, expr):
|
||||
"""Assign an expression into a single variable. The expression may
|
||||
either be an `ast.expr` object or a value to be used as a literal.
|
||||
"""
|
||||
if not isinstance(expr, ast.expr):
|
||||
expr = ex_literal(expr)
|
||||
return ast.Assign([ex_lvalue(name)], expr)
|
||||
|
||||
|
||||
def ex_call(func, args):
|
||||
"""A function-call expression with only positional parameters. The
|
||||
function may be an expression or the name of a function. Each
|
||||
argument may be an expression or a value to be used as a literal.
|
||||
"""
|
||||
if isinstance(func, basestring):
|
||||
func = ex_rvalue(func)
|
||||
|
||||
args = list(args)
|
||||
for i in range(len(args)):
|
||||
if not isinstance(args[i], ast.expr):
|
||||
args[i] = ex_literal(args[i])
|
||||
|
||||
return ast.Call(func, args, [], None, None)
|
||||
|
||||
|
||||
def compile_func(arg_names, statements, name='_the_func', debug=False):
|
||||
"""Compile a list of statements as the body of a function and return
|
||||
the resulting Python function. If `debug`, then print out the
|
||||
bytecode of the compiled function.
|
||||
"""
|
||||
func_def = ast.FunctionDef(
|
||||
name,
|
||||
ast.arguments(
|
||||
[ast.Name(n, ast.Param()) for n in arg_names],
|
||||
None, None,
|
||||
[ex_literal(None) for _ in arg_names],
|
||||
),
|
||||
statements,
|
||||
[],
|
||||
)
|
||||
mod = ast.Module([func_def])
|
||||
ast.fix_missing_locations(mod)
|
||||
|
||||
prog = compile(mod, '<generated>', 'exec')
|
||||
|
||||
# Debug: show bytecode.
|
||||
if debug:
|
||||
dis.dis(prog)
|
||||
for const in prog.co_consts:
|
||||
if isinstance(const, types.CodeType):
|
||||
dis.dis(const)
|
||||
|
||||
the_locals = {}
|
||||
exec prog in {}, the_locals
|
||||
return the_locals[name]
|
||||
|
||||
|
||||
# AST nodes for the template language.
|
||||
|
||||
class Symbol(object):
|
||||
"""A variable-substitution symbol in a template."""
|
||||
def __init__(self, ident, original):
|
||||
self.ident = ident
|
||||
self.original = original
|
||||
|
||||
def __repr__(self):
|
||||
return u'Symbol(%s)' % repr(self.ident)
|
||||
|
||||
def evaluate(self, env):
|
||||
"""Evaluate the symbol in the environment, returning a Unicode
|
||||
string.
|
||||
"""
|
||||
if self.ident in env.values:
|
||||
# Substitute for a value.
|
||||
return env.values[self.ident]
|
||||
else:
|
||||
# Keep original text.
|
||||
return self.original
|
||||
|
||||
def translate(self):
|
||||
"""Compile the variable lookup."""
|
||||
expr = ex_rvalue(VARIABLE_PREFIX + self.ident.encode('utf8'))
|
||||
return [expr], set([self.ident.encode('utf8')]), set()
|
||||
|
||||
|
||||
class Call(object):
|
||||
"""A function call in a template."""
|
||||
def __init__(self, ident, args, original):
|
||||
self.ident = ident
|
||||
self.args = args
|
||||
self.original = original
|
||||
|
||||
def __repr__(self):
|
||||
return u'Call(%s, %s, %s)' % (repr(self.ident), repr(self.args),
|
||||
repr(self.original))
|
||||
|
||||
def evaluate(self, env):
|
||||
"""Evaluate the function call in the environment, returning a
|
||||
Unicode string.
|
||||
"""
|
||||
if self.ident in env.functions:
|
||||
arg_vals = [expr.evaluate(env) for expr in self.args]
|
||||
try:
|
||||
out = env.functions[self.ident](*arg_vals)
|
||||
except Exception as exc:
|
||||
# Function raised exception! Maybe inlining the name of
|
||||
# the exception will help debug.
|
||||
return u'<%s>' % unicode(exc)
|
||||
return unicode(out)
|
||||
else:
|
||||
return self.original
|
||||
|
||||
def translate(self):
|
||||
"""Compile the function call."""
|
||||
varnames = set()
|
||||
funcnames = set([self.ident.encode('utf8')])
|
||||
|
||||
arg_exprs = []
|
||||
for arg in self.args:
|
||||
subexprs, subvars, subfuncs = arg.translate()
|
||||
varnames.update(subvars)
|
||||
funcnames.update(subfuncs)
|
||||
|
||||
# Create a subexpression that joins the result components of
|
||||
# the arguments.
|
||||
arg_exprs.append(ex_call(
|
||||
ast.Attribute(ex_literal(u''), 'join', ast.Load()),
|
||||
[ex_call(
|
||||
'map',
|
||||
[
|
||||
ex_rvalue('unicode'),
|
||||
ast.List(subexprs, ast.Load()),
|
||||
]
|
||||
)],
|
||||
))
|
||||
|
||||
subexpr_call = ex_call(
|
||||
FUNCTION_PREFIX + self.ident.encode('utf8'),
|
||||
arg_exprs
|
||||
)
|
||||
return [subexpr_call], varnames, funcnames
|
||||
|
||||
|
||||
class Expression(object):
|
||||
"""Top-level template construct: contains a list of text blobs,
|
||||
Symbols, and Calls.
|
||||
"""
|
||||
def __init__(self, parts):
|
||||
self.parts = parts
|
||||
|
||||
def __repr__(self):
|
||||
return u'Expression(%s)' % (repr(self.parts))
|
||||
|
||||
def evaluate(self, env):
|
||||
"""Evaluate the entire expression in the environment, returning
|
||||
a Unicode string.
|
||||
"""
|
||||
out = []
|
||||
for part in self.parts:
|
||||
if isinstance(part, basestring):
|
||||
out.append(part)
|
||||
else:
|
||||
out.append(part.evaluate(env))
|
||||
return u''.join(map(unicode, out))
|
||||
|
||||
def translate(self):
|
||||
"""Compile the expression to a list of Python AST expressions, a
|
||||
set of variable names used, and a set of function names.
|
||||
"""
|
||||
expressions = []
|
||||
varnames = set()
|
||||
funcnames = set()
|
||||
for part in self.parts:
|
||||
if isinstance(part, basestring):
|
||||
expressions.append(ex_literal(part))
|
||||
else:
|
||||
e, v, f = part.translate()
|
||||
expressions.extend(e)
|
||||
varnames.update(v)
|
||||
funcnames.update(f)
|
||||
return expressions, varnames, funcnames
|
||||
|
||||
|
||||
# Parser.
|
||||
|
||||
class ParseError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class Parser(object):
|
||||
"""Parses a template expression string. Instantiate the class with
|
||||
the template source and call ``parse_expression``. The ``pos`` field
|
||||
will indicate the character after the expression finished and
|
||||
``parts`` will contain a list of Unicode strings, Symbols, and Calls
|
||||
reflecting the concatenated portions of the expression.
|
||||
|
||||
This is a terrible, ad-hoc parser implementation based on a
|
||||
left-to-right scan with no lexing step to speak of; it's probably
|
||||
both inefficient and incorrect. Maybe this should eventually be
|
||||
replaced with a real, accepted parsing technique (PEG, parser
|
||||
generator, etc.).
|
||||
"""
|
||||
def __init__(self, string):
|
||||
self.string = string
|
||||
self.pos = 0
|
||||
self.parts = []
|
||||
|
||||
# Common parsing resources.
|
||||
special_chars = (SYMBOL_DELIM, FUNC_DELIM, GROUP_OPEN, GROUP_CLOSE,
|
||||
ARG_SEP, ESCAPE_CHAR)
|
||||
special_char_re = re.compile(ur'[%s]|$' %
|
||||
u''.join(re.escape(c) for c in special_chars))
|
||||
|
||||
def parse_expression(self):
|
||||
"""Parse a template expression starting at ``pos``. Resulting
|
||||
components (Unicode strings, Symbols, and Calls) are added to
|
||||
the ``parts`` field, a list. The ``pos`` field is updated to be
|
||||
the next character after the expression.
|
||||
"""
|
||||
text_parts = []
|
||||
|
||||
while self.pos < len(self.string):
|
||||
char = self.string[self.pos]
|
||||
|
||||
if char not in self.special_chars:
|
||||
# A non-special character. Skip to the next special
|
||||
# character, treating the interstice as literal text.
|
||||
next_pos = (
|
||||
self.special_char_re.search(self.string[self.pos:]).start()
|
||||
+ self.pos
|
||||
)
|
||||
text_parts.append(self.string[self.pos:next_pos])
|
||||
self.pos = next_pos
|
||||
continue
|
||||
|
||||
if self.pos == len(self.string) - 1:
|
||||
# The last character can never begin a structure, so we
|
||||
# just interpret it as a literal character (unless it
|
||||
# terminates the expression, as with , and }).
|
||||
if char not in (GROUP_CLOSE, ARG_SEP):
|
||||
text_parts.append(char)
|
||||
self.pos += 1
|
||||
break
|
||||
|
||||
next_char = self.string[self.pos + 1]
|
||||
if char == ESCAPE_CHAR and next_char in \
|
||||
(SYMBOL_DELIM, FUNC_DELIM, GROUP_CLOSE, ARG_SEP):
|
||||
# An escaped special character ($$, $}, etc.). Note that
|
||||
# ${ is not an escape sequence: this is ambiguous with
|
||||
# the start of a symbol and it's not necessary (just
|
||||
# using { suffices in all cases).
|
||||
text_parts.append(next_char)
|
||||
self.pos += 2 # Skip the next character.
|
||||
continue
|
||||
|
||||
# Shift all characters collected so far into a single string.
|
||||
if text_parts:
|
||||
self.parts.append(u''.join(text_parts))
|
||||
text_parts = []
|
||||
|
||||
if char == SYMBOL_DELIM:
|
||||
# Parse a symbol.
|
||||
self.parse_symbol()
|
||||
elif char == FUNC_DELIM:
|
||||
# Parse a function call.
|
||||
self.parse_call()
|
||||
elif char in (GROUP_CLOSE, ARG_SEP):
|
||||
# Template terminated.
|
||||
break
|
||||
elif char == GROUP_OPEN:
|
||||
# Start of a group has no meaning hear; just pass
|
||||
# through the character.
|
||||
text_parts.append(char)
|
||||
self.pos += 1
|
||||
else:
|
||||
assert False
|
||||
|
||||
# If any parsed characters remain, shift them into a string.
|
||||
if text_parts:
|
||||
self.parts.append(u''.join(text_parts))
|
||||
|
||||
def parse_symbol(self):
|
||||
"""Parse a variable reference (like ``$foo`` or ``${foo}``)
|
||||
starting at ``pos``. Possibly appends a Symbol object (or,
|
||||
failing that, text) to the ``parts`` field and updates ``pos``.
|
||||
The character at ``pos`` must, as a precondition, be ``$``.
|
||||
"""
|
||||
assert self.pos < len(self.string)
|
||||
assert self.string[self.pos] == SYMBOL_DELIM
|
||||
|
||||
if self.pos == len(self.string) - 1:
|
||||
# Last character.
|
||||
self.parts.append(SYMBOL_DELIM)
|
||||
self.pos += 1
|
||||
return
|
||||
|
||||
next_char = self.string[self.pos + 1]
|
||||
start_pos = self.pos
|
||||
self.pos += 1
|
||||
|
||||
if next_char == GROUP_OPEN:
|
||||
# A symbol like ${this}.
|
||||
self.pos += 1 # Skip opening.
|
||||
closer = self.string.find(GROUP_CLOSE, self.pos)
|
||||
if closer == -1 or closer == self.pos:
|
||||
# No closing brace found or identifier is empty.
|
||||
self.parts.append(self.string[start_pos:self.pos])
|
||||
else:
|
||||
# Closer found.
|
||||
ident = self.string[self.pos:closer]
|
||||
self.pos = closer + 1
|
||||
self.parts.append(Symbol(ident,
|
||||
self.string[start_pos:self.pos]))
|
||||
|
||||
else:
|
||||
# A bare-word symbol.
|
||||
ident = self._parse_ident()
|
||||
if ident:
|
||||
# Found a real symbol.
|
||||
self.parts.append(Symbol(ident,
|
||||
self.string[start_pos:self.pos]))
|
||||
else:
|
||||
# A standalone $.
|
||||
self.parts.append(SYMBOL_DELIM)
|
||||
|
||||
def parse_call(self):
|
||||
"""Parse a function call (like ``%foo{bar,baz}``) starting at
|
||||
``pos``. Possibly appends a Call object to ``parts`` and update
|
||||
``pos``. The character at ``pos`` must be ``%``.
|
||||
"""
|
||||
assert self.pos < len(self.string)
|
||||
assert self.string[self.pos] == FUNC_DELIM
|
||||
|
||||
start_pos = self.pos
|
||||
self.pos += 1
|
||||
|
||||
ident = self._parse_ident()
|
||||
if not ident:
|
||||
# No function name.
|
||||
self.parts.append(FUNC_DELIM)
|
||||
return
|
||||
|
||||
if self.pos >= len(self.string):
|
||||
# Identifier terminates string.
|
||||
self.parts.append(self.string[start_pos:self.pos])
|
||||
return
|
||||
|
||||
if self.string[self.pos] != GROUP_OPEN:
|
||||
# Argument list not opened.
|
||||
self.parts.append(self.string[start_pos:self.pos])
|
||||
return
|
||||
|
||||
# Skip past opening brace and try to parse an argument list.
|
||||
self.pos += 1
|
||||
args = self.parse_argument_list()
|
||||
if self.pos >= len(self.string) or \
|
||||
self.string[self.pos] != GROUP_CLOSE:
|
||||
# Arguments unclosed.
|
||||
self.parts.append(self.string[start_pos:self.pos])
|
||||
return
|
||||
|
||||
self.pos += 1 # Move past closing brace.
|
||||
self.parts.append(Call(ident, args, self.string[start_pos:self.pos]))
|
||||
|
||||
def parse_argument_list(self):
|
||||
"""Parse a list of arguments starting at ``pos``, returning a
|
||||
list of Expression objects. Does not modify ``parts``. Should
|
||||
leave ``pos`` pointing to a } character or the end of the
|
||||
string.
|
||||
"""
|
||||
# Try to parse a subexpression in a subparser.
|
||||
expressions = []
|
||||
|
||||
while self.pos < len(self.string):
|
||||
subparser = Parser(self.string[self.pos:])
|
||||
subparser.parse_expression()
|
||||
|
||||
# Extract and advance past the parsed expression.
|
||||
expressions.append(Expression(subparser.parts))
|
||||
self.pos += subparser.pos
|
||||
|
||||
if self.pos >= len(self.string) or \
|
||||
self.string[self.pos] == GROUP_CLOSE:
|
||||
# Argument list terminated by EOF or closing brace.
|
||||
break
|
||||
|
||||
# Only other way to terminate an expression is with ,.
|
||||
# Continue to the next argument.
|
||||
assert self.string[self.pos] == ARG_SEP
|
||||
self.pos += 1
|
||||
|
||||
return expressions
|
||||
|
||||
def _parse_ident(self):
|
||||
"""Parse an identifier and return it (possibly an empty string).
|
||||
Updates ``pos``.
|
||||
"""
|
||||
remainder = self.string[self.pos:]
|
||||
ident = re.match(ur'\w*', remainder).group(0)
|
||||
self.pos += len(ident)
|
||||
return ident
|
||||
|
||||
|
||||
def _parse(template):
|
||||
"""Parse a top-level template string Expression. Any extraneous text
|
||||
is considered literal text.
|
||||
"""
|
||||
parser = Parser(template)
|
||||
parser.parse_expression()
|
||||
|
||||
parts = parser.parts
|
||||
remainder = parser.string[parser.pos:]
|
||||
if remainder:
|
||||
parts.append(remainder)
|
||||
return Expression(parts)
|
||||
|
||||
|
||||
# External interface.
|
||||
|
||||
class Template(object):
|
||||
"""A string template, including text, Symbols, and Calls.
|
||||
"""
|
||||
def __init__(self, template):
|
||||
self.expr = _parse(template)
|
||||
self.original = template
|
||||
self.compiled = self.translate()
|
||||
|
||||
def __eq__(self, other):
|
||||
return self.original == other.original
|
||||
|
||||
def interpret(self, values={}, functions={}):
|
||||
"""Like `substitute`, but forces the interpreter (rather than
|
||||
the compiled version) to be used. The interpreter includes
|
||||
exception-handling code for missing variables and buggy template
|
||||
functions but is much slower.
|
||||
"""
|
||||
return self.expr.evaluate(Environment(values, functions))
|
||||
|
||||
def substitute(self, values={}, functions={}):
|
||||
"""Evaluate the template given the values and functions.
|
||||
"""
|
||||
try:
|
||||
res = self.compiled(values, functions)
|
||||
except: # Handle any exceptions thrown by compiled version.
|
||||
res = self.interpret(values, functions)
|
||||
return res
|
||||
|
||||
def translate(self):
|
||||
"""Compile the template to a Python function."""
|
||||
expressions, varnames, funcnames = self.expr.translate()
|
||||
|
||||
argnames = []
|
||||
for varname in varnames:
|
||||
argnames.append(VARIABLE_PREFIX.encode('utf8') + varname)
|
||||
for funcname in funcnames:
|
||||
argnames.append(FUNCTION_PREFIX.encode('utf8') + funcname)
|
||||
|
||||
func = compile_func(
|
||||
argnames,
|
||||
[ast.Return(ast.List(expressions, ast.Load()))],
|
||||
)
|
||||
|
||||
def wrapper_func(values={}, functions={}):
|
||||
args = {}
|
||||
for varname in varnames:
|
||||
args[VARIABLE_PREFIX + varname] = values[varname]
|
||||
for funcname in funcnames:
|
||||
args[FUNCTION_PREFIX + funcname] = functions[funcname]
|
||||
parts = func(**args)
|
||||
return u''.join(parts)
|
||||
|
||||
return wrapper_func
|
||||
|
||||
|
||||
# Performance tests.
|
||||
|
||||
if __name__ == '__main__':
|
||||
import timeit
|
||||
_tmpl = Template(u'foo $bar %baz{foozle $bar barzle} $bar')
|
||||
_vars = {'bar': 'qux'}
|
||||
_funcs = {'baz': unicode.upper}
|
||||
interp_time = timeit.timeit('_tmpl.interpret(_vars, _funcs)',
|
||||
'from __main__ import _tmpl, _vars, _funcs',
|
||||
number=10000)
|
||||
print(interp_time)
|
||||
comp_time = timeit.timeit('_tmpl.substitute(_vars, _funcs)',
|
||||
'from __main__ import _tmpl, _vars, _funcs',
|
||||
number=10000)
|
||||
print(comp_time)
|
||||
print('Speedup:', interp_time / comp_time)
|
||||
@@ -0,0 +1,517 @@
|
||||
# This file is part of beets.
|
||||
# Copyright 2013, Adrian Sampson.
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining
|
||||
# a copy of this software and associated documentation files (the
|
||||
# "Software"), to deal in the Software without restriction, including
|
||||
# without limitation the rights to use, copy, modify, merge, publish,
|
||||
# distribute, sublicense, and/or sell copies of the Software, and to
|
||||
# permit persons to whom the Software is furnished to do so, subject to
|
||||
# the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be
|
||||
# included in all copies or substantial portions of the Software.
|
||||
|
||||
"""Simple but robust implementation of generator/coroutine-based
|
||||
pipelines in Python. The pipelines may be run either sequentially
|
||||
(single-threaded) or in parallel (one thread per pipeline stage).
|
||||
|
||||
This implementation supports pipeline bubbles (indications that the
|
||||
processing for a certain item should abort). To use them, yield the
|
||||
BUBBLE constant from any stage coroutine except the last.
|
||||
|
||||
In the parallel case, the implementation transparently handles thread
|
||||
shutdown when the processing is complete and when a stage raises an
|
||||
exception. KeyboardInterrupts (^C) are also handled.
|
||||
|
||||
When running a parallel pipeline, it is also possible to use
|
||||
multiple coroutines for the same pipeline stage; this lets you speed
|
||||
up a bottleneck stage by dividing its work among multiple threads.
|
||||
To do so, pass an iterable of coroutines to the Pipeline constructor
|
||||
in place of any single coroutine.
|
||||
"""
|
||||
from __future__ import print_function
|
||||
|
||||
import Queue
|
||||
from threading import Thread, Lock
|
||||
import sys
|
||||
|
||||
BUBBLE = '__PIPELINE_BUBBLE__'
|
||||
POISON = '__PIPELINE_POISON__'
|
||||
|
||||
DEFAULT_QUEUE_SIZE = 16
|
||||
|
||||
|
||||
def _invalidate_queue(q, val=None, sync=True):
|
||||
"""Breaks a Queue such that it never blocks, always has size 1,
|
||||
and has no maximum size. get()ing from the queue returns `val`,
|
||||
which defaults to None. `sync` controls whether a lock is
|
||||
required (because it's not reentrant!).
|
||||
"""
|
||||
def _qsize(len=len):
|
||||
return 1
|
||||
|
||||
def _put(item):
|
||||
pass
|
||||
|
||||
def _get():
|
||||
return val
|
||||
|
||||
if sync:
|
||||
q.mutex.acquire()
|
||||
|
||||
try:
|
||||
q.maxsize = 0
|
||||
q._qsize = _qsize
|
||||
q._put = _put
|
||||
q._get = _get
|
||||
q.not_empty.notifyAll()
|
||||
q.not_full.notifyAll()
|
||||
|
||||
finally:
|
||||
if sync:
|
||||
q.mutex.release()
|
||||
|
||||
|
||||
class CountedQueue(Queue.Queue):
|
||||
"""A queue that keeps track of the number of threads that are
|
||||
still feeding into it. The queue is poisoned when all threads are
|
||||
finished with the queue.
|
||||
"""
|
||||
def __init__(self, maxsize=0):
|
||||
Queue.Queue.__init__(self, maxsize)
|
||||
self.nthreads = 0
|
||||
self.poisoned = False
|
||||
|
||||
def acquire(self):
|
||||
"""Indicate that a thread will start putting into this queue.
|
||||
Should not be called after the queue is already poisoned.
|
||||
"""
|
||||
with self.mutex:
|
||||
assert not self.poisoned
|
||||
assert self.nthreads >= 0
|
||||
self.nthreads += 1
|
||||
|
||||
def release(self):
|
||||
"""Indicate that a thread that was putting into this queue has
|
||||
exited. If this is the last thread using the queue, the queue
|
||||
is poisoned.
|
||||
"""
|
||||
with self.mutex:
|
||||
self.nthreads -= 1
|
||||
assert self.nthreads >= 0
|
||||
if self.nthreads == 0:
|
||||
# All threads are done adding to this queue. Poison it
|
||||
# when it becomes empty.
|
||||
self.poisoned = True
|
||||
|
||||
# Replacement _get invalidates when no items remain.
|
||||
_old_get = self._get
|
||||
|
||||
def _get():
|
||||
out = _old_get()
|
||||
if not self.queue:
|
||||
_invalidate_queue(self, POISON, False)
|
||||
return out
|
||||
|
||||
if self.queue:
|
||||
# Items remain.
|
||||
self._get = _get
|
||||
else:
|
||||
# No items. Invalidate immediately.
|
||||
_invalidate_queue(self, POISON, False)
|
||||
|
||||
|
||||
class MultiMessage(object):
|
||||
"""A message yielded by a pipeline stage encapsulating multiple
|
||||
values to be sent to the next stage.
|
||||
"""
|
||||
def __init__(self, messages):
|
||||
self.messages = messages
|
||||
|
||||
|
||||
def multiple(messages):
|
||||
"""Yield multiple([message, ..]) from a pipeline stage to send
|
||||
multiple values to the next pipeline stage.
|
||||
"""
|
||||
return MultiMessage(messages)
|
||||
|
||||
|
||||
def stage(func):
|
||||
"""Decorate a function to become a simple stage.
|
||||
|
||||
>>> @stage
|
||||
... def add(n, i):
|
||||
... return i + n
|
||||
>>> pipe = Pipeline([
|
||||
... iter([1, 2, 3]),
|
||||
... add(2),
|
||||
... ])
|
||||
>>> list(pipe.pull())
|
||||
[3, 4, 5]
|
||||
"""
|
||||
|
||||
def coro(*args):
|
||||
task = None
|
||||
while True:
|
||||
task = yield task
|
||||
task = func(*(args + (task,)))
|
||||
return coro
|
||||
|
||||
|
||||
def mutator_stage(func):
|
||||
"""Decorate a function that manipulates items in a coroutine to
|
||||
become a simple stage.
|
||||
|
||||
>>> @mutator_stage
|
||||
... def setkey(key, item):
|
||||
... item[key] = True
|
||||
>>> pipe = Pipeline([
|
||||
... iter([{'x': False}, {'a': False}]),
|
||||
... setkey('x'),
|
||||
... ])
|
||||
>>> list(pipe.pull())
|
||||
[{'x': True}, {'a': False, 'x': True}]
|
||||
"""
|
||||
|
||||
def coro(*args):
|
||||
task = None
|
||||
while True:
|
||||
task = yield task
|
||||
func(*(args + (task,)))
|
||||
return coro
|
||||
|
||||
|
||||
def _allmsgs(obj):
|
||||
"""Returns a list of all the messages encapsulated in obj. If obj
|
||||
is a MultiMessage, returns its enclosed messages. If obj is BUBBLE,
|
||||
returns an empty list. Otherwise, returns a list containing obj.
|
||||
"""
|
||||
if isinstance(obj, MultiMessage):
|
||||
return obj.messages
|
||||
elif obj == BUBBLE:
|
||||
return []
|
||||
else:
|
||||
return [obj]
|
||||
|
||||
|
||||
class PipelineThread(Thread):
|
||||
"""Abstract base class for pipeline-stage threads."""
|
||||
def __init__(self, all_threads):
|
||||
super(PipelineThread, self).__init__()
|
||||
self.abort_lock = Lock()
|
||||
self.abort_flag = False
|
||||
self.all_threads = all_threads
|
||||
self.exc_info = None
|
||||
|
||||
def abort(self):
|
||||
"""Shut down the thread at the next chance possible.
|
||||
"""
|
||||
with self.abort_lock:
|
||||
self.abort_flag = True
|
||||
|
||||
# Ensure that we are not blocking on a queue read or write.
|
||||
if hasattr(self, 'in_queue'):
|
||||
_invalidate_queue(self.in_queue, POISON)
|
||||
if hasattr(self, 'out_queue'):
|
||||
_invalidate_queue(self.out_queue, POISON)
|
||||
|
||||
def abort_all(self, exc_info):
|
||||
"""Abort all other threads in the system for an exception.
|
||||
"""
|
||||
self.exc_info = exc_info
|
||||
for thread in self.all_threads:
|
||||
thread.abort()
|
||||
|
||||
|
||||
class FirstPipelineThread(PipelineThread):
|
||||
"""The thread running the first stage in a parallel pipeline setup.
|
||||
The coroutine should just be a generator.
|
||||
"""
|
||||
def __init__(self, coro, out_queue, all_threads):
|
||||
super(FirstPipelineThread, self).__init__(all_threads)
|
||||
self.coro = coro
|
||||
self.out_queue = out_queue
|
||||
self.out_queue.acquire()
|
||||
|
||||
self.abort_lock = Lock()
|
||||
self.abort_flag = False
|
||||
|
||||
def run(self):
|
||||
try:
|
||||
while True:
|
||||
with self.abort_lock:
|
||||
if self.abort_flag:
|
||||
return
|
||||
|
||||
# Get the value from the generator.
|
||||
try:
|
||||
msg = self.coro.next()
|
||||
except StopIteration:
|
||||
break
|
||||
|
||||
# Send messages to the next stage.
|
||||
for msg in _allmsgs(msg):
|
||||
with self.abort_lock:
|
||||
if self.abort_flag:
|
||||
return
|
||||
self.out_queue.put(msg)
|
||||
|
||||
except:
|
||||
self.abort_all(sys.exc_info())
|
||||
return
|
||||
|
||||
# Generator finished; shut down the pipeline.
|
||||
self.out_queue.release()
|
||||
|
||||
|
||||
class MiddlePipelineThread(PipelineThread):
|
||||
"""A thread running any stage in the pipeline except the first or
|
||||
last.
|
||||
"""
|
||||
def __init__(self, coro, in_queue, out_queue, all_threads):
|
||||
super(MiddlePipelineThread, self).__init__(all_threads)
|
||||
self.coro = coro
|
||||
self.in_queue = in_queue
|
||||
self.out_queue = out_queue
|
||||
self.out_queue.acquire()
|
||||
|
||||
def run(self):
|
||||
try:
|
||||
# Prime the coroutine.
|
||||
self.coro.next()
|
||||
|
||||
while True:
|
||||
with self.abort_lock:
|
||||
if self.abort_flag:
|
||||
return
|
||||
|
||||
# Get the message from the previous stage.
|
||||
msg = self.in_queue.get()
|
||||
if msg is POISON:
|
||||
break
|
||||
|
||||
with self.abort_lock:
|
||||
if self.abort_flag:
|
||||
return
|
||||
|
||||
# Invoke the current stage.
|
||||
out = self.coro.send(msg)
|
||||
|
||||
# Send messages to next stage.
|
||||
for msg in _allmsgs(out):
|
||||
with self.abort_lock:
|
||||
if self.abort_flag:
|
||||
return
|
||||
self.out_queue.put(msg)
|
||||
|
||||
except:
|
||||
self.abort_all(sys.exc_info())
|
||||
return
|
||||
|
||||
# Pipeline is shutting down normally.
|
||||
self.out_queue.release()
|
||||
|
||||
|
||||
class LastPipelineThread(PipelineThread):
|
||||
"""A thread running the last stage in a pipeline. The coroutine
|
||||
should yield nothing.
|
||||
"""
|
||||
def __init__(self, coro, in_queue, all_threads):
|
||||
super(LastPipelineThread, self).__init__(all_threads)
|
||||
self.coro = coro
|
||||
self.in_queue = in_queue
|
||||
|
||||
def run(self):
|
||||
# Prime the coroutine.
|
||||
self.coro.next()
|
||||
|
||||
try:
|
||||
while True:
|
||||
with self.abort_lock:
|
||||
if self.abort_flag:
|
||||
return
|
||||
|
||||
# Get the message from the previous stage.
|
||||
msg = self.in_queue.get()
|
||||
if msg is POISON:
|
||||
break
|
||||
|
||||
with self.abort_lock:
|
||||
if self.abort_flag:
|
||||
return
|
||||
|
||||
# Send to consumer.
|
||||
self.coro.send(msg)
|
||||
|
||||
except:
|
||||
self.abort_all(sys.exc_info())
|
||||
return
|
||||
|
||||
|
||||
class Pipeline(object):
|
||||
"""Represents a staged pattern of work. Each stage in the pipeline
|
||||
is a coroutine that receives messages from the previous stage and
|
||||
yields messages to be sent to the next stage.
|
||||
"""
|
||||
def __init__(self, stages):
|
||||
"""Makes a new pipeline from a list of coroutines. There must
|
||||
be at least two stages.
|
||||
"""
|
||||
if len(stages) < 2:
|
||||
raise ValueError('pipeline must have at least two stages')
|
||||
self.stages = []
|
||||
for stage in stages:
|
||||
if isinstance(stage, (list, tuple)):
|
||||
self.stages.append(stage)
|
||||
else:
|
||||
# Default to one thread per stage.
|
||||
self.stages.append((stage,))
|
||||
|
||||
def run_sequential(self):
|
||||
"""Run the pipeline sequentially in the current thread. The
|
||||
stages are run one after the other. Only the first coroutine
|
||||
in each stage is used.
|
||||
"""
|
||||
list(self.pull())
|
||||
|
||||
def run_parallel(self, queue_size=DEFAULT_QUEUE_SIZE):
|
||||
"""Run the pipeline in parallel using one thread per stage. The
|
||||
messages between the stages are stored in queues of the given
|
||||
size.
|
||||
"""
|
||||
queue_count = len(self.stages) - 1
|
||||
queues = [CountedQueue(queue_size) for i in range(queue_count)]
|
||||
threads = []
|
||||
|
||||
# Set up first stage.
|
||||
for coro in self.stages[0]:
|
||||
threads.append(FirstPipelineThread(coro, queues[0], threads))
|
||||
|
||||
# Middle stages.
|
||||
for i in range(1, queue_count):
|
||||
for coro in self.stages[i]:
|
||||
threads.append(MiddlePipelineThread(
|
||||
coro, queues[i - 1], queues[i], threads
|
||||
))
|
||||
|
||||
# Last stage.
|
||||
for coro in self.stages[-1]:
|
||||
threads.append(
|
||||
LastPipelineThread(coro, queues[-1], threads)
|
||||
)
|
||||
|
||||
# Start threads.
|
||||
for thread in threads:
|
||||
thread.start()
|
||||
|
||||
# Wait for termination. The final thread lasts the longest.
|
||||
try:
|
||||
# Using a timeout allows us to receive KeyboardInterrupt
|
||||
# exceptions during the join().
|
||||
while threads[-1].isAlive():
|
||||
threads[-1].join(1)
|
||||
|
||||
except:
|
||||
# Stop all the threads immediately.
|
||||
for thread in threads:
|
||||
thread.abort()
|
||||
raise
|
||||
|
||||
finally:
|
||||
# Make completely sure that all the threads have finished
|
||||
# before we return. They should already be either finished,
|
||||
# in normal operation, or aborted, in case of an exception.
|
||||
for thread in threads[:-1]:
|
||||
thread.join()
|
||||
|
||||
for thread in threads:
|
||||
exc_info = thread.exc_info
|
||||
if exc_info:
|
||||
# Make the exception appear as it was raised originally.
|
||||
raise exc_info[0], exc_info[1], exc_info[2]
|
||||
|
||||
def pull(self):
|
||||
"""Yield elements from the end of the pipeline. Runs the stages
|
||||
sequentially until the last yields some messages. Each of the messages
|
||||
is then yielded by ``pulled.next()``. If the pipeline has a consumer,
|
||||
that is the last stage does not yield any messages, then pull will not
|
||||
yield any messages. Only the first coroutine in each stage is used
|
||||
"""
|
||||
coros = [stage[0] for stage in self.stages]
|
||||
|
||||
# "Prime" the coroutines.
|
||||
for coro in coros[1:]:
|
||||
coro.next()
|
||||
|
||||
# Begin the pipeline.
|
||||
for out in coros[0]:
|
||||
msgs = _allmsgs(out)
|
||||
for coro in coros[1:]:
|
||||
next_msgs = []
|
||||
for msg in msgs:
|
||||
out = coro.send(msg)
|
||||
next_msgs.extend(_allmsgs(out))
|
||||
msgs = next_msgs
|
||||
for msg in msgs:
|
||||
yield msg
|
||||
|
||||
# Smoke test.
|
||||
if __name__ == '__main__':
|
||||
import time
|
||||
|
||||
# Test a normally-terminating pipeline both in sequence and
|
||||
# in parallel.
|
||||
def produce():
|
||||
for i in range(5):
|
||||
print('generating %i' % i)
|
||||
time.sleep(1)
|
||||
yield i
|
||||
|
||||
def work():
|
||||
num = yield
|
||||
while True:
|
||||
print('processing %i' % num)
|
||||
time.sleep(2)
|
||||
num = yield num * 2
|
||||
|
||||
def consume():
|
||||
while True:
|
||||
num = yield
|
||||
time.sleep(1)
|
||||
print('received %i' % num)
|
||||
|
||||
ts_start = time.time()
|
||||
Pipeline([produce(), work(), consume()]).run_sequential()
|
||||
ts_seq = time.time()
|
||||
Pipeline([produce(), work(), consume()]).run_parallel()
|
||||
ts_par = time.time()
|
||||
Pipeline([produce(), (work(), work()), consume()]).run_parallel()
|
||||
ts_end = time.time()
|
||||
print('Sequential time:', ts_seq - ts_start)
|
||||
print('Parallel time:', ts_par - ts_seq)
|
||||
print('Multiply-parallel time:', ts_end - ts_par)
|
||||
print()
|
||||
|
||||
# Test a pipeline that raises an exception.
|
||||
def exc_produce():
|
||||
for i in range(10):
|
||||
print('generating %i' % i)
|
||||
time.sleep(1)
|
||||
yield i
|
||||
|
||||
def exc_work():
|
||||
num = yield
|
||||
while True:
|
||||
print('processing %i' % num)
|
||||
time.sleep(3)
|
||||
if num == 3:
|
||||
raise Exception()
|
||||
num = yield num * 2
|
||||
|
||||
def exc_consume():
|
||||
while True:
|
||||
num = yield
|
||||
print('received %i' % num)
|
||||
|
||||
Pipeline([exc_produce(), exc_work(), exc_consume()]).run_parallel(1)
|
||||
@@ -0,0 +1,50 @@
|
||||
# This file is part of beets.
|
||||
# Copyright 2013, Adrian Sampson.
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining
|
||||
# a copy of this software and associated documentation files (the
|
||||
# "Software"), to deal in the Software without restriction, including
|
||||
# without limitation the rights to use, copy, modify, merge, publish,
|
||||
# distribute, sublicense, and/or sell copies of the Software, and to
|
||||
# permit persons to whom the Software is furnished to do so, subject to
|
||||
# the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be
|
||||
# included in all copies or substantial portions of the Software.
|
||||
|
||||
"""A simple utility for constructing filesystem-like trees from beets
|
||||
libraries.
|
||||
"""
|
||||
from collections import namedtuple
|
||||
from beets import util
|
||||
|
||||
Node = namedtuple('Node', ['files', 'dirs'])
|
||||
|
||||
|
||||
def _insert(node, path, itemid):
|
||||
"""Insert an item into a virtual filesystem node."""
|
||||
if len(path) == 1:
|
||||
# Last component. Insert file.
|
||||
node.files[path[0]] = itemid
|
||||
else:
|
||||
# In a directory.
|
||||
dirname = path[0]
|
||||
rest = path[1:]
|
||||
if dirname not in node.dirs:
|
||||
node.dirs[dirname] = Node({}, {})
|
||||
_insert(node.dirs[dirname], rest, itemid)
|
||||
|
||||
|
||||
def libtree(lib):
|
||||
"""Generates a filesystem-like directory tree for the files
|
||||
contained in `lib`. Filesystem nodes are (files, dirs) named
|
||||
tuples in which both components are dictionaries. The first
|
||||
maps filenames to Item ids. The second maps directory names to
|
||||
child node tuples.
|
||||
"""
|
||||
root = Node({}, {})
|
||||
for item in lib.items():
|
||||
dest = item.destination(fragment=True)
|
||||
parts = util.components(dest)
|
||||
_insert(root, parts, item.id)
|
||||
return root
|
||||
@@ -0,0 +1,19 @@
|
||||
# This file is part of beets.
|
||||
# Copyright 2013, Adrian Sampson.
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining
|
||||
# a copy of this software and associated documentation files (the
|
||||
# "Software"), to deal in the Software without restriction, including
|
||||
# without limitation the rights to use, copy, modify, merge, publish,
|
||||
# distribute, sublicense, and/or sell copies of the Software, and to
|
||||
# permit persons to whom the Software is furnished to do so, subject to
|
||||
# the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be
|
||||
# included in all copies or substantial portions of the Software.
|
||||
|
||||
"""A namespace package for beets plugins."""
|
||||
|
||||
# Make this a namespace package.
|
||||
from pkgutil import extend_path
|
||||
__path__ = extend_path(__path__, __name__)
|
||||
@@ -0,0 +1,280 @@
|
||||
# This file is part of beets.
|
||||
# Copyright 2014, Adrian Sampson.
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining
|
||||
# a copy of this software and associated documentation files (the
|
||||
# "Software"), to deal in the Software without restriction, including
|
||||
# without limitation the rights to use, copy, modify, merge, publish,
|
||||
# distribute, sublicense, and/or sell copies of the Software, and to
|
||||
# permit persons to whom the Software is furnished to do so, subject to
|
||||
# the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be
|
||||
# included in all copies or substantial portions of the Software.
|
||||
|
||||
"""Allows beets to embed album art into file metadata."""
|
||||
import os.path
|
||||
import logging
|
||||
import imghdr
|
||||
import subprocess
|
||||
import platform
|
||||
from tempfile import NamedTemporaryFile
|
||||
|
||||
from beets.plugins import BeetsPlugin
|
||||
from beets import mediafile
|
||||
from beets import ui
|
||||
from beets.ui import decargs
|
||||
from beets.util import syspath, normpath, displayable_path
|
||||
from beets.util.artresizer import ArtResizer
|
||||
from beets import config
|
||||
|
||||
|
||||
log = logging.getLogger('beets')
|
||||
|
||||
|
||||
class EmbedCoverArtPlugin(BeetsPlugin):
|
||||
"""Allows albumart to be embedded into the actual files.
|
||||
"""
|
||||
def __init__(self):
|
||||
super(EmbedCoverArtPlugin, self).__init__()
|
||||
self.config.add({
|
||||
'maxwidth': 0,
|
||||
'auto': True,
|
||||
'compare_threshold': 0,
|
||||
'ifempty': False,
|
||||
})
|
||||
|
||||
if self.config['maxwidth'].get(int) and not ArtResizer.shared.local:
|
||||
self.config['maxwidth'] = 0
|
||||
log.warn(u"embedart: ImageMagick or PIL not found; "
|
||||
u"'maxwidth' option ignored")
|
||||
if self.config['compare_threshold'].get(int) and not \
|
||||
ArtResizer.shared.can_compare:
|
||||
self.config['compare_threshold'] = 0
|
||||
log.warn(u"embedart: ImageMagick 6.8.7 or higher not installed; "
|
||||
u"'compare_threshold' option ignored")
|
||||
|
||||
def commands(self):
|
||||
# Embed command.
|
||||
embed_cmd = ui.Subcommand(
|
||||
'embedart', help='embed image files into file metadata'
|
||||
)
|
||||
embed_cmd.parser.add_option(
|
||||
'-f', '--file', metavar='PATH', help='the image file to embed'
|
||||
)
|
||||
maxwidth = config['embedart']['maxwidth'].get(int)
|
||||
compare_threshold = config['embedart']['compare_threshold'].get(int)
|
||||
ifempty = config['embedart']['ifempty'].get(bool)
|
||||
|
||||
def embed_func(lib, opts, args):
|
||||
if opts.file:
|
||||
imagepath = normpath(opts.file)
|
||||
for item in lib.items(decargs(args)):
|
||||
embed_item(item, imagepath, maxwidth, None,
|
||||
compare_threshold, ifempty)
|
||||
else:
|
||||
for album in lib.albums(decargs(args)):
|
||||
embed_album(album, maxwidth)
|
||||
|
||||
embed_cmd.func = embed_func
|
||||
|
||||
# Extract command.
|
||||
extract_cmd = ui.Subcommand('extractart',
|
||||
help='extract an image from file metadata')
|
||||
extract_cmd.parser.add_option('-o', dest='outpath',
|
||||
help='image output file')
|
||||
|
||||
def extract_func(lib, opts, args):
|
||||
outpath = normpath(opts.outpath or 'cover')
|
||||
item = lib.items(decargs(args)).get()
|
||||
extract(outpath, item)
|
||||
extract_cmd.func = extract_func
|
||||
|
||||
# Clear command.
|
||||
clear_cmd = ui.Subcommand('clearart',
|
||||
help='remove images from file metadata')
|
||||
|
||||
def clear_func(lib, opts, args):
|
||||
clear(lib, decargs(args))
|
||||
clear_cmd.func = clear_func
|
||||
|
||||
return [embed_cmd, extract_cmd, clear_cmd]
|
||||
|
||||
|
||||
@EmbedCoverArtPlugin.listen('album_imported')
|
||||
def album_imported(lib, album):
|
||||
"""Automatically embed art into imported albums.
|
||||
"""
|
||||
if album.artpath and config['embedart']['auto']:
|
||||
embed_album(album, config['embedart']['maxwidth'].get(int), True)
|
||||
|
||||
|
||||
def embed_item(item, imagepath, maxwidth=None, itempath=None,
|
||||
compare_threshold=0, ifempty=False, as_album=False):
|
||||
"""Embed an image into the item's media file.
|
||||
"""
|
||||
if compare_threshold:
|
||||
if not check_art_similarity(item, imagepath, compare_threshold):
|
||||
log.warn(u'Image not similar; skipping.')
|
||||
return
|
||||
if ifempty:
|
||||
art = get_art(item)
|
||||
if not art:
|
||||
pass
|
||||
else:
|
||||
log.debug(u'embedart: media file contained art already {0}'.format(
|
||||
displayable_path(imagepath)
|
||||
))
|
||||
return
|
||||
if maxwidth and not as_album:
|
||||
imagepath = resize_image(imagepath, maxwidth)
|
||||
|
||||
try:
|
||||
log.debug(u'embedart: embedding {0}'.format(
|
||||
displayable_path(imagepath)
|
||||
))
|
||||
item['images'] = [_mediafile_image(imagepath, maxwidth)]
|
||||
except IOError as exc:
|
||||
log.error(u'embedart: could not read image file: {0}'.format(exc))
|
||||
else:
|
||||
# We don't want to store the image in the database.
|
||||
item.try_write(itempath)
|
||||
del item['images']
|
||||
|
||||
|
||||
def embed_album(album, maxwidth=None, quiet=False):
|
||||
"""Embed album art into all of the album's items.
|
||||
"""
|
||||
imagepath = album.artpath
|
||||
if not imagepath:
|
||||
log.info(u'No album art present: {0} - {1}'.
|
||||
format(album.albumartist, album.album))
|
||||
return
|
||||
if not os.path.isfile(syspath(imagepath)):
|
||||
log.error(u'Album art not found at {0}'
|
||||
.format(displayable_path(imagepath)))
|
||||
return
|
||||
if maxwidth:
|
||||
imagepath = resize_image(imagepath, maxwidth)
|
||||
|
||||
log.log(
|
||||
logging.DEBUG if quiet else logging.INFO,
|
||||
u'Embedding album art into {0.albumartist} - {0.album}.'.format(album),
|
||||
)
|
||||
|
||||
for item in album.items():
|
||||
embed_item(item, imagepath, maxwidth, None,
|
||||
config['embedart']['compare_threshold'].get(int),
|
||||
config['embedart']['ifempty'].get(bool), as_album=True)
|
||||
|
||||
|
||||
def resize_image(imagepath, maxwidth):
|
||||
"""Returns path to an image resized to maxwidth.
|
||||
"""
|
||||
log.info(u'Resizing album art to {0} pixels wide'
|
||||
.format(maxwidth))
|
||||
imagepath = ArtResizer.shared.resize(maxwidth, syspath(imagepath))
|
||||
return imagepath
|
||||
|
||||
|
||||
def check_art_similarity(item, imagepath, compare_threshold):
|
||||
"""A boolean indicating if an image is similar to embedded item art.
|
||||
"""
|
||||
with NamedTemporaryFile(delete=True) as f:
|
||||
art = extract(f.name, item)
|
||||
|
||||
if art:
|
||||
# Converting images to grayscale tends to minimize the weight
|
||||
# of colors in the diff score
|
||||
cmd = 'convert {0} {1} -colorspace gray MIFF:- | ' \
|
||||
'compare -metric PHASH - null:'.format(syspath(imagepath),
|
||||
syspath(art))
|
||||
|
||||
proc = subprocess.Popen(cmd, stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
close_fds=platform.system() != 'Windows',
|
||||
shell=True)
|
||||
stdout, stderr = proc.communicate()
|
||||
if proc.returncode:
|
||||
if proc.returncode != 1:
|
||||
log.warn(u'embedart: IM phashes compare failed for {0}, \
|
||||
{1}'.format(displayable_path(imagepath),
|
||||
displayable_path(art)))
|
||||
return
|
||||
phashDiff = float(stderr)
|
||||
else:
|
||||
phashDiff = float(stdout)
|
||||
|
||||
log.info(u'embedart: compare PHASH score is {0}'.format(phashDiff))
|
||||
if phashDiff > compare_threshold:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def _mediafile_image(image_path, maxwidth=None):
|
||||
"""Return a `mediafile.Image` object for the path.
|
||||
"""
|
||||
|
||||
with open(syspath(image_path), 'rb') as f:
|
||||
data = f.read()
|
||||
return mediafile.Image(data, type=mediafile.ImageType.front)
|
||||
|
||||
|
||||
def get_art(item):
|
||||
# Extract the art.
|
||||
try:
|
||||
mf = mediafile.MediaFile(syspath(item.path))
|
||||
except mediafile.UnreadableFileError as exc:
|
||||
log.error(u'Could not extract art from {0}: {1}'.format(
|
||||
displayable_path(item.path), exc
|
||||
))
|
||||
return
|
||||
|
||||
return mf.art
|
||||
|
||||
# 'extractart' command.
|
||||
|
||||
|
||||
def extract(outpath, item):
|
||||
if not item:
|
||||
log.error(u'No item matches query.')
|
||||
return
|
||||
|
||||
art = get_art(item)
|
||||
|
||||
if not art:
|
||||
log.error(u'No album art present in {0} - {1}.'
|
||||
.format(item.artist, item.title))
|
||||
return
|
||||
|
||||
# Add an extension to the filename.
|
||||
ext = imghdr.what(None, h=art)
|
||||
if not ext:
|
||||
log.error(u'Unknown image type.')
|
||||
return
|
||||
outpath += '.' + ext
|
||||
|
||||
log.info(u'Extracting album art from: {0.artist} - {0.title} '
|
||||
u'to: {1}'.format(item, displayable_path(outpath)))
|
||||
with open(syspath(outpath), 'wb') as f:
|
||||
f.write(art)
|
||||
return outpath
|
||||
|
||||
|
||||
# 'clearart' command.
|
||||
|
||||
def clear(lib, query):
|
||||
log.info(u'Clearing album art from items:')
|
||||
for item in lib.items(query):
|
||||
log.info(u'{0} - {1}'.format(item.artist, item.title))
|
||||
try:
|
||||
mf = mediafile.MediaFile(syspath(item.path),
|
||||
config['id3v23'].get(bool))
|
||||
except mediafile.UnreadableFileError as exc:
|
||||
log.error(u'Could not clear art from {0}: {1}'.format(
|
||||
displayable_path(item.path), exc
|
||||
))
|
||||
continue
|
||||
del mf.art
|
||||
mf.save()
|
||||
@@ -0,0 +1,396 @@
|
||||
# This file is part of beets.
|
||||
# Copyright 2014, Adrian Sampson.
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining
|
||||
# a copy of this software and associated documentation files (the
|
||||
# "Software"), to deal in the Software without restriction, including
|
||||
# without limitation the rights to use, copy, modify, merge, publish,
|
||||
# distribute, sublicense, and/or sell copies of the Software, and to
|
||||
# permit persons to whom the Software is furnished to do so, subject to
|
||||
# the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be
|
||||
# included in all copies or substantial portions of the Software.
|
||||
|
||||
"""Fetches album art.
|
||||
"""
|
||||
from contextlib import closing
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
from tempfile import NamedTemporaryFile
|
||||
|
||||
import requests
|
||||
|
||||
from beets import plugins
|
||||
from beets import importer
|
||||
from beets import ui
|
||||
from beets import util
|
||||
from beets import config
|
||||
from beets.util.artresizer import ArtResizer
|
||||
|
||||
try:
|
||||
import itunes
|
||||
HAVE_ITUNES = True
|
||||
except ImportError:
|
||||
HAVE_ITUNES = False
|
||||
|
||||
IMAGE_EXTENSIONS = ['png', 'jpg', 'jpeg']
|
||||
CONTENT_TYPES = ('image/jpeg',)
|
||||
DOWNLOAD_EXTENSION = '.jpg'
|
||||
|
||||
log = logging.getLogger('beets')
|
||||
|
||||
requests_session = requests.Session()
|
||||
requests_session.headers = {'User-Agent': 'beets'}
|
||||
|
||||
|
||||
def _fetch_image(url):
|
||||
"""Downloads an image from a URL and checks whether it seems to
|
||||
actually be an image. If so, returns a path to the downloaded image.
|
||||
Otherwise, returns None.
|
||||
"""
|
||||
log.debug(u'fetchart: downloading art: {0}'.format(url))
|
||||
try:
|
||||
with closing(requests_session.get(url, stream=True)) as resp:
|
||||
if 'Content-Type' not in resp.headers \
|
||||
or resp.headers['Content-Type'] not in CONTENT_TYPES:
|
||||
log.debug(u'fetchart: not an image')
|
||||
return
|
||||
|
||||
# Generate a temporary file with the correct extension.
|
||||
with NamedTemporaryFile(suffix=DOWNLOAD_EXTENSION, delete=False) \
|
||||
as fh:
|
||||
for chunk in resp.iter_content():
|
||||
fh.write(chunk)
|
||||
log.debug(u'fetchart: downloaded art to: {0}'.format(
|
||||
util.displayable_path(fh.name)
|
||||
))
|
||||
return fh.name
|
||||
except (IOError, requests.RequestException):
|
||||
log.debug(u'fetchart: error fetching art')
|
||||
|
||||
|
||||
# ART SOURCES ################################################################
|
||||
|
||||
# Cover Art Archive.
|
||||
|
||||
CAA_URL = 'http://coverartarchive.org/release/{mbid}/front-500.jpg'
|
||||
CAA_GROUP_URL = 'http://coverartarchive.org/release-group/{mbid}/front-500.jpg'
|
||||
|
||||
|
||||
def caa_art(album):
|
||||
"""Return the Cover Art Archive and Cover Art Archive release group URLs
|
||||
using album MusicBrainz release ID and release group ID.
|
||||
"""
|
||||
if album.mb_albumid:
|
||||
yield CAA_URL.format(mbid=album.mb_albumid)
|
||||
if album.mb_releasegroupid:
|
||||
yield CAA_GROUP_URL.format(mbid=album.mb_releasegroupid)
|
||||
|
||||
|
||||
# Art from Amazon.
|
||||
|
||||
AMAZON_URL = 'http://images.amazon.com/images/P/%s.%02i.LZZZZZZZ.jpg'
|
||||
AMAZON_INDICES = (1, 2)
|
||||
|
||||
|
||||
def art_for_asin(album):
|
||||
"""Generate URLs using Amazon ID (ASIN) string.
|
||||
"""
|
||||
if album.asin:
|
||||
for index in AMAZON_INDICES:
|
||||
yield AMAZON_URL % (album.asin, index)
|
||||
|
||||
|
||||
# AlbumArt.org scraper.
|
||||
|
||||
AAO_URL = 'http://www.albumart.org/index_detail.php'
|
||||
AAO_PAT = r'href\s*=\s*"([^>"]*)"[^>]*title\s*=\s*"View larger image"'
|
||||
|
||||
|
||||
def aao_art(album):
|
||||
"""Return art URL from AlbumArt.org using album ASIN.
|
||||
"""
|
||||
if not album.asin:
|
||||
return
|
||||
# Get the page from albumart.org.
|
||||
try:
|
||||
resp = requests_session.get(AAO_URL, params={'asin': album.asin})
|
||||
log.debug(u'fetchart: scraped art URL: {0}'.format(resp.url))
|
||||
except requests.RequestException:
|
||||
log.debug(u'fetchart: error scraping art page')
|
||||
return
|
||||
|
||||
# Search the page for the image URL.
|
||||
m = re.search(AAO_PAT, resp.text)
|
||||
if m:
|
||||
image_url = m.group(1)
|
||||
yield image_url
|
||||
else:
|
||||
log.debug(u'fetchart: no image found on page')
|
||||
|
||||
|
||||
# Google Images scraper.
|
||||
|
||||
GOOGLE_URL = 'https://ajax.googleapis.com/ajax/services/search/images'
|
||||
|
||||
|
||||
def google_art(album):
|
||||
"""Return art URL from google.org given an album title and
|
||||
interpreter.
|
||||
"""
|
||||
if not (album.albumartist and album.album):
|
||||
return
|
||||
search_string = (album.albumartist + ',' + album.album).encode('utf-8')
|
||||
response = requests_session.get(GOOGLE_URL, params={
|
||||
'v': '1.0',
|
||||
'q': search_string,
|
||||
'start': '0',
|
||||
})
|
||||
|
||||
# Get results using JSON.
|
||||
try:
|
||||
results = response.json()
|
||||
data = results['responseData']
|
||||
dataInfo = data['results']
|
||||
for myUrl in dataInfo:
|
||||
yield myUrl['unescapedUrl']
|
||||
except:
|
||||
log.debug(u'fetchart: error scraping art page')
|
||||
return
|
||||
|
||||
|
||||
# Art from the iTunes Store.
|
||||
|
||||
def itunes_art(album):
|
||||
"""Return art URL from iTunes Store given an album title.
|
||||
"""
|
||||
search_string = (album.albumartist + ' ' + album.album).encode('utf-8')
|
||||
try:
|
||||
# Isolate bugs in the iTunes library while searching.
|
||||
try:
|
||||
itunes_album = itunes.search_album(search_string)[0]
|
||||
except Exception as exc:
|
||||
log.debug('fetchart: iTunes search failed: {0}'.format(exc))
|
||||
return
|
||||
|
||||
if itunes_album.get_artwork()['100']:
|
||||
small_url = itunes_album.get_artwork()['100']
|
||||
big_url = small_url.replace('100x100', '1200x1200')
|
||||
yield big_url
|
||||
else:
|
||||
log.debug(u'fetchart: album has no artwork in iTunes Store')
|
||||
except IndexError:
|
||||
log.debug(u'fetchart: album not found in iTunes Store')
|
||||
|
||||
|
||||
# Art from the filesystem.
|
||||
|
||||
|
||||
def filename_priority(filename, cover_names):
|
||||
"""Sort order for image names.
|
||||
|
||||
Return indexes of cover names found in the image filename. This
|
||||
means that images with lower-numbered and more keywords will have higher
|
||||
priority.
|
||||
"""
|
||||
return [idx for (idx, x) in enumerate(cover_names) if x in filename]
|
||||
|
||||
|
||||
def art_in_path(path, cover_names, cautious):
|
||||
"""Look for album art files in a specified directory.
|
||||
"""
|
||||
if not os.path.isdir(path):
|
||||
return
|
||||
|
||||
# Find all files that look like images in the directory.
|
||||
images = []
|
||||
for fn in os.listdir(path):
|
||||
for ext in IMAGE_EXTENSIONS:
|
||||
if fn.lower().endswith('.' + ext):
|
||||
images.append(fn)
|
||||
|
||||
# Look for "preferred" filenames.
|
||||
images = sorted(images, key=lambda x: filename_priority(x, cover_names))
|
||||
cover_pat = r"(\b|_)({0})(\b|_)".format('|'.join(cover_names))
|
||||
for fn in images:
|
||||
if re.search(cover_pat, os.path.splitext(fn)[0], re.I):
|
||||
log.debug(u'fetchart: using well-named art file {0}'.format(
|
||||
util.displayable_path(fn)
|
||||
))
|
||||
return os.path.join(path, fn)
|
||||
|
||||
# Fall back to any image in the folder.
|
||||
if images and not cautious:
|
||||
log.debug(u'fetchart: using fallback art file {0}'.format(
|
||||
util.displayable_path(images[0])
|
||||
))
|
||||
return os.path.join(path, images[0])
|
||||
|
||||
|
||||
# Try each source in turn.
|
||||
|
||||
SOURCES_ALL = [u'coverart', u'itunes', u'amazon', u'albumart', u'google']
|
||||
|
||||
ART_FUNCS = {
|
||||
u'coverart': caa_art,
|
||||
u'itunes': itunes_art,
|
||||
u'albumart': aao_art,
|
||||
u'amazon': art_for_asin,
|
||||
u'google': google_art,
|
||||
}
|
||||
|
||||
|
||||
def _source_urls(album, sources=SOURCES_ALL):
|
||||
"""Generate possible source URLs for an album's art. The URLs are
|
||||
not guaranteed to work so they each need to be attempted in turn.
|
||||
This allows the main `art_for_album` function to abort iteration
|
||||
through this sequence early to avoid the cost of scraping when not
|
||||
necessary.
|
||||
"""
|
||||
for s in sources:
|
||||
urls = ART_FUNCS[s](album)
|
||||
for url in urls:
|
||||
yield url
|
||||
|
||||
|
||||
def art_for_album(album, paths, maxwidth=None, local_only=False):
|
||||
"""Given an Album object, returns a path to downloaded art for the
|
||||
album (or None if no art is found). If `maxwidth`, then images are
|
||||
resized to this maximum pixel size. If `local_only`, then only local
|
||||
image files from the filesystem are returned; no network requests
|
||||
are made.
|
||||
"""
|
||||
out = None
|
||||
|
||||
# Local art.
|
||||
cover_names = config['fetchart']['cover_names'].as_str_seq()
|
||||
cover_names = map(util.bytestring_path, cover_names)
|
||||
cautious = config['fetchart']['cautious'].get(bool)
|
||||
if paths:
|
||||
for path in paths:
|
||||
out = art_in_path(path, cover_names, cautious)
|
||||
if out:
|
||||
break
|
||||
|
||||
# Web art sources.
|
||||
remote_priority = config['fetchart']['remote_priority'].get(bool)
|
||||
if not local_only and (remote_priority or not out):
|
||||
for url in _source_urls(album,
|
||||
config['fetchart']['sources'].as_str_seq()):
|
||||
if maxwidth:
|
||||
url = ArtResizer.shared.proxy_url(maxwidth, url)
|
||||
candidate = _fetch_image(url)
|
||||
if candidate:
|
||||
out = candidate
|
||||
break
|
||||
|
||||
if maxwidth and out:
|
||||
out = ArtResizer.shared.resize(maxwidth, out)
|
||||
return out
|
||||
|
||||
|
||||
# PLUGIN LOGIC ###############################################################
|
||||
|
||||
|
||||
def batch_fetch_art(lib, albums, force, maxwidth=None):
|
||||
"""Fetch album art for each of the albums. This implements the manual
|
||||
fetchart CLI command.
|
||||
"""
|
||||
for album in albums:
|
||||
if album.artpath and not force:
|
||||
message = 'has album art'
|
||||
else:
|
||||
# In ordinary invocations, look for images on the
|
||||
# filesystem. When forcing, however, always go to the Web
|
||||
# sources.
|
||||
local_paths = None if force else [album.path]
|
||||
|
||||
path = art_for_album(album, local_paths, maxwidth)
|
||||
if path:
|
||||
album.set_art(path, False)
|
||||
album.store()
|
||||
message = ui.colorize('green', 'found album art')
|
||||
else:
|
||||
message = ui.colorize('red', 'no art found')
|
||||
|
||||
log.info(u'{0} - {1}: {2}'.format(album.albumartist, album.album,
|
||||
message))
|
||||
|
||||
|
||||
class FetchArtPlugin(plugins.BeetsPlugin):
|
||||
def __init__(self):
|
||||
super(FetchArtPlugin, self).__init__()
|
||||
|
||||
self.config.add({
|
||||
'auto': True,
|
||||
'maxwidth': 0,
|
||||
'remote_priority': False,
|
||||
'cautious': False,
|
||||
'google_search': False,
|
||||
'cover_names': ['cover', 'front', 'art', 'album', 'folder'],
|
||||
'sources': SOURCES_ALL,
|
||||
})
|
||||
|
||||
# Holds paths to downloaded images between fetching them and
|
||||
# placing them in the filesystem.
|
||||
self.art_paths = {}
|
||||
|
||||
self.maxwidth = self.config['maxwidth'].get(int)
|
||||
if self.config['auto']:
|
||||
# Enable two import hooks when fetching is enabled.
|
||||
self.import_stages = [self.fetch_art]
|
||||
self.register_listener('import_task_files', self.assign_art)
|
||||
|
||||
available_sources = list(SOURCES_ALL)
|
||||
if not HAVE_ITUNES and u'itunes' in available_sources:
|
||||
available_sources.remove(u'itunes')
|
||||
self.config['sources'] = plugins.sanitize_choices(
|
||||
self.config['sources'].as_str_seq(), available_sources)
|
||||
|
||||
# Asynchronous; after music is added to the library.
|
||||
def fetch_art(self, session, task):
|
||||
"""Find art for the album being imported."""
|
||||
if task.is_album: # Only fetch art for full albums.
|
||||
if task.choice_flag == importer.action.ASIS:
|
||||
# For as-is imports, don't search Web sources for art.
|
||||
local = True
|
||||
elif task.choice_flag == importer.action.APPLY:
|
||||
# Search everywhere for art.
|
||||
local = False
|
||||
else:
|
||||
# For any other choices (e.g., TRACKS), do nothing.
|
||||
return
|
||||
|
||||
path = art_for_album(task.album, task.paths, self.maxwidth, local)
|
||||
|
||||
if path:
|
||||
self.art_paths[task] = path
|
||||
|
||||
# Synchronous; after music files are put in place.
|
||||
def assign_art(self, session, task):
|
||||
"""Place the discovered art in the filesystem."""
|
||||
if task in self.art_paths:
|
||||
path = self.art_paths.pop(task)
|
||||
|
||||
album = task.album
|
||||
src_removed = (config['import']['delete'].get(bool) or
|
||||
config['import']['move'].get(bool))
|
||||
album.set_art(path, not src_removed)
|
||||
album.store()
|
||||
if src_removed:
|
||||
task.prune(path)
|
||||
|
||||
# Manual album art fetching.
|
||||
def commands(self):
|
||||
cmd = ui.Subcommand('fetchart', help='download album art')
|
||||
cmd.parser.add_option('-f', '--force', dest='force',
|
||||
action='store_true', default=False,
|
||||
help='re-download art when already present')
|
||||
|
||||
def func(lib, opts, args):
|
||||
batch_fetch_art(lib, lib.albums(ui.decargs(args)), opts.force,
|
||||
self.maxwidth)
|
||||
cmd.func = func
|
||||
return [cmd]
|
||||
@@ -0,0 +1,544 @@
|
||||
# This file is part of beets.
|
||||
# Copyright 2014, Adrian Sampson.
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining
|
||||
# a copy of this software and associated documentation files (the
|
||||
# "Software"), to deal in the Software without restriction, including
|
||||
# without limitation the rights to use, copy, modify, merge, publish,
|
||||
# distribute, sublicense, and/or sell copies of the Software, and to
|
||||
# permit persons to whom the Software is furnished to do so, subject to
|
||||
# the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be
|
||||
# included in all copies or substantial portions of the Software.
|
||||
|
||||
"""Fetches, embeds, and displays lyrics.
|
||||
"""
|
||||
from __future__ import print_function
|
||||
|
||||
import re
|
||||
import logging
|
||||
import requests
|
||||
import json
|
||||
import unicodedata
|
||||
import urllib
|
||||
import difflib
|
||||
import itertools
|
||||
from HTMLParser import HTMLParseError
|
||||
|
||||
from beets import plugins
|
||||
from beets import config, ui
|
||||
|
||||
|
||||
# Global logger.
|
||||
|
||||
log = logging.getLogger('beets')
|
||||
|
||||
DIV_RE = re.compile(r'<(/?)div>?', re.I)
|
||||
COMMENT_RE = re.compile(r'<!--.*-->', re.S)
|
||||
TAG_RE = re.compile(r'<[^>]*>')
|
||||
BREAK_RE = re.compile(r'\n?\s*<br([\s|/][^>]*)*>\s*\n?', re.I)
|
||||
URL_CHARACTERS = {
|
||||
u'\u2018': u"'",
|
||||
u'\u2019': u"'",
|
||||
u'\u201c': u'"',
|
||||
u'\u201d': u'"',
|
||||
u'\u2010': u'-',
|
||||
u'\u2011': u'-',
|
||||
u'\u2012': u'-',
|
||||
u'\u2013': u'-',
|
||||
u'\u2014': u'-',
|
||||
u'\u2015': u'-',
|
||||
u'\u2016': u'-',
|
||||
u'\u2026': u'...',
|
||||
}
|
||||
|
||||
|
||||
# Utilities.
|
||||
|
||||
def fetch_url(url):
|
||||
"""Retrieve the content at a given URL, or return None if the source
|
||||
is unreachable.
|
||||
"""
|
||||
try:
|
||||
r = requests.get(url, verify=False)
|
||||
except requests.RequestException as exc:
|
||||
log.debug(u'lyrics request failed: {0}'.format(exc))
|
||||
return
|
||||
if r.status_code == requests.codes.ok:
|
||||
return r.text
|
||||
else:
|
||||
log.debug(u'failed to fetch: {0} ({1})'.format(url, r.status_code))
|
||||
|
||||
|
||||
def unescape(text):
|
||||
"""Resolves &#xxx; HTML entities (and some others)."""
|
||||
if isinstance(text, str):
|
||||
text = text.decode('utf8', 'ignore')
|
||||
out = text.replace(u' ', u' ')
|
||||
|
||||
def replchar(m):
|
||||
num = m.group(1)
|
||||
return unichr(int(num))
|
||||
out = re.sub(u"&#(\d+);", replchar, out)
|
||||
return out
|
||||
|
||||
|
||||
def extract_text_between(html, start_marker, end_marker):
|
||||
try:
|
||||
_, html = html.split(start_marker, 1)
|
||||
html, _ = html.split(end_marker, 1)
|
||||
except ValueError:
|
||||
return u''
|
||||
return html
|
||||
|
||||
|
||||
def extract_text_in(html, starttag):
|
||||
"""Extract the text from a <DIV> tag in the HTML starting with
|
||||
``starttag``. Returns None if parsing fails.
|
||||
"""
|
||||
|
||||
# Strip off the leading text before opening tag.
|
||||
try:
|
||||
_, html = html.split(starttag, 1)
|
||||
except ValueError:
|
||||
return
|
||||
|
||||
# Walk through balanced DIV tags.
|
||||
level = 0
|
||||
parts = []
|
||||
pos = 0
|
||||
for match in DIV_RE.finditer(html):
|
||||
if match.group(1): # Closing tag.
|
||||
level -= 1
|
||||
if level == 0:
|
||||
pos = match.end()
|
||||
else: # Opening tag.
|
||||
if level == 0:
|
||||
parts.append(html[pos:match.start()])
|
||||
level += 1
|
||||
|
||||
if level == -1:
|
||||
parts.append(html[pos:match.start()])
|
||||
break
|
||||
else:
|
||||
print('no closing tag found!')
|
||||
return
|
||||
return u''.join(parts)
|
||||
|
||||
|
||||
def search_pairs(item):
|
||||
"""Yield a pairs of artists and titles to search for.
|
||||
|
||||
The first item in the pair is the name of the artist, the second
|
||||
item is a list of song names.
|
||||
|
||||
In addition to the artist and title obtained from the `item` the
|
||||
method tries to strip extra information like paranthesized suffixes
|
||||
and featured artists from the strings and add them as candidates.
|
||||
The method also tries to split multiple titles separated with `/`.
|
||||
"""
|
||||
|
||||
title, artist = item.title, item.artist
|
||||
titles = [title]
|
||||
artists = [artist]
|
||||
|
||||
# Remove any featuring artists from the artists name
|
||||
pattern = r"(.*?) {0}".format(plugins.feat_tokens())
|
||||
match = re.search(pattern, artist, re.IGNORECASE)
|
||||
if match:
|
||||
artists.append(match.group(1))
|
||||
|
||||
# Remove a parenthesized suffix from a title string. Common
|
||||
# examples include (live), (remix), and (acoustic).
|
||||
pattern = r"(.+?)\s+[(].*[)]$"
|
||||
match = re.search(pattern, title, re.IGNORECASE)
|
||||
if match:
|
||||
titles.append(match.group(1))
|
||||
|
||||
# Remove any featuring artists from the title
|
||||
pattern = r"(.*?) {0}".format(plugins.feat_tokens(for_artist=False))
|
||||
for title in titles[:]:
|
||||
match = re.search(pattern, title, re.IGNORECASE)
|
||||
if match:
|
||||
titles.append(match.group(1))
|
||||
|
||||
# Check for a dual song (e.g. Pink Floyd - Speak to Me / Breathe)
|
||||
# and each of them.
|
||||
multi_titles = []
|
||||
for title in titles:
|
||||
multi_titles.append([title])
|
||||
if '/' in title:
|
||||
multi_titles.append([x.strip() for x in title.split('/')])
|
||||
|
||||
return itertools.product(artists, multi_titles)
|
||||
|
||||
|
||||
def _encode(s):
|
||||
"""Encode the string for inclusion in a URL (common to both
|
||||
LyricsWiki and Lyrics.com).
|
||||
"""
|
||||
if isinstance(s, unicode):
|
||||
for char, repl in URL_CHARACTERS.items():
|
||||
s = s.replace(char, repl)
|
||||
s = s.encode('utf8', 'ignore')
|
||||
return urllib.quote(s)
|
||||
|
||||
# Musixmatch
|
||||
|
||||
MUSIXMATCH_URL_PATTERN = 'https://www.musixmatch.com/lyrics/%s/%s'
|
||||
|
||||
|
||||
def fetch_musixmatch(artist, title):
|
||||
url = MUSIXMATCH_URL_PATTERN % (_lw_encode(artist.title()),
|
||||
_lw_encode(title.title()))
|
||||
html = fetch_url(url)
|
||||
if not html:
|
||||
return
|
||||
lyrics = extract_text_between(html, '"lyrics_body":', '"lyrics_language":')
|
||||
return lyrics.strip(',"').replace('\\n', '\n')
|
||||
|
||||
# LyricsWiki.
|
||||
|
||||
LYRICSWIKI_URL_PATTERN = 'http://lyrics.wikia.com/%s:%s'
|
||||
|
||||
|
||||
def _lw_encode(s):
|
||||
s = re.sub(r'\s+', '_', s)
|
||||
s = s.replace("<", "Less_Than")
|
||||
s = s.replace(">", "Greater_Than")
|
||||
s = s.replace("#", "Number_")
|
||||
s = re.sub(r'[\[\{]', '(', s)
|
||||
s = re.sub(r'[\]\}]', ')', s)
|
||||
return _encode(s)
|
||||
|
||||
|
||||
def fetch_lyricswiki(artist, title):
|
||||
"""Fetch lyrics from LyricsWiki."""
|
||||
url = LYRICSWIKI_URL_PATTERN % (_lw_encode(artist), _lw_encode(title))
|
||||
html = fetch_url(url)
|
||||
if not html:
|
||||
return
|
||||
|
||||
lyrics = extract_text_in(html, u"<div class='lyricbox'>")
|
||||
if lyrics and 'Unfortunately, we are not licensed' not in lyrics:
|
||||
return lyrics
|
||||
|
||||
|
||||
# Lyrics.com.
|
||||
|
||||
LYRICSCOM_URL_PATTERN = 'http://www.lyrics.com/%s-lyrics-%s.html'
|
||||
LYRICSCOM_NOT_FOUND = (
|
||||
'Sorry, we do not have the lyric',
|
||||
'Submit Lyrics',
|
||||
)
|
||||
|
||||
|
||||
def _lc_encode(s):
|
||||
s = re.sub(r'[^\w\s-]', '', s)
|
||||
s = re.sub(r'\s+', '-', s)
|
||||
return _encode(s).lower()
|
||||
|
||||
|
||||
def fetch_lyricscom(artist, title):
|
||||
"""Fetch lyrics from Lyrics.com."""
|
||||
url = LYRICSCOM_URL_PATTERN % (_lc_encode(title), _lc_encode(artist))
|
||||
html = fetch_url(url)
|
||||
if not html:
|
||||
return
|
||||
lyrics = extract_text_between(html, '<div id="lyrics" class="SCREENONLY" '
|
||||
'itemprop="description">', '</div>')
|
||||
if not lyrics:
|
||||
return
|
||||
for not_found_str in LYRICSCOM_NOT_FOUND:
|
||||
if not_found_str in lyrics:
|
||||
return
|
||||
|
||||
parts = lyrics.split('\n---\nLyrics powered by', 1)
|
||||
if parts:
|
||||
return parts[0]
|
||||
|
||||
|
||||
# Optional Google custom search API backend.
|
||||
|
||||
def slugify(text):
|
||||
"""Normalize a string and remove non-alphanumeric characters.
|
||||
"""
|
||||
text = re.sub(r"[-'_\s]", '_', text)
|
||||
text = re.sub(r"_+", '_', text).strip('_')
|
||||
pat = "([^,\(]*)\((.*?)\)" # Remove content within parentheses
|
||||
text = re.sub(pat, '\g<1>', text).strip()
|
||||
try:
|
||||
text = unicodedata.normalize('NFKD', text).encode('ascii', 'ignore')
|
||||
text = unicode(re.sub('[-\s]+', ' ', text))
|
||||
except UnicodeDecodeError:
|
||||
log.exception(u"Failing to normalize '{0}'".format(text))
|
||||
return text
|
||||
|
||||
|
||||
BY_TRANS = ['by', 'par', 'de', 'von']
|
||||
LYRICS_TRANS = ['lyrics', 'paroles', 'letras', 'liedtexte']
|
||||
|
||||
|
||||
def is_page_candidate(urlLink, urlTitle, title, artist):
|
||||
"""Return True if the URL title makes it a good candidate to be a
|
||||
page that contains lyrics of title by artist.
|
||||
"""
|
||||
title = slugify(title.lower())
|
||||
artist = slugify(artist.lower())
|
||||
sitename = re.search(u"//([^/]+)/.*", slugify(urlLink.lower())).group(1)
|
||||
urlTitle = slugify(urlTitle.lower())
|
||||
# Check if URL title contains song title (exact match)
|
||||
if urlTitle.find(title) != -1:
|
||||
return True
|
||||
# or try extracting song title from URL title and check if
|
||||
# they are close enough
|
||||
tokens = [by + '_' + artist for by in BY_TRANS] + \
|
||||
[artist, sitename, sitename.replace('www.', '')] + LYRICS_TRANS
|
||||
songTitle = re.sub(u'(%s)' % u'|'.join(tokens), u'', urlTitle)
|
||||
songTitle = songTitle.strip('_|')
|
||||
typoRatio = .9
|
||||
return difflib.SequenceMatcher(None, songTitle, title).ratio() >= typoRatio
|
||||
|
||||
|
||||
def remove_credits(text):
|
||||
"""Remove first/last line of text if it contains the word 'lyrics'
|
||||
eg 'Lyrics by songsdatabase.com'
|
||||
"""
|
||||
textlines = text.split('\n')
|
||||
credits = None
|
||||
for i in (0, -1):
|
||||
if textlines and 'lyrics' in textlines[i].lower():
|
||||
credits = textlines.pop(i)
|
||||
if credits:
|
||||
text = '\n'.join(textlines)
|
||||
return text
|
||||
|
||||
|
||||
def is_lyrics(text, artist=None):
|
||||
"""Determine whether the text seems to be valid lyrics.
|
||||
"""
|
||||
if not text:
|
||||
return False
|
||||
badTriggersOcc = []
|
||||
nbLines = text.count('\n')
|
||||
if nbLines <= 1:
|
||||
log.debug(u"Ignoring too short lyrics '{0}'".format(text))
|
||||
return False
|
||||
elif nbLines < 5:
|
||||
badTriggersOcc.append('too_short')
|
||||
else:
|
||||
# Lyrics look legit, remove credits to avoid being penalized further
|
||||
# down
|
||||
text = remove_credits(text)
|
||||
|
||||
badTriggers = ['lyrics', 'copyright', 'property', 'links']
|
||||
if artist:
|
||||
badTriggersOcc += [artist]
|
||||
|
||||
for item in badTriggers:
|
||||
badTriggersOcc += [item] * len(re.findall(r'\W%s\W' % item,
|
||||
text, re.I))
|
||||
|
||||
if badTriggersOcc:
|
||||
log.debug(u'Bad triggers detected: {0}'.format(badTriggersOcc))
|
||||
return len(badTriggersOcc) < 2
|
||||
|
||||
|
||||
def _scrape_strip_cruft(html, plain_text_out=False):
|
||||
"""Clean up HTML
|
||||
"""
|
||||
html = unescape(html)
|
||||
|
||||
html = html.replace('\r', '\n') # Normalize EOL.
|
||||
html = re.sub(r' +', ' ', html) # Whitespaces collapse.
|
||||
html = BREAK_RE.sub('\n', html) # <br> eats up surrounding '\n'.
|
||||
html = re.sub(r'<(script).*?</\1>(?s)', '', html) # Strip script tags.
|
||||
|
||||
if plain_text_out: # Strip remaining HTML tags
|
||||
html = COMMENT_RE.sub('', html)
|
||||
html = TAG_RE.sub('', html)
|
||||
|
||||
html = '\n'.join([x.strip() for x in html.strip().split('\n')])
|
||||
html = re.sub(r'\n{3,}', r'\n\n', html)
|
||||
return html
|
||||
|
||||
|
||||
def _scrape_merge_paragraphs(html):
|
||||
html = re.sub(r'</p>\s*<p(\s*[^>]*)>', '\n', html)
|
||||
return re.sub(r'<div .*>\s*</div>', '\n', html)
|
||||
|
||||
|
||||
def scrape_lyrics_from_html(html):
|
||||
"""Scrape lyrics from a URL. If no lyrics can be found, return None
|
||||
instead.
|
||||
"""
|
||||
from bs4 import SoupStrainer, BeautifulSoup
|
||||
|
||||
if not html:
|
||||
return None
|
||||
|
||||
def is_text_notcode(text):
|
||||
length = len(text)
|
||||
return (length > 20 and
|
||||
text.count(' ') > length / 25 and
|
||||
(text.find('{') == -1 or text.find(';') == -1))
|
||||
html = _scrape_strip_cruft(html)
|
||||
html = _scrape_merge_paragraphs(html)
|
||||
|
||||
# extract all long text blocks that are not code
|
||||
try:
|
||||
soup = BeautifulSoup(html, "html.parser",
|
||||
parse_only=SoupStrainer(text=is_text_notcode))
|
||||
except HTMLParseError:
|
||||
return None
|
||||
soup = sorted(soup.stripped_strings, key=len)[-1]
|
||||
return soup
|
||||
|
||||
|
||||
def fetch_google(artist, title):
|
||||
"""Fetch lyrics from Google search results.
|
||||
"""
|
||||
query = u"%s %s" % (artist, title)
|
||||
api_key = config['lyrics']['google_API_key'].get(unicode)
|
||||
engine_id = config['lyrics']['google_engine_ID'].get(unicode)
|
||||
url = u'https://www.googleapis.com/customsearch/v1?key=%s&cx=%s&q=%s' % \
|
||||
(api_key, engine_id, urllib.quote(query.encode('utf8')))
|
||||
|
||||
data = urllib.urlopen(url)
|
||||
data = json.load(data)
|
||||
if 'error' in data:
|
||||
reason = data['error']['errors'][0]['reason']
|
||||
log.debug(u'google lyrics backend error: {0}'.format(reason))
|
||||
return
|
||||
|
||||
if 'items' in data.keys():
|
||||
for item in data['items']:
|
||||
urlLink = item['link']
|
||||
urlTitle = item.get('title', u'')
|
||||
if not is_page_candidate(urlLink, urlTitle, title, artist):
|
||||
continue
|
||||
html = fetch_url(urlLink)
|
||||
lyrics = scrape_lyrics_from_html(html)
|
||||
if not lyrics:
|
||||
continue
|
||||
|
||||
if is_lyrics(lyrics, artist):
|
||||
log.debug(u'got lyrics from {0}'.format(item['displayLink']))
|
||||
return lyrics
|
||||
|
||||
|
||||
# Plugin logic.
|
||||
|
||||
SOURCES = ['google', 'lyricwiki', 'lyrics.com', 'musixmatch']
|
||||
SOURCE_BACKENDS = {
|
||||
'google': fetch_google,
|
||||
'lyricwiki': fetch_lyricswiki,
|
||||
'lyrics.com': fetch_lyricscom,
|
||||
'musixmatch': fetch_musixmatch,
|
||||
}
|
||||
|
||||
|
||||
class LyricsPlugin(plugins.BeetsPlugin):
|
||||
def __init__(self):
|
||||
super(LyricsPlugin, self).__init__()
|
||||
self.import_stages = [self.imported]
|
||||
self.config.add({
|
||||
'auto': True,
|
||||
'google_API_key': None,
|
||||
'google_engine_ID': u'009217259823014548361:lndtuqkycfu',
|
||||
'fallback': None,
|
||||
'force': False,
|
||||
'sources': SOURCES,
|
||||
})
|
||||
|
||||
available_sources = list(SOURCES)
|
||||
if not self.config['google_API_key'].get() and \
|
||||
'google' in SOURCES:
|
||||
available_sources.remove('google')
|
||||
self.config['sources'] = plugins.sanitize_choices(
|
||||
self.config['sources'].as_str_seq(), available_sources)
|
||||
self.backends = []
|
||||
for key in self.config['sources'].as_str_seq():
|
||||
self.backends.append(SOURCE_BACKENDS[key])
|
||||
|
||||
def commands(self):
|
||||
cmd = ui.Subcommand('lyrics', help='fetch song lyrics')
|
||||
cmd.parser.add_option('-p', '--print', dest='printlyr',
|
||||
action='store_true', default=False,
|
||||
help='print lyrics to console')
|
||||
cmd.parser.add_option('-f', '--force', dest='force_refetch',
|
||||
action='store_true', default=False,
|
||||
help='always re-download lyrics')
|
||||
|
||||
def func(lib, opts, args):
|
||||
# The "write to files" option corresponds to the
|
||||
# import_write config value.
|
||||
write = config['import']['write'].get(bool)
|
||||
for item in lib.items(ui.decargs(args)):
|
||||
self.fetch_item_lyrics(
|
||||
lib, logging.INFO, item, write,
|
||||
opts.force_refetch or self.config['force'],
|
||||
)
|
||||
if opts.printlyr and item.lyrics:
|
||||
ui.print_(item.lyrics)
|
||||
|
||||
cmd.func = func
|
||||
return [cmd]
|
||||
|
||||
def imported(self, session, task):
|
||||
"""Import hook for fetching lyrics automatically.
|
||||
"""
|
||||
if self.config['auto']:
|
||||
for item in task.imported_items():
|
||||
self.fetch_item_lyrics(session.lib, logging.DEBUG, item,
|
||||
False, self.config['force'])
|
||||
|
||||
def fetch_item_lyrics(self, lib, loglevel, item, write, force):
|
||||
"""Fetch and store lyrics for a single item. If ``write``, then the
|
||||
lyrics will also be written to the file itself. The ``loglevel``
|
||||
parameter controls the visibility of the function's status log
|
||||
messages.
|
||||
"""
|
||||
# Skip if the item already has lyrics.
|
||||
if not force and item.lyrics:
|
||||
log.log(loglevel, u'lyrics already present: {0} - {1}'
|
||||
.format(item.artist, item.title))
|
||||
return
|
||||
|
||||
lyrics = None
|
||||
for artist, titles in search_pairs(item):
|
||||
lyrics = [self.get_lyrics(artist, title) for title in titles]
|
||||
if any(lyrics):
|
||||
break
|
||||
|
||||
lyrics = u"\n\n---\n\n".join([l for l in lyrics if l])
|
||||
|
||||
if lyrics:
|
||||
log.log(loglevel, u'fetched lyrics: {0} - {1}'
|
||||
.format(item.artist, item.title))
|
||||
else:
|
||||
log.log(loglevel, u'lyrics not found: {0} - {1}'
|
||||
.format(item.artist, item.title))
|
||||
fallback = self.config['fallback'].get()
|
||||
if fallback:
|
||||
lyrics = fallback
|
||||
else:
|
||||
return
|
||||
|
||||
item.lyrics = lyrics
|
||||
|
||||
if write:
|
||||
item.try_write()
|
||||
item.store()
|
||||
|
||||
def get_lyrics(self, artist, title):
|
||||
"""Fetch lyrics, trying each source in turn. Return a string or
|
||||
None if no lyrics were found.
|
||||
"""
|
||||
for backend in self.backends:
|
||||
lyrics = backend(artist, title)
|
||||
if lyrics:
|
||||
log.debug(u'got lyrics from backend: {0}'
|
||||
.format(backend.__name__))
|
||||
return _scrape_strip_cruft(lyrics, True)
|
||||
@@ -14,9 +14,6 @@
|
||||
|
||||
# Minor modifications made by Andrew Resch to replace the BTFailure errors with Exceptions
|
||||
|
||||
from types import StringType, IntType, LongType, DictType, ListType, TupleType
|
||||
|
||||
|
||||
def decode_int(x, f):
|
||||
f += 1
|
||||
newf = x.index('e', f)
|
||||
@@ -24,36 +21,32 @@ def decode_int(x, f):
|
||||
if x[f] == '-':
|
||||
if x[f + 1] == '0':
|
||||
raise ValueError
|
||||
elif x[f] == '0' and newf != f + 1:
|
||||
elif x[f] == '0' and newf != f+1:
|
||||
raise ValueError
|
||||
return (n, newf + 1)
|
||||
|
||||
return (n, newf+1)
|
||||
|
||||
def decode_string(x, f):
|
||||
colon = x.index(':', f)
|
||||
n = int(x[f:colon])
|
||||
if x[f] == '0' and colon != f + 1:
|
||||
if x[f] == '0' and colon != f+1:
|
||||
raise ValueError
|
||||
colon += 1
|
||||
return (x[colon:colon + n], colon + n)
|
||||
|
||||
return (x[colon:colon+n], colon+n)
|
||||
|
||||
def decode_list(x, f):
|
||||
r, f = [], f + 1
|
||||
r, f = [], f+1
|
||||
while x[f] != 'e':
|
||||
v, f = decode_func[x[f]](x, f)
|
||||
r.append(v)
|
||||
return (r, f + 1)
|
||||
|
||||
|
||||
def decode_dict(x, f):
|
||||
r, f = {}, f + 1
|
||||
r, f = {}, f+1
|
||||
while x[f] != 'e':
|
||||
k, f = decode_string(x, f)
|
||||
r[k], f = decode_func[x[f]](x, f)
|
||||
return (r, f + 1)
|
||||
|
||||
|
||||
decode_func = {}
|
||||
decode_func['l'] = decode_list
|
||||
decode_func['d'] = decode_dict
|
||||
@@ -69,7 +62,6 @@ decode_func['7'] = decode_string
|
||||
decode_func['8'] = decode_string
|
||||
decode_func['9'] = decode_string
|
||||
|
||||
|
||||
def bdecode(x):
|
||||
try:
|
||||
r, l = decode_func[x[0]](x, 0)
|
||||
@@ -78,41 +70,38 @@ def bdecode(x):
|
||||
|
||||
return r
|
||||
|
||||
from types import StringType, IntType, LongType, DictType, ListType, TupleType
|
||||
|
||||
|
||||
class Bencached(object):
|
||||
|
||||
__slots__ = ['bencoded']
|
||||
|
||||
def __init__(self, s):
|
||||
self.bencoded = s
|
||||
|
||||
|
||||
def encode_bencached(x, r):
|
||||
def encode_bencached(x,r):
|
||||
r.append(x.bencoded)
|
||||
|
||||
|
||||
def encode_int(x, r):
|
||||
r.extend(('i', str(x), 'e'))
|
||||
|
||||
|
||||
def encode_bool(x, r):
|
||||
if x:
|
||||
encode_int(1, r)
|
||||
else:
|
||||
encode_int(0, r)
|
||||
|
||||
|
||||
def encode_string(x, r):
|
||||
r.extend((str(len(x)), ':', x))
|
||||
|
||||
|
||||
def encode_list(x, r):
|
||||
r.append('l')
|
||||
for i in x:
|
||||
encode_func[type(i)](i, r)
|
||||
r.append('e')
|
||||
|
||||
|
||||
def encode_dict(x, r):
|
||||
def encode_dict(x,r):
|
||||
r.append('d')
|
||||
ilist = x.items()
|
||||
ilist.sort()
|
||||
@@ -121,7 +110,6 @@ def encode_dict(x, r):
|
||||
encode_func[type(v)](v, r)
|
||||
r.append('e')
|
||||
|
||||
|
||||
encode_func = {}
|
||||
encode_func[Bencached] = encode_bencached
|
||||
encode_func[IntType] = encode_int
|
||||
@@ -133,12 +121,10 @@ encode_func[DictType] = encode_dict
|
||||
|
||||
try:
|
||||
from types import BooleanType
|
||||
|
||||
encode_func[BooleanType] = encode_bool
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
def bencode(x):
|
||||
r = []
|
||||
encode_func[type(x)](x, r)
|
||||
Executable
+803
@@ -0,0 +1,803 @@
|
||||
"""biplist -- a library for reading and writing binary property list files.
|
||||
|
||||
Binary Property List (plist) files provide a faster and smaller serialization
|
||||
format for property lists on OS X. This is a library for generating binary
|
||||
plists which can be read by OS X, iOS, or other clients.
|
||||
|
||||
The API models the plistlib API, and will call through to plistlib when
|
||||
XML serialization or deserialization is required.
|
||||
|
||||
To generate plists with UID values, wrap the values with the Uid object. The
|
||||
value must be an int.
|
||||
|
||||
To generate plists with NSData/CFData values, wrap the values with the
|
||||
Data object. The value must be a string.
|
||||
|
||||
Date values can only be datetime.datetime objects.
|
||||
|
||||
The exceptions InvalidPlistException and NotBinaryPlistException may be
|
||||
thrown to indicate that the data cannot be serialized or deserialized as
|
||||
a binary plist.
|
||||
|
||||
Plist generation example:
|
||||
|
||||
from biplist import *
|
||||
from datetime import datetime
|
||||
plist = {'aKey':'aValue',
|
||||
'0':1.322,
|
||||
'now':datetime.now(),
|
||||
'list':[1,2,3],
|
||||
'tuple':('a','b','c')
|
||||
}
|
||||
try:
|
||||
writePlist(plist, "example.plist")
|
||||
except (InvalidPlistException, NotBinaryPlistException), e:
|
||||
print "Something bad happened:", e
|
||||
|
||||
Plist parsing example:
|
||||
|
||||
from biplist import *
|
||||
try:
|
||||
plist = readPlist("example.plist")
|
||||
print plist
|
||||
except (InvalidPlistException, NotBinaryPlistException), e:
|
||||
print "Not a plist:", e
|
||||
"""
|
||||
|
||||
import sys
|
||||
from collections import namedtuple
|
||||
import datetime
|
||||
import io
|
||||
import math
|
||||
import plistlib
|
||||
from struct import pack, unpack
|
||||
from struct import error as struct_error
|
||||
import sys
|
||||
import time
|
||||
|
||||
try:
|
||||
unicode
|
||||
unicodeEmpty = r''
|
||||
except NameError:
|
||||
unicode = str
|
||||
unicodeEmpty = ''
|
||||
try:
|
||||
long
|
||||
except NameError:
|
||||
long = int
|
||||
try:
|
||||
{}.iteritems
|
||||
iteritems = lambda x: x.iteritems()
|
||||
except AttributeError:
|
||||
iteritems = lambda x: x.items()
|
||||
|
||||
__all__ = [
|
||||
'Uid', 'Data', 'readPlist', 'writePlist', 'readPlistFromString',
|
||||
'writePlistToString', 'InvalidPlistException', 'NotBinaryPlistException'
|
||||
]
|
||||
|
||||
# Apple uses Jan 1, 2001 as a base for all plist date/times.
|
||||
apple_reference_date = datetime.datetime.utcfromtimestamp(978307200)
|
||||
|
||||
class Uid(int):
|
||||
"""Wrapper around integers for representing UID values. This
|
||||
is used in keyed archiving."""
|
||||
def __repr__(self):
|
||||
return "Uid(%d)" % self
|
||||
|
||||
class Data(bytes):
|
||||
"""Wrapper around str types for representing Data values."""
|
||||
pass
|
||||
|
||||
class InvalidPlistException(Exception):
|
||||
"""Raised when the plist is incorrectly formatted."""
|
||||
pass
|
||||
|
||||
class NotBinaryPlistException(Exception):
|
||||
"""Raised when a binary plist was expected but not encountered."""
|
||||
pass
|
||||
|
||||
def readPlist(pathOrFile):
|
||||
"""Raises NotBinaryPlistException, InvalidPlistException"""
|
||||
didOpen = False
|
||||
result = None
|
||||
if isinstance(pathOrFile, (bytes, unicode)):
|
||||
pathOrFile = open(pathOrFile, 'rb')
|
||||
didOpen = True
|
||||
try:
|
||||
reader = PlistReader(pathOrFile)
|
||||
result = reader.parse()
|
||||
except NotBinaryPlistException as e:
|
||||
try:
|
||||
pathOrFile.seek(0)
|
||||
result = None
|
||||
if hasattr(plistlib, 'loads'):
|
||||
contents = None
|
||||
if isinstance(pathOrFile, (bytes, unicode)):
|
||||
with open(pathOrFile, 'rb') as f:
|
||||
contents = f.read()
|
||||
else:
|
||||
contents = pathOrFile.read()
|
||||
result = plistlib.loads(contents)
|
||||
else:
|
||||
result = plistlib.readPlist(pathOrFile)
|
||||
result = wrapDataObject(result, for_binary=True)
|
||||
except Exception as e:
|
||||
raise InvalidPlistException(e)
|
||||
finally:
|
||||
if didOpen:
|
||||
pathOrFile.close()
|
||||
return result
|
||||
|
||||
def wrapDataObject(o, for_binary=False):
|
||||
if isinstance(o, Data) and not for_binary:
|
||||
v = sys.version_info
|
||||
if not (v[0] >= 3 and v[1] >= 4):
|
||||
o = plistlib.Data(o)
|
||||
elif isinstance(o, (bytes, plistlib.Data)) and for_binary:
|
||||
if hasattr(o, 'data'):
|
||||
o = Data(o.data)
|
||||
elif isinstance(o, tuple):
|
||||
o = wrapDataObject(list(o), for_binary)
|
||||
o = tuple(o)
|
||||
elif isinstance(o, list):
|
||||
for i in range(len(o)):
|
||||
o[i] = wrapDataObject(o[i], for_binary)
|
||||
elif isinstance(o, dict):
|
||||
for k in o:
|
||||
o[k] = wrapDataObject(o[k], for_binary)
|
||||
return o
|
||||
|
||||
def writePlist(rootObject, pathOrFile, binary=True):
|
||||
if not binary:
|
||||
rootObject = wrapDataObject(rootObject, binary)
|
||||
if hasattr(plistlib, "dump"):
|
||||
if isinstance(pathOrFile, (bytes, unicode)):
|
||||
with open(pathOrFile, 'wb') as f:
|
||||
return plistlib.dump(rootObject, f)
|
||||
else:
|
||||
return plistlib.dump(rootObject, pathOrFile)
|
||||
else:
|
||||
return plistlib.writePlist(rootObject, pathOrFile)
|
||||
else:
|
||||
didOpen = False
|
||||
if isinstance(pathOrFile, (bytes, unicode)):
|
||||
pathOrFile = open(pathOrFile, 'wb')
|
||||
didOpen = True
|
||||
writer = PlistWriter(pathOrFile)
|
||||
result = writer.writeRoot(rootObject)
|
||||
if didOpen:
|
||||
pathOrFile.close()
|
||||
return result
|
||||
|
||||
def readPlistFromString(data):
|
||||
return readPlist(io.BytesIO(data))
|
||||
|
||||
def writePlistToString(rootObject, binary=True):
|
||||
if not binary:
|
||||
rootObject = wrapDataObject(rootObject, binary)
|
||||
if hasattr(plistlib, "dumps"):
|
||||
return plistlib.dumps(rootObject)
|
||||
elif hasattr(plistlib, "writePlistToBytes"):
|
||||
return plistlib.writePlistToBytes(rootObject)
|
||||
else:
|
||||
return plistlib.writePlistToString(rootObject)
|
||||
else:
|
||||
ioObject = io.BytesIO()
|
||||
writer = PlistWriter(ioObject)
|
||||
writer.writeRoot(rootObject)
|
||||
return ioObject.getvalue()
|
||||
|
||||
def is_stream_binary_plist(stream):
|
||||
stream.seek(0)
|
||||
header = stream.read(7)
|
||||
if header == b'bplist0':
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
PlistTrailer = namedtuple('PlistTrailer', 'offsetSize, objectRefSize, offsetCount, topLevelObjectNumber, offsetTableOffset')
|
||||
PlistByteCounts = namedtuple('PlistByteCounts', 'nullBytes, boolBytes, intBytes, realBytes, dateBytes, dataBytes, stringBytes, uidBytes, arrayBytes, setBytes, dictBytes')
|
||||
|
||||
class PlistReader(object):
|
||||
file = None
|
||||
contents = ''
|
||||
offsets = None
|
||||
trailer = None
|
||||
currentOffset = 0
|
||||
|
||||
def __init__(self, fileOrStream):
|
||||
"""Raises NotBinaryPlistException."""
|
||||
self.reset()
|
||||
self.file = fileOrStream
|
||||
|
||||
def parse(self):
|
||||
return self.readRoot()
|
||||
|
||||
def reset(self):
|
||||
self.trailer = None
|
||||
self.contents = ''
|
||||
self.offsets = []
|
||||
self.currentOffset = 0
|
||||
|
||||
def readRoot(self):
|
||||
result = None
|
||||
self.reset()
|
||||
# Get the header, make sure it's a valid file.
|
||||
if not is_stream_binary_plist(self.file):
|
||||
raise NotBinaryPlistException()
|
||||
self.file.seek(0)
|
||||
self.contents = self.file.read()
|
||||
if len(self.contents) < 32:
|
||||
raise InvalidPlistException("File is too short.")
|
||||
trailerContents = self.contents[-32:]
|
||||
try:
|
||||
self.trailer = PlistTrailer._make(unpack("!xxxxxxBBQQQ", trailerContents))
|
||||
offset_size = self.trailer.offsetSize * self.trailer.offsetCount
|
||||
offset = self.trailer.offsetTableOffset
|
||||
offset_contents = self.contents[offset:offset+offset_size]
|
||||
offset_i = 0
|
||||
while offset_i < self.trailer.offsetCount:
|
||||
begin = self.trailer.offsetSize*offset_i
|
||||
tmp_contents = offset_contents[begin:begin+self.trailer.offsetSize]
|
||||
tmp_sized = self.getSizedInteger(tmp_contents, self.trailer.offsetSize)
|
||||
self.offsets.append(tmp_sized)
|
||||
offset_i += 1
|
||||
self.setCurrentOffsetToObjectNumber(self.trailer.topLevelObjectNumber)
|
||||
result = self.readObject()
|
||||
except TypeError as e:
|
||||
raise InvalidPlistException(e)
|
||||
return result
|
||||
|
||||
def setCurrentOffsetToObjectNumber(self, objectNumber):
|
||||
self.currentOffset = self.offsets[objectNumber]
|
||||
|
||||
def readObject(self):
|
||||
result = None
|
||||
tmp_byte = self.contents[self.currentOffset:self.currentOffset+1]
|
||||
marker_byte = unpack("!B", tmp_byte)[0]
|
||||
format = (marker_byte >> 4) & 0x0f
|
||||
extra = marker_byte & 0x0f
|
||||
self.currentOffset += 1
|
||||
|
||||
def proc_extra(extra):
|
||||
if extra == 0b1111:
|
||||
#self.currentOffset += 1
|
||||
extra = self.readObject()
|
||||
return extra
|
||||
|
||||
# bool, null, or fill byte
|
||||
if format == 0b0000:
|
||||
if extra == 0b0000:
|
||||
result = None
|
||||
elif extra == 0b1000:
|
||||
result = False
|
||||
elif extra == 0b1001:
|
||||
result = True
|
||||
elif extra == 0b1111:
|
||||
pass # fill byte
|
||||
else:
|
||||
raise InvalidPlistException("Invalid object found at offset: %d" % (self.currentOffset - 1))
|
||||
# int
|
||||
elif format == 0b0001:
|
||||
extra = proc_extra(extra)
|
||||
result = self.readInteger(pow(2, extra))
|
||||
# real
|
||||
elif format == 0b0010:
|
||||
extra = proc_extra(extra)
|
||||
result = self.readReal(extra)
|
||||
# date
|
||||
elif format == 0b0011 and extra == 0b0011:
|
||||
result = self.readDate()
|
||||
# data
|
||||
elif format == 0b0100:
|
||||
extra = proc_extra(extra)
|
||||
result = self.readData(extra)
|
||||
# ascii string
|
||||
elif format == 0b0101:
|
||||
extra = proc_extra(extra)
|
||||
result = self.readAsciiString(extra)
|
||||
# Unicode string
|
||||
elif format == 0b0110:
|
||||
extra = proc_extra(extra)
|
||||
result = self.readUnicode(extra)
|
||||
# uid
|
||||
elif format == 0b1000:
|
||||
result = self.readUid(extra)
|
||||
# array
|
||||
elif format == 0b1010:
|
||||
extra = proc_extra(extra)
|
||||
result = self.readArray(extra)
|
||||
# set
|
||||
elif format == 0b1100:
|
||||
extra = proc_extra(extra)
|
||||
result = set(self.readArray(extra))
|
||||
# dict
|
||||
elif format == 0b1101:
|
||||
extra = proc_extra(extra)
|
||||
result = self.readDict(extra)
|
||||
else:
|
||||
raise InvalidPlistException("Invalid object found: {format: %s, extra: %s}" % (bin(format), bin(extra)))
|
||||
return result
|
||||
|
||||
def readInteger(self, byteSize):
|
||||
result = 0
|
||||
original_offset = self.currentOffset
|
||||
data = self.contents[self.currentOffset:self.currentOffset + byteSize]
|
||||
result = self.getSizedInteger(data, byteSize, as_number=True)
|
||||
self.currentOffset = original_offset + byteSize
|
||||
return result
|
||||
|
||||
def readReal(self, length):
|
||||
result = 0.0
|
||||
to_read = pow(2, length)
|
||||
data = self.contents[self.currentOffset:self.currentOffset+to_read]
|
||||
if length == 2: # 4 bytes
|
||||
result = unpack('>f', data)[0]
|
||||
elif length == 3: # 8 bytes
|
||||
result = unpack('>d', data)[0]
|
||||
else:
|
||||
raise InvalidPlistException("Unknown real of length %d bytes" % to_read)
|
||||
return result
|
||||
|
||||
def readRefs(self, count):
|
||||
refs = []
|
||||
i = 0
|
||||
while i < count:
|
||||
fragment = self.contents[self.currentOffset:self.currentOffset+self.trailer.objectRefSize]
|
||||
ref = self.getSizedInteger(fragment, len(fragment))
|
||||
refs.append(ref)
|
||||
self.currentOffset += self.trailer.objectRefSize
|
||||
i += 1
|
||||
return refs
|
||||
|
||||
def readArray(self, count):
|
||||
result = []
|
||||
values = self.readRefs(count)
|
||||
i = 0
|
||||
while i < len(values):
|
||||
self.setCurrentOffsetToObjectNumber(values[i])
|
||||
value = self.readObject()
|
||||
result.append(value)
|
||||
i += 1
|
||||
return result
|
||||
|
||||
def readDict(self, count):
|
||||
result = {}
|
||||
keys = self.readRefs(count)
|
||||
values = self.readRefs(count)
|
||||
i = 0
|
||||
while i < len(keys):
|
||||
self.setCurrentOffsetToObjectNumber(keys[i])
|
||||
key = self.readObject()
|
||||
self.setCurrentOffsetToObjectNumber(values[i])
|
||||
value = self.readObject()
|
||||
result[key] = value
|
||||
i += 1
|
||||
return result
|
||||
|
||||
def readAsciiString(self, length):
|
||||
result = unpack("!%ds" % length, self.contents[self.currentOffset:self.currentOffset+length])[0]
|
||||
self.currentOffset += length
|
||||
return result
|
||||
|
||||
def readUnicode(self, length):
|
||||
actual_length = length*2
|
||||
data = self.contents[self.currentOffset:self.currentOffset+actual_length]
|
||||
# unpack not needed?!! data = unpack(">%ds" % (actual_length), data)[0]
|
||||
self.currentOffset += actual_length
|
||||
return data.decode('utf_16_be')
|
||||
|
||||
def readDate(self):
|
||||
result = unpack(">d", self.contents[self.currentOffset:self.currentOffset+8])[0]
|
||||
# Use timedelta to workaround time_t size limitation on 32-bit python.
|
||||
result = datetime.timedelta(seconds=result) + apple_reference_date
|
||||
self.currentOffset += 8
|
||||
return result
|
||||
|
||||
def readData(self, length):
|
||||
result = self.contents[self.currentOffset:self.currentOffset+length]
|
||||
self.currentOffset += length
|
||||
return Data(result)
|
||||
|
||||
def readUid(self, length):
|
||||
return Uid(self.readInteger(length+1))
|
||||
|
||||
def getSizedInteger(self, data, byteSize, as_number=False):
|
||||
"""Numbers of 8 bytes are signed integers when they refer to numbers, but unsigned otherwise."""
|
||||
result = 0
|
||||
# 1, 2, and 4 byte integers are unsigned
|
||||
if byteSize == 1:
|
||||
result = unpack('>B', data)[0]
|
||||
elif byteSize == 2:
|
||||
result = unpack('>H', data)[0]
|
||||
elif byteSize == 4:
|
||||
result = unpack('>L', data)[0]
|
||||
elif byteSize == 8:
|
||||
if as_number:
|
||||
result = unpack('>q', data)[0]
|
||||
else:
|
||||
result = unpack('>Q', data)[0]
|
||||
elif byteSize <= 16:
|
||||
# Handle odd-sized or integers larger than 8 bytes
|
||||
# Don't naively go over 16 bytes, in order to prevent infinite loops.
|
||||
result = 0
|
||||
if hasattr(int, 'from_bytes'):
|
||||
result = int.from_bytes(data, 'big')
|
||||
else:
|
||||
for byte in data:
|
||||
result = (result << 8) | unpack('>B', byte)[0]
|
||||
else:
|
||||
raise InvalidPlistException("Encountered integer longer than 16 bytes.")
|
||||
return result
|
||||
|
||||
class HashableWrapper(object):
|
||||
def __init__(self, value):
|
||||
self.value = value
|
||||
def __repr__(self):
|
||||
return "<HashableWrapper: %s>" % [self.value]
|
||||
|
||||
class BoolWrapper(object):
|
||||
def __init__(self, value):
|
||||
self.value = value
|
||||
def __repr__(self):
|
||||
return "<BoolWrapper: %s>" % self.value
|
||||
|
||||
class FloatWrapper(object):
|
||||
_instances = {}
|
||||
def __new__(klass, value):
|
||||
# Ensure FloatWrapper(x) for a given float x is always the same object
|
||||
wrapper = klass._instances.get(value)
|
||||
if wrapper is None:
|
||||
wrapper = object.__new__(klass)
|
||||
wrapper.value = value
|
||||
klass._instances[value] = wrapper
|
||||
return wrapper
|
||||
def __repr__(self):
|
||||
return "<FloatWrapper: %s>" % self.value
|
||||
|
||||
class PlistWriter(object):
|
||||
header = b'bplist00bybiplist1.0'
|
||||
file = None
|
||||
byteCounts = None
|
||||
trailer = None
|
||||
computedUniques = None
|
||||
writtenReferences = None
|
||||
referencePositions = None
|
||||
wrappedTrue = None
|
||||
wrappedFalse = None
|
||||
|
||||
def __init__(self, file):
|
||||
self.reset()
|
||||
self.file = file
|
||||
self.wrappedTrue = BoolWrapper(True)
|
||||
self.wrappedFalse = BoolWrapper(False)
|
||||
|
||||
def reset(self):
|
||||
self.byteCounts = PlistByteCounts(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0)
|
||||
self.trailer = PlistTrailer(0, 0, 0, 0, 0)
|
||||
|
||||
# A set of all the uniques which have been computed.
|
||||
self.computedUniques = set()
|
||||
# A list of all the uniques which have been written.
|
||||
self.writtenReferences = {}
|
||||
# A dict of the positions of the written uniques.
|
||||
self.referencePositions = {}
|
||||
|
||||
def positionOfObjectReference(self, obj):
|
||||
"""If the given object has been written already, return its
|
||||
position in the offset table. Otherwise, return None."""
|
||||
return self.writtenReferences.get(obj)
|
||||
|
||||
def writeRoot(self, root):
|
||||
"""
|
||||
Strategy is:
|
||||
- write header
|
||||
- wrap root object so everything is hashable
|
||||
- compute size of objects which will be written
|
||||
- need to do this in order to know how large the object refs
|
||||
will be in the list/dict/set reference lists
|
||||
- write objects
|
||||
- keep objects in writtenReferences
|
||||
- keep positions of object references in referencePositions
|
||||
- write object references with the length computed previously
|
||||
- computer object reference length
|
||||
- write object reference positions
|
||||
- write trailer
|
||||
"""
|
||||
output = self.header
|
||||
wrapped_root = self.wrapRoot(root)
|
||||
should_reference_root = True#not isinstance(wrapped_root, HashableWrapper)
|
||||
self.computeOffsets(wrapped_root, asReference=should_reference_root, isRoot=True)
|
||||
self.trailer = self.trailer._replace(**{'objectRefSize':self.intSize(len(self.computedUniques))})
|
||||
(_, output) = self.writeObjectReference(wrapped_root, output)
|
||||
output = self.writeObject(wrapped_root, output, setReferencePosition=True)
|
||||
|
||||
# output size at this point is an upper bound on how big the
|
||||
# object reference offsets need to be.
|
||||
self.trailer = self.trailer._replace(**{
|
||||
'offsetSize':self.intSize(len(output)),
|
||||
'offsetCount':len(self.computedUniques),
|
||||
'offsetTableOffset':len(output),
|
||||
'topLevelObjectNumber':0
|
||||
})
|
||||
|
||||
output = self.writeOffsetTable(output)
|
||||
output += pack('!xxxxxxBBQQQ', *self.trailer)
|
||||
self.file.write(output)
|
||||
|
||||
def wrapRoot(self, root):
|
||||
if isinstance(root, bool):
|
||||
if root is True:
|
||||
return self.wrappedTrue
|
||||
else:
|
||||
return self.wrappedFalse
|
||||
elif isinstance(root, float):
|
||||
return FloatWrapper(root)
|
||||
elif isinstance(root, set):
|
||||
n = set()
|
||||
for value in root:
|
||||
n.add(self.wrapRoot(value))
|
||||
return HashableWrapper(n)
|
||||
elif isinstance(root, dict):
|
||||
n = {}
|
||||
for key, value in iteritems(root):
|
||||
n[self.wrapRoot(key)] = self.wrapRoot(value)
|
||||
return HashableWrapper(n)
|
||||
elif isinstance(root, list):
|
||||
n = []
|
||||
for value in root:
|
||||
n.append(self.wrapRoot(value))
|
||||
return HashableWrapper(n)
|
||||
elif isinstance(root, tuple):
|
||||
n = tuple([self.wrapRoot(value) for value in root])
|
||||
return HashableWrapper(n)
|
||||
else:
|
||||
return root
|
||||
|
||||
def incrementByteCount(self, field, incr=1):
|
||||
self.byteCounts = self.byteCounts._replace(**{field:self.byteCounts.__getattribute__(field) + incr})
|
||||
|
||||
def computeOffsets(self, obj, asReference=False, isRoot=False):
|
||||
def check_key(key):
|
||||
if key is None:
|
||||
raise InvalidPlistException('Dictionary keys cannot be null in plists.')
|
||||
elif isinstance(key, Data):
|
||||
raise InvalidPlistException('Data cannot be dictionary keys in plists.')
|
||||
elif not isinstance(key, (bytes, unicode)):
|
||||
raise InvalidPlistException('Keys must be strings.')
|
||||
|
||||
def proc_size(size):
|
||||
if size > 0b1110:
|
||||
size += self.intSize(size)
|
||||
return size
|
||||
# If this should be a reference, then we keep a record of it in the
|
||||
# uniques table.
|
||||
if asReference:
|
||||
if obj in self.computedUniques:
|
||||
return
|
||||
else:
|
||||
self.computedUniques.add(obj)
|
||||
|
||||
if obj is None:
|
||||
self.incrementByteCount('nullBytes')
|
||||
elif isinstance(obj, BoolWrapper):
|
||||
self.incrementByteCount('boolBytes')
|
||||
elif isinstance(obj, Uid):
|
||||
size = self.intSize(obj)
|
||||
self.incrementByteCount('uidBytes', incr=1+size)
|
||||
elif isinstance(obj, (int, long)):
|
||||
size = self.intSize(obj)
|
||||
self.incrementByteCount('intBytes', incr=1+size)
|
||||
elif isinstance(obj, FloatWrapper):
|
||||
size = self.realSize(obj)
|
||||
self.incrementByteCount('realBytes', incr=1+size)
|
||||
elif isinstance(obj, datetime.datetime):
|
||||
self.incrementByteCount('dateBytes', incr=2)
|
||||
elif isinstance(obj, Data):
|
||||
size = proc_size(len(obj))
|
||||
self.incrementByteCount('dataBytes', incr=1+size)
|
||||
elif isinstance(obj, (unicode, bytes)):
|
||||
size = proc_size(len(obj))
|
||||
self.incrementByteCount('stringBytes', incr=1+size)
|
||||
elif isinstance(obj, HashableWrapper):
|
||||
obj = obj.value
|
||||
if isinstance(obj, set):
|
||||
size = proc_size(len(obj))
|
||||
self.incrementByteCount('setBytes', incr=1+size)
|
||||
for value in obj:
|
||||
self.computeOffsets(value, asReference=True)
|
||||
elif isinstance(obj, (list, tuple)):
|
||||
size = proc_size(len(obj))
|
||||
self.incrementByteCount('arrayBytes', incr=1+size)
|
||||
for value in obj:
|
||||
asRef = True
|
||||
self.computeOffsets(value, asReference=True)
|
||||
elif isinstance(obj, dict):
|
||||
size = proc_size(len(obj))
|
||||
self.incrementByteCount('dictBytes', incr=1+size)
|
||||
for key, value in iteritems(obj):
|
||||
check_key(key)
|
||||
self.computeOffsets(key, asReference=True)
|
||||
self.computeOffsets(value, asReference=True)
|
||||
else:
|
||||
raise InvalidPlistException("Unknown object type.")
|
||||
|
||||
def writeObjectReference(self, obj, output):
|
||||
"""Tries to write an object reference, adding it to the references
|
||||
table. Does not write the actual object bytes or set the reference
|
||||
position. Returns a tuple of whether the object was a new reference
|
||||
(True if it was, False if it already was in the reference table)
|
||||
and the new output.
|
||||
"""
|
||||
position = self.positionOfObjectReference(obj)
|
||||
if position is None:
|
||||
self.writtenReferences[obj] = len(self.writtenReferences)
|
||||
output += self.binaryInt(len(self.writtenReferences) - 1, byteSize=self.trailer.objectRefSize)
|
||||
return (True, output)
|
||||
else:
|
||||
output += self.binaryInt(position, byteSize=self.trailer.objectRefSize)
|
||||
return (False, output)
|
||||
|
||||
def writeObject(self, obj, output, setReferencePosition=False):
|
||||
"""Serializes the given object to the output. Returns output.
|
||||
If setReferencePosition is True, will set the position the
|
||||
object was written.
|
||||
"""
|
||||
def proc_variable_length(format, length):
|
||||
result = b''
|
||||
if length > 0b1110:
|
||||
result += pack('!B', (format << 4) | 0b1111)
|
||||
result = self.writeObject(length, result)
|
||||
else:
|
||||
result += pack('!B', (format << 4) | length)
|
||||
return result
|
||||
|
||||
if isinstance(obj, (str, unicode)) and obj == unicodeEmpty:
|
||||
# The Apple Plist decoder can't decode a zero length Unicode string.
|
||||
obj = b''
|
||||
|
||||
if setReferencePosition:
|
||||
self.referencePositions[obj] = len(output)
|
||||
|
||||
if obj is None:
|
||||
output += pack('!B', 0b00000000)
|
||||
elif isinstance(obj, BoolWrapper):
|
||||
if obj.value is False:
|
||||
output += pack('!B', 0b00001000)
|
||||
else:
|
||||
output += pack('!B', 0b00001001)
|
||||
elif isinstance(obj, Uid):
|
||||
size = self.intSize(obj)
|
||||
output += pack('!B', (0b1000 << 4) | size - 1)
|
||||
output += self.binaryInt(obj)
|
||||
elif isinstance(obj, (int, long)):
|
||||
byteSize = self.intSize(obj)
|
||||
root = math.log(byteSize, 2)
|
||||
output += pack('!B', (0b0001 << 4) | int(root))
|
||||
output += self.binaryInt(obj, as_number=True)
|
||||
elif isinstance(obj, FloatWrapper):
|
||||
# just use doubles
|
||||
output += pack('!B', (0b0010 << 4) | 3)
|
||||
output += self.binaryReal(obj)
|
||||
elif isinstance(obj, datetime.datetime):
|
||||
timestamp = (obj - apple_reference_date).total_seconds()
|
||||
output += pack('!B', 0b00110011)
|
||||
output += pack('!d', float(timestamp))
|
||||
elif isinstance(obj, Data):
|
||||
output += proc_variable_length(0b0100, len(obj))
|
||||
output += obj
|
||||
elif isinstance(obj, unicode):
|
||||
byteData = obj.encode('utf_16_be')
|
||||
output += proc_variable_length(0b0110, len(byteData)//2)
|
||||
output += byteData
|
||||
elif isinstance(obj, bytes):
|
||||
output += proc_variable_length(0b0101, len(obj))
|
||||
output += obj
|
||||
elif isinstance(obj, HashableWrapper):
|
||||
obj = obj.value
|
||||
if isinstance(obj, (set, list, tuple)):
|
||||
if isinstance(obj, set):
|
||||
output += proc_variable_length(0b1100, len(obj))
|
||||
else:
|
||||
output += proc_variable_length(0b1010, len(obj))
|
||||
|
||||
objectsToWrite = []
|
||||
for objRef in obj:
|
||||
(isNew, output) = self.writeObjectReference(objRef, output)
|
||||
if isNew:
|
||||
objectsToWrite.append(objRef)
|
||||
for objRef in objectsToWrite:
|
||||
output = self.writeObject(objRef, output, setReferencePosition=True)
|
||||
elif isinstance(obj, dict):
|
||||
output += proc_variable_length(0b1101, len(obj))
|
||||
keys = []
|
||||
values = []
|
||||
objectsToWrite = []
|
||||
for key, value in iteritems(obj):
|
||||
keys.append(key)
|
||||
values.append(value)
|
||||
for key in keys:
|
||||
(isNew, output) = self.writeObjectReference(key, output)
|
||||
if isNew:
|
||||
objectsToWrite.append(key)
|
||||
for value in values:
|
||||
(isNew, output) = self.writeObjectReference(value, output)
|
||||
if isNew:
|
||||
objectsToWrite.append(value)
|
||||
for objRef in objectsToWrite:
|
||||
output = self.writeObject(objRef, output, setReferencePosition=True)
|
||||
return output
|
||||
|
||||
def writeOffsetTable(self, output):
|
||||
"""Writes all of the object reference offsets."""
|
||||
all_positions = []
|
||||
writtenReferences = list(self.writtenReferences.items())
|
||||
writtenReferences.sort(key=lambda x: x[1])
|
||||
for obj,order in writtenReferences:
|
||||
# Porting note: Elsewhere we deliberately replace empty unicdoe strings
|
||||
# with empty binary strings, but the empty unicode string
|
||||
# goes into writtenReferences. This isn't an issue in Py2
|
||||
# because u'' and b'' have the same hash; but it is in
|
||||
# Py3, where they don't.
|
||||
if bytes != str and obj == unicodeEmpty:
|
||||
obj = b''
|
||||
position = self.referencePositions.get(obj)
|
||||
if position is None:
|
||||
raise InvalidPlistException("Error while writing offsets table. Object not found. %s" % obj)
|
||||
output += self.binaryInt(position, self.trailer.offsetSize)
|
||||
all_positions.append(position)
|
||||
return output
|
||||
|
||||
def binaryReal(self, obj):
|
||||
# just use doubles
|
||||
result = pack('>d', obj.value)
|
||||
return result
|
||||
|
||||
def binaryInt(self, obj, byteSize=None, as_number=False):
|
||||
result = b''
|
||||
if byteSize is None:
|
||||
byteSize = self.intSize(obj)
|
||||
if byteSize == 1:
|
||||
result += pack('>B', obj)
|
||||
elif byteSize == 2:
|
||||
result += pack('>H', obj)
|
||||
elif byteSize == 4:
|
||||
result += pack('>L', obj)
|
||||
elif byteSize == 8:
|
||||
if as_number:
|
||||
result += pack('>q', obj)
|
||||
else:
|
||||
result += pack('>Q', obj)
|
||||
elif byteSize <= 16:
|
||||
try:
|
||||
result = pack('>Q', 0) + pack('>Q', obj)
|
||||
except struct_error as e:
|
||||
raise InvalidPlistException("Unable to pack integer %d: %s" % (obj, e))
|
||||
else:
|
||||
raise InvalidPlistException("Core Foundation can't handle integers with size greater than 16 bytes.")
|
||||
return result
|
||||
|
||||
def intSize(self, obj):
|
||||
"""Returns the number of bytes necessary to store the given integer."""
|
||||
# SIGNED
|
||||
if obj < 0: # Signed integer, always 8 bytes
|
||||
return 8
|
||||
# UNSIGNED
|
||||
elif obj <= 0xFF: # 1 byte
|
||||
return 1
|
||||
elif obj <= 0xFFFF: # 2 bytes
|
||||
return 2
|
||||
elif obj <= 0xFFFFFFFF: # 4 bytes
|
||||
return 4
|
||||
# SIGNED
|
||||
# 0x7FFFFFFFFFFFFFFF is the max.
|
||||
elif obj <= 0x7FFFFFFFFFFFFFFF: # 8 bytes signed
|
||||
return 8
|
||||
elif obj <= 0xffffffffffffffff: # 8 bytes unsigned
|
||||
return 16
|
||||
else:
|
||||
raise InvalidPlistException("Core Foundation can't handle integers with size greater than 8 bytes.")
|
||||
|
||||
def realSize(self, obj):
|
||||
return 8
|
||||
@@ -0,0 +1,468 @@
|
||||
"""Beautiful Soup
|
||||
Elixir and Tonic
|
||||
"The Screen-Scraper's Friend"
|
||||
http://www.crummy.com/software/BeautifulSoup/
|
||||
|
||||
Beautiful Soup uses a pluggable XML or HTML parser to parse a
|
||||
(possibly invalid) document into a tree representation. Beautiful Soup
|
||||
provides provides methods and Pythonic idioms that make it easy to
|
||||
navigate, search, and modify the parse tree.
|
||||
|
||||
Beautiful Soup works with Python 2.6 and up. It works better if lxml
|
||||
and/or html5lib is installed.
|
||||
|
||||
For more than you ever wanted to know about Beautiful Soup, see the
|
||||
documentation:
|
||||
http://www.crummy.com/software/BeautifulSoup/bs4/doc/
|
||||
"""
|
||||
|
||||
__author__ = "Leonard Richardson (leonardr@segfault.org)"
|
||||
__version__ = "4.4.0"
|
||||
__copyright__ = "Copyright (c) 2004-2015 Leonard Richardson"
|
||||
__license__ = "MIT"
|
||||
|
||||
__all__ = ['BeautifulSoup']
|
||||
|
||||
import os
|
||||
import re
|
||||
import warnings
|
||||
|
||||
from .builder import builder_registry, ParserRejectedMarkup
|
||||
from .dammit import UnicodeDammit
|
||||
from .element import (
|
||||
CData,
|
||||
Comment,
|
||||
DEFAULT_OUTPUT_ENCODING,
|
||||
Declaration,
|
||||
Doctype,
|
||||
NavigableString,
|
||||
PageElement,
|
||||
ProcessingInstruction,
|
||||
ResultSet,
|
||||
SoupStrainer,
|
||||
Tag,
|
||||
)
|
||||
|
||||
# The very first thing we do is give a useful error if someone is
|
||||
# running this code under Python 3 without converting it.
|
||||
'You are trying to run the Python 2 version of Beautiful Soup under Python 3. This will not work.'<>'You need to convert the code, either by installing it (`python setup.py install`) or by running 2to3 (`2to3 -w bs4`).'
|
||||
|
||||
class BeautifulSoup(Tag):
|
||||
"""
|
||||
This class defines the basic interface called by the tree builders.
|
||||
|
||||
These methods will be called by the parser:
|
||||
reset()
|
||||
feed(markup)
|
||||
|
||||
The tree builder may call these methods from its feed() implementation:
|
||||
handle_starttag(name, attrs) # See note about return value
|
||||
handle_endtag(name)
|
||||
handle_data(data) # Appends to the current data node
|
||||
endData(containerClass=NavigableString) # Ends the current data node
|
||||
|
||||
No matter how complicated the underlying parser is, you should be
|
||||
able to build a tree using 'start tag' events, 'end tag' events,
|
||||
'data' events, and "done with data" events.
|
||||
|
||||
If you encounter an empty-element tag (aka a self-closing tag,
|
||||
like HTML's <br> tag), call handle_starttag and then
|
||||
handle_endtag.
|
||||
"""
|
||||
ROOT_TAG_NAME = u'[document]'
|
||||
|
||||
# If the end-user gives no indication which tree builder they
|
||||
# want, look for one with these features.
|
||||
DEFAULT_BUILDER_FEATURES = ['html', 'fast']
|
||||
|
||||
ASCII_SPACES = '\x20\x0a\x09\x0c\x0d'
|
||||
|
||||
NO_PARSER_SPECIFIED_WARNING = "No parser was explicitly specified, so I'm using the best available %(markup_type)s parser for this system (\"%(parser)s\"). This usually isn't a problem, but if you run this code on another system, or in a different virtual environment, it may use a different parser and behave differently.\n\nTo get rid of this warning, change this:\n\n BeautifulSoup([your markup])\n\nto this:\n\n BeautifulSoup([your markup], \"%(parser)s\")\n"
|
||||
|
||||
def __init__(self, markup="", features=None, builder=None,
|
||||
parse_only=None, from_encoding=None, exclude_encodings=None,
|
||||
**kwargs):
|
||||
"""The Soup object is initialized as the 'root tag', and the
|
||||
provided markup (which can be a string or a file-like object)
|
||||
is fed into the underlying parser."""
|
||||
|
||||
if 'convertEntities' in kwargs:
|
||||
warnings.warn(
|
||||
"BS4 does not respect the convertEntities argument to the "
|
||||
"BeautifulSoup constructor. Entities are always converted "
|
||||
"to Unicode characters.")
|
||||
|
||||
if 'markupMassage' in kwargs:
|
||||
del kwargs['markupMassage']
|
||||
warnings.warn(
|
||||
"BS4 does not respect the markupMassage argument to the "
|
||||
"BeautifulSoup constructor. The tree builder is responsible "
|
||||
"for any necessary markup massage.")
|
||||
|
||||
if 'smartQuotesTo' in kwargs:
|
||||
del kwargs['smartQuotesTo']
|
||||
warnings.warn(
|
||||
"BS4 does not respect the smartQuotesTo argument to the "
|
||||
"BeautifulSoup constructor. Smart quotes are always converted "
|
||||
"to Unicode characters.")
|
||||
|
||||
if 'selfClosingTags' in kwargs:
|
||||
del kwargs['selfClosingTags']
|
||||
warnings.warn(
|
||||
"BS4 does not respect the selfClosingTags argument to the "
|
||||
"BeautifulSoup constructor. The tree builder is responsible "
|
||||
"for understanding self-closing tags.")
|
||||
|
||||
if 'isHTML' in kwargs:
|
||||
del kwargs['isHTML']
|
||||
warnings.warn(
|
||||
"BS4 does not respect the isHTML argument to the "
|
||||
"BeautifulSoup constructor. Suggest you use "
|
||||
"features='lxml' for HTML and features='lxml-xml' for "
|
||||
"XML.")
|
||||
|
||||
def deprecated_argument(old_name, new_name):
|
||||
if old_name in kwargs:
|
||||
warnings.warn(
|
||||
'The "%s" argument to the BeautifulSoup constructor '
|
||||
'has been renamed to "%s."' % (old_name, new_name))
|
||||
value = kwargs[old_name]
|
||||
del kwargs[old_name]
|
||||
return value
|
||||
return None
|
||||
|
||||
parse_only = parse_only or deprecated_argument(
|
||||
"parseOnlyThese", "parse_only")
|
||||
|
||||
from_encoding = from_encoding or deprecated_argument(
|
||||
"fromEncoding", "from_encoding")
|
||||
|
||||
if len(kwargs) > 0:
|
||||
arg = kwargs.keys().pop()
|
||||
raise TypeError(
|
||||
"__init__() got an unexpected keyword argument '%s'" % arg)
|
||||
|
||||
if builder is None:
|
||||
original_features = features
|
||||
if isinstance(features, basestring):
|
||||
features = [features]
|
||||
if features is None or len(features) == 0:
|
||||
features = self.DEFAULT_BUILDER_FEATURES
|
||||
builder_class = builder_registry.lookup(*features)
|
||||
if builder_class is None:
|
||||
raise FeatureNotFound(
|
||||
"Couldn't find a tree builder with the features you "
|
||||
"requested: %s. Do you need to install a parser library?"
|
||||
% ",".join(features))
|
||||
builder = builder_class()
|
||||
if not (original_features == builder.NAME or
|
||||
original_features in builder.ALTERNATE_NAMES):
|
||||
if builder.is_xml:
|
||||
markup_type = "XML"
|
||||
else:
|
||||
markup_type = "HTML"
|
||||
warnings.warn(self.NO_PARSER_SPECIFIED_WARNING % dict(
|
||||
parser=builder.NAME,
|
||||
markup_type=markup_type))
|
||||
|
||||
self.builder = builder
|
||||
self.is_xml = builder.is_xml
|
||||
self.builder.soup = self
|
||||
|
||||
self.parse_only = parse_only
|
||||
|
||||
if hasattr(markup, 'read'): # It's a file-type object.
|
||||
markup = markup.read()
|
||||
elif len(markup) <= 256:
|
||||
# Print out warnings for a couple beginner problems
|
||||
# involving passing non-markup to Beautiful Soup.
|
||||
# Beautiful Soup will still parse the input as markup,
|
||||
# just in case that's what the user really wants.
|
||||
if (isinstance(markup, unicode)
|
||||
and not os.path.supports_unicode_filenames):
|
||||
possible_filename = markup.encode("utf8")
|
||||
else:
|
||||
possible_filename = markup
|
||||
is_file = False
|
||||
try:
|
||||
is_file = os.path.exists(possible_filename)
|
||||
except Exception, e:
|
||||
# This is almost certainly a problem involving
|
||||
# characters not valid in filenames on this
|
||||
# system. Just let it go.
|
||||
pass
|
||||
if is_file:
|
||||
if isinstance(markup, unicode):
|
||||
markup = markup.encode("utf8")
|
||||
warnings.warn(
|
||||
'"%s" looks like a filename, not markup. You should probably open this file and pass the filehandle into Beautiful Soup.' % markup)
|
||||
if markup[:5] == "http:" or markup[:6] == "https:":
|
||||
# TODO: This is ugly but I couldn't get it to work in
|
||||
# Python 3 otherwise.
|
||||
if ((isinstance(markup, bytes) and not b' ' in markup)
|
||||
or (isinstance(markup, unicode) and not u' ' in markup)):
|
||||
if isinstance(markup, unicode):
|
||||
markup = markup.encode("utf8")
|
||||
warnings.warn(
|
||||
'"%s" looks like a URL. Beautiful Soup is not an HTTP client. You should probably use an HTTP client to get the document behind the URL, and feed that document to Beautiful Soup.' % markup)
|
||||
|
||||
for (self.markup, self.original_encoding, self.declared_html_encoding,
|
||||
self.contains_replacement_characters) in (
|
||||
self.builder.prepare_markup(
|
||||
markup, from_encoding, exclude_encodings=exclude_encodings)):
|
||||
self.reset()
|
||||
try:
|
||||
self._feed()
|
||||
break
|
||||
except ParserRejectedMarkup:
|
||||
pass
|
||||
|
||||
# Clear out the markup and remove the builder's circular
|
||||
# reference to this object.
|
||||
self.markup = None
|
||||
self.builder.soup = None
|
||||
|
||||
def __copy__(self):
|
||||
return type(self)(self.encode(), builder=self.builder)
|
||||
|
||||
def __getstate__(self):
|
||||
# Frequently a tree builder can't be pickled.
|
||||
d = dict(self.__dict__)
|
||||
if 'builder' in d and not self.builder.picklable:
|
||||
del d['builder']
|
||||
return d
|
||||
|
||||
def _feed(self):
|
||||
# Convert the document to Unicode.
|
||||
self.builder.reset()
|
||||
|
||||
self.builder.feed(self.markup)
|
||||
# Close out any unfinished strings and close all the open tags.
|
||||
self.endData()
|
||||
while self.currentTag.name != self.ROOT_TAG_NAME:
|
||||
self.popTag()
|
||||
|
||||
def reset(self):
|
||||
Tag.__init__(self, self, self.builder, self.ROOT_TAG_NAME)
|
||||
self.hidden = 1
|
||||
self.builder.reset()
|
||||
self.current_data = []
|
||||
self.currentTag = None
|
||||
self.tagStack = []
|
||||
self.preserve_whitespace_tag_stack = []
|
||||
self.pushTag(self)
|
||||
|
||||
def new_tag(self, name, namespace=None, nsprefix=None, **attrs):
|
||||
"""Create a new tag associated with this soup."""
|
||||
return Tag(None, self.builder, name, namespace, nsprefix, attrs)
|
||||
|
||||
def new_string(self, s, subclass=NavigableString):
|
||||
"""Create a new NavigableString associated with this soup."""
|
||||
return subclass(s)
|
||||
|
||||
def insert_before(self, successor):
|
||||
raise NotImplementedError("BeautifulSoup objects don't support insert_before().")
|
||||
|
||||
def insert_after(self, successor):
|
||||
raise NotImplementedError("BeautifulSoup objects don't support insert_after().")
|
||||
|
||||
def popTag(self):
|
||||
tag = self.tagStack.pop()
|
||||
if self.preserve_whitespace_tag_stack and tag == self.preserve_whitespace_tag_stack[-1]:
|
||||
self.preserve_whitespace_tag_stack.pop()
|
||||
#print "Pop", tag.name
|
||||
if self.tagStack:
|
||||
self.currentTag = self.tagStack[-1]
|
||||
return self.currentTag
|
||||
|
||||
def pushTag(self, tag):
|
||||
#print "Push", tag.name
|
||||
if self.currentTag:
|
||||
self.currentTag.contents.append(tag)
|
||||
self.tagStack.append(tag)
|
||||
self.currentTag = self.tagStack[-1]
|
||||
if tag.name in self.builder.preserve_whitespace_tags:
|
||||
self.preserve_whitespace_tag_stack.append(tag)
|
||||
|
||||
def endData(self, containerClass=NavigableString):
|
||||
if self.current_data:
|
||||
current_data = u''.join(self.current_data)
|
||||
# If whitespace is not preserved, and this string contains
|
||||
# nothing but ASCII spaces, replace it with a single space
|
||||
# or newline.
|
||||
if not self.preserve_whitespace_tag_stack:
|
||||
strippable = True
|
||||
for i in current_data:
|
||||
if i not in self.ASCII_SPACES:
|
||||
strippable = False
|
||||
break
|
||||
if strippable:
|
||||
if '\n' in current_data:
|
||||
current_data = '\n'
|
||||
else:
|
||||
current_data = ' '
|
||||
|
||||
# Reset the data collector.
|
||||
self.current_data = []
|
||||
|
||||
# Should we add this string to the tree at all?
|
||||
if self.parse_only and len(self.tagStack) <= 1 and \
|
||||
(not self.parse_only.text or \
|
||||
not self.parse_only.search(current_data)):
|
||||
return
|
||||
|
||||
o = containerClass(current_data)
|
||||
self.object_was_parsed(o)
|
||||
|
||||
def object_was_parsed(self, o, parent=None, most_recent_element=None):
|
||||
"""Add an object to the parse tree."""
|
||||
parent = parent or self.currentTag
|
||||
previous_element = most_recent_element or self._most_recent_element
|
||||
|
||||
next_element = previous_sibling = next_sibling = None
|
||||
if isinstance(o, Tag):
|
||||
next_element = o.next_element
|
||||
next_sibling = o.next_sibling
|
||||
previous_sibling = o.previous_sibling
|
||||
if not previous_element:
|
||||
previous_element = o.previous_element
|
||||
|
||||
o.setup(parent, previous_element, next_element, previous_sibling, next_sibling)
|
||||
|
||||
self._most_recent_element = o
|
||||
parent.contents.append(o)
|
||||
|
||||
if parent.next_sibling:
|
||||
# This node is being inserted into an element that has
|
||||
# already been parsed. Deal with any dangling references.
|
||||
index = parent.contents.index(o)
|
||||
if index == 0:
|
||||
previous_element = parent
|
||||
previous_sibling = None
|
||||
else:
|
||||
previous_element = previous_sibling = parent.contents[index-1]
|
||||
if index == len(parent.contents)-1:
|
||||
next_element = parent.next_sibling
|
||||
next_sibling = None
|
||||
else:
|
||||
next_element = next_sibling = parent.contents[index+1]
|
||||
|
||||
o.previous_element = previous_element
|
||||
if previous_element:
|
||||
previous_element.next_element = o
|
||||
o.next_element = next_element
|
||||
if next_element:
|
||||
next_element.previous_element = o
|
||||
o.next_sibling = next_sibling
|
||||
if next_sibling:
|
||||
next_sibling.previous_sibling = o
|
||||
o.previous_sibling = previous_sibling
|
||||
if previous_sibling:
|
||||
previous_sibling.next_sibling = o
|
||||
|
||||
def _popToTag(self, name, nsprefix=None, inclusivePop=True):
|
||||
"""Pops the tag stack up to and including the most recent
|
||||
instance of the given tag. If inclusivePop is false, pops the tag
|
||||
stack up to but *not* including the most recent instqance of
|
||||
the given tag."""
|
||||
#print "Popping to %s" % name
|
||||
if name == self.ROOT_TAG_NAME:
|
||||
# The BeautifulSoup object itself can never be popped.
|
||||
return
|
||||
|
||||
most_recently_popped = None
|
||||
|
||||
stack_size = len(self.tagStack)
|
||||
for i in range(stack_size - 1, 0, -1):
|
||||
t = self.tagStack[i]
|
||||
if (name == t.name and nsprefix == t.prefix):
|
||||
if inclusivePop:
|
||||
most_recently_popped = self.popTag()
|
||||
break
|
||||
most_recently_popped = self.popTag()
|
||||
|
||||
return most_recently_popped
|
||||
|
||||
def handle_starttag(self, name, namespace, nsprefix, attrs):
|
||||
"""Push a start tag on to the stack.
|
||||
|
||||
If this method returns None, the tag was rejected by the
|
||||
SoupStrainer. You should proceed as if the tag had not occured
|
||||
in the document. For instance, if this was a self-closing tag,
|
||||
don't call handle_endtag.
|
||||
"""
|
||||
|
||||
# print "Start tag %s: %s" % (name, attrs)
|
||||
self.endData()
|
||||
|
||||
if (self.parse_only and len(self.tagStack) <= 1
|
||||
and (self.parse_only.text
|
||||
or not self.parse_only.search_tag(name, attrs))):
|
||||
return None
|
||||
|
||||
tag = Tag(self, self.builder, name, namespace, nsprefix, attrs,
|
||||
self.currentTag, self._most_recent_element)
|
||||
if tag is None:
|
||||
return tag
|
||||
if self._most_recent_element:
|
||||
self._most_recent_element.next_element = tag
|
||||
self._most_recent_element = tag
|
||||
self.pushTag(tag)
|
||||
return tag
|
||||
|
||||
def handle_endtag(self, name, nsprefix=None):
|
||||
#print "End tag: " + name
|
||||
self.endData()
|
||||
self._popToTag(name, nsprefix)
|
||||
|
||||
def handle_data(self, data):
|
||||
self.current_data.append(data)
|
||||
|
||||
def decode(self, pretty_print=False,
|
||||
eventual_encoding=DEFAULT_OUTPUT_ENCODING,
|
||||
formatter="minimal"):
|
||||
"""Returns a string or Unicode representation of this document.
|
||||
To get Unicode, pass None for encoding."""
|
||||
|
||||
if self.is_xml:
|
||||
# Print the XML declaration
|
||||
encoding_part = ''
|
||||
if eventual_encoding != None:
|
||||
encoding_part = ' encoding="%s"' % eventual_encoding
|
||||
prefix = u'<?xml version="1.0"%s?>\n' % encoding_part
|
||||
else:
|
||||
prefix = u''
|
||||
if not pretty_print:
|
||||
indent_level = None
|
||||
else:
|
||||
indent_level = 0
|
||||
return prefix + super(BeautifulSoup, self).decode(
|
||||
indent_level, eventual_encoding, formatter)
|
||||
|
||||
# Alias to make it easier to type import: 'from bs4 import _soup'
|
||||
_s = BeautifulSoup
|
||||
_soup = BeautifulSoup
|
||||
|
||||
class BeautifulStoneSoup(BeautifulSoup):
|
||||
"""Deprecated interface to an XML parser."""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
kwargs['features'] = 'xml'
|
||||
warnings.warn(
|
||||
'The BeautifulStoneSoup class is deprecated. Instead of using '
|
||||
'it, pass features="xml" into the BeautifulSoup constructor.')
|
||||
super(BeautifulStoneSoup, self).__init__(*args, **kwargs)
|
||||
|
||||
|
||||
class StopParsing(Exception):
|
||||
pass
|
||||
|
||||
class FeatureNotFound(ValueError):
|
||||
pass
|
||||
|
||||
|
||||
#By default, act as an HTML pretty-printer.
|
||||
if __name__ == '__main__':
|
||||
import sys
|
||||
soup = BeautifulSoup(sys.stdin)
|
||||
print soup.prettify()
|
||||
@@ -0,0 +1,324 @@
|
||||
from collections import defaultdict
|
||||
import itertools
|
||||
import sys
|
||||
from bs4.element import (
|
||||
CharsetMetaAttributeValue,
|
||||
ContentMetaAttributeValue,
|
||||
whitespace_re
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
'HTMLTreeBuilder',
|
||||
'SAXTreeBuilder',
|
||||
'TreeBuilder',
|
||||
'TreeBuilderRegistry',
|
||||
]
|
||||
|
||||
# Some useful features for a TreeBuilder to have.
|
||||
FAST = 'fast'
|
||||
PERMISSIVE = 'permissive'
|
||||
STRICT = 'strict'
|
||||
XML = 'xml'
|
||||
HTML = 'html'
|
||||
HTML_5 = 'html5'
|
||||
|
||||
|
||||
class TreeBuilderRegistry(object):
|
||||
|
||||
def __init__(self):
|
||||
self.builders_for_feature = defaultdict(list)
|
||||
self.builders = []
|
||||
|
||||
def register(self, treebuilder_class):
|
||||
"""Register a treebuilder based on its advertised features."""
|
||||
for feature in treebuilder_class.features:
|
||||
self.builders_for_feature[feature].insert(0, treebuilder_class)
|
||||
self.builders.insert(0, treebuilder_class)
|
||||
|
||||
def lookup(self, *features):
|
||||
if len(self.builders) == 0:
|
||||
# There are no builders at all.
|
||||
return None
|
||||
|
||||
if len(features) == 0:
|
||||
# They didn't ask for any features. Give them the most
|
||||
# recently registered builder.
|
||||
return self.builders[0]
|
||||
|
||||
# Go down the list of features in order, and eliminate any builders
|
||||
# that don't match every feature.
|
||||
features = list(features)
|
||||
features.reverse()
|
||||
candidates = None
|
||||
candidate_set = None
|
||||
while len(features) > 0:
|
||||
feature = features.pop()
|
||||
we_have_the_feature = self.builders_for_feature.get(feature, [])
|
||||
if len(we_have_the_feature) > 0:
|
||||
if candidates is None:
|
||||
candidates = we_have_the_feature
|
||||
candidate_set = set(candidates)
|
||||
else:
|
||||
# Eliminate any candidates that don't have this feature.
|
||||
candidate_set = candidate_set.intersection(
|
||||
set(we_have_the_feature))
|
||||
|
||||
# The only valid candidates are the ones in candidate_set.
|
||||
# Go through the original list of candidates and pick the first one
|
||||
# that's in candidate_set.
|
||||
if candidate_set is None:
|
||||
return None
|
||||
for candidate in candidates:
|
||||
if candidate in candidate_set:
|
||||
return candidate
|
||||
return None
|
||||
|
||||
# The BeautifulSoup class will take feature lists from developers and use them
|
||||
# to look up builders in this registry.
|
||||
builder_registry = TreeBuilderRegistry()
|
||||
|
||||
class TreeBuilder(object):
|
||||
"""Turn a document into a Beautiful Soup object tree."""
|
||||
|
||||
NAME = "[Unknown tree builder]"
|
||||
ALTERNATE_NAMES = []
|
||||
features = []
|
||||
|
||||
is_xml = False
|
||||
picklable = False
|
||||
preserve_whitespace_tags = set()
|
||||
empty_element_tags = None # A tag will be considered an empty-element
|
||||
# tag when and only when it has no contents.
|
||||
|
||||
# A value for these tag/attribute combinations is a space- or
|
||||
# comma-separated list of CDATA, rather than a single CDATA.
|
||||
cdata_list_attributes = {}
|
||||
|
||||
|
||||
def __init__(self):
|
||||
self.soup = None
|
||||
|
||||
def reset(self):
|
||||
pass
|
||||
|
||||
def can_be_empty_element(self, tag_name):
|
||||
"""Might a tag with this name be an empty-element tag?
|
||||
|
||||
The final markup may or may not actually present this tag as
|
||||
self-closing.
|
||||
|
||||
For instance: an HTMLBuilder does not consider a <p> tag to be
|
||||
an empty-element tag (it's not in
|
||||
HTMLBuilder.empty_element_tags). This means an empty <p> tag
|
||||
will be presented as "<p></p>", not "<p />".
|
||||
|
||||
The default implementation has no opinion about which tags are
|
||||
empty-element tags, so a tag will be presented as an
|
||||
empty-element tag if and only if it has no contents.
|
||||
"<foo></foo>" will become "<foo />", and "<foo>bar</foo>" will
|
||||
be left alone.
|
||||
"""
|
||||
if self.empty_element_tags is None:
|
||||
return True
|
||||
return tag_name in self.empty_element_tags
|
||||
|
||||
def feed(self, markup):
|
||||
raise NotImplementedError()
|
||||
|
||||
def prepare_markup(self, markup, user_specified_encoding=None,
|
||||
document_declared_encoding=None):
|
||||
return markup, None, None, False
|
||||
|
||||
def test_fragment_to_document(self, fragment):
|
||||
"""Wrap an HTML fragment to make it look like a document.
|
||||
|
||||
Different parsers do this differently. For instance, lxml
|
||||
introduces an empty <head> tag, and html5lib
|
||||
doesn't. Abstracting this away lets us write simple tests
|
||||
which run HTML fragments through the parser and compare the
|
||||
results against other HTML fragments.
|
||||
|
||||
This method should not be used outside of tests.
|
||||
"""
|
||||
return fragment
|
||||
|
||||
def set_up_substitutions(self, tag):
|
||||
return False
|
||||
|
||||
def _replace_cdata_list_attribute_values(self, tag_name, attrs):
|
||||
"""Replaces class="foo bar" with class=["foo", "bar"]
|
||||
|
||||
Modifies its input in place.
|
||||
"""
|
||||
if not attrs:
|
||||
return attrs
|
||||
if self.cdata_list_attributes:
|
||||
universal = self.cdata_list_attributes.get('*', [])
|
||||
tag_specific = self.cdata_list_attributes.get(
|
||||
tag_name.lower(), None)
|
||||
for attr in attrs.keys():
|
||||
if attr in universal or (tag_specific and attr in tag_specific):
|
||||
# We have a "class"-type attribute whose string
|
||||
# value is a whitespace-separated list of
|
||||
# values. Split it into a list.
|
||||
value = attrs[attr]
|
||||
if isinstance(value, basestring):
|
||||
values = whitespace_re.split(value)
|
||||
else:
|
||||
# html5lib sometimes calls setAttributes twice
|
||||
# for the same tag when rearranging the parse
|
||||
# tree. On the second call the attribute value
|
||||
# here is already a list. If this happens,
|
||||
# leave the value alone rather than trying to
|
||||
# split it again.
|
||||
values = value
|
||||
attrs[attr] = values
|
||||
return attrs
|
||||
|
||||
class SAXTreeBuilder(TreeBuilder):
|
||||
"""A Beautiful Soup treebuilder that listens for SAX events."""
|
||||
|
||||
def feed(self, markup):
|
||||
raise NotImplementedError()
|
||||
|
||||
def close(self):
|
||||
pass
|
||||
|
||||
def startElement(self, name, attrs):
|
||||
attrs = dict((key[1], value) for key, value in list(attrs.items()))
|
||||
#print "Start %s, %r" % (name, attrs)
|
||||
self.soup.handle_starttag(name, attrs)
|
||||
|
||||
def endElement(self, name):
|
||||
#print "End %s" % name
|
||||
self.soup.handle_endtag(name)
|
||||
|
||||
def startElementNS(self, nsTuple, nodeName, attrs):
|
||||
# Throw away (ns, nodeName) for now.
|
||||
self.startElement(nodeName, attrs)
|
||||
|
||||
def endElementNS(self, nsTuple, nodeName):
|
||||
# Throw away (ns, nodeName) for now.
|
||||
self.endElement(nodeName)
|
||||
#handler.endElementNS((ns, node.nodeName), node.nodeName)
|
||||
|
||||
def startPrefixMapping(self, prefix, nodeValue):
|
||||
# Ignore the prefix for now.
|
||||
pass
|
||||
|
||||
def endPrefixMapping(self, prefix):
|
||||
# Ignore the prefix for now.
|
||||
# handler.endPrefixMapping(prefix)
|
||||
pass
|
||||
|
||||
def characters(self, content):
|
||||
self.soup.handle_data(content)
|
||||
|
||||
def startDocument(self):
|
||||
pass
|
||||
|
||||
def endDocument(self):
|
||||
pass
|
||||
|
||||
|
||||
class HTMLTreeBuilder(TreeBuilder):
|
||||
"""This TreeBuilder knows facts about HTML.
|
||||
|
||||
Such as which tags are empty-element tags.
|
||||
"""
|
||||
|
||||
preserve_whitespace_tags = set(['pre', 'textarea'])
|
||||
empty_element_tags = set(['br' , 'hr', 'input', 'img', 'meta',
|
||||
'spacer', 'link', 'frame', 'base'])
|
||||
|
||||
# The HTML standard defines these attributes as containing a
|
||||
# space-separated list of values, not a single value. That is,
|
||||
# class="foo bar" means that the 'class' attribute has two values,
|
||||
# 'foo' and 'bar', not the single value 'foo bar'. When we
|
||||
# encounter one of these attributes, we will parse its value into
|
||||
# a list of values if possible. Upon output, the list will be
|
||||
# converted back into a string.
|
||||
cdata_list_attributes = {
|
||||
"*" : ['class', 'accesskey', 'dropzone'],
|
||||
"a" : ['rel', 'rev'],
|
||||
"link" : ['rel', 'rev'],
|
||||
"td" : ["headers"],
|
||||
"th" : ["headers"],
|
||||
"td" : ["headers"],
|
||||
"form" : ["accept-charset"],
|
||||
"object" : ["archive"],
|
||||
|
||||
# These are HTML5 specific, as are *.accesskey and *.dropzone above.
|
||||
"area" : ["rel"],
|
||||
"icon" : ["sizes"],
|
||||
"iframe" : ["sandbox"],
|
||||
"output" : ["for"],
|
||||
}
|
||||
|
||||
def set_up_substitutions(self, tag):
|
||||
# We are only interested in <meta> tags
|
||||
if tag.name != 'meta':
|
||||
return False
|
||||
|
||||
http_equiv = tag.get('http-equiv')
|
||||
content = tag.get('content')
|
||||
charset = tag.get('charset')
|
||||
|
||||
# We are interested in <meta> tags that say what encoding the
|
||||
# document was originally in. This means HTML 5-style <meta>
|
||||
# tags that provide the "charset" attribute. It also means
|
||||
# HTML 4-style <meta> tags that provide the "content"
|
||||
# attribute and have "http-equiv" set to "content-type".
|
||||
#
|
||||
# In both cases we will replace the value of the appropriate
|
||||
# attribute with a standin object that can take on any
|
||||
# encoding.
|
||||
meta_encoding = None
|
||||
if charset is not None:
|
||||
# HTML 5 style:
|
||||
# <meta charset="utf8">
|
||||
meta_encoding = charset
|
||||
tag['charset'] = CharsetMetaAttributeValue(charset)
|
||||
|
||||
elif (content is not None and http_equiv is not None
|
||||
and http_equiv.lower() == 'content-type'):
|
||||
# HTML 4 style:
|
||||
# <meta http-equiv="content-type" content="text/html; charset=utf8">
|
||||
tag['content'] = ContentMetaAttributeValue(content)
|
||||
|
||||
return (meta_encoding is not None)
|
||||
|
||||
def register_treebuilders_from(module):
|
||||
"""Copy TreeBuilders from the given module into this module."""
|
||||
# I'm fairly sure this is not the best way to do this.
|
||||
this_module = sys.modules['bs4.builder']
|
||||
for name in module.__all__:
|
||||
obj = getattr(module, name)
|
||||
|
||||
if issubclass(obj, TreeBuilder):
|
||||
setattr(this_module, name, obj)
|
||||
this_module.__all__.append(name)
|
||||
# Register the builder while we're at it.
|
||||
this_module.builder_registry.register(obj)
|
||||
|
||||
class ParserRejectedMarkup(Exception):
|
||||
pass
|
||||
|
||||
# Builders are registered in reverse order of priority, so that custom
|
||||
# builder registrations will take precedence. In general, we want lxml
|
||||
# to take precedence over html5lib, because it's faster. And we only
|
||||
# want to use HTMLParser as a last result.
|
||||
from . import _htmlparser
|
||||
register_treebuilders_from(_htmlparser)
|
||||
try:
|
||||
from . import _html5lib
|
||||
register_treebuilders_from(_html5lib)
|
||||
except ImportError:
|
||||
# They don't have html5lib installed.
|
||||
pass
|
||||
try:
|
||||
from . import _lxml
|
||||
register_treebuilders_from(_lxml)
|
||||
except ImportError:
|
||||
# They don't have lxml installed.
|
||||
pass
|
||||
@@ -0,0 +1,329 @@
|
||||
__all__ = [
|
||||
'HTML5TreeBuilder',
|
||||
]
|
||||
|
||||
from pdb import set_trace
|
||||
import warnings
|
||||
from bs4.builder import (
|
||||
PERMISSIVE,
|
||||
HTML,
|
||||
HTML_5,
|
||||
HTMLTreeBuilder,
|
||||
)
|
||||
from bs4.element import (
|
||||
NamespacedAttribute,
|
||||
whitespace_re,
|
||||
)
|
||||
import html5lib
|
||||
from html5lib.constants import namespaces
|
||||
from bs4.element import (
|
||||
Comment,
|
||||
Doctype,
|
||||
NavigableString,
|
||||
Tag,
|
||||
)
|
||||
|
||||
class HTML5TreeBuilder(HTMLTreeBuilder):
|
||||
"""Use html5lib to build a tree."""
|
||||
|
||||
NAME = "html5lib"
|
||||
|
||||
features = [NAME, PERMISSIVE, HTML_5, HTML]
|
||||
|
||||
def prepare_markup(self, markup, user_specified_encoding,
|
||||
document_declared_encoding=None, exclude_encodings=None):
|
||||
# Store the user-specified encoding for use later on.
|
||||
self.user_specified_encoding = user_specified_encoding
|
||||
|
||||
# document_declared_encoding and exclude_encodings aren't used
|
||||
# ATM because the html5lib TreeBuilder doesn't use
|
||||
# UnicodeDammit.
|
||||
if exclude_encodings:
|
||||
warnings.warn("You provided a value for exclude_encoding, but the html5lib tree builder doesn't support exclude_encoding.")
|
||||
yield (markup, None, None, False)
|
||||
|
||||
# These methods are defined by Beautiful Soup.
|
||||
def feed(self, markup):
|
||||
if self.soup.parse_only is not None:
|
||||
warnings.warn("You provided a value for parse_only, but the html5lib tree builder doesn't support parse_only. The entire document will be parsed.")
|
||||
parser = html5lib.HTMLParser(tree=self.create_treebuilder)
|
||||
doc = parser.parse(markup, encoding=self.user_specified_encoding)
|
||||
|
||||
# Set the character encoding detected by the tokenizer.
|
||||
if isinstance(markup, unicode):
|
||||
# We need to special-case this because html5lib sets
|
||||
# charEncoding to UTF-8 if it gets Unicode input.
|
||||
doc.original_encoding = None
|
||||
else:
|
||||
doc.original_encoding = parser.tokenizer.stream.charEncoding[0]
|
||||
|
||||
def create_treebuilder(self, namespaceHTMLElements):
|
||||
self.underlying_builder = TreeBuilderForHtml5lib(
|
||||
self.soup, namespaceHTMLElements)
|
||||
return self.underlying_builder
|
||||
|
||||
def test_fragment_to_document(self, fragment):
|
||||
"""See `TreeBuilder`."""
|
||||
return u'<html><head></head><body>%s</body></html>' % fragment
|
||||
|
||||
|
||||
class TreeBuilderForHtml5lib(html5lib.treebuilders._base.TreeBuilder):
|
||||
|
||||
def __init__(self, soup, namespaceHTMLElements):
|
||||
self.soup = soup
|
||||
super(TreeBuilderForHtml5lib, self).__init__(namespaceHTMLElements)
|
||||
|
||||
def documentClass(self):
|
||||
self.soup.reset()
|
||||
return Element(self.soup, self.soup, None)
|
||||
|
||||
def insertDoctype(self, token):
|
||||
name = token["name"]
|
||||
publicId = token["publicId"]
|
||||
systemId = token["systemId"]
|
||||
|
||||
doctype = Doctype.for_name_and_ids(name, publicId, systemId)
|
||||
self.soup.object_was_parsed(doctype)
|
||||
|
||||
def elementClass(self, name, namespace):
|
||||
tag = self.soup.new_tag(name, namespace)
|
||||
return Element(tag, self.soup, namespace)
|
||||
|
||||
def commentClass(self, data):
|
||||
return TextNode(Comment(data), self.soup)
|
||||
|
||||
def fragmentClass(self):
|
||||
self.soup = BeautifulSoup("")
|
||||
self.soup.name = "[document_fragment]"
|
||||
return Element(self.soup, self.soup, None)
|
||||
|
||||
def appendChild(self, node):
|
||||
# XXX This code is not covered by the BS4 tests.
|
||||
self.soup.append(node.element)
|
||||
|
||||
def getDocument(self):
|
||||
return self.soup
|
||||
|
||||
def getFragment(self):
|
||||
return html5lib.treebuilders._base.TreeBuilder.getFragment(self).element
|
||||
|
||||
class AttrList(object):
|
||||
def __init__(self, element):
|
||||
self.element = element
|
||||
self.attrs = dict(self.element.attrs)
|
||||
def __iter__(self):
|
||||
return list(self.attrs.items()).__iter__()
|
||||
def __setitem__(self, name, value):
|
||||
# If this attribute is a multi-valued attribute for this element,
|
||||
# turn its value into a list.
|
||||
list_attr = HTML5TreeBuilder.cdata_list_attributes
|
||||
if (name in list_attr['*']
|
||||
or (self.element.name in list_attr
|
||||
and name in list_attr[self.element.name])):
|
||||
value = whitespace_re.split(value)
|
||||
self.element[name] = value
|
||||
def items(self):
|
||||
return list(self.attrs.items())
|
||||
def keys(self):
|
||||
return list(self.attrs.keys())
|
||||
def __len__(self):
|
||||
return len(self.attrs)
|
||||
def __getitem__(self, name):
|
||||
return self.attrs[name]
|
||||
def __contains__(self, name):
|
||||
return name in list(self.attrs.keys())
|
||||
|
||||
|
||||
class Element(html5lib.treebuilders._base.Node):
|
||||
def __init__(self, element, soup, namespace):
|
||||
html5lib.treebuilders._base.Node.__init__(self, element.name)
|
||||
self.element = element
|
||||
self.soup = soup
|
||||
self.namespace = namespace
|
||||
|
||||
def appendChild(self, node):
|
||||
string_child = child = None
|
||||
if isinstance(node, basestring):
|
||||
# Some other piece of code decided to pass in a string
|
||||
# instead of creating a TextElement object to contain the
|
||||
# string.
|
||||
string_child = child = node
|
||||
elif isinstance(node, Tag):
|
||||
# Some other piece of code decided to pass in a Tag
|
||||
# instead of creating an Element object to contain the
|
||||
# Tag.
|
||||
child = node
|
||||
elif node.element.__class__ == NavigableString:
|
||||
string_child = child = node.element
|
||||
else:
|
||||
child = node.element
|
||||
|
||||
if not isinstance(child, basestring) and child.parent is not None:
|
||||
node.element.extract()
|
||||
|
||||
if (string_child and self.element.contents
|
||||
and self.element.contents[-1].__class__ == NavigableString):
|
||||
# We are appending a string onto another string.
|
||||
# TODO This has O(n^2) performance, for input like
|
||||
# "a</a>a</a>a</a>..."
|
||||
old_element = self.element.contents[-1]
|
||||
new_element = self.soup.new_string(old_element + string_child)
|
||||
old_element.replace_with(new_element)
|
||||
self.soup._most_recent_element = new_element
|
||||
else:
|
||||
if isinstance(node, basestring):
|
||||
# Create a brand new NavigableString from this string.
|
||||
child = self.soup.new_string(node)
|
||||
|
||||
# Tell Beautiful Soup to act as if it parsed this element
|
||||
# immediately after the parent's last descendant. (Or
|
||||
# immediately after the parent, if it has no children.)
|
||||
if self.element.contents:
|
||||
most_recent_element = self.element._last_descendant(False)
|
||||
elif self.element.next_element is not None:
|
||||
# Something from further ahead in the parse tree is
|
||||
# being inserted into this earlier element. This is
|
||||
# very annoying because it means an expensive search
|
||||
# for the last element in the tree.
|
||||
most_recent_element = self.soup._last_descendant()
|
||||
else:
|
||||
most_recent_element = self.element
|
||||
|
||||
self.soup.object_was_parsed(
|
||||
child, parent=self.element,
|
||||
most_recent_element=most_recent_element)
|
||||
|
||||
def getAttributes(self):
|
||||
return AttrList(self.element)
|
||||
|
||||
def setAttributes(self, attributes):
|
||||
|
||||
if attributes is not None and len(attributes) > 0:
|
||||
|
||||
converted_attributes = []
|
||||
for name, value in list(attributes.items()):
|
||||
if isinstance(name, tuple):
|
||||
new_name = NamespacedAttribute(*name)
|
||||
del attributes[name]
|
||||
attributes[new_name] = value
|
||||
|
||||
self.soup.builder._replace_cdata_list_attribute_values(
|
||||
self.name, attributes)
|
||||
for name, value in attributes.items():
|
||||
self.element[name] = value
|
||||
|
||||
# The attributes may contain variables that need substitution.
|
||||
# Call set_up_substitutions manually.
|
||||
#
|
||||
# The Tag constructor called this method when the Tag was created,
|
||||
# but we just set/changed the attributes, so call it again.
|
||||
self.soup.builder.set_up_substitutions(self.element)
|
||||
attributes = property(getAttributes, setAttributes)
|
||||
|
||||
def insertText(self, data, insertBefore=None):
|
||||
if insertBefore:
|
||||
text = TextNode(self.soup.new_string(data), self.soup)
|
||||
self.insertBefore(data, insertBefore)
|
||||
else:
|
||||
self.appendChild(data)
|
||||
|
||||
def insertBefore(self, node, refNode):
|
||||
index = self.element.index(refNode.element)
|
||||
if (node.element.__class__ == NavigableString and self.element.contents
|
||||
and self.element.contents[index-1].__class__ == NavigableString):
|
||||
# (See comments in appendChild)
|
||||
old_node = self.element.contents[index-1]
|
||||
new_str = self.soup.new_string(old_node + node.element)
|
||||
old_node.replace_with(new_str)
|
||||
else:
|
||||
self.element.insert(index, node.element)
|
||||
node.parent = self
|
||||
|
||||
def removeChild(self, node):
|
||||
node.element.extract()
|
||||
|
||||
def reparentChildren(self, new_parent):
|
||||
"""Move all of this tag's children into another tag."""
|
||||
# print "MOVE", self.element.contents
|
||||
# print "FROM", self.element
|
||||
# print "TO", new_parent.element
|
||||
element = self.element
|
||||
new_parent_element = new_parent.element
|
||||
# Determine what this tag's next_element will be once all the children
|
||||
# are removed.
|
||||
final_next_element = element.next_sibling
|
||||
|
||||
new_parents_last_descendant = new_parent_element._last_descendant(False, False)
|
||||
if len(new_parent_element.contents) > 0:
|
||||
# The new parent already contains children. We will be
|
||||
# appending this tag's children to the end.
|
||||
new_parents_last_child = new_parent_element.contents[-1]
|
||||
new_parents_last_descendant_next_element = new_parents_last_descendant.next_element
|
||||
else:
|
||||
# The new parent contains no children.
|
||||
new_parents_last_child = None
|
||||
new_parents_last_descendant_next_element = new_parent_element.next_element
|
||||
|
||||
to_append = element.contents
|
||||
append_after = new_parent_element.contents
|
||||
if len(to_append) > 0:
|
||||
# Set the first child's previous_element and previous_sibling
|
||||
# to elements within the new parent
|
||||
first_child = to_append[0]
|
||||
if new_parents_last_descendant:
|
||||
first_child.previous_element = new_parents_last_descendant
|
||||
else:
|
||||
first_child.previous_element = new_parent_element
|
||||
first_child.previous_sibling = new_parents_last_child
|
||||
if new_parents_last_descendant:
|
||||
new_parents_last_descendant.next_element = first_child
|
||||
else:
|
||||
new_parent_element.next_element = first_child
|
||||
if new_parents_last_child:
|
||||
new_parents_last_child.next_sibling = first_child
|
||||
|
||||
# Fix the last child's next_element and next_sibling
|
||||
last_child = to_append[-1]
|
||||
last_child.next_element = new_parents_last_descendant_next_element
|
||||
if new_parents_last_descendant_next_element:
|
||||
new_parents_last_descendant_next_element.previous_element = last_child
|
||||
last_child.next_sibling = None
|
||||
|
||||
for child in to_append:
|
||||
child.parent = new_parent_element
|
||||
new_parent_element.contents.append(child)
|
||||
|
||||
# Now that this element has no children, change its .next_element.
|
||||
element.contents = []
|
||||
element.next_element = final_next_element
|
||||
|
||||
# print "DONE WITH MOVE"
|
||||
# print "FROM", self.element
|
||||
# print "TO", new_parent_element
|
||||
|
||||
def cloneNode(self):
|
||||
tag = self.soup.new_tag(self.element.name, self.namespace)
|
||||
node = Element(tag, self.soup, self.namespace)
|
||||
for key,value in self.attributes:
|
||||
node.attributes[key] = value
|
||||
return node
|
||||
|
||||
def hasContent(self):
|
||||
return self.element.contents
|
||||
|
||||
def getNameTuple(self):
|
||||
if self.namespace == None:
|
||||
return namespaces["html"], self.name
|
||||
else:
|
||||
return self.namespace, self.name
|
||||
|
||||
nameTuple = property(getNameTuple)
|
||||
|
||||
class TextNode(Element):
|
||||
def __init__(self, element, soup):
|
||||
html5lib.treebuilders._base.Node.__init__(self, None)
|
||||
self.element = element
|
||||
self.soup = soup
|
||||
|
||||
def cloneNode(self):
|
||||
raise NotImplementedError
|
||||
@@ -0,0 +1,262 @@
|
||||
"""Use the HTMLParser library to parse HTML files that aren't too bad."""
|
||||
|
||||
__all__ = [
|
||||
'HTMLParserTreeBuilder',
|
||||
]
|
||||
|
||||
from HTMLParser import HTMLParser
|
||||
|
||||
try:
|
||||
from HTMLParser import HTMLParseError
|
||||
except ImportError, e:
|
||||
# HTMLParseError is removed in Python 3.5. Since it can never be
|
||||
# thrown in 3.5, we can just define our own class as a placeholder.
|
||||
class HTMLParseError(Exception):
|
||||
pass
|
||||
|
||||
import sys
|
||||
import warnings
|
||||
|
||||
# Starting in Python 3.2, the HTMLParser constructor takes a 'strict'
|
||||
# argument, which we'd like to set to False. Unfortunately,
|
||||
# http://bugs.python.org/issue13273 makes strict=True a better bet
|
||||
# before Python 3.2.3.
|
||||
#
|
||||
# At the end of this file, we monkeypatch HTMLParser so that
|
||||
# strict=True works well on Python 3.2.2.
|
||||
major, minor, release = sys.version_info[:3]
|
||||
CONSTRUCTOR_TAKES_STRICT = major == 3 and minor == 2 and release >= 3
|
||||
CONSTRUCTOR_STRICT_IS_DEPRECATED = major == 3 and minor == 3
|
||||
CONSTRUCTOR_TAKES_CONVERT_CHARREFS = major == 3 and minor >= 4
|
||||
|
||||
|
||||
from bs4.element import (
|
||||
CData,
|
||||
Comment,
|
||||
Declaration,
|
||||
Doctype,
|
||||
ProcessingInstruction,
|
||||
)
|
||||
from bs4.dammit import EntitySubstitution, UnicodeDammit
|
||||
|
||||
from bs4.builder import (
|
||||
HTML,
|
||||
HTMLTreeBuilder,
|
||||
STRICT,
|
||||
)
|
||||
|
||||
|
||||
HTMLPARSER = 'html.parser'
|
||||
|
||||
class BeautifulSoupHTMLParser(HTMLParser):
|
||||
def handle_starttag(self, name, attrs):
|
||||
# XXX namespace
|
||||
attr_dict = {}
|
||||
for key, value in attrs:
|
||||
# Change None attribute values to the empty string
|
||||
# for consistency with the other tree builders.
|
||||
if value is None:
|
||||
value = ''
|
||||
attr_dict[key] = value
|
||||
attrvalue = '""'
|
||||
self.soup.handle_starttag(name, None, None, attr_dict)
|
||||
|
||||
def handle_endtag(self, name):
|
||||
self.soup.handle_endtag(name)
|
||||
|
||||
def handle_data(self, data):
|
||||
self.soup.handle_data(data)
|
||||
|
||||
def handle_charref(self, name):
|
||||
# XXX workaround for a bug in HTMLParser. Remove this once
|
||||
# it's fixed in all supported versions.
|
||||
# http://bugs.python.org/issue13633
|
||||
if name.startswith('x'):
|
||||
real_name = int(name.lstrip('x'), 16)
|
||||
elif name.startswith('X'):
|
||||
real_name = int(name.lstrip('X'), 16)
|
||||
else:
|
||||
real_name = int(name)
|
||||
|
||||
try:
|
||||
data = unichr(real_name)
|
||||
except (ValueError, OverflowError), e:
|
||||
data = u"\N{REPLACEMENT CHARACTER}"
|
||||
|
||||
self.handle_data(data)
|
||||
|
||||
def handle_entityref(self, name):
|
||||
character = EntitySubstitution.HTML_ENTITY_TO_CHARACTER.get(name)
|
||||
if character is not None:
|
||||
data = character
|
||||
else:
|
||||
data = "&%s;" % name
|
||||
self.handle_data(data)
|
||||
|
||||
def handle_comment(self, data):
|
||||
self.soup.endData()
|
||||
self.soup.handle_data(data)
|
||||
self.soup.endData(Comment)
|
||||
|
||||
def handle_decl(self, data):
|
||||
self.soup.endData()
|
||||
if data.startswith("DOCTYPE "):
|
||||
data = data[len("DOCTYPE "):]
|
||||
elif data == 'DOCTYPE':
|
||||
# i.e. "<!DOCTYPE>"
|
||||
data = ''
|
||||
self.soup.handle_data(data)
|
||||
self.soup.endData(Doctype)
|
||||
|
||||
def unknown_decl(self, data):
|
||||
if data.upper().startswith('CDATA['):
|
||||
cls = CData
|
||||
data = data[len('CDATA['):]
|
||||
else:
|
||||
cls = Declaration
|
||||
self.soup.endData()
|
||||
self.soup.handle_data(data)
|
||||
self.soup.endData(cls)
|
||||
|
||||
def handle_pi(self, data):
|
||||
self.soup.endData()
|
||||
self.soup.handle_data(data)
|
||||
self.soup.endData(ProcessingInstruction)
|
||||
|
||||
|
||||
class HTMLParserTreeBuilder(HTMLTreeBuilder):
|
||||
|
||||
is_xml = False
|
||||
picklable = True
|
||||
NAME = HTMLPARSER
|
||||
features = [NAME, HTML, STRICT]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
if CONSTRUCTOR_TAKES_STRICT and not CONSTRUCTOR_STRICT_IS_DEPRECATED:
|
||||
kwargs['strict'] = False
|
||||
if CONSTRUCTOR_TAKES_CONVERT_CHARREFS:
|
||||
kwargs['convert_charrefs'] = False
|
||||
self.parser_args = (args, kwargs)
|
||||
|
||||
def prepare_markup(self, markup, user_specified_encoding=None,
|
||||
document_declared_encoding=None, exclude_encodings=None):
|
||||
"""
|
||||
:return: A 4-tuple (markup, original encoding, encoding
|
||||
declared within markup, whether any characters had to be
|
||||
replaced with REPLACEMENT CHARACTER).
|
||||
"""
|
||||
if isinstance(markup, unicode):
|
||||
yield (markup, None, None, False)
|
||||
return
|
||||
|
||||
try_encodings = [user_specified_encoding, document_declared_encoding]
|
||||
dammit = UnicodeDammit(markup, try_encodings, is_html=True,
|
||||
exclude_encodings=exclude_encodings)
|
||||
yield (dammit.markup, dammit.original_encoding,
|
||||
dammit.declared_html_encoding,
|
||||
dammit.contains_replacement_characters)
|
||||
|
||||
def feed(self, markup):
|
||||
args, kwargs = self.parser_args
|
||||
parser = BeautifulSoupHTMLParser(*args, **kwargs)
|
||||
parser.soup = self.soup
|
||||
try:
|
||||
parser.feed(markup)
|
||||
except HTMLParseError, e:
|
||||
warnings.warn(RuntimeWarning(
|
||||
"Python's built-in HTMLParser cannot parse the given document. This is not a bug in Beautiful Soup. The best solution is to install an external parser (lxml or html5lib), and use Beautiful Soup with that parser. See http://www.crummy.com/software/BeautifulSoup/bs4/doc/#installing-a-parser for help."))
|
||||
raise e
|
||||
|
||||
# Patch 3.2 versions of HTMLParser earlier than 3.2.3 to use some
|
||||
# 3.2.3 code. This ensures they don't treat markup like <p></p> as a
|
||||
# string.
|
||||
#
|
||||
# XXX This code can be removed once most Python 3 users are on 3.2.3.
|
||||
if major == 3 and minor == 2 and not CONSTRUCTOR_TAKES_STRICT:
|
||||
import re
|
||||
attrfind_tolerant = re.compile(
|
||||
r'\s*((?<=[\'"\s])[^\s/>][^\s/=>]*)(\s*=+\s*'
|
||||
r'(\'[^\']*\'|"[^"]*"|(?![\'"])[^>\s]*))?')
|
||||
HTMLParserTreeBuilder.attrfind_tolerant = attrfind_tolerant
|
||||
|
||||
locatestarttagend = re.compile(r"""
|
||||
<[a-zA-Z][-.a-zA-Z0-9:_]* # tag name
|
||||
(?:\s+ # whitespace before attribute name
|
||||
(?:[a-zA-Z_][-.:a-zA-Z0-9_]* # attribute name
|
||||
(?:\s*=\s* # value indicator
|
||||
(?:'[^']*' # LITA-enclosed value
|
||||
|\"[^\"]*\" # LIT-enclosed value
|
||||
|[^'\">\s]+ # bare value
|
||||
)
|
||||
)?
|
||||
)
|
||||
)*
|
||||
\s* # trailing whitespace
|
||||
""", re.VERBOSE)
|
||||
BeautifulSoupHTMLParser.locatestarttagend = locatestarttagend
|
||||
|
||||
from html.parser import tagfind, attrfind
|
||||
|
||||
def parse_starttag(self, i):
|
||||
self.__starttag_text = None
|
||||
endpos = self.check_for_whole_start_tag(i)
|
||||
if endpos < 0:
|
||||
return endpos
|
||||
rawdata = self.rawdata
|
||||
self.__starttag_text = rawdata[i:endpos]
|
||||
|
||||
# Now parse the data between i+1 and j into a tag and attrs
|
||||
attrs = []
|
||||
match = tagfind.match(rawdata, i+1)
|
||||
assert match, 'unexpected call to parse_starttag()'
|
||||
k = match.end()
|
||||
self.lasttag = tag = rawdata[i+1:k].lower()
|
||||
while k < endpos:
|
||||
if self.strict:
|
||||
m = attrfind.match(rawdata, k)
|
||||
else:
|
||||
m = attrfind_tolerant.match(rawdata, k)
|
||||
if not m:
|
||||
break
|
||||
attrname, rest, attrvalue = m.group(1, 2, 3)
|
||||
if not rest:
|
||||
attrvalue = None
|
||||
elif attrvalue[:1] == '\'' == attrvalue[-1:] or \
|
||||
attrvalue[:1] == '"' == attrvalue[-1:]:
|
||||
attrvalue = attrvalue[1:-1]
|
||||
if attrvalue:
|
||||
attrvalue = self.unescape(attrvalue)
|
||||
attrs.append((attrname.lower(), attrvalue))
|
||||
k = m.end()
|
||||
|
||||
end = rawdata[k:endpos].strip()
|
||||
if end not in (">", "/>"):
|
||||
lineno, offset = self.getpos()
|
||||
if "\n" in self.__starttag_text:
|
||||
lineno = lineno + self.__starttag_text.count("\n")
|
||||
offset = len(self.__starttag_text) \
|
||||
- self.__starttag_text.rfind("\n")
|
||||
else:
|
||||
offset = offset + len(self.__starttag_text)
|
||||
if self.strict:
|
||||
self.error("junk characters in start tag: %r"
|
||||
% (rawdata[k:endpos][:20],))
|
||||
self.handle_data(rawdata[i:endpos])
|
||||
return endpos
|
||||
if end.endswith('/>'):
|
||||
# XHTML-style empty tag: <span attr="value" />
|
||||
self.handle_startendtag(tag, attrs)
|
||||
else:
|
||||
self.handle_starttag(tag, attrs)
|
||||
if tag in self.CDATA_CONTENT_ELEMENTS:
|
||||
self.set_cdata_mode(tag)
|
||||
return endpos
|
||||
|
||||
def set_cdata_mode(self, elem):
|
||||
self.cdata_elem = elem.lower()
|
||||
self.interesting = re.compile(r'</\s*%s\s*>' % self.cdata_elem, re.I)
|
||||
|
||||
BeautifulSoupHTMLParser.parse_starttag = parse_starttag
|
||||
BeautifulSoupHTMLParser.set_cdata_mode = set_cdata_mode
|
||||
|
||||
CONSTRUCTOR_TAKES_STRICT = True
|
||||
@@ -0,0 +1,248 @@
|
||||
__all__ = [
|
||||
'LXMLTreeBuilderForXML',
|
||||
'LXMLTreeBuilder',
|
||||
]
|
||||
|
||||
from io import BytesIO
|
||||
from StringIO import StringIO
|
||||
import collections
|
||||
from lxml import etree
|
||||
from bs4.element import (
|
||||
Comment,
|
||||
Doctype,
|
||||
NamespacedAttribute,
|
||||
ProcessingInstruction,
|
||||
)
|
||||
from bs4.builder import (
|
||||
FAST,
|
||||
HTML,
|
||||
HTMLTreeBuilder,
|
||||
PERMISSIVE,
|
||||
ParserRejectedMarkup,
|
||||
TreeBuilder,
|
||||
XML)
|
||||
from bs4.dammit import EncodingDetector
|
||||
|
||||
LXML = 'lxml'
|
||||
|
||||
class LXMLTreeBuilderForXML(TreeBuilder):
|
||||
DEFAULT_PARSER_CLASS = etree.XMLParser
|
||||
|
||||
is_xml = True
|
||||
|
||||
NAME = "lxml-xml"
|
||||
ALTERNATE_NAMES = ["xml"]
|
||||
|
||||
# Well, it's permissive by XML parser standards.
|
||||
features = [NAME, LXML, XML, FAST, PERMISSIVE]
|
||||
|
||||
CHUNK_SIZE = 512
|
||||
|
||||
# This namespace mapping is specified in the XML Namespace
|
||||
# standard.
|
||||
DEFAULT_NSMAPS = {'http://www.w3.org/XML/1998/namespace' : "xml"}
|
||||
|
||||
def default_parser(self, encoding):
|
||||
# This can either return a parser object or a class, which
|
||||
# will be instantiated with default arguments.
|
||||
if self._default_parser is not None:
|
||||
return self._default_parser
|
||||
return etree.XMLParser(
|
||||
target=self, strip_cdata=False, recover=True, encoding=encoding)
|
||||
|
||||
def parser_for(self, encoding):
|
||||
# Use the default parser.
|
||||
parser = self.default_parser(encoding)
|
||||
|
||||
if isinstance(parser, collections.Callable):
|
||||
# Instantiate the parser with default arguments
|
||||
parser = parser(target=self, strip_cdata=False, encoding=encoding)
|
||||
return parser
|
||||
|
||||
def __init__(self, parser=None, empty_element_tags=None):
|
||||
# TODO: Issue a warning if parser is present but not a
|
||||
# callable, since that means there's no way to create new
|
||||
# parsers for different encodings.
|
||||
self._default_parser = parser
|
||||
if empty_element_tags is not None:
|
||||
self.empty_element_tags = set(empty_element_tags)
|
||||
self.soup = None
|
||||
self.nsmaps = [self.DEFAULT_NSMAPS]
|
||||
|
||||
def _getNsTag(self, tag):
|
||||
# Split the namespace URL out of a fully-qualified lxml tag
|
||||
# name. Copied from lxml's src/lxml/sax.py.
|
||||
if tag[0] == '{':
|
||||
return tuple(tag[1:].split('}', 1))
|
||||
else:
|
||||
return (None, tag)
|
||||
|
||||
def prepare_markup(self, markup, user_specified_encoding=None,
|
||||
exclude_encodings=None,
|
||||
document_declared_encoding=None):
|
||||
"""
|
||||
:yield: A series of 4-tuples.
|
||||
(markup, encoding, declared encoding,
|
||||
has undergone character replacement)
|
||||
|
||||
Each 4-tuple represents a strategy for parsing the document.
|
||||
"""
|
||||
if isinstance(markup, unicode):
|
||||
# We were given Unicode. Maybe lxml can parse Unicode on
|
||||
# this system?
|
||||
yield markup, None, document_declared_encoding, False
|
||||
|
||||
if isinstance(markup, unicode):
|
||||
# No, apparently not. Convert the Unicode to UTF-8 and
|
||||
# tell lxml to parse it as UTF-8.
|
||||
yield (markup.encode("utf8"), "utf8",
|
||||
document_declared_encoding, False)
|
||||
|
||||
# Instead of using UnicodeDammit to convert the bytestring to
|
||||
# Unicode using different encodings, use EncodingDetector to
|
||||
# iterate over the encodings, and tell lxml to try to parse
|
||||
# the document as each one in turn.
|
||||
is_html = not self.is_xml
|
||||
try_encodings = [user_specified_encoding, document_declared_encoding]
|
||||
detector = EncodingDetector(
|
||||
markup, try_encodings, is_html, exclude_encodings)
|
||||
for encoding in detector.encodings:
|
||||
yield (detector.markup, encoding, document_declared_encoding, False)
|
||||
|
||||
def feed(self, markup):
|
||||
if isinstance(markup, bytes):
|
||||
markup = BytesIO(markup)
|
||||
elif isinstance(markup, unicode):
|
||||
markup = StringIO(markup)
|
||||
|
||||
# Call feed() at least once, even if the markup is empty,
|
||||
# or the parser won't be initialized.
|
||||
data = markup.read(self.CHUNK_SIZE)
|
||||
try:
|
||||
self.parser = self.parser_for(self.soup.original_encoding)
|
||||
self.parser.feed(data)
|
||||
while len(data) != 0:
|
||||
# Now call feed() on the rest of the data, chunk by chunk.
|
||||
data = markup.read(self.CHUNK_SIZE)
|
||||
if len(data) != 0:
|
||||
self.parser.feed(data)
|
||||
self.parser.close()
|
||||
except (UnicodeDecodeError, LookupError, etree.ParserError), e:
|
||||
raise ParserRejectedMarkup(str(e))
|
||||
|
||||
def close(self):
|
||||
self.nsmaps = [self.DEFAULT_NSMAPS]
|
||||
|
||||
def start(self, name, attrs, nsmap={}):
|
||||
# Make sure attrs is a mutable dict--lxml may send an immutable dictproxy.
|
||||
attrs = dict(attrs)
|
||||
nsprefix = None
|
||||
# Invert each namespace map as it comes in.
|
||||
if len(self.nsmaps) > 1:
|
||||
# There are no new namespaces for this tag, but
|
||||
# non-default namespaces are in play, so we need a
|
||||
# separate tag stack to know when they end.
|
||||
self.nsmaps.append(None)
|
||||
elif len(nsmap) > 0:
|
||||
# A new namespace mapping has come into play.
|
||||
inverted_nsmap = dict((value, key) for key, value in nsmap.items())
|
||||
self.nsmaps.append(inverted_nsmap)
|
||||
# Also treat the namespace mapping as a set of attributes on the
|
||||
# tag, so we can recreate it later.
|
||||
attrs = attrs.copy()
|
||||
for prefix, namespace in nsmap.items():
|
||||
attribute = NamespacedAttribute(
|
||||
"xmlns", prefix, "http://www.w3.org/2000/xmlns/")
|
||||
attrs[attribute] = namespace
|
||||
|
||||
# Namespaces are in play. Find any attributes that came in
|
||||
# from lxml with namespaces attached to their names, and
|
||||
# turn then into NamespacedAttribute objects.
|
||||
new_attrs = {}
|
||||
for attr, value in attrs.items():
|
||||
namespace, attr = self._getNsTag(attr)
|
||||
if namespace is None:
|
||||
new_attrs[attr] = value
|
||||
else:
|
||||
nsprefix = self._prefix_for_namespace(namespace)
|
||||
attr = NamespacedAttribute(nsprefix, attr, namespace)
|
||||
new_attrs[attr] = value
|
||||
attrs = new_attrs
|
||||
|
||||
namespace, name = self._getNsTag(name)
|
||||
nsprefix = self._prefix_for_namespace(namespace)
|
||||
self.soup.handle_starttag(name, namespace, nsprefix, attrs)
|
||||
|
||||
def _prefix_for_namespace(self, namespace):
|
||||
"""Find the currently active prefix for the given namespace."""
|
||||
if namespace is None:
|
||||
return None
|
||||
for inverted_nsmap in reversed(self.nsmaps):
|
||||
if inverted_nsmap is not None and namespace in inverted_nsmap:
|
||||
return inverted_nsmap[namespace]
|
||||
return None
|
||||
|
||||
def end(self, name):
|
||||
self.soup.endData()
|
||||
completed_tag = self.soup.tagStack[-1]
|
||||
namespace, name = self._getNsTag(name)
|
||||
nsprefix = None
|
||||
if namespace is not None:
|
||||
for inverted_nsmap in reversed(self.nsmaps):
|
||||
if inverted_nsmap is not None and namespace in inverted_nsmap:
|
||||
nsprefix = inverted_nsmap[namespace]
|
||||
break
|
||||
self.soup.handle_endtag(name, nsprefix)
|
||||
if len(self.nsmaps) > 1:
|
||||
# This tag, or one of its parents, introduced a namespace
|
||||
# mapping, so pop it off the stack.
|
||||
self.nsmaps.pop()
|
||||
|
||||
def pi(self, target, data):
|
||||
self.soup.endData()
|
||||
self.soup.handle_data(target + ' ' + data)
|
||||
self.soup.endData(ProcessingInstruction)
|
||||
|
||||
def data(self, content):
|
||||
self.soup.handle_data(content)
|
||||
|
||||
def doctype(self, name, pubid, system):
|
||||
self.soup.endData()
|
||||
doctype = Doctype.for_name_and_ids(name, pubid, system)
|
||||
self.soup.object_was_parsed(doctype)
|
||||
|
||||
def comment(self, content):
|
||||
"Handle comments as Comment objects."
|
||||
self.soup.endData()
|
||||
self.soup.handle_data(content)
|
||||
self.soup.endData(Comment)
|
||||
|
||||
def test_fragment_to_document(self, fragment):
|
||||
"""See `TreeBuilder`."""
|
||||
return u'<?xml version="1.0" encoding="utf-8"?>\n%s' % fragment
|
||||
|
||||
|
||||
class LXMLTreeBuilder(HTMLTreeBuilder, LXMLTreeBuilderForXML):
|
||||
|
||||
NAME = LXML
|
||||
ALTERNATE_NAMES = ["lxml-html"]
|
||||
|
||||
features = ALTERNATE_NAMES + [NAME, HTML, FAST, PERMISSIVE]
|
||||
is_xml = False
|
||||
|
||||
def default_parser(self, encoding):
|
||||
return etree.HTMLParser
|
||||
|
||||
def feed(self, markup):
|
||||
encoding = self.soup.original_encoding
|
||||
try:
|
||||
self.parser = self.parser_for(encoding)
|
||||
self.parser.feed(markup)
|
||||
self.parser.close()
|
||||
except (UnicodeDecodeError, LookupError, etree.ParserError), e:
|
||||
raise ParserRejectedMarkup(str(e))
|
||||
|
||||
|
||||
def test_fragment_to_document(self, fragment):
|
||||
"""See `TreeBuilder`."""
|
||||
return u'<html><body>%s</body></html>' % fragment
|
||||
@@ -0,0 +1,839 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Beautiful Soup bonus library: Unicode, Dammit
|
||||
|
||||
This library converts a bytestream to Unicode through any means
|
||||
necessary. It is heavily based on code from Mark Pilgrim's Universal
|
||||
Feed Parser. It works best on XML and HTML, but it does not rewrite the
|
||||
XML or HTML to reflect a new encoding; that's the tree builder's job.
|
||||
"""
|
||||
|
||||
from pdb import set_trace
|
||||
import codecs
|
||||
from htmlentitydefs import codepoint2name
|
||||
import re
|
||||
import logging
|
||||
import string
|
||||
|
||||
# Import a library to autodetect character encodings.
|
||||
chardet_type = None
|
||||
try:
|
||||
# First try the fast C implementation.
|
||||
# PyPI package: cchardet
|
||||
import cchardet
|
||||
def chardet_dammit(s):
|
||||
return cchardet.detect(s)['encoding']
|
||||
except ImportError:
|
||||
try:
|
||||
# Fall back to the pure Python implementation
|
||||
# Debian package: python-chardet
|
||||
# PyPI package: chardet
|
||||
import chardet
|
||||
def chardet_dammit(s):
|
||||
return chardet.detect(s)['encoding']
|
||||
#import chardet.constants
|
||||
#chardet.constants._debug = 1
|
||||
except ImportError:
|
||||
# No chardet available.
|
||||
def chardet_dammit(s):
|
||||
return None
|
||||
|
||||
# Available from http://cjkpython.i18n.org/.
|
||||
try:
|
||||
import iconv_codec
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
xml_encoding_re = re.compile(
|
||||
'^<\?.*encoding=[\'"](.*?)[\'"].*\?>'.encode(), re.I)
|
||||
html_meta_re = re.compile(
|
||||
'<\s*meta[^>]+charset\s*=\s*["\']?([^>]*?)[ /;\'">]'.encode(), re.I)
|
||||
|
||||
class EntitySubstitution(object):
|
||||
|
||||
"""Substitute XML or HTML entities for the corresponding characters."""
|
||||
|
||||
def _populate_class_variables():
|
||||
lookup = {}
|
||||
reverse_lookup = {}
|
||||
characters_for_re = []
|
||||
for codepoint, name in list(codepoint2name.items()):
|
||||
character = unichr(codepoint)
|
||||
if codepoint != 34:
|
||||
# There's no point in turning the quotation mark into
|
||||
# ", unless it happens within an attribute value, which
|
||||
# is handled elsewhere.
|
||||
characters_for_re.append(character)
|
||||
lookup[character] = name
|
||||
# But we do want to turn " into the quotation mark.
|
||||
reverse_lookup[name] = character
|
||||
re_definition = "[%s]" % "".join(characters_for_re)
|
||||
return lookup, reverse_lookup, re.compile(re_definition)
|
||||
(CHARACTER_TO_HTML_ENTITY, HTML_ENTITY_TO_CHARACTER,
|
||||
CHARACTER_TO_HTML_ENTITY_RE) = _populate_class_variables()
|
||||
|
||||
CHARACTER_TO_XML_ENTITY = {
|
||||
"'": "apos",
|
||||
'"': "quot",
|
||||
"&": "amp",
|
||||
"<": "lt",
|
||||
">": "gt",
|
||||
}
|
||||
|
||||
BARE_AMPERSAND_OR_BRACKET = re.compile("([<>]|"
|
||||
"&(?!#\d+;|#x[0-9a-fA-F]+;|\w+;)"
|
||||
")")
|
||||
|
||||
AMPERSAND_OR_BRACKET = re.compile("([<>&])")
|
||||
|
||||
@classmethod
|
||||
def _substitute_html_entity(cls, matchobj):
|
||||
entity = cls.CHARACTER_TO_HTML_ENTITY.get(matchobj.group(0))
|
||||
return "&%s;" % entity
|
||||
|
||||
@classmethod
|
||||
def _substitute_xml_entity(cls, matchobj):
|
||||
"""Used with a regular expression to substitute the
|
||||
appropriate XML entity for an XML special character."""
|
||||
entity = cls.CHARACTER_TO_XML_ENTITY[matchobj.group(0)]
|
||||
return "&%s;" % entity
|
||||
|
||||
@classmethod
|
||||
def quoted_attribute_value(self, value):
|
||||
"""Make a value into a quoted XML attribute, possibly escaping it.
|
||||
|
||||
Most strings will be quoted using double quotes.
|
||||
|
||||
Bob's Bar -> "Bob's Bar"
|
||||
|
||||
If a string contains double quotes, it will be quoted using
|
||||
single quotes.
|
||||
|
||||
Welcome to "my bar" -> 'Welcome to "my bar"'
|
||||
|
||||
If a string contains both single and double quotes, the
|
||||
double quotes will be escaped, and the string will be quoted
|
||||
using double quotes.
|
||||
|
||||
Welcome to "Bob's Bar" -> "Welcome to "Bob's bar"
|
||||
"""
|
||||
quote_with = '"'
|
||||
if '"' in value:
|
||||
if "'" in value:
|
||||
# The string contains both single and double
|
||||
# quotes. Turn the double quotes into
|
||||
# entities. We quote the double quotes rather than
|
||||
# the single quotes because the entity name is
|
||||
# """ whether this is HTML or XML. If we
|
||||
# quoted the single quotes, we'd have to decide
|
||||
# between ' and &squot;.
|
||||
replace_with = """
|
||||
value = value.replace('"', replace_with)
|
||||
else:
|
||||
# There are double quotes but no single quotes.
|
||||
# We can use single quotes to quote the attribute.
|
||||
quote_with = "'"
|
||||
return quote_with + value + quote_with
|
||||
|
||||
@classmethod
|
||||
def substitute_xml(cls, value, make_quoted_attribute=False):
|
||||
"""Substitute XML entities for special XML characters.
|
||||
|
||||
:param value: A string to be substituted. The less-than sign
|
||||
will become <, the greater-than sign will become >,
|
||||
and any ampersands will become &. If you want ampersands
|
||||
that appear to be part of an entity definition to be left
|
||||
alone, use substitute_xml_containing_entities() instead.
|
||||
|
||||
:param make_quoted_attribute: If True, then the string will be
|
||||
quoted, as befits an attribute value.
|
||||
"""
|
||||
# Escape angle brackets and ampersands.
|
||||
value = cls.AMPERSAND_OR_BRACKET.sub(
|
||||
cls._substitute_xml_entity, value)
|
||||
|
||||
if make_quoted_attribute:
|
||||
value = cls.quoted_attribute_value(value)
|
||||
return value
|
||||
|
||||
@classmethod
|
||||
def substitute_xml_containing_entities(
|
||||
cls, value, make_quoted_attribute=False):
|
||||
"""Substitute XML entities for special XML characters.
|
||||
|
||||
:param value: A string to be substituted. The less-than sign will
|
||||
become <, the greater-than sign will become >, and any
|
||||
ampersands that are not part of an entity defition will
|
||||
become &.
|
||||
|
||||
:param make_quoted_attribute: If True, then the string will be
|
||||
quoted, as befits an attribute value.
|
||||
"""
|
||||
# Escape angle brackets, and ampersands that aren't part of
|
||||
# entities.
|
||||
value = cls.BARE_AMPERSAND_OR_BRACKET.sub(
|
||||
cls._substitute_xml_entity, value)
|
||||
|
||||
if make_quoted_attribute:
|
||||
value = cls.quoted_attribute_value(value)
|
||||
return value
|
||||
|
||||
@classmethod
|
||||
def substitute_html(cls, s):
|
||||
"""Replace certain Unicode characters with named HTML entities.
|
||||
|
||||
This differs from data.encode(encoding, 'xmlcharrefreplace')
|
||||
in that the goal is to make the result more readable (to those
|
||||
with ASCII displays) rather than to recover from
|
||||
errors. There's absolutely nothing wrong with a UTF-8 string
|
||||
containg a LATIN SMALL LETTER E WITH ACUTE, but replacing that
|
||||
character with "é" will make it more readable to some
|
||||
people.
|
||||
"""
|
||||
return cls.CHARACTER_TO_HTML_ENTITY_RE.sub(
|
||||
cls._substitute_html_entity, s)
|
||||
|
||||
|
||||
class EncodingDetector:
|
||||
"""Suggests a number of possible encodings for a bytestring.
|
||||
|
||||
Order of precedence:
|
||||
|
||||
1. Encodings you specifically tell EncodingDetector to try first
|
||||
(the override_encodings argument to the constructor).
|
||||
|
||||
2. An encoding declared within the bytestring itself, either in an
|
||||
XML declaration (if the bytestring is to be interpreted as an XML
|
||||
document), or in a <meta> tag (if the bytestring is to be
|
||||
interpreted as an HTML document.)
|
||||
|
||||
3. An encoding detected through textual analysis by chardet,
|
||||
cchardet, or a similar external library.
|
||||
|
||||
4. UTF-8.
|
||||
|
||||
5. Windows-1252.
|
||||
"""
|
||||
def __init__(self, markup, override_encodings=None, is_html=False,
|
||||
exclude_encodings=None):
|
||||
self.override_encodings = override_encodings or []
|
||||
exclude_encodings = exclude_encodings or []
|
||||
self.exclude_encodings = set([x.lower() for x in exclude_encodings])
|
||||
self.chardet_encoding = None
|
||||
self.is_html = is_html
|
||||
self.declared_encoding = None
|
||||
|
||||
# First order of business: strip a byte-order mark.
|
||||
self.markup, self.sniffed_encoding = self.strip_byte_order_mark(markup)
|
||||
|
||||
def _usable(self, encoding, tried):
|
||||
if encoding is not None:
|
||||
encoding = encoding.lower()
|
||||
if encoding in self.exclude_encodings:
|
||||
return False
|
||||
if encoding not in tried:
|
||||
tried.add(encoding)
|
||||
return True
|
||||
return False
|
||||
|
||||
@property
|
||||
def encodings(self):
|
||||
"""Yield a number of encodings that might work for this markup."""
|
||||
tried = set()
|
||||
for e in self.override_encodings:
|
||||
if self._usable(e, tried):
|
||||
yield e
|
||||
|
||||
# Did the document originally start with a byte-order mark
|
||||
# that indicated its encoding?
|
||||
if self._usable(self.sniffed_encoding, tried):
|
||||
yield self.sniffed_encoding
|
||||
|
||||
# Look within the document for an XML or HTML encoding
|
||||
# declaration.
|
||||
if self.declared_encoding is None:
|
||||
self.declared_encoding = self.find_declared_encoding(
|
||||
self.markup, self.is_html)
|
||||
if self._usable(self.declared_encoding, tried):
|
||||
yield self.declared_encoding
|
||||
|
||||
# Use third-party character set detection to guess at the
|
||||
# encoding.
|
||||
if self.chardet_encoding is None:
|
||||
self.chardet_encoding = chardet_dammit(self.markup)
|
||||
if self._usable(self.chardet_encoding, tried):
|
||||
yield self.chardet_encoding
|
||||
|
||||
# As a last-ditch effort, try utf-8 and windows-1252.
|
||||
for e in ('utf-8', 'windows-1252'):
|
||||
if self._usable(e, tried):
|
||||
yield e
|
||||
|
||||
@classmethod
|
||||
def strip_byte_order_mark(cls, data):
|
||||
"""If a byte-order mark is present, strip it and return the encoding it implies."""
|
||||
encoding = None
|
||||
if isinstance(data, unicode):
|
||||
# Unicode data cannot have a byte-order mark.
|
||||
return data, encoding
|
||||
if (len(data) >= 4) and (data[:2] == b'\xfe\xff') \
|
||||
and (data[2:4] != '\x00\x00'):
|
||||
encoding = 'utf-16be'
|
||||
data = data[2:]
|
||||
elif (len(data) >= 4) and (data[:2] == b'\xff\xfe') \
|
||||
and (data[2:4] != '\x00\x00'):
|
||||
encoding = 'utf-16le'
|
||||
data = data[2:]
|
||||
elif data[:3] == b'\xef\xbb\xbf':
|
||||
encoding = 'utf-8'
|
||||
data = data[3:]
|
||||
elif data[:4] == b'\x00\x00\xfe\xff':
|
||||
encoding = 'utf-32be'
|
||||
data = data[4:]
|
||||
elif data[:4] == b'\xff\xfe\x00\x00':
|
||||
encoding = 'utf-32le'
|
||||
data = data[4:]
|
||||
return data, encoding
|
||||
|
||||
@classmethod
|
||||
def find_declared_encoding(cls, markup, is_html=False, search_entire_document=False):
|
||||
"""Given a document, tries to find its declared encoding.
|
||||
|
||||
An XML encoding is declared at the beginning of the document.
|
||||
|
||||
An HTML encoding is declared in a <meta> tag, hopefully near the
|
||||
beginning of the document.
|
||||
"""
|
||||
if search_entire_document:
|
||||
xml_endpos = html_endpos = len(markup)
|
||||
else:
|
||||
xml_endpos = 1024
|
||||
html_endpos = max(2048, int(len(markup) * 0.05))
|
||||
|
||||
declared_encoding = None
|
||||
declared_encoding_match = xml_encoding_re.search(markup, endpos=xml_endpos)
|
||||
if not declared_encoding_match and is_html:
|
||||
declared_encoding_match = html_meta_re.search(markup, endpos=html_endpos)
|
||||
if declared_encoding_match is not None:
|
||||
declared_encoding = declared_encoding_match.groups()[0].decode(
|
||||
'ascii', 'replace')
|
||||
if declared_encoding:
|
||||
return declared_encoding.lower()
|
||||
return None
|
||||
|
||||
class UnicodeDammit:
|
||||
"""A class for detecting the encoding of a *ML document and
|
||||
converting it to a Unicode string. If the source encoding is
|
||||
windows-1252, can replace MS smart quotes with their HTML or XML
|
||||
equivalents."""
|
||||
|
||||
# This dictionary maps commonly seen values for "charset" in HTML
|
||||
# meta tags to the corresponding Python codec names. It only covers
|
||||
# values that aren't in Python's aliases and can't be determined
|
||||
# by the heuristics in find_codec.
|
||||
CHARSET_ALIASES = {"macintosh": "mac-roman",
|
||||
"x-sjis": "shift-jis"}
|
||||
|
||||
ENCODINGS_WITH_SMART_QUOTES = [
|
||||
"windows-1252",
|
||||
"iso-8859-1",
|
||||
"iso-8859-2",
|
||||
]
|
||||
|
||||
def __init__(self, markup, override_encodings=[],
|
||||
smart_quotes_to=None, is_html=False, exclude_encodings=[]):
|
||||
self.smart_quotes_to = smart_quotes_to
|
||||
self.tried_encodings = []
|
||||
self.contains_replacement_characters = False
|
||||
self.is_html = is_html
|
||||
|
||||
self.detector = EncodingDetector(
|
||||
markup, override_encodings, is_html, exclude_encodings)
|
||||
|
||||
# Short-circuit if the data is in Unicode to begin with.
|
||||
if isinstance(markup, unicode) or markup == '':
|
||||
self.markup = markup
|
||||
self.unicode_markup = unicode(markup)
|
||||
self.original_encoding = None
|
||||
return
|
||||
|
||||
# The encoding detector may have stripped a byte-order mark.
|
||||
# Use the stripped markup from this point on.
|
||||
self.markup = self.detector.markup
|
||||
|
||||
u = None
|
||||
for encoding in self.detector.encodings:
|
||||
markup = self.detector.markup
|
||||
u = self._convert_from(encoding)
|
||||
if u is not None:
|
||||
break
|
||||
|
||||
if not u:
|
||||
# None of the encodings worked. As an absolute last resort,
|
||||
# try them again with character replacement.
|
||||
|
||||
for encoding in self.detector.encodings:
|
||||
if encoding != "ascii":
|
||||
u = self._convert_from(encoding, "replace")
|
||||
if u is not None:
|
||||
logging.warning(
|
||||
"Some characters could not be decoded, and were "
|
||||
"replaced with REPLACEMENT CHARACTER.")
|
||||
self.contains_replacement_characters = True
|
||||
break
|
||||
|
||||
# If none of that worked, we could at this point force it to
|
||||
# ASCII, but that would destroy so much data that I think
|
||||
# giving up is better.
|
||||
self.unicode_markup = u
|
||||
if not u:
|
||||
self.original_encoding = None
|
||||
|
||||
def _sub_ms_char(self, match):
|
||||
"""Changes a MS smart quote character to an XML or HTML
|
||||
entity, or an ASCII character."""
|
||||
orig = match.group(1)
|
||||
if self.smart_quotes_to == 'ascii':
|
||||
sub = self.MS_CHARS_TO_ASCII.get(orig).encode()
|
||||
else:
|
||||
sub = self.MS_CHARS.get(orig)
|
||||
if type(sub) == tuple:
|
||||
if self.smart_quotes_to == 'xml':
|
||||
sub = '&#x'.encode() + sub[1].encode() + ';'.encode()
|
||||
else:
|
||||
sub = '&'.encode() + sub[0].encode() + ';'.encode()
|
||||
else:
|
||||
sub = sub.encode()
|
||||
return sub
|
||||
|
||||
def _convert_from(self, proposed, errors="strict"):
|
||||
proposed = self.find_codec(proposed)
|
||||
if not proposed or (proposed, errors) in self.tried_encodings:
|
||||
return None
|
||||
self.tried_encodings.append((proposed, errors))
|
||||
markup = self.markup
|
||||
# Convert smart quotes to HTML if coming from an encoding
|
||||
# that might have them.
|
||||
if (self.smart_quotes_to is not None
|
||||
and proposed in self.ENCODINGS_WITH_SMART_QUOTES):
|
||||
smart_quotes_re = b"([\x80-\x9f])"
|
||||
smart_quotes_compiled = re.compile(smart_quotes_re)
|
||||
markup = smart_quotes_compiled.sub(self._sub_ms_char, markup)
|
||||
|
||||
try:
|
||||
#print "Trying to convert document to %s (errors=%s)" % (
|
||||
# proposed, errors)
|
||||
u = self._to_unicode(markup, proposed, errors)
|
||||
self.markup = u
|
||||
self.original_encoding = proposed
|
||||
except Exception as e:
|
||||
#print "That didn't work!"
|
||||
#print e
|
||||
return None
|
||||
#print "Correct encoding: %s" % proposed
|
||||
return self.markup
|
||||
|
||||
def _to_unicode(self, data, encoding, errors="strict"):
|
||||
'''Given a string and its encoding, decodes the string into Unicode.
|
||||
%encoding is a string recognized by encodings.aliases'''
|
||||
return unicode(data, encoding, errors)
|
||||
|
||||
@property
|
||||
def declared_html_encoding(self):
|
||||
if not self.is_html:
|
||||
return None
|
||||
return self.detector.declared_encoding
|
||||
|
||||
def find_codec(self, charset):
|
||||
value = (self._codec(self.CHARSET_ALIASES.get(charset, charset))
|
||||
or (charset and self._codec(charset.replace("-", "")))
|
||||
or (charset and self._codec(charset.replace("-", "_")))
|
||||
or (charset and charset.lower())
|
||||
or charset
|
||||
)
|
||||
if value:
|
||||
return value.lower()
|
||||
return None
|
||||
|
||||
def _codec(self, charset):
|
||||
if not charset:
|
||||
return charset
|
||||
codec = None
|
||||
try:
|
||||
codecs.lookup(charset)
|
||||
codec = charset
|
||||
except (LookupError, ValueError):
|
||||
pass
|
||||
return codec
|
||||
|
||||
|
||||
# A partial mapping of ISO-Latin-1 to HTML entities/XML numeric entities.
|
||||
MS_CHARS = {b'\x80': ('euro', '20AC'),
|
||||
b'\x81': ' ',
|
||||
b'\x82': ('sbquo', '201A'),
|
||||
b'\x83': ('fnof', '192'),
|
||||
b'\x84': ('bdquo', '201E'),
|
||||
b'\x85': ('hellip', '2026'),
|
||||
b'\x86': ('dagger', '2020'),
|
||||
b'\x87': ('Dagger', '2021'),
|
||||
b'\x88': ('circ', '2C6'),
|
||||
b'\x89': ('permil', '2030'),
|
||||
b'\x8A': ('Scaron', '160'),
|
||||
b'\x8B': ('lsaquo', '2039'),
|
||||
b'\x8C': ('OElig', '152'),
|
||||
b'\x8D': '?',
|
||||
b'\x8E': ('#x17D', '17D'),
|
||||
b'\x8F': '?',
|
||||
b'\x90': '?',
|
||||
b'\x91': ('lsquo', '2018'),
|
||||
b'\x92': ('rsquo', '2019'),
|
||||
b'\x93': ('ldquo', '201C'),
|
||||
b'\x94': ('rdquo', '201D'),
|
||||
b'\x95': ('bull', '2022'),
|
||||
b'\x96': ('ndash', '2013'),
|
||||
b'\x97': ('mdash', '2014'),
|
||||
b'\x98': ('tilde', '2DC'),
|
||||
b'\x99': ('trade', '2122'),
|
||||
b'\x9a': ('scaron', '161'),
|
||||
b'\x9b': ('rsaquo', '203A'),
|
||||
b'\x9c': ('oelig', '153'),
|
||||
b'\x9d': '?',
|
||||
b'\x9e': ('#x17E', '17E'),
|
||||
b'\x9f': ('Yuml', ''),}
|
||||
|
||||
# A parochial partial mapping of ISO-Latin-1 to ASCII. Contains
|
||||
# horrors like stripping diacritical marks to turn á into a, but also
|
||||
# contains non-horrors like turning “ into ".
|
||||
MS_CHARS_TO_ASCII = {
|
||||
b'\x80' : 'EUR',
|
||||
b'\x81' : ' ',
|
||||
b'\x82' : ',',
|
||||
b'\x83' : 'f',
|
||||
b'\x84' : ',,',
|
||||
b'\x85' : '...',
|
||||
b'\x86' : '+',
|
||||
b'\x87' : '++',
|
||||
b'\x88' : '^',
|
||||
b'\x89' : '%',
|
||||
b'\x8a' : 'S',
|
||||
b'\x8b' : '<',
|
||||
b'\x8c' : 'OE',
|
||||
b'\x8d' : '?',
|
||||
b'\x8e' : 'Z',
|
||||
b'\x8f' : '?',
|
||||
b'\x90' : '?',
|
||||
b'\x91' : "'",
|
||||
b'\x92' : "'",
|
||||
b'\x93' : '"',
|
||||
b'\x94' : '"',
|
||||
b'\x95' : '*',
|
||||
b'\x96' : '-',
|
||||
b'\x97' : '--',
|
||||
b'\x98' : '~',
|
||||
b'\x99' : '(TM)',
|
||||
b'\x9a' : 's',
|
||||
b'\x9b' : '>',
|
||||
b'\x9c' : 'oe',
|
||||
b'\x9d' : '?',
|
||||
b'\x9e' : 'z',
|
||||
b'\x9f' : 'Y',
|
||||
b'\xa0' : ' ',
|
||||
b'\xa1' : '!',
|
||||
b'\xa2' : 'c',
|
||||
b'\xa3' : 'GBP',
|
||||
b'\xa4' : '$', #This approximation is especially parochial--this is the
|
||||
#generic currency symbol.
|
||||
b'\xa5' : 'YEN',
|
||||
b'\xa6' : '|',
|
||||
b'\xa7' : 'S',
|
||||
b'\xa8' : '..',
|
||||
b'\xa9' : '',
|
||||
b'\xaa' : '(th)',
|
||||
b'\xab' : '<<',
|
||||
b'\xac' : '!',
|
||||
b'\xad' : ' ',
|
||||
b'\xae' : '(R)',
|
||||
b'\xaf' : '-',
|
||||
b'\xb0' : 'o',
|
||||
b'\xb1' : '+-',
|
||||
b'\xb2' : '2',
|
||||
b'\xb3' : '3',
|
||||
b'\xb4' : ("'", 'acute'),
|
||||
b'\xb5' : 'u',
|
||||
b'\xb6' : 'P',
|
||||
b'\xb7' : '*',
|
||||
b'\xb8' : ',',
|
||||
b'\xb9' : '1',
|
||||
b'\xba' : '(th)',
|
||||
b'\xbb' : '>>',
|
||||
b'\xbc' : '1/4',
|
||||
b'\xbd' : '1/2',
|
||||
b'\xbe' : '3/4',
|
||||
b'\xbf' : '?',
|
||||
b'\xc0' : 'A',
|
||||
b'\xc1' : 'A',
|
||||
b'\xc2' : 'A',
|
||||
b'\xc3' : 'A',
|
||||
b'\xc4' : 'A',
|
||||
b'\xc5' : 'A',
|
||||
b'\xc6' : 'AE',
|
||||
b'\xc7' : 'C',
|
||||
b'\xc8' : 'E',
|
||||
b'\xc9' : 'E',
|
||||
b'\xca' : 'E',
|
||||
b'\xcb' : 'E',
|
||||
b'\xcc' : 'I',
|
||||
b'\xcd' : 'I',
|
||||
b'\xce' : 'I',
|
||||
b'\xcf' : 'I',
|
||||
b'\xd0' : 'D',
|
||||
b'\xd1' : 'N',
|
||||
b'\xd2' : 'O',
|
||||
b'\xd3' : 'O',
|
||||
b'\xd4' : 'O',
|
||||
b'\xd5' : 'O',
|
||||
b'\xd6' : 'O',
|
||||
b'\xd7' : '*',
|
||||
b'\xd8' : 'O',
|
||||
b'\xd9' : 'U',
|
||||
b'\xda' : 'U',
|
||||
b'\xdb' : 'U',
|
||||
b'\xdc' : 'U',
|
||||
b'\xdd' : 'Y',
|
||||
b'\xde' : 'b',
|
||||
b'\xdf' : 'B',
|
||||
b'\xe0' : 'a',
|
||||
b'\xe1' : 'a',
|
||||
b'\xe2' : 'a',
|
||||
b'\xe3' : 'a',
|
||||
b'\xe4' : 'a',
|
||||
b'\xe5' : 'a',
|
||||
b'\xe6' : 'ae',
|
||||
b'\xe7' : 'c',
|
||||
b'\xe8' : 'e',
|
||||
b'\xe9' : 'e',
|
||||
b'\xea' : 'e',
|
||||
b'\xeb' : 'e',
|
||||
b'\xec' : 'i',
|
||||
b'\xed' : 'i',
|
||||
b'\xee' : 'i',
|
||||
b'\xef' : 'i',
|
||||
b'\xf0' : 'o',
|
||||
b'\xf1' : 'n',
|
||||
b'\xf2' : 'o',
|
||||
b'\xf3' : 'o',
|
||||
b'\xf4' : 'o',
|
||||
b'\xf5' : 'o',
|
||||
b'\xf6' : 'o',
|
||||
b'\xf7' : '/',
|
||||
b'\xf8' : 'o',
|
||||
b'\xf9' : 'u',
|
||||
b'\xfa' : 'u',
|
||||
b'\xfb' : 'u',
|
||||
b'\xfc' : 'u',
|
||||
b'\xfd' : 'y',
|
||||
b'\xfe' : 'b',
|
||||
b'\xff' : 'y',
|
||||
}
|
||||
|
||||
# A map used when removing rogue Windows-1252/ISO-8859-1
|
||||
# characters in otherwise UTF-8 documents.
|
||||
#
|
||||
# Note that \x81, \x8d, \x8f, \x90, and \x9d are undefined in
|
||||
# Windows-1252.
|
||||
WINDOWS_1252_TO_UTF8 = {
|
||||
0x80 : b'\xe2\x82\xac', # €
|
||||
0x82 : b'\xe2\x80\x9a', # ‚
|
||||
0x83 : b'\xc6\x92', # ƒ
|
||||
0x84 : b'\xe2\x80\x9e', # „
|
||||
0x85 : b'\xe2\x80\xa6', # …
|
||||
0x86 : b'\xe2\x80\xa0', # †
|
||||
0x87 : b'\xe2\x80\xa1', # ‡
|
||||
0x88 : b'\xcb\x86', # ˆ
|
||||
0x89 : b'\xe2\x80\xb0', # ‰
|
||||
0x8a : b'\xc5\xa0', # Š
|
||||
0x8b : b'\xe2\x80\xb9', # ‹
|
||||
0x8c : b'\xc5\x92', # Œ
|
||||
0x8e : b'\xc5\xbd', # Ž
|
||||
0x91 : b'\xe2\x80\x98', # ‘
|
||||
0x92 : b'\xe2\x80\x99', # ’
|
||||
0x93 : b'\xe2\x80\x9c', # “
|
||||
0x94 : b'\xe2\x80\x9d', # ”
|
||||
0x95 : b'\xe2\x80\xa2', # •
|
||||
0x96 : b'\xe2\x80\x93', # –
|
||||
0x97 : b'\xe2\x80\x94', # —
|
||||
0x98 : b'\xcb\x9c', # ˜
|
||||
0x99 : b'\xe2\x84\xa2', # ™
|
||||
0x9a : b'\xc5\xa1', # š
|
||||
0x9b : b'\xe2\x80\xba', # ›
|
||||
0x9c : b'\xc5\x93', # œ
|
||||
0x9e : b'\xc5\xbe', # ž
|
||||
0x9f : b'\xc5\xb8', # Ÿ
|
||||
0xa0 : b'\xc2\xa0', #
|
||||
0xa1 : b'\xc2\xa1', # ¡
|
||||
0xa2 : b'\xc2\xa2', # ¢
|
||||
0xa3 : b'\xc2\xa3', # £
|
||||
0xa4 : b'\xc2\xa4', # ¤
|
||||
0xa5 : b'\xc2\xa5', # ¥
|
||||
0xa6 : b'\xc2\xa6', # ¦
|
||||
0xa7 : b'\xc2\xa7', # §
|
||||
0xa8 : b'\xc2\xa8', # ¨
|
||||
0xa9 : b'\xc2\xa9', # ©
|
||||
0xaa : b'\xc2\xaa', # ª
|
||||
0xab : b'\xc2\xab', # «
|
||||
0xac : b'\xc2\xac', # ¬
|
||||
0xad : b'\xc2\xad', #
|
||||
0xae : b'\xc2\xae', # ®
|
||||
0xaf : b'\xc2\xaf', # ¯
|
||||
0xb0 : b'\xc2\xb0', # °
|
||||
0xb1 : b'\xc2\xb1', # ±
|
||||
0xb2 : b'\xc2\xb2', # ²
|
||||
0xb3 : b'\xc2\xb3', # ³
|
||||
0xb4 : b'\xc2\xb4', # ´
|
||||
0xb5 : b'\xc2\xb5', # µ
|
||||
0xb6 : b'\xc2\xb6', # ¶
|
||||
0xb7 : b'\xc2\xb7', # ·
|
||||
0xb8 : b'\xc2\xb8', # ¸
|
||||
0xb9 : b'\xc2\xb9', # ¹
|
||||
0xba : b'\xc2\xba', # º
|
||||
0xbb : b'\xc2\xbb', # »
|
||||
0xbc : b'\xc2\xbc', # ¼
|
||||
0xbd : b'\xc2\xbd', # ½
|
||||
0xbe : b'\xc2\xbe', # ¾
|
||||
0xbf : b'\xc2\xbf', # ¿
|
||||
0xc0 : b'\xc3\x80', # À
|
||||
0xc1 : b'\xc3\x81', # Á
|
||||
0xc2 : b'\xc3\x82', # Â
|
||||
0xc3 : b'\xc3\x83', # Ã
|
||||
0xc4 : b'\xc3\x84', # Ä
|
||||
0xc5 : b'\xc3\x85', # Å
|
||||
0xc6 : b'\xc3\x86', # Æ
|
||||
0xc7 : b'\xc3\x87', # Ç
|
||||
0xc8 : b'\xc3\x88', # È
|
||||
0xc9 : b'\xc3\x89', # É
|
||||
0xca : b'\xc3\x8a', # Ê
|
||||
0xcb : b'\xc3\x8b', # Ë
|
||||
0xcc : b'\xc3\x8c', # Ì
|
||||
0xcd : b'\xc3\x8d', # Í
|
||||
0xce : b'\xc3\x8e', # Î
|
||||
0xcf : b'\xc3\x8f', # Ï
|
||||
0xd0 : b'\xc3\x90', # Ð
|
||||
0xd1 : b'\xc3\x91', # Ñ
|
||||
0xd2 : b'\xc3\x92', # Ò
|
||||
0xd3 : b'\xc3\x93', # Ó
|
||||
0xd4 : b'\xc3\x94', # Ô
|
||||
0xd5 : b'\xc3\x95', # Õ
|
||||
0xd6 : b'\xc3\x96', # Ö
|
||||
0xd7 : b'\xc3\x97', # ×
|
||||
0xd8 : b'\xc3\x98', # Ø
|
||||
0xd9 : b'\xc3\x99', # Ù
|
||||
0xda : b'\xc3\x9a', # Ú
|
||||
0xdb : b'\xc3\x9b', # Û
|
||||
0xdc : b'\xc3\x9c', # Ü
|
||||
0xdd : b'\xc3\x9d', # Ý
|
||||
0xde : b'\xc3\x9e', # Þ
|
||||
0xdf : b'\xc3\x9f', # ß
|
||||
0xe0 : b'\xc3\xa0', # à
|
||||
0xe1 : b'\xa1', # á
|
||||
0xe2 : b'\xc3\xa2', # â
|
||||
0xe3 : b'\xc3\xa3', # ã
|
||||
0xe4 : b'\xc3\xa4', # ä
|
||||
0xe5 : b'\xc3\xa5', # å
|
||||
0xe6 : b'\xc3\xa6', # æ
|
||||
0xe7 : b'\xc3\xa7', # ç
|
||||
0xe8 : b'\xc3\xa8', # è
|
||||
0xe9 : b'\xc3\xa9', # é
|
||||
0xea : b'\xc3\xaa', # ê
|
||||
0xeb : b'\xc3\xab', # ë
|
||||
0xec : b'\xc3\xac', # ì
|
||||
0xed : b'\xc3\xad', # í
|
||||
0xee : b'\xc3\xae', # î
|
||||
0xef : b'\xc3\xaf', # ï
|
||||
0xf0 : b'\xc3\xb0', # ð
|
||||
0xf1 : b'\xc3\xb1', # ñ
|
||||
0xf2 : b'\xc3\xb2', # ò
|
||||
0xf3 : b'\xc3\xb3', # ó
|
||||
0xf4 : b'\xc3\xb4', # ô
|
||||
0xf5 : b'\xc3\xb5', # õ
|
||||
0xf6 : b'\xc3\xb6', # ö
|
||||
0xf7 : b'\xc3\xb7', # ÷
|
||||
0xf8 : b'\xc3\xb8', # ø
|
||||
0xf9 : b'\xc3\xb9', # ù
|
||||
0xfa : b'\xc3\xba', # ú
|
||||
0xfb : b'\xc3\xbb', # û
|
||||
0xfc : b'\xc3\xbc', # ü
|
||||
0xfd : b'\xc3\xbd', # ý
|
||||
0xfe : b'\xc3\xbe', # þ
|
||||
}
|
||||
|
||||
MULTIBYTE_MARKERS_AND_SIZES = [
|
||||
(0xc2, 0xdf, 2), # 2-byte characters start with a byte C2-DF
|
||||
(0xe0, 0xef, 3), # 3-byte characters start with E0-EF
|
||||
(0xf0, 0xf4, 4), # 4-byte characters start with F0-F4
|
||||
]
|
||||
|
||||
FIRST_MULTIBYTE_MARKER = MULTIBYTE_MARKERS_AND_SIZES[0][0]
|
||||
LAST_MULTIBYTE_MARKER = MULTIBYTE_MARKERS_AND_SIZES[-1][1]
|
||||
|
||||
@classmethod
|
||||
def detwingle(cls, in_bytes, main_encoding="utf8",
|
||||
embedded_encoding="windows-1252"):
|
||||
"""Fix characters from one encoding embedded in some other encoding.
|
||||
|
||||
Currently the only situation supported is Windows-1252 (or its
|
||||
subset ISO-8859-1), embedded in UTF-8.
|
||||
|
||||
The input must be a bytestring. If you've already converted
|
||||
the document to Unicode, you're too late.
|
||||
|
||||
The output is a bytestring in which `embedded_encoding`
|
||||
characters have been converted to their `main_encoding`
|
||||
equivalents.
|
||||
"""
|
||||
if embedded_encoding.replace('_', '-').lower() not in (
|
||||
'windows-1252', 'windows_1252'):
|
||||
raise NotImplementedError(
|
||||
"Windows-1252 and ISO-8859-1 are the only currently supported "
|
||||
"embedded encodings.")
|
||||
|
||||
if main_encoding.lower() not in ('utf8', 'utf-8'):
|
||||
raise NotImplementedError(
|
||||
"UTF-8 is the only currently supported main encoding.")
|
||||
|
||||
byte_chunks = []
|
||||
|
||||
chunk_start = 0
|
||||
pos = 0
|
||||
while pos < len(in_bytes):
|
||||
byte = in_bytes[pos]
|
||||
if not isinstance(byte, int):
|
||||
# Python 2.x
|
||||
byte = ord(byte)
|
||||
if (byte >= cls.FIRST_MULTIBYTE_MARKER
|
||||
and byte <= cls.LAST_MULTIBYTE_MARKER):
|
||||
# This is the start of a UTF-8 multibyte character. Skip
|
||||
# to the end.
|
||||
for start, end, size in cls.MULTIBYTE_MARKERS_AND_SIZES:
|
||||
if byte >= start and byte <= end:
|
||||
pos += size
|
||||
break
|
||||
elif byte >= 0x80 and byte in cls.WINDOWS_1252_TO_UTF8:
|
||||
# We found a Windows-1252 character!
|
||||
# Save the string up to this point as a chunk.
|
||||
byte_chunks.append(in_bytes[chunk_start:pos])
|
||||
|
||||
# Now translate the Windows-1252 character into UTF-8
|
||||
# and add it as another, one-byte chunk.
|
||||
byte_chunks.append(cls.WINDOWS_1252_TO_UTF8[byte])
|
||||
pos += 1
|
||||
chunk_start = pos
|
||||
else:
|
||||
# Go on to the next character.
|
||||
pos += 1
|
||||
if chunk_start == 0:
|
||||
# The string is unchanged.
|
||||
return in_bytes
|
||||
else:
|
||||
# Store the final chunk.
|
||||
byte_chunks.append(in_bytes[chunk_start:])
|
||||
return b''.join(byte_chunks)
|
||||
|
||||
@@ -0,0 +1,213 @@
|
||||
"""Diagnostic functions, mainly for use when doing tech support."""
|
||||
import cProfile
|
||||
from StringIO import StringIO
|
||||
from HTMLParser import HTMLParser
|
||||
import bs4
|
||||
from bs4 import BeautifulSoup, __version__
|
||||
from bs4.builder import builder_registry
|
||||
|
||||
import os
|
||||
import pstats
|
||||
import random
|
||||
import tempfile
|
||||
import time
|
||||
import traceback
|
||||
import sys
|
||||
import cProfile
|
||||
|
||||
def diagnose(data):
|
||||
"""Diagnostic suite for isolating common problems."""
|
||||
print "Diagnostic running on Beautiful Soup %s" % __version__
|
||||
print "Python version %s" % sys.version
|
||||
|
||||
basic_parsers = ["html.parser", "html5lib", "lxml"]
|
||||
for name in basic_parsers:
|
||||
for builder in builder_registry.builders:
|
||||
if name in builder.features:
|
||||
break
|
||||
else:
|
||||
basic_parsers.remove(name)
|
||||
print (
|
||||
"I noticed that %s is not installed. Installing it may help." %
|
||||
name)
|
||||
|
||||
if 'lxml' in basic_parsers:
|
||||
basic_parsers.append(["lxml", "xml"])
|
||||
try:
|
||||
from lxml import etree
|
||||
print "Found lxml version %s" % ".".join(map(str,etree.LXML_VERSION))
|
||||
except ImportError, e:
|
||||
print (
|
||||
"lxml is not installed or couldn't be imported.")
|
||||
|
||||
|
||||
if 'html5lib' in basic_parsers:
|
||||
try:
|
||||
import html5lib
|
||||
print "Found html5lib version %s" % html5lib.__version__
|
||||
except ImportError, e:
|
||||
print (
|
||||
"html5lib is not installed or couldn't be imported.")
|
||||
|
||||
if hasattr(data, 'read'):
|
||||
data = data.read()
|
||||
elif os.path.exists(data):
|
||||
print '"%s" looks like a filename. Reading data from the file.' % data
|
||||
data = open(data).read()
|
||||
elif data.startswith("http:") or data.startswith("https:"):
|
||||
print '"%s" looks like a URL. Beautiful Soup is not an HTTP client.' % data
|
||||
print "You need to use some other library to get the document behind the URL, and feed that document to Beautiful Soup."
|
||||
return
|
||||
print
|
||||
|
||||
for parser in basic_parsers:
|
||||
print "Trying to parse your markup with %s" % parser
|
||||
success = False
|
||||
try:
|
||||
soup = BeautifulSoup(data, parser)
|
||||
success = True
|
||||
except Exception, e:
|
||||
print "%s could not parse the markup." % parser
|
||||
traceback.print_exc()
|
||||
if success:
|
||||
print "Here's what %s did with the markup:" % parser
|
||||
print soup.prettify()
|
||||
|
||||
print "-" * 80
|
||||
|
||||
def lxml_trace(data, html=True, **kwargs):
|
||||
"""Print out the lxml events that occur during parsing.
|
||||
|
||||
This lets you see how lxml parses a document when no Beautiful
|
||||
Soup code is running.
|
||||
"""
|
||||
from lxml import etree
|
||||
for event, element in etree.iterparse(StringIO(data), html=html, **kwargs):
|
||||
print("%s, %4s, %s" % (event, element.tag, element.text))
|
||||
|
||||
class AnnouncingParser(HTMLParser):
|
||||
"""Announces HTMLParser parse events, without doing anything else."""
|
||||
|
||||
def _p(self, s):
|
||||
print(s)
|
||||
|
||||
def handle_starttag(self, name, attrs):
|
||||
self._p("%s START" % name)
|
||||
|
||||
def handle_endtag(self, name):
|
||||
self._p("%s END" % name)
|
||||
|
||||
def handle_data(self, data):
|
||||
self._p("%s DATA" % data)
|
||||
|
||||
def handle_charref(self, name):
|
||||
self._p("%s CHARREF" % name)
|
||||
|
||||
def handle_entityref(self, name):
|
||||
self._p("%s ENTITYREF" % name)
|
||||
|
||||
def handle_comment(self, data):
|
||||
self._p("%s COMMENT" % data)
|
||||
|
||||
def handle_decl(self, data):
|
||||
self._p("%s DECL" % data)
|
||||
|
||||
def unknown_decl(self, data):
|
||||
self._p("%s UNKNOWN-DECL" % data)
|
||||
|
||||
def handle_pi(self, data):
|
||||
self._p("%s PI" % data)
|
||||
|
||||
def htmlparser_trace(data):
|
||||
"""Print out the HTMLParser events that occur during parsing.
|
||||
|
||||
This lets you see how HTMLParser parses a document when no
|
||||
Beautiful Soup code is running.
|
||||
"""
|
||||
parser = AnnouncingParser()
|
||||
parser.feed(data)
|
||||
|
||||
_vowels = "aeiou"
|
||||
_consonants = "bcdfghjklmnpqrstvwxyz"
|
||||
|
||||
def rword(length=5):
|
||||
"Generate a random word-like string."
|
||||
s = ''
|
||||
for i in range(length):
|
||||
if i % 2 == 0:
|
||||
t = _consonants
|
||||
else:
|
||||
t = _vowels
|
||||
s += random.choice(t)
|
||||
return s
|
||||
|
||||
def rsentence(length=4):
|
||||
"Generate a random sentence-like string."
|
||||
return " ".join(rword(random.randint(4,9)) for i in range(length))
|
||||
|
||||
def rdoc(num_elements=1000):
|
||||
"""Randomly generate an invalid HTML document."""
|
||||
tag_names = ['p', 'div', 'span', 'i', 'b', 'script', 'table']
|
||||
elements = []
|
||||
for i in range(num_elements):
|
||||
choice = random.randint(0,3)
|
||||
if choice == 0:
|
||||
# New tag.
|
||||
tag_name = random.choice(tag_names)
|
||||
elements.append("<%s>" % tag_name)
|
||||
elif choice == 1:
|
||||
elements.append(rsentence(random.randint(1,4)))
|
||||
elif choice == 2:
|
||||
# Close a tag.
|
||||
tag_name = random.choice(tag_names)
|
||||
elements.append("</%s>" % tag_name)
|
||||
return "<html>" + "\n".join(elements) + "</html>"
|
||||
|
||||
def benchmark_parsers(num_elements=100000):
|
||||
"""Very basic head-to-head performance benchmark."""
|
||||
print "Comparative parser benchmark on Beautiful Soup %s" % __version__
|
||||
data = rdoc(num_elements)
|
||||
print "Generated a large invalid HTML document (%d bytes)." % len(data)
|
||||
|
||||
for parser in ["lxml", ["lxml", "html"], "html5lib", "html.parser"]:
|
||||
success = False
|
||||
try:
|
||||
a = time.time()
|
||||
soup = BeautifulSoup(data, parser)
|
||||
b = time.time()
|
||||
success = True
|
||||
except Exception, e:
|
||||
print "%s could not parse the markup." % parser
|
||||
traceback.print_exc()
|
||||
if success:
|
||||
print "BS4+%s parsed the markup in %.2fs." % (parser, b-a)
|
||||
|
||||
from lxml import etree
|
||||
a = time.time()
|
||||
etree.HTML(data)
|
||||
b = time.time()
|
||||
print "Raw lxml parsed the markup in %.2fs." % (b-a)
|
||||
|
||||
import html5lib
|
||||
parser = html5lib.HTMLParser()
|
||||
a = time.time()
|
||||
parser.parse(data)
|
||||
b = time.time()
|
||||
print "Raw html5lib parsed the markup in %.2fs." % (b-a)
|
||||
|
||||
def profile(num_elements=100000, parser="lxml"):
|
||||
|
||||
filehandle = tempfile.NamedTemporaryFile()
|
||||
filename = filehandle.name
|
||||
|
||||
data = rdoc(num_elements)
|
||||
vars = dict(bs4=bs4, data=data, parser=parser)
|
||||
cProfile.runctx('bs4.BeautifulSoup(data, parser)' , vars, vars, filename)
|
||||
|
||||
stats = pstats.Stats(filename)
|
||||
# stats.strip_dirs()
|
||||
stats.sort_stats("cumulative")
|
||||
stats.print_stats('_html5lib|bs4', 50)
|
||||
|
||||
if __name__ == '__main__':
|
||||
diagnose(sys.stdin.read())
|
||||
+1713
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,680 @@
|
||||
"""Helper classes for tests."""
|
||||
|
||||
import pickle
|
||||
import copy
|
||||
import functools
|
||||
import unittest
|
||||
from unittest import TestCase
|
||||
from bs4 import BeautifulSoup
|
||||
from bs4.element import (
|
||||
CharsetMetaAttributeValue,
|
||||
Comment,
|
||||
ContentMetaAttributeValue,
|
||||
Doctype,
|
||||
SoupStrainer,
|
||||
)
|
||||
|
||||
from bs4.builder import HTMLParserTreeBuilder
|
||||
default_builder = HTMLParserTreeBuilder
|
||||
|
||||
|
||||
class SoupTest(unittest.TestCase):
|
||||
|
||||
@property
|
||||
def default_builder(self):
|
||||
return default_builder()
|
||||
|
||||
def soup(self, markup, **kwargs):
|
||||
"""Build a Beautiful Soup object from markup."""
|
||||
builder = kwargs.pop('builder', self.default_builder)
|
||||
return BeautifulSoup(markup, builder=builder, **kwargs)
|
||||
|
||||
def document_for(self, markup):
|
||||
"""Turn an HTML fragment into a document.
|
||||
|
||||
The details depend on the builder.
|
||||
"""
|
||||
return self.default_builder.test_fragment_to_document(markup)
|
||||
|
||||
def assertSoupEquals(self, to_parse, compare_parsed_to=None):
|
||||
builder = self.default_builder
|
||||
obj = BeautifulSoup(to_parse, builder=builder)
|
||||
if compare_parsed_to is None:
|
||||
compare_parsed_to = to_parse
|
||||
|
||||
self.assertEqual(obj.decode(), self.document_for(compare_parsed_to))
|
||||
|
||||
def assertConnectedness(self, element):
|
||||
"""Ensure that next_element and previous_element are properly
|
||||
set for all descendants of the given element.
|
||||
"""
|
||||
earlier = None
|
||||
for e in element.descendants:
|
||||
if earlier:
|
||||
self.assertEqual(e, earlier.next_element)
|
||||
self.assertEqual(earlier, e.previous_element)
|
||||
earlier = e
|
||||
|
||||
class HTMLTreeBuilderSmokeTest(object):
|
||||
|
||||
"""A basic test of a treebuilder's competence.
|
||||
|
||||
Any HTML treebuilder, present or future, should be able to pass
|
||||
these tests. With invalid markup, there's room for interpretation,
|
||||
and different parsers can handle it differently. But with the
|
||||
markup in these tests, there's not much room for interpretation.
|
||||
"""
|
||||
|
||||
def test_pickle_and_unpickle_identity(self):
|
||||
# Pickling a tree, then unpickling it, yields a tree identical
|
||||
# to the original.
|
||||
tree = self.soup("<a><b>foo</a>")
|
||||
dumped = pickle.dumps(tree, 2)
|
||||
loaded = pickle.loads(dumped)
|
||||
self.assertEqual(loaded.__class__, BeautifulSoup)
|
||||
self.assertEqual(loaded.decode(), tree.decode())
|
||||
|
||||
def assertDoctypeHandled(self, doctype_fragment):
|
||||
"""Assert that a given doctype string is handled correctly."""
|
||||
doctype_str, soup = self._document_with_doctype(doctype_fragment)
|
||||
|
||||
# Make sure a Doctype object was created.
|
||||
doctype = soup.contents[0]
|
||||
self.assertEqual(doctype.__class__, Doctype)
|
||||
self.assertEqual(doctype, doctype_fragment)
|
||||
self.assertEqual(str(soup)[:len(doctype_str)], doctype_str)
|
||||
|
||||
# Make sure that the doctype was correctly associated with the
|
||||
# parse tree and that the rest of the document parsed.
|
||||
self.assertEqual(soup.p.contents[0], 'foo')
|
||||
|
||||
def _document_with_doctype(self, doctype_fragment):
|
||||
"""Generate and parse a document with the given doctype."""
|
||||
doctype = '<!DOCTYPE %s>' % doctype_fragment
|
||||
markup = doctype + '\n<p>foo</p>'
|
||||
soup = self.soup(markup)
|
||||
return doctype, soup
|
||||
|
||||
def test_normal_doctypes(self):
|
||||
"""Make sure normal, everyday HTML doctypes are handled correctly."""
|
||||
self.assertDoctypeHandled("html")
|
||||
self.assertDoctypeHandled(
|
||||
'html PUBLIC "-//W3C//DTD XHTML 1.0 Transitional//EN"')
|
||||
|
||||
def test_empty_doctype(self):
|
||||
soup = self.soup("<!DOCTYPE>")
|
||||
doctype = soup.contents[0]
|
||||
self.assertEqual("", doctype.strip())
|
||||
|
||||
def test_public_doctype_with_url(self):
|
||||
doctype = 'html PUBLIC "-//W3C//DTD XHTML 1.0 Transitional//EN" "http://www.w3.org/TR/xhtml1/DTD/xhtml1-transitional.dtd"'
|
||||
self.assertDoctypeHandled(doctype)
|
||||
|
||||
def test_system_doctype(self):
|
||||
self.assertDoctypeHandled('foo SYSTEM "http://www.example.com/"')
|
||||
|
||||
def test_namespaced_system_doctype(self):
|
||||
# We can handle a namespaced doctype with a system ID.
|
||||
self.assertDoctypeHandled('xsl:stylesheet SYSTEM "htmlent.dtd"')
|
||||
|
||||
def test_namespaced_public_doctype(self):
|
||||
# Test a namespaced doctype with a public id.
|
||||
self.assertDoctypeHandled('xsl:stylesheet PUBLIC "htmlent.dtd"')
|
||||
|
||||
def test_real_xhtml_document(self):
|
||||
"""A real XHTML document should come out more or less the same as it went in."""
|
||||
markup = b"""<?xml version="1.0" encoding="utf-8"?>
|
||||
<!DOCTYPE html PUBLIC "-//W3C//DTD XHTML 1.0 Transitional//EN">
|
||||
<html xmlns="http://www.w3.org/1999/xhtml">
|
||||
<head><title>Hello.</title></head>
|
||||
<body>Goodbye.</body>
|
||||
</html>"""
|
||||
soup = self.soup(markup)
|
||||
self.assertEqual(
|
||||
soup.encode("utf-8").replace(b"\n", b""),
|
||||
markup.replace(b"\n", b""))
|
||||
|
||||
def test_processing_instruction(self):
|
||||
markup = b"""<?PITarget PIContent?>"""
|
||||
soup = self.soup(markup)
|
||||
self.assertEqual(markup, soup.encode("utf8"))
|
||||
|
||||
def test_deepcopy(self):
|
||||
"""Make sure you can copy the tree builder.
|
||||
|
||||
This is important because the builder is part of a
|
||||
BeautifulSoup object, and we want to be able to copy that.
|
||||
"""
|
||||
copy.deepcopy(self.default_builder)
|
||||
|
||||
def test_p_tag_is_never_empty_element(self):
|
||||
"""A <p> tag is never designated as an empty-element tag.
|
||||
|
||||
Even if the markup shows it as an empty-element tag, it
|
||||
shouldn't be presented that way.
|
||||
"""
|
||||
soup = self.soup("<p/>")
|
||||
self.assertFalse(soup.p.is_empty_element)
|
||||
self.assertEqual(str(soup.p), "<p></p>")
|
||||
|
||||
def test_unclosed_tags_get_closed(self):
|
||||
"""A tag that's not closed by the end of the document should be closed.
|
||||
|
||||
This applies to all tags except empty-element tags.
|
||||
"""
|
||||
self.assertSoupEquals("<p>", "<p></p>")
|
||||
self.assertSoupEquals("<b>", "<b></b>")
|
||||
|
||||
self.assertSoupEquals("<br>", "<br/>")
|
||||
|
||||
def test_br_is_always_empty_element_tag(self):
|
||||
"""A <br> tag is designated as an empty-element tag.
|
||||
|
||||
Some parsers treat <br></br> as one <br/> tag, some parsers as
|
||||
two tags, but it should always be an empty-element tag.
|
||||
"""
|
||||
soup = self.soup("<br></br>")
|
||||
self.assertTrue(soup.br.is_empty_element)
|
||||
self.assertEqual(str(soup.br), "<br/>")
|
||||
|
||||
def test_nested_formatting_elements(self):
|
||||
self.assertSoupEquals("<em><em></em></em>")
|
||||
|
||||
def test_double_head(self):
|
||||
html = '''<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<title>Ordinary HEAD element test</title>
|
||||
</head>
|
||||
<script type="text/javascript">
|
||||
alert("Help!");
|
||||
</script>
|
||||
<body>
|
||||
Hello, world!
|
||||
</body>
|
||||
</html>
|
||||
'''
|
||||
soup = self.soup(html)
|
||||
self.assertEqual("text/javascript", soup.find('script')['type'])
|
||||
|
||||
def test_comment(self):
|
||||
# Comments are represented as Comment objects.
|
||||
markup = "<p>foo<!--foobar-->baz</p>"
|
||||
self.assertSoupEquals(markup)
|
||||
|
||||
soup = self.soup(markup)
|
||||
comment = soup.find(text="foobar")
|
||||
self.assertEqual(comment.__class__, Comment)
|
||||
|
||||
# The comment is properly integrated into the tree.
|
||||
foo = soup.find(text="foo")
|
||||
self.assertEqual(comment, foo.next_element)
|
||||
baz = soup.find(text="baz")
|
||||
self.assertEqual(comment, baz.previous_element)
|
||||
|
||||
def test_preserved_whitespace_in_pre_and_textarea(self):
|
||||
"""Whitespace must be preserved in <pre> and <textarea> tags."""
|
||||
self.assertSoupEquals("<pre> </pre>")
|
||||
self.assertSoupEquals("<textarea> woo </textarea>")
|
||||
|
||||
def test_nested_inline_elements(self):
|
||||
"""Inline elements can be nested indefinitely."""
|
||||
b_tag = "<b>Inside a B tag</b>"
|
||||
self.assertSoupEquals(b_tag)
|
||||
|
||||
nested_b_tag = "<p>A <i>nested <b>tag</b></i></p>"
|
||||
self.assertSoupEquals(nested_b_tag)
|
||||
|
||||
double_nested_b_tag = "<p>A <a>doubly <i>nested <b>tag</b></i></a></p>"
|
||||
self.assertSoupEquals(nested_b_tag)
|
||||
|
||||
def test_nested_block_level_elements(self):
|
||||
"""Block elements can be nested."""
|
||||
soup = self.soup('<blockquote><p><b>Foo</b></p></blockquote>')
|
||||
blockquote = soup.blockquote
|
||||
self.assertEqual(blockquote.p.b.string, 'Foo')
|
||||
self.assertEqual(blockquote.b.string, 'Foo')
|
||||
|
||||
def test_correctly_nested_tables(self):
|
||||
"""One table can go inside another one."""
|
||||
markup = ('<table id="1">'
|
||||
'<tr>'
|
||||
"<td>Here's another table:"
|
||||
'<table id="2">'
|
||||
'<tr><td>foo</td></tr>'
|
||||
'</table></td>')
|
||||
|
||||
self.assertSoupEquals(
|
||||
markup,
|
||||
'<table id="1"><tr><td>Here\'s another table:'
|
||||
'<table id="2"><tr><td>foo</td></tr></table>'
|
||||
'</td></tr></table>')
|
||||
|
||||
self.assertSoupEquals(
|
||||
"<table><thead><tr><td>Foo</td></tr></thead>"
|
||||
"<tbody><tr><td>Bar</td></tr></tbody>"
|
||||
"<tfoot><tr><td>Baz</td></tr></tfoot></table>")
|
||||
|
||||
def test_deeply_nested_multivalued_attribute(self):
|
||||
# html5lib can set the attributes of the same tag many times
|
||||
# as it rearranges the tree. This has caused problems with
|
||||
# multivalued attributes.
|
||||
markup = '<table><div><div class="css"></div></div></table>'
|
||||
soup = self.soup(markup)
|
||||
self.assertEqual(["css"], soup.div.div['class'])
|
||||
|
||||
def test_multivalued_attribute_on_html(self):
|
||||
# html5lib uses a different API to set the attributes ot the
|
||||
# <html> tag. This has caused problems with multivalued
|
||||
# attributes.
|
||||
markup = '<html class="a b"></html>'
|
||||
soup = self.soup(markup)
|
||||
self.assertEqual(["a", "b"], soup.html['class'])
|
||||
|
||||
def test_angle_brackets_in_attribute_values_are_escaped(self):
|
||||
self.assertSoupEquals('<a b="<a>"></a>', '<a b="<a>"></a>')
|
||||
|
||||
def test_entities_in_attributes_converted_to_unicode(self):
|
||||
expect = u'<p id="pi\N{LATIN SMALL LETTER N WITH TILDE}ata"></p>'
|
||||
self.assertSoupEquals('<p id="piñata"></p>', expect)
|
||||
self.assertSoupEquals('<p id="piñata"></p>', expect)
|
||||
self.assertSoupEquals('<p id="piñata"></p>', expect)
|
||||
self.assertSoupEquals('<p id="piñata"></p>', expect)
|
||||
|
||||
def test_entities_in_text_converted_to_unicode(self):
|
||||
expect = u'<p>pi\N{LATIN SMALL LETTER N WITH TILDE}ata</p>'
|
||||
self.assertSoupEquals("<p>piñata</p>", expect)
|
||||
self.assertSoupEquals("<p>piñata</p>", expect)
|
||||
self.assertSoupEquals("<p>piñata</p>", expect)
|
||||
self.assertSoupEquals("<p>piñata</p>", expect)
|
||||
|
||||
def test_quot_entity_converted_to_quotation_mark(self):
|
||||
self.assertSoupEquals("<p>I said "good day!"</p>",
|
||||
'<p>I said "good day!"</p>')
|
||||
|
||||
def test_out_of_range_entity(self):
|
||||
expect = u"\N{REPLACEMENT CHARACTER}"
|
||||
self.assertSoupEquals("�", expect)
|
||||
self.assertSoupEquals("�", expect)
|
||||
self.assertSoupEquals("�", expect)
|
||||
|
||||
def test_multipart_strings(self):
|
||||
"Mostly to prevent a recurrence of a bug in the html5lib treebuilder."
|
||||
soup = self.soup("<html><h2>\nfoo</h2><p></p></html>")
|
||||
self.assertEqual("p", soup.h2.string.next_element.name)
|
||||
self.assertEqual("p", soup.p.name)
|
||||
self.assertConnectedness(soup)
|
||||
|
||||
def test_head_tag_between_head_and_body(self):
|
||||
"Prevent recurrence of a bug in the html5lib treebuilder."
|
||||
content = """<html><head></head>
|
||||
<link></link>
|
||||
<body>foo</body>
|
||||
</html>
|
||||
"""
|
||||
soup = self.soup(content)
|
||||
self.assertNotEqual(None, soup.html.body)
|
||||
self.assertConnectedness(soup)
|
||||
|
||||
def test_multiple_copies_of_a_tag(self):
|
||||
"Prevent recurrence of a bug in the html5lib treebuilder."
|
||||
content = """<!DOCTYPE html>
|
||||
<html>
|
||||
<body>
|
||||
<article id="a" >
|
||||
<div><a href="1"></div>
|
||||
<footer>
|
||||
<a href="2"></a>
|
||||
</footer>
|
||||
</article>
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
soup = self.soup(content)
|
||||
self.assertConnectedness(soup.article)
|
||||
|
||||
def test_basic_namespaces(self):
|
||||
"""Parsers don't need to *understand* namespaces, but at the
|
||||
very least they should not choke on namespaces or lose
|
||||
data."""
|
||||
|
||||
markup = b'<html xmlns="http://www.w3.org/1999/xhtml" xmlns:mathml="http://www.w3.org/1998/Math/MathML" xmlns:svg="http://www.w3.org/2000/svg"><head></head><body><mathml:msqrt>4</mathml:msqrt><b svg:fill="red"></b></body></html>'
|
||||
soup = self.soup(markup)
|
||||
self.assertEqual(markup, soup.encode())
|
||||
html = soup.html
|
||||
self.assertEqual('http://www.w3.org/1999/xhtml', soup.html['xmlns'])
|
||||
self.assertEqual(
|
||||
'http://www.w3.org/1998/Math/MathML', soup.html['xmlns:mathml'])
|
||||
self.assertEqual(
|
||||
'http://www.w3.org/2000/svg', soup.html['xmlns:svg'])
|
||||
|
||||
def test_multivalued_attribute_value_becomes_list(self):
|
||||
markup = b'<a class="foo bar">'
|
||||
soup = self.soup(markup)
|
||||
self.assertEqual(['foo', 'bar'], soup.a['class'])
|
||||
|
||||
#
|
||||
# Generally speaking, tests below this point are more tests of
|
||||
# Beautiful Soup than tests of the tree builders. But parsers are
|
||||
# weird, so we run these tests separately for every tree builder
|
||||
# to detect any differences between them.
|
||||
#
|
||||
|
||||
def test_can_parse_unicode_document(self):
|
||||
# A seemingly innocuous document... but it's in Unicode! And
|
||||
# it contains characters that can't be represented in the
|
||||
# encoding found in the declaration! The horror!
|
||||
markup = u'<html><head><meta encoding="euc-jp"></head><body>Sacr\N{LATIN SMALL LETTER E WITH ACUTE} bleu!</body>'
|
||||
soup = self.soup(markup)
|
||||
self.assertEqual(u'Sacr\xe9 bleu!', soup.body.string)
|
||||
|
||||
def test_soupstrainer(self):
|
||||
"""Parsers should be able to work with SoupStrainers."""
|
||||
strainer = SoupStrainer("b")
|
||||
soup = self.soup("A <b>bold</b> <meta/> <i>statement</i>",
|
||||
parse_only=strainer)
|
||||
self.assertEqual(soup.decode(), "<b>bold</b>")
|
||||
|
||||
def test_single_quote_attribute_values_become_double_quotes(self):
|
||||
self.assertSoupEquals("<foo attr='bar'></foo>",
|
||||
'<foo attr="bar"></foo>')
|
||||
|
||||
def test_attribute_values_with_nested_quotes_are_left_alone(self):
|
||||
text = """<foo attr='bar "brawls" happen'>a</foo>"""
|
||||
self.assertSoupEquals(text)
|
||||
|
||||
def test_attribute_values_with_double_nested_quotes_get_quoted(self):
|
||||
text = """<foo attr='bar "brawls" happen'>a</foo>"""
|
||||
soup = self.soup(text)
|
||||
soup.foo['attr'] = 'Brawls happen at "Bob\'s Bar"'
|
||||
self.assertSoupEquals(
|
||||
soup.foo.decode(),
|
||||
"""<foo attr="Brawls happen at "Bob\'s Bar"">a</foo>""")
|
||||
|
||||
def test_ampersand_in_attribute_value_gets_escaped(self):
|
||||
self.assertSoupEquals('<this is="really messed up & stuff"></this>',
|
||||
'<this is="really messed up & stuff"></this>')
|
||||
|
||||
self.assertSoupEquals(
|
||||
'<a href="http://example.org?a=1&b=2;3">foo</a>',
|
||||
'<a href="http://example.org?a=1&b=2;3">foo</a>')
|
||||
|
||||
def test_escaped_ampersand_in_attribute_value_is_left_alone(self):
|
||||
self.assertSoupEquals('<a href="http://example.org?a=1&b=2;3"></a>')
|
||||
|
||||
def test_entities_in_strings_converted_during_parsing(self):
|
||||
# Both XML and HTML entities are converted to Unicode characters
|
||||
# during parsing.
|
||||
text = "<p><<sacré bleu!>></p>"
|
||||
expected = u"<p><<sacr\N{LATIN SMALL LETTER E WITH ACUTE} bleu!>></p>"
|
||||
self.assertSoupEquals(text, expected)
|
||||
|
||||
def test_smart_quotes_converted_on_the_way_in(self):
|
||||
# Microsoft smart quotes are converted to Unicode characters during
|
||||
# parsing.
|
||||
quote = b"<p>\x91Foo\x92</p>"
|
||||
soup = self.soup(quote)
|
||||
self.assertEqual(
|
||||
soup.p.string,
|
||||
u"\N{LEFT SINGLE QUOTATION MARK}Foo\N{RIGHT SINGLE QUOTATION MARK}")
|
||||
|
||||
def test_non_breaking_spaces_converted_on_the_way_in(self):
|
||||
soup = self.soup("<a> </a>")
|
||||
self.assertEqual(soup.a.string, u"\N{NO-BREAK SPACE}" * 2)
|
||||
|
||||
def test_entities_converted_on_the_way_out(self):
|
||||
text = "<p><<sacré bleu!>></p>"
|
||||
expected = u"<p><<sacr\N{LATIN SMALL LETTER E WITH ACUTE} bleu!>></p>".encode("utf-8")
|
||||
soup = self.soup(text)
|
||||
self.assertEqual(soup.p.encode("utf-8"), expected)
|
||||
|
||||
def test_real_iso_latin_document(self):
|
||||
# Smoke test of interrelated functionality, using an
|
||||
# easy-to-understand document.
|
||||
|
||||
# Here it is in Unicode. Note that it claims to be in ISO-Latin-1.
|
||||
unicode_html = u'<html><head><meta content="text/html; charset=ISO-Latin-1" http-equiv="Content-type"/></head><body><p>Sacr\N{LATIN SMALL LETTER E WITH ACUTE} bleu!</p></body></html>'
|
||||
|
||||
# That's because we're going to encode it into ISO-Latin-1, and use
|
||||
# that to test.
|
||||
iso_latin_html = unicode_html.encode("iso-8859-1")
|
||||
|
||||
# Parse the ISO-Latin-1 HTML.
|
||||
soup = self.soup(iso_latin_html)
|
||||
# Encode it to UTF-8.
|
||||
result = soup.encode("utf-8")
|
||||
|
||||
# What do we expect the result to look like? Well, it would
|
||||
# look like unicode_html, except that the META tag would say
|
||||
# UTF-8 instead of ISO-Latin-1.
|
||||
expected = unicode_html.replace("ISO-Latin-1", "utf-8")
|
||||
|
||||
# And, of course, it would be in UTF-8, not Unicode.
|
||||
expected = expected.encode("utf-8")
|
||||
|
||||
# Ta-da!
|
||||
self.assertEqual(result, expected)
|
||||
|
||||
def test_real_shift_jis_document(self):
|
||||
# Smoke test to make sure the parser can handle a document in
|
||||
# Shift-JIS encoding, without choking.
|
||||
shift_jis_html = (
|
||||
b'<html><head></head><body><pre>'
|
||||
b'\x82\xb1\x82\xea\x82\xcdShift-JIS\x82\xc5\x83R\x81[\x83f'
|
||||
b'\x83B\x83\x93\x83O\x82\xb3\x82\xea\x82\xbd\x93\xfa\x96{\x8c'
|
||||
b'\xea\x82\xcc\x83t\x83@\x83C\x83\x8b\x82\xc5\x82\xb7\x81B'
|
||||
b'</pre></body></html>')
|
||||
unicode_html = shift_jis_html.decode("shift-jis")
|
||||
soup = self.soup(unicode_html)
|
||||
|
||||
# Make sure the parse tree is correctly encoded to various
|
||||
# encodings.
|
||||
self.assertEqual(soup.encode("utf-8"), unicode_html.encode("utf-8"))
|
||||
self.assertEqual(soup.encode("euc_jp"), unicode_html.encode("euc_jp"))
|
||||
|
||||
def test_real_hebrew_document(self):
|
||||
# A real-world test to make sure we can convert ISO-8859-9 (a
|
||||
# Hebrew encoding) to UTF-8.
|
||||
hebrew_document = b'<html><head><title>Hebrew (ISO 8859-8) in Visual Directionality</title></head><body><h1>Hebrew (ISO 8859-8) in Visual Directionality</h1>\xed\xe5\xec\xf9</body></html>'
|
||||
soup = self.soup(
|
||||
hebrew_document, from_encoding="iso8859-8")
|
||||
self.assertEqual(soup.original_encoding, 'iso8859-8')
|
||||
self.assertEqual(
|
||||
soup.encode('utf-8'),
|
||||
hebrew_document.decode("iso8859-8").encode("utf-8"))
|
||||
|
||||
def test_meta_tag_reflects_current_encoding(self):
|
||||
# Here's the <meta> tag saying that a document is
|
||||
# encoded in Shift-JIS.
|
||||
meta_tag = ('<meta content="text/html; charset=x-sjis" '
|
||||
'http-equiv="Content-type"/>')
|
||||
|
||||
# Here's a document incorporating that meta tag.
|
||||
shift_jis_html = (
|
||||
'<html><head>\n%s\n'
|
||||
'<meta http-equiv="Content-language" content="ja"/>'
|
||||
'</head><body>Shift-JIS markup goes here.') % meta_tag
|
||||
soup = self.soup(shift_jis_html)
|
||||
|
||||
# Parse the document, and the charset is seemingly unaffected.
|
||||
parsed_meta = soup.find('meta', {'http-equiv': 'Content-type'})
|
||||
content = parsed_meta['content']
|
||||
self.assertEqual('text/html; charset=x-sjis', content)
|
||||
|
||||
# But that value is actually a ContentMetaAttributeValue object.
|
||||
self.assertTrue(isinstance(content, ContentMetaAttributeValue))
|
||||
|
||||
# And it will take on a value that reflects its current
|
||||
# encoding.
|
||||
self.assertEqual('text/html; charset=utf8', content.encode("utf8"))
|
||||
|
||||
# For the rest of the story, see TestSubstitutions in
|
||||
# test_tree.py.
|
||||
|
||||
def test_html5_style_meta_tag_reflects_current_encoding(self):
|
||||
# Here's the <meta> tag saying that a document is
|
||||
# encoded in Shift-JIS.
|
||||
meta_tag = ('<meta id="encoding" charset="x-sjis" />')
|
||||
|
||||
# Here's a document incorporating that meta tag.
|
||||
shift_jis_html = (
|
||||
'<html><head>\n%s\n'
|
||||
'<meta http-equiv="Content-language" content="ja"/>'
|
||||
'</head><body>Shift-JIS markup goes here.') % meta_tag
|
||||
soup = self.soup(shift_jis_html)
|
||||
|
||||
# Parse the document, and the charset is seemingly unaffected.
|
||||
parsed_meta = soup.find('meta', id="encoding")
|
||||
charset = parsed_meta['charset']
|
||||
self.assertEqual('x-sjis', charset)
|
||||
|
||||
# But that value is actually a CharsetMetaAttributeValue object.
|
||||
self.assertTrue(isinstance(charset, CharsetMetaAttributeValue))
|
||||
|
||||
# And it will take on a value that reflects its current
|
||||
# encoding.
|
||||
self.assertEqual('utf8', charset.encode("utf8"))
|
||||
|
||||
def test_tag_with_no_attributes_can_have_attributes_added(self):
|
||||
data = self.soup("<a>text</a>")
|
||||
data.a['foo'] = 'bar'
|
||||
self.assertEqual('<a foo="bar">text</a>', data.a.decode())
|
||||
|
||||
class XMLTreeBuilderSmokeTest(object):
|
||||
|
||||
def test_pickle_and_unpickle_identity(self):
|
||||
# Pickling a tree, then unpickling it, yields a tree identical
|
||||
# to the original.
|
||||
tree = self.soup("<a><b>foo</a>")
|
||||
dumped = pickle.dumps(tree, 2)
|
||||
loaded = pickle.loads(dumped)
|
||||
self.assertEqual(loaded.__class__, BeautifulSoup)
|
||||
self.assertEqual(loaded.decode(), tree.decode())
|
||||
|
||||
def test_docstring_generated(self):
|
||||
soup = self.soup("<root/>")
|
||||
self.assertEqual(
|
||||
soup.encode(), b'<?xml version="1.0" encoding="utf-8"?>\n<root/>')
|
||||
|
||||
def test_real_xhtml_document(self):
|
||||
"""A real XHTML document should come out *exactly* the same as it went in."""
|
||||
markup = b"""<?xml version="1.0" encoding="utf-8"?>
|
||||
<!DOCTYPE html PUBLIC "-//W3C//DTD XHTML 1.0 Transitional//EN">
|
||||
<html xmlns="http://www.w3.org/1999/xhtml">
|
||||
<head><title>Hello.</title></head>
|
||||
<body>Goodbye.</body>
|
||||
</html>"""
|
||||
soup = self.soup(markup)
|
||||
self.assertEqual(
|
||||
soup.encode("utf-8"), markup)
|
||||
|
||||
def test_formatter_processes_script_tag_for_xml_documents(self):
|
||||
doc = """
|
||||
<script type="text/javascript">
|
||||
</script>
|
||||
"""
|
||||
soup = BeautifulSoup(doc, "lxml-xml")
|
||||
# lxml would have stripped this while parsing, but we can add
|
||||
# it later.
|
||||
soup.script.string = 'console.log("< < hey > > ");'
|
||||
encoded = soup.encode()
|
||||
self.assertTrue(b"< < hey > >" in encoded)
|
||||
|
||||
def test_can_parse_unicode_document(self):
|
||||
markup = u'<?xml version="1.0" encoding="euc-jp"><root>Sacr\N{LATIN SMALL LETTER E WITH ACUTE} bleu!</root>'
|
||||
soup = self.soup(markup)
|
||||
self.assertEqual(u'Sacr\xe9 bleu!', soup.root.string)
|
||||
|
||||
def test_popping_namespaced_tag(self):
|
||||
markup = '<rss xmlns:dc="foo"><dc:creator>b</dc:creator><dc:date>2012-07-02T20:33:42Z</dc:date><dc:rights>c</dc:rights><image>d</image></rss>'
|
||||
soup = self.soup(markup)
|
||||
self.assertEqual(
|
||||
unicode(soup.rss), markup)
|
||||
|
||||
def test_docstring_includes_correct_encoding(self):
|
||||
soup = self.soup("<root/>")
|
||||
self.assertEqual(
|
||||
soup.encode("latin1"),
|
||||
b'<?xml version="1.0" encoding="latin1"?>\n<root/>')
|
||||
|
||||
def test_large_xml_document(self):
|
||||
"""A large XML document should come out the same as it went in."""
|
||||
markup = (b'<?xml version="1.0" encoding="utf-8"?>\n<root>'
|
||||
+ b'0' * (2**12)
|
||||
+ b'</root>')
|
||||
soup = self.soup(markup)
|
||||
self.assertEqual(soup.encode("utf-8"), markup)
|
||||
|
||||
|
||||
def test_tags_are_empty_element_if_and_only_if_they_are_empty(self):
|
||||
self.assertSoupEquals("<p>", "<p/>")
|
||||
self.assertSoupEquals("<p>foo</p>")
|
||||
|
||||
def test_namespaces_are_preserved(self):
|
||||
markup = '<root xmlns:a="http://example.com/" xmlns:b="http://example.net/"><a:foo>This tag is in the a namespace</a:foo><b:foo>This tag is in the b namespace</b:foo></root>'
|
||||
soup = self.soup(markup)
|
||||
root = soup.root
|
||||
self.assertEqual("http://example.com/", root['xmlns:a'])
|
||||
self.assertEqual("http://example.net/", root['xmlns:b'])
|
||||
|
||||
def test_closing_namespaced_tag(self):
|
||||
markup = '<p xmlns:dc="http://purl.org/dc/elements/1.1/"><dc:date>20010504</dc:date></p>'
|
||||
soup = self.soup(markup)
|
||||
self.assertEqual(unicode(soup.p), markup)
|
||||
|
||||
def test_namespaced_attributes(self):
|
||||
markup = '<foo xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"><bar xsi:schemaLocation="http://www.example.com"/></foo>'
|
||||
soup = self.soup(markup)
|
||||
self.assertEqual(unicode(soup.foo), markup)
|
||||
|
||||
def test_namespaced_attributes_xml_namespace(self):
|
||||
markup = '<foo xml:lang="fr">bar</foo>'
|
||||
soup = self.soup(markup)
|
||||
self.assertEqual(unicode(soup.foo), markup)
|
||||
|
||||
class HTML5TreeBuilderSmokeTest(HTMLTreeBuilderSmokeTest):
|
||||
"""Smoke test for a tree builder that supports HTML5."""
|
||||
|
||||
def test_real_xhtml_document(self):
|
||||
# Since XHTML is not HTML5, HTML5 parsers are not tested to handle
|
||||
# XHTML documents in any particular way.
|
||||
pass
|
||||
|
||||
def test_html_tags_have_namespace(self):
|
||||
markup = "<a>"
|
||||
soup = self.soup(markup)
|
||||
self.assertEqual("http://www.w3.org/1999/xhtml", soup.a.namespace)
|
||||
|
||||
def test_svg_tags_have_namespace(self):
|
||||
markup = '<svg><circle/></svg>'
|
||||
soup = self.soup(markup)
|
||||
namespace = "http://www.w3.org/2000/svg"
|
||||
self.assertEqual(namespace, soup.svg.namespace)
|
||||
self.assertEqual(namespace, soup.circle.namespace)
|
||||
|
||||
|
||||
def test_mathml_tags_have_namespace(self):
|
||||
markup = '<math><msqrt>5</msqrt></math>'
|
||||
soup = self.soup(markup)
|
||||
namespace = 'http://www.w3.org/1998/Math/MathML'
|
||||
self.assertEqual(namespace, soup.math.namespace)
|
||||
self.assertEqual(namespace, soup.msqrt.namespace)
|
||||
|
||||
def test_xml_declaration_becomes_comment(self):
|
||||
markup = '<?xml version="1.0" encoding="utf-8"?><html></html>'
|
||||
soup = self.soup(markup)
|
||||
self.assertTrue(isinstance(soup.contents[0], Comment))
|
||||
self.assertEqual(soup.contents[0], '?xml version="1.0" encoding="utf-8"?')
|
||||
self.assertEqual("html", soup.contents[0].next_element.name)
|
||||
|
||||
def skipIf(condition, reason):
|
||||
def nothing(test, *args, **kwargs):
|
||||
return None
|
||||
|
||||
def decorator(test_item):
|
||||
if condition:
|
||||
return nothing
|
||||
else:
|
||||
return test_item
|
||||
|
||||
return decorator
|
||||
@@ -8,9 +8,8 @@
|
||||
Certificate generation module.
|
||||
"""
|
||||
|
||||
import time
|
||||
|
||||
from OpenSSL import crypto
|
||||
import time
|
||||
|
||||
TYPE_RSA = crypto.TYPE_RSA
|
||||
TYPE_DSA = crypto.TYPE_DSA
|
||||
@@ -30,7 +29,6 @@ def createKeyPair(type, bits):
|
||||
pkey.generate_key(type, bits)
|
||||
return pkey
|
||||
|
||||
|
||||
def createCertRequest(pkey, digest="md5", **name):
|
||||
"""
|
||||
Create a certificate request.
|
||||
@@ -51,14 +49,13 @@ def createCertRequest(pkey, digest="md5", **name):
|
||||
req = crypto.X509Req()
|
||||
subj = req.get_subject()
|
||||
|
||||
for (key, value) in name.items():
|
||||
for (key,value) in name.items():
|
||||
setattr(subj, key, value)
|
||||
|
||||
req.set_pubkey(pkey)
|
||||
req.sign(pkey, digest)
|
||||
return req
|
||||
|
||||
|
||||
def createCertificate(req, (issuerCert, issuerKey), serial, (notBefore, notAfter), digest="md5"):
|
||||
"""
|
||||
Generate a certificate given a certificate request.
|
||||
@@ -0,0 +1,25 @@
|
||||
Copyright (c) 2004-2011, CherryPy Team (team@cherrypy.org)
|
||||
All rights reserved.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without modification,
|
||||
are permitted provided that the following conditions are met:
|
||||
|
||||
* Redistributions of source code must retain the above copyright notice,
|
||||
this list of conditions and the following disclaimer.
|
||||
* Redistributions in binary form must reproduce the above copyright notice,
|
||||
this list of conditions and the following disclaimer in the documentation
|
||||
and/or other materials provided with the distribution.
|
||||
* Neither the name of the CherryPy Team nor the names of its contributors
|
||||
may be used to endorse or promote products derived from this software
|
||||
without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
||||
ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
||||
WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE
|
||||
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
@@ -0,0 +1,652 @@
|
||||
"""CherryPy is a pythonic, object-oriented HTTP framework.
|
||||
|
||||
|
||||
CherryPy consists of not one, but four separate API layers.
|
||||
|
||||
The APPLICATION LAYER is the simplest. CherryPy applications are written as
|
||||
a tree of classes and methods, where each branch in the tree corresponds to
|
||||
a branch in the URL path. Each method is a 'page handler', which receives
|
||||
GET and POST params as keyword arguments, and returns or yields the (HTML)
|
||||
body of the response. The special method name 'index' is used for paths
|
||||
that end in a slash, and the special method name 'default' is used to
|
||||
handle multiple paths via a single handler. This layer also includes:
|
||||
|
||||
* the 'exposed' attribute (and cherrypy.expose)
|
||||
* cherrypy.quickstart()
|
||||
* _cp_config attributes
|
||||
* cherrypy.tools (including cherrypy.session)
|
||||
* cherrypy.url()
|
||||
|
||||
The ENVIRONMENT LAYER is used by developers at all levels. It provides
|
||||
information about the current request and response, plus the application
|
||||
and server environment, via a (default) set of top-level objects:
|
||||
|
||||
* cherrypy.request
|
||||
* cherrypy.response
|
||||
* cherrypy.engine
|
||||
* cherrypy.server
|
||||
* cherrypy.tree
|
||||
* cherrypy.config
|
||||
* cherrypy.thread_data
|
||||
* cherrypy.log
|
||||
* cherrypy.HTTPError, NotFound, and HTTPRedirect
|
||||
* cherrypy.lib
|
||||
|
||||
The EXTENSION LAYER allows advanced users to construct and share their own
|
||||
plugins. It consists of:
|
||||
|
||||
* Hook API
|
||||
* Tool API
|
||||
* Toolbox API
|
||||
* Dispatch API
|
||||
* Config Namespace API
|
||||
|
||||
Finally, there is the CORE LAYER, which uses the core API's to construct
|
||||
the default components which are available at higher layers. You can think
|
||||
of the default components as the 'reference implementation' for CherryPy.
|
||||
Megaframeworks (and advanced users) may replace the default components
|
||||
with customized or extended components. The core API's are:
|
||||
|
||||
* Application API
|
||||
* Engine API
|
||||
* Request API
|
||||
* Server API
|
||||
* WSGI API
|
||||
|
||||
These API's are described in the `CherryPy specification <https://bitbucket.org/cherrypy/cherrypy/wiki/CherryPySpec>`_.
|
||||
"""
|
||||
|
||||
__version__ = "3.6.0"
|
||||
|
||||
from cherrypy._cpcompat import urljoin as _urljoin, urlencode as _urlencode
|
||||
from cherrypy._cpcompat import basestring, unicodestr, set
|
||||
|
||||
from cherrypy._cperror import HTTPError, HTTPRedirect, InternalRedirect
|
||||
from cherrypy._cperror import NotFound, CherryPyException, TimeoutError
|
||||
|
||||
from cherrypy import _cpdispatch as dispatch
|
||||
|
||||
from cherrypy import _cptools
|
||||
tools = _cptools.default_toolbox
|
||||
Tool = _cptools.Tool
|
||||
|
||||
from cherrypy import _cprequest
|
||||
from cherrypy.lib import httputil as _httputil
|
||||
|
||||
from cherrypy import _cptree
|
||||
tree = _cptree.Tree()
|
||||
from cherrypy._cptree import Application
|
||||
from cherrypy import _cpwsgi as wsgi
|
||||
|
||||
from cherrypy import process
|
||||
try:
|
||||
from cherrypy.process import win32
|
||||
engine = win32.Win32Bus()
|
||||
engine.console_control_handler = win32.ConsoleCtrlHandler(engine)
|
||||
del win32
|
||||
except ImportError:
|
||||
engine = process.bus
|
||||
|
||||
|
||||
# Timeout monitor. We add two channels to the engine
|
||||
# to which cherrypy.Application will publish.
|
||||
engine.listeners['before_request'] = set()
|
||||
engine.listeners['after_request'] = set()
|
||||
|
||||
|
||||
class _TimeoutMonitor(process.plugins.Monitor):
|
||||
|
||||
def __init__(self, bus):
|
||||
self.servings = []
|
||||
process.plugins.Monitor.__init__(self, bus, self.run)
|
||||
|
||||
def before_request(self):
|
||||
self.servings.append((serving.request, serving.response))
|
||||
|
||||
def after_request(self):
|
||||
try:
|
||||
self.servings.remove((serving.request, serving.response))
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
def run(self):
|
||||
"""Check timeout on all responses. (Internal)"""
|
||||
for req, resp in self.servings:
|
||||
resp.check_timeout()
|
||||
engine.timeout_monitor = _TimeoutMonitor(engine)
|
||||
engine.timeout_monitor.subscribe()
|
||||
|
||||
engine.autoreload = process.plugins.Autoreloader(engine)
|
||||
engine.autoreload.subscribe()
|
||||
|
||||
engine.thread_manager = process.plugins.ThreadManager(engine)
|
||||
engine.thread_manager.subscribe()
|
||||
|
||||
engine.signal_handler = process.plugins.SignalHandler(engine)
|
||||
|
||||
|
||||
class _HandleSignalsPlugin(object):
|
||||
|
||||
"""Handle signals from other processes based on the configured
|
||||
platform handlers above."""
|
||||
|
||||
def __init__(self, bus):
|
||||
self.bus = bus
|
||||
|
||||
def subscribe(self):
|
||||
"""Add the handlers based on the platform"""
|
||||
if hasattr(self.bus, "signal_handler"):
|
||||
self.bus.signal_handler.subscribe()
|
||||
if hasattr(self.bus, "console_control_handler"):
|
||||
self.bus.console_control_handler.subscribe()
|
||||
|
||||
engine.signals = _HandleSignalsPlugin(engine)
|
||||
|
||||
|
||||
from cherrypy import _cpserver
|
||||
server = _cpserver.Server()
|
||||
server.subscribe()
|
||||
|
||||
|
||||
def quickstart(root=None, script_name="", config=None):
|
||||
"""Mount the given root, start the builtin server (and engine), then block.
|
||||
|
||||
root: an instance of a "controller class" (a collection of page handler
|
||||
methods) which represents the root of the application.
|
||||
script_name: a string containing the "mount point" of the application.
|
||||
This should start with a slash, and be the path portion of the URL
|
||||
at which to mount the given root. For example, if root.index() will
|
||||
handle requests to "http://www.example.com:8080/dept/app1/", then
|
||||
the script_name argument would be "/dept/app1".
|
||||
|
||||
It MUST NOT end in a slash. If the script_name refers to the root
|
||||
of the URI, it MUST be an empty string (not "/").
|
||||
config: a file or dict containing application config. If this contains
|
||||
a [global] section, those entries will be used in the global
|
||||
(site-wide) config.
|
||||
"""
|
||||
if config:
|
||||
_global_conf_alias.update(config)
|
||||
|
||||
tree.mount(root, script_name, config)
|
||||
|
||||
engine.signals.subscribe()
|
||||
engine.start()
|
||||
engine.block()
|
||||
|
||||
|
||||
from cherrypy._cpcompat import threadlocal as _local
|
||||
|
||||
|
||||
class _Serving(_local):
|
||||
|
||||
"""An interface for registering request and response objects.
|
||||
|
||||
Rather than have a separate "thread local" object for the request and
|
||||
the response, this class works as a single threadlocal container for
|
||||
both objects (and any others which developers wish to define). In this
|
||||
way, we can easily dump those objects when we stop/start a new HTTP
|
||||
conversation, yet still refer to them as module-level globals in a
|
||||
thread-safe way.
|
||||
"""
|
||||
|
||||
request = _cprequest.Request(_httputil.Host("127.0.0.1", 80),
|
||||
_httputil.Host("127.0.0.1", 1111))
|
||||
"""
|
||||
The request object for the current thread. In the main thread,
|
||||
and any threads which are not receiving HTTP requests, this is None."""
|
||||
|
||||
response = _cprequest.Response()
|
||||
"""
|
||||
The response object for the current thread. In the main thread,
|
||||
and any threads which are not receiving HTTP requests, this is None."""
|
||||
|
||||
def load(self, request, response):
|
||||
self.request = request
|
||||
self.response = response
|
||||
|
||||
def clear(self):
|
||||
"""Remove all attributes of self."""
|
||||
self.__dict__.clear()
|
||||
|
||||
serving = _Serving()
|
||||
|
||||
|
||||
class _ThreadLocalProxy(object):
|
||||
|
||||
__slots__ = ['__attrname__', '__dict__']
|
||||
|
||||
def __init__(self, attrname):
|
||||
self.__attrname__ = attrname
|
||||
|
||||
def __getattr__(self, name):
|
||||
child = getattr(serving, self.__attrname__)
|
||||
return getattr(child, name)
|
||||
|
||||
def __setattr__(self, name, value):
|
||||
if name in ("__attrname__", ):
|
||||
object.__setattr__(self, name, value)
|
||||
else:
|
||||
child = getattr(serving, self.__attrname__)
|
||||
setattr(child, name, value)
|
||||
|
||||
def __delattr__(self, name):
|
||||
child = getattr(serving, self.__attrname__)
|
||||
delattr(child, name)
|
||||
|
||||
def _get_dict(self):
|
||||
child = getattr(serving, self.__attrname__)
|
||||
d = child.__class__.__dict__.copy()
|
||||
d.update(child.__dict__)
|
||||
return d
|
||||
__dict__ = property(_get_dict)
|
||||
|
||||
def __getitem__(self, key):
|
||||
child = getattr(serving, self.__attrname__)
|
||||
return child[key]
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
child = getattr(serving, self.__attrname__)
|
||||
child[key] = value
|
||||
|
||||
def __delitem__(self, key):
|
||||
child = getattr(serving, self.__attrname__)
|
||||
del child[key]
|
||||
|
||||
def __contains__(self, key):
|
||||
child = getattr(serving, self.__attrname__)
|
||||
return key in child
|
||||
|
||||
def __len__(self):
|
||||
child = getattr(serving, self.__attrname__)
|
||||
return len(child)
|
||||
|
||||
def __nonzero__(self):
|
||||
child = getattr(serving, self.__attrname__)
|
||||
return bool(child)
|
||||
# Python 3
|
||||
__bool__ = __nonzero__
|
||||
|
||||
# Create request and response object (the same objects will be used
|
||||
# throughout the entire life of the webserver, but will redirect
|
||||
# to the "serving" object)
|
||||
request = _ThreadLocalProxy('request')
|
||||
response = _ThreadLocalProxy('response')
|
||||
|
||||
# Create thread_data object as a thread-specific all-purpose storage
|
||||
|
||||
|
||||
class _ThreadData(_local):
|
||||
|
||||
"""A container for thread-specific data."""
|
||||
thread_data = _ThreadData()
|
||||
|
||||
|
||||
# Monkeypatch pydoc to allow help() to go through the threadlocal proxy.
|
||||
# Jan 2007: no Googleable examples of anyone else replacing pydoc.resolve.
|
||||
# The only other way would be to change what is returned from type(request)
|
||||
# and that's not possible in pure Python (you'd have to fake ob_type).
|
||||
def _cherrypy_pydoc_resolve(thing, forceload=0):
|
||||
"""Given an object or a path to an object, get the object and its name."""
|
||||
if isinstance(thing, _ThreadLocalProxy):
|
||||
thing = getattr(serving, thing.__attrname__)
|
||||
return _pydoc._builtin_resolve(thing, forceload)
|
||||
|
||||
try:
|
||||
import pydoc as _pydoc
|
||||
_pydoc._builtin_resolve = _pydoc.resolve
|
||||
_pydoc.resolve = _cherrypy_pydoc_resolve
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
from cherrypy import _cplogging
|
||||
|
||||
|
||||
class _GlobalLogManager(_cplogging.LogManager):
|
||||
|
||||
"""A site-wide LogManager; routes to app.log or global log as appropriate.
|
||||
|
||||
This :class:`LogManager<cherrypy._cplogging.LogManager>` implements
|
||||
cherrypy.log() and cherrypy.log.access(). If either
|
||||
function is called during a request, the message will be sent to the
|
||||
logger for the current Application. If they are called outside of a
|
||||
request, the message will be sent to the site-wide logger.
|
||||
"""
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
"""Log the given message to the app.log or global log as appropriate.
|
||||
"""
|
||||
# Do NOT use try/except here. See
|
||||
# https://bitbucket.org/cherrypy/cherrypy/issue/945
|
||||
if hasattr(request, 'app') and hasattr(request.app, 'log'):
|
||||
log = request.app.log
|
||||
else:
|
||||
log = self
|
||||
return log.error(*args, **kwargs)
|
||||
|
||||
def access(self):
|
||||
"""Log an access message to the app.log or global log as appropriate.
|
||||
"""
|
||||
try:
|
||||
return request.app.log.access()
|
||||
except AttributeError:
|
||||
return _cplogging.LogManager.access(self)
|
||||
|
||||
|
||||
log = _GlobalLogManager()
|
||||
# Set a default screen handler on the global log.
|
||||
log.screen = True
|
||||
log.error_file = ''
|
||||
# Using an access file makes CP about 10% slower. Leave off by default.
|
||||
log.access_file = ''
|
||||
|
||||
|
||||
def _buslog(msg, level):
|
||||
log.error(msg, 'ENGINE', severity=level)
|
||||
engine.subscribe('log', _buslog)
|
||||
|
||||
# Helper functions for CP apps #
|
||||
|
||||
|
||||
def expose(func=None, alias=None):
|
||||
"""Expose the function, optionally providing an alias or set of aliases."""
|
||||
def expose_(func):
|
||||
func.exposed = True
|
||||
if alias is not None:
|
||||
if isinstance(alias, basestring):
|
||||
parents[alias.replace(".", "_")] = func
|
||||
else:
|
||||
for a in alias:
|
||||
parents[a.replace(".", "_")] = func
|
||||
return func
|
||||
|
||||
import sys
|
||||
import types
|
||||
if isinstance(func, (types.FunctionType, types.MethodType)):
|
||||
if alias is None:
|
||||
# @expose
|
||||
func.exposed = True
|
||||
return func
|
||||
else:
|
||||
# func = expose(func, alias)
|
||||
parents = sys._getframe(1).f_locals
|
||||
return expose_(func)
|
||||
elif func is None:
|
||||
if alias is None:
|
||||
# @expose()
|
||||
parents = sys._getframe(1).f_locals
|
||||
return expose_
|
||||
else:
|
||||
# @expose(alias="alias") or
|
||||
# @expose(alias=["alias1", "alias2"])
|
||||
parents = sys._getframe(1).f_locals
|
||||
return expose_
|
||||
else:
|
||||
# @expose("alias") or
|
||||
# @expose(["alias1", "alias2"])
|
||||
parents = sys._getframe(1).f_locals
|
||||
alias = func
|
||||
return expose_
|
||||
|
||||
|
||||
def popargs(*args, **kwargs):
|
||||
"""A decorator for _cp_dispatch
|
||||
(cherrypy.dispatch.Dispatcher.dispatch_method_name).
|
||||
|
||||
Optional keyword argument: handler=(Object or Function)
|
||||
|
||||
Provides a _cp_dispatch function that pops off path segments into
|
||||
cherrypy.request.params under the names specified. The dispatch
|
||||
is then forwarded on to the next vpath element.
|
||||
|
||||
Note that any existing (and exposed) member function of the class that
|
||||
popargs is applied to will override that value of the argument. For
|
||||
instance, if you have a method named "list" on the class decorated with
|
||||
popargs, then accessing "/list" will call that function instead of popping
|
||||
it off as the requested parameter. This restriction applies to all
|
||||
_cp_dispatch functions. The only way around this restriction is to create
|
||||
a "blank class" whose only function is to provide _cp_dispatch.
|
||||
|
||||
If there are path elements after the arguments, or more arguments
|
||||
are requested than are available in the vpath, then the 'handler'
|
||||
keyword argument specifies the next object to handle the parameterized
|
||||
request. If handler is not specified or is None, then self is used.
|
||||
If handler is a function rather than an instance, then that function
|
||||
will be called with the args specified and the return value from that
|
||||
function used as the next object INSTEAD of adding the parameters to
|
||||
cherrypy.request.args.
|
||||
|
||||
This decorator may be used in one of two ways:
|
||||
|
||||
As a class decorator:
|
||||
@cherrypy.popargs('year', 'month', 'day')
|
||||
class Blog:
|
||||
def index(self, year=None, month=None, day=None):
|
||||
#Process the parameters here; any url like
|
||||
#/, /2009, /2009/12, or /2009/12/31
|
||||
#will fill in the appropriate parameters.
|
||||
|
||||
def create(self):
|
||||
#This link will still be available at /create. Defined functions
|
||||
#take precedence over arguments.
|
||||
|
||||
Or as a member of a class:
|
||||
class Blog:
|
||||
_cp_dispatch = cherrypy.popargs('year', 'month', 'day')
|
||||
#...
|
||||
|
||||
The handler argument may be used to mix arguments with built in functions.
|
||||
For instance, the following setup allows different activities at the
|
||||
day, month, and year level:
|
||||
|
||||
class DayHandler:
|
||||
def index(self, year, month, day):
|
||||
#Do something with this day; probably list entries
|
||||
|
||||
def delete(self, year, month, day):
|
||||
#Delete all entries for this day
|
||||
|
||||
@cherrypy.popargs('day', handler=DayHandler())
|
||||
class MonthHandler:
|
||||
def index(self, year, month):
|
||||
#Do something with this month; probably list entries
|
||||
|
||||
def delete(self, year, month):
|
||||
#Delete all entries for this month
|
||||
|
||||
@cherrypy.popargs('month', handler=MonthHandler())
|
||||
class YearHandler:
|
||||
def index(self, year):
|
||||
#Do something with this year
|
||||
|
||||
#...
|
||||
|
||||
@cherrypy.popargs('year', handler=YearHandler())
|
||||
class Root:
|
||||
def index(self):
|
||||
#...
|
||||
|
||||
"""
|
||||
|
||||
# Since keyword arg comes after *args, we have to process it ourselves
|
||||
# for lower versions of python.
|
||||
|
||||
handler = None
|
||||
handler_call = False
|
||||
for k, v in kwargs.items():
|
||||
if k == 'handler':
|
||||
handler = v
|
||||
else:
|
||||
raise TypeError(
|
||||
"cherrypy.popargs() got an unexpected keyword argument '{0}'"
|
||||
.format(k)
|
||||
)
|
||||
|
||||
import inspect
|
||||
|
||||
if handler is not None \
|
||||
and (hasattr(handler, '__call__') or inspect.isclass(handler)):
|
||||
handler_call = True
|
||||
|
||||
def decorated(cls_or_self=None, vpath=None):
|
||||
if inspect.isclass(cls_or_self):
|
||||
# cherrypy.popargs is a class decorator
|
||||
cls = cls_or_self
|
||||
setattr(cls, dispatch.Dispatcher.dispatch_method_name, decorated)
|
||||
return cls
|
||||
|
||||
# We're in the actual function
|
||||
self = cls_or_self
|
||||
parms = {}
|
||||
for arg in args:
|
||||
if not vpath:
|
||||
break
|
||||
parms[arg] = vpath.pop(0)
|
||||
|
||||
if handler is not None:
|
||||
if handler_call:
|
||||
return handler(**parms)
|
||||
else:
|
||||
request.params.update(parms)
|
||||
return handler
|
||||
|
||||
request.params.update(parms)
|
||||
|
||||
# If we are the ultimate handler, then to prevent our _cp_dispatch
|
||||
# from being called again, we will resolve remaining elements through
|
||||
# getattr() directly.
|
||||
if vpath:
|
||||
return getattr(self, vpath.pop(0), None)
|
||||
else:
|
||||
return self
|
||||
|
||||
return decorated
|
||||
|
||||
|
||||
def url(path="", qs="", script_name=None, base=None, relative=None):
|
||||
"""Create an absolute URL for the given path.
|
||||
|
||||
If 'path' starts with a slash ('/'), this will return
|
||||
(base + script_name + path + qs).
|
||||
If it does not start with a slash, this returns
|
||||
(base + script_name [+ request.path_info] + path + qs).
|
||||
|
||||
If script_name is None, cherrypy.request will be used
|
||||
to find a script_name, if available.
|
||||
|
||||
If base is None, cherrypy.request.base will be used (if available).
|
||||
Note that you can use cherrypy.tools.proxy to change this.
|
||||
|
||||
Finally, note that this function can be used to obtain an absolute URL
|
||||
for the current request path (minus the querystring) by passing no args.
|
||||
If you call url(qs=cherrypy.request.query_string), you should get the
|
||||
original browser URL (assuming no internal redirections).
|
||||
|
||||
If relative is None or not provided, request.app.relative_urls will
|
||||
be used (if available, else False). If False, the output will be an
|
||||
absolute URL (including the scheme, host, vhost, and script_name).
|
||||
If True, the output will instead be a URL that is relative to the
|
||||
current request path, perhaps including '..' atoms. If relative is
|
||||
the string 'server', the output will instead be a URL that is
|
||||
relative to the server root; i.e., it will start with a slash.
|
||||
"""
|
||||
if isinstance(qs, (tuple, list, dict)):
|
||||
qs = _urlencode(qs)
|
||||
if qs:
|
||||
qs = '?' + qs
|
||||
|
||||
if request.app:
|
||||
if not path.startswith("/"):
|
||||
# Append/remove trailing slash from path_info as needed
|
||||
# (this is to support mistyped URL's without redirecting;
|
||||
# if you want to redirect, use tools.trailing_slash).
|
||||
pi = request.path_info
|
||||
if request.is_index is True:
|
||||
if not pi.endswith('/'):
|
||||
pi = pi + '/'
|
||||
elif request.is_index is False:
|
||||
if pi.endswith('/') and pi != '/':
|
||||
pi = pi[:-1]
|
||||
|
||||
if path == "":
|
||||
path = pi
|
||||
else:
|
||||
path = _urljoin(pi, path)
|
||||
|
||||
if script_name is None:
|
||||
script_name = request.script_name
|
||||
if base is None:
|
||||
base = request.base
|
||||
|
||||
newurl = base + script_name + path + qs
|
||||
else:
|
||||
# No request.app (we're being called outside a request).
|
||||
# We'll have to guess the base from server.* attributes.
|
||||
# This will produce very different results from the above
|
||||
# if you're using vhosts or tools.proxy.
|
||||
if base is None:
|
||||
base = server.base()
|
||||
|
||||
path = (script_name or "") + path
|
||||
newurl = base + path + qs
|
||||
|
||||
if './' in newurl:
|
||||
# Normalize the URL by removing ./ and ../
|
||||
atoms = []
|
||||
for atom in newurl.split('/'):
|
||||
if atom == '.':
|
||||
pass
|
||||
elif atom == '..':
|
||||
atoms.pop()
|
||||
else:
|
||||
atoms.append(atom)
|
||||
newurl = '/'.join(atoms)
|
||||
|
||||
# At this point, we should have a fully-qualified absolute URL.
|
||||
|
||||
if relative is None:
|
||||
relative = getattr(request.app, "relative_urls", False)
|
||||
|
||||
# See http://www.ietf.org/rfc/rfc2396.txt
|
||||
if relative == 'server':
|
||||
# "A relative reference beginning with a single slash character is
|
||||
# termed an absolute-path reference, as defined by <abs_path>..."
|
||||
# This is also sometimes called "server-relative".
|
||||
newurl = '/' + '/'.join(newurl.split('/', 3)[3:])
|
||||
elif relative:
|
||||
# "A relative reference that does not begin with a scheme name
|
||||
# or a slash character is termed a relative-path reference."
|
||||
old = url(relative=False).split('/')[:-1]
|
||||
new = newurl.split('/')
|
||||
while old and new:
|
||||
a, b = old[0], new[0]
|
||||
if a != b:
|
||||
break
|
||||
old.pop(0)
|
||||
new.pop(0)
|
||||
new = (['..'] * len(old)) + new
|
||||
newurl = '/'.join(new)
|
||||
|
||||
return newurl
|
||||
|
||||
|
||||
# import _cpconfig last so it can reference other top-level objects
|
||||
from cherrypy import _cpconfig
|
||||
# Use _global_conf_alias so quickstart can use 'config' as an arg
|
||||
# without shadowing cherrypy.config.
|
||||
config = _global_conf_alias = _cpconfig.Config()
|
||||
config.defaults = {
|
||||
'tools.log_tracebacks.on': True,
|
||||
'tools.log_headers.on': True,
|
||||
'tools.trailing_slash.on': True,
|
||||
'tools.encode.on': True
|
||||
}
|
||||
config.namespaces["log"] = lambda k, v: setattr(log, k, v)
|
||||
config.namespaces["checker"] = lambda k, v: setattr(checker, k, v)
|
||||
# Must reset to get our defaults applied.
|
||||
config.reset()
|
||||
|
||||
from cherrypy import _cpchecker
|
||||
checker = _cpchecker.Checker()
|
||||
engine.subscribe('start', checker)
|
||||
@@ -0,0 +1,332 @@
|
||||
import os
|
||||
import warnings
|
||||
|
||||
import cherrypy
|
||||
from cherrypy._cpcompat import iteritems, copykeys, builtins
|
||||
|
||||
|
||||
class Checker(object):
|
||||
|
||||
"""A checker for CherryPy sites and their mounted applications.
|
||||
|
||||
When this object is called at engine startup, it executes each
|
||||
of its own methods whose names start with ``check_``. If you wish
|
||||
to disable selected checks, simply add a line in your global
|
||||
config which sets the appropriate method to False::
|
||||
|
||||
[global]
|
||||
checker.check_skipped_app_config = False
|
||||
|
||||
You may also dynamically add or replace ``check_*`` methods in this way.
|
||||
"""
|
||||
|
||||
on = True
|
||||
"""If True (the default), run all checks; if False, turn off all checks."""
|
||||
|
||||
def __init__(self):
|
||||
self._populate_known_types()
|
||||
|
||||
def __call__(self):
|
||||
"""Run all check_* methods."""
|
||||
if self.on:
|
||||
oldformatwarning = warnings.formatwarning
|
||||
warnings.formatwarning = self.formatwarning
|
||||
try:
|
||||
for name in dir(self):
|
||||
if name.startswith("check_"):
|
||||
method = getattr(self, name)
|
||||
if method and hasattr(method, '__call__'):
|
||||
method()
|
||||
finally:
|
||||
warnings.formatwarning = oldformatwarning
|
||||
|
||||
def formatwarning(self, message, category, filename, lineno, line=None):
|
||||
"""Function to format a warning."""
|
||||
return "CherryPy Checker:\n%s\n\n" % message
|
||||
|
||||
# This value should be set inside _cpconfig.
|
||||
global_config_contained_paths = False
|
||||
|
||||
def check_app_config_entries_dont_start_with_script_name(self):
|
||||
"""Check for Application config with sections that repeat script_name.
|
||||
"""
|
||||
for sn, app in cherrypy.tree.apps.items():
|
||||
if not isinstance(app, cherrypy.Application):
|
||||
continue
|
||||
if not app.config:
|
||||
continue
|
||||
if sn == '':
|
||||
continue
|
||||
sn_atoms = sn.strip("/").split("/")
|
||||
for key in app.config.keys():
|
||||
key_atoms = key.strip("/").split("/")
|
||||
if key_atoms[:len(sn_atoms)] == sn_atoms:
|
||||
warnings.warn(
|
||||
"The application mounted at %r has config "
|
||||
"entries that start with its script name: %r" % (sn,
|
||||
key))
|
||||
|
||||
def check_site_config_entries_in_app_config(self):
|
||||
"""Check for mounted Applications that have site-scoped config."""
|
||||
for sn, app in iteritems(cherrypy.tree.apps):
|
||||
if not isinstance(app, cherrypy.Application):
|
||||
continue
|
||||
|
||||
msg = []
|
||||
for section, entries in iteritems(app.config):
|
||||
if section.startswith('/'):
|
||||
for key, value in iteritems(entries):
|
||||
for n in ("engine.", "server.", "tree.", "checker."):
|
||||
if key.startswith(n):
|
||||
msg.append("[%s] %s = %s" %
|
||||
(section, key, value))
|
||||
if msg:
|
||||
msg.insert(0,
|
||||
"The application mounted at %r contains the "
|
||||
"following config entries, which are only allowed "
|
||||
"in site-wide config. Move them to a [global] "
|
||||
"section and pass them to cherrypy.config.update() "
|
||||
"instead of tree.mount()." % sn)
|
||||
warnings.warn(os.linesep.join(msg))
|
||||
|
||||
def check_skipped_app_config(self):
|
||||
"""Check for mounted Applications that have no config."""
|
||||
for sn, app in cherrypy.tree.apps.items():
|
||||
if not isinstance(app, cherrypy.Application):
|
||||
continue
|
||||
if not app.config:
|
||||
msg = "The Application mounted at %r has an empty config." % sn
|
||||
if self.global_config_contained_paths:
|
||||
msg += (" It looks like the config you passed to "
|
||||
"cherrypy.config.update() contains application-"
|
||||
"specific sections. You must explicitly pass "
|
||||
"application config via "
|
||||
"cherrypy.tree.mount(..., config=app_config)")
|
||||
warnings.warn(msg)
|
||||
return
|
||||
|
||||
def check_app_config_brackets(self):
|
||||
"""Check for Application config with extraneous brackets in section
|
||||
names.
|
||||
"""
|
||||
for sn, app in cherrypy.tree.apps.items():
|
||||
if not isinstance(app, cherrypy.Application):
|
||||
continue
|
||||
if not app.config:
|
||||
continue
|
||||
for key in app.config.keys():
|
||||
if key.startswith("[") or key.endswith("]"):
|
||||
warnings.warn(
|
||||
"The application mounted at %r has config "
|
||||
"section names with extraneous brackets: %r. "
|
||||
"Config *files* need brackets; config *dicts* "
|
||||
"(e.g. passed to tree.mount) do not." % (sn, key))
|
||||
|
||||
def check_static_paths(self):
|
||||
"""Check Application config for incorrect static paths."""
|
||||
# Use the dummy Request object in the main thread.
|
||||
request = cherrypy.request
|
||||
for sn, app in cherrypy.tree.apps.items():
|
||||
if not isinstance(app, cherrypy.Application):
|
||||
continue
|
||||
request.app = app
|
||||
for section in app.config:
|
||||
# get_resource will populate request.config
|
||||
request.get_resource(section + "/dummy.html")
|
||||
conf = request.config.get
|
||||
|
||||
if conf("tools.staticdir.on", False):
|
||||
msg = ""
|
||||
root = conf("tools.staticdir.root")
|
||||
dir = conf("tools.staticdir.dir")
|
||||
if dir is None:
|
||||
msg = "tools.staticdir.dir is not set."
|
||||
else:
|
||||
fulldir = ""
|
||||
if os.path.isabs(dir):
|
||||
fulldir = dir
|
||||
if root:
|
||||
msg = ("dir is an absolute path, even "
|
||||
"though a root is provided.")
|
||||
testdir = os.path.join(root, dir[1:])
|
||||
if os.path.exists(testdir):
|
||||
msg += (
|
||||
"\nIf you meant to serve the "
|
||||
"filesystem folder at %r, remove the "
|
||||
"leading slash from dir." % (testdir,))
|
||||
else:
|
||||
if not root:
|
||||
msg = (
|
||||
"dir is a relative path and "
|
||||
"no root provided.")
|
||||
else:
|
||||
fulldir = os.path.join(root, dir)
|
||||
if not os.path.isabs(fulldir):
|
||||
msg = ("%r is not an absolute path." % (
|
||||
fulldir,))
|
||||
|
||||
if fulldir and not os.path.exists(fulldir):
|
||||
if msg:
|
||||
msg += "\n"
|
||||
msg += ("%r (root + dir) is not an existing "
|
||||
"filesystem path." % fulldir)
|
||||
|
||||
if msg:
|
||||
warnings.warn("%s\nsection: [%s]\nroot: %r\ndir: %r"
|
||||
% (msg, section, root, dir))
|
||||
|
||||
# -------------------------- Compatibility -------------------------- #
|
||||
obsolete = {
|
||||
'server.default_content_type': 'tools.response_headers.headers',
|
||||
'log_access_file': 'log.access_file',
|
||||
'log_config_options': None,
|
||||
'log_file': 'log.error_file',
|
||||
'log_file_not_found': None,
|
||||
'log_request_headers': 'tools.log_headers.on',
|
||||
'log_to_screen': 'log.screen',
|
||||
'show_tracebacks': 'request.show_tracebacks',
|
||||
'throw_errors': 'request.throw_errors',
|
||||
'profiler.on': ('cherrypy.tree.mount(profiler.make_app('
|
||||
'cherrypy.Application(Root())))'),
|
||||
}
|
||||
|
||||
deprecated = {}
|
||||
|
||||
def _compat(self, config):
|
||||
"""Process config and warn on each obsolete or deprecated entry."""
|
||||
for section, conf in config.items():
|
||||
if isinstance(conf, dict):
|
||||
for k, v in conf.items():
|
||||
if k in self.obsolete:
|
||||
warnings.warn("%r is obsolete. Use %r instead.\n"
|
||||
"section: [%s]" %
|
||||
(k, self.obsolete[k], section))
|
||||
elif k in self.deprecated:
|
||||
warnings.warn("%r is deprecated. Use %r instead.\n"
|
||||
"section: [%s]" %
|
||||
(k, self.deprecated[k], section))
|
||||
else:
|
||||
if section in self.obsolete:
|
||||
warnings.warn("%r is obsolete. Use %r instead."
|
||||
% (section, self.obsolete[section]))
|
||||
elif section in self.deprecated:
|
||||
warnings.warn("%r is deprecated. Use %r instead."
|
||||
% (section, self.deprecated[section]))
|
||||
|
||||
def check_compatibility(self):
|
||||
"""Process config and warn on each obsolete or deprecated entry."""
|
||||
self._compat(cherrypy.config)
|
||||
for sn, app in cherrypy.tree.apps.items():
|
||||
if not isinstance(app, cherrypy.Application):
|
||||
continue
|
||||
self._compat(app.config)
|
||||
|
||||
# ------------------------ Known Namespaces ------------------------ #
|
||||
extra_config_namespaces = []
|
||||
|
||||
def _known_ns(self, app):
|
||||
ns = ["wsgi"]
|
||||
ns.extend(copykeys(app.toolboxes))
|
||||
ns.extend(copykeys(app.namespaces))
|
||||
ns.extend(copykeys(app.request_class.namespaces))
|
||||
ns.extend(copykeys(cherrypy.config.namespaces))
|
||||
ns += self.extra_config_namespaces
|
||||
|
||||
for section, conf in app.config.items():
|
||||
is_path_section = section.startswith("/")
|
||||
if is_path_section and isinstance(conf, dict):
|
||||
for k, v in conf.items():
|
||||
atoms = k.split(".")
|
||||
if len(atoms) > 1:
|
||||
if atoms[0] not in ns:
|
||||
# Spit out a special warning if a known
|
||||
# namespace is preceded by "cherrypy."
|
||||
if atoms[0] == "cherrypy" and atoms[1] in ns:
|
||||
msg = (
|
||||
"The config entry %r is invalid; "
|
||||
"try %r instead.\nsection: [%s]"
|
||||
% (k, ".".join(atoms[1:]), section))
|
||||
else:
|
||||
msg = (
|
||||
"The config entry %r is invalid, "
|
||||
"because the %r config namespace "
|
||||
"is unknown.\n"
|
||||
"section: [%s]" % (k, atoms[0], section))
|
||||
warnings.warn(msg)
|
||||
elif atoms[0] == "tools":
|
||||
if atoms[1] not in dir(cherrypy.tools):
|
||||
msg = (
|
||||
"The config entry %r may be invalid, "
|
||||
"because the %r tool was not found.\n"
|
||||
"section: [%s]" % (k, atoms[1], section))
|
||||
warnings.warn(msg)
|
||||
|
||||
def check_config_namespaces(self):
|
||||
"""Process config and warn on each unknown config namespace."""
|
||||
for sn, app in cherrypy.tree.apps.items():
|
||||
if not isinstance(app, cherrypy.Application):
|
||||
continue
|
||||
self._known_ns(app)
|
||||
|
||||
# -------------------------- Config Types -------------------------- #
|
||||
known_config_types = {}
|
||||
|
||||
def _populate_known_types(self):
|
||||
b = [x for x in vars(builtins).values()
|
||||
if type(x) is type(str)]
|
||||
|
||||
def traverse(obj, namespace):
|
||||
for name in dir(obj):
|
||||
# Hack for 3.2's warning about body_params
|
||||
if name == 'body_params':
|
||||
continue
|
||||
vtype = type(getattr(obj, name, None))
|
||||
if vtype in b:
|
||||
self.known_config_types[namespace + "." + name] = vtype
|
||||
|
||||
traverse(cherrypy.request, "request")
|
||||
traverse(cherrypy.response, "response")
|
||||
traverse(cherrypy.server, "server")
|
||||
traverse(cherrypy.engine, "engine")
|
||||
traverse(cherrypy.log, "log")
|
||||
|
||||
def _known_types(self, config):
|
||||
msg = ("The config entry %r in section %r is of type %r, "
|
||||
"which does not match the expected type %r.")
|
||||
|
||||
for section, conf in config.items():
|
||||
if isinstance(conf, dict):
|
||||
for k, v in conf.items():
|
||||
if v is not None:
|
||||
expected_type = self.known_config_types.get(k, None)
|
||||
vtype = type(v)
|
||||
if expected_type and vtype != expected_type:
|
||||
warnings.warn(msg % (k, section, vtype.__name__,
|
||||
expected_type.__name__))
|
||||
else:
|
||||
k, v = section, conf
|
||||
if v is not None:
|
||||
expected_type = self.known_config_types.get(k, None)
|
||||
vtype = type(v)
|
||||
if expected_type and vtype != expected_type:
|
||||
warnings.warn(msg % (k, section, vtype.__name__,
|
||||
expected_type.__name__))
|
||||
|
||||
def check_config_types(self):
|
||||
"""Assert that config values are of the same type as default values."""
|
||||
self._known_types(cherrypy.config)
|
||||
for sn, app in cherrypy.tree.apps.items():
|
||||
if not isinstance(app, cherrypy.Application):
|
||||
continue
|
||||
self._known_types(app.config)
|
||||
|
||||
# -------------------- Specific config warnings -------------------- #
|
||||
def check_localhost(self):
|
||||
"""Warn if any socket_host is 'localhost'. See #711."""
|
||||
for k, v in cherrypy.config.items():
|
||||
if k == 'server.socket_host' and v == 'localhost':
|
||||
warnings.warn("The use of 'localhost' as a socket host can "
|
||||
"cause problems on newer systems, since "
|
||||
"'localhost' can map to either an IPv4 or an "
|
||||
"IPv6 address. You should use '127.0.0.1' "
|
||||
"or '[::1]' instead.")
|
||||
@@ -0,0 +1,383 @@
|
||||
"""Compatibility code for using CherryPy with various versions of Python.
|
||||
|
||||
CherryPy 3.2 is compatible with Python versions 2.3+. This module provides a
|
||||
useful abstraction over the differences between Python versions, sometimes by
|
||||
preferring a newer idiom, sometimes an older one, and sometimes a custom one.
|
||||
|
||||
In particular, Python 2 uses str and '' for byte strings, while Python 3
|
||||
uses str and '' for unicode strings. We will call each of these the 'native
|
||||
string' type for each version. Because of this major difference, this module
|
||||
provides new 'bytestr', 'unicodestr', and 'nativestr' attributes, as well as
|
||||
two functions: 'ntob', which translates native strings (of type 'str') into
|
||||
byte strings regardless of Python version, and 'ntou', which translates native
|
||||
strings to unicode strings. This also provides a 'BytesIO' name for dealing
|
||||
specifically with bytes, and a 'StringIO' name for dealing with native strings.
|
||||
It also provides a 'base64_decode' function with native strings as input and
|
||||
output.
|
||||
"""
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import threading
|
||||
|
||||
if sys.version_info >= (3, 0):
|
||||
py3k = True
|
||||
bytestr = bytes
|
||||
unicodestr = str
|
||||
nativestr = unicodestr
|
||||
basestring = (bytes, str)
|
||||
|
||||
def ntob(n, encoding='ISO-8859-1'):
|
||||
"""Return the given native string as a byte string in the given
|
||||
encoding.
|
||||
"""
|
||||
assert_native(n)
|
||||
# In Python 3, the native string type is unicode
|
||||
return n.encode(encoding)
|
||||
|
||||
def ntou(n, encoding='ISO-8859-1'):
|
||||
"""Return the given native string as a unicode string with the given
|
||||
encoding.
|
||||
"""
|
||||
assert_native(n)
|
||||
# In Python 3, the native string type is unicode
|
||||
return n
|
||||
|
||||
def tonative(n, encoding='ISO-8859-1'):
|
||||
"""Return the given string as a native string in the given encoding."""
|
||||
# In Python 3, the native string type is unicode
|
||||
if isinstance(n, bytes):
|
||||
return n.decode(encoding)
|
||||
return n
|
||||
# type("")
|
||||
from io import StringIO
|
||||
# bytes:
|
||||
from io import BytesIO as BytesIO
|
||||
else:
|
||||
# Python 2
|
||||
py3k = False
|
||||
bytestr = str
|
||||
unicodestr = unicode
|
||||
nativestr = bytestr
|
||||
basestring = basestring
|
||||
|
||||
def ntob(n, encoding='ISO-8859-1'):
|
||||
"""Return the given native string as a byte string in the given
|
||||
encoding.
|
||||
"""
|
||||
assert_native(n)
|
||||
# In Python 2, the native string type is bytes. Assume it's already
|
||||
# in the given encoding, which for ISO-8859-1 is almost always what
|
||||
# was intended.
|
||||
return n
|
||||
|
||||
def ntou(n, encoding='ISO-8859-1'):
|
||||
"""Return the given native string as a unicode string with the given
|
||||
encoding.
|
||||
"""
|
||||
assert_native(n)
|
||||
# In Python 2, the native string type is bytes.
|
||||
# First, check for the special encoding 'escape'. The test suite uses
|
||||
# this to signal that it wants to pass a string with embedded \uXXXX
|
||||
# escapes, but without having to prefix it with u'' for Python 2,
|
||||
# but no prefix for Python 3.
|
||||
if encoding == 'escape':
|
||||
return unicode(
|
||||
re.sub(r'\\u([0-9a-zA-Z]{4})',
|
||||
lambda m: unichr(int(m.group(1), 16)),
|
||||
n.decode('ISO-8859-1')))
|
||||
# Assume it's already in the given encoding, which for ISO-8859-1
|
||||
# is almost always what was intended.
|
||||
return n.decode(encoding)
|
||||
|
||||
def tonative(n, encoding='ISO-8859-1'):
|
||||
"""Return the given string as a native string in the given encoding."""
|
||||
# In Python 2, the native string type is bytes.
|
||||
if isinstance(n, unicode):
|
||||
return n.encode(encoding)
|
||||
return n
|
||||
try:
|
||||
# type("")
|
||||
from cStringIO import StringIO
|
||||
except ImportError:
|
||||
# type("")
|
||||
from StringIO import StringIO
|
||||
# bytes:
|
||||
BytesIO = StringIO
|
||||
|
||||
|
||||
def assert_native(n):
|
||||
if not isinstance(n, nativestr):
|
||||
raise TypeError("n must be a native str (got %s)" % type(n).__name__)
|
||||
|
||||
try:
|
||||
set = set
|
||||
except NameError:
|
||||
from sets import Set as set
|
||||
|
||||
try:
|
||||
# Python 3.1+
|
||||
from base64 import decodebytes as _base64_decodebytes
|
||||
except ImportError:
|
||||
# Python 3.0-
|
||||
# since CherryPy claims compability with Python 2.3, we must use
|
||||
# the legacy API of base64
|
||||
from base64 import decodestring as _base64_decodebytes
|
||||
|
||||
|
||||
def base64_decode(n, encoding='ISO-8859-1'):
|
||||
"""Return the native string base64-decoded (as a native string)."""
|
||||
if isinstance(n, unicodestr):
|
||||
b = n.encode(encoding)
|
||||
else:
|
||||
b = n
|
||||
b = _base64_decodebytes(b)
|
||||
if nativestr is unicodestr:
|
||||
return b.decode(encoding)
|
||||
else:
|
||||
return b
|
||||
|
||||
try:
|
||||
# Python 2.5+
|
||||
from hashlib import md5
|
||||
except ImportError:
|
||||
from md5 import new as md5
|
||||
|
||||
try:
|
||||
# Python 2.5+
|
||||
from hashlib import sha1 as sha
|
||||
except ImportError:
|
||||
from sha import new as sha
|
||||
|
||||
try:
|
||||
sorted = sorted
|
||||
except NameError:
|
||||
def sorted(i):
|
||||
i = i[:]
|
||||
i.sort()
|
||||
return i
|
||||
|
||||
try:
|
||||
reversed = reversed
|
||||
except NameError:
|
||||
def reversed(x):
|
||||
i = len(x)
|
||||
while i > 0:
|
||||
i -= 1
|
||||
yield x[i]
|
||||
|
||||
try:
|
||||
# Python 3
|
||||
from urllib.parse import urljoin, urlencode
|
||||
from urllib.parse import quote, quote_plus
|
||||
from urllib.request import unquote, urlopen
|
||||
from urllib.request import parse_http_list, parse_keqv_list
|
||||
except ImportError:
|
||||
# Python 2
|
||||
from urlparse import urljoin
|
||||
from urllib import urlencode, urlopen
|
||||
from urllib import quote, quote_plus
|
||||
from urllib import unquote
|
||||
from urllib2 import parse_http_list, parse_keqv_list
|
||||
|
||||
try:
|
||||
from threading import local as threadlocal
|
||||
except ImportError:
|
||||
from cherrypy._cpthreadinglocal import local as threadlocal
|
||||
|
||||
try:
|
||||
dict.iteritems
|
||||
# Python 2
|
||||
iteritems = lambda d: d.iteritems()
|
||||
copyitems = lambda d: d.items()
|
||||
except AttributeError:
|
||||
# Python 3
|
||||
iteritems = lambda d: d.items()
|
||||
copyitems = lambda d: list(d.items())
|
||||
|
||||
try:
|
||||
dict.iterkeys
|
||||
# Python 2
|
||||
iterkeys = lambda d: d.iterkeys()
|
||||
copykeys = lambda d: d.keys()
|
||||
except AttributeError:
|
||||
# Python 3
|
||||
iterkeys = lambda d: d.keys()
|
||||
copykeys = lambda d: list(d.keys())
|
||||
|
||||
try:
|
||||
dict.itervalues
|
||||
# Python 2
|
||||
itervalues = lambda d: d.itervalues()
|
||||
copyvalues = lambda d: d.values()
|
||||
except AttributeError:
|
||||
# Python 3
|
||||
itervalues = lambda d: d.values()
|
||||
copyvalues = lambda d: list(d.values())
|
||||
|
||||
try:
|
||||
# Python 3
|
||||
import builtins
|
||||
except ImportError:
|
||||
# Python 2
|
||||
import __builtin__ as builtins
|
||||
|
||||
try:
|
||||
# Python 2. We try Python 2 first clients on Python 2
|
||||
# don't try to import the 'http' module from cherrypy.lib
|
||||
from Cookie import SimpleCookie, CookieError
|
||||
from httplib import BadStatusLine, HTTPConnection, IncompleteRead
|
||||
from httplib import NotConnected
|
||||
from BaseHTTPServer import BaseHTTPRequestHandler
|
||||
except ImportError:
|
||||
# Python 3
|
||||
from http.cookies import SimpleCookie, CookieError
|
||||
from http.client import BadStatusLine, HTTPConnection, IncompleteRead
|
||||
from http.client import NotConnected
|
||||
from http.server import BaseHTTPRequestHandler
|
||||
|
||||
# Some platforms don't expose HTTPSConnection, so handle it separately
|
||||
if py3k:
|
||||
try:
|
||||
from http.client import HTTPSConnection
|
||||
except ImportError:
|
||||
# Some platforms which don't have SSL don't expose HTTPSConnection
|
||||
HTTPSConnection = None
|
||||
else:
|
||||
try:
|
||||
from httplib import HTTPSConnection
|
||||
except ImportError:
|
||||
HTTPSConnection = None
|
||||
|
||||
try:
|
||||
# Python 2
|
||||
xrange = xrange
|
||||
except NameError:
|
||||
# Python 3
|
||||
xrange = range
|
||||
|
||||
import threading
|
||||
if hasattr(threading.Thread, "daemon"):
|
||||
# Python 2.6+
|
||||
def get_daemon(t):
|
||||
return t.daemon
|
||||
|
||||
def set_daemon(t, val):
|
||||
t.daemon = val
|
||||
else:
|
||||
def get_daemon(t):
|
||||
return t.isDaemon()
|
||||
|
||||
def set_daemon(t, val):
|
||||
t.setDaemon(val)
|
||||
|
||||
try:
|
||||
from email.utils import formatdate
|
||||
|
||||
def HTTPDate(timeval=None):
|
||||
return formatdate(timeval, usegmt=True)
|
||||
except ImportError:
|
||||
from rfc822 import formatdate as HTTPDate
|
||||
|
||||
try:
|
||||
# Python 3
|
||||
from urllib.parse import unquote as parse_unquote
|
||||
|
||||
def unquote_qs(atom, encoding, errors='strict'):
|
||||
return parse_unquote(
|
||||
atom.replace('+', ' '),
|
||||
encoding=encoding,
|
||||
errors=errors)
|
||||
except ImportError:
|
||||
# Python 2
|
||||
from urllib import unquote as parse_unquote
|
||||
|
||||
def unquote_qs(atom, encoding, errors='strict'):
|
||||
return parse_unquote(atom.replace('+', ' ')).decode(encoding, errors)
|
||||
|
||||
try:
|
||||
# Prefer simplejson, which is usually more advanced than the builtin
|
||||
# module.
|
||||
import simplejson as json
|
||||
json_decode = json.JSONDecoder().decode
|
||||
_json_encode = json.JSONEncoder().iterencode
|
||||
except ImportError:
|
||||
if sys.version_info >= (2, 6):
|
||||
# Python >=2.6 : json is part of the standard library
|
||||
import json
|
||||
json_decode = json.JSONDecoder().decode
|
||||
_json_encode = json.JSONEncoder().iterencode
|
||||
else:
|
||||
json = None
|
||||
|
||||
def json_decode(s):
|
||||
raise ValueError('No JSON library is available')
|
||||
|
||||
def _json_encode(s):
|
||||
raise ValueError('No JSON library is available')
|
||||
finally:
|
||||
if json and py3k:
|
||||
# The two Python 3 implementations (simplejson/json)
|
||||
# outputs str. We need bytes.
|
||||
def json_encode(value):
|
||||
for chunk in _json_encode(value):
|
||||
yield chunk.encode('utf8')
|
||||
else:
|
||||
json_encode = _json_encode
|
||||
|
||||
|
||||
try:
|
||||
import cPickle as pickle
|
||||
except ImportError:
|
||||
# In Python 2, pickle is a Python version.
|
||||
# In Python 3, pickle is the sped-up C version.
|
||||
import pickle
|
||||
|
||||
try:
|
||||
os.urandom(20)
|
||||
import binascii
|
||||
|
||||
def random20():
|
||||
return binascii.hexlify(os.urandom(20)).decode('ascii')
|
||||
except (AttributeError, NotImplementedError):
|
||||
import random
|
||||
# os.urandom not available until Python 2.4. Fall back to random.random.
|
||||
|
||||
def random20():
|
||||
return sha('%s' % random.random()).hexdigest()
|
||||
|
||||
try:
|
||||
from _thread import get_ident as get_thread_ident
|
||||
except ImportError:
|
||||
from thread import get_ident as get_thread_ident
|
||||
|
||||
try:
|
||||
# Python 3
|
||||
next = next
|
||||
except NameError:
|
||||
# Python 2
|
||||
def next(i):
|
||||
return i.next()
|
||||
|
||||
if sys.version_info >= (3, 3):
|
||||
Timer = threading.Timer
|
||||
Event = threading.Event
|
||||
else:
|
||||
# Python 3.2 and earlier
|
||||
Timer = threading._Timer
|
||||
Event = threading._Event
|
||||
|
||||
# Prior to Python 2.6, the Thread class did not have a .daemon property.
|
||||
# This mix-in adds that property.
|
||||
|
||||
|
||||
class SetDaemonProperty:
|
||||
|
||||
def __get_daemon(self):
|
||||
return self.isDaemon()
|
||||
|
||||
def __set_daemon(self, daemon):
|
||||
self.setDaemon(daemon)
|
||||
|
||||
if sys.version_info < (2, 6):
|
||||
daemon = property(__get_daemon, __set_daemon)
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,317 @@
|
||||
"""
|
||||
Configuration system for CherryPy.
|
||||
|
||||
Configuration in CherryPy is implemented via dictionaries. Keys are strings
|
||||
which name the mapped value, which may be of any type.
|
||||
|
||||
|
||||
Architecture
|
||||
------------
|
||||
|
||||
CherryPy Requests are part of an Application, which runs in a global context,
|
||||
and configuration data may apply to any of those three scopes:
|
||||
|
||||
Global
|
||||
Configuration entries which apply everywhere are stored in
|
||||
cherrypy.config.
|
||||
|
||||
Application
|
||||
Entries which apply to each mounted application are stored
|
||||
on the Application object itself, as 'app.config'. This is a two-level
|
||||
dict where each key is a path, or "relative URL" (for example, "/" or
|
||||
"/path/to/my/page"), and each value is a config dict. Usually, this
|
||||
data is provided in the call to tree.mount(root(), config=conf),
|
||||
although you may also use app.merge(conf).
|
||||
|
||||
Request
|
||||
Each Request object possesses a single 'Request.config' dict.
|
||||
Early in the request process, this dict is populated by merging global
|
||||
config entries, Application entries (whose path equals or is a parent
|
||||
of Request.path_info), and any config acquired while looking up the
|
||||
page handler (see next).
|
||||
|
||||
|
||||
Declaration
|
||||
-----------
|
||||
|
||||
Configuration data may be supplied as a Python dictionary, as a filename,
|
||||
or as an open file object. When you supply a filename or file, CherryPy
|
||||
uses Python's builtin ConfigParser; you declare Application config by
|
||||
writing each path as a section header::
|
||||
|
||||
[/path/to/my/page]
|
||||
request.stream = True
|
||||
|
||||
To declare global configuration entries, place them in a [global] section.
|
||||
|
||||
You may also declare config entries directly on the classes and methods
|
||||
(page handlers) that make up your CherryPy application via the ``_cp_config``
|
||||
attribute. For example::
|
||||
|
||||
class Demo:
|
||||
_cp_config = {'tools.gzip.on': True}
|
||||
|
||||
def index(self):
|
||||
return "Hello world"
|
||||
index.exposed = True
|
||||
index._cp_config = {'request.show_tracebacks': False}
|
||||
|
||||
.. note::
|
||||
|
||||
This behavior is only guaranteed for the default dispatcher.
|
||||
Other dispatchers may have different restrictions on where
|
||||
you can attach _cp_config attributes.
|
||||
|
||||
|
||||
Namespaces
|
||||
----------
|
||||
|
||||
Configuration keys are separated into namespaces by the first "." in the key.
|
||||
Current namespaces:
|
||||
|
||||
engine
|
||||
Controls the 'application engine', including autoreload.
|
||||
These can only be declared in the global config.
|
||||
|
||||
tree
|
||||
Grafts cherrypy.Application objects onto cherrypy.tree.
|
||||
These can only be declared in the global config.
|
||||
|
||||
hooks
|
||||
Declares additional request-processing functions.
|
||||
|
||||
log
|
||||
Configures the logging for each application.
|
||||
These can only be declared in the global or / config.
|
||||
|
||||
request
|
||||
Adds attributes to each Request.
|
||||
|
||||
response
|
||||
Adds attributes to each Response.
|
||||
|
||||
server
|
||||
Controls the default HTTP server via cherrypy.server.
|
||||
These can only be declared in the global config.
|
||||
|
||||
tools
|
||||
Runs and configures additional request-processing packages.
|
||||
|
||||
wsgi
|
||||
Adds WSGI middleware to an Application's "pipeline".
|
||||
These can only be declared in the app's root config ("/").
|
||||
|
||||
checker
|
||||
Controls the 'checker', which looks for common errors in
|
||||
app state (including config) when the engine starts.
|
||||
Global config only.
|
||||
|
||||
The only key that does not exist in a namespace is the "environment" entry.
|
||||
This special entry 'imports' other config entries from a template stored in
|
||||
cherrypy._cpconfig.environments[environment]. It only applies to the global
|
||||
config, and only when you use cherrypy.config.update.
|
||||
|
||||
You can define your own namespaces to be called at the Global, Application,
|
||||
or Request level, by adding a named handler to cherrypy.config.namespaces,
|
||||
app.namespaces, or app.request_class.namespaces. The name can
|
||||
be any string, and the handler must be either a callable or a (Python 2.5
|
||||
style) context manager.
|
||||
"""
|
||||
|
||||
import cherrypy
|
||||
from cherrypy._cpcompat import set, basestring
|
||||
from cherrypy.lib import reprconf
|
||||
|
||||
# Deprecated in CherryPy 3.2--remove in 3.3
|
||||
NamespaceSet = reprconf.NamespaceSet
|
||||
|
||||
|
||||
def merge(base, other):
|
||||
"""Merge one app config (from a dict, file, or filename) into another.
|
||||
|
||||
If the given config is a filename, it will be appended to
|
||||
the list of files to monitor for "autoreload" changes.
|
||||
"""
|
||||
if isinstance(other, basestring):
|
||||
cherrypy.engine.autoreload.files.add(other)
|
||||
|
||||
# Load other into base
|
||||
for section, value_map in reprconf.as_dict(other).items():
|
||||
if not isinstance(value_map, dict):
|
||||
raise ValueError(
|
||||
"Application config must include section headers, but the "
|
||||
"config you tried to merge doesn't have any sections. "
|
||||
"Wrap your config in another dict with paths as section "
|
||||
"headers, for example: {'/': config}.")
|
||||
base.setdefault(section, {}).update(value_map)
|
||||
|
||||
|
||||
class Config(reprconf.Config):
|
||||
|
||||
"""The 'global' configuration data for the entire CherryPy process."""
|
||||
|
||||
def update(self, config):
|
||||
"""Update self from a dict, file or filename."""
|
||||
if isinstance(config, basestring):
|
||||
# Filename
|
||||
cherrypy.engine.autoreload.files.add(config)
|
||||
reprconf.Config.update(self, config)
|
||||
|
||||
def _apply(self, config):
|
||||
"""Update self from a dict."""
|
||||
if isinstance(config.get("global"), dict):
|
||||
if len(config) > 1:
|
||||
cherrypy.checker.global_config_contained_paths = True
|
||||
config = config["global"]
|
||||
if 'tools.staticdir.dir' in config:
|
||||
config['tools.staticdir.section'] = "global"
|
||||
reprconf.Config._apply(self, config)
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
"""Decorator for page handlers to set _cp_config."""
|
||||
if args:
|
||||
raise TypeError(
|
||||
"The cherrypy.config decorator does not accept positional "
|
||||
"arguments; you must use keyword arguments.")
|
||||
|
||||
def tool_decorator(f):
|
||||
if not hasattr(f, "_cp_config"):
|
||||
f._cp_config = {}
|
||||
for k, v in kwargs.items():
|
||||
f._cp_config[k] = v
|
||||
return f
|
||||
return tool_decorator
|
||||
|
||||
|
||||
# Sphinx begin config.environments
|
||||
Config.environments = environments = {
|
||||
"staging": {
|
||||
'engine.autoreload.on': False,
|
||||
'checker.on': False,
|
||||
'tools.log_headers.on': False,
|
||||
'request.show_tracebacks': False,
|
||||
'request.show_mismatched_params': False,
|
||||
},
|
||||
"production": {
|
||||
'engine.autoreload.on': False,
|
||||
'checker.on': False,
|
||||
'tools.log_headers.on': False,
|
||||
'request.show_tracebacks': False,
|
||||
'request.show_mismatched_params': False,
|
||||
'log.screen': False,
|
||||
},
|
||||
"embedded": {
|
||||
# For use with CherryPy embedded in another deployment stack.
|
||||
'engine.autoreload.on': False,
|
||||
'checker.on': False,
|
||||
'tools.log_headers.on': False,
|
||||
'request.show_tracebacks': False,
|
||||
'request.show_mismatched_params': False,
|
||||
'log.screen': False,
|
||||
'engine.SIGHUP': None,
|
||||
'engine.SIGTERM': None,
|
||||
},
|
||||
"test_suite": {
|
||||
'engine.autoreload.on': False,
|
||||
'checker.on': False,
|
||||
'tools.log_headers.on': False,
|
||||
'request.show_tracebacks': True,
|
||||
'request.show_mismatched_params': True,
|
||||
'log.screen': False,
|
||||
},
|
||||
}
|
||||
# Sphinx end config.environments
|
||||
|
||||
|
||||
def _server_namespace_handler(k, v):
|
||||
"""Config handler for the "server" namespace."""
|
||||
atoms = k.split(".", 1)
|
||||
if len(atoms) > 1:
|
||||
# Special-case config keys of the form 'server.servername.socket_port'
|
||||
# to configure additional HTTP servers.
|
||||
if not hasattr(cherrypy, "servers"):
|
||||
cherrypy.servers = {}
|
||||
|
||||
servername, k = atoms
|
||||
if servername not in cherrypy.servers:
|
||||
from cherrypy import _cpserver
|
||||
cherrypy.servers[servername] = _cpserver.Server()
|
||||
# On by default, but 'on = False' can unsubscribe it (see below).
|
||||
cherrypy.servers[servername].subscribe()
|
||||
|
||||
if k == 'on':
|
||||
if v:
|
||||
cherrypy.servers[servername].subscribe()
|
||||
else:
|
||||
cherrypy.servers[servername].unsubscribe()
|
||||
else:
|
||||
setattr(cherrypy.servers[servername], k, v)
|
||||
else:
|
||||
setattr(cherrypy.server, k, v)
|
||||
Config.namespaces["server"] = _server_namespace_handler
|
||||
|
||||
|
||||
def _engine_namespace_handler(k, v):
|
||||
"""Backward compatibility handler for the "engine" namespace."""
|
||||
engine = cherrypy.engine
|
||||
|
||||
deprecated = {
|
||||
'autoreload_on': 'autoreload.on',
|
||||
'autoreload_frequency': 'autoreload.frequency',
|
||||
'autoreload_match': 'autoreload.match',
|
||||
'reload_files': 'autoreload.files',
|
||||
'deadlock_poll_freq': 'timeout_monitor.frequency'
|
||||
}
|
||||
|
||||
if k in deprecated:
|
||||
engine.log(
|
||||
'WARNING: Use of engine.%s is deprecated and will be removed in a '
|
||||
'future version. Use engine.%s instead.' % (k, deprecated[k]))
|
||||
|
||||
if k == 'autoreload_on':
|
||||
if v:
|
||||
engine.autoreload.subscribe()
|
||||
else:
|
||||
engine.autoreload.unsubscribe()
|
||||
elif k == 'autoreload_frequency':
|
||||
engine.autoreload.frequency = v
|
||||
elif k == 'autoreload_match':
|
||||
engine.autoreload.match = v
|
||||
elif k == 'reload_files':
|
||||
engine.autoreload.files = set(v)
|
||||
elif k == 'deadlock_poll_freq':
|
||||
engine.timeout_monitor.frequency = v
|
||||
elif k == 'SIGHUP':
|
||||
engine.listeners['SIGHUP'] = set([v])
|
||||
elif k == 'SIGTERM':
|
||||
engine.listeners['SIGTERM'] = set([v])
|
||||
elif "." in k:
|
||||
plugin, attrname = k.split(".", 1)
|
||||
plugin = getattr(engine, plugin)
|
||||
if attrname == 'on':
|
||||
if v and hasattr(getattr(plugin, 'subscribe', None), '__call__'):
|
||||
plugin.subscribe()
|
||||
return
|
||||
elif (
|
||||
(not v) and
|
||||
hasattr(getattr(plugin, 'unsubscribe', None), '__call__')
|
||||
):
|
||||
plugin.unsubscribe()
|
||||
return
|
||||
setattr(plugin, attrname, v)
|
||||
else:
|
||||
setattr(engine, k, v)
|
||||
Config.namespaces["engine"] = _engine_namespace_handler
|
||||
|
||||
|
||||
def _tree_namespace_handler(k, v):
|
||||
"""Namespace handler for the 'tree' config namespace."""
|
||||
if isinstance(v, dict):
|
||||
for script_name, app in v.items():
|
||||
cherrypy.tree.graft(app, script_name)
|
||||
cherrypy.engine.log("Mounted: %s on %s" %
|
||||
(app, script_name or "/"))
|
||||
else:
|
||||
cherrypy.tree.graft(v, v.script_name)
|
||||
cherrypy.engine.log("Mounted: %s on %s" % (v, v.script_name or "/"))
|
||||
Config.namespaces["tree"] = _tree_namespace_handler
|
||||
@@ -0,0 +1,686 @@
|
||||
"""CherryPy dispatchers.
|
||||
|
||||
A 'dispatcher' is the object which looks up the 'page handler' callable
|
||||
and collects config for the current request based on the path_info, other
|
||||
request attributes, and the application architecture. The core calls the
|
||||
dispatcher as early as possible, passing it a 'path_info' argument.
|
||||
|
||||
The default dispatcher discovers the page handler by matching path_info
|
||||
to a hierarchical arrangement of objects, starting at request.app.root.
|
||||
"""
|
||||
|
||||
import string
|
||||
import sys
|
||||
import types
|
||||
try:
|
||||
classtype = (type, types.ClassType)
|
||||
except AttributeError:
|
||||
classtype = type
|
||||
|
||||
import cherrypy
|
||||
from cherrypy._cpcompat import set
|
||||
|
||||
|
||||
class PageHandler(object):
|
||||
|
||||
"""Callable which sets response.body."""
|
||||
|
||||
def __init__(self, callable, *args, **kwargs):
|
||||
self.callable = callable
|
||||
self.args = args
|
||||
self.kwargs = kwargs
|
||||
|
||||
def get_args(self):
|
||||
return cherrypy.serving.request.args
|
||||
|
||||
def set_args(self, args):
|
||||
cherrypy.serving.request.args = args
|
||||
return cherrypy.serving.request.args
|
||||
|
||||
args = property(
|
||||
get_args,
|
||||
set_args,
|
||||
doc="The ordered args should be accessible from post dispatch hooks"
|
||||
)
|
||||
|
||||
def get_kwargs(self):
|
||||
return cherrypy.serving.request.kwargs
|
||||
|
||||
def set_kwargs(self, kwargs):
|
||||
cherrypy.serving.request.kwargs = kwargs
|
||||
return cherrypy.serving.request.kwargs
|
||||
|
||||
kwargs = property(
|
||||
get_kwargs,
|
||||
set_kwargs,
|
||||
doc="The named kwargs should be accessible from post dispatch hooks"
|
||||
)
|
||||
|
||||
def __call__(self):
|
||||
try:
|
||||
return self.callable(*self.args, **self.kwargs)
|
||||
except TypeError:
|
||||
x = sys.exc_info()[1]
|
||||
try:
|
||||
test_callable_spec(self.callable, self.args, self.kwargs)
|
||||
except cherrypy.HTTPError:
|
||||
raise sys.exc_info()[1]
|
||||
except:
|
||||
raise x
|
||||
raise
|
||||
|
||||
|
||||
def test_callable_spec(callable, callable_args, callable_kwargs):
|
||||
"""
|
||||
Inspect callable and test to see if the given args are suitable for it.
|
||||
|
||||
When an error occurs during the handler's invoking stage there are 2
|
||||
erroneous cases:
|
||||
1. Too many parameters passed to a function which doesn't define
|
||||
one of *args or **kwargs.
|
||||
2. Too little parameters are passed to the function.
|
||||
|
||||
There are 3 sources of parameters to a cherrypy handler.
|
||||
1. query string parameters are passed as keyword parameters to the
|
||||
handler.
|
||||
2. body parameters are also passed as keyword parameters.
|
||||
3. when partial matching occurs, the final path atoms are passed as
|
||||
positional args.
|
||||
Both the query string and path atoms are part of the URI. If they are
|
||||
incorrect, then a 404 Not Found should be raised. Conversely the body
|
||||
parameters are part of the request; if they are invalid a 400 Bad Request.
|
||||
"""
|
||||
show_mismatched_params = getattr(
|
||||
cherrypy.serving.request, 'show_mismatched_params', False)
|
||||
try:
|
||||
(args, varargs, varkw, defaults) = getargspec(callable)
|
||||
except TypeError:
|
||||
if isinstance(callable, object) and hasattr(callable, '__call__'):
|
||||
(args, varargs, varkw,
|
||||
defaults) = getargspec(callable.__call__)
|
||||
else:
|
||||
# If it wasn't one of our own types, re-raise
|
||||
# the original error
|
||||
raise
|
||||
|
||||
if args and args[0] == 'self':
|
||||
args = args[1:]
|
||||
|
||||
arg_usage = dict([(arg, 0,) for arg in args])
|
||||
vararg_usage = 0
|
||||
varkw_usage = 0
|
||||
extra_kwargs = set()
|
||||
|
||||
for i, value in enumerate(callable_args):
|
||||
try:
|
||||
arg_usage[args[i]] += 1
|
||||
except IndexError:
|
||||
vararg_usage += 1
|
||||
|
||||
for key in callable_kwargs.keys():
|
||||
try:
|
||||
arg_usage[key] += 1
|
||||
except KeyError:
|
||||
varkw_usage += 1
|
||||
extra_kwargs.add(key)
|
||||
|
||||
# figure out which args have defaults.
|
||||
args_with_defaults = args[-len(defaults or []):]
|
||||
for i, val in enumerate(defaults or []):
|
||||
# Defaults take effect only when the arg hasn't been used yet.
|
||||
if arg_usage[args_with_defaults[i]] == 0:
|
||||
arg_usage[args_with_defaults[i]] += 1
|
||||
|
||||
missing_args = []
|
||||
multiple_args = []
|
||||
for key, usage in arg_usage.items():
|
||||
if usage == 0:
|
||||
missing_args.append(key)
|
||||
elif usage > 1:
|
||||
multiple_args.append(key)
|
||||
|
||||
if missing_args:
|
||||
# In the case where the method allows body arguments
|
||||
# there are 3 potential errors:
|
||||
# 1. not enough query string parameters -> 404
|
||||
# 2. not enough body parameters -> 400
|
||||
# 3. not enough path parts (partial matches) -> 404
|
||||
#
|
||||
# We can't actually tell which case it is,
|
||||
# so I'm raising a 404 because that covers 2/3 of the
|
||||
# possibilities
|
||||
#
|
||||
# In the case where the method does not allow body
|
||||
# arguments it's definitely a 404.
|
||||
message = None
|
||||
if show_mismatched_params:
|
||||
message = "Missing parameters: %s" % ",".join(missing_args)
|
||||
raise cherrypy.HTTPError(404, message=message)
|
||||
|
||||
# the extra positional arguments come from the path - 404 Not Found
|
||||
if not varargs and vararg_usage > 0:
|
||||
raise cherrypy.HTTPError(404)
|
||||
|
||||
body_params = cherrypy.serving.request.body.params or {}
|
||||
body_params = set(body_params.keys())
|
||||
qs_params = set(callable_kwargs.keys()) - body_params
|
||||
|
||||
if multiple_args:
|
||||
if qs_params.intersection(set(multiple_args)):
|
||||
# If any of the multiple parameters came from the query string then
|
||||
# it's a 404 Not Found
|
||||
error = 404
|
||||
else:
|
||||
# Otherwise it's a 400 Bad Request
|
||||
error = 400
|
||||
|
||||
message = None
|
||||
if show_mismatched_params:
|
||||
message = "Multiple values for parameters: "\
|
||||
"%s" % ",".join(multiple_args)
|
||||
raise cherrypy.HTTPError(error, message=message)
|
||||
|
||||
if not varkw and varkw_usage > 0:
|
||||
|
||||
# If there were extra query string parameters, it's a 404 Not Found
|
||||
extra_qs_params = set(qs_params).intersection(extra_kwargs)
|
||||
if extra_qs_params:
|
||||
message = None
|
||||
if show_mismatched_params:
|
||||
message = "Unexpected query string "\
|
||||
"parameters: %s" % ", ".join(extra_qs_params)
|
||||
raise cherrypy.HTTPError(404, message=message)
|
||||
|
||||
# If there were any extra body parameters, it's a 400 Not Found
|
||||
extra_body_params = set(body_params).intersection(extra_kwargs)
|
||||
if extra_body_params:
|
||||
message = None
|
||||
if show_mismatched_params:
|
||||
message = "Unexpected body parameters: "\
|
||||
"%s" % ", ".join(extra_body_params)
|
||||
raise cherrypy.HTTPError(400, message=message)
|
||||
|
||||
|
||||
try:
|
||||
import inspect
|
||||
except ImportError:
|
||||
test_callable_spec = lambda callable, args, kwargs: None
|
||||
else:
|
||||
getargspec = inspect.getargspec
|
||||
# Python 3 requires using getfullargspec if keyword-only arguments are present
|
||||
if hasattr(inspect, 'getfullargspec'):
|
||||
def getargspec(callable):
|
||||
return inspect.getfullargspec(callable)[:4]
|
||||
|
||||
|
||||
class LateParamPageHandler(PageHandler):
|
||||
|
||||
"""When passing cherrypy.request.params to the page handler, we do not
|
||||
want to capture that dict too early; we want to give tools like the
|
||||
decoding tool a chance to modify the params dict in-between the lookup
|
||||
of the handler and the actual calling of the handler. This subclass
|
||||
takes that into account, and allows request.params to be 'bound late'
|
||||
(it's more complicated than that, but that's the effect).
|
||||
"""
|
||||
|
||||
def _get_kwargs(self):
|
||||
kwargs = cherrypy.serving.request.params.copy()
|
||||
if self._kwargs:
|
||||
kwargs.update(self._kwargs)
|
||||
return kwargs
|
||||
|
||||
def _set_kwargs(self, kwargs):
|
||||
cherrypy.serving.request.kwargs = kwargs
|
||||
self._kwargs = kwargs
|
||||
|
||||
kwargs = property(_get_kwargs, _set_kwargs,
|
||||
doc='page handler kwargs (with '
|
||||
'cherrypy.request.params copied in)')
|
||||
|
||||
|
||||
if sys.version_info < (3, 0):
|
||||
punctuation_to_underscores = string.maketrans(
|
||||
string.punctuation, '_' * len(string.punctuation))
|
||||
|
||||
def validate_translator(t):
|
||||
if not isinstance(t, str) or len(t) != 256:
|
||||
raise ValueError(
|
||||
"The translate argument must be a str of len 256.")
|
||||
else:
|
||||
punctuation_to_underscores = str.maketrans(
|
||||
string.punctuation, '_' * len(string.punctuation))
|
||||
|
||||
def validate_translator(t):
|
||||
if not isinstance(t, dict):
|
||||
raise ValueError("The translate argument must be a dict.")
|
||||
|
||||
|
||||
class Dispatcher(object):
|
||||
|
||||
"""CherryPy Dispatcher which walks a tree of objects to find a handler.
|
||||
|
||||
The tree is rooted at cherrypy.request.app.root, and each hierarchical
|
||||
component in the path_info argument is matched to a corresponding nested
|
||||
attribute of the root object. Matching handlers must have an 'exposed'
|
||||
attribute which evaluates to True. The special method name "index"
|
||||
matches a URI which ends in a slash ("/"). The special method name
|
||||
"default" may match a portion of the path_info (but only when no longer
|
||||
substring of the path_info matches some other object).
|
||||
|
||||
This is the default, built-in dispatcher for CherryPy.
|
||||
"""
|
||||
|
||||
dispatch_method_name = '_cp_dispatch'
|
||||
"""
|
||||
The name of the dispatch method that nodes may optionally implement
|
||||
to provide their own dynamic dispatch algorithm.
|
||||
"""
|
||||
|
||||
def __init__(self, dispatch_method_name=None,
|
||||
translate=punctuation_to_underscores):
|
||||
validate_translator(translate)
|
||||
self.translate = translate
|
||||
if dispatch_method_name:
|
||||
self.dispatch_method_name = dispatch_method_name
|
||||
|
||||
def __call__(self, path_info):
|
||||
"""Set handler and config for the current request."""
|
||||
request = cherrypy.serving.request
|
||||
func, vpath = self.find_handler(path_info)
|
||||
|
||||
if func:
|
||||
# Decode any leftover %2F in the virtual_path atoms.
|
||||
vpath = [x.replace("%2F", "/") for x in vpath]
|
||||
request.handler = LateParamPageHandler(func, *vpath)
|
||||
else:
|
||||
request.handler = cherrypy.NotFound()
|
||||
|
||||
def find_handler(self, path):
|
||||
"""Return the appropriate page handler, plus any virtual path.
|
||||
|
||||
This will return two objects. The first will be a callable,
|
||||
which can be used to generate page output. Any parameters from
|
||||
the query string or request body will be sent to that callable
|
||||
as keyword arguments.
|
||||
|
||||
The callable is found by traversing the application's tree,
|
||||
starting from cherrypy.request.app.root, and matching path
|
||||
components to successive objects in the tree. For example, the
|
||||
URL "/path/to/handler" might return root.path.to.handler.
|
||||
|
||||
The second object returned will be a list of names which are
|
||||
'virtual path' components: parts of the URL which are dynamic,
|
||||
and were not used when looking up the handler.
|
||||
These virtual path components are passed to the handler as
|
||||
positional arguments.
|
||||
"""
|
||||
request = cherrypy.serving.request
|
||||
app = request.app
|
||||
root = app.root
|
||||
dispatch_name = self.dispatch_method_name
|
||||
|
||||
# Get config for the root object/path.
|
||||
fullpath = [x for x in path.strip('/').split('/') if x] + ['index']
|
||||
fullpath_len = len(fullpath)
|
||||
segleft = fullpath_len
|
||||
nodeconf = {}
|
||||
if hasattr(root, "_cp_config"):
|
||||
nodeconf.update(root._cp_config)
|
||||
if "/" in app.config:
|
||||
nodeconf.update(app.config["/"])
|
||||
object_trail = [['root', root, nodeconf, segleft]]
|
||||
|
||||
node = root
|
||||
iternames = fullpath[:]
|
||||
while iternames:
|
||||
name = iternames[0]
|
||||
# map to legal Python identifiers (e.g. replace '.' with '_')
|
||||
objname = name.translate(self.translate)
|
||||
|
||||
nodeconf = {}
|
||||
subnode = getattr(node, objname, None)
|
||||
pre_len = len(iternames)
|
||||
if subnode is None:
|
||||
dispatch = getattr(node, dispatch_name, None)
|
||||
if dispatch and hasattr(dispatch, '__call__') and not \
|
||||
getattr(dispatch, 'exposed', False) and \
|
||||
pre_len > 1:
|
||||
# Don't expose the hidden 'index' token to _cp_dispatch
|
||||
# We skip this if pre_len == 1 since it makes no sense
|
||||
# to call a dispatcher when we have no tokens left.
|
||||
index_name = iternames.pop()
|
||||
subnode = dispatch(vpath=iternames)
|
||||
iternames.append(index_name)
|
||||
else:
|
||||
# We didn't find a path, but keep processing in case there
|
||||
# is a default() handler.
|
||||
iternames.pop(0)
|
||||
else:
|
||||
# We found the path, remove the vpath entry
|
||||
iternames.pop(0)
|
||||
segleft = len(iternames)
|
||||
if segleft > pre_len:
|
||||
# No path segment was removed. Raise an error.
|
||||
raise cherrypy.CherryPyException(
|
||||
"A vpath segment was added. Custom dispatchers may only "
|
||||
+ "remove elements. While trying to process "
|
||||
+ "{0} in {1}".format(name, fullpath)
|
||||
)
|
||||
elif segleft == pre_len:
|
||||
# Assume that the handler used the current path segment, but
|
||||
# did not pop it. This allows things like
|
||||
# return getattr(self, vpath[0], None)
|
||||
iternames.pop(0)
|
||||
segleft -= 1
|
||||
node = subnode
|
||||
|
||||
if node is not None:
|
||||
# Get _cp_config attached to this node.
|
||||
if hasattr(node, "_cp_config"):
|
||||
nodeconf.update(node._cp_config)
|
||||
|
||||
# Mix in values from app.config for this path.
|
||||
existing_len = fullpath_len - pre_len
|
||||
if existing_len != 0:
|
||||
curpath = '/' + '/'.join(fullpath[0:existing_len])
|
||||
else:
|
||||
curpath = ''
|
||||
new_segs = fullpath[fullpath_len - pre_len:fullpath_len - segleft]
|
||||
for seg in new_segs:
|
||||
curpath += '/' + seg
|
||||
if curpath in app.config:
|
||||
nodeconf.update(app.config[curpath])
|
||||
|
||||
object_trail.append([name, node, nodeconf, segleft])
|
||||
|
||||
def set_conf():
|
||||
"""Collapse all object_trail config into cherrypy.request.config.
|
||||
"""
|
||||
base = cherrypy.config.copy()
|
||||
# Note that we merge the config from each node
|
||||
# even if that node was None.
|
||||
for name, obj, conf, segleft in object_trail:
|
||||
base.update(conf)
|
||||
if 'tools.staticdir.dir' in conf:
|
||||
base['tools.staticdir.section'] = '/' + \
|
||||
'/'.join(fullpath[0:fullpath_len - segleft])
|
||||
return base
|
||||
|
||||
# Try successive objects (reverse order)
|
||||
num_candidates = len(object_trail) - 1
|
||||
for i in range(num_candidates, -1, -1):
|
||||
|
||||
name, candidate, nodeconf, segleft = object_trail[i]
|
||||
if candidate is None:
|
||||
continue
|
||||
|
||||
# Try a "default" method on the current leaf.
|
||||
if hasattr(candidate, "default"):
|
||||
defhandler = candidate.default
|
||||
if getattr(defhandler, 'exposed', False):
|
||||
# Insert any extra _cp_config from the default handler.
|
||||
conf = getattr(defhandler, "_cp_config", {})
|
||||
object_trail.insert(
|
||||
i + 1, ["default", defhandler, conf, segleft])
|
||||
request.config = set_conf()
|
||||
# See https://bitbucket.org/cherrypy/cherrypy/issue/613
|
||||
request.is_index = path.endswith("/")
|
||||
return defhandler, fullpath[fullpath_len - segleft:-1]
|
||||
|
||||
# Uncomment the next line to restrict positional params to
|
||||
# "default".
|
||||
# if i < num_candidates - 2: continue
|
||||
|
||||
# Try the current leaf.
|
||||
if getattr(candidate, 'exposed', False):
|
||||
request.config = set_conf()
|
||||
if i == num_candidates:
|
||||
# We found the extra ".index". Mark request so tools
|
||||
# can redirect if path_info has no trailing slash.
|
||||
request.is_index = True
|
||||
else:
|
||||
# We're not at an 'index' handler. Mark request so tools
|
||||
# can redirect if path_info has NO trailing slash.
|
||||
# Note that this also includes handlers which take
|
||||
# positional parameters (virtual paths).
|
||||
request.is_index = False
|
||||
return candidate, fullpath[fullpath_len - segleft:-1]
|
||||
|
||||
# We didn't find anything
|
||||
request.config = set_conf()
|
||||
return None, []
|
||||
|
||||
|
||||
class MethodDispatcher(Dispatcher):
|
||||
|
||||
"""Additional dispatch based on cherrypy.request.method.upper().
|
||||
|
||||
Methods named GET, POST, etc will be called on an exposed class.
|
||||
The method names must be all caps; the appropriate Allow header
|
||||
will be output showing all capitalized method names as allowable
|
||||
HTTP verbs.
|
||||
|
||||
Note that the containing class must be exposed, not the methods.
|
||||
"""
|
||||
|
||||
def __call__(self, path_info):
|
||||
"""Set handler and config for the current request."""
|
||||
request = cherrypy.serving.request
|
||||
resource, vpath = self.find_handler(path_info)
|
||||
|
||||
if resource:
|
||||
# Set Allow header
|
||||
avail = [m for m in dir(resource) if m.isupper()]
|
||||
if "GET" in avail and "HEAD" not in avail:
|
||||
avail.append("HEAD")
|
||||
avail.sort()
|
||||
cherrypy.serving.response.headers['Allow'] = ", ".join(avail)
|
||||
|
||||
# Find the subhandler
|
||||
meth = request.method.upper()
|
||||
func = getattr(resource, meth, None)
|
||||
if func is None and meth == "HEAD":
|
||||
func = getattr(resource, "GET", None)
|
||||
if func:
|
||||
# Grab any _cp_config on the subhandler.
|
||||
if hasattr(func, "_cp_config"):
|
||||
request.config.update(func._cp_config)
|
||||
|
||||
# Decode any leftover %2F in the virtual_path atoms.
|
||||
vpath = [x.replace("%2F", "/") for x in vpath]
|
||||
request.handler = LateParamPageHandler(func, *vpath)
|
||||
else:
|
||||
request.handler = cherrypy.HTTPError(405)
|
||||
else:
|
||||
request.handler = cherrypy.NotFound()
|
||||
|
||||
|
||||
class RoutesDispatcher(object):
|
||||
|
||||
"""A Routes based dispatcher for CherryPy."""
|
||||
|
||||
def __init__(self, full_result=False, **mapper_options):
|
||||
"""
|
||||
Routes dispatcher
|
||||
|
||||
Set full_result to True if you wish the controller
|
||||
and the action to be passed on to the page handler
|
||||
parameters. By default they won't be.
|
||||
"""
|
||||
import routes
|
||||
self.full_result = full_result
|
||||
self.controllers = {}
|
||||
self.mapper = routes.Mapper(**mapper_options)
|
||||
self.mapper.controller_scan = self.controllers.keys
|
||||
|
||||
def connect(self, name, route, controller, **kwargs):
|
||||
self.controllers[name] = controller
|
||||
self.mapper.connect(name, route, controller=name, **kwargs)
|
||||
|
||||
def redirect(self, url):
|
||||
raise cherrypy.HTTPRedirect(url)
|
||||
|
||||
def __call__(self, path_info):
|
||||
"""Set handler and config for the current request."""
|
||||
func = self.find_handler(path_info)
|
||||
if func:
|
||||
cherrypy.serving.request.handler = LateParamPageHandler(func)
|
||||
else:
|
||||
cherrypy.serving.request.handler = cherrypy.NotFound()
|
||||
|
||||
def find_handler(self, path_info):
|
||||
"""Find the right page handler, and set request.config."""
|
||||
import routes
|
||||
|
||||
request = cherrypy.serving.request
|
||||
|
||||
config = routes.request_config()
|
||||
config.mapper = self.mapper
|
||||
if hasattr(request, 'wsgi_environ'):
|
||||
config.environ = request.wsgi_environ
|
||||
config.host = request.headers.get('Host', None)
|
||||
config.protocol = request.scheme
|
||||
config.redirect = self.redirect
|
||||
|
||||
result = self.mapper.match(path_info)
|
||||
|
||||
config.mapper_dict = result
|
||||
params = {}
|
||||
if result:
|
||||
params = result.copy()
|
||||
if not self.full_result:
|
||||
params.pop('controller', None)
|
||||
params.pop('action', None)
|
||||
request.params.update(params)
|
||||
|
||||
# Get config for the root object/path.
|
||||
request.config = base = cherrypy.config.copy()
|
||||
curpath = ""
|
||||
|
||||
def merge(nodeconf):
|
||||
if 'tools.staticdir.dir' in nodeconf:
|
||||
nodeconf['tools.staticdir.section'] = curpath or "/"
|
||||
base.update(nodeconf)
|
||||
|
||||
app = request.app
|
||||
root = app.root
|
||||
if hasattr(root, "_cp_config"):
|
||||
merge(root._cp_config)
|
||||
if "/" in app.config:
|
||||
merge(app.config["/"])
|
||||
|
||||
# Mix in values from app.config.
|
||||
atoms = [x for x in path_info.split("/") if x]
|
||||
if atoms:
|
||||
last = atoms.pop()
|
||||
else:
|
||||
last = None
|
||||
for atom in atoms:
|
||||
curpath = "/".join((curpath, atom))
|
||||
if curpath in app.config:
|
||||
merge(app.config[curpath])
|
||||
|
||||
handler = None
|
||||
if result:
|
||||
controller = result.get('controller')
|
||||
controller = self.controllers.get(controller, controller)
|
||||
if controller:
|
||||
if isinstance(controller, classtype):
|
||||
controller = controller()
|
||||
# Get config from the controller.
|
||||
if hasattr(controller, "_cp_config"):
|
||||
merge(controller._cp_config)
|
||||
|
||||
action = result.get('action')
|
||||
if action is not None:
|
||||
handler = getattr(controller, action, None)
|
||||
# Get config from the handler
|
||||
if hasattr(handler, "_cp_config"):
|
||||
merge(handler._cp_config)
|
||||
else:
|
||||
handler = controller
|
||||
|
||||
# Do the last path atom here so it can
|
||||
# override the controller's _cp_config.
|
||||
if last:
|
||||
curpath = "/".join((curpath, last))
|
||||
if curpath in app.config:
|
||||
merge(app.config[curpath])
|
||||
|
||||
return handler
|
||||
|
||||
|
||||
def XMLRPCDispatcher(next_dispatcher=Dispatcher()):
|
||||
from cherrypy.lib import xmlrpcutil
|
||||
|
||||
def xmlrpc_dispatch(path_info):
|
||||
path_info = xmlrpcutil.patched_path(path_info)
|
||||
return next_dispatcher(path_info)
|
||||
return xmlrpc_dispatch
|
||||
|
||||
|
||||
def VirtualHost(next_dispatcher=Dispatcher(), use_x_forwarded_host=True,
|
||||
**domains):
|
||||
"""
|
||||
Select a different handler based on the Host header.
|
||||
|
||||
This can be useful when running multiple sites within one CP server.
|
||||
It allows several domains to point to different parts of a single
|
||||
website structure. For example::
|
||||
|
||||
http://www.domain.example -> root
|
||||
http://www.domain2.example -> root/domain2/
|
||||
http://www.domain2.example:443 -> root/secure
|
||||
|
||||
can be accomplished via the following config::
|
||||
|
||||
[/]
|
||||
request.dispatch = cherrypy.dispatch.VirtualHost(
|
||||
**{'www.domain2.example': '/domain2',
|
||||
'www.domain2.example:443': '/secure',
|
||||
})
|
||||
|
||||
next_dispatcher
|
||||
The next dispatcher object in the dispatch chain.
|
||||
The VirtualHost dispatcher adds a prefix to the URL and calls
|
||||
another dispatcher. Defaults to cherrypy.dispatch.Dispatcher().
|
||||
|
||||
use_x_forwarded_host
|
||||
If True (the default), any "X-Forwarded-Host"
|
||||
request header will be used instead of the "Host" header. This
|
||||
is commonly added by HTTP servers (such as Apache) when proxying.
|
||||
|
||||
``**domains``
|
||||
A dict of {host header value: virtual prefix} pairs.
|
||||
The incoming "Host" request header is looked up in this dict,
|
||||
and, if a match is found, the corresponding "virtual prefix"
|
||||
value will be prepended to the URL path before calling the
|
||||
next dispatcher. Note that you often need separate entries
|
||||
for "example.com" and "www.example.com". In addition, "Host"
|
||||
headers may contain the port number.
|
||||
"""
|
||||
from cherrypy.lib import httputil
|
||||
|
||||
def vhost_dispatch(path_info):
|
||||
request = cherrypy.serving.request
|
||||
header = request.headers.get
|
||||
|
||||
domain = header('Host', '')
|
||||
if use_x_forwarded_host:
|
||||
domain = header("X-Forwarded-Host", domain)
|
||||
|
||||
prefix = domains.get(domain, "")
|
||||
if prefix:
|
||||
path_info = httputil.urljoin(prefix, path_info)
|
||||
|
||||
result = next_dispatcher(path_info)
|
||||
|
||||
# Touch up staticdir config. See
|
||||
# https://bitbucket.org/cherrypy/cherrypy/issue/614.
|
||||
section = request.config.get('tools.staticdir.section')
|
||||
if section:
|
||||
section = section[len(prefix):]
|
||||
request.config['tools.staticdir.section'] = section
|
||||
|
||||
return result
|
||||
return vhost_dispatch
|
||||
@@ -0,0 +1,609 @@
|
||||
"""Exception classes for CherryPy.
|
||||
|
||||
CherryPy provides (and uses) exceptions for declaring that the HTTP response
|
||||
should be a status other than the default "200 OK". You can ``raise`` them like
|
||||
normal Python exceptions. You can also call them and they will raise
|
||||
themselves; this means you can set an
|
||||
:class:`HTTPError<cherrypy._cperror.HTTPError>`
|
||||
or :class:`HTTPRedirect<cherrypy._cperror.HTTPRedirect>` as the
|
||||
:attr:`request.handler<cherrypy._cprequest.Request.handler>`.
|
||||
|
||||
.. _redirectingpost:
|
||||
|
||||
Redirecting POST
|
||||
================
|
||||
|
||||
When you GET a resource and are redirected by the server to another Location,
|
||||
there's generally no problem since GET is both a "safe method" (there should
|
||||
be no side-effects) and an "idempotent method" (multiple calls are no different
|
||||
than a single call).
|
||||
|
||||
POST, however, is neither safe nor idempotent--if you
|
||||
charge a credit card, you don't want to be charged twice by a redirect!
|
||||
|
||||
For this reason, *none* of the 3xx responses permit a user-agent (browser) to
|
||||
resubmit a POST on redirection without first confirming the action with the
|
||||
user:
|
||||
|
||||
===== ================================= ===========
|
||||
300 Multiple Choices Confirm with the user
|
||||
301 Moved Permanently Confirm with the user
|
||||
302 Found (Object moved temporarily) Confirm with the user
|
||||
303 See Other GET the new URI--no confirmation
|
||||
304 Not modified (for conditional GET only--POST should not raise this error)
|
||||
305 Use Proxy Confirm with the user
|
||||
307 Temporary Redirect Confirm with the user
|
||||
===== ================================= ===========
|
||||
|
||||
However, browsers have historically implemented these restrictions poorly;
|
||||
in particular, many browsers do not force the user to confirm 301, 302
|
||||
or 307 when redirecting POST. For this reason, CherryPy defaults to 303,
|
||||
which most user-agents appear to have implemented correctly. Therefore, if
|
||||
you raise HTTPRedirect for a POST request, the user-agent will most likely
|
||||
attempt to GET the new URI (without asking for confirmation from the user).
|
||||
We realize this is confusing for developers, but it's the safest thing we
|
||||
could do. You are of course free to raise ``HTTPRedirect(uri, status=302)``
|
||||
or any other 3xx status if you know what you're doing, but given the
|
||||
environment, we couldn't let any of those be the default.
|
||||
|
||||
Custom Error Handling
|
||||
=====================
|
||||
|
||||
.. image:: /refman/cperrors.gif
|
||||
|
||||
Anticipated HTTP responses
|
||||
--------------------------
|
||||
|
||||
The 'error_page' config namespace can be used to provide custom HTML output for
|
||||
expected responses (like 404 Not Found). Supply a filename from which the
|
||||
output will be read. The contents will be interpolated with the values
|
||||
%(status)s, %(message)s, %(traceback)s, and %(version)s using plain old Python
|
||||
`string formatting <http://docs.python.org/2/library/stdtypes.html#string-formatting-operations>`_.
|
||||
|
||||
::
|
||||
|
||||
_cp_config = {
|
||||
'error_page.404': os.path.join(localDir, "static/index.html")
|
||||
}
|
||||
|
||||
|
||||
Beginning in version 3.1, you may also provide a function or other callable as
|
||||
an error_page entry. It will be passed the same status, message, traceback and
|
||||
version arguments that are interpolated into templates::
|
||||
|
||||
def error_page_402(status, message, traceback, version):
|
||||
return "Error %s - Well, I'm very sorry but you haven't paid!" % status
|
||||
cherrypy.config.update({'error_page.402': error_page_402})
|
||||
|
||||
Also in 3.1, in addition to the numbered error codes, you may also supply
|
||||
"error_page.default" to handle all codes which do not have their own error_page
|
||||
entry.
|
||||
|
||||
|
||||
|
||||
Unanticipated errors
|
||||
--------------------
|
||||
|
||||
CherryPy also has a generic error handling mechanism: whenever an unanticipated
|
||||
error occurs in your code, it will call
|
||||
:func:`Request.error_response<cherrypy._cprequest.Request.error_response>` to
|
||||
set the response status, headers, and body. By default, this is the same
|
||||
output as
|
||||
:class:`HTTPError(500) <cherrypy._cperror.HTTPError>`. If you want to provide
|
||||
some other behavior, you generally replace "request.error_response".
|
||||
|
||||
Here is some sample code that shows how to display a custom error message and
|
||||
send an e-mail containing the error::
|
||||
|
||||
from cherrypy import _cperror
|
||||
|
||||
def handle_error():
|
||||
cherrypy.response.status = 500
|
||||
cherrypy.response.body = [
|
||||
"<html><body>Sorry, an error occured</body></html>"
|
||||
]
|
||||
sendMail('error@domain.com',
|
||||
'Error in your web app',
|
||||
_cperror.format_exc())
|
||||
|
||||
class Root:
|
||||
_cp_config = {'request.error_response': handle_error}
|
||||
|
||||
|
||||
Note that you have to explicitly set
|
||||
:attr:`response.body <cherrypy._cprequest.Response.body>`
|
||||
and not simply return an error message as a result.
|
||||
"""
|
||||
|
||||
from cgi import escape as _escape
|
||||
from sys import exc_info as _exc_info
|
||||
from traceback import format_exception as _format_exception
|
||||
from cherrypy._cpcompat import basestring, bytestr, iteritems, ntob
|
||||
from cherrypy._cpcompat import tonative, urljoin as _urljoin
|
||||
from cherrypy.lib import httputil as _httputil
|
||||
|
||||
|
||||
class CherryPyException(Exception):
|
||||
|
||||
"""A base class for CherryPy exceptions."""
|
||||
pass
|
||||
|
||||
|
||||
class TimeoutError(CherryPyException):
|
||||
|
||||
"""Exception raised when Response.timed_out is detected."""
|
||||
pass
|
||||
|
||||
|
||||
class InternalRedirect(CherryPyException):
|
||||
|
||||
"""Exception raised to switch to the handler for a different URL.
|
||||
|
||||
This exception will redirect processing to another path within the site
|
||||
(without informing the client). Provide the new path as an argument when
|
||||
raising the exception. Provide any params in the querystring for the new
|
||||
URL.
|
||||
"""
|
||||
|
||||
def __init__(self, path, query_string=""):
|
||||
import cherrypy
|
||||
self.request = cherrypy.serving.request
|
||||
|
||||
self.query_string = query_string
|
||||
if "?" in path:
|
||||
# Separate any params included in the path
|
||||
path, self.query_string = path.split("?", 1)
|
||||
|
||||
# Note that urljoin will "do the right thing" whether url is:
|
||||
# 1. a URL relative to root (e.g. "/dummy")
|
||||
# 2. a URL relative to the current path
|
||||
# Note that any query string will be discarded.
|
||||
path = _urljoin(self.request.path_info, path)
|
||||
|
||||
# Set a 'path' member attribute so that code which traps this
|
||||
# error can have access to it.
|
||||
self.path = path
|
||||
|
||||
CherryPyException.__init__(self, path, self.query_string)
|
||||
|
||||
|
||||
class HTTPRedirect(CherryPyException):
|
||||
|
||||
"""Exception raised when the request should be redirected.
|
||||
|
||||
This exception will force a HTTP redirect to the URL or URL's you give it.
|
||||
The new URL must be passed as the first argument to the Exception,
|
||||
e.g., HTTPRedirect(newUrl). Multiple URLs are allowed in a list.
|
||||
If a URL is absolute, it will be used as-is. If it is relative, it is
|
||||
assumed to be relative to the current cherrypy.request.path_info.
|
||||
|
||||
If one of the provided URL is a unicode object, it will be encoded
|
||||
using the default encoding or the one passed in parameter.
|
||||
|
||||
There are multiple types of redirect, from which you can select via the
|
||||
``status`` argument. If you do not provide a ``status`` arg, it defaults to
|
||||
303 (or 302 if responding with HTTP/1.0).
|
||||
|
||||
Examples::
|
||||
|
||||
raise cherrypy.HTTPRedirect("")
|
||||
raise cherrypy.HTTPRedirect("/abs/path", 307)
|
||||
raise cherrypy.HTTPRedirect(["path1", "path2?a=1&b=2"], 301)
|
||||
|
||||
See :ref:`redirectingpost` for additional caveats.
|
||||
"""
|
||||
|
||||
status = None
|
||||
"""The integer HTTP status code to emit."""
|
||||
|
||||
urls = None
|
||||
"""The list of URL's to emit."""
|
||||
|
||||
encoding = 'utf-8'
|
||||
"""The encoding when passed urls are not native strings"""
|
||||
|
||||
def __init__(self, urls, status=None, encoding=None):
|
||||
import cherrypy
|
||||
request = cherrypy.serving.request
|
||||
|
||||
if isinstance(urls, basestring):
|
||||
urls = [urls]
|
||||
|
||||
abs_urls = []
|
||||
for url in urls:
|
||||
url = tonative(url, encoding or self.encoding)
|
||||
|
||||
# Note that urljoin will "do the right thing" whether url is:
|
||||
# 1. a complete URL with host (e.g. "http://www.example.com/test")
|
||||
# 2. a URL relative to root (e.g. "/dummy")
|
||||
# 3. a URL relative to the current path
|
||||
# Note that any query string in cherrypy.request is discarded.
|
||||
url = _urljoin(cherrypy.url(), url)
|
||||
abs_urls.append(url)
|
||||
self.urls = abs_urls
|
||||
|
||||
# RFC 2616 indicates a 301 response code fits our goal; however,
|
||||
# browser support for 301 is quite messy. Do 302/303 instead. See
|
||||
# http://www.alanflavell.org.uk/www/post-redirect.html
|
||||
if status is None:
|
||||
if request.protocol >= (1, 1):
|
||||
status = 303
|
||||
else:
|
||||
status = 302
|
||||
else:
|
||||
status = int(status)
|
||||
if status < 300 or status > 399:
|
||||
raise ValueError("status must be between 300 and 399.")
|
||||
|
||||
self.status = status
|
||||
CherryPyException.__init__(self, abs_urls, status)
|
||||
|
||||
def set_response(self):
|
||||
"""Modify cherrypy.response status, headers, and body to represent
|
||||
self.
|
||||
|
||||
CherryPy uses this internally, but you can also use it to create an
|
||||
HTTPRedirect object and set its output without *raising* the exception.
|
||||
"""
|
||||
import cherrypy
|
||||
response = cherrypy.serving.response
|
||||
response.status = status = self.status
|
||||
|
||||
if status in (300, 301, 302, 303, 307):
|
||||
response.headers['Content-Type'] = "text/html;charset=utf-8"
|
||||
# "The ... URI SHOULD be given by the Location field
|
||||
# in the response."
|
||||
response.headers['Location'] = self.urls[0]
|
||||
|
||||
# "Unless the request method was HEAD, the entity of the response
|
||||
# SHOULD contain a short hypertext note with a hyperlink to the
|
||||
# new URI(s)."
|
||||
msg = {
|
||||
300: "This resource can be found at ",
|
||||
301: "This resource has permanently moved to ",
|
||||
302: "This resource resides temporarily at ",
|
||||
303: "This resource can be found at ",
|
||||
307: "This resource has moved temporarily to ",
|
||||
}[status]
|
||||
msg += '<a href=%s>%s</a>.'
|
||||
from xml.sax import saxutils
|
||||
msgs = [msg % (saxutils.quoteattr(u), u) for u in self.urls]
|
||||
response.body = ntob("<br />\n".join(msgs), 'utf-8')
|
||||
# Previous code may have set C-L, so we have to reset it
|
||||
# (allow finalize to set it).
|
||||
response.headers.pop('Content-Length', None)
|
||||
elif status == 304:
|
||||
# Not Modified.
|
||||
# "The response MUST include the following header fields:
|
||||
# Date, unless its omission is required by section 14.18.1"
|
||||
# The "Date" header should have been set in Response.__init__
|
||||
|
||||
# "...the response SHOULD NOT include other entity-headers."
|
||||
for key in ('Allow', 'Content-Encoding', 'Content-Language',
|
||||
'Content-Length', 'Content-Location', 'Content-MD5',
|
||||
'Content-Range', 'Content-Type', 'Expires',
|
||||
'Last-Modified'):
|
||||
if key in response.headers:
|
||||
del response.headers[key]
|
||||
|
||||
# "The 304 response MUST NOT contain a message-body."
|
||||
response.body = None
|
||||
# Previous code may have set C-L, so we have to reset it.
|
||||
response.headers.pop('Content-Length', None)
|
||||
elif status == 305:
|
||||
# Use Proxy.
|
||||
# self.urls[0] should be the URI of the proxy.
|
||||
response.headers['Location'] = self.urls[0]
|
||||
response.body = None
|
||||
# Previous code may have set C-L, so we have to reset it.
|
||||
response.headers.pop('Content-Length', None)
|
||||
else:
|
||||
raise ValueError("The %s status code is unknown." % status)
|
||||
|
||||
def __call__(self):
|
||||
"""Use this exception as a request.handler (raise self)."""
|
||||
raise self
|
||||
|
||||
|
||||
def clean_headers(status):
|
||||
"""Remove any headers which should not apply to an error response."""
|
||||
import cherrypy
|
||||
|
||||
response = cherrypy.serving.response
|
||||
|
||||
# Remove headers which applied to the original content,
|
||||
# but do not apply to the error page.
|
||||
respheaders = response.headers
|
||||
for key in ["Accept-Ranges", "Age", "ETag", "Location", "Retry-After",
|
||||
"Vary", "Content-Encoding", "Content-Length", "Expires",
|
||||
"Content-Location", "Content-MD5", "Last-Modified"]:
|
||||
if key in respheaders:
|
||||
del respheaders[key]
|
||||
|
||||
if status != 416:
|
||||
# A server sending a response with status code 416 (Requested
|
||||
# range not satisfiable) SHOULD include a Content-Range field
|
||||
# with a byte-range-resp-spec of "*". The instance-length
|
||||
# specifies the current length of the selected resource.
|
||||
# A response with status code 206 (Partial Content) MUST NOT
|
||||
# include a Content-Range field with a byte-range- resp-spec of "*".
|
||||
if "Content-Range" in respheaders:
|
||||
del respheaders["Content-Range"]
|
||||
|
||||
|
||||
class HTTPError(CherryPyException):
|
||||
|
||||
"""Exception used to return an HTTP error code (4xx-5xx) to the client.
|
||||
|
||||
This exception can be used to automatically send a response using a
|
||||
http status code, with an appropriate error page. It takes an optional
|
||||
``status`` argument (which must be between 400 and 599); it defaults to 500
|
||||
("Internal Server Error"). It also takes an optional ``message`` argument,
|
||||
which will be returned in the response body. See
|
||||
`RFC2616 <http://www.w3.org/Protocols/rfc2616/rfc2616-sec10.html#sec10.4>`_
|
||||
for a complete list of available error codes and when to use them.
|
||||
|
||||
Examples::
|
||||
|
||||
raise cherrypy.HTTPError(403)
|
||||
raise cherrypy.HTTPError(
|
||||
"403 Forbidden", "You are not allowed to access this resource.")
|
||||
"""
|
||||
|
||||
status = None
|
||||
"""The HTTP status code. May be of type int or str (with a Reason-Phrase).
|
||||
"""
|
||||
|
||||
code = None
|
||||
"""The integer HTTP status code."""
|
||||
|
||||
reason = None
|
||||
"""The HTTP Reason-Phrase string."""
|
||||
|
||||
def __init__(self, status=500, message=None):
|
||||
self.status = status
|
||||
try:
|
||||
self.code, self.reason, defaultmsg = _httputil.valid_status(status)
|
||||
except ValueError:
|
||||
raise self.__class__(500, _exc_info()[1].args[0])
|
||||
|
||||
if self.code < 400 or self.code > 599:
|
||||
raise ValueError("status must be between 400 and 599.")
|
||||
|
||||
# See http://www.python.org/dev/peps/pep-0352/
|
||||
# self.message = message
|
||||
self._message = message or defaultmsg
|
||||
CherryPyException.__init__(self, status, message)
|
||||
|
||||
def set_response(self):
|
||||
"""Modify cherrypy.response status, headers, and body to represent
|
||||
self.
|
||||
|
||||
CherryPy uses this internally, but you can also use it to create an
|
||||
HTTPError object and set its output without *raising* the exception.
|
||||
"""
|
||||
import cherrypy
|
||||
|
||||
response = cherrypy.serving.response
|
||||
|
||||
clean_headers(self.code)
|
||||
|
||||
# In all cases, finalize will be called after this method,
|
||||
# so don't bother cleaning up response values here.
|
||||
response.status = self.status
|
||||
tb = None
|
||||
if cherrypy.serving.request.show_tracebacks:
|
||||
tb = format_exc()
|
||||
|
||||
response.headers.pop('Content-Length', None)
|
||||
|
||||
content = self.get_error_page(self.status, traceback=tb,
|
||||
message=self._message)
|
||||
response.body = content
|
||||
|
||||
_be_ie_unfriendly(self.code)
|
||||
|
||||
def get_error_page(self, *args, **kwargs):
|
||||
return get_error_page(*args, **kwargs)
|
||||
|
||||
def __call__(self):
|
||||
"""Use this exception as a request.handler (raise self)."""
|
||||
raise self
|
||||
|
||||
|
||||
class NotFound(HTTPError):
|
||||
|
||||
"""Exception raised when a URL could not be mapped to any handler (404).
|
||||
|
||||
This is equivalent to raising
|
||||
:class:`HTTPError("404 Not Found") <cherrypy._cperror.HTTPError>`.
|
||||
"""
|
||||
|
||||
def __init__(self, path=None):
|
||||
if path is None:
|
||||
import cherrypy
|
||||
request = cherrypy.serving.request
|
||||
path = request.script_name + request.path_info
|
||||
self.args = (path,)
|
||||
HTTPError.__init__(self, 404, "The path '%s' was not found." % path)
|
||||
|
||||
|
||||
_HTTPErrorTemplate = '''<!DOCTYPE html PUBLIC
|
||||
"-//W3C//DTD XHTML 1.0 Transitional//EN"
|
||||
"http://www.w3.org/TR/xhtml1/DTD/xhtml1-transitional.dtd">
|
||||
<html>
|
||||
<head>
|
||||
<meta http-equiv="Content-Type" content="text/html; charset=utf-8"></meta>
|
||||
<title>%(status)s</title>
|
||||
<style type="text/css">
|
||||
#powered_by {
|
||||
margin-top: 20px;
|
||||
border-top: 2px solid black;
|
||||
font-style: italic;
|
||||
}
|
||||
|
||||
#traceback {
|
||||
color: red;
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<h2>%(status)s</h2>
|
||||
<p>%(message)s</p>
|
||||
<pre id="traceback">%(traceback)s</pre>
|
||||
<div id="powered_by">
|
||||
<span>
|
||||
Powered by <a href="http://www.cherrypy.org">CherryPy %(version)s</a>
|
||||
</span>
|
||||
</div>
|
||||
</body>
|
||||
</html>
|
||||
'''
|
||||
|
||||
|
||||
def get_error_page(status, **kwargs):
|
||||
"""Return an HTML page, containing a pretty error response.
|
||||
|
||||
status should be an int or a str.
|
||||
kwargs will be interpolated into the page template.
|
||||
"""
|
||||
import cherrypy
|
||||
|
||||
try:
|
||||
code, reason, message = _httputil.valid_status(status)
|
||||
except ValueError:
|
||||
raise cherrypy.HTTPError(500, _exc_info()[1].args[0])
|
||||
|
||||
# We can't use setdefault here, because some
|
||||
# callers send None for kwarg values.
|
||||
if kwargs.get('status') is None:
|
||||
kwargs['status'] = "%s %s" % (code, reason)
|
||||
if kwargs.get('message') is None:
|
||||
kwargs['message'] = message
|
||||
if kwargs.get('traceback') is None:
|
||||
kwargs['traceback'] = ''
|
||||
if kwargs.get('version') is None:
|
||||
kwargs['version'] = cherrypy.__version__
|
||||
|
||||
for k, v in iteritems(kwargs):
|
||||
if v is None:
|
||||
kwargs[k] = ""
|
||||
else:
|
||||
kwargs[k] = _escape(kwargs[k])
|
||||
|
||||
# Use a custom template or callable for the error page?
|
||||
pages = cherrypy.serving.request.error_page
|
||||
error_page = pages.get(code) or pages.get('default')
|
||||
|
||||
# Default template, can be overridden below.
|
||||
template = _HTTPErrorTemplate
|
||||
if error_page:
|
||||
try:
|
||||
if hasattr(error_page, '__call__'):
|
||||
# The caller function may be setting headers manually,
|
||||
# so we delegate to it completely. We may be returning
|
||||
# an iterator as well as a string here.
|
||||
#
|
||||
# We *must* make sure any content is not unicode.
|
||||
result = error_page(**kwargs)
|
||||
if cherrypy.lib.is_iterator(result):
|
||||
from cherrypy.lib.encoding import UTF8StreamEncoder
|
||||
return UTF8StreamEncoder(result)
|
||||
elif isinstance(result, cherrypy._cpcompat.unicodestr):
|
||||
return result.encode('utf-8')
|
||||
else:
|
||||
if not isinstance(result, cherrypy._cpcompat.bytestr):
|
||||
raise ValueError('error page function did not '
|
||||
'return a bytestring, unicodestring or an '
|
||||
'iterator - returned object of type %s.'
|
||||
% (type(result).__name__))
|
||||
return result
|
||||
else:
|
||||
# Load the template from this path.
|
||||
template = tonative(open(error_page, 'rb').read())
|
||||
except:
|
||||
e = _format_exception(*_exc_info())[-1]
|
||||
m = kwargs['message']
|
||||
if m:
|
||||
m += "<br />"
|
||||
m += "In addition, the custom error page failed:\n<br />%s" % e
|
||||
kwargs['message'] = m
|
||||
|
||||
response = cherrypy.serving.response
|
||||
response.headers['Content-Type'] = "text/html;charset=utf-8"
|
||||
result = template % kwargs
|
||||
return result.encode('utf-8')
|
||||
|
||||
|
||||
|
||||
_ie_friendly_error_sizes = {
|
||||
400: 512, 403: 256, 404: 512, 405: 256,
|
||||
406: 512, 408: 512, 409: 512, 410: 256,
|
||||
500: 512, 501: 512, 505: 512,
|
||||
}
|
||||
|
||||
|
||||
def _be_ie_unfriendly(status):
|
||||
import cherrypy
|
||||
response = cherrypy.serving.response
|
||||
|
||||
# For some statuses, Internet Explorer 5+ shows "friendly error
|
||||
# messages" instead of our response.body if the body is smaller
|
||||
# than a given size. Fix this by returning a body over that size
|
||||
# (by adding whitespace).
|
||||
# See http://support.microsoft.com/kb/q218155/
|
||||
s = _ie_friendly_error_sizes.get(status, 0)
|
||||
if s:
|
||||
s += 1
|
||||
# Since we are issuing an HTTP error status, we assume that
|
||||
# the entity is short, and we should just collapse it.
|
||||
content = response.collapse_body()
|
||||
l = len(content)
|
||||
if l and l < s:
|
||||
# IN ADDITION: the response must be written to IE
|
||||
# in one chunk or it will still get replaced! Bah.
|
||||
content = content + (ntob(" ") * (s - l))
|
||||
response.body = content
|
||||
response.headers['Content-Length'] = str(len(content))
|
||||
|
||||
|
||||
def format_exc(exc=None):
|
||||
"""Return exc (or sys.exc_info if None), formatted."""
|
||||
try:
|
||||
if exc is None:
|
||||
exc = _exc_info()
|
||||
if exc == (None, None, None):
|
||||
return ""
|
||||
import traceback
|
||||
return "".join(traceback.format_exception(*exc))
|
||||
finally:
|
||||
del exc
|
||||
|
||||
|
||||
def bare_error(extrabody=None):
|
||||
"""Produce status, headers, body for a critical error.
|
||||
|
||||
Returns a triple without calling any other questionable functions,
|
||||
so it should be as error-free as possible. Call it from an HTTP server
|
||||
if you get errors outside of the request.
|
||||
|
||||
If extrabody is None, a friendly but rather unhelpful error message
|
||||
is set in the body. If extrabody is a string, it will be appended
|
||||
as-is to the body.
|
||||
"""
|
||||
|
||||
# The whole point of this function is to be a last line-of-defense
|
||||
# in handling errors. That is, it must not raise any errors itself;
|
||||
# it cannot be allowed to fail. Therefore, don't add to it!
|
||||
# In particular, don't call any other CP functions.
|
||||
|
||||
body = ntob("Unrecoverable error in the server.")
|
||||
if extrabody is not None:
|
||||
if not isinstance(extrabody, bytestr):
|
||||
extrabody = extrabody.encode('utf-8')
|
||||
body += ntob("\n") + extrabody
|
||||
|
||||
return (ntob("500 Internal Server Error"),
|
||||
[(ntob('Content-Type'), ntob('text/plain')),
|
||||
(ntob('Content-Length'), ntob(str(len(body)), 'ISO-8859-1'))],
|
||||
[body])
|
||||
@@ -0,0 +1,459 @@
|
||||
"""
|
||||
Simple config
|
||||
=============
|
||||
|
||||
Although CherryPy uses the :mod:`Python logging module <logging>`, it does so
|
||||
behind the scenes so that simple logging is simple, but complicated logging
|
||||
is still possible. "Simple" logging means that you can log to the screen
|
||||
(i.e. console/stdout) or to a file, and that you can easily have separate
|
||||
error and access log files.
|
||||
|
||||
Here are the simplified logging settings. You use these by adding lines to
|
||||
your config file or dict. You should set these at either the global level or
|
||||
per application (see next), but generally not both.
|
||||
|
||||
* ``log.screen``: Set this to True to have both "error" and "access" messages
|
||||
printed to stdout.
|
||||
* ``log.access_file``: Set this to an absolute filename where you want
|
||||
"access" messages written.
|
||||
* ``log.error_file``: Set this to an absolute filename where you want "error"
|
||||
messages written.
|
||||
|
||||
Many events are automatically logged; to log your own application events, call
|
||||
:func:`cherrypy.log`.
|
||||
|
||||
Architecture
|
||||
============
|
||||
|
||||
Separate scopes
|
||||
---------------
|
||||
|
||||
CherryPy provides log managers at both the global and application layers.
|
||||
This means you can have one set of logging rules for your entire site,
|
||||
and another set of rules specific to each application. The global log
|
||||
manager is found at :func:`cherrypy.log`, and the log manager for each
|
||||
application is found at :attr:`app.log<cherrypy._cptree.Application.log>`.
|
||||
If you're inside a request, the latter is reachable from
|
||||
``cherrypy.request.app.log``; if you're outside a request, you'll have to
|
||||
obtain a reference to the ``app``: either the return value of
|
||||
:func:`tree.mount()<cherrypy._cptree.Tree.mount>` or, if you used
|
||||
:func:`quickstart()<cherrypy.quickstart>` instead, via
|
||||
``cherrypy.tree.apps['/']``.
|
||||
|
||||
By default, the global logs are named "cherrypy.error" and "cherrypy.access",
|
||||
and the application logs are named "cherrypy.error.2378745" and
|
||||
"cherrypy.access.2378745" (the number is the id of the Application object).
|
||||
This means that the application logs "bubble up" to the site logs, so if your
|
||||
application has no log handlers, the site-level handlers will still log the
|
||||
messages.
|
||||
|
||||
Errors vs. Access
|
||||
-----------------
|
||||
|
||||
Each log manager handles both "access" messages (one per HTTP request) and
|
||||
"error" messages (everything else). Note that the "error" log is not just for
|
||||
errors! The format of access messages is highly formalized, but the error log
|
||||
isn't--it receives messages from a variety of sources (including full error
|
||||
tracebacks, if enabled).
|
||||
|
||||
If you are logging the access log and error log to the same source, then there
|
||||
is a possibility that a specially crafted error message may replicate an access
|
||||
log message as described in CWE-117. In this case it is the application
|
||||
developer's responsibility to manually escape data before using CherryPy's log()
|
||||
functionality, or they may create an application that is vulnerable to CWE-117.
|
||||
This would be achieved by using a custom handler escape any special characters,
|
||||
and attached as described below.
|
||||
|
||||
Custom Handlers
|
||||
===============
|
||||
|
||||
The simple settings above work by manipulating Python's standard :mod:`logging`
|
||||
module. So when you need something more complex, the full power of the standard
|
||||
module is yours to exploit. You can borrow or create custom handlers, formats,
|
||||
filters, and much more. Here's an example that skips the standard FileHandler
|
||||
and uses a RotatingFileHandler instead:
|
||||
|
||||
::
|
||||
|
||||
#python
|
||||
log = app.log
|
||||
|
||||
# Remove the default FileHandlers if present.
|
||||
log.error_file = ""
|
||||
log.access_file = ""
|
||||
|
||||
maxBytes = getattr(log, "rot_maxBytes", 10000000)
|
||||
backupCount = getattr(log, "rot_backupCount", 1000)
|
||||
|
||||
# Make a new RotatingFileHandler for the error log.
|
||||
fname = getattr(log, "rot_error_file", "error.log")
|
||||
h = handlers.RotatingFileHandler(fname, 'a', maxBytes, backupCount)
|
||||
h.setLevel(DEBUG)
|
||||
h.setFormatter(_cplogging.logfmt)
|
||||
log.error_log.addHandler(h)
|
||||
|
||||
# Make a new RotatingFileHandler for the access log.
|
||||
fname = getattr(log, "rot_access_file", "access.log")
|
||||
h = handlers.RotatingFileHandler(fname, 'a', maxBytes, backupCount)
|
||||
h.setLevel(DEBUG)
|
||||
h.setFormatter(_cplogging.logfmt)
|
||||
log.access_log.addHandler(h)
|
||||
|
||||
|
||||
The ``rot_*`` attributes are pulled straight from the application log object.
|
||||
Since "log.*" config entries simply set attributes on the log object, you can
|
||||
add custom attributes to your heart's content. Note that these handlers are
|
||||
used ''instead'' of the default, simple handlers outlined above (so don't set
|
||||
the "log.error_file" config entry, for example).
|
||||
"""
|
||||
|
||||
import datetime
|
||||
import logging
|
||||
# Silence the no-handlers "warning" (stderr write!) in stdlib logging
|
||||
logging.Logger.manager.emittedNoHandlerWarning = 1
|
||||
logfmt = logging.Formatter("%(message)s")
|
||||
import os
|
||||
import sys
|
||||
|
||||
import cherrypy
|
||||
from cherrypy import _cperror
|
||||
from cherrypy._cpcompat import ntob, py3k
|
||||
|
||||
|
||||
class NullHandler(logging.Handler):
|
||||
|
||||
"""A no-op logging handler to silence the logging.lastResort handler."""
|
||||
|
||||
def handle(self, record):
|
||||
pass
|
||||
|
||||
def emit(self, record):
|
||||
pass
|
||||
|
||||
def createLock(self):
|
||||
self.lock = None
|
||||
|
||||
|
||||
class LogManager(object):
|
||||
|
||||
"""An object to assist both simple and advanced logging.
|
||||
|
||||
``cherrypy.log`` is an instance of this class.
|
||||
"""
|
||||
|
||||
appid = None
|
||||
"""The id() of the Application object which owns this log manager. If this
|
||||
is a global log manager, appid is None."""
|
||||
|
||||
error_log = None
|
||||
"""The actual :class:`logging.Logger` instance for error messages."""
|
||||
|
||||
access_log = None
|
||||
"""The actual :class:`logging.Logger` instance for access messages."""
|
||||
|
||||
if py3k:
|
||||
access_log_format = \
|
||||
'{h} {l} {u} {t} "{r}" {s} {b} "{f}" "{a}"'
|
||||
else:
|
||||
access_log_format = \
|
||||
'%(h)s %(l)s %(u)s %(t)s "%(r)s" %(s)s %(b)s "%(f)s" "%(a)s"'
|
||||
|
||||
logger_root = None
|
||||
"""The "top-level" logger name.
|
||||
|
||||
This string will be used as the first segment in the Logger names.
|
||||
The default is "cherrypy", for example, in which case the Logger names
|
||||
will be of the form::
|
||||
|
||||
cherrypy.error.<appid>
|
||||
cherrypy.access.<appid>
|
||||
"""
|
||||
|
||||
def __init__(self, appid=None, logger_root="cherrypy"):
|
||||
self.logger_root = logger_root
|
||||
self.appid = appid
|
||||
if appid is None:
|
||||
self.error_log = logging.getLogger("%s.error" % logger_root)
|
||||
self.access_log = logging.getLogger("%s.access" % logger_root)
|
||||
else:
|
||||
self.error_log = logging.getLogger(
|
||||
"%s.error.%s" % (logger_root, appid))
|
||||
self.access_log = logging.getLogger(
|
||||
"%s.access.%s" % (logger_root, appid))
|
||||
self.error_log.setLevel(logging.INFO)
|
||||
self.access_log.setLevel(logging.INFO)
|
||||
|
||||
# Silence the no-handlers "warning" (stderr write!) in stdlib logging
|
||||
self.error_log.addHandler(NullHandler())
|
||||
self.access_log.addHandler(NullHandler())
|
||||
|
||||
cherrypy.engine.subscribe('graceful', self.reopen_files)
|
||||
|
||||
def reopen_files(self):
|
||||
"""Close and reopen all file handlers."""
|
||||
for log in (self.error_log, self.access_log):
|
||||
for h in log.handlers:
|
||||
if isinstance(h, logging.FileHandler):
|
||||
h.acquire()
|
||||
h.stream.close()
|
||||
h.stream = open(h.baseFilename, h.mode)
|
||||
h.release()
|
||||
|
||||
def error(self, msg='', context='', severity=logging.INFO,
|
||||
traceback=False):
|
||||
"""Write the given ``msg`` to the error log.
|
||||
|
||||
This is not just for errors! Applications may call this at any time
|
||||
to log application-specific information.
|
||||
|
||||
If ``traceback`` is True, the traceback of the current exception
|
||||
(if any) will be appended to ``msg``.
|
||||
"""
|
||||
if traceback:
|
||||
msg += _cperror.format_exc()
|
||||
self.error_log.log(severity, ' '.join((self.time(), context, msg)))
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
"""An alias for ``error``."""
|
||||
return self.error(*args, **kwargs)
|
||||
|
||||
def access(self):
|
||||
"""Write to the access log (in Apache/NCSA Combined Log format).
|
||||
|
||||
See the
|
||||
`apache documentation <http://httpd.apache.org/docs/current/logs.html#combined>`_
|
||||
for format details.
|
||||
|
||||
CherryPy calls this automatically for you. Note there are no arguments;
|
||||
it collects the data itself from
|
||||
:class:`cherrypy.request<cherrypy._cprequest.Request>`.
|
||||
|
||||
Like Apache started doing in 2.0.46, non-printable and other special
|
||||
characters in %r (and we expand that to all parts) are escaped using
|
||||
\\xhh sequences, where hh stands for the hexadecimal representation
|
||||
of the raw byte. Exceptions from this rule are " and \\, which are
|
||||
escaped by prepending a backslash, and all whitespace characters,
|
||||
which are written in their C-style notation (\\n, \\t, etc).
|
||||
"""
|
||||
request = cherrypy.serving.request
|
||||
remote = request.remote
|
||||
response = cherrypy.serving.response
|
||||
outheaders = response.headers
|
||||
inheaders = request.headers
|
||||
if response.output_status is None:
|
||||
status = "-"
|
||||
else:
|
||||
status = response.output_status.split(ntob(" "), 1)[0]
|
||||
if py3k:
|
||||
status = status.decode('ISO-8859-1')
|
||||
|
||||
atoms = {'h': remote.name or remote.ip,
|
||||
'l': '-',
|
||||
'u': getattr(request, "login", None) or "-",
|
||||
't': self.time(),
|
||||
'r': request.request_line,
|
||||
's': status,
|
||||
'b': dict.get(outheaders, 'Content-Length', '') or "-",
|
||||
'f': dict.get(inheaders, 'Referer', ''),
|
||||
'a': dict.get(inheaders, 'User-Agent', ''),
|
||||
'o': dict.get(inheaders, 'Host', '-'),
|
||||
}
|
||||
if py3k:
|
||||
for k, v in atoms.items():
|
||||
if not isinstance(v, str):
|
||||
v = str(v)
|
||||
v = v.replace('"', '\\"').encode('utf8')
|
||||
# Fortunately, repr(str) escapes unprintable chars, \n, \t, etc
|
||||
# and backslash for us. All we have to do is strip the quotes.
|
||||
v = repr(v)[2:-1]
|
||||
|
||||
# in python 3.0 the repr of bytes (as returned by encode)
|
||||
# uses double \'s. But then the logger escapes them yet, again
|
||||
# resulting in quadruple slashes. Remove the extra one here.
|
||||
v = v.replace('\\\\', '\\')
|
||||
|
||||
# Escape double-quote.
|
||||
atoms[k] = v
|
||||
|
||||
try:
|
||||
self.access_log.log(
|
||||
logging.INFO, self.access_log_format.format(**atoms))
|
||||
except:
|
||||
self(traceback=True)
|
||||
else:
|
||||
for k, v in atoms.items():
|
||||
if isinstance(v, unicode):
|
||||
v = v.encode('utf8')
|
||||
elif not isinstance(v, str):
|
||||
v = str(v)
|
||||
# Fortunately, repr(str) escapes unprintable chars, \n, \t, etc
|
||||
# and backslash for us. All we have to do is strip the quotes.
|
||||
v = repr(v)[1:-1]
|
||||
# Escape double-quote.
|
||||
atoms[k] = v.replace('"', '\\"')
|
||||
|
||||
try:
|
||||
self.access_log.log(
|
||||
logging.INFO, self.access_log_format % atoms)
|
||||
except:
|
||||
self(traceback=True)
|
||||
|
||||
def time(self):
|
||||
"""Return now() in Apache Common Log Format (no timezone)."""
|
||||
now = datetime.datetime.now()
|
||||
monthnames = ['jan', 'feb', 'mar', 'apr', 'may', 'jun',
|
||||
'jul', 'aug', 'sep', 'oct', 'nov', 'dec']
|
||||
month = monthnames[now.month - 1].capitalize()
|
||||
return ('[%02d/%s/%04d:%02d:%02d:%02d]' %
|
||||
(now.day, month, now.year, now.hour, now.minute, now.second))
|
||||
|
||||
def _get_builtin_handler(self, log, key):
|
||||
for h in log.handlers:
|
||||
if getattr(h, "_cpbuiltin", None) == key:
|
||||
return h
|
||||
|
||||
# ------------------------- Screen handlers ------------------------- #
|
||||
def _set_screen_handler(self, log, enable, stream=None):
|
||||
h = self._get_builtin_handler(log, "screen")
|
||||
if enable:
|
||||
if not h:
|
||||
if stream is None:
|
||||
stream = sys.stderr
|
||||
h = logging.StreamHandler(stream)
|
||||
h.setFormatter(logfmt)
|
||||
h._cpbuiltin = "screen"
|
||||
log.addHandler(h)
|
||||
elif h:
|
||||
log.handlers.remove(h)
|
||||
|
||||
def _get_screen(self):
|
||||
h = self._get_builtin_handler
|
||||
has_h = h(self.error_log, "screen") or h(self.access_log, "screen")
|
||||
return bool(has_h)
|
||||
|
||||
def _set_screen(self, newvalue):
|
||||
self._set_screen_handler(self.error_log, newvalue, stream=sys.stderr)
|
||||
self._set_screen_handler(self.access_log, newvalue, stream=sys.stdout)
|
||||
screen = property(_get_screen, _set_screen,
|
||||
doc="""Turn stderr/stdout logging on or off.
|
||||
|
||||
If you set this to True, it'll add the appropriate StreamHandler for
|
||||
you. If you set it to False, it will remove the handler.
|
||||
""")
|
||||
|
||||
# -------------------------- File handlers -------------------------- #
|
||||
|
||||
def _add_builtin_file_handler(self, log, fname):
|
||||
h = logging.FileHandler(fname)
|
||||
h.setFormatter(logfmt)
|
||||
h._cpbuiltin = "file"
|
||||
log.addHandler(h)
|
||||
|
||||
def _set_file_handler(self, log, filename):
|
||||
h = self._get_builtin_handler(log, "file")
|
||||
if filename:
|
||||
if h:
|
||||
if h.baseFilename != os.path.abspath(filename):
|
||||
h.close()
|
||||
log.handlers.remove(h)
|
||||
self._add_builtin_file_handler(log, filename)
|
||||
else:
|
||||
self._add_builtin_file_handler(log, filename)
|
||||
else:
|
||||
if h:
|
||||
h.close()
|
||||
log.handlers.remove(h)
|
||||
|
||||
def _get_error_file(self):
|
||||
h = self._get_builtin_handler(self.error_log, "file")
|
||||
if h:
|
||||
return h.baseFilename
|
||||
return ''
|
||||
|
||||
def _set_error_file(self, newvalue):
|
||||
self._set_file_handler(self.error_log, newvalue)
|
||||
error_file = property(_get_error_file, _set_error_file,
|
||||
doc="""The filename for self.error_log.
|
||||
|
||||
If you set this to a string, it'll add the appropriate FileHandler for
|
||||
you. If you set it to ``None`` or ``''``, it will remove the handler.
|
||||
""")
|
||||
|
||||
def _get_access_file(self):
|
||||
h = self._get_builtin_handler(self.access_log, "file")
|
||||
if h:
|
||||
return h.baseFilename
|
||||
return ''
|
||||
|
||||
def _set_access_file(self, newvalue):
|
||||
self._set_file_handler(self.access_log, newvalue)
|
||||
access_file = property(_get_access_file, _set_access_file,
|
||||
doc="""The filename for self.access_log.
|
||||
|
||||
If you set this to a string, it'll add the appropriate FileHandler for
|
||||
you. If you set it to ``None`` or ``''``, it will remove the handler.
|
||||
""")
|
||||
|
||||
# ------------------------- WSGI handlers ------------------------- #
|
||||
|
||||
def _set_wsgi_handler(self, log, enable):
|
||||
h = self._get_builtin_handler(log, "wsgi")
|
||||
if enable:
|
||||
if not h:
|
||||
h = WSGIErrorHandler()
|
||||
h.setFormatter(logfmt)
|
||||
h._cpbuiltin = "wsgi"
|
||||
log.addHandler(h)
|
||||
elif h:
|
||||
log.handlers.remove(h)
|
||||
|
||||
def _get_wsgi(self):
|
||||
return bool(self._get_builtin_handler(self.error_log, "wsgi"))
|
||||
|
||||
def _set_wsgi(self, newvalue):
|
||||
self._set_wsgi_handler(self.error_log, newvalue)
|
||||
wsgi = property(_get_wsgi, _set_wsgi,
|
||||
doc="""Write errors to wsgi.errors.
|
||||
|
||||
If you set this to True, it'll add the appropriate
|
||||
:class:`WSGIErrorHandler<cherrypy._cplogging.WSGIErrorHandler>` for you
|
||||
(which writes errors to ``wsgi.errors``).
|
||||
If you set it to False, it will remove the handler.
|
||||
""")
|
||||
|
||||
|
||||
class WSGIErrorHandler(logging.Handler):
|
||||
|
||||
"A handler class which writes logging records to environ['wsgi.errors']."
|
||||
|
||||
def flush(self):
|
||||
"""Flushes the stream."""
|
||||
try:
|
||||
stream = cherrypy.serving.request.wsgi_environ.get('wsgi.errors')
|
||||
except (AttributeError, KeyError):
|
||||
pass
|
||||
else:
|
||||
stream.flush()
|
||||
|
||||
def emit(self, record):
|
||||
"""Emit a record."""
|
||||
try:
|
||||
stream = cherrypy.serving.request.wsgi_environ.get('wsgi.errors')
|
||||
except (AttributeError, KeyError):
|
||||
pass
|
||||
else:
|
||||
try:
|
||||
msg = self.format(record)
|
||||
fs = "%s\n"
|
||||
import types
|
||||
# if no unicode support...
|
||||
if not hasattr(types, "UnicodeType"):
|
||||
stream.write(fs % msg)
|
||||
else:
|
||||
try:
|
||||
stream.write(fs % msg)
|
||||
except UnicodeError:
|
||||
stream.write(fs % msg.encode("UTF-8"))
|
||||
self.flush()
|
||||
except:
|
||||
self.handleError(record)
|
||||
@@ -0,0 +1,353 @@
|
||||
"""Native adapter for serving CherryPy via mod_python
|
||||
|
||||
Basic usage:
|
||||
|
||||
##########################################
|
||||
# Application in a module called myapp.py
|
||||
##########################################
|
||||
|
||||
import cherrypy
|
||||
|
||||
class Root:
|
||||
@cherrypy.expose
|
||||
def index(self):
|
||||
return 'Hi there, Ho there, Hey there'
|
||||
|
||||
|
||||
# We will use this method from the mod_python configuration
|
||||
# as the entry point to our application
|
||||
def setup_server():
|
||||
cherrypy.tree.mount(Root())
|
||||
cherrypy.config.update({'environment': 'production',
|
||||
'log.screen': False,
|
||||
'show_tracebacks': False})
|
||||
|
||||
##########################################
|
||||
# mod_python settings for apache2
|
||||
# This should reside in your httpd.conf
|
||||
# or a file that will be loaded at
|
||||
# apache startup
|
||||
##########################################
|
||||
|
||||
# Start
|
||||
DocumentRoot "/"
|
||||
Listen 8080
|
||||
LoadModule python_module /usr/lib/apache2/modules/mod_python.so
|
||||
|
||||
<Location "/">
|
||||
PythonPath "sys.path+['/path/to/my/application']"
|
||||
SetHandler python-program
|
||||
PythonHandler cherrypy._cpmodpy::handler
|
||||
PythonOption cherrypy.setup myapp::setup_server
|
||||
PythonDebug On
|
||||
</Location>
|
||||
# End
|
||||
|
||||
The actual path to your mod_python.so is dependent on your
|
||||
environment. In this case we suppose a global mod_python
|
||||
installation on a Linux distribution such as Ubuntu.
|
||||
|
||||
We do set the PythonPath configuration setting so that
|
||||
your application can be found by from the user running
|
||||
the apache2 instance. Of course if your application
|
||||
resides in the global site-package this won't be needed.
|
||||
|
||||
Then restart apache2 and access http://127.0.0.1:8080
|
||||
"""
|
||||
|
||||
import logging
|
||||
import sys
|
||||
|
||||
import cherrypy
|
||||
from cherrypy._cpcompat import BytesIO, copyitems, ntob
|
||||
from cherrypy._cperror import format_exc, bare_error
|
||||
from cherrypy.lib import httputil
|
||||
|
||||
|
||||
# ------------------------------ Request-handling
|
||||
|
||||
|
||||
def setup(req):
|
||||
from mod_python import apache
|
||||
|
||||
# Run any setup functions defined by a "PythonOption cherrypy.setup"
|
||||
# directive.
|
||||
options = req.get_options()
|
||||
if 'cherrypy.setup' in options:
|
||||
for function in options['cherrypy.setup'].split():
|
||||
atoms = function.split('::', 1)
|
||||
if len(atoms) == 1:
|
||||
mod = __import__(atoms[0], globals(), locals())
|
||||
else:
|
||||
modname, fname = atoms
|
||||
mod = __import__(modname, globals(), locals(), [fname])
|
||||
func = getattr(mod, fname)
|
||||
func()
|
||||
|
||||
cherrypy.config.update({'log.screen': False,
|
||||
"tools.ignore_headers.on": True,
|
||||
"tools.ignore_headers.headers": ['Range'],
|
||||
})
|
||||
|
||||
engine = cherrypy.engine
|
||||
if hasattr(engine, "signal_handler"):
|
||||
engine.signal_handler.unsubscribe()
|
||||
if hasattr(engine, "console_control_handler"):
|
||||
engine.console_control_handler.unsubscribe()
|
||||
engine.autoreload.unsubscribe()
|
||||
cherrypy.server.unsubscribe()
|
||||
|
||||
def _log(msg, level):
|
||||
newlevel = apache.APLOG_ERR
|
||||
if logging.DEBUG >= level:
|
||||
newlevel = apache.APLOG_DEBUG
|
||||
elif logging.INFO >= level:
|
||||
newlevel = apache.APLOG_INFO
|
||||
elif logging.WARNING >= level:
|
||||
newlevel = apache.APLOG_WARNING
|
||||
# On Windows, req.server is required or the msg will vanish. See
|
||||
# http://www.modpython.org/pipermail/mod_python/2003-October/014291.html
|
||||
# Also, "When server is not specified...LogLevel does not apply..."
|
||||
apache.log_error(msg, newlevel, req.server)
|
||||
engine.subscribe('log', _log)
|
||||
|
||||
engine.start()
|
||||
|
||||
def cherrypy_cleanup(data):
|
||||
engine.exit()
|
||||
try:
|
||||
# apache.register_cleanup wasn't available until 3.1.4.
|
||||
apache.register_cleanup(cherrypy_cleanup)
|
||||
except AttributeError:
|
||||
req.server.register_cleanup(req, cherrypy_cleanup)
|
||||
|
||||
|
||||
class _ReadOnlyRequest:
|
||||
expose = ('read', 'readline', 'readlines')
|
||||
|
||||
def __init__(self, req):
|
||||
for method in self.expose:
|
||||
self.__dict__[method] = getattr(req, method)
|
||||
|
||||
|
||||
recursive = False
|
||||
|
||||
_isSetUp = False
|
||||
|
||||
|
||||
def handler(req):
|
||||
from mod_python import apache
|
||||
try:
|
||||
global _isSetUp
|
||||
if not _isSetUp:
|
||||
setup(req)
|
||||
_isSetUp = True
|
||||
|
||||
# Obtain a Request object from CherryPy
|
||||
local = req.connection.local_addr
|
||||
local = httputil.Host(
|
||||
local[0], local[1], req.connection.local_host or "")
|
||||
remote = req.connection.remote_addr
|
||||
remote = httputil.Host(
|
||||
remote[0], remote[1], req.connection.remote_host or "")
|
||||
|
||||
scheme = req.parsed_uri[0] or 'http'
|
||||
req.get_basic_auth_pw()
|
||||
|
||||
try:
|
||||
# apache.mpm_query only became available in mod_python 3.1
|
||||
q = apache.mpm_query
|
||||
threaded = q(apache.AP_MPMQ_IS_THREADED)
|
||||
forked = q(apache.AP_MPMQ_IS_FORKED)
|
||||
except AttributeError:
|
||||
bad_value = ("You must provide a PythonOption '%s', "
|
||||
"either 'on' or 'off', when running a version "
|
||||
"of mod_python < 3.1")
|
||||
|
||||
threaded = options.get('multithread', '').lower()
|
||||
if threaded == 'on':
|
||||
threaded = True
|
||||
elif threaded == 'off':
|
||||
threaded = False
|
||||
else:
|
||||
raise ValueError(bad_value % "multithread")
|
||||
|
||||
forked = options.get('multiprocess', '').lower()
|
||||
if forked == 'on':
|
||||
forked = True
|
||||
elif forked == 'off':
|
||||
forked = False
|
||||
else:
|
||||
raise ValueError(bad_value % "multiprocess")
|
||||
|
||||
sn = cherrypy.tree.script_name(req.uri or "/")
|
||||
if sn is None:
|
||||
send_response(req, '404 Not Found', [], '')
|
||||
else:
|
||||
app = cherrypy.tree.apps[sn]
|
||||
method = req.method
|
||||
path = req.uri
|
||||
qs = req.args or ""
|
||||
reqproto = req.protocol
|
||||
headers = copyitems(req.headers_in)
|
||||
rfile = _ReadOnlyRequest(req)
|
||||
prev = None
|
||||
|
||||
try:
|
||||
redirections = []
|
||||
while True:
|
||||
request, response = app.get_serving(local, remote, scheme,
|
||||
"HTTP/1.1")
|
||||
request.login = req.user
|
||||
request.multithread = bool(threaded)
|
||||
request.multiprocess = bool(forked)
|
||||
request.app = app
|
||||
request.prev = prev
|
||||
|
||||
# Run the CherryPy Request object and obtain the response
|
||||
try:
|
||||
request.run(method, path, qs, reqproto, headers, rfile)
|
||||
break
|
||||
except cherrypy.InternalRedirect:
|
||||
ir = sys.exc_info()[1]
|
||||
app.release_serving()
|
||||
prev = request
|
||||
|
||||
if not recursive:
|
||||
if ir.path in redirections:
|
||||
raise RuntimeError(
|
||||
"InternalRedirector visited the same URL "
|
||||
"twice: %r" % ir.path)
|
||||
else:
|
||||
# Add the *previous* path_info + qs to
|
||||
# redirections.
|
||||
if qs:
|
||||
qs = "?" + qs
|
||||
redirections.append(sn + path + qs)
|
||||
|
||||
# Munge environment and try again.
|
||||
method = "GET"
|
||||
path = ir.path
|
||||
qs = ir.query_string
|
||||
rfile = BytesIO()
|
||||
|
||||
send_response(
|
||||
req, response.output_status, response.header_list,
|
||||
response.body, response.stream)
|
||||
finally:
|
||||
app.release_serving()
|
||||
except:
|
||||
tb = format_exc()
|
||||
cherrypy.log(tb, 'MOD_PYTHON', severity=logging.ERROR)
|
||||
s, h, b = bare_error()
|
||||
send_response(req, s, h, b)
|
||||
return apache.OK
|
||||
|
||||
|
||||
def send_response(req, status, headers, body, stream=False):
|
||||
# Set response status
|
||||
req.status = int(status[:3])
|
||||
|
||||
# Set response headers
|
||||
req.content_type = "text/plain"
|
||||
for header, value in headers:
|
||||
if header.lower() == 'content-type':
|
||||
req.content_type = value
|
||||
continue
|
||||
req.headers_out.add(header, value)
|
||||
|
||||
if stream:
|
||||
# Flush now so the status and headers are sent immediately.
|
||||
req.flush()
|
||||
|
||||
# Set response body
|
||||
if isinstance(body, basestring):
|
||||
req.write(body)
|
||||
else:
|
||||
for seg in body:
|
||||
req.write(seg)
|
||||
|
||||
|
||||
# --------------- Startup tools for CherryPy + mod_python --------------- #
|
||||
import os
|
||||
import re
|
||||
try:
|
||||
import subprocess
|
||||
|
||||
def popen(fullcmd):
|
||||
p = subprocess.Popen(fullcmd, shell=True,
|
||||
stdout=subprocess.PIPE, stderr=subprocess.STDOUT,
|
||||
close_fds=True)
|
||||
return p.stdout
|
||||
except ImportError:
|
||||
def popen(fullcmd):
|
||||
pipein, pipeout = os.popen4(fullcmd)
|
||||
return pipeout
|
||||
|
||||
|
||||
def read_process(cmd, args=""):
|
||||
fullcmd = "%s %s" % (cmd, args)
|
||||
pipeout = popen(fullcmd)
|
||||
try:
|
||||
firstline = pipeout.readline()
|
||||
cmd_not_found = re.search(
|
||||
ntob("(not recognized|No such file|not found)"),
|
||||
firstline,
|
||||
re.IGNORECASE
|
||||
)
|
||||
if cmd_not_found:
|
||||
raise IOError('%s must be on your system path.' % cmd)
|
||||
output = firstline + pipeout.read()
|
||||
finally:
|
||||
pipeout.close()
|
||||
return output
|
||||
|
||||
|
||||
class ModPythonServer(object):
|
||||
|
||||
template = """
|
||||
# Apache2 server configuration file for running CherryPy with mod_python.
|
||||
|
||||
DocumentRoot "/"
|
||||
Listen %(port)s
|
||||
LoadModule python_module modules/mod_python.so
|
||||
|
||||
<Location %(loc)s>
|
||||
SetHandler python-program
|
||||
PythonHandler %(handler)s
|
||||
PythonDebug On
|
||||
%(opts)s
|
||||
</Location>
|
||||
"""
|
||||
|
||||
def __init__(self, loc="/", port=80, opts=None, apache_path="apache",
|
||||
handler="cherrypy._cpmodpy::handler"):
|
||||
self.loc = loc
|
||||
self.port = port
|
||||
self.opts = opts
|
||||
self.apache_path = apache_path
|
||||
self.handler = handler
|
||||
|
||||
def start(self):
|
||||
opts = "".join([" PythonOption %s %s\n" % (k, v)
|
||||
for k, v in self.opts])
|
||||
conf_data = self.template % {"port": self.port,
|
||||
"loc": self.loc,
|
||||
"opts": opts,
|
||||
"handler": self.handler,
|
||||
}
|
||||
|
||||
mpconf = os.path.join(os.path.dirname(__file__), "cpmodpy.conf")
|
||||
f = open(mpconf, 'wb')
|
||||
try:
|
||||
f.write(conf_data)
|
||||
finally:
|
||||
f.close()
|
||||
|
||||
response = read_process(self.apache_path, "-k start -f %s" % mpconf)
|
||||
self.ready = True
|
||||
return response
|
||||
|
||||
def stop(self):
|
||||
os.popen("apache -k stop")
|
||||
self.ready = False
|
||||
@@ -0,0 +1,154 @@
|
||||
"""Native adapter for serving CherryPy via its builtin server."""
|
||||
|
||||
import logging
|
||||
import sys
|
||||
|
||||
import cherrypy
|
||||
from cherrypy._cpcompat import BytesIO
|
||||
from cherrypy._cperror import format_exc, bare_error
|
||||
from cherrypy.lib import httputil
|
||||
from cherrypy import wsgiserver
|
||||
|
||||
|
||||
class NativeGateway(wsgiserver.Gateway):
|
||||
|
||||
recursive = False
|
||||
|
||||
def respond(self):
|
||||
req = self.req
|
||||
try:
|
||||
# Obtain a Request object from CherryPy
|
||||
local = req.server.bind_addr
|
||||
local = httputil.Host(local[0], local[1], "")
|
||||
remote = req.conn.remote_addr, req.conn.remote_port
|
||||
remote = httputil.Host(remote[0], remote[1], "")
|
||||
|
||||
scheme = req.scheme
|
||||
sn = cherrypy.tree.script_name(req.uri or "/")
|
||||
if sn is None:
|
||||
self.send_response('404 Not Found', [], [''])
|
||||
else:
|
||||
app = cherrypy.tree.apps[sn]
|
||||
method = req.method
|
||||
path = req.path
|
||||
qs = req.qs or ""
|
||||
headers = req.inheaders.items()
|
||||
rfile = req.rfile
|
||||
prev = None
|
||||
|
||||
try:
|
||||
redirections = []
|
||||
while True:
|
||||
request, response = app.get_serving(
|
||||
local, remote, scheme, "HTTP/1.1")
|
||||
request.multithread = True
|
||||
request.multiprocess = False
|
||||
request.app = app
|
||||
request.prev = prev
|
||||
|
||||
# Run the CherryPy Request object and obtain the
|
||||
# response
|
||||
try:
|
||||
request.run(method, path, qs,
|
||||
req.request_protocol, headers, rfile)
|
||||
break
|
||||
except cherrypy.InternalRedirect:
|
||||
ir = sys.exc_info()[1]
|
||||
app.release_serving()
|
||||
prev = request
|
||||
|
||||
if not self.recursive:
|
||||
if ir.path in redirections:
|
||||
raise RuntimeError(
|
||||
"InternalRedirector visited the same "
|
||||
"URL twice: %r" % ir.path)
|
||||
else:
|
||||
# Add the *previous* path_info + qs to
|
||||
# redirections.
|
||||
if qs:
|
||||
qs = "?" + qs
|
||||
redirections.append(sn + path + qs)
|
||||
|
||||
# Munge environment and try again.
|
||||
method = "GET"
|
||||
path = ir.path
|
||||
qs = ir.query_string
|
||||
rfile = BytesIO()
|
||||
|
||||
self.send_response(
|
||||
response.output_status, response.header_list,
|
||||
response.body)
|
||||
finally:
|
||||
app.release_serving()
|
||||
except:
|
||||
tb = format_exc()
|
||||
# print tb
|
||||
cherrypy.log(tb, 'NATIVE_ADAPTER', severity=logging.ERROR)
|
||||
s, h, b = bare_error()
|
||||
self.send_response(s, h, b)
|
||||
|
||||
def send_response(self, status, headers, body):
|
||||
req = self.req
|
||||
|
||||
# Set response status
|
||||
req.status = str(status or "500 Server Error")
|
||||
|
||||
# Set response headers
|
||||
for header, value in headers:
|
||||
req.outheaders.append((header, value))
|
||||
if (req.ready and not req.sent_headers):
|
||||
req.sent_headers = True
|
||||
req.send_headers()
|
||||
|
||||
# Set response body
|
||||
for seg in body:
|
||||
req.write(seg)
|
||||
|
||||
|
||||
class CPHTTPServer(wsgiserver.HTTPServer):
|
||||
|
||||
"""Wrapper for wsgiserver.HTTPServer.
|
||||
|
||||
wsgiserver has been designed to not reference CherryPy in any way,
|
||||
so that it can be used in other frameworks and applications.
|
||||
Therefore, we wrap it here, so we can apply some attributes
|
||||
from config -> cherrypy.server -> HTTPServer.
|
||||
"""
|
||||
|
||||
def __init__(self, server_adapter=cherrypy.server):
|
||||
self.server_adapter = server_adapter
|
||||
|
||||
server_name = (self.server_adapter.socket_host or
|
||||
self.server_adapter.socket_file or
|
||||
None)
|
||||
|
||||
wsgiserver.HTTPServer.__init__(
|
||||
self, server_adapter.bind_addr, NativeGateway,
|
||||
minthreads=server_adapter.thread_pool,
|
||||
maxthreads=server_adapter.thread_pool_max,
|
||||
server_name=server_name)
|
||||
|
||||
self.max_request_header_size = (
|
||||
self.server_adapter.max_request_header_size or 0)
|
||||
self.max_request_body_size = (
|
||||
self.server_adapter.max_request_body_size or 0)
|
||||
self.request_queue_size = self.server_adapter.socket_queue_size
|
||||
self.timeout = self.server_adapter.socket_timeout
|
||||
self.shutdown_timeout = self.server_adapter.shutdown_timeout
|
||||
self.protocol = self.server_adapter.protocol_version
|
||||
self.nodelay = self.server_adapter.nodelay
|
||||
|
||||
ssl_module = self.server_adapter.ssl_module or 'pyopenssl'
|
||||
if self.server_adapter.ssl_context:
|
||||
adapter_class = wsgiserver.get_ssl_adapter_class(ssl_module)
|
||||
self.ssl_adapter = adapter_class(
|
||||
self.server_adapter.ssl_certificate,
|
||||
self.server_adapter.ssl_private_key,
|
||||
self.server_adapter.ssl_certificate_chain)
|
||||
self.ssl_adapter.context = self.server_adapter.ssl_context
|
||||
elif self.server_adapter.ssl_certificate:
|
||||
adapter_class = wsgiserver.get_ssl_adapter_class(ssl_module)
|
||||
self.ssl_adapter = adapter_class(
|
||||
self.server_adapter.ssl_certificate,
|
||||
self.server_adapter.ssl_private_key,
|
||||
self.server_adapter.ssl_certificate_chain)
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,973 @@
|
||||
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import warnings
|
||||
|
||||
import cherrypy
|
||||
from cherrypy._cpcompat import basestring, copykeys, ntob, unicodestr
|
||||
from cherrypy._cpcompat import SimpleCookie, CookieError, py3k
|
||||
from cherrypy import _cpreqbody, _cpconfig
|
||||
from cherrypy._cperror import format_exc, bare_error
|
||||
from cherrypy.lib import httputil, file_generator
|
||||
|
||||
|
||||
class Hook(object):
|
||||
|
||||
"""A callback and its metadata: failsafe, priority, and kwargs."""
|
||||
|
||||
callback = None
|
||||
"""
|
||||
The bare callable that this Hook object is wrapping, which will
|
||||
be called when the Hook is called."""
|
||||
|
||||
failsafe = False
|
||||
"""
|
||||
If True, the callback is guaranteed to run even if other callbacks
|
||||
from the same call point raise exceptions."""
|
||||
|
||||
priority = 50
|
||||
"""
|
||||
Defines the order of execution for a list of Hooks. Priority numbers
|
||||
should be limited to the closed interval [0, 100], but values outside
|
||||
this range are acceptable, as are fractional values."""
|
||||
|
||||
kwargs = {}
|
||||
"""
|
||||
A set of keyword arguments that will be passed to the
|
||||
callable on each call."""
|
||||
|
||||
def __init__(self, callback, failsafe=None, priority=None, **kwargs):
|
||||
self.callback = callback
|
||||
|
||||
if failsafe is None:
|
||||
failsafe = getattr(callback, "failsafe", False)
|
||||
self.failsafe = failsafe
|
||||
|
||||
if priority is None:
|
||||
priority = getattr(callback, "priority", 50)
|
||||
self.priority = priority
|
||||
|
||||
self.kwargs = kwargs
|
||||
|
||||
def __lt__(self, other):
|
||||
# Python 3
|
||||
return self.priority < other.priority
|
||||
|
||||
def __cmp__(self, other):
|
||||
# Python 2
|
||||
return cmp(self.priority, other.priority)
|
||||
|
||||
def __call__(self):
|
||||
"""Run self.callback(**self.kwargs)."""
|
||||
return self.callback(**self.kwargs)
|
||||
|
||||
def __repr__(self):
|
||||
cls = self.__class__
|
||||
return ("%s.%s(callback=%r, failsafe=%r, priority=%r, %s)"
|
||||
% (cls.__module__, cls.__name__, self.callback,
|
||||
self.failsafe, self.priority,
|
||||
", ".join(['%s=%r' % (k, v)
|
||||
for k, v in self.kwargs.items()])))
|
||||
|
||||
|
||||
class HookMap(dict):
|
||||
|
||||
"""A map of call points to lists of callbacks (Hook objects)."""
|
||||
|
||||
def __new__(cls, points=None):
|
||||
d = dict.__new__(cls)
|
||||
for p in points or []:
|
||||
d[p] = []
|
||||
return d
|
||||
|
||||
def __init__(self, *a, **kw):
|
||||
pass
|
||||
|
||||
def attach(self, point, callback, failsafe=None, priority=None, **kwargs):
|
||||
"""Append a new Hook made from the supplied arguments."""
|
||||
self[point].append(Hook(callback, failsafe, priority, **kwargs))
|
||||
|
||||
def run(self, point):
|
||||
"""Execute all registered Hooks (callbacks) for the given point."""
|
||||
exc = None
|
||||
hooks = self[point]
|
||||
hooks.sort()
|
||||
for hook in hooks:
|
||||
# Some hooks are guaranteed to run even if others at
|
||||
# the same hookpoint fail. We will still log the failure,
|
||||
# but proceed on to the next hook. The only way
|
||||
# to stop all processing from one of these hooks is
|
||||
# to raise SystemExit and stop the whole server.
|
||||
if exc is None or hook.failsafe:
|
||||
try:
|
||||
hook()
|
||||
except (KeyboardInterrupt, SystemExit):
|
||||
raise
|
||||
except (cherrypy.HTTPError, cherrypy.HTTPRedirect,
|
||||
cherrypy.InternalRedirect):
|
||||
exc = sys.exc_info()[1]
|
||||
except:
|
||||
exc = sys.exc_info()[1]
|
||||
cherrypy.log(traceback=True, severity=40)
|
||||
if exc:
|
||||
raise exc
|
||||
|
||||
def __copy__(self):
|
||||
newmap = self.__class__()
|
||||
# We can't just use 'update' because we want copies of the
|
||||
# mutable values (each is a list) as well.
|
||||
for k, v in self.items():
|
||||
newmap[k] = v[:]
|
||||
return newmap
|
||||
copy = __copy__
|
||||
|
||||
def __repr__(self):
|
||||
cls = self.__class__
|
||||
return "%s.%s(points=%r)" % (
|
||||
cls.__module__,
|
||||
cls.__name__,
|
||||
copykeys(self)
|
||||
)
|
||||
|
||||
|
||||
# Config namespace handlers
|
||||
|
||||
def hooks_namespace(k, v):
|
||||
"""Attach bare hooks declared in config."""
|
||||
# Use split again to allow multiple hooks for a single
|
||||
# hookpoint per path (e.g. "hooks.before_handler.1").
|
||||
# Little-known fact you only get from reading source ;)
|
||||
hookpoint = k.split(".", 1)[0]
|
||||
if isinstance(v, basestring):
|
||||
v = cherrypy.lib.attributes(v)
|
||||
if not isinstance(v, Hook):
|
||||
v = Hook(v)
|
||||
cherrypy.serving.request.hooks[hookpoint].append(v)
|
||||
|
||||
|
||||
def request_namespace(k, v):
|
||||
"""Attach request attributes declared in config."""
|
||||
# Provides config entries to set request.body attrs (like
|
||||
# attempt_charsets).
|
||||
if k[:5] == 'body.':
|
||||
setattr(cherrypy.serving.request.body, k[5:], v)
|
||||
else:
|
||||
setattr(cherrypy.serving.request, k, v)
|
||||
|
||||
|
||||
def response_namespace(k, v):
|
||||
"""Attach response attributes declared in config."""
|
||||
# Provides config entries to set default response headers
|
||||
# http://cherrypy.org/ticket/889
|
||||
if k[:8] == 'headers.':
|
||||
cherrypy.serving.response.headers[k.split('.', 1)[1]] = v
|
||||
else:
|
||||
setattr(cherrypy.serving.response, k, v)
|
||||
|
||||
|
||||
def error_page_namespace(k, v):
|
||||
"""Attach error pages declared in config."""
|
||||
if k != 'default':
|
||||
k = int(k)
|
||||
cherrypy.serving.request.error_page[k] = v
|
||||
|
||||
|
||||
hookpoints = ['on_start_resource', 'before_request_body',
|
||||
'before_handler', 'before_finalize',
|
||||
'on_end_resource', 'on_end_request',
|
||||
'before_error_response', 'after_error_response']
|
||||
|
||||
|
||||
class Request(object):
|
||||
|
||||
"""An HTTP request.
|
||||
|
||||
This object represents the metadata of an HTTP request message;
|
||||
that is, it contains attributes which describe the environment
|
||||
in which the request URL, headers, and body were sent (if you
|
||||
want tools to interpret the headers and body, those are elsewhere,
|
||||
mostly in Tools). This 'metadata' consists of socket data,
|
||||
transport characteristics, and the Request-Line. This object
|
||||
also contains data regarding the configuration in effect for
|
||||
the given URL, and the execution plan for generating a response.
|
||||
"""
|
||||
|
||||
prev = None
|
||||
"""
|
||||
The previous Request object (if any). This should be None
|
||||
unless we are processing an InternalRedirect."""
|
||||
|
||||
# Conversation/connection attributes
|
||||
local = httputil.Host("127.0.0.1", 80)
|
||||
"An httputil.Host(ip, port, hostname) object for the server socket."
|
||||
|
||||
remote = httputil.Host("127.0.0.1", 1111)
|
||||
"An httputil.Host(ip, port, hostname) object for the client socket."
|
||||
|
||||
scheme = "http"
|
||||
"""
|
||||
The protocol used between client and server. In most cases,
|
||||
this will be either 'http' or 'https'."""
|
||||
|
||||
server_protocol = "HTTP/1.1"
|
||||
"""
|
||||
The HTTP version for which the HTTP server is at least
|
||||
conditionally compliant."""
|
||||
|
||||
base = ""
|
||||
"""The (scheme://host) portion of the requested URL.
|
||||
In some cases (e.g. when proxying via mod_rewrite), this may contain
|
||||
path segments which cherrypy.url uses when constructing url's, but
|
||||
which otherwise are ignored by CherryPy. Regardless, this value
|
||||
MUST NOT end in a slash."""
|
||||
|
||||
# Request-Line attributes
|
||||
request_line = ""
|
||||
"""
|
||||
The complete Request-Line received from the client. This is a
|
||||
single string consisting of the request method, URI, and protocol
|
||||
version (joined by spaces). Any final CRLF is removed."""
|
||||
|
||||
method = "GET"
|
||||
"""
|
||||
Indicates the HTTP method to be performed on the resource identified
|
||||
by the Request-URI. Common methods include GET, HEAD, POST, PUT, and
|
||||
DELETE. CherryPy allows any extension method; however, various HTTP
|
||||
servers and gateways may restrict the set of allowable methods.
|
||||
CherryPy applications SHOULD restrict the set (on a per-URI basis)."""
|
||||
|
||||
query_string = ""
|
||||
"""
|
||||
The query component of the Request-URI, a string of information to be
|
||||
interpreted by the resource. The query portion of a URI follows the
|
||||
path component, and is separated by a '?'. For example, the URI
|
||||
'http://www.cherrypy.org/wiki?a=3&b=4' has the query component,
|
||||
'a=3&b=4'."""
|
||||
|
||||
query_string_encoding = 'utf8'
|
||||
"""
|
||||
The encoding expected for query string arguments after % HEX HEX decoding).
|
||||
If a query string is provided that cannot be decoded with this encoding,
|
||||
404 is raised (since technically it's a different URI). If you want
|
||||
arbitrary encodings to not error, set this to 'Latin-1'; you can then
|
||||
encode back to bytes and re-decode to whatever encoding you like later.
|
||||
"""
|
||||
|
||||
protocol = (1, 1)
|
||||
"""The HTTP protocol version corresponding to the set
|
||||
of features which should be allowed in the response. If BOTH
|
||||
the client's request message AND the server's level of HTTP
|
||||
compliance is HTTP/1.1, this attribute will be the tuple (1, 1).
|
||||
If either is 1.0, this attribute will be the tuple (1, 0).
|
||||
Lower HTTP protocol versions are not explicitly supported."""
|
||||
|
||||
params = {}
|
||||
"""
|
||||
A dict which combines query string (GET) and request entity (POST)
|
||||
variables. This is populated in two stages: GET params are added
|
||||
before the 'on_start_resource' hook, and POST params are added
|
||||
between the 'before_request_body' and 'before_handler' hooks."""
|
||||
|
||||
# Message attributes
|
||||
header_list = []
|
||||
"""
|
||||
A list of the HTTP request headers as (name, value) tuples.
|
||||
In general, you should use request.headers (a dict) instead."""
|
||||
|
||||
headers = httputil.HeaderMap()
|
||||
"""
|
||||
A dict-like object containing the request headers. Keys are header
|
||||
names (in Title-Case format); however, you may get and set them in
|
||||
a case-insensitive manner. That is, headers['Content-Type'] and
|
||||
headers['content-type'] refer to the same value. Values are header
|
||||
values (decoded according to :rfc:`2047` if necessary). See also:
|
||||
httputil.HeaderMap, httputil.HeaderElement."""
|
||||
|
||||
cookie = SimpleCookie()
|
||||
"""See help(Cookie)."""
|
||||
|
||||
rfile = None
|
||||
"""
|
||||
If the request included an entity (body), it will be available
|
||||
as a stream in this attribute. However, the rfile will normally
|
||||
be read for you between the 'before_request_body' hook and the
|
||||
'before_handler' hook, and the resulting string is placed into
|
||||
either request.params or the request.body attribute.
|
||||
|
||||
You may disable the automatic consumption of the rfile by setting
|
||||
request.process_request_body to False, either in config for the desired
|
||||
path, or in an 'on_start_resource' or 'before_request_body' hook.
|
||||
|
||||
WARNING: In almost every case, you should not attempt to read from the
|
||||
rfile stream after CherryPy's automatic mechanism has read it. If you
|
||||
turn off the automatic parsing of rfile, you should read exactly the
|
||||
number of bytes specified in request.headers['Content-Length'].
|
||||
Ignoring either of these warnings may result in a hung request thread
|
||||
or in corruption of the next (pipelined) request.
|
||||
"""
|
||||
|
||||
process_request_body = True
|
||||
"""
|
||||
If True, the rfile (if any) is automatically read and parsed,
|
||||
and the result placed into request.params or request.body."""
|
||||
|
||||
methods_with_bodies = ("POST", "PUT")
|
||||
"""
|
||||
A sequence of HTTP methods for which CherryPy will automatically
|
||||
attempt to read a body from the rfile. If you are going to change
|
||||
this property, modify it on the configuration (recommended)
|
||||
or on the "hook point" `on_start_resource`.
|
||||
"""
|
||||
|
||||
body = None
|
||||
"""
|
||||
If the request Content-Type is 'application/x-www-form-urlencoded'
|
||||
or multipart, this will be None. Otherwise, this will be an instance
|
||||
of :class:`RequestBody<cherrypy._cpreqbody.RequestBody>` (which you
|
||||
can .read()); this value is set between the 'before_request_body' and
|
||||
'before_handler' hooks (assuming that process_request_body is True)."""
|
||||
|
||||
# Dispatch attributes
|
||||
dispatch = cherrypy.dispatch.Dispatcher()
|
||||
"""
|
||||
The object which looks up the 'page handler' callable and collects
|
||||
config for the current request based on the path_info, other
|
||||
request attributes, and the application architecture. The core
|
||||
calls the dispatcher as early as possible, passing it a 'path_info'
|
||||
argument.
|
||||
|
||||
The default dispatcher discovers the page handler by matching path_info
|
||||
to a hierarchical arrangement of objects, starting at request.app.root.
|
||||
See help(cherrypy.dispatch) for more information."""
|
||||
|
||||
script_name = ""
|
||||
"""
|
||||
The 'mount point' of the application which is handling this request.
|
||||
|
||||
This attribute MUST NOT end in a slash. If the script_name refers to
|
||||
the root of the URI, it MUST be an empty string (not "/").
|
||||
"""
|
||||
|
||||
path_info = "/"
|
||||
"""
|
||||
The 'relative path' portion of the Request-URI. This is relative
|
||||
to the script_name ('mount point') of the application which is
|
||||
handling this request."""
|
||||
|
||||
login = None
|
||||
"""
|
||||
When authentication is used during the request processing this is
|
||||
set to 'False' if it failed and to the 'username' value if it succeeded.
|
||||
The default 'None' implies that no authentication happened."""
|
||||
|
||||
# Note that cherrypy.url uses "if request.app:" to determine whether
|
||||
# the call is during a real HTTP request or not. So leave this None.
|
||||
app = None
|
||||
"""The cherrypy.Application object which is handling this request."""
|
||||
|
||||
handler = None
|
||||
"""
|
||||
The function, method, or other callable which CherryPy will call to
|
||||
produce the response. The discovery of the handler and the arguments
|
||||
it will receive are determined by the request.dispatch object.
|
||||
By default, the handler is discovered by walking a tree of objects
|
||||
starting at request.app.root, and is then passed all HTTP params
|
||||
(from the query string and POST body) as keyword arguments."""
|
||||
|
||||
toolmaps = {}
|
||||
"""
|
||||
A nested dict of all Toolboxes and Tools in effect for this request,
|
||||
of the form: {Toolbox.namespace: {Tool.name: config dict}}."""
|
||||
|
||||
config = None
|
||||
"""
|
||||
A flat dict of all configuration entries which apply to the
|
||||
current request. These entries are collected from global config,
|
||||
application config (based on request.path_info), and from handler
|
||||
config (exactly how is governed by the request.dispatch object in
|
||||
effect for this request; by default, handler config can be attached
|
||||
anywhere in the tree between request.app.root and the final handler,
|
||||
and inherits downward)."""
|
||||
|
||||
is_index = None
|
||||
"""
|
||||
This will be True if the current request is mapped to an 'index'
|
||||
resource handler (also, a 'default' handler if path_info ends with
|
||||
a slash). The value may be used to automatically redirect the
|
||||
user-agent to a 'more canonical' URL which either adds or removes
|
||||
the trailing slash. See cherrypy.tools.trailing_slash."""
|
||||
|
||||
hooks = HookMap(hookpoints)
|
||||
"""
|
||||
A HookMap (dict-like object) of the form: {hookpoint: [hook, ...]}.
|
||||
Each key is a str naming the hook point, and each value is a list
|
||||
of hooks which will be called at that hook point during this request.
|
||||
The list of hooks is generally populated as early as possible (mostly
|
||||
from Tools specified in config), but may be extended at any time.
|
||||
See also: _cprequest.Hook, _cprequest.HookMap, and cherrypy.tools."""
|
||||
|
||||
error_response = cherrypy.HTTPError(500).set_response
|
||||
"""
|
||||
The no-arg callable which will handle unexpected, untrapped errors
|
||||
during request processing. This is not used for expected exceptions
|
||||
(like NotFound, HTTPError, or HTTPRedirect) which are raised in
|
||||
response to expected conditions (those should be customized either
|
||||
via request.error_page or by overriding HTTPError.set_response).
|
||||
By default, error_response uses HTTPError(500) to return a generic
|
||||
error response to the user-agent."""
|
||||
|
||||
error_page = {}
|
||||
"""
|
||||
A dict of {error code: response filename or callable} pairs.
|
||||
|
||||
The error code must be an int representing a given HTTP error code,
|
||||
or the string 'default', which will be used if no matching entry
|
||||
is found for a given numeric code.
|
||||
|
||||
If a filename is provided, the file should contain a Python string-
|
||||
formatting template, and can expect by default to receive format
|
||||
values with the mapping keys %(status)s, %(message)s, %(traceback)s,
|
||||
and %(version)s. The set of format mappings can be extended by
|
||||
overriding HTTPError.set_response.
|
||||
|
||||
If a callable is provided, it will be called by default with keyword
|
||||
arguments 'status', 'message', 'traceback', and 'version', as for a
|
||||
string-formatting template. The callable must return a string or
|
||||
iterable of strings which will be set to response.body. It may also
|
||||
override headers or perform any other processing.
|
||||
|
||||
If no entry is given for an error code, and no 'default' entry exists,
|
||||
a default template will be used.
|
||||
"""
|
||||
|
||||
show_tracebacks = True
|
||||
"""
|
||||
If True, unexpected errors encountered during request processing will
|
||||
include a traceback in the response body."""
|
||||
|
||||
show_mismatched_params = True
|
||||
"""
|
||||
If True, mismatched parameters encountered during PageHandler invocation
|
||||
processing will be included in the response body."""
|
||||
|
||||
throws = (KeyboardInterrupt, SystemExit, cherrypy.InternalRedirect)
|
||||
"""The sequence of exceptions which Request.run does not trap."""
|
||||
|
||||
throw_errors = False
|
||||
"""
|
||||
If True, Request.run will not trap any errors (except HTTPRedirect and
|
||||
HTTPError, which are more properly called 'exceptions', not errors)."""
|
||||
|
||||
closed = False
|
||||
"""True once the close method has been called, False otherwise."""
|
||||
|
||||
stage = None
|
||||
"""
|
||||
A string containing the stage reached in the request-handling process.
|
||||
This is useful when debugging a live server with hung requests."""
|
||||
|
||||
namespaces = _cpconfig.NamespaceSet(
|
||||
**{"hooks": hooks_namespace,
|
||||
"request": request_namespace,
|
||||
"response": response_namespace,
|
||||
"error_page": error_page_namespace,
|
||||
"tools": cherrypy.tools,
|
||||
})
|
||||
|
||||
def __init__(self, local_host, remote_host, scheme="http",
|
||||
server_protocol="HTTP/1.1"):
|
||||
"""Populate a new Request object.
|
||||
|
||||
local_host should be an httputil.Host object with the server info.
|
||||
remote_host should be an httputil.Host object with the client info.
|
||||
scheme should be a string, either "http" or "https".
|
||||
"""
|
||||
self.local = local_host
|
||||
self.remote = remote_host
|
||||
self.scheme = scheme
|
||||
self.server_protocol = server_protocol
|
||||
|
||||
self.closed = False
|
||||
|
||||
# Put a *copy* of the class error_page into self.
|
||||
self.error_page = self.error_page.copy()
|
||||
|
||||
# Put a *copy* of the class namespaces into self.
|
||||
self.namespaces = self.namespaces.copy()
|
||||
|
||||
self.stage = None
|
||||
|
||||
def close(self):
|
||||
"""Run cleanup code. (Core)"""
|
||||
if not self.closed:
|
||||
self.closed = True
|
||||
self.stage = 'on_end_request'
|
||||
self.hooks.run('on_end_request')
|
||||
self.stage = 'close'
|
||||
|
||||
def run(self, method, path, query_string, req_protocol, headers, rfile):
|
||||
r"""Process the Request. (Core)
|
||||
|
||||
method, path, query_string, and req_protocol should be pulled directly
|
||||
from the Request-Line (e.g. "GET /path?key=val HTTP/1.0").
|
||||
|
||||
path
|
||||
This should be %XX-unquoted, but query_string should not be.
|
||||
|
||||
When using Python 2, they both MUST be byte strings,
|
||||
not unicode strings.
|
||||
|
||||
When using Python 3, they both MUST be unicode strings,
|
||||
not byte strings, and preferably not bytes \x00-\xFF
|
||||
disguised as unicode.
|
||||
|
||||
headers
|
||||
A list of (name, value) tuples.
|
||||
|
||||
rfile
|
||||
A file-like object containing the HTTP request entity.
|
||||
|
||||
When run() is done, the returned object should have 3 attributes:
|
||||
|
||||
* status, e.g. "200 OK"
|
||||
* header_list, a list of (name, value) tuples
|
||||
* body, an iterable yielding strings
|
||||
|
||||
Consumer code (HTTP servers) should then access these response
|
||||
attributes to build the outbound stream.
|
||||
|
||||
"""
|
||||
response = cherrypy.serving.response
|
||||
self.stage = 'run'
|
||||
try:
|
||||
self.error_response = cherrypy.HTTPError(500).set_response
|
||||
|
||||
self.method = method
|
||||
path = path or "/"
|
||||
self.query_string = query_string or ''
|
||||
self.params = {}
|
||||
|
||||
# Compare request and server HTTP protocol versions, in case our
|
||||
# server does not support the requested protocol. Limit our output
|
||||
# to min(req, server). We want the following output:
|
||||
# request server actual written supported response
|
||||
# protocol protocol response protocol feature set
|
||||
# a 1.0 1.0 1.0 1.0
|
||||
# b 1.0 1.1 1.1 1.0
|
||||
# c 1.1 1.0 1.0 1.0
|
||||
# d 1.1 1.1 1.1 1.1
|
||||
# Notice that, in (b), the response will be "HTTP/1.1" even though
|
||||
# the client only understands 1.0. RFC 2616 10.5.6 says we should
|
||||
# only return 505 if the _major_ version is different.
|
||||
rp = int(req_protocol[5]), int(req_protocol[7])
|
||||
sp = int(self.server_protocol[5]), int(self.server_protocol[7])
|
||||
self.protocol = min(rp, sp)
|
||||
response.headers.protocol = self.protocol
|
||||
|
||||
# Rebuild first line of the request (e.g. "GET /path HTTP/1.0").
|
||||
url = path
|
||||
if query_string:
|
||||
url += '?' + query_string
|
||||
self.request_line = '%s %s %s' % (method, url, req_protocol)
|
||||
|
||||
self.header_list = list(headers)
|
||||
self.headers = httputil.HeaderMap()
|
||||
|
||||
self.rfile = rfile
|
||||
self.body = None
|
||||
|
||||
self.cookie = SimpleCookie()
|
||||
self.handler = None
|
||||
|
||||
# path_info should be the path from the
|
||||
# app root (script_name) to the handler.
|
||||
self.script_name = self.app.script_name
|
||||
self.path_info = pi = path[len(self.script_name):]
|
||||
|
||||
self.stage = 'respond'
|
||||
self.respond(pi)
|
||||
|
||||
except self.throws:
|
||||
raise
|
||||
except:
|
||||
if self.throw_errors:
|
||||
raise
|
||||
else:
|
||||
# Failure in setup, error handler or finalize. Bypass them.
|
||||
# Can't use handle_error because we may not have hooks yet.
|
||||
cherrypy.log(traceback=True, severity=40)
|
||||
if self.show_tracebacks:
|
||||
body = format_exc()
|
||||
else:
|
||||
body = ""
|
||||
r = bare_error(body)
|
||||
response.output_status, response.header_list, response.body = r
|
||||
|
||||
if self.method == "HEAD":
|
||||
# HEAD requests MUST NOT return a message-body in the response.
|
||||
response.body = []
|
||||
|
||||
try:
|
||||
cherrypy.log.access()
|
||||
except:
|
||||
cherrypy.log.error(traceback=True)
|
||||
|
||||
if response.timed_out:
|
||||
raise cherrypy.TimeoutError()
|
||||
|
||||
return response
|
||||
|
||||
# Uncomment for stage debugging
|
||||
# stage = property(lambda self: self._stage, lambda self, v: print(v))
|
||||
|
||||
def respond(self, path_info):
|
||||
"""Generate a response for the resource at self.path_info. (Core)"""
|
||||
response = cherrypy.serving.response
|
||||
try:
|
||||
try:
|
||||
try:
|
||||
if self.app is None:
|
||||
raise cherrypy.NotFound()
|
||||
|
||||
# Get the 'Host' header, so we can HTTPRedirect properly.
|
||||
self.stage = 'process_headers'
|
||||
self.process_headers()
|
||||
|
||||
# Make a copy of the class hooks
|
||||
self.hooks = self.__class__.hooks.copy()
|
||||
self.toolmaps = {}
|
||||
|
||||
self.stage = 'get_resource'
|
||||
self.get_resource(path_info)
|
||||
|
||||
self.body = _cpreqbody.RequestBody(
|
||||
self.rfile, self.headers, request_params=self.params)
|
||||
|
||||
self.namespaces(self.config)
|
||||
|
||||
self.stage = 'on_start_resource'
|
||||
self.hooks.run('on_start_resource')
|
||||
|
||||
# Parse the querystring
|
||||
self.stage = 'process_query_string'
|
||||
self.process_query_string()
|
||||
|
||||
# Process the body
|
||||
if self.process_request_body:
|
||||
if self.method not in self.methods_with_bodies:
|
||||
self.process_request_body = False
|
||||
self.stage = 'before_request_body'
|
||||
self.hooks.run('before_request_body')
|
||||
if self.process_request_body:
|
||||
self.body.process()
|
||||
|
||||
# Run the handler
|
||||
self.stage = 'before_handler'
|
||||
self.hooks.run('before_handler')
|
||||
if self.handler:
|
||||
self.stage = 'handler'
|
||||
response.body = self.handler()
|
||||
|
||||
# Finalize
|
||||
self.stage = 'before_finalize'
|
||||
self.hooks.run('before_finalize')
|
||||
response.finalize()
|
||||
except (cherrypy.HTTPRedirect, cherrypy.HTTPError):
|
||||
inst = sys.exc_info()[1]
|
||||
inst.set_response()
|
||||
self.stage = 'before_finalize (HTTPError)'
|
||||
self.hooks.run('before_finalize')
|
||||
response.finalize()
|
||||
finally:
|
||||
self.stage = 'on_end_resource'
|
||||
self.hooks.run('on_end_resource')
|
||||
except self.throws:
|
||||
raise
|
||||
except:
|
||||
if self.throw_errors:
|
||||
raise
|
||||
self.handle_error()
|
||||
|
||||
def process_query_string(self):
|
||||
"""Parse the query string into Python structures. (Core)"""
|
||||
try:
|
||||
p = httputil.parse_query_string(
|
||||
self.query_string, encoding=self.query_string_encoding)
|
||||
except UnicodeDecodeError:
|
||||
raise cherrypy.HTTPError(
|
||||
404, "The given query string could not be processed. Query "
|
||||
"strings for this resource must be encoded with %r." %
|
||||
self.query_string_encoding)
|
||||
|
||||
# Python 2 only: keyword arguments must be byte strings (type 'str').
|
||||
if not py3k:
|
||||
for key, value in p.items():
|
||||
if isinstance(key, unicode):
|
||||
del p[key]
|
||||
p[key.encode(self.query_string_encoding)] = value
|
||||
self.params.update(p)
|
||||
|
||||
def process_headers(self):
|
||||
"""Parse HTTP header data into Python structures. (Core)"""
|
||||
# Process the headers into self.headers
|
||||
headers = self.headers
|
||||
for name, value in self.header_list:
|
||||
# Call title() now (and use dict.__method__(headers))
|
||||
# so title doesn't have to be called twice.
|
||||
name = name.title()
|
||||
value = value.strip()
|
||||
|
||||
# Warning: if there is more than one header entry for cookies
|
||||
# (AFAIK, only Konqueror does that), only the last one will
|
||||
# remain in headers (but they will be correctly stored in
|
||||
# request.cookie).
|
||||
if "=?" in value:
|
||||
dict.__setitem__(headers, name, httputil.decode_TEXT(value))
|
||||
else:
|
||||
dict.__setitem__(headers, name, value)
|
||||
|
||||
# Handle cookies differently because on Konqueror, multiple
|
||||
# cookies come on different lines with the same key
|
||||
if name == 'Cookie':
|
||||
try:
|
||||
self.cookie.load(value)
|
||||
except CookieError:
|
||||
msg = "Illegal cookie name %s" % value.split('=')[0]
|
||||
raise cherrypy.HTTPError(400, msg)
|
||||
|
||||
if not dict.__contains__(headers, 'Host'):
|
||||
# All Internet-based HTTP/1.1 servers MUST respond with a 400
|
||||
# (Bad Request) status code to any HTTP/1.1 request message
|
||||
# which lacks a Host header field.
|
||||
if self.protocol >= (1, 1):
|
||||
msg = "HTTP/1.1 requires a 'Host' request header."
|
||||
raise cherrypy.HTTPError(400, msg)
|
||||
host = dict.get(headers, 'Host')
|
||||
if not host:
|
||||
host = self.local.name or self.local.ip
|
||||
self.base = "%s://%s" % (self.scheme, host)
|
||||
|
||||
def get_resource(self, path):
|
||||
"""Call a dispatcher (which sets self.handler and .config). (Core)"""
|
||||
# First, see if there is a custom dispatch at this URI. Custom
|
||||
# dispatchers can only be specified in app.config, not in _cp_config
|
||||
# (since custom dispatchers may not even have an app.root).
|
||||
dispatch = self.app.find_config(
|
||||
path, "request.dispatch", self.dispatch)
|
||||
|
||||
# dispatch() should set self.handler and self.config
|
||||
dispatch(path)
|
||||
|
||||
def handle_error(self):
|
||||
"""Handle the last unanticipated exception. (Core)"""
|
||||
try:
|
||||
self.hooks.run("before_error_response")
|
||||
if self.error_response:
|
||||
self.error_response()
|
||||
self.hooks.run("after_error_response")
|
||||
cherrypy.serving.response.finalize()
|
||||
except cherrypy.HTTPRedirect:
|
||||
inst = sys.exc_info()[1]
|
||||
inst.set_response()
|
||||
cherrypy.serving.response.finalize()
|
||||
|
||||
# ------------------------- Properties ------------------------- #
|
||||
|
||||
def _get_body_params(self):
|
||||
warnings.warn(
|
||||
"body_params is deprecated in CherryPy 3.2, will be removed in "
|
||||
"CherryPy 3.3.",
|
||||
DeprecationWarning
|
||||
)
|
||||
return self.body.params
|
||||
body_params = property(_get_body_params,
|
||||
doc="""
|
||||
If the request Content-Type is 'application/x-www-form-urlencoded' or
|
||||
multipart, this will be a dict of the params pulled from the entity
|
||||
body; that is, it will be the portion of request.params that come
|
||||
from the message body (sometimes called "POST params", although they
|
||||
can be sent with various HTTP method verbs). This value is set between
|
||||
the 'before_request_body' and 'before_handler' hooks (assuming that
|
||||
process_request_body is True).
|
||||
|
||||
Deprecated in 3.2, will be removed for 3.3 in favor of
|
||||
:attr:`request.body.params<cherrypy._cprequest.RequestBody.params>`.""")
|
||||
|
||||
|
||||
class ResponseBody(object):
|
||||
|
||||
"""The body of the HTTP response (the response entity)."""
|
||||
|
||||
if py3k:
|
||||
unicode_err = ("Page handlers MUST return bytes. Use tools.encode "
|
||||
"if you wish to return unicode.")
|
||||
|
||||
def __get__(self, obj, objclass=None):
|
||||
if obj is None:
|
||||
# When calling on the class instead of an instance...
|
||||
return self
|
||||
else:
|
||||
return obj._body
|
||||
|
||||
def __set__(self, obj, value):
|
||||
# Convert the given value to an iterable object.
|
||||
if py3k and isinstance(value, str):
|
||||
raise ValueError(self.unicode_err)
|
||||
|
||||
if isinstance(value, basestring):
|
||||
# strings get wrapped in a list because iterating over a single
|
||||
# item list is much faster than iterating over every character
|
||||
# in a long string.
|
||||
if value:
|
||||
value = [value]
|
||||
else:
|
||||
# [''] doesn't evaluate to False, so replace it with [].
|
||||
value = []
|
||||
elif py3k and isinstance(value, list):
|
||||
# every item in a list must be bytes...
|
||||
for i, item in enumerate(value):
|
||||
if isinstance(item, str):
|
||||
raise ValueError(self.unicode_err)
|
||||
# Don't use isinstance here; io.IOBase which has an ABC takes
|
||||
# 1000 times as long as, say, isinstance(value, str)
|
||||
elif hasattr(value, 'read'):
|
||||
value = file_generator(value)
|
||||
elif value is None:
|
||||
value = []
|
||||
obj._body = value
|
||||
|
||||
|
||||
class Response(object):
|
||||
|
||||
"""An HTTP Response, including status, headers, and body."""
|
||||
|
||||
status = ""
|
||||
"""The HTTP Status-Code and Reason-Phrase."""
|
||||
|
||||
header_list = []
|
||||
"""
|
||||
A list of the HTTP response headers as (name, value) tuples.
|
||||
In general, you should use response.headers (a dict) instead. This
|
||||
attribute is generated from response.headers and is not valid until
|
||||
after the finalize phase."""
|
||||
|
||||
headers = httputil.HeaderMap()
|
||||
"""
|
||||
A dict-like object containing the response headers. Keys are header
|
||||
names (in Title-Case format); however, you may get and set them in
|
||||
a case-insensitive manner. That is, headers['Content-Type'] and
|
||||
headers['content-type'] refer to the same value. Values are header
|
||||
values (decoded according to :rfc:`2047` if necessary).
|
||||
|
||||
.. seealso:: classes :class:`HeaderMap`, :class:`HeaderElement`
|
||||
"""
|
||||
|
||||
cookie = SimpleCookie()
|
||||
"""See help(Cookie)."""
|
||||
|
||||
body = ResponseBody()
|
||||
"""The body (entity) of the HTTP response."""
|
||||
|
||||
time = None
|
||||
"""The value of time.time() when created. Use in HTTP dates."""
|
||||
|
||||
timeout = 300
|
||||
"""Seconds after which the response will be aborted."""
|
||||
|
||||
timed_out = False
|
||||
"""
|
||||
Flag to indicate the response should be aborted, because it has
|
||||
exceeded its timeout."""
|
||||
|
||||
stream = False
|
||||
"""If False, buffer the response body."""
|
||||
|
||||
def __init__(self):
|
||||
self.status = None
|
||||
self.header_list = None
|
||||
self._body = []
|
||||
self.time = time.time()
|
||||
|
||||
self.headers = httputil.HeaderMap()
|
||||
# Since we know all our keys are titled strings, we can
|
||||
# bypass HeaderMap.update and get a big speed boost.
|
||||
dict.update(self.headers, {
|
||||
"Content-Type": 'text/html',
|
||||
"Server": "CherryPy/" + cherrypy.__version__,
|
||||
"Date": httputil.HTTPDate(self.time),
|
||||
})
|
||||
self.cookie = SimpleCookie()
|
||||
|
||||
def collapse_body(self):
|
||||
"""Collapse self.body to a single string; replace it and return it."""
|
||||
if isinstance(self.body, basestring):
|
||||
return self.body
|
||||
|
||||
newbody = []
|
||||
for chunk in self.body:
|
||||
if py3k and not isinstance(chunk, bytes):
|
||||
raise TypeError("Chunk %s is not of type 'bytes'." %
|
||||
repr(chunk))
|
||||
newbody.append(chunk)
|
||||
newbody = ntob('').join(newbody)
|
||||
|
||||
self.body = newbody
|
||||
return newbody
|
||||
|
||||
def finalize(self):
|
||||
"""Transform headers (and cookies) into self.header_list. (Core)"""
|
||||
try:
|
||||
code, reason, _ = httputil.valid_status(self.status)
|
||||
except ValueError:
|
||||
raise cherrypy.HTTPError(500, sys.exc_info()[1].args[0])
|
||||
|
||||
headers = self.headers
|
||||
|
||||
self.status = "%s %s" % (code, reason)
|
||||
self.output_status = ntob(str(code), 'ascii') + \
|
||||
ntob(" ") + headers.encode(reason)
|
||||
|
||||
if self.stream:
|
||||
# The upshot: wsgiserver will chunk the response if
|
||||
# you pop Content-Length (or set it explicitly to None).
|
||||
# Note that lib.static sets C-L to the file's st_size.
|
||||
if dict.get(headers, 'Content-Length') is None:
|
||||
dict.pop(headers, 'Content-Length', None)
|
||||
elif code < 200 or code in (204, 205, 304):
|
||||
# "All 1xx (informational), 204 (no content),
|
||||
# and 304 (not modified) responses MUST NOT
|
||||
# include a message-body."
|
||||
dict.pop(headers, 'Content-Length', None)
|
||||
self.body = ntob("")
|
||||
else:
|
||||
# Responses which are not streamed should have a Content-Length,
|
||||
# but allow user code to set Content-Length if desired.
|
||||
if dict.get(headers, 'Content-Length') is None:
|
||||
content = self.collapse_body()
|
||||
dict.__setitem__(headers, 'Content-Length', len(content))
|
||||
|
||||
# Transform our header dict into a list of tuples.
|
||||
self.header_list = h = headers.output()
|
||||
|
||||
cookie = self.cookie.output()
|
||||
if cookie:
|
||||
for line in cookie.split("\n"):
|
||||
if line.endswith("\r"):
|
||||
# Python 2.4 emits cookies joined by LF but 2.5+ by CRLF.
|
||||
line = line[:-1]
|
||||
name, value = line.split(": ", 1)
|
||||
if isinstance(name, unicodestr):
|
||||
name = name.encode("ISO-8859-1")
|
||||
if isinstance(value, unicodestr):
|
||||
value = headers.encode(value)
|
||||
h.append((name, value))
|
||||
|
||||
def check_timeout(self):
|
||||
"""If now > self.time + self.timeout, set self.timed_out.
|
||||
|
||||
This purposefully sets a flag, rather than raising an error,
|
||||
so that a monitor thread can interrupt the Response thread.
|
||||
"""
|
||||
if time.time() > self.time + self.timeout:
|
||||
self.timed_out = True
|
||||
@@ -0,0 +1,226 @@
|
||||
"""Manage HTTP servers with CherryPy."""
|
||||
|
||||
import warnings
|
||||
|
||||
import cherrypy
|
||||
from cherrypy.lib import attributes
|
||||
from cherrypy._cpcompat import basestring, py3k
|
||||
|
||||
# We import * because we want to export check_port
|
||||
# et al as attributes of this module.
|
||||
from cherrypy.process.servers import *
|
||||
|
||||
|
||||
class Server(ServerAdapter):
|
||||
|
||||
"""An adapter for an HTTP server.
|
||||
|
||||
You can set attributes (like socket_host and socket_port)
|
||||
on *this* object (which is probably cherrypy.server), and call
|
||||
quickstart. For example::
|
||||
|
||||
cherrypy.server.socket_port = 80
|
||||
cherrypy.quickstart()
|
||||
"""
|
||||
|
||||
socket_port = 8080
|
||||
"""The TCP port on which to listen for connections."""
|
||||
|
||||
_socket_host = '127.0.0.1'
|
||||
|
||||
def _get_socket_host(self):
|
||||
return self._socket_host
|
||||
|
||||
def _set_socket_host(self, value):
|
||||
if value == '':
|
||||
raise ValueError("The empty string ('') is not an allowed value. "
|
||||
"Use '0.0.0.0' instead to listen on all active "
|
||||
"interfaces (INADDR_ANY).")
|
||||
self._socket_host = value
|
||||
socket_host = property(
|
||||
_get_socket_host,
|
||||
_set_socket_host,
|
||||
doc="""The hostname or IP address on which to listen for connections.
|
||||
|
||||
Host values may be any IPv4 or IPv6 address, or any valid hostname.
|
||||
The string 'localhost' is a synonym for '127.0.0.1' (or '::1', if
|
||||
your hosts file prefers IPv6). The string '0.0.0.0' is a special
|
||||
IPv4 entry meaning "any active interface" (INADDR_ANY), and '::'
|
||||
is the similar IN6ADDR_ANY for IPv6. The empty string or None are
|
||||
not allowed.""")
|
||||
|
||||
socket_file = None
|
||||
"""If given, the name of the UNIX socket to use instead of TCP/IP.
|
||||
|
||||
When this option is not None, the `socket_host` and `socket_port` options
|
||||
are ignored."""
|
||||
|
||||
socket_queue_size = 5
|
||||
"""The 'backlog' argument to socket.listen(); specifies the maximum number
|
||||
of queued connections (default 5)."""
|
||||
|
||||
socket_timeout = 10
|
||||
"""The timeout in seconds for accepted connections (default 10)."""
|
||||
|
||||
accepted_queue_size = -1
|
||||
"""The maximum number of requests which will be queued up before
|
||||
the server refuses to accept it (default -1, meaning no limit)."""
|
||||
|
||||
accepted_queue_timeout = 10
|
||||
"""The timeout in seconds for attempting to add a request to the
|
||||
queue when the queue is full (default 10)."""
|
||||
|
||||
shutdown_timeout = 5
|
||||
"""The time to wait for HTTP worker threads to clean up."""
|
||||
|
||||
protocol_version = 'HTTP/1.1'
|
||||
"""The version string to write in the Status-Line of all HTTP responses,
|
||||
for example, "HTTP/1.1" (the default). Depending on the HTTP server used,
|
||||
this should also limit the supported features used in the response."""
|
||||
|
||||
thread_pool = 10
|
||||
"""The number of worker threads to start up in the pool."""
|
||||
|
||||
thread_pool_max = -1
|
||||
"""The maximum size of the worker-thread pool. Use -1 to indicate no limit.
|
||||
"""
|
||||
|
||||
max_request_header_size = 500 * 1024
|
||||
"""The maximum number of bytes allowable in the request headers.
|
||||
If exceeded, the HTTP server should return "413 Request Entity Too Large".
|
||||
"""
|
||||
|
||||
max_request_body_size = 100 * 1024 * 1024
|
||||
"""The maximum number of bytes allowable in the request body. If exceeded,
|
||||
the HTTP server should return "413 Request Entity Too Large"."""
|
||||
|
||||
instance = None
|
||||
"""If not None, this should be an HTTP server instance (such as
|
||||
CPWSGIServer) which cherrypy.server will control. Use this when you need
|
||||
more control over object instantiation than is available in the various
|
||||
configuration options."""
|
||||
|
||||
ssl_context = None
|
||||
"""When using PyOpenSSL, an instance of SSL.Context."""
|
||||
|
||||
ssl_certificate = None
|
||||
"""The filename of the SSL certificate to use."""
|
||||
|
||||
ssl_certificate_chain = None
|
||||
"""When using PyOpenSSL, the certificate chain to pass to
|
||||
Context.load_verify_locations."""
|
||||
|
||||
ssl_private_key = None
|
||||
"""The filename of the private key to use with SSL."""
|
||||
|
||||
if py3k:
|
||||
ssl_module = 'builtin'
|
||||
"""The name of a registered SSL adaptation module to use with
|
||||
the builtin WSGI server. Builtin options are: 'builtin' (to
|
||||
use the SSL library built into recent versions of Python).
|
||||
You may also register your own classes in the
|
||||
wsgiserver.ssl_adapters dict."""
|
||||
else:
|
||||
ssl_module = 'pyopenssl'
|
||||
"""The name of a registered SSL adaptation module to use with the
|
||||
builtin WSGI server. Builtin options are 'builtin' (to use the SSL
|
||||
library built into recent versions of Python) and 'pyopenssl' (to
|
||||
use the PyOpenSSL project, which you must install separately). You
|
||||
may also register your own classes in the wsgiserver.ssl_adapters
|
||||
dict."""
|
||||
|
||||
statistics = False
|
||||
"""Turns statistics-gathering on or off for aware HTTP servers."""
|
||||
|
||||
nodelay = True
|
||||
"""If True (the default since 3.1), sets the TCP_NODELAY socket option."""
|
||||
|
||||
wsgi_version = (1, 0)
|
||||
"""The WSGI version tuple to use with the builtin WSGI server.
|
||||
The provided options are (1, 0) [which includes support for PEP 3333,
|
||||
which declares it covers WSGI version 1.0.1 but still mandates the
|
||||
wsgi.version (1, 0)] and ('u', 0), an experimental unicode version.
|
||||
You may create and register your own experimental versions of the WSGI
|
||||
protocol by adding custom classes to the wsgiserver.wsgi_gateways dict."""
|
||||
|
||||
def __init__(self):
|
||||
self.bus = cherrypy.engine
|
||||
self.httpserver = None
|
||||
self.interrupt = None
|
||||
self.running = False
|
||||
|
||||
def httpserver_from_self(self, httpserver=None):
|
||||
"""Return a (httpserver, bind_addr) pair based on self attributes."""
|
||||
if httpserver is None:
|
||||
httpserver = self.instance
|
||||
if httpserver is None:
|
||||
from cherrypy import _cpwsgi_server
|
||||
httpserver = _cpwsgi_server.CPWSGIServer(self)
|
||||
if isinstance(httpserver, basestring):
|
||||
# Is anyone using this? Can I add an arg?
|
||||
httpserver = attributes(httpserver)(self)
|
||||
return httpserver, self.bind_addr
|
||||
|
||||
def start(self):
|
||||
"""Start the HTTP server."""
|
||||
if not self.httpserver:
|
||||
self.httpserver, self.bind_addr = self.httpserver_from_self()
|
||||
ServerAdapter.start(self)
|
||||
start.priority = 75
|
||||
|
||||
def _get_bind_addr(self):
|
||||
if self.socket_file:
|
||||
return self.socket_file
|
||||
if self.socket_host is None and self.socket_port is None:
|
||||
return None
|
||||
return (self.socket_host, self.socket_port)
|
||||
|
||||
def _set_bind_addr(self, value):
|
||||
if value is None:
|
||||
self.socket_file = None
|
||||
self.socket_host = None
|
||||
self.socket_port = None
|
||||
elif isinstance(value, basestring):
|
||||
self.socket_file = value
|
||||
self.socket_host = None
|
||||
self.socket_port = None
|
||||
else:
|
||||
try:
|
||||
self.socket_host, self.socket_port = value
|
||||
self.socket_file = None
|
||||
except ValueError:
|
||||
raise ValueError("bind_addr must be a (host, port) tuple "
|
||||
"(for TCP sockets) or a string (for Unix "
|
||||
"domain sockets), not %r" % value)
|
||||
bind_addr = property(
|
||||
_get_bind_addr,
|
||||
_set_bind_addr,
|
||||
doc='A (host, port) tuple for TCP sockets or '
|
||||
'a str for Unix domain sockets.')
|
||||
|
||||
def base(self):
|
||||
"""Return the base (scheme://host[:port] or sock file) for this server.
|
||||
"""
|
||||
if self.socket_file:
|
||||
return self.socket_file
|
||||
|
||||
host = self.socket_host
|
||||
if host in ('0.0.0.0', '::'):
|
||||
# 0.0.0.0 is INADDR_ANY and :: is IN6ADDR_ANY.
|
||||
# Look up the host name, which should be the
|
||||
# safest thing to spit out in a URL.
|
||||
import socket
|
||||
host = socket.gethostname()
|
||||
|
||||
port = self.socket_port
|
||||
|
||||
if self.ssl_certificate:
|
||||
scheme = "https"
|
||||
if port != 443:
|
||||
host += ":%s" % port
|
||||
else:
|
||||
scheme = "http"
|
||||
if port != 80:
|
||||
host += ":%s" % port
|
||||
|
||||
return "%s://%s" % (scheme, host)
|
||||
@@ -0,0 +1,241 @@
|
||||
# This is a backport of Python-2.4's threading.local() implementation
|
||||
|
||||
"""Thread-local objects
|
||||
|
||||
(Note that this module provides a Python version of thread
|
||||
threading.local class. Depending on the version of Python you're
|
||||
using, there may be a faster one available. You should always import
|
||||
the local class from threading.)
|
||||
|
||||
Thread-local objects support the management of thread-local data.
|
||||
If you have data that you want to be local to a thread, simply create
|
||||
a thread-local object and use its attributes:
|
||||
|
||||
>>> mydata = local()
|
||||
>>> mydata.number = 42
|
||||
>>> mydata.number
|
||||
42
|
||||
|
||||
You can also access the local-object's dictionary:
|
||||
|
||||
>>> mydata.__dict__
|
||||
{'number': 42}
|
||||
>>> mydata.__dict__.setdefault('widgets', [])
|
||||
[]
|
||||
>>> mydata.widgets
|
||||
[]
|
||||
|
||||
What's important about thread-local objects is that their data are
|
||||
local to a thread. If we access the data in a different thread:
|
||||
|
||||
>>> log = []
|
||||
>>> def f():
|
||||
... items = mydata.__dict__.items()
|
||||
... items.sort()
|
||||
... log.append(items)
|
||||
... mydata.number = 11
|
||||
... log.append(mydata.number)
|
||||
|
||||
>>> import threading
|
||||
>>> thread = threading.Thread(target=f)
|
||||
>>> thread.start()
|
||||
>>> thread.join()
|
||||
>>> log
|
||||
[[], 11]
|
||||
|
||||
we get different data. Furthermore, changes made in the other thread
|
||||
don't affect data seen in this thread:
|
||||
|
||||
>>> mydata.number
|
||||
42
|
||||
|
||||
Of course, values you get from a local object, including a __dict__
|
||||
attribute, are for whatever thread was current at the time the
|
||||
attribute was read. For that reason, you generally don't want to save
|
||||
these values across threads, as they apply only to the thread they
|
||||
came from.
|
||||
|
||||
You can create custom local objects by subclassing the local class:
|
||||
|
||||
>>> class MyLocal(local):
|
||||
... number = 2
|
||||
... initialized = False
|
||||
... def __init__(self, **kw):
|
||||
... if self.initialized:
|
||||
... raise SystemError('__init__ called too many times')
|
||||
... self.initialized = True
|
||||
... self.__dict__.update(kw)
|
||||
... def squared(self):
|
||||
... return self.number ** 2
|
||||
|
||||
This can be useful to support default values, methods and
|
||||
initialization. Note that if you define an __init__ method, it will be
|
||||
called each time the local object is used in a separate thread. This
|
||||
is necessary to initialize each thread's dictionary.
|
||||
|
||||
Now if we create a local object:
|
||||
|
||||
>>> mydata = MyLocal(color='red')
|
||||
|
||||
Now we have a default number:
|
||||
|
||||
>>> mydata.number
|
||||
2
|
||||
|
||||
an initial color:
|
||||
|
||||
>>> mydata.color
|
||||
'red'
|
||||
>>> del mydata.color
|
||||
|
||||
And a method that operates on the data:
|
||||
|
||||
>>> mydata.squared()
|
||||
4
|
||||
|
||||
As before, we can access the data in a separate thread:
|
||||
|
||||
>>> log = []
|
||||
>>> thread = threading.Thread(target=f)
|
||||
>>> thread.start()
|
||||
>>> thread.join()
|
||||
>>> log
|
||||
[[('color', 'red'), ('initialized', True)], 11]
|
||||
|
||||
without affecting this thread's data:
|
||||
|
||||
>>> mydata.number
|
||||
2
|
||||
>>> mydata.color
|
||||
Traceback (most recent call last):
|
||||
...
|
||||
AttributeError: 'MyLocal' object has no attribute 'color'
|
||||
|
||||
Note that subclasses can define slots, but they are not thread
|
||||
local. They are shared across threads:
|
||||
|
||||
>>> class MyLocal(local):
|
||||
... __slots__ = 'number'
|
||||
|
||||
>>> mydata = MyLocal()
|
||||
>>> mydata.number = 42
|
||||
>>> mydata.color = 'red'
|
||||
|
||||
So, the separate thread:
|
||||
|
||||
>>> thread = threading.Thread(target=f)
|
||||
>>> thread.start()
|
||||
>>> thread.join()
|
||||
|
||||
affects what we see:
|
||||
|
||||
>>> mydata.number
|
||||
11
|
||||
|
||||
>>> del mydata
|
||||
"""
|
||||
|
||||
# Threading import is at end
|
||||
|
||||
|
||||
class _localbase(object):
|
||||
__slots__ = '_local__key', '_local__args', '_local__lock'
|
||||
|
||||
def __new__(cls, *args, **kw):
|
||||
self = object.__new__(cls)
|
||||
key = 'thread.local.' + str(id(self))
|
||||
object.__setattr__(self, '_local__key', key)
|
||||
object.__setattr__(self, '_local__args', (args, kw))
|
||||
object.__setattr__(self, '_local__lock', RLock())
|
||||
|
||||
if args or kw and (cls.__init__ is object.__init__):
|
||||
raise TypeError("Initialization arguments are not supported")
|
||||
|
||||
# We need to create the thread dict in anticipation of
|
||||
# __init__ being called, to make sure we don't call it
|
||||
# again ourselves.
|
||||
dict = object.__getattribute__(self, '__dict__')
|
||||
currentThread().__dict__[key] = dict
|
||||
|
||||
return self
|
||||
|
||||
|
||||
def _patch(self):
|
||||
key = object.__getattribute__(self, '_local__key')
|
||||
d = currentThread().__dict__.get(key)
|
||||
if d is None:
|
||||
d = {}
|
||||
currentThread().__dict__[key] = d
|
||||
object.__setattr__(self, '__dict__', d)
|
||||
|
||||
# we have a new instance dict, so call out __init__ if we have
|
||||
# one
|
||||
cls = type(self)
|
||||
if cls.__init__ is not object.__init__:
|
||||
args, kw = object.__getattribute__(self, '_local__args')
|
||||
cls.__init__(self, *args, **kw)
|
||||
else:
|
||||
object.__setattr__(self, '__dict__', d)
|
||||
|
||||
|
||||
class local(_localbase):
|
||||
|
||||
def __getattribute__(self, name):
|
||||
lock = object.__getattribute__(self, '_local__lock')
|
||||
lock.acquire()
|
||||
try:
|
||||
_patch(self)
|
||||
return object.__getattribute__(self, name)
|
||||
finally:
|
||||
lock.release()
|
||||
|
||||
def __setattr__(self, name, value):
|
||||
lock = object.__getattribute__(self, '_local__lock')
|
||||
lock.acquire()
|
||||
try:
|
||||
_patch(self)
|
||||
return object.__setattr__(self, name, value)
|
||||
finally:
|
||||
lock.release()
|
||||
|
||||
def __delattr__(self, name):
|
||||
lock = object.__getattribute__(self, '_local__lock')
|
||||
lock.acquire()
|
||||
try:
|
||||
_patch(self)
|
||||
return object.__delattr__(self, name)
|
||||
finally:
|
||||
lock.release()
|
||||
|
||||
def __del__():
|
||||
threading_enumerate = enumerate
|
||||
__getattribute__ = object.__getattribute__
|
||||
|
||||
def __del__(self):
|
||||
key = __getattribute__(self, '_local__key')
|
||||
|
||||
try:
|
||||
threads = list(threading_enumerate())
|
||||
except:
|
||||
# if enumerate fails, as it seems to do during
|
||||
# shutdown, we'll skip cleanup under the assumption
|
||||
# that there is nothing to clean up
|
||||
return
|
||||
|
||||
for thread in threads:
|
||||
try:
|
||||
__dict__ = thread.__dict__
|
||||
except AttributeError:
|
||||
# Thread is dying, rest in peace
|
||||
continue
|
||||
|
||||
if key in __dict__:
|
||||
try:
|
||||
del __dict__[key]
|
||||
except KeyError:
|
||||
pass # didn't have anything in this thread
|
||||
|
||||
return __del__
|
||||
__del__ = __del__()
|
||||
|
||||
from threading import currentThread, enumerate, RLock
|
||||
@@ -0,0 +1,529 @@
|
||||
"""CherryPy tools. A "tool" is any helper, adapted to CP.
|
||||
|
||||
Tools are usually designed to be used in a variety of ways (although some
|
||||
may only offer one if they choose):
|
||||
|
||||
Library calls
|
||||
All tools are callables that can be used wherever needed.
|
||||
The arguments are straightforward and should be detailed within the
|
||||
docstring.
|
||||
|
||||
Function decorators
|
||||
All tools, when called, may be used as decorators which configure
|
||||
individual CherryPy page handlers (methods on the CherryPy tree).
|
||||
That is, "@tools.anytool()" should "turn on" the tool via the
|
||||
decorated function's _cp_config attribute.
|
||||
|
||||
CherryPy config
|
||||
If a tool exposes a "_setup" callable, it will be called
|
||||
once per Request (if the feature is "turned on" via config).
|
||||
|
||||
Tools may be implemented as any object with a namespace. The builtins
|
||||
are generally either modules or instances of the tools.Tool class.
|
||||
"""
|
||||
|
||||
import sys
|
||||
import warnings
|
||||
|
||||
import cherrypy
|
||||
|
||||
|
||||
def _getargs(func):
|
||||
"""Return the names of all static arguments to the given function."""
|
||||
# Use this instead of importing inspect for less mem overhead.
|
||||
import types
|
||||
if sys.version_info >= (3, 0):
|
||||
if isinstance(func, types.MethodType):
|
||||
func = func.__func__
|
||||
co = func.__code__
|
||||
else:
|
||||
if isinstance(func, types.MethodType):
|
||||
func = func.im_func
|
||||
co = func.func_code
|
||||
return co.co_varnames[:co.co_argcount]
|
||||
|
||||
|
||||
_attr_error = (
|
||||
"CherryPy Tools cannot be turned on directly. Instead, turn them "
|
||||
"on via config, or use them as decorators on your page handlers."
|
||||
)
|
||||
|
||||
|
||||
class Tool(object):
|
||||
|
||||
"""A registered function for use with CherryPy request-processing hooks.
|
||||
|
||||
help(tool.callable) should give you more information about this Tool.
|
||||
"""
|
||||
|
||||
namespace = "tools"
|
||||
|
||||
def __init__(self, point, callable, name=None, priority=50):
|
||||
self._point = point
|
||||
self.callable = callable
|
||||
self._name = name
|
||||
self._priority = priority
|
||||
self.__doc__ = self.callable.__doc__
|
||||
self._setargs()
|
||||
|
||||
def _get_on(self):
|
||||
raise AttributeError(_attr_error)
|
||||
|
||||
def _set_on(self, value):
|
||||
raise AttributeError(_attr_error)
|
||||
on = property(_get_on, _set_on)
|
||||
|
||||
def _setargs(self):
|
||||
"""Copy func parameter names to obj attributes."""
|
||||
try:
|
||||
for arg in _getargs(self.callable):
|
||||
setattr(self, arg, None)
|
||||
except (TypeError, AttributeError):
|
||||
if hasattr(self.callable, "__call__"):
|
||||
for arg in _getargs(self.callable.__call__):
|
||||
setattr(self, arg, None)
|
||||
# IronPython 1.0 raises NotImplementedError because
|
||||
# inspect.getargspec tries to access Python bytecode
|
||||
# in co_code attribute.
|
||||
except NotImplementedError:
|
||||
pass
|
||||
# IronPython 1B1 may raise IndexError in some cases,
|
||||
# but if we trap it here it doesn't prevent CP from working.
|
||||
except IndexError:
|
||||
pass
|
||||
|
||||
def _merged_args(self, d=None):
|
||||
"""Return a dict of configuration entries for this Tool."""
|
||||
if d:
|
||||
conf = d.copy()
|
||||
else:
|
||||
conf = {}
|
||||
|
||||
tm = cherrypy.serving.request.toolmaps[self.namespace]
|
||||
if self._name in tm:
|
||||
conf.update(tm[self._name])
|
||||
|
||||
if "on" in conf:
|
||||
del conf["on"]
|
||||
|
||||
return conf
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
"""Compile-time decorator (turn on the tool in config).
|
||||
|
||||
For example::
|
||||
|
||||
@tools.proxy()
|
||||
def whats_my_base(self):
|
||||
return cherrypy.request.base
|
||||
whats_my_base.exposed = True
|
||||
"""
|
||||
if args:
|
||||
raise TypeError("The %r Tool does not accept positional "
|
||||
"arguments; you must use keyword arguments."
|
||||
% self._name)
|
||||
|
||||
def tool_decorator(f):
|
||||
if not hasattr(f, "_cp_config"):
|
||||
f._cp_config = {}
|
||||
subspace = self.namespace + "." + self._name + "."
|
||||
f._cp_config[subspace + "on"] = True
|
||||
for k, v in kwargs.items():
|
||||
f._cp_config[subspace + k] = v
|
||||
return f
|
||||
return tool_decorator
|
||||
|
||||
def _setup(self):
|
||||
"""Hook this tool into cherrypy.request.
|
||||
|
||||
The standard CherryPy request object will automatically call this
|
||||
method when the tool is "turned on" in config.
|
||||
"""
|
||||
conf = self._merged_args()
|
||||
p = conf.pop("priority", None)
|
||||
if p is None:
|
||||
p = getattr(self.callable, "priority", self._priority)
|
||||
cherrypy.serving.request.hooks.attach(self._point, self.callable,
|
||||
priority=p, **conf)
|
||||
|
||||
|
||||
class HandlerTool(Tool):
|
||||
|
||||
"""Tool which is called 'before main', that may skip normal handlers.
|
||||
|
||||
If the tool successfully handles the request (by setting response.body),
|
||||
if should return True. This will cause CherryPy to skip any 'normal' page
|
||||
handler. If the tool did not handle the request, it should return False
|
||||
to tell CherryPy to continue on and call the normal page handler. If the
|
||||
tool is declared AS a page handler (see the 'handler' method), returning
|
||||
False will raise NotFound.
|
||||
"""
|
||||
|
||||
def __init__(self, callable, name=None):
|
||||
Tool.__init__(self, 'before_handler', callable, name)
|
||||
|
||||
def handler(self, *args, **kwargs):
|
||||
"""Use this tool as a CherryPy page handler.
|
||||
|
||||
For example::
|
||||
|
||||
class Root:
|
||||
nav = tools.staticdir.handler(section="/nav", dir="nav",
|
||||
root=absDir)
|
||||
"""
|
||||
def handle_func(*a, **kw):
|
||||
handled = self.callable(*args, **self._merged_args(kwargs))
|
||||
if not handled:
|
||||
raise cherrypy.NotFound()
|
||||
return cherrypy.serving.response.body
|
||||
handle_func.exposed = True
|
||||
return handle_func
|
||||
|
||||
def _wrapper(self, **kwargs):
|
||||
if self.callable(**kwargs):
|
||||
cherrypy.serving.request.handler = None
|
||||
|
||||
def _setup(self):
|
||||
"""Hook this tool into cherrypy.request.
|
||||
|
||||
The standard CherryPy request object will automatically call this
|
||||
method when the tool is "turned on" in config.
|
||||
"""
|
||||
conf = self._merged_args()
|
||||
p = conf.pop("priority", None)
|
||||
if p is None:
|
||||
p = getattr(self.callable, "priority", self._priority)
|
||||
cherrypy.serving.request.hooks.attach(self._point, self._wrapper,
|
||||
priority=p, **conf)
|
||||
|
||||
|
||||
class HandlerWrapperTool(Tool):
|
||||
|
||||
"""Tool which wraps request.handler in a provided wrapper function.
|
||||
|
||||
The 'newhandler' arg must be a handler wrapper function that takes a
|
||||
'next_handler' argument, plus ``*args`` and ``**kwargs``. Like all
|
||||
page handler
|
||||
functions, it must return an iterable for use as cherrypy.response.body.
|
||||
|
||||
For example, to allow your 'inner' page handlers to return dicts
|
||||
which then get interpolated into a template::
|
||||
|
||||
def interpolator(next_handler, *args, **kwargs):
|
||||
filename = cherrypy.request.config.get('template')
|
||||
cherrypy.response.template = env.get_template(filename)
|
||||
response_dict = next_handler(*args, **kwargs)
|
||||
return cherrypy.response.template.render(**response_dict)
|
||||
cherrypy.tools.jinja = HandlerWrapperTool(interpolator)
|
||||
"""
|
||||
|
||||
def __init__(self, newhandler, point='before_handler', name=None,
|
||||
priority=50):
|
||||
self.newhandler = newhandler
|
||||
self._point = point
|
||||
self._name = name
|
||||
self._priority = priority
|
||||
|
||||
def callable(self, *args, **kwargs):
|
||||
innerfunc = cherrypy.serving.request.handler
|
||||
|
||||
def wrap(*args, **kwargs):
|
||||
return self.newhandler(innerfunc, *args, **kwargs)
|
||||
cherrypy.serving.request.handler = wrap
|
||||
|
||||
|
||||
class ErrorTool(Tool):
|
||||
|
||||
"""Tool which is used to replace the default request.error_response."""
|
||||
|
||||
def __init__(self, callable, name=None):
|
||||
Tool.__init__(self, None, callable, name)
|
||||
|
||||
def _wrapper(self):
|
||||
self.callable(**self._merged_args())
|
||||
|
||||
def _setup(self):
|
||||
"""Hook this tool into cherrypy.request.
|
||||
|
||||
The standard CherryPy request object will automatically call this
|
||||
method when the tool is "turned on" in config.
|
||||
"""
|
||||
cherrypy.serving.request.error_response = self._wrapper
|
||||
|
||||
|
||||
# Builtin tools #
|
||||
|
||||
from cherrypy.lib import cptools, encoding, auth, static, jsontools
|
||||
from cherrypy.lib import sessions as _sessions, xmlrpcutil as _xmlrpc
|
||||
from cherrypy.lib import caching as _caching
|
||||
from cherrypy.lib import auth_basic, auth_digest
|
||||
|
||||
|
||||
class SessionTool(Tool):
|
||||
|
||||
"""Session Tool for CherryPy.
|
||||
|
||||
sessions.locking
|
||||
When 'implicit' (the default), the session will be locked for you,
|
||||
just before running the page handler.
|
||||
|
||||
When 'early', the session will be locked before reading the request
|
||||
body. This is off by default for safety reasons; for example,
|
||||
a large upload would block the session, denying an AJAX
|
||||
progress meter
|
||||
(`issue <https://bitbucket.org/cherrypy/cherrypy/issue/630>`_).
|
||||
|
||||
When 'explicit' (or any other value), you need to call
|
||||
cherrypy.session.acquire_lock() yourself before using
|
||||
session data.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
# _sessions.init must be bound after headers are read
|
||||
Tool.__init__(self, 'before_request_body', _sessions.init)
|
||||
|
||||
def _lock_session(self):
|
||||
cherrypy.serving.session.acquire_lock()
|
||||
|
||||
def _setup(self):
|
||||
"""Hook this tool into cherrypy.request.
|
||||
|
||||
The standard CherryPy request object will automatically call this
|
||||
method when the tool is "turned on" in config.
|
||||
"""
|
||||
hooks = cherrypy.serving.request.hooks
|
||||
|
||||
conf = self._merged_args()
|
||||
|
||||
p = conf.pop("priority", None)
|
||||
if p is None:
|
||||
p = getattr(self.callable, "priority", self._priority)
|
||||
|
||||
hooks.attach(self._point, self.callable, priority=p, **conf)
|
||||
|
||||
locking = conf.pop('locking', 'implicit')
|
||||
if locking == 'implicit':
|
||||
hooks.attach('before_handler', self._lock_session)
|
||||
elif locking == 'early':
|
||||
# Lock before the request body (but after _sessions.init runs!)
|
||||
hooks.attach('before_request_body', self._lock_session,
|
||||
priority=60)
|
||||
else:
|
||||
# Don't lock
|
||||
pass
|
||||
|
||||
hooks.attach('before_finalize', _sessions.save)
|
||||
hooks.attach('on_end_request', _sessions.close)
|
||||
|
||||
def regenerate(self):
|
||||
"""Drop the current session and make a new one (with a new id)."""
|
||||
sess = cherrypy.serving.session
|
||||
sess.regenerate()
|
||||
|
||||
# Grab cookie-relevant tool args
|
||||
conf = dict([(k, v) for k, v in self._merged_args().items()
|
||||
if k in ('path', 'path_header', 'name', 'timeout',
|
||||
'domain', 'secure')])
|
||||
_sessions.set_response_cookie(**conf)
|
||||
|
||||
|
||||
class XMLRPCController(object):
|
||||
|
||||
"""A Controller (page handler collection) for XML-RPC.
|
||||
|
||||
To use it, have your controllers subclass this base class (it will
|
||||
turn on the tool for you).
|
||||
|
||||
You can also supply the following optional config entries::
|
||||
|
||||
tools.xmlrpc.encoding: 'utf-8'
|
||||
tools.xmlrpc.allow_none: 0
|
||||
|
||||
XML-RPC is a rather discontinuous layer over HTTP; dispatching to the
|
||||
appropriate handler must first be performed according to the URL, and
|
||||
then a second dispatch step must take place according to the RPC method
|
||||
specified in the request body. It also allows a superfluous "/RPC2"
|
||||
prefix in the URL, supplies its own handler args in the body, and
|
||||
requires a 200 OK "Fault" response instead of 404 when the desired
|
||||
method is not found.
|
||||
|
||||
Therefore, XML-RPC cannot be implemented for CherryPy via a Tool alone.
|
||||
This Controller acts as the dispatch target for the first half (based
|
||||
on the URL); it then reads the RPC method from the request body and
|
||||
does its own second dispatch step based on that method. It also reads
|
||||
body params, and returns a Fault on error.
|
||||
|
||||
The XMLRPCDispatcher strips any /RPC2 prefix; if you aren't using /RPC2
|
||||
in your URL's, you can safely skip turning on the XMLRPCDispatcher.
|
||||
Otherwise, you need to use declare it in config::
|
||||
|
||||
request.dispatch: cherrypy.dispatch.XMLRPCDispatcher()
|
||||
"""
|
||||
|
||||
# Note we're hard-coding this into the 'tools' namespace. We could do
|
||||
# a huge amount of work to make it relocatable, but the only reason why
|
||||
# would be if someone actually disabled the default_toolbox. Meh.
|
||||
_cp_config = {'tools.xmlrpc.on': True}
|
||||
|
||||
def default(self, *vpath, **params):
|
||||
rpcparams, rpcmethod = _xmlrpc.process_body()
|
||||
|
||||
subhandler = self
|
||||
for attr in str(rpcmethod).split('.'):
|
||||
subhandler = getattr(subhandler, attr, None)
|
||||
|
||||
if subhandler and getattr(subhandler, "exposed", False):
|
||||
body = subhandler(*(vpath + rpcparams), **params)
|
||||
|
||||
else:
|
||||
# https://bitbucket.org/cherrypy/cherrypy/issue/533
|
||||
# if a method is not found, an xmlrpclib.Fault should be returned
|
||||
# raising an exception here will do that; see
|
||||
# cherrypy.lib.xmlrpcutil.on_error
|
||||
raise Exception('method "%s" is not supported' % attr)
|
||||
|
||||
conf = cherrypy.serving.request.toolmaps['tools'].get("xmlrpc", {})
|
||||
_xmlrpc.respond(body,
|
||||
conf.get('encoding', 'utf-8'),
|
||||
conf.get('allow_none', 0))
|
||||
return cherrypy.serving.response.body
|
||||
default.exposed = True
|
||||
|
||||
|
||||
class SessionAuthTool(HandlerTool):
|
||||
|
||||
def _setargs(self):
|
||||
for name in dir(cptools.SessionAuth):
|
||||
if not name.startswith("__"):
|
||||
setattr(self, name, None)
|
||||
|
||||
|
||||
class CachingTool(Tool):
|
||||
|
||||
"""Caching Tool for CherryPy."""
|
||||
|
||||
def _wrapper(self, **kwargs):
|
||||
request = cherrypy.serving.request
|
||||
if _caching.get(**kwargs):
|
||||
request.handler = None
|
||||
else:
|
||||
if request.cacheable:
|
||||
# Note the devious technique here of adding hooks on the fly
|
||||
request.hooks.attach('before_finalize', _caching.tee_output,
|
||||
priority=90)
|
||||
_wrapper.priority = 20
|
||||
|
||||
def _setup(self):
|
||||
"""Hook caching into cherrypy.request."""
|
||||
conf = self._merged_args()
|
||||
|
||||
p = conf.pop("priority", None)
|
||||
cherrypy.serving.request.hooks.attach('before_handler', self._wrapper,
|
||||
priority=p, **conf)
|
||||
|
||||
|
||||
class Toolbox(object):
|
||||
|
||||
"""A collection of Tools.
|
||||
|
||||
This object also functions as a config namespace handler for itself.
|
||||
Custom toolboxes should be added to each Application's toolboxes dict.
|
||||
"""
|
||||
|
||||
def __init__(self, namespace):
|
||||
self.namespace = namespace
|
||||
|
||||
def __setattr__(self, name, value):
|
||||
# If the Tool._name is None, supply it from the attribute name.
|
||||
if isinstance(value, Tool):
|
||||
if value._name is None:
|
||||
value._name = name
|
||||
value.namespace = self.namespace
|
||||
object.__setattr__(self, name, value)
|
||||
|
||||
def __enter__(self):
|
||||
"""Populate request.toolmaps from tools specified in config."""
|
||||
cherrypy.serving.request.toolmaps[self.namespace] = map = {}
|
||||
|
||||
def populate(k, v):
|
||||
toolname, arg = k.split(".", 1)
|
||||
bucket = map.setdefault(toolname, {})
|
||||
bucket[arg] = v
|
||||
return populate
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
"""Run tool._setup() for each tool in our toolmap."""
|
||||
map = cherrypy.serving.request.toolmaps.get(self.namespace)
|
||||
if map:
|
||||
for name, settings in map.items():
|
||||
if settings.get("on", False):
|
||||
tool = getattr(self, name)
|
||||
tool._setup()
|
||||
|
||||
|
||||
class DeprecatedTool(Tool):
|
||||
|
||||
_name = None
|
||||
warnmsg = "This Tool is deprecated."
|
||||
|
||||
def __init__(self, point, warnmsg=None):
|
||||
self.point = point
|
||||
if warnmsg is not None:
|
||||
self.warnmsg = warnmsg
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
warnings.warn(self.warnmsg)
|
||||
|
||||
def tool_decorator(f):
|
||||
return f
|
||||
return tool_decorator
|
||||
|
||||
def _setup(self):
|
||||
warnings.warn(self.warnmsg)
|
||||
|
||||
|
||||
default_toolbox = _d = Toolbox("tools")
|
||||
_d.session_auth = SessionAuthTool(cptools.session_auth)
|
||||
_d.allow = Tool('on_start_resource', cptools.allow)
|
||||
_d.proxy = Tool('before_request_body', cptools.proxy, priority=30)
|
||||
_d.response_headers = Tool('on_start_resource', cptools.response_headers)
|
||||
_d.log_tracebacks = Tool('before_error_response', cptools.log_traceback)
|
||||
_d.log_headers = Tool('before_error_response', cptools.log_request_headers)
|
||||
_d.log_hooks = Tool('on_end_request', cptools.log_hooks, priority=100)
|
||||
_d.err_redirect = ErrorTool(cptools.redirect)
|
||||
_d.etags = Tool('before_finalize', cptools.validate_etags, priority=75)
|
||||
_d.decode = Tool('before_request_body', encoding.decode)
|
||||
# the order of encoding, gzip, caching is important
|
||||
_d.encode = Tool('before_handler', encoding.ResponseEncoder, priority=70)
|
||||
_d.gzip = Tool('before_finalize', encoding.gzip, priority=80)
|
||||
_d.staticdir = HandlerTool(static.staticdir)
|
||||
_d.staticfile = HandlerTool(static.staticfile)
|
||||
_d.sessions = SessionTool()
|
||||
_d.xmlrpc = ErrorTool(_xmlrpc.on_error)
|
||||
_d.caching = CachingTool('before_handler', _caching.get, 'caching')
|
||||
_d.expires = Tool('before_finalize', _caching.expires)
|
||||
_d.tidy = DeprecatedTool(
|
||||
'before_finalize',
|
||||
"The tidy tool has been removed from the standard distribution of "
|
||||
"CherryPy. The most recent version can be found at "
|
||||
"http://tools.cherrypy.org/browser.")
|
||||
_d.nsgmls = DeprecatedTool(
|
||||
'before_finalize',
|
||||
"The nsgmls tool has been removed from the standard distribution of "
|
||||
"CherryPy. The most recent version can be found at "
|
||||
"http://tools.cherrypy.org/browser.")
|
||||
_d.ignore_headers = Tool('before_request_body', cptools.ignore_headers)
|
||||
_d.referer = Tool('before_request_body', cptools.referer)
|
||||
_d.basic_auth = Tool('on_start_resource', auth.basic_auth)
|
||||
_d.digest_auth = Tool('on_start_resource', auth.digest_auth)
|
||||
_d.trailing_slash = Tool('before_handler', cptools.trailing_slash, priority=60)
|
||||
_d.flatten = Tool('before_finalize', cptools.flatten)
|
||||
_d.accept = Tool('on_start_resource', cptools.accept)
|
||||
_d.redirect = Tool('on_start_resource', cptools.redirect)
|
||||
_d.autovary = Tool('on_start_resource', cptools.autovary, priority=0)
|
||||
_d.json_in = Tool('before_request_body', jsontools.json_in, priority=30)
|
||||
_d.json_out = Tool('before_handler', jsontools.json_out, priority=30)
|
||||
_d.auth_basic = Tool('before_handler', auth_basic.basic_auth, priority=1)
|
||||
_d.auth_digest = Tool('before_handler', auth_digest.digest_auth, priority=1)
|
||||
|
||||
del _d, cptools, encoding, auth, static
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user