diff --git a/.gitignore b/.gitignore index ec6cbed8..2c038baf 100644 --- a/.gitignore +++ b/.gitignore @@ -171,7 +171,3 @@ _ReSharper*/ /logs .project .pydevproject -.vs/ - -#Python virtual environment -env/ \ No newline at end of file diff --git a/.travis.yml b/.travis.yml index cc0fb495..ffed186b 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,24 +1,21 @@ +# Travis CI configuration file +# http://about.travis-ci.org/docs/ + language: python + +# Available Python versions: +# http://about.travis-ci.org/docs/user/ci-environment/#Python-VM-images python: - - '2.6' - - '2.7' + - "2.6" + - "2.7" +# pylint 1.4 does not run under python 2.6 install: - - pip install -r requirements.txt - - pip install pylint - - pip install pyflakes - - pip install pep8 + - pip install pyOpenSSL + - pip install pylint==1.3.1 + - pip install pyflakes + - pip install pep8 script: - pep8 headphones + - pylint --rcfile=pylintrc headphones - pyflakes headphones -before_deploy: - - pip install -r requirements.txt -t lib/ - - find . -name "*.pyc" -delete - - find ./lib/ -path "*/*\-info/*" -delete - - zip -r -9 headphones.zip data/ headphones/ init-scripts/ lib/ Headphones.py -deploy: - provider: releases - api_key: - secure: - file: headphones.zip - on: - tags: true + - nosetests headphones diff --git a/lib/MultipartPostHandler.py b/lib/MultipartPostHandler.py new file mode 100644 index 00000000..ef63afa2 --- /dev/null +++ b/lib/MultipartPostHandler.py @@ -0,0 +1,88 @@ +#!/usr/bin/python + +#### +# 06/2010 Nic Wolfe +# 02/2006 Will Holcomb +# +# This library is free software; you can redistribute it and/or +# modify it under the terms of the GNU Lesser General Public +# License as published by the Free Software Foundation; either +# version 2.1 of the License, or (at your option) any later version. +# +# This library is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +# Lesser General Public License for more details. +# + +import urllib +import urllib2 +import mimetools, mimetypes +import os, sys + +# Controls how sequences are uncoded. If true, elements may be given multiple values by +# assigning a sequence. +doseq = 1 + +class MultipartPostHandler(urllib2.BaseHandler): + handler_order = urllib2.HTTPHandler.handler_order - 10 # needs to run first + + def http_request(self, request): + data = request.get_data() + if data is not None and type(data) != str: + v_files = [] + v_vars = [] + try: + for(key, value) in data.items(): + if type(value) in (file, list, tuple): + v_files.append((key, value)) + else: + v_vars.append((key, value)) + except TypeError: + systype, value, traceback = sys.exc_info() + raise TypeError, "not a valid non-string sequence or mapping object", traceback + + if len(v_files) == 0: + data = urllib.urlencode(v_vars, doseq) + else: + boundary, data = MultipartPostHandler.multipart_encode(v_vars, v_files) + contenttype = 'multipart/form-data; boundary=%s' % boundary + if(request.has_header('Content-Type') + and request.get_header('Content-Type').find('multipart/form-data') != 0): + print "Replacing %s with %s" % (request.get_header('content-type'), 'multipart/form-data') + request.add_unredirected_header('Content-Type', contenttype) + + request.add_data(data) + return request + + @staticmethod + def multipart_encode(vars, files, boundary = None, buffer = None): + if boundary is None: + boundary = mimetools.choose_boundary() + if buffer is None: + buffer = '' + for(key, value) in vars: + buffer += '--%s\r\n' % boundary + buffer += 'Content-Disposition: form-data; name="%s"' % key + buffer += '\r\n\r\n' + value + '\r\n' + for(key, fd) in files: + + # allow them to pass in a file or a tuple with name & data + if type(fd) == file: + name_in = fd.name + fd.seek(0) + data_in = fd.read() + elif type(fd) in (tuple, list): + name_in, data_in = fd + + filename = os.path.basename(name_in) + contenttype = mimetypes.guess_type(filename)[0] or 'application/octet-stream' + buffer += '--%s\r\n' % boundary + buffer += 'Content-Disposition: form-data; name="%s"; filename="%s"\r\n' % (key, filename) + buffer += 'Content-Type: %s\r\n' % contenttype + # buffer += 'Content-Length: %s\r\n' % file_size + buffer += '\r\n' + data_in + '\r\n' + buffer += '--%s--\r\n\r\n' % boundary + return boundary, buffer + + https_request = http_request \ No newline at end of file diff --git a/lib/apscheduler/__init__.py b/lib/apscheduler/__init__.py new file mode 100644 index 00000000..283002b7 --- /dev/null +++ b/lib/apscheduler/__init__.py @@ -0,0 +1,5 @@ +version_info = (3, 0, 1) +version = '3.0.1' +release = '3.0.1' + +__version__ = release # PEP 396 diff --git a/lib/apscheduler/events.py b/lib/apscheduler/events.py new file mode 100644 index 00000000..9418263d --- /dev/null +++ b/lib/apscheduler/events.py @@ -0,0 +1,73 @@ +__all__ = ('EVENT_SCHEDULER_START', 'EVENT_SCHEDULER_SHUTDOWN', 'EVENT_EXECUTOR_ADDED', 'EVENT_EXECUTOR_REMOVED', + 'EVENT_JOBSTORE_ADDED', 'EVENT_JOBSTORE_REMOVED', 'EVENT_ALL_JOBS_REMOVED', 'EVENT_JOB_ADDED', + 'EVENT_JOB_REMOVED', 'EVENT_JOB_MODIFIED', 'EVENT_JOB_EXECUTED', 'EVENT_JOB_ERROR', 'EVENT_JOB_MISSED', + 'SchedulerEvent', 'JobEvent', 'JobExecutionEvent') + + +EVENT_SCHEDULER_START = 1 +EVENT_SCHEDULER_SHUTDOWN = 2 +EVENT_EXECUTOR_ADDED = 4 +EVENT_EXECUTOR_REMOVED = 8 +EVENT_JOBSTORE_ADDED = 16 +EVENT_JOBSTORE_REMOVED = 32 +EVENT_ALL_JOBS_REMOVED = 64 +EVENT_JOB_ADDED = 128 +EVENT_JOB_REMOVED = 256 +EVENT_JOB_MODIFIED = 512 +EVENT_JOB_EXECUTED = 1024 +EVENT_JOB_ERROR = 2048 +EVENT_JOB_MISSED = 4096 +EVENT_ALL = (EVENT_SCHEDULER_START | EVENT_SCHEDULER_SHUTDOWN | EVENT_JOBSTORE_ADDED | EVENT_JOBSTORE_REMOVED | + EVENT_JOB_ADDED | EVENT_JOB_REMOVED | EVENT_JOB_MODIFIED | EVENT_JOB_EXECUTED | + EVENT_JOB_ERROR | EVENT_JOB_MISSED) + + +class SchedulerEvent(object): + """ + An event that concerns the scheduler itself. + + :ivar code: the type code of this event + :ivar alias: alias of the job store or executor that was added or removed (if applicable) + """ + + def __init__(self, code, alias=None): + super(SchedulerEvent, self).__init__() + self.code = code + self.alias = alias + + def __repr__(self): + return '<%s (code=%d)>' % (self.__class__.__name__, self.code) + + +class JobEvent(SchedulerEvent): + """ + An event that concerns a job. + + :ivar code: the type code of this event + :ivar job_id: identifier of the job in question + :ivar jobstore: alias of the job store containing the job in question + """ + + def __init__(self, code, job_id, jobstore): + super(JobEvent, self).__init__(code) + self.code = code + self.job_id = job_id + self.jobstore = jobstore + + +class JobExecutionEvent(JobEvent): + """ + An event that concerns the execution of individual jobs. + + :ivar scheduled_run_time: the time when the job was scheduled to be run + :ivar retval: the return value of the successfully executed job + :ivar exception: the exception raised by the job + :ivar traceback: a formatted traceback for the exception + """ + + def __init__(self, code, job_id, jobstore, scheduled_run_time, retval=None, exception=None, traceback=None): + super(JobExecutionEvent, self).__init__(code, job_id, jobstore) + self.scheduled_run_time = scheduled_run_time + self.retval = retval + self.exception = exception + self.traceback = traceback diff --git a/lib/apscheduler/executors/__init__.py b/lib/apscheduler/executors/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/lib/apscheduler/executors/asyncio.py b/lib/apscheduler/executors/asyncio.py new file mode 100644 index 00000000..fade99f8 --- /dev/null +++ b/lib/apscheduler/executors/asyncio.py @@ -0,0 +1,28 @@ +from __future__ import absolute_import +import sys + +from apscheduler.executors.base import BaseExecutor, run_job + + +class AsyncIOExecutor(BaseExecutor): + """ + Runs jobs in the default executor of the event loop. + + Plugin alias: ``asyncio`` + """ + + def start(self, scheduler, alias): + super(AsyncIOExecutor, self).start(scheduler, alias) + self._eventloop = scheduler._eventloop + + def _do_submit_job(self, job, run_times): + def callback(f): + try: + events = f.result() + except: + self._run_job_error(job.id, *sys.exc_info()[1:]) + else: + self._run_job_success(job.id, events) + + f = self._eventloop.run_in_executor(None, run_job, job, job._jobstore_alias, run_times, self._logger.name) + f.add_done_callback(callback) diff --git a/lib/apscheduler/executors/base.py b/lib/apscheduler/executors/base.py new file mode 100644 index 00000000..5a0a19eb --- /dev/null +++ b/lib/apscheduler/executors/base.py @@ -0,0 +1,119 @@ +from abc import ABCMeta, abstractmethod +from collections import defaultdict +from datetime import datetime, timedelta +from traceback import format_tb +import logging +import sys + +from pytz import utc +import six + +from apscheduler.events import JobExecutionEvent, EVENT_JOB_MISSED, EVENT_JOB_ERROR, EVENT_JOB_EXECUTED + + +class MaxInstancesReachedError(Exception): + def __init__(self, job): + super(MaxInstancesReachedError, self).__init__( + 'Job "%s" has already reached its maximum number of instances (%d)' % (job.id, job.max_instances)) + + +class BaseExecutor(six.with_metaclass(ABCMeta, object)): + """Abstract base class that defines the interface that every executor must implement.""" + + _scheduler = None + _lock = None + _logger = logging.getLogger('apscheduler.executors') + + def __init__(self): + super(BaseExecutor, self).__init__() + self._instances = defaultdict(lambda: 0) + + def start(self, scheduler, alias): + """ + Called by the scheduler when the scheduler is being started or when the executor is being added to an already + running scheduler. + + :param apscheduler.schedulers.base.BaseScheduler scheduler: the scheduler that is starting this executor + :param str|unicode alias: alias of this executor as it was assigned to the scheduler + """ + + self._scheduler = scheduler + self._lock = scheduler._create_lock() + self._logger = logging.getLogger('apscheduler.executors.%s' % alias) + + def shutdown(self, wait=True): + """ + Shuts down this executor. + + :param bool wait: ``True`` to wait until all submitted jobs have been executed + """ + + def submit_job(self, job, run_times): + """ + Submits job for execution. + + :param Job job: job to execute + :param list[datetime] run_times: list of datetimes specifying when the job should have been run + :raises MaxInstancesReachedError: if the maximum number of allowed instances for this job has been reached + """ + + assert self._lock is not None, 'This executor has not been started yet' + with self._lock: + if self._instances[job.id] >= job.max_instances: + raise MaxInstancesReachedError(job) + + self._do_submit_job(job, run_times) + self._instances[job.id] += 1 + + @abstractmethod + def _do_submit_job(self, job, run_times): + """Performs the actual task of scheduling `run_job` to be called.""" + + def _run_job_success(self, job_id, events): + """Called by the executor with the list of generated events when `run_job` has been successfully called.""" + + with self._lock: + self._instances[job_id] -= 1 + + for event in events: + self._scheduler._dispatch_event(event) + + def _run_job_error(self, job_id, exc, traceback=None): + """Called by the executor with the exception if there is an error calling `run_job`.""" + + with self._lock: + self._instances[job_id] -= 1 + + exc_info = (exc.__class__, exc, traceback) + self._logger.error('Error running job %s', job_id, exc_info=exc_info) + + +def run_job(job, jobstore_alias, run_times, logger_name): + """Called by executors to run the job. Returns a list of scheduler events to be dispatched by the scheduler.""" + + events = [] + logger = logging.getLogger(logger_name) + for run_time in run_times: + # See if the job missed its run time window, and handle possible misfires accordingly + if job.misfire_grace_time is not None: + difference = datetime.now(utc) - run_time + grace_time = timedelta(seconds=job.misfire_grace_time) + if difference > grace_time: + events.append(JobExecutionEvent(EVENT_JOB_MISSED, job.id, jobstore_alias, run_time)) + logger.warning('Run time of job "%s" was missed by %s', job, difference) + continue + + logger.info('Running job "%s" (scheduled at %s)', job, run_time) + try: + retval = job.func(*job.args, **job.kwargs) + except: + exc, tb = sys.exc_info()[1:] + formatted_tb = ''.join(format_tb(tb)) + events.append(JobExecutionEvent(EVENT_JOB_ERROR, job.id, jobstore_alias, run_time, exception=exc, + traceback=formatted_tb)) + logger.exception('Job "%s" raised an exception', job) + else: + events.append(JobExecutionEvent(EVENT_JOB_EXECUTED, job.id, jobstore_alias, run_time, retval=retval)) + logger.info('Job "%s" executed successfully', job) + + return events diff --git a/lib/apscheduler/executors/debug.py b/lib/apscheduler/executors/debug.py new file mode 100644 index 00000000..1f6f6b1a --- /dev/null +++ b/lib/apscheduler/executors/debug.py @@ -0,0 +1,19 @@ +import sys + +from apscheduler.executors.base import BaseExecutor, run_job + + +class DebugExecutor(BaseExecutor): + """ + A special executor that executes the target callable directly instead of deferring it to a thread or process. + + Plugin alias: ``debug`` + """ + + def _do_submit_job(self, job, run_times): + try: + events = run_job(job, job._jobstore_alias, run_times, self._logger.name) + except: + self._run_job_error(job.id, *sys.exc_info()[1:]) + else: + self._run_job_success(job.id, events) diff --git a/lib/apscheduler/executors/gevent.py b/lib/apscheduler/executors/gevent.py new file mode 100644 index 00000000..9f4db2fc --- /dev/null +++ b/lib/apscheduler/executors/gevent.py @@ -0,0 +1,29 @@ +from __future__ import absolute_import +import sys + +from apscheduler.executors.base import BaseExecutor, run_job + + +try: + import gevent +except ImportError: # pragma: nocover + raise ImportError('GeventExecutor requires gevent installed') + + +class GeventExecutor(BaseExecutor): + """ + Runs jobs as greenlets. + + Plugin alias: ``gevent`` + """ + + def _do_submit_job(self, job, run_times): + def callback(greenlet): + try: + events = greenlet.get() + except: + self._run_job_error(job.id, *sys.exc_info()[1:]) + else: + self._run_job_success(job.id, events) + + gevent.spawn(run_job, job, job._jobstore_alias, run_times, self._logger.name).link(callback) diff --git a/lib/apscheduler/executors/pool.py b/lib/apscheduler/executors/pool.py new file mode 100644 index 00000000..2f4ef455 --- /dev/null +++ b/lib/apscheduler/executors/pool.py @@ -0,0 +1,54 @@ +from abc import abstractmethod +import concurrent.futures + +from apscheduler.executors.base import BaseExecutor, run_job + + +class BasePoolExecutor(BaseExecutor): + @abstractmethod + def __init__(self, pool): + super(BasePoolExecutor, self).__init__() + self._pool = pool + + def _do_submit_job(self, job, run_times): + def callback(f): + exc, tb = (f.exception_info() if hasattr(f, 'exception_info') else + (f.exception(), getattr(f.exception(), '__traceback__', None))) + if exc: + self._run_job_error(job.id, exc, tb) + else: + self._run_job_success(job.id, f.result()) + + f = self._pool.submit(run_job, job, job._jobstore_alias, run_times, self._logger.name) + f.add_done_callback(callback) + + def shutdown(self, wait=True): + self._pool.shutdown(wait) + + +class ThreadPoolExecutor(BasePoolExecutor): + """ + An executor that runs jobs in a concurrent.futures thread pool. + + Plugin alias: ``threadpool`` + + :param max_workers: the maximum number of spawned threads. + """ + + def __init__(self, max_workers=10): + pool = concurrent.futures.ThreadPoolExecutor(int(max_workers)) + super(ThreadPoolExecutor, self).__init__(pool) + + +class ProcessPoolExecutor(BasePoolExecutor): + """ + An executor that runs jobs in a concurrent.futures process pool. + + Plugin alias: ``processpool`` + + :param max_workers: the maximum number of spawned processes. + """ + + def __init__(self, max_workers=10): + pool = concurrent.futures.ProcessPoolExecutor(int(max_workers)) + super(ProcessPoolExecutor, self).__init__(pool) diff --git a/lib/apscheduler/executors/twisted.py b/lib/apscheduler/executors/twisted.py new file mode 100644 index 00000000..29217221 --- /dev/null +++ b/lib/apscheduler/executors/twisted.py @@ -0,0 +1,25 @@ +from __future__ import absolute_import + +from apscheduler.executors.base import BaseExecutor, run_job + + +class TwistedExecutor(BaseExecutor): + """ + Runs jobs in the reactor's thread pool. + + Plugin alias: ``twisted`` + """ + + def start(self, scheduler, alias): + super(TwistedExecutor, self).start(scheduler, alias) + self._reactor = scheduler._reactor + + def _do_submit_job(self, job, run_times): + def callback(success, result): + if success: + self._run_job_success(job.id, result) + else: + self._run_job_error(job.id, result.value, result.tb) + + self._reactor.getThreadPool().callInThreadWithCallback(callback, run_job, job, job._jobstore_alias, run_times, + self._logger.name) diff --git a/lib/apscheduler/job.py b/lib/apscheduler/job.py new file mode 100644 index 00000000..f5639dae --- /dev/null +++ b/lib/apscheduler/job.py @@ -0,0 +1,252 @@ +from collections import Iterable, Mapping +from uuid import uuid4 + +import six + +from apscheduler.triggers.base import BaseTrigger +from apscheduler.util import ref_to_obj, obj_to_ref, datetime_repr, repr_escape, get_callable_name, check_callable_args, \ + convert_to_datetime + + +class Job(object): + """ + Contains the options given when scheduling callables and its current schedule and other state. + This class should never be instantiated by the user. + + :var str id: the unique identifier of this job + :var str name: the description of this job + :var func: the callable to execute + :var tuple|list args: positional arguments to the callable + :var dict kwargs: keyword arguments to the callable + :var bool coalesce: whether to only run the job once when several run times are due + :var trigger: the trigger object that controls the schedule of this job + :var str executor: the name of the executor that will run this job + :var int misfire_grace_time: the time (in seconds) how much this job's execution is allowed to be late + :var int max_instances: the maximum number of concurrently executing instances allowed for this job + :var datetime.datetime next_run_time: the next scheduled run time of this job + """ + + __slots__ = ('_scheduler', '_jobstore_alias', 'id', 'trigger', 'executor', 'func', 'func_ref', 'args', 'kwargs', + 'name', 'misfire_grace_time', 'coalesce', 'max_instances', 'next_run_time') + + def __init__(self, scheduler, id=None, **kwargs): + super(Job, self).__init__() + self._scheduler = scheduler + self._jobstore_alias = None + self._modify(id=id or uuid4().hex, **kwargs) + + def modify(self, **changes): + """ + Makes the given changes to this job and saves it in the associated job store. + Accepted keyword arguments are the same as the variables on this class. + + .. seealso:: :meth:`~apscheduler.schedulers.base.BaseScheduler.modify_job` + """ + + self._scheduler.modify_job(self.id, self._jobstore_alias, **changes) + + def reschedule(self, trigger, **trigger_args): + """ + Shortcut for switching the trigger on this job. + + .. seealso:: :meth:`~apscheduler.schedulers.base.BaseScheduler.reschedule_job` + """ + + self._scheduler.reschedule_job(self.id, self._jobstore_alias, trigger, **trigger_args) + + def pause(self): + """ + Temporarily suspend the execution of this job. + + .. seealso:: :meth:`~apscheduler.schedulers.base.BaseScheduler.pause_job` + """ + + self._scheduler.pause_job(self.id, self._jobstore_alias) + + def resume(self): + """ + Resume the schedule of this job if previously paused. + + .. seealso:: :meth:`~apscheduler.schedulers.base.BaseScheduler.resume_job` + """ + + self._scheduler.resume_job(self.id, self._jobstore_alias) + + def remove(self): + """ + Unschedules this job and removes it from its associated job store. + + .. seealso:: :meth:`~apscheduler.schedulers.base.BaseScheduler.remove_job` + """ + + self._scheduler.remove_job(self.id, self._jobstore_alias) + + @property + def pending(self): + """Returns ``True`` if the referenced job is still waiting to be added to its designated job store.""" + + return self._jobstore_alias is None + + # + # Private API + # + + def _get_run_times(self, now): + """ + Computes the scheduled run times between ``next_run_time`` and ``now`` (inclusive). + + :type now: datetime.datetime + :rtype: list[datetime.datetime] + """ + + run_times = [] + next_run_time = self.next_run_time + while next_run_time and next_run_time <= now: + run_times.append(next_run_time) + next_run_time = self.trigger.get_next_fire_time(next_run_time, now) + + return run_times + + def _modify(self, **changes): + """Validates the changes to the Job and makes the modifications if and only if all of them validate.""" + + approved = {} + + if 'id' in changes: + value = changes.pop('id') + if not isinstance(value, six.string_types): + raise TypeError("id must be a nonempty string") + if hasattr(self, 'id'): + raise ValueError('The job ID may not be changed') + approved['id'] = value + + if 'func' in changes or 'args' in changes or 'kwargs' in changes: + func = changes.pop('func') if 'func' in changes else self.func + args = changes.pop('args') if 'args' in changes else self.args + kwargs = changes.pop('kwargs') if 'kwargs' in changes else self.kwargs + + if isinstance(func, str): + func_ref = func + func = ref_to_obj(func) + elif callable(func): + try: + func_ref = obj_to_ref(func) + except ValueError: + # If this happens, this Job won't be serializable + func_ref = None + else: + raise TypeError('func must be a callable or a textual reference to one') + + if not hasattr(self, 'name') and changes.get('name', None) is None: + changes['name'] = get_callable_name(func) + + if isinstance(args, six.string_types) or not isinstance(args, Iterable): + raise TypeError('args must be a non-string iterable') + if isinstance(kwargs, six.string_types) or not isinstance(kwargs, Mapping): + raise TypeError('kwargs must be a dict-like object') + + check_callable_args(func, args, kwargs) + + approved['func'] = func + approved['func_ref'] = func_ref + approved['args'] = args + approved['kwargs'] = kwargs + + if 'name' in changes: + value = changes.pop('name') + if not value or not isinstance(value, six.string_types): + raise TypeError("name must be a nonempty string") + approved['name'] = value + + if 'misfire_grace_time' in changes: + value = changes.pop('misfire_grace_time') + if value is not None and (not isinstance(value, six.integer_types) or value <= 0): + raise TypeError('misfire_grace_time must be either None or a positive integer') + approved['misfire_grace_time'] = value + + if 'coalesce' in changes: + value = bool(changes.pop('coalesce')) + approved['coalesce'] = value + + if 'max_instances' in changes: + value = changes.pop('max_instances') + if not isinstance(value, six.integer_types) or value <= 0: + raise TypeError('max_instances must be a positive integer') + approved['max_instances'] = value + + if 'trigger' in changes: + trigger = changes.pop('trigger') + if not isinstance(trigger, BaseTrigger): + raise TypeError('Expected a trigger instance, got %s instead' % trigger.__class__.__name__) + + approved['trigger'] = trigger + + if 'executor' in changes: + value = changes.pop('executor') + if not isinstance(value, six.string_types): + raise TypeError('executor must be a string') + approved['executor'] = value + + if 'next_run_time' in changes: + value = changes.pop('next_run_time') + approved['next_run_time'] = convert_to_datetime(value, self._scheduler.timezone, 'next_run_time') + + if changes: + raise AttributeError('The following are not modifiable attributes of Job: %s' % ', '.join(changes)) + + for key, value in six.iteritems(approved): + setattr(self, key, value) + + def __getstate__(self): + # Don't allow this Job to be serialized if the function reference could not be determined + if not self.func_ref: + raise ValueError('This Job cannot be serialized since the reference to its callable (%r) could not be ' + 'determined. Consider giving a textual reference (module:function name) instead.' % + (self.func,)) + + return { + 'version': 1, + 'id': self.id, + 'func': self.func_ref, + 'trigger': self.trigger, + 'executor': self.executor, + 'args': self.args, + 'kwargs': self.kwargs, + 'name': self.name, + 'misfire_grace_time': self.misfire_grace_time, + 'coalesce': self.coalesce, + 'max_instances': self.max_instances, + 'next_run_time': self.next_run_time + } + + def __setstate__(self, state): + if state.get('version', 1) > 1: + raise ValueError('Job has version %s, but only version 1 can be handled' % state['version']) + + self.id = state['id'] + self.func_ref = state['func'] + self.func = ref_to_obj(self.func_ref) + self.trigger = state['trigger'] + self.executor = state['executor'] + self.args = state['args'] + self.kwargs = state['kwargs'] + self.name = state['name'] + self.misfire_grace_time = state['misfire_grace_time'] + self.coalesce = state['coalesce'] + self.max_instances = state['max_instances'] + self.next_run_time = state['next_run_time'] + + def __eq__(self, other): + if isinstance(other, Job): + return self.id == other.id + return NotImplemented + + def __repr__(self): + return '' % (repr_escape(self.id), repr_escape(self.name)) + + def __str__(self): + return '%s (trigger: %s, next run at: %s)' % (repr_escape(self.name), repr_escape(str(self.trigger)), + datetime_repr(self.next_run_time)) + + def __unicode__(self): + return six.u('%s (trigger: %s, next run at: %s)') % (self.name, self.trigger, datetime_repr(self.next_run_time)) diff --git a/lib/apscheduler/jobstores/__init__.py b/lib/apscheduler/jobstores/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/lib/apscheduler/jobstores/base.py b/lib/apscheduler/jobstores/base.py new file mode 100644 index 00000000..d9f7147b --- /dev/null +++ b/lib/apscheduler/jobstores/base.py @@ -0,0 +1,127 @@ +from abc import ABCMeta, abstractmethod +import logging + +import six + + +class JobLookupError(KeyError): + """Raised when the job store cannot find a job for update or removal.""" + + def __init__(self, job_id): + super(JobLookupError, self).__init__(six.u('No job by the id of %s was found') % job_id) + + +class ConflictingIdError(KeyError): + """Raised when the uniqueness of job IDs is being violated.""" + + def __init__(self, job_id): + super(ConflictingIdError, self).__init__(six.u('Job identifier (%s) conflicts with an existing job') % job_id) + + +class TransientJobError(ValueError): + """Raised when an attempt to add transient (with no func_ref) job to a persistent job store is detected.""" + + def __init__(self, job_id): + super(TransientJobError, self).__init__( + six.u('Job (%s) cannot be added to this job store because a reference to the callable could not be ' + 'determined.') % job_id) + + +class BaseJobStore(six.with_metaclass(ABCMeta)): + """Abstract base class that defines the interface that every job store must implement.""" + + _scheduler = None + _alias = None + _logger = logging.getLogger('apscheduler.jobstores') + + def start(self, scheduler, alias): + """ + Called by the scheduler when the scheduler is being started or when the job store is being added to an already + running scheduler. + + :param apscheduler.schedulers.base.BaseScheduler scheduler: the scheduler that is starting this job store + :param str|unicode alias: alias of this job store as it was assigned to the scheduler + """ + + self._scheduler = scheduler + self._alias = alias + self._logger = logging.getLogger('apscheduler.jobstores.%s' % alias) + + def shutdown(self): + """Frees any resources still bound to this job store.""" + + @abstractmethod + def lookup_job(self, job_id): + """ + Returns a specific job, or ``None`` if it isn't found.. + + The job store is responsible for setting the ``scheduler`` and ``jobstore`` attributes of the returned job to + point to the scheduler and itself, respectively. + + :param str|unicode job_id: identifier of the job + :rtype: Job + """ + + @abstractmethod + def get_due_jobs(self, now): + """ + Returns the list of jobs that have ``next_run_time`` earlier or equal to ``now``. + The returned jobs must be sorted by next run time (ascending). + + :param datetime.datetime now: the current (timezone aware) datetime + :rtype: list[Job] + """ + + @abstractmethod + def get_next_run_time(self): + """ + Returns the earliest run time of all the jobs stored in this job store, or ``None`` if there are no active jobs. + + :rtype: datetime.datetime + """ + + @abstractmethod + def get_all_jobs(self): + """ + Returns a list of all jobs in this job store. The returned jobs should be sorted by next run time (ascending). + Paused jobs (next_run_time is None) should be sorted last. + + The job store is responsible for setting the ``scheduler`` and ``jobstore`` attributes of the returned jobs to + point to the scheduler and itself, respectively. + + :rtype: list[Job] + """ + + @abstractmethod + def add_job(self, job): + """ + Adds the given job to this store. + + :param Job job: the job to add + :raises ConflictingIdError: if there is another job in this store with the same ID + """ + + @abstractmethod + def update_job(self, job): + """ + Replaces the job in the store with the given newer version. + + :param Job job: the job to update + :raises JobLookupError: if the job does not exist + """ + + @abstractmethod + def remove_job(self, job_id): + """ + Removes the given job from this store. + + :param str|unicode job_id: identifier of the job + :raises JobLookupError: if the job does not exist + """ + + @abstractmethod + def remove_all_jobs(self): + """Removes all jobs from this store.""" + + def __repr__(self): + return '<%s>' % self.__class__.__name__ diff --git a/lib/apscheduler/jobstores/memory.py b/lib/apscheduler/jobstores/memory.py new file mode 100644 index 00000000..645391f3 --- /dev/null +++ b/lib/apscheduler/jobstores/memory.py @@ -0,0 +1,107 @@ +from __future__ import absolute_import + +from apscheduler.jobstores.base import BaseJobStore, JobLookupError, ConflictingIdError +from apscheduler.util import datetime_to_utc_timestamp + + +class MemoryJobStore(BaseJobStore): + """ + Stores jobs in an array in RAM. Provides no persistence support. + + Plugin alias: ``memory`` + """ + + def __init__(self): + super(MemoryJobStore, self).__init__() + self._jobs = [] # list of (job, timestamp), sorted by next_run_time and job id (ascending) + self._jobs_index = {} # id -> (job, timestamp) lookup table + + def lookup_job(self, job_id): + return self._jobs_index.get(job_id, (None, None))[0] + + def get_due_jobs(self, now): + now_timestamp = datetime_to_utc_timestamp(now) + pending = [] + for job, timestamp in self._jobs: + if timestamp is None or timestamp > now_timestamp: + break + pending.append(job) + + return pending + + def get_next_run_time(self): + return self._jobs[0][0].next_run_time if self._jobs else None + + def get_all_jobs(self): + return [j[0] for j in self._jobs] + + def add_job(self, job): + if job.id in self._jobs_index: + raise ConflictingIdError(job.id) + + timestamp = datetime_to_utc_timestamp(job.next_run_time) + index = self._get_job_index(timestamp, job.id) + self._jobs.insert(index, (job, timestamp)) + self._jobs_index[job.id] = (job, timestamp) + + def update_job(self, job): + old_job, old_timestamp = self._jobs_index.get(job.id, (None, None)) + if old_job is None: + raise JobLookupError(job.id) + + # If the next run time has not changed, simply replace the job in its present index. + # Otherwise, reinsert the job to the list to preserve the ordering. + old_index = self._get_job_index(old_timestamp, old_job.id) + new_timestamp = datetime_to_utc_timestamp(job.next_run_time) + if old_timestamp == new_timestamp: + self._jobs[old_index] = (job, new_timestamp) + else: + del self._jobs[old_index] + new_index = self._get_job_index(new_timestamp, job.id) + self._jobs.insert(new_index, (job, new_timestamp)) + + self._jobs_index[old_job.id] = (job, new_timestamp) + + def remove_job(self, job_id): + job, timestamp = self._jobs_index.get(job_id, (None, None)) + if job is None: + raise JobLookupError(job_id) + + index = self._get_job_index(timestamp, job_id) + del self._jobs[index] + del self._jobs_index[job.id] + + def remove_all_jobs(self): + self._jobs = [] + self._jobs_index = {} + + def shutdown(self): + self.remove_all_jobs() + + def _get_job_index(self, timestamp, job_id): + """ + Returns the index of the given job, or if it's not found, the index where the job should be inserted based on + the given timestamp. + + :type timestamp: int + :type job_id: str + """ + + lo, hi = 0, len(self._jobs) + timestamp = float('inf') if timestamp is None else timestamp + while lo < hi: + mid = (lo + hi) // 2 + mid_job, mid_timestamp = self._jobs[mid] + mid_timestamp = float('inf') if mid_timestamp is None else mid_timestamp + if mid_timestamp > timestamp: + hi = mid + elif mid_timestamp < timestamp: + lo = mid + 1 + elif mid_job.id > job_id: + hi = mid + elif mid_job.id < job_id: + lo = mid + 1 + else: + return mid + + return lo diff --git a/lib/apscheduler/jobstores/mongodb.py b/lib/apscheduler/jobstores/mongodb.py new file mode 100644 index 00000000..ff762f7e --- /dev/null +++ b/lib/apscheduler/jobstores/mongodb.py @@ -0,0 +1,124 @@ +from __future__ import absolute_import + +from apscheduler.jobstores.base import BaseJobStore, JobLookupError, ConflictingIdError +from apscheduler.util import maybe_ref, datetime_to_utc_timestamp, utc_timestamp_to_datetime +from apscheduler.job import Job + +try: + import cPickle as pickle +except ImportError: # pragma: nocover + import pickle + +try: + from bson.binary import Binary + from pymongo.errors import DuplicateKeyError + from pymongo import MongoClient, ASCENDING +except ImportError: # pragma: nocover + raise ImportError('MongoDBJobStore requires PyMongo installed') + + +class MongoDBJobStore(BaseJobStore): + """ + Stores jobs in a MongoDB database. Any leftover keyword arguments are directly passed to pymongo's `MongoClient + `_. + + Plugin alias: ``mongodb`` + + :param str database: database to store jobs in + :param str collection: collection to store jobs in + :param client: a :class:`~pymongo.mongo_client.MongoClient` instance to use instead of providing connection + arguments + :param int pickle_protocol: pickle protocol level to use (for serialization), defaults to the highest available + """ + + def __init__(self, database='apscheduler', collection='jobs', client=None, + pickle_protocol=pickle.HIGHEST_PROTOCOL, **connect_args): + super(MongoDBJobStore, self).__init__() + self.pickle_protocol = pickle_protocol + + if not database: + raise ValueError('The "database" parameter must not be empty') + if not collection: + raise ValueError('The "collection" parameter must not be empty') + + if client: + self.connection = maybe_ref(client) + else: + connect_args.setdefault('w', 1) + self.connection = MongoClient(**connect_args) + + self.collection = self.connection[database][collection] + self.collection.ensure_index('next_run_time', sparse=True) + + def lookup_job(self, job_id): + document = self.collection.find_one(job_id, ['job_state']) + return self._reconstitute_job(document['job_state']) if document else None + + def get_due_jobs(self, now): + timestamp = datetime_to_utc_timestamp(now) + return self._get_jobs({'next_run_time': {'$lte': timestamp}}) + + def get_next_run_time(self): + document = self.collection.find_one({'next_run_time': {'$ne': None}}, fields=['next_run_time'], + sort=[('next_run_time', ASCENDING)]) + return utc_timestamp_to_datetime(document['next_run_time']) if document else None + + def get_all_jobs(self): + return self._get_jobs({}) + + def add_job(self, job): + try: + self.collection.insert({ + '_id': job.id, + 'next_run_time': datetime_to_utc_timestamp(job.next_run_time), + 'job_state': Binary(pickle.dumps(job.__getstate__(), self.pickle_protocol)) + }) + except DuplicateKeyError: + raise ConflictingIdError(job.id) + + def update_job(self, job): + changes = { + 'next_run_time': datetime_to_utc_timestamp(job.next_run_time), + 'job_state': Binary(pickle.dumps(job.__getstate__(), self.pickle_protocol)) + } + result = self.collection.update({'_id': job.id}, {'$set': changes}) + if result and result['n'] == 0: + raise JobLookupError(id) + + def remove_job(self, job_id): + result = self.collection.remove(job_id) + if result and result['n'] == 0: + raise JobLookupError(job_id) + + def remove_all_jobs(self): + self.collection.remove() + + def shutdown(self): + self.connection.disconnect() + + def _reconstitute_job(self, job_state): + job_state = pickle.loads(job_state) + job = Job.__new__(Job) + job.__setstate__(job_state) + job._scheduler = self._scheduler + job._jobstore_alias = self._alias + return job + + def _get_jobs(self, conditions): + jobs = [] + failed_job_ids = [] + for document in self.collection.find(conditions, ['_id', 'job_state'], sort=[('next_run_time', ASCENDING)]): + try: + jobs.append(self._reconstitute_job(document['job_state'])) + except: + self._logger.exception('Unable to restore job "%s" -- removing it', document['_id']) + failed_job_ids.append(document['_id']) + + # Remove all the jobs we failed to restore + if failed_job_ids: + self.collection.remove({'_id': {'$in': failed_job_ids}}) + + return jobs + + def __repr__(self): + return '<%s (client=%s)>' % (self.__class__.__name__, self.connection) diff --git a/lib/apscheduler/jobstores/redis.py b/lib/apscheduler/jobstores/redis.py new file mode 100644 index 00000000..2b4ffd52 --- /dev/null +++ b/lib/apscheduler/jobstores/redis.py @@ -0,0 +1,138 @@ +from __future__ import absolute_import + +import six + +from apscheduler.jobstores.base import BaseJobStore, JobLookupError, ConflictingIdError +from apscheduler.util import datetime_to_utc_timestamp, utc_timestamp_to_datetime +from apscheduler.job import Job + +try: + import cPickle as pickle +except ImportError: # pragma: nocover + import pickle + +try: + from redis import StrictRedis +except ImportError: # pragma: nocover + raise ImportError('RedisJobStore requires redis installed') + + +class RedisJobStore(BaseJobStore): + """ + Stores jobs in a Redis database. Any leftover keyword arguments are directly passed to redis's StrictRedis. + + Plugin alias: ``redis`` + + :param int db: the database number to store jobs in + :param str jobs_key: key to store jobs in + :param str run_times_key: key to store the jobs' run times in + :param int pickle_protocol: pickle protocol level to use (for serialization), defaults to the highest available + """ + + def __init__(self, db=0, jobs_key='apscheduler.jobs', run_times_key='apscheduler.run_times', + pickle_protocol=pickle.HIGHEST_PROTOCOL, **connect_args): + super(RedisJobStore, self).__init__() + + if db is None: + raise ValueError('The "db" parameter must not be empty') + if not jobs_key: + raise ValueError('The "jobs_key" parameter must not be empty') + if not run_times_key: + raise ValueError('The "run_times_key" parameter must not be empty') + + self.pickle_protocol = pickle_protocol + self.jobs_key = jobs_key + self.run_times_key = run_times_key + self.redis = StrictRedis(db=int(db), **connect_args) + + def lookup_job(self, job_id): + job_state = self.redis.hget(self.jobs_key, job_id) + return self._reconstitute_job(job_state) if job_state else None + + def get_due_jobs(self, now): + timestamp = datetime_to_utc_timestamp(now) + job_ids = self.redis.zrangebyscore(self.run_times_key, 0, timestamp) + if job_ids: + job_states = self.redis.hmget(self.jobs_key, *job_ids) + return self._reconstitute_jobs(six.moves.zip(job_ids, job_states)) + return [] + + def get_next_run_time(self): + next_run_time = self.redis.zrange(self.run_times_key, 0, 0, withscores=True) + if next_run_time: + return utc_timestamp_to_datetime(next_run_time[0][1]) + + def get_all_jobs(self): + job_states = self.redis.hgetall(self.jobs_key) + jobs = self._reconstitute_jobs(six.iteritems(job_states)) + return sorted(jobs, key=lambda job: job.next_run_time) + + def add_job(self, job): + if self.redis.hexists(self.jobs_key, job.id): + raise ConflictingIdError(job.id) + + with self.redis.pipeline() as pipe: + pipe.multi() + pipe.hset(self.jobs_key, job.id, pickle.dumps(job.__getstate__(), self.pickle_protocol)) + pipe.zadd(self.run_times_key, datetime_to_utc_timestamp(job.next_run_time), job.id) + pipe.execute() + + def update_job(self, job): + if not self.redis.hexists(self.jobs_key, job.id): + raise JobLookupError(job.id) + + with self.redis.pipeline() as pipe: + pipe.hset(self.jobs_key, job.id, pickle.dumps(job.__getstate__(), self.pickle_protocol)) + if job.next_run_time: + pipe.zadd(self.run_times_key, datetime_to_utc_timestamp(job.next_run_time), job.id) + else: + pipe.zrem(self.run_times_key, job.id) + pipe.execute() + + def remove_job(self, job_id): + if not self.redis.hexists(self.jobs_key, job_id): + raise JobLookupError(job_id) + + with self.redis.pipeline() as pipe: + pipe.hdel(self.jobs_key, job_id) + pipe.zrem(self.run_times_key, job_id) + pipe.execute() + + def remove_all_jobs(self): + with self.redis.pipeline() as pipe: + pipe.delete(self.jobs_key) + pipe.delete(self.run_times_key) + pipe.execute() + + def shutdown(self): + self.redis.connection_pool.disconnect() + + def _reconstitute_job(self, job_state): + job_state = pickle.loads(job_state) + job = Job.__new__(Job) + job.__setstate__(job_state) + job._scheduler = self._scheduler + job._jobstore_alias = self._alias + return job + + def _reconstitute_jobs(self, job_states): + jobs = [] + failed_job_ids = [] + for job_id, job_state in job_states: + try: + jobs.append(self._reconstitute_job(job_state)) + except: + self._logger.exception('Unable to restore job "%s" -- removing it', job_id) + failed_job_ids.append(job_id) + + # Remove all the jobs we failed to restore + if failed_job_ids: + with self.redis.pipeline() as pipe: + pipe.hdel(self.jobs_key, *failed_job_ids) + pipe.zrem(self.run_times_key, *failed_job_ids) + pipe.execute() + + return jobs + + def __repr__(self): + return '<%s>' % self.__class__.__name__ diff --git a/lib/apscheduler/jobstores/sqlalchemy.py b/lib/apscheduler/jobstores/sqlalchemy.py new file mode 100644 index 00000000..f8a3c151 --- /dev/null +++ b/lib/apscheduler/jobstores/sqlalchemy.py @@ -0,0 +1,137 @@ +from __future__ import absolute_import + +from apscheduler.jobstores.base import BaseJobStore, JobLookupError, ConflictingIdError +from apscheduler.util import maybe_ref, datetime_to_utc_timestamp, utc_timestamp_to_datetime +from apscheduler.job import Job + +try: + import cPickle as pickle +except ImportError: # pragma: nocover + import pickle + +try: + from sqlalchemy import create_engine, Table, Column, MetaData, Unicode, Float, LargeBinary, select + from sqlalchemy.exc import IntegrityError +except ImportError: # pragma: nocover + raise ImportError('SQLAlchemyJobStore requires SQLAlchemy installed') + + +class SQLAlchemyJobStore(BaseJobStore): + """ + Stores jobs in a database table using SQLAlchemy. The table will be created if it doesn't exist in the database. + + Plugin alias: ``sqlalchemy`` + + :param str url: connection string (see `SQLAlchemy documentation + `_ + on this) + :param engine: an SQLAlchemy Engine to use instead of creating a new one based on ``url`` + :param str tablename: name of the table to store jobs in + :param metadata: a :class:`~sqlalchemy.MetaData` instance to use instead of creating a new one + :param int pickle_protocol: pickle protocol level to use (for serialization), defaults to the highest available + """ + + def __init__(self, url=None, engine=None, tablename='apscheduler_jobs', metadata=None, + pickle_protocol=pickle.HIGHEST_PROTOCOL): + super(SQLAlchemyJobStore, self).__init__() + self.pickle_protocol = pickle_protocol + metadata = maybe_ref(metadata) or MetaData() + + if engine: + self.engine = maybe_ref(engine) + elif url: + self.engine = create_engine(url) + else: + raise ValueError('Need either "engine" or "url" defined') + + # 191 = max key length in MySQL for InnoDB/utf8mb4 tables, 25 = precision that translates to an 8-byte float + self.jobs_t = Table( + tablename, metadata, + Column('id', Unicode(191, _warn_on_bytestring=False), primary_key=True), + Column('next_run_time', Float(25), index=True), + Column('job_state', LargeBinary, nullable=False) + ) + + self.jobs_t.create(self.engine, True) + + def lookup_job(self, job_id): + selectable = select([self.jobs_t.c.job_state]).where(self.jobs_t.c.id == job_id) + job_state = self.engine.execute(selectable).scalar() + return self._reconstitute_job(job_state) if job_state else None + + def get_due_jobs(self, now): + timestamp = datetime_to_utc_timestamp(now) + return self._get_jobs(self.jobs_t.c.next_run_time <= timestamp) + + def get_next_run_time(self): + selectable = select([self.jobs_t.c.next_run_time]).where(self.jobs_t.c.next_run_time != None).\ + order_by(self.jobs_t.c.next_run_time).limit(1) + next_run_time = self.engine.execute(selectable).scalar() + return utc_timestamp_to_datetime(next_run_time) + + def get_all_jobs(self): + return self._get_jobs() + + def add_job(self, job): + insert = self.jobs_t.insert().values(**{ + 'id': job.id, + 'next_run_time': datetime_to_utc_timestamp(job.next_run_time), + 'job_state': pickle.dumps(job.__getstate__(), self.pickle_protocol) + }) + try: + self.engine.execute(insert) + except IntegrityError: + raise ConflictingIdError(job.id) + + def update_job(self, job): + update = self.jobs_t.update().values(**{ + 'next_run_time': datetime_to_utc_timestamp(job.next_run_time), + 'job_state': pickle.dumps(job.__getstate__(), self.pickle_protocol) + }).where(self.jobs_t.c.id == job.id) + result = self.engine.execute(update) + if result.rowcount == 0: + raise JobLookupError(id) + + def remove_job(self, job_id): + delete = self.jobs_t.delete().where(self.jobs_t.c.id == job_id) + result = self.engine.execute(delete) + if result.rowcount == 0: + raise JobLookupError(job_id) + + def remove_all_jobs(self): + delete = self.jobs_t.delete() + self.engine.execute(delete) + + def shutdown(self): + self.engine.dispose() + + def _reconstitute_job(self, job_state): + job_state = pickle.loads(job_state) + job_state['jobstore'] = self + job = Job.__new__(Job) + job.__setstate__(job_state) + job._scheduler = self._scheduler + job._jobstore_alias = self._alias + return job + + def _get_jobs(self, *conditions): + jobs = [] + selectable = select([self.jobs_t.c.id, self.jobs_t.c.job_state]).order_by(self.jobs_t.c.next_run_time) + selectable = selectable.where(*conditions) if conditions else selectable + failed_job_ids = set() + for row in self.engine.execute(selectable): + try: + jobs.append(self._reconstitute_job(row.job_state)) + except: + self._logger.exception('Unable to restore job "%s" -- removing it', row.id) + failed_job_ids.add(row.id) + + # Remove all the jobs we failed to restore + if failed_job_ids: + delete = self.jobs_t.delete().where(self.jobs_t.c.id.in_(failed_job_ids)) + self.engine.execute(delete) + + return jobs + + def __repr__(self): + return '<%s (url=%s)>' % (self.__class__.__name__, self.engine.url) diff --git a/lib/apscheduler/schedulers/__init__.py b/lib/apscheduler/schedulers/__init__.py new file mode 100644 index 00000000..bd8a7900 --- /dev/null +++ b/lib/apscheduler/schedulers/__init__.py @@ -0,0 +1,12 @@ +class SchedulerAlreadyRunningError(Exception): + """Raised when attempting to start or configure the scheduler when it's already running.""" + + def __str__(self): + return 'Scheduler is already running' + + +class SchedulerNotRunningError(Exception): + """Raised when attempting to shutdown the scheduler when it's not running.""" + + def __str__(self): + return 'Scheduler is not running' diff --git a/lib/apscheduler/schedulers/asyncio.py b/lib/apscheduler/schedulers/asyncio.py new file mode 100644 index 00000000..b91ee97a --- /dev/null +++ b/lib/apscheduler/schedulers/asyncio.py @@ -0,0 +1,68 @@ +from __future__ import absolute_import +from functools import wraps + +from apscheduler.schedulers.base import BaseScheduler +from apscheduler.util import maybe_ref + +try: + import asyncio +except ImportError: # pragma: nocover + try: + import trollius as asyncio + except ImportError: + raise ImportError('AsyncIOScheduler requires either Python 3.4 or the asyncio package installed') + + +def run_in_event_loop(func): + @wraps(func) + def wrapper(self, *args, **kwargs): + self._eventloop.call_soon_threadsafe(func, self, *args, **kwargs) + return wrapper + + +class AsyncIOScheduler(BaseScheduler): + """ + A scheduler that runs on an asyncio (:pep:`3156`) event loop. + + Extra options: + + ============== ============================================================= + ``event_loop`` AsyncIO event loop to use (defaults to the global event loop) + ============== ============================================================= + """ + + _eventloop = None + _timeout = None + + def start(self): + super(AsyncIOScheduler, self).start() + self.wakeup() + + @run_in_event_loop + def shutdown(self, wait=True): + super(AsyncIOScheduler, self).shutdown(wait) + self._stop_timer() + + def _configure(self, config): + self._eventloop = maybe_ref(config.pop('event_loop', None)) or asyncio.get_event_loop() + super(AsyncIOScheduler, self)._configure(config) + + def _start_timer(self, wait_seconds): + self._stop_timer() + if wait_seconds is not None: + self._timeout = self._eventloop.call_later(wait_seconds, self.wakeup) + + def _stop_timer(self): + if self._timeout: + self._timeout.cancel() + del self._timeout + + @run_in_event_loop + def wakeup(self): + self._stop_timer() + wait_seconds = self._process_jobs() + self._start_timer(wait_seconds) + + def _create_default_executor(self): + from apscheduler.executors.asyncio import AsyncIOExecutor + return AsyncIOExecutor() diff --git a/lib/apscheduler/schedulers/background.py b/lib/apscheduler/schedulers/background.py new file mode 100644 index 00000000..86ff2ba3 --- /dev/null +++ b/lib/apscheduler/schedulers/background.py @@ -0,0 +1,39 @@ +from __future__ import absolute_import +from threading import Thread, Event + +from apscheduler.schedulers.base import BaseScheduler +from apscheduler.schedulers.blocking import BlockingScheduler +from apscheduler.util import asbool + + +class BackgroundScheduler(BlockingScheduler): + """ + A scheduler that runs in the background using a separate thread + (:meth:`~apscheduler.schedulers.base.BaseScheduler.start` will return immediately). + + Extra options: + + ========== ============================================================================================ + ``daemon`` Set the ``daemon`` option in the background thread (defaults to ``True``, + see `the documentation `_ + for further details) + ========== ============================================================================================ + """ + + _thread = None + + def _configure(self, config): + self._daemon = asbool(config.pop('daemon', True)) + super(BackgroundScheduler, self)._configure(config) + + def start(self): + BaseScheduler.start(self) + self._event = Event() + self._thread = Thread(target=self._main_loop, name='APScheduler') + self._thread.daemon = self._daemon + self._thread.start() + + def shutdown(self, wait=True): + super(BackgroundScheduler, self).shutdown(wait) + self._thread.join() + del self._thread diff --git a/lib/apscheduler/schedulers/base.py b/lib/apscheduler/schedulers/base.py new file mode 100644 index 00000000..cdc0bf06 --- /dev/null +++ b/lib/apscheduler/schedulers/base.py @@ -0,0 +1,845 @@ +from __future__ import print_function +from abc import ABCMeta, abstractmethod +from collections import MutableMapping +from threading import RLock +from datetime import datetime +from logging import getLogger +import sys + +from pkg_resources import iter_entry_points +from tzlocal import get_localzone +import six + +from apscheduler.schedulers import SchedulerAlreadyRunningError, SchedulerNotRunningError +from apscheduler.executors.base import MaxInstancesReachedError, BaseExecutor +from apscheduler.executors.pool import ThreadPoolExecutor +from apscheduler.jobstores.base import ConflictingIdError, JobLookupError, BaseJobStore +from apscheduler.jobstores.memory import MemoryJobStore +from apscheduler.job import Job +from apscheduler.triggers.base import BaseTrigger +from apscheduler.util import asbool, asint, astimezone, maybe_ref, timedelta_seconds, undefined +from apscheduler.events import ( + SchedulerEvent, JobEvent, EVENT_SCHEDULER_START, EVENT_SCHEDULER_SHUTDOWN, EVENT_JOBSTORE_ADDED, + EVENT_JOBSTORE_REMOVED, EVENT_ALL, EVENT_JOB_MODIFIED, EVENT_JOB_REMOVED, EVENT_JOB_ADDED, EVENT_EXECUTOR_ADDED, + EVENT_EXECUTOR_REMOVED, EVENT_ALL_JOBS_REMOVED) + + +class BaseScheduler(six.with_metaclass(ABCMeta)): + """ + Abstract base class for all schedulers. Takes the following keyword arguments: + + :param str|logging.Logger logger: logger to use for the scheduler's logging (defaults to apscheduler.scheduler) + :param str|datetime.tzinfo timezone: the default time zone (defaults to the local timezone) + :param dict job_defaults: default values for newly added jobs + :param dict jobstores: a dictionary of job store alias -> job store instance or configuration dict + :param dict executors: a dictionary of executor alias -> executor instance or configuration dict + + .. seealso:: :ref:`scheduler-config` + """ + + _trigger_plugins = dict((ep.name, ep) for ep in iter_entry_points('apscheduler.triggers')) + _trigger_classes = {} + _executor_plugins = dict((ep.name, ep) for ep in iter_entry_points('apscheduler.executors')) + _executor_classes = {} + _jobstore_plugins = dict((ep.name, ep) for ep in iter_entry_points('apscheduler.jobstores')) + _jobstore_classes = {} + _stopped = True + + # + # Public API + # + + def __init__(self, gconfig={}, **options): + super(BaseScheduler, self).__init__() + self._executors = {} + self._executors_lock = self._create_lock() + self._jobstores = {} + self._jobstores_lock = self._create_lock() + self._listeners = [] + self._listeners_lock = self._create_lock() + self._pending_jobs = [] + self.configure(gconfig, **options) + + def configure(self, gconfig={}, prefix='apscheduler.', **options): + """ + Reconfigures the scheduler with the given options. Can only be done when the scheduler isn't running. + + :param dict gconfig: a "global" configuration dictionary whose values can be overridden by keyword arguments to + this method + :param str|unicode prefix: pick only those keys from ``gconfig`` that are prefixed with this string + (pass an empty string or ``None`` to use all keys) + :raises SchedulerAlreadyRunningError: if the scheduler is already running + """ + + if self.running: + raise SchedulerAlreadyRunningError + + # If a non-empty prefix was given, strip it from the keys in the global configuration dict + if prefix: + prefixlen = len(prefix) + gconfig = dict((key[prefixlen:], value) for key, value in six.iteritems(gconfig) if key.startswith(prefix)) + + # Create a structure from the dotted options (e.g. "a.b.c = d" -> {'a': {'b': {'c': 'd'}}}) + config = {} + for key, value in six.iteritems(gconfig): + parts = key.split('.') + parent = config + key = parts.pop(0) + while parts: + parent = parent.setdefault(key, {}) + key = parts.pop(0) + parent[key] = value + + # Override any options with explicit keyword arguments + config.update(options) + self._configure(config) + + @abstractmethod + def start(self): + """ + Starts the scheduler. The details of this process depend on the implementation. + + :raises SchedulerAlreadyRunningError: if the scheduler is already running + """ + + if self.running: + raise SchedulerAlreadyRunningError + + with self._executors_lock: + # Create a default executor if nothing else is configured + if 'default' not in self._executors: + self.add_executor(self._create_default_executor(), 'default') + + # Start all the executors + for alias, executor in six.iteritems(self._executors): + executor.start(self, alias) + + with self._jobstores_lock: + # Create a default job store if nothing else is configured + if 'default' not in self._jobstores: + self.add_jobstore(self._create_default_jobstore(), 'default') + + # Start all the job stores + for alias, store in six.iteritems(self._jobstores): + store.start(self, alias) + + # Schedule all pending jobs + for job, jobstore_alias, replace_existing in self._pending_jobs: + self._real_add_job(job, jobstore_alias, replace_existing, False) + del self._pending_jobs[:] + + self._stopped = False + self._logger.info('Scheduler started') + + # Notify listeners that the scheduler has been started + self._dispatch_event(SchedulerEvent(EVENT_SCHEDULER_START)) + + @abstractmethod + def shutdown(self, wait=True): + """ + Shuts down the scheduler. Does not interrupt any currently running jobs. + + :param bool wait: ``True`` to wait until all currently executing jobs have finished + :raises SchedulerNotRunningError: if the scheduler has not been started yet + """ + + if not self.running: + raise SchedulerNotRunningError + + self._stopped = True + + # Shut down all executors + for executor in six.itervalues(self._executors): + executor.shutdown(wait) + + # Shut down all job stores + for jobstore in six.itervalues(self._jobstores): + jobstore.shutdown() + + self._logger.info('Scheduler has been shut down') + self._dispatch_event(SchedulerEvent(EVENT_SCHEDULER_SHUTDOWN)) + + @property + def running(self): + return not self._stopped + + def add_executor(self, executor, alias='default', **executor_opts): + """ + Adds an executor to this scheduler. Any extra keyword arguments will be passed to the executor plugin's + constructor, assuming that the first argument is the name of an executor plugin. + + :param str|unicode|apscheduler.executors.base.BaseExecutor executor: either an executor instance or the name of + an executor plugin + :param str|unicode alias: alias for the scheduler + :raises ValueError: if there is already an executor by the given alias + """ + + with self._executors_lock: + if alias in self._executors: + raise ValueError('This scheduler already has an executor by the alias of "%s"' % alias) + + if isinstance(executor, BaseExecutor): + self._executors[alias] = executor + elif isinstance(executor, six.string_types): + self._executors[alias] = executor = self._create_plugin_instance('executor', executor, executor_opts) + else: + raise TypeError('Expected an executor instance or a string, got %s instead' % + executor.__class__.__name__) + + # Start the executor right away if the scheduler is running + if self.running: + executor.start(self) + + self._dispatch_event(SchedulerEvent(EVENT_EXECUTOR_ADDED, alias)) + + def remove_executor(self, alias, shutdown=True): + """ + Removes the executor by the given alias from this scheduler. + + :param str|unicode alias: alias of the executor + :param bool shutdown: ``True`` to shut down the executor after removing it + """ + + with self._jobstores_lock: + executor = self._lookup_executor(alias) + del self._executors[alias] + + if shutdown: + executor.shutdown() + + self._dispatch_event(SchedulerEvent(EVENT_EXECUTOR_REMOVED, alias)) + + def add_jobstore(self, jobstore, alias='default', **jobstore_opts): + """ + Adds a job store to this scheduler. Any extra keyword arguments will be passed to the job store plugin's + constructor, assuming that the first argument is the name of a job store plugin. + + :param str|unicode|apscheduler.jobstores.base.BaseJobStore jobstore: job store to be added + :param str|unicode alias: alias for the job store + :raises ValueError: if there is already a job store by the given alias + """ + + with self._jobstores_lock: + if alias in self._jobstores: + raise ValueError('This scheduler already has a job store by the alias of "%s"' % alias) + + if isinstance(jobstore, BaseJobStore): + self._jobstores[alias] = jobstore + elif isinstance(jobstore, six.string_types): + self._jobstores[alias] = jobstore = self._create_plugin_instance('jobstore', jobstore, jobstore_opts) + else: + raise TypeError('Expected a job store instance or a string, got %s instead' % + jobstore.__class__.__name__) + + # Start the job store right away if the scheduler is running + if self.running: + jobstore.start(self, alias) + + # Notify listeners that a new job store has been added + self._dispatch_event(SchedulerEvent(EVENT_JOBSTORE_ADDED, alias)) + + # Notify the scheduler so it can scan the new job store for jobs + if self.running: + self.wakeup() + + def remove_jobstore(self, alias, shutdown=True): + """ + Removes the job store by the given alias from this scheduler. + + :param str|unicode alias: alias of the job store + :param bool shutdown: ``True`` to shut down the job store after removing it + """ + + with self._jobstores_lock: + jobstore = self._lookup_jobstore(alias) + del self._jobstores[alias] + + if shutdown: + jobstore.shutdown() + + self._dispatch_event(SchedulerEvent(EVENT_JOBSTORE_REMOVED, alias)) + + def add_listener(self, callback, mask=EVENT_ALL): + """ + add_listener(callback, mask=EVENT_ALL) + + Adds a listener for scheduler events. When a matching event occurs, ``callback`` is executed with the event + object as its sole argument. If the ``mask`` parameter is not provided, the callback will receive events of all + types. + + :param callback: any callable that takes one argument + :param int mask: bitmask that indicates which events should be listened to + + .. seealso:: :mod:`apscheduler.events` + .. seealso:: :ref:`scheduler-events` + """ + + with self._listeners_lock: + self._listeners.append((callback, mask)) + + def remove_listener(self, callback): + """Removes a previously added event listener.""" + + with self._listeners_lock: + for i, (cb, _) in enumerate(self._listeners): + if callback == cb: + del self._listeners[i] + + def add_job(self, func, trigger=None, args=None, kwargs=None, id=None, name=None, misfire_grace_time=undefined, + coalesce=undefined, max_instances=undefined, next_run_time=undefined, jobstore='default', + executor='default', replace_existing=False, **trigger_args): + """ + add_job(func, trigger=None, args=None, kwargs=None, id=None, name=None, misfire_grace_time=undefined, \ + coalesce=undefined, max_instances=undefined, next_run_time=undefined, jobstore='default', \ + executor='default', replace_existing=False, **trigger_args) + + Adds the given job to the job list and wakes up the scheduler if it's already running. + + Any option that defaults to ``undefined`` will be replaced with the corresponding default value when the job is + scheduled (which happens when the scheduler is started, or immediately if the scheduler is already running). + + The ``func`` argument can be given either as a callable object or a textual reference in the + ``package.module:some.object`` format, where the first half (separated by ``:``) is an importable module and the + second half is a reference to the callable object, relative to the module. + + The ``trigger`` argument can either be: + #. the alias name of the trigger (e.g. ``date``, ``interval`` or ``cron``), in which case any extra keyword + arguments to this method are passed on to the trigger's constructor + #. an instance of a trigger class + + :param func: callable (or a textual reference to one) to run at the given time + :param str|apscheduler.triggers.base.BaseTrigger trigger: trigger that determines when ``func`` is called + :param list|tuple args: list of positional arguments to call func with + :param dict kwargs: dict of keyword arguments to call func with + :param str|unicode id: explicit identifier for the job (for modifying it later) + :param str|unicode name: textual description of the job + :param int misfire_grace_time: seconds after the designated run time that the job is still allowed to be run + :param bool coalesce: run once instead of many times if the scheduler determines that the job should be run more + than once in succession + :param int max_instances: maximum number of concurrently running instances allowed for this job + :param datetime next_run_time: when to first run the job, regardless of the trigger (pass ``None`` to add the + job as paused) + :param str|unicode jobstore: alias of the job store to store the job in + :param str|unicode executor: alias of the executor to run the job with + :param bool replace_existing: ``True`` to replace an existing job with the same ``id`` (but retain the + number of runs from the existing one) + :rtype: Job + """ + + job_kwargs = { + 'trigger': self._create_trigger(trigger, trigger_args), + 'executor': executor, + 'func': func, + 'args': tuple(args) if args is not None else (), + 'kwargs': dict(kwargs) if kwargs is not None else {}, + 'id': id, + 'name': name, + 'misfire_grace_time': misfire_grace_time, + 'coalesce': coalesce, + 'max_instances': max_instances, + 'next_run_time': next_run_time + } + job_kwargs = dict((key, value) for key, value in six.iteritems(job_kwargs) if value is not undefined) + job = Job(self, **job_kwargs) + + # Don't really add jobs to job stores before the scheduler is up and running + with self._jobstores_lock: + if not self.running: + self._pending_jobs.append((job, jobstore, replace_existing)) + self._logger.info('Adding job tentatively -- it will be properly scheduled when the scheduler starts') + else: + self._real_add_job(job, jobstore, replace_existing, True) + + return job + + def scheduled_job(self, trigger, args=None, kwargs=None, id=None, name=None, misfire_grace_time=undefined, + coalesce=undefined, max_instances=undefined, next_run_time=undefined, jobstore='default', + executor='default', **trigger_args): + """ + scheduled_job(trigger, args=None, kwargs=None, id=None, name=None, misfire_grace_time=undefined, \ + coalesce=undefined, max_instances=undefined, next_run_time=undefined, jobstore='default', \ + executor='default',**trigger_args) + + A decorator version of :meth:`add_job`, except that ``replace_existing`` is always ``True``. + + .. important:: The ``id`` argument must be given if scheduling a job in a persistent job store. The scheduler + cannot, however, enforce this requirement. + """ + + def inner(func): + self.add_job(func, trigger, args, kwargs, id, name, misfire_grace_time, coalesce, max_instances, + next_run_time, jobstore, executor, True, **trigger_args) + return func + return inner + + def modify_job(self, job_id, jobstore=None, **changes): + """ + Modifies the properties of a single job. Modifications are passed to this method as extra keyword arguments. + + :param str|unicode job_id: the identifier of the job + :param str|unicode jobstore: alias of the job store that contains the job + """ + with self._jobstores_lock: + job, jobstore = self._lookup_job(job_id, jobstore) + job._modify(**changes) + if jobstore: + self._lookup_jobstore(jobstore).update_job(job) + + self._dispatch_event(JobEvent(EVENT_JOB_MODIFIED, job_id, jobstore)) + + # Wake up the scheduler since the job's next run time may have been changed + self.wakeup() + + def reschedule_job(self, job_id, jobstore=None, trigger=None, **trigger_args): + """ + Constructs a new trigger for a job and updates its next run time. + Extra keyword arguments are passed directly to the trigger's constructor. + + :param str|unicode job_id: the identifier of the job + :param str|unicode jobstore: alias of the job store that contains the job + :param trigger: alias of the trigger type or a trigger instance + """ + + trigger = self._create_trigger(trigger, trigger_args) + now = datetime.now(self.timezone) + next_run_time = trigger.get_next_fire_time(None, now) + self.modify_job(job_id, jobstore, trigger=trigger, next_run_time=next_run_time) + + def pause_job(self, job_id, jobstore=None): + """ + Causes the given job not to be executed until it is explicitly resumed. + + :param str|unicode job_id: the identifier of the job + :param str|unicode jobstore: alias of the job store that contains the job + """ + + self.modify_job(job_id, jobstore, next_run_time=None) + + def resume_job(self, job_id, jobstore=None): + """ + Resumes the schedule of the given job, or removes the job if its schedule is finished. + + :param str|unicode job_id: the identifier of the job + :param str|unicode jobstore: alias of the job store that contains the job + """ + + with self._jobstores_lock: + job, jobstore = self._lookup_job(job_id, jobstore) + now = datetime.now(self.timezone) + next_run_time = job.trigger.get_next_fire_time(None, now) + if next_run_time: + self.modify_job(job_id, jobstore, next_run_time=next_run_time) + else: + self.remove_job(job.id, jobstore) + + def get_jobs(self, jobstore=None, pending=None): + """ + Returns a list of pending jobs (if the scheduler hasn't been started yet) and scheduled jobs, either from a + specific job store or from all of them. + + :param str|unicode jobstore: alias of the job store + :param bool pending: ``False`` to leave out pending jobs (jobs that are waiting for the scheduler start to be + added to their respective job stores), ``True`` to only include pending jobs, anything else + to return both + :rtype: list[Job] + """ + + with self._jobstores_lock: + jobs = [] + + if pending is not False: + for job, alias, replace_existing in self._pending_jobs: + if jobstore is None or alias == jobstore: + jobs.append(job) + + if pending is not True: + for alias, store in six.iteritems(self._jobstores): + if jobstore is None or alias == jobstore: + jobs.extend(store.get_all_jobs()) + + return jobs + + def get_job(self, job_id, jobstore=None): + """ + Returns the Job that matches the given ``job_id``. + + :param str|unicode job_id: the identifier of the job + :param str|unicode jobstore: alias of the job store that most likely contains the job + :return: the Job by the given ID, or ``None`` if it wasn't found + :rtype: Job + """ + + with self._jobstores_lock: + try: + return self._lookup_job(job_id, jobstore)[0] + except JobLookupError: + return + + def remove_job(self, job_id, jobstore=None): + """ + Removes a job, preventing it from being run any more. + + :param str|unicode job_id: the identifier of the job + :param str|unicode jobstore: alias of the job store that contains the job + :raises JobLookupError: if the job was not found + """ + + with self._jobstores_lock: + # Check if the job is among the pending jobs + for i, (job, jobstore_alias, replace_existing) in enumerate(self._pending_jobs): + if job.id == job_id: + del self._pending_jobs[i] + jobstore = jobstore_alias + break + else: + # Otherwise, try to remove it from each store until it succeeds or we run out of stores to check + for alias, store in six.iteritems(self._jobstores): + if jobstore in (None, alias): + try: + store.remove_job(job_id) + except JobLookupError: + continue + + jobstore = alias + break + + if jobstore is None: + raise JobLookupError(job_id) + + # Notify listeners that a job has been removed + event = JobEvent(EVENT_JOB_REMOVED, job_id, jobstore) + self._dispatch_event(event) + + self._logger.info('Removed job %s', job_id) + + def remove_all_jobs(self, jobstore=None): + """ + Removes all jobs from the specified job store, or all job stores if none is given. + + :param str|unicode jobstore: alias of the job store + """ + + with self._jobstores_lock: + if jobstore: + self._pending_jobs = [pending for pending in self._pending_jobs if pending[1] != jobstore] + else: + self._pending_jobs = [] + + for alias, store in six.iteritems(self._jobstores): + if jobstore in (None, alias): + store.remove_all_jobs() + + self._dispatch_event(SchedulerEvent(EVENT_ALL_JOBS_REMOVED, jobstore)) + + def print_jobs(self, jobstore=None, out=None): + """ + print_jobs(jobstore=None, out=sys.stdout) + + Prints out a textual listing of all jobs currently scheduled on either all job stores or just a specific one. + + :param str|unicode jobstore: alias of the job store, ``None`` to list jobs from all stores + :param file out: a file-like object to print to (defaults to **sys.stdout** if nothing is given) + """ + + out = out or sys.stdout + with self._jobstores_lock: + if self._pending_jobs: + print(six.u('Pending jobs:'), file=out) + for job, jobstore_alias, replace_existing in self._pending_jobs: + if jobstore in (None, jobstore_alias): + print(six.u(' %s') % job, file=out) + + for alias, store in six.iteritems(self._jobstores): + if jobstore in (None, alias): + print(six.u('Jobstore %s:') % alias, file=out) + jobs = store.get_all_jobs() + if jobs: + for job in jobs: + print(six.u(' %s') % job, file=out) + else: + print(six.u(' No scheduled jobs'), file=out) + + @abstractmethod + def wakeup(self): + """ + Notifies the scheduler that there may be jobs due for execution. + Triggers :meth:`_process_jobs` to be run in an implementation specific manner. + """ + + # + # Private API + # + + def _configure(self, config): + # Set general options + self._logger = maybe_ref(config.pop('logger', None)) or getLogger('apscheduler.scheduler') + self.timezone = astimezone(config.pop('timezone', None)) or get_localzone() + + # Set the job defaults + job_defaults = config.get('job_defaults', {}) + self._job_defaults = { + 'misfire_grace_time': asint(job_defaults.get('misfire_grace_time', 1)), + 'coalesce': asbool(job_defaults.get('coalesce', True)), + 'max_instances': asint(job_defaults.get('max_instances', 1)) + } + + # Configure executors + self._executors.clear() + for alias, value in six.iteritems(config.get('executors', {})): + if isinstance(value, BaseExecutor): + self.add_executor(value, alias) + elif isinstance(value, MutableMapping): + executor_class = value.pop('class', None) + plugin = value.pop('type', None) + if plugin: + executor = self._create_plugin_instance('executor', plugin, value) + elif executor_class: + cls = maybe_ref(executor_class) + executor = cls(**value) + else: + raise ValueError('Cannot create executor "%s" -- either "type" or "class" must be defined' % alias) + + self.add_executor(executor, alias) + else: + raise TypeError("Expected executor instance or dict for executors['%s'], got %s instead" % ( + alias, value.__class__.__name__)) + + # Configure job stores + self._jobstores.clear() + for alias, value in six.iteritems(config.get('jobstores', {})): + if isinstance(value, BaseJobStore): + self.add_jobstore(value, alias) + elif isinstance(value, MutableMapping): + jobstore_class = value.pop('class', None) + plugin = value.pop('type', None) + if plugin: + jobstore = self._create_plugin_instance('jobstore', plugin, value) + elif jobstore_class: + cls = maybe_ref(jobstore_class) + jobstore = cls(**value) + else: + raise ValueError('Cannot create job store "%s" -- either "type" or "class" must be defined' % alias) + + self.add_jobstore(jobstore, alias) + else: + raise TypeError("Expected job store instance or dict for jobstores['%s'], got %s instead" % ( + alias, value.__class__.__name__)) + + def _create_default_executor(self): + """Creates a default executor store, specific to the particular scheduler type.""" + + return ThreadPoolExecutor() + + def _create_default_jobstore(self): + """Creates a default job store, specific to the particular scheduler type.""" + + return MemoryJobStore() + + def _lookup_executor(self, alias): + """ + Returns the executor instance by the given name from the list of executors that were added to this scheduler. + + :type alias: str + :raises KeyError: if no executor by the given alias is not found + """ + + try: + return self._executors[alias] + except KeyError: + raise KeyError('No such executor: %s' % alias) + + def _lookup_jobstore(self, alias): + """ + Returns the job store instance by the given name from the list of job stores that were added to this scheduler. + + :type alias: str + :raises KeyError: if no job store by the given alias is not found + """ + + try: + return self._jobstores[alias] + except KeyError: + raise KeyError('No such job store: %s' % alias) + + def _lookup_job(self, job_id, jobstore_alias): + """ + Finds a job by its ID. + + :type job_id: str + :param str jobstore_alias: alias of a job store to look in + :return tuple[Job, str]: a tuple of job, jobstore alias (jobstore alias is None in case of a pending job) + :raises JobLookupError: if no job by the given ID is found. + """ + + # Check if the job is among the pending jobs + for job, alias, replace_existing in self._pending_jobs: + if job.id == job_id: + return job, None + + # Look in all job stores + for alias, store in six.iteritems(self._jobstores): + if jobstore_alias in (None, alias): + job = store.lookup_job(job_id) + if job is not None: + return job, alias + + raise JobLookupError(job_id) + + def _dispatch_event(self, event): + """ + Dispatches the given event to interested listeners. + + :param SchedulerEvent event: the event to send + """ + + with self._listeners_lock: + listeners = tuple(self._listeners) + + for cb, mask in listeners: + if event.code & mask: + try: + cb(event) + except: + self._logger.exception('Error notifying listener') + + def _real_add_job(self, job, jobstore_alias, replace_existing, wakeup): + """ + :param Job job: the job to add + :param bool replace_existing: ``True`` to use update_job() in case the job already exists in the store + :param bool wakeup: ``True`` to wake up the scheduler after adding the job + """ + + # Fill in undefined values with defaults + replacements = {} + for key, value in six.iteritems(self._job_defaults): + if not hasattr(job, key): + replacements[key] = value + + # Calculate the next run time if there is none defined + if not hasattr(job, 'next_run_time'): + now = datetime.now(self.timezone) + replacements['next_run_time'] = job.trigger.get_next_fire_time(None, now) + + # Apply any replacements + job._modify(**replacements) + + # Add the job to the given job store + store = self._lookup_jobstore(jobstore_alias) + try: + store.add_job(job) + except ConflictingIdError: + if replace_existing: + store.update_job(job) + else: + raise + + # Mark the job as no longer pending + job._jobstore_alias = jobstore_alias + + # Notify listeners that a new job has been added + event = JobEvent(EVENT_JOB_ADDED, job.id, jobstore_alias) + self._dispatch_event(event) + + self._logger.info('Added job "%s" to job store "%s"', job.name, jobstore_alias) + + # Notify the scheduler about the new job + if wakeup: + self.wakeup() + + def _create_plugin_instance(self, type_, alias, constructor_kwargs): + """Creates an instance of the given plugin type, loading the plugin first if necessary.""" + + plugin_container, class_container, base_class = { + 'trigger': (self._trigger_plugins, self._trigger_classes, BaseTrigger), + 'jobstore': (self._jobstore_plugins, self._jobstore_classes, BaseJobStore), + 'executor': (self._executor_plugins, self._executor_classes, BaseExecutor) + }[type_] + + try: + plugin_cls = class_container[alias] + except KeyError: + if alias in plugin_container: + plugin_cls = class_container[alias] = plugin_container[alias].load() + if not issubclass(plugin_cls, base_class): + raise TypeError('The {0} entry point does not point to a {0} class'.format(type_)) + else: + raise LookupError('No {0} by the name "{1}" was found'.format(type_, alias)) + + return plugin_cls(**constructor_kwargs) + + def _create_trigger(self, trigger, trigger_args): + if isinstance(trigger, BaseTrigger): + return trigger + elif trigger is None: + trigger = 'date' + elif not isinstance(trigger, six.string_types): + raise TypeError('Expected a trigger instance or string, got %s instead' % trigger.__class__.__name__) + + # Use the scheduler's time zone if nothing else is specified + trigger_args.setdefault('timezone', self.timezone) + + # Instantiate the trigger class + return self._create_plugin_instance('trigger', trigger, trigger_args) + + def _create_lock(self): + """Creates a reentrant lock object.""" + + return RLock() + + def _process_jobs(self): + """ + Iterates through jobs in every jobstore, starts jobs that are due and figures out how long to wait for the next + round. + """ + + self._logger.debug('Looking for jobs to run') + now = datetime.now(self.timezone) + next_wakeup_time = None + + with self._jobstores_lock: + for jobstore_alias, jobstore in six.iteritems(self._jobstores): + for job in jobstore.get_due_jobs(now): + # Look up the job's executor + try: + executor = self._lookup_executor(job.executor) + except: + self._logger.error( + 'Executor lookup ("%s") failed for job "%s" -- removing it from the job store', + job.executor, job) + self.remove_job(job.id, jobstore_alias) + continue + + run_times = job._get_run_times(now) + run_times = run_times[-1:] if run_times and job.coalesce else run_times + if run_times: + try: + executor.submit_job(job, run_times) + except MaxInstancesReachedError: + self._logger.warning( + 'Execution of job "%s" skipped: maximum number of running instances reached (%d)', + job, job.max_instances) + except: + self._logger.exception('Error submitting job "%s" to executor "%s"', job, job.executor) + + # Update the job if it has a next execution time. Otherwise remove it from the job store. + job_next_run = job.trigger.get_next_fire_time(run_times[-1], now) + if job_next_run: + job._modify(next_run_time=job_next_run) + jobstore.update_job(job) + else: + self.remove_job(job.id, jobstore_alias) + + # Set a new next wakeup time if there isn't one yet or the jobstore has an even earlier one + jobstore_next_run_time = jobstore.get_next_run_time() + if jobstore_next_run_time and (next_wakeup_time is None or jobstore_next_run_time < next_wakeup_time): + next_wakeup_time = jobstore_next_run_time + + # Determine the delay until this method should be called again + if next_wakeup_time is not None: + wait_seconds = max(timedelta_seconds(next_wakeup_time - now), 0) + self._logger.debug('Next wakeup is due at %s (in %f seconds)', next_wakeup_time, wait_seconds) + else: + wait_seconds = None + self._logger.debug('No jobs; waiting until a job is added') + + return wait_seconds diff --git a/lib/apscheduler/schedulers/blocking.py b/lib/apscheduler/schedulers/blocking.py new file mode 100644 index 00000000..2720822c --- /dev/null +++ b/lib/apscheduler/schedulers/blocking.py @@ -0,0 +1,32 @@ +from __future__ import absolute_import +from threading import Event + +from apscheduler.schedulers.base import BaseScheduler + + +class BlockingScheduler(BaseScheduler): + """ + A scheduler that runs in the foreground (:meth:`~apscheduler.schedulers.base.BaseScheduler.start` will block). + """ + + MAX_WAIT_TIME = 4294967 # Maximum value accepted by Event.wait() on Windows + + _event = None + + def start(self): + super(BlockingScheduler, self).start() + self._event = Event() + self._main_loop() + + def shutdown(self, wait=True): + super(BlockingScheduler, self).shutdown(wait) + self._event.set() + + def _main_loop(self): + while self.running: + wait_seconds = self._process_jobs() + self._event.wait(wait_seconds if wait_seconds is not None else self.MAX_WAIT_TIME) + self._event.clear() + + def wakeup(self): + self._event.set() diff --git a/lib/apscheduler/schedulers/gevent.py b/lib/apscheduler/schedulers/gevent.py new file mode 100644 index 00000000..9cce6589 --- /dev/null +++ b/lib/apscheduler/schedulers/gevent.py @@ -0,0 +1,35 @@ +from __future__ import absolute_import + +from apscheduler.schedulers.blocking import BlockingScheduler +from apscheduler.schedulers.base import BaseScheduler + +try: + from gevent.event import Event + from gevent.lock import RLock + import gevent +except ImportError: # pragma: nocover + raise ImportError('GeventScheduler requires gevent installed') + + +class GeventScheduler(BlockingScheduler): + """A scheduler that runs as a Gevent greenlet.""" + + _greenlet = None + + def start(self): + BaseScheduler.start(self) + self._event = Event() + self._greenlet = gevent.spawn(self._main_loop) + return self._greenlet + + def shutdown(self, wait=True): + super(GeventScheduler, self).shutdown(wait) + self._greenlet.join() + del self._greenlet + + def _create_lock(self): + return RLock() + + def _create_default_executor(self): + from apscheduler.executors.gevent import GeventExecutor + return GeventExecutor() diff --git a/lib/apscheduler/schedulers/qt.py b/lib/apscheduler/schedulers/qt.py new file mode 100644 index 00000000..dde5afaa --- /dev/null +++ b/lib/apscheduler/schedulers/qt.py @@ -0,0 +1,46 @@ +from __future__ import absolute_import + +from apscheduler.schedulers.base import BaseScheduler + +try: + from PyQt5.QtCore import QObject, QTimer +except ImportError: # pragma: nocover + try: + from PyQt4.QtCore import QObject, QTimer + except ImportError: + try: + from PySide.QtCore import QObject, QTimer # flake8: noqa + except ImportError: + raise ImportError('QtScheduler requires either PyQt5, PyQt4 or PySide installed') + + +class QtScheduler(BaseScheduler): + """A scheduler that runs in a Qt event loop.""" + + _timer = None + + def start(self): + super(QtScheduler, self).start() + self.wakeup() + + def shutdown(self, wait=True): + super(QtScheduler, self).shutdown(wait) + self._stop_timer() + + def _start_timer(self, wait_seconds): + self._stop_timer() + if wait_seconds is not None: + self._timer = QTimer.singleShot(wait_seconds * 1000, self._process_jobs) + + def _stop_timer(self): + if self._timer: + if self._timer.isActive(): + self._timer.stop() + del self._timer + + def wakeup(self): + self._start_timer(0) + + def _process_jobs(self): + wait_seconds = super(QtScheduler, self)._process_jobs() + self._start_timer(wait_seconds) diff --git a/lib/apscheduler/schedulers/tornado.py b/lib/apscheduler/schedulers/tornado.py new file mode 100644 index 00000000..78093308 --- /dev/null +++ b/lib/apscheduler/schedulers/tornado.py @@ -0,0 +1,60 @@ +from __future__ import absolute_import +from datetime import timedelta +from functools import wraps + +from apscheduler.schedulers.base import BaseScheduler +from apscheduler.util import maybe_ref + +try: + from tornado.ioloop import IOLoop +except ImportError: # pragma: nocover + raise ImportError('TornadoScheduler requires tornado installed') + + +def run_in_ioloop(func): + @wraps(func) + def wrapper(self, *args, **kwargs): + self._ioloop.add_callback(func, self, *args, **kwargs) + return wrapper + + +class TornadoScheduler(BaseScheduler): + """ + A scheduler that runs on a Tornado IOLoop. + + =========== =============================================================== + ``io_loop`` Tornado IOLoop instance to use (defaults to the global IO loop) + =========== =============================================================== + """ + + _ioloop = None + _timeout = None + + def start(self): + super(TornadoScheduler, self).start() + self.wakeup() + + @run_in_ioloop + def shutdown(self, wait=True): + super(TornadoScheduler, self).shutdown(wait) + self._stop_timer() + + def _configure(self, config): + self._ioloop = maybe_ref(config.pop('io_loop', None)) or IOLoop.current() + super(TornadoScheduler, self)._configure(config) + + def _start_timer(self, wait_seconds): + self._stop_timer() + if wait_seconds is not None: + self._timeout = self._ioloop.add_timeout(timedelta(seconds=wait_seconds), self.wakeup) + + def _stop_timer(self): + if self._timeout: + self._ioloop.remove_timeout(self._timeout) + del self._timeout + + @run_in_ioloop + def wakeup(self): + self._stop_timer() + wait_seconds = self._process_jobs() + self._start_timer(wait_seconds) diff --git a/lib/apscheduler/schedulers/twisted.py b/lib/apscheduler/schedulers/twisted.py new file mode 100644 index 00000000..166b6130 --- /dev/null +++ b/lib/apscheduler/schedulers/twisted.py @@ -0,0 +1,65 @@ +from __future__ import absolute_import +from functools import wraps + +from apscheduler.schedulers.base import BaseScheduler +from apscheduler.util import maybe_ref + +try: + from twisted.internet import reactor as default_reactor +except ImportError: # pragma: nocover + raise ImportError('TwistedScheduler requires Twisted installed') + + +def run_in_reactor(func): + @wraps(func) + def wrapper(self, *args, **kwargs): + self._reactor.callFromThread(func, self, *args, **kwargs) + return wrapper + + +class TwistedScheduler(BaseScheduler): + """ + A scheduler that runs on a Twisted reactor. + + Extra options: + + =========== ======================================================== + ``reactor`` Reactor instance to use (defaults to the global reactor) + =========== ======================================================== + """ + + _reactor = None + _delayedcall = None + + def _configure(self, config): + self._reactor = maybe_ref(config.pop('reactor', default_reactor)) + super(TwistedScheduler, self)._configure(config) + + def start(self): + super(TwistedScheduler, self).start() + self.wakeup() + + @run_in_reactor + def shutdown(self, wait=True): + super(TwistedScheduler, self).shutdown(wait) + self._stop_timer() + + def _start_timer(self, wait_seconds): + self._stop_timer() + if wait_seconds is not None: + self._delayedcall = self._reactor.callLater(wait_seconds, self.wakeup) + + def _stop_timer(self): + if self._delayedcall and self._delayedcall.active(): + self._delayedcall.cancel() + del self._delayedcall + + @run_in_reactor + def wakeup(self): + self._stop_timer() + wait_seconds = self._process_jobs() + self._start_timer(wait_seconds) + + def _create_default_executor(self): + from apscheduler.executors.twisted import TwistedExecutor + return TwistedExecutor() diff --git a/lib/apscheduler/triggers/__init__.py b/lib/apscheduler/triggers/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/lib/apscheduler/triggers/base.py b/lib/apscheduler/triggers/base.py new file mode 100644 index 00000000..3520d316 --- /dev/null +++ b/lib/apscheduler/triggers/base.py @@ -0,0 +1,16 @@ +from abc import ABCMeta, abstractmethod + +import six + + +class BaseTrigger(six.with_metaclass(ABCMeta)): + """Abstract base class that defines the interface that every trigger must implement.""" + + @abstractmethod + def get_next_fire_time(self, previous_fire_time, now): + """ + Returns the next datetime to fire on, If no such datetime can be calculated, returns ``None``. + + :param datetime.datetime previous_fire_time: the previous time the trigger was fired + :param datetime.datetime now: current datetime + """ diff --git a/lib/apscheduler/triggers/cron/__init__.py b/lib/apscheduler/triggers/cron/__init__.py new file mode 100644 index 00000000..8df901ea --- /dev/null +++ b/lib/apscheduler/triggers/cron/__init__.py @@ -0,0 +1,176 @@ +from datetime import datetime, timedelta + +from tzlocal import get_localzone +import six + +from apscheduler.triggers.base import BaseTrigger +from apscheduler.triggers.cron.fields import BaseField, WeekField, DayOfMonthField, DayOfWeekField, DEFAULT_VALUES +from apscheduler.util import datetime_ceil, convert_to_datetime, datetime_repr, astimezone + + +class CronTrigger(BaseTrigger): + """ + Triggers when current time matches all specified time constraints, similarly to how the UNIX cron scheduler works. + + :param int|str year: 4-digit year + :param int|str month: month (1-12) + :param int|str day: day of the (1-31) + :param int|str week: ISO week (1-53) + :param int|str day_of_week: number or name of weekday (0-6 or mon,tue,wed,thu,fri,sat,sun) + :param int|str hour: hour (0-23) + :param int|str minute: minute (0-59) + :param int|str second: second (0-59) + :param datetime|str start_date: earliest possible date/time to trigger on (inclusive) + :param datetime|str end_date: latest possible date/time to trigger on (inclusive) + :param datetime.tzinfo|str timezone: time zone to use for the date/time calculations + (defaults to scheduler timezone) + + .. note:: The first weekday is always **monday**. + """ + + FIELD_NAMES = ('year', 'month', 'day', 'week', 'day_of_week', 'hour', 'minute', 'second') + FIELDS_MAP = { + 'year': BaseField, + 'month': BaseField, + 'week': WeekField, + 'day': DayOfMonthField, + 'day_of_week': DayOfWeekField, + 'hour': BaseField, + 'minute': BaseField, + 'second': BaseField + } + + __slots__ = 'timezone', 'start_date', 'end_date', 'fields' + + def __init__(self, year=None, month=None, day=None, week=None, day_of_week=None, hour=None, minute=None, + second=None, start_date=None, end_date=None, timezone=None): + if timezone: + self.timezone = astimezone(timezone) + elif start_date and start_date.tzinfo: + self.timezone = start_date.tzinfo + elif end_date and end_date.tzinfo: + self.timezone = end_date.tzinfo + else: + self.timezone = get_localzone() + + self.start_date = convert_to_datetime(start_date, self.timezone, 'start_date') + self.end_date = convert_to_datetime(end_date, self.timezone, 'end_date') + + values = dict((key, value) for (key, value) in six.iteritems(locals()) + if key in self.FIELD_NAMES and value is not None) + self.fields = [] + assign_defaults = False + for field_name in self.FIELD_NAMES: + if field_name in values: + exprs = values.pop(field_name) + is_default = False + assign_defaults = not values + elif assign_defaults: + exprs = DEFAULT_VALUES[field_name] + is_default = True + else: + exprs = '*' + is_default = True + + field_class = self.FIELDS_MAP[field_name] + field = field_class(field_name, exprs, is_default) + self.fields.append(field) + + def _increment_field_value(self, dateval, fieldnum): + """ + Increments the designated field and resets all less significant fields to their minimum values. + + :type dateval: datetime + :type fieldnum: int + :return: a tuple containing the new date, and the number of the field that was actually incremented + :rtype: tuple + """ + + values = {} + i = 0 + while i < len(self.fields): + field = self.fields[i] + if not field.REAL: + if i == fieldnum: + fieldnum -= 1 + i -= 1 + else: + i += 1 + continue + + if i < fieldnum: + values[field.name] = field.get_value(dateval) + i += 1 + elif i > fieldnum: + values[field.name] = field.get_min(dateval) + i += 1 + else: + value = field.get_value(dateval) + maxval = field.get_max(dateval) + if value == maxval: + fieldnum -= 1 + i -= 1 + else: + values[field.name] = value + 1 + i += 1 + + difference = datetime(**values) - dateval.replace(tzinfo=None) + return self.timezone.normalize(dateval + difference), fieldnum + + def _set_field_value(self, dateval, fieldnum, new_value): + values = {} + for i, field in enumerate(self.fields): + if field.REAL: + if i < fieldnum: + values[field.name] = field.get_value(dateval) + elif i > fieldnum: + values[field.name] = field.get_min(dateval) + else: + values[field.name] = new_value + + difference = datetime(**values) - dateval.replace(tzinfo=None) + return self.timezone.normalize(dateval + difference) + + def get_next_fire_time(self, previous_fire_time, now): + if previous_fire_time: + start_date = max(now, previous_fire_time + timedelta(microseconds=1)) + else: + start_date = max(now, self.start_date) if self.start_date else now + + fieldnum = 0 + next_date = datetime_ceil(start_date).astimezone(self.timezone) + while 0 <= fieldnum < len(self.fields): + field = self.fields[fieldnum] + curr_value = field.get_value(next_date) + next_value = field.get_next_value(next_date) + + if next_value is None: + # No valid value was found + next_date, fieldnum = self._increment_field_value(next_date, fieldnum - 1) + elif next_value > curr_value: + # A valid, but higher than the starting value, was found + if field.REAL: + next_date = self._set_field_value(next_date, fieldnum, next_value) + fieldnum += 1 + else: + next_date, fieldnum = self._increment_field_value(next_date, fieldnum) + else: + # A valid value was found, no changes necessary + fieldnum += 1 + + # Return if the date has rolled past the end date + if self.end_date and next_date > self.end_date: + return None + + if fieldnum >= 0: + return next_date + + def __str__(self): + options = ["%s='%s'" % (f.name, f) for f in self.fields if not f.is_default] + return 'cron[%s]' % (', '.join(options)) + + def __repr__(self): + options = ["%s='%s'" % (f.name, f) for f in self.fields if not f.is_default] + if self.start_date: + options.append("start_date='%s'" % datetime_repr(self.start_date)) + return '<%s (%s)>' % (self.__class__.__name__, ', '.join(options)) diff --git a/lib/apscheduler/triggers/cron/expressions.py b/lib/apscheduler/triggers/cron/expressions.py new file mode 100644 index 00000000..55272db4 --- /dev/null +++ b/lib/apscheduler/triggers/cron/expressions.py @@ -0,0 +1,188 @@ +""" +This module contains the expressions applicable for CronTrigger's fields. +""" + +from calendar import monthrange +import re + +from apscheduler.util import asint + +__all__ = ('AllExpression', 'RangeExpression', 'WeekdayRangeExpression', 'WeekdayPositionExpression', + 'LastDayOfMonthExpression') + + +WEEKDAYS = ['mon', 'tue', 'wed', 'thu', 'fri', 'sat', 'sun'] + + +class AllExpression(object): + value_re = re.compile(r'\*(?:/(?P\d+))?$') + + def __init__(self, step=None): + self.step = asint(step) + if self.step == 0: + raise ValueError('Increment must be higher than 0') + + def get_next_value(self, date, field): + start = field.get_value(date) + minval = field.get_min(date) + maxval = field.get_max(date) + start = max(start, minval) + + if not self.step: + next = start + else: + distance_to_next = (self.step - (start - minval)) % self.step + next = start + distance_to_next + + if next <= maxval: + return next + + def __str__(self): + if self.step: + return '*/%d' % self.step + return '*' + + def __repr__(self): + return "%s(%s)" % (self.__class__.__name__, self.step) + + +class RangeExpression(AllExpression): + value_re = re.compile( + r'(?P\d+)(?:-(?P\d+))?(?:/(?P\d+))?$') + + def __init__(self, first, last=None, step=None): + AllExpression.__init__(self, step) + first = asint(first) + last = asint(last) + if last is None and step is None: + last = first + if last is not None and first > last: + raise ValueError('The minimum value in a range must not be higher than the maximum') + self.first = first + self.last = last + + def get_next_value(self, date, field): + start = field.get_value(date) + minval = field.get_min(date) + maxval = field.get_max(date) + + # Apply range limits + minval = max(minval, self.first) + if self.last is not None: + maxval = min(maxval, self.last) + start = max(start, minval) + + if not self.step: + next = start + else: + distance_to_next = (self.step - (start - minval)) % self.step + next = start + distance_to_next + + if next <= maxval: + return next + + def __str__(self): + if self.last != self.first and self.last is not None: + range = '%d-%d' % (self.first, self.last) + else: + range = str(self.first) + + if self.step: + return '%s/%d' % (range, self.step) + return range + + def __repr__(self): + args = [str(self.first)] + if self.last != self.first and self.last is not None or self.step: + args.append(str(self.last)) + if self.step: + args.append(str(self.step)) + return "%s(%s)" % (self.__class__.__name__, ', '.join(args)) + + +class WeekdayRangeExpression(RangeExpression): + value_re = re.compile(r'(?P[a-z]+)(?:-(?P[a-z]+))?', re.IGNORECASE) + + def __init__(self, first, last=None): + try: + first_num = WEEKDAYS.index(first.lower()) + except ValueError: + raise ValueError('Invalid weekday name "%s"' % first) + + if last: + try: + last_num = WEEKDAYS.index(last.lower()) + except ValueError: + raise ValueError('Invalid weekday name "%s"' % last) + else: + last_num = None + + RangeExpression.__init__(self, first_num, last_num) + + def __str__(self): + if self.last != self.first and self.last is not None: + return '%s-%s' % (WEEKDAYS[self.first], WEEKDAYS[self.last]) + return WEEKDAYS[self.first] + + def __repr__(self): + args = ["'%s'" % WEEKDAYS[self.first]] + if self.last != self.first and self.last is not None: + args.append("'%s'" % WEEKDAYS[self.last]) + return "%s(%s)" % (self.__class__.__name__, ', '.join(args)) + + +class WeekdayPositionExpression(AllExpression): + options = ['1st', '2nd', '3rd', '4th', '5th', 'last'] + value_re = re.compile(r'(?P%s) +(?P(?:\d+|\w+))' % '|'.join(options), re.IGNORECASE) + + def __init__(self, option_name, weekday_name): + try: + self.option_num = self.options.index(option_name.lower()) + except ValueError: + raise ValueError('Invalid weekday position "%s"' % option_name) + + try: + self.weekday = WEEKDAYS.index(weekday_name.lower()) + except ValueError: + raise ValueError('Invalid weekday name "%s"' % weekday_name) + + def get_next_value(self, date, field): + # Figure out the weekday of the month's first day and the number + # of days in that month + first_day_wday, last_day = monthrange(date.year, date.month) + + # Calculate which day of the month is the first of the target weekdays + first_hit_day = self.weekday - first_day_wday + 1 + if first_hit_day <= 0: + first_hit_day += 7 + + # Calculate what day of the month the target weekday would be + if self.option_num < 5: + target_day = first_hit_day + self.option_num * 7 + else: + target_day = first_hit_day + ((last_day - first_hit_day) / 7) * 7 + + if target_day <= last_day and target_day >= date.day: + return target_day + + def __str__(self): + return '%s %s' % (self.options[self.option_num], WEEKDAYS[self.weekday]) + + def __repr__(self): + return "%s('%s', '%s')" % (self.__class__.__name__, self.options[self.option_num], WEEKDAYS[self.weekday]) + + +class LastDayOfMonthExpression(AllExpression): + value_re = re.compile(r'last', re.IGNORECASE) + + def __init__(self): + pass + + def get_next_value(self, date, field): + return monthrange(date.year, date.month)[1] + + def __str__(self): + return 'last' + + def __repr__(self): + return "%s()" % self.__class__.__name__ diff --git a/lib/apscheduler/triggers/cron/fields.py b/lib/apscheduler/triggers/cron/fields.py new file mode 100644 index 00000000..e220599f --- /dev/null +++ b/lib/apscheduler/triggers/cron/fields.py @@ -0,0 +1,97 @@ +""" +Fields represent CronTrigger options which map to :class:`~datetime.datetime` +fields. +""" + +from calendar import monthrange + +from apscheduler.triggers.cron.expressions import ( + AllExpression, RangeExpression, WeekdayPositionExpression, LastDayOfMonthExpression, WeekdayRangeExpression) + + +__all__ = ('MIN_VALUES', 'MAX_VALUES', 'DEFAULT_VALUES', 'BaseField', 'WeekField', 'DayOfMonthField', 'DayOfWeekField') + + +MIN_VALUES = {'year': 1970, 'month': 1, 'day': 1, 'week': 1, 'day_of_week': 0, 'hour': 0, 'minute': 0, 'second': 0} +MAX_VALUES = {'year': 2 ** 63, 'month': 12, 'day:': 31, 'week': 53, 'day_of_week': 6, 'hour': 23, 'minute': 59, + 'second': 59} +DEFAULT_VALUES = {'year': '*', 'month': 1, 'day': 1, 'week': '*', 'day_of_week': '*', 'hour': 0, 'minute': 0, + 'second': 0} + + +class BaseField(object): + REAL = True + COMPILERS = [AllExpression, RangeExpression] + + def __init__(self, name, exprs, is_default=False): + self.name = name + self.is_default = is_default + self.compile_expressions(exprs) + + def get_min(self, dateval): + return MIN_VALUES[self.name] + + def get_max(self, dateval): + return MAX_VALUES[self.name] + + def get_value(self, dateval): + return getattr(dateval, self.name) + + def get_next_value(self, dateval): + smallest = None + for expr in self.expressions: + value = expr.get_next_value(dateval, self) + if smallest is None or (value is not None and value < smallest): + smallest = value + + return smallest + + def compile_expressions(self, exprs): + self.expressions = [] + + # Split a comma-separated expression list, if any + exprs = str(exprs).strip() + if ',' in exprs: + for expr in exprs.split(','): + self.compile_expression(expr) + else: + self.compile_expression(exprs) + + def compile_expression(self, expr): + for compiler in self.COMPILERS: + match = compiler.value_re.match(expr) + if match: + compiled_expr = compiler(**match.groupdict()) + self.expressions.append(compiled_expr) + return + + raise ValueError('Unrecognized expression "%s" for field "%s"' % (expr, self.name)) + + def __str__(self): + expr_strings = (str(e) for e in self.expressions) + return ','.join(expr_strings) + + def __repr__(self): + return "%s('%s', '%s')" % (self.__class__.__name__, self.name, self) + + +class WeekField(BaseField): + REAL = False + + def get_value(self, dateval): + return dateval.isocalendar()[1] + + +class DayOfMonthField(BaseField): + COMPILERS = BaseField.COMPILERS + [WeekdayPositionExpression, LastDayOfMonthExpression] + + def get_max(self, dateval): + return monthrange(dateval.year, dateval.month)[1] + + +class DayOfWeekField(BaseField): + REAL = False + COMPILERS = BaseField.COMPILERS + [WeekdayRangeExpression] + + def get_value(self, dateval): + return dateval.weekday() diff --git a/lib/apscheduler/triggers/date.py b/lib/apscheduler/triggers/date.py new file mode 100644 index 00000000..237e6b4e --- /dev/null +++ b/lib/apscheduler/triggers/date.py @@ -0,0 +1,30 @@ +from datetime import datetime + +from tzlocal import get_localzone + +from apscheduler.triggers.base import BaseTrigger +from apscheduler.util import convert_to_datetime, datetime_repr, astimezone + + +class DateTrigger(BaseTrigger): + """ + Triggers once on the given datetime. If ``run_date`` is left empty, current time is used. + + :param datetime|str run_date: the date/time to run the job at + :param datetime.tzinfo|str timezone: time zone for ``run_date`` if it doesn't have one already + """ + + __slots__ = 'timezone', 'run_date' + + def __init__(self, run_date=None, timezone=None): + timezone = astimezone(timezone) or get_localzone() + self.run_date = convert_to_datetime(run_date or datetime.now(), timezone, 'run_date') + + def get_next_fire_time(self, previous_fire_time, now): + return self.run_date if previous_fire_time is None else None + + def __str__(self): + return 'date[%s]' % datetime_repr(self.run_date) + + def __repr__(self): + return "<%s (run_date='%s')>" % (self.__class__.__name__, datetime_repr(self.run_date)) diff --git a/lib/apscheduler/triggers/interval.py b/lib/apscheduler/triggers/interval.py new file mode 100644 index 00000000..df9e6fe7 --- /dev/null +++ b/lib/apscheduler/triggers/interval.py @@ -0,0 +1,65 @@ +from datetime import timedelta, datetime +from math import ceil + +from tzlocal import get_localzone + +from apscheduler.triggers.base import BaseTrigger +from apscheduler.util import convert_to_datetime, timedelta_seconds, datetime_repr, astimezone + + +class IntervalTrigger(BaseTrigger): + """ + Triggers on specified intervals, starting on ``start_date`` if specified, ``datetime.now()`` + interval + otherwise. + + :param int weeks: number of weeks to wait + :param int days: number of days to wait + :param int hours: number of hours to wait + :param int minutes: number of minutes to wait + :param int seconds: number of seconds to wait + :param datetime|str start_date: starting point for the interval calculation + :param datetime|str end_date: latest possible date/time to trigger on + :param datetime.tzinfo|str timezone: time zone to use for the date/time calculations + """ + + __slots__ = 'timezone', 'start_date', 'end_date', 'interval' + + def __init__(self, weeks=0, days=0, hours=0, minutes=0, seconds=0, start_date=None, end_date=None, timezone=None): + self.interval = timedelta(weeks=weeks, days=days, hours=hours, minutes=minutes, seconds=seconds) + self.interval_length = timedelta_seconds(self.interval) + if self.interval_length == 0: + self.interval = timedelta(seconds=1) + self.interval_length = 1 + + if timezone: + self.timezone = astimezone(timezone) + elif start_date and start_date.tzinfo: + self.timezone = start_date.tzinfo + elif end_date and end_date.tzinfo: + self.timezone = end_date.tzinfo + else: + self.timezone = get_localzone() + + start_date = start_date or (datetime.now(self.timezone) + self.interval) + self.start_date = convert_to_datetime(start_date, self.timezone, 'start_date') + self.end_date = convert_to_datetime(end_date, self.timezone, 'end_date') + + def get_next_fire_time(self, previous_fire_time, now): + if previous_fire_time: + next_fire_time = previous_fire_time + self.interval + elif self.start_date > now: + next_fire_time = self.start_date + else: + timediff_seconds = timedelta_seconds(now - self.start_date) + next_interval_num = int(ceil(timediff_seconds / self.interval_length)) + next_fire_time = self.start_date + self.interval * next_interval_num + + if not self.end_date or next_fire_time <= self.end_date: + return self.timezone.normalize(next_fire_time) + + def __str__(self): + return 'interval[%s]' % str(self.interval) + + def __repr__(self): + return "<%s (interval=%r, start_date='%s')>" % (self.__class__.__name__, self.interval, + datetime_repr(self.start_date)) diff --git a/lib/apscheduler/util.py b/lib/apscheduler/util.py new file mode 100644 index 00000000..988f9427 --- /dev/null +++ b/lib/apscheduler/util.py @@ -0,0 +1,385 @@ +"""This module contains several handy functions primarily meant for internal use.""" + +from __future__ import division +from datetime import date, datetime, time, timedelta, tzinfo +from inspect import isfunction, ismethod, getargspec +from calendar import timegm +import re + +from pytz import timezone, utc +import six + +try: + from inspect import signature +except ImportError: # pragma: nocover + try: + from funcsigs import signature + except ImportError: + signature = None + +__all__ = ('asint', 'asbool', 'astimezone', 'convert_to_datetime', 'datetime_to_utc_timestamp', + 'utc_timestamp_to_datetime', 'timedelta_seconds', 'datetime_ceil', 'get_callable_name', 'obj_to_ref', + 'ref_to_obj', 'maybe_ref', 'repr_escape', 'check_callable_args') + + +class _Undefined(object): + def __nonzero__(self): + return False + + def __bool__(self): + return False + + def __repr__(self): + return '' + +undefined = _Undefined() #: a unique object that only signifies that no value is defined + + +def asint(text): + """ + Safely converts a string to an integer, returning None if the string is None. + + :type text: str + :rtype: int + """ + + if text is not None: + return int(text) + + +def asbool(obj): + """ + Interprets an object as a boolean value. + + :rtype: bool + """ + + if isinstance(obj, str): + obj = obj.strip().lower() + if obj in ('true', 'yes', 'on', 'y', 't', '1'): + return True + if obj in ('false', 'no', 'off', 'n', 'f', '0'): + return False + raise ValueError('Unable to interpret value "%s" as boolean' % obj) + return bool(obj) + + +def astimezone(obj): + """ + Interprets an object as a timezone. + + :rtype: tzinfo + """ + + if isinstance(obj, six.string_types): + return timezone(obj) + if isinstance(obj, tzinfo): + if not hasattr(obj, 'localize') or not hasattr(obj, 'normalize'): + raise TypeError('Only timezones from the pytz library are supported') + if obj.zone == 'local': + raise ValueError('Unable to determine the name of the local timezone -- use an explicit timezone instead') + return obj + if obj is not None: + raise TypeError('Expected tzinfo, got %s instead' % obj.__class__.__name__) + + +_DATE_REGEX = re.compile( + r'(?P\d{4})-(?P\d{1,2})-(?P\d{1,2})' + r'(?: (?P\d{1,2}):(?P\d{1,2}):(?P\d{1,2})' + r'(?:\.(?P\d{1,6}))?)?') + + +def convert_to_datetime(input, tz, arg_name): + """ + Converts the given object to a timezone aware datetime object. + If a timezone aware datetime object is passed, it is returned unmodified. + If a native datetime object is passed, it is given the specified timezone. + If the input is a string, it is parsed as a datetime with the given timezone. + + Date strings are accepted in three different forms: date only (Y-m-d), + date with time (Y-m-d H:M:S) or with date+time with microseconds + (Y-m-d H:M:S.micro). + + :param str|datetime input: the datetime or string to convert to a timezone aware datetime + :param datetime.tzinfo tz: timezone to interpret ``input`` in + :param str arg_name: the name of the argument (used in an error message) + :rtype: datetime + """ + + if input is None: + return + elif isinstance(input, datetime): + datetime_ = input + elif isinstance(input, date): + datetime_ = datetime.combine(input, time()) + elif isinstance(input, six.string_types): + m = _DATE_REGEX.match(input) + if not m: + raise ValueError('Invalid date string') + values = [(k, int(v or 0)) for k, v in m.groupdict().items()] + values = dict(values) + datetime_ = datetime(**values) + else: + raise TypeError('Unsupported type for %s: %s' % (arg_name, input.__class__.__name__)) + + if datetime_.tzinfo is not None: + return datetime_ + if tz is None: + raise ValueError('The "tz" argument must be specified if %s has no timezone information' % arg_name) + if isinstance(tz, six.string_types): + tz = timezone(tz) + + try: + return tz.localize(datetime_, is_dst=None) + except AttributeError: + raise TypeError('Only pytz timezones are supported (need the localize() and normalize() methods)') + + +def datetime_to_utc_timestamp(timeval): + """ + Converts a datetime instance to a timestamp. + + :type timeval: datetime + :rtype: float + """ + + if timeval is not None: + return timegm(timeval.utctimetuple()) + timeval.microsecond / 1000000 + + +def utc_timestamp_to_datetime(timestamp): + """ + Converts the given timestamp to a datetime instance. + + :type timestamp: float + :rtype: datetime + """ + + if timestamp is not None: + return datetime.fromtimestamp(timestamp, utc) + + +def timedelta_seconds(delta): + """ + Converts the given timedelta to seconds. + + :type delta: timedelta + :rtype: float + """ + + return delta.days * 24 * 60 * 60 + delta.seconds + \ + delta.microseconds / 1000000.0 + + +def datetime_ceil(dateval): + """ + Rounds the given datetime object upwards. + + :type dateval: datetime + """ + + if dateval.microsecond > 0: + return dateval + timedelta(seconds=1, microseconds=-dateval.microsecond) + return dateval + + +def datetime_repr(dateval): + return dateval.strftime('%Y-%m-%d %H:%M:%S %Z') if dateval else 'None' + + +def get_callable_name(func): + """ + Returns the best available display name for the given function/callable. + + :rtype: str + """ + + # the easy case (on Python 3.3+) + if hasattr(func, '__qualname__'): + return func.__qualname__ + + # class methods, bound and unbound methods + f_self = getattr(func, '__self__', None) or getattr(func, 'im_self', None) + if f_self and hasattr(func, '__name__'): + f_class = f_self if isinstance(f_self, type) else f_self.__class__ + else: + f_class = getattr(func, 'im_class', None) + + if f_class and hasattr(func, '__name__'): + return '%s.%s' % (f_class.__name__, func.__name__) + + # class or class instance + if hasattr(func, '__call__'): + # class + if hasattr(func, '__name__'): + return func.__name__ + + # instance of a class with a __call__ method + return func.__class__.__name__ + + raise TypeError('Unable to determine a name for %r -- maybe it is not a callable?' % func) + + +def obj_to_ref(obj): + """ + Returns the path to the given object. + + :rtype: str + """ + + try: + ref = '%s:%s' % (obj.__module__, get_callable_name(obj)) + obj2 = ref_to_obj(ref) + if obj != obj2: + raise ValueError + except Exception: + raise ValueError('Cannot determine the reference to %r' % obj) + + return ref + + +def ref_to_obj(ref): + """ + Returns the object pointed to by ``ref``. + + :type ref: str + """ + + if not isinstance(ref, six.string_types): + raise TypeError('References must be strings') + if ':' not in ref: + raise ValueError('Invalid reference') + + modulename, rest = ref.split(':', 1) + try: + obj = __import__(modulename) + except ImportError: + raise LookupError('Error resolving reference %s: could not import module' % ref) + + try: + for name in modulename.split('.')[1:] + rest.split('.'): + obj = getattr(obj, name) + return obj + except Exception: + raise LookupError('Error resolving reference %s: error looking up object' % ref) + + +def maybe_ref(ref): + """ + Returns the object that the given reference points to, if it is indeed a reference. + If it is not a reference, the object is returned as-is. + """ + + if not isinstance(ref, str): + return ref + return ref_to_obj(ref) + + +if six.PY2: + def repr_escape(string): + if isinstance(string, six.text_type): + return string.encode('ascii', 'backslashreplace') + return string +else: + repr_escape = lambda string: string + + +def check_callable_args(func, args, kwargs): + """ + Ensures that the given callable can be called with the given arguments. + + :type args: tuple + :type kwargs: dict + """ + + pos_kwargs_conflicts = [] # parameters that have a match in both args and kwargs + positional_only_kwargs = [] # positional-only parameters that have a match in kwargs + unsatisfied_args = [] # parameters in signature that don't have a match in args or kwargs + unsatisfied_kwargs = [] # keyword-only arguments that don't have a match in kwargs + unmatched_args = list(args) # args that didn't match any of the parameters in the signature + unmatched_kwargs = list(kwargs) # kwargs that didn't match any of the parameters in the signature + has_varargs = has_var_kwargs = False # indicates if the signature defines *args and **kwargs respectively + + if signature: + try: + sig = signature(func) + except ValueError: + return # signature() doesn't work against every kind of callable + + for param in six.itervalues(sig.parameters): + if param.kind == param.POSITIONAL_OR_KEYWORD: + if param.name in unmatched_kwargs and unmatched_args: + pos_kwargs_conflicts.append(param.name) + elif unmatched_args: + del unmatched_args[0] + elif param.name in unmatched_kwargs: + unmatched_kwargs.remove(param.name) + elif param.default is param.empty: + unsatisfied_args.append(param.name) + elif param.kind == param.POSITIONAL_ONLY: + if unmatched_args: + del unmatched_args[0] + elif param.name in unmatched_kwargs: + unmatched_kwargs.remove(param.name) + positional_only_kwargs.append(param.name) + elif param.default is param.empty: + unsatisfied_args.append(param.name) + elif param.kind == param.KEYWORD_ONLY: + if param.name in unmatched_kwargs: + unmatched_kwargs.remove(param.name) + elif param.default is param.empty: + unsatisfied_kwargs.append(param.name) + elif param.kind == param.VAR_POSITIONAL: + has_varargs = True + elif param.kind == param.VAR_KEYWORD: + has_var_kwargs = True + else: + if not isfunction(func) and not ismethod(func) and hasattr(func, '__call__'): + func = func.__call__ + + try: + argspec = getargspec(func) + except TypeError: + return # getargspec() doesn't work certain callables + + argspec_args = argspec.args if not ismethod(func) else argspec.args[1:] + has_varargs = bool(argspec.varargs) + has_var_kwargs = bool(argspec.keywords) + for arg, default in six.moves.zip_longest(argspec_args, argspec.defaults or (), fillvalue=undefined): + if arg in unmatched_kwargs and unmatched_args: + pos_kwargs_conflicts.append(arg) + elif unmatched_args: + del unmatched_args[0] + elif arg in unmatched_kwargs: + unmatched_kwargs.remove(arg) + elif default is undefined: + unsatisfied_args.append(arg) + + # Make sure there are no conflicts between args and kwargs + if pos_kwargs_conflicts: + raise ValueError('The following arguments are supplied in both args and kwargs: %s' % + ', '.join(pos_kwargs_conflicts)) + + # Check if keyword arguments are being fed to positional-only parameters + if positional_only_kwargs: + raise ValueError('The following arguments cannot be given as keyword arguments: %s' % + ', '.join(positional_only_kwargs)) + + # Check that the number of positional arguments minus the number of matched kwargs matches the argspec + if unsatisfied_args: + raise ValueError('The following arguments have not been supplied: %s' % ', '.join(unsatisfied_args)) + + # Check that all keyword-only arguments have been supplied + if unsatisfied_kwargs: + raise ValueError('The following keyword-only arguments have not been supplied in kwargs: %s' % + ', '.join(unsatisfied_kwargs)) + + # Check that the callable can accept the given number of positional arguments + if not has_varargs and unmatched_args: + raise ValueError('The list of positional arguments is longer than the target callable can handle ' + '(allowed: %d, given in args: %d)' % (len(args) - len(unmatched_args), len(args))) + + # Check that the callable can accept the given keyword arguments + if not has_var_kwargs and unmatched_kwargs: + raise ValueError('The target callable does not accept the following keyword arguments: %s' % + ', '.join(unmatched_kwargs)) diff --git a/lib/argparse.py b/lib/argparse.py new file mode 100644 index 00000000..f0cfe27e --- /dev/null +++ b/lib/argparse.py @@ -0,0 +1,2386 @@ +# Author: Steven J. Bethard . + +"""Command-line parsing library + +This module is an optparse-inspired command-line parsing library that: + + - handles both optional and positional arguments + - produces highly informative usage messages + - supports parsers that dispatch to sub-parsers + +The following is a simple usage example that sums integers from the +command-line and writes the result to a file:: + + parser = argparse.ArgumentParser( + description='sum the integers at the command line') + parser.add_argument( + 'integers', metavar='int', nargs='+', type=int, + help='an integer to be summed') + parser.add_argument( + '--log', default=sys.stdout, type=argparse.FileType('w'), + help='the file where the sum should be written') + args = parser.parse_args() + args.log.write('%s' % sum(args.integers)) + args.log.close() + +The module contains the following public classes: + + - ArgumentParser -- The main entry point for command-line parsing. As the + example above shows, the add_argument() method is used to populate + the parser with actions for optional and positional arguments. Then + the parse_args() method is invoked to convert the args at the + command-line into an object with attributes. + + - ArgumentError -- The exception raised by ArgumentParser objects when + there are errors with the parser's actions. Errors raised while + parsing the command-line are caught by ArgumentParser and emitted + as command-line messages. + + - FileType -- A factory for defining types of files to be created. As the + example above shows, instances of FileType are typically passed as + the type= argument of add_argument() calls. + + - Action -- The base class for parser actions. Typically actions are + selected by passing strings like 'store_true' or 'append_const' to + the action= argument of add_argument(). However, for greater + customization of ArgumentParser actions, subclasses of Action may + be defined and passed as the action= argument. + + - HelpFormatter, RawDescriptionHelpFormatter, RawTextHelpFormatter, + ArgumentDefaultsHelpFormatter -- Formatter classes which + may be passed as the formatter_class= argument to the + ArgumentParser constructor. HelpFormatter is the default, + RawDescriptionHelpFormatter and RawTextHelpFormatter tell the parser + not to change the formatting for help text, and + ArgumentDefaultsHelpFormatter adds information about argument defaults + to the help. + +All other classes in this module are considered implementation details. +(Also note that HelpFormatter and RawDescriptionHelpFormatter are only +considered public as object names -- the API of the formatter objects is +still considered an implementation detail.) +""" + +__version__ = '1.1' +__all__ = [ + 'ArgumentParser', + 'ArgumentError', + 'ArgumentTypeError', + 'FileType', + 'HelpFormatter', + 'ArgumentDefaultsHelpFormatter', + 'RawDescriptionHelpFormatter', + 'RawTextHelpFormatter', + 'MetavarTypeHelpFormatter', + 'Namespace', + 'Action', + 'ONE_OR_MORE', + 'OPTIONAL', + 'PARSER', + 'REMAINDER', + 'SUPPRESS', + 'ZERO_OR_MORE', +] + + +import collections as _collections +import copy as _copy +import os as _os +import re as _re +import sys as _sys +import textwrap as _textwrap + +from gettext import gettext as _, ngettext + + +def _callable(obj): + return hasattr(obj, '__call__') or hasattr(obj, '__bases__') + + +SUPPRESS = '==SUPPRESS==' + +OPTIONAL = '?' +ZERO_OR_MORE = '*' +ONE_OR_MORE = '+' +PARSER = 'A...' +REMAINDER = '...' +_UNRECOGNIZED_ARGS_ATTR = '_unrecognized_args' + +# ============================= +# Utility functions and classes +# ============================= + +class _AttributeHolder(object): + """Abstract base class that provides __repr__. + + The __repr__ method returns a string in the format:: + ClassName(attr=name, attr=name, ...) + The attributes are determined either by a class-level attribute, + '_kwarg_names', or by inspecting the instance __dict__. + """ + + def __repr__(self): + type_name = type(self).__name__ + arg_strings = [] + for arg in self._get_args(): + arg_strings.append(repr(arg)) + for name, value in self._get_kwargs(): + arg_strings.append('%s=%r' % (name, value)) + return '%s(%s)' % (type_name, ', '.join(arg_strings)) + + def _get_kwargs(self): + return sorted(self.__dict__.items()) + + def _get_args(self): + return [] + + +def _ensure_value(namespace, name, value): + if getattr(namespace, name, None) is None: + setattr(namespace, name, value) + return getattr(namespace, name) + + +# =============== +# Formatting Help +# =============== + +class HelpFormatter(object): + """Formatter for generating usage messages and argument help strings. + + Only the name of this class is considered a public API. All the methods + provided by the class are considered an implementation detail. + """ + + def __init__(self, + prog, + indent_increment=2, + max_help_position=24, + width=None): + + # default setting for width + if width is None: + try: + width = int(_os.environ['COLUMNS']) + except (KeyError, ValueError): + width = 80 + width -= 2 + + self._prog = prog + self._indent_increment = indent_increment + self._max_help_position = max_help_position + self._width = width + + self._current_indent = 0 + self._level = 0 + self._action_max_length = 0 + + self._root_section = self._Section(self, None) + self._current_section = self._root_section + + self._whitespace_matcher = _re.compile(r'\s+') + self._long_break_matcher = _re.compile(r'\n\n\n+') + + # =============================== + # Section and indentation methods + # =============================== + def _indent(self): + self._current_indent += self._indent_increment + self._level += 1 + + def _dedent(self): + self._current_indent -= self._indent_increment + assert self._current_indent >= 0, 'Indent decreased below 0.' + self._level -= 1 + + class _Section(object): + + def __init__(self, formatter, parent, heading=None): + self.formatter = formatter + self.parent = parent + self.heading = heading + self.items = [] + + def format_help(self): + # format the indented section + if self.parent is not None: + self.formatter._indent() + join = self.formatter._join_parts + for func, args in self.items: + func(*args) + item_help = join([func(*args) for func, args in self.items]) + if self.parent is not None: + self.formatter._dedent() + + # return nothing if the section was empty + if not item_help: + return '' + + # add the heading if the section was non-empty + if self.heading is not SUPPRESS and self.heading is not None: + current_indent = self.formatter._current_indent + heading = '%*s%s:\n' % (current_indent, '', self.heading) + else: + heading = '' + + # join the section-initial newline, the heading and the help + return join(['\n', heading, item_help, '\n']) + + def _add_item(self, func, args): + self._current_section.items.append((func, args)) + + # ======================== + # Message building methods + # ======================== + def start_section(self, heading): + self._indent() + section = self._Section(self, self._current_section, heading) + self._add_item(section.format_help, []) + self._current_section = section + + def end_section(self): + self._current_section = self._current_section.parent + self._dedent() + + def add_text(self, text): + if text is not SUPPRESS and text is not None: + self._add_item(self._format_text, [text]) + + def add_usage(self, usage, actions, groups, prefix=None): + if usage is not SUPPRESS: + args = usage, actions, groups, prefix + self._add_item(self._format_usage, args) + + def add_argument(self, action): + if action.help is not SUPPRESS: + + # find all invocations + get_invocation = self._format_action_invocation + invocations = [get_invocation(action)] + for subaction in self._iter_indented_subactions(action): + invocations.append(get_invocation(subaction)) + + # update the maximum item length + invocation_length = max([len(s) for s in invocations]) + action_length = invocation_length + self._current_indent + self._action_max_length = max(self._action_max_length, + action_length) + + # add the item to the list + self._add_item(self._format_action, [action]) + + def add_arguments(self, actions): + for action in actions: + self.add_argument(action) + + # ======================= + # Help-formatting methods + # ======================= + def format_help(self): + help = self._root_section.format_help() + if help: + help = self._long_break_matcher.sub('\n\n', help) + help = help.strip('\n') + '\n' + return help + + def _join_parts(self, part_strings): + return ''.join([part + for part in part_strings + if part and part is not SUPPRESS]) + + def _format_usage(self, usage, actions, groups, prefix): + if prefix is None: + prefix = _('usage: ') + + # if usage is specified, use that + if usage is not None: + usage = usage % dict(prog=self._prog) + + # if no optionals or positionals are available, usage is just prog + elif usage is None and not actions: + usage = '%(prog)s' % dict(prog=self._prog) + + # if optionals and positionals are available, calculate usage + elif usage is None: + prog = '%(prog)s' % dict(prog=self._prog) + + # split optionals from positionals + optionals = [] + positionals = [] + for action in actions: + if action.option_strings: + optionals.append(action) + else: + positionals.append(action) + + # build full usage string + format = self._format_actions_usage + action_usage = format(optionals + positionals, groups) + usage = ' '.join([s for s in [prog, action_usage] if s]) + + # wrap the usage parts if it's too long + text_width = self._width - self._current_indent + if len(prefix) + len(usage) > text_width: + + # break usage into wrappable parts + part_regexp = r'\(.*?\)+|\[.*?\]+|\S+' + opt_usage = format(optionals, groups) + pos_usage = format(positionals, groups) + opt_parts = _re.findall(part_regexp, opt_usage) + pos_parts = _re.findall(part_regexp, pos_usage) + assert ' '.join(opt_parts) == opt_usage + assert ' '.join(pos_parts) == pos_usage + + # helper for wrapping lines + def get_lines(parts, indent, prefix=None): + lines = [] + line = [] + if prefix is not None: + line_len = len(prefix) - 1 + else: + line_len = len(indent) - 1 + for part in parts: + if line_len + 1 + len(part) > text_width: + lines.append(indent + ' '.join(line)) + line = [] + line_len = len(indent) - 1 + line.append(part) + line_len += len(part) + 1 + if line: + lines.append(indent + ' '.join(line)) + if prefix is not None: + lines[0] = lines[0][len(indent):] + return lines + + # if prog is short, follow it with optionals or positionals + if len(prefix) + len(prog) <= 0.75 * text_width: + indent = ' ' * (len(prefix) + len(prog) + 1) + if opt_parts: + lines = get_lines([prog] + opt_parts, indent, prefix) + lines.extend(get_lines(pos_parts, indent)) + elif pos_parts: + lines = get_lines([prog] + pos_parts, indent, prefix) + else: + lines = [prog] + + # if prog is long, put it on its own line + else: + indent = ' ' * len(prefix) + parts = opt_parts + pos_parts + lines = get_lines(parts, indent) + if len(lines) > 1: + lines = [] + lines.extend(get_lines(opt_parts, indent)) + lines.extend(get_lines(pos_parts, indent)) + lines = [prog] + lines + + # join lines into usage + usage = '\n'.join(lines) + + # prefix with 'usage:' + return '%s%s\n\n' % (prefix, usage) + + def _format_actions_usage(self, actions, groups): + # find group indices and identify actions in groups + group_actions = set() + inserts = {} + for group in groups: + try: + start = actions.index(group._group_actions[0]) + except ValueError: + continue + else: + end = start + len(group._group_actions) + if actions[start:end] == group._group_actions: + for action in group._group_actions: + group_actions.add(action) + if not group.required: + if start in inserts: + inserts[start] += ' [' + else: + inserts[start] = '[' + inserts[end] = ']' + else: + if start in inserts: + inserts[start] += ' (' + else: + inserts[start] = '(' + inserts[end] = ')' + for i in range(start + 1, end): + inserts[i] = '|' + + # collect all actions format strings + parts = [] + for i, action in enumerate(actions): + + # suppressed arguments are marked with None + # remove | separators for suppressed arguments + if action.help is SUPPRESS: + parts.append(None) + if inserts.get(i) == '|': + inserts.pop(i) + elif inserts.get(i + 1) == '|': + inserts.pop(i + 1) + + # produce all arg strings + elif not action.option_strings: + default = self._get_default_metavar_for_positional(action) + part = self._format_args(action, default) + + # if it's in a group, strip the outer [] + if action in group_actions: + if part[0] == '[' and part[-1] == ']': + part = part[1:-1] + + # add the action string to the list + parts.append(part) + + # produce the first way to invoke the option in brackets + else: + option_string = action.option_strings[0] + + # if the Optional doesn't take a value, format is: + # -s or --long + if action.nargs == 0: + part = '%s' % option_string + + # if the Optional takes a value, format is: + # -s ARGS or --long ARGS + else: + default = self._get_default_metavar_for_optional(action) + args_string = self._format_args(action, default) + part = '%s %s' % (option_string, args_string) + + # make it look optional if it's not required or in a group + if not action.required and action not in group_actions: + part = '[%s]' % part + + # add the action string to the list + parts.append(part) + + # insert things at the necessary indices + for i in sorted(inserts, reverse=True): + parts[i:i] = [inserts[i]] + + # join all the action items with spaces + text = ' '.join([item for item in parts if item is not None]) + + # clean up separators for mutually exclusive groups + open = r'[\[(]' + close = r'[\])]' + text = _re.sub(r'(%s) ' % open, r'\1', text) + text = _re.sub(r' (%s)' % close, r'\1', text) + text = _re.sub(r'%s *%s' % (open, close), r'', text) + text = _re.sub(r'\(([^|]*)\)', r'\1', text) + text = text.strip() + + # return the text + return text + + def _format_text(self, text): + if '%(prog)' in text: + text = text % dict(prog=self._prog) + text_width = self._width - self._current_indent + indent = ' ' * self._current_indent + return self._fill_text(text, text_width, indent) + '\n\n' + + def _format_action(self, action): + # determine the required width and the entry label + help_position = min(self._action_max_length + 2, + self._max_help_position) + help_width = self._width - help_position + action_width = help_position - self._current_indent - 2 + action_header = self._format_action_invocation(action) + + # ho nelp; start on same line and add a final newline + if not action.help: + tup = self._current_indent, '', action_header + action_header = '%*s%s\n' % tup + + # short action name; start on the same line and pad two spaces + elif len(action_header) <= action_width: + tup = self._current_indent, '', action_width, action_header + action_header = '%*s%-*s ' % tup + indent_first = 0 + + # long action name; start on the next line + else: + tup = self._current_indent, '', action_header + action_header = '%*s%s\n' % tup + indent_first = help_position + + # collect the pieces of the action help + parts = [action_header] + + # if there was help for the action, add lines of help text + if action.help: + help_text = self._expand_help(action) + help_lines = self._split_lines(help_text, help_width) + parts.append('%*s%s\n' % (indent_first, '', help_lines[0])) + for line in help_lines[1:]: + parts.append('%*s%s\n' % (help_position, '', line)) + + # or add a newline if the description doesn't end with one + elif not action_header.endswith('\n'): + parts.append('\n') + + # if there are any sub-actions, add their help as well + for subaction in self._iter_indented_subactions(action): + parts.append(self._format_action(subaction)) + + # return a single string + return self._join_parts(parts) + + def _format_action_invocation(self, action): + if not action.option_strings: + default = self._get_default_metavar_for_positional(action) + metavar, = self._metavar_formatter(action, default)(1) + return metavar + + else: + parts = [] + + # if the Optional doesn't take a value, format is: + # -s, --long + if action.nargs == 0: + parts.extend(action.option_strings) + + # if the Optional takes a value, format is: + # -s ARGS, --long ARGS + else: + default = self._get_default_metavar_for_optional(action) + args_string = self._format_args(action, default) + for option_string in action.option_strings: + parts.append('%s %s' % (option_string, args_string)) + + return ', '.join(parts) + + def _metavar_formatter(self, action, default_metavar): + if action.metavar is not None: + result = action.metavar + elif action.choices is not None: + choice_strs = [str(choice) for choice in action.choices] + result = '{%s}' % ','.join(choice_strs) + else: + result = default_metavar + + def format(tuple_size): + if isinstance(result, tuple): + return result + else: + return (result, ) * tuple_size + return format + + def _format_args(self, action, default_metavar): + get_metavar = self._metavar_formatter(action, default_metavar) + if action.nargs is None: + result = '%s' % get_metavar(1) + elif action.nargs == OPTIONAL: + result = '[%s]' % get_metavar(1) + elif action.nargs == ZERO_OR_MORE: + result = '[%s [%s ...]]' % get_metavar(2) + elif action.nargs == ONE_OR_MORE: + result = '%s [%s ...]' % get_metavar(2) + elif action.nargs == REMAINDER: + result = '...' + elif action.nargs == PARSER: + result = '%s ...' % get_metavar(1) + else: + formats = ['%s' for _ in range(action.nargs)] + result = ' '.join(formats) % get_metavar(action.nargs) + return result + + def _expand_help(self, action): + params = dict(vars(action), prog=self._prog) + for name in list(params): + if params[name] is SUPPRESS: + del params[name] + for name in list(params): + if hasattr(params[name], '__name__'): + params[name] = params[name].__name__ + if params.get('choices') is not None: + choices_str = ', '.join([str(c) for c in params['choices']]) + params['choices'] = choices_str + return self._get_help_string(action) % params + + def _iter_indented_subactions(self, action): + try: + get_subactions = action._get_subactions + except AttributeError: + pass + else: + self._indent() + for subaction in get_subactions(): + yield subaction + self._dedent() + + def _split_lines(self, text, width): + text = self._whitespace_matcher.sub(' ', text).strip() + return _textwrap.wrap(text, width) + + def _fill_text(self, text, width, indent): + text = self._whitespace_matcher.sub(' ', text).strip() + return _textwrap.fill(text, width, initial_indent=indent, + subsequent_indent=indent) + + def _get_help_string(self, action): + return action.help + + def _get_default_metavar_for_optional(self, action): + return action.dest.upper() + + def _get_default_metavar_for_positional(self, action): + return action.dest + + +class RawDescriptionHelpFormatter(HelpFormatter): + """Help message formatter which retains any formatting in descriptions. + + Only the name of this class is considered a public API. All the methods + provided by the class are considered an implementation detail. + """ + + def _fill_text(self, text, width, indent): + return ''.join([indent + line for line in text.splitlines(True)]) + + +class RawTextHelpFormatter(RawDescriptionHelpFormatter): + """Help message formatter which retains formatting of all help text. + + Only the name of this class is considered a public API. All the methods + provided by the class are considered an implementation detail. + """ + + def _split_lines(self, text, width): + return text.splitlines() + + +class ArgumentDefaultsHelpFormatter(HelpFormatter): + """Help message formatter which adds default values to argument help. + + Only the name of this class is considered a public API. All the methods + provided by the class are considered an implementation detail. + """ + + def _get_help_string(self, action): + help = action.help + if '%(default)' not in action.help: + if action.default is not SUPPRESS: + defaulting_nargs = [OPTIONAL, ZERO_OR_MORE] + if action.option_strings or action.nargs in defaulting_nargs: + help += ' (default: %(default)s)' + return help + + +class MetavarTypeHelpFormatter(HelpFormatter): + """Help message formatter which uses the argument 'type' as the default + metavar value (instead of the argument 'dest') + + Only the name of this class is considered a public API. All the methods + provided by the class are considered an implementation detail. + """ + + def _get_default_metavar_for_optional(self, action): + return action.type.__name__ + + def _get_default_metavar_for_positional(self, action): + return action.type.__name__ + + + +# ===================== +# Options and Arguments +# ===================== + +def _get_action_name(argument): + if argument is None: + return None + elif argument.option_strings: + return '/'.join(argument.option_strings) + elif argument.metavar not in (None, SUPPRESS): + return argument.metavar + elif argument.dest not in (None, SUPPRESS): + return argument.dest + else: + return None + + +class ArgumentError(Exception): + """An error from creating or using an argument (optional or positional). + + The string value of this exception is the message, augmented with + information about the argument that caused it. + """ + + def __init__(self, argument, message): + self.argument_name = _get_action_name(argument) + self.message = message + + def __str__(self): + if self.argument_name is None: + format = '%(message)s' + else: + format = 'argument %(argument_name)s: %(message)s' + return format % dict(message=self.message, + argument_name=self.argument_name) + + +class ArgumentTypeError(Exception): + """An error from trying to convert a command line string to a type.""" + pass + + +# ============== +# Action classes +# ============== + +class Action(_AttributeHolder): + """Information about how to convert command line strings to Python objects. + + Action objects are used by an ArgumentParser to represent the information + needed to parse a single argument from one or more strings from the + command line. The keyword arguments to the Action constructor are also + all attributes of Action instances. + + Keyword Arguments: + + - option_strings -- A list of command-line option strings which + should be associated with this action. + + - dest -- The name of the attribute to hold the created object(s) + + - nargs -- The number of command-line arguments that should be + consumed. By default, one argument will be consumed and a single + value will be produced. Other values include: + - N (an integer) consumes N arguments (and produces a list) + - '?' consumes zero or one arguments + - '*' consumes zero or more arguments (and produces a list) + - '+' consumes one or more arguments (and produces a list) + Note that the difference between the default and nargs=1 is that + with the default, a single value will be produced, while with + nargs=1, a list containing a single value will be produced. + + - const -- The value to be produced if the option is specified and the + option uses an action that takes no values. + + - default -- The value to be produced if the option is not specified. + + - type -- The type which the command-line arguments should be converted + to, should be one of 'string', 'int', 'float', 'complex' or a + callable object that accepts a single string argument. If None, + 'string' is assumed. + + - choices -- A container of values that should be allowed. If not None, + after a command-line argument has been converted to the appropriate + type, an exception will be raised if it is not a member of this + collection. + + - required -- True if the action must always be specified at the + command line. This is only meaningful for optional command-line + arguments. + + - help -- The help string describing the argument. + + - metavar -- The name to be used for the option's argument with the + help string. If None, the 'dest' value will be used as the name. + """ + + def __init__(self, + option_strings, + dest, + nargs=None, + const=None, + default=None, + type=None, + choices=None, + required=False, + help=None, + metavar=None): + self.option_strings = option_strings + self.dest = dest + self.nargs = nargs + self.const = const + self.default = default + self.type = type + self.choices = choices + self.required = required + self.help = help + self.metavar = metavar + + def _get_kwargs(self): + names = [ + 'option_strings', + 'dest', + 'nargs', + 'const', + 'default', + 'type', + 'choices', + 'help', + 'metavar', + ] + return [(name, getattr(self, name)) for name in names] + + def __call__(self, parser, namespace, values, option_string=None): + raise NotImplementedError(_('.__call__() not defined')) + + +class _StoreAction(Action): + + def __init__(self, + option_strings, + dest, + nargs=None, + const=None, + default=None, + type=None, + choices=None, + required=False, + help=None, + metavar=None): + if nargs == 0: + raise ValueError('nargs for store actions must be > 0; if you ' + 'have nothing to store, actions such as store ' + 'true or store const may be more appropriate') + if const is not None and nargs != OPTIONAL: + raise ValueError('nargs must be %r to supply const' % OPTIONAL) + super(_StoreAction, self).__init__( + option_strings=option_strings, + dest=dest, + nargs=nargs, + const=const, + default=default, + type=type, + choices=choices, + required=required, + help=help, + metavar=metavar) + + def __call__(self, parser, namespace, values, option_string=None): + setattr(namespace, self.dest, values) + + +class _StoreConstAction(Action): + + def __init__(self, + option_strings, + dest, + const, + default=None, + required=False, + help=None, + metavar=None): + super(_StoreConstAction, self).__init__( + option_strings=option_strings, + dest=dest, + nargs=0, + const=const, + default=default, + required=required, + help=help) + + def __call__(self, parser, namespace, values, option_string=None): + setattr(namespace, self.dest, self.const) + + +class _StoreTrueAction(_StoreConstAction): + + def __init__(self, + option_strings, + dest, + default=False, + required=False, + help=None): + super(_StoreTrueAction, self).__init__( + option_strings=option_strings, + dest=dest, + const=True, + default=default, + required=required, + help=help) + + +class _StoreFalseAction(_StoreConstAction): + + def __init__(self, + option_strings, + dest, + default=True, + required=False, + help=None): + super(_StoreFalseAction, self).__init__( + option_strings=option_strings, + dest=dest, + const=False, + default=default, + required=required, + help=help) + + +class _AppendAction(Action): + + def __init__(self, + option_strings, + dest, + nargs=None, + const=None, + default=None, + type=None, + choices=None, + required=False, + help=None, + metavar=None): + if nargs == 0: + raise ValueError('nargs for append actions must be > 0; if arg ' + 'strings are not supplying the value to append, ' + 'the append const action may be more appropriate') + if const is not None and nargs != OPTIONAL: + raise ValueError('nargs must be %r to supply const' % OPTIONAL) + super(_AppendAction, self).__init__( + option_strings=option_strings, + dest=dest, + nargs=nargs, + const=const, + default=default, + type=type, + choices=choices, + required=required, + help=help, + metavar=metavar) + + def __call__(self, parser, namespace, values, option_string=None): + items = _copy.copy(_ensure_value(namespace, self.dest, [])) + items.append(values) + setattr(namespace, self.dest, items) + + +class _AppendConstAction(Action): + + def __init__(self, + option_strings, + dest, + const, + default=None, + required=False, + help=None, + metavar=None): + super(_AppendConstAction, self).__init__( + option_strings=option_strings, + dest=dest, + nargs=0, + const=const, + default=default, + required=required, + help=help, + metavar=metavar) + + def __call__(self, parser, namespace, values, option_string=None): + items = _copy.copy(_ensure_value(namespace, self.dest, [])) + items.append(self.const) + setattr(namespace, self.dest, items) + + +class _CountAction(Action): + + def __init__(self, + option_strings, + dest, + default=None, + required=False, + help=None): + super(_CountAction, self).__init__( + option_strings=option_strings, + dest=dest, + nargs=0, + default=default, + required=required, + help=help) + + def __call__(self, parser, namespace, values, option_string=None): + new_count = _ensure_value(namespace, self.dest, 0) + 1 + setattr(namespace, self.dest, new_count) + + +class _HelpAction(Action): + + def __init__(self, + option_strings, + dest=SUPPRESS, + default=SUPPRESS, + help=None): + super(_HelpAction, self).__init__( + option_strings=option_strings, + dest=dest, + default=default, + nargs=0, + help=help) + + def __call__(self, parser, namespace, values, option_string=None): + parser.print_help() + parser.exit() + + +class _VersionAction(Action): + + def __init__(self, + option_strings, + version=None, + dest=SUPPRESS, + default=SUPPRESS, + help="show program's version number and exit"): + super(_VersionAction, self).__init__( + option_strings=option_strings, + dest=dest, + default=default, + nargs=0, + help=help) + self.version = version + + def __call__(self, parser, namespace, values, option_string=None): + version = self.version + if version is None: + version = parser.version + formatter = parser._get_formatter() + formatter.add_text(version) + parser.exit(message=formatter.format_help()) + + +class _SubParsersAction(Action): + + class _ChoicesPseudoAction(Action): + + def __init__(self, name, aliases, help): + metavar = dest = name + if aliases: + metavar += ' (%s)' % ', '.join(aliases) + sup = super(_SubParsersAction._ChoicesPseudoAction, self) + sup.__init__(option_strings=[], dest=dest, help=help, + metavar=metavar) + + def __init__(self, + option_strings, + prog, + parser_class, + dest=SUPPRESS, + help=None, + metavar=None): + + self._prog_prefix = prog + self._parser_class = parser_class + self._name_parser_map = _collections.OrderedDict() + self._choices_actions = [] + + super(_SubParsersAction, self).__init__( + option_strings=option_strings, + dest=dest, + nargs=PARSER, + choices=self._name_parser_map, + help=help, + metavar=metavar) + + def add_parser(self, name, **kwargs): + # set prog from the existing prefix + if kwargs.get('prog') is None: + kwargs['prog'] = '%s %s' % (self._prog_prefix, name) + + aliases = kwargs.pop('aliases', ()) + + # create a pseudo-action to hold the choice help + if 'help' in kwargs: + help = kwargs.pop('help') + choice_action = self._ChoicesPseudoAction(name, aliases, help) + self._choices_actions.append(choice_action) + + # create the parser and add it to the map + parser = self._parser_class(**kwargs) + self._name_parser_map[name] = parser + + # make parser available under aliases also + for alias in aliases: + self._name_parser_map[alias] = parser + + return parser + + def _get_subactions(self): + return self._choices_actions + + def __call__(self, parser, namespace, values, option_string=None): + parser_name = values[0] + arg_strings = values[1:] + + # set the parser name if requested + if self.dest is not SUPPRESS: + setattr(namespace, self.dest, parser_name) + + # select the parser + try: + parser = self._name_parser_map[parser_name] + except KeyError: + args = {'parser_name': parser_name, + 'choices': ', '.join(self._name_parser_map)} + msg = _('unknown parser %(parser_name)r (choices: %(choices)s)') % args + raise ArgumentError(self, msg) + + # parse all the remaining options into the namespace + # store any unrecognized options on the object, so that the top + # level parser can decide what to do with them + namespace, arg_strings = parser.parse_known_args(arg_strings, namespace) + if arg_strings: + vars(namespace).setdefault(_UNRECOGNIZED_ARGS_ATTR, []) + getattr(namespace, _UNRECOGNIZED_ARGS_ATTR).extend(arg_strings) + + +# ============== +# Type classes +# ============== + +class FileType(object): + """Factory for creating file object types + + Instances of FileType are typically passed as type= arguments to the + ArgumentParser add_argument() method. + + Keyword Arguments: + - mode -- A string indicating how the file is to be opened. Accepts the + same values as the builtin open() function. + - bufsize -- The file's desired buffer size. Accepts the same values as + the builtin open() function. + """ + + def __init__(self, mode='r', bufsize=-1): + self._mode = mode + self._bufsize = bufsize + + def __call__(self, string): + # the special argument "-" means sys.std{in,out} + if string == '-': + if 'r' in self._mode: + return _sys.stdin + elif 'w' in self._mode: + return _sys.stdout + else: + msg = _('argument "-" with mode %r') % self._mode + raise ValueError(msg) + + # all other arguments are used as file names + try: + return open(string, self._mode, self._bufsize) + except IOError as e: + message = _("can't open '%s': %s") + raise ArgumentTypeError(message % (string, e)) + + def __repr__(self): + args = self._mode, self._bufsize + args_str = ', '.join(repr(arg) for arg in args if arg != -1) + return '%s(%s)' % (type(self).__name__, args_str) + +# =========================== +# Optional and Positional Parsing +# =========================== + +class Namespace(_AttributeHolder): + """Simple object for storing attributes. + + Implements equality by attribute names and values, and provides a simple + string representation. + """ + + def __init__(self, **kwargs): + for name in kwargs: + setattr(self, name, kwargs[name]) + + def __eq__(self, other): + return vars(self) == vars(other) + + def __ne__(self, other): + return not (self == other) + + def __contains__(self, key): + return key in self.__dict__ + + +class _ActionsContainer(object): + + def __init__(self, + description, + prefix_chars, + argument_default, + conflict_handler): + super(_ActionsContainer, self).__init__() + + self.description = description + self.argument_default = argument_default + self.prefix_chars = prefix_chars + self.conflict_handler = conflict_handler + + # set up registries + self._registries = {} + + # register actions + self.register('action', None, _StoreAction) + self.register('action', 'store', _StoreAction) + self.register('action', 'store_const', _StoreConstAction) + self.register('action', 'store_true', _StoreTrueAction) + self.register('action', 'store_false', _StoreFalseAction) + self.register('action', 'append', _AppendAction) + self.register('action', 'append_const', _AppendConstAction) + self.register('action', 'count', _CountAction) + self.register('action', 'help', _HelpAction) + self.register('action', 'version', _VersionAction) + self.register('action', 'parsers', _SubParsersAction) + + # raise an exception if the conflict handler is invalid + self._get_handler() + + # action storage + self._actions = [] + self._option_string_actions = {} + + # groups + self._action_groups = [] + self._mutually_exclusive_groups = [] + + # defaults storage + self._defaults = {} + + # determines whether an "option" looks like a negative number + self._negative_number_matcher = _re.compile(r'^-\d+$|^-\d*\.\d+$') + + # whether or not there are any optionals that look like negative + # numbers -- uses a list so it can be shared and edited + self._has_negative_number_optionals = [] + + # ==================== + # Registration methods + # ==================== + def register(self, registry_name, value, object): + registry = self._registries.setdefault(registry_name, {}) + registry[value] = object + + def _registry_get(self, registry_name, value, default=None): + return self._registries[registry_name].get(value, default) + + # ================================== + # Namespace default accessor methods + # ================================== + def set_defaults(self, **kwargs): + self._defaults.update(kwargs) + + # if these defaults match any existing arguments, replace + # the previous default on the object with the new one + for action in self._actions: + if action.dest in kwargs: + action.default = kwargs[action.dest] + + def get_default(self, dest): + for action in self._actions: + if action.dest == dest and action.default is not None: + return action.default + return self._defaults.get(dest, None) + + + # ======================= + # Adding argument actions + # ======================= + def add_argument(self, *args, **kwargs): + """ + add_argument(dest, ..., name=value, ...) + add_argument(option_string, option_string, ..., name=value, ...) + """ + + # if no positional args are supplied or only one is supplied and + # it doesn't look like an option string, parse a positional + # argument + chars = self.prefix_chars + if not args or len(args) == 1 and args[0][0] not in chars: + if args and 'dest' in kwargs: + raise ValueError('dest supplied twice for positional argument') + kwargs = self._get_positional_kwargs(*args, **kwargs) + + # otherwise, we're adding an optional argument + else: + kwargs = self._get_optional_kwargs(*args, **kwargs) + + # if no default was supplied, use the parser-level default + if 'default' not in kwargs: + dest = kwargs['dest'] + if dest in self._defaults: + kwargs['default'] = self._defaults[dest] + elif self.argument_default is not None: + kwargs['default'] = self.argument_default + + # create the action object, and add it to the parser + action_class = self._pop_action_class(kwargs) + if not _callable(action_class): + raise ValueError('unknown action "%s"' % (action_class,)) + action = action_class(**kwargs) + + # raise an error if the action type is not callable + type_func = self._registry_get('type', action.type, action.type) + if not _callable(type_func): + raise ValueError('%r is not callable' % (type_func,)) + + # raise an error if the metavar does not match the type + if hasattr(self, "_get_formatter"): + try: + self._get_formatter()._format_args(action, None) + except TypeError: + raise ValueError("length of metavar tuple does not match nargs") + + return self._add_action(action) + + def add_argument_group(self, *args, **kwargs): + group = _ArgumentGroup(self, *args, **kwargs) + self._action_groups.append(group) + return group + + def add_mutually_exclusive_group(self, **kwargs): + group = _MutuallyExclusiveGroup(self, **kwargs) + self._mutually_exclusive_groups.append(group) + return group + + def _add_action(self, action): + # resolve any conflicts + self._check_conflict(action) + + # add to actions list + self._actions.append(action) + action.container = self + + # index the action by any option strings it has + for option_string in action.option_strings: + self._option_string_actions[option_string] = action + + # set the flag if any option strings look like negative numbers + for option_string in action.option_strings: + if self._negative_number_matcher.match(option_string): + if not self._has_negative_number_optionals: + self._has_negative_number_optionals.append(True) + + # return the created action + return action + + def _remove_action(self, action): + self._actions.remove(action) + + def _add_container_actions(self, container): + # collect groups by titles + title_group_map = {} + for group in self._action_groups: + if group.title in title_group_map: + msg = _('cannot merge actions - two groups are named %r') + raise ValueError(msg % (group.title)) + title_group_map[group.title] = group + + # map each action to its group + group_map = {} + for group in container._action_groups: + + # if a group with the title exists, use that, otherwise + # create a new group matching the container's group + if group.title not in title_group_map: + title_group_map[group.title] = self.add_argument_group( + title=group.title, + description=group.description, + conflict_handler=group.conflict_handler) + + # map the actions to their new group + for action in group._group_actions: + group_map[action] = title_group_map[group.title] + + # add container's mutually exclusive groups + # NOTE: if add_mutually_exclusive_group ever gains title= and + # description= then this code will need to be expanded as above + for group in container._mutually_exclusive_groups: + mutex_group = self.add_mutually_exclusive_group( + required=group.required) + + # map the actions to their new mutex group + for action in group._group_actions: + group_map[action] = mutex_group + + # add all actions to this container or their group + for action in container._actions: + group_map.get(action, self)._add_action(action) + + def _get_positional_kwargs(self, dest, **kwargs): + # make sure required is not specified + if 'required' in kwargs: + msg = _("'required' is an invalid argument for positionals") + raise TypeError(msg) + + # mark positional arguments as required if at least one is + # always required + if kwargs.get('nargs') not in [OPTIONAL, ZERO_OR_MORE]: + kwargs['required'] = True + if kwargs.get('nargs') == ZERO_OR_MORE and 'default' not in kwargs: + kwargs['required'] = True + + # return the keyword arguments with no option strings + return dict(kwargs, dest=dest, option_strings=[]) + + def _get_optional_kwargs(self, *args, **kwargs): + # determine short and long option strings + option_strings = [] + long_option_strings = [] + for option_string in args: + # error on strings that don't start with an appropriate prefix + if not option_string[0] in self.prefix_chars: + args = {'option': option_string, + 'prefix_chars': self.prefix_chars} + msg = _('invalid option string %(option)r: ' + 'must start with a character %(prefix_chars)r') + raise ValueError(msg % args) + + # strings starting with two prefix characters are long options + option_strings.append(option_string) + if option_string[0] in self.prefix_chars: + if len(option_string) > 1: + if option_string[1] in self.prefix_chars: + long_option_strings.append(option_string) + + # infer destination, '--foo-bar' -> 'foo_bar' and '-x' -> 'x' + dest = kwargs.pop('dest', None) + if dest is None: + if long_option_strings: + dest_option_string = long_option_strings[0] + else: + dest_option_string = option_strings[0] + dest = dest_option_string.lstrip(self.prefix_chars) + if not dest: + msg = _('dest= is required for options like %r') + raise ValueError(msg % option_string) + dest = dest.replace('-', '_') + + # return the updated keyword arguments + return dict(kwargs, dest=dest, option_strings=option_strings) + + def _pop_action_class(self, kwargs, default=None): + action = kwargs.pop('action', default) + return self._registry_get('action', action, action) + + def _get_handler(self): + # determine function from conflict handler string + handler_func_name = '_handle_conflict_%s' % self.conflict_handler + try: + return getattr(self, handler_func_name) + except AttributeError: + msg = _('invalid conflict_resolution value: %r') + raise ValueError(msg % self.conflict_handler) + + def _check_conflict(self, action): + + # find all options that conflict with this option + confl_optionals = [] + for option_string in action.option_strings: + if option_string in self._option_string_actions: + confl_optional = self._option_string_actions[option_string] + confl_optionals.append((option_string, confl_optional)) + + # resolve any conflicts + if confl_optionals: + conflict_handler = self._get_handler() + conflict_handler(action, confl_optionals) + + def _handle_conflict_error(self, action, conflicting_actions): + message = ngettext('conflicting option string: %s', + 'conflicting option strings: %s', + len(conflicting_actions)) + conflict_string = ', '.join([option_string + for option_string, action + in conflicting_actions]) + raise ArgumentError(action, message % conflict_string) + + def _handle_conflict_resolve(self, action, conflicting_actions): + + # remove all conflicting options + for option_string, action in conflicting_actions: + + # remove the conflicting option + action.option_strings.remove(option_string) + self._option_string_actions.pop(option_string, None) + + # if the option now has no option string, remove it from the + # container holding it + if not action.option_strings: + action.container._remove_action(action) + + +class _ArgumentGroup(_ActionsContainer): + + def __init__(self, container, title=None, description=None, **kwargs): + # add any missing keyword arguments by checking the container + update = kwargs.setdefault + update('conflict_handler', container.conflict_handler) + update('prefix_chars', container.prefix_chars) + update('argument_default', container.argument_default) + super_init = super(_ArgumentGroup, self).__init__ + super_init(description=description, **kwargs) + + # group attributes + self.title = title + self._group_actions = [] + + # share most attributes with the container + self._registries = container._registries + self._actions = container._actions + self._option_string_actions = container._option_string_actions + self._defaults = container._defaults + self._has_negative_number_optionals = \ + container._has_negative_number_optionals + self._mutually_exclusive_groups = container._mutually_exclusive_groups + + def _add_action(self, action): + action = super(_ArgumentGroup, self)._add_action(action) + self._group_actions.append(action) + return action + + def _remove_action(self, action): + super(_ArgumentGroup, self)._remove_action(action) + self._group_actions.remove(action) + + +class _MutuallyExclusiveGroup(_ArgumentGroup): + + def __init__(self, container, required=False): + super(_MutuallyExclusiveGroup, self).__init__(container) + self.required = required + self._container = container + + def _add_action(self, action): + if action.required: + msg = _('mutually exclusive arguments must be optional') + raise ValueError(msg) + action = self._container._add_action(action) + self._group_actions.append(action) + return action + + def _remove_action(self, action): + self._container._remove_action(action) + self._group_actions.remove(action) + + +class ArgumentParser(_AttributeHolder, _ActionsContainer): + """Object for parsing command line strings into Python objects. + + Keyword Arguments: + - prog -- The name of the program (default: sys.argv[0]) + - usage -- A usage message (default: auto-generated from arguments) + - description -- A description of what the program does + - epilog -- Text following the argument descriptions + - parents -- Parsers whose arguments should be copied into this one + - formatter_class -- HelpFormatter class for printing help messages + - prefix_chars -- Characters that prefix optional arguments + - fromfile_prefix_chars -- Characters that prefix files containing + additional arguments + - argument_default -- The default value for all arguments + - conflict_handler -- String indicating how to handle conflicts + - add_help -- Add a -h/-help option + """ + + def __init__(self, + prog=None, + usage=None, + description=None, + epilog=None, + version=None, + parents=[], + formatter_class=HelpFormatter, + prefix_chars='-', + fromfile_prefix_chars=None, + argument_default=None, + conflict_handler='error', + add_help=True): + + if version is not None: + import warnings + warnings.warn( + """The "version" argument to ArgumentParser is deprecated. """ + """Please use """ + """"add_argument(..., action='version', version="N", ...)" """ + """instead""", DeprecationWarning) + + superinit = super(ArgumentParser, self).__init__ + superinit(description=description, + prefix_chars=prefix_chars, + argument_default=argument_default, + conflict_handler=conflict_handler) + + # default setting for prog + if prog is None: + prog = _os.path.basename(_sys.argv[0]) + + self.prog = prog + self.usage = usage + self.epilog = epilog + self.version = version + self.formatter_class = formatter_class + self.fromfile_prefix_chars = fromfile_prefix_chars + self.add_help = add_help + + add_group = self.add_argument_group + self._positionals = add_group(_('positional arguments')) + self._optionals = add_group(_('optional arguments')) + self._subparsers = None + + # register types + def identity(string): + return string + self.register('type', None, identity) + + # add help and version arguments if necessary + # (using explicit default to override global argument_default) + default_prefix = '-' if '-' in prefix_chars else prefix_chars[0] + if self.add_help: + self.add_argument( + default_prefix+'h', default_prefix*2+'help', + action='help', default=SUPPRESS, + help=_('show this help message and exit')) + if self.version: + self.add_argument( + default_prefix+'v', default_prefix*2+'version', + action='version', default=SUPPRESS, + version=self.version, + help=_("show program's version number and exit")) + + # add parent arguments and defaults + for parent in parents: + self._add_container_actions(parent) + try: + defaults = parent._defaults + except AttributeError: + pass + else: + self._defaults.update(defaults) + + # ======================= + # Pretty __repr__ methods + # ======================= + def _get_kwargs(self): + names = [ + 'prog', + 'usage', + 'description', + 'version', + 'formatter_class', + 'conflict_handler', + 'add_help', + ] + return [(name, getattr(self, name)) for name in names] + + # ================================== + # Optional/Positional adding methods + # ================================== + def add_subparsers(self, **kwargs): + if self._subparsers is not None: + self.error(_('cannot have multiple subparser arguments')) + + # add the parser class to the arguments if it's not present + kwargs.setdefault('parser_class', type(self)) + + if 'title' in kwargs or 'description' in kwargs: + title = _(kwargs.pop('title', 'subcommands')) + description = _(kwargs.pop('description', None)) + self._subparsers = self.add_argument_group(title, description) + else: + self._subparsers = self._positionals + + # prog defaults to the usage message of this parser, skipping + # optional arguments and with no "usage:" prefix + if kwargs.get('prog') is None: + formatter = self._get_formatter() + positionals = self._get_positional_actions() + groups = self._mutually_exclusive_groups + formatter.add_usage(self.usage, positionals, groups, '') + kwargs['prog'] = formatter.format_help().strip() + + # create the parsers action and add it to the positionals list + parsers_class = self._pop_action_class(kwargs, 'parsers') + action = parsers_class(option_strings=[], **kwargs) + self._subparsers._add_action(action) + + # return the created parsers action + return action + + def _add_action(self, action): + if action.option_strings: + self._optionals._add_action(action) + else: + self._positionals._add_action(action) + return action + + def _get_optional_actions(self): + return [action + for action in self._actions + if action.option_strings] + + def _get_positional_actions(self): + return [action + for action in self._actions + if not action.option_strings] + + # ===================================== + # Command line argument parsing methods + # ===================================== + def parse_args(self, args=None, namespace=None): + args, argv = self.parse_known_args(args, namespace) + if argv: + msg = _('unrecognized arguments: %s') + self.error(msg % ' '.join(argv)) + return args + + def parse_known_args(self, args=None, namespace=None): + # args default to the system args + if args is None: + args = _sys.argv[1:] + + # default Namespace built from parser defaults + if namespace is None: + namespace = Namespace() + + # add any action defaults that aren't present + for action in self._actions: + if action.dest is not SUPPRESS: + if not hasattr(namespace, action.dest): + if action.default is not SUPPRESS: + default = action.default + if isinstance(action.default, str): + default = self._get_value(action, default) + setattr(namespace, action.dest, default) + + # add any parser defaults that aren't present + for dest in self._defaults: + if not hasattr(namespace, dest): + setattr(namespace, dest, self._defaults[dest]) + + # parse the arguments and exit if there are any errors + try: + namespace, args = self._parse_known_args(args, namespace) + if hasattr(namespace, _UNRECOGNIZED_ARGS_ATTR): + args.extend(getattr(namespace, _UNRECOGNIZED_ARGS_ATTR)) + delattr(namespace, _UNRECOGNIZED_ARGS_ATTR) + return namespace, args + except ArgumentError: + err = _sys.exc_info()[1] + self.error(str(err)) + + def _parse_known_args(self, arg_strings, namespace): + # replace arg strings that are file references + if self.fromfile_prefix_chars is not None: + arg_strings = self._read_args_from_files(arg_strings) + + # map all mutually exclusive arguments to the other arguments + # they can't occur with + action_conflicts = {} + for mutex_group in self._mutually_exclusive_groups: + group_actions = mutex_group._group_actions + for i, mutex_action in enumerate(mutex_group._group_actions): + conflicts = action_conflicts.setdefault(mutex_action, []) + conflicts.extend(group_actions[:i]) + conflicts.extend(group_actions[i + 1:]) + + # find all option indices, and determine the arg_string_pattern + # which has an 'O' if there is an option at an index, + # an 'A' if there is an argument, or a '-' if there is a '--' + option_string_indices = {} + arg_string_pattern_parts = [] + arg_strings_iter = iter(arg_strings) + for i, arg_string in enumerate(arg_strings_iter): + + # all args after -- are non-options + if arg_string == '--': + arg_string_pattern_parts.append('-') + for arg_string in arg_strings_iter: + arg_string_pattern_parts.append('A') + + # otherwise, add the arg to the arg strings + # and note the index if it was an option + else: + option_tuple = self._parse_optional(arg_string) + if option_tuple is None: + pattern = 'A' + else: + option_string_indices[i] = option_tuple + pattern = 'O' + arg_string_pattern_parts.append(pattern) + + # join the pieces together to form the pattern + arg_strings_pattern = ''.join(arg_string_pattern_parts) + + # converts arg strings to the appropriate and then takes the action + seen_actions = set() + seen_non_default_actions = set() + + def take_action(action, argument_strings, option_string=None): + seen_actions.add(action) + argument_values = self._get_values(action, argument_strings) + + # error if this argument is not allowed with other previously + # seen arguments, assuming that actions that use the default + # value don't really count as "present" + if argument_values is not action.default: + seen_non_default_actions.add(action) + for conflict_action in action_conflicts.get(action, []): + if conflict_action in seen_non_default_actions: + msg = _('not allowed with argument %s') + action_name = _get_action_name(conflict_action) + raise ArgumentError(action, msg % action_name) + + # take the action if we didn't receive a SUPPRESS value + # (e.g. from a default) + if argument_values is not SUPPRESS: + action(self, namespace, argument_values, option_string) + + # function to convert arg_strings into an optional action + def consume_optional(start_index): + + # get the optional identified at this index + option_tuple = option_string_indices[start_index] + action, option_string, explicit_arg = option_tuple + + # identify additional optionals in the same arg string + # (e.g. -xyz is the same as -x -y -z if no args are required) + match_argument = self._match_argument + action_tuples = [] + while True: + + # if we found no optional action, skip it + if action is None: + extras.append(arg_strings[start_index]) + return start_index + 1 + + # if there is an explicit argument, try to match the + # optional's string arguments to only this + if explicit_arg is not None: + arg_count = match_argument(action, 'A') + + # if the action is a single-dash option and takes no + # arguments, try to parse more single-dash options out + # of the tail of the option string + chars = self.prefix_chars + if arg_count == 0 and option_string[1] not in chars: + action_tuples.append((action, [], option_string)) + char = option_string[0] + option_string = char + explicit_arg[0] + new_explicit_arg = explicit_arg[1:] or None + optionals_map = self._option_string_actions + if option_string in optionals_map: + action = optionals_map[option_string] + explicit_arg = new_explicit_arg + else: + msg = _('ignored explicit argument %r') + raise ArgumentError(action, msg % explicit_arg) + + # if the action expect exactly one argument, we've + # successfully matched the option; exit the loop + elif arg_count == 1: + stop = start_index + 1 + args = [explicit_arg] + action_tuples.append((action, args, option_string)) + break + + # error if a double-dash option did not use the + # explicit argument + else: + msg = _('ignored explicit argument %r') + raise ArgumentError(action, msg % explicit_arg) + + # if there is no explicit argument, try to match the + # optional's string arguments with the following strings + # if successful, exit the loop + else: + start = start_index + 1 + selected_patterns = arg_strings_pattern[start:] + arg_count = match_argument(action, selected_patterns) + stop = start + arg_count + args = arg_strings[start:stop] + action_tuples.append((action, args, option_string)) + break + + # add the Optional to the list and return the index at which + # the Optional's string args stopped + assert action_tuples + for action, args, option_string in action_tuples: + take_action(action, args, option_string) + return stop + + # the list of Positionals left to be parsed; this is modified + # by consume_positionals() + positionals = self._get_positional_actions() + + # function to convert arg_strings into positional actions + def consume_positionals(start_index): + # match as many Positionals as possible + match_partial = self._match_arguments_partial + selected_pattern = arg_strings_pattern[start_index:] + arg_counts = match_partial(positionals, selected_pattern) + + # slice off the appropriate arg strings for each Positional + # and add the Positional and its args to the list + for action, arg_count in zip(positionals, arg_counts): + args = arg_strings[start_index: start_index + arg_count] + start_index += arg_count + take_action(action, args) + + # slice off the Positionals that we just parsed and return the + # index at which the Positionals' string args stopped + positionals[:] = positionals[len(arg_counts):] + return start_index + + # consume Positionals and Optionals alternately, until we have + # passed the last option string + extras = [] + start_index = 0 + if option_string_indices: + max_option_string_index = max(option_string_indices) + else: + max_option_string_index = -1 + while start_index <= max_option_string_index: + + # consume any Positionals preceding the next option + next_option_string_index = min([ + index + for index in option_string_indices + if index >= start_index]) + if start_index != next_option_string_index: + positionals_end_index = consume_positionals(start_index) + + # only try to parse the next optional if we didn't consume + # the option string during the positionals parsing + if positionals_end_index > start_index: + start_index = positionals_end_index + continue + else: + start_index = positionals_end_index + + # if we consumed all the positionals we could and we're not + # at the index of an option string, there were extra arguments + if start_index not in option_string_indices: + strings = arg_strings[start_index:next_option_string_index] + extras.extend(strings) + start_index = next_option_string_index + + # consume the next optional and any arguments for it + start_index = consume_optional(start_index) + + # consume any positionals following the last Optional + stop_index = consume_positionals(start_index) + + # if we didn't consume all the argument strings, there were extras + extras.extend(arg_strings[stop_index:]) + + # make sure all required actions were present + required_actions = [_get_action_name(action) for action in self._actions + if action.required and action not in seen_actions] + if required_actions: + self.error(_('the following arguments are required: %s') % + ', '.join(required_actions)) + + # make sure all required groups had one option present + for group in self._mutually_exclusive_groups: + if group.required: + for action in group._group_actions: + if action in seen_non_default_actions: + break + + # if no actions were used, report the error + else: + names = [_get_action_name(action) + for action in group._group_actions + if action.help is not SUPPRESS] + msg = _('one of the arguments %s is required') + self.error(msg % ' '.join(names)) + + # return the updated namespace and the extra arguments + return namespace, extras + + def _read_args_from_files(self, arg_strings): + # expand arguments referencing files + new_arg_strings = [] + for arg_string in arg_strings: + + # for regular arguments, just add them back into the list + if arg_string[0] not in self.fromfile_prefix_chars: + new_arg_strings.append(arg_string) + + # replace arguments referencing files with the file content + else: + try: + args_file = open(arg_string[1:]) + try: + arg_strings = [] + for arg_line in args_file.read().splitlines(): + for arg in self.convert_arg_line_to_args(arg_line): + arg_strings.append(arg) + arg_strings = self._read_args_from_files(arg_strings) + new_arg_strings.extend(arg_strings) + finally: + args_file.close() + except IOError: + err = _sys.exc_info()[1] + self.error(str(err)) + + # return the modified argument list + return new_arg_strings + + def convert_arg_line_to_args(self, arg_line): + return [arg_line] + + def _match_argument(self, action, arg_strings_pattern): + # match the pattern for this action to the arg strings + nargs_pattern = self._get_nargs_pattern(action) + match = _re.match(nargs_pattern, arg_strings_pattern) + + # raise an exception if we weren't able to find a match + if match is None: + nargs_errors = { + None: _('expected one argument'), + OPTIONAL: _('expected at most one argument'), + ONE_OR_MORE: _('expected at least one argument'), + } + default = ngettext('expected %s argument', + 'expected %s arguments', + action.nargs) % action.nargs + msg = nargs_errors.get(action.nargs, default) + raise ArgumentError(action, msg) + + # return the number of arguments matched + return len(match.group(1)) + + def _match_arguments_partial(self, actions, arg_strings_pattern): + # progressively shorten the actions list by slicing off the + # final actions until we find a match + result = [] + for i in range(len(actions), 0, -1): + actions_slice = actions[:i] + pattern = ''.join([self._get_nargs_pattern(action) + for action in actions_slice]) + match = _re.match(pattern, arg_strings_pattern) + if match is not None: + result.extend([len(string) for string in match.groups()]) + break + + # return the list of arg string counts + return result + + def _parse_optional(self, arg_string): + # if it's an empty string, it was meant to be a positional + if not arg_string: + return None + + # if it doesn't start with a prefix, it was meant to be positional + if not arg_string[0] in self.prefix_chars: + return None + + # if the option string is present in the parser, return the action + if arg_string in self._option_string_actions: + action = self._option_string_actions[arg_string] + return action, arg_string, None + + # if it's just a single character, it was meant to be positional + if len(arg_string) == 1: + return None + + # if the option string before the "=" is present, return the action + if '=' in arg_string: + option_string, explicit_arg = arg_string.split('=', 1) + if option_string in self._option_string_actions: + action = self._option_string_actions[option_string] + return action, option_string, explicit_arg + + # search through all possible prefixes of the option string + # and all actions in the parser for possible interpretations + option_tuples = self._get_option_tuples(arg_string) + + # if multiple actions match, the option string was ambiguous + if len(option_tuples) > 1: + options = ', '.join([option_string + for action, option_string, explicit_arg in option_tuples]) + args = {'option': arg_string, 'matches': options} + msg = _('ambiguous option: %(option)s could match %(matches)s') + self.error(msg % args) + + # if exactly one action matched, this segmentation is good, + # so return the parsed action + elif len(option_tuples) == 1: + option_tuple, = option_tuples + return option_tuple + + # if it was not found as an option, but it looks like a negative + # number, it was meant to be positional + # unless there are negative-number-like options + if self._negative_number_matcher.match(arg_string): + if not self._has_negative_number_optionals: + return None + + # if it contains a space, it was meant to be a positional + if ' ' in arg_string: + return None + + # it was meant to be an optional but there is no such option + # in this parser (though it might be a valid option in a subparser) + return None, arg_string, None + + def _get_option_tuples(self, option_string): + result = [] + + # option strings starting with two prefix characters are only + # split at the '=' + chars = self.prefix_chars + if option_string[0] in chars and option_string[1] in chars: + if '=' in option_string: + option_prefix, explicit_arg = option_string.split('=', 1) + else: + option_prefix = option_string + explicit_arg = None + for option_string in self._option_string_actions: + if option_string.startswith(option_prefix): + action = self._option_string_actions[option_string] + tup = action, option_string, explicit_arg + result.append(tup) + + # single character options can be concatenated with their arguments + # but multiple character options always have to have their argument + # separate + elif option_string[0] in chars and option_string[1] not in chars: + option_prefix = option_string + explicit_arg = None + short_option_prefix = option_string[:2] + short_explicit_arg = option_string[2:] + + for option_string in self._option_string_actions: + if option_string == short_option_prefix: + action = self._option_string_actions[option_string] + tup = action, option_string, short_explicit_arg + result.append(tup) + elif option_string.startswith(option_prefix): + action = self._option_string_actions[option_string] + tup = action, option_string, explicit_arg + result.append(tup) + + # shouldn't ever get here + else: + self.error(_('unexpected option string: %s') % option_string) + + # return the collected option tuples + return result + + def _get_nargs_pattern(self, action): + # in all examples below, we have to allow for '--' args + # which are represented as '-' in the pattern + nargs = action.nargs + + # the default (None) is assumed to be a single argument + if nargs is None: + nargs_pattern = '(-*A-*)' + + # allow zero or one arguments + elif nargs == OPTIONAL: + nargs_pattern = '(-*A?-*)' + + # allow zero or more arguments + elif nargs == ZERO_OR_MORE: + nargs_pattern = '(-*[A-]*)' + + # allow one or more arguments + elif nargs == ONE_OR_MORE: + nargs_pattern = '(-*A[A-]*)' + + # allow any number of options or arguments + elif nargs == REMAINDER: + nargs_pattern = '([-AO]*)' + + # allow one argument followed by any number of options or arguments + elif nargs == PARSER: + nargs_pattern = '(-*A[-AO]*)' + + # all others should be integers + else: + nargs_pattern = '(-*%s-*)' % '-*'.join('A' * nargs) + + # if this is an optional action, -- is not allowed + if action.option_strings: + nargs_pattern = nargs_pattern.replace('-*', '') + nargs_pattern = nargs_pattern.replace('-', '') + + # return the pattern + return nargs_pattern + + # ======================== + # Value conversion methods + # ======================== + def _get_values(self, action, arg_strings): + # for everything but PARSER args, strip out '--' + if action.nargs not in [PARSER, REMAINDER]: + arg_strings = [s for s in arg_strings if s != '--'] + + # optional argument produces a default when not present + if not arg_strings and action.nargs == OPTIONAL: + if action.option_strings: + value = action.const + else: + value = action.default + if isinstance(value, str): + value = self._get_value(action, value) + self._check_value(action, value) + + # when nargs='*' on a positional, if there were no command-line + # args, use the default if it is anything other than None + elif (not arg_strings and action.nargs == ZERO_OR_MORE and + not action.option_strings): + if action.default is not None: + value = action.default + else: + value = arg_strings + self._check_value(action, value) + + # single argument or optional argument produces a single value + elif len(arg_strings) == 1 and action.nargs in [None, OPTIONAL]: + arg_string, = arg_strings + value = self._get_value(action, arg_string) + self._check_value(action, value) + + # REMAINDER arguments convert all values, checking none + elif action.nargs == REMAINDER: + value = [self._get_value(action, v) for v in arg_strings] + + # PARSER arguments convert all values, but check only the first + elif action.nargs == PARSER: + value = [self._get_value(action, v) for v in arg_strings] + self._check_value(action, value[0]) + + # all other types of nargs produce a list + else: + value = [self._get_value(action, v) for v in arg_strings] + for v in value: + self._check_value(action, v) + + # return the converted value + return value + + def _get_value(self, action, arg_string): + type_func = self._registry_get('type', action.type, action.type) + if not _callable(type_func): + msg = _('%r is not callable') + raise ArgumentError(action, msg % type_func) + + # convert the value to the appropriate type + try: + result = type_func(arg_string) + + # ArgumentTypeErrors indicate errors + except ArgumentTypeError: + name = getattr(action.type, '__name__', repr(action.type)) + msg = str(_sys.exc_info()[1]) + raise ArgumentError(action, msg) + + # TypeErrors or ValueErrors also indicate errors + except (TypeError, ValueError): + name = getattr(action.type, '__name__', repr(action.type)) + args = {'type': name, 'value': arg_string} + msg = _('invalid %(type)s value: %(value)r') + raise ArgumentError(action, msg % args) + + # return the converted value + return result + + def _check_value(self, action, value): + # converted value must be one of the choices (if specified) + if action.choices is not None and value not in action.choices: + args = {'value': value, + 'choices': ', '.join(map(repr, action.choices))} + msg = _('invalid choice: %(value)r (choose from %(choices)s)') + raise ArgumentError(action, msg % args) + + # ======================= + # Help-formatting methods + # ======================= + def format_usage(self): + formatter = self._get_formatter() + formatter.add_usage(self.usage, self._actions, + self._mutually_exclusive_groups) + return formatter.format_help() + + def format_help(self): + formatter = self._get_formatter() + + # usage + formatter.add_usage(self.usage, self._actions, + self._mutually_exclusive_groups) + + # description + formatter.add_text(self.description) + + # positionals, optionals and user-defined groups + for action_group in self._action_groups: + formatter.start_section(action_group.title) + formatter.add_text(action_group.description) + formatter.add_arguments(action_group._group_actions) + formatter.end_section() + + # epilog + formatter.add_text(self.epilog) + + # determine help from format above + return formatter.format_help() + + def format_version(self): + import warnings + warnings.warn( + 'The format_version method is deprecated -- the "version" ' + 'argument to ArgumentParser is no longer supported.', + DeprecationWarning) + formatter = self._get_formatter() + formatter.add_text(self.version) + return formatter.format_help() + + def _get_formatter(self): + return self.formatter_class(prog=self.prog) + + # ===================== + # Help-printing methods + # ===================== + def print_usage(self, file=None): + if file is None: + file = _sys.stdout + self._print_message(self.format_usage(), file) + + def print_help(self, file=None): + if file is None: + file = _sys.stdout + self._print_message(self.format_help(), file) + + def print_version(self, file=None): + import warnings + warnings.warn( + 'The print_version method is deprecated -- the "version" ' + 'argument to ArgumentParser is no longer supported.', + DeprecationWarning) + self._print_message(self.format_version(), file) + + def _print_message(self, message, file=None): + if message: + if file is None: + file = _sys.stderr + file.write(message) + + # =============== + # Exiting methods + # =============== + def exit(self, status=0, message=None): + if message: + self._print_message(message, _sys.stderr) + _sys.exit(status) + + def error(self, message): + """error(message: string) + + Prints a usage message incorporating the message to stderr and + exits. + + If you override this in a subclass, it should not return -- it + should either exit or raise an exception. + """ + self.print_usage(_sys.stderr) + args = {'prog': self.prog, 'message': message} + self.exit(2, _('%(prog)s: error: %(message)s\n') % args) diff --git a/lib/beets/LICENSE b/lib/beets/LICENSE new file mode 100644 index 00000000..cddcf990 --- /dev/null +++ b/lib/beets/LICENSE @@ -0,0 +1,21 @@ +The MIT License + +Copyright (c) 2010-2014 Adrian Sampson + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. diff --git a/lib/beets/README.rst b/lib/beets/README.rst new file mode 100644 index 00000000..8c64e244 --- /dev/null +++ b/lib/beets/README.rst @@ -0,0 +1,94 @@ +.. image:: https://travis-ci.org/sampsyo/beets.svg?branch=master + :target: https://travis-ci.org/sampsyo/beets + +.. image:: http://img.shields.io/coveralls/sampsyo/beets.svg + :target: https://coveralls.io/r/sampsyo/beets + +.. image:: http://img.shields.io/pypi/v/beets.svg + :target: https://pypi.python.org/pypi/beets + +Beets is the media library management system for obsessive-compulsive music +geeks. + +The purpose of beets is to get your music collection right once and for all. +It catalogs your collection, automatically improving its metadata as it goes. +It then provides a bouquet of tools for manipulating and accessing your music. + +Here's an example of beets' brainy tag corrector doing its thing:: + + $ beet import ~/music/ladytron + Tagging: + Ladytron - Witching Hour + (Similarity: 98.4%) + * Last One Standing -> The Last One Standing + * Beauty -> Beauty*2 + * White Light Generation -> Whitelightgenerator + * All the Way -> All the Way... + +Because beets is designed as a library, it can do almost anything you can +imagine for your music collection. Via `plugins`_, beets becomes a panacea: + +- Fetch or calculate all the metadata you could possibly need: `album art`_, + `lyrics`_, `genres`_, `tempos`_, `ReplayGain`_ levels, or `acoustic + fingerprints`_. +- Get metadata from `MusicBrainz`_, `Discogs`_, or `Beatport`_. Or guess + metadata using songs' filenames or their acoustic fingerprints. +- `Transcode audio`_ to any format you like. +- Check your library for `duplicate tracks and albums`_ or for `albums that + are missing tracks`_. +- Clean up crufty tags left behind by other, less-awesome tools. +- Embed and extract album art from files' metadata. +- Browse your music library graphically through a Web browser and play it in any + browser that supports `HTML5 Audio`_. +- Analyze music files' metadata from the command line. +- Listen to your library with a music player that speaks the `MPD`_ protocol + and works with a staggering variety of interfaces. + +If beets doesn't do what you want yet, `writing your own plugin`_ is +shockingly simple if you know a little Python. + +.. _plugins: http://beets.readthedocs.org/page/plugins/ +.. _MPD: http://www.musicpd.org/ +.. _MusicBrainz music collection: http://musicbrainz.org/doc/Collections/ +.. _writing your own plugin: + http://beets.readthedocs.org/page/dev/plugins.html +.. _HTML5 Audio: + http://www.w3.org/TR/html-markup/audio.html +.. _albums that are missing tracks: + http://beets.readthedocs.org/page/plugins/missing.html +.. _duplicate tracks and albums: + http://beets.readthedocs.org/page/plugins/duplicates.html +.. _Transcode audio: + http://beets.readthedocs.org/page/plugins/convert.html +.. _Beatport: http://www.beatport.com/ +.. _Discogs: http://www.discogs.com/ +.. _acoustic fingerprints: + http://beets.readthedocs.org/page/plugins/chroma.html +.. _ReplayGain: http://beets.readthedocs.org/page/plugins/replaygain.html +.. _tempos: http://beets.readthedocs.org/page/plugins/echonest.html +.. _genres: http://beets.readthedocs.org/page/plugins/lastgenre.html +.. _album art: http://beets.readthedocs.org/page/plugins/fetchart.html +.. _lyrics: http://beets.readthedocs.org/page/plugins/lyrics.html +.. _MusicBrainz: http://musicbrainz.org/ + +Read More +--------- + +Learn more about beets at `its Web site`_. Follow `@b33ts`_ on Twitter for +news and updates. + +You can install beets by typing ``pip install beets``. Then check out the +`Getting Started`_ guide. + +.. _its Web site: http://beets.radbox.org/ +.. _Getting Started: http://beets.readthedocs.org/page/guides/main.html +.. _@b33ts: http://twitter.com/b33ts/ + +Authors +------- + +Beets is by `Adrian Sampson`_ with a supporting cast of thousands. For help, +please contact the `mailing list`_. + +.. _mailing list: https://groups.google.com/forum/#!forum/beets-users +.. _Adrian Sampson: http://homes.cs.washington.edu/~asampson/ diff --git a/lib/beets/__init__.py b/lib/beets/__init__.py new file mode 100644 index 00000000..d050a028 --- /dev/null +++ b/lib/beets/__init__.py @@ -0,0 +1,28 @@ +# This file is part of beets. +# Copyright 2014, Adrian Sampson. +# +# Permission is hereby granted, free of charge, to any person obtaining +# a copy of this software and associated documentation files (the +# "Software"), to deal in the Software without restriction, including +# without limitation the rights to use, copy, modify, merge, publish, +# distribute, sublicense, and/or sell copies of the Software, and to +# permit persons to whom the Software is furnished to do so, subject to +# the following conditions: +# +# The above copyright notice and this permission notice shall be +# included in all copies or substantial portions of the Software. + +# This particular version has been slightly modified to work with Headphones +# https://github.com/rembo10/headphones + +__version__ = '1.3.10-headphones' +__author__ = 'Adrian Sampson ' + +import os + +import beets.library +from beets.util import confit + +Library = beets.library.Library + +config = confit.LazyConfig(os.path.dirname(__file__), __name__) diff --git a/lib/beets/autotag/__init__.py b/lib/beets/autotag/__init__.py new file mode 100644 index 00000000..7c517c60 --- /dev/null +++ b/lib/beets/autotag/__init__.py @@ -0,0 +1,136 @@ +# This file is part of beets. +# Copyright 2013, Adrian Sampson. +# +# Permission is hereby granted, free of charge, to any person obtaining +# a copy of this software and associated documentation files (the +# "Software"), to deal in the Software without restriction, including +# without limitation the rights to use, copy, modify, merge, publish, +# distribute, sublicense, and/or sell copies of the Software, and to +# permit persons to whom the Software is furnished to do so, subject to +# the following conditions: +# +# The above copyright notice and this permission notice shall be +# included in all copies or substantial portions of the Software. + +"""Facilities for automatically determining files' correct metadata. +""" +import logging + +from beets import config + +# Parts of external interface. +from .hooks import AlbumInfo, TrackInfo, AlbumMatch, TrackMatch # noqa +from .match import tag_item, tag_album # noqa +from .match import Recommendation # noqa + +# Global logger. +log = logging.getLogger('beets') + + +# Additional utilities for the main interface. + +def apply_item_metadata(item, track_info): + """Set an item's metadata from its matched TrackInfo object. + """ + item.artist = track_info.artist + item.artist_sort = track_info.artist_sort + item.artist_credit = track_info.artist_credit + item.title = track_info.title + item.mb_trackid = track_info.track_id + if track_info.artist_id: + item.mb_artistid = track_info.artist_id + # At the moment, the other metadata is left intact (including album + # and track number). Perhaps these should be emptied? + + +def apply_metadata(album_info, mapping): + """Set the items' metadata to match an AlbumInfo object using a + mapping from Items to TrackInfo objects. + """ + for item, track_info in mapping.iteritems(): + # Album, artist, track count. + if track_info.artist: + item.artist = track_info.artist + else: + item.artist = album_info.artist + item.albumartist = album_info.artist + item.album = album_info.album + + # Artist sort and credit names. + item.artist_sort = track_info.artist_sort or album_info.artist_sort + item.artist_credit = (track_info.artist_credit or + album_info.artist_credit) + item.albumartist_sort = album_info.artist_sort + item.albumartist_credit = album_info.artist_credit + + # Release date. + for prefix in '', 'original_': + if config['original_date'] and not prefix: + # Ignore specific release date. + continue + + for suffix in 'year', 'month', 'day': + key = prefix + suffix + value = getattr(album_info, key) or 0 + + # If we don't even have a year, apply nothing. + if suffix == 'year' and not value: + break + + # Otherwise, set the fetched value (or 0 for the month + # and day if not available). + item[key] = value + + # If we're using original release date for both fields, + # also set item.year = info.original_year, etc. + if config['original_date']: + item[suffix] = value + + # Title. + item.title = track_info.title + + if config['per_disc_numbering']: + item.track = track_info.medium_index or track_info.index + item.tracktotal = track_info.medium_total or len(album_info.tracks) + else: + item.track = track_info.index + item.tracktotal = len(album_info.tracks) + + # Disc and disc count. + item.disc = track_info.medium + item.disctotal = album_info.mediums + + # MusicBrainz IDs. + item.mb_trackid = track_info.track_id + item.mb_albumid = album_info.album_id + if track_info.artist_id: + item.mb_artistid = track_info.artist_id + else: + item.mb_artistid = album_info.artist_id + item.mb_albumartistid = album_info.artist_id + item.mb_releasegroupid = album_info.releasegroup_id + + # Compilation flag. + item.comp = album_info.va + + # Miscellaneous metadata. + for field in ('albumtype', + 'label', + 'asin', + 'catalognum', + 'script', + 'language', + 'country', + 'albumstatus', + 'albumdisambig'): + value = getattr(album_info, field) + if value is not None: + item[field] = value + if track_info.disctitle is not None: + item.disctitle = track_info.disctitle + + if track_info.media is not None: + item.media = track_info.media + + # Headphones seal of approval + item.comments = 'tagged by headphones/beets' \ No newline at end of file diff --git a/lib/beets/autotag/hooks.py b/lib/beets/autotag/hooks.py new file mode 100644 index 00000000..beb3bd91 --- /dev/null +++ b/lib/beets/autotag/hooks.py @@ -0,0 +1,579 @@ +# This file is part of beets. +# Copyright 2013, Adrian Sampson. +# +# Permission is hereby granted, free of charge, to any person obtaining +# a copy of this software and associated documentation files (the +# "Software"), to deal in the Software without restriction, including +# without limitation the rights to use, copy, modify, merge, publish, +# distribute, sublicense, and/or sell copies of the Software, and to +# permit persons to whom the Software is furnished to do so, subject to +# the following conditions: +# +# The above copyright notice and this permission notice shall be +# included in all copies or substantial portions of the Software. + +"""Glue between metadata sources and the matching logic.""" +import logging +from collections import namedtuple +import re + +from beets import plugins +from beets import config +from beets.autotag import mb +from beets.util import levenshtein +from unidecode import unidecode + +log = logging.getLogger('beets') + + +# Classes used to represent candidate options. + +class AlbumInfo(object): + """Describes a canonical release that may be used to match a release + in the library. Consists of these data members: + + - ``album``: the release title + - ``album_id``: MusicBrainz ID; UUID fragment only + - ``artist``: name of the release's primary artist + - ``artist_id`` + - ``tracks``: list of TrackInfo objects making up the release + - ``asin``: Amazon ASIN + - ``albumtype``: string describing the kind of release + - ``va``: boolean: whether the release has "various artists" + - ``year``: release year + - ``month``: release month + - ``day``: release day + - ``label``: music label responsible for the release + - ``mediums``: the number of discs in this release + - ``artist_sort``: name of the release's artist for sorting + - ``releasegroup_id``: MBID for the album's release group + - ``catalognum``: the label's catalog number for the release + - ``script``: character set used for metadata + - ``language``: human language of the metadata + - ``country``: the release country + - ``albumstatus``: MusicBrainz release status (Official, etc.) + - ``media``: delivery mechanism (Vinyl, etc.) + - ``albumdisambig``: MusicBrainz release disambiguation comment + - ``artist_credit``: Release-specific artist name + - ``data_source``: The original data source (MusicBrainz, Discogs, etc.) + - ``data_url``: The data source release URL. + + The fields up through ``tracks`` are required. The others are + optional and may be None. + """ + def __init__(self, album, album_id, artist, artist_id, tracks, asin=None, + albumtype=None, va=False, year=None, month=None, day=None, + label=None, mediums=None, artist_sort=None, + releasegroup_id=None, catalognum=None, script=None, + language=None, country=None, albumstatus=None, media=None, + albumdisambig=None, artist_credit=None, original_year=None, + original_month=None, original_day=None, data_source=None, + data_url=None): + self.album = album + self.album_id = album_id + self.artist = artist + self.artist_id = artist_id + self.tracks = tracks + self.asin = asin + self.albumtype = albumtype + self.va = va + self.year = year + self.month = month + self.day = day + self.label = label + self.mediums = mediums + self.artist_sort = artist_sort + self.releasegroup_id = releasegroup_id + self.catalognum = catalognum + self.script = script + self.language = language + self.country = country + self.albumstatus = albumstatus + self.media = media + self.albumdisambig = albumdisambig + self.artist_credit = artist_credit + self.original_year = original_year + self.original_month = original_month + self.original_day = original_day + self.data_source = data_source + self.data_url = data_url + + # Work around a bug in python-musicbrainz-ngs that causes some + # strings to be bytes rather than Unicode. + # https://github.com/alastair/python-musicbrainz-ngs/issues/85 + def decode(self, codec='utf8'): + """Ensure that all string attributes on this object, and the + constituent `TrackInfo` objects, are decoded to Unicode. + """ + for fld in ['album', 'artist', 'albumtype', 'label', 'artist_sort', + 'catalognum', 'script', 'language', 'country', + 'albumstatus', 'albumdisambig', 'artist_credit', 'media']: + value = getattr(self, fld) + if isinstance(value, str): + setattr(self, fld, value.decode(codec, 'ignore')) + + if self.tracks: + for track in self.tracks: + track.decode(codec) + + +class TrackInfo(object): + """Describes a canonical track present on a release. Appears as part + of an AlbumInfo's ``tracks`` list. Consists of these data members: + + - ``title``: name of the track + - ``track_id``: MusicBrainz ID; UUID fragment only + - ``artist``: individual track artist name + - ``artist_id`` + - ``length``: float: duration of the track in seconds + - ``index``: position on the entire release + - ``media``: delivery mechanism (Vinyl, etc.) + - ``medium``: the disc number this track appears on in the album + - ``medium_index``: the track's position on the disc + - ``medium_total``: the number of tracks on the item's disc + - ``artist_sort``: name of the track artist for sorting + - ``disctitle``: name of the individual medium (subtitle) + - ``artist_credit``: Recording-specific artist name + + Only ``title`` and ``track_id`` are required. The rest of the fields + may be None. The indices ``index``, ``medium``, and ``medium_index`` + are all 1-based. + """ + def __init__(self, title, track_id, artist=None, artist_id=None, + length=None, index=None, medium=None, medium_index=None, + medium_total=None, artist_sort=None, disctitle=None, + artist_credit=None, data_source=None, data_url=None, + media=None): + self.title = title + self.track_id = track_id + self.artist = artist + self.artist_id = artist_id + self.length = length + self.index = index + self.media = media + self.medium = medium + self.medium_index = medium_index + self.medium_total = medium_total + self.artist_sort = artist_sort + self.disctitle = disctitle + self.artist_credit = artist_credit + self.data_source = data_source + self.data_url = data_url + + # As above, work around a bug in python-musicbrainz-ngs. + def decode(self, codec='utf8'): + """Ensure that all string attributes on this object are decoded + to Unicode. + """ + for fld in ['title', 'artist', 'medium', 'artist_sort', 'disctitle', + 'artist_credit', 'media']: + value = getattr(self, fld) + if isinstance(value, str): + setattr(self, fld, value.decode(codec, 'ignore')) + + +# Candidate distance scoring. + +# Parameters for string distance function. +# Words that can be moved to the end of a string using a comma. +SD_END_WORDS = ['the', 'a', 'an'] +# Reduced weights for certain portions of the string. +SD_PATTERNS = [ + (r'^the ', 0.1), + (r'[\[\(]?(ep|single)[\]\)]?', 0.0), + (r'[\[\(]?(featuring|feat|ft)[\. :].+', 0.1), + (r'\(.*?\)', 0.3), + (r'\[.*?\]', 0.3), + (r'(, )?(pt\.|part) .+', 0.2), +] +# Replacements to use before testing distance. +SD_REPLACE = [ + (r'&', 'and'), +] + + +def _string_dist_basic(str1, str2): + """Basic edit distance between two strings, ignoring + non-alphanumeric characters and case. Comparisons are based on a + transliteration/lowering to ASCII characters. Normalized by string + length. + """ + str1 = unidecode(str1) + str2 = unidecode(str2) + str1 = re.sub(r'[^a-z0-9]', '', str1.lower()) + str2 = re.sub(r'[^a-z0-9]', '', str2.lower()) + if not str1 and not str2: + return 0.0 + return levenshtein(str1, str2) / float(max(len(str1), len(str2))) + + +def string_dist(str1, str2): + """Gives an "intuitive" edit distance between two strings. This is + an edit distance, normalized by the string length, with a number of + tweaks that reflect intuition about text. + """ + if str1 is None and str2 is None: + return 0.0 + if str1 is None or str2 is None: + return 1.0 + + str1 = str1.lower() + str2 = str2.lower() + + # Don't penalize strings that move certain words to the end. For + # example, "the something" should be considered equal to + # "something, the". + for word in SD_END_WORDS: + if str1.endswith(', %s' % word): + str1 = '%s %s' % (word, str1[:-len(word) - 2]) + if str2.endswith(', %s' % word): + str2 = '%s %s' % (word, str2[:-len(word) - 2]) + + # Perform a couple of basic normalizing substitutions. + for pat, repl in SD_REPLACE: + str1 = re.sub(pat, repl, str1) + str2 = re.sub(pat, repl, str2) + + # Change the weight for certain string portions matched by a set + # of regular expressions. We gradually change the strings and build + # up penalties associated with parts of the string that were + # deleted. + base_dist = _string_dist_basic(str1, str2) + penalty = 0.0 + for pat, weight in SD_PATTERNS: + # Get strings that drop the pattern. + case_str1 = re.sub(pat, '', str1) + case_str2 = re.sub(pat, '', str2) + + if case_str1 != str1 or case_str2 != str2: + # If the pattern was present (i.e., it is deleted in the + # the current case), recalculate the distances for the + # modified strings. + case_dist = _string_dist_basic(case_str1, case_str2) + case_delta = max(0.0, base_dist - case_dist) + if case_delta == 0.0: + continue + + # Shift our baseline strings down (to avoid rematching the + # same part of the string) and add a scaled distance + # amount to the penalties. + str1 = case_str1 + str2 = case_str2 + base_dist = case_dist + penalty += weight * case_delta + + return base_dist + penalty + + +class LazyClassProperty(object): + """A decorator implementing a read-only property that is *lazy* in + the sense that the getter is only invoked once. Subsequent accesses + through *any* instance use the cached result. + """ + def __init__(self, getter): + self.getter = getter + self.computed = False + + def __get__(self, obj, owner): + if not self.computed: + self.value = self.getter(owner) + self.computed = True + return self.value + + +class Distance(object): + """Keeps track of multiple distance penalties. Provides a single + weighted distance for all penalties as well as a weighted distance + for each individual penalty. + """ + def __init__(self): + self._penalties = {} + + @LazyClassProperty + def _weights(cls): + """A dictionary from keys to floating-point weights. + """ + weights_view = config['match']['distance_weights'] + weights = {} + for key in weights_view.keys(): + weights[key] = weights_view[key].as_number() + return weights + + # Access the components and their aggregates. + + @property + def distance(self): + """Return a weighted and normalized distance across all + penalties. + """ + dist_max = self.max_distance + if dist_max: + return self.raw_distance / self.max_distance + return 0.0 + + @property + def max_distance(self): + """Return the maximum distance penalty (normalization factor). + """ + dist_max = 0.0 + for key, penalty in self._penalties.iteritems(): + dist_max += len(penalty) * self._weights[key] + return dist_max + + @property + def raw_distance(self): + """Return the raw (denormalized) distance. + """ + dist_raw = 0.0 + for key, penalty in self._penalties.iteritems(): + dist_raw += sum(penalty) * self._weights[key] + return dist_raw + + def items(self): + """Return a list of (key, dist) pairs, with `dist` being the + weighted distance, sorted from highest to lowest. Does not + include penalties with a zero value. + """ + list_ = [] + for key in self._penalties: + dist = self[key] + if dist: + list_.append((key, dist)) + # Convert distance into a negative float we can sort items in + # ascending order (for keys, when the penalty is equal) and + # still get the items with the biggest distance first. + return sorted(list_, key=lambda (key, dist): (0 - dist, key)) + + # Behave like a float. + + def __cmp__(self, other): + return cmp(self.distance, other) + + def __float__(self): + return self.distance + + def __sub__(self, other): + return self.distance - other + + def __rsub__(self, other): + return other - self.distance + + # Behave like a dict. + + def __getitem__(self, key): + """Returns the weighted distance for a named penalty. + """ + dist = sum(self._penalties[key]) * self._weights[key] + dist_max = self.max_distance + if dist_max: + return dist / dist_max + return 0.0 + + def __iter__(self): + return iter(self.items()) + + def __len__(self): + return len(self.items()) + + def keys(self): + return [key for key, _ in self.items()] + + def update(self, dist): + """Adds all the distance penalties from `dist`. + """ + if not isinstance(dist, Distance): + raise ValueError( + '`dist` must be a Distance object, not {0}'.format(type(dist)) + ) + for key, penalties in dist._penalties.iteritems(): + self._penalties.setdefault(key, []).extend(penalties) + + # Adding components. + + def _eq(self, value1, value2): + """Returns True if `value1` is equal to `value2`. `value1` may + be a compiled regular expression, in which case it will be + matched against `value2`. + """ + if isinstance(value1, re._pattern_type): + return bool(value1.match(value2)) + return value1 == value2 + + def add(self, key, dist): + """Adds a distance penalty. `key` must correspond with a + configured weight setting. `dist` must be a float between 0.0 + and 1.0, and will be added to any existing distance penalties + for the same key. + """ + if not 0.0 <= dist <= 1.0: + raise ValueError( + '`dist` must be between 0.0 and 1.0, not {0}'.format(dist) + ) + self._penalties.setdefault(key, []).append(dist) + + def add_equality(self, key, value, options): + """Adds a distance penalty of 1.0 if `value` doesn't match any + of the values in `options`. If an option is a compiled regular + expression, it will be considered equal if it matches against + `value`. + """ + if not isinstance(options, (list, tuple)): + options = [options] + for opt in options: + if self._eq(opt, value): + dist = 0.0 + break + else: + dist = 1.0 + self.add(key, dist) + + def add_expr(self, key, expr): + """Adds a distance penalty of 1.0 if `expr` evaluates to True, + or 0.0. + """ + if expr: + self.add(key, 1.0) + else: + self.add(key, 0.0) + + def add_number(self, key, number1, number2): + """Adds a distance penalty of 1.0 for each number of difference + between `number1` and `number2`, or 0.0 when there is no + difference. Use this when there is no upper limit on the + difference between the two numbers. + """ + diff = abs(number1 - number2) + if diff: + for i in range(diff): + self.add(key, 1.0) + else: + self.add(key, 0.0) + + def add_priority(self, key, value, options): + """Adds a distance penalty that corresponds to the position at + which `value` appears in `options`. A distance penalty of 0.0 + for the first option, or 1.0 if there is no matching option. If + an option is a compiled regular expression, it will be + considered equal if it matches against `value`. + """ + if not isinstance(options, (list, tuple)): + options = [options] + unit = 1.0 / (len(options) or 1) + for i, opt in enumerate(options): + if self._eq(opt, value): + dist = i * unit + break + else: + dist = 1.0 + self.add(key, dist) + + def add_ratio(self, key, number1, number2): + """Adds a distance penalty for `number1` as a ratio of `number2`. + `number1` is bound at 0 and `number2`. + """ + number = float(max(min(number1, number2), 0)) + if number2: + dist = number / number2 + else: + dist = 0.0 + self.add(key, dist) + + def add_string(self, key, str1, str2): + """Adds a distance penalty based on the edit distance between + `str1` and `str2`. + """ + dist = string_dist(str1, str2) + self.add(key, dist) + + +# Structures that compose all the information for a candidate match. + +AlbumMatch = namedtuple('AlbumMatch', ['distance', 'info', 'mapping', + 'extra_items', 'extra_tracks']) + +TrackMatch = namedtuple('TrackMatch', ['distance', 'info']) + + +# Aggregation of sources. + +def album_for_mbid(release_id): + """Get an AlbumInfo object for a MusicBrainz release ID. Return None + if the ID is not found. + """ + try: + return mb.album_for_id(release_id) + except mb.MusicBrainzAPIError as exc: + exc.log(log) + + +def track_for_mbid(recording_id): + """Get a TrackInfo object for a MusicBrainz recording ID. Return None + if the ID is not found. + """ + try: + return mb.track_for_id(recording_id) + except mb.MusicBrainzAPIError as exc: + exc.log(log) + + +def albums_for_id(album_id): + """Get a list of albums for an ID.""" + candidates = [album_for_mbid(album_id)] + candidates.extend(plugins.album_for_id(album_id)) + return filter(None, candidates) + + +def tracks_for_id(track_id): + """Get a list of tracks for an ID.""" + candidates = [track_for_mbid(track_id)] + candidates.extend(plugins.track_for_id(track_id)) + return filter(None, candidates) + + +def album_candidates(items, artist, album, va_likely): + """Search for album matches. ``items`` is a list of Item objects + that make up the album. ``artist`` and ``album`` are the respective + names (strings), which may be derived from the item list or may be + entered by the user. ``va_likely`` is a boolean indicating whether + the album is likely to be a "various artists" release. + """ + out = [] + + # Base candidates if we have album and artist to match. + if artist and album: + try: + out.extend(mb.match_album(artist, album, len(items))) + except mb.MusicBrainzAPIError as exc: + exc.log(log) + + # Also add VA matches from MusicBrainz where appropriate. + if va_likely and album: + try: + out.extend(mb.match_album(None, album, len(items))) + except mb.MusicBrainzAPIError as exc: + exc.log(log) + + # Candidates from plugins. + out.extend(plugins.candidates(items, artist, album, va_likely)) + + return out + + +def item_candidates(item, artist, title): + """Search for item matches. ``item`` is the Item to be matched. + ``artist`` and ``title`` are strings and either reflect the item or + are specified by the user. + """ + out = [] + + # MusicBrainz candidates. + if artist and title: + try: + out.extend(mb.match_track(artist, title)) + except mb.MusicBrainzAPIError as exc: + exc.log(log) + + # Plugin candidates. + out.extend(plugins.item_candidates(item, artist, title)) + + return out diff --git a/lib/beets/autotag/match.py b/lib/beets/autotag/match.py new file mode 100644 index 00000000..2d1f2007 --- /dev/null +++ b/lib/beets/autotag/match.py @@ -0,0 +1,492 @@ +# This file is part of beets. +# Copyright 2013, Adrian Sampson. +# +# Permission is hereby granted, free of charge, to any person obtaining +# a copy of this software and associated documentation files (the +# "Software"), to deal in the Software without restriction, including +# without limitation the rights to use, copy, modify, merge, publish, +# distribute, sublicense, and/or sell copies of the Software, and to +# permit persons to whom the Software is furnished to do so, subject to +# the following conditions: +# +# The above copyright notice and this permission notice shall be +# included in all copies or substantial portions of the Software. + +"""Matches existing metadata with canonical information to identify +releases and tracks. +""" +from __future__ import division + +import datetime +import logging +import re +from munkres import Munkres + +from beets import plugins +from beets import config +from beets.util import plurality +from beets.autotag import hooks +from beets.util.enumeration import OrderedEnum + +# Artist signals that indicate "various artists". These are used at the +# album level to determine whether a given release is likely a VA +# release and also on the track level to to remove the penalty for +# differing artists. +VA_ARTISTS = (u'', u'various artists', u'various', u'va', u'unknown') + +# Global logger. +log = logging.getLogger('beets') + + +# Recommendation enumeration. + +class Recommendation(OrderedEnum): + """Indicates a qualitative suggestion to the user about what should + be done with a given match. + """ + none = 0 + low = 1 + medium = 2 + strong = 3 + + +# Primary matching functionality. + +def current_metadata(items): + """Extract the likely current metadata for an album given a list of its + items. Return two dictionaries: + - The most common value for each field. + - Whether each field's value was unanimous (values are booleans). + """ + assert items # Must be nonempty. + + likelies = {} + consensus = {} + fields = ['artist', 'album', 'albumartist', 'year', 'disctotal', + 'mb_albumid', 'label', 'catalognum', 'country', 'media', + 'albumdisambig'] + for field in fields: + values = [item[field] for item in items if item] + likelies[field], freq = plurality(values) + consensus[field] = (freq == len(values)) + + # If there's an album artist consensus, use this for the artist. + if consensus['albumartist'] and likelies['albumartist']: + likelies['artist'] = likelies['albumartist'] + + return likelies, consensus + + +def assign_items(items, tracks): + """Given a list of Items and a list of TrackInfo objects, find the + best mapping between them. Returns a mapping from Items to TrackInfo + objects, a set of extra Items, and a set of extra TrackInfo + objects. These "extra" objects occur when there is an unequal number + of objects of the two types. + """ + # Construct the cost matrix. + costs = [] + for item in items: + row = [] + for i, track in enumerate(tracks): + row.append(track_distance(item, track)) + costs.append(row) + + # Find a minimum-cost bipartite matching. + matching = Munkres().compute(costs) + + # Produce the output matching. + mapping = dict((items[i], tracks[j]) for (i, j) in matching) + extra_items = list(set(items) - set(mapping.keys())) + extra_items.sort(key=lambda i: (i.disc, i.track, i.title)) + extra_tracks = list(set(tracks) - set(mapping.values())) + extra_tracks.sort(key=lambda t: (t.index, t.title)) + return mapping, extra_items, extra_tracks + + +def track_index_changed(item, track_info): + """Returns True if the item and track info index is different. Tolerates + per disc and per release numbering. + """ + return item.track not in (track_info.medium_index, track_info.index) + + +def track_distance(item, track_info, incl_artist=False): + """Determines the significance of a track metadata change. Returns a + Distance object. `incl_artist` indicates that a distance component should + be included for the track artist (i.e., for various-artist releases). + """ + dist = hooks.Distance() + + # Length. + if track_info.length: + diff = abs(item.length - track_info.length) - \ + config['match']['track_length_grace'].as_number() + dist.add_ratio('track_length', diff, + config['match']['track_length_max'].as_number()) + + # Title. + dist.add_string('track_title', item.title, track_info.title) + + # Artist. Only check if there is actually an artist in the track data. + if incl_artist and track_info.artist and \ + item.artist.lower() not in VA_ARTISTS: + dist.add_string('track_artist', item.artist, track_info.artist) + + # Track index. + if track_info.index and item.track: + dist.add_expr('track_index', track_index_changed(item, track_info)) + + # Track ID. + if item.mb_trackid: + dist.add_expr('track_id', item.mb_trackid != track_info.track_id) + + # Plugins. + dist.update(plugins.track_distance(item, track_info)) + + return dist + + +def distance(items, album_info, mapping): + """Determines how "significant" an album metadata change would be. + Returns a Distance object. `album_info` is an AlbumInfo object + reflecting the album to be compared. `items` is a sequence of all + Item objects that will be matched (order is not important). + `mapping` is a dictionary mapping Items to TrackInfo objects; the + keys are a subset of `items` and the values are a subset of + `album_info.tracks`. + """ + likelies, _ = current_metadata(items) + + dist = hooks.Distance() + + # Artist, if not various. + if not album_info.va: + dist.add_string('artist', likelies['artist'], album_info.artist) + + # Album. + dist.add_string('album', likelies['album'], album_info.album) + + # Current or preferred media. + if album_info.media: + # Preferred media options. + patterns = config['match']['preferred']['media'].as_str_seq() + options = [re.compile(r'(\d+x)?(%s)' % pat, re.I) for pat in patterns] + if options: + dist.add_priority('media', album_info.media, options) + # Current media. + elif likelies['media']: + dist.add_equality('media', album_info.media, likelies['media']) + + # Mediums. + if likelies['disctotal'] and album_info.mediums: + dist.add_number('mediums', likelies['disctotal'], album_info.mediums) + + # Prefer earliest release. + if album_info.year and config['match']['preferred']['original_year']: + # Assume 1889 (earliest first gramophone discs) if we don't know the + # original year. + original = album_info.original_year or 1889 + diff = abs(album_info.year - original) + diff_max = abs(datetime.date.today().year - original) + dist.add_ratio('year', diff, diff_max) + # Year. + elif likelies['year'] and album_info.year: + if likelies['year'] in (album_info.year, album_info.original_year): + # No penalty for matching release or original year. + dist.add('year', 0.0) + elif album_info.original_year: + # Prefer matchest closest to the release year. + diff = abs(likelies['year'] - album_info.year) + diff_max = abs(datetime.date.today().year - + album_info.original_year) + dist.add_ratio('year', diff, diff_max) + else: + # Full penalty when there is no original year. + dist.add('year', 1.0) + + # Preferred countries. + patterns = config['match']['preferred']['countries'].as_str_seq() + options = [re.compile(pat, re.I) for pat in patterns] + if album_info.country and options: + dist.add_priority('country', album_info.country, options) + # Country. + elif likelies['country'] and album_info.country: + dist.add_string('country', likelies['country'], album_info.country) + + # Label. + if likelies['label'] and album_info.label: + dist.add_string('label', likelies['label'], album_info.label) + + # Catalog number. + if likelies['catalognum'] and album_info.catalognum: + dist.add_string('catalognum', likelies['catalognum'], + album_info.catalognum) + + # Disambiguation. + if likelies['albumdisambig'] and album_info.albumdisambig: + dist.add_string('albumdisambig', likelies['albumdisambig'], + album_info.albumdisambig) + + # Album ID. + if likelies['mb_albumid']: + dist.add_equality('album_id', likelies['mb_albumid'], + album_info.album_id) + + # Tracks. + dist.tracks = {} + for item, track in mapping.iteritems(): + dist.tracks[track] = track_distance(item, track, album_info.va) + dist.add('tracks', dist.tracks[track].distance) + + # Missing tracks. + for i in range(len(album_info.tracks) - len(mapping)): + dist.add('missing_tracks', 1.0) + + # Unmatched tracks. + for i in range(len(items) - len(mapping)): + dist.add('unmatched_tracks', 1.0) + + # Plugins. + dist.update(plugins.album_distance(items, album_info, mapping)) + + return dist + + +def match_by_id(items): + """If the items are tagged with a MusicBrainz album ID, returns an + AlbumInfo object for the corresponding album. Otherwise, returns + None. + """ + # Is there a consensus on the MB album ID? + albumids = [item.mb_albumid for item in items if item.mb_albumid] + if not albumids: + log.debug(u'No album IDs found.') + return None + + # If all album IDs are equal, look up the album. + if bool(reduce(lambda x, y: x if x == y else (), albumids)): + albumid = albumids[0] + log.debug(u'Searching for discovered album ID: {0}'.format(albumid)) + return hooks.album_for_mbid(albumid) + else: + log.debug(u'No album ID consensus.') + + +def _recommendation(results): + """Given a sorted list of AlbumMatch or TrackMatch objects, return a + recommendation based on the results' distances. + + If the recommendation is higher than the configured maximum for + an applied penalty, the recommendation will be downgraded to the + configured maximum for that penalty. + """ + if not results: + # No candidates: no recommendation. + return Recommendation.none + + # Basic distance thresholding. + min_dist = results[0].distance + if min_dist < config['match']['strong_rec_thresh'].as_number(): + # Strong recommendation level. + rec = Recommendation.strong + elif min_dist <= config['match']['medium_rec_thresh'].as_number(): + # Medium recommendation level. + rec = Recommendation.medium + elif len(results) == 1: + # Only a single candidate. + rec = Recommendation.low + elif results[1].distance - min_dist >= \ + config['match']['rec_gap_thresh'].as_number(): + # Gap between first two candidates is large. + rec = Recommendation.low + else: + # No conclusion. Return immediately. Can't be downgraded any further. + return Recommendation.none + + # Downgrade to the max rec if it is lower than the current rec for an + # applied penalty. + keys = set(min_dist.keys()) + if isinstance(results[0], hooks.AlbumMatch): + for track_dist in min_dist.tracks.values(): + keys.update(track_dist.keys()) + max_rec_view = config['match']['max_rec'] + for key in keys: + if key in max_rec_view.keys(): + max_rec = max_rec_view[key].as_choice({ + 'strong': Recommendation.strong, + 'medium': Recommendation.medium, + 'low': Recommendation.low, + 'none': Recommendation.none, + }) + rec = min(rec, max_rec) + + return rec + + +def _add_candidate(items, results, info): + """Given a candidate AlbumInfo object, attempt to add the candidate + to the output dictionary of AlbumMatch objects. This involves + checking the track count, ordering the items, checking for + duplicates, and calculating the distance. + """ + log.debug(u'Candidate: {0} - {1}'.format(info.artist, info.album)) + + # Discard albums with zero tracks. + if not info.tracks: + log.debug('No tracks.') + return + + # Don't duplicate. + if info.album_id in results: + log.debug(u'Duplicate.') + return + + # Discard matches without required tags. + for req_tag in config['match']['required'].as_str_seq(): + if getattr(info, req_tag) is None: + log.debug(u'Ignored. Missing required tag: {0}'.format(req_tag)) + return + + # Find mapping between the items and the track info. + mapping, extra_items, extra_tracks = assign_items(items, info.tracks) + + # Get the change distance. + dist = distance(items, info, mapping) + + # Skip matches with ignored penalties. + penalties = [key for _, key in dist] + for penalty in config['match']['ignored'].as_str_seq(): + if penalty in penalties: + log.debug(u'Ignored. Penalty: {0}'.format(penalty)) + return + + log.debug(u'Success. Distance: {0}'.format(dist)) + results[info.album_id] = hooks.AlbumMatch(dist, info, mapping, + extra_items, extra_tracks) + + +def tag_album(items, search_artist=None, search_album=None, + search_id=None): + """Return a tuple of a artist name, an album name, a list of + `AlbumMatch` candidates from the metadata backend, and a + `Recommendation`. + + The artist and album are the most common values of these fields + among `items`. + + The `AlbumMatch` objects are generated by searching the metadata + backends. By default, the metadata of the items is used for the + search. This can be customized by setting the parameters. The + `mapping` field of the album has the matched `items` as keys. + + The recommendation is calculated from the match qualitiy of the + candidates. + """ + # Get current metadata. + likelies, consensus = current_metadata(items) + cur_artist = likelies['artist'] + cur_album = likelies['album'] + log.debug(u'Tagging {0} - {1}'.format(cur_artist, cur_album)) + + # The output result (distance, AlbumInfo) tuples (keyed by MB album + # ID). + candidates = {} + + # Search by explicit ID. + if search_id is not None: + log.debug(u'Searching for album ID: {0}'.format(search_id)) + search_cands = hooks.albums_for_id(search_id) + + # Use existing metadata or text search. + else: + # Try search based on current ID. + id_info = match_by_id(items) + if id_info: + _add_candidate(items, candidates, id_info) + rec = _recommendation(candidates.values()) + log.debug(u'Album ID match recommendation is {0}'.format(str(rec))) + if candidates and not config['import']['timid']: + # If we have a very good MBID match, return immediately. + # Otherwise, this match will compete against metadata-based + # matches. + if rec == Recommendation.strong: + log.debug(u'ID match.') + return cur_artist, cur_album, candidates.values(), rec + + # Search terms. + if not (search_artist and search_album): + # No explicit search terms -- use current metadata. + search_artist, search_album = cur_artist, cur_album + log.debug(u'Search terms: {0} - {1}'.format(search_artist, + search_album)) + + # Is this album likely to be a "various artist" release? + va_likely = ((not consensus['artist']) or + (search_artist.lower() in VA_ARTISTS) or + any(item.comp for item in items)) + log.debug(u'Album might be VA: {0}'.format(str(va_likely))) + + # Get the results from the data sources. + search_cands = hooks.album_candidates(items, search_artist, + search_album, va_likely) + + log.debug(u'Evaluating {0} candidates.'.format(len(search_cands))) + for info in search_cands: + _add_candidate(items, candidates, info) + + # Sort and get the recommendation. + candidates = sorted(candidates.itervalues()) + rec = _recommendation(candidates) + return cur_artist, cur_album, candidates, rec + + +def tag_item(item, search_artist=None, search_title=None, + search_id=None): + """Attempts to find metadata for a single track. Returns a + `(candidates, recommendation)` pair where `candidates` is a list of + TrackMatch objects. `search_artist` and `search_title` may be used + to override the current metadata for the purposes of the MusicBrainz + title; likewise `search_id`. + """ + # Holds candidates found so far: keys are MBIDs; values are + # (distance, TrackInfo) pairs. + candidates = {} + + # First, try matching by MusicBrainz ID. + trackid = search_id or item.mb_trackid + if trackid: + log.debug(u'Searching for track ID: {0}'.format(trackid)) + for track_info in hooks.tracks_for_id(trackid): + dist = track_distance(item, track_info, incl_artist=True) + candidates[track_info.track_id] = \ + hooks.TrackMatch(dist, track_info) + # If this is a good match, then don't keep searching. + rec = _recommendation(candidates.values()) + if rec == Recommendation.strong and not config['import']['timid']: + log.debug(u'Track ID match.') + return candidates.values(), rec + + # If we're searching by ID, don't proceed. + if search_id is not None: + if candidates: + return candidates.values(), rec + else: + return [], Recommendation.none + + # Search terms. + if not (search_artist and search_title): + search_artist, search_title = item.artist, item.title + log.debug(u'Item search terms: {0} - {1}'.format(search_artist, + search_title)) + + # Get and evaluate candidate metadata. + for track_info in hooks.item_candidates(item, search_artist, search_title): + dist = track_distance(item, track_info, incl_artist=True) + candidates[track_info.track_id] = hooks.TrackMatch(dist, track_info) + + # Sort by distance and return with recommendation. + log.debug(u'Found {0} candidates.'.format(len(candidates))) + candidates = sorted(candidates.itervalues()) + rec = _recommendation(candidates) + return candidates, rec diff --git a/lib/beets/autotag/mb.py b/lib/beets/autotag/mb.py new file mode 100644 index 00000000..d063f627 --- /dev/null +++ b/lib/beets/autotag/mb.py @@ -0,0 +1,407 @@ +# This file is part of beets. +# Copyright 2013, Adrian Sampson. +# +# Permission is hereby granted, free of charge, to any person obtaining +# a copy of this software and associated documentation files (the +# "Software"), to deal in the Software without restriction, including +# without limitation the rights to use, copy, modify, merge, publish, +# distribute, sublicense, and/or sell copies of the Software, and to +# permit persons to whom the Software is furnished to do so, subject to +# the following conditions: +# +# The above copyright notice and this permission notice shall be +# included in all copies or substantial portions of the Software. + +"""Searches for albums in the MusicBrainz database. +""" +import logging +import musicbrainzngs +import re +import traceback +from urlparse import urljoin + +import beets.autotag.hooks +import beets +from beets import util +from beets import config + +SEARCH_LIMIT = 5 +VARIOUS_ARTISTS_ID = '89ad4ac3-39f7-470e-963a-56509c546377' +BASE_URL = 'http://musicbrainz.org/' + +musicbrainzngs.set_useragent('beets', beets.__version__, + 'http://beets.radbox.org/') + + +class MusicBrainzAPIError(util.HumanReadableException): + """An error while talking to MusicBrainz. The `query` field is the + parameter to the action and may have any type. + """ + def __init__(self, reason, verb, query, tb=None): + self.query = query + super(MusicBrainzAPIError, self).__init__(reason, verb, tb) + + def get_message(self): + return u'{0} in {1} with query {2}'.format( + self._reasonstr(), self.verb, repr(self.query) + ) + +log = logging.getLogger('beets') + +RELEASE_INCLUDES = ['artists', 'media', 'recordings', 'release-groups', + 'labels', 'artist-credits', 'aliases'] +TRACK_INCLUDES = ['artists', 'aliases'] + + +def track_url(trackid): + return urljoin(BASE_URL, 'recording/' + trackid) + + +def album_url(albumid): + return urljoin(BASE_URL, 'release/' + albumid) + + +def configure(): + """Set up the python-musicbrainz-ngs module according to settings + from the beets configuration. This should be called at startup. + """ + musicbrainzngs.set_hostname(config['musicbrainz']['host'].get(unicode)) + musicbrainzngs.set_rate_limit( + config['musicbrainz']['ratelimit_interval'].as_number(), + config['musicbrainz']['ratelimit'].get(int), + ) + + +def _preferred_alias(aliases): + """Given an list of alias structures for an artist credit, select + and return the user's preferred alias alias or None if no matching + alias is found. + """ + if not aliases: + return + + # Only consider aliases that have locales set. + aliases = [a for a in aliases if 'locale' in a] + + # Search configured locales in order. + for locale in config['import']['languages'].as_str_seq(): + # Find matching primary aliases for this locale. + matches = [a for a in aliases + if a['locale'] == locale and 'primary' in a] + # Skip to the next locale if we have no matches + if not matches: + continue + + return matches[0] + + +def _flatten_artist_credit(credit): + """Given a list representing an ``artist-credit`` block, flatten the + data into a triple of joined artist name strings: canonical, sort, and + credit. + """ + artist_parts = [] + artist_sort_parts = [] + artist_credit_parts = [] + for el in credit: + if isinstance(el, basestring): + # Join phrase. + artist_parts.append(el) + artist_credit_parts.append(el) + artist_sort_parts.append(el) + + else: + alias = _preferred_alias(el['artist'].get('alias-list', ())) + + # An artist. + if alias: + cur_artist_name = alias['alias'] + else: + cur_artist_name = el['artist']['name'] + artist_parts.append(cur_artist_name) + + # Artist sort name. + if alias: + artist_sort_parts.append(alias['sort-name']) + elif 'sort-name' in el['artist']: + artist_sort_parts.append(el['artist']['sort-name']) + else: + artist_sort_parts.append(cur_artist_name) + + # Artist credit. + if 'name' in el: + artist_credit_parts.append(el['name']) + else: + artist_credit_parts.append(cur_artist_name) + + return ( + ''.join(artist_parts), + ''.join(artist_sort_parts), + ''.join(artist_credit_parts), + ) + + +def track_info(recording, index=None, medium=None, medium_index=None, + medium_total=None): + """Translates a MusicBrainz recording result dictionary into a beets + ``TrackInfo`` object. Three parameters are optional and are used + only for tracks that appear on releases (non-singletons): ``index``, + the overall track number; ``medium``, the disc number; + ``medium_index``, the track's index on its medium; ``medium_total``, + the number of tracks on the medium. Each number is a 1-based index. + """ + info = beets.autotag.hooks.TrackInfo( + recording['title'], + recording['id'], + index=index, + medium=medium, + medium_index=medium_index, + medium_total=medium_total, + data_url=track_url(recording['id']), + ) + + if recording.get('artist-credit'): + # Get the artist names. + info.artist, info.artist_sort, info.artist_credit = \ + _flatten_artist_credit(recording['artist-credit']) + + # Get the ID and sort name of the first artist. + artist = recording['artist-credit'][0]['artist'] + info.artist_id = artist['id'] + + if recording.get('length'): + info.length = int(recording['length']) / (1000.0) + + info.decode() + return info + + +def _set_date_str(info, date_str, original=False): + """Given a (possibly partial) YYYY-MM-DD string and an AlbumInfo + object, set the object's release date fields appropriately. If + `original`, then set the original_year, etc., fields. + """ + if date_str: + date_parts = date_str.split('-') + for key in ('year', 'month', 'day'): + if date_parts: + date_part = date_parts.pop(0) + try: + date_num = int(date_part) + except ValueError: + continue + + if original: + key = 'original_' + key + setattr(info, key, date_num) + + +def album_info(release): + """Takes a MusicBrainz release result dictionary and returns a beets + AlbumInfo object containing the interesting data about that release. + """ + # Get artist name using join phrases. + artist_name, artist_sort_name, artist_credit_name = \ + _flatten_artist_credit(release['artist-credit']) + + # Basic info. + track_infos = [] + index = 0 + for medium in release['medium-list']: + disctitle = medium.get('title') + format = medium.get('format') + for track in medium['track-list']: + # Basic information from the recording. + index += 1 + ti = track_info( + track['recording'], + index, + int(medium['position']), + int(track['position']), + len(medium['track-list']), + ) + ti.disctitle = disctitle + ti.media = format + + # Prefer track data, where present, over recording data. + if track.get('title'): + ti.title = track['title'] + if track.get('artist-credit'): + # Get the artist names. + ti.artist, ti.artist_sort, ti.artist_credit = \ + _flatten_artist_credit(track['artist-credit']) + ti.artist_id = track['artist-credit'][0]['artist']['id'] + if track.get('length'): + ti.length = int(track['length']) / (1000.0) + + track_infos.append(ti) + + info = beets.autotag.hooks.AlbumInfo( + release['title'], + release['id'], + artist_name, + release['artist-credit'][0]['artist']['id'], + track_infos, + mediums=len(release['medium-list']), + artist_sort=artist_sort_name, + artist_credit=artist_credit_name, + data_source='MusicBrainz', + data_url=album_url(release['id']), + ) + info.va = info.artist_id == VARIOUS_ARTISTS_ID + info.asin = release.get('asin') + info.releasegroup_id = release['release-group']['id'] + info.country = release.get('country') + info.albumstatus = release.get('status') + + # Build up the disambiguation string from the release group and release. + disambig = [] + if release['release-group'].get('disambiguation'): + disambig.append(release['release-group'].get('disambiguation')) + if release.get('disambiguation'): + disambig.append(release.get('disambiguation')) + info.albumdisambig = u', '.join(disambig) + + # Release type not always populated. + if 'type' in release['release-group']: + reltype = release['release-group']['type'] + if reltype: + info.albumtype = reltype.lower() + + # Release dates. + release_date = release.get('date') + release_group_date = release['release-group'].get('first-release-date') + if not release_date: + # Fall back if release-specific date is not available. + release_date = release_group_date + _set_date_str(info, release_date, False) + _set_date_str(info, release_group_date, True) + + # Label name. + if release.get('label-info-list'): + label_info = release['label-info-list'][0] + if label_info.get('label'): + label = label_info['label']['name'] + if label != '[no label]': + info.label = label + info.catalognum = label_info.get('catalog-number') + + # Text representation data. + if release.get('text-representation'): + rep = release['text-representation'] + info.script = rep.get('script') + info.language = rep.get('language') + + # Media (format). + if release['medium-list']: + first_medium = release['medium-list'][0] + info.media = first_medium.get('format') + + info.decode() + return info + + +def match_album(artist, album, tracks=None, limit=SEARCH_LIMIT): + """Searches for a single album ("release" in MusicBrainz parlance) + and returns an iterator over AlbumInfo objects. May raise a + MusicBrainzAPIError. + + The query consists of an artist name, an album name, and, + optionally, a number of tracks on the album. + """ + # Build search criteria. + criteria = {'release': album.lower().strip()} + if artist is not None: + criteria['artist'] = artist.lower().strip() + else: + # Various Artists search. + criteria['arid'] = VARIOUS_ARTISTS_ID + if tracks is not None: + criteria['tracks'] = str(tracks) + + # Abort if we have no search terms. + if not any(criteria.itervalues()): + return + + try: + res = musicbrainzngs.search_releases(limit=limit, **criteria) + except musicbrainzngs.MusicBrainzError as exc: + raise MusicBrainzAPIError(exc, 'release search', criteria, + traceback.format_exc()) + for release in res['release-list']: + # The search result is missing some data (namely, the tracks), + # so we just use the ID and fetch the rest of the information. + albuminfo = album_for_id(release['id']) + if albuminfo is not None: + yield albuminfo + + +def match_track(artist, title, limit=SEARCH_LIMIT): + """Searches for a single track and returns an iterable of TrackInfo + objects. May raise a MusicBrainzAPIError. + """ + criteria = { + 'artist': artist.lower().strip(), + 'recording': title.lower().strip(), + } + + if not any(criteria.itervalues()): + return + + try: + res = musicbrainzngs.search_recordings(limit=limit, **criteria) + except musicbrainzngs.MusicBrainzError as exc: + raise MusicBrainzAPIError(exc, 'recording search', criteria, + traceback.format_exc()) + for recording in res['recording-list']: + yield track_info(recording) + + +def _parse_id(s): + """Search for a MusicBrainz ID in the given string and return it. If + no ID can be found, return None. + """ + # Find the first thing that looks like a UUID/MBID. + match = re.search('[a-f0-9]{8}(-[a-f0-9]{4}){3}-[a-f0-9]{12}', s) + if match: + return match.group() + + +def album_for_id(releaseid): + """Fetches an album by its MusicBrainz ID and returns an AlbumInfo + object or None if the album is not found. May raise a + MusicBrainzAPIError. + """ + albumid = _parse_id(releaseid) + if not albumid: + log.debug(u'Invalid MBID ({0}).'.format(releaseid)) + return + try: + res = musicbrainzngs.get_release_by_id(albumid, + RELEASE_INCLUDES) + except musicbrainzngs.ResponseError: + log.debug(u'Album ID match failed.') + return None + except musicbrainzngs.MusicBrainzError as exc: + raise MusicBrainzAPIError(exc, 'get release by ID', albumid, + traceback.format_exc()) + return album_info(res['release']) + + +def track_for_id(releaseid): + """Fetches a track by its MusicBrainz ID. Returns a TrackInfo object + or None if no track is found. May raise a MusicBrainzAPIError. + """ + trackid = _parse_id(releaseid) + if not trackid: + log.debug(u'Invalid MBID ({0}).'.format(releaseid)) + return + try: + res = musicbrainzngs.get_recording_by_id(trackid, TRACK_INCLUDES) + except musicbrainzngs.ResponseError: + log.debug(u'Track ID match failed.') + return None + except musicbrainzngs.MusicBrainzError as exc: + raise MusicBrainzAPIError(exc, 'get recording by ID', trackid, + traceback.format_exc()) + return track_info(res['recording']) diff --git a/lib/beets/config_default.yaml b/lib/beets/config_default.yaml new file mode 100644 index 00000000..78f16d05 --- /dev/null +++ b/lib/beets/config_default.yaml @@ -0,0 +1,109 @@ +library: library.db +directory: ~/Music + +import: + write: yes + copy: yes + move: no + link: no + delete: no + resume: ask + incremental: no + quiet_fallback: skip + none_rec_action: ask + timid: no + log: + autotag: yes + quiet: no + singletons: no + default_action: apply + languages: [] + detail: no + flat: no + group_albums: no + pretend: false + +clutter: ["Thumbs.DB", ".DS_Store"] +ignore: [".*", "*~", "System Volume Information"] +replace: + '[\\/]': _ + '^\.': _ + '[\x00-\x1f]': _ + '[<>:"\?\*\|]': _ + '\.$': _ + '\s+$': '' + '^\s+': '' +path_sep_replace: _ +asciify_paths: false +art_filename: cover +max_filename_length: 0 + +plugins: [] +pluginpath: [] +threaded: yes +color: yes +timeout: 5.0 +per_disc_numbering: no +verbose: no +terminal_encoding: utf8 +original_date: no +id3v23: no + +ui: + terminal_width: 80 + length_diff_thresh: 10.0 + +list_format_item: $artist - $album - $title +list_format_album: $albumartist - $album +time_format: '%Y-%m-%d %H:%M:%S' + +sort_album: albumartist+ album+ +sort_item: artist+ album+ disc+ track+ + +paths: + default: $albumartist/$album%aunique{}/$track $title + singleton: Non-Album/$artist/$title + comp: Compilations/$album%aunique{}/$track $title + +statefile: state.pickle + +musicbrainz: + host: musicbrainz.org + ratelimit: 1 + ratelimit_interval: 1.0 + +match: + strong_rec_thresh: 0.04 + medium_rec_thresh: 0.25 + rec_gap_thresh: 0.25 + max_rec: + missing_tracks: medium + unmatched_tracks: medium + distance_weights: + source: 2.0 + artist: 3.0 + album: 3.0 + media: 1.0 + mediums: 1.0 + year: 1.0 + country: 0.5 + label: 0.5 + catalognum: 0.5 + albumdisambig: 0.5 + album_id: 5.0 + tracks: 2.0 + missing_tracks: 0.9 + unmatched_tracks: 0.6 + track_title: 3.0 + track_artist: 2.0 + track_index: 1.0 + track_length: 2.0 + track_id: 5.0 + preferred: + countries: [] + media: [] + original_year: no + ignored: [] + required: [] + track_length_grace: 10 + track_length_max: 30 diff --git a/lib/beets/dbcore/__init__.py b/lib/beets/dbcore/__init__.py new file mode 100644 index 00000000..c364fdfc --- /dev/null +++ b/lib/beets/dbcore/__init__.py @@ -0,0 +1,25 @@ +# This file is part of beets. +# Copyright 2014, Adrian Sampson. +# +# Permission is hereby granted, free of charge, to any person obtaining +# a copy of this software and associated documentation files (the +# "Software"), to deal in the Software without restriction, including +# without limitation the rights to use, copy, modify, merge, publish, +# distribute, sublicense, and/or sell copies of the Software, and to +# permit persons to whom the Software is furnished to do so, subject to +# the following conditions: +# +# The above copyright notice and this permission notice shall be +# included in all copies or substantial portions of the Software. + +"""DBCore is an abstract database package that forms the basis for beets' +Library. +""" +from .db import Model, Database +from .query import Query, FieldQuery, MatchQuery, AndQuery, OrQuery +from .types import Type +from .queryparse import query_from_strings +from .queryparse import sort_from_strings +from .queryparse import parse_sorted_query + +# flake8: noqa diff --git a/lib/beets/dbcore/db.py b/lib/beets/dbcore/db.py new file mode 100644 index 00000000..0c786daa --- /dev/null +++ b/lib/beets/dbcore/db.py @@ -0,0 +1,820 @@ +# This file is part of beets. +# Copyright 2014, Adrian Sampson. +# +# Permission is hereby granted, free of charge, to any person obtaining +# a copy of this software and associated documentation files (the +# "Software"), to deal in the Software without restriction, including +# without limitation the rights to use, copy, modify, merge, publish, +# distribute, sublicense, and/or sell copies of the Software, and to +# permit persons to whom the Software is furnished to do so, subject to +# the following conditions: +# +# The above copyright notice and this permission notice shall be +# included in all copies or substantial portions of the Software. + +"""The central Model and Database constructs for DBCore. +""" +import time +import os +from collections import defaultdict +import threading +import sqlite3 +import contextlib +import collections + +import beets +from beets.util.functemplate import Template +from beets.dbcore import types +from .query import MatchQuery, NullSort, TrueQuery + + +class FormattedMapping(collections.Mapping): + """A `dict`-like formatted view of a model. + + The accessor `mapping[key]` returns the formated version of + `model[key]` as a unicode string. + + If `for_path` is true, all path separators in the formatted values + are replaced. + """ + + def __init__(self, model, for_path=False): + self.for_path = for_path + self.model = model + self.model_keys = model.keys(True) + + def __getitem__(self, key): + if key in self.model_keys: + return self._get_formatted(self.model, key) + else: + raise KeyError(key) + + def __iter__(self): + return iter(self.model_keys) + + def __len__(self): + return len(self.model_keys) + + def get(self, key, default=None): + if default is None: + default = self.model._type(key).format(None) + return super(FormattedMapping, self).get(key, default) + + def _get_formatted(self, model, key): + value = model._type(key).format(model.get(key)) + if isinstance(value, bytes): + value = value.decode('utf8', 'ignore') + + if self.for_path: + sep_repl = beets.config['path_sep_replace'].get(unicode) + for sep in (os.path.sep, os.path.altsep): + if sep: + value = value.replace(sep, sep_repl) + + return value + + +# Abstract base for model classes. + +class Model(object): + """An abstract object representing an object in the database. Model + objects act like dictionaries (i.e., the allow subscript access like + ``obj['field']``). The same field set is available via attribute + access as a shortcut (i.e., ``obj.field``). Three kinds of attributes are + available: + + * **Fixed attributes** come from a predetermined list of field + names. These fields correspond to SQLite table columns and are + thus fast to read, write, and query. + * **Flexible attributes** are free-form and do not need to be listed + ahead of time. + * **Computed attributes** are read-only fields computed by a getter + function provided by a plugin. + + Access to all three field types is uniform: ``obj.field`` works the + same regardless of whether ``field`` is fixed, flexible, or + computed. + + Model objects can optionally be associated with a `Library` object, + in which case they can be loaded and stored from the database. Dirty + flags are used to track which fields need to be stored. + """ + + # Abstract components (to be provided by subclasses). + + _table = None + """The main SQLite table name. + """ + + _flex_table = None + """The flex field SQLite table name. + """ + + _fields = {} + """A mapping indicating available "fixed" fields on this type. The + keys are field names and the values are `Type` objects. + """ + + _search_fields = () + """The fields that should be queried by default by unqualified query + terms. + """ + + _types = {} + """Optional Types for non-fixed (i.e., flexible and computed) fields. + """ + + _sorts = {} + """Optional named sort criteria. The keys are strings and the values + are subclasses of `Sort`. + """ + + _always_dirty = False + """By default, fields only become "dirty" when their value actually + changes. Enabling this flag marks fields as dirty even when the new + value is the same as the old value (e.g., `o.f = o.f`). + """ + + @classmethod + def _getters(cls): + """Return a mapping from field names to getter functions. + """ + # We could cache this if it becomes a performance problem to + # gather the getter mapping every time. + raise NotImplementedError() + + def _template_funcs(self): + """Return a mapping from function names to text-transformer + functions. + """ + # As above: we could consider caching this result. + raise NotImplementedError() + + # Basic operation. + + def __init__(self, db=None, **values): + """Create a new object with an optional Database association and + initial field values. + """ + self._db = db + self._dirty = set() + self._values_fixed = {} + self._values_flex = {} + + # Initial contents. + self.update(values) + self.clear_dirty() + + @classmethod + def _awaken(cls, db=None, fixed_values={}, flex_values={}): + """Create an object with values drawn from the database. + + This is a performance optimization: the checks involved with + ordinary construction are bypassed. + """ + obj = cls(db) + for key, value in fixed_values.iteritems(): + obj._values_fixed[key] = cls._type(key).from_sql(value) + for key, value in flex_values.iteritems(): + obj._values_flex[key] = cls._type(key).from_sql(value) + return obj + + def __repr__(self): + return '{0}({1})'.format( + type(self).__name__, + ', '.join('{0}={1!r}'.format(k, v) for k, v in dict(self).items()), + ) + + def clear_dirty(self): + """Mark all fields as *clean* (i.e., not needing to be stored to + the database). + """ + self._dirty = set() + + def _check_db(self, need_id=True): + """Ensure that this object is associated with a database row: it + has a reference to a database (`_db`) and an id. A ValueError + exception is raised otherwise. + """ + if not self._db: + raise ValueError('{0} has no database'.format(type(self).__name__)) + if need_id and not self.id: + raise ValueError('{0} has no id'.format(type(self).__name__)) + + # Essential field accessors. + + @classmethod + def _type(self, key): + """Get the type of a field, a `Type` instance. + + If the field has no explicit type, it is given the base `Type`, + which does no conversion. + """ + return self._fields.get(key) or self._types.get(key) or types.DEFAULT + + def __getitem__(self, key): + """Get the value for a field. Raise a KeyError if the field is + not available. + """ + getters = self._getters() + if key in getters: # Computed. + return getters[key](self) + elif key in self._fields: # Fixed. + return self._values_fixed.get(key) + elif key in self._values_flex: # Flexible. + return self._values_flex[key] + else: + raise KeyError(key) + + def __setitem__(self, key, value): + """Assign the value for a field. + """ + # Choose where to place the value. + if key in self._fields: + source = self._values_fixed + else: + source = self._values_flex + + # If the field has a type, filter the value. + value = self._type(key).normalize(value) + + # Assign value and possibly mark as dirty. + old_value = source.get(key) + source[key] = value + if self._always_dirty or old_value != value: + self._dirty.add(key) + + def __delitem__(self, key): + """Remove a flexible attribute from the model. + """ + if key in self._values_flex: # Flexible. + del self._values_flex[key] + self._dirty.add(key) # Mark for dropping on store. + elif key in self._getters(): # Computed. + raise KeyError('computed field {0} cannot be deleted'.format(key)) + elif key in self._fields: # Fixed. + raise KeyError('fixed field {0} cannot be deleted'.format(key)) + else: + raise KeyError('no such field {0}'.format(key)) + + def keys(self, computed=False): + """Get a list of available field names for this object. The + `computed` parameter controls whether computed (plugin-provided) + fields are included in the key list. + """ + base_keys = list(self._fields) + self._values_flex.keys() + if computed: + return base_keys + self._getters().keys() + else: + return base_keys + + # Act like a dictionary. + + def update(self, values): + """Assign all values in the given dict. + """ + for key, value in values.items(): + self[key] = value + + def items(self): + """Iterate over (key, value) pairs that this object contains. + Computed fields are not included. + """ + for key in self: + yield key, self[key] + + def get(self, key, default=None): + """Get the value for a given key or `default` if it does not + exist. + """ + if key in self: + return self[key] + else: + return default + + def __contains__(self, key): + """Determine whether `key` is an attribute on this object. + """ + return key in self.keys(True) + + def __iter__(self): + """Iterate over the available field names (excluding computed + fields). + """ + return iter(self.keys()) + + # Convenient attribute access. + + def __getattr__(self, key): + if key.startswith('_'): + raise AttributeError('model has no attribute {0!r}'.format(key)) + else: + try: + return self[key] + except KeyError: + raise AttributeError('no such field {0!r}'.format(key)) + + def __setattr__(self, key, value): + if key.startswith('_'): + super(Model, self).__setattr__(key, value) + else: + self[key] = value + + def __delattr__(self, key): + if key.startswith('_'): + super(Model, self).__delattr__(key) + else: + del self[key] + + # Database interaction (CRUD methods). + + def store(self): + """Save the object's metadata into the library database. + """ + self._check_db() + + # Build assignments for query. + assignments = [] + subvars = [] + for key in self._fields: + if key != 'id' and key in self._dirty: + self._dirty.remove(key) + assignments.append(key + '=?') + value = self._type(key).to_sql(self[key]) + subvars.append(value) + assignments = ','.join(assignments) + + with self._db.transaction() as tx: + # Main table update. + if assignments: + query = 'UPDATE {0} SET {1} WHERE id=?'.format( + self._table, assignments + ) + subvars.append(self.id) + tx.mutate(query, subvars) + + # Modified/added flexible attributes. + for key, value in self._values_flex.items(): + if key in self._dirty: + self._dirty.remove(key) + tx.mutate( + 'INSERT INTO {0} ' + '(entity_id, key, value) ' + 'VALUES (?, ?, ?);'.format(self._flex_table), + (self.id, key, value), + ) + + # Deleted flexible attributes. + for key in self._dirty: + tx.mutate( + 'DELETE FROM {0} ' + 'WHERE entity_id=? AND key=?'.format(self._flex_table), + (self.id, key) + ) + + self.clear_dirty() + + def load(self): + """Refresh the object's metadata from the library database. + """ + self._check_db() + stored_obj = self._db._get(type(self), self.id) + assert stored_obj is not None, "object {0} not in DB".format(self.id) + self._values_fixed = {} + self._values_flex = {} + self.update(dict(stored_obj)) + self.clear_dirty() + + def remove(self): + """Remove the object's associated rows from the database. + """ + self._check_db() + with self._db.transaction() as tx: + tx.mutate( + 'DELETE FROM {0} WHERE id=?'.format(self._table), + (self.id,) + ) + tx.mutate( + 'DELETE FROM {0} WHERE entity_id=?'.format(self._flex_table), + (self.id,) + ) + + def add(self, db=None): + """Add the object to the library database. This object must be + associated with a database; you can provide one via the `db` + parameter or use the currently associated database. + + The object's `id` and `added` fields are set along with any + current field values. + """ + if db: + self._db = db + self._check_db(False) + + with self._db.transaction() as tx: + new_id = tx.mutate( + 'INSERT INTO {0} DEFAULT VALUES'.format(self._table) + ) + self.id = new_id + self.added = time.time() + + # Mark every non-null field as dirty and store. + for key in self: + if self[key] is not None: + self._dirty.add(key) + self.store() + + # Formatting and templating. + + _formatter = FormattedMapping + + def formatted(self, for_path=False): + """Get a mapping containing all values on this object formatted + as human-readable unicode strings. + """ + return self._formatter(self, for_path) + + def evaluate_template(self, template, for_path=False): + """Evaluate a template (a string or a `Template` object) using + the object's fields. If `for_path` is true, then no new path + separators will be added to the template. + """ + # Perform substitution. + if isinstance(template, basestring): + template = Template(template) + return template.substitute(self.formatted(for_path), + self._template_funcs()) + + # Parsing. + + @classmethod + def _parse(cls, key, string): + """Parse a string as a value for the given key. + """ + if not isinstance(string, basestring): + raise TypeError("_parse() argument must be a string") + + return cls._type(key).parse(string) + + +# Database controller and supporting interfaces. + +class Results(object): + """An item query result set. Iterating over the collection lazily + constructs LibModel objects that reflect database rows. + """ + def __init__(self, model_class, rows, db, query=None, sort=None): + """Create a result set that will construct objects of type + `model_class`. + + `model_class` is a subclass of `LibModel` that will be + constructed. `rows` is a query result: a list of mappings. The + new objects will be associated with the database `db`. + + If `query` is provided, it is used as a predicate to filter the + results for a "slow query" that cannot be evaluated by the + database directly. If `sort` is provided, it is used to sort the + full list of results before returning. This means it is a "slow + sort" and all objects must be built before returning the first + one. + """ + self.model_class = model_class + self.rows = rows + self.db = db + self.query = query + self.sort = sort + + # We keep a queue of rows we haven't yet consumed for + # materialization. We preserve the original total number of + # rows. + self._rows = rows + self._row_count = len(rows) + + # The materialized objects corresponding to rows that have been + # consumed. + self._objects = [] + + def _get_objects(self): + """Construct and generate Model objects for they query. The + objects are returned in the order emitted from the database; no + slow sort is applied. + + For performance, this generator caches materialized objects to + avoid constructing them more than once. This way, iterating over + a `Results` object a second time should be much faster than the + first. + """ + index = 0 # Position in the materialized objects. + while index < len(self._objects) or self._rows: + # Are there previously-materialized objects to produce? + if index < len(self._objects): + yield self._objects[index] + index += 1 + + # Otherwise, we consume another row, materialize its object + # and produce it. + else: + while self._rows: + row = self._rows.pop(0) + obj = self._make_model(row) + # If there is a slow-query predicate, ensurer that the + # object passes it. + if not self.query or self.query.match(obj): + self._objects.append(obj) + index += 1 + yield obj + break + + def __iter__(self): + """Construct and generate Model objects for all matching + objects, in sorted order. + """ + if self.sort: + # Slow sort. Must build the full list first. + objects = self.sort.sort(list(self._get_objects())) + return iter(objects) + + else: + # Objects are pre-sorted (i.e., by the database). + return self._get_objects() + + def _make_model(self, row): + # Get the flexible attributes for the object. + with self.db.transaction() as tx: + flex_rows = tx.query( + 'SELECT * FROM {0} WHERE entity_id=?'.format( + self.model_class._flex_table + ), + (row['id'],) + ) + + cols = dict(row) + values = dict((k, v) for (k, v) in cols.items() + if not k[:4] == 'flex') + flex_values = dict((row['key'], row['value']) for row in flex_rows) + + # Construct the Python object + obj = self.model_class._awaken(self.db, values, flex_values) + return obj + + def __len__(self): + """Get the number of matching objects. + """ + if not self._rows: + # Fully materialized. Just count the objects. + return len(self._objects) + + elif self.query: + # A slow query. Fall back to testing every object. + count = 0 + for obj in self: + count += 1 + return count + + else: + # A fast query. Just count the rows. + return self._row_count + + def __nonzero__(self): + """Does this result contain any objects? + """ + return bool(len(self)) + + def __getitem__(self, n): + """Get the nth item in this result set. This is inefficient: all + items up to n are materialized and thrown away. + """ + if not self._rows and not self.sort: + # Fully materialized and already in order. Just look up the + # object. + return self._objects[n] + + it = iter(self) + try: + for i in range(n): + it.next() + return it.next() + except StopIteration: + raise IndexError('result index {0} out of range'.format(n)) + + def get(self): + """Return the first matching object, or None if no objects + match. + """ + it = iter(self) + try: + return it.next() + except StopIteration: + return None + + +class Transaction(object): + """A context manager for safe, concurrent access to the database. + All SQL commands should be executed through a transaction. + """ + def __init__(self, db): + self.db = db + + def __enter__(self): + """Begin a transaction. This transaction may be created while + another is active in a different thread. + """ + with self.db._tx_stack() as stack: + first = not stack + stack.append(self) + if first: + # Beginning a "root" transaction, which corresponds to an + # SQLite transaction. + self.db._db_lock.acquire() + return self + + def __exit__(self, exc_type, exc_value, traceback): + """Complete a transaction. This must be the most recently + entered but not yet exited transaction. If it is the last active + transaction, the database updates are committed. + """ + with self.db._tx_stack() as stack: + assert stack.pop() is self + empty = not stack + if empty: + # Ending a "root" transaction. End the SQLite transaction. + self.db._connection().commit() + self.db._db_lock.release() + + def query(self, statement, subvals=()): + """Execute an SQL statement with substitution values and return + a list of rows from the database. + """ + cursor = self.db._connection().execute(statement, subvals) + return cursor.fetchall() + + def mutate(self, statement, subvals=()): + """Execute an SQL statement with substitution values and return + the row ID of the last affected row. + """ + cursor = self.db._connection().execute(statement, subvals) + return cursor.lastrowid + + def script(self, statements): + """Execute a string containing multiple SQL statements.""" + self.db._connection().executescript(statements) + + +class Database(object): + """A container for Model objects that wraps an SQLite database as + the backend. + """ + _models = () + """The Model subclasses representing tables in this database. + """ + + def __init__(self, path): + self.path = path + + self._connections = {} + self._tx_stacks = defaultdict(list) + + # A lock to protect the _connections and _tx_stacks maps, which + # both map thread IDs to private resources. + self._shared_map_lock = threading.Lock() + + # A lock to protect access to the database itself. SQLite does + # allow multiple threads to access the database at the same + # time, but many users were experiencing crashes related to this + # capability: where SQLite was compiled without HAVE_USLEEP, its + # backoff algorithm in the case of contention was causing + # whole-second sleeps (!) that would trigger its internal + # timeout. Using this lock ensures only one SQLite transaction + # is active at a time. + self._db_lock = threading.Lock() + + # Set up database schema. + for model_cls in self._models: + self._make_table(model_cls._table, model_cls._fields) + self._make_attribute_table(model_cls._flex_table) + + # Primitive access control: connections and transactions. + + def _connection(self): + """Get a SQLite connection object to the underlying database. + One connection object is created per thread. + """ + thread_id = threading.current_thread().ident + with self._shared_map_lock: + if thread_id in self._connections: + return self._connections[thread_id] + else: + # Make a new connection. + conn = sqlite3.connect( + self.path, + timeout=beets.config['timeout'].as_number(), + ) + + # Access SELECT results like dictionaries. + conn.row_factory = sqlite3.Row + + self._connections[thread_id] = conn + return conn + + @contextlib.contextmanager + def _tx_stack(self): + """A context manager providing access to the current thread's + transaction stack. The context manager synchronizes access to + the stack map. Transactions should never migrate across threads. + """ + thread_id = threading.current_thread().ident + with self._shared_map_lock: + yield self._tx_stacks[thread_id] + + def transaction(self): + """Get a :class:`Transaction` object for interacting directly + with the underlying SQLite database. + """ + return Transaction(self) + + # Schema setup and migration. + + def _make_table(self, table, fields): + """Set up the schema of the database. `fields` is a mapping + from field names to `Type`s. Columns are added if necessary. + """ + # Get current schema. + with self.transaction() as tx: + rows = tx.query('PRAGMA table_info(%s)' % table) + current_fields = set([row[1] for row in rows]) + + field_names = set(fields.keys()) + if current_fields.issuperset(field_names): + # Table exists and has all the required columns. + return + + if not current_fields: + # No table exists. + columns = [] + for name, typ in fields.items(): + columns.append('{0} {1}'.format(name, typ.sql)) + setup_sql = 'CREATE TABLE {0} ({1});\n'.format(table, + ', '.join(columns)) + + else: + # Table exists does not match the field set. + setup_sql = '' + for name, typ in fields.items(): + if name in current_fields: + continue + setup_sql += 'ALTER TABLE {0} ADD COLUMN {1} {2};\n'.format( + table, name, typ.sql + ) + + with self.transaction() as tx: + tx.script(setup_sql) + + def _make_attribute_table(self, flex_table): + """Create a table and associated index for flexible attributes + for the given entity (if they don't exist). + """ + with self.transaction() as tx: + tx.script(""" + CREATE TABLE IF NOT EXISTS {0} ( + id INTEGER PRIMARY KEY, + entity_id INTEGER, + key TEXT, + value TEXT, + UNIQUE(entity_id, key) ON CONFLICT REPLACE); + CREATE INDEX IF NOT EXISTS {0}_by_entity + ON {0} (entity_id); + """.format(flex_table)) + + # Querying. + + def _fetch(self, model_cls, query=None, sort=None): + """Fetch the objects of type `model_cls` matching the given + query. The query may be given as a string, string sequence, a + Query object, or None (to fetch everything). `sort` is an + `Sort` object. + """ + query = query or TrueQuery() # A null query. + sort = sort or NullSort() # Unsorted. + where, subvals = query.clause() + order_by = sort.order_clause() + + sql = ("SELECT * FROM {0} WHERE {1} {2}").format( + model_cls._table, + where or '1', + "ORDER BY {0}".format(order_by) if order_by else '', + ) + + with self.transaction() as tx: + rows = tx.query(sql, subvals) + + return Results( + model_cls, rows, self, + None if where else query, # Slow query component. + sort if sort.is_slow() else None, # Slow sort component. + ) + + def _get(self, model_cls, id): + """Get a Model object by its id or None if the id does not + exist. + """ + return self._fetch(model_cls, MatchQuery('id', id)).get() diff --git a/lib/beets/dbcore/query.py b/lib/beets/dbcore/query.py new file mode 100644 index 00000000..5a116eb2 --- /dev/null +++ b/lib/beets/dbcore/query.py @@ -0,0 +1,654 @@ +# This file is part of beets. +# Copyright 2014, Adrian Sampson. +# +# Permission is hereby granted, free of charge, to any person obtaining +# a copy of this software and associated documentation files (the +# "Software"), to deal in the Software without restriction, including +# without limitation the rights to use, copy, modify, merge, publish, +# distribute, sublicense, and/or sell copies of the Software, and to +# permit persons to whom the Software is furnished to do so, subject to +# the following conditions: +# +# The above copyright notice and this permission notice shall be +# included in all copies or substantial portions of the Software. + +"""The Query type hierarchy for DBCore. +""" +import re +from operator import attrgetter +from beets import util +from datetime import datetime, timedelta + + +class Query(object): + """An abstract class representing a query into the item database. + """ + def clause(self): + """Generate an SQLite expression implementing the query. + Return a clause string, a sequence of substitution values for + the clause, and a Query object representing the "remainder" + Returns (clause, subvals) where clause is a valid sqlite + WHERE clause implementing the query and subvals is a list of + items to be substituted for ?s in the clause. + """ + return None, () + + def match(self, item): + """Check whether this query matches a given Item. Can be used to + perform queries on arbitrary sets of Items. + """ + raise NotImplementedError + + +class FieldQuery(Query): + """An abstract query that searches in a specific field for a + pattern. Subclasses must provide a `value_match` class method, which + determines whether a certain pattern string matches a certain value + string. Subclasses may also provide `col_clause` to implement the + same matching functionality in SQLite. + """ + def __init__(self, field, pattern, fast=True): + self.field = field + self.pattern = pattern + self.fast = fast + + def col_clause(self): + return None, () + + def clause(self): + if self.fast: + return self.col_clause() + else: + # Matching a flexattr. This is a slow query. + return None, () + + @classmethod + def value_match(cls, pattern, value): + """Determine whether the value matches the pattern. Both + arguments are strings. + """ + raise NotImplementedError() + + def match(self, item): + return self.value_match(self.pattern, item.get(self.field)) + + +class MatchQuery(FieldQuery): + """A query that looks for exact matches in an item field.""" + def col_clause(self): + return self.field + " = ?", [self.pattern] + + @classmethod + def value_match(cls, pattern, value): + return pattern == value + + +class NoneQuery(FieldQuery): + + def __init__(self, field, fast=True): + self.field = field + self.fast = fast + + def col_clause(self): + return self.field + " IS NULL", () + + @classmethod + def match(self, item): + try: + return item[self.field] is None + except KeyError: + return True + + +class StringFieldQuery(FieldQuery): + """A FieldQuery that converts values to strings before matching + them. + """ + @classmethod + def value_match(cls, pattern, value): + """Determine whether the value matches the pattern. The value + may have any type. + """ + return cls.string_match(pattern, util.as_string(value)) + + @classmethod + def string_match(cls, pattern, value): + """Determine whether the value matches the pattern. Both + arguments are strings. Subclasses implement this method. + """ + raise NotImplementedError() + + +class SubstringQuery(StringFieldQuery): + """A query that matches a substring in a specific item field.""" + def col_clause(self): + pattern = (self.pattern + .replace('\\', '\\\\') + .replace('%', '\\%') + .replace('_', '\\_')) + search = '%' + pattern + '%' + clause = self.field + " like ? escape '\\'" + subvals = [search] + return clause, subvals + + @classmethod + def string_match(cls, pattern, value): + return pattern.lower() in value.lower() + + +class RegexpQuery(StringFieldQuery): + """A query that matches a regular expression in a specific item + field. + """ + @classmethod + def string_match(cls, pattern, value): + try: + res = re.search(pattern, value) + except re.error: + # Invalid regular expression. + return False + return res is not None + + +class BooleanQuery(MatchQuery): + """Matches a boolean field. Pattern should either be a boolean or a + string reflecting a boolean. + """ + def __init__(self, field, pattern, fast=True): + super(BooleanQuery, self).__init__(field, pattern, fast) + if isinstance(pattern, basestring): + self.pattern = util.str2bool(pattern) + self.pattern = int(self.pattern) + + +class BytesQuery(MatchQuery): + """Match a raw bytes field (i.e., a path). This is a necessary hack + to work around the `sqlite3` module's desire to treat `str` and + `unicode` equivalently in Python 2. Always use this query instead of + `MatchQuery` when matching on BLOB values. + """ + def __init__(self, field, pattern): + super(BytesQuery, self).__init__(field, pattern) + + # Use a buffer representation of the pattern for SQLite + # matching. This instructs SQLite to treat the blob as binary + # rather than encoded Unicode. + if isinstance(self.pattern, basestring): + # Implicitly coerce Unicode strings to their bytes + # equivalents. + if isinstance(self.pattern, unicode): + self.pattern = self.pattern.encode('utf8') + self.buf_pattern = buffer(self.pattern) + elif isinstance(self.pattern, buffer): + self.buf_pattern = self.pattern + self.pattern = bytes(self.pattern) + + def col_clause(self): + return self.field + " = ?", [self.buf_pattern] + + +class NumericQuery(FieldQuery): + """Matches numeric fields. A syntax using Ruby-style range ellipses + (``..``) lets users specify one- or two-sided ranges. For example, + ``year:2001..`` finds music released since the turn of the century. + """ + def _convert(self, s): + """Convert a string to a numeric type (float or int). If the + string cannot be converted, return None. + """ + # This is really just a bit of fun premature optimization. + try: + return int(s) + except ValueError: + try: + return float(s) + except ValueError: + return None + + def __init__(self, field, pattern, fast=True): + super(NumericQuery, self).__init__(field, pattern, fast) + + parts = pattern.split('..', 1) + if len(parts) == 1: + # No range. + self.point = self._convert(parts[0]) + self.rangemin = None + self.rangemax = None + else: + # One- or two-sided range. + self.point = None + self.rangemin = self._convert(parts[0]) + self.rangemax = self._convert(parts[1]) + + def match(self, item): + if self.field not in item: + return False + value = item[self.field] + if isinstance(value, basestring): + value = self._convert(value) + + if self.point is not None: + return value == self.point + else: + if self.rangemin is not None and value < self.rangemin: + return False + if self.rangemax is not None and value > self.rangemax: + return False + return True + + def col_clause(self): + if self.point is not None: + return self.field + '=?', (self.point,) + else: + if self.rangemin is not None and self.rangemax is not None: + return (u'{0} >= ? AND {0} <= ?'.format(self.field), + (self.rangemin, self.rangemax)) + elif self.rangemin is not None: + return u'{0} >= ?'.format(self.field), (self.rangemin,) + elif self.rangemax is not None: + return u'{0} <= ?'.format(self.field), (self.rangemax,) + else: + return '1', () + + +class CollectionQuery(Query): + """An abstract query class that aggregates other queries. Can be + indexed like a list to access the sub-queries. + """ + def __init__(self, subqueries=()): + self.subqueries = subqueries + + # Act like a sequence. + + def __len__(self): + return len(self.subqueries) + + def __getitem__(self, key): + return self.subqueries[key] + + def __iter__(self): + return iter(self.subqueries) + + def __contains__(self, item): + return item in self.subqueries + + def clause_with_joiner(self, joiner): + """Returns a clause created by joining together the clauses of + all subqueries with the string joiner (padded by spaces). + """ + clause_parts = [] + subvals = [] + for subq in self.subqueries: + subq_clause, subq_subvals = subq.clause() + if not subq_clause: + # Fall back to slow query. + return None, () + clause_parts.append('(' + subq_clause + ')') + subvals += subq_subvals + clause = (' ' + joiner + ' ').join(clause_parts) + return clause, subvals + + +class AnyFieldQuery(CollectionQuery): + """A query that matches if a given FieldQuery subclass matches in + any field. The individual field query class is provided to the + constructor. + """ + def __init__(self, pattern, fields, cls): + self.pattern = pattern + self.fields = fields + self.query_class = cls + + subqueries = [] + for field in self.fields: + subqueries.append(cls(field, pattern, True)) + super(AnyFieldQuery, self).__init__(subqueries) + + def clause(self): + return self.clause_with_joiner('or') + + def match(self, item): + for subq in self.subqueries: + if subq.match(item): + return True + return False + + +class MutableCollectionQuery(CollectionQuery): + """A collection query whose subqueries may be modified after the + query is initialized. + """ + def __setitem__(self, key, value): + self.subqueries[key] = value + + def __delitem__(self, key): + del self.subqueries[key] + + +class AndQuery(MutableCollectionQuery): + """A conjunction of a list of other queries.""" + def clause(self): + return self.clause_with_joiner('and') + + def match(self, item): + return all([q.match(item) for q in self.subqueries]) + + +class OrQuery(MutableCollectionQuery): + """A conjunction of a list of other queries.""" + def clause(self): + return self.clause_with_joiner('or') + + def match(self, item): + return any([q.match(item) for q in self.subqueries]) + + +class TrueQuery(Query): + """A query that always matches.""" + def clause(self): + return '1', () + + def match(self, item): + return True + + +class FalseQuery(Query): + """A query that never matches.""" + def clause(self): + return '0', () + + def match(self, item): + return False + + +# Time/date queries. + +def _to_epoch_time(date): + """Convert a `datetime` object to an integer number of seconds since + the (local) Unix epoch. + """ + epoch = datetime.fromtimestamp(0) + delta = date - epoch + try: + return int(delta.total_seconds()) + except AttributeError: + # datetime.timedelta.total_seconds() is not available on Python 2.6 + return delta.seconds + delta.days * 24 * 3600 + + +def _parse_periods(pattern): + """Parse a string containing two dates separated by two dots (..). + Return a pair of `Period` objects. + """ + parts = pattern.split('..', 1) + if len(parts) == 1: + instant = Period.parse(parts[0]) + return (instant, instant) + else: + start = Period.parse(parts[0]) + end = Period.parse(parts[1]) + return (start, end) + + +class Period(object): + """A period of time given by a date, time and precision. + + Example: 2014-01-01 10:50:30 with precision 'month' represents all + instants of time during January 2014. + """ + + precisions = ('year', 'month', 'day') + date_formats = ('%Y', '%Y-%m', '%Y-%m-%d') + + def __init__(self, date, precision): + """Create a period with the given date (a `datetime` object) and + precision (a string, one of "year", "month", or "day"). + """ + if precision not in Period.precisions: + raise ValueError('Invalid precision ' + str(precision)) + self.date = date + self.precision = precision + + @classmethod + def parse(cls, string): + """Parse a date and return a `Period` object or `None` if the + string is empty. + """ + if not string: + return None + ordinal = string.count('-') + if ordinal >= len(cls.date_formats): + # Too many components. + return None + date_format = cls.date_formats[ordinal] + try: + date = datetime.strptime(string, date_format) + except ValueError: + # Parsing failed. + return None + precision = cls.precisions[ordinal] + return cls(date, precision) + + def open_right_endpoint(self): + """Based on the precision, convert the period to a precise + `datetime` for use as a right endpoint in a right-open interval. + """ + precision = self.precision + date = self.date + if 'year' == self.precision: + return date.replace(year=date.year + 1, month=1) + elif 'month' == precision: + if (date.month < 12): + return date.replace(month=date.month + 1) + else: + return date.replace(year=date.year + 1, month=1) + elif 'day' == precision: + return date + timedelta(days=1) + else: + raise ValueError('unhandled precision ' + str(precision)) + + +class DateInterval(object): + """A closed-open interval of dates. + + A left endpoint of None means since the beginning of time. + A right endpoint of None means towards infinity. + """ + + def __init__(self, start, end): + if start is not None and end is not None and not start < end: + raise ValueError("start date {0} is not before end date {1}" + .format(start, end)) + self.start = start + self.end = end + + @classmethod + def from_periods(cls, start, end): + """Create an interval with two Periods as the endpoints. + """ + end_date = end.open_right_endpoint() if end is not None else None + start_date = start.date if start is not None else None + return cls(start_date, end_date) + + def contains(self, date): + if self.start is not None and date < self.start: + return False + if self.end is not None and date >= self.end: + return False + return True + + def __str__(self): + return'[{0}, {1})'.format(self.start, self.end) + + +class DateQuery(FieldQuery): + """Matches date fields stored as seconds since Unix epoch time. + + Dates can be specified as ``year-month-day`` strings where only year + is mandatory. + + The value of a date field can be matched against a date interval by + using an ellipsis interval syntax similar to that of NumericQuery. + """ + def __init__(self, field, pattern, fast=True): + super(DateQuery, self).__init__(field, pattern, fast) + start, end = _parse_periods(pattern) + self.interval = DateInterval.from_periods(start, end) + + def match(self, item): + timestamp = float(item[self.field]) + date = datetime.utcfromtimestamp(timestamp) + return self.interval.contains(date) + + _clause_tmpl = "{0} {1} ?" + + def col_clause(self): + clause_parts = [] + subvals = [] + + if self.interval.start: + clause_parts.append(self._clause_tmpl.format(self.field, ">=")) + subvals.append(_to_epoch_time(self.interval.start)) + + if self.interval.end: + clause_parts.append(self._clause_tmpl.format(self.field, "<")) + subvals.append(_to_epoch_time(self.interval.end)) + + if clause_parts: + # One- or two-sided interval. + clause = ' AND '.join(clause_parts) + else: + # Match any date. + clause = '1' + return clause, subvals + + +# Sorting. + +class Sort(object): + """An abstract class representing a sort operation for a query into + the item database. + """ + + def order_clause(self): + """Generates a SQL fragment to be used in a ORDER BY clause, or + None if no fragment is used (i.e., this is a slow sort). + """ + return None + + def sort(self, items): + """Sort the list of objects and return a list. + """ + return sorted(items) + + def is_slow(self): + """Indicate whether this query is *slow*, meaning that it cannot + be executed in SQL and must be executed in Python. + """ + return False + + +class MultipleSort(Sort): + """Sort that encapsulates multiple sub-sorts. + """ + + def __init__(self, sorts=None): + self.sorts = sorts or [] + + def add_sort(self, sort): + self.sorts.append(sort) + + def _sql_sorts(self): + """Return the list of sub-sorts for which we can be (at least + partially) fast. + + A contiguous suffix of fast (SQL-capable) sub-sorts are + executable in SQL. The remaining, even if they are fast + independently, must be executed slowly. + """ + sql_sorts = [] + for sort in reversed(self.sorts): + if not sort.order_clause() is None: + sql_sorts.append(sort) + else: + break + sql_sorts.reverse() + return sql_sorts + + def order_clause(self): + order_strings = [] + for sort in self._sql_sorts(): + order = sort.order_clause() + order_strings.append(order) + + return ", ".join(order_strings) + + def is_slow(self): + for sort in self.sorts: + if sort.is_slow(): + return True + return False + + def sort(self, items): + slow_sorts = [] + switch_slow = False + for sort in reversed(self.sorts): + if switch_slow: + slow_sorts.append(sort) + elif sort.order_clause() is None: + switch_slow = True + slow_sorts.append(sort) + else: + pass + + for sort in slow_sorts: + items = sort.sort(items) + return items + + def __repr__(self): + return u'MultipleSort({0})'.format(repr(self.sorts)) + + +class FieldSort(Sort): + """An abstract sort criterion that orders by a specific field (of + any kind). + """ + def __init__(self, field, ascending=True): + self.field = field + self.ascending = ascending + + def sort(self, objs): + # TODO: Conversion and null-detection here. In Python 3, + # comparisons with None fail. We should also support flexible + # attributes with different types without falling over. + return sorted(objs, key=attrgetter(self.field), + reverse=not self.ascending) + + def __repr__(self): + return u'<{0}: {1}{2}>'.format( + type(self).__name__, + self.field, + '+' if self.ascending else '-', + ) + + +class FixedFieldSort(FieldSort): + """Sort object to sort on a fixed field. + """ + def order_clause(self): + order = "ASC" if self.ascending else "DESC" + return "{0} {1}".format(self.field, order) + + +class SlowFieldSort(FieldSort): + """A sort criterion by some model field other than a fixed field: + i.e., a computed or flexible field. + """ + def is_slow(self): + return True + + +class NullSort(Sort): + """No sorting. Leave results unsorted.""" + def sort(items): + return items diff --git a/lib/beets/dbcore/queryparse.py b/lib/beets/dbcore/queryparse.py new file mode 100644 index 00000000..90963696 --- /dev/null +++ b/lib/beets/dbcore/queryparse.py @@ -0,0 +1,180 @@ +# This file is part of beets. +# Copyright 2014, Adrian Sampson. +# +# Permission is hereby granted, free of charge, to any person obtaining +# a copy of this software and associated documentation files (the +# "Software"), to deal in the Software without restriction, including +# without limitation the rights to use, copy, modify, merge, publish, +# distribute, sublicense, and/or sell copies of the Software, and to +# permit persons to whom the Software is furnished to do so, subject to +# the following conditions: +# +# The above copyright notice and this permission notice shall be +# included in all copies or substantial portions of the Software. + +"""Parsing of strings into DBCore queries. +""" +import re +import itertools +from . import query + + +PARSE_QUERY_PART_REGEX = re.compile( + # Non-capturing optional segment for the keyword. + r'(?:' + r'(\S+?)' # The field key. + r'(? (None, 'stapler', SubstringQuery) + 'color:red' -> ('color', 'red', SubstringQuery) + ':^Quiet' -> (None, '^Quiet', RegexpQuery) + 'color::b..e' -> ('color', 'b..e', RegexpQuery) + + Prefixes may be "escaped" with a backslash to disable the keying + behavior. + """ + part = part.strip() + match = PARSE_QUERY_PART_REGEX.match(part) + + assert match # Regex should always match. + key = match.group(1) + term = match.group(2).replace('\:', ':') + + # Match the search term against the list of prefixes. + for pre, query_class in prefixes.items(): + if term.startswith(pre): + return key, term[len(pre):], query_class + + # No matching prefix: use type-based or fallback/default query. + query_class = query_classes.get(key, default_class) + return key, term, query_class + + +def construct_query_part(model_cls, prefixes, query_part): + """Create a query from a single query component, `query_part`, for + querying instances of `model_cls`. Return a `Query` instance. + """ + # Shortcut for empty query parts. + if not query_part: + return query.TrueQuery() + + # Get the query classes for each possible field. + query_classes = {} + for k, t in itertools.chain(model_cls._fields.items(), + model_cls._types.items()): + query_classes[k] = t.query + + # Parse the string. + key, pattern, query_class = \ + parse_query_part(query_part, query_classes, prefixes) + + # No key specified. + if key is None: + if issubclass(query_class, query.FieldQuery): + # The query type matches a specific field, but none was + # specified. So we use a version of the query that matches + # any field. + return query.AnyFieldQuery(pattern, model_cls._search_fields, + query_class) + else: + # Other query type. + return query_class(pattern) + + key = key.lower() + return query_class(key.lower(), pattern, key in model_cls._fields) + + +def query_from_strings(query_cls, model_cls, prefixes, query_parts): + """Creates a collection query of type `query_cls` from a list of + strings in the format used by parse_query_part. `model_cls` + determines how queries are constructed from strings. + """ + subqueries = [] + for part in query_parts: + subqueries.append(construct_query_part(model_cls, prefixes, part)) + if not subqueries: # No terms in query. + subqueries = [query.TrueQuery()] + return query_cls(subqueries) + + +def construct_sort_part(model_cls, part): + """Create a `Sort` from a single string criterion. + + `model_cls` is the `Model` being queried. `part` is a single string + ending in ``+`` or ``-`` indicating the sort. + """ + assert part, "part must be a field name and + or -" + field = part[:-1] + assert field, "field is missing" + direction = part[-1] + assert direction in ('+', '-'), "part must end with + or -" + is_ascending = direction == '+' + + if field in model_cls._sorts: + sort = model_cls._sorts[field](model_cls, is_ascending) + elif field in model_cls._fields: + sort = query.FixedFieldSort(field, is_ascending) + else: + # Flexible or computed. + sort = query.SlowFieldSort(field, is_ascending) + return sort + + +def sort_from_strings(model_cls, sort_parts): + """Create a `Sort` from a list of sort criteria (strings). + """ + if not sort_parts: + return query.NullSort() + else: + sort = query.MultipleSort() + for part in sort_parts: + sort.add_sort(construct_sort_part(model_cls, part)) + return sort + + +def parse_sorted_query(model_cls, parts, prefixes={}, + query_cls=query.AndQuery): + """Given a list of strings, create the `Query` and `Sort` that they + represent. + """ + # Separate query token and sort token. + query_parts = [] + sort_parts = [] + for part in parts: + if part.endswith((u'+', u'-')) and u':' not in part: + sort_parts.append(part) + else: + query_parts.append(part) + + # Parse each. + q = query_from_strings( + query_cls, model_cls, prefixes, query_parts + ) + s = sort_from_strings(model_cls, sort_parts) + return q, s diff --git a/lib/beets/dbcore/types.py b/lib/beets/dbcore/types.py new file mode 100644 index 00000000..82346e70 --- /dev/null +++ b/lib/beets/dbcore/types.py @@ -0,0 +1,208 @@ +# This file is part of beets. +# Copyright 2014, Adrian Sampson. +# +# Permission is hereby granted, free of charge, to any person obtaining +# a copy of this software and associated documentation files (the +# "Software"), to deal in the Software without restriction, including +# without limitation the rights to use, copy, modify, merge, publish, +# distribute, sublicense, and/or sell copies of the Software, and to +# permit persons to whom the Software is furnished to do so, subject to +# the following conditions: +# +# The above copyright notice and this permission notice shall be +# included in all copies or substantial portions of the Software. + +"""Representation of type information for DBCore model fields. +""" +from . import query +from beets.util import str2bool + + +# Abstract base. + +class Type(object): + """An object encapsulating the type of a model field. Includes + information about how to store, query, format, and parse a given + field. + """ + + sql = u'TEXT' + """The SQLite column type for the value. + """ + + query = query.SubstringQuery + """The `Query` subclass to be used when querying the field. + """ + + model_type = unicode + """The Python type that is used to represent the value in the model. + + The model is guaranteed to return a value of this type if the field + is accessed. To this end, the constructor is used by the `normalize` + and `from_sql` methods and the `default` property. + """ + + @property + def null(self): + """The value to be exposed when the underlying value is None. + """ + return self.model_type() + + def format(self, value): + """Given a value of this type, produce a Unicode string + representing the value. This is used in template evaluation. + """ + if value is None: + value = self.null + # `self.null` might be `None` + if value is None: + value = u'' + if isinstance(value, bytes): + value = value.decode('utf8', 'ignore') + + return unicode(value) + + def parse(self, string): + """Parse a (possibly human-written) string and return the + indicated value of this type. + """ + try: + return self.model_type(string) + except ValueError: + return self.null + + def normalize(self, value): + """Given a value that will be assigned into a field of this + type, normalize the value to have the appropriate type. This + base implementation only reinterprets `None`. + """ + if value is None: + return self.null + else: + # TODO This should eventually be replaced by + # `self.model_type(value)` + return value + + def from_sql(self, sql_value): + """Receives the value stored in the SQL backend and return the + value to be stored in the model. + + For fixed fields the type of `value` is determined by the column + type affinity given in the `sql` property and the SQL to Python + mapping of the database adapter. For more information see: + http://www.sqlite.org/datatype3.html + https://docs.python.org/2/library/sqlite3.html#sqlite-and-python-types + + Flexible fields have the type afinity `TEXT`. This means the + `sql_value` is either a `buffer` or a `unicode` object` and the + method must handle these in addition. + """ + if isinstance(sql_value, buffer): + sql_value = bytes(sql_value).decode('utf8', 'ignore') + if isinstance(sql_value, unicode): + return self.parse(sql_value) + else: + return self.normalize(sql_value) + + def to_sql(self, model_value): + """Convert a value as stored in the model object to a value used + by the database adapter. + """ + return model_value + + +# Reusable types. + +class Default(Type): + null = None + + +class Integer(Type): + """A basic integer type. + """ + sql = u'INTEGER' + query = query.NumericQuery + model_type = int + + +class PaddedInt(Integer): + """An integer field that is formatted with a given number of digits, + padded with zeroes. + """ + def __init__(self, digits): + self.digits = digits + + def format(self, value): + return u'{0:0{1}d}'.format(value or 0, self.digits) + + +class ScaledInt(Integer): + """An integer whose formatting operation scales the number by a + constant and adds a suffix. Good for units with large magnitudes. + """ + def __init__(self, unit, suffix=u''): + self.unit = unit + self.suffix = suffix + + def format(self, value): + return u'{0}{1}'.format((value or 0) // self.unit, self.suffix) + + +class Id(Integer): + """An integer used as the row id or a foreign key in a SQLite table. + This type is nullable: None values are not translated to zero. + """ + null = None + + def __init__(self, primary=True): + if primary: + self.sql = u'INTEGER PRIMARY KEY' + + +class Float(Type): + """A basic floating-point type. + """ + sql = u'REAL' + query = query.NumericQuery + model_type = float + + def format(self, value): + return u'{0:.1f}'.format(value or 0.0) + + +class NullFloat(Float): + """Same as `Float`, but does not normalize `None` to `0.0`. + """ + null = None + + +class String(Type): + """A Unicode string type. + """ + sql = u'TEXT' + query = query.SubstringQuery + + +class Boolean(Type): + """A boolean type. + """ + sql = u'INTEGER' + query = query.BooleanQuery + model_type = bool + + def format(self, value): + return unicode(bool(value)) + + def parse(self, string): + return str2bool(string) + + +# Shared instances of common types. +DEFAULT = Default() +INTEGER = Integer() +PRIMARY_ID = Id(True) +FOREIGN_ID = Id(False) +FLOAT = Float() +NULL_FLOAT = NullFloat() +STRING = String() +BOOLEAN = Boolean() diff --git a/lib/beets/importer.py b/lib/beets/importer.py new file mode 100644 index 00000000..4a7bd997 --- /dev/null +++ b/lib/beets/importer.py @@ -0,0 +1,1428 @@ +# This file is part of beets. +# Copyright 2013, Adrian Sampson. +# +# Permission is hereby granted, free of charge, to any person obtaining +# a copy of this software and associated documentation files (the +# "Software"), to deal in the Software without restriction, including +# without limitation the rights to use, copy, modify, merge, publish, +# distribute, sublicense, and/or sell copies of the Software, and to +# permit persons to whom the Software is furnished to do so, subject to +# the following conditions: +# +# The above copyright notice and this permission notice shall be +# included in all copies or substantial portions of the Software. + +"""Provides the basic, interface-agnostic workflow for importing and +autotagging music files. +""" +from __future__ import print_function + +import os +import re +import logging +import pickle +import itertools +from collections import defaultdict +from tempfile import mkdtemp +from bisect import insort, bisect_left +from contextlib import contextmanager +import shutil + +from beets import autotag +from beets import library +from beets import dbcore +from beets import plugins +from beets import util +from beets import config +from beets.util import pipeline, sorted_walk, ancestry +from beets.util import syspath, normpath, displayable_path +from enum import Enum +from beets import mediafile + +action = Enum('action', + ['SKIP', 'ASIS', 'TRACKS', 'MANUAL', 'APPLY', 'MANUAL_ID', + 'ALBUMS']) + +QUEUE_SIZE = 128 +SINGLE_ARTIST_THRESH = 0.25 +VARIOUS_ARTISTS = u'Various Artists' +PROGRESS_KEY = 'tagprogress' +HISTORY_KEY = 'taghistory' + +# Global logger. +log = logging.getLogger('beets') + + +class ImportAbort(Exception): + """Raised when the user aborts the tagging operation. + """ + pass + + +# Utilities. + +def _open_state(): + """Reads the state file, returning a dictionary.""" + try: + with open(config['statefile'].as_filename()) as f: + return pickle.load(f) + except Exception as exc: + # The `pickle` module can emit all sorts of exceptions during + # unpickling, including ImportError. We use a catch-all + # exception to avoid enumerating them all (the docs don't even have a + # full list!). + log.debug(u'state file could not be read: {0}'.format(exc)) + return {} + + +def _save_state(state): + """Writes the state dictionary out to disk.""" + try: + with open(config['statefile'].as_filename(), 'w') as f: + pickle.dump(state, f) + except IOError as exc: + log.error(u'state file could not be written: {0}'.format(exc)) + + +# Utilities for reading and writing the beets progress file, which +# allows long tagging tasks to be resumed when they pause (or crash). + +def progress_read(): + state = _open_state() + return state.setdefault(PROGRESS_KEY, {}) + + +@contextmanager +def progress_write(): + state = _open_state() + progress = state.setdefault(PROGRESS_KEY, {}) + yield progress + _save_state(state) + + +def progress_add(toppath, *paths): + """Record that the files under all of the `paths` have been imported + under `toppath`. + """ + with progress_write() as state: + imported = state.setdefault(toppath, []) + for path in paths: + # Normally `progress_add` will be called with the path + # argument increasing. This is because of the ordering in + # `albums_in_dir`. We take advantage of that to make the + # code faster + if imported and imported[len(imported) - 1] <= path: + imported.append(path) + else: + insort(imported, path) + + +def progress_element(toppath, path): + """Return whether `path` has been imported in `toppath`. + """ + state = progress_read() + if toppath not in state: + return False + imported = state[toppath] + i = bisect_left(imported, path) + return i != len(imported) and imported[i] == path + + +def has_progress(toppath): + """Return `True` if there exist paths that have already been + imported under `toppath`. + """ + state = progress_read() + return toppath in state + + +def progress_reset(toppath): + with progress_write() as state: + if toppath in state: + del state[toppath] + + +# Similarly, utilities for manipulating the "incremental" import log. +# This keeps track of all directories that were ever imported, which +# allows the importer to only import new stuff. + +def history_add(paths): + """Indicate that the import of the album in `paths` is completed and + should not be repeated in incremental imports. + """ + state = _open_state() + if HISTORY_KEY not in state: + state[HISTORY_KEY] = set() + + state[HISTORY_KEY].add(tuple(paths)) + + _save_state(state) + + +def history_get(): + """Get the set of completed path tuples in incremental imports. + """ + state = _open_state() + if HISTORY_KEY not in state: + return set() + return state[HISTORY_KEY] + + +# Abstract session class. + +class ImportSession(object): + """Controls an import action. Subclasses should implement methods to + communicate with the user or otherwise make decisions. + """ + def __init__(self, lib, logfile, paths, query): + """Create a session. `lib` is a Library object. `logfile` is a + file-like object open for writing or None if no logging is to be + performed. Either `paths` or `query` is non-null and indicates + the source of files to be imported. + """ + self.lib = lib + self.logfile = logfile + self.paths = paths + self.query = query + self.seen_idents = set() + self._is_resuming = dict() + + # Normalize the paths. + if self.paths: + self.paths = map(normpath, self.paths) + + def set_config(self, config): + """Set `config` property from global import config and make + implied changes. + """ + # FIXME: Maybe this function should not exist and should instead + # provide "decision wrappers" like "should_resume()", etc. + iconfig = dict(config) + self.config = iconfig + + # Incremental and progress are mutually exclusive. + if iconfig['incremental']: + iconfig['resume'] = False + + # When based on a query instead of directories, never + # save progress or try to resume. + if self.query is not None: + iconfig['resume'] = False + iconfig['incremental'] = False + + # Copy, move, and link are mutually exclusive. + if iconfig['move']: + iconfig['copy'] = False + iconfig['link'] = False + elif iconfig['link']: + iconfig['copy'] = False + iconfig['move'] = False + + # Only delete when copying. + if not iconfig['copy']: + iconfig['delete'] = False + + self.want_resume = config['resume'].as_choice([True, False, 'ask']) + + def tag_log(self, status, paths): + """Log a message about a given album to logfile. The status should + reflect the reason the album couldn't be tagged. + """ + if self.logfile: + print(u'{0} {1}'.format(status, displayable_path(paths)), + file=self.logfile) + self.logfile.flush() + + def log_choice(self, task, duplicate=False): + """Logs the task's current choice if it should be logged. If + ``duplicate``, then this is a secondary choice after a duplicate was + detected and a decision was made. + """ + paths = task.paths + if duplicate: + # Duplicate: log all three choices (skip, keep both, and trump). + if task.should_remove_duplicates: + self.tag_log('duplicate-replace', paths) + elif task.choice_flag in (action.ASIS, action.APPLY): + self.tag_log('duplicate-keep', paths) + elif task.choice_flag is (action.SKIP): + self.tag_log('duplicate-skip', paths) + else: + # Non-duplicate: log "skip" and "asis" choices. + if task.choice_flag is action.ASIS: + self.tag_log('asis', paths) + elif task.choice_flag is action.SKIP: + self.tag_log('skip', paths) + + def should_resume(self, path): + raise NotImplementedError + + def choose_match(self, task): + raise NotImplementedError + + def resolve_duplicate(self, task, found_duplicates): + raise NotImplementedError + + def choose_item(self, task): + raise NotImplementedError + + def run(self): + """Run the import task. + """ + self.set_config(config['import']) + + # Set up the pipeline. + if self.query is None: + stages = [read_tasks(self)] + else: + stages = [query_tasks(self)] + + if self.config['pretend']: + # Only log the imported files and end the pipeline + stages += [log_files(self)] + else: + if self.config['group_albums'] and \ + not self.config['singletons']: + # Split directory tasks into one task for each album + stages += [group_albums(self)] + if self.config['autotag']: + # FIXME We should also resolve duplicates when not + # autotagging. This is currently handled in `user_query` + stages += [lookup_candidates(self), user_query(self)] + else: + stages += [import_asis(self)] + stages += [apply_choices(self)] + + for stage_func in plugins.import_stages(): + stages.append(plugin_stage(self, stage_func)) + stages += [manipulate_files(self)] + pl = pipeline.Pipeline(stages) + + # Run the pipeline. + plugins.send('import_begin', session=self) + try: + if config['threaded']: + pl.run_parallel(QUEUE_SIZE) + else: + pl.run_sequential() + except ImportAbort: + # User aborted operation. Silently stop. + pass + + # Incremental and resumed imports + + def already_imported(self, toppath, paths): + """Returns true if the files belonging to this task have already + been imported in a previous session. + """ + if self.is_resuming(toppath) \ + and all(map(lambda p: progress_element(toppath, p), paths)): + return True + if self.config['incremental'] \ + and tuple(paths) in self.history_dirs: + return True + + return False + + @property + def history_dirs(self): + if not hasattr(self, '_history_dirs'): + self._history_dirs = history_get() + return self._history_dirs + + def is_resuming(self, toppath): + """Return `True` if user wants to resume import of this path. + + You have to call `ask_resume` first to determine the return value. + """ + return self._is_resuming.get(toppath, False) + + def ask_resume(self, toppath): + """If import of `toppath` was aborted in an earlier session, ask + user if she wants to resume the import. + + Determines the return value of `is_resuming(toppath)`. + """ + if self.want_resume and has_progress(toppath): + # Either accept immediately or prompt for input to decide. + if self.want_resume is True or \ + self.should_resume(toppath): + log.warn(u'Resuming interrupted import of {0}'.format( + util.displayable_path(toppath))) + self._is_resuming[toppath] = True + else: + # Clear progress; we're starting from the top. + progress_reset(toppath) + + +# The importer task class. + +class ImportTask(object): + """Represents a single set of items to be imported along with its + intermediate state. May represent an album or a single item. + + The import session and stages call the following methods in the + given order. + + * `lookup_candidates()` Sets the `common_artist`, `common_album`, + `candidates`, and `rec` attributes. `candidates` is a list of + `AlbumMatch` objects. + + * `choose_match()` Uses the session to set the `match` attribute + from the `candidates` list. + + * `find_duplicates()` Returns a list of albums from `lib` with the + same artist and album name as the task. + + * `apply_metadata()` Sets the attributes of the items from the + task's `match` attribute. + + * `add()` Add the imported items and album to the database. + + * `manipulate_files()` Copy, move, and write files depending on the + session configuration. + + * `finalize()` Update the import progress and cleanup the file + system. + """ + def __init__(self, toppath=None, paths=None, items=None): + self.toppath = toppath + self.paths = paths + self.items = items + self.choice_flag = None + + self.cur_album = None + self.cur_artist = None + self.candidates = [] + self.rec = None + # TODO remove this eventually + self.should_remove_duplicates = False + self.is_album = True + + def set_choice(self, choice): + """Given an AlbumMatch or TrackMatch object or an action constant, + indicates that an action has been selected for this task. + """ + # Not part of the task structure: + assert choice not in (action.MANUAL, action.MANUAL_ID) + assert choice != action.APPLY # Only used internally. + if choice in (action.SKIP, action.ASIS, action.TRACKS, action.ALBUMS): + self.choice_flag = choice + self.match = None + else: + self.choice_flag = action.APPLY # Implicit choice. + self.match = choice + + def save_progress(self): + """Updates the progress state to indicate that this album has + finished. + """ + if self.toppath: + progress_add(self.toppath, *self.paths) + + def save_history(self): + """Save the directory in the history for incremental imports. + """ + if self.paths: + history_add(self.paths) + + # Logical decisions. + + @property + def apply(self): + return self.choice_flag == action.APPLY + + @property + def skip(self): + return self.choice_flag == action.SKIP + + # Convenient data. + + def chosen_ident(self): + """Returns identifying metadata about the current choice. For + albums, this is an (artist, album) pair. For items, this is + (artist, title). May only be called when the choice flag is ASIS + (in which case the data comes from the files' current metadata) + or APPLY (data comes from the choice). + """ + if self.choice_flag is action.ASIS: + return (self.cur_artist, self.cur_album) + elif self.choice_flag is action.APPLY: + return (self.match.info.artist, self.match.info.album) + + def imported_items(self): + """Return a list of Items that should be added to the library. + + If the tasks applies an album match the method only returns the + matched items. + """ + if self.choice_flag == action.ASIS: + return list(self.items) + # FIXME this should be a simple attribute. There should be no + # need to retrieve the keys of `match.mapping`. This requires + # that we remove unmatched items from the list. + elif self.choice_flag == action.APPLY: + return self.match.mapping.keys() + else: + assert False + + def apply_metadata(self): + """Copy metadata from match info to the items. + """ + # TODO call should be more descriptive like + # apply_metadata(self.match, self.items) + autotag.apply_metadata(self.match.info, self.match.mapping) + + def duplicate_items(self, lib): + duplicate_items = [] + for album in self.find_duplicates(lib): + duplicate_items += album.items() + return duplicate_items + + def remove_duplicates(self, lib): + duplicate_items = self.duplicate_items(lib) + log.debug(u'removing {0} old duplicated items' + .format(len(duplicate_items))) + for item in duplicate_items: + item.remove() + if lib.directory in util.ancestry(item.path): + log.debug(u'deleting duplicate {0}' + .format(util.displayable_path(item.path))) + util.remove(item.path) + util.prune_dirs(os.path.dirname(item.path), + lib.directory) + + def finalize(self, session): + """Save progress, clean up files, and emit plugin event. + """ + # FIXME the session argument is unfortunate. It should be + # present as an attribute of the task. + + # Update progress. + if session.want_resume: + self.save_progress() + if session.config['incremental']: + self.save_history() + + self.cleanup(copy=session.config['copy'], + delete=session.config['delete'], + move=session.config['move']) + + if not self.skip: + self._emit_imported(session.lib) + + def cleanup(self, copy=False, delete=False, move=False): + """Remove and prune imported paths. + """ + # FIXME Maybe the keywords should be task properties. + + # Do not delete any files or prune directories when skipping. + if self.skip: + return + + items = self.imported_items() + + # When copying and deleting originals, delete old files. + if copy and delete: + new_paths = [os.path.realpath(item.path) for item in items] + for old_path in self.old_paths: + # Only delete files that were actually copied. + if old_path not in new_paths: + util.remove(syspath(old_path), False) + self.prune(old_path) + + # When moving, prune empty directories containing the original files. + elif move: + for old_path in self.old_paths: + self.prune(old_path) + + def _emit_imported(self, lib): + # FIXME This shouldn't be here. Skipping should be handled in + # the stages. + if self.skip: + return + plugins.send('album_imported', lib=lib, album=self.album) + + def lookup_candidates(self): + """Retrieve and store candidates for this album. + """ + artist, album, candidates, recommendation = \ + autotag.tag_album(self.items) + self.cur_artist = artist + self.cur_album = album + self.candidates = candidates + self.rec = recommendation + + def find_duplicates(self, lib): + """Return a list of albums from `lib` with the same artist and + album name as the task. + """ + artist, album = self.chosen_ident() + + if artist is None: + # As-is import with no artist. Skip check. + return [] + + duplicates = [] + task_paths = set(i.path for i in self.items if i) + duplicate_query = dbcore.AndQuery(( + dbcore.MatchQuery('albumartist', artist), + dbcore.MatchQuery('album', album), + )) + + for album in lib.albums(duplicate_query): + # Check whether the album is identical in contents, in which + # case it is not a duplicate (will be replaced). + album_paths = set(i.path for i in album.items()) + if album_paths != task_paths: + duplicates.append(album) + return duplicates + + def align_album_level_fields(self): + """Make the some album fields equal across `self.items` + """ + changes = {} + + if self.choice_flag == action.ASIS: + # Taking metadata "as-is". Guess whether this album is VA. + plur_albumartist, freq = util.plurality( + [i.albumartist or i.artist for i in self.items] + ) + if freq == len(self.items) or \ + (freq > 1 and + float(freq) / len(self.items) >= SINGLE_ARTIST_THRESH): + # Single-artist album. + changes['albumartist'] = plur_albumartist + changes['comp'] = False + else: + # VA. + changes['albumartist'] = VARIOUS_ARTISTS + changes['comp'] = True + + elif self.choice_flag == action.APPLY: + # Applying autotagged metadata. Just get AA from the first + # item. + if not self.items[0].albumartist: + changes['albumartist'] = self.items[0].artist + if not self.items[0].mb_albumartistid: + changes['mb_albumartistid'] = self.items[0].mb_artistid + + # Apply new metadata. + for item in self.items: + item.update(changes) + + def manipulate_files(self, move=False, copy=False, write=False, + link=False, session=None): + items = self.imported_items() + # Save the original paths of all items for deletion and pruning + # in the next step (finalization). + self.old_paths = [item.path for item in items] + for item in items: + if move or copy or link: + # In copy and link modes, treat re-imports specially: + # move in-library files. (Out-of-library files are + # copied/moved as usual). + old_path = item.path + if (copy or link) and self.replaced_items[item] and \ + session.lib.directory in util.ancestry(old_path): + item.move() + # We moved the item, so remove the + # now-nonexistent file from old_paths. + self.old_paths.remove(old_path) + else: + # A normal import. Just copy files and keep track of + # old paths. + item.move(copy, link) + + if write and self.apply: + item.try_write() + + with session.lib.transaction(): + for item in self.imported_items(): + item.store() + + plugins.send('import_task_files', session=session, task=self) + + def add(self, lib): + """Add the items as an album to the library and remove replaced items. + """ + self.align_album_level_fields() + with lib.transaction(): + self.record_replaced(lib) + self.remove_replaced(lib) + self.album = lib.add_album(self.imported_items()) + self.reimport_metadata(lib) + + def record_replaced(self, lib): + """Records the replaced items and albums in the `replaced_items` + and `replaced_albums` dictionaries. + """ + self.replaced_items = defaultdict(list) + self.replaced_albums = defaultdict(list) + replaced_album_ids = set() + for item in self.imported_items(): + dup_items = list(lib.items( + dbcore.query.BytesQuery('path', item.path) + )) + self.replaced_items[item] = dup_items + for dup_item in dup_items: + if (not dup_item.album_id or + dup_item.album_id in replaced_album_ids): + continue + replaced_album = dup_item.get_album() + if replaced_album: + replaced_album_ids.add(dup_item.album_id) + self.replaced_albums[replaced_album.path] = replaced_album + + def reimport_metadata(self, lib): + """For reimports, preserves metadata for reimported items and + albums. + """ + if self.is_album: + replaced_album = self.replaced_albums.get(self.album.path) + if replaced_album: + self.album.added = replaced_album.added + self.album.update(replaced_album._values_flex) + self.album.store() + log.debug( + u'Reimported album: added {0}, flexible ' + u'attributes {1} from album {2} for {3}'.format( + self.album.added, + replaced_album._values_flex.keys(), + replaced_album.id, + displayable_path(self.album.path), + ) + ) + + for item in self.imported_items(): + dup_items = self.replaced_items[item] + for dup_item in dup_items: + if dup_item.added and dup_item.added != item.added: + item.added = dup_item.added + log.debug( + u'Reimported item added {0} ' + u'from item {1} for {2}'.format( + item.added, + dup_item.id, + displayable_path(item.path), + ) + ) + item.update(dup_item._values_flex) + log.debug( + u'Reimported item flexible attributes {0} ' + u'from item {1} for {2}'.format( + dup_item._values_flex.keys(), + dup_item.id, + displayable_path(item.path), + ) + ) + item.store() + + def remove_replaced(self, lib): + """Removes all the items from the library that have the same + path as an item from this task. + """ + for item in self.imported_items(): + for dup_item in self.replaced_items[item]: + log.debug(u'Replacing item {0}: {1}' + .format(dup_item.id, + displayable_path(item.path))) + dup_item.remove() + log.debug(u'{0} of {1} items replaced' + .format(sum(bool(l) for l in self.replaced_items.values()), + len(self.imported_items()))) + + def choose_match(self, session): + """Ask the session which match should apply and apply it. + """ + choice = session.choose_match(self) + self.set_choice(choice) + session.log_choice(self) + + def reload(self): + """Reload albums and items from the database. + """ + for item in self.imported_items(): + item.load() + self.album.load() + + # Utilities. + + def prune(self, filename): + """Prune any empty directories above the given file. If this + task has no `toppath` or the file path provided is not within + the `toppath`, then this function has no effect. Similarly, if + the file still exists, no pruning is performed, so it's safe to + call when the file in question may not have been removed. + """ + if self.toppath and not os.path.exists(filename): + util.prune_dirs(os.path.dirname(filename), + self.toppath, + clutter=config['clutter'].as_str_seq()) + + +class SingletonImportTask(ImportTask): + """ImportTask for a single track that is not associated to an album. + """ + + def __init__(self, toppath, item): + super(SingletonImportTask, self).__init__(toppath, [item.path]) + self.item = item + self.is_album = False + self.paths = [item.path] + + def chosen_ident(self): + assert self.choice_flag in (action.ASIS, action.APPLY) + if self.choice_flag is action.ASIS: + return (self.item.artist, self.item.title) + elif self.choice_flag is action.APPLY: + return (self.match.info.artist, self.match.info.title) + + def imported_items(self): + return [self.item] + + def apply_metadata(self): + autotag.apply_item_metadata(self.item, self.match.info) + + def _emit_imported(self, lib): + # FIXME This shouldn't be here. Skipped tasks should be removed from + # the pipeline. + if self.skip: + return + for item in self.imported_items(): + plugins.send('item_imported', lib=lib, item=item) + + def lookup_candidates(self): + candidates, recommendation = autotag.tag_item(self.item) + self.candidates = candidates + self.rec = recommendation + + def find_duplicates(self, lib): + """Return a list of items from `lib` that have the same artist + and title as the task. + """ + artist, title = self.chosen_ident() + + found_items = [] + query = dbcore.AndQuery(( + dbcore.MatchQuery('artist', artist), + dbcore.MatchQuery('title', title), + )) + for other_item in lib.items(query): + # Existing items not considered duplicates. + if other_item.path != self.item.path: + found_items.append(other_item) + return found_items + + duplicate_items = find_duplicates + + def add(self, lib): + with lib.transaction(): + self.record_replaced(lib) + self.remove_replaced(lib) + lib.add(self.item) + self.reimport_metadata(lib) + + def infer_album_fields(self): + raise NotImplementedError + + def choose_match(self, session): + """Ask the session which match should apply and apply it. + """ + choice = session.choose_item(self) + self.set_choice(choice) + session.log_choice(self) + + def reload(self): + self.item.load() + + +# FIXME The inheritance relationships are inverted. This is why there +# are so many methods which pass. We should introduce a new +# BaseImportTask class. +class SentinelImportTask(ImportTask): + """This class marks the progress of an import and does not import + any items itself. + + If only `toppath` is set the task indicats the end of a top-level + directory import. If the `paths` argument is givent, too, the task + indicates the progress in the `toppath` import. + """ + + def __init__(self, toppath=None, paths=None): + self.toppath = toppath + self.paths = paths + # TODO Remove the remaining attributes eventually + self.items = None + self.should_remove_duplicates = False + self.is_album = True + self.choice_flag = None + + def save_history(self): + pass + + def save_progress(self): + if self.paths is None: + # "Done" sentinel. + progress_reset(self.toppath) + else: + # "Directory progress" sentinel for singletons + progress_add(self.toppath, *self.paths) + + def skip(self): + return True + + def set_choice(self, choice): + raise NotImplementedError + + def cleanup(self, **kwargs): + pass + + def _emit_imported(self, session): + pass + + +class ArchiveImportTask(SentinelImportTask): + """Additional methods for handling archives. + + Use when `toppath` points to a `zip`, `tar`, or `rar` archive. + """ + + def __init__(self, toppath): + super(ArchiveImportTask, self).__init__(toppath) + self.extracted = False + + @classmethod + def is_archive(cls, path): + """Returns true if the given path points to an archive that can + be handled. + """ + if not os.path.isfile(path): + return False + + for path_test, _ in cls.handlers(): + if path_test(path): + return True + return False + + @classmethod + def handlers(cls): + """Returns a list of archive handlers. + + Each handler is a `(path_test, ArchiveClass)` tuple. `path_test` + is a function that returns `True` if the given path can be + handled by `ArchiveClass`. `ArchiveClass` is a class that + implements the same interface as `tarfile.TarFile`. + """ + if not hasattr(cls, '_handlers'): + cls._handlers = [] + from zipfile import is_zipfile, ZipFile + cls._handlers.append((is_zipfile, ZipFile)) + from tarfile import is_tarfile, TarFile + cls._handlers.append((is_tarfile, TarFile)) + try: + from rarfile import is_rarfile, RarFile + except ImportError: + pass + else: + cls._handlers.append((is_rarfile, RarFile)) + + return cls._handlers + + def cleanup(self, **kwargs): + """Removes the temporary directory the archive was extracted to. + """ + if self.extracted: + shutil.rmtree(self.toppath) + + def extract(self): + """Extracts the archive to a temporary directory and sets + `toppath` to that directory. + """ + for path_test, handler_class in self.handlers(): + if path_test(self.toppath): + break + + try: + extract_to = mkdtemp() + archive = handler_class(self.toppath, mode='r') + archive.extractall(extract_to) + finally: + archive.close() + self.extracted = True + self.toppath = extract_to + + +class ImportTaskFactory(object): + """Create album and singleton import tasks for all media files in a + directory or path. + + Depending on the session's 'flat' and 'singleton' configuration, it + groups all media files contained in `toppath` into singleton or + album import tasks. + """ + def __init__(self, toppath, session): + self.toppath = toppath + self.session = session + self.skipped = 0 + + def tasks(self): + """Yield all import tasks for `self.toppath`. + + The behavior is configured by the session's 'flat', and + 'singleton' flags. + """ + for dirs, paths in self.paths(): + if self.session.config['singletons']: + for path in paths: + task = self.singleton(path) + if task: + yield task + yield self.sentinel(dirs) + + else: + task = self.album(paths, dirs) + if task: + yield task + + def paths(self): + """Walk `self.toppath` and yield pairs of directory lists and + path lists. + """ + if not os.path.isdir(syspath(self.toppath)): + yield ([self.toppath], [self.toppath]) + elif self.session.config['flat']: + paths = [] + for dirs, paths_in_dir in albums_in_dir(self.toppath): + paths += paths_in_dir + yield ([self.toppath], paths) + else: + for dirs, paths in albums_in_dir(self.toppath): + yield (dirs, paths) + + def singleton(self, path): + if self.session.already_imported(self.toppath, [path]): + log.debug(u'Skipping previously-imported path: {0}' + .format(displayable_path(path))) + self.skipped += 1 + return None + + item = self.read_item(path) + if item: + return SingletonImportTask(self.toppath, item) + else: + return None + + def album(self, paths, dirs=None): + """Return `ImportTask` with all media files from paths. + + `dirs` is a list of parent directories used to record already + imported albums. + """ + if not paths: + return None + + if dirs is None: + dirs = list(set(os.path.dirname(p) for p in paths)) + + if self.session.already_imported(self.toppath, dirs): + log.debug(u'Skipping previously-imported path: {0}' + .format(displayable_path(dirs))) + self.skipped += 1 + return None + + items = map(self.read_item, paths) + items = [item for item in items if item] + + if items: + return ImportTask(self.toppath, dirs, items) + else: + return None + + def sentinel(self, paths=None): + return SentinelImportTask(self.toppath, paths) + + def read_item(self, path): + """Return an item created from the path. + + If an item could not be read it returns None and logs an error. + """ + # TODO remove this method. Should be handled in ImportTask creation. + try: + return library.Item.from_path(path) + except library.ReadError as exc: + if isinstance(exc.reason, mediafile.FileTypeError): + # Silently ignore non-music files. + pass + elif isinstance(exc.reason, mediafile.UnreadableFileError): + log.warn(u'unreadable file: {0}'.format( + displayable_path(path)) + ) + else: + log.error(u'error reading {0}: {1}'.format( + displayable_path(path), + exc, + )) + + +# Full-album pipeline stages. + +def read_tasks(session): + """A generator yielding all the albums (as ImportTask objects) found + in the user-specified list of paths. In the case of a singleton + import, yields single-item tasks instead. + """ + skipped = 0 + for toppath in session.paths: + # Determine if we want to resume import of the toppath + session.ask_resume(toppath) + user_toppath = toppath + + # Extract archives. + archive_task = None + if ArchiveImportTask.is_archive(syspath(toppath)): + if not (session.config['move'] or session.config['copy']): + log.warn(u"Archive importing requires either " + "'copy' or 'move' to be enabled.") + continue + + log.debug(u'extracting archive {0}' + .format(displayable_path(toppath))) + archive_task = ArchiveImportTask(toppath) + try: + archive_task.extract() + except Exception as exc: + log.error(u'extraction failed: {0}'.format(exc)) + continue + + # Continue reading albums from the extracted directory. + toppath = archive_task.toppath + + task_factory = ImportTaskFactory(toppath, session) + imported = False + for t in task_factory.tasks(): + imported |= not t.skip + yield t + + # Indicate the directory is finished. + # FIXME hack to delete extracted archives + if archive_task is None: + yield task_factory.sentinel() + else: + yield archive_task + + if not imported: + log.warn(u'No files imported from {0}' + .format(displayable_path(user_toppath))) + + # Show skipped directories. + if skipped: + log.info(u'Skipped {0} directories.'.format(skipped)) + + +def query_tasks(session): + """A generator that works as a drop-in-replacement for read_tasks. + Instead of finding files from the filesystem, a query is used to + match items from the library. + """ + if session.config['singletons']: + # Search for items. + for item in session.lib.items(session.query): + yield SingletonImportTask(None, item) + + else: + # Search for albums. + for album in session.lib.albums(session.query): + log.debug(u'yielding album {0}: {1} - {2}' + .format(album.id, album.albumartist, album.album)) + items = list(album.items()) + + # Clear IDs from re-tagged items so they appear "fresh" when + # we add them back to the library. + for item in items: + item.id = None + item.album_id = None + + yield ImportTask(None, [album.item_dir()], items) + + +@pipeline.mutator_stage +def lookup_candidates(session, task): + """A coroutine for performing the initial MusicBrainz lookup for an + album. It accepts lists of Items and yields + (items, cur_artist, cur_album, candidates, rec) tuples. If no match + is found, all of the yielded parameters (except items) are None. + """ + if task.skip: + # FIXME This gets duplicated a lot. We need a better + # abstraction. + return + + plugins.send('import_task_start', session=session, task=task) + log.debug(u'Looking up: {0}'.format(displayable_path(task.paths))) + task.lookup_candidates() + + +@pipeline.stage +def user_query(session, task): + """A coroutine for interfacing with the user about the tagging + process. + + The coroutine accepts an ImportTask objects. It uses the + session's `choose_match` method to determine the `action` for + this task. Depending on the action additional stages are exectuted + and the processed task is yielded. + + It emits the ``import_task_choice`` event for plugins. Plugins have + acces to the choice via the ``taks.choice_flag`` property and may + choose to change it. + """ + if task.skip: + return task + + # Ask the user for a choice. + task.choose_match(session) + plugins.send('import_task_choice', session=session, task=task) + + # As-tracks: transition to singleton workflow. + if task.choice_flag is action.TRACKS: + # Set up a little pipeline for dealing with the singletons. + def emitter(task): + for item in task.items: + yield SingletonImportTask(task.toppath, item) + yield SentinelImportTask(task.toppath, task.paths) + + ipl = pipeline.Pipeline([ + emitter(task), + lookup_candidates(session), + user_query(session), + ]) + return pipeline.multiple(ipl.pull()) + + # As albums: group items by albums and create task for each album + if task.choice_flag is action.ALBUMS: + ipl = pipeline.Pipeline([ + iter([task]), + group_albums(session), + lookup_candidates(session), + user_query(session) + ]) + return pipeline.multiple(ipl.pull()) + + resolve_duplicates(session, task) + return task + + +def resolve_duplicates(session, task): + """Check if a task conflicts with items or albums already imported + and ask the session to resolve this. + """ + if task.choice_flag in (action.ASIS, action.APPLY): + ident = task.chosen_ident() + found_duplicates = task.find_duplicates(session.lib) + if ident in session.seen_idents or found_duplicates: + session.resolve_duplicate(task, found_duplicates) + session.log_choice(task, True) + session.seen_idents.add(ident) + + +@pipeline.mutator_stage +def import_asis(session, task): + """Select the `action.ASIS` choice for all tasks. + + This stage replaces the initial_lookup and user_query stages + when the importer is run without autotagging. + """ + if task.skip: + return + + log.info(displayable_path(task.paths)) + task.set_choice(action.ASIS) + + +@pipeline.mutator_stage +def apply_choices(session, task): + """A coroutine for applying changes to albums and singletons during + the autotag process. + """ + if task.skip: + return + + # Change metadata. + if task.apply: + task.apply_metadata() + plugins.send('import_task_apply', session=session, task=task) + + task.add(session.lib) + + +@pipeline.mutator_stage +def plugin_stage(session, func, task): + """A coroutine (pipeline stage) that calls the given function with + each non-skipped import task. These stages occur between applying + metadata changes and moving/copying/writing files. + """ + if task.skip: + return + + func(session, task) + + # Stage may modify DB, so re-load cached item data. + # FIXME Importer plugins should not modify the database but instead + # the albums and items attached to tasks. + task.reload() + + +@pipeline.stage +def manipulate_files(session, task): + """A coroutine (pipeline stage) that performs necessary file + manipulations *after* items have been added to the library and + finalizes each task. + """ + if not task.skip: + if task.should_remove_duplicates: + task.remove_duplicates(session.lib) + + task.manipulate_files( + move=session.config['move'], + copy=session.config['copy'], + write=session.config['write'], + link=session.config['link'], + session=session, + ) + + # Progress, cleanup, and event. + task.finalize(session) + + +@pipeline.stage +def log_files(session, task): + """A coroutine (pipeline stage) to log each file which will be imported + """ + if isinstance(task, SingletonImportTask): + log.info( + 'Singleton: {0}'.format(displayable_path(task.item['path']))) + elif task.items: + log.info('Album {0}'.format(displayable_path(task.paths[0]))) + for item in task.items: + log.info(' {0}'.format(displayable_path(item['path']))) + + +def group_albums(session): + """Group the items of a task by albumartist and album name and create a new + task for each album. Yield the tasks as a multi message. + """ + def group(item): + return (item.albumartist or item.artist, item.album) + + task = None + while True: + task = yield task + if task.skip: + continue + tasks = [] + for _, items in itertools.groupby(task.items, group): + tasks.append(ImportTask(items=list(items))) + tasks.append(SentinelImportTask(task.toppath, task.paths)) + + task = pipeline.multiple(tasks) + + +MULTIDISC_MARKERS = (r'dis[ck]', r'cd') +MULTIDISC_PAT_FMT = r'^(.*%s[\W_]*)\d' + + +def albums_in_dir(path): + """Recursively searches the given directory and returns an iterable + of (paths, items) where paths is a list of directories and items is + a list of Items that is probably an album. Specifically, any folder + containing any media files is an album. + """ + collapse_pat = collapse_paths = collapse_items = None + ignore = config['ignore'].as_str_seq() + + for root, dirs, files in sorted_walk(path, ignore=ignore, logger=log): + items = [os.path.join(root, f) for f in files] + # If we're currently collapsing the constituent directories in a + # multi-disc album, check whether we should continue collapsing + # and add the current directory. If so, just add the directory + # and move on to the next directory. If not, stop collapsing. + if collapse_paths: + if (not collapse_pat and collapse_paths[0] in ancestry(root)) or \ + (collapse_pat and + collapse_pat.match(os.path.basename(root))): + # Still collapsing. + collapse_paths.append(root) + collapse_items += items + continue + else: + # Collapse finished. Yield the collapsed directory and + # proceed to process the current one. + if collapse_items: + yield collapse_paths, collapse_items + collapse_pat = collapse_paths = collapse_items = None + + # Check whether this directory looks like the *first* directory + # in a multi-disc sequence. There are two indicators: the file + # is named like part of a multi-disc sequence (e.g., "Title Disc + # 1") or it contains no items but only directories that are + # named in this way. + start_collapsing = False + for marker in MULTIDISC_MARKERS: + marker_pat = re.compile(MULTIDISC_PAT_FMT % marker, re.I) + match = marker_pat.match(os.path.basename(root)) + + # Is this directory the root of a nested multi-disc album? + if dirs and not items: + # Check whether all subdirectories have the same prefix. + start_collapsing = True + subdir_pat = None + for subdir in dirs: + # The first directory dictates the pattern for + # the remaining directories. + if not subdir_pat: + match = marker_pat.match(subdir) + if match: + subdir_pat = re.compile( + r'^%s\d' % re.escape(match.group(1)), re.I + ) + else: + start_collapsing = False + break + + # Subsequent directories must match the pattern. + elif not subdir_pat.match(subdir): + start_collapsing = False + break + + # If all subdirectories match, don't check other + # markers. + if start_collapsing: + break + + # Is this directory the first in a flattened multi-disc album? + elif match: + start_collapsing = True + # Set the current pattern to match directories with the same + # prefix as this one, followed by a digit. + collapse_pat = re.compile( + r'^%s\d' % re.escape(match.group(1)), re.I + ) + break + + # If either of the above heuristics indicated that this is the + # beginning of a multi-disc album, initialize the collapsed + # directory and item lists and check the next directory. + if start_collapsing: + # Start collapsing; continue to the next iteration. + collapse_paths = [root] + collapse_items = items + continue + + # If it's nonempty, yield it. + if items: + yield [root], items + + # Clear out any unfinished collapse. + if collapse_paths and collapse_items: + yield collapse_paths, collapse_items diff --git a/lib/beets/library.py b/lib/beets/library.py new file mode 100644 index 00000000..1de1bba5 --- /dev/null +++ b/lib/beets/library.py @@ -0,0 +1,1299 @@ +# This file is part of beets. +# Copyright 2013, Adrian Sampson. +# +# Permission is hereby granted, free of charge, to any person obtaining +# a copy of this software and associated documentation files (the +# "Software"), to deal in the Software without restriction, including +# without limitation the rights to use, copy, modify, merge, publish, +# distribute, sublicense, and/or sell copies of the Software, and to +# permit persons to whom the Software is furnished to do so, subject to +# the following conditions: +# +# The above copyright notice and this permission notice shall be +# included in all copies or substantial portions of the Software. + +"""The core data store and collection logic for beets. +""" +import os +import sys +import logging +import shlex +import unicodedata +import time +import re +from unidecode import unidecode +from beets.mediafile import MediaFile, MutagenError, UnreadableFileError +from beets import plugins +from beets import util +from beets.util import bytestring_path, syspath, normpath, samefile +from beets.util.functemplate import Template +from beets import dbcore +from beets.dbcore import types +import beets + + +log = logging.getLogger('beets') + + +# Library-specific query types. + +class PathQuery(dbcore.FieldQuery): + """A query that matches all items under a given path.""" + + escape_re = re.compile(r'[\\_%]') + escape_char = '\\' + + def __init__(self, field, pattern, fast=True): + super(PathQuery, self).__init__(field, pattern, fast) + + # Match the path as a single file. + self.file_path = util.bytestring_path(util.normpath(pattern)) + # As a directory (prefix). + self.dir_path = util.bytestring_path(os.path.join(self.file_path, '')) + + def match(self, item): + return (item.path == self.file_path) or \ + item.path.startswith(self.dir_path) + + def clause(self): + escape = lambda m: self.escape_char + m.group(0) + dir_pattern = self.escape_re.sub(escape, self.dir_path) + dir_pattern = buffer(dir_pattern + '%') + file_blob = buffer(self.file_path) + return '({0} = ?) || ({0} LIKE ? ESCAPE ?)'.format(self.field), \ + (file_blob, dir_pattern, self.escape_char) + + +# Library-specific field types. + +class DateType(types.Float): + # TODO representation should be `datetime` object + # TODO distinguish beetween date and time types + query = dbcore.query.DateQuery + + def format(self, value): + return time.strftime(beets.config['time_format'].get(unicode), + time.localtime(value or 0)) + + def parse(self, string): + try: + # Try a formatted date string. + return time.mktime( + time.strptime(string, beets.config['time_format'].get(unicode)) + ) + except ValueError: + # Fall back to a plain timestamp number. + try: + return float(string) + except ValueError: + return self.null + + +class PathType(types.Type): + sql = u'BLOB' + query = PathQuery + model_type = bytes + + def format(self, value): + return util.displayable_path(value) + + def parse(self, string): + return normpath(bytestring_path(string)) + + def normalize(self, value): + if isinstance(value, unicode): + # Paths stored internally as encoded bytes. + return bytestring_path(value) + + elif isinstance(value, buffer): + # SQLite must store bytestings as buffers to avoid decoding. + # We unwrap buffers to bytes. + return bytes(value) + + else: + return value + + def from_sql(self, sql_value): + return self.normalize(sql_value) + + def to_sql(self, value): + if isinstance(value, str): + value = buffer(value) + return value + + +class MusicalKey(types.String): + """String representing the musical key of a song. + + The standard format is C, Cm, C#, C#m, etc. + """ + ENHARMONIC = { + r'db': 'c#', + r'eb': 'd#', + r'gb': 'f#', + r'ab': 'g#', + r'bb': 'a#', + } + + def parse(self, key): + key = key.lower() + for flat, sharp in self.ENHARMONIC.items(): + key = re.sub(flat, sharp, key) + key = re.sub(r'[\W\s]+minor', 'm', key) + return key.capitalize() + + def normalize(self, key): + if key is None: + return None + else: + return self.parse(key) + + +# Library-specific sort types. + +class SmartArtistSort(dbcore.query.Sort): + """Sort by artist (either album artist or track artist), + prioritizing the sort field over the raw field. + """ + def __init__(self, model_cls, ascending=True): + self.album = model_cls is Album + self.ascending = ascending + + def order_clause(self): + order = "ASC" if self.ascending else "DESC" + if self.album: + field = 'albumartist' + else: + field = 'artist' + return ('(CASE {0}_sort WHEN NULL THEN {0} ' + 'WHEN "" THEN {0} ' + 'ELSE {0}_sort END) {1}').format(field, order) + + def sort(self, objs): + if self.album: + key = lambda a: a.albumartist_sort or a.albumartist + else: + key = lambda i: i.artist_sort or i.artist + return sorted(objs, key=key, reverse=not self.ascending) + + +# Special path format key. +PF_KEY_DEFAULT = 'default' + + +# Exceptions. + +class FileOperationError(Exception): + """Indicates an error when interacting with a file on disk. + Possibilities include an unsupported media type, a permissions + error, and an unhandled Mutagen exception. + """ + def __init__(self, path, reason): + """Create an exception describing an operation on the file at + `path` with the underlying (chained) exception `reason`. + """ + super(FileOperationError, self).__init__(path, reason) + self.path = path + self.reason = reason + + def __unicode__(self): + """Get a string representing the error. Describes both the + underlying reason and the file path in question. + """ + return u'{0}: {1}'.format( + util.displayable_path(self.path), + unicode(self.reason) + ) + + def __str__(self): + return unicode(self).encode('utf8') + + +class ReadError(FileOperationError): + """An error while reading a file (i.e. in `Item.read`). + """ + def __unicode__(self): + return u'error reading ' + super(ReadError, self).__unicode__() + + +class WriteError(FileOperationError): + """An error while writing a file (i.e. in `Item.write`). + """ + def __unicode__(self): + return u'error writing ' + super(WriteError, self).__unicode__() + + +# Item and Album model classes. + +class LibModel(dbcore.Model): + """Shared concrete functionality for Items and Albums. + """ + + def _template_funcs(self): + funcs = DefaultTemplateFunctions(self, self._db).functions() + funcs.update(plugins.template_funcs()) + return funcs + + def store(self): + super(LibModel, self).store() + plugins.send('database_change', lib=self._db) + + def remove(self): + super(LibModel, self).remove() + plugins.send('database_change', lib=self._db) + + def add(self, lib=None): + super(LibModel, self).add(lib) + plugins.send('database_change', lib=self._db) + + +class FormattedItemMapping(dbcore.db.FormattedMapping): + """Add lookup for album-level fields. + + Album-level fields take precedence if `for_path` is true. + """ + + def __init__(self, item, for_path=False): + super(FormattedItemMapping, self).__init__(item, for_path) + self.album = item.get_album() + self.album_keys = [] + if self.album: + for key in self.album.keys(True): + if key in Album.item_keys or key not in item._fields.keys(): + self.album_keys.append(key) + self.all_keys = set(self.model_keys).union(self.album_keys) + + def _get(self, key): + """Get the value for a key, either from the album or the item. + Raise a KeyError for invalid keys. + """ + if self.for_path and key in self.album_keys: + return self._get_formatted(self.album, key) + elif key in self.model_keys: + return self._get_formatted(self.model, key) + elif key in self.album_keys: + return self._get_formatted(self.album, key) + else: + raise KeyError(key) + + def __getitem__(self, key): + """Get the value for a key. Certain unset values are remapped. + """ + value = self._get(key) + + # `artist` and `albumartist` fields fall back to one another. + # This is helpful in path formats when the album artist is unset + # on as-is imports. + if key == 'artist' and not value: + return self._get('albumartist') + elif key == 'albumartist' and not value: + return self._get('artist') + else: + return value + + def __iter__(self): + return iter(self.all_keys) + + def __len__(self): + return len(self.all_keys) + + +class Item(LibModel): + _table = 'items' + _flex_table = 'item_attributes' + _fields = { + 'id': types.PRIMARY_ID, + 'path': PathType(), + 'album_id': types.FOREIGN_ID, + + 'title': types.STRING, + 'artist': types.STRING, + 'artist_sort': types.STRING, + 'artist_credit': types.STRING, + 'album': types.STRING, + 'albumartist': types.STRING, + 'albumartist_sort': types.STRING, + 'albumartist_credit': types.STRING, + 'genre': types.STRING, + 'composer': types.STRING, + 'grouping': types.STRING, + 'year': types.PaddedInt(4), + 'month': types.PaddedInt(2), + 'day': types.PaddedInt(2), + 'track': types.PaddedInt(2), + 'tracktotal': types.PaddedInt(2), + 'disc': types.PaddedInt(2), + 'disctotal': types.PaddedInt(2), + 'lyrics': types.STRING, + 'comments': types.STRING, + 'bpm': types.INTEGER, + 'comp': types.BOOLEAN, + 'mb_trackid': types.STRING, + 'mb_albumid': types.STRING, + 'mb_artistid': types.STRING, + 'mb_albumartistid': types.STRING, + 'albumtype': types.STRING, + 'label': types.STRING, + 'acoustid_fingerprint': types.STRING, + 'acoustid_id': types.STRING, + 'mb_releasegroupid': types.STRING, + 'asin': types.STRING, + 'catalognum': types.STRING, + 'script': types.STRING, + 'language': types.STRING, + 'country': types.STRING, + 'albumstatus': types.STRING, + 'media': types.STRING, + 'albumdisambig': types.STRING, + 'disctitle': types.STRING, + 'encoder': types.STRING, + 'rg_track_gain': types.NULL_FLOAT, + 'rg_track_peak': types.NULL_FLOAT, + 'rg_album_gain': types.NULL_FLOAT, + 'rg_album_peak': types.NULL_FLOAT, + 'original_year': types.PaddedInt(4), + 'original_month': types.PaddedInt(2), + 'original_day': types.PaddedInt(2), + 'initial_key': MusicalKey(), + + 'length': types.FLOAT, + 'bitrate': types.ScaledInt(1000, u'kbps'), + 'format': types.STRING, + 'samplerate': types.ScaledInt(1000, u'kHz'), + 'bitdepth': types.INTEGER, + 'channels': types.INTEGER, + 'mtime': DateType(), + 'added': DateType(), + } + + _search_fields = ('artist', 'title', 'comments', + 'album', 'albumartist', 'genre') + + _media_fields = set(MediaFile.readable_fields()) \ + .intersection(_fields.keys()) + """Set of item fields that are backed by `MediaFile` fields. + + Any kind of field (fixed, flexible, and computed) may be a media + field. Only these fields are read from disk in `read` and written in + `write`. + """ + + _formatter = FormattedItemMapping + + _sorts = {'artist': SmartArtistSort} + + @classmethod + def _getters(cls): + getters = plugins.item_field_getters() + getters['singleton'] = lambda i: i.album_id is None + return getters + + @classmethod + def from_path(cls, path): + """Creates a new item from the media file at the specified path. + """ + # Initiate with values that aren't read from files. + i = cls(album_id=None) + i.read(path) + i.mtime = i.current_mtime() # Initial mtime. + return i + + def __setitem__(self, key, value): + """Set the item's value for a standard field or a flexattr. + """ + # Encode unicode paths and read buffers. + if key == 'path': + if isinstance(value, unicode): + value = bytestring_path(value) + elif isinstance(value, buffer): + value = str(value) + + if key in MediaFile.fields(): + self.mtime = 0 # Reset mtime on dirty. + + super(Item, self).__setitem__(key, value) + + def update(self, values): + """Set all key/value pairs in the mapping. If mtime is + specified, it is not reset (as it might otherwise be). + """ + super(Item, self).update(values) + if self.mtime == 0 and 'mtime' in values: + self.mtime = values['mtime'] + + def get_album(self): + """Get the Album object that this item belongs to, if any, or + None if the item is a singleton or is not associated with a + library. + """ + if not self._db: + return None + return self._db.get_album(self) + + # Interaction with file metadata. + + def read(self, read_path=None): + """Read the metadata from the associated file. + + If `read_path` is specified, read metadata from that file + instead. Updates all the properties in `_media_fields` + from the media file. + + Raises a `ReadError` if the file could not be read. + """ + if read_path is None: + read_path = self.path + else: + read_path = normpath(read_path) + try: + mediafile = MediaFile(syspath(read_path)) + except (OSError, IOError, UnreadableFileError) as exc: + raise ReadError(read_path, exc) + + for key in self._media_fields: + value = getattr(mediafile, key) + if isinstance(value, (int, long)): + # Filter values wider than 64 bits (in signed representation). + # SQLite cannot store them. py26: Post transition, we can use: + # value.bit_length() > 63 + if abs(value) >= 2 ** 63: + value = 0 + self[key] = value + + # Database's mtime should now reflect the on-disk value. + if read_path == self.path: + self.mtime = self.current_mtime() + + self.path = read_path + + def write(self, path=None): + """Write the item's metadata to a media file. + + All fields in `_media_fields` are written to disk according to + the values on this object. + + Can raise either a `ReadError` or a `WriteError`. + """ + if path is None: + path = self.path + else: + path = normpath(path) + + tags = dict(self) + plugins.send('write', item=self, path=path, tags=tags) + + try: + mediafile = MediaFile(syspath(path), + id3v23=beets.config['id3v23'].get(bool)) + except (OSError, IOError, UnreadableFileError) as exc: + raise ReadError(self.path, exc) + + mediafile.update(tags) + try: + mediafile.save() + except (OSError, IOError, MutagenError) as exc: + raise WriteError(self.path, exc) + + # The file has a new mtime. + if path == self.path: + self.mtime = self.current_mtime() + plugins.send('after_write', item=self, path=path) + + def try_write(self, path=None): + """Calls `write()` but catches and logs `FileOperationError` + exceptions. + + Returns `False` an exception was caught and `True` otherwise. + """ + try: + self.write(path) + return True + except FileOperationError as exc: + log.error(exc) + return False + + def try_sync(self, write=None): + """Synchronize the item with the database and the media file + tags, updating them with this object's current state. + + By default, the current `path` for the item is used to write + tags. If `write` is `False`, no tags are written. If `write` is + a path, tags are written to that file instead. + + Similar to calling :meth:`write` and :meth:`store`. + """ + if write is True: + write = None + if write is not False: + self.try_write(path=write) + self.store() + + # Files themselves. + + def move_file(self, dest, copy=False, link=False): + """Moves or copies the item's file, updating the path value if + the move succeeds. If a file exists at ``dest``, then it is + slightly modified to be unique. + """ + if not util.samefile(self.path, dest): + dest = util.unique_path(dest) + if copy: + util.copy(self.path, dest) + plugins.send("item_copied", item=self, source=self.path, + destination=dest) + elif link: + util.link(self.path, dest) + plugins.send("item_linked", item=self, source=self.path, + destination=dest) + else: + plugins.send("before_item_moved", item=self, source=self.path, + destination=dest) + util.move(self.path, dest) + plugins.send("item_moved", item=self, source=self.path, + destination=dest) + + # Either copying or moving succeeded, so update the stored path. + self.path = dest + + def current_mtime(self): + """Returns the current mtime of the file, rounded to the nearest + integer. + """ + return int(os.path.getmtime(syspath(self.path))) + + # Model methods. + + def remove(self, delete=False, with_album=True): + """Removes the item. If `delete`, then the associated file is + removed from disk. If `with_album`, then the item's album (if + any) is removed if it the item was the last in the album. + """ + super(Item, self).remove() + + # Remove the album if it is empty. + if with_album: + album = self.get_album() + if album and not album.items(): + album.remove(delete, False) + + # Send a 'item_removed' signal to plugins + plugins.send('item_removed', item=self) + + # Delete the associated file. + if delete: + util.remove(self.path) + util.prune_dirs(os.path.dirname(self.path), self._db.directory) + + self._db._memotable = {} + + def move(self, copy=False, link=False, basedir=None, with_album=True): + """Move the item to its designated location within the library + directory (provided by destination()). Subdirectories are + created as needed. If the operation succeeds, the item's path + field is updated to reflect the new location. + + If `copy` is true, moving the file is copied rather than moved. + Similarly, `link` creates a symlink instead. + + basedir overrides the library base directory for the + destination. + + If the item is in an album, the album is given an opportunity to + move its art. (This can be disabled by passing + with_album=False.) + + The item is stored to the database if it is in the database, so + any dirty fields prior to the move() call will be written as a + side effect. You probably want to call save() to commit the DB + transaction. + """ + self._check_db() + dest = self.destination(basedir=basedir) + + # Create necessary ancestry for the move. + util.mkdirall(dest) + + # Perform the move and store the change. + old_path = self.path + self.move_file(dest, copy, link) + self.store() + + # If this item is in an album, move its art. + if with_album: + album = self.get_album() + if album: + album.move_art(copy) + album.store() + + # Prune vacated directory. + if not copy: + util.prune_dirs(os.path.dirname(old_path), self._db.directory) + + # Templating. + + def destination(self, fragment=False, basedir=None, platform=None, + path_formats=None): + """Returns the path in the library directory designated for the + item (i.e., where the file ought to be). fragment makes this + method return just the path fragment underneath the root library + directory; the path is also returned as Unicode instead of + encoded as a bytestring. basedir can override the library's base + directory for the destination. + """ + self._check_db() + platform = platform or sys.platform + basedir = basedir or self._db.directory + path_formats = path_formats or self._db.path_formats + + # Use a path format based on a query, falling back on the + # default. + for query, path_format in path_formats: + if query == PF_KEY_DEFAULT: + continue + query, _ = parse_query_string(query, type(self)) + if query.match(self): + # The query matches the item! Use the corresponding path + # format. + break + else: + # No query matched; fall back to default. + for query, path_format in path_formats: + if query == PF_KEY_DEFAULT: + break + else: + assert False, "no default path format" + if isinstance(path_format, Template): + subpath_tmpl = path_format + else: + subpath_tmpl = Template(path_format) + + # Evaluate the selected template. + subpath = self.evaluate_template(subpath_tmpl, True) + + # Prepare path for output: normalize Unicode characters. + if platform == 'darwin': + subpath = unicodedata.normalize('NFD', subpath) + else: + subpath = unicodedata.normalize('NFC', subpath) + + if beets.config['asciify_paths']: + subpath = unidecode(subpath) + + # Truncate components and remove forbidden characters. + subpath = util.sanitize_path(subpath, self._db.replacements) + + # Encode for the filesystem. + if not fragment: + subpath = bytestring_path(subpath) + + # Preserve extension. + _, extension = os.path.splitext(self.path) + if fragment: + # Outputting Unicode. + extension = extension.decode('utf8', 'ignore') + subpath += extension.lower() + + # Truncate too-long components. + maxlen = beets.config['max_filename_length'].get(int) + if not maxlen: + # When zero, try to determine from filesystem. + maxlen = util.max_filename_length(self._db.directory) + subpath = util.truncate_path(subpath, maxlen) + + if fragment: + return subpath + else: + return normpath(os.path.join(basedir, subpath)) + + +class Album(LibModel): + """Provides access to information about albums stored in a + library. Reflects the library's "albums" table, including album + art. + """ + _table = 'albums' + _flex_table = 'album_attributes' + _always_dirty = True + _fields = { + 'id': types.PRIMARY_ID, + 'artpath': PathType(), + 'added': DateType(), + + 'albumartist': types.STRING, + 'albumartist_sort': types.STRING, + 'albumartist_credit': types.STRING, + 'album': types.STRING, + 'genre': types.STRING, + 'year': types.PaddedInt(4), + 'month': types.PaddedInt(2), + 'day': types.PaddedInt(2), + 'tracktotal': types.PaddedInt(2), + 'disctotal': types.PaddedInt(2), + 'comp': types.BOOLEAN, + 'mb_albumid': types.STRING, + 'mb_albumartistid': types.STRING, + 'albumtype': types.STRING, + 'label': types.STRING, + 'mb_releasegroupid': types.STRING, + 'asin': types.STRING, + 'catalognum': types.STRING, + 'script': types.STRING, + 'language': types.STRING, + 'country': types.STRING, + 'albumstatus': types.STRING, + 'albumdisambig': types.STRING, + 'rg_album_gain': types.NULL_FLOAT, + 'rg_album_peak': types.NULL_FLOAT, + 'original_year': types.PaddedInt(4), + 'original_month': types.PaddedInt(2), + 'original_day': types.PaddedInt(2), + } + + _search_fields = ('album', 'albumartist', 'genre') + + _sorts = { + 'albumartist': SmartArtistSort, + 'artist': SmartArtistSort, + } + + item_keys = [ + 'added', + 'albumartist', + 'albumartist_sort', + 'albumartist_credit', + 'album', + 'genre', + 'year', + 'month', + 'day', + 'tracktotal', + 'disctotal', + 'comp', + 'mb_albumid', + 'mb_albumartistid', + 'albumtype', + 'label', + 'mb_releasegroupid', + 'asin', + 'catalognum', + 'script', + 'language', + 'country', + 'albumstatus', + 'albumdisambig', + 'rg_album_gain', + 'rg_album_peak', + 'original_year', + 'original_month', + 'original_day', + ] + """List of keys that are set on an album's items. + """ + + @classmethod + def _getters(cls): + # In addition to plugin-provided computed fields, also expose + # the album's directory as `path`. + getters = plugins.album_field_getters() + getters['path'] = Album.item_dir + return getters + + def items(self): + """Returns an iterable over the items associated with this + album. + """ + return self._db.items(dbcore.MatchQuery('album_id', self.id)) + + def remove(self, delete=False, with_items=True): + """Removes this album and all its associated items from the + library. If delete, then the items' files are also deleted + from disk, along with any album art. The directories + containing the album are also removed (recursively) if empty. + Set with_items to False to avoid removing the album's items. + """ + super(Album, self).remove() + + # Delete art file. + if delete: + artpath = self.artpath + if artpath: + util.remove(artpath) + + # Remove (and possibly delete) the constituent items. + if with_items: + for item in self.items(): + item.remove(delete, False) + + def move_art(self, copy=False, link=False): + """Move or copy any existing album art so that it remains in the + same directory as the items. + """ + old_art = self.artpath + if not old_art: + return + + new_art = self.art_destination(old_art) + if new_art == old_art: + return + + new_art = util.unique_path(new_art) + log.debug(u'moving album art {0} to {1}' + .format(util.displayable_path(old_art), + util.displayable_path(new_art))) + if copy: + util.copy(old_art, new_art) + elif link: + util.link(old_art, new_art) + else: + util.move(old_art, new_art) + self.artpath = new_art + + # Prune old path when moving. + if not copy: + util.prune_dirs(os.path.dirname(old_art), + self._db.directory) + + def move(self, copy=False, link=False, basedir=None): + """Moves (or copies) all items to their destination. Any album + art moves along with them. basedir overrides the library base + directory for the destination. The album is stored to the + database, persisting any modifications to its metadata. + """ + basedir = basedir or self._db.directory + + # Ensure new metadata is available to items for destination + # computation. + self.store() + + # Move items. + items = list(self.items()) + for item in items: + item.move(copy, link, basedir=basedir, with_album=False) + + # Move art. + self.move_art(copy, link) + self.store() + + def item_dir(self): + """Returns the directory containing the album's first item, + provided that such an item exists. + """ + item = self.items().get() + if not item: + raise ValueError('empty album') + return os.path.dirname(item.path) + + def art_destination(self, image, item_dir=None): + """Returns a path to the destination for the album art image + for the album. `image` is the path of the image that will be + moved there (used for its extension). + + The path construction uses the existing path of the album's + items, so the album must contain at least one item or + item_dir must be provided. + """ + image = bytestring_path(image) + item_dir = item_dir or self.item_dir() + + filename_tmpl = Template(beets.config['art_filename'].get(unicode)) + subpath = self.evaluate_template(filename_tmpl, True) + if beets.config['asciify_paths']: + subpath = unidecode(subpath) + subpath = util.sanitize_path(subpath, + replacements=self._db.replacements) + subpath = bytestring_path(subpath) + + _, ext = os.path.splitext(image) + dest = os.path.join(item_dir, subpath + ext) + + return bytestring_path(dest) + + def set_art(self, path, copy=True): + """Sets the album's cover art to the image at the given path. + The image is copied (or moved) into place, replacing any + existing art. + """ + path = bytestring_path(path) + oldart = self.artpath + artdest = self.art_destination(path) + + if oldart and samefile(path, oldart): + # Art already set. + return + elif samefile(path, artdest): + # Art already in place. + self.artpath = path + return + + # Normal operation. + if oldart == artdest: + util.remove(oldart) + artdest = util.unique_path(artdest) + if copy: + util.copy(path, artdest) + else: + util.move(path, artdest) + self.artpath = artdest + + def store(self): + """Update the database with the album information. The album's + tracks are also updated. + """ + # Get modified track fields. + track_updates = {} + for key in self.item_keys: + if key in self._dirty: + track_updates[key] = self[key] + + with self._db.transaction(): + super(Album, self).store() + if track_updates: + for item in self.items(): + for key, value in track_updates.items(): + item[key] = value + item.store() + + def try_sync(self, write=True): + """Synchronize the album and its items with the database and + their files by updating them with this object's current state. + + `write` indicates whether to write tags to the item files. + """ + self.store() + for item in self.items(): + item.try_sync(bool(write)) + + +# Query construction helpers. + +def parse_query_parts(parts, model_cls): + """Given a beets query string as a list of components, return the + `Query` and `Sort` they represent. + + Like `dbcore.parse_sorted_query`, with beets query prefixes and + special path query detection. + """ + # Get query types and their prefix characters. + prefixes = {':': dbcore.query.RegexpQuery} + prefixes.update(plugins.queries()) + + # Special-case path-like queries, which are non-field queries + # containing path separators (/). + if 'path' in model_cls._fields: + path_parts = [] + non_path_parts = [] + for s in parts: + if s.find(os.sep, 0, s.find(':')) != -1: + # Separator precedes colon. + path_parts.append(s) + else: + non_path_parts.append(s) + else: + path_parts = () + non_path_parts = parts + + query, sort = dbcore.parse_sorted_query( + model_cls, non_path_parts, prefixes + ) + + # Add path queries to aggregate query. + if path_parts: + query.subqueries += [PathQuery('path', s) for s in path_parts] + return query, sort + + +def parse_query_string(s, model_cls): + """Given a beets query string, return the `Query` and `Sort` they + represent. + + The string is split into components using shell-like syntax. + """ + # A bug in Python < 2.7.3 prevents correct shlex splitting of + # Unicode strings. + # http://bugs.python.org/issue6988 + if isinstance(s, unicode): + s = s.encode('utf8') + parts = [p.decode('utf8') for p in shlex.split(s)] + return parse_query_parts(parts, model_cls) + + +# The Library: interface to the database. + +class Library(dbcore.Database): + """A database of music containing songs and albums. + """ + _models = (Item, Album) + + def __init__(self, path='library.blb', + directory='~/Music', + path_formats=((PF_KEY_DEFAULT, + '$artist/$album/$track $title'),), + replacements=None): + if path != ':memory:': + self.path = bytestring_path(normpath(path)) + super(Library, self).__init__(path) + + self.directory = bytestring_path(normpath(directory)) + self.path_formats = path_formats + self.replacements = replacements + + self._memotable = {} # Used for template substitution performance. + + # Adding objects to the database. + + def add(self, obj): + """Add the :class:`Item` or :class:`Album` object to the library + database. Return the object's new id. + """ + obj.add(self) + self._memotable = {} + return obj.id + + def add_album(self, items): + """Create a new album consisting of a list of items. + + The items are added to the database if they don't yet have an + ID. Return a new :class:`Album` object. The list items must not + be empty. + """ + if not items: + raise ValueError(u'need at least one item') + + # Create the album structure using metadata from the first item. + values = dict((key, items[0][key]) for key in Album.item_keys) + album = Album(self, **values) + + # Add the album structure and set the items' album_id fields. + # Store or add the items. + with self.transaction(): + album.add(self) + for item in items: + item.album_id = album.id + if item.id is None: + item.add(self) + else: + item.store() + + return album + + # Querying. + + def _fetch(self, model_cls, query, sort=None): + """Parse a query and fetch. If a order specification is present + in the query string the `sort` argument is ignored. + """ + # Parse the query, if necessary. + parsed_sort = None + if isinstance(query, basestring): + query, parsed_sort = parse_query_string(query, model_cls) + elif isinstance(query, (list, tuple)): + query, parsed_sort = parse_query_parts(query, model_cls) + + # Any non-null sort specified by the parsed query overrides the + # provided sort. + if parsed_sort and not isinstance(parsed_sort, dbcore.query.NullSort): + sort = parsed_sort + + return super(Library, self)._fetch( + model_cls, query, sort + ) + + def albums(self, query=None, sort=None): + """Get :class:`Album` objects matching the query. + """ + sort = sort or dbcore.sort_from_strings( + Album, beets.config['sort_album'].as_str_seq() + ) + return self._fetch(Album, query, sort) + + def items(self, query=None, sort=None): + """Get :class:`Item` objects matching the query. + """ + sort = sort or dbcore.sort_from_strings( + Item, beets.config['sort_item'].as_str_seq() + ) + return self._fetch(Item, query, sort) + + # Convenience accessors. + + def get_item(self, id): + """Fetch an :class:`Item` by its ID. Returns `None` if no match is + found. + """ + return self._get(Item, id) + + def get_album(self, item_or_id): + """Given an album ID or an item associated with an album, return + an :class:`Album` object for the album. If no such album exists, + returns `None`. + """ + if isinstance(item_or_id, int): + album_id = item_or_id + else: + album_id = item_or_id.album_id + if album_id is None: + return None + return self._get(Album, album_id) + + +# Default path template resources. + +def _int_arg(s): + """Convert a string argument to an integer for use in a template + function. May raise a ValueError. + """ + return int(s.strip()) + + +class DefaultTemplateFunctions(object): + """A container class for the default functions provided to path + templates. These functions are contained in an object to provide + additional context to the functions -- specifically, the Item being + evaluated. + """ + _prefix = 'tmpl_' + + def __init__(self, item=None, lib=None): + """Paramaterize the functions. If `item` or `lib` is None, then + some functions (namely, ``aunique``) will always evaluate to the + empty string. + """ + self.item = item + self.lib = lib + + def functions(self): + """Returns a dictionary containing the functions defined in this + object. The keys are function names (as exposed in templates) + and the values are Python functions. + """ + out = {} + for key in self._func_names: + out[key[len(self._prefix):]] = getattr(self, key) + return out + + @staticmethod + def tmpl_lower(s): + """Convert a string to lower case.""" + return s.lower() + + @staticmethod + def tmpl_upper(s): + """Covert a string to upper case.""" + return s.upper() + + @staticmethod + def tmpl_title(s): + """Convert a string to title case.""" + return s.title() + + @staticmethod + def tmpl_left(s, chars): + """Get the leftmost characters of a string.""" + return s[0:_int_arg(chars)] + + @staticmethod + def tmpl_right(s, chars): + """Get the rightmost characters of a string.""" + return s[-_int_arg(chars):] + + @staticmethod + def tmpl_if(condition, trueval, falseval=u''): + """If ``condition`` is nonempty and nonzero, emit ``trueval``; + otherwise, emit ``falseval`` (if provided). + """ + try: + int_condition = _int_arg(condition) + except ValueError: + if condition.lower() == "false": + return falseval + else: + condition = int_condition + + if condition: + return trueval + else: + return falseval + + @staticmethod + def tmpl_asciify(s): + """Translate non-ASCII characters to their ASCII equivalents. + """ + return unidecode(s) + + @staticmethod + def tmpl_time(s, format): + """Format a time value using `strftime`. + """ + cur_fmt = beets.config['time_format'].get(unicode) + return time.strftime(format, time.strptime(s, cur_fmt)) + + def tmpl_aunique(self, keys=None, disam=None): + """Generate a string that is guaranteed to be unique among all + albums in the library who share the same set of keys. A fields + from "disam" is used in the string if one is sufficient to + disambiguate the albums. Otherwise, a fallback opaque value is + used. Both "keys" and "disam" should be given as + whitespace-separated lists of field names. + """ + # Fast paths: no album, no item or library, or memoized value. + if not self.item or not self.lib: + return u'' + if self.item.album_id is None: + return u'' + memokey = ('aunique', keys, disam, self.item.album_id) + memoval = self.lib._memotable.get(memokey) + if memoval is not None: + return memoval + + keys = keys or 'albumartist album' + disam = disam or 'albumtype year label catalognum albumdisambig' + keys = keys.split() + disam = disam.split() + + album = self.lib.get_album(self.item) + if not album: + # Do nothing for singletons. + self.lib._memotable[memokey] = u'' + return u'' + + # Find matching albums to disambiguate with. + subqueries = [] + for key in keys: + value = getattr(album, key) + subqueries.append(dbcore.MatchQuery(key, value)) + albums = self.lib.albums(dbcore.AndQuery(subqueries)) + + # If there's only one album to matching these details, then do + # nothing. + if len(albums) == 1: + self.lib._memotable[memokey] = u'' + return u'' + + # Find the first disambiguator that distinguishes the albums. + for disambiguator in disam: + # Get the value for each album for the current field. + disam_values = set([getattr(a, disambiguator) for a in albums]) + + # If the set of unique values is equal to the number of + # albums in the disambiguation set, we're done -- this is + # sufficient disambiguation. + if len(disam_values) == len(albums): + break + + else: + # No disambiguator distinguished all fields. + res = u' {0}'.format(album.id) + self.lib._memotable[memokey] = res + return res + + # Flatten disambiguation value into a string. + disam_value = album.formatted(True).get(disambiguator) + res = u' [{0}]'.format(disam_value) + self.lib._memotable[memokey] = res + return res + + +# Get the name of tmpl_* functions in the above class. +DefaultTemplateFunctions._func_names = \ + [s for s in dir(DefaultTemplateFunctions) + if s.startswith(DefaultTemplateFunctions._prefix)] diff --git a/lib/beets/mediafile.py b/lib/beets/mediafile.py new file mode 100644 index 00000000..49ef1037 --- /dev/null +++ b/lib/beets/mediafile.py @@ -0,0 +1,1929 @@ +# This file is part of beets. +# Copyright 2014, Adrian Sampson. +# +# Permission is hereby granted, free of charge, to any person obtaining +# a copy of this software and associated documentation files (the +# "Software"), to deal in the Software without restriction, including +# without limitation the rights to use, copy, modify, merge, publish, +# distribute, sublicense, and/or sell copies of the Software, and to +# permit persons to whom the Software is furnished to do so, subject to +# the following conditions: +# +# The above copyright notice and this permission notice shall be +# included in all copies or substantial portions of the Software. + +"""Handles low-level interfacing for files' tags. Wraps Mutagen to +automatically detect file types and provide a unified interface for a +useful subset of music files' tags. + +Usage: + + >>> f = MediaFile('Lucy.mp3') + >>> f.title + u'Lucy in the Sky with Diamonds' + >>> f.artist = 'The Beatles' + >>> f.save() + +A field will always return a reasonable value of the correct type, even +if no tag is present. If no value is available, the value will be false +(e.g., zero or the empty string). + +Internally ``MediaFile`` uses ``MediaField`` descriptors to access the +data from the tags. In turn ``MediaField`` uses a number of +``StorageStyle`` strategies to handle format specific logic. +""" +import mutagen +import mutagen.mp3 +import mutagen.oggopus +import mutagen.oggvorbis +import mutagen.mp4 +import mutagen.flac +import mutagen.monkeysaudio +import mutagen.asf +import mutagen.aiff +import datetime +import re +import base64 +import math +import struct +import imghdr +import os +import logging +import traceback +import enum + +from beets.util import displayable_path + + +__all__ = ['UnreadableFileError', 'FileTypeError', 'MediaFile'] + +log = logging.getLogger('beets') + +# Human-readable type names. +TYPES = { + 'mp3': 'MP3', + 'aac': 'AAC', + 'alac': 'ALAC', + 'ogg': 'OGG', + 'opus': 'Opus', + 'flac': 'FLAC', + 'ape': 'APE', + 'wv': 'WavPack', + 'mpc': 'Musepack', + 'asf': 'Windows Media', + 'aiff': 'AIFF', +} + + +# Exceptions. + +class UnreadableFileError(Exception): + """Mutagen is not able to extract information from the file. + """ + def __init__(self, path): + Exception.__init__(self, displayable_path(path)) + + +class FileTypeError(UnreadableFileError): + """Reading this type of file is not supported. + + If passed the `mutagen_type` argument this indicates that the + mutagen type is not supported by `Mediafile`. + """ + def __init__(self, path, mutagen_type=None): + path = displayable_path(path) + if mutagen_type is None: + msg = path + else: + msg = u'{0}: of mutagen type {1}'.format(path, mutagen_type) + Exception.__init__(self, msg) + + +class MutagenError(UnreadableFileError): + """Raised when Mutagen fails unexpectedly---probably due to a bug. + """ + def __init__(self, path, mutagen_exc): + msg = u'{0}: {1}'.format(displayable_path(path), mutagen_exc) + Exception.__init__(self, msg) + + +# Utility. + +def _safe_cast(out_type, val): + """Try to covert val to out_type but never raise an exception. If + the value can't be converted, then a sensible default value is + returned. out_type should be bool, int, or unicode; otherwise, the + value is just passed through. + """ + if val is None: + return None + + if out_type == int: + if isinstance(val, int) or isinstance(val, float): + # Just a number. + return int(val) + else: + # Process any other type as a string. + if not isinstance(val, basestring): + val = unicode(val) + # Get a number from the front of the string. + val = re.match(r'[0-9]*', val.strip()).group(0) + if not val: + return 0 + else: + return int(val) + + elif out_type == bool: + try: + # Should work for strings, bools, ints: + return bool(int(val)) + except ValueError: + return False + + elif out_type == unicode: + if isinstance(val, str): + return val.decode('utf8', 'ignore') + elif isinstance(val, unicode): + return val + else: + return unicode(val) + + elif out_type == float: + if isinstance(val, int) or isinstance(val, float): + return float(val) + else: + if not isinstance(val, basestring): + val = unicode(val) + match = re.match(r'[\+-]?[0-9\.]+', val.strip()) + if match: + val = match.group(0) + if val: + return float(val) + return 0.0 + + else: + return val + + +# Image coding for ASF/WMA. + +def _unpack_asf_image(data): + """Unpack image data from a WM/Picture tag. Return a tuple + containing the MIME type, the raw image data, a type indicator, and + the image's description. + + This function is treated as "untrusted" and could throw all manner + of exceptions (out-of-bounds, etc.). We should clean this up + sometime so that the failure modes are well-defined. + """ + type, size = struct.unpack_from(" 0: + gain = math.log10(maxgain / 1000.0) * -10 + else: + # Invalid gain value found. + gain = 0.0 + + # SoundCheck stores peak values as the actual value of the sample, + # and again separately for the left and right channels. We need to + # convert this to a percentage of full scale, which is 32768 for a + # 16 bit sample. Once again, we play it safe by using the larger of + # the two values. + peak = max(soundcheck[6:8]) / 32768.0 + + return round(gain, 2), round(peak, 6) + + +def _sc_encode(gain, peak): + """Encode ReplayGain gain/peak values as a Sound Check string. + """ + # SoundCheck stores the peak value as the actual value of the + # sample, rather than the percentage of full scale that RG uses, so + # we do a simple conversion assuming 16 bit samples. + peak *= 32768.0 + + # SoundCheck stores absolute RMS values in some unknown units rather + # than the dB values RG uses. We can calculate these absolute values + # from the gain ratio using a reference value of 1000 units. We also + # enforce the maximum value here, which is equivalent to about + # -18.2dB. + g1 = min(round((10 ** (gain / -10)) * 1000), 65534) + # Same as above, except our reference level is 2500 units. + g2 = min(round((10 ** (gain / -10)) * 2500), 65534) + + # The purpose of these values are unknown, but they also seem to be + # unused so we just use zero. + uk = 0 + values = (g1, g1, g2, g2, uk, uk, peak, peak, uk, uk) + return (u' %08X' * 10) % values + + +# Cover art and other images. + +def _image_mime_type(data): + """Return the MIME type of the image data (a bytestring). + """ + kind = imghdr.what(None, h=data) + if kind in ['gif', 'jpeg', 'png', 'tiff', 'bmp']: + return 'image/{0}'.format(kind) + elif kind == 'pgm': + return 'image/x-portable-graymap' + elif kind == 'pbm': + return 'image/x-portable-bitmap' + elif kind == 'ppm': + return 'image/x-portable-pixmap' + elif kind == 'xbm': + return 'image/x-xbitmap' + else: + return 'image/x-{0}'.format(kind) + + +class ImageType(enum.Enum): + """Indicates the kind of an `Image` stored in a file's tag. + """ + other = 0 + icon = 1 + other_icon = 2 + front = 3 + back = 4 + leaflet = 5 + media = 6 + lead_artist = 7 + artist = 8 + conductor = 9 + group = 10 + composer = 11 + lyricist = 12 + recording_location = 13 + recording_session = 14 + performance = 15 + screen_capture = 16 + fish = 17 + illustration = 18 + artist_logo = 19 + publisher_logo = 20 + + +class Image(object): + """Strucuture representing image data and metadata that can be + stored and retrieved from tags. + + The structure has four properties. + * ``data`` The binary data of the image + * ``desc`` An optional descritpion of the image + * ``type`` An instance of `ImageType` indicating the kind of image + * ``mime_type`` Read-only property that contains the mime type of + the binary data + """ + def __init__(self, data, desc=None, type=None): + self.data = data + self.desc = desc + if isinstance(type, int): + type = list(ImageType)[type] + self.type = type + + @property + def mime_type(self): + if self.data: + return _image_mime_type(self.data) + + @property + def type_index(self): + if self.type is None: + # This method is used when a tag format requires the type + # index to be set, so we return "other" as the default value. + return 0 + return self.type.value + + +# StorageStyle classes describe strategies for accessing values in +# Mutagen file objects. + +class StorageStyle(object): + """A strategy for storing a value for a certain tag format (or set + of tag formats). This basic StorageStyle describes simple 1:1 + mapping from raw values to keys in a Mutagen file object; subclasses + describe more sophisticated translations or format-specific access + strategies. + + MediaFile uses a StorageStyle via three methods: ``get()``, + ``set()``, and ``delete()``. It passes a Mutagen file object to + each. + + Internally, the StorageStyle implements ``get()`` and ``set()`` + using two steps that may be overridden by subtypes. To get a value, + the StorageStyle first calls ``fetch()`` to retrieve the value + corresponding to a key and then ``deserialize()`` to convert the raw + Mutagen value to a consumable Python value. Similarly, to set a + field, we call ``serialize()`` to encode the value and then + ``store()`` to assign the result into the Mutagen object. + + Each StorageStyle type has a class-level `formats` attribute that is + a list of strings indicating the formats that the style applies to. + MediaFile only uses StorageStyles that apply to the correct type for + a given audio file. + """ + + formats = ['FLAC', 'OggOpus', 'OggTheora', 'OggSpeex', 'OggVorbis', + 'OggFlac', 'APEv2File', 'WavPack', 'Musepack', 'MonkeysAudio'] + """List of mutagen classes the StorageStyle can handle. + """ + + def __init__(self, key, as_type=unicode, suffix=None, float_places=2): + """Create a basic storage strategy. Parameters: + + - `key`: The key on the Mutagen file object used to access the + field's data. + - `as_type`: The Python type that the value is stored as + internally (`unicode`, `int`, `bool`, or `bytes`). + - `suffix`: When `as_type` is a string type, append this before + storing the value. + - `float_places`: When the value is a floating-point number and + encoded as a string, the number of digits to store after the + decimal point. + """ + self.key = key + self.as_type = as_type + self.suffix = suffix + self.float_places = float_places + + # Convert suffix to correct string type. + if self.suffix and self.as_type is unicode: + self.suffix = self.as_type(self.suffix) + + # Getter. + + def get(self, mutagen_file): + """Get the value for the field using this style. + """ + return self.deserialize(self.fetch(mutagen_file)) + + def fetch(self, mutagen_file): + """Retrieve the raw value of for this tag from the Mutagen file + object. + """ + try: + return mutagen_file[self.key][0] + except (KeyError, IndexError): + return None + + def deserialize(self, mutagen_value): + """Given a raw value stored on a Mutagen object, decode and + return the represented value. + """ + if self.suffix and isinstance(mutagen_value, unicode) \ + and mutagen_value.endswith(self.suffix): + return mutagen_value[:-len(self.suffix)] + else: + return mutagen_value + + # Setter. + + def set(self, mutagen_file, value): + """Assign the value for the field using this style. + """ + self.store(mutagen_file, self.serialize(value)) + + def store(self, mutagen_file, value): + """Store a serialized value in the Mutagen file object. + """ + mutagen_file[self.key] = [value] + + def serialize(self, value): + """Convert the external Python value to a type that is suitable for + storing in a Mutagen file object. + """ + if isinstance(value, float) and self.as_type is unicode: + value = u'{0:.{1}f}'.format(value, self.float_places) + value = self.as_type(value) + elif self.as_type is unicode: + if isinstance(value, bool): + # Store bools as 1/0 instead of True/False. + value = unicode(int(bool(value))) + elif isinstance(value, str): + value = value.decode('utf8', 'ignore') + else: + value = unicode(value) + else: + value = self.as_type(value) + + if self.suffix: + value += self.suffix + + return value + + def delete(self, mutagen_file): + """Remove the tag from the file. + """ + if self.key in mutagen_file: + del mutagen_file[self.key] + + +class ListStorageStyle(StorageStyle): + """Abstract storage style that provides access to lists. + + The ListMediaField descriptor uses a ListStorageStyle via two + methods: ``get_list()`` and ``set_list()``. It passes a Mutagen file + object to each. + + Subclasses may overwrite ``fetch`` and ``store``. ``fetch`` must + return a (possibly empty) list and ``store`` receives a serialized + list of values as the second argument. + + The `serialize` and `deserialize` methods (from the base + `StorageStyle`) are still called with individual values. This class + handles packing and unpacking the values into lists. + """ + def get(self, mutagen_file): + """Get the first value in the field's value list. + """ + try: + return self.get_list(mutagen_file)[0] + except IndexError: + return None + + def get_list(self, mutagen_file): + """Get a list of all values for the field using this style. + """ + return [self.deserialize(item) for item in self.fetch(mutagen_file)] + + def fetch(self, mutagen_file): + """Get the list of raw (serialized) values. + """ + try: + return mutagen_file[self.key] + except KeyError: + return [] + + def set(self, mutagen_file, value): + """Set an individual value as the only value for the field using + this style. + """ + self.set_list(mutagen_file, [value]) + + def set_list(self, mutagen_file, values): + """Set all values for the field using this style. `values` + should be an iterable. + """ + self.store(mutagen_file, [self.serialize(value) for value in values]) + + def store(self, mutagen_file, values): + """Set the list of all raw (serialized) values for this field. + """ + mutagen_file[self.key] = values + + +class SoundCheckStorageStyleMixin(object): + """A mixin for storage styles that read and write iTunes SoundCheck + analysis values. The object must have an `index` field that + indicates which half of the gain/peak pair---0 or 1---the field + represents. + """ + def get(self, mutagen_file): + data = self.fetch(mutagen_file) + if data is not None: + return _sc_decode(data)[self.index] + + def set(self, mutagen_file, value): + data = self.fetch(mutagen_file) + if data is None: + gain_peak = [0, 0] + else: + gain_peak = list(_sc_decode(data)) + gain_peak[self.index] = value or 0 + data = self.serialize(_sc_encode(*gain_peak)) + self.store(mutagen_file, data) + + +class ASFStorageStyle(ListStorageStyle): + """A general storage style for Windows Media/ASF files. + """ + formats = ['ASF'] + + def deserialize(self, data): + if isinstance(data, mutagen.asf.ASFBaseAttribute): + data = data.value + return data + + +class MP4StorageStyle(StorageStyle): + """A general storage style for MPEG-4 tags. + """ + formats = ['MP4'] + + def serialize(self, value): + value = super(MP4StorageStyle, self).serialize(value) + if self.key.startswith('----:') and isinstance(value, unicode): + value = value.encode('utf8') + return value + + +class MP4TupleStorageStyle(MP4StorageStyle): + """A style for storing values as part of a pair of numbers in an + MPEG-4 file. + """ + def __init__(self, key, index=0, **kwargs): + super(MP4TupleStorageStyle, self).__init__(key, **kwargs) + self.index = index + + def deserialize(self, mutagen_value): + items = mutagen_value or [] + packing_length = 2 + return list(items) + [0] * (packing_length - len(items)) + + def get(self, mutagen_file): + value = super(MP4TupleStorageStyle, self).get(mutagen_file)[self.index] + if value == 0: + # The values are always present and saved as integers. So we + # assume that "0" indicates it is not set. + return None + else: + return value + + def set(self, mutagen_file, value): + if value is None: + value = 0 + items = self.deserialize(self.fetch(mutagen_file)) + items[self.index] = int(value) + self.store(mutagen_file, items) + + def delete(self, mutagen_file): + if self.index == 0: + super(MP4TupleStorageStyle, self).delete(mutagen_file) + else: + self.set(mutagen_file, None) + + +class MP4ListStorageStyle(ListStorageStyle, MP4StorageStyle): + pass + + +class MP4SoundCheckStorageStyle(SoundCheckStorageStyleMixin, MP4StorageStyle): + def __init__(self, key, index=0, **kwargs): + super(MP4SoundCheckStorageStyle, self).__init__(key, **kwargs) + self.index = index + + +class MP4BoolStorageStyle(MP4StorageStyle): + """A style for booleans in MPEG-4 files. (MPEG-4 has an atom type + specifically for representing booleans.) + """ + def get(self, mutagen_file): + try: + return mutagen_file[self.key] + except KeyError: + return None + + def get_list(self, mutagen_file): + raise NotImplementedError('MP4 bool storage does not support lists') + + def set(self, mutagen_file, value): + mutagen_file[self.key] = value + + def set_list(self, mutagen_file, values): + raise NotImplementedError('MP4 bool storage does not support lists') + + +class MP4ImageStorageStyle(MP4ListStorageStyle): + """Store images as MPEG-4 image atoms. Values are `Image` objects. + """ + def __init__(self, **kwargs): + super(MP4ImageStorageStyle, self).__init__(key='covr', **kwargs) + + def deserialize(self, data): + return Image(data) + + def serialize(self, image): + if image.mime_type == 'image/png': + kind = mutagen.mp4.MP4Cover.FORMAT_PNG + elif image.mime_type == 'image/jpeg': + kind = mutagen.mp4.MP4Cover.FORMAT_JPEG + else: + raise ValueError('MP4 files only supports PNG and JPEG images') + return mutagen.mp4.MP4Cover(image.data, kind) + + +class MP3StorageStyle(StorageStyle): + """Store data in ID3 frames. + """ + formats = ['MP3', 'AIFF'] + + def __init__(self, key, id3_lang=None, **kwargs): + """Create a new ID3 storage style. `id3_lang` is the value for + the language field of newly created frames. + """ + self.id3_lang = id3_lang + super(MP3StorageStyle, self).__init__(key, **kwargs) + + def fetch(self, mutagen_file): + try: + return mutagen_file[self.key].text[0] + except (KeyError, IndexError): + return None + + def store(self, mutagen_file, value): + frame = mutagen.id3.Frames[self.key](encoding=3, text=[value]) + mutagen_file.tags.setall(self.key, [frame]) + + +class MP3ListStorageStyle(ListStorageStyle, MP3StorageStyle): + """Store lists of data in multiple ID3 frames. + """ + def fetch(self, mutagen_file): + try: + return mutagen_file[self.key].text + except KeyError: + return [] + + def store(self, mutagen_file, values): + frame = mutagen.id3.Frames[self.key](encoding=3, text=values) + mutagen_file.tags.setall(self.key, [frame]) + + +class MP3UFIDStorageStyle(MP3StorageStyle): + """Store data in a UFID ID3 frame with a particular owner. + """ + def __init__(self, owner, **kwargs): + self.owner = owner + super(MP3UFIDStorageStyle, self).__init__('UFID:' + owner, **kwargs) + + def fetch(self, mutagen_file): + try: + return mutagen_file[self.key].data + except KeyError: + return None + + def store(self, mutagen_file, value): + frames = mutagen_file.tags.getall(self.key) + for frame in frames: + # Replace existing frame data. + if frame.owner == self.owner: + frame.data = value + else: + # New frame. + frame = mutagen.id3.UFID(owner=self.owner, data=value) + mutagen_file.tags.setall(self.key, [frame]) + + +class MP3DescStorageStyle(MP3StorageStyle): + """Store data in a TXXX (or similar) ID3 frame. The frame is + selected based its ``desc`` field. + """ + def __init__(self, desc=u'', key='TXXX', **kwargs): + self.description = desc + super(MP3DescStorageStyle, self).__init__(key=key, **kwargs) + + def store(self, mutagen_file, value): + frames = mutagen_file.tags.getall(self.key) + if self.key != 'USLT': + value = [value] + + # try modifying in place + found = False + for frame in frames: + if frame.desc.lower() == self.description.lower(): + frame.text = value + found = True + + # need to make a new frame? + if not found: + frame = mutagen.id3.Frames[self.key]( + desc=str(self.description), + text=value, + encoding=3 + ) + if self.id3_lang: + frame.lang = self.id3_lang + mutagen_file.tags.add(frame) + + def fetch(self, mutagen_file): + for frame in mutagen_file.tags.getall(self.key): + if frame.desc.lower() == self.description.lower(): + if self.key == 'USLT': + return frame.text + try: + return frame.text[0] + except IndexError: + return None + + def delete(self, mutagen_file): + found_frame = None + for frame in mutagen_file.tags.getall(self.key): + if frame.desc.lower() == self.description.lower(): + found_frame = frame + break + if found_frame is not None: + del mutagen_file[frame.HashKey] + + +class MP3SlashPackStorageStyle(MP3StorageStyle): + """Store value as part of pair that is serialized as a slash- + separated string. + """ + def __init__(self, key, pack_pos=0, **kwargs): + super(MP3SlashPackStorageStyle, self).__init__(key, **kwargs) + self.pack_pos = pack_pos + + def _fetch_unpacked(self, mutagen_file): + data = self.fetch(mutagen_file) + if data: + items = unicode(data).split('/') + else: + items = [] + packing_length = 2 + return list(items) + [None] * (packing_length - len(items)) + + def get(self, mutagen_file): + return self._fetch_unpacked(mutagen_file)[self.pack_pos] + + def set(self, mutagen_file, value): + items = self._fetch_unpacked(mutagen_file) + items[self.pack_pos] = value + if items[0] is None: + items[0] = '' + if items[1] is None: + items.pop() # Do not store last value + self.store(mutagen_file, '/'.join(map(unicode, items))) + + def delete(self, mutagen_file): + if self.pack_pos == 0: + super(MP3SlashPackStorageStyle, self).delete(mutagen_file) + else: + self.set(mutagen_file, None) + + +class MP3ImageStorageStyle(ListStorageStyle, MP3StorageStyle): + """Converts between APIC frames and ``Image`` instances. + + The `get_list` method inherited from ``ListStorageStyle`` returns a + list of ``Image``s. Similarly, the `set_list` method accepts a + list of ``Image``s as its ``values`` arguemnt. + """ + def __init__(self): + super(MP3ImageStorageStyle, self).__init__(key='APIC') + self.as_type = str + + def deserialize(self, apic_frame): + """Convert APIC frame into Image.""" + return Image(data=apic_frame.data, desc=apic_frame.desc, + type=apic_frame.type) + + def fetch(self, mutagen_file): + return mutagen_file.tags.getall(self.key) + + def store(self, mutagen_file, frames): + mutagen_file.tags.setall(self.key, frames) + + def delete(self, mutagen_file): + mutagen_file.tags.delall(self.key) + + def serialize(self, image): + """Return an APIC frame populated with data from ``image``. + """ + assert isinstance(image, Image) + frame = mutagen.id3.Frames[self.key]() + frame.data = image.data + frame.mime = image.mime_type + frame.desc = (image.desc or u'').encode('utf8') + frame.encoding = 3 # UTF-8 encoding of desc + frame.type = image.type_index + return frame + + +class MP3SoundCheckStorageStyle(SoundCheckStorageStyleMixin, + MP3DescStorageStyle): + def __init__(self, index=0, **kwargs): + super(MP3SoundCheckStorageStyle, self).__init__(**kwargs) + self.index = index + + +class ASFImageStorageStyle(ListStorageStyle): + """Store images packed into Windows Media/ASF byte array attributes. + Values are `Image` objects. + """ + formats = ['ASF'] + + def __init__(self): + super(ASFImageStorageStyle, self).__init__(key='WM/Picture') + + def deserialize(self, asf_picture): + mime, data, type, desc = _unpack_asf_image(asf_picture.value) + return Image(data, desc=desc, type=type) + + def serialize(self, image): + pic = mutagen.asf.ASFByteArrayAttribute() + pic.value = _pack_asf_image(image.mime_type, image.data, + type=image.type_index, + description=image.desc or u'') + return pic + + +class VorbisImageStorageStyle(ListStorageStyle): + """Store images in Vorbis comments. Both legacy COVERART fields and + modern METADATA_BLOCK_PICTURE tags are supported. Data is + base64-encoded. Values are `Image` objects. + """ + formats = ['OggOpus', 'OggTheora', 'OggSpeex', 'OggVorbis', + 'OggFlac'] + + def __init__(self): + super(VorbisImageStorageStyle, self).__init__( + key='metadata_block_picture' + ) + self.as_type = str + + def fetch(self, mutagen_file): + images = [] + if 'metadata_block_picture' not in mutagen_file: + # Try legacy COVERART tags. + if 'coverart' in mutagen_file: + for data in mutagen_file['coverart']: + images.append(Image(base64.b64decode(data))) + return images + for data in mutagen_file["metadata_block_picture"]: + try: + pic = mutagen.flac.Picture(base64.b64decode(data)) + except (TypeError, AttributeError): + continue + images.append(Image(data=pic.data, desc=pic.desc, + type=pic.type)) + return images + + def store(self, mutagen_file, image_data): + # Strip all art, including legacy COVERART. + if 'coverart' in mutagen_file: + del mutagen_file['coverart'] + if 'coverartmime' in mutagen_file: + del mutagen_file['coverartmime'] + super(VorbisImageStorageStyle, self).store(mutagen_file, image_data) + + def serialize(self, image): + """Turn a Image into a base64 encoded FLAC picture block. + """ + pic = mutagen.flac.Picture() + pic.data = image.data + pic.type = image.type_index + pic.mime = image.mime_type + pic.desc = image.desc or u'' + return base64.b64encode(pic.write()) + + +class FlacImageStorageStyle(ListStorageStyle): + """Converts between ``mutagen.flac.Picture`` and ``Image`` instances. + """ + formats = ['FLAC'] + + def __init__(self): + super(FlacImageStorageStyle, self).__init__(key='') + + def fetch(self, mutagen_file): + return mutagen_file.pictures + + def deserialize(self, flac_picture): + return Image(data=flac_picture.data, desc=flac_picture.desc, + type=flac_picture.type) + + def store(self, mutagen_file, pictures): + """``pictures`` is a list of mutagen.flac.Picture instances. + """ + mutagen_file.clear_pictures() + for pic in pictures: + mutagen_file.add_picture(pic) + + def serialize(self, image): + """Turn a Image into a mutagen.flac.Picture. + """ + pic = mutagen.flac.Picture() + pic.data = image.data + pic.type = image.type_index + pic.mime = image.mime_type + pic.desc = image.desc or u'' + return pic + + def delete(self, mutagen_file): + """Remove all images from the file. + """ + mutagen_file.clear_pictures() + + +class APEv2ImageStorageStyle(ListStorageStyle): + """Store images in APEv2 tags. Values are `Image` objects. + """ + formats = ['APEv2File', 'WavPack', 'Musepack', 'MonkeysAudio', 'OptimFROG'] + + TAG_NAMES = { + ImageType.other: 'Cover Art (other)', + ImageType.icon: 'Cover Art (icon)', + ImageType.other_icon: 'Cover Art (other icon)', + ImageType.front: 'Cover Art (front)', + ImageType.back: 'Cover Art (back)', + ImageType.leaflet: 'Cover Art (leaflet)', + ImageType.media: 'Cover Art (media)', + ImageType.lead_artist: 'Cover Art (lead)', + ImageType.artist: 'Cover Art (artist)', + ImageType.conductor: 'Cover Art (conductor)', + ImageType.group: 'Cover Art (band)', + ImageType.composer: 'Cover Art (composer)', + ImageType.lyricist: 'Cover Art (lyricist)', + ImageType.recording_location: 'Cover Art (studio)', + ImageType.recording_session: 'Cover Art (recording)', + ImageType.performance: 'Cover Art (performance)', + ImageType.screen_capture: 'Cover Art (movie scene)', + ImageType.fish: 'Cover Art (colored fish)', + ImageType.illustration: 'Cover Art (illustration)', + ImageType.artist_logo: 'Cover Art (band logo)', + ImageType.publisher_logo: 'Cover Art (publisher logo)', + } + + def __init__(self): + super(APEv2ImageStorageStyle, self).__init__(key='') + + def fetch(self, mutagen_file): + images = [] + for cover_type, cover_tag in self.TAG_NAMES.items(): + try: + frame = mutagen_file[cover_tag] + text_delimiter_index = frame.value.find('\x00') + comment = frame.value[0:text_delimiter_index] \ + if text_delimiter_index > 0 else None + image_data = frame.value[text_delimiter_index + 1:] + images.append(Image(data=image_data, type=cover_type, + desc=comment)) + except KeyError: + pass + + return images + + def set_list(self, mutagen_file, values): + self.delete(mutagen_file) + + for image in values: + image_type = image.type or ImageType.other + comment = image.desc or '' + image_data = comment + "\x00" + image.data + cover_tag = self.TAG_NAMES[image_type] + mutagen_file[cover_tag] = image_data + + def delete(self, mutagen_file): + """Remove all images from the file. + """ + for cover_tag in self.TAG_NAMES.values(): + try: + del mutagen_file[cover_tag] + except KeyError: + pass + + +# MediaField is a descriptor that represents a single logical field. It +# aggregates several StorageStyles describing how to access the data for +# each file type. + +class MediaField(object): + """A descriptor providing access to a particular (abstract) metadata + field. + """ + def __init__(self, *styles, **kwargs): + """Creates a new MediaField. + + :param styles: `StorageStyle` instances that describe the strategy + for reading and writing the field in particular + formats. There must be at least one style for + each possible file format. + + :param out_type: the type of the value that should be returned when + getting this property. + + """ + self.out_type = kwargs.get('out_type', unicode) + self._styles = styles + + def styles(self, mutagen_file): + """Yields the list of storage styles of this field that can + handle the MediaFile's format. + """ + for style in self._styles: + if mutagen_file.__class__.__name__ in style.formats: + yield style + + def __get__(self, mediafile, owner=None): + out = None + for style in self.styles(mediafile.mgfile): + out = style.get(mediafile.mgfile) + if out: + break + return _safe_cast(self.out_type, out) + + def __set__(self, mediafile, value): + if value is None: + value = self._none_value() + for style in self.styles(mediafile.mgfile): + style.set(mediafile.mgfile, value) + + def __delete__(self, mediafile): + for style in self.styles(mediafile.mgfile): + style.delete(mediafile.mgfile) + + def _none_value(self): + """Get an appropriate "null" value for this field's type. This + is used internally when setting the field to None. + """ + if self.out_type == int: + return 0 + elif self.out_type == float: + return 0.0 + elif self.out_type == bool: + return False + elif self.out_type == unicode: + return u'' + + +class ListMediaField(MediaField): + """Property descriptor that retrieves a list of multiple values from + a tag. + + Uses ``get_list`` and set_list`` methods of its ``StorageStyle`` + strategies to do the actual work. + """ + def __get__(self, mediafile, _): + values = [] + for style in self.styles(mediafile.mgfile): + values.extend(style.get_list(mediafile.mgfile)) + return [_safe_cast(self.out_type, value) for value in values] + + def __set__(self, mediafile, values): + for style in self.styles(mediafile.mgfile): + style.set_list(mediafile.mgfile, values) + + def single_field(self): + """Returns a ``MediaField`` descriptor that gets and sets the + first item. + """ + options = {'out_type': self.out_type} + return MediaField(*self._styles, **options) + + +class DateField(MediaField): + """Descriptor that handles serializing and deserializing dates + + The getter parses value from tags into a ``datetime.date`` instance + and setter serializes such an instance into a string. + + For granular access to year, month, and day, use the ``*_field`` + methods to create corresponding `DateItemField`s. + """ + def __init__(self, *date_styles, **kwargs): + """``date_styles`` is a list of ``StorageStyle``s to store and + retrieve the whole date from. The ``year`` option is an + additional list of fallback styles for the year. The year is + always set on this style, but is only retrieved if the main + storage styles do not return a value. + """ + super(DateField, self).__init__(*date_styles) + year_style = kwargs.get('year', None) + if year_style: + self._year_field = MediaField(*year_style) + + def __get__(self, mediafile, owner=None): + year, month, day = self._get_date_tuple(mediafile) + if not year: + return None + try: + return datetime.date( + year, + month or 1, + day or 1 + ) + except ValueError: # Out of range values. + return None + + def __set__(self, mediafile, date): + if date is None: + self._set_date_tuple(mediafile, None, None, None) + else: + self._set_date_tuple(mediafile, date.year, date.month, date.day) + + def __delete__(self, mediafile): + super(DateField, self).__delete__(mediafile) + if hasattr(self, '_year_field'): + self._year_field.__delete__(mediafile) + + def _get_date_tuple(self, mediafile): + """Get a 3-item sequence representing the date consisting of a + year, month, and day number. Each number is either an integer or + None. + """ + # Get the underlying data and split on hyphens and slashes. + datestring = super(DateField, self).__get__(mediafile, None) + if isinstance(datestring, basestring): + datestring = re.sub(r'[Tt ].*$', '', unicode(datestring)) + items = re.split('[-/]', unicode(datestring)) + else: + items = [] + + # Ensure that we have exactly 3 components, possibly by + # truncating or padding. + items = items[:3] + if len(items) < 3: + items += [None] * (3 - len(items)) + + # Use year field if year is missing. + if not items[0] and hasattr(self, '_year_field'): + items[0] = self._year_field.__get__(mediafile) + + # Convert each component to an integer if possible. + items_ = [] + for item in items: + try: + items_.append(int(item)) + except: + items_.append(None) + return items_ + + def _set_date_tuple(self, mediafile, year, month=None, day=None): + """Set the value of the field given a year, month, and day + number. Each number can be an integer or None to indicate an + unset component. + """ + if year is None: + self.__delete__(mediafile) + return + + date = [u'{0:04d}'.format(int(year))] + if month: + date.append(u'{0:02d}'.format(int(month))) + if month and day: + date.append(u'{0:02d}'.format(int(day))) + date = map(unicode, date) + super(DateField, self).__set__(mediafile, u'-'.join(date)) + + if hasattr(self, '_year_field'): + self._year_field.__set__(mediafile, year) + + def year_field(self): + return DateItemField(self, 0) + + def month_field(self): + return DateItemField(self, 1) + + def day_field(self): + return DateItemField(self, 2) + + +class DateItemField(MediaField): + """Descriptor that gets and sets constituent parts of a `DateField`: + the month, day, or year. + """ + def __init__(self, date_field, item_pos): + self.date_field = date_field + self.item_pos = item_pos + + def __get__(self, mediafile, _): + return self.date_field._get_date_tuple(mediafile)[self.item_pos] + + def __set__(self, mediafile, value): + items = self.date_field._get_date_tuple(mediafile) + items[self.item_pos] = value + self.date_field._set_date_tuple(mediafile, *items) + + def __delete__(self, mediafile): + self.__set__(mediafile, None) + + +class CoverArtField(MediaField): + """A descriptor that provides access to the *raw image data* for the + first image on a file. This is used for backwards compatibility: the + full `ImageListField` provides richer `Image` objects. + """ + def __init__(self): + pass + + def __get__(self, mediafile, _): + try: + return mediafile.images[0].data + except IndexError: + return None + + def __set__(self, mediafile, data): + if data: + mediafile.images = [Image(data=data)] + else: + mediafile.images = [] + + def __delete__(self, mediafile): + delattr(mediafile, 'images') + + +class ImageListField(ListMediaField): + """Descriptor to access the list of images embedded in tags. + + The getter returns a list of `Image` instances obtained from + the tags. The setter accepts a list of `Image` instances to be + written to the tags. + """ + def __init__(self): + # The storage styles used here must implement the + # `ListStorageStyle` interface and get and set lists of + # `Image`s. + super(ImageListField, self).__init__( + MP3ImageStorageStyle(), + MP4ImageStorageStyle(), + ASFImageStorageStyle(), + VorbisImageStorageStyle(), + FlacImageStorageStyle(), + APEv2ImageStorageStyle(), + out_type=Image, + ) + + +# MediaFile is a collection of fields. + +class MediaFile(object): + """Represents a multimedia file on disk and provides access to its + metadata. + """ + def __init__(self, path, id3v23=False): + """Constructs a new `MediaFile` reflecting the file at path. May + throw `UnreadableFileError`. + + By default, MP3 files are saved with ID3v2.4 tags. You can use + the older ID3v2.3 standard by specifying the `id3v23` option. + """ + self.path = path + + unreadable_exc = ( + mutagen.mp3.error, + mutagen.id3.error, + mutagen.flac.error, + mutagen.monkeysaudio.MonkeysAudioHeaderError, + mutagen.mp4.error, + mutagen.oggopus.error, + mutagen.oggvorbis.error, + mutagen.ogg.error, + mutagen.asf.error, + mutagen.apev2.error, + mutagen.aiff.error, + ) + try: + self.mgfile = mutagen.File(path) + except unreadable_exc as exc: + log.debug(u'header parsing failed: {0}'.format(unicode(exc))) + raise UnreadableFileError(path) + except IOError as exc: + if type(exc) == IOError: + # This is a base IOError, not a subclass from Mutagen or + # anywhere else. + raise + else: + log.debug(traceback.format_exc()) + raise MutagenError(path, exc) + except Exception as exc: + # Isolate bugs in Mutagen. + log.debug(traceback.format_exc()) + log.error(u'uncaught Mutagen exception in open: {0}'.format(exc)) + raise MutagenError(path, exc) + + if self.mgfile is None: + # Mutagen couldn't guess the type + raise FileTypeError(path) + elif (type(self.mgfile).__name__ == 'M4A' or + type(self.mgfile).__name__ == 'MP4'): + info = self.mgfile.info + if hasattr(info, 'codec'): + if info.codec and info.codec.startswith('alac'): + self.type = 'alac' + else: + self.type = 'aac' + else: + # This hack differentiates AAC and ALAC on versions of + # Mutagen < 1.26. Once Mutagen > 1.26 is out and + # required by beets, we can remove this. + if hasattr(self.mgfile.info, 'bitrate') and \ + self.mgfile.info.bitrate > 0: + self.type = 'aac' + else: + self.type = 'alac' + elif (type(self.mgfile).__name__ == 'ID3' or + type(self.mgfile).__name__ == 'MP3'): + self.type = 'mp3' + elif type(self.mgfile).__name__ == 'FLAC': + self.type = 'flac' + elif type(self.mgfile).__name__ == 'OggOpus': + self.type = 'opus' + elif type(self.mgfile).__name__ == 'OggVorbis': + self.type = 'ogg' + elif type(self.mgfile).__name__ == 'MonkeysAudio': + self.type = 'ape' + elif type(self.mgfile).__name__ == 'WavPack': + self.type = 'wv' + elif type(self.mgfile).__name__ == 'Musepack': + self.type = 'mpc' + elif type(self.mgfile).__name__ == 'ASF': + self.type = 'asf' + elif type(self.mgfile).__name__ == 'AIFF': + self.type = 'aiff' + else: + raise FileTypeError(path, type(self.mgfile).__name__) + + # Add a set of tags if it's missing. + if self.mgfile.tags is None: + self.mgfile.add_tags() + + # Set the ID3v2.3 flag only for MP3s. + self.id3v23 = id3v23 and self.type == 'mp3' + + def save(self): + """Write the object's tags back to the file. + """ + # Possibly save the tags to ID3v2.3. + kwargs = {} + if self.id3v23: + id3 = self.mgfile + if hasattr(id3, 'tags'): + # In case this is an MP3 object, not an ID3 object. + id3 = id3.tags + id3.update_to_v23() + kwargs['v2_version'] = 3 + + # Isolate bugs in Mutagen. + try: + self.mgfile.save(**kwargs) + except (IOError, OSError): + # Propagate these through: they don't represent Mutagen bugs. + raise + except Exception as exc: + log.debug(traceback.format_exc()) + log.error(u'uncaught Mutagen exception in save: {0}'.format(exc)) + raise MutagenError(self.path, exc) + + def delete(self): + """Remove the current metadata tag from the file. + """ + try: + self.mgfile.delete() + except NotImplementedError: + # For Mutagen types that don't support deletion (notably, + # ASF), just delete each tag individually. + for tag in self.mgfile.keys(): + del self.mgfile[tag] + + # Convenient access to the set of available fields. + + @classmethod + def fields(cls): + """Get the names of all writable properties that reflect + metadata tags (i.e., those that are instances of + :class:`MediaField`). + """ + for property, descriptor in cls.__dict__.items(): + if isinstance(descriptor, MediaField): + yield property + + @classmethod + def readable_fields(cls): + """Get all metadata fields: the writable ones from + :meth:`fields` and also other audio properties. + """ + for property in cls.fields(): + yield property + for property in ('length', 'samplerate', 'bitdepth', 'bitrate', + 'channels', 'format'): + yield property + + @classmethod + def add_field(cls, name, descriptor): + """Add a field to store custom tags. + + :param name: the name of the property the field is accessed + through. It must not already exist on this class. + + :param descriptor: an instance of :class:`MediaField`. + """ + if not isinstance(descriptor, MediaField): + raise ValueError( + u'{0} must be an instance of MediaField'.format(descriptor)) + if name in cls.__dict__: + raise ValueError( + u'property "{0}" already exists on MediaField'.format(name)) + setattr(cls, name, descriptor) + + def update(self, dict): + """Set all field values from a dictionary. + + For any key in `dict` that is also a field to store tags the + method retrieves the corresponding value from `dict` and updates + the `MediaFile`. If a key has the value `None`, the + corresponding property is deleted from the `MediaFile`. + """ + for field in self.fields(): + if field in dict: + if dict[field] is None: + delattr(self, field) + else: + setattr(self, field, dict[field]) + + # Field definitions. + + title = MediaField( + MP3StorageStyle('TIT2'), + MP4StorageStyle("\xa9nam"), + StorageStyle('TITLE'), + ASFStorageStyle('Title'), + ) + artist = MediaField( + MP3StorageStyle('TPE1'), + MP4StorageStyle("\xa9ART"), + StorageStyle('ARTIST'), + ASFStorageStyle('Author'), + ) + album = MediaField( + MP3StorageStyle('TALB'), + MP4StorageStyle("\xa9alb"), + StorageStyle('ALBUM'), + ASFStorageStyle('WM/AlbumTitle'), + ) + genres = ListMediaField( + MP3ListStorageStyle('TCON'), + MP4ListStorageStyle("\xa9gen"), + ListStorageStyle('GENRE'), + ASFStorageStyle('WM/Genre'), + ) + genre = genres.single_field() + + composer = MediaField( + MP3StorageStyle('TCOM'), + MP4StorageStyle("\xa9wrt"), + StorageStyle('COMPOSER'), + ASFStorageStyle('WM/Composer'), + ) + grouping = MediaField( + MP3StorageStyle('TIT1'), + MP4StorageStyle("\xa9grp"), + StorageStyle('GROUPING'), + ASFStorageStyle('WM/ContentGroupDescription'), + ) + track = MediaField( + MP3SlashPackStorageStyle('TRCK', pack_pos=0), + MP4TupleStorageStyle('trkn', index=0), + StorageStyle('TRACK'), + StorageStyle('TRACKNUMBER'), + ASFStorageStyle('WM/TrackNumber'), + out_type=int, + ) + tracktotal = MediaField( + MP3SlashPackStorageStyle('TRCK', pack_pos=1), + MP4TupleStorageStyle('trkn', index=1), + StorageStyle('TRACKTOTAL'), + StorageStyle('TRACKC'), + StorageStyle('TOTALTRACKS'), + ASFStorageStyle('TotalTracks'), + out_type=int, + ) + disc = MediaField( + MP3SlashPackStorageStyle('TPOS', pack_pos=0), + MP4TupleStorageStyle('disk', index=0), + StorageStyle('DISC'), + StorageStyle('DISCNUMBER'), + ASFStorageStyle('WM/PartOfSet'), + out_type=int, + ) + disctotal = MediaField( + MP3SlashPackStorageStyle('TPOS', pack_pos=1), + MP4TupleStorageStyle('disk', index=1), + StorageStyle('DISCTOTAL'), + StorageStyle('DISCC'), + StorageStyle('TOTALDISCS'), + ASFStorageStyle('TotalDiscs'), + out_type=int, + ) + lyrics = MediaField( + MP3DescStorageStyle(key='USLT'), + MP4StorageStyle("\xa9lyr"), + StorageStyle('LYRICS'), + ASFStorageStyle('WM/Lyrics'), + ) + comments = MediaField( + MP3DescStorageStyle(key='COMM'), + MP4StorageStyle("\xa9cmt"), + StorageStyle('DESCRIPTION'), + StorageStyle('COMMENT'), + ASFStorageStyle('WM/Comments'), + ASFStorageStyle('Description') + ) + bpm = MediaField( + MP3StorageStyle('TBPM'), + MP4StorageStyle('tmpo', as_type=int), + StorageStyle('BPM'), + ASFStorageStyle('WM/BeatsPerMinute'), + out_type=int, + ) + comp = MediaField( + MP3StorageStyle('TCMP'), + MP4BoolStorageStyle('cpil'), + StorageStyle('COMPILATION'), + ASFStorageStyle('WM/IsCompilation', as_type=bool), + out_type=bool, + ) + albumartist = MediaField( + MP3StorageStyle('TPE2'), + MP4StorageStyle('aART'), + StorageStyle('ALBUM ARTIST'), + StorageStyle('ALBUMARTIST'), + ASFStorageStyle('WM/AlbumArtist'), + ) + albumtype = MediaField( + MP3DescStorageStyle(u'MusicBrainz Album Type'), + MP4StorageStyle('----:com.apple.iTunes:MusicBrainz Album Type'), + StorageStyle('MUSICBRAINZ_ALBUMTYPE'), + ASFStorageStyle('MusicBrainz/Album Type'), + ) + label = MediaField( + MP3StorageStyle('TPUB'), + MP4StorageStyle('----:com.apple.iTunes:Label'), + MP4StorageStyle('----:com.apple.iTunes:publisher'), + StorageStyle('LABEL'), + StorageStyle('PUBLISHER'), # Traktor + ASFStorageStyle('WM/Publisher'), + ) + artist_sort = MediaField( + MP3StorageStyle('TSOP'), + MP4StorageStyle("soar"), + StorageStyle('ARTISTSORT'), + ASFStorageStyle('WM/ArtistSortOrder'), + ) + albumartist_sort = MediaField( + MP3DescStorageStyle(u'ALBUMARTISTSORT'), + MP4StorageStyle("soaa"), + StorageStyle('ALBUMARTISTSORT'), + ASFStorageStyle('WM/AlbumArtistSortOrder'), + ) + asin = MediaField( + MP3DescStorageStyle(u'ASIN'), + MP4StorageStyle("----:com.apple.iTunes:ASIN"), + StorageStyle('ASIN'), + ASFStorageStyle('MusicBrainz/ASIN'), + ) + catalognum = MediaField( + MP3DescStorageStyle(u'CATALOGNUMBER'), + MP4StorageStyle("----:com.apple.iTunes:CATALOGNUMBER"), + StorageStyle('CATALOGNUMBER'), + ASFStorageStyle('WM/CatalogNo'), + ) + disctitle = MediaField( + MP3StorageStyle('TSST'), + MP4StorageStyle("----:com.apple.iTunes:DISCSUBTITLE"), + StorageStyle('DISCSUBTITLE'), + ASFStorageStyle('WM/SetSubTitle'), + ) + encoder = MediaField( + MP3StorageStyle('TENC'), + MP4StorageStyle("\xa9too"), + StorageStyle('ENCODEDBY'), + StorageStyle('ENCODER'), + ASFStorageStyle('WM/EncodedBy'), + ) + script = MediaField( + MP3DescStorageStyle(u'Script'), + MP4StorageStyle("----:com.apple.iTunes:SCRIPT"), + StorageStyle('SCRIPT'), + ASFStorageStyle('WM/Script'), + ) + language = MediaField( + MP3StorageStyle('TLAN'), + MP4StorageStyle("----:com.apple.iTunes:LANGUAGE"), + StorageStyle('LANGUAGE'), + ASFStorageStyle('WM/Language'), + ) + country = MediaField( + MP3DescStorageStyle('MusicBrainz Album Release Country'), + MP4StorageStyle("----:com.apple.iTunes:MusicBrainz " + "Album Release Country"), + StorageStyle('RELEASECOUNTRY'), + ASFStorageStyle('MusicBrainz/Album Release Country'), + ) + albumstatus = MediaField( + MP3DescStorageStyle(u'MusicBrainz Album Status'), + MP4StorageStyle("----:com.apple.iTunes:MusicBrainz Album Status"), + StorageStyle('MUSICBRAINZ_ALBUMSTATUS'), + ASFStorageStyle('MusicBrainz/Album Status'), + ) + media = MediaField( + MP3StorageStyle('TMED'), + MP4StorageStyle("----:com.apple.iTunes:MEDIA"), + StorageStyle('MEDIA'), + ASFStorageStyle('WM/Media'), + ) + albumdisambig = MediaField( + # This tag mapping was invented for beets (not used by Picard, etc). + MP3DescStorageStyle(u'MusicBrainz Album Comment'), + MP4StorageStyle("----:com.apple.iTunes:MusicBrainz Album Comment"), + StorageStyle('MUSICBRAINZ_ALBUMCOMMENT'), + ASFStorageStyle('MusicBrainz/Album Comment'), + ) + + # Release date. + date = DateField( + MP3StorageStyle('TDRC'), + MP4StorageStyle("\xa9day"), + StorageStyle('DATE'), + ASFStorageStyle('WM/Year'), + year=(StorageStyle('YEAR'),)) + + year = date.year_field() + month = date.month_field() + day = date.day_field() + + # *Original* release date. + original_date = DateField( + MP3StorageStyle('TDOR'), + MP4StorageStyle('----:com.apple.iTunes:ORIGINAL YEAR'), + StorageStyle('ORIGINALDATE'), + ASFStorageStyle('WM/OriginalReleaseYear')) + + original_year = original_date.year_field() + original_month = original_date.month_field() + original_day = original_date.day_field() + + # Nonstandard metadata. + artist_credit = MediaField( + MP3DescStorageStyle(u'Artist Credit'), + MP4StorageStyle("----:com.apple.iTunes:Artist Credit"), + StorageStyle('ARTIST_CREDIT'), + ASFStorageStyle('beets/Artist Credit'), + ) + albumartist_credit = MediaField( + MP3DescStorageStyle(u'Album Artist Credit'), + MP4StorageStyle("----:com.apple.iTunes:Album Artist Credit"), + StorageStyle('ALBUMARTIST_CREDIT'), + ASFStorageStyle('beets/Album Artist Credit'), + ) + + # Legacy album art field + art = CoverArtField() + + # Image list + images = ImageListField() + + # MusicBrainz IDs. + mb_trackid = MediaField( + MP3UFIDStorageStyle(owner='http://musicbrainz.org'), + MP4StorageStyle('----:com.apple.iTunes:MusicBrainz Track Id'), + StorageStyle('MUSICBRAINZ_TRACKID'), + ASFStorageStyle('MusicBrainz/Track Id'), + ) + mb_albumid = MediaField( + MP3DescStorageStyle(u'MusicBrainz Album Id'), + MP4StorageStyle('----:com.apple.iTunes:MusicBrainz Album Id'), + StorageStyle('MUSICBRAINZ_ALBUMID'), + ASFStorageStyle('MusicBrainz/Album Id'), + ) + mb_artistid = MediaField( + MP3DescStorageStyle(u'MusicBrainz Artist Id'), + MP4StorageStyle('----:com.apple.iTunes:MusicBrainz Artist Id'), + StorageStyle('MUSICBRAINZ_ARTISTID'), + ASFStorageStyle('MusicBrainz/Artist Id'), + ) + mb_albumartistid = MediaField( + MP3DescStorageStyle(u'MusicBrainz Album Artist Id'), + MP4StorageStyle('----:com.apple.iTunes:MusicBrainz Album Artist Id'), + StorageStyle('MUSICBRAINZ_ALBUMARTISTID'), + ASFStorageStyle('MusicBrainz/Album Artist Id'), + ) + mb_releasegroupid = MediaField( + MP3DescStorageStyle(u'MusicBrainz Release Group Id'), + MP4StorageStyle('----:com.apple.iTunes:MusicBrainz Release Group Id'), + StorageStyle('MUSICBRAINZ_RELEASEGROUPID'), + ASFStorageStyle('MusicBrainz/Release Group Id'), + ) + + # Acoustid fields. + acoustid_fingerprint = MediaField( + MP3DescStorageStyle(u'Acoustid Fingerprint'), + MP4StorageStyle('----:com.apple.iTunes:Acoustid Fingerprint'), + StorageStyle('ACOUSTID_FINGERPRINT'), + ASFStorageStyle('Acoustid/Fingerprint'), + ) + acoustid_id = MediaField( + MP3DescStorageStyle(u'Acoustid Id'), + MP4StorageStyle('----:com.apple.iTunes:Acoustid Id'), + StorageStyle('ACOUSTID_ID'), + ASFStorageStyle('Acoustid/Id'), + ) + + # ReplayGain fields. + rg_track_gain = MediaField( + MP3DescStorageStyle( + u'REPLAYGAIN_TRACK_GAIN', + float_places=2, suffix=u' dB' + ), + MP3DescStorageStyle( + u'replaygain_track_gain', + float_places=2, suffix=u' dB' + ), + MP3SoundCheckStorageStyle( + key='COMM', + index=0, desc=u'iTunNORM', + id3_lang='eng' + ), + MP4StorageStyle( + '----:com.apple.iTunes:replaygain_track_gain', + float_places=2, suffix=b' dB' + ), + MP4SoundCheckStorageStyle( + '----:com.apple.iTunes:iTunNORM', + index=0 + ), + StorageStyle( + u'REPLAYGAIN_TRACK_GAIN', + float_places=2, suffix=u' dB' + ), + ASFStorageStyle( + u'replaygain_track_gain', + float_places=2, suffix=u' dB' + ), + out_type=float + ) + rg_album_gain = MediaField( + MP3DescStorageStyle( + u'REPLAYGAIN_ALBUM_GAIN', + float_places=2, suffix=u' dB' + ), + MP3DescStorageStyle( + u'replaygain_album_gain', + float_places=2, suffix=u' dB' + ), + MP4SoundCheckStorageStyle( + '----:com.apple.iTunes:iTunNORM', + index=1 + ), + StorageStyle( + u'REPLAYGAIN_ALBUM_GAIN', + float_places=2, suffix=u' dB' + ), + ASFStorageStyle( + u'replaygain_album_gain', + float_places=2, suffix=u' dB' + ), + out_type=float + ) + rg_track_peak = MediaField( + MP3DescStorageStyle( + u'REPLAYGAIN_TRACK_PEAK', + float_places=6 + ), + MP3DescStorageStyle( + u'replaygain_track_peak', + float_places=6 + ), + MP3SoundCheckStorageStyle( + key=u'COMM', + index=1, desc=u'iTunNORM', + id3_lang='eng' + ), + MP4StorageStyle( + '----:com.apple.iTunes:replaygain_track_peak', + float_places=6 + ), + MP4SoundCheckStorageStyle( + '----:com.apple.iTunes:iTunNORM', + index=1 + ), + StorageStyle(u'REPLAYGAIN_TRACK_PEAK', float_places=6), + ASFStorageStyle(u'replaygain_track_peak', float_places=6), + out_type=float, + ) + rg_album_peak = MediaField( + MP3DescStorageStyle( + u'REPLAYGAIN_ALBUM_PEAK', + float_places=6 + ), + MP3DescStorageStyle( + u'replaygain_album_peak', + float_places=6 + ), + MP4StorageStyle( + '----:com.apple.iTunes:replaygain_album_peak', + float_places=6 + ), + StorageStyle(u'REPLAYGAIN_ALBUM_PEAK', float_places=6), + ASFStorageStyle(u'replaygain_album_peak', float_places=6), + out_type=float, + ) + + initial_key = MediaField( + MP3StorageStyle('TKEY'), + MP4StorageStyle('----:com.apple.iTunes:initialkey'), + StorageStyle('INITIALKEY'), + ASFStorageStyle('INITIALKEY'), + ) + + @property + def length(self): + """The duration of the audio in seconds (a float).""" + return self.mgfile.info.length + + @property + def samplerate(self): + """The audio's sample rate (an int).""" + if hasattr(self.mgfile.info, 'sample_rate'): + return self.mgfile.info.sample_rate + elif self.type == 'opus': + # Opus is always 48kHz internally. + return 48000 + return 0 + + @property + def bitdepth(self): + """The number of bits per sample in the audio encoding (an int). + Only available for certain file formats (zero where + unavailable). + """ + if hasattr(self.mgfile.info, 'bits_per_sample'): + return self.mgfile.info.bits_per_sample + return 0 + + @property + def channels(self): + """The number of channels in the audio (an int).""" + if isinstance(self.mgfile.info, mutagen.mp3.MPEGInfo): + return { + mutagen.mp3.STEREO: 2, + mutagen.mp3.JOINTSTEREO: 2, + mutagen.mp3.DUALCHANNEL: 2, + mutagen.mp3.MONO: 1, + }[self.mgfile.info.mode] + if hasattr(self.mgfile.info, 'channels'): + return self.mgfile.info.channels + return 0 + + @property + def bitrate(self): + """The number of bits per seconds used in the audio coding (an + int). If this is provided explicitly by the compressed file + format, this is a precise reflection of the encoding. Otherwise, + it is estimated from the on-disk file size. In this case, some + imprecision is possible because the file header is incorporated + in the file size. + """ + if hasattr(self.mgfile.info, 'bitrate') and self.mgfile.info.bitrate: + # Many formats provide it explicitly. + return self.mgfile.info.bitrate + else: + # Otherwise, we calculate bitrate from the file size. (This + # is the case for all of the lossless formats.) + if not self.length: + # Avoid division by zero if length is not available. + return 0 + size = os.path.getsize(self.path) + return int(size * 8 / self.length) + + @property + def format(self): + """A string describing the file format/codec.""" + return TYPES[self.type] diff --git a/lib/beets/plugins.py b/lib/beets/plugins.py new file mode 100755 index 00000000..8611b92a --- /dev/null +++ b/lib/beets/plugins.py @@ -0,0 +1,435 @@ +# This file is part of beets. +# Copyright 2013, Adrian Sampson. +# +# Permission is hereby granted, free of charge, to any person obtaining +# a copy of this software and associated documentation files (the +# "Software"), to deal in the Software without restriction, including +# without limitation the rights to use, copy, modify, merge, publish, +# distribute, sublicense, and/or sell copies of the Software, and to +# permit persons to whom the Software is furnished to do so, subject to +# the following conditions: +# +# The above copyright notice and this permission notice shall be +# included in all copies or substantial portions of the Software. + +"""Support for beets plugins.""" + +import logging +import traceback +import inspect +import re +from collections import defaultdict + + +import beets +from beets import mediafile + +PLUGIN_NAMESPACE = 'beetsplug' + +# Plugins using the Last.fm API can share the same API key. +LASTFM_KEY = '2dc3914abf35f0d9c92d97d8f8e42b43' + +# Global logger. +log = logging.getLogger('beets') + + +class PluginConflictException(Exception): + """Indicates that the services provided by one plugin conflict with + those of another. + + For example two plugins may define different types for flexible fields. + """ + + +# Managing the plugins themselves. + +class BeetsPlugin(object): + """The base class for all beets plugins. Plugins provide + functionality by defining a subclass of BeetsPlugin and overriding + the abstract methods defined here. + """ + def __init__(self, name=None): + """Perform one-time plugin setup. + """ + self.import_stages = [] + self.name = name or self.__module__.split('.')[-1] + self.config = beets.config[self.name] + if not self.template_funcs: + self.template_funcs = {} + if not self.template_fields: + self.template_fields = {} + if not self.album_template_fields: + self.album_template_fields = {} + + def commands(self): + """Should return a list of beets.ui.Subcommand objects for + commands that should be added to beets' CLI. + """ + return () + + def queries(self): + """Should return a dict mapping prefixes to Query subclasses. + """ + return {} + + def track_distance(self, item, info): + """Should return a Distance object to be added to the + distance for every track comparison. + """ + return beets.autotag.hooks.Distance() + + def album_distance(self, items, album_info, mapping): + """Should return a Distance object to be added to the + distance for every album-level comparison. + """ + return beets.autotag.hooks.Distance() + + def candidates(self, items, artist, album, va_likely): + """Should return a sequence of AlbumInfo objects that match the + album whose items are provided. + """ + return () + + def item_candidates(self, item, artist, title): + """Should return a sequence of TrackInfo objects that match the + item provided. + """ + return () + + def album_for_id(self, album_id): + """Return an AlbumInfo object or None if no matching release was + found. + """ + return None + + def track_for_id(self, track_id): + """Return a TrackInfo object or None if no matching release was + found. + """ + return None + + def add_media_field(self, name, descriptor): + """Add a field that is synchronized between media files and items. + + When a media field is added ``item.write()`` will set the name + property of the item's MediaFile to ``item[name]`` and save the + changes. Similarly ``item.read()`` will set ``item[name]`` to + the value of the name property of the media file. + + ``descriptor`` must be an instance of ``mediafile.MediaField``. + """ + # Defer impor to prevent circular dependency + from beets import library + mediafile.MediaFile.add_field(name, descriptor) + library.Item._media_fields.add(name) + + listeners = None + + @classmethod + def register_listener(cls, event, func): + """Add a function as a listener for the specified event. (An + imperative alternative to the @listen decorator.) + """ + if cls.listeners is None: + cls.listeners = defaultdict(list) + cls.listeners[event].append(func) + + @classmethod + def listen(cls, event): + """Decorator that adds a function as an event handler for the + specified event (as a string). The parameters passed to function + will vary depending on what event occurred. + + The function should respond to named parameters. + function(**kwargs) will trap all arguments in a dictionary. + Example: + + >>> @MyPlugin.listen("imported") + >>> def importListener(**kwargs): + ... pass + """ + def helper(func): + if cls.listeners is None: + cls.listeners = defaultdict(list) + cls.listeners[event].append(func) + return func + return helper + + template_funcs = None + template_fields = None + album_template_fields = None + + @classmethod + def template_func(cls, name): + """Decorator that registers a path template function. The + function will be invoked as ``%name{}`` from path format + strings. + """ + def helper(func): + if cls.template_funcs is None: + cls.template_funcs = {} + cls.template_funcs[name] = func + return func + return helper + + @classmethod + def template_field(cls, name): + """Decorator that registers a path template field computation. + The value will be referenced as ``$name`` from path format + strings. The function must accept a single parameter, the Item + being formatted. + """ + def helper(func): + if cls.template_fields is None: + cls.template_fields = {} + cls.template_fields[name] = func + return func + return helper + + +_classes = set() + + +def load_plugins(names=()): + """Imports the modules for a sequence of plugin names. Each name + must be the name of a Python module under the "beetsplug" namespace + package in sys.path; the module indicated should contain the + BeetsPlugin subclasses desired. + """ + for name in names: + modname = '%s.%s' % (PLUGIN_NAMESPACE, name) + try: + try: + namespace = __import__(modname, None, None) + except ImportError as exc: + # Again, this is hacky: + if exc.args[0].endswith(' ' + name): + log.warn(u'** plugin {0} not found'.format(name)) + else: + raise + else: + for obj in getattr(namespace, name).__dict__.values(): + if isinstance(obj, type) and issubclass(obj, BeetsPlugin) \ + and obj != BeetsPlugin and obj not in _classes: + _classes.add(obj) + + except: + log.warn(u'** error loading plugin {0}'.format(name)) + log.warn(traceback.format_exc()) + + +_instances = {} + + +def find_plugins(): + """Returns a list of BeetsPlugin subclass instances from all + currently loaded beets plugins. Loads the default plugin set + first. + """ + load_plugins() + plugins = [] + for cls in _classes: + # Only instantiate each plugin class once. + if cls not in _instances: + _instances[cls] = cls() + plugins.append(_instances[cls]) + return plugins + + +# Communication with plugins. + +def commands(): + """Returns a list of Subcommand objects from all loaded plugins. + """ + out = [] + for plugin in find_plugins(): + out += plugin.commands() + return out + + +def queries(): + """Returns a dict mapping prefix strings to Query subclasses all loaded + plugins. + """ + out = {} + for plugin in find_plugins(): + out.update(plugin.queries()) + return out + + +def types(model_cls): + # Gives us `item_types` and `album_types` + attr_name = '{0}_types'.format(model_cls.__name__.lower()) + types = {} + for plugin in find_plugins(): + plugin_types = getattr(plugin, attr_name, {}) + for field in plugin_types: + if field in types and plugin_types[field] != types[field]: + raise PluginConflictException( + u'Plugin {0} defines flexible field {1} ' + 'which has already been defined with ' + 'another type.'.format(plugin.name, field) + ) + types.update(plugin_types) + return types + + +def track_distance(item, info): + """Gets the track distance calculated by all loaded plugins. + Returns a Distance object. + """ + from beets.autotag.hooks import Distance + dist = Distance() + for plugin in find_plugins(): + dist.update(plugin.track_distance(item, info)) + return dist + + +def album_distance(items, album_info, mapping): + """Returns the album distance calculated by plugins.""" + from beets.autotag.hooks import Distance + dist = Distance() + for plugin in find_plugins(): + dist.update(plugin.album_distance(items, album_info, mapping)) + return dist + + +def candidates(items, artist, album, va_likely): + """Gets MusicBrainz candidates for an album from each plugin. + """ + out = [] + for plugin in find_plugins(): + out.extend(plugin.candidates(items, artist, album, va_likely)) + return out + + +def item_candidates(item, artist, title): + """Gets MusicBrainz candidates for an item from the plugins. + """ + out = [] + for plugin in find_plugins(): + out.extend(plugin.item_candidates(item, artist, title)) + return out + + +def album_for_id(album_id): + """Get AlbumInfo objects for a given ID string. + """ + out = [] + for plugin in find_plugins(): + res = plugin.album_for_id(album_id) + if res: + out.append(res) + return out + + +def track_for_id(track_id): + """Get TrackInfo objects for a given ID string. + """ + out = [] + for plugin in find_plugins(): + res = plugin.track_for_id(track_id) + if res: + out.append(res) + return out + + +def template_funcs(): + """Get all the template functions declared by plugins as a + dictionary. + """ + funcs = {} + for plugin in find_plugins(): + if plugin.template_funcs: + funcs.update(plugin.template_funcs) + return funcs + + +def import_stages(): + """Get a list of import stage functions defined by plugins.""" + stages = [] + for plugin in find_plugins(): + if hasattr(plugin, 'import_stages'): + stages += plugin.import_stages + return stages + + +# New-style (lazy) plugin-provided fields. + +def item_field_getters(): + """Get a dictionary mapping field names to unary functions that + compute the field's value. + """ + funcs = {} + for plugin in find_plugins(): + if plugin.template_fields: + funcs.update(plugin.template_fields) + return funcs + + +def album_field_getters(): + """As above, for album fields. + """ + funcs = {} + for plugin in find_plugins(): + if plugin.album_template_fields: + funcs.update(plugin.album_template_fields) + return funcs + + +# Event dispatch. + +def event_handlers(): + """Find all event handlers from plugins as a dictionary mapping + event names to sequences of callables. + """ + all_handlers = defaultdict(list) + for plugin in find_plugins(): + if plugin.listeners: + for event, handlers in plugin.listeners.items(): + all_handlers[event] += handlers + return all_handlers + + +def send(event, **arguments): + """Sends an event to all assigned event listeners. Event is the + name of the event to send, all other named arguments go to the + event handler(s). + + Returns a list of return values from the handlers. + """ + log.debug(u'Sending event: {0}'.format(event)) + for handler in event_handlers()[event]: + # Don't break legacy plugins if we want to pass more arguments + argspec = inspect.getargspec(handler).args + args = dict((k, v) for k, v in arguments.items() if k in argspec) + handler(**args) + + +def feat_tokens(for_artist=True): + """Return a regular expression that matches phrases like "featuring" + that separate a main artist or a song title from secondary artists. + The `for_artist` option determines whether the regex should be + suitable for matching artist fields (the default) or title fields. + """ + feat_words = ['ft', 'featuring', 'feat', 'feat.', 'ft.'] + if for_artist: + feat_words += ['with', 'vs', 'and', 'con', '&'] + return '(?<=\s)(?:{0})(?=\s)'.format( + '|'.join(re.escape(x) for x in feat_words) + ) + + +def sanitize_choices(choices, choices_all): + """Clean up a stringlist configuration attribute: keep only choices + elements present in choices_all, remove duplicate elements, expand '*' + wildcard while keeping original stringlist order. + """ + seen = set() + others = [x for x in choices_all if x not in choices] + res = [] + for s in choices: + if s in list(choices_all) + ['*']: + if not (s in seen or seen.add(s)): + res.extend(list(others) if s == '*' else [s]) + return res diff --git a/lib/beets/ui/__init__.py b/lib/beets/ui/__init__.py new file mode 100644 index 00000000..8978ff54 --- /dev/null +++ b/lib/beets/ui/__init__.py @@ -0,0 +1,970 @@ +# This file is part of beets. +# Copyright 2014, Adrian Sampson. +# +# Permission is hereby granted, free of charge, to any person obtaining +# a copy of this software and associated documentation files (the +# "Software"), to deal in the Software without restriction, including +# without limitation the rights to use, copy, modify, merge, publish, +# distribute, sublicense, and/or sell copies of the Software, and to +# permit persons to whom the Software is furnished to do so, subject to +# the following conditions: +# +# The above copyright notice and this permission notice shall be +# included in all copies or substantial portions of the Software. + +"""This module contains all of the core logic for beets' command-line +interface. To invoke the CLI, just call beets.ui.main(). The actual +CLI commands are implemented in the ui.commands module. +""" +from __future__ import print_function + +import locale +import optparse +import textwrap +import sys +from difflib import SequenceMatcher +import logging +import sqlite3 +import errno +import re +import struct +import traceback +import os.path + +from beets import library +from beets import plugins +from beets import util +from beets.util.functemplate import Template +from beets import config +from beets.util import confit +from beets.autotag import mb + +# On Windows platforms, use colorama to support "ANSI" terminal colors. +if sys.platform == 'win32': + try: + import colorama + except ImportError: + pass + else: + colorama.init() + + +log = logging.getLogger('beets') +if not log.handlers: + log.addHandler(logging.StreamHandler()) +log.propagate = False # Don't propagate to root handler. + + +PF_KEY_QUERIES = { + 'comp': 'comp:true', + 'singleton': 'singleton:true', +} + + +class UserError(Exception): + """UI exception. Commands should throw this in order to display + nonrecoverable errors to the user. + """ + + +# Utilities. + +def _encoding(): + """Tries to guess the encoding used by the terminal.""" + # Configured override? + encoding = config['terminal_encoding'].get() + if encoding: + return encoding + + # Determine from locale settings. + try: + return locale.getdefaultlocale()[1] or 'utf8' + except ValueError: + # Invalid locale environment variable setting. To avoid + # failing entirely for no good reason, assume UTF-8. + return 'utf8' + + +def decargs(arglist): + """Given a list of command-line argument bytestrings, attempts to + decode them to Unicode strings. + """ + return [s.decode(_encoding()) for s in arglist] + + +def print_(*strings): + """Like print, but rather than raising an error when a character + is not in the terminal's encoding's character set, just silently + replaces it. + """ + if strings: + if isinstance(strings[0], unicode): + txt = u' '.join(strings) + else: + txt = ' '.join(strings) + else: + txt = u'' + if isinstance(txt, unicode): + txt = txt.encode(_encoding(), 'replace') + print(txt) + + +def input_(prompt=None): + """Like `raw_input`, but decodes the result to a Unicode string. + Raises a UserError if stdin is not available. The prompt is sent to + stdout rather than stderr. A printed between the prompt and the + input cursor. + """ + # raw_input incorrectly sends prompts to stderr, not stdout, so we + # use print() explicitly to display prompts. + # http://bugs.python.org/issue1927 + if prompt: + if isinstance(prompt, unicode): + prompt = prompt.encode(_encoding(), 'replace') + print(prompt, end=' ') + + try: + resp = raw_input() + except EOFError: + raise UserError('stdin stream ended while input required') + + return resp.decode(sys.stdin.encoding or 'utf8', 'ignore') + + +def input_options(options, require=False, prompt=None, fallback_prompt=None, + numrange=None, default=None, max_width=72): + """Prompts a user for input. The sequence of `options` defines the + choices the user has. A single-letter shortcut is inferred for each + option; the user's choice is returned as that single, lower-case + letter. The options should be provided as lower-case strings unless + a particular shortcut is desired; in that case, only that letter + should be capitalized. + + By default, the first option is the default. `default` can be provided to + override this. If `require` is provided, then there is no default. The + prompt and fallback prompt are also inferred but can be overridden. + + If numrange is provided, it is a pair of `(high, low)` (both ints) + indicating that, in addition to `options`, the user may enter an + integer in that inclusive range. + + `max_width` specifies the maximum number of columns in the + automatically generated prompt string. + """ + # Assign single letters to each option. Also capitalize the options + # to indicate the letter. + letters = {} + display_letters = [] + capitalized = [] + first = True + for option in options: + # Is a letter already capitalized? + for letter in option: + if letter.isalpha() and letter.upper() == letter: + found_letter = letter + break + else: + # Infer a letter. + for letter in option: + if not letter.isalpha(): + continue # Don't use punctuation. + if letter not in letters: + found_letter = letter + break + else: + raise ValueError('no unambiguous lettering found') + + letters[found_letter.lower()] = option + index = option.index(found_letter) + + # Mark the option's shortcut letter for display. + if not require and ( + (default is None and not numrange and first) or + (isinstance(default, basestring) and + found_letter.lower() == default.lower())): + # The first option is the default; mark it. + show_letter = '[%s]' % found_letter.upper() + is_default = True + else: + show_letter = found_letter.upper() + is_default = False + + # Colorize the letter shortcut. + show_letter = colorize('turquoise' if is_default else 'blue', + show_letter) + + # Insert the highlighted letter back into the word. + capitalized.append( + option[:index] + show_letter + option[index + 1:] + ) + display_letters.append(found_letter.upper()) + + first = False + + # The default is just the first option if unspecified. + if require: + default = None + elif default is None: + if numrange: + default = numrange[0] + else: + default = display_letters[0].lower() + + # Make a prompt if one is not provided. + if not prompt: + prompt_parts = [] + prompt_part_lengths = [] + if numrange: + if isinstance(default, int): + default_name = str(default) + default_name = colorize('turquoise', default_name) + tmpl = '# selection (default %s)' + prompt_parts.append(tmpl % default_name) + prompt_part_lengths.append(len(tmpl % str(default))) + else: + prompt_parts.append('# selection') + prompt_part_lengths.append(len(prompt_parts[-1])) + prompt_parts += capitalized + prompt_part_lengths += [len(s) for s in options] + + # Wrap the query text. + prompt = '' + line_length = 0 + for i, (part, length) in enumerate(zip(prompt_parts, + prompt_part_lengths)): + # Add punctuation. + if i == len(prompt_parts) - 1: + part += '?' + else: + part += ',' + length += 1 + + # Choose either the current line or the beginning of the next. + if line_length + length + 1 > max_width: + prompt += '\n' + line_length = 0 + + if line_length != 0: + # Not the beginning of the line; need a space. + part = ' ' + part + length += 1 + + prompt += part + line_length += length + + # Make a fallback prompt too. This is displayed if the user enters + # something that is not recognized. + if not fallback_prompt: + fallback_prompt = 'Enter one of ' + if numrange: + fallback_prompt += '%i-%i, ' % numrange + fallback_prompt += ', '.join(display_letters) + ':' + + resp = input_(prompt) + while True: + resp = resp.strip().lower() + + # Try default option. + if default is not None and not resp: + resp = default + + # Try an integer input if available. + if numrange: + try: + resp = int(resp) + except ValueError: + pass + else: + low, high = numrange + if low <= resp <= high: + return resp + else: + resp = None + + # Try a normal letter input. + if resp: + resp = resp[0] + if resp in letters: + return resp + + # Prompt for new input. + resp = input_(fallback_prompt) + + +def input_yn(prompt, require=False): + """Prompts the user for a "yes" or "no" response. The default is + "yes" unless `require` is `True`, in which case there is no default. + """ + sel = input_options( + ('y', 'n'), require, prompt, 'Enter Y or N:' + ) + return sel == 'y' + + +def human_bytes(size): + """Formats size, a number of bytes, in a human-readable way.""" + suffices = ['B', 'KB', 'MB', 'GB', 'TB', 'PB', 'EB', 'ZB', 'YB', 'HB'] + for suffix in suffices: + if size < 1024: + return "%3.1f %s" % (size, suffix) + size /= 1024.0 + return "big" + + +def human_seconds(interval): + """Formats interval, a number of seconds, as a human-readable time + interval using English words. + """ + units = [ + (1, 'second'), + (60, 'minute'), + (60, 'hour'), + (24, 'day'), + (7, 'week'), + (52, 'year'), + (10, 'decade'), + ] + for i in range(len(units) - 1): + increment, suffix = units[i] + next_increment, _ = units[i + 1] + interval /= float(increment) + if interval < next_increment: + break + else: + # Last unit. + increment, suffix = units[-1] + interval /= float(increment) + + return "%3.1f %ss" % (interval, suffix) + + +def human_seconds_short(interval): + """Formats a number of seconds as a short human-readable M:SS + string. + """ + interval = int(interval) + return u'%i:%02i' % (interval // 60, interval % 60) + + +# ANSI terminal colorization code heavily inspired by pygments: +# http://dev.pocoo.org/hg/pygments-main/file/b2deea5b5030/pygments/console.py +# (pygments is by Tim Hatch, Armin Ronacher, et al.) +COLOR_ESCAPE = "\x1b[" +DARK_COLORS = ["black", "darkred", "darkgreen", "brown", "darkblue", + "purple", "teal", "lightgray"] +LIGHT_COLORS = ["darkgray", "red", "green", "yellow", "blue", + "fuchsia", "turquoise", "white"] +RESET_COLOR = COLOR_ESCAPE + "39;49;00m" + + +def _colorize(color, text): + """Returns a string that prints the given text in the given color + in a terminal that is ANSI color-aware. The color must be something + in DARK_COLORS or LIGHT_COLORS. + """ + if color in DARK_COLORS: + escape = COLOR_ESCAPE + "%im" % (DARK_COLORS.index(color) + 30) + elif color in LIGHT_COLORS: + escape = COLOR_ESCAPE + "%i;01m" % (LIGHT_COLORS.index(color) + 30) + else: + raise ValueError('no such color %s', color) + return escape + text + RESET_COLOR + + +def colorize(color, text): + """Colorize text if colored output is enabled. (Like _colorize but + conditional.) + """ + if config['color']: + return _colorize(color, text) + else: + return text + + +def _colordiff(a, b, highlight='red', minor_highlight='lightgray'): + """Given two values, return the same pair of strings except with + their differences highlighted in the specified color. Strings are + highlighted intelligently to show differences; other values are + stringified and highlighted in their entirety. + """ + if not isinstance(a, basestring) or not isinstance(b, basestring): + # Non-strings: use ordinary equality. + a = unicode(a) + b = unicode(b) + if a == b: + return a, b + else: + return colorize(highlight, a), colorize(highlight, b) + + if isinstance(a, bytes) or isinstance(b, bytes): + # A path field. + a = util.displayable_path(a) + b = util.displayable_path(b) + + a_out = [] + b_out = [] + + matcher = SequenceMatcher(lambda x: False, a, b) + for op, a_start, a_end, b_start, b_end in matcher.get_opcodes(): + if op == 'equal': + # In both strings. + a_out.append(a[a_start:a_end]) + b_out.append(b[b_start:b_end]) + elif op == 'insert': + # Right only. + b_out.append(colorize(highlight, b[b_start:b_end])) + elif op == 'delete': + # Left only. + a_out.append(colorize(highlight, a[a_start:a_end])) + elif op == 'replace': + # Right and left differ. Colorise with second highlight if + # it's just a case change. + if a[a_start:a_end].lower() != b[b_start:b_end].lower(): + color = highlight + else: + color = minor_highlight + a_out.append(colorize(color, a[a_start:a_end])) + b_out.append(colorize(color, b[b_start:b_end])) + else: + assert(False) + + return u''.join(a_out), u''.join(b_out) + + +def colordiff(a, b, highlight='red'): + """Colorize differences between two values if color is enabled. + (Like _colordiff but conditional.) + """ + if config['color']: + return _colordiff(a, b, highlight) + else: + return unicode(a), unicode(b) + + +def get_path_formats(subview=None): + """Get the configuration's path formats as a list of query/template + pairs. + """ + path_formats = [] + subview = subview or config['paths'] + for query, view in subview.items(): + query = PF_KEY_QUERIES.get(query, query) # Expand common queries. + path_formats.append((query, Template(view.get(unicode)))) + return path_formats + + +def get_replacements(): + """Confit validation function that reads regex/string pairs. + """ + replacements = [] + for pattern, repl in config['replace'].get(dict).items(): + repl = repl or '' + try: + replacements.append((re.compile(pattern), repl)) + except re.error: + raise UserError( + u'malformed regular expression in replace: {0}'.format( + pattern + ) + ) + return replacements + + +def _pick_format(album, fmt=None): + """Pick a format string for printing Album or Item objects, + falling back to config options and defaults. + """ + if fmt: + return fmt + if album: + return config['list_format_album'].get(unicode) + else: + return config['list_format_item'].get(unicode) + + +def print_obj(obj, lib, fmt=None): + """Print an Album or Item object. If `fmt` is specified, use that + format string. Otherwise, use the configured template. + """ + album = isinstance(obj, library.Album) + fmt = _pick_format(album, fmt) + if isinstance(fmt, Template): + template = fmt + else: + template = Template(fmt) + print_(obj.evaluate_template(template)) + + +def term_width(): + """Get the width (columns) of the terminal.""" + fallback = config['ui']['terminal_width'].get(int) + + # The fcntl and termios modules are not available on non-Unix + # platforms, so we fall back to a constant. + try: + import fcntl + import termios + except ImportError: + return fallback + + try: + buf = fcntl.ioctl(0, termios.TIOCGWINSZ, ' ' * 4) + except IOError: + return fallback + try: + height, width = struct.unpack('hh', buf) + except struct.error: + return fallback + return width + + +FLOAT_EPSILON = 0.01 + + +def _field_diff(field, old, new): + """Given two Model objects, format their values for `field` and + highlight changes among them. Return a human-readable string. If the + value has not changed, return None instead. + """ + oldval = old.get(field) + newval = new.get(field) + + # If no change, abort. + if isinstance(oldval, float) and isinstance(newval, float) and \ + abs(oldval - newval) < FLOAT_EPSILON: + return None + elif oldval == newval: + return None + + # Get formatted values for output. + oldstr = old.formatted().get(field, u'') + newstr = new.formatted().get(field, u'') + + # For strings, highlight changes. For others, colorize the whole + # thing. + if isinstance(oldval, basestring): + oldstr, newstr = colordiff(oldval, newstr) + else: + oldstr, newstr = colorize('red', oldstr), colorize('red', newstr) + + return u'{0} -> {1}'.format(oldstr, newstr) + + +def show_model_changes(new, old=None, fields=None, always=False): + """Given a Model object, print a list of changes from its pristine + version stored in the database. Return a boolean indicating whether + any changes were found. + + `old` may be the "original" object to avoid using the pristine + version from the database. `fields` may be a list of fields to + restrict the detection to. `always` indicates whether the object is + always identified, regardless of whether any changes are present. + """ + old = old or new._db._get(type(new), new.id) + + # Build up lines showing changed fields. + changes = [] + for field in old: + # Subset of the fields. Never show mtime. + if field == 'mtime' or (fields and field not in fields): + continue + + # Detect and show difference for this field. + line = _field_diff(field, old, new) + if line: + changes.append(u' {0}: {1}'.format(field, line)) + + # New fields. + for field in set(new) - set(old): + if fields and field not in fields: + continue + + changes.append(u' {0}: {1}'.format( + field, + colorize('red', new.formatted()[field]) + )) + + # Print changes. + if changes or always: + print_obj(old, old._db) + if changes: + print_(u'\n'.join(changes)) + + return bool(changes) + + +# Subcommand parsing infrastructure. +# +# This is a fairly generic subcommand parser for optparse. It is +# maintained externally here: +# http://gist.github.com/462717 +# There you will also find a better description of the code and a more +# succinct example program. + +class Subcommand(object): + """A subcommand of a root command-line application that may be + invoked by a SubcommandOptionParser. + """ + def __init__(self, name, parser=None, help='', aliases=(), hide=False): + """Creates a new subcommand. name is the primary way to invoke + the subcommand; aliases are alternate names. parser is an + OptionParser responsible for parsing the subcommand's options. + help is a short description of the command. If no parser is + given, it defaults to a new, empty OptionParser. + """ + self.name = name + self.parser = parser or optparse.OptionParser() + self.aliases = aliases + self.help = help + self.hide = hide + self._root_parser = None + + def print_help(self): + self.parser.print_help() + + def parse_args(self, args): + return self.parser.parse_args(args) + + @property + def root_parser(self): + return self._root_parser + + @root_parser.setter + def root_parser(self, root_parser): + self._root_parser = root_parser + self.parser.prog = '{0} {1}'.format(root_parser.get_prog_name(), + self.name) + + +class SubcommandsOptionParser(optparse.OptionParser): + """A variant of OptionParser that parses subcommands and their + arguments. + """ + + def __init__(self, *args, **kwargs): + """Create a new subcommand-aware option parser. All of the + options to OptionParser.__init__ are supported in addition + to subcommands, a sequence of Subcommand objects. + """ + # A more helpful default usage. + if 'usage' not in kwargs: + kwargs['usage'] = """ + %prog COMMAND [ARGS...] + %prog help COMMAND""" + kwargs['add_help_option'] = False + + # Super constructor. + optparse.OptionParser.__init__(self, *args, **kwargs) + + # Our root parser needs to stop on the first unrecognized argument. + self.disable_interspersed_args() + + self.subcommands = [] + + def add_subcommand(self, *cmds): + """Adds a Subcommand object to the parser's list of commands. + """ + for cmd in cmds: + cmd.root_parser = self + self.subcommands.append(cmd) + + # Add the list of subcommands to the help message. + def format_help(self, formatter=None): + # Get the original help message, to which we will append. + out = optparse.OptionParser.format_help(self, formatter) + if formatter is None: + formatter = self.formatter + + # Subcommands header. + result = ["\n"] + result.append(formatter.format_heading('Commands')) + formatter.indent() + + # Generate the display names (including aliases). + # Also determine the help position. + disp_names = [] + help_position = 0 + subcommands = [c for c in self.subcommands if not c.hide] + subcommands.sort(key=lambda c: c.name) + for subcommand in subcommands: + name = subcommand.name + if subcommand.aliases: + name += ' (%s)' % ', '.join(subcommand.aliases) + disp_names.append(name) + + # Set the help position based on the max width. + proposed_help_position = len(name) + formatter.current_indent + 2 + if proposed_help_position <= formatter.max_help_position: + help_position = max(help_position, proposed_help_position) + + # Add each subcommand to the output. + for subcommand, name in zip(subcommands, disp_names): + # Lifted directly from optparse.py. + name_width = help_position - formatter.current_indent - 2 + if len(name) > name_width: + name = "%*s%s\n" % (formatter.current_indent, "", name) + indent_first = help_position + else: + name = "%*s%-*s " % (formatter.current_indent, "", + name_width, name) + indent_first = 0 + result.append(name) + help_width = formatter.width - help_position + help_lines = textwrap.wrap(subcommand.help, help_width) + result.append("%*s%s\n" % (indent_first, "", help_lines[0])) + result.extend(["%*s%s\n" % (help_position, "", line) + for line in help_lines[1:]]) + formatter.dedent() + + # Concatenate the original help message with the subcommand + # list. + return out + "".join(result) + + def _subcommand_for_name(self, name): + """Return the subcommand in self.subcommands matching the + given name. The name may either be the name of a subcommand or + an alias. If no subcommand matches, returns None. + """ + for subcommand in self.subcommands: + if name == subcommand.name or \ + name in subcommand.aliases: + return subcommand + return None + + def parse_global_options(self, args): + """Parse options up to the subcommand argument. Returns a tuple + of the options object and the remaining arguments. + """ + options, subargs = self.parse_args(args) + + # Force the help command + if options.help: + subargs = ['help'] + elif options.version: + subargs = ['version'] + return options, subargs + + def parse_subcommand(self, args): + """Given the `args` left unused by a `parse_global_options`, + return the invoked subcommand, the subcommand options, and the + subcommand arguments. + """ + # Help is default command + if not args: + args = ['help'] + + cmdname = args.pop(0) + subcommand = self._subcommand_for_name(cmdname) + if not subcommand: + raise UserError("unknown command '{0}'".format(cmdname)) + + suboptions, subargs = subcommand.parse_args(args) + return subcommand, suboptions, subargs + + +optparse.Option.ALWAYS_TYPED_ACTIONS += ('callback',) + + +def vararg_callback(option, opt_str, value, parser): + """Callback for an option with variable arguments. + Manually collect arguments right of a callback-action + option (ie. with action="callback"), and add the resulting + list to the destination var. + + Usage: + parser.add_option("-c", "--callback", dest="vararg_attr", + action="callback", callback=vararg_callback) + + Details: + http://docs.python.org/2/library/optparse.html#callback-example-6-variable + -arguments + """ + value = [value] + + def floatable(str): + try: + float(str) + return True + except ValueError: + return False + + for arg in parser.rargs: + # stop on --foo like options + if arg[:2] == "--" and len(arg) > 2: + break + # stop on -a, but not on -3 or -3.0 + if arg[:1] == "-" and len(arg) > 1 and not floatable(arg): + break + value.append(arg) + + del parser.rargs[:len(value) - 1] + setattr(parser.values, option.dest, value) + + +# The main entry point and bootstrapping. + +def _load_plugins(config): + """Load the plugins specified in the configuration. + """ + paths = config['pluginpath'].get(confit.StrSeq(split=False)) + paths = map(util.normpath, paths) + + import beetsplug + beetsplug.__path__ = paths + beetsplug.__path__ + # For backwards compatibility. + sys.path += paths + + plugins.load_plugins(config['plugins'].as_str_seq()) + plugins.send("pluginload") + return plugins + + +def _setup(options, lib=None): + """Prepare and global state and updates it with command line options. + + Returns a list of subcommands, a list of plugins, and a library instance. + """ + # Configure the MusicBrainz API. + mb.configure() + + config = _configure(options) + + plugins = _load_plugins(config) + + # Get the default subcommands. + from beets.ui.commands import default_commands + + subcommands = list(default_commands) + subcommands.extend(plugins.commands()) + + if lib is None: + lib = _open_library(config) + plugins.send("library_opened", lib=lib) + library.Item._types = plugins.types(library.Item) + library.Album._types = plugins.types(library.Album) + + return subcommands, plugins, lib + + +def _configure(options): + """Amend the global configuration object with command line options. + """ + # Add any additional config files specified with --config. This + # special handling lets specified plugins get loaded before we + # finish parsing the command line. + if getattr(options, 'config', None) is not None: + config_path = options.config + del options.config + config.set_file(config_path) + config.set_args(options) + + # Configure the logger. + if config['verbose'].get(bool): + log.setLevel(logging.DEBUG) + else: + log.setLevel(logging.INFO) + + config_path = config.user_config_path() + if os.path.isfile(config_path): + log.debug(u'user configuration: {0}'.format( + util.displayable_path(config_path))) + else: + log.debug(u'no user configuration found at {0}'.format( + util.displayable_path(config_path))) + + log.debug(u'data directory: {0}' + .format(util.displayable_path(config.config_dir()))) + return config + + +def _open_library(config): + """Create a new library instance from the configuration. + """ + dbpath = config['library'].as_filename() + try: + lib = library.Library( + dbpath, + config['directory'].as_filename(), + get_path_formats(), + get_replacements(), + ) + lib.get_item(0) # Test database connection. + except (sqlite3.OperationalError, sqlite3.DatabaseError): + log.debug(traceback.format_exc()) + raise UserError(u"database file {0} could not be opened".format( + util.displayable_path(dbpath) + )) + log.debug(u'library database: {0}\n' + u'library directory: {1}' + .format(util.displayable_path(lib.path), + util.displayable_path(lib.directory))) + return lib + + +def _raw_main(args, lib=None): + """A helper function for `main` without top-level exception + handling. + """ + parser = SubcommandsOptionParser() + parser.add_option('-l', '--library', dest='library', + help='library database file to use') + parser.add_option('-d', '--directory', dest='directory', + help="destination music directory") + parser.add_option('-v', '--verbose', dest='verbose', action='store_true', + help='print debugging information') + parser.add_option('-c', '--config', dest='config', + help='path to configuration file') + parser.add_option('-h', '--help', dest='help', action='store_true', + help='how this help message and exit') + parser.add_option('--version', dest='version', action='store_true', + help=optparse.SUPPRESS_HELP) + + options, subargs = parser.parse_global_options(args) + + # Special case for the `config --edit` command: bypass _setup so + # that an invalid configuration does not prevent the editor from + # starting. + if subargs[0] == 'config' and ('-e' in subargs or '--edit' in subargs): + from beets.ui.commands import config_edit + return config_edit() + + subcommands, plugins, lib = _setup(options, lib) + parser.add_subcommand(*subcommands) + + subcommand, suboptions, subargs = parser.parse_subcommand(subargs) + subcommand.func(lib, suboptions, subargs) + + plugins.send('cli_exit', lib=lib) + + +def main(args=None): + """Run the main command-line interface for beets. Includes top-level + exception handlers that print friendly error messages. + """ + try: + _raw_main(args) + except UserError as exc: + message = exc.args[0] if exc.args else None + log.error(u'error: {0}'.format(message)) + sys.exit(1) + except util.HumanReadableException as exc: + exc.log(log) + sys.exit(1) + except library.FileOperationError as exc: + # These errors have reasonable human-readable descriptions, but + # we still want to log their tracebacks for debugging. + log.debug(traceback.format_exc()) + log.error(exc) + sys.exit(1) + except confit.ConfigError as exc: + log.error(u'configuration error: {0}'.format(exc)) + sys.exit(1) + except IOError as exc: + if exc.errno == errno.EPIPE: + # "Broken pipe". End silently. + pass + else: + raise + except KeyboardInterrupt: + # Silently ignore ^C except in verbose mode. + log.debug(traceback.format_exc()) diff --git a/lib/beets/ui/commands.py b/lib/beets/ui/commands.py new file mode 100644 index 00000000..4dfac11c --- /dev/null +++ b/lib/beets/ui/commands.py @@ -0,0 +1,1646 @@ +# This file is part of beets. +# Copyright 2014, Adrian Sampson. +# +# Permission is hereby granted, free of charge, to any person obtaining +# a copy of this software and associated documentation files (the +# "Software"), to deal in the Software without restriction, including +# without limitation the rights to use, copy, modify, merge, publish, +# distribute, sublicense, and/or sell copies of the Software, and to +# permit persons to whom the Software is furnished to do so, subject to +# the following conditions: +# +# The above copyright notice and this permission notice shall be +# included in all copies or substantial portions of the Software. + +"""This module provides the default commands for beets' command-line +interface. +""" +from __future__ import print_function + +import logging +import os +import time +import codecs +import platform +import re +import shlex + +import beets +from beets import ui +from beets.ui import print_, input_, decargs +from beets import autotag +from beets.autotag import Recommendation +from beets.autotag import hooks +from beets import plugins +from beets import importer +from beets import util +from beets.util import syspath, normpath, ancestry, displayable_path +from beets.util.functemplate import Template +from beets import library +from beets import config +from beets.util.confit import _package_path + +VARIOUS_ARTISTS = u'Various Artists' + +# Global logger. +log = logging.getLogger('beets') + +# The list of default subcommands. This is populated with Subcommand +# objects that can be fed to a SubcommandsOptionParser. +default_commands = [] + + +# Utilities. + +def _do_query(lib, query, album, also_items=True): + """For commands that operate on matched items, performs a query + and returns a list of matching items and a list of matching + albums. (The latter is only nonempty when album is True.) Raises + a UserError if no items match. also_items controls whether, when + fetching albums, the associated items should be fetched also. + """ + if album: + albums = list(lib.albums(query)) + items = [] + if also_items: + for al in albums: + items += al.items() + + else: + albums = [] + items = list(lib.items(query)) + + if album and not albums: + raise ui.UserError('No matching albums found.') + elif not album and not items: + raise ui.UserError('No matching items found.') + + return items, albums + + +# fields: Shows a list of available fields for queries and format strings. + +def fields_func(lib, opts, args): + def _print_rows(names): + print(" " + "\n ".join(names)) + + def _show_plugin_fields(album): + plugin_fields = [] + for plugin in plugins.find_plugins(): + if album: + fdict = plugin.album_template_fields + else: + fdict = plugin.template_fields + plugin_fields += fdict.keys() + if plugin_fields: + print("Template fields from plugins:") + _print_rows(plugin_fields) + + print("Item fields:") + _print_rows(library.Item._fields.keys()) + _show_plugin_fields(False) + + print("\nAlbum fields:") + _print_rows(library.Album._fields.keys()) + _show_plugin_fields(True) + + +fields_cmd = ui.Subcommand( + 'fields', + help='show fields available for queries and format strings' +) +fields_cmd.func = fields_func +default_commands.append(fields_cmd) + + +# help: Print help text for commands + +class HelpCommand(ui.Subcommand): + + def __init__(self): + super(HelpCommand, self).__init__( + 'help', aliases=('?',), + help='give detailed help on a specific sub-command', + ) + + def func(self, lib, opts, args): + if args: + cmdname = args[0] + helpcommand = self.root_parser._subcommand_for_name(cmdname) + if not helpcommand: + raise ui.UserError("unknown command '{0}'".format(cmdname)) + helpcommand.print_help() + else: + self.root_parser.print_help() + + +default_commands.append(HelpCommand()) + + +# import: Autotagger and importer. + +# Importer utilities and support. + +def disambig_string(info): + """Generate a string for an AlbumInfo or TrackInfo object that + provides context that helps disambiguate similar-looking albums and + tracks. + """ + disambig = [] + if info.data_source and info.data_source != 'MusicBrainz': + disambig.append(info.data_source) + + if isinstance(info, hooks.AlbumInfo): + if info.media: + if info.mediums > 1: + disambig.append(u'{0}x{1}'.format( + info.mediums, info.media + )) + else: + disambig.append(info.media) + if info.year: + disambig.append(unicode(info.year)) + if info.country: + disambig.append(info.country) + if info.label: + disambig.append(info.label) + if info.albumdisambig: + disambig.append(info.albumdisambig) + + if disambig: + return u', '.join(disambig) + + +def dist_string(dist): + """Formats a distance (a float) as a colorized similarity percentage + string. + """ + out = '%.1f%%' % ((1 - dist) * 100) + if dist <= config['match']['strong_rec_thresh'].as_number(): + out = ui.colorize('green', out) + elif dist <= config['match']['medium_rec_thresh'].as_number(): + out = ui.colorize('yellow', out) + else: + out = ui.colorize('red', out) + return out + + +def penalty_string(distance, limit=None): + """Returns a colorized string that indicates all the penalties + applied to a distance object. + """ + penalties = [] + for key in distance.keys(): + key = key.replace('album_', '') + key = key.replace('track_', '') + key = key.replace('_', ' ') + penalties.append(key) + if penalties: + if limit and len(penalties) > limit: + penalties = penalties[:limit] + ['...'] + return ui.colorize('yellow', '(%s)' % ', '.join(penalties)) + + +def show_change(cur_artist, cur_album, match): + """Print out a representation of the changes that will be made if an + album's tags are changed according to `match`, which must be an AlbumMatch + object. + """ + def show_album(artist, album): + if artist: + album_description = u' %s - %s' % (artist, album) + elif album: + album_description = u' %s' % album + else: + album_description = u' (unknown album)' + print_(album_description) + + def format_index(track_info): + """Return a string representing the track index of the given + TrackInfo or Item object. + """ + if isinstance(track_info, hooks.TrackInfo): + index = track_info.index + medium_index = track_info.medium_index + medium = track_info.medium + mediums = match.info.mediums + else: + index = medium_index = track_info.track + medium = track_info.disc + mediums = track_info.disctotal + if config['per_disc_numbering']: + if mediums > 1: + return u'{0}-{1}'.format(medium, medium_index) + else: + return unicode(medium_index) + else: + return unicode(index) + + # Identify the album in question. + if cur_artist != match.info.artist or \ + (cur_album != match.info.album and + match.info.album != VARIOUS_ARTISTS): + artist_l, artist_r = cur_artist or '', match.info.artist + album_l, album_r = cur_album or '', match.info.album + if artist_r == VARIOUS_ARTISTS: + # Hide artists for VA releases. + artist_l, artist_r = u'', u'' + + artist_l, artist_r = ui.colordiff(artist_l, artist_r) + album_l, album_r = ui.colordiff(album_l, album_r) + + print_("Correcting tags from:") + show_album(artist_l, album_l) + print_("To:") + show_album(artist_r, album_r) + else: + print_(u"Tagging:\n {0.artist} - {0.album}".format(match.info)) + + # Data URL. + if match.info.data_url: + print_('URL:\n %s' % match.info.data_url) + + # Info line. + info = [] + # Similarity. + info.append('(Similarity: %s)' % dist_string(match.distance)) + # Penalties. + penalties = penalty_string(match.distance) + if penalties: + info.append(penalties) + # Disambiguation. + disambig = disambig_string(match.info) + if disambig: + info.append(ui.colorize('lightgray', '(%s)' % disambig)) + print_(' '.join(info)) + + # Tracks. + pairs = match.mapping.items() + pairs.sort(key=lambda (_, track_info): track_info.index) + + # Build up LHS and RHS for track difference display. The `lines` list + # contains ``(lhs, rhs, width)`` tuples where `width` is the length (in + # characters) of the uncolorized LHS. + lines = [] + medium = disctitle = None + for item, track_info in pairs: + + # Medium number and title. + if medium != track_info.medium or disctitle != track_info.disctitle: + media = match.info.media or 'Media' + if match.info.mediums > 1 and track_info.disctitle: + lhs = '%s %s: %s' % (media, track_info.medium, + track_info.disctitle) + elif match.info.mediums > 1: + lhs = '%s %s' % (media, track_info.medium) + elif track_info.disctitle: + lhs = '%s: %s' % (media, track_info.disctitle) + else: + lhs = None + if lhs: + lines.append((lhs, '', 0)) + medium, disctitle = track_info.medium, track_info.disctitle + + # Titles. + new_title = track_info.title + if not item.title.strip(): + # If there's no title, we use the filename. + cur_title = displayable_path(os.path.basename(item.path)) + lhs, rhs = cur_title, new_title + else: + cur_title = item.title.strip() + lhs, rhs = ui.colordiff(cur_title, new_title) + lhs_width = len(cur_title) + + # Track number change. + cur_track, new_track = format_index(item), format_index(track_info) + if cur_track != new_track: + if item.track in (track_info.index, track_info.medium_index): + color = 'lightgray' + else: + color = 'red' + templ = ui.colorize(color, u' (#{0})') + lhs += templ.format(cur_track) + rhs += templ.format(new_track) + lhs_width += len(cur_track) + 4 + + # Length change. + if item.length and track_info.length and \ + abs(item.length - track_info.length) > \ + config['ui']['length_diff_thresh'].as_number(): + cur_length = ui.human_seconds_short(item.length) + new_length = ui.human_seconds_short(track_info.length) + templ = ui.colorize('red', u' ({0})') + lhs += templ.format(cur_length) + rhs += templ.format(new_length) + lhs_width += len(cur_length) + 3 + + # Penalties. + penalties = penalty_string(match.distance.tracks[track_info]) + if penalties: + rhs += ' %s' % penalties + + if lhs != rhs: + lines.append((' * %s' % lhs, rhs, lhs_width)) + elif config['import']['detail']: + lines.append((' * %s' % lhs, '', lhs_width)) + + # Print each track in two columns, or across two lines. + col_width = (ui.term_width() - len(''.join([' * ', ' -> ']))) // 2 + if lines: + max_width = max(w for _, _, w in lines) + for lhs, rhs, lhs_width in lines: + if not rhs: + print_(lhs) + elif max_width > col_width: + print_(u'%s ->\n %s' % (lhs, rhs)) + else: + pad = max_width - lhs_width + print_(u'%s%s -> %s' % (lhs, ' ' * pad, rhs)) + + # Missing and unmatched tracks. + if match.extra_tracks: + print_('Missing tracks:') + for track_info in match.extra_tracks: + line = ' ! %s (#%s)' % (track_info.title, format_index(track_info)) + if track_info.length: + line += ' (%s)' % ui.human_seconds_short(track_info.length) + print_(ui.colorize('yellow', line)) + if match.extra_items: + print_('Unmatched tracks:') + for item in match.extra_items: + line = ' ! %s (#%s)' % (item.title, format_index(item)) + if item.length: + line += ' (%s)' % ui.human_seconds_short(item.length) + print_(ui.colorize('yellow', line)) + + +def show_item_change(item, match): + """Print out the change that would occur by tagging `item` with the + metadata from `match`, a TrackMatch object. + """ + cur_artist, new_artist = item.artist, match.info.artist + cur_title, new_title = item.title, match.info.title + + if cur_artist != new_artist or cur_title != new_title: + cur_artist, new_artist = ui.colordiff(cur_artist, new_artist) + cur_title, new_title = ui.colordiff(cur_title, new_title) + + print_("Correcting track tags from:") + print_(" %s - %s" % (cur_artist, cur_title)) + print_("To:") + print_(" %s - %s" % (new_artist, new_title)) + + else: + print_("Tagging track: %s - %s" % (cur_artist, cur_title)) + + # Data URL. + if match.info.data_url: + print_('URL:\n %s' % match.info.data_url) + + # Info line. + info = [] + # Similarity. + info.append('(Similarity: %s)' % dist_string(match.distance)) + # Penalties. + penalties = penalty_string(match.distance) + if penalties: + info.append(penalties) + # Disambiguation. + disambig = disambig_string(match.info) + if disambig: + info.append(ui.colorize('lightgray', '(%s)' % disambig)) + print_(' '.join(info)) + + +def summarize_items(items, singleton): + """Produces a brief summary line describing a set of items. Used for + manually resolving duplicates during import. + + `items` is a list of `Item` objects. `singleton` indicates whether + this is an album or single-item import (if the latter, them `items` + should only have one element). + """ + summary_parts = [] + if not singleton: + summary_parts.append("{0} items".format(len(items))) + + format_counts = {} + for item in items: + format_counts[item.format] = format_counts.get(item.format, 0) + 1 + if len(format_counts) == 1: + # A single format. + summary_parts.append(items[0].format) + else: + # Enumerate all the formats. + for format, count in format_counts.iteritems(): + summary_parts.append('{0} {1}'.format(format, count)) + + average_bitrate = sum([item.bitrate for item in items]) / len(items) + total_duration = sum([item.length for item in items]) + summary_parts.append('{0}kbps'.format(int(average_bitrate / 1000))) + summary_parts.append(ui.human_seconds_short(total_duration)) + + return ', '.join(summary_parts) + + +def _summary_judment(rec): + """Determines whether a decision should be made without even asking + the user. This occurs in quiet mode and when an action is chosen for + NONE recommendations. Return an action or None if the user should be + queried. May also print to the console if a summary judgment is + made. + """ + if config['import']['quiet']: + if rec == Recommendation.strong: + return importer.action.APPLY + else: + action = config['import']['quiet_fallback'].as_choice({ + 'skip': importer.action.SKIP, + 'asis': importer.action.ASIS, + }) + + elif rec == Recommendation.none: + action = config['import']['none_rec_action'].as_choice({ + 'skip': importer.action.SKIP, + 'asis': importer.action.ASIS, + 'ask': None, + }) + + else: + return None + + if action == importer.action.SKIP: + print_('Skipping.') + elif action == importer.action.ASIS: + print_('Importing as-is.') + return action + + +def choose_candidate(candidates, singleton, rec, cur_artist=None, + cur_album=None, item=None, itemcount=None): + """Given a sorted list of candidates, ask the user for a selection + of which candidate to use. Applies to both full albums and + singletons (tracks). Candidates are either AlbumMatch or TrackMatch + objects depending on `singleton`. for albums, `cur_artist`, + `cur_album`, and `itemcount` must be provided. For singletons, + `item` must be provided. + + Returns the result of the choice, which may SKIP, ASIS, TRACKS, or + MANUAL or a candidate (an AlbumMatch/TrackMatch object). + """ + # Sanity check. + if singleton: + assert item is not None + else: + assert cur_artist is not None + assert cur_album is not None + + # Zero candidates. + if not candidates: + if singleton: + print_("No matching recordings found.") + opts = ('Use as-is', 'Skip', 'Enter search', 'enter Id', + 'aBort') + else: + print_("No matching release found for {0} tracks." + .format(itemcount)) + print_('For help, see: ' + 'http://beets.readthedocs.org/en/latest/faq.html#nomatch') + opts = ('Use as-is', 'as Tracks', 'Group albums', 'Skip', + 'Enter search', 'enter Id', 'aBort') + sel = ui.input_options(opts) + if sel == 'u': + return importer.action.ASIS + elif sel == 't': + assert not singleton + return importer.action.TRACKS + elif sel == 'e': + return importer.action.MANUAL + elif sel == 's': + return importer.action.SKIP + elif sel == 'b': + raise importer.ImportAbort() + elif sel == 'i': + return importer.action.MANUAL_ID + elif sel == 'g': + return importer.action.ALBUMS + else: + assert False + + # Is the change good enough? + bypass_candidates = False + if rec != Recommendation.none: + match = candidates[0] + bypass_candidates = True + + while True: + # Display and choose from candidates. + require = rec <= Recommendation.low + + if not bypass_candidates: + # Display list of candidates. + print_(u'Finding tags for {0} "{1} - {2}".'.format( + u'track' if singleton else u'album', + item.artist if singleton else cur_artist, + item.title if singleton else cur_album, + )) + + print_(u'Candidates:') + for i, match in enumerate(candidates): + # Index, metadata, and distance. + line = [ + u'{0}.'.format(i + 1), + u'{0} - {1}'.format( + match.info.artist, + match.info.title if singleton else match.info.album, + ), + u'({0})'.format(dist_string(match.distance)), + ] + + # Penalties. + penalties = penalty_string(match.distance, 3) + if penalties: + line.append(penalties) + + # Disambiguation + disambig = disambig_string(match.info) + if disambig: + line.append(ui.colorize('lightgray', '(%s)' % disambig)) + + print_(' '.join(line)) + + # Ask the user for a choice. + if singleton: + opts = ('Skip', 'Use as-is', 'Enter search', 'enter Id', + 'aBort') + else: + opts = ('Skip', 'Use as-is', 'as Tracks', 'Group albums', + 'Enter search', 'enter Id', 'aBort') + sel = ui.input_options(opts, numrange=(1, len(candidates))) + if sel == 's': + return importer.action.SKIP + elif sel == 'u': + return importer.action.ASIS + elif sel == 'm': + pass + elif sel == 'e': + return importer.action.MANUAL + elif sel == 't': + assert not singleton + return importer.action.TRACKS + elif sel == 'b': + raise importer.ImportAbort() + elif sel == 'i': + return importer.action.MANUAL_ID + elif sel == 'g': + return importer.action.ALBUMS + else: # Numerical selection. + match = candidates[sel - 1] + if sel != 1: + # When choosing anything but the first match, + # disable the default action. + require = True + bypass_candidates = False + + # Show what we're about to do. + if singleton: + show_item_change(item, match) + else: + show_change(cur_artist, cur_album, match) + + # Exact match => tag automatically if we're not in timid mode. + if rec == Recommendation.strong and not config['import']['timid']: + return match + + # Ask for confirmation. + if singleton: + opts = ('Apply', 'More candidates', 'Skip', 'Use as-is', + 'Enter search', 'enter Id', 'aBort') + else: + opts = ('Apply', 'More candidates', 'Skip', 'Use as-is', + 'as Tracks', 'Group albums', 'Enter search', 'enter Id', + 'aBort') + default = config['import']['default_action'].as_choice({ + 'apply': 'a', + 'skip': 's', + 'asis': 'u', + 'none': None, + }) + if default is None: + require = True + sel = ui.input_options(opts, require=require, default=default) + if sel == 'a': + return match + elif sel == 'g': + return importer.action.ALBUMS + elif sel == 's': + return importer.action.SKIP + elif sel == 'u': + return importer.action.ASIS + elif sel == 't': + assert not singleton + return importer.action.TRACKS + elif sel == 'e': + return importer.action.MANUAL + elif sel == 'b': + raise importer.ImportAbort() + elif sel == 'i': + return importer.action.MANUAL_ID + + +def manual_search(singleton): + """Input either an artist and album (for full albums) or artist and + track name (for singletons) for manual search. + """ + artist = input_('Artist:') + name = input_('Track:' if singleton else 'Album:') + return artist.strip(), name.strip() + + +def manual_id(singleton): + """Input an ID, either for an album ("release") or a track ("recording"). + """ + prompt = u'Enter {0} ID:'.format('recording' if singleton else 'release') + return input_(prompt).strip() + + +class TerminalImportSession(importer.ImportSession): + """An import session that runs in a terminal. + """ + def choose_match(self, task): + """Given an initial autotagging of items, go through an interactive + dance with the user to ask for a choice of metadata. Returns an + AlbumMatch object, ASIS, or SKIP. + """ + # Show what we're tagging. + print_() + print_(displayable_path(task.paths, u'\n') + + u' ({0} items)'.format(len(task.items))) + + # Take immediate action if appropriate. + action = _summary_judment(task.rec) + if action == importer.action.APPLY: + match = task.candidates[0] + show_change(task.cur_artist, task.cur_album, match) + return match + elif action is not None: + return action + + # Loop until we have a choice. + candidates, rec = task.candidates, task.rec + while True: + # Ask for a choice from the user. + choice = choose_candidate( + candidates, False, rec, task.cur_artist, task.cur_album, + itemcount=len(task.items) + ) + + # Choose which tags to use. + if choice in (importer.action.SKIP, importer.action.ASIS, + importer.action.TRACKS, importer.action.ALBUMS): + # Pass selection to main control flow. + return choice + elif choice is importer.action.MANUAL: + # Try again with manual search terms. + search_artist, search_album = manual_search(False) + _, _, candidates, rec = autotag.tag_album( + task.items, search_artist, search_album + ) + elif choice is importer.action.MANUAL_ID: + # Try a manually-entered ID. + search_id = manual_id(False) + if search_id: + _, _, candidates, rec = autotag.tag_album( + task.items, search_id=search_id + ) + else: + # We have a candidate! Finish tagging. Here, choice is an + # AlbumMatch object. + assert isinstance(choice, autotag.AlbumMatch) + return choice + + def choose_item(self, task): + """Ask the user for a choice about tagging a single item. Returns + either an action constant or a TrackMatch object. + """ + print_() + print_(task.item.path) + candidates, rec = task.candidates, task.rec + + # Take immediate action if appropriate. + action = _summary_judment(task.rec) + if action == importer.action.APPLY: + match = candidates[0] + show_item_change(task.item, match) + return match + elif action is not None: + return action + + while True: + # Ask for a choice. + choice = choose_candidate(candidates, True, rec, item=task.item) + + if choice in (importer.action.SKIP, importer.action.ASIS): + return choice + elif choice == importer.action.TRACKS: + assert False # TRACKS is only legal for albums. + elif choice == importer.action.MANUAL: + # Continue in the loop with a new set of candidates. + search_artist, search_title = manual_search(True) + candidates, rec = autotag.tag_item(task.item, search_artist, + search_title) + elif choice == importer.action.MANUAL_ID: + # Ask for a track ID. + search_id = manual_id(True) + if search_id: + candidates, rec = autotag.tag_item(task.item, + search_id=search_id) + else: + # Chose a candidate. + assert isinstance(choice, autotag.TrackMatch) + return choice + + def resolve_duplicate(self, task, found_duplicates): + """Decide what to do when a new album or item seems similar to one + that's already in the library. + """ + log.warn(u"This {0} is already in the library!" + .format("album" if task.is_album else "item")) + + if config['import']['quiet']: + # In quiet mode, don't prompt -- just skip. + log.info(u'Skipping.') + sel = 's' + else: + # Print some detail about the existing and new items so the + # user can make an informed decision. + for duplicate in found_duplicates: + print("Old: " + summarize_items( + list(duplicate.items()) if task.is_album else [duplicate], + not task.is_album, + )) + print("New: " + summarize_items( + task.imported_items(), + not task.is_album, + )) + + sel = ui.input_options( + ('Skip new', 'Keep both', 'Remove old') + ) + + if sel == 's': + # Skip new. + task.set_choice(importer.action.SKIP) + elif sel == 'k': + # Keep both. Do nothing; leave the choice intact. + pass + elif sel == 'r': + # Remove old. + task.should_remove_duplicates = True + else: + assert False + + def should_resume(self, path): + return ui.input_yn(u"Import of the directory:\n{0}\n" + "was interrupted. Resume (Y/n)?" + .format(displayable_path(path))) + +# The import command. + + +def import_files(lib, paths, query): + """Import the files in the given list of paths or matching the + query. + """ + # Check the user-specified directories. + for path in paths: + if not os.path.exists(syspath(normpath(path))): + raise ui.UserError(u'no such file or directory: {0}'.format( + displayable_path(path))) + + # Check parameter consistency. + if config['import']['quiet'] and config['import']['timid']: + raise ui.UserError("can't be both quiet and timid") + + # Open the log. + if config['import']['log'].get() is not None: + logpath = config['import']['log'].as_filename() + try: + logfile = codecs.open(syspath(logpath), 'a', 'utf8') + except IOError: + raise ui.UserError(u"could not open log file for writing: %s" % + displayable_path(logpath)) + print(u'import started', time.asctime(), file=logfile) + else: + logfile = None + + # Never ask for input in quiet mode. + if config['import']['resume'].get() == 'ask' and \ + config['import']['quiet']: + config['import']['resume'] = False + + session = TerminalImportSession(lib, logfile, paths, query) + try: + session.run() + finally: + # If we were logging, close the file. + if logfile: + print(u'', file=logfile) + logfile.close() + + # Emit event. + plugins.send('import', lib=lib, paths=paths) + + +def import_func(lib, opts, args): + config['import'].set_args(opts) + + # Special case: --copy flag suppresses import_move (which would + # otherwise take precedence). + if opts.copy: + config['import']['move'] = False + + if opts.library: + query = decargs(args) + paths = [] + else: + query = None + paths = args + if not paths: + raise ui.UserError('no path specified') + + import_files(lib, paths, query) + + +import_cmd = ui.Subcommand( + 'import', help='import new music', aliases=('imp', 'im') +) +import_cmd.parser.add_option( + '-c', '--copy', action='store_true', default=None, + help="copy tracks into library directory (default)" +) +import_cmd.parser.add_option( + '-C', '--nocopy', action='store_false', dest='copy', + help="don't copy tracks (opposite of -c)" +) +import_cmd.parser.add_option( + '-w', '--write', action='store_true', default=None, + help="write new metadata to files' tags (default)" +) +import_cmd.parser.add_option( + '-W', '--nowrite', action='store_false', dest='write', + help="don't write metadata (opposite of -w)" +) +import_cmd.parser.add_option( + '-a', '--autotag', action='store_true', dest='autotag', + help="infer tags for imported files (default)" +) +import_cmd.parser.add_option( + '-A', '--noautotag', action='store_false', dest='autotag', + help="don't infer tags for imported files (opposite of -a)" +) +import_cmd.parser.add_option( + '-p', '--resume', action='store_true', default=None, + help="resume importing if interrupted" +) +import_cmd.parser.add_option( + '-P', '--noresume', action='store_false', dest='resume', + help="do not try to resume importing" +) +import_cmd.parser.add_option( + '-q', '--quiet', action='store_true', dest='quiet', + help="never prompt for input: skip albums instead" +) +import_cmd.parser.add_option( + '-l', '--log', dest='log', + help='file to log untaggable albums for later review' +) +import_cmd.parser.add_option( + '-s', '--singletons', action='store_true', + help='import individual tracks instead of full albums' +) +import_cmd.parser.add_option( + '-t', '--timid', dest='timid', action='store_true', + help='always confirm all actions' +) +import_cmd.parser.add_option( + '-L', '--library', dest='library', action='store_true', + help='retag items matching a query' +) +import_cmd.parser.add_option( + '-i', '--incremental', dest='incremental', action='store_true', + help='skip already-imported directories' +) +import_cmd.parser.add_option( + '-I', '--noincremental', dest='incremental', action='store_false', + help='do not skip already-imported directories' +) +import_cmd.parser.add_option( + '--flat', dest='flat', action='store_true', + help='import an entire tree as a single album' +) +import_cmd.parser.add_option( + '-g', '--group-albums', dest='group_albums', action='store_true', + help='group tracks in a folder into separate albums' +) +import_cmd.parser.add_option( + '--pretend', dest='pretend', action='store_true', + help='just print the files to import' +) +import_cmd.func = import_func +default_commands.append(import_cmd) + + +# list: Query and show library contents. + +def list_items(lib, query, album, fmt): + """Print out items in lib matching query. If album, then search for + albums instead of single items. + """ + tmpl = Template(ui._pick_format(album, fmt)) + if album: + for album in lib.albums(query): + ui.print_obj(album, lib, tmpl) + else: + for item in lib.items(query): + ui.print_obj(item, lib, tmpl) + + +def list_func(lib, opts, args): + if opts.path: + fmt = '$path' + else: + fmt = opts.format + list_items(lib, decargs(args), opts.album, fmt) + + +list_cmd = ui.Subcommand('list', help='query the library', aliases=('ls',)) +list_cmd.parser.add_option( + '-a', '--album', action='store_true', + help='show matching albums instead of tracks' +) +list_cmd.parser.add_option( + '-p', '--path', action='store_true', + help='print paths for matched items or albums' +) +list_cmd.parser.add_option( + '-f', '--format', action='store', + help='print with custom format', default=None +) +list_cmd.func = list_func +default_commands.append(list_cmd) + + +# update: Update library contents according to on-disk tags. + +def update_items(lib, query, album, move, pretend): + """For all the items matched by the query, update the library to + reflect the item's embedded tags. + """ + with lib.transaction(): + items, _ = _do_query(lib, query, album) + + # Walk through the items and pick up their changes. + affected_albums = set() + for item in items: + # Item deleted? + if not os.path.exists(syspath(item.path)): + ui.print_obj(item, lib) + ui.print_(ui.colorize('red', u' deleted')) + if not pretend: + item.remove(True) + affected_albums.add(item.album_id) + continue + + # Did the item change since last checked? + if item.current_mtime() <= item.mtime: + log.debug(u'skipping {0} because mtime is up to date ({1})' + .format(displayable_path(item.path), item.mtime)) + continue + + # Read new data. + try: + item.read() + except library.ReadError as exc: + log.error(u'error reading {0}: {1}'.format( + displayable_path(item.path), exc)) + continue + + # Special-case album artist when it matches track artist. (Hacky + # but necessary for preserving album-level metadata for non- + # autotagged imports.) + if not item.albumartist: + old_item = lib.get_item(item.id) + if old_item.albumartist == old_item.artist == item.artist: + item.albumartist = old_item.albumartist + item._dirty.discard('albumartist') + + # Check for and display changes. + changed = ui.show_model_changes(item, + fields=library.Item._media_fields) + + # Save changes. + if not pretend: + if changed: + # Move the item if it's in the library. + if move and lib.directory in ancestry(item.path): + item.move() + + item.store() + affected_albums.add(item.album_id) + else: + # The file's mtime was different, but there were no + # changes to the metadata. Store the new mtime, + # which is set in the call to read(), so we don't + # check this again in the future. + item.store() + + # Skip album changes while pretending. + if pretend: + return + + # Modify affected albums to reflect changes in their items. + for album_id in affected_albums: + if album_id is None: # Singletons. + continue + album = lib.get_album(album_id) + if not album: # Empty albums have already been removed. + log.debug(u'emptied album {0}'.format(album_id)) + continue + first_item = album.items().get() + + # Update album structure to reflect an item in it. + for key in library.Album.item_keys: + album[key] = first_item[key] + album.store() + + # Move album art (and any inconsistent items). + if move and lib.directory in ancestry(first_item.path): + log.debug(u'moving album {0}'.format(album_id)) + album.move() + + +def update_func(lib, opts, args): + update_items(lib, decargs(args), opts.album, opts.move, opts.pretend) + + +update_cmd = ui.Subcommand( + 'update', help='update the library', aliases=('upd', 'up',) +) +update_cmd.parser.add_option( + '-a', '--album', action='store_true', + help='match albums instead of tracks' +) +update_cmd.parser.add_option( + '-M', '--nomove', action='store_false', default=True, dest='move', + help="don't move files in library" +) +update_cmd.parser.add_option( + '-p', '--pretend', action='store_true', + help="show all changes but do nothing" +) +update_cmd.parser.add_option( + '-f', '--format', action='store', + help='print with custom format', default=None +) +update_cmd.func = update_func +default_commands.append(update_cmd) + + +# remove: Remove items from library, delete files. + +def remove_items(lib, query, album, delete): + """Remove items matching query from lib. If album, then match and + remove whole albums. If delete, also remove files from disk. + """ + # Get the matching items. + items, albums = _do_query(lib, query, album) + + # Prepare confirmation with user. + print_() + if delete: + fmt = u'$path - $title' + prompt = 'Really DELETE %i files (y/n)?' % len(items) + else: + fmt = None + prompt = 'Really remove %i items from the library (y/n)?' % \ + len(items) + + # Show all the items. + for item in items: + ui.print_obj(item, lib, fmt) + + # Confirm with user. + if not ui.input_yn(prompt, True): + return + + # Remove (and possibly delete) items. + with lib.transaction(): + for obj in (albums if album else items): + obj.remove(delete) + + +def remove_func(lib, opts, args): + remove_items(lib, decargs(args), opts.album, opts.delete) + + +remove_cmd = ui.Subcommand( + 'remove', help='remove matching items from the library', aliases=('rm',) +) +remove_cmd.parser.add_option( + "-d", "--delete", action="store_true", + help="also remove files from disk" +) +remove_cmd.parser.add_option( + '-a', '--album', action='store_true', + help='match albums instead of tracks' +) +remove_cmd.func = remove_func +default_commands.append(remove_cmd) + + +# stats: Show library/query statistics. + +def show_stats(lib, query, exact): + """Shows some statistics about the matched items.""" + items = lib.items(query) + + total_size = 0 + total_time = 0.0 + total_items = 0 + artists = set() + albums = set() + album_artists = set() + + for item in items: + if exact: + total_size += os.path.getsize(item.path) + else: + total_size += int(item.length * item.bitrate / 8) + total_time += item.length + total_items += 1 + artists.add(item.artist) + album_artists.add(item.albumartist) + if item.album_id: + albums.add(item.album_id) + + size_str = '' + ui.human_bytes(total_size) + if exact: + size_str += ' ({0} bytes)'.format(total_size) + + print_("""Tracks: {0} +Total time: {1}{2} +{3}: {4} +Artists: {5} +Albums: {6} +Album artists: {7}""".format( + total_items, + ui.human_seconds(total_time), + ' ({0:.2f} seconds)'.format(total_time) if exact else '', + 'Total size' if exact else 'Approximate total size', + size_str, + len(artists), + len(albums), + len(album_artists)), + ) + + +def stats_func(lib, opts, args): + show_stats(lib, decargs(args), opts.exact) + + +stats_cmd = ui.Subcommand( + 'stats', help='show statistics about the library or a query' +) +stats_cmd.parser.add_option( + '-e', '--exact', action='store_true', + help='exact size and time' +) +stats_cmd.func = stats_func +default_commands.append(stats_cmd) + + +# version: Show current beets version. + +def show_version(lib, opts, args): + print_('beets version %s' % beets.__version__) + # Show plugins. + names = [p.name for p in plugins.find_plugins()] + if names: + print_('plugins:', ', '.join(names)) + else: + print_('no plugins loaded') + + +version_cmd = ui.Subcommand( + 'version', help='output version information' +) +version_cmd.func = show_version +default_commands.append(version_cmd) + + +# modify: Declaratively change metadata. + +def modify_items(lib, mods, dels, query, write, move, album, confirm): + """Modifies matching items according to user-specified assignments and + deletions. + + `mods` is a dictionary of field and value pairse indicating + assignments. `dels` is a list of fields to be deleted. + """ + # Parse key=value specifications into a dictionary. + model_cls = library.Album if album else library.Item + + for key, value in mods.items(): + mods[key] = model_cls._parse(key, value) + + # Get the items to modify. + items, albums = _do_query(lib, query, album, False) + objs = albums if album else items + + # Apply changes *temporarily*, preview them, and collect modified + # objects. + print_('Modifying {0} {1}s.' + .format(len(objs), 'album' if album else 'item')) + changed = set() + for obj in objs: + obj.update(mods) + for field in dels: + try: + del obj[field] + except KeyError: + pass + if ui.show_model_changes(obj): + changed.add(obj) + + # Still something to do? + if not changed: + print_('No changes to make.') + return + + # Confirm action. + if confirm: + if write and move: + extra = ', move and write tags' + elif write: + extra = ' and write tags' + elif move: + extra = ' and move' + else: + extra = '' + + if not ui.input_yn('Really modify%s (Y/n)?' % extra): + return + + # Apply changes to database and files + with lib.transaction(): + for obj in changed: + if move: + cur_path = obj.path + if lib.directory in ancestry(cur_path): # In library? + log.debug(u'moving object {0}' + .format(displayable_path(cur_path))) + obj.move() + + obj.try_sync(write) + + +def modify_parse_args(args): + """Split the arguments for the modify subcommand into query parts, + assignments (field=value), and deletions (field!). Returns the result as + a three-tuple in that order. + """ + mods = {} + dels = [] + query = [] + for arg in args: + if arg.endswith('!') and '=' not in arg and ':' not in arg: + dels.append(arg[:-1]) # Strip trailing !. + elif '=' in arg and ':' not in arg.split('=', 1)[0]: + key, val = arg.split('=', 1) + mods[key] = val + else: + query.append(arg) + return query, mods, dels + + +def modify_func(lib, opts, args): + query, mods, dels = modify_parse_args(decargs(args)) + if not mods and not dels: + raise ui.UserError('no modifications specified') + write = opts.write if opts.write is not None else \ + config['import']['write'].get(bool) + modify_items(lib, mods, dels, query, write, opts.move, opts.album, + not opts.yes) + + +modify_cmd = ui.Subcommand( + 'modify', help='change metadata fields', aliases=('mod',) +) +modify_cmd.parser.add_option( + '-M', '--nomove', action='store_false', default=True, dest='move', + help="don't move files in library" +) +modify_cmd.parser.add_option( + '-w', '--write', action='store_true', default=None, + help="write new metadata to files' tags (default)" +) +modify_cmd.parser.add_option( + '-W', '--nowrite', action='store_false', dest='write', + help="don't write metadata (opposite of -w)" +) +modify_cmd.parser.add_option( + '-a', '--album', action='store_true', + help='modify whole albums instead of tracks' +) +modify_cmd.parser.add_option( + '-y', '--yes', action='store_true', + help='skip confirmation' +) +modify_cmd.parser.add_option( + '-f', '--format', action='store', + help='print with custom format', default=None +) +modify_cmd.func = modify_func +default_commands.append(modify_cmd) + + +# move: Move/copy files to the library or a new base directory. + +def move_items(lib, dest, query, copy, album): + """Moves or copies items to a new base directory, given by dest. If + dest is None, then the library's base directory is used, making the + command "consolidate" files. + """ + items, albums = _do_query(lib, query, album, False) + objs = albums if album else items + + action = 'Copying' if copy else 'Moving' + entity = 'album' if album else 'item' + log.info(u'{0} {1} {2}s.'.format(action, len(objs), entity)) + for obj in objs: + log.debug(u'moving: {0}'.format(util.displayable_path(obj.path))) + + obj.move(copy, basedir=dest) + obj.store() + + +def move_func(lib, opts, args): + dest = opts.dest + if dest is not None: + dest = normpath(dest) + if not os.path.isdir(dest): + raise ui.UserError('no such directory: %s' % dest) + + move_items(lib, dest, decargs(args), opts.copy, opts.album) + + +move_cmd = ui.Subcommand( + 'move', help='move or copy items', aliases=('mv',) +) +move_cmd.parser.add_option( + '-d', '--dest', metavar='DIR', dest='dest', + help='destination directory' +) +move_cmd.parser.add_option( + '-c', '--copy', default=False, action='store_true', + help='copy instead of moving' +) +move_cmd.parser.add_option( + '-a', '--album', default=False, action='store_true', + help='match whole albums instead of tracks' +) +move_cmd.func = move_func +default_commands.append(move_cmd) + + +# write: Write tags into files. + +def write_items(lib, query, pretend, force): + """Write tag information from the database to the respective files + in the filesystem. + """ + items, albums = _do_query(lib, query, False, False) + + for item in items: + # Item deleted? + if not os.path.exists(syspath(item.path)): + log.info(u'missing file: {0}'.format( + util.displayable_path(item.path) + )) + continue + + # Get an Item object reflecting the "clean" (on-disk) state. + try: + clean_item = library.Item.from_path(item.path) + except library.ReadError as exc: + log.error(u'error reading {0}: {1}'.format( + displayable_path(item.path), exc + )) + continue + + # Check for and display changes. + changed = ui.show_model_changes(item, clean_item, + library.Item._media_fields, force) + if (changed or force) and not pretend: + item.try_sync() + + +def write_func(lib, opts, args): + write_items(lib, decargs(args), opts.pretend, opts.force) + + +write_cmd = ui.Subcommand('write', help='write tag information to files') +write_cmd.parser.add_option( + '-p', '--pretend', action='store_true', + help="show all changes but do nothing" +) +write_cmd.parser.add_option( + '-f', '--force', action='store_true', + help="write tags even if the existing tags match the database" +) +write_cmd.func = write_func +default_commands.append(write_cmd) + + +# config: Show and edit user configuration. + +def config_func(lib, opts, args): + # Make sure lazy configuration is loaded + config.resolve() + + # Print paths. + if opts.paths: + filenames = [] + for source in config.sources: + if not opts.defaults and source.default: + continue + if source.filename: + filenames.append(source.filename) + + # In case the user config file does not exist, prepend it to the + # list. + user_path = config.user_config_path() + if user_path not in filenames: + filenames.insert(0, user_path) + + for filename in filenames: + print(filename) + + # Open in editor. + elif opts.edit: + config_edit() + + # Dump configuration. + else: + print(config.dump(full=opts.defaults)) + + +def config_edit(): + """Open a program to edit the user configuration. + """ + path = config.user_config_path() + + if 'EDITOR' in os.environ: + editor = os.environ['EDITOR'] + try: + editor = shlex.split(editor) + except ValueError: # Malformed shell tokens. + editor = [editor] + args = editor + [path] + args.insert(1, args[0]) + elif platform.system() == 'Darwin': + args = ['open', 'open', '-n', path] + elif platform.system() == 'Windows': + # On windows we can execute arbitrary files. The os will + # take care of starting an appropriate application + args = [path, path] + else: + # Assume Unix + args = ['xdg-open', 'xdg-open', path] + + try: + os.execlp(*args) + except OSError: + raise ui.UserError("Could not edit configuration. Please " + "set the EDITOR environment variable.") + + +config_cmd = ui.Subcommand('config', + help='show or edit the user configuration') +config_cmd.parser.add_option( + '-p', '--paths', action='store_true', + help='show files that configuration was loaded from' +) +config_cmd.parser.add_option( + '-e', '--edit', action='store_true', + help='edit user configuration with $EDITOR' +) +config_cmd.parser.add_option( + '-d', '--defaults', action='store_true', + help='include the default configuration' +) +config_cmd.func = config_func +default_commands.append(config_cmd) + + +# completion: print completion script + +def print_completion(*args): + for line in completion_script(default_commands + plugins.commands()): + print(line, end='') + if not any(map(os.path.isfile, BASH_COMPLETION_PATHS)): + log.warn(u'Warning: Unable to find the bash-completion package. ' + u'Command line completion might not work.') + +BASH_COMPLETION_PATHS = map(syspath, [ + u'/etc/bash_completion', + u'/usr/share/bash-completion/bash_completion', + u'/usr/share/local/bash-completion/bash_completion', + u'/opt/local/share/bash-completion/bash_completion', # SmartOS + u'/usr/local/etc/bash_completion', # Homebrew +]) + + +def completion_script(commands): + """Yield the full completion shell script as strings. + + ``commands`` is alist of ``ui.Subcommand`` instances to generate + completion data for. + """ + base_script = os.path.join(_package_path('beets.ui'), 'completion_base.sh') + with open(base_script, 'r') as base_script: + yield base_script.read() + + options = {} + aliases = {} + command_names = [] + + # Collect subcommands + for cmd in commands: + name = cmd.name + command_names.append(name) + + for alias in cmd.aliases: + if re.match(r'^\w+$', alias): + aliases[alias] = name + + options[name] = {'flags': [], 'opts': []} + for opts in cmd.parser._get_all_options()[1:]: + if opts.action in ('store_true', 'store_false'): + option_type = 'flags' + else: + option_type = 'opts' + + options[name][option_type].extend( + opts._short_opts + opts._long_opts + ) + + # Add global options + options['_global'] = { + 'flags': ['-v', '--verbose'], + 'opts': '-l --library -c --config -d --directory -h --help'.split(' ') + } + + # Add flags common to all commands + options['_common'] = { + 'flags': ['-h', '--help'] + } + + # Start generating the script + yield "_beet() {\n" + + # Command names + yield " local commands='%s'\n" % ' '.join(command_names) + yield "\n" + + # Command aliases + yield " local aliases='%s'\n" % ' '.join(aliases.keys()) + for alias, cmd in aliases.items(): + yield " local alias__%s=%s\n" % (alias, cmd) + yield '\n' + + # Fields + yield " fields='%s'\n" % ' '.join( + set(library.Item._fields.keys() + library.Album._fields.keys()) + ) + + # Command options + for cmd, opts in options.items(): + for option_type, option_list in opts.items(): + if option_list: + option_list = ' '.join(option_list) + yield " local %s__%s='%s'\n" % (option_type, cmd, option_list) + + yield ' _beet_dispatch\n' + yield '}\n' + + +completion_cmd = ui.Subcommand( + 'completion', + help='print shell script that provides command line completion' +) +completion_cmd.func = print_completion +completion_cmd.hide = True +default_commands.append(completion_cmd) diff --git a/lib/beets/ui/completion_base.sh b/lib/beets/ui/completion_base.sh new file mode 100644 index 00000000..ce3fb6e2 --- /dev/null +++ b/lib/beets/ui/completion_base.sh @@ -0,0 +1,162 @@ +# This file is part of beets. +# Copyright (c) 2014, Thomas Scholtes. +# +# Permission is hereby granted, free of charge, to any person obtaining +# a copy of this software and associated documentation files (the +# "Software"), to deal in the Software without restriction, including +# without limitation the rights to use, copy, modify, merge, publish, +# distribute, sublicense, and/or sell copies of the Software, and to +# permit persons to whom the Software is furnished to do so, subject to +# the following conditions: +# +# The above copyright notice and this permission notice shall be +# included in all copies or substantial portions of the Software. + + + +# Completion for the `beet` command +# ================================= +# +# Load this script to complete beets subcommands, options, and +# queries. +# +# If a beets command is found on the command line it completes filenames and +# the subcommand's options. Otherwise it will complete global options and +# subcommands. If the previous option on the command line expects an argument, +# it also completes filenames or directories. Options are only +# completed if '-' has already been typed on the command line. +# +# Note that completion of plugin commands only works for those plugins +# that were enabled when running `beet completion`. It does not check +# plugins dynamically +# +# Currently, only Bash 3.2 and newer is supported and the +# `bash-completion` package is requied. +# +# TODO +# ---- +# +# * There are some issues with arguments that are quoted on the command line. +# +# * Complete arguments for the `--format` option by expanding field variables. +# +# beet ls -f "$tit[TAB] +# beet ls -f "$title +# +# * Support long options with `=`, e.g. `--config=file`. Debian's bash +# completion package can handle this. +# + + +# Determines the beets subcommand and dispatches the completion +# accordingly. +_beet_dispatch() { + local cur prev cmd= + + COMPREPLY=() + _get_comp_words_by_ref -n : cur prev + + # Look for the beets subcommand + local arg + for (( i=1; i < COMP_CWORD; i++ )); do + arg="${COMP_WORDS[i]}" + if _list_include_item "${opts___global}" $arg; then + ((i++)) + elif [[ "$arg" != -* ]]; then + cmd="$arg" + break + fi + done + + # Replace command shortcuts + if [[ -n $cmd ]] && _list_include_item "$aliases" "$cmd"; then + eval "cmd=\$alias__$cmd" + fi + + case $cmd in + help) + COMPREPLY+=( $(compgen -W "$commands" -- $cur) ) + ;; + list|remove|move|update|write|stats) + _beet_complete_query + ;; + "") + _beet_complete_global + ;; + *) + _beet_complete + ;; + esac +} + + +# Adds option and file completion to COMPREPLY for the subcommand $cmd +_beet_complete() { + if [[ $cur == -* ]]; then + local opts flags completions + eval "opts=\$opts__$cmd" + eval "flags=\$flags__$cmd" + completions="${flags___common} ${opts} ${flags}" + COMPREPLY+=( $(compgen -W "$completions" -- $cur) ) + else + _filedir + fi +} + + +# Add global options and subcommands to the completion +_beet_complete_global() { + case $prev in + -h|--help) + # Complete commands + COMPREPLY+=( $(compgen -W "$commands" -- $cur) ) + return + ;; + -l|--library|-c|--config) + # Filename completion + _filedir + return + ;; + -d|--directory) + # Directory completion + _filedir -d + return + ;; + esac + + if [[ $cur == -* ]]; then + local completions="$opts___global $flags___global" + COMPREPLY+=( $(compgen -W "$completions" -- $cur) ) + elif [[ -n $cur ]] && _list_include_item "$aliases" "$cur"; then + local cmd + eval "cmd=\$alias__$cur" + COMPREPLY+=( "$cmd" ) + else + COMPREPLY+=( $(compgen -W "$commands" -- $cur) ) + fi +} + +_beet_complete_query() { + local opts + eval "opts=\$opts__$cmd" + + if [[ $cur == -* ]] || _list_include_item "$opts" "$prev"; then + _beet_complete + elif [[ $cur != \'* && $cur != \"* && + $cur != *:* ]]; then + # Do not complete quoted queries or those who already have a field + # set. + compopt -o nospace + COMPREPLY+=( $(compgen -S : -W "$fields" -- $cur) ) + return 0 + fi +} + +# Returns true if the space separated list $1 includes $2 +_list_include_item() { + [[ " $1 " == *[[:space:]]$2[[:space:]]* ]] +} + +# This is where beets dynamically adds the _beet function. This +# function sets the variables $flags, $opts, $commands, and $aliases. +complete -o filenames -F _beet beet diff --git a/lib/beets/util/__init__.py b/lib/beets/util/__init__.py new file mode 100644 index 00000000..38cecd70 --- /dev/null +++ b/lib/beets/util/__init__.py @@ -0,0 +1,686 @@ +# This file is part of beets. +# Copyright 2013, Adrian Sampson. +# +# Permission is hereby granted, free of charge, to any person obtaining +# a copy of this software and associated documentation files (the +# "Software"), to deal in the Software without restriction, including +# without limitation the rights to use, copy, modify, merge, publish, +# distribute, sublicense, and/or sell copies of the Software, and to +# permit persons to whom the Software is furnished to do so, subject to +# the following conditions: +# +# The above copyright notice and this permission notice shall be +# included in all copies or substantial portions of the Software. + +"""Miscellaneous utility functions.""" +from __future__ import division + +import os +import sys +import re +import shutil +import fnmatch +from collections import defaultdict +import traceback +import subprocess +import platform + + +MAX_FILENAME_LENGTH = 200 +WINDOWS_MAGIC_PREFIX = u'\\\\?\\' + + +class HumanReadableException(Exception): + """An Exception that can include a human-readable error message to + be logged without a traceback. Can preserve a traceback for + debugging purposes as well. + + Has at least two fields: `reason`, the underlying exception or a + string describing the problem; and `verb`, the action being + performed during the error. + + If `tb` is provided, it is a string containing a traceback for the + associated exception. (Note that this is not necessary in Python 3.x + and should be removed when we make the transition.) + """ + error_kind = 'Error' # Human-readable description of error type. + + def __init__(self, reason, verb, tb=None): + self.reason = reason + self.verb = verb + self.tb = tb + super(HumanReadableException, self).__init__(self.get_message()) + + def _gerund(self): + """Generate a (likely) gerund form of the English verb. + """ + if ' ' in self.verb: + return self.verb + gerund = self.verb[:-1] if self.verb.endswith('e') else self.verb + gerund += 'ing' + return gerund + + def _reasonstr(self): + """Get the reason as a string.""" + if isinstance(self.reason, unicode): + return self.reason + elif isinstance(self.reason, basestring): # Byte string. + return self.reason.decode('utf8', 'ignore') + elif hasattr(self.reason, 'strerror'): # i.e., EnvironmentError + return self.reason.strerror + else: + return u'"{0}"'.format(unicode(self.reason)) + + def get_message(self): + """Create the human-readable description of the error, sans + introduction. + """ + raise NotImplementedError + + def log(self, logger): + """Log to the provided `logger` a human-readable message as an + error and a verbose traceback as a debug message. + """ + if self.tb: + logger.debug(self.tb) + logger.error(u'{0}: {1}'.format(self.error_kind, self.args[0])) + + +class FilesystemError(HumanReadableException): + """An error that occurred while performing a filesystem manipulation + via a function in this module. The `paths` field is a sequence of + pathnames involved in the operation. + """ + def __init__(self, reason, verb, paths, tb=None): + self.paths = paths + super(FilesystemError, self).__init__(reason, verb, tb) + + def get_message(self): + # Use a nicer English phrasing for some specific verbs. + if self.verb in ('move', 'copy', 'rename'): + clause = u'while {0} {1} to {2}'.format( + self._gerund(), + displayable_path(self.paths[0]), + displayable_path(self.paths[1]) + ) + elif self.verb in ('delete', 'write', 'create', 'read'): + clause = u'while {0} {1}'.format( + self._gerund(), + displayable_path(self.paths[0]) + ) + else: + clause = u'during {0} of paths {1}'.format( + self.verb, u', '.join(displayable_path(p) for p in self.paths) + ) + + return u'{0} {1}'.format(self._reasonstr(), clause) + + +def normpath(path): + """Provide the canonical form of the path suitable for storing in + the database. + """ + path = syspath(path, prefix=False) + path = os.path.normpath(os.path.abspath(os.path.expanduser(path))) + return bytestring_path(path) + + +def ancestry(path): + """Return a list consisting of path's parent directory, its + grandparent, and so on. For instance: + + >>> ancestry('/a/b/c') + ['/', '/a', '/a/b'] + + The argument should *not* be the result of a call to `syspath`. + """ + out = [] + last_path = None + while path: + path = os.path.dirname(path) + + if path == last_path: + break + last_path = path + + if path: + # don't yield '' + out.insert(0, path) + return out + + +def sorted_walk(path, ignore=(), logger=None): + """Like `os.walk`, but yields things in case-insensitive sorted, + breadth-first order. Directory and file names matching any glob + pattern in `ignore` are skipped. If `logger` is provided, then + warning messages are logged there when a directory cannot be listed. + """ + # Make sure the path isn't a Unicode string. + path = bytestring_path(path) + + # Get all the directories and files at this level. + try: + contents = os.listdir(syspath(path)) + except OSError as exc: + if logger: + logger.warn(u'could not list directory {0}: {1}'.format( + displayable_path(path), exc.strerror + )) + return + dirs = [] + files = [] + for base in contents: + base = bytestring_path(base) + + # Skip ignored filenames. + skip = False + for pat in ignore: + if fnmatch.fnmatch(base, pat): + skip = True + break + if skip: + continue + + # Add to output as either a file or a directory. + cur = os.path.join(path, base) + if os.path.isdir(syspath(cur)): + dirs.append(base) + else: + files.append(base) + + # Sort lists (case-insensitive) and yield the current level. + dirs.sort(key=bytes.lower) + files.sort(key=bytes.lower) + yield (path, dirs, files) + + # Recurse into directories. + for base in dirs: + cur = os.path.join(path, base) + # yield from sorted_walk(...) + for res in sorted_walk(cur, ignore, logger): + yield res + + +def mkdirall(path): + """Make all the enclosing directories of path (like mkdir -p on the + parent). + """ + for ancestor in ancestry(path): + if not os.path.isdir(syspath(ancestor)): + try: + os.mkdir(syspath(ancestor)) + except (OSError, IOError) as exc: + raise FilesystemError(exc, 'create', (ancestor,), + traceback.format_exc()) + + +def fnmatch_all(names, patterns): + """Determine whether all strings in `names` match at least one of + the `patterns`, which should be shell glob expressions. + """ + for name in names: + matches = False + for pattern in patterns: + matches = fnmatch.fnmatch(name, pattern) + if matches: + break + if not matches: + return False + return True + + +def prune_dirs(path, root=None, clutter=('.DS_Store', 'Thumbs.db')): + """If path is an empty directory, then remove it. Recursively remove + path's ancestry up to root (which is never removed) where there are + empty directories. If path is not contained in root, then nothing is + removed. Glob patterns in clutter are ignored when determining + emptiness. If root is not provided, then only path may be removed + (i.e., no recursive removal). + """ + path = normpath(path) + if root is not None: + root = normpath(root) + + ancestors = ancestry(path) + if root is None: + # Only remove the top directory. + ancestors = [] + elif root in ancestors: + # Only remove directories below the root. + ancestors = ancestors[ancestors.index(root) + 1:] + else: + # Remove nothing. + return + + # Traverse upward from path. + ancestors.append(path) + ancestors.reverse() + for directory in ancestors: + directory = syspath(directory) + if not os.path.exists(directory): + # Directory gone already. + continue + if fnmatch_all(os.listdir(directory), clutter): + # Directory contains only clutter (or nothing). + try: + shutil.rmtree(directory) + except OSError: + break + else: + break + + +def components(path): + """Return a list of the path components in path. For instance: + + >>> components('/a/b/c') + ['a', 'b', 'c'] + + The argument should *not* be the result of a call to `syspath`. + """ + comps = [] + ances = ancestry(path) + for anc in ances: + comp = os.path.basename(anc) + if comp: + comps.append(comp) + else: # root + comps.append(anc) + + last = os.path.basename(path) + if last: + comps.append(last) + + return comps + + +def _fsencoding(): + """Get the system's filesystem encoding. On Windows, this is always + UTF-8 (not MBCS). + """ + encoding = sys.getfilesystemencoding() or sys.getdefaultencoding() + if encoding == 'mbcs': + # On Windows, a broken encoding known to Python as "MBCS" is + # used for the filesystem. However, we only use the Unicode API + # for Windows paths, so the encoding is actually immaterial so + # we can avoid dealing with this nastiness. We arbitrarily + # choose UTF-8. + encoding = 'utf8' + return encoding + + +def bytestring_path(path): + """Given a path, which is either a str or a unicode, returns a str + path (ensuring that we never deal with Unicode pathnames). + """ + # Pass through bytestrings. + if isinstance(path, str): + return path + + # On Windows, remove the magic prefix added by `syspath`. This makes + # ``bytestring_path(syspath(X)) == X``, i.e., we can safely + # round-trip through `syspath`. + if os.path.__name__ == 'ntpath' and path.startswith(WINDOWS_MAGIC_PREFIX): + path = path[len(WINDOWS_MAGIC_PREFIX):] + + # Try to encode with default encodings, but fall back to UTF8. + try: + return path.encode(_fsencoding()) + except (UnicodeError, LookupError): + return path.encode('utf8') + + +def displayable_path(path, separator=u'; '): + """Attempts to decode a bytestring path to a unicode object for the + purpose of displaying it to the user. If the `path` argument is a + list or a tuple, the elements are joined with `separator`. + """ + if isinstance(path, (list, tuple)): + return separator.join(displayable_path(p) for p in path) + elif isinstance(path, unicode): + return path + elif not isinstance(path, str): + # A non-string object: just get its unicode representation. + return unicode(path) + + try: + return path.decode(_fsencoding(), 'ignore') + except (UnicodeError, LookupError): + return path.decode('utf8', 'ignore') + + +def syspath(path, prefix=True): + """Convert a path for use by the operating system. In particular, + paths on Windows must receive a magic prefix and must be converted + to Unicode before they are sent to the OS. To disable the magic + prefix on Windows, set `prefix` to False---but only do this if you + *really* know what you're doing. + """ + # Don't do anything if we're not on windows + if os.path.__name__ != 'ntpath': + return path + + if not isinstance(path, unicode): + # Beets currently represents Windows paths internally with UTF-8 + # arbitrarily. But earlier versions used MBCS because it is + # reported as the FS encoding by Windows. Try both. + try: + path = path.decode('utf8') + except UnicodeError: + # The encoding should always be MBCS, Windows' broken + # Unicode representation. + encoding = sys.getfilesystemencoding() or sys.getdefaultencoding() + path = path.decode(encoding, 'replace') + + # Add the magic prefix if it isn't already there. + # http://msdn.microsoft.com/en-us/library/windows/desktop/aa365247.aspx + if prefix and not path.startswith(WINDOWS_MAGIC_PREFIX): + if path.startswith(u'\\\\'): + # UNC path. Final path should look like \\?\UNC\... + path = u'UNC' + path[1:] + path = WINDOWS_MAGIC_PREFIX + path + + return path + + +def samefile(p1, p2): + """Safer equality for paths.""" + return shutil._samefile(syspath(p1), syspath(p2)) + + +def remove(path, soft=True): + """Remove the file. If `soft`, then no error will be raised if the + file does not exist. + """ + path = syspath(path) + if soft and not os.path.exists(path): + return + try: + os.remove(path) + except (OSError, IOError) as exc: + raise FilesystemError(exc, 'delete', (path,), traceback.format_exc()) + + +def copy(path, dest, replace=False): + """Copy a plain file. Permissions are not copied. If `dest` already + exists, raises a FilesystemError unless `replace` is True. Has no + effect if `path` is the same as `dest`. Paths are translated to + system paths before the syscall. + """ + if samefile(path, dest): + return + path = syspath(path) + dest = syspath(dest) + if not replace and os.path.exists(dest): + raise FilesystemError('file exists', 'copy', (path, dest)) + try: + shutil.copyfile(path, dest) + except (OSError, IOError) as exc: + raise FilesystemError(exc, 'copy', (path, dest), + traceback.format_exc()) + + +def move(path, dest, replace=False): + """Rename a file. `dest` may not be a directory. If `dest` already + exists, raises an OSError unless `replace` is True. Has no effect if + `path` is the same as `dest`. If the paths are on different + filesystems (or the rename otherwise fails), a copy is attempted + instead, in which case metadata will *not* be preserved. Paths are + translated to system paths. + """ + if samefile(path, dest): + return + path = syspath(path) + dest = syspath(dest) + if os.path.exists(dest) and not replace: + raise FilesystemError('file exists', 'rename', (path, dest), + traceback.format_exc()) + + # First, try renaming the file. + try: + os.rename(path, dest) + except OSError: + # Otherwise, copy and delete the original. + try: + shutil.copyfile(path, dest) + os.remove(path) + except (OSError, IOError) as exc: + raise FilesystemError(exc, 'move', (path, dest), + traceback.format_exc()) + + +def link(path, dest, replace=False): + """Create a symbolic link from path to `dest`. Raises an OSError if + `dest` already exists, unless `replace` is True. Does nothing if + `path` == `dest`.""" + if (samefile(path, dest)): + return + + path = syspath(path) + dest = syspath(dest) + if os.path.exists(dest) and not replace: + raise FilesystemError('file exists', 'rename', (path, dest), + traceback.format_exc()) + try: + os.symlink(path, dest) + except OSError: + raise FilesystemError('Operating system does not support symbolic ' + 'links.', 'link', (path, dest), + traceback.format_exc()) + + +def unique_path(path): + """Returns a version of ``path`` that does not exist on the + filesystem. Specifically, if ``path` itself already exists, then + something unique is appended to the path. + """ + if not os.path.exists(syspath(path)): + return path + + base, ext = os.path.splitext(path) + match = re.search(r'\.(\d)+$', base) + if match: + num = int(match.group(1)) + base = base[:match.start()] + else: + num = 0 + while True: + num += 1 + new_path = '%s.%i%s' % (base, num, ext) + if not os.path.exists(new_path): + return new_path + +# Note: The Windows "reserved characters" are, of course, allowed on +# Unix. They are forbidden here because they cause problems on Samba +# shares, which are sufficiently common as to cause frequent problems. +# http://msdn.microsoft.com/en-us/library/windows/desktop/aa365247.aspx +CHAR_REPLACE = [ + (re.compile(ur'[\\/]'), u'_'), # / and \ -- forbidden everywhere. + (re.compile(ur'^\.'), u'_'), # Leading dot (hidden files on Unix). + (re.compile(ur'[\x00-\x1f]'), u''), # Control characters. + (re.compile(ur'[<>:"\?\*\|]'), u'_'), # Windows "reserved characters". + (re.compile(ur'\.$'), u'_'), # Trailing dots. + (re.compile(ur'\s+$'), u''), # Trailing whitespace. +] + + +def sanitize_path(path, replacements=None): + """Takes a path (as a Unicode string) and makes sure that it is + legal. Returns a new path. Only works with fragments; won't work + reliably on Windows when a path begins with a drive letter. Path + separators (including altsep!) should already be cleaned from the + path components. If replacements is specified, it is used *instead* + of the default set of replacements; it must be a list of (compiled + regex, replacement string) pairs. + """ + replacements = replacements or CHAR_REPLACE + + comps = components(path) + if not comps: + return '' + for i, comp in enumerate(comps): + for regex, repl in replacements: + comp = regex.sub(repl, comp) + comps[i] = comp + return os.path.join(*comps) + + +def truncate_path(path, length=MAX_FILENAME_LENGTH): + """Given a bytestring path or a Unicode path fragment, truncate the + components to a legal length. In the last component, the extension + is preserved. + """ + comps = components(path) + + out = [c[:length] for c in comps] + base, ext = os.path.splitext(comps[-1]) + if ext: + # Last component has an extension. + base = base[:length - len(ext)] + out[-1] = base + ext + + return os.path.join(*out) + + +def str2bool(value): + """Returns a boolean reflecting a human-entered string.""" + if value.lower() in ('yes', '1', 'true', 't', 'y'): + return True + else: + return False + + +def as_string(value): + """Convert a value to a Unicode object for matching with a query. + None becomes the empty string. Bytestrings are silently decoded. + """ + if value is None: + return u'' + elif isinstance(value, buffer): + return str(value).decode('utf8', 'ignore') + elif isinstance(value, str): + return value.decode('utf8', 'ignore') + else: + return unicode(value) + + +def levenshtein(s1, s2): + """A nice DP edit distance implementation from Wikibooks: + http://en.wikibooks.org/wiki/Algorithm_implementation/Strings/ + Levenshtein_distance#Python + """ + if len(s1) < len(s2): + return levenshtein(s2, s1) + if not s1: + return len(s2) + + previous_row = xrange(len(s2) + 1) + for i, c1 in enumerate(s1): + current_row = [i + 1] + for j, c2 in enumerate(s2): + insertions = previous_row[j + 1] + 1 + deletions = current_row[j] + 1 + substitutions = previous_row[j] + (c1 != c2) + current_row.append(min(insertions, deletions, substitutions)) + previous_row = current_row + + return previous_row[-1] + + +def plurality(objs): + """Given a sequence of comparable objects, returns the object that + is most common in the set and the frequency of that object. The + sequence must contain at least one object. + """ + # Calculate frequencies. + freqs = defaultdict(int) + for obj in objs: + freqs[obj] += 1 + + if not freqs: + raise ValueError('sequence must be non-empty') + + # Find object with maximum frequency. + max_freq = 0 + res = None + for obj, freq in freqs.items(): + if freq > max_freq: + max_freq = freq + res = obj + + return res, max_freq + + +def cpu_count(): + """Return the number of hardware thread contexts (cores or SMT + threads) in the system. + """ + # Adapted from the soundconverter project: + # https://github.com/kassoulet/soundconverter + if sys.platform == 'win32': + try: + num = int(os.environ['NUMBER_OF_PROCESSORS']) + except (ValueError, KeyError): + num = 0 + elif sys.platform == 'darwin': + try: + num = int(command_output(['sysctl', '-n', 'hw.ncpu'])) + except ValueError: + num = 0 + else: + try: + num = os.sysconf('SC_NPROCESSORS_ONLN') + except (ValueError, OSError, AttributeError): + num = 0 + if num >= 1: + return num + else: + return 1 + + +def command_output(cmd, shell=False): + """Runs the command and returns its output after it has exited. + + ``cmd`` is a list of arguments starting with the command names. If + ``shell`` is true, ``cmd`` is assumed to be a string and passed to a + shell to execute. + + If the process exits with a non-zero return code + ``subprocess.CalledProcessError`` is raised. May also raise + ``OSError``. + + This replaces `subprocess.check_output`, which isn't available in + Python 2.6 and which can have problems if lots of output is sent to + stderr. + """ + proc = subprocess.Popen( + cmd, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + close_fds=platform.system() != 'Windows', + shell=shell + ) + stdout, stderr = proc.communicate() + if proc.returncode: + raise subprocess.CalledProcessError( + returncode=proc.returncode, + cmd=' '.join(cmd), + ) + return stdout + + +def max_filename_length(path, limit=MAX_FILENAME_LENGTH): + """Attempt to determine the maximum filename length for the + filesystem containing `path`. If the value is greater than `limit`, + then `limit` is used instead (to prevent errors when a filesystem + misreports its capacity). If it cannot be determined (e.g., on + Windows), return `limit`. + """ + if hasattr(os, 'statvfs'): + try: + res = os.statvfs(path) + except OSError: + return limit + return min(res[9], limit) + else: + return limit diff --git a/lib/beets/util/artresizer.py b/lib/beets/util/artresizer.py new file mode 100644 index 00000000..f17fdc5b --- /dev/null +++ b/lib/beets/util/artresizer.py @@ -0,0 +1,211 @@ +# This file is part of beets. +# Copyright 2014, Fabrice Laporte +# +# Permission is hereby granted, free of charge, to any person obtaining +# a copy of this software and associated documentation files (the +# "Software"), to deal in the Software without restriction, including +# without limitation the rights to use, copy, modify, merge, publish, +# distribute, sublicense, and/or sell copies of the Software, and to +# permit persons to whom the Software is furnished to do so, subject to +# the following conditions: +# +# The above copyright notice and this permission notice shall be +# included in all copies or substantial portions of the Software. + +"""Abstraction layer to resize images using PIL, ImageMagick, or a +public resizing proxy if neither is available. +""" +import urllib +import subprocess +import os +import re +from tempfile import NamedTemporaryFile +import logging +from beets import util + +# Resizing methods +PIL = 1 +IMAGEMAGICK = 2 +WEBPROXY = 3 + +PROXY_URL = 'http://images.weserv.nl/' + +log = logging.getLogger('beets') + + +def resize_url(url, maxwidth): + """Return a proxied image URL that resizes the original image to + maxwidth (preserving aspect ratio). + """ + return '{0}?{1}'.format(PROXY_URL, urllib.urlencode({ + 'url': url.replace('http://', ''), + 'w': str(maxwidth), + })) + + +def temp_file_for(path): + """Return an unused filename with the same extension as the + specified path. + """ + ext = os.path.splitext(path)[1] + with NamedTemporaryFile(suffix=ext, delete=False) as f: + return f.name + + +def pil_resize(maxwidth, path_in, path_out=None): + """Resize using Python Imaging Library (PIL). Return the output path + of resized image. + """ + path_out = path_out or temp_file_for(path_in) + from PIL import Image + log.debug(u'artresizer: PIL resizing {0} to {1}'.format( + util.displayable_path(path_in), util.displayable_path(path_out) + )) + + try: + im = Image.open(util.syspath(path_in)) + size = maxwidth, maxwidth + im.thumbnail(size, Image.ANTIALIAS) + im.save(path_out) + return path_out + except IOError: + log.error(u"PIL cannot create thumbnail for '{0}'".format( + util.displayable_path(path_in) + )) + return path_in + + +def im_resize(maxwidth, path_in, path_out=None): + """Resize using ImageMagick's ``convert`` tool. + Return the output path of resized image. + """ + path_out = path_out or temp_file_for(path_in) + log.debug(u'artresizer: ImageMagick resizing {0} to {1}'.format( + util.displayable_path(path_in), util.displayable_path(path_out) + )) + + # "-resize widthxheight>" shrinks images with dimension(s) larger + # than the corresponding width and/or height dimension(s). The > + # "only shrink" flag is prefixed by ^ escape char for Windows + # compatibility. + try: + util.command_output([ + 'convert', util.syspath(path_in), + '-resize', '{0}x^>'.format(maxwidth), path_out + ]) + except subprocess.CalledProcessError: + log.warn(u'artresizer: IM convert failed for {0}'.format( + util.displayable_path(path_in) + )) + return path_in + return path_out + + +BACKEND_FUNCS = { + PIL: pil_resize, + IMAGEMAGICK: im_resize, +} + + +class Shareable(type): + """A pseudo-singleton metaclass that allows both shared and + non-shared instances. The ``MyClass.shared`` property holds a + lazily-created shared instance of ``MyClass`` while calling + ``MyClass()`` to construct a new object works as usual. + """ + def __init__(cls, name, bases, dict): + super(Shareable, cls).__init__(name, bases, dict) + cls._instance = None + + @property + def shared(cls): + if cls._instance is None: + cls._instance = cls() + return cls._instance + + +class ArtResizer(object): + """A singleton class that performs image resizes. + """ + __metaclass__ = Shareable + + def __init__(self, method=None): + """Create a resizer object for the given method or, if none is + specified, with an inferred method. + """ + self.method = self._check_method(method) + log.debug(u"artresizer: method is {0}".format(self.method)) + self.can_compare = self._can_compare() + + def resize(self, maxwidth, path_in, path_out=None): + """Manipulate an image file according to the method, returning a + new path. For PIL or IMAGEMAGIC methods, resizes the image to a + temporary file. For WEBPROXY, returns `path_in` unmodified. + """ + if self.local: + func = BACKEND_FUNCS[self.method[0]] + return func(maxwidth, path_in, path_out) + else: + return path_in + + def proxy_url(self, maxwidth, url): + """Modifies an image URL according the method, returning a new + URL. For WEBPROXY, a URL on the proxy server is returned. + Otherwise, the URL is returned unmodified. + """ + if self.local: + return url + else: + return resize_url(url, maxwidth) + + @property + def local(self): + """A boolean indicating whether the resizing method is performed + locally (i.e., PIL or ImageMagick). + """ + return self.method[0] in BACKEND_FUNCS + + def _can_compare(self): + """A boolean indicating whether image comparison is available""" + + return self.method[0] == IMAGEMAGICK and self.method[1] > (6, 8, 7) + + @staticmethod + def _check_method(method=None): + """A tuple indicating whether current method is available and its + version. If no method is given, it returns a supported one. + """ + # Guess available method + if not method: + for m in [IMAGEMAGICK, PIL]: + _, version = ArtResizer._check_method(m) + if version: + return (m, version) + return (WEBPROXY, (0)) + + if method == IMAGEMAGICK: + + # Try invoking ImageMagick's "convert". + try: + out = util.command_output(['identify', '--version']) + + if 'imagemagick' in out.lower(): + pattern = r".+ (\d+)\.(\d+)\.(\d+).*" + match = re.search(pattern, out) + if match: + return (IMAGEMAGICK, + (int(match.group(1)), + int(match.group(2)), + int(match.group(3)))) + return (IMAGEMAGICK, (0)) + + except (subprocess.CalledProcessError, OSError): + return (IMAGEMAGICK, None) + + if method == PIL: + # Try importing PIL. + try: + __import__('PIL', fromlist=['Image']) + return (PIL, (0)) + except ImportError: + return (PIL, None) diff --git a/lib/beets/util/bluelet.py b/lib/beets/util/bluelet.py new file mode 100644 index 00000000..a12ec945 --- /dev/null +++ b/lib/beets/util/bluelet.py @@ -0,0 +1,647 @@ +"""Extremely simple pure-Python implementation of coroutine-style +asynchronous socket I/O. Inspired by, but inferior to, Eventlet. +Bluelet can also be thought of as a less-terrible replacement for +asyncore. + +Bluelet: easy concurrency without all the messy parallelism. +""" +import socket +import select +import sys +import types +import errno +import traceback +import time +import collections + + +# A little bit of "six" (Python 2/3 compatibility): cope with PEP 3109 syntax +# changes. + +PY3 = sys.version_info[0] == 3 +if PY3: + def _reraise(typ, exc, tb): + raise exc.with_traceback(tb) +else: + exec(""" +def _reraise(typ, exc, tb): + raise typ, exc, tb +""") + + +# Basic events used for thread scheduling. + +class Event(object): + """Just a base class identifying Bluelet events. An event is an + object yielded from a Bluelet thread coroutine to suspend operation + and communicate with the scheduler. + """ + pass + + +class WaitableEvent(Event): + """A waitable event is one encapsulating an action that can be + waited for using a select() call. That is, it's an event with an + associated file descriptor. + """ + def waitables(self): + """Return "waitable" objects to pass to select(). Should return + three iterables for input readiness, output readiness, and + exceptional conditions (i.e., the three lists passed to + select()). + """ + return (), (), () + + def fire(self): + """Called when an associated file descriptor becomes ready + (i.e., is returned from a select() call). + """ + pass + + +class ValueEvent(Event): + """An event that does nothing but return a fixed value.""" + def __init__(self, value): + self.value = value + + +class ExceptionEvent(Event): + """Raise an exception at the yield point. Used internally.""" + def __init__(self, exc_info): + self.exc_info = exc_info + + +class SpawnEvent(Event): + """Add a new coroutine thread to the scheduler.""" + def __init__(self, coro): + self.spawned = coro + + +class JoinEvent(Event): + """Suspend the thread until the specified child thread has + completed. + """ + def __init__(self, child): + self.child = child + + +class KillEvent(Event): + """Unschedule a child thread.""" + def __init__(self, child): + self.child = child + + +class DelegationEvent(Event): + """Suspend execution of the current thread, start a new thread and, + once the child thread finished, return control to the parent + thread. + """ + def __init__(self, coro): + self.spawned = coro + + +class ReturnEvent(Event): + """Return a value the current thread's delegator at the point of + delegation. Ends the current (delegate) thread. + """ + def __init__(self, value): + self.value = value + + +class SleepEvent(WaitableEvent): + """Suspend the thread for a given duration. + """ + def __init__(self, duration): + self.wakeup_time = time.time() + duration + + def time_left(self): + return max(self.wakeup_time - time.time(), 0.0) + + +class ReadEvent(WaitableEvent): + """Reads from a file-like object.""" + def __init__(self, fd, bufsize): + self.fd = fd + self.bufsize = bufsize + + def waitables(self): + return (self.fd,), (), () + + def fire(self): + return self.fd.read(self.bufsize) + + +class WriteEvent(WaitableEvent): + """Writes to a file-like object.""" + def __init__(self, fd, data): + self.fd = fd + self.data = data + + def waitable(self): + return (), (self.fd,), () + + def fire(self): + self.fd.write(self.data) + + +# Core logic for executing and scheduling threads. + +def _event_select(events): + """Perform a select() over all the Events provided, returning the + ones ready to be fired. Only WaitableEvents (including SleepEvents) + matter here; all other events are ignored (and thus postponed). + """ + # Gather waitables and wakeup times. + waitable_to_event = {} + rlist, wlist, xlist = [], [], [] + earliest_wakeup = None + for event in events: + if isinstance(event, SleepEvent): + if not earliest_wakeup: + earliest_wakeup = event.wakeup_time + else: + earliest_wakeup = min(earliest_wakeup, event.wakeup_time) + elif isinstance(event, WaitableEvent): + r, w, x = event.waitables() + rlist += r + wlist += w + xlist += x + for waitable in r: + waitable_to_event[('r', waitable)] = event + for waitable in w: + waitable_to_event[('w', waitable)] = event + for waitable in x: + waitable_to_event[('x', waitable)] = event + + # If we have a any sleeping threads, determine how long to sleep. + if earliest_wakeup: + timeout = max(earliest_wakeup - time.time(), 0.0) + else: + timeout = None + + # Perform select() if we have any waitables. + if rlist or wlist or xlist: + rready, wready, xready = select.select(rlist, wlist, xlist, timeout) + else: + rready, wready, xready = (), (), () + if timeout: + time.sleep(timeout) + + # Gather ready events corresponding to the ready waitables. + ready_events = set() + for ready in rready: + ready_events.add(waitable_to_event[('r', ready)]) + for ready in wready: + ready_events.add(waitable_to_event[('w', ready)]) + for ready in xready: + ready_events.add(waitable_to_event[('x', ready)]) + + # Gather any finished sleeps. + for event in events: + if isinstance(event, SleepEvent) and event.time_left() == 0.0: + ready_events.add(event) + + return ready_events + + +class ThreadException(Exception): + def __init__(self, coro, exc_info): + self.coro = coro + self.exc_info = exc_info + + def reraise(self): + _reraise(self.exc_info[0], self.exc_info[1], self.exc_info[2]) + + +SUSPENDED = Event() # Special sentinel placeholder for suspended threads. + + +class Delegated(Event): + """Placeholder indicating that a thread has delegated execution to a + different thread. + """ + def __init__(self, child): + self.child = child + + +def run(root_coro): + """Schedules a coroutine, running it to completion. This + encapsulates the Bluelet scheduler, which the root coroutine can + add to by spawning new coroutines. + """ + # The "threads" dictionary keeps track of all the currently- + # executing and suspended coroutines. It maps coroutines to their + # currently "blocking" event. The event value may be SUSPENDED if + # the coroutine is waiting on some other condition: namely, a + # delegated coroutine or a joined coroutine. In this case, the + # coroutine should *also* appear as a value in one of the below + # dictionaries `delegators` or `joiners`. + threads = {root_coro: ValueEvent(None)} + + # Maps child coroutines to delegating parents. + delegators = {} + + # Maps child coroutines to joining (exit-waiting) parents. + joiners = collections.defaultdict(list) + + def complete_thread(coro, return_value): + """Remove a coroutine from the scheduling pool, awaking + delegators and joiners as necessary and returning the specified + value to any delegating parent. + """ + del threads[coro] + + # Resume delegator. + if coro in delegators: + threads[delegators[coro]] = ValueEvent(return_value) + del delegators[coro] + + # Resume joiners. + if coro in joiners: + for parent in joiners[coro]: + threads[parent] = ValueEvent(None) + del joiners[coro] + + def advance_thread(coro, value, is_exc=False): + """After an event is fired, run a given coroutine associated with + it in the threads dict until it yields again. If the coroutine + exits, then the thread is removed from the pool. If the coroutine + raises an exception, it is reraised in a ThreadException. If + is_exc is True, then the value must be an exc_info tuple and the + exception is thrown into the coroutine. + """ + try: + if is_exc: + next_event = coro.throw(*value) + else: + next_event = coro.send(value) + except StopIteration: + # Thread is done. + complete_thread(coro, None) + except: + # Thread raised some other exception. + del threads[coro] + raise ThreadException(coro, sys.exc_info()) + else: + if isinstance(next_event, types.GeneratorType): + # Automatically invoke sub-coroutines. (Shorthand for + # explicit bluelet.call().) + next_event = DelegationEvent(next_event) + threads[coro] = next_event + + def kill_thread(coro): + """Unschedule this thread and its (recursive) delegates. + """ + # Collect all coroutines in the delegation stack. + coros = [coro] + while isinstance(threads[coro], Delegated): + coro = threads[coro].child + coros.append(coro) + + # Complete each coroutine from the top to the bottom of the + # stack. + for coro in reversed(coros): + complete_thread(coro, None) + + # Continue advancing threads until root thread exits. + exit_te = None + while threads: + try: + # Look for events that can be run immediately. Continue + # running immediate events until nothing is ready. + while True: + have_ready = False + for coro, event in list(threads.items()): + if isinstance(event, SpawnEvent): + threads[event.spawned] = ValueEvent(None) # Spawn. + advance_thread(coro, None) + have_ready = True + elif isinstance(event, ValueEvent): + advance_thread(coro, event.value) + have_ready = True + elif isinstance(event, ExceptionEvent): + advance_thread(coro, event.exc_info, True) + have_ready = True + elif isinstance(event, DelegationEvent): + threads[coro] = Delegated(event.spawned) # Suspend. + threads[event.spawned] = ValueEvent(None) # Spawn. + delegators[event.spawned] = coro + have_ready = True + elif isinstance(event, ReturnEvent): + # Thread is done. + complete_thread(coro, event.value) + have_ready = True + elif isinstance(event, JoinEvent): + threads[coro] = SUSPENDED # Suspend. + joiners[event.child].append(coro) + have_ready = True + elif isinstance(event, KillEvent): + threads[coro] = ValueEvent(None) + kill_thread(event.child) + have_ready = True + + # Only start the select when nothing else is ready. + if not have_ready: + break + + # Wait and fire. + event2coro = dict((v, k) for k, v in threads.items()) + for event in _event_select(threads.values()): + # Run the IO operation, but catch socket errors. + try: + value = event.fire() + except socket.error as exc: + if isinstance(exc.args, tuple) and \ + exc.args[0] == errno.EPIPE: + # Broken pipe. Remote host disconnected. + pass + else: + traceback.print_exc() + # Abort the coroutine. + threads[event2coro[event]] = ReturnEvent(None) + else: + advance_thread(event2coro[event], value) + + except ThreadException as te: + # Exception raised from inside a thread. + event = ExceptionEvent(te.exc_info) + if te.coro in delegators: + # The thread is a delegate. Raise exception in its + # delegator. + threads[delegators[te.coro]] = event + del delegators[te.coro] + else: + # The thread is root-level. Raise in client code. + exit_te = te + break + + except: + # For instance, KeyboardInterrupt during select(). Raise + # into root thread and terminate others. + threads = {root_coro: ExceptionEvent(sys.exc_info())} + + # If any threads still remain, kill them. + for coro in threads: + coro.close() + + # If we're exiting with an exception, raise it in the client. + if exit_te: + exit_te.reraise() + + +# Sockets and their associated events. + +class SocketClosedError(Exception): + pass + + +class Listener(object): + """A socket wrapper object for listening sockets. + """ + def __init__(self, host, port): + """Create a listening socket on the given hostname and port. + """ + self._closed = False + self.host = host + self.port = port + self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + self.sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + self.sock.bind((host, port)) + self.sock.listen(5) + + def accept(self): + """An event that waits for a connection on the listening socket. + When a connection is made, the event returns a Connection + object. + """ + if self._closed: + raise SocketClosedError() + return AcceptEvent(self) + + def close(self): + """Immediately close the listening socket. (Not an event.) + """ + self._closed = True + self.sock.close() + + +class Connection(object): + """A socket wrapper object for connected sockets. + """ + def __init__(self, sock, addr): + self.sock = sock + self.addr = addr + self._buf = b'' + self._closed = False + + def close(self): + """Close the connection.""" + self._closed = True + self.sock.close() + + def recv(self, size): + """Read at most size bytes of data from the socket.""" + if self._closed: + raise SocketClosedError() + + if self._buf: + # We already have data read previously. + out = self._buf[:size] + self._buf = self._buf[size:] + return ValueEvent(out) + else: + return ReceiveEvent(self, size) + + def send(self, data): + """Sends data on the socket, returning the number of bytes + successfully sent. + """ + if self._closed: + raise SocketClosedError() + return SendEvent(self, data) + + def sendall(self, data): + """Send all of data on the socket.""" + if self._closed: + raise SocketClosedError() + return SendEvent(self, data, True) + + def readline(self, terminator=b"\n", bufsize=1024): + """Reads a line (delimited by terminator) from the socket.""" + if self._closed: + raise SocketClosedError() + + while True: + if terminator in self._buf: + line, self._buf = self._buf.split(terminator, 1) + line += terminator + yield ReturnEvent(line) + break + data = yield ReceiveEvent(self, bufsize) + if data: + self._buf += data + else: + line = self._buf + self._buf = b'' + yield ReturnEvent(line) + break + + +class AcceptEvent(WaitableEvent): + """An event for Listener objects (listening sockets) that suspends + execution until the socket gets a connection. + """ + def __init__(self, listener): + self.listener = listener + + def waitables(self): + return (self.listener.sock,), (), () + + def fire(self): + sock, addr = self.listener.sock.accept() + return Connection(sock, addr) + + +class ReceiveEvent(WaitableEvent): + """An event for Connection objects (connected sockets) for + asynchronously reading data. + """ + def __init__(self, conn, bufsize): + self.conn = conn + self.bufsize = bufsize + + def waitables(self): + return (self.conn.sock,), (), () + + def fire(self): + return self.conn.sock.recv(self.bufsize) + + +class SendEvent(WaitableEvent): + """An event for Connection objects (connected sockets) for + asynchronously writing data. + """ + def __init__(self, conn, data, sendall=False): + self.conn = conn + self.data = data + self.sendall = sendall + + def waitables(self): + return (), (self.conn.sock,), () + + def fire(self): + if self.sendall: + return self.conn.sock.sendall(self.data) + else: + return self.conn.sock.send(self.data) + + +# Public interface for threads; each returns an event object that +# can immediately be "yield"ed. + +def null(): + """Event: yield to the scheduler without doing anything special. + """ + return ValueEvent(None) + + +def spawn(coro): + """Event: add another coroutine to the scheduler. Both the parent + and child coroutines run concurrently. + """ + if not isinstance(coro, types.GeneratorType): + raise ValueError('%s is not a coroutine' % str(coro)) + return SpawnEvent(coro) + + +def call(coro): + """Event: delegate to another coroutine. The current coroutine + is resumed once the sub-coroutine finishes. If the sub-coroutine + returns a value using end(), then this event returns that value. + """ + if not isinstance(coro, types.GeneratorType): + raise ValueError('%s is not a coroutine' % str(coro)) + return DelegationEvent(coro) + + +def end(value=None): + """Event: ends the coroutine and returns a value to its + delegator. + """ + return ReturnEvent(value) + + +def read(fd, bufsize=None): + """Event: read from a file descriptor asynchronously.""" + if bufsize is None: + # Read all. + def reader(): + buf = [] + while True: + data = yield read(fd, 1024) + if not data: + break + buf.append(data) + yield ReturnEvent(''.join(buf)) + return DelegationEvent(reader()) + + else: + return ReadEvent(fd, bufsize) + + +def write(fd, data): + """Event: write to a file descriptor asynchronously.""" + return WriteEvent(fd, data) + + +def connect(host, port): + """Event: connect to a network address and return a Connection + object for communicating on the socket. + """ + addr = (host, port) + sock = socket.create_connection(addr) + return ValueEvent(Connection(sock, addr)) + + +def sleep(duration): + """Event: suspend the thread for ``duration`` seconds. + """ + return SleepEvent(duration) + + +def join(coro): + """Suspend the thread until another, previously `spawn`ed thread + completes. + """ + return JoinEvent(coro) + + +def kill(coro): + """Halt the execution of a different `spawn`ed thread. + """ + return KillEvent(coro) + + +# Convenience function for running socket servers. + +def server(host, port, func): + """A coroutine that runs a network server. Host and port specify the + listening address. func should be a coroutine that takes a single + parameter, a Connection object. The coroutine is invoked for every + incoming connection on the listening socket. + """ + def handler(conn): + try: + yield func(conn) + finally: + conn.close() + + listener = Listener(host, port) + try: + while True: + conn = yield listener.accept() + yield spawn(handler(conn)) + except KeyboardInterrupt: + pass + finally: + listener.close() diff --git a/lib/beets/util/confit.py b/lib/beets/util/confit.py new file mode 100644 index 00000000..de22e0ad --- /dev/null +++ b/lib/beets/util/confit.py @@ -0,0 +1,1281 @@ +# This file is part of Confit. +# Copyright 2014, Adrian Sampson. +# +# Permission is hereby granted, free of charge, to any person obtaining +# a copy of this software and associated documentation files (the +# "Software"), to deal in the Software without restriction, including +# without limitation the rights to use, copy, modify, merge, publish, +# distribute, sublicense, and/or sell copies of the Software, and to +# permit persons to whom the Software is furnished to do so, subject to +# the following conditions: +# +# The above copyright notice and this permission notice shall be +# included in all copies or substantial portions of the Software. + +"""Worry-free YAML configuration files. +""" +from __future__ import unicode_literals +import platform +import os +import pkgutil +import sys +import yaml +import types +import collections +import re +try: + from collections import OrderedDict +except ImportError: + from ordereddict import OrderedDict + +UNIX_DIR_VAR = 'XDG_CONFIG_HOME' +UNIX_DIR_FALLBACK = '~/.config' +WINDOWS_DIR_VAR = 'APPDATA' +WINDOWS_DIR_FALLBACK = '~\\AppData\\Roaming' +MAC_DIR = '~/Library/Application Support' + +CONFIG_FILENAME = 'config.yaml' +DEFAULT_FILENAME = 'config_default.yaml' +ROOT_NAME = 'root' + +YAML_TAB_PROBLEM = "found character '\\t' that cannot start any token" + + +# Utilities. + +PY3 = sys.version_info[0] == 3 +STRING = str if PY3 else unicode +BASESTRING = str if PY3 else basestring +NUMERIC_TYPES = (int, float) if PY3 else (int, float, long) +TYPE_TYPES = (type,) if PY3 else (type, types.ClassType) + + +def iter_first(sequence): + """Get the first element from an iterable or raise a ValueError if + the iterator generates no values. + """ + it = iter(sequence) + try: + if PY3: + return next(it) + else: + return it.next() + except StopIteration: + raise ValueError() + + +# Exceptions. + +class ConfigError(Exception): + """Base class for exceptions raised when querying a configuration. + """ + + +class NotFoundError(ConfigError): + """A requested value could not be found in the configuration trees. + """ + + +class ConfigValueError(ConfigError): + """The value in the configuration is illegal.""" + + +class ConfigTypeError(ConfigValueError): + """The value in the configuration did not match the expected type. + """ + + +class ConfigTemplateError(ConfigError): + """Base class for exceptions raised because of an invalid template. + """ + + +class ConfigReadError(ConfigError): + """A configuration file could not be read.""" + def __init__(self, filename, reason=None): + self.filename = filename + self.reason = reason + + message = 'file {0} could not be read'.format(filename) + if isinstance(reason, yaml.scanner.ScannerError) and \ + reason.problem == YAML_TAB_PROBLEM: + # Special-case error message for tab indentation in YAML markup. + message += ': found tab character at line {0}, column {1}'.format( + reason.problem_mark.line + 1, + reason.problem_mark.column + 1, + ) + elif reason: + # Generic error message uses exception's message. + message += ': {0}'.format(reason) + + super(ConfigReadError, self).__init__(message) + + +# Views and sources. + +class ConfigSource(dict): + """A dictionary augmented with metadata about the source of the + configuration. + """ + def __init__(self, value, filename=None, default=False): + super(ConfigSource, self).__init__(value) + if filename is not None and not isinstance(filename, BASESTRING): + raise TypeError('filename must be a string or None') + self.filename = filename + self.default = default + + def __repr__(self): + return 'ConfigSource({0}, {1}, {2})'.format( + super(ConfigSource, self).__repr__(), + repr(self.filename), + repr(self.default) + ) + + @classmethod + def of(self, value): + """Given either a dictionary or a `ConfigSource` object, return + a `ConfigSource` object. This lets a function accept either type + of object as an argument. + """ + if isinstance(value, ConfigSource): + return value + elif isinstance(value, dict): + return ConfigSource(value) + else: + raise TypeError('source value must be a dict') + + +class ConfigView(object): + """A configuration "view" is a query into a program's configuration + data. A view represents a hypothetical location in the configuration + tree; to extract the data from the location, a client typically + calls the ``view.get()`` method. The client can access children in + the tree (subviews) by subscripting the parent view (i.e., + ``view[key]``). + """ + + name = None + """The name of the view, depicting the path taken through the + configuration in Python-like syntax (e.g., ``foo['bar'][42]``). + """ + + def resolve(self): + """The core (internal) data retrieval method. Generates (value, + source) pairs for each source that contains a value for this + view. May raise ConfigTypeError if a type error occurs while + traversing a source. + """ + raise NotImplementedError + + def first(self): + """Return a (value, source) pair for the first object found for + this view. This amounts to the first element returned by + `resolve`. If no values are available, a NotFoundError is + raised. + """ + pairs = self.resolve() + try: + return iter_first(pairs) + except ValueError: + raise NotFoundError("{0} not found".format(self.name)) + + def exists(self): + """Determine whether the view has a setting in any source. + """ + try: + self.first() + except NotFoundError: + return False + return True + + def add(self, value): + """Set the *default* value for this configuration view. The + specified value is added as the lowest-priority configuration + data source. + """ + raise NotImplementedError + + def set(self, value): + """*Override* the value for this configuration view. The + specified value is added as the highest-priority configuration + data source. + """ + raise NotImplementedError + + def root(self): + """The RootView object from which this view is descended. + """ + raise NotImplementedError + + def __repr__(self): + return '' % self.name + + def __getitem__(self, key): + """Get a subview of this view.""" + return Subview(self, key) + + def __setitem__(self, key, value): + """Create an overlay source to assign a given key under this + view. + """ + self.set({key: value}) + + def __contains__(self, key): + return self[key].exists() + + def set_args(self, namespace): + """Overlay parsed command-line arguments, generated by a library + like argparse or optparse, onto this view's value. + """ + args = {} + for key, value in namespace.__dict__.items(): + if value is not None: # Avoid unset options. + args[key] = value + self.set(args) + + # Magical conversions. These special methods make it possible to use + # View objects somewhat transparently in certain circumstances. For + # example, rather than using ``view.get(bool)``, it's possible to + # just say ``bool(view)`` or use ``view`` in a conditional. + + def __str__(self): + """Gets the value for this view as a byte string.""" + return str(self.get()) + + def __unicode__(self): + """Gets the value for this view as a unicode string. (Python 2 + only.) + """ + return unicode(self.get()) + + def __nonzero__(self): + """Gets the value for this view as a boolean. (Python 2 only.) + """ + return self.__bool__() + + def __bool__(self): + """Gets the value for this view as a boolean. (Python 3 only.) + """ + return bool(self.get()) + + # Dictionary emulation methods. + + def keys(self): + """Returns a list containing all the keys available as subviews + of the current views. This enumerates all the keys in *all* + dictionaries matching the current view, in contrast to + ``view.get(dict).keys()``, which gets all the keys for the + *first* dict matching the view. If the object for this view in + any source is not a dict, then a ConfigTypeError is raised. The + keys are ordered according to how they appear in each source. + """ + keys = [] + + for dic, _ in self.resolve(): + try: + cur_keys = dic.keys() + except AttributeError: + raise ConfigTypeError( + '{0} must be a dict, not {1}'.format( + self.name, type(dic).__name__ + ) + ) + + for key in cur_keys: + if key not in keys: + keys.append(key) + + return keys + + def items(self): + """Iterates over (key, subview) pairs contained in dictionaries + from *all* sources at this view. If the object for this view in + any source is not a dict, then a ConfigTypeError is raised. + """ + for key in self.keys(): + yield key, self[key] + + def values(self): + """Iterates over all the subviews contained in dictionaries from + *all* sources at this view. If the object for this view in any + source is not a dict, then a ConfigTypeError is raised. + """ + for key in self.keys(): + yield self[key] + + # List/sequence emulation. + + def all_contents(self): + """Iterates over all subviews from collections at this view from + *all* sources. If the object for this view in any source is not + iterable, then a ConfigTypeError is raised. This method is + intended to be used when the view indicates a list; this method + will concatenate the contents of the list from all sources. + """ + for collection, _ in self.resolve(): + try: + it = iter(collection) + except TypeError: + raise ConfigTypeError( + '{0} must be an iterable, not {1}'.format( + self.name, type(collection).__name__ + ) + ) + for value in it: + yield value + + # Validation and conversion. + + def flatten(self): + """Create a hierarchy of OrderedDicts containing the data from + this view, recursively reifying all views to get their + represented values. + """ + od = OrderedDict() + for key, view in self.items(): + try: + od[key] = view.flatten() + except ConfigTypeError: + od[key] = view.get() + return od + + def get(self, template=None): + """Retrieve the value for this view according to the template. + + The `template` against which the values are checked can be + anything convertible to a `Template` using `as_template`. This + means you can pass in a default integer or string value, for + example, or a type to just check that something matches the type + you expect. + + May raise a `ConfigValueError` (or its subclass, + `ConfigTypeError`) or a `NotFoundError` when the configuration + doesn't satisfy the template. + """ + return as_template(template).value(self, template) + + # Old validation methods (deprecated). + + def as_filename(self): + return self.get(Filename()) + + def as_choice(self, choices): + return self.get(Choice(choices)) + + def as_number(self): + return self.get(Number()) + + def as_str_seq(self): + return self.get(StrSeq()) + + +class RootView(ConfigView): + """The base of a view hierarchy. This view keeps track of the + sources that may be accessed by subviews. + """ + def __init__(self, sources): + """Create a configuration hierarchy for a list of sources. At + least one source must be provided. The first source in the list + has the highest priority. + """ + self.sources = list(sources) + self.name = ROOT_NAME + + def add(self, obj): + self.sources.append(ConfigSource.of(obj)) + + def set(self, value): + self.sources.insert(0, ConfigSource.of(value)) + + def resolve(self): + return ((dict(s), s) for s in self.sources) + + def clear(self): + """Remove all sources from this configuration.""" + del self.sources[:] + + def root(self): + return self + + +class Subview(ConfigView): + """A subview accessed via a subscript of a parent view.""" + def __init__(self, parent, key): + """Make a subview of a parent view for a given subscript key. + """ + self.parent = parent + self.key = key + + # Choose a human-readable name for this view. + if isinstance(self.parent, RootView): + self.name = '' + else: + self.name = self.parent.name + if not isinstance(self.key, int): + self.name += '.' + if isinstance(self.key, int): + self.name += '#{0}'.format(self.key) + elif isinstance(self.key, BASESTRING): + self.name += '{0}'.format(self.key) + else: + self.name += '{0}'.format(repr(self.key)) + + def resolve(self): + for collection, source in self.parent.resolve(): + try: + value = collection[self.key] + except IndexError: + # List index out of bounds. + continue + except KeyError: + # Dict key does not exist. + continue + except TypeError: + # Not subscriptable. + raise ConfigTypeError( + "{0} must be a collection, not {1}".format( + self.parent.name, type(collection).__name__ + ) + ) + yield value, source + + def set(self, value): + self.parent.set({self.key: value}) + + def add(self, value): + self.parent.add({self.key: value}) + + def root(self): + return self.parent.root() + + +# Config file paths, including platform-specific paths and in-package +# defaults. + +# Based on get_root_path from Flask by Armin Ronacher. +def _package_path(name): + """Returns the path to the package containing the named module or + None if the path could not be identified (e.g., if + ``name == "__main__"``). + """ + loader = pkgutil.get_loader(name) + if loader is None or name == '__main__': + return None + + if hasattr(loader, 'get_filename'): + filepath = loader.get_filename(name) + else: + # Fall back to importing the specified module. + __import__(name) + filepath = sys.modules[name].__file__ + + return os.path.dirname(os.path.abspath(filepath)) + + +def config_dirs(): + """Return a platform-specific list of candidates for user + configuration directories on the system. + + The candidates are in order of priority, from highest to lowest. The + last element is the "fallback" location to be used when no + higher-priority config file exists. + """ + paths = [] + + if platform.system() == 'Darwin': + paths.append(MAC_DIR) + paths.append(UNIX_DIR_FALLBACK) + if UNIX_DIR_VAR in os.environ: + paths.append(os.environ[UNIX_DIR_VAR]) + + elif platform.system() == 'Windows': + paths.append(WINDOWS_DIR_FALLBACK) + if WINDOWS_DIR_VAR in os.environ: + paths.append(os.environ[WINDOWS_DIR_VAR]) + + else: + # Assume Unix. + paths.append(UNIX_DIR_FALLBACK) + if UNIX_DIR_VAR in os.environ: + paths.append(os.environ[UNIX_DIR_VAR]) + + # Expand and deduplicate paths. + out = [] + for path in paths: + path = os.path.abspath(os.path.expanduser(path)) + if path not in out: + out.append(path) + return out + + +# YAML loading. + +class Loader(yaml.SafeLoader): + """A customized YAML loader. This loader deviates from the official + YAML spec in a few convenient ways: + + - All strings as are Unicode objects. + - All maps are OrderedDicts. + - Strings can begin with % without quotation. + """ + # All strings should be Unicode objects, regardless of contents. + def _construct_unicode(self, node): + return self.construct_scalar(node) + + # Use ordered dictionaries for every YAML map. + # From https://gist.github.com/844388 + def construct_yaml_map(self, node): + data = OrderedDict() + yield data + value = self.construct_mapping(node) + data.update(value) + + def construct_mapping(self, node, deep=False): + if isinstance(node, yaml.MappingNode): + self.flatten_mapping(node) + else: + raise yaml.constructor.ConstructorError( + None, None, + 'expected a mapping node, but found %s' % node.id, + node.start_mark + ) + + mapping = OrderedDict() + for key_node, value_node in node.value: + key = self.construct_object(key_node, deep=deep) + try: + hash(key) + except TypeError as exc: + raise yaml.constructor.ConstructorError( + 'while constructing a mapping', + node.start_mark, 'found unacceptable key (%s)' % exc, + key_node.start_mark + ) + value = self.construct_object(value_node, deep=deep) + mapping[key] = value + return mapping + + # Allow bare strings to begin with %. Directives are still detected. + def check_plain(self): + plain = super(Loader, self).check_plain() + return plain or self.peek() == '%' + + +Loader.add_constructor('tag:yaml.org,2002:str', Loader._construct_unicode) +Loader.add_constructor('tag:yaml.org,2002:map', Loader.construct_yaml_map) +Loader.add_constructor('tag:yaml.org,2002:omap', Loader.construct_yaml_map) + + +def load_yaml(filename): + """Read a YAML document from a file. If the file cannot be read or + parsed, a ConfigReadError is raised. + """ + try: + with open(filename, 'r') as f: + return yaml.load(f, Loader=Loader) + except (IOError, yaml.error.YAMLError) as exc: + raise ConfigReadError(filename, exc) + + +# YAML dumping. + +class Dumper(yaml.SafeDumper): + """A PyYAML Dumper that represents OrderedDicts as ordinary mappings + (in order, of course). + """ + # From http://pyyaml.org/attachment/ticket/161/use_ordered_dict.py + def represent_mapping(self, tag, mapping, flow_style=None): + value = [] + node = yaml.MappingNode(tag, value, flow_style=flow_style) + if self.alias_key is not None: + self.represented_objects[self.alias_key] = node + best_style = False + if hasattr(mapping, 'items'): + mapping = list(mapping.items()) + for item_key, item_value in mapping: + node_key = self.represent_data(item_key) + node_value = self.represent_data(item_value) + if not (isinstance(node_key, yaml.ScalarNode) + and not node_key.style): + best_style = False + if not (isinstance(node_value, yaml.ScalarNode) + and not node_value.style): + best_style = False + value.append((node_key, node_value)) + if flow_style is None: + if self.default_flow_style is not None: + node.flow_style = self.default_flow_style + else: + node.flow_style = best_style + return node + + def represent_list(self, data): + """If a list has less than 4 items, represent it in inline style + (i.e. comma separated, within square brackets). + """ + node = super(Dumper, self).represent_list(data) + length = len(data) + if self.default_flow_style is None and length < 4: + node.flow_style = True + elif self.default_flow_style is None: + node.flow_style = False + return node + + def represent_bool(self, data): + """Represent bool as 'yes' or 'no' instead of 'true' or 'false'. + """ + if data: + value = 'yes' + else: + value = 'no' + return self.represent_scalar('tag:yaml.org,2002:bool', value) + + def represent_none(self, data): + """Represent a None value with nothing instead of 'none'. + """ + return self.represent_scalar('tag:yaml.org,2002:null', '') + + +Dumper.add_representer(OrderedDict, Dumper.represent_dict) +Dumper.add_representer(bool, Dumper.represent_bool) +Dumper.add_representer(type(None), Dumper.represent_none) +Dumper.add_representer(list, Dumper.represent_list) + + +def restore_yaml_comments(data, default_data): + """Scan default_data for comments (we include empty lines in our + definition of comments) and place them before the same keys in data. + Only works with comments that are on one or more own lines, i.e. + not next to a yaml mapping. + """ + comment_map = dict() + default_lines = iter(default_data.splitlines()) + for line in default_lines: + if not line: + comment = "\n" + elif line.startswith("#"): + comment = "{0}\n".format(line) + else: + continue + while True: + line = next(default_lines) + if line and not line.startswith("#"): + break + comment += "{0}\n".format(line) + key = line.split(':')[0].strip() + comment_map[key] = comment + out_lines = iter(data.splitlines()) + out_data = "" + for line in out_lines: + key = line.split(':')[0].strip() + if key in comment_map: + out_data += comment_map[key] + out_data += "{0}\n".format(line) + return out_data + + +# Main interface. + +class Configuration(RootView): + def __init__(self, appname, modname=None, read=True): + """Create a configuration object by reading the + automatically-discovered config files for the application for a + given name. If `modname` is specified, it should be the import + name of a module whose package will be searched for a default + config file. (Otherwise, no defaults are used.) Pass `False` for + `read` to disable automatic reading of all discovered + configuration files. Use this when creating a configuration + object at module load time and then call the `read` method + later. + """ + super(Configuration, self).__init__([]) + self.appname = appname + self.modname = modname + + self._env_var = '{0}DIR'.format(self.appname.upper()) + + if read: + self.read() + + def user_config_path(self): + """Points to the location of the user configuration. + + The file may not exist. + """ + return os.path.join(self.config_dir(), CONFIG_FILENAME) + + def _add_user_source(self): + """Add the configuration options from the YAML file in the + user's configuration directory (given by `config_dir`) if it + exists. + """ + filename = self.user_config_path() + if os.path.isfile(filename): + self.add(ConfigSource(load_yaml(filename) or {}, filename)) + + def _add_default_source(self): + """Add the package's default configuration settings. This looks + for a YAML file located inside the package for the module + `modname` if it was given. + """ + if self.modname: + pkg_path = _package_path(self.modname) + if pkg_path: + filename = os.path.join(pkg_path, DEFAULT_FILENAME) + if os.path.isfile(filename): + self.add(ConfigSource(load_yaml(filename), filename, True)) + + def read(self, user=True, defaults=True): + """Find and read the files for this configuration and set them + as the sources for this configuration. To disable either + discovered user configuration files or the in-package defaults, + set `user` or `defaults` to `False`. + """ + if user: + self._add_user_source() + if defaults: + self._add_default_source() + + def config_dir(self): + """Get the path to the user configuration directory. The + directory is guaranteed to exist as a postcondition (one may be + created if none exist). + + If the application's ``...DIR`` environment variable is set, it + is used as the configuration directory. Otherwise, + platform-specific standard configuration locations are searched + for a ``config.yaml`` file. If no configuration file is found, a + fallback path is used. + """ + # If environment variable is set, use it. + if self._env_var in os.environ: + appdir = os.environ[self._env_var] + appdir = os.path.abspath(os.path.expanduser(appdir)) + if os.path.isfile(appdir): + raise ConfigError('{0} must be a directory'.format( + self._env_var + )) + + else: + # Search platform-specific locations. If no config file is + # found, fall back to the final directory in the list. + for confdir in config_dirs(): + appdir = os.path.join(confdir, self.appname) + if os.path.isfile(os.path.join(appdir, CONFIG_FILENAME)): + break + + # Ensure that the directory exists. + if not os.path.isdir(appdir): + os.makedirs(appdir) + return appdir + + def set_file(self, filename): + """Parses the file as YAML and inserts it into the configuration + sources with highest priority. + """ + filename = os.path.abspath(filename) + self.set(ConfigSource(load_yaml(filename), filename)) + + def dump(self, full=True): + """Dump the Configuration object to a YAML file. + + The order of the keys is determined from the default + configuration file. All keys not in the default configuration + will be appended to the end of the file. + + :param filename: The file to dump the configuration to, or None + if the YAML string should be returned instead + :type filename: unicode + :param full: Dump settings that don't differ from the defaults + as well + """ + if full: + out_dict = self.flatten() + else: + # Exclude defaults when flattening. + sources = [s for s in self.sources if not s.default] + out_dict = RootView(sources).flatten() + + yaml_out = yaml.dump(out_dict, Dumper=Dumper, + default_flow_style=None, indent=4, + width=1000) + + # Restore comments to the YAML text. + default_source = None + for source in self.sources: + if source.default: + default_source = source + break + if default_source: + with open(default_source.filename, 'r') as fp: + default_data = fp.read() + yaml_out = restore_yaml_comments(yaml_out, default_data) + + return yaml_out + + +class LazyConfig(Configuration): + """A Configuration at reads files on demand when it is first + accessed. This is appropriate for using as a global config object at + the module level. + """ + def __init__(self, appname, modname=None): + super(LazyConfig, self).__init__(appname, modname, False) + self._materialized = False # Have we read the files yet? + self._lazy_prefix = [] # Pre-materialization calls to set(). + self._lazy_suffix = [] # Calls to add(). + + def read(self, user=True, defaults=True): + self._materialized = True + super(LazyConfig, self).read(user, defaults) + + def resolve(self): + if not self._materialized: + # Read files and unspool buffers. + self.read() + self.sources += self._lazy_suffix + self.sources[:0] = self._lazy_prefix + return super(LazyConfig, self).resolve() + + def add(self, value): + super(LazyConfig, self).add(value) + if not self._materialized: + # Buffer additions to end. + self._lazy_suffix += self.sources + del self.sources[:] + + def set(self, value): + super(LazyConfig, self).set(value) + if not self._materialized: + # Buffer additions to beginning. + self._lazy_prefix[:0] = self.sources + del self.sources[:] + + def clear(self): + """Remove all sources from this configuration.""" + del self.sources[:] + self._lazy_suffix = [] + self._lazy_prefix = [] + + +# "Validated" configuration views: experimental! + + +REQUIRED = object() +"""A sentinel indicating that there is no default value and an exception +should be raised when the value is missing. +""" + + +class Template(object): + """A value template for configuration fields. + + The template works like a type and instructs Confit about how to + interpret a deserialized YAML value. This includes type conversions, + providing a default value, and validating for errors. For example, a + filepath type might expand tildes and check that the file exists. + """ + def __init__(self, default=REQUIRED): + """Create a template with a given default value. + + If `default` is the sentinel `REQUIRED` (as it is by default), + then an error will be raised when a value is missing. Otherwise, + missing values will instead return `default`. + """ + self.default = default + + def __call__(self, view): + """Invoking a template on a view gets the view's value according + to the template. + """ + return self.value(view, self) + + def value(self, view, template=None): + """Get the value for a `ConfigView`. + + May raise a `NotFoundError` if the value is missing (and the + template requires it) or a `ConfigValueError` for invalid values. + """ + if view.exists(): + value, _ = view.first() + return self.convert(value, view) + elif self.default is REQUIRED: + # Missing required value. This is an error. + raise NotFoundError("{0} not found".format(view.name)) + else: + # Missing value, but not required. + return self.default + + def convert(self, value, view): + """Convert the YAML-deserialized value to a value of the desired + type. + + Subclasses should override this to provide useful conversions. + May raise a `ConfigValueError` when the configuration is wrong. + """ + # Default implementation does no conversion. + return value + + def fail(self, message, view, type_error=False): + """Raise an exception indicating that a value cannot be + accepted. + + `type_error` indicates whether the error is due to a type + mismatch rather than a malformed value. In this case, a more + specific exception is raised. + """ + exc_class = ConfigTypeError if type_error else ConfigValueError + raise exc_class( + '{0}: {1}'.format(view.name, message) + ) + + def __repr__(self): + return '{0}({1})'.format( + type(self).__name__, + '' if self.default is REQUIRED else repr(self.default), + ) + + +class Integer(Template): + """An integer configuration value template. + """ + def convert(self, value, view): + """Check that the value is an integer. Floats are rounded. + """ + if isinstance(value, int): + return value + elif isinstance(value, float): + return int(value) + else: + self.fail('must be a number', view, True) + + +class Number(Template): + """A numeric type: either an integer or a floating-point number. + """ + def convert(self, value, view): + """Check that the value is an int or a float. + """ + if isinstance(value, NUMERIC_TYPES): + return value + else: + self.fail( + 'must be numeric, not {0}'.format(type(value).__name__), + view, + True + ) + + +class MappingTemplate(Template): + """A template that uses a dictionary to specify other types for the + values for a set of keys and produce a validated `AttrDict`. + """ + def __init__(self, mapping): + """Create a template according to a dict (mapping). The + mapping's values should themselves either be Types or + convertible to Types. + """ + subtemplates = {} + for key, typ in mapping.items(): + subtemplates[key] = as_template(typ) + self.subtemplates = subtemplates + + def value(self, view, template=None): + """Get a dict with the same keys as the template and values + validated according to the value types. + """ + out = AttrDict() + for key, typ in self.subtemplates.items(): + out[key] = typ.value(view[key], self) + return out + + def __repr__(self): + return 'MappingTemplate({0})'.format(repr(self.subtemplates)) + + +class String(Template): + """A string configuration value template. + """ + def __init__(self, default=REQUIRED, pattern=None): + """Create a template with the added optional `pattern` argument, + a regular expression string that the value should match. + """ + super(String, self).__init__(default) + self.pattern = pattern + if pattern: + self.regex = re.compile(pattern) + + def convert(self, value, view): + """Check that the value is a string and matches the pattern. + """ + if isinstance(value, BASESTRING): + if self.pattern and not self.regex.match(value): + self.fail( + "must match the pattern {0}".format(self.pattern), + view + ) + return value + else: + self.fail('must be a string', view, True) + + +class Choice(Template): + """A template that permits values from a sequence of choices. + """ + def __init__(self, choices): + """Create a template that validates any of the values from the + iterable `choices`. + + If `choices` is a map, then the corresponding value is emitted. + Otherwise, the value itself is emitted. + """ + self.choices = choices + + def convert(self, value, view): + """Ensure that the value is among the choices (and remap if the + choices are a mapping). + """ + if value not in self.choices: + self.fail( + 'must be one of {0}, not {1}'.format( + repr(list(self.choices)), repr(value) + ), + view + ) + + if isinstance(self.choices, collections.Mapping): + return self.choices[value] + else: + return value + + def __repr__(self): + return 'Choice({0!r})'.format(self.choices) + + +class StrSeq(Template): + """A template for values that are lists of strings. + + Validates both actual YAML string lists and single strings. Strings + can optionally be split on whitespace. + """ + def __init__(self, split=True): + """Create a new template. + + `split` indicates whether, when the underlying value is a single + string, it should be split on whitespace. Otherwise, the + resulting value is a list containing a single string. + """ + super(StrSeq, self).__init__() + self.split = split + + def convert(self, value, view): + if isinstance(value, bytes): + value = value.decode('utf8', 'ignore') + + if isinstance(value, STRING): + if self.split: + return value.split() + else: + return [value] + + try: + value = list(value) + except TypeError: + self.fail('must be a whitespace-separated string or a list', + view, True) + + def convert(x): + if isinstance(x, unicode): + return x + elif isinstance(x, BASESTRING): + return x.decode('utf8', 'ignore') + else: + self.fail('must be a list of strings', view, True) + return map(convert, value) + + +class Filename(Template): + """A template that validates strings as filenames. + + Filenames are returned as absolute, tilde-free paths. + + Relative paths are relative to the template's `cwd` argument + when it is specified, then the configuration directory (see + the `config_dir` method) if they come from a file. Otherwise, + they are relative to the current working directory. This helps + attain the expected behavior when using command-line options. + """ + def __init__(self, default=REQUIRED, cwd=None, relative_to=None, + in_app_dir=False): + """ `relative_to` is the name of a sibling value that is + being validated at the same time. + + `in_app_dir` indicates whether the path should be resolved + inside the application's config directory (even when the setting + does not come from a file). + """ + super(Filename, self).__init__(default) + self.cwd = cwd + self.relative_to = relative_to + self.in_app_dir = in_app_dir + + def __repr__(self): + args = [] + + if self.default is not REQUIRED: + args.append(repr(self.default)) + + if self.cwd is not None: + args.append('cwd=' + repr(self.cwd)) + + if self.relative_to is not None: + args.append('relative_to=' + repr(self.relative_to)) + + if self.in_app_dir: + args.append('in_app_dir=True') + + return 'Filename({0})'.format(', '.join(args)) + + def resolve_relative_to(self, view, template): + if not isinstance(template, (collections.Mapping, MappingTemplate)): + # disallow config.get(Filename(relative_to='foo')) + raise ConfigTemplateError( + 'relative_to may only be used when getting multiple values.' + ) + + elif self.relative_to == view.key: + raise ConfigTemplateError( + '{0} is relative to itself'.format(view.name) + ) + + elif self.relative_to not in view.parent.keys(): + # self.relative_to is not in the config + self.fail( + ( + 'needs sibling value "{0}" to expand relative path' + ).format(self.relative_to), + view + ) + + old_template = {} + old_template.update(template.subtemplates) + + # save time by skipping MappingTemplate's init loop + next_template = MappingTemplate({}) + next_relative = self.relative_to + + # gather all the needed templates and nothing else + while next_relative is not None: + try: + # pop to avoid infinite loop because of recursive + # relative paths + rel_to_template = old_template.pop(next_relative) + except KeyError: + if next_relative in template.subtemplates: + # we encountered this config key previously + raise ConfigTemplateError(( + '{0} and {1} are recursively relative' + ).format(view.name, self.relative_to)) + else: + raise ConfigTemplateError(( + 'missing template for {0}, needed to expand {1}\'s' + + 'relative path' + ).format(self.relative_to, view.name)) + + next_template.subtemplates[next_relative] = rel_to_template + next_relative = rel_to_template.relative_to + + return view.parent.get(next_template)[self.relative_to] + + def value(self, view, template=None): + path, source = view.first() + if not isinstance(path, BASESTRING): + self.fail( + 'must be a filename, not {0}'.format(type(path).__name__), + view, + True + ) + path = os.path.expanduser(STRING(path)) + + if not os.path.isabs(path): + if self.cwd is not None: + # relative to the template's argument + path = os.path.join(self.cwd, path) + + elif self.relative_to is not None: + path = os.path.join( + self.resolve_relative_to(view, template), + path, + ) + + elif source.filename or self.in_app_dir: + # From defaults: relative to the app's directory. + path = os.path.join(view.root().config_dir(), path) + + return os.path.abspath(path) + + +class TypeTemplate(Template): + """A simple template that checks that a value is an instance of a + desired Python type. + """ + def __init__(self, typ, default=REQUIRED): + """Create a template that checks that the value is an instance + of `typ`. + """ + super(TypeTemplate, self).__init__(default) + self.typ = typ + + def convert(self, value, view): + if not isinstance(value, self.typ): + self.fail( + 'must be a {0}, not {1}'.format( + self.typ.__name__, + type(value).__name__, + ), + view, + True + ) + return value + + +class AttrDict(dict): + """A `dict` subclass that can be accessed via attributes (dot + notation) for convenience. + """ + def __getattr__(self, key): + if key in self: + return self[key] + else: + raise AttributeError(key) + + +def as_template(value): + """Convert a simple "shorthand" Python value to a `Template`. + """ + if isinstance(value, Template): + # If it's already a Template, pass it through. + return value + elif isinstance(value, collections.Mapping): + # Dictionaries work as templates. + return MappingTemplate(value) + elif value is int: + return Integer() + elif isinstance(value, int): + return Integer(value) + elif isinstance(value, type) and issubclass(value, BASESTRING): + return String() + elif isinstance(value, BASESTRING): + return String(value) + elif value is float: + return Number() + elif value is None: + return Template() + elif value is dict: + return TypeTemplate(collections.Mapping) + elif value is list: + return TypeTemplate(collections.Sequence) + elif isinstance(value, type): + return TypeTemplate(value) + else: + raise ValueError('cannot convert to template: {0!r}'.format(value)) diff --git a/lib/beets/util/enumeration.py b/lib/beets/util/enumeration.py new file mode 100644 index 00000000..e8cd0fe1 --- /dev/null +++ b/lib/beets/util/enumeration.py @@ -0,0 +1,40 @@ +# This file is part of beets. +# Copyright 2013, Adrian Sampson. +# +# Permission is hereby granted, free of charge, to any person obtaining +# a copy of this software and associated documentation files (the +# "Software"), to deal in the Software without restriction, including +# without limitation the rights to use, copy, modify, merge, publish, +# distribute, sublicense, and/or sell copies of the Software, and to +# permit persons to whom the Software is furnished to do so, subject to +# the following conditions: +# +# The above copyright notice and this permission notice shall be +# included in all copies or substantial portions of the Software. + +from enum import Enum + + +class OrderedEnum(Enum): + """ + An Enum subclass that allows comparison of members. + """ + def __ge__(self, other): + if self.__class__ is other.__class__: + return self.value >= other.value + return NotImplemented + + def __gt__(self, other): + if self.__class__ is other.__class__: + return self.value > other.value + return NotImplemented + + def __le__(self, other): + if self.__class__ is other.__class__: + return self.value <= other.value + return NotImplemented + + def __lt__(self, other): + if self.__class__ is other.__class__: + return self.value < other.value + return NotImplemented diff --git a/lib/beets/util/functemplate.py b/lib/beets/util/functemplate.py new file mode 100644 index 00000000..03e57c61 --- /dev/null +++ b/lib/beets/util/functemplate.py @@ -0,0 +1,571 @@ +# This file is part of beets. +# Copyright 2013, Adrian Sampson. +# +# Permission is hereby granted, free of charge, to any person obtaining +# a copy of this software and associated documentation files (the +# "Software"), to deal in the Software without restriction, including +# without limitation the rights to use, copy, modify, merge, publish, +# distribute, sublicense, and/or sell copies of the Software, and to +# permit persons to whom the Software is furnished to do so, subject to +# the following conditions: +# +# The above copyright notice and this permission notice shall be +# included in all copies or substantial portions of the Software. + +"""This module implements a string formatter based on the standard PEP +292 string.Template class extended with function calls. Variables, as +with string.Template, are indicated with $ and functions are delimited +with %. + +This module assumes that everything is Unicode: the template and the +substitution values. Bytestrings are not supported. Also, the templates +always behave like the ``safe_substitute`` method in the standard +library: unknown symbols are left intact. + +This is sort of like a tiny, horrible degeneration of a real templating +engine like Jinja2 or Mustache. +""" +from __future__ import print_function + +import re +import ast +import dis +import types + +SYMBOL_DELIM = u'$' +FUNC_DELIM = u'%' +GROUP_OPEN = u'{' +GROUP_CLOSE = u'}' +ARG_SEP = u',' +ESCAPE_CHAR = u'$' + +VARIABLE_PREFIX = '__var_' +FUNCTION_PREFIX = '__func_' + + +class Environment(object): + """Contains the values and functions to be substituted into a + template. + """ + def __init__(self, values, functions): + self.values = values + self.functions = functions + + +# Code generation helpers. + +def ex_lvalue(name): + """A variable load expression.""" + return ast.Name(name, ast.Store()) + + +def ex_rvalue(name): + """A variable store expression.""" + return ast.Name(name, ast.Load()) + + +def ex_literal(val): + """An int, float, long, bool, string, or None literal with the given + value. + """ + if val is None: + return ast.Name('None', ast.Load()) + elif isinstance(val, (int, float, long)): + return ast.Num(val) + elif isinstance(val, bool): + return ast.Name(str(val), ast.Load()) + elif isinstance(val, basestring): + return ast.Str(val) + raise TypeError('no literal for {0}'.format(type(val))) + + +def ex_varassign(name, expr): + """Assign an expression into a single variable. The expression may + either be an `ast.expr` object or a value to be used as a literal. + """ + if not isinstance(expr, ast.expr): + expr = ex_literal(expr) + return ast.Assign([ex_lvalue(name)], expr) + + +def ex_call(func, args): + """A function-call expression with only positional parameters. The + function may be an expression or the name of a function. Each + argument may be an expression or a value to be used as a literal. + """ + if isinstance(func, basestring): + func = ex_rvalue(func) + + args = list(args) + for i in range(len(args)): + if not isinstance(args[i], ast.expr): + args[i] = ex_literal(args[i]) + + return ast.Call(func, args, [], None, None) + + +def compile_func(arg_names, statements, name='_the_func', debug=False): + """Compile a list of statements as the body of a function and return + the resulting Python function. If `debug`, then print out the + bytecode of the compiled function. + """ + func_def = ast.FunctionDef( + name, + ast.arguments( + [ast.Name(n, ast.Param()) for n in arg_names], + None, None, + [ex_literal(None) for _ in arg_names], + ), + statements, + [], + ) + mod = ast.Module([func_def]) + ast.fix_missing_locations(mod) + + prog = compile(mod, '', 'exec') + + # Debug: show bytecode. + if debug: + dis.dis(prog) + for const in prog.co_consts: + if isinstance(const, types.CodeType): + dis.dis(const) + + the_locals = {} + exec prog in {}, the_locals + return the_locals[name] + + +# AST nodes for the template language. + +class Symbol(object): + """A variable-substitution symbol in a template.""" + def __init__(self, ident, original): + self.ident = ident + self.original = original + + def __repr__(self): + return u'Symbol(%s)' % repr(self.ident) + + def evaluate(self, env): + """Evaluate the symbol in the environment, returning a Unicode + string. + """ + if self.ident in env.values: + # Substitute for a value. + return env.values[self.ident] + else: + # Keep original text. + return self.original + + def translate(self): + """Compile the variable lookup.""" + expr = ex_rvalue(VARIABLE_PREFIX + self.ident.encode('utf8')) + return [expr], set([self.ident.encode('utf8')]), set() + + +class Call(object): + """A function call in a template.""" + def __init__(self, ident, args, original): + self.ident = ident + self.args = args + self.original = original + + def __repr__(self): + return u'Call(%s, %s, %s)' % (repr(self.ident), repr(self.args), + repr(self.original)) + + def evaluate(self, env): + """Evaluate the function call in the environment, returning a + Unicode string. + """ + if self.ident in env.functions: + arg_vals = [expr.evaluate(env) for expr in self.args] + try: + out = env.functions[self.ident](*arg_vals) + except Exception as exc: + # Function raised exception! Maybe inlining the name of + # the exception will help debug. + return u'<%s>' % unicode(exc) + return unicode(out) + else: + return self.original + + def translate(self): + """Compile the function call.""" + varnames = set() + funcnames = set([self.ident.encode('utf8')]) + + arg_exprs = [] + for arg in self.args: + subexprs, subvars, subfuncs = arg.translate() + varnames.update(subvars) + funcnames.update(subfuncs) + + # Create a subexpression that joins the result components of + # the arguments. + arg_exprs.append(ex_call( + ast.Attribute(ex_literal(u''), 'join', ast.Load()), + [ex_call( + 'map', + [ + ex_rvalue('unicode'), + ast.List(subexprs, ast.Load()), + ] + )], + )) + + subexpr_call = ex_call( + FUNCTION_PREFIX + self.ident.encode('utf8'), + arg_exprs + ) + return [subexpr_call], varnames, funcnames + + +class Expression(object): + """Top-level template construct: contains a list of text blobs, + Symbols, and Calls. + """ + def __init__(self, parts): + self.parts = parts + + def __repr__(self): + return u'Expression(%s)' % (repr(self.parts)) + + def evaluate(self, env): + """Evaluate the entire expression in the environment, returning + a Unicode string. + """ + out = [] + for part in self.parts: + if isinstance(part, basestring): + out.append(part) + else: + out.append(part.evaluate(env)) + return u''.join(map(unicode, out)) + + def translate(self): + """Compile the expression to a list of Python AST expressions, a + set of variable names used, and a set of function names. + """ + expressions = [] + varnames = set() + funcnames = set() + for part in self.parts: + if isinstance(part, basestring): + expressions.append(ex_literal(part)) + else: + e, v, f = part.translate() + expressions.extend(e) + varnames.update(v) + funcnames.update(f) + return expressions, varnames, funcnames + + +# Parser. + +class ParseError(Exception): + pass + + +class Parser(object): + """Parses a template expression string. Instantiate the class with + the template source and call ``parse_expression``. The ``pos`` field + will indicate the character after the expression finished and + ``parts`` will contain a list of Unicode strings, Symbols, and Calls + reflecting the concatenated portions of the expression. + + This is a terrible, ad-hoc parser implementation based on a + left-to-right scan with no lexing step to speak of; it's probably + both inefficient and incorrect. Maybe this should eventually be + replaced with a real, accepted parsing technique (PEG, parser + generator, etc.). + """ + def __init__(self, string): + self.string = string + self.pos = 0 + self.parts = [] + + # Common parsing resources. + special_chars = (SYMBOL_DELIM, FUNC_DELIM, GROUP_OPEN, GROUP_CLOSE, + ARG_SEP, ESCAPE_CHAR) + special_char_re = re.compile(ur'[%s]|$' % + u''.join(re.escape(c) for c in special_chars)) + + def parse_expression(self): + """Parse a template expression starting at ``pos``. Resulting + components (Unicode strings, Symbols, and Calls) are added to + the ``parts`` field, a list. The ``pos`` field is updated to be + the next character after the expression. + """ + text_parts = [] + + while self.pos < len(self.string): + char = self.string[self.pos] + + if char not in self.special_chars: + # A non-special character. Skip to the next special + # character, treating the interstice as literal text. + next_pos = ( + self.special_char_re.search(self.string[self.pos:]).start() + + self.pos + ) + text_parts.append(self.string[self.pos:next_pos]) + self.pos = next_pos + continue + + if self.pos == len(self.string) - 1: + # The last character can never begin a structure, so we + # just interpret it as a literal character (unless it + # terminates the expression, as with , and }). + if char not in (GROUP_CLOSE, ARG_SEP): + text_parts.append(char) + self.pos += 1 + break + + next_char = self.string[self.pos + 1] + if char == ESCAPE_CHAR and next_char in \ + (SYMBOL_DELIM, FUNC_DELIM, GROUP_CLOSE, ARG_SEP): + # An escaped special character ($$, $}, etc.). Note that + # ${ is not an escape sequence: this is ambiguous with + # the start of a symbol and it's not necessary (just + # using { suffices in all cases). + text_parts.append(next_char) + self.pos += 2 # Skip the next character. + continue + + # Shift all characters collected so far into a single string. + if text_parts: + self.parts.append(u''.join(text_parts)) + text_parts = [] + + if char == SYMBOL_DELIM: + # Parse a symbol. + self.parse_symbol() + elif char == FUNC_DELIM: + # Parse a function call. + self.parse_call() + elif char in (GROUP_CLOSE, ARG_SEP): + # Template terminated. + break + elif char == GROUP_OPEN: + # Start of a group has no meaning hear; just pass + # through the character. + text_parts.append(char) + self.pos += 1 + else: + assert False + + # If any parsed characters remain, shift them into a string. + if text_parts: + self.parts.append(u''.join(text_parts)) + + def parse_symbol(self): + """Parse a variable reference (like ``$foo`` or ``${foo}``) + starting at ``pos``. Possibly appends a Symbol object (or, + failing that, text) to the ``parts`` field and updates ``pos``. + The character at ``pos`` must, as a precondition, be ``$``. + """ + assert self.pos < len(self.string) + assert self.string[self.pos] == SYMBOL_DELIM + + if self.pos == len(self.string) - 1: + # Last character. + self.parts.append(SYMBOL_DELIM) + self.pos += 1 + return + + next_char = self.string[self.pos + 1] + start_pos = self.pos + self.pos += 1 + + if next_char == GROUP_OPEN: + # A symbol like ${this}. + self.pos += 1 # Skip opening. + closer = self.string.find(GROUP_CLOSE, self.pos) + if closer == -1 or closer == self.pos: + # No closing brace found or identifier is empty. + self.parts.append(self.string[start_pos:self.pos]) + else: + # Closer found. + ident = self.string[self.pos:closer] + self.pos = closer + 1 + self.parts.append(Symbol(ident, + self.string[start_pos:self.pos])) + + else: + # A bare-word symbol. + ident = self._parse_ident() + if ident: + # Found a real symbol. + self.parts.append(Symbol(ident, + self.string[start_pos:self.pos])) + else: + # A standalone $. + self.parts.append(SYMBOL_DELIM) + + def parse_call(self): + """Parse a function call (like ``%foo{bar,baz}``) starting at + ``pos``. Possibly appends a Call object to ``parts`` and update + ``pos``. The character at ``pos`` must be ``%``. + """ + assert self.pos < len(self.string) + assert self.string[self.pos] == FUNC_DELIM + + start_pos = self.pos + self.pos += 1 + + ident = self._parse_ident() + if not ident: + # No function name. + self.parts.append(FUNC_DELIM) + return + + if self.pos >= len(self.string): + # Identifier terminates string. + self.parts.append(self.string[start_pos:self.pos]) + return + + if self.string[self.pos] != GROUP_OPEN: + # Argument list not opened. + self.parts.append(self.string[start_pos:self.pos]) + return + + # Skip past opening brace and try to parse an argument list. + self.pos += 1 + args = self.parse_argument_list() + if self.pos >= len(self.string) or \ + self.string[self.pos] != GROUP_CLOSE: + # Arguments unclosed. + self.parts.append(self.string[start_pos:self.pos]) + return + + self.pos += 1 # Move past closing brace. + self.parts.append(Call(ident, args, self.string[start_pos:self.pos])) + + def parse_argument_list(self): + """Parse a list of arguments starting at ``pos``, returning a + list of Expression objects. Does not modify ``parts``. Should + leave ``pos`` pointing to a } character or the end of the + string. + """ + # Try to parse a subexpression in a subparser. + expressions = [] + + while self.pos < len(self.string): + subparser = Parser(self.string[self.pos:]) + subparser.parse_expression() + + # Extract and advance past the parsed expression. + expressions.append(Expression(subparser.parts)) + self.pos += subparser.pos + + if self.pos >= len(self.string) or \ + self.string[self.pos] == GROUP_CLOSE: + # Argument list terminated by EOF or closing brace. + break + + # Only other way to terminate an expression is with ,. + # Continue to the next argument. + assert self.string[self.pos] == ARG_SEP + self.pos += 1 + + return expressions + + def _parse_ident(self): + """Parse an identifier and return it (possibly an empty string). + Updates ``pos``. + """ + remainder = self.string[self.pos:] + ident = re.match(ur'\w*', remainder).group(0) + self.pos += len(ident) + return ident + + +def _parse(template): + """Parse a top-level template string Expression. Any extraneous text + is considered literal text. + """ + parser = Parser(template) + parser.parse_expression() + + parts = parser.parts + remainder = parser.string[parser.pos:] + if remainder: + parts.append(remainder) + return Expression(parts) + + +# External interface. + +class Template(object): + """A string template, including text, Symbols, and Calls. + """ + def __init__(self, template): + self.expr = _parse(template) + self.original = template + self.compiled = self.translate() + + def __eq__(self, other): + return self.original == other.original + + def interpret(self, values={}, functions={}): + """Like `substitute`, but forces the interpreter (rather than + the compiled version) to be used. The interpreter includes + exception-handling code for missing variables and buggy template + functions but is much slower. + """ + return self.expr.evaluate(Environment(values, functions)) + + def substitute(self, values={}, functions={}): + """Evaluate the template given the values and functions. + """ + try: + res = self.compiled(values, functions) + except: # Handle any exceptions thrown by compiled version. + res = self.interpret(values, functions) + return res + + def translate(self): + """Compile the template to a Python function.""" + expressions, varnames, funcnames = self.expr.translate() + + argnames = [] + for varname in varnames: + argnames.append(VARIABLE_PREFIX.encode('utf8') + varname) + for funcname in funcnames: + argnames.append(FUNCTION_PREFIX.encode('utf8') + funcname) + + func = compile_func( + argnames, + [ast.Return(ast.List(expressions, ast.Load()))], + ) + + def wrapper_func(values={}, functions={}): + args = {} + for varname in varnames: + args[VARIABLE_PREFIX + varname] = values[varname] + for funcname in funcnames: + args[FUNCTION_PREFIX + funcname] = functions[funcname] + parts = func(**args) + return u''.join(parts) + + return wrapper_func + + +# Performance tests. + +if __name__ == '__main__': + import timeit + _tmpl = Template(u'foo $bar %baz{foozle $bar barzle} $bar') + _vars = {'bar': 'qux'} + _funcs = {'baz': unicode.upper} + interp_time = timeit.timeit('_tmpl.interpret(_vars, _funcs)', + 'from __main__ import _tmpl, _vars, _funcs', + number=10000) + print(interp_time) + comp_time = timeit.timeit('_tmpl.substitute(_vars, _funcs)', + 'from __main__ import _tmpl, _vars, _funcs', + number=10000) + print(comp_time) + print('Speedup:', interp_time / comp_time) diff --git a/lib/beets/util/pipeline.py b/lib/beets/util/pipeline.py new file mode 100644 index 00000000..d267789c --- /dev/null +++ b/lib/beets/util/pipeline.py @@ -0,0 +1,517 @@ +# This file is part of beets. +# Copyright 2013, Adrian Sampson. +# +# Permission is hereby granted, free of charge, to any person obtaining +# a copy of this software and associated documentation files (the +# "Software"), to deal in the Software without restriction, including +# without limitation the rights to use, copy, modify, merge, publish, +# distribute, sublicense, and/or sell copies of the Software, and to +# permit persons to whom the Software is furnished to do so, subject to +# the following conditions: +# +# The above copyright notice and this permission notice shall be +# included in all copies or substantial portions of the Software. + +"""Simple but robust implementation of generator/coroutine-based +pipelines in Python. The pipelines may be run either sequentially +(single-threaded) or in parallel (one thread per pipeline stage). + +This implementation supports pipeline bubbles (indications that the +processing for a certain item should abort). To use them, yield the +BUBBLE constant from any stage coroutine except the last. + +In the parallel case, the implementation transparently handles thread +shutdown when the processing is complete and when a stage raises an +exception. KeyboardInterrupts (^C) are also handled. + +When running a parallel pipeline, it is also possible to use +multiple coroutines for the same pipeline stage; this lets you speed +up a bottleneck stage by dividing its work among multiple threads. +To do so, pass an iterable of coroutines to the Pipeline constructor +in place of any single coroutine. +""" +from __future__ import print_function + +import Queue +from threading import Thread, Lock +import sys + +BUBBLE = '__PIPELINE_BUBBLE__' +POISON = '__PIPELINE_POISON__' + +DEFAULT_QUEUE_SIZE = 16 + + +def _invalidate_queue(q, val=None, sync=True): + """Breaks a Queue such that it never blocks, always has size 1, + and has no maximum size. get()ing from the queue returns `val`, + which defaults to None. `sync` controls whether a lock is + required (because it's not reentrant!). + """ + def _qsize(len=len): + return 1 + + def _put(item): + pass + + def _get(): + return val + + if sync: + q.mutex.acquire() + + try: + q.maxsize = 0 + q._qsize = _qsize + q._put = _put + q._get = _get + q.not_empty.notifyAll() + q.not_full.notifyAll() + + finally: + if sync: + q.mutex.release() + + +class CountedQueue(Queue.Queue): + """A queue that keeps track of the number of threads that are + still feeding into it. The queue is poisoned when all threads are + finished with the queue. + """ + def __init__(self, maxsize=0): + Queue.Queue.__init__(self, maxsize) + self.nthreads = 0 + self.poisoned = False + + def acquire(self): + """Indicate that a thread will start putting into this queue. + Should not be called after the queue is already poisoned. + """ + with self.mutex: + assert not self.poisoned + assert self.nthreads >= 0 + self.nthreads += 1 + + def release(self): + """Indicate that a thread that was putting into this queue has + exited. If this is the last thread using the queue, the queue + is poisoned. + """ + with self.mutex: + self.nthreads -= 1 + assert self.nthreads >= 0 + if self.nthreads == 0: + # All threads are done adding to this queue. Poison it + # when it becomes empty. + self.poisoned = True + + # Replacement _get invalidates when no items remain. + _old_get = self._get + + def _get(): + out = _old_get() + if not self.queue: + _invalidate_queue(self, POISON, False) + return out + + if self.queue: + # Items remain. + self._get = _get + else: + # No items. Invalidate immediately. + _invalidate_queue(self, POISON, False) + + +class MultiMessage(object): + """A message yielded by a pipeline stage encapsulating multiple + values to be sent to the next stage. + """ + def __init__(self, messages): + self.messages = messages + + +def multiple(messages): + """Yield multiple([message, ..]) from a pipeline stage to send + multiple values to the next pipeline stage. + """ + return MultiMessage(messages) + + +def stage(func): + """Decorate a function to become a simple stage. + + >>> @stage + ... def add(n, i): + ... return i + n + >>> pipe = Pipeline([ + ... iter([1, 2, 3]), + ... add(2), + ... ]) + >>> list(pipe.pull()) + [3, 4, 5] + """ + + def coro(*args): + task = None + while True: + task = yield task + task = func(*(args + (task,))) + return coro + + +def mutator_stage(func): + """Decorate a function that manipulates items in a coroutine to + become a simple stage. + + >>> @mutator_stage + ... def setkey(key, item): + ... item[key] = True + >>> pipe = Pipeline([ + ... iter([{'x': False}, {'a': False}]), + ... setkey('x'), + ... ]) + >>> list(pipe.pull()) + [{'x': True}, {'a': False, 'x': True}] + """ + + def coro(*args): + task = None + while True: + task = yield task + func(*(args + (task,))) + return coro + + +def _allmsgs(obj): + """Returns a list of all the messages encapsulated in obj. If obj + is a MultiMessage, returns its enclosed messages. If obj is BUBBLE, + returns an empty list. Otherwise, returns a list containing obj. + """ + if isinstance(obj, MultiMessage): + return obj.messages + elif obj == BUBBLE: + return [] + else: + return [obj] + + +class PipelineThread(Thread): + """Abstract base class for pipeline-stage threads.""" + def __init__(self, all_threads): + super(PipelineThread, self).__init__() + self.abort_lock = Lock() + self.abort_flag = False + self.all_threads = all_threads + self.exc_info = None + + def abort(self): + """Shut down the thread at the next chance possible. + """ + with self.abort_lock: + self.abort_flag = True + + # Ensure that we are not blocking on a queue read or write. + if hasattr(self, 'in_queue'): + _invalidate_queue(self.in_queue, POISON) + if hasattr(self, 'out_queue'): + _invalidate_queue(self.out_queue, POISON) + + def abort_all(self, exc_info): + """Abort all other threads in the system for an exception. + """ + self.exc_info = exc_info + for thread in self.all_threads: + thread.abort() + + +class FirstPipelineThread(PipelineThread): + """The thread running the first stage in a parallel pipeline setup. + The coroutine should just be a generator. + """ + def __init__(self, coro, out_queue, all_threads): + super(FirstPipelineThread, self).__init__(all_threads) + self.coro = coro + self.out_queue = out_queue + self.out_queue.acquire() + + self.abort_lock = Lock() + self.abort_flag = False + + def run(self): + try: + while True: + with self.abort_lock: + if self.abort_flag: + return + + # Get the value from the generator. + try: + msg = self.coro.next() + except StopIteration: + break + + # Send messages to the next stage. + for msg in _allmsgs(msg): + with self.abort_lock: + if self.abort_flag: + return + self.out_queue.put(msg) + + except: + self.abort_all(sys.exc_info()) + return + + # Generator finished; shut down the pipeline. + self.out_queue.release() + + +class MiddlePipelineThread(PipelineThread): + """A thread running any stage in the pipeline except the first or + last. + """ + def __init__(self, coro, in_queue, out_queue, all_threads): + super(MiddlePipelineThread, self).__init__(all_threads) + self.coro = coro + self.in_queue = in_queue + self.out_queue = out_queue + self.out_queue.acquire() + + def run(self): + try: + # Prime the coroutine. + self.coro.next() + + while True: + with self.abort_lock: + if self.abort_flag: + return + + # Get the message from the previous stage. + msg = self.in_queue.get() + if msg is POISON: + break + + with self.abort_lock: + if self.abort_flag: + return + + # Invoke the current stage. + out = self.coro.send(msg) + + # Send messages to next stage. + for msg in _allmsgs(out): + with self.abort_lock: + if self.abort_flag: + return + self.out_queue.put(msg) + + except: + self.abort_all(sys.exc_info()) + return + + # Pipeline is shutting down normally. + self.out_queue.release() + + +class LastPipelineThread(PipelineThread): + """A thread running the last stage in a pipeline. The coroutine + should yield nothing. + """ + def __init__(self, coro, in_queue, all_threads): + super(LastPipelineThread, self).__init__(all_threads) + self.coro = coro + self.in_queue = in_queue + + def run(self): + # Prime the coroutine. + self.coro.next() + + try: + while True: + with self.abort_lock: + if self.abort_flag: + return + + # Get the message from the previous stage. + msg = self.in_queue.get() + if msg is POISON: + break + + with self.abort_lock: + if self.abort_flag: + return + + # Send to consumer. + self.coro.send(msg) + + except: + self.abort_all(sys.exc_info()) + return + + +class Pipeline(object): + """Represents a staged pattern of work. Each stage in the pipeline + is a coroutine that receives messages from the previous stage and + yields messages to be sent to the next stage. + """ + def __init__(self, stages): + """Makes a new pipeline from a list of coroutines. There must + be at least two stages. + """ + if len(stages) < 2: + raise ValueError('pipeline must have at least two stages') + self.stages = [] + for stage in stages: + if isinstance(stage, (list, tuple)): + self.stages.append(stage) + else: + # Default to one thread per stage. + self.stages.append((stage,)) + + def run_sequential(self): + """Run the pipeline sequentially in the current thread. The + stages are run one after the other. Only the first coroutine + in each stage is used. + """ + list(self.pull()) + + def run_parallel(self, queue_size=DEFAULT_QUEUE_SIZE): + """Run the pipeline in parallel using one thread per stage. The + messages between the stages are stored in queues of the given + size. + """ + queue_count = len(self.stages) - 1 + queues = [CountedQueue(queue_size) for i in range(queue_count)] + threads = [] + + # Set up first stage. + for coro in self.stages[0]: + threads.append(FirstPipelineThread(coro, queues[0], threads)) + + # Middle stages. + for i in range(1, queue_count): + for coro in self.stages[i]: + threads.append(MiddlePipelineThread( + coro, queues[i - 1], queues[i], threads + )) + + # Last stage. + for coro in self.stages[-1]: + threads.append( + LastPipelineThread(coro, queues[-1], threads) + ) + + # Start threads. + for thread in threads: + thread.start() + + # Wait for termination. The final thread lasts the longest. + try: + # Using a timeout allows us to receive KeyboardInterrupt + # exceptions during the join(). + while threads[-1].isAlive(): + threads[-1].join(1) + + except: + # Stop all the threads immediately. + for thread in threads: + thread.abort() + raise + + finally: + # Make completely sure that all the threads have finished + # before we return. They should already be either finished, + # in normal operation, or aborted, in case of an exception. + for thread in threads[:-1]: + thread.join() + + for thread in threads: + exc_info = thread.exc_info + if exc_info: + # Make the exception appear as it was raised originally. + raise exc_info[0], exc_info[1], exc_info[2] + + def pull(self): + """Yield elements from the end of the pipeline. Runs the stages + sequentially until the last yields some messages. Each of the messages + is then yielded by ``pulled.next()``. If the pipeline has a consumer, + that is the last stage does not yield any messages, then pull will not + yield any messages. Only the first coroutine in each stage is used + """ + coros = [stage[0] for stage in self.stages] + + # "Prime" the coroutines. + for coro in coros[1:]: + coro.next() + + # Begin the pipeline. + for out in coros[0]: + msgs = _allmsgs(out) + for coro in coros[1:]: + next_msgs = [] + for msg in msgs: + out = coro.send(msg) + next_msgs.extend(_allmsgs(out)) + msgs = next_msgs + for msg in msgs: + yield msg + +# Smoke test. +if __name__ == '__main__': + import time + + # Test a normally-terminating pipeline both in sequence and + # in parallel. + def produce(): + for i in range(5): + print('generating %i' % i) + time.sleep(1) + yield i + + def work(): + num = yield + while True: + print('processing %i' % num) + time.sleep(2) + num = yield num * 2 + + def consume(): + while True: + num = yield + time.sleep(1) + print('received %i' % num) + + ts_start = time.time() + Pipeline([produce(), work(), consume()]).run_sequential() + ts_seq = time.time() + Pipeline([produce(), work(), consume()]).run_parallel() + ts_par = time.time() + Pipeline([produce(), (work(), work()), consume()]).run_parallel() + ts_end = time.time() + print('Sequential time:', ts_seq - ts_start) + print('Parallel time:', ts_par - ts_seq) + print('Multiply-parallel time:', ts_end - ts_par) + print() + + # Test a pipeline that raises an exception. + def exc_produce(): + for i in range(10): + print('generating %i' % i) + time.sleep(1) + yield i + + def exc_work(): + num = yield + while True: + print('processing %i' % num) + time.sleep(3) + if num == 3: + raise Exception() + num = yield num * 2 + + def exc_consume(): + while True: + num = yield + print('received %i' % num) + + Pipeline([exc_produce(), exc_work(), exc_consume()]).run_parallel(1) diff --git a/lib/beets/vfs.py b/lib/beets/vfs.py new file mode 100644 index 00000000..e940e21f --- /dev/null +++ b/lib/beets/vfs.py @@ -0,0 +1,50 @@ +# This file is part of beets. +# Copyright 2013, Adrian Sampson. +# +# Permission is hereby granted, free of charge, to any person obtaining +# a copy of this software and associated documentation files (the +# "Software"), to deal in the Software without restriction, including +# without limitation the rights to use, copy, modify, merge, publish, +# distribute, sublicense, and/or sell copies of the Software, and to +# permit persons to whom the Software is furnished to do so, subject to +# the following conditions: +# +# The above copyright notice and this permission notice shall be +# included in all copies or substantial portions of the Software. + +"""A simple utility for constructing filesystem-like trees from beets +libraries. +""" +from collections import namedtuple +from beets import util + +Node = namedtuple('Node', ['files', 'dirs']) + + +def _insert(node, path, itemid): + """Insert an item into a virtual filesystem node.""" + if len(path) == 1: + # Last component. Insert file. + node.files[path[0]] = itemid + else: + # In a directory. + dirname = path[0] + rest = path[1:] + if dirname not in node.dirs: + node.dirs[dirname] = Node({}, {}) + _insert(node.dirs[dirname], rest, itemid) + + +def libtree(lib): + """Generates a filesystem-like directory tree for the files + contained in `lib`. Filesystem nodes are (files, dirs) named + tuples in which both components are dictionaries. The first + maps filenames to Item ids. The second maps directory names to + child node tuples. + """ + root = Node({}, {}) + for item in lib.items(): + dest = item.destination(fragment=True) + parts = util.components(dest) + _insert(root, parts, item.id) + return root diff --git a/lib/beetsplug/__init__.py b/lib/beetsplug/__init__.py new file mode 100644 index 00000000..98a7ffd5 --- /dev/null +++ b/lib/beetsplug/__init__.py @@ -0,0 +1,19 @@ +# This file is part of beets. +# Copyright 2013, Adrian Sampson. +# +# Permission is hereby granted, free of charge, to any person obtaining +# a copy of this software and associated documentation files (the +# "Software"), to deal in the Software without restriction, including +# without limitation the rights to use, copy, modify, merge, publish, +# distribute, sublicense, and/or sell copies of the Software, and to +# permit persons to whom the Software is furnished to do so, subject to +# the following conditions: +# +# The above copyright notice and this permission notice shall be +# included in all copies or substantial portions of the Software. + +"""A namespace package for beets plugins.""" + +# Make this a namespace package. +from pkgutil import extend_path +__path__ = extend_path(__path__, __name__) diff --git a/lib/beetsplug/embedart.py b/lib/beetsplug/embedart.py new file mode 100644 index 00000000..49ed4792 --- /dev/null +++ b/lib/beetsplug/embedart.py @@ -0,0 +1,280 @@ +# This file is part of beets. +# Copyright 2014, Adrian Sampson. +# +# Permission is hereby granted, free of charge, to any person obtaining +# a copy of this software and associated documentation files (the +# "Software"), to deal in the Software without restriction, including +# without limitation the rights to use, copy, modify, merge, publish, +# distribute, sublicense, and/or sell copies of the Software, and to +# permit persons to whom the Software is furnished to do so, subject to +# the following conditions: +# +# The above copyright notice and this permission notice shall be +# included in all copies or substantial portions of the Software. + +"""Allows beets to embed album art into file metadata.""" +import os.path +import logging +import imghdr +import subprocess +import platform +from tempfile import NamedTemporaryFile + +from beets.plugins import BeetsPlugin +from beets import mediafile +from beets import ui +from beets.ui import decargs +from beets.util import syspath, normpath, displayable_path +from beets.util.artresizer import ArtResizer +from beets import config + + +log = logging.getLogger('beets') + + +class EmbedCoverArtPlugin(BeetsPlugin): + """Allows albumart to be embedded into the actual files. + """ + def __init__(self): + super(EmbedCoverArtPlugin, self).__init__() + self.config.add({ + 'maxwidth': 0, + 'auto': True, + 'compare_threshold': 0, + 'ifempty': False, + }) + + if self.config['maxwidth'].get(int) and not ArtResizer.shared.local: + self.config['maxwidth'] = 0 + log.warn(u"embedart: ImageMagick or PIL not found; " + u"'maxwidth' option ignored") + if self.config['compare_threshold'].get(int) and not \ + ArtResizer.shared.can_compare: + self.config['compare_threshold'] = 0 + log.warn(u"embedart: ImageMagick 6.8.7 or higher not installed; " + u"'compare_threshold' option ignored") + + def commands(self): + # Embed command. + embed_cmd = ui.Subcommand( + 'embedart', help='embed image files into file metadata' + ) + embed_cmd.parser.add_option( + '-f', '--file', metavar='PATH', help='the image file to embed' + ) + maxwidth = config['embedart']['maxwidth'].get(int) + compare_threshold = config['embedart']['compare_threshold'].get(int) + ifempty = config['embedart']['ifempty'].get(bool) + + def embed_func(lib, opts, args): + if opts.file: + imagepath = normpath(opts.file) + for item in lib.items(decargs(args)): + embed_item(item, imagepath, maxwidth, None, + compare_threshold, ifempty) + else: + for album in lib.albums(decargs(args)): + embed_album(album, maxwidth) + + embed_cmd.func = embed_func + + # Extract command. + extract_cmd = ui.Subcommand('extractart', + help='extract an image from file metadata') + extract_cmd.parser.add_option('-o', dest='outpath', + help='image output file') + + def extract_func(lib, opts, args): + outpath = normpath(opts.outpath or 'cover') + item = lib.items(decargs(args)).get() + extract(outpath, item) + extract_cmd.func = extract_func + + # Clear command. + clear_cmd = ui.Subcommand('clearart', + help='remove images from file metadata') + + def clear_func(lib, opts, args): + clear(lib, decargs(args)) + clear_cmd.func = clear_func + + return [embed_cmd, extract_cmd, clear_cmd] + + +@EmbedCoverArtPlugin.listen('album_imported') +def album_imported(lib, album): + """Automatically embed art into imported albums. + """ + if album.artpath and config['embedart']['auto']: + embed_album(album, config['embedart']['maxwidth'].get(int), True) + + +def embed_item(item, imagepath, maxwidth=None, itempath=None, + compare_threshold=0, ifempty=False, as_album=False): + """Embed an image into the item's media file. + """ + if compare_threshold: + if not check_art_similarity(item, imagepath, compare_threshold): + log.warn(u'Image not similar; skipping.') + return + if ifempty: + art = get_art(item) + if not art: + pass + else: + log.debug(u'embedart: media file contained art already {0}'.format( + displayable_path(imagepath) + )) + return + if maxwidth and not as_album: + imagepath = resize_image(imagepath, maxwidth) + + try: + log.debug(u'embedart: embedding {0}'.format( + displayable_path(imagepath) + )) + item['images'] = [_mediafile_image(imagepath, maxwidth)] + except IOError as exc: + log.error(u'embedart: could not read image file: {0}'.format(exc)) + else: + # We don't want to store the image in the database. + item.try_write(itempath) + del item['images'] + + +def embed_album(album, maxwidth=None, quiet=False): + """Embed album art into all of the album's items. + """ + imagepath = album.artpath + if not imagepath: + log.info(u'No album art present: {0} - {1}'. + format(album.albumartist, album.album)) + return + if not os.path.isfile(syspath(imagepath)): + log.error(u'Album art not found at {0}' + .format(displayable_path(imagepath))) + return + if maxwidth: + imagepath = resize_image(imagepath, maxwidth) + + log.log( + logging.DEBUG if quiet else logging.INFO, + u'Embedding album art into {0.albumartist} - {0.album}.'.format(album), + ) + + for item in album.items(): + embed_item(item, imagepath, maxwidth, None, + config['embedart']['compare_threshold'].get(int), + config['embedart']['ifempty'].get(bool), as_album=True) + + +def resize_image(imagepath, maxwidth): + """Returns path to an image resized to maxwidth. + """ + log.info(u'Resizing album art to {0} pixels wide' + .format(maxwidth)) + imagepath = ArtResizer.shared.resize(maxwidth, syspath(imagepath)) + return imagepath + + +def check_art_similarity(item, imagepath, compare_threshold): + """A boolean indicating if an image is similar to embedded item art. + """ + with NamedTemporaryFile(delete=True) as f: + art = extract(f.name, item) + + if art: + # Converting images to grayscale tends to minimize the weight + # of colors in the diff score + cmd = 'convert {0} {1} -colorspace gray MIFF:- | ' \ + 'compare -metric PHASH - null:'.format(syspath(imagepath), + syspath(art)) + + proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + close_fds=platform.system() != 'Windows', + shell=True) + stdout, stderr = proc.communicate() + if proc.returncode: + if proc.returncode != 1: + log.warn(u'embedart: IM phashes compare failed for {0}, \ + {1}'.format(displayable_path(imagepath), + displayable_path(art))) + return + phashDiff = float(stderr) + else: + phashDiff = float(stdout) + + log.info(u'embedart: compare PHASH score is {0}'.format(phashDiff)) + if phashDiff > compare_threshold: + return False + + return True + + +def _mediafile_image(image_path, maxwidth=None): + """Return a `mediafile.Image` object for the path. + """ + + with open(syspath(image_path), 'rb') as f: + data = f.read() + return mediafile.Image(data, type=mediafile.ImageType.front) + + +def get_art(item): + # Extract the art. + try: + mf = mediafile.MediaFile(syspath(item.path)) + except mediafile.UnreadableFileError as exc: + log.error(u'Could not extract art from {0}: {1}'.format( + displayable_path(item.path), exc + )) + return + + return mf.art + +# 'extractart' command. + + +def extract(outpath, item): + if not item: + log.error(u'No item matches query.') + return + + art = get_art(item) + + if not art: + log.error(u'No album art present in {0} - {1}.' + .format(item.artist, item.title)) + return + + # Add an extension to the filename. + ext = imghdr.what(None, h=art) + if not ext: + log.error(u'Unknown image type.') + return + outpath += '.' + ext + + log.info(u'Extracting album art from: {0.artist} - {0.title} ' + u'to: {1}'.format(item, displayable_path(outpath))) + with open(syspath(outpath), 'wb') as f: + f.write(art) + return outpath + + +# 'clearart' command. + +def clear(lib, query): + log.info(u'Clearing album art from items:') + for item in lib.items(query): + log.info(u'{0} - {1}'.format(item.artist, item.title)) + try: + mf = mediafile.MediaFile(syspath(item.path), + config['id3v23'].get(bool)) + except mediafile.UnreadableFileError as exc: + log.error(u'Could not clear art from {0}: {1}'.format( + displayable_path(item.path), exc + )) + continue + del mf.art + mf.save() diff --git a/lib/beetsplug/fetchart.py b/lib/beetsplug/fetchart.py new file mode 100644 index 00000000..b2a4620b --- /dev/null +++ b/lib/beetsplug/fetchart.py @@ -0,0 +1,396 @@ +# This file is part of beets. +# Copyright 2014, Adrian Sampson. +# +# Permission is hereby granted, free of charge, to any person obtaining +# a copy of this software and associated documentation files (the +# "Software"), to deal in the Software without restriction, including +# without limitation the rights to use, copy, modify, merge, publish, +# distribute, sublicense, and/or sell copies of the Software, and to +# permit persons to whom the Software is furnished to do so, subject to +# the following conditions: +# +# The above copyright notice and this permission notice shall be +# included in all copies or substantial portions of the Software. + +"""Fetches album art. +""" +from contextlib import closing +import logging +import os +import re +from tempfile import NamedTemporaryFile + +import requests + +from beets import plugins +from beets import importer +from beets import ui +from beets import util +from beets import config +from beets.util.artresizer import ArtResizer + +try: + import itunes + HAVE_ITUNES = True +except ImportError: + HAVE_ITUNES = False + +IMAGE_EXTENSIONS = ['png', 'jpg', 'jpeg'] +CONTENT_TYPES = ('image/jpeg',) +DOWNLOAD_EXTENSION = '.jpg' + +log = logging.getLogger('beets') + +requests_session = requests.Session() +requests_session.headers = {'User-Agent': 'beets'} + + +def _fetch_image(url): + """Downloads an image from a URL and checks whether it seems to + actually be an image. If so, returns a path to the downloaded image. + Otherwise, returns None. + """ + log.debug(u'fetchart: downloading art: {0}'.format(url)) + try: + with closing(requests_session.get(url, stream=True)) as resp: + if 'Content-Type' not in resp.headers \ + or resp.headers['Content-Type'] not in CONTENT_TYPES: + log.debug(u'fetchart: not an image') + return + + # Generate a temporary file with the correct extension. + with NamedTemporaryFile(suffix=DOWNLOAD_EXTENSION, delete=False) \ + as fh: + for chunk in resp.iter_content(): + fh.write(chunk) + log.debug(u'fetchart: downloaded art to: {0}'.format( + util.displayable_path(fh.name) + )) + return fh.name + except (IOError, requests.RequestException): + log.debug(u'fetchart: error fetching art') + + +# ART SOURCES ################################################################ + +# Cover Art Archive. + +CAA_URL = 'http://coverartarchive.org/release/{mbid}/front-500.jpg' +CAA_GROUP_URL = 'http://coverartarchive.org/release-group/{mbid}/front-500.jpg' + + +def caa_art(album): + """Return the Cover Art Archive and Cover Art Archive release group URLs + using album MusicBrainz release ID and release group ID. + """ + if album.mb_albumid: + yield CAA_URL.format(mbid=album.mb_albumid) + if album.mb_releasegroupid: + yield CAA_GROUP_URL.format(mbid=album.mb_releasegroupid) + + +# Art from Amazon. + +AMAZON_URL = 'http://images.amazon.com/images/P/%s.%02i.LZZZZZZZ.jpg' +AMAZON_INDICES = (1, 2) + + +def art_for_asin(album): + """Generate URLs using Amazon ID (ASIN) string. + """ + if album.asin: + for index in AMAZON_INDICES: + yield AMAZON_URL % (album.asin, index) + + +# AlbumArt.org scraper. + +AAO_URL = 'http://www.albumart.org/index_detail.php' +AAO_PAT = r'href\s*=\s*"([^>"]*)"[^>]*title\s*=\s*"View larger image"' + + +def aao_art(album): + """Return art URL from AlbumArt.org using album ASIN. + """ + if not album.asin: + return + # Get the page from albumart.org. + try: + resp = requests_session.get(AAO_URL, params={'asin': album.asin}) + log.debug(u'fetchart: scraped art URL: {0}'.format(resp.url)) + except requests.RequestException: + log.debug(u'fetchart: error scraping art page') + return + + # Search the page for the image URL. + m = re.search(AAO_PAT, resp.text) + if m: + image_url = m.group(1) + yield image_url + else: + log.debug(u'fetchart: no image found on page') + + +# Google Images scraper. + +GOOGLE_URL = 'https://ajax.googleapis.com/ajax/services/search/images' + + +def google_art(album): + """Return art URL from google.org given an album title and + interpreter. + """ + if not (album.albumartist and album.album): + return + search_string = (album.albumartist + ',' + album.album).encode('utf-8') + response = requests_session.get(GOOGLE_URL, params={ + 'v': '1.0', + 'q': search_string, + 'start': '0', + }) + + # Get results using JSON. + try: + results = response.json() + data = results['responseData'] + dataInfo = data['results'] + for myUrl in dataInfo: + yield myUrl['unescapedUrl'] + except: + log.debug(u'fetchart: error scraping art page') + return + + +# Art from the iTunes Store. + +def itunes_art(album): + """Return art URL from iTunes Store given an album title. + """ + search_string = (album.albumartist + ' ' + album.album).encode('utf-8') + try: + # Isolate bugs in the iTunes library while searching. + try: + itunes_album = itunes.search_album(search_string)[0] + except Exception as exc: + log.debug('fetchart: iTunes search failed: {0}'.format(exc)) + return + + if itunes_album.get_artwork()['100']: + small_url = itunes_album.get_artwork()['100'] + big_url = small_url.replace('100x100', '1200x1200') + yield big_url + else: + log.debug(u'fetchart: album has no artwork in iTunes Store') + except IndexError: + log.debug(u'fetchart: album not found in iTunes Store') + + +# Art from the filesystem. + + +def filename_priority(filename, cover_names): + """Sort order for image names. + + Return indexes of cover names found in the image filename. This + means that images with lower-numbered and more keywords will have higher + priority. + """ + return [idx for (idx, x) in enumerate(cover_names) if x in filename] + + +def art_in_path(path, cover_names, cautious): + """Look for album art files in a specified directory. + """ + if not os.path.isdir(path): + return + + # Find all files that look like images in the directory. + images = [] + for fn in os.listdir(path): + for ext in IMAGE_EXTENSIONS: + if fn.lower().endswith('.' + ext): + images.append(fn) + + # Look for "preferred" filenames. + images = sorted(images, key=lambda x: filename_priority(x, cover_names)) + cover_pat = r"(\b|_)({0})(\b|_)".format('|'.join(cover_names)) + for fn in images: + if re.search(cover_pat, os.path.splitext(fn)[0], re.I): + log.debug(u'fetchart: using well-named art file {0}'.format( + util.displayable_path(fn) + )) + return os.path.join(path, fn) + + # Fall back to any image in the folder. + if images and not cautious: + log.debug(u'fetchart: using fallback art file {0}'.format( + util.displayable_path(images[0]) + )) + return os.path.join(path, images[0]) + + +# Try each source in turn. + +SOURCES_ALL = [u'coverart', u'itunes', u'amazon', u'albumart', u'google'] + +ART_FUNCS = { + u'coverart': caa_art, + u'itunes': itunes_art, + u'albumart': aao_art, + u'amazon': art_for_asin, + u'google': google_art, +} + + +def _source_urls(album, sources=SOURCES_ALL): + """Generate possible source URLs for an album's art. The URLs are + not guaranteed to work so they each need to be attempted in turn. + This allows the main `art_for_album` function to abort iteration + through this sequence early to avoid the cost of scraping when not + necessary. + """ + for s in sources: + urls = ART_FUNCS[s](album) + for url in urls: + yield url + + +def art_for_album(album, paths, maxwidth=None, local_only=False): + """Given an Album object, returns a path to downloaded art for the + album (or None if no art is found). If `maxwidth`, then images are + resized to this maximum pixel size. If `local_only`, then only local + image files from the filesystem are returned; no network requests + are made. + """ + out = None + + # Local art. + cover_names = config['fetchart']['cover_names'].as_str_seq() + cover_names = map(util.bytestring_path, cover_names) + cautious = config['fetchart']['cautious'].get(bool) + if paths: + for path in paths: + out = art_in_path(path, cover_names, cautious) + if out: + break + + # Web art sources. + remote_priority = config['fetchart']['remote_priority'].get(bool) + if not local_only and (remote_priority or not out): + for url in _source_urls(album, + config['fetchart']['sources'].as_str_seq()): + if maxwidth: + url = ArtResizer.shared.proxy_url(maxwidth, url) + candidate = _fetch_image(url) + if candidate: + out = candidate + break + + if maxwidth and out: + out = ArtResizer.shared.resize(maxwidth, out) + return out + + +# PLUGIN LOGIC ############################################################### + + +def batch_fetch_art(lib, albums, force, maxwidth=None): + """Fetch album art for each of the albums. This implements the manual + fetchart CLI command. + """ + for album in albums: + if album.artpath and not force: + message = 'has album art' + else: + # In ordinary invocations, look for images on the + # filesystem. When forcing, however, always go to the Web + # sources. + local_paths = None if force else [album.path] + + path = art_for_album(album, local_paths, maxwidth) + if path: + album.set_art(path, False) + album.store() + message = ui.colorize('green', 'found album art') + else: + message = ui.colorize('red', 'no art found') + + log.info(u'{0} - {1}: {2}'.format(album.albumartist, album.album, + message)) + + +class FetchArtPlugin(plugins.BeetsPlugin): + def __init__(self): + super(FetchArtPlugin, self).__init__() + + self.config.add({ + 'auto': True, + 'maxwidth': 0, + 'remote_priority': False, + 'cautious': False, + 'google_search': False, + 'cover_names': ['cover', 'front', 'art', 'album', 'folder'], + 'sources': SOURCES_ALL, + }) + + # Holds paths to downloaded images between fetching them and + # placing them in the filesystem. + self.art_paths = {} + + self.maxwidth = self.config['maxwidth'].get(int) + if self.config['auto']: + # Enable two import hooks when fetching is enabled. + self.import_stages = [self.fetch_art] + self.register_listener('import_task_files', self.assign_art) + + available_sources = list(SOURCES_ALL) + if not HAVE_ITUNES and u'itunes' in available_sources: + available_sources.remove(u'itunes') + self.config['sources'] = plugins.sanitize_choices( + self.config['sources'].as_str_seq(), available_sources) + + # Asynchronous; after music is added to the library. + def fetch_art(self, session, task): + """Find art for the album being imported.""" + if task.is_album: # Only fetch art for full albums. + if task.choice_flag == importer.action.ASIS: + # For as-is imports, don't search Web sources for art. + local = True + elif task.choice_flag == importer.action.APPLY: + # Search everywhere for art. + local = False + else: + # For any other choices (e.g., TRACKS), do nothing. + return + + path = art_for_album(task.album, task.paths, self.maxwidth, local) + + if path: + self.art_paths[task] = path + + # Synchronous; after music files are put in place. + def assign_art(self, session, task): + """Place the discovered art in the filesystem.""" + if task in self.art_paths: + path = self.art_paths.pop(task) + + album = task.album + src_removed = (config['import']['delete'].get(bool) or + config['import']['move'].get(bool)) + album.set_art(path, not src_removed) + album.store() + if src_removed: + task.prune(path) + + # Manual album art fetching. + def commands(self): + cmd = ui.Subcommand('fetchart', help='download album art') + cmd.parser.add_option('-f', '--force', dest='force', + action='store_true', default=False, + help='re-download art when already present') + + def func(lib, opts, args): + batch_fetch_art(lib, lib.albums(ui.decargs(args)), opts.force, + self.maxwidth) + cmd.func = func + return [cmd] diff --git a/lib/beetsplug/lyrics.py b/lib/beetsplug/lyrics.py new file mode 100644 index 00000000..a2ebe7c3 --- /dev/null +++ b/lib/beetsplug/lyrics.py @@ -0,0 +1,544 @@ +# This file is part of beets. +# Copyright 2014, Adrian Sampson. +# +# Permission is hereby granted, free of charge, to any person obtaining +# a copy of this software and associated documentation files (the +# "Software"), to deal in the Software without restriction, including +# without limitation the rights to use, copy, modify, merge, publish, +# distribute, sublicense, and/or sell copies of the Software, and to +# permit persons to whom the Software is furnished to do so, subject to +# the following conditions: +# +# The above copyright notice and this permission notice shall be +# included in all copies or substantial portions of the Software. + +"""Fetches, embeds, and displays lyrics. +""" +from __future__ import print_function + +import re +import logging +import requests +import json +import unicodedata +import urllib +import difflib +import itertools +from HTMLParser import HTMLParseError + +from beets import plugins +from beets import config, ui + + +# Global logger. + +log = logging.getLogger('beets') + +DIV_RE = re.compile(r'<(/?)div>?', re.I) +COMMENT_RE = re.compile(r'', re.S) +TAG_RE = re.compile(r'<[^>]*>') +BREAK_RE = re.compile(r'\n?\s*]*)*>\s*\n?', re.I) +URL_CHARACTERS = { + u'\u2018': u"'", + u'\u2019': u"'", + u'\u201c': u'"', + u'\u201d': u'"', + u'\u2010': u'-', + u'\u2011': u'-', + u'\u2012': u'-', + u'\u2013': u'-', + u'\u2014': u'-', + u'\u2015': u'-', + u'\u2016': u'-', + u'\u2026': u'...', +} + + +# Utilities. + +def fetch_url(url): + """Retrieve the content at a given URL, or return None if the source + is unreachable. + """ + try: + r = requests.get(url, verify=False) + except requests.RequestException as exc: + log.debug(u'lyrics request failed: {0}'.format(exc)) + return + if r.status_code == requests.codes.ok: + return r.text + else: + log.debug(u'failed to fetch: {0} ({1})'.format(url, r.status_code)) + + +def unescape(text): + """Resolves &#xxx; HTML entities (and some others).""" + if isinstance(text, str): + text = text.decode('utf8', 'ignore') + out = text.replace(u' ', u' ') + + def replchar(m): + num = m.group(1) + return unichr(int(num)) + out = re.sub(u"&#(\d+);", replchar, out) + return out + + +def extract_text_between(html, start_marker, end_marker): + try: + _, html = html.split(start_marker, 1) + html, _ = html.split(end_marker, 1) + except ValueError: + return u'' + return html + + +def extract_text_in(html, starttag): + """Extract the text from a
tag in the HTML starting with + ``starttag``. Returns None if parsing fails. + """ + + # Strip off the leading text before opening tag. + try: + _, html = html.split(starttag, 1) + except ValueError: + return + + # Walk through balanced DIV tags. + level = 0 + parts = [] + pos = 0 + for match in DIV_RE.finditer(html): + if match.group(1): # Closing tag. + level -= 1 + if level == 0: + pos = match.end() + else: # Opening tag. + if level == 0: + parts.append(html[pos:match.start()]) + level += 1 + + if level == -1: + parts.append(html[pos:match.start()]) + break + else: + print('no closing tag found!') + return + return u''.join(parts) + + +def search_pairs(item): + """Yield a pairs of artists and titles to search for. + + The first item in the pair is the name of the artist, the second + item is a list of song names. + + In addition to the artist and title obtained from the `item` the + method tries to strip extra information like paranthesized suffixes + and featured artists from the strings and add them as candidates. + The method also tries to split multiple titles separated with `/`. + """ + + title, artist = item.title, item.artist + titles = [title] + artists = [artist] + + # Remove any featuring artists from the artists name + pattern = r"(.*?) {0}".format(plugins.feat_tokens()) + match = re.search(pattern, artist, re.IGNORECASE) + if match: + artists.append(match.group(1)) + + # Remove a parenthesized suffix from a title string. Common + # examples include (live), (remix), and (acoustic). + pattern = r"(.+?)\s+[(].*[)]$" + match = re.search(pattern, title, re.IGNORECASE) + if match: + titles.append(match.group(1)) + + # Remove any featuring artists from the title + pattern = r"(.*?) {0}".format(plugins.feat_tokens(for_artist=False)) + for title in titles[:]: + match = re.search(pattern, title, re.IGNORECASE) + if match: + titles.append(match.group(1)) + + # Check for a dual song (e.g. Pink Floyd - Speak to Me / Breathe) + # and each of them. + multi_titles = [] + for title in titles: + multi_titles.append([title]) + if '/' in title: + multi_titles.append([x.strip() for x in title.split('/')]) + + return itertools.product(artists, multi_titles) + + +def _encode(s): + """Encode the string for inclusion in a URL (common to both + LyricsWiki and Lyrics.com). + """ + if isinstance(s, unicode): + for char, repl in URL_CHARACTERS.items(): + s = s.replace(char, repl) + s = s.encode('utf8', 'ignore') + return urllib.quote(s) + +# Musixmatch + +MUSIXMATCH_URL_PATTERN = 'https://www.musixmatch.com/lyrics/%s/%s' + + +def fetch_musixmatch(artist, title): + url = MUSIXMATCH_URL_PATTERN % (_lw_encode(artist.title()), + _lw_encode(title.title())) + html = fetch_url(url) + if not html: + return + lyrics = extract_text_between(html, '"lyrics_body":', '"lyrics_language":') + return lyrics.strip(',"').replace('\\n', '\n') + +# LyricsWiki. + +LYRICSWIKI_URL_PATTERN = 'http://lyrics.wikia.com/%s:%s' + + +def _lw_encode(s): + s = re.sub(r'\s+', '_', s) + s = s.replace("<", "Less_Than") + s = s.replace(">", "Greater_Than") + s = s.replace("#", "Number_") + s = re.sub(r'[\[\{]', '(', s) + s = re.sub(r'[\]\}]', ')', s) + return _encode(s) + + +def fetch_lyricswiki(artist, title): + """Fetch lyrics from LyricsWiki.""" + url = LYRICSWIKI_URL_PATTERN % (_lw_encode(artist), _lw_encode(title)) + html = fetch_url(url) + if not html: + return + + lyrics = extract_text_in(html, u"
") + if lyrics and 'Unfortunately, we are not licensed' not in lyrics: + return lyrics + + +# Lyrics.com. + +LYRICSCOM_URL_PATTERN = 'http://www.lyrics.com/%s-lyrics-%s.html' +LYRICSCOM_NOT_FOUND = ( + 'Sorry, we do not have the lyric', + 'Submit Lyrics', +) + + +def _lc_encode(s): + s = re.sub(r'[^\w\s-]', '', s) + s = re.sub(r'\s+', '-', s) + return _encode(s).lower() + + +def fetch_lyricscom(artist, title): + """Fetch lyrics from Lyrics.com.""" + url = LYRICSCOM_URL_PATTERN % (_lc_encode(title), _lc_encode(artist)) + html = fetch_url(url) + if not html: + return + lyrics = extract_text_between(html, '
', '
') + if not lyrics: + return + for not_found_str in LYRICSCOM_NOT_FOUND: + if not_found_str in lyrics: + return + + parts = lyrics.split('\n---\nLyrics powered by', 1) + if parts: + return parts[0] + + +# Optional Google custom search API backend. + +def slugify(text): + """Normalize a string and remove non-alphanumeric characters. + """ + text = re.sub(r"[-'_\s]", '_', text) + text = re.sub(r"_+", '_', text).strip('_') + pat = "([^,\(]*)\((.*?)\)" # Remove content within parentheses + text = re.sub(pat, '\g<1>', text).strip() + try: + text = unicodedata.normalize('NFKD', text).encode('ascii', 'ignore') + text = unicode(re.sub('[-\s]+', ' ', text)) + except UnicodeDecodeError: + log.exception(u"Failing to normalize '{0}'".format(text)) + return text + + +BY_TRANS = ['by', 'par', 'de', 'von'] +LYRICS_TRANS = ['lyrics', 'paroles', 'letras', 'liedtexte'] + + +def is_page_candidate(urlLink, urlTitle, title, artist): + """Return True if the URL title makes it a good candidate to be a + page that contains lyrics of title by artist. + """ + title = slugify(title.lower()) + artist = slugify(artist.lower()) + sitename = re.search(u"//([^/]+)/.*", slugify(urlLink.lower())).group(1) + urlTitle = slugify(urlTitle.lower()) + # Check if URL title contains song title (exact match) + if urlTitle.find(title) != -1: + return True + # or try extracting song title from URL title and check if + # they are close enough + tokens = [by + '_' + artist for by in BY_TRANS] + \ + [artist, sitename, sitename.replace('www.', '')] + LYRICS_TRANS + songTitle = re.sub(u'(%s)' % u'|'.join(tokens), u'', urlTitle) + songTitle = songTitle.strip('_|') + typoRatio = .9 + return difflib.SequenceMatcher(None, songTitle, title).ratio() >= typoRatio + + +def remove_credits(text): + """Remove first/last line of text if it contains the word 'lyrics' + eg 'Lyrics by songsdatabase.com' + """ + textlines = text.split('\n') + credits = None + for i in (0, -1): + if textlines and 'lyrics' in textlines[i].lower(): + credits = textlines.pop(i) + if credits: + text = '\n'.join(textlines) + return text + + +def is_lyrics(text, artist=None): + """Determine whether the text seems to be valid lyrics. + """ + if not text: + return False + badTriggersOcc = [] + nbLines = text.count('\n') + if nbLines <= 1: + log.debug(u"Ignoring too short lyrics '{0}'".format(text)) + return False + elif nbLines < 5: + badTriggersOcc.append('too_short') + else: + # Lyrics look legit, remove credits to avoid being penalized further + # down + text = remove_credits(text) + + badTriggers = ['lyrics', 'copyright', 'property', 'links'] + if artist: + badTriggersOcc += [artist] + + for item in badTriggers: + badTriggersOcc += [item] * len(re.findall(r'\W%s\W' % item, + text, re.I)) + + if badTriggersOcc: + log.debug(u'Bad triggers detected: {0}'.format(badTriggersOcc)) + return len(badTriggersOcc) < 2 + + +def _scrape_strip_cruft(html, plain_text_out=False): + """Clean up HTML + """ + html = unescape(html) + + html = html.replace('\r', '\n') # Normalize EOL. + html = re.sub(r' +', ' ', html) # Whitespaces collapse. + html = BREAK_RE.sub('\n', html) #
eats up surrounding '\n'. + html = re.sub(r'<(script).*?(?s)', '', html) # Strip script tags. + + if plain_text_out: # Strip remaining HTML tags + html = COMMENT_RE.sub('', html) + html = TAG_RE.sub('', html) + + html = '\n'.join([x.strip() for x in html.strip().split('\n')]) + html = re.sub(r'\n{3,}', r'\n\n', html) + return html + + +def _scrape_merge_paragraphs(html): + html = re.sub(r'

\s*]*)>', '\n', html) + return re.sub(r'
\s*
', '\n', html) + + +def scrape_lyrics_from_html(html): + """Scrape lyrics from a URL. If no lyrics can be found, return None + instead. + """ + from bs4 import SoupStrainer, BeautifulSoup + + if not html: + return None + + def is_text_notcode(text): + length = len(text) + return (length > 20 and + text.count(' ') > length / 25 and + (text.find('{') == -1 or text.find(';') == -1)) + html = _scrape_strip_cruft(html) + html = _scrape_merge_paragraphs(html) + + # extract all long text blocks that are not code + try: + soup = BeautifulSoup(html, "html.parser", + parse_only=SoupStrainer(text=is_text_notcode)) + except HTMLParseError: + return None + soup = sorted(soup.stripped_strings, key=len)[-1] + return soup + + +def fetch_google(artist, title): + """Fetch lyrics from Google search results. + """ + query = u"%s %s" % (artist, title) + api_key = config['lyrics']['google_API_key'].get(unicode) + engine_id = config['lyrics']['google_engine_ID'].get(unicode) + url = u'https://www.googleapis.com/customsearch/v1?key=%s&cx=%s&q=%s' % \ + (api_key, engine_id, urllib.quote(query.encode('utf8'))) + + data = urllib.urlopen(url) + data = json.load(data) + if 'error' in data: + reason = data['error']['errors'][0]['reason'] + log.debug(u'google lyrics backend error: {0}'.format(reason)) + return + + if 'items' in data.keys(): + for item in data['items']: + urlLink = item['link'] + urlTitle = item.get('title', u'') + if not is_page_candidate(urlLink, urlTitle, title, artist): + continue + html = fetch_url(urlLink) + lyrics = scrape_lyrics_from_html(html) + if not lyrics: + continue + + if is_lyrics(lyrics, artist): + log.debug(u'got lyrics from {0}'.format(item['displayLink'])) + return lyrics + + +# Plugin logic. + +SOURCES = ['google', 'lyricwiki', 'lyrics.com', 'musixmatch'] +SOURCE_BACKENDS = { + 'google': fetch_google, + 'lyricwiki': fetch_lyricswiki, + 'lyrics.com': fetch_lyricscom, + 'musixmatch': fetch_musixmatch, +} + + +class LyricsPlugin(plugins.BeetsPlugin): + def __init__(self): + super(LyricsPlugin, self).__init__() + self.import_stages = [self.imported] + self.config.add({ + 'auto': True, + 'google_API_key': None, + 'google_engine_ID': u'009217259823014548361:lndtuqkycfu', + 'fallback': None, + 'force': False, + 'sources': SOURCES, + }) + + available_sources = list(SOURCES) + if not self.config['google_API_key'].get() and \ + 'google' in SOURCES: + available_sources.remove('google') + self.config['sources'] = plugins.sanitize_choices( + self.config['sources'].as_str_seq(), available_sources) + self.backends = [] + for key in self.config['sources'].as_str_seq(): + self.backends.append(SOURCE_BACKENDS[key]) + + def commands(self): + cmd = ui.Subcommand('lyrics', help='fetch song lyrics') + cmd.parser.add_option('-p', '--print', dest='printlyr', + action='store_true', default=False, + help='print lyrics to console') + cmd.parser.add_option('-f', '--force', dest='force_refetch', + action='store_true', default=False, + help='always re-download lyrics') + + def func(lib, opts, args): + # The "write to files" option corresponds to the + # import_write config value. + write = config['import']['write'].get(bool) + for item in lib.items(ui.decargs(args)): + self.fetch_item_lyrics( + lib, logging.INFO, item, write, + opts.force_refetch or self.config['force'], + ) + if opts.printlyr and item.lyrics: + ui.print_(item.lyrics) + + cmd.func = func + return [cmd] + + def imported(self, session, task): + """Import hook for fetching lyrics automatically. + """ + if self.config['auto']: + for item in task.imported_items(): + self.fetch_item_lyrics(session.lib, logging.DEBUG, item, + False, self.config['force']) + + def fetch_item_lyrics(self, lib, loglevel, item, write, force): + """Fetch and store lyrics for a single item. If ``write``, then the + lyrics will also be written to the file itself. The ``loglevel`` + parameter controls the visibility of the function's status log + messages. + """ + # Skip if the item already has lyrics. + if not force and item.lyrics: + log.log(loglevel, u'lyrics already present: {0} - {1}' + .format(item.artist, item.title)) + return + + lyrics = None + for artist, titles in search_pairs(item): + lyrics = [self.get_lyrics(artist, title) for title in titles] + if any(lyrics): + break + + lyrics = u"\n\n---\n\n".join([l for l in lyrics if l]) + + if lyrics: + log.log(loglevel, u'fetched lyrics: {0} - {1}' + .format(item.artist, item.title)) + else: + log.log(loglevel, u'lyrics not found: {0} - {1}' + .format(item.artist, item.title)) + fallback = self.config['fallback'].get() + if fallback: + lyrics = fallback + else: + return + + item.lyrics = lyrics + + if write: + item.try_write() + item.store() + + def get_lyrics(self, artist, title): + """Fetch lyrics, trying each source in turn. Return a string or + None if no lyrics were found. + """ + for backend in self.backends: + lyrics = backend(artist, title) + if lyrics: + log.debug(u'got lyrics from backend: {0}' + .format(backend.__name__)) + return _scrape_strip_cruft(lyrics, True) diff --git a/headphones/bencode.py b/lib/bencode.py similarity index 92% rename from headphones/bencode.py rename to lib/bencode.py index 4fa1db71..3181ee1c 100644 --- a/headphones/bencode.py +++ b/lib/bencode.py @@ -14,9 +14,6 @@ # Minor modifications made by Andrew Resch to replace the BTFailure errors with Exceptions -from types import StringType, IntType, LongType, DictType, ListType, TupleType - - def decode_int(x, f): f += 1 newf = x.index('e', f) @@ -24,36 +21,32 @@ def decode_int(x, f): if x[f] == '-': if x[f + 1] == '0': raise ValueError - elif x[f] == '0' and newf != f + 1: + elif x[f] == '0' and newf != f+1: raise ValueError - return (n, newf + 1) - + return (n, newf+1) def decode_string(x, f): colon = x.index(':', f) n = int(x[f:colon]) - if x[f] == '0' and colon != f + 1: + if x[f] == '0' and colon != f+1: raise ValueError colon += 1 - return (x[colon:colon + n], colon + n) - + return (x[colon:colon+n], colon+n) def decode_list(x, f): - r, f = [], f + 1 + r, f = [], f+1 while x[f] != 'e': v, f = decode_func[x[f]](x, f) r.append(v) return (r, f + 1) - def decode_dict(x, f): - r, f = {}, f + 1 + r, f = {}, f+1 while x[f] != 'e': k, f = decode_string(x, f) r[k], f = decode_func[x[f]](x, f) return (r, f + 1) - decode_func = {} decode_func['l'] = decode_list decode_func['d'] = decode_dict @@ -69,7 +62,6 @@ decode_func['7'] = decode_string decode_func['8'] = decode_string decode_func['9'] = decode_string - def bdecode(x): try: r, l = decode_func[x[0]](x, 0) @@ -78,41 +70,38 @@ def bdecode(x): return r +from types import StringType, IntType, LongType, DictType, ListType, TupleType + class Bencached(object): + __slots__ = ['bencoded'] def __init__(self, s): self.bencoded = s - -def encode_bencached(x, r): +def encode_bencached(x,r): r.append(x.bencoded) - def encode_int(x, r): r.extend(('i', str(x), 'e')) - def encode_bool(x, r): if x: encode_int(1, r) else: encode_int(0, r) - def encode_string(x, r): r.extend((str(len(x)), ':', x)) - def encode_list(x, r): r.append('l') for i in x: encode_func[type(i)](i, r) r.append('e') - -def encode_dict(x, r): +def encode_dict(x,r): r.append('d') ilist = x.items() ilist.sort() @@ -121,7 +110,6 @@ def encode_dict(x, r): encode_func[type(v)](v, r) r.append('e') - encode_func = {} encode_func[Bencached] = encode_bencached encode_func[IntType] = encode_int @@ -133,12 +121,10 @@ encode_func[DictType] = encode_dict try: from types import BooleanType - encode_func[BooleanType] = encode_bool except ImportError: pass - def bencode(x): r = [] encode_func[type(x)](x, r) diff --git a/lib/biplist/__init__.py b/lib/biplist/__init__.py new file mode 100755 index 00000000..fe5b711b --- /dev/null +++ b/lib/biplist/__init__.py @@ -0,0 +1,803 @@ +"""biplist -- a library for reading and writing binary property list files. + +Binary Property List (plist) files provide a faster and smaller serialization +format for property lists on OS X. This is a library for generating binary +plists which can be read by OS X, iOS, or other clients. + +The API models the plistlib API, and will call through to plistlib when +XML serialization or deserialization is required. + +To generate plists with UID values, wrap the values with the Uid object. The +value must be an int. + +To generate plists with NSData/CFData values, wrap the values with the +Data object. The value must be a string. + +Date values can only be datetime.datetime objects. + +The exceptions InvalidPlistException and NotBinaryPlistException may be +thrown to indicate that the data cannot be serialized or deserialized as +a binary plist. + +Plist generation example: + + from biplist import * + from datetime import datetime + plist = {'aKey':'aValue', + '0':1.322, + 'now':datetime.now(), + 'list':[1,2,3], + 'tuple':('a','b','c') + } + try: + writePlist(plist, "example.plist") + except (InvalidPlistException, NotBinaryPlistException), e: + print "Something bad happened:", e + +Plist parsing example: + + from biplist import * + try: + plist = readPlist("example.plist") + print plist + except (InvalidPlistException, NotBinaryPlistException), e: + print "Not a plist:", e +""" + +import sys +from collections import namedtuple +import datetime +import io +import math +import plistlib +from struct import pack, unpack +from struct import error as struct_error +import sys +import time + +try: + unicode + unicodeEmpty = r'' +except NameError: + unicode = str + unicodeEmpty = '' +try: + long +except NameError: + long = int +try: + {}.iteritems + iteritems = lambda x: x.iteritems() +except AttributeError: + iteritems = lambda x: x.items() + +__all__ = [ + 'Uid', 'Data', 'readPlist', 'writePlist', 'readPlistFromString', + 'writePlistToString', 'InvalidPlistException', 'NotBinaryPlistException' +] + +# Apple uses Jan 1, 2001 as a base for all plist date/times. +apple_reference_date = datetime.datetime.utcfromtimestamp(978307200) + +class Uid(int): + """Wrapper around integers for representing UID values. This + is used in keyed archiving.""" + def __repr__(self): + return "Uid(%d)" % self + +class Data(bytes): + """Wrapper around str types for representing Data values.""" + pass + +class InvalidPlistException(Exception): + """Raised when the plist is incorrectly formatted.""" + pass + +class NotBinaryPlistException(Exception): + """Raised when a binary plist was expected but not encountered.""" + pass + +def readPlist(pathOrFile): + """Raises NotBinaryPlistException, InvalidPlistException""" + didOpen = False + result = None + if isinstance(pathOrFile, (bytes, unicode)): + pathOrFile = open(pathOrFile, 'rb') + didOpen = True + try: + reader = PlistReader(pathOrFile) + result = reader.parse() + except NotBinaryPlistException as e: + try: + pathOrFile.seek(0) + result = None + if hasattr(plistlib, 'loads'): + contents = None + if isinstance(pathOrFile, (bytes, unicode)): + with open(pathOrFile, 'rb') as f: + contents = f.read() + else: + contents = pathOrFile.read() + result = plistlib.loads(contents) + else: + result = plistlib.readPlist(pathOrFile) + result = wrapDataObject(result, for_binary=True) + except Exception as e: + raise InvalidPlistException(e) + finally: + if didOpen: + pathOrFile.close() + return result + +def wrapDataObject(o, for_binary=False): + if isinstance(o, Data) and not for_binary: + v = sys.version_info + if not (v[0] >= 3 and v[1] >= 4): + o = plistlib.Data(o) + elif isinstance(o, (bytes, plistlib.Data)) and for_binary: + if hasattr(o, 'data'): + o = Data(o.data) + elif isinstance(o, tuple): + o = wrapDataObject(list(o), for_binary) + o = tuple(o) + elif isinstance(o, list): + for i in range(len(o)): + o[i] = wrapDataObject(o[i], for_binary) + elif isinstance(o, dict): + for k in o: + o[k] = wrapDataObject(o[k], for_binary) + return o + +def writePlist(rootObject, pathOrFile, binary=True): + if not binary: + rootObject = wrapDataObject(rootObject, binary) + if hasattr(plistlib, "dump"): + if isinstance(pathOrFile, (bytes, unicode)): + with open(pathOrFile, 'wb') as f: + return plistlib.dump(rootObject, f) + else: + return plistlib.dump(rootObject, pathOrFile) + else: + return plistlib.writePlist(rootObject, pathOrFile) + else: + didOpen = False + if isinstance(pathOrFile, (bytes, unicode)): + pathOrFile = open(pathOrFile, 'wb') + didOpen = True + writer = PlistWriter(pathOrFile) + result = writer.writeRoot(rootObject) + if didOpen: + pathOrFile.close() + return result + +def readPlistFromString(data): + return readPlist(io.BytesIO(data)) + +def writePlistToString(rootObject, binary=True): + if not binary: + rootObject = wrapDataObject(rootObject, binary) + if hasattr(plistlib, "dumps"): + return plistlib.dumps(rootObject) + elif hasattr(plistlib, "writePlistToBytes"): + return plistlib.writePlistToBytes(rootObject) + else: + return plistlib.writePlistToString(rootObject) + else: + ioObject = io.BytesIO() + writer = PlistWriter(ioObject) + writer.writeRoot(rootObject) + return ioObject.getvalue() + +def is_stream_binary_plist(stream): + stream.seek(0) + header = stream.read(7) + if header == b'bplist0': + return True + else: + return False + +PlistTrailer = namedtuple('PlistTrailer', 'offsetSize, objectRefSize, offsetCount, topLevelObjectNumber, offsetTableOffset') +PlistByteCounts = namedtuple('PlistByteCounts', 'nullBytes, boolBytes, intBytes, realBytes, dateBytes, dataBytes, stringBytes, uidBytes, arrayBytes, setBytes, dictBytes') + +class PlistReader(object): + file = None + contents = '' + offsets = None + trailer = None + currentOffset = 0 + + def __init__(self, fileOrStream): + """Raises NotBinaryPlistException.""" + self.reset() + self.file = fileOrStream + + def parse(self): + return self.readRoot() + + def reset(self): + self.trailer = None + self.contents = '' + self.offsets = [] + self.currentOffset = 0 + + def readRoot(self): + result = None + self.reset() + # Get the header, make sure it's a valid file. + if not is_stream_binary_plist(self.file): + raise NotBinaryPlistException() + self.file.seek(0) + self.contents = self.file.read() + if len(self.contents) < 32: + raise InvalidPlistException("File is too short.") + trailerContents = self.contents[-32:] + try: + self.trailer = PlistTrailer._make(unpack("!xxxxxxBBQQQ", trailerContents)) + offset_size = self.trailer.offsetSize * self.trailer.offsetCount + offset = self.trailer.offsetTableOffset + offset_contents = self.contents[offset:offset+offset_size] + offset_i = 0 + while offset_i < self.trailer.offsetCount: + begin = self.trailer.offsetSize*offset_i + tmp_contents = offset_contents[begin:begin+self.trailer.offsetSize] + tmp_sized = self.getSizedInteger(tmp_contents, self.trailer.offsetSize) + self.offsets.append(tmp_sized) + offset_i += 1 + self.setCurrentOffsetToObjectNumber(self.trailer.topLevelObjectNumber) + result = self.readObject() + except TypeError as e: + raise InvalidPlistException(e) + return result + + def setCurrentOffsetToObjectNumber(self, objectNumber): + self.currentOffset = self.offsets[objectNumber] + + def readObject(self): + result = None + tmp_byte = self.contents[self.currentOffset:self.currentOffset+1] + marker_byte = unpack("!B", tmp_byte)[0] + format = (marker_byte >> 4) & 0x0f + extra = marker_byte & 0x0f + self.currentOffset += 1 + + def proc_extra(extra): + if extra == 0b1111: + #self.currentOffset += 1 + extra = self.readObject() + return extra + + # bool, null, or fill byte + if format == 0b0000: + if extra == 0b0000: + result = None + elif extra == 0b1000: + result = False + elif extra == 0b1001: + result = True + elif extra == 0b1111: + pass # fill byte + else: + raise InvalidPlistException("Invalid object found at offset: %d" % (self.currentOffset - 1)) + # int + elif format == 0b0001: + extra = proc_extra(extra) + result = self.readInteger(pow(2, extra)) + # real + elif format == 0b0010: + extra = proc_extra(extra) + result = self.readReal(extra) + # date + elif format == 0b0011 and extra == 0b0011: + result = self.readDate() + # data + elif format == 0b0100: + extra = proc_extra(extra) + result = self.readData(extra) + # ascii string + elif format == 0b0101: + extra = proc_extra(extra) + result = self.readAsciiString(extra) + # Unicode string + elif format == 0b0110: + extra = proc_extra(extra) + result = self.readUnicode(extra) + # uid + elif format == 0b1000: + result = self.readUid(extra) + # array + elif format == 0b1010: + extra = proc_extra(extra) + result = self.readArray(extra) + # set + elif format == 0b1100: + extra = proc_extra(extra) + result = set(self.readArray(extra)) + # dict + elif format == 0b1101: + extra = proc_extra(extra) + result = self.readDict(extra) + else: + raise InvalidPlistException("Invalid object found: {format: %s, extra: %s}" % (bin(format), bin(extra))) + return result + + def readInteger(self, byteSize): + result = 0 + original_offset = self.currentOffset + data = self.contents[self.currentOffset:self.currentOffset + byteSize] + result = self.getSizedInteger(data, byteSize, as_number=True) + self.currentOffset = original_offset + byteSize + return result + + def readReal(self, length): + result = 0.0 + to_read = pow(2, length) + data = self.contents[self.currentOffset:self.currentOffset+to_read] + if length == 2: # 4 bytes + result = unpack('>f', data)[0] + elif length == 3: # 8 bytes + result = unpack('>d', data)[0] + else: + raise InvalidPlistException("Unknown real of length %d bytes" % to_read) + return result + + def readRefs(self, count): + refs = [] + i = 0 + while i < count: + fragment = self.contents[self.currentOffset:self.currentOffset+self.trailer.objectRefSize] + ref = self.getSizedInteger(fragment, len(fragment)) + refs.append(ref) + self.currentOffset += self.trailer.objectRefSize + i += 1 + return refs + + def readArray(self, count): + result = [] + values = self.readRefs(count) + i = 0 + while i < len(values): + self.setCurrentOffsetToObjectNumber(values[i]) + value = self.readObject() + result.append(value) + i += 1 + return result + + def readDict(self, count): + result = {} + keys = self.readRefs(count) + values = self.readRefs(count) + i = 0 + while i < len(keys): + self.setCurrentOffsetToObjectNumber(keys[i]) + key = self.readObject() + self.setCurrentOffsetToObjectNumber(values[i]) + value = self.readObject() + result[key] = value + i += 1 + return result + + def readAsciiString(self, length): + result = unpack("!%ds" % length, self.contents[self.currentOffset:self.currentOffset+length])[0] + self.currentOffset += length + return result + + def readUnicode(self, length): + actual_length = length*2 + data = self.contents[self.currentOffset:self.currentOffset+actual_length] + # unpack not needed?!! data = unpack(">%ds" % (actual_length), data)[0] + self.currentOffset += actual_length + return data.decode('utf_16_be') + + def readDate(self): + result = unpack(">d", self.contents[self.currentOffset:self.currentOffset+8])[0] + # Use timedelta to workaround time_t size limitation on 32-bit python. + result = datetime.timedelta(seconds=result) + apple_reference_date + self.currentOffset += 8 + return result + + def readData(self, length): + result = self.contents[self.currentOffset:self.currentOffset+length] + self.currentOffset += length + return Data(result) + + def readUid(self, length): + return Uid(self.readInteger(length+1)) + + def getSizedInteger(self, data, byteSize, as_number=False): + """Numbers of 8 bytes are signed integers when they refer to numbers, but unsigned otherwise.""" + result = 0 + # 1, 2, and 4 byte integers are unsigned + if byteSize == 1: + result = unpack('>B', data)[0] + elif byteSize == 2: + result = unpack('>H', data)[0] + elif byteSize == 4: + result = unpack('>L', data)[0] + elif byteSize == 8: + if as_number: + result = unpack('>q', data)[0] + else: + result = unpack('>Q', data)[0] + elif byteSize <= 16: + # Handle odd-sized or integers larger than 8 bytes + # Don't naively go over 16 bytes, in order to prevent infinite loops. + result = 0 + if hasattr(int, 'from_bytes'): + result = int.from_bytes(data, 'big') + else: + for byte in data: + result = (result << 8) | unpack('>B', byte)[0] + else: + raise InvalidPlistException("Encountered integer longer than 16 bytes.") + return result + +class HashableWrapper(object): + def __init__(self, value): + self.value = value + def __repr__(self): + return "" % [self.value] + +class BoolWrapper(object): + def __init__(self, value): + self.value = value + def __repr__(self): + return "" % self.value + +class FloatWrapper(object): + _instances = {} + def __new__(klass, value): + # Ensure FloatWrapper(x) for a given float x is always the same object + wrapper = klass._instances.get(value) + if wrapper is None: + wrapper = object.__new__(klass) + wrapper.value = value + klass._instances[value] = wrapper + return wrapper + def __repr__(self): + return "" % self.value + +class PlistWriter(object): + header = b'bplist00bybiplist1.0' + file = None + byteCounts = None + trailer = None + computedUniques = None + writtenReferences = None + referencePositions = None + wrappedTrue = None + wrappedFalse = None + + def __init__(self, file): + self.reset() + self.file = file + self.wrappedTrue = BoolWrapper(True) + self.wrappedFalse = BoolWrapper(False) + + def reset(self): + self.byteCounts = PlistByteCounts(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0) + self.trailer = PlistTrailer(0, 0, 0, 0, 0) + + # A set of all the uniques which have been computed. + self.computedUniques = set() + # A list of all the uniques which have been written. + self.writtenReferences = {} + # A dict of the positions of the written uniques. + self.referencePositions = {} + + def positionOfObjectReference(self, obj): + """If the given object has been written already, return its + position in the offset table. Otherwise, return None.""" + return self.writtenReferences.get(obj) + + def writeRoot(self, root): + """ + Strategy is: + - write header + - wrap root object so everything is hashable + - compute size of objects which will be written + - need to do this in order to know how large the object refs + will be in the list/dict/set reference lists + - write objects + - keep objects in writtenReferences + - keep positions of object references in referencePositions + - write object references with the length computed previously + - computer object reference length + - write object reference positions + - write trailer + """ + output = self.header + wrapped_root = self.wrapRoot(root) + should_reference_root = True#not isinstance(wrapped_root, HashableWrapper) + self.computeOffsets(wrapped_root, asReference=should_reference_root, isRoot=True) + self.trailer = self.trailer._replace(**{'objectRefSize':self.intSize(len(self.computedUniques))}) + (_, output) = self.writeObjectReference(wrapped_root, output) + output = self.writeObject(wrapped_root, output, setReferencePosition=True) + + # output size at this point is an upper bound on how big the + # object reference offsets need to be. + self.trailer = self.trailer._replace(**{ + 'offsetSize':self.intSize(len(output)), + 'offsetCount':len(self.computedUniques), + 'offsetTableOffset':len(output), + 'topLevelObjectNumber':0 + }) + + output = self.writeOffsetTable(output) + output += pack('!xxxxxxBBQQQ', *self.trailer) + self.file.write(output) + + def wrapRoot(self, root): + if isinstance(root, bool): + if root is True: + return self.wrappedTrue + else: + return self.wrappedFalse + elif isinstance(root, float): + return FloatWrapper(root) + elif isinstance(root, set): + n = set() + for value in root: + n.add(self.wrapRoot(value)) + return HashableWrapper(n) + elif isinstance(root, dict): + n = {} + for key, value in iteritems(root): + n[self.wrapRoot(key)] = self.wrapRoot(value) + return HashableWrapper(n) + elif isinstance(root, list): + n = [] + for value in root: + n.append(self.wrapRoot(value)) + return HashableWrapper(n) + elif isinstance(root, tuple): + n = tuple([self.wrapRoot(value) for value in root]) + return HashableWrapper(n) + else: + return root + + def incrementByteCount(self, field, incr=1): + self.byteCounts = self.byteCounts._replace(**{field:self.byteCounts.__getattribute__(field) + incr}) + + def computeOffsets(self, obj, asReference=False, isRoot=False): + def check_key(key): + if key is None: + raise InvalidPlistException('Dictionary keys cannot be null in plists.') + elif isinstance(key, Data): + raise InvalidPlistException('Data cannot be dictionary keys in plists.') + elif not isinstance(key, (bytes, unicode)): + raise InvalidPlistException('Keys must be strings.') + + def proc_size(size): + if size > 0b1110: + size += self.intSize(size) + return size + # If this should be a reference, then we keep a record of it in the + # uniques table. + if asReference: + if obj in self.computedUniques: + return + else: + self.computedUniques.add(obj) + + if obj is None: + self.incrementByteCount('nullBytes') + elif isinstance(obj, BoolWrapper): + self.incrementByteCount('boolBytes') + elif isinstance(obj, Uid): + size = self.intSize(obj) + self.incrementByteCount('uidBytes', incr=1+size) + elif isinstance(obj, (int, long)): + size = self.intSize(obj) + self.incrementByteCount('intBytes', incr=1+size) + elif isinstance(obj, FloatWrapper): + size = self.realSize(obj) + self.incrementByteCount('realBytes', incr=1+size) + elif isinstance(obj, datetime.datetime): + self.incrementByteCount('dateBytes', incr=2) + elif isinstance(obj, Data): + size = proc_size(len(obj)) + self.incrementByteCount('dataBytes', incr=1+size) + elif isinstance(obj, (unicode, bytes)): + size = proc_size(len(obj)) + self.incrementByteCount('stringBytes', incr=1+size) + elif isinstance(obj, HashableWrapper): + obj = obj.value + if isinstance(obj, set): + size = proc_size(len(obj)) + self.incrementByteCount('setBytes', incr=1+size) + for value in obj: + self.computeOffsets(value, asReference=True) + elif isinstance(obj, (list, tuple)): + size = proc_size(len(obj)) + self.incrementByteCount('arrayBytes', incr=1+size) + for value in obj: + asRef = True + self.computeOffsets(value, asReference=True) + elif isinstance(obj, dict): + size = proc_size(len(obj)) + self.incrementByteCount('dictBytes', incr=1+size) + for key, value in iteritems(obj): + check_key(key) + self.computeOffsets(key, asReference=True) + self.computeOffsets(value, asReference=True) + else: + raise InvalidPlistException("Unknown object type.") + + def writeObjectReference(self, obj, output): + """Tries to write an object reference, adding it to the references + table. Does not write the actual object bytes or set the reference + position. Returns a tuple of whether the object was a new reference + (True if it was, False if it already was in the reference table) + and the new output. + """ + position = self.positionOfObjectReference(obj) + if position is None: + self.writtenReferences[obj] = len(self.writtenReferences) + output += self.binaryInt(len(self.writtenReferences) - 1, byteSize=self.trailer.objectRefSize) + return (True, output) + else: + output += self.binaryInt(position, byteSize=self.trailer.objectRefSize) + return (False, output) + + def writeObject(self, obj, output, setReferencePosition=False): + """Serializes the given object to the output. Returns output. + If setReferencePosition is True, will set the position the + object was written. + """ + def proc_variable_length(format, length): + result = b'' + if length > 0b1110: + result += pack('!B', (format << 4) | 0b1111) + result = self.writeObject(length, result) + else: + result += pack('!B', (format << 4) | length) + return result + + if isinstance(obj, (str, unicode)) and obj == unicodeEmpty: + # The Apple Plist decoder can't decode a zero length Unicode string. + obj = b'' + + if setReferencePosition: + self.referencePositions[obj] = len(output) + + if obj is None: + output += pack('!B', 0b00000000) + elif isinstance(obj, BoolWrapper): + if obj.value is False: + output += pack('!B', 0b00001000) + else: + output += pack('!B', 0b00001001) + elif isinstance(obj, Uid): + size = self.intSize(obj) + output += pack('!B', (0b1000 << 4) | size - 1) + output += self.binaryInt(obj) + elif isinstance(obj, (int, long)): + byteSize = self.intSize(obj) + root = math.log(byteSize, 2) + output += pack('!B', (0b0001 << 4) | int(root)) + output += self.binaryInt(obj, as_number=True) + elif isinstance(obj, FloatWrapper): + # just use doubles + output += pack('!B', (0b0010 << 4) | 3) + output += self.binaryReal(obj) + elif isinstance(obj, datetime.datetime): + timestamp = (obj - apple_reference_date).total_seconds() + output += pack('!B', 0b00110011) + output += pack('!d', float(timestamp)) + elif isinstance(obj, Data): + output += proc_variable_length(0b0100, len(obj)) + output += obj + elif isinstance(obj, unicode): + byteData = obj.encode('utf_16_be') + output += proc_variable_length(0b0110, len(byteData)//2) + output += byteData + elif isinstance(obj, bytes): + output += proc_variable_length(0b0101, len(obj)) + output += obj + elif isinstance(obj, HashableWrapper): + obj = obj.value + if isinstance(obj, (set, list, tuple)): + if isinstance(obj, set): + output += proc_variable_length(0b1100, len(obj)) + else: + output += proc_variable_length(0b1010, len(obj)) + + objectsToWrite = [] + for objRef in obj: + (isNew, output) = self.writeObjectReference(objRef, output) + if isNew: + objectsToWrite.append(objRef) + for objRef in objectsToWrite: + output = self.writeObject(objRef, output, setReferencePosition=True) + elif isinstance(obj, dict): + output += proc_variable_length(0b1101, len(obj)) + keys = [] + values = [] + objectsToWrite = [] + for key, value in iteritems(obj): + keys.append(key) + values.append(value) + for key in keys: + (isNew, output) = self.writeObjectReference(key, output) + if isNew: + objectsToWrite.append(key) + for value in values: + (isNew, output) = self.writeObjectReference(value, output) + if isNew: + objectsToWrite.append(value) + for objRef in objectsToWrite: + output = self.writeObject(objRef, output, setReferencePosition=True) + return output + + def writeOffsetTable(self, output): + """Writes all of the object reference offsets.""" + all_positions = [] + writtenReferences = list(self.writtenReferences.items()) + writtenReferences.sort(key=lambda x: x[1]) + for obj,order in writtenReferences: + # Porting note: Elsewhere we deliberately replace empty unicdoe strings + # with empty binary strings, but the empty unicode string + # goes into writtenReferences. This isn't an issue in Py2 + # because u'' and b'' have the same hash; but it is in + # Py3, where they don't. + if bytes != str and obj == unicodeEmpty: + obj = b'' + position = self.referencePositions.get(obj) + if position is None: + raise InvalidPlistException("Error while writing offsets table. Object not found. %s" % obj) + output += self.binaryInt(position, self.trailer.offsetSize) + all_positions.append(position) + return output + + def binaryReal(self, obj): + # just use doubles + result = pack('>d', obj.value) + return result + + def binaryInt(self, obj, byteSize=None, as_number=False): + result = b'' + if byteSize is None: + byteSize = self.intSize(obj) + if byteSize == 1: + result += pack('>B', obj) + elif byteSize == 2: + result += pack('>H', obj) + elif byteSize == 4: + result += pack('>L', obj) + elif byteSize == 8: + if as_number: + result += pack('>q', obj) + else: + result += pack('>Q', obj) + elif byteSize <= 16: + try: + result = pack('>Q', 0) + pack('>Q', obj) + except struct_error as e: + raise InvalidPlistException("Unable to pack integer %d: %s" % (obj, e)) + else: + raise InvalidPlistException("Core Foundation can't handle integers with size greater than 16 bytes.") + return result + + def intSize(self, obj): + """Returns the number of bytes necessary to store the given integer.""" + # SIGNED + if obj < 0: # Signed integer, always 8 bytes + return 8 + # UNSIGNED + elif obj <= 0xFF: # 1 byte + return 1 + elif obj <= 0xFFFF: # 2 bytes + return 2 + elif obj <= 0xFFFFFFFF: # 4 bytes + return 4 + # SIGNED + # 0x7FFFFFFFFFFFFFFF is the max. + elif obj <= 0x7FFFFFFFFFFFFFFF: # 8 bytes signed + return 8 + elif obj <= 0xffffffffffffffff: # 8 bytes unsigned + return 16 + else: + raise InvalidPlistException("Core Foundation can't handle integers with size greater than 8 bytes.") + + def realSize(self, obj): + return 8 diff --git a/lib/bs4/__init__.py b/lib/bs4/__init__.py new file mode 100644 index 00000000..d35f765b --- /dev/null +++ b/lib/bs4/__init__.py @@ -0,0 +1,468 @@ +"""Beautiful Soup +Elixir and Tonic +"The Screen-Scraper's Friend" +http://www.crummy.com/software/BeautifulSoup/ + +Beautiful Soup uses a pluggable XML or HTML parser to parse a +(possibly invalid) document into a tree representation. Beautiful Soup +provides provides methods and Pythonic idioms that make it easy to +navigate, search, and modify the parse tree. + +Beautiful Soup works with Python 2.6 and up. It works better if lxml +and/or html5lib is installed. + +For more than you ever wanted to know about Beautiful Soup, see the +documentation: +http://www.crummy.com/software/BeautifulSoup/bs4/doc/ +""" + +__author__ = "Leonard Richardson (leonardr@segfault.org)" +__version__ = "4.4.0" +__copyright__ = "Copyright (c) 2004-2015 Leonard Richardson" +__license__ = "MIT" + +__all__ = ['BeautifulSoup'] + +import os +import re +import warnings + +from .builder import builder_registry, ParserRejectedMarkup +from .dammit import UnicodeDammit +from .element import ( + CData, + Comment, + DEFAULT_OUTPUT_ENCODING, + Declaration, + Doctype, + NavigableString, + PageElement, + ProcessingInstruction, + ResultSet, + SoupStrainer, + Tag, + ) + +# The very first thing we do is give a useful error if someone is +# running this code under Python 3 without converting it. +'You are trying to run the Python 2 version of Beautiful Soup under Python 3. This will not work.'<>'You need to convert the code, either by installing it (`python setup.py install`) or by running 2to3 (`2to3 -w bs4`).' + +class BeautifulSoup(Tag): + """ + This class defines the basic interface called by the tree builders. + + These methods will be called by the parser: + reset() + feed(markup) + + The tree builder may call these methods from its feed() implementation: + handle_starttag(name, attrs) # See note about return value + handle_endtag(name) + handle_data(data) # Appends to the current data node + endData(containerClass=NavigableString) # Ends the current data node + + No matter how complicated the underlying parser is, you should be + able to build a tree using 'start tag' events, 'end tag' events, + 'data' events, and "done with data" events. + + If you encounter an empty-element tag (aka a self-closing tag, + like HTML's
tag), call handle_starttag and then + handle_endtag. + """ + ROOT_TAG_NAME = u'[document]' + + # If the end-user gives no indication which tree builder they + # want, look for one with these features. + DEFAULT_BUILDER_FEATURES = ['html', 'fast'] + + ASCII_SPACES = '\x20\x0a\x09\x0c\x0d' + + NO_PARSER_SPECIFIED_WARNING = "No parser was explicitly specified, so I'm using the best available %(markup_type)s parser for this system (\"%(parser)s\"). This usually isn't a problem, but if you run this code on another system, or in a different virtual environment, it may use a different parser and behave differently.\n\nTo get rid of this warning, change this:\n\n BeautifulSoup([your markup])\n\nto this:\n\n BeautifulSoup([your markup], \"%(parser)s\")\n" + + def __init__(self, markup="", features=None, builder=None, + parse_only=None, from_encoding=None, exclude_encodings=None, + **kwargs): + """The Soup object is initialized as the 'root tag', and the + provided markup (which can be a string or a file-like object) + is fed into the underlying parser.""" + + if 'convertEntities' in kwargs: + warnings.warn( + "BS4 does not respect the convertEntities argument to the " + "BeautifulSoup constructor. Entities are always converted " + "to Unicode characters.") + + if 'markupMassage' in kwargs: + del kwargs['markupMassage'] + warnings.warn( + "BS4 does not respect the markupMassage argument to the " + "BeautifulSoup constructor. The tree builder is responsible " + "for any necessary markup massage.") + + if 'smartQuotesTo' in kwargs: + del kwargs['smartQuotesTo'] + warnings.warn( + "BS4 does not respect the smartQuotesTo argument to the " + "BeautifulSoup constructor. Smart quotes are always converted " + "to Unicode characters.") + + if 'selfClosingTags' in kwargs: + del kwargs['selfClosingTags'] + warnings.warn( + "BS4 does not respect the selfClosingTags argument to the " + "BeautifulSoup constructor. The tree builder is responsible " + "for understanding self-closing tags.") + + if 'isHTML' in kwargs: + del kwargs['isHTML'] + warnings.warn( + "BS4 does not respect the isHTML argument to the " + "BeautifulSoup constructor. Suggest you use " + "features='lxml' for HTML and features='lxml-xml' for " + "XML.") + + def deprecated_argument(old_name, new_name): + if old_name in kwargs: + warnings.warn( + 'The "%s" argument to the BeautifulSoup constructor ' + 'has been renamed to "%s."' % (old_name, new_name)) + value = kwargs[old_name] + del kwargs[old_name] + return value + return None + + parse_only = parse_only or deprecated_argument( + "parseOnlyThese", "parse_only") + + from_encoding = from_encoding or deprecated_argument( + "fromEncoding", "from_encoding") + + if len(kwargs) > 0: + arg = kwargs.keys().pop() + raise TypeError( + "__init__() got an unexpected keyword argument '%s'" % arg) + + if builder is None: + original_features = features + if isinstance(features, basestring): + features = [features] + if features is None or len(features) == 0: + features = self.DEFAULT_BUILDER_FEATURES + builder_class = builder_registry.lookup(*features) + if builder_class is None: + raise FeatureNotFound( + "Couldn't find a tree builder with the features you " + "requested: %s. Do you need to install a parser library?" + % ",".join(features)) + builder = builder_class() + if not (original_features == builder.NAME or + original_features in builder.ALTERNATE_NAMES): + if builder.is_xml: + markup_type = "XML" + else: + markup_type = "HTML" + warnings.warn(self.NO_PARSER_SPECIFIED_WARNING % dict( + parser=builder.NAME, + markup_type=markup_type)) + + self.builder = builder + self.is_xml = builder.is_xml + self.builder.soup = self + + self.parse_only = parse_only + + if hasattr(markup, 'read'): # It's a file-type object. + markup = markup.read() + elif len(markup) <= 256: + # Print out warnings for a couple beginner problems + # involving passing non-markup to Beautiful Soup. + # Beautiful Soup will still parse the input as markup, + # just in case that's what the user really wants. + if (isinstance(markup, unicode) + and not os.path.supports_unicode_filenames): + possible_filename = markup.encode("utf8") + else: + possible_filename = markup + is_file = False + try: + is_file = os.path.exists(possible_filename) + except Exception, e: + # This is almost certainly a problem involving + # characters not valid in filenames on this + # system. Just let it go. + pass + if is_file: + if isinstance(markup, unicode): + markup = markup.encode("utf8") + warnings.warn( + '"%s" looks like a filename, not markup. You should probably open this file and pass the filehandle into Beautiful Soup.' % markup) + if markup[:5] == "http:" or markup[:6] == "https:": + # TODO: This is ugly but I couldn't get it to work in + # Python 3 otherwise. + if ((isinstance(markup, bytes) and not b' ' in markup) + or (isinstance(markup, unicode) and not u' ' in markup)): + if isinstance(markup, unicode): + markup = markup.encode("utf8") + warnings.warn( + '"%s" looks like a URL. Beautiful Soup is not an HTTP client. You should probably use an HTTP client to get the document behind the URL, and feed that document to Beautiful Soup.' % markup) + + for (self.markup, self.original_encoding, self.declared_html_encoding, + self.contains_replacement_characters) in ( + self.builder.prepare_markup( + markup, from_encoding, exclude_encodings=exclude_encodings)): + self.reset() + try: + self._feed() + break + except ParserRejectedMarkup: + pass + + # Clear out the markup and remove the builder's circular + # reference to this object. + self.markup = None + self.builder.soup = None + + def __copy__(self): + return type(self)(self.encode(), builder=self.builder) + + def __getstate__(self): + # Frequently a tree builder can't be pickled. + d = dict(self.__dict__) + if 'builder' in d and not self.builder.picklable: + del d['builder'] + return d + + def _feed(self): + # Convert the document to Unicode. + self.builder.reset() + + self.builder.feed(self.markup) + # Close out any unfinished strings and close all the open tags. + self.endData() + while self.currentTag.name != self.ROOT_TAG_NAME: + self.popTag() + + def reset(self): + Tag.__init__(self, self, self.builder, self.ROOT_TAG_NAME) + self.hidden = 1 + self.builder.reset() + self.current_data = [] + self.currentTag = None + self.tagStack = [] + self.preserve_whitespace_tag_stack = [] + self.pushTag(self) + + def new_tag(self, name, namespace=None, nsprefix=None, **attrs): + """Create a new tag associated with this soup.""" + return Tag(None, self.builder, name, namespace, nsprefix, attrs) + + def new_string(self, s, subclass=NavigableString): + """Create a new NavigableString associated with this soup.""" + return subclass(s) + + def insert_before(self, successor): + raise NotImplementedError("BeautifulSoup objects don't support insert_before().") + + def insert_after(self, successor): + raise NotImplementedError("BeautifulSoup objects don't support insert_after().") + + def popTag(self): + tag = self.tagStack.pop() + if self.preserve_whitespace_tag_stack and tag == self.preserve_whitespace_tag_stack[-1]: + self.preserve_whitespace_tag_stack.pop() + #print "Pop", tag.name + if self.tagStack: + self.currentTag = self.tagStack[-1] + return self.currentTag + + def pushTag(self, tag): + #print "Push", tag.name + if self.currentTag: + self.currentTag.contents.append(tag) + self.tagStack.append(tag) + self.currentTag = self.tagStack[-1] + if tag.name in self.builder.preserve_whitespace_tags: + self.preserve_whitespace_tag_stack.append(tag) + + def endData(self, containerClass=NavigableString): + if self.current_data: + current_data = u''.join(self.current_data) + # If whitespace is not preserved, and this string contains + # nothing but ASCII spaces, replace it with a single space + # or newline. + if not self.preserve_whitespace_tag_stack: + strippable = True + for i in current_data: + if i not in self.ASCII_SPACES: + strippable = False + break + if strippable: + if '\n' in current_data: + current_data = '\n' + else: + current_data = ' ' + + # Reset the data collector. + self.current_data = [] + + # Should we add this string to the tree at all? + if self.parse_only and len(self.tagStack) <= 1 and \ + (not self.parse_only.text or \ + not self.parse_only.search(current_data)): + return + + o = containerClass(current_data) + self.object_was_parsed(o) + + def object_was_parsed(self, o, parent=None, most_recent_element=None): + """Add an object to the parse tree.""" + parent = parent or self.currentTag + previous_element = most_recent_element or self._most_recent_element + + next_element = previous_sibling = next_sibling = None + if isinstance(o, Tag): + next_element = o.next_element + next_sibling = o.next_sibling + previous_sibling = o.previous_sibling + if not previous_element: + previous_element = o.previous_element + + o.setup(parent, previous_element, next_element, previous_sibling, next_sibling) + + self._most_recent_element = o + parent.contents.append(o) + + if parent.next_sibling: + # This node is being inserted into an element that has + # already been parsed. Deal with any dangling references. + index = parent.contents.index(o) + if index == 0: + previous_element = parent + previous_sibling = None + else: + previous_element = previous_sibling = parent.contents[index-1] + if index == len(parent.contents)-1: + next_element = parent.next_sibling + next_sibling = None + else: + next_element = next_sibling = parent.contents[index+1] + + o.previous_element = previous_element + if previous_element: + previous_element.next_element = o + o.next_element = next_element + if next_element: + next_element.previous_element = o + o.next_sibling = next_sibling + if next_sibling: + next_sibling.previous_sibling = o + o.previous_sibling = previous_sibling + if previous_sibling: + previous_sibling.next_sibling = o + + def _popToTag(self, name, nsprefix=None, inclusivePop=True): + """Pops the tag stack up to and including the most recent + instance of the given tag. If inclusivePop is false, pops the tag + stack up to but *not* including the most recent instqance of + the given tag.""" + #print "Popping to %s" % name + if name == self.ROOT_TAG_NAME: + # The BeautifulSoup object itself can never be popped. + return + + most_recently_popped = None + + stack_size = len(self.tagStack) + for i in range(stack_size - 1, 0, -1): + t = self.tagStack[i] + if (name == t.name and nsprefix == t.prefix): + if inclusivePop: + most_recently_popped = self.popTag() + break + most_recently_popped = self.popTag() + + return most_recently_popped + + def handle_starttag(self, name, namespace, nsprefix, attrs): + """Push a start tag on to the stack. + + If this method returns None, the tag was rejected by the + SoupStrainer. You should proceed as if the tag had not occured + in the document. For instance, if this was a self-closing tag, + don't call handle_endtag. + """ + + # print "Start tag %s: %s" % (name, attrs) + self.endData() + + if (self.parse_only and len(self.tagStack) <= 1 + and (self.parse_only.text + or not self.parse_only.search_tag(name, attrs))): + return None + + tag = Tag(self, self.builder, name, namespace, nsprefix, attrs, + self.currentTag, self._most_recent_element) + if tag is None: + return tag + if self._most_recent_element: + self._most_recent_element.next_element = tag + self._most_recent_element = tag + self.pushTag(tag) + return tag + + def handle_endtag(self, name, nsprefix=None): + #print "End tag: " + name + self.endData() + self._popToTag(name, nsprefix) + + def handle_data(self, data): + self.current_data.append(data) + + def decode(self, pretty_print=False, + eventual_encoding=DEFAULT_OUTPUT_ENCODING, + formatter="minimal"): + """Returns a string or Unicode representation of this document. + To get Unicode, pass None for encoding.""" + + if self.is_xml: + # Print the XML declaration + encoding_part = '' + if eventual_encoding != None: + encoding_part = ' encoding="%s"' % eventual_encoding + prefix = u'\n' % encoding_part + else: + prefix = u'' + if not pretty_print: + indent_level = None + else: + indent_level = 0 + return prefix + super(BeautifulSoup, self).decode( + indent_level, eventual_encoding, formatter) + +# Alias to make it easier to type import: 'from bs4 import _soup' +_s = BeautifulSoup +_soup = BeautifulSoup + +class BeautifulStoneSoup(BeautifulSoup): + """Deprecated interface to an XML parser.""" + + def __init__(self, *args, **kwargs): + kwargs['features'] = 'xml' + warnings.warn( + 'The BeautifulStoneSoup class is deprecated. Instead of using ' + 'it, pass features="xml" into the BeautifulSoup constructor.') + super(BeautifulStoneSoup, self).__init__(*args, **kwargs) + + +class StopParsing(Exception): + pass + +class FeatureNotFound(ValueError): + pass + + +#By default, act as an HTML pretty-printer. +if __name__ == '__main__': + import sys + soup = BeautifulSoup(sys.stdin) + print soup.prettify() diff --git a/lib/bs4/builder/__init__.py b/lib/bs4/builder/__init__.py new file mode 100644 index 00000000..f8fce568 --- /dev/null +++ b/lib/bs4/builder/__init__.py @@ -0,0 +1,324 @@ +from collections import defaultdict +import itertools +import sys +from bs4.element import ( + CharsetMetaAttributeValue, + ContentMetaAttributeValue, + whitespace_re + ) + +__all__ = [ + 'HTMLTreeBuilder', + 'SAXTreeBuilder', + 'TreeBuilder', + 'TreeBuilderRegistry', + ] + +# Some useful features for a TreeBuilder to have. +FAST = 'fast' +PERMISSIVE = 'permissive' +STRICT = 'strict' +XML = 'xml' +HTML = 'html' +HTML_5 = 'html5' + + +class TreeBuilderRegistry(object): + + def __init__(self): + self.builders_for_feature = defaultdict(list) + self.builders = [] + + def register(self, treebuilder_class): + """Register a treebuilder based on its advertised features.""" + for feature in treebuilder_class.features: + self.builders_for_feature[feature].insert(0, treebuilder_class) + self.builders.insert(0, treebuilder_class) + + def lookup(self, *features): + if len(self.builders) == 0: + # There are no builders at all. + return None + + if len(features) == 0: + # They didn't ask for any features. Give them the most + # recently registered builder. + return self.builders[0] + + # Go down the list of features in order, and eliminate any builders + # that don't match every feature. + features = list(features) + features.reverse() + candidates = None + candidate_set = None + while len(features) > 0: + feature = features.pop() + we_have_the_feature = self.builders_for_feature.get(feature, []) + if len(we_have_the_feature) > 0: + if candidates is None: + candidates = we_have_the_feature + candidate_set = set(candidates) + else: + # Eliminate any candidates that don't have this feature. + candidate_set = candidate_set.intersection( + set(we_have_the_feature)) + + # The only valid candidates are the ones in candidate_set. + # Go through the original list of candidates and pick the first one + # that's in candidate_set. + if candidate_set is None: + return None + for candidate in candidates: + if candidate in candidate_set: + return candidate + return None + +# The BeautifulSoup class will take feature lists from developers and use them +# to look up builders in this registry. +builder_registry = TreeBuilderRegistry() + +class TreeBuilder(object): + """Turn a document into a Beautiful Soup object tree.""" + + NAME = "[Unknown tree builder]" + ALTERNATE_NAMES = [] + features = [] + + is_xml = False + picklable = False + preserve_whitespace_tags = set() + empty_element_tags = None # A tag will be considered an empty-element + # tag when and only when it has no contents. + + # A value for these tag/attribute combinations is a space- or + # comma-separated list of CDATA, rather than a single CDATA. + cdata_list_attributes = {} + + + def __init__(self): + self.soup = None + + def reset(self): + pass + + def can_be_empty_element(self, tag_name): + """Might a tag with this name be an empty-element tag? + + The final markup may or may not actually present this tag as + self-closing. + + For instance: an HTMLBuilder does not consider a

tag to be + an empty-element tag (it's not in + HTMLBuilder.empty_element_tags). This means an empty

tag + will be presented as "

", not "

". + + The default implementation has no opinion about which tags are + empty-element tags, so a tag will be presented as an + empty-element tag if and only if it has no contents. + "" will become "", and "bar" will + be left alone. + """ + if self.empty_element_tags is None: + return True + return tag_name in self.empty_element_tags + + def feed(self, markup): + raise NotImplementedError() + + def prepare_markup(self, markup, user_specified_encoding=None, + document_declared_encoding=None): + return markup, None, None, False + + def test_fragment_to_document(self, fragment): + """Wrap an HTML fragment to make it look like a document. + + Different parsers do this differently. For instance, lxml + introduces an empty tag, and html5lib + doesn't. Abstracting this away lets us write simple tests + which run HTML fragments through the parser and compare the + results against other HTML fragments. + + This method should not be used outside of tests. + """ + return fragment + + def set_up_substitutions(self, tag): + return False + + def _replace_cdata_list_attribute_values(self, tag_name, attrs): + """Replaces class="foo bar" with class=["foo", "bar"] + + Modifies its input in place. + """ + if not attrs: + return attrs + if self.cdata_list_attributes: + universal = self.cdata_list_attributes.get('*', []) + tag_specific = self.cdata_list_attributes.get( + tag_name.lower(), None) + for attr in attrs.keys(): + if attr in universal or (tag_specific and attr in tag_specific): + # We have a "class"-type attribute whose string + # value is a whitespace-separated list of + # values. Split it into a list. + value = attrs[attr] + if isinstance(value, basestring): + values = whitespace_re.split(value) + else: + # html5lib sometimes calls setAttributes twice + # for the same tag when rearranging the parse + # tree. On the second call the attribute value + # here is already a list. If this happens, + # leave the value alone rather than trying to + # split it again. + values = value + attrs[attr] = values + return attrs + +class SAXTreeBuilder(TreeBuilder): + """A Beautiful Soup treebuilder that listens for SAX events.""" + + def feed(self, markup): + raise NotImplementedError() + + def close(self): + pass + + def startElement(self, name, attrs): + attrs = dict((key[1], value) for key, value in list(attrs.items())) + #print "Start %s, %r" % (name, attrs) + self.soup.handle_starttag(name, attrs) + + def endElement(self, name): + #print "End %s" % name + self.soup.handle_endtag(name) + + def startElementNS(self, nsTuple, nodeName, attrs): + # Throw away (ns, nodeName) for now. + self.startElement(nodeName, attrs) + + def endElementNS(self, nsTuple, nodeName): + # Throw away (ns, nodeName) for now. + self.endElement(nodeName) + #handler.endElementNS((ns, node.nodeName), node.nodeName) + + def startPrefixMapping(self, prefix, nodeValue): + # Ignore the prefix for now. + pass + + def endPrefixMapping(self, prefix): + # Ignore the prefix for now. + # handler.endPrefixMapping(prefix) + pass + + def characters(self, content): + self.soup.handle_data(content) + + def startDocument(self): + pass + + def endDocument(self): + pass + + +class HTMLTreeBuilder(TreeBuilder): + """This TreeBuilder knows facts about HTML. + + Such as which tags are empty-element tags. + """ + + preserve_whitespace_tags = set(['pre', 'textarea']) + empty_element_tags = set(['br' , 'hr', 'input', 'img', 'meta', + 'spacer', 'link', 'frame', 'base']) + + # The HTML standard defines these attributes as containing a + # space-separated list of values, not a single value. That is, + # class="foo bar" means that the 'class' attribute has two values, + # 'foo' and 'bar', not the single value 'foo bar'. When we + # encounter one of these attributes, we will parse its value into + # a list of values if possible. Upon output, the list will be + # converted back into a string. + cdata_list_attributes = { + "*" : ['class', 'accesskey', 'dropzone'], + "a" : ['rel', 'rev'], + "link" : ['rel', 'rev'], + "td" : ["headers"], + "th" : ["headers"], + "td" : ["headers"], + "form" : ["accept-charset"], + "object" : ["archive"], + + # These are HTML5 specific, as are *.accesskey and *.dropzone above. + "area" : ["rel"], + "icon" : ["sizes"], + "iframe" : ["sandbox"], + "output" : ["for"], + } + + def set_up_substitutions(self, tag): + # We are only interested in tags + if tag.name != 'meta': + return False + + http_equiv = tag.get('http-equiv') + content = tag.get('content') + charset = tag.get('charset') + + # We are interested in tags that say what encoding the + # document was originally in. This means HTML 5-style + # tags that provide the "charset" attribute. It also means + # HTML 4-style tags that provide the "content" + # attribute and have "http-equiv" set to "content-type". + # + # In both cases we will replace the value of the appropriate + # attribute with a standin object that can take on any + # encoding. + meta_encoding = None + if charset is not None: + # HTML 5 style: + # + meta_encoding = charset + tag['charset'] = CharsetMetaAttributeValue(charset) + + elif (content is not None and http_equiv is not None + and http_equiv.lower() == 'content-type'): + # HTML 4 style: + # + tag['content'] = ContentMetaAttributeValue(content) + + return (meta_encoding is not None) + +def register_treebuilders_from(module): + """Copy TreeBuilders from the given module into this module.""" + # I'm fairly sure this is not the best way to do this. + this_module = sys.modules['bs4.builder'] + for name in module.__all__: + obj = getattr(module, name) + + if issubclass(obj, TreeBuilder): + setattr(this_module, name, obj) + this_module.__all__.append(name) + # Register the builder while we're at it. + this_module.builder_registry.register(obj) + +class ParserRejectedMarkup(Exception): + pass + +# Builders are registered in reverse order of priority, so that custom +# builder registrations will take precedence. In general, we want lxml +# to take precedence over html5lib, because it's faster. And we only +# want to use HTMLParser as a last result. +from . import _htmlparser +register_treebuilders_from(_htmlparser) +try: + from . import _html5lib + register_treebuilders_from(_html5lib) +except ImportError: + # They don't have html5lib installed. + pass +try: + from . import _lxml + register_treebuilders_from(_lxml) +except ImportError: + # They don't have lxml installed. + pass diff --git a/lib/bs4/builder/_html5lib.py b/lib/bs4/builder/_html5lib.py new file mode 100644 index 00000000..ab5793c1 --- /dev/null +++ b/lib/bs4/builder/_html5lib.py @@ -0,0 +1,329 @@ +__all__ = [ + 'HTML5TreeBuilder', + ] + +from pdb import set_trace +import warnings +from bs4.builder import ( + PERMISSIVE, + HTML, + HTML_5, + HTMLTreeBuilder, + ) +from bs4.element import ( + NamespacedAttribute, + whitespace_re, +) +import html5lib +from html5lib.constants import namespaces +from bs4.element import ( + Comment, + Doctype, + NavigableString, + Tag, + ) + +class HTML5TreeBuilder(HTMLTreeBuilder): + """Use html5lib to build a tree.""" + + NAME = "html5lib" + + features = [NAME, PERMISSIVE, HTML_5, HTML] + + def prepare_markup(self, markup, user_specified_encoding, + document_declared_encoding=None, exclude_encodings=None): + # Store the user-specified encoding for use later on. + self.user_specified_encoding = user_specified_encoding + + # document_declared_encoding and exclude_encodings aren't used + # ATM because the html5lib TreeBuilder doesn't use + # UnicodeDammit. + if exclude_encodings: + warnings.warn("You provided a value for exclude_encoding, but the html5lib tree builder doesn't support exclude_encoding.") + yield (markup, None, None, False) + + # These methods are defined by Beautiful Soup. + def feed(self, markup): + if self.soup.parse_only is not None: + warnings.warn("You provided a value for parse_only, but the html5lib tree builder doesn't support parse_only. The entire document will be parsed.") + parser = html5lib.HTMLParser(tree=self.create_treebuilder) + doc = parser.parse(markup, encoding=self.user_specified_encoding) + + # Set the character encoding detected by the tokenizer. + if isinstance(markup, unicode): + # We need to special-case this because html5lib sets + # charEncoding to UTF-8 if it gets Unicode input. + doc.original_encoding = None + else: + doc.original_encoding = parser.tokenizer.stream.charEncoding[0] + + def create_treebuilder(self, namespaceHTMLElements): + self.underlying_builder = TreeBuilderForHtml5lib( + self.soup, namespaceHTMLElements) + return self.underlying_builder + + def test_fragment_to_document(self, fragment): + """See `TreeBuilder`.""" + return u'%s' % fragment + + +class TreeBuilderForHtml5lib(html5lib.treebuilders._base.TreeBuilder): + + def __init__(self, soup, namespaceHTMLElements): + self.soup = soup + super(TreeBuilderForHtml5lib, self).__init__(namespaceHTMLElements) + + def documentClass(self): + self.soup.reset() + return Element(self.soup, self.soup, None) + + def insertDoctype(self, token): + name = token["name"] + publicId = token["publicId"] + systemId = token["systemId"] + + doctype = Doctype.for_name_and_ids(name, publicId, systemId) + self.soup.object_was_parsed(doctype) + + def elementClass(self, name, namespace): + tag = self.soup.new_tag(name, namespace) + return Element(tag, self.soup, namespace) + + def commentClass(self, data): + return TextNode(Comment(data), self.soup) + + def fragmentClass(self): + self.soup = BeautifulSoup("") + self.soup.name = "[document_fragment]" + return Element(self.soup, self.soup, None) + + def appendChild(self, node): + # XXX This code is not covered by the BS4 tests. + self.soup.append(node.element) + + def getDocument(self): + return self.soup + + def getFragment(self): + return html5lib.treebuilders._base.TreeBuilder.getFragment(self).element + +class AttrList(object): + def __init__(self, element): + self.element = element + self.attrs = dict(self.element.attrs) + def __iter__(self): + return list(self.attrs.items()).__iter__() + def __setitem__(self, name, value): + # If this attribute is a multi-valued attribute for this element, + # turn its value into a list. + list_attr = HTML5TreeBuilder.cdata_list_attributes + if (name in list_attr['*'] + or (self.element.name in list_attr + and name in list_attr[self.element.name])): + value = whitespace_re.split(value) + self.element[name] = value + def items(self): + return list(self.attrs.items()) + def keys(self): + return list(self.attrs.keys()) + def __len__(self): + return len(self.attrs) + def __getitem__(self, name): + return self.attrs[name] + def __contains__(self, name): + return name in list(self.attrs.keys()) + + +class Element(html5lib.treebuilders._base.Node): + def __init__(self, element, soup, namespace): + html5lib.treebuilders._base.Node.__init__(self, element.name) + self.element = element + self.soup = soup + self.namespace = namespace + + def appendChild(self, node): + string_child = child = None + if isinstance(node, basestring): + # Some other piece of code decided to pass in a string + # instead of creating a TextElement object to contain the + # string. + string_child = child = node + elif isinstance(node, Tag): + # Some other piece of code decided to pass in a Tag + # instead of creating an Element object to contain the + # Tag. + child = node + elif node.element.__class__ == NavigableString: + string_child = child = node.element + else: + child = node.element + + if not isinstance(child, basestring) and child.parent is not None: + node.element.extract() + + if (string_child and self.element.contents + and self.element.contents[-1].__class__ == NavigableString): + # We are appending a string onto another string. + # TODO This has O(n^2) performance, for input like + # "aaa..." + old_element = self.element.contents[-1] + new_element = self.soup.new_string(old_element + string_child) + old_element.replace_with(new_element) + self.soup._most_recent_element = new_element + else: + if isinstance(node, basestring): + # Create a brand new NavigableString from this string. + child = self.soup.new_string(node) + + # Tell Beautiful Soup to act as if it parsed this element + # immediately after the parent's last descendant. (Or + # immediately after the parent, if it has no children.) + if self.element.contents: + most_recent_element = self.element._last_descendant(False) + elif self.element.next_element is not None: + # Something from further ahead in the parse tree is + # being inserted into this earlier element. This is + # very annoying because it means an expensive search + # for the last element in the tree. + most_recent_element = self.soup._last_descendant() + else: + most_recent_element = self.element + + self.soup.object_was_parsed( + child, parent=self.element, + most_recent_element=most_recent_element) + + def getAttributes(self): + return AttrList(self.element) + + def setAttributes(self, attributes): + + if attributes is not None and len(attributes) > 0: + + converted_attributes = [] + for name, value in list(attributes.items()): + if isinstance(name, tuple): + new_name = NamespacedAttribute(*name) + del attributes[name] + attributes[new_name] = value + + self.soup.builder._replace_cdata_list_attribute_values( + self.name, attributes) + for name, value in attributes.items(): + self.element[name] = value + + # The attributes may contain variables that need substitution. + # Call set_up_substitutions manually. + # + # The Tag constructor called this method when the Tag was created, + # but we just set/changed the attributes, so call it again. + self.soup.builder.set_up_substitutions(self.element) + attributes = property(getAttributes, setAttributes) + + def insertText(self, data, insertBefore=None): + if insertBefore: + text = TextNode(self.soup.new_string(data), self.soup) + self.insertBefore(data, insertBefore) + else: + self.appendChild(data) + + def insertBefore(self, node, refNode): + index = self.element.index(refNode.element) + if (node.element.__class__ == NavigableString and self.element.contents + and self.element.contents[index-1].__class__ == NavigableString): + # (See comments in appendChild) + old_node = self.element.contents[index-1] + new_str = self.soup.new_string(old_node + node.element) + old_node.replace_with(new_str) + else: + self.element.insert(index, node.element) + node.parent = self + + def removeChild(self, node): + node.element.extract() + + def reparentChildren(self, new_parent): + """Move all of this tag's children into another tag.""" + # print "MOVE", self.element.contents + # print "FROM", self.element + # print "TO", new_parent.element + element = self.element + new_parent_element = new_parent.element + # Determine what this tag's next_element will be once all the children + # are removed. + final_next_element = element.next_sibling + + new_parents_last_descendant = new_parent_element._last_descendant(False, False) + if len(new_parent_element.contents) > 0: + # The new parent already contains children. We will be + # appending this tag's children to the end. + new_parents_last_child = new_parent_element.contents[-1] + new_parents_last_descendant_next_element = new_parents_last_descendant.next_element + else: + # The new parent contains no children. + new_parents_last_child = None + new_parents_last_descendant_next_element = new_parent_element.next_element + + to_append = element.contents + append_after = new_parent_element.contents + if len(to_append) > 0: + # Set the first child's previous_element and previous_sibling + # to elements within the new parent + first_child = to_append[0] + if new_parents_last_descendant: + first_child.previous_element = new_parents_last_descendant + else: + first_child.previous_element = new_parent_element + first_child.previous_sibling = new_parents_last_child + if new_parents_last_descendant: + new_parents_last_descendant.next_element = first_child + else: + new_parent_element.next_element = first_child + if new_parents_last_child: + new_parents_last_child.next_sibling = first_child + + # Fix the last child's next_element and next_sibling + last_child = to_append[-1] + last_child.next_element = new_parents_last_descendant_next_element + if new_parents_last_descendant_next_element: + new_parents_last_descendant_next_element.previous_element = last_child + last_child.next_sibling = None + + for child in to_append: + child.parent = new_parent_element + new_parent_element.contents.append(child) + + # Now that this element has no children, change its .next_element. + element.contents = [] + element.next_element = final_next_element + + # print "DONE WITH MOVE" + # print "FROM", self.element + # print "TO", new_parent_element + + def cloneNode(self): + tag = self.soup.new_tag(self.element.name, self.namespace) + node = Element(tag, self.soup, self.namespace) + for key,value in self.attributes: + node.attributes[key] = value + return node + + def hasContent(self): + return self.element.contents + + def getNameTuple(self): + if self.namespace == None: + return namespaces["html"], self.name + else: + return self.namespace, self.name + + nameTuple = property(getNameTuple) + +class TextNode(Element): + def __init__(self, element, soup): + html5lib.treebuilders._base.Node.__init__(self, None) + self.element = element + self.soup = soup + + def cloneNode(self): + raise NotImplementedError diff --git a/lib/bs4/builder/_htmlparser.py b/lib/bs4/builder/_htmlparser.py new file mode 100644 index 00000000..0101d647 --- /dev/null +++ b/lib/bs4/builder/_htmlparser.py @@ -0,0 +1,262 @@ +"""Use the HTMLParser library to parse HTML files that aren't too bad.""" + +__all__ = [ + 'HTMLParserTreeBuilder', + ] + +from HTMLParser import HTMLParser + +try: + from HTMLParser import HTMLParseError +except ImportError, e: + # HTMLParseError is removed in Python 3.5. Since it can never be + # thrown in 3.5, we can just define our own class as a placeholder. + class HTMLParseError(Exception): + pass + +import sys +import warnings + +# Starting in Python 3.2, the HTMLParser constructor takes a 'strict' +# argument, which we'd like to set to False. Unfortunately, +# http://bugs.python.org/issue13273 makes strict=True a better bet +# before Python 3.2.3. +# +# At the end of this file, we monkeypatch HTMLParser so that +# strict=True works well on Python 3.2.2. +major, minor, release = sys.version_info[:3] +CONSTRUCTOR_TAKES_STRICT = major == 3 and minor == 2 and release >= 3 +CONSTRUCTOR_STRICT_IS_DEPRECATED = major == 3 and minor == 3 +CONSTRUCTOR_TAKES_CONVERT_CHARREFS = major == 3 and minor >= 4 + + +from bs4.element import ( + CData, + Comment, + Declaration, + Doctype, + ProcessingInstruction, + ) +from bs4.dammit import EntitySubstitution, UnicodeDammit + +from bs4.builder import ( + HTML, + HTMLTreeBuilder, + STRICT, + ) + + +HTMLPARSER = 'html.parser' + +class BeautifulSoupHTMLParser(HTMLParser): + def handle_starttag(self, name, attrs): + # XXX namespace + attr_dict = {} + for key, value in attrs: + # Change None attribute values to the empty string + # for consistency with the other tree builders. + if value is None: + value = '' + attr_dict[key] = value + attrvalue = '""' + self.soup.handle_starttag(name, None, None, attr_dict) + + def handle_endtag(self, name): + self.soup.handle_endtag(name) + + def handle_data(self, data): + self.soup.handle_data(data) + + def handle_charref(self, name): + # XXX workaround for a bug in HTMLParser. Remove this once + # it's fixed in all supported versions. + # http://bugs.python.org/issue13633 + if name.startswith('x'): + real_name = int(name.lstrip('x'), 16) + elif name.startswith('X'): + real_name = int(name.lstrip('X'), 16) + else: + real_name = int(name) + + try: + data = unichr(real_name) + except (ValueError, OverflowError), e: + data = u"\N{REPLACEMENT CHARACTER}" + + self.handle_data(data) + + def handle_entityref(self, name): + character = EntitySubstitution.HTML_ENTITY_TO_CHARACTER.get(name) + if character is not None: + data = character + else: + data = "&%s;" % name + self.handle_data(data) + + def handle_comment(self, data): + self.soup.endData() + self.soup.handle_data(data) + self.soup.endData(Comment) + + def handle_decl(self, data): + self.soup.endData() + if data.startswith("DOCTYPE "): + data = data[len("DOCTYPE "):] + elif data == 'DOCTYPE': + # i.e. "" + data = '' + self.soup.handle_data(data) + self.soup.endData(Doctype) + + def unknown_decl(self, data): + if data.upper().startswith('CDATA['): + cls = CData + data = data[len('CDATA['):] + else: + cls = Declaration + self.soup.endData() + self.soup.handle_data(data) + self.soup.endData(cls) + + def handle_pi(self, data): + self.soup.endData() + self.soup.handle_data(data) + self.soup.endData(ProcessingInstruction) + + +class HTMLParserTreeBuilder(HTMLTreeBuilder): + + is_xml = False + picklable = True + NAME = HTMLPARSER + features = [NAME, HTML, STRICT] + + def __init__(self, *args, **kwargs): + if CONSTRUCTOR_TAKES_STRICT and not CONSTRUCTOR_STRICT_IS_DEPRECATED: + kwargs['strict'] = False + if CONSTRUCTOR_TAKES_CONVERT_CHARREFS: + kwargs['convert_charrefs'] = False + self.parser_args = (args, kwargs) + + def prepare_markup(self, markup, user_specified_encoding=None, + document_declared_encoding=None, exclude_encodings=None): + """ + :return: A 4-tuple (markup, original encoding, encoding + declared within markup, whether any characters had to be + replaced with REPLACEMENT CHARACTER). + """ + if isinstance(markup, unicode): + yield (markup, None, None, False) + return + + try_encodings = [user_specified_encoding, document_declared_encoding] + dammit = UnicodeDammit(markup, try_encodings, is_html=True, + exclude_encodings=exclude_encodings) + yield (dammit.markup, dammit.original_encoding, + dammit.declared_html_encoding, + dammit.contains_replacement_characters) + + def feed(self, markup): + args, kwargs = self.parser_args + parser = BeautifulSoupHTMLParser(*args, **kwargs) + parser.soup = self.soup + try: + parser.feed(markup) + except HTMLParseError, e: + warnings.warn(RuntimeWarning( + "Python's built-in HTMLParser cannot parse the given document. This is not a bug in Beautiful Soup. The best solution is to install an external parser (lxml or html5lib), and use Beautiful Soup with that parser. See http://www.crummy.com/software/BeautifulSoup/bs4/doc/#installing-a-parser for help.")) + raise e + +# Patch 3.2 versions of HTMLParser earlier than 3.2.3 to use some +# 3.2.3 code. This ensures they don't treat markup like

as a +# string. +# +# XXX This code can be removed once most Python 3 users are on 3.2.3. +if major == 3 and minor == 2 and not CONSTRUCTOR_TAKES_STRICT: + import re + attrfind_tolerant = re.compile( + r'\s*((?<=[\'"\s])[^\s/>][^\s/=>]*)(\s*=+\s*' + r'(\'[^\']*\'|"[^"]*"|(?![\'"])[^>\s]*))?') + HTMLParserTreeBuilder.attrfind_tolerant = attrfind_tolerant + + locatestarttagend = re.compile(r""" + <[a-zA-Z][-.a-zA-Z0-9:_]* # tag name + (?:\s+ # whitespace before attribute name + (?:[a-zA-Z_][-.:a-zA-Z0-9_]* # attribute name + (?:\s*=\s* # value indicator + (?:'[^']*' # LITA-enclosed value + |\"[^\"]*\" # LIT-enclosed value + |[^'\">\s]+ # bare value + ) + )? + ) + )* + \s* # trailing whitespace +""", re.VERBOSE) + BeautifulSoupHTMLParser.locatestarttagend = locatestarttagend + + from html.parser import tagfind, attrfind + + def parse_starttag(self, i): + self.__starttag_text = None + endpos = self.check_for_whole_start_tag(i) + if endpos < 0: + return endpos + rawdata = self.rawdata + self.__starttag_text = rawdata[i:endpos] + + # Now parse the data between i+1 and j into a tag and attrs + attrs = [] + match = tagfind.match(rawdata, i+1) + assert match, 'unexpected call to parse_starttag()' + k = match.end() + self.lasttag = tag = rawdata[i+1:k].lower() + while k < endpos: + if self.strict: + m = attrfind.match(rawdata, k) + else: + m = attrfind_tolerant.match(rawdata, k) + if not m: + break + attrname, rest, attrvalue = m.group(1, 2, 3) + if not rest: + attrvalue = None + elif attrvalue[:1] == '\'' == attrvalue[-1:] or \ + attrvalue[:1] == '"' == attrvalue[-1:]: + attrvalue = attrvalue[1:-1] + if attrvalue: + attrvalue = self.unescape(attrvalue) + attrs.append((attrname.lower(), attrvalue)) + k = m.end() + + end = rawdata[k:endpos].strip() + if end not in (">", "/>"): + lineno, offset = self.getpos() + if "\n" in self.__starttag_text: + lineno = lineno + self.__starttag_text.count("\n") + offset = len(self.__starttag_text) \ + - self.__starttag_text.rfind("\n") + else: + offset = offset + len(self.__starttag_text) + if self.strict: + self.error("junk characters in start tag: %r" + % (rawdata[k:endpos][:20],)) + self.handle_data(rawdata[i:endpos]) + return endpos + if end.endswith('/>'): + # XHTML-style empty tag: + self.handle_startendtag(tag, attrs) + else: + self.handle_starttag(tag, attrs) + if tag in self.CDATA_CONTENT_ELEMENTS: + self.set_cdata_mode(tag) + return endpos + + def set_cdata_mode(self, elem): + self.cdata_elem = elem.lower() + self.interesting = re.compile(r'' % self.cdata_elem, re.I) + + BeautifulSoupHTMLParser.parse_starttag = parse_starttag + BeautifulSoupHTMLParser.set_cdata_mode = set_cdata_mode + + CONSTRUCTOR_TAKES_STRICT = True diff --git a/lib/bs4/builder/_lxml.py b/lib/bs4/builder/_lxml.py new file mode 100644 index 00000000..9e8f88fb --- /dev/null +++ b/lib/bs4/builder/_lxml.py @@ -0,0 +1,248 @@ +__all__ = [ + 'LXMLTreeBuilderForXML', + 'LXMLTreeBuilder', + ] + +from io import BytesIO +from StringIO import StringIO +import collections +from lxml import etree +from bs4.element import ( + Comment, + Doctype, + NamespacedAttribute, + ProcessingInstruction, +) +from bs4.builder import ( + FAST, + HTML, + HTMLTreeBuilder, + PERMISSIVE, + ParserRejectedMarkup, + TreeBuilder, + XML) +from bs4.dammit import EncodingDetector + +LXML = 'lxml' + +class LXMLTreeBuilderForXML(TreeBuilder): + DEFAULT_PARSER_CLASS = etree.XMLParser + + is_xml = True + + NAME = "lxml-xml" + ALTERNATE_NAMES = ["xml"] + + # Well, it's permissive by XML parser standards. + features = [NAME, LXML, XML, FAST, PERMISSIVE] + + CHUNK_SIZE = 512 + + # This namespace mapping is specified in the XML Namespace + # standard. + DEFAULT_NSMAPS = {'http://www.w3.org/XML/1998/namespace' : "xml"} + + def default_parser(self, encoding): + # This can either return a parser object or a class, which + # will be instantiated with default arguments. + if self._default_parser is not None: + return self._default_parser + return etree.XMLParser( + target=self, strip_cdata=False, recover=True, encoding=encoding) + + def parser_for(self, encoding): + # Use the default parser. + parser = self.default_parser(encoding) + + if isinstance(parser, collections.Callable): + # Instantiate the parser with default arguments + parser = parser(target=self, strip_cdata=False, encoding=encoding) + return parser + + def __init__(self, parser=None, empty_element_tags=None): + # TODO: Issue a warning if parser is present but not a + # callable, since that means there's no way to create new + # parsers for different encodings. + self._default_parser = parser + if empty_element_tags is not None: + self.empty_element_tags = set(empty_element_tags) + self.soup = None + self.nsmaps = [self.DEFAULT_NSMAPS] + + def _getNsTag(self, tag): + # Split the namespace URL out of a fully-qualified lxml tag + # name. Copied from lxml's src/lxml/sax.py. + if tag[0] == '{': + return tuple(tag[1:].split('}', 1)) + else: + return (None, tag) + + def prepare_markup(self, markup, user_specified_encoding=None, + exclude_encodings=None, + document_declared_encoding=None): + """ + :yield: A series of 4-tuples. + (markup, encoding, declared encoding, + has undergone character replacement) + + Each 4-tuple represents a strategy for parsing the document. + """ + if isinstance(markup, unicode): + # We were given Unicode. Maybe lxml can parse Unicode on + # this system? + yield markup, None, document_declared_encoding, False + + if isinstance(markup, unicode): + # No, apparently not. Convert the Unicode to UTF-8 and + # tell lxml to parse it as UTF-8. + yield (markup.encode("utf8"), "utf8", + document_declared_encoding, False) + + # Instead of using UnicodeDammit to convert the bytestring to + # Unicode using different encodings, use EncodingDetector to + # iterate over the encodings, and tell lxml to try to parse + # the document as each one in turn. + is_html = not self.is_xml + try_encodings = [user_specified_encoding, document_declared_encoding] + detector = EncodingDetector( + markup, try_encodings, is_html, exclude_encodings) + for encoding in detector.encodings: + yield (detector.markup, encoding, document_declared_encoding, False) + + def feed(self, markup): + if isinstance(markup, bytes): + markup = BytesIO(markup) + elif isinstance(markup, unicode): + markup = StringIO(markup) + + # Call feed() at least once, even if the markup is empty, + # or the parser won't be initialized. + data = markup.read(self.CHUNK_SIZE) + try: + self.parser = self.parser_for(self.soup.original_encoding) + self.parser.feed(data) + while len(data) != 0: + # Now call feed() on the rest of the data, chunk by chunk. + data = markup.read(self.CHUNK_SIZE) + if len(data) != 0: + self.parser.feed(data) + self.parser.close() + except (UnicodeDecodeError, LookupError, etree.ParserError), e: + raise ParserRejectedMarkup(str(e)) + + def close(self): + self.nsmaps = [self.DEFAULT_NSMAPS] + + def start(self, name, attrs, nsmap={}): + # Make sure attrs is a mutable dict--lxml may send an immutable dictproxy. + attrs = dict(attrs) + nsprefix = None + # Invert each namespace map as it comes in. + if len(self.nsmaps) > 1: + # There are no new namespaces for this tag, but + # non-default namespaces are in play, so we need a + # separate tag stack to know when they end. + self.nsmaps.append(None) + elif len(nsmap) > 0: + # A new namespace mapping has come into play. + inverted_nsmap = dict((value, key) for key, value in nsmap.items()) + self.nsmaps.append(inverted_nsmap) + # Also treat the namespace mapping as a set of attributes on the + # tag, so we can recreate it later. + attrs = attrs.copy() + for prefix, namespace in nsmap.items(): + attribute = NamespacedAttribute( + "xmlns", prefix, "http://www.w3.org/2000/xmlns/") + attrs[attribute] = namespace + + # Namespaces are in play. Find any attributes that came in + # from lxml with namespaces attached to their names, and + # turn then into NamespacedAttribute objects. + new_attrs = {} + for attr, value in attrs.items(): + namespace, attr = self._getNsTag(attr) + if namespace is None: + new_attrs[attr] = value + else: + nsprefix = self._prefix_for_namespace(namespace) + attr = NamespacedAttribute(nsprefix, attr, namespace) + new_attrs[attr] = value + attrs = new_attrs + + namespace, name = self._getNsTag(name) + nsprefix = self._prefix_for_namespace(namespace) + self.soup.handle_starttag(name, namespace, nsprefix, attrs) + + def _prefix_for_namespace(self, namespace): + """Find the currently active prefix for the given namespace.""" + if namespace is None: + return None + for inverted_nsmap in reversed(self.nsmaps): + if inverted_nsmap is not None and namespace in inverted_nsmap: + return inverted_nsmap[namespace] + return None + + def end(self, name): + self.soup.endData() + completed_tag = self.soup.tagStack[-1] + namespace, name = self._getNsTag(name) + nsprefix = None + if namespace is not None: + for inverted_nsmap in reversed(self.nsmaps): + if inverted_nsmap is not None and namespace in inverted_nsmap: + nsprefix = inverted_nsmap[namespace] + break + self.soup.handle_endtag(name, nsprefix) + if len(self.nsmaps) > 1: + # This tag, or one of its parents, introduced a namespace + # mapping, so pop it off the stack. + self.nsmaps.pop() + + def pi(self, target, data): + self.soup.endData() + self.soup.handle_data(target + ' ' + data) + self.soup.endData(ProcessingInstruction) + + def data(self, content): + self.soup.handle_data(content) + + def doctype(self, name, pubid, system): + self.soup.endData() + doctype = Doctype.for_name_and_ids(name, pubid, system) + self.soup.object_was_parsed(doctype) + + def comment(self, content): + "Handle comments as Comment objects." + self.soup.endData() + self.soup.handle_data(content) + self.soup.endData(Comment) + + def test_fragment_to_document(self, fragment): + """See `TreeBuilder`.""" + return u'\n%s' % fragment + + +class LXMLTreeBuilder(HTMLTreeBuilder, LXMLTreeBuilderForXML): + + NAME = LXML + ALTERNATE_NAMES = ["lxml-html"] + + features = ALTERNATE_NAMES + [NAME, HTML, FAST, PERMISSIVE] + is_xml = False + + def default_parser(self, encoding): + return etree.HTMLParser + + def feed(self, markup): + encoding = self.soup.original_encoding + try: + self.parser = self.parser_for(encoding) + self.parser.feed(markup) + self.parser.close() + except (UnicodeDecodeError, LookupError, etree.ParserError), e: + raise ParserRejectedMarkup(str(e)) + + + def test_fragment_to_document(self, fragment): + """See `TreeBuilder`.""" + return u'%s' % fragment diff --git a/lib/bs4/dammit.py b/lib/bs4/dammit.py new file mode 100644 index 00000000..317ad6d7 --- /dev/null +++ b/lib/bs4/dammit.py @@ -0,0 +1,839 @@ +# -*- coding: utf-8 -*- +"""Beautiful Soup bonus library: Unicode, Dammit + +This library converts a bytestream to Unicode through any means +necessary. It is heavily based on code from Mark Pilgrim's Universal +Feed Parser. It works best on XML and HTML, but it does not rewrite the +XML or HTML to reflect a new encoding; that's the tree builder's job. +""" + +from pdb import set_trace +import codecs +from htmlentitydefs import codepoint2name +import re +import logging +import string + +# Import a library to autodetect character encodings. +chardet_type = None +try: + # First try the fast C implementation. + # PyPI package: cchardet + import cchardet + def chardet_dammit(s): + return cchardet.detect(s)['encoding'] +except ImportError: + try: + # Fall back to the pure Python implementation + # Debian package: python-chardet + # PyPI package: chardet + import chardet + def chardet_dammit(s): + return chardet.detect(s)['encoding'] + #import chardet.constants + #chardet.constants._debug = 1 + except ImportError: + # No chardet available. + def chardet_dammit(s): + return None + +# Available from http://cjkpython.i18n.org/. +try: + import iconv_codec +except ImportError: + pass + +xml_encoding_re = re.compile( + '^<\?.*encoding=[\'"](.*?)[\'"].*\?>'.encode(), re.I) +html_meta_re = re.compile( + '<\s*meta[^>]+charset\s*=\s*["\']?([^>]*?)[ /;\'">]'.encode(), re.I) + +class EntitySubstitution(object): + + """Substitute XML or HTML entities for the corresponding characters.""" + + def _populate_class_variables(): + lookup = {} + reverse_lookup = {} + characters_for_re = [] + for codepoint, name in list(codepoint2name.items()): + character = unichr(codepoint) + if codepoint != 34: + # There's no point in turning the quotation mark into + # ", unless it happens within an attribute value, which + # is handled elsewhere. + characters_for_re.append(character) + lookup[character] = name + # But we do want to turn " into the quotation mark. + reverse_lookup[name] = character + re_definition = "[%s]" % "".join(characters_for_re) + return lookup, reverse_lookup, re.compile(re_definition) + (CHARACTER_TO_HTML_ENTITY, HTML_ENTITY_TO_CHARACTER, + CHARACTER_TO_HTML_ENTITY_RE) = _populate_class_variables() + + CHARACTER_TO_XML_ENTITY = { + "'": "apos", + '"': "quot", + "&": "amp", + "<": "lt", + ">": "gt", + } + + BARE_AMPERSAND_OR_BRACKET = re.compile("([<>]|" + "&(?!#\d+;|#x[0-9a-fA-F]+;|\w+;)" + ")") + + AMPERSAND_OR_BRACKET = re.compile("([<>&])") + + @classmethod + def _substitute_html_entity(cls, matchobj): + entity = cls.CHARACTER_TO_HTML_ENTITY.get(matchobj.group(0)) + return "&%s;" % entity + + @classmethod + def _substitute_xml_entity(cls, matchobj): + """Used with a regular expression to substitute the + appropriate XML entity for an XML special character.""" + entity = cls.CHARACTER_TO_XML_ENTITY[matchobj.group(0)] + return "&%s;" % entity + + @classmethod + def quoted_attribute_value(self, value): + """Make a value into a quoted XML attribute, possibly escaping it. + + Most strings will be quoted using double quotes. + + Bob's Bar -> "Bob's Bar" + + If a string contains double quotes, it will be quoted using + single quotes. + + Welcome to "my bar" -> 'Welcome to "my bar"' + + If a string contains both single and double quotes, the + double quotes will be escaped, and the string will be quoted + using double quotes. + + Welcome to "Bob's Bar" -> "Welcome to "Bob's bar" + """ + quote_with = '"' + if '"' in value: + if "'" in value: + # The string contains both single and double + # quotes. Turn the double quotes into + # entities. We quote the double quotes rather than + # the single quotes because the entity name is + # """ whether this is HTML or XML. If we + # quoted the single quotes, we'd have to decide + # between ' and &squot;. + replace_with = """ + value = value.replace('"', replace_with) + else: + # There are double quotes but no single quotes. + # We can use single quotes to quote the attribute. + quote_with = "'" + return quote_with + value + quote_with + + @classmethod + def substitute_xml(cls, value, make_quoted_attribute=False): + """Substitute XML entities for special XML characters. + + :param value: A string to be substituted. The less-than sign + will become <, the greater-than sign will become >, + and any ampersands will become &. If you want ampersands + that appear to be part of an entity definition to be left + alone, use substitute_xml_containing_entities() instead. + + :param make_quoted_attribute: If True, then the string will be + quoted, as befits an attribute value. + """ + # Escape angle brackets and ampersands. + value = cls.AMPERSAND_OR_BRACKET.sub( + cls._substitute_xml_entity, value) + + if make_quoted_attribute: + value = cls.quoted_attribute_value(value) + return value + + @classmethod + def substitute_xml_containing_entities( + cls, value, make_quoted_attribute=False): + """Substitute XML entities for special XML characters. + + :param value: A string to be substituted. The less-than sign will + become <, the greater-than sign will become >, and any + ampersands that are not part of an entity defition will + become &. + + :param make_quoted_attribute: If True, then the string will be + quoted, as befits an attribute value. + """ + # Escape angle brackets, and ampersands that aren't part of + # entities. + value = cls.BARE_AMPERSAND_OR_BRACKET.sub( + cls._substitute_xml_entity, value) + + if make_quoted_attribute: + value = cls.quoted_attribute_value(value) + return value + + @classmethod + def substitute_html(cls, s): + """Replace certain Unicode characters with named HTML entities. + + This differs from data.encode(encoding, 'xmlcharrefreplace') + in that the goal is to make the result more readable (to those + with ASCII displays) rather than to recover from + errors. There's absolutely nothing wrong with a UTF-8 string + containg a LATIN SMALL LETTER E WITH ACUTE, but replacing that + character with "é" will make it more readable to some + people. + """ + return cls.CHARACTER_TO_HTML_ENTITY_RE.sub( + cls._substitute_html_entity, s) + + +class EncodingDetector: + """Suggests a number of possible encodings for a bytestring. + + Order of precedence: + + 1. Encodings you specifically tell EncodingDetector to try first + (the override_encodings argument to the constructor). + + 2. An encoding declared within the bytestring itself, either in an + XML declaration (if the bytestring is to be interpreted as an XML + document), or in a tag (if the bytestring is to be + interpreted as an HTML document.) + + 3. An encoding detected through textual analysis by chardet, + cchardet, or a similar external library. + + 4. UTF-8. + + 5. Windows-1252. + """ + def __init__(self, markup, override_encodings=None, is_html=False, + exclude_encodings=None): + self.override_encodings = override_encodings or [] + exclude_encodings = exclude_encodings or [] + self.exclude_encodings = set([x.lower() for x in exclude_encodings]) + self.chardet_encoding = None + self.is_html = is_html + self.declared_encoding = None + + # First order of business: strip a byte-order mark. + self.markup, self.sniffed_encoding = self.strip_byte_order_mark(markup) + + def _usable(self, encoding, tried): + if encoding is not None: + encoding = encoding.lower() + if encoding in self.exclude_encodings: + return False + if encoding not in tried: + tried.add(encoding) + return True + return False + + @property + def encodings(self): + """Yield a number of encodings that might work for this markup.""" + tried = set() + for e in self.override_encodings: + if self._usable(e, tried): + yield e + + # Did the document originally start with a byte-order mark + # that indicated its encoding? + if self._usable(self.sniffed_encoding, tried): + yield self.sniffed_encoding + + # Look within the document for an XML or HTML encoding + # declaration. + if self.declared_encoding is None: + self.declared_encoding = self.find_declared_encoding( + self.markup, self.is_html) + if self._usable(self.declared_encoding, tried): + yield self.declared_encoding + + # Use third-party character set detection to guess at the + # encoding. + if self.chardet_encoding is None: + self.chardet_encoding = chardet_dammit(self.markup) + if self._usable(self.chardet_encoding, tried): + yield self.chardet_encoding + + # As a last-ditch effort, try utf-8 and windows-1252. + for e in ('utf-8', 'windows-1252'): + if self._usable(e, tried): + yield e + + @classmethod + def strip_byte_order_mark(cls, data): + """If a byte-order mark is present, strip it and return the encoding it implies.""" + encoding = None + if isinstance(data, unicode): + # Unicode data cannot have a byte-order mark. + return data, encoding + if (len(data) >= 4) and (data[:2] == b'\xfe\xff') \ + and (data[2:4] != '\x00\x00'): + encoding = 'utf-16be' + data = data[2:] + elif (len(data) >= 4) and (data[:2] == b'\xff\xfe') \ + and (data[2:4] != '\x00\x00'): + encoding = 'utf-16le' + data = data[2:] + elif data[:3] == b'\xef\xbb\xbf': + encoding = 'utf-8' + data = data[3:] + elif data[:4] == b'\x00\x00\xfe\xff': + encoding = 'utf-32be' + data = data[4:] + elif data[:4] == b'\xff\xfe\x00\x00': + encoding = 'utf-32le' + data = data[4:] + return data, encoding + + @classmethod + def find_declared_encoding(cls, markup, is_html=False, search_entire_document=False): + """Given a document, tries to find its declared encoding. + + An XML encoding is declared at the beginning of the document. + + An HTML encoding is declared in a tag, hopefully near the + beginning of the document. + """ + if search_entire_document: + xml_endpos = html_endpos = len(markup) + else: + xml_endpos = 1024 + html_endpos = max(2048, int(len(markup) * 0.05)) + + declared_encoding = None + declared_encoding_match = xml_encoding_re.search(markup, endpos=xml_endpos) + if not declared_encoding_match and is_html: + declared_encoding_match = html_meta_re.search(markup, endpos=html_endpos) + if declared_encoding_match is not None: + declared_encoding = declared_encoding_match.groups()[0].decode( + 'ascii', 'replace') + if declared_encoding: + return declared_encoding.lower() + return None + +class UnicodeDammit: + """A class for detecting the encoding of a *ML document and + converting it to a Unicode string. If the source encoding is + windows-1252, can replace MS smart quotes with their HTML or XML + equivalents.""" + + # This dictionary maps commonly seen values for "charset" in HTML + # meta tags to the corresponding Python codec names. It only covers + # values that aren't in Python's aliases and can't be determined + # by the heuristics in find_codec. + CHARSET_ALIASES = {"macintosh": "mac-roman", + "x-sjis": "shift-jis"} + + ENCODINGS_WITH_SMART_QUOTES = [ + "windows-1252", + "iso-8859-1", + "iso-8859-2", + ] + + def __init__(self, markup, override_encodings=[], + smart_quotes_to=None, is_html=False, exclude_encodings=[]): + self.smart_quotes_to = smart_quotes_to + self.tried_encodings = [] + self.contains_replacement_characters = False + self.is_html = is_html + + self.detector = EncodingDetector( + markup, override_encodings, is_html, exclude_encodings) + + # Short-circuit if the data is in Unicode to begin with. + if isinstance(markup, unicode) or markup == '': + self.markup = markup + self.unicode_markup = unicode(markup) + self.original_encoding = None + return + + # The encoding detector may have stripped a byte-order mark. + # Use the stripped markup from this point on. + self.markup = self.detector.markup + + u = None + for encoding in self.detector.encodings: + markup = self.detector.markup + u = self._convert_from(encoding) + if u is not None: + break + + if not u: + # None of the encodings worked. As an absolute last resort, + # try them again with character replacement. + + for encoding in self.detector.encodings: + if encoding != "ascii": + u = self._convert_from(encoding, "replace") + if u is not None: + logging.warning( + "Some characters could not be decoded, and were " + "replaced with REPLACEMENT CHARACTER.") + self.contains_replacement_characters = True + break + + # If none of that worked, we could at this point force it to + # ASCII, but that would destroy so much data that I think + # giving up is better. + self.unicode_markup = u + if not u: + self.original_encoding = None + + def _sub_ms_char(self, match): + """Changes a MS smart quote character to an XML or HTML + entity, or an ASCII character.""" + orig = match.group(1) + if self.smart_quotes_to == 'ascii': + sub = self.MS_CHARS_TO_ASCII.get(orig).encode() + else: + sub = self.MS_CHARS.get(orig) + if type(sub) == tuple: + if self.smart_quotes_to == 'xml': + sub = '&#x'.encode() + sub[1].encode() + ';'.encode() + else: + sub = '&'.encode() + sub[0].encode() + ';'.encode() + else: + sub = sub.encode() + return sub + + def _convert_from(self, proposed, errors="strict"): + proposed = self.find_codec(proposed) + if not proposed or (proposed, errors) in self.tried_encodings: + return None + self.tried_encodings.append((proposed, errors)) + markup = self.markup + # Convert smart quotes to HTML if coming from an encoding + # that might have them. + if (self.smart_quotes_to is not None + and proposed in self.ENCODINGS_WITH_SMART_QUOTES): + smart_quotes_re = b"([\x80-\x9f])" + smart_quotes_compiled = re.compile(smart_quotes_re) + markup = smart_quotes_compiled.sub(self._sub_ms_char, markup) + + try: + #print "Trying to convert document to %s (errors=%s)" % ( + # proposed, errors) + u = self._to_unicode(markup, proposed, errors) + self.markup = u + self.original_encoding = proposed + except Exception as e: + #print "That didn't work!" + #print e + return None + #print "Correct encoding: %s" % proposed + return self.markup + + def _to_unicode(self, data, encoding, errors="strict"): + '''Given a string and its encoding, decodes the string into Unicode. + %encoding is a string recognized by encodings.aliases''' + return unicode(data, encoding, errors) + + @property + def declared_html_encoding(self): + if not self.is_html: + return None + return self.detector.declared_encoding + + def find_codec(self, charset): + value = (self._codec(self.CHARSET_ALIASES.get(charset, charset)) + or (charset and self._codec(charset.replace("-", ""))) + or (charset and self._codec(charset.replace("-", "_"))) + or (charset and charset.lower()) + or charset + ) + if value: + return value.lower() + return None + + def _codec(self, charset): + if not charset: + return charset + codec = None + try: + codecs.lookup(charset) + codec = charset + except (LookupError, ValueError): + pass + return codec + + + # A partial mapping of ISO-Latin-1 to HTML entities/XML numeric entities. + MS_CHARS = {b'\x80': ('euro', '20AC'), + b'\x81': ' ', + b'\x82': ('sbquo', '201A'), + b'\x83': ('fnof', '192'), + b'\x84': ('bdquo', '201E'), + b'\x85': ('hellip', '2026'), + b'\x86': ('dagger', '2020'), + b'\x87': ('Dagger', '2021'), + b'\x88': ('circ', '2C6'), + b'\x89': ('permil', '2030'), + b'\x8A': ('Scaron', '160'), + b'\x8B': ('lsaquo', '2039'), + b'\x8C': ('OElig', '152'), + b'\x8D': '?', + b'\x8E': ('#x17D', '17D'), + b'\x8F': '?', + b'\x90': '?', + b'\x91': ('lsquo', '2018'), + b'\x92': ('rsquo', '2019'), + b'\x93': ('ldquo', '201C'), + b'\x94': ('rdquo', '201D'), + b'\x95': ('bull', '2022'), + b'\x96': ('ndash', '2013'), + b'\x97': ('mdash', '2014'), + b'\x98': ('tilde', '2DC'), + b'\x99': ('trade', '2122'), + b'\x9a': ('scaron', '161'), + b'\x9b': ('rsaquo', '203A'), + b'\x9c': ('oelig', '153'), + b'\x9d': '?', + b'\x9e': ('#x17E', '17E'), + b'\x9f': ('Yuml', ''),} + + # A parochial partial mapping of ISO-Latin-1 to ASCII. Contains + # horrors like stripping diacritical marks to turn á into a, but also + # contains non-horrors like turning “ into ". + MS_CHARS_TO_ASCII = { + b'\x80' : 'EUR', + b'\x81' : ' ', + b'\x82' : ',', + b'\x83' : 'f', + b'\x84' : ',,', + b'\x85' : '...', + b'\x86' : '+', + b'\x87' : '++', + b'\x88' : '^', + b'\x89' : '%', + b'\x8a' : 'S', + b'\x8b' : '<', + b'\x8c' : 'OE', + b'\x8d' : '?', + b'\x8e' : 'Z', + b'\x8f' : '?', + b'\x90' : '?', + b'\x91' : "'", + b'\x92' : "'", + b'\x93' : '"', + b'\x94' : '"', + b'\x95' : '*', + b'\x96' : '-', + b'\x97' : '--', + b'\x98' : '~', + b'\x99' : '(TM)', + b'\x9a' : 's', + b'\x9b' : '>', + b'\x9c' : 'oe', + b'\x9d' : '?', + b'\x9e' : 'z', + b'\x9f' : 'Y', + b'\xa0' : ' ', + b'\xa1' : '!', + b'\xa2' : 'c', + b'\xa3' : 'GBP', + b'\xa4' : '$', #This approximation is especially parochial--this is the + #generic currency symbol. + b'\xa5' : 'YEN', + b'\xa6' : '|', + b'\xa7' : 'S', + b'\xa8' : '..', + b'\xa9' : '', + b'\xaa' : '(th)', + b'\xab' : '<<', + b'\xac' : '!', + b'\xad' : ' ', + b'\xae' : '(R)', + b'\xaf' : '-', + b'\xb0' : 'o', + b'\xb1' : '+-', + b'\xb2' : '2', + b'\xb3' : '3', + b'\xb4' : ("'", 'acute'), + b'\xb5' : 'u', + b'\xb6' : 'P', + b'\xb7' : '*', + b'\xb8' : ',', + b'\xb9' : '1', + b'\xba' : '(th)', + b'\xbb' : '>>', + b'\xbc' : '1/4', + b'\xbd' : '1/2', + b'\xbe' : '3/4', + b'\xbf' : '?', + b'\xc0' : 'A', + b'\xc1' : 'A', + b'\xc2' : 'A', + b'\xc3' : 'A', + b'\xc4' : 'A', + b'\xc5' : 'A', + b'\xc6' : 'AE', + b'\xc7' : 'C', + b'\xc8' : 'E', + b'\xc9' : 'E', + b'\xca' : 'E', + b'\xcb' : 'E', + b'\xcc' : 'I', + b'\xcd' : 'I', + b'\xce' : 'I', + b'\xcf' : 'I', + b'\xd0' : 'D', + b'\xd1' : 'N', + b'\xd2' : 'O', + b'\xd3' : 'O', + b'\xd4' : 'O', + b'\xd5' : 'O', + b'\xd6' : 'O', + b'\xd7' : '*', + b'\xd8' : 'O', + b'\xd9' : 'U', + b'\xda' : 'U', + b'\xdb' : 'U', + b'\xdc' : 'U', + b'\xdd' : 'Y', + b'\xde' : 'b', + b'\xdf' : 'B', + b'\xe0' : 'a', + b'\xe1' : 'a', + b'\xe2' : 'a', + b'\xe3' : 'a', + b'\xe4' : 'a', + b'\xe5' : 'a', + b'\xe6' : 'ae', + b'\xe7' : 'c', + b'\xe8' : 'e', + b'\xe9' : 'e', + b'\xea' : 'e', + b'\xeb' : 'e', + b'\xec' : 'i', + b'\xed' : 'i', + b'\xee' : 'i', + b'\xef' : 'i', + b'\xf0' : 'o', + b'\xf1' : 'n', + b'\xf2' : 'o', + b'\xf3' : 'o', + b'\xf4' : 'o', + b'\xf5' : 'o', + b'\xf6' : 'o', + b'\xf7' : '/', + b'\xf8' : 'o', + b'\xf9' : 'u', + b'\xfa' : 'u', + b'\xfb' : 'u', + b'\xfc' : 'u', + b'\xfd' : 'y', + b'\xfe' : 'b', + b'\xff' : 'y', + } + + # A map used when removing rogue Windows-1252/ISO-8859-1 + # characters in otherwise UTF-8 documents. + # + # Note that \x81, \x8d, \x8f, \x90, and \x9d are undefined in + # Windows-1252. + WINDOWS_1252_TO_UTF8 = { + 0x80 : b'\xe2\x82\xac', # € + 0x82 : b'\xe2\x80\x9a', # ‚ + 0x83 : b'\xc6\x92', # ƒ + 0x84 : b'\xe2\x80\x9e', # „ + 0x85 : b'\xe2\x80\xa6', # … + 0x86 : b'\xe2\x80\xa0', # † + 0x87 : b'\xe2\x80\xa1', # ‡ + 0x88 : b'\xcb\x86', # ˆ + 0x89 : b'\xe2\x80\xb0', # ‰ + 0x8a : b'\xc5\xa0', # Š + 0x8b : b'\xe2\x80\xb9', # ‹ + 0x8c : b'\xc5\x92', # Œ + 0x8e : b'\xc5\xbd', # Ž + 0x91 : b'\xe2\x80\x98', # ‘ + 0x92 : b'\xe2\x80\x99', # ’ + 0x93 : b'\xe2\x80\x9c', # “ + 0x94 : b'\xe2\x80\x9d', # ” + 0x95 : b'\xe2\x80\xa2', # • + 0x96 : b'\xe2\x80\x93', # – + 0x97 : b'\xe2\x80\x94', # — + 0x98 : b'\xcb\x9c', # ˜ + 0x99 : b'\xe2\x84\xa2', # ™ + 0x9a : b'\xc5\xa1', # š + 0x9b : b'\xe2\x80\xba', # › + 0x9c : b'\xc5\x93', # œ + 0x9e : b'\xc5\xbe', # ž + 0x9f : b'\xc5\xb8', # Ÿ + 0xa0 : b'\xc2\xa0', #   + 0xa1 : b'\xc2\xa1', # ¡ + 0xa2 : b'\xc2\xa2', # ¢ + 0xa3 : b'\xc2\xa3', # £ + 0xa4 : b'\xc2\xa4', # ¤ + 0xa5 : b'\xc2\xa5', # ¥ + 0xa6 : b'\xc2\xa6', # ¦ + 0xa7 : b'\xc2\xa7', # § + 0xa8 : b'\xc2\xa8', # ¨ + 0xa9 : b'\xc2\xa9', # © + 0xaa : b'\xc2\xaa', # ª + 0xab : b'\xc2\xab', # « + 0xac : b'\xc2\xac', # ¬ + 0xad : b'\xc2\xad', # ­ + 0xae : b'\xc2\xae', # ® + 0xaf : b'\xc2\xaf', # ¯ + 0xb0 : b'\xc2\xb0', # ° + 0xb1 : b'\xc2\xb1', # ± + 0xb2 : b'\xc2\xb2', # ² + 0xb3 : b'\xc2\xb3', # ³ + 0xb4 : b'\xc2\xb4', # ´ + 0xb5 : b'\xc2\xb5', # µ + 0xb6 : b'\xc2\xb6', # ¶ + 0xb7 : b'\xc2\xb7', # · + 0xb8 : b'\xc2\xb8', # ¸ + 0xb9 : b'\xc2\xb9', # ¹ + 0xba : b'\xc2\xba', # º + 0xbb : b'\xc2\xbb', # » + 0xbc : b'\xc2\xbc', # ¼ + 0xbd : b'\xc2\xbd', # ½ + 0xbe : b'\xc2\xbe', # ¾ + 0xbf : b'\xc2\xbf', # ¿ + 0xc0 : b'\xc3\x80', # À + 0xc1 : b'\xc3\x81', # Á + 0xc2 : b'\xc3\x82', #  + 0xc3 : b'\xc3\x83', # à + 0xc4 : b'\xc3\x84', # Ä + 0xc5 : b'\xc3\x85', # Å + 0xc6 : b'\xc3\x86', # Æ + 0xc7 : b'\xc3\x87', # Ç + 0xc8 : b'\xc3\x88', # È + 0xc9 : b'\xc3\x89', # É + 0xca : b'\xc3\x8a', # Ê + 0xcb : b'\xc3\x8b', # Ë + 0xcc : b'\xc3\x8c', # Ì + 0xcd : b'\xc3\x8d', # Í + 0xce : b'\xc3\x8e', # Î + 0xcf : b'\xc3\x8f', # Ï + 0xd0 : b'\xc3\x90', # Ð + 0xd1 : b'\xc3\x91', # Ñ + 0xd2 : b'\xc3\x92', # Ò + 0xd3 : b'\xc3\x93', # Ó + 0xd4 : b'\xc3\x94', # Ô + 0xd5 : b'\xc3\x95', # Õ + 0xd6 : b'\xc3\x96', # Ö + 0xd7 : b'\xc3\x97', # × + 0xd8 : b'\xc3\x98', # Ø + 0xd9 : b'\xc3\x99', # Ù + 0xda : b'\xc3\x9a', # Ú + 0xdb : b'\xc3\x9b', # Û + 0xdc : b'\xc3\x9c', # Ü + 0xdd : b'\xc3\x9d', # Ý + 0xde : b'\xc3\x9e', # Þ + 0xdf : b'\xc3\x9f', # ß + 0xe0 : b'\xc3\xa0', # à + 0xe1 : b'\xa1', # á + 0xe2 : b'\xc3\xa2', # â + 0xe3 : b'\xc3\xa3', # ã + 0xe4 : b'\xc3\xa4', # ä + 0xe5 : b'\xc3\xa5', # å + 0xe6 : b'\xc3\xa6', # æ + 0xe7 : b'\xc3\xa7', # ç + 0xe8 : b'\xc3\xa8', # è + 0xe9 : b'\xc3\xa9', # é + 0xea : b'\xc3\xaa', # ê + 0xeb : b'\xc3\xab', # ë + 0xec : b'\xc3\xac', # ì + 0xed : b'\xc3\xad', # í + 0xee : b'\xc3\xae', # î + 0xef : b'\xc3\xaf', # ï + 0xf0 : b'\xc3\xb0', # ð + 0xf1 : b'\xc3\xb1', # ñ + 0xf2 : b'\xc3\xb2', # ò + 0xf3 : b'\xc3\xb3', # ó + 0xf4 : b'\xc3\xb4', # ô + 0xf5 : b'\xc3\xb5', # õ + 0xf6 : b'\xc3\xb6', # ö + 0xf7 : b'\xc3\xb7', # ÷ + 0xf8 : b'\xc3\xb8', # ø + 0xf9 : b'\xc3\xb9', # ù + 0xfa : b'\xc3\xba', # ú + 0xfb : b'\xc3\xbb', # û + 0xfc : b'\xc3\xbc', # ü + 0xfd : b'\xc3\xbd', # ý + 0xfe : b'\xc3\xbe', # þ + } + + MULTIBYTE_MARKERS_AND_SIZES = [ + (0xc2, 0xdf, 2), # 2-byte characters start with a byte C2-DF + (0xe0, 0xef, 3), # 3-byte characters start with E0-EF + (0xf0, 0xf4, 4), # 4-byte characters start with F0-F4 + ] + + FIRST_MULTIBYTE_MARKER = MULTIBYTE_MARKERS_AND_SIZES[0][0] + LAST_MULTIBYTE_MARKER = MULTIBYTE_MARKERS_AND_SIZES[-1][1] + + @classmethod + def detwingle(cls, in_bytes, main_encoding="utf8", + embedded_encoding="windows-1252"): + """Fix characters from one encoding embedded in some other encoding. + + Currently the only situation supported is Windows-1252 (or its + subset ISO-8859-1), embedded in UTF-8. + + The input must be a bytestring. If you've already converted + the document to Unicode, you're too late. + + The output is a bytestring in which `embedded_encoding` + characters have been converted to their `main_encoding` + equivalents. + """ + if embedded_encoding.replace('_', '-').lower() not in ( + 'windows-1252', 'windows_1252'): + raise NotImplementedError( + "Windows-1252 and ISO-8859-1 are the only currently supported " + "embedded encodings.") + + if main_encoding.lower() not in ('utf8', 'utf-8'): + raise NotImplementedError( + "UTF-8 is the only currently supported main encoding.") + + byte_chunks = [] + + chunk_start = 0 + pos = 0 + while pos < len(in_bytes): + byte = in_bytes[pos] + if not isinstance(byte, int): + # Python 2.x + byte = ord(byte) + if (byte >= cls.FIRST_MULTIBYTE_MARKER + and byte <= cls.LAST_MULTIBYTE_MARKER): + # This is the start of a UTF-8 multibyte character. Skip + # to the end. + for start, end, size in cls.MULTIBYTE_MARKERS_AND_SIZES: + if byte >= start and byte <= end: + pos += size + break + elif byte >= 0x80 and byte in cls.WINDOWS_1252_TO_UTF8: + # We found a Windows-1252 character! + # Save the string up to this point as a chunk. + byte_chunks.append(in_bytes[chunk_start:pos]) + + # Now translate the Windows-1252 character into UTF-8 + # and add it as another, one-byte chunk. + byte_chunks.append(cls.WINDOWS_1252_TO_UTF8[byte]) + pos += 1 + chunk_start = pos + else: + # Go on to the next character. + pos += 1 + if chunk_start == 0: + # The string is unchanged. + return in_bytes + else: + # Store the final chunk. + byte_chunks.append(in_bytes[chunk_start:]) + return b''.join(byte_chunks) + diff --git a/lib/bs4/diagnose.py b/lib/bs4/diagnose.py new file mode 100644 index 00000000..1b719830 --- /dev/null +++ b/lib/bs4/diagnose.py @@ -0,0 +1,213 @@ +"""Diagnostic functions, mainly for use when doing tech support.""" +import cProfile +from StringIO import StringIO +from HTMLParser import HTMLParser +import bs4 +from bs4 import BeautifulSoup, __version__ +from bs4.builder import builder_registry + +import os +import pstats +import random +import tempfile +import time +import traceback +import sys +import cProfile + +def diagnose(data): + """Diagnostic suite for isolating common problems.""" + print "Diagnostic running on Beautiful Soup %s" % __version__ + print "Python version %s" % sys.version + + basic_parsers = ["html.parser", "html5lib", "lxml"] + for name in basic_parsers: + for builder in builder_registry.builders: + if name in builder.features: + break + else: + basic_parsers.remove(name) + print ( + "I noticed that %s is not installed. Installing it may help." % + name) + + if 'lxml' in basic_parsers: + basic_parsers.append(["lxml", "xml"]) + try: + from lxml import etree + print "Found lxml version %s" % ".".join(map(str,etree.LXML_VERSION)) + except ImportError, e: + print ( + "lxml is not installed or couldn't be imported.") + + + if 'html5lib' in basic_parsers: + try: + import html5lib + print "Found html5lib version %s" % html5lib.__version__ + except ImportError, e: + print ( + "html5lib is not installed or couldn't be imported.") + + if hasattr(data, 'read'): + data = data.read() + elif os.path.exists(data): + print '"%s" looks like a filename. Reading data from the file.' % data + data = open(data).read() + elif data.startswith("http:") or data.startswith("https:"): + print '"%s" looks like a URL. Beautiful Soup is not an HTTP client.' % data + print "You need to use some other library to get the document behind the URL, and feed that document to Beautiful Soup." + return + print + + for parser in basic_parsers: + print "Trying to parse your markup with %s" % parser + success = False + try: + soup = BeautifulSoup(data, parser) + success = True + except Exception, e: + print "%s could not parse the markup." % parser + traceback.print_exc() + if success: + print "Here's what %s did with the markup:" % parser + print soup.prettify() + + print "-" * 80 + +def lxml_trace(data, html=True, **kwargs): + """Print out the lxml events that occur during parsing. + + This lets you see how lxml parses a document when no Beautiful + Soup code is running. + """ + from lxml import etree + for event, element in etree.iterparse(StringIO(data), html=html, **kwargs): + print("%s, %4s, %s" % (event, element.tag, element.text)) + +class AnnouncingParser(HTMLParser): + """Announces HTMLParser parse events, without doing anything else.""" + + def _p(self, s): + print(s) + + def handle_starttag(self, name, attrs): + self._p("%s START" % name) + + def handle_endtag(self, name): + self._p("%s END" % name) + + def handle_data(self, data): + self._p("%s DATA" % data) + + def handle_charref(self, name): + self._p("%s CHARREF" % name) + + def handle_entityref(self, name): + self._p("%s ENTITYREF" % name) + + def handle_comment(self, data): + self._p("%s COMMENT" % data) + + def handle_decl(self, data): + self._p("%s DECL" % data) + + def unknown_decl(self, data): + self._p("%s UNKNOWN-DECL" % data) + + def handle_pi(self, data): + self._p("%s PI" % data) + +def htmlparser_trace(data): + """Print out the HTMLParser events that occur during parsing. + + This lets you see how HTMLParser parses a document when no + Beautiful Soup code is running. + """ + parser = AnnouncingParser() + parser.feed(data) + +_vowels = "aeiou" +_consonants = "bcdfghjklmnpqrstvwxyz" + +def rword(length=5): + "Generate a random word-like string." + s = '' + for i in range(length): + if i % 2 == 0: + t = _consonants + else: + t = _vowels + s += random.choice(t) + return s + +def rsentence(length=4): + "Generate a random sentence-like string." + return " ".join(rword(random.randint(4,9)) for i in range(length)) + +def rdoc(num_elements=1000): + """Randomly generate an invalid HTML document.""" + tag_names = ['p', 'div', 'span', 'i', 'b', 'script', 'table'] + elements = [] + for i in range(num_elements): + choice = random.randint(0,3) + if choice == 0: + # New tag. + tag_name = random.choice(tag_names) + elements.append("<%s>" % tag_name) + elif choice == 1: + elements.append(rsentence(random.randint(1,4))) + elif choice == 2: + # Close a tag. + tag_name = random.choice(tag_names) + elements.append("" % tag_name) + return "" + "\n".join(elements) + "" + +def benchmark_parsers(num_elements=100000): + """Very basic head-to-head performance benchmark.""" + print "Comparative parser benchmark on Beautiful Soup %s" % __version__ + data = rdoc(num_elements) + print "Generated a large invalid HTML document (%d bytes)." % len(data) + + for parser in ["lxml", ["lxml", "html"], "html5lib", "html.parser"]: + success = False + try: + a = time.time() + soup = BeautifulSoup(data, parser) + b = time.time() + success = True + except Exception, e: + print "%s could not parse the markup." % parser + traceback.print_exc() + if success: + print "BS4+%s parsed the markup in %.2fs." % (parser, b-a) + + from lxml import etree + a = time.time() + etree.HTML(data) + b = time.time() + print "Raw lxml parsed the markup in %.2fs." % (b-a) + + import html5lib + parser = html5lib.HTMLParser() + a = time.time() + parser.parse(data) + b = time.time() + print "Raw html5lib parsed the markup in %.2fs." % (b-a) + +def profile(num_elements=100000, parser="lxml"): + + filehandle = tempfile.NamedTemporaryFile() + filename = filehandle.name + + data = rdoc(num_elements) + vars = dict(bs4=bs4, data=data, parser=parser) + cProfile.runctx('bs4.BeautifulSoup(data, parser)' , vars, vars, filename) + + stats = pstats.Stats(filename) + # stats.strip_dirs() + stats.sort_stats("cumulative") + stats.print_stats('_html5lib|bs4', 50) + +if __name__ == '__main__': + diagnose(sys.stdin.read()) diff --git a/lib/bs4/element.py b/lib/bs4/element.py new file mode 100644 index 00000000..c70ad5a0 --- /dev/null +++ b/lib/bs4/element.py @@ -0,0 +1,1713 @@ +from pdb import set_trace +import collections +import re +import sys +import warnings +from bs4.dammit import EntitySubstitution + +DEFAULT_OUTPUT_ENCODING = "utf-8" +PY3K = (sys.version_info[0] > 2) + +whitespace_re = re.compile("\s+") + +def _alias(attr): + """Alias one attribute name to another for backward compatibility""" + @property + def alias(self): + return getattr(self, attr) + + @alias.setter + def alias(self): + return setattr(self, attr) + return alias + + +class NamespacedAttribute(unicode): + + def __new__(cls, prefix, name, namespace=None): + if name is None: + obj = unicode.__new__(cls, prefix) + elif prefix is None: + # Not really namespaced. + obj = unicode.__new__(cls, name) + else: + obj = unicode.__new__(cls, prefix + ":" + name) + obj.prefix = prefix + obj.name = name + obj.namespace = namespace + return obj + +class AttributeValueWithCharsetSubstitution(unicode): + """A stand-in object for a character encoding specified in HTML.""" + +class CharsetMetaAttributeValue(AttributeValueWithCharsetSubstitution): + """A generic stand-in for the value of a meta tag's 'charset' attribute. + + When Beautiful Soup parses the markup '', the + value of the 'charset' attribute will be one of these objects. + """ + + def __new__(cls, original_value): + obj = unicode.__new__(cls, original_value) + obj.original_value = original_value + return obj + + def encode(self, encoding): + return encoding + + +class ContentMetaAttributeValue(AttributeValueWithCharsetSubstitution): + """A generic stand-in for the value of a meta tag's 'content' attribute. + + When Beautiful Soup parses the markup: + + + The value of the 'content' attribute will be one of these objects. + """ + + CHARSET_RE = re.compile("((^|;)\s*charset=)([^;]*)", re.M) + + def __new__(cls, original_value): + match = cls.CHARSET_RE.search(original_value) + if match is None: + # No substitution necessary. + return unicode.__new__(unicode, original_value) + + obj = unicode.__new__(cls, original_value) + obj.original_value = original_value + return obj + + def encode(self, encoding): + def rewrite(match): + return match.group(1) + encoding + return self.CHARSET_RE.sub(rewrite, self.original_value) + +class HTMLAwareEntitySubstitution(EntitySubstitution): + + """Entity substitution rules that are aware of some HTML quirks. + + Specifically, the contents of + +Hello, world! + + +''' + soup = self.soup(html) + self.assertEqual("text/javascript", soup.find('script')['type']) + + def test_comment(self): + # Comments are represented as Comment objects. + markup = "

foobaz

" + self.assertSoupEquals(markup) + + soup = self.soup(markup) + comment = soup.find(text="foobar") + self.assertEqual(comment.__class__, Comment) + + # The comment is properly integrated into the tree. + foo = soup.find(text="foo") + self.assertEqual(comment, foo.next_element) + baz = soup.find(text="baz") + self.assertEqual(comment, baz.previous_element) + + def test_preserved_whitespace_in_pre_and_textarea(self): + """Whitespace must be preserved in
 and ")
+
+    def test_nested_inline_elements(self):
+        """Inline elements can be nested indefinitely."""
+        b_tag = "Inside a B tag"
+        self.assertSoupEquals(b_tag)
+
+        nested_b_tag = "

A nested tag

" + self.assertSoupEquals(nested_b_tag) + + double_nested_b_tag = "

A doubly nested tag

" + self.assertSoupEquals(nested_b_tag) + + def test_nested_block_level_elements(self): + """Block elements can be nested.""" + soup = self.soup('

Foo

') + blockquote = soup.blockquote + self.assertEqual(blockquote.p.b.string, 'Foo') + self.assertEqual(blockquote.b.string, 'Foo') + + def test_correctly_nested_tables(self): + """One table can go inside another one.""" + markup = ('' + '' + "') + + self.assertSoupEquals( + markup, + '
Here's another table:" + '' + '' + '
foo
Here\'s another table:' + '
foo
' + '
') + + self.assertSoupEquals( + "" + "" + "
Foo
Bar
Baz
") + + def test_deeply_nested_multivalued_attribute(self): + # html5lib can set the attributes of the same tag many times + # as it rearranges the tree. This has caused problems with + # multivalued attributes. + markup = '
' + soup = self.soup(markup) + self.assertEqual(["css"], soup.div.div['class']) + + def test_multivalued_attribute_on_html(self): + # html5lib uses a different API to set the attributes ot the + # tag. This has caused problems with multivalued + # attributes. + markup = '' + soup = self.soup(markup) + self.assertEqual(["a", "b"], soup.html['class']) + + def test_angle_brackets_in_attribute_values_are_escaped(self): + self.assertSoupEquals('', '') + + def test_entities_in_attributes_converted_to_unicode(self): + expect = u'

' + self.assertSoupEquals('

', expect) + self.assertSoupEquals('

', expect) + self.assertSoupEquals('

', expect) + self.assertSoupEquals('

', expect) + + def test_entities_in_text_converted_to_unicode(self): + expect = u'

pi\N{LATIN SMALL LETTER N WITH TILDE}ata

' + self.assertSoupEquals("

piñata

", expect) + self.assertSoupEquals("

piñata

", expect) + self.assertSoupEquals("

piñata

", expect) + self.assertSoupEquals("

piñata

", expect) + + def test_quot_entity_converted_to_quotation_mark(self): + self.assertSoupEquals("

I said "good day!"

", + '

I said "good day!"

') + + def test_out_of_range_entity(self): + expect = u"\N{REPLACEMENT CHARACTER}" + self.assertSoupEquals("�", expect) + self.assertSoupEquals("�", expect) + self.assertSoupEquals("�", expect) + + def test_multipart_strings(self): + "Mostly to prevent a recurrence of a bug in the html5lib treebuilder." + soup = self.soup("

\nfoo

") + self.assertEqual("p", soup.h2.string.next_element.name) + self.assertEqual("p", soup.p.name) + self.assertConnectedness(soup) + + def test_head_tag_between_head_and_body(self): + "Prevent recurrence of a bug in the html5lib treebuilder." + content = """ + + foo + +""" + soup = self.soup(content) + self.assertNotEqual(None, soup.html.body) + self.assertConnectedness(soup) + + def test_multiple_copies_of_a_tag(self): + "Prevent recurrence of a bug in the html5lib treebuilder." + content = """ + + + + + +""" + soup = self.soup(content) + self.assertConnectedness(soup.article) + + def test_basic_namespaces(self): + """Parsers don't need to *understand* namespaces, but at the + very least they should not choke on namespaces or lose + data.""" + + markup = b'4' + soup = self.soup(markup) + self.assertEqual(markup, soup.encode()) + html = soup.html + self.assertEqual('http://www.w3.org/1999/xhtml', soup.html['xmlns']) + self.assertEqual( + 'http://www.w3.org/1998/Math/MathML', soup.html['xmlns:mathml']) + self.assertEqual( + 'http://www.w3.org/2000/svg', soup.html['xmlns:svg']) + + def test_multivalued_attribute_value_becomes_list(self): + markup = b'' + soup = self.soup(markup) + self.assertEqual(['foo', 'bar'], soup.a['class']) + + # + # Generally speaking, tests below this point are more tests of + # Beautiful Soup than tests of the tree builders. But parsers are + # weird, so we run these tests separately for every tree builder + # to detect any differences between them. + # + + def test_can_parse_unicode_document(self): + # A seemingly innocuous document... but it's in Unicode! And + # it contains characters that can't be represented in the + # encoding found in the declaration! The horror! + markup = u'Sacr\N{LATIN SMALL LETTER E WITH ACUTE} bleu!' + soup = self.soup(markup) + self.assertEqual(u'Sacr\xe9 bleu!', soup.body.string) + + def test_soupstrainer(self): + """Parsers should be able to work with SoupStrainers.""" + strainer = SoupStrainer("b") + soup = self.soup("A bold statement", + parse_only=strainer) + self.assertEqual(soup.decode(), "bold") + + def test_single_quote_attribute_values_become_double_quotes(self): + self.assertSoupEquals("", + '') + + def test_attribute_values_with_nested_quotes_are_left_alone(self): + text = """a""" + self.assertSoupEquals(text) + + def test_attribute_values_with_double_nested_quotes_get_quoted(self): + text = """a""" + soup = self.soup(text) + soup.foo['attr'] = 'Brawls happen at "Bob\'s Bar"' + self.assertSoupEquals( + soup.foo.decode(), + """a""") + + def test_ampersand_in_attribute_value_gets_escaped(self): + self.assertSoupEquals('', + '') + + self.assertSoupEquals( + 'foo', + 'foo') + + def test_escaped_ampersand_in_attribute_value_is_left_alone(self): + self.assertSoupEquals('') + + def test_entities_in_strings_converted_during_parsing(self): + # Both XML and HTML entities are converted to Unicode characters + # during parsing. + text = "

<<sacré bleu!>>

" + expected = u"

<<sacr\N{LATIN SMALL LETTER E WITH ACUTE} bleu!>>

" + self.assertSoupEquals(text, expected) + + def test_smart_quotes_converted_on_the_way_in(self): + # Microsoft smart quotes are converted to Unicode characters during + # parsing. + quote = b"

\x91Foo\x92

" + soup = self.soup(quote) + self.assertEqual( + soup.p.string, + u"\N{LEFT SINGLE QUOTATION MARK}Foo\N{RIGHT SINGLE QUOTATION MARK}") + + def test_non_breaking_spaces_converted_on_the_way_in(self): + soup = self.soup("  ") + self.assertEqual(soup.a.string, u"\N{NO-BREAK SPACE}" * 2) + + def test_entities_converted_on_the_way_out(self): + text = "

<<sacré bleu!>>

" + expected = u"

<<sacr\N{LATIN SMALL LETTER E WITH ACUTE} bleu!>>

".encode("utf-8") + soup = self.soup(text) + self.assertEqual(soup.p.encode("utf-8"), expected) + + def test_real_iso_latin_document(self): + # Smoke test of interrelated functionality, using an + # easy-to-understand document. + + # Here it is in Unicode. Note that it claims to be in ISO-Latin-1. + unicode_html = u'

Sacr\N{LATIN SMALL LETTER E WITH ACUTE} bleu!

' + + # That's because we're going to encode it into ISO-Latin-1, and use + # that to test. + iso_latin_html = unicode_html.encode("iso-8859-1") + + # Parse the ISO-Latin-1 HTML. + soup = self.soup(iso_latin_html) + # Encode it to UTF-8. + result = soup.encode("utf-8") + + # What do we expect the result to look like? Well, it would + # look like unicode_html, except that the META tag would say + # UTF-8 instead of ISO-Latin-1. + expected = unicode_html.replace("ISO-Latin-1", "utf-8") + + # And, of course, it would be in UTF-8, not Unicode. + expected = expected.encode("utf-8") + + # Ta-da! + self.assertEqual(result, expected) + + def test_real_shift_jis_document(self): + # Smoke test to make sure the parser can handle a document in + # Shift-JIS encoding, without choking. + shift_jis_html = ( + b'
'
+            b'\x82\xb1\x82\xea\x82\xcdShift-JIS\x82\xc5\x83R\x81[\x83f'
+            b'\x83B\x83\x93\x83O\x82\xb3\x82\xea\x82\xbd\x93\xfa\x96{\x8c'
+            b'\xea\x82\xcc\x83t\x83@\x83C\x83\x8b\x82\xc5\x82\xb7\x81B'
+            b'
') + unicode_html = shift_jis_html.decode("shift-jis") + soup = self.soup(unicode_html) + + # Make sure the parse tree is correctly encoded to various + # encodings. + self.assertEqual(soup.encode("utf-8"), unicode_html.encode("utf-8")) + self.assertEqual(soup.encode("euc_jp"), unicode_html.encode("euc_jp")) + + def test_real_hebrew_document(self): + # A real-world test to make sure we can convert ISO-8859-9 (a + # Hebrew encoding) to UTF-8. + hebrew_document = b'Hebrew (ISO 8859-8) in Visual Directionality

Hebrew (ISO 8859-8) in Visual Directionality

\xed\xe5\xec\xf9' + soup = self.soup( + hebrew_document, from_encoding="iso8859-8") + self.assertEqual(soup.original_encoding, 'iso8859-8') + self.assertEqual( + soup.encode('utf-8'), + hebrew_document.decode("iso8859-8").encode("utf-8")) + + def test_meta_tag_reflects_current_encoding(self): + # Here's the tag saying that a document is + # encoded in Shift-JIS. + meta_tag = ('') + + # Here's a document incorporating that meta tag. + shift_jis_html = ( + '\n%s\n' + '' + 'Shift-JIS markup goes here.') % meta_tag + soup = self.soup(shift_jis_html) + + # Parse the document, and the charset is seemingly unaffected. + parsed_meta = soup.find('meta', {'http-equiv': 'Content-type'}) + content = parsed_meta['content'] + self.assertEqual('text/html; charset=x-sjis', content) + + # But that value is actually a ContentMetaAttributeValue object. + self.assertTrue(isinstance(content, ContentMetaAttributeValue)) + + # And it will take on a value that reflects its current + # encoding. + self.assertEqual('text/html; charset=utf8', content.encode("utf8")) + + # For the rest of the story, see TestSubstitutions in + # test_tree.py. + + def test_html5_style_meta_tag_reflects_current_encoding(self): + # Here's the tag saying that a document is + # encoded in Shift-JIS. + meta_tag = ('') + + # Here's a document incorporating that meta tag. + shift_jis_html = ( + '\n%s\n' + '' + 'Shift-JIS markup goes here.') % meta_tag + soup = self.soup(shift_jis_html) + + # Parse the document, and the charset is seemingly unaffected. + parsed_meta = soup.find('meta', id="encoding") + charset = parsed_meta['charset'] + self.assertEqual('x-sjis', charset) + + # But that value is actually a CharsetMetaAttributeValue object. + self.assertTrue(isinstance(charset, CharsetMetaAttributeValue)) + + # And it will take on a value that reflects its current + # encoding. + self.assertEqual('utf8', charset.encode("utf8")) + + def test_tag_with_no_attributes_can_have_attributes_added(self): + data = self.soup("text") + data.a['foo'] = 'bar' + self.assertEqual('text', data.a.decode()) + +class XMLTreeBuilderSmokeTest(object): + + def test_pickle_and_unpickle_identity(self): + # Pickling a tree, then unpickling it, yields a tree identical + # to the original. + tree = self.soup("foo") + dumped = pickle.dumps(tree, 2) + loaded = pickle.loads(dumped) + self.assertEqual(loaded.__class__, BeautifulSoup) + self.assertEqual(loaded.decode(), tree.decode()) + + def test_docstring_generated(self): + soup = self.soup("") + self.assertEqual( + soup.encode(), b'\n') + + def test_real_xhtml_document(self): + """A real XHTML document should come out *exactly* the same as it went in.""" + markup = b""" + + +Hello. +Goodbye. +""" + soup = self.soup(markup) + self.assertEqual( + soup.encode("utf-8"), markup) + + def test_formatter_processes_script_tag_for_xml_documents(self): + doc = """ + +""" + soup = BeautifulSoup(doc, "lxml-xml") + # lxml would have stripped this while parsing, but we can add + # it later. + soup.script.string = 'console.log("< < hey > > ");' + encoded = soup.encode() + self.assertTrue(b"< < hey > >" in encoded) + + def test_can_parse_unicode_document(self): + markup = u'Sacr\N{LATIN SMALL LETTER E WITH ACUTE} bleu!' + soup = self.soup(markup) + self.assertEqual(u'Sacr\xe9 bleu!', soup.root.string) + + def test_popping_namespaced_tag(self): + markup = 'b2012-07-02T20:33:42Zcd' + soup = self.soup(markup) + self.assertEqual( + unicode(soup.rss), markup) + + def test_docstring_includes_correct_encoding(self): + soup = self.soup("") + self.assertEqual( + soup.encode("latin1"), + b'\n') + + def test_large_xml_document(self): + """A large XML document should come out the same as it went in.""" + markup = (b'\n' + + b'0' * (2**12) + + b'') + soup = self.soup(markup) + self.assertEqual(soup.encode("utf-8"), markup) + + + def test_tags_are_empty_element_if_and_only_if_they_are_empty(self): + self.assertSoupEquals("

", "

") + self.assertSoupEquals("

foo

") + + def test_namespaces_are_preserved(self): + markup = 'This tag is in the a namespaceThis tag is in the b namespace' + soup = self.soup(markup) + root = soup.root + self.assertEqual("http://example.com/", root['xmlns:a']) + self.assertEqual("http://example.net/", root['xmlns:b']) + + def test_closing_namespaced_tag(self): + markup = '

20010504

' + soup = self.soup(markup) + self.assertEqual(unicode(soup.p), markup) + + def test_namespaced_attributes(self): + markup = '' + soup = self.soup(markup) + self.assertEqual(unicode(soup.foo), markup) + + def test_namespaced_attributes_xml_namespace(self): + markup = 'bar' + soup = self.soup(markup) + self.assertEqual(unicode(soup.foo), markup) + +class HTML5TreeBuilderSmokeTest(HTMLTreeBuilderSmokeTest): + """Smoke test for a tree builder that supports HTML5.""" + + def test_real_xhtml_document(self): + # Since XHTML is not HTML5, HTML5 parsers are not tested to handle + # XHTML documents in any particular way. + pass + + def test_html_tags_have_namespace(self): + markup = "" + soup = self.soup(markup) + self.assertEqual("http://www.w3.org/1999/xhtml", soup.a.namespace) + + def test_svg_tags_have_namespace(self): + markup = '' + soup = self.soup(markup) + namespace = "http://www.w3.org/2000/svg" + self.assertEqual(namespace, soup.svg.namespace) + self.assertEqual(namespace, soup.circle.namespace) + + + def test_mathml_tags_have_namespace(self): + markup = '5' + soup = self.soup(markup) + namespace = 'http://www.w3.org/1998/Math/MathML' + self.assertEqual(namespace, soup.math.namespace) + self.assertEqual(namespace, soup.msqrt.namespace) + + def test_xml_declaration_becomes_comment(self): + markup = '' + soup = self.soup(markup) + self.assertTrue(isinstance(soup.contents[0], Comment)) + self.assertEqual(soup.contents[0], '?xml version="1.0" encoding="utf-8"?') + self.assertEqual("html", soup.contents[0].next_element.name) + +def skipIf(condition, reason): + def nothing(test, *args, **kwargs): + return None + + def decorator(test_item): + if condition: + return nothing + else: + return test_item + + return decorator diff --git a/headphones/certgen.py b/lib/certgen.py similarity index 98% rename from headphones/certgen.py rename to lib/certgen.py index 079424b8..1b941161 100644 --- a/headphones/certgen.py +++ b/lib/certgen.py @@ -8,9 +8,8 @@ Certificate generation module. """ -import time - from OpenSSL import crypto +import time TYPE_RSA = crypto.TYPE_RSA TYPE_DSA = crypto.TYPE_DSA @@ -30,7 +29,6 @@ def createKeyPair(type, bits): pkey.generate_key(type, bits) return pkey - def createCertRequest(pkey, digest="md5", **name): """ Create a certificate request. @@ -51,14 +49,13 @@ def createCertRequest(pkey, digest="md5", **name): req = crypto.X509Req() subj = req.get_subject() - for (key, value) in name.items(): + for (key,value) in name.items(): setattr(subj, key, value) req.set_pubkey(pkey) req.sign(pkey, digest) return req - def createCertificate(req, (issuerCert, issuerKey), serial, (notBefore, notAfter), digest="md5"): """ Generate a certificate given a certificate request. diff --git a/lib/cherrypy/LICENSE.txt b/lib/cherrypy/LICENSE.txt new file mode 100644 index 00000000..3816f933 --- /dev/null +++ b/lib/cherrypy/LICENSE.txt @@ -0,0 +1,25 @@ +Copyright (c) 2004-2011, CherryPy Team (team@cherrypy.org) +All rights reserved. + +Redistribution and use in source and binary forms, with or without modification, +are permitted provided that the following conditions are met: + + * Redistributions of source code must retain the above copyright notice, + this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + * Neither the name of the CherryPy Team nor the names of its contributors + may be used to endorse or promote products derived from this software + without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/lib/cherrypy/__init__.py b/lib/cherrypy/__init__.py new file mode 100644 index 00000000..7870ef0d --- /dev/null +++ b/lib/cherrypy/__init__.py @@ -0,0 +1,652 @@ +"""CherryPy is a pythonic, object-oriented HTTP framework. + + +CherryPy consists of not one, but four separate API layers. + +The APPLICATION LAYER is the simplest. CherryPy applications are written as +a tree of classes and methods, where each branch in the tree corresponds to +a branch in the URL path. Each method is a 'page handler', which receives +GET and POST params as keyword arguments, and returns or yields the (HTML) +body of the response. The special method name 'index' is used for paths +that end in a slash, and the special method name 'default' is used to +handle multiple paths via a single handler. This layer also includes: + + * the 'exposed' attribute (and cherrypy.expose) + * cherrypy.quickstart() + * _cp_config attributes + * cherrypy.tools (including cherrypy.session) + * cherrypy.url() + +The ENVIRONMENT LAYER is used by developers at all levels. It provides +information about the current request and response, plus the application +and server environment, via a (default) set of top-level objects: + + * cherrypy.request + * cherrypy.response + * cherrypy.engine + * cherrypy.server + * cherrypy.tree + * cherrypy.config + * cherrypy.thread_data + * cherrypy.log + * cherrypy.HTTPError, NotFound, and HTTPRedirect + * cherrypy.lib + +The EXTENSION LAYER allows advanced users to construct and share their own +plugins. It consists of: + + * Hook API + * Tool API + * Toolbox API + * Dispatch API + * Config Namespace API + +Finally, there is the CORE LAYER, which uses the core API's to construct +the default components which are available at higher layers. You can think +of the default components as the 'reference implementation' for CherryPy. +Megaframeworks (and advanced users) may replace the default components +with customized or extended components. The core API's are: + + * Application API + * Engine API + * Request API + * Server API + * WSGI API + +These API's are described in the `CherryPy specification `_. +""" + +__version__ = "3.6.0" + +from cherrypy._cpcompat import urljoin as _urljoin, urlencode as _urlencode +from cherrypy._cpcompat import basestring, unicodestr, set + +from cherrypy._cperror import HTTPError, HTTPRedirect, InternalRedirect +from cherrypy._cperror import NotFound, CherryPyException, TimeoutError + +from cherrypy import _cpdispatch as dispatch + +from cherrypy import _cptools +tools = _cptools.default_toolbox +Tool = _cptools.Tool + +from cherrypy import _cprequest +from cherrypy.lib import httputil as _httputil + +from cherrypy import _cptree +tree = _cptree.Tree() +from cherrypy._cptree import Application +from cherrypy import _cpwsgi as wsgi + +from cherrypy import process +try: + from cherrypy.process import win32 + engine = win32.Win32Bus() + engine.console_control_handler = win32.ConsoleCtrlHandler(engine) + del win32 +except ImportError: + engine = process.bus + + +# Timeout monitor. We add two channels to the engine +# to which cherrypy.Application will publish. +engine.listeners['before_request'] = set() +engine.listeners['after_request'] = set() + + +class _TimeoutMonitor(process.plugins.Monitor): + + def __init__(self, bus): + self.servings = [] + process.plugins.Monitor.__init__(self, bus, self.run) + + def before_request(self): + self.servings.append((serving.request, serving.response)) + + def after_request(self): + try: + self.servings.remove((serving.request, serving.response)) + except ValueError: + pass + + def run(self): + """Check timeout on all responses. (Internal)""" + for req, resp in self.servings: + resp.check_timeout() +engine.timeout_monitor = _TimeoutMonitor(engine) +engine.timeout_monitor.subscribe() + +engine.autoreload = process.plugins.Autoreloader(engine) +engine.autoreload.subscribe() + +engine.thread_manager = process.plugins.ThreadManager(engine) +engine.thread_manager.subscribe() + +engine.signal_handler = process.plugins.SignalHandler(engine) + + +class _HandleSignalsPlugin(object): + + """Handle signals from other processes based on the configured + platform handlers above.""" + + def __init__(self, bus): + self.bus = bus + + def subscribe(self): + """Add the handlers based on the platform""" + if hasattr(self.bus, "signal_handler"): + self.bus.signal_handler.subscribe() + if hasattr(self.bus, "console_control_handler"): + self.bus.console_control_handler.subscribe() + +engine.signals = _HandleSignalsPlugin(engine) + + +from cherrypy import _cpserver +server = _cpserver.Server() +server.subscribe() + + +def quickstart(root=None, script_name="", config=None): + """Mount the given root, start the builtin server (and engine), then block. + + root: an instance of a "controller class" (a collection of page handler + methods) which represents the root of the application. + script_name: a string containing the "mount point" of the application. + This should start with a slash, and be the path portion of the URL + at which to mount the given root. For example, if root.index() will + handle requests to "http://www.example.com:8080/dept/app1/", then + the script_name argument would be "/dept/app1". + + It MUST NOT end in a slash. If the script_name refers to the root + of the URI, it MUST be an empty string (not "/"). + config: a file or dict containing application config. If this contains + a [global] section, those entries will be used in the global + (site-wide) config. + """ + if config: + _global_conf_alias.update(config) + + tree.mount(root, script_name, config) + + engine.signals.subscribe() + engine.start() + engine.block() + + +from cherrypy._cpcompat import threadlocal as _local + + +class _Serving(_local): + + """An interface for registering request and response objects. + + Rather than have a separate "thread local" object for the request and + the response, this class works as a single threadlocal container for + both objects (and any others which developers wish to define). In this + way, we can easily dump those objects when we stop/start a new HTTP + conversation, yet still refer to them as module-level globals in a + thread-safe way. + """ + + request = _cprequest.Request(_httputil.Host("127.0.0.1", 80), + _httputil.Host("127.0.0.1", 1111)) + """ + The request object for the current thread. In the main thread, + and any threads which are not receiving HTTP requests, this is None.""" + + response = _cprequest.Response() + """ + The response object for the current thread. In the main thread, + and any threads which are not receiving HTTP requests, this is None.""" + + def load(self, request, response): + self.request = request + self.response = response + + def clear(self): + """Remove all attributes of self.""" + self.__dict__.clear() + +serving = _Serving() + + +class _ThreadLocalProxy(object): + + __slots__ = ['__attrname__', '__dict__'] + + def __init__(self, attrname): + self.__attrname__ = attrname + + def __getattr__(self, name): + child = getattr(serving, self.__attrname__) + return getattr(child, name) + + def __setattr__(self, name, value): + if name in ("__attrname__", ): + object.__setattr__(self, name, value) + else: + child = getattr(serving, self.__attrname__) + setattr(child, name, value) + + def __delattr__(self, name): + child = getattr(serving, self.__attrname__) + delattr(child, name) + + def _get_dict(self): + child = getattr(serving, self.__attrname__) + d = child.__class__.__dict__.copy() + d.update(child.__dict__) + return d + __dict__ = property(_get_dict) + + def __getitem__(self, key): + child = getattr(serving, self.__attrname__) + return child[key] + + def __setitem__(self, key, value): + child = getattr(serving, self.__attrname__) + child[key] = value + + def __delitem__(self, key): + child = getattr(serving, self.__attrname__) + del child[key] + + def __contains__(self, key): + child = getattr(serving, self.__attrname__) + return key in child + + def __len__(self): + child = getattr(serving, self.__attrname__) + return len(child) + + def __nonzero__(self): + child = getattr(serving, self.__attrname__) + return bool(child) + # Python 3 + __bool__ = __nonzero__ + +# Create request and response object (the same objects will be used +# throughout the entire life of the webserver, but will redirect +# to the "serving" object) +request = _ThreadLocalProxy('request') +response = _ThreadLocalProxy('response') + +# Create thread_data object as a thread-specific all-purpose storage + + +class _ThreadData(_local): + + """A container for thread-specific data.""" +thread_data = _ThreadData() + + +# Monkeypatch pydoc to allow help() to go through the threadlocal proxy. +# Jan 2007: no Googleable examples of anyone else replacing pydoc.resolve. +# The only other way would be to change what is returned from type(request) +# and that's not possible in pure Python (you'd have to fake ob_type). +def _cherrypy_pydoc_resolve(thing, forceload=0): + """Given an object or a path to an object, get the object and its name.""" + if isinstance(thing, _ThreadLocalProxy): + thing = getattr(serving, thing.__attrname__) + return _pydoc._builtin_resolve(thing, forceload) + +try: + import pydoc as _pydoc + _pydoc._builtin_resolve = _pydoc.resolve + _pydoc.resolve = _cherrypy_pydoc_resolve +except ImportError: + pass + + +from cherrypy import _cplogging + + +class _GlobalLogManager(_cplogging.LogManager): + + """A site-wide LogManager; routes to app.log or global log as appropriate. + + This :class:`LogManager` implements + cherrypy.log() and cherrypy.log.access(). If either + function is called during a request, the message will be sent to the + logger for the current Application. If they are called outside of a + request, the message will be sent to the site-wide logger. + """ + + def __call__(self, *args, **kwargs): + """Log the given message to the app.log or global log as appropriate. + """ + # Do NOT use try/except here. See + # https://bitbucket.org/cherrypy/cherrypy/issue/945 + if hasattr(request, 'app') and hasattr(request.app, 'log'): + log = request.app.log + else: + log = self + return log.error(*args, **kwargs) + + def access(self): + """Log an access message to the app.log or global log as appropriate. + """ + try: + return request.app.log.access() + except AttributeError: + return _cplogging.LogManager.access(self) + + +log = _GlobalLogManager() +# Set a default screen handler on the global log. +log.screen = True +log.error_file = '' +# Using an access file makes CP about 10% slower. Leave off by default. +log.access_file = '' + + +def _buslog(msg, level): + log.error(msg, 'ENGINE', severity=level) +engine.subscribe('log', _buslog) + +# Helper functions for CP apps # + + +def expose(func=None, alias=None): + """Expose the function, optionally providing an alias or set of aliases.""" + def expose_(func): + func.exposed = True + if alias is not None: + if isinstance(alias, basestring): + parents[alias.replace(".", "_")] = func + else: + for a in alias: + parents[a.replace(".", "_")] = func + return func + + import sys + import types + if isinstance(func, (types.FunctionType, types.MethodType)): + if alias is None: + # @expose + func.exposed = True + return func + else: + # func = expose(func, alias) + parents = sys._getframe(1).f_locals + return expose_(func) + elif func is None: + if alias is None: + # @expose() + parents = sys._getframe(1).f_locals + return expose_ + else: + # @expose(alias="alias") or + # @expose(alias=["alias1", "alias2"]) + parents = sys._getframe(1).f_locals + return expose_ + else: + # @expose("alias") or + # @expose(["alias1", "alias2"]) + parents = sys._getframe(1).f_locals + alias = func + return expose_ + + +def popargs(*args, **kwargs): + """A decorator for _cp_dispatch + (cherrypy.dispatch.Dispatcher.dispatch_method_name). + + Optional keyword argument: handler=(Object or Function) + + Provides a _cp_dispatch function that pops off path segments into + cherrypy.request.params under the names specified. The dispatch + is then forwarded on to the next vpath element. + + Note that any existing (and exposed) member function of the class that + popargs is applied to will override that value of the argument. For + instance, if you have a method named "list" on the class decorated with + popargs, then accessing "/list" will call that function instead of popping + it off as the requested parameter. This restriction applies to all + _cp_dispatch functions. The only way around this restriction is to create + a "blank class" whose only function is to provide _cp_dispatch. + + If there are path elements after the arguments, or more arguments + are requested than are available in the vpath, then the 'handler' + keyword argument specifies the next object to handle the parameterized + request. If handler is not specified or is None, then self is used. + If handler is a function rather than an instance, then that function + will be called with the args specified and the return value from that + function used as the next object INSTEAD of adding the parameters to + cherrypy.request.args. + + This decorator may be used in one of two ways: + + As a class decorator: + @cherrypy.popargs('year', 'month', 'day') + class Blog: + def index(self, year=None, month=None, day=None): + #Process the parameters here; any url like + #/, /2009, /2009/12, or /2009/12/31 + #will fill in the appropriate parameters. + + def create(self): + #This link will still be available at /create. Defined functions + #take precedence over arguments. + + Or as a member of a class: + class Blog: + _cp_dispatch = cherrypy.popargs('year', 'month', 'day') + #... + + The handler argument may be used to mix arguments with built in functions. + For instance, the following setup allows different activities at the + day, month, and year level: + + class DayHandler: + def index(self, year, month, day): + #Do something with this day; probably list entries + + def delete(self, year, month, day): + #Delete all entries for this day + + @cherrypy.popargs('day', handler=DayHandler()) + class MonthHandler: + def index(self, year, month): + #Do something with this month; probably list entries + + def delete(self, year, month): + #Delete all entries for this month + + @cherrypy.popargs('month', handler=MonthHandler()) + class YearHandler: + def index(self, year): + #Do something with this year + + #... + + @cherrypy.popargs('year', handler=YearHandler()) + class Root: + def index(self): + #... + + """ + + # Since keyword arg comes after *args, we have to process it ourselves + # for lower versions of python. + + handler = None + handler_call = False + for k, v in kwargs.items(): + if k == 'handler': + handler = v + else: + raise TypeError( + "cherrypy.popargs() got an unexpected keyword argument '{0}'" + .format(k) + ) + + import inspect + + if handler is not None \ + and (hasattr(handler, '__call__') or inspect.isclass(handler)): + handler_call = True + + def decorated(cls_or_self=None, vpath=None): + if inspect.isclass(cls_or_self): + # cherrypy.popargs is a class decorator + cls = cls_or_self + setattr(cls, dispatch.Dispatcher.dispatch_method_name, decorated) + return cls + + # We're in the actual function + self = cls_or_self + parms = {} + for arg in args: + if not vpath: + break + parms[arg] = vpath.pop(0) + + if handler is not None: + if handler_call: + return handler(**parms) + else: + request.params.update(parms) + return handler + + request.params.update(parms) + + # If we are the ultimate handler, then to prevent our _cp_dispatch + # from being called again, we will resolve remaining elements through + # getattr() directly. + if vpath: + return getattr(self, vpath.pop(0), None) + else: + return self + + return decorated + + +def url(path="", qs="", script_name=None, base=None, relative=None): + """Create an absolute URL for the given path. + + If 'path' starts with a slash ('/'), this will return + (base + script_name + path + qs). + If it does not start with a slash, this returns + (base + script_name [+ request.path_info] + path + qs). + + If script_name is None, cherrypy.request will be used + to find a script_name, if available. + + If base is None, cherrypy.request.base will be used (if available). + Note that you can use cherrypy.tools.proxy to change this. + + Finally, note that this function can be used to obtain an absolute URL + for the current request path (minus the querystring) by passing no args. + If you call url(qs=cherrypy.request.query_string), you should get the + original browser URL (assuming no internal redirections). + + If relative is None or not provided, request.app.relative_urls will + be used (if available, else False). If False, the output will be an + absolute URL (including the scheme, host, vhost, and script_name). + If True, the output will instead be a URL that is relative to the + current request path, perhaps including '..' atoms. If relative is + the string 'server', the output will instead be a URL that is + relative to the server root; i.e., it will start with a slash. + """ + if isinstance(qs, (tuple, list, dict)): + qs = _urlencode(qs) + if qs: + qs = '?' + qs + + if request.app: + if not path.startswith("/"): + # Append/remove trailing slash from path_info as needed + # (this is to support mistyped URL's without redirecting; + # if you want to redirect, use tools.trailing_slash). + pi = request.path_info + if request.is_index is True: + if not pi.endswith('/'): + pi = pi + '/' + elif request.is_index is False: + if pi.endswith('/') and pi != '/': + pi = pi[:-1] + + if path == "": + path = pi + else: + path = _urljoin(pi, path) + + if script_name is None: + script_name = request.script_name + if base is None: + base = request.base + + newurl = base + script_name + path + qs + else: + # No request.app (we're being called outside a request). + # We'll have to guess the base from server.* attributes. + # This will produce very different results from the above + # if you're using vhosts or tools.proxy. + if base is None: + base = server.base() + + path = (script_name or "") + path + newurl = base + path + qs + + if './' in newurl: + # Normalize the URL by removing ./ and ../ + atoms = [] + for atom in newurl.split('/'): + if atom == '.': + pass + elif atom == '..': + atoms.pop() + else: + atoms.append(atom) + newurl = '/'.join(atoms) + + # At this point, we should have a fully-qualified absolute URL. + + if relative is None: + relative = getattr(request.app, "relative_urls", False) + + # See http://www.ietf.org/rfc/rfc2396.txt + if relative == 'server': + # "A relative reference beginning with a single slash character is + # termed an absolute-path reference, as defined by ..." + # This is also sometimes called "server-relative". + newurl = '/' + '/'.join(newurl.split('/', 3)[3:]) + elif relative: + # "A relative reference that does not begin with a scheme name + # or a slash character is termed a relative-path reference." + old = url(relative=False).split('/')[:-1] + new = newurl.split('/') + while old and new: + a, b = old[0], new[0] + if a != b: + break + old.pop(0) + new.pop(0) + new = (['..'] * len(old)) + new + newurl = '/'.join(new) + + return newurl + + +# import _cpconfig last so it can reference other top-level objects +from cherrypy import _cpconfig +# Use _global_conf_alias so quickstart can use 'config' as an arg +# without shadowing cherrypy.config. +config = _global_conf_alias = _cpconfig.Config() +config.defaults = { + 'tools.log_tracebacks.on': True, + 'tools.log_headers.on': True, + 'tools.trailing_slash.on': True, + 'tools.encode.on': True +} +config.namespaces["log"] = lambda k, v: setattr(log, k, v) +config.namespaces["checker"] = lambda k, v: setattr(checker, k, v) +# Must reset to get our defaults applied. +config.reset() + +from cherrypy import _cpchecker +checker = _cpchecker.Checker() +engine.subscribe('start', checker) diff --git a/lib/cherrypy/_cpchecker.py b/lib/cherrypy/_cpchecker.py new file mode 100644 index 00000000..4ef82597 --- /dev/null +++ b/lib/cherrypy/_cpchecker.py @@ -0,0 +1,332 @@ +import os +import warnings + +import cherrypy +from cherrypy._cpcompat import iteritems, copykeys, builtins + + +class Checker(object): + + """A checker for CherryPy sites and their mounted applications. + + When this object is called at engine startup, it executes each + of its own methods whose names start with ``check_``. If you wish + to disable selected checks, simply add a line in your global + config which sets the appropriate method to False:: + + [global] + checker.check_skipped_app_config = False + + You may also dynamically add or replace ``check_*`` methods in this way. + """ + + on = True + """If True (the default), run all checks; if False, turn off all checks.""" + + def __init__(self): + self._populate_known_types() + + def __call__(self): + """Run all check_* methods.""" + if self.on: + oldformatwarning = warnings.formatwarning + warnings.formatwarning = self.formatwarning + try: + for name in dir(self): + if name.startswith("check_"): + method = getattr(self, name) + if method and hasattr(method, '__call__'): + method() + finally: + warnings.formatwarning = oldformatwarning + + def formatwarning(self, message, category, filename, lineno, line=None): + """Function to format a warning.""" + return "CherryPy Checker:\n%s\n\n" % message + + # This value should be set inside _cpconfig. + global_config_contained_paths = False + + def check_app_config_entries_dont_start_with_script_name(self): + """Check for Application config with sections that repeat script_name. + """ + for sn, app in cherrypy.tree.apps.items(): + if not isinstance(app, cherrypy.Application): + continue + if not app.config: + continue + if sn == '': + continue + sn_atoms = sn.strip("/").split("/") + for key in app.config.keys(): + key_atoms = key.strip("/").split("/") + if key_atoms[:len(sn_atoms)] == sn_atoms: + warnings.warn( + "The application mounted at %r has config " + "entries that start with its script name: %r" % (sn, + key)) + + def check_site_config_entries_in_app_config(self): + """Check for mounted Applications that have site-scoped config.""" + for sn, app in iteritems(cherrypy.tree.apps): + if not isinstance(app, cherrypy.Application): + continue + + msg = [] + for section, entries in iteritems(app.config): + if section.startswith('/'): + for key, value in iteritems(entries): + for n in ("engine.", "server.", "tree.", "checker."): + if key.startswith(n): + msg.append("[%s] %s = %s" % + (section, key, value)) + if msg: + msg.insert(0, + "The application mounted at %r contains the " + "following config entries, which are only allowed " + "in site-wide config. Move them to a [global] " + "section and pass them to cherrypy.config.update() " + "instead of tree.mount()." % sn) + warnings.warn(os.linesep.join(msg)) + + def check_skipped_app_config(self): + """Check for mounted Applications that have no config.""" + for sn, app in cherrypy.tree.apps.items(): + if not isinstance(app, cherrypy.Application): + continue + if not app.config: + msg = "The Application mounted at %r has an empty config." % sn + if self.global_config_contained_paths: + msg += (" It looks like the config you passed to " + "cherrypy.config.update() contains application-" + "specific sections. You must explicitly pass " + "application config via " + "cherrypy.tree.mount(..., config=app_config)") + warnings.warn(msg) + return + + def check_app_config_brackets(self): + """Check for Application config with extraneous brackets in section + names. + """ + for sn, app in cherrypy.tree.apps.items(): + if not isinstance(app, cherrypy.Application): + continue + if not app.config: + continue + for key in app.config.keys(): + if key.startswith("[") or key.endswith("]"): + warnings.warn( + "The application mounted at %r has config " + "section names with extraneous brackets: %r. " + "Config *files* need brackets; config *dicts* " + "(e.g. passed to tree.mount) do not." % (sn, key)) + + def check_static_paths(self): + """Check Application config for incorrect static paths.""" + # Use the dummy Request object in the main thread. + request = cherrypy.request + for sn, app in cherrypy.tree.apps.items(): + if not isinstance(app, cherrypy.Application): + continue + request.app = app + for section in app.config: + # get_resource will populate request.config + request.get_resource(section + "/dummy.html") + conf = request.config.get + + if conf("tools.staticdir.on", False): + msg = "" + root = conf("tools.staticdir.root") + dir = conf("tools.staticdir.dir") + if dir is None: + msg = "tools.staticdir.dir is not set." + else: + fulldir = "" + if os.path.isabs(dir): + fulldir = dir + if root: + msg = ("dir is an absolute path, even " + "though a root is provided.") + testdir = os.path.join(root, dir[1:]) + if os.path.exists(testdir): + msg += ( + "\nIf you meant to serve the " + "filesystem folder at %r, remove the " + "leading slash from dir." % (testdir,)) + else: + if not root: + msg = ( + "dir is a relative path and " + "no root provided.") + else: + fulldir = os.path.join(root, dir) + if not os.path.isabs(fulldir): + msg = ("%r is not an absolute path." % ( + fulldir,)) + + if fulldir and not os.path.exists(fulldir): + if msg: + msg += "\n" + msg += ("%r (root + dir) is not an existing " + "filesystem path." % fulldir) + + if msg: + warnings.warn("%s\nsection: [%s]\nroot: %r\ndir: %r" + % (msg, section, root, dir)) + + # -------------------------- Compatibility -------------------------- # + obsolete = { + 'server.default_content_type': 'tools.response_headers.headers', + 'log_access_file': 'log.access_file', + 'log_config_options': None, + 'log_file': 'log.error_file', + 'log_file_not_found': None, + 'log_request_headers': 'tools.log_headers.on', + 'log_to_screen': 'log.screen', + 'show_tracebacks': 'request.show_tracebacks', + 'throw_errors': 'request.throw_errors', + 'profiler.on': ('cherrypy.tree.mount(profiler.make_app(' + 'cherrypy.Application(Root())))'), + } + + deprecated = {} + + def _compat(self, config): + """Process config and warn on each obsolete or deprecated entry.""" + for section, conf in config.items(): + if isinstance(conf, dict): + for k, v in conf.items(): + if k in self.obsolete: + warnings.warn("%r is obsolete. Use %r instead.\n" + "section: [%s]" % + (k, self.obsolete[k], section)) + elif k in self.deprecated: + warnings.warn("%r is deprecated. Use %r instead.\n" + "section: [%s]" % + (k, self.deprecated[k], section)) + else: + if section in self.obsolete: + warnings.warn("%r is obsolete. Use %r instead." + % (section, self.obsolete[section])) + elif section in self.deprecated: + warnings.warn("%r is deprecated. Use %r instead." + % (section, self.deprecated[section])) + + def check_compatibility(self): + """Process config and warn on each obsolete or deprecated entry.""" + self._compat(cherrypy.config) + for sn, app in cherrypy.tree.apps.items(): + if not isinstance(app, cherrypy.Application): + continue + self._compat(app.config) + + # ------------------------ Known Namespaces ------------------------ # + extra_config_namespaces = [] + + def _known_ns(self, app): + ns = ["wsgi"] + ns.extend(copykeys(app.toolboxes)) + ns.extend(copykeys(app.namespaces)) + ns.extend(copykeys(app.request_class.namespaces)) + ns.extend(copykeys(cherrypy.config.namespaces)) + ns += self.extra_config_namespaces + + for section, conf in app.config.items(): + is_path_section = section.startswith("/") + if is_path_section and isinstance(conf, dict): + for k, v in conf.items(): + atoms = k.split(".") + if len(atoms) > 1: + if atoms[0] not in ns: + # Spit out a special warning if a known + # namespace is preceded by "cherrypy." + if atoms[0] == "cherrypy" and atoms[1] in ns: + msg = ( + "The config entry %r is invalid; " + "try %r instead.\nsection: [%s]" + % (k, ".".join(atoms[1:]), section)) + else: + msg = ( + "The config entry %r is invalid, " + "because the %r config namespace " + "is unknown.\n" + "section: [%s]" % (k, atoms[0], section)) + warnings.warn(msg) + elif atoms[0] == "tools": + if atoms[1] not in dir(cherrypy.tools): + msg = ( + "The config entry %r may be invalid, " + "because the %r tool was not found.\n" + "section: [%s]" % (k, atoms[1], section)) + warnings.warn(msg) + + def check_config_namespaces(self): + """Process config and warn on each unknown config namespace.""" + for sn, app in cherrypy.tree.apps.items(): + if not isinstance(app, cherrypy.Application): + continue + self._known_ns(app) + + # -------------------------- Config Types -------------------------- # + known_config_types = {} + + def _populate_known_types(self): + b = [x for x in vars(builtins).values() + if type(x) is type(str)] + + def traverse(obj, namespace): + for name in dir(obj): + # Hack for 3.2's warning about body_params + if name == 'body_params': + continue + vtype = type(getattr(obj, name, None)) + if vtype in b: + self.known_config_types[namespace + "." + name] = vtype + + traverse(cherrypy.request, "request") + traverse(cherrypy.response, "response") + traverse(cherrypy.server, "server") + traverse(cherrypy.engine, "engine") + traverse(cherrypy.log, "log") + + def _known_types(self, config): + msg = ("The config entry %r in section %r is of type %r, " + "which does not match the expected type %r.") + + for section, conf in config.items(): + if isinstance(conf, dict): + for k, v in conf.items(): + if v is not None: + expected_type = self.known_config_types.get(k, None) + vtype = type(v) + if expected_type and vtype != expected_type: + warnings.warn(msg % (k, section, vtype.__name__, + expected_type.__name__)) + else: + k, v = section, conf + if v is not None: + expected_type = self.known_config_types.get(k, None) + vtype = type(v) + if expected_type and vtype != expected_type: + warnings.warn(msg % (k, section, vtype.__name__, + expected_type.__name__)) + + def check_config_types(self): + """Assert that config values are of the same type as default values.""" + self._known_types(cherrypy.config) + for sn, app in cherrypy.tree.apps.items(): + if not isinstance(app, cherrypy.Application): + continue + self._known_types(app.config) + + # -------------------- Specific config warnings -------------------- # + def check_localhost(self): + """Warn if any socket_host is 'localhost'. See #711.""" + for k, v in cherrypy.config.items(): + if k == 'server.socket_host' and v == 'localhost': + warnings.warn("The use of 'localhost' as a socket host can " + "cause problems on newer systems, since " + "'localhost' can map to either an IPv4 or an " + "IPv6 address. You should use '127.0.0.1' " + "or '[::1]' instead.") diff --git a/lib/cherrypy/_cpcompat.py b/lib/cherrypy/_cpcompat.py new file mode 100644 index 00000000..8a98b38b --- /dev/null +++ b/lib/cherrypy/_cpcompat.py @@ -0,0 +1,383 @@ +"""Compatibility code for using CherryPy with various versions of Python. + +CherryPy 3.2 is compatible with Python versions 2.3+. This module provides a +useful abstraction over the differences between Python versions, sometimes by +preferring a newer idiom, sometimes an older one, and sometimes a custom one. + +In particular, Python 2 uses str and '' for byte strings, while Python 3 +uses str and '' for unicode strings. We will call each of these the 'native +string' type for each version. Because of this major difference, this module +provides new 'bytestr', 'unicodestr', and 'nativestr' attributes, as well as +two functions: 'ntob', which translates native strings (of type 'str') into +byte strings regardless of Python version, and 'ntou', which translates native +strings to unicode strings. This also provides a 'BytesIO' name for dealing +specifically with bytes, and a 'StringIO' name for dealing with native strings. +It also provides a 'base64_decode' function with native strings as input and +output. +""" +import os +import re +import sys +import threading + +if sys.version_info >= (3, 0): + py3k = True + bytestr = bytes + unicodestr = str + nativestr = unicodestr + basestring = (bytes, str) + + def ntob(n, encoding='ISO-8859-1'): + """Return the given native string as a byte string in the given + encoding. + """ + assert_native(n) + # In Python 3, the native string type is unicode + return n.encode(encoding) + + def ntou(n, encoding='ISO-8859-1'): + """Return the given native string as a unicode string with the given + encoding. + """ + assert_native(n) + # In Python 3, the native string type is unicode + return n + + def tonative(n, encoding='ISO-8859-1'): + """Return the given string as a native string in the given encoding.""" + # In Python 3, the native string type is unicode + if isinstance(n, bytes): + return n.decode(encoding) + return n + # type("") + from io import StringIO + # bytes: + from io import BytesIO as BytesIO +else: + # Python 2 + py3k = False + bytestr = str + unicodestr = unicode + nativestr = bytestr + basestring = basestring + + def ntob(n, encoding='ISO-8859-1'): + """Return the given native string as a byte string in the given + encoding. + """ + assert_native(n) + # In Python 2, the native string type is bytes. Assume it's already + # in the given encoding, which for ISO-8859-1 is almost always what + # was intended. + return n + + def ntou(n, encoding='ISO-8859-1'): + """Return the given native string as a unicode string with the given + encoding. + """ + assert_native(n) + # In Python 2, the native string type is bytes. + # First, check for the special encoding 'escape'. The test suite uses + # this to signal that it wants to pass a string with embedded \uXXXX + # escapes, but without having to prefix it with u'' for Python 2, + # but no prefix for Python 3. + if encoding == 'escape': + return unicode( + re.sub(r'\\u([0-9a-zA-Z]{4})', + lambda m: unichr(int(m.group(1), 16)), + n.decode('ISO-8859-1'))) + # Assume it's already in the given encoding, which for ISO-8859-1 + # is almost always what was intended. + return n.decode(encoding) + + def tonative(n, encoding='ISO-8859-1'): + """Return the given string as a native string in the given encoding.""" + # In Python 2, the native string type is bytes. + if isinstance(n, unicode): + return n.encode(encoding) + return n + try: + # type("") + from cStringIO import StringIO + except ImportError: + # type("") + from StringIO import StringIO + # bytes: + BytesIO = StringIO + + +def assert_native(n): + if not isinstance(n, nativestr): + raise TypeError("n must be a native str (got %s)" % type(n).__name__) + +try: + set = set +except NameError: + from sets import Set as set + +try: + # Python 3.1+ + from base64 import decodebytes as _base64_decodebytes +except ImportError: + # Python 3.0- + # since CherryPy claims compability with Python 2.3, we must use + # the legacy API of base64 + from base64 import decodestring as _base64_decodebytes + + +def base64_decode(n, encoding='ISO-8859-1'): + """Return the native string base64-decoded (as a native string).""" + if isinstance(n, unicodestr): + b = n.encode(encoding) + else: + b = n + b = _base64_decodebytes(b) + if nativestr is unicodestr: + return b.decode(encoding) + else: + return b + +try: + # Python 2.5+ + from hashlib import md5 +except ImportError: + from md5 import new as md5 + +try: + # Python 2.5+ + from hashlib import sha1 as sha +except ImportError: + from sha import new as sha + +try: + sorted = sorted +except NameError: + def sorted(i): + i = i[:] + i.sort() + return i + +try: + reversed = reversed +except NameError: + def reversed(x): + i = len(x) + while i > 0: + i -= 1 + yield x[i] + +try: + # Python 3 + from urllib.parse import urljoin, urlencode + from urllib.parse import quote, quote_plus + from urllib.request import unquote, urlopen + from urllib.request import parse_http_list, parse_keqv_list +except ImportError: + # Python 2 + from urlparse import urljoin + from urllib import urlencode, urlopen + from urllib import quote, quote_plus + from urllib import unquote + from urllib2 import parse_http_list, parse_keqv_list + +try: + from threading import local as threadlocal +except ImportError: + from cherrypy._cpthreadinglocal import local as threadlocal + +try: + dict.iteritems + # Python 2 + iteritems = lambda d: d.iteritems() + copyitems = lambda d: d.items() +except AttributeError: + # Python 3 + iteritems = lambda d: d.items() + copyitems = lambda d: list(d.items()) + +try: + dict.iterkeys + # Python 2 + iterkeys = lambda d: d.iterkeys() + copykeys = lambda d: d.keys() +except AttributeError: + # Python 3 + iterkeys = lambda d: d.keys() + copykeys = lambda d: list(d.keys()) + +try: + dict.itervalues + # Python 2 + itervalues = lambda d: d.itervalues() + copyvalues = lambda d: d.values() +except AttributeError: + # Python 3 + itervalues = lambda d: d.values() + copyvalues = lambda d: list(d.values()) + +try: + # Python 3 + import builtins +except ImportError: + # Python 2 + import __builtin__ as builtins + +try: + # Python 2. We try Python 2 first clients on Python 2 + # don't try to import the 'http' module from cherrypy.lib + from Cookie import SimpleCookie, CookieError + from httplib import BadStatusLine, HTTPConnection, IncompleteRead + from httplib import NotConnected + from BaseHTTPServer import BaseHTTPRequestHandler +except ImportError: + # Python 3 + from http.cookies import SimpleCookie, CookieError + from http.client import BadStatusLine, HTTPConnection, IncompleteRead + from http.client import NotConnected + from http.server import BaseHTTPRequestHandler + +# Some platforms don't expose HTTPSConnection, so handle it separately +if py3k: + try: + from http.client import HTTPSConnection + except ImportError: + # Some platforms which don't have SSL don't expose HTTPSConnection + HTTPSConnection = None +else: + try: + from httplib import HTTPSConnection + except ImportError: + HTTPSConnection = None + +try: + # Python 2 + xrange = xrange +except NameError: + # Python 3 + xrange = range + +import threading +if hasattr(threading.Thread, "daemon"): + # Python 2.6+ + def get_daemon(t): + return t.daemon + + def set_daemon(t, val): + t.daemon = val +else: + def get_daemon(t): + return t.isDaemon() + + def set_daemon(t, val): + t.setDaemon(val) + +try: + from email.utils import formatdate + + def HTTPDate(timeval=None): + return formatdate(timeval, usegmt=True) +except ImportError: + from rfc822 import formatdate as HTTPDate + +try: + # Python 3 + from urllib.parse import unquote as parse_unquote + + def unquote_qs(atom, encoding, errors='strict'): + return parse_unquote( + atom.replace('+', ' '), + encoding=encoding, + errors=errors) +except ImportError: + # Python 2 + from urllib import unquote as parse_unquote + + def unquote_qs(atom, encoding, errors='strict'): + return parse_unquote(atom.replace('+', ' ')).decode(encoding, errors) + +try: + # Prefer simplejson, which is usually more advanced than the builtin + # module. + import simplejson as json + json_decode = json.JSONDecoder().decode + _json_encode = json.JSONEncoder().iterencode +except ImportError: + if sys.version_info >= (2, 6): + # Python >=2.6 : json is part of the standard library + import json + json_decode = json.JSONDecoder().decode + _json_encode = json.JSONEncoder().iterencode + else: + json = None + + def json_decode(s): + raise ValueError('No JSON library is available') + + def _json_encode(s): + raise ValueError('No JSON library is available') +finally: + if json and py3k: + # The two Python 3 implementations (simplejson/json) + # outputs str. We need bytes. + def json_encode(value): + for chunk in _json_encode(value): + yield chunk.encode('utf8') + else: + json_encode = _json_encode + + +try: + import cPickle as pickle +except ImportError: + # In Python 2, pickle is a Python version. + # In Python 3, pickle is the sped-up C version. + import pickle + +try: + os.urandom(20) + import binascii + + def random20(): + return binascii.hexlify(os.urandom(20)).decode('ascii') +except (AttributeError, NotImplementedError): + import random + # os.urandom not available until Python 2.4. Fall back to random.random. + + def random20(): + return sha('%s' % random.random()).hexdigest() + +try: + from _thread import get_ident as get_thread_ident +except ImportError: + from thread import get_ident as get_thread_ident + +try: + # Python 3 + next = next +except NameError: + # Python 2 + def next(i): + return i.next() + +if sys.version_info >= (3, 3): + Timer = threading.Timer + Event = threading.Event +else: + # Python 3.2 and earlier + Timer = threading._Timer + Event = threading._Event + +# Prior to Python 2.6, the Thread class did not have a .daemon property. +# This mix-in adds that property. + + +class SetDaemonProperty: + + def __get_daemon(self): + return self.isDaemon() + + def __set_daemon(self, daemon): + self.setDaemon(daemon) + + if sys.version_info < (2, 6): + daemon = property(__get_daemon, __set_daemon) diff --git a/lib/cherrypy/_cpcompat_subprocess.py b/lib/cherrypy/_cpcompat_subprocess.py new file mode 100644 index 00000000..478f4a74 --- /dev/null +++ b/lib/cherrypy/_cpcompat_subprocess.py @@ -0,0 +1,1544 @@ +# subprocess - Subprocesses with accessible I/O streams +# +# For more information about this module, see PEP 324. +# +# This module should remain compatible with Python 2.2, see PEP 291. +# +# Copyright (c) 2003-2005 by Peter Astrand +# +# Licensed to PSF under a Contributor Agreement. +# See http://www.python.org/2.4/license for licensing details. + +r"""subprocess - Subprocesses with accessible I/O streams + +This module allows you to spawn processes, connect to their +input/output/error pipes, and obtain their return codes. This module +intends to replace several other, older modules and functions, like: + +os.system +os.spawn* +os.popen* +popen2.* +commands.* + +Information about how the subprocess module can be used to replace these +modules and functions can be found below. + + + +Using the subprocess module +=========================== +This module defines one class called Popen: + +class Popen(args, bufsize=0, executable=None, + stdin=None, stdout=None, stderr=None, + preexec_fn=None, close_fds=False, shell=False, + cwd=None, env=None, universal_newlines=False, + startupinfo=None, creationflags=0): + + +Arguments are: + +args should be a string, or a sequence of program arguments. The +program to execute is normally the first item in the args sequence or +string, but can be explicitly set by using the executable argument. + +On UNIX, with shell=False (default): In this case, the Popen class +uses os.execvp() to execute the child program. args should normally +be a sequence. A string will be treated as a sequence with the string +as the only item (the program to execute). + +On UNIX, with shell=True: If args is a string, it specifies the +command string to execute through the shell. If args is a sequence, +the first item specifies the command string, and any additional items +will be treated as additional shell arguments. + +On Windows: the Popen class uses CreateProcess() to execute the child +program, which operates on strings. If args is a sequence, it will be +converted to a string using the list2cmdline method. Please note that +not all MS Windows applications interpret the command line the same +way: The list2cmdline is designed for applications using the same +rules as the MS C runtime. + +bufsize, if given, has the same meaning as the corresponding argument +to the built-in open() function: 0 means unbuffered, 1 means line +buffered, any other positive value means use a buffer of +(approximately) that size. A negative bufsize means to use the system +default, which usually means fully buffered. The default value for +bufsize is 0 (unbuffered). + +stdin, stdout and stderr specify the executed programs' standard +input, standard output and standard error file handles, respectively. +Valid values are PIPE, an existing file descriptor (a positive +integer), an existing file object, and None. PIPE indicates that a +new pipe to the child should be created. With None, no redirection +will occur; the child's file handles will be inherited from the +parent. Additionally, stderr can be STDOUT, which indicates that the +stderr data from the applications should be captured into the same +file handle as for stdout. + +If preexec_fn is set to a callable object, this object will be called +in the child process just before the child is executed. + +If close_fds is true, all file descriptors except 0, 1 and 2 will be +closed before the child process is executed. + +if shell is true, the specified command will be executed through the +shell. + +If cwd is not None, the current directory will be changed to cwd +before the child is executed. + +If env is not None, it defines the environment variables for the new +process. + +If universal_newlines is true, the file objects stdout and stderr are +opened as a text files, but lines may be terminated by any of '\n', +the Unix end-of-line convention, '\r', the Macintosh convention or +'\r\n', the Windows convention. All of these external representations +are seen as '\n' by the Python program. Note: This feature is only +available if Python is built with universal newline support (the +default). Also, the newlines attribute of the file objects stdout, +stdin and stderr are not updated by the communicate() method. + +The startupinfo and creationflags, if given, will be passed to the +underlying CreateProcess() function. They can specify things such as +appearance of the main window and priority for the new process. +(Windows only) + + +This module also defines some shortcut functions: + +call(*popenargs, **kwargs): + Run command with arguments. Wait for command to complete, then + return the returncode attribute. + + The arguments are the same as for the Popen constructor. Example: + + retcode = call(["ls", "-l"]) + +check_call(*popenargs, **kwargs): + Run command with arguments. Wait for command to complete. If the + exit code was zero then return, otherwise raise + CalledProcessError. The CalledProcessError object will have the + return code in the returncode attribute. + + The arguments are the same as for the Popen constructor. Example: + + check_call(["ls", "-l"]) + +check_output(*popenargs, **kwargs): + Run command with arguments and return its output as a byte string. + + If the exit code was non-zero it raises a CalledProcessError. The + CalledProcessError object will have the return code in the returncode + attribute and output in the output attribute. + + The arguments are the same as for the Popen constructor. Example: + + output = check_output(["ls", "-l", "/dev/null"]) + + +Exceptions +---------- +Exceptions raised in the child process, before the new program has +started to execute, will be re-raised in the parent. Additionally, +the exception object will have one extra attribute called +'child_traceback', which is a string containing traceback information +from the childs point of view. + +The most common exception raised is OSError. This occurs, for +example, when trying to execute a non-existent file. Applications +should prepare for OSErrors. + +A ValueError will be raised if Popen is called with invalid arguments. + +check_call() and check_output() will raise CalledProcessError, if the +called process returns a non-zero return code. + + +Security +-------- +Unlike some other popen functions, this implementation will never call +/bin/sh implicitly. This means that all characters, including shell +metacharacters, can safely be passed to child processes. + + +Popen objects +============= +Instances of the Popen class have the following methods: + +poll() + Check if child process has terminated. Returns returncode + attribute. + +wait() + Wait for child process to terminate. Returns returncode attribute. + +communicate(input=None) + Interact with process: Send data to stdin. Read data from stdout + and stderr, until end-of-file is reached. Wait for process to + terminate. The optional input argument should be a string to be + sent to the child process, or None, if no data should be sent to + the child. + + communicate() returns a tuple (stdout, stderr). + + Note: The data read is buffered in memory, so do not use this + method if the data size is large or unlimited. + +The following attributes are also available: + +stdin + If the stdin argument is PIPE, this attribute is a file object + that provides input to the child process. Otherwise, it is None. + +stdout + If the stdout argument is PIPE, this attribute is a file object + that provides output from the child process. Otherwise, it is + None. + +stderr + If the stderr argument is PIPE, this attribute is file object that + provides error output from the child process. Otherwise, it is + None. + +pid + The process ID of the child process. + +returncode + The child return code. A None value indicates that the process + hasn't terminated yet. A negative value -N indicates that the + child was terminated by signal N (UNIX only). + + +Replacing older functions with the subprocess module +==================================================== +In this section, "a ==> b" means that b can be used as a replacement +for a. + +Note: All functions in this section fail (more or less) silently if +the executed program cannot be found; this module raises an OSError +exception. + +In the following examples, we assume that the subprocess module is +imported with "from subprocess import *". + + +Replacing /bin/sh shell backquote +--------------------------------- +output=`mycmd myarg` +==> +output = Popen(["mycmd", "myarg"], stdout=PIPE).communicate()[0] + + +Replacing shell pipe line +------------------------- +output=`dmesg | grep hda` +==> +p1 = Popen(["dmesg"], stdout=PIPE) +p2 = Popen(["grep", "hda"], stdin=p1.stdout, stdout=PIPE) +output = p2.communicate()[0] + + +Replacing os.system() +--------------------- +sts = os.system("mycmd" + " myarg") +==> +p = Popen("mycmd" + " myarg", shell=True) +pid, sts = os.waitpid(p.pid, 0) + +Note: + +* Calling the program through the shell is usually not required. + +* It's easier to look at the returncode attribute than the + exitstatus. + +A more real-world example would look like this: + +try: + retcode = call("mycmd" + " myarg", shell=True) + if retcode < 0: + print >>sys.stderr, "Child was terminated by signal", -retcode + else: + print >>sys.stderr, "Child returned", retcode +except OSError, e: + print >>sys.stderr, "Execution failed:", e + + +Replacing os.spawn* +------------------- +P_NOWAIT example: + +pid = os.spawnlp(os.P_NOWAIT, "/bin/mycmd", "mycmd", "myarg") +==> +pid = Popen(["/bin/mycmd", "myarg"]).pid + + +P_WAIT example: + +retcode = os.spawnlp(os.P_WAIT, "/bin/mycmd", "mycmd", "myarg") +==> +retcode = call(["/bin/mycmd", "myarg"]) + + +Vector example: + +os.spawnvp(os.P_NOWAIT, path, args) +==> +Popen([path] + args[1:]) + + +Environment example: + +os.spawnlpe(os.P_NOWAIT, "/bin/mycmd", "mycmd", "myarg", env) +==> +Popen(["/bin/mycmd", "myarg"], env={"PATH": "/usr/bin"}) + + +Replacing os.popen* +------------------- +pipe = os.popen("cmd", mode='r', bufsize) +==> +pipe = Popen("cmd", shell=True, bufsize=bufsize, stdout=PIPE).stdout + +pipe = os.popen("cmd", mode='w', bufsize) +==> +pipe = Popen("cmd", shell=True, bufsize=bufsize, stdin=PIPE).stdin + + +(child_stdin, child_stdout) = os.popen2("cmd", mode, bufsize) +==> +p = Popen("cmd", shell=True, bufsize=bufsize, + stdin=PIPE, stdout=PIPE, close_fds=True) +(child_stdin, child_stdout) = (p.stdin, p.stdout) + + +(child_stdin, + child_stdout, + child_stderr) = os.popen3("cmd", mode, bufsize) +==> +p = Popen("cmd", shell=True, bufsize=bufsize, + stdin=PIPE, stdout=PIPE, stderr=PIPE, close_fds=True) +(child_stdin, + child_stdout, + child_stderr) = (p.stdin, p.stdout, p.stderr) + + +(child_stdin, child_stdout_and_stderr) = os.popen4("cmd", mode, + bufsize) +==> +p = Popen("cmd", shell=True, bufsize=bufsize, + stdin=PIPE, stdout=PIPE, stderr=STDOUT, close_fds=True) +(child_stdin, child_stdout_and_stderr) = (p.stdin, p.stdout) + +On Unix, os.popen2, os.popen3 and os.popen4 also accept a sequence as +the command to execute, in which case arguments will be passed +directly to the program without shell intervention. This usage can be +replaced as follows: + +(child_stdin, child_stdout) = os.popen2(["/bin/ls", "-l"], mode, + bufsize) +==> +p = Popen(["/bin/ls", "-l"], bufsize=bufsize, stdin=PIPE, stdout=PIPE) +(child_stdin, child_stdout) = (p.stdin, p.stdout) + +Return code handling translates as follows: + +pipe = os.popen("cmd", 'w') +... +rc = pipe.close() +if rc is not None and rc % 256: + print "There were some errors" +==> +process = Popen("cmd", 'w', shell=True, stdin=PIPE) +... +process.stdin.close() +if process.wait() != 0: + print "There were some errors" + + +Replacing popen2.* +------------------ +(child_stdout, child_stdin) = popen2.popen2("somestring", bufsize, mode) +==> +p = Popen(["somestring"], shell=True, bufsize=bufsize + stdin=PIPE, stdout=PIPE, close_fds=True) +(child_stdout, child_stdin) = (p.stdout, p.stdin) + +On Unix, popen2 also accepts a sequence as the command to execute, in +which case arguments will be passed directly to the program without +shell intervention. This usage can be replaced as follows: + +(child_stdout, child_stdin) = popen2.popen2(["mycmd", "myarg"], bufsize, + mode) +==> +p = Popen(["mycmd", "myarg"], bufsize=bufsize, + stdin=PIPE, stdout=PIPE, close_fds=True) +(child_stdout, child_stdin) = (p.stdout, p.stdin) + +The popen2.Popen3 and popen2.Popen4 basically works as subprocess.Popen, +except that: + +* subprocess.Popen raises an exception if the execution fails +* the capturestderr argument is replaced with the stderr argument. +* stdin=PIPE and stdout=PIPE must be specified. +* popen2 closes all filedescriptors by default, but you have to specify + close_fds=True with subprocess.Popen. +""" + +import sys +mswindows = (sys.platform == "win32") + +import os +import types +import traceback +import gc +import signal +import errno + +try: + set +except NameError: + from sets import Set as set + +# Exception classes used by this module. + + +class CalledProcessError(Exception): + + """This exception is raised when a process run by check_call() or + check_output() returns a non-zero exit status. + The exit status will be stored in the returncode attribute; + check_output() will also store the output in the output attribute. + """ + + def __init__(self, returncode, cmd, output=None): + self.returncode = returncode + self.cmd = cmd + self.output = output + + def __str__(self): + return "Command '%s' returned non-zero exit status %d" % ( + self.cmd, self.returncode) + + +if mswindows: + import threading + import msvcrt + import _subprocess + + class STARTUPINFO: + dwFlags = 0 + hStdInput = None + hStdOutput = None + hStdError = None + wShowWindow = 0 + + class pywintypes: + error = IOError +else: + import select + _has_poll = hasattr(select, 'poll') + import fcntl + import pickle + + # When select or poll has indicated that the file is writable, + # we can write up to _PIPE_BUF bytes without risk of blocking. + # POSIX defines PIPE_BUF as >= 512. + _PIPE_BUF = getattr(select, 'PIPE_BUF', 512) + + +__all__ = ["Popen", "PIPE", "STDOUT", "call", "check_call", + "check_output", "CalledProcessError"] + +if mswindows: + from _subprocess import CREATE_NEW_CONSOLE, CREATE_NEW_PROCESS_GROUP, \ + STD_INPUT_HANDLE, STD_OUTPUT_HANDLE, \ + STD_ERROR_HANDLE, SW_HIDE, \ + STARTF_USESTDHANDLES, STARTF_USESHOWWINDOW + + __all__.extend(["CREATE_NEW_CONSOLE", "CREATE_NEW_PROCESS_GROUP", + "STD_INPUT_HANDLE", "STD_OUTPUT_HANDLE", + "STD_ERROR_HANDLE", "SW_HIDE", + "STARTF_USESTDHANDLES", "STARTF_USESHOWWINDOW"]) +try: + MAXFD = os.sysconf("SC_OPEN_MAX") +except: + MAXFD = 256 + +_active = [] + + +def _cleanup(): + for inst in _active[:]: + res = inst._internal_poll(_deadstate=sys.maxint) + if res is not None: + try: + _active.remove(inst) + except ValueError: + # This can happen if two threads create a new Popen instance. + # It's harmless that it was already removed, so ignore. + pass + +PIPE = -1 +STDOUT = -2 + + +def _eintr_retry_call(func, *args): + while True: + try: + return func(*args) + except (OSError, IOError), e: + if e.errno == errno.EINTR: + continue + raise + + +def call(*popenargs, **kwargs): + """Run command with arguments. Wait for command to complete, then + return the returncode attribute. + + The arguments are the same as for the Popen constructor. Example: + + retcode = call(["ls", "-l"]) + """ + return Popen(*popenargs, **kwargs).wait() + + +def check_call(*popenargs, **kwargs): + """Run command with arguments. Wait for command to complete. If + the exit code was zero then return, otherwise raise + CalledProcessError. The CalledProcessError object will have the + return code in the returncode attribute. + + The arguments are the same as for the Popen constructor. Example: + + check_call(["ls", "-l"]) + """ + retcode = call(*popenargs, **kwargs) + if retcode: + cmd = kwargs.get("args") + if cmd is None: + cmd = popenargs[0] + raise CalledProcessError(retcode, cmd) + return 0 + + +def check_output(*popenargs, **kwargs): + r"""Run command with arguments and return its output as a byte string. + + If the exit code was non-zero it raises a CalledProcessError. The + CalledProcessError object will have the return code in the returncode + attribute and output in the output attribute. + + The arguments are the same as for the Popen constructor. Example: + + >>> check_output(["ls", "-l", "/dev/null"]) + 'crw-rw-rw- 1 root root 1, 3 Oct 18 2007 /dev/null\n' + + The stdout argument is not allowed as it is used internally. + To capture standard error in the result, use stderr=STDOUT. + + >>> check_output(["/bin/sh", "-c", + ... "ls -l non_existent_file ; exit 0"], + ... stderr=STDOUT) + 'ls: non_existent_file: No such file or directory\n' + """ + if 'stdout' in kwargs: + raise ValueError('stdout argument not allowed, it will be overridden.') + process = Popen(stdout=PIPE, *popenargs, **kwargs) + output, unused_err = process.communicate() + retcode = process.poll() + if retcode: + cmd = kwargs.get("args") + if cmd is None: + cmd = popenargs[0] + raise CalledProcessError(retcode, cmd, output=output) + return output + + +def list2cmdline(seq): + """ + Translate a sequence of arguments into a command line + string, using the same rules as the MS C runtime: + + 1) Arguments are delimited by white space, which is either a + space or a tab. + + 2) A string surrounded by double quotation marks is + interpreted as a single argument, regardless of white space + contained within. A quoted string can be embedded in an + argument. + + 3) A double quotation mark preceded by a backslash is + interpreted as a literal double quotation mark. + + 4) Backslashes are interpreted literally, unless they + immediately precede a double quotation mark. + + 5) If backslashes immediately precede a double quotation mark, + every pair of backslashes is interpreted as a literal + backslash. If the number of backslashes is odd, the last + backslash escapes the next double quotation mark as + described in rule 3. + """ + + # See + # http://msdn.microsoft.com/en-us/library/17w5ykft.aspx + # or search http://msdn.microsoft.com for + # "Parsing C++ Command-Line Arguments" + result = [] + needquote = False + for arg in seq: + bs_buf = [] + + # Add a space to separate this argument from the others + if result: + result.append(' ') + + needquote = (" " in arg) or ("\t" in arg) or not arg + if needquote: + result.append('"') + + for c in arg: + if c == '\\': + # Don't know if we need to double yet. + bs_buf.append(c) + elif c == '"': + # Double backslashes. + result.append('\\' * len(bs_buf) * 2) + bs_buf = [] + result.append('\\"') + else: + # Normal char + if bs_buf: + result.extend(bs_buf) + bs_buf = [] + result.append(c) + + # Add remaining backslashes, if any. + if bs_buf: + result.extend(bs_buf) + + if needquote: + result.extend(bs_buf) + result.append('"') + + return ''.join(result) + + +class Popen(object): + + def __init__(self, args, bufsize=0, executable=None, + stdin=None, stdout=None, stderr=None, + preexec_fn=None, close_fds=False, shell=False, + cwd=None, env=None, universal_newlines=False, + startupinfo=None, creationflags=0): + """Create new Popen instance.""" + _cleanup() + + self._child_created = False + if not isinstance(bufsize, (int, long)): + raise TypeError("bufsize must be an integer") + + if mswindows: + if preexec_fn is not None: + raise ValueError("preexec_fn is not supported on Windows " + "platforms") + if close_fds and (stdin is not None or stdout is not None or + stderr is not None): + raise ValueError("close_fds is not supported on Windows " + "platforms if you redirect " + "stdin/stdout/stderr") + else: + # POSIX + if startupinfo is not None: + raise ValueError("startupinfo is only supported on Windows " + "platforms") + if creationflags != 0: + raise ValueError("creationflags is only supported on Windows " + "platforms") + + self.stdin = None + self.stdout = None + self.stderr = None + self.pid = None + self.returncode = None + self.universal_newlines = universal_newlines + + # Input and output objects. The general principle is like + # this: + # + # Parent Child + # ------ ----- + # p2cwrite ---stdin---> p2cread + # c2pread <--stdout--- c2pwrite + # errread <--stderr--- errwrite + # + # On POSIX, the child objects are file descriptors. On + # Windows, these are Windows file handles. The parent objects + # are file descriptors on both platforms. The parent objects + # are None when not using PIPEs. The child objects are None + # when not redirecting. + + (p2cread, p2cwrite, + c2pread, c2pwrite, + errread, errwrite) = self._get_handles(stdin, stdout, stderr) + + self._execute_child(args, executable, preexec_fn, close_fds, + cwd, env, universal_newlines, + startupinfo, creationflags, shell, + p2cread, p2cwrite, + c2pread, c2pwrite, + errread, errwrite) + + if mswindows: + if p2cwrite is not None: + p2cwrite = msvcrt.open_osfhandle(p2cwrite.Detach(), 0) + if c2pread is not None: + c2pread = msvcrt.open_osfhandle(c2pread.Detach(), 0) + if errread is not None: + errread = msvcrt.open_osfhandle(errread.Detach(), 0) + + if p2cwrite is not None: + self.stdin = os.fdopen(p2cwrite, 'wb', bufsize) + if c2pread is not None: + if universal_newlines: + self.stdout = os.fdopen(c2pread, 'rU', bufsize) + else: + self.stdout = os.fdopen(c2pread, 'rb', bufsize) + if errread is not None: + if universal_newlines: + self.stderr = os.fdopen(errread, 'rU', bufsize) + else: + self.stderr = os.fdopen(errread, 'rb', bufsize) + + def _translate_newlines(self, data): + data = data.replace("\r\n", "\n") + data = data.replace("\r", "\n") + return data + + def __del__(self, _maxint=sys.maxint, _active=_active): + # If __init__ hasn't had a chance to execute (e.g. if it + # was passed an undeclared keyword argument), we don't + # have a _child_created attribute at all. + if not getattr(self, '_child_created', False): + # We didn't get to successfully create a child process. + return + # In case the child hasn't been waited on, check if it's done. + self._internal_poll(_deadstate=_maxint) + if self.returncode is None and _active is not None: + # Child is still running, keep us alive until we can wait on it. + _active.append(self) + + def communicate(self, input=None): + """Interact with process: Send data to stdin. Read data from + stdout and stderr, until end-of-file is reached. Wait for + process to terminate. The optional input argument should be a + string to be sent to the child process, or None, if no data + should be sent to the child. + + communicate() returns a tuple (stdout, stderr).""" + + # Optimization: If we are only using one pipe, or no pipe at + # all, using select() or threads is unnecessary. + if [self.stdin, self.stdout, self.stderr].count(None) >= 2: + stdout = None + stderr = None + if self.stdin: + if input: + try: + self.stdin.write(input) + except IOError, e: + if e.errno != errno.EPIPE and e.errno != errno.EINVAL: + raise + self.stdin.close() + elif self.stdout: + stdout = _eintr_retry_call(self.stdout.read) + self.stdout.close() + elif self.stderr: + stderr = _eintr_retry_call(self.stderr.read) + self.stderr.close() + self.wait() + return (stdout, stderr) + + return self._communicate(input) + + def poll(self): + return self._internal_poll() + + if mswindows: + # + # Windows methods + # + def _get_handles(self, stdin, stdout, stderr): + """Construct and return tuple with IO objects: + p2cread, p2cwrite, c2pread, c2pwrite, errread, errwrite + """ + if stdin is None and stdout is None and stderr is None: + return (None, None, None, None, None, None) + + p2cread, p2cwrite = None, None + c2pread, c2pwrite = None, None + errread, errwrite = None, None + + if stdin is None: + p2cread = _subprocess.GetStdHandle( + _subprocess.STD_INPUT_HANDLE) + if p2cread is None: + p2cread, _ = _subprocess.CreatePipe(None, 0) + elif stdin == PIPE: + p2cread, p2cwrite = _subprocess.CreatePipe(None, 0) + elif isinstance(stdin, int): + p2cread = msvcrt.get_osfhandle(stdin) + else: + # Assuming file-like object + p2cread = msvcrt.get_osfhandle(stdin.fileno()) + p2cread = self._make_inheritable(p2cread) + + if stdout is None: + c2pwrite = _subprocess.GetStdHandle( + _subprocess.STD_OUTPUT_HANDLE) + if c2pwrite is None: + _, c2pwrite = _subprocess.CreatePipe(None, 0) + elif stdout == PIPE: + c2pread, c2pwrite = _subprocess.CreatePipe(None, 0) + elif isinstance(stdout, int): + c2pwrite = msvcrt.get_osfhandle(stdout) + else: + # Assuming file-like object + c2pwrite = msvcrt.get_osfhandle(stdout.fileno()) + c2pwrite = self._make_inheritable(c2pwrite) + + if stderr is None: + errwrite = _subprocess.GetStdHandle( + _subprocess.STD_ERROR_HANDLE) + if errwrite is None: + _, errwrite = _subprocess.CreatePipe(None, 0) + elif stderr == PIPE: + errread, errwrite = _subprocess.CreatePipe(None, 0) + elif stderr == STDOUT: + errwrite = c2pwrite + elif isinstance(stderr, int): + errwrite = msvcrt.get_osfhandle(stderr) + else: + # Assuming file-like object + errwrite = msvcrt.get_osfhandle(stderr.fileno()) + errwrite = self._make_inheritable(errwrite) + + return (p2cread, p2cwrite, + c2pread, c2pwrite, + errread, errwrite) + + def _make_inheritable(self, handle): + """Return a duplicate of handle, which is inheritable""" + return _subprocess.DuplicateHandle( + _subprocess.GetCurrentProcess(), + handle, + _subprocess.GetCurrentProcess(), + 0, + 1, + _subprocess.DUPLICATE_SAME_ACCESS + ) + + def _find_w9xpopen(self): + """Find and return absolut path to w9xpopen.exe""" + w9xpopen = os.path.join( + os.path.dirname(_subprocess.GetModuleFileName(0)), + "w9xpopen.exe") + if not os.path.exists(w9xpopen): + # Eeek - file-not-found - possibly an embedding + # situation - see if we can locate it in sys.exec_prefix + w9xpopen = os.path.join(os.path.dirname(sys.exec_prefix), + "w9xpopen.exe") + if not os.path.exists(w9xpopen): + raise RuntimeError("Cannot locate w9xpopen.exe, which is " + "needed for Popen to work with your " + "shell or platform.") + return w9xpopen + + def _execute_child(self, args, executable, preexec_fn, close_fds, + cwd, env, universal_newlines, + startupinfo, creationflags, shell, + p2cread, p2cwrite, + c2pread, c2pwrite, + errread, errwrite): + """Execute program (MS Windows version)""" + + if not isinstance(args, types.StringTypes): + args = list2cmdline(args) + + # Process startup details + if startupinfo is None: + startupinfo = STARTUPINFO() + if None not in (p2cread, c2pwrite, errwrite): + startupinfo.dwFlags |= _subprocess.STARTF_USESTDHANDLES + startupinfo.hStdInput = p2cread + startupinfo.hStdOutput = c2pwrite + startupinfo.hStdError = errwrite + + if shell: + startupinfo.dwFlags |= _subprocess.STARTF_USESHOWWINDOW + startupinfo.wShowWindow = _subprocess.SW_HIDE + comspec = os.environ.get("COMSPEC", "cmd.exe") + args = '{} /c "{}"'.format(comspec, args) + if (_subprocess.GetVersion() >= 0x80000000 or + os.path.basename(comspec).lower() == "command.com"): + # Win9x, or using command.com on NT. We need to + # use the w9xpopen intermediate program. For more + # information, see KB Q150956 + # (http://web.archive.org/web/20011105084002/http://support.microsoft.com/support/kb/articles/Q150/9/56.asp) + w9xpopen = self._find_w9xpopen() + args = '"%s" %s' % (w9xpopen, args) + # Not passing CREATE_NEW_CONSOLE has been known to + # cause random failures on win9x. Specifically a + # dialog: "Your program accessed mem currently in + # use at xxx" and a hopeful warning about the + # stability of your system. Cost is Ctrl+C wont + # kill children. + creationflags |= _subprocess.CREATE_NEW_CONSOLE + + # Start the process + try: + try: + hp, ht, pid, tid = _subprocess.CreateProcess( + executable, args, + # no special + # security + None, None, + int(not close_fds), + creationflags, + env, + cwd, + startupinfo) + except pywintypes.error, e: + # Translate pywintypes.error to WindowsError, which is + # a subclass of OSError. FIXME: We should really + # translate errno using _sys_errlist (or similar), but + # how can this be done from Python? + raise WindowsError(*e.args) + finally: + # Child is launched. Close the parent's copy of those pipe + # handles that only the child should have open. You need + # to make sure that no handles to the write end of the + # output pipe are maintained in this process or else the + # pipe will not close when the child process exits and the + # ReadFile will hang. + if p2cread is not None: + p2cread.Close() + if c2pwrite is not None: + c2pwrite.Close() + if errwrite is not None: + errwrite.Close() + + # Retain the process handle, but close the thread handle + self._child_created = True + self._handle = hp + self.pid = pid + ht.Close() + + def _internal_poll( + self, _deadstate=None, + _WaitForSingleObject=_subprocess.WaitForSingleObject, + _WAIT_OBJECT_0=_subprocess.WAIT_OBJECT_0, + _GetExitCodeProcess=_subprocess.GetExitCodeProcess + ): + """Check if child process has terminated. Returns returncode + attribute. + + This method is called by __del__, so it can only refer to objects + in its local scope. + + """ + if self.returncode is None: + if _WaitForSingleObject(self._handle, 0) == _WAIT_OBJECT_0: + self.returncode = _GetExitCodeProcess(self._handle) + return self.returncode + + def wait(self): + """Wait for child process to terminate. Returns returncode + attribute.""" + if self.returncode is None: + _subprocess.WaitForSingleObject(self._handle, + _subprocess.INFINITE) + self.returncode = _subprocess.GetExitCodeProcess(self._handle) + return self.returncode + + def _readerthread(self, fh, buffer): + buffer.append(fh.read()) + + def _communicate(self, input): + stdout = None # Return + stderr = None # Return + + if self.stdout: + stdout = [] + stdout_thread = threading.Thread(target=self._readerthread, + args=(self.stdout, stdout)) + stdout_thread.setDaemon(True) + stdout_thread.start() + if self.stderr: + stderr = [] + stderr_thread = threading.Thread(target=self._readerthread, + args=(self.stderr, stderr)) + stderr_thread.setDaemon(True) + stderr_thread.start() + + if self.stdin: + if input is not None: + try: + self.stdin.write(input) + except IOError, e: + if e.errno != errno.EPIPE: + raise + self.stdin.close() + + if self.stdout: + stdout_thread.join() + if self.stderr: + stderr_thread.join() + + # All data exchanged. Translate lists into strings. + if stdout is not None: + stdout = stdout[0] + if stderr is not None: + stderr = stderr[0] + + # Translate newlines, if requested. We cannot let the file + # object do the translation: It is based on stdio, which is + # impossible to combine with select (unless forcing no + # buffering). + if self.universal_newlines and hasattr(file, 'newlines'): + if stdout: + stdout = self._translate_newlines(stdout) + if stderr: + stderr = self._translate_newlines(stderr) + + self.wait() + return (stdout, stderr) + + def send_signal(self, sig): + """Send a signal to the process + """ + if sig == signal.SIGTERM: + self.terminate() + elif sig == signal.CTRL_C_EVENT: + os.kill(self.pid, signal.CTRL_C_EVENT) + elif sig == signal.CTRL_BREAK_EVENT: + os.kill(self.pid, signal.CTRL_BREAK_EVENT) + else: + raise ValueError("Unsupported signal: {}".format(sig)) + + def terminate(self): + """Terminates the process + """ + _subprocess.TerminateProcess(self._handle, 1) + + kill = terminate + + else: + # + # POSIX methods + # + def _get_handles(self, stdin, stdout, stderr): + """Construct and return tuple with IO objects: + p2cread, p2cwrite, c2pread, c2pwrite, errread, errwrite + """ + p2cread, p2cwrite = None, None + c2pread, c2pwrite = None, None + errread, errwrite = None, None + + if stdin is None: + pass + elif stdin == PIPE: + p2cread, p2cwrite = self.pipe_cloexec() + elif isinstance(stdin, int): + p2cread = stdin + else: + # Assuming file-like object + p2cread = stdin.fileno() + + if stdout is None: + pass + elif stdout == PIPE: + c2pread, c2pwrite = self.pipe_cloexec() + elif isinstance(stdout, int): + c2pwrite = stdout + else: + # Assuming file-like object + c2pwrite = stdout.fileno() + + if stderr is None: + pass + elif stderr == PIPE: + errread, errwrite = self.pipe_cloexec() + elif stderr == STDOUT: + errwrite = c2pwrite + elif isinstance(stderr, int): + errwrite = stderr + else: + # Assuming file-like object + errwrite = stderr.fileno() + + return (p2cread, p2cwrite, + c2pread, c2pwrite, + errread, errwrite) + + def _set_cloexec_flag(self, fd, cloexec=True): + try: + cloexec_flag = fcntl.FD_CLOEXEC + except AttributeError: + cloexec_flag = 1 + + old = fcntl.fcntl(fd, fcntl.F_GETFD) + if cloexec: + fcntl.fcntl(fd, fcntl.F_SETFD, old | cloexec_flag) + else: + fcntl.fcntl(fd, fcntl.F_SETFD, old & ~cloexec_flag) + + def pipe_cloexec(self): + """Create a pipe with FDs set CLOEXEC.""" + # Pipes' FDs are set CLOEXEC by default because we don't want them + # to be inherited by other subprocesses: the CLOEXEC flag is + # removed from the child's FDs by _dup2(), between fork() and + # exec(). + # This is not atomic: we would need the pipe2() syscall for that. + r, w = os.pipe() + self._set_cloexec_flag(r) + self._set_cloexec_flag(w) + return r, w + + def _close_fds(self, but): + if hasattr(os, 'closerange'): + os.closerange(3, but) + os.closerange(but + 1, MAXFD) + else: + for i in xrange(3, MAXFD): + if i == but: + continue + try: + os.close(i) + except: + pass + + def _execute_child(self, args, executable, preexec_fn, close_fds, + cwd, env, universal_newlines, + startupinfo, creationflags, shell, + p2cread, p2cwrite, + c2pread, c2pwrite, + errread, errwrite): + """Execute program (POSIX version)""" + + if isinstance(args, types.StringTypes): + args = [args] + else: + args = list(args) + + if shell: + args = ["/bin/sh", "-c"] + args + if executable: + args[0] = executable + + if executable is None: + executable = args[0] + + # For transferring possible exec failure from child to parent + # The first char specifies the exception type: 0 means + # OSError, 1 means some other error. + errpipe_read, errpipe_write = self.pipe_cloexec() + try: + try: + gc_was_enabled = gc.isenabled() + # Disable gc to avoid bug where gc -> file_dealloc -> + # write to stderr -> hang. + # http://bugs.python.org/issue1336 + gc.disable() + try: + self.pid = os.fork() + except: + if gc_was_enabled: + gc.enable() + raise + self._child_created = True + if self.pid == 0: + # Child + try: + # Close parent's pipe ends + if p2cwrite is not None: + os.close(p2cwrite) + if c2pread is not None: + os.close(c2pread) + if errread is not None: + os.close(errread) + os.close(errpipe_read) + + # When duping fds, if there arises a situation + # where one of the fds is either 0, 1 or 2, it + # is possible that it is overwritten (#12607). + if c2pwrite == 0: + c2pwrite = os.dup(c2pwrite) + if errwrite == 0 or errwrite == 1: + errwrite = os.dup(errwrite) + + # Dup fds for child + def _dup2(a, b): + # dup2() removes the CLOEXEC flag but + # we must do it ourselves if dup2() + # would be a no-op (issue #10806). + if a == b: + self._set_cloexec_flag(a, False) + elif a is not None: + os.dup2(a, b) + _dup2(p2cread, 0) + _dup2(c2pwrite, 1) + _dup2(errwrite, 2) + + # Close pipe fds. Make sure we don't close the + # same fd more than once, or standard fds. + closed = set([None]) + for fd in [p2cread, c2pwrite, errwrite]: + if fd not in closed and fd > 2: + os.close(fd) + closed.add(fd) + + # Close all other fds, if asked for + if close_fds: + self._close_fds(but=errpipe_write) + + if cwd is not None: + os.chdir(cwd) + + if preexec_fn: + preexec_fn() + + if env is None: + os.execvp(executable, args) + else: + os.execvpe(executable, args, env) + + except: + exc_type, exc_value, tb = sys.exc_info() + # Save the traceback and attach it to the exception + # object + exc_lines = traceback.format_exception(exc_type, + exc_value, + tb) + exc_value.child_traceback = ''.join(exc_lines) + os.write(errpipe_write, pickle.dumps(exc_value)) + + # This exitcode won't be reported to applications, + # so it really doesn't matter what we return. + os._exit(255) + + # Parent + if gc_was_enabled: + gc.enable() + finally: + # be sure the FD is closed no matter what + os.close(errpipe_write) + + if p2cread is not None and p2cwrite is not None: + os.close(p2cread) + if c2pwrite is not None and c2pread is not None: + os.close(c2pwrite) + if errwrite is not None and errread is not None: + os.close(errwrite) + + # Wait for exec to fail or succeed; possibly raising exception + # Exception limited to 1M + data = _eintr_retry_call(os.read, errpipe_read, 1048576) + finally: + # be sure the FD is closed no matter what + os.close(errpipe_read) + + if data != "": + try: + _eintr_retry_call(os.waitpid, self.pid, 0) + except OSError, e: + if e.errno != errno.ECHILD: + raise + child_exception = pickle.loads(data) + for fd in (p2cwrite, c2pread, errread): + if fd is not None: + os.close(fd) + raise child_exception + + def _handle_exitstatus(self, sts, _WIFSIGNALED=os.WIFSIGNALED, + _WTERMSIG=os.WTERMSIG, _WIFEXITED=os.WIFEXITED, + _WEXITSTATUS=os.WEXITSTATUS): + # This method is called (indirectly) by __del__, so it cannot + # refer to anything outside of its local scope.""" + if _WIFSIGNALED(sts): + self.returncode = -_WTERMSIG(sts) + elif _WIFEXITED(sts): + self.returncode = _WEXITSTATUS(sts) + else: + # Should never happen + raise RuntimeError("Unknown child exit status!") + + def _internal_poll(self, _deadstate=None, _waitpid=os.waitpid, + _WNOHANG=os.WNOHANG, _os_error=os.error): + """Check if child process has terminated. Returns returncode + attribute. + + This method is called by __del__, so it cannot reference anything + outside of the local scope (nor can any methods it calls). + + """ + if self.returncode is None: + try: + pid, sts = _waitpid(self.pid, _WNOHANG) + if pid == self.pid: + self._handle_exitstatus(sts) + except _os_error: + if _deadstate is not None: + self.returncode = _deadstate + return self.returncode + + def wait(self): + """Wait for child process to terminate. Returns returncode + attribute.""" + if self.returncode is None: + try: + pid, sts = _eintr_retry_call(os.waitpid, self.pid, 0) + except OSError, e: + if e.errno != errno.ECHILD: + raise + # This happens if SIGCLD is set to be ignored or waiting + # for child processes has otherwise been disabled for our + # process. This child is dead, we can't get the status. + sts = 0 + self._handle_exitstatus(sts) + return self.returncode + + def _communicate(self, input): + if self.stdin: + # Flush stdio buffer. This might block, if the user has + # been writing to .stdin in an uncontrolled fashion. + self.stdin.flush() + if not input: + self.stdin.close() + + if _has_poll: + stdout, stderr = self._communicate_with_poll(input) + else: + stdout, stderr = self._communicate_with_select(input) + + # All data exchanged. Translate lists into strings. + if stdout is not None: + stdout = ''.join(stdout) + if stderr is not None: + stderr = ''.join(stderr) + + # Translate newlines, if requested. We cannot let the file + # object do the translation: It is based on stdio, which is + # impossible to combine with select (unless forcing no + # buffering). + if self.universal_newlines and hasattr(file, 'newlines'): + if stdout: + stdout = self._translate_newlines(stdout) + if stderr: + stderr = self._translate_newlines(stderr) + + self.wait() + return (stdout, stderr) + + def _communicate_with_poll(self, input): + stdout = None # Return + stderr = None # Return + fd2file = {} + fd2output = {} + + poller = select.poll() + + def register_and_append(file_obj, eventmask): + poller.register(file_obj.fileno(), eventmask) + fd2file[file_obj.fileno()] = file_obj + + def close_unregister_and_remove(fd): + poller.unregister(fd) + fd2file[fd].close() + fd2file.pop(fd) + + if self.stdin and input: + register_and_append(self.stdin, select.POLLOUT) + + select_POLLIN_POLLPRI = select.POLLIN | select.POLLPRI + if self.stdout: + register_and_append(self.stdout, select_POLLIN_POLLPRI) + fd2output[self.stdout.fileno()] = stdout = [] + if self.stderr: + register_and_append(self.stderr, select_POLLIN_POLLPRI) + fd2output[self.stderr.fileno()] = stderr = [] + + input_offset = 0 + while fd2file: + try: + ready = poller.poll() + except select.error, e: + if e.args[0] == errno.EINTR: + continue + raise + + for fd, mode in ready: + if mode & select.POLLOUT: + chunk = input[input_offset: input_offset + _PIPE_BUF] + try: + input_offset += os.write(fd, chunk) + except OSError, e: + if e.errno == errno.EPIPE: + close_unregister_and_remove(fd) + else: + raise + else: + if input_offset >= len(input): + close_unregister_and_remove(fd) + elif mode & select_POLLIN_POLLPRI: + data = os.read(fd, 4096) + if not data: + close_unregister_and_remove(fd) + fd2output[fd].append(data) + else: + # Ignore hang up or errors. + close_unregister_and_remove(fd) + + return (stdout, stderr) + + def _communicate_with_select(self, input): + read_set = [] + write_set = [] + stdout = None # Return + stderr = None # Return + + if self.stdin and input: + write_set.append(self.stdin) + if self.stdout: + read_set.append(self.stdout) + stdout = [] + if self.stderr: + read_set.append(self.stderr) + stderr = [] + + input_offset = 0 + while read_set or write_set: + try: + rlist, wlist, xlist = select.select( + read_set, write_set, []) + except select.error, e: + if e.args[0] == errno.EINTR: + continue + raise + + if self.stdin in wlist: + chunk = input[input_offset: input_offset + _PIPE_BUF] + try: + bytes_written = os.write(self.stdin.fileno(), chunk) + except OSError, e: + if e.errno == errno.EPIPE: + self.stdin.close() + write_set.remove(self.stdin) + else: + raise + else: + input_offset += bytes_written + if input_offset >= len(input): + self.stdin.close() + write_set.remove(self.stdin) + + if self.stdout in rlist: + data = os.read(self.stdout.fileno(), 1024) + if data == "": + self.stdout.close() + read_set.remove(self.stdout) + stdout.append(data) + + if self.stderr in rlist: + data = os.read(self.stderr.fileno(), 1024) + if data == "": + self.stderr.close() + read_set.remove(self.stderr) + stderr.append(data) + + return (stdout, stderr) + + def send_signal(self, sig): + """Send a signal to the process + """ + os.kill(self.pid, sig) + + def terminate(self): + """Terminate the process with SIGTERM + """ + self.send_signal(signal.SIGTERM) + + def kill(self): + """Kill the process with SIGKILL + """ + self.send_signal(signal.SIGKILL) + + +def _demo_posix(): + # + # Example 1: Simple redirection: Get process list + # + plist = Popen(["ps"], stdout=PIPE).communicate()[0] + print "Process list:" + print plist + + # + # Example 2: Change uid before executing child + # + if os.getuid() == 0: + p = Popen(["id"], preexec_fn=lambda: os.setuid(100)) + p.wait() + + # + # Example 3: Connecting several subprocesses + # + print "Looking for 'hda'..." + p1 = Popen(["dmesg"], stdout=PIPE) + p2 = Popen(["grep", "hda"], stdin=p1.stdout, stdout=PIPE) + print repr(p2.communicate()[0]) + + # + # Example 4: Catch execution error + # + print + print "Trying a weird file..." + try: + print Popen(["/this/path/does/not/exist"]).communicate() + except OSError, e: + if e.errno == errno.ENOENT: + print "The file didn't exist. I thought so..." + print "Child traceback:" + print e.child_traceback + else: + print "Error", e.errno + else: + print >>sys.stderr, "Gosh. No error." + + +def _demo_windows(): + # + # Example 1: Connecting several subprocesses + # + print "Looking for 'PROMPT' in set output..." + p1 = Popen("set", stdout=PIPE, shell=True) + p2 = Popen('find "PROMPT"', stdin=p1.stdout, stdout=PIPE) + print repr(p2.communicate()[0]) + + # + # Example 2: Simple execution of program + # + print "Executing calc..." + p = Popen("calc") + p.wait() + + +if __name__ == "__main__": + if mswindows: + _demo_windows() + else: + _demo_posix() diff --git a/lib/cherrypy/_cpconfig.py b/lib/cherrypy/_cpconfig.py new file mode 100644 index 00000000..c11bc1d1 --- /dev/null +++ b/lib/cherrypy/_cpconfig.py @@ -0,0 +1,317 @@ +""" +Configuration system for CherryPy. + +Configuration in CherryPy is implemented via dictionaries. Keys are strings +which name the mapped value, which may be of any type. + + +Architecture +------------ + +CherryPy Requests are part of an Application, which runs in a global context, +and configuration data may apply to any of those three scopes: + +Global + Configuration entries which apply everywhere are stored in + cherrypy.config. + +Application + Entries which apply to each mounted application are stored + on the Application object itself, as 'app.config'. This is a two-level + dict where each key is a path, or "relative URL" (for example, "/" or + "/path/to/my/page"), and each value is a config dict. Usually, this + data is provided in the call to tree.mount(root(), config=conf), + although you may also use app.merge(conf). + +Request + Each Request object possesses a single 'Request.config' dict. + Early in the request process, this dict is populated by merging global + config entries, Application entries (whose path equals or is a parent + of Request.path_info), and any config acquired while looking up the + page handler (see next). + + +Declaration +----------- + +Configuration data may be supplied as a Python dictionary, as a filename, +or as an open file object. When you supply a filename or file, CherryPy +uses Python's builtin ConfigParser; you declare Application config by +writing each path as a section header:: + + [/path/to/my/page] + request.stream = True + +To declare global configuration entries, place them in a [global] section. + +You may also declare config entries directly on the classes and methods +(page handlers) that make up your CherryPy application via the ``_cp_config`` +attribute. For example:: + + class Demo: + _cp_config = {'tools.gzip.on': True} + + def index(self): + return "Hello world" + index.exposed = True + index._cp_config = {'request.show_tracebacks': False} + +.. note:: + + This behavior is only guaranteed for the default dispatcher. + Other dispatchers may have different restrictions on where + you can attach _cp_config attributes. + + +Namespaces +---------- + +Configuration keys are separated into namespaces by the first "." in the key. +Current namespaces: + +engine + Controls the 'application engine', including autoreload. + These can only be declared in the global config. + +tree + Grafts cherrypy.Application objects onto cherrypy.tree. + These can only be declared in the global config. + +hooks + Declares additional request-processing functions. + +log + Configures the logging for each application. + These can only be declared in the global or / config. + +request + Adds attributes to each Request. + +response + Adds attributes to each Response. + +server + Controls the default HTTP server via cherrypy.server. + These can only be declared in the global config. + +tools + Runs and configures additional request-processing packages. + +wsgi + Adds WSGI middleware to an Application's "pipeline". + These can only be declared in the app's root config ("/"). + +checker + Controls the 'checker', which looks for common errors in + app state (including config) when the engine starts. + Global config only. + +The only key that does not exist in a namespace is the "environment" entry. +This special entry 'imports' other config entries from a template stored in +cherrypy._cpconfig.environments[environment]. It only applies to the global +config, and only when you use cherrypy.config.update. + +You can define your own namespaces to be called at the Global, Application, +or Request level, by adding a named handler to cherrypy.config.namespaces, +app.namespaces, or app.request_class.namespaces. The name can +be any string, and the handler must be either a callable or a (Python 2.5 +style) context manager. +""" + +import cherrypy +from cherrypy._cpcompat import set, basestring +from cherrypy.lib import reprconf + +# Deprecated in CherryPy 3.2--remove in 3.3 +NamespaceSet = reprconf.NamespaceSet + + +def merge(base, other): + """Merge one app config (from a dict, file, or filename) into another. + + If the given config is a filename, it will be appended to + the list of files to monitor for "autoreload" changes. + """ + if isinstance(other, basestring): + cherrypy.engine.autoreload.files.add(other) + + # Load other into base + for section, value_map in reprconf.as_dict(other).items(): + if not isinstance(value_map, dict): + raise ValueError( + "Application config must include section headers, but the " + "config you tried to merge doesn't have any sections. " + "Wrap your config in another dict with paths as section " + "headers, for example: {'/': config}.") + base.setdefault(section, {}).update(value_map) + + +class Config(reprconf.Config): + + """The 'global' configuration data for the entire CherryPy process.""" + + def update(self, config): + """Update self from a dict, file or filename.""" + if isinstance(config, basestring): + # Filename + cherrypy.engine.autoreload.files.add(config) + reprconf.Config.update(self, config) + + def _apply(self, config): + """Update self from a dict.""" + if isinstance(config.get("global"), dict): + if len(config) > 1: + cherrypy.checker.global_config_contained_paths = True + config = config["global"] + if 'tools.staticdir.dir' in config: + config['tools.staticdir.section'] = "global" + reprconf.Config._apply(self, config) + + def __call__(self, *args, **kwargs): + """Decorator for page handlers to set _cp_config.""" + if args: + raise TypeError( + "The cherrypy.config decorator does not accept positional " + "arguments; you must use keyword arguments.") + + def tool_decorator(f): + if not hasattr(f, "_cp_config"): + f._cp_config = {} + for k, v in kwargs.items(): + f._cp_config[k] = v + return f + return tool_decorator + + +# Sphinx begin config.environments +Config.environments = environments = { + "staging": { + 'engine.autoreload.on': False, + 'checker.on': False, + 'tools.log_headers.on': False, + 'request.show_tracebacks': False, + 'request.show_mismatched_params': False, + }, + "production": { + 'engine.autoreload.on': False, + 'checker.on': False, + 'tools.log_headers.on': False, + 'request.show_tracebacks': False, + 'request.show_mismatched_params': False, + 'log.screen': False, + }, + "embedded": { + # For use with CherryPy embedded in another deployment stack. + 'engine.autoreload.on': False, + 'checker.on': False, + 'tools.log_headers.on': False, + 'request.show_tracebacks': False, + 'request.show_mismatched_params': False, + 'log.screen': False, + 'engine.SIGHUP': None, + 'engine.SIGTERM': None, + }, + "test_suite": { + 'engine.autoreload.on': False, + 'checker.on': False, + 'tools.log_headers.on': False, + 'request.show_tracebacks': True, + 'request.show_mismatched_params': True, + 'log.screen': False, + }, +} +# Sphinx end config.environments + + +def _server_namespace_handler(k, v): + """Config handler for the "server" namespace.""" + atoms = k.split(".", 1) + if len(atoms) > 1: + # Special-case config keys of the form 'server.servername.socket_port' + # to configure additional HTTP servers. + if not hasattr(cherrypy, "servers"): + cherrypy.servers = {} + + servername, k = atoms + if servername not in cherrypy.servers: + from cherrypy import _cpserver + cherrypy.servers[servername] = _cpserver.Server() + # On by default, but 'on = False' can unsubscribe it (see below). + cherrypy.servers[servername].subscribe() + + if k == 'on': + if v: + cherrypy.servers[servername].subscribe() + else: + cherrypy.servers[servername].unsubscribe() + else: + setattr(cherrypy.servers[servername], k, v) + else: + setattr(cherrypy.server, k, v) +Config.namespaces["server"] = _server_namespace_handler + + +def _engine_namespace_handler(k, v): + """Backward compatibility handler for the "engine" namespace.""" + engine = cherrypy.engine + + deprecated = { + 'autoreload_on': 'autoreload.on', + 'autoreload_frequency': 'autoreload.frequency', + 'autoreload_match': 'autoreload.match', + 'reload_files': 'autoreload.files', + 'deadlock_poll_freq': 'timeout_monitor.frequency' + } + + if k in deprecated: + engine.log( + 'WARNING: Use of engine.%s is deprecated and will be removed in a ' + 'future version. Use engine.%s instead.' % (k, deprecated[k])) + + if k == 'autoreload_on': + if v: + engine.autoreload.subscribe() + else: + engine.autoreload.unsubscribe() + elif k == 'autoreload_frequency': + engine.autoreload.frequency = v + elif k == 'autoreload_match': + engine.autoreload.match = v + elif k == 'reload_files': + engine.autoreload.files = set(v) + elif k == 'deadlock_poll_freq': + engine.timeout_monitor.frequency = v + elif k == 'SIGHUP': + engine.listeners['SIGHUP'] = set([v]) + elif k == 'SIGTERM': + engine.listeners['SIGTERM'] = set([v]) + elif "." in k: + plugin, attrname = k.split(".", 1) + plugin = getattr(engine, plugin) + if attrname == 'on': + if v and hasattr(getattr(plugin, 'subscribe', None), '__call__'): + plugin.subscribe() + return + elif ( + (not v) and + hasattr(getattr(plugin, 'unsubscribe', None), '__call__') + ): + plugin.unsubscribe() + return + setattr(plugin, attrname, v) + else: + setattr(engine, k, v) +Config.namespaces["engine"] = _engine_namespace_handler + + +def _tree_namespace_handler(k, v): + """Namespace handler for the 'tree' config namespace.""" + if isinstance(v, dict): + for script_name, app in v.items(): + cherrypy.tree.graft(app, script_name) + cherrypy.engine.log("Mounted: %s on %s" % + (app, script_name or "/")) + else: + cherrypy.tree.graft(v, v.script_name) + cherrypy.engine.log("Mounted: %s on %s" % (v, v.script_name or "/")) +Config.namespaces["tree"] = _tree_namespace_handler diff --git a/lib/cherrypy/_cpdispatch.py b/lib/cherrypy/_cpdispatch.py new file mode 100644 index 00000000..1c2d7df8 --- /dev/null +++ b/lib/cherrypy/_cpdispatch.py @@ -0,0 +1,686 @@ +"""CherryPy dispatchers. + +A 'dispatcher' is the object which looks up the 'page handler' callable +and collects config for the current request based on the path_info, other +request attributes, and the application architecture. The core calls the +dispatcher as early as possible, passing it a 'path_info' argument. + +The default dispatcher discovers the page handler by matching path_info +to a hierarchical arrangement of objects, starting at request.app.root. +""" + +import string +import sys +import types +try: + classtype = (type, types.ClassType) +except AttributeError: + classtype = type + +import cherrypy +from cherrypy._cpcompat import set + + +class PageHandler(object): + + """Callable which sets response.body.""" + + def __init__(self, callable, *args, **kwargs): + self.callable = callable + self.args = args + self.kwargs = kwargs + + def get_args(self): + return cherrypy.serving.request.args + + def set_args(self, args): + cherrypy.serving.request.args = args + return cherrypy.serving.request.args + + args = property( + get_args, + set_args, + doc="The ordered args should be accessible from post dispatch hooks" + ) + + def get_kwargs(self): + return cherrypy.serving.request.kwargs + + def set_kwargs(self, kwargs): + cherrypy.serving.request.kwargs = kwargs + return cherrypy.serving.request.kwargs + + kwargs = property( + get_kwargs, + set_kwargs, + doc="The named kwargs should be accessible from post dispatch hooks" + ) + + def __call__(self): + try: + return self.callable(*self.args, **self.kwargs) + except TypeError: + x = sys.exc_info()[1] + try: + test_callable_spec(self.callable, self.args, self.kwargs) + except cherrypy.HTTPError: + raise sys.exc_info()[1] + except: + raise x + raise + + +def test_callable_spec(callable, callable_args, callable_kwargs): + """ + Inspect callable and test to see if the given args are suitable for it. + + When an error occurs during the handler's invoking stage there are 2 + erroneous cases: + 1. Too many parameters passed to a function which doesn't define + one of *args or **kwargs. + 2. Too little parameters are passed to the function. + + There are 3 sources of parameters to a cherrypy handler. + 1. query string parameters are passed as keyword parameters to the + handler. + 2. body parameters are also passed as keyword parameters. + 3. when partial matching occurs, the final path atoms are passed as + positional args. + Both the query string and path atoms are part of the URI. If they are + incorrect, then a 404 Not Found should be raised. Conversely the body + parameters are part of the request; if they are invalid a 400 Bad Request. + """ + show_mismatched_params = getattr( + cherrypy.serving.request, 'show_mismatched_params', False) + try: + (args, varargs, varkw, defaults) = getargspec(callable) + except TypeError: + if isinstance(callable, object) and hasattr(callable, '__call__'): + (args, varargs, varkw, + defaults) = getargspec(callable.__call__) + else: + # If it wasn't one of our own types, re-raise + # the original error + raise + + if args and args[0] == 'self': + args = args[1:] + + arg_usage = dict([(arg, 0,) for arg in args]) + vararg_usage = 0 + varkw_usage = 0 + extra_kwargs = set() + + for i, value in enumerate(callable_args): + try: + arg_usage[args[i]] += 1 + except IndexError: + vararg_usage += 1 + + for key in callable_kwargs.keys(): + try: + arg_usage[key] += 1 + except KeyError: + varkw_usage += 1 + extra_kwargs.add(key) + + # figure out which args have defaults. + args_with_defaults = args[-len(defaults or []):] + for i, val in enumerate(defaults or []): + # Defaults take effect only when the arg hasn't been used yet. + if arg_usage[args_with_defaults[i]] == 0: + arg_usage[args_with_defaults[i]] += 1 + + missing_args = [] + multiple_args = [] + for key, usage in arg_usage.items(): + if usage == 0: + missing_args.append(key) + elif usage > 1: + multiple_args.append(key) + + if missing_args: + # In the case where the method allows body arguments + # there are 3 potential errors: + # 1. not enough query string parameters -> 404 + # 2. not enough body parameters -> 400 + # 3. not enough path parts (partial matches) -> 404 + # + # We can't actually tell which case it is, + # so I'm raising a 404 because that covers 2/3 of the + # possibilities + # + # In the case where the method does not allow body + # arguments it's definitely a 404. + message = None + if show_mismatched_params: + message = "Missing parameters: %s" % ",".join(missing_args) + raise cherrypy.HTTPError(404, message=message) + + # the extra positional arguments come from the path - 404 Not Found + if not varargs and vararg_usage > 0: + raise cherrypy.HTTPError(404) + + body_params = cherrypy.serving.request.body.params or {} + body_params = set(body_params.keys()) + qs_params = set(callable_kwargs.keys()) - body_params + + if multiple_args: + if qs_params.intersection(set(multiple_args)): + # If any of the multiple parameters came from the query string then + # it's a 404 Not Found + error = 404 + else: + # Otherwise it's a 400 Bad Request + error = 400 + + message = None + if show_mismatched_params: + message = "Multiple values for parameters: "\ + "%s" % ",".join(multiple_args) + raise cherrypy.HTTPError(error, message=message) + + if not varkw and varkw_usage > 0: + + # If there were extra query string parameters, it's a 404 Not Found + extra_qs_params = set(qs_params).intersection(extra_kwargs) + if extra_qs_params: + message = None + if show_mismatched_params: + message = "Unexpected query string "\ + "parameters: %s" % ", ".join(extra_qs_params) + raise cherrypy.HTTPError(404, message=message) + + # If there were any extra body parameters, it's a 400 Not Found + extra_body_params = set(body_params).intersection(extra_kwargs) + if extra_body_params: + message = None + if show_mismatched_params: + message = "Unexpected body parameters: "\ + "%s" % ", ".join(extra_body_params) + raise cherrypy.HTTPError(400, message=message) + + +try: + import inspect +except ImportError: + test_callable_spec = lambda callable, args, kwargs: None +else: + getargspec = inspect.getargspec + # Python 3 requires using getfullargspec if keyword-only arguments are present + if hasattr(inspect, 'getfullargspec'): + def getargspec(callable): + return inspect.getfullargspec(callable)[:4] + + +class LateParamPageHandler(PageHandler): + + """When passing cherrypy.request.params to the page handler, we do not + want to capture that dict too early; we want to give tools like the + decoding tool a chance to modify the params dict in-between the lookup + of the handler and the actual calling of the handler. This subclass + takes that into account, and allows request.params to be 'bound late' + (it's more complicated than that, but that's the effect). + """ + + def _get_kwargs(self): + kwargs = cherrypy.serving.request.params.copy() + if self._kwargs: + kwargs.update(self._kwargs) + return kwargs + + def _set_kwargs(self, kwargs): + cherrypy.serving.request.kwargs = kwargs + self._kwargs = kwargs + + kwargs = property(_get_kwargs, _set_kwargs, + doc='page handler kwargs (with ' + 'cherrypy.request.params copied in)') + + +if sys.version_info < (3, 0): + punctuation_to_underscores = string.maketrans( + string.punctuation, '_' * len(string.punctuation)) + + def validate_translator(t): + if not isinstance(t, str) or len(t) != 256: + raise ValueError( + "The translate argument must be a str of len 256.") +else: + punctuation_to_underscores = str.maketrans( + string.punctuation, '_' * len(string.punctuation)) + + def validate_translator(t): + if not isinstance(t, dict): + raise ValueError("The translate argument must be a dict.") + + +class Dispatcher(object): + + """CherryPy Dispatcher which walks a tree of objects to find a handler. + + The tree is rooted at cherrypy.request.app.root, and each hierarchical + component in the path_info argument is matched to a corresponding nested + attribute of the root object. Matching handlers must have an 'exposed' + attribute which evaluates to True. The special method name "index" + matches a URI which ends in a slash ("/"). The special method name + "default" may match a portion of the path_info (but only when no longer + substring of the path_info matches some other object). + + This is the default, built-in dispatcher for CherryPy. + """ + + dispatch_method_name = '_cp_dispatch' + """ + The name of the dispatch method that nodes may optionally implement + to provide their own dynamic dispatch algorithm. + """ + + def __init__(self, dispatch_method_name=None, + translate=punctuation_to_underscores): + validate_translator(translate) + self.translate = translate + if dispatch_method_name: + self.dispatch_method_name = dispatch_method_name + + def __call__(self, path_info): + """Set handler and config for the current request.""" + request = cherrypy.serving.request + func, vpath = self.find_handler(path_info) + + if func: + # Decode any leftover %2F in the virtual_path atoms. + vpath = [x.replace("%2F", "/") for x in vpath] + request.handler = LateParamPageHandler(func, *vpath) + else: + request.handler = cherrypy.NotFound() + + def find_handler(self, path): + """Return the appropriate page handler, plus any virtual path. + + This will return two objects. The first will be a callable, + which can be used to generate page output. Any parameters from + the query string or request body will be sent to that callable + as keyword arguments. + + The callable is found by traversing the application's tree, + starting from cherrypy.request.app.root, and matching path + components to successive objects in the tree. For example, the + URL "/path/to/handler" might return root.path.to.handler. + + The second object returned will be a list of names which are + 'virtual path' components: parts of the URL which are dynamic, + and were not used when looking up the handler. + These virtual path components are passed to the handler as + positional arguments. + """ + request = cherrypy.serving.request + app = request.app + root = app.root + dispatch_name = self.dispatch_method_name + + # Get config for the root object/path. + fullpath = [x for x in path.strip('/').split('/') if x] + ['index'] + fullpath_len = len(fullpath) + segleft = fullpath_len + nodeconf = {} + if hasattr(root, "_cp_config"): + nodeconf.update(root._cp_config) + if "/" in app.config: + nodeconf.update(app.config["/"]) + object_trail = [['root', root, nodeconf, segleft]] + + node = root + iternames = fullpath[:] + while iternames: + name = iternames[0] + # map to legal Python identifiers (e.g. replace '.' with '_') + objname = name.translate(self.translate) + + nodeconf = {} + subnode = getattr(node, objname, None) + pre_len = len(iternames) + if subnode is None: + dispatch = getattr(node, dispatch_name, None) + if dispatch and hasattr(dispatch, '__call__') and not \ + getattr(dispatch, 'exposed', False) and \ + pre_len > 1: + # Don't expose the hidden 'index' token to _cp_dispatch + # We skip this if pre_len == 1 since it makes no sense + # to call a dispatcher when we have no tokens left. + index_name = iternames.pop() + subnode = dispatch(vpath=iternames) + iternames.append(index_name) + else: + # We didn't find a path, but keep processing in case there + # is a default() handler. + iternames.pop(0) + else: + # We found the path, remove the vpath entry + iternames.pop(0) + segleft = len(iternames) + if segleft > pre_len: + # No path segment was removed. Raise an error. + raise cherrypy.CherryPyException( + "A vpath segment was added. Custom dispatchers may only " + + "remove elements. While trying to process " + + "{0} in {1}".format(name, fullpath) + ) + elif segleft == pre_len: + # Assume that the handler used the current path segment, but + # did not pop it. This allows things like + # return getattr(self, vpath[0], None) + iternames.pop(0) + segleft -= 1 + node = subnode + + if node is not None: + # Get _cp_config attached to this node. + if hasattr(node, "_cp_config"): + nodeconf.update(node._cp_config) + + # Mix in values from app.config for this path. + existing_len = fullpath_len - pre_len + if existing_len != 0: + curpath = '/' + '/'.join(fullpath[0:existing_len]) + else: + curpath = '' + new_segs = fullpath[fullpath_len - pre_len:fullpath_len - segleft] + for seg in new_segs: + curpath += '/' + seg + if curpath in app.config: + nodeconf.update(app.config[curpath]) + + object_trail.append([name, node, nodeconf, segleft]) + + def set_conf(): + """Collapse all object_trail config into cherrypy.request.config. + """ + base = cherrypy.config.copy() + # Note that we merge the config from each node + # even if that node was None. + for name, obj, conf, segleft in object_trail: + base.update(conf) + if 'tools.staticdir.dir' in conf: + base['tools.staticdir.section'] = '/' + \ + '/'.join(fullpath[0:fullpath_len - segleft]) + return base + + # Try successive objects (reverse order) + num_candidates = len(object_trail) - 1 + for i in range(num_candidates, -1, -1): + + name, candidate, nodeconf, segleft = object_trail[i] + if candidate is None: + continue + + # Try a "default" method on the current leaf. + if hasattr(candidate, "default"): + defhandler = candidate.default + if getattr(defhandler, 'exposed', False): + # Insert any extra _cp_config from the default handler. + conf = getattr(defhandler, "_cp_config", {}) + object_trail.insert( + i + 1, ["default", defhandler, conf, segleft]) + request.config = set_conf() + # See https://bitbucket.org/cherrypy/cherrypy/issue/613 + request.is_index = path.endswith("/") + return defhandler, fullpath[fullpath_len - segleft:-1] + + # Uncomment the next line to restrict positional params to + # "default". + # if i < num_candidates - 2: continue + + # Try the current leaf. + if getattr(candidate, 'exposed', False): + request.config = set_conf() + if i == num_candidates: + # We found the extra ".index". Mark request so tools + # can redirect if path_info has no trailing slash. + request.is_index = True + else: + # We're not at an 'index' handler. Mark request so tools + # can redirect if path_info has NO trailing slash. + # Note that this also includes handlers which take + # positional parameters (virtual paths). + request.is_index = False + return candidate, fullpath[fullpath_len - segleft:-1] + + # We didn't find anything + request.config = set_conf() + return None, [] + + +class MethodDispatcher(Dispatcher): + + """Additional dispatch based on cherrypy.request.method.upper(). + + Methods named GET, POST, etc will be called on an exposed class. + The method names must be all caps; the appropriate Allow header + will be output showing all capitalized method names as allowable + HTTP verbs. + + Note that the containing class must be exposed, not the methods. + """ + + def __call__(self, path_info): + """Set handler and config for the current request.""" + request = cherrypy.serving.request + resource, vpath = self.find_handler(path_info) + + if resource: + # Set Allow header + avail = [m for m in dir(resource) if m.isupper()] + if "GET" in avail and "HEAD" not in avail: + avail.append("HEAD") + avail.sort() + cherrypy.serving.response.headers['Allow'] = ", ".join(avail) + + # Find the subhandler + meth = request.method.upper() + func = getattr(resource, meth, None) + if func is None and meth == "HEAD": + func = getattr(resource, "GET", None) + if func: + # Grab any _cp_config on the subhandler. + if hasattr(func, "_cp_config"): + request.config.update(func._cp_config) + + # Decode any leftover %2F in the virtual_path atoms. + vpath = [x.replace("%2F", "/") for x in vpath] + request.handler = LateParamPageHandler(func, *vpath) + else: + request.handler = cherrypy.HTTPError(405) + else: + request.handler = cherrypy.NotFound() + + +class RoutesDispatcher(object): + + """A Routes based dispatcher for CherryPy.""" + + def __init__(self, full_result=False, **mapper_options): + """ + Routes dispatcher + + Set full_result to True if you wish the controller + and the action to be passed on to the page handler + parameters. By default they won't be. + """ + import routes + self.full_result = full_result + self.controllers = {} + self.mapper = routes.Mapper(**mapper_options) + self.mapper.controller_scan = self.controllers.keys + + def connect(self, name, route, controller, **kwargs): + self.controllers[name] = controller + self.mapper.connect(name, route, controller=name, **kwargs) + + def redirect(self, url): + raise cherrypy.HTTPRedirect(url) + + def __call__(self, path_info): + """Set handler and config for the current request.""" + func = self.find_handler(path_info) + if func: + cherrypy.serving.request.handler = LateParamPageHandler(func) + else: + cherrypy.serving.request.handler = cherrypy.NotFound() + + def find_handler(self, path_info): + """Find the right page handler, and set request.config.""" + import routes + + request = cherrypy.serving.request + + config = routes.request_config() + config.mapper = self.mapper + if hasattr(request, 'wsgi_environ'): + config.environ = request.wsgi_environ + config.host = request.headers.get('Host', None) + config.protocol = request.scheme + config.redirect = self.redirect + + result = self.mapper.match(path_info) + + config.mapper_dict = result + params = {} + if result: + params = result.copy() + if not self.full_result: + params.pop('controller', None) + params.pop('action', None) + request.params.update(params) + + # Get config for the root object/path. + request.config = base = cherrypy.config.copy() + curpath = "" + + def merge(nodeconf): + if 'tools.staticdir.dir' in nodeconf: + nodeconf['tools.staticdir.section'] = curpath or "/" + base.update(nodeconf) + + app = request.app + root = app.root + if hasattr(root, "_cp_config"): + merge(root._cp_config) + if "/" in app.config: + merge(app.config["/"]) + + # Mix in values from app.config. + atoms = [x for x in path_info.split("/") if x] + if atoms: + last = atoms.pop() + else: + last = None + for atom in atoms: + curpath = "/".join((curpath, atom)) + if curpath in app.config: + merge(app.config[curpath]) + + handler = None + if result: + controller = result.get('controller') + controller = self.controllers.get(controller, controller) + if controller: + if isinstance(controller, classtype): + controller = controller() + # Get config from the controller. + if hasattr(controller, "_cp_config"): + merge(controller._cp_config) + + action = result.get('action') + if action is not None: + handler = getattr(controller, action, None) + # Get config from the handler + if hasattr(handler, "_cp_config"): + merge(handler._cp_config) + else: + handler = controller + + # Do the last path atom here so it can + # override the controller's _cp_config. + if last: + curpath = "/".join((curpath, last)) + if curpath in app.config: + merge(app.config[curpath]) + + return handler + + +def XMLRPCDispatcher(next_dispatcher=Dispatcher()): + from cherrypy.lib import xmlrpcutil + + def xmlrpc_dispatch(path_info): + path_info = xmlrpcutil.patched_path(path_info) + return next_dispatcher(path_info) + return xmlrpc_dispatch + + +def VirtualHost(next_dispatcher=Dispatcher(), use_x_forwarded_host=True, + **domains): + """ + Select a different handler based on the Host header. + + This can be useful when running multiple sites within one CP server. + It allows several domains to point to different parts of a single + website structure. For example:: + + http://www.domain.example -> root + http://www.domain2.example -> root/domain2/ + http://www.domain2.example:443 -> root/secure + + can be accomplished via the following config:: + + [/] + request.dispatch = cherrypy.dispatch.VirtualHost( + **{'www.domain2.example': '/domain2', + 'www.domain2.example:443': '/secure', + }) + + next_dispatcher + The next dispatcher object in the dispatch chain. + The VirtualHost dispatcher adds a prefix to the URL and calls + another dispatcher. Defaults to cherrypy.dispatch.Dispatcher(). + + use_x_forwarded_host + If True (the default), any "X-Forwarded-Host" + request header will be used instead of the "Host" header. This + is commonly added by HTTP servers (such as Apache) when proxying. + + ``**domains`` + A dict of {host header value: virtual prefix} pairs. + The incoming "Host" request header is looked up in this dict, + and, if a match is found, the corresponding "virtual prefix" + value will be prepended to the URL path before calling the + next dispatcher. Note that you often need separate entries + for "example.com" and "www.example.com". In addition, "Host" + headers may contain the port number. + """ + from cherrypy.lib import httputil + + def vhost_dispatch(path_info): + request = cherrypy.serving.request + header = request.headers.get + + domain = header('Host', '') + if use_x_forwarded_host: + domain = header("X-Forwarded-Host", domain) + + prefix = domains.get(domain, "") + if prefix: + path_info = httputil.urljoin(prefix, path_info) + + result = next_dispatcher(path_info) + + # Touch up staticdir config. See + # https://bitbucket.org/cherrypy/cherrypy/issue/614. + section = request.config.get('tools.staticdir.section') + if section: + section = section[len(prefix):] + request.config['tools.staticdir.section'] = section + + return result + return vhost_dispatch diff --git a/lib/cherrypy/_cperror.py b/lib/cherrypy/_cperror.py new file mode 100644 index 00000000..6256595b --- /dev/null +++ b/lib/cherrypy/_cperror.py @@ -0,0 +1,609 @@ +"""Exception classes for CherryPy. + +CherryPy provides (and uses) exceptions for declaring that the HTTP response +should be a status other than the default "200 OK". You can ``raise`` them like +normal Python exceptions. You can also call them and they will raise +themselves; this means you can set an +:class:`HTTPError` +or :class:`HTTPRedirect` as the +:attr:`request.handler`. + +.. _redirectingpost: + +Redirecting POST +================ + +When you GET a resource and are redirected by the server to another Location, +there's generally no problem since GET is both a "safe method" (there should +be no side-effects) and an "idempotent method" (multiple calls are no different +than a single call). + +POST, however, is neither safe nor idempotent--if you +charge a credit card, you don't want to be charged twice by a redirect! + +For this reason, *none* of the 3xx responses permit a user-agent (browser) to +resubmit a POST on redirection without first confirming the action with the +user: + +===== ================================= =========== +300 Multiple Choices Confirm with the user +301 Moved Permanently Confirm with the user +302 Found (Object moved temporarily) Confirm with the user +303 See Other GET the new URI--no confirmation +304 Not modified (for conditional GET only--POST should not raise this error) +305 Use Proxy Confirm with the user +307 Temporary Redirect Confirm with the user +===== ================================= =========== + +However, browsers have historically implemented these restrictions poorly; +in particular, many browsers do not force the user to confirm 301, 302 +or 307 when redirecting POST. For this reason, CherryPy defaults to 303, +which most user-agents appear to have implemented correctly. Therefore, if +you raise HTTPRedirect for a POST request, the user-agent will most likely +attempt to GET the new URI (without asking for confirmation from the user). +We realize this is confusing for developers, but it's the safest thing we +could do. You are of course free to raise ``HTTPRedirect(uri, status=302)`` +or any other 3xx status if you know what you're doing, but given the +environment, we couldn't let any of those be the default. + +Custom Error Handling +===================== + +.. image:: /refman/cperrors.gif + +Anticipated HTTP responses +-------------------------- + +The 'error_page' config namespace can be used to provide custom HTML output for +expected responses (like 404 Not Found). Supply a filename from which the +output will be read. The contents will be interpolated with the values +%(status)s, %(message)s, %(traceback)s, and %(version)s using plain old Python +`string formatting `_. + +:: + + _cp_config = { + 'error_page.404': os.path.join(localDir, "static/index.html") + } + + +Beginning in version 3.1, you may also provide a function or other callable as +an error_page entry. It will be passed the same status, message, traceback and +version arguments that are interpolated into templates:: + + def error_page_402(status, message, traceback, version): + return "Error %s - Well, I'm very sorry but you haven't paid!" % status + cherrypy.config.update({'error_page.402': error_page_402}) + +Also in 3.1, in addition to the numbered error codes, you may also supply +"error_page.default" to handle all codes which do not have their own error_page +entry. + + + +Unanticipated errors +-------------------- + +CherryPy also has a generic error handling mechanism: whenever an unanticipated +error occurs in your code, it will call +:func:`Request.error_response` to +set the response status, headers, and body. By default, this is the same +output as +:class:`HTTPError(500) `. If you want to provide +some other behavior, you generally replace "request.error_response". + +Here is some sample code that shows how to display a custom error message and +send an e-mail containing the error:: + + from cherrypy import _cperror + + def handle_error(): + cherrypy.response.status = 500 + cherrypy.response.body = [ + "Sorry, an error occured" + ] + sendMail('error@domain.com', + 'Error in your web app', + _cperror.format_exc()) + + class Root: + _cp_config = {'request.error_response': handle_error} + + +Note that you have to explicitly set +:attr:`response.body ` +and not simply return an error message as a result. +""" + +from cgi import escape as _escape +from sys import exc_info as _exc_info +from traceback import format_exception as _format_exception +from cherrypy._cpcompat import basestring, bytestr, iteritems, ntob +from cherrypy._cpcompat import tonative, urljoin as _urljoin +from cherrypy.lib import httputil as _httputil + + +class CherryPyException(Exception): + + """A base class for CherryPy exceptions.""" + pass + + +class TimeoutError(CherryPyException): + + """Exception raised when Response.timed_out is detected.""" + pass + + +class InternalRedirect(CherryPyException): + + """Exception raised to switch to the handler for a different URL. + + This exception will redirect processing to another path within the site + (without informing the client). Provide the new path as an argument when + raising the exception. Provide any params in the querystring for the new + URL. + """ + + def __init__(self, path, query_string=""): + import cherrypy + self.request = cherrypy.serving.request + + self.query_string = query_string + if "?" in path: + # Separate any params included in the path + path, self.query_string = path.split("?", 1) + + # Note that urljoin will "do the right thing" whether url is: + # 1. a URL relative to root (e.g. "/dummy") + # 2. a URL relative to the current path + # Note that any query string will be discarded. + path = _urljoin(self.request.path_info, path) + + # Set a 'path' member attribute so that code which traps this + # error can have access to it. + self.path = path + + CherryPyException.__init__(self, path, self.query_string) + + +class HTTPRedirect(CherryPyException): + + """Exception raised when the request should be redirected. + + This exception will force a HTTP redirect to the URL or URL's you give it. + The new URL must be passed as the first argument to the Exception, + e.g., HTTPRedirect(newUrl). Multiple URLs are allowed in a list. + If a URL is absolute, it will be used as-is. If it is relative, it is + assumed to be relative to the current cherrypy.request.path_info. + + If one of the provided URL is a unicode object, it will be encoded + using the default encoding or the one passed in parameter. + + There are multiple types of redirect, from which you can select via the + ``status`` argument. If you do not provide a ``status`` arg, it defaults to + 303 (or 302 if responding with HTTP/1.0). + + Examples:: + + raise cherrypy.HTTPRedirect("") + raise cherrypy.HTTPRedirect("/abs/path", 307) + raise cherrypy.HTTPRedirect(["path1", "path2?a=1&b=2"], 301) + + See :ref:`redirectingpost` for additional caveats. + """ + + status = None + """The integer HTTP status code to emit.""" + + urls = None + """The list of URL's to emit.""" + + encoding = 'utf-8' + """The encoding when passed urls are not native strings""" + + def __init__(self, urls, status=None, encoding=None): + import cherrypy + request = cherrypy.serving.request + + if isinstance(urls, basestring): + urls = [urls] + + abs_urls = [] + for url in urls: + url = tonative(url, encoding or self.encoding) + + # Note that urljoin will "do the right thing" whether url is: + # 1. a complete URL with host (e.g. "http://www.example.com/test") + # 2. a URL relative to root (e.g. "/dummy") + # 3. a URL relative to the current path + # Note that any query string in cherrypy.request is discarded. + url = _urljoin(cherrypy.url(), url) + abs_urls.append(url) + self.urls = abs_urls + + # RFC 2616 indicates a 301 response code fits our goal; however, + # browser support for 301 is quite messy. Do 302/303 instead. See + # http://www.alanflavell.org.uk/www/post-redirect.html + if status is None: + if request.protocol >= (1, 1): + status = 303 + else: + status = 302 + else: + status = int(status) + if status < 300 or status > 399: + raise ValueError("status must be between 300 and 399.") + + self.status = status + CherryPyException.__init__(self, abs_urls, status) + + def set_response(self): + """Modify cherrypy.response status, headers, and body to represent + self. + + CherryPy uses this internally, but you can also use it to create an + HTTPRedirect object and set its output without *raising* the exception. + """ + import cherrypy + response = cherrypy.serving.response + response.status = status = self.status + + if status in (300, 301, 302, 303, 307): + response.headers['Content-Type'] = "text/html;charset=utf-8" + # "The ... URI SHOULD be given by the Location field + # in the response." + response.headers['Location'] = self.urls[0] + + # "Unless the request method was HEAD, the entity of the response + # SHOULD contain a short hypertext note with a hyperlink to the + # new URI(s)." + msg = { + 300: "This resource can be found at ", + 301: "This resource has permanently moved to ", + 302: "This resource resides temporarily at ", + 303: "This resource can be found at ", + 307: "This resource has moved temporarily to ", + }[status] + msg += '%s.' + from xml.sax import saxutils + msgs = [msg % (saxutils.quoteattr(u), u) for u in self.urls] + response.body = ntob("
\n".join(msgs), 'utf-8') + # Previous code may have set C-L, so we have to reset it + # (allow finalize to set it). + response.headers.pop('Content-Length', None) + elif status == 304: + # Not Modified. + # "The response MUST include the following header fields: + # Date, unless its omission is required by section 14.18.1" + # The "Date" header should have been set in Response.__init__ + + # "...the response SHOULD NOT include other entity-headers." + for key in ('Allow', 'Content-Encoding', 'Content-Language', + 'Content-Length', 'Content-Location', 'Content-MD5', + 'Content-Range', 'Content-Type', 'Expires', + 'Last-Modified'): + if key in response.headers: + del response.headers[key] + + # "The 304 response MUST NOT contain a message-body." + response.body = None + # Previous code may have set C-L, so we have to reset it. + response.headers.pop('Content-Length', None) + elif status == 305: + # Use Proxy. + # self.urls[0] should be the URI of the proxy. + response.headers['Location'] = self.urls[0] + response.body = None + # Previous code may have set C-L, so we have to reset it. + response.headers.pop('Content-Length', None) + else: + raise ValueError("The %s status code is unknown." % status) + + def __call__(self): + """Use this exception as a request.handler (raise self).""" + raise self + + +def clean_headers(status): + """Remove any headers which should not apply to an error response.""" + import cherrypy + + response = cherrypy.serving.response + + # Remove headers which applied to the original content, + # but do not apply to the error page. + respheaders = response.headers + for key in ["Accept-Ranges", "Age", "ETag", "Location", "Retry-After", + "Vary", "Content-Encoding", "Content-Length", "Expires", + "Content-Location", "Content-MD5", "Last-Modified"]: + if key in respheaders: + del respheaders[key] + + if status != 416: + # A server sending a response with status code 416 (Requested + # range not satisfiable) SHOULD include a Content-Range field + # with a byte-range-resp-spec of "*". The instance-length + # specifies the current length of the selected resource. + # A response with status code 206 (Partial Content) MUST NOT + # include a Content-Range field with a byte-range- resp-spec of "*". + if "Content-Range" in respheaders: + del respheaders["Content-Range"] + + +class HTTPError(CherryPyException): + + """Exception used to return an HTTP error code (4xx-5xx) to the client. + + This exception can be used to automatically send a response using a + http status code, with an appropriate error page. It takes an optional + ``status`` argument (which must be between 400 and 599); it defaults to 500 + ("Internal Server Error"). It also takes an optional ``message`` argument, + which will be returned in the response body. See + `RFC2616 `_ + for a complete list of available error codes and when to use them. + + Examples:: + + raise cherrypy.HTTPError(403) + raise cherrypy.HTTPError( + "403 Forbidden", "You are not allowed to access this resource.") + """ + + status = None + """The HTTP status code. May be of type int or str (with a Reason-Phrase). + """ + + code = None + """The integer HTTP status code.""" + + reason = None + """The HTTP Reason-Phrase string.""" + + def __init__(self, status=500, message=None): + self.status = status + try: + self.code, self.reason, defaultmsg = _httputil.valid_status(status) + except ValueError: + raise self.__class__(500, _exc_info()[1].args[0]) + + if self.code < 400 or self.code > 599: + raise ValueError("status must be between 400 and 599.") + + # See http://www.python.org/dev/peps/pep-0352/ + # self.message = message + self._message = message or defaultmsg + CherryPyException.__init__(self, status, message) + + def set_response(self): + """Modify cherrypy.response status, headers, and body to represent + self. + + CherryPy uses this internally, but you can also use it to create an + HTTPError object and set its output without *raising* the exception. + """ + import cherrypy + + response = cherrypy.serving.response + + clean_headers(self.code) + + # In all cases, finalize will be called after this method, + # so don't bother cleaning up response values here. + response.status = self.status + tb = None + if cherrypy.serving.request.show_tracebacks: + tb = format_exc() + + response.headers.pop('Content-Length', None) + + content = self.get_error_page(self.status, traceback=tb, + message=self._message) + response.body = content + + _be_ie_unfriendly(self.code) + + def get_error_page(self, *args, **kwargs): + return get_error_page(*args, **kwargs) + + def __call__(self): + """Use this exception as a request.handler (raise self).""" + raise self + + +class NotFound(HTTPError): + + """Exception raised when a URL could not be mapped to any handler (404). + + This is equivalent to raising + :class:`HTTPError("404 Not Found") `. + """ + + def __init__(self, path=None): + if path is None: + import cherrypy + request = cherrypy.serving.request + path = request.script_name + request.path_info + self.args = (path,) + HTTPError.__init__(self, 404, "The path '%s' was not found." % path) + + +_HTTPErrorTemplate = ''' + + + + %(status)s + + + +

%(status)s

+

%(message)s

+
%(traceback)s
+
+ + Powered by CherryPy %(version)s + +
+ + +''' + + +def get_error_page(status, **kwargs): + """Return an HTML page, containing a pretty error response. + + status should be an int or a str. + kwargs will be interpolated into the page template. + """ + import cherrypy + + try: + code, reason, message = _httputil.valid_status(status) + except ValueError: + raise cherrypy.HTTPError(500, _exc_info()[1].args[0]) + + # We can't use setdefault here, because some + # callers send None for kwarg values. + if kwargs.get('status') is None: + kwargs['status'] = "%s %s" % (code, reason) + if kwargs.get('message') is None: + kwargs['message'] = message + if kwargs.get('traceback') is None: + kwargs['traceback'] = '' + if kwargs.get('version') is None: + kwargs['version'] = cherrypy.__version__ + + for k, v in iteritems(kwargs): + if v is None: + kwargs[k] = "" + else: + kwargs[k] = _escape(kwargs[k]) + + # Use a custom template or callable for the error page? + pages = cherrypy.serving.request.error_page + error_page = pages.get(code) or pages.get('default') + + # Default template, can be overridden below. + template = _HTTPErrorTemplate + if error_page: + try: + if hasattr(error_page, '__call__'): + # The caller function may be setting headers manually, + # so we delegate to it completely. We may be returning + # an iterator as well as a string here. + # + # We *must* make sure any content is not unicode. + result = error_page(**kwargs) + if cherrypy.lib.is_iterator(result): + from cherrypy.lib.encoding import UTF8StreamEncoder + return UTF8StreamEncoder(result) + elif isinstance(result, cherrypy._cpcompat.unicodestr): + return result.encode('utf-8') + else: + if not isinstance(result, cherrypy._cpcompat.bytestr): + raise ValueError('error page function did not ' + 'return a bytestring, unicodestring or an ' + 'iterator - returned object of type %s.' + % (type(result).__name__)) + return result + else: + # Load the template from this path. + template = tonative(open(error_page, 'rb').read()) + except: + e = _format_exception(*_exc_info())[-1] + m = kwargs['message'] + if m: + m += "
" + m += "In addition, the custom error page failed:\n
%s" % e + kwargs['message'] = m + + response = cherrypy.serving.response + response.headers['Content-Type'] = "text/html;charset=utf-8" + result = template % kwargs + return result.encode('utf-8') + + + +_ie_friendly_error_sizes = { + 400: 512, 403: 256, 404: 512, 405: 256, + 406: 512, 408: 512, 409: 512, 410: 256, + 500: 512, 501: 512, 505: 512, +} + + +def _be_ie_unfriendly(status): + import cherrypy + response = cherrypy.serving.response + + # For some statuses, Internet Explorer 5+ shows "friendly error + # messages" instead of our response.body if the body is smaller + # than a given size. Fix this by returning a body over that size + # (by adding whitespace). + # See http://support.microsoft.com/kb/q218155/ + s = _ie_friendly_error_sizes.get(status, 0) + if s: + s += 1 + # Since we are issuing an HTTP error status, we assume that + # the entity is short, and we should just collapse it. + content = response.collapse_body() + l = len(content) + if l and l < s: + # IN ADDITION: the response must be written to IE + # in one chunk or it will still get replaced! Bah. + content = content + (ntob(" ") * (s - l)) + response.body = content + response.headers['Content-Length'] = str(len(content)) + + +def format_exc(exc=None): + """Return exc (or sys.exc_info if None), formatted.""" + try: + if exc is None: + exc = _exc_info() + if exc == (None, None, None): + return "" + import traceback + return "".join(traceback.format_exception(*exc)) + finally: + del exc + + +def bare_error(extrabody=None): + """Produce status, headers, body for a critical error. + + Returns a triple without calling any other questionable functions, + so it should be as error-free as possible. Call it from an HTTP server + if you get errors outside of the request. + + If extrabody is None, a friendly but rather unhelpful error message + is set in the body. If extrabody is a string, it will be appended + as-is to the body. + """ + + # The whole point of this function is to be a last line-of-defense + # in handling errors. That is, it must not raise any errors itself; + # it cannot be allowed to fail. Therefore, don't add to it! + # In particular, don't call any other CP functions. + + body = ntob("Unrecoverable error in the server.") + if extrabody is not None: + if not isinstance(extrabody, bytestr): + extrabody = extrabody.encode('utf-8') + body += ntob("\n") + extrabody + + return (ntob("500 Internal Server Error"), + [(ntob('Content-Type'), ntob('text/plain')), + (ntob('Content-Length'), ntob(str(len(body)), 'ISO-8859-1'))], + [body]) diff --git a/lib/cherrypy/_cplogging.py b/lib/cherrypy/_cplogging.py new file mode 100644 index 00000000..554fd7ef --- /dev/null +++ b/lib/cherrypy/_cplogging.py @@ -0,0 +1,459 @@ +""" +Simple config +============= + +Although CherryPy uses the :mod:`Python logging module `, it does so +behind the scenes so that simple logging is simple, but complicated logging +is still possible. "Simple" logging means that you can log to the screen +(i.e. console/stdout) or to a file, and that you can easily have separate +error and access log files. + +Here are the simplified logging settings. You use these by adding lines to +your config file or dict. You should set these at either the global level or +per application (see next), but generally not both. + + * ``log.screen``: Set this to True to have both "error" and "access" messages + printed to stdout. + * ``log.access_file``: Set this to an absolute filename where you want + "access" messages written. + * ``log.error_file``: Set this to an absolute filename where you want "error" + messages written. + +Many events are automatically logged; to log your own application events, call +:func:`cherrypy.log`. + +Architecture +============ + +Separate scopes +--------------- + +CherryPy provides log managers at both the global and application layers. +This means you can have one set of logging rules for your entire site, +and another set of rules specific to each application. The global log +manager is found at :func:`cherrypy.log`, and the log manager for each +application is found at :attr:`app.log`. +If you're inside a request, the latter is reachable from +``cherrypy.request.app.log``; if you're outside a request, you'll have to +obtain a reference to the ``app``: either the return value of +:func:`tree.mount()` or, if you used +:func:`quickstart()` instead, via +``cherrypy.tree.apps['/']``. + +By default, the global logs are named "cherrypy.error" and "cherrypy.access", +and the application logs are named "cherrypy.error.2378745" and +"cherrypy.access.2378745" (the number is the id of the Application object). +This means that the application logs "bubble up" to the site logs, so if your +application has no log handlers, the site-level handlers will still log the +messages. + +Errors vs. Access +----------------- + +Each log manager handles both "access" messages (one per HTTP request) and +"error" messages (everything else). Note that the "error" log is not just for +errors! The format of access messages is highly formalized, but the error log +isn't--it receives messages from a variety of sources (including full error +tracebacks, if enabled). + +If you are logging the access log and error log to the same source, then there +is a possibility that a specially crafted error message may replicate an access +log message as described in CWE-117. In this case it is the application +developer's responsibility to manually escape data before using CherryPy's log() +functionality, or they may create an application that is vulnerable to CWE-117. +This would be achieved by using a custom handler escape any special characters, +and attached as described below. + +Custom Handlers +=============== + +The simple settings above work by manipulating Python's standard :mod:`logging` +module. So when you need something more complex, the full power of the standard +module is yours to exploit. You can borrow or create custom handlers, formats, +filters, and much more. Here's an example that skips the standard FileHandler +and uses a RotatingFileHandler instead: + +:: + + #python + log = app.log + + # Remove the default FileHandlers if present. + log.error_file = "" + log.access_file = "" + + maxBytes = getattr(log, "rot_maxBytes", 10000000) + backupCount = getattr(log, "rot_backupCount", 1000) + + # Make a new RotatingFileHandler for the error log. + fname = getattr(log, "rot_error_file", "error.log") + h = handlers.RotatingFileHandler(fname, 'a', maxBytes, backupCount) + h.setLevel(DEBUG) + h.setFormatter(_cplogging.logfmt) + log.error_log.addHandler(h) + + # Make a new RotatingFileHandler for the access log. + fname = getattr(log, "rot_access_file", "access.log") + h = handlers.RotatingFileHandler(fname, 'a', maxBytes, backupCount) + h.setLevel(DEBUG) + h.setFormatter(_cplogging.logfmt) + log.access_log.addHandler(h) + + +The ``rot_*`` attributes are pulled straight from the application log object. +Since "log.*" config entries simply set attributes on the log object, you can +add custom attributes to your heart's content. Note that these handlers are +used ''instead'' of the default, simple handlers outlined above (so don't set +the "log.error_file" config entry, for example). +""" + +import datetime +import logging +# Silence the no-handlers "warning" (stderr write!) in stdlib logging +logging.Logger.manager.emittedNoHandlerWarning = 1 +logfmt = logging.Formatter("%(message)s") +import os +import sys + +import cherrypy +from cherrypy import _cperror +from cherrypy._cpcompat import ntob, py3k + + +class NullHandler(logging.Handler): + + """A no-op logging handler to silence the logging.lastResort handler.""" + + def handle(self, record): + pass + + def emit(self, record): + pass + + def createLock(self): + self.lock = None + + +class LogManager(object): + + """An object to assist both simple and advanced logging. + + ``cherrypy.log`` is an instance of this class. + """ + + appid = None + """The id() of the Application object which owns this log manager. If this + is a global log manager, appid is None.""" + + error_log = None + """The actual :class:`logging.Logger` instance for error messages.""" + + access_log = None + """The actual :class:`logging.Logger` instance for access messages.""" + + if py3k: + access_log_format = \ + '{h} {l} {u} {t} "{r}" {s} {b} "{f}" "{a}"' + else: + access_log_format = \ + '%(h)s %(l)s %(u)s %(t)s "%(r)s" %(s)s %(b)s "%(f)s" "%(a)s"' + + logger_root = None + """The "top-level" logger name. + + This string will be used as the first segment in the Logger names. + The default is "cherrypy", for example, in which case the Logger names + will be of the form:: + + cherrypy.error. + cherrypy.access. + """ + + def __init__(self, appid=None, logger_root="cherrypy"): + self.logger_root = logger_root + self.appid = appid + if appid is None: + self.error_log = logging.getLogger("%s.error" % logger_root) + self.access_log = logging.getLogger("%s.access" % logger_root) + else: + self.error_log = logging.getLogger( + "%s.error.%s" % (logger_root, appid)) + self.access_log = logging.getLogger( + "%s.access.%s" % (logger_root, appid)) + self.error_log.setLevel(logging.INFO) + self.access_log.setLevel(logging.INFO) + + # Silence the no-handlers "warning" (stderr write!) in stdlib logging + self.error_log.addHandler(NullHandler()) + self.access_log.addHandler(NullHandler()) + + cherrypy.engine.subscribe('graceful', self.reopen_files) + + def reopen_files(self): + """Close and reopen all file handlers.""" + for log in (self.error_log, self.access_log): + for h in log.handlers: + if isinstance(h, logging.FileHandler): + h.acquire() + h.stream.close() + h.stream = open(h.baseFilename, h.mode) + h.release() + + def error(self, msg='', context='', severity=logging.INFO, + traceback=False): + """Write the given ``msg`` to the error log. + + This is not just for errors! Applications may call this at any time + to log application-specific information. + + If ``traceback`` is True, the traceback of the current exception + (if any) will be appended to ``msg``. + """ + if traceback: + msg += _cperror.format_exc() + self.error_log.log(severity, ' '.join((self.time(), context, msg))) + + def __call__(self, *args, **kwargs): + """An alias for ``error``.""" + return self.error(*args, **kwargs) + + def access(self): + """Write to the access log (in Apache/NCSA Combined Log format). + + See the + `apache documentation `_ + for format details. + + CherryPy calls this automatically for you. Note there are no arguments; + it collects the data itself from + :class:`cherrypy.request`. + + Like Apache started doing in 2.0.46, non-printable and other special + characters in %r (and we expand that to all parts) are escaped using + \\xhh sequences, where hh stands for the hexadecimal representation + of the raw byte. Exceptions from this rule are " and \\, which are + escaped by prepending a backslash, and all whitespace characters, + which are written in their C-style notation (\\n, \\t, etc). + """ + request = cherrypy.serving.request + remote = request.remote + response = cherrypy.serving.response + outheaders = response.headers + inheaders = request.headers + if response.output_status is None: + status = "-" + else: + status = response.output_status.split(ntob(" "), 1)[0] + if py3k: + status = status.decode('ISO-8859-1') + + atoms = {'h': remote.name or remote.ip, + 'l': '-', + 'u': getattr(request, "login", None) or "-", + 't': self.time(), + 'r': request.request_line, + 's': status, + 'b': dict.get(outheaders, 'Content-Length', '') or "-", + 'f': dict.get(inheaders, 'Referer', ''), + 'a': dict.get(inheaders, 'User-Agent', ''), + 'o': dict.get(inheaders, 'Host', '-'), + } + if py3k: + for k, v in atoms.items(): + if not isinstance(v, str): + v = str(v) + v = v.replace('"', '\\"').encode('utf8') + # Fortunately, repr(str) escapes unprintable chars, \n, \t, etc + # and backslash for us. All we have to do is strip the quotes. + v = repr(v)[2:-1] + + # in python 3.0 the repr of bytes (as returned by encode) + # uses double \'s. But then the logger escapes them yet, again + # resulting in quadruple slashes. Remove the extra one here. + v = v.replace('\\\\', '\\') + + # Escape double-quote. + atoms[k] = v + + try: + self.access_log.log( + logging.INFO, self.access_log_format.format(**atoms)) + except: + self(traceback=True) + else: + for k, v in atoms.items(): + if isinstance(v, unicode): + v = v.encode('utf8') + elif not isinstance(v, str): + v = str(v) + # Fortunately, repr(str) escapes unprintable chars, \n, \t, etc + # and backslash for us. All we have to do is strip the quotes. + v = repr(v)[1:-1] + # Escape double-quote. + atoms[k] = v.replace('"', '\\"') + + try: + self.access_log.log( + logging.INFO, self.access_log_format % atoms) + except: + self(traceback=True) + + def time(self): + """Return now() in Apache Common Log Format (no timezone).""" + now = datetime.datetime.now() + monthnames = ['jan', 'feb', 'mar', 'apr', 'may', 'jun', + 'jul', 'aug', 'sep', 'oct', 'nov', 'dec'] + month = monthnames[now.month - 1].capitalize() + return ('[%02d/%s/%04d:%02d:%02d:%02d]' % + (now.day, month, now.year, now.hour, now.minute, now.second)) + + def _get_builtin_handler(self, log, key): + for h in log.handlers: + if getattr(h, "_cpbuiltin", None) == key: + return h + + # ------------------------- Screen handlers ------------------------- # + def _set_screen_handler(self, log, enable, stream=None): + h = self._get_builtin_handler(log, "screen") + if enable: + if not h: + if stream is None: + stream = sys.stderr + h = logging.StreamHandler(stream) + h.setFormatter(logfmt) + h._cpbuiltin = "screen" + log.addHandler(h) + elif h: + log.handlers.remove(h) + + def _get_screen(self): + h = self._get_builtin_handler + has_h = h(self.error_log, "screen") or h(self.access_log, "screen") + return bool(has_h) + + def _set_screen(self, newvalue): + self._set_screen_handler(self.error_log, newvalue, stream=sys.stderr) + self._set_screen_handler(self.access_log, newvalue, stream=sys.stdout) + screen = property(_get_screen, _set_screen, + doc="""Turn stderr/stdout logging on or off. + + If you set this to True, it'll add the appropriate StreamHandler for + you. If you set it to False, it will remove the handler. + """) + + # -------------------------- File handlers -------------------------- # + + def _add_builtin_file_handler(self, log, fname): + h = logging.FileHandler(fname) + h.setFormatter(logfmt) + h._cpbuiltin = "file" + log.addHandler(h) + + def _set_file_handler(self, log, filename): + h = self._get_builtin_handler(log, "file") + if filename: + if h: + if h.baseFilename != os.path.abspath(filename): + h.close() + log.handlers.remove(h) + self._add_builtin_file_handler(log, filename) + else: + self._add_builtin_file_handler(log, filename) + else: + if h: + h.close() + log.handlers.remove(h) + + def _get_error_file(self): + h = self._get_builtin_handler(self.error_log, "file") + if h: + return h.baseFilename + return '' + + def _set_error_file(self, newvalue): + self._set_file_handler(self.error_log, newvalue) + error_file = property(_get_error_file, _set_error_file, + doc="""The filename for self.error_log. + + If you set this to a string, it'll add the appropriate FileHandler for + you. If you set it to ``None`` or ``''``, it will remove the handler. + """) + + def _get_access_file(self): + h = self._get_builtin_handler(self.access_log, "file") + if h: + return h.baseFilename + return '' + + def _set_access_file(self, newvalue): + self._set_file_handler(self.access_log, newvalue) + access_file = property(_get_access_file, _set_access_file, + doc="""The filename for self.access_log. + + If you set this to a string, it'll add the appropriate FileHandler for + you. If you set it to ``None`` or ``''``, it will remove the handler. + """) + + # ------------------------- WSGI handlers ------------------------- # + + def _set_wsgi_handler(self, log, enable): + h = self._get_builtin_handler(log, "wsgi") + if enable: + if not h: + h = WSGIErrorHandler() + h.setFormatter(logfmt) + h._cpbuiltin = "wsgi" + log.addHandler(h) + elif h: + log.handlers.remove(h) + + def _get_wsgi(self): + return bool(self._get_builtin_handler(self.error_log, "wsgi")) + + def _set_wsgi(self, newvalue): + self._set_wsgi_handler(self.error_log, newvalue) + wsgi = property(_get_wsgi, _set_wsgi, + doc="""Write errors to wsgi.errors. + + If you set this to True, it'll add the appropriate + :class:`WSGIErrorHandler` for you + (which writes errors to ``wsgi.errors``). + If you set it to False, it will remove the handler. + """) + + +class WSGIErrorHandler(logging.Handler): + + "A handler class which writes logging records to environ['wsgi.errors']." + + def flush(self): + """Flushes the stream.""" + try: + stream = cherrypy.serving.request.wsgi_environ.get('wsgi.errors') + except (AttributeError, KeyError): + pass + else: + stream.flush() + + def emit(self, record): + """Emit a record.""" + try: + stream = cherrypy.serving.request.wsgi_environ.get('wsgi.errors') + except (AttributeError, KeyError): + pass + else: + try: + msg = self.format(record) + fs = "%s\n" + import types + # if no unicode support... + if not hasattr(types, "UnicodeType"): + stream.write(fs % msg) + else: + try: + stream.write(fs % msg) + except UnicodeError: + stream.write(fs % msg.encode("UTF-8")) + self.flush() + except: + self.handleError(record) diff --git a/lib/cherrypy/_cpmodpy.py b/lib/cherrypy/_cpmodpy.py new file mode 100644 index 00000000..02154d69 --- /dev/null +++ b/lib/cherrypy/_cpmodpy.py @@ -0,0 +1,353 @@ +"""Native adapter for serving CherryPy via mod_python + +Basic usage: + +########################################## +# Application in a module called myapp.py +########################################## + +import cherrypy + +class Root: + @cherrypy.expose + def index(self): + return 'Hi there, Ho there, Hey there' + + +# We will use this method from the mod_python configuration +# as the entry point to our application +def setup_server(): + cherrypy.tree.mount(Root()) + cherrypy.config.update({'environment': 'production', + 'log.screen': False, + 'show_tracebacks': False}) + +########################################## +# mod_python settings for apache2 +# This should reside in your httpd.conf +# or a file that will be loaded at +# apache startup +########################################## + +# Start +DocumentRoot "/" +Listen 8080 +LoadModule python_module /usr/lib/apache2/modules/mod_python.so + + + PythonPath "sys.path+['/path/to/my/application']" + SetHandler python-program + PythonHandler cherrypy._cpmodpy::handler + PythonOption cherrypy.setup myapp::setup_server + PythonDebug On + +# End + +The actual path to your mod_python.so is dependent on your +environment. In this case we suppose a global mod_python +installation on a Linux distribution such as Ubuntu. + +We do set the PythonPath configuration setting so that +your application can be found by from the user running +the apache2 instance. Of course if your application +resides in the global site-package this won't be needed. + +Then restart apache2 and access http://127.0.0.1:8080 +""" + +import logging +import sys + +import cherrypy +from cherrypy._cpcompat import BytesIO, copyitems, ntob +from cherrypy._cperror import format_exc, bare_error +from cherrypy.lib import httputil + + +# ------------------------------ Request-handling + + +def setup(req): + from mod_python import apache + + # Run any setup functions defined by a "PythonOption cherrypy.setup" + # directive. + options = req.get_options() + if 'cherrypy.setup' in options: + for function in options['cherrypy.setup'].split(): + atoms = function.split('::', 1) + if len(atoms) == 1: + mod = __import__(atoms[0], globals(), locals()) + else: + modname, fname = atoms + mod = __import__(modname, globals(), locals(), [fname]) + func = getattr(mod, fname) + func() + + cherrypy.config.update({'log.screen': False, + "tools.ignore_headers.on": True, + "tools.ignore_headers.headers": ['Range'], + }) + + engine = cherrypy.engine + if hasattr(engine, "signal_handler"): + engine.signal_handler.unsubscribe() + if hasattr(engine, "console_control_handler"): + engine.console_control_handler.unsubscribe() + engine.autoreload.unsubscribe() + cherrypy.server.unsubscribe() + + def _log(msg, level): + newlevel = apache.APLOG_ERR + if logging.DEBUG >= level: + newlevel = apache.APLOG_DEBUG + elif logging.INFO >= level: + newlevel = apache.APLOG_INFO + elif logging.WARNING >= level: + newlevel = apache.APLOG_WARNING + # On Windows, req.server is required or the msg will vanish. See + # http://www.modpython.org/pipermail/mod_python/2003-October/014291.html + # Also, "When server is not specified...LogLevel does not apply..." + apache.log_error(msg, newlevel, req.server) + engine.subscribe('log', _log) + + engine.start() + + def cherrypy_cleanup(data): + engine.exit() + try: + # apache.register_cleanup wasn't available until 3.1.4. + apache.register_cleanup(cherrypy_cleanup) + except AttributeError: + req.server.register_cleanup(req, cherrypy_cleanup) + + +class _ReadOnlyRequest: + expose = ('read', 'readline', 'readlines') + + def __init__(self, req): + for method in self.expose: + self.__dict__[method] = getattr(req, method) + + +recursive = False + +_isSetUp = False + + +def handler(req): + from mod_python import apache + try: + global _isSetUp + if not _isSetUp: + setup(req) + _isSetUp = True + + # Obtain a Request object from CherryPy + local = req.connection.local_addr + local = httputil.Host( + local[0], local[1], req.connection.local_host or "") + remote = req.connection.remote_addr + remote = httputil.Host( + remote[0], remote[1], req.connection.remote_host or "") + + scheme = req.parsed_uri[0] or 'http' + req.get_basic_auth_pw() + + try: + # apache.mpm_query only became available in mod_python 3.1 + q = apache.mpm_query + threaded = q(apache.AP_MPMQ_IS_THREADED) + forked = q(apache.AP_MPMQ_IS_FORKED) + except AttributeError: + bad_value = ("You must provide a PythonOption '%s', " + "either 'on' or 'off', when running a version " + "of mod_python < 3.1") + + threaded = options.get('multithread', '').lower() + if threaded == 'on': + threaded = True + elif threaded == 'off': + threaded = False + else: + raise ValueError(bad_value % "multithread") + + forked = options.get('multiprocess', '').lower() + if forked == 'on': + forked = True + elif forked == 'off': + forked = False + else: + raise ValueError(bad_value % "multiprocess") + + sn = cherrypy.tree.script_name(req.uri or "/") + if sn is None: + send_response(req, '404 Not Found', [], '') + else: + app = cherrypy.tree.apps[sn] + method = req.method + path = req.uri + qs = req.args or "" + reqproto = req.protocol + headers = copyitems(req.headers_in) + rfile = _ReadOnlyRequest(req) + prev = None + + try: + redirections = [] + while True: + request, response = app.get_serving(local, remote, scheme, + "HTTP/1.1") + request.login = req.user + request.multithread = bool(threaded) + request.multiprocess = bool(forked) + request.app = app + request.prev = prev + + # Run the CherryPy Request object and obtain the response + try: + request.run(method, path, qs, reqproto, headers, rfile) + break + except cherrypy.InternalRedirect: + ir = sys.exc_info()[1] + app.release_serving() + prev = request + + if not recursive: + if ir.path in redirections: + raise RuntimeError( + "InternalRedirector visited the same URL " + "twice: %r" % ir.path) + else: + # Add the *previous* path_info + qs to + # redirections. + if qs: + qs = "?" + qs + redirections.append(sn + path + qs) + + # Munge environment and try again. + method = "GET" + path = ir.path + qs = ir.query_string + rfile = BytesIO() + + send_response( + req, response.output_status, response.header_list, + response.body, response.stream) + finally: + app.release_serving() + except: + tb = format_exc() + cherrypy.log(tb, 'MOD_PYTHON', severity=logging.ERROR) + s, h, b = bare_error() + send_response(req, s, h, b) + return apache.OK + + +def send_response(req, status, headers, body, stream=False): + # Set response status + req.status = int(status[:3]) + + # Set response headers + req.content_type = "text/plain" + for header, value in headers: + if header.lower() == 'content-type': + req.content_type = value + continue + req.headers_out.add(header, value) + + if stream: + # Flush now so the status and headers are sent immediately. + req.flush() + + # Set response body + if isinstance(body, basestring): + req.write(body) + else: + for seg in body: + req.write(seg) + + +# --------------- Startup tools for CherryPy + mod_python --------------- # +import os +import re +try: + import subprocess + + def popen(fullcmd): + p = subprocess.Popen(fullcmd, shell=True, + stdout=subprocess.PIPE, stderr=subprocess.STDOUT, + close_fds=True) + return p.stdout +except ImportError: + def popen(fullcmd): + pipein, pipeout = os.popen4(fullcmd) + return pipeout + + +def read_process(cmd, args=""): + fullcmd = "%s %s" % (cmd, args) + pipeout = popen(fullcmd) + try: + firstline = pipeout.readline() + cmd_not_found = re.search( + ntob("(not recognized|No such file|not found)"), + firstline, + re.IGNORECASE + ) + if cmd_not_found: + raise IOError('%s must be on your system path.' % cmd) + output = firstline + pipeout.read() + finally: + pipeout.close() + return output + + +class ModPythonServer(object): + + template = """ +# Apache2 server configuration file for running CherryPy with mod_python. + +DocumentRoot "/" +Listen %(port)s +LoadModule python_module modules/mod_python.so + + + SetHandler python-program + PythonHandler %(handler)s + PythonDebug On +%(opts)s + +""" + + def __init__(self, loc="/", port=80, opts=None, apache_path="apache", + handler="cherrypy._cpmodpy::handler"): + self.loc = loc + self.port = port + self.opts = opts + self.apache_path = apache_path + self.handler = handler + + def start(self): + opts = "".join([" PythonOption %s %s\n" % (k, v) + for k, v in self.opts]) + conf_data = self.template % {"port": self.port, + "loc": self.loc, + "opts": opts, + "handler": self.handler, + } + + mpconf = os.path.join(os.path.dirname(__file__), "cpmodpy.conf") + f = open(mpconf, 'wb') + try: + f.write(conf_data) + finally: + f.close() + + response = read_process(self.apache_path, "-k start -f %s" % mpconf) + self.ready = True + return response + + def stop(self): + os.popen("apache -k stop") + self.ready = False diff --git a/lib/cherrypy/_cpnative_server.py b/lib/cherrypy/_cpnative_server.py new file mode 100644 index 00000000..e303573d --- /dev/null +++ b/lib/cherrypy/_cpnative_server.py @@ -0,0 +1,154 @@ +"""Native adapter for serving CherryPy via its builtin server.""" + +import logging +import sys + +import cherrypy +from cherrypy._cpcompat import BytesIO +from cherrypy._cperror import format_exc, bare_error +from cherrypy.lib import httputil +from cherrypy import wsgiserver + + +class NativeGateway(wsgiserver.Gateway): + + recursive = False + + def respond(self): + req = self.req + try: + # Obtain a Request object from CherryPy + local = req.server.bind_addr + local = httputil.Host(local[0], local[1], "") + remote = req.conn.remote_addr, req.conn.remote_port + remote = httputil.Host(remote[0], remote[1], "") + + scheme = req.scheme + sn = cherrypy.tree.script_name(req.uri or "/") + if sn is None: + self.send_response('404 Not Found', [], ['']) + else: + app = cherrypy.tree.apps[sn] + method = req.method + path = req.path + qs = req.qs or "" + headers = req.inheaders.items() + rfile = req.rfile + prev = None + + try: + redirections = [] + while True: + request, response = app.get_serving( + local, remote, scheme, "HTTP/1.1") + request.multithread = True + request.multiprocess = False + request.app = app + request.prev = prev + + # Run the CherryPy Request object and obtain the + # response + try: + request.run(method, path, qs, + req.request_protocol, headers, rfile) + break + except cherrypy.InternalRedirect: + ir = sys.exc_info()[1] + app.release_serving() + prev = request + + if not self.recursive: + if ir.path in redirections: + raise RuntimeError( + "InternalRedirector visited the same " + "URL twice: %r" % ir.path) + else: + # Add the *previous* path_info + qs to + # redirections. + if qs: + qs = "?" + qs + redirections.append(sn + path + qs) + + # Munge environment and try again. + method = "GET" + path = ir.path + qs = ir.query_string + rfile = BytesIO() + + self.send_response( + response.output_status, response.header_list, + response.body) + finally: + app.release_serving() + except: + tb = format_exc() + # print tb + cherrypy.log(tb, 'NATIVE_ADAPTER', severity=logging.ERROR) + s, h, b = bare_error() + self.send_response(s, h, b) + + def send_response(self, status, headers, body): + req = self.req + + # Set response status + req.status = str(status or "500 Server Error") + + # Set response headers + for header, value in headers: + req.outheaders.append((header, value)) + if (req.ready and not req.sent_headers): + req.sent_headers = True + req.send_headers() + + # Set response body + for seg in body: + req.write(seg) + + +class CPHTTPServer(wsgiserver.HTTPServer): + + """Wrapper for wsgiserver.HTTPServer. + + wsgiserver has been designed to not reference CherryPy in any way, + so that it can be used in other frameworks and applications. + Therefore, we wrap it here, so we can apply some attributes + from config -> cherrypy.server -> HTTPServer. + """ + + def __init__(self, server_adapter=cherrypy.server): + self.server_adapter = server_adapter + + server_name = (self.server_adapter.socket_host or + self.server_adapter.socket_file or + None) + + wsgiserver.HTTPServer.__init__( + self, server_adapter.bind_addr, NativeGateway, + minthreads=server_adapter.thread_pool, + maxthreads=server_adapter.thread_pool_max, + server_name=server_name) + + self.max_request_header_size = ( + self.server_adapter.max_request_header_size or 0) + self.max_request_body_size = ( + self.server_adapter.max_request_body_size or 0) + self.request_queue_size = self.server_adapter.socket_queue_size + self.timeout = self.server_adapter.socket_timeout + self.shutdown_timeout = self.server_adapter.shutdown_timeout + self.protocol = self.server_adapter.protocol_version + self.nodelay = self.server_adapter.nodelay + + ssl_module = self.server_adapter.ssl_module or 'pyopenssl' + if self.server_adapter.ssl_context: + adapter_class = wsgiserver.get_ssl_adapter_class(ssl_module) + self.ssl_adapter = adapter_class( + self.server_adapter.ssl_certificate, + self.server_adapter.ssl_private_key, + self.server_adapter.ssl_certificate_chain) + self.ssl_adapter.context = self.server_adapter.ssl_context + elif self.server_adapter.ssl_certificate: + adapter_class = wsgiserver.get_ssl_adapter_class(ssl_module) + self.ssl_adapter = adapter_class( + self.server_adapter.ssl_certificate, + self.server_adapter.ssl_private_key, + self.server_adapter.ssl_certificate_chain) diff --git a/lib/cherrypy/_cpreqbody.py b/lib/cherrypy/_cpreqbody.py new file mode 100644 index 00000000..d2dbbc92 --- /dev/null +++ b/lib/cherrypy/_cpreqbody.py @@ -0,0 +1,1013 @@ +"""Request body processing for CherryPy. + +.. versionadded:: 3.2 + +Application authors have complete control over the parsing of HTTP request +entities. In short, +:attr:`cherrypy.request.body` +is now always set to an instance of +:class:`RequestBody`, +and *that* class is a subclass of :class:`Entity`. + +When an HTTP request includes an entity body, it is often desirable to +provide that information to applications in a form other than the raw bytes. +Different content types demand different approaches. Examples: + + * For a GIF file, we want the raw bytes in a stream. + * An HTML form is better parsed into its component fields, and each text field + decoded from bytes to unicode. + * A JSON body should be deserialized into a Python dict or list. + +When the request contains a Content-Type header, the media type is used as a +key to look up a value in the +:attr:`request.body.processors` dict. +If the full media +type is not found, then the major type is tried; for example, if no processor +is found for the 'image/jpeg' type, then we look for a processor for the +'image' types altogether. If neither the full type nor the major type has a +matching processor, then a default processor is used +(:func:`default_proc`). For most +types, this means no processing is done, and the body is left unread as a +raw byte stream. Processors are configurable in an 'on_start_resource' hook. + +Some processors, especially those for the 'text' types, attempt to decode bytes +to unicode. If the Content-Type request header includes a 'charset' parameter, +this is used to decode the entity. Otherwise, one or more default charsets may +be attempted, although this decision is up to each processor. If a processor +successfully decodes an Entity or Part, it should set the +:attr:`charset` attribute +on the Entity or Part to the name of the successful charset, so that +applications can easily re-encode or transcode the value if they wish. + +If the Content-Type of the request entity is of major type 'multipart', then +the above parsing process, and possibly a decoding process, is performed for +each part. + +For both the full entity and multipart parts, a Content-Disposition header may +be used to fill :attr:`name` and +:attr:`filename` attributes on the +request.body or the Part. + +.. _custombodyprocessors: + +Custom Processors +================= + +You can add your own processors for any specific or major MIME type. Simply add +it to the :attr:`processors` dict in a +hook/tool that runs at ``on_start_resource`` or ``before_request_body``. +Here's the built-in JSON tool for an example:: + + def json_in(force=True, debug=False): + request = cherrypy.serving.request + def json_processor(entity): + \"""Read application/json data into request.json.\""" + if not entity.headers.get("Content-Length", ""): + raise cherrypy.HTTPError(411) + + body = entity.fp.read() + try: + request.json = json_decode(body) + except ValueError: + raise cherrypy.HTTPError(400, 'Invalid JSON document') + if force: + request.body.processors.clear() + request.body.default_proc = cherrypy.HTTPError( + 415, 'Expected an application/json content type') + request.body.processors['application/json'] = json_processor + +We begin by defining a new ``json_processor`` function to stick in the +``processors`` dictionary. All processor functions take a single argument, +the ``Entity`` instance they are to process. It will be called whenever a +request is received (for those URI's where the tool is turned on) which +has a ``Content-Type`` of "application/json". + +First, it checks for a valid ``Content-Length`` (raising 411 if not valid), +then reads the remaining bytes on the socket. The ``fp`` object knows its +own length, so it won't hang waiting for data that never arrives. It will +return when all data has been read. Then, we decode those bytes using +Python's built-in ``json`` module, and stick the decoded result onto +``request.json`` . If it cannot be decoded, we raise 400. + +If the "force" argument is True (the default), the ``Tool`` clears the +``processors`` dict so that request entities of other ``Content-Types`` +aren't parsed at all. Since there's no entry for those invalid MIME +types, the ``default_proc`` method of ``cherrypy.request.body`` is +called. But this does nothing by default (usually to provide the page +handler an opportunity to handle it.) +But in our case, we want to raise 415, so we replace +``request.body.default_proc`` +with the error (``HTTPError`` instances, when called, raise themselves). + +If we were defining a custom processor, we can do so without making a ``Tool``. +Just add the config entry:: + + request.body.processors = {'application/json': json_processor} + +Note that you can only replace the ``processors`` dict wholesale this way, +not update the existing one. +""" + +try: + from io import DEFAULT_BUFFER_SIZE +except ImportError: + DEFAULT_BUFFER_SIZE = 8192 +import re +import sys +import tempfile +try: + from urllib import unquote_plus +except ImportError: + def unquote_plus(bs): + """Bytes version of urllib.parse.unquote_plus.""" + bs = bs.replace(ntob('+'), ntob(' ')) + atoms = bs.split(ntob('%')) + for i in range(1, len(atoms)): + item = atoms[i] + try: + pct = int(item[:2], 16) + atoms[i] = bytes([pct]) + item[2:] + except ValueError: + pass + return ntob('').join(atoms) + +import cherrypy +from cherrypy._cpcompat import basestring, ntob, ntou +from cherrypy.lib import httputil + + +# ------------------------------- Processors -------------------------------- # + +def process_urlencoded(entity): + """Read application/x-www-form-urlencoded data into entity.params.""" + qs = entity.fp.read() + for charset in entity.attempt_charsets: + try: + params = {} + for aparam in qs.split(ntob('&')): + for pair in aparam.split(ntob(';')): + if not pair: + continue + + atoms = pair.split(ntob('='), 1) + if len(atoms) == 1: + atoms.append(ntob('')) + + key = unquote_plus(atoms[0]).decode(charset) + value = unquote_plus(atoms[1]).decode(charset) + + if key in params: + if not isinstance(params[key], list): + params[key] = [params[key]] + params[key].append(value) + else: + params[key] = value + except UnicodeDecodeError: + pass + else: + entity.charset = charset + break + else: + raise cherrypy.HTTPError( + 400, "The request entity could not be decoded. The following " + "charsets were attempted: %s" % repr(entity.attempt_charsets)) + + # Now that all values have been successfully parsed and decoded, + # apply them to the entity.params dict. + for key, value in params.items(): + if key in entity.params: + if not isinstance(entity.params[key], list): + entity.params[key] = [entity.params[key]] + entity.params[key].append(value) + else: + entity.params[key] = value + + +def process_multipart(entity): + """Read all multipart parts into entity.parts.""" + ib = "" + if 'boundary' in entity.content_type.params: + # http://tools.ietf.org/html/rfc2046#section-5.1.1 + # "The grammar for parameters on the Content-type field is such that it + # is often necessary to enclose the boundary parameter values in quotes + # on the Content-type line" + ib = entity.content_type.params['boundary'].strip('"') + + if not re.match("^[ -~]{0,200}[!-~]$", ib): + raise ValueError('Invalid boundary in multipart form: %r' % (ib,)) + + ib = ('--' + ib).encode('ascii') + + # Find the first marker + while True: + b = entity.readline() + if not b: + return + + b = b.strip() + if b == ib: + break + + # Read all parts + while True: + part = entity.part_class.from_fp(entity.fp, ib) + entity.parts.append(part) + part.process() + if part.fp.done: + break + + +def process_multipart_form_data(entity): + """Read all multipart/form-data parts into entity.parts or entity.params. + """ + process_multipart(entity) + + kept_parts = [] + for part in entity.parts: + if part.name is None: + kept_parts.append(part) + else: + if part.filename is None: + # It's a regular field + value = part.fullvalue() + else: + # It's a file upload. Retain the whole part so consumer code + # has access to its .file and .filename attributes. + value = part + + if part.name in entity.params: + if not isinstance(entity.params[part.name], list): + entity.params[part.name] = [entity.params[part.name]] + entity.params[part.name].append(value) + else: + entity.params[part.name] = value + + entity.parts = kept_parts + + +def _old_process_multipart(entity): + """The behavior of 3.2 and lower. Deprecated and will be changed in 3.3.""" + process_multipart(entity) + + params = entity.params + + for part in entity.parts: + if part.name is None: + key = ntou('parts') + else: + key = part.name + + if part.filename is None: + # It's a regular field + value = part.fullvalue() + else: + # It's a file upload. Retain the whole part so consumer code + # has access to its .file and .filename attributes. + value = part + + if key in params: + if not isinstance(params[key], list): + params[key] = [params[key]] + params[key].append(value) + else: + params[key] = value + + +# -------------------------------- Entities --------------------------------- # +class Entity(object): + + """An HTTP request body, or MIME multipart body. + + This class collects information about the HTTP request entity. When a + given entity is of MIME type "multipart", each part is parsed into its own + Entity instance, and the set of parts stored in + :attr:`entity.parts`. + + Between the ``before_request_body`` and ``before_handler`` tools, CherryPy + tries to process the request body (if any) by calling + :func:`request.body.process`. + This uses the ``content_type`` of the Entity to look up a suitable + processor in + :attr:`Entity.processors`, + a dict. + If a matching processor cannot be found for the complete Content-Type, + it tries again using the major type. For example, if a request with an + entity of type "image/jpeg" arrives, but no processor can be found for + that complete type, then one is sought for the major type "image". If a + processor is still not found, then the + :func:`default_proc` method + of the Entity is called (which does nothing by default; you can + override this too). + + CherryPy includes processors for the "application/x-www-form-urlencoded" + type, the "multipart/form-data" type, and the "multipart" major type. + CherryPy 3.2 processes these types almost exactly as older versions. + Parts are passed as arguments to the page handler using their + ``Content-Disposition.name`` if given, otherwise in a generic "parts" + argument. Each such part is either a string, or the + :class:`Part` itself if it's a file. (In this + case it will have ``file`` and ``filename`` attributes, or possibly a + ``value`` attribute). Each Part is itself a subclass of + Entity, and has its own ``process`` method and ``processors`` dict. + + There is a separate processor for the "multipart" major type which is more + flexible, and simply stores all multipart parts in + :attr:`request.body.parts`. You can + enable it with:: + + cherrypy.request.body.processors['multipart'] = _cpreqbody.process_multipart + + in an ``on_start_resource`` tool. + """ + + # http://tools.ietf.org/html/rfc2046#section-4.1.2: + # "The default character set, which must be assumed in the + # absence of a charset parameter, is US-ASCII." + # However, many browsers send data in utf-8 with no charset. + attempt_charsets = ['utf-8'] + """A list of strings, each of which should be a known encoding. + + When the Content-Type of the request body warrants it, each of the given + encodings will be tried in order. The first one to successfully decode the + entity without raising an error is stored as + :attr:`entity.charset`. This defaults + to ``['utf-8']`` (plus 'ISO-8859-1' for "text/\*" types, as required by + `HTTP/1.1 `_), + but ``['us-ascii', 'utf-8']`` for multipart parts. + """ + + charset = None + """The successful decoding; see "attempt_charsets" above.""" + + content_type = None + """The value of the Content-Type request header. + + If the Entity is part of a multipart payload, this will be the Content-Type + given in the MIME headers for this part. + """ + + default_content_type = 'application/x-www-form-urlencoded' + """This defines a default ``Content-Type`` to use if no Content-Type header + is given. The empty string is used for RequestBody, which results in the + request body not being read or parsed at all. This is by design; a missing + ``Content-Type`` header in the HTTP request entity is an error at best, + and a security hole at worst. For multipart parts, however, the MIME spec + declares that a part with no Content-Type defaults to "text/plain" + (see :class:`Part`). + """ + + filename = None + """The ``Content-Disposition.filename`` header, if available.""" + + fp = None + """The readable socket file object.""" + + headers = None + """A dict of request/multipart header names and values. + + This is a copy of the ``request.headers`` for the ``request.body``; + for multipart parts, it is the set of headers for that part. + """ + + length = None + """The value of the ``Content-Length`` header, if provided.""" + + name = None + """The "name" parameter of the ``Content-Disposition`` header, if any.""" + + params = None + """ + If the request Content-Type is 'application/x-www-form-urlencoded' or + multipart, this will be a dict of the params pulled from the entity + body; that is, it will be the portion of request.params that come + from the message body (sometimes called "POST params", although they + can be sent with various HTTP method verbs). This value is set between + the 'before_request_body' and 'before_handler' hooks (assuming that + process_request_body is True).""" + + processors = {'application/x-www-form-urlencoded': process_urlencoded, + 'multipart/form-data': process_multipart_form_data, + 'multipart': process_multipart, + } + """A dict of Content-Type names to processor methods.""" + + parts = None + """A list of Part instances if ``Content-Type`` is of major type + "multipart".""" + + part_class = None + """The class used for multipart parts. + + You can replace this with custom subclasses to alter the processing of + multipart parts. + """ + + def __init__(self, fp, headers, params=None, parts=None): + # Make an instance-specific copy of the class processors + # so Tools, etc. can replace them per-request. + self.processors = self.processors.copy() + + self.fp = fp + self.headers = headers + + if params is None: + params = {} + self.params = params + + if parts is None: + parts = [] + self.parts = parts + + # Content-Type + self.content_type = headers.elements('Content-Type') + if self.content_type: + self.content_type = self.content_type[0] + else: + self.content_type = httputil.HeaderElement.from_str( + self.default_content_type) + + # Copy the class 'attempt_charsets', prepending any Content-Type + # charset + dec = self.content_type.params.get("charset", None) + if dec: + self.attempt_charsets = [dec] + [c for c in self.attempt_charsets + if c != dec] + else: + self.attempt_charsets = self.attempt_charsets[:] + + # Length + self.length = None + clen = headers.get('Content-Length', None) + # If Transfer-Encoding is 'chunked', ignore any Content-Length. + if ( + clen is not None and + 'chunked' not in headers.get('Transfer-Encoding', '') + ): + try: + self.length = int(clen) + except ValueError: + pass + + # Content-Disposition + self.name = None + self.filename = None + disp = headers.elements('Content-Disposition') + if disp: + disp = disp[0] + if 'name' in disp.params: + self.name = disp.params['name'] + if self.name.startswith('"') and self.name.endswith('"'): + self.name = self.name[1:-1] + if 'filename' in disp.params: + self.filename = disp.params['filename'] + if ( + self.filename.startswith('"') and + self.filename.endswith('"') + ): + self.filename = self.filename[1:-1] + + # The 'type' attribute is deprecated in 3.2; remove it in 3.3. + type = property( + lambda self: self.content_type, + doc="A deprecated alias for " + ":attr:`content_type`." + ) + + def read(self, size=None, fp_out=None): + return self.fp.read(size, fp_out) + + def readline(self, size=None): + return self.fp.readline(size) + + def readlines(self, sizehint=None): + return self.fp.readlines(sizehint) + + def __iter__(self): + return self + + def __next__(self): + line = self.readline() + if not line: + raise StopIteration + return line + + def next(self): + return self.__next__() + + def read_into_file(self, fp_out=None): + """Read the request body into fp_out (or make_file() if None). + + Return fp_out. + """ + if fp_out is None: + fp_out = self.make_file() + self.read(fp_out=fp_out) + return fp_out + + def make_file(self): + """Return a file-like object into which the request body will be read. + + By default, this will return a TemporaryFile. Override as needed. + See also :attr:`cherrypy._cpreqbody.Part.maxrambytes`.""" + return tempfile.TemporaryFile() + + def fullvalue(self): + """Return this entity as a string, whether stored in a file or not.""" + if self.file: + # It was stored in a tempfile. Read it. + self.file.seek(0) + value = self.file.read() + self.file.seek(0) + else: + value = self.value + return value + + def process(self): + """Execute the best-match processor for the given media type.""" + proc = None + ct = self.content_type.value + try: + proc = self.processors[ct] + except KeyError: + toptype = ct.split('/', 1)[0] + try: + proc = self.processors[toptype] + except KeyError: + pass + if proc is None: + self.default_proc() + else: + proc(self) + + def default_proc(self): + """Called if a more-specific processor is not found for the + ``Content-Type``. + """ + # Leave the fp alone for someone else to read. This works fine + # for request.body, but the Part subclasses need to override this + # so they can move on to the next part. + pass + + +class Part(Entity): + + """A MIME part entity, part of a multipart entity.""" + + # "The default character set, which must be assumed in the absence of a + # charset parameter, is US-ASCII." + attempt_charsets = ['us-ascii', 'utf-8'] + """A list of strings, each of which should be a known encoding. + + When the Content-Type of the request body warrants it, each of the given + encodings will be tried in order. The first one to successfully decode the + entity without raising an error is stored as + :attr:`entity.charset`. This defaults + to ``['utf-8']`` (plus 'ISO-8859-1' for "text/\*" types, as required by + `HTTP/1.1 `_), + but ``['us-ascii', 'utf-8']`` for multipart parts. + """ + + boundary = None + """The MIME multipart boundary.""" + + default_content_type = 'text/plain' + """This defines a default ``Content-Type`` to use if no Content-Type header + is given. The empty string is used for RequestBody, which results in the + request body not being read or parsed at all. This is by design; a missing + ``Content-Type`` header in the HTTP request entity is an error at best, + and a security hole at worst. For multipart parts, however (this class), + the MIME spec declares that a part with no Content-Type defaults to + "text/plain". + """ + + # This is the default in stdlib cgi. We may want to increase it. + maxrambytes = 1000 + """The threshold of bytes after which point the ``Part`` will store + its data in a file (generated by + :func:`make_file`) + instead of a string. Defaults to 1000, just like the :mod:`cgi` + module in Python's standard library. + """ + + def __init__(self, fp, headers, boundary): + Entity.__init__(self, fp, headers) + self.boundary = boundary + self.file = None + self.value = None + + def from_fp(cls, fp, boundary): + headers = cls.read_headers(fp) + return cls(fp, headers, boundary) + from_fp = classmethod(from_fp) + + def read_headers(cls, fp): + headers = httputil.HeaderMap() + while True: + line = fp.readline() + if not line: + # No more data--illegal end of headers + raise EOFError("Illegal end of headers.") + + if line == ntob('\r\n'): + # Normal end of headers + break + if not line.endswith(ntob('\r\n')): + raise ValueError("MIME requires CRLF terminators: %r" % line) + + if line[0] in ntob(' \t'): + # It's a continuation line. + v = line.strip().decode('ISO-8859-1') + else: + k, v = line.split(ntob(":"), 1) + k = k.strip().decode('ISO-8859-1') + v = v.strip().decode('ISO-8859-1') + + existing = headers.get(k) + if existing: + v = ", ".join((existing, v)) + headers[k] = v + + return headers + read_headers = classmethod(read_headers) + + def read_lines_to_boundary(self, fp_out=None): + """Read bytes from self.fp and return or write them to a file. + + If the 'fp_out' argument is None (the default), all bytes read are + returned in a single byte string. + + If the 'fp_out' argument is not None, it must be a file-like + object that supports the 'write' method; all bytes read will be + written to the fp, and that fp is returned. + """ + endmarker = self.boundary + ntob("--") + delim = ntob("") + prev_lf = True + lines = [] + seen = 0 + while True: + line = self.fp.readline(1 << 16) + if not line: + raise EOFError("Illegal end of multipart body.") + if line.startswith(ntob("--")) and prev_lf: + strippedline = line.strip() + if strippedline == self.boundary: + break + if strippedline == endmarker: + self.fp.finish() + break + + line = delim + line + + if line.endswith(ntob("\r\n")): + delim = ntob("\r\n") + line = line[:-2] + prev_lf = True + elif line.endswith(ntob("\n")): + delim = ntob("\n") + line = line[:-1] + prev_lf = True + else: + delim = ntob("") + prev_lf = False + + if fp_out is None: + lines.append(line) + seen += len(line) + if seen > self.maxrambytes: + fp_out = self.make_file() + for line in lines: + fp_out.write(line) + else: + fp_out.write(line) + + if fp_out is None: + result = ntob('').join(lines) + for charset in self.attempt_charsets: + try: + result = result.decode(charset) + except UnicodeDecodeError: + pass + else: + self.charset = charset + return result + else: + raise cherrypy.HTTPError( + 400, + "The request entity could not be decoded. The following " + "charsets were attempted: %s" % repr(self.attempt_charsets) + ) + else: + fp_out.seek(0) + return fp_out + + def default_proc(self): + """Called if a more-specific processor is not found for the + ``Content-Type``. + """ + if self.filename: + # Always read into a file if a .filename was given. + self.file = self.read_into_file() + else: + result = self.read_lines_to_boundary() + if isinstance(result, basestring): + self.value = result + else: + self.file = result + + def read_into_file(self, fp_out=None): + """Read the request body into fp_out (or make_file() if None). + + Return fp_out. + """ + if fp_out is None: + fp_out = self.make_file() + self.read_lines_to_boundary(fp_out=fp_out) + return fp_out + +Entity.part_class = Part + +try: + inf = float('inf') +except ValueError: + # Python 2.4 and lower + class Infinity(object): + + def __cmp__(self, other): + return 1 + + def __sub__(self, other): + return self + inf = Infinity() + + +comma_separated_headers = [ + 'Accept', 'Accept-Charset', 'Accept-Encoding', + 'Accept-Language', 'Accept-Ranges', 'Allow', + 'Cache-Control', 'Connection', 'Content-Encoding', + 'Content-Language', 'Expect', 'If-Match', + 'If-None-Match', 'Pragma', 'Proxy-Authenticate', + 'Te', 'Trailer', 'Transfer-Encoding', 'Upgrade', + 'Vary', 'Via', 'Warning', 'Www-Authenticate' +] + + +class SizedReader: + + def __init__(self, fp, length, maxbytes, bufsize=DEFAULT_BUFFER_SIZE, + has_trailers=False): + # Wrap our fp in a buffer so peek() works + self.fp = fp + self.length = length + self.maxbytes = maxbytes + self.buffer = ntob('') + self.bufsize = bufsize + self.bytes_read = 0 + self.done = False + self.has_trailers = has_trailers + + def read(self, size=None, fp_out=None): + """Read bytes from the request body and return or write them to a file. + + A number of bytes less than or equal to the 'size' argument are read + off the socket. The actual number of bytes read are tracked in + self.bytes_read. The number may be smaller than 'size' when 1) the + client sends fewer bytes, 2) the 'Content-Length' request header + specifies fewer bytes than requested, or 3) the number of bytes read + exceeds self.maxbytes (in which case, 413 is raised). + + If the 'fp_out' argument is None (the default), all bytes read are + returned in a single byte string. + + If the 'fp_out' argument is not None, it must be a file-like + object that supports the 'write' method; all bytes read will be + written to the fp, and None is returned. + """ + + if self.length is None: + if size is None: + remaining = inf + else: + remaining = size + else: + remaining = self.length - self.bytes_read + if size and size < remaining: + remaining = size + if remaining == 0: + self.finish() + if fp_out is None: + return ntob('') + else: + return None + + chunks = [] + + # Read bytes from the buffer. + if self.buffer: + if remaining is inf: + data = self.buffer + self.buffer = ntob('') + else: + data = self.buffer[:remaining] + self.buffer = self.buffer[remaining:] + datalen = len(data) + remaining -= datalen + + # Check lengths. + self.bytes_read += datalen + if self.maxbytes and self.bytes_read > self.maxbytes: + raise cherrypy.HTTPError(413) + + # Store the data. + if fp_out is None: + chunks.append(data) + else: + fp_out.write(data) + + # Read bytes from the socket. + while remaining > 0: + chunksize = min(remaining, self.bufsize) + try: + data = self.fp.read(chunksize) + except Exception: + e = sys.exc_info()[1] + if e.__class__.__name__ == 'MaxSizeExceeded': + # Post data is too big + raise cherrypy.HTTPError( + 413, "Maximum request length: %r" % e.args[1]) + else: + raise + if not data: + self.finish() + break + datalen = len(data) + remaining -= datalen + + # Check lengths. + self.bytes_read += datalen + if self.maxbytes and self.bytes_read > self.maxbytes: + raise cherrypy.HTTPError(413) + + # Store the data. + if fp_out is None: + chunks.append(data) + else: + fp_out.write(data) + + if fp_out is None: + return ntob('').join(chunks) + + def readline(self, size=None): + """Read a line from the request body and return it.""" + chunks = [] + while size is None or size > 0: + chunksize = self.bufsize + if size is not None and size < self.bufsize: + chunksize = size + data = self.read(chunksize) + if not data: + break + pos = data.find(ntob('\n')) + 1 + if pos: + chunks.append(data[:pos]) + remainder = data[pos:] + self.buffer += remainder + self.bytes_read -= len(remainder) + break + else: + chunks.append(data) + return ntob('').join(chunks) + + def readlines(self, sizehint=None): + """Read lines from the request body and return them.""" + if self.length is not None: + if sizehint is None: + sizehint = self.length - self.bytes_read + else: + sizehint = min(sizehint, self.length - self.bytes_read) + + lines = [] + seen = 0 + while True: + line = self.readline() + if not line: + break + lines.append(line) + seen += len(line) + if seen >= sizehint: + break + return lines + + def finish(self): + self.done = True + if self.has_trailers and hasattr(self.fp, 'read_trailer_lines'): + self.trailers = {} + + try: + for line in self.fp.read_trailer_lines(): + if line[0] in ntob(' \t'): + # It's a continuation line. + v = line.strip() + else: + try: + k, v = line.split(ntob(":"), 1) + except ValueError: + raise ValueError("Illegal header line.") + k = k.strip().title() + v = v.strip() + + if k in comma_separated_headers: + existing = self.trailers.get(envname) + if existing: + v = ntob(", ").join((existing, v)) + self.trailers[k] = v + except Exception: + e = sys.exc_info()[1] + if e.__class__.__name__ == 'MaxSizeExceeded': + # Post data is too big + raise cherrypy.HTTPError( + 413, "Maximum request length: %r" % e.args[1]) + else: + raise + + +class RequestBody(Entity): + + """The entity of the HTTP request.""" + + bufsize = 8 * 1024 + """The buffer size used when reading the socket.""" + + # Don't parse the request body at all if the client didn't provide + # a Content-Type header. See + # https://bitbucket.org/cherrypy/cherrypy/issue/790 + default_content_type = '' + """This defines a default ``Content-Type`` to use if no Content-Type header + is given. The empty string is used for RequestBody, which results in the + request body not being read or parsed at all. This is by design; a missing + ``Content-Type`` header in the HTTP request entity is an error at best, + and a security hole at worst. For multipart parts, however, the MIME spec + declares that a part with no Content-Type defaults to "text/plain" + (see :class:`Part`). + """ + + maxbytes = None + """Raise ``MaxSizeExceeded`` if more bytes than this are read from + the socket. + """ + + def __init__(self, fp, headers, params=None, request_params=None): + Entity.__init__(self, fp, headers, params) + + # http://www.w3.org/Protocols/rfc2616/rfc2616-sec3.html#sec3.7.1 + # When no explicit charset parameter is provided by the + # sender, media subtypes of the "text" type are defined + # to have a default charset value of "ISO-8859-1" when + # received via HTTP. + if self.content_type.value.startswith('text/'): + for c in ('ISO-8859-1', 'iso-8859-1', 'Latin-1', 'latin-1'): + if c in self.attempt_charsets: + break + else: + self.attempt_charsets.append('ISO-8859-1') + + # Temporary fix while deprecating passing .parts as .params. + self.processors['multipart'] = _old_process_multipart + + if request_params is None: + request_params = {} + self.request_params = request_params + + def process(self): + """Process the request entity based on its Content-Type.""" + # "The presence of a message-body in a request is signaled by the + # inclusion of a Content-Length or Transfer-Encoding header field in + # the request's message-headers." + # It is possible to send a POST request with no body, for example; + # however, app developers are responsible in that case to set + # cherrypy.request.process_body to False so this method isn't called. + h = cherrypy.serving.request.headers + if 'Content-Length' not in h and 'Transfer-Encoding' not in h: + raise cherrypy.HTTPError(411) + + self.fp = SizedReader(self.fp, self.length, + self.maxbytes, bufsize=self.bufsize, + has_trailers='Trailer' in h) + super(RequestBody, self).process() + + # Body params should also be a part of the request_params + # add them in here. + request_params = self.request_params + for key, value in self.params.items(): + # Python 2 only: keyword arguments must be byte strings (type + # 'str'). + if sys.version_info < (3, 0): + if isinstance(key, unicode): + key = key.encode('ISO-8859-1') + + if key in request_params: + if not isinstance(request_params[key], list): + request_params[key] = [request_params[key]] + request_params[key].append(value) + else: + request_params[key] = value diff --git a/lib/cherrypy/_cprequest.py b/lib/cherrypy/_cprequest.py new file mode 100644 index 00000000..290bd2eb --- /dev/null +++ b/lib/cherrypy/_cprequest.py @@ -0,0 +1,973 @@ + +import os +import sys +import time +import warnings + +import cherrypy +from cherrypy._cpcompat import basestring, copykeys, ntob, unicodestr +from cherrypy._cpcompat import SimpleCookie, CookieError, py3k +from cherrypy import _cpreqbody, _cpconfig +from cherrypy._cperror import format_exc, bare_error +from cherrypy.lib import httputil, file_generator + + +class Hook(object): + + """A callback and its metadata: failsafe, priority, and kwargs.""" + + callback = None + """ + The bare callable that this Hook object is wrapping, which will + be called when the Hook is called.""" + + failsafe = False + """ + If True, the callback is guaranteed to run even if other callbacks + from the same call point raise exceptions.""" + + priority = 50 + """ + Defines the order of execution for a list of Hooks. Priority numbers + should be limited to the closed interval [0, 100], but values outside + this range are acceptable, as are fractional values.""" + + kwargs = {} + """ + A set of keyword arguments that will be passed to the + callable on each call.""" + + def __init__(self, callback, failsafe=None, priority=None, **kwargs): + self.callback = callback + + if failsafe is None: + failsafe = getattr(callback, "failsafe", False) + self.failsafe = failsafe + + if priority is None: + priority = getattr(callback, "priority", 50) + self.priority = priority + + self.kwargs = kwargs + + def __lt__(self, other): + # Python 3 + return self.priority < other.priority + + def __cmp__(self, other): + # Python 2 + return cmp(self.priority, other.priority) + + def __call__(self): + """Run self.callback(**self.kwargs).""" + return self.callback(**self.kwargs) + + def __repr__(self): + cls = self.__class__ + return ("%s.%s(callback=%r, failsafe=%r, priority=%r, %s)" + % (cls.__module__, cls.__name__, self.callback, + self.failsafe, self.priority, + ", ".join(['%s=%r' % (k, v) + for k, v in self.kwargs.items()]))) + + +class HookMap(dict): + + """A map of call points to lists of callbacks (Hook objects).""" + + def __new__(cls, points=None): + d = dict.__new__(cls) + for p in points or []: + d[p] = [] + return d + + def __init__(self, *a, **kw): + pass + + def attach(self, point, callback, failsafe=None, priority=None, **kwargs): + """Append a new Hook made from the supplied arguments.""" + self[point].append(Hook(callback, failsafe, priority, **kwargs)) + + def run(self, point): + """Execute all registered Hooks (callbacks) for the given point.""" + exc = None + hooks = self[point] + hooks.sort() + for hook in hooks: + # Some hooks are guaranteed to run even if others at + # the same hookpoint fail. We will still log the failure, + # but proceed on to the next hook. The only way + # to stop all processing from one of these hooks is + # to raise SystemExit and stop the whole server. + if exc is None or hook.failsafe: + try: + hook() + except (KeyboardInterrupt, SystemExit): + raise + except (cherrypy.HTTPError, cherrypy.HTTPRedirect, + cherrypy.InternalRedirect): + exc = sys.exc_info()[1] + except: + exc = sys.exc_info()[1] + cherrypy.log(traceback=True, severity=40) + if exc: + raise exc + + def __copy__(self): + newmap = self.__class__() + # We can't just use 'update' because we want copies of the + # mutable values (each is a list) as well. + for k, v in self.items(): + newmap[k] = v[:] + return newmap + copy = __copy__ + + def __repr__(self): + cls = self.__class__ + return "%s.%s(points=%r)" % ( + cls.__module__, + cls.__name__, + copykeys(self) + ) + + +# Config namespace handlers + +def hooks_namespace(k, v): + """Attach bare hooks declared in config.""" + # Use split again to allow multiple hooks for a single + # hookpoint per path (e.g. "hooks.before_handler.1"). + # Little-known fact you only get from reading source ;) + hookpoint = k.split(".", 1)[0] + if isinstance(v, basestring): + v = cherrypy.lib.attributes(v) + if not isinstance(v, Hook): + v = Hook(v) + cherrypy.serving.request.hooks[hookpoint].append(v) + + +def request_namespace(k, v): + """Attach request attributes declared in config.""" + # Provides config entries to set request.body attrs (like + # attempt_charsets). + if k[:5] == 'body.': + setattr(cherrypy.serving.request.body, k[5:], v) + else: + setattr(cherrypy.serving.request, k, v) + + +def response_namespace(k, v): + """Attach response attributes declared in config.""" + # Provides config entries to set default response headers + # http://cherrypy.org/ticket/889 + if k[:8] == 'headers.': + cherrypy.serving.response.headers[k.split('.', 1)[1]] = v + else: + setattr(cherrypy.serving.response, k, v) + + +def error_page_namespace(k, v): + """Attach error pages declared in config.""" + if k != 'default': + k = int(k) + cherrypy.serving.request.error_page[k] = v + + +hookpoints = ['on_start_resource', 'before_request_body', + 'before_handler', 'before_finalize', + 'on_end_resource', 'on_end_request', + 'before_error_response', 'after_error_response'] + + +class Request(object): + + """An HTTP request. + + This object represents the metadata of an HTTP request message; + that is, it contains attributes which describe the environment + in which the request URL, headers, and body were sent (if you + want tools to interpret the headers and body, those are elsewhere, + mostly in Tools). This 'metadata' consists of socket data, + transport characteristics, and the Request-Line. This object + also contains data regarding the configuration in effect for + the given URL, and the execution plan for generating a response. + """ + + prev = None + """ + The previous Request object (if any). This should be None + unless we are processing an InternalRedirect.""" + + # Conversation/connection attributes + local = httputil.Host("127.0.0.1", 80) + "An httputil.Host(ip, port, hostname) object for the server socket." + + remote = httputil.Host("127.0.0.1", 1111) + "An httputil.Host(ip, port, hostname) object for the client socket." + + scheme = "http" + """ + The protocol used between client and server. In most cases, + this will be either 'http' or 'https'.""" + + server_protocol = "HTTP/1.1" + """ + The HTTP version for which the HTTP server is at least + conditionally compliant.""" + + base = "" + """The (scheme://host) portion of the requested URL. + In some cases (e.g. when proxying via mod_rewrite), this may contain + path segments which cherrypy.url uses when constructing url's, but + which otherwise are ignored by CherryPy. Regardless, this value + MUST NOT end in a slash.""" + + # Request-Line attributes + request_line = "" + """ + The complete Request-Line received from the client. This is a + single string consisting of the request method, URI, and protocol + version (joined by spaces). Any final CRLF is removed.""" + + method = "GET" + """ + Indicates the HTTP method to be performed on the resource identified + by the Request-URI. Common methods include GET, HEAD, POST, PUT, and + DELETE. CherryPy allows any extension method; however, various HTTP + servers and gateways may restrict the set of allowable methods. + CherryPy applications SHOULD restrict the set (on a per-URI basis).""" + + query_string = "" + """ + The query component of the Request-URI, a string of information to be + interpreted by the resource. The query portion of a URI follows the + path component, and is separated by a '?'. For example, the URI + 'http://www.cherrypy.org/wiki?a=3&b=4' has the query component, + 'a=3&b=4'.""" + + query_string_encoding = 'utf8' + """ + The encoding expected for query string arguments after % HEX HEX decoding). + If a query string is provided that cannot be decoded with this encoding, + 404 is raised (since technically it's a different URI). If you want + arbitrary encodings to not error, set this to 'Latin-1'; you can then + encode back to bytes and re-decode to whatever encoding you like later. + """ + + protocol = (1, 1) + """The HTTP protocol version corresponding to the set + of features which should be allowed in the response. If BOTH + the client's request message AND the server's level of HTTP + compliance is HTTP/1.1, this attribute will be the tuple (1, 1). + If either is 1.0, this attribute will be the tuple (1, 0). + Lower HTTP protocol versions are not explicitly supported.""" + + params = {} + """ + A dict which combines query string (GET) and request entity (POST) + variables. This is populated in two stages: GET params are added + before the 'on_start_resource' hook, and POST params are added + between the 'before_request_body' and 'before_handler' hooks.""" + + # Message attributes + header_list = [] + """ + A list of the HTTP request headers as (name, value) tuples. + In general, you should use request.headers (a dict) instead.""" + + headers = httputil.HeaderMap() + """ + A dict-like object containing the request headers. Keys are header + names (in Title-Case format); however, you may get and set them in + a case-insensitive manner. That is, headers['Content-Type'] and + headers['content-type'] refer to the same value. Values are header + values (decoded according to :rfc:`2047` if necessary). See also: + httputil.HeaderMap, httputil.HeaderElement.""" + + cookie = SimpleCookie() + """See help(Cookie).""" + + rfile = None + """ + If the request included an entity (body), it will be available + as a stream in this attribute. However, the rfile will normally + be read for you between the 'before_request_body' hook and the + 'before_handler' hook, and the resulting string is placed into + either request.params or the request.body attribute. + + You may disable the automatic consumption of the rfile by setting + request.process_request_body to False, either in config for the desired + path, or in an 'on_start_resource' or 'before_request_body' hook. + + WARNING: In almost every case, you should not attempt to read from the + rfile stream after CherryPy's automatic mechanism has read it. If you + turn off the automatic parsing of rfile, you should read exactly the + number of bytes specified in request.headers['Content-Length']. + Ignoring either of these warnings may result in a hung request thread + or in corruption of the next (pipelined) request. + """ + + process_request_body = True + """ + If True, the rfile (if any) is automatically read and parsed, + and the result placed into request.params or request.body.""" + + methods_with_bodies = ("POST", "PUT") + """ + A sequence of HTTP methods for which CherryPy will automatically + attempt to read a body from the rfile. If you are going to change + this property, modify it on the configuration (recommended) + or on the "hook point" `on_start_resource`. + """ + + body = None + """ + If the request Content-Type is 'application/x-www-form-urlencoded' + or multipart, this will be None. Otherwise, this will be an instance + of :class:`RequestBody` (which you + can .read()); this value is set between the 'before_request_body' and + 'before_handler' hooks (assuming that process_request_body is True).""" + + # Dispatch attributes + dispatch = cherrypy.dispatch.Dispatcher() + """ + The object which looks up the 'page handler' callable and collects + config for the current request based on the path_info, other + request attributes, and the application architecture. The core + calls the dispatcher as early as possible, passing it a 'path_info' + argument. + + The default dispatcher discovers the page handler by matching path_info + to a hierarchical arrangement of objects, starting at request.app.root. + See help(cherrypy.dispatch) for more information.""" + + script_name = "" + """ + The 'mount point' of the application which is handling this request. + + This attribute MUST NOT end in a slash. If the script_name refers to + the root of the URI, it MUST be an empty string (not "/"). + """ + + path_info = "/" + """ + The 'relative path' portion of the Request-URI. This is relative + to the script_name ('mount point') of the application which is + handling this request.""" + + login = None + """ + When authentication is used during the request processing this is + set to 'False' if it failed and to the 'username' value if it succeeded. + The default 'None' implies that no authentication happened.""" + + # Note that cherrypy.url uses "if request.app:" to determine whether + # the call is during a real HTTP request or not. So leave this None. + app = None + """The cherrypy.Application object which is handling this request.""" + + handler = None + """ + The function, method, or other callable which CherryPy will call to + produce the response. The discovery of the handler and the arguments + it will receive are determined by the request.dispatch object. + By default, the handler is discovered by walking a tree of objects + starting at request.app.root, and is then passed all HTTP params + (from the query string and POST body) as keyword arguments.""" + + toolmaps = {} + """ + A nested dict of all Toolboxes and Tools in effect for this request, + of the form: {Toolbox.namespace: {Tool.name: config dict}}.""" + + config = None + """ + A flat dict of all configuration entries which apply to the + current request. These entries are collected from global config, + application config (based on request.path_info), and from handler + config (exactly how is governed by the request.dispatch object in + effect for this request; by default, handler config can be attached + anywhere in the tree between request.app.root and the final handler, + and inherits downward).""" + + is_index = None + """ + This will be True if the current request is mapped to an 'index' + resource handler (also, a 'default' handler if path_info ends with + a slash). The value may be used to automatically redirect the + user-agent to a 'more canonical' URL which either adds or removes + the trailing slash. See cherrypy.tools.trailing_slash.""" + + hooks = HookMap(hookpoints) + """ + A HookMap (dict-like object) of the form: {hookpoint: [hook, ...]}. + Each key is a str naming the hook point, and each value is a list + of hooks which will be called at that hook point during this request. + The list of hooks is generally populated as early as possible (mostly + from Tools specified in config), but may be extended at any time. + See also: _cprequest.Hook, _cprequest.HookMap, and cherrypy.tools.""" + + error_response = cherrypy.HTTPError(500).set_response + """ + The no-arg callable which will handle unexpected, untrapped errors + during request processing. This is not used for expected exceptions + (like NotFound, HTTPError, or HTTPRedirect) which are raised in + response to expected conditions (those should be customized either + via request.error_page or by overriding HTTPError.set_response). + By default, error_response uses HTTPError(500) to return a generic + error response to the user-agent.""" + + error_page = {} + """ + A dict of {error code: response filename or callable} pairs. + + The error code must be an int representing a given HTTP error code, + or the string 'default', which will be used if no matching entry + is found for a given numeric code. + + If a filename is provided, the file should contain a Python string- + formatting template, and can expect by default to receive format + values with the mapping keys %(status)s, %(message)s, %(traceback)s, + and %(version)s. The set of format mappings can be extended by + overriding HTTPError.set_response. + + If a callable is provided, it will be called by default with keyword + arguments 'status', 'message', 'traceback', and 'version', as for a + string-formatting template. The callable must return a string or + iterable of strings which will be set to response.body. It may also + override headers or perform any other processing. + + If no entry is given for an error code, and no 'default' entry exists, + a default template will be used. + """ + + show_tracebacks = True + """ + If True, unexpected errors encountered during request processing will + include a traceback in the response body.""" + + show_mismatched_params = True + """ + If True, mismatched parameters encountered during PageHandler invocation + processing will be included in the response body.""" + + throws = (KeyboardInterrupt, SystemExit, cherrypy.InternalRedirect) + """The sequence of exceptions which Request.run does not trap.""" + + throw_errors = False + """ + If True, Request.run will not trap any errors (except HTTPRedirect and + HTTPError, which are more properly called 'exceptions', not errors).""" + + closed = False + """True once the close method has been called, False otherwise.""" + + stage = None + """ + A string containing the stage reached in the request-handling process. + This is useful when debugging a live server with hung requests.""" + + namespaces = _cpconfig.NamespaceSet( + **{"hooks": hooks_namespace, + "request": request_namespace, + "response": response_namespace, + "error_page": error_page_namespace, + "tools": cherrypy.tools, + }) + + def __init__(self, local_host, remote_host, scheme="http", + server_protocol="HTTP/1.1"): + """Populate a new Request object. + + local_host should be an httputil.Host object with the server info. + remote_host should be an httputil.Host object with the client info. + scheme should be a string, either "http" or "https". + """ + self.local = local_host + self.remote = remote_host + self.scheme = scheme + self.server_protocol = server_protocol + + self.closed = False + + # Put a *copy* of the class error_page into self. + self.error_page = self.error_page.copy() + + # Put a *copy* of the class namespaces into self. + self.namespaces = self.namespaces.copy() + + self.stage = None + + def close(self): + """Run cleanup code. (Core)""" + if not self.closed: + self.closed = True + self.stage = 'on_end_request' + self.hooks.run('on_end_request') + self.stage = 'close' + + def run(self, method, path, query_string, req_protocol, headers, rfile): + r"""Process the Request. (Core) + + method, path, query_string, and req_protocol should be pulled directly + from the Request-Line (e.g. "GET /path?key=val HTTP/1.0"). + + path + This should be %XX-unquoted, but query_string should not be. + + When using Python 2, they both MUST be byte strings, + not unicode strings. + + When using Python 3, they both MUST be unicode strings, + not byte strings, and preferably not bytes \x00-\xFF + disguised as unicode. + + headers + A list of (name, value) tuples. + + rfile + A file-like object containing the HTTP request entity. + + When run() is done, the returned object should have 3 attributes: + + * status, e.g. "200 OK" + * header_list, a list of (name, value) tuples + * body, an iterable yielding strings + + Consumer code (HTTP servers) should then access these response + attributes to build the outbound stream. + + """ + response = cherrypy.serving.response + self.stage = 'run' + try: + self.error_response = cherrypy.HTTPError(500).set_response + + self.method = method + path = path or "/" + self.query_string = query_string or '' + self.params = {} + + # Compare request and server HTTP protocol versions, in case our + # server does not support the requested protocol. Limit our output + # to min(req, server). We want the following output: + # request server actual written supported response + # protocol protocol response protocol feature set + # a 1.0 1.0 1.0 1.0 + # b 1.0 1.1 1.1 1.0 + # c 1.1 1.0 1.0 1.0 + # d 1.1 1.1 1.1 1.1 + # Notice that, in (b), the response will be "HTTP/1.1" even though + # the client only understands 1.0. RFC 2616 10.5.6 says we should + # only return 505 if the _major_ version is different. + rp = int(req_protocol[5]), int(req_protocol[7]) + sp = int(self.server_protocol[5]), int(self.server_protocol[7]) + self.protocol = min(rp, sp) + response.headers.protocol = self.protocol + + # Rebuild first line of the request (e.g. "GET /path HTTP/1.0"). + url = path + if query_string: + url += '?' + query_string + self.request_line = '%s %s %s' % (method, url, req_protocol) + + self.header_list = list(headers) + self.headers = httputil.HeaderMap() + + self.rfile = rfile + self.body = None + + self.cookie = SimpleCookie() + self.handler = None + + # path_info should be the path from the + # app root (script_name) to the handler. + self.script_name = self.app.script_name + self.path_info = pi = path[len(self.script_name):] + + self.stage = 'respond' + self.respond(pi) + + except self.throws: + raise + except: + if self.throw_errors: + raise + else: + # Failure in setup, error handler or finalize. Bypass them. + # Can't use handle_error because we may not have hooks yet. + cherrypy.log(traceback=True, severity=40) + if self.show_tracebacks: + body = format_exc() + else: + body = "" + r = bare_error(body) + response.output_status, response.header_list, response.body = r + + if self.method == "HEAD": + # HEAD requests MUST NOT return a message-body in the response. + response.body = [] + + try: + cherrypy.log.access() + except: + cherrypy.log.error(traceback=True) + + if response.timed_out: + raise cherrypy.TimeoutError() + + return response + + # Uncomment for stage debugging + # stage = property(lambda self: self._stage, lambda self, v: print(v)) + + def respond(self, path_info): + """Generate a response for the resource at self.path_info. (Core)""" + response = cherrypy.serving.response + try: + try: + try: + if self.app is None: + raise cherrypy.NotFound() + + # Get the 'Host' header, so we can HTTPRedirect properly. + self.stage = 'process_headers' + self.process_headers() + + # Make a copy of the class hooks + self.hooks = self.__class__.hooks.copy() + self.toolmaps = {} + + self.stage = 'get_resource' + self.get_resource(path_info) + + self.body = _cpreqbody.RequestBody( + self.rfile, self.headers, request_params=self.params) + + self.namespaces(self.config) + + self.stage = 'on_start_resource' + self.hooks.run('on_start_resource') + + # Parse the querystring + self.stage = 'process_query_string' + self.process_query_string() + + # Process the body + if self.process_request_body: + if self.method not in self.methods_with_bodies: + self.process_request_body = False + self.stage = 'before_request_body' + self.hooks.run('before_request_body') + if self.process_request_body: + self.body.process() + + # Run the handler + self.stage = 'before_handler' + self.hooks.run('before_handler') + if self.handler: + self.stage = 'handler' + response.body = self.handler() + + # Finalize + self.stage = 'before_finalize' + self.hooks.run('before_finalize') + response.finalize() + except (cherrypy.HTTPRedirect, cherrypy.HTTPError): + inst = sys.exc_info()[1] + inst.set_response() + self.stage = 'before_finalize (HTTPError)' + self.hooks.run('before_finalize') + response.finalize() + finally: + self.stage = 'on_end_resource' + self.hooks.run('on_end_resource') + except self.throws: + raise + except: + if self.throw_errors: + raise + self.handle_error() + + def process_query_string(self): + """Parse the query string into Python structures. (Core)""" + try: + p = httputil.parse_query_string( + self.query_string, encoding=self.query_string_encoding) + except UnicodeDecodeError: + raise cherrypy.HTTPError( + 404, "The given query string could not be processed. Query " + "strings for this resource must be encoded with %r." % + self.query_string_encoding) + + # Python 2 only: keyword arguments must be byte strings (type 'str'). + if not py3k: + for key, value in p.items(): + if isinstance(key, unicode): + del p[key] + p[key.encode(self.query_string_encoding)] = value + self.params.update(p) + + def process_headers(self): + """Parse HTTP header data into Python structures. (Core)""" + # Process the headers into self.headers + headers = self.headers + for name, value in self.header_list: + # Call title() now (and use dict.__method__(headers)) + # so title doesn't have to be called twice. + name = name.title() + value = value.strip() + + # Warning: if there is more than one header entry for cookies + # (AFAIK, only Konqueror does that), only the last one will + # remain in headers (but they will be correctly stored in + # request.cookie). + if "=?" in value: + dict.__setitem__(headers, name, httputil.decode_TEXT(value)) + else: + dict.__setitem__(headers, name, value) + + # Handle cookies differently because on Konqueror, multiple + # cookies come on different lines with the same key + if name == 'Cookie': + try: + self.cookie.load(value) + except CookieError: + msg = "Illegal cookie name %s" % value.split('=')[0] + raise cherrypy.HTTPError(400, msg) + + if not dict.__contains__(headers, 'Host'): + # All Internet-based HTTP/1.1 servers MUST respond with a 400 + # (Bad Request) status code to any HTTP/1.1 request message + # which lacks a Host header field. + if self.protocol >= (1, 1): + msg = "HTTP/1.1 requires a 'Host' request header." + raise cherrypy.HTTPError(400, msg) + host = dict.get(headers, 'Host') + if not host: + host = self.local.name or self.local.ip + self.base = "%s://%s" % (self.scheme, host) + + def get_resource(self, path): + """Call a dispatcher (which sets self.handler and .config). (Core)""" + # First, see if there is a custom dispatch at this URI. Custom + # dispatchers can only be specified in app.config, not in _cp_config + # (since custom dispatchers may not even have an app.root). + dispatch = self.app.find_config( + path, "request.dispatch", self.dispatch) + + # dispatch() should set self.handler and self.config + dispatch(path) + + def handle_error(self): + """Handle the last unanticipated exception. (Core)""" + try: + self.hooks.run("before_error_response") + if self.error_response: + self.error_response() + self.hooks.run("after_error_response") + cherrypy.serving.response.finalize() + except cherrypy.HTTPRedirect: + inst = sys.exc_info()[1] + inst.set_response() + cherrypy.serving.response.finalize() + + # ------------------------- Properties ------------------------- # + + def _get_body_params(self): + warnings.warn( + "body_params is deprecated in CherryPy 3.2, will be removed in " + "CherryPy 3.3.", + DeprecationWarning + ) + return self.body.params + body_params = property(_get_body_params, + doc=""" + If the request Content-Type is 'application/x-www-form-urlencoded' or + multipart, this will be a dict of the params pulled from the entity + body; that is, it will be the portion of request.params that come + from the message body (sometimes called "POST params", although they + can be sent with various HTTP method verbs). This value is set between + the 'before_request_body' and 'before_handler' hooks (assuming that + process_request_body is True). + + Deprecated in 3.2, will be removed for 3.3 in favor of + :attr:`request.body.params`.""") + + +class ResponseBody(object): + + """The body of the HTTP response (the response entity).""" + + if py3k: + unicode_err = ("Page handlers MUST return bytes. Use tools.encode " + "if you wish to return unicode.") + + def __get__(self, obj, objclass=None): + if obj is None: + # When calling on the class instead of an instance... + return self + else: + return obj._body + + def __set__(self, obj, value): + # Convert the given value to an iterable object. + if py3k and isinstance(value, str): + raise ValueError(self.unicode_err) + + if isinstance(value, basestring): + # strings get wrapped in a list because iterating over a single + # item list is much faster than iterating over every character + # in a long string. + if value: + value = [value] + else: + # [''] doesn't evaluate to False, so replace it with []. + value = [] + elif py3k and isinstance(value, list): + # every item in a list must be bytes... + for i, item in enumerate(value): + if isinstance(item, str): + raise ValueError(self.unicode_err) + # Don't use isinstance here; io.IOBase which has an ABC takes + # 1000 times as long as, say, isinstance(value, str) + elif hasattr(value, 'read'): + value = file_generator(value) + elif value is None: + value = [] + obj._body = value + + +class Response(object): + + """An HTTP Response, including status, headers, and body.""" + + status = "" + """The HTTP Status-Code and Reason-Phrase.""" + + header_list = [] + """ + A list of the HTTP response headers as (name, value) tuples. + In general, you should use response.headers (a dict) instead. This + attribute is generated from response.headers and is not valid until + after the finalize phase.""" + + headers = httputil.HeaderMap() + """ + A dict-like object containing the response headers. Keys are header + names (in Title-Case format); however, you may get and set them in + a case-insensitive manner. That is, headers['Content-Type'] and + headers['content-type'] refer to the same value. Values are header + values (decoded according to :rfc:`2047` if necessary). + + .. seealso:: classes :class:`HeaderMap`, :class:`HeaderElement` + """ + + cookie = SimpleCookie() + """See help(Cookie).""" + + body = ResponseBody() + """The body (entity) of the HTTP response.""" + + time = None + """The value of time.time() when created. Use in HTTP dates.""" + + timeout = 300 + """Seconds after which the response will be aborted.""" + + timed_out = False + """ + Flag to indicate the response should be aborted, because it has + exceeded its timeout.""" + + stream = False + """If False, buffer the response body.""" + + def __init__(self): + self.status = None + self.header_list = None + self._body = [] + self.time = time.time() + + self.headers = httputil.HeaderMap() + # Since we know all our keys are titled strings, we can + # bypass HeaderMap.update and get a big speed boost. + dict.update(self.headers, { + "Content-Type": 'text/html', + "Server": "CherryPy/" + cherrypy.__version__, + "Date": httputil.HTTPDate(self.time), + }) + self.cookie = SimpleCookie() + + def collapse_body(self): + """Collapse self.body to a single string; replace it and return it.""" + if isinstance(self.body, basestring): + return self.body + + newbody = [] + for chunk in self.body: + if py3k and not isinstance(chunk, bytes): + raise TypeError("Chunk %s is not of type 'bytes'." % + repr(chunk)) + newbody.append(chunk) + newbody = ntob('').join(newbody) + + self.body = newbody + return newbody + + def finalize(self): + """Transform headers (and cookies) into self.header_list. (Core)""" + try: + code, reason, _ = httputil.valid_status(self.status) + except ValueError: + raise cherrypy.HTTPError(500, sys.exc_info()[1].args[0]) + + headers = self.headers + + self.status = "%s %s" % (code, reason) + self.output_status = ntob(str(code), 'ascii') + \ + ntob(" ") + headers.encode(reason) + + if self.stream: + # The upshot: wsgiserver will chunk the response if + # you pop Content-Length (or set it explicitly to None). + # Note that lib.static sets C-L to the file's st_size. + if dict.get(headers, 'Content-Length') is None: + dict.pop(headers, 'Content-Length', None) + elif code < 200 or code in (204, 205, 304): + # "All 1xx (informational), 204 (no content), + # and 304 (not modified) responses MUST NOT + # include a message-body." + dict.pop(headers, 'Content-Length', None) + self.body = ntob("") + else: + # Responses which are not streamed should have a Content-Length, + # but allow user code to set Content-Length if desired. + if dict.get(headers, 'Content-Length') is None: + content = self.collapse_body() + dict.__setitem__(headers, 'Content-Length', len(content)) + + # Transform our header dict into a list of tuples. + self.header_list = h = headers.output() + + cookie = self.cookie.output() + if cookie: + for line in cookie.split("\n"): + if line.endswith("\r"): + # Python 2.4 emits cookies joined by LF but 2.5+ by CRLF. + line = line[:-1] + name, value = line.split(": ", 1) + if isinstance(name, unicodestr): + name = name.encode("ISO-8859-1") + if isinstance(value, unicodestr): + value = headers.encode(value) + h.append((name, value)) + + def check_timeout(self): + """If now > self.time + self.timeout, set self.timed_out. + + This purposefully sets a flag, rather than raising an error, + so that a monitor thread can interrupt the Response thread. + """ + if time.time() > self.time + self.timeout: + self.timed_out = True diff --git a/lib/cherrypy/_cpserver.py b/lib/cherrypy/_cpserver.py new file mode 100644 index 00000000..a31e7428 --- /dev/null +++ b/lib/cherrypy/_cpserver.py @@ -0,0 +1,226 @@ +"""Manage HTTP servers with CherryPy.""" + +import warnings + +import cherrypy +from cherrypy.lib import attributes +from cherrypy._cpcompat import basestring, py3k + +# We import * because we want to export check_port +# et al as attributes of this module. +from cherrypy.process.servers import * + + +class Server(ServerAdapter): + + """An adapter for an HTTP server. + + You can set attributes (like socket_host and socket_port) + on *this* object (which is probably cherrypy.server), and call + quickstart. For example:: + + cherrypy.server.socket_port = 80 + cherrypy.quickstart() + """ + + socket_port = 8080 + """The TCP port on which to listen for connections.""" + + _socket_host = '127.0.0.1' + + def _get_socket_host(self): + return self._socket_host + + def _set_socket_host(self, value): + if value == '': + raise ValueError("The empty string ('') is not an allowed value. " + "Use '0.0.0.0' instead to listen on all active " + "interfaces (INADDR_ANY).") + self._socket_host = value + socket_host = property( + _get_socket_host, + _set_socket_host, + doc="""The hostname or IP address on which to listen for connections. + + Host values may be any IPv4 or IPv6 address, or any valid hostname. + The string 'localhost' is a synonym for '127.0.0.1' (or '::1', if + your hosts file prefers IPv6). The string '0.0.0.0' is a special + IPv4 entry meaning "any active interface" (INADDR_ANY), and '::' + is the similar IN6ADDR_ANY for IPv6. The empty string or None are + not allowed.""") + + socket_file = None + """If given, the name of the UNIX socket to use instead of TCP/IP. + + When this option is not None, the `socket_host` and `socket_port` options + are ignored.""" + + socket_queue_size = 5 + """The 'backlog' argument to socket.listen(); specifies the maximum number + of queued connections (default 5).""" + + socket_timeout = 10 + """The timeout in seconds for accepted connections (default 10).""" + + accepted_queue_size = -1 + """The maximum number of requests which will be queued up before + the server refuses to accept it (default -1, meaning no limit).""" + + accepted_queue_timeout = 10 + """The timeout in seconds for attempting to add a request to the + queue when the queue is full (default 10).""" + + shutdown_timeout = 5 + """The time to wait for HTTP worker threads to clean up.""" + + protocol_version = 'HTTP/1.1' + """The version string to write in the Status-Line of all HTTP responses, + for example, "HTTP/1.1" (the default). Depending on the HTTP server used, + this should also limit the supported features used in the response.""" + + thread_pool = 10 + """The number of worker threads to start up in the pool.""" + + thread_pool_max = -1 + """The maximum size of the worker-thread pool. Use -1 to indicate no limit. + """ + + max_request_header_size = 500 * 1024 + """The maximum number of bytes allowable in the request headers. + If exceeded, the HTTP server should return "413 Request Entity Too Large". + """ + + max_request_body_size = 100 * 1024 * 1024 + """The maximum number of bytes allowable in the request body. If exceeded, + the HTTP server should return "413 Request Entity Too Large".""" + + instance = None + """If not None, this should be an HTTP server instance (such as + CPWSGIServer) which cherrypy.server will control. Use this when you need + more control over object instantiation than is available in the various + configuration options.""" + + ssl_context = None + """When using PyOpenSSL, an instance of SSL.Context.""" + + ssl_certificate = None + """The filename of the SSL certificate to use.""" + + ssl_certificate_chain = None + """When using PyOpenSSL, the certificate chain to pass to + Context.load_verify_locations.""" + + ssl_private_key = None + """The filename of the private key to use with SSL.""" + + if py3k: + ssl_module = 'builtin' + """The name of a registered SSL adaptation module to use with + the builtin WSGI server. Builtin options are: 'builtin' (to + use the SSL library built into recent versions of Python). + You may also register your own classes in the + wsgiserver.ssl_adapters dict.""" + else: + ssl_module = 'pyopenssl' + """The name of a registered SSL adaptation module to use with the + builtin WSGI server. Builtin options are 'builtin' (to use the SSL + library built into recent versions of Python) and 'pyopenssl' (to + use the PyOpenSSL project, which you must install separately). You + may also register your own classes in the wsgiserver.ssl_adapters + dict.""" + + statistics = False + """Turns statistics-gathering on or off for aware HTTP servers.""" + + nodelay = True + """If True (the default since 3.1), sets the TCP_NODELAY socket option.""" + + wsgi_version = (1, 0) + """The WSGI version tuple to use with the builtin WSGI server. + The provided options are (1, 0) [which includes support for PEP 3333, + which declares it covers WSGI version 1.0.1 but still mandates the + wsgi.version (1, 0)] and ('u', 0), an experimental unicode version. + You may create and register your own experimental versions of the WSGI + protocol by adding custom classes to the wsgiserver.wsgi_gateways dict.""" + + def __init__(self): + self.bus = cherrypy.engine + self.httpserver = None + self.interrupt = None + self.running = False + + def httpserver_from_self(self, httpserver=None): + """Return a (httpserver, bind_addr) pair based on self attributes.""" + if httpserver is None: + httpserver = self.instance + if httpserver is None: + from cherrypy import _cpwsgi_server + httpserver = _cpwsgi_server.CPWSGIServer(self) + if isinstance(httpserver, basestring): + # Is anyone using this? Can I add an arg? + httpserver = attributes(httpserver)(self) + return httpserver, self.bind_addr + + def start(self): + """Start the HTTP server.""" + if not self.httpserver: + self.httpserver, self.bind_addr = self.httpserver_from_self() + ServerAdapter.start(self) + start.priority = 75 + + def _get_bind_addr(self): + if self.socket_file: + return self.socket_file + if self.socket_host is None and self.socket_port is None: + return None + return (self.socket_host, self.socket_port) + + def _set_bind_addr(self, value): + if value is None: + self.socket_file = None + self.socket_host = None + self.socket_port = None + elif isinstance(value, basestring): + self.socket_file = value + self.socket_host = None + self.socket_port = None + else: + try: + self.socket_host, self.socket_port = value + self.socket_file = None + except ValueError: + raise ValueError("bind_addr must be a (host, port) tuple " + "(for TCP sockets) or a string (for Unix " + "domain sockets), not %r" % value) + bind_addr = property( + _get_bind_addr, + _set_bind_addr, + doc='A (host, port) tuple for TCP sockets or ' + 'a str for Unix domain sockets.') + + def base(self): + """Return the base (scheme://host[:port] or sock file) for this server. + """ + if self.socket_file: + return self.socket_file + + host = self.socket_host + if host in ('0.0.0.0', '::'): + # 0.0.0.0 is INADDR_ANY and :: is IN6ADDR_ANY. + # Look up the host name, which should be the + # safest thing to spit out in a URL. + import socket + host = socket.gethostname() + + port = self.socket_port + + if self.ssl_certificate: + scheme = "https" + if port != 443: + host += ":%s" % port + else: + scheme = "http" + if port != 80: + host += ":%s" % port + + return "%s://%s" % (scheme, host) diff --git a/lib/cherrypy/_cpthreadinglocal.py b/lib/cherrypy/_cpthreadinglocal.py new file mode 100644 index 00000000..238c3224 --- /dev/null +++ b/lib/cherrypy/_cpthreadinglocal.py @@ -0,0 +1,241 @@ +# This is a backport of Python-2.4's threading.local() implementation + +"""Thread-local objects + +(Note that this module provides a Python version of thread + threading.local class. Depending on the version of Python you're + using, there may be a faster one available. You should always import + the local class from threading.) + +Thread-local objects support the management of thread-local data. +If you have data that you want to be local to a thread, simply create +a thread-local object and use its attributes: + + >>> mydata = local() + >>> mydata.number = 42 + >>> mydata.number + 42 + +You can also access the local-object's dictionary: + + >>> mydata.__dict__ + {'number': 42} + >>> mydata.__dict__.setdefault('widgets', []) + [] + >>> mydata.widgets + [] + +What's important about thread-local objects is that their data are +local to a thread. If we access the data in a different thread: + + >>> log = [] + >>> def f(): + ... items = mydata.__dict__.items() + ... items.sort() + ... log.append(items) + ... mydata.number = 11 + ... log.append(mydata.number) + + >>> import threading + >>> thread = threading.Thread(target=f) + >>> thread.start() + >>> thread.join() + >>> log + [[], 11] + +we get different data. Furthermore, changes made in the other thread +don't affect data seen in this thread: + + >>> mydata.number + 42 + +Of course, values you get from a local object, including a __dict__ +attribute, are for whatever thread was current at the time the +attribute was read. For that reason, you generally don't want to save +these values across threads, as they apply only to the thread they +came from. + +You can create custom local objects by subclassing the local class: + + >>> class MyLocal(local): + ... number = 2 + ... initialized = False + ... def __init__(self, **kw): + ... if self.initialized: + ... raise SystemError('__init__ called too many times') + ... self.initialized = True + ... self.__dict__.update(kw) + ... def squared(self): + ... return self.number ** 2 + +This can be useful to support default values, methods and +initialization. Note that if you define an __init__ method, it will be +called each time the local object is used in a separate thread. This +is necessary to initialize each thread's dictionary. + +Now if we create a local object: + + >>> mydata = MyLocal(color='red') + +Now we have a default number: + + >>> mydata.number + 2 + +an initial color: + + >>> mydata.color + 'red' + >>> del mydata.color + +And a method that operates on the data: + + >>> mydata.squared() + 4 + +As before, we can access the data in a separate thread: + + >>> log = [] + >>> thread = threading.Thread(target=f) + >>> thread.start() + >>> thread.join() + >>> log + [[('color', 'red'), ('initialized', True)], 11] + +without affecting this thread's data: + + >>> mydata.number + 2 + >>> mydata.color + Traceback (most recent call last): + ... + AttributeError: 'MyLocal' object has no attribute 'color' + +Note that subclasses can define slots, but they are not thread +local. They are shared across threads: + + >>> class MyLocal(local): + ... __slots__ = 'number' + + >>> mydata = MyLocal() + >>> mydata.number = 42 + >>> mydata.color = 'red' + +So, the separate thread: + + >>> thread = threading.Thread(target=f) + >>> thread.start() + >>> thread.join() + +affects what we see: + + >>> mydata.number + 11 + +>>> del mydata +""" + +# Threading import is at end + + +class _localbase(object): + __slots__ = '_local__key', '_local__args', '_local__lock' + + def __new__(cls, *args, **kw): + self = object.__new__(cls) + key = 'thread.local.' + str(id(self)) + object.__setattr__(self, '_local__key', key) + object.__setattr__(self, '_local__args', (args, kw)) + object.__setattr__(self, '_local__lock', RLock()) + + if args or kw and (cls.__init__ is object.__init__): + raise TypeError("Initialization arguments are not supported") + + # We need to create the thread dict in anticipation of + # __init__ being called, to make sure we don't call it + # again ourselves. + dict = object.__getattribute__(self, '__dict__') + currentThread().__dict__[key] = dict + + return self + + +def _patch(self): + key = object.__getattribute__(self, '_local__key') + d = currentThread().__dict__.get(key) + if d is None: + d = {} + currentThread().__dict__[key] = d + object.__setattr__(self, '__dict__', d) + + # we have a new instance dict, so call out __init__ if we have + # one + cls = type(self) + if cls.__init__ is not object.__init__: + args, kw = object.__getattribute__(self, '_local__args') + cls.__init__(self, *args, **kw) + else: + object.__setattr__(self, '__dict__', d) + + +class local(_localbase): + + def __getattribute__(self, name): + lock = object.__getattribute__(self, '_local__lock') + lock.acquire() + try: + _patch(self) + return object.__getattribute__(self, name) + finally: + lock.release() + + def __setattr__(self, name, value): + lock = object.__getattribute__(self, '_local__lock') + lock.acquire() + try: + _patch(self) + return object.__setattr__(self, name, value) + finally: + lock.release() + + def __delattr__(self, name): + lock = object.__getattribute__(self, '_local__lock') + lock.acquire() + try: + _patch(self) + return object.__delattr__(self, name) + finally: + lock.release() + + def __del__(): + threading_enumerate = enumerate + __getattribute__ = object.__getattribute__ + + def __del__(self): + key = __getattribute__(self, '_local__key') + + try: + threads = list(threading_enumerate()) + except: + # if enumerate fails, as it seems to do during + # shutdown, we'll skip cleanup under the assumption + # that there is nothing to clean up + return + + for thread in threads: + try: + __dict__ = thread.__dict__ + except AttributeError: + # Thread is dying, rest in peace + continue + + if key in __dict__: + try: + del __dict__[key] + except KeyError: + pass # didn't have anything in this thread + + return __del__ + __del__ = __del__() + +from threading import currentThread, enumerate, RLock diff --git a/lib/cherrypy/_cptools.py b/lib/cherrypy/_cptools.py new file mode 100644 index 00000000..06a56e87 --- /dev/null +++ b/lib/cherrypy/_cptools.py @@ -0,0 +1,529 @@ +"""CherryPy tools. A "tool" is any helper, adapted to CP. + +Tools are usually designed to be used in a variety of ways (although some +may only offer one if they choose): + + Library calls + All tools are callables that can be used wherever needed. + The arguments are straightforward and should be detailed within the + docstring. + + Function decorators + All tools, when called, may be used as decorators which configure + individual CherryPy page handlers (methods on the CherryPy tree). + That is, "@tools.anytool()" should "turn on" the tool via the + decorated function's _cp_config attribute. + + CherryPy config + If a tool exposes a "_setup" callable, it will be called + once per Request (if the feature is "turned on" via config). + +Tools may be implemented as any object with a namespace. The builtins +are generally either modules or instances of the tools.Tool class. +""" + +import sys +import warnings + +import cherrypy + + +def _getargs(func): + """Return the names of all static arguments to the given function.""" + # Use this instead of importing inspect for less mem overhead. + import types + if sys.version_info >= (3, 0): + if isinstance(func, types.MethodType): + func = func.__func__ + co = func.__code__ + else: + if isinstance(func, types.MethodType): + func = func.im_func + co = func.func_code + return co.co_varnames[:co.co_argcount] + + +_attr_error = ( + "CherryPy Tools cannot be turned on directly. Instead, turn them " + "on via config, or use them as decorators on your page handlers." +) + + +class Tool(object): + + """A registered function for use with CherryPy request-processing hooks. + + help(tool.callable) should give you more information about this Tool. + """ + + namespace = "tools" + + def __init__(self, point, callable, name=None, priority=50): + self._point = point + self.callable = callable + self._name = name + self._priority = priority + self.__doc__ = self.callable.__doc__ + self._setargs() + + def _get_on(self): + raise AttributeError(_attr_error) + + def _set_on(self, value): + raise AttributeError(_attr_error) + on = property(_get_on, _set_on) + + def _setargs(self): + """Copy func parameter names to obj attributes.""" + try: + for arg in _getargs(self.callable): + setattr(self, arg, None) + except (TypeError, AttributeError): + if hasattr(self.callable, "__call__"): + for arg in _getargs(self.callable.__call__): + setattr(self, arg, None) + # IronPython 1.0 raises NotImplementedError because + # inspect.getargspec tries to access Python bytecode + # in co_code attribute. + except NotImplementedError: + pass + # IronPython 1B1 may raise IndexError in some cases, + # but if we trap it here it doesn't prevent CP from working. + except IndexError: + pass + + def _merged_args(self, d=None): + """Return a dict of configuration entries for this Tool.""" + if d: + conf = d.copy() + else: + conf = {} + + tm = cherrypy.serving.request.toolmaps[self.namespace] + if self._name in tm: + conf.update(tm[self._name]) + + if "on" in conf: + del conf["on"] + + return conf + + def __call__(self, *args, **kwargs): + """Compile-time decorator (turn on the tool in config). + + For example:: + + @tools.proxy() + def whats_my_base(self): + return cherrypy.request.base + whats_my_base.exposed = True + """ + if args: + raise TypeError("The %r Tool does not accept positional " + "arguments; you must use keyword arguments." + % self._name) + + def tool_decorator(f): + if not hasattr(f, "_cp_config"): + f._cp_config = {} + subspace = self.namespace + "." + self._name + "." + f._cp_config[subspace + "on"] = True + for k, v in kwargs.items(): + f._cp_config[subspace + k] = v + return f + return tool_decorator + + def _setup(self): + """Hook this tool into cherrypy.request. + + The standard CherryPy request object will automatically call this + method when the tool is "turned on" in config. + """ + conf = self._merged_args() + p = conf.pop("priority", None) + if p is None: + p = getattr(self.callable, "priority", self._priority) + cherrypy.serving.request.hooks.attach(self._point, self.callable, + priority=p, **conf) + + +class HandlerTool(Tool): + + """Tool which is called 'before main', that may skip normal handlers. + + If the tool successfully handles the request (by setting response.body), + if should return True. This will cause CherryPy to skip any 'normal' page + handler. If the tool did not handle the request, it should return False + to tell CherryPy to continue on and call the normal page handler. If the + tool is declared AS a page handler (see the 'handler' method), returning + False will raise NotFound. + """ + + def __init__(self, callable, name=None): + Tool.__init__(self, 'before_handler', callable, name) + + def handler(self, *args, **kwargs): + """Use this tool as a CherryPy page handler. + + For example:: + + class Root: + nav = tools.staticdir.handler(section="/nav", dir="nav", + root=absDir) + """ + def handle_func(*a, **kw): + handled = self.callable(*args, **self._merged_args(kwargs)) + if not handled: + raise cherrypy.NotFound() + return cherrypy.serving.response.body + handle_func.exposed = True + return handle_func + + def _wrapper(self, **kwargs): + if self.callable(**kwargs): + cherrypy.serving.request.handler = None + + def _setup(self): + """Hook this tool into cherrypy.request. + + The standard CherryPy request object will automatically call this + method when the tool is "turned on" in config. + """ + conf = self._merged_args() + p = conf.pop("priority", None) + if p is None: + p = getattr(self.callable, "priority", self._priority) + cherrypy.serving.request.hooks.attach(self._point, self._wrapper, + priority=p, **conf) + + +class HandlerWrapperTool(Tool): + + """Tool which wraps request.handler in a provided wrapper function. + + The 'newhandler' arg must be a handler wrapper function that takes a + 'next_handler' argument, plus ``*args`` and ``**kwargs``. Like all + page handler + functions, it must return an iterable for use as cherrypy.response.body. + + For example, to allow your 'inner' page handlers to return dicts + which then get interpolated into a template:: + + def interpolator(next_handler, *args, **kwargs): + filename = cherrypy.request.config.get('template') + cherrypy.response.template = env.get_template(filename) + response_dict = next_handler(*args, **kwargs) + return cherrypy.response.template.render(**response_dict) + cherrypy.tools.jinja = HandlerWrapperTool(interpolator) + """ + + def __init__(self, newhandler, point='before_handler', name=None, + priority=50): + self.newhandler = newhandler + self._point = point + self._name = name + self._priority = priority + + def callable(self, *args, **kwargs): + innerfunc = cherrypy.serving.request.handler + + def wrap(*args, **kwargs): + return self.newhandler(innerfunc, *args, **kwargs) + cherrypy.serving.request.handler = wrap + + +class ErrorTool(Tool): + + """Tool which is used to replace the default request.error_response.""" + + def __init__(self, callable, name=None): + Tool.__init__(self, None, callable, name) + + def _wrapper(self): + self.callable(**self._merged_args()) + + def _setup(self): + """Hook this tool into cherrypy.request. + + The standard CherryPy request object will automatically call this + method when the tool is "turned on" in config. + """ + cherrypy.serving.request.error_response = self._wrapper + + +# Builtin tools # + +from cherrypy.lib import cptools, encoding, auth, static, jsontools +from cherrypy.lib import sessions as _sessions, xmlrpcutil as _xmlrpc +from cherrypy.lib import caching as _caching +from cherrypy.lib import auth_basic, auth_digest + + +class SessionTool(Tool): + + """Session Tool for CherryPy. + + sessions.locking + When 'implicit' (the default), the session will be locked for you, + just before running the page handler. + + When 'early', the session will be locked before reading the request + body. This is off by default for safety reasons; for example, + a large upload would block the session, denying an AJAX + progress meter + (`issue `_). + + When 'explicit' (or any other value), you need to call + cherrypy.session.acquire_lock() yourself before using + session data. + """ + + def __init__(self): + # _sessions.init must be bound after headers are read + Tool.__init__(self, 'before_request_body', _sessions.init) + + def _lock_session(self): + cherrypy.serving.session.acquire_lock() + + def _setup(self): + """Hook this tool into cherrypy.request. + + The standard CherryPy request object will automatically call this + method when the tool is "turned on" in config. + """ + hooks = cherrypy.serving.request.hooks + + conf = self._merged_args() + + p = conf.pop("priority", None) + if p is None: + p = getattr(self.callable, "priority", self._priority) + + hooks.attach(self._point, self.callable, priority=p, **conf) + + locking = conf.pop('locking', 'implicit') + if locking == 'implicit': + hooks.attach('before_handler', self._lock_session) + elif locking == 'early': + # Lock before the request body (but after _sessions.init runs!) + hooks.attach('before_request_body', self._lock_session, + priority=60) + else: + # Don't lock + pass + + hooks.attach('before_finalize', _sessions.save) + hooks.attach('on_end_request', _sessions.close) + + def regenerate(self): + """Drop the current session and make a new one (with a new id).""" + sess = cherrypy.serving.session + sess.regenerate() + + # Grab cookie-relevant tool args + conf = dict([(k, v) for k, v in self._merged_args().items() + if k in ('path', 'path_header', 'name', 'timeout', + 'domain', 'secure')]) + _sessions.set_response_cookie(**conf) + + +class XMLRPCController(object): + + """A Controller (page handler collection) for XML-RPC. + + To use it, have your controllers subclass this base class (it will + turn on the tool for you). + + You can also supply the following optional config entries:: + + tools.xmlrpc.encoding: 'utf-8' + tools.xmlrpc.allow_none: 0 + + XML-RPC is a rather discontinuous layer over HTTP; dispatching to the + appropriate handler must first be performed according to the URL, and + then a second dispatch step must take place according to the RPC method + specified in the request body. It also allows a superfluous "/RPC2" + prefix in the URL, supplies its own handler args in the body, and + requires a 200 OK "Fault" response instead of 404 when the desired + method is not found. + + Therefore, XML-RPC cannot be implemented for CherryPy via a Tool alone. + This Controller acts as the dispatch target for the first half (based + on the URL); it then reads the RPC method from the request body and + does its own second dispatch step based on that method. It also reads + body params, and returns a Fault on error. + + The XMLRPCDispatcher strips any /RPC2 prefix; if you aren't using /RPC2 + in your URL's, you can safely skip turning on the XMLRPCDispatcher. + Otherwise, you need to use declare it in config:: + + request.dispatch: cherrypy.dispatch.XMLRPCDispatcher() + """ + + # Note we're hard-coding this into the 'tools' namespace. We could do + # a huge amount of work to make it relocatable, but the only reason why + # would be if someone actually disabled the default_toolbox. Meh. + _cp_config = {'tools.xmlrpc.on': True} + + def default(self, *vpath, **params): + rpcparams, rpcmethod = _xmlrpc.process_body() + + subhandler = self + for attr in str(rpcmethod).split('.'): + subhandler = getattr(subhandler, attr, None) + + if subhandler and getattr(subhandler, "exposed", False): + body = subhandler(*(vpath + rpcparams), **params) + + else: + # https://bitbucket.org/cherrypy/cherrypy/issue/533 + # if a method is not found, an xmlrpclib.Fault should be returned + # raising an exception here will do that; see + # cherrypy.lib.xmlrpcutil.on_error + raise Exception('method "%s" is not supported' % attr) + + conf = cherrypy.serving.request.toolmaps['tools'].get("xmlrpc", {}) + _xmlrpc.respond(body, + conf.get('encoding', 'utf-8'), + conf.get('allow_none', 0)) + return cherrypy.serving.response.body + default.exposed = True + + +class SessionAuthTool(HandlerTool): + + def _setargs(self): + for name in dir(cptools.SessionAuth): + if not name.startswith("__"): + setattr(self, name, None) + + +class CachingTool(Tool): + + """Caching Tool for CherryPy.""" + + def _wrapper(self, **kwargs): + request = cherrypy.serving.request + if _caching.get(**kwargs): + request.handler = None + else: + if request.cacheable: + # Note the devious technique here of adding hooks on the fly + request.hooks.attach('before_finalize', _caching.tee_output, + priority=90) + _wrapper.priority = 20 + + def _setup(self): + """Hook caching into cherrypy.request.""" + conf = self._merged_args() + + p = conf.pop("priority", None) + cherrypy.serving.request.hooks.attach('before_handler', self._wrapper, + priority=p, **conf) + + +class Toolbox(object): + + """A collection of Tools. + + This object also functions as a config namespace handler for itself. + Custom toolboxes should be added to each Application's toolboxes dict. + """ + + def __init__(self, namespace): + self.namespace = namespace + + def __setattr__(self, name, value): + # If the Tool._name is None, supply it from the attribute name. + if isinstance(value, Tool): + if value._name is None: + value._name = name + value.namespace = self.namespace + object.__setattr__(self, name, value) + + def __enter__(self): + """Populate request.toolmaps from tools specified in config.""" + cherrypy.serving.request.toolmaps[self.namespace] = map = {} + + def populate(k, v): + toolname, arg = k.split(".", 1) + bucket = map.setdefault(toolname, {}) + bucket[arg] = v + return populate + + def __exit__(self, exc_type, exc_val, exc_tb): + """Run tool._setup() for each tool in our toolmap.""" + map = cherrypy.serving.request.toolmaps.get(self.namespace) + if map: + for name, settings in map.items(): + if settings.get("on", False): + tool = getattr(self, name) + tool._setup() + + +class DeprecatedTool(Tool): + + _name = None + warnmsg = "This Tool is deprecated." + + def __init__(self, point, warnmsg=None): + self.point = point + if warnmsg is not None: + self.warnmsg = warnmsg + + def __call__(self, *args, **kwargs): + warnings.warn(self.warnmsg) + + def tool_decorator(f): + return f + return tool_decorator + + def _setup(self): + warnings.warn(self.warnmsg) + + +default_toolbox = _d = Toolbox("tools") +_d.session_auth = SessionAuthTool(cptools.session_auth) +_d.allow = Tool('on_start_resource', cptools.allow) +_d.proxy = Tool('before_request_body', cptools.proxy, priority=30) +_d.response_headers = Tool('on_start_resource', cptools.response_headers) +_d.log_tracebacks = Tool('before_error_response', cptools.log_traceback) +_d.log_headers = Tool('before_error_response', cptools.log_request_headers) +_d.log_hooks = Tool('on_end_request', cptools.log_hooks, priority=100) +_d.err_redirect = ErrorTool(cptools.redirect) +_d.etags = Tool('before_finalize', cptools.validate_etags, priority=75) +_d.decode = Tool('before_request_body', encoding.decode) +# the order of encoding, gzip, caching is important +_d.encode = Tool('before_handler', encoding.ResponseEncoder, priority=70) +_d.gzip = Tool('before_finalize', encoding.gzip, priority=80) +_d.staticdir = HandlerTool(static.staticdir) +_d.staticfile = HandlerTool(static.staticfile) +_d.sessions = SessionTool() +_d.xmlrpc = ErrorTool(_xmlrpc.on_error) +_d.caching = CachingTool('before_handler', _caching.get, 'caching') +_d.expires = Tool('before_finalize', _caching.expires) +_d.tidy = DeprecatedTool( + 'before_finalize', + "The tidy tool has been removed from the standard distribution of " + "CherryPy. The most recent version can be found at " + "http://tools.cherrypy.org/browser.") +_d.nsgmls = DeprecatedTool( + 'before_finalize', + "The nsgmls tool has been removed from the standard distribution of " + "CherryPy. The most recent version can be found at " + "http://tools.cherrypy.org/browser.") +_d.ignore_headers = Tool('before_request_body', cptools.ignore_headers) +_d.referer = Tool('before_request_body', cptools.referer) +_d.basic_auth = Tool('on_start_resource', auth.basic_auth) +_d.digest_auth = Tool('on_start_resource', auth.digest_auth) +_d.trailing_slash = Tool('before_handler', cptools.trailing_slash, priority=60) +_d.flatten = Tool('before_finalize', cptools.flatten) +_d.accept = Tool('on_start_resource', cptools.accept) +_d.redirect = Tool('on_start_resource', cptools.redirect) +_d.autovary = Tool('on_start_resource', cptools.autovary, priority=0) +_d.json_in = Tool('before_request_body', jsontools.json_in, priority=30) +_d.json_out = Tool('before_handler', jsontools.json_out, priority=30) +_d.auth_basic = Tool('before_handler', auth_basic.basic_auth, priority=1) +_d.auth_digest = Tool('before_handler', auth_digest.digest_auth, priority=1) + +del _d, cptools, encoding, auth, static diff --git a/lib/cherrypy/_cptree.py b/lib/cherrypy/_cptree.py new file mode 100644 index 00000000..a31b2793 --- /dev/null +++ b/lib/cherrypy/_cptree.py @@ -0,0 +1,299 @@ +"""CherryPy Application and Tree objects.""" + +import os + +import cherrypy +from cherrypy._cpcompat import ntou, py3k +from cherrypy import _cpconfig, _cplogging, _cprequest, _cpwsgi, tools +from cherrypy.lib import httputil + + +class Application(object): + + """A CherryPy Application. + + Servers and gateways should not instantiate Request objects directly. + Instead, they should ask an Application object for a request object. + + An instance of this class may also be used as a WSGI callable + (WSGI application object) for itself. + """ + + root = None + """The top-most container of page handlers for this app. Handlers should + be arranged in a hierarchy of attributes, matching the expected URI + hierarchy; the default dispatcher then searches this hierarchy for a + matching handler. When using a dispatcher other than the default, + this value may be None.""" + + config = {} + """A dict of {path: pathconf} pairs, where 'pathconf' is itself a dict + of {key: value} pairs.""" + + namespaces = _cpconfig.NamespaceSet() + toolboxes = {'tools': cherrypy.tools} + + log = None + """A LogManager instance. See _cplogging.""" + + wsgiapp = None + """A CPWSGIApp instance. See _cpwsgi.""" + + request_class = _cprequest.Request + response_class = _cprequest.Response + + relative_urls = False + + def __init__(self, root, script_name="", config=None): + self.log = _cplogging.LogManager(id(self), cherrypy.log.logger_root) + self.root = root + self.script_name = script_name + self.wsgiapp = _cpwsgi.CPWSGIApp(self) + + self.namespaces = self.namespaces.copy() + self.namespaces["log"] = lambda k, v: setattr(self.log, k, v) + self.namespaces["wsgi"] = self.wsgiapp.namespace_handler + + self.config = self.__class__.config.copy() + if config: + self.merge(config) + + def __repr__(self): + return "%s.%s(%r, %r)" % (self.__module__, self.__class__.__name__, + self.root, self.script_name) + + script_name_doc = """The URI "mount point" for this app. A mount point + is that portion of the URI which is constant for all URIs that are + serviced by this application; it does not include scheme, host, or proxy + ("virtual host") portions of the URI. + + For example, if script_name is "/my/cool/app", then the URL + "http://www.example.com/my/cool/app/page1" might be handled by a + "page1" method on the root object. + + The value of script_name MUST NOT end in a slash. If the script_name + refers to the root of the URI, it MUST be an empty string (not "/"). + + If script_name is explicitly set to None, then the script_name will be + provided for each call from request.wsgi_environ['SCRIPT_NAME']. + """ + + def _get_script_name(self): + if self._script_name is not None: + return self._script_name + + # A `_script_name` with a value of None signals that the script name + # should be pulled from WSGI environ. + return cherrypy.serving.request.wsgi_environ['SCRIPT_NAME'].rstrip("/") + + def _set_script_name(self, value): + if value: + value = value.rstrip("/") + self._script_name = value + script_name = property(fget=_get_script_name, fset=_set_script_name, + doc=script_name_doc) + + def merge(self, config): + """Merge the given config into self.config.""" + _cpconfig.merge(self.config, config) + + # Handle namespaces specified in config. + self.namespaces(self.config.get("/", {})) + + def find_config(self, path, key, default=None): + """Return the most-specific value for key along path, or default.""" + trail = path or "/" + while trail: + nodeconf = self.config.get(trail, {}) + + if key in nodeconf: + return nodeconf[key] + + lastslash = trail.rfind("/") + if lastslash == -1: + break + elif lastslash == 0 and trail != "/": + trail = "/" + else: + trail = trail[:lastslash] + + return default + + def get_serving(self, local, remote, scheme, sproto): + """Create and return a Request and Response object.""" + req = self.request_class(local, remote, scheme, sproto) + req.app = self + + for name, toolbox in self.toolboxes.items(): + req.namespaces[name] = toolbox + + resp = self.response_class() + cherrypy.serving.load(req, resp) + cherrypy.engine.publish('acquire_thread') + cherrypy.engine.publish('before_request') + + return req, resp + + def release_serving(self): + """Release the current serving (request and response).""" + req = cherrypy.serving.request + + cherrypy.engine.publish('after_request') + + try: + req.close() + except: + cherrypy.log(traceback=True, severity=40) + + cherrypy.serving.clear() + + def __call__(self, environ, start_response): + return self.wsgiapp(environ, start_response) + + +class Tree(object): + + """A registry of CherryPy applications, mounted at diverse points. + + An instance of this class may also be used as a WSGI callable + (WSGI application object), in which case it dispatches to all + mounted apps. + """ + + apps = {} + """ + A dict of the form {script name: application}, where "script name" + is a string declaring the URI mount point (no trailing slash), and + "application" is an instance of cherrypy.Application (or an arbitrary + WSGI callable if you happen to be using a WSGI server).""" + + def __init__(self): + self.apps = {} + + def mount(self, root, script_name="", config=None): + """Mount a new app from a root object, script_name, and config. + + root + An instance of a "controller class" (a collection of page + handler methods) which represents the root of the application. + This may also be an Application instance, or None if using + a dispatcher other than the default. + + script_name + A string containing the "mount point" of the application. + This should start with a slash, and be the path portion of the + URL at which to mount the given root. For example, if root.index() + will handle requests to "http://www.example.com:8080/dept/app1/", + then the script_name argument would be "/dept/app1". + + It MUST NOT end in a slash. If the script_name refers to the + root of the URI, it MUST be an empty string (not "/"). + + config + A file or dict containing application config. + """ + if script_name is None: + raise TypeError( + "The 'script_name' argument may not be None. Application " + "objects may, however, possess a script_name of None (in " + "order to inpect the WSGI environ for SCRIPT_NAME upon each " + "request). You cannot mount such Applications on this Tree; " + "you must pass them to a WSGI server interface directly.") + + # Next line both 1) strips trailing slash and 2) maps "/" -> "". + script_name = script_name.rstrip("/") + + if isinstance(root, Application): + app = root + if script_name != "" and script_name != app.script_name: + raise ValueError( + "Cannot specify a different script name and pass an " + "Application instance to cherrypy.mount") + script_name = app.script_name + else: + app = Application(root, script_name) + + # If mounted at "", add favicon.ico + if (script_name == "" and root is not None + and not hasattr(root, "favicon_ico")): + favicon = os.path.join(os.getcwd(), os.path.dirname(__file__), + "favicon.ico") + root.favicon_ico = tools.staticfile.handler(favicon) + + if config: + app.merge(config) + + self.apps[script_name] = app + + return app + + def graft(self, wsgi_callable, script_name=""): + """Mount a wsgi callable at the given script_name.""" + # Next line both 1) strips trailing slash and 2) maps "/" -> "". + script_name = script_name.rstrip("/") + self.apps[script_name] = wsgi_callable + + def script_name(self, path=None): + """The script_name of the app at the given path, or None. + + If path is None, cherrypy.request is used. + """ + if path is None: + try: + request = cherrypy.serving.request + path = httputil.urljoin(request.script_name, + request.path_info) + except AttributeError: + return None + + while True: + if path in self.apps: + return path + + if path == "": + return None + + # Move one node up the tree and try again. + path = path[:path.rfind("/")] + + def __call__(self, environ, start_response): + # If you're calling this, then you're probably setting SCRIPT_NAME + # to '' (some WSGI servers always set SCRIPT_NAME to ''). + # Try to look up the app using the full path. + env1x = environ + if environ.get(ntou('wsgi.version')) == (ntou('u'), 0): + env1x = _cpwsgi.downgrade_wsgi_ux_to_1x(environ) + path = httputil.urljoin(env1x.get('SCRIPT_NAME', ''), + env1x.get('PATH_INFO', '')) + sn = self.script_name(path or "/") + if sn is None: + start_response('404 Not Found', []) + return [] + + app = self.apps[sn] + + # Correct the SCRIPT_NAME and PATH_INFO environ entries. + environ = environ.copy() + if not py3k: + if environ.get(ntou('wsgi.version')) == (ntou('u'), 0): + # Python 2/WSGI u.0: all strings MUST be of type unicode + enc = environ[ntou('wsgi.url_encoding')] + environ[ntou('SCRIPT_NAME')] = sn.decode(enc) + environ[ntou('PATH_INFO')] = path[ + len(sn.rstrip("/")):].decode(enc) + else: + # Python 2/WSGI 1.x: all strings MUST be of type str + environ['SCRIPT_NAME'] = sn + environ['PATH_INFO'] = path[len(sn.rstrip("/")):] + else: + if environ.get(ntou('wsgi.version')) == (ntou('u'), 0): + # Python 3/WSGI u.0: all strings MUST be full unicode + environ['SCRIPT_NAME'] = sn + environ['PATH_INFO'] = path[len(sn.rstrip("/")):] + else: + # Python 3/WSGI 1.x: all strings MUST be ISO-8859-1 str + environ['SCRIPT_NAME'] = sn.encode( + 'utf-8').decode('ISO-8859-1') + environ['PATH_INFO'] = path[ + len(sn.rstrip("/")):].encode('utf-8').decode('ISO-8859-1') + return app(environ, start_response) diff --git a/lib/cherrypy/_cpwsgi.py b/lib/cherrypy/_cpwsgi.py new file mode 100644 index 00000000..f6db68b0 --- /dev/null +++ b/lib/cherrypy/_cpwsgi.py @@ -0,0 +1,438 @@ +"""WSGI interface (see PEP 333 and 3333). + +Note that WSGI environ keys and values are 'native strings'; that is, +whatever the type of "" is. For Python 2, that's a byte string; for Python 3, +it's a unicode string. But PEP 3333 says: "even if Python's str type is +actually Unicode "under the hood", the content of native strings must +still be translatable to bytes via the Latin-1 encoding!" +""" + +import sys as _sys + +import cherrypy as _cherrypy +from cherrypy._cpcompat import BytesIO, bytestr, ntob, ntou, py3k, unicodestr +from cherrypy import _cperror +from cherrypy.lib import httputil +from cherrypy.lib import is_closable_iterator + +def downgrade_wsgi_ux_to_1x(environ): + """Return a new environ dict for WSGI 1.x from the given WSGI u.x environ. + """ + env1x = {} + + url_encoding = environ[ntou('wsgi.url_encoding')] + for k, v in list(environ.items()): + if k in [ntou('PATH_INFO'), ntou('SCRIPT_NAME'), ntou('QUERY_STRING')]: + v = v.encode(url_encoding) + elif isinstance(v, unicodestr): + v = v.encode('ISO-8859-1') + env1x[k.encode('ISO-8859-1')] = v + + return env1x + + +class VirtualHost(object): + + """Select a different WSGI application based on the Host header. + + This can be useful when running multiple sites within one CP server. + It allows several domains to point to different applications. For example:: + + root = Root() + RootApp = cherrypy.Application(root) + Domain2App = cherrypy.Application(root) + SecureApp = cherrypy.Application(Secure()) + + vhost = cherrypy._cpwsgi.VirtualHost(RootApp, + domains={'www.domain2.example': Domain2App, + 'www.domain2.example:443': SecureApp, + }) + + cherrypy.tree.graft(vhost) + """ + default = None + """Required. The default WSGI application.""" + + use_x_forwarded_host = True + """If True (the default), any "X-Forwarded-Host" + request header will be used instead of the "Host" header. This + is commonly added by HTTP servers (such as Apache) when proxying.""" + + domains = {} + """A dict of {host header value: application} pairs. + The incoming "Host" request header is looked up in this dict, + and, if a match is found, the corresponding WSGI application + will be called instead of the default. Note that you often need + separate entries for "example.com" and "www.example.com". + In addition, "Host" headers may contain the port number. + """ + + def __init__(self, default, domains=None, use_x_forwarded_host=True): + self.default = default + self.domains = domains or {} + self.use_x_forwarded_host = use_x_forwarded_host + + def __call__(self, environ, start_response): + domain = environ.get('HTTP_HOST', '') + if self.use_x_forwarded_host: + domain = environ.get("HTTP_X_FORWARDED_HOST", domain) + + nextapp = self.domains.get(domain) + if nextapp is None: + nextapp = self.default + return nextapp(environ, start_response) + + +class InternalRedirector(object): + + """WSGI middleware that handles raised cherrypy.InternalRedirect.""" + + def __init__(self, nextapp, recursive=False): + self.nextapp = nextapp + self.recursive = recursive + + def __call__(self, environ, start_response): + redirections = [] + while True: + environ = environ.copy() + try: + return self.nextapp(environ, start_response) + except _cherrypy.InternalRedirect: + ir = _sys.exc_info()[1] + sn = environ.get('SCRIPT_NAME', '') + path = environ.get('PATH_INFO', '') + qs = environ.get('QUERY_STRING', '') + + # Add the *previous* path_info + qs to redirections. + old_uri = sn + path + if qs: + old_uri += "?" + qs + redirections.append(old_uri) + + if not self.recursive: + # Check to see if the new URI has been redirected to + # already + new_uri = sn + ir.path + if ir.query_string: + new_uri += "?" + ir.query_string + if new_uri in redirections: + ir.request.close() + raise RuntimeError("InternalRedirector visited the " + "same URL twice: %r" % new_uri) + + # Munge the environment and try again. + environ['REQUEST_METHOD'] = "GET" + environ['PATH_INFO'] = ir.path + environ['QUERY_STRING'] = ir.query_string + environ['wsgi.input'] = BytesIO() + environ['CONTENT_LENGTH'] = "0" + environ['cherrypy.previous_request'] = ir.request + + +class ExceptionTrapper(object): + + """WSGI middleware that traps exceptions.""" + + def __init__(self, nextapp, throws=(KeyboardInterrupt, SystemExit)): + self.nextapp = nextapp + self.throws = throws + + def __call__(self, environ, start_response): + return _TrappedResponse( + self.nextapp, + environ, + start_response, + self.throws + ) + + +class _TrappedResponse(object): + + response = iter([]) + + def __init__(self, nextapp, environ, start_response, throws): + self.nextapp = nextapp + self.environ = environ + self.start_response = start_response + self.throws = throws + self.started_response = False + self.response = self.trap( + self.nextapp, self.environ, self.start_response) + self.iter_response = iter(self.response) + + def __iter__(self): + self.started_response = True + return self + + if py3k: + def __next__(self): + return self.trap(next, self.iter_response) + else: + def next(self): + return self.trap(self.iter_response.next) + + def close(self): + if hasattr(self.response, 'close'): + self.response.close() + + def trap(self, func, *args, **kwargs): + try: + return func(*args, **kwargs) + except self.throws: + raise + except StopIteration: + raise + except: + tb = _cperror.format_exc() + #print('trapped (started %s):' % self.started_response, tb) + _cherrypy.log(tb, severity=40) + if not _cherrypy.request.show_tracebacks: + tb = "" + s, h, b = _cperror.bare_error(tb) + if py3k: + # What fun. + s = s.decode('ISO-8859-1') + h = [(k.decode('ISO-8859-1'), v.decode('ISO-8859-1')) + for k, v in h] + if self.started_response: + # Empty our iterable (so future calls raise StopIteration) + self.iter_response = iter([]) + else: + self.iter_response = iter(b) + + try: + self.start_response(s, h, _sys.exc_info()) + except: + # "The application must not trap any exceptions raised by + # start_response, if it called start_response with exc_info. + # Instead, it should allow such exceptions to propagate + # back to the server or gateway." + # But we still log and call close() to clean up ourselves. + _cherrypy.log(traceback=True, severity=40) + raise + + if self.started_response: + return ntob("").join(b) + else: + return b + + +# WSGI-to-CP Adapter # + + +class AppResponse(object): + + """WSGI response iterable for CherryPy applications.""" + + def __init__(self, environ, start_response, cpapp): + self.cpapp = cpapp + try: + if not py3k: + if environ.get(ntou('wsgi.version')) == (ntou('u'), 0): + environ = downgrade_wsgi_ux_to_1x(environ) + self.environ = environ + self.run() + + r = _cherrypy.serving.response + + outstatus = r.output_status + if not isinstance(outstatus, bytestr): + raise TypeError("response.output_status is not a byte string.") + + outheaders = [] + for k, v in r.header_list: + if not isinstance(k, bytestr): + raise TypeError( + "response.header_list key %r is not a byte string." % + k) + if not isinstance(v, bytestr): + raise TypeError( + "response.header_list value %r is not a byte string." % + v) + outheaders.append((k, v)) + + if py3k: + # According to PEP 3333, when using Python 3, the response + # status and headers must be bytes masquerading as unicode; + # that is, they must be of type "str" but are restricted to + # code points in the "latin-1" set. + outstatus = outstatus.decode('ISO-8859-1') + outheaders = [(k.decode('ISO-8859-1'), v.decode('ISO-8859-1')) + for k, v in outheaders] + + self.iter_response = iter(r.body) + self.write = start_response(outstatus, outheaders) + except: + self.close() + raise + + def __iter__(self): + return self + + if py3k: + def __next__(self): + return next(self.iter_response) + else: + def next(self): + return self.iter_response.next() + + def close(self): + """Close and de-reference the current request and response. (Core)""" + streaming = _cherrypy.serving.response.stream + self.cpapp.release_serving() + + # We avoid the expense of examining the iterator to see if it's + # closable unless we are streaming the response, as that's the + # only situation where we are going to have an iterator which + # may not have been exhausted yet. + if streaming and is_closable_iterator(self.iter_response): + iter_close = self.iter_response.close + try: + iter_close() + except Exception: + _cherrypy.log(traceback=True, severity=40) + + def run(self): + """Create a Request object using environ.""" + env = self.environ.get + + local = httputil.Host('', int(env('SERVER_PORT', 80)), + env('SERVER_NAME', '')) + remote = httputil.Host(env('REMOTE_ADDR', ''), + int(env('REMOTE_PORT', -1) or -1), + env('REMOTE_HOST', '')) + scheme = env('wsgi.url_scheme') + sproto = env('ACTUAL_SERVER_PROTOCOL', "HTTP/1.1") + request, resp = self.cpapp.get_serving(local, remote, scheme, sproto) + + # LOGON_USER is served by IIS, and is the name of the + # user after having been mapped to a local account. + # Both IIS and Apache set REMOTE_USER, when possible. + request.login = env('LOGON_USER') or env('REMOTE_USER') or None + request.multithread = self.environ['wsgi.multithread'] + request.multiprocess = self.environ['wsgi.multiprocess'] + request.wsgi_environ = self.environ + request.prev = env('cherrypy.previous_request', None) + + meth = self.environ['REQUEST_METHOD'] + + path = httputil.urljoin(self.environ.get('SCRIPT_NAME', ''), + self.environ.get('PATH_INFO', '')) + qs = self.environ.get('QUERY_STRING', '') + + if py3k: + # This isn't perfect; if the given PATH_INFO is in the + # wrong encoding, it may fail to match the appropriate config + # section URI. But meh. + old_enc = self.environ.get('wsgi.url_encoding', 'ISO-8859-1') + new_enc = self.cpapp.find_config(self.environ.get('PATH_INFO', ''), + "request.uri_encoding", 'utf-8') + if new_enc.lower() != old_enc.lower(): + # Even though the path and qs are unicode, the WSGI server + # is required by PEP 3333 to coerce them to ISO-8859-1 + # masquerading as unicode. So we have to encode back to + # bytes and then decode again using the "correct" encoding. + try: + u_path = path.encode(old_enc).decode(new_enc) + u_qs = qs.encode(old_enc).decode(new_enc) + except (UnicodeEncodeError, UnicodeDecodeError): + # Just pass them through without transcoding and hope. + pass + else: + # Only set transcoded values if they both succeed. + path = u_path + qs = u_qs + + rproto = self.environ.get('SERVER_PROTOCOL') + headers = self.translate_headers(self.environ) + rfile = self.environ['wsgi.input'] + request.run(meth, path, qs, rproto, headers, rfile) + + headerNames = {'HTTP_CGI_AUTHORIZATION': 'Authorization', + 'CONTENT_LENGTH': 'Content-Length', + 'CONTENT_TYPE': 'Content-Type', + 'REMOTE_HOST': 'Remote-Host', + 'REMOTE_ADDR': 'Remote-Addr', + } + + def translate_headers(self, environ): + """Translate CGI-environ header names to HTTP header names.""" + for cgiName in environ: + # We assume all incoming header keys are uppercase already. + if cgiName in self.headerNames: + yield self.headerNames[cgiName], environ[cgiName] + elif cgiName[:5] == "HTTP_": + # Hackish attempt at recovering original header names. + translatedHeader = cgiName[5:].replace("_", "-") + yield translatedHeader, environ[cgiName] + + +class CPWSGIApp(object): + + """A WSGI application object for a CherryPy Application.""" + + pipeline = [('ExceptionTrapper', ExceptionTrapper), + ('InternalRedirector', InternalRedirector), + ] + """A list of (name, wsgiapp) pairs. Each 'wsgiapp' MUST be a + constructor that takes an initial, positional 'nextapp' argument, + plus optional keyword arguments, and returns a WSGI application + (that takes environ and start_response arguments). The 'name' can + be any you choose, and will correspond to keys in self.config.""" + + head = None + """Rather than nest all apps in the pipeline on each call, it's only + done the first time, and the result is memoized into self.head. Set + this to None again if you change self.pipeline after calling self.""" + + config = {} + """A dict whose keys match names listed in the pipeline. Each + value is a further dict which will be passed to the corresponding + named WSGI callable (from the pipeline) as keyword arguments.""" + + response_class = AppResponse + """The class to instantiate and return as the next app in the WSGI chain. + """ + + def __init__(self, cpapp, pipeline=None): + self.cpapp = cpapp + self.pipeline = self.pipeline[:] + if pipeline: + self.pipeline.extend(pipeline) + self.config = self.config.copy() + + def tail(self, environ, start_response): + """WSGI application callable for the actual CherryPy application. + + You probably shouldn't call this; call self.__call__ instead, + so that any WSGI middleware in self.pipeline can run first. + """ + return self.response_class(environ, start_response, self.cpapp) + + def __call__(self, environ, start_response): + head = self.head + if head is None: + # Create and nest the WSGI apps in our pipeline (in reverse order). + # Then memoize the result in self.head. + head = self.tail + for name, callable in self.pipeline[::-1]: + conf = self.config.get(name, {}) + head = callable(head, **conf) + self.head = head + return head(environ, start_response) + + def namespace_handler(self, k, v): + """Config handler for the 'wsgi' namespace.""" + if k == "pipeline": + # Note this allows multiple 'wsgi.pipeline' config entries + # (but each entry will be processed in a 'random' order). + # It should also allow developers to set default middleware + # in code (passed to self.__init__) that deployers can add to + # (but not remove) via config. + self.pipeline.extend(v) + elif k == "response_class": + self.response_class = v + else: + name, arg = k.split(".", 1) + bucket = self.config.setdefault(name, {}) + bucket[arg] = v diff --git a/lib/cherrypy/_cpwsgi_server.py b/lib/cherrypy/_cpwsgi_server.py new file mode 100644 index 00000000..874e2e9f --- /dev/null +++ b/lib/cherrypy/_cpwsgi_server.py @@ -0,0 +1,70 @@ +"""WSGI server interface (see PEP 333). This adds some CP-specific bits to +the framework-agnostic wsgiserver package. +""" +import sys + +import cherrypy +from cherrypy import wsgiserver + + +class CPWSGIServer(wsgiserver.CherryPyWSGIServer): + + """Wrapper for wsgiserver.CherryPyWSGIServer. + + wsgiserver has been designed to not reference CherryPy in any way, + so that it can be used in other frameworks and applications. Therefore, + we wrap it here, so we can set our own mount points from cherrypy.tree + and apply some attributes from config -> cherrypy.server -> wsgiserver. + """ + + def __init__(self, server_adapter=cherrypy.server): + self.server_adapter = server_adapter + self.max_request_header_size = ( + self.server_adapter.max_request_header_size or 0 + ) + self.max_request_body_size = ( + self.server_adapter.max_request_body_size or 0 + ) + + server_name = (self.server_adapter.socket_host or + self.server_adapter.socket_file or + None) + + self.wsgi_version = self.server_adapter.wsgi_version + s = wsgiserver.CherryPyWSGIServer + s.__init__(self, server_adapter.bind_addr, cherrypy.tree, + self.server_adapter.thread_pool, + server_name, + max=self.server_adapter.thread_pool_max, + request_queue_size=self.server_adapter.socket_queue_size, + timeout=self.server_adapter.socket_timeout, + shutdown_timeout=self.server_adapter.shutdown_timeout, + accepted_queue_size=self.server_adapter.accepted_queue_size, + accepted_queue_timeout=self.server_adapter.accepted_queue_timeout, + ) + self.protocol = self.server_adapter.protocol_version + self.nodelay = self.server_adapter.nodelay + + if sys.version_info >= (3, 0): + ssl_module = self.server_adapter.ssl_module or 'builtin' + else: + ssl_module = self.server_adapter.ssl_module or 'pyopenssl' + if self.server_adapter.ssl_context: + adapter_class = wsgiserver.get_ssl_adapter_class(ssl_module) + self.ssl_adapter = adapter_class( + self.server_adapter.ssl_certificate, + self.server_adapter.ssl_private_key, + self.server_adapter.ssl_certificate_chain) + self.ssl_adapter.context = self.server_adapter.ssl_context + elif self.server_adapter.ssl_certificate: + adapter_class = wsgiserver.get_ssl_adapter_class(ssl_module) + self.ssl_adapter = adapter_class( + self.server_adapter.ssl_certificate, + self.server_adapter.ssl_private_key, + self.server_adapter.ssl_certificate_chain) + + self.stats['Enabled'] = getattr( + self.server_adapter, 'statistics', False) + + def error_log(self, msg="", level=20, traceback=False): + cherrypy.engine.log(msg, level, traceback) diff --git a/lib/cherrypy/cherryd b/lib/cherrypy/cherryd new file mode 100644 index 00000000..5afb27ad --- /dev/null +++ b/lib/cherrypy/cherryd @@ -0,0 +1,110 @@ +#! /usr/bin/env python +"""The CherryPy daemon.""" + +import sys + +import cherrypy +from cherrypy.process import plugins, servers +from cherrypy import Application + + +def start(configfiles=None, daemonize=False, environment=None, + fastcgi=False, scgi=False, pidfile=None, imports=None, + cgi=False): + """Subscribe all engine plugins and start the engine.""" + sys.path = [''] + sys.path + for i in imports or []: + exec("import %s" % i) + + for c in configfiles or []: + cherrypy.config.update(c) + # If there's only one app mounted, merge config into it. + if len(cherrypy.tree.apps) == 1: + for app in cherrypy.tree.apps.values(): + if isinstance(app, Application): + app.merge(c) + + engine = cherrypy.engine + + if environment is not None: + cherrypy.config.update({'environment': environment}) + + # Only daemonize if asked to. + if daemonize: + # Don't print anything to stdout/sterr. + cherrypy.config.update({'log.screen': False}) + plugins.Daemonizer(engine).subscribe() + + if pidfile: + plugins.PIDFile(engine, pidfile).subscribe() + + if hasattr(engine, "signal_handler"): + engine.signal_handler.subscribe() + if hasattr(engine, "console_control_handler"): + engine.console_control_handler.subscribe() + + if (fastcgi and (scgi or cgi)) or (scgi and cgi): + cherrypy.log.error("You may only specify one of the cgi, fastcgi, and " + "scgi options.", 'ENGINE') + sys.exit(1) + elif fastcgi or scgi or cgi: + # Turn off autoreload when using *cgi. + cherrypy.config.update({'engine.autoreload_on': False}) + # Turn off the default HTTP server (which is subscribed by default). + cherrypy.server.unsubscribe() + + addr = cherrypy.server.bind_addr + if fastcgi: + f = servers.FlupFCGIServer(application=cherrypy.tree, + bindAddress=addr) + elif scgi: + f = servers.FlupSCGIServer(application=cherrypy.tree, + bindAddress=addr) + else: + f = servers.FlupCGIServer(application=cherrypy.tree, + bindAddress=addr) + s = servers.ServerAdapter(engine, httpserver=f, bind_addr=addr) + s.subscribe() + + # Always start the engine; this will start all other services + try: + engine.start() + except: + # Assume the error has been logged already via bus.log. + sys.exit(1) + else: + engine.block() + + +if __name__ == '__main__': + from optparse import OptionParser + + p = OptionParser() + p.add_option('-c', '--config', action="append", dest='config', + help="specify config file(s)") + p.add_option('-d', action="store_true", dest='daemonize', + help="run the server as a daemon") + p.add_option('-e', '--environment', dest='environment', default=None, + help="apply the given config environment") + p.add_option('-f', action="store_true", dest='fastcgi', + help="start a fastcgi server instead of the default HTTP " + "server") + p.add_option('-s', action="store_true", dest='scgi', + help="start a scgi server instead of the default HTTP server") + p.add_option('-x', action="store_true", dest='cgi', + help="start a cgi server instead of the default HTTP server") + p.add_option('-i', '--import', action="append", dest='imports', + help="specify modules to import") + p.add_option('-p', '--pidfile', dest='pidfile', default=None, + help="store the process id in the given file") + p.add_option('-P', '--Path', action="append", dest='Path', + help="add the given paths to sys.path") + options, args = p.parse_args() + + if options.Path: + for p in options.Path: + sys.path.insert(0, p) + + start(options.config, options.daemonize, + options.environment, options.fastcgi, options.scgi, + options.pidfile, options.imports, options.cgi) diff --git a/lib/cherrypy/favicon.ico b/lib/cherrypy/favicon.ico new file mode 100644 index 00000000..f0d7e61b Binary files /dev/null and b/lib/cherrypy/favicon.ico differ diff --git a/lib/cherrypy/lib/__init__.py b/lib/cherrypy/lib/__init__.py new file mode 100644 index 00000000..a75a53da --- /dev/null +++ b/lib/cherrypy/lib/__init__.py @@ -0,0 +1,85 @@ +"""CherryPy Library""" + +# Deprecated in CherryPy 3.2 -- remove in CherryPy 3.3 +from cherrypy.lib.reprconf import unrepr, modules, attributes + +def is_iterator(obj): + '''Returns a boolean indicating if the object provided implements + the iterator protocol (i.e. like a generator). This will return + false for objects which iterable, but not iterators themselves.''' + from types import GeneratorType + if isinstance(obj, GeneratorType): + return True + elif not hasattr(obj, '__iter__'): + return False + else: + # Types which implement the protocol must return themselves when + # invoking 'iter' upon them. + return iter(obj) is obj + +def is_closable_iterator(obj): + + # Not an iterator. + if not is_iterator(obj): + return False + + # A generator - the easiest thing to deal with. + import inspect + if inspect.isgenerator(obj): + return True + + # A custom iterator. Look for a close method... + if not (hasattr(obj, 'close') and callable(obj.close)): + return False + + # ... which doesn't require any arguments. + try: + inspect.getcallargs(obj.close) + except TypeError: + return False + else: + return True + +class file_generator(object): + + """Yield the given input (a file object) in chunks (default 64k). (Core)""" + + def __init__(self, input, chunkSize=65536): + self.input = input + self.chunkSize = chunkSize + + def __iter__(self): + return self + + def __next__(self): + chunk = self.input.read(self.chunkSize) + if chunk: + return chunk + else: + if hasattr(self.input, 'close'): + self.input.close() + raise StopIteration() + next = __next__ + + +def file_generator_limited(fileobj, count, chunk_size=65536): + """Yield the given file object in chunks, stopping after `count` + bytes has been emitted. Default chunk size is 64kB. (Core) + """ + remaining = count + while remaining > 0: + chunk = fileobj.read(min(chunk_size, remaining)) + chunklen = len(chunk) + if chunklen == 0: + return + remaining -= chunklen + yield chunk + + +def set_vary_header(response, header_name): + "Add a Vary header to a response" + varies = response.headers.get("Vary", "") + varies = [x.strip() for x in varies.split(",") if x.strip()] + if header_name not in varies: + varies.append(header_name) + response.headers['Vary'] = ", ".join(varies) diff --git a/lib/cherrypy/lib/auth.py b/lib/cherrypy/lib/auth.py new file mode 100644 index 00000000..71591aaa --- /dev/null +++ b/lib/cherrypy/lib/auth.py @@ -0,0 +1,97 @@ +import cherrypy +from cherrypy.lib import httpauth + + +def check_auth(users, encrypt=None, realm=None): + """If an authorization header contains credentials, return True or False. + """ + request = cherrypy.serving.request + if 'authorization' in request.headers: + # make sure the provided credentials are correctly set + ah = httpauth.parseAuthorization(request.headers['authorization']) + if ah is None: + raise cherrypy.HTTPError(400, 'Bad Request') + + if not encrypt: + encrypt = httpauth.DIGEST_AUTH_ENCODERS[httpauth.MD5] + + if hasattr(users, '__call__'): + try: + # backward compatibility + users = users() # expect it to return a dictionary + + if not isinstance(users, dict): + raise ValueError( + "Authentication users must be a dictionary") + + # fetch the user password + password = users.get(ah["username"], None) + except TypeError: + # returns a password (encrypted or clear text) + password = users(ah["username"]) + else: + if not isinstance(users, dict): + raise ValueError("Authentication users must be a dictionary") + + # fetch the user password + password = users.get(ah["username"], None) + + # validate the authorization by re-computing it here + # and compare it with what the user-agent provided + if httpauth.checkResponse(ah, password, method=request.method, + encrypt=encrypt, realm=realm): + request.login = ah["username"] + return True + + request.login = False + return False + + +def basic_auth(realm, users, encrypt=None, debug=False): + """If auth fails, raise 401 with a basic authentication header. + + realm + A string containing the authentication realm. + + users + A dict of the form: {username: password} or a callable returning + a dict. + + encrypt + callable used to encrypt the password returned from the user-agent. + if None it defaults to a md5 encryption. + + """ + if check_auth(users, encrypt): + if debug: + cherrypy.log('Auth successful', 'TOOLS.BASIC_AUTH') + return + + # inform the user-agent this path is protected + cherrypy.serving.response.headers[ + 'www-authenticate'] = httpauth.basicAuth(realm) + + raise cherrypy.HTTPError( + 401, "You are not authorized to access that resource") + + +def digest_auth(realm, users, debug=False): + """If auth fails, raise 401 with a digest authentication header. + + realm + A string containing the authentication realm. + users + A dict of the form: {username: password} or a callable returning + a dict. + """ + if check_auth(users, realm=realm): + if debug: + cherrypy.log('Auth successful', 'TOOLS.DIGEST_AUTH') + return + + # inform the user-agent this path is protected + cherrypy.serving.response.headers[ + 'www-authenticate'] = httpauth.digestAuth(realm) + + raise cherrypy.HTTPError( + 401, "You are not authorized to access that resource") diff --git a/lib/cherrypy/lib/auth_basic.py b/lib/cherrypy/lib/auth_basic.py new file mode 100644 index 00000000..5ba16f7f --- /dev/null +++ b/lib/cherrypy/lib/auth_basic.py @@ -0,0 +1,90 @@ +# This file is part of CherryPy +# -*- coding: utf-8 -*- +# vim:ts=4:sw=4:expandtab:fileencoding=utf-8 + +__doc__ = """This module provides a CherryPy 3.x tool which implements +the server-side of HTTP Basic Access Authentication, as described in +:rfc:`2617`. + +Example usage, using the built-in checkpassword_dict function which uses a dict +as the credentials store:: + + userpassdict = {'bird' : 'bebop', 'ornette' : 'wayout'} + checkpassword = cherrypy.lib.auth_basic.checkpassword_dict(userpassdict) + basic_auth = {'tools.auth_basic.on': True, + 'tools.auth_basic.realm': 'earth', + 'tools.auth_basic.checkpassword': checkpassword, + } + app_config = { '/' : basic_auth } + +""" + +__author__ = 'visteya' +__date__ = 'April 2009' + +import binascii +from cherrypy._cpcompat import base64_decode +import cherrypy + + +def checkpassword_dict(user_password_dict): + """Returns a checkpassword function which checks credentials + against a dictionary of the form: {username : password}. + + If you want a simple dictionary-based authentication scheme, use + checkpassword_dict(my_credentials_dict) as the value for the + checkpassword argument to basic_auth(). + """ + def checkpassword(realm, user, password): + p = user_password_dict.get(user) + return p and p == password or False + + return checkpassword + + +def basic_auth(realm, checkpassword, debug=False): + """A CherryPy tool which hooks at before_handler to perform + HTTP Basic Access Authentication, as specified in :rfc:`2617`. + + If the request has an 'authorization' header with a 'Basic' scheme, this + tool attempts to authenticate the credentials supplied in that header. If + the request has no 'authorization' header, or if it does but the scheme is + not 'Basic', or if authentication fails, the tool sends a 401 response with + a 'WWW-Authenticate' Basic header. + + realm + A string containing the authentication realm. + + checkpassword + A callable which checks the authentication credentials. + Its signature is checkpassword(realm, username, password). where + username and password are the values obtained from the request's + 'authorization' header. If authentication succeeds, checkpassword + returns True, else it returns False. + + """ + + if '"' in realm: + raise ValueError('Realm cannot contain the " (quote) character.') + request = cherrypy.serving.request + + auth_header = request.headers.get('authorization') + if auth_header is not None: + try: + scheme, params = auth_header.split(' ', 1) + if scheme.lower() == 'basic': + username, password = base64_decode(params).split(':', 1) + if checkpassword(realm, username, password): + if debug: + cherrypy.log('Auth succeeded', 'TOOLS.AUTH_BASIC') + request.login = username + return # successful authentication + # split() error, base64.decodestring() error + except (ValueError, binascii.Error): + raise cherrypy.HTTPError(400, 'Bad Request') + + # Respond with 401 status and a WWW-Authenticate header + cherrypy.serving.response.headers[ + 'www-authenticate'] = 'Basic realm="%s"' % realm + raise cherrypy.HTTPError( + 401, "You are not authorized to access that resource") diff --git a/lib/cherrypy/lib/auth_digest.py b/lib/cherrypy/lib/auth_digest.py new file mode 100644 index 00000000..e06535dc --- /dev/null +++ b/lib/cherrypy/lib/auth_digest.py @@ -0,0 +1,390 @@ +# This file is part of CherryPy +# -*- coding: utf-8 -*- +# vim:ts=4:sw=4:expandtab:fileencoding=utf-8 + +__doc__ = """An implementation of the server-side of HTTP Digest Access +Authentication, which is described in :rfc:`2617`. + +Example usage, using the built-in get_ha1_dict_plain function which uses a dict +of plaintext passwords as the credentials store:: + + userpassdict = {'alice' : '4x5istwelve'} + get_ha1 = cherrypy.lib.auth_digest.get_ha1_dict_plain(userpassdict) + digest_auth = {'tools.auth_digest.on': True, + 'tools.auth_digest.realm': 'wonderland', + 'tools.auth_digest.get_ha1': get_ha1, + 'tools.auth_digest.key': 'a565c27146791cfb', + } + app_config = { '/' : digest_auth } +""" + +__author__ = 'visteya' +__date__ = 'April 2009' + + +import time +from cherrypy._cpcompat import parse_http_list, parse_keqv_list + +import cherrypy +from cherrypy._cpcompat import md5, ntob +md5_hex = lambda s: md5(ntob(s)).hexdigest() + +qop_auth = 'auth' +qop_auth_int = 'auth-int' +valid_qops = (qop_auth, qop_auth_int) + +valid_algorithms = ('MD5', 'MD5-sess') + + +def TRACE(msg): + cherrypy.log(msg, context='TOOLS.AUTH_DIGEST') + +# Three helper functions for users of the tool, providing three variants +# of get_ha1() functions for three different kinds of credential stores. + + +def get_ha1_dict_plain(user_password_dict): + """Returns a get_ha1 function which obtains a plaintext password from a + dictionary of the form: {username : password}. + + If you want a simple dictionary-based authentication scheme, with plaintext + passwords, use get_ha1_dict_plain(my_userpass_dict) as the value for the + get_ha1 argument to digest_auth(). + """ + def get_ha1(realm, username): + password = user_password_dict.get(username) + if password: + return md5_hex('%s:%s:%s' % (username, realm, password)) + return None + + return get_ha1 + + +def get_ha1_dict(user_ha1_dict): + """Returns a get_ha1 function which obtains a HA1 password hash from a + dictionary of the form: {username : HA1}. + + If you want a dictionary-based authentication scheme, but with + pre-computed HA1 hashes instead of plain-text passwords, use + get_ha1_dict(my_userha1_dict) as the value for the get_ha1 + argument to digest_auth(). + """ + def get_ha1(realm, username): + return user_ha1_dict.get(username) + + return get_ha1 + + +def get_ha1_file_htdigest(filename): + """Returns a get_ha1 function which obtains a HA1 password hash from a + flat file with lines of the same format as that produced by the Apache + htdigest utility. For example, for realm 'wonderland', username 'alice', + and password '4x5istwelve', the htdigest line would be:: + + alice:wonderland:3238cdfe91a8b2ed8e39646921a02d4c + + If you want to use an Apache htdigest file as the credentials store, + then use get_ha1_file_htdigest(my_htdigest_file) as the value for the + get_ha1 argument to digest_auth(). It is recommended that the filename + argument be an absolute path, to avoid problems. + """ + def get_ha1(realm, username): + result = None + f = open(filename, 'r') + for line in f: + u, r, ha1 = line.rstrip().split(':') + if u == username and r == realm: + result = ha1 + break + f.close() + return result + + return get_ha1 + + +def synthesize_nonce(s, key, timestamp=None): + """Synthesize a nonce value which resists spoofing and can be checked + for staleness. Returns a string suitable as the value for 'nonce' in + the www-authenticate header. + + s + A string related to the resource, such as the hostname of the server. + + key + A secret string known only to the server. + + timestamp + An integer seconds-since-the-epoch timestamp + + """ + if timestamp is None: + timestamp = int(time.time()) + h = md5_hex('%s:%s:%s' % (timestamp, s, key)) + nonce = '%s:%s' % (timestamp, h) + return nonce + + +def H(s): + """The hash function H""" + return md5_hex(s) + + +class HttpDigestAuthorization (object): + + """Class to parse a Digest Authorization header and perform re-calculation + of the digest. + """ + + def errmsg(self, s): + return 'Digest Authorization header: %s' % s + + def __init__(self, auth_header, http_method, debug=False): + self.http_method = http_method + self.debug = debug + scheme, params = auth_header.split(" ", 1) + self.scheme = scheme.lower() + if self.scheme != 'digest': + raise ValueError('Authorization scheme is not "Digest"') + + self.auth_header = auth_header + + # make a dict of the params + items = parse_http_list(params) + paramsd = parse_keqv_list(items) + + self.realm = paramsd.get('realm') + self.username = paramsd.get('username') + self.nonce = paramsd.get('nonce') + self.uri = paramsd.get('uri') + self.method = paramsd.get('method') + self.response = paramsd.get('response') # the response digest + self.algorithm = paramsd.get('algorithm', 'MD5').upper() + self.cnonce = paramsd.get('cnonce') + self.opaque = paramsd.get('opaque') + self.qop = paramsd.get('qop') # qop + self.nc = paramsd.get('nc') # nonce count + + # perform some correctness checks + if self.algorithm not in valid_algorithms: + raise ValueError( + self.errmsg("Unsupported value for algorithm: '%s'" % + self.algorithm)) + + has_reqd = ( + self.username and + self.realm and + self.nonce and + self.uri and + self.response + ) + if not has_reqd: + raise ValueError( + self.errmsg("Not all required parameters are present.")) + + if self.qop: + if self.qop not in valid_qops: + raise ValueError( + self.errmsg("Unsupported value for qop: '%s'" % self.qop)) + if not (self.cnonce and self.nc): + raise ValueError( + self.errmsg("If qop is sent then " + "cnonce and nc MUST be present")) + else: + if self.cnonce or self.nc: + raise ValueError( + self.errmsg("If qop is not sent, " + "neither cnonce nor nc can be present")) + + def __str__(self): + return 'authorization : %s' % self.auth_header + + def validate_nonce(self, s, key): + """Validate the nonce. + Returns True if nonce was generated by synthesize_nonce() and the + timestamp is not spoofed, else returns False. + + s + A string related to the resource, such as the hostname of + the server. + + key + A secret string known only to the server. + + Both s and key must be the same values which were used to synthesize + the nonce we are trying to validate. + """ + try: + timestamp, hashpart = self.nonce.split(':', 1) + s_timestamp, s_hashpart = synthesize_nonce( + s, key, timestamp).split(':', 1) + is_valid = s_hashpart == hashpart + if self.debug: + TRACE('validate_nonce: %s' % is_valid) + return is_valid + except ValueError: # split() error + pass + return False + + def is_nonce_stale(self, max_age_seconds=600): + """Returns True if a validated nonce is stale. The nonce contains a + timestamp in plaintext and also a secure hash of the timestamp. + You should first validate the nonce to ensure the plaintext + timestamp is not spoofed. + """ + try: + timestamp, hashpart = self.nonce.split(':', 1) + if int(timestamp) + max_age_seconds > int(time.time()): + return False + except ValueError: # int() error + pass + if self.debug: + TRACE("nonce is stale") + return True + + def HA2(self, entity_body=''): + """Returns the H(A2) string. See :rfc:`2617` section 3.2.2.3.""" + # RFC 2617 3.2.2.3 + # If the "qop" directive's value is "auth" or is unspecified, + # then A2 is: + # A2 = method ":" digest-uri-value + # + # If the "qop" value is "auth-int", then A2 is: + # A2 = method ":" digest-uri-value ":" H(entity-body) + if self.qop is None or self.qop == "auth": + a2 = '%s:%s' % (self.http_method, self.uri) + elif self.qop == "auth-int": + a2 = "%s:%s:%s" % (self.http_method, self.uri, H(entity_body)) + else: + # in theory, this should never happen, since I validate qop in + # __init__() + raise ValueError(self.errmsg("Unrecognized value for qop!")) + return H(a2) + + def request_digest(self, ha1, entity_body=''): + """Calculates the Request-Digest. See :rfc:`2617` section 3.2.2.1. + + ha1 + The HA1 string obtained from the credentials store. + + entity_body + If 'qop' is set to 'auth-int', then A2 includes a hash + of the "entity body". The entity body is the part of the + message which follows the HTTP headers. See :rfc:`2617` section + 4.3. This refers to the entity the user agent sent in the + request which has the Authorization header. Typically GET + requests don't have an entity, and POST requests do. + + """ + ha2 = self.HA2(entity_body) + # Request-Digest -- RFC 2617 3.2.2.1 + if self.qop: + req = "%s:%s:%s:%s:%s" % ( + self.nonce, self.nc, self.cnonce, self.qop, ha2) + else: + req = "%s:%s" % (self.nonce, ha2) + + # RFC 2617 3.2.2.2 + # + # If the "algorithm" directive's value is "MD5" or is unspecified, + # then A1 is: + # A1 = unq(username-value) ":" unq(realm-value) ":" passwd + # + # If the "algorithm" directive's value is "MD5-sess", then A1 is + # calculated only once - on the first request by the client following + # receipt of a WWW-Authenticate challenge from the server. + # A1 = H( unq(username-value) ":" unq(realm-value) ":" passwd ) + # ":" unq(nonce-value) ":" unq(cnonce-value) + if self.algorithm == 'MD5-sess': + ha1 = H('%s:%s:%s' % (ha1, self.nonce, self.cnonce)) + + digest = H('%s:%s' % (ha1, req)) + return digest + + +def www_authenticate(realm, key, algorithm='MD5', nonce=None, qop=qop_auth, + stale=False): + """Constructs a WWW-Authenticate header for Digest authentication.""" + if qop not in valid_qops: + raise ValueError("Unsupported value for qop: '%s'" % qop) + if algorithm not in valid_algorithms: + raise ValueError("Unsupported value for algorithm: '%s'" % algorithm) + + if nonce is None: + nonce = synthesize_nonce(realm, key) + s = 'Digest realm="%s", nonce="%s", algorithm="%s", qop="%s"' % ( + realm, nonce, algorithm, qop) + if stale: + s += ', stale="true"' + return s + + +def digest_auth(realm, get_ha1, key, debug=False): + """A CherryPy tool which hooks at before_handler to perform + HTTP Digest Access Authentication, as specified in :rfc:`2617`. + + If the request has an 'authorization' header with a 'Digest' scheme, + this tool authenticates the credentials supplied in that header. + If the request has no 'authorization' header, or if it does but the + scheme is not "Digest", or if authentication fails, the tool sends + a 401 response with a 'WWW-Authenticate' Digest header. + + realm + A string containing the authentication realm. + + get_ha1 + A callable which looks up a username in a credentials store + and returns the HA1 string, which is defined in the RFC to be + MD5(username : realm : password). The function's signature is: + ``get_ha1(realm, username)`` + where username is obtained from the request's 'authorization' header. + If username is not found in the credentials store, get_ha1() returns + None. + + key + A secret string known only to the server, used in the synthesis + of nonces. + + """ + request = cherrypy.serving.request + + auth_header = request.headers.get('authorization') + nonce_is_stale = False + if auth_header is not None: + try: + auth = HttpDigestAuthorization( + auth_header, request.method, debug=debug) + except ValueError: + raise cherrypy.HTTPError( + 400, "The Authorization header could not be parsed.") + + if debug: + TRACE(str(auth)) + + if auth.validate_nonce(realm, key): + ha1 = get_ha1(realm, auth.username) + if ha1 is not None: + # note that for request.body to be available we need to + # hook in at before_handler, not on_start_resource like + # 3.1.x digest_auth does. + digest = auth.request_digest(ha1, entity_body=request.body) + if digest == auth.response: # authenticated + if debug: + TRACE("digest matches auth.response") + # Now check if nonce is stale. + # The choice of ten minutes' lifetime for nonce is somewhat + # arbitrary + nonce_is_stale = auth.is_nonce_stale(max_age_seconds=600) + if not nonce_is_stale: + request.login = auth.username + if debug: + TRACE("authentication of %s successful" % + auth.username) + return + + # Respond with 401 status and a WWW-Authenticate header + header = www_authenticate(realm, key, stale=nonce_is_stale) + if debug: + TRACE(header) + cherrypy.serving.response.headers['WWW-Authenticate'] = header + raise cherrypy.HTTPError( + 401, "You are not authorized to access that resource") diff --git a/lib/cherrypy/lib/caching.py b/lib/cherrypy/lib/caching.py new file mode 100644 index 00000000..fab6b569 --- /dev/null +++ b/lib/cherrypy/lib/caching.py @@ -0,0 +1,470 @@ +""" +CherryPy implements a simple caching system as a pluggable Tool. This tool +tries to be an (in-process) HTTP/1.1-compliant cache. It's not quite there +yet, but it's probably good enough for most sites. + +In general, GET responses are cached (along with selecting headers) and, if +another request arrives for the same resource, the caching Tool will return 304 +Not Modified if possible, or serve the cached response otherwise. It also sets +request.cached to True if serving a cached representation, and sets +request.cacheable to False (so it doesn't get cached again). + +If POST, PUT, or DELETE requests are made for a cached resource, they +invalidate (delete) any cached response. + +Usage +===== + +Configuration file example:: + + [/] + tools.caching.on = True + tools.caching.delay = 3600 + +You may use a class other than the default +:class:`MemoryCache` by supplying the config +entry ``cache_class``; supply the full dotted name of the replacement class +as the config value. It must implement the basic methods ``get``, ``put``, +``delete``, and ``clear``. + +You may set any attribute, including overriding methods, on the cache +instance by providing them in config. The above sets the +:attr:`delay` attribute, for example. +""" + +import datetime +import sys +import threading +import time + +import cherrypy +from cherrypy.lib import cptools, httputil +from cherrypy._cpcompat import copyitems, ntob, set_daemon, sorted, Event + + +class Cache(object): + + """Base class for Cache implementations.""" + + def get(self): + """Return the current variant if in the cache, else None.""" + raise NotImplemented + + def put(self, obj, size): + """Store the current variant in the cache.""" + raise NotImplemented + + def delete(self): + """Remove ALL cached variants of the current resource.""" + raise NotImplemented + + def clear(self): + """Reset the cache to its initial, empty state.""" + raise NotImplemented + + +# ------------------------------ Memory Cache ------------------------------- # +class AntiStampedeCache(dict): + + """A storage system for cached items which reduces stampede collisions.""" + + def wait(self, key, timeout=5, debug=False): + """Return the cached value for the given key, or None. + + If timeout is not None, and the value is already + being calculated by another thread, wait until the given timeout has + elapsed. If the value is available before the timeout expires, it is + returned. If not, None is returned, and a sentinel placed in the cache + to signal other threads to wait. + + If timeout is None, no waiting is performed nor sentinels used. + """ + value = self.get(key) + if isinstance(value, Event): + if timeout is None: + # Ignore the other thread and recalc it ourselves. + if debug: + cherrypy.log('No timeout', 'TOOLS.CACHING') + return None + + # Wait until it's done or times out. + if debug: + cherrypy.log('Waiting up to %s seconds' % + timeout, 'TOOLS.CACHING') + value.wait(timeout) + if value.result is not None: + # The other thread finished its calculation. Use it. + if debug: + cherrypy.log('Result!', 'TOOLS.CACHING') + return value.result + # Timed out. Stick an Event in the slot so other threads wait + # on this one to finish calculating the value. + if debug: + cherrypy.log('Timed out', 'TOOLS.CACHING') + e = threading.Event() + e.result = None + dict.__setitem__(self, key, e) + + return None + elif value is None: + # Stick an Event in the slot so other threads wait + # on this one to finish calculating the value. + if debug: + cherrypy.log('Timed out', 'TOOLS.CACHING') + e = threading.Event() + e.result = None + dict.__setitem__(self, key, e) + return value + + def __setitem__(self, key, value): + """Set the cached value for the given key.""" + existing = self.get(key) + dict.__setitem__(self, key, value) + if isinstance(existing, Event): + # Set Event.result so other threads waiting on it have + # immediate access without needing to poll the cache again. + existing.result = value + existing.set() + + +class MemoryCache(Cache): + + """An in-memory cache for varying response content. + + Each key in self.store is a URI, and each value is an AntiStampedeCache. + The response for any given URI may vary based on the values of + "selecting request headers"; that is, those named in the Vary + response header. We assume the list of header names to be constant + for each URI throughout the lifetime of the application, and store + that list in ``self.store[uri].selecting_headers``. + + The items contained in ``self.store[uri]`` have keys which are tuples of + request header values (in the same order as the names in its + selecting_headers), and values which are the actual responses. + """ + + maxobjects = 1000 + """The maximum number of cached objects; defaults to 1000.""" + + maxobj_size = 100000 + """The maximum size of each cached object in bytes; defaults to 100 KB.""" + + maxsize = 10000000 + """The maximum size of the entire cache in bytes; defaults to 10 MB.""" + + delay = 600 + """Seconds until the cached content expires; defaults to 600 (10 minutes). + """ + + antistampede_timeout = 5 + """Seconds to wait for other threads to release a cache lock.""" + + expire_freq = 0.1 + """Seconds to sleep between cache expiration sweeps.""" + + debug = False + + def __init__(self): + self.clear() + + # Run self.expire_cache in a separate daemon thread. + t = threading.Thread(target=self.expire_cache, name='expire_cache') + self.expiration_thread = t + set_daemon(t, True) + t.start() + + def clear(self): + """Reset the cache to its initial, empty state.""" + self.store = {} + self.expirations = {} + self.tot_puts = 0 + self.tot_gets = 0 + self.tot_hist = 0 + self.tot_expires = 0 + self.tot_non_modified = 0 + self.cursize = 0 + + def expire_cache(self): + """Continuously examine cached objects, expiring stale ones. + + This function is designed to be run in its own daemon thread, + referenced at ``self.expiration_thread``. + """ + # It's possible that "time" will be set to None + # arbitrarily, so we check "while time" to avoid exceptions. + # See tickets #99 and #180 for more information. + while time: + now = time.time() + # Must make a copy of expirations so it doesn't change size + # during iteration + for expiration_time, objects in copyitems(self.expirations): + if expiration_time <= now: + for obj_size, uri, sel_header_values in objects: + try: + del self.store[uri][tuple(sel_header_values)] + self.tot_expires += 1 + self.cursize -= obj_size + except KeyError: + # the key may have been deleted elsewhere + pass + del self.expirations[expiration_time] + time.sleep(self.expire_freq) + + def get(self): + """Return the current variant if in the cache, else None.""" + request = cherrypy.serving.request + self.tot_gets += 1 + + uri = cherrypy.url(qs=request.query_string) + uricache = self.store.get(uri) + if uricache is None: + return None + + header_values = [request.headers.get(h, '') + for h in uricache.selecting_headers] + variant = uricache.wait(key=tuple(sorted(header_values)), + timeout=self.antistampede_timeout, + debug=self.debug) + if variant is not None: + self.tot_hist += 1 + return variant + + def put(self, variant, size): + """Store the current variant in the cache.""" + request = cherrypy.serving.request + response = cherrypy.serving.response + + uri = cherrypy.url(qs=request.query_string) + uricache = self.store.get(uri) + if uricache is None: + uricache = AntiStampedeCache() + uricache.selecting_headers = [ + e.value for e in response.headers.elements('Vary')] + self.store[uri] = uricache + + if len(self.store) < self.maxobjects: + total_size = self.cursize + size + + # checks if there's space for the object + if (size < self.maxobj_size and total_size < self.maxsize): + # add to the expirations list + expiration_time = response.time + self.delay + bucket = self.expirations.setdefault(expiration_time, []) + bucket.append((size, uri, uricache.selecting_headers)) + + # add to the cache + header_values = [request.headers.get(h, '') + for h in uricache.selecting_headers] + uricache[tuple(sorted(header_values))] = variant + self.tot_puts += 1 + self.cursize = total_size + + def delete(self): + """Remove ALL cached variants of the current resource.""" + uri = cherrypy.url(qs=cherrypy.serving.request.query_string) + self.store.pop(uri, None) + + +def get(invalid_methods=("POST", "PUT", "DELETE"), debug=False, **kwargs): + """Try to obtain cached output. If fresh enough, raise HTTPError(304). + + If POST, PUT, or DELETE: + * invalidates (deletes) any cached response for this resource + * sets request.cached = False + * sets request.cacheable = False + + else if a cached copy exists: + * sets request.cached = True + * sets request.cacheable = False + * sets response.headers to the cached values + * checks the cached Last-Modified response header against the + current If-(Un)Modified-Since request headers; raises 304 + if necessary. + * sets response.status and response.body to the cached values + * returns True + + otherwise: + * sets request.cached = False + * sets request.cacheable = True + * returns False + """ + request = cherrypy.serving.request + response = cherrypy.serving.response + + if not hasattr(cherrypy, "_cache"): + # Make a process-wide Cache object. + cherrypy._cache = kwargs.pop("cache_class", MemoryCache)() + + # Take all remaining kwargs and set them on the Cache object. + for k, v in kwargs.items(): + setattr(cherrypy._cache, k, v) + cherrypy._cache.debug = debug + + # POST, PUT, DELETE should invalidate (delete) the cached copy. + # See http://www.w3.org/Protocols/rfc2616/rfc2616-sec13.html#sec13.10. + if request.method in invalid_methods: + if debug: + cherrypy.log('request.method %r in invalid_methods %r' % + (request.method, invalid_methods), 'TOOLS.CACHING') + cherrypy._cache.delete() + request.cached = False + request.cacheable = False + return False + + if 'no-cache' in [e.value for e in request.headers.elements('Pragma')]: + request.cached = False + request.cacheable = True + return False + + cache_data = cherrypy._cache.get() + request.cached = bool(cache_data) + request.cacheable = not request.cached + if request.cached: + # Serve the cached copy. + max_age = cherrypy._cache.delay + for v in [e.value for e in request.headers.elements('Cache-Control')]: + atoms = v.split('=', 1) + directive = atoms.pop(0) + if directive == 'max-age': + if len(atoms) != 1 or not atoms[0].isdigit(): + raise cherrypy.HTTPError( + 400, "Invalid Cache-Control header") + max_age = int(atoms[0]) + break + elif directive == 'no-cache': + if debug: + cherrypy.log( + 'Ignoring cache due to Cache-Control: no-cache', + 'TOOLS.CACHING') + request.cached = False + request.cacheable = True + return False + + if debug: + cherrypy.log('Reading response from cache', 'TOOLS.CACHING') + s, h, b, create_time = cache_data + age = int(response.time - create_time) + if (age > max_age): + if debug: + cherrypy.log('Ignoring cache due to age > %d' % max_age, + 'TOOLS.CACHING') + request.cached = False + request.cacheable = True + return False + + # Copy the response headers. See + # https://bitbucket.org/cherrypy/cherrypy/issue/721. + response.headers = rh = httputil.HeaderMap() + for k in h: + dict.__setitem__(rh, k, dict.__getitem__(h, k)) + + # Add the required Age header + response.headers["Age"] = str(age) + + try: + # Note that validate_since depends on a Last-Modified header; + # this was put into the cached copy, and should have been + # resurrected just above (response.headers = cache_data[1]). + cptools.validate_since() + except cherrypy.HTTPRedirect: + x = sys.exc_info()[1] + if x.status == 304: + cherrypy._cache.tot_non_modified += 1 + raise + + # serve it & get out from the request + response.status = s + response.body = b + else: + if debug: + cherrypy.log('request is not cached', 'TOOLS.CACHING') + return request.cached + + +def tee_output(): + """Tee response output to cache storage. Internal.""" + # Used by CachingTool by attaching to request.hooks + + request = cherrypy.serving.request + if 'no-store' in request.headers.values('Cache-Control'): + return + + def tee(body): + """Tee response.body into a list.""" + if ('no-cache' in response.headers.values('Pragma') or + 'no-store' in response.headers.values('Cache-Control')): + for chunk in body: + yield chunk + return + + output = [] + for chunk in body: + output.append(chunk) + yield chunk + + # save the cache data + body = ntob('').join(output) + cherrypy._cache.put((response.status, response.headers or {}, + body, response.time), len(body)) + + response = cherrypy.serving.response + response.body = tee(response.body) + + +def expires(secs=0, force=False, debug=False): + """Tool for influencing cache mechanisms using the 'Expires' header. + + secs + Must be either an int or a datetime.timedelta, and indicates the + number of seconds between response.time and when the response should + expire. The 'Expires' header will be set to response.time + secs. + If secs is zero, the 'Expires' header is set one year in the past, and + the following "cache prevention" headers are also set: + + * Pragma: no-cache + * Cache-Control': no-cache, must-revalidate + + force + If False, the following headers are checked: + + * Etag + * Last-Modified + * Age + * Expires + + If any are already present, none of the above response headers are set. + + """ + + response = cherrypy.serving.response + headers = response.headers + + cacheable = False + if not force: + # some header names that indicate that the response can be cached + for indicator in ('Etag', 'Last-Modified', 'Age', 'Expires'): + if indicator in headers: + cacheable = True + break + + if not cacheable and not force: + if debug: + cherrypy.log('request is not cacheable', 'TOOLS.EXPIRES') + else: + if debug: + cherrypy.log('request is cacheable', 'TOOLS.EXPIRES') + if isinstance(secs, datetime.timedelta): + secs = (86400 * secs.days) + secs.seconds + + if secs == 0: + if force or ("Pragma" not in headers): + headers["Pragma"] = "no-cache" + if cherrypy.serving.request.protocol >= (1, 1): + if force or "Cache-Control" not in headers: + headers["Cache-Control"] = "no-cache, must-revalidate" + # Set an explicit Expires date in the past. + expiry = httputil.HTTPDate(1169942400.0) + else: + expiry = httputil.HTTPDate(response.time + secs) + if force or "Expires" not in headers: + headers["Expires"] = expiry diff --git a/lib/cherrypy/lib/covercp.py b/lib/cherrypy/lib/covercp.py new file mode 100644 index 00000000..a74ec342 --- /dev/null +++ b/lib/cherrypy/lib/covercp.py @@ -0,0 +1,387 @@ +"""Code-coverage tools for CherryPy. + +To use this module, or the coverage tools in the test suite, +you need to download 'coverage.py', either Gareth Rees' `original +implementation `_ +or Ned Batchelder's `enhanced version: +`_ + +To turn on coverage tracing, use the following code:: + + cherrypy.engine.subscribe('start', covercp.start) + +DO NOT subscribe anything on the 'start_thread' channel, as previously +recommended. Calling start once in the main thread should be sufficient +to start coverage on all threads. Calling start again in each thread +effectively clears any coverage data gathered up to that point. + +Run your code, then use the ``covercp.serve()`` function to browse the +results in a web browser. If you run this module from the command line, +it will call ``serve()`` for you. +""" + +import re +import sys +import cgi +from cherrypy._cpcompat import quote_plus +import os +import os.path +localFile = os.path.join(os.path.dirname(__file__), "coverage.cache") + +the_coverage = None +try: + from coverage import coverage + the_coverage = coverage(data_file=localFile) + + def start(): + the_coverage.start() +except ImportError: + # Setting the_coverage to None will raise errors + # that need to be trapped downstream. + the_coverage = None + + import warnings + warnings.warn( + "No code coverage will be performed; " + "coverage.py could not be imported.") + + def start(): + pass +start.priority = 20 + +TEMPLATE_MENU = """ + + CherryPy Coverage Menu + + + +

CherryPy Coverage

""" + +TEMPLATE_FORM = """ +
+
+ + Show percentages +
+ Hide files over + %%
+ Exclude files matching
+ +
+ + +
+
""" + +TEMPLATE_FRAMESET = """ +CherryPy coverage data + + + + + +""" + +TEMPLATE_COVERAGE = """ + + Coverage for %(name)s + + + +

%(name)s

+

%(fullpath)s

+

Coverage: %(pc)s%%

""" + +TEMPLATE_LOC_COVERED = """ + %s  + %s +\n""" +TEMPLATE_LOC_NOT_COVERED = """ + %s  + %s +\n""" +TEMPLATE_LOC_EXCLUDED = """ + %s  + %s +\n""" + +TEMPLATE_ITEM = ( + "%s%s%s\n" +) + + +def _percent(statements, missing): + s = len(statements) + e = s - len(missing) + if s > 0: + return int(round(100.0 * e / s)) + return 0 + + +def _show_branch(root, base, path, pct=0, showpct=False, exclude="", + coverage=the_coverage): + + # Show the directory name and any of our children + dirs = [k for k, v in root.items() if v] + dirs.sort() + for name in dirs: + newpath = os.path.join(path, name) + + if newpath.lower().startswith(base): + relpath = newpath[len(base):] + yield "| " * relpath.count(os.sep) + yield ( + "%s\n" % + (newpath, quote_plus(exclude), name) + ) + + for chunk in _show_branch( + root[name], base, newpath, pct, showpct, + exclude, coverage=coverage + ): + yield chunk + + # Now list the files + if path.lower().startswith(base): + relpath = path[len(base):] + files = [k for k, v in root.items() if not v] + files.sort() + for name in files: + newpath = os.path.join(path, name) + + pc_str = "" + if showpct: + try: + _, statements, _, missing, _ = coverage.analysis2(newpath) + except: + # Yes, we really want to pass on all errors. + pass + else: + pc = _percent(statements, missing) + pc_str = ("%3d%% " % pc).replace(' ', ' ') + if pc < float(pct) or pc == -1: + pc_str = "%s" % pc_str + else: + pc_str = "%s" % pc_str + + yield TEMPLATE_ITEM % ("| " * (relpath.count(os.sep) + 1), + pc_str, newpath, name) + + +def _skip_file(path, exclude): + if exclude: + return bool(re.search(exclude, path)) + + +def _graft(path, tree): + d = tree + + p = path + atoms = [] + while True: + p, tail = os.path.split(p) + if not tail: + break + atoms.append(tail) + atoms.append(p) + if p != "/": + atoms.append("/") + + atoms.reverse() + for node in atoms: + if node: + d = d.setdefault(node, {}) + + +def get_tree(base, exclude, coverage=the_coverage): + """Return covered module names as a nested dict.""" + tree = {} + runs = coverage.data.executed_files() + for path in runs: + if not _skip_file(path, exclude) and not os.path.isdir(path): + _graft(path, tree) + return tree + + +class CoverStats(object): + + def __init__(self, coverage, root=None): + self.coverage = coverage + if root is None: + # Guess initial depth. Files outside this path will not be + # reachable from the web interface. + import cherrypy + root = os.path.dirname(cherrypy.__file__) + self.root = root + + def index(self): + return TEMPLATE_FRAMESET % self.root.lower() + index.exposed = True + + def menu(self, base="/", pct="50", showpct="", + exclude=r'python\d\.\d|test|tut\d|tutorial'): + + # The coverage module uses all-lower-case names. + base = base.lower().rstrip(os.sep) + + yield TEMPLATE_MENU + yield TEMPLATE_FORM % locals() + + # Start by showing links for parent paths + yield "
" + path = "" + atoms = base.split(os.sep) + atoms.pop() + for atom in atoms: + path += atom + os.sep + yield ("%s %s" + % (path, quote_plus(exclude), atom, os.sep)) + yield "
" + + yield "
" + + # Then display the tree + tree = get_tree(base, exclude, self.coverage) + if not tree: + yield "

No modules covered.

" + else: + for chunk in _show_branch(tree, base, "/", pct, + showpct == 'checked', exclude, + coverage=self.coverage): + yield chunk + + yield "
" + yield "" + menu.exposed = True + + def annotated_file(self, filename, statements, excluded, missing): + source = open(filename, 'r') + buffer = [] + for lineno, line in enumerate(source.readlines()): + lineno += 1 + line = line.strip("\n\r") + empty_the_buffer = True + if lineno in excluded: + template = TEMPLATE_LOC_EXCLUDED + elif lineno in missing: + template = TEMPLATE_LOC_NOT_COVERED + elif lineno in statements: + template = TEMPLATE_LOC_COVERED + else: + empty_the_buffer = False + buffer.append((lineno, line)) + if empty_the_buffer: + for lno, pastline in buffer: + yield template % (lno, cgi.escape(pastline)) + buffer = [] + yield template % (lineno, cgi.escape(line)) + + def report(self, name): + filename, statements, excluded, missing, _ = self.coverage.analysis2( + name) + pc = _percent(statements, missing) + yield TEMPLATE_COVERAGE % dict(name=os.path.basename(name), + fullpath=name, + pc=pc) + yield '\n' + for line in self.annotated_file(filename, statements, excluded, + missing): + yield line + yield '
' + yield '' + yield '' + report.exposed = True + + +def serve(path=localFile, port=8080, root=None): + if coverage is None: + raise ImportError("The coverage module could not be imported.") + from coverage import coverage + cov = coverage(data_file=path) + cov.load() + + import cherrypy + cherrypy.config.update({'server.socket_port': int(port), + 'server.thread_pool': 10, + 'environment': "production", + }) + cherrypy.quickstart(CoverStats(cov, root)) + +if __name__ == "__main__": + serve(*tuple(sys.argv[1:])) diff --git a/lib/cherrypy/lib/cpstats.py b/lib/cherrypy/lib/cpstats.py new file mode 100644 index 00000000..a8661a14 --- /dev/null +++ b/lib/cherrypy/lib/cpstats.py @@ -0,0 +1,687 @@ +"""CPStats, a package for collecting and reporting on program statistics. + +Overview +======== + +Statistics about program operation are an invaluable monitoring and debugging +tool. Unfortunately, the gathering and reporting of these critical values is +usually ad-hoc. This package aims to add a centralized place for gathering +statistical performance data, a structure for recording that data which +provides for extrapolation of that data into more useful information, +and a method of serving that data to both human investigators and +monitoring software. Let's examine each of those in more detail. + +Data Gathering +-------------- + +Just as Python's `logging` module provides a common importable for gathering +and sending messages, performance statistics would benefit from a similar +common mechanism, and one that does *not* require each package which wishes +to collect stats to import a third-party module. Therefore, we choose to +re-use the `logging` module by adding a `statistics` object to it. + +That `logging.statistics` object is a nested dict. It is not a custom class, +because that would: + + 1. require libraries and applications to import a third-party module in + order to participate + 2. inhibit innovation in extrapolation approaches and in reporting tools, and + 3. be slow. + +There are, however, some specifications regarding the structure of the dict.:: + + { + +----"SQLAlchemy": { + | "Inserts": 4389745, + | "Inserts per Second": + | lambda s: s["Inserts"] / (time() - s["Start"]), + | C +---"Table Statistics": { + | o | "widgets": {-----------+ + N | l | "Rows": 1.3M, | Record + a | l | "Inserts": 400, | + m | e | },---------------------+ + e | c | "froobles": { + s | t | "Rows": 7845, + p | i | "Inserts": 0, + a | o | }, + c | n +---}, + e | "Slow Queries": + | [{"Query": "SELECT * FROM widgets;", + | "Processing Time": 47.840923343, + | }, + | ], + +----}, + } + +The `logging.statistics` dict has four levels. The topmost level is nothing +more than a set of names to introduce modularity, usually along the lines of +package names. If the SQLAlchemy project wanted to participate, for example, +it might populate the item `logging.statistics['SQLAlchemy']`, whose value +would be a second-layer dict we call a "namespace". Namespaces help multiple +packages to avoid collisions over key names, and make reports easier to read, +to boot. The maintainers of SQLAlchemy should feel free to use more than one +namespace if needed (such as 'SQLAlchemy ORM'). Note that there are no case +or other syntax constraints on the namespace names; they should be chosen +to be maximally readable by humans (neither too short nor too long). + +Each namespace, then, is a dict of named statistical values, such as +'Requests/sec' or 'Uptime'. You should choose names which will look +good on a report: spaces and capitalization are just fine. + +In addition to scalars, values in a namespace MAY be a (third-layer) +dict, or a list, called a "collection". For example, the CherryPy +:class:`StatsTool` keeps track of what each request is doing (or has most +recently done) in a 'Requests' collection, where each key is a thread ID; each +value in the subdict MUST be a fourth dict (whew!) of statistical data about +each thread. We call each subdict in the collection a "record". Similarly, +the :class:`StatsTool` also keeps a list of slow queries, where each record +contains data about each slow query, in order. + +Values in a namespace or record may also be functions, which brings us to: + +Extrapolation +------------- + +The collection of statistical data needs to be fast, as close to unnoticeable +as possible to the host program. That requires us to minimize I/O, for example, +but in Python it also means we need to minimize function calls. So when you +are designing your namespace and record values, try to insert the most basic +scalar values you already have on hand. + +When it comes time to report on the gathered data, however, we usually have +much more freedom in what we can calculate. Therefore, whenever reporting +tools (like the provided :class:`StatsPage` CherryPy class) fetch the contents +of `logging.statistics` for reporting, they first call +`extrapolate_statistics` (passing the whole `statistics` dict as the only +argument). This makes a deep copy of the statistics dict so that the +reporting tool can both iterate over it and even change it without harming +the original. But it also expands any functions in the dict by calling them. +For example, you might have a 'Current Time' entry in the namespace with the +value "lambda scope: time.time()". The "scope" parameter is the current +namespace dict (or record, if we're currently expanding one of those +instead), allowing you access to existing static entries. If you're truly +evil, you can even modify more than one entry at a time. + +However, don't try to calculate an entry and then use its value in further +extrapolations; the order in which the functions are called is not guaranteed. +This can lead to a certain amount of duplicated work (or a redesign of your +schema), but that's better than complicating the spec. + +After the whole thing has been extrapolated, it's time for: + +Reporting +--------- + +The :class:`StatsPage` class grabs the `logging.statistics` dict, extrapolates +it all, and then transforms it to HTML for easy viewing. Each namespace gets +its own header and attribute table, plus an extra table for each collection. +This is NOT part of the statistics specification; other tools can format how +they like. + +You can control which columns are output and how they are formatted by updating +StatsPage.formatting, which is a dict that mirrors the keys and nesting of +`logging.statistics`. The difference is that, instead of data values, it has +formatting values. Use None for a given key to indicate to the StatsPage that a +given column should not be output. Use a string with formatting +(such as '%.3f') to interpolate the value(s), or use a callable (such as +lambda v: v.isoformat()) for more advanced formatting. Any entry which is not +mentioned in the formatting dict is output unchanged. + +Monitoring +---------- + +Although the HTML output takes pains to assign unique id's to each with +statistical data, you're probably better off fetching /cpstats/data, which +outputs the whole (extrapolated) `logging.statistics` dict in JSON format. +That is probably easier to parse, and doesn't have any formatting controls, +so you get the "original" data in a consistently-serialized format. +Note: there's no treatment yet for datetime objects. Try time.time() instead +for now if you can. Nagios will probably thank you. + +Turning Collection Off +---------------------- + +It is recommended each namespace have an "Enabled" item which, if False, +stops collection (but not reporting) of statistical data. Applications +SHOULD provide controls to pause and resume collection by setting these +entries to False or True, if present. + + +Usage +===== + +To collect statistics on CherryPy applications:: + + from cherrypy.lib import cpstats + appconfig['/']['tools.cpstats.on'] = True + +To collect statistics on your own code:: + + import logging + # Initialize the repository + if not hasattr(logging, 'statistics'): logging.statistics = {} + # Initialize my namespace + mystats = logging.statistics.setdefault('My Stuff', {}) + # Initialize my namespace's scalars and collections + mystats.update({ + 'Enabled': True, + 'Start Time': time.time(), + 'Important Events': 0, + 'Events/Second': lambda s: ( + (s['Important Events'] / (time.time() - s['Start Time']))), + }) + ... + for event in events: + ... + # Collect stats + if mystats.get('Enabled', False): + mystats['Important Events'] += 1 + +To report statistics:: + + root.cpstats = cpstats.StatsPage() + +To format statistics reports:: + + See 'Reporting', above. + +""" + +# ------------------------------- Statistics -------------------------------- # + +import logging +if not hasattr(logging, 'statistics'): + logging.statistics = {} + + +def extrapolate_statistics(scope): + """Return an extrapolated copy of the given scope.""" + c = {} + for k, v in list(scope.items()): + if isinstance(v, dict): + v = extrapolate_statistics(v) + elif isinstance(v, (list, tuple)): + v = [extrapolate_statistics(record) for record in v] + elif hasattr(v, '__call__'): + v = v(scope) + c[k] = v + return c + + +# -------------------- CherryPy Applications Statistics --------------------- # + +import threading +import time + +import cherrypy + +appstats = logging.statistics.setdefault('CherryPy Applications', {}) +appstats.update({ + 'Enabled': True, + 'Bytes Read/Request': lambda s: ( + s['Total Requests'] and + (s['Total Bytes Read'] / float(s['Total Requests'])) or + 0.0 + ), + 'Bytes Read/Second': lambda s: s['Total Bytes Read'] / s['Uptime'](s), + 'Bytes Written/Request': lambda s: ( + s['Total Requests'] and + (s['Total Bytes Written'] / float(s['Total Requests'])) or + 0.0 + ), + 'Bytes Written/Second': lambda s: ( + s['Total Bytes Written'] / s['Uptime'](s) + ), + 'Current Time': lambda s: time.time(), + 'Current Requests': 0, + 'Requests/Second': lambda s: float(s['Total Requests']) / s['Uptime'](s), + 'Server Version': cherrypy.__version__, + 'Start Time': time.time(), + 'Total Bytes Read': 0, + 'Total Bytes Written': 0, + 'Total Requests': 0, + 'Total Time': 0, + 'Uptime': lambda s: time.time() - s['Start Time'], + 'Requests': {}, +}) + +proc_time = lambda s: time.time() - s['Start Time'] + + +class ByteCountWrapper(object): + + """Wraps a file-like object, counting the number of bytes read.""" + + def __init__(self, rfile): + self.rfile = rfile + self.bytes_read = 0 + + def read(self, size=-1): + data = self.rfile.read(size) + self.bytes_read += len(data) + return data + + def readline(self, size=-1): + data = self.rfile.readline(size) + self.bytes_read += len(data) + return data + + def readlines(self, sizehint=0): + # Shamelessly stolen from StringIO + total = 0 + lines = [] + line = self.readline() + while line: + lines.append(line) + total += len(line) + if 0 < sizehint <= total: + break + line = self.readline() + return lines + + def close(self): + self.rfile.close() + + def __iter__(self): + return self + + def next(self): + data = self.rfile.next() + self.bytes_read += len(data) + return data + + +average_uriset_time = lambda s: s['Count'] and (s['Sum'] / s['Count']) or 0 + + +class StatsTool(cherrypy.Tool): + + """Record various information about the current request.""" + + def __init__(self): + cherrypy.Tool.__init__(self, 'on_end_request', self.record_stop) + + def _setup(self): + """Hook this tool into cherrypy.request. + + The standard CherryPy request object will automatically call this + method when the tool is "turned on" in config. + """ + if appstats.get('Enabled', False): + cherrypy.Tool._setup(self) + self.record_start() + + def record_start(self): + """Record the beginning of a request.""" + request = cherrypy.serving.request + if not hasattr(request.rfile, 'bytes_read'): + request.rfile = ByteCountWrapper(request.rfile) + request.body.fp = request.rfile + + r = request.remote + + appstats['Current Requests'] += 1 + appstats['Total Requests'] += 1 + appstats['Requests'][threading._get_ident()] = { + 'Bytes Read': None, + 'Bytes Written': None, + # Use a lambda so the ip gets updated by tools.proxy later + 'Client': lambda s: '%s:%s' % (r.ip, r.port), + 'End Time': None, + 'Processing Time': proc_time, + 'Request-Line': request.request_line, + 'Response Status': None, + 'Start Time': time.time(), + } + + def record_stop( + self, uriset=None, slow_queries=1.0, slow_queries_count=100, + debug=False, **kwargs): + """Record the end of a request.""" + resp = cherrypy.serving.response + w = appstats['Requests'][threading._get_ident()] + + r = cherrypy.request.rfile.bytes_read + w['Bytes Read'] = r + appstats['Total Bytes Read'] += r + + if resp.stream: + w['Bytes Written'] = 'chunked' + else: + cl = int(resp.headers.get('Content-Length', 0)) + w['Bytes Written'] = cl + appstats['Total Bytes Written'] += cl + + w['Response Status'] = getattr( + resp, 'output_status', None) or resp.status + + w['End Time'] = time.time() + p = w['End Time'] - w['Start Time'] + w['Processing Time'] = p + appstats['Total Time'] += p + + appstats['Current Requests'] -= 1 + + if debug: + cherrypy.log('Stats recorded: %s' % repr(w), 'TOOLS.CPSTATS') + + if uriset: + rs = appstats.setdefault('URI Set Tracking', {}) + r = rs.setdefault(uriset, { + 'Min': None, 'Max': None, 'Count': 0, 'Sum': 0, + 'Avg': average_uriset_time}) + if r['Min'] is None or p < r['Min']: + r['Min'] = p + if r['Max'] is None or p > r['Max']: + r['Max'] = p + r['Count'] += 1 + r['Sum'] += p + + if slow_queries and p > slow_queries: + sq = appstats.setdefault('Slow Queries', []) + sq.append(w.copy()) + if len(sq) > slow_queries_count: + sq.pop(0) + + +import cherrypy +cherrypy.tools.cpstats = StatsTool() + + +# ---------------------- CherryPy Statistics Reporting ---------------------- # + +import os +thisdir = os.path.abspath(os.path.dirname(__file__)) + +try: + import json +except ImportError: + try: + import simplejson as json + except ImportError: + json = None + + +missing = object() + +locale_date = lambda v: time.strftime('%c', time.gmtime(v)) +iso_format = lambda v: time.strftime('%Y-%m-%d %H:%M:%S', time.gmtime(v)) + + +def pause_resume(ns): + def _pause_resume(enabled): + pause_disabled = '' + resume_disabled = '' + if enabled: + resume_disabled = 'disabled="disabled" ' + else: + pause_disabled = 'disabled="disabled" ' + return """ +
+ + +
+
+ + +
+ """ % (ns, pause_disabled, ns, resume_disabled) + return _pause_resume + + +class StatsPage(object): + + formatting = { + 'CherryPy Applications': { + 'Enabled': pause_resume('CherryPy Applications'), + 'Bytes Read/Request': '%.3f', + 'Bytes Read/Second': '%.3f', + 'Bytes Written/Request': '%.3f', + 'Bytes Written/Second': '%.3f', + 'Current Time': iso_format, + 'Requests/Second': '%.3f', + 'Start Time': iso_format, + 'Total Time': '%.3f', + 'Uptime': '%.3f', + 'Slow Queries': { + 'End Time': None, + 'Processing Time': '%.3f', + 'Start Time': iso_format, + }, + 'URI Set Tracking': { + 'Avg': '%.3f', + 'Max': '%.3f', + 'Min': '%.3f', + 'Sum': '%.3f', + }, + 'Requests': { + 'Bytes Read': '%s', + 'Bytes Written': '%s', + 'End Time': None, + 'Processing Time': '%.3f', + 'Start Time': None, + }, + }, + 'CherryPy WSGIServer': { + 'Enabled': pause_resume('CherryPy WSGIServer'), + 'Connections/second': '%.3f', + 'Start time': iso_format, + }, + } + + def index(self): + # Transform the raw data into pretty output for HTML + yield """ + + + Statistics + + + +""" + for title, scalars, collections in self.get_namespaces(): + yield """ +

%s

+ + + +""" % title + for i, (key, value) in enumerate(scalars): + colnum = i % 3 + if colnum == 0: + yield """ + """ + yield ( + """ + """ % + vars() + ) + if colnum == 2: + yield """ + """ + + if colnum == 0: + yield """ + + + """ + elif colnum == 1: + yield """ + + """ + yield """ + +
%(key)s%(value)s
""" + + for subtitle, headers, subrows in collections: + yield """ +

%s

+ + + """ % subtitle + for key in headers: + yield """ + """ % key + yield """ + + + """ + for subrow in subrows: + yield """ + """ + for value in subrow: + yield """ + """ % value + yield """ + """ + yield """ + +
%s
%s
""" + yield """ + + +""" + index.exposed = True + + def get_namespaces(self): + """Yield (title, scalars, collections) for each namespace.""" + s = extrapolate_statistics(logging.statistics) + for title, ns in sorted(s.items()): + scalars = [] + collections = [] + ns_fmt = self.formatting.get(title, {}) + for k, v in sorted(ns.items()): + fmt = ns_fmt.get(k, {}) + if isinstance(v, dict): + headers, subrows = self.get_dict_collection(v, fmt) + collections.append((k, ['ID'] + headers, subrows)) + elif isinstance(v, (list, tuple)): + headers, subrows = self.get_list_collection(v, fmt) + collections.append((k, headers, subrows)) + else: + format = ns_fmt.get(k, missing) + if format is None: + # Don't output this column. + continue + if hasattr(format, '__call__'): + v = format(v) + elif format is not missing: + v = format % v + scalars.append((k, v)) + yield title, scalars, collections + + def get_dict_collection(self, v, formatting): + """Return ([headers], [rows]) for the given collection.""" + # E.g., the 'Requests' dict. + headers = [] + for record in v.itervalues(): + for k3 in record: + format = formatting.get(k3, missing) + if format is None: + # Don't output this column. + continue + if k3 not in headers: + headers.append(k3) + headers.sort() + + subrows = [] + for k2, record in sorted(v.items()): + subrow = [k2] + for k3 in headers: + v3 = record.get(k3, '') + format = formatting.get(k3, missing) + if format is None: + # Don't output this column. + continue + if hasattr(format, '__call__'): + v3 = format(v3) + elif format is not missing: + v3 = format % v3 + subrow.append(v3) + subrows.append(subrow) + + return headers, subrows + + def get_list_collection(self, v, formatting): + """Return ([headers], [subrows]) for the given collection.""" + # E.g., the 'Slow Queries' list. + headers = [] + for record in v: + for k3 in record: + format = formatting.get(k3, missing) + if format is None: + # Don't output this column. + continue + if k3 not in headers: + headers.append(k3) + headers.sort() + + subrows = [] + for record in v: + subrow = [] + for k3 in headers: + v3 = record.get(k3, '') + format = formatting.get(k3, missing) + if format is None: + # Don't output this column. + continue + if hasattr(format, '__call__'): + v3 = format(v3) + elif format is not missing: + v3 = format % v3 + subrow.append(v3) + subrows.append(subrow) + + return headers, subrows + + if json is not None: + def data(self): + s = extrapolate_statistics(logging.statistics) + cherrypy.response.headers['Content-Type'] = 'application/json' + return json.dumps(s, sort_keys=True, indent=4) + data.exposed = True + + def pause(self, namespace): + logging.statistics.get(namespace, {})['Enabled'] = False + raise cherrypy.HTTPRedirect('./') + pause.exposed = True + pause.cp_config = {'tools.allow.on': True, + 'tools.allow.methods': ['POST']} + + def resume(self, namespace): + logging.statistics.get(namespace, {})['Enabled'] = True + raise cherrypy.HTTPRedirect('./') + resume.exposed = True + resume.cp_config = {'tools.allow.on': True, + 'tools.allow.methods': ['POST']} diff --git a/lib/cherrypy/lib/cptools.py b/lib/cherrypy/lib/cptools.py new file mode 100644 index 00000000..f376282c --- /dev/null +++ b/lib/cherrypy/lib/cptools.py @@ -0,0 +1,630 @@ +"""Functions for builtin CherryPy tools.""" + +import logging +import re + +import cherrypy +from cherrypy._cpcompat import basestring, md5, set, unicodestr +from cherrypy.lib import httputil as _httputil +from cherrypy.lib import is_iterator + + +# Conditional HTTP request support # + +def validate_etags(autotags=False, debug=False): + """Validate the current ETag against If-Match, If-None-Match headers. + + If autotags is True, an ETag response-header value will be provided + from an MD5 hash of the response body (unless some other code has + already provided an ETag header). If False (the default), the ETag + will not be automatic. + + WARNING: the autotags feature is not designed for URL's which allow + methods other than GET. For example, if a POST to the same URL returns + no content, the automatic ETag will be incorrect, breaking a fundamental + use for entity tags in a possibly destructive fashion. Likewise, if you + raise 304 Not Modified, the response body will be empty, the ETag hash + will be incorrect, and your application will break. + See :rfc:`2616` Section 14.24. + """ + response = cherrypy.serving.response + + # Guard against being run twice. + if hasattr(response, "ETag"): + return + + status, reason, msg = _httputil.valid_status(response.status) + + etag = response.headers.get('ETag') + + # Automatic ETag generation. See warning in docstring. + if etag: + if debug: + cherrypy.log('ETag already set: %s' % etag, 'TOOLS.ETAGS') + elif not autotags: + if debug: + cherrypy.log('Autotags off', 'TOOLS.ETAGS') + elif status != 200: + if debug: + cherrypy.log('Status not 200', 'TOOLS.ETAGS') + else: + etag = response.collapse_body() + etag = '"%s"' % md5(etag).hexdigest() + if debug: + cherrypy.log('Setting ETag: %s' % etag, 'TOOLS.ETAGS') + response.headers['ETag'] = etag + + response.ETag = etag + + # "If the request would, without the If-Match header field, result in + # anything other than a 2xx or 412 status, then the If-Match header + # MUST be ignored." + if debug: + cherrypy.log('Status: %s' % status, 'TOOLS.ETAGS') + if status >= 200 and status <= 299: + request = cherrypy.serving.request + + conditions = request.headers.elements('If-Match') or [] + conditions = [str(x) for x in conditions] + if debug: + cherrypy.log('If-Match conditions: %s' % repr(conditions), + 'TOOLS.ETAGS') + if conditions and not (conditions == ["*"] or etag in conditions): + raise cherrypy.HTTPError(412, "If-Match failed: ETag %r did " + "not match %r" % (etag, conditions)) + + conditions = request.headers.elements('If-None-Match') or [] + conditions = [str(x) for x in conditions] + if debug: + cherrypy.log('If-None-Match conditions: %s' % repr(conditions), + 'TOOLS.ETAGS') + if conditions == ["*"] or etag in conditions: + if debug: + cherrypy.log('request.method: %s' % + request.method, 'TOOLS.ETAGS') + if request.method in ("GET", "HEAD"): + raise cherrypy.HTTPRedirect([], 304) + else: + raise cherrypy.HTTPError(412, "If-None-Match failed: ETag %r " + "matched %r" % (etag, conditions)) + + +def validate_since(): + """Validate the current Last-Modified against If-Modified-Since headers. + + If no code has set the Last-Modified response header, then no validation + will be performed. + """ + response = cherrypy.serving.response + lastmod = response.headers.get('Last-Modified') + if lastmod: + status, reason, msg = _httputil.valid_status(response.status) + + request = cherrypy.serving.request + + since = request.headers.get('If-Unmodified-Since') + if since and since != lastmod: + if (status >= 200 and status <= 299) or status == 412: + raise cherrypy.HTTPError(412) + + since = request.headers.get('If-Modified-Since') + if since and since == lastmod: + if (status >= 200 and status <= 299) or status == 304: + if request.method in ("GET", "HEAD"): + raise cherrypy.HTTPRedirect([], 304) + else: + raise cherrypy.HTTPError(412) + + +# Tool code # + +def allow(methods=None, debug=False): + """Raise 405 if request.method not in methods (default ['GET', 'HEAD']). + + The given methods are case-insensitive, and may be in any order. + If only one method is allowed, you may supply a single string; + if more than one, supply a list of strings. + + Regardless of whether the current method is allowed or not, this + also emits an 'Allow' response header, containing the given methods. + """ + if not isinstance(methods, (tuple, list)): + methods = [methods] + methods = [m.upper() for m in methods if m] + if not methods: + methods = ['GET', 'HEAD'] + elif 'GET' in methods and 'HEAD' not in methods: + methods.append('HEAD') + + cherrypy.response.headers['Allow'] = ', '.join(methods) + if cherrypy.request.method not in methods: + if debug: + cherrypy.log('request.method %r not in methods %r' % + (cherrypy.request.method, methods), 'TOOLS.ALLOW') + raise cherrypy.HTTPError(405) + else: + if debug: + cherrypy.log('request.method %r in methods %r' % + (cherrypy.request.method, methods), 'TOOLS.ALLOW') + + +def proxy(base=None, local='X-Forwarded-Host', remote='X-Forwarded-For', + scheme='X-Forwarded-Proto', debug=False): + """Change the base URL (scheme://host[:port][/path]). + + For running a CP server behind Apache, lighttpd, or other HTTP server. + + For Apache and lighttpd, you should leave the 'local' argument at the + default value of 'X-Forwarded-Host'. For Squid, you probably want to set + tools.proxy.local = 'Origin'. + + If you want the new request.base to include path info (not just the host), + you must explicitly set base to the full base path, and ALSO set 'local' + to '', so that the X-Forwarded-Host request header (which never includes + path info) does not override it. Regardless, the value for 'base' MUST + NOT end in a slash. + + cherrypy.request.remote.ip (the IP address of the client) will be + rewritten if the header specified by the 'remote' arg is valid. + By default, 'remote' is set to 'X-Forwarded-For'. If you do not + want to rewrite remote.ip, set the 'remote' arg to an empty string. + """ + + request = cherrypy.serving.request + + if scheme: + s = request.headers.get(scheme, None) + if debug: + cherrypy.log('Testing scheme %r:%r' % (scheme, s), 'TOOLS.PROXY') + if s == 'on' and 'ssl' in scheme.lower(): + # This handles e.g. webfaction's 'X-Forwarded-Ssl: on' header + scheme = 'https' + else: + # This is for lighttpd/pound/Mongrel's 'X-Forwarded-Proto: https' + scheme = s + if not scheme: + scheme = request.base[:request.base.find("://")] + + if local: + lbase = request.headers.get(local, None) + if debug: + cherrypy.log('Testing local %r:%r' % (local, lbase), 'TOOLS.PROXY') + if lbase is not None: + base = lbase.split(',')[0] + if not base: + port = request.local.port + if port == 80: + base = '127.0.0.1' + else: + base = '127.0.0.1:%s' % port + + if base.find("://") == -1: + # add http:// or https:// if needed + base = scheme + "://" + base + + request.base = base + + if remote: + xff = request.headers.get(remote) + if debug: + cherrypy.log('Testing remote %r:%r' % (remote, xff), 'TOOLS.PROXY') + if xff: + if remote == 'X-Forwarded-For': + #Bug #1268 + xff = xff.split(',')[0].strip() + request.remote.ip = xff + + +def ignore_headers(headers=('Range',), debug=False): + """Delete request headers whose field names are included in 'headers'. + + This is a useful tool for working behind certain HTTP servers; + for example, Apache duplicates the work that CP does for 'Range' + headers, and will doubly-truncate the response. + """ + request = cherrypy.serving.request + for name in headers: + if name in request.headers: + if debug: + cherrypy.log('Ignoring request header %r' % name, + 'TOOLS.IGNORE_HEADERS') + del request.headers[name] + + +def response_headers(headers=None, debug=False): + """Set headers on the response.""" + if debug: + cherrypy.log('Setting response headers: %s' % repr(headers), + 'TOOLS.RESPONSE_HEADERS') + for name, value in (headers or []): + cherrypy.serving.response.headers[name] = value +response_headers.failsafe = True + + +def referer(pattern, accept=True, accept_missing=False, error=403, + message='Forbidden Referer header.', debug=False): + """Raise HTTPError if Referer header does/does not match the given pattern. + + pattern + A regular expression pattern to test against the Referer. + + accept + If True, the Referer must match the pattern; if False, + the Referer must NOT match the pattern. + + accept_missing + If True, permit requests with no Referer header. + + error + The HTTP error code to return to the client on failure. + + message + A string to include in the response body on failure. + + """ + try: + ref = cherrypy.serving.request.headers['Referer'] + match = bool(re.match(pattern, ref)) + if debug: + cherrypy.log('Referer %r matches %r' % (ref, pattern), + 'TOOLS.REFERER') + if accept == match: + return + except KeyError: + if debug: + cherrypy.log('No Referer header', 'TOOLS.REFERER') + if accept_missing: + return + + raise cherrypy.HTTPError(error, message) + + +class SessionAuth(object): + + """Assert that the user is logged in.""" + + session_key = "username" + debug = False + + def check_username_and_password(self, username, password): + pass + + def anonymous(self): + """Provide a temporary user name for anonymous users.""" + pass + + def on_login(self, username): + pass + + def on_logout(self, username): + pass + + def on_check(self, username): + pass + + def login_screen(self, from_page='..', username='', error_msg='', + **kwargs): + return (unicodestr(""" +Message: %(error_msg)s +
+ Login: +
+ Password: +
+ +
+ +
+""") % vars()).encode("utf-8") + + def do_login(self, username, password, from_page='..', **kwargs): + """Login. May raise redirect, or return True if request handled.""" + response = cherrypy.serving.response + error_msg = self.check_username_and_password(username, password) + if error_msg: + body = self.login_screen(from_page, username, error_msg) + response.body = body + if "Content-Length" in response.headers: + # Delete Content-Length header so finalize() recalcs it. + del response.headers["Content-Length"] + return True + else: + cherrypy.serving.request.login = username + cherrypy.session[self.session_key] = username + self.on_login(username) + raise cherrypy.HTTPRedirect(from_page or "/") + + def do_logout(self, from_page='..', **kwargs): + """Logout. May raise redirect, or return True if request handled.""" + sess = cherrypy.session + username = sess.get(self.session_key) + sess[self.session_key] = None + if username: + cherrypy.serving.request.login = None + self.on_logout(username) + raise cherrypy.HTTPRedirect(from_page) + + def do_check(self): + """Assert username. Raise redirect, or return True if request handled. + """ + sess = cherrypy.session + request = cherrypy.serving.request + response = cherrypy.serving.response + + username = sess.get(self.session_key) + if not username: + sess[self.session_key] = username = self.anonymous() + self._debug_message('No session[username], trying anonymous') + if not username: + url = cherrypy.url(qs=request.query_string) + self._debug_message( + 'No username, routing to login_screen with from_page %(url)r', + locals(), + ) + response.body = self.login_screen(url) + if "Content-Length" in response.headers: + # Delete Content-Length header so finalize() recalcs it. + del response.headers["Content-Length"] + return True + self._debug_message('Setting request.login to %(username)r', locals()) + request.login = username + self.on_check(username) + + def _debug_message(self, template, context={}): + if not self.debug: + return + cherrypy.log(template % context, 'TOOLS.SESSAUTH') + + def run(self): + request = cherrypy.serving.request + response = cherrypy.serving.response + + path = request.path_info + if path.endswith('login_screen'): + self._debug_message('routing %(path)r to login_screen', locals()) + response.body = self.login_screen() + return True + elif path.endswith('do_login'): + if request.method != 'POST': + response.headers['Allow'] = "POST" + self._debug_message('do_login requires POST') + raise cherrypy.HTTPError(405) + self._debug_message('routing %(path)r to do_login', locals()) + return self.do_login(**request.params) + elif path.endswith('do_logout'): + if request.method != 'POST': + response.headers['Allow'] = "POST" + raise cherrypy.HTTPError(405) + self._debug_message('routing %(path)r to do_logout', locals()) + return self.do_logout(**request.params) + else: + self._debug_message('No special path, running do_check') + return self.do_check() + + +def session_auth(**kwargs): + sa = SessionAuth() + for k, v in kwargs.items(): + setattr(sa, k, v) + return sa.run() +session_auth.__doc__ = """Session authentication hook. + +Any attribute of the SessionAuth class may be overridden via a keyword arg +to this function: + +""" + "\n".join(["%s: %s" % (k, type(getattr(SessionAuth, k)).__name__) + for k in dir(SessionAuth) if not k.startswith("__")]) + + +def log_traceback(severity=logging.ERROR, debug=False): + """Write the last error's traceback to the cherrypy error log.""" + cherrypy.log("", "HTTP", severity=severity, traceback=True) + + +def log_request_headers(debug=False): + """Write request headers to the cherrypy error log.""" + h = [" %s: %s" % (k, v) for k, v in cherrypy.serving.request.header_list] + cherrypy.log('\nRequest Headers:\n' + '\n'.join(h), "HTTP") + + +def log_hooks(debug=False): + """Write request.hooks to the cherrypy error log.""" + request = cherrypy.serving.request + + msg = [] + # Sort by the standard points if possible. + from cherrypy import _cprequest + points = _cprequest.hookpoints + for k in request.hooks.keys(): + if k not in points: + points.append(k) + + for k in points: + msg.append(" %s:" % k) + v = request.hooks.get(k, []) + v.sort() + for h in v: + msg.append(" %r" % h) + cherrypy.log('\nRequest Hooks for ' + cherrypy.url() + + ':\n' + '\n'.join(msg), "HTTP") + + +def redirect(url='', internal=True, debug=False): + """Raise InternalRedirect or HTTPRedirect to the given url.""" + if debug: + cherrypy.log('Redirecting %sto: %s' % + ({True: 'internal ', False: ''}[internal], url), + 'TOOLS.REDIRECT') + if internal: + raise cherrypy.InternalRedirect(url) + else: + raise cherrypy.HTTPRedirect(url) + + +def trailing_slash(missing=True, extra=False, status=None, debug=False): + """Redirect if path_info has (missing|extra) trailing slash.""" + request = cherrypy.serving.request + pi = request.path_info + + if debug: + cherrypy.log('is_index: %r, missing: %r, extra: %r, path_info: %r' % + (request.is_index, missing, extra, pi), + 'TOOLS.TRAILING_SLASH') + if request.is_index is True: + if missing: + if not pi.endswith('/'): + new_url = cherrypy.url(pi + '/', request.query_string) + raise cherrypy.HTTPRedirect(new_url, status=status or 301) + elif request.is_index is False: + if extra: + # If pi == '/', don't redirect to ''! + if pi.endswith('/') and pi != '/': + new_url = cherrypy.url(pi[:-1], request.query_string) + raise cherrypy.HTTPRedirect(new_url, status=status or 301) + + +def flatten(debug=False): + """Wrap response.body in a generator that recursively iterates over body. + + This allows cherrypy.response.body to consist of 'nested generators'; + that is, a set of generators that yield generators. + """ + def flattener(input): + numchunks = 0 + for x in input: + if not is_iterator(x): + numchunks += 1 + yield x + else: + for y in flattener(x): + numchunks += 1 + yield y + if debug: + cherrypy.log('Flattened %d chunks' % numchunks, 'TOOLS.FLATTEN') + response = cherrypy.serving.response + response.body = flattener(response.body) + + +def accept(media=None, debug=False): + """Return the client's preferred media-type (from the given Content-Types). + + If 'media' is None (the default), no test will be performed. + + If 'media' is provided, it should be the Content-Type value (as a string) + or values (as a list or tuple of strings) which the current resource + can emit. The client's acceptable media ranges (as declared in the + Accept request header) will be matched in order to these Content-Type + values; the first such string is returned. That is, the return value + will always be one of the strings provided in the 'media' arg (or None + if 'media' is None). + + If no match is found, then HTTPError 406 (Not Acceptable) is raised. + Note that most web browsers send */* as a (low-quality) acceptable + media range, which should match any Content-Type. In addition, "...if + no Accept header field is present, then it is assumed that the client + accepts all media types." + + Matching types are checked in order of client preference first, + and then in the order of the given 'media' values. + + Note that this function does not honor accept-params (other than "q"). + """ + if not media: + return + if isinstance(media, basestring): + media = [media] + request = cherrypy.serving.request + + # Parse the Accept request header, and try to match one + # of the requested media-ranges (in order of preference). + ranges = request.headers.elements('Accept') + if not ranges: + # Any media type is acceptable. + if debug: + cherrypy.log('No Accept header elements', 'TOOLS.ACCEPT') + return media[0] + else: + # Note that 'ranges' is sorted in order of preference + for element in ranges: + if element.qvalue > 0: + if element.value == "*/*": + # Matches any type or subtype + if debug: + cherrypy.log('Match due to */*', 'TOOLS.ACCEPT') + return media[0] + elif element.value.endswith("/*"): + # Matches any subtype + mtype = element.value[:-1] # Keep the slash + for m in media: + if m.startswith(mtype): + if debug: + cherrypy.log('Match due to %s' % element.value, + 'TOOLS.ACCEPT') + return m + else: + # Matches exact value + if element.value in media: + if debug: + cherrypy.log('Match due to %s' % element.value, + 'TOOLS.ACCEPT') + return element.value + + # No suitable media-range found. + ah = request.headers.get('Accept') + if ah is None: + msg = "Your client did not send an Accept header." + else: + msg = "Your client sent this Accept header: %s." % ah + msg += (" But this resource only emits these media types: %s." % + ", ".join(media)) + raise cherrypy.HTTPError(406, msg) + + +class MonitoredHeaderMap(_httputil.HeaderMap): + + def __init__(self): + self.accessed_headers = set() + + def __getitem__(self, key): + self.accessed_headers.add(key) + return _httputil.HeaderMap.__getitem__(self, key) + + def __contains__(self, key): + self.accessed_headers.add(key) + return _httputil.HeaderMap.__contains__(self, key) + + def get(self, key, default=None): + self.accessed_headers.add(key) + return _httputil.HeaderMap.get(self, key, default=default) + + if hasattr({}, 'has_key'): + # Python 2 + def has_key(self, key): + self.accessed_headers.add(key) + return _httputil.HeaderMap.has_key(self, key) + + +def autovary(ignore=None, debug=False): + """Auto-populate the Vary response header based on request.header access. + """ + request = cherrypy.serving.request + + req_h = request.headers + request.headers = MonitoredHeaderMap() + request.headers.update(req_h) + if ignore is None: + ignore = set(['Content-Disposition', 'Content-Length', 'Content-Type']) + + def set_response_header(): + resp_h = cherrypy.serving.response.headers + v = set([e.value for e in resp_h.elements('Vary')]) + if debug: + cherrypy.log( + 'Accessed headers: %s' % request.headers.accessed_headers, + 'TOOLS.AUTOVARY') + v = v.union(request.headers.accessed_headers) + v = v.difference(ignore) + v = list(v) + v.sort() + resp_h['Vary'] = ', '.join(v) + request.hooks.attach('before_finalize', set_response_header, 95) diff --git a/lib/cherrypy/lib/encoding.py b/lib/cherrypy/lib/encoding.py new file mode 100644 index 00000000..a4c2cbd6 --- /dev/null +++ b/lib/cherrypy/lib/encoding.py @@ -0,0 +1,421 @@ +import struct +import time + +import cherrypy +from cherrypy._cpcompat import basestring, BytesIO, ntob, set, unicodestr +from cherrypy.lib import file_generator +from cherrypy.lib import is_closable_iterator +from cherrypy.lib import set_vary_header + + +def decode(encoding=None, default_encoding='utf-8'): + """Replace or extend the list of charsets used to decode a request entity. + + Either argument may be a single string or a list of strings. + + encoding + If not None, restricts the set of charsets attempted while decoding + a request entity to the given set (even if a different charset is + given in the Content-Type request header). + + default_encoding + Only in effect if the 'encoding' argument is not given. + If given, the set of charsets attempted while decoding a request + entity is *extended* with the given value(s). + + """ + body = cherrypy.request.body + if encoding is not None: + if not isinstance(encoding, list): + encoding = [encoding] + body.attempt_charsets = encoding + elif default_encoding: + if not isinstance(default_encoding, list): + default_encoding = [default_encoding] + body.attempt_charsets = body.attempt_charsets + default_encoding + +class UTF8StreamEncoder: + def __init__(self, iterator): + self._iterator = iterator + + def __iter__(self): + return self + + def next(self): + return self.__next__() + + def __next__(self): + res = next(self._iterator) + if isinstance(res, unicodestr): + res = res.encode('utf-8') + return res + + def close(self): + if is_closable_iterator(self._iterator): + self._iterator.close() + + def __getattr__(self, attr): + if attr.startswith('__'): + raise AttributeError(self, attr) + return getattr(self._iterator, attr) + + +class ResponseEncoder: + + default_encoding = 'utf-8' + failmsg = "Response body could not be encoded with %r." + encoding = None + errors = 'strict' + text_only = True + add_charset = True + debug = False + + def __init__(self, **kwargs): + for k, v in kwargs.items(): + setattr(self, k, v) + + self.attempted_charsets = set() + request = cherrypy.serving.request + if request.handler is not None: + # Replace request.handler with self + if self.debug: + cherrypy.log('Replacing request.handler', 'TOOLS.ENCODE') + self.oldhandler = request.handler + request.handler = self + + def encode_stream(self, encoding): + """Encode a streaming response body. + + Use a generator wrapper, and just pray it works as the stream is + being written out. + """ + if encoding in self.attempted_charsets: + return False + self.attempted_charsets.add(encoding) + + def encoder(body): + for chunk in body: + if isinstance(chunk, unicodestr): + chunk = chunk.encode(encoding, self.errors) + yield chunk + self.body = encoder(self.body) + return True + + def encode_string(self, encoding): + """Encode a buffered response body.""" + if encoding in self.attempted_charsets: + return False + self.attempted_charsets.add(encoding) + body = [] + for chunk in self.body: + if isinstance(chunk, unicodestr): + try: + chunk = chunk.encode(encoding, self.errors) + except (LookupError, UnicodeError): + return False + body.append(chunk) + self.body = body + return True + + def find_acceptable_charset(self): + request = cherrypy.serving.request + response = cherrypy.serving.response + + if self.debug: + cherrypy.log('response.stream %r' % + response.stream, 'TOOLS.ENCODE') + if response.stream: + encoder = self.encode_stream + else: + encoder = self.encode_string + if "Content-Length" in response.headers: + # Delete Content-Length header so finalize() recalcs it. + # Encoded strings may be of different lengths from their + # unicode equivalents, and even from each other. For example: + # >>> t = u"\u7007\u3040" + # >>> len(t) + # 2 + # >>> len(t.encode("UTF-8")) + # 6 + # >>> len(t.encode("utf7")) + # 8 + del response.headers["Content-Length"] + + # Parse the Accept-Charset request header, and try to provide one + # of the requested charsets (in order of user preference). + encs = request.headers.elements('Accept-Charset') + charsets = [enc.value.lower() for enc in encs] + if self.debug: + cherrypy.log('charsets %s' % repr(charsets), 'TOOLS.ENCODE') + + if self.encoding is not None: + # If specified, force this encoding to be used, or fail. + encoding = self.encoding.lower() + if self.debug: + cherrypy.log('Specified encoding %r' % + encoding, 'TOOLS.ENCODE') + if (not charsets) or "*" in charsets or encoding in charsets: + if self.debug: + cherrypy.log('Attempting encoding %r' % + encoding, 'TOOLS.ENCODE') + if encoder(encoding): + return encoding + else: + if not encs: + if self.debug: + cherrypy.log('Attempting default encoding %r' % + self.default_encoding, 'TOOLS.ENCODE') + # Any character-set is acceptable. + if encoder(self.default_encoding): + return self.default_encoding + else: + raise cherrypy.HTTPError(500, self.failmsg % + self.default_encoding) + else: + for element in encs: + if element.qvalue > 0: + if element.value == "*": + # Matches any charset. Try our default. + if self.debug: + cherrypy.log('Attempting default encoding due ' + 'to %r' % element, 'TOOLS.ENCODE') + if encoder(self.default_encoding): + return self.default_encoding + else: + encoding = element.value + if self.debug: + cherrypy.log('Attempting encoding %s (qvalue >' + '0)' % element, 'TOOLS.ENCODE') + if encoder(encoding): + return encoding + + if "*" not in charsets: + # If no "*" is present in an Accept-Charset field, then all + # character sets not explicitly mentioned get a quality + # value of 0, except for ISO-8859-1, which gets a quality + # value of 1 if not explicitly mentioned. + iso = 'iso-8859-1' + if iso not in charsets: + if self.debug: + cherrypy.log('Attempting ISO-8859-1 encoding', + 'TOOLS.ENCODE') + if encoder(iso): + return iso + + # No suitable encoding found. + ac = request.headers.get('Accept-Charset') + if ac is None: + msg = "Your client did not send an Accept-Charset header." + else: + msg = "Your client sent this Accept-Charset header: %s." % ac + _charsets = ", ".join(sorted(self.attempted_charsets)) + msg += " We tried these charsets: %s." % (_charsets,) + raise cherrypy.HTTPError(406, msg) + + def __call__(self, *args, **kwargs): + response = cherrypy.serving.response + self.body = self.oldhandler(*args, **kwargs) + + if isinstance(self.body, basestring): + # strings get wrapped in a list because iterating over a single + # item list is much faster than iterating over every character + # in a long string. + if self.body: + self.body = [self.body] + else: + # [''] doesn't evaluate to False, so replace it with []. + self.body = [] + elif hasattr(self.body, 'read'): + self.body = file_generator(self.body) + elif self.body is None: + self.body = [] + + ct = response.headers.elements("Content-Type") + if self.debug: + cherrypy.log('Content-Type: %r' % [str(h) + for h in ct], 'TOOLS.ENCODE') + if ct and self.add_charset: + ct = ct[0] + if self.text_only: + if ct.value.lower().startswith("text/"): + if self.debug: + cherrypy.log( + 'Content-Type %s starts with "text/"' % ct, + 'TOOLS.ENCODE') + do_find = True + else: + if self.debug: + cherrypy.log('Not finding because Content-Type %s ' + 'does not start with "text/"' % ct, + 'TOOLS.ENCODE') + do_find = False + else: + if self.debug: + cherrypy.log('Finding because not text_only', + 'TOOLS.ENCODE') + do_find = True + + if do_find: + # Set "charset=..." param on response Content-Type header + ct.params['charset'] = self.find_acceptable_charset() + if self.debug: + cherrypy.log('Setting Content-Type %s' % ct, + 'TOOLS.ENCODE') + response.headers["Content-Type"] = str(ct) + + return self.body + +# GZIP + + +def compress(body, compress_level): + """Compress 'body' at the given compress_level.""" + import zlib + + # See http://www.gzip.org/zlib/rfc-gzip.html + yield ntob('\x1f\x8b') # ID1 and ID2: gzip marker + yield ntob('\x08') # CM: compression method + yield ntob('\x00') # FLG: none set + # MTIME: 4 bytes + yield struct.pack(" 0 is present + * The 'identity' value is given with a qvalue > 0. + + """ + request = cherrypy.serving.request + response = cherrypy.serving.response + + set_vary_header(response, "Accept-Encoding") + + if not response.body: + # Response body is empty (might be a 304 for instance) + if debug: + cherrypy.log('No response body', context='TOOLS.GZIP') + return + + # If returning cached content (which should already have been gzipped), + # don't re-zip. + if getattr(request, "cached", False): + if debug: + cherrypy.log('Not gzipping cached response', context='TOOLS.GZIP') + return + + acceptable = request.headers.elements('Accept-Encoding') + if not acceptable: + # If no Accept-Encoding field is present in a request, + # the server MAY assume that the client will accept any + # content coding. In this case, if "identity" is one of + # the available content-codings, then the server SHOULD use + # the "identity" content-coding, unless it has additional + # information that a different content-coding is meaningful + # to the client. + if debug: + cherrypy.log('No Accept-Encoding', context='TOOLS.GZIP') + return + + ct = response.headers.get('Content-Type', '').split(';')[0] + for coding in acceptable: + if coding.value == 'identity' and coding.qvalue != 0: + if debug: + cherrypy.log('Non-zero identity qvalue: %s' % coding, + context='TOOLS.GZIP') + return + if coding.value in ('gzip', 'x-gzip'): + if coding.qvalue == 0: + if debug: + cherrypy.log('Zero gzip qvalue: %s' % coding, + context='TOOLS.GZIP') + return + + if ct not in mime_types: + # If the list of provided mime-types contains tokens + # such as 'text/*' or 'application/*+xml', + # we go through them and find the most appropriate one + # based on the given content-type. + # The pattern matching is only caring about the most + # common cases, as stated above, and doesn't support + # for extra parameters. + found = False + if '/' in ct: + ct_media_type, ct_sub_type = ct.split('/') + for mime_type in mime_types: + if '/' in mime_type: + media_type, sub_type = mime_type.split('/') + if ct_media_type == media_type: + if sub_type == '*': + found = True + break + elif '+' in sub_type and '+' in ct_sub_type: + ct_left, ct_right = ct_sub_type.split('+') + left, right = sub_type.split('+') + if left == '*' and ct_right == right: + found = True + break + + if not found: + if debug: + cherrypy.log('Content-Type %s not in mime_types %r' % + (ct, mime_types), context='TOOLS.GZIP') + return + + if debug: + cherrypy.log('Gzipping', context='TOOLS.GZIP') + # Return a generator that compresses the page + response.headers['Content-Encoding'] = 'gzip' + response.body = compress(response.body, compress_level) + if "Content-Length" in response.headers: + # Delete Content-Length header so finalize() recalcs it. + del response.headers["Content-Length"] + + return + + if debug: + cherrypy.log('No acceptable encoding found.', context='GZIP') + cherrypy.HTTPError(406, "identity, gzip").set_response() diff --git a/lib/cherrypy/lib/gctools.py b/lib/cherrypy/lib/gctools.py new file mode 100644 index 00000000..4b616c59 --- /dev/null +++ b/lib/cherrypy/lib/gctools.py @@ -0,0 +1,217 @@ +import gc +import inspect +import os +import sys +import time + +try: + import objgraph +except ImportError: + objgraph = None + +import cherrypy +from cherrypy import _cprequest, _cpwsgi +from cherrypy.process.plugins import SimplePlugin + + +class ReferrerTree(object): + + """An object which gathers all referrers of an object to a given depth.""" + + peek_length = 40 + + def __init__(self, ignore=None, maxdepth=2, maxparents=10): + self.ignore = ignore or [] + self.ignore.append(inspect.currentframe().f_back) + self.maxdepth = maxdepth + self.maxparents = maxparents + + def ascend(self, obj, depth=1): + """Return a nested list containing referrers of the given object.""" + depth += 1 + parents = [] + + # Gather all referrers in one step to minimize + # cascading references due to repr() logic. + refs = gc.get_referrers(obj) + self.ignore.append(refs) + if len(refs) > self.maxparents: + return [("[%s referrers]" % len(refs), [])] + + try: + ascendcode = self.ascend.__code__ + except AttributeError: + ascendcode = self.ascend.im_func.func_code + for parent in refs: + if inspect.isframe(parent) and parent.f_code is ascendcode: + continue + if parent in self.ignore: + continue + if depth <= self.maxdepth: + parents.append((parent, self.ascend(parent, depth))) + else: + parents.append((parent, [])) + + return parents + + def peek(self, s): + """Return s, restricted to a sane length.""" + if len(s) > (self.peek_length + 3): + half = self.peek_length // 2 + return s[:half] + '...' + s[-half:] + else: + return s + + def _format(self, obj, descend=True): + """Return a string representation of a single object.""" + if inspect.isframe(obj): + filename, lineno, func, context, index = inspect.getframeinfo(obj) + return "" % func + + if not descend: + return self.peek(repr(obj)) + + if isinstance(obj, dict): + return "{" + ", ".join(["%s: %s" % (self._format(k, descend=False), + self._format(v, descend=False)) + for k, v in obj.items()]) + "}" + elif isinstance(obj, list): + return "[" + ", ".join([self._format(item, descend=False) + for item in obj]) + "]" + elif isinstance(obj, tuple): + return "(" + ", ".join([self._format(item, descend=False) + for item in obj]) + ")" + + r = self.peek(repr(obj)) + if isinstance(obj, (str, int, float)): + return r + return "%s: %s" % (type(obj), r) + + def format(self, tree): + """Return a list of string reprs from a nested list of referrers.""" + output = [] + + def ascend(branch, depth=1): + for parent, grandparents in branch: + output.append((" " * depth) + self._format(parent)) + if grandparents: + ascend(grandparents, depth + 1) + ascend(tree) + return output + + +def get_instances(cls): + return [x for x in gc.get_objects() if isinstance(x, cls)] + + +class RequestCounter(SimplePlugin): + + def start(self): + self.count = 0 + + def before_request(self): + self.count += 1 + + def after_request(self): + self.count -= 1 +request_counter = RequestCounter(cherrypy.engine) +request_counter.subscribe() + + +def get_context(obj): + if isinstance(obj, _cprequest.Request): + return "path=%s;stage=%s" % (obj.path_info, obj.stage) + elif isinstance(obj, _cprequest.Response): + return "status=%s" % obj.status + elif isinstance(obj, _cpwsgi.AppResponse): + return "PATH_INFO=%s" % obj.environ.get('PATH_INFO', '') + elif hasattr(obj, "tb_lineno"): + return "tb_lineno=%s" % obj.tb_lineno + return "" + + +class GCRoot(object): + + """A CherryPy page handler for testing reference leaks.""" + + classes = [ + (_cprequest.Request, 2, 2, + "Should be 1 in this request thread and 1 in the main thread."), + (_cprequest.Response, 2, 2, + "Should be 1 in this request thread and 1 in the main thread."), + (_cpwsgi.AppResponse, 1, 1, + "Should be 1 in this request thread only."), + ] + + def index(self): + return "Hello, world!" + index.exposed = True + + def stats(self): + output = ["Statistics:"] + + for trial in range(10): + if request_counter.count > 0: + break + time.sleep(0.5) + else: + output.append("\nNot all requests closed properly.") + + # gc_collect isn't perfectly synchronous, because it may + # break reference cycles that then take time to fully + # finalize. Call it thrice and hope for the best. + gc.collect() + gc.collect() + unreachable = gc.collect() + if unreachable: + if objgraph is not None: + final = objgraph.by_type('Nondestructible') + if final: + objgraph.show_backrefs(final, filename='finalizers.png') + + trash = {} + for x in gc.garbage: + trash[type(x)] = trash.get(type(x), 0) + 1 + if trash: + output.insert(0, "\n%s unreachable objects:" % unreachable) + trash = [(v, k) for k, v in trash.items()] + trash.sort() + for pair in trash: + output.append(" " + repr(pair)) + + # Check declared classes to verify uncollected instances. + # These don't have to be part of a cycle; they can be + # any objects that have unanticipated referrers that keep + # them from being collected. + allobjs = {} + for cls, minobj, maxobj, msg in self.classes: + allobjs[cls] = get_instances(cls) + + for cls, minobj, maxobj, msg in self.classes: + objs = allobjs[cls] + lenobj = len(objs) + if lenobj < minobj or lenobj > maxobj: + if minobj == maxobj: + output.append( + "\nExpected %s %r references, got %s." % + (minobj, cls, lenobj)) + else: + output.append( + "\nExpected %s to %s %r references, got %s." % + (minobj, maxobj, cls, lenobj)) + + for obj in objs: + if objgraph is not None: + ig = [id(objs), id(inspect.currentframe())] + fname = "graph_%s_%s.png" % (cls.__name__, id(obj)) + objgraph.show_backrefs( + obj, extra_ignore=ig, max_depth=4, too_many=20, + filename=fname, extra_info=get_context) + output.append("\nReferrers for %s (refcount=%s):" % + (repr(obj), sys.getrefcount(obj))) + t = ReferrerTree(ignore=[objs], maxdepth=3) + tree = t.ascend(obj) + output.extend(t.format(tree)) + + return "\n".join(output) + stats.exposed = True diff --git a/lib/cherrypy/lib/http.py b/lib/cherrypy/lib/http.py new file mode 100644 index 00000000..12043ad1 --- /dev/null +++ b/lib/cherrypy/lib/http.py @@ -0,0 +1,6 @@ +import warnings +warnings.warn('cherrypy.lib.http has been deprecated and will be removed ' + 'in CherryPy 3.3 use cherrypy.lib.httputil instead.', + DeprecationWarning) + +from cherrypy.lib.httputil import * diff --git a/lib/cherrypy/lib/httpauth.py b/lib/cherrypy/lib/httpauth.py new file mode 100644 index 00000000..0897ea2a --- /dev/null +++ b/lib/cherrypy/lib/httpauth.py @@ -0,0 +1,371 @@ +""" +This module defines functions to implement HTTP Digest Authentication +(:rfc:`2617`). +This has full compliance with 'Digest' and 'Basic' authentication methods. In +'Digest' it supports both MD5 and MD5-sess algorithms. + +Usage: + First use 'doAuth' to request the client authentication for a + certain resource. You should send an httplib.UNAUTHORIZED response to the + client so he knows he has to authenticate itself. + + Then use 'parseAuthorization' to retrieve the 'auth_map' used in + 'checkResponse'. + + To use 'checkResponse' you must have already verified the password + associated with the 'username' key in 'auth_map' dict. Then you use the + 'checkResponse' function to verify if the password matches the one sent + by the client. + +SUPPORTED_ALGORITHM - list of supported 'Digest' algorithms +SUPPORTED_QOP - list of supported 'Digest' 'qop'. +""" +__version__ = 1, 0, 1 +__author__ = "Tiago Cogumbreiro " +__credits__ = """ + Peter van Kampen for its recipe which implement most of Digest + authentication: + http://aspn.activestate.com/ASPN/Cookbook/Python/Recipe/302378 +""" + +__license__ = """ +Copyright (c) 2005, Tiago Cogumbreiro +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + + * Redistributions of source code must retain the above copyright notice, + this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + * Neither the name of Sylvain Hellegouarch nor the names of his + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +""" + +__all__ = ("digestAuth", "basicAuth", "doAuth", "checkResponse", + "parseAuthorization", "SUPPORTED_ALGORITHM", "md5SessionKey", + "calculateNonce", "SUPPORTED_QOP") + +########################################################################## +import time +from cherrypy._cpcompat import base64_decode, ntob, md5 +from cherrypy._cpcompat import parse_http_list, parse_keqv_list + +MD5 = "MD5" +MD5_SESS = "MD5-sess" +AUTH = "auth" +AUTH_INT = "auth-int" + +SUPPORTED_ALGORITHM = (MD5, MD5_SESS) +SUPPORTED_QOP = (AUTH, AUTH_INT) + +########################################################################## +# doAuth +# +DIGEST_AUTH_ENCODERS = { + MD5: lambda val: md5(ntob(val)).hexdigest(), + MD5_SESS: lambda val: md5(ntob(val)).hexdigest(), + # SHA: lambda val: sha.new(ntob(val)).hexdigest (), +} + + +def calculateNonce(realm, algorithm=MD5): + """This is an auxaliary function that calculates 'nonce' value. It is used + to handle sessions.""" + + global SUPPORTED_ALGORITHM, DIGEST_AUTH_ENCODERS + assert algorithm in SUPPORTED_ALGORITHM + + try: + encoder = DIGEST_AUTH_ENCODERS[algorithm] + except KeyError: + raise NotImplementedError("The chosen algorithm (%s) does not have " + "an implementation yet" % algorithm) + + return encoder("%d:%s" % (time.time(), realm)) + + +def digestAuth(realm, algorithm=MD5, nonce=None, qop=AUTH): + """Challenges the client for a Digest authentication.""" + global SUPPORTED_ALGORITHM, DIGEST_AUTH_ENCODERS, SUPPORTED_QOP + assert algorithm in SUPPORTED_ALGORITHM + assert qop in SUPPORTED_QOP + + if nonce is None: + nonce = calculateNonce(realm, algorithm) + + return 'Digest realm="%s", nonce="%s", algorithm="%s", qop="%s"' % ( + realm, nonce, algorithm, qop + ) + + +def basicAuth(realm): + """Challengenes the client for a Basic authentication.""" + assert '"' not in realm, "Realms cannot contain the \" (quote) character." + + return 'Basic realm="%s"' % realm + + +def doAuth(realm): + """'doAuth' function returns the challenge string b giving priority over + Digest and fallback to Basic authentication when the browser doesn't + support the first one. + + This should be set in the HTTP header under the key 'WWW-Authenticate'.""" + + return digestAuth(realm) + " " + basicAuth(realm) + + +########################################################################## +# Parse authorization parameters +# +def _parseDigestAuthorization(auth_params): + # Convert the auth params to a dict + items = parse_http_list(auth_params) + params = parse_keqv_list(items) + + # Now validate the params + + # Check for required parameters + required = ["username", "realm", "nonce", "uri", "response"] + for k in required: + if k not in params: + return None + + # If qop is sent then cnonce and nc MUST be present + if "qop" in params and not ("cnonce" in params + and "nc" in params): + return None + + # If qop is not sent, neither cnonce nor nc can be present + if ("cnonce" in params or "nc" in params) and \ + "qop" not in params: + return None + + return params + + +def _parseBasicAuthorization(auth_params): + username, password = base64_decode(auth_params).split(":", 1) + return {"username": username, "password": password} + +AUTH_SCHEMES = { + "basic": _parseBasicAuthorization, + "digest": _parseDigestAuthorization, +} + + +def parseAuthorization(credentials): + """parseAuthorization will convert the value of the 'Authorization' key in + the HTTP header to a map itself. If the parsing fails 'None' is returned. + """ + + global AUTH_SCHEMES + + auth_scheme, auth_params = credentials.split(" ", 1) + auth_scheme = auth_scheme.lower() + + parser = AUTH_SCHEMES[auth_scheme] + params = parser(auth_params) + + if params is None: + return + + assert "auth_scheme" not in params + params["auth_scheme"] = auth_scheme + return params + + +########################################################################## +# Check provided response for a valid password +# +def md5SessionKey(params, password): + """ + If the "algorithm" directive's value is "MD5-sess", then A1 + [the session key] is calculated only once - on the first request by the + client following receipt of a WWW-Authenticate challenge from the server. + + This creates a 'session key' for the authentication of subsequent + requests and responses which is different for each "authentication + session", thus limiting the amount of material hashed with any one + key. + + Because the server need only use the hash of the user + credentials in order to create the A1 value, this construction could + be used in conjunction with a third party authentication service so + that the web server would not need the actual password value. The + specification of such a protocol is beyond the scope of this + specification. +""" + + keys = ("username", "realm", "nonce", "cnonce") + params_copy = {} + for key in keys: + params_copy[key] = params[key] + + params_copy["algorithm"] = MD5_SESS + return _A1(params_copy, password) + + +def _A1(params, password): + algorithm = params.get("algorithm", MD5) + H = DIGEST_AUTH_ENCODERS[algorithm] + + if algorithm == MD5: + # If the "algorithm" directive's value is "MD5" or is + # unspecified, then A1 is: + # A1 = unq(username-value) ":" unq(realm-value) ":" passwd + return "%s:%s:%s" % (params["username"], params["realm"], password) + + elif algorithm == MD5_SESS: + + # This is A1 if qop is set + # A1 = H( unq(username-value) ":" unq(realm-value) ":" passwd ) + # ":" unq(nonce-value) ":" unq(cnonce-value) + h_a1 = H("%s:%s:%s" % (params["username"], params["realm"], password)) + return "%s:%s:%s" % (h_a1, params["nonce"], params["cnonce"]) + + +def _A2(params, method, kwargs): + # If the "qop" directive's value is "auth" or is unspecified, then A2 is: + # A2 = Method ":" digest-uri-value + + qop = params.get("qop", "auth") + if qop == "auth": + return method + ":" + params["uri"] + elif qop == "auth-int": + # If the "qop" value is "auth-int", then A2 is: + # A2 = Method ":" digest-uri-value ":" H(entity-body) + entity_body = kwargs.get("entity_body", "") + H = kwargs["H"] + + return "%s:%s:%s" % ( + method, + params["uri"], + H(entity_body) + ) + + else: + raise NotImplementedError("The 'qop' method is unknown: %s" % qop) + + +def _computeDigestResponse(auth_map, password, method="GET", A1=None, + **kwargs): + """ + Generates a response respecting the algorithm defined in RFC 2617 + """ + params = auth_map + + algorithm = params.get("algorithm", MD5) + + H = DIGEST_AUTH_ENCODERS[algorithm] + KD = lambda secret, data: H(secret + ":" + data) + + qop = params.get("qop", None) + + H_A2 = H(_A2(params, method, kwargs)) + + if algorithm == MD5_SESS and A1 is not None: + H_A1 = H(A1) + else: + H_A1 = H(_A1(params, password)) + + if qop in ("auth", "auth-int"): + # If the "qop" value is "auth" or "auth-int": + # request-digest = <"> < KD ( H(A1), unq(nonce-value) + # ":" nc-value + # ":" unq(cnonce-value) + # ":" unq(qop-value) + # ":" H(A2) + # ) <"> + request = "%s:%s:%s:%s:%s" % ( + params["nonce"], + params["nc"], + params["cnonce"], + params["qop"], + H_A2, + ) + elif qop is None: + # If the "qop" directive is not present (this construction is + # for compatibility with RFC 2069): + # request-digest = + # <"> < KD ( H(A1), unq(nonce-value) ":" H(A2) ) > <"> + request = "%s:%s" % (params["nonce"], H_A2) + + return KD(H_A1, request) + + +def _checkDigestResponse(auth_map, password, method="GET", A1=None, **kwargs): + """This function is used to verify the response given by the client when + he tries to authenticate. + Optional arguments: + entity_body - when 'qop' is set to 'auth-int' you MUST provide the + raw data you are going to send to the client (usually the + HTML page. + request_uri - the uri from the request line compared with the 'uri' + directive of the authorization map. They must represent + the same resource (unused at this time). + """ + + if auth_map['realm'] != kwargs.get('realm', None): + return False + + response = _computeDigestResponse( + auth_map, password, method, A1, **kwargs) + + return response == auth_map["response"] + + +def _checkBasicResponse(auth_map, password, method='GET', encrypt=None, + **kwargs): + # Note that the Basic response doesn't provide the realm value so we cannot + # test it + pass_through = lambda password, username=None: password + encrypt = encrypt or pass_through + try: + candidate = encrypt(auth_map["password"], auth_map["username"]) + except TypeError: + # if encrypt only takes one parameter, it's the password + candidate = encrypt(auth_map["password"]) + return candidate == password + +AUTH_RESPONSES = { + "basic": _checkBasicResponse, + "digest": _checkDigestResponse, +} + + +def checkResponse(auth_map, password, method="GET", encrypt=None, **kwargs): + """'checkResponse' compares the auth_map with the password and optionally + other arguments that each implementation might need. + + If the response is of type 'Basic' then the function has the following + signature:: + + checkBasicResponse(auth_map, password) -> bool + + If the response is of type 'Digest' then the function has the following + signature:: + + checkDigestResponse(auth_map, password, method='GET', A1=None) -> bool + + The 'A1' argument is only used in MD5_SESS algorithm based responses. + Check md5SessionKey() for more info. + """ + checker = AUTH_RESPONSES[auth_map["auth_scheme"]] + return checker(auth_map, password, method=method, encrypt=encrypt, + **kwargs) diff --git a/lib/cherrypy/lib/httputil.py b/lib/cherrypy/lib/httputil.py new file mode 100644 index 00000000..69a18d45 --- /dev/null +++ b/lib/cherrypy/lib/httputil.py @@ -0,0 +1,536 @@ +"""HTTP library functions. + +This module contains functions for building an HTTP application +framework: any one, not just one whose name starts with "Ch". ;) If you +reference any modules from some popular framework inside *this* module, +FuManChu will personally hang you up by your thumbs and submit you +to a public caning. +""" + +from binascii import b2a_base64 +from cherrypy._cpcompat import BaseHTTPRequestHandler, HTTPDate, ntob, ntou +from cherrypy._cpcompat import basestring, bytestr, iteritems, nativestr +from cherrypy._cpcompat import reversed, sorted, unicodestr, unquote_qs +response_codes = BaseHTTPRequestHandler.responses.copy() + +# From https://bitbucket.org/cherrypy/cherrypy/issue/361 +response_codes[500] = ('Internal Server Error', + 'The server encountered an unexpected condition ' + 'which prevented it from fulfilling the request.') +response_codes[503] = ('Service Unavailable', + 'The server is currently unable to handle the ' + 'request due to a temporary overloading or ' + 'maintenance of the server.') + +import re +import urllib + + +def urljoin(*atoms): + """Return the given path \*atoms, joined into a single URL. + + This will correctly join a SCRIPT_NAME and PATH_INFO into the + original URL, even if either atom is blank. + """ + url = "/".join([x for x in atoms if x]) + while "//" in url: + url = url.replace("//", "/") + # Special-case the final url of "", and return "/" instead. + return url or "/" + + +def urljoin_bytes(*atoms): + """Return the given path *atoms, joined into a single URL. + + This will correctly join a SCRIPT_NAME and PATH_INFO into the + original URL, even if either atom is blank. + """ + url = ntob("/").join([x for x in atoms if x]) + while ntob("//") in url: + url = url.replace(ntob("//"), ntob("/")) + # Special-case the final url of "", and return "/" instead. + return url or ntob("/") + + +def protocol_from_http(protocol_str): + """Return a protocol tuple from the given 'HTTP/x.y' string.""" + return int(protocol_str[5]), int(protocol_str[7]) + + +def get_ranges(headervalue, content_length): + """Return a list of (start, stop) indices from a Range header, or None. + + Each (start, stop) tuple will be composed of two ints, which are suitable + for use in a slicing operation. That is, the header "Range: bytes=3-6", + if applied against a Python string, is requesting resource[3:7]. This + function will return the list [(3, 7)]. + + If this function returns an empty list, you should return HTTP 416. + """ + + if not headervalue: + return None + + result = [] + bytesunit, byteranges = headervalue.split("=", 1) + for brange in byteranges.split(","): + start, stop = [x.strip() for x in brange.split("-", 1)] + if start: + if not stop: + stop = content_length - 1 + start, stop = int(start), int(stop) + if start >= content_length: + # From rfc 2616 sec 14.16: + # "If the server receives a request (other than one + # including an If-Range request-header field) with an + # unsatisfiable Range request-header field (that is, + # all of whose byte-range-spec values have a first-byte-pos + # value greater than the current length of the selected + # resource), it SHOULD return a response code of 416 + # (Requested range not satisfiable)." + continue + if stop < start: + # From rfc 2616 sec 14.16: + # "If the server ignores a byte-range-spec because it + # is syntactically invalid, the server SHOULD treat + # the request as if the invalid Range header field + # did not exist. (Normally, this means return a 200 + # response containing the full entity)." + return None + result.append((start, stop + 1)) + else: + if not stop: + # See rfc quote above. + return None + # Negative subscript (last N bytes) + # + # RFC 2616 Section 14.35.1: + # If the entity is shorter than the specified suffix-length, + # the entire entity-body is used. + if int(stop) > content_length: + result.append((0, content_length)) + else: + result.append((content_length - int(stop), content_length)) + + return result + + +class HeaderElement(object): + + """An element (with parameters) from an HTTP header's element list.""" + + def __init__(self, value, params=None): + self.value = value + if params is None: + params = {} + self.params = params + + def __cmp__(self, other): + return cmp(self.value, other.value) + + def __lt__(self, other): + return self.value < other.value + + def __str__(self): + p = [";%s=%s" % (k, v) for k, v in iteritems(self.params)] + return str("%s%s" % (self.value, "".join(p))) + + def __bytes__(self): + return ntob(self.__str__()) + + def __unicode__(self): + return ntou(self.__str__()) + + def parse(elementstr): + """Transform 'token;key=val' to ('token', {'key': 'val'}).""" + # Split the element into a value and parameters. The 'value' may + # be of the form, "token=token", but we don't split that here. + atoms = [x.strip() for x in elementstr.split(";") if x.strip()] + if not atoms: + initial_value = '' + else: + initial_value = atoms.pop(0).strip() + params = {} + for atom in atoms: + atom = [x.strip() for x in atom.split("=", 1) if x.strip()] + key = atom.pop(0) + if atom: + val = atom[0] + else: + val = "" + params[key] = val + return initial_value, params + parse = staticmethod(parse) + + def from_str(cls, elementstr): + """Construct an instance from a string of the form 'token;key=val'.""" + ival, params = cls.parse(elementstr) + return cls(ival, params) + from_str = classmethod(from_str) + + +q_separator = re.compile(r'; *q *=') + + +class AcceptElement(HeaderElement): + + """An element (with parameters) from an Accept* header's element list. + + AcceptElement objects are comparable; the more-preferred object will be + "less than" the less-preferred object. They are also therefore sortable; + if you sort a list of AcceptElement objects, they will be listed in + priority order; the most preferred value will be first. Yes, it should + have been the other way around, but it's too late to fix now. + """ + + def from_str(cls, elementstr): + qvalue = None + # The first "q" parameter (if any) separates the initial + # media-range parameter(s) (if any) from the accept-params. + atoms = q_separator.split(elementstr, 1) + media_range = atoms.pop(0).strip() + if atoms: + # The qvalue for an Accept header can have extensions. The other + # headers cannot, but it's easier to parse them as if they did. + qvalue = HeaderElement.from_str(atoms[0].strip()) + + media_type, params = cls.parse(media_range) + if qvalue is not None: + params["q"] = qvalue + return cls(media_type, params) + from_str = classmethod(from_str) + + def qvalue(self): + val = self.params.get("q", "1") + if isinstance(val, HeaderElement): + val = val.value + return float(val) + qvalue = property(qvalue, doc="The qvalue, or priority, of this value.") + + def __cmp__(self, other): + diff = cmp(self.qvalue, other.qvalue) + if diff == 0: + diff = cmp(str(self), str(other)) + return diff + + def __lt__(self, other): + if self.qvalue == other.qvalue: + return str(self) < str(other) + else: + return self.qvalue < other.qvalue + +RE_HEADER_SPLIT = re.compile(',(?=(?:[^"]*"[^"]*")*[^"]*$)') +def header_elements(fieldname, fieldvalue): + """Return a sorted HeaderElement list from a comma-separated header string. + """ + if not fieldvalue: + return [] + + result = [] + for element in RE_HEADER_SPLIT.split(fieldvalue): + if fieldname.startswith("Accept") or fieldname == 'TE': + hv = AcceptElement.from_str(element) + else: + hv = HeaderElement.from_str(element) + result.append(hv) + + return list(reversed(sorted(result))) + + +def decode_TEXT(value): + r"""Decode :rfc:`2047` TEXT (e.g. "=?utf-8?q?f=C3=BCr?=" -> "f\xfcr").""" + try: + # Python 3 + from email.header import decode_header + except ImportError: + from email.Header import decode_header + atoms = decode_header(value) + decodedvalue = "" + for atom, charset in atoms: + if charset is not None: + atom = atom.decode(charset) + decodedvalue += atom + return decodedvalue + + +def valid_status(status): + """Return legal HTTP status Code, Reason-phrase and Message. + + The status arg must be an int, or a str that begins with an int. + + If status is an int, or a str and no reason-phrase is supplied, + a default reason-phrase will be provided. + """ + + if not status: + status = 200 + + status = str(status) + parts = status.split(" ", 1) + if len(parts) == 1: + # No reason supplied. + code, = parts + reason = None + else: + code, reason = parts + reason = reason.strip() + + try: + code = int(code) + except ValueError: + raise ValueError("Illegal response status from server " + "(%s is non-numeric)." % repr(code)) + + if code < 100 or code > 599: + raise ValueError("Illegal response status from server " + "(%s is out of range)." % repr(code)) + + if code not in response_codes: + # code is unknown but not illegal + default_reason, message = "", "" + else: + default_reason, message = response_codes[code] + + if reason is None: + reason = default_reason + + return code, reason, message + + +# NOTE: the parse_qs functions that follow are modified version of those +# in the python3.0 source - we need to pass through an encoding to the unquote +# method, but the default parse_qs function doesn't allow us to. These do. + +def _parse_qs(qs, keep_blank_values=0, strict_parsing=0, encoding='utf-8'): + """Parse a query given as a string argument. + + Arguments: + + qs: URL-encoded query string to be parsed + + keep_blank_values: flag indicating whether blank values in + URL encoded queries should be treated as blank strings. A + true value indicates that blanks should be retained as blank + strings. The default false value indicates that blank values + are to be ignored and treated as if they were not included. + + strict_parsing: flag indicating what to do with parsing errors. If + false (the default), errors are silently ignored. If true, + errors raise a ValueError exception. + + Returns a dict, as G-d intended. + """ + pairs = [s2 for s1 in qs.split('&') for s2 in s1.split(';')] + d = {} + for name_value in pairs: + if not name_value and not strict_parsing: + continue + nv = name_value.split('=', 1) + if len(nv) != 2: + if strict_parsing: + raise ValueError("bad query field: %r" % (name_value,)) + # Handle case of a control-name with no equal sign + if keep_blank_values: + nv.append('') + else: + continue + if len(nv[1]) or keep_blank_values: + name = unquote_qs(nv[0], encoding) + value = unquote_qs(nv[1], encoding) + if name in d: + if not isinstance(d[name], list): + d[name] = [d[name]] + d[name].append(value) + else: + d[name] = value + return d + + +image_map_pattern = re.compile(r"[0-9]+,[0-9]+") + + +def parse_query_string(query_string, keep_blank_values=True, encoding='utf-8'): + """Build a params dictionary from a query_string. + + Duplicate key/value pairs in the provided query_string will be + returned as {'key': [val1, val2, ...]}. Single key/values will + be returned as strings: {'key': 'value'}. + """ + if image_map_pattern.match(query_string): + # Server-side image map. Map the coords to 'x' and 'y' + # (like CGI::Request does). + pm = query_string.split(",") + pm = {'x': int(pm[0]), 'y': int(pm[1])} + else: + pm = _parse_qs(query_string, keep_blank_values, encoding=encoding) + return pm + + +class CaseInsensitiveDict(dict): + + """A case-insensitive dict subclass. + + Each key is changed on entry to str(key).title(). + """ + + def __getitem__(self, key): + return dict.__getitem__(self, str(key).title()) + + def __setitem__(self, key, value): + dict.__setitem__(self, str(key).title(), value) + + def __delitem__(self, key): + dict.__delitem__(self, str(key).title()) + + def __contains__(self, key): + return dict.__contains__(self, str(key).title()) + + def get(self, key, default=None): + return dict.get(self, str(key).title(), default) + + if hasattr({}, 'has_key'): + def has_key(self, key): + return str(key).title() in self + + def update(self, E): + for k in E.keys(): + self[str(k).title()] = E[k] + + def fromkeys(cls, seq, value=None): + newdict = cls() + for k in seq: + newdict[str(k).title()] = value + return newdict + fromkeys = classmethod(fromkeys) + + def setdefault(self, key, x=None): + key = str(key).title() + try: + return self[key] + except KeyError: + self[key] = x + return x + + def pop(self, key, default): + return dict.pop(self, str(key).title(), default) + + +# TEXT = +# +# A CRLF is allowed in the definition of TEXT only as part of a header +# field continuation. It is expected that the folding LWS will be +# replaced with a single SP before interpretation of the TEXT value." +if nativestr == bytestr: + header_translate_table = ''.join([chr(i) for i in xrange(256)]) + header_translate_deletechars = ''.join( + [chr(i) for i in xrange(32)]) + chr(127) +else: + header_translate_table = None + header_translate_deletechars = bytes(range(32)) + bytes([127]) + + +class HeaderMap(CaseInsensitiveDict): + + """A dict subclass for HTTP request and response headers. + + Each key is changed on entry to str(key).title(). This allows headers + to be case-insensitive and avoid duplicates. + + Values are header values (decoded according to :rfc:`2047` if necessary). + """ + + protocol = (1, 1) + encodings = ["ISO-8859-1"] + + # Someday, when http-bis is done, this will probably get dropped + # since few servers, clients, or intermediaries do it. But until then, + # we're going to obey the spec as is. + # "Words of *TEXT MAY contain characters from character sets other than + # ISO-8859-1 only when encoded according to the rules of RFC 2047." + use_rfc_2047 = True + + def elements(self, key): + """Return a sorted list of HeaderElements for the given header.""" + key = str(key).title() + value = self.get(key) + return header_elements(key, value) + + def values(self, key): + """Return a sorted list of HeaderElement.value for the given header.""" + return [e.value for e in self.elements(key)] + + def output(self): + """Transform self into a list of (name, value) tuples.""" + return list(self.encode_header_items(self.items())) + + def encode_header_items(cls, header_items): + """ + Prepare the sequence of name, value tuples into a form suitable for + transmitting on the wire for HTTP. + """ + for k, v in header_items: + if isinstance(k, unicodestr): + k = cls.encode(k) + + if not isinstance(v, basestring): + v = str(v) + + if isinstance(v, unicodestr): + v = cls.encode(v) + + # See header_translate_* constants above. + # Replace only if you really know what you're doing. + k = k.translate(header_translate_table, + header_translate_deletechars) + v = v.translate(header_translate_table, + header_translate_deletechars) + + yield (k, v) + encode_header_items = classmethod(encode_header_items) + + def encode(cls, v): + """Return the given header name or value, encoded for HTTP output.""" + for enc in cls.encodings: + try: + return v.encode(enc) + except UnicodeEncodeError: + continue + + if cls.protocol == (1, 1) and cls.use_rfc_2047: + # Encode RFC-2047 TEXT + # (e.g. u"\u8200" -> "=?utf-8?b?6IiA?="). + # We do our own here instead of using the email module + # because we never want to fold lines--folding has + # been deprecated by the HTTP working group. + v = b2a_base64(v.encode('utf-8')) + return (ntob('=?utf-8?b?') + v.strip(ntob('\n')) + ntob('?=')) + + raise ValueError("Could not encode header part %r using " + "any of the encodings %r." % + (v, cls.encodings)) + encode = classmethod(encode) + + +class Host(object): + + """An internet address. + + name + Should be the client's host name. If not available (because no DNS + lookup is performed), the IP address should be used instead. + + """ + + ip = "0.0.0.0" + port = 80 + name = "unknown.tld" + + def __init__(self, ip, port, name=None): + self.ip = ip + self.port = port + if name is None: + name = ip + self.name = name + + def __repr__(self): + return "httputil.Host(%r, %r, %r)" % (self.ip, self.port, self.name) diff --git a/lib/cherrypy/lib/jsontools.py b/lib/cherrypy/lib/jsontools.py new file mode 100644 index 00000000..90b3ff8a --- /dev/null +++ b/lib/cherrypy/lib/jsontools.py @@ -0,0 +1,96 @@ +import cherrypy +from cherrypy._cpcompat import basestring, ntou, json_encode, json_decode + + +def json_processor(entity): + """Read application/json data into request.json.""" + if not entity.headers.get(ntou("Content-Length"), ntou("")): + raise cherrypy.HTTPError(411) + + body = entity.fp.read() + try: + cherrypy.serving.request.json = json_decode(body.decode('utf-8')) + except ValueError: + raise cherrypy.HTTPError(400, 'Invalid JSON document') + + +def json_in(content_type=[ntou('application/json'), ntou('text/javascript')], + force=True, debug=False, processor=json_processor): + """Add a processor to parse JSON request entities: + The default processor places the parsed data into request.json. + + Incoming request entities which match the given content_type(s) will + be deserialized from JSON to the Python equivalent, and the result + stored at cherrypy.request.json. The 'content_type' argument may + be a Content-Type string or a list of allowable Content-Type strings. + + If the 'force' argument is True (the default), then entities of other + content types will not be allowed; "415 Unsupported Media Type" is + raised instead. + + Supply your own processor to use a custom decoder, or to handle the parsed + data differently. The processor can be configured via + tools.json_in.processor or via the decorator method. + + Note that the deserializer requires the client send a Content-Length + request header, or it will raise "411 Length Required". If for any + other reason the request entity cannot be deserialized from JSON, + it will raise "400 Bad Request: Invalid JSON document". + + You must be using Python 2.6 or greater, or have the 'simplejson' + package importable; otherwise, ValueError is raised during processing. + """ + request = cherrypy.serving.request + if isinstance(content_type, basestring): + content_type = [content_type] + + if force: + if debug: + cherrypy.log('Removing body processors %s' % + repr(request.body.processors.keys()), 'TOOLS.JSON_IN') + request.body.processors.clear() + request.body.default_proc = cherrypy.HTTPError( + 415, 'Expected an entity of content type %s' % + ', '.join(content_type)) + + for ct in content_type: + if debug: + cherrypy.log('Adding body processor for %s' % ct, 'TOOLS.JSON_IN') + request.body.processors[ct] = processor + + +def json_handler(*args, **kwargs): + value = cherrypy.serving.request._json_inner_handler(*args, **kwargs) + return json_encode(value) + + +def json_out(content_type='application/json', debug=False, + handler=json_handler): + """Wrap request.handler to serialize its output to JSON. Sets Content-Type. + + If the given content_type is None, the Content-Type response header + is not set. + + Provide your own handler to use a custom encoder. For example + cherrypy.config['tools.json_out.handler'] = , or + @json_out(handler=function). + + You must be using Python 2.6 or greater, or have the 'simplejson' + package importable; otherwise, ValueError is raised during processing. + """ + request = cherrypy.serving.request + # request.handler may be set to None by e.g. the caching tool + # to signal to all components that a response body has already + # been attached, in which case we don't need to wrap anything. + if request.handler is None: + return + if debug: + cherrypy.log('Replacing %s with JSON handler' % request.handler, + 'TOOLS.JSON_OUT') + request._json_inner_handler = request.handler + request.handler = handler + if content_type is not None: + if debug: + cherrypy.log('Setting Content-Type to %s' % + content_type, 'TOOLS.JSON_OUT') + cherrypy.serving.response.headers['Content-Type'] = content_type diff --git a/lib/cherrypy/lib/lockfile.py b/lib/cherrypy/lib/lockfile.py new file mode 100644 index 00000000..4cf7b1b6 --- /dev/null +++ b/lib/cherrypy/lib/lockfile.py @@ -0,0 +1,147 @@ +""" +Platform-independent file locking. Inspired by and modeled after zc.lockfile. +""" + +import os + +try: + import msvcrt +except ImportError: + pass + +try: + import fcntl +except ImportError: + pass + + +class LockError(Exception): + + "Could not obtain a lock" + + msg = "Unable to lock %r" + + def __init__(self, path): + super(LockError, self).__init__(self.msg % path) + + +class UnlockError(LockError): + + "Could not release a lock" + + msg = "Unable to unlock %r" + + +# first, a default, naive locking implementation +class LockFile(object): + + """ + A default, naive locking implementation. Always fails if the file + already exists. + """ + + def __init__(self, path): + self.path = path + try: + fd = os.open(path, os.O_CREAT | os.O_WRONLY | os.O_EXCL) + except OSError: + raise LockError(self.path) + os.close(fd) + + def release(self): + os.remove(self.path) + + def remove(self): + pass + + +class SystemLockFile(object): + + """ + An abstract base class for platform-specific locking. + """ + + def __init__(self, path): + self.path = path + + try: + # Open lockfile for writing without truncation: + self.fp = open(path, 'r+') + except IOError: + # If the file doesn't exist, IOError is raised; Use a+ instead. + # Note that there may be a race here. Multiple processes + # could fail on the r+ open and open the file a+, but only + # one will get the the lock and write a pid. + self.fp = open(path, 'a+') + + try: + self._lock_file() + except: + self.fp.seek(1) + self.fp.close() + del self.fp + raise + + self.fp.write(" %s\n" % os.getpid()) + self.fp.truncate() + self.fp.flush() + + def release(self): + if not hasattr(self, 'fp'): + return + self._unlock_file() + self.fp.close() + del self.fp + + def remove(self): + """ + Attempt to remove the file + """ + try: + os.remove(self.path) + except: + pass + + #@abc.abstract_method + # def _lock_file(self): + # """Attempt to obtain the lock on self.fp. Raise LockError if not + # acquired.""" + + def _unlock_file(self): + """Attempt to obtain the lock on self.fp. Raise UnlockError if not + released.""" + + +class WindowsLockFile(SystemLockFile): + + def _lock_file(self): + # Lock just the first byte + try: + msvcrt.locking(self.fp.fileno(), msvcrt.LK_NBLCK, 1) + except IOError: + raise LockError(self.fp.name) + + def _unlock_file(self): + try: + self.fp.seek(0) + msvcrt.locking(self.fp.fileno(), msvcrt.LK_UNLCK, 1) + except IOError: + raise UnlockError(self.fp.name) + +if 'msvcrt' in globals(): + LockFile = WindowsLockFile + + +class UnixLockFile(SystemLockFile): + + def _lock_file(self): + flags = fcntl.LOCK_EX | fcntl.LOCK_NB + try: + fcntl.flock(self.fp.fileno(), flags) + except IOError: + raise LockError(self.fp.name) + + # no need to implement _unlock_file, it will be unlocked on close() + +if 'fcntl' in globals(): + LockFile = UnixLockFile diff --git a/lib/cherrypy/lib/locking.py b/lib/cherrypy/lib/locking.py new file mode 100644 index 00000000..72dda9b3 --- /dev/null +++ b/lib/cherrypy/lib/locking.py @@ -0,0 +1,47 @@ +import datetime + + +class NeverExpires(object): + def expired(self): + return False + + +class Timer(object): + """ + A simple timer that will indicate when an expiration time has passed. + """ + def __init__(self, expiration): + "Create a timer that expires at `expiration` (UTC datetime)" + self.expiration = expiration + + @classmethod + def after(cls, elapsed): + """ + Return a timer that will expire after `elapsed` passes. + """ + return cls(datetime.datetime.utcnow() + elapsed) + + def expired(self): + return datetime.datetime.utcnow() >= self.expiration + + +class LockTimeout(Exception): + "An exception when a lock could not be acquired before a timeout period" + + +class LockChecker(object): + """ + Keep track of the time and detect if a timeout has expired + """ + def __init__(self, session_id, timeout): + self.session_id = session_id + if timeout: + self.timer = Timer.after(timeout) + else: + self.timer = NeverExpires() + + def expired(self): + if self.timer.expired(): + raise LockTimeout( + "Timeout acquiring lock for %(session_id)s" % vars(self)) + return False diff --git a/lib/cherrypy/lib/profiler.py b/lib/cherrypy/lib/profiler.py new file mode 100644 index 00000000..5dac386e --- /dev/null +++ b/lib/cherrypy/lib/profiler.py @@ -0,0 +1,216 @@ +"""Profiler tools for CherryPy. + +CherryPy users +============== + +You can profile any of your pages as follows:: + + from cherrypy.lib import profiler + + class Root: + p = profile.Profiler("/path/to/profile/dir") + + def index(self): + self.p.run(self._index) + index.exposed = True + + def _index(self): + return "Hello, world!" + + cherrypy.tree.mount(Root()) + +You can also turn on profiling for all requests +using the ``make_app`` function as WSGI middleware. + +CherryPy developers +=================== + +This module can be used whenever you make changes to CherryPy, +to get a quick sanity-check on overall CP performance. Use the +``--profile`` flag when running the test suite. Then, use the ``serve()`` +function to browse the results in a web browser. If you run this +module from the command line, it will call ``serve()`` for you. + +""" + + +def new_func_strip_path(func_name): + """Make profiler output more readable by adding `__init__` modules' parents + """ + filename, line, name = func_name + if filename.endswith("__init__.py"): + return os.path.basename(filename[:-12]) + filename[-12:], line, name + return os.path.basename(filename), line, name + +try: + import profile + import pstats + pstats.func_strip_path = new_func_strip_path +except ImportError: + profile = None + pstats = None + +import os +import os.path +import sys +import warnings + +from cherrypy._cpcompat import StringIO + +_count = 0 + + +class Profiler(object): + + def __init__(self, path=None): + if not path: + path = os.path.join(os.path.dirname(__file__), "profile") + self.path = path + if not os.path.exists(path): + os.makedirs(path) + + def run(self, func, *args, **params): + """Dump profile data into self.path.""" + global _count + c = _count = _count + 1 + path = os.path.join(self.path, "cp_%04d.prof" % c) + prof = profile.Profile() + result = prof.runcall(func, *args, **params) + prof.dump_stats(path) + return result + + def statfiles(self): + """:rtype: list of available profiles. + """ + return [f for f in os.listdir(self.path) + if f.startswith("cp_") and f.endswith(".prof")] + + def stats(self, filename, sortby='cumulative'): + """:rtype stats(index): output of print_stats() for the given profile. + """ + sio = StringIO() + if sys.version_info >= (2, 5): + s = pstats.Stats(os.path.join(self.path, filename), stream=sio) + s.strip_dirs() + s.sort_stats(sortby) + s.print_stats() + else: + # pstats.Stats before Python 2.5 didn't take a 'stream' arg, + # but just printed to stdout. So re-route stdout. + s = pstats.Stats(os.path.join(self.path, filename)) + s.strip_dirs() + s.sort_stats(sortby) + oldout = sys.stdout + try: + sys.stdout = sio + s.print_stats() + finally: + sys.stdout = oldout + response = sio.getvalue() + sio.close() + return response + + def index(self): + return """ + CherryPy profile data + + + + + + """ + index.exposed = True + + def menu(self): + yield "

Profiling runs

" + yield "

Click on one of the runs below to see profiling data.

" + runs = self.statfiles() + runs.sort() + for i in runs: + yield "%s
" % ( + i, i) + menu.exposed = True + + def report(self, filename): + import cherrypy + cherrypy.response.headers['Content-Type'] = 'text/plain' + return self.stats(filename) + report.exposed = True + + +class ProfileAggregator(Profiler): + + def __init__(self, path=None): + Profiler.__init__(self, path) + global _count + self.count = _count = _count + 1 + self.profiler = profile.Profile() + + def run(self, func, *args, **params): + path = os.path.join(self.path, "cp_%04d.prof" % self.count) + result = self.profiler.runcall(func, *args, **params) + self.profiler.dump_stats(path) + return result + + +class make_app: + + def __init__(self, nextapp, path=None, aggregate=False): + """Make a WSGI middleware app which wraps 'nextapp' with profiling. + + nextapp + the WSGI application to wrap, usually an instance of + cherrypy.Application. + + path + where to dump the profiling output. + + aggregate + if True, profile data for all HTTP requests will go in + a single file. If False (the default), each HTTP request will + dump its profile data into a separate file. + + """ + if profile is None or pstats is None: + msg = ("Your installation of Python does not have a profile " + "module. If you're on Debian, try " + "`sudo apt-get install python-profiler`. " + "See http://www.cherrypy.org/wiki/ProfilingOnDebian " + "for details.") + warnings.warn(msg) + + self.nextapp = nextapp + self.aggregate = aggregate + if aggregate: + self.profiler = ProfileAggregator(path) + else: + self.profiler = Profiler(path) + + def __call__(self, environ, start_response): + def gather(): + result = [] + for line in self.nextapp(environ, start_response): + result.append(line) + return result + return self.profiler.run(gather) + + +def serve(path=None, port=8080): + if profile is None or pstats is None: + msg = ("Your installation of Python does not have a profile module. " + "If you're on Debian, try " + "`sudo apt-get install python-profiler`. " + "See http://www.cherrypy.org/wiki/ProfilingOnDebian " + "for details.") + warnings.warn(msg) + + import cherrypy + cherrypy.config.update({'server.socket_port': int(port), + 'server.thread_pool': 10, + 'environment': "production", + }) + cherrypy.quickstart(Profiler(path)) + + +if __name__ == "__main__": + serve(*tuple(sys.argv[1:])) diff --git a/lib/cherrypy/lib/reprconf.py b/lib/cherrypy/lib/reprconf.py new file mode 100644 index 00000000..6e70b5ec --- /dev/null +++ b/lib/cherrypy/lib/reprconf.py @@ -0,0 +1,503 @@ +"""Generic configuration system using unrepr. + +Configuration data may be supplied as a Python dictionary, as a filename, +or as an open file object. When you supply a filename or file, Python's +builtin ConfigParser is used (with some extensions). + +Namespaces +---------- + +Configuration keys are separated into namespaces by the first "." in the key. + +The only key that cannot exist in a namespace is the "environment" entry. +This special entry 'imports' other config entries from a template stored in +the Config.environments dict. + +You can define your own namespaces to be called when new config is merged +by adding a named handler to Config.namespaces. The name can be any string, +and the handler must be either a callable or a context manager. +""" + +try: + # Python 3.0+ + from configparser import ConfigParser +except ImportError: + from ConfigParser import ConfigParser + +try: + set +except NameError: + from sets import Set as set + +try: + basestring +except NameError: + basestring = str + +try: + # Python 3 + import builtins +except ImportError: + # Python 2 + import __builtin__ as builtins + +import operator as _operator +import sys + + +def as_dict(config): + """Return a dict from 'config' whether it is a dict, file, or filename.""" + if isinstance(config, basestring): + config = Parser().dict_from_file(config) + elif hasattr(config, 'read'): + config = Parser().dict_from_file(config) + return config + + +class NamespaceSet(dict): + + """A dict of config namespace names and handlers. + + Each config entry should begin with a namespace name; the corresponding + namespace handler will be called once for each config entry in that + namespace, and will be passed two arguments: the config key (with the + namespace removed) and the config value. + + Namespace handlers may be any Python callable; they may also be + Python 2.5-style 'context managers', in which case their __enter__ + method should return a callable to be used as the handler. + See cherrypy.tools (the Toolbox class) for an example. + """ + + def __call__(self, config): + """Iterate through config and pass it to each namespace handler. + + config + A flat dict, where keys use dots to separate + namespaces, and values are arbitrary. + + The first name in each config key is used to look up the corresponding + namespace handler. For example, a config entry of {'tools.gzip.on': v} + will call the 'tools' namespace handler with the args: ('gzip.on', v) + """ + # Separate the given config into namespaces + ns_confs = {} + for k in config: + if "." in k: + ns, name = k.split(".", 1) + bucket = ns_confs.setdefault(ns, {}) + bucket[name] = config[k] + + # I chose __enter__ and __exit__ so someday this could be + # rewritten using Python 2.5's 'with' statement: + # for ns, handler in self.iteritems(): + # with handler as callable: + # for k, v in ns_confs.get(ns, {}).iteritems(): + # callable(k, v) + for ns, handler in self.items(): + exit = getattr(handler, "__exit__", None) + if exit: + callable = handler.__enter__() + no_exc = True + try: + try: + for k, v in ns_confs.get(ns, {}).items(): + callable(k, v) + except: + # The exceptional case is handled here + no_exc = False + if exit is None: + raise + if not exit(*sys.exc_info()): + raise + # The exception is swallowed if exit() returns true + finally: + # The normal and non-local-goto cases are handled here + if no_exc and exit: + exit(None, None, None) + else: + for k, v in ns_confs.get(ns, {}).items(): + handler(k, v) + + def __repr__(self): + return "%s.%s(%s)" % (self.__module__, self.__class__.__name__, + dict.__repr__(self)) + + def __copy__(self): + newobj = self.__class__() + newobj.update(self) + return newobj + copy = __copy__ + + +class Config(dict): + + """A dict-like set of configuration data, with defaults and namespaces. + + May take a file, filename, or dict. + """ + + defaults = {} + environments = {} + namespaces = NamespaceSet() + + def __init__(self, file=None, **kwargs): + self.reset() + if file is not None: + self.update(file) + if kwargs: + self.update(kwargs) + + def reset(self): + """Reset self to default values.""" + self.clear() + dict.update(self, self.defaults) + + def update(self, config): + """Update self from a dict, file or filename.""" + if isinstance(config, basestring): + # Filename + config = Parser().dict_from_file(config) + elif hasattr(config, 'read'): + # Open file object + config = Parser().dict_from_file(config) + else: + config = config.copy() + self._apply(config) + + def _apply(self, config): + """Update self from a dict.""" + which_env = config.get('environment') + if which_env: + env = self.environments[which_env] + for k in env: + if k not in config: + config[k] = env[k] + + dict.update(self, config) + self.namespaces(config) + + def __setitem__(self, k, v): + dict.__setitem__(self, k, v) + self.namespaces({k: v}) + + +class Parser(ConfigParser): + + """Sub-class of ConfigParser that keeps the case of options and that + raises an exception if the file cannot be read. + """ + + def optionxform(self, optionstr): + return optionstr + + def read(self, filenames): + if isinstance(filenames, basestring): + filenames = [filenames] + for filename in filenames: + # try: + # fp = open(filename) + # except IOError: + # continue + fp = open(filename) + try: + self._read(fp, filename) + finally: + fp.close() + + def as_dict(self, raw=False, vars=None): + """Convert an INI file to a dictionary""" + # Load INI file into a dict + result = {} + for section in self.sections(): + if section not in result: + result[section] = {} + for option in self.options(section): + value = self.get(section, option, raw=raw, vars=vars) + try: + value = unrepr(value) + except Exception: + x = sys.exc_info()[1] + msg = ("Config error in section: %r, option: %r, " + "value: %r. Config values must be valid Python." % + (section, option, value)) + raise ValueError(msg, x.__class__.__name__, x.args) + result[section][option] = value + return result + + def dict_from_file(self, file): + if hasattr(file, 'read'): + self.readfp(file) + else: + self.read(file) + return self.as_dict() + + +# public domain "unrepr" implementation, found on the web and then improved. + + +class _Builder2: + + def build(self, o): + m = getattr(self, 'build_' + o.__class__.__name__, None) + if m is None: + raise TypeError("unrepr does not recognize %s" % + repr(o.__class__.__name__)) + return m(o) + + def astnode(self, s): + """Return a Python2 ast Node compiled from a string.""" + try: + import compiler + except ImportError: + # Fallback to eval when compiler package is not available, + # e.g. IronPython 1.0. + return eval(s) + + p = compiler.parse("__tempvalue__ = " + s) + return p.getChildren()[1].getChildren()[0].getChildren()[1] + + def build_Subscript(self, o): + expr, flags, subs = o.getChildren() + expr = self.build(expr) + subs = self.build(subs) + return expr[subs] + + def build_CallFunc(self, o): + children = o.getChildren() + # Build callee from first child + callee = self.build(children[0]) + # Build args and kwargs from remaining children + args = [] + kwargs = {} + for child in children[1:]: + class_name = child.__class__.__name__ + # None is ignored + if class_name == 'NoneType': + continue + # Keywords become kwargs + if class_name == 'Keyword': + kwargs.update(self.build(child)) + # Everything else becomes args + else : + args.append(self.build(child)) + return callee(*args, **kwargs) + + def build_Keyword(self, o): + key, value_obj = o.getChildren() + value = self.build(value_obj) + kw_dict = {key: value} + return kw_dict + + def build_List(self, o): + return map(self.build, o.getChildren()) + + def build_Const(self, o): + return o.value + + def build_Dict(self, o): + d = {} + i = iter(map(self.build, o.getChildren())) + for el in i: + d[el] = i.next() + return d + + def build_Tuple(self, o): + return tuple(self.build_List(o)) + + def build_Name(self, o): + name = o.name + if name == 'None': + return None + if name == 'True': + return True + if name == 'False': + return False + + # See if the Name is a package or module. If it is, import it. + try: + return modules(name) + except ImportError: + pass + + # See if the Name is in builtins. + try: + return getattr(builtins, name) + except AttributeError: + pass + + raise TypeError("unrepr could not resolve the name %s" % repr(name)) + + def build_Add(self, o): + left, right = map(self.build, o.getChildren()) + return left + right + + def build_Mul(self, o): + left, right = map(self.build, o.getChildren()) + return left * right + + def build_Getattr(self, o): + parent = self.build(o.expr) + return getattr(parent, o.attrname) + + def build_NoneType(self, o): + return None + + def build_UnarySub(self, o): + return -self.build(o.getChildren()[0]) + + def build_UnaryAdd(self, o): + return self.build(o.getChildren()[0]) + + +class _Builder3: + + def build(self, o): + m = getattr(self, 'build_' + o.__class__.__name__, None) + if m is None: + raise TypeError("unrepr does not recognize %s" % + repr(o.__class__.__name__)) + return m(o) + + def astnode(self, s): + """Return a Python3 ast Node compiled from a string.""" + try: + import ast + except ImportError: + # Fallback to eval when ast package is not available, + # e.g. IronPython 1.0. + return eval(s) + + p = ast.parse("__tempvalue__ = " + s) + return p.body[0].value + + def build_Subscript(self, o): + return self.build(o.value)[self.build(o.slice)] + + def build_Index(self, o): + return self.build(o.value) + + def build_Call(self, o): + callee = self.build(o.func) + + if o.args is None: + args = () + else: + args = tuple([self.build(a) for a in o.args]) + + if o.starargs is None: + starargs = () + else: + starargs = self.build(o.starargs) + + if o.kwargs is None: + kwargs = {} + else: + kwargs = self.build(o.kwargs) + + return callee(*(args + starargs), **kwargs) + + def build_List(self, o): + return list(map(self.build, o.elts)) + + def build_Str(self, o): + return o.s + + def build_Num(self, o): + return o.n + + def build_Dict(self, o): + return dict([(self.build(k), self.build(v)) + for k, v in zip(o.keys, o.values)]) + + def build_Tuple(self, o): + return tuple(self.build_List(o)) + + def build_Name(self, o): + name = o.id + if name == 'None': + return None + if name == 'True': + return True + if name == 'False': + return False + + # See if the Name is a package or module. If it is, import it. + try: + return modules(name) + except ImportError: + pass + + # See if the Name is in builtins. + try: + import builtins + return getattr(builtins, name) + except AttributeError: + pass + + raise TypeError("unrepr could not resolve the name %s" % repr(name)) + + def build_NameConstant(self, o): + return o.value + + def build_UnaryOp(self, o): + op, operand = map(self.build, [o.op, o.operand]) + return op(operand) + + def build_BinOp(self, o): + left, op, right = map(self.build, [o.left, o.op, o.right]) + return op(left, right) + + def build_Add(self, o): + return _operator.add + + def build_Mult(self, o): + return _operator.mul + + def build_USub(self, o): + return _operator.neg + + def build_Attribute(self, o): + parent = self.build(o.value) + return getattr(parent, o.attr) + + def build_NoneType(self, o): + return None + + +def unrepr(s): + """Return a Python object compiled from a string.""" + if not s: + return s + if sys.version_info < (3, 0): + b = _Builder2() + else: + b = _Builder3() + obj = b.astnode(s) + return b.build(obj) + + +def modules(modulePath): + """Load a module and retrieve a reference to that module.""" + __import__(modulePath) + return sys.modules[modulePath] + + +def attributes(full_attribute_name): + """Load a module and retrieve an attribute of that module.""" + + # Parse out the path, module, and attribute + last_dot = full_attribute_name.rfind(".") + attr_name = full_attribute_name[last_dot + 1:] + mod_path = full_attribute_name[:last_dot] + + mod = modules(mod_path) + # Let an AttributeError propagate outward. + try: + attr = getattr(mod, attr_name) + except AttributeError: + raise AttributeError("'%s' object has no attribute '%s'" + % (mod_path, attr_name)) + + # Return a reference to the attribute. + return attr diff --git a/lib/cherrypy/lib/sessions.py b/lib/cherrypy/lib/sessions.py new file mode 100644 index 00000000..37556363 --- /dev/null +++ b/lib/cherrypy/lib/sessions.py @@ -0,0 +1,968 @@ +"""Session implementation for CherryPy. + +You need to edit your config file to use sessions. Here's an example:: + + [/] + tools.sessions.on = True + tools.sessions.storage_type = "file" + tools.sessions.storage_path = "/home/site/sessions" + tools.sessions.timeout = 60 + +This sets the session to be stored in files in the directory +/home/site/sessions, and the session timeout to 60 minutes. If you omit +``storage_type`` the sessions will be saved in RAM. +``tools.sessions.on`` is the only required line for working sessions, +the rest are optional. + +By default, the session ID is passed in a cookie, so the client's browser must +have cookies enabled for your site. + +To set data for the current session, use +``cherrypy.session['fieldname'] = 'fieldvalue'``; +to get data use ``cherrypy.session.get('fieldname')``. + +================ +Locking sessions +================ + +By default, the ``'locking'`` mode of sessions is ``'implicit'``, which means +the session is locked early and unlocked late. Be mindful of this default mode +for any requests that take a long time to process (streaming responses, +expensive calculations, database lookups, API calls, etc), as other concurrent +requests that also utilize sessions will hang until the session is unlocked. + +If you want to control when the session data is locked and unlocked, +set ``tools.sessions.locking = 'explicit'``. Then call +``cherrypy.session.acquire_lock()`` and ``cherrypy.session.release_lock()``. +Regardless of which mode you use, the session is guaranteed to be unlocked when +the request is complete. + +================= +Expiring Sessions +================= + +You can force a session to expire with :func:`cherrypy.lib.sessions.expire`. +Simply call that function at the point you want the session to expire, and it +will cause the session cookie to expire client-side. + +=========================== +Session Fixation Protection +=========================== + +If CherryPy receives, via a request cookie, a session id that it does not +recognize, it will reject that id and create a new one to return in the +response cookie. This `helps prevent session fixation attacks +`_. +However, CherryPy "recognizes" a session id by looking up the saved session +data for that id. Therefore, if you never save any session data, +**you will get a new session id for every request**. + +================ +Sharing Sessions +================ + +If you run multiple instances of CherryPy (for example via mod_python behind +Apache prefork), you most likely cannot use the RAM session backend, since each +instance of CherryPy will have its own memory space. Use a different backend +instead, and verify that all instances are pointing at the same file or db +location. Alternately, you might try a load balancer which makes sessions +"sticky". Google is your friend, there. + +================ +Expiration Dates +================ + +The response cookie will possess an expiration date to inform the client at +which point to stop sending the cookie back in requests. If the server time +and client time differ, expect sessions to be unreliable. **Make sure the +system time of your server is accurate**. + +CherryPy defaults to a 60-minute session timeout, which also applies to the +cookie which is sent to the client. Unfortunately, some versions of Safari +("4 public beta" on Windows XP at least) appear to have a bug in their parsing +of the GMT expiration date--they appear to interpret the date as one hour in +the past. Sixty minutes minus one hour is pretty close to zero, so you may +experience this bug as a new session id for every request, unless the requests +are less than one second apart. To fix, try increasing the session.timeout. + +On the other extreme, some users report Firefox sending cookies after their +expiration date, although this was on a system with an inaccurate system time. +Maybe FF doesn't trust system time. +""" +import sys +import datetime +import os +import time +import threading +import types + +import cherrypy +from cherrypy._cpcompat import copyitems, pickle, random20, unicodestr +from cherrypy.lib import httputil +from cherrypy.lib import lockfile +from cherrypy.lib import locking +from cherrypy.lib import is_iterator + +missing = object() + + +class Session(object): + + """A CherryPy dict-like Session object (one per request).""" + + _id = None + + id_observers = None + "A list of callbacks to which to pass new id's." + + def _get_id(self): + return self._id + + def _set_id(self, value): + self._id = value + for o in self.id_observers: + o(value) + id = property(_get_id, _set_id, doc="The current session ID.") + + timeout = 60 + "Number of minutes after which to delete session data." + + locked = False + """ + If True, this session instance has exclusive read/write access + to session data.""" + + loaded = False + """ + If True, data has been retrieved from storage. This should happen + automatically on the first attempt to access session data.""" + + clean_thread = None + "Class-level Monitor which calls self.clean_up." + + clean_freq = 5 + "The poll rate for expired session cleanup in minutes." + + originalid = None + "The session id passed by the client. May be missing or unsafe." + + missing = False + "True if the session requested by the client did not exist." + + regenerated = False + """ + True if the application called session.regenerate(). This is not set by + internal calls to regenerate the session id.""" + + debug = False + "If True, log debug information." + + # --------------------- Session management methods --------------------- # + + def __init__(self, id=None, **kwargs): + self.id_observers = [] + self._data = {} + + for k, v in kwargs.items(): + setattr(self, k, v) + + self.originalid = id + self.missing = False + if id is None: + if self.debug: + cherrypy.log('No id given; making a new one', 'TOOLS.SESSIONS') + self._regenerate() + else: + self.id = id + if self._exists(): + if self.debug: + cherrypy.log('Set id to %s.' % id, 'TOOLS.SESSIONS') + else: + if self.debug: + cherrypy.log('Expired or malicious session %r; ' + 'making a new one' % id, 'TOOLS.SESSIONS') + # Expired or malicious session. Make a new one. + # See https://bitbucket.org/cherrypy/cherrypy/issue/709. + self.id = None + self.missing = True + self._regenerate() + + def now(self): + """Generate the session specific concept of 'now'. + + Other session providers can override this to use alternative, + possibly timezone aware, versions of 'now'. + """ + return datetime.datetime.now() + + def regenerate(self): + """Replace the current session (with a new id).""" + self.regenerated = True + self._regenerate() + + def _regenerate(self): + if self.id is not None: + if self.debug: + cherrypy.log( + 'Deleting the existing session %r before ' + 'regeneration.' % self.id, + 'TOOLS.SESSIONS') + self.delete() + + old_session_was_locked = self.locked + if old_session_was_locked: + self.release_lock() + if self.debug: + cherrypy.log('Old lock released.', 'TOOLS.SESSIONS') + + self.id = None + while self.id is None: + self.id = self.generate_id() + # Assert that the generated id is not already stored. + if self._exists(): + self.id = None + if self.debug: + cherrypy.log('Set id to generated %s.' % self.id, + 'TOOLS.SESSIONS') + + if old_session_was_locked: + self.acquire_lock() + if self.debug: + cherrypy.log('Regenerated lock acquired.', 'TOOLS.SESSIONS') + + def clean_up(self): + """Clean up expired sessions.""" + pass + + def generate_id(self): + """Return a new session id.""" + return random20() + + def save(self): + """Save session data.""" + try: + # If session data has never been loaded then it's never been + # accessed: no need to save it + if self.loaded: + t = datetime.timedelta(seconds=self.timeout * 60) + expiration_time = self.now() + t + if self.debug: + cherrypy.log('Saving session %r with expiry %s' % + (self.id, expiration_time), + 'TOOLS.SESSIONS') + self._save(expiration_time) + else: + if self.debug: + cherrypy.log( + 'Skipping save of session %r (no session loaded).' % + self.id, 'TOOLS.SESSIONS') + finally: + if self.locked: + # Always release the lock if the user didn't release it + self.release_lock() + if self.debug: + cherrypy.log('Lock released after save.', 'TOOLS.SESSIONS') + + def load(self): + """Copy stored session data into this session instance.""" + data = self._load() + # data is either None or a tuple (session_data, expiration_time) + if data is None or data[1] < self.now(): + if self.debug: + cherrypy.log('Expired session %r, flushing data.' % self.id, + 'TOOLS.SESSIONS') + self._data = {} + else: + if self.debug: + cherrypy.log('Data loaded for session %r.' % self.id, + 'TOOLS.SESSIONS') + self._data = data[0] + self.loaded = True + + # Stick the clean_thread in the class, not the instance. + # The instances are created and destroyed per-request. + cls = self.__class__ + if self.clean_freq and not cls.clean_thread: + # clean_up is an instancemethod and not a classmethod, + # so that tool config can be accessed inside the method. + t = cherrypy.process.plugins.Monitor( + cherrypy.engine, self.clean_up, self.clean_freq * 60, + name='Session cleanup') + t.subscribe() + cls.clean_thread = t + t.start() + if self.debug: + cherrypy.log('Started cleanup thread.', 'TOOLS.SESSIONS') + + def delete(self): + """Delete stored session data.""" + self._delete() + if self.debug: + cherrypy.log('Deleted session %s.' % self.id, + 'TOOLS.SESSIONS') + + # -------------------- Application accessor methods -------------------- # + + def __getitem__(self, key): + if not self.loaded: + self.load() + return self._data[key] + + def __setitem__(self, key, value): + if not self.loaded: + self.load() + self._data[key] = value + + def __delitem__(self, key): + if not self.loaded: + self.load() + del self._data[key] + + def pop(self, key, default=missing): + """Remove the specified key and return the corresponding value. + If key is not found, default is returned if given, + otherwise KeyError is raised. + """ + if not self.loaded: + self.load() + if default is missing: + return self._data.pop(key) + else: + return self._data.pop(key, default) + + def __contains__(self, key): + if not self.loaded: + self.load() + return key in self._data + + if hasattr({}, 'has_key'): + def has_key(self, key): + """D.has_key(k) -> True if D has a key k, else False.""" + if not self.loaded: + self.load() + return key in self._data + + def get(self, key, default=None): + """D.get(k[,d]) -> D[k] if k in D, else d. d defaults to None.""" + if not self.loaded: + self.load() + return self._data.get(key, default) + + def update(self, d): + """D.update(E) -> None. Update D from E: for k in E: D[k] = E[k].""" + if not self.loaded: + self.load() + self._data.update(d) + + def setdefault(self, key, default=None): + """D.setdefault(k[,d]) -> D.get(k,d), also set D[k]=d if k not in D.""" + if not self.loaded: + self.load() + return self._data.setdefault(key, default) + + def clear(self): + """D.clear() -> None. Remove all items from D.""" + if not self.loaded: + self.load() + self._data.clear() + + def keys(self): + """D.keys() -> list of D's keys.""" + if not self.loaded: + self.load() + return self._data.keys() + + def items(self): + """D.items() -> list of D's (key, value) pairs, as 2-tuples.""" + if not self.loaded: + self.load() + return self._data.items() + + def values(self): + """D.values() -> list of D's values.""" + if not self.loaded: + self.load() + return self._data.values() + + +class RamSession(Session): + + # Class-level objects. Don't rebind these! + cache = {} + locks = {} + + def clean_up(self): + """Clean up expired sessions.""" + + now = self.now() + for _id, (data, expiration_time) in copyitems(self.cache): + if expiration_time <= now: + try: + del self.cache[_id] + except KeyError: + pass + try: + if self.locks[_id].acquire(blocking=False): + lock = self.locks.pop(_id) + lock.release() + except KeyError: + pass + + # added to remove obsolete lock objects + for _id in list(self.locks): + if _id not in self.cache and self.locks[_id].acquire(blocking=False): + lock = self.locks.pop(_id) + lock.release() + + def _exists(self): + return self.id in self.cache + + def _load(self): + return self.cache.get(self.id) + + def _save(self, expiration_time): + self.cache[self.id] = (self._data, expiration_time) + + def _delete(self): + self.cache.pop(self.id, None) + + def acquire_lock(self): + """Acquire an exclusive lock on the currently-loaded session data.""" + self.locked = True + self.locks.setdefault(self.id, threading.RLock()).acquire() + + def release_lock(self): + """Release the lock on the currently-loaded session data.""" + self.locks[self.id].release() + self.locked = False + + def __len__(self): + """Return the number of active sessions.""" + return len(self.cache) + + +class FileSession(Session): + + """Implementation of the File backend for sessions + + storage_path + The folder where session data will be saved. Each session + will be saved as pickle.dump(data, expiration_time) in its own file; + the filename will be self.SESSION_PREFIX + self.id. + + lock_timeout + A timedelta or numeric seconds indicating how long + to block acquiring a lock. If None (default), acquiring a lock + will block indefinitely. + """ + + SESSION_PREFIX = 'session-' + LOCK_SUFFIX = '.lock' + pickle_protocol = pickle.HIGHEST_PROTOCOL + + def __init__(self, id=None, **kwargs): + # The 'storage_path' arg is required for file-based sessions. + kwargs['storage_path'] = os.path.abspath(kwargs['storage_path']) + kwargs.setdefault('lock_timeout', None) + + Session.__init__(self, id=id, **kwargs) + + # validate self.lock_timeout + if isinstance(self.lock_timeout, (int, float)): + self.lock_timeout = datetime.timedelta(seconds=self.lock_timeout) + if not isinstance(self.lock_timeout, (datetime.timedelta, type(None))): + raise ValueError("Lock timeout must be numeric seconds or " + "a timedelta instance.") + + def setup(cls, **kwargs): + """Set up the storage system for file-based sessions. + + This should only be called once per process; this will be done + automatically when using sessions.init (as the built-in Tool does). + """ + # The 'storage_path' arg is required for file-based sessions. + kwargs['storage_path'] = os.path.abspath(kwargs['storage_path']) + + for k, v in kwargs.items(): + setattr(cls, k, v) + setup = classmethod(setup) + + def _get_file_path(self): + f = os.path.join(self.storage_path, self.SESSION_PREFIX + self.id) + if not os.path.abspath(f).startswith(self.storage_path): + raise cherrypy.HTTPError(400, "Invalid session id in cookie.") + return f + + def _exists(self): + path = self._get_file_path() + return os.path.exists(path) + + def _load(self, path=None): + assert self.locked, ("The session load without being locked. " + "Check your tools' priority levels.") + if path is None: + path = self._get_file_path() + try: + f = open(path, "rb") + try: + return pickle.load(f) + finally: + f.close() + except (IOError, EOFError): + e = sys.exc_info()[1] + if self.debug: + cherrypy.log("Error loading the session pickle: %s" % + e, 'TOOLS.SESSIONS') + return None + + def _save(self, expiration_time): + assert self.locked, ("The session was saved without being locked. " + "Check your tools' priority levels.") + f = open(self._get_file_path(), "wb") + try: + pickle.dump((self._data, expiration_time), f, self.pickle_protocol) + finally: + f.close() + + def _delete(self): + assert self.locked, ("The session deletion without being locked. " + "Check your tools' priority levels.") + try: + os.unlink(self._get_file_path()) + except OSError: + pass + + def acquire_lock(self, path=None): + """Acquire an exclusive lock on the currently-loaded session data.""" + if path is None: + path = self._get_file_path() + path += self.LOCK_SUFFIX + checker = locking.LockChecker(self.id, self.lock_timeout) + while not checker.expired(): + try: + self.lock = lockfile.LockFile(path) + except lockfile.LockError: + time.sleep(0.1) + else: + break + self.locked = True + if self.debug: + cherrypy.log('Lock acquired.', 'TOOLS.SESSIONS') + + def release_lock(self, path=None): + """Release the lock on the currently-loaded session data.""" + self.lock.release() + self.lock.remove() + self.locked = False + + def clean_up(self): + """Clean up expired sessions.""" + now = self.now() + # Iterate over all session files in self.storage_path + for fname in os.listdir(self.storage_path): + if (fname.startswith(self.SESSION_PREFIX) + and not fname.endswith(self.LOCK_SUFFIX)): + # We have a session file: lock and load it and check + # if it's expired. If it fails, nevermind. + path = os.path.join(self.storage_path, fname) + self.acquire_lock(path) + if self.debug: + # This is a bit of a hack, since we're calling clean_up + # on the first instance rather than the entire class, + # so depending on whether you have "debug" set on the + # path of the first session called, this may not run. + cherrypy.log('Cleanup lock acquired.', 'TOOLS.SESSIONS') + + try: + contents = self._load(path) + # _load returns None on IOError + if contents is not None: + data, expiration_time = contents + if expiration_time < now: + # Session expired: deleting it + os.unlink(path) + finally: + self.release_lock(path) + + def __len__(self): + """Return the number of active sessions.""" + return len([fname for fname in os.listdir(self.storage_path) + if (fname.startswith(self.SESSION_PREFIX) + and not fname.endswith(self.LOCK_SUFFIX))]) + + +class PostgresqlSession(Session): + + """ Implementation of the PostgreSQL backend for sessions. It assumes + a table like this:: + + create table session ( + id varchar(40), + data text, + expiration_time timestamp + ) + + You must provide your own get_db function. + """ + + pickle_protocol = pickle.HIGHEST_PROTOCOL + + def __init__(self, id=None, **kwargs): + Session.__init__(self, id, **kwargs) + self.cursor = self.db.cursor() + + def setup(cls, **kwargs): + """Set up the storage system for Postgres-based sessions. + + This should only be called once per process; this will be done + automatically when using sessions.init (as the built-in Tool does). + """ + for k, v in kwargs.items(): + setattr(cls, k, v) + + self.db = self.get_db() + setup = classmethod(setup) + + def __del__(self): + if self.cursor: + self.cursor.close() + self.db.commit() + + def _exists(self): + # Select session data from table + self.cursor.execute('select data, expiration_time from session ' + 'where id=%s', (self.id,)) + rows = self.cursor.fetchall() + return bool(rows) + + def _load(self): + # Select session data from table + self.cursor.execute('select data, expiration_time from session ' + 'where id=%s', (self.id,)) + rows = self.cursor.fetchall() + if not rows: + return None + + pickled_data, expiration_time = rows[0] + data = pickle.loads(pickled_data) + return data, expiration_time + + def _save(self, expiration_time): + pickled_data = pickle.dumps(self._data, self.pickle_protocol) + self.cursor.execute('update session set data = %s, ' + 'expiration_time = %s where id = %s', + (pickled_data, expiration_time, self.id)) + + def _delete(self): + self.cursor.execute('delete from session where id=%s', (self.id,)) + + def acquire_lock(self): + """Acquire an exclusive lock on the currently-loaded session data.""" + # We use the "for update" clause to lock the row + self.locked = True + self.cursor.execute('select id from session where id=%s for update', + (self.id,)) + if self.debug: + cherrypy.log('Lock acquired.', 'TOOLS.SESSIONS') + + def release_lock(self): + """Release the lock on the currently-loaded session data.""" + # We just close the cursor and that will remove the lock + # introduced by the "for update" clause + self.cursor.close() + self.locked = False + + def clean_up(self): + """Clean up expired sessions.""" + self.cursor.execute('delete from session where expiration_time < %s', + (self.now(),)) + + +class MemcachedSession(Session): + + # The most popular memcached client for Python isn't thread-safe. + # Wrap all .get and .set operations in a single lock. + mc_lock = threading.RLock() + + # This is a seperate set of locks per session id. + locks = {} + + servers = ['127.0.0.1:11211'] + + def setup(cls, **kwargs): + """Set up the storage system for memcached-based sessions. + + This should only be called once per process; this will be done + automatically when using sessions.init (as the built-in Tool does). + """ + for k, v in kwargs.items(): + setattr(cls, k, v) + + import memcache + cls.cache = memcache.Client(cls.servers) + setup = classmethod(setup) + + def _get_id(self): + return self._id + + def _set_id(self, value): + # This encode() call is where we differ from the superclass. + # Memcache keys MUST be byte strings, not unicode. + if isinstance(value, unicodestr): + value = value.encode('utf-8') + + self._id = value + for o in self.id_observers: + o(value) + id = property(_get_id, _set_id, doc="The current session ID.") + + def _exists(self): + self.mc_lock.acquire() + try: + return bool(self.cache.get(self.id)) + finally: + self.mc_lock.release() + + def _load(self): + self.mc_lock.acquire() + try: + return self.cache.get(self.id) + finally: + self.mc_lock.release() + + def _save(self, expiration_time): + # Send the expiration time as "Unix time" (seconds since 1/1/1970) + td = int(time.mktime(expiration_time.timetuple())) + self.mc_lock.acquire() + try: + if not self.cache.set(self.id, (self._data, expiration_time), td): + raise AssertionError( + "Session data for id %r not set." % self.id) + finally: + self.mc_lock.release() + + def _delete(self): + self.cache.delete(self.id) + + def acquire_lock(self): + """Acquire an exclusive lock on the currently-loaded session data.""" + self.locked = True + self.locks.setdefault(self.id, threading.RLock()).acquire() + if self.debug: + cherrypy.log('Lock acquired.', 'TOOLS.SESSIONS') + + def release_lock(self): + """Release the lock on the currently-loaded session data.""" + self.locks[self.id].release() + self.locked = False + + def __len__(self): + """Return the number of active sessions.""" + raise NotImplementedError + + +# Hook functions (for CherryPy tools) + +def save(): + """Save any changed session data.""" + + if not hasattr(cherrypy.serving, "session"): + return + request = cherrypy.serving.request + response = cherrypy.serving.response + + # Guard against running twice + if hasattr(request, "_sessionsaved"): + return + request._sessionsaved = True + + if response.stream: + # If the body is being streamed, we have to save the data + # *after* the response has been written out + request.hooks.attach('on_end_request', cherrypy.session.save) + else: + # If the body is not being streamed, we save the data now + # (so we can release the lock). + if is_iterator(response.body): + response.collapse_body() + cherrypy.session.save() +save.failsafe = True + + +def close(): + """Close the session object for this request.""" + sess = getattr(cherrypy.serving, "session", None) + if getattr(sess, "locked", False): + # If the session is still locked we release the lock + sess.release_lock() + if sess.debug: + cherrypy.log('Lock released on close.', 'TOOLS.SESSIONS') +close.failsafe = True +close.priority = 90 + + +def init(storage_type='ram', path=None, path_header=None, name='session_id', + timeout=60, domain=None, secure=False, clean_freq=5, + persistent=True, httponly=False, debug=False, **kwargs): + """Initialize session object (using cookies). + + storage_type + One of 'ram', 'file', 'postgresql', 'memcached'. This will be + used to look up the corresponding class in cherrypy.lib.sessions + globals. For example, 'file' will use the FileSession class. + + path + The 'path' value to stick in the response cookie metadata. + + path_header + If 'path' is None (the default), then the response + cookie 'path' will be pulled from request.headers[path_header]. + + name + The name of the cookie. + + timeout + The expiration timeout (in minutes) for the stored session data. + If 'persistent' is True (the default), this is also the timeout + for the cookie. + + domain + The cookie domain. + + secure + If False (the default) the cookie 'secure' value will not + be set. If True, the cookie 'secure' value will be set (to 1). + + clean_freq (minutes) + The poll rate for expired session cleanup. + + persistent + If True (the default), the 'timeout' argument will be used + to expire the cookie. If False, the cookie will not have an expiry, + and the cookie will be a "session cookie" which expires when the + browser is closed. + + httponly + If False (the default) the cookie 'httponly' value will not be set. + If True, the cookie 'httponly' value will be set (to 1). + + Any additional kwargs will be bound to the new Session instance, + and may be specific to the storage type. See the subclass of Session + you're using for more information. + """ + + request = cherrypy.serving.request + + # Guard against running twice + if hasattr(request, "_session_init_flag"): + return + request._session_init_flag = True + + # Check if request came with a session ID + id = None + if name in request.cookie: + id = request.cookie[name].value + if debug: + cherrypy.log('ID obtained from request.cookie: %r' % id, + 'TOOLS.SESSIONS') + + # Find the storage class and call setup (first time only). + storage_class = storage_type.title() + 'Session' + storage_class = globals()[storage_class] + if not hasattr(cherrypy, "session"): + if hasattr(storage_class, "setup"): + storage_class.setup(**kwargs) + + # Create and attach a new Session instance to cherrypy.serving. + # It will possess a reference to (and lock, and lazily load) + # the requested session data. + kwargs['timeout'] = timeout + kwargs['clean_freq'] = clean_freq + cherrypy.serving.session = sess = storage_class(id, **kwargs) + sess.debug = debug + + def update_cookie(id): + """Update the cookie every time the session id changes.""" + cherrypy.serving.response.cookie[name] = id + sess.id_observers.append(update_cookie) + + # Create cherrypy.session which will proxy to cherrypy.serving.session + if not hasattr(cherrypy, "session"): + cherrypy.session = cherrypy._ThreadLocalProxy('session') + + if persistent: + cookie_timeout = timeout + else: + # See http://support.microsoft.com/kb/223799/EN-US/ + # and http://support.mozilla.com/en-US/kb/Cookies + cookie_timeout = None + set_response_cookie(path=path, path_header=path_header, name=name, + timeout=cookie_timeout, domain=domain, secure=secure, + httponly=httponly) + + +def set_response_cookie(path=None, path_header=None, name='session_id', + timeout=60, domain=None, secure=False, httponly=False): + """Set a response cookie for the client. + + path + the 'path' value to stick in the response cookie metadata. + + path_header + if 'path' is None (the default), then the response + cookie 'path' will be pulled from request.headers[path_header]. + + name + the name of the cookie. + + timeout + the expiration timeout for the cookie. If 0 or other boolean + False, no 'expires' param will be set, and the cookie will be a + "session cookie" which expires when the browser is closed. + + domain + the cookie domain. + + secure + if False (the default) the cookie 'secure' value will not + be set. If True, the cookie 'secure' value will be set (to 1). + + httponly + If False (the default) the cookie 'httponly' value will not be set. + If True, the cookie 'httponly' value will be set (to 1). + + """ + # Set response cookie + cookie = cherrypy.serving.response.cookie + cookie[name] = cherrypy.serving.session.id + cookie[name]['path'] = ( + path or + cherrypy.serving.request.headers.get(path_header) or + '/' + ) + + # We'd like to use the "max-age" param as indicated in + # http://www.faqs.org/rfcs/rfc2109.html but IE doesn't + # save it to disk and the session is lost if people close + # the browser. So we have to use the old "expires" ... sigh ... +## cookie[name]['max-age'] = timeout * 60 + if timeout: + e = time.time() + (timeout * 60) + cookie[name]['expires'] = httputil.HTTPDate(e) + if domain is not None: + cookie[name]['domain'] = domain + if secure: + cookie[name]['secure'] = 1 + if httponly: + if not cookie[name].isReservedKey('httponly'): + raise ValueError("The httponly cookie token is not supported.") + cookie[name]['httponly'] = 1 + + +def expire(): + """Expire the current session cookie.""" + name = cherrypy.serving.request.config.get( + 'tools.sessions.name', 'session_id') + one_year = 60 * 60 * 24 * 365 + e = time.time() - one_year + cherrypy.serving.response.cookie[name]['expires'] = httputil.HTTPDate(e) diff --git a/lib/cherrypy/lib/static.py b/lib/cherrypy/lib/static.py new file mode 100644 index 00000000..a630dae6 --- /dev/null +++ b/lib/cherrypy/lib/static.py @@ -0,0 +1,378 @@ +import os +import re +import stat +import mimetypes + +try: + from io import UnsupportedOperation +except ImportError: + UnsupportedOperation = object() + +import cherrypy +from cherrypy._cpcompat import ntob, unquote +from cherrypy.lib import cptools, httputil, file_generator_limited + + +mimetypes.init() +mimetypes.types_map['.dwg'] = 'image/x-dwg' +mimetypes.types_map['.ico'] = 'image/x-icon' +mimetypes.types_map['.bz2'] = 'application/x-bzip2' +mimetypes.types_map['.gz'] = 'application/x-gzip' + + +def serve_file(path, content_type=None, disposition=None, name=None, + debug=False): + """Set status, headers, and body in order to serve the given path. + + The Content-Type header will be set to the content_type arg, if provided. + If not provided, the Content-Type will be guessed by the file extension + of the 'path' argument. + + If disposition is not None, the Content-Disposition header will be set + to "; filename=". If name is None, it will be set + to the basename of path. If disposition is None, no Content-Disposition + header will be written. + """ + + response = cherrypy.serving.response + + # If path is relative, users should fix it by making path absolute. + # That is, CherryPy should not guess where the application root is. + # It certainly should *not* use cwd (since CP may be invoked from a + # variety of paths). If using tools.staticdir, you can make your relative + # paths become absolute by supplying a value for "tools.staticdir.root". + if not os.path.isabs(path): + msg = "'%s' is not an absolute path." % path + if debug: + cherrypy.log(msg, 'TOOLS.STATICFILE') + raise ValueError(msg) + + try: + st = os.stat(path) + except OSError: + if debug: + cherrypy.log('os.stat(%r) failed' % path, 'TOOLS.STATIC') + raise cherrypy.NotFound() + + # Check if path is a directory. + if stat.S_ISDIR(st.st_mode): + # Let the caller deal with it as they like. + if debug: + cherrypy.log('%r is a directory' % path, 'TOOLS.STATIC') + raise cherrypy.NotFound() + + # Set the Last-Modified response header, so that + # modified-since validation code can work. + response.headers['Last-Modified'] = httputil.HTTPDate(st.st_mtime) + cptools.validate_since() + + if content_type is None: + # Set content-type based on filename extension + ext = "" + i = path.rfind('.') + if i != -1: + ext = path[i:].lower() + content_type = mimetypes.types_map.get(ext, None) + if content_type is not None: + response.headers['Content-Type'] = content_type + if debug: + cherrypy.log('Content-Type: %r' % content_type, 'TOOLS.STATIC') + + cd = None + if disposition is not None: + if name is None: + name = os.path.basename(path) + cd = '%s; filename="%s"' % (disposition, name) + response.headers["Content-Disposition"] = cd + if debug: + cherrypy.log('Content-Disposition: %r' % cd, 'TOOLS.STATIC') + + # Set Content-Length and use an iterable (file object) + # this way CP won't load the whole file in memory + content_length = st.st_size + fileobj = open(path, 'rb') + return _serve_fileobj(fileobj, content_type, content_length, debug=debug) + + +def serve_fileobj(fileobj, content_type=None, disposition=None, name=None, + debug=False): + """Set status, headers, and body in order to serve the given file object. + + The Content-Type header will be set to the content_type arg, if provided. + + If disposition is not None, the Content-Disposition header will be set + to "; filename=". If name is None, 'filename' will + not be set. If disposition is None, no Content-Disposition header will + be written. + + CAUTION: If the request contains a 'Range' header, one or more seek()s will + be performed on the file object. This may cause undesired behavior if + the file object is not seekable. It could also produce undesired results + if the caller set the read position of the file object prior to calling + serve_fileobj(), expecting that the data would be served starting from that + position. + """ + + response = cherrypy.serving.response + + try: + st = os.fstat(fileobj.fileno()) + except AttributeError: + if debug: + cherrypy.log('os has no fstat attribute', 'TOOLS.STATIC') + content_length = None + except UnsupportedOperation: + content_length = None + else: + # Set the Last-Modified response header, so that + # modified-since validation code can work. + response.headers['Last-Modified'] = httputil.HTTPDate(st.st_mtime) + cptools.validate_since() + content_length = st.st_size + + if content_type is not None: + response.headers['Content-Type'] = content_type + if debug: + cherrypy.log('Content-Type: %r' % content_type, 'TOOLS.STATIC') + + cd = None + if disposition is not None: + if name is None: + cd = disposition + else: + cd = '%s; filename="%s"' % (disposition, name) + response.headers["Content-Disposition"] = cd + if debug: + cherrypy.log('Content-Disposition: %r' % cd, 'TOOLS.STATIC') + + return _serve_fileobj(fileobj, content_type, content_length, debug=debug) + + +def _serve_fileobj(fileobj, content_type, content_length, debug=False): + """Internal. Set response.body to the given file object, perhaps ranged.""" + response = cherrypy.serving.response + + # HTTP/1.0 didn't have Range/Accept-Ranges headers, or the 206 code + request = cherrypy.serving.request + if request.protocol >= (1, 1): + response.headers["Accept-Ranges"] = "bytes" + r = httputil.get_ranges(request.headers.get('Range'), content_length) + if r == []: + response.headers['Content-Range'] = "bytes */%s" % content_length + message = ("Invalid Range (first-byte-pos greater than " + "Content-Length)") + if debug: + cherrypy.log(message, 'TOOLS.STATIC') + raise cherrypy.HTTPError(416, message) + + if r: + if len(r) == 1: + # Return a single-part response. + start, stop = r[0] + if stop > content_length: + stop = content_length + r_len = stop - start + if debug: + cherrypy.log( + 'Single part; start: %r, stop: %r' % (start, stop), + 'TOOLS.STATIC') + response.status = "206 Partial Content" + response.headers['Content-Range'] = ( + "bytes %s-%s/%s" % (start, stop - 1, content_length)) + response.headers['Content-Length'] = r_len + fileobj.seek(start) + response.body = file_generator_limited(fileobj, r_len) + else: + # Return a multipart/byteranges response. + response.status = "206 Partial Content" + try: + # Python 3 + from email.generator import _make_boundary as make_boundary + except ImportError: + # Python 2 + from mimetools import choose_boundary as make_boundary + boundary = make_boundary() + ct = "multipart/byteranges; boundary=%s" % boundary + response.headers['Content-Type'] = ct + if "Content-Length" in response.headers: + # Delete Content-Length header so finalize() recalcs it. + del response.headers["Content-Length"] + + def file_ranges(): + # Apache compatibility: + yield ntob("\r\n") + + for start, stop in r: + if debug: + cherrypy.log( + 'Multipart; start: %r, stop: %r' % ( + start, stop), + 'TOOLS.STATIC') + yield ntob("--" + boundary, 'ascii') + yield ntob("\r\nContent-type: %s" % content_type, + 'ascii') + yield ntob( + "\r\nContent-range: bytes %s-%s/%s\r\n\r\n" % ( + start, stop - 1, content_length), + 'ascii') + fileobj.seek(start) + gen = file_generator_limited(fileobj, stop - start) + for chunk in gen: + yield chunk + yield ntob("\r\n") + # Final boundary + yield ntob("--" + boundary + "--", 'ascii') + + # Apache compatibility: + yield ntob("\r\n") + response.body = file_ranges() + return response.body + else: + if debug: + cherrypy.log('No byteranges requested', 'TOOLS.STATIC') + + # Set Content-Length and use an iterable (file object) + # this way CP won't load the whole file in memory + response.headers['Content-Length'] = content_length + response.body = fileobj + return response.body + + +def serve_download(path, name=None): + """Serve 'path' as an application/x-download attachment.""" + # This is such a common idiom I felt it deserved its own wrapper. + return serve_file(path, "application/x-download", "attachment", name) + + +def _attempt(filename, content_types, debug=False): + if debug: + cherrypy.log('Attempting %r (content_types %r)' % + (filename, content_types), 'TOOLS.STATICDIR') + try: + # you can set the content types for a + # complete directory per extension + content_type = None + if content_types: + r, ext = os.path.splitext(filename) + content_type = content_types.get(ext[1:], None) + serve_file(filename, content_type=content_type, debug=debug) + return True + except cherrypy.NotFound: + # If we didn't find the static file, continue handling the + # request. We might find a dynamic handler instead. + if debug: + cherrypy.log('NotFound', 'TOOLS.STATICFILE') + return False + + +def staticdir(section, dir, root="", match="", content_types=None, index="", + debug=False): + """Serve a static resource from the given (root +) dir. + + match + If given, request.path_info will be searched for the given + regular expression before attempting to serve static content. + + content_types + If given, it should be a Python dictionary of + {file-extension: content-type} pairs, where 'file-extension' is + a string (e.g. "gif") and 'content-type' is the value to write + out in the Content-Type response header (e.g. "image/gif"). + + index + If provided, it should be the (relative) name of a file to + serve for directory requests. For example, if the dir argument is + '/home/me', the Request-URI is 'myapp', and the index arg is + 'index.html', the file '/home/me/myapp/index.html' will be sought. + """ + request = cherrypy.serving.request + if request.method not in ('GET', 'HEAD'): + if debug: + cherrypy.log('request.method not GET or HEAD', 'TOOLS.STATICDIR') + return False + + if match and not re.search(match, request.path_info): + if debug: + cherrypy.log('request.path_info %r does not match pattern %r' % + (request.path_info, match), 'TOOLS.STATICDIR') + return False + + # Allow the use of '~' to refer to a user's home directory. + dir = os.path.expanduser(dir) + + # If dir is relative, make absolute using "root". + if not os.path.isabs(dir): + if not root: + msg = "Static dir requires an absolute dir (or root)." + if debug: + cherrypy.log(msg, 'TOOLS.STATICDIR') + raise ValueError(msg) + dir = os.path.join(root, dir) + + # Determine where we are in the object tree relative to 'section' + # (where the static tool was defined). + if section == 'global': + section = "/" + section = section.rstrip(r"\/") + branch = request.path_info[len(section) + 1:] + branch = unquote(branch.lstrip(r"\/")) + + # If branch is "", filename will end in a slash + filename = os.path.join(dir, branch) + if debug: + cherrypy.log('Checking file %r to fulfill %r' % + (filename, request.path_info), 'TOOLS.STATICDIR') + + # There's a chance that the branch pulled from the URL might + # have ".." or similar uplevel attacks in it. Check that the final + # filename is a child of dir. + if not os.path.normpath(filename).startswith(os.path.normpath(dir)): + raise cherrypy.HTTPError(403) # Forbidden + + handled = _attempt(filename, content_types) + if not handled: + # Check for an index file if a folder was requested. + if index: + handled = _attempt(os.path.join(filename, index), content_types) + if handled: + request.is_index = filename[-1] in (r"\/") + return handled + + +def staticfile(filename, root=None, match="", content_types=None, debug=False): + """Serve a static resource from the given (root +) filename. + + match + If given, request.path_info will be searched for the given + regular expression before attempting to serve static content. + + content_types + If given, it should be a Python dictionary of + {file-extension: content-type} pairs, where 'file-extension' is + a string (e.g. "gif") and 'content-type' is the value to write + out in the Content-Type response header (e.g. "image/gif"). + + """ + request = cherrypy.serving.request + if request.method not in ('GET', 'HEAD'): + if debug: + cherrypy.log('request.method not GET or HEAD', 'TOOLS.STATICFILE') + return False + + if match and not re.search(match, request.path_info): + if debug: + cherrypy.log('request.path_info %r does not match pattern %r' % + (request.path_info, match), 'TOOLS.STATICFILE') + return False + + # If filename is relative, make absolute using "root". + if not os.path.isabs(filename): + if not root: + msg = "Static tool requires an absolute filename (got '%s')." % ( + filename,) + if debug: + cherrypy.log(msg, 'TOOLS.STATICFILE') + raise ValueError(msg) + filename = os.path.join(root, filename) + + return _attempt(filename, content_types, debug=debug) diff --git a/lib/cherrypy/lib/xmlrpcutil.py b/lib/cherrypy/lib/xmlrpcutil.py new file mode 100644 index 00000000..9fc9564f --- /dev/null +++ b/lib/cherrypy/lib/xmlrpcutil.py @@ -0,0 +1,57 @@ +import sys + +import cherrypy +from cherrypy._cpcompat import ntob + + +def get_xmlrpclib(): + try: + import xmlrpc.client as x + except ImportError: + import xmlrpclib as x + return x + + +def process_body(): + """Return (params, method) from request body.""" + try: + return get_xmlrpclib().loads(cherrypy.request.body.read()) + except Exception: + return ('ERROR PARAMS', ), 'ERRORMETHOD' + + +def patched_path(path): + """Return 'path', doctored for RPC.""" + if not path.endswith('/'): + path += '/' + if path.startswith('/RPC2/'): + # strip the first /rpc2 + path = path[5:] + return path + + +def _set_response(body): + # The XML-RPC spec (http://www.xmlrpc.com/spec) says: + # "Unless there's a lower-level error, always return 200 OK." + # Since Python's xmlrpclib interprets a non-200 response + # as a "Protocol Error", we'll just return 200 every time. + response = cherrypy.response + response.status = '200 OK' + response.body = ntob(body, 'utf-8') + response.headers['Content-Type'] = 'text/xml' + response.headers['Content-Length'] = len(body) + + +def respond(body, encoding='utf-8', allow_none=0): + xmlrpclib = get_xmlrpclib() + if not isinstance(body, xmlrpclib.Fault): + body = (body,) + _set_response(xmlrpclib.dumps(body, methodresponse=1, + encoding=encoding, + allow_none=allow_none)) + + +def on_error(*args, **kwargs): + body = str(sys.exc_info()[1]) + xmlrpclib = get_xmlrpclib() + _set_response(xmlrpclib.dumps(xmlrpclib.Fault(1, body))) diff --git a/lib/cherrypy/process/__init__.py b/lib/cherrypy/process/__init__.py new file mode 100644 index 00000000..f15b1237 --- /dev/null +++ b/lib/cherrypy/process/__init__.py @@ -0,0 +1,14 @@ +"""Site container for an HTTP server. + +A Web Site Process Bus object is used to connect applications, servers, +and frameworks with site-wide services such as daemonization, process +reload, signal handling, drop privileges, PID file management, logging +for all of these, and many more. + +The 'plugins' module defines a few abstract and concrete services for +use with the bus. Some use tool-specific channels; see the documentation +for each class. +""" + +from cherrypy.process.wspbus import bus +from cherrypy.process import plugins, servers diff --git a/lib/cherrypy/process/plugins.py b/lib/cherrypy/process/plugins.py new file mode 100644 index 00000000..c787ba92 --- /dev/null +++ b/lib/cherrypy/process/plugins.py @@ -0,0 +1,717 @@ +"""Site services for use with a Web Site Process Bus.""" + +import os +import re +import signal as _signal +import sys +import time +import threading + +from cherrypy._cpcompat import basestring, get_daemon, get_thread_ident +from cherrypy._cpcompat import ntob, set, Timer, SetDaemonProperty + +# _module__file__base is used by Autoreload to make +# absolute any filenames retrieved from sys.modules which are not +# already absolute paths. This is to work around Python's quirk +# of importing the startup script and using a relative filename +# for it in sys.modules. +# +# Autoreload examines sys.modules afresh every time it runs. If an application +# changes the current directory by executing os.chdir(), then the next time +# Autoreload runs, it will not be able to find any filenames which are +# not absolute paths, because the current directory is not the same as when the +# module was first imported. Autoreload will then wrongly conclude the file +# has "changed", and initiate the shutdown/re-exec sequence. +# See ticket #917. +# For this workaround to have a decent probability of success, this module +# needs to be imported as early as possible, before the app has much chance +# to change the working directory. +_module__file__base = os.getcwd() + + +class SimplePlugin(object): + + """Plugin base class which auto-subscribes methods for known channels.""" + + bus = None + """A :class:`Bus `, usually cherrypy.engine. + """ + + def __init__(self, bus): + self.bus = bus + + def subscribe(self): + """Register this object as a (multi-channel) listener on the bus.""" + for channel in self.bus.listeners: + # Subscribe self.start, self.exit, etc. if present. + method = getattr(self, channel, None) + if method is not None: + self.bus.subscribe(channel, method) + + def unsubscribe(self): + """Unregister this object as a listener on the bus.""" + for channel in self.bus.listeners: + # Unsubscribe self.start, self.exit, etc. if present. + method = getattr(self, channel, None) + if method is not None: + self.bus.unsubscribe(channel, method) + + +class SignalHandler(object): + + """Register bus channels (and listeners) for system signals. + + You can modify what signals your application listens for, and what it does + when it receives signals, by modifying :attr:`SignalHandler.handlers`, + a dict of {signal name: callback} pairs. The default set is:: + + handlers = {'SIGTERM': self.bus.exit, + 'SIGHUP': self.handle_SIGHUP, + 'SIGUSR1': self.bus.graceful, + } + + The :func:`SignalHandler.handle_SIGHUP`` method calls + :func:`bus.restart()` + if the process is daemonized, but + :func:`bus.exit()` + if the process is attached to a TTY. This is because Unix window + managers tend to send SIGHUP to terminal windows when the user closes them. + + Feel free to add signals which are not available on every platform. + The :class:`SignalHandler` will ignore errors raised from attempting + to register handlers for unknown signals. + """ + + handlers = {} + """A map from signal names (e.g. 'SIGTERM') to handlers (e.g. bus.exit).""" + + signals = {} + """A map from signal numbers to names.""" + + for k, v in vars(_signal).items(): + if k.startswith('SIG') and not k.startswith('SIG_'): + signals[v] = k + del k, v + + def __init__(self, bus): + self.bus = bus + # Set default handlers + self.handlers = {'SIGTERM': self.bus.exit, + 'SIGHUP': self.handle_SIGHUP, + 'SIGUSR1': self.bus.graceful, + } + + if sys.platform[:4] == 'java': + del self.handlers['SIGUSR1'] + self.handlers['SIGUSR2'] = self.bus.graceful + self.bus.log("SIGUSR1 cannot be set on the JVM platform. " + "Using SIGUSR2 instead.") + self.handlers['SIGINT'] = self._jython_SIGINT_handler + + self._previous_handlers = {} + + def _jython_SIGINT_handler(self, signum=None, frame=None): + # See http://bugs.jython.org/issue1313 + self.bus.log('Keyboard Interrupt: shutting down bus') + self.bus.exit() + + def subscribe(self): + """Subscribe self.handlers to signals.""" + for sig, func in self.handlers.items(): + try: + self.set_handler(sig, func) + except ValueError: + pass + + def unsubscribe(self): + """Unsubscribe self.handlers from signals.""" + for signum, handler in self._previous_handlers.items(): + signame = self.signals[signum] + + if handler is None: + self.bus.log("Restoring %s handler to SIG_DFL." % signame) + handler = _signal.SIG_DFL + else: + self.bus.log("Restoring %s handler %r." % (signame, handler)) + + try: + our_handler = _signal.signal(signum, handler) + if our_handler is None: + self.bus.log("Restored old %s handler %r, but our " + "handler was not registered." % + (signame, handler), level=30) + except ValueError: + self.bus.log("Unable to restore %s handler %r." % + (signame, handler), level=40, traceback=True) + + def set_handler(self, signal, listener=None): + """Subscribe a handler for the given signal (number or name). + + If the optional 'listener' argument is provided, it will be + subscribed as a listener for the given signal's channel. + + If the given signal name or number is not available on the current + platform, ValueError is raised. + """ + if isinstance(signal, basestring): + signum = getattr(_signal, signal, None) + if signum is None: + raise ValueError("No such signal: %r" % signal) + signame = signal + else: + try: + signame = self.signals[signal] + except KeyError: + raise ValueError("No such signal: %r" % signal) + signum = signal + + prev = _signal.signal(signum, self._handle_signal) + self._previous_handlers[signum] = prev + + if listener is not None: + self.bus.log("Listening for %s." % signame) + self.bus.subscribe(signame, listener) + + def _handle_signal(self, signum=None, frame=None): + """Python signal handler (self.set_handler subscribes it for you).""" + signame = self.signals[signum] + self.bus.log("Caught signal %s." % signame) + self.bus.publish(signame) + + def handle_SIGHUP(self): + """Restart if daemonized, else exit.""" + if os.isatty(sys.stdin.fileno()): + # not daemonized (may be foreground or background) + self.bus.log("SIGHUP caught but not daemonized. Exiting.") + self.bus.exit() + else: + self.bus.log("SIGHUP caught while daemonized. Restarting.") + self.bus.restart() + + +try: + import pwd + import grp +except ImportError: + pwd, grp = None, None + + +class DropPrivileges(SimplePlugin): + + """Drop privileges. uid/gid arguments not available on Windows. + + Special thanks to `Gavin Baker `_ + """ + + def __init__(self, bus, umask=None, uid=None, gid=None): + SimplePlugin.__init__(self, bus) + self.finalized = False + self.uid = uid + self.gid = gid + self.umask = umask + + def _get_uid(self): + return self._uid + + def _set_uid(self, val): + if val is not None: + if pwd is None: + self.bus.log("pwd module not available; ignoring uid.", + level=30) + val = None + elif isinstance(val, basestring): + val = pwd.getpwnam(val)[2] + self._uid = val + uid = property(_get_uid, _set_uid, + doc="The uid under which to run. Availability: Unix.") + + def _get_gid(self): + return self._gid + + def _set_gid(self, val): + if val is not None: + if grp is None: + self.bus.log("grp module not available; ignoring gid.", + level=30) + val = None + elif isinstance(val, basestring): + val = grp.getgrnam(val)[2] + self._gid = val + gid = property(_get_gid, _set_gid, + doc="The gid under which to run. Availability: Unix.") + + def _get_umask(self): + return self._umask + + def _set_umask(self, val): + if val is not None: + try: + os.umask + except AttributeError: + self.bus.log("umask function not available; ignoring umask.", + level=30) + val = None + self._umask = val + umask = property( + _get_umask, + _set_umask, + doc="""The default permission mode for newly created files and + directories. + + Usually expressed in octal format, for example, ``0644``. + Availability: Unix, Windows. + """) + + def start(self): + # uid/gid + def current_ids(): + """Return the current (uid, gid) if available.""" + name, group = None, None + if pwd: + name = pwd.getpwuid(os.getuid())[0] + if grp: + group = grp.getgrgid(os.getgid())[0] + return name, group + + if self.finalized: + if not (self.uid is None and self.gid is None): + self.bus.log('Already running as uid: %r gid: %r' % + current_ids()) + else: + if self.uid is None and self.gid is None: + if pwd or grp: + self.bus.log('uid/gid not set', level=30) + else: + self.bus.log('Started as uid: %r gid: %r' % current_ids()) + if self.gid is not None: + os.setgid(self.gid) + os.setgroups([]) + if self.uid is not None: + os.setuid(self.uid) + self.bus.log('Running as uid: %r gid: %r' % current_ids()) + + # umask + if self.finalized: + if self.umask is not None: + self.bus.log('umask already set to: %03o' % self.umask) + else: + if self.umask is None: + self.bus.log('umask not set', level=30) + else: + old_umask = os.umask(self.umask) + self.bus.log('umask old: %03o, new: %03o' % + (old_umask, self.umask)) + + self.finalized = True + # This is slightly higher than the priority for server.start + # in order to facilitate the most common use: starting on a low + # port (which requires root) and then dropping to another user. + start.priority = 77 + + +class Daemonizer(SimplePlugin): + + """Daemonize the running script. + + Use this with a Web Site Process Bus via:: + + Daemonizer(bus).subscribe() + + When this component finishes, the process is completely decoupled from + the parent environment. Please note that when this component is used, + the return code from the parent process will still be 0 if a startup + error occurs in the forked children. Errors in the initial daemonizing + process still return proper exit codes. Therefore, if you use this + plugin to daemonize, don't use the return code as an accurate indicator + of whether the process fully started. In fact, that return code only + indicates if the process succesfully finished the first fork. + """ + + def __init__(self, bus, stdin='/dev/null', stdout='/dev/null', + stderr='/dev/null'): + SimplePlugin.__init__(self, bus) + self.stdin = stdin + self.stdout = stdout + self.stderr = stderr + self.finalized = False + + def start(self): + if self.finalized: + self.bus.log('Already deamonized.') + + # forking has issues with threads: + # http://www.opengroup.org/onlinepubs/000095399/functions/fork.html + # "The general problem with making fork() work in a multi-threaded + # world is what to do with all of the threads..." + # So we check for active threads: + if threading.activeCount() != 1: + self.bus.log('There are %r active threads. ' + 'Daemonizing now may cause strange failures.' % + threading.enumerate(), level=30) + + # See http://www.erlenstar.demon.co.uk/unix/faq_2.html#SEC16 + # (or http://www.faqs.org/faqs/unix-faq/programmer/faq/ section 1.7) + # and http://aspn.activestate.com/ASPN/Cookbook/Python/Recipe/66012 + + # Finish up with the current stdout/stderr + sys.stdout.flush() + sys.stderr.flush() + + # Do first fork. + try: + pid = os.fork() + if pid == 0: + # This is the child process. Continue. + pass + else: + # This is the first parent. Exit, now that we've forked. + self.bus.log('Forking once.') + os._exit(0) + except OSError: + # Python raises OSError rather than returning negative numbers. + exc = sys.exc_info()[1] + sys.exit("%s: fork #1 failed: (%d) %s\n" + % (sys.argv[0], exc.errno, exc.strerror)) + + os.setsid() + + # Do second fork + try: + pid = os.fork() + if pid > 0: + self.bus.log('Forking twice.') + os._exit(0) # Exit second parent + except OSError: + exc = sys.exc_info()[1] + sys.exit("%s: fork #2 failed: (%d) %s\n" + % (sys.argv[0], exc.errno, exc.strerror)) + + os.chdir("/") + os.umask(0) + + si = open(self.stdin, "r") + so = open(self.stdout, "a+") + se = open(self.stderr, "a+") + + # os.dup2(fd, fd2) will close fd2 if necessary, + # so we don't explicitly close stdin/out/err. + # See http://docs.python.org/lib/os-fd-ops.html + os.dup2(si.fileno(), sys.stdin.fileno()) + os.dup2(so.fileno(), sys.stdout.fileno()) + os.dup2(se.fileno(), sys.stderr.fileno()) + + self.bus.log('Daemonized to PID: %s' % os.getpid()) + self.finalized = True + start.priority = 65 + + +class PIDFile(SimplePlugin): + + """Maintain a PID file via a WSPBus.""" + + def __init__(self, bus, pidfile): + SimplePlugin.__init__(self, bus) + self.pidfile = pidfile + self.finalized = False + + def start(self): + pid = os.getpid() + if self.finalized: + self.bus.log('PID %r already written to %r.' % (pid, self.pidfile)) + else: + open(self.pidfile, "wb").write(ntob("%s\n" % pid, 'utf8')) + self.bus.log('PID %r written to %r.' % (pid, self.pidfile)) + self.finalized = True + start.priority = 70 + + def exit(self): + try: + os.remove(self.pidfile) + self.bus.log('PID file removed: %r.' % self.pidfile) + except (KeyboardInterrupt, SystemExit): + raise + except: + pass + + +class PerpetualTimer(Timer): + + """A responsive subclass of threading.Timer whose run() method repeats. + + Use this timer only when you really need a very interruptible timer; + this checks its 'finished' condition up to 20 times a second, which can + results in pretty high CPU usage + """ + + def __init__(self, *args, **kwargs): + "Override parent constructor to allow 'bus' to be provided." + self.bus = kwargs.pop('bus', None) + super(PerpetualTimer, self).__init__(*args, **kwargs) + + def run(self): + while True: + self.finished.wait(self.interval) + if self.finished.isSet(): + return + try: + self.function(*self.args, **self.kwargs) + except Exception: + if self.bus: + self.bus.log( + "Error in perpetual timer thread function %r." % + self.function, level=40, traceback=True) + # Quit on first error to avoid massive logs. + raise + + +class BackgroundTask(SetDaemonProperty, threading.Thread): + + """A subclass of threading.Thread whose run() method repeats. + + Use this class for most repeating tasks. It uses time.sleep() to wait + for each interval, which isn't very responsive; that is, even if you call + self.cancel(), you'll have to wait until the sleep() call finishes before + the thread stops. To compensate, it defaults to being daemonic, which means + it won't delay stopping the whole process. + """ + + def __init__(self, interval, function, args=[], kwargs={}, bus=None): + threading.Thread.__init__(self) + self.interval = interval + self.function = function + self.args = args + self.kwargs = kwargs + self.running = False + self.bus = bus + + # default to daemonic + self.daemon = True + + def cancel(self): + self.running = False + + def run(self): + self.running = True + while self.running: + time.sleep(self.interval) + if not self.running: + return + try: + self.function(*self.args, **self.kwargs) + except Exception: + if self.bus: + self.bus.log("Error in background task thread function %r." + % self.function, level=40, traceback=True) + # Quit on first error to avoid massive logs. + raise + + +class Monitor(SimplePlugin): + + """WSPBus listener to periodically run a callback in its own thread.""" + + callback = None + """The function to call at intervals.""" + + frequency = 60 + """The time in seconds between callback runs.""" + + thread = None + """A :class:`BackgroundTask` + thread. + """ + + def __init__(self, bus, callback, frequency=60, name=None): + SimplePlugin.__init__(self, bus) + self.callback = callback + self.frequency = frequency + self.thread = None + self.name = name + + def start(self): + """Start our callback in its own background thread.""" + if self.frequency > 0: + threadname = self.name or self.__class__.__name__ + if self.thread is None: + self.thread = BackgroundTask(self.frequency, self.callback, + bus=self.bus) + self.thread.setName(threadname) + self.thread.start() + self.bus.log("Started monitor thread %r." % threadname) + else: + self.bus.log("Monitor thread %r already started." % threadname) + start.priority = 70 + + def stop(self): + """Stop our callback's background task thread.""" + if self.thread is None: + self.bus.log("No thread running for %s." % + self.name or self.__class__.__name__) + else: + if self.thread is not threading.currentThread(): + name = self.thread.getName() + self.thread.cancel() + if not get_daemon(self.thread): + self.bus.log("Joining %r" % name) + self.thread.join() + self.bus.log("Stopped thread %r." % name) + self.thread = None + + def graceful(self): + """Stop the callback's background task thread and restart it.""" + self.stop() + self.start() + + +class Autoreloader(Monitor): + + """Monitor which re-executes the process when files change. + + This :ref:`plugin` restarts the process (via :func:`os.execv`) + if any of the files it monitors change (or is deleted). By default, the + autoreloader monitors all imported modules; you can add to the + set by adding to ``autoreload.files``:: + + cherrypy.engine.autoreload.files.add(myFile) + + If there are imported files you do *not* wish to monitor, you can + adjust the ``match`` attribute, a regular expression. For example, + to stop monitoring cherrypy itself:: + + cherrypy.engine.autoreload.match = r'^(?!cherrypy).+' + + Like all :class:`Monitor` plugins, + the autoreload plugin takes a ``frequency`` argument. The default is + 1 second; that is, the autoreloader will examine files once each second. + """ + + files = None + """The set of files to poll for modifications.""" + + frequency = 1 + """The interval in seconds at which to poll for modified files.""" + + match = '.*' + """A regular expression by which to match filenames.""" + + def __init__(self, bus, frequency=1, match='.*'): + self.mtimes = {} + self.files = set() + self.match = match + Monitor.__init__(self, bus, self.run, frequency) + + def start(self): + """Start our own background task thread for self.run.""" + if self.thread is None: + self.mtimes = {} + Monitor.start(self) + start.priority = 70 + + def sysfiles(self): + """Return a Set of sys.modules filenames to monitor.""" + files = set() + for k, m in list(sys.modules.items()): + if re.match(self.match, k): + if ( + hasattr(m, '__loader__') and + hasattr(m.__loader__, 'archive') + ): + f = m.__loader__.archive + else: + f = getattr(m, '__file__', None) + if f is not None and not os.path.isabs(f): + # ensure absolute paths so a os.chdir() in the app + # doesn't break me + f = os.path.normpath( + os.path.join(_module__file__base, f)) + files.add(f) + return files + + def run(self): + """Reload the process if registered files have been modified.""" + for filename in self.sysfiles() | self.files: + if filename: + if filename.endswith('.pyc'): + filename = filename[:-1] + + oldtime = self.mtimes.get(filename, 0) + if oldtime is None: + # Module with no .py file. Skip it. + continue + + try: + mtime = os.stat(filename).st_mtime + except OSError: + # Either a module with no .py file, or it's been deleted. + mtime = None + + if filename not in self.mtimes: + # If a module has no .py file, this will be None. + self.mtimes[filename] = mtime + else: + if mtime is None or mtime > oldtime: + # The file has been deleted or modified. + self.bus.log("Restarting because %s changed." % + filename) + self.thread.cancel() + self.bus.log("Stopped thread %r." % + self.thread.getName()) + self.bus.restart() + return + + +class ThreadManager(SimplePlugin): + + """Manager for HTTP request threads. + + If you have control over thread creation and destruction, publish to + the 'acquire_thread' and 'release_thread' channels (for each thread). + This will register/unregister the current thread and publish to + 'start_thread' and 'stop_thread' listeners in the bus as needed. + + If threads are created and destroyed by code you do not control + (e.g., Apache), then, at the beginning of every HTTP request, + publish to 'acquire_thread' only. You should not publish to + 'release_thread' in this case, since you do not know whether + the thread will be re-used or not. The bus will call + 'stop_thread' listeners for you when it stops. + """ + + threads = None + """A map of {thread ident: index number} pairs.""" + + def __init__(self, bus): + self.threads = {} + SimplePlugin.__init__(self, bus) + self.bus.listeners.setdefault('acquire_thread', set()) + self.bus.listeners.setdefault('start_thread', set()) + self.bus.listeners.setdefault('release_thread', set()) + self.bus.listeners.setdefault('stop_thread', set()) + + def acquire_thread(self): + """Run 'start_thread' listeners for the current thread. + + If the current thread has already been seen, any 'start_thread' + listeners will not be run again. + """ + thread_ident = get_thread_ident() + if thread_ident not in self.threads: + # We can't just use get_ident as the thread ID + # because some platforms reuse thread ID's. + i = len(self.threads) + 1 + self.threads[thread_ident] = i + self.bus.publish('start_thread', i) + + def release_thread(self): + """Release the current thread and run 'stop_thread' listeners.""" + thread_ident = get_thread_ident() + i = self.threads.pop(thread_ident, None) + if i is not None: + self.bus.publish('stop_thread', i) + + def stop(self): + """Release all threads and run all 'stop_thread' listeners.""" + for thread_ident, i in self.threads.items(): + self.bus.publish('stop_thread', i) + self.threads.clear() + graceful = stop diff --git a/lib/cherrypy/process/servers.py b/lib/cherrypy/process/servers.py new file mode 100644 index 00000000..fef37f77 --- /dev/null +++ b/lib/cherrypy/process/servers.py @@ -0,0 +1,465 @@ +""" +Starting in CherryPy 3.1, cherrypy.server is implemented as an +:ref:`Engine Plugin`. It's an instance of +:class:`cherrypy._cpserver.Server`, which is a subclass of +:class:`cherrypy.process.servers.ServerAdapter`. The ``ServerAdapter`` class +is designed to control other servers, as well. + +Multiple servers/ports +====================== + +If you need to start more than one HTTP server (to serve on multiple ports, or +protocols, etc.), you can manually register each one and then start them all +with engine.start:: + + s1 = ServerAdapter(cherrypy.engine, MyWSGIServer(host='0.0.0.0', port=80)) + s2 = ServerAdapter(cherrypy.engine, + another.HTTPServer(host='127.0.0.1', + SSL=True)) + s1.subscribe() + s2.subscribe() + cherrypy.engine.start() + +.. index:: SCGI + +FastCGI/SCGI +============ + +There are also Flup\ **F**\ CGIServer and Flup\ **S**\ CGIServer classes in +:mod:`cherrypy.process.servers`. To start an fcgi server, for example, +wrap an instance of it in a ServerAdapter:: + + addr = ('0.0.0.0', 4000) + f = servers.FlupFCGIServer(application=cherrypy.tree, bindAddress=addr) + s = servers.ServerAdapter(cherrypy.engine, httpserver=f, bind_addr=addr) + s.subscribe() + +The :doc:`cherryd` startup script will do the above for +you via its `-f` flag. +Note that you need to download and install `flup `_ +yourself, whether you use ``cherryd`` or not. + +.. _fastcgi: +.. index:: FastCGI + +FastCGI +------- + +A very simple setup lets your cherry run with FastCGI. +You just need the flup library, +plus a running Apache server (with ``mod_fastcgi``) or lighttpd server. + +CherryPy code +^^^^^^^^^^^^^ + +hello.py:: + + #!/usr/bin/python + import cherrypy + + class HelloWorld: + \"""Sample request handler class.\""" + def index(self): + return "Hello world!" + index.exposed = True + + cherrypy.tree.mount(HelloWorld()) + # CherryPy autoreload must be disabled for the flup server to work + cherrypy.config.update({'engine.autoreload.on':False}) + +Then run :doc:`/deployguide/cherryd` with the '-f' arg:: + + cherryd -c -d -f -i hello.py + +Apache +^^^^^^ + +At the top level in httpd.conf:: + + FastCgiIpcDir /tmp + FastCgiServer /path/to/cherry.fcgi -idle-timeout 120 -processes 4 + +And inside the relevant VirtualHost section:: + + # FastCGI config + AddHandler fastcgi-script .fcgi + ScriptAliasMatch (.*$) /path/to/cherry.fcgi$1 + +Lighttpd +^^^^^^^^ + +For `Lighttpd `_ you can follow these +instructions. Within ``lighttpd.conf`` make sure ``mod_fastcgi`` is +active within ``server.modules``. Then, within your ``$HTTP["host"]`` +directive, configure your fastcgi script like the following:: + + $HTTP["url"] =~ "" { + fastcgi.server = ( + "/" => ( + "script.fcgi" => ( + "bin-path" => "/path/to/your/script.fcgi", + "socket" => "/tmp/script.sock", + "check-local" => "disable", + "disable-time" => 1, + "min-procs" => 1, + "max-procs" => 1, # adjust as needed + ), + ), + ) + } # end of $HTTP["url"] =~ "^/" + +Please see `Lighttpd FastCGI Docs +`_ for +an explanation of the possible configuration options. +""" + +import sys +import time +import warnings + + +class ServerAdapter(object): + + """Adapter for an HTTP server. + + If you need to start more than one HTTP server (to serve on multiple + ports, or protocols, etc.), you can manually register each one and then + start them all with bus.start: + + s1 = ServerAdapter(bus, MyWSGIServer(host='0.0.0.0', port=80)) + s2 = ServerAdapter(bus, another.HTTPServer(host='127.0.0.1', SSL=True)) + s1.subscribe() + s2.subscribe() + bus.start() + """ + + def __init__(self, bus, httpserver=None, bind_addr=None): + self.bus = bus + self.httpserver = httpserver + self.bind_addr = bind_addr + self.interrupt = None + self.running = False + + def subscribe(self): + self.bus.subscribe('start', self.start) + self.bus.subscribe('stop', self.stop) + + def unsubscribe(self): + self.bus.unsubscribe('start', self.start) + self.bus.unsubscribe('stop', self.stop) + + def start(self): + """Start the HTTP server.""" + if self.bind_addr is None: + on_what = "unknown interface (dynamic?)" + elif isinstance(self.bind_addr, tuple): + on_what = self._get_base() + else: + on_what = "socket file: %s" % self.bind_addr + + if self.running: + self.bus.log("Already serving on %s" % on_what) + return + + self.interrupt = None + if not self.httpserver: + raise ValueError("No HTTP server has been created.") + + # Start the httpserver in a new thread. + if isinstance(self.bind_addr, tuple): + wait_for_free_port(*self.bind_addr) + + import threading + t = threading.Thread(target=self._start_http_thread) + t.setName("HTTPServer " + t.getName()) + t.start() + + self.wait() + self.running = True + self.bus.log("Serving on %s" % on_what) + start.priority = 75 + + def _get_base(self): + if not self.httpserver: + return '' + host, port = self.bind_addr + if getattr(self.httpserver, 'ssl_certificate', None): + scheme = "https" + if port != 443: + host += ":%s" % port + else: + scheme = "http" + if port != 80: + host += ":%s" % port + + return "%s://%s" % (scheme, host) + + def _start_http_thread(self): + """HTTP servers MUST be running in new threads, so that the + main thread persists to receive KeyboardInterrupt's. If an + exception is raised in the httpserver's thread then it's + trapped here, and the bus (and therefore our httpserver) + are shut down. + """ + try: + self.httpserver.start() + except KeyboardInterrupt: + self.bus.log(" hit: shutting down HTTP server") + self.interrupt = sys.exc_info()[1] + self.bus.exit() + except SystemExit: + self.bus.log("SystemExit raised: shutting down HTTP server") + self.interrupt = sys.exc_info()[1] + self.bus.exit() + raise + except: + self.interrupt = sys.exc_info()[1] + self.bus.log("Error in HTTP server: shutting down", + traceback=True, level=40) + self.bus.exit() + raise + + def wait(self): + """Wait until the HTTP server is ready to receive requests.""" + while not getattr(self.httpserver, "ready", False): + if self.interrupt: + raise self.interrupt + time.sleep(.1) + + # Wait for port to be occupied + if isinstance(self.bind_addr, tuple): + host, port = self.bind_addr + wait_for_occupied_port(host, port) + + def stop(self): + """Stop the HTTP server.""" + if self.running: + # stop() MUST block until the server is *truly* stopped. + self.httpserver.stop() + # Wait for the socket to be truly freed. + if isinstance(self.bind_addr, tuple): + wait_for_free_port(*self.bind_addr) + self.running = False + self.bus.log("HTTP Server %s shut down" % self.httpserver) + else: + self.bus.log("HTTP Server %s already shut down" % self.httpserver) + stop.priority = 25 + + def restart(self): + """Restart the HTTP server.""" + self.stop() + self.start() + + +class FlupCGIServer(object): + + """Adapter for a flup.server.cgi.WSGIServer.""" + + def __init__(self, *args, **kwargs): + self.args = args + self.kwargs = kwargs + self.ready = False + + def start(self): + """Start the CGI server.""" + # We have to instantiate the server class here because its __init__ + # starts a threadpool. If we do it too early, daemonize won't work. + from flup.server.cgi import WSGIServer + + self.cgiserver = WSGIServer(*self.args, **self.kwargs) + self.ready = True + self.cgiserver.run() + + def stop(self): + """Stop the HTTP server.""" + self.ready = False + + +class FlupFCGIServer(object): + + """Adapter for a flup.server.fcgi.WSGIServer.""" + + def __init__(self, *args, **kwargs): + if kwargs.get('bindAddress', None) is None: + import socket + if not hasattr(socket, 'fromfd'): + raise ValueError( + 'Dynamic FCGI server not available on this platform. ' + 'You must use a static or external one by providing a ' + 'legal bindAddress.') + self.args = args + self.kwargs = kwargs + self.ready = False + + def start(self): + """Start the FCGI server.""" + # We have to instantiate the server class here because its __init__ + # starts a threadpool. If we do it too early, daemonize won't work. + from flup.server.fcgi import WSGIServer + self.fcgiserver = WSGIServer(*self.args, **self.kwargs) + # TODO: report this bug upstream to flup. + # If we don't set _oldSIGs on Windows, we get: + # File "C:\Python24\Lib\site-packages\flup\server\threadedserver.py", + # line 108, in run + # self._restoreSignalHandlers() + # File "C:\Python24\Lib\site-packages\flup\server\threadedserver.py", + # line 156, in _restoreSignalHandlers + # for signum,handler in self._oldSIGs: + # AttributeError: 'WSGIServer' object has no attribute '_oldSIGs' + self.fcgiserver._installSignalHandlers = lambda: None + self.fcgiserver._oldSIGs = [] + self.ready = True + self.fcgiserver.run() + + def stop(self): + """Stop the HTTP server.""" + # Forcibly stop the fcgi server main event loop. + self.fcgiserver._keepGoing = False + # Force all worker threads to die off. + self.fcgiserver._threadPool.maxSpare = ( + self.fcgiserver._threadPool._idleCount) + self.ready = False + + +class FlupSCGIServer(object): + + """Adapter for a flup.server.scgi.WSGIServer.""" + + def __init__(self, *args, **kwargs): + self.args = args + self.kwargs = kwargs + self.ready = False + + def start(self): + """Start the SCGI server.""" + # We have to instantiate the server class here because its __init__ + # starts a threadpool. If we do it too early, daemonize won't work. + from flup.server.scgi import WSGIServer + self.scgiserver = WSGIServer(*self.args, **self.kwargs) + # TODO: report this bug upstream to flup. + # If we don't set _oldSIGs on Windows, we get: + # File "C:\Python24\Lib\site-packages\flup\server\threadedserver.py", + # line 108, in run + # self._restoreSignalHandlers() + # File "C:\Python24\Lib\site-packages\flup\server\threadedserver.py", + # line 156, in _restoreSignalHandlers + # for signum,handler in self._oldSIGs: + # AttributeError: 'WSGIServer' object has no attribute '_oldSIGs' + self.scgiserver._installSignalHandlers = lambda: None + self.scgiserver._oldSIGs = [] + self.ready = True + self.scgiserver.run() + + def stop(self): + """Stop the HTTP server.""" + self.ready = False + # Forcibly stop the scgi server main event loop. + self.scgiserver._keepGoing = False + # Force all worker threads to die off. + self.scgiserver._threadPool.maxSpare = 0 + + +def client_host(server_host): + """Return the host on which a client can connect to the given listener.""" + if server_host == '0.0.0.0': + # 0.0.0.0 is INADDR_ANY, which should answer on localhost. + return '127.0.0.1' + if server_host in ('::', '::0', '::0.0.0.0'): + # :: is IN6ADDR_ANY, which should answer on localhost. + # ::0 and ::0.0.0.0 are non-canonical but common + # ways to write IN6ADDR_ANY. + return '::1' + return server_host + + +def check_port(host, port, timeout=1.0): + """Raise an error if the given port is not free on the given host.""" + if not host: + raise ValueError("Host values of '' or None are not allowed.") + host = client_host(host) + port = int(port) + + import socket + + # AF_INET or AF_INET6 socket + # Get the correct address family for our host (allows IPv6 addresses) + try: + info = socket.getaddrinfo(host, port, socket.AF_UNSPEC, + socket.SOCK_STREAM) + except socket.gaierror: + if ':' in host: + info = [( + socket.AF_INET6, socket.SOCK_STREAM, 0, "", (host, port, 0, 0) + )] + else: + info = [(socket.AF_INET, socket.SOCK_STREAM, 0, "", (host, port))] + + for res in info: + af, socktype, proto, canonname, sa = res + s = None + try: + s = socket.socket(af, socktype, proto) + # See http://groups.google.com/group/cherrypy-users/ + # browse_frm/thread/bbfe5eb39c904fe0 + s.settimeout(timeout) + s.connect((host, port)) + s.close() + except socket.error: + if s: + s.close() + else: + raise IOError("Port %s is in use on %s; perhaps the previous " + "httpserver did not shut down properly." % + (repr(port), repr(host))) + + +# Feel free to increase these defaults on slow systems: +free_port_timeout = 0.1 +occupied_port_timeout = 1.0 + + +def wait_for_free_port(host, port, timeout=None): + """Wait for the specified port to become free (drop requests).""" + if not host: + raise ValueError("Host values of '' or None are not allowed.") + if timeout is None: + timeout = free_port_timeout + + for trial in range(50): + try: + # we are expecting a free port, so reduce the timeout + check_port(host, port, timeout=timeout) + except IOError: + # Give the old server thread time to free the port. + time.sleep(timeout) + else: + return + + raise IOError("Port %r not free on %r" % (port, host)) + + +def wait_for_occupied_port(host, port, timeout=None): + """Wait for the specified port to become active (receive requests).""" + if not host: + raise ValueError("Host values of '' or None are not allowed.") + if timeout is None: + timeout = occupied_port_timeout + + for trial in range(50): + try: + check_port(host, port, timeout=timeout) + except IOError: + # port is occupied + return + else: + time.sleep(timeout) + + if host == client_host(host): + raise IOError("Port %r not bound on %r" % (port, host)) + + # On systems where a loopback interface is not available and the + # server is bound to all interfaces, it's difficult to determine + # whether the server is in fact occupying the port. In this case, + # just issue a warning and move on. See issue #1100. + msg = "Unable to verify that the server is bound on %r" % port + warnings.warn(msg) diff --git a/lib/cherrypy/process/win32.py b/lib/cherrypy/process/win32.py new file mode 100644 index 00000000..4afd3f14 --- /dev/null +++ b/lib/cherrypy/process/win32.py @@ -0,0 +1,180 @@ +"""Windows service. Requires pywin32.""" + +import os +import win32api +import win32con +import win32event +import win32service +import win32serviceutil + +from cherrypy.process import wspbus, plugins + + +class ConsoleCtrlHandler(plugins.SimplePlugin): + + """A WSPBus plugin for handling Win32 console events (like Ctrl-C).""" + + def __init__(self, bus): + self.is_set = False + plugins.SimplePlugin.__init__(self, bus) + + def start(self): + if self.is_set: + self.bus.log('Handler for console events already set.', level=40) + return + + result = win32api.SetConsoleCtrlHandler(self.handle, 1) + if result == 0: + self.bus.log('Could not SetConsoleCtrlHandler (error %r)' % + win32api.GetLastError(), level=40) + else: + self.bus.log('Set handler for console events.', level=40) + self.is_set = True + + def stop(self): + if not self.is_set: + self.bus.log('Handler for console events already off.', level=40) + return + + try: + result = win32api.SetConsoleCtrlHandler(self.handle, 0) + except ValueError: + # "ValueError: The object has not been registered" + result = 1 + + if result == 0: + self.bus.log('Could not remove SetConsoleCtrlHandler (error %r)' % + win32api.GetLastError(), level=40) + else: + self.bus.log('Removed handler for console events.', level=40) + self.is_set = False + + def handle(self, event): + """Handle console control events (like Ctrl-C).""" + if event in (win32con.CTRL_C_EVENT, win32con.CTRL_LOGOFF_EVENT, + win32con.CTRL_BREAK_EVENT, win32con.CTRL_SHUTDOWN_EVENT, + win32con.CTRL_CLOSE_EVENT): + self.bus.log('Console event %s: shutting down bus' % event) + + # Remove self immediately so repeated Ctrl-C doesn't re-call it. + try: + self.stop() + except ValueError: + pass + + self.bus.exit() + # 'First to return True stops the calls' + return 1 + return 0 + + +class Win32Bus(wspbus.Bus): + + """A Web Site Process Bus implementation for Win32. + + Instead of time.sleep, this bus blocks using native win32event objects. + """ + + def __init__(self): + self.events = {} + wspbus.Bus.__init__(self) + + def _get_state_event(self, state): + """Return a win32event for the given state (creating it if needed).""" + try: + return self.events[state] + except KeyError: + event = win32event.CreateEvent(None, 0, 0, + "WSPBus %s Event (pid=%r)" % + (state.name, os.getpid())) + self.events[state] = event + return event + + def _get_state(self): + return self._state + + def _set_state(self, value): + self._state = value + event = self._get_state_event(value) + win32event.PulseEvent(event) + state = property(_get_state, _set_state) + + def wait(self, state, interval=0.1, channel=None): + """Wait for the given state(s), KeyboardInterrupt or SystemExit. + + Since this class uses native win32event objects, the interval + argument is ignored. + """ + if isinstance(state, (tuple, list)): + # Don't wait for an event that beat us to the punch ;) + if self.state not in state: + events = tuple([self._get_state_event(s) for s in state]) + win32event.WaitForMultipleObjects( + events, 0, win32event.INFINITE) + else: + # Don't wait for an event that beat us to the punch ;) + if self.state != state: + event = self._get_state_event(state) + win32event.WaitForSingleObject(event, win32event.INFINITE) + + +class _ControlCodes(dict): + + """Control codes used to "signal" a service via ControlService. + + User-defined control codes are in the range 128-255. We generally use + the standard Python value for the Linux signal and add 128. Example: + + >>> signal.SIGUSR1 + 10 + control_codes['graceful'] = 128 + 10 + """ + + def key_for(self, obj): + """For the given value, return its corresponding key.""" + for key, val in self.items(): + if val is obj: + return key + raise ValueError("The given object could not be found: %r" % obj) + +control_codes = _ControlCodes({'graceful': 138}) + + +def signal_child(service, command): + if command == 'stop': + win32serviceutil.StopService(service) + elif command == 'restart': + win32serviceutil.RestartService(service) + else: + win32serviceutil.ControlService(service, control_codes[command]) + + +class PyWebService(win32serviceutil.ServiceFramework): + + """Python Web Service.""" + + _svc_name_ = "Python Web Service" + _svc_display_name_ = "Python Web Service" + _svc_deps_ = None # sequence of service names on which this depends + _exe_name_ = "pywebsvc" + _exe_args_ = None # Default to no arguments + + # Only exists on Windows 2000 or later, ignored on windows NT + _svc_description_ = "Python Web Service" + + def SvcDoRun(self): + from cherrypy import process + process.bus.start() + process.bus.block() + + def SvcStop(self): + from cherrypy import process + self.ReportServiceStatus(win32service.SERVICE_STOP_PENDING) + process.bus.exit() + + def SvcOther(self, control): + process.bus.publish(control_codes.key_for(control)) + + +if __name__ == '__main__': + win32serviceutil.HandleCommandLine(PyWebService) diff --git a/lib/cherrypy/process/wspbus.py b/lib/cherrypy/process/wspbus.py new file mode 100644 index 00000000..5409d038 --- /dev/null +++ b/lib/cherrypy/process/wspbus.py @@ -0,0 +1,448 @@ +"""An implementation of the Web Site Process Bus. + +This module is completely standalone, depending only on the stdlib. + +Web Site Process Bus +-------------------- + +A Bus object is used to contain and manage site-wide behavior: +daemonization, HTTP server start/stop, process reload, signal handling, +drop privileges, PID file management, logging for all of these, +and many more. + +In addition, a Bus object provides a place for each web framework +to register code that runs in response to site-wide events (like +process start and stop), or which controls or otherwise interacts with +the site-wide components mentioned above. For example, a framework which +uses file-based templates would add known template filenames to an +autoreload component. + +Ideally, a Bus object will be flexible enough to be useful in a variety +of invocation scenarios: + + 1. The deployer starts a site from the command line via a + framework-neutral deployment script; applications from multiple frameworks + are mixed in a single site. Command-line arguments and configuration + files are used to define site-wide components such as the HTTP server, + WSGI component graph, autoreload behavior, signal handling, etc. + 2. The deployer starts a site via some other process, such as Apache; + applications from multiple frameworks are mixed in a single site. + Autoreload and signal handling (from Python at least) are disabled. + 3. The deployer starts a site via a framework-specific mechanism; + for example, when running tests, exploring tutorials, or deploying + single applications from a single framework. The framework controls + which site-wide components are enabled as it sees fit. + +The Bus object in this package uses topic-based publish-subscribe +messaging to accomplish all this. A few topic channels are built in +('start', 'stop', 'exit', 'graceful', 'log', and 'main'). Frameworks and +site containers are free to define their own. If a message is sent to a +channel that has not been defined or has no listeners, there is no effect. + +In general, there should only ever be a single Bus object per process. +Frameworks and site containers share a single Bus object by publishing +messages and subscribing listeners. + +The Bus object works as a finite state machine which models the current +state of the process. Bus methods move it from one state to another; +those methods then publish to subscribed listeners on the channel for +the new state.:: + + O + | + V + STOPPING --> STOPPED --> EXITING -> X + A A | + | \___ | + | \ | + | V V + STARTED <-- STARTING + +""" + +import atexit +import os +import sys +import threading +import time +import traceback as _traceback +import warnings + +from cherrypy._cpcompat import set + +# Here I save the value of os.getcwd(), which, if I am imported early enough, +# will be the directory from which the startup script was run. This is needed +# by _do_execv(), to change back to the original directory before execv()ing a +# new process. This is a defense against the application having changed the +# current working directory (which could make sys.executable "not found" if +# sys.executable is a relative-path, and/or cause other problems). +_startup_cwd = os.getcwd() + + +class ChannelFailures(Exception): + + """Exception raised when errors occur in a listener during Bus.publish(). + """ + delimiter = '\n' + + def __init__(self, *args, **kwargs): + # Don't use 'super' here; Exceptions are old-style in Py2.4 + # See https://bitbucket.org/cherrypy/cherrypy/issue/959 + Exception.__init__(self, *args, **kwargs) + self._exceptions = list() + + def handle_exception(self): + """Append the current exception to self.""" + self._exceptions.append(sys.exc_info()[1]) + + def get_instances(self): + """Return a list of seen exception instances.""" + return self._exceptions[:] + + def __str__(self): + exception_strings = map(repr, self.get_instances()) + return self.delimiter.join(exception_strings) + + __repr__ = __str__ + + def __bool__(self): + return bool(self._exceptions) + __nonzero__ = __bool__ + +# Use a flag to indicate the state of the bus. + + +class _StateEnum(object): + + class State(object): + name = None + + def __repr__(self): + return "states.%s" % self.name + + def __setattr__(self, key, value): + if isinstance(value, self.State): + value.name = key + object.__setattr__(self, key, value) +states = _StateEnum() +states.STOPPED = states.State() +states.STARTING = states.State() +states.STARTED = states.State() +states.STOPPING = states.State() +states.EXITING = states.State() + + +try: + import fcntl +except ImportError: + max_files = 0 +else: + try: + max_files = os.sysconf('SC_OPEN_MAX') + except AttributeError: + max_files = 1024 + + +class Bus(object): + + """Process state-machine and messenger for HTTP site deployment. + + All listeners for a given channel are guaranteed to be called even + if others at the same channel fail. Each failure is logged, but + execution proceeds on to the next listener. The only way to stop all + processing from inside a listener is to raise SystemExit and stop the + whole server. + """ + + states = states + state = states.STOPPED + execv = False + max_cloexec_files = max_files + + def __init__(self): + self.execv = False + self.state = states.STOPPED + self.listeners = dict( + [(channel, set()) for channel + in ('start', 'stop', 'exit', 'graceful', 'log', 'main')]) + self._priorities = {} + + def subscribe(self, channel, callback, priority=None): + """Add the given callback at the given channel (if not present).""" + if channel not in self.listeners: + self.listeners[channel] = set() + self.listeners[channel].add(callback) + + if priority is None: + priority = getattr(callback, 'priority', 50) + self._priorities[(channel, callback)] = priority + + def unsubscribe(self, channel, callback): + """Discard the given callback (if present).""" + listeners = self.listeners.get(channel) + if listeners and callback in listeners: + listeners.discard(callback) + del self._priorities[(channel, callback)] + + def publish(self, channel, *args, **kwargs): + """Return output of all subscribers for the given channel.""" + if channel not in self.listeners: + return [] + + exc = ChannelFailures() + output = [] + + items = [(self._priorities[(channel, listener)], listener) + for listener in self.listeners[channel]] + try: + items.sort(key=lambda item: item[0]) + except TypeError: + # Python 2.3 had no 'key' arg, but that doesn't matter + # since it could sort dissimilar types just fine. + items.sort() + for priority, listener in items: + try: + output.append(listener(*args, **kwargs)) + except KeyboardInterrupt: + raise + except SystemExit: + e = sys.exc_info()[1] + # If we have previous errors ensure the exit code is non-zero + if exc and e.code == 0: + e.code = 1 + raise + except: + exc.handle_exception() + if channel == 'log': + # Assume any further messages to 'log' will fail. + pass + else: + self.log("Error in %r listener %r" % (channel, listener), + level=40, traceback=True) + if exc: + raise exc + return output + + def _clean_exit(self): + """An atexit handler which asserts the Bus is not running.""" + if self.state != states.EXITING: + warnings.warn( + "The main thread is exiting, but the Bus is in the %r state; " + "shutting it down automatically now. You must either call " + "bus.block() after start(), or call bus.exit() before the " + "main thread exits." % self.state, RuntimeWarning) + self.exit() + + def start(self): + """Start all services.""" + atexit.register(self._clean_exit) + + self.state = states.STARTING + self.log('Bus STARTING') + try: + self.publish('start') + self.state = states.STARTED + self.log('Bus STARTED') + except (KeyboardInterrupt, SystemExit): + raise + except: + self.log("Shutting down due to error in start listener:", + level=40, traceback=True) + e_info = sys.exc_info()[1] + try: + self.exit() + except: + # Any stop/exit errors will be logged inside publish(). + pass + # Re-raise the original error + raise e_info + + def exit(self): + """Stop all services and prepare to exit the process.""" + exitstate = self.state + try: + self.stop() + + self.state = states.EXITING + self.log('Bus EXITING') + self.publish('exit') + # This isn't strictly necessary, but it's better than seeing + # "Waiting for child threads to terminate..." and then nothing. + self.log('Bus EXITED') + except: + # This method is often called asynchronously (whether thread, + # signal handler, console handler, or atexit handler), so we + # can't just let exceptions propagate out unhandled. + # Assume it's been logged and just die. + os._exit(70) # EX_SOFTWARE + + if exitstate == states.STARTING: + # exit() was called before start() finished, possibly due to + # Ctrl-C because a start listener got stuck. In this case, + # we could get stuck in a loop where Ctrl-C never exits the + # process, so we just call os.exit here. + os._exit(70) # EX_SOFTWARE + + def restart(self): + """Restart the process (may close connections). + + This method does not restart the process from the calling thread; + instead, it stops the bus and asks the main thread to call execv. + """ + self.execv = True + self.exit() + + def graceful(self): + """Advise all services to reload.""" + self.log('Bus graceful') + self.publish('graceful') + + def block(self, interval=0.1): + """Wait for the EXITING state, KeyboardInterrupt or SystemExit. + + This function is intended to be called only by the main thread. + After waiting for the EXITING state, it also waits for all threads + to terminate, and then calls os.execv if self.execv is True. This + design allows another thread to call bus.restart, yet have the main + thread perform the actual execv call (required on some platforms). + """ + try: + self.wait(states.EXITING, interval=interval, channel='main') + except (KeyboardInterrupt, IOError): + # The time.sleep call might raise + # "IOError: [Errno 4] Interrupted function call" on KBInt. + self.log('Keyboard Interrupt: shutting down bus') + self.exit() + except SystemExit: + self.log('SystemExit raised: shutting down bus') + self.exit() + raise + + # Waiting for ALL child threads to finish is necessary on OS X. + # See https://bitbucket.org/cherrypy/cherrypy/issue/581. + # It's also good to let them all shut down before allowing + # the main thread to call atexit handlers. + # See https://bitbucket.org/cherrypy/cherrypy/issue/751. + self.log("Waiting for child threads to terminate...") + for t in threading.enumerate(): + # Validate the we're not trying to join the MainThread + # that will cause a deadlock and the case exist when + # implemented as a windows service and in any other case + # that another thread executes cherrypy.engine.exit() + if ( + t != threading.currentThread() and + t.isAlive() and + not isinstance(t, threading._MainThread) + ): + # Note that any dummy (external) threads are always daemonic. + if hasattr(threading.Thread, "daemon"): + # Python 2.6+ + d = t.daemon + else: + d = t.isDaemon() + if not d: + self.log("Waiting for thread %s." % t.getName()) + t.join() + + if self.execv: + self._do_execv() + + def wait(self, state, interval=0.1, channel=None): + """Poll for the given state(s) at intervals; publish to channel.""" + if isinstance(state, (tuple, list)): + states = state + else: + states = [state] + + def _wait(): + while self.state not in states: + time.sleep(interval) + self.publish(channel) + + # From http://psyco.sourceforge.net/psycoguide/bugs.html: + # "The compiled machine code does not include the regular polling + # done by Python, meaning that a KeyboardInterrupt will not be + # detected before execution comes back to the regular Python + # interpreter. Your program cannot be interrupted if caught + # into an infinite Psyco-compiled loop." + try: + sys.modules['psyco'].cannotcompile(_wait) + except (KeyError, AttributeError): + pass + + _wait() + + def _do_execv(self): + """Re-execute the current process. + + This must be called from the main thread, because certain platforms + (OS X) don't allow execv to be called in a child thread very well. + """ + args = sys.argv[:] + self.log('Re-spawning %s' % ' '.join(args)) + + if sys.platform[:4] == 'java': + from _systemrestart import SystemRestart + raise SystemRestart + else: + args.insert(0, sys.executable) + if sys.platform == 'win32': + args = ['"%s"' % arg for arg in args] + + os.chdir(_startup_cwd) + if self.max_cloexec_files: + self._set_cloexec() + os.execv(sys.executable, args) + + def _set_cloexec(self): + """Set the CLOEXEC flag on all open files (except stdin/out/err). + + If self.max_cloexec_files is an integer (the default), then on + platforms which support it, it represents the max open files setting + for the operating system. This function will be called just before + the process is restarted via os.execv() to prevent open files + from persisting into the new process. + + Set self.max_cloexec_files to 0 to disable this behavior. + """ + for fd in range(3, self.max_cloexec_files): # skip stdin/out/err + try: + flags = fcntl.fcntl(fd, fcntl.F_GETFD) + except IOError: + continue + fcntl.fcntl(fd, fcntl.F_SETFD, flags | fcntl.FD_CLOEXEC) + + def stop(self): + """Stop all services.""" + self.state = states.STOPPING + self.log('Bus STOPPING') + self.publish('stop') + self.state = states.STOPPED + self.log('Bus STOPPED') + + def start_with_callback(self, func, args=None, kwargs=None): + """Start 'func' in a new thread T, then start self (and return T).""" + if args is None: + args = () + if kwargs is None: + kwargs = {} + args = (func,) + args + + def _callback(func, *a, **kw): + self.wait(states.STARTED) + func(*a, **kw) + t = threading.Thread(target=_callback, args=args, kwargs=kwargs) + t.setName('Bus Callback ' + t.getName()) + t.start() + + self.start() + + return t + + def log(self, msg="", level=20, traceback=False): + """Log the given message. Append the last traceback if requested.""" + if traceback: + msg += "\n" + "".join(_traceback.format_exception(*sys.exc_info())) + self.publish('log', msg, level) + +bus = Bus() diff --git a/lib/cherrypy/scaffold/__init__.py b/lib/cherrypy/scaffold/__init__.py new file mode 100644 index 00000000..50de34bb --- /dev/null +++ b/lib/cherrypy/scaffold/__init__.py @@ -0,0 +1,61 @@ +""", a CherryPy application. + +Use this as a base for creating new CherryPy applications. When you want +to make a new app, copy and paste this folder to some other location +(maybe site-packages) and rename it to the name of your project, +then tweak as desired. + +Even before any tweaking, this should serve a few demonstration pages. +Change to this directory and run: + + ../cherryd -c site.conf + +""" + +import cherrypy +from cherrypy import tools, url + +import os +local_dir = os.path.join(os.getcwd(), os.path.dirname(__file__)) + + +class Root: + + _cp_config = {'tools.log_tracebacks.on': True, + } + + def index(self): + return """ +Try some other path, +or a default path.
+Or, just look at the pretty picture:
+ +""" % (url("other"), url("else"), + url("files/made_with_cherrypy_small.png")) + index.exposed = True + + def default(self, *args, **kwargs): + return "args: %s kwargs: %s" % (args, kwargs) + default.exposed = True + + def other(self, a=2, b='bananas', c=None): + cherrypy.response.headers['Content-Type'] = 'text/plain' + if c is None: + return "Have %d %s." % (int(a), b) + else: + return "Have %d %s, %s." % (int(a), b, c) + other.exposed = True + + files = cherrypy.tools.staticdir.handler( + section="/files", + dir=os.path.join(local_dir, "static"), + # Ignore .php files, etc. + match=r'\.(css|gif|html?|ico|jpe?g|js|png|swf|xml)$', + ) + + +root = Root() + +# Uncomment the following to use your own favicon instead of CP's default. +#favicon_path = os.path.join(local_dir, "favicon.ico") +#root.favicon_ico = tools.staticfile.handler(filename=favicon_path) diff --git a/lib/cherrypy/scaffold/apache-fcgi.conf b/lib/cherrypy/scaffold/apache-fcgi.conf new file mode 100644 index 00000000..922398ea --- /dev/null +++ b/lib/cherrypy/scaffold/apache-fcgi.conf @@ -0,0 +1,22 @@ +# Apache2 server conf file for using CherryPy with mod_fcgid. + +# This doesn't have to be "C:/", but it has to be a directory somewhere, and +# MUST match the directory used in the FastCgiExternalServer directive, below. +DocumentRoot "C:/" + +ServerName 127.0.0.1 +Listen 80 +LoadModule fastcgi_module modules/mod_fastcgi.dll +LoadModule rewrite_module modules/mod_rewrite.so + +Options ExecCGI +SetHandler fastcgi-script +RewriteEngine On +# Send requests for any URI to our fastcgi handler. +RewriteRule ^(.*)$ /fastcgi.pyc [L] + +# The FastCgiExternalServer directive defines filename as an external FastCGI application. +# If filename does not begin with a slash (/) then it is assumed to be relative to the ServerRoot. +# The filename does not have to exist in the local filesystem. URIs that Apache resolves to this +# filename will be handled by this external FastCGI application. +FastCgiExternalServer "C:/fastcgi.pyc" -host 127.0.0.1:8088 \ No newline at end of file diff --git a/lib/cherrypy/scaffold/example.conf b/lib/cherrypy/scaffold/example.conf new file mode 100644 index 00000000..93a6e53c --- /dev/null +++ b/lib/cherrypy/scaffold/example.conf @@ -0,0 +1,3 @@ +[/] +log.error_file: "error.log" +log.access_file: "access.log" \ No newline at end of file diff --git a/lib/cherrypy/scaffold/site.conf b/lib/cherrypy/scaffold/site.conf new file mode 100644 index 00000000..6ed38983 --- /dev/null +++ b/lib/cherrypy/scaffold/site.conf @@ -0,0 +1,14 @@ +[global] +# Uncomment this when you're done developing +#environment: "production" + +server.socket_host: "0.0.0.0" +server.socket_port: 8088 + +# Uncomment the following lines to run on HTTPS at the same time +#server.2.socket_host: "0.0.0.0" +#server.2.socket_port: 8433 +#server.2.ssl_certificate: '../test/test.pem' +#server.2.ssl_private_key: '../test/test.pem' + +tree.myapp: cherrypy.Application(scaffold.root, "/", "example.conf") diff --git a/lib/cherrypy/scaffold/static/made_with_cherrypy_small.png b/lib/cherrypy/scaffold/static/made_with_cherrypy_small.png new file mode 100644 index 00000000..c3aafeed Binary files /dev/null and b/lib/cherrypy/scaffold/static/made_with_cherrypy_small.png differ diff --git a/lib/cherrypy/wsgiserver/__init__.py b/lib/cherrypy/wsgiserver/__init__.py new file mode 100644 index 00000000..ee6190fe --- /dev/null +++ b/lib/cherrypy/wsgiserver/__init__.py @@ -0,0 +1,14 @@ +__all__ = ['HTTPRequest', 'HTTPConnection', 'HTTPServer', + 'SizeCheckWrapper', 'KnownLengthRFile', 'ChunkedRFile', + 'MaxSizeExceeded', 'NoSSLError', 'FatalSSLAlert', + 'WorkerThread', 'ThreadPool', 'SSLAdapter', + 'CherryPyWSGIServer', + 'Gateway', 'WSGIGateway', 'WSGIGateway_10', 'WSGIGateway_u0', + 'WSGIPathInfoDispatcher', 'get_ssl_adapter_class'] + +import sys +if sys.version_info < (3, 0): + from wsgiserver2 import * +else: + # Le sigh. Boo for backward-incompatible syntax. + exec('from .wsgiserver3 import *') diff --git a/lib/cherrypy/wsgiserver/ssl_builtin.py b/lib/cherrypy/wsgiserver/ssl_builtin.py new file mode 100644 index 00000000..2c74ad84 --- /dev/null +++ b/lib/cherrypy/wsgiserver/ssl_builtin.py @@ -0,0 +1,92 @@ +"""A library for integrating Python's builtin ``ssl`` library with CherryPy. + +The ssl module must be importable for SSL functionality. + +To use this module, set ``CherryPyWSGIServer.ssl_adapter`` to an instance of +``BuiltinSSLAdapter``. +""" + +try: + import ssl +except ImportError: + ssl = None + +try: + from _pyio import DEFAULT_BUFFER_SIZE +except ImportError: + try: + from io import DEFAULT_BUFFER_SIZE + except ImportError: + DEFAULT_BUFFER_SIZE = -1 + +import sys + +from cherrypy import wsgiserver + + +class BuiltinSSLAdapter(wsgiserver.SSLAdapter): + + """A wrapper for integrating Python's builtin ssl module with CherryPy.""" + + certificate = None + """The filename of the server SSL certificate.""" + + private_key = None + """The filename of the server's private key file.""" + + def __init__(self, certificate, private_key, certificate_chain=None): + if ssl is None: + raise ImportError("You must install the ssl module to use HTTPS.") + self.certificate = certificate + self.private_key = private_key + self.certificate_chain = certificate_chain + + def bind(self, sock): + """Wrap and return the given socket.""" + return sock + + def wrap(self, sock): + """Wrap and return the given socket, plus WSGI environ entries.""" + try: + s = ssl.wrap_socket(sock, do_handshake_on_connect=True, + server_side=True, certfile=self.certificate, + keyfile=self.private_key, + ssl_version=ssl.PROTOCOL_SSLv23) + except ssl.SSLError: + e = sys.exc_info()[1] + if e.errno == ssl.SSL_ERROR_EOF: + # This is almost certainly due to the cherrypy engine + # 'pinging' the socket to assert it's connectable; + # the 'ping' isn't SSL. + return None, {} + elif e.errno == ssl.SSL_ERROR_SSL: + if e.args[1].endswith('http request'): + # The client is speaking HTTP to an HTTPS server. + raise wsgiserver.NoSSLError + elif e.args[1].endswith('unknown protocol'): + # The client is speaking some non-HTTP protocol. + # Drop the conn. + return None, {} + raise + return s, self.get_environ(s) + + # TODO: fill this out more with mod ssl env + def get_environ(self, sock): + """Create WSGI environ entries to be merged into each request.""" + cipher = sock.cipher() + ssl_environ = { + "wsgi.url_scheme": "https", + "HTTPS": "on", + 'SSL_PROTOCOL': cipher[1], + 'SSL_CIPHER': cipher[0] + # SSL_VERSION_INTERFACE string The mod_ssl program version + # SSL_VERSION_LIBRARY string The OpenSSL program version + } + return ssl_environ + + if sys.version_info >= (3, 0): + def makefile(self, sock, mode='r', bufsize=DEFAULT_BUFFER_SIZE): + return wsgiserver.CP_makefile(sock, mode, bufsize) + else: + def makefile(self, sock, mode='r', bufsize=DEFAULT_BUFFER_SIZE): + return wsgiserver.CP_fileobject(sock, mode, bufsize) diff --git a/lib/cherrypy/wsgiserver/ssl_pyopenssl.py b/lib/cherrypy/wsgiserver/ssl_pyopenssl.py new file mode 100644 index 00000000..f8f2dafe --- /dev/null +++ b/lib/cherrypy/wsgiserver/ssl_pyopenssl.py @@ -0,0 +1,253 @@ +"""A library for integrating pyOpenSSL with CherryPy. + +The OpenSSL module must be importable for SSL functionality. +You can obtain it from `here `_. + +To use this module, set CherryPyWSGIServer.ssl_adapter to an instance of +SSLAdapter. There are two ways to use SSL: + +Method One +---------- + + * ``ssl_adapter.context``: an instance of SSL.Context. + +If this is not None, it is assumed to be an SSL.Context instance, +and will be passed to SSL.Connection on bind(). The developer is +responsible for forming a valid Context object. This approach is +to be preferred for more flexibility, e.g. if the cert and key are +streams instead of files, or need decryption, or SSL.SSLv3_METHOD +is desired instead of the default SSL.SSLv23_METHOD, etc. Consult +the pyOpenSSL documentation for complete options. + +Method Two (shortcut) +--------------------- + + * ``ssl_adapter.certificate``: the filename of the server SSL certificate. + * ``ssl_adapter.private_key``: the filename of the server's private key file. + +Both are None by default. If ssl_adapter.context is None, but .private_key +and .certificate are both given and valid, they will be read, and the +context will be automatically created from them. +""" + +import socket +import threading +import time + +from cherrypy import wsgiserver + +try: + from OpenSSL import SSL + from OpenSSL import crypto +except ImportError: + SSL = None + + +class SSL_fileobject(wsgiserver.CP_fileobject): + + """SSL file object attached to a socket object.""" + + ssl_timeout = 3 + ssl_retry = .01 + + def _safe_call(self, is_reader, call, *args, **kwargs): + """Wrap the given call with SSL error-trapping. + + is_reader: if False EOF errors will be raised. If True, EOF errors + will return "" (to emulate normal sockets). + """ + start = time.time() + while True: + try: + return call(*args, **kwargs) + except SSL.WantReadError: + # Sleep and try again. This is dangerous, because it means + # the rest of the stack has no way of differentiating + # between a "new handshake" error and "client dropped". + # Note this isn't an endless loop: there's a timeout below. + time.sleep(self.ssl_retry) + except SSL.WantWriteError: + time.sleep(self.ssl_retry) + except SSL.SysCallError, e: + if is_reader and e.args == (-1, 'Unexpected EOF'): + return "" + + errnum = e.args[0] + if is_reader and errnum in wsgiserver.socket_errors_to_ignore: + return "" + raise socket.error(errnum) + except SSL.Error, e: + if is_reader and e.args == (-1, 'Unexpected EOF'): + return "" + + thirdarg = None + try: + thirdarg = e.args[0][0][2] + except IndexError: + pass + + if thirdarg == 'http request': + # The client is talking HTTP to an HTTPS server. + raise wsgiserver.NoSSLError() + + raise wsgiserver.FatalSSLAlert(*e.args) + except: + raise + + if time.time() - start > self.ssl_timeout: + raise socket.timeout("timed out") + + def recv(self, size): + return self._safe_call(True, super(SSL_fileobject, self).recv, size) + + def sendall(self, *args, **kwargs): + return self._safe_call(False, super(SSL_fileobject, self).sendall, + *args, **kwargs) + + def send(self, *args, **kwargs): + return self._safe_call(False, super(SSL_fileobject, self).send, + *args, **kwargs) + + +class SSLConnection: + + """A thread-safe wrapper for an SSL.Connection. + + ``*args``: the arguments to create the wrapped ``SSL.Connection(*args)``. + """ + + def __init__(self, *args): + self._ssl_conn = SSL.Connection(*args) + self._lock = threading.RLock() + + for f in ('get_context', 'pending', 'send', 'write', 'recv', 'read', + 'renegotiate', 'bind', 'listen', 'connect', 'accept', + 'setblocking', 'fileno', 'close', 'get_cipher_list', + 'getpeername', 'getsockname', 'getsockopt', 'setsockopt', + 'makefile', 'get_app_data', 'set_app_data', 'state_string', + 'sock_shutdown', 'get_peer_certificate', 'want_read', + 'want_write', 'set_connect_state', 'set_accept_state', + 'connect_ex', 'sendall', 'settimeout', 'gettimeout'): + exec("""def %s(self, *args): + self._lock.acquire() + try: + return self._ssl_conn.%s(*args) + finally: + self._lock.release() +""" % (f, f)) + + def shutdown(self, *args): + self._lock.acquire() + try: + # pyOpenSSL.socket.shutdown takes no args + return self._ssl_conn.shutdown() + finally: + self._lock.release() + + +class pyOpenSSLAdapter(wsgiserver.SSLAdapter): + + """A wrapper for integrating pyOpenSSL with CherryPy.""" + + context = None + """An instance of SSL.Context.""" + + certificate = None + """The filename of the server SSL certificate.""" + + private_key = None + """The filename of the server's private key file.""" + + certificate_chain = None + """Optional. The filename of CA's intermediate certificate bundle. + + This is needed for cheaper "chained root" SSL certificates, and should be + left as None if not required.""" + + def __init__(self, certificate, private_key, certificate_chain=None): + if SSL is None: + raise ImportError("You must install pyOpenSSL to use HTTPS.") + + self.context = None + self.certificate = certificate + self.private_key = private_key + self.certificate_chain = certificate_chain + self._environ = None + + def bind(self, sock): + """Wrap and return the given socket.""" + if self.context is None: + self.context = self.get_context() + conn = SSLConnection(self.context, sock) + self._environ = self.get_environ() + return conn + + def wrap(self, sock): + """Wrap and return the given socket, plus WSGI environ entries.""" + return sock, self._environ.copy() + + def get_context(self): + """Return an SSL.Context from self attributes.""" + # See http://aspn.activestate.com/ASPN/Cookbook/Python/Recipe/442473 + c = SSL.Context(SSL.SSLv23_METHOD) + c.use_privatekey_file(self.private_key) + if self.certificate_chain: + c.load_verify_locations(self.certificate_chain) + c.use_certificate_file(self.certificate) + return c + + def get_environ(self): + """Return WSGI environ entries to be merged into each request.""" + ssl_environ = { + "HTTPS": "on", + # pyOpenSSL doesn't provide access to any of these AFAICT + # 'SSL_PROTOCOL': 'SSLv2', + # SSL_CIPHER string The cipher specification name + # SSL_VERSION_INTERFACE string The mod_ssl program version + # SSL_VERSION_LIBRARY string The OpenSSL program version + } + + if self.certificate: + # Server certificate attributes + cert = open(self.certificate, 'rb').read() + cert = crypto.load_certificate(crypto.FILETYPE_PEM, cert) + ssl_environ.update({ + 'SSL_SERVER_M_VERSION': cert.get_version(), + 'SSL_SERVER_M_SERIAL': cert.get_serial_number(), + # 'SSL_SERVER_V_START': + # Validity of server's certificate (start time), + # 'SSL_SERVER_V_END': + # Validity of server's certificate (end time), + }) + + for prefix, dn in [("I", cert.get_issuer()), + ("S", cert.get_subject())]: + # X509Name objects don't seem to have a way to get the + # complete DN string. Use str() and slice it instead, + # because str(dn) == "" + dnstr = str(dn)[18:-2] + + wsgikey = 'SSL_SERVER_%s_DN' % prefix + ssl_environ[wsgikey] = dnstr + + # The DN should be of the form: /k1=v1/k2=v2, but we must allow + # for any value to contain slashes itself (in a URL). + while dnstr: + pos = dnstr.rfind("=") + dnstr, value = dnstr[:pos], dnstr[pos + 1:] + pos = dnstr.rfind("/") + dnstr, key = dnstr[:pos], dnstr[pos + 1:] + if key and value: + wsgikey = 'SSL_SERVER_%s_DN_%s' % (prefix, key) + ssl_environ[wsgikey] = value + + return ssl_environ + + def makefile(self, sock, mode='r', bufsize=-1): + if SSL and isinstance(sock, SSL.ConnectionType): + timeout = sock.gettimeout() + f = SSL_fileobject(sock, mode, bufsize) + f.ssl_timeout = timeout + return f + else: + return wsgiserver.CP_fileobject(sock, mode, bufsize) diff --git a/lib/cherrypy/wsgiserver/wsgiserver2.py b/lib/cherrypy/wsgiserver/wsgiserver2.py new file mode 100644 index 00000000..b4919cf5 --- /dev/null +++ b/lib/cherrypy/wsgiserver/wsgiserver2.py @@ -0,0 +1,2481 @@ +"""A high-speed, production ready, thread pooled, generic HTTP server. + +Simplest example on how to use this module directly +(without using CherryPy's application machinery):: + + from cherrypy import wsgiserver + + def my_crazy_app(environ, start_response): + status = '200 OK' + response_headers = [('Content-type','text/plain')] + start_response(status, response_headers) + return ['Hello world!'] + + server = wsgiserver.CherryPyWSGIServer( + ('0.0.0.0', 8070), my_crazy_app, + server_name='www.cherrypy.example') + server.start() + +The CherryPy WSGI server can serve as many WSGI applications +as you want in one instance by using a WSGIPathInfoDispatcher:: + + d = WSGIPathInfoDispatcher({'/': my_crazy_app, '/blog': my_blog_app}) + server = wsgiserver.CherryPyWSGIServer(('0.0.0.0', 80), d) + +Want SSL support? Just set server.ssl_adapter to an SSLAdapter instance. + +This won't call the CherryPy engine (application side) at all, only the +HTTP server, which is independent from the rest of CherryPy. Don't +let the name "CherryPyWSGIServer" throw you; the name merely reflects +its origin, not its coupling. + +For those of you wanting to understand internals of this module, here's the +basic call flow. The server's listening thread runs a very tight loop, +sticking incoming connections onto a Queue:: + + server = CherryPyWSGIServer(...) + server.start() + while True: + tick() + # This blocks until a request comes in: + child = socket.accept() + conn = HTTPConnection(child, ...) + server.requests.put(conn) + +Worker threads are kept in a pool and poll the Queue, popping off and then +handling each connection in turn. Each connection can consist of an arbitrary +number of requests and their responses, so we run a nested loop:: + + while True: + conn = server.requests.get() + conn.communicate() + -> while True: + req = HTTPRequest(...) + req.parse_request() + -> # Read the Request-Line, e.g. "GET /page HTTP/1.1" + req.rfile.readline() + read_headers(req.rfile, req.inheaders) + req.respond() + -> response = app(...) + try: + for chunk in response: + if chunk: + req.write(chunk) + finally: + if hasattr(response, "close"): + response.close() + if req.close_connection: + return +""" + +__all__ = ['HTTPRequest', 'HTTPConnection', 'HTTPServer', + 'SizeCheckWrapper', 'KnownLengthRFile', 'ChunkedRFile', + 'CP_fileobject', + 'MaxSizeExceeded', 'NoSSLError', 'FatalSSLAlert', + 'WorkerThread', 'ThreadPool', 'SSLAdapter', + 'CherryPyWSGIServer', + 'Gateway', 'WSGIGateway', 'WSGIGateway_10', 'WSGIGateway_u0', + 'WSGIPathInfoDispatcher', 'get_ssl_adapter_class'] + +import os +try: + import queue +except: + import Queue as queue +import re +import rfc822 +import socket +import sys +if 'win' in sys.platform and hasattr(socket, "AF_INET6"): + if not hasattr(socket, 'IPPROTO_IPV6'): + socket.IPPROTO_IPV6 = 41 + if not hasattr(socket, 'IPV6_V6ONLY'): + socket.IPV6_V6ONLY = 27 +try: + import cStringIO as StringIO +except ImportError: + import StringIO +DEFAULT_BUFFER_SIZE = -1 + + +class FauxSocket(object): + + """Faux socket with the minimal interface required by pypy""" + + def _reuse(self): + pass + +_fileobject_uses_str_type = isinstance( + socket._fileobject(FauxSocket())._rbuf, basestring) +del FauxSocket # this class is not longer required for anything. + +import threading +import time +import traceback + + +def format_exc(limit=None): + """Like print_exc() but return a string. Backport for Python 2.3.""" + try: + etype, value, tb = sys.exc_info() + return ''.join(traceback.format_exception(etype, value, tb, limit)) + finally: + etype = value = tb = None + +import operator + +from urllib import unquote +import warnings + +if sys.version_info >= (3, 0): + bytestr = bytes + unicodestr = str + basestring = (bytes, str) + + def ntob(n, encoding='ISO-8859-1'): + """Return the given native string as a byte string in the given + encoding. + """ + # In Python 3, the native string type is unicode + return n.encode(encoding) +else: + bytestr = str + unicodestr = unicode + basestring = basestring + + def ntob(n, encoding='ISO-8859-1'): + """Return the given native string as a byte string in the given + encoding. + """ + # In Python 2, the native string type is bytes. Assume it's already + # in the given encoding, which for ISO-8859-1 is almost always what + # was intended. + return n + +LF = ntob('\n') +CRLF = ntob('\r\n') +TAB = ntob('\t') +SPACE = ntob(' ') +COLON = ntob(':') +SEMICOLON = ntob(';') +EMPTY = ntob('') +NUMBER_SIGN = ntob('#') +QUESTION_MARK = ntob('?') +ASTERISK = ntob('*') +FORWARD_SLASH = ntob('/') +quoted_slash = re.compile(ntob("(?i)%2F")) + +import errno + + +def plat_specific_errors(*errnames): + """Return error numbers for all errors in errnames on this platform. + + The 'errno' module contains different global constants depending on + the specific platform (OS). This function will return the list of + numeric values for a given list of potential names. + """ + errno_names = dir(errno) + nums = [getattr(errno, k) for k in errnames if k in errno_names] + # de-dupe the list + return list(dict.fromkeys(nums).keys()) + +socket_error_eintr = plat_specific_errors("EINTR", "WSAEINTR") + +socket_errors_to_ignore = plat_specific_errors( + "EPIPE", + "EBADF", "WSAEBADF", + "ENOTSOCK", "WSAENOTSOCK", + "ETIMEDOUT", "WSAETIMEDOUT", + "ECONNREFUSED", "WSAECONNREFUSED", + "ECONNRESET", "WSAECONNRESET", + "ECONNABORTED", "WSAECONNABORTED", + "ENETRESET", "WSAENETRESET", + "EHOSTDOWN", "EHOSTUNREACH", +) +socket_errors_to_ignore.append("timed out") +socket_errors_to_ignore.append("The read operation timed out") + +socket_errors_nonblocking = plat_specific_errors( + 'EAGAIN', 'EWOULDBLOCK', 'WSAEWOULDBLOCK') + +comma_separated_headers = [ + ntob(h) for h in + ['Accept', 'Accept-Charset', 'Accept-Encoding', + 'Accept-Language', 'Accept-Ranges', 'Allow', 'Cache-Control', + 'Connection', 'Content-Encoding', 'Content-Language', 'Expect', + 'If-Match', 'If-None-Match', 'Pragma', 'Proxy-Authenticate', 'TE', + 'Trailer', 'Transfer-Encoding', 'Upgrade', 'Vary', 'Via', 'Warning', + 'WWW-Authenticate'] +] + + +import logging +if not hasattr(logging, 'statistics'): + logging.statistics = {} + + +def read_headers(rfile, hdict=None): + """Read headers from the given stream into the given header dict. + + If hdict is None, a new header dict is created. Returns the populated + header dict. + + Headers which are repeated are folded together using a comma if their + specification so dictates. + + This function raises ValueError when the read bytes violate the HTTP spec. + You should probably return "400 Bad Request" if this happens. + """ + if hdict is None: + hdict = {} + + while True: + line = rfile.readline() + if not line: + # No more data--illegal end of headers + raise ValueError("Illegal end of headers.") + + if line == CRLF: + # Normal end of headers + break + if not line.endswith(CRLF): + raise ValueError("HTTP requires CRLF terminators") + + if line[0] in (SPACE, TAB): + # It's a continuation line. + v = line.strip() + else: + try: + k, v = line.split(COLON, 1) + except ValueError: + raise ValueError("Illegal header line.") + # TODO: what about TE and WWW-Authenticate? + k = k.strip().title() + v = v.strip() + hname = k + + if k in comma_separated_headers: + existing = hdict.get(hname) + if existing: + v = ", ".join((existing, v)) + hdict[hname] = v + + return hdict + + +class MaxSizeExceeded(Exception): + pass + + +class SizeCheckWrapper(object): + + """Wraps a file-like object, raising MaxSizeExceeded if too large.""" + + def __init__(self, rfile, maxlen): + self.rfile = rfile + self.maxlen = maxlen + self.bytes_read = 0 + + def _check_length(self): + if self.maxlen and self.bytes_read > self.maxlen: + raise MaxSizeExceeded() + + def read(self, size=None): + data = self.rfile.read(size) + self.bytes_read += len(data) + self._check_length() + return data + + def readline(self, size=None): + if size is not None: + data = self.rfile.readline(size) + self.bytes_read += len(data) + self._check_length() + return data + + # User didn't specify a size ... + # We read the line in chunks to make sure it's not a 100MB line ! + res = [] + while True: + data = self.rfile.readline(256) + self.bytes_read += len(data) + self._check_length() + res.append(data) + # See https://bitbucket.org/cherrypy/cherrypy/issue/421 + if len(data) < 256 or data[-1:] == LF: + return EMPTY.join(res) + + def readlines(self, sizehint=0): + # Shamelessly stolen from StringIO + total = 0 + lines = [] + line = self.readline() + while line: + lines.append(line) + total += len(line) + if 0 < sizehint <= total: + break + line = self.readline() + return lines + + def close(self): + self.rfile.close() + + def __iter__(self): + return self + + def __next__(self): + data = next(self.rfile) + self.bytes_read += len(data) + self._check_length() + return data + + def next(self): + data = self.rfile.next() + self.bytes_read += len(data) + self._check_length() + return data + + +class KnownLengthRFile(object): + + """Wraps a file-like object, returning an empty string when exhausted.""" + + def __init__(self, rfile, content_length): + self.rfile = rfile + self.remaining = content_length + + def read(self, size=None): + if self.remaining == 0: + return '' + if size is None: + size = self.remaining + else: + size = min(size, self.remaining) + + data = self.rfile.read(size) + self.remaining -= len(data) + return data + + def readline(self, size=None): + if self.remaining == 0: + return '' + if size is None: + size = self.remaining + else: + size = min(size, self.remaining) + + data = self.rfile.readline(size) + self.remaining -= len(data) + return data + + def readlines(self, sizehint=0): + # Shamelessly stolen from StringIO + total = 0 + lines = [] + line = self.readline(sizehint) + while line: + lines.append(line) + total += len(line) + if 0 < sizehint <= total: + break + line = self.readline(sizehint) + return lines + + def close(self): + self.rfile.close() + + def __iter__(self): + return self + + def __next__(self): + data = next(self.rfile) + self.remaining -= len(data) + return data + + +class ChunkedRFile(object): + + """Wraps a file-like object, returning an empty string when exhausted. + + This class is intended to provide a conforming wsgi.input value for + request entities that have been encoded with the 'chunked' transfer + encoding. + """ + + def __init__(self, rfile, maxlen, bufsize=8192): + self.rfile = rfile + self.maxlen = maxlen + self.bytes_read = 0 + self.buffer = EMPTY + self.bufsize = bufsize + self.closed = False + + def _fetch(self): + if self.closed: + return + + line = self.rfile.readline() + self.bytes_read += len(line) + + if self.maxlen and self.bytes_read > self.maxlen: + raise MaxSizeExceeded("Request Entity Too Large", self.maxlen) + + line = line.strip().split(SEMICOLON, 1) + + try: + chunk_size = line.pop(0) + chunk_size = int(chunk_size, 16) + except ValueError: + raise ValueError("Bad chunked transfer size: " + repr(chunk_size)) + + if chunk_size <= 0: + self.closed = True + return + +## if line: chunk_extension = line[0] + + if self.maxlen and self.bytes_read + chunk_size > self.maxlen: + raise IOError("Request Entity Too Large") + + chunk = self.rfile.read(chunk_size) + self.bytes_read += len(chunk) + self.buffer += chunk + + crlf = self.rfile.read(2) + if crlf != CRLF: + raise ValueError( + "Bad chunked transfer coding (expected '\\r\\n', " + "got " + repr(crlf) + ")") + + def read(self, size=None): + data = EMPTY + while True: + if size and len(data) >= size: + return data + + if not self.buffer: + self._fetch() + if not self.buffer: + # EOF + return data + + if size: + remaining = size - len(data) + data += self.buffer[:remaining] + self.buffer = self.buffer[remaining:] + else: + data += self.buffer + + def readline(self, size=None): + data = EMPTY + while True: + if size and len(data) >= size: + return data + + if not self.buffer: + self._fetch() + if not self.buffer: + # EOF + return data + + newline_pos = self.buffer.find(LF) + if size: + if newline_pos == -1: + remaining = size - len(data) + data += self.buffer[:remaining] + self.buffer = self.buffer[remaining:] + else: + remaining = min(size - len(data), newline_pos) + data += self.buffer[:remaining] + self.buffer = self.buffer[remaining:] + else: + if newline_pos == -1: + data += self.buffer + else: + data += self.buffer[:newline_pos] + self.buffer = self.buffer[newline_pos:] + + def readlines(self, sizehint=0): + # Shamelessly stolen from StringIO + total = 0 + lines = [] + line = self.readline(sizehint) + while line: + lines.append(line) + total += len(line) + if 0 < sizehint <= total: + break + line = self.readline(sizehint) + return lines + + def read_trailer_lines(self): + if not self.closed: + raise ValueError( + "Cannot read trailers until the request body has been read.") + + while True: + line = self.rfile.readline() + if not line: + # No more data--illegal end of headers + raise ValueError("Illegal end of headers.") + + self.bytes_read += len(line) + if self.maxlen and self.bytes_read > self.maxlen: + raise IOError("Request Entity Too Large") + + if line == CRLF: + # Normal end of headers + break + if not line.endswith(CRLF): + raise ValueError("HTTP requires CRLF terminators") + + yield line + + def close(self): + self.rfile.close() + + def __iter__(self): + # Shamelessly stolen from StringIO + total = 0 + line = self.readline(sizehint) + while line: + yield line + total += len(line) + if 0 < sizehint <= total: + break + line = self.readline(sizehint) + + +class HTTPRequest(object): + + """An HTTP Request (and response). + + A single HTTP connection may consist of multiple request/response pairs. + """ + + server = None + """The HTTPServer object which is receiving this request.""" + + conn = None + """The HTTPConnection object on which this request connected.""" + + inheaders = {} + """A dict of request headers.""" + + outheaders = [] + """A list of header tuples to write in the response.""" + + ready = False + """When True, the request has been parsed and is ready to begin generating + the response. When False, signals the calling Connection that the response + should not be generated and the connection should close.""" + + close_connection = False + """Signals the calling Connection that the request should close. This does + not imply an error! The client and/or server may each request that the + connection be closed.""" + + chunked_write = False + """If True, output will be encoded with the "chunked" transfer-coding. + + This value is set automatically inside send_headers.""" + + def __init__(self, server, conn): + self.server = server + self.conn = conn + + self.ready = False + self.started_request = False + self.scheme = ntob("http") + if self.server.ssl_adapter is not None: + self.scheme = ntob("https") + # Use the lowest-common protocol in case read_request_line errors. + self.response_protocol = 'HTTP/1.0' + self.inheaders = {} + + self.status = "" + self.outheaders = [] + self.sent_headers = False + self.close_connection = self.__class__.close_connection + self.chunked_read = False + self.chunked_write = self.__class__.chunked_write + + def parse_request(self): + """Parse the next HTTP request start-line and message-headers.""" + self.rfile = SizeCheckWrapper(self.conn.rfile, + self.server.max_request_header_size) + try: + success = self.read_request_line() + except MaxSizeExceeded: + self.simple_response( + "414 Request-URI Too Long", + "The Request-URI sent with the request exceeds the maximum " + "allowed bytes.") + return + else: + if not success: + return + + try: + success = self.read_request_headers() + except MaxSizeExceeded: + self.simple_response( + "413 Request Entity Too Large", + "The headers sent with the request exceed the maximum " + "allowed bytes.") + return + else: + if not success: + return + + self.ready = True + + def read_request_line(self): + # HTTP/1.1 connections are persistent by default. If a client + # requests a page, then idles (leaves the connection open), + # then rfile.readline() will raise socket.error("timed out"). + # Note that it does this based on the value given to settimeout(), + # and doesn't need the client to request or acknowledge the close + # (although your TCP stack might suffer for it: cf Apache's history + # with FIN_WAIT_2). + request_line = self.rfile.readline() + + # Set started_request to True so communicate() knows to send 408 + # from here on out. + self.started_request = True + if not request_line: + return False + + if request_line == CRLF: + # RFC 2616 sec 4.1: "...if the server is reading the protocol + # stream at the beginning of a message and receives a CRLF + # first, it should ignore the CRLF." + # But only ignore one leading line! else we enable a DoS. + request_line = self.rfile.readline() + if not request_line: + return False + + if not request_line.endswith(CRLF): + self.simple_response( + "400 Bad Request", "HTTP requires CRLF terminators") + return False + + try: + method, uri, req_protocol = request_line.strip().split(SPACE, 2) + rp = int(req_protocol[5]), int(req_protocol[7]) + except (ValueError, IndexError): + self.simple_response("400 Bad Request", "Malformed Request-Line") + return False + + self.uri = uri + self.method = method + + # uri may be an abs_path (including "http://host.domain.tld"); + scheme, authority, path = self.parse_request_uri(uri) + if NUMBER_SIGN in path: + self.simple_response("400 Bad Request", + "Illegal #fragment in Request-URI.") + return False + + if scheme: + self.scheme = scheme + + qs = EMPTY + if QUESTION_MARK in path: + path, qs = path.split(QUESTION_MARK, 1) + + # Unquote the path+params (e.g. "/this%20path" -> "/this path"). + # http://www.w3.org/Protocols/rfc2616/rfc2616-sec5.html#sec5.1.2 + # + # But note that "...a URI must be separated into its components + # before the escaped characters within those components can be + # safely decoded." http://www.ietf.org/rfc/rfc2396.txt, sec 2.4.2 + # Therefore, "/this%2Fpath" becomes "/this%2Fpath", not "/this/path". + try: + atoms = [unquote(x) for x in quoted_slash.split(path)] + except ValueError: + ex = sys.exc_info()[1] + self.simple_response("400 Bad Request", ex.args[0]) + return False + path = "%2F".join(atoms) + self.path = path + + # Note that, like wsgiref and most other HTTP servers, + # we "% HEX HEX"-unquote the path but not the query string. + self.qs = qs + + # Compare request and server HTTP protocol versions, in case our + # server does not support the requested protocol. Limit our output + # to min(req, server). We want the following output: + # request server actual written supported response + # protocol protocol response protocol feature set + # a 1.0 1.0 1.0 1.0 + # b 1.0 1.1 1.1 1.0 + # c 1.1 1.0 1.0 1.0 + # d 1.1 1.1 1.1 1.1 + # Notice that, in (b), the response will be "HTTP/1.1" even though + # the client only understands 1.0. RFC 2616 10.5.6 says we should + # only return 505 if the _major_ version is different. + sp = int(self.server.protocol[5]), int(self.server.protocol[7]) + + if sp[0] != rp[0]: + self.simple_response("505 HTTP Version Not Supported") + return False + + self.request_protocol = req_protocol + self.response_protocol = "HTTP/%s.%s" % min(rp, sp) + + return True + + def read_request_headers(self): + """Read self.rfile into self.inheaders. Return success.""" + + # then all the http headers + try: + read_headers(self.rfile, self.inheaders) + except ValueError: + ex = sys.exc_info()[1] + self.simple_response("400 Bad Request", ex.args[0]) + return False + + mrbs = self.server.max_request_body_size + if mrbs and int(self.inheaders.get("Content-Length", 0)) > mrbs: + self.simple_response( + "413 Request Entity Too Large", + "The entity sent with the request exceeds the maximum " + "allowed bytes.") + return False + + # Persistent connection support + if self.response_protocol == "HTTP/1.1": + # Both server and client are HTTP/1.1 + if self.inheaders.get("Connection", "") == "close": + self.close_connection = True + else: + # Either the server or client (or both) are HTTP/1.0 + if self.inheaders.get("Connection", "") != "Keep-Alive": + self.close_connection = True + + # Transfer-Encoding support + te = None + if self.response_protocol == "HTTP/1.1": + te = self.inheaders.get("Transfer-Encoding") + if te: + te = [x.strip().lower() for x in te.split(",") if x.strip()] + + self.chunked_read = False + + if te: + for enc in te: + if enc == "chunked": + self.chunked_read = True + else: + # Note that, even if we see "chunked", we must reject + # if there is an extension we don't recognize. + self.simple_response("501 Unimplemented") + self.close_connection = True + return False + + # From PEP 333: + # "Servers and gateways that implement HTTP 1.1 must provide + # transparent support for HTTP 1.1's "expect/continue" mechanism. + # This may be done in any of several ways: + # 1. Respond to requests containing an Expect: 100-continue request + # with an immediate "100 Continue" response, and proceed normally. + # 2. Proceed with the request normally, but provide the application + # with a wsgi.input stream that will send the "100 Continue" + # response if/when the application first attempts to read from + # the input stream. The read request must then remain blocked + # until the client responds. + # 3. Wait until the client decides that the server does not support + # expect/continue, and sends the request body on its own. + # (This is suboptimal, and is not recommended.) + # + # We used to do 3, but are now doing 1. Maybe we'll do 2 someday, + # but it seems like it would be a big slowdown for such a rare case. + if self.inheaders.get("Expect", "") == "100-continue": + # Don't use simple_response here, because it emits headers + # we don't want. See + # https://bitbucket.org/cherrypy/cherrypy/issue/951 + msg = self.server.protocol + " 100 Continue\r\n\r\n" + try: + self.conn.wfile.sendall(msg) + except socket.error: + x = sys.exc_info()[1] + if x.args[0] not in socket_errors_to_ignore: + raise + return True + + def parse_request_uri(self, uri): + """Parse a Request-URI into (scheme, authority, path). + + Note that Request-URI's must be one of:: + + Request-URI = "*" | absoluteURI | abs_path | authority + + Therefore, a Request-URI which starts with a double forward-slash + cannot be a "net_path":: + + net_path = "//" authority [ abs_path ] + + Instead, it must be interpreted as an "abs_path" with an empty first + path segment:: + + abs_path = "/" path_segments + path_segments = segment *( "/" segment ) + segment = *pchar *( ";" param ) + param = *pchar + """ + if uri == ASTERISK: + return None, None, uri + + i = uri.find('://') + if i > 0 and QUESTION_MARK not in uri[:i]: + # An absoluteURI. + # If there's a scheme (and it must be http or https), then: + # http_URL = "http:" "//" host [ ":" port ] [ abs_path [ "?" query + # ]] + scheme, remainder = uri[:i].lower(), uri[i + 3:] + authority, path = remainder.split(FORWARD_SLASH, 1) + path = FORWARD_SLASH + path + return scheme, authority, path + + if uri.startswith(FORWARD_SLASH): + # An abs_path. + return None, None, uri + else: + # An authority. + return None, uri, None + + def respond(self): + """Call the gateway and write its iterable output.""" + mrbs = self.server.max_request_body_size + if self.chunked_read: + self.rfile = ChunkedRFile(self.conn.rfile, mrbs) + else: + cl = int(self.inheaders.get("Content-Length", 0)) + if mrbs and mrbs < cl: + if not self.sent_headers: + self.simple_response( + "413 Request Entity Too Large", + "The entity sent with the request exceeds the maximum " + "allowed bytes.") + return + self.rfile = KnownLengthRFile(self.conn.rfile, cl) + + self.server.gateway(self).respond() + + if (self.ready and not self.sent_headers): + self.sent_headers = True + self.send_headers() + if self.chunked_write: + self.conn.wfile.sendall("0\r\n\r\n") + + def simple_response(self, status, msg=""): + """Write a simple response back to the client.""" + status = str(status) + buf = [self.server.protocol + SPACE + + status + CRLF, + "Content-Length: %s\r\n" % len(msg), + "Content-Type: text/plain\r\n"] + + if status[:3] in ("413", "414"): + # Request Entity Too Large / Request-URI Too Long + self.close_connection = True + if self.response_protocol == 'HTTP/1.1': + # This will not be true for 414, since read_request_line + # usually raises 414 before reading the whole line, and we + # therefore cannot know the proper response_protocol. + buf.append("Connection: close\r\n") + else: + # HTTP/1.0 had no 413/414 status nor Connection header. + # Emit 400 instead and trust the message body is enough. + status = "400 Bad Request" + + buf.append(CRLF) + if msg: + if isinstance(msg, unicodestr): + msg = msg.encode("ISO-8859-1") + buf.append(msg) + + try: + self.conn.wfile.sendall("".join(buf)) + except socket.error: + x = sys.exc_info()[1] + if x.args[0] not in socket_errors_to_ignore: + raise + + def write(self, chunk): + """Write unbuffered data to the client.""" + if self.chunked_write and chunk: + buf = [hex(len(chunk))[2:], CRLF, chunk, CRLF] + self.conn.wfile.sendall(EMPTY.join(buf)) + else: + self.conn.wfile.sendall(chunk) + + def send_headers(self): + """Assert, process, and send the HTTP response message-headers. + + You must set self.status, and self.outheaders before calling this. + """ + hkeys = [key.lower() for key, value in self.outheaders] + status = int(self.status[:3]) + + if status == 413: + # Request Entity Too Large. Close conn to avoid garbage. + self.close_connection = True + elif "content-length" not in hkeys: + # "All 1xx (informational), 204 (no content), + # and 304 (not modified) responses MUST NOT + # include a message-body." So no point chunking. + if status < 200 or status in (204, 205, 304): + pass + else: + if (self.response_protocol == 'HTTP/1.1' + and self.method != 'HEAD'): + # Use the chunked transfer-coding + self.chunked_write = True + self.outheaders.append(("Transfer-Encoding", "chunked")) + else: + # Closing the conn is the only way to determine len. + self.close_connection = True + + if "connection" not in hkeys: + if self.response_protocol == 'HTTP/1.1': + # Both server and client are HTTP/1.1 or better + if self.close_connection: + self.outheaders.append(("Connection", "close")) + else: + # Server and/or client are HTTP/1.0 + if not self.close_connection: + self.outheaders.append(("Connection", "Keep-Alive")) + + if (not self.close_connection) and (not self.chunked_read): + # Read any remaining request body data on the socket. + # "If an origin server receives a request that does not include an + # Expect request-header field with the "100-continue" expectation, + # the request includes a request body, and the server responds + # with a final status code before reading the entire request body + # from the transport connection, then the server SHOULD NOT close + # the transport connection until it has read the entire request, + # or until the client closes the connection. Otherwise, the client + # might not reliably receive the response message. However, this + # requirement is not be construed as preventing a server from + # defending itself against denial-of-service attacks, or from + # badly broken client implementations." + remaining = getattr(self.rfile, 'remaining', 0) + if remaining > 0: + self.rfile.read(remaining) + + if "date" not in hkeys: + self.outheaders.append(("Date", rfc822.formatdate())) + + if "server" not in hkeys: + self.outheaders.append(("Server", self.server.server_name)) + + buf = [self.server.protocol + SPACE + self.status + CRLF] + for k, v in self.outheaders: + buf.append(k + COLON + SPACE + v + CRLF) + buf.append(CRLF) + self.conn.wfile.sendall(EMPTY.join(buf)) + + +class NoSSLError(Exception): + + """Exception raised when a client speaks HTTP to an HTTPS socket.""" + pass + + +class FatalSSLAlert(Exception): + + """Exception raised when the SSL implementation signals a fatal alert.""" + pass + + +class CP_fileobject(socket._fileobject): + + """Faux file object attached to a socket object.""" + + def __init__(self, *args, **kwargs): + self.bytes_read = 0 + self.bytes_written = 0 + socket._fileobject.__init__(self, *args, **kwargs) + + def sendall(self, data): + """Sendall for non-blocking sockets.""" + while data: + try: + bytes_sent = self.send(data) + data = data[bytes_sent:] + except socket.error, e: + if e.args[0] not in socket_errors_nonblocking: + raise + + def send(self, data): + bytes_sent = self._sock.send(data) + self.bytes_written += bytes_sent + return bytes_sent + + def flush(self): + if self._wbuf: + buffer = "".join(self._wbuf) + self._wbuf = [] + self.sendall(buffer) + + def recv(self, size): + while True: + try: + data = self._sock.recv(size) + self.bytes_read += len(data) + return data + except socket.error, e: + if (e.args[0] not in socket_errors_nonblocking + and e.args[0] not in socket_error_eintr): + raise + + if not _fileobject_uses_str_type: + def read(self, size=-1): + # Use max, disallow tiny reads in a loop as they are very + # inefficient. + # We never leave read() with any leftover data from a new recv() + # call in our internal buffer. + rbufsize = max(self._rbufsize, self.default_bufsize) + # Our use of StringIO rather than lists of string objects returned + # by recv() minimizes memory usage and fragmentation that occurs + # when rbufsize is large compared to the typical return value of + # recv(). + buf = self._rbuf + buf.seek(0, 2) # seek end + if size < 0: + # Read until EOF + # reset _rbuf. we consume it via buf. + self._rbuf = StringIO.StringIO() + while True: + data = self.recv(rbufsize) + if not data: + break + buf.write(data) + return buf.getvalue() + else: + # Read until size bytes or EOF seen, whichever comes first + buf_len = buf.tell() + if buf_len >= size: + # Already have size bytes in our buffer? Extract and + # return. + buf.seek(0) + rv = buf.read(size) + self._rbuf = StringIO.StringIO() + self._rbuf.write(buf.read()) + return rv + + # reset _rbuf. we consume it via buf. + self._rbuf = StringIO.StringIO() + while True: + left = size - buf_len + # recv() will malloc the amount of memory given as its + # parameter even though it often returns much less data + # than that. The returned data string is short lived + # as we copy it into a StringIO and free it. This avoids + # fragmentation issues on many platforms. + data = self.recv(left) + if not data: + break + n = len(data) + if n == size and not buf_len: + # Shortcut. Avoid buffer data copies when: + # - We have no data in our buffer. + # AND + # - Our call to recv returned exactly the + # number of bytes we were asked to read. + return data + if n == left: + buf.write(data) + del data # explicit free + break + assert n <= left, "recv(%d) returned %d bytes" % (left, n) + buf.write(data) + buf_len += n + del data # explicit free + #assert buf_len == buf.tell() + return buf.getvalue() + + def readline(self, size=-1): + buf = self._rbuf + buf.seek(0, 2) # seek end + if buf.tell() > 0: + # check if we already have it in our buffer + buf.seek(0) + bline = buf.readline(size) + if bline.endswith('\n') or len(bline) == size: + self._rbuf = StringIO.StringIO() + self._rbuf.write(buf.read()) + return bline + del bline + if size < 0: + # Read until \n or EOF, whichever comes first + if self._rbufsize <= 1: + # Speed up unbuffered case + buf.seek(0) + buffers = [buf.read()] + # reset _rbuf. we consume it via buf. + self._rbuf = StringIO.StringIO() + data = None + recv = self.recv + while data != "\n": + data = recv(1) + if not data: + break + buffers.append(data) + return "".join(buffers) + + buf.seek(0, 2) # seek end + # reset _rbuf. we consume it via buf. + self._rbuf = StringIO.StringIO() + while True: + data = self.recv(self._rbufsize) + if not data: + break + nl = data.find('\n') + if nl >= 0: + nl += 1 + buf.write(data[:nl]) + self._rbuf.write(data[nl:]) + del data + break + buf.write(data) + return buf.getvalue() + else: + # Read until size bytes or \n or EOF seen, whichever comes + # first + buf.seek(0, 2) # seek end + buf_len = buf.tell() + if buf_len >= size: + buf.seek(0) + rv = buf.read(size) + self._rbuf = StringIO.StringIO() + self._rbuf.write(buf.read()) + return rv + # reset _rbuf. we consume it via buf. + self._rbuf = StringIO.StringIO() + while True: + data = self.recv(self._rbufsize) + if not data: + break + left = size - buf_len + # did we just receive a newline? + nl = data.find('\n', 0, left) + if nl >= 0: + nl += 1 + # save the excess data to _rbuf + self._rbuf.write(data[nl:]) + if buf_len: + buf.write(data[:nl]) + break + else: + # Shortcut. Avoid data copy through buf when + # returning a substring of our first recv(). + return data[:nl] + n = len(data) + if n == size and not buf_len: + # Shortcut. Avoid data copy through buf when + # returning exactly all of our first recv(). + return data + if n >= left: + buf.write(data[:left]) + self._rbuf.write(data[left:]) + break + buf.write(data) + buf_len += n + #assert buf_len == buf.tell() + return buf.getvalue() + else: + def read(self, size=-1): + if size < 0: + # Read until EOF + buffers = [self._rbuf] + self._rbuf = "" + if self._rbufsize <= 1: + recv_size = self.default_bufsize + else: + recv_size = self._rbufsize + + while True: + data = self.recv(recv_size) + if not data: + break + buffers.append(data) + return "".join(buffers) + else: + # Read until size bytes or EOF seen, whichever comes first + data = self._rbuf + buf_len = len(data) + if buf_len >= size: + self._rbuf = data[size:] + return data[:size] + buffers = [] + if data: + buffers.append(data) + self._rbuf = "" + while True: + left = size - buf_len + recv_size = max(self._rbufsize, left) + data = self.recv(recv_size) + if not data: + break + buffers.append(data) + n = len(data) + if n >= left: + self._rbuf = data[left:] + buffers[-1] = data[:left] + break + buf_len += n + return "".join(buffers) + + def readline(self, size=-1): + data = self._rbuf + if size < 0: + # Read until \n or EOF, whichever comes first + if self._rbufsize <= 1: + # Speed up unbuffered case + assert data == "" + buffers = [] + while data != "\n": + data = self.recv(1) + if not data: + break + buffers.append(data) + return "".join(buffers) + nl = data.find('\n') + if nl >= 0: + nl += 1 + self._rbuf = data[nl:] + return data[:nl] + buffers = [] + if data: + buffers.append(data) + self._rbuf = "" + while True: + data = self.recv(self._rbufsize) + if not data: + break + buffers.append(data) + nl = data.find('\n') + if nl >= 0: + nl += 1 + self._rbuf = data[nl:] + buffers[-1] = data[:nl] + break + return "".join(buffers) + else: + # Read until size bytes or \n or EOF seen, whichever comes + # first + nl = data.find('\n', 0, size) + if nl >= 0: + nl += 1 + self._rbuf = data[nl:] + return data[:nl] + buf_len = len(data) + if buf_len >= size: + self._rbuf = data[size:] + return data[:size] + buffers = [] + if data: + buffers.append(data) + self._rbuf = "" + while True: + data = self.recv(self._rbufsize) + if not data: + break + buffers.append(data) + left = size - buf_len + nl = data.find('\n', 0, left) + if nl >= 0: + nl += 1 + self._rbuf = data[nl:] + buffers[-1] = data[:nl] + break + n = len(data) + if n >= left: + self._rbuf = data[left:] + buffers[-1] = data[:left] + break + buf_len += n + return "".join(buffers) + + +class HTTPConnection(object): + + """An HTTP connection (active socket). + + server: the Server object which received this connection. + socket: the raw socket object (usually TCP) for this connection. + makefile: a fileobject class for reading from the socket. + """ + + remote_addr = None + remote_port = None + ssl_env = None + rbufsize = DEFAULT_BUFFER_SIZE + wbufsize = DEFAULT_BUFFER_SIZE + RequestHandlerClass = HTTPRequest + + def __init__(self, server, sock, makefile=CP_fileobject): + self.server = server + self.socket = sock + self.rfile = makefile(sock, "rb", self.rbufsize) + self.wfile = makefile(sock, "wb", self.wbufsize) + self.requests_seen = 0 + + def communicate(self): + """Read each request and respond appropriately.""" + request_seen = False + try: + while True: + # (re)set req to None so that if something goes wrong in + # the RequestHandlerClass constructor, the error doesn't + # get written to the previous request. + req = None + req = self.RequestHandlerClass(self.server, self) + + # This order of operations should guarantee correct pipelining. + req.parse_request() + if self.server.stats['Enabled']: + self.requests_seen += 1 + if not req.ready: + # Something went wrong in the parsing (and the server has + # probably already made a simple_response). Return and + # let the conn close. + return + + request_seen = True + req.respond() + if req.close_connection: + return + except socket.error: + e = sys.exc_info()[1] + errnum = e.args[0] + # sadly SSL sockets return a different (longer) time out string + if ( + errnum == 'timed out' or + errnum == 'The read operation timed out' + ): + # Don't error if we're between requests; only error + # if 1) no request has been started at all, or 2) we're + # in the middle of a request. + # See https://bitbucket.org/cherrypy/cherrypy/issue/853 + if (not request_seen) or (req and req.started_request): + # Don't bother writing the 408 if the response + # has already started being written. + if req and not req.sent_headers: + try: + req.simple_response("408 Request Timeout") + except FatalSSLAlert: + # Close the connection. + return + elif errnum not in socket_errors_to_ignore: + self.server.error_log("socket.error %s" % repr(errnum), + level=logging.WARNING, traceback=True) + if req and not req.sent_headers: + try: + req.simple_response("500 Internal Server Error") + except FatalSSLAlert: + # Close the connection. + return + return + except (KeyboardInterrupt, SystemExit): + raise + except FatalSSLAlert: + # Close the connection. + return + except NoSSLError: + if req and not req.sent_headers: + # Unwrap our wfile + self.wfile = CP_fileobject( + self.socket._sock, "wb", self.wbufsize) + req.simple_response( + "400 Bad Request", + "The client sent a plain HTTP request, but " + "this server only speaks HTTPS on this port.") + self.linger = True + except Exception: + e = sys.exc_info()[1] + self.server.error_log(repr(e), level=logging.ERROR, traceback=True) + if req and not req.sent_headers: + try: + req.simple_response("500 Internal Server Error") + except FatalSSLAlert: + # Close the connection. + return + + linger = False + + def close(self): + """Close the socket underlying this connection.""" + self.rfile.close() + + if not self.linger: + # Python's socket module does NOT call close on the kernel + # socket when you call socket.close(). We do so manually here + # because we want this server to send a FIN TCP segment + # immediately. Note this must be called *before* calling + # socket.close(), because the latter drops its reference to + # the kernel socket. + if hasattr(self.socket, '_sock'): + self.socket._sock.close() + self.socket.close() + else: + # On the other hand, sometimes we want to hang around for a bit + # to make sure the client has a chance to read our entire + # response. Skipping the close() calls here delays the FIN + # packet until the socket object is garbage-collected later. + # Someday, perhaps, we'll do the full lingering_close that + # Apache does, but not today. + pass + + +class TrueyZero(object): + + """An object which equals and does math like the integer 0 but evals True. + """ + + def __add__(self, other): + return other + + def __radd__(self, other): + return other +trueyzero = TrueyZero() + + +_SHUTDOWNREQUEST = None + + +class WorkerThread(threading.Thread): + + """Thread which continuously polls a Queue for Connection objects. + + Due to the timing issues of polling a Queue, a WorkerThread does not + check its own 'ready' flag after it has started. To stop the thread, + it is necessary to stick a _SHUTDOWNREQUEST object onto the Queue + (one for each running WorkerThread). + """ + + conn = None + """The current connection pulled off the Queue, or None.""" + + server = None + """The HTTP Server which spawned this thread, and which owns the + Queue and is placing active connections into it.""" + + ready = False + """A simple flag for the calling server to know when this thread + has begun polling the Queue.""" + + def __init__(self, server): + self.ready = False + self.server = server + + self.requests_seen = 0 + self.bytes_read = 0 + self.bytes_written = 0 + self.start_time = None + self.work_time = 0 + self.stats = { + 'Requests': lambda s: self.requests_seen + ( + (self.start_time is None) and + trueyzero or + self.conn.requests_seen + ), + 'Bytes Read': lambda s: self.bytes_read + ( + (self.start_time is None) and + trueyzero or + self.conn.rfile.bytes_read + ), + 'Bytes Written': lambda s: self.bytes_written + ( + (self.start_time is None) and + trueyzero or + self.conn.wfile.bytes_written + ), + 'Work Time': lambda s: self.work_time + ( + (self.start_time is None) and + trueyzero or + time.time() - self.start_time + ), + 'Read Throughput': lambda s: s['Bytes Read'](s) / ( + s['Work Time'](s) or 1e-6), + 'Write Throughput': lambda s: s['Bytes Written'](s) / ( + s['Work Time'](s) or 1e-6), + } + threading.Thread.__init__(self) + + def run(self): + self.server.stats['Worker Threads'][self.getName()] = self.stats + try: + self.ready = True + while True: + conn = self.server.requests.get() + if conn is _SHUTDOWNREQUEST: + return + + self.conn = conn + if self.server.stats['Enabled']: + self.start_time = time.time() + try: + conn.communicate() + finally: + conn.close() + if self.server.stats['Enabled']: + self.requests_seen += self.conn.requests_seen + self.bytes_read += self.conn.rfile.bytes_read + self.bytes_written += self.conn.wfile.bytes_written + self.work_time += time.time() - self.start_time + self.start_time = None + self.conn = None + except (KeyboardInterrupt, SystemExit): + exc = sys.exc_info()[1] + self.server.interrupt = exc + + +class ThreadPool(object): + + """A Request Queue for an HTTPServer which pools threads. + + ThreadPool objects must provide min, get(), put(obj), start() + and stop(timeout) attributes. + """ + + def __init__(self, server, min=10, max=-1, + accepted_queue_size=-1, accepted_queue_timeout=10): + self.server = server + self.min = min + self.max = max + self._threads = [] + self._queue = queue.Queue(maxsize=accepted_queue_size) + self._queue_put_timeout = accepted_queue_timeout + self.get = self._queue.get + + def start(self): + """Start the pool of threads.""" + for i in range(self.min): + self._threads.append(WorkerThread(self.server)) + for worker in self._threads: + worker.setName("CP Server " + worker.getName()) + worker.start() + for worker in self._threads: + while not worker.ready: + time.sleep(.1) + + def _get_idle(self): + """Number of worker threads which are idle. Read-only.""" + return len([t for t in self._threads if t.conn is None]) + idle = property(_get_idle, doc=_get_idle.__doc__) + + def put(self, obj): + self._queue.put(obj, block=True, timeout=self._queue_put_timeout) + if obj is _SHUTDOWNREQUEST: + return + + def grow(self, amount): + """Spawn new worker threads (not above self.max).""" + if self.max > 0: + budget = max(self.max - len(self._threads), 0) + else: + # self.max <= 0 indicates no maximum + budget = float('inf') + + n_new = min(amount, budget) + + workers = [self._spawn_worker() for i in range(n_new)] + while not self._all(operator.attrgetter('ready'), workers): + time.sleep(.1) + self._threads.extend(workers) + + def _spawn_worker(self): + worker = WorkerThread(self.server) + worker.setName("CP Server " + worker.getName()) + worker.start() + return worker + + def _all(func, items): + results = [func(item) for item in items] + return reduce(operator.and_, results, True) + _all = staticmethod(_all) + + def shrink(self, amount): + """Kill off worker threads (not below self.min).""" + # Grow/shrink the pool if necessary. + # Remove any dead threads from our list + for t in self._threads: + if not t.isAlive(): + self._threads.remove(t) + amount -= 1 + + # calculate the number of threads above the minimum + n_extra = max(len(self._threads) - self.min, 0) + + # don't remove more than amount + n_to_remove = min(amount, n_extra) + + # put shutdown requests on the queue equal to the number of threads + # to remove. As each request is processed by a worker, that worker + # will terminate and be culled from the list. + for n in range(n_to_remove): + self._queue.put(_SHUTDOWNREQUEST) + + def stop(self, timeout=5): + # Must shut down threads here so the code that calls + # this method can know when all threads are stopped. + for worker in self._threads: + self._queue.put(_SHUTDOWNREQUEST) + + # Don't join currentThread (when stop is called inside a request). + current = threading.currentThread() + if timeout and timeout >= 0: + endtime = time.time() + timeout + while self._threads: + worker = self._threads.pop() + if worker is not current and worker.isAlive(): + try: + if timeout is None or timeout < 0: + worker.join() + else: + remaining_time = endtime - time.time() + if remaining_time > 0: + worker.join(remaining_time) + if worker.isAlive(): + # We exhausted the timeout. + # Forcibly shut down the socket. + c = worker.conn + if c and not c.rfile.closed: + try: + c.socket.shutdown(socket.SHUT_RD) + except TypeError: + # pyOpenSSL sockets don't take an arg + c.socket.shutdown() + worker.join() + except (AssertionError, + # Ignore repeated Ctrl-C. + # See + # https://bitbucket.org/cherrypy/cherrypy/issue/691. + KeyboardInterrupt): + pass + + def _get_qsize(self): + return self._queue.qsize() + qsize = property(_get_qsize) + + +try: + import fcntl +except ImportError: + try: + from ctypes import windll, WinError + import ctypes.wintypes + _SetHandleInformation = windll.kernel32.SetHandleInformation + _SetHandleInformation.argtypes = [ + ctypes.wintypes.HANDLE, + ctypes.wintypes.DWORD, + ctypes.wintypes.DWORD, + ] + _SetHandleInformation.restype = ctypes.wintypes.BOOL + except ImportError: + def prevent_socket_inheritance(sock): + """Dummy function, since neither fcntl nor ctypes are available.""" + pass + else: + def prevent_socket_inheritance(sock): + """Mark the given socket fd as non-inheritable (Windows).""" + if not _SetHandleInformation(sock.fileno(), 1, 0): + raise WinError() +else: + def prevent_socket_inheritance(sock): + """Mark the given socket fd as non-inheritable (POSIX).""" + fd = sock.fileno() + old_flags = fcntl.fcntl(fd, fcntl.F_GETFD) + fcntl.fcntl(fd, fcntl.F_SETFD, old_flags | fcntl.FD_CLOEXEC) + + +class SSLAdapter(object): + + """Base class for SSL driver library adapters. + + Required methods: + + * ``wrap(sock) -> (wrapped socket, ssl environ dict)`` + * ``makefile(sock, mode='r', bufsize=DEFAULT_BUFFER_SIZE) -> + socket file object`` + """ + + def __init__(self, certificate, private_key, certificate_chain=None): + self.certificate = certificate + self.private_key = private_key + self.certificate_chain = certificate_chain + + def wrap(self, sock): + raise NotImplemented + + def makefile(self, sock, mode='r', bufsize=DEFAULT_BUFFER_SIZE): + raise NotImplemented + + +class HTTPServer(object): + + """An HTTP server.""" + + _bind_addr = "127.0.0.1" + _interrupt = None + + gateway = None + """A Gateway instance.""" + + minthreads = None + """The minimum number of worker threads to create (default 10).""" + + maxthreads = None + """The maximum number of worker threads to create (default -1 = no limit). + """ + + server_name = None + """The name of the server; defaults to socket.gethostname().""" + + protocol = "HTTP/1.1" + """The version string to write in the Status-Line of all HTTP responses. + + For example, "HTTP/1.1" is the default. This also limits the supported + features used in the response.""" + + request_queue_size = 5 + """The 'backlog' arg to socket.listen(); max queued connections + (default 5). + """ + + shutdown_timeout = 5 + """The total time, in seconds, to wait for worker threads to cleanly exit. + """ + + timeout = 10 + """The timeout in seconds for accepted connections (default 10).""" + + version = "CherryPy/3.6.0" + """A version string for the HTTPServer.""" + + software = None + """The value to set for the SERVER_SOFTWARE entry in the WSGI environ. + + If None, this defaults to ``'%s Server' % self.version``.""" + + ready = False + """An internal flag which marks whether the socket is accepting connections + """ + + max_request_header_size = 0 + """The maximum size, in bytes, for request headers, or 0 for no limit.""" + + max_request_body_size = 0 + """The maximum size, in bytes, for request bodies, or 0 for no limit.""" + + nodelay = True + """If True (the default since 3.1), sets the TCP_NODELAY socket option.""" + + ConnectionClass = HTTPConnection + """The class to use for handling HTTP connections.""" + + ssl_adapter = None + """An instance of SSLAdapter (or a subclass). + + You must have the corresponding SSL driver library installed.""" + + def __init__(self, bind_addr, gateway, minthreads=10, maxthreads=-1, + server_name=None): + self.bind_addr = bind_addr + self.gateway = gateway + + self.requests = ThreadPool(self, min=minthreads or 1, max=maxthreads) + + if not server_name: + server_name = socket.gethostname() + self.server_name = server_name + self.clear_stats() + + def clear_stats(self): + self._start_time = None + self._run_time = 0 + self.stats = { + 'Enabled': False, + 'Bind Address': lambda s: repr(self.bind_addr), + 'Run time': lambda s: (not s['Enabled']) and -1 or self.runtime(), + 'Accepts': 0, + 'Accepts/sec': lambda s: s['Accepts'] / self.runtime(), + 'Queue': lambda s: getattr(self.requests, "qsize", None), + 'Threads': lambda s: len(getattr(self.requests, "_threads", [])), + 'Threads Idle': lambda s: getattr(self.requests, "idle", None), + 'Socket Errors': 0, + 'Requests': lambda s: (not s['Enabled']) and -1 or sum( + [w['Requests'](w) for w in s['Worker Threads'].values()], 0), + 'Bytes Read': lambda s: (not s['Enabled']) and -1 or sum( + [w['Bytes Read'](w) for w in s['Worker Threads'].values()], 0), + 'Bytes Written': lambda s: (not s['Enabled']) and -1 or sum( + [w['Bytes Written'](w) for w in s['Worker Threads'].values()], + 0), + 'Work Time': lambda s: (not s['Enabled']) and -1 or sum( + [w['Work Time'](w) for w in s['Worker Threads'].values()], 0), + 'Read Throughput': lambda s: (not s['Enabled']) and -1 or sum( + [w['Bytes Read'](w) / (w['Work Time'](w) or 1e-6) + for w in s['Worker Threads'].values()], 0), + 'Write Throughput': lambda s: (not s['Enabled']) and -1 or sum( + [w['Bytes Written'](w) / (w['Work Time'](w) or 1e-6) + for w in s['Worker Threads'].values()], 0), + 'Worker Threads': {}, + } + logging.statistics["CherryPy HTTPServer %d" % id(self)] = self.stats + + def runtime(self): + if self._start_time is None: + return self._run_time + else: + return self._run_time + (time.time() - self._start_time) + + def __str__(self): + return "%s.%s(%r)" % (self.__module__, self.__class__.__name__, + self.bind_addr) + + def _get_bind_addr(self): + return self._bind_addr + + def _set_bind_addr(self, value): + if isinstance(value, tuple) and value[0] in ('', None): + # Despite the socket module docs, using '' does not + # allow AI_PASSIVE to work. Passing None instead + # returns '0.0.0.0' like we want. In other words: + # host AI_PASSIVE result + # '' Y 192.168.x.y + # '' N 192.168.x.y + # None Y 0.0.0.0 + # None N 127.0.0.1 + # But since you can get the same effect with an explicit + # '0.0.0.0', we deny both the empty string and None as values. + raise ValueError("Host values of '' or None are not allowed. " + "Use '0.0.0.0' (IPv4) or '::' (IPv6) instead " + "to listen on all active interfaces.") + self._bind_addr = value + bind_addr = property( + _get_bind_addr, + _set_bind_addr, + doc="""The interface on which to listen for connections. + + For TCP sockets, a (host, port) tuple. Host values may be any IPv4 + or IPv6 address, or any valid hostname. The string 'localhost' is a + synonym for '127.0.0.1' (or '::1', if your hosts file prefers IPv6). + The string '0.0.0.0' is a special IPv4 entry meaning "any active + interface" (INADDR_ANY), and '::' is the similar IN6ADDR_ANY for + IPv6. The empty string or None are not allowed. + + For UNIX sockets, supply the filename as a string.""") + + def start(self): + """Run the server forever.""" + # We don't have to trap KeyboardInterrupt or SystemExit here, + # because cherrpy.server already does so, calling self.stop() for us. + # If you're using this server with another framework, you should + # trap those exceptions in whatever code block calls start(). + self._interrupt = None + + if self.software is None: + self.software = "%s Server" % self.version + + # SSL backward compatibility + if (self.ssl_adapter is None and + getattr(self, 'ssl_certificate', None) and + getattr(self, 'ssl_private_key', None)): + warnings.warn( + "SSL attributes are deprecated in CherryPy 3.2, and will " + "be removed in CherryPy 3.3. Use an ssl_adapter attribute " + "instead.", + DeprecationWarning + ) + try: + from cherrypy.wsgiserver.ssl_pyopenssl import pyOpenSSLAdapter + except ImportError: + pass + else: + self.ssl_adapter = pyOpenSSLAdapter( + self.ssl_certificate, self.ssl_private_key, + getattr(self, 'ssl_certificate_chain', None)) + + # Select the appropriate socket + if isinstance(self.bind_addr, basestring): + # AF_UNIX socket + + # So we can reuse the socket... + try: + os.unlink(self.bind_addr) + except: + pass + + # So everyone can access the socket... + try: + os.chmod(self.bind_addr, 511) # 0777 + except: + pass + + info = [ + (socket.AF_UNIX, socket.SOCK_STREAM, 0, "", self.bind_addr)] + else: + # AF_INET or AF_INET6 socket + # Get the correct address family for our host (allows IPv6 + # addresses) + host, port = self.bind_addr + try: + info = socket.getaddrinfo( + host, port, socket.AF_UNSPEC, + socket.SOCK_STREAM, 0, socket.AI_PASSIVE) + except socket.gaierror: + if ':' in self.bind_addr[0]: + info = [(socket.AF_INET6, socket.SOCK_STREAM, + 0, "", self.bind_addr + (0, 0))] + else: + info = [(socket.AF_INET, socket.SOCK_STREAM, + 0, "", self.bind_addr)] + + self.socket = None + msg = "No socket could be created" + for res in info: + af, socktype, proto, canonname, sa = res + try: + self.bind(af, socktype, proto) + except socket.error, serr: + msg = "%s -- (%s: %s)" % (msg, sa, serr) + if self.socket: + self.socket.close() + self.socket = None + continue + break + if not self.socket: + raise socket.error(msg) + + # Timeout so KeyboardInterrupt can be caught on Win32 + self.socket.settimeout(1) + self.socket.listen(self.request_queue_size) + + # Create worker threads + self.requests.start() + + self.ready = True + self._start_time = time.time() + while self.ready: + try: + self.tick() + except (KeyboardInterrupt, SystemExit): + raise + except: + self.error_log("Error in HTTPServer.tick", level=logging.ERROR, + traceback=True) + + if self.interrupt: + while self.interrupt is True: + # Wait for self.stop() to complete. See _set_interrupt. + time.sleep(0.1) + if self.interrupt: + raise self.interrupt + + def error_log(self, msg="", level=20, traceback=False): + # Override this in subclasses as desired + sys.stderr.write(msg + '\n') + sys.stderr.flush() + if traceback: + tblines = format_exc() + sys.stderr.write(tblines) + sys.stderr.flush() + + def bind(self, family, type, proto=0): + """Create (or recreate) the actual socket object.""" + self.socket = socket.socket(family, type, proto) + prevent_socket_inheritance(self.socket) + self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + if self.nodelay and not isinstance(self.bind_addr, str): + self.socket.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) + + if self.ssl_adapter is not None: + self.socket = self.ssl_adapter.bind(self.socket) + + # If listening on the IPV6 any address ('::' = IN6ADDR_ANY), + # activate dual-stack. See + # https://bitbucket.org/cherrypy/cherrypy/issue/871. + if (hasattr(socket, 'AF_INET6') and family == socket.AF_INET6 + and self.bind_addr[0] in ('::', '::0', '::0.0.0.0')): + try: + self.socket.setsockopt( + socket.IPPROTO_IPV6, socket.IPV6_V6ONLY, 0) + except (AttributeError, socket.error): + # Apparently, the socket option is not available in + # this machine's TCP stack + pass + + self.socket.bind(self.bind_addr) + + def tick(self): + """Accept a new connection and put it on the Queue.""" + try: + s, addr = self.socket.accept() + if self.stats['Enabled']: + self.stats['Accepts'] += 1 + if not self.ready: + return + + prevent_socket_inheritance(s) + if hasattr(s, 'settimeout'): + s.settimeout(self.timeout) + + makefile = CP_fileobject + ssl_env = {} + # if ssl cert and key are set, we try to be a secure HTTP server + if self.ssl_adapter is not None: + try: + s, ssl_env = self.ssl_adapter.wrap(s) + except NoSSLError: + msg = ("The client sent a plain HTTP request, but " + "this server only speaks HTTPS on this port.") + buf = ["%s 400 Bad Request\r\n" % self.protocol, + "Content-Length: %s\r\n" % len(msg), + "Content-Type: text/plain\r\n\r\n", + msg] + + wfile = makefile(s._sock, "wb", DEFAULT_BUFFER_SIZE) + try: + wfile.sendall("".join(buf)) + except socket.error: + x = sys.exc_info()[1] + if x.args[0] not in socket_errors_to_ignore: + raise + return + if not s: + return + makefile = self.ssl_adapter.makefile + # Re-apply our timeout since we may have a new socket object + if hasattr(s, 'settimeout'): + s.settimeout(self.timeout) + + conn = self.ConnectionClass(self, s, makefile) + + if not isinstance(self.bind_addr, basestring): + # optional values + # Until we do DNS lookups, omit REMOTE_HOST + if addr is None: # sometimes this can happen + # figure out if AF_INET or AF_INET6. + if len(s.getsockname()) == 2: + # AF_INET + addr = ('0.0.0.0', 0) + else: + # AF_INET6 + addr = ('::', 0) + conn.remote_addr = addr[0] + conn.remote_port = addr[1] + + conn.ssl_env = ssl_env + + try: + self.requests.put(conn) + except queue.Full: + # Just drop the conn. TODO: write 503 back? + conn.close() + return + except socket.timeout: + # The only reason for the timeout in start() is so we can + # notice keyboard interrupts on Win32, which don't interrupt + # accept() by default + return + except socket.error: + x = sys.exc_info()[1] + if self.stats['Enabled']: + self.stats['Socket Errors'] += 1 + if x.args[0] in socket_error_eintr: + # I *think* this is right. EINTR should occur when a signal + # is received during the accept() call; all docs say retry + # the call, and I *think* I'm reading it right that Python + # will then go ahead and poll for and handle the signal + # elsewhere. See + # https://bitbucket.org/cherrypy/cherrypy/issue/707. + return + if x.args[0] in socket_errors_nonblocking: + # Just try again. See + # https://bitbucket.org/cherrypy/cherrypy/issue/479. + return + if x.args[0] in socket_errors_to_ignore: + # Our socket was closed. + # See https://bitbucket.org/cherrypy/cherrypy/issue/686. + return + raise + + def _get_interrupt(self): + return self._interrupt + + def _set_interrupt(self, interrupt): + self._interrupt = True + self.stop() + self._interrupt = interrupt + interrupt = property(_get_interrupt, _set_interrupt, + doc="Set this to an Exception instance to " + "interrupt the server.") + + def stop(self): + """Gracefully shutdown a server that is serving forever.""" + self.ready = False + if self._start_time is not None: + self._run_time += (time.time() - self._start_time) + self._start_time = None + + sock = getattr(self, "socket", None) + if sock: + if not isinstance(self.bind_addr, basestring): + # Touch our own socket to make accept() return immediately. + try: + host, port = sock.getsockname()[:2] + except socket.error: + x = sys.exc_info()[1] + if x.args[0] not in socket_errors_to_ignore: + # Changed to use error code and not message + # See + # https://bitbucket.org/cherrypy/cherrypy/issue/860. + raise + else: + # Note that we're explicitly NOT using AI_PASSIVE, + # here, because we want an actual IP to touch. + # localhost won't work if we've bound to a public IP, + # but it will if we bound to '0.0.0.0' (INADDR_ANY). + for res in socket.getaddrinfo(host, port, socket.AF_UNSPEC, + socket.SOCK_STREAM): + af, socktype, proto, canonname, sa = res + s = None + try: + s = socket.socket(af, socktype, proto) + # See + # http://groups.google.com/group/cherrypy-users/ + # browse_frm/thread/bbfe5eb39c904fe0 + s.settimeout(1.0) + s.connect((host, port)) + s.close() + except socket.error: + if s: + s.close() + if hasattr(sock, "close"): + sock.close() + self.socket = None + + self.requests.stop(self.shutdown_timeout) + + +class Gateway(object): + + """A base class to interface HTTPServer with other systems, such as WSGI. + """ + + def __init__(self, req): + self.req = req + + def respond(self): + """Process the current request. Must be overridden in a subclass.""" + raise NotImplemented + + +# These may either be wsgiserver.SSLAdapter subclasses or the string names +# of such classes (in which case they will be lazily loaded). +ssl_adapters = { + 'builtin': 'cherrypy.wsgiserver.ssl_builtin.BuiltinSSLAdapter', + 'pyopenssl': 'cherrypy.wsgiserver.ssl_pyopenssl.pyOpenSSLAdapter', +} + + +def get_ssl_adapter_class(name='pyopenssl'): + """Return an SSL adapter class for the given name.""" + adapter = ssl_adapters[name.lower()] + if isinstance(adapter, basestring): + last_dot = adapter.rfind(".") + attr_name = adapter[last_dot + 1:] + mod_path = adapter[:last_dot] + + try: + mod = sys.modules[mod_path] + if mod is None: + raise KeyError() + except KeyError: + # The last [''] is important. + mod = __import__(mod_path, globals(), locals(), ['']) + + # Let an AttributeError propagate outward. + try: + adapter = getattr(mod, attr_name) + except AttributeError: + raise AttributeError("'%s' object has no attribute '%s'" + % (mod_path, attr_name)) + + return adapter + +# ------------------------------- WSGI Stuff -------------------------------- # + + +class CherryPyWSGIServer(HTTPServer): + + """A subclass of HTTPServer which calls a WSGI application.""" + + wsgi_version = (1, 0) + """The version of WSGI to produce.""" + + def __init__(self, bind_addr, wsgi_app, numthreads=10, server_name=None, + max=-1, request_queue_size=5, timeout=10, shutdown_timeout=5, + accepted_queue_size=-1, accepted_queue_timeout=10): + self.requests = ThreadPool(self, min=numthreads or 1, max=max, + accepted_queue_size=accepted_queue_size, + accepted_queue_timeout=accepted_queue_timeout) + self.wsgi_app = wsgi_app + self.gateway = wsgi_gateways[self.wsgi_version] + + self.bind_addr = bind_addr + if not server_name: + server_name = socket.gethostname() + self.server_name = server_name + self.request_queue_size = request_queue_size + + self.timeout = timeout + self.shutdown_timeout = shutdown_timeout + self.clear_stats() + + def _get_numthreads(self): + return self.requests.min + + def _set_numthreads(self, value): + self.requests.min = value + numthreads = property(_get_numthreads, _set_numthreads) + + +class WSGIGateway(Gateway): + + """A base class to interface HTTPServer with WSGI.""" + + def __init__(self, req): + self.req = req + self.started_response = False + self.env = self.get_environ() + self.remaining_bytes_out = None + + def get_environ(self): + """Return a new environ dict targeting the given wsgi.version""" + raise NotImplemented + + def respond(self): + """Process the current request.""" + response = self.req.server.wsgi_app(self.env, self.start_response) + try: + for chunk in response: + # "The start_response callable must not actually transmit + # the response headers. Instead, it must store them for the + # server or gateway to transmit only after the first + # iteration of the application return value that yields + # a NON-EMPTY string, or upon the application's first + # invocation of the write() callable." (PEP 333) + if chunk: + if isinstance(chunk, unicodestr): + chunk = chunk.encode('ISO-8859-1') + self.write(chunk) + finally: + if hasattr(response, "close"): + response.close() + + def start_response(self, status, headers, exc_info=None): + """WSGI callable to begin the HTTP response.""" + # "The application may call start_response more than once, + # if and only if the exc_info argument is provided." + if self.started_response and not exc_info: + raise AssertionError("WSGI start_response called a second " + "time with no exc_info.") + self.started_response = True + + # "if exc_info is provided, and the HTTP headers have already been + # sent, start_response must raise an error, and should raise the + # exc_info tuple." + if self.req.sent_headers: + try: + raise exc_info[0], exc_info[1], exc_info[2] + finally: + exc_info = None + + self.req.status = status + for k, v in headers: + if not isinstance(k, str): + raise TypeError( + "WSGI response header key %r is not of type str." % k) + if not isinstance(v, str): + raise TypeError( + "WSGI response header value %r is not of type str." % v) + if k.lower() == 'content-length': + self.remaining_bytes_out = int(v) + self.req.outheaders.extend(headers) + + return self.write + + def write(self, chunk): + """WSGI callable to write unbuffered data to the client. + + This method is also used internally by start_response (to write + data from the iterable returned by the WSGI application). + """ + if not self.started_response: + raise AssertionError("WSGI write called before start_response.") + + chunklen = len(chunk) + rbo = self.remaining_bytes_out + if rbo is not None and chunklen > rbo: + if not self.req.sent_headers: + # Whew. We can send a 500 to the client. + self.req.simple_response( + "500 Internal Server Error", + "The requested resource returned more bytes than the " + "declared Content-Length.") + else: + # Dang. We have probably already sent data. Truncate the chunk + # to fit (so the client doesn't hang) and raise an error later. + chunk = chunk[:rbo] + + if not self.req.sent_headers: + self.req.sent_headers = True + self.req.send_headers() + + self.req.write(chunk) + + if rbo is not None: + rbo -= chunklen + if rbo < 0: + raise ValueError( + "Response body exceeds the declared Content-Length.") + + +class WSGIGateway_10(WSGIGateway): + + """A Gateway class to interface HTTPServer with WSGI 1.0.x.""" + + def get_environ(self): + """Return a new environ dict targeting the given wsgi.version""" + req = self.req + env = { + # set a non-standard environ entry so the WSGI app can know what + # the *real* server protocol is (and what features to support). + # See http://www.faqs.org/rfcs/rfc2145.html. + 'ACTUAL_SERVER_PROTOCOL': req.server.protocol, + 'PATH_INFO': req.path, + 'QUERY_STRING': req.qs, + 'REMOTE_ADDR': req.conn.remote_addr or '', + 'REMOTE_PORT': str(req.conn.remote_port or ''), + 'REQUEST_METHOD': req.method, + 'REQUEST_URI': req.uri, + 'SCRIPT_NAME': '', + 'SERVER_NAME': req.server.server_name, + # Bah. "SERVER_PROTOCOL" is actually the REQUEST protocol. + 'SERVER_PROTOCOL': req.request_protocol, + 'SERVER_SOFTWARE': req.server.software, + 'wsgi.errors': sys.stderr, + 'wsgi.input': req.rfile, + 'wsgi.multiprocess': False, + 'wsgi.multithread': True, + 'wsgi.run_once': False, + 'wsgi.url_scheme': req.scheme, + 'wsgi.version': (1, 0), + } + + if isinstance(req.server.bind_addr, basestring): + # AF_UNIX. This isn't really allowed by WSGI, which doesn't + # address unix domain sockets. But it's better than nothing. + env["SERVER_PORT"] = "" + else: + env["SERVER_PORT"] = str(req.server.bind_addr[1]) + + # Request headers + for k, v in req.inheaders.iteritems(): + env["HTTP_" + k.upper().replace("-", "_")] = v + + # CONTENT_TYPE/CONTENT_LENGTH + ct = env.pop("HTTP_CONTENT_TYPE", None) + if ct is not None: + env["CONTENT_TYPE"] = ct + cl = env.pop("HTTP_CONTENT_LENGTH", None) + if cl is not None: + env["CONTENT_LENGTH"] = cl + + if req.conn.ssl_env: + env.update(req.conn.ssl_env) + + return env + + +class WSGIGateway_u0(WSGIGateway_10): + + """A Gateway class to interface HTTPServer with WSGI u.0. + + WSGI u.0 is an experimental protocol, which uses unicode for keys and + values in both Python 2 and Python 3. + """ + + def get_environ(self): + """Return a new environ dict targeting the given wsgi.version""" + req = self.req + env_10 = WSGIGateway_10.get_environ(self) + env = dict([(k.decode('ISO-8859-1'), v) + for k, v in env_10.iteritems()]) + env[u'wsgi.version'] = ('u', 0) + + # Request-URI + env.setdefault(u'wsgi.url_encoding', u'utf-8') + try: + for key in [u"PATH_INFO", u"SCRIPT_NAME", u"QUERY_STRING"]: + env[key] = env_10[str(key)].decode(env[u'wsgi.url_encoding']) + except UnicodeDecodeError: + # Fall back to latin 1 so apps can transcode if needed. + env[u'wsgi.url_encoding'] = u'ISO-8859-1' + for key in [u"PATH_INFO", u"SCRIPT_NAME", u"QUERY_STRING"]: + env[key] = env_10[str(key)].decode(env[u'wsgi.url_encoding']) + + for k, v in sorted(env.items()): + if isinstance(v, str) and k not in ('REQUEST_URI', 'wsgi.input'): + env[k] = v.decode('ISO-8859-1') + + return env + +wsgi_gateways = { + (1, 0): WSGIGateway_10, + ('u', 0): WSGIGateway_u0, +} + + +class WSGIPathInfoDispatcher(object): + + """A WSGI dispatcher for dispatch based on the PATH_INFO. + + apps: a dict or list of (path_prefix, app) pairs. + """ + + def __init__(self, apps): + try: + apps = list(apps.items()) + except AttributeError: + pass + + # Sort the apps by len(path), descending + apps.sort(cmp=lambda x, y: cmp(len(x[0]), len(y[0]))) + apps.reverse() + + # The path_prefix strings must start, but not end, with a slash. + # Use "" instead of "/". + self.apps = [(p.rstrip("/"), a) for p, a in apps] + + def __call__(self, environ, start_response): + path = environ["PATH_INFO"] or "/" + for p, app in self.apps: + # The apps list should be sorted by length, descending. + if path.startswith(p + "/") or path == p: + environ = environ.copy() + environ["SCRIPT_NAME"] = environ["SCRIPT_NAME"] + p + environ["PATH_INFO"] = path[len(p):] + return app(environ, start_response) + + start_response('404 Not Found', [('Content-Type', 'text/plain'), + ('Content-Length', '0')]) + return [''] diff --git a/lib/cherrypy/wsgiserver/wsgiserver3.py b/lib/cherrypy/wsgiserver/wsgiserver3.py new file mode 100644 index 00000000..8bba6261 --- /dev/null +++ b/lib/cherrypy/wsgiserver/wsgiserver3.py @@ -0,0 +1,2176 @@ +"""A high-speed, production ready, thread pooled, generic HTTP server. + +Simplest example on how to use this module directly +(without using CherryPy's application machinery):: + + from cherrypy import wsgiserver + + def my_crazy_app(environ, start_response): + status = '200 OK' + response_headers = [('Content-type','text/plain')] + start_response(status, response_headers) + return ['Hello world!'] + + server = wsgiserver.CherryPyWSGIServer( + ('0.0.0.0', 8070), my_crazy_app, + server_name='www.cherrypy.example') + server.start() + +The CherryPy WSGI server can serve as many WSGI applications +as you want in one instance by using a WSGIPathInfoDispatcher:: + + d = WSGIPathInfoDispatcher({'/': my_crazy_app, '/blog': my_blog_app}) + server = wsgiserver.CherryPyWSGIServer(('0.0.0.0', 80), d) + +Want SSL support? Just set server.ssl_adapter to an SSLAdapter instance. + +This won't call the CherryPy engine (application side) at all, only the +HTTP server, which is independent from the rest of CherryPy. Don't +let the name "CherryPyWSGIServer" throw you; the name merely reflects +its origin, not its coupling. + +For those of you wanting to understand internals of this module, here's the +basic call flow. The server's listening thread runs a very tight loop, +sticking incoming connections onto a Queue:: + + server = CherryPyWSGIServer(...) + server.start() + while True: + tick() + # This blocks until a request comes in: + child = socket.accept() + conn = HTTPConnection(child, ...) + server.requests.put(conn) + +Worker threads are kept in a pool and poll the Queue, popping off and then +handling each connection in turn. Each connection can consist of an arbitrary +number of requests and their responses, so we run a nested loop:: + + while True: + conn = server.requests.get() + conn.communicate() + -> while True: + req = HTTPRequest(...) + req.parse_request() + -> # Read the Request-Line, e.g. "GET /page HTTP/1.1" + req.rfile.readline() + read_headers(req.rfile, req.inheaders) + req.respond() + -> response = app(...) + try: + for chunk in response: + if chunk: + req.write(chunk) + finally: + if hasattr(response, "close"): + response.close() + if req.close_connection: + return +""" + +__all__ = ['HTTPRequest', 'HTTPConnection', 'HTTPServer', + 'SizeCheckWrapper', 'KnownLengthRFile', 'ChunkedRFile', + 'CP_makefile', + 'MaxSizeExceeded', 'NoSSLError', 'FatalSSLAlert', + 'WorkerThread', 'ThreadPool', 'SSLAdapter', + 'CherryPyWSGIServer', + 'Gateway', 'WSGIGateway', 'WSGIGateway_10', 'WSGIGateway_u0', + 'WSGIPathInfoDispatcher', 'get_ssl_adapter_class'] + +import os +try: + import queue +except: + import Queue as queue +import re +import email.utils +import socket +import sys +if 'win' in sys.platform and hasattr(socket, "AF_INET6"): + if not hasattr(socket, 'IPPROTO_IPV6'): + socket.IPPROTO_IPV6 = 41 + if not hasattr(socket, 'IPV6_V6ONLY'): + socket.IPV6_V6ONLY = 27 +if sys.version_info < (3, 1): + import io +else: + import _pyio as io +DEFAULT_BUFFER_SIZE = io.DEFAULT_BUFFER_SIZE + +import threading +import time +from traceback import format_exc + +if sys.version_info >= (3, 0): + bytestr = bytes + unicodestr = str + basestring = (bytes, str) + + def ntob(n, encoding='ISO-8859-1'): + """Return the given native string as a byte string in the given + encoding. + """ + # In Python 3, the native string type is unicode + return n.encode(encoding) +else: + bytestr = str + unicodestr = unicode + basestring = basestring + + def ntob(n, encoding='ISO-8859-1'): + """Return the given native string as a byte string in the given + encoding. + """ + # In Python 2, the native string type is bytes. Assume it's already + # in the given encoding, which for ISO-8859-1 is almost always what + # was intended. + return n + +LF = ntob('\n') +CRLF = ntob('\r\n') +TAB = ntob('\t') +SPACE = ntob(' ') +COLON = ntob(':') +SEMICOLON = ntob(';') +EMPTY = ntob('') +NUMBER_SIGN = ntob('#') +QUESTION_MARK = ntob('?') +ASTERISK = ntob('*') +FORWARD_SLASH = ntob('/') +quoted_slash = re.compile(ntob("(?i)%2F")) + +import errno + + +def plat_specific_errors(*errnames): + """Return error numbers for all errors in errnames on this platform. + + The 'errno' module contains different global constants depending on + the specific platform (OS). This function will return the list of + numeric values for a given list of potential names. + """ + errno_names = dir(errno) + nums = [getattr(errno, k) for k in errnames if k in errno_names] + # de-dupe the list + return list(dict.fromkeys(nums).keys()) + +socket_error_eintr = plat_specific_errors("EINTR", "WSAEINTR") + +socket_errors_to_ignore = plat_specific_errors( + "EPIPE", + "EBADF", "WSAEBADF", + "ENOTSOCK", "WSAENOTSOCK", + "ETIMEDOUT", "WSAETIMEDOUT", + "ECONNREFUSED", "WSAECONNREFUSED", + "ECONNRESET", "WSAECONNRESET", + "ECONNABORTED", "WSAECONNABORTED", + "ENETRESET", "WSAENETRESET", + "EHOSTDOWN", "EHOSTUNREACH", +) +socket_errors_to_ignore.append("timed out") +socket_errors_to_ignore.append("The read operation timed out") + +socket_errors_nonblocking = plat_specific_errors( + 'EAGAIN', 'EWOULDBLOCK', 'WSAEWOULDBLOCK') + +comma_separated_headers = [ + ntob(h) for h in + ['Accept', 'Accept-Charset', 'Accept-Encoding', + 'Accept-Language', 'Accept-Ranges', 'Allow', 'Cache-Control', + 'Connection', 'Content-Encoding', 'Content-Language', 'Expect', + 'If-Match', 'If-None-Match', 'Pragma', 'Proxy-Authenticate', 'TE', + 'Trailer', 'Transfer-Encoding', 'Upgrade', 'Vary', 'Via', 'Warning', + 'WWW-Authenticate'] +] + + +import logging +if not hasattr(logging, 'statistics'): + logging.statistics = {} + + +def read_headers(rfile, hdict=None): + """Read headers from the given stream into the given header dict. + + If hdict is None, a new header dict is created. Returns the populated + header dict. + + Headers which are repeated are folded together using a comma if their + specification so dictates. + + This function raises ValueError when the read bytes violate the HTTP spec. + You should probably return "400 Bad Request" if this happens. + """ + if hdict is None: + hdict = {} + + while True: + line = rfile.readline() + if not line: + # No more data--illegal end of headers + raise ValueError("Illegal end of headers.") + + if line == CRLF: + # Normal end of headers + break + if not line.endswith(CRLF): + raise ValueError("HTTP requires CRLF terminators") + + if line[0] in (SPACE, TAB): + # It's a continuation line. + v = line.strip() + else: + try: + k, v = line.split(COLON, 1) + except ValueError: + raise ValueError("Illegal header line.") + # TODO: what about TE and WWW-Authenticate? + k = k.strip().title() + v = v.strip() + hname = k + + if k in comma_separated_headers: + existing = hdict.get(hname) + if existing: + v = b", ".join((existing, v)) + hdict[hname] = v + + return hdict + + +class MaxSizeExceeded(Exception): + pass + + +class SizeCheckWrapper(object): + + """Wraps a file-like object, raising MaxSizeExceeded if too large.""" + + def __init__(self, rfile, maxlen): + self.rfile = rfile + self.maxlen = maxlen + self.bytes_read = 0 + + def _check_length(self): + if self.maxlen and self.bytes_read > self.maxlen: + raise MaxSizeExceeded() + + def read(self, size=None): + data = self.rfile.read(size) + self.bytes_read += len(data) + self._check_length() + return data + + def readline(self, size=None): + if size is not None: + data = self.rfile.readline(size) + self.bytes_read += len(data) + self._check_length() + return data + + # User didn't specify a size ... + # We read the line in chunks to make sure it's not a 100MB line ! + res = [] + while True: + data = self.rfile.readline(256) + self.bytes_read += len(data) + self._check_length() + res.append(data) + # See https://bitbucket.org/cherrypy/cherrypy/issue/421 + if len(data) < 256 or data[-1:] == LF: + return EMPTY.join(res) + + def readlines(self, sizehint=0): + # Shamelessly stolen from StringIO + total = 0 + lines = [] + line = self.readline() + while line: + lines.append(line) + total += len(line) + if 0 < sizehint <= total: + break + line = self.readline() + return lines + + def close(self): + self.rfile.close() + + def __iter__(self): + return self + + def __next__(self): + data = next(self.rfile) + self.bytes_read += len(data) + self._check_length() + return data + + def next(self): + data = self.rfile.next() + self.bytes_read += len(data) + self._check_length() + return data + + +class KnownLengthRFile(object): + + """Wraps a file-like object, returning an empty string when exhausted.""" + + def __init__(self, rfile, content_length): + self.rfile = rfile + self.remaining = content_length + + def read(self, size=None): + if self.remaining == 0: + return b'' + if size is None: + size = self.remaining + else: + size = min(size, self.remaining) + + data = self.rfile.read(size) + self.remaining -= len(data) + return data + + def readline(self, size=None): + if self.remaining == 0: + return b'' + if size is None: + size = self.remaining + else: + size = min(size, self.remaining) + + data = self.rfile.readline(size) + self.remaining -= len(data) + return data + + def readlines(self, sizehint=0): + # Shamelessly stolen from StringIO + total = 0 + lines = [] + line = self.readline(sizehint) + while line: + lines.append(line) + total += len(line) + if 0 < sizehint <= total: + break + line = self.readline(sizehint) + return lines + + def close(self): + self.rfile.close() + + def __iter__(self): + return self + + def __next__(self): + data = next(self.rfile) + self.remaining -= len(data) + return data + + +class ChunkedRFile(object): + + """Wraps a file-like object, returning an empty string when exhausted. + + This class is intended to provide a conforming wsgi.input value for + request entities that have been encoded with the 'chunked' transfer + encoding. + """ + + def __init__(self, rfile, maxlen, bufsize=8192): + self.rfile = rfile + self.maxlen = maxlen + self.bytes_read = 0 + self.buffer = EMPTY + self.bufsize = bufsize + self.closed = False + + def _fetch(self): + if self.closed: + return + + line = self.rfile.readline() + self.bytes_read += len(line) + + if self.maxlen and self.bytes_read > self.maxlen: + raise MaxSizeExceeded("Request Entity Too Large", self.maxlen) + + line = line.strip().split(SEMICOLON, 1) + + try: + chunk_size = line.pop(0) + chunk_size = int(chunk_size, 16) + except ValueError: + raise ValueError("Bad chunked transfer size: " + repr(chunk_size)) + + if chunk_size <= 0: + self.closed = True + return + +## if line: chunk_extension = line[0] + + if self.maxlen and self.bytes_read + chunk_size > self.maxlen: + raise IOError("Request Entity Too Large") + + chunk = self.rfile.read(chunk_size) + self.bytes_read += len(chunk) + self.buffer += chunk + + crlf = self.rfile.read(2) + if crlf != CRLF: + raise ValueError( + "Bad chunked transfer coding (expected '\\r\\n', " + "got " + repr(crlf) + ")") + + def read(self, size=None): + data = EMPTY + while True: + if size and len(data) >= size: + return data + + if not self.buffer: + self._fetch() + if not self.buffer: + # EOF + return data + + if size: + remaining = size - len(data) + data += self.buffer[:remaining] + self.buffer = self.buffer[remaining:] + else: + data += self.buffer + + def readline(self, size=None): + data = EMPTY + while True: + if size and len(data) >= size: + return data + + if not self.buffer: + self._fetch() + if not self.buffer: + # EOF + return data + + newline_pos = self.buffer.find(LF) + if size: + if newline_pos == -1: + remaining = size - len(data) + data += self.buffer[:remaining] + self.buffer = self.buffer[remaining:] + else: + remaining = min(size - len(data), newline_pos) + data += self.buffer[:remaining] + self.buffer = self.buffer[remaining:] + else: + if newline_pos == -1: + data += self.buffer + else: + data += self.buffer[:newline_pos] + self.buffer = self.buffer[newline_pos:] + + def readlines(self, sizehint=0): + # Shamelessly stolen from StringIO + total = 0 + lines = [] + line = self.readline(sizehint) + while line: + lines.append(line) + total += len(line) + if 0 < sizehint <= total: + break + line = self.readline(sizehint) + return lines + + def read_trailer_lines(self): + if not self.closed: + raise ValueError( + "Cannot read trailers until the request body has been read.") + + while True: + line = self.rfile.readline() + if not line: + # No more data--illegal end of headers + raise ValueError("Illegal end of headers.") + + self.bytes_read += len(line) + if self.maxlen and self.bytes_read > self.maxlen: + raise IOError("Request Entity Too Large") + + if line == CRLF: + # Normal end of headers + break + if not line.endswith(CRLF): + raise ValueError("HTTP requires CRLF terminators") + + yield line + + def close(self): + self.rfile.close() + + def __iter__(self): + # Shamelessly stolen from StringIO + total = 0 + line = self.readline(sizehint) + while line: + yield line + total += len(line) + if 0 < sizehint <= total: + break + line = self.readline(sizehint) + + +class HTTPRequest(object): + + """An HTTP Request (and response). + + A single HTTP connection may consist of multiple request/response pairs. + """ + + server = None + """The HTTPServer object which is receiving this request.""" + + conn = None + """The HTTPConnection object on which this request connected.""" + + inheaders = {} + """A dict of request headers.""" + + outheaders = [] + """A list of header tuples to write in the response.""" + + ready = False + """When True, the request has been parsed and is ready to begin generating + the response. When False, signals the calling Connection that the response + should not be generated and the connection should close.""" + + close_connection = False + """Signals the calling Connection that the request should close. This does + not imply an error! The client and/or server may each request that the + connection be closed.""" + + chunked_write = False + """If True, output will be encoded with the "chunked" transfer-coding. + + This value is set automatically inside send_headers.""" + + def __init__(self, server, conn): + self.server = server + self.conn = conn + + self.ready = False + self.started_request = False + self.scheme = ntob("http") + if self.server.ssl_adapter is not None: + self.scheme = ntob("https") + # Use the lowest-common protocol in case read_request_line errors. + self.response_protocol = 'HTTP/1.0' + self.inheaders = {} + + self.status = "" + self.outheaders = [] + self.sent_headers = False + self.close_connection = self.__class__.close_connection + self.chunked_read = False + self.chunked_write = self.__class__.chunked_write + + def parse_request(self): + """Parse the next HTTP request start-line and message-headers.""" + self.rfile = SizeCheckWrapper(self.conn.rfile, + self.server.max_request_header_size) + try: + success = self.read_request_line() + except MaxSizeExceeded: + self.simple_response( + "414 Request-URI Too Long", + "The Request-URI sent with the request exceeds the maximum " + "allowed bytes.") + return + else: + if not success: + return + + try: + success = self.read_request_headers() + except MaxSizeExceeded: + self.simple_response( + "413 Request Entity Too Large", + "The headers sent with the request exceed the maximum " + "allowed bytes.") + return + else: + if not success: + return + + self.ready = True + + def read_request_line(self): + # HTTP/1.1 connections are persistent by default. If a client + # requests a page, then idles (leaves the connection open), + # then rfile.readline() will raise socket.error("timed out"). + # Note that it does this based on the value given to settimeout(), + # and doesn't need the client to request or acknowledge the close + # (although your TCP stack might suffer for it: cf Apache's history + # with FIN_WAIT_2). + request_line = self.rfile.readline() + + # Set started_request to True so communicate() knows to send 408 + # from here on out. + self.started_request = True + if not request_line: + return False + + if request_line == CRLF: + # RFC 2616 sec 4.1: "...if the server is reading the protocol + # stream at the beginning of a message and receives a CRLF + # first, it should ignore the CRLF." + # But only ignore one leading line! else we enable a DoS. + request_line = self.rfile.readline() + if not request_line: + return False + + if not request_line.endswith(CRLF): + self.simple_response( + "400 Bad Request", "HTTP requires CRLF terminators") + return False + + try: + method, uri, req_protocol = request_line.strip().split(SPACE, 2) + # The [x:y] slicing is necessary for byte strings to avoid getting + # ord's + rp = int(req_protocol[5:6]), int(req_protocol[7:8]) + except ValueError: + self.simple_response("400 Bad Request", "Malformed Request-Line") + return False + + self.uri = uri + self.method = method + + # uri may be an abs_path (including "http://host.domain.tld"); + scheme, authority, path = self.parse_request_uri(uri) + if NUMBER_SIGN in path: + self.simple_response("400 Bad Request", + "Illegal #fragment in Request-URI.") + return False + + if scheme: + self.scheme = scheme + + qs = EMPTY + if QUESTION_MARK in path: + path, qs = path.split(QUESTION_MARK, 1) + + # Unquote the path+params (e.g. "/this%20path" -> "/this path"). + # http://www.w3.org/Protocols/rfc2616/rfc2616-sec5.html#sec5.1.2 + # + # But note that "...a URI must be separated into its components + # before the escaped characters within those components can be + # safely decoded." http://www.ietf.org/rfc/rfc2396.txt, sec 2.4.2 + # Therefore, "/this%2Fpath" becomes "/this%2Fpath", not "/this/path". + try: + atoms = [self.unquote_bytes(x) for x in quoted_slash.split(path)] + except ValueError: + ex = sys.exc_info()[1] + self.simple_response("400 Bad Request", ex.args[0]) + return False + path = b"%2F".join(atoms) + self.path = path + + # Note that, like wsgiref and most other HTTP servers, + # we "% HEX HEX"-unquote the path but not the query string. + self.qs = qs + + # Compare request and server HTTP protocol versions, in case our + # server does not support the requested protocol. Limit our output + # to min(req, server). We want the following output: + # request server actual written supported response + # protocol protocol response protocol feature set + # a 1.0 1.0 1.0 1.0 + # b 1.0 1.1 1.1 1.0 + # c 1.1 1.0 1.0 1.0 + # d 1.1 1.1 1.1 1.1 + # Notice that, in (b), the response will be "HTTP/1.1" even though + # the client only understands 1.0. RFC 2616 10.5.6 says we should + # only return 505 if the _major_ version is different. + # The [x:y] slicing is necessary for byte strings to avoid getting + # ord's + sp = int(self.server.protocol[5:6]), int(self.server.protocol[7:8]) + + if sp[0] != rp[0]: + self.simple_response("505 HTTP Version Not Supported") + return False + + self.request_protocol = req_protocol + self.response_protocol = "HTTP/%s.%s" % min(rp, sp) + return True + + def read_request_headers(self): + """Read self.rfile into self.inheaders. Return success.""" + + # then all the http headers + try: + read_headers(self.rfile, self.inheaders) + except ValueError: + ex = sys.exc_info()[1] + self.simple_response("400 Bad Request", ex.args[0]) + return False + + mrbs = self.server.max_request_body_size + if mrbs and int(self.inheaders.get(b"Content-Length", 0)) > mrbs: + self.simple_response( + "413 Request Entity Too Large", + "The entity sent with the request exceeds the maximum " + "allowed bytes.") + return False + + # Persistent connection support + if self.response_protocol == "HTTP/1.1": + # Both server and client are HTTP/1.1 + if self.inheaders.get(b"Connection", b"") == b"close": + self.close_connection = True + else: + # Either the server or client (or both) are HTTP/1.0 + if self.inheaders.get(b"Connection", b"") != b"Keep-Alive": + self.close_connection = True + + # Transfer-Encoding support + te = None + if self.response_protocol == "HTTP/1.1": + te = self.inheaders.get(b"Transfer-Encoding") + if te: + te = [x.strip().lower() for x in te.split(b",") if x.strip()] + + self.chunked_read = False + + if te: + for enc in te: + if enc == b"chunked": + self.chunked_read = True + else: + # Note that, even if we see "chunked", we must reject + # if there is an extension we don't recognize. + self.simple_response("501 Unimplemented") + self.close_connection = True + return False + + # From PEP 333: + # "Servers and gateways that implement HTTP 1.1 must provide + # transparent support for HTTP 1.1's "expect/continue" mechanism. + # This may be done in any of several ways: + # 1. Respond to requests containing an Expect: 100-continue request + # with an immediate "100 Continue" response, and proceed normally. + # 2. Proceed with the request normally, but provide the application + # with a wsgi.input stream that will send the "100 Continue" + # response if/when the application first attempts to read from + # the input stream. The read request must then remain blocked + # until the client responds. + # 3. Wait until the client decides that the server does not support + # expect/continue, and sends the request body on its own. + # (This is suboptimal, and is not recommended.) + # + # We used to do 3, but are now doing 1. Maybe we'll do 2 someday, + # but it seems like it would be a big slowdown for such a rare case. + if self.inheaders.get(b"Expect", b"") == b"100-continue": + # Don't use simple_response here, because it emits headers + # we don't want. See + # https://bitbucket.org/cherrypy/cherrypy/issue/951 + msg = self.server.protocol.encode( + 'ascii') + b" 100 Continue\r\n\r\n" + try: + self.conn.wfile.write(msg) + except socket.error: + x = sys.exc_info()[1] + if x.args[0] not in socket_errors_to_ignore: + raise + return True + + def parse_request_uri(self, uri): + """Parse a Request-URI into (scheme, authority, path). + + Note that Request-URI's must be one of:: + + Request-URI = "*" | absoluteURI | abs_path | authority + + Therefore, a Request-URI which starts with a double forward-slash + cannot be a "net_path":: + + net_path = "//" authority [ abs_path ] + + Instead, it must be interpreted as an "abs_path" with an empty first + path segment:: + + abs_path = "/" path_segments + path_segments = segment *( "/" segment ) + segment = *pchar *( ";" param ) + param = *pchar + """ + if uri == ASTERISK: + return None, None, uri + + scheme, sep, remainder = uri.partition(b'://') + if sep and QUESTION_MARK not in scheme: + # An absoluteURI. + # If there's a scheme (and it must be http or https), then: + # http_URL = "http:" "//" host [ ":" port ] [ abs_path [ "?" query + # ]] + authority, path_a, path_b = remainder.partition(FORWARD_SLASH) + return scheme.lower(), authority, path_a + path_b + + if uri.startswith(FORWARD_SLASH): + # An abs_path. + return None, None, uri + else: + # An authority. + return None, uri, None + + def unquote_bytes(self, path): + """takes quoted string and unquotes % encoded values""" + res = path.split(b'%') + + for i in range(1, len(res)): + item = res[i] + try: + res[i] = bytes([int(item[:2], 16)]) + item[2:] + except ValueError: + raise + return b''.join(res) + + def respond(self): + """Call the gateway and write its iterable output.""" + mrbs = self.server.max_request_body_size + if self.chunked_read: + self.rfile = ChunkedRFile(self.conn.rfile, mrbs) + else: + cl = int(self.inheaders.get(b"Content-Length", 0)) + if mrbs and mrbs < cl: + if not self.sent_headers: + self.simple_response( + "413 Request Entity Too Large", + "The entity sent with the request exceeds the " + "maximum allowed bytes.") + return + self.rfile = KnownLengthRFile(self.conn.rfile, cl) + + self.server.gateway(self).respond() + + if (self.ready and not self.sent_headers): + self.sent_headers = True + self.send_headers() + if self.chunked_write: + self.conn.wfile.write(b"0\r\n\r\n") + + def simple_response(self, status, msg=""): + """Write a simple response back to the client.""" + status = str(status) + buf = [bytes(self.server.protocol, "ascii") + SPACE + + bytes(status, "ISO-8859-1") + CRLF, + bytes("Content-Length: %s\r\n" % len(msg), "ISO-8859-1"), + b"Content-Type: text/plain\r\n"] + + if status[:3] in ("413", "414"): + # Request Entity Too Large / Request-URI Too Long + self.close_connection = True + if self.response_protocol == 'HTTP/1.1': + # This will not be true for 414, since read_request_line + # usually raises 414 before reading the whole line, and we + # therefore cannot know the proper response_protocol. + buf.append(b"Connection: close\r\n") + else: + # HTTP/1.0 had no 413/414 status nor Connection header. + # Emit 400 instead and trust the message body is enough. + status = "400 Bad Request" + + buf.append(CRLF) + if msg: + if isinstance(msg, unicodestr): + msg = msg.encode("ISO-8859-1") + buf.append(msg) + + try: + self.conn.wfile.write(b"".join(buf)) + except socket.error: + x = sys.exc_info()[1] + if x.args[0] not in socket_errors_to_ignore: + raise + + def write(self, chunk): + """Write unbuffered data to the client.""" + if self.chunked_write and chunk: + buf = [bytes(hex(len(chunk)), 'ASCII')[2:], CRLF, chunk, CRLF] + self.conn.wfile.write(EMPTY.join(buf)) + else: + self.conn.wfile.write(chunk) + + def send_headers(self): + """Assert, process, and send the HTTP response message-headers. + + You must set self.status, and self.outheaders before calling this. + """ + hkeys = [key.lower() for key, value in self.outheaders] + status = int(self.status[:3]) + + if status == 413: + # Request Entity Too Large. Close conn to avoid garbage. + self.close_connection = True + elif b"content-length" not in hkeys: + # "All 1xx (informational), 204 (no content), + # and 304 (not modified) responses MUST NOT + # include a message-body." So no point chunking. + if status < 200 or status in (204, 205, 304): + pass + else: + if (self.response_protocol == 'HTTP/1.1' + and self.method != b'HEAD'): + # Use the chunked transfer-coding + self.chunked_write = True + self.outheaders.append((b"Transfer-Encoding", b"chunked")) + else: + # Closing the conn is the only way to determine len. + self.close_connection = True + + if b"connection" not in hkeys: + if self.response_protocol == 'HTTP/1.1': + # Both server and client are HTTP/1.1 or better + if self.close_connection: + self.outheaders.append((b"Connection", b"close")) + else: + # Server and/or client are HTTP/1.0 + if not self.close_connection: + self.outheaders.append((b"Connection", b"Keep-Alive")) + + if (not self.close_connection) and (not self.chunked_read): + # Read any remaining request body data on the socket. + # "If an origin server receives a request that does not include an + # Expect request-header field with the "100-continue" expectation, + # the request includes a request body, and the server responds + # with a final status code before reading the entire request body + # from the transport connection, then the server SHOULD NOT close + # the transport connection until it has read the entire request, + # or until the client closes the connection. Otherwise, the client + # might not reliably receive the response message. However, this + # requirement is not be construed as preventing a server from + # defending itself against denial-of-service attacks, or from + # badly broken client implementations." + remaining = getattr(self.rfile, 'remaining', 0) + if remaining > 0: + self.rfile.read(remaining) + + if b"date" not in hkeys: + self.outheaders.append(( + b"Date", + email.utils.formatdate(usegmt=True).encode('ISO-8859-1') + )) + + if b"server" not in hkeys: + self.outheaders.append( + (b"Server", self.server.server_name.encode('ISO-8859-1'))) + + buf = [self.server.protocol.encode( + 'ascii') + SPACE + self.status + CRLF] + for k, v in self.outheaders: + buf.append(k + COLON + SPACE + v + CRLF) + buf.append(CRLF) + self.conn.wfile.write(EMPTY.join(buf)) + + +class NoSSLError(Exception): + + """Exception raised when a client speaks HTTP to an HTTPS socket.""" + pass + + +class FatalSSLAlert(Exception): + + """Exception raised when the SSL implementation signals a fatal alert.""" + pass + + +class CP_BufferedWriter(io.BufferedWriter): + + """Faux file object attached to a socket object.""" + + def write(self, b): + self._checkClosed() + if isinstance(b, str): + raise TypeError("can't write str to binary stream") + + with self._write_lock: + self._write_buf.extend(b) + self._flush_unlocked() + return len(b) + + def _flush_unlocked(self): + self._checkClosed("flush of closed file") + while self._write_buf: + try: + # ssl sockets only except 'bytes', not bytearrays + # so perhaps we should conditionally wrap this for perf? + n = self.raw.write(bytes(self._write_buf)) + except io.BlockingIOError as e: + n = e.characters_written + del self._write_buf[:n] + + +def CP_makefile(sock, mode='r', bufsize=DEFAULT_BUFFER_SIZE): + if 'r' in mode: + return io.BufferedReader(socket.SocketIO(sock, mode), bufsize) + else: + return CP_BufferedWriter(socket.SocketIO(sock, mode), bufsize) + + +class HTTPConnection(object): + + """An HTTP connection (active socket). + + server: the Server object which received this connection. + socket: the raw socket object (usually TCP) for this connection. + makefile: a fileobject class for reading from the socket. + """ + + remote_addr = None + remote_port = None + ssl_env = None + rbufsize = DEFAULT_BUFFER_SIZE + wbufsize = DEFAULT_BUFFER_SIZE + RequestHandlerClass = HTTPRequest + + def __init__(self, server, sock, makefile=CP_makefile): + self.server = server + self.socket = sock + self.rfile = makefile(sock, "rb", self.rbufsize) + self.wfile = makefile(sock, "wb", self.wbufsize) + self.requests_seen = 0 + + def communicate(self): + """Read each request and respond appropriately.""" + request_seen = False + try: + while True: + # (re)set req to None so that if something goes wrong in + # the RequestHandlerClass constructor, the error doesn't + # get written to the previous request. + req = None + req = self.RequestHandlerClass(self.server, self) + + # This order of operations should guarantee correct pipelining. + req.parse_request() + if self.server.stats['Enabled']: + self.requests_seen += 1 + if not req.ready: + # Something went wrong in the parsing (and the server has + # probably already made a simple_response). Return and + # let the conn close. + return + + request_seen = True + req.respond() + if req.close_connection: + return + except socket.error: + e = sys.exc_info()[1] + errnum = e.args[0] + # sadly SSL sockets return a different (longer) time out string + if ( + errnum == 'timed out' or + errnum == 'The read operation timed out' + ): + # Don't error if we're between requests; only error + # if 1) no request has been started at all, or 2) we're + # in the middle of a request. + # See https://bitbucket.org/cherrypy/cherrypy/issue/853 + if (not request_seen) or (req and req.started_request): + # Don't bother writing the 408 if the response + # has already started being written. + if req and not req.sent_headers: + try: + req.simple_response("408 Request Timeout") + except FatalSSLAlert: + # Close the connection. + return + elif errnum not in socket_errors_to_ignore: + self.server.error_log("socket.error %s" % repr(errnum), + level=logging.WARNING, traceback=True) + if req and not req.sent_headers: + try: + req.simple_response("500 Internal Server Error") + except FatalSSLAlert: + # Close the connection. + return + return + except (KeyboardInterrupt, SystemExit): + raise + except FatalSSLAlert: + # Close the connection. + return + except NoSSLError: + if req and not req.sent_headers: + # Unwrap our wfile + self.wfile = CP_makefile( + self.socket._sock, "wb", self.wbufsize) + req.simple_response( + "400 Bad Request", + "The client sent a plain HTTP request, but this server " + "only speaks HTTPS on this port.") + self.linger = True + except Exception: + e = sys.exc_info()[1] + self.server.error_log(repr(e), level=logging.ERROR, traceback=True) + if req and not req.sent_headers: + try: + req.simple_response("500 Internal Server Error") + except FatalSSLAlert: + # Close the connection. + return + + linger = False + + def close(self): + """Close the socket underlying this connection.""" + self.rfile.close() + + if not self.linger: + # Python's socket module does NOT call close on the kernel + # socket when you call socket.close(). We do so manually here + # because we want this server to send a FIN TCP segment + # immediately. Note this must be called *before* calling + # socket.close(), because the latter drops its reference to + # the kernel socket. + # Python 3 *probably* fixed this with socket._real_close; + # hard to tell. +# self.socket._sock.close() + self.socket.close() + else: + # On the other hand, sometimes we want to hang around for a bit + # to make sure the client has a chance to read our entire + # response. Skipping the close() calls here delays the FIN + # packet until the socket object is garbage-collected later. + # Someday, perhaps, we'll do the full lingering_close that + # Apache does, but not today. + pass + + +class TrueyZero(object): + + """An object which equals and does math like the integer 0 but evals True. + """ + + def __add__(self, other): + return other + + def __radd__(self, other): + return other +trueyzero = TrueyZero() + + +_SHUTDOWNREQUEST = None + + +class WorkerThread(threading.Thread): + + """Thread which continuously polls a Queue for Connection objects. + + Due to the timing issues of polling a Queue, a WorkerThread does not + check its own 'ready' flag after it has started. To stop the thread, + it is necessary to stick a _SHUTDOWNREQUEST object onto the Queue + (one for each running WorkerThread). + """ + + conn = None + """The current connection pulled off the Queue, or None.""" + + server = None + """The HTTP Server which spawned this thread, and which owns the + Queue and is placing active connections into it.""" + + ready = False + """A simple flag for the calling server to know when this thread + has begun polling the Queue.""" + + def __init__(self, server): + self.ready = False + self.server = server + + self.requests_seen = 0 + self.bytes_read = 0 + self.bytes_written = 0 + self.start_time = None + self.work_time = 0 + self.stats = { + 'Requests': lambda s: self.requests_seen + ( + (self.start_time is None) and + trueyzero or + self.conn.requests_seen + ), + 'Bytes Read': lambda s: self.bytes_read + ( + (self.start_time is None) and + trueyzero or + self.conn.rfile.bytes_read + ), + 'Bytes Written': lambda s: self.bytes_written + ( + (self.start_time is None) and + trueyzero or + self.conn.wfile.bytes_written + ), + 'Work Time': lambda s: self.work_time + ( + (self.start_time is None) and + trueyzero or + time.time() - self.start_time + ), + 'Read Throughput': lambda s: s['Bytes Read'](s) / ( + s['Work Time'](s) or 1e-6), + 'Write Throughput': lambda s: s['Bytes Written'](s) / ( + s['Work Time'](s) or 1e-6), + } + threading.Thread.__init__(self) + + def run(self): + self.server.stats['Worker Threads'][self.getName()] = self.stats + try: + self.ready = True + while True: + conn = self.server.requests.get() + if conn is _SHUTDOWNREQUEST: + return + + self.conn = conn + if self.server.stats['Enabled']: + self.start_time = time.time() + try: + conn.communicate() + finally: + conn.close() + if self.server.stats['Enabled']: + self.requests_seen += self.conn.requests_seen + self.bytes_read += self.conn.rfile.bytes_read + self.bytes_written += self.conn.wfile.bytes_written + self.work_time += time.time() - self.start_time + self.start_time = None + self.conn = None + except (KeyboardInterrupt, SystemExit): + exc = sys.exc_info()[1] + self.server.interrupt = exc + + +class ThreadPool(object): + + """A Request Queue for an HTTPServer which pools threads. + + ThreadPool objects must provide min, get(), put(obj), start() + and stop(timeout) attributes. + """ + + def __init__(self, server, min=10, max=-1, + accepted_queue_size=-1, accepted_queue_timeout=10): + self.server = server + self.min = min + self.max = max + self._threads = [] + self._queue = queue.Queue(maxsize=accepted_queue_size) + self._queue_put_timeout = accepted_queue_timeout + self.get = self._queue.get + + def start(self): + """Start the pool of threads.""" + for i in range(self.min): + self._threads.append(WorkerThread(self.server)) + for worker in self._threads: + worker.setName("CP Server " + worker.getName()) + worker.start() + for worker in self._threads: + while not worker.ready: + time.sleep(.1) + + def _get_idle(self): + """Number of worker threads which are idle. Read-only.""" + return len([t for t in self._threads if t.conn is None]) + idle = property(_get_idle, doc=_get_idle.__doc__) + + def put(self, obj): + self._queue.put(obj, block=True, timeout=self._queue_put_timeout) + if obj is _SHUTDOWNREQUEST: + return + + def grow(self, amount): + """Spawn new worker threads (not above self.max).""" + if self.max > 0: + budget = max(self.max - len(self._threads), 0) + else: + # self.max <= 0 indicates no maximum + budget = float('inf') + + n_new = min(amount, budget) + + workers = [self._spawn_worker() for i in range(n_new)] + while not all(worker.ready for worker in workers): + time.sleep(.1) + self._threads.extend(workers) + + def _spawn_worker(self): + worker = WorkerThread(self.server) + worker.setName("CP Server " + worker.getName()) + worker.start() + return worker + + def shrink(self, amount): + """Kill off worker threads (not below self.min).""" + # Grow/shrink the pool if necessary. + # Remove any dead threads from our list + for t in self._threads: + if not t.isAlive(): + self._threads.remove(t) + amount -= 1 + + # calculate the number of threads above the minimum + n_extra = max(len(self._threads) - self.min, 0) + + # don't remove more than amount + n_to_remove = min(amount, n_extra) + + # put shutdown requests on the queue equal to the number of threads + # to remove. As each request is processed by a worker, that worker + # will terminate and be culled from the list. + for n in range(n_to_remove): + self._queue.put(_SHUTDOWNREQUEST) + + def stop(self, timeout=5): + # Must shut down threads here so the code that calls + # this method can know when all threads are stopped. + for worker in self._threads: + self._queue.put(_SHUTDOWNREQUEST) + + # Don't join currentThread (when stop is called inside a request). + current = threading.currentThread() + if timeout and timeout >= 0: + endtime = time.time() + timeout + while self._threads: + worker = self._threads.pop() + if worker is not current and worker.isAlive(): + try: + if timeout is None or timeout < 0: + worker.join() + else: + remaining_time = endtime - time.time() + if remaining_time > 0: + worker.join(remaining_time) + if worker.isAlive(): + # We exhausted the timeout. + # Forcibly shut down the socket. + c = worker.conn + if c and not c.rfile.closed: + try: + c.socket.shutdown(socket.SHUT_RD) + except TypeError: + # pyOpenSSL sockets don't take an arg + c.socket.shutdown() + worker.join() + except (AssertionError, + # Ignore repeated Ctrl-C. + # See + # https://bitbucket.org/cherrypy/cherrypy/issue/691. + KeyboardInterrupt): + pass + + def _get_qsize(self): + return self._queue.qsize() + qsize = property(_get_qsize) + + +try: + import fcntl +except ImportError: + try: + from ctypes import windll, WinError + import ctypes.wintypes + _SetHandleInformation = windll.kernel32.SetHandleInformation + _SetHandleInformation.argtypes = [ + ctypes.wintypes.HANDLE, + ctypes.wintypes.DWORD, + ctypes.wintypes.DWORD, + ] + _SetHandleInformation.restype = ctypes.wintypes.BOOL + except ImportError: + def prevent_socket_inheritance(sock): + """Dummy function, since neither fcntl nor ctypes are available.""" + pass + else: + def prevent_socket_inheritance(sock): + """Mark the given socket fd as non-inheritable (Windows).""" + if not _SetHandleInformation(sock.fileno(), 1, 0): + raise WinError() +else: + def prevent_socket_inheritance(sock): + """Mark the given socket fd as non-inheritable (POSIX).""" + fd = sock.fileno() + old_flags = fcntl.fcntl(fd, fcntl.F_GETFD) + fcntl.fcntl(fd, fcntl.F_SETFD, old_flags | fcntl.FD_CLOEXEC) + + +class SSLAdapter(object): + + """Base class for SSL driver library adapters. + + Required methods: + + * ``wrap(sock) -> (wrapped socket, ssl environ dict)`` + * ``makefile(sock, mode='r', bufsize=DEFAULT_BUFFER_SIZE) -> + socket file object`` + """ + + def __init__(self, certificate, private_key, certificate_chain=None): + self.certificate = certificate + self.private_key = private_key + self.certificate_chain = certificate_chain + + def wrap(self, sock): + raise NotImplemented + + def makefile(self, sock, mode='r', bufsize=DEFAULT_BUFFER_SIZE): + raise NotImplemented + + +class HTTPServer(object): + + """An HTTP server.""" + + _bind_addr = "127.0.0.1" + _interrupt = None + + gateway = None + """A Gateway instance.""" + + minthreads = None + """The minimum number of worker threads to create (default 10).""" + + maxthreads = None + """The maximum number of worker threads to create (default -1 = no limit). + """ + + server_name = None + """The name of the server; defaults to socket.gethostname().""" + + protocol = "HTTP/1.1" + """The version string to write in the Status-Line of all HTTP responses. + + For example, "HTTP/1.1" is the default. This also limits the supported + features used in the response.""" + + request_queue_size = 5 + """The 'backlog' arg to socket.listen(); max queued connections + (default 5). + """ + + shutdown_timeout = 5 + """The total time, in seconds, to wait for worker threads to cleanly exit. + """ + + timeout = 10 + """The timeout in seconds for accepted connections (default 10).""" + + version = "CherryPy/3.6.0" + """A version string for the HTTPServer.""" + + software = None + """The value to set for the SERVER_SOFTWARE entry in the WSGI environ. + + If None, this defaults to ``'%s Server' % self.version``.""" + + ready = False + """An internal flag which marks whether the socket is accepting + connections. + """ + + max_request_header_size = 0 + """The maximum size, in bytes, for request headers, or 0 for no limit.""" + + max_request_body_size = 0 + """The maximum size, in bytes, for request bodies, or 0 for no limit.""" + + nodelay = True + """If True (the default since 3.1), sets the TCP_NODELAY socket option.""" + + ConnectionClass = HTTPConnection + """The class to use for handling HTTP connections.""" + + ssl_adapter = None + """An instance of SSLAdapter (or a subclass). + + You must have the corresponding SSL driver library installed.""" + + def __init__(self, bind_addr, gateway, minthreads=10, maxthreads=-1, + server_name=None): + self.bind_addr = bind_addr + self.gateway = gateway + + self.requests = ThreadPool(self, min=minthreads or 1, max=maxthreads) + + if not server_name: + server_name = socket.gethostname() + self.server_name = server_name + self.clear_stats() + + def clear_stats(self): + self._start_time = None + self._run_time = 0 + self.stats = { + 'Enabled': False, + 'Bind Address': lambda s: repr(self.bind_addr), + 'Run time': lambda s: (not s['Enabled']) and -1 or self.runtime(), + 'Accepts': 0, + 'Accepts/sec': lambda s: s['Accepts'] / self.runtime(), + 'Queue': lambda s: getattr(self.requests, "qsize", None), + 'Threads': lambda s: len(getattr(self.requests, "_threads", [])), + 'Threads Idle': lambda s: getattr(self.requests, "idle", None), + 'Socket Errors': 0, + 'Requests': lambda s: (not s['Enabled']) and -1 or sum( + [w['Requests'](w) for w in s['Worker Threads'].values()], 0), + 'Bytes Read': lambda s: (not s['Enabled']) and -1 or sum( + [w['Bytes Read'](w) for w in s['Worker Threads'].values()], 0), + 'Bytes Written': lambda s: (not s['Enabled']) and -1 or sum( + [w['Bytes Written'](w) for w in s['Worker Threads'].values()], + 0), + 'Work Time': lambda s: (not s['Enabled']) and -1 or sum( + [w['Work Time'](w) for w in s['Worker Threads'].values()], 0), + 'Read Throughput': lambda s: (not s['Enabled']) and -1 or sum( + [w['Bytes Read'](w) / (w['Work Time'](w) or 1e-6) + for w in s['Worker Threads'].values()], 0), + 'Write Throughput': lambda s: (not s['Enabled']) and -1 or sum( + [w['Bytes Written'](w) / (w['Work Time'](w) or 1e-6) + for w in s['Worker Threads'].values()], 0), + 'Worker Threads': {}, + } + logging.statistics["CherryPy HTTPServer %d" % id(self)] = self.stats + + def runtime(self): + if self._start_time is None: + return self._run_time + else: + return self._run_time + (time.time() - self._start_time) + + def __str__(self): + return "%s.%s(%r)" % (self.__module__, self.__class__.__name__, + self.bind_addr) + + def _get_bind_addr(self): + return self._bind_addr + + def _set_bind_addr(self, value): + if isinstance(value, tuple) and value[0] in ('', None): + # Despite the socket module docs, using '' does not + # allow AI_PASSIVE to work. Passing None instead + # returns '0.0.0.0' like we want. In other words: + # host AI_PASSIVE result + # '' Y 192.168.x.y + # '' N 192.168.x.y + # None Y 0.0.0.0 + # None N 127.0.0.1 + # But since you can get the same effect with an explicit + # '0.0.0.0', we deny both the empty string and None as values. + raise ValueError("Host values of '' or None are not allowed. " + "Use '0.0.0.0' (IPv4) or '::' (IPv6) instead " + "to listen on all active interfaces.") + self._bind_addr = value + bind_addr = property( + _get_bind_addr, + _set_bind_addr, + doc="""The interface on which to listen for connections. + + For TCP sockets, a (host, port) tuple. Host values may be any IPv4 + or IPv6 address, or any valid hostname. The string 'localhost' is a + synonym for '127.0.0.1' (or '::1', if your hosts file prefers IPv6). + The string '0.0.0.0' is a special IPv4 entry meaning "any active + interface" (INADDR_ANY), and '::' is the similar IN6ADDR_ANY for + IPv6. The empty string or None are not allowed. + + For UNIX sockets, supply the filename as a string.""") + + def start(self): + """Run the server forever.""" + # We don't have to trap KeyboardInterrupt or SystemExit here, + # because cherrpy.server already does so, calling self.stop() for us. + # If you're using this server with another framework, you should + # trap those exceptions in whatever code block calls start(). + self._interrupt = None + + if self.software is None: + self.software = "%s Server" % self.version + + # Select the appropriate socket + if isinstance(self.bind_addr, basestring): + # AF_UNIX socket + + # So we can reuse the socket... + try: + os.unlink(self.bind_addr) + except: + pass + + # So everyone can access the socket... + try: + os.chmod(self.bind_addr, 511) # 0777 + except: + pass + + info = [ + (socket.AF_UNIX, socket.SOCK_STREAM, 0, "", self.bind_addr)] + else: + # AF_INET or AF_INET6 socket + # Get the correct address family for our host (allows IPv6 + # addresses) + host, port = self.bind_addr + try: + info = socket.getaddrinfo(host, port, socket.AF_UNSPEC, + socket.SOCK_STREAM, 0, + socket.AI_PASSIVE) + except socket.gaierror: + if ':' in self.bind_addr[0]: + info = [(socket.AF_INET6, socket.SOCK_STREAM, + 0, "", self.bind_addr + (0, 0))] + else: + info = [(socket.AF_INET, socket.SOCK_STREAM, + 0, "", self.bind_addr)] + + self.socket = None + msg = "No socket could be created" + for res in info: + af, socktype, proto, canonname, sa = res + try: + self.bind(af, socktype, proto) + except socket.error as serr: + msg = "%s -- (%s: %s)" % (msg, sa, serr) + if self.socket: + self.socket.close() + self.socket = None + continue + break + if not self.socket: + raise socket.error(msg) + + # Timeout so KeyboardInterrupt can be caught on Win32 + self.socket.settimeout(1) + self.socket.listen(self.request_queue_size) + + # Create worker threads + self.requests.start() + + self.ready = True + self._start_time = time.time() + while self.ready: + try: + self.tick() + except (KeyboardInterrupt, SystemExit): + raise + except: + self.error_log("Error in HTTPServer.tick", level=logging.ERROR, + traceback=True) + if self.interrupt: + while self.interrupt is True: + # Wait for self.stop() to complete. See _set_interrupt. + time.sleep(0.1) + if self.interrupt: + raise self.interrupt + + def error_log(self, msg="", level=20, traceback=False): + # Override this in subclasses as desired + sys.stderr.write(msg + '\n') + sys.stderr.flush() + if traceback: + tblines = format_exc() + sys.stderr.write(tblines) + sys.stderr.flush() + + def bind(self, family, type, proto=0): + """Create (or recreate) the actual socket object.""" + self.socket = socket.socket(family, type, proto) + prevent_socket_inheritance(self.socket) + self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + if self.nodelay and not isinstance(self.bind_addr, str): + self.socket.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) + + if self.ssl_adapter is not None: + self.socket = self.ssl_adapter.bind(self.socket) + + # If listening on the IPV6 any address ('::' = IN6ADDR_ANY), + # activate dual-stack. See + # https://bitbucket.org/cherrypy/cherrypy/issue/871. + if (hasattr(socket, 'AF_INET6') and family == socket.AF_INET6 + and self.bind_addr[0] in ('::', '::0', '::0.0.0.0')): + try: + self.socket.setsockopt( + socket.IPPROTO_IPV6, socket.IPV6_V6ONLY, 0) + except (AttributeError, socket.error): + # Apparently, the socket option is not available in + # this machine's TCP stack + pass + + self.socket.bind(self.bind_addr) + + def tick(self): + """Accept a new connection and put it on the Queue.""" + try: + s, addr = self.socket.accept() + if self.stats['Enabled']: + self.stats['Accepts'] += 1 + if not self.ready: + return + + prevent_socket_inheritance(s) + if hasattr(s, 'settimeout'): + s.settimeout(self.timeout) + + makefile = CP_makefile + ssl_env = {} + # if ssl cert and key are set, we try to be a secure HTTP server + if self.ssl_adapter is not None: + try: + s, ssl_env = self.ssl_adapter.wrap(s) + except NoSSLError: + msg = ("The client sent a plain HTTP request, but " + "this server only speaks HTTPS on this port.") + buf = ["%s 400 Bad Request\r\n" % self.protocol, + "Content-Length: %s\r\n" % len(msg), + "Content-Type: text/plain\r\n\r\n", + msg] + + wfile = makefile(s, "wb", DEFAULT_BUFFER_SIZE) + try: + wfile.write("".join(buf).encode('ISO-8859-1')) + except socket.error: + x = sys.exc_info()[1] + if x.args[0] not in socket_errors_to_ignore: + raise + return + if not s: + return + makefile = self.ssl_adapter.makefile + # Re-apply our timeout since we may have a new socket object + if hasattr(s, 'settimeout'): + s.settimeout(self.timeout) + + conn = self.ConnectionClass(self, s, makefile) + + if not isinstance(self.bind_addr, basestring): + # optional values + # Until we do DNS lookups, omit REMOTE_HOST + if addr is None: # sometimes this can happen + # figure out if AF_INET or AF_INET6. + if len(s.getsockname()) == 2: + # AF_INET + addr = ('0.0.0.0', 0) + else: + # AF_INET6 + addr = ('::', 0) + conn.remote_addr = addr[0] + conn.remote_port = addr[1] + + conn.ssl_env = ssl_env + + try: + self.requests.put(conn) + except queue.Full: + # Just drop the conn. TODO: write 503 back? + conn.close() + return + except socket.timeout: + # The only reason for the timeout in start() is so we can + # notice keyboard interrupts on Win32, which don't interrupt + # accept() by default + return + except socket.error: + x = sys.exc_info()[1] + if self.stats['Enabled']: + self.stats['Socket Errors'] += 1 + if x.args[0] in socket_error_eintr: + # I *think* this is right. EINTR should occur when a signal + # is received during the accept() call; all docs say retry + # the call, and I *think* I'm reading it right that Python + # will then go ahead and poll for and handle the signal + # elsewhere. See + # https://bitbucket.org/cherrypy/cherrypy/issue/707. + return + if x.args[0] in socket_errors_nonblocking: + # Just try again. See + # https://bitbucket.org/cherrypy/cherrypy/issue/479. + return + if x.args[0] in socket_errors_to_ignore: + # Our socket was closed. + # See https://bitbucket.org/cherrypy/cherrypy/issue/686. + return + raise + + def _get_interrupt(self): + return self._interrupt + + def _set_interrupt(self, interrupt): + self._interrupt = True + self.stop() + self._interrupt = interrupt + interrupt = property(_get_interrupt, _set_interrupt, + doc="Set this to an Exception instance to " + "interrupt the server.") + + def stop(self): + """Gracefully shutdown a server that is serving forever.""" + self.ready = False + if self._start_time is not None: + self._run_time += (time.time() - self._start_time) + self._start_time = None + + sock = getattr(self, "socket", None) + if sock: + if not isinstance(self.bind_addr, basestring): + # Touch our own socket to make accept() return immediately. + try: + host, port = sock.getsockname()[:2] + except socket.error: + x = sys.exc_info()[1] + if x.args[0] not in socket_errors_to_ignore: + # Changed to use error code and not message + # See + # https://bitbucket.org/cherrypy/cherrypy/issue/860. + raise + else: + # Note that we're explicitly NOT using AI_PASSIVE, + # here, because we want an actual IP to touch. + # localhost won't work if we've bound to a public IP, + # but it will if we bound to '0.0.0.0' (INADDR_ANY). + for res in socket.getaddrinfo(host, port, socket.AF_UNSPEC, + socket.SOCK_STREAM): + af, socktype, proto, canonname, sa = res + s = None + try: + s = socket.socket(af, socktype, proto) + # See + # http://groups.google.com/group/cherrypy-users/ + # browse_frm/thread/bbfe5eb39c904fe0 + s.settimeout(1.0) + s.connect((host, port)) + s.close() + except socket.error: + if s: + s.close() + if hasattr(sock, "close"): + sock.close() + self.socket = None + + self.requests.stop(self.shutdown_timeout) + + +class Gateway(object): + + """A base class to interface HTTPServer with other systems, such as WSGI. + """ + + def __init__(self, req): + self.req = req + + def respond(self): + """Process the current request. Must be overridden in a subclass.""" + raise NotImplemented + + +# These may either be wsgiserver.SSLAdapter subclasses or the string names +# of such classes (in which case they will be lazily loaded). +ssl_adapters = { + 'builtin': 'cherrypy.wsgiserver.ssl_builtin.BuiltinSSLAdapter', +} + + +def get_ssl_adapter_class(name='builtin'): + """Return an SSL adapter class for the given name.""" + adapter = ssl_adapters[name.lower()] + if isinstance(adapter, basestring): + last_dot = adapter.rfind(".") + attr_name = adapter[last_dot + 1:] + mod_path = adapter[:last_dot] + + try: + mod = sys.modules[mod_path] + if mod is None: + raise KeyError() + except KeyError: + # The last [''] is important. + mod = __import__(mod_path, globals(), locals(), ['']) + + # Let an AttributeError propagate outward. + try: + adapter = getattr(mod, attr_name) + except AttributeError: + raise AttributeError("'%s' object has no attribute '%s'" + % (mod_path, attr_name)) + + return adapter + +# ------------------------------- WSGI Stuff -------------------------------- # + + +class CherryPyWSGIServer(HTTPServer): + + """A subclass of HTTPServer which calls a WSGI application.""" + + wsgi_version = (1, 0) + """The version of WSGI to produce.""" + + def __init__(self, bind_addr, wsgi_app, numthreads=10, server_name=None, + max=-1, request_queue_size=5, timeout=10, shutdown_timeout=5, + accepted_queue_size=-1, accepted_queue_timeout=10): + self.requests = ThreadPool(self, min=numthreads or 1, max=max, + accepted_queue_size=accepted_queue_size, + accepted_queue_timeout=accepted_queue_timeout) + self.wsgi_app = wsgi_app + self.gateway = wsgi_gateways[self.wsgi_version] + + self.bind_addr = bind_addr + if not server_name: + server_name = socket.gethostname() + self.server_name = server_name + self.request_queue_size = request_queue_size + + self.timeout = timeout + self.shutdown_timeout = shutdown_timeout + self.clear_stats() + + def _get_numthreads(self): + return self.requests.min + + def _set_numthreads(self, value): + self.requests.min = value + numthreads = property(_get_numthreads, _set_numthreads) + + +class WSGIGateway(Gateway): + + """A base class to interface HTTPServer with WSGI.""" + + def __init__(self, req): + self.req = req + self.started_response = False + self.env = self.get_environ() + self.remaining_bytes_out = None + + def get_environ(self): + """Return a new environ dict targeting the given wsgi.version""" + raise NotImplemented + + def respond(self): + """Process the current request.""" + response = self.req.server.wsgi_app(self.env, self.start_response) + try: + for chunk in response: + # "The start_response callable must not actually transmit + # the response headers. Instead, it must store them for the + # server or gateway to transmit only after the first + # iteration of the application return value that yields + # a NON-EMPTY string, or upon the application's first + # invocation of the write() callable." (PEP 333) + if chunk: + if isinstance(chunk, unicodestr): + chunk = chunk.encode('ISO-8859-1') + self.write(chunk) + finally: + if hasattr(response, "close"): + response.close() + + def start_response(self, status, headers, exc_info=None): + """WSGI callable to begin the HTTP response.""" + # "The application may call start_response more than once, + # if and only if the exc_info argument is provided." + if self.started_response and not exc_info: + raise AssertionError("WSGI start_response called a second " + "time with no exc_info.") + self.started_response = True + + # "if exc_info is provided, and the HTTP headers have already been + # sent, start_response must raise an error, and should raise the + # exc_info tuple." + if self.req.sent_headers: + try: + raise exc_info[0](exc_info[1]).with_traceback(exc_info[2]) + finally: + exc_info = None + + # According to PEP 3333, when using Python 3, the response status + # and headers must be bytes masquerading as unicode; that is, they + # must be of type "str" but are restricted to code points in the + # "latin-1" set. + if not isinstance(status, str): + raise TypeError("WSGI response status is not of type str.") + self.req.status = status.encode('ISO-8859-1') + + for k, v in headers: + if not isinstance(k, str): + raise TypeError( + "WSGI response header key %r is not of type str." % k) + if not isinstance(v, str): + raise TypeError( + "WSGI response header value %r is not of type str." % v) + if k.lower() == 'content-length': + self.remaining_bytes_out = int(v) + self.req.outheaders.append( + (k.encode('ISO-8859-1'), v.encode('ISO-8859-1'))) + + return self.write + + def write(self, chunk): + """WSGI callable to write unbuffered data to the client. + + This method is also used internally by start_response (to write + data from the iterable returned by the WSGI application). + """ + if not self.started_response: + raise AssertionError("WSGI write called before start_response.") + + chunklen = len(chunk) + rbo = self.remaining_bytes_out + if rbo is not None and chunklen > rbo: + if not self.req.sent_headers: + # Whew. We can send a 500 to the client. + self.req.simple_response("500 Internal Server Error", + "The requested resource returned " + "more bytes than the declared " + "Content-Length.") + else: + # Dang. We have probably already sent data. Truncate the chunk + # to fit (so the client doesn't hang) and raise an error later. + chunk = chunk[:rbo] + + if not self.req.sent_headers: + self.req.sent_headers = True + self.req.send_headers() + + self.req.write(chunk) + + if rbo is not None: + rbo -= chunklen + if rbo < 0: + raise ValueError( + "Response body exceeds the declared Content-Length.") + + +class WSGIGateway_10(WSGIGateway): + + """A Gateway class to interface HTTPServer with WSGI 1.0.x.""" + + def get_environ(self): + """Return a new environ dict targeting the given wsgi.version""" + req = self.req + env = { + # set a non-standard environ entry so the WSGI app can know what + # the *real* server protocol is (and what features to support). + # See http://www.faqs.org/rfcs/rfc2145.html. + 'ACTUAL_SERVER_PROTOCOL': req.server.protocol, + 'PATH_INFO': req.path.decode('ISO-8859-1'), + 'QUERY_STRING': req.qs.decode('ISO-8859-1'), + 'REMOTE_ADDR': req.conn.remote_addr or '', + 'REMOTE_PORT': str(req.conn.remote_port or ''), + 'REQUEST_METHOD': req.method.decode('ISO-8859-1'), + 'REQUEST_URI': req.uri.decode('ISO-8859-1'), + 'SCRIPT_NAME': '', + 'SERVER_NAME': req.server.server_name, + # Bah. "SERVER_PROTOCOL" is actually the REQUEST protocol. + 'SERVER_PROTOCOL': req.request_protocol.decode('ISO-8859-1'), + 'SERVER_SOFTWARE': req.server.software, + 'wsgi.errors': sys.stderr, + 'wsgi.input': req.rfile, + 'wsgi.multiprocess': False, + 'wsgi.multithread': True, + 'wsgi.run_once': False, + 'wsgi.url_scheme': req.scheme.decode('ISO-8859-1'), + 'wsgi.version': (1, 0), + } + if isinstance(req.server.bind_addr, basestring): + # AF_UNIX. This isn't really allowed by WSGI, which doesn't + # address unix domain sockets. But it's better than nothing. + env["SERVER_PORT"] = "" + else: + env["SERVER_PORT"] = str(req.server.bind_addr[1]) + + # Request headers + for k, v in req.inheaders.items(): + k = k.decode('ISO-8859-1').upper().replace("-", "_") + env["HTTP_" + k] = v.decode('ISO-8859-1') + + # CONTENT_TYPE/CONTENT_LENGTH + ct = env.pop("HTTP_CONTENT_TYPE", None) + if ct is not None: + env["CONTENT_TYPE"] = ct + cl = env.pop("HTTP_CONTENT_LENGTH", None) + if cl is not None: + env["CONTENT_LENGTH"] = cl + + if req.conn.ssl_env: + env.update(req.conn.ssl_env) + + return env + + +class WSGIGateway_u0(WSGIGateway_10): + + """A Gateway class to interface HTTPServer with WSGI u.0. + + WSGI u.0 is an experimental protocol, which uses unicode for keys + and values in both Python 2 and Python 3. + """ + + def get_environ(self): + """Return a new environ dict targeting the given wsgi.version""" + req = self.req + env_10 = WSGIGateway_10.get_environ(self) + env = env_10.copy() + env['wsgi.version'] = ('u', 0) + + # Request-URI + env.setdefault('wsgi.url_encoding', 'utf-8') + try: + # SCRIPT_NAME is the empty string, who cares what encoding it is? + env["PATH_INFO"] = req.path.decode(env['wsgi.url_encoding']) + env["QUERY_STRING"] = req.qs.decode(env['wsgi.url_encoding']) + except UnicodeDecodeError: + # Fall back to latin 1 so apps can transcode if needed. + env['wsgi.url_encoding'] = 'ISO-8859-1' + env["PATH_INFO"] = env_10["PATH_INFO"] + env["QUERY_STRING"] = env_10["QUERY_STRING"] + + return env + +wsgi_gateways = { + (1, 0): WSGIGateway_10, + ('u', 0): WSGIGateway_u0, +} + + +class WSGIPathInfoDispatcher(object): + + """A WSGI dispatcher for dispatch based on the PATH_INFO. + + apps: a dict or list of (path_prefix, app) pairs. + """ + + def __init__(self, apps): + try: + apps = list(apps.items()) + except AttributeError: + pass + + # Sort the apps by len(path), descending + apps.sort() + apps.reverse() + + # The path_prefix strings must start, but not end, with a slash. + # Use "" instead of "/". + self.apps = [(p.rstrip("/"), a) for p, a in apps] + + def __call__(self, environ, start_response): + path = environ["PATH_INFO"] or "/" + for p, app in self.apps: + # The apps list should be sorted by length, descending. + if path.startswith(p + "/") or path == p: + environ = environ.copy() + environ["SCRIPT_NAME"] = environ["SCRIPT_NAME"] + p + environ["PATH_INFO"] = path[len(p):] + return app(environ, start_response) + + start_response('404 Not Found', [('Content-Type', 'text/plain'), + ('Content-Length', '0')]) + return [''] diff --git a/lib/concurrent/LICENSE b/lib/concurrent/LICENSE new file mode 100644 index 00000000..c430db0f --- /dev/null +++ b/lib/concurrent/LICENSE @@ -0,0 +1,21 @@ +Copyright 2009 Brian Quinlan. All rights reserved. + +Redistribution and use in source and binary forms, with or without modification, +are permitted provided that the following conditions are met: + + 1. Redistributions of source code must retain the above copyright notice, + this list of conditions and the following disclaimer. + 2. Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +THIS SOFTWARE IS PROVIDED BY BRIAN QUINLAN "AS IS" AND ANY EXPRESS OR IMPLIED +WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF +MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT +HALL THE FREEBSD PROJECT OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, +INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF +LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE +OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF +ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. \ No newline at end of file diff --git a/lib/concurrent/__init__.py b/lib/concurrent/__init__.py new file mode 100644 index 00000000..b36383a6 --- /dev/null +++ b/lib/concurrent/__init__.py @@ -0,0 +1,3 @@ +from pkgutil import extend_path + +__path__ = extend_path(__path__, __name__) diff --git a/lib/concurrent/futures/__init__.py b/lib/concurrent/futures/__init__.py new file mode 100644 index 00000000..fef52819 --- /dev/null +++ b/lib/concurrent/futures/__init__.py @@ -0,0 +1,23 @@ +# Copyright 2009 Brian Quinlan. All Rights Reserved. +# Licensed to PSF under a Contributor Agreement. + +"""Execute computations asynchronously using threads or processes.""" + +__author__ = 'Brian Quinlan (brian@sweetapp.com)' + +from concurrent.futures._base import (FIRST_COMPLETED, + FIRST_EXCEPTION, + ALL_COMPLETED, + CancelledError, + TimeoutError, + Future, + Executor, + wait, + as_completed) +from concurrent.futures.thread import ThreadPoolExecutor + +# Jython doesn't have multiprocessing +try: + from concurrent.futures.process import ProcessPoolExecutor +except ImportError: + pass diff --git a/lib/concurrent/futures/_base.py b/lib/concurrent/futures/_base.py new file mode 100644 index 00000000..6f0c0f3b --- /dev/null +++ b/lib/concurrent/futures/_base.py @@ -0,0 +1,605 @@ +# Copyright 2009 Brian Quinlan. All Rights Reserved. +# Licensed to PSF under a Contributor Agreement. + +from __future__ import with_statement +import logging +import threading +import time + +from concurrent.futures._compat import reraise + +try: + from collections import namedtuple +except ImportError: + from concurrent.futures._compat import namedtuple + +__author__ = 'Brian Quinlan (brian@sweetapp.com)' + +FIRST_COMPLETED = 'FIRST_COMPLETED' +FIRST_EXCEPTION = 'FIRST_EXCEPTION' +ALL_COMPLETED = 'ALL_COMPLETED' +_AS_COMPLETED = '_AS_COMPLETED' + +# Possible future states (for internal use by the futures package). +PENDING = 'PENDING' +RUNNING = 'RUNNING' +# The future was cancelled by the user... +CANCELLED = 'CANCELLED' +# ...and _Waiter.add_cancelled() was called by a worker. +CANCELLED_AND_NOTIFIED = 'CANCELLED_AND_NOTIFIED' +FINISHED = 'FINISHED' + +_FUTURE_STATES = [ + PENDING, + RUNNING, + CANCELLED, + CANCELLED_AND_NOTIFIED, + FINISHED +] + +_STATE_TO_DESCRIPTION_MAP = { + PENDING: "pending", + RUNNING: "running", + CANCELLED: "cancelled", + CANCELLED_AND_NOTIFIED: "cancelled", + FINISHED: "finished" +} + +# Logger for internal use by the futures package. +LOGGER = logging.getLogger("concurrent.futures") + +class Error(Exception): + """Base class for all future-related exceptions.""" + pass + +class CancelledError(Error): + """The Future was cancelled.""" + pass + +class TimeoutError(Error): + """The operation exceeded the given deadline.""" + pass + +class _Waiter(object): + """Provides the event that wait() and as_completed() block on.""" + def __init__(self): + self.event = threading.Event() + self.finished_futures = [] + + def add_result(self, future): + self.finished_futures.append(future) + + def add_exception(self, future): + self.finished_futures.append(future) + + def add_cancelled(self, future): + self.finished_futures.append(future) + +class _AsCompletedWaiter(_Waiter): + """Used by as_completed().""" + + def __init__(self): + super(_AsCompletedWaiter, self).__init__() + self.lock = threading.Lock() + + def add_result(self, future): + with self.lock: + super(_AsCompletedWaiter, self).add_result(future) + self.event.set() + + def add_exception(self, future): + with self.lock: + super(_AsCompletedWaiter, self).add_exception(future) + self.event.set() + + def add_cancelled(self, future): + with self.lock: + super(_AsCompletedWaiter, self).add_cancelled(future) + self.event.set() + +class _FirstCompletedWaiter(_Waiter): + """Used by wait(return_when=FIRST_COMPLETED).""" + + def add_result(self, future): + super(_FirstCompletedWaiter, self).add_result(future) + self.event.set() + + def add_exception(self, future): + super(_FirstCompletedWaiter, self).add_exception(future) + self.event.set() + + def add_cancelled(self, future): + super(_FirstCompletedWaiter, self).add_cancelled(future) + self.event.set() + +class _AllCompletedWaiter(_Waiter): + """Used by wait(return_when=FIRST_EXCEPTION and ALL_COMPLETED).""" + + def __init__(self, num_pending_calls, stop_on_exception): + self.num_pending_calls = num_pending_calls + self.stop_on_exception = stop_on_exception + self.lock = threading.Lock() + super(_AllCompletedWaiter, self).__init__() + + def _decrement_pending_calls(self): + with self.lock: + self.num_pending_calls -= 1 + if not self.num_pending_calls: + self.event.set() + + def add_result(self, future): + super(_AllCompletedWaiter, self).add_result(future) + self._decrement_pending_calls() + + def add_exception(self, future): + super(_AllCompletedWaiter, self).add_exception(future) + if self.stop_on_exception: + self.event.set() + else: + self._decrement_pending_calls() + + def add_cancelled(self, future): + super(_AllCompletedWaiter, self).add_cancelled(future) + self._decrement_pending_calls() + +class _AcquireFutures(object): + """A context manager that does an ordered acquire of Future conditions.""" + + def __init__(self, futures): + self.futures = sorted(futures, key=id) + + def __enter__(self): + for future in self.futures: + future._condition.acquire() + + def __exit__(self, *args): + for future in self.futures: + future._condition.release() + +def _create_and_install_waiters(fs, return_when): + if return_when == _AS_COMPLETED: + waiter = _AsCompletedWaiter() + elif return_when == FIRST_COMPLETED: + waiter = _FirstCompletedWaiter() + else: + pending_count = sum( + f._state not in [CANCELLED_AND_NOTIFIED, FINISHED] for f in fs) + + if return_when == FIRST_EXCEPTION: + waiter = _AllCompletedWaiter(pending_count, stop_on_exception=True) + elif return_when == ALL_COMPLETED: + waiter = _AllCompletedWaiter(pending_count, stop_on_exception=False) + else: + raise ValueError("Invalid return condition: %r" % return_when) + + for f in fs: + f._waiters.append(waiter) + + return waiter + +def as_completed(fs, timeout=None): + """An iterator over the given futures that yields each as it completes. + + Args: + fs: The sequence of Futures (possibly created by different Executors) to + iterate over. + timeout: The maximum number of seconds to wait. If None, then there + is no limit on the wait time. + + Returns: + An iterator that yields the given Futures as they complete (finished or + cancelled). + + Raises: + TimeoutError: If the entire result iterator could not be generated + before the given timeout. + """ + if timeout is not None: + end_time = timeout + time.time() + + with _AcquireFutures(fs): + finished = set( + f for f in fs + if f._state in [CANCELLED_AND_NOTIFIED, FINISHED]) + pending = set(fs) - finished + waiter = _create_and_install_waiters(fs, _AS_COMPLETED) + + try: + for future in finished: + yield future + + while pending: + if timeout is None: + wait_timeout = None + else: + wait_timeout = end_time - time.time() + if wait_timeout < 0: + raise TimeoutError( + '%d (of %d) futures unfinished' % ( + len(pending), len(fs))) + + waiter.event.wait(wait_timeout) + + with waiter.lock: + finished = waiter.finished_futures + waiter.finished_futures = [] + waiter.event.clear() + + for future in finished: + yield future + pending.remove(future) + + finally: + for f in fs: + f._waiters.remove(waiter) + +DoneAndNotDoneFutures = namedtuple( + 'DoneAndNotDoneFutures', 'done not_done') +def wait(fs, timeout=None, return_when=ALL_COMPLETED): + """Wait for the futures in the given sequence to complete. + + Args: + fs: The sequence of Futures (possibly created by different Executors) to + wait upon. + timeout: The maximum number of seconds to wait. If None, then there + is no limit on the wait time. + return_when: Indicates when this function should return. The options + are: + + FIRST_COMPLETED - Return when any future finishes or is + cancelled. + FIRST_EXCEPTION - Return when any future finishes by raising an + exception. If no future raises an exception + then it is equivalent to ALL_COMPLETED. + ALL_COMPLETED - Return when all futures finish or are cancelled. + + Returns: + A named 2-tuple of sets. The first set, named 'done', contains the + futures that completed (is finished or cancelled) before the wait + completed. The second set, named 'not_done', contains uncompleted + futures. + """ + with _AcquireFutures(fs): + done = set(f for f in fs + if f._state in [CANCELLED_AND_NOTIFIED, FINISHED]) + not_done = set(fs) - done + + if (return_when == FIRST_COMPLETED) and done: + return DoneAndNotDoneFutures(done, not_done) + elif (return_when == FIRST_EXCEPTION) and done: + if any(f for f in done + if not f.cancelled() and f.exception() is not None): + return DoneAndNotDoneFutures(done, not_done) + + if len(done) == len(fs): + return DoneAndNotDoneFutures(done, not_done) + + waiter = _create_and_install_waiters(fs, return_when) + + waiter.event.wait(timeout) + for f in fs: + f._waiters.remove(waiter) + + done.update(waiter.finished_futures) + return DoneAndNotDoneFutures(done, set(fs) - done) + +class Future(object): + """Represents the result of an asynchronous computation.""" + + def __init__(self): + """Initializes the future. Should not be called by clients.""" + self._condition = threading.Condition() + self._state = PENDING + self._result = None + self._exception = None + self._traceback = None + self._waiters = [] + self._done_callbacks = [] + + def _invoke_callbacks(self): + for callback in self._done_callbacks: + try: + callback(self) + except Exception: + LOGGER.exception('exception calling callback for %r', self) + + def __repr__(self): + with self._condition: + if self._state == FINISHED: + if self._exception: + return '' % ( + hex(id(self)), + _STATE_TO_DESCRIPTION_MAP[self._state], + self._exception.__class__.__name__) + else: + return '' % ( + hex(id(self)), + _STATE_TO_DESCRIPTION_MAP[self._state], + self._result.__class__.__name__) + return '' % ( + hex(id(self)), + _STATE_TO_DESCRIPTION_MAP[self._state]) + + def cancel(self): + """Cancel the future if possible. + + Returns True if the future was cancelled, False otherwise. A future + cannot be cancelled if it is running or has already completed. + """ + with self._condition: + if self._state in [RUNNING, FINISHED]: + return False + + if self._state in [CANCELLED, CANCELLED_AND_NOTIFIED]: + return True + + self._state = CANCELLED + self._condition.notify_all() + + self._invoke_callbacks() + return True + + def cancelled(self): + """Return True if the future has cancelled.""" + with self._condition: + return self._state in [CANCELLED, CANCELLED_AND_NOTIFIED] + + def running(self): + """Return True if the future is currently executing.""" + with self._condition: + return self._state == RUNNING + + def done(self): + """Return True of the future was cancelled or finished executing.""" + with self._condition: + return self._state in [CANCELLED, CANCELLED_AND_NOTIFIED, FINISHED] + + def __get_result(self): + if self._exception: + reraise(self._exception, self._traceback) + else: + return self._result + + def add_done_callback(self, fn): + """Attaches a callable that will be called when the future finishes. + + Args: + fn: A callable that will be called with this future as its only + argument when the future completes or is cancelled. The callable + will always be called by a thread in the same process in which + it was added. If the future has already completed or been + cancelled then the callable will be called immediately. These + callables are called in the order that they were added. + """ + with self._condition: + if self._state not in [CANCELLED, CANCELLED_AND_NOTIFIED, FINISHED]: + self._done_callbacks.append(fn) + return + fn(self) + + def result(self, timeout=None): + """Return the result of the call that the future represents. + + Args: + timeout: The number of seconds to wait for the result if the future + isn't done. If None, then there is no limit on the wait time. + + Returns: + The result of the call that the future represents. + + Raises: + CancelledError: If the future was cancelled. + TimeoutError: If the future didn't finish executing before the given + timeout. + Exception: If the call raised then that exception will be raised. + """ + with self._condition: + if self._state in [CANCELLED, CANCELLED_AND_NOTIFIED]: + raise CancelledError() + elif self._state == FINISHED: + return self.__get_result() + + self._condition.wait(timeout) + + if self._state in [CANCELLED, CANCELLED_AND_NOTIFIED]: + raise CancelledError() + elif self._state == FINISHED: + return self.__get_result() + else: + raise TimeoutError() + + def exception_info(self, timeout=None): + """Return a tuple of (exception, traceback) raised by the call that the + future represents. + + Args: + timeout: The number of seconds to wait for the exception if the + future isn't done. If None, then there is no limit on the wait + time. + + Returns: + The exception raised by the call that the future represents or None + if the call completed without raising. + + Raises: + CancelledError: If the future was cancelled. + TimeoutError: If the future didn't finish executing before the given + timeout. + """ + with self._condition: + if self._state in [CANCELLED, CANCELLED_AND_NOTIFIED]: + raise CancelledError() + elif self._state == FINISHED: + return self._exception, self._traceback + + self._condition.wait(timeout) + + if self._state in [CANCELLED, CANCELLED_AND_NOTIFIED]: + raise CancelledError() + elif self._state == FINISHED: + return self._exception, self._traceback + else: + raise TimeoutError() + + def exception(self, timeout=None): + """Return the exception raised by the call that the future represents. + + Args: + timeout: The number of seconds to wait for the exception if the + future isn't done. If None, then there is no limit on the wait + time. + + Returns: + The exception raised by the call that the future represents or None + if the call completed without raising. + + Raises: + CancelledError: If the future was cancelled. + TimeoutError: If the future didn't finish executing before the given + timeout. + """ + return self.exception_info(timeout)[0] + + # The following methods should only be used by Executors and in tests. + def set_running_or_notify_cancel(self): + """Mark the future as running or process any cancel notifications. + + Should only be used by Executor implementations and unit tests. + + If the future has been cancelled (cancel() was called and returned + True) then any threads waiting on the future completing (though calls + to as_completed() or wait()) are notified and False is returned. + + If the future was not cancelled then it is put in the running state + (future calls to running() will return True) and True is returned. + + This method should be called by Executor implementations before + executing the work associated with this future. If this method returns + False then the work should not be executed. + + Returns: + False if the Future was cancelled, True otherwise. + + Raises: + RuntimeError: if this method was already called or if set_result() + or set_exception() was called. + """ + with self._condition: + if self._state == CANCELLED: + self._state = CANCELLED_AND_NOTIFIED + for waiter in self._waiters: + waiter.add_cancelled(self) + # self._condition.notify_all() is not necessary because + # self.cancel() triggers a notification. + return False + elif self._state == PENDING: + self._state = RUNNING + return True + else: + LOGGER.critical('Future %s in unexpected state: %s', + id(self.future), + self.future._state) + raise RuntimeError('Future in unexpected state') + + def set_result(self, result): + """Sets the return value of work associated with the future. + + Should only be used by Executor implementations and unit tests. + """ + with self._condition: + self._result = result + self._state = FINISHED + for waiter in self._waiters: + waiter.add_result(self) + self._condition.notify_all() + self._invoke_callbacks() + + def set_exception_info(self, exception, traceback): + """Sets the result of the future as being the given exception + and traceback. + + Should only be used by Executor implementations and unit tests. + """ + with self._condition: + self._exception = exception + self._traceback = traceback + self._state = FINISHED + for waiter in self._waiters: + waiter.add_exception(self) + self._condition.notify_all() + self._invoke_callbacks() + + def set_exception(self, exception): + """Sets the result of the future as being the given exception. + + Should only be used by Executor implementations and unit tests. + """ + self.set_exception_info(exception, None) + +class Executor(object): + """This is an abstract base class for concrete asynchronous executors.""" + + def submit(self, fn, *args, **kwargs): + """Submits a callable to be executed with the given arguments. + + Schedules the callable to be executed as fn(*args, **kwargs) and returns + a Future instance representing the execution of the callable. + + Returns: + A Future representing the given call. + """ + raise NotImplementedError() + + def map(self, fn, *iterables, **kwargs): + """Returns a iterator equivalent to map(fn, iter). + + Args: + fn: A callable that will take as many arguments as there are + passed iterables. + timeout: The maximum number of seconds to wait. If None, then there + is no limit on the wait time. + + Returns: + An iterator equivalent to: map(func, *iterables) but the calls may + be evaluated out-of-order. + + Raises: + TimeoutError: If the entire result iterator could not be generated + before the given timeout. + Exception: If fn(*args) raises for any values. + """ + timeout = kwargs.get('timeout') + if timeout is not None: + end_time = timeout + time.time() + + fs = [self.submit(fn, *args) for args in zip(*iterables)] + + try: + for future in fs: + if timeout is None: + yield future.result() + else: + yield future.result(end_time - time.time()) + finally: + for future in fs: + future.cancel() + + def shutdown(self, wait=True): + """Clean-up the resources associated with the Executor. + + It is safe to call this method several times. Otherwise, no other + methods can be called after this one. + + Args: + wait: If True then shutdown will not return until all running + futures have finished executing and the resources used by the + executor have been reclaimed. + """ + pass + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.shutdown(wait=True) + return False diff --git a/lib/concurrent/futures/_compat.py b/lib/concurrent/futures/_compat.py new file mode 100644 index 00000000..e77cf0e5 --- /dev/null +++ b/lib/concurrent/futures/_compat.py @@ -0,0 +1,111 @@ +from keyword import iskeyword as _iskeyword +from operator import itemgetter as _itemgetter +import sys as _sys + + +def namedtuple(typename, field_names): + """Returns a new subclass of tuple with named fields. + + >>> Point = namedtuple('Point', 'x y') + >>> Point.__doc__ # docstring for the new class + 'Point(x, y)' + >>> p = Point(11, y=22) # instantiate with positional args or keywords + >>> p[0] + p[1] # indexable like a plain tuple + 33 + >>> x, y = p # unpack like a regular tuple + >>> x, y + (11, 22) + >>> p.x + p.y # fields also accessable by name + 33 + >>> d = p._asdict() # convert to a dictionary + >>> d['x'] + 11 + >>> Point(**d) # convert from a dictionary + Point(x=11, y=22) + >>> p._replace(x=100) # _replace() is like str.replace() but targets named fields + Point(x=100, y=22) + + """ + + # Parse and validate the field names. Validation serves two purposes, + # generating informative error messages and preventing template injection attacks. + if isinstance(field_names, basestring): + field_names = field_names.replace(',', ' ').split() # names separated by whitespace and/or commas + field_names = tuple(map(str, field_names)) + for name in (typename,) + field_names: + if not all(c.isalnum() or c=='_' for c in name): + raise ValueError('Type names and field names can only contain alphanumeric characters and underscores: %r' % name) + if _iskeyword(name): + raise ValueError('Type names and field names cannot be a keyword: %r' % name) + if name[0].isdigit(): + raise ValueError('Type names and field names cannot start with a number: %r' % name) + seen_names = set() + for name in field_names: + if name.startswith('_'): + raise ValueError('Field names cannot start with an underscore: %r' % name) + if name in seen_names: + raise ValueError('Encountered duplicate field name: %r' % name) + seen_names.add(name) + + # Create and fill-in the class template + numfields = len(field_names) + argtxt = repr(field_names).replace("'", "")[1:-1] # tuple repr without parens or quotes + reprtxt = ', '.join('%s=%%r' % name for name in field_names) + dicttxt = ', '.join('%r: t[%d]' % (name, pos) for pos, name in enumerate(field_names)) + template = '''class %(typename)s(tuple): + '%(typename)s(%(argtxt)s)' \n + __slots__ = () \n + _fields = %(field_names)r \n + def __new__(_cls, %(argtxt)s): + return _tuple.__new__(_cls, (%(argtxt)s)) \n + @classmethod + def _make(cls, iterable, new=tuple.__new__, len=len): + 'Make a new %(typename)s object from a sequence or iterable' + result = new(cls, iterable) + if len(result) != %(numfields)d: + raise TypeError('Expected %(numfields)d arguments, got %%d' %% len(result)) + return result \n + def __repr__(self): + return '%(typename)s(%(reprtxt)s)' %% self \n + def _asdict(t): + 'Return a new dict which maps field names to their values' + return {%(dicttxt)s} \n + def _replace(_self, **kwds): + 'Return a new %(typename)s object replacing specified fields with new values' + result = _self._make(map(kwds.pop, %(field_names)r, _self)) + if kwds: + raise ValueError('Got unexpected field names: %%r' %% kwds.keys()) + return result \n + def __getnewargs__(self): + return tuple(self) \n\n''' % locals() + for i, name in enumerate(field_names): + template += ' %s = _property(_itemgetter(%d))\n' % (name, i) + + # Execute the template string in a temporary namespace and + # support tracing utilities by setting a value for frame.f_globals['__name__'] + namespace = dict(_itemgetter=_itemgetter, __name__='namedtuple_%s' % typename, + _property=property, _tuple=tuple) + try: + exec(template, namespace) + except SyntaxError: + e = _sys.exc_info()[1] + raise SyntaxError(e.message + ':\n' + template) + result = namespace[typename] + + # For pickling to work, the __module__ variable needs to be set to the frame + # where the named tuple is created. Bypass this step in enviroments where + # sys._getframe is not defined (Jython for example). + if hasattr(_sys, '_getframe'): + result.__module__ = _sys._getframe(1).f_globals.get('__name__', '__main__') + + return result + + +if _sys.version_info[0] < 3: + def reraise(exc, traceback): + locals_ = {'exc_type': type(exc), 'exc_value': exc, 'traceback': traceback} + exec('raise exc_type, exc_value, traceback', {}, locals_) +else: + def reraise(exc, traceback): + # Tracebacks are embedded in exceptions in Python 3 + raise exc diff --git a/lib/concurrent/futures/process.py b/lib/concurrent/futures/process.py new file mode 100644 index 00000000..98684f8e --- /dev/null +++ b/lib/concurrent/futures/process.py @@ -0,0 +1,363 @@ +# Copyright 2009 Brian Quinlan. All Rights Reserved. +# Licensed to PSF under a Contributor Agreement. + +"""Implements ProcessPoolExecutor. + +The follow diagram and text describe the data-flow through the system: + +|======================= In-process =====================|== Out-of-process ==| + ++----------+ +----------+ +--------+ +-----------+ +---------+ +| | => | Work Ids | => | | => | Call Q | => | | +| | +----------+ | | +-----------+ | | +| | | ... | | | | ... | | | +| | | 6 | | | | 5, call() | | | +| | | 7 | | | | ... | | | +| Process | | ... | | Local | +-----------+ | Process | +| Pool | +----------+ | Worker | | #1..n | +| Executor | | Thread | | | +| | +----------- + | | +-----------+ | | +| | <=> | Work Items | <=> | | <= | Result Q | <= | | +| | +------------+ | | +-----------+ | | +| | | 6: call() | | | | ... | | | +| | | future | | | | 4, result | | | +| | | ... | | | | 3, except | | | ++----------+ +------------+ +--------+ +-----------+ +---------+ + +Executor.submit() called: +- creates a uniquely numbered _WorkItem and adds it to the "Work Items" dict +- adds the id of the _WorkItem to the "Work Ids" queue + +Local worker thread: +- reads work ids from the "Work Ids" queue and looks up the corresponding + WorkItem from the "Work Items" dict: if the work item has been cancelled then + it is simply removed from the dict, otherwise it is repackaged as a + _CallItem and put in the "Call Q". New _CallItems are put in the "Call Q" + until "Call Q" is full. NOTE: the size of the "Call Q" is kept small because + calls placed in the "Call Q" can no longer be cancelled with Future.cancel(). +- reads _ResultItems from "Result Q", updates the future stored in the + "Work Items" dict and deletes the dict entry + +Process #1..n: +- reads _CallItems from "Call Q", executes the calls, and puts the resulting + _ResultItems in "Request Q" +""" + +from __future__ import with_statement +import atexit +import multiprocessing +import threading +import weakref +import sys + +from concurrent.futures import _base + +try: + import queue +except ImportError: + import Queue as queue + +__author__ = 'Brian Quinlan (brian@sweetapp.com)' + +# Workers are created as daemon threads and processes. This is done to allow the +# interpreter to exit when there are still idle processes in a +# ProcessPoolExecutor's process pool (i.e. shutdown() was not called). However, +# allowing workers to die with the interpreter has two undesirable properties: +# - The workers would still be running during interpretor shutdown, +# meaning that they would fail in unpredictable ways. +# - The workers could be killed while evaluating a work item, which could +# be bad if the callable being evaluated has external side-effects e.g. +# writing to a file. +# +# To work around this problem, an exit handler is installed which tells the +# workers to exit when their work queues are empty and then waits until the +# threads/processes finish. + +_threads_queues = weakref.WeakKeyDictionary() +_shutdown = False + +def _python_exit(): + global _shutdown + _shutdown = True + items = list(_threads_queues.items()) + for t, q in items: + q.put(None) + for t, q in items: + t.join() + +# Controls how many more calls than processes will be queued in the call queue. +# A smaller number will mean that processes spend more time idle waiting for +# work while a larger number will make Future.cancel() succeed less frequently +# (Futures in the call queue cannot be cancelled). +EXTRA_QUEUED_CALLS = 1 + +class _WorkItem(object): + def __init__(self, future, fn, args, kwargs): + self.future = future + self.fn = fn + self.args = args + self.kwargs = kwargs + +class _ResultItem(object): + def __init__(self, work_id, exception=None, result=None): + self.work_id = work_id + self.exception = exception + self.result = result + +class _CallItem(object): + def __init__(self, work_id, fn, args, kwargs): + self.work_id = work_id + self.fn = fn + self.args = args + self.kwargs = kwargs + +def _process_worker(call_queue, result_queue): + """Evaluates calls from call_queue and places the results in result_queue. + + This worker is run in a separate process. + + Args: + call_queue: A multiprocessing.Queue of _CallItems that will be read and + evaluated by the worker. + result_queue: A multiprocessing.Queue of _ResultItems that will written + to by the worker. + shutdown: A multiprocessing.Event that will be set as a signal to the + worker that it should exit when call_queue is empty. + """ + while True: + call_item = call_queue.get(block=True) + if call_item is None: + # Wake up queue management thread + result_queue.put(None) + return + try: + r = call_item.fn(*call_item.args, **call_item.kwargs) + except BaseException: + e = sys.exc_info()[1] + result_queue.put(_ResultItem(call_item.work_id, + exception=e)) + else: + result_queue.put(_ResultItem(call_item.work_id, + result=r)) + +def _add_call_item_to_queue(pending_work_items, + work_ids, + call_queue): + """Fills call_queue with _WorkItems from pending_work_items. + + This function never blocks. + + Args: + pending_work_items: A dict mapping work ids to _WorkItems e.g. + {5: <_WorkItem...>, 6: <_WorkItem...>, ...} + work_ids: A queue.Queue of work ids e.g. Queue([5, 6, ...]). Work ids + are consumed and the corresponding _WorkItems from + pending_work_items are transformed into _CallItems and put in + call_queue. + call_queue: A multiprocessing.Queue that will be filled with _CallItems + derived from _WorkItems. + """ + while True: + if call_queue.full(): + return + try: + work_id = work_ids.get(block=False) + except queue.Empty: + return + else: + work_item = pending_work_items[work_id] + + if work_item.future.set_running_or_notify_cancel(): + call_queue.put(_CallItem(work_id, + work_item.fn, + work_item.args, + work_item.kwargs), + block=True) + else: + del pending_work_items[work_id] + continue + +def _queue_management_worker(executor_reference, + processes, + pending_work_items, + work_ids_queue, + call_queue, + result_queue): + """Manages the communication between this process and the worker processes. + + This function is run in a local thread. + + Args: + executor_reference: A weakref.ref to the ProcessPoolExecutor that owns + this thread. Used to determine if the ProcessPoolExecutor has been + garbage collected and that this function can exit. + process: A list of the multiprocessing.Process instances used as + workers. + pending_work_items: A dict mapping work ids to _WorkItems e.g. + {5: <_WorkItem...>, 6: <_WorkItem...>, ...} + work_ids_queue: A queue.Queue of work ids e.g. Queue([5, 6, ...]). + call_queue: A multiprocessing.Queue that will be filled with _CallItems + derived from _WorkItems for processing by the process workers. + result_queue: A multiprocessing.Queue of _ResultItems generated by the + process workers. + """ + nb_shutdown_processes = [0] + def shutdown_one_process(): + """Tell a worker to terminate, which will in turn wake us again""" + call_queue.put(None) + nb_shutdown_processes[0] += 1 + while True: + _add_call_item_to_queue(pending_work_items, + work_ids_queue, + call_queue) + + result_item = result_queue.get(block=True) + if result_item is not None: + work_item = pending_work_items[result_item.work_id] + del pending_work_items[result_item.work_id] + + if result_item.exception: + work_item.future.set_exception(result_item.exception) + else: + work_item.future.set_result(result_item.result) + # Check whether we should start shutting down. + executor = executor_reference() + # No more work items can be added if: + # - The interpreter is shutting down OR + # - The executor that owns this worker has been collected OR + # - The executor that owns this worker has been shutdown. + if _shutdown or executor is None or executor._shutdown_thread: + # Since no new work items can be added, it is safe to shutdown + # this thread if there are no pending work items. + if not pending_work_items: + while nb_shutdown_processes[0] < len(processes): + shutdown_one_process() + # If .join() is not called on the created processes then + # some multiprocessing.Queue methods may deadlock on Mac OS + # X. + for p in processes: + p.join() + call_queue.close() + return + del executor + +_system_limits_checked = False +_system_limited = None +def _check_system_limits(): + global _system_limits_checked, _system_limited + if _system_limits_checked: + if _system_limited: + raise NotImplementedError(_system_limited) + _system_limits_checked = True + try: + import os + nsems_max = os.sysconf("SC_SEM_NSEMS_MAX") + except (AttributeError, ValueError): + # sysconf not available or setting not available + return + if nsems_max == -1: + # indetermine limit, assume that limit is determined + # by available memory only + return + if nsems_max >= 256: + # minimum number of semaphores available + # according to POSIX + return + _system_limited = "system provides too few semaphores (%d available, 256 necessary)" % nsems_max + raise NotImplementedError(_system_limited) + +class ProcessPoolExecutor(_base.Executor): + def __init__(self, max_workers=None): + """Initializes a new ProcessPoolExecutor instance. + + Args: + max_workers: The maximum number of processes that can be used to + execute the given calls. If None or not given then as many + worker processes will be created as the machine has processors. + """ + _check_system_limits() + + if max_workers is None: + self._max_workers = multiprocessing.cpu_count() + else: + self._max_workers = max_workers + + # Make the call queue slightly larger than the number of processes to + # prevent the worker processes from idling. But don't make it too big + # because futures in the call queue cannot be cancelled. + self._call_queue = multiprocessing.Queue(self._max_workers + + EXTRA_QUEUED_CALLS) + self._result_queue = multiprocessing.Queue() + self._work_ids = queue.Queue() + self._queue_management_thread = None + self._processes = set() + + # Shutdown is a two-step process. + self._shutdown_thread = False + self._shutdown_lock = threading.Lock() + self._queue_count = 0 + self._pending_work_items = {} + + def _start_queue_management_thread(self): + # When the executor gets lost, the weakref callback will wake up + # the queue management thread. + def weakref_cb(_, q=self._result_queue): + q.put(None) + if self._queue_management_thread is None: + self._queue_management_thread = threading.Thread( + target=_queue_management_worker, + args=(weakref.ref(self, weakref_cb), + self._processes, + self._pending_work_items, + self._work_ids, + self._call_queue, + self._result_queue)) + self._queue_management_thread.daemon = True + self._queue_management_thread.start() + _threads_queues[self._queue_management_thread] = self._result_queue + + def _adjust_process_count(self): + for _ in range(len(self._processes), self._max_workers): + p = multiprocessing.Process( + target=_process_worker, + args=(self._call_queue, + self._result_queue)) + p.start() + self._processes.add(p) + + def submit(self, fn, *args, **kwargs): + with self._shutdown_lock: + if self._shutdown_thread: + raise RuntimeError('cannot schedule new futures after shutdown') + + f = _base.Future() + w = _WorkItem(f, fn, args, kwargs) + + self._pending_work_items[self._queue_count] = w + self._work_ids.put(self._queue_count) + self._queue_count += 1 + # Wake up queue management thread + self._result_queue.put(None) + + self._start_queue_management_thread() + self._adjust_process_count() + return f + submit.__doc__ = _base.Executor.submit.__doc__ + + def shutdown(self, wait=True): + with self._shutdown_lock: + self._shutdown_thread = True + if self._queue_management_thread: + # Wake up queue management thread + self._result_queue.put(None) + if wait: + self._queue_management_thread.join() + # To reduce the risk of openning too many files, remove references to + # objects that use file descriptors. + self._queue_management_thread = None + self._call_queue = None + self._result_queue = None + self._processes = None + shutdown.__doc__ = _base.Executor.shutdown.__doc__ + +atexit.register(_python_exit) diff --git a/lib/concurrent/futures/thread.py b/lib/concurrent/futures/thread.py new file mode 100644 index 00000000..930d1673 --- /dev/null +++ b/lib/concurrent/futures/thread.py @@ -0,0 +1,138 @@ +# Copyright 2009 Brian Quinlan. All Rights Reserved. +# Licensed to PSF under a Contributor Agreement. + +"""Implements ThreadPoolExecutor.""" + +from __future__ import with_statement +import atexit +import threading +import weakref +import sys + +from concurrent.futures import _base + +try: + import queue +except ImportError: + import Queue as queue + +__author__ = 'Brian Quinlan (brian@sweetapp.com)' + +# Workers are created as daemon threads. This is done to allow the interpreter +# to exit when there are still idle threads in a ThreadPoolExecutor's thread +# pool (i.e. shutdown() was not called). However, allowing workers to die with +# the interpreter has two undesirable properties: +# - The workers would still be running during interpretor shutdown, +# meaning that they would fail in unpredictable ways. +# - The workers could be killed while evaluating a work item, which could +# be bad if the callable being evaluated has external side-effects e.g. +# writing to a file. +# +# To work around this problem, an exit handler is installed which tells the +# workers to exit when their work queues are empty and then waits until the +# threads finish. + +_threads_queues = weakref.WeakKeyDictionary() +_shutdown = False + +def _python_exit(): + global _shutdown + _shutdown = True + items = list(_threads_queues.items()) + for t, q in items: + q.put(None) + for t, q in items: + t.join() + +atexit.register(_python_exit) + +class _WorkItem(object): + def __init__(self, future, fn, args, kwargs): + self.future = future + self.fn = fn + self.args = args + self.kwargs = kwargs + + def run(self): + if not self.future.set_running_or_notify_cancel(): + return + + try: + result = self.fn(*self.args, **self.kwargs) + except BaseException: + e, tb = sys.exc_info()[1:] + self.future.set_exception_info(e, tb) + else: + self.future.set_result(result) + +def _worker(executor_reference, work_queue): + try: + while True: + work_item = work_queue.get(block=True) + if work_item is not None: + work_item.run() + continue + executor = executor_reference() + # Exit if: + # - The interpreter is shutting down OR + # - The executor that owns the worker has been collected OR + # - The executor that owns the worker has been shutdown. + if _shutdown or executor is None or executor._shutdown: + # Notice other workers + work_queue.put(None) + return + del executor + except BaseException: + _base.LOGGER.critical('Exception in worker', exc_info=True) + +class ThreadPoolExecutor(_base.Executor): + def __init__(self, max_workers): + """Initializes a new ThreadPoolExecutor instance. + + Args: + max_workers: The maximum number of threads that can be used to + execute the given calls. + """ + self._max_workers = max_workers + self._work_queue = queue.Queue() + self._threads = set() + self._shutdown = False + self._shutdown_lock = threading.Lock() + + def submit(self, fn, *args, **kwargs): + with self._shutdown_lock: + if self._shutdown: + raise RuntimeError('cannot schedule new futures after shutdown') + + f = _base.Future() + w = _WorkItem(f, fn, args, kwargs) + + self._work_queue.put(w) + self._adjust_thread_count() + return f + submit.__doc__ = _base.Executor.submit.__doc__ + + def _adjust_thread_count(self): + # When the executor gets lost, the weakref callback will wake up + # the worker threads. + def weakref_cb(_, q=self._work_queue): + q.put(None) + # TODO(bquinlan): Should avoid creating new threads if there are more + # idle threads than items in the work queue. + if len(self._threads) < self._max_workers: + t = threading.Thread(target=_worker, + args=(weakref.ref(self, weakref_cb), + self._work_queue)) + t.daemon = True + t.start() + self._threads.add(t) + _threads_queues[t] = self._work_queue + + def shutdown(self, wait=True): + with self._shutdown_lock: + self._shutdown = True + self._work_queue.put(None) + if wait: + for t in self._threads: + t.join() + shutdown.__doc__ = _base.Executor.shutdown.__doc__ diff --git a/lib/configobj.py b/lib/configobj.py new file mode 100644 index 00000000..eae556a7 --- /dev/null +++ b/lib/configobj.py @@ -0,0 +1,2468 @@ +# configobj.py +# A config file reader/writer that supports nested sections in config files. +# Copyright (C) 2005-2010 Michael Foord, Nicola Larosa +# E-mail: fuzzyman AT voidspace DOT org DOT uk +# nico AT tekNico DOT net + +# ConfigObj 4 +# http://www.voidspace.org.uk/python/configobj.html + +# Released subject to the BSD License +# Please see http://www.voidspace.org.uk/python/license.shtml + +# Scripts maintained at http://www.voidspace.org.uk/python/index.shtml +# For information about bugfixes, updates and support, please join the +# ConfigObj mailing list: +# http://lists.sourceforge.net/lists/listinfo/configobj-develop +# Comments, suggestions and bug reports welcome. + +from __future__ import generators + +import os +import re +import sys + +from codecs import BOM_UTF8, BOM_UTF16, BOM_UTF16_BE, BOM_UTF16_LE + + +# imported lazily to avoid startup performance hit if it isn't used +compiler = None + +# A dictionary mapping BOM to +# the encoding to decode with, and what to set the +# encoding attribute to. +BOMS = { + BOM_UTF8: ('utf_8', None), + BOM_UTF16_BE: ('utf16_be', 'utf_16'), + BOM_UTF16_LE: ('utf16_le', 'utf_16'), + BOM_UTF16: ('utf_16', 'utf_16'), + } +# All legal variants of the BOM codecs. +# TODO: the list of aliases is not meant to be exhaustive, is there a +# better way ? +BOM_LIST = { + 'utf_16': 'utf_16', + 'u16': 'utf_16', + 'utf16': 'utf_16', + 'utf-16': 'utf_16', + 'utf16_be': 'utf16_be', + 'utf_16_be': 'utf16_be', + 'utf-16be': 'utf16_be', + 'utf16_le': 'utf16_le', + 'utf_16_le': 'utf16_le', + 'utf-16le': 'utf16_le', + 'utf_8': 'utf_8', + 'u8': 'utf_8', + 'utf': 'utf_8', + 'utf8': 'utf_8', + 'utf-8': 'utf_8', + } + +# Map of encodings to the BOM to write. +BOM_SET = { + 'utf_8': BOM_UTF8, + 'utf_16': BOM_UTF16, + 'utf16_be': BOM_UTF16_BE, + 'utf16_le': BOM_UTF16_LE, + None: BOM_UTF8 + } + + +def match_utf8(encoding): + return BOM_LIST.get(encoding.lower()) == 'utf_8' + + +# Quote strings used for writing values +squot = "'%s'" +dquot = '"%s"' +noquot = "%s" +wspace_plus = ' \r\n\v\t\'"' +tsquot = '"""%s"""' +tdquot = "'''%s'''" + +# Sentinel for use in getattr calls to replace hasattr +MISSING = object() + +__version__ = '4.7.2' + +try: + any +except NameError: + def any(iterable): + for entry in iterable: + if entry: + return True + return False + + +__all__ = ( + '__version__', + 'DEFAULT_INDENT_TYPE', + 'DEFAULT_INTERPOLATION', + 'ConfigObjError', + 'NestingError', + 'ParseError', + 'DuplicateError', + 'ConfigspecError', + 'ConfigObj', + 'SimpleVal', + 'InterpolationError', + 'InterpolationLoopError', + 'MissingInterpolationOption', + 'RepeatSectionError', + 'ReloadError', + 'UnreprError', + 'UnknownType', + 'flatten_errors', + 'get_extra_values' +) + +DEFAULT_INTERPOLATION = 'configparser' +DEFAULT_INDENT_TYPE = ' ' +MAX_INTERPOL_DEPTH = 10 + +OPTION_DEFAULTS = { + 'interpolation': True, + 'raise_errors': False, + 'list_values': True, + 'create_empty': False, + 'file_error': False, + 'configspec': None, + 'stringify': True, + # option may be set to one of ('', ' ', '\t') + 'indent_type': None, + 'encoding': None, + 'default_encoding': None, + 'unrepr': False, + 'write_empty_values': False, +} + + + +def getObj(s): + global compiler + if compiler is None: + import compiler + s = "a=" + s + p = compiler.parse(s) + return p.getChildren()[1].getChildren()[0].getChildren()[1] + + +class UnknownType(Exception): + pass + + +class Builder(object): + + def build(self, o): + m = getattr(self, 'build_' + o.__class__.__name__, None) + if m is None: + raise UnknownType(o.__class__.__name__) + return m(o) + + def build_List(self, o): + return map(self.build, o.getChildren()) + + def build_Const(self, o): + return o.value + + def build_Dict(self, o): + d = {} + i = iter(map(self.build, o.getChildren())) + for el in i: + d[el] = i.next() + return d + + def build_Tuple(self, o): + return tuple(self.build_List(o)) + + def build_Name(self, o): + if o.name == 'None': + return None + if o.name == 'True': + return True + if o.name == 'False': + return False + + # An undefined Name + raise UnknownType('Undefined Name') + + def build_Add(self, o): + real, imag = map(self.build_Const, o.getChildren()) + try: + real = float(real) + except TypeError: + raise UnknownType('Add') + if not isinstance(imag, complex) or imag.real != 0.0: + raise UnknownType('Add') + return real+imag + + def build_Getattr(self, o): + parent = self.build(o.expr) + return getattr(parent, o.attrname) + + def build_UnarySub(self, o): + return -self.build_Const(o.getChildren()[0]) + + def build_UnaryAdd(self, o): + return self.build_Const(o.getChildren()[0]) + + +_builder = Builder() + + +def unrepr(s): + if not s: + return s + return _builder.build(getObj(s)) + + + +class ConfigObjError(SyntaxError): + """ + This is the base class for all errors that ConfigObj raises. + It is a subclass of SyntaxError. + """ + def __init__(self, message='', line_number=None, line=''): + self.line = line + self.line_number = line_number + SyntaxError.__init__(self, message) + + +class NestingError(ConfigObjError): + """ + This error indicates a level of nesting that doesn't match. + """ + + +class ParseError(ConfigObjError): + """ + This error indicates that a line is badly written. + It is neither a valid ``key = value`` line, + nor a valid section marker line. + """ + + +class ReloadError(IOError): + """ + A 'reload' operation failed. + This exception is a subclass of ``IOError``. + """ + def __init__(self): + IOError.__init__(self, 'reload failed, filename is not set.') + + +class DuplicateError(ConfigObjError): + """ + The keyword or section specified already exists. + """ + + +class ConfigspecError(ConfigObjError): + """ + An error occured whilst parsing a configspec. + """ + + +class InterpolationError(ConfigObjError): + """Base class for the two interpolation errors.""" + + +class InterpolationLoopError(InterpolationError): + """Maximum interpolation depth exceeded in string interpolation.""" + + def __init__(self, option): + InterpolationError.__init__( + self, + 'interpolation loop detected in value "%s".' % option) + + +class RepeatSectionError(ConfigObjError): + """ + This error indicates additional sections in a section with a + ``__many__`` (repeated) section. + """ + + +class MissingInterpolationOption(InterpolationError): + """A value specified for interpolation was missing.""" + def __init__(self, option): + msg = 'missing option "%s" in interpolation.' % option + InterpolationError.__init__(self, msg) + + +class UnreprError(ConfigObjError): + """An error parsing in unrepr mode.""" + + + +class InterpolationEngine(object): + """ + A helper class to help perform string interpolation. + + This class is an abstract base class; its descendants perform + the actual work. + """ + + # compiled regexp to use in self.interpolate() + _KEYCRE = re.compile(r"%\(([^)]*)\)s") + _cookie = '%' + + def __init__(self, section): + # the Section instance that "owns" this engine + self.section = section + + + def interpolate(self, key, value): + # short-cut + if not self._cookie in value: + return value + + def recursive_interpolate(key, value, section, backtrail): + """The function that does the actual work. + + ``value``: the string we're trying to interpolate. + ``section``: the section in which that string was found + ``backtrail``: a dict to keep track of where we've been, + to detect and prevent infinite recursion loops + + This is similar to a depth-first-search algorithm. + """ + # Have we been here already? + if (key, section.name) in backtrail: + # Yes - infinite loop detected + raise InterpolationLoopError(key) + # Place a marker on our backtrail so we won't come back here again + backtrail[(key, section.name)] = 1 + + # Now start the actual work + match = self._KEYCRE.search(value) + while match: + # The actual parsing of the match is implementation-dependent, + # so delegate to our helper function + k, v, s = self._parse_match(match) + if k is None: + # That's the signal that no further interpolation is needed + replacement = v + else: + # Further interpolation may be needed to obtain final value + replacement = recursive_interpolate(k, v, s, backtrail) + # Replace the matched string with its final value + start, end = match.span() + value = ''.join((value[:start], replacement, value[end:])) + new_search_start = start + len(replacement) + # Pick up the next interpolation key, if any, for next time + # through the while loop + match = self._KEYCRE.search(value, new_search_start) + + # Now safe to come back here again; remove marker from backtrail + del backtrail[(key, section.name)] + + return value + + # Back in interpolate(), all we have to do is kick off the recursive + # function with appropriate starting values + value = recursive_interpolate(key, value, self.section, {}) + return value + + + def _fetch(self, key): + """Helper function to fetch values from owning section. + + Returns a 2-tuple: the value, and the section where it was found. + """ + # switch off interpolation before we try and fetch anything ! + save_interp = self.section.main.interpolation + self.section.main.interpolation = False + + # Start at section that "owns" this InterpolationEngine + current_section = self.section + while True: + # try the current section first + val = current_section.get(key) + if val is not None and not isinstance(val, Section): + break + # try "DEFAULT" next + val = current_section.get('DEFAULT', {}).get(key) + if val is not None and not isinstance(val, Section): + break + # move up to parent and try again + # top-level's parent is itself + if current_section.parent is current_section: + # reached top level, time to give up + break + current_section = current_section.parent + + # restore interpolation to previous value before returning + self.section.main.interpolation = save_interp + if val is None: + raise MissingInterpolationOption(key) + return val, current_section + + + def _parse_match(self, match): + """Implementation-dependent helper function. + + Will be passed a match object corresponding to the interpolation + key we just found (e.g., "%(foo)s" or "$foo"). Should look up that + key in the appropriate config file section (using the ``_fetch()`` + helper function) and return a 3-tuple: (key, value, section) + + ``key`` is the name of the key we're looking for + ``value`` is the value found for that key + ``section`` is a reference to the section where it was found + + ``key`` and ``section`` should be None if no further + interpolation should be performed on the resulting value + (e.g., if we interpolated "$$" and returned "$"). + """ + raise NotImplementedError() + + + +class ConfigParserInterpolation(InterpolationEngine): + """Behaves like ConfigParser.""" + _cookie = '%' + _KEYCRE = re.compile(r"%\(([^)]*)\)s") + + def _parse_match(self, match): + key = match.group(1) + value, section = self._fetch(key) + return key, value, section + + + +class TemplateInterpolation(InterpolationEngine): + """Behaves like string.Template.""" + _cookie = '$' + _delimiter = '$' + _KEYCRE = re.compile(r""" + \$(?: + (?P\$) | # Two $ signs + (?P[_a-z][_a-z0-9]*) | # $name format + {(?P[^}]*)} # ${name} format + ) + """, re.IGNORECASE | re.VERBOSE) + + def _parse_match(self, match): + # Valid name (in or out of braces): fetch value from section + key = match.group('named') or match.group('braced') + if key is not None: + value, section = self._fetch(key) + return key, value, section + # Escaped delimiter (e.g., $$): return single delimiter + if match.group('escaped') is not None: + # Return None for key and section to indicate it's time to stop + return None, self._delimiter, None + # Anything else: ignore completely, just return it unchanged + return None, match.group(), None + + +interpolation_engines = { + 'configparser': ConfigParserInterpolation, + 'template': TemplateInterpolation, +} + + +def __newobj__(cls, *args): + # Hack for pickle + return cls.__new__(cls, *args) + +class Section(dict): + """ + A dictionary-like object that represents a section in a config file. + + It does string interpolation if the 'interpolation' attribute + of the 'main' object is set to True. + + Interpolation is tried first from this object, then from the 'DEFAULT' + section of this object, next from the parent and its 'DEFAULT' section, + and so on until the main object is reached. + + A Section will behave like an ordered dictionary - following the + order of the ``scalars`` and ``sections`` attributes. + You can use this to change the order of members. + + Iteration follows the order: scalars, then sections. + """ + + + def __setstate__(self, state): + dict.update(self, state[0]) + self.__dict__.update(state[1]) + + def __reduce__(self): + state = (dict(self), self.__dict__) + return (__newobj__, (self.__class__,), state) + + + def __init__(self, parent, depth, main, indict=None, name=None): + """ + * parent is the section above + * depth is the depth level of this section + * main is the main ConfigObj + * indict is a dictionary to initialise the section with + """ + if indict is None: + indict = {} + dict.__init__(self) + # used for nesting level *and* interpolation + self.parent = parent + # used for the interpolation attribute + self.main = main + # level of nesting depth of this Section + self.depth = depth + # purely for information + self.name = name + # + self._initialise() + # we do this explicitly so that __setitem__ is used properly + # (rather than just passing to ``dict.__init__``) + for entry, value in indict.iteritems(): + self[entry] = value + + + def _initialise(self): + # the sequence of scalar values in this Section + self.scalars = [] + # the sequence of sections in this Section + self.sections = [] + # for comments :-) + self.comments = {} + self.inline_comments = {} + # the configspec + self.configspec = None + # for defaults + self.defaults = [] + self.default_values = {} + self.extra_values = [] + self._created = False + + + def _interpolate(self, key, value): + try: + # do we already have an interpolation engine? + engine = self._interpolation_engine + except AttributeError: + # not yet: first time running _interpolate(), so pick the engine + name = self.main.interpolation + if name == True: # note that "if name:" would be incorrect here + # backwards-compatibility: interpolation=True means use default + name = DEFAULT_INTERPOLATION + name = name.lower() # so that "Template", "template", etc. all work + class_ = interpolation_engines.get(name, None) + if class_ is None: + # invalid value for self.main.interpolation + self.main.interpolation = False + return value + else: + # save reference to engine so we don't have to do this again + engine = self._interpolation_engine = class_(self) + # let the engine do the actual work + return engine.interpolate(key, value) + + + def __getitem__(self, key): + """Fetch the item and do string interpolation.""" + val = dict.__getitem__(self, key) + if self.main.interpolation: + if isinstance(val, basestring): + return self._interpolate(key, val) + if isinstance(val, list): + def _check(entry): + if isinstance(entry, basestring): + return self._interpolate(key, entry) + return entry + new = [_check(entry) for entry in val] + if new != val: + return new + return val + + + def __setitem__(self, key, value, unrepr=False): + """ + Correctly set a value. + + Making dictionary values Section instances. + (We have to special case 'Section' instances - which are also dicts) + + Keys must be strings. + Values need only be strings (or lists of strings) if + ``main.stringify`` is set. + + ``unrepr`` must be set when setting a value to a dictionary, without + creating a new sub-section. + """ + if not isinstance(key, basestring): + raise ValueError('The key "%s" is not a string.' % key) + + # add the comment + if key not in self.comments: + self.comments[key] = [] + self.inline_comments[key] = '' + # remove the entry from defaults + if key in self.defaults: + self.defaults.remove(key) + # + if isinstance(value, Section): + if key not in self: + self.sections.append(key) + dict.__setitem__(self, key, value) + elif isinstance(value, dict) and not unrepr: + # First create the new depth level, + # then create the section + if key not in self: + self.sections.append(key) + new_depth = self.depth + 1 + dict.__setitem__( + self, + key, + Section( + self, + new_depth, + self.main, + indict=value, + name=key)) + else: + if key not in self: + self.scalars.append(key) + if not self.main.stringify: + if isinstance(value, basestring): + pass + elif isinstance(value, (list, tuple)): + for entry in value: + if not isinstance(entry, basestring): + raise TypeError('Value is not a string "%s".' % entry) + else: + raise TypeError('Value is not a string "%s".' % value) + dict.__setitem__(self, key, value) + + + def __delitem__(self, key): + """Remove items from the sequence when deleting.""" + dict. __delitem__(self, key) + if key in self.scalars: + self.scalars.remove(key) + else: + self.sections.remove(key) + del self.comments[key] + del self.inline_comments[key] + + + def get(self, key, default=None): + """A version of ``get`` that doesn't bypass string interpolation.""" + try: + return self[key] + except KeyError: + return default + + + def update(self, indict): + """ + A version of update that uses our ``__setitem__``. + """ + for entry in indict: + self[entry] = indict[entry] + + + def pop(self, key, default=MISSING): + """ + 'D.pop(k[,d]) -> v, remove specified key and return the corresponding value. + If key is not found, d is returned if given, otherwise KeyError is raised' + """ + try: + val = self[key] + except KeyError: + if default is MISSING: + raise + val = default + else: + del self[key] + return val + + + def popitem(self): + """Pops the first (key,val)""" + sequence = (self.scalars + self.sections) + if not sequence: + raise KeyError(": 'popitem(): dictionary is empty'") + key = sequence[0] + val = self[key] + del self[key] + return key, val + + + def clear(self): + """ + A version of clear that also affects scalars/sections + Also clears comments and configspec. + + Leaves other attributes alone : + depth/main/parent are not affected + """ + dict.clear(self) + self.scalars = [] + self.sections = [] + self.comments = {} + self.inline_comments = {} + self.configspec = None + self.defaults = [] + self.extra_values = [] + + + def setdefault(self, key, default=None): + """A version of setdefault that sets sequence if appropriate.""" + try: + return self[key] + except KeyError: + self[key] = default + return self[key] + + + def items(self): + """D.items() -> list of D's (key, value) pairs, as 2-tuples""" + return zip((self.scalars + self.sections), self.values()) + + + def keys(self): + """D.keys() -> list of D's keys""" + return (self.scalars + self.sections) + + + def values(self): + """D.values() -> list of D's values""" + return [self[key] for key in (self.scalars + self.sections)] + + + def iteritems(self): + """D.iteritems() -> an iterator over the (key, value) items of D""" + return iter(self.items()) + + + def iterkeys(self): + """D.iterkeys() -> an iterator over the keys of D""" + return iter((self.scalars + self.sections)) + + __iter__ = iterkeys + + + def itervalues(self): + """D.itervalues() -> an iterator over the values of D""" + return iter(self.values()) + + + def __repr__(self): + """x.__repr__() <==> repr(x)""" + def _getval(key): + try: + return self[key] + except MissingInterpolationOption: + return dict.__getitem__(self, key) + return '{%s}' % ', '.join([('%s: %s' % (repr(key), repr(_getval(key)))) + for key in (self.scalars + self.sections)]) + + __str__ = __repr__ + __str__.__doc__ = "x.__str__() <==> str(x)" + + + # Extra methods - not in a normal dictionary + + def dict(self): + """ + Return a deepcopy of self as a dictionary. + + All members that are ``Section`` instances are recursively turned to + ordinary dictionaries - by calling their ``dict`` method. + + >>> n = a.dict() + >>> n == a + 1 + >>> n is a + 0 + """ + newdict = {} + for entry in self: + this_entry = self[entry] + if isinstance(this_entry, Section): + this_entry = this_entry.dict() + elif isinstance(this_entry, list): + # create a copy rather than a reference + this_entry = list(this_entry) + elif isinstance(this_entry, tuple): + # create a copy rather than a reference + this_entry = tuple(this_entry) + newdict[entry] = this_entry + return newdict + + + def merge(self, indict): + """ + A recursive update - useful for merging config files. + + >>> a = '''[section1] + ... option1 = True + ... [[subsection]] + ... more_options = False + ... # end of file'''.splitlines() + >>> b = '''# File is user.ini + ... [section1] + ... option1 = False + ... # end of file'''.splitlines() + >>> c1 = ConfigObj(b) + >>> c2 = ConfigObj(a) + >>> c2.merge(c1) + >>> c2 + ConfigObj({'section1': {'option1': 'False', 'subsection': {'more_options': 'False'}}}) + """ + for key, val in indict.items(): + if (key in self and isinstance(self[key], dict) and + isinstance(val, dict)): + self[key].merge(val) + else: + self[key] = val + + + def rename(self, oldkey, newkey): + """ + Change a keyname to another, without changing position in sequence. + + Implemented so that transformations can be made on keys, + as well as on values. (used by encode and decode) + + Also renames comments. + """ + if oldkey in self.scalars: + the_list = self.scalars + elif oldkey in self.sections: + the_list = self.sections + else: + raise KeyError('Key "%s" not found.' % oldkey) + pos = the_list.index(oldkey) + # + val = self[oldkey] + dict.__delitem__(self, oldkey) + dict.__setitem__(self, newkey, val) + the_list.remove(oldkey) + the_list.insert(pos, newkey) + comm = self.comments[oldkey] + inline_comment = self.inline_comments[oldkey] + del self.comments[oldkey] + del self.inline_comments[oldkey] + self.comments[newkey] = comm + self.inline_comments[newkey] = inline_comment + + + def walk(self, function, raise_errors=True, + call_on_sections=False, **keywargs): + """ + Walk every member and call a function on the keyword and value. + + Return a dictionary of the return values + + If the function raises an exception, raise the errror + unless ``raise_errors=False``, in which case set the return value to + ``False``. + + Any unrecognised keyword arguments you pass to walk, will be pased on + to the function you pass in. + + Note: if ``call_on_sections`` is ``True`` then - on encountering a + subsection, *first* the function is called for the *whole* subsection, + and then recurses into it's members. This means your function must be + able to handle strings, dictionaries and lists. This allows you + to change the key of subsections as well as for ordinary members. The + return value when called on the whole subsection has to be discarded. + + See the encode and decode methods for examples, including functions. + + .. admonition:: caution + + You can use ``walk`` to transform the names of members of a section + but you mustn't add or delete members. + + >>> config = '''[XXXXsection] + ... XXXXkey = XXXXvalue'''.splitlines() + >>> cfg = ConfigObj(config) + >>> cfg + ConfigObj({'XXXXsection': {'XXXXkey': 'XXXXvalue'}}) + >>> def transform(section, key): + ... val = section[key] + ... newkey = key.replace('XXXX', 'CLIENT1') + ... section.rename(key, newkey) + ... if isinstance(val, (tuple, list, dict)): + ... pass + ... else: + ... val = val.replace('XXXX', 'CLIENT1') + ... section[newkey] = val + >>> cfg.walk(transform, call_on_sections=True) + {'CLIENT1section': {'CLIENT1key': None}} + >>> cfg + ConfigObj({'CLIENT1section': {'CLIENT1key': 'CLIENT1value'}}) + """ + out = {} + # scalars first + for i in range(len(self.scalars)): + entry = self.scalars[i] + try: + val = function(self, entry, **keywargs) + # bound again in case name has changed + entry = self.scalars[i] + out[entry] = val + except Exception: + if raise_errors: + raise + else: + entry = self.scalars[i] + out[entry] = False + # then sections + for i in range(len(self.sections)): + entry = self.sections[i] + if call_on_sections: + try: + function(self, entry, **keywargs) + except Exception: + if raise_errors: + raise + else: + entry = self.sections[i] + out[entry] = False + # bound again in case name has changed + entry = self.sections[i] + # previous result is discarded + out[entry] = self[entry].walk( + function, + raise_errors=raise_errors, + call_on_sections=call_on_sections, + **keywargs) + return out + + + def as_bool(self, key): + """ + Accepts a key as input. The corresponding value must be a string or + the objects (``True`` or 1) or (``False`` or 0). We allow 0 and 1 to + retain compatibility with Python 2.2. + + If the string is one of ``True``, ``On``, ``Yes``, or ``1`` it returns + ``True``. + + If the string is one of ``False``, ``Off``, ``No``, or ``0`` it returns + ``False``. + + ``as_bool`` is not case sensitive. + + Any other input will raise a ``ValueError``. + + >>> a = ConfigObj() + >>> a['a'] = 'fish' + >>> a.as_bool('a') + Traceback (most recent call last): + ValueError: Value "fish" is neither True nor False + >>> a['b'] = 'True' + >>> a.as_bool('b') + 1 + >>> a['b'] = 'off' + >>> a.as_bool('b') + 0 + """ + val = self[key] + if val == True: + return True + elif val == False: + return False + else: + try: + if not isinstance(val, basestring): + # TODO: Why do we raise a KeyError here? + raise KeyError() + else: + return self.main._bools[val.lower()] + except KeyError: + raise ValueError('Value "%s" is neither True nor False' % val) + + + def as_int(self, key): + """ + A convenience method which coerces the specified value to an integer. + + If the value is an invalid literal for ``int``, a ``ValueError`` will + be raised. + + >>> a = ConfigObj() + >>> a['a'] = 'fish' + >>> a.as_int('a') + Traceback (most recent call last): + ValueError: invalid literal for int() with base 10: 'fish' + >>> a['b'] = '1' + >>> a.as_int('b') + 1 + >>> a['b'] = '3.2' + >>> a.as_int('b') + Traceback (most recent call last): + ValueError: invalid literal for int() with base 10: '3.2' + """ + return int(self[key]) + + + def as_float(self, key): + """ + A convenience method which coerces the specified value to a float. + + If the value is an invalid literal for ``float``, a ``ValueError`` will + be raised. + + >>> a = ConfigObj() + >>> a['a'] = 'fish' + >>> a.as_float('a') + Traceback (most recent call last): + ValueError: invalid literal for float(): fish + >>> a['b'] = '1' + >>> a.as_float('b') + 1.0 + >>> a['b'] = '3.2' + >>> a.as_float('b') + 3.2000000000000002 + """ + return float(self[key]) + + + def as_list(self, key): + """ + A convenience method which fetches the specified value, guaranteeing + that it is a list. + + >>> a = ConfigObj() + >>> a['a'] = 1 + >>> a.as_list('a') + [1] + >>> a['a'] = (1,) + >>> a.as_list('a') + [1] + >>> a['a'] = [1] + >>> a.as_list('a') + [1] + """ + result = self[key] + if isinstance(result, (tuple, list)): + return list(result) + return [result] + + + def restore_default(self, key): + """ + Restore (and return) default value for the specified key. + + This method will only work for a ConfigObj that was created + with a configspec and has been validated. + + If there is no default value for this key, ``KeyError`` is raised. + """ + default = self.default_values[key] + dict.__setitem__(self, key, default) + if key not in self.defaults: + self.defaults.append(key) + return default + + + def restore_defaults(self): + """ + Recursively restore default values to all members + that have them. + + This method will only work for a ConfigObj that was created + with a configspec and has been validated. + + It doesn't delete or modify entries without default values. + """ + for key in self.default_values: + self.restore_default(key) + + for section in self.sections: + self[section].restore_defaults() + + +class ConfigObj(Section): + """An object to read, create, and write config files.""" + + _keyword = re.compile(r'''^ # line start + (\s*) # indentation + ( # keyword + (?:".*?")| # double quotes + (?:'.*?')| # single quotes + (?:[^'"=].*?) # no quotes + ) + \s*=\s* # divider + (.*) # value (including list values and comments) + $ # line end + ''', + re.VERBOSE) + + _sectionmarker = re.compile(r'''^ + (\s*) # 1: indentation + ((?:\[\s*)+) # 2: section marker open + ( # 3: section name open + (?:"\s*\S.*?\s*")| # at least one non-space with double quotes + (?:'\s*\S.*?\s*')| # at least one non-space with single quotes + (?:[^'"\s].*?) # at least one non-space unquoted + ) # section name close + ((?:\s*\])+) # 4: section marker close + \s*(\#.*)? # 5: optional comment + $''', + re.VERBOSE) + + # this regexp pulls list values out as a single string + # or single values and comments + # FIXME: this regex adds a '' to the end of comma terminated lists + # workaround in ``_handle_value`` + _valueexp = re.compile(r'''^ + (?: + (?: + ( + (?: + (?: + (?:".*?")| # double quotes + (?:'.*?')| # single quotes + (?:[^'",\#][^,\#]*?) # unquoted + ) + \s*,\s* # comma + )* # match all list items ending in a comma (if any) + ) + ( + (?:".*?")| # double quotes + (?:'.*?')| # single quotes + (?:[^'",\#\s][^,]*?)| # unquoted + (?:(? 1: + msg = "Parsing failed with several errors.\nFirst error %s" % info + error = ConfigObjError(msg) + else: + error = self._errors[0] + # set the errors attribute; it's a list of tuples: + # (error_type, message, line_number) + error.errors = self._errors + # set the config attribute + error.config = self + raise error + # delete private attributes + del self._errors + + if configspec is None: + self.configspec = None + else: + self._handle_configspec(configspec) + + + def _initialise(self, options=None): + if options is None: + options = OPTION_DEFAULTS + + # initialise a few variables + self.filename = None + self._errors = [] + self.raise_errors = options['raise_errors'] + self.interpolation = options['interpolation'] + self.list_values = options['list_values'] + self.create_empty = options['create_empty'] + self.file_error = options['file_error'] + self.stringify = options['stringify'] + self.indent_type = options['indent_type'] + self.encoding = options['encoding'] + self.default_encoding = options['default_encoding'] + self.BOM = False + self.newlines = None + self.write_empty_values = options['write_empty_values'] + self.unrepr = options['unrepr'] + + self.initial_comment = [] + self.final_comment = [] + self.configspec = None + + if self._inspec: + self.list_values = False + + # Clear section attributes as well + Section._initialise(self) + + + def __repr__(self): + def _getval(key): + try: + return self[key] + except MissingInterpolationOption: + return dict.__getitem__(self, key) + return ('ConfigObj({%s})' % + ', '.join([('%s: %s' % (repr(key), repr(_getval(key)))) + for key in (self.scalars + self.sections)])) + + + def _handle_bom(self, infile): + """ + Handle any BOM, and decode if necessary. + + If an encoding is specified, that *must* be used - but the BOM should + still be removed (and the BOM attribute set). + + (If the encoding is wrongly specified, then a BOM for an alternative + encoding won't be discovered or removed.) + + If an encoding is not specified, UTF8 or UTF16 BOM will be detected and + removed. The BOM attribute will be set. UTF16 will be decoded to + unicode. + + NOTE: This method must not be called with an empty ``infile``. + + Specifying the *wrong* encoding is likely to cause a + ``UnicodeDecodeError``. + + ``infile`` must always be returned as a list of lines, but may be + passed in as a single string. + """ + if ((self.encoding is not None) and + (self.encoding.lower() not in BOM_LIST)): + # No need to check for a BOM + # the encoding specified doesn't have one + # just decode + return self._decode(infile, self.encoding) + + if isinstance(infile, (list, tuple)): + line = infile[0] + else: + line = infile + if self.encoding is not None: + # encoding explicitly supplied + # And it could have an associated BOM + # TODO: if encoding is just UTF16 - we ought to check for both + # TODO: big endian and little endian versions. + enc = BOM_LIST[self.encoding.lower()] + if enc == 'utf_16': + # For UTF16 we try big endian and little endian + for BOM, (encoding, final_encoding) in BOMS.items(): + if not final_encoding: + # skip UTF8 + continue + if infile.startswith(BOM): + ### BOM discovered + ##self.BOM = True + # Don't need to remove BOM + return self._decode(infile, encoding) + + # If we get this far, will *probably* raise a DecodeError + # As it doesn't appear to start with a BOM + return self._decode(infile, self.encoding) + + # Must be UTF8 + BOM = BOM_SET[enc] + if not line.startswith(BOM): + return self._decode(infile, self.encoding) + + newline = line[len(BOM):] + + # BOM removed + if isinstance(infile, (list, tuple)): + infile[0] = newline + else: + infile = newline + self.BOM = True + return self._decode(infile, self.encoding) + + # No encoding specified - so we need to check for UTF8/UTF16 + for BOM, (encoding, final_encoding) in BOMS.items(): + if not line.startswith(BOM): + continue + else: + # BOM discovered + self.encoding = final_encoding + if not final_encoding: + self.BOM = True + # UTF8 + # remove BOM + newline = line[len(BOM):] + if isinstance(infile, (list, tuple)): + infile[0] = newline + else: + infile = newline + # UTF8 - don't decode + if isinstance(infile, basestring): + return infile.splitlines(True) + else: + return infile + # UTF16 - have to decode + return self._decode(infile, encoding) + + # No BOM discovered and no encoding specified, just return + if isinstance(infile, basestring): + # infile read from a file will be a single string + return infile.splitlines(True) + return infile + + + def _a_to_u(self, aString): + """Decode ASCII strings to unicode if a self.encoding is specified.""" + if self.encoding: + return aString.decode('ascii') + else: + return aString + + + def _decode(self, infile, encoding): + """ + Decode infile to unicode. Using the specified encoding. + + if is a string, it also needs converting to a list. + """ + if isinstance(infile, basestring): + # can't be unicode + # NOTE: Could raise a ``UnicodeDecodeError`` + return infile.decode(encoding).splitlines(True) + for i, line in enumerate(infile): + if not isinstance(line, unicode): + # NOTE: The isinstance test here handles mixed lists of unicode/string + # NOTE: But the decode will break on any non-string values + # NOTE: Or could raise a ``UnicodeDecodeError`` + infile[i] = line.decode(encoding) + return infile + + + def _decode_element(self, line): + """Decode element to unicode if necessary.""" + if not self.encoding: + return line + if isinstance(line, str) and self.default_encoding: + return line.decode(self.default_encoding) + return line + + + def _str(self, value): + """ + Used by ``stringify`` within validate, to turn non-string values + into strings. + """ + if not isinstance(value, basestring): + return str(value) + else: + return value + + + def _parse(self, infile): + """Actually parse the config file.""" + temp_list_values = self.list_values + if self.unrepr: + self.list_values = False + + comment_list = [] + done_start = False + this_section = self + maxline = len(infile) - 1 + cur_index = -1 + reset_comment = False + + while cur_index < maxline: + if reset_comment: + comment_list = [] + cur_index += 1 + line = infile[cur_index] + sline = line.strip() + # do we have anything on the line ? + if not sline or sline.startswith('#'): + reset_comment = False + comment_list.append(line) + continue + + if not done_start: + # preserve initial comment + self.initial_comment = comment_list + comment_list = [] + done_start = True + + reset_comment = True + # first we check if it's a section marker + mat = self._sectionmarker.match(line) + if mat is not None: + # is a section line + (indent, sect_open, sect_name, sect_close, comment) = mat.groups() + if indent and (self.indent_type is None): + self.indent_type = indent + cur_depth = sect_open.count('[') + if cur_depth != sect_close.count(']'): + self._handle_error("Cannot compute the section depth at line %s.", + NestingError, infile, cur_index) + continue + + if cur_depth < this_section.depth: + # the new section is dropping back to a previous level + try: + parent = self._match_depth(this_section, + cur_depth).parent + except SyntaxError: + self._handle_error("Cannot compute nesting level at line %s.", + NestingError, infile, cur_index) + continue + elif cur_depth == this_section.depth: + # the new section is a sibling of the current section + parent = this_section.parent + elif cur_depth == this_section.depth + 1: + # the new section is a child the current section + parent = this_section + else: + self._handle_error("Section too nested at line %s.", + NestingError, infile, cur_index) + + sect_name = self._unquote(sect_name) + if sect_name in parent: + self._handle_error('Duplicate section name at line %s.', + DuplicateError, infile, cur_index) + continue + + # create the new section + this_section = Section( + parent, + cur_depth, + self, + name=sect_name) + parent[sect_name] = this_section + parent.inline_comments[sect_name] = comment + parent.comments[sect_name] = comment_list + continue + # + # it's not a section marker, + # so it should be a valid ``key = value`` line + mat = self._keyword.match(line) + if mat is None: + # it neither matched as a keyword + # or a section marker + self._handle_error( + 'Invalid line at line "%s".', + ParseError, infile, cur_index) + else: + # is a keyword value + # value will include any inline comment + (indent, key, value) = mat.groups() + if indent and (self.indent_type is None): + self.indent_type = indent + # check for a multiline value + if value[:3] in ['"""', "'''"]: + try: + value, comment, cur_index = self._multiline( + value, infile, cur_index, maxline) + except SyntaxError: + self._handle_error( + 'Parse error in value at line %s.', + ParseError, infile, cur_index) + continue + else: + if self.unrepr: + comment = '' + try: + value = unrepr(value) + except Exception, e: + if type(e) == UnknownType: + msg = 'Unknown name or type in value at line %s.' + else: + msg = 'Parse error in value at line %s.' + self._handle_error(msg, UnreprError, infile, + cur_index) + continue + else: + if self.unrepr: + comment = '' + try: + value = unrepr(value) + except Exception, e: + if isinstance(e, UnknownType): + msg = 'Unknown name or type in value at line %s.' + else: + msg = 'Parse error in value at line %s.' + self._handle_error(msg, UnreprError, infile, + cur_index) + continue + else: + # extract comment and lists + try: + (value, comment) = self._handle_value(value) + except SyntaxError: + self._handle_error( + 'Parse error in value at line %s.', + ParseError, infile, cur_index) + continue + # + key = self._unquote(key) + if key in this_section: + self._handle_error( + 'Duplicate keyword name at line %s.', + DuplicateError, infile, cur_index) + continue + # add the key. + # we set unrepr because if we have got this far we will never + # be creating a new section + this_section.__setitem__(key, value, unrepr=True) + this_section.inline_comments[key] = comment + this_section.comments[key] = comment_list + continue + # + if self.indent_type is None: + # no indentation used, set the type accordingly + self.indent_type = '' + + # preserve the final comment + if not self and not self.initial_comment: + self.initial_comment = comment_list + elif not reset_comment: + self.final_comment = comment_list + self.list_values = temp_list_values + + + def _match_depth(self, sect, depth): + """ + Given a section and a depth level, walk back through the sections + parents to see if the depth level matches a previous section. + + Return a reference to the right section, + or raise a SyntaxError. + """ + while depth < sect.depth: + if sect is sect.parent: + # we've reached the top level already + raise SyntaxError() + sect = sect.parent + if sect.depth == depth: + return sect + # shouldn't get here + raise SyntaxError() + + + def _handle_error(self, text, ErrorClass, infile, cur_index): + """ + Handle an error according to the error settings. + + Either raise the error or store it. + The error will have occured at ``cur_index`` + """ + line = infile[cur_index] + cur_index += 1 + message = text % cur_index + error = ErrorClass(message, cur_index, line) + if self.raise_errors: + # raise the error - parsing stops here + raise error + # store the error + # reraise when parsing has finished + self._errors.append(error) + + + def _unquote(self, value): + """Return an unquoted version of a value""" + if not value: + # should only happen during parsing of lists + raise SyntaxError + if (value[0] == value[-1]) and (value[0] in ('"', "'")): + value = value[1:-1] + return value + + + def _quote(self, value, multiline=True): + """ + Return a safely quoted version of a value. + + Raise a ConfigObjError if the value cannot be safely quoted. + If multiline is ``True`` (default) then use triple quotes + if necessary. + + * Don't quote values that don't need it. + * Recursively quote members of a list and return a comma joined list. + * Multiline is ``False`` for lists. + * Obey list syntax for empty and single member lists. + + If ``list_values=False`` then the value is only quoted if it contains + a ``\\n`` (is multiline) or '#'. + + If ``write_empty_values`` is set, and the value is an empty string, it + won't be quoted. + """ + if multiline and self.write_empty_values and value == '': + # Only if multiline is set, so that it is used for values not + # keys, and not values that are part of a list + return '' + + if multiline and isinstance(value, (list, tuple)): + if not value: + return ',' + elif len(value) == 1: + return self._quote(value[0], multiline=False) + ',' + return ', '.join([self._quote(val, multiline=False) + for val in value]) + if not isinstance(value, basestring): + if self.stringify: + value = str(value) + else: + raise TypeError('Value "%s" is not a string.' % value) + + if not value: + return '""' + + no_lists_no_quotes = not self.list_values and '\n' not in value and '#' not in value + need_triple = multiline and ((("'" in value) and ('"' in value)) or ('\n' in value )) + hash_triple_quote = multiline and not need_triple and ("'" in value) and ('"' in value) and ('#' in value) + check_for_single = (no_lists_no_quotes or not need_triple) and not hash_triple_quote + + if check_for_single: + if not self.list_values: + # we don't quote if ``list_values=False`` + quot = noquot + # for normal values either single or double quotes will do + elif '\n' in value: + # will only happen if multiline is off - e.g. '\n' in key + raise ConfigObjError('Value "%s" cannot be safely quoted.' % value) + elif ((value[0] not in wspace_plus) and + (value[-1] not in wspace_plus) and + (',' not in value)): + quot = noquot + else: + quot = self._get_single_quote(value) + else: + # if value has '\n' or "'" *and* '"', it will need triple quotes + quot = self._get_triple_quote(value) + + if quot == noquot and '#' in value and self.list_values: + quot = self._get_single_quote(value) + + return quot % value + + + def _get_single_quote(self, value): + if ("'" in value) and ('"' in value): + raise ConfigObjError('Value "%s" cannot be safely quoted.' % value) + elif '"' in value: + quot = squot + else: + quot = dquot + return quot + + + def _get_triple_quote(self, value): + if (value.find('"""') != -1) and (value.find("'''") != -1): + raise ConfigObjError('Value "%s" cannot be safely quoted.' % value) + if value.find('"""') == -1: + quot = tdquot + else: + quot = tsquot + return quot + + + def _handle_value(self, value): + """ + Given a value string, unquote, remove comment, + handle lists. (including empty and single member lists) + """ + if self._inspec: + # Parsing a configspec so don't handle comments + return (value, '') + # do we look for lists in values ? + if not self.list_values: + mat = self._nolistvalue.match(value) + if mat is None: + raise SyntaxError() + # NOTE: we don't unquote here + return mat.groups() + # + mat = self._valueexp.match(value) + if mat is None: + # the value is badly constructed, probably badly quoted, + # or an invalid list + raise SyntaxError() + (list_values, single, empty_list, comment) = mat.groups() + if (list_values == '') and (single is None): + # change this if you want to accept empty values + raise SyntaxError() + # NOTE: note there is no error handling from here if the regex + # is wrong: then incorrect values will slip through + if empty_list is not None: + # the single comma - meaning an empty list + return ([], comment) + if single is not None: + # handle empty values + if list_values and not single: + # FIXME: the '' is a workaround because our regex now matches + # '' at the end of a list if it has a trailing comma + single = None + else: + single = single or '""' + single = self._unquote(single) + if list_values == '': + # not a list value + return (single, comment) + the_list = self._listvalueexp.findall(list_values) + the_list = [self._unquote(val) for val in the_list] + if single is not None: + the_list += [single] + return (the_list, comment) + + + def _multiline(self, value, infile, cur_index, maxline): + """Extract the value, where we are in a multiline situation.""" + quot = value[:3] + newvalue = value[3:] + single_line = self._triple_quote[quot][0] + multi_line = self._triple_quote[quot][1] + mat = single_line.match(value) + if mat is not None: + retval = list(mat.groups()) + retval.append(cur_index) + return retval + elif newvalue.find(quot) != -1: + # somehow the triple quote is missing + raise SyntaxError() + # + while cur_index < maxline: + cur_index += 1 + newvalue += '\n' + line = infile[cur_index] + if line.find(quot) == -1: + newvalue += line + else: + # end of multiline, process it + break + else: + # we've got to the end of the config, oops... + raise SyntaxError() + mat = multi_line.match(line) + if mat is None: + # a badly formed line + raise SyntaxError() + (value, comment) = mat.groups() + return (newvalue + value, comment, cur_index) + + + def _handle_configspec(self, configspec): + """Parse the configspec.""" + # FIXME: Should we check that the configspec was created with the + # correct settings ? (i.e. ``list_values=False``) + if not isinstance(configspec, ConfigObj): + try: + configspec = ConfigObj(configspec, + raise_errors=True, + file_error=True, + _inspec=True) + except ConfigObjError, e: + # FIXME: Should these errors have a reference + # to the already parsed ConfigObj ? + raise ConfigspecError('Parsing configspec failed: %s' % e) + except IOError, e: + raise IOError('Reading configspec failed: %s' % e) + + self.configspec = configspec + + + + def _set_configspec(self, section, copy): + """ + Called by validate. Handles setting the configspec on subsections + including sections to be validated by __many__ + """ + configspec = section.configspec + many = configspec.get('__many__') + if isinstance(many, dict): + for entry in section.sections: + if entry not in configspec: + section[entry].configspec = many + + for entry in configspec.sections: + if entry == '__many__': + continue + if entry not in section: + section[entry] = {} + section[entry]._created = True + if copy: + # copy comments + section.comments[entry] = configspec.comments.get(entry, []) + section.inline_comments[entry] = configspec.inline_comments.get(entry, '') + + # Could be a scalar when we expect a section + if isinstance(section[entry], Section): + section[entry].configspec = configspec[entry] + + + def _write_line(self, indent_string, entry, this_entry, comment): + """Write an individual line, for the write method""" + # NOTE: the calls to self._quote here handles non-StringType values. + if not self.unrepr: + val = self._decode_element(self._quote(this_entry)) + else: + val = repr(this_entry) + return '%s%s%s%s%s' % (indent_string, + self._decode_element(self._quote(entry, multiline=False)), + self._a_to_u(' = '), + val, + self._decode_element(comment)) + + + def _write_marker(self, indent_string, depth, entry, comment): + """Write a section marker line""" + return '%s%s%s%s%s' % (indent_string, + self._a_to_u('[' * depth), + self._quote(self._decode_element(entry), multiline=False), + self._a_to_u(']' * depth), + self._decode_element(comment)) + + + def _handle_comment(self, comment): + """Deal with a comment.""" + if not comment: + return '' + start = self.indent_type + if not comment.startswith('#'): + start += self._a_to_u(' # ') + return (start + comment) + + + # Public methods + + def write(self, outfile=None, section=None): + """ + Write the current ConfigObj as a file + + tekNico: FIXME: use StringIO instead of real files + + >>> filename = a.filename + >>> a.filename = 'test.ini' + >>> a.write() + >>> a.filename = filename + >>> a == ConfigObj('test.ini', raise_errors=True) + 1 + >>> import os + >>> os.remove('test.ini') + """ + if self.indent_type is None: + # this can be true if initialised from a dictionary + self.indent_type = DEFAULT_INDENT_TYPE + + out = [] + cs = self._a_to_u('#') + csp = self._a_to_u('# ') + if section is None: + int_val = self.interpolation + self.interpolation = False + section = self + for line in self.initial_comment: + line = self._decode_element(line) + stripped_line = line.strip() + if stripped_line and not stripped_line.startswith(cs): + line = csp + line + out.append(line) + + indent_string = self.indent_type * section.depth + for entry in (section.scalars + section.sections): + if entry in section.defaults: + # don't write out default values + continue + for comment_line in section.comments[entry]: + comment_line = self._decode_element(comment_line.lstrip()) + if comment_line and not comment_line.startswith(cs): + comment_line = csp + comment_line + out.append(indent_string + comment_line) + this_entry = section[entry] + comment = self._handle_comment(section.inline_comments[entry]) + + if isinstance(this_entry, dict): + # a section + out.append(self._write_marker( + indent_string, + this_entry.depth, + entry, + comment)) + out.extend(self.write(section=this_entry)) + else: + out.append(self._write_line( + indent_string, + entry, + this_entry, + comment)) + + if section is self: + for line in self.final_comment: + line = self._decode_element(line) + stripped_line = line.strip() + if stripped_line and not stripped_line.startswith(cs): + line = csp + line + out.append(line) + self.interpolation = int_val + + if section is not self: + return out + + if (self.filename is None) and (outfile is None): + # output a list of lines + # might need to encode + # NOTE: This will *screw* UTF16, each line will start with the BOM + if self.encoding: + out = [l.encode(self.encoding) for l in out] + if (self.BOM and ((self.encoding is None) or + (BOM_LIST.get(self.encoding.lower()) == 'utf_8'))): + # Add the UTF8 BOM + if not out: + out.append('') + out[0] = BOM_UTF8 + out[0] + return out + + # Turn the list to a string, joined with correct newlines + newline = self.newlines or os.linesep + if (getattr(outfile, 'mode', None) is not None and outfile.mode == 'w' + and sys.platform == 'win32' and newline == '\r\n'): + # Windows specific hack to avoid writing '\r\r\n' + newline = '\n' + output = self._a_to_u(newline).join(out) + if self.encoding: + output = output.encode(self.encoding) + if self.BOM and ((self.encoding is None) or match_utf8(self.encoding)): + # Add the UTF8 BOM + output = BOM_UTF8 + output + + if not output.endswith(newline): + output += newline + if outfile is not None: + outfile.write(output) + else: + h = open(self.filename, 'wb') + h.write(output) + h.close() + + + def validate(self, validator, preserve_errors=False, copy=False, + section=None): + """ + Test the ConfigObj against a configspec. + + It uses the ``validator`` object from *validate.py*. + + To run ``validate`` on the current ConfigObj, call: :: + + test = config.validate(validator) + + (Normally having previously passed in the configspec when the ConfigObj + was created - you can dynamically assign a dictionary of checks to the + ``configspec`` attribute of a section though). + + It returns ``True`` if everything passes, or a dictionary of + pass/fails (True/False). If every member of a subsection passes, it + will just have the value ``True``. (It also returns ``False`` if all + members fail). + + In addition, it converts the values from strings to their native + types if their checks pass (and ``stringify`` is set). + + If ``preserve_errors`` is ``True`` (``False`` is default) then instead + of a marking a fail with a ``False``, it will preserve the actual + exception object. This can contain info about the reason for failure. + For example the ``VdtValueTooSmallError`` indicates that the value + supplied was too small. If a value (or section) is missing it will + still be marked as ``False``. + + You must have the validate module to use ``preserve_errors=True``. + + You can then use the ``flatten_errors`` function to turn your nested + results dictionary into a flattened list of failures - useful for + displaying meaningful error messages. + """ + if section is None: + if self.configspec is None: + raise ValueError('No configspec supplied.') + if preserve_errors: + # We do this once to remove a top level dependency on the validate module + # Which makes importing configobj faster + from validate import VdtMissingValue + self._vdtMissingValue = VdtMissingValue + + section = self + + if copy: + section.initial_comment = section.configspec.initial_comment + section.final_comment = section.configspec.final_comment + section.encoding = section.configspec.encoding + section.BOM = section.configspec.BOM + section.newlines = section.configspec.newlines + section.indent_type = section.configspec.indent_type + + # + # section.default_values.clear() #?? + configspec = section.configspec + self._set_configspec(section, copy) + + + def validate_entry(entry, spec, val, missing, ret_true, ret_false): + section.default_values.pop(entry, None) + + try: + section.default_values[entry] = validator.get_default_value(configspec[entry]) + except (KeyError, AttributeError, validator.baseErrorClass): + # No default, bad default or validator has no 'get_default_value' + # (e.g. SimpleVal) + pass + + try: + check = validator.check(spec, + val, + missing=missing + ) + except validator.baseErrorClass, e: + if not preserve_errors or isinstance(e, self._vdtMissingValue): + out[entry] = False + else: + # preserve the error + out[entry] = e + ret_false = False + ret_true = False + else: + ret_false = False + out[entry] = True + if self.stringify or missing: + # if we are doing type conversion + # or the value is a supplied default + if not self.stringify: + if isinstance(check, (list, tuple)): + # preserve lists + check = [self._str(item) for item in check] + elif missing and check is None: + # convert the None from a default to a '' + check = '' + else: + check = self._str(check) + if (check != val) or missing: + section[entry] = check + if not copy and missing and entry not in section.defaults: + section.defaults.append(entry) + return ret_true, ret_false + + # + out = {} + ret_true = True + ret_false = True + + unvalidated = [k for k in section.scalars if k not in configspec] + incorrect_sections = [k for k in configspec.sections if k in section.scalars] + incorrect_scalars = [k for k in configspec.scalars if k in section.sections] + + for entry in configspec.scalars: + if entry in ('__many__', '___many___'): + # reserved names + continue + if (not entry in section.scalars) or (entry in section.defaults): + # missing entries + # or entries from defaults + missing = True + val = None + if copy and entry not in section.scalars: + # copy comments + section.comments[entry] = ( + configspec.comments.get(entry, [])) + section.inline_comments[entry] = ( + configspec.inline_comments.get(entry, '')) + # + else: + missing = False + val = section[entry] + + ret_true, ret_false = validate_entry(entry, configspec[entry], val, + missing, ret_true, ret_false) + + many = None + if '__many__' in configspec.scalars: + many = configspec['__many__'] + elif '___many___' in configspec.scalars: + many = configspec['___many___'] + + if many is not None: + for entry in unvalidated: + val = section[entry] + ret_true, ret_false = validate_entry(entry, many, val, False, + ret_true, ret_false) + unvalidated = [] + + for entry in incorrect_scalars: + ret_true = False + if not preserve_errors: + out[entry] = False + else: + ret_false = False + msg = 'Value %r was provided as a section' % entry + out[entry] = validator.baseErrorClass(msg) + for entry in incorrect_sections: + ret_true = False + if not preserve_errors: + out[entry] = False + else: + ret_false = False + msg = 'Section %r was provided as a single value' % entry + out[entry] = validator.baseErrorClass(msg) + + # Missing sections will have been created as empty ones when the + # configspec was read. + for entry in section.sections: + # FIXME: this means DEFAULT is not copied in copy mode + if section is self and entry == 'DEFAULT': + continue + if section[entry].configspec is None: + unvalidated.append(entry) + continue + if copy: + section.comments[entry] = configspec.comments.get(entry, []) + section.inline_comments[entry] = configspec.inline_comments.get(entry, '') + check = self.validate(validator, preserve_errors=preserve_errors, copy=copy, section=section[entry]) + out[entry] = check + if check == False: + ret_true = False + elif check == True: + ret_false = False + else: + ret_true = False + + section.extra_values = unvalidated + if preserve_errors and not section._created: + # If the section wasn't created (i.e. it wasn't missing) + # then we can't return False, we need to preserve errors + ret_false = False + # + if ret_false and preserve_errors and out: + # If we are preserving errors, but all + # the failures are from missing sections / values + # then we can return False. Otherwise there is a + # real failure that we need to preserve. + ret_false = not any(out.values()) + if ret_true: + return True + elif ret_false: + return False + return out + + + def reset(self): + """Clear ConfigObj instance and restore to 'freshly created' state.""" + self.clear() + self._initialise() + # FIXME: Should be done by '_initialise', but ConfigObj constructor (and reload) + # requires an empty dictionary + self.configspec = None + # Just to be sure ;-) + self._original_configspec = None + + + def reload(self): + """ + Reload a ConfigObj from file. + + This method raises a ``ReloadError`` if the ConfigObj doesn't have + a filename attribute pointing to a file. + """ + if not isinstance(self.filename, basestring): + raise ReloadError() + + filename = self.filename + current_options = {} + for entry in OPTION_DEFAULTS: + if entry == 'configspec': + continue + current_options[entry] = getattr(self, entry) + + configspec = self._original_configspec + current_options['configspec'] = configspec + + self.clear() + self._initialise(current_options) + self._load(filename, configspec) + + + +class SimpleVal(object): + """ + A simple validator. + Can be used to check that all members expected are present. + + To use it, provide a configspec with all your members in (the value given + will be ignored). Pass an instance of ``SimpleVal`` to the ``validate`` + method of your ``ConfigObj``. ``validate`` will return ``True`` if all + members are present, or a dictionary with True/False meaning + present/missing. (Whole missing sections will be replaced with ``False``) + """ + + def __init__(self): + self.baseErrorClass = ConfigObjError + + def check(self, check, member, missing=False): + """A dummy check method, always returns the value unchanged.""" + if missing: + raise self.baseErrorClass() + return member + + +def flatten_errors(cfg, res, levels=None, results=None): + """ + An example function that will turn a nested dictionary of results + (as returned by ``ConfigObj.validate``) into a flat list. + + ``cfg`` is the ConfigObj instance being checked, ``res`` is the results + dictionary returned by ``validate``. + + (This is a recursive function, so you shouldn't use the ``levels`` or + ``results`` arguments - they are used by the function.) + + Returns a list of keys that failed. Each member of the list is a tuple:: + + ([list of sections...], key, result) + + If ``validate`` was called with ``preserve_errors=False`` (the default) + then ``result`` will always be ``False``. + + *list of sections* is a flattened list of sections that the key was found + in. + + If the section was missing (or a section was expected and a scalar provided + - or vice-versa) then key will be ``None``. + + If the value (or section) was missing then ``result`` will be ``False``. + + If ``validate`` was called with ``preserve_errors=True`` and a value + was present, but failed the check, then ``result`` will be the exception + object returned. You can use this as a string that describes the failure. + + For example *The value "3" is of the wrong type*. + """ + if levels is None: + # first time called + levels = [] + results = [] + if res == True: + return results + if res == False or isinstance(res, Exception): + results.append((levels[:], None, res)) + if levels: + levels.pop() + return results + for (key, val) in res.items(): + if val == True: + continue + if isinstance(cfg.get(key), dict): + # Go down one level + levels.append(key) + flatten_errors(cfg[key], val, levels, results) + continue + results.append((levels[:], key, val)) + # + # Go up one level + if levels: + levels.pop() + # + return results + + +def get_extra_values(conf, _prepend=()): + """ + Find all the values and sections not in the configspec from a validated + ConfigObj. + + ``get_extra_values`` returns a list of tuples where each tuple represents + either an extra section, or an extra value. + + The tuples contain two values, a tuple representing the section the value + is in and the name of the extra values. For extra values in the top level + section the first member will be an empty tuple. For values in the 'foo' + section the first member will be ``('foo',)``. For members in the 'bar' + subsection of the 'foo' section the first member will be ``('foo', 'bar')``. + + NOTE: If you call ``get_extra_values`` on a ConfigObj instance that hasn't + been validated it will return an empty list. + """ + out = [] + + out.extend([(_prepend, name) for name in conf.extra_values]) + for name in conf.sections: + if name not in conf.extra_values: + out.extend(get_extra_values(conf[name], _prepend + (name,))) + return out + + +"""*A programming language is a medium of expression.* - Paul Graham""" diff --git a/lib/enum/LICENSE b/lib/enum/LICENSE new file mode 100644 index 00000000..9003b885 --- /dev/null +++ b/lib/enum/LICENSE @@ -0,0 +1,32 @@ +Copyright (c) 2013, Ethan Furman. +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions +are met: + + Redistributions of source code must retain the above + copyright notice, this list of conditions and the + following disclaimer. + + Redistributions in binary form must reproduce the above + copyright notice, this list of conditions and the following + disclaimer in the documentation and/or other materials + provided with the distribution. + + Neither the name Ethan Furman nor the names of any + contributors may be used to endorse or promote products + derived from this software without specific prior written + permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +POSSIBILITY OF SUCH DAMAGE. diff --git a/lib/enum/README b/lib/enum/README new file mode 100644 index 00000000..511af984 --- /dev/null +++ b/lib/enum/README @@ -0,0 +1,2 @@ +enum34 is the new Python stdlib enum module available in Python 3.4 +backported for previous versions of Python from 2.4 to 3.3. diff --git a/lib/enum/__init__.py b/lib/enum/__init__.py new file mode 100644 index 00000000..6a327a8a --- /dev/null +++ b/lib/enum/__init__.py @@ -0,0 +1,790 @@ +"""Python Enumerations""" + +import sys as _sys + +__all__ = ['Enum', 'IntEnum', 'unique'] + +version = 1, 0, 4 + +pyver = float('%s.%s' % _sys.version_info[:2]) + +try: + any +except NameError: + def any(iterable): + for element in iterable: + if element: + return True + return False + +try: + from collections import OrderedDict +except ImportError: + OrderedDict = None + +try: + basestring +except NameError: + # In Python 2 basestring is the ancestor of both str and unicode + # in Python 3 it's just str, but was missing in 3.1 + basestring = str + +try: + unicode +except NameError: + # In Python 3 unicode no longer exists (it's just str) + unicode = str + +class _RouteClassAttributeToGetattr(object): + """Route attribute access on a class to __getattr__. + + This is a descriptor, used to define attributes that act differently when + accessed through an instance and through a class. Instance access remains + normal, but access to an attribute through a class will be routed to the + class's __getattr__ method; this is done by raising AttributeError. + + """ + def __init__(self, fget=None): + self.fget = fget + + def __get__(self, instance, ownerclass=None): + if instance is None: + raise AttributeError() + return self.fget(instance) + + def __set__(self, instance, value): + raise AttributeError("can't set attribute") + + def __delete__(self, instance): + raise AttributeError("can't delete attribute") + + +def _is_descriptor(obj): + """Returns True if obj is a descriptor, False otherwise.""" + return ( + hasattr(obj, '__get__') or + hasattr(obj, '__set__') or + hasattr(obj, '__delete__')) + + +def _is_dunder(name): + """Returns True if a __dunder__ name, False otherwise.""" + return (name[:2] == name[-2:] == '__' and + name[2:3] != '_' and + name[-3:-2] != '_' and + len(name) > 4) + + +def _is_sunder(name): + """Returns True if a _sunder_ name, False otherwise.""" + return (name[0] == name[-1] == '_' and + name[1:2] != '_' and + name[-2:-1] != '_' and + len(name) > 2) + + +def _make_class_unpicklable(cls): + """Make the given class un-picklable.""" + def _break_on_call_reduce(self, protocol=None): + raise TypeError('%r cannot be pickled' % self) + cls.__reduce_ex__ = _break_on_call_reduce + cls.__module__ = '' + + +class _EnumDict(dict): + """Track enum member order and ensure member names are not reused. + + EnumMeta will use the names found in self._member_names as the + enumeration member names. + + """ + def __init__(self): + super(_EnumDict, self).__init__() + self._member_names = [] + + def __setitem__(self, key, value): + """Changes anything not dundered or not a descriptor. + + If a descriptor is added with the same name as an enum member, the name + is removed from _member_names (this may leave a hole in the numerical + sequence of values). + + If an enum member name is used twice, an error is raised; duplicate + values are not checked for. + + Single underscore (sunder) names are reserved. + + Note: in 3.x __order__ is simply discarded as a not necessary piece + leftover from 2.x + + """ + if pyver >= 3.0 and key == '__order__': + return + if _is_sunder(key): + raise ValueError('_names_ are reserved for future Enum use') + elif _is_dunder(key): + pass + elif key in self._member_names: + # descriptor overwriting an enum? + raise TypeError('Attempted to reuse key: %r' % key) + elif not _is_descriptor(value): + if key in self: + # enum overwriting a descriptor? + raise TypeError('Key already defined as: %r' % self[key]) + self._member_names.append(key) + super(_EnumDict, self).__setitem__(key, value) + + +# Dummy value for Enum as EnumMeta explicity checks for it, but of course until +# EnumMeta finishes running the first time the Enum class doesn't exist. This +# is also why there are checks in EnumMeta like `if Enum is not None` +Enum = None + + +class EnumMeta(type): + """Metaclass for Enum""" + @classmethod + def __prepare__(metacls, cls, bases): + return _EnumDict() + + def __new__(metacls, cls, bases, classdict): + # an Enum class is final once enumeration items have been defined; it + # cannot be mixed with other types (int, float, etc.) if it has an + # inherited __new__ unless a new __new__ is defined (or the resulting + # class will fail). + if type(classdict) is dict: + original_dict = classdict + classdict = _EnumDict() + for k, v in original_dict.items(): + classdict[k] = v + + member_type, first_enum = metacls._get_mixins_(bases) + __new__, save_new, use_args = metacls._find_new_(classdict, member_type, + first_enum) + # save enum items into separate mapping so they don't get baked into + # the new class + members = dict((k, classdict[k]) for k in classdict._member_names) + for name in classdict._member_names: + del classdict[name] + + # py2 support for definition order + __order__ = classdict.get('__order__') + if __order__ is None: + if pyver < 3.0: + try: + __order__ = [name for (name, value) in sorted(members.items(), key=lambda item: item[1])] + except TypeError: + __order__ = [name for name in sorted(members.keys())] + else: + __order__ = classdict._member_names + else: + del classdict['__order__'] + if pyver < 3.0: + __order__ = __order__.replace(',', ' ').split() + aliases = [name for name in members if name not in __order__] + __order__ += aliases + + # check for illegal enum names (any others?) + invalid_names = set(members) & set(['mro']) + if invalid_names: + raise ValueError('Invalid enum member name(s): %s' % ( + ', '.join(invalid_names), )) + + # create our new Enum type + enum_class = super(EnumMeta, metacls).__new__(metacls, cls, bases, classdict) + enum_class._member_names_ = [] # names in random order + if OrderedDict is not None: + enum_class._member_map_ = OrderedDict() + else: + enum_class._member_map_ = {} # name->value map + enum_class._member_type_ = member_type + + # Reverse value->name map for hashable values. + enum_class._value2member_map_ = {} + + # instantiate them, checking for duplicates as we go + # we instantiate first instead of checking for duplicates first in case + # a custom __new__ is doing something funky with the values -- such as + # auto-numbering ;) + if __new__ is None: + __new__ = enum_class.__new__ + for member_name in __order__: + value = members[member_name] + if not isinstance(value, tuple): + args = (value, ) + else: + args = value + if member_type is tuple: # special case for tuple enums + args = (args, ) # wrap it one more time + if not use_args or not args: + enum_member = __new__(enum_class) + if not hasattr(enum_member, '_value_'): + enum_member._value_ = value + else: + enum_member = __new__(enum_class, *args) + if not hasattr(enum_member, '_value_'): + enum_member._value_ = member_type(*args) + value = enum_member._value_ + enum_member._name_ = member_name + enum_member.__objclass__ = enum_class + enum_member.__init__(*args) + # If another member with the same value was already defined, the + # new member becomes an alias to the existing one. + for name, canonical_member in enum_class._member_map_.items(): + if canonical_member.value == enum_member._value_: + enum_member = canonical_member + break + else: + # Aliases don't appear in member names (only in __members__). + enum_class._member_names_.append(member_name) + enum_class._member_map_[member_name] = enum_member + try: + # This may fail if value is not hashable. We can't add the value + # to the map, and by-value lookups for this value will be + # linear. + enum_class._value2member_map_[value] = enum_member + except TypeError: + pass + + + # If a custom type is mixed into the Enum, and it does not know how + # to pickle itself, pickle.dumps will succeed but pickle.loads will + # fail. Rather than have the error show up later and possibly far + # from the source, sabotage the pickle protocol for this class so + # that pickle.dumps also fails. + # + # However, if the new class implements its own __reduce_ex__, do not + # sabotage -- it's on them to make sure it works correctly. We use + # __reduce_ex__ instead of any of the others as it is preferred by + # pickle over __reduce__, and it handles all pickle protocols. + unpicklable = False + if '__reduce_ex__' not in classdict: + if member_type is not object: + methods = ('__getnewargs_ex__', '__getnewargs__', + '__reduce_ex__', '__reduce__') + if not any(m in member_type.__dict__ for m in methods): + _make_class_unpicklable(enum_class) + unpicklable = True + + + # double check that repr and friends are not the mixin's or various + # things break (such as pickle) + for name in ('__repr__', '__str__', '__format__', '__reduce_ex__'): + class_method = getattr(enum_class, name) + obj_method = getattr(member_type, name, None) + enum_method = getattr(first_enum, name, None) + if name not in classdict and class_method is not enum_method: + if name == '__reduce_ex__' and unpicklable: + continue + setattr(enum_class, name, enum_method) + + # method resolution and int's are not playing nice + # Python's less than 2.6 use __cmp__ + + if pyver < 2.6: + + if issubclass(enum_class, int): + setattr(enum_class, '__cmp__', getattr(int, '__cmp__')) + + elif pyver < 3.0: + + if issubclass(enum_class, int): + for method in ( + '__le__', + '__lt__', + '__gt__', + '__ge__', + '__eq__', + '__ne__', + '__hash__', + ): + setattr(enum_class, method, getattr(int, method)) + + # replace any other __new__ with our own (as long as Enum is not None, + # anyway) -- again, this is to support pickle + if Enum is not None: + # if the user defined their own __new__, save it before it gets + # clobbered in case they subclass later + if save_new: + setattr(enum_class, '__member_new__', enum_class.__dict__['__new__']) + setattr(enum_class, '__new__', Enum.__dict__['__new__']) + return enum_class + + def __call__(cls, value, names=None, module=None, type=None): + """Either returns an existing member, or creates a new enum class. + + This method is used both when an enum class is given a value to match + to an enumeration member (i.e. Color(3)) and for the functional API + (i.e. Color = Enum('Color', names='red green blue')). + + When used for the functional API: `module`, if set, will be stored in + the new class' __module__ attribute; `type`, if set, will be mixed in + as the first base class. + + Note: if `module` is not set this routine will attempt to discover the + calling module by walking the frame stack; if this is unsuccessful + the resulting class will not be pickleable. + + """ + if names is None: # simple value lookup + return cls.__new__(cls, value) + # otherwise, functional API: we're creating a new Enum type + return cls._create_(value, names, module=module, type=type) + + def __contains__(cls, member): + return isinstance(member, cls) and member.name in cls._member_map_ + + def __delattr__(cls, attr): + # nicer error message when someone tries to delete an attribute + # (see issue19025). + if attr in cls._member_map_: + raise AttributeError( + "%s: cannot delete Enum member." % cls.__name__) + super(EnumMeta, cls).__delattr__(attr) + + def __dir__(self): + return (['__class__', '__doc__', '__members__', '__module__'] + + self._member_names_) + + @property + def __members__(cls): + """Returns a mapping of member name->value. + + This mapping lists all enum members, including aliases. Note that this + is a copy of the internal mapping. + + """ + return cls._member_map_.copy() + + def __getattr__(cls, name): + """Return the enum member matching `name` + + We use __getattr__ instead of descriptors or inserting into the enum + class' __dict__ in order to support `name` and `value` being both + properties for enum members (which live in the class' __dict__) and + enum members themselves. + + """ + if _is_dunder(name): + raise AttributeError(name) + try: + return cls._member_map_[name] + except KeyError: + raise AttributeError(name) + + def __getitem__(cls, name): + return cls._member_map_[name] + + def __iter__(cls): + return (cls._member_map_[name] for name in cls._member_names_) + + def __reversed__(cls): + return (cls._member_map_[name] for name in reversed(cls._member_names_)) + + def __len__(cls): + return len(cls._member_names_) + + def __repr__(cls): + return "" % cls.__name__ + + def __setattr__(cls, name, value): + """Block attempts to reassign Enum members. + + A simple assignment to the class namespace only changes one of the + several possible ways to get an Enum member from the Enum class, + resulting in an inconsistent Enumeration. + + """ + member_map = cls.__dict__.get('_member_map_', {}) + if name in member_map: + raise AttributeError('Cannot reassign members.') + super(EnumMeta, cls).__setattr__(name, value) + + def _create_(cls, class_name, names=None, module=None, type=None): + """Convenience method to create a new Enum class. + + `names` can be: + + * A string containing member names, separated either with spaces or + commas. Values are auto-numbered from 1. + * An iterable of member names. Values are auto-numbered from 1. + * An iterable of (member name, value) pairs. + * A mapping of member name -> value. + + """ + if pyver < 3.0: + # if class_name is unicode, attempt a conversion to ASCII + if isinstance(class_name, unicode): + try: + class_name = class_name.encode('ascii') + except UnicodeEncodeError: + raise TypeError('%r is not representable in ASCII' % class_name) + metacls = cls.__class__ + if type is None: + bases = (cls, ) + else: + bases = (type, cls) + classdict = metacls.__prepare__(class_name, bases) + __order__ = [] + + # special processing needed for names? + if isinstance(names, basestring): + names = names.replace(',', ' ').split() + if isinstance(names, (tuple, list)) and isinstance(names[0], basestring): + names = [(e, i+1) for (i, e) in enumerate(names)] + + # Here, names is either an iterable of (name, value) or a mapping. + for item in names: + if isinstance(item, basestring): + member_name, member_value = item, names[item] + else: + member_name, member_value = item + classdict[member_name] = member_value + __order__.append(member_name) + # only set __order__ in classdict if name/value was not from a mapping + if not isinstance(item, basestring): + classdict['__order__'] = ' '.join(__order__) + enum_class = metacls.__new__(metacls, class_name, bases, classdict) + + # TODO: replace the frame hack if a blessed way to know the calling + # module is ever developed + if module is None: + try: + module = _sys._getframe(2).f_globals['__name__'] + except (AttributeError, ValueError): + pass + if module is None: + _make_class_unpicklable(enum_class) + else: + enum_class.__module__ = module + + return enum_class + + @staticmethod + def _get_mixins_(bases): + """Returns the type for creating enum members, and the first inherited + enum class. + + bases: the tuple of bases that was given to __new__ + + """ + if not bases or Enum is None: + return object, Enum + + + # double check that we are not subclassing a class with existing + # enumeration members; while we're at it, see if any other data + # type has been mixed in so we can use the correct __new__ + member_type = first_enum = None + for base in bases: + if (base is not Enum and + issubclass(base, Enum) and + base._member_names_): + raise TypeError("Cannot extend enumerations") + # base is now the last base in bases + if not issubclass(base, Enum): + raise TypeError("new enumerations must be created as " + "`ClassName([mixin_type,] enum_type)`") + + # get correct mix-in type (either mix-in type of Enum subclass, or + # first base if last base is Enum) + if not issubclass(bases[0], Enum): + member_type = bases[0] # first data type + first_enum = bases[-1] # enum type + else: + for base in bases[0].__mro__: + # most common: (IntEnum, int, Enum, object) + # possible: (, , + # , , + # ) + if issubclass(base, Enum): + if first_enum is None: + first_enum = base + else: + if member_type is None: + member_type = base + + return member_type, first_enum + + if pyver < 3.0: + @staticmethod + def _find_new_(classdict, member_type, first_enum): + """Returns the __new__ to be used for creating the enum members. + + classdict: the class dictionary given to __new__ + member_type: the data type whose __new__ will be used by default + first_enum: enumeration to check for an overriding __new__ + + """ + # now find the correct __new__, checking to see of one was defined + # by the user; also check earlier enum classes in case a __new__ was + # saved as __member_new__ + __new__ = classdict.get('__new__', None) + if __new__: + return None, True, True # __new__, save_new, use_args + + N__new__ = getattr(None, '__new__') + O__new__ = getattr(object, '__new__') + if Enum is None: + E__new__ = N__new__ + else: + E__new__ = Enum.__dict__['__new__'] + # check all possibles for __member_new__ before falling back to + # __new__ + for method in ('__member_new__', '__new__'): + for possible in (member_type, first_enum): + try: + target = possible.__dict__[method] + except (AttributeError, KeyError): + target = getattr(possible, method, None) + if target not in [ + None, + N__new__, + O__new__, + E__new__, + ]: + if method == '__member_new__': + classdict['__new__'] = target + return None, False, True + if isinstance(target, staticmethod): + target = target.__get__(member_type) + __new__ = target + break + if __new__ is not None: + break + else: + __new__ = object.__new__ + + # if a non-object.__new__ is used then whatever value/tuple was + # assigned to the enum member name will be passed to __new__ and to the + # new enum member's __init__ + if __new__ is object.__new__: + use_args = False + else: + use_args = True + + return __new__, False, use_args + else: + @staticmethod + def _find_new_(classdict, member_type, first_enum): + """Returns the __new__ to be used for creating the enum members. + + classdict: the class dictionary given to __new__ + member_type: the data type whose __new__ will be used by default + first_enum: enumeration to check for an overriding __new__ + + """ + # now find the correct __new__, checking to see of one was defined + # by the user; also check earlier enum classes in case a __new__ was + # saved as __member_new__ + __new__ = classdict.get('__new__', None) + + # should __new__ be saved as __member_new__ later? + save_new = __new__ is not None + + if __new__ is None: + # check all possibles for __member_new__ before falling back to + # __new__ + for method in ('__member_new__', '__new__'): + for possible in (member_type, first_enum): + target = getattr(possible, method, None) + if target not in ( + None, + None.__new__, + object.__new__, + Enum.__new__, + ): + __new__ = target + break + if __new__ is not None: + break + else: + __new__ = object.__new__ + + # if a non-object.__new__ is used then whatever value/tuple was + # assigned to the enum member name will be passed to __new__ and to the + # new enum member's __init__ + if __new__ is object.__new__: + use_args = False + else: + use_args = True + + return __new__, save_new, use_args + + +######################################################## +# In order to support Python 2 and 3 with a single +# codebase we have to create the Enum methods separately +# and then use the `type(name, bases, dict)` method to +# create the class. +######################################################## +temp_enum_dict = {} +temp_enum_dict['__doc__'] = "Generic enumeration.\n\n Derive from this class to define new enumerations.\n\n" + +def __new__(cls, value): + # all enum instances are actually created during class construction + # without calling this method; this method is called by the metaclass' + # __call__ (i.e. Color(3) ), and by pickle + if type(value) is cls: + # For lookups like Color(Color.red) + value = value.value + #return value + # by-value search for a matching enum member + # see if it's in the reverse mapping (for hashable values) + try: + if value in cls._value2member_map_: + return cls._value2member_map_[value] + except TypeError: + # not there, now do long search -- O(n) behavior + for member in cls._member_map_.values(): + if member.value == value: + return member + raise ValueError("%s is not a valid %s" % (value, cls.__name__)) +temp_enum_dict['__new__'] = __new__ +del __new__ + +def __repr__(self): + return "<%s.%s: %r>" % ( + self.__class__.__name__, self._name_, self._value_) +temp_enum_dict['__repr__'] = __repr__ +del __repr__ + +def __str__(self): + return "%s.%s" % (self.__class__.__name__, self._name_) +temp_enum_dict['__str__'] = __str__ +del __str__ + +def __dir__(self): + added_behavior = [ + m + for cls in self.__class__.mro() + for m in cls.__dict__ + if m[0] != '_' + ] + return (['__class__', '__doc__', '__module__', ] + added_behavior) +temp_enum_dict['__dir__'] = __dir__ +del __dir__ + +def __format__(self, format_spec): + # mixed-in Enums should use the mixed-in type's __format__, otherwise + # we can get strange results with the Enum name showing up instead of + # the value + + # pure Enum branch + if self._member_type_ is object: + cls = str + val = str(self) + # mix-in branch + else: + cls = self._member_type_ + val = self.value + return cls.__format__(val, format_spec) +temp_enum_dict['__format__'] = __format__ +del __format__ + + +#################################### +# Python's less than 2.6 use __cmp__ + +if pyver < 2.6: + + def __cmp__(self, other): + if type(other) is self.__class__: + if self is other: + return 0 + return -1 + return NotImplemented + raise TypeError("unorderable types: %s() and %s()" % (self.__class__.__name__, other.__class__.__name__)) + temp_enum_dict['__cmp__'] = __cmp__ + del __cmp__ + +else: + + def __le__(self, other): + raise TypeError("unorderable types: %s() <= %s()" % (self.__class__.__name__, other.__class__.__name__)) + temp_enum_dict['__le__'] = __le__ + del __le__ + + def __lt__(self, other): + raise TypeError("unorderable types: %s() < %s()" % (self.__class__.__name__, other.__class__.__name__)) + temp_enum_dict['__lt__'] = __lt__ + del __lt__ + + def __ge__(self, other): + raise TypeError("unorderable types: %s() >= %s()" % (self.__class__.__name__, other.__class__.__name__)) + temp_enum_dict['__ge__'] = __ge__ + del __ge__ + + def __gt__(self, other): + raise TypeError("unorderable types: %s() > %s()" % (self.__class__.__name__, other.__class__.__name__)) + temp_enum_dict['__gt__'] = __gt__ + del __gt__ + + +def __eq__(self, other): + if type(other) is self.__class__: + return self is other + return NotImplemented +temp_enum_dict['__eq__'] = __eq__ +del __eq__ + +def __ne__(self, other): + if type(other) is self.__class__: + return self is not other + return NotImplemented +temp_enum_dict['__ne__'] = __ne__ +del __ne__ + +def __hash__(self): + return hash(self._name_) +temp_enum_dict['__hash__'] = __hash__ +del __hash__ + +def __reduce_ex__(self, proto): + return self.__class__, (self._value_, ) +temp_enum_dict['__reduce_ex__'] = __reduce_ex__ +del __reduce_ex__ + +# _RouteClassAttributeToGetattr is used to provide access to the `name` +# and `value` properties of enum members while keeping some measure of +# protection from modification, while still allowing for an enumeration +# to have members named `name` and `value`. This works because enumeration +# members are not set directly on the enum class -- __getattr__ is +# used to look them up. + +@_RouteClassAttributeToGetattr +def name(self): + return self._name_ +temp_enum_dict['name'] = name +del name + +@_RouteClassAttributeToGetattr +def value(self): + return self._value_ +temp_enum_dict['value'] = value +del value + +Enum = EnumMeta('Enum', (object, ), temp_enum_dict) +del temp_enum_dict + +# Enum has now been created +########################### + +class IntEnum(int, Enum): + """Enum where members are also (and must be) ints""" + + +def unique(enumeration): + """Class decorator that ensures only unique members exist in an enumeration.""" + duplicates = [] + for name, member in enumeration.__members__.items(): + if name != member.name: + duplicates.append((name, member.name)) + if duplicates: + duplicate_names = ', '.join( + ["%s -> %s" % (alias, name) for (alias, name) in duplicates] + ) + raise ValueError('duplicate names found in %r: %s' % + (enumeration, duplicate_names) + ) + return enumeration diff --git a/lib/enum/doc/enum.rst b/lib/enum/doc/enum.rst new file mode 100644 index 00000000..0d429bfc --- /dev/null +++ b/lib/enum/doc/enum.rst @@ -0,0 +1,725 @@ +``enum`` --- support for enumerations +======================================== + +.. :synopsis: enumerations are sets of symbolic names bound to unique, constant + values. +.. :moduleauthor:: Ethan Furman +.. :sectionauthor:: Barry Warsaw , +.. :sectionauthor:: Eli Bendersky , +.. :sectionauthor:: Ethan Furman + +---------------- + +An enumeration is a set of symbolic names (members) bound to unique, constant +values. Within an enumeration, the members can be compared by identity, and +the enumeration itself can be iterated over. + + +Module Contents +--------------- + +This module defines two enumeration classes that can be used to define unique +sets of names and values: ``Enum`` and ``IntEnum``. It also defines +one decorator, ``unique``. + +``Enum`` + +Base class for creating enumerated constants. See section `Functional API`_ +for an alternate construction syntax. + +``IntEnum`` + +Base class for creating enumerated constants that are also subclasses of ``int``. + +``unique`` + +Enum class decorator that ensures only one name is bound to any one value. + + +Creating an Enum +---------------- + +Enumerations are created using the ``class`` syntax, which makes them +easy to read and write. An alternative creation method is described in +`Functional API`_. To define an enumeration, subclass ``Enum`` as +follows:: + + >>> from enum import Enum + >>> class Color(Enum): + ... red = 1 + ... green = 2 + ... blue = 3 + +Note: Nomenclature + + - The class ``Color`` is an *enumeration* (or *enum*) + - The attributes ``Color.red``, ``Color.green``, etc., are + *enumeration members* (or *enum members*). + - The enum members have *names* and *values* (the name of + ``Color.red`` is ``red``, the value of ``Color.blue`` is + ``3``, etc.) + +Note: + + Even though we use the ``class`` syntax to create Enums, Enums + are not normal Python classes. See `How are Enums different?`_ for + more details. + +Enumeration members have human readable string representations:: + + >>> print(Color.red) + Color.red + +...while their ``repr`` has more information:: + + >>> print(repr(Color.red)) + + +The *type* of an enumeration member is the enumeration it belongs to:: + + >>> type(Color.red) + + >>> isinstance(Color.green, Color) + True + >>> + +Enum members also have a property that contains just their item name:: + + >>> print(Color.red.name) + red + +Enumerations support iteration. In Python 3.x definition order is used; in +Python 2.x the definition order is not available, but class attribute +``__order__`` is supported; otherwise, value order is used:: + + >>> class Shake(Enum): + ... __order__ = 'vanilla chocolate cookies mint' # only needed in 2.x + ... vanilla = 7 + ... chocolate = 4 + ... cookies = 9 + ... mint = 3 + ... + >>> for shake in Shake: + ... print(shake) + ... + Shake.vanilla + Shake.chocolate + Shake.cookies + Shake.mint + +The ``__order__`` attribute is always removed, and in 3.x it is also ignored +(order is definition order); however, in the stdlib version it will be ignored +but not removed. + +Enumeration members are hashable, so they can be used in dictionaries and sets:: + + >>> apples = {} + >>> apples[Color.red] = 'red delicious' + >>> apples[Color.green] = 'granny smith' + >>> apples == {Color.red: 'red delicious', Color.green: 'granny smith'} + True + + +Programmatic access to enumeration members and their attributes +--------------------------------------------------------------- + +Sometimes it's useful to access members in enumerations programmatically (i.e. +situations where ``Color.red`` won't do because the exact color is not known +at program-writing time). ``Enum`` allows such access:: + + >>> Color(1) + + >>> Color(3) + + +If you want to access enum members by *name*, use item access:: + + >>> Color['red'] + + >>> Color['green'] + + +If have an enum member and need its ``name`` or ``value``:: + + >>> member = Color.red + >>> member.name + 'red' + >>> member.value + 1 + + +Duplicating enum members and values +----------------------------------- + +Having two enum members (or any other attribute) with the same name is invalid; +in Python 3.x this would raise an error, but in Python 2.x the second member +simply overwrites the first:: + + >>> # python 2.x + >>> class Shape(Enum): + ... square = 2 + ... square = 3 + ... + >>> Shape.square + + + >>> # python 3.x + >>> class Shape(Enum): + ... square = 2 + ... square = 3 + Traceback (most recent call last): + ... + TypeError: Attempted to reuse key: 'square' + +However, two enum members are allowed to have the same value. Given two members +A and B with the same value (and A defined first), B is an alias to A. By-value +lookup of the value of A and B will return A. By-name lookup of B will also +return A:: + + >>> class Shape(Enum): + ... __order__ = 'square diamond circle alias_for_square' # only needed in 2.x + ... square = 2 + ... diamond = 1 + ... circle = 3 + ... alias_for_square = 2 + ... + >>> Shape.square + + >>> Shape.alias_for_square + + >>> Shape(2) + + + +Allowing aliases is not always desirable. ``unique`` can be used to ensure +that none exist in a particular enumeration:: + + >>> from enum import unique + >>> @unique + ... class Mistake(Enum): + ... __order__ = 'one two three four' # only needed in 2.x + ... one = 1 + ... two = 2 + ... three = 3 + ... four = 3 + Traceback (most recent call last): + ... + ValueError: duplicate names found in : four -> three + +Iterating over the members of an enum does not provide the aliases:: + + >>> list(Shape) + [, , ] + +The special attribute ``__members__`` is a dictionary mapping names to members. +It includes all names defined in the enumeration, including the aliases:: + + >>> for name, member in sorted(Shape.__members__.items()): + ... name, member + ... + ('alias_for_square', ) + ('circle', ) + ('diamond', ) + ('square', ) + +The ``__members__`` attribute can be used for detailed programmatic access to +the enumeration members. For example, finding all the aliases:: + + >>> [name for name, member in Shape.__members__.items() if member.name != name] + ['alias_for_square'] + +Comparisons +----------- + +Enumeration members are compared by identity:: + + >>> Color.red is Color.red + True + >>> Color.red is Color.blue + False + >>> Color.red is not Color.blue + True + +Ordered comparisons between enumeration values are *not* supported. Enum +members are not integers (but see `IntEnum`_ below):: + + >>> Color.red < Color.blue + Traceback (most recent call last): + File "", line 1, in + TypeError: unorderable types: Color() < Color() + +.. warning:: + + In Python 2 *everything* is ordered, even though the ordering may not + make sense. If you want your enumerations to have a sensible ordering + check out the `OrderedEnum`_ recipe below. + + +Equality comparisons are defined though:: + + >>> Color.blue == Color.red + False + >>> Color.blue != Color.red + True + >>> Color.blue == Color.blue + True + +Comparisons against non-enumeration values will always compare not equal +(again, ``IntEnum`` was explicitly designed to behave differently, see +below):: + + >>> Color.blue == 2 + False + + +Allowed members and attributes of enumerations +---------------------------------------------- + +The examples above use integers for enumeration values. Using integers is +short and handy (and provided by default by the `Functional API`_), but not +strictly enforced. In the vast majority of use-cases, one doesn't care what +the actual value of an enumeration is. But if the value *is* important, +enumerations can have arbitrary values. + +Enumerations are Python classes, and can have methods and special methods as +usual. If we have this enumeration:: + + >>> class Mood(Enum): + ... funky = 1 + ... happy = 3 + ... + ... def describe(self): + ... # self is the member here + ... return self.name, self.value + ... + ... def __str__(self): + ... return 'my custom str! {0}'.format(self.value) + ... + ... @classmethod + ... def favorite_mood(cls): + ... # cls here is the enumeration + ... return cls.happy + +Then:: + + >>> Mood.favorite_mood() + + >>> Mood.happy.describe() + ('happy', 3) + >>> str(Mood.funky) + 'my custom str! 1' + +The rules for what is allowed are as follows: _sunder_ names (starting and +ending with a single underscore) are reserved by enum and cannot be used; +all other attributes defined within an enumeration will become members of this +enumeration, with the exception of *__dunder__* names and descriptors (methods +are also descriptors). + +Note: + + If your enumeration defines ``__new__`` and/or ``__init__`` then + whatever value(s) were given to the enum member will be passed into + those methods. See `Planet`_ for an example. + + +Restricted subclassing of enumerations +-------------------------------------- + +Subclassing an enumeration is allowed only if the enumeration does not define +any members. So this is forbidden:: + + >>> class MoreColor(Color): + ... pink = 17 + Traceback (most recent call last): + ... + TypeError: Cannot extend enumerations + +But this is allowed:: + + >>> class Foo(Enum): + ... def some_behavior(self): + ... pass + ... + >>> class Bar(Foo): + ... happy = 1 + ... sad = 2 + ... + +Allowing subclassing of enums that define members would lead to a violation of +some important invariants of types and instances. On the other hand, it makes +sense to allow sharing some common behavior between a group of enumerations. +(See `OrderedEnum`_ for an example.) + + +Pickling +-------- + +Enumerations can be pickled and unpickled:: + + >>> from enum.test_enum import Fruit + >>> from pickle import dumps, loads + >>> Fruit.tomato is loads(dumps(Fruit.tomato, 2)) + True + +The usual restrictions for pickling apply: picklable enums must be defined in +the top level of a module, since unpickling requires them to be importable +from that module. + +Note: + + With pickle protocol version 4 (introduced in Python 3.4) it is possible + to easily pickle enums nested in other classes. + + + +Functional API +-------------- + +The ``Enum`` class is callable, providing the following functional API:: + + >>> Animal = Enum('Animal', 'ant bee cat dog') + >>> Animal + + >>> Animal.ant + + >>> Animal.ant.value + 1 + >>> list(Animal) + [, , , ] + +The semantics of this API resemble ``namedtuple``. The first argument +of the call to ``Enum`` is the name of the enumeration. + +The second argument is the *source* of enumeration member names. It can be a +whitespace-separated string of names, a sequence of names, a sequence of +2-tuples with key/value pairs, or a mapping (e.g. dictionary) of names to +values. The last two options enable assigning arbitrary values to +enumerations; the others auto-assign increasing integers starting with 1. A +new class derived from ``Enum`` is returned. In other words, the above +assignment to ``Animal`` is equivalent to:: + + >>> class Animals(Enum): + ... ant = 1 + ... bee = 2 + ... cat = 3 + ... dog = 4 + +Pickling enums created with the functional API can be tricky as frame stack +implementation details are used to try and figure out which module the +enumeration is being created in (e.g. it will fail if you use a utility +function in separate module, and also may not work on IronPython or Jython). +The solution is to specify the module name explicitly as follows:: + + >>> Animals = Enum('Animals', 'ant bee cat dog', module=__name__) + +Derived Enumerations +-------------------- + +IntEnum +^^^^^^^ + +A variation of ``Enum`` is provided which is also a subclass of +``int``. Members of an ``IntEnum`` can be compared to integers; +by extension, integer enumerations of different types can also be compared +to each other:: + + >>> from enum import IntEnum + >>> class Shape(IntEnum): + ... circle = 1 + ... square = 2 + ... + >>> class Request(IntEnum): + ... post = 1 + ... get = 2 + ... + >>> Shape == 1 + False + >>> Shape.circle == 1 + True + >>> Shape.circle == Request.post + True + +However, they still can't be compared to standard ``Enum`` enumerations:: + + >>> class Shape(IntEnum): + ... circle = 1 + ... square = 2 + ... + >>> class Color(Enum): + ... red = 1 + ... green = 2 + ... + >>> Shape.circle == Color.red + False + +``IntEnum`` values behave like integers in other ways you'd expect:: + + >>> int(Shape.circle) + 1 + >>> ['a', 'b', 'c'][Shape.circle] + 'b' + >>> [i for i in range(Shape.square)] + [0, 1] + +For the vast majority of code, ``Enum`` is strongly recommended, +since ``IntEnum`` breaks some semantic promises of an enumeration (by +being comparable to integers, and thus by transitivity to other +unrelated enumerations). It should be used only in special cases where +there's no other choice; for example, when integer constants are +replaced with enumerations and backwards compatibility is required with code +that still expects integers. + + +Others +^^^^^^ + +While ``IntEnum`` is part of the ``enum`` module, it would be very +simple to implement independently:: + + class IntEnum(int, Enum): + pass + +This demonstrates how similar derived enumerations can be defined; for example +a ``StrEnum`` that mixes in ``str`` instead of ``int``. + +Some rules: + +1. When subclassing ``Enum``, mix-in types must appear before + ``Enum`` itself in the sequence of bases, as in the ``IntEnum`` + example above. +2. While ``Enum`` can have members of any type, once you mix in an + additional type, all the members must have values of that type, e.g. + ``int`` above. This restriction does not apply to mix-ins which only + add methods and don't specify another data type such as ``int`` or + ``str``. +3. When another data type is mixed in, the ``value`` attribute is *not the + same* as the enum member itself, although it is equivalant and will compare + equal. +4. %-style formatting: ``%s`` and ``%r`` call ``Enum``'s ``__str__`` and + ``__repr__`` respectively; other codes (such as ``%i`` or ``%h`` for + IntEnum) treat the enum member as its mixed-in type. + + Note: Prior to Python 3.4 there is a bug in ``str``'s %-formatting: ``int`` + subclasses are printed as strings and not numbers when the ``%d``, ``%i``, + or ``%u`` codes are used. +5. ``str.__format__`` (or ``format``) will use the mixed-in + type's ``__format__``. If the ``Enum``'s ``str`` or + ``repr`` is desired use the ``!s`` or ``!r`` ``str`` format codes. + + +Decorators +---------- + +unique +^^^^^^ + +A ``class`` decorator specifically for enumerations. It searches an +enumeration's ``__members__`` gathering any aliases it finds; if any are +found ``ValueError`` is raised with the details:: + + >>> @unique + ... class NoDupes(Enum): + ... first = 'one' + ... second = 'two' + ... third = 'two' + Traceback (most recent call last): + ... + ValueError: duplicate names found in : third -> second + + +Interesting examples +-------------------- + +While ``Enum`` and ``IntEnum`` are expected to cover the majority of +use-cases, they cannot cover them all. Here are recipes for some different +types of enumerations that can be used directly, or as examples for creating +one's own. + + +AutoNumber +^^^^^^^^^^ + +Avoids having to specify the value for each enumeration member:: + + >>> class AutoNumber(Enum): + ... def __new__(cls): + ... value = len(cls.__members__) + 1 + ... obj = object.__new__(cls) + ... obj._value_ = value + ... return obj + ... + >>> class Color(AutoNumber): + ... __order__ = "red green blue" # only needed in 2.x + ... red = () + ... green = () + ... blue = () + ... + >>> Color.green.value == 2 + True + +Note: + + The `__new__` method, if defined, is used during creation of the Enum + members; it is then replaced by Enum's `__new__` which is used after + class creation for lookup of existing members. Due to the way Enums are + supposed to behave, there is no way to customize Enum's `__new__`. + + +UniqueEnum +^^^^^^^^^^ + +Raises an error if a duplicate member name is found instead of creating an +alias:: + + >>> class UniqueEnum(Enum): + ... def __init__(self, *args): + ... cls = self.__class__ + ... if any(self.value == e.value for e in cls): + ... a = self.name + ... e = cls(self.value).name + ... raise ValueError( + ... "aliases not allowed in UniqueEnum: %r --> %r" + ... % (a, e)) + ... + >>> class Color(UniqueEnum): + ... red = 1 + ... green = 2 + ... blue = 3 + ... grene = 2 + Traceback (most recent call last): + ... + ValueError: aliases not allowed in UniqueEnum: 'grene' --> 'green' + + +OrderedEnum +^^^^^^^^^^^ + +An ordered enumeration that is not based on ``IntEnum`` and so maintains +the normal ``Enum`` invariants (such as not being comparable to other +enumerations):: + + >>> class OrderedEnum(Enum): + ... def __ge__(self, other): + ... if self.__class__ is other.__class__: + ... return self._value_ >= other._value_ + ... return NotImplemented + ... def __gt__(self, other): + ... if self.__class__ is other.__class__: + ... return self._value_ > other._value_ + ... return NotImplemented + ... def __le__(self, other): + ... if self.__class__ is other.__class__: + ... return self._value_ <= other._value_ + ... return NotImplemented + ... def __lt__(self, other): + ... if self.__class__ is other.__class__: + ... return self._value_ < other._value_ + ... return NotImplemented + ... + >>> class Grade(OrderedEnum): + ... __ordered__ = 'A B C D F' + ... A = 5 + ... B = 4 + ... C = 3 + ... D = 2 + ... F = 1 + ... + >>> Grade.C < Grade.A + True + + +Planet +^^^^^^ + +If ``__new__`` or ``__init__`` is defined the value of the enum member +will be passed to those methods:: + + >>> class Planet(Enum): + ... MERCURY = (3.303e+23, 2.4397e6) + ... VENUS = (4.869e+24, 6.0518e6) + ... EARTH = (5.976e+24, 6.37814e6) + ... MARS = (6.421e+23, 3.3972e6) + ... JUPITER = (1.9e+27, 7.1492e7) + ... SATURN = (5.688e+26, 6.0268e7) + ... URANUS = (8.686e+25, 2.5559e7) + ... NEPTUNE = (1.024e+26, 2.4746e7) + ... def __init__(self, mass, radius): + ... self.mass = mass # in kilograms + ... self.radius = radius # in meters + ... @property + ... def surface_gravity(self): + ... # universal gravitational constant (m3 kg-1 s-2) + ... G = 6.67300E-11 + ... return G * self.mass / (self.radius * self.radius) + ... + >>> Planet.EARTH.value + (5.976e+24, 6378140.0) + >>> Planet.EARTH.surface_gravity + 9.802652743337129 + + +How are Enums different? +------------------------ + +Enums have a custom metaclass that affects many aspects of both derived Enum +classes and their instances (members). + + +Enum Classes +^^^^^^^^^^^^ + +The ``EnumMeta`` metaclass is responsible for providing the +``__contains__``, ``__dir__``, ``__iter__`` and other methods that +allow one to do things with an ``Enum`` class that fail on a typical +class, such as ``list(Color)`` or ``some_var in Color``. ``EnumMeta`` is +responsible for ensuring that various other methods on the final ``Enum`` +class are correct (such as ``__new__``, ``__getnewargs__``, +``__str__`` and ``__repr__``) + + +Enum Members (aka instances) +^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +The most interesting thing about Enum members is that they are singletons. +``EnumMeta`` creates them all while it is creating the ``Enum`` +class itself, and then puts a custom ``__new__`` in place to ensure +that no new ones are ever instantiated by returning only the existing +member instances. + + +Finer Points +^^^^^^^^^^^^ + +Enum members are instances of an Enum class, and even though they are +accessible as ``EnumClass.member``, they are not accessible directly from +the member:: + + >>> Color.red + + >>> Color.red.blue + Traceback (most recent call last): + ... + AttributeError: 'Color' object has no attribute 'blue' + +Likewise, ``__members__`` is only available on the class. + +In Python 3.x ``__members__`` is always an ``OrderedDict``, with the order being +the definition order. In Python 2.7 ``__members__`` is an ``OrderedDict`` if +``__order__`` was specified, and a plain ``dict`` otherwise. In all other Python +2.x versions ``__members__`` is a plain ``dict`` even if ``__order__`` was specified +as the ``OrderedDict`` type didn't exist yet. + +If you give your ``Enum`` subclass extra methods, like the `Planet`_ +class above, those methods will show up in a `dir` of the member, +but not of the class:: + + >>> dir(Planet) + ['EARTH', 'JUPITER', 'MARS', 'MERCURY', 'NEPTUNE', 'SATURN', 'URANUS', + 'VENUS', '__class__', '__doc__', '__members__', '__module__'] + >>> dir(Planet.EARTH) + ['__class__', '__doc__', '__module__', 'name', 'surface_gravity', 'value'] + +A ``__new__`` method will only be used for the creation of the +``Enum`` members -- after that it is replaced. This means if you wish to +change how ``Enum`` members are looked up you either have to write a +helper function or a ``classmethod``. diff --git a/lib/enum/enum.py b/lib/enum/enum.py new file mode 100644 index 00000000..6a327a8a --- /dev/null +++ b/lib/enum/enum.py @@ -0,0 +1,790 @@ +"""Python Enumerations""" + +import sys as _sys + +__all__ = ['Enum', 'IntEnum', 'unique'] + +version = 1, 0, 4 + +pyver = float('%s.%s' % _sys.version_info[:2]) + +try: + any +except NameError: + def any(iterable): + for element in iterable: + if element: + return True + return False + +try: + from collections import OrderedDict +except ImportError: + OrderedDict = None + +try: + basestring +except NameError: + # In Python 2 basestring is the ancestor of both str and unicode + # in Python 3 it's just str, but was missing in 3.1 + basestring = str + +try: + unicode +except NameError: + # In Python 3 unicode no longer exists (it's just str) + unicode = str + +class _RouteClassAttributeToGetattr(object): + """Route attribute access on a class to __getattr__. + + This is a descriptor, used to define attributes that act differently when + accessed through an instance and through a class. Instance access remains + normal, but access to an attribute through a class will be routed to the + class's __getattr__ method; this is done by raising AttributeError. + + """ + def __init__(self, fget=None): + self.fget = fget + + def __get__(self, instance, ownerclass=None): + if instance is None: + raise AttributeError() + return self.fget(instance) + + def __set__(self, instance, value): + raise AttributeError("can't set attribute") + + def __delete__(self, instance): + raise AttributeError("can't delete attribute") + + +def _is_descriptor(obj): + """Returns True if obj is a descriptor, False otherwise.""" + return ( + hasattr(obj, '__get__') or + hasattr(obj, '__set__') or + hasattr(obj, '__delete__')) + + +def _is_dunder(name): + """Returns True if a __dunder__ name, False otherwise.""" + return (name[:2] == name[-2:] == '__' and + name[2:3] != '_' and + name[-3:-2] != '_' and + len(name) > 4) + + +def _is_sunder(name): + """Returns True if a _sunder_ name, False otherwise.""" + return (name[0] == name[-1] == '_' and + name[1:2] != '_' and + name[-2:-1] != '_' and + len(name) > 2) + + +def _make_class_unpicklable(cls): + """Make the given class un-picklable.""" + def _break_on_call_reduce(self, protocol=None): + raise TypeError('%r cannot be pickled' % self) + cls.__reduce_ex__ = _break_on_call_reduce + cls.__module__ = '' + + +class _EnumDict(dict): + """Track enum member order and ensure member names are not reused. + + EnumMeta will use the names found in self._member_names as the + enumeration member names. + + """ + def __init__(self): + super(_EnumDict, self).__init__() + self._member_names = [] + + def __setitem__(self, key, value): + """Changes anything not dundered or not a descriptor. + + If a descriptor is added with the same name as an enum member, the name + is removed from _member_names (this may leave a hole in the numerical + sequence of values). + + If an enum member name is used twice, an error is raised; duplicate + values are not checked for. + + Single underscore (sunder) names are reserved. + + Note: in 3.x __order__ is simply discarded as a not necessary piece + leftover from 2.x + + """ + if pyver >= 3.0 and key == '__order__': + return + if _is_sunder(key): + raise ValueError('_names_ are reserved for future Enum use') + elif _is_dunder(key): + pass + elif key in self._member_names: + # descriptor overwriting an enum? + raise TypeError('Attempted to reuse key: %r' % key) + elif not _is_descriptor(value): + if key in self: + # enum overwriting a descriptor? + raise TypeError('Key already defined as: %r' % self[key]) + self._member_names.append(key) + super(_EnumDict, self).__setitem__(key, value) + + +# Dummy value for Enum as EnumMeta explicity checks for it, but of course until +# EnumMeta finishes running the first time the Enum class doesn't exist. This +# is also why there are checks in EnumMeta like `if Enum is not None` +Enum = None + + +class EnumMeta(type): + """Metaclass for Enum""" + @classmethod + def __prepare__(metacls, cls, bases): + return _EnumDict() + + def __new__(metacls, cls, bases, classdict): + # an Enum class is final once enumeration items have been defined; it + # cannot be mixed with other types (int, float, etc.) if it has an + # inherited __new__ unless a new __new__ is defined (or the resulting + # class will fail). + if type(classdict) is dict: + original_dict = classdict + classdict = _EnumDict() + for k, v in original_dict.items(): + classdict[k] = v + + member_type, first_enum = metacls._get_mixins_(bases) + __new__, save_new, use_args = metacls._find_new_(classdict, member_type, + first_enum) + # save enum items into separate mapping so they don't get baked into + # the new class + members = dict((k, classdict[k]) for k in classdict._member_names) + for name in classdict._member_names: + del classdict[name] + + # py2 support for definition order + __order__ = classdict.get('__order__') + if __order__ is None: + if pyver < 3.0: + try: + __order__ = [name for (name, value) in sorted(members.items(), key=lambda item: item[1])] + except TypeError: + __order__ = [name for name in sorted(members.keys())] + else: + __order__ = classdict._member_names + else: + del classdict['__order__'] + if pyver < 3.0: + __order__ = __order__.replace(',', ' ').split() + aliases = [name for name in members if name not in __order__] + __order__ += aliases + + # check for illegal enum names (any others?) + invalid_names = set(members) & set(['mro']) + if invalid_names: + raise ValueError('Invalid enum member name(s): %s' % ( + ', '.join(invalid_names), )) + + # create our new Enum type + enum_class = super(EnumMeta, metacls).__new__(metacls, cls, bases, classdict) + enum_class._member_names_ = [] # names in random order + if OrderedDict is not None: + enum_class._member_map_ = OrderedDict() + else: + enum_class._member_map_ = {} # name->value map + enum_class._member_type_ = member_type + + # Reverse value->name map for hashable values. + enum_class._value2member_map_ = {} + + # instantiate them, checking for duplicates as we go + # we instantiate first instead of checking for duplicates first in case + # a custom __new__ is doing something funky with the values -- such as + # auto-numbering ;) + if __new__ is None: + __new__ = enum_class.__new__ + for member_name in __order__: + value = members[member_name] + if not isinstance(value, tuple): + args = (value, ) + else: + args = value + if member_type is tuple: # special case for tuple enums + args = (args, ) # wrap it one more time + if not use_args or not args: + enum_member = __new__(enum_class) + if not hasattr(enum_member, '_value_'): + enum_member._value_ = value + else: + enum_member = __new__(enum_class, *args) + if not hasattr(enum_member, '_value_'): + enum_member._value_ = member_type(*args) + value = enum_member._value_ + enum_member._name_ = member_name + enum_member.__objclass__ = enum_class + enum_member.__init__(*args) + # If another member with the same value was already defined, the + # new member becomes an alias to the existing one. + for name, canonical_member in enum_class._member_map_.items(): + if canonical_member.value == enum_member._value_: + enum_member = canonical_member + break + else: + # Aliases don't appear in member names (only in __members__). + enum_class._member_names_.append(member_name) + enum_class._member_map_[member_name] = enum_member + try: + # This may fail if value is not hashable. We can't add the value + # to the map, and by-value lookups for this value will be + # linear. + enum_class._value2member_map_[value] = enum_member + except TypeError: + pass + + + # If a custom type is mixed into the Enum, and it does not know how + # to pickle itself, pickle.dumps will succeed but pickle.loads will + # fail. Rather than have the error show up later and possibly far + # from the source, sabotage the pickle protocol for this class so + # that pickle.dumps also fails. + # + # However, if the new class implements its own __reduce_ex__, do not + # sabotage -- it's on them to make sure it works correctly. We use + # __reduce_ex__ instead of any of the others as it is preferred by + # pickle over __reduce__, and it handles all pickle protocols. + unpicklable = False + if '__reduce_ex__' not in classdict: + if member_type is not object: + methods = ('__getnewargs_ex__', '__getnewargs__', + '__reduce_ex__', '__reduce__') + if not any(m in member_type.__dict__ for m in methods): + _make_class_unpicklable(enum_class) + unpicklable = True + + + # double check that repr and friends are not the mixin's or various + # things break (such as pickle) + for name in ('__repr__', '__str__', '__format__', '__reduce_ex__'): + class_method = getattr(enum_class, name) + obj_method = getattr(member_type, name, None) + enum_method = getattr(first_enum, name, None) + if name not in classdict and class_method is not enum_method: + if name == '__reduce_ex__' and unpicklable: + continue + setattr(enum_class, name, enum_method) + + # method resolution and int's are not playing nice + # Python's less than 2.6 use __cmp__ + + if pyver < 2.6: + + if issubclass(enum_class, int): + setattr(enum_class, '__cmp__', getattr(int, '__cmp__')) + + elif pyver < 3.0: + + if issubclass(enum_class, int): + for method in ( + '__le__', + '__lt__', + '__gt__', + '__ge__', + '__eq__', + '__ne__', + '__hash__', + ): + setattr(enum_class, method, getattr(int, method)) + + # replace any other __new__ with our own (as long as Enum is not None, + # anyway) -- again, this is to support pickle + if Enum is not None: + # if the user defined their own __new__, save it before it gets + # clobbered in case they subclass later + if save_new: + setattr(enum_class, '__member_new__', enum_class.__dict__['__new__']) + setattr(enum_class, '__new__', Enum.__dict__['__new__']) + return enum_class + + def __call__(cls, value, names=None, module=None, type=None): + """Either returns an existing member, or creates a new enum class. + + This method is used both when an enum class is given a value to match + to an enumeration member (i.e. Color(3)) and for the functional API + (i.e. Color = Enum('Color', names='red green blue')). + + When used for the functional API: `module`, if set, will be stored in + the new class' __module__ attribute; `type`, if set, will be mixed in + as the first base class. + + Note: if `module` is not set this routine will attempt to discover the + calling module by walking the frame stack; if this is unsuccessful + the resulting class will not be pickleable. + + """ + if names is None: # simple value lookup + return cls.__new__(cls, value) + # otherwise, functional API: we're creating a new Enum type + return cls._create_(value, names, module=module, type=type) + + def __contains__(cls, member): + return isinstance(member, cls) and member.name in cls._member_map_ + + def __delattr__(cls, attr): + # nicer error message when someone tries to delete an attribute + # (see issue19025). + if attr in cls._member_map_: + raise AttributeError( + "%s: cannot delete Enum member." % cls.__name__) + super(EnumMeta, cls).__delattr__(attr) + + def __dir__(self): + return (['__class__', '__doc__', '__members__', '__module__'] + + self._member_names_) + + @property + def __members__(cls): + """Returns a mapping of member name->value. + + This mapping lists all enum members, including aliases. Note that this + is a copy of the internal mapping. + + """ + return cls._member_map_.copy() + + def __getattr__(cls, name): + """Return the enum member matching `name` + + We use __getattr__ instead of descriptors or inserting into the enum + class' __dict__ in order to support `name` and `value` being both + properties for enum members (which live in the class' __dict__) and + enum members themselves. + + """ + if _is_dunder(name): + raise AttributeError(name) + try: + return cls._member_map_[name] + except KeyError: + raise AttributeError(name) + + def __getitem__(cls, name): + return cls._member_map_[name] + + def __iter__(cls): + return (cls._member_map_[name] for name in cls._member_names_) + + def __reversed__(cls): + return (cls._member_map_[name] for name in reversed(cls._member_names_)) + + def __len__(cls): + return len(cls._member_names_) + + def __repr__(cls): + return "" % cls.__name__ + + def __setattr__(cls, name, value): + """Block attempts to reassign Enum members. + + A simple assignment to the class namespace only changes one of the + several possible ways to get an Enum member from the Enum class, + resulting in an inconsistent Enumeration. + + """ + member_map = cls.__dict__.get('_member_map_', {}) + if name in member_map: + raise AttributeError('Cannot reassign members.') + super(EnumMeta, cls).__setattr__(name, value) + + def _create_(cls, class_name, names=None, module=None, type=None): + """Convenience method to create a new Enum class. + + `names` can be: + + * A string containing member names, separated either with spaces or + commas. Values are auto-numbered from 1. + * An iterable of member names. Values are auto-numbered from 1. + * An iterable of (member name, value) pairs. + * A mapping of member name -> value. + + """ + if pyver < 3.0: + # if class_name is unicode, attempt a conversion to ASCII + if isinstance(class_name, unicode): + try: + class_name = class_name.encode('ascii') + except UnicodeEncodeError: + raise TypeError('%r is not representable in ASCII' % class_name) + metacls = cls.__class__ + if type is None: + bases = (cls, ) + else: + bases = (type, cls) + classdict = metacls.__prepare__(class_name, bases) + __order__ = [] + + # special processing needed for names? + if isinstance(names, basestring): + names = names.replace(',', ' ').split() + if isinstance(names, (tuple, list)) and isinstance(names[0], basestring): + names = [(e, i+1) for (i, e) in enumerate(names)] + + # Here, names is either an iterable of (name, value) or a mapping. + for item in names: + if isinstance(item, basestring): + member_name, member_value = item, names[item] + else: + member_name, member_value = item + classdict[member_name] = member_value + __order__.append(member_name) + # only set __order__ in classdict if name/value was not from a mapping + if not isinstance(item, basestring): + classdict['__order__'] = ' '.join(__order__) + enum_class = metacls.__new__(metacls, class_name, bases, classdict) + + # TODO: replace the frame hack if a blessed way to know the calling + # module is ever developed + if module is None: + try: + module = _sys._getframe(2).f_globals['__name__'] + except (AttributeError, ValueError): + pass + if module is None: + _make_class_unpicklable(enum_class) + else: + enum_class.__module__ = module + + return enum_class + + @staticmethod + def _get_mixins_(bases): + """Returns the type for creating enum members, and the first inherited + enum class. + + bases: the tuple of bases that was given to __new__ + + """ + if not bases or Enum is None: + return object, Enum + + + # double check that we are not subclassing a class with existing + # enumeration members; while we're at it, see if any other data + # type has been mixed in so we can use the correct __new__ + member_type = first_enum = None + for base in bases: + if (base is not Enum and + issubclass(base, Enum) and + base._member_names_): + raise TypeError("Cannot extend enumerations") + # base is now the last base in bases + if not issubclass(base, Enum): + raise TypeError("new enumerations must be created as " + "`ClassName([mixin_type,] enum_type)`") + + # get correct mix-in type (either mix-in type of Enum subclass, or + # first base if last base is Enum) + if not issubclass(bases[0], Enum): + member_type = bases[0] # first data type + first_enum = bases[-1] # enum type + else: + for base in bases[0].__mro__: + # most common: (IntEnum, int, Enum, object) + # possible: (, , + # , , + # ) + if issubclass(base, Enum): + if first_enum is None: + first_enum = base + else: + if member_type is None: + member_type = base + + return member_type, first_enum + + if pyver < 3.0: + @staticmethod + def _find_new_(classdict, member_type, first_enum): + """Returns the __new__ to be used for creating the enum members. + + classdict: the class dictionary given to __new__ + member_type: the data type whose __new__ will be used by default + first_enum: enumeration to check for an overriding __new__ + + """ + # now find the correct __new__, checking to see of one was defined + # by the user; also check earlier enum classes in case a __new__ was + # saved as __member_new__ + __new__ = classdict.get('__new__', None) + if __new__: + return None, True, True # __new__, save_new, use_args + + N__new__ = getattr(None, '__new__') + O__new__ = getattr(object, '__new__') + if Enum is None: + E__new__ = N__new__ + else: + E__new__ = Enum.__dict__['__new__'] + # check all possibles for __member_new__ before falling back to + # __new__ + for method in ('__member_new__', '__new__'): + for possible in (member_type, first_enum): + try: + target = possible.__dict__[method] + except (AttributeError, KeyError): + target = getattr(possible, method, None) + if target not in [ + None, + N__new__, + O__new__, + E__new__, + ]: + if method == '__member_new__': + classdict['__new__'] = target + return None, False, True + if isinstance(target, staticmethod): + target = target.__get__(member_type) + __new__ = target + break + if __new__ is not None: + break + else: + __new__ = object.__new__ + + # if a non-object.__new__ is used then whatever value/tuple was + # assigned to the enum member name will be passed to __new__ and to the + # new enum member's __init__ + if __new__ is object.__new__: + use_args = False + else: + use_args = True + + return __new__, False, use_args + else: + @staticmethod + def _find_new_(classdict, member_type, first_enum): + """Returns the __new__ to be used for creating the enum members. + + classdict: the class dictionary given to __new__ + member_type: the data type whose __new__ will be used by default + first_enum: enumeration to check for an overriding __new__ + + """ + # now find the correct __new__, checking to see of one was defined + # by the user; also check earlier enum classes in case a __new__ was + # saved as __member_new__ + __new__ = classdict.get('__new__', None) + + # should __new__ be saved as __member_new__ later? + save_new = __new__ is not None + + if __new__ is None: + # check all possibles for __member_new__ before falling back to + # __new__ + for method in ('__member_new__', '__new__'): + for possible in (member_type, first_enum): + target = getattr(possible, method, None) + if target not in ( + None, + None.__new__, + object.__new__, + Enum.__new__, + ): + __new__ = target + break + if __new__ is not None: + break + else: + __new__ = object.__new__ + + # if a non-object.__new__ is used then whatever value/tuple was + # assigned to the enum member name will be passed to __new__ and to the + # new enum member's __init__ + if __new__ is object.__new__: + use_args = False + else: + use_args = True + + return __new__, save_new, use_args + + +######################################################## +# In order to support Python 2 and 3 with a single +# codebase we have to create the Enum methods separately +# and then use the `type(name, bases, dict)` method to +# create the class. +######################################################## +temp_enum_dict = {} +temp_enum_dict['__doc__'] = "Generic enumeration.\n\n Derive from this class to define new enumerations.\n\n" + +def __new__(cls, value): + # all enum instances are actually created during class construction + # without calling this method; this method is called by the metaclass' + # __call__ (i.e. Color(3) ), and by pickle + if type(value) is cls: + # For lookups like Color(Color.red) + value = value.value + #return value + # by-value search for a matching enum member + # see if it's in the reverse mapping (for hashable values) + try: + if value in cls._value2member_map_: + return cls._value2member_map_[value] + except TypeError: + # not there, now do long search -- O(n) behavior + for member in cls._member_map_.values(): + if member.value == value: + return member + raise ValueError("%s is not a valid %s" % (value, cls.__name__)) +temp_enum_dict['__new__'] = __new__ +del __new__ + +def __repr__(self): + return "<%s.%s: %r>" % ( + self.__class__.__name__, self._name_, self._value_) +temp_enum_dict['__repr__'] = __repr__ +del __repr__ + +def __str__(self): + return "%s.%s" % (self.__class__.__name__, self._name_) +temp_enum_dict['__str__'] = __str__ +del __str__ + +def __dir__(self): + added_behavior = [ + m + for cls in self.__class__.mro() + for m in cls.__dict__ + if m[0] != '_' + ] + return (['__class__', '__doc__', '__module__', ] + added_behavior) +temp_enum_dict['__dir__'] = __dir__ +del __dir__ + +def __format__(self, format_spec): + # mixed-in Enums should use the mixed-in type's __format__, otherwise + # we can get strange results with the Enum name showing up instead of + # the value + + # pure Enum branch + if self._member_type_ is object: + cls = str + val = str(self) + # mix-in branch + else: + cls = self._member_type_ + val = self.value + return cls.__format__(val, format_spec) +temp_enum_dict['__format__'] = __format__ +del __format__ + + +#################################### +# Python's less than 2.6 use __cmp__ + +if pyver < 2.6: + + def __cmp__(self, other): + if type(other) is self.__class__: + if self is other: + return 0 + return -1 + return NotImplemented + raise TypeError("unorderable types: %s() and %s()" % (self.__class__.__name__, other.__class__.__name__)) + temp_enum_dict['__cmp__'] = __cmp__ + del __cmp__ + +else: + + def __le__(self, other): + raise TypeError("unorderable types: %s() <= %s()" % (self.__class__.__name__, other.__class__.__name__)) + temp_enum_dict['__le__'] = __le__ + del __le__ + + def __lt__(self, other): + raise TypeError("unorderable types: %s() < %s()" % (self.__class__.__name__, other.__class__.__name__)) + temp_enum_dict['__lt__'] = __lt__ + del __lt__ + + def __ge__(self, other): + raise TypeError("unorderable types: %s() >= %s()" % (self.__class__.__name__, other.__class__.__name__)) + temp_enum_dict['__ge__'] = __ge__ + del __ge__ + + def __gt__(self, other): + raise TypeError("unorderable types: %s() > %s()" % (self.__class__.__name__, other.__class__.__name__)) + temp_enum_dict['__gt__'] = __gt__ + del __gt__ + + +def __eq__(self, other): + if type(other) is self.__class__: + return self is other + return NotImplemented +temp_enum_dict['__eq__'] = __eq__ +del __eq__ + +def __ne__(self, other): + if type(other) is self.__class__: + return self is not other + return NotImplemented +temp_enum_dict['__ne__'] = __ne__ +del __ne__ + +def __hash__(self): + return hash(self._name_) +temp_enum_dict['__hash__'] = __hash__ +del __hash__ + +def __reduce_ex__(self, proto): + return self.__class__, (self._value_, ) +temp_enum_dict['__reduce_ex__'] = __reduce_ex__ +del __reduce_ex__ + +# _RouteClassAttributeToGetattr is used to provide access to the `name` +# and `value` properties of enum members while keeping some measure of +# protection from modification, while still allowing for an enumeration +# to have members named `name` and `value`. This works because enumeration +# members are not set directly on the enum class -- __getattr__ is +# used to look them up. + +@_RouteClassAttributeToGetattr +def name(self): + return self._name_ +temp_enum_dict['name'] = name +del name + +@_RouteClassAttributeToGetattr +def value(self): + return self._value_ +temp_enum_dict['value'] = value +del value + +Enum = EnumMeta('Enum', (object, ), temp_enum_dict) +del temp_enum_dict + +# Enum has now been created +########################### + +class IntEnum(int, Enum): + """Enum where members are also (and must be) ints""" + + +def unique(enumeration): + """Class decorator that ensures only unique members exist in an enumeration.""" + duplicates = [] + for name, member in enumeration.__members__.items(): + if name != member.name: + duplicates.append((name, member.name)) + if duplicates: + duplicate_names = ', '.join( + ["%s -> %s" % (alias, name) for (alias, name) in duplicates] + ) + raise ValueError('duplicate names found in %r: %s' % + (enumeration, duplicate_names) + ) + return enumeration diff --git a/lib/enum/test_enum.py b/lib/enum/test_enum.py new file mode 100644 index 00000000..d7a97942 --- /dev/null +++ b/lib/enum/test_enum.py @@ -0,0 +1,1690 @@ +import enum +import sys +import unittest +from enum import Enum, IntEnum, unique, EnumMeta +from pickle import dumps, loads, PicklingError, HIGHEST_PROTOCOL + +pyver = float('%s.%s' % sys.version_info[:2]) + +try: + any +except NameError: + def any(iterable): + for element in iterable: + if element: + return True + return False + +try: + unicode +except NameError: + unicode = str + +try: + from collections import OrderedDict +except ImportError: + OrderedDict = None + +# for pickle tests +try: + class Stooges(Enum): + LARRY = 1 + CURLY = 2 + MOE = 3 +except Exception: + Stooges = sys.exc_info()[1] + +try: + class IntStooges(int, Enum): + LARRY = 1 + CURLY = 2 + MOE = 3 +except Exception: + IntStooges = sys.exc_info()[1] + +try: + class FloatStooges(float, Enum): + LARRY = 1.39 + CURLY = 2.72 + MOE = 3.142596 +except Exception: + FloatStooges = sys.exc_info()[1] + +# for pickle test and subclass tests +try: + class StrEnum(str, Enum): + 'accepts only string values' + class Name(StrEnum): + BDFL = 'Guido van Rossum' + FLUFL = 'Barry Warsaw' +except Exception: + Name = sys.exc_info()[1] + +try: + Question = Enum('Question', 'who what when where why', module=__name__) +except Exception: + Question = sys.exc_info()[1] + +try: + Answer = Enum('Answer', 'him this then there because') +except Exception: + Answer = sys.exc_info()[1] + +try: + Theory = Enum('Theory', 'rule law supposition', qualname='spanish_inquisition') +except Exception: + Theory = sys.exc_info()[1] + +# for doctests +try: + class Fruit(Enum): + tomato = 1 + banana = 2 + cherry = 3 +except Exception: + pass + +def test_pickle_dump_load(assertion, source, target=None, + protocol=(0, HIGHEST_PROTOCOL)): + start, stop = protocol + failures = [] + for protocol in range(start, stop+1): + try: + if target is None: + assertion(loads(dumps(source, protocol=protocol)) is source) + else: + assertion(loads(dumps(source, protocol=protocol)), target) + except Exception: + exc, tb = sys.exc_info()[1:] + failures.append('%2d: %s' %(protocol, exc)) + if failures: + raise ValueError('Failed with protocols: %s' % ', '.join(failures)) + +def test_pickle_exception(assertion, exception, obj, + protocol=(0, HIGHEST_PROTOCOL)): + start, stop = protocol + failures = [] + for protocol in range(start, stop+1): + try: + assertion(exception, dumps, obj, protocol=protocol) + except Exception: + exc = sys.exc_info()[1] + failures.append('%d: %s %s' % (protocol, exc.__class__.__name__, exc)) + if failures: + raise ValueError('Failed with protocols: %s' % ', '.join(failures)) + + +class TestHelpers(unittest.TestCase): + # _is_descriptor, _is_sunder, _is_dunder + + def test_is_descriptor(self): + class foo: + pass + for attr in ('__get__','__set__','__delete__'): + obj = foo() + self.assertFalse(enum._is_descriptor(obj)) + setattr(obj, attr, 1) + self.assertTrue(enum._is_descriptor(obj)) + + def test_is_sunder(self): + for s in ('_a_', '_aa_'): + self.assertTrue(enum._is_sunder(s)) + + for s in ('a', 'a_', '_a', '__a', 'a__', '__a__', '_a__', '__a_', '_', + '__', '___', '____', '_____',): + self.assertFalse(enum._is_sunder(s)) + + def test_is_dunder(self): + for s in ('__a__', '__aa__'): + self.assertTrue(enum._is_dunder(s)) + for s in ('a', 'a_', '_a', '__a', 'a__', '_a_', '_a__', '__a_', '_', + '__', '___', '____', '_____',): + self.assertFalse(enum._is_dunder(s)) + + +class TestEnum(unittest.TestCase): + def setUp(self): + class Season(Enum): + SPRING = 1 + SUMMER = 2 + AUTUMN = 3 + WINTER = 4 + self.Season = Season + + class Konstants(float, Enum): + E = 2.7182818 + PI = 3.1415926 + TAU = 2 * PI + self.Konstants = Konstants + + class Grades(IntEnum): + A = 5 + B = 4 + C = 3 + D = 2 + F = 0 + self.Grades = Grades + + class Directional(str, Enum): + EAST = 'east' + WEST = 'west' + NORTH = 'north' + SOUTH = 'south' + self.Directional = Directional + + from datetime import date + class Holiday(date, Enum): + NEW_YEAR = 2013, 1, 1 + IDES_OF_MARCH = 2013, 3, 15 + self.Holiday = Holiday + + if pyver >= 2.6: # cannot specify custom `dir` on previous versions + def test_dir_on_class(self): + Season = self.Season + self.assertEqual( + set(dir(Season)), + set(['__class__', '__doc__', '__members__', '__module__', + 'SPRING', 'SUMMER', 'AUTUMN', 'WINTER']), + ) + + def test_dir_on_item(self): + Season = self.Season + self.assertEqual( + set(dir(Season.WINTER)), + set(['__class__', '__doc__', '__module__', 'name', 'value']), + ) + + def test_dir_on_sub_with_behavior_on_super(self): + # see issue22506 + class SuperEnum(Enum): + def invisible(self): + return "did you see me?" + class SubEnum(SuperEnum): + sample = 5 + self.assertEqual( + set(dir(SubEnum.sample)), + set(['__class__', '__doc__', '__module__', 'name', 'value', 'invisible']), + ) + + if pyver >= 2.7: # OrderedDict first available here + def test_members_is_ordereddict_if_ordered(self): + class Ordered(Enum): + __order__ = 'first second third' + first = 'bippity' + second = 'boppity' + third = 'boo' + self.assertTrue(type(Ordered.__members__) is OrderedDict) + + def test_members_is_ordereddict_if_not_ordered(self): + class Unordered(Enum): + this = 'that' + these = 'those' + self.assertTrue(type(Unordered.__members__) is OrderedDict) + + if pyver >= 3.0: # all objects are ordered in Python 2.x + def test_members_is_always_ordered(self): + class AlwaysOrdered(Enum): + first = 1 + second = 2 + third = 3 + self.assertTrue(type(AlwaysOrdered.__members__) is OrderedDict) + + def test_comparisons(self): + def bad_compare(): + Season.SPRING > 4 + Season = self.Season + self.assertNotEqual(Season.SPRING, 1) + self.assertRaises(TypeError, bad_compare) + + class Part(Enum): + SPRING = 1 + CLIP = 2 + BARREL = 3 + + self.assertNotEqual(Season.SPRING, Part.SPRING) + def bad_compare(): + Season.SPRING < Part.CLIP + self.assertRaises(TypeError, bad_compare) + + def test_enum_in_enum_out(self): + Season = self.Season + self.assertTrue(Season(Season.WINTER) is Season.WINTER) + + def test_enum_value(self): + Season = self.Season + self.assertEqual(Season.SPRING.value, 1) + + def test_intenum_value(self): + self.assertEqual(IntStooges.CURLY.value, 2) + + def test_enum(self): + Season = self.Season + lst = list(Season) + self.assertEqual(len(lst), len(Season)) + self.assertEqual(len(Season), 4, Season) + self.assertEqual( + [Season.SPRING, Season.SUMMER, Season.AUTUMN, Season.WINTER], lst) + + for i, season in enumerate('SPRING SUMMER AUTUMN WINTER'.split()): + i += 1 + e = Season(i) + self.assertEqual(e, getattr(Season, season)) + self.assertEqual(e.value, i) + self.assertNotEqual(e, i) + self.assertEqual(e.name, season) + self.assertTrue(e in Season) + self.assertTrue(type(e) is Season) + self.assertTrue(isinstance(e, Season)) + self.assertEqual(str(e), 'Season.' + season) + self.assertEqual( + repr(e), + '' % (season, i), + ) + + def test_value_name(self): + Season = self.Season + self.assertEqual(Season.SPRING.name, 'SPRING') + self.assertEqual(Season.SPRING.value, 1) + def set_name(obj, new_value): + obj.name = new_value + def set_value(obj, new_value): + obj.value = new_value + self.assertRaises(AttributeError, set_name, Season.SPRING, 'invierno', ) + self.assertRaises(AttributeError, set_value, Season.SPRING, 2) + + def test_attribute_deletion(self): + class Season(Enum): + SPRING = 1 + SUMMER = 2 + AUTUMN = 3 + WINTER = 4 + + def spam(cls): + pass + + self.assertTrue(hasattr(Season, 'spam')) + del Season.spam + self.assertFalse(hasattr(Season, 'spam')) + + self.assertRaises(AttributeError, delattr, Season, 'SPRING') + self.assertRaises(AttributeError, delattr, Season, 'DRY') + self.assertRaises(AttributeError, delattr, Season.SPRING, 'name') + + def test_invalid_names(self): + def create_bad_class_1(): + class Wrong(Enum): + mro = 9 + def create_bad_class_2(): + class Wrong(Enum): + _reserved_ = 3 + self.assertRaises(ValueError, create_bad_class_1) + self.assertRaises(ValueError, create_bad_class_2) + + def test_contains(self): + Season = self.Season + self.assertTrue(Season.AUTUMN in Season) + self.assertTrue(3 not in Season) + + val = Season(3) + self.assertTrue(val in Season) + + class OtherEnum(Enum): + one = 1; two = 2 + self.assertTrue(OtherEnum.two not in Season) + + if pyver >= 2.6: # when `format` came into being + + def test_format_enum(self): + Season = self.Season + self.assertEqual('{0}'.format(Season.SPRING), + '{0}'.format(str(Season.SPRING))) + self.assertEqual( '{0:}'.format(Season.SPRING), + '{0:}'.format(str(Season.SPRING))) + self.assertEqual('{0:20}'.format(Season.SPRING), + '{0:20}'.format(str(Season.SPRING))) + self.assertEqual('{0:^20}'.format(Season.SPRING), + '{0:^20}'.format(str(Season.SPRING))) + self.assertEqual('{0:>20}'.format(Season.SPRING), + '{0:>20}'.format(str(Season.SPRING))) + self.assertEqual('{0:<20}'.format(Season.SPRING), + '{0:<20}'.format(str(Season.SPRING))) + + def test_format_enum_custom(self): + class TestFloat(float, Enum): + one = 1.0 + two = 2.0 + def __format__(self, spec): + return 'TestFloat success!' + self.assertEqual('{0}'.format(TestFloat.one), 'TestFloat success!') + + def assertFormatIsValue(self, spec, member): + self.assertEqual(spec.format(member), spec.format(member.value)) + + def test_format_enum_date(self): + Holiday = self.Holiday + self.assertFormatIsValue('{0}', Holiday.IDES_OF_MARCH) + self.assertFormatIsValue('{0:}', Holiday.IDES_OF_MARCH) + self.assertFormatIsValue('{0:20}', Holiday.IDES_OF_MARCH) + self.assertFormatIsValue('{0:^20}', Holiday.IDES_OF_MARCH) + self.assertFormatIsValue('{0:>20}', Holiday.IDES_OF_MARCH) + self.assertFormatIsValue('{0:<20}', Holiday.IDES_OF_MARCH) + self.assertFormatIsValue('{0:%Y %m}', Holiday.IDES_OF_MARCH) + self.assertFormatIsValue('{0:%Y %m %M:00}', Holiday.IDES_OF_MARCH) + + def test_format_enum_float(self): + Konstants = self.Konstants + self.assertFormatIsValue('{0}', Konstants.TAU) + self.assertFormatIsValue('{0:}', Konstants.TAU) + self.assertFormatIsValue('{0:20}', Konstants.TAU) + self.assertFormatIsValue('{0:^20}', Konstants.TAU) + self.assertFormatIsValue('{0:>20}', Konstants.TAU) + self.assertFormatIsValue('{0:<20}', Konstants.TAU) + self.assertFormatIsValue('{0:n}', Konstants.TAU) + self.assertFormatIsValue('{0:5.2}', Konstants.TAU) + self.assertFormatIsValue('{0:f}', Konstants.TAU) + + def test_format_enum_int(self): + Grades = self.Grades + self.assertFormatIsValue('{0}', Grades.C) + self.assertFormatIsValue('{0:}', Grades.C) + self.assertFormatIsValue('{0:20}', Grades.C) + self.assertFormatIsValue('{0:^20}', Grades.C) + self.assertFormatIsValue('{0:>20}', Grades.C) + self.assertFormatIsValue('{0:<20}', Grades.C) + self.assertFormatIsValue('{0:+}', Grades.C) + self.assertFormatIsValue('{0:08X}', Grades.C) + self.assertFormatIsValue('{0:b}', Grades.C) + + def test_format_enum_str(self): + Directional = self.Directional + self.assertFormatIsValue('{0}', Directional.WEST) + self.assertFormatIsValue('{0:}', Directional.WEST) + self.assertFormatIsValue('{0:20}', Directional.WEST) + self.assertFormatIsValue('{0:^20}', Directional.WEST) + self.assertFormatIsValue('{0:>20}', Directional.WEST) + self.assertFormatIsValue('{0:<20}', Directional.WEST) + + def test_hash(self): + Season = self.Season + dates = {} + dates[Season.WINTER] = '1225' + dates[Season.SPRING] = '0315' + dates[Season.SUMMER] = '0704' + dates[Season.AUTUMN] = '1031' + self.assertEqual(dates[Season.AUTUMN], '1031') + + def test_enum_duplicates(self): + __order__ = "SPRING SUMMER AUTUMN WINTER" + class Season(Enum): + SPRING = 1 + SUMMER = 2 + AUTUMN = FALL = 3 + WINTER = 4 + ANOTHER_SPRING = 1 + lst = list(Season) + self.assertEqual( + lst, + [Season.SPRING, Season.SUMMER, + Season.AUTUMN, Season.WINTER, + ]) + self.assertTrue(Season.FALL is Season.AUTUMN) + self.assertEqual(Season.FALL.value, 3) + self.assertEqual(Season.AUTUMN.value, 3) + self.assertTrue(Season(3) is Season.AUTUMN) + self.assertTrue(Season(1) is Season.SPRING) + self.assertEqual(Season.FALL.name, 'AUTUMN') + self.assertEqual( + set([k for k,v in Season.__members__.items() if v.name != k]), + set(['FALL', 'ANOTHER_SPRING']), + ) + + if pyver >= 3.0: + cls = vars() + result = {'Enum':Enum} + exec("""def test_duplicate_name(self): + with self.assertRaises(TypeError): + class Color(Enum): + red = 1 + green = 2 + blue = 3 + red = 4 + + with self.assertRaises(TypeError): + class Color(Enum): + red = 1 + green = 2 + blue = 3 + def red(self): + return 'red' + + with self.assertRaises(TypeError): + class Color(Enum): + @property + + def red(self): + return 'redder' + red = 1 + green = 2 + blue = 3""", + result) + cls['test_duplicate_name'] = result['test_duplicate_name'] + + def test_enum_with_value_name(self): + class Huh(Enum): + name = 1 + value = 2 + self.assertEqual( + list(Huh), + [Huh.name, Huh.value], + ) + self.assertTrue(type(Huh.name) is Huh) + self.assertEqual(Huh.name.name, 'name') + self.assertEqual(Huh.name.value, 1) + + def test_intenum_from_scratch(self): + class phy(int, Enum): + pi = 3 + tau = 2 * pi + self.assertTrue(phy.pi < phy.tau) + + def test_intenum_inherited(self): + class IntEnum(int, Enum): + pass + class phy(IntEnum): + pi = 3 + tau = 2 * pi + self.assertTrue(phy.pi < phy.tau) + + def test_floatenum_from_scratch(self): + class phy(float, Enum): + pi = 3.1415926 + tau = 2 * pi + self.assertTrue(phy.pi < phy.tau) + + def test_floatenum_inherited(self): + class FloatEnum(float, Enum): + pass + class phy(FloatEnum): + pi = 3.1415926 + tau = 2 * pi + self.assertTrue(phy.pi < phy.tau) + + def test_strenum_from_scratch(self): + class phy(str, Enum): + pi = 'Pi' + tau = 'Tau' + self.assertTrue(phy.pi < phy.tau) + + def test_strenum_inherited(self): + class StrEnum(str, Enum): + pass + class phy(StrEnum): + pi = 'Pi' + tau = 'Tau' + self.assertTrue(phy.pi < phy.tau) + + def test_intenum(self): + class WeekDay(IntEnum): + SUNDAY = 1 + MONDAY = 2 + TUESDAY = 3 + WEDNESDAY = 4 + THURSDAY = 5 + FRIDAY = 6 + SATURDAY = 7 + + self.assertEqual(['a', 'b', 'c'][WeekDay.MONDAY], 'c') + self.assertEqual([i for i in range(WeekDay.TUESDAY)], [0, 1, 2]) + + lst = list(WeekDay) + self.assertEqual(len(lst), len(WeekDay)) + self.assertEqual(len(WeekDay), 7) + target = 'SUNDAY MONDAY TUESDAY WEDNESDAY THURSDAY FRIDAY SATURDAY' + target = target.split() + for i, weekday in enumerate(target): + i += 1 + e = WeekDay(i) + self.assertEqual(e, i) + self.assertEqual(int(e), i) + self.assertEqual(e.name, weekday) + self.assertTrue(e in WeekDay) + self.assertEqual(lst.index(e)+1, i) + self.assertTrue(0 < e < 8) + self.assertTrue(type(e) is WeekDay) + self.assertTrue(isinstance(e, int)) + self.assertTrue(isinstance(e, Enum)) + + def test_intenum_duplicates(self): + class WeekDay(IntEnum): + __order__ = 'SUNDAY MONDAY TUESDAY WEDNESDAY THURSDAY FRIDAY SATURDAY' + SUNDAY = 1 + MONDAY = 2 + TUESDAY = TEUSDAY = 3 + WEDNESDAY = 4 + THURSDAY = 5 + FRIDAY = 6 + SATURDAY = 7 + self.assertTrue(WeekDay.TEUSDAY is WeekDay.TUESDAY) + self.assertEqual(WeekDay(3).name, 'TUESDAY') + self.assertEqual([k for k,v in WeekDay.__members__.items() + if v.name != k], ['TEUSDAY', ]) + + def test_pickle_enum(self): + if isinstance(Stooges, Exception): + raise Stooges + test_pickle_dump_load(self.assertTrue, Stooges.CURLY) + test_pickle_dump_load(self.assertTrue, Stooges) + + def test_pickle_int(self): + if isinstance(IntStooges, Exception): + raise IntStooges + test_pickle_dump_load(self.assertTrue, IntStooges.CURLY) + test_pickle_dump_load(self.assertTrue, IntStooges) + + def test_pickle_float(self): + if isinstance(FloatStooges, Exception): + raise FloatStooges + test_pickle_dump_load(self.assertTrue, FloatStooges.CURLY) + test_pickle_dump_load(self.assertTrue, FloatStooges) + + def test_pickle_enum_function(self): + if isinstance(Answer, Exception): + raise Answer + test_pickle_dump_load(self.assertTrue, Answer.him) + test_pickle_dump_load(self.assertTrue, Answer) + + def test_pickle_enum_function_with_module(self): + if isinstance(Question, Exception): + raise Question + test_pickle_dump_load(self.assertTrue, Question.who) + test_pickle_dump_load(self.assertTrue, Question) + + if pyver >= 3.4: + def test_class_nested_enum_and_pickle_protocol_four(self): + # would normally just have this directly in the class namespace + class NestedEnum(Enum): + twigs = 'common' + shiny = 'rare' + + self.__class__.NestedEnum = NestedEnum + self.NestedEnum.__qualname__ = '%s.NestedEnum' % self.__class__.__name__ + test_pickle_exception( + self.assertRaises, PicklingError, self.NestedEnum.twigs, + protocol=(0, 3)) + test_pickle_dump_load(self.assertTrue, self.NestedEnum.twigs, + protocol=(4, HIGHEST_PROTOCOL)) + + def test_exploding_pickle(self): + BadPickle = Enum('BadPickle', 'dill sweet bread-n-butter') + enum._make_class_unpicklable(BadPickle) + globals()['BadPickle'] = BadPickle + test_pickle_exception(self.assertRaises, TypeError, BadPickle.dill) + test_pickle_exception(self.assertRaises, PicklingError, BadPickle) + + def test_string_enum(self): + class SkillLevel(str, Enum): + master = 'what is the sound of one hand clapping?' + journeyman = 'why did the chicken cross the road?' + apprentice = 'knock, knock!' + self.assertEqual(SkillLevel.apprentice, 'knock, knock!') + + def test_getattr_getitem(self): + class Period(Enum): + morning = 1 + noon = 2 + evening = 3 + night = 4 + self.assertTrue(Period(2) is Period.noon) + self.assertTrue(getattr(Period, 'night') is Period.night) + self.assertTrue(Period['morning'] is Period.morning) + + def test_getattr_dunder(self): + Season = self.Season + self.assertTrue(getattr(Season, '__hash__')) + + def test_iteration_order(self): + class Season(Enum): + __order__ = 'SUMMER WINTER AUTUMN SPRING' + SUMMER = 2 + WINTER = 4 + AUTUMN = 3 + SPRING = 1 + self.assertEqual( + list(Season), + [Season.SUMMER, Season.WINTER, Season.AUTUMN, Season.SPRING], + ) + + def test_iteration_order_with_unorderable_values(self): + class Complex(Enum): + a = complex(7, 9) + b = complex(3.14, 2) + c = complex(1, -1) + d = complex(-77, 32) + self.assertEqual( + list(Complex), + [Complex.a, Complex.b, Complex.c, Complex.d], + ) + + def test_programatic_function_string(self): + SummerMonth = Enum('SummerMonth', 'june july august') + lst = list(SummerMonth) + self.assertEqual(len(lst), len(SummerMonth)) + self.assertEqual(len(SummerMonth), 3, SummerMonth) + self.assertEqual( + [SummerMonth.june, SummerMonth.july, SummerMonth.august], + lst, + ) + for i, month in enumerate('june july august'.split()): + i += 1 + e = SummerMonth(i) + self.assertEqual(int(e.value), i) + self.assertNotEqual(e, i) + self.assertEqual(e.name, month) + self.assertTrue(e in SummerMonth) + self.assertTrue(type(e) is SummerMonth) + + def test_programatic_function_string_list(self): + SummerMonth = Enum('SummerMonth', ['june', 'july', 'august']) + lst = list(SummerMonth) + self.assertEqual(len(lst), len(SummerMonth)) + self.assertEqual(len(SummerMonth), 3, SummerMonth) + self.assertEqual( + [SummerMonth.june, SummerMonth.july, SummerMonth.august], + lst, + ) + for i, month in enumerate('june july august'.split()): + i += 1 + e = SummerMonth(i) + self.assertEqual(int(e.value), i) + self.assertNotEqual(e, i) + self.assertEqual(e.name, month) + self.assertTrue(e in SummerMonth) + self.assertTrue(type(e) is SummerMonth) + + def test_programatic_function_iterable(self): + SummerMonth = Enum( + 'SummerMonth', + (('june', 1), ('july', 2), ('august', 3)) + ) + lst = list(SummerMonth) + self.assertEqual(len(lst), len(SummerMonth)) + self.assertEqual(len(SummerMonth), 3, SummerMonth) + self.assertEqual( + [SummerMonth.june, SummerMonth.july, SummerMonth.august], + lst, + ) + for i, month in enumerate('june july august'.split()): + i += 1 + e = SummerMonth(i) + self.assertEqual(int(e.value), i) + self.assertNotEqual(e, i) + self.assertEqual(e.name, month) + self.assertTrue(e in SummerMonth) + self.assertTrue(type(e) is SummerMonth) + + def test_programatic_function_from_dict(self): + SummerMonth = Enum( + 'SummerMonth', + dict((('june', 1), ('july', 2), ('august', 3))) + ) + lst = list(SummerMonth) + self.assertEqual(len(lst), len(SummerMonth)) + self.assertEqual(len(SummerMonth), 3, SummerMonth) + if pyver < 3.0: + self.assertEqual( + [SummerMonth.june, SummerMonth.july, SummerMonth.august], + lst, + ) + for i, month in enumerate('june july august'.split()): + i += 1 + e = SummerMonth(i) + self.assertEqual(int(e.value), i) + self.assertNotEqual(e, i) + self.assertEqual(e.name, month) + self.assertTrue(e in SummerMonth) + self.assertTrue(type(e) is SummerMonth) + + def test_programatic_function_type(self): + SummerMonth = Enum('SummerMonth', 'june july august', type=int) + lst = list(SummerMonth) + self.assertEqual(len(lst), len(SummerMonth)) + self.assertEqual(len(SummerMonth), 3, SummerMonth) + self.assertEqual( + [SummerMonth.june, SummerMonth.july, SummerMonth.august], + lst, + ) + for i, month in enumerate('june july august'.split()): + i += 1 + e = SummerMonth(i) + self.assertEqual(e, i) + self.assertEqual(e.name, month) + self.assertTrue(e in SummerMonth) + self.assertTrue(type(e) is SummerMonth) + + def test_programatic_function_type_from_subclass(self): + SummerMonth = IntEnum('SummerMonth', 'june july august') + lst = list(SummerMonth) + self.assertEqual(len(lst), len(SummerMonth)) + self.assertEqual(len(SummerMonth), 3, SummerMonth) + self.assertEqual( + [SummerMonth.june, SummerMonth.july, SummerMonth.august], + lst, + ) + for i, month in enumerate('june july august'.split()): + i += 1 + e = SummerMonth(i) + self.assertEqual(e, i) + self.assertEqual(e.name, month) + self.assertTrue(e in SummerMonth) + self.assertTrue(type(e) is SummerMonth) + + def test_programatic_function_unicode(self): + SummerMonth = Enum('SummerMonth', unicode('june july august')) + lst = list(SummerMonth) + self.assertEqual(len(lst), len(SummerMonth)) + self.assertEqual(len(SummerMonth), 3, SummerMonth) + self.assertEqual( + [SummerMonth.june, SummerMonth.july, SummerMonth.august], + lst, + ) + for i, month in enumerate(unicode('june july august').split()): + i += 1 + e = SummerMonth(i) + self.assertEqual(int(e.value), i) + self.assertNotEqual(e, i) + self.assertEqual(e.name, month) + self.assertTrue(e in SummerMonth) + self.assertTrue(type(e) is SummerMonth) + + def test_programatic_function_unicode_list(self): + SummerMonth = Enum('SummerMonth', [unicode('june'), unicode('july'), unicode('august')]) + lst = list(SummerMonth) + self.assertEqual(len(lst), len(SummerMonth)) + self.assertEqual(len(SummerMonth), 3, SummerMonth) + self.assertEqual( + [SummerMonth.june, SummerMonth.july, SummerMonth.august], + lst, + ) + for i, month in enumerate(unicode('june july august').split()): + i += 1 + e = SummerMonth(i) + self.assertEqual(int(e.value), i) + self.assertNotEqual(e, i) + self.assertEqual(e.name, month) + self.assertTrue(e in SummerMonth) + self.assertTrue(type(e) is SummerMonth) + + def test_programatic_function_unicode_iterable(self): + SummerMonth = Enum( + 'SummerMonth', + ((unicode('june'), 1), (unicode('july'), 2), (unicode('august'), 3)) + ) + lst = list(SummerMonth) + self.assertEqual(len(lst), len(SummerMonth)) + self.assertEqual(len(SummerMonth), 3, SummerMonth) + self.assertEqual( + [SummerMonth.june, SummerMonth.july, SummerMonth.august], + lst, + ) + for i, month in enumerate(unicode('june july august').split()): + i += 1 + e = SummerMonth(i) + self.assertEqual(int(e.value), i) + self.assertNotEqual(e, i) + self.assertEqual(e.name, month) + self.assertTrue(e in SummerMonth) + self.assertTrue(type(e) is SummerMonth) + + def test_programatic_function_from_unicode_dict(self): + SummerMonth = Enum( + 'SummerMonth', + dict(((unicode('june'), 1), (unicode('july'), 2), (unicode('august'), 3))) + ) + lst = list(SummerMonth) + self.assertEqual(len(lst), len(SummerMonth)) + self.assertEqual(len(SummerMonth), 3, SummerMonth) + if pyver < 3.0: + self.assertEqual( + [SummerMonth.june, SummerMonth.july, SummerMonth.august], + lst, + ) + for i, month in enumerate(unicode('june july august').split()): + i += 1 + e = SummerMonth(i) + self.assertEqual(int(e.value), i) + self.assertNotEqual(e, i) + self.assertEqual(e.name, month) + self.assertTrue(e in SummerMonth) + self.assertTrue(type(e) is SummerMonth) + + def test_programatic_function_unicode_type(self): + SummerMonth = Enum('SummerMonth', unicode('june july august'), type=int) + lst = list(SummerMonth) + self.assertEqual(len(lst), len(SummerMonth)) + self.assertEqual(len(SummerMonth), 3, SummerMonth) + self.assertEqual( + [SummerMonth.june, SummerMonth.july, SummerMonth.august], + lst, + ) + for i, month in enumerate(unicode('june july august').split()): + i += 1 + e = SummerMonth(i) + self.assertEqual(e, i) + self.assertEqual(e.name, month) + self.assertTrue(e in SummerMonth) + self.assertTrue(type(e) is SummerMonth) + + def test_programatic_function_unicode_type_from_subclass(self): + SummerMonth = IntEnum('SummerMonth', unicode('june july august')) + lst = list(SummerMonth) + self.assertEqual(len(lst), len(SummerMonth)) + self.assertEqual(len(SummerMonth), 3, SummerMonth) + self.assertEqual( + [SummerMonth.june, SummerMonth.july, SummerMonth.august], + lst, + ) + for i, month in enumerate(unicode('june july august').split()): + i += 1 + e = SummerMonth(i) + self.assertEqual(e, i) + self.assertEqual(e.name, month) + self.assertTrue(e in SummerMonth) + self.assertTrue(type(e) is SummerMonth) + + def test_programmatic_function_unicode_class(self): + if pyver < 3.0: + class_names = unicode('SummerMonth'), 'S\xfcmm\xe9rM\xf6nth'.decode('latin1') + else: + class_names = 'SummerMonth', 'S\xfcmm\xe9rM\xf6nth' + for i, class_name in enumerate(class_names): + if pyver < 3.0 and i == 1: + self.assertRaises(TypeError, Enum, class_name, unicode('june july august')) + else: + SummerMonth = Enum(class_name, unicode('june july august')) + lst = list(SummerMonth) + self.assertEqual(len(lst), len(SummerMonth)) + self.assertEqual(len(SummerMonth), 3, SummerMonth) + self.assertEqual( + [SummerMonth.june, SummerMonth.july, SummerMonth.august], + lst, + ) + for i, month in enumerate(unicode('june july august').split()): + i += 1 + e = SummerMonth(i) + self.assertEqual(e.value, i) + self.assertEqual(e.name, month) + self.assertTrue(e in SummerMonth) + self.assertTrue(type(e) is SummerMonth) + + def test_subclassing(self): + if isinstance(Name, Exception): + raise Name + self.assertEqual(Name.BDFL, 'Guido van Rossum') + self.assertTrue(Name.BDFL, Name('Guido van Rossum')) + self.assertTrue(Name.BDFL is getattr(Name, 'BDFL')) + test_pickle_dump_load(self.assertTrue, Name.BDFL) + + def test_extending(self): + def bad_extension(): + class Color(Enum): + red = 1 + green = 2 + blue = 3 + class MoreColor(Color): + cyan = 4 + magenta = 5 + yellow = 6 + self.assertRaises(TypeError, bad_extension) + + def test_exclude_methods(self): + class whatever(Enum): + this = 'that' + these = 'those' + def really(self): + return 'no, not %s' % self.value + self.assertFalse(type(whatever.really) is whatever) + self.assertEqual(whatever.this.really(), 'no, not that') + + def test_wrong_inheritance_order(self): + def wrong_inherit(): + class Wrong(Enum, str): + NotHere = 'error before this point' + self.assertRaises(TypeError, wrong_inherit) + + def test_intenum_transitivity(self): + class number(IntEnum): + one = 1 + two = 2 + three = 3 + class numero(IntEnum): + uno = 1 + dos = 2 + tres = 3 + self.assertEqual(number.one, numero.uno) + self.assertEqual(number.two, numero.dos) + self.assertEqual(number.three, numero.tres) + + def test_introspection(self): + class Number(IntEnum): + one = 100 + two = 200 + self.assertTrue(Number.one._member_type_ is int) + self.assertTrue(Number._member_type_ is int) + class String(str, Enum): + yarn = 'soft' + rope = 'rough' + wire = 'hard' + self.assertTrue(String.yarn._member_type_ is str) + self.assertTrue(String._member_type_ is str) + class Plain(Enum): + vanilla = 'white' + one = 1 + self.assertTrue(Plain.vanilla._member_type_ is object) + self.assertTrue(Plain._member_type_ is object) + + def test_wrong_enum_in_call(self): + class Monochrome(Enum): + black = 0 + white = 1 + class Gender(Enum): + male = 0 + female = 1 + self.assertRaises(ValueError, Monochrome, Gender.male) + + def test_wrong_enum_in_mixed_call(self): + class Monochrome(IntEnum): + black = 0 + white = 1 + class Gender(Enum): + male = 0 + female = 1 + self.assertRaises(ValueError, Monochrome, Gender.male) + + def test_mixed_enum_in_call_1(self): + class Monochrome(IntEnum): + black = 0 + white = 1 + class Gender(IntEnum): + male = 0 + female = 1 + self.assertTrue(Monochrome(Gender.female) is Monochrome.white) + + def test_mixed_enum_in_call_2(self): + class Monochrome(Enum): + black = 0 + white = 1 + class Gender(IntEnum): + male = 0 + female = 1 + self.assertTrue(Monochrome(Gender.male) is Monochrome.black) + + def test_flufl_enum(self): + class Fluflnum(Enum): + def __int__(self): + return int(self.value) + class MailManOptions(Fluflnum): + option1 = 1 + option2 = 2 + option3 = 3 + self.assertEqual(int(MailManOptions.option1), 1) + + def test_no_such_enum_member(self): + class Color(Enum): + red = 1 + green = 2 + blue = 3 + self.assertRaises(ValueError, Color, 4) + self.assertRaises(KeyError, Color.__getitem__, 'chartreuse') + + def test_new_repr(self): + class Color(Enum): + red = 1 + green = 2 + blue = 3 + def __repr__(self): + return "don't you just love shades of %s?" % self.name + self.assertEqual( + repr(Color.blue), + "don't you just love shades of blue?", + ) + + def test_inherited_repr(self): + class MyEnum(Enum): + def __repr__(self): + return "My name is %s." % self.name + class MyIntEnum(int, MyEnum): + this = 1 + that = 2 + theother = 3 + self.assertEqual(repr(MyIntEnum.that), "My name is that.") + + def test_multiple_mixin_mro(self): + class auto_enum(EnumMeta): + def __new__(metacls, cls, bases, classdict): + original_dict = classdict + classdict = enum._EnumDict() + for k, v in original_dict.items(): + classdict[k] = v + temp = type(classdict)() + names = set(classdict._member_names) + i = 0 + for k in classdict._member_names: + v = classdict[k] + if v == (): + v = i + else: + i = v + i += 1 + temp[k] = v + for k, v in classdict.items(): + if k not in names: + temp[k] = v + return super(auto_enum, metacls).__new__( + metacls, cls, bases, temp) + + AutoNumberedEnum = auto_enum('AutoNumberedEnum', (Enum,), {}) + + AutoIntEnum = auto_enum('AutoIntEnum', (IntEnum,), {}) + + class TestAutoNumber(AutoNumberedEnum): + a = () + b = 3 + c = () + + class TestAutoInt(AutoIntEnum): + a = () + b = 3 + c = () + + def test_subclasses_with_getnewargs(self): + class NamedInt(int): + __qualname__ = 'NamedInt' # needed for pickle protocol 4 + def __new__(cls, *args): + _args = args + if len(args) < 1: + raise TypeError("name and value must be specified") + name, args = args[0], args[1:] + self = int.__new__(cls, *args) + self._intname = name + self._args = _args + return self + def __getnewargs__(self): + return self._args + @property + def __name__(self): + return self._intname + def __repr__(self): + # repr() is updated to include the name and type info + return "%s(%r, %s)" % (type(self).__name__, + self.__name__, + int.__repr__(self)) + def __str__(self): + # str() is unchanged, even if it relies on the repr() fallback + base = int + base_str = base.__str__ + if base_str.__objclass__ is object: + return base.__repr__(self) + return base_str(self) + # for simplicity, we only define one operator that + # propagates expressions + def __add__(self, other): + temp = int(self) + int( other) + if isinstance(self, NamedInt) and isinstance(other, NamedInt): + return NamedInt( + '(%s + %s)' % (self.__name__, other.__name__), + temp ) + else: + return temp + + class NEI(NamedInt, Enum): + __qualname__ = 'NEI' # needed for pickle protocol 4 + x = ('the-x', 1) + y = ('the-y', 2) + + self.assertTrue(NEI.__new__ is Enum.__new__) + self.assertEqual(repr(NEI.x + NEI.y), "NamedInt('(the-x + the-y)', 3)") + globals()['NamedInt'] = NamedInt + globals()['NEI'] = NEI + NI5 = NamedInt('test', 5) + self.assertEqual(NI5, 5) + test_pickle_dump_load(self.assertTrue, NI5, 5) + self.assertEqual(NEI.y.value, 2) + test_pickle_dump_load(self.assertTrue, NEI.y) + + if pyver >= 3.4: + def test_subclasses_with_getnewargs_ex(self): + class NamedInt(int): + __qualname__ = 'NamedInt' # needed for pickle protocol 4 + def __new__(cls, *args): + _args = args + if len(args) < 2: + raise TypeError("name and value must be specified") + name, args = args[0], args[1:] + self = int.__new__(cls, *args) + self._intname = name + self._args = _args + return self + def __getnewargs_ex__(self): + return self._args, {} + @property + def __name__(self): + return self._intname + def __repr__(self): + # repr() is updated to include the name and type info + return "{}({!r}, {})".format(type(self).__name__, + self.__name__, + int.__repr__(self)) + def __str__(self): + # str() is unchanged, even if it relies on the repr() fallback + base = int + base_str = base.__str__ + if base_str.__objclass__ is object: + return base.__repr__(self) + return base_str(self) + # for simplicity, we only define one operator that + # propagates expressions + def __add__(self, other): + temp = int(self) + int( other) + if isinstance(self, NamedInt) and isinstance(other, NamedInt): + return NamedInt( + '({0} + {1})'.format(self.__name__, other.__name__), + temp ) + else: + return temp + + class NEI(NamedInt, Enum): + __qualname__ = 'NEI' # needed for pickle protocol 4 + x = ('the-x', 1) + y = ('the-y', 2) + + + self.assertIs(NEI.__new__, Enum.__new__) + self.assertEqual(repr(NEI.x + NEI.y), "NamedInt('(the-x + the-y)', 3)") + globals()['NamedInt'] = NamedInt + globals()['NEI'] = NEI + NI5 = NamedInt('test', 5) + self.assertEqual(NI5, 5) + test_pickle_dump_load(self.assertEqual, NI5, 5, protocol=(4, HIGHEST_PROTOCOL)) + self.assertEqual(NEI.y.value, 2) + test_pickle_dump_load(self.assertTrue, NEI.y, protocol=(4, HIGHEST_PROTOCOL)) + + def test_subclasses_with_reduce(self): + class NamedInt(int): + __qualname__ = 'NamedInt' # needed for pickle protocol 4 + def __new__(cls, *args): + _args = args + if len(args) < 1: + raise TypeError("name and value must be specified") + name, args = args[0], args[1:] + self = int.__new__(cls, *args) + self._intname = name + self._args = _args + return self + def __reduce__(self): + return self.__class__, self._args + @property + def __name__(self): + return self._intname + def __repr__(self): + # repr() is updated to include the name and type info + return "%s(%r, %s)" % (type(self).__name__, + self.__name__, + int.__repr__(self)) + def __str__(self): + # str() is unchanged, even if it relies on the repr() fallback + base = int + base_str = base.__str__ + if base_str.__objclass__ is object: + return base.__repr__(self) + return base_str(self) + # for simplicity, we only define one operator that + # propagates expressions + def __add__(self, other): + temp = int(self) + int( other) + if isinstance(self, NamedInt) and isinstance(other, NamedInt): + return NamedInt( + '(%s + %s)' % (self.__name__, other.__name__), + temp ) + else: + return temp + + class NEI(NamedInt, Enum): + __qualname__ = 'NEI' # needed for pickle protocol 4 + x = ('the-x', 1) + y = ('the-y', 2) + + + self.assertTrue(NEI.__new__ is Enum.__new__) + self.assertEqual(repr(NEI.x + NEI.y), "NamedInt('(the-x + the-y)', 3)") + globals()['NamedInt'] = NamedInt + globals()['NEI'] = NEI + NI5 = NamedInt('test', 5) + self.assertEqual(NI5, 5) + test_pickle_dump_load(self.assertEqual, NI5, 5) + self.assertEqual(NEI.y.value, 2) + test_pickle_dump_load(self.assertTrue, NEI.y) + + def test_subclasses_with_reduce_ex(self): + class NamedInt(int): + __qualname__ = 'NamedInt' # needed for pickle protocol 4 + def __new__(cls, *args): + _args = args + if len(args) < 1: + raise TypeError("name and value must be specified") + name, args = args[0], args[1:] + self = int.__new__(cls, *args) + self._intname = name + self._args = _args + return self + def __reduce_ex__(self, proto): + return self.__class__, self._args + @property + def __name__(self): + return self._intname + def __repr__(self): + # repr() is updated to include the name and type info + return "%s(%r, %s)" % (type(self).__name__, + self.__name__, + int.__repr__(self)) + def __str__(self): + # str() is unchanged, even if it relies on the repr() fallback + base = int + base_str = base.__str__ + if base_str.__objclass__ is object: + return base.__repr__(self) + return base_str(self) + # for simplicity, we only define one operator that + # propagates expressions + def __add__(self, other): + temp = int(self) + int( other) + if isinstance(self, NamedInt) and isinstance(other, NamedInt): + return NamedInt( + '(%s + %s)' % (self.__name__, other.__name__), + temp ) + else: + return temp + + class NEI(NamedInt, Enum): + __qualname__ = 'NEI' # needed for pickle protocol 4 + x = ('the-x', 1) + y = ('the-y', 2) + + + self.assertTrue(NEI.__new__ is Enum.__new__) + self.assertEqual(repr(NEI.x + NEI.y), "NamedInt('(the-x + the-y)', 3)") + globals()['NamedInt'] = NamedInt + globals()['NEI'] = NEI + NI5 = NamedInt('test', 5) + self.assertEqual(NI5, 5) + test_pickle_dump_load(self.assertEqual, NI5, 5) + self.assertEqual(NEI.y.value, 2) + test_pickle_dump_load(self.assertTrue, NEI.y) + + def test_subclasses_without_direct_pickle_support(self): + class NamedInt(int): + __qualname__ = 'NamedInt' + def __new__(cls, *args): + _args = args + name, args = args[0], args[1:] + if len(args) == 0: + raise TypeError("name and value must be specified") + self = int.__new__(cls, *args) + self._intname = name + self._args = _args + return self + @property + def __name__(self): + return self._intname + def __repr__(self): + # repr() is updated to include the name and type info + return "%s(%r, %s)" % (type(self).__name__, + self.__name__, + int.__repr__(self)) + def __str__(self): + # str() is unchanged, even if it relies on the repr() fallback + base = int + base_str = base.__str__ + if base_str.__objclass__ is object: + return base.__repr__(self) + return base_str(self) + # for simplicity, we only define one operator that + # propagates expressions + def __add__(self, other): + temp = int(self) + int( other) + if isinstance(self, NamedInt) and isinstance(other, NamedInt): + return NamedInt( + '(%s + %s)' % (self.__name__, other.__name__), + temp ) + else: + return temp + + class NEI(NamedInt, Enum): + __qualname__ = 'NEI' + x = ('the-x', 1) + y = ('the-y', 2) + + self.assertTrue(NEI.__new__ is Enum.__new__) + self.assertEqual(repr(NEI.x + NEI.y), "NamedInt('(the-x + the-y)', 3)") + globals()['NamedInt'] = NamedInt + globals()['NEI'] = NEI + NI5 = NamedInt('test', 5) + self.assertEqual(NI5, 5) + self.assertEqual(NEI.y.value, 2) + test_pickle_exception(self.assertRaises, TypeError, NEI.x) + test_pickle_exception(self.assertRaises, PicklingError, NEI) + + def test_subclasses_without_direct_pickle_support_using_name(self): + class NamedInt(int): + __qualname__ = 'NamedInt' + def __new__(cls, *args): + _args = args + name, args = args[0], args[1:] + if len(args) == 0: + raise TypeError("name and value must be specified") + self = int.__new__(cls, *args) + self._intname = name + self._args = _args + return self + @property + def __name__(self): + return self._intname + def __repr__(self): + # repr() is updated to include the name and type info + return "%s(%r, %s)" % (type(self).__name__, + self.__name__, + int.__repr__(self)) + def __str__(self): + # str() is unchanged, even if it relies on the repr() fallback + base = int + base_str = base.__str__ + if base_str.__objclass__ is object: + return base.__repr__(self) + return base_str(self) + # for simplicity, we only define one operator that + # propagates expressions + def __add__(self, other): + temp = int(self) + int( other) + if isinstance(self, NamedInt) and isinstance(other, NamedInt): + return NamedInt( + '(%s + %s)' % (self.__name__, other.__name__), + temp ) + else: + return temp + + class NEI(NamedInt, Enum): + __qualname__ = 'NEI' + x = ('the-x', 1) + y = ('the-y', 2) + def __reduce_ex__(self, proto): + return getattr, (self.__class__, self._name_) + + self.assertTrue(NEI.__new__ is Enum.__new__) + self.assertEqual(repr(NEI.x + NEI.y), "NamedInt('(the-x + the-y)', 3)") + globals()['NamedInt'] = NamedInt + globals()['NEI'] = NEI + NI5 = NamedInt('test', 5) + self.assertEqual(NI5, 5) + self.assertEqual(NEI.y.value, 2) + test_pickle_dump_load(self.assertTrue, NEI.y) + test_pickle_dump_load(self.assertTrue, NEI) + + def test_tuple_subclass(self): + class SomeTuple(tuple, Enum): + __qualname__ = 'SomeTuple' + first = (1, 'for the money') + second = (2, 'for the show') + third = (3, 'for the music') + self.assertTrue(type(SomeTuple.first) is SomeTuple) + self.assertTrue(isinstance(SomeTuple.second, tuple)) + self.assertEqual(SomeTuple.third, (3, 'for the music')) + globals()['SomeTuple'] = SomeTuple + test_pickle_dump_load(self.assertTrue, SomeTuple.first) + + def test_duplicate_values_give_unique_enum_items(self): + class AutoNumber(Enum): + __order__ = 'enum_m enum_d enum_y' + enum_m = () + enum_d = () + enum_y = () + def __new__(cls): + value = len(cls.__members__) + 1 + obj = object.__new__(cls) + obj._value_ = value + return obj + def __int__(self): + return int(self._value_) + self.assertEqual(int(AutoNumber.enum_d), 2) + self.assertEqual(AutoNumber.enum_y.value, 3) + self.assertTrue(AutoNumber(1) is AutoNumber.enum_m) + self.assertEqual( + list(AutoNumber), + [AutoNumber.enum_m, AutoNumber.enum_d, AutoNumber.enum_y], + ) + + def test_inherited_new_from_enhanced_enum(self): + class AutoNumber2(Enum): + def __new__(cls): + value = len(cls.__members__) + 1 + obj = object.__new__(cls) + obj._value_ = value + return obj + def __int__(self): + return int(self._value_) + class Color(AutoNumber2): + __order__ = 'red green blue' + red = () + green = () + blue = () + self.assertEqual(len(Color), 3, "wrong number of elements: %d (should be %d)" % (len(Color), 3)) + self.assertEqual(list(Color), [Color.red, Color.green, Color.blue]) + if pyver >= 3.0: + self.assertEqual(list(map(int, Color)), [1, 2, 3]) + + def test_inherited_new_from_mixed_enum(self): + class AutoNumber3(IntEnum): + def __new__(cls): + value = len(cls.__members__) + 1 + obj = int.__new__(cls, value) + obj._value_ = value + return obj + class Color(AutoNumber3): + red = () + green = () + blue = () + self.assertEqual(len(Color), 3, "wrong number of elements: %d (should be %d)" % (len(Color), 3)) + Color.red + Color.green + Color.blue + + def test_ordered_mixin(self): + class OrderedEnum(Enum): + def __ge__(self, other): + if self.__class__ is other.__class__: + return self._value_ >= other._value_ + return NotImplemented + def __gt__(self, other): + if self.__class__ is other.__class__: + return self._value_ > other._value_ + return NotImplemented + def __le__(self, other): + if self.__class__ is other.__class__: + return self._value_ <= other._value_ + return NotImplemented + def __lt__(self, other): + if self.__class__ is other.__class__: + return self._value_ < other._value_ + return NotImplemented + class Grade(OrderedEnum): + __order__ = 'A B C D F' + A = 5 + B = 4 + C = 3 + D = 2 + F = 1 + self.assertEqual(list(Grade), [Grade.A, Grade.B, Grade.C, Grade.D, Grade.F]) + self.assertTrue(Grade.A > Grade.B) + self.assertTrue(Grade.F <= Grade.C) + self.assertTrue(Grade.D < Grade.A) + self.assertTrue(Grade.B >= Grade.B) + + def test_extending2(self): + def bad_extension(): + class Shade(Enum): + def shade(self): + print(self.name) + class Color(Shade): + red = 1 + green = 2 + blue = 3 + class MoreColor(Color): + cyan = 4 + magenta = 5 + yellow = 6 + self.assertRaises(TypeError, bad_extension) + + def test_extending3(self): + class Shade(Enum): + def shade(self): + return self.name + class Color(Shade): + def hex(self): + return '%s hexlified!' % self.value + class MoreColor(Color): + cyan = 4 + magenta = 5 + yellow = 6 + self.assertEqual(MoreColor.magenta.hex(), '5 hexlified!') + + def test_no_duplicates(self): + def bad_duplicates(): + class UniqueEnum(Enum): + def __init__(self, *args): + cls = self.__class__ + if any(self.value == e.value for e in cls): + a = self.name + e = cls(self.value).name + raise ValueError( + "aliases not allowed in UniqueEnum: %r --> %r" + % (a, e) + ) + class Color(UniqueEnum): + red = 1 + green = 2 + blue = 3 + class Color(UniqueEnum): + red = 1 + green = 2 + blue = 3 + grene = 2 + self.assertRaises(ValueError, bad_duplicates) + + def test_reversed(self): + self.assertEqual( + list(reversed(self.Season)), + [self.Season.WINTER, self.Season.AUTUMN, self.Season.SUMMER, + self.Season.SPRING] + ) + + def test_init(self): + class Planet(Enum): + MERCURY = (3.303e+23, 2.4397e6) + VENUS = (4.869e+24, 6.0518e6) + EARTH = (5.976e+24, 6.37814e6) + MARS = (6.421e+23, 3.3972e6) + JUPITER = (1.9e+27, 7.1492e7) + SATURN = (5.688e+26, 6.0268e7) + URANUS = (8.686e+25, 2.5559e7) + NEPTUNE = (1.024e+26, 2.4746e7) + def __init__(self, mass, radius): + self.mass = mass # in kilograms + self.radius = radius # in meters + @property + def surface_gravity(self): + # universal gravitational constant (m3 kg-1 s-2) + G = 6.67300E-11 + return G * self.mass / (self.radius * self.radius) + self.assertEqual(round(Planet.EARTH.surface_gravity, 2), 9.80) + self.assertEqual(Planet.EARTH.value, (5.976e+24, 6.37814e6)) + + def test_nonhash_value(self): + class AutoNumberInAList(Enum): + def __new__(cls): + value = [len(cls.__members__) + 1] + obj = object.__new__(cls) + obj._value_ = value + return obj + class ColorInAList(AutoNumberInAList): + __order__ = 'red green blue' + red = () + green = () + blue = () + self.assertEqual(list(ColorInAList), [ColorInAList.red, ColorInAList.green, ColorInAList.blue]) + self.assertEqual(ColorInAList.red.value, [1]) + self.assertEqual(ColorInAList([1]), ColorInAList.red) + + def test_conflicting_types_resolved_in_new(self): + class LabelledIntEnum(int, Enum): + def __new__(cls, *args): + value, label = args + obj = int.__new__(cls, value) + obj.label = label + obj._value_ = value + return obj + + class LabelledList(LabelledIntEnum): + unprocessed = (1, "Unprocessed") + payment_complete = (2, "Payment Complete") + + self.assertEqual(list(LabelledList), [LabelledList.unprocessed, LabelledList.payment_complete]) + self.assertEqual(LabelledList.unprocessed, 1) + self.assertEqual(LabelledList(1), LabelledList.unprocessed) + +class TestUnique(unittest.TestCase): + """2.4 doesn't allow class decorators, use function syntax.""" + + def test_unique_clean(self): + class Clean(Enum): + one = 1 + two = 'dos' + tres = 4.0 + unique(Clean) + class Cleaner(IntEnum): + single = 1 + double = 2 + triple = 3 + unique(Cleaner) + + def test_unique_dirty(self): + try: + class Dirty(Enum): + __order__ = 'one two tres' + one = 1 + two = 'dos' + tres = 1 + unique(Dirty) + except ValueError: + exc = sys.exc_info()[1] + message = exc.args[0] + self.assertTrue('tres -> one' in message) + + try: + class Dirtier(IntEnum): + __order__ = 'single double triple turkey' + single = 1 + double = 1 + triple = 3 + turkey = 3 + unique(Dirtier) + except ValueError: + exc = sys.exc_info()[1] + message = exc.args[0] + self.assertTrue('double -> single' in message) + self.assertTrue('turkey -> triple' in message) + + +class TestMe(unittest.TestCase): + + pass + +if __name__ == '__main__': + unittest.main() diff --git a/lib/feedparser.py b/lib/feedparser.py new file mode 100644 index 00000000..15fdc95b --- /dev/null +++ b/lib/feedparser.py @@ -0,0 +1,3909 @@ +#!/usr/bin/env python +"""Universal feed parser + +Handles RSS 0.9x, RSS 1.0, RSS 2.0, CDF, Atom 0.3, and Atom 1.0 feeds + +Visit http://feedparser.org/ for the latest version +Visit http://feedparser.org/docs/ for the latest documentation + +Required: Python 2.4 or later +Recommended: CJKCodecs and iconv_codec +""" + +__version__ = "5.0.1" +__license__ = """Copyright (c) 2002-2008, Mark Pilgrim, All rights reserved. + +Redistribution and use in source and binary forms, with or without modification, +are permitted provided that the following conditions are met: + +* Redistributions of source code must retain the above copyright notice, + this list of conditions and the following disclaimer. +* Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 'AS IS' +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +POSSIBILITY OF SUCH DAMAGE.""" +__author__ = "Mark Pilgrim " +__contributors__ = ["Jason Diamond ", + "John Beimler ", + "Fazal Majid ", + "Aaron Swartz ", + "Kevin Marks ", + "Sam Ruby ", + "Ade Oshineye ", + "Martin Pool ", + "Kurt McKee "] +_debug = 0 + +# HTTP "User-Agent" header to send to servers when downloading feeds. +# If you are embedding feedparser in a larger application, you should +# change this to your application name and URL. +USER_AGENT = "UniversalFeedParser/%s +http://feedparser.org/" % __version__ + +# HTTP "Accept" header to send to servers when downloading feeds. If you don't +# want to send an Accept header, set this to None. +ACCEPT_HEADER = "application/atom+xml,application/rdf+xml,application/rss+xml,application/x-netcdf,application/xml;q=0.9,text/xml;q=0.2,*/*;q=0.1" + +# List of preferred XML parsers, by SAX driver name. These will be tried first, +# but if they're not installed, Python will keep searching through its own list +# of pre-installed parsers until it finds one that supports everything we need. +PREFERRED_XML_PARSERS = ["drv_libxml2"] + +# If you want feedparser to automatically run HTML markup through HTML Tidy, set +# this to 1. Requires mxTidy +# or utidylib . +TIDY_MARKUP = 0 + +# List of Python interfaces for HTML Tidy, in order of preference. Only useful +# if TIDY_MARKUP = 1 +PREFERRED_TIDY_INTERFACES = ["uTidy", "mxTidy"] + +# If you want feedparser to automatically resolve all relative URIs, set this +# to 1. +RESOLVE_RELATIVE_URIS = 1 + +# If you want feedparser to automatically sanitize all potentially unsafe +# HTML content, set this to 1. +SANITIZE_HTML = 1 + +# ---------- Python 3 modules (make it work if possible) ---------- +try: + import rfc822 +except ImportError: + from email import _parseaddr as rfc822 + +try: + # Python 3.1 introduces bytes.maketrans and simultaneously + # deprecates string.maketrans; use bytes.maketrans if possible + _maketrans = bytes.maketrans +except (NameError, AttributeError): + import string + _maketrans = string.maketrans + +# base64 support for Atom feeds that contain embedded binary data +try: + import base64, binascii + # Python 3.1 deprecates decodestring in favor of decodebytes + _base64decode = getattr(base64, 'decodebytes', base64.decodestring) +except: + base64 = binascii = None + +def _s2bytes(s): + # Convert a UTF-8 str to bytes if the interpreter is Python 3 + try: + return bytes(s, 'utf8') + except (NameError, TypeError): + # In Python 2.5 and below, bytes doesn't exist (NameError) + # In Python 2.6 and above, bytes and str are the same (TypeError) + return s + +def _l2bytes(l): + # Convert a list of ints to bytes if the interpreter is Python 3 + try: + if bytes is not str: + # In Python 2.6 and above, this call won't raise an exception + # but it will return bytes([65]) as '[65]' instead of 'A' + return bytes(l) + raise NameError + except NameError: + return ''.join(map(chr, l)) + +# If you want feedparser to allow all URL schemes, set this to () +# List culled from Python's urlparse documentation at: +# http://docs.python.org/library/urlparse.html +# as well as from "URI scheme" at Wikipedia: +# https://secure.wikimedia.org/wikipedia/en/wiki/URI_scheme +# Many more will likely need to be added! +ACCEPTABLE_URI_SCHEMES = ( + 'file', 'ftp', 'gopher', 'h323', 'hdl', 'http', 'https', 'imap', 'mailto', + 'mms', 'news', 'nntp', 'prospero', 'rsync', 'rtsp', 'rtspu', 'sftp', + 'shttp', 'sip', 'sips', 'snews', 'svn', 'svn+ssh', 'telnet', 'wais', + # Additional common-but-unofficial schemes + 'aim', 'callto', 'cvs', 'facetime', 'feed', 'git', 'gtalk', 'irc', 'ircs', + 'irc6', 'itms', 'mms', 'msnim', 'skype', 'ssh', 'smb', 'svn', 'ymsg', +) +#ACCEPTABLE_URI_SCHEMES = () + +# ---------- required modules (should come with any Python distribution) ---------- +import sgmllib, re, sys, copy, urlparse, time, types, cgi, urllib, urllib2, datetime +try: + from io import BytesIO as _StringIO +except ImportError: + try: + from cStringIO import StringIO as _StringIO + except: + from StringIO import StringIO as _StringIO + +# ---------- optional modules (feedparser will work without these, but with reduced functionality) ---------- + +# gzip is included with most Python distributions, but may not be available if you compiled your own +try: + import gzip +except: + gzip = None +try: + import zlib +except: + zlib = None + +# If a real XML parser is available, feedparser will attempt to use it. feedparser has +# been tested with the built-in SAX parser, PyXML, and libxml2. On platforms where the +# Python distribution does not come with an XML parser (such as Mac OS X 10.2 and some +# versions of FreeBSD), feedparser will quietly fall back on regex-based parsing. +try: + import xml.sax + xml.sax.make_parser(PREFERRED_XML_PARSERS) # test for valid parsers + from xml.sax.saxutils import escape as _xmlescape + _XML_AVAILABLE = 1 +except: + _XML_AVAILABLE = 0 + def _xmlescape(data,entities={}): + data = data.replace('&', '&') + data = data.replace('>', '>') + data = data.replace('<', '<') + for char, entity in entities: + data = data.replace(char, entity) + return data + +# cjkcodecs and iconv_codec provide support for more character encodings. +# Both are available from http://cjkpython.i18n.org/ +try: + import cjkcodecs.aliases +except: + pass +try: + import iconv_codec +except: + pass + +# chardet library auto-detects character encodings +# Download from http://chardet.feedparser.org/ +try: + import chardet + if _debug: + import chardet.constants + chardet.constants._debug = 1 +except: + chardet = None + +# reversable htmlentitydefs mappings for Python 2.2 +try: + from htmlentitydefs import name2codepoint, codepoint2name +except: + import htmlentitydefs + name2codepoint={} + codepoint2name={} + for (name,codepoint) in htmlentitydefs.entitydefs.iteritems(): + if codepoint.startswith('&#'): codepoint=unichr(int(codepoint[2:-1])) + name2codepoint[name]=ord(codepoint) + codepoint2name[ord(codepoint)]=name + +# BeautifulSoup parser used for parsing microformats from embedded HTML content +# http://www.crummy.com/software/BeautifulSoup/ +# feedparser is tested with BeautifulSoup 3.0.x, but it might work with the +# older 2.x series. If it doesn't, and you can figure out why, I'll accept a +# patch and modify the compatibility statement accordingly. +try: + import BeautifulSoup +except: + BeautifulSoup = None + +# ---------- don't touch these ---------- +class ThingsNobodyCaresAboutButMe(Exception): pass +class CharacterEncodingOverride(ThingsNobodyCaresAboutButMe): pass +class CharacterEncodingUnknown(ThingsNobodyCaresAboutButMe): pass +class NonXMLContentType(ThingsNobodyCaresAboutButMe): pass +class UndeclaredNamespace(Exception): pass + +sgmllib.tagfind = re.compile('[a-zA-Z][-_.:a-zA-Z0-9]*') +sgmllib.special = re.compile(']|"[^"]*"(?=>|/|\s|\w+=)|'[^']*'(?=>|/|\s|\w+=))*(?=[<>])|.*?(?=[<>])''') + def search(self,string,index=0): + match = self.endbracket.match(string,index) + if match is not None: + # Returning a new object in the calling thread's context + # resolves a thread-safety. + return EndBracketMatch(match) + return None + class EndBracketMatch: + def __init__(self, match): + self.match = match + def start(self, n): + return self.match.end(n) + sgmllib.endbracket = EndBracketRegEx() + +SUPPORTED_VERSIONS = {'': 'unknown', + 'rss090': 'RSS 0.90', + 'rss091n': 'RSS 0.91 (Netscape)', + 'rss091u': 'RSS 0.91 (Userland)', + 'rss092': 'RSS 0.92', + 'rss093': 'RSS 0.93', + 'rss094': 'RSS 0.94', + 'rss20': 'RSS 2.0', + 'rss10': 'RSS 1.0', + 'rss': 'RSS (unknown version)', + 'atom01': 'Atom 0.1', + 'atom02': 'Atom 0.2', + 'atom03': 'Atom 0.3', + 'atom10': 'Atom 1.0', + 'atom': 'Atom (unknown version)', + 'cdf': 'CDF', + 'hotrss': 'Hot RSS' + } + +try: + UserDict = dict +except NameError: + # Python 2.1 does not have dict + from UserDict import UserDict + def dict(aList): + rc = {} + for k, v in aList: + rc[k] = v + return rc + +class FeedParserDict(UserDict): + keymap = {'channel': 'feed', + 'items': 'entries', + 'guid': 'id', + 'date': 'updated', + 'date_parsed': 'updated_parsed', + 'description': ['summary', 'subtitle'], + 'url': ['href'], + 'modified': 'updated', + 'modified_parsed': 'updated_parsed', + 'issued': 'published', + 'issued_parsed': 'published_parsed', + 'copyright': 'rights', + 'copyright_detail': 'rights_detail', + 'tagline': 'subtitle', + 'tagline_detail': 'subtitle_detail'} + def __getitem__(self, key): + if key == 'category': + return UserDict.__getitem__(self, 'tags')[0]['term'] + if key == 'enclosures': + norel = lambda link: FeedParserDict([(name,value) for (name,value) in link.items() if name!='rel']) + return [norel(link) for link in UserDict.__getitem__(self, 'links') if link['rel']=='enclosure'] + if key == 'license': + for link in UserDict.__getitem__(self, 'links'): + if link['rel']=='license' and link.has_key('href'): + return link['href'] + if key == 'categories': + return [(tag['scheme'], tag['term']) for tag in UserDict.__getitem__(self, 'tags')] + realkey = self.keymap.get(key, key) + if type(realkey) == types.ListType: + for k in realkey: + if UserDict.__contains__(self, k): + return UserDict.__getitem__(self, k) + if UserDict.__contains__(self, key): + return UserDict.__getitem__(self, key) + return UserDict.__getitem__(self, realkey) + + def __setitem__(self, key, value): + for k in self.keymap.keys(): + if key == k: + key = self.keymap[k] + if type(key) == types.ListType: + key = key[0] + return UserDict.__setitem__(self, key, value) + + def get(self, key, default=None): + if self.has_key(key): + return self[key] + else: + return default + + def setdefault(self, key, value): + if not self.has_key(key): + self[key] = value + return self[key] + + def has_key(self, key): + try: + return hasattr(self, key) or UserDict.__contains__(self, key) + except AttributeError: + return False + # This alias prevents the 2to3 tool from changing the semantics of the + # __contains__ function below and exhausting the maximum recursion depth + __has_key = has_key + + def __getattr__(self, key): + try: + return self.__dict__[key] + except KeyError: + pass + try: + assert not key.startswith('_') + return self.__getitem__(key) + except: + raise AttributeError, "object has no attribute '%s'" % key + + def __setattr__(self, key, value): + if key.startswith('_') or key == 'data': + self.__dict__[key] = value + else: + return self.__setitem__(key, value) + + def __contains__(self, key): + return self.__has_key(key) + +def zopeCompatibilityHack(): + global FeedParserDict + del FeedParserDict + def FeedParserDict(aDict=None): + rc = {} + if aDict: + rc.update(aDict) + return rc + +_ebcdic_to_ascii_map = None +def _ebcdic_to_ascii(s): + global _ebcdic_to_ascii_map + if not _ebcdic_to_ascii_map: + emap = ( + 0,1,2,3,156,9,134,127,151,141,142,11,12,13,14,15, + 16,17,18,19,157,133,8,135,24,25,146,143,28,29,30,31, + 128,129,130,131,132,10,23,27,136,137,138,139,140,5,6,7, + 144,145,22,147,148,149,150,4,152,153,154,155,20,21,158,26, + 32,160,161,162,163,164,165,166,167,168,91,46,60,40,43,33, + 38,169,170,171,172,173,174,175,176,177,93,36,42,41,59,94, + 45,47,178,179,180,181,182,183,184,185,124,44,37,95,62,63, + 186,187,188,189,190,191,192,193,194,96,58,35,64,39,61,34, + 195,97,98,99,100,101,102,103,104,105,196,197,198,199,200,201, + 202,106,107,108,109,110,111,112,113,114,203,204,205,206,207,208, + 209,126,115,116,117,118,119,120,121,122,210,211,212,213,214,215, + 216,217,218,219,220,221,222,223,224,225,226,227,228,229,230,231, + 123,65,66,67,68,69,70,71,72,73,232,233,234,235,236,237, + 125,74,75,76,77,78,79,80,81,82,238,239,240,241,242,243, + 92,159,83,84,85,86,87,88,89,90,244,245,246,247,248,249, + 48,49,50,51,52,53,54,55,56,57,250,251,252,253,254,255 + ) + _ebcdic_to_ascii_map = _maketrans( \ + _l2bytes(range(256)), _l2bytes(emap)) + return s.translate(_ebcdic_to_ascii_map) + +_cp1252 = { + unichr(128): unichr(8364), # euro sign + unichr(130): unichr(8218), # single low-9 quotation mark + unichr(131): unichr( 402), # latin small letter f with hook + unichr(132): unichr(8222), # double low-9 quotation mark + unichr(133): unichr(8230), # horizontal ellipsis + unichr(134): unichr(8224), # dagger + unichr(135): unichr(8225), # double dagger + unichr(136): unichr( 710), # modifier letter circumflex accent + unichr(137): unichr(8240), # per mille sign + unichr(138): unichr( 352), # latin capital letter s with caron + unichr(139): unichr(8249), # single left-pointing angle quotation mark + unichr(140): unichr( 338), # latin capital ligature oe + unichr(142): unichr( 381), # latin capital letter z with caron + unichr(145): unichr(8216), # left single quotation mark + unichr(146): unichr(8217), # right single quotation mark + unichr(147): unichr(8220), # left double quotation mark + unichr(148): unichr(8221), # right double quotation mark + unichr(149): unichr(8226), # bullet + unichr(150): unichr(8211), # en dash + unichr(151): unichr(8212), # em dash + unichr(152): unichr( 732), # small tilde + unichr(153): unichr(8482), # trade mark sign + unichr(154): unichr( 353), # latin small letter s with caron + unichr(155): unichr(8250), # single right-pointing angle quotation mark + unichr(156): unichr( 339), # latin small ligature oe + unichr(158): unichr( 382), # latin small letter z with caron + unichr(159): unichr( 376)} # latin capital letter y with diaeresis + +_urifixer = re.compile('^([A-Za-z][A-Za-z0-9+-.]*://)(/*)(.*?)') +def _urljoin(base, uri): + uri = _urifixer.sub(r'\1\3', uri) + try: + return urlparse.urljoin(base, uri) + except: + uri = urlparse.urlunparse([urllib.quote(part) for part in urlparse.urlparse(uri)]) + return urlparse.urljoin(base, uri) + +class _FeedParserMixin: + namespaces = {'': '', + 'http://backend.userland.com/rss': '', + 'http://blogs.law.harvard.edu/tech/rss': '', + 'http://purl.org/rss/1.0/': '', + 'http://my.netscape.com/rdf/simple/0.9/': '', + 'http://example.com/newformat#': '', + 'http://example.com/necho': '', + 'http://purl.org/echo/': '', + 'uri/of/echo/namespace#': '', + 'http://purl.org/pie/': '', + 'http://purl.org/atom/ns#': '', + 'http://www.w3.org/2005/Atom': '', + 'http://purl.org/rss/1.0/modules/rss091#': '', + + 'http://webns.net/mvcb/': 'admin', + 'http://purl.org/rss/1.0/modules/aggregation/': 'ag', + 'http://purl.org/rss/1.0/modules/annotate/': 'annotate', + 'http://media.tangent.org/rss/1.0/': 'audio', + 'http://backend.userland.com/blogChannelModule': 'blogChannel', + 'http://web.resource.org/cc/': 'cc', + 'http://backend.userland.com/creativeCommonsRssModule': 'creativeCommons', + 'http://purl.org/rss/1.0/modules/company': 'co', + 'http://purl.org/rss/1.0/modules/content/': 'content', + 'http://my.theinfo.org/changed/1.0/rss/': 'cp', + 'http://purl.org/dc/elements/1.1/': 'dc', + 'http://purl.org/dc/terms/': 'dcterms', + 'http://purl.org/rss/1.0/modules/email/': 'email', + 'http://purl.org/rss/1.0/modules/event/': 'ev', + 'http://rssnamespace.org/feedburner/ext/1.0': 'feedburner', + 'http://freshmeat.net/rss/fm/': 'fm', + 'http://xmlns.com/foaf/0.1/': 'foaf', + 'http://www.w3.org/2003/01/geo/wgs84_pos#': 'geo', + 'http://postneo.com/icbm/': 'icbm', + 'http://purl.org/rss/1.0/modules/image/': 'image', + 'http://www.itunes.com/DTDs/PodCast-1.0.dtd': 'itunes', + 'http://example.com/DTDs/PodCast-1.0.dtd': 'itunes', + 'http://purl.org/rss/1.0/modules/link/': 'l', + 'http://search.yahoo.com/mrss': 'media', + #Version 1.1.2 of the Media RSS spec added the trailing slash on the namespace + 'http://search.yahoo.com/mrss/': 'media', + 'http://madskills.com/public/xml/rss/module/pingback/': 'pingback', + 'http://prismstandard.org/namespaces/1.2/basic/': 'prism', + 'http://www.w3.org/1999/02/22-rdf-syntax-ns#': 'rdf', + 'http://www.w3.org/2000/01/rdf-schema#': 'rdfs', + 'http://purl.org/rss/1.0/modules/reference/': 'ref', + 'http://purl.org/rss/1.0/modules/richequiv/': 'reqv', + 'http://purl.org/rss/1.0/modules/search/': 'search', + 'http://purl.org/rss/1.0/modules/slash/': 'slash', + 'http://schemas.xmlsoap.org/soap/envelope/': 'soap', + 'http://purl.org/rss/1.0/modules/servicestatus/': 'ss', + 'http://hacks.benhammersley.com/rss/streaming/': 'str', + 'http://purl.org/rss/1.0/modules/subscription/': 'sub', + 'http://purl.org/rss/1.0/modules/syndication/': 'sy', + 'http://schemas.pocketsoap.com/rss/myDescModule/': 'szf', + 'http://purl.org/rss/1.0/modules/taxonomy/': 'taxo', + 'http://purl.org/rss/1.0/modules/threading/': 'thr', + 'http://purl.org/rss/1.0/modules/textinput/': 'ti', + 'http://madskills.com/public/xml/rss/module/trackback/':'trackback', + 'http://wellformedweb.org/commentAPI/': 'wfw', + 'http://purl.org/rss/1.0/modules/wiki/': 'wiki', + 'http://www.w3.org/1999/xhtml': 'xhtml', + 'http://www.w3.org/1999/xlink': 'xlink', + 'http://www.w3.org/XML/1998/namespace': 'xml' +} + _matchnamespaces = {} + + can_be_relative_uri = ['link', 'id', 'wfw_comment', 'wfw_commentrss', 'docs', 'url', 'href', 'comments', 'icon', 'logo'] + can_contain_relative_uris = ['content', 'title', 'summary', 'info', 'tagline', 'subtitle', 'copyright', 'rights', 'description'] + can_contain_dangerous_markup = ['content', 'title', 'summary', 'info', 'tagline', 'subtitle', 'copyright', 'rights', 'description'] + html_types = ['text/html', 'application/xhtml+xml'] + + def __init__(self, baseuri=None, baselang=None, encoding='utf-8'): + if _debug: sys.stderr.write('initializing FeedParser\n') + if not self._matchnamespaces: + for k, v in self.namespaces.items(): + self._matchnamespaces[k.lower()] = v + self.feeddata = FeedParserDict() # feed-level data + self.encoding = encoding # character encoding + self.entries = [] # list of entry-level data + self.version = '' # feed type/version, see SUPPORTED_VERSIONS + self.namespacesInUse = {} # dictionary of namespaces defined by the feed + + # the following are used internally to track state; + # this is really out of control and should be refactored + self.infeed = 0 + self.inentry = 0 + self.incontent = 0 + self.intextinput = 0 + self.inimage = 0 + self.inauthor = 0 + self.incontributor = 0 + self.inpublisher = 0 + self.insource = 0 + self.sourcedata = FeedParserDict() + self.contentparams = FeedParserDict() + self._summaryKey = None + self.namespacemap = {} + self.elementstack = [] + self.basestack = [] + self.langstack = [] + self.baseuri = baseuri or '' + self.lang = baselang or None + self.svgOK = 0 + self.hasTitle = 0 + if baselang: + self.feeddata['language'] = baselang.replace('_','-') + + def unknown_starttag(self, tag, attrs): + if _debug: sys.stderr.write('start %s with %s\n' % (tag, attrs)) + # normalize attrs + attrs = [(k.lower(), v) for k, v in attrs] + attrs = [(k, k in ('rel', 'type') and v.lower() or v) for k, v in attrs] + # the sgml parser doesn't handle entities in attributes, but + # strict xml parsers do -- account for this difference + if isinstance(self, _LooseFeedParser): + attrs = [(k, v.replace('&', '&')) for k, v in attrs] + + # track xml:base and xml:lang + attrsD = dict(attrs) + baseuri = attrsD.get('xml:base', attrsD.get('base')) or self.baseuri + if type(baseuri) != type(u''): + try: + baseuri = unicode(baseuri, self.encoding) + except: + baseuri = unicode(baseuri, 'iso-8859-1') + # ensure that self.baseuri is always an absolute URI that + # uses a whitelisted URI scheme (e.g. not `javscript:`) + if self.baseuri: + self.baseuri = _makeSafeAbsoluteURI(self.baseuri, baseuri) or self.baseuri + else: + self.baseuri = _urljoin(self.baseuri, baseuri) + lang = attrsD.get('xml:lang', attrsD.get('lang')) + if lang == '': + # xml:lang could be explicitly set to '', we need to capture that + lang = None + elif lang is None: + # if no xml:lang is specified, use parent lang + lang = self.lang + if lang: + if tag in ('feed', 'rss', 'rdf:RDF'): + self.feeddata['language'] = lang.replace('_','-') + self.lang = lang + self.basestack.append(self.baseuri) + self.langstack.append(lang) + + # track namespaces + for prefix, uri in attrs: + if prefix.startswith('xmlns:'): + self.trackNamespace(prefix[6:], uri) + elif prefix == 'xmlns': + self.trackNamespace(None, uri) + + # track inline content + if self.incontent and self.contentparams.has_key('type') and not self.contentparams.get('type', 'xml').endswith('xml'): + if tag in ['xhtml:div', 'div']: return # typepad does this 10/2007 + # element declared itself as escaped markup, but it isn't really + self.contentparams['type'] = 'application/xhtml+xml' + if self.incontent and self.contentparams.get('type') == 'application/xhtml+xml': + if tag.find(':') <> -1: + prefix, tag = tag.split(':', 1) + namespace = self.namespacesInUse.get(prefix, '') + if tag=='math' and namespace=='http://www.w3.org/1998/Math/MathML': + attrs.append(('xmlns',namespace)) + if tag=='svg' and namespace=='http://www.w3.org/2000/svg': + attrs.append(('xmlns',namespace)) + if tag == 'svg': self.svgOK += 1 + return self.handle_data('<%s%s>' % (tag, self.strattrs(attrs)), escape=0) + + # match namespaces + if tag.find(':') <> -1: + prefix, suffix = tag.split(':', 1) + else: + prefix, suffix = '', tag + prefix = self.namespacemap.get(prefix, prefix) + if prefix: + prefix = prefix + '_' + + # special hack for better tracking of empty textinput/image elements in illformed feeds + if (not prefix) and tag not in ('title', 'link', 'description', 'name'): + self.intextinput = 0 + if (not prefix) and tag not in ('title', 'link', 'description', 'url', 'href', 'width', 'height'): + self.inimage = 0 + + # call special handler (if defined) or default handler + methodname = '_start_' + prefix + suffix + try: + method = getattr(self, methodname) + return method(attrsD) + except AttributeError: + # Since there's no handler or something has gone wrong we explicitly add the element and its attributes + unknown_tag = prefix + suffix + if len(attrsD) == 0: + # No attributes so merge it into the encosing dictionary + return self.push(unknown_tag, 1) + else: + # Has attributes so create it in its own dictionary + context = self._getContext() + context[unknown_tag] = attrsD + + def unknown_endtag(self, tag): + if _debug: sys.stderr.write('end %s\n' % tag) + # match namespaces + if tag.find(':') <> -1: + prefix, suffix = tag.split(':', 1) + else: + prefix, suffix = '', tag + prefix = self.namespacemap.get(prefix, prefix) + if prefix: + prefix = prefix + '_' + if suffix == 'svg' and self.svgOK: self.svgOK -= 1 + + # call special handler (if defined) or default handler + methodname = '_end_' + prefix + suffix + try: + if self.svgOK: raise AttributeError() + method = getattr(self, methodname) + method() + except AttributeError: + self.pop(prefix + suffix) + + # track inline content + if self.incontent and self.contentparams.has_key('type') and not self.contentparams.get('type', 'xml').endswith('xml'): + # element declared itself as escaped markup, but it isn't really + if tag in ['xhtml:div', 'div']: return # typepad does this 10/2007 + self.contentparams['type'] = 'application/xhtml+xml' + if self.incontent and self.contentparams.get('type') == 'application/xhtml+xml': + tag = tag.split(':')[-1] + self.handle_data('' % tag, escape=0) + + # track xml:base and xml:lang going out of scope + if self.basestack: + self.basestack.pop() + if self.basestack and self.basestack[-1]: + self.baseuri = self.basestack[-1] + if self.langstack: + self.langstack.pop() + if self.langstack: # and (self.langstack[-1] is not None): + self.lang = self.langstack[-1] + + def handle_charref(self, ref): + # called for each character reference, e.g. for ' ', ref will be '160' + if not self.elementstack: return + ref = ref.lower() + if ref in ('34', '38', '39', '60', '62', 'x22', 'x26', 'x27', 'x3c', 'x3e'): + text = '&#%s;' % ref + else: + if ref[0] == 'x': + c = int(ref[1:], 16) + else: + c = int(ref) + text = unichr(c).encode('utf-8') + self.elementstack[-1][2].append(text) + + def handle_entityref(self, ref): + # called for each entity reference, e.g. for '©', ref will be 'copy' + if not self.elementstack: return + if _debug: sys.stderr.write('entering handle_entityref with %s\n' % ref) + if ref in ('lt', 'gt', 'quot', 'amp', 'apos'): + text = '&%s;' % ref + elif ref in self.entities.keys(): + text = self.entities[ref] + if text.startswith('&#') and text.endswith(';'): + return self.handle_entityref(text) + else: + try: name2codepoint[ref] + except KeyError: text = '&%s;' % ref + else: text = unichr(name2codepoint[ref]).encode('utf-8') + self.elementstack[-1][2].append(text) + + def handle_data(self, text, escape=1): + # called for each block of plain text, i.e. outside of any tag and + # not containing any character or entity references + if not self.elementstack: return + if escape and self.contentparams.get('type') == 'application/xhtml+xml': + text = _xmlescape(text) + self.elementstack[-1][2].append(text) + + def handle_comment(self, text): + # called for each comment, e.g. + pass + + def handle_pi(self, text): + # called for each processing instruction, e.g. + pass + + def handle_decl(self, text): + pass + + def parse_declaration(self, i): + # override internal declaration handler to handle CDATA blocks + if _debug: sys.stderr.write('entering parse_declaration\n') + if self.rawdata[i:i+9] == '', i) + if k == -1: + # CDATA block began but didn't finish + k = len(self.rawdata) + return k + self.handle_data(_xmlescape(self.rawdata[i+9:k]), 0) + return k+3 + else: + k = self.rawdata.find('>', i) + if k >= 0: + return k+1 + else: + # We have an incomplete CDATA block. + return k + + def mapContentType(self, contentType): + contentType = contentType.lower() + if contentType == 'text' or contentType == 'plain': + contentType = 'text/plain' + elif contentType == 'html': + contentType = 'text/html' + elif contentType == 'xhtml': + contentType = 'application/xhtml+xml' + return contentType + + def trackNamespace(self, prefix, uri): + loweruri = uri.lower() + if (prefix, loweruri) == (None, 'http://my.netscape.com/rdf/simple/0.9/') and not self.version: + self.version = 'rss090' + if loweruri == 'http://purl.org/rss/1.0/' and not self.version: + self.version = 'rss10' + if loweruri == 'http://www.w3.org/2005/atom' and not self.version: + self.version = 'atom10' + if loweruri.find('backend.userland.com/rss') <> -1: + # match any backend.userland.com namespace + uri = 'http://backend.userland.com/rss' + loweruri = uri + if self._matchnamespaces.has_key(loweruri): + self.namespacemap[prefix] = self._matchnamespaces[loweruri] + self.namespacesInUse[self._matchnamespaces[loweruri]] = uri + else: + self.namespacesInUse[prefix or ''] = uri + + def resolveURI(self, uri): + return _urljoin(self.baseuri or '', uri) + + def decodeEntities(self, element, data): + return data + + def strattrs(self, attrs): + return ''.join([' %s="%s"' % (t[0],_xmlescape(t[1],{'"':'"'})) for t in attrs]) + + def push(self, element, expectingText): + self.elementstack.append([element, expectingText, []]) + + def pop(self, element, stripWhitespace=1): + if not self.elementstack: return + if self.elementstack[-1][0] != element: return + + element, expectingText, pieces = self.elementstack.pop() + + if self.version == 'atom10' and self.contentparams.get('type','text') == 'application/xhtml+xml': + # remove enclosing child element, but only if it is a
and + # only if all the remaining content is nested underneath it. + # This means that the divs would be retained in the following: + #
foo
bar
+ while pieces and len(pieces)>1 and not pieces[-1].strip(): + del pieces[-1] + while pieces and len(pieces)>1 and not pieces[0].strip(): + del pieces[0] + if pieces and (pieces[0] == '
' or pieces[0].startswith('
': + depth = 0 + for piece in pieces[:-1]: + if piece.startswith(''): + depth += 1 + else: + pieces = pieces[1:-1] + + # Ensure each piece is a str for Python 3 + for (i, v) in enumerate(pieces): + if not isinstance(v, basestring): + pieces[i] = v.decode('utf-8') + + output = ''.join(pieces) + if stripWhitespace: + output = output.strip() + if not expectingText: return output + + # decode base64 content + if base64 and self.contentparams.get('base64', 0): + try: + output = _base64decode(output) + except binascii.Error: + pass + except binascii.Incomplete: + pass + except TypeError: + # In Python 3, base64 takes and outputs bytes, not str + # This may not be the most correct way to accomplish this + output = _base64decode(output.encode('utf-8')).decode('utf-8') + + # resolve relative URIs + if (element in self.can_be_relative_uri) and output: + output = self.resolveURI(output) + + # decode entities within embedded markup + if not self.contentparams.get('base64', 0): + output = self.decodeEntities(element, output) + + if self.lookslikehtml(output): + self.contentparams['type']='text/html' + + # remove temporary cruft from contentparams + try: + del self.contentparams['mode'] + except KeyError: + pass + try: + del self.contentparams['base64'] + except KeyError: + pass + + is_htmlish = self.mapContentType(self.contentparams.get('type', 'text/html')) in self.html_types + # resolve relative URIs within embedded markup + if is_htmlish and RESOLVE_RELATIVE_URIS: + if element in self.can_contain_relative_uris: + output = _resolveRelativeURIs(output, self.baseuri, self.encoding, self.contentparams.get('type', 'text/html')) + + # parse microformats + # (must do this before sanitizing because some microformats + # rely on elements that we sanitize) + if is_htmlish and element in ['content', 'description', 'summary']: + mfresults = _parseMicroformats(output, self.baseuri, self.encoding) + if mfresults: + for tag in mfresults.get('tags', []): + self._addTag(tag['term'], tag['scheme'], tag['label']) + for enclosure in mfresults.get('enclosures', []): + self._start_enclosure(enclosure) + for xfn in mfresults.get('xfn', []): + self._addXFN(xfn['relationships'], xfn['href'], xfn['name']) + vcard = mfresults.get('vcard') + if vcard: + self._getContext()['vcard'] = vcard + + # sanitize embedded markup + if is_htmlish and SANITIZE_HTML: + if element in self.can_contain_dangerous_markup: + output = _sanitizeHTML(output, self.encoding, self.contentparams.get('type', 'text/html')) + + if self.encoding and type(output) != type(u''): + try: + output = unicode(output, self.encoding) + except: + pass + + # address common error where people take data that is already + # utf-8, presume that it is iso-8859-1, and re-encode it. + if self.encoding in ('utf-8', 'utf-8_INVALID_PYTHON_3') and type(output) == type(u''): + try: + output = unicode(output.encode('iso-8859-1'), 'utf-8') + except: + pass + + # map win-1252 extensions to the proper code points + if type(output) == type(u''): + output = u''.join([c in _cp1252.keys() and _cp1252[c] or c for c in output]) + + # categories/tags/keywords/whatever are handled in _end_category + if element == 'category': + return output + + if element == 'title' and self.hasTitle: + return output + + # store output in appropriate place(s) + if self.inentry and not self.insource: + if element == 'content': + self.entries[-1].setdefault(element, []) + contentparams = copy.deepcopy(self.contentparams) + contentparams['value'] = output + self.entries[-1][element].append(contentparams) + elif element == 'link': + if not self.inimage: + # query variables in urls in link elements are improperly + # converted from `?a=1&b=2` to `?a=1&b;=2` as if they're + # unhandled character references. fix this special case. + output = re.sub("&([A-Za-z0-9_]+);", "&\g<1>", output) + self.entries[-1][element] = output + if output: + self.entries[-1]['links'][-1]['href'] = output + else: + if element == 'description': + element = 'summary' + self.entries[-1][element] = output + if self.incontent: + contentparams = copy.deepcopy(self.contentparams) + contentparams['value'] = output + self.entries[-1][element + '_detail'] = contentparams + elif (self.infeed or self.insource):# and (not self.intextinput) and (not self.inimage): + context = self._getContext() + if element == 'description': + element = 'subtitle' + context[element] = output + if element == 'link': + # fix query variables; see above for the explanation + output = re.sub("&([A-Za-z0-9_]+);", "&\g<1>", output) + context[element] = output + context['links'][-1]['href'] = output + elif self.incontent: + contentparams = copy.deepcopy(self.contentparams) + contentparams['value'] = output + context[element + '_detail'] = contentparams + return output + + def pushContent(self, tag, attrsD, defaultContentType, expectingText): + self.incontent += 1 + if self.lang: self.lang=self.lang.replace('_','-') + self.contentparams = FeedParserDict({ + 'type': self.mapContentType(attrsD.get('type', defaultContentType)), + 'language': self.lang, + 'base': self.baseuri}) + self.contentparams['base64'] = self._isBase64(attrsD, self.contentparams) + self.push(tag, expectingText) + + def popContent(self, tag): + value = self.pop(tag) + self.incontent -= 1 + self.contentparams.clear() + return value + + # a number of elements in a number of RSS variants are nominally plain + # text, but this is routinely ignored. This is an attempt to detect + # the most common cases. As false positives often result in silent + # data loss, this function errs on the conservative side. + def lookslikehtml(self, s): + if self.version.startswith('atom'): return + if self.contentparams.get('type','text/html') != 'text/plain': return + + # must have a close tag or a entity reference to qualify + if not (re.search(r'',s) or re.search("&#?\w+;",s)): return + + # all tags must be in a restricted subset of valid HTML tags + if filter(lambda t: t.lower() not in _HTMLSanitizer.acceptable_elements, + re.findall(r' -1: + prefix = name[:colonpos] + suffix = name[colonpos+1:] + prefix = self.namespacemap.get(prefix, prefix) + name = prefix + ':' + suffix + return name + + def _getAttribute(self, attrsD, name): + return attrsD.get(self._mapToStandardPrefix(name)) + + def _isBase64(self, attrsD, contentparams): + if attrsD.get('mode', '') == 'base64': + return 1 + if self.contentparams['type'].startswith('text/'): + return 0 + if self.contentparams['type'].endswith('+xml'): + return 0 + if self.contentparams['type'].endswith('/xml'): + return 0 + return 1 + + def _itsAnHrefDamnIt(self, attrsD): + href = attrsD.get('url', attrsD.get('uri', attrsD.get('href', None))) + if href: + try: + del attrsD['url'] + except KeyError: + pass + try: + del attrsD['uri'] + except KeyError: + pass + attrsD['href'] = href + return attrsD + + def _save(self, key, value, overwrite=False): + context = self._getContext() + if overwrite: + context[key] = value + else: + context.setdefault(key, value) + + def _start_rss(self, attrsD): + versionmap = {'0.91': 'rss091u', + '0.92': 'rss092', + '0.93': 'rss093', + '0.94': 'rss094'} + #If we're here then this is an RSS feed. + #If we don't have a version or have a version that starts with something + #other than RSS then there's been a mistake. Correct it. + if not self.version or not self.version.startswith('rss'): + attr_version = attrsD.get('version', '') + version = versionmap.get(attr_version) + if version: + self.version = version + elif attr_version.startswith('2.'): + self.version = 'rss20' + else: + self.version = 'rss' + + def _start_dlhottitles(self, attrsD): + self.version = 'hotrss' + + def _start_channel(self, attrsD): + self.infeed = 1 + self._cdf_common(attrsD) + _start_feedinfo = _start_channel + + def _cdf_common(self, attrsD): + if attrsD.has_key('lastmod'): + self._start_modified({}) + self.elementstack[-1][-1] = attrsD['lastmod'] + self._end_modified() + if attrsD.has_key('href'): + self._start_link({}) + self.elementstack[-1][-1] = attrsD['href'] + self._end_link() + + def _start_feed(self, attrsD): + self.infeed = 1 + versionmap = {'0.1': 'atom01', + '0.2': 'atom02', + '0.3': 'atom03'} + if not self.version: + attr_version = attrsD.get('version') + version = versionmap.get(attr_version) + if version: + self.version = version + else: + self.version = 'atom' + + def _end_channel(self): + self.infeed = 0 + _end_feed = _end_channel + + def _start_image(self, attrsD): + context = self._getContext() + if not self.inentry: + context.setdefault('image', FeedParserDict()) + self.inimage = 1 + self.hasTitle = 0 + self.push('image', 0) + + def _end_image(self): + self.pop('image') + self.inimage = 0 + + def _start_textinput(self, attrsD): + context = self._getContext() + context.setdefault('textinput', FeedParserDict()) + self.intextinput = 1 + self.hasTitle = 0 + self.push('textinput', 0) + _start_textInput = _start_textinput + + def _end_textinput(self): + self.pop('textinput') + self.intextinput = 0 + _end_textInput = _end_textinput + + def _start_author(self, attrsD): + self.inauthor = 1 + self.push('author', 1) + # Append a new FeedParserDict when expecting an author + context = self._getContext() + context.setdefault('authors', []) + context['authors'].append(FeedParserDict()) + _start_managingeditor = _start_author + _start_dc_author = _start_author + _start_dc_creator = _start_author + _start_itunes_author = _start_author + + def _end_author(self): + self.pop('author') + self.inauthor = 0 + self._sync_author_detail() + _end_managingeditor = _end_author + _end_dc_author = _end_author + _end_dc_creator = _end_author + _end_itunes_author = _end_author + + def _start_itunes_owner(self, attrsD): + self.inpublisher = 1 + self.push('publisher', 0) + + def _end_itunes_owner(self): + self.pop('publisher') + self.inpublisher = 0 + self._sync_author_detail('publisher') + + def _start_contributor(self, attrsD): + self.incontributor = 1 + context = self._getContext() + context.setdefault('contributors', []) + context['contributors'].append(FeedParserDict()) + self.push('contributor', 0) + + def _end_contributor(self): + self.pop('contributor') + self.incontributor = 0 + + def _start_dc_contributor(self, attrsD): + self.incontributor = 1 + context = self._getContext() + context.setdefault('contributors', []) + context['contributors'].append(FeedParserDict()) + self.push('name', 0) + + def _end_dc_contributor(self): + self._end_name() + self.incontributor = 0 + + def _start_name(self, attrsD): + self.push('name', 0) + _start_itunes_name = _start_name + + def _end_name(self): + value = self.pop('name') + if self.inpublisher: + self._save_author('name', value, 'publisher') + elif self.inauthor: + self._save_author('name', value) + elif self.incontributor: + self._save_contributor('name', value) + elif self.intextinput: + context = self._getContext() + context['name'] = value + _end_itunes_name = _end_name + + def _start_width(self, attrsD): + self.push('width', 0) + + def _end_width(self): + value = self.pop('width') + try: + value = int(value) + except: + value = 0 + if self.inimage: + context = self._getContext() + context['width'] = value + + def _start_height(self, attrsD): + self.push('height', 0) + + def _end_height(self): + value = self.pop('height') + try: + value = int(value) + except: + value = 0 + if self.inimage: + context = self._getContext() + context['height'] = value + + def _start_url(self, attrsD): + self.push('href', 1) + _start_homepage = _start_url + _start_uri = _start_url + + def _end_url(self): + value = self.pop('href') + if self.inauthor: + self._save_author('href', value) + elif self.incontributor: + self._save_contributor('href', value) + _end_homepage = _end_url + _end_uri = _end_url + + def _start_email(self, attrsD): + self.push('email', 0) + _start_itunes_email = _start_email + + def _end_email(self): + value = self.pop('email') + if self.inpublisher: + self._save_author('email', value, 'publisher') + elif self.inauthor: + self._save_author('email', value) + elif self.incontributor: + self._save_contributor('email', value) + _end_itunes_email = _end_email + + def _getContext(self): + if self.insource: + context = self.sourcedata + elif self.inimage and self.feeddata.has_key('image'): + context = self.feeddata['image'] + elif self.intextinput: + context = self.feeddata['textinput'] + elif self.inentry: + context = self.entries[-1] + else: + context = self.feeddata + return context + + def _save_author(self, key, value, prefix='author'): + context = self._getContext() + context.setdefault(prefix + '_detail', FeedParserDict()) + context[prefix + '_detail'][key] = value + self._sync_author_detail() + context.setdefault('authors', [FeedParserDict()]) + context['authors'][-1][key] = value + + def _save_contributor(self, key, value): + context = self._getContext() + context.setdefault('contributors', [FeedParserDict()]) + context['contributors'][-1][key] = value + + def _sync_author_detail(self, key='author'): + context = self._getContext() + detail = context.get('%s_detail' % key) + if detail: + name = detail.get('name') + email = detail.get('email') + if name and email: + context[key] = '%s (%s)' % (name, email) + elif name: + context[key] = name + elif email: + context[key] = email + else: + author, email = context.get(key), None + if not author: return + emailmatch = re.search(r'''(([a-zA-Z0-9\_\-\.\+]+)@((\[[0-9]{1,3}\.[0-9]{1,3}\.[0-9]{1,3}\.)|(([a-zA-Z0-9\-]+\.)+))([a-zA-Z]{2,4}|[0-9]{1,3})(\]?))(\?subject=\S+)?''', author) + if emailmatch: + email = emailmatch.group(0) + # probably a better way to do the following, but it passes all the tests + author = author.replace(email, '') + author = author.replace('()', '') + author = author.replace('<>', '') + author = author.replace('<>', '') + author = author.strip() + if author and (author[0] == '('): + author = author[1:] + if author and (author[-1] == ')'): + author = author[:-1] + author = author.strip() + if author or email: + context.setdefault('%s_detail' % key, FeedParserDict()) + if author: + context['%s_detail' % key]['name'] = author + if email: + context['%s_detail' % key]['email'] = email + + def _start_subtitle(self, attrsD): + self.pushContent('subtitle', attrsD, 'text/plain', 1) + _start_tagline = _start_subtitle + _start_itunes_subtitle = _start_subtitle + + def _end_subtitle(self): + self.popContent('subtitle') + _end_tagline = _end_subtitle + _end_itunes_subtitle = _end_subtitle + + def _start_rights(self, attrsD): + self.pushContent('rights', attrsD, 'text/plain', 1) + _start_dc_rights = _start_rights + _start_copyright = _start_rights + + def _end_rights(self): + self.popContent('rights') + _end_dc_rights = _end_rights + _end_copyright = _end_rights + + def _start_item(self, attrsD): + self.entries.append(FeedParserDict()) + self.push('item', 0) + self.inentry = 1 + self.guidislink = 0 + self.hasTitle = 0 + id = self._getAttribute(attrsD, 'rdf:about') + if id: + context = self._getContext() + context['id'] = id + self._cdf_common(attrsD) + _start_entry = _start_item + _start_product = _start_item + + def _end_item(self): + self.pop('item') + self.inentry = 0 + _end_entry = _end_item + + def _start_dc_language(self, attrsD): + self.push('language', 1) + _start_language = _start_dc_language + + def _end_dc_language(self): + self.lang = self.pop('language') + _end_language = _end_dc_language + + def _start_dc_publisher(self, attrsD): + self.push('publisher', 1) + _start_webmaster = _start_dc_publisher + + def _end_dc_publisher(self): + self.pop('publisher') + self._sync_author_detail('publisher') + _end_webmaster = _end_dc_publisher + + def _start_published(self, attrsD): + self.push('published', 1) + _start_dcterms_issued = _start_published + _start_issued = _start_published + + def _end_published(self): + value = self.pop('published') + self._save('published_parsed', _parse_date(value), overwrite=True) + _end_dcterms_issued = _end_published + _end_issued = _end_published + + def _start_updated(self, attrsD): + self.push('updated', 1) + _start_modified = _start_updated + _start_dcterms_modified = _start_updated + _start_pubdate = _start_updated + _start_dc_date = _start_updated + _start_lastbuilddate = _start_updated + + def _end_updated(self): + value = self.pop('updated') + parsed_value = _parse_date(value) + self._save('updated_parsed', parsed_value, overwrite=True) + _end_modified = _end_updated + _end_dcterms_modified = _end_updated + _end_pubdate = _end_updated + _end_dc_date = _end_updated + _end_lastbuilddate = _end_updated + + def _start_created(self, attrsD): + self.push('created', 1) + _start_dcterms_created = _start_created + + def _end_created(self): + value = self.pop('created') + self._save('created_parsed', _parse_date(value), overwrite=True) + _end_dcterms_created = _end_created + + def _start_expirationdate(self, attrsD): + self.push('expired', 1) + + def _end_expirationdate(self): + self._save('expired_parsed', _parse_date(self.pop('expired')), overwrite=True) + + def _start_cc_license(self, attrsD): + context = self._getContext() + value = self._getAttribute(attrsD, 'rdf:resource') + attrsD = FeedParserDict() + attrsD['rel']='license' + if value: attrsD['href']=value + context.setdefault('links', []).append(attrsD) + + def _start_creativecommons_license(self, attrsD): + self.push('license', 1) + _start_creativeCommons_license = _start_creativecommons_license + + def _end_creativecommons_license(self): + value = self.pop('license') + context = self._getContext() + attrsD = FeedParserDict() + attrsD['rel']='license' + if value: attrsD['href']=value + context.setdefault('links', []).append(attrsD) + del context['license'] + _end_creativeCommons_license = _end_creativecommons_license + + def _addXFN(self, relationships, href, name): + context = self._getContext() + xfn = context.setdefault('xfn', []) + value = FeedParserDict({'relationships': relationships, 'href': href, 'name': name}) + if value not in xfn: + xfn.append(value) + + def _addTag(self, term, scheme, label): + context = self._getContext() + tags = context.setdefault('tags', []) + if (not term) and (not scheme) and (not label): return + value = FeedParserDict({'term': term, 'scheme': scheme, 'label': label}) + if value not in tags: + tags.append(value) + + def _start_category(self, attrsD): + if _debug: sys.stderr.write('entering _start_category with %s\n' % repr(attrsD)) + term = attrsD.get('term') + scheme = attrsD.get('scheme', attrsD.get('domain')) + label = attrsD.get('label') + self._addTag(term, scheme, label) + self.push('category', 1) + _start_dc_subject = _start_category + _start_keywords = _start_category + + def _start_media_category(self, attrsD): + attrsD.setdefault('scheme', 'http://search.yahoo.com/mrss/category_schema') + self._start_category(attrsD) + + def _end_itunes_keywords(self): + for term in self.pop('itunes_keywords').split(): + self._addTag(term, 'http://www.itunes.com/', None) + + def _start_itunes_category(self, attrsD): + self._addTag(attrsD.get('text'), 'http://www.itunes.com/', None) + self.push('category', 1) + + def _end_category(self): + value = self.pop('category') + if not value: return + context = self._getContext() + tags = context['tags'] + if value and len(tags) and not tags[-1]['term']: + tags[-1]['term'] = value + else: + self._addTag(value, None, None) + _end_dc_subject = _end_category + _end_keywords = _end_category + _end_itunes_category = _end_category + _end_media_category = _end_category + + def _start_cloud(self, attrsD): + self._getContext()['cloud'] = FeedParserDict(attrsD) + + def _start_link(self, attrsD): + attrsD.setdefault('rel', 'alternate') + if attrsD['rel'] == 'self': + attrsD.setdefault('type', 'application/atom+xml') + else: + attrsD.setdefault('type', 'text/html') + context = self._getContext() + attrsD = self._itsAnHrefDamnIt(attrsD) + if attrsD.has_key('href'): + attrsD['href'] = self.resolveURI(attrsD['href']) + expectingText = self.infeed or self.inentry or self.insource + context.setdefault('links', []) + if not (self.inentry and self.inimage): + context['links'].append(FeedParserDict(attrsD)) + if attrsD.has_key('href'): + expectingText = 0 + if (attrsD.get('rel') == 'alternate') and (self.mapContentType(attrsD.get('type')) in self.html_types): + context['link'] = attrsD['href'] + else: + self.push('link', expectingText) + _start_producturl = _start_link + + def _end_link(self): + value = self.pop('link') + context = self._getContext() + _end_producturl = _end_link + + def _start_guid(self, attrsD): + self.guidislink = (attrsD.get('ispermalink', 'true') == 'true') + self.push('id', 1) + + def _end_guid(self): + value = self.pop('id') + self._save('guidislink', self.guidislink and not self._getContext().has_key('link')) + if self.guidislink: + # guid acts as link, but only if 'ispermalink' is not present or is 'true', + # and only if the item doesn't already have a link element + self._save('link', value) + + def _start_title(self, attrsD): + if self.svgOK: return self.unknown_starttag('title', attrsD.items()) + self.pushContent('title', attrsD, 'text/plain', self.infeed or self.inentry or self.insource) + _start_dc_title = _start_title + _start_media_title = _start_title + + def _end_title(self): + if self.svgOK: return + value = self.popContent('title') + if not value: return + context = self._getContext() + self.hasTitle = 1 + _end_dc_title = _end_title + + def _end_media_title(self): + hasTitle = self.hasTitle + self._end_title() + self.hasTitle = hasTitle + + def _start_description(self, attrsD): + context = self._getContext() + if context.has_key('summary'): + self._summaryKey = 'content' + self._start_content(attrsD) + else: + self.pushContent('description', attrsD, 'text/html', self.infeed or self.inentry or self.insource) + _start_dc_description = _start_description + + def _start_abstract(self, attrsD): + self.pushContent('description', attrsD, 'text/plain', self.infeed or self.inentry or self.insource) + + def _end_description(self): + if self._summaryKey == 'content': + self._end_content() + else: + value = self.popContent('description') + self._summaryKey = None + _end_abstract = _end_description + _end_dc_description = _end_description + + def _start_info(self, attrsD): + self.pushContent('info', attrsD, 'text/plain', 1) + _start_feedburner_browserfriendly = _start_info + + def _end_info(self): + self.popContent('info') + _end_feedburner_browserfriendly = _end_info + + def _start_generator(self, attrsD): + if attrsD: + attrsD = self._itsAnHrefDamnIt(attrsD) + if attrsD.has_key('href'): + attrsD['href'] = self.resolveURI(attrsD['href']) + self._getContext()['generator_detail'] = FeedParserDict(attrsD) + self.push('generator', 1) + + def _end_generator(self): + value = self.pop('generator') + context = self._getContext() + if context.has_key('generator_detail'): + context['generator_detail']['name'] = value + + def _start_admin_generatoragent(self, attrsD): + self.push('generator', 1) + value = self._getAttribute(attrsD, 'rdf:resource') + if value: + self.elementstack[-1][2].append(value) + self.pop('generator') + self._getContext()['generator_detail'] = FeedParserDict({'href': value}) + + def _start_admin_errorreportsto(self, attrsD): + self.push('errorreportsto', 1) + value = self._getAttribute(attrsD, 'rdf:resource') + if value: + self.elementstack[-1][2].append(value) + self.pop('errorreportsto') + + def _start_summary(self, attrsD): + context = self._getContext() + if context.has_key('summary'): + self._summaryKey = 'content' + self._start_content(attrsD) + else: + self._summaryKey = 'summary' + self.pushContent(self._summaryKey, attrsD, 'text/plain', 1) + _start_itunes_summary = _start_summary + + def _end_summary(self): + if self._summaryKey == 'content': + self._end_content() + else: + self.popContent(self._summaryKey or 'summary') + self._summaryKey = None + _end_itunes_summary = _end_summary + + def _start_enclosure(self, attrsD): + attrsD = self._itsAnHrefDamnIt(attrsD) + context = self._getContext() + attrsD['rel']='enclosure' + context.setdefault('links', []).append(FeedParserDict(attrsD)) + + def _start_source(self, attrsD): + if 'url' in attrsD: + # This means that we're processing a source element from an RSS 2.0 feed + self.sourcedata['href'] = attrsD[u'url'] + self.push('source', 1) + self.insource = 1 + self.hasTitle = 0 + + def _end_source(self): + self.insource = 0 + value = self.pop('source') + if value: + self.sourcedata['title'] = value + self._getContext()['source'] = copy.deepcopy(self.sourcedata) + self.sourcedata.clear() + + def _start_content(self, attrsD): + self.pushContent('content', attrsD, 'text/plain', 1) + src = attrsD.get('src') + if src: + self.contentparams['src'] = src + self.push('content', 1) + + def _start_prodlink(self, attrsD): + self.pushContent('content', attrsD, 'text/html', 1) + + def _start_body(self, attrsD): + self.pushContent('content', attrsD, 'application/xhtml+xml', 1) + _start_xhtml_body = _start_body + + def _start_content_encoded(self, attrsD): + self.pushContent('content', attrsD, 'text/html', 1) + _start_fullitem = _start_content_encoded + + def _end_content(self): + copyToSummary = self.mapContentType(self.contentparams.get('type')) in (['text/plain'] + self.html_types) + value = self.popContent('content') + if copyToSummary: + self._save('summary', value) + + _end_body = _end_content + _end_xhtml_body = _end_content + _end_content_encoded = _end_content + _end_fullitem = _end_content + _end_prodlink = _end_content + + def _start_itunes_image(self, attrsD): + self.push('itunes_image', 0) + if attrsD.get('href'): + self._getContext()['image'] = FeedParserDict({'href': attrsD.get('href')}) + _start_itunes_link = _start_itunes_image + + def _end_itunes_block(self): + value = self.pop('itunes_block', 0) + self._getContext()['itunes_block'] = (value == 'yes') and 1 or 0 + + def _end_itunes_explicit(self): + value = self.pop('itunes_explicit', 0) + # Convert 'yes' -> True, 'clean' to False, and any other value to None + # False and None both evaluate as False, so the difference can be ignored + # by applications that only need to know if the content is explicit. + self._getContext()['itunes_explicit'] = (None, False, True)[(value == 'yes' and 2) or value == 'clean' or 0] + + def _start_media_content(self, attrsD): + context = self._getContext() + context.setdefault('media_content', []) + context['media_content'].append(attrsD) + + def _start_media_thumbnail(self, attrsD): + context = self._getContext() + context.setdefault('media_thumbnail', []) + self.push('url', 1) # new + context['media_thumbnail'].append(attrsD) + + def _end_media_thumbnail(self): + url = self.pop('url') + context = self._getContext() + if url != None and len(url.strip()) != 0: + if not context['media_thumbnail'][-1].has_key('url'): + context['media_thumbnail'][-1]['url'] = url + + def _start_media_player(self, attrsD): + self.push('media_player', 0) + self._getContext()['media_player'] = FeedParserDict(attrsD) + + def _end_media_player(self): + value = self.pop('media_player') + context = self._getContext() + context['media_player']['content'] = value + + def _start_newlocation(self, attrsD): + self.push('newlocation', 1) + + def _end_newlocation(self): + url = self.pop('newlocation') + context = self._getContext() + # don't set newlocation if the context isn't right + if context is not self.feeddata: + return + context['newlocation'] = _makeSafeAbsoluteURI(self.baseuri, url.strip()) + +if _XML_AVAILABLE: + class _StrictFeedParser(_FeedParserMixin, xml.sax.handler.ContentHandler): + def __init__(self, baseuri, baselang, encoding): + if _debug: sys.stderr.write('trying StrictFeedParser\n') + xml.sax.handler.ContentHandler.__init__(self) + _FeedParserMixin.__init__(self, baseuri, baselang, encoding) + self.bozo = 0 + self.exc = None + self.decls = {} + + def startPrefixMapping(self, prefix, uri): + self.trackNamespace(prefix, uri) + if uri == 'http://www.w3.org/1999/xlink': + self.decls['xmlns:'+prefix] = uri + + def startElementNS(self, name, qname, attrs): + namespace, localname = name + lowernamespace = str(namespace or '').lower() + if lowernamespace.find('backend.userland.com/rss') <> -1: + # match any backend.userland.com namespace + namespace = 'http://backend.userland.com/rss' + lowernamespace = namespace + if qname and qname.find(':') > 0: + givenprefix = qname.split(':')[0] + else: + givenprefix = None + prefix = self._matchnamespaces.get(lowernamespace, givenprefix) + if givenprefix and (prefix is None or (prefix == '' and lowernamespace == '')) and not self.namespacesInUse.has_key(givenprefix): + raise UndeclaredNamespace, "'%s' is not associated with a namespace" % givenprefix + localname = str(localname).lower() + + # qname implementation is horribly broken in Python 2.1 (it + # doesn't report any), and slightly broken in Python 2.2 (it + # doesn't report the xml: namespace). So we match up namespaces + # with a known list first, and then possibly override them with + # the qnames the SAX parser gives us (if indeed it gives us any + # at all). Thanks to MatejC for helping me test this and + # tirelessly telling me that it didn't work yet. + attrsD, self.decls = self.decls, {} + if localname=='math' and namespace=='http://www.w3.org/1998/Math/MathML': + attrsD['xmlns']=namespace + if localname=='svg' and namespace=='http://www.w3.org/2000/svg': + attrsD['xmlns']=namespace + + if prefix: + localname = prefix.lower() + ':' + localname + elif namespace and not qname: #Expat + for name,value in self.namespacesInUse.items(): + if name and value == namespace: + localname = name + ':' + localname + break + if _debug: sys.stderr.write('startElementNS: qname = %s, namespace = %s, givenprefix = %s, prefix = %s, attrs = %s, localname = %s\n' % (qname, namespace, givenprefix, prefix, attrs.items(), localname)) + + for (namespace, attrlocalname), attrvalue in attrs._attrs.items(): + lowernamespace = (namespace or '').lower() + prefix = self._matchnamespaces.get(lowernamespace, '') + if prefix: + attrlocalname = prefix + ':' + attrlocalname + attrsD[str(attrlocalname).lower()] = attrvalue + for qname in attrs.getQNames(): + attrsD[str(qname).lower()] = attrs.getValueByQName(qname) + self.unknown_starttag(localname, attrsD.items()) + + def characters(self, text): + self.handle_data(text) + + def endElementNS(self, name, qname): + namespace, localname = name + lowernamespace = str(namespace or '').lower() + if qname and qname.find(':') > 0: + givenprefix = qname.split(':')[0] + else: + givenprefix = '' + prefix = self._matchnamespaces.get(lowernamespace, givenprefix) + if prefix: + localname = prefix + ':' + localname + elif namespace and not qname: #Expat + for name,value in self.namespacesInUse.items(): + if name and value == namespace: + localname = name + ':' + localname + break + localname = str(localname).lower() + self.unknown_endtag(localname) + + def error(self, exc): + self.bozo = 1 + self.exc = exc + + def fatalError(self, exc): + self.error(exc) + raise exc + +class _BaseHTMLProcessor(sgmllib.SGMLParser): + special = re.compile('''[<>'"]''') + bare_ampersand = re.compile("&(?!#\d+;|#x[0-9a-fA-F]+;|\w+;)") + elements_no_end_tag = [ + 'area', 'base', 'basefont', 'br', 'col', 'command', 'embed', 'frame', + 'hr', 'img', 'input', 'isindex', 'keygen', 'link', 'meta', 'param', + 'source', 'track', 'wbr' + ] + + def __init__(self, encoding, _type): + self.encoding = encoding + self._type = _type + if _debug: sys.stderr.write('entering BaseHTMLProcessor, encoding=%s\n' % self.encoding) + sgmllib.SGMLParser.__init__(self) + + def reset(self): + self.pieces = [] + sgmllib.SGMLParser.reset(self) + + def _shorttag_replace(self, match): + tag = match.group(1) + if tag in self.elements_no_end_tag: + return '<' + tag + ' />' + else: + return '<' + tag + '>' + + def parse_starttag(self,i): + j=sgmllib.SGMLParser.parse_starttag(self, i) + if self._type == 'application/xhtml+xml': + if j>2 and self.rawdata[j-2:j]=='/>': + self.unknown_endtag(self.lasttag) + return j + + def feed(self, data): + data = re.compile(r'', self._shorttag_replace, data) # bug [ 1399464 ] Bad regexp for _shorttag_replace + data = re.sub(r'<([^<>\s]+?)\s*/>', self._shorttag_replace, data) + data = data.replace(''', "'") + data = data.replace('"', '"') + try: + bytes + if bytes is str: + raise NameError + self.encoding = self.encoding + '_INVALID_PYTHON_3' + except NameError: + if self.encoding and type(data) == type(u''): + data = data.encode(self.encoding) + sgmllib.SGMLParser.feed(self, data) + sgmllib.SGMLParser.close(self) + + def normalize_attrs(self, attrs): + if not attrs: return attrs + # utility method to be called by descendants + attrs = dict([(k.lower(), v) for k, v in attrs]).items() + attrs = [(k, k in ('rel', 'type') and v.lower() or v) for k, v in attrs] + attrs.sort() + return attrs + + def unknown_starttag(self, tag, attrs): + # called for each start tag + # attrs is a list of (attr, value) tuples + # e.g. for
, tag='pre', attrs=[('class', 'screen')]
+        if _debug: sys.stderr.write('_BaseHTMLProcessor, unknown_starttag, tag=%s\n' % tag)
+        uattrs = []
+        strattrs=''
+        if attrs:
+            for key, value in attrs:
+                value=value.replace('>','>').replace('<','<').replace('"','"')
+                value = self.bare_ampersand.sub("&", value)
+                # thanks to Kevin Marks for this breathtaking hack to deal with (valid) high-bit attribute values in UTF-8 feeds
+                if type(value) != type(u''):
+                    try:
+                        value = unicode(value, self.encoding)
+                    except:
+                        value = unicode(value, 'iso-8859-1')
+                try:
+                    # Currently, in Python 3 the key is already a str, and cannot be decoded again
+                    uattrs.append((unicode(key, self.encoding), value))
+                except TypeError:
+                    uattrs.append((key, value))
+            strattrs = u''.join([u' %s="%s"' % (key, value) for key, value in uattrs])
+            if self.encoding:
+                try:
+                    strattrs=strattrs.encode(self.encoding)
+                except:
+                    pass
+        if tag in self.elements_no_end_tag:
+            self.pieces.append('<%(tag)s%(strattrs)s />' % locals())
+        else:
+            self.pieces.append('<%(tag)s%(strattrs)s>' % locals())
+
+    def unknown_endtag(self, tag):
+        # called for each end tag, e.g. for 
, tag will be 'pre' + # Reconstruct the original end tag. + if tag not in self.elements_no_end_tag: + self.pieces.append("" % locals()) + + def handle_charref(self, ref): + # called for each character reference, e.g. for ' ', ref will be '160' + # Reconstruct the original character reference. + if ref.startswith('x'): + value = unichr(int(ref[1:],16)) + else: + value = unichr(int(ref)) + + if value in _cp1252.keys(): + self.pieces.append('&#%s;' % hex(ord(_cp1252[value]))[1:]) + else: + self.pieces.append('&#%(ref)s;' % locals()) + + def handle_entityref(self, ref): + # called for each entity reference, e.g. for '©', ref will be 'copy' + # Reconstruct the original entity reference. + if name2codepoint.has_key(ref): + self.pieces.append('&%(ref)s;' % locals()) + else: + self.pieces.append('&%(ref)s' % locals()) + + def handle_data(self, text): + # called for each block of plain text, i.e. outside of any tag and + # not containing any character or entity references + # Store the original text verbatim. + if _debug: sys.stderr.write('_BaseHTMLProcessor, handle_data, text=%s\n' % text) + self.pieces.append(text) + + def handle_comment(self, text): + # called for each HTML comment, e.g. + # Reconstruct the original comment. + self.pieces.append('' % locals()) + + def handle_pi(self, text): + # called for each processing instruction, e.g. + # Reconstruct original processing instruction. + self.pieces.append('' % locals()) + + def handle_decl(self, text): + # called for the DOCTYPE, if present, e.g. + # + # Reconstruct original DOCTYPE + self.pieces.append('' % locals()) + + _new_declname_match = re.compile(r'[a-zA-Z][-_.a-zA-Z0-9:]*\s*').match + def _scan_name(self, i, declstartpos): + rawdata = self.rawdata + n = len(rawdata) + if i == n: + return None, -1 + m = self._new_declname_match(rawdata, i) + if m: + s = m.group() + name = s.strip() + if (i + len(s)) == n: + return None, -1 # end of buffer + return name.lower(), m.end() + else: + self.handle_data(rawdata) +# self.updatepos(declstartpos, i) + return None, -1 + + def convert_charref(self, name): + return '&#%s;' % name + + def convert_entityref(self, name): + return '&%s;' % name + + def output(self): + '''Return processed HTML as a single string''' + return ''.join([str(p) for p in self.pieces]) + + def parse_declaration(self, i): + try: + return sgmllib.SGMLParser.parse_declaration(self, i) + except sgmllib.SGMLParseError: + # escape the doctype declaration and continue parsing + self.handle_data('<') + return i+1 + +class _LooseFeedParser(_FeedParserMixin, _BaseHTMLProcessor): + def __init__(self, baseuri, baselang, encoding, entities): + sgmllib.SGMLParser.__init__(self) + _FeedParserMixin.__init__(self, baseuri, baselang, encoding) + _BaseHTMLProcessor.__init__(self, encoding, 'application/xhtml+xml') + self.entities=entities + + def decodeEntities(self, element, data): + data = data.replace('<', '<') + data = data.replace('<', '<') + data = data.replace('<', '<') + data = data.replace('>', '>') + data = data.replace('>', '>') + data = data.replace('>', '>') + data = data.replace('&', '&') + data = data.replace('&', '&') + data = data.replace('"', '"') + data = data.replace('"', '"') + data = data.replace(''', ''') + data = data.replace(''', ''') + if self.contentparams.has_key('type') and not self.contentparams.get('type', 'xml').endswith('xml'): + data = data.replace('<', '<') + data = data.replace('>', '>') + data = data.replace('&', '&') + data = data.replace('"', '"') + data = data.replace(''', "'") + return data + + def strattrs(self, attrs): + return ''.join([' %s="%s"' % (n,v.replace('"','"')) for n,v in attrs]) + +class _MicroformatsParser: + STRING = 1 + DATE = 2 + URI = 3 + NODE = 4 + EMAIL = 5 + + known_xfn_relationships = ['contact', 'acquaintance', 'friend', 'met', 'co-worker', 'coworker', 'colleague', 'co-resident', 'coresident', 'neighbor', 'child', 'parent', 'sibling', 'brother', 'sister', 'spouse', 'wife', 'husband', 'kin', 'relative', 'muse', 'crush', 'date', 'sweetheart', 'me'] + known_binary_extensions = ['zip','rar','exe','gz','tar','tgz','tbz2','bz2','z','7z','dmg','img','sit','sitx','hqx','deb','rpm','bz2','jar','rar','iso','bin','msi','mp2','mp3','ogg','ogm','mp4','m4v','m4a','avi','wma','wmv'] + + def __init__(self, data, baseuri, encoding): + self.document = BeautifulSoup.BeautifulSoup(data) + self.baseuri = baseuri + self.encoding = encoding + if type(data) == type(u''): + data = data.encode(encoding) + self.tags = [] + self.enclosures = [] + self.xfn = [] + self.vcard = None + + def vcardEscape(self, s): + if type(s) in (type(''), type(u'')): + s = s.replace(',', '\\,').replace(';', '\\;').replace('\n', '\\n') + return s + + def vcardFold(self, s): + s = re.sub(';+$', '', s) + sFolded = '' + iMax = 75 + sPrefix = '' + while len(s) > iMax: + sFolded += sPrefix + s[:iMax] + '\n' + s = s[iMax:] + sPrefix = ' ' + iMax = 74 + sFolded += sPrefix + s + return sFolded + + def normalize(self, s): + return re.sub(r'\s+', ' ', s).strip() + + def unique(self, aList): + results = [] + for element in aList: + if element not in results: + results.append(element) + return results + + def toISO8601(self, dt): + return time.strftime('%Y-%m-%dT%H:%M:%SZ', dt) + + def getPropertyValue(self, elmRoot, sProperty, iPropertyType=4, bAllowMultiple=0, bAutoEscape=0): + all = lambda x: 1 + sProperty = sProperty.lower() + bFound = 0 + bNormalize = 1 + propertyMatch = {'class': re.compile(r'\b%s\b' % sProperty)} + if bAllowMultiple and (iPropertyType != self.NODE): + snapResults = [] + containers = elmRoot(['ul', 'ol'], propertyMatch) + for container in containers: + snapResults.extend(container('li')) + bFound = (len(snapResults) != 0) + if not bFound: + snapResults = elmRoot(all, propertyMatch) + bFound = (len(snapResults) != 0) + if (not bFound) and (sProperty == 'value'): + snapResults = elmRoot('pre') + bFound = (len(snapResults) != 0) + bNormalize = not bFound + if not bFound: + snapResults = [elmRoot] + bFound = (len(snapResults) != 0) + arFilter = [] + if sProperty == 'vcard': + snapFilter = elmRoot(all, propertyMatch) + for node in snapFilter: + if node.findParent(all, propertyMatch): + arFilter.append(node) + arResults = [] + for node in snapResults: + if node not in arFilter: + arResults.append(node) + bFound = (len(arResults) != 0) + if not bFound: + if bAllowMultiple: return [] + elif iPropertyType == self.STRING: return '' + elif iPropertyType == self.DATE: return None + elif iPropertyType == self.URI: return '' + elif iPropertyType == self.NODE: return None + else: return None + arValues = [] + for elmResult in arResults: + sValue = None + if iPropertyType == self.NODE: + if bAllowMultiple: + arValues.append(elmResult) + continue + else: + return elmResult + sNodeName = elmResult.name.lower() + if (iPropertyType == self.EMAIL) and (sNodeName == 'a'): + sValue = (elmResult.get('href') or '').split('mailto:').pop().split('?')[0] + if sValue: + sValue = bNormalize and self.normalize(sValue) or sValue.strip() + if (not sValue) and (sNodeName == 'abbr'): + sValue = elmResult.get('title') + if sValue: + sValue = bNormalize and self.normalize(sValue) or sValue.strip() + if (not sValue) and (iPropertyType == self.URI): + if sNodeName == 'a': sValue = elmResult.get('href') + elif sNodeName == 'img': sValue = elmResult.get('src') + elif sNodeName == 'object': sValue = elmResult.get('data') + if sValue: + sValue = bNormalize and self.normalize(sValue) or sValue.strip() + if (not sValue) and (sNodeName == 'img'): + sValue = elmResult.get('alt') + if sValue: + sValue = bNormalize and self.normalize(sValue) or sValue.strip() + if not sValue: + sValue = elmResult.renderContents() + sValue = re.sub(r'<\S[^>]*>', '', sValue) + sValue = sValue.replace('\r\n', '\n') + sValue = sValue.replace('\r', '\n') + if sValue: + sValue = bNormalize and self.normalize(sValue) or sValue.strip() + if not sValue: continue + if iPropertyType == self.DATE: + sValue = _parse_date_iso8601(sValue) + if bAllowMultiple: + arValues.append(bAutoEscape and self.vcardEscape(sValue) or sValue) + else: + return bAutoEscape and self.vcardEscape(sValue) or sValue + return arValues + + def findVCards(self, elmRoot, bAgentParsing=0): + sVCards = '' + + if not bAgentParsing: + arCards = self.getPropertyValue(elmRoot, 'vcard', bAllowMultiple=1) + else: + arCards = [elmRoot] + + for elmCard in arCards: + arLines = [] + + def processSingleString(sProperty): + sValue = self.getPropertyValue(elmCard, sProperty, self.STRING, bAutoEscape=1).decode(self.encoding) + if sValue: + arLines.append(self.vcardFold(sProperty.upper() + ':' + sValue)) + return sValue or u'' + + def processSingleURI(sProperty): + sValue = self.getPropertyValue(elmCard, sProperty, self.URI) + if sValue: + sContentType = '' + sEncoding = '' + sValueKey = '' + if sValue.startswith('data:'): + sEncoding = ';ENCODING=b' + sContentType = sValue.split(';')[0].split('/').pop() + sValue = sValue.split(',', 1).pop() + else: + elmValue = self.getPropertyValue(elmCard, sProperty) + if elmValue: + if sProperty != 'url': + sValueKey = ';VALUE=uri' + sContentType = elmValue.get('type', '').strip().split('/').pop().strip() + sContentType = sContentType.upper() + if sContentType == 'OCTET-STREAM': + sContentType = '' + if sContentType: + sContentType = ';TYPE=' + sContentType.upper() + arLines.append(self.vcardFold(sProperty.upper() + sEncoding + sContentType + sValueKey + ':' + sValue)) + + def processTypeValue(sProperty, arDefaultType, arForceType=None): + arResults = self.getPropertyValue(elmCard, sProperty, bAllowMultiple=1) + for elmResult in arResults: + arType = self.getPropertyValue(elmResult, 'type', self.STRING, 1, 1) + if arForceType: + arType = self.unique(arForceType + arType) + if not arType: + arType = arDefaultType + sValue = self.getPropertyValue(elmResult, 'value', self.EMAIL, 0) + if sValue: + arLines.append(self.vcardFold(sProperty.upper() + ';TYPE=' + ','.join(arType) + ':' + sValue)) + + # AGENT + # must do this before all other properties because it is destructive + # (removes nested class="vcard" nodes so they don't interfere with + # this vcard's other properties) + arAgent = self.getPropertyValue(elmCard, 'agent', bAllowMultiple=1) + for elmAgent in arAgent: + if re.compile(r'\bvcard\b').search(elmAgent.get('class')): + sAgentValue = self.findVCards(elmAgent, 1) + '\n' + sAgentValue = sAgentValue.replace('\n', '\\n') + sAgentValue = sAgentValue.replace(';', '\\;') + if sAgentValue: + arLines.append(self.vcardFold('AGENT:' + sAgentValue)) + # Completely remove the agent element from the parse tree + elmAgent.extract() + else: + sAgentValue = self.getPropertyValue(elmAgent, 'value', self.URI, bAutoEscape=1); + if sAgentValue: + arLines.append(self.vcardFold('AGENT;VALUE=uri:' + sAgentValue)) + + # FN (full name) + sFN = processSingleString('fn') + + # N (name) + elmName = self.getPropertyValue(elmCard, 'n') + if elmName: + sFamilyName = self.getPropertyValue(elmName, 'family-name', self.STRING, bAutoEscape=1) + sGivenName = self.getPropertyValue(elmName, 'given-name', self.STRING, bAutoEscape=1) + arAdditionalNames = self.getPropertyValue(elmName, 'additional-name', self.STRING, 1, 1) + self.getPropertyValue(elmName, 'additional-names', self.STRING, 1, 1) + arHonorificPrefixes = self.getPropertyValue(elmName, 'honorific-prefix', self.STRING, 1, 1) + self.getPropertyValue(elmName, 'honorific-prefixes', self.STRING, 1, 1) + arHonorificSuffixes = self.getPropertyValue(elmName, 'honorific-suffix', self.STRING, 1, 1) + self.getPropertyValue(elmName, 'honorific-suffixes', self.STRING, 1, 1) + arLines.append(self.vcardFold('N:' + sFamilyName + ';' + + sGivenName + ';' + + ','.join(arAdditionalNames) + ';' + + ','.join(arHonorificPrefixes) + ';' + + ','.join(arHonorificSuffixes))) + elif sFN: + # implied "N" optimization + # http://microformats.org/wiki/hcard#Implied_.22N.22_Optimization + arNames = self.normalize(sFN).split() + if len(arNames) == 2: + bFamilyNameFirst = (arNames[0].endswith(',') or + len(arNames[1]) == 1 or + ((len(arNames[1]) == 2) and (arNames[1].endswith('.')))) + if bFamilyNameFirst: + arLines.append(self.vcardFold('N:' + arNames[0] + ';' + arNames[1])) + else: + arLines.append(self.vcardFold('N:' + arNames[1] + ';' + arNames[0])) + + # SORT-STRING + sSortString = self.getPropertyValue(elmCard, 'sort-string', self.STRING, bAutoEscape=1) + if sSortString: + arLines.append(self.vcardFold('SORT-STRING:' + sSortString)) + + # NICKNAME + arNickname = self.getPropertyValue(elmCard, 'nickname', self.STRING, 1, 1) + if arNickname: + arLines.append(self.vcardFold('NICKNAME:' + ','.join(arNickname))) + + # PHOTO + processSingleURI('photo') + + # BDAY + dtBday = self.getPropertyValue(elmCard, 'bday', self.DATE) + if dtBday: + arLines.append(self.vcardFold('BDAY:' + self.toISO8601(dtBday))) + + # ADR (address) + arAdr = self.getPropertyValue(elmCard, 'adr', bAllowMultiple=1) + for elmAdr in arAdr: + arType = self.getPropertyValue(elmAdr, 'type', self.STRING, 1, 1) + if not arType: + arType = ['intl','postal','parcel','work'] # default adr types, see RFC 2426 section 3.2.1 + sPostOfficeBox = self.getPropertyValue(elmAdr, 'post-office-box', self.STRING, 0, 1) + sExtendedAddress = self.getPropertyValue(elmAdr, 'extended-address', self.STRING, 0, 1) + sStreetAddress = self.getPropertyValue(elmAdr, 'street-address', self.STRING, 0, 1) + sLocality = self.getPropertyValue(elmAdr, 'locality', self.STRING, 0, 1) + sRegion = self.getPropertyValue(elmAdr, 'region', self.STRING, 0, 1) + sPostalCode = self.getPropertyValue(elmAdr, 'postal-code', self.STRING, 0, 1) + sCountryName = self.getPropertyValue(elmAdr, 'country-name', self.STRING, 0, 1) + arLines.append(self.vcardFold('ADR;TYPE=' + ','.join(arType) + ':' + + sPostOfficeBox + ';' + + sExtendedAddress + ';' + + sStreetAddress + ';' + + sLocality + ';' + + sRegion + ';' + + sPostalCode + ';' + + sCountryName)) + + # LABEL + processTypeValue('label', ['intl','postal','parcel','work']) + + # TEL (phone number) + processTypeValue('tel', ['voice']) + + # EMAIL + processTypeValue('email', ['internet'], ['internet']) + + # MAILER + processSingleString('mailer') + + # TZ (timezone) + processSingleString('tz') + + # GEO (geographical information) + elmGeo = self.getPropertyValue(elmCard, 'geo') + if elmGeo: + sLatitude = self.getPropertyValue(elmGeo, 'latitude', self.STRING, 0, 1) + sLongitude = self.getPropertyValue(elmGeo, 'longitude', self.STRING, 0, 1) + arLines.append(self.vcardFold('GEO:' + sLatitude + ';' + sLongitude)) + + # TITLE + processSingleString('title') + + # ROLE + processSingleString('role') + + # LOGO + processSingleURI('logo') + + # ORG (organization) + elmOrg = self.getPropertyValue(elmCard, 'org') + if elmOrg: + sOrganizationName = self.getPropertyValue(elmOrg, 'organization-name', self.STRING, 0, 1) + if not sOrganizationName: + # implied "organization-name" optimization + # http://microformats.org/wiki/hcard#Implied_.22organization-name.22_Optimization + sOrganizationName = self.getPropertyValue(elmCard, 'org', self.STRING, 0, 1) + if sOrganizationName: + arLines.append(self.vcardFold('ORG:' + sOrganizationName)) + else: + arOrganizationUnit = self.getPropertyValue(elmOrg, 'organization-unit', self.STRING, 1, 1) + arLines.append(self.vcardFold('ORG:' + sOrganizationName + ';' + ';'.join(arOrganizationUnit))) + + # CATEGORY + arCategory = self.getPropertyValue(elmCard, 'category', self.STRING, 1, 1) + self.getPropertyValue(elmCard, 'categories', self.STRING, 1, 1) + if arCategory: + arLines.append(self.vcardFold('CATEGORIES:' + ','.join(arCategory))) + + # NOTE + processSingleString('note') + + # REV + processSingleString('rev') + + # SOUND + processSingleURI('sound') + + # UID + processSingleString('uid') + + # URL + processSingleURI('url') + + # CLASS + processSingleString('class') + + # KEY + processSingleURI('key') + + if arLines: + arLines = [u'BEGIN:vCard',u'VERSION:3.0'] + arLines + [u'END:vCard'] + sVCards += u'\n'.join(arLines) + u'\n' + + return sVCards.strip() + + def isProbablyDownloadable(self, elm): + attrsD = elm.attrMap + if not attrsD.has_key('href'): return 0 + linktype = attrsD.get('type', '').strip() + if linktype.startswith('audio/') or \ + linktype.startswith('video/') or \ + (linktype.startswith('application/') and not linktype.endswith('xml')): + return 1 + path = urlparse.urlparse(attrsD['href'])[2] + if path.find('.') == -1: return 0 + fileext = path.split('.').pop().lower() + return fileext in self.known_binary_extensions + + def findTags(self): + all = lambda x: 1 + for elm in self.document(all, {'rel': re.compile(r'\btag\b')}): + href = elm.get('href') + if not href: continue + urlscheme, domain, path, params, query, fragment = \ + urlparse.urlparse(_urljoin(self.baseuri, href)) + segments = path.split('/') + tag = segments.pop() + if not tag: + tag = segments.pop() + tagscheme = urlparse.urlunparse((urlscheme, domain, '/'.join(segments), '', '', '')) + if not tagscheme.endswith('/'): + tagscheme += '/' + self.tags.append(FeedParserDict({"term": tag, "scheme": tagscheme, "label": elm.string or ''})) + + def findEnclosures(self): + all = lambda x: 1 + enclosure_match = re.compile(r'\benclosure\b') + for elm in self.document(all, {'href': re.compile(r'.+')}): + if not enclosure_match.search(elm.get('rel', '')) and not self.isProbablyDownloadable(elm): continue + if elm.attrMap not in self.enclosures: + self.enclosures.append(elm.attrMap) + if elm.string and not elm.get('title'): + self.enclosures[-1]['title'] = elm.string + + def findXFN(self): + all = lambda x: 1 + for elm in self.document(all, {'rel': re.compile('.+'), 'href': re.compile('.+')}): + rels = elm.get('rel', '').split() + xfn_rels = [] + for rel in rels: + if rel in self.known_xfn_relationships: + xfn_rels.append(rel) + if xfn_rels: + self.xfn.append({"relationships": xfn_rels, "href": elm.get('href', ''), "name": elm.string}) + +def _parseMicroformats(htmlSource, baseURI, encoding): + if not BeautifulSoup: return + if _debug: sys.stderr.write('entering _parseMicroformats\n') + try: + p = _MicroformatsParser(htmlSource, baseURI, encoding) + except UnicodeEncodeError: + # sgmllib throws this exception when performing lookups of tags + # with non-ASCII characters in them. + return + p.vcard = p.findVCards(p.document) + p.findTags() + p.findEnclosures() + p.findXFN() + return {"tags": p.tags, "enclosures": p.enclosures, "xfn": p.xfn, "vcard": p.vcard} + +class _RelativeURIResolver(_BaseHTMLProcessor): + relative_uris = [('a', 'href'), + ('applet', 'codebase'), + ('area', 'href'), + ('blockquote', 'cite'), + ('body', 'background'), + ('del', 'cite'), + ('form', 'action'), + ('frame', 'longdesc'), + ('frame', 'src'), + ('iframe', 'longdesc'), + ('iframe', 'src'), + ('head', 'profile'), + ('img', 'longdesc'), + ('img', 'src'), + ('img', 'usemap'), + ('input', 'src'), + ('input', 'usemap'), + ('ins', 'cite'), + ('link', 'href'), + ('object', 'classid'), + ('object', 'codebase'), + ('object', 'data'), + ('object', 'usemap'), + ('q', 'cite'), + ('script', 'src')] + + def __init__(self, baseuri, encoding, _type): + _BaseHTMLProcessor.__init__(self, encoding, _type) + self.baseuri = baseuri + + def resolveURI(self, uri): + return _makeSafeAbsoluteURI(_urljoin(self.baseuri, uri.strip())) + + def unknown_starttag(self, tag, attrs): + if _debug: + sys.stderr.write('tag: [%s] with attributes: [%s]\n' % (tag, str(attrs))) + attrs = self.normalize_attrs(attrs) + attrs = [(key, ((tag, key) in self.relative_uris) and self.resolveURI(value) or value) for key, value in attrs] + _BaseHTMLProcessor.unknown_starttag(self, tag, attrs) + +def _resolveRelativeURIs(htmlSource, baseURI, encoding, _type): + if _debug: + sys.stderr.write('entering _resolveRelativeURIs\n') + + p = _RelativeURIResolver(baseURI, encoding, _type) + p.feed(htmlSource) + return p.output() + +def _makeSafeAbsoluteURI(base, rel=None): + # bail if ACCEPTABLE_URI_SCHEMES is empty + if not ACCEPTABLE_URI_SCHEMES: + return _urljoin(base, rel or u'') + if not base: + return rel or u'' + if not rel: + scheme = urlparse.urlparse(base)[0] + if not scheme or scheme in ACCEPTABLE_URI_SCHEMES: + return base + return u'' + uri = _urljoin(base, rel) + if uri.strip().split(':', 1)[0] not in ACCEPTABLE_URI_SCHEMES: + return u'' + return uri + +class _HTMLSanitizer(_BaseHTMLProcessor): + acceptable_elements = ['a', 'abbr', 'acronym', 'address', 'area', + 'article', 'aside', 'audio', 'b', 'big', 'blockquote', 'br', 'button', + 'canvas', 'caption', 'center', 'cite', 'code', 'col', 'colgroup', + 'command', 'datagrid', 'datalist', 'dd', 'del', 'details', 'dfn', + 'dialog', 'dir', 'div', 'dl', 'dt', 'em', 'event-source', 'fieldset', + 'figcaption', 'figure', 'footer', 'font', 'form', 'header', 'h1', + 'h2', 'h3', 'h4', 'h5', 'h6', 'hr', 'i', 'img', 'input', 'ins', + 'keygen', 'kbd', 'label', 'legend', 'li', 'm', 'map', 'menu', 'meter', + 'multicol', 'nav', 'nextid', 'ol', 'output', 'optgroup', 'option', + 'p', 'pre', 'progress', 'q', 's', 'samp', 'section', 'select', + 'small', 'sound', 'source', 'spacer', 'span', 'strike', 'strong', + 'sub', 'sup', 'table', 'tbody', 'td', 'textarea', 'time', 'tfoot', + 'th', 'thead', 'tr', 'tt', 'u', 'ul', 'var', 'video', 'noscript'] + + acceptable_attributes = ['abbr', 'accept', 'accept-charset', 'accesskey', + 'action', 'align', 'alt', 'autocomplete', 'autofocus', 'axis', + 'background', 'balance', 'bgcolor', 'bgproperties', 'border', + 'bordercolor', 'bordercolordark', 'bordercolorlight', 'bottompadding', + 'cellpadding', 'cellspacing', 'ch', 'challenge', 'char', 'charoff', + 'choff', 'charset', 'checked', 'cite', 'class', 'clear', 'color', 'cols', + 'colspan', 'compact', 'contenteditable', 'controls', 'coords', 'data', + 'datafld', 'datapagesize', 'datasrc', 'datetime', 'default', 'delay', + 'dir', 'disabled', 'draggable', 'dynsrc', 'enctype', 'end', 'face', 'for', + 'form', 'frame', 'galleryimg', 'gutter', 'headers', 'height', 'hidefocus', + 'hidden', 'high', 'href', 'hreflang', 'hspace', 'icon', 'id', 'inputmode', + 'ismap', 'keytype', 'label', 'leftspacing', 'lang', 'list', 'longdesc', + 'loop', 'loopcount', 'loopend', 'loopstart', 'low', 'lowsrc', 'max', + 'maxlength', 'media', 'method', 'min', 'multiple', 'name', 'nohref', + 'noshade', 'nowrap', 'open', 'optimum', 'pattern', 'ping', 'point-size', + 'prompt', 'pqg', 'radiogroup', 'readonly', 'rel', 'repeat-max', + 'repeat-min', 'replace', 'required', 'rev', 'rightspacing', 'rows', + 'rowspan', 'rules', 'scope', 'selected', 'shape', 'size', 'span', 'src', + 'start', 'step', 'summary', 'suppress', 'tabindex', 'target', 'template', + 'title', 'toppadding', 'type', 'unselectable', 'usemap', 'urn', 'valign', + 'value', 'variable', 'volume', 'vspace', 'vrml', 'width', 'wrap', + 'xml:lang'] + + unacceptable_elements_with_end_tag = ['script', 'applet', 'style'] + + acceptable_css_properties = ['azimuth', 'background-color', + 'border-bottom-color', 'border-collapse', 'border-color', + 'border-left-color', 'border-right-color', 'border-top-color', 'clear', + 'color', 'cursor', 'direction', 'display', 'elevation', 'float', 'font', + 'font-family', 'font-size', 'font-style', 'font-variant', 'font-weight', + 'height', 'letter-spacing', 'line-height', 'overflow', 'pause', + 'pause-after', 'pause-before', 'pitch', 'pitch-range', 'richness', + 'speak', 'speak-header', 'speak-numeral', 'speak-punctuation', + 'speech-rate', 'stress', 'text-align', 'text-decoration', 'text-indent', + 'unicode-bidi', 'vertical-align', 'voice-family', 'volume', + 'white-space', 'width'] + + # survey of common keywords found in feeds + acceptable_css_keywords = ['auto', 'aqua', 'black', 'block', 'blue', + 'bold', 'both', 'bottom', 'brown', 'center', 'collapse', 'dashed', + 'dotted', 'fuchsia', 'gray', 'green', '!important', 'italic', 'left', + 'lime', 'maroon', 'medium', 'none', 'navy', 'normal', 'nowrap', 'olive', + 'pointer', 'purple', 'red', 'right', 'solid', 'silver', 'teal', 'top', + 'transparent', 'underline', 'white', 'yellow'] + + valid_css_values = re.compile('^(#[0-9a-f]+|rgb\(\d+%?,\d*%?,?\d*%?\)?|' + + '\d{0,2}\.?\d{0,2}(cm|em|ex|in|mm|pc|pt|px|%|,|\))?)$') + + mathml_elements = ['annotation', 'annotation-xml', 'maction', 'math', + 'merror', 'mfenced', 'mfrac', 'mi', 'mmultiscripts', 'mn', 'mo', 'mover', 'mpadded', + 'mphantom', 'mprescripts', 'mroot', 'mrow', 'mspace', 'msqrt', 'mstyle', + 'msub', 'msubsup', 'msup', 'mtable', 'mtd', 'mtext', 'mtr', 'munder', + 'munderover', 'none', 'semantics'] + + mathml_attributes = ['actiontype', 'align', 'columnalign', 'columnalign', + 'columnalign', 'close', 'columnlines', 'columnspacing', 'columnspan', 'depth', + 'display', 'displaystyle', 'encoding', 'equalcolumns', 'equalrows', + 'fence', 'fontstyle', 'fontweight', 'frame', 'height', 'linethickness', + 'lspace', 'mathbackground', 'mathcolor', 'mathvariant', 'mathvariant', + 'maxsize', 'minsize', 'open', 'other', 'rowalign', 'rowalign', 'rowalign', + 'rowlines', 'rowspacing', 'rowspan', 'rspace', 'scriptlevel', 'selection', + 'separator', 'separators', 'stretchy', 'width', 'width', 'xlink:href', + 'xlink:show', 'xlink:type', 'xmlns', 'xmlns:xlink'] + + # svgtiny - foreignObject + linearGradient + radialGradient + stop + svg_elements = ['a', 'animate', 'animateColor', 'animateMotion', + 'animateTransform', 'circle', 'defs', 'desc', 'ellipse', 'foreignObject', + 'font-face', 'font-face-name', 'font-face-src', 'g', 'glyph', 'hkern', + 'linearGradient', 'line', 'marker', 'metadata', 'missing-glyph', 'mpath', + 'path', 'polygon', 'polyline', 'radialGradient', 'rect', 'set', 'stop', + 'svg', 'switch', 'text', 'title', 'tspan', 'use'] + + # svgtiny + class + opacity + offset + xmlns + xmlns:xlink + svg_attributes = ['accent-height', 'accumulate', 'additive', 'alphabetic', + 'arabic-form', 'ascent', 'attributeName', 'attributeType', + 'baseProfile', 'bbox', 'begin', 'by', 'calcMode', 'cap-height', + 'class', 'color', 'color-rendering', 'content', 'cx', 'cy', 'd', 'dx', + 'dy', 'descent', 'display', 'dur', 'end', 'fill', 'fill-opacity', + 'fill-rule', 'font-family', 'font-size', 'font-stretch', 'font-style', + 'font-variant', 'font-weight', 'from', 'fx', 'fy', 'g1', 'g2', + 'glyph-name', 'gradientUnits', 'hanging', 'height', 'horiz-adv-x', + 'horiz-origin-x', 'id', 'ideographic', 'k', 'keyPoints', 'keySplines', + 'keyTimes', 'lang', 'mathematical', 'marker-end', 'marker-mid', + 'marker-start', 'markerHeight', 'markerUnits', 'markerWidth', 'max', + 'min', 'name', 'offset', 'opacity', 'orient', 'origin', + 'overline-position', 'overline-thickness', 'panose-1', 'path', + 'pathLength', 'points', 'preserveAspectRatio', 'r', 'refX', 'refY', + 'repeatCount', 'repeatDur', 'requiredExtensions', 'requiredFeatures', + 'restart', 'rotate', 'rx', 'ry', 'slope', 'stemh', 'stemv', + 'stop-color', 'stop-opacity', 'strikethrough-position', + 'strikethrough-thickness', 'stroke', 'stroke-dasharray', + 'stroke-dashoffset', 'stroke-linecap', 'stroke-linejoin', + 'stroke-miterlimit', 'stroke-opacity', 'stroke-width', 'systemLanguage', + 'target', 'text-anchor', 'to', 'transform', 'type', 'u1', 'u2', + 'underline-position', 'underline-thickness', 'unicode', 'unicode-range', + 'units-per-em', 'values', 'version', 'viewBox', 'visibility', 'width', + 'widths', 'x', 'x-height', 'x1', 'x2', 'xlink:actuate', 'xlink:arcrole', + 'xlink:href', 'xlink:role', 'xlink:show', 'xlink:title', 'xlink:type', + 'xml:base', 'xml:lang', 'xml:space', 'xmlns', 'xmlns:xlink', 'y', 'y1', + 'y2', 'zoomAndPan'] + + svg_attr_map = None + svg_elem_map = None + + acceptable_svg_properties = [ 'fill', 'fill-opacity', 'fill-rule', + 'stroke', 'stroke-width', 'stroke-linecap', 'stroke-linejoin', + 'stroke-opacity'] + + def reset(self): + _BaseHTMLProcessor.reset(self) + self.unacceptablestack = 0 + self.mathmlOK = 0 + self.svgOK = 0 + + def unknown_starttag(self, tag, attrs): + acceptable_attributes = self.acceptable_attributes + keymap = {} + if not tag in self.acceptable_elements or self.svgOK: + if tag in self.unacceptable_elements_with_end_tag: + self.unacceptablestack += 1 + + # add implicit namespaces to html5 inline svg/mathml + if self._type.endswith('html'): + if not dict(attrs).get('xmlns'): + if tag=='svg': + attrs.append( ('xmlns','http://www.w3.org/2000/svg') ) + if tag=='math': + attrs.append( ('xmlns','http://www.w3.org/1998/Math/MathML') ) + + # not otherwise acceptable, perhaps it is MathML or SVG? + if tag=='math' and ('xmlns','http://www.w3.org/1998/Math/MathML') in attrs: + self.mathmlOK += 1 + if tag=='svg' and ('xmlns','http://www.w3.org/2000/svg') in attrs: + self.svgOK += 1 + + # chose acceptable attributes based on tag class, else bail + if self.mathmlOK and tag in self.mathml_elements: + acceptable_attributes = self.mathml_attributes + elif self.svgOK and tag in self.svg_elements: + # for most vocabularies, lowercasing is a good idea. Many + # svg elements, however, are camel case + if not self.svg_attr_map: + lower=[attr.lower() for attr in self.svg_attributes] + mix=[a for a in self.svg_attributes if a not in lower] + self.svg_attributes = lower + self.svg_attr_map = dict([(a.lower(),a) for a in mix]) + + lower=[attr.lower() for attr in self.svg_elements] + mix=[a for a in self.svg_elements if a not in lower] + self.svg_elements = lower + self.svg_elem_map = dict([(a.lower(),a) for a in mix]) + acceptable_attributes = self.svg_attributes + tag = self.svg_elem_map.get(tag,tag) + keymap = self.svg_attr_map + elif not tag in self.acceptable_elements: + return + + # declare xlink namespace, if needed + if self.mathmlOK or self.svgOK: + if filter(lambda (n,v): n.startswith('xlink:'),attrs): + if not ('xmlns:xlink','http://www.w3.org/1999/xlink') in attrs: + attrs.append(('xmlns:xlink','http://www.w3.org/1999/xlink')) + + clean_attrs = [] + for key, value in self.normalize_attrs(attrs): + if key in acceptable_attributes: + key=keymap.get(key,key) + # make sure the uri uses an acceptable uri scheme + if key == u'href': + value = _makeSafeAbsoluteURI(value) + clean_attrs.append((key,value)) + elif key=='style': + clean_value = self.sanitize_style(value) + if clean_value: clean_attrs.append((key,clean_value)) + _BaseHTMLProcessor.unknown_starttag(self, tag, clean_attrs) + + def unknown_endtag(self, tag): + if not tag in self.acceptable_elements: + if tag in self.unacceptable_elements_with_end_tag: + self.unacceptablestack -= 1 + if self.mathmlOK and tag in self.mathml_elements: + if tag == 'math' and self.mathmlOK: self.mathmlOK -= 1 + elif self.svgOK and tag in self.svg_elements: + tag = self.svg_elem_map.get(tag,tag) + if tag == 'svg' and self.svgOK: self.svgOK -= 1 + else: + return + _BaseHTMLProcessor.unknown_endtag(self, tag) + + def handle_pi(self, text): + pass + + def handle_decl(self, text): + pass + + def handle_data(self, text): + if not self.unacceptablestack: + _BaseHTMLProcessor.handle_data(self, text) + + def sanitize_style(self, style): + # disallow urls + style=re.compile('url\s*\(\s*[^\s)]+?\s*\)\s*').sub(' ',style) + + # gauntlet + if not re.match("""^([:,;#%.\sa-zA-Z0-9!]|\w-\w|'[\s\w]+'|"[\s\w]+"|\([\d,\s]+\))*$""", style): return '' + # This replaced a regexp that used re.match and was prone to pathological back-tracking. + if re.sub("\s*[-\w]+\s*:\s*[^:;]*;?", '', style).strip(): return '' + + clean = [] + for prop,value in re.findall("([-\w]+)\s*:\s*([^:;]*)",style): + if not value: continue + if prop.lower() in self.acceptable_css_properties: + clean.append(prop + ': ' + value + ';') + elif prop.split('-')[0].lower() in ['background','border','margin','padding']: + for keyword in value.split(): + if not keyword in self.acceptable_css_keywords and \ + not self.valid_css_values.match(keyword): + break + else: + clean.append(prop + ': ' + value + ';') + elif self.svgOK and prop.lower() in self.acceptable_svg_properties: + clean.append(prop + ': ' + value + ';') + + return ' '.join(clean) + + def parse_comment(self, i, report=1): + ret = _BaseHTMLProcessor.parse_comment(self, i, report) + if ret >= 0: + return ret + # if ret == -1, this may be a malicious attempt to circumvent + # sanitization, or a page-destroying unclosed comment + match = re.compile(r'--[^>]*>').search(self.rawdata, i+4) + if match: + return match.end() + # unclosed comment; deliberately fail to handle_data() + return len(self.rawdata) + + +def _sanitizeHTML(htmlSource, encoding, _type): + p = _HTMLSanitizer(encoding, _type) + htmlSource = htmlSource.replace(''): + data = data.split('>', 1)[1] + if data.count('= '2.3.3' + assert base64 != None + user, passw = _base64decode(req.headers['Authorization'].split(' ')[1]).split(':') + realm = re.findall('realm="([^"]*)"', headers['WWW-Authenticate'])[0] + self.add_password(realm, host, user, passw) + retry = self.http_error_auth_reqed('www-authenticate', host, req, headers) + self.reset_retry_count() + return retry + except: + return self.http_error_default(req, fp, code, msg, headers) + +def _open_resource(url_file_stream_or_string, etag, modified, agent, referrer, handlers, request_headers): + """URL, filename, or string --> stream + + This function lets you define parsers that take any input source + (URL, pathname to local or network file, or actual data as a string) + and deal with it in a uniform manner. Returned object is guaranteed + to have all the basic stdio read methods (read, readline, readlines). + Just .close() the object when you're done with it. + + If the etag argument is supplied, it will be used as the value of an + If-None-Match request header. + + If the modified argument is supplied, it can be a tuple of 9 integers + (as returned by gmtime() in the standard Python time module) or a date + string in any format supported by feedparser. Regardless, it MUST + be in GMT (Greenwich Mean Time). It will be reformatted into an + RFC 1123-compliant date and used as the value of an If-Modified-Since + request header. + + If the agent argument is supplied, it will be used as the value of a + User-Agent request header. + + If the referrer argument is supplied, it will be used as the value of a + Referer[sic] request header. + + If handlers is supplied, it is a list of handlers used to build a + urllib2 opener. + + if request_headers is supplied it is a dictionary of HTTP request headers + that will override the values generated by FeedParser. + """ + + if hasattr(url_file_stream_or_string, 'read'): + return url_file_stream_or_string + + if url_file_stream_or_string == '-': + return sys.stdin + + if urlparse.urlparse(url_file_stream_or_string)[0] in ('http', 'https', 'ftp', 'file', 'feed'): + # Deal with the feed URI scheme + if url_file_stream_or_string.startswith('feed:http'): + url_file_stream_or_string = url_file_stream_or_string[5:] + elif url_file_stream_or_string.startswith('feed:'): + url_file_stream_or_string = 'http:' + url_file_stream_or_string[5:] + if not agent: + agent = USER_AGENT + # test for inline user:password for basic auth + auth = None + if base64: + urltype, rest = urllib.splittype(url_file_stream_or_string) + realhost, rest = urllib.splithost(rest) + if realhost: + user_passwd, realhost = urllib.splituser(realhost) + if user_passwd: + url_file_stream_or_string = '%s://%s%s' % (urltype, realhost, rest) + auth = base64.standard_b64encode(user_passwd).strip() + + # iri support + try: + if isinstance(url_file_stream_or_string,unicode): + url_file_stream_or_string = url_file_stream_or_string.encode('idna').decode('utf-8') + else: + url_file_stream_or_string = url_file_stream_or_string.decode('utf-8').encode('idna').decode('utf-8') + except: + pass + + # try to open with urllib2 (to use optional headers) + request = _build_urllib2_request(url_file_stream_or_string, agent, etag, modified, referrer, auth, request_headers) + opener = apply(urllib2.build_opener, tuple(handlers + [_FeedURLHandler()])) + opener.addheaders = [] # RMK - must clear so we only send our custom User-Agent + try: + return opener.open(request, timeout=15) + finally: + opener.close() # JohnD + + # try to open with native open function (if url_file_stream_or_string is a filename) + try: + return open(url_file_stream_or_string, 'rb') + except: + pass + + # treat url_file_stream_or_string as string + return _StringIO(str(url_file_stream_or_string)) + +def _build_urllib2_request(url, agent, etag, modified, referrer, auth, request_headers): + request = urllib2.Request(url) + request.add_header('User-Agent', agent) + if etag: + request.add_header('If-None-Match', etag) + if type(modified) == type(''): + modified = _parse_date(modified) + elif isinstance(modified, datetime.datetime): + modified = modified.utctimetuple() + if modified: + # format into an RFC 1123-compliant timestamp. We can't use + # time.strftime() since the %a and %b directives can be affected + # by the current locale, but RFC 2616 states that dates must be + # in English. + short_weekdays = ['Mon', 'Tue', 'Wed', 'Thu', 'Fri', 'Sat', 'Sun'] + months = ['Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun', 'Jul', 'Aug', 'Sep', 'Oct', 'Nov', 'Dec'] + request.add_header('If-Modified-Since', '%s, %02d %s %04d %02d:%02d:%02d GMT' % (short_weekdays[modified[6]], modified[2], months[modified[1] - 1], modified[0], modified[3], modified[4], modified[5])) + if referrer: + request.add_header('Referer', referrer) + if gzip and zlib: + request.add_header('Accept-encoding', 'gzip, deflate') + elif gzip: + request.add_header('Accept-encoding', 'gzip') + elif zlib: + request.add_header('Accept-encoding', 'deflate') + else: + request.add_header('Accept-encoding', '') + if auth: + request.add_header('Authorization', 'Basic %s' % auth) + if ACCEPT_HEADER: + request.add_header('Accept', ACCEPT_HEADER) + # use this for whatever -- cookies, special headers, etc + # [('Cookie','Something'),('x-special-header','Another Value')] + for header_name, header_value in request_headers.items(): + request.add_header(header_name, header_value) + request.add_header('A-IM', 'feed') # RFC 3229 support + return request + +_date_handlers = [] +def registerDateHandler(func): + '''Register a date handler function (takes string, returns 9-tuple date in GMT)''' + _date_handlers.insert(0, func) + +# ISO-8601 date parsing routines written by Fazal Majid. +# The ISO 8601 standard is very convoluted and irregular - a full ISO 8601 +# parser is beyond the scope of feedparser and would be a worthwhile addition +# to the Python library. +# A single regular expression cannot parse ISO 8601 date formats into groups +# as the standard is highly irregular (for instance is 030104 2003-01-04 or +# 0301-04-01), so we use templates instead. +# Please note the order in templates is significant because we need a +# greedy match. +_iso8601_tmpl = ['YYYY-?MM-?DD', 'YYYY-0MM?-?DD', 'YYYY-MM', 'YYYY-?OOO', + 'YY-?MM-?DD', 'YY-?OOO', 'YYYY', + '-YY-?MM', '-OOO', '-YY', + '--MM-?DD', '--MM', + '---DD', + 'CC', ''] +_iso8601_re = [ + tmpl.replace( + 'YYYY', r'(?P\d{4})').replace( + 'YY', r'(?P\d\d)').replace( + 'MM', r'(?P[01]\d)').replace( + 'DD', r'(?P[0123]\d)').replace( + 'OOO', r'(?P[0123]\d\d)').replace( + 'CC', r'(?P\d\d$)') + + r'(T?(?P\d{2}):(?P\d{2})' + + r'(:(?P\d{2}))?' + + r'(\.(?P\d+))?' + + r'(?P[+-](?P\d{2})(:(?P\d{2}))?|Z)?)?' + for tmpl in _iso8601_tmpl] +try: + del tmpl +except NameError: + pass +_iso8601_matches = [re.compile(regex).match for regex in _iso8601_re] +try: + del regex +except NameError: + pass +def _parse_date_iso8601(dateString): + '''Parse a variety of ISO-8601-compatible formats like 20040105''' + m = None + for _iso8601_match in _iso8601_matches: + m = _iso8601_match(dateString) + if m: break + if not m: return + if m.span() == (0, 0): return + params = m.groupdict() + ordinal = params.get('ordinal', 0) + if ordinal: + ordinal = int(ordinal) + else: + ordinal = 0 + year = params.get('year', '--') + if not year or year == '--': + year = time.gmtime()[0] + elif len(year) == 2: + # ISO 8601 assumes current century, i.e. 93 -> 2093, NOT 1993 + year = 100 * int(time.gmtime()[0] / 100) + int(year) + else: + year = int(year) + month = params.get('month', '-') + if not month or month == '-': + # ordinals are NOT normalized by mktime, we simulate them + # by setting month=1, day=ordinal + if ordinal: + month = 1 + else: + month = time.gmtime()[1] + month = int(month) + day = params.get('day', 0) + if not day: + # see above + if ordinal: + day = ordinal + elif params.get('century', 0) or \ + params.get('year', 0) or params.get('month', 0): + day = 1 + else: + day = time.gmtime()[2] + else: + day = int(day) + # special case of the century - is the first year of the 21st century + # 2000 or 2001 ? The debate goes on... + if 'century' in params.keys(): + year = (int(params['century']) - 1) * 100 + 1 + # in ISO 8601 most fields are optional + for field in ['hour', 'minute', 'second', 'tzhour', 'tzmin']: + if not params.get(field, None): + params[field] = 0 + hour = int(params.get('hour', 0)) + minute = int(params.get('minute', 0)) + second = int(float(params.get('second', 0))) + # weekday is normalized by mktime(), we can ignore it + weekday = 0 + daylight_savings_flag = -1 + tm = [year, month, day, hour, minute, second, weekday, + ordinal, daylight_savings_flag] + # ISO 8601 time zone adjustments + tz = params.get('tz') + if tz and tz != 'Z': + if tz[0] == '-': + tm[3] += int(params.get('tzhour', 0)) + tm[4] += int(params.get('tzmin', 0)) + elif tz[0] == '+': + tm[3] -= int(params.get('tzhour', 0)) + tm[4] -= int(params.get('tzmin', 0)) + else: + return None + # Python's time.mktime() is a wrapper around the ANSI C mktime(3c) + # which is guaranteed to normalize d/m/y/h/m/s. + # Many implementations have bugs, but we'll pretend they don't. + return time.localtime(time.mktime(tuple(tm))) +registerDateHandler(_parse_date_iso8601) + +# 8-bit date handling routines written by ytrewq1. +_korean_year = u'\ub144' # b3e2 in euc-kr +_korean_month = u'\uc6d4' # bff9 in euc-kr +_korean_day = u'\uc77c' # c0cf in euc-kr +_korean_am = u'\uc624\uc804' # bfc0 c0fc in euc-kr +_korean_pm = u'\uc624\ud6c4' # bfc0 c8c4 in euc-kr + +_korean_onblog_date_re = \ + re.compile('(\d{4})%s\s+(\d{2})%s\s+(\d{2})%s\s+(\d{2}):(\d{2}):(\d{2})' % \ + (_korean_year, _korean_month, _korean_day)) +_korean_nate_date_re = \ + re.compile(u'(\d{4})-(\d{2})-(\d{2})\s+(%s|%s)\s+(\d{,2}):(\d{,2}):(\d{,2})' % \ + (_korean_am, _korean_pm)) +def _parse_date_onblog(dateString): + '''Parse a string according to the OnBlog 8-bit date format''' + m = _korean_onblog_date_re.match(dateString) + if not m: return + w3dtfdate = '%(year)s-%(month)s-%(day)sT%(hour)s:%(minute)s:%(second)s%(zonediff)s' % \ + {'year': m.group(1), 'month': m.group(2), 'day': m.group(3),\ + 'hour': m.group(4), 'minute': m.group(5), 'second': m.group(6),\ + 'zonediff': '+09:00'} + if _debug: sys.stderr.write('OnBlog date parsed as: %s\n' % w3dtfdate) + return _parse_date_w3dtf(w3dtfdate) +registerDateHandler(_parse_date_onblog) + +def _parse_date_nate(dateString): + '''Parse a string according to the Nate 8-bit date format''' + m = _korean_nate_date_re.match(dateString) + if not m: return + hour = int(m.group(5)) + ampm = m.group(4) + if (ampm == _korean_pm): + hour += 12 + hour = str(hour) + if len(hour) == 1: + hour = '0' + hour + w3dtfdate = '%(year)s-%(month)s-%(day)sT%(hour)s:%(minute)s:%(second)s%(zonediff)s' % \ + {'year': m.group(1), 'month': m.group(2), 'day': m.group(3),\ + 'hour': hour, 'minute': m.group(6), 'second': m.group(7),\ + 'zonediff': '+09:00'} + if _debug: sys.stderr.write('Nate date parsed as: %s\n' % w3dtfdate) + return _parse_date_w3dtf(w3dtfdate) +registerDateHandler(_parse_date_nate) + +_mssql_date_re = \ + re.compile('(\d{4})-(\d{2})-(\d{2})\s+(\d{2}):(\d{2}):(\d{2})(\.\d+)?') +def _parse_date_mssql(dateString): + '''Parse a string according to the MS SQL date format''' + m = _mssql_date_re.match(dateString) + if not m: return + w3dtfdate = '%(year)s-%(month)s-%(day)sT%(hour)s:%(minute)s:%(second)s%(zonediff)s' % \ + {'year': m.group(1), 'month': m.group(2), 'day': m.group(3),\ + 'hour': m.group(4), 'minute': m.group(5), 'second': m.group(6),\ + 'zonediff': '+09:00'} + if _debug: sys.stderr.write('MS SQL date parsed as: %s\n' % w3dtfdate) + return _parse_date_w3dtf(w3dtfdate) +registerDateHandler(_parse_date_mssql) + +# Unicode strings for Greek date strings +_greek_months = \ + { \ + u'\u0399\u03b1\u03bd': u'Jan', # c9e1ed in iso-8859-7 + u'\u03a6\u03b5\u03b2': u'Feb', # d6e5e2 in iso-8859-7 + u'\u039c\u03ac\u03ce': u'Mar', # ccdcfe in iso-8859-7 + u'\u039c\u03b1\u03ce': u'Mar', # cce1fe in iso-8859-7 + u'\u0391\u03c0\u03c1': u'Apr', # c1f0f1 in iso-8859-7 + u'\u039c\u03ac\u03b9': u'May', # ccdce9 in iso-8859-7 + u'\u039c\u03b1\u03ca': u'May', # cce1fa in iso-8859-7 + u'\u039c\u03b1\u03b9': u'May', # cce1e9 in iso-8859-7 + u'\u0399\u03bf\u03cd\u03bd': u'Jun', # c9effded in iso-8859-7 + u'\u0399\u03bf\u03bd': u'Jun', # c9efed in iso-8859-7 + u'\u0399\u03bf\u03cd\u03bb': u'Jul', # c9effdeb in iso-8859-7 + u'\u0399\u03bf\u03bb': u'Jul', # c9f9eb in iso-8859-7 + u'\u0391\u03cd\u03b3': u'Aug', # c1fde3 in iso-8859-7 + u'\u0391\u03c5\u03b3': u'Aug', # c1f5e3 in iso-8859-7 + u'\u03a3\u03b5\u03c0': u'Sep', # d3e5f0 in iso-8859-7 + u'\u039f\u03ba\u03c4': u'Oct', # cfeaf4 in iso-8859-7 + u'\u039d\u03bf\u03ad': u'Nov', # cdefdd in iso-8859-7 + u'\u039d\u03bf\u03b5': u'Nov', # cdefe5 in iso-8859-7 + u'\u0394\u03b5\u03ba': u'Dec', # c4e5ea in iso-8859-7 + } + +_greek_wdays = \ + { \ + u'\u039a\u03c5\u03c1': u'Sun', # caf5f1 in iso-8859-7 + u'\u0394\u03b5\u03c5': u'Mon', # c4e5f5 in iso-8859-7 + u'\u03a4\u03c1\u03b9': u'Tue', # d4f1e9 in iso-8859-7 + u'\u03a4\u03b5\u03c4': u'Wed', # d4e5f4 in iso-8859-7 + u'\u03a0\u03b5\u03bc': u'Thu', # d0e5ec in iso-8859-7 + u'\u03a0\u03b1\u03c1': u'Fri', # d0e1f1 in iso-8859-7 + u'\u03a3\u03b1\u03b2': u'Sat', # d3e1e2 in iso-8859-7 + } + +_greek_date_format_re = \ + re.compile(u'([^,]+),\s+(\d{2})\s+([^\s]+)\s+(\d{4})\s+(\d{2}):(\d{2}):(\d{2})\s+([^\s]+)') + +def _parse_date_greek(dateString): + '''Parse a string according to a Greek 8-bit date format.''' + m = _greek_date_format_re.match(dateString) + if not m: return + try: + wday = _greek_wdays[m.group(1)] + month = _greek_months[m.group(3)] + except: + return + rfc822date = '%(wday)s, %(day)s %(month)s %(year)s %(hour)s:%(minute)s:%(second)s %(zonediff)s' % \ + {'wday': wday, 'day': m.group(2), 'month': month, 'year': m.group(4),\ + 'hour': m.group(5), 'minute': m.group(6), 'second': m.group(7),\ + 'zonediff': m.group(8)} + if _debug: sys.stderr.write('Greek date parsed as: %s\n' % rfc822date) + return _parse_date_rfc822(rfc822date) +registerDateHandler(_parse_date_greek) + +# Unicode strings for Hungarian date strings +_hungarian_months = \ + { \ + u'janu\u00e1r': u'01', # e1 in iso-8859-2 + u'febru\u00e1ri': u'02', # e1 in iso-8859-2 + u'm\u00e1rcius': u'03', # e1 in iso-8859-2 + u'\u00e1prilis': u'04', # e1 in iso-8859-2 + u'm\u00e1ujus': u'05', # e1 in iso-8859-2 + u'j\u00fanius': u'06', # fa in iso-8859-2 + u'j\u00falius': u'07', # fa in iso-8859-2 + u'augusztus': u'08', + u'szeptember': u'09', + u'okt\u00f3ber': u'10', # f3 in iso-8859-2 + u'november': u'11', + u'december': u'12', + } + +_hungarian_date_format_re = \ + re.compile(u'(\d{4})-([^-]+)-(\d{,2})T(\d{,2}):(\d{2})((\+|-)(\d{,2}:\d{2}))') + +def _parse_date_hungarian(dateString): + '''Parse a string according to a Hungarian 8-bit date format.''' + m = _hungarian_date_format_re.match(dateString) + if not m: return + try: + month = _hungarian_months[m.group(2)] + day = m.group(3) + if len(day) == 1: + day = '0' + day + hour = m.group(4) + if len(hour) == 1: + hour = '0' + hour + except: + return + w3dtfdate = '%(year)s-%(month)s-%(day)sT%(hour)s:%(minute)s%(zonediff)s' % \ + {'year': m.group(1), 'month': month, 'day': day,\ + 'hour': hour, 'minute': m.group(5),\ + 'zonediff': m.group(6)} + if _debug: sys.stderr.write('Hungarian date parsed as: %s\n' % w3dtfdate) + return _parse_date_w3dtf(w3dtfdate) +registerDateHandler(_parse_date_hungarian) + +# W3DTF-style date parsing adapted from PyXML xml.utils.iso8601, written by +# Drake and licensed under the Python license. Removed all range checking +# for month, day, hour, minute, and second, since mktime will normalize +# these later +def _parse_date_w3dtf(dateString): + def __extract_date(m): + year = int(m.group('year')) + if year < 100: + year = 100 * int(time.gmtime()[0] / 100) + int(year) + if year < 1000: + return 0, 0, 0 + julian = m.group('julian') + if julian: + julian = int(julian) + month = julian / 30 + 1 + day = julian % 30 + 1 + jday = None + while jday != julian: + t = time.mktime((year, month, day, 0, 0, 0, 0, 0, 0)) + jday = time.gmtime(t)[-2] + diff = abs(jday - julian) + if jday > julian: + if diff < day: + day = day - diff + else: + month = month - 1 + day = 31 + elif jday < julian: + if day + diff < 28: + day = day + diff + else: + month = month + 1 + return year, month, day + month = m.group('month') + day = 1 + if month is None: + month = 1 + else: + month = int(month) + day = m.group('day') + if day: + day = int(day) + else: + day = 1 + return year, month, day + + def __extract_time(m): + if not m: + return 0, 0, 0 + hours = m.group('hours') + if not hours: + return 0, 0, 0 + hours = int(hours) + minutes = int(m.group('minutes')) + seconds = m.group('seconds') + if seconds: + seconds = int(seconds) + else: + seconds = 0 + return hours, minutes, seconds + + def __extract_tzd(m): + '''Return the Time Zone Designator as an offset in seconds from UTC.''' + if not m: + return 0 + tzd = m.group('tzd') + if not tzd: + return 0 + if tzd == 'Z': + return 0 + hours = int(m.group('tzdhours')) + minutes = m.group('tzdminutes') + if minutes: + minutes = int(minutes) + else: + minutes = 0 + offset = (hours*60 + minutes) * 60 + if tzd[0] == '+': + return -offset + return offset + + __date_re = ('(?P\d\d\d\d)' + '(?:(?P-|)' + '(?:(?P\d\d)(?:(?P=dsep)(?P\d\d))?' + '|(?P\d\d\d)))?') + __tzd_re = '(?P[-+](?P\d\d)(?::?(?P\d\d))|Z)' + __tzd_rx = re.compile(__tzd_re) + __time_re = ('(?P\d\d)(?P:|)(?P\d\d)' + '(?:(?P=tsep)(?P\d\d)(?:[.,]\d+)?)?' + + __tzd_re) + __datetime_re = '%s(?:T%s)?' % (__date_re, __time_re) + __datetime_rx = re.compile(__datetime_re) + m = __datetime_rx.match(dateString) + if (m is None) or (m.group() != dateString): return + gmt = __extract_date(m) + __extract_time(m) + (0, 0, 0) + if gmt[0] == 0: return + return time.gmtime(time.mktime(gmt) + __extract_tzd(m) - time.timezone) +registerDateHandler(_parse_date_w3dtf) + +def _parse_date_rfc822(dateString): + '''Parse an RFC822, RFC1123, RFC2822, or asctime-style date''' + data = dateString.split() + if data[0][-1] in (',', '.') or data[0].lower() in rfc822._daynames: + del data[0] + if len(data) == 4: + s = data[3] + i = s.find('+') + if i > 0: + data[3:] = [s[:i], s[i+1:]] + else: + data.append('') + dateString = " ".join(data) + # Account for the Etc/GMT timezone by stripping 'Etc/' + elif len(data) == 5 and data[4].lower().startswith('etc/'): + data[4] = data[4][4:] + dateString = " ".join(data) + if len(data) < 5: + dateString += ' 00:00:00 GMT' + tm = rfc822.parsedate_tz(dateString) + if tm: + return time.gmtime(rfc822.mktime_tz(tm)) +# rfc822.py defines several time zones, but we define some extra ones. +# 'ET' is equivalent to 'EST', etc. +_additional_timezones = {'AT': -400, 'ET': -500, 'CT': -600, 'MT': -700, 'PT': -800} +rfc822._timezones.update(_additional_timezones) +registerDateHandler(_parse_date_rfc822) + +def _parse_date_perforce(aDateString): + """parse a date in yyyy/mm/dd hh:mm:ss TTT format""" + # Fri, 2006/09/15 08:19:53 EDT + _my_date_pattern = re.compile( \ + r'(\w{,3}), (\d{,4})/(\d{,2})/(\d{2}) (\d{,2}):(\d{2}):(\d{2}) (\w{,3})') + + dow, year, month, day, hour, minute, second, tz = \ + _my_date_pattern.search(aDateString).groups() + months = ['Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun', 'Jul', 'Aug', 'Sep', 'Oct', 'Nov', 'Dec'] + dateString = "%s, %s %s %s %s:%s:%s %s" % (dow, day, months[int(month) - 1], year, hour, minute, second, tz) + tm = rfc822.parsedate_tz(dateString) + if tm: + return time.gmtime(rfc822.mktime_tz(tm)) +registerDateHandler(_parse_date_perforce) + +def _parse_date(dateString): + '''Parses a variety of date formats into a 9-tuple in GMT''' + for handler in _date_handlers: + try: + date9tuple = handler(dateString) + if not date9tuple: continue + if len(date9tuple) != 9: + if _debug: sys.stderr.write('date handler function must return 9-tuple\n') + raise ValueError + map(int, date9tuple) + return date9tuple + except Exception, e: + if _debug: sys.stderr.write('%s raised %s\n' % (handler.__name__, repr(e))) + pass + return None + +def _getCharacterEncoding(http_headers, xml_data): + '''Get the character encoding of the XML document + + http_headers is a dictionary + xml_data is a raw string (not Unicode) + + This is so much trickier than it sounds, it's not even funny. + According to RFC 3023 ('XML Media Types'), if the HTTP Content-Type + is application/xml, application/*+xml, + application/xml-external-parsed-entity, or application/xml-dtd, + the encoding given in the charset parameter of the HTTP Content-Type + takes precedence over the encoding given in the XML prefix within the + document, and defaults to 'utf-8' if neither are specified. But, if + the HTTP Content-Type is text/xml, text/*+xml, or + text/xml-external-parsed-entity, the encoding given in the XML prefix + within the document is ALWAYS IGNORED and only the encoding given in + the charset parameter of the HTTP Content-Type header should be + respected, and it defaults to 'us-ascii' if not specified. + + Furthermore, discussion on the atom-syntax mailing list with the + author of RFC 3023 leads me to the conclusion that any document + served with a Content-Type of text/* and no charset parameter + must be treated as us-ascii. (We now do this.) And also that it + must always be flagged as non-well-formed. (We now do this too.) + + If Content-Type is unspecified (input was local file or non-HTTP source) + or unrecognized (server just got it totally wrong), then go by the + encoding given in the XML prefix of the document and default to + 'iso-8859-1' as per the HTTP specification (RFC 2616). + + Then, assuming we didn't find a character encoding in the HTTP headers + (and the HTTP Content-type allowed us to look in the body), we need + to sniff the first few bytes of the XML data and try to determine + whether the encoding is ASCII-compatible. Section F of the XML + specification shows the way here: + http://www.w3.org/TR/REC-xml/#sec-guessing-no-ext-info + + If the sniffed encoding is not ASCII-compatible, we need to make it + ASCII compatible so that we can sniff further into the XML declaration + to find the encoding attribute, which will tell us the true encoding. + + Of course, none of this guarantees that we will be able to parse the + feed in the declared character encoding (assuming it was declared + correctly, which many are not). CJKCodecs and iconv_codec help a lot; + you should definitely install them if you can. + http://cjkpython.i18n.org/ + ''' + + def _parseHTTPContentType(content_type): + '''takes HTTP Content-Type header and returns (content type, charset) + + If no charset is specified, returns (content type, '') + If no content type is specified, returns ('', '') + Both return parameters are guaranteed to be lowercase strings + ''' + content_type = content_type or '' + content_type, params = cgi.parse_header(content_type) + return content_type, params.get('charset', '').replace("'", '') + + sniffed_xml_encoding = '' + xml_encoding = '' + true_encoding = '' + http_content_type, http_encoding = _parseHTTPContentType(http_headers.get('content-type', http_headers.get('Content-type'))) + # Must sniff for non-ASCII-compatible character encodings before + # searching for XML declaration. This heuristic is defined in + # section F of the XML specification: + # http://www.w3.org/TR/REC-xml/#sec-guessing-no-ext-info + try: + if xml_data[:4] == _l2bytes([0x4c, 0x6f, 0xa7, 0x94]): + # EBCDIC + xml_data = _ebcdic_to_ascii(xml_data) + elif xml_data[:4] == _l2bytes([0x00, 0x3c, 0x00, 0x3f]): + # UTF-16BE + sniffed_xml_encoding = 'utf-16be' + xml_data = unicode(xml_data, 'utf-16be').encode('utf-8') + elif (len(xml_data) >= 4) and (xml_data[:2] == _l2bytes([0xfe, 0xff])) and (xml_data[2:4] != _l2bytes([0x00, 0x00])): + # UTF-16BE with BOM + sniffed_xml_encoding = 'utf-16be' + xml_data = unicode(xml_data[2:], 'utf-16be').encode('utf-8') + elif xml_data[:4] == _l2bytes([0x3c, 0x00, 0x3f, 0x00]): + # UTF-16LE + sniffed_xml_encoding = 'utf-16le' + xml_data = unicode(xml_data, 'utf-16le').encode('utf-8') + elif (len(xml_data) >= 4) and (xml_data[:2] == _l2bytes([0xff, 0xfe])) and (xml_data[2:4] != _l2bytes([0x00, 0x00])): + # UTF-16LE with BOM + sniffed_xml_encoding = 'utf-16le' + xml_data = unicode(xml_data[2:], 'utf-16le').encode('utf-8') + elif xml_data[:4] == _l2bytes([0x00, 0x00, 0x00, 0x3c]): + # UTF-32BE + sniffed_xml_encoding = 'utf-32be' + xml_data = unicode(xml_data, 'utf-32be').encode('utf-8') + elif xml_data[:4] == _l2bytes([0x3c, 0x00, 0x00, 0x00]): + # UTF-32LE + sniffed_xml_encoding = 'utf-32le' + xml_data = unicode(xml_data, 'utf-32le').encode('utf-8') + elif xml_data[:4] == _l2bytes([0x00, 0x00, 0xfe, 0xff]): + # UTF-32BE with BOM + sniffed_xml_encoding = 'utf-32be' + xml_data = unicode(xml_data[4:], 'utf-32be').encode('utf-8') + elif xml_data[:4] == _l2bytes([0xff, 0xfe, 0x00, 0x00]): + # UTF-32LE with BOM + sniffed_xml_encoding = 'utf-32le' + xml_data = unicode(xml_data[4:], 'utf-32le').encode('utf-8') + elif xml_data[:3] == _l2bytes([0xef, 0xbb, 0xbf]): + # UTF-8 with BOM + sniffed_xml_encoding = 'utf-8' + xml_data = unicode(xml_data[3:], 'utf-8').encode('utf-8') + else: + # ASCII-compatible + pass + xml_encoding_match = re.compile(_s2bytes('^<\?.*encoding=[\'"](.*?)[\'"].*\?>')).match(xml_data) + except: + xml_encoding_match = None + if xml_encoding_match: + xml_encoding = xml_encoding_match.groups()[0].decode('utf-8').lower() + if sniffed_xml_encoding and (xml_encoding in ('iso-10646-ucs-2', 'ucs-2', 'csunicode', 'iso-10646-ucs-4', 'ucs-4', 'csucs4', 'utf-16', 'utf-32', 'utf_16', 'utf_32', 'utf16', 'u16')): + xml_encoding = sniffed_xml_encoding + acceptable_content_type = 0 + application_content_types = ('application/xml', 'application/xml-dtd', 'application/xml-external-parsed-entity') + text_content_types = ('text/xml', 'text/xml-external-parsed-entity') + if (http_content_type in application_content_types) or \ + (http_content_type.startswith('application/') and http_content_type.endswith('+xml')): + acceptable_content_type = 1 + true_encoding = http_encoding or xml_encoding or 'utf-8' + elif (http_content_type in text_content_types) or \ + (http_content_type.startswith('text/')) and http_content_type.endswith('+xml'): + acceptable_content_type = 1 + true_encoding = http_encoding or 'us-ascii' + elif http_content_type.startswith('text/'): + true_encoding = http_encoding or 'us-ascii' + elif http_headers and (not (http_headers.has_key('content-type') or http_headers.has_key('Content-type'))): + true_encoding = xml_encoding or 'iso-8859-1' + else: + true_encoding = xml_encoding or 'utf-8' + # some feeds claim to be gb2312 but are actually gb18030. + # apparently MSIE and Firefox both do the following switch: + if true_encoding.lower() == 'gb2312': + true_encoding = 'gb18030' + return true_encoding, http_encoding, xml_encoding, sniffed_xml_encoding, acceptable_content_type + +def _toUTF8(data, encoding): + '''Changes an XML data stream on the fly to specify a new encoding + + data is a raw sequence of bytes (not Unicode) that is presumed to be in %encoding already + encoding is a string recognized by encodings.aliases + ''' + if _debug: sys.stderr.write('entering _toUTF8, trying encoding %s\n' % encoding) + # strip Byte Order Mark (if present) + if (len(data) >= 4) and (data[:2] == _l2bytes([0xfe, 0xff])) and (data[2:4] != _l2bytes([0x00, 0x00])): + if _debug: + sys.stderr.write('stripping BOM\n') + if encoding != 'utf-16be': + sys.stderr.write('trying utf-16be instead\n') + encoding = 'utf-16be' + data = data[2:] + elif (len(data) >= 4) and (data[:2] == _l2bytes([0xff, 0xfe])) and (data[2:4] != _l2bytes([0x00, 0x00])): + if _debug: + sys.stderr.write('stripping BOM\n') + if encoding != 'utf-16le': + sys.stderr.write('trying utf-16le instead\n') + encoding = 'utf-16le' + data = data[2:] + elif data[:3] == _l2bytes([0xef, 0xbb, 0xbf]): + if _debug: + sys.stderr.write('stripping BOM\n') + if encoding != 'utf-8': + sys.stderr.write('trying utf-8 instead\n') + encoding = 'utf-8' + data = data[3:] + elif data[:4] == _l2bytes([0x00, 0x00, 0xfe, 0xff]): + if _debug: + sys.stderr.write('stripping BOM\n') + if encoding != 'utf-32be': + sys.stderr.write('trying utf-32be instead\n') + encoding = 'utf-32be' + data = data[4:] + elif data[:4] == _l2bytes([0xff, 0xfe, 0x00, 0x00]): + if _debug: + sys.stderr.write('stripping BOM\n') + if encoding != 'utf-32le': + sys.stderr.write('trying utf-32le instead\n') + encoding = 'utf-32le' + data = data[4:] + newdata = unicode(data, encoding) + if _debug: sys.stderr.write('successfully converted %s data to unicode\n' % encoding) + declmatch = re.compile('^<\?xml[^>]*?>') + newdecl = '''''' + if declmatch.search(newdata): + newdata = declmatch.sub(newdecl, newdata) + else: + newdata = newdecl + u'\n' + newdata + return newdata.encode('utf-8') + +def _stripDoctype(data): + '''Strips DOCTYPE from XML document, returns (rss_version, stripped_data) + + rss_version may be 'rss091n' or None + stripped_data is the same XML document, minus the DOCTYPE + ''' + start = re.search(_s2bytes('<\w'), data) + start = start and start.start() or -1 + head,data = data[:start+1], data[start+1:] + + entity_pattern = re.compile(_s2bytes(r'^\s*]*?)>'), re.MULTILINE) + entity_results=entity_pattern.findall(head) + head = entity_pattern.sub(_s2bytes(''), head) + doctype_pattern = re.compile(_s2bytes(r'^\s*]*?)>'), re.MULTILINE) + doctype_results = doctype_pattern.findall(head) + doctype = doctype_results and doctype_results[0] or _s2bytes('') + if doctype.lower().count(_s2bytes('netscape')): + version = 'rss091n' + else: + version = None + + # only allow in 'safe' inline entity definitions + replacement=_s2bytes('') + if len(doctype_results)==1 and entity_results: + safe_pattern=re.compile(_s2bytes('\s+(\w+)\s+"(&#\w+;|[^&"]*)"')) + safe_entities=filter(lambda e: safe_pattern.match(e),entity_results) + if safe_entities: + replacement=_s2bytes('\n \n]>') + data = doctype_pattern.sub(replacement, head) + data + + return version, data, dict(replacement and [(k.decode('utf-8'), v.decode('utf-8')) for k, v in safe_pattern.findall(replacement)]) + +def parse(url_file_stream_or_string, etag=None, modified=None, agent=None, referrer=None, handlers=[], request_headers={}, response_headers={}): + '''Parse a feed from a URL, file, stream, or string. + + request_headers, if given, is a dict from http header name to value to add + to the request; this overrides internally generated values. + ''' + result = FeedParserDict() + result['feed'] = FeedParserDict() + result['entries'] = [] + if _XML_AVAILABLE: + result['bozo'] = 0 + if not isinstance(handlers, list): + handlers = [handlers] + try: + f = _open_resource(url_file_stream_or_string, etag, modified, agent, referrer, handlers, request_headers) + data = f.read() + except Exception, e: + result['bozo'] = 1 + result['bozo_exception'] = e + data = None + f = None + + if hasattr(f, 'headers'): + result['headers'] = dict(f.headers) + # overwrite existing headers using response_headers + if 'headers' in result: + result['headers'].update(response_headers) + elif response_headers: + result['headers'] = copy.deepcopy(response_headers) + + # if feed is gzip-compressed, decompress it + if f and data and 'headers' in result: + if gzip and result['headers'].get('content-encoding') == 'gzip': + try: + data = gzip.GzipFile(fileobj=_StringIO(data)).read() + except Exception, e: + # Some feeds claim to be gzipped but they're not, so + # we get garbage. Ideally, we should re-request the + # feed without the 'Accept-encoding: gzip' header, + # but we don't. + result['bozo'] = 1 + result['bozo_exception'] = e + data = '' + elif zlib and result['headers'].get('content-encoding') == 'deflate': + try: + data = zlib.decompress(data, -zlib.MAX_WBITS) + except Exception, e: + result['bozo'] = 1 + result['bozo_exception'] = e + data = '' + + # save HTTP headers + if 'headers' in result: + if 'etag' in result['headers'] or 'ETag' in result['headers']: + etag = result['headers'].get('etag', result['headers'].get('ETag')) + if etag: + result['etag'] = etag + if 'last-modified' in result['headers'] or 'Last-Modified' in result['headers']: + modified = result['headers'].get('last-modified', result['headers'].get('Last-Modified')) + if modified: + result['modified'] = _parse_date(modified) + if hasattr(f, 'url'): + result['href'] = f.url + result['status'] = 200 + if hasattr(f, 'status'): + result['status'] = f.status + if hasattr(f, 'close'): + f.close() + + # there are four encodings to keep track of: + # - http_encoding is the encoding declared in the Content-Type HTTP header + # - xml_encoding is the encoding declared in the [:][ :.] +GNTP_INFO_LINE = re.compile( + 'GNTP/(?P\d+\.\d+) (?PREGISTER|NOTIFY|SUBSCRIBE|\-OK|\-ERROR)' + + ' (?P[A-Z0-9]+(:(?P[A-F0-9]+))?) ?' + + '((?P[A-Z0-9]+):(?P[A-F0-9]+).(?P[A-F0-9]+))?\r\n', + re.IGNORECASE +) + +GNTP_INFO_LINE_SHORT = re.compile( + 'GNTP/(?P\d+\.\d+) (?PREGISTER|NOTIFY|SUBSCRIBE|\-OK|\-ERROR)', + re.IGNORECASE +) + +GNTP_HEADER = re.compile('([\w-]+):(.+)') + +GNTP_EOL = gntp.shim.b('\r\n') +GNTP_SEP = gntp.shim.b(': ') + + +class _GNTPBuffer(gntp.shim.StringIO): + """GNTP Buffer class""" + def writeln(self, value=None): + if value: + self.write(gntp.shim.b(value)) + self.write(GNTP_EOL) + + def writeheader(self, key, value): + if not isinstance(value, str): + value = str(value) + self.write(gntp.shim.b(key)) + self.write(GNTP_SEP) + self.write(gntp.shim.b(value)) + self.write(GNTP_EOL) + + +class _GNTPBase(object): + """Base initilization + + :param string messagetype: GNTP Message type + :param string version: GNTP Protocol version + :param string encription: Encryption protocol + """ + def __init__(self, messagetype=None, version='1.0', encryption=None): + self.info = { + 'version': version, + 'messagetype': messagetype, + 'encryptionAlgorithmID': encryption + } + self.hash_algo = { + 'MD5': hashlib.md5, + 'SHA1': hashlib.sha1, + 'SHA256': hashlib.sha256, + 'SHA512': hashlib.sha512, + } + self.headers = {} + self.resources = {} + + def __str__(self): + return self.encode() + + def _parse_info(self, data): + """Parse the first line of a GNTP message to get security and other info values + + :param string data: GNTP Message + :return dict: Parsed GNTP Info line + """ + + match = GNTP_INFO_LINE.match(data) + + if not match: + raise errors.ParseError('ERROR_PARSING_INFO_LINE') + + info = match.groupdict() + if info['encryptionAlgorithmID'] == 'NONE': + info['encryptionAlgorithmID'] = None + + return info + + def set_password(self, password, encryptAlgo='MD5'): + """Set a password for a GNTP Message + + :param string password: Null to clear password + :param string encryptAlgo: Supports MD5, SHA1, SHA256, SHA512 + """ + if not password: + self.info['encryptionAlgorithmID'] = None + self.info['keyHashAlgorithm'] = None + return + + self.password = gntp.shim.b(password) + self.encryptAlgo = encryptAlgo.upper() + + if not self.encryptAlgo in self.hash_algo: + raise errors.UnsupportedError('INVALID HASH "%s"' % self.encryptAlgo) + + hashfunction = self.hash_algo.get(self.encryptAlgo) + + password = password.encode('utf8') + seed = time.ctime().encode('utf8') + salt = hashfunction(seed).hexdigest() + saltHash = hashfunction(seed).digest() + keyBasis = password + saltHash + key = hashfunction(keyBasis).digest() + keyHash = hashfunction(key).hexdigest() + + self.info['keyHashAlgorithmID'] = self.encryptAlgo + self.info['keyHash'] = keyHash.upper() + self.info['salt'] = salt.upper() + + def _decode_hex(self, value): + """Helper function to decode hex string to `proper` hex string + + :param string value: Human readable hex string + :return string: Hex string + """ + result = '' + for i in range(0, len(value), 2): + tmp = int(value[i:i + 2], 16) + result += chr(tmp) + return result + + def _decode_binary(self, rawIdentifier, identifier): + rawIdentifier += '\r\n\r\n' + dataLength = int(identifier['Length']) + pointerStart = self.raw.find(rawIdentifier) + len(rawIdentifier) + pointerEnd = pointerStart + dataLength + data = self.raw[pointerStart:pointerEnd] + if not len(data) == dataLength: + raise errors.ParseError('INVALID_DATA_LENGTH Expected: %s Recieved %s' % (dataLength, len(data))) + return data + + def _validate_password(self, password): + """Validate GNTP Message against stored password""" + self.password = password + if password is None: + raise errors.AuthError('Missing password') + keyHash = self.info.get('keyHash', None) + if keyHash is None and self.password is None: + return True + if keyHash is None: + raise errors.AuthError('Invalid keyHash') + if self.password is None: + raise errors.AuthError('Missing password') + + keyHashAlgorithmID = self.info.get('keyHashAlgorithmID','MD5') + + password = self.password.encode('utf8') + saltHash = self._decode_hex(self.info['salt']) + + keyBasis = password + saltHash + self.key = self.hash_algo[keyHashAlgorithmID](keyBasis).digest() + keyHash = self.hash_algo[keyHashAlgorithmID](self.key).hexdigest() + + if not keyHash.upper() == self.info['keyHash'].upper(): + raise errors.AuthError('Invalid Hash') + return True + + def validate(self): + """Verify required headers""" + for header in self._requiredHeaders: + if not self.headers.get(header, False): + raise errors.ParseError('Missing Notification Header: ' + header) + + def _format_info(self): + """Generate info line for GNTP Message + + :return string: + """ + info = 'GNTP/%s %s' % ( + self.info.get('version'), + self.info.get('messagetype'), + ) + if self.info.get('encryptionAlgorithmID', None): + info += ' %s:%s' % ( + self.info.get('encryptionAlgorithmID'), + self.info.get('ivValue'), + ) + else: + info += ' NONE' + + if self.info.get('keyHashAlgorithmID', None): + info += ' %s:%s.%s' % ( + self.info.get('keyHashAlgorithmID'), + self.info.get('keyHash'), + self.info.get('salt') + ) + + return info + + def _parse_dict(self, data): + """Helper function to parse blocks of GNTP headers into a dictionary + + :param string data: + :return dict: Dictionary of parsed GNTP Headers + """ + d = {} + for line in data.split('\r\n'): + match = GNTP_HEADER.match(line) + if not match: + continue + + key = match.group(1).strip() + val = match.group(2).strip() + d[key] = val + return d + + def add_header(self, key, value): + self.headers[key] = value + + def add_resource(self, data): + """Add binary resource + + :param string data: Binary Data + """ + data = gntp.shim.b(data) + identifier = hashlib.md5(data).hexdigest() + self.resources[identifier] = data + return 'x-growl-resource://%s' % identifier + + def decode(self, data, password=None): + """Decode GNTP Message + + :param string data: + """ + self.password = password + self.raw = gntp.shim.u(data) + parts = self.raw.split('\r\n\r\n') + self.info = self._parse_info(self.raw) + self.headers = self._parse_dict(parts[0]) + + def encode(self): + """Encode a generic GNTP Message + + :return string: GNTP Message ready to be sent. Returned as a byte string + """ + + buff = _GNTPBuffer() + + buff.writeln(self._format_info()) + + #Headers + for k, v in self.headers.items(): + buff.writeheader(k, v) + buff.writeln() + + #Resources + for resource, data in self.resources.items(): + buff.writeheader('Identifier', resource) + buff.writeheader('Length', len(data)) + buff.writeln() + buff.write(data) + buff.writeln() + buff.writeln() + + return buff.getvalue() + + +class GNTPRegister(_GNTPBase): + """Represents a GNTP Registration Command + + :param string data: (Optional) See decode() + :param string password: (Optional) Password to use while encoding/decoding messages + """ + _requiredHeaders = [ + 'Application-Name', + 'Notifications-Count' + ] + _requiredNotificationHeaders = ['Notification-Name'] + + def __init__(self, data=None, password=None): + _GNTPBase.__init__(self, 'REGISTER') + self.notifications = [] + + if data: + self.decode(data, password) + else: + self.set_password(password) + self.add_header('Application-Name', 'pygntp') + self.add_header('Notifications-Count', 0) + + def validate(self): + '''Validate required headers and validate notification headers''' + for header in self._requiredHeaders: + if not self.headers.get(header, False): + raise errors.ParseError('Missing Registration Header: ' + header) + for notice in self.notifications: + for header in self._requiredNotificationHeaders: + if not notice.get(header, False): + raise errors.ParseError('Missing Notification Header: ' + header) + + def decode(self, data, password): + """Decode existing GNTP Registration message + + :param string data: Message to decode + """ + self.raw = gntp.shim.u(data) + parts = self.raw.split('\r\n\r\n') + self.info = self._parse_info(self.raw) + self._validate_password(password) + self.headers = self._parse_dict(parts[0]) + + for i, part in enumerate(parts): + if i == 0: + continue # Skip Header + if part.strip() == '': + continue + notice = self._parse_dict(part) + if notice.get('Notification-Name', False): + self.notifications.append(notice) + elif notice.get('Identifier', False): + notice['Data'] = self._decode_binary(part, notice) + #open('register.png','wblol').write(notice['Data']) + self.resources[notice.get('Identifier')] = notice + + def add_notification(self, name, enabled=True): + """Add new Notification to Registration message + + :param string name: Notification Name + :param boolean enabled: Enable this notification by default + """ + notice = {} + notice['Notification-Name'] = name + notice['Notification-Enabled'] = enabled + + self.notifications.append(notice) + self.add_header('Notifications-Count', len(self.notifications)) + + def encode(self): + """Encode a GNTP Registration Message + + :return string: Encoded GNTP Registration message. Returned as a byte string + """ + + buff = _GNTPBuffer() + + buff.writeln(self._format_info()) + + #Headers + for k, v in self.headers.items(): + buff.writeheader(k, v) + buff.writeln() + + #Notifications + if len(self.notifications) > 0: + for notice in self.notifications: + for k, v in notice.items(): + buff.writeheader(k, v) + buff.writeln() + + #Resources + for resource, data in self.resources.items(): + buff.writeheader('Identifier', resource) + buff.writeheader('Length', len(data)) + buff.writeln() + buff.write(data) + buff.writeln() + buff.writeln() + + return buff.getvalue() + + +class GNTPNotice(_GNTPBase): + """Represents a GNTP Notification Command + + :param string data: (Optional) See decode() + :param string app: (Optional) Set Application-Name + :param string name: (Optional) Set Notification-Name + :param string title: (Optional) Set Notification Title + :param string password: (Optional) Password to use while encoding/decoding messages + """ + _requiredHeaders = [ + 'Application-Name', + 'Notification-Name', + 'Notification-Title' + ] + + def __init__(self, data=None, app=None, name=None, title=None, password=None): + _GNTPBase.__init__(self, 'NOTIFY') + + if data: + self.decode(data, password) + else: + self.set_password(password) + if app: + self.add_header('Application-Name', app) + if name: + self.add_header('Notification-Name', name) + if title: + self.add_header('Notification-Title', title) + + def decode(self, data, password): + """Decode existing GNTP Notification message + + :param string data: Message to decode. + """ + self.raw = gntp.shim.u(data) + parts = self.raw.split('\r\n\r\n') + self.info = self._parse_info(self.raw) + self._validate_password(password) + self.headers = self._parse_dict(parts[0]) + + for i, part in enumerate(parts): + if i == 0: + continue # Skip Header + if part.strip() == '': + continue + notice = self._parse_dict(part) + if notice.get('Identifier', False): + notice['Data'] = self._decode_binary(part, notice) + #open('notice.png','wblol').write(notice['Data']) + self.resources[notice.get('Identifier')] = notice + + +class GNTPSubscribe(_GNTPBase): + """Represents a GNTP Subscribe Command + + :param string data: (Optional) See decode() + :param string password: (Optional) Password to use while encoding/decoding messages + """ + _requiredHeaders = [ + 'Subscriber-ID', + 'Subscriber-Name', + ] + + def __init__(self, data=None, password=None): + _GNTPBase.__init__(self, 'SUBSCRIBE') + if data: + self.decode(data, password) + else: + self.set_password(password) + + +class GNTPOK(_GNTPBase): + """Represents a GNTP OK Response + + :param string data: (Optional) See _GNTPResponse.decode() + :param string action: (Optional) Set type of action the OK Response is for + """ + _requiredHeaders = ['Response-Action'] + + def __init__(self, data=None, action=None): + _GNTPBase.__init__(self, '-OK') + if data: + self.decode(data) + if action: + self.add_header('Response-Action', action) + + +class GNTPError(_GNTPBase): + """Represents a GNTP Error response + + :param string data: (Optional) See _GNTPResponse.decode() + :param string errorcode: (Optional) Error code + :param string errordesc: (Optional) Error Description + """ + _requiredHeaders = ['Error-Code', 'Error-Description'] + + def __init__(self, data=None, errorcode=None, errordesc=None): + _GNTPBase.__init__(self, '-ERROR') + if data: + self.decode(data) + if errorcode: + self.add_header('Error-Code', errorcode) + self.add_header('Error-Description', errordesc) + + def error(self): + return (self.headers.get('Error-Code', None), + self.headers.get('Error-Description', None)) + + +def parse_gntp(data, password=None): + """Attempt to parse a message as a GNTP message + + :param string data: Message to be parsed + :param string password: Optional password to be used to verify the message + """ + data = gntp.shim.u(data) + match = GNTP_INFO_LINE_SHORT.match(data) + if not match: + raise errors.ParseError('INVALID_GNTP_INFO') + info = match.groupdict() + if info['messagetype'] == 'REGISTER': + return GNTPRegister(data, password=password) + elif info['messagetype'] == 'NOTIFY': + return GNTPNotice(data, password=password) + elif info['messagetype'] == 'SUBSCRIBE': + return GNTPSubscribe(data, password=password) + elif info['messagetype'] == '-OK': + return GNTPOK(data) + elif info['messagetype'] == '-ERROR': + return GNTPError(data) + raise errors.ParseError('INVALID_GNTP_MESSAGE') diff --git a/lib/gntp/errors.py b/lib/gntp/errors.py new file mode 100644 index 00000000..c006fd68 --- /dev/null +++ b/lib/gntp/errors.py @@ -0,0 +1,25 @@ +# Copyright: 2013 Paul Traylor +# These sources are released under the terms of the MIT license: see LICENSE + +class BaseError(Exception): + pass + + +class ParseError(BaseError): + errorcode = 500 + errordesc = 'Error parsing the message' + + +class AuthError(BaseError): + errorcode = 400 + errordesc = 'Error with authorization' + + +class UnsupportedError(BaseError): + errorcode = 500 + errordesc = 'Currently unsupported by gntp.py' + + +class NetworkError(BaseError): + errorcode = 500 + errordesc = "Error connecting to growl server" diff --git a/lib/gntp/notifier.py b/lib/gntp/notifier.py new file mode 100644 index 00000000..1719ecdf --- /dev/null +++ b/lib/gntp/notifier.py @@ -0,0 +1,265 @@ +# Copyright: 2013 Paul Traylor +# These sources are released under the terms of the MIT license: see LICENSE + +""" +The gntp.notifier module is provided as a simple way to send notifications +using GNTP + +.. note:: + This class is intended to mostly mirror the older Python bindings such + that you should be able to replace instances of the old bindings with + this class. + `Original Python bindings `_ + +""" +import logging +import platform +import socket +import sys + +from gntp.version import __version__ +import gntp.core +import gntp.errors as errors +import gntp.shim + +__all__ = [ + 'mini', + 'GrowlNotifier', +] + +logger = logging.getLogger(__name__) + + +class GrowlNotifier(object): + """Helper class to simplfy sending Growl messages + + :param string applicationName: Sending application name + :param list notification: List of valid notifications + :param list defaultNotifications: List of notifications that should be enabled + by default + :param string applicationIcon: Icon URL + :param string hostname: Remote host + :param integer port: Remote port + """ + + passwordHash = 'MD5' + socketTimeout = 3 + + def __init__(self, applicationName='Python GNTP', notifications=[], + defaultNotifications=None, applicationIcon=None, hostname='localhost', + password=None, port=23053): + + self.applicationName = applicationName + self.notifications = list(notifications) + if defaultNotifications: + self.defaultNotifications = list(defaultNotifications) + else: + self.defaultNotifications = self.notifications + self.applicationIcon = applicationIcon + + self.password = password + self.hostname = hostname + self.port = int(port) + + def _checkIcon(self, data): + ''' + Check the icon to see if it's valid + + If it's a simple URL icon, then we return True. If it's a data icon + then we return False + ''' + logger.info('Checking icon') + return gntp.shim.u(data).startswith('http') + + def register(self): + """Send GNTP Registration + + .. warning:: + Before sending notifications to Growl, you need to have + sent a registration message at least once + """ + logger.info('Sending registration to %s:%s', self.hostname, self.port) + register = gntp.core.GNTPRegister() + register.add_header('Application-Name', self.applicationName) + for notification in self.notifications: + enabled = notification in self.defaultNotifications + register.add_notification(notification, enabled) + if self.applicationIcon: + if self._checkIcon(self.applicationIcon): + register.add_header('Application-Icon', self.applicationIcon) + else: + resource = register.add_resource(self.applicationIcon) + register.add_header('Application-Icon', resource) + if self.password: + register.set_password(self.password, self.passwordHash) + self.add_origin_info(register) + self.register_hook(register) + return self._send('register', register) + + def notify(self, noteType, title, description, icon=None, sticky=False, + priority=None, callback=None, identifier=None, custom={}): + """Send a GNTP notifications + + .. warning:: + Must have registered with growl beforehand or messages will be ignored + + :param string noteType: One of the notification names registered earlier + :param string title: Notification title (usually displayed on the notification) + :param string description: The main content of the notification + :param string icon: Icon URL path + :param boolean sticky: Sticky notification + :param integer priority: Message priority level from -2 to 2 + :param string callback: URL callback + :param dict custom: Custom attributes. Key names should be prefixed with X- + according to the spec but this is not enforced by this class + + .. warning:: + For now, only URL callbacks are supported. In the future, the + callback argument will also support a function + """ + logger.info('Sending notification [%s] to %s:%s', noteType, self.hostname, self.port) + assert noteType in self.notifications + notice = gntp.core.GNTPNotice() + notice.add_header('Application-Name', self.applicationName) + notice.add_header('Notification-Name', noteType) + notice.add_header('Notification-Title', title) + if self.password: + notice.set_password(self.password, self.passwordHash) + if sticky: + notice.add_header('Notification-Sticky', sticky) + if priority: + notice.add_header('Notification-Priority', priority) + if icon: + if self._checkIcon(icon): + notice.add_header('Notification-Icon', icon) + else: + resource = notice.add_resource(icon) + notice.add_header('Notification-Icon', resource) + + if description: + notice.add_header('Notification-Text', description) + if callback: + notice.add_header('Notification-Callback-Target', callback) + if identifier: + notice.add_header('Notification-Coalescing-ID', identifier) + + for key in custom: + notice.add_header(key, custom[key]) + + self.add_origin_info(notice) + self.notify_hook(notice) + + return self._send('notify', notice) + + def subscribe(self, id, name, port): + """Send a Subscribe request to a remote machine""" + sub = gntp.core.GNTPSubscribe() + sub.add_header('Subscriber-ID', id) + sub.add_header('Subscriber-Name', name) + sub.add_header('Subscriber-Port', port) + if self.password: + sub.set_password(self.password, self.passwordHash) + + self.add_origin_info(sub) + self.subscribe_hook(sub) + + return self._send('subscribe', sub) + + def add_origin_info(self, packet): + """Add optional Origin headers to message""" + packet.add_header('Origin-Machine-Name', platform.node()) + packet.add_header('Origin-Software-Name', 'gntp.py') + packet.add_header('Origin-Software-Version', __version__) + packet.add_header('Origin-Platform-Name', platform.system()) + packet.add_header('Origin-Platform-Version', platform.platform()) + + def register_hook(self, packet): + pass + + def notify_hook(self, packet): + pass + + def subscribe_hook(self, packet): + pass + + def _send(self, messagetype, packet): + """Send the GNTP Packet""" + + packet.validate() + data = packet.encode() + + logger.debug('To : %s:%s <%s>\n%s', self.hostname, self.port, packet.__class__, data) + + s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + s.settimeout(self.socketTimeout) + try: + s.connect((self.hostname, self.port)) + s.send(data) + recv_data = s.recv(1024) + while not recv_data.endswith(gntp.shim.b("\r\n\r\n")): + recv_data += s.recv(1024) + except socket.error: + # Python2.5 and Python3 compatibile exception + exc = sys.exc_info()[1] + raise errors.NetworkError(exc) + + response = gntp.core.parse_gntp(recv_data) + s.close() + + logger.debug('From : %s:%s <%s>\n%s', self.hostname, self.port, response.__class__, response) + + if type(response) == gntp.core.GNTPOK: + return True + logger.error('Invalid response: %s', response.error()) + return response.error() + + +def mini(description, applicationName='PythonMini', noteType="Message", + title="Mini Message", applicationIcon=None, hostname='localhost', + password=None, port=23053, sticky=False, priority=None, + callback=None, notificationIcon=None, identifier=None, + notifierFactory=GrowlNotifier): + """Single notification function + + Simple notification function in one line. Has only one required parameter + and attempts to use reasonable defaults for everything else + :param string description: Notification message + + .. warning:: + For now, only URL callbacks are supported. In the future, the + callback argument will also support a function + """ + try: + growl = notifierFactory( + applicationName=applicationName, + notifications=[noteType], + defaultNotifications=[noteType], + applicationIcon=applicationIcon, + hostname=hostname, + password=password, + port=port, + ) + result = growl.register() + if result is not True: + return result + + return growl.notify( + noteType=noteType, + title=title, + description=description, + icon=notificationIcon, + sticky=sticky, + priority=priority, + callback=callback, + identifier=identifier, + ) + except Exception: + # We want the "mini" function to be simple and swallow Exceptions + # in order to be less invasive + logger.exception("Growl error") + +if __name__ == '__main__': + # If we're running this module directly we're likely running it as a test + # so extra debugging is useful + logging.basicConfig(level=logging.INFO) + mini('Testing mini notification') diff --git a/lib/gntp/shim.py b/lib/gntp/shim.py new file mode 100644 index 00000000..3a387828 --- /dev/null +++ b/lib/gntp/shim.py @@ -0,0 +1,45 @@ +# Copyright: 2013 Paul Traylor +# These sources are released under the terms of the MIT license: see LICENSE + +""" +Python2.5 and Python3.3 compatibility shim + +Heavily inspirted by the "six" library. +https://pypi.python.org/pypi/six +""" + +import sys + +PY3 = sys.version_info[0] == 3 + +if PY3: + def b(s): + if isinstance(s, bytes): + return s + return s.encode('utf8', 'replace') + + def u(s): + if isinstance(s, bytes): + return s.decode('utf8', 'replace') + return s + + from io import BytesIO as StringIO + from configparser import RawConfigParser +else: + def b(s): + if isinstance(s, unicode): + return s.encode('utf8', 'replace') + return s + + def u(s): + if isinstance(s, unicode): + return s + if isinstance(s, int): + s = str(s) + return unicode(s, "utf8", "replace") + + from StringIO import StringIO + from ConfigParser import RawConfigParser + +b.__doc__ = "Ensure we have a byte string" +u.__doc__ = "Ensure we have a unicode string" diff --git a/lib/gntp/version.py b/lib/gntp/version.py new file mode 100644 index 00000000..2166aaca --- /dev/null +++ b/lib/gntp/version.py @@ -0,0 +1,4 @@ +# Copyright: 2013 Paul Traylor +# These sources are released under the terms of the MIT license: see LICENSE + +__version__ = '1.0.2' diff --git a/lib/html5lib/__init__.py b/lib/html5lib/__init__.py new file mode 100644 index 00000000..3ba1163c --- /dev/null +++ b/lib/html5lib/__init__.py @@ -0,0 +1,25 @@ +""" +HTML parsing library based on the WHATWG "HTML5" +specification. The parser is designed to be compatible with existing +HTML found in the wild and implements well-defined error recovery that +is largely compatible with modern desktop web browsers. + +Example usage: + +import html5lib +f = open("my_document.html") +tree = html5lib.parse(f) +""" + +from __future__ import absolute_import, division, unicode_literals + +from .html5parser import HTMLParser, parse, parseFragment +from .treebuilders import getTreeBuilder +from .treewalkers import getTreeWalker +from .serializer import serialize + +__all__ = ["HTMLParser", "parse", "parseFragment", "getTreeBuilder", + "getTreeWalker", "serialize"] + +# this has to be at the top level, see how setup.py parses this +__version__ = "0.999999" diff --git a/lib/html5lib/constants.py b/lib/html5lib/constants.py new file mode 100644 index 00000000..d938e0ae --- /dev/null +++ b/lib/html5lib/constants.py @@ -0,0 +1,3102 @@ +from __future__ import absolute_import, division, unicode_literals + +import string + +EOF = None + +E = { + "null-character": + "Null character in input stream, replaced with U+FFFD.", + "invalid-codepoint": + "Invalid codepoint in stream.", + "incorrectly-placed-solidus": + "Solidus (/) incorrectly placed in tag.", + "incorrect-cr-newline-entity": + "Incorrect CR newline entity, replaced with LF.", + "illegal-windows-1252-entity": + "Entity used with illegal number (windows-1252 reference).", + "cant-convert-numeric-entity": + "Numeric entity couldn't be converted to character " + "(codepoint U+%(charAsInt)08x).", + "illegal-codepoint-for-numeric-entity": + "Numeric entity represents an illegal codepoint: " + "U+%(charAsInt)08x.", + "numeric-entity-without-semicolon": + "Numeric entity didn't end with ';'.", + "expected-numeric-entity-but-got-eof": + "Numeric entity expected. Got end of file instead.", + "expected-numeric-entity": + "Numeric entity expected but none found.", + "named-entity-without-semicolon": + "Named entity didn't end with ';'.", + "expected-named-entity": + "Named entity expected. Got none.", + "attributes-in-end-tag": + "End tag contains unexpected attributes.", + 'self-closing-flag-on-end-tag': + "End tag contains unexpected self-closing flag.", + "expected-tag-name-but-got-right-bracket": + "Expected tag name. Got '>' instead.", + "expected-tag-name-but-got-question-mark": + "Expected tag name. Got '?' instead. (HTML doesn't " + "support processing instructions.)", + "expected-tag-name": + "Expected tag name. Got something else instead", + "expected-closing-tag-but-got-right-bracket": + "Expected closing tag. Got '>' instead. Ignoring ''.", + "expected-closing-tag-but-got-eof": + "Expected closing tag. Unexpected end of file.", + "expected-closing-tag-but-got-char": + "Expected closing tag. Unexpected character '%(data)s' found.", + "eof-in-tag-name": + "Unexpected end of file in the tag name.", + "expected-attribute-name-but-got-eof": + "Unexpected end of file. Expected attribute name instead.", + "eof-in-attribute-name": + "Unexpected end of file in attribute name.", + "invalid-character-in-attribute-name": + "Invalid character in attribute name", + "duplicate-attribute": + "Dropped duplicate attribute on tag.", + "expected-end-of-tag-name-but-got-eof": + "Unexpected end of file. Expected = or end of tag.", + "expected-attribute-value-but-got-eof": + "Unexpected end of file. Expected attribute value.", + "expected-attribute-value-but-got-right-bracket": + "Expected attribute value. Got '>' instead.", + 'equals-in-unquoted-attribute-value': + "Unexpected = in unquoted attribute", + 'unexpected-character-in-unquoted-attribute-value': + "Unexpected character in unquoted attribute", + "invalid-character-after-attribute-name": + "Unexpected character after attribute name.", + "unexpected-character-after-attribute-value": + "Unexpected character after attribute value.", + "eof-in-attribute-value-double-quote": + "Unexpected end of file in attribute value (\").", + "eof-in-attribute-value-single-quote": + "Unexpected end of file in attribute value (').", + "eof-in-attribute-value-no-quotes": + "Unexpected end of file in attribute value.", + "unexpected-EOF-after-solidus-in-tag": + "Unexpected end of file in tag. Expected >", + "unexpected-character-after-solidus-in-tag": + "Unexpected character after / in tag. Expected >", + "expected-dashes-or-doctype": + "Expected '--' or 'DOCTYPE'. Not found.", + "unexpected-bang-after-double-dash-in-comment": + "Unexpected ! after -- in comment", + "unexpected-space-after-double-dash-in-comment": + "Unexpected space after -- in comment", + "incorrect-comment": + "Incorrect comment.", + "eof-in-comment": + "Unexpected end of file in comment.", + "eof-in-comment-end-dash": + "Unexpected end of file in comment (-)", + "unexpected-dash-after-double-dash-in-comment": + "Unexpected '-' after '--' found in comment.", + "eof-in-comment-double-dash": + "Unexpected end of file in comment (--).", + "eof-in-comment-end-space-state": + "Unexpected end of file in comment.", + "eof-in-comment-end-bang-state": + "Unexpected end of file in comment.", + "unexpected-char-in-comment": + "Unexpected character in comment found.", + "need-space-after-doctype": + "No space after literal string 'DOCTYPE'.", + "expected-doctype-name-but-got-right-bracket": + "Unexpected > character. Expected DOCTYPE name.", + "expected-doctype-name-but-got-eof": + "Unexpected end of file. Expected DOCTYPE name.", + "eof-in-doctype-name": + "Unexpected end of file in DOCTYPE name.", + "eof-in-doctype": + "Unexpected end of file in DOCTYPE.", + "expected-space-or-right-bracket-in-doctype": + "Expected space or '>'. Got '%(data)s'", + "unexpected-end-of-doctype": + "Unexpected end of DOCTYPE.", + "unexpected-char-in-doctype": + "Unexpected character in DOCTYPE.", + "eof-in-innerhtml": + "XXX innerHTML EOF", + "unexpected-doctype": + "Unexpected DOCTYPE. Ignored.", + "non-html-root": + "html needs to be the first start tag.", + "expected-doctype-but-got-eof": + "Unexpected End of file. Expected DOCTYPE.", + "unknown-doctype": + "Erroneous DOCTYPE.", + "expected-doctype-but-got-chars": + "Unexpected non-space characters. Expected DOCTYPE.", + "expected-doctype-but-got-start-tag": + "Unexpected start tag (%(name)s). Expected DOCTYPE.", + "expected-doctype-but-got-end-tag": + "Unexpected end tag (%(name)s). Expected DOCTYPE.", + "end-tag-after-implied-root": + "Unexpected end tag (%(name)s) after the (implied) root element.", + "expected-named-closing-tag-but-got-eof": + "Unexpected end of file. Expected end tag (%(name)s).", + "two-heads-are-not-better-than-one": + "Unexpected start tag head in existing head. Ignored.", + "unexpected-end-tag": + "Unexpected end tag (%(name)s). Ignored.", + "unexpected-start-tag-out-of-my-head": + "Unexpected start tag (%(name)s) that can be in head. Moved.", + "unexpected-start-tag": + "Unexpected start tag (%(name)s).", + "missing-end-tag": + "Missing end tag (%(name)s).", + "missing-end-tags": + "Missing end tags (%(name)s).", + "unexpected-start-tag-implies-end-tag": + "Unexpected start tag (%(startName)s) " + "implies end tag (%(endName)s).", + "unexpected-start-tag-treated-as": + "Unexpected start tag (%(originalName)s). Treated as %(newName)s.", + "deprecated-tag": + "Unexpected start tag %(name)s. Don't use it!", + "unexpected-start-tag-ignored": + "Unexpected start tag %(name)s. Ignored.", + "expected-one-end-tag-but-got-another": + "Unexpected end tag (%(gotName)s). " + "Missing end tag (%(expectedName)s).", + "end-tag-too-early": + "End tag (%(name)s) seen too early. Expected other end tag.", + "end-tag-too-early-named": + "Unexpected end tag (%(gotName)s). Expected end tag (%(expectedName)s).", + "end-tag-too-early-ignored": + "End tag (%(name)s) seen too early. Ignored.", + "adoption-agency-1.1": + "End tag (%(name)s) violates step 1, " + "paragraph 1 of the adoption agency algorithm.", + "adoption-agency-1.2": + "End tag (%(name)s) violates step 1, " + "paragraph 2 of the adoption agency algorithm.", + "adoption-agency-1.3": + "End tag (%(name)s) violates step 1, " + "paragraph 3 of the adoption agency algorithm.", + "adoption-agency-4.4": + "End tag (%(name)s) violates step 4, " + "paragraph 4 of the adoption agency algorithm.", + "unexpected-end-tag-treated-as": + "Unexpected end tag (%(originalName)s). Treated as %(newName)s.", + "no-end-tag": + "This element (%(name)s) has no end tag.", + "unexpected-implied-end-tag-in-table": + "Unexpected implied end tag (%(name)s) in the table phase.", + "unexpected-implied-end-tag-in-table-body": + "Unexpected implied end tag (%(name)s) in the table body phase.", + "unexpected-char-implies-table-voodoo": + "Unexpected non-space characters in " + "table context caused voodoo mode.", + "unexpected-hidden-input-in-table": + "Unexpected input with type hidden in table context.", + "unexpected-form-in-table": + "Unexpected form in table context.", + "unexpected-start-tag-implies-table-voodoo": + "Unexpected start tag (%(name)s) in " + "table context caused voodoo mode.", + "unexpected-end-tag-implies-table-voodoo": + "Unexpected end tag (%(name)s) in " + "table context caused voodoo mode.", + "unexpected-cell-in-table-body": + "Unexpected table cell start tag (%(name)s) " + "in the table body phase.", + "unexpected-cell-end-tag": + "Got table cell end tag (%(name)s) " + "while required end tags are missing.", + "unexpected-end-tag-in-table-body": + "Unexpected end tag (%(name)s) in the table body phase. Ignored.", + "unexpected-implied-end-tag-in-table-row": + "Unexpected implied end tag (%(name)s) in the table row phase.", + "unexpected-end-tag-in-table-row": + "Unexpected end tag (%(name)s) in the table row phase. Ignored.", + "unexpected-select-in-select": + "Unexpected select start tag in the select phase " + "treated as select end tag.", + "unexpected-input-in-select": + "Unexpected input start tag in the select phase.", + "unexpected-start-tag-in-select": + "Unexpected start tag token (%(name)s in the select phase. " + "Ignored.", + "unexpected-end-tag-in-select": + "Unexpected end tag (%(name)s) in the select phase. Ignored.", + "unexpected-table-element-start-tag-in-select-in-table": + "Unexpected table element start tag (%(name)s) in the select in table phase.", + "unexpected-table-element-end-tag-in-select-in-table": + "Unexpected table element end tag (%(name)s) in the select in table phase.", + "unexpected-char-after-body": + "Unexpected non-space characters in the after body phase.", + "unexpected-start-tag-after-body": + "Unexpected start tag token (%(name)s)" + " in the after body phase.", + "unexpected-end-tag-after-body": + "Unexpected end tag token (%(name)s)" + " in the after body phase.", + "unexpected-char-in-frameset": + "Unexpected characters in the frameset phase. Characters ignored.", + "unexpected-start-tag-in-frameset": + "Unexpected start tag token (%(name)s)" + " in the frameset phase. Ignored.", + "unexpected-frameset-in-frameset-innerhtml": + "Unexpected end tag token (frameset) " + "in the frameset phase (innerHTML).", + "unexpected-end-tag-in-frameset": + "Unexpected end tag token (%(name)s)" + " in the frameset phase. Ignored.", + "unexpected-char-after-frameset": + "Unexpected non-space characters in the " + "after frameset phase. Ignored.", + "unexpected-start-tag-after-frameset": + "Unexpected start tag (%(name)s)" + " in the after frameset phase. Ignored.", + "unexpected-end-tag-after-frameset": + "Unexpected end tag (%(name)s)" + " in the after frameset phase. Ignored.", + "unexpected-end-tag-after-body-innerhtml": + "Unexpected end tag after body(innerHtml)", + "expected-eof-but-got-char": + "Unexpected non-space characters. Expected end of file.", + "expected-eof-but-got-start-tag": + "Unexpected start tag (%(name)s)" + ". Expected end of file.", + "expected-eof-but-got-end-tag": + "Unexpected end tag (%(name)s)" + ". Expected end of file.", + "eof-in-table": + "Unexpected end of file. Expected table content.", + "eof-in-select": + "Unexpected end of file. Expected select content.", + "eof-in-frameset": + "Unexpected end of file. Expected frameset content.", + "eof-in-script-in-script": + "Unexpected end of file. Expected script content.", + "eof-in-foreign-lands": + "Unexpected end of file. Expected foreign content", + "non-void-element-with-trailing-solidus": + "Trailing solidus not allowed on element %(name)s", + "unexpected-html-element-in-foreign-content": + "Element %(name)s not allowed in a non-html context", + "unexpected-end-tag-before-html": + "Unexpected end tag (%(name)s) before html.", + "XXX-undefined-error": + "Undefined error (this sucks and should be fixed)", +} + +namespaces = { + "html": "http://www.w3.org/1999/xhtml", + "mathml": "http://www.w3.org/1998/Math/MathML", + "svg": "http://www.w3.org/2000/svg", + "xlink": "http://www.w3.org/1999/xlink", + "xml": "http://www.w3.org/XML/1998/namespace", + "xmlns": "http://www.w3.org/2000/xmlns/" +} + +scopingElements = frozenset([ + (namespaces["html"], "applet"), + (namespaces["html"], "caption"), + (namespaces["html"], "html"), + (namespaces["html"], "marquee"), + (namespaces["html"], "object"), + (namespaces["html"], "table"), + (namespaces["html"], "td"), + (namespaces["html"], "th"), + (namespaces["mathml"], "mi"), + (namespaces["mathml"], "mo"), + (namespaces["mathml"], "mn"), + (namespaces["mathml"], "ms"), + (namespaces["mathml"], "mtext"), + (namespaces["mathml"], "annotation-xml"), + (namespaces["svg"], "foreignObject"), + (namespaces["svg"], "desc"), + (namespaces["svg"], "title"), +]) + +formattingElements = frozenset([ + (namespaces["html"], "a"), + (namespaces["html"], "b"), + (namespaces["html"], "big"), + (namespaces["html"], "code"), + (namespaces["html"], "em"), + (namespaces["html"], "font"), + (namespaces["html"], "i"), + (namespaces["html"], "nobr"), + (namespaces["html"], "s"), + (namespaces["html"], "small"), + (namespaces["html"], "strike"), + (namespaces["html"], "strong"), + (namespaces["html"], "tt"), + (namespaces["html"], "u") +]) + +specialElements = frozenset([ + (namespaces["html"], "address"), + (namespaces["html"], "applet"), + (namespaces["html"], "area"), + (namespaces["html"], "article"), + (namespaces["html"], "aside"), + (namespaces["html"], "base"), + (namespaces["html"], "basefont"), + (namespaces["html"], "bgsound"), + (namespaces["html"], "blockquote"), + (namespaces["html"], "body"), + (namespaces["html"], "br"), + (namespaces["html"], "button"), + (namespaces["html"], "caption"), + (namespaces["html"], "center"), + (namespaces["html"], "col"), + (namespaces["html"], "colgroup"), + (namespaces["html"], "command"), + (namespaces["html"], "dd"), + (namespaces["html"], "details"), + (namespaces["html"], "dir"), + (namespaces["html"], "div"), + (namespaces["html"], "dl"), + (namespaces["html"], "dt"), + (namespaces["html"], "embed"), + (namespaces["html"], "fieldset"), + (namespaces["html"], "figure"), + (namespaces["html"], "footer"), + (namespaces["html"], "form"), + (namespaces["html"], "frame"), + (namespaces["html"], "frameset"), + (namespaces["html"], "h1"), + (namespaces["html"], "h2"), + (namespaces["html"], "h3"), + (namespaces["html"], "h4"), + (namespaces["html"], "h5"), + (namespaces["html"], "h6"), + (namespaces["html"], "head"), + (namespaces["html"], "header"), + (namespaces["html"], "hr"), + (namespaces["html"], "html"), + (namespaces["html"], "iframe"), + # Note that image is commented out in the spec as "this isn't an + # element that can end up on the stack, so it doesn't matter," + (namespaces["html"], "image"), + (namespaces["html"], "img"), + (namespaces["html"], "input"), + (namespaces["html"], "isindex"), + (namespaces["html"], "li"), + (namespaces["html"], "link"), + (namespaces["html"], "listing"), + (namespaces["html"], "marquee"), + (namespaces["html"], "menu"), + (namespaces["html"], "meta"), + (namespaces["html"], "nav"), + (namespaces["html"], "noembed"), + (namespaces["html"], "noframes"), + (namespaces["html"], "noscript"), + (namespaces["html"], "object"), + (namespaces["html"], "ol"), + (namespaces["html"], "p"), + (namespaces["html"], "param"), + (namespaces["html"], "plaintext"), + (namespaces["html"], "pre"), + (namespaces["html"], "script"), + (namespaces["html"], "section"), + (namespaces["html"], "select"), + (namespaces["html"], "style"), + (namespaces["html"], "table"), + (namespaces["html"], "tbody"), + (namespaces["html"], "td"), + (namespaces["html"], "textarea"), + (namespaces["html"], "tfoot"), + (namespaces["html"], "th"), + (namespaces["html"], "thead"), + (namespaces["html"], "title"), + (namespaces["html"], "tr"), + (namespaces["html"], "ul"), + (namespaces["html"], "wbr"), + (namespaces["html"], "xmp"), + (namespaces["svg"], "foreignObject") +]) + +htmlIntegrationPointElements = frozenset([ + (namespaces["mathml"], "annotaion-xml"), + (namespaces["svg"], "foreignObject"), + (namespaces["svg"], "desc"), + (namespaces["svg"], "title") +]) + +mathmlTextIntegrationPointElements = frozenset([ + (namespaces["mathml"], "mi"), + (namespaces["mathml"], "mo"), + (namespaces["mathml"], "mn"), + (namespaces["mathml"], "ms"), + (namespaces["mathml"], "mtext") +]) + +adjustForeignAttributes = { + "xlink:actuate": ("xlink", "actuate", namespaces["xlink"]), + "xlink:arcrole": ("xlink", "arcrole", namespaces["xlink"]), + "xlink:href": ("xlink", "href", namespaces["xlink"]), + "xlink:role": ("xlink", "role", namespaces["xlink"]), + "xlink:show": ("xlink", "show", namespaces["xlink"]), + "xlink:title": ("xlink", "title", namespaces["xlink"]), + "xlink:type": ("xlink", "type", namespaces["xlink"]), + "xml:base": ("xml", "base", namespaces["xml"]), + "xml:lang": ("xml", "lang", namespaces["xml"]), + "xml:space": ("xml", "space", namespaces["xml"]), + "xmlns": (None, "xmlns", namespaces["xmlns"]), + "xmlns:xlink": ("xmlns", "xlink", namespaces["xmlns"]) +} + +unadjustForeignAttributes = dict([((ns, local), qname) for qname, (prefix, local, ns) in + adjustForeignAttributes.items()]) + +spaceCharacters = frozenset([ + "\t", + "\n", + "\u000C", + " ", + "\r" +]) + +tableInsertModeElements = frozenset([ + "table", + "tbody", + "tfoot", + "thead", + "tr" +]) + +asciiLowercase = frozenset(string.ascii_lowercase) +asciiUppercase = frozenset(string.ascii_uppercase) +asciiLetters = frozenset(string.ascii_letters) +digits = frozenset(string.digits) +hexDigits = frozenset(string.hexdigits) + +asciiUpper2Lower = dict([(ord(c), ord(c.lower())) + for c in string.ascii_uppercase]) + +# Heading elements need to be ordered +headingElements = ( + "h1", + "h2", + "h3", + "h4", + "h5", + "h6" +) + +voidElements = frozenset([ + "base", + "command", + "event-source", + "link", + "meta", + "hr", + "br", + "img", + "embed", + "param", + "area", + "col", + "input", + "source", + "track" +]) + +cdataElements = frozenset(['title', 'textarea']) + +rcdataElements = frozenset([ + 'style', + 'script', + 'xmp', + 'iframe', + 'noembed', + 'noframes', + 'noscript' +]) + +booleanAttributes = { + "": frozenset(["irrelevant"]), + "style": frozenset(["scoped"]), + "img": frozenset(["ismap"]), + "audio": frozenset(["autoplay", "controls"]), + "video": frozenset(["autoplay", "controls"]), + "script": frozenset(["defer", "async"]), + "details": frozenset(["open"]), + "datagrid": frozenset(["multiple", "disabled"]), + "command": frozenset(["hidden", "disabled", "checked", "default"]), + "hr": frozenset(["noshade"]), + "menu": frozenset(["autosubmit"]), + "fieldset": frozenset(["disabled", "readonly"]), + "option": frozenset(["disabled", "readonly", "selected"]), + "optgroup": frozenset(["disabled", "readonly"]), + "button": frozenset(["disabled", "autofocus"]), + "input": frozenset(["disabled", "readonly", "required", "autofocus", "checked", "ismap"]), + "select": frozenset(["disabled", "readonly", "autofocus", "multiple"]), + "output": frozenset(["disabled", "readonly"]), +} + +# entitiesWindows1252 has to be _ordered_ and needs to have an index. It +# therefore can't be a frozenset. +entitiesWindows1252 = ( + 8364, # 0x80 0x20AC EURO SIGN + 65533, # 0x81 UNDEFINED + 8218, # 0x82 0x201A SINGLE LOW-9 QUOTATION MARK + 402, # 0x83 0x0192 LATIN SMALL LETTER F WITH HOOK + 8222, # 0x84 0x201E DOUBLE LOW-9 QUOTATION MARK + 8230, # 0x85 0x2026 HORIZONTAL ELLIPSIS + 8224, # 0x86 0x2020 DAGGER + 8225, # 0x87 0x2021 DOUBLE DAGGER + 710, # 0x88 0x02C6 MODIFIER LETTER CIRCUMFLEX ACCENT + 8240, # 0x89 0x2030 PER MILLE SIGN + 352, # 0x8A 0x0160 LATIN CAPITAL LETTER S WITH CARON + 8249, # 0x8B 0x2039 SINGLE LEFT-POINTING ANGLE QUOTATION MARK + 338, # 0x8C 0x0152 LATIN CAPITAL LIGATURE OE + 65533, # 0x8D UNDEFINED + 381, # 0x8E 0x017D LATIN CAPITAL LETTER Z WITH CARON + 65533, # 0x8F UNDEFINED + 65533, # 0x90 UNDEFINED + 8216, # 0x91 0x2018 LEFT SINGLE QUOTATION MARK + 8217, # 0x92 0x2019 RIGHT SINGLE QUOTATION MARK + 8220, # 0x93 0x201C LEFT DOUBLE QUOTATION MARK + 8221, # 0x94 0x201D RIGHT DOUBLE QUOTATION MARK + 8226, # 0x95 0x2022 BULLET + 8211, # 0x96 0x2013 EN DASH + 8212, # 0x97 0x2014 EM DASH + 732, # 0x98 0x02DC SMALL TILDE + 8482, # 0x99 0x2122 TRADE MARK SIGN + 353, # 0x9A 0x0161 LATIN SMALL LETTER S WITH CARON + 8250, # 0x9B 0x203A SINGLE RIGHT-POINTING ANGLE QUOTATION MARK + 339, # 0x9C 0x0153 LATIN SMALL LIGATURE OE + 65533, # 0x9D UNDEFINED + 382, # 0x9E 0x017E LATIN SMALL LETTER Z WITH CARON + 376 # 0x9F 0x0178 LATIN CAPITAL LETTER Y WITH DIAERESIS +) + +xmlEntities = frozenset(['lt;', 'gt;', 'amp;', 'apos;', 'quot;']) + +entities = { + "AElig": "\xc6", + "AElig;": "\xc6", + "AMP": "&", + "AMP;": "&", + "Aacute": "\xc1", + "Aacute;": "\xc1", + "Abreve;": "\u0102", + "Acirc": "\xc2", + "Acirc;": "\xc2", + "Acy;": "\u0410", + "Afr;": "\U0001d504", + "Agrave": "\xc0", + "Agrave;": "\xc0", + "Alpha;": "\u0391", + "Amacr;": "\u0100", + "And;": "\u2a53", + "Aogon;": "\u0104", + "Aopf;": "\U0001d538", + "ApplyFunction;": "\u2061", + "Aring": "\xc5", + "Aring;": "\xc5", + "Ascr;": "\U0001d49c", + "Assign;": "\u2254", + "Atilde": "\xc3", + "Atilde;": "\xc3", + "Auml": "\xc4", + "Auml;": "\xc4", + "Backslash;": "\u2216", + "Barv;": "\u2ae7", + "Barwed;": "\u2306", + "Bcy;": "\u0411", + "Because;": "\u2235", + "Bernoullis;": "\u212c", + "Beta;": "\u0392", + "Bfr;": "\U0001d505", + "Bopf;": "\U0001d539", + "Breve;": "\u02d8", + "Bscr;": "\u212c", + "Bumpeq;": "\u224e", + "CHcy;": "\u0427", + "COPY": "\xa9", + "COPY;": "\xa9", + "Cacute;": "\u0106", + "Cap;": "\u22d2", + "CapitalDifferentialD;": "\u2145", + "Cayleys;": "\u212d", + "Ccaron;": "\u010c", + "Ccedil": "\xc7", + "Ccedil;": "\xc7", + "Ccirc;": "\u0108", + "Cconint;": "\u2230", + "Cdot;": "\u010a", + "Cedilla;": "\xb8", + "CenterDot;": "\xb7", + "Cfr;": "\u212d", + "Chi;": "\u03a7", + "CircleDot;": "\u2299", + "CircleMinus;": "\u2296", + "CirclePlus;": "\u2295", + "CircleTimes;": "\u2297", + "ClockwiseContourIntegral;": "\u2232", + "CloseCurlyDoubleQuote;": "\u201d", + "CloseCurlyQuote;": "\u2019", + "Colon;": "\u2237", + "Colone;": "\u2a74", + "Congruent;": "\u2261", + "Conint;": "\u222f", + "ContourIntegral;": "\u222e", + "Copf;": "\u2102", + "Coproduct;": "\u2210", + "CounterClockwiseContourIntegral;": "\u2233", + "Cross;": "\u2a2f", + "Cscr;": "\U0001d49e", + "Cup;": "\u22d3", + "CupCap;": "\u224d", + "DD;": "\u2145", + "DDotrahd;": "\u2911", + "DJcy;": "\u0402", + "DScy;": "\u0405", + "DZcy;": "\u040f", + "Dagger;": "\u2021", + "Darr;": "\u21a1", + "Dashv;": "\u2ae4", + "Dcaron;": "\u010e", + "Dcy;": "\u0414", + "Del;": "\u2207", + "Delta;": "\u0394", + "Dfr;": "\U0001d507", + "DiacriticalAcute;": "\xb4", + "DiacriticalDot;": "\u02d9", + "DiacriticalDoubleAcute;": "\u02dd", + "DiacriticalGrave;": "`", + "DiacriticalTilde;": "\u02dc", + "Diamond;": "\u22c4", + "DifferentialD;": "\u2146", + "Dopf;": "\U0001d53b", + "Dot;": "\xa8", + "DotDot;": "\u20dc", + "DotEqual;": "\u2250", + "DoubleContourIntegral;": "\u222f", + "DoubleDot;": "\xa8", + "DoubleDownArrow;": "\u21d3", + "DoubleLeftArrow;": "\u21d0", + "DoubleLeftRightArrow;": "\u21d4", + "DoubleLeftTee;": "\u2ae4", + "DoubleLongLeftArrow;": "\u27f8", + "DoubleLongLeftRightArrow;": "\u27fa", + "DoubleLongRightArrow;": "\u27f9", + "DoubleRightArrow;": "\u21d2", + "DoubleRightTee;": "\u22a8", + "DoubleUpArrow;": "\u21d1", + "DoubleUpDownArrow;": "\u21d5", + "DoubleVerticalBar;": "\u2225", + "DownArrow;": "\u2193", + "DownArrowBar;": "\u2913", + "DownArrowUpArrow;": "\u21f5", + "DownBreve;": "\u0311", + "DownLeftRightVector;": "\u2950", + "DownLeftTeeVector;": "\u295e", + "DownLeftVector;": "\u21bd", + "DownLeftVectorBar;": "\u2956", + "DownRightTeeVector;": "\u295f", + "DownRightVector;": "\u21c1", + "DownRightVectorBar;": "\u2957", + "DownTee;": "\u22a4", + "DownTeeArrow;": "\u21a7", + "Downarrow;": "\u21d3", + "Dscr;": "\U0001d49f", + "Dstrok;": "\u0110", + "ENG;": "\u014a", + "ETH": "\xd0", + "ETH;": "\xd0", + "Eacute": "\xc9", + "Eacute;": "\xc9", + "Ecaron;": "\u011a", + "Ecirc": "\xca", + "Ecirc;": "\xca", + "Ecy;": "\u042d", + "Edot;": "\u0116", + "Efr;": "\U0001d508", + "Egrave": "\xc8", + "Egrave;": "\xc8", + "Element;": "\u2208", + "Emacr;": "\u0112", + "EmptySmallSquare;": "\u25fb", + "EmptyVerySmallSquare;": "\u25ab", + "Eogon;": "\u0118", + "Eopf;": "\U0001d53c", + "Epsilon;": "\u0395", + "Equal;": "\u2a75", + "EqualTilde;": "\u2242", + "Equilibrium;": "\u21cc", + "Escr;": "\u2130", + "Esim;": "\u2a73", + "Eta;": "\u0397", + "Euml": "\xcb", + "Euml;": "\xcb", + "Exists;": "\u2203", + "ExponentialE;": "\u2147", + "Fcy;": "\u0424", + "Ffr;": "\U0001d509", + "FilledSmallSquare;": "\u25fc", + "FilledVerySmallSquare;": "\u25aa", + "Fopf;": "\U0001d53d", + "ForAll;": "\u2200", + "Fouriertrf;": "\u2131", + "Fscr;": "\u2131", + "GJcy;": "\u0403", + "GT": ">", + "GT;": ">", + "Gamma;": "\u0393", + "Gammad;": "\u03dc", + "Gbreve;": "\u011e", + "Gcedil;": "\u0122", + "Gcirc;": "\u011c", + "Gcy;": "\u0413", + "Gdot;": "\u0120", + "Gfr;": "\U0001d50a", + "Gg;": "\u22d9", + "Gopf;": "\U0001d53e", + "GreaterEqual;": "\u2265", + "GreaterEqualLess;": "\u22db", + "GreaterFullEqual;": "\u2267", + "GreaterGreater;": "\u2aa2", + "GreaterLess;": "\u2277", + "GreaterSlantEqual;": "\u2a7e", + "GreaterTilde;": "\u2273", + "Gscr;": "\U0001d4a2", + "Gt;": "\u226b", + "HARDcy;": "\u042a", + "Hacek;": "\u02c7", + "Hat;": "^", + "Hcirc;": "\u0124", + "Hfr;": "\u210c", + "HilbertSpace;": "\u210b", + "Hopf;": "\u210d", + "HorizontalLine;": "\u2500", + "Hscr;": "\u210b", + "Hstrok;": "\u0126", + "HumpDownHump;": "\u224e", + "HumpEqual;": "\u224f", + "IEcy;": "\u0415", + "IJlig;": "\u0132", + "IOcy;": "\u0401", + "Iacute": "\xcd", + "Iacute;": "\xcd", + "Icirc": "\xce", + "Icirc;": "\xce", + "Icy;": "\u0418", + "Idot;": "\u0130", + "Ifr;": "\u2111", + "Igrave": "\xcc", + "Igrave;": "\xcc", + "Im;": "\u2111", + "Imacr;": "\u012a", + "ImaginaryI;": "\u2148", + "Implies;": "\u21d2", + "Int;": "\u222c", + "Integral;": "\u222b", + "Intersection;": "\u22c2", + "InvisibleComma;": "\u2063", + "InvisibleTimes;": "\u2062", + "Iogon;": "\u012e", + "Iopf;": "\U0001d540", + "Iota;": "\u0399", + "Iscr;": "\u2110", + "Itilde;": "\u0128", + "Iukcy;": "\u0406", + "Iuml": "\xcf", + "Iuml;": "\xcf", + "Jcirc;": "\u0134", + "Jcy;": "\u0419", + "Jfr;": "\U0001d50d", + "Jopf;": "\U0001d541", + "Jscr;": "\U0001d4a5", + "Jsercy;": "\u0408", + "Jukcy;": "\u0404", + "KHcy;": "\u0425", + "KJcy;": "\u040c", + "Kappa;": "\u039a", + "Kcedil;": "\u0136", + "Kcy;": "\u041a", + "Kfr;": "\U0001d50e", + "Kopf;": "\U0001d542", + "Kscr;": "\U0001d4a6", + "LJcy;": "\u0409", + "LT": "<", + "LT;": "<", + "Lacute;": "\u0139", + "Lambda;": "\u039b", + "Lang;": "\u27ea", + "Laplacetrf;": "\u2112", + "Larr;": "\u219e", + "Lcaron;": "\u013d", + "Lcedil;": "\u013b", + "Lcy;": "\u041b", + "LeftAngleBracket;": "\u27e8", + "LeftArrow;": "\u2190", + "LeftArrowBar;": "\u21e4", + "LeftArrowRightArrow;": "\u21c6", + "LeftCeiling;": "\u2308", + "LeftDoubleBracket;": "\u27e6", + "LeftDownTeeVector;": "\u2961", + "LeftDownVector;": "\u21c3", + "LeftDownVectorBar;": "\u2959", + "LeftFloor;": "\u230a", + "LeftRightArrow;": "\u2194", + "LeftRightVector;": "\u294e", + "LeftTee;": "\u22a3", + "LeftTeeArrow;": "\u21a4", + "LeftTeeVector;": "\u295a", + "LeftTriangle;": "\u22b2", + "LeftTriangleBar;": "\u29cf", + "LeftTriangleEqual;": "\u22b4", + "LeftUpDownVector;": "\u2951", + "LeftUpTeeVector;": "\u2960", + "LeftUpVector;": "\u21bf", + "LeftUpVectorBar;": "\u2958", + "LeftVector;": "\u21bc", + "LeftVectorBar;": "\u2952", + "Leftarrow;": "\u21d0", + "Leftrightarrow;": "\u21d4", + "LessEqualGreater;": "\u22da", + "LessFullEqual;": "\u2266", + "LessGreater;": "\u2276", + "LessLess;": "\u2aa1", + "LessSlantEqual;": "\u2a7d", + "LessTilde;": "\u2272", + "Lfr;": "\U0001d50f", + "Ll;": "\u22d8", + "Lleftarrow;": "\u21da", + "Lmidot;": "\u013f", + "LongLeftArrow;": "\u27f5", + "LongLeftRightArrow;": "\u27f7", + "LongRightArrow;": "\u27f6", + "Longleftarrow;": "\u27f8", + "Longleftrightarrow;": "\u27fa", + "Longrightarrow;": "\u27f9", + "Lopf;": "\U0001d543", + "LowerLeftArrow;": "\u2199", + "LowerRightArrow;": "\u2198", + "Lscr;": "\u2112", + "Lsh;": "\u21b0", + "Lstrok;": "\u0141", + "Lt;": "\u226a", + "Map;": "\u2905", + "Mcy;": "\u041c", + "MediumSpace;": "\u205f", + "Mellintrf;": "\u2133", + "Mfr;": "\U0001d510", + "MinusPlus;": "\u2213", + "Mopf;": "\U0001d544", + "Mscr;": "\u2133", + "Mu;": "\u039c", + "NJcy;": "\u040a", + "Nacute;": "\u0143", + "Ncaron;": "\u0147", + "Ncedil;": "\u0145", + "Ncy;": "\u041d", + "NegativeMediumSpace;": "\u200b", + "NegativeThickSpace;": "\u200b", + "NegativeThinSpace;": "\u200b", + "NegativeVeryThinSpace;": "\u200b", + "NestedGreaterGreater;": "\u226b", + "NestedLessLess;": "\u226a", + "NewLine;": "\n", + "Nfr;": "\U0001d511", + "NoBreak;": "\u2060", + "NonBreakingSpace;": "\xa0", + "Nopf;": "\u2115", + "Not;": "\u2aec", + "NotCongruent;": "\u2262", + "NotCupCap;": "\u226d", + "NotDoubleVerticalBar;": "\u2226", + "NotElement;": "\u2209", + "NotEqual;": "\u2260", + "NotEqualTilde;": "\u2242\u0338", + "NotExists;": "\u2204", + "NotGreater;": "\u226f", + "NotGreaterEqual;": "\u2271", + "NotGreaterFullEqual;": "\u2267\u0338", + "NotGreaterGreater;": "\u226b\u0338", + "NotGreaterLess;": "\u2279", + "NotGreaterSlantEqual;": "\u2a7e\u0338", + "NotGreaterTilde;": "\u2275", + "NotHumpDownHump;": "\u224e\u0338", + "NotHumpEqual;": "\u224f\u0338", + "NotLeftTriangle;": "\u22ea", + "NotLeftTriangleBar;": "\u29cf\u0338", + "NotLeftTriangleEqual;": "\u22ec", + "NotLess;": "\u226e", + "NotLessEqual;": "\u2270", + "NotLessGreater;": "\u2278", + "NotLessLess;": "\u226a\u0338", + "NotLessSlantEqual;": "\u2a7d\u0338", + "NotLessTilde;": "\u2274", + "NotNestedGreaterGreater;": "\u2aa2\u0338", + "NotNestedLessLess;": "\u2aa1\u0338", + "NotPrecedes;": "\u2280", + "NotPrecedesEqual;": "\u2aaf\u0338", + "NotPrecedesSlantEqual;": "\u22e0", + "NotReverseElement;": "\u220c", + "NotRightTriangle;": "\u22eb", + "NotRightTriangleBar;": "\u29d0\u0338", + "NotRightTriangleEqual;": "\u22ed", + "NotSquareSubset;": "\u228f\u0338", + "NotSquareSubsetEqual;": "\u22e2", + "NotSquareSuperset;": "\u2290\u0338", + "NotSquareSupersetEqual;": "\u22e3", + "NotSubset;": "\u2282\u20d2", + "NotSubsetEqual;": "\u2288", + "NotSucceeds;": "\u2281", + "NotSucceedsEqual;": "\u2ab0\u0338", + "NotSucceedsSlantEqual;": "\u22e1", + "NotSucceedsTilde;": "\u227f\u0338", + "NotSuperset;": "\u2283\u20d2", + "NotSupersetEqual;": "\u2289", + "NotTilde;": "\u2241", + "NotTildeEqual;": "\u2244", + "NotTildeFullEqual;": "\u2247", + "NotTildeTilde;": "\u2249", + "NotVerticalBar;": "\u2224", + "Nscr;": "\U0001d4a9", + "Ntilde": "\xd1", + "Ntilde;": "\xd1", + "Nu;": "\u039d", + "OElig;": "\u0152", + "Oacute": "\xd3", + "Oacute;": "\xd3", + "Ocirc": "\xd4", + "Ocirc;": "\xd4", + "Ocy;": "\u041e", + "Odblac;": "\u0150", + "Ofr;": "\U0001d512", + "Ograve": "\xd2", + "Ograve;": "\xd2", + "Omacr;": "\u014c", + "Omega;": "\u03a9", + "Omicron;": "\u039f", + "Oopf;": "\U0001d546", + "OpenCurlyDoubleQuote;": "\u201c", + "OpenCurlyQuote;": "\u2018", + "Or;": "\u2a54", + "Oscr;": "\U0001d4aa", + "Oslash": "\xd8", + "Oslash;": "\xd8", + "Otilde": "\xd5", + "Otilde;": "\xd5", + "Otimes;": "\u2a37", + "Ouml": "\xd6", + "Ouml;": "\xd6", + "OverBar;": "\u203e", + "OverBrace;": "\u23de", + "OverBracket;": "\u23b4", + "OverParenthesis;": "\u23dc", + "PartialD;": "\u2202", + "Pcy;": "\u041f", + "Pfr;": "\U0001d513", + "Phi;": "\u03a6", + "Pi;": "\u03a0", + "PlusMinus;": "\xb1", + "Poincareplane;": "\u210c", + "Popf;": "\u2119", + "Pr;": "\u2abb", + "Precedes;": "\u227a", + "PrecedesEqual;": "\u2aaf", + "PrecedesSlantEqual;": "\u227c", + "PrecedesTilde;": "\u227e", + "Prime;": "\u2033", + "Product;": "\u220f", + "Proportion;": "\u2237", + "Proportional;": "\u221d", + "Pscr;": "\U0001d4ab", + "Psi;": "\u03a8", + "QUOT": "\"", + "QUOT;": "\"", + "Qfr;": "\U0001d514", + "Qopf;": "\u211a", + "Qscr;": "\U0001d4ac", + "RBarr;": "\u2910", + "REG": "\xae", + "REG;": "\xae", + "Racute;": "\u0154", + "Rang;": "\u27eb", + "Rarr;": "\u21a0", + "Rarrtl;": "\u2916", + "Rcaron;": "\u0158", + "Rcedil;": "\u0156", + "Rcy;": "\u0420", + "Re;": "\u211c", + "ReverseElement;": "\u220b", + "ReverseEquilibrium;": "\u21cb", + "ReverseUpEquilibrium;": "\u296f", + "Rfr;": "\u211c", + "Rho;": "\u03a1", + "RightAngleBracket;": "\u27e9", + "RightArrow;": "\u2192", + "RightArrowBar;": "\u21e5", + "RightArrowLeftArrow;": "\u21c4", + "RightCeiling;": "\u2309", + "RightDoubleBracket;": "\u27e7", + "RightDownTeeVector;": "\u295d", + "RightDownVector;": "\u21c2", + "RightDownVectorBar;": "\u2955", + "RightFloor;": "\u230b", + "RightTee;": "\u22a2", + "RightTeeArrow;": "\u21a6", + "RightTeeVector;": "\u295b", + "RightTriangle;": "\u22b3", + "RightTriangleBar;": "\u29d0", + "RightTriangleEqual;": "\u22b5", + "RightUpDownVector;": "\u294f", + "RightUpTeeVector;": "\u295c", + "RightUpVector;": "\u21be", + "RightUpVectorBar;": "\u2954", + "RightVector;": "\u21c0", + "RightVectorBar;": "\u2953", + "Rightarrow;": "\u21d2", + "Ropf;": "\u211d", + "RoundImplies;": "\u2970", + "Rrightarrow;": "\u21db", + "Rscr;": "\u211b", + "Rsh;": "\u21b1", + "RuleDelayed;": "\u29f4", + "SHCHcy;": "\u0429", + "SHcy;": "\u0428", + "SOFTcy;": "\u042c", + "Sacute;": "\u015a", + "Sc;": "\u2abc", + "Scaron;": "\u0160", + "Scedil;": "\u015e", + "Scirc;": "\u015c", + "Scy;": "\u0421", + "Sfr;": "\U0001d516", + "ShortDownArrow;": "\u2193", + "ShortLeftArrow;": "\u2190", + "ShortRightArrow;": "\u2192", + "ShortUpArrow;": "\u2191", + "Sigma;": "\u03a3", + "SmallCircle;": "\u2218", + "Sopf;": "\U0001d54a", + "Sqrt;": "\u221a", + "Square;": "\u25a1", + "SquareIntersection;": "\u2293", + "SquareSubset;": "\u228f", + "SquareSubsetEqual;": "\u2291", + "SquareSuperset;": "\u2290", + "SquareSupersetEqual;": "\u2292", + "SquareUnion;": "\u2294", + "Sscr;": "\U0001d4ae", + "Star;": "\u22c6", + "Sub;": "\u22d0", + "Subset;": "\u22d0", + "SubsetEqual;": "\u2286", + "Succeeds;": "\u227b", + "SucceedsEqual;": "\u2ab0", + "SucceedsSlantEqual;": "\u227d", + "SucceedsTilde;": "\u227f", + "SuchThat;": "\u220b", + "Sum;": "\u2211", + "Sup;": "\u22d1", + "Superset;": "\u2283", + "SupersetEqual;": "\u2287", + "Supset;": "\u22d1", + "THORN": "\xde", + "THORN;": "\xde", + "TRADE;": "\u2122", + "TSHcy;": "\u040b", + "TScy;": "\u0426", + "Tab;": "\t", + "Tau;": "\u03a4", + "Tcaron;": "\u0164", + "Tcedil;": "\u0162", + "Tcy;": "\u0422", + "Tfr;": "\U0001d517", + "Therefore;": "\u2234", + "Theta;": "\u0398", + "ThickSpace;": "\u205f\u200a", + "ThinSpace;": "\u2009", + "Tilde;": "\u223c", + "TildeEqual;": "\u2243", + "TildeFullEqual;": "\u2245", + "TildeTilde;": "\u2248", + "Topf;": "\U0001d54b", + "TripleDot;": "\u20db", + "Tscr;": "\U0001d4af", + "Tstrok;": "\u0166", + "Uacute": "\xda", + "Uacute;": "\xda", + "Uarr;": "\u219f", + "Uarrocir;": "\u2949", + "Ubrcy;": "\u040e", + "Ubreve;": "\u016c", + "Ucirc": "\xdb", + "Ucirc;": "\xdb", + "Ucy;": "\u0423", + "Udblac;": "\u0170", + "Ufr;": "\U0001d518", + "Ugrave": "\xd9", + "Ugrave;": "\xd9", + "Umacr;": "\u016a", + "UnderBar;": "_", + "UnderBrace;": "\u23df", + "UnderBracket;": "\u23b5", + "UnderParenthesis;": "\u23dd", + "Union;": "\u22c3", + "UnionPlus;": "\u228e", + "Uogon;": "\u0172", + "Uopf;": "\U0001d54c", + "UpArrow;": "\u2191", + "UpArrowBar;": "\u2912", + "UpArrowDownArrow;": "\u21c5", + "UpDownArrow;": "\u2195", + "UpEquilibrium;": "\u296e", + "UpTee;": "\u22a5", + "UpTeeArrow;": "\u21a5", + "Uparrow;": "\u21d1", + "Updownarrow;": "\u21d5", + "UpperLeftArrow;": "\u2196", + "UpperRightArrow;": "\u2197", + "Upsi;": "\u03d2", + "Upsilon;": "\u03a5", + "Uring;": "\u016e", + "Uscr;": "\U0001d4b0", + "Utilde;": "\u0168", + "Uuml": "\xdc", + "Uuml;": "\xdc", + "VDash;": "\u22ab", + "Vbar;": "\u2aeb", + "Vcy;": "\u0412", + "Vdash;": "\u22a9", + "Vdashl;": "\u2ae6", + "Vee;": "\u22c1", + "Verbar;": "\u2016", + "Vert;": "\u2016", + "VerticalBar;": "\u2223", + "VerticalLine;": "|", + "VerticalSeparator;": "\u2758", + "VerticalTilde;": "\u2240", + "VeryThinSpace;": "\u200a", + "Vfr;": "\U0001d519", + "Vopf;": "\U0001d54d", + "Vscr;": "\U0001d4b1", + "Vvdash;": "\u22aa", + "Wcirc;": "\u0174", + "Wedge;": "\u22c0", + "Wfr;": "\U0001d51a", + "Wopf;": "\U0001d54e", + "Wscr;": "\U0001d4b2", + "Xfr;": "\U0001d51b", + "Xi;": "\u039e", + "Xopf;": "\U0001d54f", + "Xscr;": "\U0001d4b3", + "YAcy;": "\u042f", + "YIcy;": "\u0407", + "YUcy;": "\u042e", + "Yacute": "\xdd", + "Yacute;": "\xdd", + "Ycirc;": "\u0176", + "Ycy;": "\u042b", + "Yfr;": "\U0001d51c", + "Yopf;": "\U0001d550", + "Yscr;": "\U0001d4b4", + "Yuml;": "\u0178", + "ZHcy;": "\u0416", + "Zacute;": "\u0179", + "Zcaron;": "\u017d", + "Zcy;": "\u0417", + "Zdot;": "\u017b", + "ZeroWidthSpace;": "\u200b", + "Zeta;": "\u0396", + "Zfr;": "\u2128", + "Zopf;": "\u2124", + "Zscr;": "\U0001d4b5", + "aacute": "\xe1", + "aacute;": "\xe1", + "abreve;": "\u0103", + "ac;": "\u223e", + "acE;": "\u223e\u0333", + "acd;": "\u223f", + "acirc": "\xe2", + "acirc;": "\xe2", + "acute": "\xb4", + "acute;": "\xb4", + "acy;": "\u0430", + "aelig": "\xe6", + "aelig;": "\xe6", + "af;": "\u2061", + "afr;": "\U0001d51e", + "agrave": "\xe0", + "agrave;": "\xe0", + "alefsym;": "\u2135", + "aleph;": "\u2135", + "alpha;": "\u03b1", + "amacr;": "\u0101", + "amalg;": "\u2a3f", + "amp": "&", + "amp;": "&", + "and;": "\u2227", + "andand;": "\u2a55", + "andd;": "\u2a5c", + "andslope;": "\u2a58", + "andv;": "\u2a5a", + "ang;": "\u2220", + "ange;": "\u29a4", + "angle;": "\u2220", + "angmsd;": "\u2221", + "angmsdaa;": "\u29a8", + "angmsdab;": "\u29a9", + "angmsdac;": "\u29aa", + "angmsdad;": "\u29ab", + "angmsdae;": "\u29ac", + "angmsdaf;": "\u29ad", + "angmsdag;": "\u29ae", + "angmsdah;": "\u29af", + "angrt;": "\u221f", + "angrtvb;": "\u22be", + "angrtvbd;": "\u299d", + "angsph;": "\u2222", + "angst;": "\xc5", + "angzarr;": "\u237c", + "aogon;": "\u0105", + "aopf;": "\U0001d552", + "ap;": "\u2248", + "apE;": "\u2a70", + "apacir;": "\u2a6f", + "ape;": "\u224a", + "apid;": "\u224b", + "apos;": "'", + "approx;": "\u2248", + "approxeq;": "\u224a", + "aring": "\xe5", + "aring;": "\xe5", + "ascr;": "\U0001d4b6", + "ast;": "*", + "asymp;": "\u2248", + "asympeq;": "\u224d", + "atilde": "\xe3", + "atilde;": "\xe3", + "auml": "\xe4", + "auml;": "\xe4", + "awconint;": "\u2233", + "awint;": "\u2a11", + "bNot;": "\u2aed", + "backcong;": "\u224c", + "backepsilon;": "\u03f6", + "backprime;": "\u2035", + "backsim;": "\u223d", + "backsimeq;": "\u22cd", + "barvee;": "\u22bd", + "barwed;": "\u2305", + "barwedge;": "\u2305", + "bbrk;": "\u23b5", + "bbrktbrk;": "\u23b6", + "bcong;": "\u224c", + "bcy;": "\u0431", + "bdquo;": "\u201e", + "becaus;": "\u2235", + "because;": "\u2235", + "bemptyv;": "\u29b0", + "bepsi;": "\u03f6", + "bernou;": "\u212c", + "beta;": "\u03b2", + "beth;": "\u2136", + "between;": "\u226c", + "bfr;": "\U0001d51f", + "bigcap;": "\u22c2", + "bigcirc;": "\u25ef", + "bigcup;": "\u22c3", + "bigodot;": "\u2a00", + "bigoplus;": "\u2a01", + "bigotimes;": "\u2a02", + "bigsqcup;": "\u2a06", + "bigstar;": "\u2605", + "bigtriangledown;": "\u25bd", + "bigtriangleup;": "\u25b3", + "biguplus;": "\u2a04", + "bigvee;": "\u22c1", + "bigwedge;": "\u22c0", + "bkarow;": "\u290d", + "blacklozenge;": "\u29eb", + "blacksquare;": "\u25aa", + "blacktriangle;": "\u25b4", + "blacktriangledown;": "\u25be", + "blacktriangleleft;": "\u25c2", + "blacktriangleright;": "\u25b8", + "blank;": "\u2423", + "blk12;": "\u2592", + "blk14;": "\u2591", + "blk34;": "\u2593", + "block;": "\u2588", + "bne;": "=\u20e5", + "bnequiv;": "\u2261\u20e5", + "bnot;": "\u2310", + "bopf;": "\U0001d553", + "bot;": "\u22a5", + "bottom;": "\u22a5", + "bowtie;": "\u22c8", + "boxDL;": "\u2557", + "boxDR;": "\u2554", + "boxDl;": "\u2556", + "boxDr;": "\u2553", + "boxH;": "\u2550", + "boxHD;": "\u2566", + "boxHU;": "\u2569", + "boxHd;": "\u2564", + "boxHu;": "\u2567", + "boxUL;": "\u255d", + "boxUR;": "\u255a", + "boxUl;": "\u255c", + "boxUr;": "\u2559", + "boxV;": "\u2551", + "boxVH;": "\u256c", + "boxVL;": "\u2563", + "boxVR;": "\u2560", + "boxVh;": "\u256b", + "boxVl;": "\u2562", + "boxVr;": "\u255f", + "boxbox;": "\u29c9", + "boxdL;": "\u2555", + "boxdR;": "\u2552", + "boxdl;": "\u2510", + "boxdr;": "\u250c", + "boxh;": "\u2500", + "boxhD;": "\u2565", + "boxhU;": "\u2568", + "boxhd;": "\u252c", + "boxhu;": "\u2534", + "boxminus;": "\u229f", + "boxplus;": "\u229e", + "boxtimes;": "\u22a0", + "boxuL;": "\u255b", + "boxuR;": "\u2558", + "boxul;": "\u2518", + "boxur;": "\u2514", + "boxv;": "\u2502", + "boxvH;": "\u256a", + "boxvL;": "\u2561", + "boxvR;": "\u255e", + "boxvh;": "\u253c", + "boxvl;": "\u2524", + "boxvr;": "\u251c", + "bprime;": "\u2035", + "breve;": "\u02d8", + "brvbar": "\xa6", + "brvbar;": "\xa6", + "bscr;": "\U0001d4b7", + "bsemi;": "\u204f", + "bsim;": "\u223d", + "bsime;": "\u22cd", + "bsol;": "\\", + "bsolb;": "\u29c5", + "bsolhsub;": "\u27c8", + "bull;": "\u2022", + "bullet;": "\u2022", + "bump;": "\u224e", + "bumpE;": "\u2aae", + "bumpe;": "\u224f", + "bumpeq;": "\u224f", + "cacute;": "\u0107", + "cap;": "\u2229", + "capand;": "\u2a44", + "capbrcup;": "\u2a49", + "capcap;": "\u2a4b", + "capcup;": "\u2a47", + "capdot;": "\u2a40", + "caps;": "\u2229\ufe00", + "caret;": "\u2041", + "caron;": "\u02c7", + "ccaps;": "\u2a4d", + "ccaron;": "\u010d", + "ccedil": "\xe7", + "ccedil;": "\xe7", + "ccirc;": "\u0109", + "ccups;": "\u2a4c", + "ccupssm;": "\u2a50", + "cdot;": "\u010b", + "cedil": "\xb8", + "cedil;": "\xb8", + "cemptyv;": "\u29b2", + "cent": "\xa2", + "cent;": "\xa2", + "centerdot;": "\xb7", + "cfr;": "\U0001d520", + "chcy;": "\u0447", + "check;": "\u2713", + "checkmark;": "\u2713", + "chi;": "\u03c7", + "cir;": "\u25cb", + "cirE;": "\u29c3", + "circ;": "\u02c6", + "circeq;": "\u2257", + "circlearrowleft;": "\u21ba", + "circlearrowright;": "\u21bb", + "circledR;": "\xae", + "circledS;": "\u24c8", + "circledast;": "\u229b", + "circledcirc;": "\u229a", + "circleddash;": "\u229d", + "cire;": "\u2257", + "cirfnint;": "\u2a10", + "cirmid;": "\u2aef", + "cirscir;": "\u29c2", + "clubs;": "\u2663", + "clubsuit;": "\u2663", + "colon;": ":", + "colone;": "\u2254", + "coloneq;": "\u2254", + "comma;": ",", + "commat;": "@", + "comp;": "\u2201", + "compfn;": "\u2218", + "complement;": "\u2201", + "complexes;": "\u2102", + "cong;": "\u2245", + "congdot;": "\u2a6d", + "conint;": "\u222e", + "copf;": "\U0001d554", + "coprod;": "\u2210", + "copy": "\xa9", + "copy;": "\xa9", + "copysr;": "\u2117", + "crarr;": "\u21b5", + "cross;": "\u2717", + "cscr;": "\U0001d4b8", + "csub;": "\u2acf", + "csube;": "\u2ad1", + "csup;": "\u2ad0", + "csupe;": "\u2ad2", + "ctdot;": "\u22ef", + "cudarrl;": "\u2938", + "cudarrr;": "\u2935", + "cuepr;": "\u22de", + "cuesc;": "\u22df", + "cularr;": "\u21b6", + "cularrp;": "\u293d", + "cup;": "\u222a", + "cupbrcap;": "\u2a48", + "cupcap;": "\u2a46", + "cupcup;": "\u2a4a", + "cupdot;": "\u228d", + "cupor;": "\u2a45", + "cups;": "\u222a\ufe00", + "curarr;": "\u21b7", + "curarrm;": "\u293c", + "curlyeqprec;": "\u22de", + "curlyeqsucc;": "\u22df", + "curlyvee;": "\u22ce", + "curlywedge;": "\u22cf", + "curren": "\xa4", + "curren;": "\xa4", + "curvearrowleft;": "\u21b6", + "curvearrowright;": "\u21b7", + "cuvee;": "\u22ce", + "cuwed;": "\u22cf", + "cwconint;": "\u2232", + "cwint;": "\u2231", + "cylcty;": "\u232d", + "dArr;": "\u21d3", + "dHar;": "\u2965", + "dagger;": "\u2020", + "daleth;": "\u2138", + "darr;": "\u2193", + "dash;": "\u2010", + "dashv;": "\u22a3", + "dbkarow;": "\u290f", + "dblac;": "\u02dd", + "dcaron;": "\u010f", + "dcy;": "\u0434", + "dd;": "\u2146", + "ddagger;": "\u2021", + "ddarr;": "\u21ca", + "ddotseq;": "\u2a77", + "deg": "\xb0", + "deg;": "\xb0", + "delta;": "\u03b4", + "demptyv;": "\u29b1", + "dfisht;": "\u297f", + "dfr;": "\U0001d521", + "dharl;": "\u21c3", + "dharr;": "\u21c2", + "diam;": "\u22c4", + "diamond;": "\u22c4", + "diamondsuit;": "\u2666", + "diams;": "\u2666", + "die;": "\xa8", + "digamma;": "\u03dd", + "disin;": "\u22f2", + "div;": "\xf7", + "divide": "\xf7", + "divide;": "\xf7", + "divideontimes;": "\u22c7", + "divonx;": "\u22c7", + "djcy;": "\u0452", + "dlcorn;": "\u231e", + "dlcrop;": "\u230d", + "dollar;": "$", + "dopf;": "\U0001d555", + "dot;": "\u02d9", + "doteq;": "\u2250", + "doteqdot;": "\u2251", + "dotminus;": "\u2238", + "dotplus;": "\u2214", + "dotsquare;": "\u22a1", + "doublebarwedge;": "\u2306", + "downarrow;": "\u2193", + "downdownarrows;": "\u21ca", + "downharpoonleft;": "\u21c3", + "downharpoonright;": "\u21c2", + "drbkarow;": "\u2910", + "drcorn;": "\u231f", + "drcrop;": "\u230c", + "dscr;": "\U0001d4b9", + "dscy;": "\u0455", + "dsol;": "\u29f6", + "dstrok;": "\u0111", + "dtdot;": "\u22f1", + "dtri;": "\u25bf", + "dtrif;": "\u25be", + "duarr;": "\u21f5", + "duhar;": "\u296f", + "dwangle;": "\u29a6", + "dzcy;": "\u045f", + "dzigrarr;": "\u27ff", + "eDDot;": "\u2a77", + "eDot;": "\u2251", + "eacute": "\xe9", + "eacute;": "\xe9", + "easter;": "\u2a6e", + "ecaron;": "\u011b", + "ecir;": "\u2256", + "ecirc": "\xea", + "ecirc;": "\xea", + "ecolon;": "\u2255", + "ecy;": "\u044d", + "edot;": "\u0117", + "ee;": "\u2147", + "efDot;": "\u2252", + "efr;": "\U0001d522", + "eg;": "\u2a9a", + "egrave": "\xe8", + "egrave;": "\xe8", + "egs;": "\u2a96", + "egsdot;": "\u2a98", + "el;": "\u2a99", + "elinters;": "\u23e7", + "ell;": "\u2113", + "els;": "\u2a95", + "elsdot;": "\u2a97", + "emacr;": "\u0113", + "empty;": "\u2205", + "emptyset;": "\u2205", + "emptyv;": "\u2205", + "emsp13;": "\u2004", + "emsp14;": "\u2005", + "emsp;": "\u2003", + "eng;": "\u014b", + "ensp;": "\u2002", + "eogon;": "\u0119", + "eopf;": "\U0001d556", + "epar;": "\u22d5", + "eparsl;": "\u29e3", + "eplus;": "\u2a71", + "epsi;": "\u03b5", + "epsilon;": "\u03b5", + "epsiv;": "\u03f5", + "eqcirc;": "\u2256", + "eqcolon;": "\u2255", + "eqsim;": "\u2242", + "eqslantgtr;": "\u2a96", + "eqslantless;": "\u2a95", + "equals;": "=", + "equest;": "\u225f", + "equiv;": "\u2261", + "equivDD;": "\u2a78", + "eqvparsl;": "\u29e5", + "erDot;": "\u2253", + "erarr;": "\u2971", + "escr;": "\u212f", + "esdot;": "\u2250", + "esim;": "\u2242", + "eta;": "\u03b7", + "eth": "\xf0", + "eth;": "\xf0", + "euml": "\xeb", + "euml;": "\xeb", + "euro;": "\u20ac", + "excl;": "!", + "exist;": "\u2203", + "expectation;": "\u2130", + "exponentiale;": "\u2147", + "fallingdotseq;": "\u2252", + "fcy;": "\u0444", + "female;": "\u2640", + "ffilig;": "\ufb03", + "fflig;": "\ufb00", + "ffllig;": "\ufb04", + "ffr;": "\U0001d523", + "filig;": "\ufb01", + "fjlig;": "fj", + "flat;": "\u266d", + "fllig;": "\ufb02", + "fltns;": "\u25b1", + "fnof;": "\u0192", + "fopf;": "\U0001d557", + "forall;": "\u2200", + "fork;": "\u22d4", + "forkv;": "\u2ad9", + "fpartint;": "\u2a0d", + "frac12": "\xbd", + "frac12;": "\xbd", + "frac13;": "\u2153", + "frac14": "\xbc", + "frac14;": "\xbc", + "frac15;": "\u2155", + "frac16;": "\u2159", + "frac18;": "\u215b", + "frac23;": "\u2154", + "frac25;": "\u2156", + "frac34": "\xbe", + "frac34;": "\xbe", + "frac35;": "\u2157", + "frac38;": "\u215c", + "frac45;": "\u2158", + "frac56;": "\u215a", + "frac58;": "\u215d", + "frac78;": "\u215e", + "frasl;": "\u2044", + "frown;": "\u2322", + "fscr;": "\U0001d4bb", + "gE;": "\u2267", + "gEl;": "\u2a8c", + "gacute;": "\u01f5", + "gamma;": "\u03b3", + "gammad;": "\u03dd", + "gap;": "\u2a86", + "gbreve;": "\u011f", + "gcirc;": "\u011d", + "gcy;": "\u0433", + "gdot;": "\u0121", + "ge;": "\u2265", + "gel;": "\u22db", + "geq;": "\u2265", + "geqq;": "\u2267", + "geqslant;": "\u2a7e", + "ges;": "\u2a7e", + "gescc;": "\u2aa9", + "gesdot;": "\u2a80", + "gesdoto;": "\u2a82", + "gesdotol;": "\u2a84", + "gesl;": "\u22db\ufe00", + "gesles;": "\u2a94", + "gfr;": "\U0001d524", + "gg;": "\u226b", + "ggg;": "\u22d9", + "gimel;": "\u2137", + "gjcy;": "\u0453", + "gl;": "\u2277", + "glE;": "\u2a92", + "gla;": "\u2aa5", + "glj;": "\u2aa4", + "gnE;": "\u2269", + "gnap;": "\u2a8a", + "gnapprox;": "\u2a8a", + "gne;": "\u2a88", + "gneq;": "\u2a88", + "gneqq;": "\u2269", + "gnsim;": "\u22e7", + "gopf;": "\U0001d558", + "grave;": "`", + "gscr;": "\u210a", + "gsim;": "\u2273", + "gsime;": "\u2a8e", + "gsiml;": "\u2a90", + "gt": ">", + "gt;": ">", + "gtcc;": "\u2aa7", + "gtcir;": "\u2a7a", + "gtdot;": "\u22d7", + "gtlPar;": "\u2995", + "gtquest;": "\u2a7c", + "gtrapprox;": "\u2a86", + "gtrarr;": "\u2978", + "gtrdot;": "\u22d7", + "gtreqless;": "\u22db", + "gtreqqless;": "\u2a8c", + "gtrless;": "\u2277", + "gtrsim;": "\u2273", + "gvertneqq;": "\u2269\ufe00", + "gvnE;": "\u2269\ufe00", + "hArr;": "\u21d4", + "hairsp;": "\u200a", + "half;": "\xbd", + "hamilt;": "\u210b", + "hardcy;": "\u044a", + "harr;": "\u2194", + "harrcir;": "\u2948", + "harrw;": "\u21ad", + "hbar;": "\u210f", + "hcirc;": "\u0125", + "hearts;": "\u2665", + "heartsuit;": "\u2665", + "hellip;": "\u2026", + "hercon;": "\u22b9", + "hfr;": "\U0001d525", + "hksearow;": "\u2925", + "hkswarow;": "\u2926", + "hoarr;": "\u21ff", + "homtht;": "\u223b", + "hookleftarrow;": "\u21a9", + "hookrightarrow;": "\u21aa", + "hopf;": "\U0001d559", + "horbar;": "\u2015", + "hscr;": "\U0001d4bd", + "hslash;": "\u210f", + "hstrok;": "\u0127", + "hybull;": "\u2043", + "hyphen;": "\u2010", + "iacute": "\xed", + "iacute;": "\xed", + "ic;": "\u2063", + "icirc": "\xee", + "icirc;": "\xee", + "icy;": "\u0438", + "iecy;": "\u0435", + "iexcl": "\xa1", + "iexcl;": "\xa1", + "iff;": "\u21d4", + "ifr;": "\U0001d526", + "igrave": "\xec", + "igrave;": "\xec", + "ii;": "\u2148", + "iiiint;": "\u2a0c", + "iiint;": "\u222d", + "iinfin;": "\u29dc", + "iiota;": "\u2129", + "ijlig;": "\u0133", + "imacr;": "\u012b", + "image;": "\u2111", + "imagline;": "\u2110", + "imagpart;": "\u2111", + "imath;": "\u0131", + "imof;": "\u22b7", + "imped;": "\u01b5", + "in;": "\u2208", + "incare;": "\u2105", + "infin;": "\u221e", + "infintie;": "\u29dd", + "inodot;": "\u0131", + "int;": "\u222b", + "intcal;": "\u22ba", + "integers;": "\u2124", + "intercal;": "\u22ba", + "intlarhk;": "\u2a17", + "intprod;": "\u2a3c", + "iocy;": "\u0451", + "iogon;": "\u012f", + "iopf;": "\U0001d55a", + "iota;": "\u03b9", + "iprod;": "\u2a3c", + "iquest": "\xbf", + "iquest;": "\xbf", + "iscr;": "\U0001d4be", + "isin;": "\u2208", + "isinE;": "\u22f9", + "isindot;": "\u22f5", + "isins;": "\u22f4", + "isinsv;": "\u22f3", + "isinv;": "\u2208", + "it;": "\u2062", + "itilde;": "\u0129", + "iukcy;": "\u0456", + "iuml": "\xef", + "iuml;": "\xef", + "jcirc;": "\u0135", + "jcy;": "\u0439", + "jfr;": "\U0001d527", + "jmath;": "\u0237", + "jopf;": "\U0001d55b", + "jscr;": "\U0001d4bf", + "jsercy;": "\u0458", + "jukcy;": "\u0454", + "kappa;": "\u03ba", + "kappav;": "\u03f0", + "kcedil;": "\u0137", + "kcy;": "\u043a", + "kfr;": "\U0001d528", + "kgreen;": "\u0138", + "khcy;": "\u0445", + "kjcy;": "\u045c", + "kopf;": "\U0001d55c", + "kscr;": "\U0001d4c0", + "lAarr;": "\u21da", + "lArr;": "\u21d0", + "lAtail;": "\u291b", + "lBarr;": "\u290e", + "lE;": "\u2266", + "lEg;": "\u2a8b", + "lHar;": "\u2962", + "lacute;": "\u013a", + "laemptyv;": "\u29b4", + "lagran;": "\u2112", + "lambda;": "\u03bb", + "lang;": "\u27e8", + "langd;": "\u2991", + "langle;": "\u27e8", + "lap;": "\u2a85", + "laquo": "\xab", + "laquo;": "\xab", + "larr;": "\u2190", + "larrb;": "\u21e4", + "larrbfs;": "\u291f", + "larrfs;": "\u291d", + "larrhk;": "\u21a9", + "larrlp;": "\u21ab", + "larrpl;": "\u2939", + "larrsim;": "\u2973", + "larrtl;": "\u21a2", + "lat;": "\u2aab", + "latail;": "\u2919", + "late;": "\u2aad", + "lates;": "\u2aad\ufe00", + "lbarr;": "\u290c", + "lbbrk;": "\u2772", + "lbrace;": "{", + "lbrack;": "[", + "lbrke;": "\u298b", + "lbrksld;": "\u298f", + "lbrkslu;": "\u298d", + "lcaron;": "\u013e", + "lcedil;": "\u013c", + "lceil;": "\u2308", + "lcub;": "{", + "lcy;": "\u043b", + "ldca;": "\u2936", + "ldquo;": "\u201c", + "ldquor;": "\u201e", + "ldrdhar;": "\u2967", + "ldrushar;": "\u294b", + "ldsh;": "\u21b2", + "le;": "\u2264", + "leftarrow;": "\u2190", + "leftarrowtail;": "\u21a2", + "leftharpoondown;": "\u21bd", + "leftharpoonup;": "\u21bc", + "leftleftarrows;": "\u21c7", + "leftrightarrow;": "\u2194", + "leftrightarrows;": "\u21c6", + "leftrightharpoons;": "\u21cb", + "leftrightsquigarrow;": "\u21ad", + "leftthreetimes;": "\u22cb", + "leg;": "\u22da", + "leq;": "\u2264", + "leqq;": "\u2266", + "leqslant;": "\u2a7d", + "les;": "\u2a7d", + "lescc;": "\u2aa8", + "lesdot;": "\u2a7f", + "lesdoto;": "\u2a81", + "lesdotor;": "\u2a83", + "lesg;": "\u22da\ufe00", + "lesges;": "\u2a93", + "lessapprox;": "\u2a85", + "lessdot;": "\u22d6", + "lesseqgtr;": "\u22da", + "lesseqqgtr;": "\u2a8b", + "lessgtr;": "\u2276", + "lesssim;": "\u2272", + "lfisht;": "\u297c", + "lfloor;": "\u230a", + "lfr;": "\U0001d529", + "lg;": "\u2276", + "lgE;": "\u2a91", + "lhard;": "\u21bd", + "lharu;": "\u21bc", + "lharul;": "\u296a", + "lhblk;": "\u2584", + "ljcy;": "\u0459", + "ll;": "\u226a", + "llarr;": "\u21c7", + "llcorner;": "\u231e", + "llhard;": "\u296b", + "lltri;": "\u25fa", + "lmidot;": "\u0140", + "lmoust;": "\u23b0", + "lmoustache;": "\u23b0", + "lnE;": "\u2268", + "lnap;": "\u2a89", + "lnapprox;": "\u2a89", + "lne;": "\u2a87", + "lneq;": "\u2a87", + "lneqq;": "\u2268", + "lnsim;": "\u22e6", + "loang;": "\u27ec", + "loarr;": "\u21fd", + "lobrk;": "\u27e6", + "longleftarrow;": "\u27f5", + "longleftrightarrow;": "\u27f7", + "longmapsto;": "\u27fc", + "longrightarrow;": "\u27f6", + "looparrowleft;": "\u21ab", + "looparrowright;": "\u21ac", + "lopar;": "\u2985", + "lopf;": "\U0001d55d", + "loplus;": "\u2a2d", + "lotimes;": "\u2a34", + "lowast;": "\u2217", + "lowbar;": "_", + "loz;": "\u25ca", + "lozenge;": "\u25ca", + "lozf;": "\u29eb", + "lpar;": "(", + "lparlt;": "\u2993", + "lrarr;": "\u21c6", + "lrcorner;": "\u231f", + "lrhar;": "\u21cb", + "lrhard;": "\u296d", + "lrm;": "\u200e", + "lrtri;": "\u22bf", + "lsaquo;": "\u2039", + "lscr;": "\U0001d4c1", + "lsh;": "\u21b0", + "lsim;": "\u2272", + "lsime;": "\u2a8d", + "lsimg;": "\u2a8f", + "lsqb;": "[", + "lsquo;": "\u2018", + "lsquor;": "\u201a", + "lstrok;": "\u0142", + "lt": "<", + "lt;": "<", + "ltcc;": "\u2aa6", + "ltcir;": "\u2a79", + "ltdot;": "\u22d6", + "lthree;": "\u22cb", + "ltimes;": "\u22c9", + "ltlarr;": "\u2976", + "ltquest;": "\u2a7b", + "ltrPar;": "\u2996", + "ltri;": "\u25c3", + "ltrie;": "\u22b4", + "ltrif;": "\u25c2", + "lurdshar;": "\u294a", + "luruhar;": "\u2966", + "lvertneqq;": "\u2268\ufe00", + "lvnE;": "\u2268\ufe00", + "mDDot;": "\u223a", + "macr": "\xaf", + "macr;": "\xaf", + "male;": "\u2642", + "malt;": "\u2720", + "maltese;": "\u2720", + "map;": "\u21a6", + "mapsto;": "\u21a6", + "mapstodown;": "\u21a7", + "mapstoleft;": "\u21a4", + "mapstoup;": "\u21a5", + "marker;": "\u25ae", + "mcomma;": "\u2a29", + "mcy;": "\u043c", + "mdash;": "\u2014", + "measuredangle;": "\u2221", + "mfr;": "\U0001d52a", + "mho;": "\u2127", + "micro": "\xb5", + "micro;": "\xb5", + "mid;": "\u2223", + "midast;": "*", + "midcir;": "\u2af0", + "middot": "\xb7", + "middot;": "\xb7", + "minus;": "\u2212", + "minusb;": "\u229f", + "minusd;": "\u2238", + "minusdu;": "\u2a2a", + "mlcp;": "\u2adb", + "mldr;": "\u2026", + "mnplus;": "\u2213", + "models;": "\u22a7", + "mopf;": "\U0001d55e", + "mp;": "\u2213", + "mscr;": "\U0001d4c2", + "mstpos;": "\u223e", + "mu;": "\u03bc", + "multimap;": "\u22b8", + "mumap;": "\u22b8", + "nGg;": "\u22d9\u0338", + "nGt;": "\u226b\u20d2", + "nGtv;": "\u226b\u0338", + "nLeftarrow;": "\u21cd", + "nLeftrightarrow;": "\u21ce", + "nLl;": "\u22d8\u0338", + "nLt;": "\u226a\u20d2", + "nLtv;": "\u226a\u0338", + "nRightarrow;": "\u21cf", + "nVDash;": "\u22af", + "nVdash;": "\u22ae", + "nabla;": "\u2207", + "nacute;": "\u0144", + "nang;": "\u2220\u20d2", + "nap;": "\u2249", + "napE;": "\u2a70\u0338", + "napid;": "\u224b\u0338", + "napos;": "\u0149", + "napprox;": "\u2249", + "natur;": "\u266e", + "natural;": "\u266e", + "naturals;": "\u2115", + "nbsp": "\xa0", + "nbsp;": "\xa0", + "nbump;": "\u224e\u0338", + "nbumpe;": "\u224f\u0338", + "ncap;": "\u2a43", + "ncaron;": "\u0148", + "ncedil;": "\u0146", + "ncong;": "\u2247", + "ncongdot;": "\u2a6d\u0338", + "ncup;": "\u2a42", + "ncy;": "\u043d", + "ndash;": "\u2013", + "ne;": "\u2260", + "neArr;": "\u21d7", + "nearhk;": "\u2924", + "nearr;": "\u2197", + "nearrow;": "\u2197", + "nedot;": "\u2250\u0338", + "nequiv;": "\u2262", + "nesear;": "\u2928", + "nesim;": "\u2242\u0338", + "nexist;": "\u2204", + "nexists;": "\u2204", + "nfr;": "\U0001d52b", + "ngE;": "\u2267\u0338", + "nge;": "\u2271", + "ngeq;": "\u2271", + "ngeqq;": "\u2267\u0338", + "ngeqslant;": "\u2a7e\u0338", + "nges;": "\u2a7e\u0338", + "ngsim;": "\u2275", + "ngt;": "\u226f", + "ngtr;": "\u226f", + "nhArr;": "\u21ce", + "nharr;": "\u21ae", + "nhpar;": "\u2af2", + "ni;": "\u220b", + "nis;": "\u22fc", + "nisd;": "\u22fa", + "niv;": "\u220b", + "njcy;": "\u045a", + "nlArr;": "\u21cd", + "nlE;": "\u2266\u0338", + "nlarr;": "\u219a", + "nldr;": "\u2025", + "nle;": "\u2270", + "nleftarrow;": "\u219a", + "nleftrightarrow;": "\u21ae", + "nleq;": "\u2270", + "nleqq;": "\u2266\u0338", + "nleqslant;": "\u2a7d\u0338", + "nles;": "\u2a7d\u0338", + "nless;": "\u226e", + "nlsim;": "\u2274", + "nlt;": "\u226e", + "nltri;": "\u22ea", + "nltrie;": "\u22ec", + "nmid;": "\u2224", + "nopf;": "\U0001d55f", + "not": "\xac", + "not;": "\xac", + "notin;": "\u2209", + "notinE;": "\u22f9\u0338", + "notindot;": "\u22f5\u0338", + "notinva;": "\u2209", + "notinvb;": "\u22f7", + "notinvc;": "\u22f6", + "notni;": "\u220c", + "notniva;": "\u220c", + "notnivb;": "\u22fe", + "notnivc;": "\u22fd", + "npar;": "\u2226", + "nparallel;": "\u2226", + "nparsl;": "\u2afd\u20e5", + "npart;": "\u2202\u0338", + "npolint;": "\u2a14", + "npr;": "\u2280", + "nprcue;": "\u22e0", + "npre;": "\u2aaf\u0338", + "nprec;": "\u2280", + "npreceq;": "\u2aaf\u0338", + "nrArr;": "\u21cf", + "nrarr;": "\u219b", + "nrarrc;": "\u2933\u0338", + "nrarrw;": "\u219d\u0338", + "nrightarrow;": "\u219b", + "nrtri;": "\u22eb", + "nrtrie;": "\u22ed", + "nsc;": "\u2281", + "nsccue;": "\u22e1", + "nsce;": "\u2ab0\u0338", + "nscr;": "\U0001d4c3", + "nshortmid;": "\u2224", + "nshortparallel;": "\u2226", + "nsim;": "\u2241", + "nsime;": "\u2244", + "nsimeq;": "\u2244", + "nsmid;": "\u2224", + "nspar;": "\u2226", + "nsqsube;": "\u22e2", + "nsqsupe;": "\u22e3", + "nsub;": "\u2284", + "nsubE;": "\u2ac5\u0338", + "nsube;": "\u2288", + "nsubset;": "\u2282\u20d2", + "nsubseteq;": "\u2288", + "nsubseteqq;": "\u2ac5\u0338", + "nsucc;": "\u2281", + "nsucceq;": "\u2ab0\u0338", + "nsup;": "\u2285", + "nsupE;": "\u2ac6\u0338", + "nsupe;": "\u2289", + "nsupset;": "\u2283\u20d2", + "nsupseteq;": "\u2289", + "nsupseteqq;": "\u2ac6\u0338", + "ntgl;": "\u2279", + "ntilde": "\xf1", + "ntilde;": "\xf1", + "ntlg;": "\u2278", + "ntriangleleft;": "\u22ea", + "ntrianglelefteq;": "\u22ec", + "ntriangleright;": "\u22eb", + "ntrianglerighteq;": "\u22ed", + "nu;": "\u03bd", + "num;": "#", + "numero;": "\u2116", + "numsp;": "\u2007", + "nvDash;": "\u22ad", + "nvHarr;": "\u2904", + "nvap;": "\u224d\u20d2", + "nvdash;": "\u22ac", + "nvge;": "\u2265\u20d2", + "nvgt;": ">\u20d2", + "nvinfin;": "\u29de", + "nvlArr;": "\u2902", + "nvle;": "\u2264\u20d2", + "nvlt;": "<\u20d2", + "nvltrie;": "\u22b4\u20d2", + "nvrArr;": "\u2903", + "nvrtrie;": "\u22b5\u20d2", + "nvsim;": "\u223c\u20d2", + "nwArr;": "\u21d6", + "nwarhk;": "\u2923", + "nwarr;": "\u2196", + "nwarrow;": "\u2196", + "nwnear;": "\u2927", + "oS;": "\u24c8", + "oacute": "\xf3", + "oacute;": "\xf3", + "oast;": "\u229b", + "ocir;": "\u229a", + "ocirc": "\xf4", + "ocirc;": "\xf4", + "ocy;": "\u043e", + "odash;": "\u229d", + "odblac;": "\u0151", + "odiv;": "\u2a38", + "odot;": "\u2299", + "odsold;": "\u29bc", + "oelig;": "\u0153", + "ofcir;": "\u29bf", + "ofr;": "\U0001d52c", + "ogon;": "\u02db", + "ograve": "\xf2", + "ograve;": "\xf2", + "ogt;": "\u29c1", + "ohbar;": "\u29b5", + "ohm;": "\u03a9", + "oint;": "\u222e", + "olarr;": "\u21ba", + "olcir;": "\u29be", + "olcross;": "\u29bb", + "oline;": "\u203e", + "olt;": "\u29c0", + "omacr;": "\u014d", + "omega;": "\u03c9", + "omicron;": "\u03bf", + "omid;": "\u29b6", + "ominus;": "\u2296", + "oopf;": "\U0001d560", + "opar;": "\u29b7", + "operp;": "\u29b9", + "oplus;": "\u2295", + "or;": "\u2228", + "orarr;": "\u21bb", + "ord;": "\u2a5d", + "order;": "\u2134", + "orderof;": "\u2134", + "ordf": "\xaa", + "ordf;": "\xaa", + "ordm": "\xba", + "ordm;": "\xba", + "origof;": "\u22b6", + "oror;": "\u2a56", + "orslope;": "\u2a57", + "orv;": "\u2a5b", + "oscr;": "\u2134", + "oslash": "\xf8", + "oslash;": "\xf8", + "osol;": "\u2298", + "otilde": "\xf5", + "otilde;": "\xf5", + "otimes;": "\u2297", + "otimesas;": "\u2a36", + "ouml": "\xf6", + "ouml;": "\xf6", + "ovbar;": "\u233d", + "par;": "\u2225", + "para": "\xb6", + "para;": "\xb6", + "parallel;": "\u2225", + "parsim;": "\u2af3", + "parsl;": "\u2afd", + "part;": "\u2202", + "pcy;": "\u043f", + "percnt;": "%", + "period;": ".", + "permil;": "\u2030", + "perp;": "\u22a5", + "pertenk;": "\u2031", + "pfr;": "\U0001d52d", + "phi;": "\u03c6", + "phiv;": "\u03d5", + "phmmat;": "\u2133", + "phone;": "\u260e", + "pi;": "\u03c0", + "pitchfork;": "\u22d4", + "piv;": "\u03d6", + "planck;": "\u210f", + "planckh;": "\u210e", + "plankv;": "\u210f", + "plus;": "+", + "plusacir;": "\u2a23", + "plusb;": "\u229e", + "pluscir;": "\u2a22", + "plusdo;": "\u2214", + "plusdu;": "\u2a25", + "pluse;": "\u2a72", + "plusmn": "\xb1", + "plusmn;": "\xb1", + "plussim;": "\u2a26", + "plustwo;": "\u2a27", + "pm;": "\xb1", + "pointint;": "\u2a15", + "popf;": "\U0001d561", + "pound": "\xa3", + "pound;": "\xa3", + "pr;": "\u227a", + "prE;": "\u2ab3", + "prap;": "\u2ab7", + "prcue;": "\u227c", + "pre;": "\u2aaf", + "prec;": "\u227a", + "precapprox;": "\u2ab7", + "preccurlyeq;": "\u227c", + "preceq;": "\u2aaf", + "precnapprox;": "\u2ab9", + "precneqq;": "\u2ab5", + "precnsim;": "\u22e8", + "precsim;": "\u227e", + "prime;": "\u2032", + "primes;": "\u2119", + "prnE;": "\u2ab5", + "prnap;": "\u2ab9", + "prnsim;": "\u22e8", + "prod;": "\u220f", + "profalar;": "\u232e", + "profline;": "\u2312", + "profsurf;": "\u2313", + "prop;": "\u221d", + "propto;": "\u221d", + "prsim;": "\u227e", + "prurel;": "\u22b0", + "pscr;": "\U0001d4c5", + "psi;": "\u03c8", + "puncsp;": "\u2008", + "qfr;": "\U0001d52e", + "qint;": "\u2a0c", + "qopf;": "\U0001d562", + "qprime;": "\u2057", + "qscr;": "\U0001d4c6", + "quaternions;": "\u210d", + "quatint;": "\u2a16", + "quest;": "?", + "questeq;": "\u225f", + "quot": "\"", + "quot;": "\"", + "rAarr;": "\u21db", + "rArr;": "\u21d2", + "rAtail;": "\u291c", + "rBarr;": "\u290f", + "rHar;": "\u2964", + "race;": "\u223d\u0331", + "racute;": "\u0155", + "radic;": "\u221a", + "raemptyv;": "\u29b3", + "rang;": "\u27e9", + "rangd;": "\u2992", + "range;": "\u29a5", + "rangle;": "\u27e9", + "raquo": "\xbb", + "raquo;": "\xbb", + "rarr;": "\u2192", + "rarrap;": "\u2975", + "rarrb;": "\u21e5", + "rarrbfs;": "\u2920", + "rarrc;": "\u2933", + "rarrfs;": "\u291e", + "rarrhk;": "\u21aa", + "rarrlp;": "\u21ac", + "rarrpl;": "\u2945", + "rarrsim;": "\u2974", + "rarrtl;": "\u21a3", + "rarrw;": "\u219d", + "ratail;": "\u291a", + "ratio;": "\u2236", + "rationals;": "\u211a", + "rbarr;": "\u290d", + "rbbrk;": "\u2773", + "rbrace;": "}", + "rbrack;": "]", + "rbrke;": "\u298c", + "rbrksld;": "\u298e", + "rbrkslu;": "\u2990", + "rcaron;": "\u0159", + "rcedil;": "\u0157", + "rceil;": "\u2309", + "rcub;": "}", + "rcy;": "\u0440", + "rdca;": "\u2937", + "rdldhar;": "\u2969", + "rdquo;": "\u201d", + "rdquor;": "\u201d", + "rdsh;": "\u21b3", + "real;": "\u211c", + "realine;": "\u211b", + "realpart;": "\u211c", + "reals;": "\u211d", + "rect;": "\u25ad", + "reg": "\xae", + "reg;": "\xae", + "rfisht;": "\u297d", + "rfloor;": "\u230b", + "rfr;": "\U0001d52f", + "rhard;": "\u21c1", + "rharu;": "\u21c0", + "rharul;": "\u296c", + "rho;": "\u03c1", + "rhov;": "\u03f1", + "rightarrow;": "\u2192", + "rightarrowtail;": "\u21a3", + "rightharpoondown;": "\u21c1", + "rightharpoonup;": "\u21c0", + "rightleftarrows;": "\u21c4", + "rightleftharpoons;": "\u21cc", + "rightrightarrows;": "\u21c9", + "rightsquigarrow;": "\u219d", + "rightthreetimes;": "\u22cc", + "ring;": "\u02da", + "risingdotseq;": "\u2253", + "rlarr;": "\u21c4", + "rlhar;": "\u21cc", + "rlm;": "\u200f", + "rmoust;": "\u23b1", + "rmoustache;": "\u23b1", + "rnmid;": "\u2aee", + "roang;": "\u27ed", + "roarr;": "\u21fe", + "robrk;": "\u27e7", + "ropar;": "\u2986", + "ropf;": "\U0001d563", + "roplus;": "\u2a2e", + "rotimes;": "\u2a35", + "rpar;": ")", + "rpargt;": "\u2994", + "rppolint;": "\u2a12", + "rrarr;": "\u21c9", + "rsaquo;": "\u203a", + "rscr;": "\U0001d4c7", + "rsh;": "\u21b1", + "rsqb;": "]", + "rsquo;": "\u2019", + "rsquor;": "\u2019", + "rthree;": "\u22cc", + "rtimes;": "\u22ca", + "rtri;": "\u25b9", + "rtrie;": "\u22b5", + "rtrif;": "\u25b8", + "rtriltri;": "\u29ce", + "ruluhar;": "\u2968", + "rx;": "\u211e", + "sacute;": "\u015b", + "sbquo;": "\u201a", + "sc;": "\u227b", + "scE;": "\u2ab4", + "scap;": "\u2ab8", + "scaron;": "\u0161", + "sccue;": "\u227d", + "sce;": "\u2ab0", + "scedil;": "\u015f", + "scirc;": "\u015d", + "scnE;": "\u2ab6", + "scnap;": "\u2aba", + "scnsim;": "\u22e9", + "scpolint;": "\u2a13", + "scsim;": "\u227f", + "scy;": "\u0441", + "sdot;": "\u22c5", + "sdotb;": "\u22a1", + "sdote;": "\u2a66", + "seArr;": "\u21d8", + "searhk;": "\u2925", + "searr;": "\u2198", + "searrow;": "\u2198", + "sect": "\xa7", + "sect;": "\xa7", + "semi;": ";", + "seswar;": "\u2929", + "setminus;": "\u2216", + "setmn;": "\u2216", + "sext;": "\u2736", + "sfr;": "\U0001d530", + "sfrown;": "\u2322", + "sharp;": "\u266f", + "shchcy;": "\u0449", + "shcy;": "\u0448", + "shortmid;": "\u2223", + "shortparallel;": "\u2225", + "shy": "\xad", + "shy;": "\xad", + "sigma;": "\u03c3", + "sigmaf;": "\u03c2", + "sigmav;": "\u03c2", + "sim;": "\u223c", + "simdot;": "\u2a6a", + "sime;": "\u2243", + "simeq;": "\u2243", + "simg;": "\u2a9e", + "simgE;": "\u2aa0", + "siml;": "\u2a9d", + "simlE;": "\u2a9f", + "simne;": "\u2246", + "simplus;": "\u2a24", + "simrarr;": "\u2972", + "slarr;": "\u2190", + "smallsetminus;": "\u2216", + "smashp;": "\u2a33", + "smeparsl;": "\u29e4", + "smid;": "\u2223", + "smile;": "\u2323", + "smt;": "\u2aaa", + "smte;": "\u2aac", + "smtes;": "\u2aac\ufe00", + "softcy;": "\u044c", + "sol;": "/", + "solb;": "\u29c4", + "solbar;": "\u233f", + "sopf;": "\U0001d564", + "spades;": "\u2660", + "spadesuit;": "\u2660", + "spar;": "\u2225", + "sqcap;": "\u2293", + "sqcaps;": "\u2293\ufe00", + "sqcup;": "\u2294", + "sqcups;": "\u2294\ufe00", + "sqsub;": "\u228f", + "sqsube;": "\u2291", + "sqsubset;": "\u228f", + "sqsubseteq;": "\u2291", + "sqsup;": "\u2290", + "sqsupe;": "\u2292", + "sqsupset;": "\u2290", + "sqsupseteq;": "\u2292", + "squ;": "\u25a1", + "square;": "\u25a1", + "squarf;": "\u25aa", + "squf;": "\u25aa", + "srarr;": "\u2192", + "sscr;": "\U0001d4c8", + "ssetmn;": "\u2216", + "ssmile;": "\u2323", + "sstarf;": "\u22c6", + "star;": "\u2606", + "starf;": "\u2605", + "straightepsilon;": "\u03f5", + "straightphi;": "\u03d5", + "strns;": "\xaf", + "sub;": "\u2282", + "subE;": "\u2ac5", + "subdot;": "\u2abd", + "sube;": "\u2286", + "subedot;": "\u2ac3", + "submult;": "\u2ac1", + "subnE;": "\u2acb", + "subne;": "\u228a", + "subplus;": "\u2abf", + "subrarr;": "\u2979", + "subset;": "\u2282", + "subseteq;": "\u2286", + "subseteqq;": "\u2ac5", + "subsetneq;": "\u228a", + "subsetneqq;": "\u2acb", + "subsim;": "\u2ac7", + "subsub;": "\u2ad5", + "subsup;": "\u2ad3", + "succ;": "\u227b", + "succapprox;": "\u2ab8", + "succcurlyeq;": "\u227d", + "succeq;": "\u2ab0", + "succnapprox;": "\u2aba", + "succneqq;": "\u2ab6", + "succnsim;": "\u22e9", + "succsim;": "\u227f", + "sum;": "\u2211", + "sung;": "\u266a", + "sup1": "\xb9", + "sup1;": "\xb9", + "sup2": "\xb2", + "sup2;": "\xb2", + "sup3": "\xb3", + "sup3;": "\xb3", + "sup;": "\u2283", + "supE;": "\u2ac6", + "supdot;": "\u2abe", + "supdsub;": "\u2ad8", + "supe;": "\u2287", + "supedot;": "\u2ac4", + "suphsol;": "\u27c9", + "suphsub;": "\u2ad7", + "suplarr;": "\u297b", + "supmult;": "\u2ac2", + "supnE;": "\u2acc", + "supne;": "\u228b", + "supplus;": "\u2ac0", + "supset;": "\u2283", + "supseteq;": "\u2287", + "supseteqq;": "\u2ac6", + "supsetneq;": "\u228b", + "supsetneqq;": "\u2acc", + "supsim;": "\u2ac8", + "supsub;": "\u2ad4", + "supsup;": "\u2ad6", + "swArr;": "\u21d9", + "swarhk;": "\u2926", + "swarr;": "\u2199", + "swarrow;": "\u2199", + "swnwar;": "\u292a", + "szlig": "\xdf", + "szlig;": "\xdf", + "target;": "\u2316", + "tau;": "\u03c4", + "tbrk;": "\u23b4", + "tcaron;": "\u0165", + "tcedil;": "\u0163", + "tcy;": "\u0442", + "tdot;": "\u20db", + "telrec;": "\u2315", + "tfr;": "\U0001d531", + "there4;": "\u2234", + "therefore;": "\u2234", + "theta;": "\u03b8", + "thetasym;": "\u03d1", + "thetav;": "\u03d1", + "thickapprox;": "\u2248", + "thicksim;": "\u223c", + "thinsp;": "\u2009", + "thkap;": "\u2248", + "thksim;": "\u223c", + "thorn": "\xfe", + "thorn;": "\xfe", + "tilde;": "\u02dc", + "times": "\xd7", + "times;": "\xd7", + "timesb;": "\u22a0", + "timesbar;": "\u2a31", + "timesd;": "\u2a30", + "tint;": "\u222d", + "toea;": "\u2928", + "top;": "\u22a4", + "topbot;": "\u2336", + "topcir;": "\u2af1", + "topf;": "\U0001d565", + "topfork;": "\u2ada", + "tosa;": "\u2929", + "tprime;": "\u2034", + "trade;": "\u2122", + "triangle;": "\u25b5", + "triangledown;": "\u25bf", + "triangleleft;": "\u25c3", + "trianglelefteq;": "\u22b4", + "triangleq;": "\u225c", + "triangleright;": "\u25b9", + "trianglerighteq;": "\u22b5", + "tridot;": "\u25ec", + "trie;": "\u225c", + "triminus;": "\u2a3a", + "triplus;": "\u2a39", + "trisb;": "\u29cd", + "tritime;": "\u2a3b", + "trpezium;": "\u23e2", + "tscr;": "\U0001d4c9", + "tscy;": "\u0446", + "tshcy;": "\u045b", + "tstrok;": "\u0167", + "twixt;": "\u226c", + "twoheadleftarrow;": "\u219e", + "twoheadrightarrow;": "\u21a0", + "uArr;": "\u21d1", + "uHar;": "\u2963", + "uacute": "\xfa", + "uacute;": "\xfa", + "uarr;": "\u2191", + "ubrcy;": "\u045e", + "ubreve;": "\u016d", + "ucirc": "\xfb", + "ucirc;": "\xfb", + "ucy;": "\u0443", + "udarr;": "\u21c5", + "udblac;": "\u0171", + "udhar;": "\u296e", + "ufisht;": "\u297e", + "ufr;": "\U0001d532", + "ugrave": "\xf9", + "ugrave;": "\xf9", + "uharl;": "\u21bf", + "uharr;": "\u21be", + "uhblk;": "\u2580", + "ulcorn;": "\u231c", + "ulcorner;": "\u231c", + "ulcrop;": "\u230f", + "ultri;": "\u25f8", + "umacr;": "\u016b", + "uml": "\xa8", + "uml;": "\xa8", + "uogon;": "\u0173", + "uopf;": "\U0001d566", + "uparrow;": "\u2191", + "updownarrow;": "\u2195", + "upharpoonleft;": "\u21bf", + "upharpoonright;": "\u21be", + "uplus;": "\u228e", + "upsi;": "\u03c5", + "upsih;": "\u03d2", + "upsilon;": "\u03c5", + "upuparrows;": "\u21c8", + "urcorn;": "\u231d", + "urcorner;": "\u231d", + "urcrop;": "\u230e", + "uring;": "\u016f", + "urtri;": "\u25f9", + "uscr;": "\U0001d4ca", + "utdot;": "\u22f0", + "utilde;": "\u0169", + "utri;": "\u25b5", + "utrif;": "\u25b4", + "uuarr;": "\u21c8", + "uuml": "\xfc", + "uuml;": "\xfc", + "uwangle;": "\u29a7", + "vArr;": "\u21d5", + "vBar;": "\u2ae8", + "vBarv;": "\u2ae9", + "vDash;": "\u22a8", + "vangrt;": "\u299c", + "varepsilon;": "\u03f5", + "varkappa;": "\u03f0", + "varnothing;": "\u2205", + "varphi;": "\u03d5", + "varpi;": "\u03d6", + "varpropto;": "\u221d", + "varr;": "\u2195", + "varrho;": "\u03f1", + "varsigma;": "\u03c2", + "varsubsetneq;": "\u228a\ufe00", + "varsubsetneqq;": "\u2acb\ufe00", + "varsupsetneq;": "\u228b\ufe00", + "varsupsetneqq;": "\u2acc\ufe00", + "vartheta;": "\u03d1", + "vartriangleleft;": "\u22b2", + "vartriangleright;": "\u22b3", + "vcy;": "\u0432", + "vdash;": "\u22a2", + "vee;": "\u2228", + "veebar;": "\u22bb", + "veeeq;": "\u225a", + "vellip;": "\u22ee", + "verbar;": "|", + "vert;": "|", + "vfr;": "\U0001d533", + "vltri;": "\u22b2", + "vnsub;": "\u2282\u20d2", + "vnsup;": "\u2283\u20d2", + "vopf;": "\U0001d567", + "vprop;": "\u221d", + "vrtri;": "\u22b3", + "vscr;": "\U0001d4cb", + "vsubnE;": "\u2acb\ufe00", + "vsubne;": "\u228a\ufe00", + "vsupnE;": "\u2acc\ufe00", + "vsupne;": "\u228b\ufe00", + "vzigzag;": "\u299a", + "wcirc;": "\u0175", + "wedbar;": "\u2a5f", + "wedge;": "\u2227", + "wedgeq;": "\u2259", + "weierp;": "\u2118", + "wfr;": "\U0001d534", + "wopf;": "\U0001d568", + "wp;": "\u2118", + "wr;": "\u2240", + "wreath;": "\u2240", + "wscr;": "\U0001d4cc", + "xcap;": "\u22c2", + "xcirc;": "\u25ef", + "xcup;": "\u22c3", + "xdtri;": "\u25bd", + "xfr;": "\U0001d535", + "xhArr;": "\u27fa", + "xharr;": "\u27f7", + "xi;": "\u03be", + "xlArr;": "\u27f8", + "xlarr;": "\u27f5", + "xmap;": "\u27fc", + "xnis;": "\u22fb", + "xodot;": "\u2a00", + "xopf;": "\U0001d569", + "xoplus;": "\u2a01", + "xotime;": "\u2a02", + "xrArr;": "\u27f9", + "xrarr;": "\u27f6", + "xscr;": "\U0001d4cd", + "xsqcup;": "\u2a06", + "xuplus;": "\u2a04", + "xutri;": "\u25b3", + "xvee;": "\u22c1", + "xwedge;": "\u22c0", + "yacute": "\xfd", + "yacute;": "\xfd", + "yacy;": "\u044f", + "ycirc;": "\u0177", + "ycy;": "\u044b", + "yen": "\xa5", + "yen;": "\xa5", + "yfr;": "\U0001d536", + "yicy;": "\u0457", + "yopf;": "\U0001d56a", + "yscr;": "\U0001d4ce", + "yucy;": "\u044e", + "yuml": "\xff", + "yuml;": "\xff", + "zacute;": "\u017a", + "zcaron;": "\u017e", + "zcy;": "\u0437", + "zdot;": "\u017c", + "zeetrf;": "\u2128", + "zeta;": "\u03b6", + "zfr;": "\U0001d537", + "zhcy;": "\u0436", + "zigrarr;": "\u21dd", + "zopf;": "\U0001d56b", + "zscr;": "\U0001d4cf", + "zwj;": "\u200d", + "zwnj;": "\u200c", +} + +replacementCharacters = { + 0x0: "\uFFFD", + 0x0d: "\u000D", + 0x80: "\u20AC", + 0x81: "\u0081", + 0x81: "\u0081", + 0x82: "\u201A", + 0x83: "\u0192", + 0x84: "\u201E", + 0x85: "\u2026", + 0x86: "\u2020", + 0x87: "\u2021", + 0x88: "\u02C6", + 0x89: "\u2030", + 0x8A: "\u0160", + 0x8B: "\u2039", + 0x8C: "\u0152", + 0x8D: "\u008D", + 0x8E: "\u017D", + 0x8F: "\u008F", + 0x90: "\u0090", + 0x91: "\u2018", + 0x92: "\u2019", + 0x93: "\u201C", + 0x94: "\u201D", + 0x95: "\u2022", + 0x96: "\u2013", + 0x97: "\u2014", + 0x98: "\u02DC", + 0x99: "\u2122", + 0x9A: "\u0161", + 0x9B: "\u203A", + 0x9C: "\u0153", + 0x9D: "\u009D", + 0x9E: "\u017E", + 0x9F: "\u0178", +} + +encodings = { + '437': 'cp437', + '850': 'cp850', + '852': 'cp852', + '855': 'cp855', + '857': 'cp857', + '860': 'cp860', + '861': 'cp861', + '862': 'cp862', + '863': 'cp863', + '865': 'cp865', + '866': 'cp866', + '869': 'cp869', + 'ansix341968': 'ascii', + 'ansix341986': 'ascii', + 'arabic': 'iso8859-6', + 'ascii': 'ascii', + 'asmo708': 'iso8859-6', + 'big5': 'big5', + 'big5hkscs': 'big5hkscs', + 'chinese': 'gbk', + 'cp037': 'cp037', + 'cp1026': 'cp1026', + 'cp154': 'ptcp154', + 'cp367': 'ascii', + 'cp424': 'cp424', + 'cp437': 'cp437', + 'cp500': 'cp500', + 'cp775': 'cp775', + 'cp819': 'windows-1252', + 'cp850': 'cp850', + 'cp852': 'cp852', + 'cp855': 'cp855', + 'cp857': 'cp857', + 'cp860': 'cp860', + 'cp861': 'cp861', + 'cp862': 'cp862', + 'cp863': 'cp863', + 'cp864': 'cp864', + 'cp865': 'cp865', + 'cp866': 'cp866', + 'cp869': 'cp869', + 'cp936': 'gbk', + 'cpgr': 'cp869', + 'cpis': 'cp861', + 'csascii': 'ascii', + 'csbig5': 'big5', + 'cseuckr': 'cp949', + 'cseucpkdfmtjapanese': 'euc_jp', + 'csgb2312': 'gbk', + 'cshproman8': 'hp-roman8', + 'csibm037': 'cp037', + 'csibm1026': 'cp1026', + 'csibm424': 'cp424', + 'csibm500': 'cp500', + 'csibm855': 'cp855', + 'csibm857': 'cp857', + 'csibm860': 'cp860', + 'csibm861': 'cp861', + 'csibm863': 'cp863', + 'csibm864': 'cp864', + 'csibm865': 'cp865', + 'csibm866': 'cp866', + 'csibm869': 'cp869', + 'csiso2022jp': 'iso2022_jp', + 'csiso2022jp2': 'iso2022_jp_2', + 'csiso2022kr': 'iso2022_kr', + 'csiso58gb231280': 'gbk', + 'csisolatin1': 'windows-1252', + 'csisolatin2': 'iso8859-2', + 'csisolatin3': 'iso8859-3', + 'csisolatin4': 'iso8859-4', + 'csisolatin5': 'windows-1254', + 'csisolatin6': 'iso8859-10', + 'csisolatinarabic': 'iso8859-6', + 'csisolatincyrillic': 'iso8859-5', + 'csisolatingreek': 'iso8859-7', + 'csisolatinhebrew': 'iso8859-8', + 'cskoi8r': 'koi8-r', + 'csksc56011987': 'cp949', + 'cspc775baltic': 'cp775', + 'cspc850multilingual': 'cp850', + 'cspc862latinhebrew': 'cp862', + 'cspc8codepage437': 'cp437', + 'cspcp852': 'cp852', + 'csptcp154': 'ptcp154', + 'csshiftjis': 'shift_jis', + 'csunicode11utf7': 'utf-7', + 'cyrillic': 'iso8859-5', + 'cyrillicasian': 'ptcp154', + 'ebcdiccpbe': 'cp500', + 'ebcdiccpca': 'cp037', + 'ebcdiccpch': 'cp500', + 'ebcdiccphe': 'cp424', + 'ebcdiccpnl': 'cp037', + 'ebcdiccpus': 'cp037', + 'ebcdiccpwt': 'cp037', + 'ecma114': 'iso8859-6', + 'ecma118': 'iso8859-7', + 'elot928': 'iso8859-7', + 'eucjp': 'euc_jp', + 'euckr': 'cp949', + 'extendedunixcodepackedformatforjapanese': 'euc_jp', + 'gb18030': 'gb18030', + 'gb2312': 'gbk', + 'gb231280': 'gbk', + 'gbk': 'gbk', + 'greek': 'iso8859-7', + 'greek8': 'iso8859-7', + 'hebrew': 'iso8859-8', + 'hproman8': 'hp-roman8', + 'hzgb2312': 'hz', + 'ibm037': 'cp037', + 'ibm1026': 'cp1026', + 'ibm367': 'ascii', + 'ibm424': 'cp424', + 'ibm437': 'cp437', + 'ibm500': 'cp500', + 'ibm775': 'cp775', + 'ibm819': 'windows-1252', + 'ibm850': 'cp850', + 'ibm852': 'cp852', + 'ibm855': 'cp855', + 'ibm857': 'cp857', + 'ibm860': 'cp860', + 'ibm861': 'cp861', + 'ibm862': 'cp862', + 'ibm863': 'cp863', + 'ibm864': 'cp864', + 'ibm865': 'cp865', + 'ibm866': 'cp866', + 'ibm869': 'cp869', + 'iso2022jp': 'iso2022_jp', + 'iso2022jp2': 'iso2022_jp_2', + 'iso2022kr': 'iso2022_kr', + 'iso646irv1991': 'ascii', + 'iso646us': 'ascii', + 'iso88591': 'windows-1252', + 'iso885910': 'iso8859-10', + 'iso8859101992': 'iso8859-10', + 'iso885911987': 'windows-1252', + 'iso885913': 'iso8859-13', + 'iso885914': 'iso8859-14', + 'iso8859141998': 'iso8859-14', + 'iso885915': 'iso8859-15', + 'iso885916': 'iso8859-16', + 'iso8859162001': 'iso8859-16', + 'iso88592': 'iso8859-2', + 'iso885921987': 'iso8859-2', + 'iso88593': 'iso8859-3', + 'iso885931988': 'iso8859-3', + 'iso88594': 'iso8859-4', + 'iso885941988': 'iso8859-4', + 'iso88595': 'iso8859-5', + 'iso885951988': 'iso8859-5', + 'iso88596': 'iso8859-6', + 'iso885961987': 'iso8859-6', + 'iso88597': 'iso8859-7', + 'iso885971987': 'iso8859-7', + 'iso88598': 'iso8859-8', + 'iso885981988': 'iso8859-8', + 'iso88599': 'windows-1254', + 'iso885991989': 'windows-1254', + 'isoceltic': 'iso8859-14', + 'isoir100': 'windows-1252', + 'isoir101': 'iso8859-2', + 'isoir109': 'iso8859-3', + 'isoir110': 'iso8859-4', + 'isoir126': 'iso8859-7', + 'isoir127': 'iso8859-6', + 'isoir138': 'iso8859-8', + 'isoir144': 'iso8859-5', + 'isoir148': 'windows-1254', + 'isoir149': 'cp949', + 'isoir157': 'iso8859-10', + 'isoir199': 'iso8859-14', + 'isoir226': 'iso8859-16', + 'isoir58': 'gbk', + 'isoir6': 'ascii', + 'koi8r': 'koi8-r', + 'koi8u': 'koi8-u', + 'korean': 'cp949', + 'ksc5601': 'cp949', + 'ksc56011987': 'cp949', + 'ksc56011989': 'cp949', + 'l1': 'windows-1252', + 'l10': 'iso8859-16', + 'l2': 'iso8859-2', + 'l3': 'iso8859-3', + 'l4': 'iso8859-4', + 'l5': 'windows-1254', + 'l6': 'iso8859-10', + 'l8': 'iso8859-14', + 'latin1': 'windows-1252', + 'latin10': 'iso8859-16', + 'latin2': 'iso8859-2', + 'latin3': 'iso8859-3', + 'latin4': 'iso8859-4', + 'latin5': 'windows-1254', + 'latin6': 'iso8859-10', + 'latin8': 'iso8859-14', + 'latin9': 'iso8859-15', + 'ms936': 'gbk', + 'mskanji': 'shift_jis', + 'pt154': 'ptcp154', + 'ptcp154': 'ptcp154', + 'r8': 'hp-roman8', + 'roman8': 'hp-roman8', + 'shiftjis': 'shift_jis', + 'tis620': 'cp874', + 'unicode11utf7': 'utf-7', + 'us': 'ascii', + 'usascii': 'ascii', + 'utf16': 'utf-16', + 'utf16be': 'utf-16-be', + 'utf16le': 'utf-16-le', + 'utf8': 'utf-8', + 'windows1250': 'cp1250', + 'windows1251': 'cp1251', + 'windows1252': 'cp1252', + 'windows1253': 'cp1253', + 'windows1254': 'cp1254', + 'windows1255': 'cp1255', + 'windows1256': 'cp1256', + 'windows1257': 'cp1257', + 'windows1258': 'cp1258', + 'windows936': 'gbk', + 'x-x-big5': 'big5'} + +tokenTypes = { + "Doctype": 0, + "Characters": 1, + "SpaceCharacters": 2, + "StartTag": 3, + "EndTag": 4, + "EmptyTag": 5, + "Comment": 6, + "ParseError": 7 +} + +tagTokenTypes = frozenset([tokenTypes["StartTag"], tokenTypes["EndTag"], + tokenTypes["EmptyTag"]]) + + +prefixes = dict([(v, k) for k, v in namespaces.items()]) +prefixes["http://www.w3.org/1998/Math/MathML"] = "math" + + +class DataLossWarning(UserWarning): + pass + + +class ReparseException(Exception): + pass diff --git a/lib/html5lib/filters/__init__.py b/lib/html5lib/filters/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/lib/html5lib/filters/_base.py b/lib/html5lib/filters/_base.py new file mode 100644 index 00000000..c7dbaed0 --- /dev/null +++ b/lib/html5lib/filters/_base.py @@ -0,0 +1,12 @@ +from __future__ import absolute_import, division, unicode_literals + + +class Filter(object): + def __init__(self, source): + self.source = source + + def __iter__(self): + return iter(self.source) + + def __getattr__(self, name): + return getattr(self.source, name) diff --git a/lib/html5lib/filters/alphabeticalattributes.py b/lib/html5lib/filters/alphabeticalattributes.py new file mode 100644 index 00000000..fed6996c --- /dev/null +++ b/lib/html5lib/filters/alphabeticalattributes.py @@ -0,0 +1,20 @@ +from __future__ import absolute_import, division, unicode_literals + +from . import _base + +try: + from collections import OrderedDict +except ImportError: + from ordereddict import OrderedDict + + +class Filter(_base.Filter): + def __iter__(self): + for token in _base.Filter.__iter__(self): + if token["type"] in ("StartTag", "EmptyTag"): + attrs = OrderedDict() + for name, value in sorted(token["data"].items(), + key=lambda x: x[0]): + attrs[name] = value + token["data"] = attrs + yield token diff --git a/lib/html5lib/filters/inject_meta_charset.py b/lib/html5lib/filters/inject_meta_charset.py new file mode 100644 index 00000000..ca33b70b --- /dev/null +++ b/lib/html5lib/filters/inject_meta_charset.py @@ -0,0 +1,65 @@ +from __future__ import absolute_import, division, unicode_literals + +from . import _base + + +class Filter(_base.Filter): + def __init__(self, source, encoding): + _base.Filter.__init__(self, source) + self.encoding = encoding + + def __iter__(self): + state = "pre_head" + meta_found = (self.encoding is None) + pending = [] + + for token in _base.Filter.__iter__(self): + type = token["type"] + if type == "StartTag": + if token["name"].lower() == "head": + state = "in_head" + + elif type == "EmptyTag": + if token["name"].lower() == "meta": + # replace charset with actual encoding + has_http_equiv_content_type = False + for (namespace, name), value in token["data"].items(): + if namespace is not None: + continue + elif name.lower() == 'charset': + token["data"][(namespace, name)] = self.encoding + meta_found = True + break + elif name == 'http-equiv' and value.lower() == 'content-type': + has_http_equiv_content_type = True + else: + if has_http_equiv_content_type and (None, "content") in token["data"]: + token["data"][(None, "content")] = 'text/html; charset=%s' % self.encoding + meta_found = True + + elif token["name"].lower() == "head" and not meta_found: + # insert meta into empty head + yield {"type": "StartTag", "name": "head", + "data": token["data"]} + yield {"type": "EmptyTag", "name": "meta", + "data": {(None, "charset"): self.encoding}} + yield {"type": "EndTag", "name": "head"} + meta_found = True + continue + + elif type == "EndTag": + if token["name"].lower() == "head" and pending: + # insert meta into head (if necessary) and flush pending queue + yield pending.pop(0) + if not meta_found: + yield {"type": "EmptyTag", "name": "meta", + "data": {(None, "charset"): self.encoding}} + while pending: + yield pending.pop(0) + meta_found = True + state = "post_head" + + if state == "in_head": + pending.append(token) + else: + yield token diff --git a/lib/html5lib/filters/lint.py b/lib/html5lib/filters/lint.py new file mode 100644 index 00000000..8884696d --- /dev/null +++ b/lib/html5lib/filters/lint.py @@ -0,0 +1,90 @@ +from __future__ import absolute_import, division, unicode_literals + +from . import _base +from ..constants import cdataElements, rcdataElements, voidElements + +from ..constants import spaceCharacters +spaceCharacters = "".join(spaceCharacters) + + +class LintError(Exception): + pass + + +class Filter(_base.Filter): + def __iter__(self): + open_elements = [] + contentModelFlag = "PCDATA" + for token in _base.Filter.__iter__(self): + type = token["type"] + if type in ("StartTag", "EmptyTag"): + name = token["name"] + if contentModelFlag != "PCDATA": + raise LintError("StartTag not in PCDATA content model flag: %(tag)s" % {"tag": name}) + if not isinstance(name, str): + raise LintError("Tag name is not a string: %(tag)r" % {"tag": name}) + if not name: + raise LintError("Empty tag name") + if type == "StartTag" and name in voidElements: + raise LintError("Void element reported as StartTag token: %(tag)s" % {"tag": name}) + elif type == "EmptyTag" and name not in voidElements: + raise LintError("Non-void element reported as EmptyTag token: %(tag)s" % {"tag": token["name"]}) + if type == "StartTag": + open_elements.append(name) + for name, value in token["data"]: + if not isinstance(name, str): + raise LintError("Attribute name is not a string: %(name)r" % {"name": name}) + if not name: + raise LintError("Empty attribute name") + if not isinstance(value, str): + raise LintError("Attribute value is not a string: %(value)r" % {"value": value}) + if name in cdataElements: + contentModelFlag = "CDATA" + elif name in rcdataElements: + contentModelFlag = "RCDATA" + elif name == "plaintext": + contentModelFlag = "PLAINTEXT" + + elif type == "EndTag": + name = token["name"] + if not isinstance(name, str): + raise LintError("Tag name is not a string: %(tag)r" % {"tag": name}) + if not name: + raise LintError("Empty tag name") + if name in voidElements: + raise LintError("Void element reported as EndTag token: %(tag)s" % {"tag": name}) + start_name = open_elements.pop() + if start_name != name: + raise LintError("EndTag (%(end)s) does not match StartTag (%(start)s)" % {"end": name, "start": start_name}) + contentModelFlag = "PCDATA" + + elif type == "Comment": + if contentModelFlag != "PCDATA": + raise LintError("Comment not in PCDATA content model flag") + + elif type in ("Characters", "SpaceCharacters"): + data = token["data"] + if not isinstance(data, str): + raise LintError("Attribute name is not a string: %(name)r" % {"name": data}) + if not data: + raise LintError("%(type)s token with empty data" % {"type": type}) + if type == "SpaceCharacters": + data = data.strip(spaceCharacters) + if data: + raise LintError("Non-space character(s) found in SpaceCharacters token: %(token)r" % {"token": data}) + + elif type == "Doctype": + name = token["name"] + if contentModelFlag != "PCDATA": + raise LintError("Doctype not in PCDATA content model flag: %(name)s" % {"name": name}) + if not isinstance(name, str): + raise LintError("Tag name is not a string: %(tag)r" % {"tag": name}) + # XXX: what to do with token["data"] ? + + elif type in ("ParseError", "SerializeError"): + pass + + else: + raise LintError("Unknown token type: %(type)s" % {"type": type}) + + yield token diff --git a/lib/html5lib/filters/optionaltags.py b/lib/html5lib/filters/optionaltags.py new file mode 100644 index 00000000..fefe0b30 --- /dev/null +++ b/lib/html5lib/filters/optionaltags.py @@ -0,0 +1,205 @@ +from __future__ import absolute_import, division, unicode_literals + +from . import _base + + +class Filter(_base.Filter): + def slider(self): + previous1 = previous2 = None + for token in self.source: + if previous1 is not None: + yield previous2, previous1, token + previous2 = previous1 + previous1 = token + yield previous2, previous1, None + + def __iter__(self): + for previous, token, next in self.slider(): + type = token["type"] + if type == "StartTag": + if (token["data"] or + not self.is_optional_start(token["name"], previous, next)): + yield token + elif type == "EndTag": + if not self.is_optional_end(token["name"], next): + yield token + else: + yield token + + def is_optional_start(self, tagname, previous, next): + type = next and next["type"] or None + if tagname in 'html': + # An html element's start tag may be omitted if the first thing + # inside the html element is not a space character or a comment. + return type not in ("Comment", "SpaceCharacters") + elif tagname == 'head': + # A head element's start tag may be omitted if the first thing + # inside the head element is an element. + # XXX: we also omit the start tag if the head element is empty + if type in ("StartTag", "EmptyTag"): + return True + elif type == "EndTag": + return next["name"] == "head" + elif tagname == 'body': + # A body element's start tag may be omitted if the first thing + # inside the body element is not a space character or a comment, + # except if the first thing inside the body element is a script + # or style element and the node immediately preceding the body + # element is a head element whose end tag has been omitted. + if type in ("Comment", "SpaceCharacters"): + return False + elif type == "StartTag": + # XXX: we do not look at the preceding event, so we never omit + # the body element's start tag if it's followed by a script or + # a style element. + return next["name"] not in ('script', 'style') + else: + return True + elif tagname == 'colgroup': + # A colgroup element's start tag may be omitted if the first thing + # inside the colgroup element is a col element, and if the element + # is not immediately preceeded by another colgroup element whose + # end tag has been omitted. + if type in ("StartTag", "EmptyTag"): + # XXX: we do not look at the preceding event, so instead we never + # omit the colgroup element's end tag when it is immediately + # followed by another colgroup element. See is_optional_end. + return next["name"] == "col" + else: + return False + elif tagname == 'tbody': + # A tbody element's start tag may be omitted if the first thing + # inside the tbody element is a tr element, and if the element is + # not immediately preceeded by a tbody, thead, or tfoot element + # whose end tag has been omitted. + if type == "StartTag": + # omit the thead and tfoot elements' end tag when they are + # immediately followed by a tbody element. See is_optional_end. + if previous and previous['type'] == 'EndTag' and \ + previous['name'] in ('tbody', 'thead', 'tfoot'): + return False + return next["name"] == 'tr' + else: + return False + return False + + def is_optional_end(self, tagname, next): + type = next and next["type"] or None + if tagname in ('html', 'head', 'body'): + # An html element's end tag may be omitted if the html element + # is not immediately followed by a space character or a comment. + return type not in ("Comment", "SpaceCharacters") + elif tagname in ('li', 'optgroup', 'tr'): + # A li element's end tag may be omitted if the li element is + # immediately followed by another li element or if there is + # no more content in the parent element. + # An optgroup element's end tag may be omitted if the optgroup + # element is immediately followed by another optgroup element, + # or if there is no more content in the parent element. + # A tr element's end tag may be omitted if the tr element is + # immediately followed by another tr element, or if there is + # no more content in the parent element. + if type == "StartTag": + return next["name"] == tagname + else: + return type == "EndTag" or type is None + elif tagname in ('dt', 'dd'): + # A dt element's end tag may be omitted if the dt element is + # immediately followed by another dt element or a dd element. + # A dd element's end tag may be omitted if the dd element is + # immediately followed by another dd element or a dt element, + # or if there is no more content in the parent element. + if type == "StartTag": + return next["name"] in ('dt', 'dd') + elif tagname == 'dd': + return type == "EndTag" or type is None + else: + return False + elif tagname == 'p': + # A p element's end tag may be omitted if the p element is + # immediately followed by an address, article, aside, + # blockquote, datagrid, dialog, dir, div, dl, fieldset, + # footer, form, h1, h2, h3, h4, h5, h6, header, hr, menu, + # nav, ol, p, pre, section, table, or ul, element, or if + # there is no more content in the parent element. + if type in ("StartTag", "EmptyTag"): + return next["name"] in ('address', 'article', 'aside', + 'blockquote', 'datagrid', 'dialog', + 'dir', 'div', 'dl', 'fieldset', 'footer', + 'form', 'h1', 'h2', 'h3', 'h4', 'h5', 'h6', + 'header', 'hr', 'menu', 'nav', 'ol', + 'p', 'pre', 'section', 'table', 'ul') + else: + return type == "EndTag" or type is None + elif tagname == 'option': + # An option element's end tag may be omitted if the option + # element is immediately followed by another option element, + # or if it is immediately followed by an optgroup + # element, or if there is no more content in the parent + # element. + if type == "StartTag": + return next["name"] in ('option', 'optgroup') + else: + return type == "EndTag" or type is None + elif tagname in ('rt', 'rp'): + # An rt element's end tag may be omitted if the rt element is + # immediately followed by an rt or rp element, or if there is + # no more content in the parent element. + # An rp element's end tag may be omitted if the rp element is + # immediately followed by an rt or rp element, or if there is + # no more content in the parent element. + if type == "StartTag": + return next["name"] in ('rt', 'rp') + else: + return type == "EndTag" or type is None + elif tagname == 'colgroup': + # A colgroup element's end tag may be omitted if the colgroup + # element is not immediately followed by a space character or + # a comment. + if type in ("Comment", "SpaceCharacters"): + return False + elif type == "StartTag": + # XXX: we also look for an immediately following colgroup + # element. See is_optional_start. + return next["name"] != 'colgroup' + else: + return True + elif tagname in ('thead', 'tbody'): + # A thead element's end tag may be omitted if the thead element + # is immediately followed by a tbody or tfoot element. + # A tbody element's end tag may be omitted if the tbody element + # is immediately followed by a tbody or tfoot element, or if + # there is no more content in the parent element. + # A tfoot element's end tag may be omitted if the tfoot element + # is immediately followed by a tbody element, or if there is no + # more content in the parent element. + # XXX: we never omit the end tag when the following element is + # a tbody. See is_optional_start. + if type == "StartTag": + return next["name"] in ['tbody', 'tfoot'] + elif tagname == 'tbody': + return type == "EndTag" or type is None + else: + return False + elif tagname == 'tfoot': + # A tfoot element's end tag may be omitted if the tfoot element + # is immediately followed by a tbody element, or if there is no + # more content in the parent element. + # XXX: we never omit the end tag when the following element is + # a tbody. See is_optional_start. + if type == "StartTag": + return next["name"] == 'tbody' + else: + return type == "EndTag" or type is None + elif tagname in ('td', 'th'): + # A td element's end tag may be omitted if the td element is + # immediately followed by a td or th element, or if there is + # no more content in the parent element. + # A th element's end tag may be omitted if the th element is + # immediately followed by a td or th element, or if there is + # no more content in the parent element. + if type == "StartTag": + return next["name"] in ('td', 'th') + else: + return type == "EndTag" or type is None + return False diff --git a/lib/html5lib/filters/sanitizer.py b/lib/html5lib/filters/sanitizer.py new file mode 100644 index 00000000..b206b54e --- /dev/null +++ b/lib/html5lib/filters/sanitizer.py @@ -0,0 +1,12 @@ +from __future__ import absolute_import, division, unicode_literals + +from . import _base +from ..sanitizer import HTMLSanitizerMixin + + +class Filter(_base.Filter, HTMLSanitizerMixin): + def __iter__(self): + for token in _base.Filter.__iter__(self): + token = self.sanitize_token(token) + if token: + yield token diff --git a/lib/html5lib/filters/whitespace.py b/lib/html5lib/filters/whitespace.py new file mode 100644 index 00000000..dfc60eeb --- /dev/null +++ b/lib/html5lib/filters/whitespace.py @@ -0,0 +1,38 @@ +from __future__ import absolute_import, division, unicode_literals + +import re + +from . import _base +from ..constants import rcdataElements, spaceCharacters +spaceCharacters = "".join(spaceCharacters) + +SPACES_REGEX = re.compile("[%s]+" % spaceCharacters) + + +class Filter(_base.Filter): + + spacePreserveElements = frozenset(["pre", "textarea"] + list(rcdataElements)) + + def __iter__(self): + preserve = 0 + for token in _base.Filter.__iter__(self): + type = token["type"] + if type == "StartTag" \ + and (preserve or token["name"] in self.spacePreserveElements): + preserve += 1 + + elif type == "EndTag" and preserve: + preserve -= 1 + + elif not preserve and type == "SpaceCharacters" and token["data"]: + # Test on token["data"] above to not introduce spaces where there were not + token["data"] = " " + + elif not preserve and type == "Characters": + token["data"] = collapse_spaces(token["data"]) + + yield token + + +def collapse_spaces(text): + return SPACES_REGEX.sub(' ', text) diff --git a/lib/html5lib/html5parser.py b/lib/html5lib/html5parser.py new file mode 100644 index 00000000..12aa6a35 --- /dev/null +++ b/lib/html5lib/html5parser.py @@ -0,0 +1,2724 @@ +from __future__ import absolute_import, division, unicode_literals +from six import with_metaclass + +import types + +from . import inputstream +from . import tokenizer + +from . import treebuilders +from .treebuilders._base import Marker + +from . import utils +from . import constants +from .constants import spaceCharacters, asciiUpper2Lower +from .constants import specialElements +from .constants import headingElements +from .constants import cdataElements, rcdataElements +from .constants import tokenTypes, ReparseException, namespaces +from .constants import htmlIntegrationPointElements, mathmlTextIntegrationPointElements +from .constants import adjustForeignAttributes as adjustForeignAttributesMap +from .constants import E + + +def parse(doc, treebuilder="etree", encoding=None, + namespaceHTMLElements=True): + """Parse a string or file-like object into a tree""" + tb = treebuilders.getTreeBuilder(treebuilder) + p = HTMLParser(tb, namespaceHTMLElements=namespaceHTMLElements) + return p.parse(doc, encoding=encoding) + + +def parseFragment(doc, container="div", treebuilder="etree", encoding=None, + namespaceHTMLElements=True): + tb = treebuilders.getTreeBuilder(treebuilder) + p = HTMLParser(tb, namespaceHTMLElements=namespaceHTMLElements) + return p.parseFragment(doc, container=container, encoding=encoding) + + +def method_decorator_metaclass(function): + class Decorated(type): + def __new__(meta, classname, bases, classDict): + for attributeName, attribute in classDict.items(): + if isinstance(attribute, types.FunctionType): + attribute = function(attribute) + + classDict[attributeName] = attribute + return type.__new__(meta, classname, bases, classDict) + return Decorated + + +class HTMLParser(object): + """HTML parser. Generates a tree structure from a stream of (possibly + malformed) HTML""" + + def __init__(self, tree=None, tokenizer=tokenizer.HTMLTokenizer, + strict=False, namespaceHTMLElements=True, debug=False): + """ + strict - raise an exception when a parse error is encountered + + tree - a treebuilder class controlling the type of tree that will be + returned. Built in treebuilders can be accessed through + html5lib.treebuilders.getTreeBuilder(treeType) + + tokenizer - a class that provides a stream of tokens to the treebuilder. + This may be replaced for e.g. a sanitizer which converts some tags to + text + """ + + # Raise an exception on the first error encountered + self.strict = strict + + if tree is None: + tree = treebuilders.getTreeBuilder("etree") + self.tree = tree(namespaceHTMLElements) + self.tokenizer_class = tokenizer + self.errors = [] + + self.phases = dict([(name, cls(self, self.tree)) for name, cls in + getPhases(debug).items()]) + + def _parse(self, stream, innerHTML=False, container="div", + encoding=None, parseMeta=True, useChardet=True, **kwargs): + + self.innerHTMLMode = innerHTML + self.container = container + self.tokenizer = self.tokenizer_class(stream, encoding=encoding, + parseMeta=parseMeta, + useChardet=useChardet, + parser=self, **kwargs) + self.reset() + + while True: + try: + self.mainLoop() + break + except ReparseException: + self.reset() + + def reset(self): + self.tree.reset() + self.firstStartTag = False + self.errors = [] + self.log = [] # only used with debug mode + # "quirks" / "limited quirks" / "no quirks" + self.compatMode = "no quirks" + + if self.innerHTMLMode: + self.innerHTML = self.container.lower() + + if self.innerHTML in cdataElements: + self.tokenizer.state = self.tokenizer.rcdataState + elif self.innerHTML in rcdataElements: + self.tokenizer.state = self.tokenizer.rawtextState + elif self.innerHTML == 'plaintext': + self.tokenizer.state = self.tokenizer.plaintextState + else: + # state already is data state + # self.tokenizer.state = self.tokenizer.dataState + pass + self.phase = self.phases["beforeHtml"] + self.phase.insertHtmlElement() + self.resetInsertionMode() + else: + self.innerHTML = False + self.phase = self.phases["initial"] + + self.lastPhase = None + + self.beforeRCDataPhase = None + + self.framesetOK = True + + @property + def documentEncoding(self): + """The name of the character encoding + that was used to decode the input stream, + or :obj:`None` if that is not determined yet. + + """ + if not hasattr(self, 'tokenizer'): + return None + return self.tokenizer.stream.charEncoding[0] + + def isHTMLIntegrationPoint(self, element): + if (element.name == "annotation-xml" and + element.namespace == namespaces["mathml"]): + return ("encoding" in element.attributes and + element.attributes["encoding"].translate( + asciiUpper2Lower) in + ("text/html", "application/xhtml+xml")) + else: + return (element.namespace, element.name) in htmlIntegrationPointElements + + def isMathMLTextIntegrationPoint(self, element): + return (element.namespace, element.name) in mathmlTextIntegrationPointElements + + def mainLoop(self): + CharactersToken = tokenTypes["Characters"] + SpaceCharactersToken = tokenTypes["SpaceCharacters"] + StartTagToken = tokenTypes["StartTag"] + EndTagToken = tokenTypes["EndTag"] + CommentToken = tokenTypes["Comment"] + DoctypeToken = tokenTypes["Doctype"] + ParseErrorToken = tokenTypes["ParseError"] + + for token in self.normalizedTokens(): + new_token = token + while new_token is not None: + currentNode = self.tree.openElements[-1] if self.tree.openElements else None + currentNodeNamespace = currentNode.namespace if currentNode else None + currentNodeName = currentNode.name if currentNode else None + + type = new_token["type"] + + if type == ParseErrorToken: + self.parseError(new_token["data"], new_token.get("datavars", {})) + new_token = None + else: + if (len(self.tree.openElements) == 0 or + currentNodeNamespace == self.tree.defaultNamespace or + (self.isMathMLTextIntegrationPoint(currentNode) and + ((type == StartTagToken and + token["name"] not in frozenset(["mglyph", "malignmark"])) or + type in (CharactersToken, SpaceCharactersToken))) or + (currentNodeNamespace == namespaces["mathml"] and + currentNodeName == "annotation-xml" and + token["name"] == "svg") or + (self.isHTMLIntegrationPoint(currentNode) and + type in (StartTagToken, CharactersToken, SpaceCharactersToken))): + phase = self.phase + else: + phase = self.phases["inForeignContent"] + + if type == CharactersToken: + new_token = phase.processCharacters(new_token) + elif type == SpaceCharactersToken: + new_token = phase.processSpaceCharacters(new_token) + elif type == StartTagToken: + new_token = phase.processStartTag(new_token) + elif type == EndTagToken: + new_token = phase.processEndTag(new_token) + elif type == CommentToken: + new_token = phase.processComment(new_token) + elif type == DoctypeToken: + new_token = phase.processDoctype(new_token) + + if (type == StartTagToken and token["selfClosing"] + and not token["selfClosingAcknowledged"]): + self.parseError("non-void-element-with-trailing-solidus", + {"name": token["name"]}) + + # When the loop finishes it's EOF + reprocess = True + phases = [] + while reprocess: + phases.append(self.phase) + reprocess = self.phase.processEOF() + if reprocess: + assert self.phase not in phases + + def normalizedTokens(self): + for token in self.tokenizer: + yield self.normalizeToken(token) + + def parse(self, stream, encoding=None, parseMeta=True, useChardet=True): + """Parse a HTML document into a well-formed tree + + stream - a filelike object or string containing the HTML to be parsed + + The optional encoding parameter must be a string that indicates + the encoding. If specified, that encoding will be used, + regardless of any BOM or later declaration (such as in a meta + element) + """ + self._parse(stream, innerHTML=False, encoding=encoding, + parseMeta=parseMeta, useChardet=useChardet) + return self.tree.getDocument() + + def parseFragment(self, stream, container="div", encoding=None, + parseMeta=False, useChardet=True): + """Parse a HTML fragment into a well-formed tree fragment + + container - name of the element we're setting the innerHTML property + if set to None, default to 'div' + + stream - a filelike object or string containing the HTML to be parsed + + The optional encoding parameter must be a string that indicates + the encoding. If specified, that encoding will be used, + regardless of any BOM or later declaration (such as in a meta + element) + """ + self._parse(stream, True, container=container, encoding=encoding) + return self.tree.getFragment() + + def parseError(self, errorcode="XXX-undefined-error", datavars={}): + # XXX The idea is to make errorcode mandatory. + self.errors.append((self.tokenizer.stream.position(), errorcode, datavars)) + if self.strict: + raise ParseError(E[errorcode] % datavars) + + def normalizeToken(self, token): + """ HTML5 specific normalizations to the token stream """ + + if token["type"] == tokenTypes["StartTag"]: + token["data"] = dict(token["data"][::-1]) + + return token + + def adjustMathMLAttributes(self, token): + replacements = {"definitionurl": "definitionURL"} + for k, v in replacements.items(): + if k in token["data"]: + token["data"][v] = token["data"][k] + del token["data"][k] + + def adjustSVGAttributes(self, token): + replacements = { + "attributename": "attributeName", + "attributetype": "attributeType", + "basefrequency": "baseFrequency", + "baseprofile": "baseProfile", + "calcmode": "calcMode", + "clippathunits": "clipPathUnits", + "contentscripttype": "contentScriptType", + "contentstyletype": "contentStyleType", + "diffuseconstant": "diffuseConstant", + "edgemode": "edgeMode", + "externalresourcesrequired": "externalResourcesRequired", + "filterres": "filterRes", + "filterunits": "filterUnits", + "glyphref": "glyphRef", + "gradienttransform": "gradientTransform", + "gradientunits": "gradientUnits", + "kernelmatrix": "kernelMatrix", + "kernelunitlength": "kernelUnitLength", + "keypoints": "keyPoints", + "keysplines": "keySplines", + "keytimes": "keyTimes", + "lengthadjust": "lengthAdjust", + "limitingconeangle": "limitingConeAngle", + "markerheight": "markerHeight", + "markerunits": "markerUnits", + "markerwidth": "markerWidth", + "maskcontentunits": "maskContentUnits", + "maskunits": "maskUnits", + "numoctaves": "numOctaves", + "pathlength": "pathLength", + "patterncontentunits": "patternContentUnits", + "patterntransform": "patternTransform", + "patternunits": "patternUnits", + "pointsatx": "pointsAtX", + "pointsaty": "pointsAtY", + "pointsatz": "pointsAtZ", + "preservealpha": "preserveAlpha", + "preserveaspectratio": "preserveAspectRatio", + "primitiveunits": "primitiveUnits", + "refx": "refX", + "refy": "refY", + "repeatcount": "repeatCount", + "repeatdur": "repeatDur", + "requiredextensions": "requiredExtensions", + "requiredfeatures": "requiredFeatures", + "specularconstant": "specularConstant", + "specularexponent": "specularExponent", + "spreadmethod": "spreadMethod", + "startoffset": "startOffset", + "stddeviation": "stdDeviation", + "stitchtiles": "stitchTiles", + "surfacescale": "surfaceScale", + "systemlanguage": "systemLanguage", + "tablevalues": "tableValues", + "targetx": "targetX", + "targety": "targetY", + "textlength": "textLength", + "viewbox": "viewBox", + "viewtarget": "viewTarget", + "xchannelselector": "xChannelSelector", + "ychannelselector": "yChannelSelector", + "zoomandpan": "zoomAndPan" + } + for originalName in list(token["data"].keys()): + if originalName in replacements: + svgName = replacements[originalName] + token["data"][svgName] = token["data"][originalName] + del token["data"][originalName] + + def adjustForeignAttributes(self, token): + replacements = adjustForeignAttributesMap + + for originalName in token["data"].keys(): + if originalName in replacements: + foreignName = replacements[originalName] + token["data"][foreignName] = token["data"][originalName] + del token["data"][originalName] + + def reparseTokenNormal(self, token): + self.parser.phase() + + def resetInsertionMode(self): + # The name of this method is mostly historical. (It's also used in the + # specification.) + last = False + newModes = { + "select": "inSelect", + "td": "inCell", + "th": "inCell", + "tr": "inRow", + "tbody": "inTableBody", + "thead": "inTableBody", + "tfoot": "inTableBody", + "caption": "inCaption", + "colgroup": "inColumnGroup", + "table": "inTable", + "head": "inBody", + "body": "inBody", + "frameset": "inFrameset", + "html": "beforeHead" + } + for node in self.tree.openElements[::-1]: + nodeName = node.name + new_phase = None + if node == self.tree.openElements[0]: + assert self.innerHTML + last = True + nodeName = self.innerHTML + # Check for conditions that should only happen in the innerHTML + # case + if nodeName in ("select", "colgroup", "head", "html"): + assert self.innerHTML + + if not last and node.namespace != self.tree.defaultNamespace: + continue + + if nodeName in newModes: + new_phase = self.phases[newModes[nodeName]] + break + elif last: + new_phase = self.phases["inBody"] + break + + self.phase = new_phase + + def parseRCDataRawtext(self, token, contentType): + """Generic RCDATA/RAWTEXT Parsing algorithm + contentType - RCDATA or RAWTEXT + """ + assert contentType in ("RAWTEXT", "RCDATA") + + self.tree.insertElement(token) + + if contentType == "RAWTEXT": + self.tokenizer.state = self.tokenizer.rawtextState + else: + self.tokenizer.state = self.tokenizer.rcdataState + + self.originalPhase = self.phase + + self.phase = self.phases["text"] + + +def getPhases(debug): + def log(function): + """Logger that records which phase processes each token""" + type_names = dict((value, key) for key, value in + constants.tokenTypes.items()) + + def wrapped(self, *args, **kwargs): + if function.__name__.startswith("process") and len(args) > 0: + token = args[0] + try: + info = {"type": type_names[token['type']]} + except: + raise + if token['type'] in constants.tagTokenTypes: + info["name"] = token['name'] + + self.parser.log.append((self.parser.tokenizer.state.__name__, + self.parser.phase.__class__.__name__, + self.__class__.__name__, + function.__name__, + info)) + return function(self, *args, **kwargs) + else: + return function(self, *args, **kwargs) + return wrapped + + def getMetaclass(use_metaclass, metaclass_func): + if use_metaclass: + return method_decorator_metaclass(metaclass_func) + else: + return type + + class Phase(with_metaclass(getMetaclass(debug, log))): + """Base class for helper object that implements each phase of processing + """ + + def __init__(self, parser, tree): + self.parser = parser + self.tree = tree + + def processEOF(self): + raise NotImplementedError + + def processComment(self, token): + # For most phases the following is correct. Where it's not it will be + # overridden. + self.tree.insertComment(token, self.tree.openElements[-1]) + + def processDoctype(self, token): + self.parser.parseError("unexpected-doctype") + + def processCharacters(self, token): + self.tree.insertText(token["data"]) + + def processSpaceCharacters(self, token): + self.tree.insertText(token["data"]) + + def processStartTag(self, token): + return self.startTagHandler[token["name"]](token) + + def startTagHtml(self, token): + if not self.parser.firstStartTag and token["name"] == "html": + self.parser.parseError("non-html-root") + # XXX Need a check here to see if the first start tag token emitted is + # this token... If it's not, invoke self.parser.parseError(). + for attr, value in token["data"].items(): + if attr not in self.tree.openElements[0].attributes: + self.tree.openElements[0].attributes[attr] = value + self.parser.firstStartTag = False + + def processEndTag(self, token): + return self.endTagHandler[token["name"]](token) + + class InitialPhase(Phase): + def processSpaceCharacters(self, token): + pass + + def processComment(self, token): + self.tree.insertComment(token, self.tree.document) + + def processDoctype(self, token): + name = token["name"] + publicId = token["publicId"] + systemId = token["systemId"] + correct = token["correct"] + + if (name != "html" or publicId is not None or + systemId is not None and systemId != "about:legacy-compat"): + self.parser.parseError("unknown-doctype") + + if publicId is None: + publicId = "" + + self.tree.insertDoctype(token) + + if publicId != "": + publicId = publicId.translate(asciiUpper2Lower) + + if (not correct or token["name"] != "html" + or publicId.startswith( + ("+//silmaril//dtd html pro v0r11 19970101//", + "-//advasoft ltd//dtd html 3.0 aswedit + extensions//", + "-//as//dtd html 3.0 aswedit + extensions//", + "-//ietf//dtd html 2.0 level 1//", + "-//ietf//dtd html 2.0 level 2//", + "-//ietf//dtd html 2.0 strict level 1//", + "-//ietf//dtd html 2.0 strict level 2//", + "-//ietf//dtd html 2.0 strict//", + "-//ietf//dtd html 2.0//", + "-//ietf//dtd html 2.1e//", + "-//ietf//dtd html 3.0//", + "-//ietf//dtd html 3.2 final//", + "-//ietf//dtd html 3.2//", + "-//ietf//dtd html 3//", + "-//ietf//dtd html level 0//", + "-//ietf//dtd html level 1//", + "-//ietf//dtd html level 2//", + "-//ietf//dtd html level 3//", + "-//ietf//dtd html strict level 0//", + "-//ietf//dtd html strict level 1//", + "-//ietf//dtd html strict level 2//", + "-//ietf//dtd html strict level 3//", + "-//ietf//dtd html strict//", + "-//ietf//dtd html//", + "-//metrius//dtd metrius presentational//", + "-//microsoft//dtd internet explorer 2.0 html strict//", + "-//microsoft//dtd internet explorer 2.0 html//", + "-//microsoft//dtd internet explorer 2.0 tables//", + "-//microsoft//dtd internet explorer 3.0 html strict//", + "-//microsoft//dtd internet explorer 3.0 html//", + "-//microsoft//dtd internet explorer 3.0 tables//", + "-//netscape comm. corp.//dtd html//", + "-//netscape comm. corp.//dtd strict html//", + "-//o'reilly and associates//dtd html 2.0//", + "-//o'reilly and associates//dtd html extended 1.0//", + "-//o'reilly and associates//dtd html extended relaxed 1.0//", + "-//softquad software//dtd hotmetal pro 6.0::19990601::extensions to html 4.0//", + "-//softquad//dtd hotmetal pro 4.0::19971010::extensions to html 4.0//", + "-//spyglass//dtd html 2.0 extended//", + "-//sq//dtd html 2.0 hotmetal + extensions//", + "-//sun microsystems corp.//dtd hotjava html//", + "-//sun microsystems corp.//dtd hotjava strict html//", + "-//w3c//dtd html 3 1995-03-24//", + "-//w3c//dtd html 3.2 draft//", + "-//w3c//dtd html 3.2 final//", + "-//w3c//dtd html 3.2//", + "-//w3c//dtd html 3.2s draft//", + "-//w3c//dtd html 4.0 frameset//", + "-//w3c//dtd html 4.0 transitional//", + "-//w3c//dtd html experimental 19960712//", + "-//w3c//dtd html experimental 970421//", + "-//w3c//dtd w3 html//", + "-//w3o//dtd w3 html 3.0//", + "-//webtechs//dtd mozilla html 2.0//", + "-//webtechs//dtd mozilla html//")) + or publicId in + ("-//w3o//dtd w3 html strict 3.0//en//", + "-/w3c/dtd html 4.0 transitional/en", + "html") + or publicId.startswith( + ("-//w3c//dtd html 4.01 frameset//", + "-//w3c//dtd html 4.01 transitional//")) and + systemId is None + or systemId and systemId.lower() == "http://www.ibm.com/data/dtd/v11/ibmxhtml1-transitional.dtd"): + self.parser.compatMode = "quirks" + elif (publicId.startswith( + ("-//w3c//dtd xhtml 1.0 frameset//", + "-//w3c//dtd xhtml 1.0 transitional//")) + or publicId.startswith( + ("-//w3c//dtd html 4.01 frameset//", + "-//w3c//dtd html 4.01 transitional//")) and + systemId is not None): + self.parser.compatMode = "limited quirks" + + self.parser.phase = self.parser.phases["beforeHtml"] + + def anythingElse(self): + self.parser.compatMode = "quirks" + self.parser.phase = self.parser.phases["beforeHtml"] + + def processCharacters(self, token): + self.parser.parseError("expected-doctype-but-got-chars") + self.anythingElse() + return token + + def processStartTag(self, token): + self.parser.parseError("expected-doctype-but-got-start-tag", + {"name": token["name"]}) + self.anythingElse() + return token + + def processEndTag(self, token): + self.parser.parseError("expected-doctype-but-got-end-tag", + {"name": token["name"]}) + self.anythingElse() + return token + + def processEOF(self): + self.parser.parseError("expected-doctype-but-got-eof") + self.anythingElse() + return True + + class BeforeHtmlPhase(Phase): + # helper methods + def insertHtmlElement(self): + self.tree.insertRoot(impliedTagToken("html", "StartTag")) + self.parser.phase = self.parser.phases["beforeHead"] + + # other + def processEOF(self): + self.insertHtmlElement() + return True + + def processComment(self, token): + self.tree.insertComment(token, self.tree.document) + + def processSpaceCharacters(self, token): + pass + + def processCharacters(self, token): + self.insertHtmlElement() + return token + + def processStartTag(self, token): + if token["name"] == "html": + self.parser.firstStartTag = True + self.insertHtmlElement() + return token + + def processEndTag(self, token): + if token["name"] not in ("head", "body", "html", "br"): + self.parser.parseError("unexpected-end-tag-before-html", + {"name": token["name"]}) + else: + self.insertHtmlElement() + return token + + class BeforeHeadPhase(Phase): + def __init__(self, parser, tree): + Phase.__init__(self, parser, tree) + + self.startTagHandler = utils.MethodDispatcher([ + ("html", self.startTagHtml), + ("head", self.startTagHead) + ]) + self.startTagHandler.default = self.startTagOther + + self.endTagHandler = utils.MethodDispatcher([ + (("head", "body", "html", "br"), self.endTagImplyHead) + ]) + self.endTagHandler.default = self.endTagOther + + def processEOF(self): + self.startTagHead(impliedTagToken("head", "StartTag")) + return True + + def processSpaceCharacters(self, token): + pass + + def processCharacters(self, token): + self.startTagHead(impliedTagToken("head", "StartTag")) + return token + + def startTagHtml(self, token): + return self.parser.phases["inBody"].processStartTag(token) + + def startTagHead(self, token): + self.tree.insertElement(token) + self.tree.headPointer = self.tree.openElements[-1] + self.parser.phase = self.parser.phases["inHead"] + + def startTagOther(self, token): + self.startTagHead(impliedTagToken("head", "StartTag")) + return token + + def endTagImplyHead(self, token): + self.startTagHead(impliedTagToken("head", "StartTag")) + return token + + def endTagOther(self, token): + self.parser.parseError("end-tag-after-implied-root", + {"name": token["name"]}) + + class InHeadPhase(Phase): + def __init__(self, parser, tree): + Phase.__init__(self, parser, tree) + + self.startTagHandler = utils.MethodDispatcher([ + ("html", self.startTagHtml), + ("title", self.startTagTitle), + (("noscript", "noframes", "style"), self.startTagNoScriptNoFramesStyle), + ("script", self.startTagScript), + (("base", "basefont", "bgsound", "command", "link"), + self.startTagBaseLinkCommand), + ("meta", self.startTagMeta), + ("head", self.startTagHead) + ]) + self.startTagHandler.default = self.startTagOther + + self. endTagHandler = utils.MethodDispatcher([ + ("head", self.endTagHead), + (("br", "html", "body"), self.endTagHtmlBodyBr) + ]) + self.endTagHandler.default = self.endTagOther + + # the real thing + def processEOF(self): + self.anythingElse() + return True + + def processCharacters(self, token): + self.anythingElse() + return token + + def startTagHtml(self, token): + return self.parser.phases["inBody"].processStartTag(token) + + def startTagHead(self, token): + self.parser.parseError("two-heads-are-not-better-than-one") + + def startTagBaseLinkCommand(self, token): + self.tree.insertElement(token) + self.tree.openElements.pop() + token["selfClosingAcknowledged"] = True + + def startTagMeta(self, token): + self.tree.insertElement(token) + self.tree.openElements.pop() + token["selfClosingAcknowledged"] = True + + attributes = token["data"] + if self.parser.tokenizer.stream.charEncoding[1] == "tentative": + if "charset" in attributes: + self.parser.tokenizer.stream.changeEncoding(attributes["charset"]) + elif ("content" in attributes and + "http-equiv" in attributes and + attributes["http-equiv"].lower() == "content-type"): + # Encoding it as UTF-8 here is a hack, as really we should pass + # the abstract Unicode string, and just use the + # ContentAttrParser on that, but using UTF-8 allows all chars + # to be encoded and as a ASCII-superset works. + data = inputstream.EncodingBytes(attributes["content"].encode("utf-8")) + parser = inputstream.ContentAttrParser(data) + codec = parser.parse() + self.parser.tokenizer.stream.changeEncoding(codec) + + def startTagTitle(self, token): + self.parser.parseRCDataRawtext(token, "RCDATA") + + def startTagNoScriptNoFramesStyle(self, token): + # Need to decide whether to implement the scripting-disabled case + self.parser.parseRCDataRawtext(token, "RAWTEXT") + + def startTagScript(self, token): + self.tree.insertElement(token) + self.parser.tokenizer.state = self.parser.tokenizer.scriptDataState + self.parser.originalPhase = self.parser.phase + self.parser.phase = self.parser.phases["text"] + + def startTagOther(self, token): + self.anythingElse() + return token + + def endTagHead(self, token): + node = self.parser.tree.openElements.pop() + assert node.name == "head", "Expected head got %s" % node.name + self.parser.phase = self.parser.phases["afterHead"] + + def endTagHtmlBodyBr(self, token): + self.anythingElse() + return token + + def endTagOther(self, token): + self.parser.parseError("unexpected-end-tag", {"name": token["name"]}) + + def anythingElse(self): + self.endTagHead(impliedTagToken("head")) + + # XXX If we implement a parser for which scripting is disabled we need to + # implement this phase. + # + # class InHeadNoScriptPhase(Phase): + class AfterHeadPhase(Phase): + def __init__(self, parser, tree): + Phase.__init__(self, parser, tree) + + self.startTagHandler = utils.MethodDispatcher([ + ("html", self.startTagHtml), + ("body", self.startTagBody), + ("frameset", self.startTagFrameset), + (("base", "basefont", "bgsound", "link", "meta", "noframes", "script", + "style", "title"), + self.startTagFromHead), + ("head", self.startTagHead) + ]) + self.startTagHandler.default = self.startTagOther + self.endTagHandler = utils.MethodDispatcher([(("body", "html", "br"), + self.endTagHtmlBodyBr)]) + self.endTagHandler.default = self.endTagOther + + def processEOF(self): + self.anythingElse() + return True + + def processCharacters(self, token): + self.anythingElse() + return token + + def startTagHtml(self, token): + return self.parser.phases["inBody"].processStartTag(token) + + def startTagBody(self, token): + self.parser.framesetOK = False + self.tree.insertElement(token) + self.parser.phase = self.parser.phases["inBody"] + + def startTagFrameset(self, token): + self.tree.insertElement(token) + self.parser.phase = self.parser.phases["inFrameset"] + + def startTagFromHead(self, token): + self.parser.parseError("unexpected-start-tag-out-of-my-head", + {"name": token["name"]}) + self.tree.openElements.append(self.tree.headPointer) + self.parser.phases["inHead"].processStartTag(token) + for node in self.tree.openElements[::-1]: + if node.name == "head": + self.tree.openElements.remove(node) + break + + def startTagHead(self, token): + self.parser.parseError("unexpected-start-tag", {"name": token["name"]}) + + def startTagOther(self, token): + self.anythingElse() + return token + + def endTagHtmlBodyBr(self, token): + self.anythingElse() + return token + + def endTagOther(self, token): + self.parser.parseError("unexpected-end-tag", {"name": token["name"]}) + + def anythingElse(self): + self.tree.insertElement(impliedTagToken("body", "StartTag")) + self.parser.phase = self.parser.phases["inBody"] + self.parser.framesetOK = True + + class InBodyPhase(Phase): + # http://www.whatwg.org/specs/web-apps/current-work/#parsing-main-inbody + # the really-really-really-very crazy mode + def __init__(self, parser, tree): + Phase.__init__(self, parser, tree) + + # Keep a ref to this for special handling of whitespace in
+            self.processSpaceCharactersNonPre = self.processSpaceCharacters
+
+            self.startTagHandler = utils.MethodDispatcher([
+                ("html", self.startTagHtml),
+                (("base", "basefont", "bgsound", "command", "link", "meta",
+                  "script", "style", "title"),
+                 self.startTagProcessInHead),
+                ("body", self.startTagBody),
+                ("frameset", self.startTagFrameset),
+                (("address", "article", "aside", "blockquote", "center", "details",
+                  "details", "dir", "div", "dl", "fieldset", "figcaption", "figure",
+                  "footer", "header", "hgroup", "main", "menu", "nav", "ol", "p",
+                  "section", "summary", "ul"),
+                 self.startTagCloseP),
+                (headingElements, self.startTagHeading),
+                (("pre", "listing"), self.startTagPreListing),
+                ("form", self.startTagForm),
+                (("li", "dd", "dt"), self.startTagListItem),
+                ("plaintext", self.startTagPlaintext),
+                ("a", self.startTagA),
+                (("b", "big", "code", "em", "font", "i", "s", "small", "strike",
+                  "strong", "tt", "u"), self.startTagFormatting),
+                ("nobr", self.startTagNobr),
+                ("button", self.startTagButton),
+                (("applet", "marquee", "object"), self.startTagAppletMarqueeObject),
+                ("xmp", self.startTagXmp),
+                ("table", self.startTagTable),
+                (("area", "br", "embed", "img", "keygen", "wbr"),
+                 self.startTagVoidFormatting),
+                (("param", "source", "track"), self.startTagParamSource),
+                ("input", self.startTagInput),
+                ("hr", self.startTagHr),
+                ("image", self.startTagImage),
+                ("isindex", self.startTagIsIndex),
+                ("textarea", self.startTagTextarea),
+                ("iframe", self.startTagIFrame),
+                (("noembed", "noframes", "noscript"), self.startTagRawtext),
+                ("select", self.startTagSelect),
+                (("rp", "rt"), self.startTagRpRt),
+                (("option", "optgroup"), self.startTagOpt),
+                (("math"), self.startTagMath),
+                (("svg"), self.startTagSvg),
+                (("caption", "col", "colgroup", "frame", "head",
+                  "tbody", "td", "tfoot", "th", "thead",
+                  "tr"), self.startTagMisplaced)
+            ])
+            self.startTagHandler.default = self.startTagOther
+
+            self.endTagHandler = utils.MethodDispatcher([
+                ("body", self.endTagBody),
+                ("html", self.endTagHtml),
+                (("address", "article", "aside", "blockquote", "button", "center",
+                  "details", "dialog", "dir", "div", "dl", "fieldset", "figcaption", "figure",
+                  "footer", "header", "hgroup", "listing", "main", "menu", "nav", "ol", "pre",
+                  "section", "summary", "ul"), self.endTagBlock),
+                ("form", self.endTagForm),
+                ("p", self.endTagP),
+                (("dd", "dt", "li"), self.endTagListItem),
+                (headingElements, self.endTagHeading),
+                (("a", "b", "big", "code", "em", "font", "i", "nobr", "s", "small",
+                  "strike", "strong", "tt", "u"), self.endTagFormatting),
+                (("applet", "marquee", "object"), self.endTagAppletMarqueeObject),
+                ("br", self.endTagBr),
+            ])
+            self.endTagHandler.default = self.endTagOther
+
+        def isMatchingFormattingElement(self, node1, node2):
+            if node1.name != node2.name or node1.namespace != node2.namespace:
+                return False
+            elif len(node1.attributes) != len(node2.attributes):
+                return False
+            else:
+                attributes1 = sorted(node1.attributes.items())
+                attributes2 = sorted(node2.attributes.items())
+                for attr1, attr2 in zip(attributes1, attributes2):
+                    if attr1 != attr2:
+                        return False
+            return True
+
+        # helper
+        def addFormattingElement(self, token):
+            self.tree.insertElement(token)
+            element = self.tree.openElements[-1]
+
+            matchingElements = []
+            for node in self.tree.activeFormattingElements[::-1]:
+                if node is Marker:
+                    break
+                elif self.isMatchingFormattingElement(node, element):
+                    matchingElements.append(node)
+
+            assert len(matchingElements) <= 3
+            if len(matchingElements) == 3:
+                self.tree.activeFormattingElements.remove(matchingElements[-1])
+            self.tree.activeFormattingElements.append(element)
+
+        # the real deal
+        def processEOF(self):
+            allowed_elements = frozenset(("dd", "dt", "li", "p", "tbody", "td",
+                                          "tfoot", "th", "thead", "tr", "body",
+                                          "html"))
+            for node in self.tree.openElements[::-1]:
+                if node.name not in allowed_elements:
+                    self.parser.parseError("expected-closing-tag-but-got-eof")
+                    break
+            # Stop parsing
+
+        def processSpaceCharactersDropNewline(self, token):
+            # Sometimes (start of 
, , and