Upgraded beets to 1.3.10, including patches

This commit is contained in:
Bas Stottelaar
2015-01-27 22:26:35 +01:00
parent ea842a95ca
commit cf6a6a876f
31 changed files with 4792 additions and 3002 deletions

21
lib/beets/LICENSE Normal file
View File

@@ -0,0 +1,21 @@
The MIT License
Copyright (c) 2010-2014 Adrian Sampson
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.

94
lib/beets/README.rst Normal file
View File

@@ -0,0 +1,94 @@
.. image:: https://travis-ci.org/sampsyo/beets.svg?branch=master
:target: https://travis-ci.org/sampsyo/beets
.. image:: http://img.shields.io/coveralls/sampsyo/beets.svg
:target: https://coveralls.io/r/sampsyo/beets
.. image:: http://img.shields.io/pypi/v/beets.svg
:target: https://pypi.python.org/pypi/beets
Beets is the media library management system for obsessive-compulsive music
geeks.
The purpose of beets is to get your music collection right once and for all.
It catalogs your collection, automatically improving its metadata as it goes.
It then provides a bouquet of tools for manipulating and accessing your music.
Here's an example of beets' brainy tag corrector doing its thing::
$ beet import ~/music/ladytron
Tagging:
Ladytron - Witching Hour
(Similarity: 98.4%)
* Last One Standing -> The Last One Standing
* Beauty -> Beauty*2
* White Light Generation -> Whitelightgenerator
* All the Way -> All the Way...
Because beets is designed as a library, it can do almost anything you can
imagine for your music collection. Via `plugins`_, beets becomes a panacea:
- Fetch or calculate all the metadata you could possibly need: `album art`_,
`lyrics`_, `genres`_, `tempos`_, `ReplayGain`_ levels, or `acoustic
fingerprints`_.
- Get metadata from `MusicBrainz`_, `Discogs`_, or `Beatport`_. Or guess
metadata using songs' filenames or their acoustic fingerprints.
- `Transcode audio`_ to any format you like.
- Check your library for `duplicate tracks and albums`_ or for `albums that
are missing tracks`_.
- Clean up crufty tags left behind by other, less-awesome tools.
- Embed and extract album art from files' metadata.
- Browse your music library graphically through a Web browser and play it in any
browser that supports `HTML5 Audio`_.
- Analyze music files' metadata from the command line.
- Listen to your library with a music player that speaks the `MPD`_ protocol
and works with a staggering variety of interfaces.
If beets doesn't do what you want yet, `writing your own plugin`_ is
shockingly simple if you know a little Python.
.. _plugins: http://beets.readthedocs.org/page/plugins/
.. _MPD: http://www.musicpd.org/
.. _MusicBrainz music collection: http://musicbrainz.org/doc/Collections/
.. _writing your own plugin:
http://beets.readthedocs.org/page/dev/plugins.html
.. _HTML5 Audio:
http://www.w3.org/TR/html-markup/audio.html
.. _albums that are missing tracks:
http://beets.readthedocs.org/page/plugins/missing.html
.. _duplicate tracks and albums:
http://beets.readthedocs.org/page/plugins/duplicates.html
.. _Transcode audio:
http://beets.readthedocs.org/page/plugins/convert.html
.. _Beatport: http://www.beatport.com/
.. _Discogs: http://www.discogs.com/
.. _acoustic fingerprints:
http://beets.readthedocs.org/page/plugins/chroma.html
.. _ReplayGain: http://beets.readthedocs.org/page/plugins/replaygain.html
.. _tempos: http://beets.readthedocs.org/page/plugins/echonest.html
.. _genres: http://beets.readthedocs.org/page/plugins/lastgenre.html
.. _album art: http://beets.readthedocs.org/page/plugins/fetchart.html
.. _lyrics: http://beets.readthedocs.org/page/plugins/lyrics.html
.. _MusicBrainz: http://musicbrainz.org/
Read More
---------
Learn more about beets at `its Web site`_. Follow `@b33ts`_ on Twitter for
news and updates.
You can install beets by typing ``pip install beets``. Then check out the
`Getting Started`_ guide.
.. _its Web site: http://beets.radbox.org/
.. _Getting Started: http://beets.readthedocs.org/page/guides/main.html
.. _@b33ts: http://twitter.com/b33ts/
Authors
-------
Beets is by `Adrian Sampson`_ with a supporting cast of thousands. For help,
please contact the `mailing list`_.
.. _mailing list: https://groups.google.com/forum/#!forum/beets-users
.. _Adrian Sampson: http://homes.cs.washington.edu/~asampson/

View File

@@ -1,5 +1,5 @@
# This file is part of beets.
# Copyright 2013, Adrian Sampson.
# Copyright 2014, Adrian Sampson.
#
# Permission is hereby granted, free of charge, to any person obtaining
# a copy of this software and associated documentation files (the
@@ -12,13 +12,14 @@
# The above copyright notice and this permission notice shall be
# included in all copies or substantial portions of the Software.
# This particular version has been slightly modified to work with headphones
# This particular version has been slightly modified to work with Headphones
# https://github.com/rembo10/headphones
import os
__version__ = '1.3.4'
__version__ = '1.3.10-headphones'
__author__ = 'Adrian Sampson <adrian@radbox.org>'
import os
import beets.library
from beets.util import confit

View File

@@ -14,135 +14,21 @@
"""Facilities for automatically determining files' correct metadata.
"""
import os
import logging
import re
from beets import library, mediafile, config
from beets.util import sorted_walk, ancestry, displayable_path
from beets import config
# Parts of external interface.
from .hooks import AlbumInfo, TrackInfo, AlbumMatch, TrackMatch
from .match import tag_item, tag_album
from .match import recommendation
from .hooks import AlbumInfo, TrackInfo, AlbumMatch, TrackMatch # noqa
from .match import tag_item, tag_album # noqa
from .match import Recommendation # noqa
# Global logger.
log = logging.getLogger('beets')
# Constants for directory walker.
MULTIDISC_MARKERS = (r'dis[ck]', r'cd')
MULTIDISC_PAT_FMT = r'^(.*%s[\W_]*)\d'
# Additional utilities for the main interface.
def albums_in_dir(path):
"""Recursively searches the given directory and returns an iterable
of (paths, items) where paths is a list of directories and items is
a list of Items that is probably an album. Specifically, any folder
containing any media files is an album.
"""
collapse_pat = collapse_paths = collapse_items = None
for root, dirs, files in sorted_walk(path,
ignore=config['ignore'].as_str_seq(),
logger=log):
# Get a list of items in the directory.
items = []
for filename in files:
try:
i = library.Item.from_path(os.path.join(root, filename))
except mediafile.FileTypeError:
pass
except mediafile.UnreadableFileError:
log.warn(u'unreadable file: {0}'.format(
displayable_path(filename))
)
else:
items.append(i)
# If we're currently collapsing the constituent directories in a
# multi-disc album, check whether we should continue collapsing
# and add the current directory. If so, just add the directory
# and move on to the next directory. If not, stop collapsing.
if collapse_paths:
if (not collapse_pat and collapse_paths[0] in ancestry(root)) or \
(collapse_pat and
collapse_pat.match(os.path.basename(root))):
# Still collapsing.
collapse_paths.append(root)
collapse_items += items
continue
else:
# Collapse finished. Yield the collapsed directory and
# proceed to process the current one.
if collapse_items:
yield collapse_paths, collapse_items
collapse_pat = collapse_paths = collapse_items = None
# Check whether this directory looks like the *first* directory
# in a multi-disc sequence. There are two indicators: the file
# is named like part of a multi-disc sequence (e.g., "Title Disc
# 1") or it contains no items but only directories that are
# named in this way.
start_collapsing = False
for marker in MULTIDISC_MARKERS:
marker_pat = re.compile(MULTIDISC_PAT_FMT % marker, re.I)
match = marker_pat.match(os.path.basename(root))
# Is this directory the root of a nested multi-disc album?
if dirs and not items:
# Check whether all subdirectories have the same prefix.
start_collapsing = True
subdir_pat = None
for subdir in dirs:
# The first directory dictates the pattern for
# the remaining directories.
if not subdir_pat:
match = marker_pat.match(subdir)
if match:
subdir_pat = re.compile(r'^%s\d' %
re.escape(match.group(1)), re.I)
else:
start_collapsing = False
break
# Subsequent directories must match the pattern.
elif not subdir_pat.match(subdir):
start_collapsing = False
break
# If all subdirectories match, don't check other
# markers.
if start_collapsing:
break
# Is this directory the first in a flattened multi-disc album?
elif match:
start_collapsing = True
# Set the current pattern to match directories with the same
# prefix as this one, followed by a digit.
collapse_pat = re.compile(r'^%s\d' %
re.escape(match.group(1)), re.I)
break
# If either of the above heuristics indicated that this is the
# beginning of a multi-disc album, initialize the collapsed
# directory and item lists and check the next directory.
if start_collapsing:
# Start collapsing; continue to the next iteration.
collapse_paths = [root]
collapse_items = items
continue
# If it's nonempty, yield it.
if items:
yield [root], items
# Clear out any unfinished collapse.
if collapse_paths and collapse_items:
yield collapse_paths, collapse_items
def apply_item_metadata(item, track_info):
"""Set an item's metadata from its matched TrackInfo object.
"""
@@ -156,6 +42,7 @@ def apply_item_metadata(item, track_info):
# At the moment, the other metadata is left intact (including album
# and track number). Perhaps these should be emptied?
def apply_metadata(album_info, mapping):
"""Set the items' metadata to match an AlbumInfo object using a
mapping from Items to TrackInfo objects.
@@ -171,8 +58,8 @@ def apply_metadata(album_info, mapping):
# Artist sort and credit names.
item.artist_sort = track_info.artist_sort or album_info.artist_sort
item.artist_credit = track_info.artist_credit or \
album_info.artist_credit
item.artist_credit = (track_info.artist_credit or
album_info.artist_credit)
item.albumartist_sort = album_info.artist_sort
item.albumartist_credit = album_info.artist_credit
@@ -235,7 +122,6 @@ def apply_metadata(album_info, mapping):
'language',
'country',
'albumstatus',
'media',
'albumdisambig'):
value = getattr(album_info, field)
if value is not None:
@@ -243,5 +129,8 @@ def apply_metadata(album_info, mapping):
if track_info.disctitle is not None:
item.disctitle = track_info.disctitle
if track_info.media is not None:
item.media = track_info.media
# Headphones seal of approval
item.comments = 'tagged by headphones/beets'
item.comments = 'tagged by headphones/beets'

View File

@@ -116,6 +116,7 @@ class AlbumInfo(object):
for track in self.tracks:
track.decode(codec)
class TrackInfo(object):
"""Describes a canonical track present on a release. Appears as part
of an AlbumInfo's ``tracks`` list. Consists of these data members:
@@ -126,6 +127,7 @@ class TrackInfo(object):
- ``artist_id``
- ``length``: float: duration of the track in seconds
- ``index``: position on the entire release
- ``media``: delivery mechanism (Vinyl, etc.)
- ``medium``: the disc number this track appears on in the album
- ``medium_index``: the track's position on the disc
- ``medium_total``: the number of tracks on the item's disc
@@ -140,13 +142,15 @@ class TrackInfo(object):
def __init__(self, title, track_id, artist=None, artist_id=None,
length=None, index=None, medium=None, medium_index=None,
medium_total=None, artist_sort=None, disctitle=None,
artist_credit=None, data_source=None, data_url=None):
artist_credit=None, data_source=None, data_url=None,
media=None):
self.title = title
self.track_id = track_id
self.artist = artist
self.artist_id = artist_id
self.length = length
self.index = index
self.media = media
self.medium = medium
self.medium_index = medium_index
self.medium_total = medium_total
@@ -162,7 +166,7 @@ class TrackInfo(object):
to Unicode.
"""
for fld in ['title', 'artist', 'medium', 'artist_sort', 'disctitle',
'artist_credit']:
'artist_credit', 'media']:
value = getattr(self, fld)
if isinstance(value, str):
setattr(self, fld, value.decode(codec, 'ignore'))
@@ -187,6 +191,7 @@ SD_REPLACE = [
(r'&', 'and'),
]
def _string_dist_basic(str1, str2):
"""Basic edit distance between two strings, ignoring
non-alphanumeric characters and case. Comparisons are based on a
@@ -201,13 +206,16 @@ def _string_dist_basic(str1, str2):
return 0.0
return levenshtein(str1, str2) / float(max(len(str1), len(str2)))
def string_dist(str1, str2):
"""Gives an "intuitive" edit distance between two strings. This is
an edit distance, normalized by the string length, with a number of
tweaks that reflect intuition about text.
"""
if str1 is None and str2 is None: return 0.0
if str1 is None or str2 is None: return 1.0
if str1 is None and str2 is None:
return 0.0
if str1 is None or str2 is None:
return 1.0
str1 = str1.lower()
str2 = str2.lower()
@@ -217,9 +225,9 @@ def string_dist(str1, str2):
# "something, the".
for word in SD_END_WORDS:
if str1.endswith(', %s' % word):
str1 = '%s %s' % (word, str1[:-len(word)-2])
str1 = '%s %s' % (word, str1[:-len(word) - 2])
if str2.endswith(', %s' % word):
str2 = '%s %s' % (word, str2[:-len(word)-2])
str2 = '%s %s' % (word, str2[:-len(word) - 2])
# Perform a couple of basic normalizing substitutions.
for pat, repl in SD_REPLACE:
@@ -256,6 +264,23 @@ def string_dist(str1, str2):
return base_dist + penalty
class LazyClassProperty(object):
"""A decorator implementing a read-only property that is *lazy* in
the sense that the getter is only invoked once. Subsequent accesses
through *any* instance use the cached result.
"""
def __init__(self, getter):
self.getter = getter
self.computed = False
def __get__(self, obj, owner):
if not self.computed:
self.value = self.getter(owner)
self.computed = True
return self.value
class Distance(object):
"""Keeps track of multiple distance penalties. Provides a single
weighted distance for all penalties as well as a weighted distance
@@ -264,11 +289,15 @@ class Distance(object):
def __init__(self):
self._penalties = {}
@LazyClassProperty
def _weights(cls):
"""A dictionary from keys to floating-point weights.
"""
weights_view = config['match']['distance_weights']
self._weights = {}
weights = {}
for key in weights_view.keys():
self._weights[key] = weights_view[key].as_number()
weights[key] = weights_view[key].as_number()
return weights
# Access the components and their aggregates.
@@ -313,8 +342,7 @@ class Distance(object):
# Convert distance into a negative float we can sort items in
# ascending order (for keys, when the penalty is equal) and
# still get the items with the biggest distance first.
return sorted(list_, key=lambda (key, dist): (0-dist, key))
return sorted(list_, key=lambda (key, dist): (0 - dist, key))
# Behave like a float.
@@ -323,13 +351,13 @@ class Distance(object):
def __float__(self):
return self.distance
def __sub__(self, other):
return self.distance - other
def __rsub__(self, other):
return other - self.distance
# Behave like a dict.
def __getitem__(self, key):
@@ -355,11 +383,11 @@ class Distance(object):
"""
if not isinstance(dist, Distance):
raise ValueError(
'`dist` must be a Distance object. It is: %r' % dist)
'`dist` must be a Distance object, not {0}'.format(type(dist))
)
for key, penalties in dist._penalties.iteritems():
self._penalties.setdefault(key, []).extend(penalties)
# Adding components.
def _eq(self, value1, value2):
@@ -379,7 +407,8 @@ class Distance(object):
"""
if not 0.0 <= dist <= 1.0:
raise ValueError(
'`dist` must be between 0.0 and 1.0. It is: %r' % dist)
'`dist` must be between 0.0 and 1.0, not {0}'.format(dist)
)
self._penalties.setdefault(key, []).append(dist)
def add_equality(self, key, value, options):
@@ -476,6 +505,7 @@ def album_for_mbid(release_id):
except mb.MusicBrainzAPIError as exc:
exc.log(log)
def track_for_mbid(recording_id):
"""Get a TrackInfo object for a MusicBrainz recording ID. Return None
if the ID is not found.
@@ -485,18 +515,21 @@ def track_for_mbid(recording_id):
except mb.MusicBrainzAPIError as exc:
exc.log(log)
def albums_for_id(album_id):
"""Get a list of albums for an ID."""
candidates = [album_for_mbid(album_id)]
candidates.extend(plugins.album_for_id(album_id))
return filter(None, candidates)
def tracks_for_id(track_id):
"""Get a list of tracks for an ID."""
candidates = [track_for_mbid(track_id)]
candidates.extend(plugins.track_for_id(track_id))
return filter(None, candidates)
def album_candidates(items, artist, album, va_likely):
"""Search for album matches. ``items`` is a list of Item objects
that make up the album. ``artist`` and ``album`` are the respective
@@ -525,6 +558,7 @@ def album_candidates(items, artist, album, va_likely):
return out
def item_candidates(item, artist, title):
"""Search for item matches. ``item`` is the Item to be matched.
``artist`` and ``title`` are strings and either reflect the item or

View File

@@ -25,11 +25,8 @@ from munkres import Munkres
from beets import plugins
from beets import config
from beets.util import plurality
from beets.util.enumeration import enum
from beets.autotag import hooks
# Recommendation enumeration.
recommendation = enum('none', 'low', 'medium', 'strong', name='recommendation')
from beets.util.enumeration import OrderedEnum
# Artist signals that indicate "various artists". These are used at the
# album level to determine whether a given release is likely a VA
@@ -41,6 +38,18 @@ VA_ARTISTS = (u'', u'various artists', u'various', u'va', u'unknown')
log = logging.getLogger('beets')
# Recommendation enumeration.
class Recommendation(OrderedEnum):
"""Indicates a qualitative suggestion to the user about what should
be done with a given match.
"""
none = 0
low = 1
medium = 2
strong = 3
# Primary matching functionality.
def current_metadata(items):
@@ -56,10 +65,10 @@ def current_metadata(items):
fields = ['artist', 'album', 'albumartist', 'year', 'disctotal',
'mb_albumid', 'label', 'catalognum', 'country', 'media',
'albumdisambig']
for key in fields:
values = [getattr(item, key) for item in items if item]
likelies[key], freq = plurality(values)
consensus[key] = (freq == len(values))
for field in fields:
values = [item[field] for item in items if item]
likelies[field], freq = plurality(values)
consensus[field] = (freq == len(values))
# If there's an album artist consensus, use this for the artist.
if consensus['albumartist'] and likelies['albumartist']:
@@ -67,6 +76,7 @@ def current_metadata(items):
return likelies, consensus
def assign_items(items, tracks):
"""Given a list of Items and a list of TrackInfo objects, find the
best mapping between them. Returns a mapping from Items to TrackInfo
@@ -93,12 +103,14 @@ def assign_items(items, tracks):
extra_tracks.sort(key=lambda t: (t.index, t.title))
return mapping, extra_items, extra_tracks
def track_index_changed(item, track_info):
"""Returns True if the item and track info index is different. Tolerates
per disc and per release numbering.
"""
return item.track not in (track_info.medium_index, track_info.index)
def track_distance(item, track_info, incl_artist=False):
"""Determines the significance of a track metadata change. Returns a
Distance object. `incl_artist` indicates that a distance component should
@@ -109,7 +121,7 @@ def track_distance(item, track_info, incl_artist=False):
# Length.
if track_info.length:
diff = abs(item.length - track_info.length) - \
config['match']['track_length_grace'].as_number()
config['match']['track_length_grace'].as_number()
dist.add_ratio('track_length', diff,
config['match']['track_length_max'].as_number())
@@ -134,6 +146,7 @@ def track_distance(item, track_info, incl_artist=False):
return dist
def distance(items, album_info, mapping):
"""Determines how "significant" an album metadata change would be.
Returns a Distance object. `album_info` is an AlbumInfo object
@@ -239,6 +252,7 @@ def distance(items, album_info, mapping):
return dist
def match_by_id(items):
"""If the items are tagged with a MusicBrainz album ID, returns an
AlbumInfo object for the corresponding album. Otherwise, returns
@@ -247,16 +261,17 @@ def match_by_id(items):
# Is there a consensus on the MB album ID?
albumids = [item.mb_albumid for item in items if item.mb_albumid]
if not albumids:
log.debug('No album IDs found.')
log.debug(u'No album IDs found.')
return None
# If all album IDs are equal, look up the album.
if bool(reduce(lambda x,y: x if x==y else (), albumids)):
if bool(reduce(lambda x, y: x if x == y else (), albumids)):
albumid = albumids[0]
log.debug('Searching for discovered album ID: ' + albumid)
log.debug(u'Searching for discovered album ID: {0}'.format(albumid))
return hooks.album_for_mbid(albumid)
else:
log.debug('No album ID consensus.')
log.debug(u'No album ID consensus.')
def _recommendation(results):
"""Given a sorted list of AlbumMatch or TrackMatch objects, return a
@@ -268,26 +283,26 @@ def _recommendation(results):
"""
if not results:
# No candidates: no recommendation.
return recommendation.none
return Recommendation.none
# Basic distance thresholding.
min_dist = results[0].distance
if min_dist < config['match']['strong_rec_thresh'].as_number():
# Strong recommendation level.
rec = recommendation.strong
rec = Recommendation.strong
elif min_dist <= config['match']['medium_rec_thresh'].as_number():
# Medium recommendation level.
rec = recommendation.medium
rec = Recommendation.medium
elif len(results) == 1:
# Only a single candidate.
rec = recommendation.low
rec = Recommendation.low
elif results[1].distance - min_dist >= \
config['match']['rec_gap_thresh'].as_number():
# Gap between first two candidates is large.
rec = recommendation.low
rec = Recommendation.low
else:
# No conclusion. Return immediately. Can't be downgraded any further.
return recommendation.none
return Recommendation.none
# Downgrade to the max rec if it is lower than the current rec for an
# applied penalty.
@@ -299,28 +314,40 @@ def _recommendation(results):
for key in keys:
if key in max_rec_view.keys():
max_rec = max_rec_view[key].as_choice({
'strong': recommendation.strong,
'medium': recommendation.medium,
'low': recommendation.low,
'none': recommendation.none,
'strong': Recommendation.strong,
'medium': Recommendation.medium,
'low': Recommendation.low,
'none': Recommendation.none,
})
rec = min(rec, max_rec)
return rec
def _add_candidate(items, results, info):
"""Given a candidate AlbumInfo object, attempt to add the candidate
to the output dictionary of AlbumMatch objects. This involves
checking the track count, ordering the items, checking for
duplicates, and calculating the distance.
"""
log.debug('Candidate: %s - %s' % (info.artist, info.album))
log.debug(u'Candidate: {0} - {1}'.format(info.artist, info.album))
# Discard albums with zero tracks.
if not info.tracks:
log.debug('No tracks.')
return
# Don't duplicate.
if info.album_id in results:
log.debug('Duplicate.')
log.debug(u'Duplicate.')
return
# Discard matches without required tags.
for req_tag in config['match']['required'].as_str_seq():
if getattr(info, req_tag) is None:
log.debug(u'Ignored. Missing required tag: {0}'.format(req_tag))
return
# Find mapping between the items and the track info.
mapping, extra_items, extra_tracks = assign_items(items, info.tracks)
@@ -331,30 +358,36 @@ def _add_candidate(items, results, info):
penalties = [key for _, key in dist]
for penalty in config['match']['ignored'].as_str_seq():
if penalty in penalties:
log.debug('Ignored. Penalty: %s' % penalty)
log.debug(u'Ignored. Penalty: {0}'.format(penalty))
return
log.debug('Success. Distance: %f' % dist)
log.debug(u'Success. Distance: {0}'.format(dist))
results[info.album_id] = hooks.AlbumMatch(dist, info, mapping,
extra_items, extra_tracks)
def tag_album(items, search_artist=None, search_album=None,
search_id=None):
"""Bundles together the functionality used to infer tags for a
set of items comprised by an album. Returns everything relevant:
- The current artist.
- The current album.
- A list of AlbumMatch objects. The candidates are sorted by
distance (i.e., best match first).
- A recommendation.
If search_artist and search_album or search_id are provided, then
they are used as search terms in place of the current metadata.
"""Return a tuple of a artist name, an album name, a list of
`AlbumMatch` candidates from the metadata backend, and a
`Recommendation`.
The artist and album are the most common values of these fields
among `items`.
The `AlbumMatch` objects are generated by searching the metadata
backends. By default, the metadata of the items is used for the
search. This can be customized by setting the parameters. The
`mapping` field of the album has the matched `items` as keys.
The recommendation is calculated from the match qualitiy of the
candidates.
"""
# Get current metadata.
likelies, consensus = current_metadata(items)
cur_artist = likelies['artist']
cur_album = likelies['album']
log.debug('Tagging %s - %s' % (cur_artist, cur_album))
log.debug(u'Tagging {0} - {1}'.format(cur_artist, cur_album))
# The output result (distance, AlbumInfo) tuples (keyed by MB album
# ID).
@@ -362,7 +395,7 @@ def tag_album(items, search_artist=None, search_album=None,
# Search by explicit ID.
if search_id is not None:
log.debug('Searching for album ID: ' + search_id)
log.debug(u'Searching for album ID: {0}'.format(search_id))
search_cands = hooks.albums_for_id(search_id)
# Use existing metadata or text search.
@@ -372,32 +405,33 @@ def tag_album(items, search_artist=None, search_album=None,
if id_info:
_add_candidate(items, candidates, id_info)
rec = _recommendation(candidates.values())
log.debug('Album ID match recommendation is ' + str(rec))
log.debug(u'Album ID match recommendation is {0}'.format(str(rec)))
if candidates and not config['import']['timid']:
# If we have a very good MBID match, return immediately.
# Otherwise, this match will compete against metadata-based
# matches.
if rec == recommendation.strong:
log.debug('ID match.')
if rec == Recommendation.strong:
log.debug(u'ID match.')
return cur_artist, cur_album, candidates.values(), rec
# Search terms.
if not (search_artist and search_album):
# No explicit search terms -- use current metadata.
search_artist, search_album = cur_artist, cur_album
log.debug(u'Search terms: %s - %s' % (search_artist, search_album))
log.debug(u'Search terms: {0} - {1}'.format(search_artist,
search_album))
# Is this album likely to be a "various artist" release?
va_likely = ((not consensus['artist']) or
(search_artist.lower() in VA_ARTISTS) or
any(item.comp for item in items))
log.debug(u'Album might be VA: %s' % str(va_likely))
(search_artist.lower() in VA_ARTISTS) or
any(item.comp for item in items))
log.debug(u'Album might be VA: {0}'.format(str(va_likely)))
# Get the results from the data sources.
search_cands = hooks.album_candidates(items, search_artist,
search_album, va_likely)
log.debug(u'Evaluating %i candidates.' % len(search_cands))
log.debug(u'Evaluating {0} candidates.'.format(len(search_cands)))
for info in search_cands:
_add_candidate(items, candidates, info)
@@ -406,6 +440,7 @@ def tag_album(items, search_artist=None, search_album=None,
rec = _recommendation(candidates)
return cur_artist, cur_album, candidates, rec
def tag_item(item, search_artist=None, search_title=None,
search_id=None):
"""Attempts to find metadata for a single track. Returns a
@@ -421,15 +456,15 @@ def tag_item(item, search_artist=None, search_title=None,
# First, try matching by MusicBrainz ID.
trackid = search_id or item.mb_trackid
if trackid:
log.debug('Searching for track ID: ' + trackid)
log.debug(u'Searching for track ID: {0}'.format(trackid))
for track_info in hooks.tracks_for_id(trackid):
dist = track_distance(item, track_info, incl_artist=True)
candidates[track_info.track_id] = \
hooks.TrackMatch(dist, track_info)
hooks.TrackMatch(dist, track_info)
# If this is a good match, then don't keep searching.
rec = _recommendation(candidates.values())
if rec == recommendation.strong and not config['import']['timid']:
log.debug('Track ID match.')
if rec == Recommendation.strong and not config['import']['timid']:
log.debug(u'Track ID match.')
return candidates.values(), rec
# If we're searching by ID, don't proceed.
@@ -437,12 +472,13 @@ def tag_item(item, search_artist=None, search_title=None,
if candidates:
return candidates.values(), rec
else:
return [], recommendation.none
return [], Recommendation.none
# Search terms.
if not (search_artist and search_title):
search_artist, search_title = item.artist, item.title
log.debug(u'Item search terms: %s - %s' % (search_artist, search_title))
log.debug(u'Item search terms: {0} - {1}'.format(search_artist,
search_title))
# Get and evaluate candidate metadata.
for track_info in hooks.item_candidates(item, search_artist, search_title):
@@ -450,7 +486,7 @@ def tag_item(item, search_artist=None, search_title=None,
candidates[track_info.track_id] = hooks.TrackMatch(dist, track_info)
# Sort by distance and return with recommendation.
log.debug('Found %i candidates.' % len(candidates))
log.debug(u'Found {0} candidates.'.format(len(candidates)))
candidates = sorted(candidates.itervalues())
rec = _recommendation(candidates)
return candidates, rec

View File

@@ -32,6 +32,7 @@ BASE_URL = 'http://musicbrainz.org/'
musicbrainzngs.set_useragent('beets', beets.__version__,
'http://beets.radbox.org/')
class MusicBrainzAPIError(util.HumanReadableException):
"""An error while talking to MusicBrainz. The `query` field is the
parameter to the action and may have any type.
@@ -41,7 +42,7 @@ class MusicBrainzAPIError(util.HumanReadableException):
super(MusicBrainzAPIError, self).__init__(reason, verb, tb)
def get_message(self):
return u'"{0}" in {1} with query {2}'.format(
return u'{0} in {1} with query {2}'.format(
self._reasonstr(), self.verb, repr(self.query)
)
@@ -51,12 +52,15 @@ RELEASE_INCLUDES = ['artists', 'media', 'recordings', 'release-groups',
'labels', 'artist-credits', 'aliases']
TRACK_INCLUDES = ['artists', 'aliases']
def track_url(trackid):
return urljoin(BASE_URL, 'recording/' + trackid)
def album_url(albumid):
return urljoin(BASE_URL, 'release/' + albumid)
def configure():
"""Set up the python-musicbrainz-ngs module according to settings
from the beets configuration. This should be called at startup.
@@ -67,6 +71,7 @@ def configure():
config['musicbrainz']['ratelimit'].get(int),
)
def _preferred_alias(aliases):
"""Given an list of alias structures for an artist credit, select
and return the user's preferred alias alias or None if no matching
@@ -81,13 +86,15 @@ def _preferred_alias(aliases):
# Search configured locales in order.
for locale in config['import']['languages'].as_str_seq():
# Find matching primary aliases for this locale.
matches = [a for a in aliases if a['locale'] == locale and 'primary' in a]
matches = [a for a in aliases
if a['locale'] == locale and 'primary' in a]
# Skip to the next locale if we have no matches
if not matches:
continue
return matches[0]
def _flatten_artist_credit(credit):
"""Given a list representing an ``artist-credit`` block, flatten the
data into a triple of joined artist name strings: canonical, sort, and
@@ -133,6 +140,7 @@ def _flatten_artist_credit(credit):
''.join(artist_credit_parts),
)
def track_info(recording, index=None, medium=None, medium_index=None,
medium_total=None):
"""Translates a MusicBrainz recording result dictionary into a beets
@@ -167,6 +175,7 @@ def track_info(recording, index=None, medium=None, medium_index=None,
info.decode()
return info
def _set_date_str(info, date_str, original=False):
"""Given a (possibly partial) YYYY-MM-DD string and an AlbumInfo
object, set the object's release date fields appropriately. If
@@ -186,6 +195,7 @@ def _set_date_str(info, date_str, original=False):
key = 'original_' + key
setattr(info, key, date_num)
def album_info(release):
"""Takes a MusicBrainz release result dictionary and returns a beets
AlbumInfo object containing the interesting data about that release.
@@ -199,6 +209,7 @@ def album_info(release):
index = 0
for medium in release['medium-list']:
disctitle = medium.get('title')
format = medium.get('format')
for track in medium['track-list']:
# Basic information from the recording.
index += 1
@@ -210,6 +221,7 @@ def album_info(release):
len(medium['track-list']),
)
ti.disctitle = disctitle
ti.media = format
# Prefer track data, where present, over recording data.
if track.get('title'):
@@ -288,6 +300,7 @@ def album_info(release):
info.decode()
return info
def match_album(artist, album, tracks=None, limit=SEARCH_LIMIT):
"""Searches for a single album ("release" in MusicBrainz parlance)
and returns an iterator over AlbumInfo objects. May raise a
@@ -297,9 +310,9 @@ def match_album(artist, album, tracks=None, limit=SEARCH_LIMIT):
optionally, a number of tracks on the album.
"""
# Build search criteria.
criteria = {'release': album.lower()}
criteria = {'release': album.lower().strip()}
if artist is not None:
criteria['artist'] = artist.lower()
criteria['artist'] = artist.lower().strip()
else:
# Various Artists search.
criteria['arid'] = VARIOUS_ARTISTS_ID
@@ -322,13 +335,14 @@ def match_album(artist, album, tracks=None, limit=SEARCH_LIMIT):
if albuminfo is not None:
yield albuminfo
def match_track(artist, title, limit=SEARCH_LIMIT):
"""Searches for a single track and returns an iterable of TrackInfo
objects. May raise a MusicBrainzAPIError.
"""
criteria = {
'artist': artist.lower(),
'recording': title.lower(),
'artist': artist.lower().strip(),
'recording': title.lower().strip(),
}
if not any(criteria.itervalues()):
@@ -342,6 +356,7 @@ def match_track(artist, title, limit=SEARCH_LIMIT):
for recording in res['recording-list']:
yield track_info(recording)
def _parse_id(s):
"""Search for a MusicBrainz ID in the given string and return it. If
no ID can be found, return None.
@@ -351,38 +366,40 @@ def _parse_id(s):
if match:
return match.group()
def album_for_id(albumid):
def album_for_id(releaseid):
"""Fetches an album by its MusicBrainz ID and returns an AlbumInfo
object or None if the album is not found. May raise a
MusicBrainzAPIError.
"""
albumid = _parse_id(albumid)
albumid = _parse_id(releaseid)
if not albumid:
log.error('Invalid MBID.')
log.debug(u'Invalid MBID ({0}).'.format(releaseid))
return
try:
res = musicbrainzngs.get_release_by_id(albumid,
RELEASE_INCLUDES)
except musicbrainzngs.ResponseError:
log.debug('Album ID match failed.')
log.debug(u'Album ID match failed.')
return None
except musicbrainzngs.MusicBrainzError as exc:
raise MusicBrainzAPIError(exc, 'get release by ID', albumid,
traceback.format_exc())
return album_info(res['release'])
def track_for_id(trackid):
def track_for_id(releaseid):
"""Fetches a track by its MusicBrainz ID. Returns a TrackInfo object
or None if no track is found. May raise a MusicBrainzAPIError.
"""
trackid = _parse_id(trackid)
trackid = _parse_id(releaseid)
if not trackid:
log.error('Invalid MBID.')
log.debug(u'Invalid MBID ({0}).'.format(releaseid))
return
try:
res = musicbrainzngs.get_recording_by_id(trackid, TRACK_INCLUDES)
except musicbrainzngs.ResponseError:
log.debug('Track ID match failed.')
log.debug(u'Track ID match failed.')
return None
except musicbrainzngs.MusicBrainzError as exc:
raise MusicBrainzAPIError(exc, 'get recording by ID', trackid,

View File

@@ -5,6 +5,7 @@ import:
write: yes
copy: yes
move: no
link: no
delete: no
resume: ask
incremental: no
@@ -20,6 +21,7 @@ import:
detail: no
flat: no
group_albums: no
pretend: false
clutter: ["Thumbs.DB", ".DS_Store"]
ignore: [".*", "*~", "System Volume Information"]
@@ -32,6 +34,7 @@ replace:
'\s+$': ''
'^\s+': ''
path_sep_replace: _
asciify_paths: false
art_filename: cover
max_filename_length: 0
@@ -54,6 +57,9 @@ list_format_item: $artist - $album - $title
list_format_album: $albumartist - $album
time_format: '%Y-%m-%d %H:%M:%S'
sort_album: albumartist+ album+
sort_item: artist+ album+ disc+ track+
paths:
default: $albumartist/$album%aunique{}/$track $title
singleton: Non-Album/$artist/$title
@@ -98,5 +104,6 @@ match:
media: []
original_year: no
ignored: []
required: []
track_length_grace: 10
track_length_max: 30

View File

@@ -18,3 +18,8 @@ Library.
from .db import Model, Database
from .query import Query, FieldQuery, MatchQuery, AndQuery, OrQuery
from .types import Type
from .queryparse import query_from_strings
from .queryparse import sort_from_strings
from .queryparse import parse_sorted_query
# flake8: noqa

View File

@@ -20,16 +20,62 @@ from collections import defaultdict
import threading
import sqlite3
import contextlib
import collections
import beets
from beets.util.functemplate import Template
from .query import MatchQuery
from beets.dbcore import types
from .query import MatchQuery, NullSort, TrueQuery
class FormattedMapping(collections.Mapping):
"""A `dict`-like formatted view of a model.
The accessor `mapping[key]` returns the formated version of
`model[key]` as a unicode string.
If `for_path` is true, all path separators in the formatted values
are replaced.
"""
def __init__(self, model, for_path=False):
self.for_path = for_path
self.model = model
self.model_keys = model.keys(True)
def __getitem__(self, key):
if key in self.model_keys:
return self._get_formatted(self.model, key)
else:
raise KeyError(key)
def __iter__(self):
return iter(self.model_keys)
def __len__(self):
return len(self.model_keys)
def get(self, key, default=None):
if default is None:
default = self.model._type(key).format(None)
return super(FormattedMapping, self).get(key, default)
def _get_formatted(self, model, key):
value = model._type(key).format(model.get(key))
if isinstance(value, bytes):
value = value.decode('utf8', 'ignore')
if self.for_path:
sep_repl = beets.config['path_sep_replace'].get(unicode)
for sep in (os.path.sep, os.path.altsep):
if sep:
value = value.replace(sep, sep_repl)
return value
# Abstract base for model classes.
class Model(object):
"""An abstract object representing an object in the database. Model
objects act like dictionaries (i.e., the allow subscript access like
@@ -66,12 +112,7 @@ class Model(object):
_fields = {}
"""A mapping indicating available "fixed" fields on this type. The
keys are field names and the values are Type objects.
"""
_bytes_keys = ()
"""Keys whose values should be stored as raw bytes blobs rather than
strings.
keys are field names and the values are `Type` objects.
"""
_search_fields = ()
@@ -79,6 +120,21 @@ class Model(object):
terms.
"""
_types = {}
"""Optional Types for non-fixed (i.e., flexible and computed) fields.
"""
_sorts = {}
"""Optional named sort criteria. The keys are strings and the values
are subclasses of `Sort`.
"""
_always_dirty = False
"""By default, fields only become "dirty" when their value actually
changes. Enabling this flag marks fields as dirty even when the new
value is the same as the old value (e.g., `o.f = o.f`).
"""
@classmethod
def _getters(cls):
"""Return a mapping from field names to getter functions.
@@ -94,7 +150,6 @@ class Model(object):
# As above: we could consider caching this result.
raise NotImplementedError()
# Basic operation.
def __init__(self, db=None, **values):
@@ -110,6 +165,20 @@ class Model(object):
self.update(values)
self.clear_dirty()
@classmethod
def _awaken(cls, db=None, fixed_values={}, flex_values={}):
"""Create an object with values drawn from the database.
This is a performance optimization: the checks involved with
ordinary construction are bypassed.
"""
obj = cls(db)
for key, value in fixed_values.iteritems():
obj._values_fixed[key] = cls._type(key).from_sql(value)
for key, value in flex_values.iteritems():
obj._values_flex[key] = cls._type(key).from_sql(value)
return obj
def __repr__(self):
return '{0}({1})'.format(
type(self).__name__,
@@ -132,9 +201,17 @@ class Model(object):
if need_id and not self.id:
raise ValueError('{0} has no id'.format(type(self).__name__))
# Essential field accessors.
@classmethod
def _type(self, key):
"""Get the type of a field, a `Type` instance.
If the field has no explicit type, it is given the base `Type`,
which does no conversion.
"""
return self._fields.get(key) or self._types.get(key) or types.DEFAULT
def __getitem__(self, key):
"""Get the value for a field. Raise a KeyError if the field is
not available.
@@ -152,11 +229,19 @@ class Model(object):
def __setitem__(self, key, value):
"""Assign the value for a field.
"""
source = self._values_fixed if key in self._fields \
else self._values_flex
# Choose where to place the value.
if key in self._fields:
source = self._values_fixed
else:
source = self._values_flex
# If the field has a type, filter the value.
value = self._type(key).normalize(value)
# Assign value and possibly mark as dirty.
old_value = source.get(key)
source[key] = value
if old_value != value:
if self._always_dirty or old_value != value:
self._dirty.add(key)
def __delitem__(self, key):
@@ -183,7 +268,6 @@ class Model(object):
else:
return base_keys
# Act like a dictionary.
def update(self, values):
@@ -219,7 +303,6 @@ class Model(object):
"""
return iter(self.keys())
# Convenient attribute access.
def __getattr__(self, key):
@@ -243,7 +326,6 @@ class Model(object):
else:
del self[key]
# Database interaction (CRUD methods).
def store(self):
@@ -252,19 +334,15 @@ class Model(object):
self._check_db()
# Build assignments for query.
assignments = ''
assignments = []
subvars = []
for key in self._fields:
if key != 'id' and key in self._dirty:
self._dirty.remove(key)
assignments += key + '=?,'
value = self[key]
# Wrap path strings in buffers so they get stored
# "in the raw".
if key in self._bytes_keys and isinstance(value, str):
value = buffer(value)
assignments.append(key + '=?')
value = self._type(key).to_sql(self[key])
subvars.append(value)
assignments = assignments[:-1] # Knock off last ,
assignments = ','.join(assignments)
with self._db.transaction() as tx:
# Main table update.
@@ -302,6 +380,8 @@ class Model(object):
self._check_db()
stored_obj = self._db._get(type(self), self.id)
assert stored_obj is not None, "object {0} not in DB".format(self.id)
self._values_fixed = {}
self._values_flex = {}
self.update(dict(stored_obj))
self.clear_dirty()
@@ -344,76 +424,26 @@ class Model(object):
self._dirty.add(key)
self.store()
# Formatting and templating.
@classmethod
def _format(cls, key, value, for_path=False):
"""Format a value as the given field for this model.
"""
# Format the value as a string according to its type, if any.
if key in cls._fields:
value = cls._fields[key].format(value)
# Formatting must result in a string. To deal with
# Python2isms, implicitly convert ASCII strings.
assert isinstance(value, basestring), \
u'field formatter must produce strings'
if isinstance(value, bytes):
value = value.decode('utf8', 'ignore')
_formatter = FormattedMapping
elif not isinstance(value, unicode):
# Fallback formatter. Convert to unicode at all cost.
if value is None:
value = u''
elif isinstance(value, basestring):
if isinstance(value, bytes):
value = value.decode('utf8', 'ignore')
else:
value = unicode(value)
if for_path:
sep_repl = beets.config['path_sep_replace'].get(unicode)
for sep in (os.path.sep, os.path.altsep):
if sep:
value = value.replace(sep, sep_repl)
return value
def _get_formatted(self, key, for_path=False):
"""Get a field value formatted as a string (`unicode` object)
for display to the user. If `for_path` is true, then the value
will be sanitized for inclusion in a pathname (i.e., path
separators will be removed from the value).
"""
return self._format(key, self.get(key), for_path)
def _formatted_mapping(self, for_path=False):
def formatted(self, for_path=False):
"""Get a mapping containing all values on this object formatted
as human-readable strings.
as human-readable unicode strings.
"""
# In the future, this could be made "lazy" to avoid computing
# fields unnecessarily.
out = {}
for key in self.keys(True):
out[key] = self._get_formatted(key, for_path)
return out
return self._formatter(self, for_path)
def evaluate_template(self, template, for_path=False):
"""Evaluate a template (a string or a `Template` object) using
the object's fields. If `for_path` is true, then no new path
separators will be added to the template.
"""
# Build value mapping.
mapping = self._formatted_mapping(for_path)
# Get template functions.
funcs = self._template_funcs()
# Perform substitution.
if isinstance(template, basestring):
template = Template(template)
return template.substitute(mapping, funcs)
return template.substitute(self.formatted(for_path),
self._template_funcs())
# Parsing.
@@ -424,63 +454,117 @@ class Model(object):
if not isinstance(string, basestring):
raise TypeError("_parse() argument must be a string")
typ = cls._fields.get(key)
if typ:
return typ.parse(string)
else:
# Fall back to unparsed string.
return string
return cls._type(key).parse(string)
# Database controller and supporting interfaces.
class Results(object):
"""An item query result set. Iterating over the collection lazily
constructs LibModel objects that reflect database rows.
"""
def __init__(self, model_class, rows, db, query=None):
def __init__(self, model_class, rows, db, query=None, sort=None):
"""Create a result set that will construct objects of type
`model_class`, which should be a subclass of `LibModel`, out of
the query result mapping in `rows`. The new objects are
associated with the database `db`. If `query` is provided, it is
used as a predicate to filter the results for a "slow query" that
cannot be evaluated by the database directly.
`model_class`.
`model_class` is a subclass of `LibModel` that will be
constructed. `rows` is a query result: a list of mappings. The
new objects will be associated with the database `db`.
If `query` is provided, it is used as a predicate to filter the
results for a "slow query" that cannot be evaluated by the
database directly. If `sort` is provided, it is used to sort the
full list of results before returning. This means it is a "slow
sort" and all objects must be built before returning the first
one.
"""
self.model_class = model_class
self.rows = rows
self.db = db
self.query = query
self.sort = sort
# We keep a queue of rows we haven't yet consumed for
# materialization. We preserve the original total number of
# rows.
self._rows = rows
self._row_count = len(rows)
# The materialized objects corresponding to rows that have been
# consumed.
self._objects = []
def _get_objects(self):
"""Construct and generate Model objects for they query. The
objects are returned in the order emitted from the database; no
slow sort is applied.
For performance, this generator caches materialized objects to
avoid constructing them more than once. This way, iterating over
a `Results` object a second time should be much faster than the
first.
"""
index = 0 # Position in the materialized objects.
while index < len(self._objects) or self._rows:
# Are there previously-materialized objects to produce?
if index < len(self._objects):
yield self._objects[index]
index += 1
# Otherwise, we consume another row, materialize its object
# and produce it.
else:
while self._rows:
row = self._rows.pop(0)
obj = self._make_model(row)
# If there is a slow-query predicate, ensurer that the
# object passes it.
if not self.query or self.query.match(obj):
self._objects.append(obj)
index += 1
yield obj
break
def __iter__(self):
"""Construct Python objects for all rows that pass the query
predicate.
"""Construct and generate Model objects for all matching
objects, in sorted order.
"""
for row in self.rows:
# Get the flexible attributes for the object.
with self.db.transaction() as tx:
flex_rows = tx.query(
'SELECT * FROM {0} WHERE entity_id=?'.format(
self.model_class._flex_table
),
(row['id'],)
)
values = dict(row)
values.update(
dict((row['key'], row['value']) for row in flex_rows)
if self.sort:
# Slow sort. Must build the full list first.
objects = self.sort.sort(list(self._get_objects()))
return iter(objects)
else:
# Objects are pre-sorted (i.e., by the database).
return self._get_objects()
def _make_model(self, row):
# Get the flexible attributes for the object.
with self.db.transaction() as tx:
flex_rows = tx.query(
'SELECT * FROM {0} WHERE entity_id=?'.format(
self.model_class._flex_table
),
(row['id'],)
)
# Construct the Python object and yield it if it passes the
# predicate.
obj = self.model_class(self.db, **values)
if not self.query or self.query.match(obj):
yield obj
cols = dict(row)
values = dict((k, v) for (k, v) in cols.items()
if not k[:4] == 'flex')
flex_values = dict((row['key'], row['value']) for row in flex_rows)
# Construct the Python object
obj = self.model_class._awaken(self.db, values, flex_values)
return obj
def __len__(self):
"""Get the number of matching objects.
"""
if self.query:
if not self._rows:
# Fully materialized. Just count the objects.
return len(self._objects)
elif self.query:
# A slow query. Fall back to testing every object.
count = 0
for obj in self:
@@ -489,7 +573,7 @@ class Results(object):
else:
# A fast query. Just count the rows.
return len(self.rows)
return self._row_count
def __nonzero__(self):
"""Does this result contain any objects?
@@ -500,6 +584,11 @@ class Results(object):
"""Get the nth item in this result set. This is inefficient: all
items up to n are materialized and thrown away.
"""
if not self._rows and not self.sort:
# Fully materialized and already in order. Just look up the
# object.
return self._objects[n]
it = iter(self)
try:
for i in range(n):
@@ -604,7 +693,6 @@ class Database(object):
self._make_table(model_cls._table, model_cls._fields)
self._make_attribute_table(model_cls._flex_table)
# Primitive access control: connections and transactions.
def _connection(self):
@@ -644,7 +732,6 @@ class Database(object):
"""
return Transaction(self)
# Schema setup and migration.
def _make_table(self, table, fields):
@@ -698,27 +785,33 @@ class Database(object):
ON {0} (entity_id);
""".format(flex_table))
# Querying.
def _fetch(self, model_cls, query, order_by=None):
def _fetch(self, model_cls, query=None, sort=None):
"""Fetch the objects of type `model_cls` matching the given
query. The query may be given as a string, string sequence, a
Query object, or None (to fetch everything). If provided,
`order_by` is a SQLite ORDER BY clause for sorting.
Query object, or None (to fetch everything). `sort` is an
`Sort` object.
"""
query = query or TrueQuery() # A null query.
sort = sort or NullSort() # Unsorted.
where, subvals = query.clause()
order_by = sort.order_clause()
sql = "SELECT * FROM {0} WHERE {1}".format(
sql = ("SELECT * FROM {0} WHERE {1} {2}").format(
model_cls._table,
where or '1',
"ORDER BY {0}".format(order_by) if order_by else '',
)
if order_by:
sql += " ORDER BY {0}".format(order_by)
with self.transaction() as tx:
rows = tx.query(sql, subvals)
return Results(model_cls, rows, self, None if where else query)
return Results(
model_cls, rows, self,
None if where else query, # Slow query component.
sort if sort.is_slow() else None, # Slow sort component.
)
def _get(self, model_cls, id):
"""Get a Model object by its id or None if the id does not

View File

@@ -15,6 +15,7 @@
"""The Query type hierarchy for DBCore.
"""
import re
from operator import attrgetter
from beets import util
from datetime import datetime, timedelta
@@ -82,6 +83,23 @@ class MatchQuery(FieldQuery):
return pattern == value
class NoneQuery(FieldQuery):
def __init__(self, field, fast=True):
self.field = field
self.fast = fast
def col_clause(self):
return self.field + " IS NULL", ()
@classmethod
def match(self, item):
try:
return item[self.field] is None
except KeyError:
return True
class StringFieldQuery(FieldQuery):
"""A FieldQuery that converts values to strings before matching
them.
@@ -104,8 +122,11 @@ class StringFieldQuery(FieldQuery):
class SubstringQuery(StringFieldQuery):
"""A query that matches a substring in a specific item field."""
def col_clause(self):
search = '%' + (self.pattern.replace('\\','\\\\').replace('%','\\%')
.replace('_','\\_')) + '%'
pattern = (self.pattern
.replace('\\', '\\\\')
.replace('%', '\\%')
.replace('_', '\\_'))
search = '%' + pattern + '%'
clause = self.field + " like ? escape '\\'"
subvals = [search]
return clause, subvals
@@ -200,7 +221,9 @@ class NumericQuery(FieldQuery):
self.rangemax = self._convert(parts[1])
def match(self, item):
value = getattr(item, self.field)
if self.field not in item:
return False
value = item[self.field]
if isinstance(value, basestring):
value = self._convert(value)
@@ -236,12 +259,16 @@ class CollectionQuery(Query):
self.subqueries = subqueries
# Act like a sequence.
def __len__(self):
return len(self.subqueries)
def __getitem__(self, key):
return self.subqueries[key]
def __iter__(self):
return iter(self.subqueries)
def __contains__(self, item):
return item in self.subqueries
@@ -334,10 +361,8 @@ class FalseQuery(Query):
return False
# Time/date queries.
def _to_epoch_time(date):
"""Convert a `datetime` object to an integer number of seconds since
the (local) Unix epoch.
@@ -393,10 +418,14 @@ class Period(object):
return None
ordinal = string.count('-')
if ordinal >= len(cls.date_formats):
raise ValueError('date is not in one of the formats '
+ ', '.join(cls.date_formats))
# Too many components.
return None
date_format = cls.date_formats[ordinal]
date = datetime.strptime(string, date_format)
try:
date = datetime.strptime(string, date_format)
except ValueError:
# Parsing failed.
return None
precision = cls.precisions[ordinal]
return cls(date, precision)
@@ -492,3 +521,134 @@ class DateQuery(FieldQuery):
# Match any date.
clause = '1'
return clause, subvals
# Sorting.
class Sort(object):
"""An abstract class representing a sort operation for a query into
the item database.
"""
def order_clause(self):
"""Generates a SQL fragment to be used in a ORDER BY clause, or
None if no fragment is used (i.e., this is a slow sort).
"""
return None
def sort(self, items):
"""Sort the list of objects and return a list.
"""
return sorted(items)
def is_slow(self):
"""Indicate whether this query is *slow*, meaning that it cannot
be executed in SQL and must be executed in Python.
"""
return False
class MultipleSort(Sort):
"""Sort that encapsulates multiple sub-sorts.
"""
def __init__(self, sorts=None):
self.sorts = sorts or []
def add_sort(self, sort):
self.sorts.append(sort)
def _sql_sorts(self):
"""Return the list of sub-sorts for which we can be (at least
partially) fast.
A contiguous suffix of fast (SQL-capable) sub-sorts are
executable in SQL. The remaining, even if they are fast
independently, must be executed slowly.
"""
sql_sorts = []
for sort in reversed(self.sorts):
if not sort.order_clause() is None:
sql_sorts.append(sort)
else:
break
sql_sorts.reverse()
return sql_sorts
def order_clause(self):
order_strings = []
for sort in self._sql_sorts():
order = sort.order_clause()
order_strings.append(order)
return ", ".join(order_strings)
def is_slow(self):
for sort in self.sorts:
if sort.is_slow():
return True
return False
def sort(self, items):
slow_sorts = []
switch_slow = False
for sort in reversed(self.sorts):
if switch_slow:
slow_sorts.append(sort)
elif sort.order_clause() is None:
switch_slow = True
slow_sorts.append(sort)
else:
pass
for sort in slow_sorts:
items = sort.sort(items)
return items
def __repr__(self):
return u'MultipleSort({0})'.format(repr(self.sorts))
class FieldSort(Sort):
"""An abstract sort criterion that orders by a specific field (of
any kind).
"""
def __init__(self, field, ascending=True):
self.field = field
self.ascending = ascending
def sort(self, objs):
# TODO: Conversion and null-detection here. In Python 3,
# comparisons with None fail. We should also support flexible
# attributes with different types without falling over.
return sorted(objs, key=attrgetter(self.field),
reverse=not self.ascending)
def __repr__(self):
return u'<{0}: {1}{2}>'.format(
type(self).__name__,
self.field,
'+' if self.ascending else '-',
)
class FixedFieldSort(FieldSort):
"""Sort object to sort on a fixed field.
"""
def order_clause(self):
order = "ASC" if self.ascending else "DESC"
return "{0} {1}".format(self.field, order)
class SlowFieldSort(FieldSort):
"""A sort criterion by some model field other than a fixed field:
i.e., a computed or flexible field.
"""
def is_slow(self):
return True
class NullSort(Sort):
"""No sorting. Leave results unsorted."""
def sort(items):
return items

View File

@@ -0,0 +1,180 @@
# This file is part of beets.
# Copyright 2014, Adrian Sampson.
#
# Permission is hereby granted, free of charge, to any person obtaining
# a copy of this software and associated documentation files (the
# "Software"), to deal in the Software without restriction, including
# without limitation the rights to use, copy, modify, merge, publish,
# distribute, sublicense, and/or sell copies of the Software, and to
# permit persons to whom the Software is furnished to do so, subject to
# the following conditions:
#
# The above copyright notice and this permission notice shall be
# included in all copies or substantial portions of the Software.
"""Parsing of strings into DBCore queries.
"""
import re
import itertools
from . import query
PARSE_QUERY_PART_REGEX = re.compile(
# Non-capturing optional segment for the keyword.
r'(?:'
r'(\S+?)' # The field key.
r'(?<!\\):' # Unescaped :
r')?'
r'(.*)', # The term itself.
re.I # Case-insensitive.
)
def parse_query_part(part, query_classes={}, prefixes={},
default_class=query.SubstringQuery):
"""Take a query in the form of a key/value pair separated by a
colon and return a tuple of `(key, value, cls)`. `key` may be None,
indicating that any field may be matched. `cls` is a subclass of
`FieldQuery`.
The optional `query_classes` parameter maps field names to default
query types; `default_class` is the fallback. `prefixes` is a map
from query prefix markers and query types. Prefix-indicated queries
take precedence over type-based queries.
To determine the query class, two factors are used: prefixes and
field types. For example, the colon prefix denotes a regular
expression query and a type map might provide a special kind of
query for numeric values. If neither a prefix nor a specific query
class is available, `default_class` is used.
For instance,
'stapler' -> (None, 'stapler', SubstringQuery)
'color:red' -> ('color', 'red', SubstringQuery)
':^Quiet' -> (None, '^Quiet', RegexpQuery)
'color::b..e' -> ('color', 'b..e', RegexpQuery)
Prefixes may be "escaped" with a backslash to disable the keying
behavior.
"""
part = part.strip()
match = PARSE_QUERY_PART_REGEX.match(part)
assert match # Regex should always match.
key = match.group(1)
term = match.group(2).replace('\:', ':')
# Match the search term against the list of prefixes.
for pre, query_class in prefixes.items():
if term.startswith(pre):
return key, term[len(pre):], query_class
# No matching prefix: use type-based or fallback/default query.
query_class = query_classes.get(key, default_class)
return key, term, query_class
def construct_query_part(model_cls, prefixes, query_part):
"""Create a query from a single query component, `query_part`, for
querying instances of `model_cls`. Return a `Query` instance.
"""
# Shortcut for empty query parts.
if not query_part:
return query.TrueQuery()
# Get the query classes for each possible field.
query_classes = {}
for k, t in itertools.chain(model_cls._fields.items(),
model_cls._types.items()):
query_classes[k] = t.query
# Parse the string.
key, pattern, query_class = \
parse_query_part(query_part, query_classes, prefixes)
# No key specified.
if key is None:
if issubclass(query_class, query.FieldQuery):
# The query type matches a specific field, but none was
# specified. So we use a version of the query that matches
# any field.
return query.AnyFieldQuery(pattern, model_cls._search_fields,
query_class)
else:
# Other query type.
return query_class(pattern)
key = key.lower()
return query_class(key.lower(), pattern, key in model_cls._fields)
def query_from_strings(query_cls, model_cls, prefixes, query_parts):
"""Creates a collection query of type `query_cls` from a list of
strings in the format used by parse_query_part. `model_cls`
determines how queries are constructed from strings.
"""
subqueries = []
for part in query_parts:
subqueries.append(construct_query_part(model_cls, prefixes, part))
if not subqueries: # No terms in query.
subqueries = [query.TrueQuery()]
return query_cls(subqueries)
def construct_sort_part(model_cls, part):
"""Create a `Sort` from a single string criterion.
`model_cls` is the `Model` being queried. `part` is a single string
ending in ``+`` or ``-`` indicating the sort.
"""
assert part, "part must be a field name and + or -"
field = part[:-1]
assert field, "field is missing"
direction = part[-1]
assert direction in ('+', '-'), "part must end with + or -"
is_ascending = direction == '+'
if field in model_cls._sorts:
sort = model_cls._sorts[field](model_cls, is_ascending)
elif field in model_cls._fields:
sort = query.FixedFieldSort(field, is_ascending)
else:
# Flexible or computed.
sort = query.SlowFieldSort(field, is_ascending)
return sort
def sort_from_strings(model_cls, sort_parts):
"""Create a `Sort` from a list of sort criteria (strings).
"""
if not sort_parts:
return query.NullSort()
else:
sort = query.MultipleSort()
for part in sort_parts:
sort.add_sort(construct_sort_part(model_cls, part))
return sort
def parse_sorted_query(model_cls, parts, prefixes={},
query_cls=query.AndQuery):
"""Given a list of strings, create the `Query` and `Sort` that they
represent.
"""
# Separate query token and sort token.
query_parts = []
sort_parts = []
for part in parts:
if part.endswith((u'+', u'-')) and u':' not in part:
sort_parts.append(part)
else:
query_parts.append(part)
# Parse each.
q = query_from_strings(
query_cls, model_cls, prefixes, query_parts
)
s = sort_from_strings(model_cls, sort_parts)
return q, s

View File

@@ -18,55 +18,111 @@ from . import query
from beets.util import str2bool
# Abstract base.
class Type(object):
"""An object encapsulating the type of a model field. Includes
information about how to store the value in the database, query,
format, and parse a given field.
information about how to store, query, format, and parse a given
field.
"""
sql = None
sql = u'TEXT'
"""The SQLite column type for the value.
"""
query = None
query = query.SubstringQuery
"""The `Query` subclass to be used when querying the field.
"""
model_type = unicode
"""The Python type that is used to represent the value in the model.
The model is guaranteed to return a value of this type if the field
is accessed. To this end, the constructor is used by the `normalize`
and `from_sql` methods and the `default` property.
"""
@property
def null(self):
"""The value to be exposed when the underlying value is None.
"""
return self.model_type()
def format(self, value):
"""Given a value of this type, produce a Unicode string
representing the value. This is used in template evaluation.
"""
raise NotImplementedError()
if value is None:
value = self.null
# `self.null` might be `None`
if value is None:
value = u''
if isinstance(value, bytes):
value = value.decode('utf8', 'ignore')
return unicode(value)
def parse(self, string):
"""Parse a (possibly human-written) string and return the
indicated value of this type.
"""
raise NotImplementedError()
try:
return self.model_type(string)
except ValueError:
return self.null
def normalize(self, value):
"""Given a value that will be assigned into a field of this
type, normalize the value to have the appropriate type. This
base implementation only reinterprets `None`.
"""
if value is None:
return self.null
else:
# TODO This should eventually be replaced by
# `self.model_type(value)`
return value
def from_sql(self, sql_value):
"""Receives the value stored in the SQL backend and return the
value to be stored in the model.
For fixed fields the type of `value` is determined by the column
type affinity given in the `sql` property and the SQL to Python
mapping of the database adapter. For more information see:
http://www.sqlite.org/datatype3.html
https://docs.python.org/2/library/sqlite3.html#sqlite-and-python-types
Flexible fields have the type afinity `TEXT`. This means the
`sql_value` is either a `buffer` or a `unicode` object` and the
method must handle these in addition.
"""
if isinstance(sql_value, buffer):
sql_value = bytes(sql_value).decode('utf8', 'ignore')
if isinstance(sql_value, unicode):
return self.parse(sql_value)
else:
return self.normalize(sql_value)
def to_sql(self, model_value):
"""Convert a value as stored in the model object to a value used
by the database adapter.
"""
return model_value
# Reusable types.
class Default(Type):
null = None
class Integer(Type):
"""A basic integer type.
"""
sql = u'INTEGER'
query = query.NumericQuery
def format(self, value):
return unicode(value or 0)
def parse(self, string):
try:
return int(string)
except ValueError:
return 0
model_type = int
class PaddedInt(Integer):
@@ -93,9 +149,14 @@ class ScaledInt(Integer):
class Id(Integer):
"""An integer used as the row key for a SQLite table.
"""An integer used as the row id or a foreign key in a SQLite table.
This type is nullable: None values are not translated to zero.
"""
sql = u'INTEGER PRIMARY KEY'
null = None
def __init__(self, primary=True):
if primary:
self.sql = u'INTEGER PRIMARY KEY'
class Float(Type):
@@ -103,15 +164,16 @@ class Float(Type):
"""
sql = u'REAL'
query = query.NumericQuery
model_type = float
def format(self, value):
return u'{0:.1f}'.format(value or 0.0)
def parse(self, string):
try:
return float(string)
except ValueError:
return 0.0
class NullFloat(Float):
"""Same as `Float`, but does not normalize `None` to `0.0`.
"""
null = None
class String(Type):
@@ -120,21 +182,27 @@ class String(Type):
sql = u'TEXT'
query = query.SubstringQuery
def format(self, value):
return unicode(value) if value else u''
def parse(self, string):
return string
class Boolean(Type):
"""A boolean type.
"""
sql = u'INTEGER'
query = query.BooleanQuery
model_type = bool
def format(self, value):
return unicode(bool(value))
def parse(self, string):
return str2bool(string)
# Shared instances of common types.
DEFAULT = Default()
INTEGER = Integer()
PRIMARY_ID = Id(True)
FOREIGN_ID = Id(False)
FLOAT = Float()
NULL_FLOAT = NullFloat()
STRING = String()
BOOLEAN = Boolean()

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -16,8 +16,11 @@
import logging
import traceback
import inspect
import re
from collections import defaultdict
import beets
from beets import mediafile
@@ -30,6 +33,14 @@ LASTFM_KEY = '2dc3914abf35f0d9c92d97d8f8e42b43'
log = logging.getLogger('beets')
class PluginConflictException(Exception):
"""Indicates that the services provided by one plugin conflict with
those of another.
For example two plugins may define different types for flexible fields.
"""
# Managing the plugins themselves.
class BeetsPlugin(object):
@@ -40,7 +51,6 @@ class BeetsPlugin(object):
def __init__(self, name=None):
"""Perform one-time plugin setup.
"""
_add_media_fields(self.item_fields())
self.import_stages = []
self.name = name or self.__module__.split('.')[-1]
self.config = beets.config[self.name]
@@ -86,14 +96,6 @@ class BeetsPlugin(object):
"""
return ()
def item_fields(self):
"""Returns field descriptors to be added to the MediaFile class,
in the form of a dictionary whose keys are field names and whose
values are descriptor (e.g., MediaField) instances. The Library
database schema is not (currently) extended.
"""
return {}
def album_for_id(self, album_id):
"""Return an AlbumInfo object or None if no matching release was
found.
@@ -106,6 +108,20 @@ class BeetsPlugin(object):
"""
return None
def add_media_field(self, name, descriptor):
"""Add a field that is synchronized between media files and items.
When a media field is added ``item.write()`` will set the name
property of the item's MediaFile to ``item[name]`` and save the
changes. Similarly ``item.read()`` will set ``item[name]`` to
the value of the name property of the media file.
``descriptor`` must be an instance of ``mediafile.MediaField``.
"""
# Defer impor to prevent circular dependency
from beets import library
mediafile.MediaFile.add_field(name, descriptor)
library.Item._media_fields.add(name)
listeners = None
@@ -130,7 +146,7 @@ class BeetsPlugin(object):
>>> @MyPlugin.listen("imported")
>>> def importListener(**kwargs):
>>> pass
... pass
"""
def helper(func):
if cls.listeners is None:
@@ -170,7 +186,10 @@ class BeetsPlugin(object):
return func
return helper
_classes = set()
def load_plugins(names=()):
"""Imports the modules for a sequence of plugin names. Each name
must be the name of a Python module under the "beetsplug" namespace
@@ -185,7 +204,7 @@ def load_plugins(names=()):
except ImportError as exc:
# Again, this is hacky:
if exc.args[0].endswith(' ' + name):
log.warn('** plugin %s not found' % name)
log.warn(u'** plugin {0} not found'.format(name))
else:
raise
else:
@@ -195,10 +214,13 @@ def load_plugins(names=()):
_classes.add(obj)
except:
log.warn('** error loading plugin %s' % name)
log.warn(u'** error loading plugin {0}'.format(name))
log.warn(traceback.format_exc())
_instances = {}
def find_plugins():
"""Returns a list of BeetsPlugin subclass instances from all
currently loaded beets plugins. Loads the default plugin set
@@ -224,6 +246,7 @@ def commands():
out += plugin.commands()
return out
def queries():
"""Returns a dict mapping prefix strings to Query subclasses all loaded
plugins.
@@ -233,6 +256,24 @@ def queries():
out.update(plugin.queries())
return out
def types(model_cls):
# Gives us `item_types` and `album_types`
attr_name = '{0}_types'.format(model_cls.__name__.lower())
types = {}
for plugin in find_plugins():
plugin_types = getattr(plugin, attr_name, {})
for field in plugin_types:
if field in types and plugin_types[field] != types[field]:
raise PluginConflictException(
u'Plugin {0} defines flexible field {1} '
'which has already been defined with '
'another type.'.format(plugin.name, field)
)
types.update(plugin_types)
return types
def track_distance(item, info):
"""Gets the track distance calculated by all loaded plugins.
Returns a Distance object.
@@ -243,6 +284,7 @@ def track_distance(item, info):
dist.update(plugin.track_distance(item, info))
return dist
def album_distance(items, album_info, mapping):
"""Returns the album distance calculated by plugins."""
from beets.autotag.hooks import Distance
@@ -251,6 +293,7 @@ def album_distance(items, album_info, mapping):
dist.update(plugin.album_distance(items, album_info, mapping))
return dist
def candidates(items, artist, album, va_likely):
"""Gets MusicBrainz candidates for an album from each plugin.
"""
@@ -259,6 +302,7 @@ def candidates(items, artist, album, va_likely):
out.extend(plugin.candidates(items, artist, album, va_likely))
return out
def item_candidates(item, artist, title):
"""Gets MusicBrainz candidates for an item from the plugins.
"""
@@ -267,6 +311,7 @@ def item_candidates(item, artist, title):
out.extend(plugin.item_candidates(item, artist, title))
return out
def album_for_id(album_id):
"""Get AlbumInfo objects for a given ID string.
"""
@@ -277,6 +322,7 @@ def album_for_id(album_id):
out.append(res)
return out
def track_for_id(track_id):
"""Get TrackInfo objects for a given ID string.
"""
@@ -287,6 +333,7 @@ def track_for_id(track_id):
out.append(res)
return out
def template_funcs():
"""Get all the template functions declared by plugins as a
dictionary.
@@ -297,12 +344,6 @@ def template_funcs():
funcs.update(plugin.template_funcs)
return funcs
def _add_media_fields(fields):
"""Adds a {name: descriptor} dictionary of fields to the MediaFile
class. Called during the plugin initialization.
"""
for key, value in fields.iteritems():
setattr(mediafile.MediaFile, key, value)
def import_stages():
"""Get a list of import stage functions defined by plugins."""
@@ -325,6 +366,7 @@ def item_field_getters():
funcs.update(plugin.template_fields)
return funcs
def album_field_getters():
"""As above, for album fields.
"""
@@ -348,6 +390,7 @@ def event_handlers():
all_handlers[event] += handlers
return all_handlers
def send(event, **arguments):
"""Sends an event to all assigned event listeners. Event is the
name of the event to send, all other named arguments go to the
@@ -355,5 +398,38 @@ def send(event, **arguments):
Returns a list of return values from the handlers.
"""
log.debug('Sending event: %s' % event)
return [handler(**arguments) for handler in event_handlers()[event]]
log.debug(u'Sending event: {0}'.format(event))
for handler in event_handlers()[event]:
# Don't break legacy plugins if we want to pass more arguments
argspec = inspect.getargspec(handler).args
args = dict((k, v) for k, v in arguments.items() if k in argspec)
handler(**args)
def feat_tokens(for_artist=True):
"""Return a regular expression that matches phrases like "featuring"
that separate a main artist or a song title from secondary artists.
The `for_artist` option determines whether the regex should be
suitable for matching artist fields (the default) or title fields.
"""
feat_words = ['ft', 'featuring', 'feat', 'feat.', 'ft.']
if for_artist:
feat_words += ['with', 'vs', 'and', 'con', '&']
return '(?<=\s)(?:{0})(?=\s)'.format(
'|'.join(re.escape(x) for x in feat_words)
)
def sanitize_choices(choices, choices_all):
"""Clean up a stringlist configuration attribute: keep only choices
elements present in choices_all, remove duplicate elements, expand '*'
wildcard while keeping original stringlist order.
"""
seen = set()
others = [x for x in choices_all if x not in choices]
res = []
for s in choices:
if s in list(choices_all) + ['*']:
if not (s in seen or seen.add(s)):
res.extend(list(others) if s == '*' else [s])
return res

View File

@@ -29,6 +29,7 @@ import errno
import re
import struct
import traceback
import os.path
from beets import library
from beets import plugins
@@ -38,9 +39,7 @@ from beets import config
from beets.util import confit
from beets.autotag import mb
# On Windows platforms, use colorama to support "ANSI" terminal colors.
if sys.platform == 'win32':
try:
import colorama
@@ -50,8 +49,10 @@ if sys.platform == 'win32':
colorama.init()
# Constants.
log = logging.getLogger('beets')
if not log.handlers:
log.addHandler(logging.StreamHandler())
log.propagate = False # Don't propagate to root handler.
PF_KEY_QUERIES = {
@@ -59,19 +60,15 @@ PF_KEY_QUERIES = {
'singleton': 'singleton:true',
}
# UI exception. Commands should throw this in order to display
# nonrecoverable errors to the user.
class UserError(Exception):
pass
# Main logger.
log = logging.getLogger('beets')
"""UI exception. Commands should throw this in order to display
nonrecoverable errors to the user.
"""
# Utilities.
def _encoding():
"""Tries to guess the encoding used by the terminal."""
# Configured override?
@@ -170,7 +167,7 @@ def input_options(options, require=False, prompt=None, fallback_prompt=None,
# Infer a letter.
for letter in option:
if not letter.isalpha():
continue # Don't use punctuation.
continue # Don't use punctuation.
if letter not in letters:
found_letter = letter
break
@@ -181,9 +178,10 @@ def input_options(options, require=False, prompt=None, fallback_prompt=None,
index = option.index(found_letter)
# Mark the option's shortcut letter for display.
if not require and ((default is None and not numrange and first) or
(isinstance(default, basestring) and
found_letter.lower() == default.lower())):
if not require and (
(default is None and not numrange and first) or
(isinstance(default, basestring) and
found_letter.lower() == default.lower())):
# The first option is the default; mark it.
show_letter = '[%s]' % found_letter.upper()
is_default = True
@@ -352,11 +350,13 @@ def human_seconds_short(interval):
# http://dev.pocoo.org/hg/pygments-main/file/b2deea5b5030/pygments/console.py
# (pygments is by Tim Hatch, Armin Ronacher, et al.)
COLOR_ESCAPE = "\x1b["
DARK_COLORS = ["black", "darkred", "darkgreen", "brown", "darkblue",
"purple", "teal", "lightgray"]
DARK_COLORS = ["black", "darkred", "darkgreen", "brown", "darkblue",
"purple", "teal", "lightgray"]
LIGHT_COLORS = ["darkgray", "red", "green", "yellow", "blue",
"fuchsia", "turquoise", "white"]
RESET_COLOR = COLOR_ESCAPE + "39;49;00m"
def _colorize(color, text):
"""Returns a string that prints the given text in the given color
in a terminal that is ANSI color-aware. The color must be something
@@ -441,30 +441,6 @@ def colordiff(a, b, highlight='red'):
return unicode(a), unicode(b)
def color_diff_suffix(a, b, highlight='red'):
"""Colorize the differing suffix between two strings."""
a, b = unicode(a), unicode(b)
if not config['color']:
return a, b
# Fast path.
if a == b:
return a, b
# Find the longest common prefix.
first_diff = None
for i in range(min(len(a), len(b))):
if a[i] != b[i]:
first_diff = i
break
else:
first_diff = min(len(a), len(b))
# Colorize from the first difference on.
return a[:first_diff] + colorize(highlight, a[first_diff:]), \
b[:first_diff] + colorize(highlight, b[first_diff:])
def get_path_formats(subview=None):
"""Get the configuration's path formats as a list of query/template
pairs.
@@ -494,21 +470,6 @@ def get_replacements():
return replacements
def get_plugin_paths():
"""Get the list of search paths for plugins from the config file.
The value for "pluginpath" may be a single string or a list of
strings.
"""
pluginpaths = config['pluginpath'].get()
if isinstance(pluginpaths, basestring):
pluginpaths = [pluginpaths]
if not isinstance(pluginpaths, list):
raise confit.ConfigTypeError(
u'pluginpath must be string or a list of strings'
)
return map(util.normpath, pluginpaths)
def _pick_format(album, fmt=None):
"""Pick a format string for printing Album or Item objects,
falling back to config options and defaults.
@@ -558,6 +519,8 @@ def term_width():
FLOAT_EPSILON = 0.01
def _field_diff(field, old, new):
"""Given two Model objects, format their values for `field` and
highlight changes among them. Return a human-readable string. If the
@@ -574,13 +537,13 @@ def _field_diff(field, old, new):
return None
# Get formatted values for output.
oldstr = old._get_formatted(field)
newstr = new._get_formatted(field)
oldstr = old.formatted().get(field, u'')
newstr = new.formatted().get(field, u'')
# For strings, highlight changes. For others, colorize the whole
# thing.
if isinstance(oldval, basestring):
oldstr, newstr = colordiff(oldval, newval)
oldstr, newstr = colordiff(oldval, newstr)
else:
oldstr, newstr = colorize('red', oldstr), colorize('red', newstr)
@@ -613,9 +576,12 @@ def show_model_changes(new, old=None, fields=None, always=False):
# New fields.
for field in set(new) - set(old):
if fields and field not in fields:
continue
changes.append(u' {0}: {1}'.format(
field,
colorize('red', new._get_formatted(field))
colorize('red', new.formatted()[field])
))
# Print changes.
@@ -627,10 +593,8 @@ def show_model_changes(new, old=None, fields=None, always=False):
return bool(changes)
# Subcommand parsing infrastructure.
#
# This is a fairly generic subcommand parser for optparse. It is
# maintained externally here:
# http://gist.github.com/462717
@@ -653,46 +617,56 @@ class Subcommand(object):
self.aliases = aliases
self.help = help
self.hide = hide
self._root_parser = None
def print_help(self):
self.parser.print_help()
def parse_args(self, args):
return self.parser.parse_args(args)
@property
def root_parser(self):
return self._root_parser
@root_parser.setter
def root_parser(self, root_parser):
self._root_parser = root_parser
self.parser.prog = '{0} {1}'.format(root_parser.get_prog_name(),
self.name)
class SubcommandsOptionParser(optparse.OptionParser):
"""A variant of OptionParser that parses subcommands and their
arguments.
"""
# A singleton command used to give help on other subcommands.
_HelpSubcommand = Subcommand('help', optparse.OptionParser(),
help='give detailed help on a specific sub-command',
aliases=('?',))
def __init__(self, *args, **kwargs):
"""Create a new subcommand-aware option parser. All of the
options to OptionParser.__init__ are supported in addition
to subcommands, a sequence of Subcommand objects.
"""
# The subcommand array, with the help command included.
self.subcommands = list(kwargs.pop('subcommands', []))
self.subcommands.append(self._HelpSubcommand)
# A more helpful default usage.
if 'usage' not in kwargs:
kwargs['usage'] = """
%prog COMMAND [ARGS...]
%prog help COMMAND"""
kwargs['add_help_option'] = False
# Super constructor.
optparse.OptionParser.__init__(self, *args, **kwargs)
# Adjust the help-visible name of each subcommand.
for subcommand in self.subcommands:
subcommand.parser.prog = '%s %s' % \
(self.get_prog_name(), subcommand.name)
# Our root parser needs to stop on the first unrecognized argument.
self.disable_interspersed_args()
def add_subcommand(self, cmd):
self.subcommands = []
def add_subcommand(self, *cmds):
"""Adds a Subcommand object to the parser's list of commands.
"""
self.subcommands.append(cmd)
for cmd in cmds:
cmd.root_parser = self
self.subcommands.append(cmd)
# Add the list of subcommands to the help message.
def format_help(self, formatter=None):
@@ -711,6 +685,7 @@ class SubcommandsOptionParser(optparse.OptionParser):
disp_names = []
help_position = 0
subcommands = [c for c in self.subcommands if not c.hide]
subcommands.sort(key=lambda c: c.name)
for subcommand in subcommands:
name = subcommand.name
if subcommand.aliases:
@@ -756,52 +731,40 @@ class SubcommandsOptionParser(optparse.OptionParser):
return subcommand
return None
def parse_args(self, a=None, v=None):
"""Like OptionParser.parse_args, but returns these four items:
- options: the options passed to the root parser
- subcommand: the Subcommand object that was invoked
- suboptions: the options passed to the subcommand parser
- subargs: the positional arguments passed to the subcommand
def parse_global_options(self, args):
"""Parse options up to the subcommand argument. Returns a tuple
of the options object and the remaining arguments.
"""
options, args = optparse.OptionParser.parse_args(self, a, v)
subcommand, suboptions, subargs = self._parse_sub(args)
return options, subcommand, suboptions, subargs
options, subargs = self.parse_args(args)
def _parse_sub(self, args):
"""Given the `args` left unused by a typical OptionParser
`parse_args`, return the invoked subcommand, the subcommand
options, and the subcommand arguments.
# Force the help command
if options.help:
subargs = ['help']
elif options.version:
subargs = ['version']
return options, subargs
def parse_subcommand(self, args):
"""Given the `args` left unused by a `parse_global_options`,
return the invoked subcommand, the subcommand options, and the
subcommand arguments.
"""
# Help is default command
if not args:
# No command given.
self.print_help()
self.exit()
else:
cmdname = args.pop(0)
subcommand = self._subcommand_for_name(cmdname)
if not subcommand:
self.error('unknown command ' + cmdname)
args = ['help']
suboptions, subargs = subcommand.parser.parse_args(args)
if subcommand is self._HelpSubcommand:
if subargs:
# particular
cmdname = subargs[0]
helpcommand = self._subcommand_for_name(cmdname)
if not helpcommand:
self.error('no command named {0}'.format(cmdname))
helpcommand.parser.print_help()
self.exit()
else:
# general
self.print_help()
self.exit()
cmdname = args.pop(0)
subcommand = self._subcommand_for_name(cmdname)
if not subcommand:
raise UserError("unknown command '{0}'".format(cmdname))
suboptions, subargs = subcommand.parse_args(args)
return subcommand, suboptions, subargs
optparse.Option.ALWAYS_TYPED_ACTIONS += ('callback',)
def vararg_callback(option, opt_str, value, parser):
"""Callback for an option with variable arguments.
Manually collect arguments right of a callback-action
@@ -838,53 +801,54 @@ def vararg_callback(option, opt_str, value, parser):
setattr(parser.values, option.dest, value)
# The main entry point and bootstrapping.
def _load_plugins():
def _load_plugins(config):
"""Load the plugins specified in the configuration.
"""
# Add plugin paths.
paths = config['pluginpath'].get(confit.StrSeq(split=False))
paths = map(util.normpath, paths)
import beetsplug
beetsplug.__path__ = get_plugin_paths() + beetsplug.__path__
beetsplug.__path__ = paths + beetsplug.__path__
# For backwards compatibility.
sys.path += get_plugin_paths()
sys.path += paths
# Load requested plugins.
plugins.load_plugins(config['plugins'].as_str_seq())
plugins.send("pluginload")
return plugins
def _configure(args):
"""Parse the command line, load configuration files (including
loading any indicated plugins), and return the invoked subcomand,
the subcommand options, and the subcommand arguments.
def _setup(options, lib=None):
"""Prepare and global state and updates it with command line options.
Returns a list of subcommands, a list of plugins, and a library instance.
"""
# Temporary: Migrate from 1.0-style configuration.
from beets.ui import migrate
migrate.automigrate()
# Configure the MusicBrainz API.
mb.configure()
config = _configure(options)
plugins = _load_plugins(config)
# Get the default subcommands.
from beets.ui.commands import default_commands
# Construct the root parser.
commands = list(default_commands)
commands.append(migrate.migrate_cmd) # Temporary.
parser = SubcommandsOptionParser(subcommands=commands)
parser.add_option('-l', '--library', dest='library',
help='library database file to use')
parser.add_option('-d', '--directory', dest='directory',
help="destination music directory")
parser.add_option('-v', '--verbose', dest='verbose', action='store_true',
help='print debugging information')
parser.add_option('-c', '--config', dest='config',
help='path to configuration file')
subcommands = list(default_commands)
subcommands.extend(plugins.commands())
# Parse the command-line!
options, args = optparse.OptionParser.parse_args(parser, args)
if lib is None:
lib = _open_library(config)
plugins.send("library_opened", lib=lib)
library.Item._types = plugins.types(library.Item)
library.Album._types = plugins.types(library.Album)
return subcommands, plugins, lib
def _configure(options):
"""Amend the global configuration object with command line options.
"""
# Add any additional config files specified with --config. This
# special handling lets specified plugins get loaded before we
# finish parsing the command line.
@@ -894,22 +858,28 @@ def _configure(args):
config.set_file(config_path)
config.set_args(options)
# Now add the plugin commands to the parser.
_load_plugins()
for cmd in plugins.commands():
parser.add_subcommand(cmd)
# Configure the logger.
if config['verbose'].get(bool):
log.setLevel(logging.DEBUG)
else:
log.setLevel(logging.INFO)
# Parse the remainder of the command line with loaded plugins.
return parser._parse_sub(args)
config_path = config.user_config_path()
if os.path.isfile(config_path):
log.debug(u'user configuration: {0}'.format(
util.displayable_path(config_path)))
else:
log.debug(u'no user configuration found at {0}'.format(
util.displayable_path(config_path)))
log.debug(u'data directory: {0}'
.format(util.displayable_path(config.config_dir())))
return config
def _raw_main(args):
"""A helper function for `main` without top-level exception
handling.
def _open_library(config):
"""Create a new library instance from the configuration.
"""
subcommand, suboptions, subargs = _configure(args)
# Open library file.
dbpath = config['library'].as_filename()
try:
lib = library.Library(
@@ -918,32 +888,52 @@ def _raw_main(args):
get_path_formats(),
get_replacements(),
)
except sqlite3.OperationalError:
lib.get_item(0) # Test database connection.
except (sqlite3.OperationalError, sqlite3.DatabaseError):
log.debug(traceback.format_exc())
raise UserError(u"database file {0} could not be opened".format(
util.displayable_path(dbpath)
))
plugins.send("library_opened", lib=lib)
log.debug(u'library database: {0}\n'
u'library directory: {1}'
.format(util.displayable_path(lib.path),
util.displayable_path(lib.directory)))
return lib
# Configure the logger.
if config['verbose'].get(bool):
log.setLevel(logging.DEBUG)
else:
log.setLevel(logging.INFO)
log.debug(u'data directory: {0}\n'
u'library database: {1}\n'
u'library directory: {2}'
.format(
util.displayable_path(config.config_dir()),
util.displayable_path(lib.path),
util.displayable_path(lib.directory),
)
)
# Configure the MusicBrainz API.
mb.configure()
def _raw_main(args, lib=None):
"""A helper function for `main` without top-level exception
handling.
"""
parser = SubcommandsOptionParser()
parser.add_option('-l', '--library', dest='library',
help='library database file to use')
parser.add_option('-d', '--directory', dest='directory',
help="destination music directory")
parser.add_option('-v', '--verbose', dest='verbose', action='store_true',
help='print debugging information')
parser.add_option('-c', '--config', dest='config',
help='path to configuration file')
parser.add_option('-h', '--help', dest='help', action='store_true',
help='how this help message and exit')
parser.add_option('--version', dest='version', action='store_true',
help=optparse.SUPPRESS_HELP)
# Invoke the subcommand.
options, subargs = parser.parse_global_options(args)
# Special case for the `config --edit` command: bypass _setup so
# that an invalid configuration does not prevent the editor from
# starting.
if subargs[0] == 'config' and ('-e' in subargs or '--edit' in subargs):
from beets.ui.commands import config_edit
return config_edit()
subcommands, plugins, lib = _setup(options, lib)
parser.add_subcommand(*subcommands)
subcommand, suboptions, subargs = parser.parse_subcommand(subargs)
subcommand.func(lib, suboptions, subargs)
plugins.send('cli_exit', lib=lib)

File diff suppressed because it is too large Load Diff

View File

@@ -1,401 +0,0 @@
# This file is part of beets.
# Copyright 2013, Adrian Sampson.
#
# Permission is hereby granted, free of charge, to any person obtaining
# a copy of this software and associated documentation files (the
# "Software"), to deal in the Software without restriction, including
# without limitation the rights to use, copy, modify, merge, publish,
# distribute, sublicense, and/or sell copies of the Software, and to
# permit persons to whom the Software is furnished to do so, subject to
# the following conditions:
#
# The above copyright notice and this permission notice shall be
# included in all copies or substantial portions of the Software.
"""Conversion from legacy (pre-1.1) configuration to Confit/YAML
configuration.
"""
import os
import ConfigParser
import codecs
import yaml
import logging
import time
import itertools
import re
import beets
from beets import util
from beets import ui
from beets.util import confit
CONFIG_PATH_VAR = 'BEETSCONFIG'
DEFAULT_CONFIG_FILENAME_UNIX = '.beetsconfig'
DEFAULT_CONFIG_FILENAME_WINDOWS = 'beetsconfig.ini'
DEFAULT_LIBRARY_FILENAME_UNIX = '.beetsmusic.blb'
DEFAULT_LIBRARY_FILENAME_WINDOWS = 'beetsmusic.blb'
WINDOWS_BASEDIR = os.environ.get('APPDATA') or '~'
OLD_CONFIG_SUFFIX = '.old'
PLUGIN_NAMES = {
'rdm': 'random',
'fuzzy_search': 'fuzzy',
}
AUTO_KEYS = ('automatic', 'autofetch', 'autoembed', 'autoscrub')
IMPORTFEEDS_PREFIX = 'feeds_'
CONFIG_MIGRATED_MESSAGE = u"""
You appear to be upgrading from beets 1.0 (or earlier) to 1.1. Your
configuration file has been migrated automatically to:
{newconfig}
Edit this file to configure beets. You might want to remove your
old-style ".beetsconfig" file now. See the documentation for more
details on the new configuration system:
http://beets.readthedocs.org/page/reference/config.html
""".strip()
DB_MIGRATED_MESSAGE = u'Your database file has also been copied to:\n{newdb}'
YAML_COMMENT = '# Automatically migrated from legacy .beetsconfig.\n\n'
log = logging.getLogger('beets')
# An itertools recipe.
def grouper(n, iterable):
args = [iter(iterable)] * n
return itertools.izip_longest(*args)
def _displace(fn):
"""Move a file aside using a timestamp suffix so a new file can be
put in its place.
"""
util.move(
fn,
u'{0}.old.{1}'.format(fn, int(time.time())),
True
)
def default_paths():
"""Produces the appropriate default config and library database
paths for the current system. On Unix, this is always in ~. On
Windows, tries ~ first and then $APPDATA for the config and library
files (for backwards compatibility).
"""
windows = os.path.__name__ == 'ntpath'
if windows:
windata = os.environ.get('APPDATA') or '~'
# Shorthand for joining paths.
def exp(*vals):
return os.path.expanduser(os.path.join(*vals))
config = exp('~', DEFAULT_CONFIG_FILENAME_UNIX)
if windows and not os.path.exists(config):
config = exp(windata, DEFAULT_CONFIG_FILENAME_WINDOWS)
libpath = exp('~', DEFAULT_LIBRARY_FILENAME_UNIX)
if windows and not os.path.exists(libpath):
libpath = exp(windata, DEFAULT_LIBRARY_FILENAME_WINDOWS)
return config, libpath
def get_config():
"""Using the same logic as beets 1.0, locate and read the
.beetsconfig file. Return a ConfigParser instance or None if no
config is found.
"""
default_config, default_libpath = default_paths()
if CONFIG_PATH_VAR in os.environ:
configpath = os.path.expanduser(os.environ[CONFIG_PATH_VAR])
else:
configpath = default_config
config = ConfigParser.SafeConfigParser()
if os.path.exists(util.syspath(configpath)):
with codecs.open(configpath, 'r', encoding='utf-8') as f:
config.readfp(f)
return config, configpath
else:
return None, configpath
def flatten_config(config):
"""Given a ConfigParser, flatten the values into a dict-of-dicts
representation where each section gets its own dictionary of values.
"""
out = confit.OrderedDict()
for section in config.sections():
sec_dict = out[section] = confit.OrderedDict()
for option in config.options(section):
sec_dict[option] = config.get(section, option, True)
return out
def transform_value(value):
"""Given a string read as the value of a config option, return a
massaged version of that value (possibly with a different type).
"""
# Booleans.
if value.lower() in ('false', 'no', 'off'):
return False
elif value.lower() in ('true', 'yes', 'on'):
return True
# Integers.
try:
return int(value)
except ValueError:
pass
# Floats.
try:
return float(value)
except ValueError:
pass
return value
def transform_data(data):
"""Given a dict-of-dicts representation of legacy config data, tweak
the data into a new form. This new form is suitable for dumping as
YAML.
"""
out = confit.OrderedDict()
for section, pairs in data.items():
if section == 'beets':
# The "main" section. In the new config system, these values
# are in the "root": no section at all.
for key, value in pairs.items():
value = transform_value(value)
if key.startswith('import_'):
# Importer config is now under an "import:" key.
if 'import' not in out:
out['import'] = confit.OrderedDict()
out['import'][key[7:]] = value
elif key == 'plugins':
# Renamed plugins.
plugins = value.split()
new_plugins = [PLUGIN_NAMES.get(p, p) for p in plugins]
out['plugins'] = ' '.join(new_plugins)
elif key == 'replace':
# YAMLy representation for character replacements.
replacements = confit.OrderedDict()
for pat, repl in grouper(2, value.split()):
if repl == '<strip>':
repl = ''
replacements[pat] = repl
out['replace'] = replacements
elif key == 'pluginpath':
# Used to be a colon-separated string. Now a list.
out['pluginpath'] = value.split(':')
else:
out[key] = value
elif pairs:
# Other sections (plugins, etc).
sec_out = out[section] = confit.OrderedDict()
for key, value in pairs.items():
# Standardized "auto" option.
if key in AUTO_KEYS:
key = 'auto'
# Unnecessary : hack in queries.
if section == 'paths':
key = key.replace('_', ':')
# Changed option names for importfeeds plugin.
if section == 'importfeeds':
if key.startswith(IMPORTFEEDS_PREFIX):
key = key[len(IMPORTFEEDS_PREFIX):]
sec_out[key] = transform_value(value)
return out
class Dumper(yaml.SafeDumper):
"""A PyYAML Dumper that represents OrderedDicts as ordinary mappings
(in order, of course).
"""
# From http://pyyaml.org/attachment/ticket/161/use_ordered_dict.py
def represent_mapping(self, tag, mapping, flow_style=None):
value = []
node = yaml.MappingNode(tag, value, flow_style=flow_style)
if self.alias_key is not None:
self.represented_objects[self.alias_key] = node
best_style = True
if hasattr(mapping, 'items'):
mapping = list(mapping.items())
for item_key, item_value in mapping:
node_key = self.represent_data(item_key)
node_value = self.represent_data(item_value)
if not (isinstance(node_key, yaml.ScalarNode) and \
not node_key.style):
best_style = False
if not (isinstance(node_value, yaml.ScalarNode) and \
not node_value.style):
best_style = False
value.append((node_key, node_value))
if flow_style is None:
if self.default_flow_style is not None:
node.flow_style = self.default_flow_style
else:
node.flow_style = best_style
return node
Dumper.add_representer(confit.OrderedDict, Dumper.represent_dict)
def migrate_config(replace=False):
"""Migrate a legacy beetsconfig file to a new-style config.yaml file
in an appropriate place. If `replace` is enabled, then any existing
config.yaml will be moved aside. Otherwise, the process is aborted
when the file exists.
"""
# Load legacy configuration data, if any.
config, configpath = get_config()
if not config:
log.debug(u'no config file found at {0}'.format(
util.displayable_path(configpath)
))
return
# Get the new configuration file path and possibly move it out of
# the way.
destfn = os.path.join(beets.config.config_dir(), confit.CONFIG_FILENAME)
if os.path.exists(destfn):
if replace:
log.debug(u'moving old config aside: {0}'.format(
util.displayable_path(destfn)
))
_displace(destfn)
else:
# File exists and we won't replace it. We're done.
return
log.debug(u'migrating config file {0}'.format(
util.displayable_path(configpath)
))
# Convert the configuration to a data structure ready to be dumped
# as the new Confit file.
data = transform_data(flatten_config(config))
# Encode result as YAML.
yaml_out = yaml.dump(
data,
Dumper=Dumper,
default_flow_style=False,
indent=4,
width=1000,
)
# A ridiculous little hack to add some whitespace between "sections"
# in the YAML output. I hope this doesn't break any YAML syntax.
yaml_out = re.sub(r'(\n\w+:\n [^-\s])', '\n\\1', yaml_out)
yaml_out = YAML_COMMENT + yaml_out
# Write the data to the new config destination.
log.debug(u'writing migrated config to {0}'.format(
util.displayable_path(destfn)
))
with open(destfn, 'w') as f:
f.write(yaml_out)
return destfn
def migrate_db(replace=False):
"""Copy the beets library database file to the new location (e.g.,
from ~/.beetsmusic.blb to ~/.config/beets/library.db).
"""
_, srcfn = default_paths()
destfn = beets.config['library'].as_filename()
if not os.path.exists(srcfn) or srcfn == destfn:
# Old DB does not exist or we're configured to point to the same
# database. Do nothing.
return
if os.path.exists(destfn):
if replace:
log.debug(u'moving old database aside: {0}'.format(
util.displayable_path(destfn)
))
_displace(destfn)
else:
return
log.debug(u'copying database from {0} to {1}'.format(
util.displayable_path(srcfn), util.displayable_path(destfn)
))
util.copy(srcfn, destfn)
return destfn
def migrate_state(replace=False):
"""Copy the beets runtime state file from the old path (i.e.,
~/.beetsstate) to the new path (i.e., ~/.config/beets/state.pickle).
"""
srcfn = os.path.expanduser(os.path.join('~', '.beetsstate'))
if not os.path.exists(srcfn):
return
destfn = beets.config['statefile'].as_filename()
if os.path.exists(destfn):
if replace:
_displace(destfn)
else:
return
log.debug(u'copying state file from {0} to {1}'.format(
util.displayable_path(srcfn), util.displayable_path(destfn)
))
util.copy(srcfn, destfn)
return destfn
# Automatic migration when beets starts.
def automigrate():
"""Migrate the configuration, database, and state files. If any
migration occurs, print out a notice with some helpful next steps.
"""
config_fn = migrate_config()
db_fn = migrate_db()
migrate_state()
if config_fn:
ui.print_(ui.colorize('fuchsia', u'MIGRATED CONFIGURATION'))
ui.print_(CONFIG_MIGRATED_MESSAGE.format(
newconfig=util.displayable_path(config_fn))
)
if db_fn:
ui.print_(DB_MIGRATED_MESSAGE.format(
newdb=util.displayable_path(db_fn)
))
ui.input_(ui.colorize('fuchsia', u'Press ENTER to continue:'))
ui.print_()
# CLI command for explicit migration.
migrate_cmd = ui.Subcommand('migrate', help='convert legacy config')
def migrate_func(lib, opts, args):
"""Explicit command for migrating files. Existing files in each
destination are moved aside.
"""
config_fn = migrate_config(replace=True)
if config_fn:
log.info(u'Migrated configuration to: {0}'.format(
util.displayable_path(config_fn)
))
db_fn = migrate_db(replace=True)
if db_fn:
log.info(u'Migrated library database to: {0}'.format(
util.displayable_path(db_fn)
))
state_fn = migrate_state(replace=True)
if state_fn:
log.info(u'Migrated state file to: {0}'.format(
util.displayable_path(state_fn)
))
migrate_cmd.func = migrate_func

View File

@@ -23,10 +23,13 @@ import fnmatch
from collections import defaultdict
import traceback
import subprocess
import platform
MAX_FILENAME_LENGTH = 200
WINDOWS_MAGIC_PREFIX = u'\\\\?\\'
class HumanReadableException(Exception):
"""An Exception that can include a human-readable error message to
be logged without a traceback. Can preserve a traceback for
@@ -82,6 +85,7 @@ class HumanReadableException(Exception):
logger.debug(self.tb)
logger.error(u'{0}: {1}'.format(self.error_kind, self.args[0]))
class FilesystemError(HumanReadableException):
"""An error that occurred while performing a filesystem manipulation
via a function in this module. The `paths` field is a sequence of
@@ -111,6 +115,7 @@ class FilesystemError(HumanReadableException):
return u'{0} {1}'.format(self._reasonstr(), clause)
def normpath(path):
"""Provide the canonical form of the path suitable for storing in
the database.
@@ -119,6 +124,7 @@ def normpath(path):
path = os.path.normpath(os.path.abspath(os.path.expanduser(path)))
return bytestring_path(path)
def ancestry(path):
"""Return a list consisting of path's parent directory, its
grandparent, and so on. For instance:
@@ -137,10 +143,12 @@ def ancestry(path):
break
last_path = path
if path: # don't yield ''
if path:
# don't yield ''
out.insert(0, path)
return out
def sorted_walk(path, ignore=(), logger=None):
"""Like `os.walk`, but yields things in case-insensitive sorted,
breadth-first order. Directory and file names matching any glob
@@ -192,6 +200,7 @@ def sorted_walk(path, ignore=(), logger=None):
for res in sorted_walk(cur, ignore, logger):
yield res
def mkdirall(path):
"""Make all the enclosing directories of path (like mkdir -p on the
parent).
@@ -204,6 +213,7 @@ def mkdirall(path):
raise FilesystemError(exc, 'create', (ancestor,),
traceback.format_exc())
def fnmatch_all(names, patterns):
"""Determine whether all strings in `names` match at least one of
the `patterns`, which should be shell glob expressions.
@@ -218,6 +228,7 @@ def fnmatch_all(names, patterns):
return False
return True
def prune_dirs(path, root=None, clutter=('.DS_Store', 'Thumbs.db')):
"""If path is an empty directory, then remove it. Recursively remove
path's ancestry up to root (which is never removed) where there are
@@ -236,7 +247,7 @@ def prune_dirs(path, root=None, clutter=('.DS_Store', 'Thumbs.db')):
ancestors = []
elif root in ancestors:
# Only remove directories below the root.
ancestors = ancestors[ancestors.index(root)+1:]
ancestors = ancestors[ancestors.index(root) + 1:]
else:
# Remove nothing.
return
@@ -258,6 +269,7 @@ def prune_dirs(path, root=None, clutter=('.DS_Store', 'Thumbs.db')):
else:
break
def components(path):
"""Return a list of the path components in path. For instance:
@@ -281,6 +293,7 @@ def components(path):
return comps
def _fsencoding():
"""Get the system's filesystem encoding. On Windows, this is always
UTF-8 (not MBCS).
@@ -295,6 +308,7 @@ def _fsencoding():
encoding = 'utf8'
return encoding
def bytestring_path(path):
"""Given a path, which is either a str or a unicode, returns a str
path (ensuring that we never deal with Unicode pathnames).
@@ -315,6 +329,7 @@ def bytestring_path(path):
except (UnicodeError, LookupError):
return path.encode('utf8')
def displayable_path(path, separator=u'; '):
"""Attempts to decode a bytestring path to a unicode object for the
purpose of displaying it to the user. If the `path` argument is a
@@ -333,6 +348,7 @@ def displayable_path(path, separator=u'; '):
except (UnicodeError, LookupError):
return path.decode('utf8', 'ignore')
def syspath(path, prefix=True):
"""Convert a path for use by the operating system. In particular,
paths on Windows must receive a magic prefix and must be converted
@@ -356,16 +372,22 @@ def syspath(path, prefix=True):
encoding = sys.getfilesystemencoding() or sys.getdefaultencoding()
path = path.decode(encoding, 'replace')
# Add the magic prefix if it isn't already there
# Add the magic prefix if it isn't already there.
# http://msdn.microsoft.com/en-us/library/windows/desktop/aa365247.aspx
if prefix and not path.startswith(WINDOWS_MAGIC_PREFIX):
if path.startswith(u'\\\\'):
# UNC path. Final path should look like \\?\UNC\...
path = u'UNC' + path[1:]
path = WINDOWS_MAGIC_PREFIX + path
return path
def samefile(p1, p2):
"""Safer equality for paths."""
return shutil._samefile(syspath(p1), syspath(p2))
def remove(path, soft=True):
"""Remove the file. If `soft`, then no error will be raised if the
file does not exist.
@@ -378,6 +400,7 @@ def remove(path, soft=True):
except (OSError, IOError) as exc:
raise FilesystemError(exc, 'delete', (path,), traceback.format_exc())
def copy(path, dest, replace=False):
"""Copy a plain file. Permissions are not copied. If `dest` already
exists, raises a FilesystemError unless `replace` is True. Has no
@@ -396,6 +419,7 @@ def copy(path, dest, replace=False):
raise FilesystemError(exc, 'copy', (path, dest),
traceback.format_exc())
def move(path, dest, replace=False):
"""Rename a file. `dest` may not be a directory. If `dest` already
exists, raises an OSError unless `replace` is True. Has no effect if
@@ -424,6 +448,27 @@ def move(path, dest, replace=False):
raise FilesystemError(exc, 'move', (path, dest),
traceback.format_exc())
def link(path, dest, replace=False):
"""Create a symbolic link from path to `dest`. Raises an OSError if
`dest` already exists, unless `replace` is True. Does nothing if
`path` == `dest`."""
if (samefile(path, dest)):
return
path = syspath(path)
dest = syspath(dest)
if os.path.exists(dest) and not replace:
raise FilesystemError('file exists', 'rename', (path, dest),
traceback.format_exc())
try:
os.symlink(path, dest)
except OSError:
raise FilesystemError('Operating system does not support symbolic '
'links.', 'link', (path, dest),
traceback.format_exc())
def unique_path(path):
"""Returns a version of ``path`` that does not exist on the
filesystem. Specifically, if ``path` itself already exists, then
@@ -457,6 +502,8 @@ CHAR_REPLACE = [
(re.compile(ur'\.$'), u'_'), # Trailing dots.
(re.compile(ur'\s+$'), u''), # Trailing whitespace.
]
def sanitize_path(path, replacements=None):
"""Takes a path (as a Unicode string) and makes sure that it is
legal. Returns a new path. Only works with fragments; won't work
@@ -477,6 +524,7 @@ def sanitize_path(path, replacements=None):
comps[i] = comp
return os.path.join(*comps)
def truncate_path(path, length=MAX_FILENAME_LENGTH):
"""Given a bytestring path or a Unicode path fragment, truncate the
components to a legal length. In the last component, the extension
@@ -493,6 +541,7 @@ def truncate_path(path, length=MAX_FILENAME_LENGTH):
return os.path.join(*out)
def str2bool(value):
"""Returns a boolean reflecting a human-entered string."""
if value.lower() in ('yes', '1', 'true', 't', 'y'):
@@ -500,6 +549,7 @@ def str2bool(value):
else:
return False
def as_string(value):
"""Convert a value to a Unicode object for matching with a query.
None becomes the empty string. Bytestrings are silently decoded.
@@ -513,6 +563,7 @@ def as_string(value):
else:
return unicode(value)
def levenshtein(s1, s2):
"""A nice DP edit distance implementation from Wikibooks:
http://en.wikibooks.org/wiki/Algorithm_implementation/Strings/
@@ -535,6 +586,7 @@ def levenshtein(s1, s2):
return previous_row[-1]
def plurality(objs):
"""Given a sequence of comparable objects, returns the object that
is most common in the set and the frequency of that object. The
@@ -558,6 +610,7 @@ def plurality(objs):
return res, max_freq
def cpu_count():
"""Return the number of hardware thread contexts (cores or SMT
threads) in the system.
@@ -571,7 +624,7 @@ def cpu_count():
num = 0
elif sys.platform == 'darwin':
try:
num = int(os.popen('sysctl -n hw.ncpu').read())
num = int(command_output(['sysctl', '-n', 'hw.ncpu']))
except ValueError:
num = 0
else:
@@ -584,23 +637,38 @@ def cpu_count():
else:
return 1
def command_output(cmd):
"""Wraps the `subprocess` module to invoke a command (given as a
list of arguments starting with the command name) and collect
stdout. The stderr stream is ignored. May raise
`subprocess.CalledProcessError` or an `OSError`.
def command_output(cmd, shell=False):
"""Runs the command and returns its output after it has exited.
``cmd`` is a list of arguments starting with the command names. If
``shell`` is true, ``cmd`` is assumed to be a string and passed to a
shell to execute.
If the process exits with a non-zero return code
``subprocess.CalledProcessError`` is raised. May also raise
``OSError``.
This replaces `subprocess.check_output`, which isn't available in
Python 2.6 and which can have problems if lots of output is sent to
stderr.
"""
with open(os.devnull, 'w') as devnull:
proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=devnull)
stdout, _ = proc.communicate()
proc = subprocess.Popen(
cmd,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
close_fds=platform.system() != 'Windows',
shell=shell
)
stdout, stderr = proc.communicate()
if proc.returncode:
raise subprocess.CalledProcessError(proc.returncode, cmd)
raise subprocess.CalledProcessError(
returncode=proc.returncode,
cmd=' '.join(cmd),
)
return stdout
def max_filename_length(path, limit=MAX_FILENAME_LENGTH):
"""Attempt to determine the maximum filename length for the
filesystem containing `path`. If the value is greater than `limit`,

View File

@@ -1,5 +1,5 @@
# This file is part of beets.
# Copyright 2013, Fabrice Laporte
# Copyright 2014, Fabrice Laporte
#
# Permission is hereby granted, free of charge, to any person obtaining
# a copy of this software and associated documentation files (the
@@ -18,6 +18,7 @@ public resizing proxy if neither is available.
import urllib
import subprocess
import os
import re
from tempfile import NamedTemporaryFile
import logging
from beets import util
@@ -37,7 +38,7 @@ def resize_url(url, maxwidth):
maxwidth (preserving aspect ratio).
"""
return '{0}?{1}'.format(PROXY_URL, urllib.urlencode({
'url': url.replace('http://',''),
'url': url.replace('http://', ''),
'w': str(maxwidth),
}))
@@ -76,7 +77,7 @@ def pil_resize(maxwidth, path_in, path_out=None):
def im_resize(maxwidth, path_in, path_out=None):
"""Resize using ImageMagick's ``convert`` tool.
tool. Return the output path of resized image.
Return the output path of resized image.
"""
path_out = path_out or temp_file_for(path_in)
log.debug(u'artresizer: ImageMagick resizing {0} to {1}'.format(
@@ -132,8 +133,9 @@ class ArtResizer(object):
"""Create a resizer object for the given method or, if none is
specified, with an inferred method.
"""
self.method = method or self._guess_method()
self.method = self._check_method(method)
log.debug(u"artresizer: method is {0}".format(self.method))
self.can_compare = self._can_compare()
def resize(self, maxwidth, path_in, path_out=None):
"""Manipulate an image file according to the method, returning a
@@ -141,7 +143,7 @@ class ArtResizer(object):
temporary file. For WEBPROXY, returns `path_in` unmodified.
"""
if self.local:
func = BACKEND_FUNCS[self.method]
func = BACKEND_FUNCS[self.method[0]]
return func(maxwidth, path_in, path_out)
else:
return path_in
@@ -159,30 +161,51 @@ class ArtResizer(object):
@property
def local(self):
"""A boolean indicating whether the resizing method is performed
locally (i.e., PIL or IMAGEMAGICK).
locally (i.e., PIL or ImageMagick).
"""
return self.method in BACKEND_FUNCS
return self.method[0] in BACKEND_FUNCS
def _can_compare(self):
"""A boolean indicating whether image comparison is available"""
return self.method[0] == IMAGEMAGICK and self.method[1] > (6, 8, 7)
@staticmethod
def _guess_method():
"""Determine which resizing method to use. Returns PIL,
IMAGEMAGICK, or WEBPROXY depending on available dependencies.
def _check_method(method=None):
"""A tuple indicating whether current method is available and its
version. If no method is given, it returns a supported one.
"""
# Try importing PIL.
try:
__import__('PIL', fromlist=['Image'])
return PIL
except ImportError:
pass
# Guess available method
if not method:
for m in [IMAGEMAGICK, PIL]:
_, version = ArtResizer._check_method(m)
if version:
return (m, version)
return (WEBPROXY, (0))
# Try invoking ImageMagick's "convert".
try:
out = util.command_output(['convert', '--version'])
if 'imagemagick' in out.lower():
# system32/convert.exe may be interfering
return IMAGEMAGICK
except (subprocess.CalledProcessError, OSError):
pass
if method == IMAGEMAGICK:
# Fall back to Web proxy method.
return WEBPROXY
# Try invoking ImageMagick's "convert".
try:
out = util.command_output(['identify', '--version'])
if 'imagemagick' in out.lower():
pattern = r".+ (\d+)\.(\d+)\.(\d+).*"
match = re.search(pattern, out)
if match:
return (IMAGEMAGICK,
(int(match.group(1)),
int(match.group(2)),
int(match.group(3))))
return (IMAGEMAGICK, (0))
except (subprocess.CalledProcessError, OSError):
return (IMAGEMAGICK, None)
if method == PIL:
# Try importing PIL.
try:
__import__('PIL', fromlist=['Image'])
return (PIL, (0))
except ImportError:
return (PIL, None)

View File

@@ -38,6 +38,7 @@ class Event(object):
"""
pass
class WaitableEvent(Event):
"""A waitable event is one encapsulating an action that can be
waited for using a select() call. That is, it's an event with an
@@ -57,21 +58,25 @@ class WaitableEvent(Event):
"""
pass
class ValueEvent(Event):
"""An event that does nothing but return a fixed value."""
def __init__(self, value):
self.value = value
class ExceptionEvent(Event):
"""Raise an exception at the yield point. Used internally."""
def __init__(self, exc_info):
self.exc_info = exc_info
class SpawnEvent(Event):
"""Add a new coroutine thread to the scheduler."""
def __init__(self, coro):
self.spawned = coro
class JoinEvent(Event):
"""Suspend the thread until the specified child thread has
completed.
@@ -79,11 +84,13 @@ class JoinEvent(Event):
def __init__(self, child):
self.child = child
class KillEvent(Event):
"""Unschedule a child thread."""
def __init__(self, child):
self.child = child
class DelegationEvent(Event):
"""Suspend execution of the current thread, start a new thread and,
once the child thread finished, return control to the parent
@@ -92,6 +99,7 @@ class DelegationEvent(Event):
def __init__(self, coro):
self.spawned = coro
class ReturnEvent(Event):
"""Return a value the current thread's delegator at the point of
delegation. Ends the current (delegate) thread.
@@ -99,6 +107,7 @@ class ReturnEvent(Event):
def __init__(self, value):
self.value = value
class SleepEvent(WaitableEvent):
"""Suspend the thread for a given duration.
"""
@@ -108,6 +117,7 @@ class SleepEvent(WaitableEvent):
def time_left(self):
return max(self.wakeup_time - time.time(), 0.0)
class ReadEvent(WaitableEvent):
"""Reads from a file-like object."""
def __init__(self, fd, bufsize):
@@ -120,6 +130,7 @@ class ReadEvent(WaitableEvent):
def fire(self):
return self.fd.read(self.bufsize)
class WriteEvent(WaitableEvent):
"""Writes to a file-like object."""
def __init__(self, fd, data):
@@ -192,15 +203,19 @@ def _event_select(events):
return ready_events
class ThreadException(Exception):
def __init__(self, coro, exc_info):
self.coro = coro
self.exc_info = exc_info
def reraise(self):
_reraise(self.exc_info[0], self.exc_info[1], self.exc_info[2])
SUSPENDED = Event() # Special sentinel placeholder for suspended threads.
class Delegated(Event):
"""Placeholder indicating that a thread has delegated execution to a
different thread.
@@ -208,6 +223,7 @@ class Delegated(Event):
def __init__(self, child):
self.child = child
def run(root_coro):
"""Schedules a coroutine, running it to completion. This
encapsulates the Bluelet scheduler, which the root coroutine can
@@ -329,7 +345,7 @@ def run(root_coro):
break
# Wait and fire.
event2coro = dict((v,k) for k,v in threads.items())
event2coro = dict((v, k) for k, v in threads.items())
for event in _event_select(threads.values()):
# Run the IO operation, but catch socket errors.
try:
@@ -378,6 +394,7 @@ def run(root_coro):
class SocketClosedError(Exception):
pass
class Listener(object):
"""A socket wrapper object for listening sockets.
"""
@@ -407,6 +424,7 @@ class Listener(object):
self._closed = True
self.sock.close()
class Connection(object):
"""A socket wrapper object for connected sockets.
"""
@@ -468,6 +486,7 @@ class Connection(object):
yield ReturnEvent(line)
break
class AcceptEvent(WaitableEvent):
"""An event for Listener objects (listening sockets) that suspends
execution until the socket gets a connection.
@@ -482,6 +501,7 @@ class AcceptEvent(WaitableEvent):
sock, addr = self.listener.sock.accept()
return Connection(sock, addr)
class ReceiveEvent(WaitableEvent):
"""An event for Connection objects (connected sockets) for
asynchronously reading data.
@@ -496,6 +516,7 @@ class ReceiveEvent(WaitableEvent):
def fire(self):
return self.conn.sock.recv(self.bufsize)
class SendEvent(WaitableEvent):
"""An event for Connection objects (connected sockets) for
asynchronously writing data.
@@ -523,6 +544,7 @@ def null():
"""
return ValueEvent(None)
def spawn(coro):
"""Event: add another coroutine to the scheduler. Both the parent
and child coroutines run concurrently.
@@ -531,6 +553,7 @@ def spawn(coro):
raise ValueError('%s is not a coroutine' % str(coro))
return SpawnEvent(coro)
def call(coro):
"""Event: delegate to another coroutine. The current coroutine
is resumed once the sub-coroutine finishes. If the sub-coroutine
@@ -540,12 +563,14 @@ def call(coro):
raise ValueError('%s is not a coroutine' % str(coro))
return DelegationEvent(coro)
def end(value=None):
"""Event: ends the coroutine and returns a value to its
delegator.
"""
return ReturnEvent(value)
def read(fd, bufsize=None):
"""Event: read from a file descriptor asynchronously."""
if bufsize is None:
@@ -563,10 +588,12 @@ def read(fd, bufsize=None):
else:
return ReadEvent(fd, bufsize)
def write(fd, data):
"""Event: write to a file descriptor asynchronously."""
return WriteEvent(fd, data)
def connect(host, port):
"""Event: connect to a network address and return a Connection
object for communicating on the socket.
@@ -575,17 +602,20 @@ def connect(host, port):
sock = socket.create_connection(addr)
return ValueEvent(Connection(sock, addr))
def sleep(duration):
"""Event: suspend the thread for ``duration`` seconds.
"""
return SleepEvent(duration)
def join(coro):
"""Suspend the thread until another, previously `spawn`ed thread
completes.
"""
return JoinEvent(coro)
def kill(coro):
"""Halt the execution of a different `spawn`ed thread.
"""

View File

@@ -21,6 +21,8 @@ import pkgutil
import sys
import yaml
import types
import collections
import re
try:
from collections import OrderedDict
except ImportError:
@@ -47,6 +49,7 @@ BASESTRING = str if PY3 else basestring
NUMERIC_TYPES = (int, float) if PY3 else (int, float, long)
TYPE_TYPES = (type,) if PY3 else (type, types.ClassType)
def iter_first(sequence):
"""Get the first element from an iterable or raise a ValueError if
the iterator generates no values.
@@ -67,16 +70,25 @@ class ConfigError(Exception):
"""Base class for exceptions raised when querying a configuration.
"""
class NotFoundError(ConfigError):
"""A requested value could not be found in the configuration trees.
"""
class ConfigTypeError(ConfigError, TypeError):
class ConfigValueError(ConfigError):
"""The value in the configuration is illegal."""
class ConfigTypeError(ConfigValueError):
"""The value in the configuration did not match the expected type.
"""
class ConfigValueError(ConfigError, ValueError):
"""The value in the configuration is illegal."""
class ConfigTemplateError(ConfigError):
"""Base class for exceptions raised because of an invalid template.
"""
class ConfigReadError(ConfigError):
"""A configuration file could not be read."""
@@ -132,6 +144,7 @@ class ConfigSource(dict):
else:
raise TypeError('source value must be a dict')
class ConfigView(object):
"""A configuration "view" is a query into a program's configuration
data. A view represents a hypothetical location in the configuration
@@ -207,6 +220,9 @@ class ConfigView(object):
"""
self.set({key: value})
def __contains__(self, key):
return self[key].exists()
def set_args(self, namespace):
"""Overlay parsed command-line arguments, generated by a library
like argparse or optparse, onto this view's value.
@@ -310,98 +326,6 @@ class ConfigView(object):
# Validation and conversion.
def get(self, typ=None):
"""Returns the canonical value for the view, checked against the
passed-in type. If the value is not an instance of the given
type, a ConfigTypeError is raised. May also raise a
NotFoundError.
"""
value, _ = self.first()
if typ is not None:
if not isinstance(typ, TYPE_TYPES):
raise TypeError('argument to get() must be a type')
if not isinstance(value, typ):
raise ConfigTypeError(
"{0} must be of type {1}, not {2}".format(
self.name, typ.__name__, type(value).__name__
)
)
return value
def as_filename(self):
"""Get a string as a normalized as an absolute, tilde-free path.
Relative paths are relative to the configuration directory (see
the `config_dir` method) if they come from a file. Otherwise,
they are relative to the current working directory. This helps
attain the expected behavior when using command-line options.
"""
path, source = self.first()
if not isinstance(path, BASESTRING):
raise ConfigTypeError('{0} must be a filename, not {1}'.format(
self.name, type(path).__name__
))
path = os.path.expanduser(STRING(path))
if not os.path.isabs(path) and source.filename:
# From defaults: relative to the app's directory.
path = os.path.join(self.root().config_dir(), path)
return os.path.abspath(path)
def as_choice(self, choices):
"""Ensure that the value is among a collection of choices and
return it. If `choices` is a dictionary, then return the
corresponding value rather than the value itself (the key).
"""
value = self.get()
if value not in choices:
raise ConfigValueError(
'{0} must be one of {1}, not {2}'.format(
self.name, repr(list(choices)), repr(value)
)
)
if isinstance(choices, dict):
return choices[value]
else:
return value
def as_number(self):
"""Ensure that a value is of numeric type."""
value = self.get()
if isinstance(value, NUMERIC_TYPES):
return value
raise ConfigTypeError(
'{0} must be numeric, not {1}'.format(
self.name, type(value).__name__
)
)
def as_str_seq(self):
"""Get the value as a list of strings. The underlying configured
value can be a sequence or a single string. In the latter case,
the string is treated as a white-space separated list of words.
"""
value = self.get()
if isinstance(value, bytes):
value = value.decode('utf8', 'ignore')
if isinstance(value, STRING):
return value.split()
else:
try:
return list(value)
except TypeError:
raise ConfigTypeError(
'{0} must be a whitespace-separated string or '
'a list'.format(self.name)
)
def flatten(self):
"""Create a hierarchy of OrderedDicts containing the data from
this view, recursively reifying all views to get their
@@ -415,6 +339,35 @@ class ConfigView(object):
od[key] = view.get()
return od
def get(self, template=None):
"""Retrieve the value for this view according to the template.
The `template` against which the values are checked can be
anything convertible to a `Template` using `as_template`. This
means you can pass in a default integer or string value, for
example, or a type to just check that something matches the type
you expect.
May raise a `ConfigValueError` (or its subclass,
`ConfigTypeError`) or a `NotFoundError` when the configuration
doesn't satisfy the template.
"""
return as_template(template).value(self, template)
# Old validation methods (deprecated).
def as_filename(self):
return self.get(Filename())
def as_choice(self, choices):
return self.get(Choice(choices))
def as_number(self):
return self.get(Number())
def as_str_seq(self):
return self.get(StrSeq())
class RootView(ConfigView):
"""The base of a view hierarchy. This view keeps track of the
@@ -518,6 +471,7 @@ def _package_path(name):
return os.path.dirname(os.path.abspath(filepath))
def config_dirs():
"""Return a platform-specific list of candidates for user
configuration directories on the system.
@@ -606,10 +560,12 @@ class Loader(yaml.SafeLoader):
plain = super(Loader, self).check_plain()
return plain or self.peek() == '%'
Loader.add_constructor('tag:yaml.org,2002:str', Loader._construct_unicode)
Loader.add_constructor('tag:yaml.org,2002:map', Loader.construct_yaml_map)
Loader.add_constructor('tag:yaml.org,2002:omap', Loader.construct_yaml_map)
def load_yaml(filename):
"""Read a YAML document from a file. If the file cannot be read or
parsed, a ConfigReadError is raised.
@@ -679,11 +635,13 @@ class Dumper(yaml.SafeDumper):
"""
return self.represent_scalar('tag:yaml.org,2002:null', '')
Dumper.add_representer(OrderedDict, Dumper.represent_dict)
Dumper.add_representer(bool, Dumper.represent_bool)
Dumper.add_representer(type(None), Dumper.represent_none)
Dumper.add_representer(list, Dumper.represent_list)
def restore_yaml_comments(data, default_data):
"""Scan default_data for comments (we include empty lines in our
definition of comments) and place them before the same keys in data.
@@ -898,3 +856,426 @@ class LazyConfig(Configuration):
del self.sources[:]
self._lazy_suffix = []
self._lazy_prefix = []
# "Validated" configuration views: experimental!
REQUIRED = object()
"""A sentinel indicating that there is no default value and an exception
should be raised when the value is missing.
"""
class Template(object):
"""A value template for configuration fields.
The template works like a type and instructs Confit about how to
interpret a deserialized YAML value. This includes type conversions,
providing a default value, and validating for errors. For example, a
filepath type might expand tildes and check that the file exists.
"""
def __init__(self, default=REQUIRED):
"""Create a template with a given default value.
If `default` is the sentinel `REQUIRED` (as it is by default),
then an error will be raised when a value is missing. Otherwise,
missing values will instead return `default`.
"""
self.default = default
def __call__(self, view):
"""Invoking a template on a view gets the view's value according
to the template.
"""
return self.value(view, self)
def value(self, view, template=None):
"""Get the value for a `ConfigView`.
May raise a `NotFoundError` if the value is missing (and the
template requires it) or a `ConfigValueError` for invalid values.
"""
if view.exists():
value, _ = view.first()
return self.convert(value, view)
elif self.default is REQUIRED:
# Missing required value. This is an error.
raise NotFoundError("{0} not found".format(view.name))
else:
# Missing value, but not required.
return self.default
def convert(self, value, view):
"""Convert the YAML-deserialized value to a value of the desired
type.
Subclasses should override this to provide useful conversions.
May raise a `ConfigValueError` when the configuration is wrong.
"""
# Default implementation does no conversion.
return value
def fail(self, message, view, type_error=False):
"""Raise an exception indicating that a value cannot be
accepted.
`type_error` indicates whether the error is due to a type
mismatch rather than a malformed value. In this case, a more
specific exception is raised.
"""
exc_class = ConfigTypeError if type_error else ConfigValueError
raise exc_class(
'{0}: {1}'.format(view.name, message)
)
def __repr__(self):
return '{0}({1})'.format(
type(self).__name__,
'' if self.default is REQUIRED else repr(self.default),
)
class Integer(Template):
"""An integer configuration value template.
"""
def convert(self, value, view):
"""Check that the value is an integer. Floats are rounded.
"""
if isinstance(value, int):
return value
elif isinstance(value, float):
return int(value)
else:
self.fail('must be a number', view, True)
class Number(Template):
"""A numeric type: either an integer or a floating-point number.
"""
def convert(self, value, view):
"""Check that the value is an int or a float.
"""
if isinstance(value, NUMERIC_TYPES):
return value
else:
self.fail(
'must be numeric, not {0}'.format(type(value).__name__),
view,
True
)
class MappingTemplate(Template):
"""A template that uses a dictionary to specify other types for the
values for a set of keys and produce a validated `AttrDict`.
"""
def __init__(self, mapping):
"""Create a template according to a dict (mapping). The
mapping's values should themselves either be Types or
convertible to Types.
"""
subtemplates = {}
for key, typ in mapping.items():
subtemplates[key] = as_template(typ)
self.subtemplates = subtemplates
def value(self, view, template=None):
"""Get a dict with the same keys as the template and values
validated according to the value types.
"""
out = AttrDict()
for key, typ in self.subtemplates.items():
out[key] = typ.value(view[key], self)
return out
def __repr__(self):
return 'MappingTemplate({0})'.format(repr(self.subtemplates))
class String(Template):
"""A string configuration value template.
"""
def __init__(self, default=REQUIRED, pattern=None):
"""Create a template with the added optional `pattern` argument,
a regular expression string that the value should match.
"""
super(String, self).__init__(default)
self.pattern = pattern
if pattern:
self.regex = re.compile(pattern)
def convert(self, value, view):
"""Check that the value is a string and matches the pattern.
"""
if isinstance(value, BASESTRING):
if self.pattern and not self.regex.match(value):
self.fail(
"must match the pattern {0}".format(self.pattern),
view
)
return value
else:
self.fail('must be a string', view, True)
class Choice(Template):
"""A template that permits values from a sequence of choices.
"""
def __init__(self, choices):
"""Create a template that validates any of the values from the
iterable `choices`.
If `choices` is a map, then the corresponding value is emitted.
Otherwise, the value itself is emitted.
"""
self.choices = choices
def convert(self, value, view):
"""Ensure that the value is among the choices (and remap if the
choices are a mapping).
"""
if value not in self.choices:
self.fail(
'must be one of {0}, not {1}'.format(
repr(list(self.choices)), repr(value)
),
view
)
if isinstance(self.choices, collections.Mapping):
return self.choices[value]
else:
return value
def __repr__(self):
return 'Choice({0!r})'.format(self.choices)
class StrSeq(Template):
"""A template for values that are lists of strings.
Validates both actual YAML string lists and single strings. Strings
can optionally be split on whitespace.
"""
def __init__(self, split=True):
"""Create a new template.
`split` indicates whether, when the underlying value is a single
string, it should be split on whitespace. Otherwise, the
resulting value is a list containing a single string.
"""
super(StrSeq, self).__init__()
self.split = split
def convert(self, value, view):
if isinstance(value, bytes):
value = value.decode('utf8', 'ignore')
if isinstance(value, STRING):
if self.split:
return value.split()
else:
return [value]
try:
value = list(value)
except TypeError:
self.fail('must be a whitespace-separated string or a list',
view, True)
def convert(x):
if isinstance(x, unicode):
return x
elif isinstance(x, BASESTRING):
return x.decode('utf8', 'ignore')
else:
self.fail('must be a list of strings', view, True)
return map(convert, value)
class Filename(Template):
"""A template that validates strings as filenames.
Filenames are returned as absolute, tilde-free paths.
Relative paths are relative to the template's `cwd` argument
when it is specified, then the configuration directory (see
the `config_dir` method) if they come from a file. Otherwise,
they are relative to the current working directory. This helps
attain the expected behavior when using command-line options.
"""
def __init__(self, default=REQUIRED, cwd=None, relative_to=None,
in_app_dir=False):
""" `relative_to` is the name of a sibling value that is
being validated at the same time.
`in_app_dir` indicates whether the path should be resolved
inside the application's config directory (even when the setting
does not come from a file).
"""
super(Filename, self).__init__(default)
self.cwd = cwd
self.relative_to = relative_to
self.in_app_dir = in_app_dir
def __repr__(self):
args = []
if self.default is not REQUIRED:
args.append(repr(self.default))
if self.cwd is not None:
args.append('cwd=' + repr(self.cwd))
if self.relative_to is not None:
args.append('relative_to=' + repr(self.relative_to))
if self.in_app_dir:
args.append('in_app_dir=True')
return 'Filename({0})'.format(', '.join(args))
def resolve_relative_to(self, view, template):
if not isinstance(template, (collections.Mapping, MappingTemplate)):
# disallow config.get(Filename(relative_to='foo'))
raise ConfigTemplateError(
'relative_to may only be used when getting multiple values.'
)
elif self.relative_to == view.key:
raise ConfigTemplateError(
'{0} is relative to itself'.format(view.name)
)
elif self.relative_to not in view.parent.keys():
# self.relative_to is not in the config
self.fail(
(
'needs sibling value "{0}" to expand relative path'
).format(self.relative_to),
view
)
old_template = {}
old_template.update(template.subtemplates)
# save time by skipping MappingTemplate's init loop
next_template = MappingTemplate({})
next_relative = self.relative_to
# gather all the needed templates and nothing else
while next_relative is not None:
try:
# pop to avoid infinite loop because of recursive
# relative paths
rel_to_template = old_template.pop(next_relative)
except KeyError:
if next_relative in template.subtemplates:
# we encountered this config key previously
raise ConfigTemplateError((
'{0} and {1} are recursively relative'
).format(view.name, self.relative_to))
else:
raise ConfigTemplateError((
'missing template for {0}, needed to expand {1}\'s' +
'relative path'
).format(self.relative_to, view.name))
next_template.subtemplates[next_relative] = rel_to_template
next_relative = rel_to_template.relative_to
return view.parent.get(next_template)[self.relative_to]
def value(self, view, template=None):
path, source = view.first()
if not isinstance(path, BASESTRING):
self.fail(
'must be a filename, not {0}'.format(type(path).__name__),
view,
True
)
path = os.path.expanduser(STRING(path))
if not os.path.isabs(path):
if self.cwd is not None:
# relative to the template's argument
path = os.path.join(self.cwd, path)
elif self.relative_to is not None:
path = os.path.join(
self.resolve_relative_to(view, template),
path,
)
elif source.filename or self.in_app_dir:
# From defaults: relative to the app's directory.
path = os.path.join(view.root().config_dir(), path)
return os.path.abspath(path)
class TypeTemplate(Template):
"""A simple template that checks that a value is an instance of a
desired Python type.
"""
def __init__(self, typ, default=REQUIRED):
"""Create a template that checks that the value is an instance
of `typ`.
"""
super(TypeTemplate, self).__init__(default)
self.typ = typ
def convert(self, value, view):
if not isinstance(value, self.typ):
self.fail(
'must be a {0}, not {1}'.format(
self.typ.__name__,
type(value).__name__,
),
view,
True
)
return value
class AttrDict(dict):
"""A `dict` subclass that can be accessed via attributes (dot
notation) for convenience.
"""
def __getattr__(self, key):
if key in self:
return self[key]
else:
raise AttributeError(key)
def as_template(value):
"""Convert a simple "shorthand" Python value to a `Template`.
"""
if isinstance(value, Template):
# If it's already a Template, pass it through.
return value
elif isinstance(value, collections.Mapping):
# Dictionaries work as templates.
return MappingTemplate(value)
elif value is int:
return Integer()
elif isinstance(value, int):
return Integer(value)
elif isinstance(value, type) and issubclass(value, BASESTRING):
return String()
elif isinstance(value, BASESTRING):
return String(value)
elif value is float:
return Number()
elif value is None:
return Template()
elif value is dict:
return TypeTemplate(collections.Mapping)
elif value is list:
return TypeTemplate(collections.Sequence)
elif isinstance(value, type):
return TypeTemplate(value)
else:
raise ValueError('cannot convert to template: {0!r}'.format(value))

View File

@@ -12,167 +12,29 @@
# The above copyright notice and this permission notice shall be
# included in all copies or substantial portions of the Software.
"""A metaclass for enumerated types that really are types.
from enum import Enum
You can create enumerations with `enum(values, [name])` and they work
how you would expect them to.
>>> from enumeration import enum
>>> Direction = enum('north east south west', name='Direction')
>>> Direction.west
Direction.west
>>> Direction.west == Direction.west
True
>>> Direction.west == Direction.east
False
>>> isinstance(Direction.west, Direction)
True
>>> Direction[3]
Direction.west
>>> Direction['west']
Direction.west
>>> Direction.west.name
'west'
>>> Direction.north < Direction.west
True
Enumerations are classes; their instances represent the possible values
of the enumeration. Because Python classes must have names, you may
provide a `name` parameter to `enum`; if you don't, a meaningless one
will be chosen for you.
"""
import random
class Enumeration(type):
"""A metaclass whose classes are enumerations.
The `values` attribute of the class is used to populate the
enumeration. Values may either be a list of enumerated names or a
string containing a space-separated list of names. When the class
is created, it is instantiated for each name value in `values`.
Each such instance is the name of the enumerated item as the sole
argument.
The `Enumerated` class is a good choice for a superclass.
class OrderedEnum(Enum):
"""
def __init__(cls, name, bases, dic):
super(Enumeration, cls).__init__(name, bases, dic)
if 'values' not in dic:
# Do nothing if no values are provided (i.e., with
# Enumerated itself).
return
# May be called with a single string, in which case we split on
# whitespace for convenience.
values = dic['values']
if isinstance(values, basestring):
values = values.split()
# Create the Enumerated instances for each value. We have to use
# super's __setattr__ here because we disallow setattr below.
super(Enumeration, cls).__setattr__('_items_dict', {})
super(Enumeration, cls).__setattr__('_items_list', [])
for value in values:
item = cls(value, len(cls._items_list))
cls._items_dict[value] = item
cls._items_list.append(item)
def __getattr__(cls, key):
try:
return cls._items_dict[key]
except KeyError:
raise AttributeError("enumeration '" + cls.__name__ +
"' has no item '" + key + "'")
def __setattr__(cls, key, val):
raise TypeError("enumerations do not support attribute assignment")
def __getitem__(cls, key):
if isinstance(key, int):
return cls._items_list[key]
else:
return getattr(cls, key)
def __len__(cls):
return len(cls._items_list)
def __iter__(cls):
return iter(cls._items_list)
def __nonzero__(cls):
# Ensures that __len__ doesn't get called before __init__ by
# pydoc.
return True
class Enumerated(object):
"""An item in an enumeration.
Contains instance methods inherited by enumerated objects. The
metaclass is preset to `Enumeration` for your convenience.
Instance attributes:
name -- The name of the item.
index -- The index of the item in its enumeration.
>>> from enumeration import Enumerated
>>> class Garment(Enumerated):
... values = 'hat glove belt poncho lederhosen suspenders'
... def wear(self):
... print('now wearing a ' + self.name)
...
>>> Garment.poncho.wear()
now wearing a poncho
An Enum subclass that allows comparison of members.
"""
def __ge__(self, other):
if self.__class__ is other.__class__:
return self.value >= other.value
return NotImplemented
__metaclass__ = Enumeration
def __gt__(self, other):
if self.__class__ is other.__class__:
return self.value > other.value
return NotImplemented
def __init__(self, name, index):
self.name = name
self.index = index
def __le__(self, other):
if self.__class__ is other.__class__:
return self.value <= other.value
return NotImplemented
def __str__(self):
return type(self).__name__ + '.' + self.name
def __repr__(self):
return str(self)
def __cmp__(self, other):
if type(self) is type(other):
# Note that we're assuming that the items are direct
# instances of the same Enumeration (i.e., no fancy
# subclassing), which is probably okay.
return cmp(self.index, other.index)
else:
return NotImplemented
def enum(*values, **kwargs):
"""Shorthand for creating a new Enumeration class.
Call with enumeration values as a list, a space-delimited string, or
just an argument list. To give the class a name, pass it as the
`name` keyword argument. Otherwise, a name will be chosen for you.
The following are all equivalent:
enum('pinkie ring middle index thumb')
enum('pinkie', 'ring', 'middle', 'index', 'thumb')
enum(['pinkie', 'ring', 'middle', 'index', 'thumb'])
"""
if ('name' not in kwargs) or kwargs['name'] is None:
# Create a probably-unique name. It doesn't really have to be
# unique, but getting distinct names each time helps with
# identification in debugging.
name = 'Enumeration' + hex(random.randint(0,0xfffffff))[2:].upper()
else:
name = kwargs['name']
if len(values) == 1:
# If there's only one value, we have a couple of alternate calling
# styles.
if isinstance(values[0], basestring) or hasattr(values[0], '__iter__'):
values = values[0]
return type(name, (Enumerated,), {'values': values})
def __lt__(self, other):
if self.__class__ is other.__class__:
return self.value < other.value
return NotImplemented

View File

@@ -42,6 +42,7 @@ ESCAPE_CHAR = u'$'
VARIABLE_PREFIX = '__var_'
FUNCTION_PREFIX = '__func_'
class Environment(object):
"""Contains the values and functions to be substituted into a
template.
@@ -57,10 +58,12 @@ def ex_lvalue(name):
"""A variable load expression."""
return ast.Name(name, ast.Store())
def ex_rvalue(name):
"""A variable store expression."""
return ast.Name(name, ast.Load())
def ex_literal(val):
"""An int, float, long, bool, string, or None literal with the given
value.
@@ -75,6 +78,7 @@ def ex_literal(val):
return ast.Str(val)
raise TypeError('no literal for {0}'.format(type(val)))
def ex_varassign(name, expr):
"""Assign an expression into a single variable. The expression may
either be an `ast.expr` object or a value to be used as a literal.
@@ -83,6 +87,7 @@ def ex_varassign(name, expr):
expr = ex_literal(expr)
return ast.Assign([ex_lvalue(name)], expr)
def ex_call(func, args):
"""A function-call expression with only positional parameters. The
function may be an expression or the name of a function. Each
@@ -98,6 +103,7 @@ def ex_call(func, args):
return ast.Call(func, args, [], None, None)
def compile_func(arg_names, statements, name='_the_func', debug=False):
"""Compile a list of statements as the body of a function and return
the resulting Python function. If `debug`, then print out the
@@ -157,6 +163,7 @@ class Symbol(object):
expr = ex_rvalue(VARIABLE_PREFIX + self.ident.encode('utf8'))
return [expr], set([self.ident.encode('utf8')]), set()
class Call(object):
"""A function call in a template."""
def __init__(self, ident, args, original):
@@ -214,6 +221,7 @@ class Call(object):
)
return [subexpr_call], varnames, funcnames
class Expression(object):
"""Top-level template construct: contains a list of text blobs,
Symbols, and Calls.
@@ -259,6 +267,7 @@ class Expression(object):
class ParseError(Exception):
pass
class Parser(object):
"""Parses a template expression string. Instantiate the class with
the template source and call ``parse_expression``. The ``pos`` field
@@ -316,13 +325,13 @@ class Parser(object):
next_char = self.string[self.pos + 1]
if char == ESCAPE_CHAR and next_char in \
(SYMBOL_DELIM, FUNC_DELIM, GROUP_CLOSE, ARG_SEP):
(SYMBOL_DELIM, FUNC_DELIM, GROUP_CLOSE, ARG_SEP):
# An escaped special character ($$, $}, etc.). Note that
# ${ is not an escape sequence: this is ambiguous with
# the start of a symbol and it's not necessary (just
# using { suffices in all cases).
text_parts.append(next_char)
self.pos += 2 # Skip the next character.
self.pos += 2 # Skip the next character.
continue
# Shift all characters collected so far into a single string.
@@ -372,7 +381,7 @@ class Parser(object):
if next_char == GROUP_OPEN:
# A symbol like ${this}.
self.pos += 1 # Skip opening.
self.pos += 1 # Skip opening.
closer = self.string.find(GROUP_CLOSE, self.pos)
if closer == -1 or closer == self.pos:
# No closing brace found or identifier is empty.
@@ -431,7 +440,7 @@ class Parser(object):
self.parts.append(self.string[start_pos:self.pos])
return
self.pos += 1 # Move past closing brace.
self.pos += 1 # Move past closing brace.
self.parts.append(Call(ident, args, self.string[start_pos:self.pos]))
def parse_argument_list(self):
@@ -472,6 +481,7 @@ class Parser(object):
self.pos += len(ident)
return ident
def _parse(template):
"""Parse a top-level template string Expression. Any extraneous text
is considered literal text.

View File

@@ -35,13 +35,13 @@ from __future__ import print_function
import Queue
from threading import Thread, Lock
import sys
import types
BUBBLE = '__PIPELINE_BUBBLE__'
POISON = '__PIPELINE_POISON__'
DEFAULT_QUEUE_SIZE = 16
def _invalidate_queue(q, val=None, sync=True):
"""Breaks a Queue such that it never blocks, always has size 1,
and has no maximum size. get()ing from the queue returns `val`,
@@ -50,8 +50,10 @@ def _invalidate_queue(q, val=None, sync=True):
"""
def _qsize(len=len):
return 1
def _put(item):
pass
def _get():
return val
@@ -70,6 +72,7 @@ def _invalidate_queue(q, val=None, sync=True):
if sync:
q.mutex.release()
class CountedQueue(Queue.Queue):
"""A queue that keeps track of the number of threads that are
still feeding into it. The queue is poisoned when all threads are
@@ -104,6 +107,7 @@ class CountedQueue(Queue.Queue):
# Replacement _get invalidates when no items remain.
_old_get = self._get
def _get():
out = _old_get()
if not self.queue:
@@ -117,18 +121,67 @@ class CountedQueue(Queue.Queue):
# No items. Invalidate immediately.
_invalidate_queue(self, POISON, False)
class MultiMessage(object):
"""A message yielded by a pipeline stage encapsulating multiple
values to be sent to the next stage.
"""
def __init__(self, messages):
self.messages = messages
def multiple(messages):
"""Yield multiple([message, ..]) from a pipeline stage to send
multiple values to the next pipeline stage.
"""
return MultiMessage(messages)
def stage(func):
"""Decorate a function to become a simple stage.
>>> @stage
... def add(n, i):
... return i + n
>>> pipe = Pipeline([
... iter([1, 2, 3]),
... add(2),
... ])
>>> list(pipe.pull())
[3, 4, 5]
"""
def coro(*args):
task = None
while True:
task = yield task
task = func(*(args + (task,)))
return coro
def mutator_stage(func):
"""Decorate a function that manipulates items in a coroutine to
become a simple stage.
>>> @mutator_stage
... def setkey(key, item):
... item[key] = True
>>> pipe = Pipeline([
... iter([{'x': False}, {'a': False}]),
... setkey('x'),
... ])
>>> list(pipe.pull())
[{'x': True}, {'a': False, 'x': True}]
"""
def coro(*args):
task = None
while True:
task = yield task
func(*(args + (task,)))
return coro
def _allmsgs(obj):
"""Returns a list of all the messages encapsulated in obj. If obj
is a MultiMessage, returns its enclosed messages. If obj is BUBBLE,
@@ -141,6 +194,7 @@ def _allmsgs(obj):
else:
return [obj]
class PipelineThread(Thread):
"""Abstract base class for pipeline-stage threads."""
def __init__(self, all_threads):
@@ -169,6 +223,7 @@ class PipelineThread(Thread):
for thread in self.all_threads:
thread.abort()
class FirstPipelineThread(PipelineThread):
"""The thread running the first stage in a parallel pipeline setup.
The coroutine should just be a generator.
@@ -209,6 +264,7 @@ class FirstPipelineThread(PipelineThread):
# Generator finished; shut down the pipeline.
self.out_queue.release()
class MiddlePipelineThread(PipelineThread):
"""A thread running any stage in the pipeline except the first or
last.
@@ -256,6 +312,7 @@ class MiddlePipelineThread(PipelineThread):
# Pipeline is shutting down normally.
self.out_queue.release()
class LastPipelineThread(PipelineThread):
"""A thread running the last stage in a pipeline. The coroutine
should yield nothing.
@@ -291,6 +348,7 @@ class LastPipelineThread(PipelineThread):
self.abort_all(sys.exc_info())
return
class Pipeline(object):
"""Represents a staged pattern of work. Each stage in the pipeline
is a coroutine that receives messages from the previous stage and
@@ -322,7 +380,8 @@ class Pipeline(object):
messages between the stages are stored in queues of the given
size.
"""
queues = [CountedQueue(queue_size) for i in range(len(self.stages)-1)]
queue_count = len(self.stages) - 1
queues = [CountedQueue(queue_size) for i in range(queue_count)]
threads = []
# Set up first stage.
@@ -330,10 +389,10 @@ class Pipeline(object):
threads.append(FirstPipelineThread(coro, queues[0], threads))
# Middle stages.
for i in range(1, len(self.stages)-1):
for i in range(1, queue_count):
for coro in self.stages[i]:
threads.append(MiddlePipelineThread(
coro, queues[i-1], queues[i], threads
coro, queues[i - 1], queues[i], threads
))
# Last stage.
@@ -408,17 +467,20 @@ if __name__ == '__main__':
print('generating %i' % i)
time.sleep(1)
yield i
def work():
num = yield
while True:
print('processing %i' % num)
time.sleep(2)
num = yield num*2
num = yield num * 2
def consume():
while True:
num = yield
time.sleep(1)
print('received %i' % num)
ts_start = time.time()
Pipeline([produce(), work(), consume()]).run_sequential()
ts_seq = time.time()
@@ -437,6 +499,7 @@ if __name__ == '__main__':
print('generating %i' % i)
time.sleep(1)
yield i
def exc_work():
num = yield
while True:
@@ -445,10 +508,10 @@ if __name__ == '__main__':
if num == 3:
raise Exception()
num = yield num * 2
def exc_consume():
while True:
num = yield
#if num == 4:
# raise Exception()
print('received %i' % num)
Pipeline([exc_produce(), exc_work(), exc_consume()]).run_parallel(1)

View File

@@ -20,6 +20,7 @@ from beets import util
Node = namedtuple('Node', ['files', 'dirs'])
def _insert(node, path, itemid):
"""Insert an item into a virtual filesystem node."""
if len(path) == 1:
@@ -33,6 +34,7 @@ def _insert(node, path, itemid):
node.dirs[dirname] = Node({}, {})
_insert(node.dirs[dirname], rest, itemid)
def libtree(lib):
"""Generates a filesystem-like directory tree for the files
contained in `lib`. Filesystem nodes are (files, dirs) named

View File

@@ -16,6 +16,9 @@
import os.path
import logging
import imghdr
import subprocess
import platform
from tempfile import NamedTemporaryFile
from beets.plugins import BeetsPlugin
from beets import mediafile
@@ -25,6 +28,7 @@ from beets.util import syspath, normpath, displayable_path
from beets.util.artresizer import ArtResizer
from beets import config
log = logging.getLogger('beets')
@@ -36,12 +40,19 @@ class EmbedCoverArtPlugin(BeetsPlugin):
self.config.add({
'maxwidth': 0,
'auto': True,
'compare_threshold': 0,
'ifempty': False,
})
if self.config['maxwidth'].get(int) and \
not ArtResizer.shared.local:
if self.config['maxwidth'].get(int) and not ArtResizer.shared.local:
self.config['maxwidth'] = 0
log.warn(u"embedart: ImageMagick or PIL not found; "
u"'maxwidth' option ignored")
if self.config['compare_threshold'].get(int) and not \
ArtResizer.shared.can_compare:
self.config['compare_threshold'] = 0
log.warn(u"embedart: ImageMagick 6.8.7 or higher not installed; "
u"'compare_threshold' option ignored")
def commands(self):
# Embed command.
@@ -52,12 +63,15 @@ class EmbedCoverArtPlugin(BeetsPlugin):
'-f', '--file', metavar='PATH', help='the image file to embed'
)
maxwidth = config['embedart']['maxwidth'].get(int)
compare_threshold = config['embedart']['compare_threshold'].get(int)
ifempty = config['embedart']['ifempty'].get(bool)
def embed_func(lib, opts, args):
if opts.file:
imagepath = normpath(opts.file)
for item in lib.items(decargs(args)):
embed_item(item, imagepath, maxwidth)
embed_item(item, imagepath, maxwidth, None,
compare_threshold, ifempty)
else:
for album in lib.albums(decargs(args)):
embed_album(album, maxwidth)
@@ -72,7 +86,8 @@ class EmbedCoverArtPlugin(BeetsPlugin):
def extract_func(lib, opts, args):
outpath = normpath(opts.outpath or 'cover')
extract(lib, outpath, decargs(args))
item = lib.items(decargs(args)).get()
extract(outpath, item)
extract_cmd.func = extract_func
# Clear command.
@@ -91,23 +106,43 @@ def album_imported(lib, album):
"""Automatically embed art into imported albums.
"""
if album.artpath and config['embedart']['auto']:
embed_album(album, config['embedart']['maxwidth'].get(int))
embed_album(album, config['embedart']['maxwidth'].get(int), True)
def embed_item(item, imagepath, maxwidth=None, itempath=None):
def embed_item(item, imagepath, maxwidth=None, itempath=None,
compare_threshold=0, ifempty=False, as_album=False):
"""Embed an image into the item's media file.
"""
if compare_threshold:
if not check_art_similarity(item, imagepath, compare_threshold):
log.warn(u'Image not similar; skipping.')
return
if ifempty:
art = get_art(item)
if not art:
pass
else:
log.debug(u'embedart: media file contained art already {0}'.format(
displayable_path(imagepath)
))
return
if maxwidth and not as_album:
imagepath = resize_image(imagepath, maxwidth)
try:
log.debug(u'embedart: embedding {0}'.format(
displayable_path(imagepath)
))
item['images'] = [_mediafile_image(imagepath, maxwidth)]
item.try_write(itempath)
except IOError as exc:
log.error(u'embedart: could not read image file: {0}'.format(exc))
finally:
# We don't want to store the image in the database
else:
# We don't want to store the image in the database.
item.try_write(itempath)
del item['images']
def embed_album(album, maxwidth=None):
def embed_album(album, maxwidth=None, quiet=False):
"""Embed album art into all of the album's items.
"""
imagepath = album.artpath
@@ -115,39 +150,78 @@ def embed_album(album, maxwidth=None):
log.info(u'No album art present: {0} - {1}'.
format(album.albumartist, album.album))
return
if not os.path.isfile(imagepath):
if not os.path.isfile(syspath(imagepath)):
log.error(u'Album art not found at {0}'
.format(imagepath))
.format(displayable_path(imagepath)))
return
if maxwidth:
imagepath = resize_image(imagepath, maxwidth)
log.info(u'Embedding album art into {0.albumartist} - {0.album}.'
.format(album))
log.log(
logging.DEBUG if quiet else logging.INFO,
u'Embedding album art into {0.albumartist} - {0.album}.'.format(album),
)
for item in album.items():
embed_item(item, imagepath, maxwidth)
embed_item(item, imagepath, maxwidth, None,
config['embedart']['compare_threshold'].get(int),
config['embedart']['ifempty'].get(bool), as_album=True)
def resize_image(imagepath, maxwidth):
"""Returns path to an image resized to maxwidth.
"""
log.info(u'Resizing album art to {0} pixels wide'
.format(maxwidth))
imagepath = ArtResizer.shared.resize(maxwidth, syspath(imagepath))
return imagepath
def check_art_similarity(item, imagepath, compare_threshold):
"""A boolean indicating if an image is similar to embedded item art.
"""
with NamedTemporaryFile(delete=True) as f:
art = extract(f.name, item)
if art:
# Converting images to grayscale tends to minimize the weight
# of colors in the diff score
cmd = 'convert {0} {1} -colorspace gray MIFF:- | ' \
'compare -metric PHASH - null:'.format(syspath(imagepath),
syspath(art))
proc = subprocess.Popen(cmd, stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
close_fds=platform.system() != 'Windows',
shell=True)
stdout, stderr = proc.communicate()
if proc.returncode:
if proc.returncode != 1:
log.warn(u'embedart: IM phashes compare failed for {0}, \
{1}'.format(displayable_path(imagepath),
displayable_path(art)))
return
phashDiff = float(stderr)
else:
phashDiff = float(stdout)
log.info(u'embedart: compare PHASH score is {0}'.format(phashDiff))
if phashDiff > compare_threshold:
return False
return True
def _mediafile_image(image_path, maxwidth=None):
"""Return a `mediafile.Image` object for the path.
If maxwidth is set the image is resized if necessary.
"""
if maxwidth:
image_path = ArtResizer.shared.resize(maxwidth, syspath(image_path))
with open(syspath(image_path), 'rb') as f:
data = f.read()
return mediafile.Image(data, type=mediafile.ImageType.front)
# 'extractart' command.
def extract(lib, outpath, query):
item = lib.items(query).get()
if not item:
log.error(u'No item matches query.')
return
def get_art(item):
# Extract the art.
try:
mf = mediafile.MediaFile(syspath(item.path))
@@ -157,7 +231,18 @@ def extract(lib, outpath, query):
))
return
art = mf.art
return mf.art
# 'extractart' command.
def extract(outpath, item):
if not item:
log.error(u'No item matches query.')
return
art = get_art(item)
if not art:
log.error(u'No album art present in {0} - {1}.'
.format(item.artist, item.title))
@@ -170,10 +255,11 @@ def extract(lib, outpath, query):
return
outpath += '.' + ext
log.info(u'Extracting album art from: {0.artist} - {0.title}\n'
u'To: {1}'.format(item, displayable_path(outpath)))
log.info(u'Extracting album art from: {0.artist} - {0.title} '
u'to: {1}'.format(item, displayable_path(outpath)))
with open(syspath(outpath), 'wb') as f:
f.write(art)
return outpath
# 'clearart' command.
@@ -190,5 +276,5 @@ def clear(lib, query):
displayable_path(item.path), exc
))
continue
mf.art = None
del mf.art
mf.save()

View File

@@ -1,5 +1,5 @@
# This file is part of beets.
# Copyright 2013, Adrian Sampson.
# Copyright 2014, Adrian Sampson.
#
# Permission is hereby granted, free of charge, to any person obtaining
# a copy of this software and associated documentation files (the
@@ -22,12 +22,18 @@ from tempfile import NamedTemporaryFile
import requests
from beets.plugins import BeetsPlugin
from beets.util.artresizer import ArtResizer
from beets import plugins
from beets import importer
from beets import ui
from beets import util
from beets import config
from beets.util.artresizer import ArtResizer
try:
import itunes
HAVE_ITUNES = True
except ImportError:
HAVE_ITUNES = False
IMAGE_EXTENSIONS = ['png', 'jpg', 'jpeg']
CONTENT_TYPES = ('image/jpeg',)
@@ -73,17 +79,14 @@ CAA_URL = 'http://coverartarchive.org/release/{mbid}/front-500.jpg'
CAA_GROUP_URL = 'http://coverartarchive.org/release-group/{mbid}/front-500.jpg'
def caa_art(release_id):
"""Return the Cover Art Archive URL given a MusicBrainz release ID.
def caa_art(album):
"""Return the Cover Art Archive and Cover Art Archive release group URLs
using album MusicBrainz release ID and release group ID.
"""
return CAA_URL.format(mbid=release_id)
def caa_group_art(release_group_id):
"""Return the Cover Art Archive release group URL given a MusicBrainz
release group ID.
"""
return CAA_GROUP_URL.format(mbid=release_group_id)
if album.mb_albumid:
yield CAA_URL.format(mbid=album.mb_albumid)
if album.mb_releasegroupid:
yield CAA_GROUP_URL.format(mbid=album.mb_releasegroupid)
# Art from Amazon.
@@ -92,10 +95,12 @@ AMAZON_URL = 'http://images.amazon.com/images/P/%s.%02i.LZZZZZZZ.jpg'
AMAZON_INDICES = (1, 2)
def art_for_asin(asin):
"""Generate URLs for an Amazon ID (ASIN) string."""
for index in AMAZON_INDICES:
yield AMAZON_URL % (asin, index)
def art_for_asin(album):
"""Generate URLs using Amazon ID (ASIN) string.
"""
if album.asin:
for index in AMAZON_INDICES:
yield AMAZON_URL % (album.asin, index)
# AlbumArt.org scraper.
@@ -104,11 +109,14 @@ AAO_URL = 'http://www.albumart.org/index_detail.php'
AAO_PAT = r'href\s*=\s*"([^>"]*)"[^>]*title\s*=\s*"View larger image"'
def aao_art(asin):
"""Return art URL from AlbumArt.org given an ASIN."""
def aao_art(album):
"""Return art URL from AlbumArt.org using album ASIN.
"""
if not album.asin:
return
# Get the page from albumart.org.
try:
resp = requests_session.get(AAO_URL, params={'asin': asin})
resp = requests_session.get(AAO_URL, params={'asin': album.asin})
log.debug(u'fetchart: scraped art URL: {0}'.format(resp.url))
except requests.RequestException:
log.debug(u'fetchart: error scraping art page')
@@ -118,7 +126,7 @@ def aao_art(asin):
m = re.search(AAO_PAT, resp.text)
if m:
image_url = m.group(1)
return image_url
yield image_url
else:
log.debug(u'fetchart: no image found on page')
@@ -132,6 +140,8 @@ def google_art(album):
"""Return art URL from google.org given an album title and
interpreter.
"""
if not (album.albumartist and album.album):
return
search_string = (album.albumartist + ',' + album.album).encode('utf-8')
response = requests_session.get(GOOGLE_URL, params={
'v': '1.0',
@@ -145,14 +155,39 @@ def google_art(album):
data = results['responseData']
dataInfo = data['results']
for myUrl in dataInfo:
return myUrl['unescapedUrl']
yield myUrl['unescapedUrl']
except:
log.debug(u'fetchart: error scraping art page')
return
# Art from the iTunes Store.
def itunes_art(album):
"""Return art URL from iTunes Store given an album title.
"""
search_string = (album.albumartist + ' ' + album.album).encode('utf-8')
try:
# Isolate bugs in the iTunes library while searching.
try:
itunes_album = itunes.search_album(search_string)[0]
except Exception as exc:
log.debug('fetchart: iTunes search failed: {0}'.format(exc))
return
if itunes_album.get_artwork()['100']:
small_url = itunes_album.get_artwork()['100']
big_url = small_url.replace('100x100', '1200x1200')
yield big_url
else:
log.debug(u'fetchart: album has no artwork in iTunes Store')
except IndexError:
log.debug(u'fetchart: album not found in iTunes Store')
# Art from the filesystem.
def filename_priority(filename, cover_names):
"""Sort order for image names.
@@ -164,7 +199,8 @@ def filename_priority(filename, cover_names):
def art_in_path(path, cover_names, cautious):
"""Look for album art files in a specified directory."""
"""Look for album art files in a specified directory.
"""
if not os.path.isdir(path):
return
@@ -195,31 +231,27 @@ def art_in_path(path, cover_names, cautious):
# Try each source in turn.
SOURCES_ALL = [u'coverart', u'itunes', u'amazon', u'albumart', u'google']
def _source_urls(album):
ART_FUNCS = {
u'coverart': caa_art,
u'itunes': itunes_art,
u'albumart': aao_art,
u'amazon': art_for_asin,
u'google': google_art,
}
def _source_urls(album, sources=SOURCES_ALL):
"""Generate possible source URLs for an album's art. The URLs are
not guaranteed to work so they each need to be attempted in turn.
This allows the main `art_for_album` function to abort iteration
through this sequence early to avoid the cost of scraping when not
necessary.
"""
# Cover Art Archive.
if album.mb_albumid:
yield caa_art(album.mb_albumid)
if album.mb_releasegroupid:
yield caa_group_art(album.mb_releasegroupid)
# Amazon and AlbumArt.org.
if album.asin:
for url in art_for_asin(album.asin):
yield url
url = aao_art(album.asin)
if url:
yield url
if config['fetchart']['google_search']:
url = google_art(album)
if url:
for s in sources:
urls = ART_FUNCS[s](album)
for url in urls:
yield url
@@ -245,7 +277,8 @@ def art_for_album(album, paths, maxwidth=None, local_only=False):
# Web art sources.
remote_priority = config['fetchart']['remote_priority'].get(bool)
if not local_only and (remote_priority or not out):
for url in _source_urls(album):
for url in _source_urls(album,
config['fetchart']['sources'].as_str_seq()):
if maxwidth:
url = ArtResizer.shared.proxy_url(maxwidth, url)
candidate = _fetch_image(url)
@@ -286,7 +319,7 @@ def batch_fetch_art(lib, albums, force, maxwidth=None):
message))
class FetchArtPlugin(BeetsPlugin):
class FetchArtPlugin(plugins.BeetsPlugin):
def __init__(self):
super(FetchArtPlugin, self).__init__()
@@ -297,6 +330,7 @@ class FetchArtPlugin(BeetsPlugin):
'cautious': False,
'google_search': False,
'cover_names': ['cover', 'front', 'art', 'album', 'folder'],
'sources': SOURCES_ALL,
})
# Holds paths to downloaded images between fetching them and
@@ -309,6 +343,12 @@ class FetchArtPlugin(BeetsPlugin):
self.import_stages = [self.fetch_art]
self.register_listener('import_task_files', self.assign_art)
available_sources = list(SOURCES_ALL)
if not HAVE_ITUNES and u'itunes' in available_sources:
available_sources.remove(u'itunes')
self.config['sources'] = plugins.sanitize_choices(
self.config['sources'].as_str_seq(), available_sources)
# Asynchronous; after music is added to the library.
def fetch_art(self, session, task):
"""Find art for the album being imported."""

View File

@@ -18,25 +18,26 @@ from __future__ import print_function
import re
import logging
import urllib
import requests
import json
import unicodedata
import urllib
import difflib
import itertools
from HTMLParser import HTMLParseError
from beets.plugins import BeetsPlugin
from beets import ui
from beets import config
from beets import plugins
from beets import config, ui
# Global logger.
log = logging.getLogger('beets')
DIV_RE = re.compile(r'<(/?)div>?')
DIV_RE = re.compile(r'<(/?)div>?', re.I)
COMMENT_RE = re.compile(r'<!--.*-->', re.S)
TAG_RE = re.compile(r'<[^>]*>')
BREAK_RE = re.compile(r'<br\s*/?>')
BREAK_RE = re.compile(r'\n?\s*<br([\s|/][^>]*)*>\s*\n?', re.I)
URL_CHARACTERS = {
u'\u2018': u"'",
u'\u2019': u"'",
@@ -60,10 +61,14 @@ def fetch_url(url):
is unreachable.
"""
try:
return urllib.urlopen(url).read()
except IOError as exc:
log.debug(u'failed to fetch: {0} ({1})'.format(url, unicode(exc)))
return None
r = requests.get(url, verify=False)
except requests.RequestException as exc:
log.debug(u'lyrics request failed: {0}'.format(exc))
return
if r.status_code == requests.codes.ok:
return r.text
else:
log.debug(u'failed to fetch: {0} ({1})'.format(url, r.status_code))
def unescape(text):
@@ -79,10 +84,20 @@ def unescape(text):
return out
def extract_text(html, starttag):
def extract_text_between(html, start_marker, end_marker):
try:
_, html = html.split(start_marker, 1)
html, _ = html.split(end_marker, 1)
except ValueError:
return u''
return html
def extract_text_in(html, starttag):
"""Extract the text from a <DIV> tag in the HTML starting with
``starttag``. Returns None if parsing fails.
"""
# Strip off the leading text before opening tag.
try:
_, html = html.split(starttag, 1)
@@ -101,7 +116,6 @@ def extract_text(html, starttag):
else: # Opening tag.
if level == 0:
parts.append(html[pos:match.start()])
level += 1
if level == -1:
@@ -110,26 +124,7 @@ def extract_text(html, starttag):
else:
print('no closing tag found!')
return
lyrics = ''.join(parts)
return strip_cruft(lyrics)
def strip_cruft(lyrics, wscollapse=True):
"""Clean up HTML from an extracted lyrics string. For example, <BR>
tags are replaced with newlines.
"""
lyrics = COMMENT_RE.sub('', lyrics)
lyrics = unescape(lyrics)
if wscollapse:
lyrics = re.sub(r'\s+', ' ', lyrics) # Whitespace collapse.
lyrics = re.sub(r'<(script).*?</\1>(?s)', '', lyrics) # Strip script tags.
lyrics = BREAK_RE.sub('\n', lyrics) # <BR> newlines.
lyrics = re.sub(r'\n +', '\n', lyrics)
lyrics = re.sub(r' +\n', '\n', lyrics)
lyrics = TAG_RE.sub('', lyrics) # Strip remaining HTML tags.
lyrics = lyrics.replace('\r', '\n')
lyrics = lyrics.strip()
return lyrics
return u''.join(parts)
def search_pairs(item):
@@ -140,7 +135,7 @@ def search_pairs(item):
In addition to the artist and title obtained from the `item` the
method tries to strip extra information like paranthesized suffixes
and featured artists from the strings and add them as caniddates.
and featured artists from the strings and add them as candidates.
The method also tries to split multiple titles separated with `/`.
"""
@@ -149,7 +144,7 @@ def search_pairs(item):
artists = [artist]
# Remove any featuring artists from the artists name
pattern = r"(.*?) (&|\b(and|ft|feat(uring)?\b))"
pattern = r"(.*?) {0}".format(plugins.feat_tokens())
match = re.search(pattern, artist, re.IGNORECASE)
if match:
artists.append(match.group(1))
@@ -162,8 +157,8 @@ def search_pairs(item):
titles.append(match.group(1))
# Remove any featuring artists from the title
pattern = r"(.*?) \b(ft|feat(uring)?)\b"
for title in titles:
pattern = r"(.*?) {0}".format(plugins.feat_tokens(for_artist=False))
for title in titles[:]:
match = re.search(pattern, title, re.IGNORECASE)
if match:
titles.append(match.group(1))
@@ -189,6 +184,19 @@ def _encode(s):
s = s.encode('utf8', 'ignore')
return urllib.quote(s)
# Musixmatch
MUSIXMATCH_URL_PATTERN = 'https://www.musixmatch.com/lyrics/%s/%s'
def fetch_musixmatch(artist, title):
url = MUSIXMATCH_URL_PATTERN % (_lw_encode(artist.title()),
_lw_encode(title.title()))
html = fetch_url(url)
if not html:
return
lyrics = extract_text_between(html, '"lyrics_body":', '"lyrics_language":')
return lyrics.strip(',"').replace('\\n', '\n')
# LyricsWiki.
@@ -212,7 +220,7 @@ def fetch_lyricswiki(artist, title):
if not html:
return
lyrics = extract_text(html, "<div class='lyricbox'>")
lyrics = extract_text_in(html, u"<div class='lyricbox'>")
if lyrics and 'Unfortunately, we are not licensed' not in lyrics:
return lyrics
@@ -238,8 +246,8 @@ def fetch_lyricscom(artist, title):
html = fetch_url(url)
if not html:
return
lyrics = extract_text(html, '<div id="lyric_space">')
lyrics = extract_text_between(html, '<div id="lyrics" class="SCREENONLY" '
'itemprop="description">', '</div>')
if not lyrics:
return
for not_found_str in LYRICSCOM_NOT_FOUND:
@@ -280,7 +288,6 @@ def is_page_candidate(urlLink, urlTitle, title, artist):
artist = slugify(artist.lower())
sitename = re.search(u"//([^/]+)/.*", slugify(urlLink.lower())).group(1)
urlTitle = slugify(urlTitle.lower())
# Check if URL title contains song title (exact match)
if urlTitle.find(title) != -1:
return True
@@ -289,41 +296,11 @@ def is_page_candidate(urlLink, urlTitle, title, artist):
tokens = [by + '_' + artist for by in BY_TRANS] + \
[artist, sitename, sitename.replace('www.', '')] + LYRICS_TRANS
songTitle = re.sub(u'(%s)' % u'|'.join(tokens), u'', urlTitle)
typoRatio = .8
songTitle = songTitle.strip('_|')
typoRatio = .9
return difflib.SequenceMatcher(None, songTitle, title).ratio() >= typoRatio
def insert_line_feeds(text):
"""Insert newlines before upper-case characters.
"""
tokensStr = re.split("([a-z][A-Z])", text)
for idx in range(1, len(tokensStr), 2):
ltoken = list(tokensStr[idx])
tokensStr[idx] = ltoken[0] + '\n' + ltoken[1]
return ''.join(tokensStr)
def sanitize_lyrics(text):
"""Clean text, returning raw lyrics as output or None if it happens
that input text is actually not lyrics content. Clean (x)html tags
in text, correct layout and syntax...
"""
text = strip_cruft(text, False)
# Restore \n in input text
if '\n' not in text:
text = insert_line_feeds(text)
while text.count('\n\n') > text.count('\n') // 4:
# Remove first occurrence of \n for each sequence of \n
text = re.sub(r'\n(\n+)', '\g<1>', text)
text = re.sub(r'\n\n+', '\n\n', text) # keep at most two \n in a row
return text
def remove_credits(text):
"""Remove first/last line of text if it contains the word 'lyrics'
eg 'Lyrics by songsdatabase.com'
@@ -342,13 +319,12 @@ def is_lyrics(text, artist=None):
"""Determine whether the text seems to be valid lyrics.
"""
if not text:
return
return False
badTriggersOcc = []
nbLines = text.count('\n')
if nbLines <= 1:
log.debug(u"Ignoring too short lyrics '{0}'".format(text))
return 0
return False
elif nbLines < 5:
badTriggersOcc.append('too_short')
else:
@@ -356,7 +332,7 @@ def is_lyrics(text, artist=None):
# down
text = remove_credits(text)
badTriggers = ['lyrics', 'copyright', 'property']
badTriggers = ['lyrics', 'copyright', 'property', 'links']
if artist:
badTriggersOcc += [artist]
@@ -366,62 +342,58 @@ def is_lyrics(text, artist=None):
if badTriggersOcc:
log.debug(u'Bad triggers detected: {0}'.format(badTriggersOcc))
return len(badTriggersOcc) < 2
def scrape_lyrics_from_url(url):
def _scrape_strip_cruft(html, plain_text_out=False):
"""Clean up HTML
"""
html = unescape(html)
html = html.replace('\r', '\n') # Normalize EOL.
html = re.sub(r' +', ' ', html) # Whitespaces collapse.
html = BREAK_RE.sub('\n', html) # <br> eats up surrounding '\n'.
html = re.sub(r'<(script).*?</\1>(?s)', '', html) # Strip script tags.
if plain_text_out: # Strip remaining HTML tags
html = COMMENT_RE.sub('', html)
html = TAG_RE.sub('', html)
html = '\n'.join([x.strip() for x in html.strip().split('\n')])
html = re.sub(r'\n{3,}', r'\n\n', html)
return html
def _scrape_merge_paragraphs(html):
html = re.sub(r'</p>\s*<p(\s*[^>]*)>', '\n', html)
return re.sub(r'<div .*>\s*</div>', '\n', html)
def scrape_lyrics_from_html(html):
"""Scrape lyrics from a URL. If no lyrics can be found, return None
instead.
"""
from bs4 import BeautifulSoup, Comment
html = fetch_url(url)
from bs4 import SoupStrainer, BeautifulSoup
if not html:
return None
soup = BeautifulSoup(html)
for tag in soup.findAll('br'):
tag.replaceWith('\n')
# Remove non relevant html parts
[s.extract() for s in soup(['head', 'script'])]
comments = soup.findAll(text=lambda text: isinstance(text, Comment))
[s.extract() for s in comments]
def is_text_notcode(text):
length = len(text)
return (length > 20 and
text.count(' ') > length / 25 and
(text.find('{') == -1 or text.find(';') == -1))
html = _scrape_strip_cruft(html)
html = _scrape_merge_paragraphs(html)
# extract all long text blocks that are not code
try:
for tag in soup.findAll(True):
tag.name = 'p' # keep tag contents
except Exception, e:
log.debug(u'Error {0} when replacing containing marker by p marker'
.format(e, exc_info=True))
# Make better soup from current soup! The previous unclosed <p> sections
# are now closed. Use str() rather than prettify() as it's more
# conservative concerning EOL
soup = BeautifulSoup(str(soup))
# In case lyrics are nested in no markup but <body>
# Insert the whole body in a <p>
bodyTag = soup.find('body')
if bodyTag:
pTag = soup.new_tag("p")
bodyTag.parent.insert(0, pTag)
pTag.insert(0, bodyTag)
tagTokens = []
for tag in soup.findAll('p'):
soup2 = BeautifulSoup(str(tag))
# Extract all text of <p> section.
tagTokens += soup2.findAll(text=True)
if tagTokens:
# Lyrics are expected to be the longest paragraph
tagTokens = sorted(tagTokens, key=len, reverse=True)
soup = BeautifulSoup(tagTokens[0])
return unescape(tagTokens[0].strip("\n\r: "))
soup = BeautifulSoup(html, "html.parser",
parse_only=SoupStrainer(text=is_text_notcode))
except HTMLParseError:
return None
soup = sorted(soup.stripped_strings, key=len)[-1]
return soup
def fetch_google(artist, title):
@@ -443,15 +415,14 @@ def fetch_google(artist, title):
if 'items' in data.keys():
for item in data['items']:
urlLink = item['link']
urlTitle = item['title']
urlTitle = item.get('title', u'')
if not is_page_candidate(urlLink, urlTitle, title, artist):
continue
lyrics = scrape_lyrics_from_url(urlLink)
html = fetch_url(urlLink)
lyrics = scrape_lyrics_from_html(html)
if not lyrics:
continue
lyrics = sanitize_lyrics(lyrics)
if is_lyrics(lyrics, artist):
log.debug(u'got lyrics from {0}'.format(item['displayLink']))
return lyrics
@@ -459,8 +430,16 @@ def fetch_google(artist, title):
# Plugin logic.
SOURCES = ['google', 'lyricwiki', 'lyrics.com', 'musixmatch']
SOURCE_BACKENDS = {
'google': fetch_google,
'lyricwiki': fetch_lyricswiki,
'lyrics.com': fetch_lyricscom,
'musixmatch': fetch_musixmatch,
}
class LyricsPlugin(BeetsPlugin):
class LyricsPlugin(plugins.BeetsPlugin):
def __init__(self):
super(LyricsPlugin, self).__init__()
self.import_stages = [self.imported]
@@ -469,12 +448,19 @@ class LyricsPlugin(BeetsPlugin):
'google_API_key': None,
'google_engine_ID': u'009217259823014548361:lndtuqkycfu',
'fallback': None,
'force': False,
'sources': SOURCES,
})
self.backends = [fetch_lyricswiki, fetch_lyricscom]
if self.config['google_API_key'].get():
self.backends.insert(0, fetch_google)
available_sources = list(SOURCES)
if not self.config['google_API_key'].get() and \
'google' in SOURCES:
available_sources.remove('google')
self.config['sources'] = plugins.sanitize_choices(
self.config['sources'].as_str_seq(), available_sources)
self.backends = []
for key in self.config['sources'].as_str_seq():
self.backends.append(SOURCE_BACKENDS[key])
def commands(self):
cmd = ui.Subcommand('lyrics', help='fetch song lyrics')
@@ -490,8 +476,10 @@ class LyricsPlugin(BeetsPlugin):
# import_write config value.
write = config['import']['write'].get(bool)
for item in lib.items(ui.decargs(args)):
self.fetch_item_lyrics(lib, logging.INFO, item, write,
opts.force_refetch)
self.fetch_item_lyrics(
lib, logging.INFO, item, write,
opts.force_refetch or self.config['force'],
)
if opts.printlyr and item.lyrics:
ui.print_(item.lyrics)
@@ -504,7 +492,7 @@ class LyricsPlugin(BeetsPlugin):
if self.config['auto']:
for item in task.imported_items():
self.fetch_item_lyrics(session.lib, logging.DEBUG, item,
False, False)
False, self.config['force'])
def fetch_item_lyrics(self, lib, loglevel, item, write, force):
"""Fetch and store lyrics for a single item. If ``write``, then the
@@ -551,8 +539,6 @@ class LyricsPlugin(BeetsPlugin):
for backend in self.backends:
lyrics = backend(artist, title)
if lyrics:
if isinstance(lyrics, str):
lyrics = lyrics.decode('utf8', 'ignore')
log.debug(u'got lyrics from backend: {0}'
.format(backend.__name__))
return lyrics.strip()
return _scrape_strip_cruft(lyrics, True)