From 046d4d82b49c721d68e0561bc08621e6a2d6ed20 Mon Sep 17 00:00:00 2001
From: AdeHub
Date: Sat, 24 Aug 2024 16:44:41 +1200
Subject: [PATCH] update Beets
---
lib/beets/__init__.py | 15 +-
lib/beets/__main__.py | 1 +
lib/beets/art.py | 220 +-
lib/beets/autotag/__init__.py | 161 +-
lib/beets/autotag/hooks.py | 426 ++--
lib/beets/autotag/match.py | 394 ++--
lib/beets/autotag/mb.py | 807 ++++---
lib/beets/config_default.yaml | 155 +-
lib/beets/dbcore/__init__.py | 20 +-
lib/beets/dbcore/db.py | 670 ++++--
lib/beets/dbcore/query.py | 719 +++---
lib/beets/dbcore/queryparse.py | 156 +-
lib/beets/dbcore/types.py | 229 +-
lib/beets/importer.py | 765 ++++---
lib/beets/library.py | 1373 +++++++-----
lib/beets/logging.py | 106 +-
lib/beets/mediafile.py | 14 +-
lib/beets/plugins.py | 208 +-
lib/beets/random.py | 15 +-
.../{util/confit.py => test/__init__.py} | 21 +-
lib/beets/test/_common.py | 328 +++
lib/beets/test/helper.py | 1003 +++++++++
lib/beets/ui/__init__.py | 1188 +++++++---
lib/beets/ui/commands.py | 1963 +++++++++++------
lib/beets/ui/completion_base.sh | 33 +-
lib/beets/util/__init__.py | 716 +++---
lib/beets/util/artresizer.py | 837 ++++---
lib/beets/util/bluelet.py | 101 +-
lib/beets/util/enumeration.py | 1 +
lib/beets/util/functemplate.py | 221 +-
lib/beets/util/hidden.py | 81 +-
lib/beets/util/id_extractors.py | 65 +
lib/beets/util/m3u.py | 97 +
lib/beets/util/pipeline.py | 57 +-
lib/beets/vfs.py | 3 +-
lib/beetsplug/__init__.py | 1 +
lib/beetsplug/absubmit.py | 141 +-
lib/beetsplug/acousticbrainz.py | 292 ++-
lib/beetsplug/advancedrewrite.py | 174 ++
lib/beetsplug/albumtypes.py | 44 +-
lib/beetsplug/aura.py | 248 ++-
lib/beetsplug/autobpm.py | 92 +
lib/beetsplug/badfiles.py | 96 +-
lib/beetsplug/bareasc.py | 41 +-
lib/beetsplug/beatport.py | 337 +--
lib/beetsplug/bench.py | 114 +-
lib/beetsplug/bpd/__init__.py | 686 +++---
lib/beetsplug/bpd/gstplayer.py | 42 +-
lib/beetsplug/bpm.py | 40 +-
lib/beetsplug/bpsync.py | 61 +-
lib/beetsplug/bucket.py | 167 +-
lib/beetsplug/chroma.py | 199 +-
lib/beetsplug/convert.py | 694 ++++--
lib/beetsplug/deezer.py | 202 +-
lib/beetsplug/discogs.py | 571 +++--
lib/beetsplug/duplicates.py | 321 +--
lib/beetsplug/edit.py | 153 +-
lib/beetsplug/embedart.py | 231 +-
lib/beetsplug/embyupdate.py | 117 +-
lib/beetsplug/export.py | 177 +-
lib/beetsplug/fetchart.py | 1113 ++++++----
lib/beetsplug/filefilter.py | 39 +-
lib/beetsplug/fish.py | 258 ++-
lib/beetsplug/freedesktop.py | 17 +-
lib/beetsplug/fromfilename.py | 107 +-
lib/beetsplug/ftintitle.py | 54 +-
lib/beetsplug/fuzzy.py | 25 +-
lib/beetsplug/gmusic.py | 8 +-
lib/beetsplug/hook.py | 66 +-
lib/beetsplug/ihate.py | 44 +-
lib/beetsplug/importadded.py | 122 +-
lib/beetsplug/importfeeds.py | 102 +-
lib/beetsplug/info.py | 79 +-
lib/beetsplug/inline.py | 56 +-
lib/beetsplug/ipfs.py | 124 +-
lib/beetsplug/keyfinder.py | 61 +-
lib/beetsplug/kodiupdate.py | 78 +-
lib/beetsplug/lastgenre/__init__.py | 247 ++-
lib/beetsplug/lastgenre/genres.txt | 1 -
lib/beetsplug/lastimport.py | 237 +-
lib/beetsplug/limit.py | 96 +
lib/beetsplug/listenbrainz.py | 266 +++
lib/beetsplug/loadext.py | 15 +-
lib/beetsplug/lyrics.py | 778 ++++---
lib/beetsplug/mbcollection.py | 106 +-
lib/beetsplug/mbsubmit.py | 70 +-
lib/beetsplug/mbsync.py | 103 +-
lib/beetsplug/metasync/__init__.py | 62 +-
lib/beetsplug/metasync/amarok.py | 68 +-
lib/beetsplug/metasync/itunes.py | 93 +-
lib/beetsplug/missing.py | 177 +-
lib/beetsplug/mpdstats.py | 241 +-
lib/beetsplug/mpdupdate.py | 69 +-
lib/beetsplug/parentwork.py | 164 +-
lib/beetsplug/permissions.py | 69 +-
lib/beetsplug/play.py | 164 +-
lib/beetsplug/playlist.py | 139 +-
lib/beetsplug/plexupdate.py | 112 +-
lib/beetsplug/random.py | 37 +-
lib/beetsplug/replaygain.py | 1281 ++++++-----
lib/beetsplug/rewrite.py | 16 +-
lib/beetsplug/scrub.py | 107 +-
lib/beetsplug/smartplaylist.py | 274 ++-
lib/beetsplug/sonosupdate.py | 15 +-
lib/beetsplug/spotify.py | 483 ++--
lib/beetsplug/subsonicplaylist.py | 137 +-
lib/beetsplug/subsonicupdate.py | 112 +-
lib/beetsplug/substitute.py | 56 +
lib/beetsplug/the.py | 63 +-
lib/beetsplug/thumbnails.py | 218 +-
lib/beetsplug/types.py | 18 +-
lib/beetsplug/unimported.py | 38 +-
lib/beetsplug/web/__init__.py | 326 +--
lib/beetsplug/web/static/backbone.js | 6 +-
lib/beetsplug/web/static/jquery.js | 8 +-
lib/beetsplug/zero.py | 82 +-
116 files changed, 17353 insertions(+), 9964 deletions(-)
rename lib/beets/{util/confit.py => test/__init__.py} (62%)
create mode 100644 lib/beets/test/_common.py
create mode 100644 lib/beets/test/helper.py
create mode 100644 lib/beets/util/id_extractors.py
create mode 100644 lib/beets/util/m3u.py
create mode 100644 lib/beetsplug/advancedrewrite.py
create mode 100644 lib/beetsplug/autobpm.py
create mode 100644 lib/beetsplug/limit.py
create mode 100644 lib/beetsplug/listenbrainz.py
create mode 100644 lib/beetsplug/substitute.py
diff --git a/lib/beets/__init__.py b/lib/beets/__init__.py
index 9642a6f3..16f51f85 100644
--- a/lib/beets/__init__.py
+++ b/lib/beets/__init__.py
@@ -13,28 +13,29 @@
# included in all copies or substantial portions of the Software.
-import confuse
from sys import stderr
-__version__ = '1.6.0'
-__author__ = 'Adrian Sampson '
+import confuse
+
+__version__ = "2.0.0"
+__author__ = "Adrian Sampson "
class IncludeLazyConfig(confuse.LazyConfig):
"""A version of Confuse's LazyConfig that also merges in data from
YAML files specified in an `include` setting.
"""
+
def read(self, user=True, defaults=True):
super().read(user, defaults)
try:
- for view in self['include']:
+ for view in self["include"]:
self.set_file(view.as_filename())
except confuse.NotFoundError:
pass
except confuse.ConfigReadError as err:
- stderr.write("configuration `import` failed: {}"
- .format(err.reason))
+ stderr.write("configuration `import` failed: {}".format(err.reason))
-config = IncludeLazyConfig('beets', __name__)
+config = IncludeLazyConfig("beets", __name__)
diff --git a/lib/beets/__main__.py b/lib/beets/__main__.py
index ac829de9..81995f7a 100644
--- a/lib/beets/__main__.py
+++ b/lib/beets/__main__.py
@@ -18,6 +18,7 @@
import sys
+
from .ui import main
if __name__ == "__main__":
diff --git a/lib/beets/art.py b/lib/beets/art.py
index 13d5dfbd..466d4000 100644
--- a/lib/beets/art.py
+++ b/lib/beets/art.py
@@ -17,21 +17,19 @@ music and items' embedded album art.
"""
-import subprocess
-import platform
-from tempfile import NamedTemporaryFile
import os
+from tempfile import NamedTemporaryFile
-from beets.util import displayable_path, syspath, bytestring_path
-from beets.util.artresizer import ArtResizer
import mediafile
+from beets.util import bytestring_path, displayable_path, syspath
+from beets.util.artresizer import ArtResizer
+
def mediafile_image(image_path, maxwidth=None):
- """Return a `mediafile.Image` object for the path.
- """
+ """Return a `mediafile.Image` object for the path."""
- with open(syspath(image_path), 'rb') as f:
+ with open(syspath(image_path), "rb") as f:
data = f.read()
return mediafile.Image(data, type=mediafile.ImageType.front)
@@ -41,170 +39,168 @@ def get_art(log, item):
try:
mf = mediafile.MediaFile(syspath(item.path))
except mediafile.UnreadableFileError as exc:
- log.warning('Could not extract art from {0}: {1}',
- displayable_path(item.path), exc)
+ log.warning(
+ "Could not extract art from {0}: {1}",
+ displayable_path(item.path),
+ exc,
+ )
return
return mf.art
-def embed_item(log, item, imagepath, maxwidth=None, itempath=None,
- compare_threshold=0, ifempty=False, as_album=False, id3v23=None,
- quality=0):
- """Embed an image into the item's media file.
- """
- # Conditions and filters.
+def embed_item(
+ log,
+ item,
+ imagepath,
+ maxwidth=None,
+ itempath=None,
+ compare_threshold=0,
+ ifempty=False,
+ as_album=False,
+ id3v23=None,
+ quality=0,
+):
+ """Embed an image into the item's media file."""
+ # Conditions.
if compare_threshold:
- if not check_art_similarity(log, item, imagepath, compare_threshold):
- log.info('Image not similar; skipping.')
+ is_similar = check_art_similarity(
+ log, item, imagepath, compare_threshold
+ )
+ if is_similar is None:
+ log.warning("Error while checking art similarity; skipping.")
return
+ elif not is_similar:
+ log.info("Image not similar; skipping.")
+ return
+
if ifempty and get_art(log, item):
- log.info('media file already contained art')
+ log.info("media file already contained art")
return
+
+ # Filters.
if maxwidth and not as_album:
imagepath = resize_image(log, imagepath, maxwidth, quality)
# Get the `Image` object from the file.
try:
- log.debug('embedding {0}', displayable_path(imagepath))
+ log.debug("embedding {0}", displayable_path(imagepath))
image = mediafile_image(imagepath, maxwidth)
except OSError as exc:
- log.warning('could not read image file: {0}', exc)
+ log.warning("could not read image file: {0}", exc)
return
# Make sure the image kind is safe (some formats only support PNG
# and JPEG).
- if image.mime_type not in ('image/jpeg', 'image/png'):
- log.info('not embedding image of unsupported type: {}',
- image.mime_type)
+ if image.mime_type not in ("image/jpeg", "image/png"):
+ log.info("not embedding image of unsupported type: {}", image.mime_type)
return
- item.try_write(path=itempath, tags={'images': [image]}, id3v23=id3v23)
+ item.try_write(path=itempath, tags={"images": [image]}, id3v23=id3v23)
-def embed_album(log, album, maxwidth=None, quiet=False, compare_threshold=0,
- ifempty=False, quality=0):
- """Embed album art into all of the album's items.
- """
+def embed_album(
+ log,
+ album,
+ maxwidth=None,
+ quiet=False,
+ compare_threshold=0,
+ ifempty=False,
+ quality=0,
+):
+ """Embed album art into all of the album's items."""
imagepath = album.artpath
if not imagepath:
- log.info('No album art present for {0}', album)
+ log.info("No album art present for {0}", album)
return
if not os.path.isfile(syspath(imagepath)):
- log.info('Album art not found at {0} for {1}',
- displayable_path(imagepath), album)
+ log.info(
+ "Album art not found at {0} for {1}",
+ displayable_path(imagepath),
+ album,
+ )
return
if maxwidth:
imagepath = resize_image(log, imagepath, maxwidth, quality)
- log.info('Embedding album art into {0}', album)
+ log.info("Embedding album art into {0}", album)
for item in album.items():
- embed_item(log, item, imagepath, maxwidth, None, compare_threshold,
- ifempty, as_album=True, quality=quality)
+ embed_item(
+ log,
+ item,
+ imagepath,
+ maxwidth,
+ None,
+ compare_threshold,
+ ifempty,
+ as_album=True,
+ quality=quality,
+ )
def resize_image(log, imagepath, maxwidth, quality):
"""Returns path to an image resized to maxwidth and encoded with the
specified quality level.
"""
- log.debug('Resizing album art to {0} pixels wide and encoding at quality \
- level {1}', maxwidth, quality)
- imagepath = ArtResizer.shared.resize(maxwidth, syspath(imagepath),
- quality=quality)
+ log.debug(
+ "Resizing album art to {0} pixels wide and encoding at quality \
+ level {1}",
+ maxwidth,
+ quality,
+ )
+ imagepath = ArtResizer.shared.resize(
+ maxwidth, syspath(imagepath), quality=quality
+ )
return imagepath
-def check_art_similarity(log, item, imagepath, compare_threshold):
+def check_art_similarity(
+ log,
+ item,
+ imagepath,
+ compare_threshold,
+ artresizer=None,
+):
"""A boolean indicating if an image is similar to embedded item art.
+
+ If no embedded art exists, always return `True`. If the comparison fails
+ for some reason, the return value is `None`.
+
+ This must only be called if `ArtResizer.shared.can_compare` is `True`.
"""
with NamedTemporaryFile(delete=True) as f:
art = extract(log, f.name, item)
- if art:
- is_windows = platform.system() == "Windows"
+ if not art:
+ return True
- # Converting images to grayscale tends to minimize the weight
- # of colors in the diff score. So we first convert both images
- # to grayscale and then pipe them into the `compare` command.
- # On Windows, ImageMagick doesn't support the magic \\?\ prefix
- # on paths, so we pass `prefix=False` to `syspath`.
- convert_cmd = ['convert', syspath(imagepath, prefix=False),
- syspath(art, prefix=False),
- '-colorspace', 'gray', 'MIFF:-']
- compare_cmd = ['compare', '-metric', 'PHASH', '-', 'null:']
- log.debug('comparing images with pipeline {} | {}',
- convert_cmd, compare_cmd)
- convert_proc = subprocess.Popen(
- convert_cmd,
- stdout=subprocess.PIPE,
- stderr=subprocess.PIPE,
- close_fds=not is_windows,
- )
- compare_proc = subprocess.Popen(
- compare_cmd,
- stdin=convert_proc.stdout,
- stdout=subprocess.PIPE,
- stderr=subprocess.PIPE,
- close_fds=not is_windows,
- )
+ if artresizer is None:
+ artresizer = ArtResizer.shared
- # Check the convert output. We're not interested in the
- # standard output; that gets piped to the next stage.
- convert_proc.stdout.close()
- convert_stderr = convert_proc.stderr.read()
- convert_proc.stderr.close()
- convert_proc.wait()
- if convert_proc.returncode:
- log.debug(
- 'ImageMagick convert failed with status {}: {!r}',
- convert_proc.returncode,
- convert_stderr,
- )
- return
-
- # Check the compare output.
- stdout, stderr = compare_proc.communicate()
- if compare_proc.returncode:
- if compare_proc.returncode != 1:
- log.debug('ImageMagick compare failed: {0}, {1}',
- displayable_path(imagepath),
- displayable_path(art))
- return
- out_str = stderr
- else:
- out_str = stdout
-
- try:
- phash_diff = float(out_str)
- except ValueError:
- log.debug('IM output is not a number: {0!r}', out_str)
- return
-
- log.debug('ImageMagick compare score: {0}', phash_diff)
- return phash_diff <= compare_threshold
-
- return True
+ return artresizer.compare(art, imagepath, compare_threshold)
def extract(log, outpath, item):
art = get_art(log, item)
outpath = bytestring_path(outpath)
if not art:
- log.info('No album art present in {0}, skipping.', item)
+ log.info("No album art present in {0}, skipping.", item)
return
# Add an extension to the filename.
ext = mediafile.image_extension(art)
if not ext:
- log.warning('Unknown image type in {0}.',
- displayable_path(item.path))
+ log.warning("Unknown image type in {0}.", displayable_path(item.path))
return
- outpath += bytestring_path('.' + ext)
+ outpath += bytestring_path("." + ext)
- log.info('Extracting album art from: {0} to: {1}',
- item, displayable_path(outpath))
- with open(syspath(outpath), 'wb') as f:
+ log.info(
+ "Extracting album art from: {0} to: {1}",
+ item,
+ displayable_path(outpath),
+ )
+ with open(syspath(outpath), "wb") as f:
f.write(art)
return outpath
@@ -218,7 +214,7 @@ def extract_first(log, outpath, items):
def clear(log, lib, query):
items = lib.items(query)
- log.info('Clearing album art from {0} items', len(items))
+ log.info("Clearing album art from {0} items", len(items))
for item in items:
- log.debug('Clearing art for {0}', item)
- item.try_write(tags={'images': None})
+ log.debug("Clearing art for {0}", item)
+ item.try_write(tags={"images": None})
diff --git a/lib/beets/autotag/__init__.py b/lib/beets/autotag/__init__.py
index e62f492c..54a9d554 100644
--- a/lib/beets/autotag/__init__.py
+++ b/lib/beets/autotag/__init__.py
@@ -14,78 +14,91 @@
"""Facilities for automatically determining files' correct metadata.
"""
+from typing import Mapping
-
-from beets import logging
-from beets import config
+from beets import config, logging
+from beets.library import Item
# Parts of external interface.
from .hooks import ( # noqa
AlbumInfo,
- TrackInfo,
AlbumMatch,
- TrackMatch,
Distance,
+ TrackInfo,
+ TrackMatch,
)
-from .match import tag_item, tag_album, Proposal # noqa
from .match import Recommendation # noqa
+from .match import Proposal, current_metadata, tag_album, tag_item # noqa
# Global logger.
-log = logging.getLogger('beets')
+log = logging.getLogger("beets")
# Metadata fields that are already hardcoded, or where the tag name changes.
SPECIAL_FIELDS = {
- 'album': (
- 'va',
- 'releasegroup_id',
- 'artist_id',
- 'album_id',
- 'mediums',
- 'tracks',
- 'year',
- 'month',
- 'day',
- 'artist',
- 'artist_credit',
- 'artist_sort',
- 'data_url'
+ "album": (
+ "va",
+ "releasegroup_id",
+ "artist_id",
+ "artists_ids",
+ "album_id",
+ "mediums",
+ "tracks",
+ "year",
+ "month",
+ "day",
+ "artist",
+ "artists",
+ "artist_credit",
+ "artists_credit",
+ "artist_sort",
+ "artists_sort",
+ "data_url",
+ ),
+ "track": (
+ "track_alt",
+ "artist_id",
+ "artists_ids",
+ "release_track_id",
+ "medium",
+ "index",
+ "medium_index",
+ "title",
+ "artist_credit",
+ "artists_credit",
+ "artist_sort",
+ "artists_sort",
+ "artist",
+ "artists",
+ "track_id",
+ "medium_total",
+ "data_url",
+ "length",
),
- 'track': (
- 'track_alt',
- 'artist_id',
- 'release_track_id',
- 'medium',
- 'index',
- 'medium_index',
- 'title',
- 'artist_credit',
- 'artist_sort',
- 'artist',
- 'track_id',
- 'medium_total',
- 'data_url',
- 'length'
- )
}
# Additional utilities for the main interface.
-def apply_item_metadata(item, track_info):
- """Set an item's metadata from its matched TrackInfo object.
- """
+
+def apply_item_metadata(item: Item, track_info: TrackInfo):
+ """Set an item's metadata from its matched TrackInfo object."""
item.artist = track_info.artist
+ item.artists = track_info.artists
item.artist_sort = track_info.artist_sort
+ item.artists_sort = track_info.artists_sort
item.artist_credit = track_info.artist_credit
+ item.artists_credit = track_info.artists_credit
item.title = track_info.title
item.mb_trackid = track_info.track_id
item.mb_releasetrackid = track_info.release_track_id
if track_info.artist_id:
item.mb_artistid = track_info.artist_id
+ if track_info.artists_ids:
+ item.mb_artistids = track_info.artists_ids
for field, value in track_info.items():
# We only overwrite fields that are not already hardcoded.
- if field in SPECIAL_FIELDS['track']:
+ if field in SPECIAL_FIELDS["track"]:
continue
if value is None:
continue
@@ -95,45 +108,62 @@ def apply_item_metadata(item, track_info):
# and track number). Perhaps these should be emptied?
-def apply_metadata(album_info, mapping):
+def apply_metadata(album_info: AlbumInfo, mapping: Mapping[Item, TrackInfo]):
"""Set the items' metadata to match an AlbumInfo object using a
mapping from Items to TrackInfo objects.
"""
for item, track_info in mapping.items():
# Artist or artist credit.
- if config['artist_credit']:
- item.artist = (track_info.artist_credit or
- track_info.artist or
- album_info.artist_credit or
- album_info.artist)
- item.albumartist = (album_info.artist_credit or
- album_info.artist)
+ if config["artist_credit"]:
+ item.artist = (
+ track_info.artist_credit
+ or track_info.artist
+ or album_info.artist_credit
+ or album_info.artist
+ )
+ item.artists = (
+ track_info.artists_credit
+ or track_info.artists
+ or album_info.artists_credit
+ or album_info.artists
+ )
+ item.albumartist = album_info.artist_credit or album_info.artist
+ item.albumartists = album_info.artists_credit or album_info.artists
else:
- item.artist = (track_info.artist or album_info.artist)
+ item.artist = track_info.artist or album_info.artist
+ item.artists = track_info.artists or album_info.artists
item.albumartist = album_info.artist
+ item.albumartists = album_info.artists
# Album.
item.album = album_info.album
# Artist sort and credit names.
item.artist_sort = track_info.artist_sort or album_info.artist_sort
- item.artist_credit = (track_info.artist_credit or
- album_info.artist_credit)
+ item.artists_sort = track_info.artists_sort or album_info.artists_sort
+ item.artist_credit = (
+ track_info.artist_credit or album_info.artist_credit
+ )
+ item.artists_credit = (
+ track_info.artists_credit or album_info.artists_credit
+ )
item.albumartist_sort = album_info.artist_sort
+ item.albumartists_sort = album_info.artists_sort
item.albumartist_credit = album_info.artist_credit
+ item.albumartists_credit = album_info.artists_credit
# Release date.
- for prefix in '', 'original_':
- if config['original_date'] and not prefix:
+ for prefix in "", "original_":
+ if config["original_date"] and not prefix:
# Ignore specific release date.
continue
- for suffix in 'year', 'month', 'day':
+ for suffix in "year", "month", "day":
key = prefix + suffix
value = getattr(album_info, key) or 0
# If we don't even have a year, apply nothing.
- if suffix == 'year' and not value:
+ if suffix == "year" and not value:
break
# Otherwise, set the fetched value (or 0 for the month
@@ -142,13 +172,13 @@ def apply_metadata(album_info, mapping):
# If we're using original release date for both fields,
# also set item.year = info.original_year, etc.
- if config['original_date']:
+ if config["original_date"]:
item[suffix] = value
# Title.
item.title = track_info.title
- if config['per_disc_numbering']:
+ if config["per_disc_numbering"]:
# We want to let the track number be zero, but if the medium index
# is not provided we need to fall back to the overall index.
if track_info.medium_index is not None:
@@ -172,7 +202,14 @@ def apply_metadata(album_info, mapping):
item.mb_artistid = track_info.artist_id
else:
item.mb_artistid = album_info.artist_id
+
+ if track_info.artists_ids:
+ item.mb_artistids = track_info.artists_ids
+ else:
+ item.mb_artistids = album_info.artists_ids
+
item.mb_albumartistid = album_info.artist_id
+ item.mb_albumartistids = album_info.artists_ids
item.mb_releasegroupid = album_info.releasegroup_id
# Compilation flag.
@@ -184,17 +221,17 @@ def apply_metadata(album_info, mapping):
# Don't overwrite fields with empty values unless the
# field is explicitly allowed to be overwritten
for field, value in album_info.items():
- if field in SPECIAL_FIELDS['album']:
+ if field in SPECIAL_FIELDS["album"]:
continue
- clobber = field in config['overwrite_null']['album'].as_str_seq()
+ clobber = field in config["overwrite_null"]["album"].as_str_seq()
if value is None and not clobber:
continue
item[field] = value
for field, value in track_info.items():
- if field in SPECIAL_FIELDS['track']:
+ if field in SPECIAL_FIELDS["track"]:
continue
- clobber = field in config['overwrite_null']['track'].as_str_seq()
+ clobber = field in config["overwrite_null"]["track"].as_str_seq()
value = getattr(track_info, field)
if value is None and not clobber:
continue
diff --git a/lib/beets/autotag/hooks.py b/lib/beets/autotag/hooks.py
index 9cd6f2cd..0a9b7daf 100644
--- a/lib/beets/autotag/hooks.py
+++ b/lib/beets/autotag/hooks.py
@@ -14,40 +14,51 @@
"""Glue between metadata sources and the matching logic."""
+from __future__ import annotations
+
+import re
from collections import namedtuple
from functools import total_ordering
-import re
+from typing import (
+ Any,
+ Callable,
+ Dict,
+ Iterable,
+ Iterator,
+ List,
+ Optional,
+ Tuple,
+ TypeVar,
+ Union,
+ cast,
+)
-from beets import logging
-from beets import plugins
-from beets import config
-from beets.util import as_string
-from beets.autotag import mb
from jellyfish import levenshtein_distance
from unidecode import unidecode
-log = logging.getLogger('beets')
+from beets import config, logging, plugins
+from beets.autotag import mb
+from beets.library import Item
+from beets.util import as_string, cached_classproperty
-# The name of the type for patterns in re changed in Python 3.7.
-try:
- Pattern = re._pattern_type
-except AttributeError:
- Pattern = re.Pattern
+log = logging.getLogger("beets")
+
+V = TypeVar("V")
# Classes used to represent candidate options.
-class AttrDict(dict):
+class AttrDict(Dict[str, V]):
"""A dictionary that supports attribute ("dot") access, so `d.field`
is equivalent to `d['field']`.
"""
- def __getattr__(self, attr):
+ def __getattr__(self, attr: str) -> V:
if attr in self:
- return self.get(attr)
+ return self[attr]
else:
raise AttributeError
- def __setattr__(self, key, value):
+ def __setattr__(self, key: str, value: V):
self.__setitem__(key, value)
def __hash__(self):
@@ -68,32 +79,73 @@ class AlbumInfo(AttrDict):
The others are optional and may be None.
"""
- def __init__(self, tracks, album=None, album_id=None, artist=None,
- artist_id=None, asin=None, albumtype=None, va=False,
- year=None, month=None, day=None, label=None, mediums=None,
- artist_sort=None, releasegroup_id=None, catalognum=None,
- script=None, language=None, country=None, style=None,
- genre=None, albumstatus=None, media=None, albumdisambig=None,
- releasegroupdisambig=None, artist_credit=None,
- original_year=None, original_month=None,
- original_day=None, data_source=None, data_url=None,
- discogs_albumid=None, discogs_labelid=None,
- discogs_artistid=None, **kwargs):
+ # TYPING: are all of these correct? I've assumed optional strings
+ def __init__(
+ self,
+ tracks: List[TrackInfo],
+ album: Optional[str] = None,
+ album_id: Optional[str] = None,
+ artist: Optional[str] = None,
+ artist_id: Optional[str] = None,
+ artists: Optional[List[str]] = None,
+ artists_ids: Optional[List[str]] = None,
+ asin: Optional[str] = None,
+ albumtype: Optional[str] = None,
+ albumtypes: Optional[List[str]] = None,
+ va: bool = False,
+ year: Optional[int] = None,
+ month: Optional[int] = None,
+ day: Optional[int] = None,
+ label: Optional[str] = None,
+ barcode: Optional[str] = None,
+ mediums: Optional[int] = None,
+ artist_sort: Optional[str] = None,
+ artists_sort: Optional[List[str]] = None,
+ releasegroup_id: Optional[str] = None,
+ release_group_title: Optional[str] = None,
+ catalognum: Optional[str] = None,
+ script: Optional[str] = None,
+ language: Optional[str] = None,
+ country: Optional[str] = None,
+ style: Optional[str] = None,
+ genre: Optional[str] = None,
+ albumstatus: Optional[str] = None,
+ media: Optional[str] = None,
+ albumdisambig: Optional[str] = None,
+ releasegroupdisambig: Optional[str] = None,
+ artist_credit: Optional[str] = None,
+ artists_credit: Optional[List[str]] = None,
+ original_year: Optional[int] = None,
+ original_month: Optional[int] = None,
+ original_day: Optional[int] = None,
+ data_source: Optional[str] = None,
+ data_url: Optional[str] = None,
+ discogs_albumid: Optional[str] = None,
+ discogs_labelid: Optional[str] = None,
+ discogs_artistid: Optional[str] = None,
+ **kwargs,
+ ):
self.album = album
self.album_id = album_id
self.artist = artist
self.artist_id = artist_id
+ self.artists = artists or []
+ self.artists_ids = artists_ids or []
self.tracks = tracks
self.asin = asin
self.albumtype = albumtype
+ self.albumtypes = albumtypes or []
self.va = va
self.year = year
self.month = month
self.day = day
self.label = label
+ self.barcode = barcode
self.mediums = mediums
self.artist_sort = artist_sort
+ self.artists_sort = artists_sort or []
self.releasegroup_id = releasegroup_id
+ self.release_group_title = release_group_title
self.catalognum = catalognum
self.script = script
self.language = language
@@ -105,6 +157,7 @@ class AlbumInfo(AttrDict):
self.albumdisambig = albumdisambig
self.releasegroupdisambig = releasegroupdisambig
self.artist_credit = artist_credit
+ self.artists_credit = artists_credit or []
self.original_year = original_year
self.original_month = original_month
self.original_day = original_day
@@ -115,27 +168,7 @@ class AlbumInfo(AttrDict):
self.discogs_artistid = discogs_artistid
self.update(kwargs)
- # Work around a bug in python-musicbrainz-ngs that causes some
- # strings to be bytes rather than Unicode.
- # https://github.com/alastair/python-musicbrainz-ngs/issues/85
- def decode(self, codec='utf-8'):
- """Ensure that all string attributes on this object, and the
- constituent `TrackInfo` objects, are decoded to Unicode.
- """
- for fld in ['album', 'artist', 'albumtype', 'label', 'artist_sort',
- 'catalognum', 'script', 'language', 'country', 'style',
- 'genre', 'albumstatus', 'albumdisambig',
- 'releasegroupdisambig', 'artist_credit',
- 'media', 'discogs_albumid', 'discogs_labelid',
- 'discogs_artistid']:
- value = getattr(self, fld)
- if isinstance(value, bytes):
- setattr(self, fld, value.decode(codec, 'ignore'))
-
- for track in self.tracks:
- track.decode(codec)
-
- def copy(self):
+ def copy(self) -> AlbumInfo:
dupe = AlbumInfo([])
dupe.update(self)
dupe.tracks = [track.copy() for track in self.tracks]
@@ -154,20 +187,50 @@ class TrackInfo(AttrDict):
are all 1-based.
"""
- def __init__(self, title=None, track_id=None, release_track_id=None,
- artist=None, artist_id=None, length=None, index=None,
- medium=None, medium_index=None, medium_total=None,
- artist_sort=None, disctitle=None, artist_credit=None,
- data_source=None, data_url=None, media=None, lyricist=None,
- composer=None, composer_sort=None, arranger=None,
- track_alt=None, work=None, mb_workid=None,
- work_disambig=None, bpm=None, initial_key=None, genre=None,
- **kwargs):
+ # TYPING: are all of these correct? I've assumed optional strings
+ def __init__(
+ self,
+ title: Optional[str] = None,
+ track_id: Optional[str] = None,
+ release_track_id: Optional[str] = None,
+ artist: Optional[str] = None,
+ artist_id: Optional[str] = None,
+ artists: Optional[List[str]] = None,
+ artists_ids: Optional[List[str]] = None,
+ length: Optional[float] = None,
+ index: Optional[int] = None,
+ medium: Optional[int] = None,
+ medium_index: Optional[int] = None,
+ medium_total: Optional[int] = None,
+ artist_sort: Optional[str] = None,
+ artists_sort: Optional[List[str]] = None,
+ disctitle: Optional[str] = None,
+ artist_credit: Optional[str] = None,
+ artists_credit: Optional[List[str]] = None,
+ data_source: Optional[str] = None,
+ data_url: Optional[str] = None,
+ media: Optional[str] = None,
+ lyricist: Optional[str] = None,
+ composer: Optional[str] = None,
+ composer_sort: Optional[str] = None,
+ arranger: Optional[str] = None,
+ track_alt: Optional[str] = None,
+ work: Optional[str] = None,
+ mb_workid: Optional[str] = None,
+ work_disambig: Optional[str] = None,
+ bpm: Optional[str] = None,
+ initial_key: Optional[str] = None,
+ genre: Optional[str] = None,
+ album: Optional[str] = None,
+ **kwargs,
+ ):
self.title = title
self.track_id = track_id
self.release_track_id = release_track_id
self.artist = artist
self.artist_id = artist_id
+ self.artists = artists or []
+ self.artists_ids = artists_ids or []
self.length = length
self.index = index
self.media = media
@@ -175,8 +238,10 @@ class TrackInfo(AttrDict):
self.medium_index = medium_index
self.medium_total = medium_total
self.artist_sort = artist_sort
+ self.artists_sort = artists_sort or []
self.disctitle = disctitle
self.artist_credit = artist_credit
+ self.artists_credit = artists_credit or []
self.data_source = data_source
self.data_url = data_url
self.lyricist = lyricist
@@ -190,20 +255,10 @@ class TrackInfo(AttrDict):
self.bpm = bpm
self.initial_key = initial_key
self.genre = genre
+ self.album = album
self.update(kwargs)
- # As above, work around a bug in python-musicbrainz-ngs.
- def decode(self, codec='utf-8'):
- """Ensure that all string attributes on this object are decoded
- to Unicode.
- """
- for fld in ['title', 'artist', 'medium', 'artist_sort', 'disctitle',
- 'artist_credit', 'media']:
- value = getattr(self, fld)
- if isinstance(value, bytes):
- setattr(self, fld, value.decode(codec, 'ignore'))
-
- def copy(self):
+ def copy(self) -> TrackInfo:
dupe = TrackInfo()
dupe.update(self)
return dupe
@@ -213,23 +268,23 @@ class TrackInfo(AttrDict):
# Parameters for string distance function.
# Words that can be moved to the end of a string using a comma.
-SD_END_WORDS = ['the', 'a', 'an']
+SD_END_WORDS = ["the", "a", "an"]
# Reduced weights for certain portions of the string.
SD_PATTERNS = [
- (r'^the ', 0.1),
- (r'[\[\(]?(ep|single)[\]\)]?', 0.0),
- (r'[\[\(]?(featuring|feat|ft)[\. :].+', 0.1),
- (r'\(.*?\)', 0.3),
- (r'\[.*?\]', 0.3),
- (r'(, )?(pt\.|part) .+', 0.2),
+ (r"^the ", 0.1),
+ (r"[\[\(]?(ep|single)[\]\)]?", 0.0),
+ (r"[\[\(]?(featuring|feat|ft)[\. :].+", 0.1),
+ (r"\(.*?\)", 0.3),
+ (r"\[.*?\]", 0.3),
+ (r"(, )?(pt\.|part) .+", 0.2),
]
# Replacements to use before testing distance.
SD_REPLACE = [
- (r'&', 'and'),
+ (r"&", "and"),
]
-def _string_dist_basic(str1, str2):
+def _string_dist_basic(str1: str, str2: str) -> float:
"""Basic edit distance between two strings, ignoring
non-alphanumeric characters and case. Comparisons are based on a
transliteration/lowering to ASCII characters. Normalized by string
@@ -239,14 +294,14 @@ def _string_dist_basic(str1, str2):
assert isinstance(str2, str)
str1 = as_string(unidecode(str1))
str2 = as_string(unidecode(str2))
- str1 = re.sub(r'[^a-z0-9]', '', str1.lower())
- str2 = re.sub(r'[^a-z0-9]', '', str2.lower())
+ str1 = re.sub(r"[^a-z0-9]", "", str1.lower())
+ str2 = re.sub(r"[^a-z0-9]", "", str2.lower())
if not str1 and not str2:
return 0.0
return levenshtein_distance(str1, str2) / float(max(len(str1), len(str2)))
-def string_dist(str1, str2):
+def string_dist(str1: Optional[str], str2: Optional[str]) -> float:
"""Gives an "intuitive" edit distance between two strings. This is
an edit distance, normalized by the string length, with a number of
tweaks that reflect intuition about text.
@@ -263,10 +318,10 @@ def string_dist(str1, str2):
# example, "the something" should be considered equal to
# "something, the".
for word in SD_END_WORDS:
- if str1.endswith(', %s' % word):
- str1 = '{} {}'.format(word, str1[:-len(word) - 2])
- if str2.endswith(', %s' % word):
- str2 = '{} {}'.format(word, str2[:-len(word) - 2])
+ if str1.endswith(", %s" % word):
+ str1 = "{} {}".format(word, str1[: -len(word) - 2])
+ if str2.endswith(", %s" % word):
+ str2 = "{} {}".format(word, str2[: -len(word) - 2])
# Perform a couple of basic normalizing substitutions.
for pat, repl in SD_REPLACE:
@@ -281,8 +336,8 @@ def string_dist(str1, str2):
penalty = 0.0
for pat, weight in SD_PATTERNS:
# Get strings that drop the pattern.
- case_str1 = re.sub(pat, '', str1)
- case_str2 = re.sub(pat, '', str2)
+ case_str1 = re.sub(pat, "", str1)
+ case_str2 = re.sub(pat, "", str2)
if case_str1 != str1 or case_str2 != str2:
# If the pattern was present (i.e., it is deleted in the
@@ -304,23 +359,6 @@ def string_dist(str1, str2):
return base_dist + penalty
-class LazyClassProperty:
- """A decorator implementing a read-only property that is *lazy* in
- the sense that the getter is only invoked once. Subsequent accesses
- through *any* instance use the cached result.
- """
-
- def __init__(self, getter):
- self.getter = getter
- self.computed = False
-
- def __get__(self, obj, owner):
- if not self.computed:
- self.value = self.getter(owner)
- self.computed = True
- return self.value
-
-
@total_ordering
class Distance:
"""Keeps track of multiple distance penalties. Provides a single
@@ -330,12 +368,12 @@ class Distance:
def __init__(self):
self._penalties = {}
+ self.tracks: Dict[TrackInfo, Distance] = {}
- @LazyClassProperty
- def _weights(cls): # noqa: N805
- """A dictionary from keys to floating-point weights.
- """
- weights_view = config['match']['distance_weights']
+ @cached_classproperty
+ def _weights(cls) -> Dict[str, float]: # noqa: N805
+ """A dictionary from keys to floating-point weights."""
+ weights_view = config["match"]["distance_weights"]
weights = {}
for key in weights_view.keys():
weights[key] = weights_view[key].as_number()
@@ -344,7 +382,7 @@ class Distance:
# Access the components and their aggregates.
@property
- def distance(self):
+ def distance(self) -> float:
"""Return a weighted and normalized distance across all
penalties.
"""
@@ -354,24 +392,22 @@ class Distance:
return 0.0
@property
- def max_distance(self):
- """Return the maximum distance penalty (normalization factor).
- """
+ def max_distance(self) -> float:
+ """Return the maximum distance penalty (normalization factor)."""
dist_max = 0.0
for key, penalty in self._penalties.items():
dist_max += len(penalty) * self._weights[key]
return dist_max
@property
- def raw_distance(self):
- """Return the raw (denormalized) distance.
- """
+ def raw_distance(self) -> float:
+ """Return the raw (denormalized) distance."""
dist_raw = 0.0
for key, penalty in self._penalties.items():
dist_raw += sum(penalty) * self._weights[key]
return dist_raw
- def items(self):
+ def items(self) -> List[Tuple[str, float]]:
"""Return a list of (key, dist) pairs, with `dist` being the
weighted distance, sorted from highest to lowest. Does not
include penalties with a zero value.
@@ -385,87 +421,88 @@ class Distance:
# ascending order (for keys, when the penalty is equal) and
# still get the items with the biggest distance first.
return sorted(
- list_,
- key=lambda key_and_dist: (-key_and_dist[1], key_and_dist[0])
+ list_, key=lambda key_and_dist: (-key_and_dist[1], key_and_dist[0])
)
- def __hash__(self):
+ def __hash__(self) -> int:
return id(self)
- def __eq__(self, other):
+ def __eq__(self, other) -> bool:
return self.distance == other
# Behave like a float.
- def __lt__(self, other):
+ def __lt__(self, other) -> bool:
return self.distance < other
- def __float__(self):
+ def __float__(self) -> float:
return self.distance
- def __sub__(self, other):
+ def __sub__(self, other) -> float:
return self.distance - other
- def __rsub__(self, other):
+ def __rsub__(self, other) -> float:
return other - self.distance
- def __str__(self):
+ def __str__(self) -> str:
return f"{self.distance:.2f}"
# Behave like a dict.
- def __getitem__(self, key):
- """Returns the weighted distance for a named penalty.
- """
+ def __getitem__(self, key) -> float:
+ """Returns the weighted distance for a named penalty."""
dist = sum(self._penalties[key]) * self._weights[key]
dist_max = self.max_distance
if dist_max:
return dist / dist_max
return 0.0
- def __iter__(self):
+ def __iter__(self) -> Iterator[Tuple[str, float]]:
return iter(self.items())
- def __len__(self):
+ def __len__(self) -> int:
return len(self.items())
- def keys(self):
+ def keys(self) -> List[str]:
return [key for key, _ in self.items()]
- def update(self, dist):
- """Adds all the distance penalties from `dist`.
- """
+ def update(self, dist: "Distance"):
+ """Adds all the distance penalties from `dist`."""
if not isinstance(dist, Distance):
raise ValueError(
- '`dist` must be a Distance object, not {}'.format(type(dist))
+ "`dist` must be a Distance object, not {}".format(type(dist))
)
for key, penalties in dist._penalties.items():
self._penalties.setdefault(key, []).extend(penalties)
# Adding components.
- def _eq(self, value1, value2):
+ def _eq(self, value1: Union[re.Pattern[str], Any], value2: Any) -> bool:
"""Returns True if `value1` is equal to `value2`. `value1` may
be a compiled regular expression, in which case it will be
matched against `value2`.
"""
- if isinstance(value1, Pattern):
+ if isinstance(value1, re.Pattern):
+ value2 = cast(str, value2)
return bool(value1.match(value2))
return value1 == value2
- def add(self, key, dist):
+ def add(self, key: str, dist: float):
"""Adds a distance penalty. `key` must correspond with a
configured weight setting. `dist` must be a float between 0.0
and 1.0, and will be added to any existing distance penalties
for the same key.
"""
if not 0.0 <= dist <= 1.0:
- raise ValueError(
- f'`dist` must be between 0.0 and 1.0, not {dist}'
- )
+ raise ValueError(f"`dist` must be between 0.0 and 1.0, not {dist}")
self._penalties.setdefault(key, []).append(dist)
- def add_equality(self, key, value, options):
+ def add_equality(
+ self,
+ key: str,
+ value: Any,
+ options: Union[List[Any], Tuple[Any, ...], Any],
+ ):
"""Adds a distance penalty of 1.0 if `value` doesn't match any
of the values in `options`. If an option is a compiled regular
expression, it will be considered equal if it matches against
@@ -481,7 +518,7 @@ class Distance:
dist = 1.0
self.add(key, dist)
- def add_expr(self, key, expr):
+ def add_expr(self, key: str, expr: bool):
"""Adds a distance penalty of 1.0 if `expr` evaluates to True,
or 0.0.
"""
@@ -490,7 +527,7 @@ class Distance:
else:
self.add(key, 0.0)
- def add_number(self, key, number1, number2):
+ def add_number(self, key: str, number1: int, number2: int):
"""Adds a distance penalty of 1.0 for each number of difference
between `number1` and `number2`, or 0.0 when there is no
difference. Use this when there is no upper limit on the
@@ -503,7 +540,12 @@ class Distance:
else:
self.add(key, 0.0)
- def add_priority(self, key, value, options):
+ def add_priority(
+ self,
+ key: str,
+ value: Any,
+ options: Union[List[Any], Tuple[Any, ...], Any],
+ ):
"""Adds a distance penalty that corresponds to the position at
which `value` appears in `options`. A distance penalty of 0.0
for the first option, or 1.0 if there is no matching option. If
@@ -521,7 +563,12 @@ class Distance:
dist = 1.0
self.add(key, dist)
- def add_ratio(self, key, number1, number2):
+ def add_ratio(
+ self,
+ key: str,
+ number1: Union[int, float],
+ number2: Union[int, float],
+ ):
"""Adds a distance penalty for `number1` as a ratio of `number2`.
`number1` is bound at 0 and `number2`.
"""
@@ -532,7 +579,7 @@ class Distance:
dist = 0.0
self.add(key, dist)
- def add_string(self, key, str1, str2):
+ def add_string(self, key: str, str1: Optional[str], str2: Optional[str]):
"""Adds a distance penalty based on the edit distance between
`str1` and `str2`.
"""
@@ -542,64 +589,82 @@ class Distance:
# Structures that compose all the information for a candidate match.
-AlbumMatch = namedtuple('AlbumMatch', ['distance', 'info', 'mapping',
- 'extra_items', 'extra_tracks'])
+AlbumMatch = namedtuple(
+ "AlbumMatch", ["distance", "info", "mapping", "extra_items", "extra_tracks"]
+)
-TrackMatch = namedtuple('TrackMatch', ['distance', 'info'])
+TrackMatch = namedtuple("TrackMatch", ["distance", "info"])
# Aggregation of sources.
-def album_for_mbid(release_id):
+
+def album_for_mbid(release_id: str) -> Optional[AlbumInfo]:
"""Get an AlbumInfo object for a MusicBrainz release ID. Return None
if the ID is not found.
"""
try:
album = mb.album_for_id(release_id)
if album:
- plugins.send('albuminfo_received', info=album)
+ plugins.send("albuminfo_received", info=album)
return album
except mb.MusicBrainzAPIError as exc:
exc.log(log)
+ return None
-def track_for_mbid(recording_id):
+def track_for_mbid(recording_id: str) -> Optional[TrackInfo]:
"""Get a TrackInfo object for a MusicBrainz recording ID. Return None
if the ID is not found.
"""
try:
track = mb.track_for_id(recording_id)
if track:
- plugins.send('trackinfo_received', info=track)
+ plugins.send("trackinfo_received", info=track)
return track
except mb.MusicBrainzAPIError as exc:
exc.log(log)
+ return None
-def albums_for_id(album_id):
+def albums_for_id(album_id: str) -> Iterable[AlbumInfo]:
"""Get a list of albums for an ID."""
a = album_for_mbid(album_id)
if a:
yield a
for a in plugins.album_for_id(album_id):
if a:
- plugins.send('albuminfo_received', info=a)
+ plugins.send("albuminfo_received", info=a)
yield a
-def tracks_for_id(track_id):
+def tracks_for_id(track_id: str) -> Iterable[TrackInfo]:
"""Get a list of tracks for an ID."""
t = track_for_mbid(track_id)
if t:
yield t
for t in plugins.track_for_id(track_id):
if t:
- plugins.send('trackinfo_received', info=t)
+ plugins.send("trackinfo_received", info=t)
yield t
-@plugins.notify_info_yielded('albuminfo_received')
-def album_candidates(items, artist, album, va_likely, extra_tags):
+def invoke_mb(call_func: Callable, *args):
+ try:
+ return call_func(*args)
+ except mb.MusicBrainzAPIError as exc:
+ exc.log(log)
+ return ()
+
+
+@plugins.notify_info_yielded("albuminfo_received")
+def album_candidates(
+ items: List[Item],
+ artist: str,
+ album: str,
+ va_likely: bool,
+ extra_tags: Dict,
+) -> Iterable[Tuple]:
"""Search for album matches. ``items`` is a list of Item objects
that make up the album. ``artist`` and ``album`` are the respective
names (strings), which may be derived from the item list or may be
@@ -609,40 +674,33 @@ def album_candidates(items, artist, album, va_likely, extra_tags):
constrain the search.
"""
- # Base candidates if we have album and artist to match.
- if artist and album:
- try:
- yield from mb.match_album(artist, album, len(items),
- extra_tags)
- except mb.MusicBrainzAPIError as exc:
- exc.log(log)
+ if config["musicbrainz"]["enabled"]:
+ # Base candidates if we have album and artist to match.
+ if artist and album:
+ yield from invoke_mb(
+ mb.match_album, artist, album, len(items), extra_tags
+ )
- # Also add VA matches from MusicBrainz where appropriate.
- if va_likely and album:
- try:
- yield from mb.match_album(None, album, len(items),
- extra_tags)
- except mb.MusicBrainzAPIError as exc:
- exc.log(log)
+ # Also add VA matches from MusicBrainz where appropriate.
+ if va_likely and album:
+ yield from invoke_mb(
+ mb.match_album, None, album, len(items), extra_tags
+ )
# Candidates from plugins.
- yield from plugins.candidates(items, artist, album, va_likely,
- extra_tags)
+ yield from plugins.candidates(items, artist, album, va_likely, extra_tags)
-@plugins.notify_info_yielded('trackinfo_received')
-def item_candidates(item, artist, title):
+@plugins.notify_info_yielded("trackinfo_received")
+def item_candidates(item: Item, artist: str, title: str) -> Iterable[Tuple]:
"""Search for item matches. ``item`` is the Item to be matched.
``artist`` and ``title`` are strings and either reflect the item or
are specified by the user.
"""
# MusicBrainz candidates.
- if artist and title:
- try:
- yield from mb.match_track(artist, title)
- except mb.MusicBrainzAPIError as exc:
- exc.log(log)
+ if config["musicbrainz"]["enabled"] and artist and title:
+ yield from invoke_mb(mb.match_track, artist, title)
# Plugin candidates.
yield from plugins.item_candidates(item, artist, title)
diff --git a/lib/beets/autotag/match.py b/lib/beets/autotag/match.py
index d352a013..63db9e33 100644
--- a/lib/beets/autotag/match.py
+++ b/lib/beets/autotag/match.py
@@ -19,32 +19,53 @@ releases and tracks.
import datetime
import re
-from munkres import Munkres
from collections import namedtuple
+from typing import (
+ Any,
+ Dict,
+ Iterable,
+ List,
+ Optional,
+ Sequence,
+ Tuple,
+ TypeVar,
+ Union,
+ cast,
+)
-from beets import logging
-from beets import plugins
-from beets import config
+from munkres import Munkres
+
+from beets import config, logging, plugins
+from beets.autotag import (
+ AlbumInfo,
+ AlbumMatch,
+ Distance,
+ TrackInfo,
+ TrackMatch,
+ hooks,
+)
+from beets.library import Item
from beets.util import plurality
-from beets.autotag import hooks
from beets.util.enumeration import OrderedEnum
# Artist signals that indicate "various artists". These are used at the
# album level to determine whether a given release is likely a VA
# release and also on the track level to to remove the penalty for
# differing artists.
-VA_ARTISTS = ('', 'various artists', 'various', 'va', 'unknown')
+VA_ARTISTS = ("", "various artists", "various", "va", "unknown")
# Global logger.
-log = logging.getLogger('beets')
+log = logging.getLogger("beets")
# Recommendation enumeration.
+
class Recommendation(OrderedEnum):
"""Indicates a qualitative suggestion to the user about what should
be done with a given match.
"""
+
none = 0
low = 1
medium = 2
@@ -55,12 +76,15 @@ class Recommendation(OrderedEnum):
# consists of a list of possible candidates (i.e., AlbumInfo or TrackInfo
# objects) and a recommendation value.
-Proposal = namedtuple('Proposal', ('candidates', 'recommendation'))
+Proposal = namedtuple("Proposal", ("candidates", "recommendation"))
# Primary matching functionality.
-def current_metadata(items):
+
+def current_metadata(
+ items: Iterable[Item],
+) -> Tuple[Dict[str, Any], Dict[str, Any]]:
"""Extract the likely current metadata for an album given a list of its
items. Return two dictionaries:
- The most common value for each field.
@@ -70,22 +94,36 @@ def current_metadata(items):
likelies = {}
consensus = {}
- fields = ['artist', 'album', 'albumartist', 'year', 'disctotal',
- 'mb_albumid', 'label', 'catalognum', 'country', 'media',
- 'albumdisambig']
+ fields = [
+ "artist",
+ "album",
+ "albumartist",
+ "year",
+ "disctotal",
+ "mb_albumid",
+ "label",
+ "barcode",
+ "catalognum",
+ "country",
+ "media",
+ "albumdisambig",
+ ]
for field in fields:
values = [item[field] for item in items if item]
likelies[field], freq = plurality(values)
- consensus[field] = (freq == len(values))
+ consensus[field] = freq == len(values)
# If there's an album artist consensus, use this for the artist.
- if consensus['albumartist'] and likelies['albumartist']:
- likelies['artist'] = likelies['albumartist']
+ if consensus["albumartist"] and likelies["albumartist"]:
+ likelies["artist"] = likelies["albumartist"]
return likelies, consensus
-def assign_items(items, tracks):
+def assign_items(
+ items: Sequence[Item],
+ tracks: Sequence[TrackInfo],
+) -> Tuple[Dict[Item, TrackInfo], List[Item], List[TrackInfo]]:
"""Given a list of Items and a list of TrackInfo objects, find the
best mapping between them. Returns a mapping from Items to TrackInfo
objects, a set of extra Items, and a set of extra TrackInfo
@@ -93,17 +131,17 @@ def assign_items(items, tracks):
of objects of the two types.
"""
# Construct the cost matrix.
- costs = []
+ costs: List[List[Distance]] = []
for item in items:
row = []
- for i, track in enumerate(tracks):
+ for track in tracks:
row.append(track_distance(item, track))
costs.append(row)
# Find a minimum-cost bipartite matching.
- log.debug('Computing track assignment...')
+ log.debug("Computing track assignment...")
matching = Munkres().compute(costs)
- log.debug('...done.')
+ log.debug("...done.")
# Produce the output matching.
mapping = {items[i]: tracks[j] for (i, j) in matching}
@@ -114,14 +152,18 @@ def assign_items(items, tracks):
return mapping, extra_items, extra_tracks
-def track_index_changed(item, track_info):
+def track_index_changed(item: Item, track_info: TrackInfo) -> bool:
"""Returns True if the item and track info index is different. Tolerates
per disc and per release numbering.
"""
return item.track not in (track_info.medium_index, track_info.index)
-def track_distance(item, track_info, incl_artist=False):
+def track_distance(
+ item: Item,
+ track_info: TrackInfo,
+ incl_artist: bool = False,
+) -> Distance:
"""Determines the significance of a track metadata change. Returns a
Distance object. `incl_artist` indicates that a distance component should
be included for the track artist (i.e., for various-artist releases).
@@ -130,26 +172,37 @@ def track_distance(item, track_info, incl_artist=False):
# Length.
if track_info.length:
- diff = abs(item.length - track_info.length) - \
- config['match']['track_length_grace'].as_number()
- dist.add_ratio('track_length', diff,
- config['match']['track_length_max'].as_number())
+ item_length = cast(float, item.length)
+ track_length_grace = cast(
+ Union[float, int],
+ config["match"]["track_length_grace"].as_number(),
+ )
+ track_length_max = cast(
+ Union[float, int],
+ config["match"]["track_length_max"].as_number(),
+ )
+
+ diff = abs(item_length - track_info.length) - track_length_grace
+ dist.add_ratio("track_length", diff, track_length_max)
# Title.
- dist.add_string('track_title', item.title, track_info.title)
+ dist.add_string("track_title", item.title, track_info.title)
# Artist. Only check if there is actually an artist in the track data.
- if incl_artist and track_info.artist and \
- item.artist.lower() not in VA_ARTISTS:
- dist.add_string('track_artist', item.artist, track_info.artist)
+ if (
+ incl_artist
+ and track_info.artist
+ and item.artist.lower() not in VA_ARTISTS
+ ):
+ dist.add_string("track_artist", item.artist, track_info.artist)
# Track index.
if track_info.index and item.track:
- dist.add_expr('track_index', track_index_changed(item, track_info))
+ dist.add_expr("track_index", track_index_changed(item, track_info))
# Track ID.
if item.mb_trackid:
- dist.add_expr('track_id', item.mb_trackid != track_info.track_id)
+ dist.add_expr("track_id", item.mb_trackid != track_info.track_id)
# Plugins.
dist.update(plugins.track_distance(item, track_info))
@@ -157,7 +210,11 @@ def track_distance(item, track_info, incl_artist=False):
return dist
-def distance(items, album_info, mapping):
+def distance(
+ items: Sequence[Item],
+ album_info: AlbumInfo,
+ mapping: Dict[Item, TrackInfo],
+) -> Distance:
"""Determines how "significant" an album metadata change would be.
Returns a Distance object. `album_info` is an AlbumInfo object
reflecting the album to be compared. `items` is a sequence of all
@@ -172,90 +229,96 @@ def distance(items, album_info, mapping):
# Artist, if not various.
if not album_info.va:
- dist.add_string('artist', likelies['artist'], album_info.artist)
+ dist.add_string("artist", likelies["artist"], album_info.artist)
# Album.
- dist.add_string('album', likelies['album'], album_info.album)
+ dist.add_string("album", likelies["album"], album_info.album)
# Current or preferred media.
if album_info.media:
# Preferred media options.
- patterns = config['match']['preferred']['media'].as_str_seq()
- options = [re.compile(r'(\d+x)?(%s)' % pat, re.I) for pat in patterns]
+ patterns = config["match"]["preferred"]["media"].as_str_seq()
+ patterns = cast(Sequence[str], patterns)
+ options = [re.compile(r"(\d+x)?(%s)" % pat, re.I) for pat in patterns]
if options:
- dist.add_priority('media', album_info.media, options)
+ dist.add_priority("media", album_info.media, options)
# Current media.
- elif likelies['media']:
- dist.add_equality('media', album_info.media, likelies['media'])
+ elif likelies["media"]:
+ dist.add_equality("media", album_info.media, likelies["media"])
# Mediums.
- if likelies['disctotal'] and album_info.mediums:
- dist.add_number('mediums', likelies['disctotal'], album_info.mediums)
+ if likelies["disctotal"] and album_info.mediums:
+ dist.add_number("mediums", likelies["disctotal"], album_info.mediums)
# Prefer earliest release.
- if album_info.year and config['match']['preferred']['original_year']:
+ if album_info.year and config["match"]["preferred"]["original_year"]:
# Assume 1889 (earliest first gramophone discs) if we don't know the
# original year.
original = album_info.original_year or 1889
diff = abs(album_info.year - original)
diff_max = abs(datetime.date.today().year - original)
- dist.add_ratio('year', diff, diff_max)
+ dist.add_ratio("year", diff, diff_max)
# Year.
- elif likelies['year'] and album_info.year:
- if likelies['year'] in (album_info.year, album_info.original_year):
+ elif likelies["year"] and album_info.year:
+ if likelies["year"] in (album_info.year, album_info.original_year):
# No penalty for matching release or original year.
- dist.add('year', 0.0)
+ dist.add("year", 0.0)
elif album_info.original_year:
# Prefer matchest closest to the release year.
- diff = abs(likelies['year'] - album_info.year)
- diff_max = abs(datetime.date.today().year -
- album_info.original_year)
- dist.add_ratio('year', diff, diff_max)
+ diff = abs(likelies["year"] - album_info.year)
+ diff_max = abs(
+ datetime.date.today().year - album_info.original_year
+ )
+ dist.add_ratio("year", diff, diff_max)
else:
# Full penalty when there is no original year.
- dist.add('year', 1.0)
+ dist.add("year", 1.0)
# Preferred countries.
- patterns = config['match']['preferred']['countries'].as_str_seq()
+ patterns = config["match"]["preferred"]["countries"].as_str_seq()
+ patterns = cast(Sequence[str], patterns)
options = [re.compile(pat, re.I) for pat in patterns]
if album_info.country and options:
- dist.add_priority('country', album_info.country, options)
+ dist.add_priority("country", album_info.country, options)
# Country.
- elif likelies['country'] and album_info.country:
- dist.add_string('country', likelies['country'], album_info.country)
+ elif likelies["country"] and album_info.country:
+ dist.add_string("country", likelies["country"], album_info.country)
# Label.
- if likelies['label'] and album_info.label:
- dist.add_string('label', likelies['label'], album_info.label)
+ if likelies["label"] and album_info.label:
+ dist.add_string("label", likelies["label"], album_info.label)
# Catalog number.
- if likelies['catalognum'] and album_info.catalognum:
- dist.add_string('catalognum', likelies['catalognum'],
- album_info.catalognum)
+ if likelies["catalognum"] and album_info.catalognum:
+ dist.add_string(
+ "catalognum", likelies["catalognum"], album_info.catalognum
+ )
# Disambiguation.
- if likelies['albumdisambig'] and album_info.albumdisambig:
- dist.add_string('albumdisambig', likelies['albumdisambig'],
- album_info.albumdisambig)
+ if likelies["albumdisambig"] and album_info.albumdisambig:
+ dist.add_string(
+ "albumdisambig", likelies["albumdisambig"], album_info.albumdisambig
+ )
# Album ID.
- if likelies['mb_albumid']:
- dist.add_equality('album_id', likelies['mb_albumid'],
- album_info.album_id)
+ if likelies["mb_albumid"]:
+ dist.add_equality(
+ "album_id", likelies["mb_albumid"], album_info.album_id
+ )
# Tracks.
dist.tracks = {}
for item, track in mapping.items():
dist.tracks[track] = track_distance(item, track, album_info.va)
- dist.add('tracks', dist.tracks[track].distance)
+ dist.add("tracks", dist.tracks[track].distance)
# Missing tracks.
- for i in range(len(album_info.tracks) - len(mapping)):
- dist.add('missing_tracks', 1.0)
+ for _ in range(len(album_info.tracks) - len(mapping)):
+ dist.add("missing_tracks", 1.0)
# Unmatched tracks.
- for i in range(len(items) - len(mapping)):
- dist.add('unmatched_tracks', 1.0)
+ for _ in range(len(items) - len(mapping)):
+ dist.add("unmatched_tracks", 1.0)
# Plugins.
dist.update(plugins.album_distance(items, album_info, mapping))
@@ -263,7 +326,7 @@ def distance(items, album_info, mapping):
return dist
-def match_by_id(items):
+def match_by_id(items: Iterable[Item]):
"""If the items are tagged with a MusicBrainz album ID, returns an
AlbumInfo object for the corresponding album. Otherwise, returns
None.
@@ -274,20 +337,22 @@ def match_by_id(items):
try:
first = next(albumids)
except StopIteration:
- log.debug('No album ID found.')
+ log.debug("No album ID found.")
return None
# Is there a consensus on the MB album ID?
for other in albumids:
if other != first:
- log.debug('No album ID consensus.')
+ log.debug("No album ID consensus.")
return None
# If all album IDs are equal, look up the album.
- log.debug('Searching for discovered album ID: {0}', first)
+ log.debug("Searching for discovered album ID: {0}", first)
return hooks.album_for_mbid(first)
-def _recommendation(results):
+def _recommendation(
+ results: Sequence[Union[AlbumMatch, TrackMatch]],
+) -> Recommendation:
"""Given a sorted list of AlbumMatch or TrackMatch objects, return a
recommendation based on the results' distances.
@@ -301,17 +366,19 @@ def _recommendation(results):
# Basic distance thresholding.
min_dist = results[0].distance
- if min_dist < config['match']['strong_rec_thresh'].as_number():
+ if min_dist < config["match"]["strong_rec_thresh"].as_number():
# Strong recommendation level.
rec = Recommendation.strong
- elif min_dist <= config['match']['medium_rec_thresh'].as_number():
+ elif min_dist <= config["match"]["medium_rec_thresh"].as_number():
# Medium recommendation level.
rec = Recommendation.medium
elif len(results) == 1:
# Only a single candidate.
rec = Recommendation.low
- elif results[1].distance - min_dist >= \
- config['match']['rec_gap_thresh'].as_number():
+ elif (
+ results[1].distance - min_dist
+ >= config["match"]["rec_gap_thresh"].as_number()
+ ):
# Gap between first two candidates is large.
rec = Recommendation.low
else:
@@ -324,48 +391,60 @@ def _recommendation(results):
if isinstance(results[0], hooks.AlbumMatch):
for track_dist in min_dist.tracks.values():
keys.update(list(track_dist.keys()))
- max_rec_view = config['match']['max_rec']
+ max_rec_view = config["match"]["max_rec"]
for key in keys:
if key in list(max_rec_view.keys()):
- max_rec = max_rec_view[key].as_choice({
- 'strong': Recommendation.strong,
- 'medium': Recommendation.medium,
- 'low': Recommendation.low,
- 'none': Recommendation.none,
- })
+ max_rec = max_rec_view[key].as_choice(
+ {
+ "strong": Recommendation.strong,
+ "medium": Recommendation.medium,
+ "low": Recommendation.low,
+ "none": Recommendation.none,
+ }
+ )
rec = min(rec, max_rec)
return rec
-def _sort_candidates(candidates):
+AnyMatch = TypeVar("AnyMatch", TrackMatch, AlbumMatch)
+
+
+def _sort_candidates(candidates: Iterable[AnyMatch]) -> Sequence[AnyMatch]:
"""Sort candidates by distance."""
return sorted(candidates, key=lambda match: match.distance)
-def _add_candidate(items, results, info):
+def _add_candidate(
+ items: Sequence[Item],
+ results: Dict[Any, AlbumMatch],
+ info: AlbumInfo,
+):
"""Given a candidate AlbumInfo object, attempt to add the candidate
to the output dictionary of AlbumMatch objects. This involves
checking the track count, ordering the items, checking for
duplicates, and calculating the distance.
"""
- log.debug('Candidate: {0} - {1} ({2})',
- info.artist, info.album, info.album_id)
+ log.debug(
+ "Candidate: {0} - {1} ({2})", info.artist, info.album, info.album_id
+ )
# Discard albums with zero tracks.
if not info.tracks:
- log.debug('No tracks.')
+ log.debug("No tracks.")
return
- # Don't duplicate.
- if info.album_id in results:
- log.debug('Duplicate.')
+ # Prevent duplicates.
+ if info.album_id and info.album_id in results:
+ log.debug("Duplicate.")
return
# Discard matches without required tags.
- for req_tag in config['match']['required'].as_str_seq():
+ for req_tag in cast(
+ Sequence[str], config["match"]["required"].as_str_seq()
+ ):
if getattr(info, req_tag) is None:
- log.debug('Ignored. Missing required tag: {0}', req_tag)
+ log.debug("Ignored. Missing required tag: {0}", req_tag)
return
# Find mapping between the items and the track info.
@@ -376,18 +455,24 @@ def _add_candidate(items, results, info):
# Skip matches with ignored penalties.
penalties = [key for key, _ in dist]
- for penalty in config['match']['ignored'].as_str_seq():
+ ignored = cast(Sequence[str], config["match"]["ignored"].as_str_seq())
+ for penalty in ignored:
if penalty in penalties:
- log.debug('Ignored. Penalty: {0}', penalty)
+ log.debug("Ignored. Penalty: {0}", penalty)
return
- log.debug('Success. Distance: {0}', dist)
- results[info.album_id] = hooks.AlbumMatch(dist, info, mapping,
- extra_items, extra_tracks)
+ log.debug("Success. Distance: {0}", dist)
+ results[info.album_id] = hooks.AlbumMatch(
+ dist, info, mapping, extra_items, extra_tracks
+ )
-def tag_album(items, search_artist=None, search_album=None,
- search_ids=[]):
+def tag_album(
+ items,
+ search_artist: Optional[str] = None,
+ search_album: Optional[str] = None,
+ search_ids: List[str] = [],
+) -> Tuple[str, str, Proposal]:
"""Return a tuple of the current artist name, the current album
name, and a `Proposal` containing `AlbumMatch` candidates.
@@ -407,20 +492,19 @@ def tag_album(items, search_artist=None, search_album=None,
"""
# Get current metadata.
likelies, consensus = current_metadata(items)
- cur_artist = likelies['artist']
- cur_album = likelies['album']
- log.debug('Tagging {0} - {1}', cur_artist, cur_album)
+ cur_artist = cast(str, likelies["artist"])
+ cur_album = cast(str, likelies["album"])
+ log.debug("Tagging {0} - {1}", cur_artist, cur_album)
- # The output result (distance, AlbumInfo) tuples (keyed by MB album
- # ID).
- candidates = {}
+ # The output result, keys are the MB album ID.
+ candidates: Dict[Any, AlbumMatch] = {}
# Search by explicit ID.
if search_ids:
for search_id in search_ids:
- log.debug('Searching for album ID: {0}', search_id)
- for id_candidate in hooks.albums_for_id(search_id):
- _add_candidate(items, candidates, id_candidate)
+ log.debug("Searching for album ID: {0}", search_id)
+ for album_info_for_id in hooks.albums_for_id(search_id):
+ _add_candidate(items, candidates, album_info_for_id)
# Use existing metadata or text search.
else:
@@ -429,51 +513,58 @@ def tag_album(items, search_artist=None, search_album=None,
if id_info:
_add_candidate(items, candidates, id_info)
rec = _recommendation(list(candidates.values()))
- log.debug('Album ID match recommendation is {0}', rec)
- if candidates and not config['import']['timid']:
+ log.debug("Album ID match recommendation is {0}", rec)
+ if candidates and not config["import"]["timid"]:
# If we have a very good MBID match, return immediately.
# Otherwise, this match will compete against metadata-based
# matches.
if rec == Recommendation.strong:
- log.debug('ID match.')
- return cur_artist, cur_album, \
- Proposal(list(candidates.values()), rec)
+ log.debug("ID match.")
+ return (
+ cur_artist,
+ cur_album,
+ Proposal(list(candidates.values()), rec),
+ )
# Search terms.
if not (search_artist and search_album):
# No explicit search terms -- use current metadata.
search_artist, search_album = cur_artist, cur_album
- log.debug('Search terms: {0} - {1}', search_artist, search_album)
+ log.debug("Search terms: {0} - {1}", search_artist, search_album)
extra_tags = None
- if config['musicbrainz']['extra_tags']:
- tag_list = config['musicbrainz']['extra_tags'].get()
+ if config["musicbrainz"]["extra_tags"]:
+ tag_list = config["musicbrainz"]["extra_tags"].get()
extra_tags = {k: v for (k, v) in likelies.items() if k in tag_list}
- log.debug('Additional search terms: {0}', extra_tags)
+ log.debug("Additional search terms: {0}", extra_tags)
# Is this album likely to be a "various artist" release?
- va_likely = ((not consensus['artist']) or
- (search_artist.lower() in VA_ARTISTS) or
- any(item.comp for item in items))
- log.debug('Album might be VA: {0}', va_likely)
+ va_likely = (
+ (not consensus["artist"])
+ or (search_artist.lower() in VA_ARTISTS)
+ or any(item.comp for item in items)
+ )
+ log.debug("Album might be VA: {0}", va_likely)
# Get the results from the data sources.
- for matched_candidate in hooks.album_candidates(items,
- search_artist,
- search_album,
- va_likely,
- extra_tags):
+ for matched_candidate in hooks.album_candidates(
+ items, search_artist, search_album, va_likely, extra_tags
+ ):
_add_candidate(items, candidates, matched_candidate)
- log.debug('Evaluating {0} candidates.', len(candidates))
+ log.debug("Evaluating {0} candidates.", len(candidates))
# Sort and get the recommendation.
- candidates = _sort_candidates(candidates.values())
- rec = _recommendation(candidates)
- return cur_artist, cur_album, Proposal(candidates, rec)
+ candidates_sorted = _sort_candidates(candidates.values())
+ rec = _recommendation(candidates_sorted)
+ return cur_artist, cur_album, Proposal(candidates_sorted, rec)
-def tag_item(item, search_artist=None, search_title=None,
- search_ids=[]):
+def tag_item(
+ item,
+ search_artist: Optional[str] = None,
+ search_title: Optional[str] = None,
+ search_ids: Optional[List[str]] = None,
+) -> Proposal:
"""Find metadata for a single track. Return a `Proposal` consisting
of `TrackMatch` objects.
@@ -485,26 +576,31 @@ def tag_item(item, search_artist=None, search_title=None,
# Holds candidates found so far: keys are MBIDs; values are
# (distance, TrackInfo) pairs.
candidates = {}
+ rec: Optional[Recommendation] = None
# First, try matching by MusicBrainz ID.
trackids = search_ids or [t for t in [item.mb_trackid] if t]
if trackids:
for trackid in trackids:
- log.debug('Searching for track ID: {0}', trackid)
+ log.debug("Searching for track ID: {0}", trackid)
for track_info in hooks.tracks_for_id(trackid):
dist = track_distance(item, track_info, incl_artist=True)
- candidates[track_info.track_id] = \
- hooks.TrackMatch(dist, track_info)
+ candidates[track_info.track_id] = hooks.TrackMatch(
+ dist, track_info
+ )
# If this is a good match, then don't keep searching.
rec = _recommendation(_sort_candidates(candidates.values()))
- if rec == Recommendation.strong and \
- not config['import']['timid']:
- log.debug('Track ID match.')
+ if (
+ rec == Recommendation.strong
+ and not config["import"]["timid"]
+ ):
+ log.debug("Track ID match.")
return Proposal(_sort_candidates(candidates.values()), rec)
# If we're searching by ID, don't proceed.
if search_ids:
if candidates:
+ assert rec is not None
return Proposal(_sort_candidates(candidates.values()), rec)
else:
return Proposal([], Recommendation.none)
@@ -512,7 +608,7 @@ def tag_item(item, search_artist=None, search_title=None,
# Search terms.
if not (search_artist and search_title):
search_artist, search_title = item.artist, item.title
- log.debug('Item search terms: {0} - {1}', search_artist, search_title)
+ log.debug("Item search terms: {0} - {1}", search_artist, search_title)
# Get and evaluate candidate metadata.
for track_info in hooks.item_candidates(item, search_artist, search_title):
@@ -520,7 +616,7 @@ def tag_item(item, search_artist=None, search_title=None,
candidates[track_info.track_id] = hooks.TrackMatch(dist, track_info)
# Sort by distance and return with recommendation.
- log.debug('Found {0} candidates.', len(candidates))
- candidates = _sort_candidates(candidates.values())
- rec = _recommendation(candidates)
- return Proposal(candidates, rec)
+ log.debug("Found {0} candidates.", len(candidates))
+ candidates_sorted = _sort_candidates(candidates.values())
+ rec = _recommendation(candidates_sorted)
+ return Proposal(candidates_sorted, rec)
diff --git a/lib/beets/autotag/mb.py b/lib/beets/autotag/mb.py
index e6a2e277..0d0eb975 100644
--- a/lib/beets/autotag/mb.py
+++ b/lib/beets/autotag/mb.py
@@ -14,36 +14,43 @@
"""Searches for albums in the MusicBrainz database.
"""
+from __future__ import annotations
-import musicbrainzngs
import re
import traceback
-
-from beets import logging
-from beets import plugins
-import beets.autotag.hooks
-import beets
-from beets import util
-from beets import config
from collections import Counter
+from typing import Any, Dict, Iterator, List, Optional, Sequence, Tuple, cast
from urllib.parse import urljoin
-VARIOUS_ARTISTS_ID = '89ad4ac3-39f7-470e-963a-56509c546377'
+import musicbrainzngs
-BASE_URL = 'https://musicbrainz.org/'
+import beets
+import beets.autotag.hooks
+from beets import config, logging, plugins, util
+from beets.plugins import MetadataSourcePlugin
+from beets.util.id_extractors import (
+ beatport_id_regex,
+ deezer_id_regex,
+ extract_discogs_id_regex,
+ spotify_id_regex,
+)
-SKIPPED_TRACKS = ['[data track]']
+VARIOUS_ARTISTS_ID = "89ad4ac3-39f7-470e-963a-56509c546377"
+
+BASE_URL = "https://musicbrainz.org/"
+
+SKIPPED_TRACKS = ["[data track]"]
FIELDS_TO_MB_KEYS = {
- 'catalognum': 'catno',
- 'country': 'country',
- 'label': 'label',
- 'media': 'format',
- 'year': 'date',
+ "catalognum": "catno",
+ "country": "country",
+ "label": "label",
+ "barcode": "barcode",
+ "media": "format",
+ "year": "date",
}
-musicbrainzngs.set_useragent('beets', beets.__version__,
- 'https://beets.io/')
+musicbrainzngs.set_useragent("beets", beets.__version__, "https://beets.io/")
class MusicBrainzAPIError(util.HumanReadableException):
@@ -54,59 +61,76 @@ class MusicBrainzAPIError(util.HumanReadableException):
def __init__(self, reason, verb, query, tb=None):
self.query = query
if isinstance(reason, musicbrainzngs.WebServiceError):
- reason = 'MusicBrainz not reachable'
+ reason = "MusicBrainz not reachable"
super().__init__(reason, verb, tb)
def get_message(self):
- return '{} in {} with query {}'.format(
+ return "{} in {} with query {}".format(
self._reasonstr(), self.verb, repr(self.query)
)
-log = logging.getLogger('beets')
+log = logging.getLogger("beets")
-RELEASE_INCLUDES = ['artists', 'media', 'recordings', 'release-groups',
- 'labels', 'artist-credits', 'aliases',
- 'recording-level-rels', 'work-rels',
- 'work-level-rels', 'artist-rels', 'isrcs']
-BROWSE_INCLUDES = ['artist-credits', 'work-rels',
- 'artist-rels', 'recording-rels', 'release-rels']
-if "work-level-rels" in musicbrainzngs.VALID_BROWSE_INCLUDES['recording']:
+RELEASE_INCLUDES = [
+ "artists",
+ "media",
+ "recordings",
+ "release-groups",
+ "labels",
+ "artist-credits",
+ "aliases",
+ "recording-level-rels",
+ "work-rels",
+ "work-level-rels",
+ "artist-rels",
+ "isrcs",
+ "url-rels",
+ "release-rels",
+]
+BROWSE_INCLUDES = [
+ "artist-credits",
+ "work-rels",
+ "artist-rels",
+ "recording-rels",
+ "release-rels",
+]
+if "work-level-rels" in musicbrainzngs.VALID_BROWSE_INCLUDES["recording"]:
BROWSE_INCLUDES.append("work-level-rels")
BROWSE_CHUNKSIZE = 100
BROWSE_MAXTRACKS = 500
-TRACK_INCLUDES = ['artists', 'aliases', 'isrcs']
-if 'work-level-rels' in musicbrainzngs.VALID_INCLUDES['recording']:
- TRACK_INCLUDES += ['work-level-rels', 'artist-rels']
-if 'genres' in musicbrainzngs.VALID_INCLUDES['recording']:
- RELEASE_INCLUDES += ['genres']
+TRACK_INCLUDES = ["artists", "aliases", "isrcs"]
+if "work-level-rels" in musicbrainzngs.VALID_INCLUDES["recording"]:
+ TRACK_INCLUDES += ["work-level-rels", "artist-rels"]
+if "genres" in musicbrainzngs.VALID_INCLUDES["recording"]:
+ RELEASE_INCLUDES += ["genres"]
-def track_url(trackid):
- return urljoin(BASE_URL, 'recording/' + trackid)
+def track_url(trackid: str) -> str:
+ return urljoin(BASE_URL, "recording/" + trackid)
-def album_url(albumid):
- return urljoin(BASE_URL, 'release/' + albumid)
+def album_url(albumid: str) -> str:
+ return urljoin(BASE_URL, "release/" + albumid)
def configure():
"""Set up the python-musicbrainz-ngs module according to settings
from the beets configuration. This should be called at startup.
"""
- hostname = config['musicbrainz']['host'].as_str()
- https = config['musicbrainz']['https'].get(bool)
+ hostname = config["musicbrainz"]["host"].as_str()
+ https = config["musicbrainz"]["https"].get(bool)
# Only call set_hostname when a custom server is configured. Since
# musicbrainz-ngs connects to musicbrainz.org with HTTPS by default
if hostname != "musicbrainz.org":
musicbrainzngs.set_hostname(hostname, https)
musicbrainzngs.set_rate_limit(
- config['musicbrainz']['ratelimit_interval'].as_number(),
- config['musicbrainz']['ratelimit'].get(int),
+ config["musicbrainz"]["ratelimit_interval"].as_number(),
+ config["musicbrainz"]["ratelimit"].get(int),
)
-def _preferred_alias(aliases):
+def _preferred_alias(aliases: List):
"""Given an list of alias structures for an artist credit, select
and return the user's preferred alias alias or None if no matching
alias is found.
@@ -115,13 +139,25 @@ def _preferred_alias(aliases):
return
# Only consider aliases that have locales set.
- aliases = [a for a in aliases if 'locale' in a]
+ aliases = [a for a in aliases if "locale" in a]
+
+ # Get any ignored alias types and lower case them to prevent case issues
+ ignored_alias_types = config["import"]["ignored_alias_types"].as_str_seq()
+ ignored_alias_types = [a.lower() for a in ignored_alias_types]
# Search configured locales in order.
- for locale in config['import']['languages'].as_str_seq():
- # Find matching primary aliases for this locale.
- matches = [a for a in aliases
- if a['locale'] == locale and 'primary' in a]
+ for locale in config["import"]["languages"].as_str_seq():
+ # Find matching primary aliases for this locale that are not
+ # being ignored
+ matches = []
+ for a in aliases:
+ if (
+ a["locale"] == locale
+ and "primary" in a
+ and a.get("type", "").lower() not in ignored_alias_types
+ ):
+ matches.append(a)
+
# Skip to the next locale if we have no matches
if not matches:
continue
@@ -129,27 +165,30 @@ def _preferred_alias(aliases):
return matches[0]
-def _preferred_release_event(release):
+def _preferred_release_event(release: Dict[str, Any]) -> Tuple[str, str]:
"""Given a release, select and return the user's preferred release
event as a tuple of (country, release_date). Fall back to the
default release event if a preferred event is not found.
"""
- countries = config['match']['preferred']['countries'].as_str_seq()
+ countries = config["match"]["preferred"]["countries"].as_str_seq()
+ countries = cast(Sequence, countries)
for country in countries:
- for event in release.get('release-event-list', {}):
+ for event in release.get("release-event-list", {}):
try:
- if country in event['area']['iso-3166-1-code-list']:
- return country, event['date']
+ if country in event["area"]["iso-3166-1-code-list"]:
+ return country, event["date"]
except KeyError:
pass
- return release.get('country'), release.get('date')
+ return (cast(str, release.get("country")), cast(str, release.get("date")))
-def _flatten_artist_credit(credit):
- """Given a list representing an ``artist-credit`` block, flatten the
- data into a triple of joined artist name strings: canonical, sort, and
+def _multi_artist_credit(
+ credit: List[Dict], include_join_phrase: bool
+) -> Tuple[List[str], List[str], List[str]]:
+ """Given a list representing an ``artist-credit`` block, accumulate
+ data into a triple of joined artist name lists: canonical, sort, and
credit.
"""
artist_parts = []
@@ -158,43 +197,90 @@ def _flatten_artist_credit(credit):
for el in credit:
if isinstance(el, str):
# Join phrase.
- artist_parts.append(el)
- artist_credit_parts.append(el)
- artist_sort_parts.append(el)
+ if include_join_phrase:
+ artist_parts.append(el)
+ artist_credit_parts.append(el)
+ artist_sort_parts.append(el)
else:
- alias = _preferred_alias(el['artist'].get('alias-list', ()))
+ alias = _preferred_alias(el["artist"].get("alias-list", ()))
# An artist.
if alias:
- cur_artist_name = alias['alias']
+ cur_artist_name = alias["alias"]
else:
- cur_artist_name = el['artist']['name']
+ cur_artist_name = el["artist"]["name"]
artist_parts.append(cur_artist_name)
# Artist sort name.
if alias:
- artist_sort_parts.append(alias['sort-name'])
- elif 'sort-name' in el['artist']:
- artist_sort_parts.append(el['artist']['sort-name'])
+ artist_sort_parts.append(alias["sort-name"])
+ elif "sort-name" in el["artist"]:
+ artist_sort_parts.append(el["artist"]["sort-name"])
else:
artist_sort_parts.append(cur_artist_name)
# Artist credit.
- if 'name' in el:
- artist_credit_parts.append(el['name'])
+ if "name" in el:
+ artist_credit_parts.append(el["name"])
else:
artist_credit_parts.append(cur_artist_name)
return (
- ''.join(artist_parts),
- ''.join(artist_sort_parts),
- ''.join(artist_credit_parts),
+ artist_parts,
+ artist_sort_parts,
+ artist_credit_parts,
)
-def track_info(recording, index=None, medium=None, medium_index=None,
- medium_total=None):
+def _flatten_artist_credit(credit: List[Dict]) -> Tuple[str, str, str]:
+ """Given a list representing an ``artist-credit`` block, flatten the
+ data into a triple of joined artist name strings: canonical, sort, and
+ credit.
+ """
+ artist_parts, artist_sort_parts, artist_credit_parts = _multi_artist_credit(
+ credit, include_join_phrase=True
+ )
+ return (
+ "".join(artist_parts),
+ "".join(artist_sort_parts),
+ "".join(artist_credit_parts),
+ )
+
+
+def _artist_ids(credit: List[Dict]) -> List[str]:
+ """
+ Given a list representing an ``artist-credit``,
+ return a list of artist IDs
+ """
+ artist_ids: List[str] = []
+ for el in credit:
+ if isinstance(el, dict):
+ artist_ids.append(el["artist"]["id"])
+
+ return artist_ids
+
+
+def _get_related_artist_names(relations, relation_type):
+ """Given a list representing the artist relationships extract the names of
+ the remixers and concatenate them.
+ """
+ related_artists = []
+
+ for relation in relations:
+ if relation["type"] == relation_type:
+ related_artists.append(relation["artist"]["name"])
+
+ return ", ".join(related_artists)
+
+
+def track_info(
+ recording: Dict,
+ index: Optional[int] = None,
+ medium: Optional[int] = None,
+ medium_index: Optional[int] = None,
+ medium_total: Optional[int] = None,
+) -> beets.autotag.hooks.TrackInfo:
"""Translates a MusicBrainz recording result dictionary into a beets
``TrackInfo`` object. Three parameters are optional and are used
only for tracks that appear on releases (non-singletons): ``index``,
@@ -203,86 +289,104 @@ def track_info(recording, index=None, medium=None, medium_index=None,
the number of tracks on the medium. Each number is a 1-based index.
"""
info = beets.autotag.hooks.TrackInfo(
- title=recording['title'],
- track_id=recording['id'],
+ title=recording["title"],
+ track_id=recording["id"],
index=index,
medium=medium,
medium_index=medium_index,
medium_total=medium_total,
- data_source='MusicBrainz',
- data_url=track_url(recording['id']),
+ data_source="MusicBrainz",
+ data_url=track_url(recording["id"]),
)
- if recording.get('artist-credit'):
+ if recording.get("artist-credit"):
# Get the artist names.
- info.artist, info.artist_sort, info.artist_credit = \
- _flatten_artist_credit(recording['artist-credit'])
+ (
+ info.artist,
+ info.artist_sort,
+ info.artist_credit,
+ ) = _flatten_artist_credit(recording["artist-credit"])
- # Get the ID and sort name of the first artist.
- artist = recording['artist-credit'][0]['artist']
- info.artist_id = artist['id']
+ (
+ info.artists,
+ info.artists_sort,
+ info.artists_credit,
+ ) = _multi_artist_credit(
+ recording["artist-credit"], include_join_phrase=False
+ )
- if recording.get('length'):
- info.length = int(recording['length']) / (1000.0)
+ info.artists_ids = _artist_ids(recording["artist-credit"])
+ info.artist_id = info.artists_ids[0]
- info.trackdisambig = recording.get('disambiguation')
+ if recording.get("artist-relation-list"):
+ info.remixer = _get_related_artist_names(
+ recording["artist-relation-list"], relation_type="remixer"
+ )
- if recording.get('isrc-list'):
- info.isrc = ';'.join(recording['isrc-list'])
+ if recording.get("length"):
+ info.length = int(recording["length"]) / 1000.0
+
+ info.trackdisambig = recording.get("disambiguation")
+
+ if recording.get("isrc-list"):
+ info.isrc = ";".join(recording["isrc-list"])
lyricist = []
composer = []
composer_sort = []
- for work_relation in recording.get('work-relation-list', ()):
- if work_relation['type'] != 'performance':
+ for work_relation in recording.get("work-relation-list", ()):
+ if work_relation["type"] != "performance":
continue
- info.work = work_relation['work']['title']
- info.mb_workid = work_relation['work']['id']
- if 'disambiguation' in work_relation['work']:
- info.work_disambig = work_relation['work']['disambiguation']
+ info.work = work_relation["work"]["title"]
+ info.mb_workid = work_relation["work"]["id"]
+ if "disambiguation" in work_relation["work"]:
+ info.work_disambig = work_relation["work"]["disambiguation"]
- for artist_relation in work_relation['work'].get(
- 'artist-relation-list', ()):
- if 'type' in artist_relation:
- type = artist_relation['type']
- if type == 'lyricist':
- lyricist.append(artist_relation['artist']['name'])
- elif type == 'composer':
- composer.append(artist_relation['artist']['name'])
- composer_sort.append(
- artist_relation['artist']['sort-name'])
+ for artist_relation in work_relation["work"].get(
+ "artist-relation-list", ()
+ ):
+ if "type" in artist_relation:
+ type = artist_relation["type"]
+ if type == "lyricist":
+ lyricist.append(artist_relation["artist"]["name"])
+ elif type == "composer":
+ composer.append(artist_relation["artist"]["name"])
+ composer_sort.append(artist_relation["artist"]["sort-name"])
if lyricist:
- info.lyricist = ', '.join(lyricist)
+ info.lyricist = ", ".join(lyricist)
if composer:
- info.composer = ', '.join(composer)
- info.composer_sort = ', '.join(composer_sort)
+ info.composer = ", ".join(composer)
+ info.composer_sort = ", ".join(composer_sort)
arranger = []
- for artist_relation in recording.get('artist-relation-list', ()):
- if 'type' in artist_relation:
- type = artist_relation['type']
- if type == 'arranger':
- arranger.append(artist_relation['artist']['name'])
+ for artist_relation in recording.get("artist-relation-list", ()):
+ if "type" in artist_relation:
+ type = artist_relation["type"]
+ if type == "arranger":
+ arranger.append(artist_relation["artist"]["name"])
if arranger:
- info.arranger = ', '.join(arranger)
+ info.arranger = ", ".join(arranger)
# Supplementary fields provided by plugins
- extra_trackdatas = plugins.send('mb_track_extract', data=recording)
+ extra_trackdatas = plugins.send("mb_track_extract", data=recording)
for extra_trackdata in extra_trackdatas:
info.update(extra_trackdata)
- info.decode()
return info
-def _set_date_str(info, date_str, original=False):
+def _set_date_str(
+ info: beets.autotag.hooks.AlbumInfo,
+ date_str: str,
+ original: bool = False,
+):
"""Given a (possibly partial) YYYY-MM-DD string and an AlbumInfo
object, set the object's release date fields appropriately. If
`original`, then set the original_year, etc., fields.
"""
if date_str:
- date_parts = date_str.split('-')
- for key in ('year', 'month', 'day'):
+ date_parts = date_str.split("-")
+ for key in ("year", "month", "day"):
if date_parts:
date_part = date_parts.pop(0)
try:
@@ -291,143 +395,184 @@ def _set_date_str(info, date_str, original=False):
continue
if original:
- key = 'original_' + key
+ key = "original_" + key
setattr(info, key, date_num)
-def album_info(release):
+def album_info(release: Dict) -> beets.autotag.hooks.AlbumInfo:
"""Takes a MusicBrainz release result dictionary and returns a beets
AlbumInfo object containing the interesting data about that release.
"""
# Get artist name using join phrases.
- artist_name, artist_sort_name, artist_credit_name = \
- _flatten_artist_credit(release['artist-credit'])
+ artist_name, artist_sort_name, artist_credit_name = _flatten_artist_credit(
+ release["artist-credit"]
+ )
- ntracks = sum(len(m['track-list']) for m in release['medium-list'])
+ (
+ artists_names,
+ artists_sort_names,
+ artists_credit_names,
+ ) = _multi_artist_credit(
+ release["artist-credit"], include_join_phrase=False
+ )
+
+ ntracks = sum(len(m["track-list"]) for m in release["medium-list"])
# The MusicBrainz API omits 'artist-relation-list' and 'work-relation-list'
# when the release has more than 500 tracks. So we use browse_recordings
# on chunks of tracks to recover the same information in this case.
if ntracks > BROWSE_MAXTRACKS:
- log.debug('Album {} has too many tracks', release['id'])
+ log.debug("Album {} has too many tracks", release["id"])
recording_list = []
for i in range(0, ntracks, BROWSE_CHUNKSIZE):
- log.debug('Retrieving tracks starting at {}', i)
- recording_list.extend(musicbrainzngs.browse_recordings(
- release=release['id'], limit=BROWSE_CHUNKSIZE,
- includes=BROWSE_INCLUDES,
- offset=i)['recording-list'])
- track_map = {r['id']: r for r in recording_list}
- for medium in release['medium-list']:
- for recording in medium['track-list']:
- recording_info = track_map[recording['recording']['id']]
- recording['recording'] = recording_info
+ log.debug("Retrieving tracks starting at {}", i)
+ recording_list.extend(
+ musicbrainzngs.browse_recordings(
+ release=release["id"],
+ limit=BROWSE_CHUNKSIZE,
+ includes=BROWSE_INCLUDES,
+ offset=i,
+ )["recording-list"]
+ )
+ track_map = {r["id"]: r for r in recording_list}
+ for medium in release["medium-list"]:
+ for recording in medium["track-list"]:
+ recording_info = track_map[recording["recording"]["id"]]
+ recording["recording"] = recording_info
# Basic info.
track_infos = []
index = 0
- for medium in release['medium-list']:
- disctitle = medium.get('title')
- format = medium.get('format')
+ for medium in release["medium-list"]:
+ disctitle = medium.get("title")
+ format = medium.get("format")
- if format in config['match']['ignored_media'].as_str_seq():
+ if format in config["match"]["ignored_media"].as_str_seq():
continue
- all_tracks = medium['track-list']
- if ('data-track-list' in medium
- and not config['match']['ignore_data_tracks']):
- all_tracks += medium['data-track-list']
+ all_tracks = medium["track-list"]
+ if (
+ "data-track-list" in medium
+ and not config["match"]["ignore_data_tracks"]
+ ):
+ all_tracks += medium["data-track-list"]
track_count = len(all_tracks)
- if 'pregap' in medium:
- all_tracks.insert(0, medium['pregap'])
+ if "pregap" in medium:
+ all_tracks.insert(0, medium["pregap"])
for track in all_tracks:
-
- if ('title' in track['recording'] and
- track['recording']['title'] in SKIPPED_TRACKS):
+ if (
+ "title" in track["recording"]
+ and track["recording"]["title"] in SKIPPED_TRACKS
+ ):
continue
- if ('video' in track['recording'] and
- track['recording']['video'] == 'true' and
- config['match']['ignore_video_tracks']):
+ if (
+ "video" in track["recording"]
+ and track["recording"]["video"] == "true"
+ and config["match"]["ignore_video_tracks"]
+ ):
continue
# Basic information from the recording.
index += 1
ti = track_info(
- track['recording'],
+ track["recording"],
index,
- int(medium['position']),
- int(track['position']),
+ int(medium["position"]),
+ int(track["position"]),
track_count,
)
- ti.release_track_id = track['id']
+ ti.release_track_id = track["id"]
ti.disctitle = disctitle
ti.media = format
- ti.track_alt = track['number']
+ ti.track_alt = track["number"]
# Prefer track data, where present, over recording data.
- if track.get('title'):
- ti.title = track['title']
- if track.get('artist-credit'):
+ if track.get("title"):
+ ti.title = track["title"]
+ if track.get("artist-credit"):
# Get the artist names.
- ti.artist, ti.artist_sort, ti.artist_credit = \
- _flatten_artist_credit(track['artist-credit'])
- ti.artist_id = track['artist-credit'][0]['artist']['id']
- if track.get('length'):
- ti.length = int(track['length']) / (1000.0)
+ (
+ ti.artist,
+ ti.artist_sort,
+ ti.artist_credit,
+ ) = _flatten_artist_credit(track["artist-credit"])
+
+ (
+ ti.artists,
+ ti.artists_sort,
+ ti.artists_credit,
+ ) = _multi_artist_credit(
+ track["artist-credit"], include_join_phrase=False
+ )
+
+ ti.artists_ids = _artist_ids(track["artist-credit"])
+ ti.artist_id = ti.artists_ids[0]
+ if track.get("length"):
+ ti.length = int(track["length"]) / (1000.0)
track_infos.append(ti)
+ album_artist_ids = _artist_ids(release["artist-credit"])
info = beets.autotag.hooks.AlbumInfo(
- album=release['title'],
- album_id=release['id'],
+ album=release["title"],
+ album_id=release["id"],
artist=artist_name,
- artist_id=release['artist-credit'][0]['artist']['id'],
+ artist_id=album_artist_ids[0],
+ artists=artists_names,
+ artists_ids=album_artist_ids,
tracks=track_infos,
- mediums=len(release['medium-list']),
+ mediums=len(release["medium-list"]),
artist_sort=artist_sort_name,
+ artists_sort=artists_sort_names,
artist_credit=artist_credit_name,
- data_source='MusicBrainz',
- data_url=album_url(release['id']),
+ artists_credit=artists_credit_names,
+ data_source="MusicBrainz",
+ data_url=album_url(release["id"]),
+ barcode=release.get("barcode"),
)
info.va = info.artist_id == VARIOUS_ARTISTS_ID
if info.va:
- info.artist = config['va_name'].as_str()
- info.asin = release.get('asin')
- info.releasegroup_id = release['release-group']['id']
- info.albumstatus = release.get('status')
+ info.artist = config["va_name"].as_str()
+ info.asin = release.get("asin")
+ info.releasegroup_id = release["release-group"]["id"]
+ info.albumstatus = release.get("status")
+
+ if release["release-group"].get("title"):
+ info.release_group_title = release["release-group"].get("title")
# Get the disambiguation strings at the release and release group level.
- if release['release-group'].get('disambiguation'):
- info.releasegroupdisambig = \
- release['release-group'].get('disambiguation')
- if release.get('disambiguation'):
- info.albumdisambig = release.get('disambiguation')
+ if release["release-group"].get("disambiguation"):
+ info.releasegroupdisambig = release["release-group"].get(
+ "disambiguation"
+ )
+ if release.get("disambiguation"):
+ info.albumdisambig = release.get("disambiguation")
# Get the "classic" Release type. This data comes from a legacy API
# feature before MusicBrainz supported multiple release types.
- if 'type' in release['release-group']:
- reltype = release['release-group']['type']
+ if "type" in release["release-group"]:
+ reltype = release["release-group"]["type"]
if reltype:
info.albumtype = reltype.lower()
# Set the new-style "primary" and "secondary" release types.
albumtypes = []
- if 'primary-type' in release['release-group']:
- rel_primarytype = release['release-group']['primary-type']
+ if "primary-type" in release["release-group"]:
+ rel_primarytype = release["release-group"]["primary-type"]
if rel_primarytype:
albumtypes.append(rel_primarytype.lower())
- if 'secondary-type-list' in release['release-group']:
- if release['release-group']['secondary-type-list']:
- for sec_type in release['release-group']['secondary-type-list']:
+ if "secondary-type-list" in release["release-group"]:
+ if release["release-group"]["secondary-type-list"]:
+ for sec_type in release["release-group"]["secondary-type-list"]:
albumtypes.append(sec_type.lower())
- info.albumtypes = '; '.join(albumtypes)
+ info.albumtypes = albumtypes
# Release events.
info.country, release_date = _preferred_release_event(release)
- release_group_date = release['release-group'].get('first-release-date')
+ release_group_date = release["release-group"].get("first-release-date")
if not release_date:
# Fall back if release-specific date is not available.
release_date = release_group_date
@@ -435,46 +580,117 @@ def album_info(release):
_set_date_str(info, release_group_date, True)
# Label name.
- if release.get('label-info-list'):
- label_info = release['label-info-list'][0]
- if label_info.get('label'):
- label = label_info['label']['name']
- if label != '[no label]':
+ if release.get("label-info-list"):
+ label_info = release["label-info-list"][0]
+ if label_info.get("label"):
+ label = label_info["label"]["name"]
+ if label != "[no label]":
info.label = label
- info.catalognum = label_info.get('catalog-number')
+ info.catalognum = label_info.get("catalog-number")
# Text representation data.
- if release.get('text-representation'):
- rep = release['text-representation']
- info.script = rep.get('script')
- info.language = rep.get('language')
+ if release.get("text-representation"):
+ rep = release["text-representation"]
+ info.script = rep.get("script")
+ info.language = rep.get("language")
# Media (format).
- if release['medium-list']:
- first_medium = release['medium-list'][0]
- info.media = first_medium.get('format')
+ if release["medium-list"]:
+ # If all media are the same, use that medium name
+ if len({m.get("format") for m in release["medium-list"]}) == 1:
+ info.media = release["medium-list"][0].get("format")
+ # Otherwise, let's just call it "Media"
+ else:
+ info.media = "Media"
- if config['musicbrainz']['genres']:
+ if config["musicbrainz"]["genres"]:
sources = [
- release['release-group'].get('genre-list', []),
- release.get('genre-list', []),
+ release["release-group"].get("genre-list", []),
+ release.get("genre-list", []),
]
- genres = Counter()
+ genres: Counter[str] = Counter()
for source in sources:
for genreitem in source:
- genres[genreitem['name']] += int(genreitem['count'])
- info.genre = '; '.join(g[0] for g in sorted(genres.items(),
- key=lambda g: -g[1]))
+ genres[genreitem["name"]] += int(genreitem["count"])
+ info.genre = "; ".join(
+ genre
+ for genre, _count in sorted(genres.items(), key=lambda g: -g[1])
+ )
- extra_albumdatas = plugins.send('mb_album_extract', data=release)
+ # We might find links to external sources (Discogs, Bandcamp, ...)
+ if any(
+ config["musicbrainz"]["external_ids"].get().values()
+ ) and release.get("url-relation-list"):
+ discogs_url, bandcamp_url, spotify_url = None, None, None
+ deezer_url, beatport_url, tidal_url = None, None, None
+ fetch_discogs, fetch_bandcamp, fetch_spotify = False, False, False
+ fetch_deezer, fetch_beatport, fetch_tidal = False, False, False
+
+ if config["musicbrainz"]["external_ids"]["discogs"].get():
+ fetch_discogs = True
+ if config["musicbrainz"]["external_ids"]["bandcamp"].get():
+ fetch_bandcamp = True
+ if config["musicbrainz"]["external_ids"]["spotify"].get():
+ fetch_spotify = True
+ if config["musicbrainz"]["external_ids"]["deezer"].get():
+ fetch_deezer = True
+ if config["musicbrainz"]["external_ids"]["beatport"].get():
+ fetch_beatport = True
+ if config["musicbrainz"]["external_ids"]["tidal"].get():
+ fetch_tidal = True
+
+ for url in release["url-relation-list"]:
+ if fetch_discogs and url["type"] == "discogs":
+ log.debug("Found link to Discogs release via MusicBrainz")
+ discogs_url = url["target"]
+ if fetch_bandcamp and "bandcamp.com" in url["target"]:
+ log.debug("Found link to Bandcamp release via MusicBrainz")
+ bandcamp_url = url["target"]
+ if fetch_spotify and "spotify.com" in url["target"]:
+ log.debug("Found link to Spotify album via MusicBrainz")
+ spotify_url = url["target"]
+ if fetch_deezer and "deezer.com" in url["target"]:
+ log.debug("Found link to Deezer album via MusicBrainz")
+ deezer_url = url["target"]
+ if fetch_beatport and "beatport.com" in url["target"]:
+ log.debug("Found link to Beatport release via MusicBrainz")
+ beatport_url = url["target"]
+ if fetch_tidal and "tidal.com" in url["target"]:
+ log.debug("Found link to Tidal release via MusicBrainz")
+ tidal_url = url["target"]
+
+ if discogs_url:
+ info.discogs_albumid = extract_discogs_id_regex(discogs_url)
+ if bandcamp_url:
+ info.bandcamp_album_id = bandcamp_url
+ if spotify_url:
+ info.spotify_album_id = MetadataSourcePlugin._get_id(
+ "album", spotify_url, spotify_id_regex
+ )
+ if deezer_url:
+ info.deezer_album_id = MetadataSourcePlugin._get_id(
+ "album", deezer_url, deezer_id_regex
+ )
+ if beatport_url:
+ info.beatport_album_id = MetadataSourcePlugin._get_id(
+ "album", beatport_url, beatport_id_regex
+ )
+ if tidal_url:
+ info.tidal_album_id = tidal_url.split("/")[-1]
+
+ extra_albumdatas = plugins.send("mb_album_extract", data=release)
for extra_albumdata in extra_albumdatas:
info.update(extra_albumdata)
- info.decode()
return info
-def match_album(artist, album, tracks=None, extra_tags=None):
+def match_album(
+ artist: str,
+ album: str,
+ tracks: Optional[int] = None,
+ extra_tags: Optional[Dict[str, Any]] = None,
+) -> Iterator[beets.autotag.hooks.AlbumInfo]:
"""Searches for a single album ("release" in MusicBrainz parlance)
and returns an iterator over AlbumInfo objects. May raise a
MusicBrainzAPIError.
@@ -483,22 +699,22 @@ def match_album(artist, album, tracks=None, extra_tags=None):
optionally, a number of tracks on the album and any other extra tags.
"""
# Build search criteria.
- criteria = {'release': album.lower().strip()}
+ criteria = {"release": album.lower().strip()}
if artist is not None:
- criteria['artist'] = artist.lower().strip()
+ criteria["artist"] = artist.lower().strip()
else:
# Various Artists search.
- criteria['arid'] = VARIOUS_ARTISTS_ID
+ criteria["arid"] = VARIOUS_ARTISTS_ID
if tracks is not None:
- criteria['tracks'] = str(tracks)
+ criteria["tracks"] = str(tracks)
# Additional search cues from existing metadata.
if extra_tags:
- for tag in extra_tags:
+ for tag, value in extra_tags.items():
key = FIELDS_TO_MB_KEYS[tag]
- value = str(extra_tags.get(tag, '')).lower().strip()
- if key == 'catno':
- value = value.replace(' ', '')
+ value = str(value).lower().strip()
+ if key == "catno":
+ value = value.replace(" ", "")
if value:
criteria[key] = value
@@ -507,27 +723,32 @@ def match_album(artist, album, tracks=None, extra_tags=None):
return
try:
- log.debug('Searching for MusicBrainz releases with: {!r}', criteria)
+ log.debug("Searching for MusicBrainz releases with: {!r}", criteria)
res = musicbrainzngs.search_releases(
- limit=config['musicbrainz']['searchlimit'].get(int), **criteria)
+ limit=config["musicbrainz"]["searchlimit"].get(int), **criteria
+ )
except musicbrainzngs.MusicBrainzError as exc:
- raise MusicBrainzAPIError(exc, 'release search', criteria,
- traceback.format_exc())
- for release in res['release-list']:
+ raise MusicBrainzAPIError(
+ exc, "release search", criteria, traceback.format_exc()
+ )
+ for release in res["release-list"]:
# The search result is missing some data (namely, the tracks),
# so we just use the ID and fetch the rest of the information.
- albuminfo = album_for_id(release['id'])
+ albuminfo = album_for_id(release["id"])
if albuminfo is not None:
yield albuminfo
-def match_track(artist, title):
+def match_track(
+ artist: str,
+ title: str,
+) -> Iterator[beets.autotag.hooks.TrackInfo]:
"""Searches for a single track and returns an iterable of TrackInfo
objects. May raise a MusicBrainzAPIError.
"""
criteria = {
- 'artist': artist.lower().strip(),
- 'recording': title.lower().strip(),
+ "artist": artist.lower().strip(),
+ "recording": title.lower().strip(),
}
if not any(criteria.values()):
@@ -535,60 +756,144 @@ def match_track(artist, title):
try:
res = musicbrainzngs.search_recordings(
- limit=config['musicbrainz']['searchlimit'].get(int), **criteria)
+ limit=config["musicbrainz"]["searchlimit"].get(int), **criteria
+ )
except musicbrainzngs.MusicBrainzError as exc:
- raise MusicBrainzAPIError(exc, 'recording search', criteria,
- traceback.format_exc())
- for recording in res['recording-list']:
+ raise MusicBrainzAPIError(
+ exc, "recording search", criteria, traceback.format_exc()
+ )
+ for recording in res["recording-list"]:
yield track_info(recording)
-def _parse_id(s):
+def _parse_id(s: str) -> Optional[str]:
"""Search for a MusicBrainz ID in the given string and return it. If
no ID can be found, return None.
"""
# Find the first thing that looks like a UUID/MBID.
- match = re.search('[a-f0-9]{8}(-[a-f0-9]{4}){3}-[a-f0-9]{12}', s)
- if match:
- return match.group()
+ match = re.search("[a-f0-9]{8}(-[a-f0-9]{4}){3}-[a-f0-9]{12}", s)
+ if match is not None:
+ return match.group() if match else None
+ return None
-def album_for_id(releaseid):
+def _is_translation(r):
+ _trans_key = "transl-tracklisting"
+ return r["type"] == _trans_key and r["direction"] == "backward"
+
+
+def _find_actual_release_from_pseudo_release(
+ pseudo_rel: Dict,
+) -> Optional[Dict]:
+ try:
+ relations = pseudo_rel["release"]["release-relation-list"]
+ except KeyError:
+ return None
+
+ # currently we only support trans(liter)ation's
+ translations = [r for r in relations if _is_translation(r)]
+
+ if not translations:
+ return None
+
+ actual_id = translations[0]["target"]
+
+ return musicbrainzngs.get_release_by_id(actual_id, RELEASE_INCLUDES)
+
+
+def _merge_pseudo_and_actual_album(
+ pseudo: beets.autotag.hooks.AlbumInfo, actual: beets.autotag.hooks.AlbumInfo
+) -> Optional[beets.autotag.hooks.AlbumInfo]:
+ """
+ Merges a pseudo release with its actual release.
+
+ This implementation is naive, it doesn't overwrite fields,
+ like status or ids.
+
+ According to the ticket PICARD-145, the main release id should be used.
+ But the ticket has been in limbo since over a decade now.
+ It also suggests the introduction of the tag `musicbrainz_pseudoreleaseid`,
+ but as of this field can't be found in any official Picard docs,
+ hence why we did not implement that for now.
+ """
+ merged = pseudo.copy()
+ from_actual = {
+ k: actual[k]
+ for k in [
+ "media",
+ "mediums",
+ "country",
+ "catalognum",
+ "year",
+ "month",
+ "day",
+ "original_year",
+ "original_month",
+ "original_day",
+ "label",
+ "barcode",
+ "asin",
+ "style",
+ "genre",
+ ]
+ }
+ merged.update(from_actual)
+ return merged
+
+
+def album_for_id(releaseid: str) -> Optional[beets.autotag.hooks.AlbumInfo]:
"""Fetches an album by its MusicBrainz ID and returns an AlbumInfo
object or None if the album is not found. May raise a
MusicBrainzAPIError.
"""
- log.debug('Requesting MusicBrainz release {}', releaseid)
+ log.debug("Requesting MusicBrainz release {}", releaseid)
albumid = _parse_id(releaseid)
if not albumid:
- log.debug('Invalid MBID ({0}).', releaseid)
- return
+ log.debug("Invalid MBID ({0}).", releaseid)
+ return None
try:
- res = musicbrainzngs.get_release_by_id(albumid,
- RELEASE_INCLUDES)
+ res = musicbrainzngs.get_release_by_id(albumid, RELEASE_INCLUDES)
+
+ # resolve linked release relations
+ actual_res = None
+
+ if res["release"].get("status") == "Pseudo-Release":
+ actual_res = _find_actual_release_from_pseudo_release(res)
+
except musicbrainzngs.ResponseError:
- log.debug('Album ID match failed.')
+ log.debug("Album ID match failed.")
return None
except musicbrainzngs.MusicBrainzError as exc:
- raise MusicBrainzAPIError(exc, 'get release by ID', albumid,
- traceback.format_exc())
- return album_info(res['release'])
+ raise MusicBrainzAPIError(
+ exc, "get release by ID", albumid, traceback.format_exc()
+ )
+
+ # release is potentially a pseudo release
+ release = album_info(res["release"])
+
+ # should be None unless we're dealing with a pseudo release
+ if actual_res is not None:
+ actual_release = album_info(actual_res["release"])
+ return _merge_pseudo_and_actual_album(release, actual_release)
+ else:
+ return release
-def track_for_id(releaseid):
+def track_for_id(releaseid: str) -> Optional[beets.autotag.hooks.TrackInfo]:
"""Fetches a track by its MusicBrainz ID. Returns a TrackInfo object
or None if no track is found. May raise a MusicBrainzAPIError.
"""
trackid = _parse_id(releaseid)
if not trackid:
- log.debug('Invalid MBID ({0}).', releaseid)
- return
+ log.debug("Invalid MBID ({0}).", releaseid)
+ return None
try:
res = musicbrainzngs.get_recording_by_id(trackid, TRACK_INCLUDES)
except musicbrainzngs.ResponseError:
- log.debug('Track ID match failed.')
+ log.debug("Track ID match failed.")
return None
except musicbrainzngs.MusicBrainzError as exc:
- raise MusicBrainzAPIError(exc, 'get recording by ID', trackid,
- traceback.format_exc())
- return track_info(res['recording'])
+ raise MusicBrainzAPIError(
+ exc, "get recording by ID", trackid, traceback.format_exc()
+ )
+ return track_info(res["recording"])
diff --git a/lib/beets/config_default.yaml b/lib/beets/config_default.yaml
index 74540891..b28165c2 100644
--- a/lib/beets/config_default.yaml
+++ b/lib/beets/config_default.yaml
@@ -1,10 +1,34 @@
+# --------------- Main ---------------
+
library: library.db
directory: ~/Music
+statefile: state.pickle
+
+# --------------- Plugins ---------------
+
+plugins: []
+pluginpath: []
+
+# --------------- Import ---------------
+
+clutter: ["Thumbs.DB", ".DS_Store"]
+ignore: [".*", "*~", "System Volume Information", "lost+found"]
+ignore_hidden: yes
import:
+ # common options
write: yes
copy: yes
move: no
+ timid: no
+ quiet: no
+ log:
+ # other options
+ default_action: apply
+ languages: []
+ quiet_fallback: skip
+ none_rec_action: ask
+ # rare options
link: no
hardlink: no
reflink: no
@@ -13,76 +37,117 @@ import:
incremental: no
incremental_skip_later: no
from_scratch: no
- quiet_fallback: skip
- none_rec_action: ask
- timid: no
- log:
autotag: yes
- quiet: no
singletons: no
- default_action: apply
- languages: []
detail: no
flat: no
group_albums: no
pretend: false
search_ids: []
+ duplicate_keys:
+ album: albumartist album
+ item: artist title
duplicate_action: ask
+ duplicate_verbose_prompt: no
bell: no
set_fields: {}
+ ignored_alias_types: []
+ singleton_album_disambig: yes
-clutter: ["Thumbs.DB", ".DS_Store"]
-ignore: [".*", "*~", "System Volume Information", "lost+found"]
-ignore_hidden: yes
+# --------------- Paths ---------------
-replace:
- '[\\/]': _
- '^\.': _
- '[\x00-\x1f]': _
- '[<>:"\?\*\|]': _
- '\.$': _
- '\s+$': ''
- '^\s+': ''
- '^-': _
path_sep_replace: _
drive_sep_replace: _
asciify_paths: false
art_filename: cover
max_filename_length: 0
+replace:
+ # Replace bad characters with _
+ # prohibited in many filesystem paths
+ '[<>:\?\*\|]': _
+ # double quotation mark "
+ '\"': _
+ # path separators: \ or /
+ '[\\/]': _
+ # starting and closing periods
+ '^\.': _
+ '\.$': _
+ # control characters
+ '[\x00-\x1f]': _
+ # dash at the start of a filename (causes command line ambiguity)
+ '^-': _
+ # Replace bad characters with nothing
+ # starting and closing whitespace
+ '\s+$': ''
+ '^\s+': ''
aunique:
keys: albumartist album
disambiguators: albumtype year label catalognum albumdisambig releasegroupdisambig
bracket: '[]'
-overwrite_null:
- album: []
- track: []
+sunique:
+ keys: artist title
+ disambiguators: year trackdisambig
+ bracket: '[]'
+
+# --------------- Tagging ---------------
-plugins: []
-pluginpath: []
-threaded: yes
-timeout: 5.0
per_disc_numbering: no
-verbose: 0
-terminal_encoding:
original_date: no
artist_credit: no
id3v23: no
va_name: "Various Artists"
+paths:
+ default: $albumartist/$album%aunique{}/$track $title
+ singleton: Non-Album/$artist/$title
+ comp: Compilations/$album%aunique{}/$track $title
+
+# --------------- Performance ---------------
+
+threaded: yes
+timeout: 5.0
+
+# --------------- UI ---------------
+
+verbose: 0
+terminal_encoding:
ui:
terminal_width: 80
length_diff_thresh: 10.0
color: yes
colors:
- text_success: green
- text_warning: yellow
- text_error: red
- text_highlight: red
- text_highlight_minor: lightgray
- action_default: turquoise
- action: blue
+ text_success: ['bold', 'green']
+ text_warning: ['bold', 'yellow']
+ text_error: ['bold', 'red']
+ text_highlight: ['bold', 'red']
+ text_highlight_minor: ['white']
+ action_default: ['bold', 'cyan']
+ action: ['bold', 'cyan']
+ # New Colors
+ text: ['normal']
+ text_faint: ['faint']
+ import_path: ['bold', 'blue']
+ import_path_items: ['bold', 'blue']
+ added: ['green']
+ removed: ['red']
+ changed: ['yellow']
+ added_highlight: ['bold', 'green']
+ removed_highlight: ['bold', 'red']
+ changed_highlight: ['bold', 'yellow']
+ text_diff_added: ['bold', 'red']
+ text_diff_removed: ['bold', 'red']
+ text_diff_changed: ['bold', 'red']
+ action_description: ['white']
+ import:
+ indentation:
+ match_header: 2
+ match_details: 2
+ match_tracklist: 5
+ layout: column
+
+# --------------- Search ---------------
format_item: $artist - $album - $title
format_album: $albumartist - $album
@@ -93,14 +158,13 @@ sort_album: albumartist+ album+
sort_item: artist+ album+ disc+ track+
sort_case_insensitive: yes
-paths:
- default: $albumartist/$album%aunique{}/$track $title
- singleton: Non-Album/$artist/$title
- comp: Compilations/$album%aunique{}/$track $title
-
-statefile: state.pickle
+# --------------- Autotagger ---------------
+overwrite_null:
+ album: []
+ track: []
musicbrainz:
+ enabled: yes
host: musicbrainz.org
https: no
ratelimit: 1
@@ -108,6 +172,13 @@ musicbrainz:
searchlimit: 5
extra_tags: []
genres: no
+ external_ids:
+ discogs: no
+ bandcamp: no
+ spotify: no
+ deezer: no
+ beatport: no
+ tidal: no
match:
strong_rec_thresh: 0.04
@@ -147,3 +218,5 @@ match:
ignore_video_tracks: yes
track_length_grace: 10
track_length_max: 30
+ album_disambig_fields: data_source media year country label catalognum albumdisambig
+ singleton_disambig_fields: data_source index track_alt album
diff --git a/lib/beets/dbcore/__init__.py b/lib/beets/dbcore/__init__.py
index 923c34ca..06d0b3dc 100644
--- a/lib/beets/dbcore/__init__.py
+++ b/lib/beets/dbcore/__init__.py
@@ -16,12 +16,20 @@
Library.
"""
-from .db import Model, Database
-from .query import Query, FieldQuery, MatchQuery, AndQuery, OrQuery
+from .db import Database, Model, Results
+from .query import (
+ AndQuery,
+ FieldQuery,
+ InvalidQueryError,
+ MatchQuery,
+ OrQuery,
+ Query,
+)
+from .queryparse import (
+ parse_sorted_query,
+ query_from_strings,
+ sort_from_strings,
+)
from .types import Type
-from .queryparse import query_from_strings
-from .queryparse import sort_from_strings
-from .queryparse import parse_sorted_query
-from .query import InvalidQueryError
# flake8: noqa
diff --git a/lib/beets/dbcore/db.py b/lib/beets/dbcore/db.py
index acd131be..566c1163 100755
--- a/lib/beets/dbcore/db.py
+++ b/lib/beets/dbcore/db.py
@@ -12,23 +12,56 @@
# The above copyright notice and this permission notice shall be
# included in all copies or substantial portions of the Software.
-"""The central Model and Database constructs for DBCore.
-"""
+"""The central Model and Database constructs for DBCore."""
-import time
+from __future__ import annotations
+
+import contextlib
import os
import re
-from collections import defaultdict
-import threading
import sqlite3
-import contextlib
+import threading
+import time
+from abc import ABC
+from collections import defaultdict
+from sqlite3 import Connection
+from types import TracebackType
+from typing import (
+ Any,
+ AnyStr,
+ Callable,
+ DefaultDict,
+ Dict,
+ Generator,
+ Generic,
+ Iterable,
+ Iterator,
+ List,
+ Mapping,
+ Optional,
+ Sequence,
+ Tuple,
+ Type,
+ TypeVar,
+ Union,
+ cast,
+)
+
+from unidecode import unidecode
import beets
-from beets.util import functemplate
-from beets.util import py3_path
-from beets.dbcore import types
-from .query import MatchQuery, NullSort, TrueQuery
-from collections.abc import Mapping
+
+from ..util import cached_classproperty, functemplate
+from . import types
+from .query import (
+ AndQuery,
+ FieldQuery,
+ MatchQuery,
+ NullSort,
+ Query,
+ Sort,
+ TrueQuery,
+)
class DBAccessError(Exception):
@@ -40,7 +73,7 @@ class DBAccessError(Exception):
"""
-class FormattedMapping(Mapping):
+class FormattedMapping(Mapping[str, str]):
"""A `dict`-like formatted view of a model.
The accessor `mapping[key]` returns the formatted version of
@@ -54,9 +87,14 @@ class FormattedMapping(Mapping):
are replaced.
"""
- ALL_KEYS = '*'
+ ALL_KEYS = "*"
- def __init__(self, model, included_keys=ALL_KEYS, for_path=False):
+ def __init__(
+ self,
+ model: Model,
+ included_keys: str = ALL_KEYS,
+ for_path: bool = False,
+ ):
self.for_path = for_path
self.model = model
if included_keys == self.ALL_KEYS:
@@ -65,34 +103,41 @@ class FormattedMapping(Mapping):
else:
self.model_keys = included_keys
- def __getitem__(self, key):
+ def __getitem__(self, key: str) -> str:
if key in self.model_keys:
return self._get_formatted(self.model, key)
else:
raise KeyError(key)
- def __iter__(self):
+ def __iter__(self) -> Iterator[str]:
return iter(self.model_keys)
- def __len__(self):
+ def __len__(self) -> int:
return len(self.model_keys)
- def get(self, key, default=None):
+ # The following signature is incompatible with `Mapping[str, str]`, since
+ # the return type doesn't include `None` (but `default` can be `None`).
+ def get( # type: ignore
+ self,
+ key: str,
+ default: Optional[str] = None,
+ ) -> str:
+ """Similar to Mapping.get(key, default), but always formats to str."""
if default is None:
default = self.model._type(key).format(None)
return super().get(key, default)
- def _get_formatted(self, model, key):
+ def _get_formatted(self, model: Model, key: str) -> str:
value = model._type(key).format(model.get(key))
if isinstance(value, bytes):
- value = value.decode('utf-8', 'ignore')
+ value = value.decode("utf-8", "ignore")
if self.for_path:
- sep_repl = beets.config['path_sep_replace'].as_str()
- sep_drive = beets.config['drive_sep_replace'].as_str()
+ sep_repl = cast(str, beets.config["path_sep_replace"].as_str())
+ sep_drive = cast(str, beets.config["drive_sep_replace"].as_str())
- if re.match(r'^\w:', value):
- value = re.sub(r'(?<=^\w):', sep_drive, value)
+ if re.match(r"^\w:", value):
+ value = re.sub(r"(?<=^\w):", sep_drive, value)
for sep in (os.path.sep, os.path.altsep):
if sep:
@@ -101,80 +146,81 @@ class FormattedMapping(Mapping):
return value
+# NOTE: This seems like it should be a `Mapping`, i.e.
+# ```
+# class LazyConvertDict(Mapping[str, Any])
+# ```
+# but there are some conflicts with the `Mapping` protocol such that we
+# can't do this without changing behaviour: In particular, iterators returned
+# by some methods build intermediate lists, such that modification of the
+# `LazyConvertDict` becomes safe during iteration. Some code does in fact rely
+# on this.
class LazyConvertDict:
- """Lazily convert types for attributes fetched from the database
- """
+ """Lazily convert types for attributes fetched from the database"""
- def __init__(self, model_cls):
- """Initialize the object empty
- """
- self.data = {}
+ def __init__(self, model_cls: "Model"):
+ """Initialize the object empty"""
+ # FIXME: Dict[str, SQLiteType]
+ self._data: Dict[str, Any] = {}
self.model_cls = model_cls
- self._converted = {}
+ self._converted: Dict[str, Any] = {}
- def init(self, data):
- """Set the base data that should be lazily converted
- """
- self.data = data
+ def init(self, data: Dict[str, Any]):
+ """Set the base data that should be lazily converted"""
+ self._data = data
- def _convert(self, key, value):
- """Convert the attribute type according the the SQL type
- """
+ def _convert(self, key: str, value: Any):
+ """Convert the attribute type according to the SQL type"""
return self.model_cls._type(key).from_sql(value)
- def __setitem__(self, key, value):
- """Set an attribute value, assume it's already converted
- """
+ def __setitem__(self, key: str, value: Any):
+ """Set an attribute value, assume it's already converted"""
self._converted[key] = value
- def __getitem__(self, key):
+ def __getitem__(self, key: str) -> Any:
"""Get an attribute value, converting the type on demand
if needed
"""
if key in self._converted:
return self._converted[key]
- elif key in self.data:
- value = self._convert(key, self.data[key])
+ elif key in self._data:
+ value = self._convert(key, self._data[key])
self._converted[key] = value
return value
- def __delitem__(self, key):
- """Delete both converted and base data
- """
+ def __delitem__(self, key: str):
+ """Delete both converted and base data"""
if key in self._converted:
del self._converted[key]
- if key in self.data:
- del self.data[key]
+ if key in self._data:
+ del self._data[key]
- def keys(self):
- """Get a list of available field names for this object.
- """
- return list(self._converted.keys()) + list(self.data.keys())
+ def keys(self) -> List[str]:
+ """Get a list of available field names for this object."""
+ return list(self._converted.keys()) + list(self._data.keys())
- def copy(self):
- """Create a copy of the object.
- """
+ def copy(self) -> LazyConvertDict:
+ """Create a copy of the object."""
new = self.__class__(self.model_cls)
- new.data = self.data.copy()
+ new._data = self._data.copy()
new._converted = self._converted.copy()
return new
# Act like a dictionary.
- def update(self, values):
- """Assign all values in the given dict.
- """
+ def update(self, values: Mapping[str, Any]):
+ """Assign all values in the given dict."""
for key, value in values.items():
self[key] = value
- def items(self):
+ def items(self) -> Iterable[Tuple[str, Any]]:
"""Iterate over (key, value) pairs that this object contains.
Computed fields are not included.
"""
for key in self:
yield key, self[key]
- def get(self, key, default=None):
+ def get(self, key: str, default: Optional[Any] = None):
"""Get the value for a given key or `default` if it does not
exist.
"""
@@ -183,21 +229,30 @@ class LazyConvertDict:
else:
return default
- def __contains__(self, key):
- """Determine whether `key` is an attribute on this object.
- """
- return key in self.keys()
+ def __contains__(self, key: Any) -> bool:
+ """Determine whether `key` is an attribute on this object."""
+ return key in self._converted or key in self._data
- def __iter__(self):
+ def __iter__(self) -> Iterator[str]:
"""Iterate over the available field names (excluding computed
fields).
"""
+ # NOTE: It would be nice to use the following:
+ # yield from self._converted
+ # yield from self._data
+ # but that won't work since some code relies on modifying `self`
+ # during iteration.
return iter(self.keys())
+ def __len__(self) -> int:
+ # FIXME: This is incorrect due to duplication of keys
+ return len(self._converted) + len(self._data)
+
# Abstract base for model classes.
-class Model:
+
+class Model(ABC):
"""An abstract object representing an object in the database. Model
objects act like dictionaries (i.e., they allow subscript access like
``obj['field']``). The same field set is available via attribute
@@ -223,34 +278,34 @@ class Model:
# Abstract components (to be provided by subclasses).
- _table = None
+ _table: str
"""The main SQLite table name.
"""
- _flex_table = None
+ _flex_table: str
"""The flex field SQLite table name.
"""
- _fields = {}
+ _fields: Dict[str, types.Type] = {}
"""A mapping indicating available "fixed" fields on this type. The
keys are field names and the values are `Type` objects.
"""
- _search_fields = ()
+ _search_fields: Sequence[str] = ()
"""The fields that should be queried by default by unqualified query
terms.
"""
- _types = {}
+ _types: Dict[str, types.Type] = {}
"""Optional Types for non-fixed (i.e., flexible and computed) fields.
"""
- _sorts = {}
+ _sorts: Dict[str, Type[Sort]] = {}
"""Optional named sort criteria. The keys are strings and the values
are subclasses of `Sort`.
"""
- _queries = {}
+ _queries: Dict[str, Type[FieldQuery]] = {}
"""Named queries that use a field-like `name:value` syntax but which
do not relate to any specific field.
"""
@@ -266,15 +321,40 @@ class Model:
to the database.
"""
- @classmethod
- def _getters(cls):
- """Return a mapping from field names to getter functions.
+ @cached_classproperty
+ def _relation(cls) -> type[Model]:
+ """The model that this model is closely related to."""
+ return cls
+
+ @cached_classproperty
+ def relation_join(cls) -> str:
+ """Return the join required to include the related table in the query.
+
+ This is intended to be used as a FROM clause in the SQL query.
"""
+ return ""
+
+ @cached_classproperty
+ def all_db_fields(cls) -> set[str]:
+ return cls._fields.keys() | cls._relation._fields.keys()
+
+ @cached_classproperty
+ def shared_db_fields(cls) -> set[str]:
+ return cls._fields.keys() & cls._relation._fields.keys()
+
+ @cached_classproperty
+ def other_db_fields(cls) -> set[str]:
+ """Fields in the related table."""
+ return cls._relation._fields.keys() - cls.shared_db_fields
+
+ @classmethod
+ def _getters(cls: Type["Model"]):
+ """Return a mapping from field names to getter functions."""
# We could cache this if it becomes a performance problem to
# gather the getter mapping every time.
raise NotImplementedError()
- def _template_funcs(self):
+ def _template_funcs(self) -> Mapping[str, Callable[[str], str]]:
"""Return a mapping from function names to text-transformer
functions.
"""
@@ -283,12 +363,12 @@ class Model:
# Basic operation.
- def __init__(self, db=None, **values):
+ def __init__(self, db: Optional[Database] = None, **values):
"""Create a new object with an optional Database association and
initial field values.
"""
self._db = db
- self._dirty = set()
+ self._dirty: set[str] = set()
self._values_fixed = LazyConvertDict(self)
self._values_flex = LazyConvertDict(self)
@@ -297,7 +377,12 @@ class Model:
self.clear_dirty()
@classmethod
- def _awaken(cls, db=None, fixed_values={}, flex_values={}):
+ def _awaken(
+ cls: Type[AnyModel],
+ db: Optional[Database] = None,
+ fixed_values: Dict[str, Any] = {},
+ flex_values: Dict[str, Any] = {},
+ ) -> AnyModel:
"""Create an object with values drawn from the database.
This is a performance optimization: the checks involved with
@@ -310,10 +395,10 @@ class Model:
return obj
- def __repr__(self):
- return '{}({})'.format(
+ def __repr__(self) -> str:
+ return "{}({})".format(
type(self).__name__,
- ', '.join(f'{k}={v!r}' for k, v in dict(self).items()),
+ ", ".join(f"{k}={v!r}" for k, v in dict(self).items()),
)
def clear_dirty(self):
@@ -324,19 +409,19 @@ class Model:
if self._db:
self._revision = self._db.revision
- def _check_db(self, need_id=True):
+ def _check_db(self, need_id: bool = True) -> Database:
"""Ensure that this object is associated with a database row: it
has a reference to a database (`_db`) and an id. A ValueError
exception is raised otherwise.
"""
if not self._db:
- raise ValueError(
- '{} has no database'.format(type(self).__name__)
- )
+ raise ValueError("{} has no database".format(type(self).__name__))
if need_id and not self.id:
- raise ValueError('{} has no id'.format(type(self).__name__))
+ raise ValueError("{} has no id".format(type(self).__name__))
- def copy(self):
+ return self._db
+
+ def copy(self) -> "Model":
"""Create a copy of the model object.
The field values and other state is duplicated, but the new copy
@@ -354,7 +439,7 @@ class Model:
# Essential field accessors.
@classmethod
- def _type(cls, key):
+ def _type(cls, key) -> types.Type:
"""Get the type of a field, a `Type` instance.
If the field has no explicit type, it is given the base `Type`,
@@ -362,7 +447,7 @@ class Model:
"""
return cls._fields.get(key) or cls._types.get(key) or types.DEFAULT
- def _get(self, key, default=None, raise_=False):
+ def _get(self, key, default: Any = None, raise_: bool = False):
"""Get the value for a field, or `default`. Alternatively,
raise a KeyError if the field is not available.
"""
@@ -412,24 +497,22 @@ class Model:
return changed
def __setitem__(self, key, value):
- """Assign the value for a field.
- """
+ """Assign the value for a field."""
self._setitem(key, value)
def __delitem__(self, key):
- """Remove a flexible attribute from the model.
- """
+ """Remove a flexible attribute from the model."""
if key in self._values_flex: # Flexible.
del self._values_flex[key]
self._dirty.add(key) # Mark for dropping on store.
elif key in self._fields: # Fixed
setattr(self, key, self._type(key).null)
elif key in self._getters(): # Computed.
- raise KeyError(f'computed field {key} cannot be deleted')
+ raise KeyError(f"computed field {key} cannot be deleted")
else:
- raise KeyError(f'no such field {key}')
+ raise KeyError(f"no such field {key}")
- def keys(self, computed=False):
+ def keys(self, computed: bool = False):
"""Get a list of available field names for this object. The
`computed` parameter controls whether computed (plugin-provided)
fields are included in the key list.
@@ -450,24 +533,22 @@ class Model:
# Act like a dictionary.
def update(self, values):
- """Assign all values in the given dict.
- """
+ """Assign all values in the given dict."""
for key, value in values.items():
self[key] = value
- def items(self):
+ def items(self) -> Iterator[Tuple[str, Any]]:
"""Iterate over (key, value) pairs that this object contains.
Computed fields are not included.
"""
for key in self:
yield key, self[key]
- def __contains__(self, key):
- """Determine whether `key` is an attribute on this object.
- """
+ def __contains__(self, key) -> bool:
+ """Determine whether `key` is an attribute on this object."""
return key in self.keys(computed=True)
- def __iter__(self):
+ def __iter__(self) -> Iterator[str]:
"""Iterate over the available field names (excluding computed
fields).
"""
@@ -476,53 +557,52 @@ class Model:
# Convenient attribute access.
def __getattr__(self, key):
- if key.startswith('_'):
- raise AttributeError(f'model has no attribute {key!r}')
+ if key.startswith("_"):
+ raise AttributeError(f"model has no attribute {key!r}")
else:
try:
return self[key]
except KeyError:
- raise AttributeError(f'no such field {key!r}')
+ raise AttributeError(f"no such field {key!r}")
def __setattr__(self, key, value):
- if key.startswith('_'):
+ if key.startswith("_"):
super().__setattr__(key, value)
else:
self[key] = value
def __delattr__(self, key):
- if key.startswith('_'):
+ if key.startswith("_"):
super().__delattr__(key)
else:
del self[key]
# Database interaction (CRUD methods).
- def store(self, fields=None):
+ def store(self, fields: Optional[Iterable[str]] = None):
"""Save the object's metadata into the library database.
:param fields: the fields to be stored. If not specified, all fields
will be.
"""
if fields is None:
fields = self._fields
- self._check_db()
+ db = self._check_db()
# Build assignments for query.
assignments = []
subvars = []
for key in fields:
- if key != 'id' and key in self._dirty:
+ if key != "id" and key in self._dirty:
self._dirty.remove(key)
- assignments.append(key + '=?')
+ assignments.append(key + "=?")
value = self._type(key).to_sql(self[key])
subvars.append(value)
- assignments = ','.join(assignments)
- with self._db.transaction() as tx:
+ with db.transaction() as tx:
# Main table update.
if assignments:
- query = 'UPDATE {} SET {} WHERE id=?'.format(
- self._table, assignments
+ query = "UPDATE {} SET {} WHERE id=?".format(
+ self._table, ",".join(assignments)
)
subvars.append(self.id)
tx.mutate(query, subvars)
@@ -532,18 +612,17 @@ class Model:
if key in self._dirty:
self._dirty.remove(key)
tx.mutate(
- 'INSERT INTO {} '
- '(entity_id, key, value) '
- 'VALUES (?, ?, ?);'.format(self._flex_table),
+ "INSERT INTO {} "
+ "(entity_id, key, value) "
+ "VALUES (?, ?, ?);".format(self._flex_table),
(self.id, key, value),
)
# Deleted flexible attributes.
for key in self._dirty:
tx.mutate(
- 'DELETE FROM {} '
- 'WHERE entity_id=? AND key=?'.format(self._flex_table),
- (self.id, key)
+ f"DELETE FROM {self._flex_table} WHERE entity_id=? AND key=?",
+ (self.id, key),
)
self.clear_dirty()
@@ -554,11 +633,11 @@ class Model:
If check_revision is true, the database is only queried loaded when a
transaction has been committed since the item was last loaded.
"""
- self._check_db()
- if not self._dirty and self._db.revision == self._revision:
+ db = self._check_db()
+ if not self._dirty and db.revision == self._revision:
# Exit early
return
- stored_obj = self._db._get(type(self), self.id)
+ stored_obj = db._get(type(self), self.id)
assert stored_obj is not None, f"object {self.id} not in DB"
self._values_fixed = LazyConvertDict(self)
self._values_flex = LazyConvertDict(self)
@@ -566,20 +645,15 @@ class Model:
self.clear_dirty()
def remove(self):
- """Remove the object's associated rows from the database.
- """
- self._check_db()
- with self._db.transaction() as tx:
+ """Remove the object's associated rows from the database."""
+ db = self._check_db()
+ with db.transaction() as tx:
+ tx.mutate(f"DELETE FROM {self._table} WHERE id=?", (self.id,))
tx.mutate(
- f'DELETE FROM {self._table} WHERE id=?',
- (self.id,)
- )
- tx.mutate(
- f'DELETE FROM {self._flex_table} WHERE entity_id=?',
- (self.id,)
+ f"DELETE FROM {self._flex_table} WHERE entity_id=?", (self.id,)
)
- def add(self, db=None):
+ def add(self, db: Optional["Database"] = None):
"""Add the object to the library database. This object must be
associated with a database; you can provide one via the `db`
parameter or use the currently associated database.
@@ -589,12 +663,10 @@ class Model:
"""
if db:
self._db = db
- self._check_db(False)
+ db = self._check_db(False)
- with self._db.transaction() as tx:
- new_id = tx.mutate(
- f'INSERT INTO {self._table} DEFAULT VALUES'
- )
+ with db.transaction() as tx:
+ new_id = tx.mutate(f"INSERT INTO {self._table} DEFAULT VALUES")
self.id = new_id
self.added = time.time()
@@ -608,53 +680,101 @@ class Model:
_formatter = FormattedMapping
- def formatted(self, included_keys=_formatter.ALL_KEYS, for_path=False):
+ def formatted(
+ self,
+ included_keys: str = _formatter.ALL_KEYS,
+ for_path: bool = False,
+ ):
"""Get a mapping containing all values on this object formatted
as human-readable unicode strings.
"""
return self._formatter(self, included_keys, for_path)
- def evaluate_template(self, template, for_path=False):
+ def evaluate_template(
+ self,
+ template: Union[str, functemplate.Template],
+ for_path: bool = False,
+ ) -> str:
"""Evaluate a template (a string or a `Template` object) using
the object's fields. If `for_path` is true, then no new path
separators will be added to the template.
"""
# Perform substitution.
if isinstance(template, str):
- template = functemplate.template(template)
- return template.substitute(self.formatted(for_path=for_path),
- self._template_funcs())
+ t = functemplate.template(template)
+ else:
+ # Help out mypy
+ t = template
+ return t.substitute(
+ self.formatted(for_path=for_path), self._template_funcs()
+ )
# Parsing.
@classmethod
- def _parse(cls, key, string):
- """Parse a string as a value for the given key.
- """
+ def _parse(cls, key, string: str) -> Any:
+ """Parse a string as a value for the given key."""
if not isinstance(string, str):
raise TypeError("_parse() argument must be a string")
return cls._type(key).parse(string)
- def set_parse(self, key, string):
- """Set the object's key to a value represented by a string.
- """
+ def set_parse(self, key, string: str):
+ """Set the object's key to a value represented by a string."""
self[key] = self._parse(key, string)
+ # Convenient queries.
+
+ @classmethod
+ def field_query(
+ cls,
+ field,
+ pattern,
+ query_cls: Type[FieldQuery] = MatchQuery,
+ ) -> FieldQuery:
+ """Get a `FieldQuery` for this model."""
+ return query_cls(field, pattern, field in cls._fields)
+
+ @classmethod
+ def all_fields_query(
+ cls: Type["Model"],
+ pats: Mapping,
+ query_cls: Type[FieldQuery] = MatchQuery,
+ ):
+ """Get a query that matches many fields with different patterns.
+
+ `pats` should be a mapping from field names to patterns. The
+ resulting query is a conjunction ("and") of per-field queries
+ for all of these field/pattern pairs.
+ """
+ subqueries = [cls.field_query(k, v, query_cls) for k, v in pats.items()]
+ return AndQuery(subqueries)
+
# Database controller and supporting interfaces.
-class Results:
+
+AnyModel = TypeVar("AnyModel", bound=Model)
+
+
+class Results(Generic[AnyModel]):
"""An item query result set. Iterating over the collection lazily
- constructs LibModel objects that reflect database rows.
+ constructs Model objects that reflect database rows.
"""
- def __init__(self, model_class, rows, db, flex_rows,
- query=None, sort=None):
+ def __init__(
+ self,
+ model_class: Type[AnyModel],
+ rows: List[Mapping],
+ db: "Database",
+ flex_rows,
+ query: Optional[Query] = None,
+ sort=None,
+ ):
"""Create a result set that will construct objects of type
`model_class`.
- `model_class` is a subclass of `LibModel` that will be
+ `model_class` is a subclass of `Model` that will be
constructed. `rows` is a query result: a list of mappings. The
new objects will be associated with the database `db`.
@@ -680,9 +800,9 @@ class Results:
# The materialized objects corresponding to rows that have been
# consumed.
- self._objects = []
+ self._objects: List[AnyModel] = []
- def _get_objects(self):
+ def _get_objects(self) -> Iterator[AnyModel]:
"""Construct and generate Model objects for they query. The
objects are returned in the order emitted from the database; no
slow sort is applied.
@@ -708,7 +828,7 @@ class Results:
else:
while self._rows:
row = self._rows.pop(0)
- obj = self._make_model(row, flex_attrs.get(row['id'], {}))
+ obj = self._make_model(row, flex_attrs.get(row["id"], {}))
# If there is a slow-query predicate, ensurer that the
# object passes it.
if not self.query or self.query.match(obj):
@@ -717,7 +837,7 @@ class Results:
yield obj
break
- def __iter__(self):
+ def __iter__(self) -> Iterator[AnyModel]:
"""Construct and generate Model objects for all matching
objects, in sorted order.
"""
@@ -730,32 +850,28 @@ class Results:
# Objects are pre-sorted (i.e., by the database).
return self._get_objects()
- def _get_indexed_flex_attrs(self):
- """ Index flexible attributes by the entity id they belong to
- """
- flex_values = {}
+ def _get_indexed_flex_attrs(self) -> Mapping:
+ """Index flexible attributes by the entity id they belong to"""
+ flex_values: Dict[int, Dict[str, Any]] = {}
for row in self.flex_rows:
- if row['entity_id'] not in flex_values:
- flex_values[row['entity_id']] = {}
+ if row["entity_id"] not in flex_values:
+ flex_values[row["entity_id"]] = {}
- flex_values[row['entity_id']][row['key']] = row['value']
+ flex_values[row["entity_id"]][row["key"]] = row["value"]
return flex_values
- def _make_model(self, row, flex_values={}):
- """ Create a Model object for the given row
- """
+ def _make_model(self, row, flex_values: Dict = {}) -> AnyModel:
+ """Create a Model object for the given row"""
cols = dict(row)
- values = {k: v for (k, v) in cols.items()
- if not k[:4] == 'flex'}
+ values = {k: v for (k, v) in cols.items() if not k[:4] == "flex"}
# Construct the Python object
obj = self.model_class._awaken(self.db, values, flex_values)
return obj
- def __len__(self):
- """Get the number of matching objects.
- """
+ def __len__(self) -> int:
+ """Get the number of matching objects."""
if not self._rows:
# Fully materialized. Just count the objects.
return len(self._objects)
@@ -771,14 +887,12 @@ class Results:
# A fast query. Just count the rows.
return self._row_count
- def __nonzero__(self):
- """Does this result contain any objects?
- """
+ def __nonzero__(self) -> bool:
+ """Does this result contain any objects?"""
return self.__bool__()
- def __bool__(self):
- """Does this result contain any objects?
- """
+ def __bool__(self) -> bool:
+ """Does this result contain any objects?"""
return bool(len(self))
def __getitem__(self, n):
@@ -796,9 +910,9 @@ class Results:
next(it)
return next(it)
except StopIteration:
- raise IndexError(f'result index {n} out of range')
+ raise IndexError(f"result index {n} out of range")
- def get(self):
+ def get(self) -> Optional[AnyModel]:
"""Return the first matching object, or None if no objects
match.
"""
@@ -819,10 +933,10 @@ class Transaction:
current transaction.
"""
- def __init__(self, db):
+ def __init__(self, db: "Database"):
self.db = db
- def __enter__(self):
+ def __enter__(self) -> "Transaction":
"""Begin a transaction. This transaction may be created while
another is active in a different thread.
"""
@@ -835,7 +949,12 @@ class Transaction:
self.db._db_lock.acquire()
return self
- def __exit__(self, exc_type, exc_value, traceback):
+ def __exit__(
+ self,
+ exc_type: Type[Exception],
+ exc_value: Exception,
+ traceback: TracebackType,
+ ):
"""Complete a transaction. This must be the most recently
entered but not yet exited transaction. If it is the last active
transaction, the database updates are committed.
@@ -851,14 +970,14 @@ class Transaction:
self._mutated = False
self.db._db_lock.release()
- def query(self, statement, subvals=()):
+ def query(self, statement: str, subvals: Sequence = ()) -> List:
"""Execute an SQL statement with substitution values and return
a list of rows from the database.
"""
cursor = self.db._connection().execute(statement, subvals)
return cursor.fetchall()
- def mutate(self, statement, subvals=()):
+ def mutate(self, statement: str, subvals: Sequence = ()) -> Any:
"""Execute an SQL statement with substitution values and return
the row ID of the last affected row.
"""
@@ -868,8 +987,10 @@ class Transaction:
# In two specific cases, SQLite reports an error while accessing
# the underlying database file. We surface these exceptions as
# DBAccessError so the application can abort.
- if e.args[0] in ("attempt to write a readonly database",
- "unable to open database file"):
+ if e.args[0] in (
+ "attempt to write a readonly database",
+ "unable to open database file",
+ ):
raise DBAccessError(e.args[0])
else:
raise
@@ -877,7 +998,7 @@ class Transaction:
self._mutated = True
return cursor.lastrowid
- def script(self, statements):
+ def script(self, statements: str):
"""Execute a string containing multiple SQL statements."""
# We don't know whether this mutates, but quite likely it does.
self._mutated = True
@@ -889,11 +1010,11 @@ class Database:
the backend.
"""
- _models = ()
+ _models: Sequence[Type[Model]] = ()
"""The Model subclasses representing tables in this database.
"""
- supports_extensions = hasattr(sqlite3.Connection, 'enable_load_extension')
+ supports_extensions = hasattr(sqlite3.Connection, "enable_load_extension")
"""Whether or not the current version of SQLite supports extensions"""
revision = 0
@@ -901,13 +1022,18 @@ class Database:
data is written in a transaction.
"""
- def __init__(self, path, timeout=5.0):
+ def __init__(self, path, timeout: float = 5.0):
+ if sqlite3.threadsafety == 0:
+ raise RuntimeError(
+ "sqlite3 must be compiled with multi-threading support"
+ )
+
self.path = path
self.timeout = timeout
- self._connections = {}
- self._tx_stacks = defaultdict(list)
- self._extensions = []
+ self._connections: Dict[int, sqlite3.Connection] = {}
+ self._tx_stacks: DefaultDict[int, List[Transaction]] = defaultdict(list)
+ self._extensions: List[str] = []
# A lock to protect the _connections and _tx_stacks maps, which
# both map thread IDs to private resources.
@@ -930,11 +1056,16 @@ class Database:
# Primitive access control: connections and transactions.
- def _connection(self):
+ def _connection(self) -> Connection:
"""Get a SQLite connection object to the underlying database.
One connection object is created per thread.
"""
thread_id = threading.current_thread().ident
+ # Help the type checker: ident can only be None if the thread has not
+ # been started yet; but since this results from current_thread(), that
+ # can't happen
+ assert thread_id is not None
+
with self._shared_map_lock:
if thread_id in self._connections:
return self._connections[thread_id]
@@ -943,7 +1074,7 @@ class Database:
self._connections[thread_id] = conn
return conn
- def _create_connection(self):
+ def _create_connection(self) -> Connection:
"""Create a SQLite connection to the underlying database.
Makes a new connection every time. If you need to configure the
@@ -952,10 +1083,15 @@ class Database:
"""
# Make a new connection. The `sqlite3` module can't use
# bytestring paths here on Python 3, so we need to
- # provide a `str` using `py3_path`.
+ # provide a `str` using `os.fsdecode`.
conn = sqlite3.connect(
- py3_path(self.path), timeout=self.timeout
+ os.fsdecode(self.path),
+ timeout=self.timeout,
+ # We have our own same-thread checks in _connection(), but need to
+ # call conn.close() in _close()
+ check_same_thread=False,
)
+ self.add_functions(conn)
if self.supports_extensions:
conn.enable_load_extension(True)
@@ -968,35 +1104,66 @@ class Database:
conn.row_factory = sqlite3.Row
return conn
+ def add_functions(self, conn):
+ def regexp(value, pattern):
+ if isinstance(value, bytes):
+ value = value.decode()
+ return re.search(pattern, str(value)) is not None
+
+ def bytelower(bytestring: Optional[AnyStr]) -> Optional[AnyStr]:
+ """A custom ``bytelower`` sqlite function so we can compare
+ bytestrings in a semi case insensitive fashion.
+
+ This is to work around sqlite builds are that compiled with
+ ``-DSQLITE_LIKE_DOESNT_MATCH_BLOBS``. See
+ ``https://github.com/beetbox/beets/issues/2172`` for details.
+ """
+ if bytestring is not None:
+ return bytestring.lower()
+
+ return bytestring
+
+ conn.create_function("regexp", 2, regexp)
+ conn.create_function("unidecode", 1, unidecode)
+ conn.create_function("bytelower", 1, bytelower)
+
def _close(self):
"""Close the all connections to the underlying SQLite database
from all threads. This does not render the database object
unusable; new connections can still be opened on demand.
"""
with self._shared_map_lock:
- self._connections.clear()
+ while self._connections:
+ _thread_id, conn = self._connections.popitem()
+ conn.close()
@contextlib.contextmanager
- def _tx_stack(self):
+ def _tx_stack(self) -> Generator[List, None, None]:
"""A context manager providing access to the current thread's
transaction stack. The context manager synchronizes access to
the stack map. Transactions should never migrate across threads.
"""
thread_id = threading.current_thread().ident
+ # Help the type checker: ident can only be None if the thread has not
+ # been started yet; but since this results from current_thread(), that
+ # can't happen
+ assert thread_id is not None
+
with self._shared_map_lock:
yield self._tx_stacks[thread_id]
- def transaction(self):
+ def transaction(self) -> Transaction:
"""Get a :class:`Transaction` object for interacting directly
with the underlying SQLite database.
"""
return Transaction(self)
- def load_extension(self, path):
+ def load_extension(self, path: str):
"""Load an SQLite extension into all open connections."""
if not self.supports_extensions:
raise ValueError(
- 'this sqlite3 installation does not support extensions')
+ "this sqlite3 installation does not support extensions"
+ )
self._extensions.append(path)
@@ -1006,13 +1173,13 @@ class Database:
# Schema setup and migration.
- def _make_table(self, table, fields):
+ def _make_table(self, table: str, fields: Mapping[str, types.Type]):
"""Set up the schema of the database. `fields` is a mapping
from field names to `Type`s. Columns are added if necessary.
"""
# Get current schema.
with self.transaction() as tx:
- rows = tx.query('PRAGMA table_info(%s)' % table)
+ rows = tx.query("PRAGMA table_info(%s)" % table)
current_fields = {row[1] for row in rows}
field_names = set(fields.keys())
@@ -1024,29 +1191,31 @@ class Database:
# No table exists.
columns = []
for name, typ in fields.items():
- columns.append(f'{name} {typ.sql}')
- setup_sql = 'CREATE TABLE {} ({});\n'.format(table,
- ', '.join(columns))
+ columns.append(f"{name} {typ.sql}")
+ setup_sql = "CREATE TABLE {} ({});\n".format(
+ table, ", ".join(columns)
+ )
else:
# Table exists does not match the field set.
- setup_sql = ''
+ setup_sql = ""
for name, typ in fields.items():
if name in current_fields:
continue
- setup_sql += 'ALTER TABLE {} ADD COLUMN {} {};\n'.format(
+ setup_sql += "ALTER TABLE {} ADD COLUMN {} {};\n".format(
table, name, typ.sql
)
with self.transaction() as tx:
tx.script(setup_sql)
- def _make_attribute_table(self, flex_table):
+ def _make_attribute_table(self, flex_table: str):
"""Create a table and associated index for flexible attributes
for the given entity (if they don't exist).
"""
with self.transaction() as tx:
- tx.script("""
+ tx.script(
+ """
CREATE TABLE IF NOT EXISTS {0} (
id INTEGER PRIMARY KEY,
entity_id INTEGER,
@@ -1055,11 +1224,19 @@ class Database:
UNIQUE(entity_id, key) ON CONFLICT REPLACE);
CREATE INDEX IF NOT EXISTS {0}_by_entity
ON {0} (entity_id);
- """.format(flex_table))
+ """.format(
+ flex_table
+ )
+ )
# Querying.
- def _fetch(self, model_cls, query=None, sort=None):
+ def _fetch(
+ self,
+ model_cls: Type[AnyModel],
+ query: Optional[Query] = None,
+ sort: Optional[Sort] = None,
+ ) -> Results[AnyModel]:
"""Fetch the objects of type `model_cls` matching the given
query. The query may be given as a string, string sequence, a
Query object, or None (to fetch everything). `sort` is an
@@ -1070,37 +1247,54 @@ class Database:
where, subvals = query.clause()
order_by = sort.order_clause()
- sql = ("SELECT * FROM {} WHERE {} {}").format(
- model_cls._table,
- where or '1',
- f"ORDER BY {order_by}" if order_by else '',
- )
+ table = model_cls._table
+ _from = table
+ if query.field_names & model_cls.other_db_fields:
+ _from += f" {model_cls.relation_join}"
+ # group by id to avoid duplicates when joining with the relation
+ sql = (
+ f"SELECT {table}.* "
+ f"FROM ({_from}) "
+ f"WHERE {where or 1} "
+ f"GROUP BY {table}.id"
+ )
# Fetch flexible attributes for items matching the main query.
# Doing the per-item filtering in python is faster than issuing
# one query per item to sqlite.
- flex_sql = ("""
- SELECT * FROM {} WHERE entity_id IN
- (SELECT id FROM {} WHERE {});
- """.format(
- model_cls._flex_table,
- model_cls._table,
- where or '1',
- )
+ flex_sql = (
+ "SELECT * "
+ f"FROM {model_cls._flex_table} "
+ f"WHERE entity_id IN (SELECT id FROM ({sql}))"
)
+ if order_by:
+ # the sort field may exist in both 'items' and 'albums' tables
+ # (when they are joined), causing ambiguous column OperationalError
+ # if we try to order directly.
+ # Since the join is required only for filtering, we can filter in
+ # a subquery and order the result, which returns unique fields.
+ sql = f"SELECT * FROM ({sql}) ORDER BY {order_by}"
+
with self.transaction() as tx:
rows = tx.query(sql, subvals)
flex_rows = tx.query(flex_sql, subvals)
return Results(
- model_cls, rows, self, flex_rows,
+ model_cls,
+ rows,
+ self,
+ flex_rows,
None if where else query, # Slow query component.
sort if sort.is_slow() else None, # Slow sort component.
)
- def _get(self, model_cls, id):
+ def _get(
+ self,
+ model_cls: Type[AnyModel],
+ id,
+ ) -> Optional[AnyModel]:
"""Get a Model object by its id or None if the id does not
exist.
"""
- return self._fetch(model_cls, MatchQuery('id', id)).get()
+ return self._fetch(model_cls, MatchQuery("id", id)).get()
diff --git a/lib/beets/dbcore/query.py b/lib/beets/dbcore/query.py
index 96476a5b..f8cf7fe4 100644
--- a/lib/beets/dbcore/query.py
+++ b/lib/beets/dbcore/query.py
@@ -12,19 +12,42 @@
# The above copyright notice and this permission notice shall be
# included in all copies or substantial portions of the Software.
-"""The Query type hierarchy for DBCore.
-"""
+"""The Query type hierarchy for DBCore."""
+
+from __future__ import annotations
import re
-from operator import mul
-from beets import util
-from datetime import datetime, timedelta
import unicodedata
+from abc import ABC, abstractmethod
+from datetime import datetime, timedelta
from functools import reduce
+from operator import mul, or_
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ Collection,
+ Generic,
+ Iterator,
+ List,
+ MutableSequence,
+ Optional,
+ Pattern,
+ Sequence,
+ Set,
+ Tuple,
+ Type,
+ TypeVar,
+ Union,
+)
+
+from beets import util
+
+if TYPE_CHECKING:
+ from beets.dbcore import Model
class ParsingError(ValueError):
- """Abstract class for any unparseable user-requested album/query
+ """Abstract class for any unparsable user-requested album/query
specification.
"""
@@ -56,36 +79,54 @@ class InvalidQueryArgumentValueError(ParsingError):
super().__init__(message)
-class Query:
- """An abstract class representing a query into the item database.
- """
+class Query(ABC):
+ """An abstract class representing a query into the database."""
- def clause(self):
+ @property
+ def field_names(self) -> Set[str]:
+ """Return a set with field names that this query operates on."""
+ return set()
+
+ def clause(self) -> Tuple[Optional[str], Sequence[Any]]:
"""Generate an SQLite expression implementing the query.
Return (clause, subvals) where clause is a valid sqlite
WHERE clause implementing the query and subvals is a list of
items to be substituted for ?s in the clause.
+
+ The default implementation returns None, falling back to a slow query
+ using `match()`.
"""
return None, ()
- def match(self, item):
- """Check whether this query matches a given Item. Can be used to
- perform queries on arbitrary sets of Items.
+ @abstractmethod
+ def match(self, obj: Model):
+ """Check whether this query matches a given Model. Can be used to
+ perform queries on arbitrary sets of Model.
"""
- raise NotImplementedError
+ ...
- def __repr__(self):
+ def __repr__(self) -> str:
return f"{self.__class__.__name__}()"
- def __eq__(self, other):
- return type(self) == type(other)
+ def __eq__(self, other) -> bool:
+ return type(self) is type(other)
- def __hash__(self):
- return 0
+ def __hash__(self) -> int:
+ """Minimalistic default implementation of a hash.
+
+ Given the implementation if __eq__ above, this is
+ certainly correct.
+ """
+ return hash(type(self))
-class FieldQuery(Query):
+P = TypeVar("P")
+SQLiteType = Union[str, bytes, float, int, memoryview]
+AnySQLiteType = TypeVar("AnySQLiteType", bound=SQLiteType)
+
+
+class FieldQuery(Query, Generic[P]):
"""An abstract query that searches in a specific field for a
pattern. Subclasses must provide a `value_match` class method, which
determines whether a certain pattern string matches a certain value
@@ -93,15 +134,26 @@ class FieldQuery(Query):
same matching functionality in SQLite.
"""
- def __init__(self, field, pattern, fast=True):
- self.field = field
+ @property
+ def field(self) -> str:
+ return (
+ f"{self.table}.{self.field_name}" if self.table else self.field_name
+ )
+
+ @property
+ def field_names(self) -> Set[str]:
+ """Return a set with field names that this query operates on."""
+ return {self.field_name}
+
+ def __init__(self, field_name: str, pattern: P, fast: bool = True):
+ self.table, _, self.field_name = field_name.rpartition(".")
self.pattern = pattern
self.fast = fast
- def col_clause(self):
- return None, ()
+ def col_clause(self) -> Tuple[str, Sequence[SQLiteType]]:
+ return self.field, ()
- def clause(self):
+ def clause(self) -> Tuple[Optional[str], Sequence[SQLiteType]]:
if self.fast:
return self.col_clause()
else:
@@ -109,161 +161,205 @@ class FieldQuery(Query):
return None, ()
@classmethod
- def value_match(cls, pattern, value):
- """Determine whether the value matches the pattern. Both
- arguments are strings.
- """
+ def value_match(cls, pattern: P, value: Any):
+ """Determine whether the value matches the pattern."""
raise NotImplementedError()
- def match(self, item):
- return self.value_match(self.pattern, item.get(self.field))
+ def match(self, obj: Model) -> bool:
+ return self.value_match(self.pattern, obj.get(self.field_name))
- def __repr__(self):
- return ("{0.__class__.__name__}({0.field!r}, {0.pattern!r}, "
- "{0.fast})".format(self))
+ def __repr__(self) -> str:
+ return (
+ f"{self.__class__.__name__}({self.field_name!r}, {self.pattern!r}, "
+ f"fast={self.fast})"
+ )
- def __eq__(self, other):
- return super().__eq__(other) and \
- self.field == other.field and self.pattern == other.pattern
+ def __eq__(self, other) -> bool:
+ return (
+ super().__eq__(other)
+ and self.field_name == other.field_name
+ and self.pattern == other.pattern
+ )
- def __hash__(self):
- return hash((self.field, hash(self.pattern)))
+ def __hash__(self) -> int:
+ return hash((self.field_name, hash(self.pattern)))
-class MatchQuery(FieldQuery):
- """A query that looks for exact matches in an item field."""
+class MatchQuery(FieldQuery[AnySQLiteType]):
+ """A query that looks for exact matches in an Model field."""
- def col_clause(self):
+ def col_clause(self) -> Tuple[str, Sequence[SQLiteType]]:
return self.field + " = ?", [self.pattern]
@classmethod
- def value_match(cls, pattern, value):
+ def value_match(cls, pattern: AnySQLiteType, value: Any) -> bool:
return pattern == value
-class NoneQuery(FieldQuery):
+class NoneQuery(FieldQuery[None]):
"""A query that checks whether a field is null."""
- def __init__(self, field, fast=True):
+ def __init__(self, field, fast: bool = True):
super().__init__(field, None, fast)
- def col_clause(self):
+ def col_clause(self) -> Tuple[str, Sequence[SQLiteType]]:
return self.field + " IS NULL", ()
- def match(self, item):
- return item.get(self.field) is None
+ def match(self, obj: Model) -> bool:
+ return obj.get(self.field_name) is None
- def __repr__(self):
- return "{0.__class__.__name__}({0.field!r}, {0.fast})".format(self)
+ def __repr__(self) -> str:
+ return f"{self.__class__.__name__}({self.field_name!r}, {self.fast})"
-class StringFieldQuery(FieldQuery):
+class StringFieldQuery(FieldQuery[P]):
"""A FieldQuery that converts values to strings before matching
them.
"""
@classmethod
- def value_match(cls, pattern, value):
+ def value_match(cls, pattern: P, value: Any):
"""Determine whether the value matches the pattern. The value
may have any type.
"""
return cls.string_match(pattern, util.as_string(value))
@classmethod
- def string_match(cls, pattern, value):
+ def string_match(
+ cls,
+ pattern: P,
+ value: str,
+ ) -> bool:
"""Determine whether the value matches the pattern. Both
arguments are strings. Subclasses implement this method.
"""
raise NotImplementedError()
-class SubstringQuery(StringFieldQuery):
- """A query that matches a substring in a specific item field."""
+class StringQuery(StringFieldQuery[str]):
+ """A query that matches a whole string in a specific Model field."""
- def col_clause(self):
- pattern = (self.pattern
- .replace('\\', '\\\\')
- .replace('%', '\\%')
- .replace('_', '\\_'))
- search = '%' + pattern + '%'
+ def col_clause(self) -> Tuple[str, Sequence[SQLiteType]]:
+ search = (
+ self.pattern.replace("\\", "\\\\")
+ .replace("%", "\\%")
+ .replace("_", "\\_")
+ )
clause = self.field + " like ? escape '\\'"
subvals = [search]
return clause, subvals
@classmethod
- def string_match(cls, pattern, value):
+ def string_match(cls, pattern: str, value: str) -> bool:
+ return pattern.lower() == value.lower()
+
+
+class SubstringQuery(StringFieldQuery[str]):
+ """A query that matches a substring in a specific Model field."""
+
+ def col_clause(self) -> Tuple[str, Sequence[SQLiteType]]:
+ pattern = (
+ self.pattern.replace("\\", "\\\\")
+ .replace("%", "\\%")
+ .replace("_", "\\_")
+ )
+ search = "%" + pattern + "%"
+ clause = self.field + " like ? escape '\\'"
+ subvals = [search]
+ return clause, subvals
+
+ @classmethod
+ def string_match(cls, pattern: str, value: str) -> bool:
return pattern.lower() in value.lower()
-class RegexpQuery(StringFieldQuery):
- """A query that matches a regular expression in a specific item
- field.
+class RegexpQuery(StringFieldQuery[Pattern[str]]):
+ """A query that matches a regular expression in a specific Model field.
Raises InvalidQueryError when the pattern is not a valid regular
expression.
"""
- def __init__(self, field, pattern, fast=True):
- super().__init__(field, pattern, fast)
+ def __init__(self, field_name: str, pattern: str, fast: bool = True):
pattern = self._normalize(pattern)
try:
- self.pattern = re.compile(self.pattern)
+ pattern_re = re.compile(pattern)
except re.error as exc:
# Invalid regular expression.
- raise InvalidQueryArgumentValueError(pattern,
- "a regular expression",
- format(exc))
+ raise InvalidQueryArgumentValueError(
+ pattern, "a regular expression", format(exc)
+ )
+
+ super().__init__(field_name, pattern_re, fast)
+
+ def col_clause(self) -> Tuple[str, Sequence[SQLiteType]]:
+ return f" regexp({self.field}, ?)", [self.pattern.pattern]
@staticmethod
- def _normalize(s):
+ def _normalize(s: str) -> str:
"""Normalize a Unicode string's representation (used on both
patterns and matched values).
"""
- return unicodedata.normalize('NFC', s)
+ return unicodedata.normalize("NFC", s)
@classmethod
- def string_match(cls, pattern, value):
+ def string_match(cls, pattern: Pattern, value: str) -> bool:
return pattern.search(cls._normalize(value)) is not None
-class BooleanQuery(MatchQuery):
+class BooleanQuery(MatchQuery[int]):
"""Matches a boolean field. Pattern should either be a boolean or a
string reflecting a boolean.
"""
- def __init__(self, field, pattern, fast=True):
- super().__init__(field, pattern, fast)
+ def __init__(
+ self,
+ field_name: str,
+ pattern: bool,
+ fast: bool = True,
+ ):
if isinstance(pattern, str):
- self.pattern = util.str2bool(pattern)
- self.pattern = int(self.pattern)
+ pattern = util.str2bool(pattern)
+
+ pattern_int = int(pattern)
+
+ super().__init__(field_name, pattern_int, fast)
-class BytesQuery(MatchQuery):
+class BytesQuery(FieldQuery[bytes]):
"""Match a raw bytes field (i.e., a path). This is a necessary hack
to work around the `sqlite3` module's desire to treat `bytes` and
`unicode` equivalently in Python 2. Always use this query instead of
`MatchQuery` when matching on BLOB values.
"""
- def __init__(self, field, pattern):
- super().__init__(field, pattern)
-
+ def __init__(self, field_name: str, pattern: Union[bytes, str, memoryview]):
# Use a buffer/memoryview representation of the pattern for SQLite
# matching. This instructs SQLite to treat the blob as binary
# rather than encoded Unicode.
- if isinstance(self.pattern, (str, bytes)):
- if isinstance(self.pattern, str):
- self.pattern = self.pattern.encode('utf-8')
- self.buf_pattern = memoryview(self.pattern)
- elif isinstance(self.pattern, memoryview):
- self.buf_pattern = self.pattern
- self.pattern = bytes(self.pattern)
+ if isinstance(pattern, (str, bytes)):
+ if isinstance(pattern, str):
+ bytes_pattern = pattern.encode("utf-8")
+ else:
+ bytes_pattern = pattern
+ self.buf_pattern = memoryview(bytes_pattern)
+ elif isinstance(pattern, memoryview):
+ self.buf_pattern = pattern
+ bytes_pattern = bytes(pattern)
+ else:
+ raise ValueError("pattern must be bytes, str, or memoryview")
- def col_clause(self):
+ super().__init__(field_name, bytes_pattern)
+
+ def col_clause(self) -> Tuple[str, Sequence[SQLiteType]]:
return self.field + " = ?", [self.buf_pattern]
+ @classmethod
+ def value_match(cls, pattern: bytes, value: Any) -> bool:
+ return pattern == value
-class NumericQuery(FieldQuery):
+
+class NumericQuery(FieldQuery[str]):
"""Matches numeric fields. A syntax using Ruby-style range ellipses
(``..``) lets users specify one- or two-sided ranges. For example,
``year:2001..`` finds music released since the turn of the century.
@@ -272,7 +368,7 @@ class NumericQuery(FieldQuery):
a float.
"""
- def _convert(self, s):
+ def _convert(self, s: str) -> Union[float, int, None]:
"""Convert a string to a numeric type (float or int).
Return None if `s` is empty.
@@ -289,10 +385,10 @@ class NumericQuery(FieldQuery):
except ValueError:
raise InvalidQueryArgumentValueError(s, "an int or a float")
- def __init__(self, field, pattern, fast=True):
- super().__init__(field, pattern, fast)
+ def __init__(self, field_name: str, pattern: str, fast: bool = True):
+ super().__init__(field_name, pattern, fast)
- parts = pattern.split('..', 1)
+ parts = pattern.split("..", 1)
if len(parts) == 1:
# No range.
self.point = self._convert(parts[0])
@@ -304,10 +400,10 @@ class NumericQuery(FieldQuery):
self.rangemin = self._convert(parts[0])
self.rangemax = self._convert(parts[1])
- def match(self, item):
- if self.field not in item:
+ def match(self, obj: Model) -> bool:
+ if self.field_name not in obj:
return False
- value = item[self.field]
+ value = obj[self.field_name]
if isinstance(value, str):
value = self._convert(value)
@@ -320,19 +416,43 @@ class NumericQuery(FieldQuery):
return False
return True
- def col_clause(self):
+ def col_clause(self) -> Tuple[str, Sequence[SQLiteType]]:
if self.point is not None:
- return self.field + '=?', (self.point,)
+ return self.field + "=?", (self.point,)
else:
if self.rangemin is not None and self.rangemax is not None:
- return ('{0} >= ? AND {0} <= ?'.format(self.field),
- (self.rangemin, self.rangemax))
+ return (
+ "{0} >= ? AND {0} <= ?".format(self.field),
+ (self.rangemin, self.rangemax),
+ )
elif self.rangemin is not None:
- return f'{self.field} >= ?', (self.rangemin,)
+ return f"{self.field} >= ?", (self.rangemin,)
elif self.rangemax is not None:
- return f'{self.field} <= ?', (self.rangemax,)
+ return f"{self.field} <= ?", (self.rangemax,)
else:
- return '1', ()
+ return "1", ()
+
+
+class InQuery(Generic[AnySQLiteType], FieldQuery[Sequence[AnySQLiteType]]):
+ """Query which matches values in the given set."""
+
+ field_name: str
+ pattern: Sequence[AnySQLiteType]
+ fast: bool = True
+
+ @property
+ def subvals(self) -> Sequence[SQLiteType]:
+ return self.pattern
+
+ def col_clause(self) -> Tuple[str, Sequence[SQLiteType]]:
+ placeholders = ", ".join(["?"] * len(self.subvals))
+ return f"{self.field_name} IN ({placeholders})", self.subvals
+
+ @classmethod
+ def value_match(
+ cls, pattern: Sequence[AnySQLiteType], value: AnySQLiteType
+ ) -> bool:
+ return value in pattern
class CollectionQuery(Query):
@@ -340,24 +460,32 @@ class CollectionQuery(Query):
indexed like a list to access the sub-queries.
"""
- def __init__(self, subqueries=()):
+ @property
+ def field_names(self) -> Set[str]:
+ """Return a set with field names that this query operates on."""
+ return reduce(or_, (sq.field_names for sq in self.subqueries))
+
+ def __init__(self, subqueries: Sequence = ()):
self.subqueries = subqueries
# Act like a sequence.
- def __len__(self):
+ def __len__(self) -> int:
return len(self.subqueries)
def __getitem__(self, key):
return self.subqueries[key]
- def __iter__(self):
+ def __iter__(self) -> Iterator:
return iter(self.subqueries)
- def __contains__(self, item):
- return item in self.subqueries
+ def __contains__(self, subq) -> bool:
+ return subq in self.subqueries
- def clause_with_joiner(self, joiner):
+ def clause_with_joiner(
+ self,
+ joiner: str,
+ ) -> Tuple[Optional[str], Sequence[SQLiteType]]:
"""Return a clause created by joining together the clauses of
all subqueries with the string joiner (padded by spaces).
"""
@@ -368,19 +496,18 @@ class CollectionQuery(Query):
if not subq_clause:
# Fall back to slow query.
return None, ()
- clause_parts.append('(' + subq_clause + ')')
+ clause_parts.append("(" + subq_clause + ")")
subvals += subq_subvals
- clause = (' ' + joiner + ' ').join(clause_parts)
+ clause = (" " + joiner + " ").join(clause_parts)
return clause, subvals
- def __repr__(self):
- return "{0.__class__.__name__}({0.subqueries!r})".format(self)
+ def __repr__(self) -> str:
+ return f"{self.__class__.__name__}({self.subqueries!r})"
- def __eq__(self, other):
- return super().__eq__(other) and \
- self.subqueries == other.subqueries
+ def __eq__(self, other) -> bool:
+ return super().__eq__(other) and self.subqueries == other.subqueries
- def __hash__(self):
+ def __hash__(self) -> int:
"""Since subqueries are mutable, this object should not be hashable.
However and for conveniences purposes, it can be hashed.
"""
@@ -393,7 +520,12 @@ class AnyFieldQuery(CollectionQuery):
constructor.
"""
- def __init__(self, pattern, fields, cls):
+ @property
+ def field_names(self) -> Set[str]:
+ """Return a set with field names that this query operates on."""
+ return set(self.fields)
+
+ def __init__(self, pattern, fields, cls: Type[FieldQuery]):
self.pattern = pattern
self.fields = fields
self.query_class = cls
@@ -401,26 +533,28 @@ class AnyFieldQuery(CollectionQuery):
subqueries = []
for field in self.fields:
subqueries.append(cls(field, pattern, True))
+ # TYPING ERROR
super().__init__(subqueries)
- def clause(self):
- return self.clause_with_joiner('or')
+ def clause(self) -> Tuple[Optional[str], Sequence[SQLiteType]]:
+ return self.clause_with_joiner("or")
- def match(self, item):
+ def match(self, obj: Model) -> bool:
for subq in self.subqueries:
- if subq.match(item):
+ if subq.match(obj):
return True
return False
- def __repr__(self):
- return ("{0.__class__.__name__}({0.pattern!r}, {0.fields!r}, "
- "{0.query_class.__name__})".format(self))
+ def __repr__(self) -> str:
+ return (
+ f"{self.__class__.__name__}({self.pattern!r}, {self.fields!r}, "
+ f"{self.query_class.__name__})"
+ )
- def __eq__(self, other):
- return super().__eq__(other) and \
- self.query_class == other.query_class
+ def __eq__(self, other) -> bool:
+ return super().__eq__(other) and self.query_class == other.query_class
- def __hash__(self):
+ def __hash__(self) -> int:
return hash((self.pattern, tuple(self.fields), self.query_class))
@@ -429,6 +563,8 @@ class MutableCollectionQuery(CollectionQuery):
query is initialized.
"""
+ subqueries: MutableSequence
+
def __setitem__(self, key, value):
self.subqueries[key] = value
@@ -439,94 +575,86 @@ class MutableCollectionQuery(CollectionQuery):
class AndQuery(MutableCollectionQuery):
"""A conjunction of a list of other queries."""
- def clause(self):
- return self.clause_with_joiner('and')
+ def clause(self) -> Tuple[Optional[str], Sequence[SQLiteType]]:
+ return self.clause_with_joiner("and")
- def match(self, item):
- return all(q.match(item) for q in self.subqueries)
+ def match(self, obj: Model) -> bool:
+ return all(q.match(obj) for q in self.subqueries)
class OrQuery(MutableCollectionQuery):
"""A conjunction of a list of other queries."""
- def clause(self):
- return self.clause_with_joiner('or')
+ def clause(self) -> Tuple[Optional[str], Sequence[SQLiteType]]:
+ return self.clause_with_joiner("or")
- def match(self, item):
- return any(q.match(item) for q in self.subqueries)
+ def match(self, obj: Model) -> bool:
+ return any(q.match(obj) for q in self.subqueries)
class NotQuery(Query):
- """A query that matches the negation of its `subquery`, as a shorcut for
+ """A query that matches the negation of its `subquery`, as a shortcut for
performing `not(subquery)` without using regular expressions.
"""
+ @property
+ def field_names(self) -> Set[str]:
+ """Return a set with field names that this query operates on."""
+ return self.subquery.field_names
+
def __init__(self, subquery):
self.subquery = subquery
- def clause(self):
+ def clause(self) -> Tuple[Optional[str], Sequence[SQLiteType]]:
clause, subvals = self.subquery.clause()
if clause:
- return f'not ({clause})', subvals
+ return f"not ({clause})", subvals
else:
# If there is no clause, there is nothing to negate. All the logic
# is handled by match() for slow queries.
return clause, subvals
- def match(self, item):
- return not self.subquery.match(item)
+ def match(self, obj: Model) -> bool:
+ return not self.subquery.match(obj)
- def __repr__(self):
- return "{0.__class__.__name__}({0.subquery!r})".format(self)
+ def __repr__(self) -> str:
+ return f"{self.__class__.__name__}({self.subquery!r})"
- def __eq__(self, other):
- return super().__eq__(other) and \
- self.subquery == other.subquery
+ def __eq__(self, other) -> bool:
+ return super().__eq__(other) and self.subquery == other.subquery
- def __hash__(self):
- return hash(('not', hash(self.subquery)))
+ def __hash__(self) -> int:
+ return hash(("not", hash(self.subquery)))
class TrueQuery(Query):
"""A query that always matches."""
- def clause(self):
- return '1', ()
+ def clause(self) -> Tuple[str, Sequence[SQLiteType]]:
+ return "1", ()
- def match(self, item):
+ def match(self, obj: Model) -> bool:
return True
class FalseQuery(Query):
"""A query that never matches."""
- def clause(self):
- return '0', ()
+ def clause(self) -> Tuple[str, Sequence[SQLiteType]]:
+ return "0", ()
- def match(self, item):
+ def match(self, obj: Model) -> bool:
return False
# Time/date queries.
-def _to_epoch_time(date):
- """Convert a `datetime` object to an integer number of seconds since
- the (local) Unix epoch.
- """
- if hasattr(date, 'timestamp'):
- # The `timestamp` method exists on Python 3.3+.
- return int(date.timestamp())
- else:
- epoch = datetime.fromtimestamp(0)
- delta = date - epoch
- return int(delta.total_seconds())
-
-def _parse_periods(pattern):
+def _parse_periods(pattern: str) -> Tuple[Optional[Period], Optional[Period]]:
"""Parse a string containing two dates separated by two dots (..).
Return a pair of `Period` objects.
"""
- parts = pattern.split('..', 1)
+ parts = pattern.split("..", 1)
if len(parts) == 1:
instant = Period.parse(parts[0])
return (instant, instant)
@@ -543,31 +671,32 @@ class Period:
instants of time during January 2014.
"""
- precisions = ('year', 'month', 'day', 'hour', 'minute', 'second')
+ precisions = ("year", "month", "day", "hour", "minute", "second")
date_formats = (
- ('%Y',), # year
- ('%Y-%m',), # month
- ('%Y-%m-%d',), # day
- ('%Y-%m-%dT%H', '%Y-%m-%d %H'), # hour
- ('%Y-%m-%dT%H:%M', '%Y-%m-%d %H:%M'), # minute
- ('%Y-%m-%dT%H:%M:%S', '%Y-%m-%d %H:%M:%S') # second
+ ("%Y",), # year
+ ("%Y-%m",), # month
+ ("%Y-%m-%d",), # day
+ ("%Y-%m-%dT%H", "%Y-%m-%d %H"), # hour
+ ("%Y-%m-%dT%H:%M", "%Y-%m-%d %H:%M"), # minute
+ ("%Y-%m-%dT%H:%M:%S", "%Y-%m-%d %H:%M:%S"), # second
+ )
+ relative_units = {"y": 365, "m": 30, "w": 7, "d": 1}
+ relative_re = (
+ "(?P[+|-]?)(?P[0-9]+)" + "(?P[y|m|w|d])"
)
- relative_units = {'y': 365, 'm': 30, 'w': 7, 'd': 1}
- relative_re = '(?P[+|-]?)(?P[0-9]+)' + \
- '(?P[y|m|w|d])'
- def __init__(self, date, precision):
+ def __init__(self, date: datetime, precision: str):
"""Create a period with the given date (a `datetime` object) and
precision (a string, one of "year", "month", "day", "hour", "minute",
or "second").
"""
if precision not in Period.precisions:
- raise ValueError(f'Invalid precision {precision}')
+ raise ValueError(f"Invalid precision {precision}")
self.date = date
self.precision = precision
@classmethod
- def parse(cls, string):
+ def parse(cls: Type["Period"], string: str) -> Optional["Period"]:
"""Parse a date and return a `Period` object or `None` if the
string is empty, or raise an InvalidQueryArgumentValueError if
the string cannot be parsed to a date.
@@ -584,7 +713,9 @@ class Period:
and a "year" is exactly 365 days.
"""
- def find_date_and_format(string):
+ def find_date_and_format(
+ string: str,
+ ) -> Union[Tuple[None, None], Tuple[datetime, int]]:
for ord, format in enumerate(cls.date_formats):
for format_option in format:
try:
@@ -598,52 +729,57 @@ class Period:
if not string:
return None
+ date: Optional[datetime]
+
# Check for a relative date.
match_dq = re.match(cls.relative_re, string)
if match_dq:
- sign = match_dq.group('sign')
- quantity = match_dq.group('quantity')
- timespan = match_dq.group('timespan')
+ sign = match_dq.group("sign")
+ quantity = match_dq.group("quantity")
+ timespan = match_dq.group("timespan")
# Add or subtract the given amount of time from the current
# date.
- multiplier = -1 if sign == '-' else 1
+ multiplier = -1 if sign == "-" else 1
days = cls.relative_units[timespan]
- date = datetime.now() + \
- timedelta(days=int(quantity) * days) * multiplier
+ date = (
+ datetime.now()
+ + timedelta(days=int(quantity) * days) * multiplier
+ )
return cls(date, cls.precisions[5])
# Check for an absolute date.
date, ordinal = find_date_and_format(string)
- if date is None:
- raise InvalidQueryArgumentValueError(string,
- 'a valid date/time string')
+ if date is None or ordinal is None:
+ raise InvalidQueryArgumentValueError(
+ string, "a valid date/time string"
+ )
precision = cls.precisions[ordinal]
return cls(date, precision)
- def open_right_endpoint(self):
+ def open_right_endpoint(self) -> datetime:
"""Based on the precision, convert the period to a precise
`datetime` for use as a right endpoint in a right-open interval.
"""
precision = self.precision
date = self.date
- if 'year' == self.precision:
+ if "year" == self.precision:
return date.replace(year=date.year + 1, month=1)
- elif 'month' == precision:
- if (date.month < 12):
+ elif "month" == precision:
+ if date.month < 12:
return date.replace(month=date.month + 1)
else:
return date.replace(year=date.year + 1, month=1)
- elif 'day' == precision:
+ elif "day" == precision:
return date + timedelta(days=1)
- elif 'hour' == precision:
+ elif "hour" == precision:
return date + timedelta(hours=1)
- elif 'minute' == precision:
+ elif "minute" == precision:
return date + timedelta(minutes=1)
- elif 'second' == precision:
+ elif "second" == precision:
return date + timedelta(seconds=1)
else:
- raise ValueError(f'unhandled precision {precision}')
+ raise ValueError(f"unhandled precision {precision}")
class DateInterval:
@@ -653,33 +789,37 @@ class DateInterval:
A right endpoint of None means towards infinity.
"""
- def __init__(self, start, end):
+ def __init__(self, start: Optional[datetime], end: Optional[datetime]):
if start is not None and end is not None and not start < end:
- raise ValueError("start date {} is not before end date {}"
- .format(start, end))
+ raise ValueError(
+ "start date {} is not before end date {}".format(start, end)
+ )
self.start = start
self.end = end
@classmethod
- def from_periods(cls, start, end):
- """Create an interval with two Periods as the endpoints.
- """
+ def from_periods(
+ cls,
+ start: Optional[Period],
+ end: Optional[Period],
+ ) -> DateInterval:
+ """Create an interval with two Periods as the endpoints."""
end_date = end.open_right_endpoint() if end is not None else None
start_date = start.date if start is not None else None
return cls(start_date, end_date)
- def contains(self, date):
+ def contains(self, date: datetime) -> bool:
if self.start is not None and date < self.start:
return False
if self.end is not None and date >= self.end:
return False
return True
- def __str__(self):
- return f'[{self.start}, {self.end})'
+ def __str__(self) -> str:
+ return f"[{self.start}, {self.end})"
-class DateQuery(FieldQuery):
+class DateQuery(FieldQuery[str]):
"""Matches date fields stored as seconds since Unix epoch time.
Dates can be specified as ``year-month-day`` strings where only year
@@ -689,38 +829,40 @@ class DateQuery(FieldQuery):
using an ellipsis interval syntax similar to that of NumericQuery.
"""
- def __init__(self, field, pattern, fast=True):
- super().__init__(field, pattern, fast)
+ def __init__(self, field_name: str, pattern: str, fast: bool = True):
+ super().__init__(field_name, pattern, fast)
start, end = _parse_periods(pattern)
self.interval = DateInterval.from_periods(start, end)
- def match(self, item):
- if self.field not in item:
+ def match(self, obj: Model) -> bool:
+ if self.field_name not in obj:
return False
- timestamp = float(item[self.field])
+ timestamp = float(obj[self.field_name])
date = datetime.fromtimestamp(timestamp)
return self.interval.contains(date)
_clause_tmpl = "{0} {1} ?"
- def col_clause(self):
+ def col_clause(self) -> Tuple[str, Sequence[SQLiteType]]:
clause_parts = []
subvals = []
+ # Convert the `datetime` objects to an integer number of seconds since
+ # the (local) Unix epoch using `datetime.timestamp()`.
if self.interval.start:
clause_parts.append(self._clause_tmpl.format(self.field, ">="))
- subvals.append(_to_epoch_time(self.interval.start))
+ subvals.append(int(self.interval.start.timestamp()))
if self.interval.end:
clause_parts.append(self._clause_tmpl.format(self.field, "<"))
- subvals.append(_to_epoch_time(self.interval.end))
+ subvals.append(int(self.interval.end.timestamp()))
if clause_parts:
# One- or two-sided interval.
- clause = ' AND '.join(clause_parts)
+ clause = " AND ".join(clause_parts)
else:
# Match any date.
- clause = '1'
+ clause = "1"
return clause, subvals
@@ -733,7 +875,7 @@ class DurationQuery(NumericQuery):
or M:SS time interval.
"""
- def _convert(self, s):
+ def _convert(self, s: str) -> Optional[float]:
"""Convert a M:SS or numeric string to a float.
Return None if `s` is empty.
@@ -748,77 +890,72 @@ class DurationQuery(NumericQuery):
return float(s)
except ValueError:
raise InvalidQueryArgumentValueError(
- s,
- "a M:SS string or a float")
+ s, "a M:SS string or a float"
+ )
# Sorting.
+
class Sort:
"""An abstract class representing a sort operation for a query into
- the item database.
+ the database.
"""
- def order_clause(self):
+ def order_clause(self) -> Optional[str]:
"""Generates a SQL fragment to be used in a ORDER BY clause, or
None if no fragment is used (i.e., this is a slow sort).
"""
return None
- def sort(self, items):
- """Sort the list of objects and return a list.
- """
+ def sort(self, items: List) -> List:
+ """Sort the list of objects and return a list."""
return sorted(items)
- def is_slow(self):
+ def is_slow(self) -> bool:
"""Indicate whether this query is *slow*, meaning that it cannot
be executed in SQL and must be executed in Python.
"""
return False
- def __hash__(self):
+ def __hash__(self) -> int:
return 0
- def __eq__(self, other):
- return type(self) == type(other)
+ def __eq__(self, other) -> bool:
+ return type(self) is type(other)
+
+ def __repr__(self):
+ return f"{self.__class__.__name__}()"
class MultipleSort(Sort):
- """Sort that encapsulates multiple sub-sorts.
- """
+ """Sort that encapsulates multiple sub-sorts."""
- def __init__(self, sorts=None):
+ def __init__(self, sorts: Optional[List[Sort]] = None):
self.sorts = sorts or []
- def add_sort(self, sort):
+ def add_sort(self, sort: Sort):
self.sorts.append(sort)
- def _sql_sorts(self):
- """Return the list of sub-sorts for which we can be (at least
- partially) fast.
+ def order_clause(self) -> str:
+ """Return the list SQL clauses for those sub-sorts for which we can be
+ (at least partially) fast.
A contiguous suffix of fast (SQL-capable) sub-sorts are
executable in SQL. The remaining, even if they are fast
independently, must be executed slowly.
"""
- sql_sorts = []
- for sort in reversed(self.sorts):
- if not sort.order_clause() is None:
- sql_sorts.append(sort)
- else:
- break
- sql_sorts.reverse()
- return sql_sorts
-
- def order_clause(self):
order_strings = []
- for sort in self._sql_sorts():
- order = sort.order_clause()
- order_strings.append(order)
+ for sort in reversed(self.sorts):
+ clause = sort.order_clause()
+ if clause is None:
+ break
+ order_strings.append(clause)
+ order_strings.reverse()
return ", ".join(order_strings)
- def is_slow(self):
+ def is_slow(self) -> bool:
for sort in self.sorts:
if sort.is_slow():
return True
@@ -841,14 +978,13 @@ class MultipleSort(Sort):
return items
def __repr__(self):
- return f'MultipleSort({self.sorts!r})'
+ return f"{self.__class__.__name__}({self.sorts!r})"
def __hash__(self):
return hash(tuple(self.sorts))
def __eq__(self, other):
- return super().__eq__(other) and \
- self.sorts == other.sorts
+ return super().__eq__(other) and self.sorts == other.sorts
class FieldSort(Sort):
@@ -856,51 +992,58 @@ class FieldSort(Sort):
any kind).
"""
- def __init__(self, field, ascending=True, case_insensitive=True):
+ def __init__(
+ self,
+ field,
+ ascending: bool = True,
+ case_insensitive: bool = True,
+ ):
self.field = field
self.ascending = ascending
self.case_insensitive = case_insensitive
- def sort(self, objs):
+ def sort(self, objs: Collection):
# TODO: Conversion and null-detection here. In Python 3,
# comparisons with None fail. We should also support flexible
# attributes with different types without falling over.
- def key(item):
- field_val = item.get(self.field, '')
+ def key(obj: Model) -> Any:
+ field_val = obj.get(self.field, "")
if self.case_insensitive and isinstance(field_val, str):
field_val = field_val.lower()
return field_val
return sorted(objs, key=key, reverse=not self.ascending)
- def __repr__(self):
- return '<{}: {}{}>'.format(
- type(self).__name__,
- self.field,
- '+' if self.ascending else '-',
+ def __repr__(self) -> str:
+ return (
+ f"{self.__class__.__name__}"
+ f"({self.field!r}, ascending={self.ascending!r})"
)
- def __hash__(self):
+ def __hash__(self) -> int:
return hash((self.field, self.ascending))
- def __eq__(self, other):
- return super().__eq__(other) and \
- self.field == other.field and \
- self.ascending == other.ascending
+ def __eq__(self, other) -> bool:
+ return (
+ super().__eq__(other)
+ and self.field == other.field
+ and self.ascending == other.ascending
+ )
class FixedFieldSort(FieldSort):
- """Sort object to sort on a fixed field.
- """
+ """Sort object to sort on a fixed field."""
- def order_clause(self):
+ def order_clause(self) -> str:
order = "ASC" if self.ascending else "DESC"
if self.case_insensitive:
- field = '(CASE ' \
- 'WHEN TYPEOF({0})="text" THEN LOWER({0}) ' \
- 'WHEN TYPEOF({0})="blob" THEN LOWER({0}) ' \
- 'ELSE {0} END)'.format(self.field)
+ field = (
+ "(CASE "
+ 'WHEN TYPEOF({0})="text" THEN LOWER({0}) '
+ 'WHEN TYPEOF({0})="blob" THEN LOWER({0}) '
+ "ELSE {0} END)".format(self.field)
+ )
else:
field = self.field
return f"{field} {order}"
@@ -911,24 +1054,24 @@ class SlowFieldSort(FieldSort):
i.e., a computed or flexible field.
"""
- def is_slow(self):
+ def is_slow(self) -> bool:
return True
class NullSort(Sort):
"""No sorting. Leave results unsorted."""
- def sort(self, items):
+ def sort(self, items: List) -> List:
return items
- def __nonzero__(self):
+ def __nonzero__(self) -> bool:
return self.__bool__()
- def __bool__(self):
+ def __bool__(self) -> bool:
return False
- def __eq__(self, other):
- return type(self) == type(other) or other is None
+ def __eq__(self, other) -> bool:
+ return type(self) is type(other) or other is None
- def __hash__(self):
+ def __hash__(self) -> int:
return 0
diff --git a/lib/beets/dbcore/queryparse.py b/lib/beets/dbcore/queryparse.py
index 3bf02e4d..8d2a0ae0 100644
--- a/lib/beets/dbcore/queryparse.py
+++ b/lib/beets/dbcore/queryparse.py
@@ -12,30 +12,33 @@
# The above copyright notice and this permission notice shall be
# included in all copies or substantial portions of the Software.
-"""Parsing of strings into DBCore queries.
-"""
+"""Parsing of strings into DBCore queries."""
-import re
import itertools
-from . import query
+import re
+from typing import Collection, Dict, List, Optional, Sequence, Tuple, Type
+
+from . import Model, query
+from .query import Sort
PARSE_QUERY_PART_REGEX = re.compile(
# Non-capturing optional segment for the keyword.
- r'(-|\^)?' # Negation prefixes.
-
- r'(?:'
- r'(\S+?)' # The field key.
- r'(? Tuple[Optional[str], str, Type[query.FieldQuery], bool]:
"""Parse a single *query part*, which is a chunk of a complete query
string representing a single criterion.
@@ -86,13 +89,13 @@ def parse_query_part(part, query_classes={}, prefixes={},
assert match # Regex should always match
negate = bool(match.group(1))
key = match.group(2)
- term = match.group(3).replace('\\:', ':')
+ term = match.group(3).replace("\\:", ":")
# Check whether there's a prefix in the query and use the
# corresponding query type.
for pre, query_class in prefixes.items():
if term.startswith(pre):
- return key, term[len(pre):], query_class, negate
+ return key, term[len(pre) :], query_class, negate
# No matching prefix, so use either the query class determined by
# the field or the default as a fallback.
@@ -100,7 +103,11 @@ def parse_query_part(part, query_classes={}, prefixes={},
return key, term, query_class, negate
-def construct_query_part(model_cls, prefixes, query_part):
+def construct_query_part(
+ model_cls: Type[Model],
+ prefixes: Dict,
+ query_part: str,
+) -> query.Query:
"""Parse a *query part* string and return a :class:`Query` object.
:param model_cls: The :class:`Model` class that this is a query for.
@@ -116,40 +123,44 @@ def construct_query_part(model_cls, prefixes, query_part):
if not query_part:
return query.TrueQuery()
+ out_query: query.Query
+
# Use `model_cls` to build up a map from field (or query) names to
# `Query` classes.
- query_classes = {}
- for k, t in itertools.chain(model_cls._fields.items(),
- model_cls._types.items()):
+ query_classes: Dict[str, Type[query.FieldQuery]] = {}
+ for k, t in itertools.chain(
+ model_cls._fields.items(), model_cls._types.items()
+ ):
query_classes[k] = t.query
query_classes.update(model_cls._queries) # Non-field queries.
# Parse the string.
- key, pattern, query_class, negate = \
- parse_query_part(query_part, query_classes, prefixes)
+ key, pattern, query_class, negate = parse_query_part(
+ query_part, query_classes, prefixes
+ )
# If there's no key (field name) specified, this is a "match
# anything" query.
if key is None:
- if issubclass(query_class, query.FieldQuery):
- # The query type matches a specific field, but none was
- # specified. So we use a version of the query that matches
- # any field.
- out_query = query.AnyFieldQuery(pattern, model_cls._search_fields,
- query_class)
- else:
- # Non-field query type.
- out_query = query_class(pattern)
+ # The query type matches a specific field, but none was
+ # specified. So we use a version of the query that matches
+ # any field.
+ out_query = query.AnyFieldQuery(
+ pattern, model_cls._search_fields, query_class
+ )
# Field queries get constructed according to the name of the field
# they are querying.
- elif issubclass(query_class, query.FieldQuery):
- key = key.lower()
- out_query = query_class(key.lower(), pattern, key in model_cls._fields)
-
- # Non-field (named) query.
else:
- out_query = query_class(pattern)
+ field = table = key.lower()
+ if field in model_cls.shared_db_fields:
+ # This field exists in both tables, so SQLite will encounter
+ # an OperationalError if we try to query it in a join.
+ # Using an explicit table name resolves this.
+ table = f"{model_cls._table}.{field}"
+
+ field_in_db = field in model_cls.all_db_fields
+ out_query = query_class(table, pattern, field_in_db)
# Apply negation.
if negate:
@@ -158,7 +169,13 @@ def construct_query_part(model_cls, prefixes, query_part):
return out_query
-def query_from_strings(query_cls, model_cls, prefixes, query_parts):
+# TYPING ERROR
+def query_from_strings(
+ query_cls: Type[query.CollectionQuery],
+ model_cls: Type[Model],
+ prefixes: Dict,
+ query_parts: Collection[str],
+) -> query.Query:
"""Creates a collection query of type `query_cls` from a list of
strings in the format used by parse_query_part. `model_cls`
determines how queries are constructed from strings.
@@ -171,7 +188,11 @@ def query_from_strings(query_cls, model_cls, prefixes, query_parts):
return query_cls(subqueries)
-def construct_sort_part(model_cls, part, case_insensitive=True):
+def construct_sort_part(
+ model_cls: Type[Model],
+ part: str,
+ case_insensitive: bool = True,
+) -> Sort:
"""Create a `Sort` from a single string criterion.
`model_cls` is the `Model` being queried. `part` is a single string
@@ -183,12 +204,13 @@ def construct_sort_part(model_cls, part, case_insensitive=True):
field = part[:-1]
assert field, "field is missing"
direction = part[-1]
- assert direction in ('+', '-'), "part must end with + or -"
- is_ascending = direction == '+'
+ assert direction in ("+", "-"), "part must end with + or -"
+ is_ascending = direction == "+"
if field in model_cls._sorts:
- sort = model_cls._sorts[field](model_cls, is_ascending,
- case_insensitive)
+ sort = model_cls._sorts[field](
+ model_cls, is_ascending, case_insensitive
+ )
elif field in model_cls._fields:
sort = query.FixedFieldSort(field, is_ascending, case_insensitive)
else:
@@ -197,23 +219,31 @@ def construct_sort_part(model_cls, part, case_insensitive=True):
return sort
-def sort_from_strings(model_cls, sort_parts, case_insensitive=True):
- """Create a `Sort` from a list of sort criteria (strings).
- """
+def sort_from_strings(
+ model_cls: Type[Model],
+ sort_parts: Sequence[str],
+ case_insensitive: bool = True,
+) -> Sort:
+ """Create a `Sort` from a list of sort criteria (strings)."""
if not sort_parts:
- sort = query.NullSort()
+ return query.NullSort()
elif len(sort_parts) == 1:
- sort = construct_sort_part(model_cls, sort_parts[0], case_insensitive)
+ return construct_sort_part(model_cls, sort_parts[0], case_insensitive)
else:
sort = query.MultipleSort()
for part in sort_parts:
- sort.add_sort(construct_sort_part(model_cls, part,
- case_insensitive))
- return sort
+ sort.add_sort(
+ construct_sort_part(model_cls, part, case_insensitive)
+ )
+ return sort
-def parse_sorted_query(model_cls, parts, prefixes={},
- case_insensitive=True):
+def parse_sorted_query(
+ model_cls: Type[Model],
+ parts: List[str],
+ prefixes: Dict = {},
+ case_insensitive: bool = True,
+) -> Tuple[query.Query, Sort]:
"""Given a list of strings, create the `Query` and `Sort` that they
represent.
"""
@@ -224,24 +254,24 @@ def parse_sorted_query(model_cls, parts, prefixes={},
# Split up query in to comma-separated subqueries, each representing
# an AndQuery, which need to be joined together in one OrQuery
subquery_parts = []
- for part in parts + [',']:
- if part.endswith(','):
+ for part in parts + [","]:
+ if part.endswith(","):
# Ensure we can catch "foo, bar" as well as "foo , bar"
last_subquery_part = part[:-1]
if last_subquery_part:
subquery_parts.append(last_subquery_part)
# Parse the subquery in to a single AndQuery
# TODO: Avoid needlessly wrapping AndQueries containing 1 subquery?
- query_parts.append(query_from_strings(
- query.AndQuery, model_cls, prefixes, subquery_parts
- ))
+ query_parts.append(
+ query_from_strings(
+ query.AndQuery, model_cls, prefixes, subquery_parts
+ )
+ )
del subquery_parts[:]
else:
# Sort parts (1) end in + or -, (2) don't have a field, and
# (3) consist of more than just the + or -.
- if part.endswith(('+', '-')) \
- and ':' not in part \
- and len(part) > 1:
+ if part.endswith(("+", "-")) and ":" not in part and len(part) > 1:
sort_parts.append(part)
else:
subquery_parts.append(part)
diff --git a/lib/beets/dbcore/types.py b/lib/beets/dbcore/types.py
index 40f6a080..432db2b7 100644
--- a/lib/beets/dbcore/types.py
+++ b/lib/beets/dbcore/types.py
@@ -14,28 +14,46 @@
"""Representation of type information for DBCore model fields.
"""
+import typing
+from abc import ABC
+from typing import Any, Generic, List, TypeVar, Union, cast
-from . import query
from beets.util import str2bool
+from .query import BooleanQuery, FieldQuery, NumericQuery, SubstringQuery
-# Abstract base.
-class Type:
+class ModelType(typing.Protocol):
+ """Protocol that specifies the required constructor for model types,
+ i.e. a function that takes any argument and attempts to parse it to the
+ given type.
+ """
+
+ def __init__(self, value: Any = None): ...
+
+
+# Generic type variables, used for the value type T and null type N (if
+# nullable, else T and N are set to the same type for the concrete subclasses
+# of Type).
+N = TypeVar("N")
+T = TypeVar("T", bound=ModelType)
+
+
+class Type(ABC, Generic[T, N]):
"""An object encapsulating the type of a model field. Includes
information about how to store, query, format, and parse a given
field.
"""
- sql = 'TEXT'
+ sql: str = "TEXT"
"""The SQLite column type for the value.
"""
- query = query.SubstringQuery
+ query: typing.Type[FieldQuery] = SubstringQuery
"""The `Query` subclass to be used when querying the field.
"""
- model_type = str
+ model_type: typing.Type[T]
"""The Python type that is used to represent the value in the model.
The model is guaranteed to return a value of this type if the field
@@ -44,12 +62,14 @@ class Type:
"""
@property
- def null(self):
- """The value to be exposed when the underlying value is None.
- """
- return self.model_type()
+ def null(self) -> N:
+ """The value to be exposed when the underlying value is None."""
+ # Note that this default implementation only makes sense for T = N.
+ # It would be better to implement `null()` only in subclasses, or
+ # have a field null_type similar to `model_type` and use that here.
+ return cast(N, self.model_type())
- def format(self, value):
+ def format(self, value: Union[N, T]) -> str:
"""Given a value of this type, produce a Unicode string
representing the value. This is used in template evaluation.
"""
@@ -57,13 +77,13 @@ class Type:
value = self.null
# `self.null` might be `None`
if value is None:
- value = ''
- if isinstance(value, bytes):
- value = value.decode('utf-8', 'ignore')
+ return ""
+ elif isinstance(value, bytes):
+ return value.decode("utf-8", "ignore")
+ else:
+ return str(value)
- return str(value)
-
- def parse(self, string):
+ def parse(self, string: str) -> Union[T, N]:
"""Parse a (possibly human-written) string and return the
indicated value of this type.
"""
@@ -72,19 +92,23 @@ class Type:
except ValueError:
return self.null
- def normalize(self, value):
+ def normalize(self, value: Any) -> Union[T, N]:
"""Given a value that will be assigned into a field of this
type, normalize the value to have the appropriate type. This
base implementation only reinterprets `None`.
"""
+ # TYPING ERROR
if value is None:
return self.null
else:
# TODO This should eventually be replaced by
# `self.model_type(value)`
- return value
+ return cast(T, value)
- def from_sql(self, sql_value):
+ def from_sql(
+ self,
+ sql_value: Union[None, int, float, str, bytes],
+ ) -> Union[T, N]:
"""Receives the value stored in the SQL backend and return the
value to be stored in the model.
@@ -99,13 +123,13 @@ class Type:
and the method must handle these in addition.
"""
if isinstance(sql_value, memoryview):
- sql_value = bytes(sql_value).decode('utf-8', 'ignore')
+ sql_value = bytes(sql_value).decode("utf-8", "ignore")
if isinstance(sql_value, str):
return self.parse(sql_value)
else:
return self.normalize(sql_value)
- def to_sql(self, model_value):
+ def to_sql(self, model_value: Any) -> Union[None, int, float, str, bytes]:
"""Convert a value as stored in the model object to a value used
by the database adapter.
"""
@@ -114,18 +138,23 @@ class Type:
# Reusable types.
-class Default(Type):
- null = None
+
+class Default(Type[str, None]):
+ model_type = str
+
+ @property
+ def null(self):
+ return None
-class Integer(Type):
- """A basic integer type.
- """
- sql = 'INTEGER'
- query = query.NumericQuery
+class BaseInteger(Type[int, N]):
+ """A basic integer type."""
+
+ sql = "INTEGER"
+ query = NumericQuery
model_type = int
- def normalize(self, value):
+ def normalize(self, value: Any) -> Union[int, N]:
try:
return self.model_type(round(float(value)))
except ValueError:
@@ -134,91 +163,153 @@ class Integer(Type):
return self.null
-class PaddedInt(Integer):
+class Integer(BaseInteger[int]):
+ @property
+ def null(self) -> int:
+ return 0
+
+
+class NullInteger(BaseInteger[None]):
+ @property
+ def null(self) -> None:
+ return None
+
+
+class BasePaddedInt(BaseInteger[N]):
"""An integer field that is formatted with a given number of digits,
padded with zeroes.
"""
- def __init__(self, digits):
+
+ def __init__(self, digits: int):
self.digits = digits
- def format(self, value):
- return '{0:0{1}d}'.format(value or 0, self.digits)
+ def format(self, value: Union[int, N]) -> str:
+ return "{0:0{1}d}".format(value or 0, self.digits)
-class NullPaddedInt(PaddedInt):
- """Same as `PaddedInt`, but does not normalize `None` to `0.0`.
- """
- null = None
+class PaddedInt(BasePaddedInt[int]):
+ pass
+
+
+class NullPaddedInt(BasePaddedInt[None]):
+ """Same as `PaddedInt`, but does not normalize `None` to `0`."""
+
+ @property
+ def null(self) -> None:
+ return None
class ScaledInt(Integer):
"""An integer whose formatting operation scales the number by a
constant and adds a suffix. Good for units with large magnitudes.
"""
- def __init__(self, unit, suffix=''):
+
+ def __init__(self, unit: int, suffix: str = ""):
self.unit = unit
self.suffix = suffix
- def format(self, value):
- return '{}{}'.format((value or 0) // self.unit, self.suffix)
+ def format(self, value: int) -> str:
+ return "{}{}".format((value or 0) // self.unit, self.suffix)
-class Id(Integer):
+class Id(NullInteger):
"""An integer used as the row id or a foreign key in a SQLite table.
This type is nullable: None values are not translated to zero.
"""
- null = None
- def __init__(self, primary=True):
+ @property
+ def null(self) -> None:
+ return None
+
+ def __init__(self, primary: bool = True):
if primary:
- self.sql = 'INTEGER PRIMARY KEY'
+ self.sql = "INTEGER PRIMARY KEY"
-class Float(Type):
+class BaseFloat(Type[float, N]):
"""A basic floating-point type. The `digits` parameter specifies how
many decimal places to use in the human-readable representation.
"""
- sql = 'REAL'
- query = query.NumericQuery
+
+ sql = "REAL"
+ query: typing.Type[FieldQuery[Any]] = NumericQuery
model_type = float
- def __init__(self, digits=1):
+ def __init__(self, digits: int = 1):
self.digits = digits
- def format(self, value):
- return '{0:.{1}f}'.format(value or 0, self.digits)
+ def format(self, value: Union[float, N]) -> str:
+ return "{0:.{1}f}".format(value or 0, self.digits)
-class NullFloat(Float):
- """Same as `Float`, but does not normalize `None` to `0.0`.
- """
- null = None
+class Float(BaseFloat[float]):
+ """Floating-point type that normalizes `None` to `0.0`."""
+
+ @property
+ def null(self) -> float:
+ return 0.0
-class String(Type):
- """A Unicode string type.
- """
- sql = 'TEXT'
- query = query.SubstringQuery
+class NullFloat(BaseFloat[None]):
+ """Same as `Float`, but does not normalize `None` to `0.0`."""
- def normalize(self, value):
+ @property
+ def null(self) -> None:
+ return None
+
+
+class BaseString(Type[T, N]):
+ """A Unicode string type."""
+
+ sql = "TEXT"
+ query = SubstringQuery
+
+ def normalize(self, value: Any) -> Union[T, N]:
if value is None:
return self.null
else:
return self.model_type(value)
-class Boolean(Type):
- """A boolean type.
+class String(BaseString[str, Any]):
+ """A Unicode string type."""
+
+ model_type = str
+
+
+class DelimitedString(BaseString[List[str], List[str]]):
+ """A list of Unicode strings, represented in-database by a single string
+ containing delimiter-separated values.
"""
- sql = 'INTEGER'
- query = query.BooleanQuery
+
+ model_type = list
+
+ def __init__(self, delimiter: str):
+ self.delimiter = delimiter
+
+ def format(self, value: List[str]):
+ return self.delimiter.join(value)
+
+ def parse(self, string: str):
+ if not string:
+ return []
+ return string.split(self.delimiter)
+
+ def to_sql(self, model_value: List[str]):
+ return self.delimiter.join(model_value)
+
+
+class Boolean(Type):
+ """A boolean type."""
+
+ sql = "INTEGER"
+ query = BooleanQuery
model_type = bool
- def format(self, value):
+ def format(self, value: bool) -> str:
return str(bool(value))
- def parse(self, string):
+ def parse(self, string: str) -> bool:
return str2bool(string)
@@ -231,3 +322,7 @@ FLOAT = Float()
NULL_FLOAT = NullFloat()
STRING = String()
BOOLEAN = Boolean()
+SEMICOLON_SPACE_DSV = DelimitedString(delimiter="; ")
+
+# Will set the proper null char in mediafile
+MULTI_VALUE_DSV = DelimitedString(delimiter="\\␀")
diff --git a/lib/beets/importer.py b/lib/beets/importer.py
index 561cedd2..f6517b51 100644
--- a/lib/beets/importer.py
+++ b/lib/beets/importer.py
@@ -17,78 +17,99 @@
autotagging music files.
"""
-import os
-import re
-import pickle
import itertools
-from collections import defaultdict
-from tempfile import mkdtemp
-from bisect import insort, bisect_left
-from contextlib import contextmanager
+import os
+import pickle
+import re
import shutil
import time
-
-from beets import logging
-from beets import autotag
-from beets import library
-from beets import dbcore
-from beets import plugins
-from beets import util
-from beets import config
-from beets.util import pipeline, sorted_walk, ancestry, MoveOperation
-from beets.util import syspath, normpath, displayable_path
+from bisect import bisect_left, insort
+from collections import defaultdict
+from contextlib import contextmanager
from enum import Enum
+from tempfile import mkdtemp
+
import mediafile
-action = Enum('action',
- ['SKIP', 'ASIS', 'TRACKS', 'APPLY', 'ALBUMS', 'RETAG'])
+from beets import autotag, config, dbcore, library, logging, plugins, util
+from beets.util import (
+ MoveOperation,
+ ancestry,
+ displayable_path,
+ normpath,
+ pipeline,
+ sorted_walk,
+ syspath,
+)
+
+action = Enum("action", ["SKIP", "ASIS", "TRACKS", "APPLY", "ALBUMS", "RETAG"])
# The RETAG action represents "don't apply any match, but do record
# new metadata". It's not reachable via the standard command prompt but
# can be used by plugins.
QUEUE_SIZE = 128
SINGLE_ARTIST_THRESH = 0.25
-PROGRESS_KEY = 'tagprogress'
-HISTORY_KEY = 'taghistory'
+PROGRESS_KEY = "tagprogress"
+HISTORY_KEY = "taghistory"
+# Usually flexible attributes are preserved (i.e., not updated) during
+# reimports. The following two lists (globally) change this behaviour for
+# certain fields. To alter these lists only when a specific plugin is in use,
+# something like this can be used within that plugin's code:
+#
+# from beets import importer
+# def extend_reimport_fresh_fields_item():
+# importer.REIMPORT_FRESH_FIELDS_ITEM.extend(['tidal_track_popularity']
+# )
+REIMPORT_FRESH_FIELDS_ALBUM = ["data_source"]
+REIMPORT_FRESH_FIELDS_ITEM = [
+ "data_source",
+ "bandcamp_album_id",
+ "spotify_album_id",
+ "deezer_album_id",
+ "beatport_album_id",
+ "tidal_album_id",
+]
# Global logger.
-log = logging.getLogger('beets')
+log = logging.getLogger("beets")
class ImportAbort(Exception):
- """Raised when the user aborts the tagging operation.
- """
+ """Raised when the user aborts the tagging operation."""
+
pass
# Utilities.
+
def _open_state():
"""Reads the state file, returning a dictionary."""
try:
- with open(config['statefile'].as_filename(), 'rb') as f:
+ with open(config["statefile"].as_filename(), "rb") as f:
return pickle.load(f)
except Exception as exc:
# The `pickle` module can emit all sorts of exceptions during
# unpickling, including ImportError. We use a catch-all
# exception to avoid enumerating them all (the docs don't even have a
# full list!).
- log.debug('state file could not be read: {0}', exc)
+ log.debug("state file could not be read: {0}", exc)
return {}
def _save_state(state):
"""Writes the state dictionary out to disk."""
try:
- with open(config['statefile'].as_filename(), 'wb') as f:
+ with open(config["statefile"].as_filename(), "wb") as f:
pickle.dump(state, f)
except OSError as exc:
- log.error('state file could not be written: {0}', exc)
+ log.error("state file could not be written: {0}", exc)
# Utilities for reading and writing the beets progress file, which
# allows long tagging tasks to be resumed when they pause (or crash).
+
def progress_read():
state = _open_state()
return state.setdefault(PROGRESS_KEY, {})
@@ -120,8 +141,7 @@ def progress_add(toppath, *paths):
def progress_element(toppath, path):
- """Return whether `path` has been imported in `toppath`.
- """
+ """Return whether `path` has been imported in `toppath`."""
state = progress_read()
if toppath not in state:
return False
@@ -148,6 +168,7 @@ def progress_reset(toppath):
# This keeps track of all directories that were ever imported, which
# allows the importer to only import new stuff.
+
def history_add(paths):
"""Indicate that the import of the album in `paths` is completed and
should not be repeated in incremental imports.
@@ -162,8 +183,7 @@ def history_add(paths):
def history_get():
- """Get the set of completed path tuples in incremental imports.
- """
+ """Get the set of completed path tuples in incremental imports."""
state = _open_state()
if HISTORY_KEY not in state:
return set()
@@ -172,6 +192,7 @@ def history_get():
# Abstract session class.
+
class ImportSession:
"""Controls an import action. Subclasses should implement methods to
communicate with the user or otherwise make decisions.
@@ -212,52 +233,53 @@ class ImportSession:
self.config = iconfig
# Incremental and progress are mutually exclusive.
- if iconfig['incremental']:
- iconfig['resume'] = False
+ if iconfig["incremental"]:
+ iconfig["resume"] = False
# When based on a query instead of directories, never
# save progress or try to resume.
if self.query is not None:
- iconfig['resume'] = False
- iconfig['incremental'] = False
+ iconfig["resume"] = False
+ iconfig["incremental"] = False
- if iconfig['reflink']:
- iconfig['reflink'] = iconfig['reflink'] \
- .as_choice(['auto', True, False])
+ if iconfig["reflink"]:
+ iconfig["reflink"] = iconfig["reflink"].as_choice(
+ ["auto", True, False]
+ )
# Copy, move, reflink, link, and hardlink are mutually exclusive.
- if iconfig['move']:
- iconfig['copy'] = False
- iconfig['link'] = False
- iconfig['hardlink'] = False
- iconfig['reflink'] = False
- elif iconfig['link']:
- iconfig['copy'] = False
- iconfig['move'] = False
- iconfig['hardlink'] = False
- iconfig['reflink'] = False
- elif iconfig['hardlink']:
- iconfig['copy'] = False
- iconfig['move'] = False
- iconfig['link'] = False
- iconfig['reflink'] = False
- elif iconfig['reflink']:
- iconfig['copy'] = False
- iconfig['move'] = False
- iconfig['link'] = False
- iconfig['hardlink'] = False
+ if iconfig["move"]:
+ iconfig["copy"] = False
+ iconfig["link"] = False
+ iconfig["hardlink"] = False
+ iconfig["reflink"] = False
+ elif iconfig["link"]:
+ iconfig["copy"] = False
+ iconfig["move"] = False
+ iconfig["hardlink"] = False
+ iconfig["reflink"] = False
+ elif iconfig["hardlink"]:
+ iconfig["copy"] = False
+ iconfig["move"] = False
+ iconfig["link"] = False
+ iconfig["reflink"] = False
+ elif iconfig["reflink"]:
+ iconfig["copy"] = False
+ iconfig["move"] = False
+ iconfig["link"] = False
+ iconfig["hardlink"] = False
# Only delete when copying.
- if not iconfig['copy']:
- iconfig['delete'] = False
+ if not iconfig["copy"]:
+ iconfig["delete"] = False
- self.want_resume = config['resume'].as_choice([True, False, 'ask'])
+ self.want_resume = config["resume"].as_choice([True, False, "ask"])
def tag_log(self, status, paths):
"""Log a message about a given album to the importer log. The status
should reflect the reason the album couldn't be tagged.
"""
- self.logger.info('{0} {1}', status, displayable_path(paths))
+ self.logger.info("{0} {1}", status, displayable_path(paths))
def log_choice(self, task, duplicate=False):
"""Logs the task's current choice if it should be logged. If
@@ -268,17 +290,17 @@ class ImportSession:
if duplicate:
# Duplicate: log all three choices (skip, keep both, and trump).
if task.should_remove_duplicates:
- self.tag_log('duplicate-replace', paths)
+ self.tag_log("duplicate-replace", paths)
elif task.choice_flag in (action.ASIS, action.APPLY):
- self.tag_log('duplicate-keep', paths)
+ self.tag_log("duplicate-keep", paths)
elif task.choice_flag is (action.SKIP):
- self.tag_log('duplicate-skip', paths)
+ self.tag_log("duplicate-skip", paths)
else:
# Non-duplicate: log "skip" and "asis" choices.
if task.choice_flag is action.ASIS:
- self.tag_log('asis', paths)
+ self.tag_log("asis", paths)
elif task.choice_flag is action.SKIP:
- self.tag_log('skip', paths)
+ self.tag_log("skip", paths)
def should_resume(self, path):
raise NotImplementedError
@@ -293,10 +315,9 @@ class ImportSession:
raise NotImplementedError
def run(self):
- """Run the import task.
- """
- self.logger.info('import started {0}', time.asctime())
- self.set_config(config['import'])
+ """Run the import task."""
+ self.logger.info("import started {0}", time.asctime())
+ self.set_config(config["import"])
# Set up the pipeline.
if self.query is None:
@@ -305,11 +326,10 @@ class ImportSession:
stages = [query_tasks(self)]
# In pretend mode, just log what would otherwise be imported.
- if self.config['pretend']:
+ if self.config["pretend"]:
stages += [log_files(self)]
else:
- if self.config['group_albums'] and \
- not self.config['singletons']:
+ if self.config["group_albums"] and not self.config["singletons"]:
# Split directory tasks into one task for each album.
stages += [group_albums(self)]
@@ -318,7 +338,7 @@ class ImportSession:
# import everything as-is. In *both* cases, these stages
# also add the music to the library database, so later
# stages need to read and write data from there.
- if self.config['autotag']:
+ if self.config["autotag"]:
stages += [lookup_candidates(self), user_query(self)]
else:
stages += [import_asis(self)]
@@ -334,9 +354,9 @@ class ImportSession:
pl = pipeline.Pipeline(stages)
# Run the pipeline.
- plugins.send('import_begin', session=self)
+ plugins.send("import_begin", session=self)
try:
- if config['threaded']:
+ if config["threaded"]:
pl.run_parallel(QUEUE_SIZE)
else:
pl.run_sequential()
@@ -350,18 +370,18 @@ class ImportSession:
"""Returns true if the files belonging to this task have already
been imported in a previous session.
"""
- if self.is_resuming(toppath) \
- and all([progress_element(toppath, p) for p in paths]):
+ if self.is_resuming(toppath) and all(
+ [progress_element(toppath, p) for p in paths]
+ ):
return True
- if self.config['incremental'] \
- and tuple(paths) in self.history_dirs:
+ if self.config["incremental"] and tuple(paths) in self.history_dirs:
return True
return False
@property
def history_dirs(self):
- if not hasattr(self, '_history_dirs'):
+ if not hasattr(self, "_history_dirs"):
self._history_dirs = history_get()
return self._history_dirs
@@ -370,17 +390,17 @@ class ImportSession:
during previous tasks.
"""
for path in paths:
- if path not in self._merged_items \
- and path not in self._merged_dirs:
+ if path not in self._merged_items and path not in self._merged_dirs:
return False
return True
def mark_merged(self, paths):
- """Mark paths and directories as merged for future reimport tasks.
- """
+ """Mark paths and directories as merged for future reimport tasks."""
self._merged_items.update(paths)
- dirs = {os.path.dirname(path) if os.path.isfile(path) else path
- for path in paths}
+ dirs = {
+ os.path.dirname(path) if os.path.isfile(syspath(path)) else path
+ for path in paths
+ }
self._merged_dirs.update(dirs)
def is_resuming(self, toppath):
@@ -392,16 +412,17 @@ class ImportSession:
def ask_resume(self, toppath):
"""If import of `toppath` was aborted in an earlier session, ask
- user if she wants to resume the import.
+ user if they want to resume the import.
Determines the return value of `is_resuming(toppath)`.
"""
if self.want_resume and has_progress(toppath):
# Either accept immediately or prompt for input to decide.
- if self.want_resume is True or \
- self.should_resume(toppath):
- log.warning('Resuming interrupted import of {0}',
- util.displayable_path(toppath))
+ if self.want_resume is True or self.should_resume(toppath):
+ log.warning(
+ "Resuming interrupted import of {0}",
+ util.displayable_path(toppath),
+ )
self._is_resuming[toppath] = True
else:
# Clear progress; we're starting from the top.
@@ -410,11 +431,12 @@ class ImportSession:
# The importer task class.
+
class BaseImportTask:
"""An abstract base class for importer tasks.
Tasks flow through the importer pipeline. Each stage can update
- them. """
+ them."""
def __init__(self, toppath, paths, items):
"""Create a task. The primary fields that define a task are:
@@ -488,8 +510,13 @@ class ImportTask(BaseImportTask):
"""
# Not part of the task structure:
assert choice != action.APPLY # Only used internally.
- if choice in (action.SKIP, action.ASIS, action.TRACKS, action.ALBUMS,
- action.RETAG):
+ if choice in (
+ action.SKIP,
+ action.ASIS,
+ action.TRACKS,
+ action.ALBUMS,
+ action.RETAG,
+ ):
self.choice_flag = choice
self.match = None
else:
@@ -504,8 +531,7 @@ class ImportTask(BaseImportTask):
progress_add(self.toppath, *self.paths)
def save_history(self):
- """Save the directory in the history for incremental imports.
- """
+ """Save the directory in the history for incremental imports."""
if self.paths:
history_add(self.paths)
@@ -521,17 +547,18 @@ class ImportTask(BaseImportTask):
# Convenient data.
- def chosen_ident(self):
- """Returns identifying metadata about the current choice. For
- albums, this is an (artist, album) pair. For items, this is
- (artist, title). May only be called when the choice flag is ASIS
- or RETAG (in which case the data comes from the files' current
- metadata) or APPLY (data comes from the choice).
+ def chosen_info(self):
+ """Return a dictionary of metadata about the current choice.
+ May only be called when the choice flag is ASIS or RETAG
+ (in which case the data comes from the files' current metadata)
+ or APPLY (in which case the data comes from the choice).
"""
if self.choice_flag in (action.ASIS, action.RETAG):
- return (self.cur_artist, self.cur_album)
+ likelies, consensus = autotag.current_metadata(self.items)
+ return likelies
elif self.choice_flag is action.APPLY:
- return (self.match.info.artist, self.match.info.album)
+ return self.match.info.copy()
+ assert False
def imported_items(self):
"""Return a list of Items that should be added to the library.
@@ -547,9 +574,8 @@ class ImportTask(BaseImportTask):
assert False
def apply_metadata(self):
- """Copy metadata from match info to the items.
- """
- if config['import']['from_scratch']:
+ """Copy metadata from match info to the items."""
+ if config["import"]["from_scratch"]:
for item in self.match.mapping:
item.clear()
@@ -563,57 +589,60 @@ class ImportTask(BaseImportTask):
def remove_duplicates(self, lib):
duplicate_items = self.duplicate_items(lib)
- log.debug('removing {0} old duplicated items', len(duplicate_items))
+ log.debug("removing {0} old duplicated items", len(duplicate_items))
for item in duplicate_items:
item.remove()
if lib.directory in util.ancestry(item.path):
- log.debug('deleting duplicate {0}',
- util.displayable_path(item.path))
+ log.debug(
+ "deleting duplicate {0}", util.displayable_path(item.path)
+ )
util.remove(item.path)
- util.prune_dirs(os.path.dirname(item.path),
- lib.directory)
+ util.prune_dirs(os.path.dirname(item.path), lib.directory)
def set_fields(self, lib):
"""Sets the fields given at CLI or configuration to the specified
values, for both the album and all its items.
"""
items = self.imported_items()
- for field, view in config['import']['set_fields'].items():
+ for field, view in config["import"]["set_fields"].items():
value = view.get()
- log.debug('Set field {1}={2} for {0}',
- displayable_path(self.paths),
- field,
- value)
- self.album[field] = value
+ log.debug(
+ "Set field {1}={2} for {0}",
+ displayable_path(self.paths),
+ field,
+ value,
+ )
+ self.album.set_parse(field, format(self.album, value))
for item in items:
- item[field] = value
+ item.set_parse(field, format(item, value))
with lib.transaction():
for item in items:
item.store()
self.album.store()
def finalize(self, session):
- """Save progress, clean up files, and emit plugin event.
- """
+ """Save progress, clean up files, and emit plugin event."""
# Update progress.
if session.want_resume:
self.save_progress()
- if session.config['incremental'] and not (
+ if session.config["incremental"] and not (
# Should we skip recording to incremental list?
- self.skip and session.config['incremental_skip_later']
+ self.skip
+ and session.config["incremental_skip_later"]
):
self.save_history()
- self.cleanup(copy=session.config['copy'],
- delete=session.config['delete'],
- move=session.config['move'])
+ self.cleanup(
+ copy=session.config["copy"],
+ delete=session.config["delete"],
+ move=session.config["move"],
+ )
if not self.skip:
self._emit_imported(session.lib)
def cleanup(self, copy=False, delete=False, move=False):
- """Remove and prune imported paths.
- """
+ """Remove and prune imported paths."""
# Do not delete any files or prune directories when skipping.
if self.skip:
return
@@ -635,7 +664,7 @@ class ImportTask(BaseImportTask):
self.prune(old_path)
def _emit_imported(self, lib):
- plugins.send('album_imported', lib=lib, album=self.album)
+ plugins.send("album_imported", lib=lib, album=self.album)
def handle_created(self, session):
"""Send the `import_task_created` event for this task. Return a list of
@@ -643,7 +672,7 @@ class ImportTask(BaseImportTask):
list containing only the task itself, but plugins can replace the task
with new ones.
"""
- tasks = plugins.send('import_task_created', session=session, task=self)
+ tasks = plugins.send("import_task_created", session=session, task=self)
if not tasks:
tasks = [self]
else:
@@ -656,8 +685,9 @@ class ImportTask(BaseImportTask):
candidate IDs are stored in self.search_ids: if present, the
initial lookup is restricted to only those IDs.
"""
- artist, album, prop = \
- autotag.tag_album(self.items, search_ids=self.search_ids)
+ artist, album, prop = autotag.tag_album(
+ self.items, search_ids=self.search_ids
+ )
self.cur_artist = artist
self.cur_album = album
self.candidates = prop.candidates
@@ -667,26 +697,33 @@ class ImportTask(BaseImportTask):
"""Return a list of albums from `lib` with the same artist and
album name as the task.
"""
- artist, album = self.chosen_ident()
+ info = self.chosen_info()
+ info["albumartist"] = info["artist"]
- if artist is None:
+ if info["artist"] is None:
# As-is import with no artist. Skip check.
return []
- duplicates = []
- task_paths = {i.path for i in self.items if i}
- duplicate_query = dbcore.AndQuery((
- dbcore.MatchQuery('albumartist', artist),
- dbcore.MatchQuery('album', album),
- ))
+ # Construct a query to find duplicates with this metadata. We
+ # use a temporary Album object to generate any computed fields.
+ tmp_album = library.Album(lib, **info)
+ keys = config["import"]["duplicate_keys"]["album"].as_str_seq()
+ dup_query = library.Album.all_fields_query(
+ {key: tmp_album.get(key) for key in keys}
+ )
- for album in lib.albums(duplicate_query):
+ # Don't count albums with the same files as duplicates.
+ task_paths = {i.path for i in self.items if i}
+
+ duplicates = []
+ for album in lib.albums(dup_query):
# Check whether the album paths are all present in the task
# i.e. album is being completely re-imported by the task,
# in which case it is not a duplicate (will be replaced).
album_paths = {i.path for i in album.items()}
if not (album_paths <= task_paths):
duplicates.append(album)
+
return duplicates
def align_album_level_fields(self):
@@ -702,32 +739,37 @@ class ImportTask(BaseImportTask):
plur_albumartist, freq = util.plurality(
[i.albumartist or i.artist for i in self.items]
)
- if freq == len(self.items) or \
- (freq > 1 and
- float(freq) / len(self.items) >= SINGLE_ARTIST_THRESH):
+ if freq == len(self.items) or (
+ freq > 1
+ and float(freq) / len(self.items) >= SINGLE_ARTIST_THRESH
+ ):
# Single-artist album.
- changes['albumartist'] = plur_albumartist
- changes['comp'] = False
+ changes["albumartist"] = plur_albumartist
+ changes["comp"] = False
else:
# VA.
- changes['albumartist'] = config['va_name'].as_str()
- changes['comp'] = True
+ changes["albumartist"] = config["va_name"].as_str()
+ changes["comp"] = True
elif self.choice_flag in (action.APPLY, action.RETAG):
# Applying autotagged metadata. Just get AA from the first
# item.
if not self.items[0].albumartist:
- changes['albumartist'] = self.items[0].artist
+ changes["albumartist"] = self.items[0].artist
+ if not self.items[0].albumartists:
+ changes["albumartists"] = self.items[0].artists
if not self.items[0].mb_albumartistid:
- changes['mb_albumartistid'] = self.items[0].mb_artistid
+ changes["mb_albumartistid"] = self.items[0].mb_artistid
+ if not self.items[0].mb_albumartistids:
+ changes["mb_albumartistids"] = self.items[0].mb_artistids
# Apply new metadata.
for item in self.items:
item.update(changes)
def manipulate_files(self, operation=None, write=False, session=None):
- """ Copy, move, link, hardlink or reflink (depending on `operation`) the files
- as well as write metadata.
+ """Copy, move, link, hardlink or reflink (depending on `operation`)
+ the files as well as write metadata.
`operation` should be an instance of `util.MoveOperation`.
@@ -744,9 +786,11 @@ class ImportTask(BaseImportTask):
# move in-library files. (Out-of-library files are
# copied/moved as usual).
old_path = item.path
- if (operation != MoveOperation.MOVE
- and self.replaced_items[item]
- and session.lib.directory in util.ancestry(old_path)):
+ if (
+ operation != MoveOperation.MOVE
+ and self.replaced_items[item]
+ and session.lib.directory in util.ancestry(old_path)
+ ):
item.move()
# We moved the item, so remove the
# now-nonexistent file from old_paths.
@@ -763,17 +807,16 @@ class ImportTask(BaseImportTask):
for item in self.imported_items():
item.store()
- plugins.send('import_task_files', session=session, task=self)
+ plugins.send("import_task_files", session=session, task=self)
def add(self, lib):
- """Add the items as an album to the library and remove replaced items.
- """
+ """Add the items as an album to the library and remove replaced items."""
self.align_album_level_fields()
with lib.transaction():
self.record_replaced(lib)
self.remove_replaced(lib)
self.album = lib.add_album(self.imported_items())
- if 'data_source' in self.imported_items()[0]:
+ if "data_source" in self.imported_items()[0]:
self.album.data_source = self.imported_items()[0].data_source
self.reimport_metadata(lib)
@@ -785,13 +828,15 @@ class ImportTask(BaseImportTask):
self.replaced_albums = defaultdict(list)
replaced_album_ids = set()
for item in self.imported_items():
- dup_items = list(lib.items(
- dbcore.query.BytesQuery('path', item.path)
- ))
+ dup_items = list(
+ lib.items(dbcore.query.BytesQuery("path", item.path))
+ )
self.replaced_items[item] = dup_items
for dup_item in dup_items:
- if (not dup_item.album_id or
- dup_item.album_id in replaced_album_ids):
+ if (
+ not dup_item.album_id
+ or dup_item.album_id in replaced_album_ids
+ ):
continue
replaced_album = dup_item._cached_album
if replaced_album:
@@ -802,20 +847,59 @@ class ImportTask(BaseImportTask):
"""For reimports, preserves metadata for reimported items and
albums.
"""
+
+ def _reduce_and_log(new_obj, existing_fields, overwrite_keys):
+ """Some flexible attributes should be overwritten (rather than
+ preserved) on reimports; Copies existing_fields, logs and removes
+ entries that should not be preserved and returns a dict containing
+ those fields left to actually be preserved.
+ """
+ noun = "album" if isinstance(new_obj, library.Album) else "item"
+ existing_fields = dict(existing_fields)
+ overwritten_fields = [
+ k
+ for k in existing_fields
+ if k in overwrite_keys
+ and new_obj.get(k)
+ and existing_fields.get(k) != new_obj.get(k)
+ ]
+ if overwritten_fields:
+ log.debug(
+ "Reimported {} {}. Not preserving flexible attributes {}. "
+ "Path: {}",
+ noun,
+ new_obj.id,
+ overwritten_fields,
+ displayable_path(new_obj.path),
+ )
+ for key in overwritten_fields:
+ del existing_fields[key]
+ return existing_fields
+
if self.is_album:
replaced_album = self.replaced_albums.get(self.album.path)
if replaced_album:
+ album_fields = _reduce_and_log(
+ self.album,
+ replaced_album._values_flex,
+ REIMPORT_FRESH_FIELDS_ALBUM,
+ )
self.album.added = replaced_album.added
- self.album.update(replaced_album._values_flex)
+ self.album.update(album_fields)
self.album.artpath = replaced_album.artpath
self.album.store()
log.debug(
- 'Reimported album: added {0}, flexible '
- 'attributes {1} from album {2} for {3}',
- self.album.added,
- replaced_album._values_flex.keys(),
- replaced_album.id,
- displayable_path(self.album.path)
+ "Reimported album {}. Preserving attribute ['added']. "
+ "Path: {}",
+ self.album.id,
+ displayable_path(self.album.path),
+ )
+ log.debug(
+ "Reimported album {}. Preserving flexible attributes {}. "
+ "Path: {}",
+ self.album.id,
+ list(album_fields.keys()),
+ displayable_path(self.album.path),
)
for item in self.imported_items():
@@ -824,19 +908,21 @@ class ImportTask(BaseImportTask):
if dup_item.added and dup_item.added != item.added:
item.added = dup_item.added
log.debug(
- 'Reimported item added {0} '
- 'from item {1} for {2}',
- item.added,
- dup_item.id,
- displayable_path(item.path)
+ "Reimported item {}. Preserving attribute ['added']. "
+ "Path: {}",
+ item.id,
+ displayable_path(item.path),
)
- item.update(dup_item._values_flex)
+ item_fields = _reduce_and_log(
+ item, dup_item._values_flex, REIMPORT_FRESH_FIELDS_ITEM
+ )
+ item.update(item_fields)
log.debug(
- 'Reimported item flexible attributes {0} '
- 'from item {1} for {2}',
- dup_item._values_flex.keys(),
- dup_item.id,
- displayable_path(item.path)
+ "Reimported item {}. Preserving flexible attributes {}. "
+ "Path: {}",
+ item.id,
+ list(item_fields.keys()),
+ displayable_path(item.path),
)
item.store()
@@ -846,23 +932,26 @@ class ImportTask(BaseImportTask):
"""
for item in self.imported_items():
for dup_item in self.replaced_items[item]:
- log.debug('Replacing item {0}: {1}',
- dup_item.id, displayable_path(item.path))
+ log.debug(
+ "Replacing item {0}: {1}",
+ dup_item.id,
+ displayable_path(item.path),
+ )
dup_item.remove()
- log.debug('{0} of {1} items replaced',
- sum(bool(l) for l in self.replaced_items.values()),
- len(self.imported_items()))
+ log.debug(
+ "{0} of {1} items replaced",
+ sum(bool(l) for l in self.replaced_items.values()),
+ len(self.imported_items()),
+ )
def choose_match(self, session):
- """Ask the session which match should apply and apply it.
- """
+ """Ask the session which match should apply and apply it."""
choice = session.choose_match(self)
self.set_choice(choice)
session.log_choice(self)
def reload(self):
- """Reload albums and items from the database.
- """
+ """Reload albums and items from the database."""
for item in self.imported_items():
item.load()
self.album.load()
@@ -876,15 +965,16 @@ class ImportTask(BaseImportTask):
the file still exists, no pruning is performed, so it's safe to
call when the file in question may not have been removed.
"""
- if self.toppath and not os.path.exists(filename):
- util.prune_dirs(os.path.dirname(filename),
- self.toppath,
- clutter=config['clutter'].as_str_seq())
+ if self.toppath and not os.path.exists(syspath(filename)):
+ util.prune_dirs(
+ os.path.dirname(filename),
+ self.toppath,
+ clutter=config["clutter"].as_str_seq(),
+ )
class SingletonImportTask(ImportTask):
- """ImportTask for a single track that is not associated to an album.
- """
+ """ImportTask for a single track that is not associated to an album."""
def __init__(self, toppath, item):
super().__init__(toppath, [item.path], [item])
@@ -892,12 +982,17 @@ class SingletonImportTask(ImportTask):
self.is_album = False
self.paths = [item.path]
- def chosen_ident(self):
- assert self.choice_flag in (action.ASIS, action.APPLY, action.RETAG)
+ def chosen_info(self):
+ """Return a dictionary of metadata about the current choice.
+ May only be called when the choice flag is ASIS or RETAG
+ (in which case the data comes from the files' current metadata)
+ or APPLY (in which case the data comes from the choice).
+ """
+ assert self.choice_flag in (action.ASIS, action.RETAG, action.APPLY)
if self.choice_flag in (action.ASIS, action.RETAG):
- return (self.item.artist, self.item.title)
+ return dict(self.item)
elif self.choice_flag is action.APPLY:
- return (self.match.info.artist, self.match.info.title)
+ return self.match.info.copy()
def imported_items(self):
return [self.item]
@@ -907,7 +1002,7 @@ class SingletonImportTask(ImportTask):
def _emit_imported(self, lib):
for item in self.imported_items():
- plugins.send('item_imported', lib=lib, item=item)
+ plugins.send("item_imported", lib=lib, item=item)
def lookup_candidates(self):
prop = autotag.tag_item(self.item, search_ids=self.search_ids)
@@ -918,14 +1013,18 @@ class SingletonImportTask(ImportTask):
"""Return a list of items from `lib` that have the same artist
and title as the task.
"""
- artist, title = self.chosen_ident()
+ info = self.chosen_info()
+
+ # Query for existing items using the same metadata. We use a
+ # temporary `Item` object to generate any computed fields.
+ tmp_item = library.Item(lib, **info)
+ keys = config["import"]["duplicate_keys"]["item"].as_str_seq()
+ dup_query = library.Album.all_fields_query(
+ {key: tmp_item.get(key) for key in keys}
+ )
found_items = []
- query = dbcore.AndQuery((
- dbcore.MatchQuery('artist', artist),
- dbcore.MatchQuery('title', title),
- ))
- for other_item in lib.items(query):
+ for other_item in lib.items(dup_query):
# Existing items not considered duplicates.
if other_item.path != self.item.path:
found_items.append(other_item)
@@ -944,8 +1043,7 @@ class SingletonImportTask(ImportTask):
raise NotImplementedError
def choose_match(self, session):
- """Ask the session which match should apply and apply it.
- """
+ """Ask the session which match should apply and apply it."""
choice = session.choose_item(self)
self.set_choice(choice)
session.log_choice(self)
@@ -957,13 +1055,15 @@ class SingletonImportTask(ImportTask):
"""Sets the fields given at CLI or configuration to the specified
values, for the singleton item.
"""
- for field, view in config['import']['set_fields'].items():
+ for field, view in config["import"]["set_fields"].items():
value = view.get()
- log.debug('Set field {1}={2} for {0}',
- displayable_path(self.paths),
- field,
- value)
- self.item[field] = value
+ log.debug(
+ "Set field {1}={2} for {0}",
+ displayable_path(self.paths),
+ field,
+ value,
+ )
+ self.item.set_parse(field, format(self.item, value))
self.item.store()
@@ -1036,7 +1136,7 @@ class ArchiveImportTask(SentinelImportTask):
return False
for path_test, _ in cls.handlers():
- if path_test(util.py3_path(path)):
+ if path_test(os.fsdecode(path)):
return True
return False
@@ -1049,20 +1149,22 @@ class ArchiveImportTask(SentinelImportTask):
handled by `ArchiveClass`. `ArchiveClass` is a class that
implements the same interface as `tarfile.TarFile`.
"""
- if not hasattr(cls, '_handlers'):
+ if not hasattr(cls, "_handlers"):
cls._handlers = []
- from zipfile import is_zipfile, ZipFile
+ from zipfile import ZipFile, is_zipfile
+
cls._handlers.append((is_zipfile, ZipFile))
import tarfile
+
cls._handlers.append((tarfile.is_tarfile, tarfile.open))
try:
- from rarfile import is_rarfile, RarFile
+ from rarfile import RarFile, is_rarfile
except ImportError:
pass
else:
cls._handlers.append((is_rarfile, RarFile))
try:
- from py7zr import is_7zfile, SevenZipFile
+ from py7zr import SevenZipFile, is_7zfile
except ImportError:
pass
else:
@@ -1071,25 +1173,39 @@ class ArchiveImportTask(SentinelImportTask):
return cls._handlers
def cleanup(self, **kwargs):
- """Removes the temporary directory the archive was extracted to.
- """
+ """Removes the temporary directory the archive was extracted to."""
if self.extracted:
- log.debug('Removing extracted directory: {0}',
- displayable_path(self.toppath))
- shutil.rmtree(self.toppath)
+ log.debug(
+ "Removing extracted directory: {0}",
+ displayable_path(self.toppath),
+ )
+ shutil.rmtree(syspath(self.toppath))
def extract(self):
"""Extracts the archive to a temporary directory and sets
`toppath` to that directory.
"""
for path_test, handler_class in self.handlers():
- if path_test(util.py3_path(self.toppath)):
+ if path_test(os.fsdecode(self.toppath)):
break
extract_to = mkdtemp()
- archive = handler_class(util.py3_path(self.toppath), mode='r')
+ archive = handler_class(os.fsdecode(self.toppath), mode="r")
try:
archive.extractall(extract_to)
+
+ # Adjust the files' mtimes to match the information from the
+ # archive. Inspired by: https://stackoverflow.com/q/9813243
+ for f in archive.infolist():
+ # The date_time will need to adjusted otherwise
+ # the item will have the current date_time of extraction.
+ # The (0, 0, -1) is added to date_time because the
+ # function time.mktime expects a 9-element tuple.
+ # The -1 indicates that the DST flag is unknown.
+ date_time = time.mktime(f.date_time + (0, 0, -1))
+ fullpath = os.path.join(extract_to, f.filename)
+ os.utime(fullpath, (date_time, date_time))
+
finally:
archive.close()
self.extracted = True
@@ -1135,7 +1251,7 @@ class ImportTaskFactory:
# Search for music in the directory.
for dirs, paths in self.paths():
- if self.session.config['singletons']:
+ if self.session.config["singletons"]:
for path in paths:
tasks = self._create(self.singleton(path))
yield from tasks
@@ -1178,7 +1294,7 @@ class ImportTaskFactory:
"""
if not os.path.isdir(syspath(self.toppath)):
yield [self.toppath], [self.toppath]
- elif self.session.config['flat']:
+ elif self.session.config["flat"]:
paths = []
for dirs, paths_in_dir in albums_in_dir(self.toppath):
paths += paths_in_dir
@@ -1188,11 +1304,11 @@ class ImportTaskFactory:
yield dirs, paths
def singleton(self, path):
- """Return a `SingletonImportTask` for the music file.
- """
+ """Return a `SingletonImportTask` for the music file."""
if self.session.already_imported(self.toppath, [path]):
- log.debug('Skipping previously-imported path: {0}',
- displayable_path(path))
+ log.debug(
+ "Skipping previously-imported path: {0}", displayable_path(path)
+ )
self.skipped += 1
return None
@@ -1215,8 +1331,9 @@ class ImportTaskFactory:
dirs = list({os.path.dirname(p) for p in paths})
if self.session.already_imported(self.toppath, dirs):
- log.debug('Skipping previously-imported path: {0}',
- displayable_path(dirs))
+ log.debug(
+ "Skipping previously-imported path: {0}", displayable_path(dirs)
+ )
self.skipped += 1
return None
@@ -1243,24 +1360,24 @@ class ImportTaskFactory:
"""
assert self.is_archive
- if not (self.session.config['move'] or
- self.session.config['copy']):
- log.warning("Archive importing requires either "
- "'copy' or 'move' to be enabled.")
+ if not (self.session.config["move"] or self.session.config["copy"]):
+ log.warning(
+ "Archive importing requires either "
+ "'copy' or 'move' to be enabled."
+ )
return
- log.debug('Extracting archive: {0}',
- displayable_path(self.toppath))
+ log.debug("Extracting archive: {0}", displayable_path(self.toppath))
archive_task = ArchiveImportTask(self.toppath)
try:
archive_task.extract()
except Exception as exc:
- log.error('extraction failed: {0}', exc)
+ log.error("extraction failed: {0}", exc)
return
# Now read albums from the extracted directory.
self.toppath = archive_task.toppath
- log.debug('Archive extracted to: {0}', self.toppath)
+ log.debug("Archive extracted to: {0}", self.toppath)
return archive_task
def read_item(self, path):
@@ -1276,14 +1393,14 @@ class ImportTaskFactory:
# Silently ignore non-music files.
pass
elif isinstance(exc.reason, mediafile.UnreadableFileError):
- log.warning('unreadable file: {0}', displayable_path(path))
+ log.warning("unreadable file: {0}", displayable_path(path))
else:
- log.error('error reading {0}: {1}',
- displayable_path(path), exc)
+ log.error("error reading {0}: {1}", displayable_path(path), exc)
# Pipeline utilities
+
def _freshen_items(items):
# Clear IDs from re-tagged items so they appear "fresh" when
# we add them back to the library.
@@ -1294,7 +1411,7 @@ def _freshen_items(items):
def _extend_pipeline(tasks, *stages):
# Return pipeline extension for stages with list of tasks
- if type(tasks) == list:
+ if isinstance(tasks, list):
task_iter = iter(tasks)
else:
task_iter = tasks
@@ -1305,6 +1422,7 @@ def _extend_pipeline(tasks, *stages):
# Full-album pipeline stages.
+
def read_tasks(session):
"""A generator yielding all the albums (as ImportTask objects) found
in the user-specified list of paths. In the case of a singleton
@@ -1321,12 +1439,11 @@ def read_tasks(session):
skipped += task_factory.skipped
if not task_factory.imported:
- log.warning('No files imported from {0}',
- displayable_path(toppath))
+ log.warning("No files imported from {0}", displayable_path(toppath))
# Show skipped directories (due to incremental/resume).
if skipped:
- log.info('Skipped {0} paths.', skipped)
+ log.info("Skipped {0} paths.", skipped)
def query_tasks(session):
@@ -1334,7 +1451,7 @@ def query_tasks(session):
Instead of finding files from the filesystem, a query is used to
match items from the library.
"""
- if session.config['singletons']:
+ if session.config["singletons"]:
# Search for items.
for item in session.lib.items(session.query):
task = SingletonImportTask(None, item)
@@ -1344,8 +1461,12 @@ def query_tasks(session):
else:
# Search for albums.
for album in session.lib.albums(session.query):
- log.debug('yielding album {0}: {1} - {2}',
- album.id, album.albumartist, album.album)
+ log.debug(
+ "yielding album {0}: {1} - {2}",
+ album.id,
+ album.albumartist,
+ album.album,
+ )
items = list(album.items())
_freshen_items(items)
@@ -1366,12 +1487,12 @@ def lookup_candidates(session, task):
# abstraction.
return
- plugins.send('import_task_start', session=session, task=task)
- log.debug('Looking up: {0}', displayable_path(task.paths))
+ plugins.send("import_task_start", session=session, task=task)
+ log.debug("Looking up: {0}", displayable_path(task.paths))
# Restrict the initial lookup to IDs specified by the user via the -m
# option. Currently all the IDs are passed onto the tasks directly.
- task.search_ids = session.config['search_ids'].as_str_seq()
+ task.search_ids = session.config["search_ids"].as_str_seq()
task.lookup_candidates()
@@ -1387,7 +1508,7 @@ def user_query(session, task):
and the processed task is yielded.
It emits the ``import_task_choice`` event for plugins. Plugins have
- acces to the choice via the ``taks.choice_flag`` property and may
+ access to the choice via the ``task.choice_flag`` property and may
choose to change it.
"""
if task.skip:
@@ -1398,7 +1519,7 @@ def user_query(session, task):
# Ask the user for a choice.
task.choose_match(session)
- plugins.send('import_task_choice', session=session, task=task)
+ plugins.send("import_task_choice", session=session, task=task)
# As-tracks: transition to singleton workflow.
if task.choice_flag is action.TRACKS:
@@ -1409,16 +1530,18 @@ def user_query(session, task):
yield from task.handle_created(session)
yield SentinelImportTask(task.toppath, task.paths)
- return _extend_pipeline(emitter(task),
- lookup_candidates(session),
- user_query(session))
+ return _extend_pipeline(
+ emitter(task), lookup_candidates(session), user_query(session)
+ )
# As albums: group items by albums and create task for each album
if task.choice_flag is action.ALBUMS:
- return _extend_pipeline([task],
- group_albums(session),
- lookup_candidates(session),
- user_query(session))
+ return _extend_pipeline(
+ [task],
+ group_albums(session),
+ lookup_candidates(session),
+ user_query(session),
+ )
resolve_duplicates(session, task)
@@ -1434,12 +1557,13 @@ def user_query(session, task):
# Record merged paths in the session so they are not reimported
session.mark_merged(duplicate_paths)
- merged_task = ImportTask(None, task.paths + duplicate_paths,
- task.items + duplicate_items)
+ merged_task = ImportTask(
+ None, task.paths + duplicate_paths, task.items + duplicate_items
+ )
- return _extend_pipeline([merged_task],
- lookup_candidates(session),
- user_query(session))
+ return _extend_pipeline(
+ [merged_task], lookup_candidates(session), user_query(session)
+ )
apply_choice(session, task)
return task
@@ -1452,30 +1576,32 @@ def resolve_duplicates(session, task):
if task.choice_flag in (action.ASIS, action.APPLY, action.RETAG):
found_duplicates = task.find_duplicates(session.lib)
if found_duplicates:
- log.debug('found duplicates: {}'.format(
- [o.id for o in found_duplicates]
- ))
+ log.debug(
+ "found duplicates: {}".format([o.id for o in found_duplicates])
+ )
# Get the default action to follow from config.
- duplicate_action = config['import']['duplicate_action'].as_choice({
- 'skip': 's',
- 'keep': 'k',
- 'remove': 'r',
- 'merge': 'm',
- 'ask': 'a',
- })
- log.debug('default action for duplicates: {0}', duplicate_action)
+ duplicate_action = config["import"]["duplicate_action"].as_choice(
+ {
+ "skip": "s",
+ "keep": "k",
+ "remove": "r",
+ "merge": "m",
+ "ask": "a",
+ }
+ )
+ log.debug("default action for duplicates: {0}", duplicate_action)
- if duplicate_action == 's':
+ if duplicate_action == "s":
# Skip new.
task.set_choice(action.SKIP)
- elif duplicate_action == 'k':
+ elif duplicate_action == "k":
# Keep both. Do nothing; leave the choice intact.
pass
- elif duplicate_action == 'r':
+ elif duplicate_action == "r":
# Remove old.
task.should_remove_duplicates = True
- elif duplicate_action == 'm':
+ elif duplicate_action == "m":
# Merge duplicates together
task.should_merge_duplicates = True
else:
@@ -1495,7 +1621,7 @@ def import_asis(session, task):
if task.skip:
return
- log.info('{}', displayable_path(task.paths))
+ log.info("{}", displayable_path(task.paths))
task.set_choice(action.ASIS)
apply_choice(session, task)
@@ -1510,7 +1636,7 @@ def apply_choice(session, task):
# Change metadata.
if task.apply:
task.apply_metadata()
- plugins.send('import_task_apply', session=session, task=task)
+ plugins.send("import_task_apply", session=session, task=task)
task.add(session.lib)
@@ -1519,7 +1645,7 @@ def apply_choice(session, task):
# NOTE: This cannot be done before the ``task.add()`` call above,
# because then the ``ImportTask`` won't have an `album` for which
# it can set the fields.
- if config['import']['set_fields']:
+ if config["import"]["set_fields"]:
task.set_fields(session.lib)
@@ -1550,22 +1676,24 @@ def manipulate_files(session, task):
if task.should_remove_duplicates:
task.remove_duplicates(session.lib)
- if session.config['move']:
+ if session.config["move"]:
operation = MoveOperation.MOVE
- elif session.config['copy']:
+ elif session.config["copy"]:
operation = MoveOperation.COPY
- elif session.config['link']:
+ elif session.config["link"]:
operation = MoveOperation.LINK
- elif session.config['hardlink']:
+ elif session.config["hardlink"]:
operation = MoveOperation.HARDLINK
- elif session.config['reflink']:
+ elif session.config["reflink"] == "auto":
+ operation = MoveOperation.REFLINK_AUTO
+ elif session.config["reflink"]:
operation = MoveOperation.REFLINK
else:
operation = None
task.manipulate_files(
operation,
- write=session.config['write'],
+ write=session.config["write"],
session=session,
)
@@ -1575,14 +1703,13 @@ def manipulate_files(session, task):
@pipeline.stage
def log_files(session, task):
- """A coroutine (pipeline stage) to log each file to be imported.
- """
+ """A coroutine (pipeline stage) to log each file to be imported."""
if isinstance(task, SingletonImportTask):
- log.info('Singleton: {0}', displayable_path(task.item['path']))
+ log.info("Singleton: {0}", displayable_path(task.item["path"]))
elif task.items:
- log.info('Album: {0}', displayable_path(task.paths[0]))
+ log.info("Album: {0}", displayable_path(task.paths[0]))
for item in task.items:
- log.info(' {0}', displayable_path(item['path']))
+ log.info(" {0}", displayable_path(item["path"]))
def group_albums(session):
@@ -1592,6 +1719,7 @@ def group_albums(session):
Groups are identified using their artist and album fields. The
pipeline stage emits new album tasks for each discovered group.
"""
+
def group(item):
return (item.albumartist or item.artist, item.album)
@@ -1604,16 +1732,15 @@ def group_albums(session):
sorted_items = sorted(task.items, key=group)
for _, items in itertools.groupby(sorted_items, group):
items = list(items)
- task = ImportTask(task.toppath, [i.path for i in items],
- items)
+ task = ImportTask(task.toppath, [i.path for i in items], items)
tasks += task.handle_created(session)
tasks.append(SentinelImportTask(task.toppath, task.paths))
task = pipeline.multiple(tasks)
-MULTIDISC_MARKERS = (br'dis[ck]', br'cd')
-MULTIDISC_PAT_FMT = br'^(.*%s[\W_]*)\d'
+MULTIDISC_MARKERS = (rb"dis[ck]", rb"cd")
+MULTIDISC_PAT_FMT = rb"^(.*%s[\W_]*)\d"
def is_subdir_of_any_in_list(path, dirs):
@@ -1631,21 +1758,21 @@ def albums_in_dir(path):
containing any media files is an album.
"""
collapse_pat = collapse_paths = collapse_items = None
- ignore = config['ignore'].as_str_seq()
- ignore_hidden = config['ignore_hidden'].get(bool)
+ ignore = config["ignore"].as_str_seq()
+ ignore_hidden = config["ignore_hidden"].get(bool)
- for root, dirs, files in sorted_walk(path, ignore=ignore,
- ignore_hidden=ignore_hidden,
- logger=log):
+ for root, dirs, files in sorted_walk(
+ path, ignore=ignore, ignore_hidden=ignore_hidden, logger=log
+ ):
items = [os.path.join(root, f) for f in files]
# If we're currently collapsing the constituent directories in a
# multi-disc album, check whether we should continue collapsing
# and add the current directory. If so, just add the directory
# and move on to the next directory. If not, stop collapsing.
if collapse_paths:
- if (is_subdir_of_any_in_list(root, collapse_paths)) or \
- (collapse_pat and
- collapse_pat.match(os.path.basename(root))):
+ if (is_subdir_of_any_in_list(root, collapse_paths)) or (
+ collapse_pat and collapse_pat.match(os.path.basename(root))
+ ):
# Still collapsing.
collapse_paths.append(root)
collapse_items += items
@@ -1665,7 +1792,7 @@ def albums_in_dir(path):
start_collapsing = False
for marker in MULTIDISC_MARKERS:
# We're using replace on %s due to lack of .format() on bytestrings
- p = MULTIDISC_PAT_FMT.replace(b'%s', marker)
+ p = MULTIDISC_PAT_FMT.replace(b"%s", marker)
marker_pat = re.compile(p, re.I)
match = marker_pat.match(os.path.basename(root))
@@ -1683,8 +1810,7 @@ def albums_in_dir(path):
if match:
match_group = re.escape(match.group(1))
subdir_pat = re.compile(
- b''.join([b'^', match_group, br'\d']),
- re.I
+ b"".join([b"^", match_group, rb"\d"]), re.I
)
else:
start_collapsing = False
@@ -1706,8 +1832,7 @@ def albums_in_dir(path):
# Set the current pattern to match directories with the same
# prefix as this one, followed by a digit.
collapse_pat = re.compile(
- b''.join([b'^', re.escape(match.group(1)), br'\d']),
- re.I
+ b"".join([b"^", re.escape(match.group(1)), rb"\d"]), re.I
)
break
diff --git a/lib/beets/library.py b/lib/beets/library.py
index 888836cd..6d0ee613 100644
--- a/lib/beets/library.py
+++ b/lib/beets/library.py
@@ -14,37 +14,62 @@
"""The core data store and collection logic for beets.
"""
+from __future__ import annotations
import os
-import sys
-import unicodedata
-import time
import re
-import string
import shlex
+import string
+import sys
+import time
+import unicodedata
+from functools import cached_property
-from beets import logging
from mediafile import MediaFile, UnreadableFileError
-from beets import plugins
-from beets import util
-from beets.util import bytestring_path, syspath, normpath, samefile, \
- MoveOperation, lazy_property
-from beets.util.functemplate import template, Template
-from beets import dbcore
-from beets.dbcore import types
+
import beets
+from beets import dbcore, logging, plugins, util
+from beets.dbcore import Results, types
+from beets.util import (
+ MoveOperation,
+ bytestring_path,
+ cached_classproperty,
+ normpath,
+ samefile,
+ syspath,
+)
+from beets.util.functemplate import Template, template
# To use the SQLite "blob" type, it doesn't suffice to provide a byte
# string; SQLite treats that as encoded text. Wrapping it in a
# `memoryview` tells it that we actually mean non-text data.
BLOB_TYPE = memoryview
-log = logging.getLogger('beets')
+log = logging.getLogger("beets")
# Library-specific query types.
-class PathQuery(dbcore.FieldQuery):
+
+class SingletonQuery(dbcore.FieldQuery[str]):
+ """This query is responsible for the 'singleton' lookup.
+
+ It is based on the FieldQuery and constructs a SQL clause
+ 'album_id is NULL' which yields the same result as the previous filter
+ in Python but is more performant since it's done in SQL.
+
+ Using util.str2bool ensures that lookups like singleton:true, singleton:1
+ and singleton:false, singleton:0 are handled consistently.
+ """
+
+ def __new__(cls, field: str, value: str, *args, **kwargs):
+ query = dbcore.query.NoneQuery("album_id")
+ if util.str2bool(value):
+ return query
+ return dbcore.query.NotQuery(query)
+
+
+class PathQuery(dbcore.FieldQuery[bytes]):
"""A query that matches all items under a given path.
Matching can either be case-insensitive or case-sensitive. By
@@ -52,30 +77,40 @@ class PathQuery(dbcore.FieldQuery):
and case-sensitive otherwise.
"""
+ # For tests
+ force_implicit_query_detection = False
+
def __init__(self, field, pattern, fast=True, case_sensitive=None):
- """Create a path query. `pattern` must be a path, either to a
- file or a directory.
+ """Create a path query.
+
+ `pattern` must be a path, either to a file or a directory.
`case_sensitive` can be a bool or `None`, indicating that the
behavior should depend on the filesystem.
"""
super().__init__(field, pattern, fast)
+ path = util.normpath(pattern)
+
# By default, the case sensitivity depends on the filesystem
# that the query path is located on.
if case_sensitive is None:
- path = util.bytestring_path(util.normpath(pattern))
- case_sensitive = beets.util.case_sensitive(path)
+ case_sensitive = util.case_sensitive(path)
self.case_sensitive = case_sensitive
# Use a normalized-case pattern for case-insensitive matches.
if not case_sensitive:
- pattern = pattern.lower()
+ # We need to lowercase the entire path, not just the pattern.
+ # In particular, on Windows, the drive letter is otherwise not
+ # lowercased.
+ # This also ensures that the `match()` method below and the SQL
+ # from `col_clause()` do the same thing.
+ path = path.lower()
# Match the path as a single file.
- self.file_path = util.bytestring_path(util.normpath(pattern))
+ self.file_path = path
# As a directory (prefix).
- self.dir_path = util.bytestring_path(os.path.join(self.file_path, b''))
+ self.dir_path = os.path.join(path, b"")
@classmethod
def is_path_query(cls, query_part):
@@ -83,17 +118,20 @@ class PathQuery(dbcore.FieldQuery):
Condition: separator precedes colon and the file exists.
"""
- colon = query_part.find(':')
+ colon = query_part.find(":")
if colon != -1:
query_part = query_part[:colon]
# Test both `sep` and `altsep` (i.e., both slash and backslash on
# Windows).
- return (
- (os.sep in query_part or
- (os.altsep and os.altsep in query_part)) and
- os.path.exists(syspath(normpath(query_part)))
- )
+ if not (
+ os.sep in query_part or (os.altsep and os.altsep in query_part)
+ ):
+ return False
+
+ if cls.force_implicit_query_detection:
+ return True
+ return os.path.exists(syspath(normpath(query_part)))
def match(self, item):
path = item.path if self.case_sensitive else item.path.lower()
@@ -104,32 +142,42 @@ class PathQuery(dbcore.FieldQuery):
dir_blob = BLOB_TYPE(self.dir_path)
if self.case_sensitive:
- query_part = '({0} = ?) || (substr({0}, 1, ?) = ?)'
+ query_part = "({0} = ?) || (substr({0}, 1, ?) = ?)"
else:
- query_part = '(BYTELOWER({0}) = BYTELOWER(?)) || \
- (substr(BYTELOWER({0}), 1, ?) = BYTELOWER(?))'
+ query_part = "(BYTELOWER({0}) = BYTELOWER(?)) || \
+ (substr(BYTELOWER({0}), 1, ?) = BYTELOWER(?))"
- return query_part.format(self.field), \
- (file_blob, len(dir_blob), dir_blob)
+ return query_part.format(self.field), (
+ file_blob,
+ len(dir_blob),
+ dir_blob,
+ )
+
+ def __repr__(self) -> str:
+ return (
+ f"{self.__class__.__name__}({self.field!r}, {self.pattern!r}, "
+ f"fast={self.fast}, case_sensitive={self.case_sensitive})"
+ )
# Library-specific field types.
+
class DateType(types.Float):
# TODO representation should be `datetime` object
# TODO distinguish between date and time types
query = dbcore.query.DateQuery
def format(self, value):
- return time.strftime(beets.config['time_format'].as_str(),
- time.localtime(value or 0))
+ return time.strftime(
+ beets.config["time_format"].as_str(), time.localtime(value or 0)
+ )
def parse(self, string):
try:
# Try a formatted date string.
return time.mktime(
- time.strptime(string,
- beets.config['time_format'].as_str())
+ time.strptime(string, beets.config["time_format"].as_str())
)
except ValueError:
# Fall back to a plain timestamp number.
@@ -139,18 +187,21 @@ class DateType(types.Float):
return self.null
-class PathType(types.Type):
- """A dbcore type for filesystem paths. These are represented as
- `bytes` objects, in keeping with the Unix filesystem abstraction.
+class PathType(types.Type[bytes, bytes]):
+ """A dbcore type for filesystem paths.
+
+ These are represented as `bytes` objects, in keeping with
+ the Unix filesystem abstraction.
"""
- sql = 'BLOB'
+ sql = "BLOB"
query = PathQuery
model_type = bytes
def __init__(self, nullable=False):
- """Create a path type object. `nullable` controls whether the
- type may be missing, i.e., None.
+ """Create a path type object.
+
+ `nullable` controls whether the type may be missing, i.e., None.
"""
self.nullable = nullable
@@ -159,7 +210,7 @@ class PathType(types.Type):
if self.nullable:
return None
else:
- return b''
+ return b""
def format(self, value):
return util.displayable_path(value)
@@ -193,12 +244,13 @@ class MusicalKey(types.String):
The standard format is C, Cm, C#, C#m, etc.
"""
+
ENHARMONIC = {
- r'db': 'c#',
- r'eb': 'd#',
- r'gb': 'f#',
- r'ab': 'g#',
- r'bb': 'a#',
+ r"db": "c#",
+ r"eb": "d#",
+ r"gb": "f#",
+ r"ab": "g#",
+ r"bb": "a#",
}
null = None
@@ -207,8 +259,8 @@ class MusicalKey(types.String):
key = key.lower()
for flat, sharp in self.ENHARMONIC.items():
key = re.sub(flat, sharp, key)
- key = re.sub(r'[\W\s]+minor', 'm', key)
- key = re.sub(r'[\W\s]+major', '', key)
+ key = re.sub(r"[\W\s]+minor", "m", key)
+ key = re.sub(r"[\W\s]+major", "", key)
return key.capitalize()
def normalize(self, key):
@@ -220,10 +272,11 @@ class MusicalKey(types.String):
class DurationType(types.Float):
"""Human-friendly (M:SS) representation of a time interval."""
+
query = dbcore.query.DurationQuery
def format(self, value):
- if not beets.config['format_raw_length'].get(bool):
+ if not beets.config["format_raw_length"].get(bool):
return beets.ui.human_seconds_short(value or 0.0)
else:
return value
@@ -242,6 +295,7 @@ class DurationType(types.Float):
# Library-specific sort types.
+
class SmartArtistSort(dbcore.query.Sort):
"""Sort by artist (either album artist or track artist),
prioritizing the sort field over the raw field.
@@ -254,35 +308,43 @@ class SmartArtistSort(dbcore.query.Sort):
def order_clause(self):
order = "ASC" if self.ascending else "DESC"
- field = 'albumartist' if self.album else 'artist'
- collate = 'COLLATE NOCASE' if self.case_insensitive else ''
- return ('(CASE {0}_sort WHEN NULL THEN {0} '
- 'WHEN "" THEN {0} '
- 'ELSE {0}_sort END) {1} {2}').format(field, collate, order)
+ field = "albumartist" if self.album else "artist"
+ collate = "COLLATE NOCASE" if self.case_insensitive else ""
+ return (
+ "(CASE {0}_sort WHEN NULL THEN {0} "
+ 'WHEN "" THEN {0} '
+ "ELSE {0}_sort END) {1} {2}"
+ ).format(field, collate, order)
def sort(self, objs):
if self.album:
+
def field(a):
return a.albumartist_sort or a.albumartist
+
else:
+
def field(i):
return i.artist_sort or i.artist
if self.case_insensitive:
+
def key(x):
return field(x).lower()
+
else:
key = field
return sorted(objs, key=key, reverse=not self.ascending)
# Special path format key.
-PF_KEY_DEFAULT = 'default'
+PF_KEY_DEFAULT = "default"
# Exceptions.
class FileOperationError(Exception):
- """Indicates an error when interacting with a file on disk.
+ """Indicate an error when interacting with a file on disk.
+
Possibilities include an unsupported media type, a permissions
error, and an unhandled Mutagen exception.
"""
@@ -295,45 +357,36 @@ class FileOperationError(Exception):
self.path = path
self.reason = reason
- def text(self):
- """Get a string representing the error. Describes both the
- underlying reason and the file path in question.
- """
- return '{}: {}'.format(
- util.displayable_path(self.path),
- str(self.reason)
- )
+ def __str__(self):
+ """Get a string representing the error.
- # define __str__ as text to avoid infinite loop on super() calls
- # with @six.python_2_unicode_compatible
- __str__ = text
+ Describe both the underlying reason and the file path in question.
+ """
+ return f"{util.displayable_path(self.path)}: {self.reason}"
class ReadError(FileOperationError):
- """An error while reading a file (i.e. in `Item.read`).
- """
+ """An error while reading a file (i.e. in `Item.read`)."""
def __str__(self):
- return 'error reading ' + super().text()
+ return "error reading " + str(super())
class WriteError(FileOperationError):
- """An error while writing a file (i.e. in `Item.write`).
- """
+ """An error while writing a file (i.e. in `Item.write`)."""
def __str__(self):
- return 'error writing ' + super().text()
+ return "error writing " + str(super())
# Item and Album model classes.
-class LibModel(dbcore.Model):
- """Shared concrete functionality for Items and Albums.
- """
- _format_config_key = None
- """Config key that specifies how an instance should be formatted.
- """
+class LibModel(dbcore.Model):
+ """Shared concrete functionality for Items and Albums."""
+
+ # Config key that specifies how an instance should be formatted.
+ _format_config_key: str
def _template_funcs(self):
funcs = DefaultTemplateFunctions(self, self._db).functions()
@@ -342,15 +395,15 @@ class LibModel(dbcore.Model):
def store(self, fields=None):
super().store(fields)
- plugins.send('database_change', lib=self._db, model=self)
+ plugins.send("database_change", lib=self._db, model=self)
def remove(self):
super().remove()
- plugins.send('database_change', lib=self._db, model=self)
+ plugins.send("database_change", lib=self._db, model=self)
def add(self, lib=None):
super().add(lib)
- plugins.send('database_change', lib=self._db, model=self)
+ plugins.send("database_change", lib=self._db, model=self)
def __format__(self, spec):
if not spec:
@@ -362,7 +415,7 @@ class LibModel(dbcore.Model):
return format(self)
def __bytes__(self):
- return self.__str__().encode('utf-8')
+ return self.__str__().encode("utf-8")
class FormattedItemMapping(dbcore.db.FormattedMapping):
@@ -371,13 +424,12 @@ class FormattedItemMapping(dbcore.db.FormattedMapping):
Album-level fields take precedence if `for_path` is true.
"""
- ALL_KEYS = '*'
+ ALL_KEYS = "*"
def __init__(self, item, included_keys=ALL_KEYS, for_path=False):
# We treat album and item keys specially here,
# so exclude transitive album keys from the model's keys.
- super().__init__(item, included_keys=[],
- for_path=for_path)
+ super().__init__(item, included_keys=[], for_path=for_path)
self.included_keys = included_keys
if included_keys == self.ALL_KEYS:
# Performance note: this triggers a database query.
@@ -386,19 +438,21 @@ class FormattedItemMapping(dbcore.db.FormattedMapping):
self.model_keys = included_keys
self.item = item
- @lazy_property
+ @cached_property
def all_keys(self):
return set(self.model_keys).union(self.album_keys)
- @lazy_property
+ @cached_property
def album_keys(self):
album_keys = []
if self.album:
if self.included_keys == self.ALL_KEYS:
# Performance note: this triggers a database query.
for key in self.album.keys(computed=True):
- if key in Album.item_keys \
- or key not in self.item._fields.keys():
+ if (
+ key in Album.item_keys
+ or key not in self.item._fields.keys()
+ ):
album_keys.append(key)
else:
album_keys = self.included_keys
@@ -410,6 +464,7 @@ class FormattedItemMapping(dbcore.db.FormattedMapping):
def _get(self, key):
"""Get the value for a key, either from the album or the item.
+
Raise a KeyError for invalid keys.
"""
if self.for_path and key in self.album_keys:
@@ -422,8 +477,10 @@ class FormattedItemMapping(dbcore.db.FormattedMapping):
raise KeyError(key)
def __getitem__(self, key):
- """Get the value for a key. `artist` and `albumartist`
- are fallback values for each other when not set.
+ """Get the value for a key.
+
+ `artist` and `albumartist` are fallback values for each other
+ when not set.
"""
value = self._get(key)
@@ -431,10 +488,10 @@ class FormattedItemMapping(dbcore.db.FormattedMapping):
# This is helpful in path formats when the album artist is unset
# on as-is imports.
try:
- if key == 'artist' and not value:
- return self._get('albumartist')
- elif key == 'albumartist' and not value:
- return self._get('artist')
+ if key == "artist" and not value:
+ return self._get("albumartist")
+ elif key == "albumartist" and not value:
+ return self._get("artist")
except KeyError:
pass
@@ -448,122 +505,158 @@ class FormattedItemMapping(dbcore.db.FormattedMapping):
class Item(LibModel):
- _table = 'items'
- _flex_table = 'item_attributes'
+ """Represent a song or track."""
+
+ _table = "items"
+ _flex_table = "item_attributes"
_fields = {
- 'id': types.PRIMARY_ID,
- 'path': PathType(),
- 'album_id': types.FOREIGN_ID,
-
- 'title': types.STRING,
- 'artist': types.STRING,
- 'artist_sort': types.STRING,
- 'artist_credit': types.STRING,
- 'album': types.STRING,
- 'albumartist': types.STRING,
- 'albumartist_sort': types.STRING,
- 'albumartist_credit': types.STRING,
- 'genre': types.STRING,
- 'style': types.STRING,
- 'discogs_albumid': types.INTEGER,
- 'discogs_artistid': types.INTEGER,
- 'discogs_labelid': types.INTEGER,
- 'lyricist': types.STRING,
- 'composer': types.STRING,
- 'composer_sort': types.STRING,
- 'work': types.STRING,
- 'mb_workid': types.STRING,
- 'work_disambig': types.STRING,
- 'arranger': types.STRING,
- 'grouping': types.STRING,
- 'year': types.PaddedInt(4),
- 'month': types.PaddedInt(2),
- 'day': types.PaddedInt(2),
- 'track': types.PaddedInt(2),
- 'tracktotal': types.PaddedInt(2),
- 'disc': types.PaddedInt(2),
- 'disctotal': types.PaddedInt(2),
- 'lyrics': types.STRING,
- 'comments': types.STRING,
- 'bpm': types.INTEGER,
- 'comp': types.BOOLEAN,
- 'mb_trackid': types.STRING,
- 'mb_albumid': types.STRING,
- 'mb_artistid': types.STRING,
- 'mb_albumartistid': types.STRING,
- 'mb_releasetrackid': types.STRING,
- 'trackdisambig': types.STRING,
- 'albumtype': types.STRING,
- 'albumtypes': types.STRING,
- 'label': types.STRING,
- 'acoustid_fingerprint': types.STRING,
- 'acoustid_id': types.STRING,
- 'mb_releasegroupid': types.STRING,
- 'asin': types.STRING,
- 'isrc': types.STRING,
- 'catalognum': types.STRING,
- 'script': types.STRING,
- 'language': types.STRING,
- 'country': types.STRING,
- 'albumstatus': types.STRING,
- 'media': types.STRING,
- 'albumdisambig': types.STRING,
- 'releasegroupdisambig': types.STRING,
- 'disctitle': types.STRING,
- 'encoder': types.STRING,
- 'rg_track_gain': types.NULL_FLOAT,
- 'rg_track_peak': types.NULL_FLOAT,
- 'rg_album_gain': types.NULL_FLOAT,
- 'rg_album_peak': types.NULL_FLOAT,
- 'r128_track_gain': types.NullPaddedInt(6),
- 'r128_album_gain': types.NullPaddedInt(6),
- 'original_year': types.PaddedInt(4),
- 'original_month': types.PaddedInt(2),
- 'original_day': types.PaddedInt(2),
- 'initial_key': MusicalKey(),
-
- 'length': DurationType(),
- 'bitrate': types.ScaledInt(1000, 'kbps'),
- 'format': types.STRING,
- 'samplerate': types.ScaledInt(1000, 'kHz'),
- 'bitdepth': types.INTEGER,
- 'channels': types.INTEGER,
- 'mtime': DateType(),
- 'added': DateType(),
+ "id": types.PRIMARY_ID,
+ "path": PathType(),
+ "album_id": types.FOREIGN_ID,
+ "title": types.STRING,
+ "artist": types.STRING,
+ "artists": types.MULTI_VALUE_DSV,
+ "artists_ids": types.MULTI_VALUE_DSV,
+ "artist_sort": types.STRING,
+ "artists_sort": types.MULTI_VALUE_DSV,
+ "artist_credit": types.STRING,
+ "artists_credit": types.MULTI_VALUE_DSV,
+ "remixer": types.STRING,
+ "album": types.STRING,
+ "albumartist": types.STRING,
+ "albumartists": types.MULTI_VALUE_DSV,
+ "albumartist_sort": types.STRING,
+ "albumartists_sort": types.MULTI_VALUE_DSV,
+ "albumartist_credit": types.STRING,
+ "albumartists_credit": types.MULTI_VALUE_DSV,
+ "genre": types.STRING,
+ "style": types.STRING,
+ "discogs_albumid": types.INTEGER,
+ "discogs_artistid": types.INTEGER,
+ "discogs_labelid": types.INTEGER,
+ "lyricist": types.STRING,
+ "composer": types.STRING,
+ "composer_sort": types.STRING,
+ "work": types.STRING,
+ "mb_workid": types.STRING,
+ "work_disambig": types.STRING,
+ "arranger": types.STRING,
+ "grouping": types.STRING,
+ "year": types.PaddedInt(4),
+ "month": types.PaddedInt(2),
+ "day": types.PaddedInt(2),
+ "track": types.PaddedInt(2),
+ "tracktotal": types.PaddedInt(2),
+ "disc": types.PaddedInt(2),
+ "disctotal": types.PaddedInt(2),
+ "lyrics": types.STRING,
+ "comments": types.STRING,
+ "bpm": types.INTEGER,
+ "comp": types.BOOLEAN,
+ "mb_trackid": types.STRING,
+ "mb_albumid": types.STRING,
+ "mb_artistid": types.STRING,
+ "mb_artistids": types.MULTI_VALUE_DSV,
+ "mb_albumartistid": types.STRING,
+ "mb_albumartistids": types.MULTI_VALUE_DSV,
+ "mb_releasetrackid": types.STRING,
+ "trackdisambig": types.STRING,
+ "albumtype": types.STRING,
+ "albumtypes": types.SEMICOLON_SPACE_DSV,
+ "label": types.STRING,
+ "barcode": types.STRING,
+ "acoustid_fingerprint": types.STRING,
+ "acoustid_id": types.STRING,
+ "mb_releasegroupid": types.STRING,
+ "release_group_title": types.STRING,
+ "asin": types.STRING,
+ "isrc": types.STRING,
+ "catalognum": types.STRING,
+ "script": types.STRING,
+ "language": types.STRING,
+ "country": types.STRING,
+ "albumstatus": types.STRING,
+ "media": types.STRING,
+ "albumdisambig": types.STRING,
+ "releasegroupdisambig": types.STRING,
+ "disctitle": types.STRING,
+ "encoder": types.STRING,
+ "rg_track_gain": types.NULL_FLOAT,
+ "rg_track_peak": types.NULL_FLOAT,
+ "rg_album_gain": types.NULL_FLOAT,
+ "rg_album_peak": types.NULL_FLOAT,
+ "r128_track_gain": types.NULL_FLOAT,
+ "r128_album_gain": types.NULL_FLOAT,
+ "original_year": types.PaddedInt(4),
+ "original_month": types.PaddedInt(2),
+ "original_day": types.PaddedInt(2),
+ "initial_key": MusicalKey(),
+ "length": DurationType(),
+ "bitrate": types.ScaledInt(1000, "kbps"),
+ "bitrate_mode": types.STRING,
+ "encoder_info": types.STRING,
+ "encoder_settings": types.STRING,
+ "format": types.STRING,
+ "samplerate": types.ScaledInt(1000, "kHz"),
+ "bitdepth": types.INTEGER,
+ "channels": types.INTEGER,
+ "mtime": DateType(),
+ "added": DateType(),
}
- _search_fields = ('artist', 'title', 'comments',
- 'album', 'albumartist', 'genre')
+ _search_fields = (
+ "artist",
+ "title",
+ "comments",
+ "album",
+ "albumartist",
+ "genre",
+ )
_types = {
- 'data_source': types.STRING,
+ "data_source": types.STRING,
}
- _media_fields = set(MediaFile.readable_fields()) \
- .intersection(_fields.keys())
- """Set of item fields that are backed by `MediaFile` fields.
-
- Any kind of field (fixed, flexible, and computed) may be a media
- field. Only these fields are read from disk in `read` and written in
- `write`.
- """
+ # Set of item fields that are backed by `MediaFile` fields.
+ # Any kind of field (fixed, flexible, and computed) may be a media
+ # field. Only these fields are read from disk in `read` and written in
+ # `write`.
+ _media_fields = set(MediaFile.readable_fields()).intersection(
+ _fields.keys()
+ )
+ # Set of item fields that are backed by *writable* `MediaFile` tag
+ # fields.
+ # This excludes fields that represent audio data, such as `bitrate` or
+ # `length`.
_media_tag_fields = set(MediaFile.fields()).intersection(_fields.keys())
- """Set of item fields that are backed by *writable* `MediaFile` tag
- fields.
-
- This excludes fields that represent audio data, such as `bitrate` or
- `length`.
- """
_formatter = FormattedItemMapping
- _sorts = {'artist': SmartArtistSort}
+ _sorts = {"artist": SmartArtistSort}
- _format_config_key = 'format_item'
+ _queries = {"singleton": SingletonQuery}
+ _format_config_key = "format_item"
+
+ # Cached album object. Read-only.
__album = None
- """Cached album object. Read-only."""
+
+ @cached_classproperty
+ def _relation(cls) -> type[Album]:
+ return Album
+
+ @cached_classproperty
+ def relation_join(cls) -> str:
+ """Return the FROM clause which includes related albums.
+
+ We need to use a LEFT JOIN here, otherwise items that are not part of
+ an album (e.g. singletons) would be left out.
+ """
+ return (
+ f"LEFT JOIN {cls._relation._table} "
+ f"ON {cls._table}.album_id = {cls._relation._table}.id"
+ )
@property
def _cached_album(self):
@@ -588,14 +681,13 @@ class Item(LibModel):
@classmethod
def _getters(cls):
getters = plugins.item_field_getters()
- getters['singleton'] = lambda i: i.album_id is None
- getters['filesize'] = Item.try_filesize # In bytes.
+ getters["singleton"] = lambda i: i.album_id is None
+ getters["filesize"] = Item.try_filesize # In bytes.
return getters
@classmethod
def from_path(cls, path):
- """Creates a new item from the media file at the specified path.
- """
+ """Create a new item from the media file at the specified path."""
# Initiate with values that aren't read from files.
i = cls(album_id=None)
i.read(path)
@@ -603,15 +695,14 @@ class Item(LibModel):
return i
def __setitem__(self, key, value):
- """Set the item's value for a standard field or a flexattr.
- """
+ """Set the item's value for a standard field or a flexattr."""
# Encode unicode paths and read buffers.
- if key == 'path':
+ if key == "path":
if isinstance(value, str):
value = bytestring_path(value)
elif isinstance(value, BLOB_TYPE):
value = bytes(value)
- elif key == 'album_id':
+ elif key == "album_id":
self._cached_album = None
changed = super()._setitem(key, value)
@@ -621,7 +712,9 @@ class Item(LibModel):
def __getitem__(self, key):
"""Get the value for a field, falling back to the album if
- necessary. Raise a KeyError if the field is not available.
+ necessary.
+
+ Raise a KeyError if the field is not available.
"""
try:
return super().__getitem__(key)
@@ -634,15 +727,18 @@ class Item(LibModel):
# This must not use `with_album=True`, because that might access
# the database. When debugging, that is not guaranteed to succeed, and
# can even deadlock due to the database lock.
- return '{}({})'.format(
+ return "{}({})".format(
type(self).__name__,
- ', '.join('{}={!r}'.format(k, self[k])
- for k in self.keys(with_album=False)),
+ ", ".join(
+ "{}={!r}".format(k, self[k])
+ for k in self.keys(with_album=False)
+ ),
)
def keys(self, computed=False, with_album=True):
- """Get a list of available field names. `with_album`
- controls whether the album's fields are included.
+ """Get a list of available field names.
+
+ `with_album` controls whether the album's fields are included.
"""
keys = super().keys(computed=computed)
if with_album and self._cached_album:
@@ -653,7 +749,9 @@ class Item(LibModel):
def get(self, key, default=None, with_album=True):
"""Get the value for a given key or `default` if it does not
- exist. Set `with_album` to false to skip album fallback.
+ exist.
+
+ Set `with_album` to false to skip album fallback.
"""
try:
return self._get(key, default, raise_=with_album)
@@ -663,12 +761,13 @@ class Item(LibModel):
return default
def update(self, values):
- """Set all key/value pairs in the mapping. If mtime is
- specified, it is not reset (as it might otherwise be).
+ """Set all key/value pairs in the mapping.
+
+ If mtime is specified, it is not reset (as it might otherwise be).
"""
super().update(values)
- if self.mtime == 0 and 'mtime' in values:
- self.mtime = values['mtime']
+ if self.mtime == 0 and "mtime" in values:
+ self.mtime = values["mtime"]
def clear(self):
"""Set all key/value pairs to None."""
@@ -690,10 +789,10 @@ class Item(LibModel):
"""Read the metadata from the associated file.
If `read_path` is specified, read metadata from that file
- instead. Updates all the properties in `_media_fields`
+ instead. Update all the properties in `_media_fields`
from the media file.
- Raises a `ReadError` if the file could not be read.
+ Raise a `ReadError` if the file could not be read.
"""
if read_path is None:
read_path = self.path
@@ -740,15 +839,16 @@ class Item(LibModel):
path = normpath(path)
if id3v23 is None:
- id3v23 = beets.config['id3v23'].get(bool)
+ id3v23 = beets.config["id3v23"].get(bool)
# Get the data to write to the file.
item_tags = dict(self)
- item_tags = {k: v for k, v in item_tags.items()
- if k in self._media_fields} # Only write media fields.
+ item_tags = {
+ k: v for k, v in item_tags.items() if k in self._media_fields
+ } # Only write media fields.
if tags is not None:
item_tags.update(tags)
- plugins.send('write', item=self, path=path, tags=item_tags)
+ plugins.send("write", item=self, path=path, tags=item_tags)
# Open the file.
try:
@@ -766,13 +866,13 @@ class Item(LibModel):
# The file has a new mtime.
if path == self.path:
self.mtime = self.current_mtime()
- plugins.send('after_write', item=self, path=path)
+ plugins.send("after_write", item=self, path=path)
def try_write(self, *args, **kwargs):
- """Calls `write()` but catches and logs `FileOperationError`
+ """Call `write()` but catch and log `FileOperationError`
exceptions.
- Returns `False` an exception was caught and `True` otherwise.
+ Return `False` an exception was caught and `True` otherwise.
"""
try:
self.write(*args, **kwargs)
@@ -782,7 +882,7 @@ class Item(LibModel):
return False
def try_sync(self, write, move, with_album=True):
- """Synchronize the item with the database and, possibly, updates its
+ """Synchronize the item with the database and, possibly, update its
tags on disk and its path (by moving the file).
`write` indicates whether to write new tags into the file. Similarly,
@@ -798,15 +898,17 @@ class Item(LibModel):
if move:
# Check whether this file is inside the library directory.
if self._db and self._db.directory in util.ancestry(self.path):
- log.debug('moving {0} to synchronize path',
- util.displayable_path(self.path))
+ log.debug(
+ "moving {0} to synchronize path",
+ util.displayable_path(self.path),
+ )
self.move(with_album=with_album)
self.store()
# Files themselves.
def move_file(self, dest, operation=MoveOperation.MOVE):
- """Move, copy, link or hardlink the item's depending on `operation`,
+ """Move, copy, link or hardlink the item depending on `operation`,
updating the path value if the move succeeds.
If a file exists at `dest`, then it is slightly modified to be unique.
@@ -816,39 +918,49 @@ class Item(LibModel):
if not util.samefile(self.path, dest):
dest = util.unique_path(dest)
if operation == MoveOperation.MOVE:
- plugins.send("before_item_moved", item=self, source=self.path,
- destination=dest)
+ plugins.send(
+ "before_item_moved",
+ item=self,
+ source=self.path,
+ destination=dest,
+ )
util.move(self.path, dest)
- plugins.send("item_moved", item=self, source=self.path,
- destination=dest)
+ plugins.send(
+ "item_moved", item=self, source=self.path, destination=dest
+ )
elif operation == MoveOperation.COPY:
util.copy(self.path, dest)
- plugins.send("item_copied", item=self, source=self.path,
- destination=dest)
+ plugins.send(
+ "item_copied", item=self, source=self.path, destination=dest
+ )
elif operation == MoveOperation.LINK:
util.link(self.path, dest)
- plugins.send("item_linked", item=self, source=self.path,
- destination=dest)
+ plugins.send(
+ "item_linked", item=self, source=self.path, destination=dest
+ )
elif operation == MoveOperation.HARDLINK:
util.hardlink(self.path, dest)
- plugins.send("item_hardlinked", item=self, source=self.path,
- destination=dest)
+ plugins.send(
+ "item_hardlinked", item=self, source=self.path, destination=dest
+ )
elif operation == MoveOperation.REFLINK:
util.reflink(self.path, dest, fallback=False)
- plugins.send("item_reflinked", item=self, source=self.path,
- destination=dest)
+ plugins.send(
+ "item_reflinked", item=self, source=self.path, destination=dest
+ )
elif operation == MoveOperation.REFLINK_AUTO:
util.reflink(self.path, dest, fallback=True)
- plugins.send("item_reflinked", item=self, source=self.path,
- destination=dest)
+ plugins.send(
+ "item_reflinked", item=self, source=self.path, destination=dest
+ )
else:
- assert False, 'unknown MoveOperation'
+ assert False, "unknown MoveOperation"
# Either copying or moving succeeded, so update the stored path.
self.path = dest
def current_mtime(self):
- """Returns the current mtime of the file, rounded to the nearest
+ """Return the current mtime of the file, rounded to the nearest
integer.
"""
return int(os.path.getmtime(syspath(self.path)))
@@ -861,15 +973,18 @@ class Item(LibModel):
try:
return os.path.getsize(syspath(self.path))
except (OSError, Exception) as exc:
- log.warning('could not get filesize: {0}', exc)
+ log.warning("could not get filesize: {0}", exc)
return 0
# Model methods.
def remove(self, delete=False, with_album=True):
- """Removes the item. If `delete`, then the associated file is
- removed from disk. If `with_album`, then the item's album (if
- any) is removed if it the item was the last in the album.
+ """Remove the item.
+
+ If `delete`, then the associated file is removed from disk.
+
+ If `with_album`, then the item's album (if any) is removed
+ if the item was the last in the album.
"""
super().remove()
@@ -880,7 +995,7 @@ class Item(LibModel):
album.remove(delete, False)
# Send a 'item_removed' signal to plugins
- plugins.send('item_removed', item=self)
+ plugins.send("item_removed", item=self)
# Delete the associated file.
if delete:
@@ -889,12 +1004,18 @@ class Item(LibModel):
self._db._memotable = {}
- def move(self, operation=MoveOperation.MOVE, basedir=None,
- with_album=True, store=True):
+ def move(
+ self,
+ operation=MoveOperation.MOVE,
+ basedir=None,
+ with_album=True,
+ store=True,
+ ):
"""Move the item to its designated location within the library
- directory (provided by destination()). Subdirectories are
- created as needed. If the operation succeeds, the item's path
- field is updated to reflect the new location.
+ directory (provided by destination()).
+
+ Subdirectories are created as needed. If the operation succeeds,
+ the item's path field is updated to reflect the new location.
Instead of moving the item it can also be copied, linked or hardlinked
depending on `operation` which should be an instance of
@@ -908,8 +1029,8 @@ class Item(LibModel):
By default, the item is stored to the database if it is in the
database, so any dirty fields prior to the move() call will be written
as a side effect.
- If `store` is `False` however, the item won't be stored and you'll
- have to manually store it after invoking this method.
+ If `store` is `False` however, the item won't be stored and it will
+ have to be manually stored after invoking this method.
"""
self._check_db()
dest = self.destination(basedir=basedir)
@@ -937,14 +1058,21 @@ class Item(LibModel):
# Templating.
- def destination(self, fragment=False, basedir=None, platform=None,
- path_formats=None, replacements=None):
- """Returns the path in the library directory designated for the
- item (i.e., where the file ought to be). fragment makes this
- method return just the path fragment underneath the root library
- directory; the path is also returned as Unicode instead of
- encoded as a bytestring. basedir can override the library's base
- directory for the destination.
+ def destination(
+ self,
+ fragment=False,
+ basedir=None,
+ platform=None,
+ path_formats=None,
+ replacements=None,
+ ):
+ """Return the path in the library directory designated for the
+ item (i.e., where the file ought to be).
+
+ fragment makes this method return just the path fragment underneath
+ the root library directory; the path is also returned as Unicode
+ instead of encoded as a bytestring. basedir can override the library's
+ base directory for the destination.
"""
self._check_db()
platform = platform or sys.platform
@@ -979,34 +1107,36 @@ class Item(LibModel):
subpath = self.evaluate_template(subpath_tmpl, True)
# Prepare path for output: normalize Unicode characters.
- if platform == 'darwin':
- subpath = unicodedata.normalize('NFD', subpath)
+ if platform == "darwin":
+ subpath = unicodedata.normalize("NFD", subpath)
else:
- subpath = unicodedata.normalize('NFC', subpath)
+ subpath = unicodedata.normalize("NFC", subpath)
- if beets.config['asciify_paths']:
+ if beets.config["asciify_paths"]:
subpath = util.asciify_path(
- subpath,
- beets.config['path_sep_replace'].as_str()
+ subpath, beets.config["path_sep_replace"].as_str()
)
- maxlen = beets.config['max_filename_length'].get(int)
+ maxlen = beets.config["max_filename_length"].get(int)
if not maxlen:
# When zero, try to determine from filesystem.
maxlen = util.max_filename_length(self._db.directory)
subpath, fellback = util.legalize_path(
- subpath, replacements, maxlen,
- os.path.splitext(self.path)[1], fragment
+ subpath,
+ replacements,
+ maxlen,
+ os.path.splitext(self.path)[1],
+ fragment,
)
if fellback:
# Print an error message if legalization fell back to
# default replacements because of the maximum length.
log.warning(
- 'Fell back to default replacements when naming '
- 'file {}. Configure replacements to avoid lengthening '
- 'the filename.',
- subpath
+ "Fell back to default replacements when naming "
+ "file {}. Configure replacements to avoid lengthening "
+ "the filename.",
+ subpath,
)
if fragment:
@@ -1016,134 +1146,168 @@ class Item(LibModel):
class Album(LibModel):
- """Provides access to information about albums stored in a
- library. Reflects the library's "albums" table, including album
- art.
+ """Provide access to information about albums stored in a
+ library.
+
+ Reflects the library's "albums" table, including album art.
"""
- _table = 'albums'
- _flex_table = 'album_attributes'
+
+ _table = "albums"
+ _flex_table = "album_attributes"
_always_dirty = True
_fields = {
- 'id': types.PRIMARY_ID,
- 'artpath': PathType(True),
- 'added': DateType(),
-
- 'albumartist': types.STRING,
- 'albumartist_sort': types.STRING,
- 'albumartist_credit': types.STRING,
- 'album': types.STRING,
- 'genre': types.STRING,
- 'style': types.STRING,
- 'discogs_albumid': types.INTEGER,
- 'discogs_artistid': types.INTEGER,
- 'discogs_labelid': types.INTEGER,
- 'year': types.PaddedInt(4),
- 'month': types.PaddedInt(2),
- 'day': types.PaddedInt(2),
- 'disctotal': types.PaddedInt(2),
- 'comp': types.BOOLEAN,
- 'mb_albumid': types.STRING,
- 'mb_albumartistid': types.STRING,
- 'albumtype': types.STRING,
- 'albumtypes': types.STRING,
- 'label': types.STRING,
- 'mb_releasegroupid': types.STRING,
- 'asin': types.STRING,
- 'catalognum': types.STRING,
- 'script': types.STRING,
- 'language': types.STRING,
- 'country': types.STRING,
- 'albumstatus': types.STRING,
- 'albumdisambig': types.STRING,
- 'releasegroupdisambig': types.STRING,
- 'rg_album_gain': types.NULL_FLOAT,
- 'rg_album_peak': types.NULL_FLOAT,
- 'r128_album_gain': types.NullPaddedInt(6),
- 'original_year': types.PaddedInt(4),
- 'original_month': types.PaddedInt(2),
- 'original_day': types.PaddedInt(2),
+ "id": types.PRIMARY_ID,
+ "artpath": PathType(True),
+ "added": DateType(),
+ "albumartist": types.STRING,
+ "albumartist_sort": types.STRING,
+ "albumartist_credit": types.STRING,
+ "albumartists": types.MULTI_VALUE_DSV,
+ "albumartists_sort": types.MULTI_VALUE_DSV,
+ "albumartists_credit": types.MULTI_VALUE_DSV,
+ "album": types.STRING,
+ "genre": types.STRING,
+ "style": types.STRING,
+ "discogs_albumid": types.INTEGER,
+ "discogs_artistid": types.INTEGER,
+ "discogs_labelid": types.INTEGER,
+ "year": types.PaddedInt(4),
+ "month": types.PaddedInt(2),
+ "day": types.PaddedInt(2),
+ "disctotal": types.PaddedInt(2),
+ "comp": types.BOOLEAN,
+ "mb_albumid": types.STRING,
+ "mb_albumartistid": types.STRING,
+ "albumtype": types.STRING,
+ "albumtypes": types.SEMICOLON_SPACE_DSV,
+ "label": types.STRING,
+ "barcode": types.STRING,
+ "mb_releasegroupid": types.STRING,
+ "release_group_title": types.STRING,
+ "asin": types.STRING,
+ "catalognum": types.STRING,
+ "script": types.STRING,
+ "language": types.STRING,
+ "country": types.STRING,
+ "albumstatus": types.STRING,
+ "albumdisambig": types.STRING,
+ "releasegroupdisambig": types.STRING,
+ "rg_album_gain": types.NULL_FLOAT,
+ "rg_album_peak": types.NULL_FLOAT,
+ "r128_album_gain": types.NULL_FLOAT,
+ "original_year": types.PaddedInt(4),
+ "original_month": types.PaddedInt(2),
+ "original_day": types.PaddedInt(2),
}
- _search_fields = ('album', 'albumartist', 'genre')
+ _search_fields = ("album", "albumartist", "genre")
_types = {
- 'path': PathType(),
- 'data_source': types.STRING,
+ "path": PathType(),
+ "data_source": types.STRING,
}
_sorts = {
- 'albumartist': SmartArtistSort,
- 'artist': SmartArtistSort,
+ "albumartist": SmartArtistSort,
+ "artist": SmartArtistSort,
}
+ # List of keys that are set on an album's items.
item_keys = [
- 'added',
- 'albumartist',
- 'albumartist_sort',
- 'albumartist_credit',
- 'album',
- 'genre',
- 'style',
- 'discogs_albumid',
- 'discogs_artistid',
- 'discogs_labelid',
- 'year',
- 'month',
- 'day',
- 'disctotal',
- 'comp',
- 'mb_albumid',
- 'mb_albumartistid',
- 'albumtype',
- 'albumtypes',
- 'label',
- 'mb_releasegroupid',
- 'asin',
- 'catalognum',
- 'script',
- 'language',
- 'country',
- 'albumstatus',
- 'albumdisambig',
- 'releasegroupdisambig',
- 'rg_album_gain',
- 'rg_album_peak',
- 'r128_album_gain',
- 'original_year',
- 'original_month',
- 'original_day',
+ "added",
+ "albumartist",
+ "albumartists",
+ "albumartist_sort",
+ "albumartists_sort",
+ "albumartist_credit",
+ "albumartists_credit",
+ "album",
+ "genre",
+ "style",
+ "discogs_albumid",
+ "discogs_artistid",
+ "discogs_labelid",
+ "year",
+ "month",
+ "day",
+ "disctotal",
+ "comp",
+ "mb_albumid",
+ "mb_albumartistid",
+ "albumtype",
+ "albumtypes",
+ "label",
+ "barcode",
+ "mb_releasegroupid",
+ "asin",
+ "catalognum",
+ "script",
+ "language",
+ "country",
+ "albumstatus",
+ "albumdisambig",
+ "releasegroupdisambig",
+ "release_group_title",
+ "rg_album_gain",
+ "rg_album_peak",
+ "r128_album_gain",
+ "original_year",
+ "original_month",
+ "original_day",
]
- """List of keys that are set on an album's items.
- """
- _format_config_key = 'format_album'
+ _format_config_key = "format_album"
+
+ @cached_classproperty
+ def _relation(cls) -> type[Item]:
+ return Item
+
+ @cached_classproperty
+ def relation_join(cls) -> str:
+ """Return FROM clause which joins on related album items.
+
+ Use LEFT join to select all albums, including those that do not have
+ any items.
+ """
+ return (
+ f"LEFT JOIN {cls._relation._table} "
+ f"ON {cls._table}.id = {cls._relation._table}.album_id"
+ )
@classmethod
def _getters(cls):
# In addition to plugin-provided computed fields, also expose
# the album's directory as `path`.
getters = plugins.album_field_getters()
- getters['path'] = Album.item_dir
- getters['albumtotal'] = Album._albumtotal
+ getters["path"] = Album.item_dir
+ getters["albumtotal"] = Album._albumtotal
return getters
def items(self):
- """Returns an iterable over the items associated with this
+ """Return an iterable over the items associated with this
album.
+
+ This method conflicts with :meth:`LibModel.items`, which is
+ inherited from :meth:`beets.dbcore.Model.items`.
+ Since :meth:`Album.items` predates these methods, and is
+ likely to be used by plugins, we keep this interface as-is.
"""
- return self._db.items(dbcore.MatchQuery('album_id', self.id))
+ return self._db.items(dbcore.MatchQuery("album_id", self.id))
def remove(self, delete=False, with_items=True):
- """Removes this album and all its associated items from the
- library. If delete, then the items' files are also deleted
- from disk, along with any album art. The directories
- containing the album are also removed (recursively) if empty.
+ """Remove this album and all its associated items from the
+ library.
+
+ If delete, then the items' files are also deleted from disk,
+ along with any album art. The directories containing the album are
+ also removed (recursively) if empty.
+
Set with_items to False to avoid removing the album's items.
"""
super().remove()
# Send a 'album_removed' signal to plugins
- plugins.send('album_removed', album=self)
+ plugins.send("album_removed", album=self)
# Delete art file.
if delete:
@@ -1167,9 +1331,11 @@ class Album(LibModel):
if not old_art:
return
- if not os.path.exists(old_art):
- log.error('removing reference to missing album art file {}',
- util.displayable_path(old_art))
+ if not os.path.exists(syspath(old_art)):
+ log.error(
+ "removing reference to missing album art file {}",
+ util.displayable_path(old_art),
+ )
self.artpath = None
return
@@ -1178,9 +1344,11 @@ class Album(LibModel):
return
new_art = util.unique_path(new_art)
- log.debug('moving album art {0} to {1}',
- util.displayable_path(old_art),
- util.displayable_path(new_art))
+ log.debug(
+ "moving album art {0} to {1}",
+ util.displayable_path(old_art),
+ util.displayable_path(new_art),
+ )
if operation == MoveOperation.MOVE:
util.move(old_art, new_art)
util.prune_dirs(os.path.dirname(old_art), self._db.directory)
@@ -1195,7 +1363,7 @@ class Album(LibModel):
elif operation == MoveOperation.REFLINK_AUTO:
util.reflink(old_art, new_art, fallback=True)
else:
- assert False, 'unknown MoveOperation'
+ assert False, "unknown MoveOperation"
self.artpath = new_art
def move(self, operation=MoveOperation.MOVE, basedir=None, store=True):
@@ -1208,8 +1376,8 @@ class Album(LibModel):
By default, the album is stored to the database, persisting any
modifications to its metadata. If `store` is `False` however,
- the album is not stored automatically, and you'll have to manually
- store it after invoking this method.
+ the album is not stored automatically, and it will have to be manually
+ stored after invoking this method.
"""
basedir = basedir or self._db.directory
@@ -1221,8 +1389,7 @@ class Album(LibModel):
# Move items.
items = list(self.items())
for item in items:
- item.move(operation, basedir=basedir, with_album=False,
- store=store)
+ item.move(operation, basedir=basedir, with_album=False, store=store)
# Move art.
self.move_art(operation)
@@ -1230,18 +1397,17 @@ class Album(LibModel):
self.store()
def item_dir(self):
- """Returns the directory containing the album's first item,
+ """Return the directory containing the album's first item,
provided that such an item exists.
"""
item = self.items().get()
if not item:
- raise ValueError('empty album for album id %d' % self.id)
+ raise ValueError("empty album for album id %d" % self.id)
return os.path.dirname(item.path)
def _albumtotal(self):
- """Return the total number of tracks on all discs on the album
- """
- if self.disctotal == 1 or not beets.config['per_disc_numbering']:
+ """Return the total number of tracks on all discs on the album."""
+ if self.disctotal == 1 or not beets.config["per_disc_numbering"]:
return self.items()[0].tracktotal
counted = []
@@ -1260,8 +1426,10 @@ class Album(LibModel):
return total
def art_destination(self, image, item_dir=None):
- """Returns a path to the destination for the album art image
- for the album. `image` is the path of the image that will be
+ """Return a path to the destination for the album art image
+ for the album.
+
+ `image` is the path of the image that will be
moved there (used for its extension).
The path construction uses the existing path of the album's
@@ -1271,16 +1439,15 @@ class Album(LibModel):
image = bytestring_path(image)
item_dir = item_dir or self.item_dir()
- filename_tmpl = template(
- beets.config['art_filename'].as_str())
+ filename_tmpl = template(beets.config["art_filename"].as_str())
subpath = self.evaluate_template(filename_tmpl, True)
- if beets.config['asciify_paths']:
+ if beets.config["asciify_paths"]:
subpath = util.asciify_path(
- subpath,
- beets.config['path_sep_replace'].as_str()
+ subpath, beets.config["path_sep_replace"].as_str()
)
- subpath = util.sanitize_path(subpath,
- replacements=self._db.replacements)
+ subpath = util.sanitize_path(
+ subpath, replacements=self._db.replacements
+ )
subpath = bytestring_path(subpath)
_, ext = os.path.splitext(image)
@@ -1289,11 +1456,12 @@ class Album(LibModel):
return bytestring_path(dest)
def set_art(self, path, copy=True):
- """Sets the album's cover art to the image at the given path.
+ """Set the album's cover art to the image at the given path.
+
The image is copied (or moved) into place, replacing any
existing art.
- Sends an 'art_set' event with `self` as the sole argument.
+ Send an 'art_set' event with `self` as the sole argument.
"""
path = bytestring_path(path)
oldart = self.artpath
@@ -1317,19 +1485,29 @@ class Album(LibModel):
util.move(path, artdest)
self.artpath = artdest
- plugins.send('art_set', album=self)
+ plugins.send("art_set", album=self)
- def store(self, fields=None):
- """Update the database with the album information. The album's
- tracks are also updated.
- :param fields: The fields to be stored. If not specified, all fields
- will be.
+ def store(self, fields=None, inherit=True):
+ """Update the database with the album information.
+
+ `fields` represents the fields to be stored. If not specified,
+ all fields will be.
+
+ The album's tracks are also updated when the `inherit` flag is enabled.
+ This applies to fixed attributes as well as flexible ones. The `id`
+ attribute of the album will never be inherited.
"""
# Get modified track fields.
track_updates = {}
- for key in self.item_keys:
- if key in self._dirty:
- track_updates[key] = self[key]
+ track_deletes = set()
+ for key in self._dirty:
+ if inherit:
+ if key in self.item_keys: # is a fixed attribute
+ track_updates[key] = self[key]
+ elif key not in self: # is a fixed or a flexible attribute
+ track_deletes.add(key)
+ elif key != "id": # is a flexible attribute
+ track_updates[key] = self[key]
with self._db.transaction():
super().store(fields)
@@ -1338,8 +1516,14 @@ class Album(LibModel):
for key, value in track_updates.items():
item[key] = value
item.store()
+ if track_deletes:
+ for item in self.items():
+ for key in track_deletes:
+ if key in item:
+ del item[key]
+ item.store()
- def try_sync(self, write, move):
+ def try_sync(self, write, move, inherit=True):
"""Synchronize the album and its items with the database.
Optionally, also write any new tags into the files and update
their paths.
@@ -1348,46 +1532,40 @@ class Album(LibModel):
`move` controls whether files (both audio and album art) are
moved.
"""
- self.store()
+ self.store(inherit=inherit)
for item in self.items():
item.try_sync(write, move)
# Query construction helpers.
+
def parse_query_parts(parts, model_cls):
"""Given a beets query string as a list of components, return the
`Query` and `Sort` they represent.
Like `dbcore.parse_sorted_query`, with beets query prefixes and
- special path query detection.
+ ensuring that implicit path queries are made explicit with 'path::'
"""
# Get query types and their prefix characters.
- prefixes = {':': dbcore.query.RegexpQuery}
+ prefixes = {
+ ":": dbcore.query.RegexpQuery,
+ "=~": dbcore.query.StringQuery,
+ "=": dbcore.query.MatchQuery,
+ }
prefixes.update(plugins.queries())
# Special-case path-like queries, which are non-field queries
# containing path separators (/).
- path_parts = []
- non_path_parts = []
- for s in parts:
- if PathQuery.is_path_query(s):
- path_parts.append(s)
- else:
- non_path_parts.append(s)
+ parts = [f"path:{s}" if PathQuery.is_path_query(s) else s for s in parts]
- case_insensitive = beets.config['sort_case_insensitive'].get(bool)
+ case_insensitive = beets.config["sort_case_insensitive"].get(bool)
query, sort = dbcore.parse_sorted_query(
- model_cls, non_path_parts, prefixes, case_insensitive
+ model_cls, parts, prefixes, case_insensitive
)
-
- # Add path queries to aggregate query.
- # Match field / flexattr depending on whether the model has the path field
- fast_path_query = 'path' in model_cls._fields
- query.subqueries += [PathQuery('path', s, fast_path_query)
- for s in path_parts]
-
+ log.debug("Parsed query: {!r}", query)
+ log.debug("Parsed sort: {!r}", sort)
return query, sort
@@ -1406,29 +1584,22 @@ def parse_query_string(s, model_cls):
return parse_query_parts(parts, model_cls)
-def _sqlite_bytelower(bytestring):
- """ A custom ``bytelower`` sqlite function so we can compare
- bytestrings in a semi case insensitive fashion. This is to work
- around sqlite builds are that compiled with
- ``-DSQLITE_LIKE_DOESNT_MATCH_BLOBS``. See
- ``https://github.com/beetbox/beets/issues/2172`` for details.
- """
- return bytestring.lower()
-
-
# The Library: interface to the database.
+
class Library(dbcore.Database):
- """A database of music containing songs and albums.
- """
+ """A database of music containing songs and albums."""
+
_models = (Item, Album)
- def __init__(self, path='library.blb',
- directory='~/Music',
- path_formats=((PF_KEY_DEFAULT,
- '$artist/$album/$track $title'),),
- replacements=None):
- timeout = beets.config['timeout'].as_number()
+ def __init__(
+ self,
+ path="library.blb",
+ directory="~/Music",
+ path_formats=((PF_KEY_DEFAULT, "$artist/$album/$track $title"),),
+ replacements=None,
+ ):
+ timeout = beets.config["timeout"].as_number()
super().__init__(path, timeout=timeout)
self.directory = bytestring_path(normpath(directory))
@@ -1437,16 +1608,13 @@ class Library(dbcore.Database):
self._memotable = {} # Used for template substitution performance.
- def _create_connection(self):
- conn = super()._create_connection()
- conn.create_function('bytelower', 1, _sqlite_bytelower)
- return conn
-
# Adding objects to the database.
def add(self, obj):
"""Add the :class:`Item` or :class:`Album` object to the library
- database. Return the object's new id.
+ database.
+
+ Return the object's new id.
"""
obj.add(self)
self._memotable = {}
@@ -1460,7 +1628,7 @@ class Library(dbcore.Database):
be empty.
"""
if not items:
- raise ValueError('need at least one item')
+ raise ValueError("need at least one item")
# Create the album structure using metadata from the first item.
values = {key: items[0][key] for key in Album.item_keys}
@@ -1482,8 +1650,10 @@ class Library(dbcore.Database):
# Querying.
def _fetch(self, model_cls, query, sort=None):
- """Parse a query and fetch. If a order specification is present
- in the query string the `sort` argument is ignored.
+ """Parse a query and fetch.
+
+ If an order specification is present in the query string
+ the `sort` argument is ignored.
"""
# Parse the query, if necessary.
try:
@@ -1500,46 +1670,44 @@ class Library(dbcore.Database):
if parsed_sort and not isinstance(parsed_sort, dbcore.query.NullSort):
sort = parsed_sort
- return super()._fetch(
- model_cls, query, sort
- )
+ return super()._fetch(model_cls, query, sort)
@staticmethod
def get_default_album_sort():
- """Get a :class:`Sort` object for albums from the config option.
- """
+ """Get a :class:`Sort` object for albums from the config option."""
return dbcore.sort_from_strings(
- Album, beets.config['sort_album'].as_str_seq())
+ Album, beets.config["sort_album"].as_str_seq()
+ )
@staticmethod
def get_default_item_sort():
- """Get a :class:`Sort` object for items from the config option.
- """
+ """Get a :class:`Sort` object for items from the config option."""
return dbcore.sort_from_strings(
- Item, beets.config['sort_item'].as_str_seq())
+ Item, beets.config["sort_item"].as_str_seq()
+ )
- def albums(self, query=None, sort=None):
- """Get :class:`Album` objects matching the query.
- """
+ def albums(self, query=None, sort=None) -> Results[Album]:
+ """Get :class:`Album` objects matching the query."""
return self._fetch(Album, query, sort or self.get_default_album_sort())
- def items(self, query=None, sort=None):
- """Get :class:`Item` objects matching the query.
- """
+ def items(self, query=None, sort=None) -> Results[Item]:
+ """Get :class:`Item` objects matching the query."""
return self._fetch(Item, query, sort or self.get_default_item_sort())
# Convenience accessors.
def get_item(self, id):
- """Fetch an :class:`Item` by its ID. Returns `None` if no match is
- found.
+ """Fetch a :class:`Item` by its ID.
+
+ Return `None` if no match is found.
"""
return self._get(Item, id)
def get_album(self, item_or_id):
"""Given an album ID or an item associated with an album, return
- an :class:`Album` object for the album. If no such album exists,
- returns `None`.
+ a :class:`Album` object for the album.
+
+ If no such album exists, return `None`.
"""
if isinstance(item_or_id, int):
album_id = item_or_id
@@ -1552,37 +1720,46 @@ class Library(dbcore.Database):
# Default path template resources.
+
def _int_arg(s):
"""Convert a string argument to an integer for use in a template
- function. May raise a ValueError.
+ function.
+
+ May raise a ValueError.
"""
return int(s.strip())
class DefaultTemplateFunctions:
"""A container class for the default functions provided to path
- templates. These functions are contained in an object to provide
+ templates.
+
+ These functions are contained in an object to provide
additional context to the functions -- specifically, the Item being
evaluated.
"""
- _prefix = 'tmpl_'
+
+ _prefix = "tmpl_"
def __init__(self, item=None, lib=None):
- """Parametrize the functions. If `item` or `lib` is None, then
- some functions (namely, ``aunique``) will always evaluate to the
- empty string.
+ """Parametrize the functions.
+
+ If `item` or `lib` is None, then some functions (namely, ``aunique``)
+ will always evaluate to the empty string.
"""
self.item = item
self.lib = lib
def functions(self):
- """Returns a dictionary containing the functions defined in this
- object. The keys are function names (as exposed in templates)
+ """Return a dictionary containing the functions defined in this
+ object.
+
+ The keys are function names (as exposed in templates)
and the values are Python functions.
"""
out = {}
for key in self._func_names:
- out[key[len(self._prefix):]] = getattr(self, key)
+ out[key[len(self._prefix) :]] = getattr(self, key)
return out
@staticmethod
@@ -1592,7 +1769,7 @@ class DefaultTemplateFunctions:
@staticmethod
def tmpl_upper(s):
- """Covert a string to upper case."""
+ """Convert a string to upper case."""
return s.upper()
@staticmethod
@@ -1603,15 +1780,15 @@ class DefaultTemplateFunctions:
@staticmethod
def tmpl_left(s, chars):
"""Get the leftmost characters of a string."""
- return s[0:_int_arg(chars)]
+ return s[0 : _int_arg(chars)]
@staticmethod
def tmpl_right(s, chars):
"""Get the rightmost characters of a string."""
- return s[-_int_arg(chars):]
+ return s[-_int_arg(chars) :]
@staticmethod
- def tmpl_if(condition, trueval, falseval=''):
+ def tmpl_if(condition, trueval, falseval=""):
"""If ``condition`` is nonempty and nonzero, emit ``trueval``;
otherwise, emit ``falseval`` (if provided).
"""
@@ -1630,21 +1807,20 @@ class DefaultTemplateFunctions:
@staticmethod
def tmpl_asciify(s):
- """Translate non-ASCII characters to their ASCII equivalents.
- """
- return util.asciify_path(s, beets.config['path_sep_replace'].as_str())
+ """Translate non-ASCII characters to their ASCII equivalents."""
+ return util.asciify_path(s, beets.config["path_sep_replace"].as_str())
@staticmethod
def tmpl_time(s, fmt):
- """Format a time value using `strftime`.
- """
- cur_fmt = beets.config['time_format'].as_str()
+ """Format a time value using `strftime`."""
+ cur_fmt = beets.config["time_format"].as_str()
return time.strftime(fmt, time.strptime(s, cur_fmt))
def tmpl_aunique(self, keys=None, disam=None, bracket=None):
"""Generate a string that is guaranteed to be unique among all
- albums in the library who share the same set of keys. A fields
- from "disam" is used in the string if one is sufficient to
+ albums in the library who share the same set of keys.
+
+ A fields from "disam" is used in the string if one is sufficient to
disambiguate the albums. Otherwise, a fallback opaque value is
used. Both "keys" and "disam" should be given as
whitespace-separated lists of field names, while "bracket" is a
@@ -1653,7 +1829,7 @@ class DefaultTemplateFunctions:
"""
# Fast paths: no album, no item or library, or memoized value.
if not self.item or not self.lib:
- return ''
+ return ""
if isinstance(self.item, Item):
album_id = self.item.album_id
@@ -1661,17 +1837,114 @@ class DefaultTemplateFunctions:
album_id = self.item.id
if album_id is None:
- return ''
+ return ""
- memokey = ('aunique', keys, disam, album_id)
+ memokey = self._tmpl_unique_memokey("aunique", keys, disam, album_id)
memoval = self.lib._memotable.get(memokey)
if memoval is not None:
return memoval
- keys = keys or beets.config['aunique']['keys'].as_str()
- disam = disam or beets.config['aunique']['disambiguators'].as_str()
+ album = self.lib.get_album(album_id)
+
+ return self._tmpl_unique(
+ "aunique",
+ keys,
+ disam,
+ bracket,
+ album_id,
+ album,
+ album.item_keys,
+ # Do nothing for singletons.
+ lambda a: a is None,
+ )
+
+ def tmpl_sunique(self, keys=None, disam=None, bracket=None):
+ """Generate a string that is guaranteed to be unique among all
+ singletons in the library who share the same set of keys.
+
+ A fields from "disam" is used in the string if one is sufficient to
+ disambiguate the albums. Otherwise, a fallback opaque value is
+ used. Both "keys" and "disam" should be given as
+ whitespace-separated lists of field names, while "bracket" is a
+ pair of characters to be used as brackets surrounding the
+ disambiguator or empty to have no brackets.
+ """
+ # Fast paths: no album, no item or library, or memoized value.
+ if not self.item or not self.lib:
+ return ""
+
+ if isinstance(self.item, Item):
+ item_id = self.item.id
+ else:
+ raise NotImplementedError("sunique is only implemented for items")
+
+ if item_id is None:
+ return ""
+
+ return self._tmpl_unique(
+ "sunique",
+ keys,
+ disam,
+ bracket,
+ item_id,
+ self.item,
+ Item.all_keys(),
+ # Do nothing for non singletons.
+ lambda i: i.album_id is not None,
+ initial_subqueries=[dbcore.query.NoneQuery("album_id", True)],
+ )
+
+ def _tmpl_unique_memokey(self, name, keys, disam, item_id):
+ """Get the memokey for the unique template named "name" for the
+ specific parameters.
+ """
+ return (name, keys, disam, item_id)
+
+ def _tmpl_unique(
+ self,
+ name,
+ keys,
+ disam,
+ bracket,
+ item_id,
+ db_item,
+ item_keys,
+ skip_item,
+ initial_subqueries=None,
+ ):
+ """Generate a string that is guaranteed to be unique among all items of
+ the same type as "db_item" who share the same set of keys.
+
+ A field from "disam" is used in the string if one is sufficient to
+ disambiguate the items. Otherwise, a fallback opaque value is
+ used. Both "keys" and "disam" should be given as
+ whitespace-separated lists of field names, while "bracket" is a
+ pair of characters to be used as brackets surrounding the
+ disambiguator or empty to have no brackets.
+
+ "name" is the name of the templates. It is also the name of the
+ configuration section where the default values of the parameters
+ are stored.
+
+ "skip_item" is a function that must return True when the template
+ should return an empty string.
+
+ "initial_subqueries" is a list of subqueries that should be included
+ in the query to find the ambiguous items.
+ """
+ memokey = self._tmpl_unique_memokey(name, keys, disam, item_id)
+ memoval = self.lib._memotable.get(memokey)
+ if memoval is not None:
+ return memoval
+
+ if skip_item(db_item):
+ self.lib._memotable[memokey] = ""
+ return ""
+
+ keys = keys or beets.config[name]["keys"].as_str()
+ disam = disam or beets.config[name]["disambiguators"].as_str()
if bracket is None:
- bracket = beets.config['aunique']['bracket'].as_str()
+ bracket = beets.config[name]["bracket"].as_str()
keys = keys.split()
disam = disam.split()
@@ -1680,82 +1953,86 @@ class DefaultTemplateFunctions:
bracket_l = bracket[0]
bracket_r = bracket[1]
else:
- bracket_l = ''
- bracket_r = ''
+ bracket_l = ""
+ bracket_r = ""
- album = self.lib.get_album(album_id)
- if not album:
- # Do nothing for singletons.
- self.lib._memotable[memokey] = ''
- return ''
-
- # Find matching albums to disambiguate with.
+ # Find matching items to disambiguate with.
subqueries = []
+ if initial_subqueries is not None:
+ subqueries.extend(initial_subqueries)
for key in keys:
- value = album.get(key, '')
+ value = db_item.get(key, "")
# Use slow queries for flexible attributes.
- fast = key in album.item_keys
+ fast = key in item_keys
subqueries.append(dbcore.MatchQuery(key, value, fast))
- albums = self.lib.albums(dbcore.AndQuery(subqueries))
+ query = dbcore.AndQuery(subqueries)
+ ambigous_items = (
+ self.lib.items(query)
+ if isinstance(db_item, Item)
+ else self.lib.albums(query)
+ )
- # If there's only one album to matching these details, then do
+ # If there's only one item to matching these details, then do
# nothing.
- if len(albums) == 1:
- self.lib._memotable[memokey] = ''
- return ''
+ if len(ambigous_items) == 1:
+ self.lib._memotable[memokey] = ""
+ return ""
- # Find the first disambiguator that distinguishes the albums.
+ # Find the first disambiguator that distinguishes the items.
for disambiguator in disam:
- # Get the value for each album for the current field.
- disam_values = {a.get(disambiguator, '') for a in albums}
+ # Get the value for each item for the current field.
+ disam_values = {s.get(disambiguator, "") for s in ambigous_items}
# If the set of unique values is equal to the number of
- # albums in the disambiguation set, we're done -- this is
+ # items in the disambiguation set, we're done -- this is
# sufficient disambiguation.
- if len(disam_values) == len(albums):
+ if len(disam_values) == len(ambigous_items):
break
-
else:
# No disambiguator distinguished all fields.
- res = f' {bracket_l}{album.id}{bracket_r}'
+ res = f" {bracket_l}{item_id}{bracket_r}"
self.lib._memotable[memokey] = res
return res
# Flatten disambiguation value into a string.
- disam_value = album.formatted(for_path=True).get(disambiguator)
+ disam_value = db_item.formatted(for_path=True).get(disambiguator)
# Return empty string if disambiguator is empty.
if disam_value:
- res = f' {bracket_l}{disam_value}{bracket_r}'
+ res = f" {bracket_l}{disam_value}{bracket_r}"
else:
- res = ''
+ res = ""
self.lib._memotable[memokey] = res
return res
@staticmethod
- def tmpl_first(s, count=1, skip=0, sep='; ', join_str='; '):
- """ Gets the item(s) from x to y in a string separated by something
- and join then with something
+ def tmpl_first(s, count=1, skip=0, sep="; ", join_str="; "):
+ """Get the item(s) from x to y in a string separated by something
+ and join then with something.
- :param s: the string
- :param count: The number of items included
- :param skip: The number of items skipped
- :param sep: the separator. Usually is '; ' (default) or '/ '
- :param join_str: the string which will join the items, default '; '.
+ Args:
+ s: the string
+ count: The number of items included
+ skip: The number of items skipped
+ sep: the separator. Usually is '; ' (default) or '/ '
+ join_str: the string which will join the items, default '; '.
"""
skip = int(skip)
count = skip + int(count)
return join_str.join(s.split(sep)[skip:count])
- def tmpl_ifdef(self, field, trueval='', falseval=''):
- """ If field exists return trueval or the field (default)
+ def tmpl_ifdef(self, field, trueval="", falseval=""):
+ """If field exists return trueval or the field (default)
otherwise, emit return falseval (if provided).
- :param field: The name of the field
- :param trueval: The string if the condition is true
- :param falseval: The string if the condition is false
- :return: The string, based on condition
+ Args:
+ field: The name of the field
+ trueval: The string if the condition is true
+ falseval: The string if the condition is false
+
+ Returns:
+ The string, based on condition.
"""
if field in self.item:
return trueval if trueval else self.item.formatted().get(field)
@@ -1764,6 +2041,8 @@ class DefaultTemplateFunctions:
# Get the name of tmpl_* functions in the above class.
-DefaultTemplateFunctions._func_names = \
- [s for s in dir(DefaultTemplateFunctions)
- if s.startswith(DefaultTemplateFunctions._prefix)]
+DefaultTemplateFunctions._func_names = [
+ s
+ for s in dir(DefaultTemplateFunctions)
+ if s.startswith(DefaultTemplateFunctions._prefix)
+]
diff --git a/lib/beets/logging.py b/lib/beets/logging.py
index 4f004f8d..faa93d59 100644
--- a/lib/beets/logging.py
+++ b/lib/beets/logging.py
@@ -12,63 +12,50 @@
# The above copyright notice and this permission notice shall be
# included in all copies or substantial portions of the Software.
-"""A drop-in replacement for the standard-library `logging` module that
-allows {}-style log formatting on Python 2 and 3.
+"""A drop-in replacement for the standard-library `logging` module.
-Provides everything the "logging" module does. The only difference is
-that when getLogger(name) instantiates a logger that logger uses
-{}-style formatting.
+Provides everything the "logging" module does. In addition, beets' logger
+(as obtained by `getLogger(name)`) supports thread-local levels, and messages
+use {}-style formatting and can interpolate keywords arguments to the logging
+calls (`debug`, `info`, etc).
"""
-from copy import copy
-from logging import * # noqa
-import subprocess
+import logging
import threading
+from copy import copy
def logsafe(val):
- """Coerce a potentially "problematic" value so it can be formatted
- in a Unicode log string.
+ """Coerce `bytes` to `str` to avoid crashes solely due to logging.
- This works around a number of pitfalls when logging objects in
- Python 2:
- - Logging path names, which must be byte strings, requires
- conversion for output.
- - Some objects, including some exceptions, will crash when you call
- `unicode(v)` while `str(v)` works fine. CalledProcessError is an
- example.
+ This is particularly relevant for bytestring paths. Much of our code
+ explicitly uses `displayable_path` for them, but better be safe and prevent
+ any crashes that are solely due to log formatting.
"""
- # Already Unicode.
- if isinstance(val, str):
- return val
-
- # Bytestring: needs decoding.
- elif isinstance(val, bytes):
+ # Bytestring: Needs decoding to be safe for substitution in format strings.
+ if isinstance(val, bytes):
# Blindly convert with UTF-8. Eventually, it would be nice to
# (a) only do this for paths, if they can be given a distinct
# type, and (b) warn the developer if they do this for other
# bytestrings.
- return val.decode('utf-8', 'replace')
-
- # A "problem" object: needs a workaround.
- elif isinstance(val, subprocess.CalledProcessError):
- try:
- return str(val)
- except UnicodeDecodeError:
- # An object with a broken __unicode__ formatter. Use __str__
- # instead.
- return str(val).decode('utf-8', 'replace')
+ return val.decode("utf-8", "replace")
# Other objects are used as-is so field access, etc., still works in
- # the format string.
- else:
- return val
+ # the format string. Relies on a working __str__ implementation.
+ return val
-class StrFormatLogger(Logger):
+class StrFormatLogger(logging.Logger):
"""A version of `Logger` that uses `str.format`-style formatting
- instead of %-style formatting.
+ instead of %-style formatting and supports keyword arguments.
+
+ We cannot easily get rid of this even in the Python 3 era: This custom
+ formatting supports substitution from `kwargs` into the message, which the
+ default `logging.Logger._log()` implementation does not.
+
+ Remark by @sampsyo: https://stackoverflow.com/a/24683360 might be a way to
+ achieve this with less code.
"""
class _LogMessage:
@@ -82,19 +69,39 @@ class StrFormatLogger(Logger):
kwargs = {k: logsafe(v) for (k, v) in self.kwargs.items()}
return self.msg.format(*args, **kwargs)
- def _log(self, level, msg, args, exc_info=None, extra=None, **kwargs):
+ def _log(
+ self,
+ level,
+ msg,
+ args,
+ exc_info=None,
+ extra=None,
+ stack_info=False,
+ **kwargs,
+ ):
"""Log msg.format(*args, **kwargs)"""
m = self._LogMessage(msg, args, kwargs)
- return super()._log(level, m, (), exc_info, extra)
+
+ stacklevel = kwargs.pop("stacklevel", 1)
+ stacklevel = {"stacklevel": stacklevel}
+
+ return super()._log(
+ level,
+ m,
+ (),
+ exc_info=exc_info,
+ extra=extra,
+ stack_info=stack_info,
+ **stacklevel,
+ )
-class ThreadLocalLevelLogger(Logger):
- """A version of `Logger` whose level is thread-local instead of shared.
- """
+class ThreadLocalLevelLogger(logging.Logger):
+ """A version of `Logger` whose level is thread-local instead of shared."""
- def __init__(self, name, level=NOTSET):
+ def __init__(self, name, level=logging.NOTSET):
self._thread_level = threading.local()
- self.default_level = NOTSET
+ self.default_level = logging.NOTSET
super().__init__(name, level)
@property
@@ -121,12 +128,17 @@ class BeetsLogger(ThreadLocalLevelLogger, StrFormatLogger):
pass
-my_manager = copy(Logger.manager)
+my_manager = copy(logging.Logger.manager)
my_manager.loggerClass = BeetsLogger
+# Act like the stdlib logging module by re-exporting its namespace.
+from logging import * # noqa
+
+
+# Override the `getLogger` to use our machinery.
def getLogger(name=None): # noqa
if name:
return my_manager.getLogger(name)
else:
- return Logger.root
+ return logging.Logger.root
diff --git a/lib/beets/mediafile.py b/lib/beets/mediafile.py
index 82bcc973..8bde9274 100644
--- a/lib/beets/mediafile.py
+++ b/lib/beets/mediafile.py
@@ -13,14 +13,22 @@
# included in all copies or substantial portions of the Software.
+import warnings
+
import mediafile
-import warnings
-warnings.warn("beets.mediafile is deprecated; use mediafile instead")
+warnings.warn(
+ "beets.mediafile is deprecated; use mediafile instead",
+ # Show the location of the `import mediafile` statement as the warning's
+ # source, rather than this file, such that the offending module can be
+ # identified easily.
+ stacklevel=2,
+)
# Import everything from the mediafile module into this module.
for key, value in mediafile.__dict__.items():
- if key not in ['__name__']:
+ if key not in ["__name__"]:
globals()[key] = value
+# Cleanup namespace.
del key, value, warnings, mediafile
diff --git a/lib/beets/plugins.py b/lib/beets/plugins.py
index ed1f82d8..35995c34 100644
--- a/lib/beets/plugins.py
+++ b/lib/beets/plugins.py
@@ -15,26 +15,25 @@
"""Support for beets plugins."""
-import traceback
-import re
-import inspect
import abc
+import inspect
+import re
+import traceback
from collections import defaultdict
from functools import wraps
+import mediafile
import beets
from beets import logging
-import mediafile
-
-PLUGIN_NAMESPACE = 'beetsplug'
+PLUGIN_NAMESPACE = "beetsplug"
# Plugins using the Last.fm API can share the same API key.
-LASTFM_KEY = '2dc3914abf35f0d9c92d97d8f8e42b43'
+LASTFM_KEY = "2dc3914abf35f0d9c92d97d8f8e42b43"
# Global logger.
-log = logging.getLogger('beets')
+log = logging.getLogger("beets")
class PluginConflictException(Exception):
@@ -51,11 +50,10 @@ class PluginLogFilter(logging.Filter):
"""
def __init__(self, plugin):
- self.prefix = f'{plugin.name}: '
+ self.prefix = f"{plugin.name}: "
def filter(self, record):
- if hasattr(record.msg, 'msg') and isinstance(record.msg.msg,
- str):
+ if hasattr(record.msg, "msg") and isinstance(record.msg.msg, str):
# A _LogMessage from our hacked-up Logging replacement.
record.msg.msg = self.prefix + record.msg.msg
elif isinstance(record.msg, str):
@@ -65,6 +63,7 @@ class PluginLogFilter(logging.Filter):
# Managing the plugins themselves.
+
class BeetsPlugin:
"""The base class for all beets plugins. Plugins provide
functionality by defining a subclass of BeetsPlugin and overriding
@@ -72,9 +71,8 @@ class BeetsPlugin:
"""
def __init__(self, name=None):
- """Perform one-time plugin setup.
- """
- self.name = name or self.__module__.split('.')[-1]
+ """Perform one-time plugin setup."""
+ self.name = name or self.__module__.split(".")[-1]
self.config = beets.config[self.name]
if not self.template_funcs:
self.template_funcs = {}
@@ -97,10 +95,11 @@ class BeetsPlugin:
return ()
def _set_stage_log_level(self, stages):
- """Adjust all the stages in `stages` to WARNING logging level.
- """
- return [self._set_log_level_and_params(logging.WARNING, stage)
- for stage in stages]
+ """Adjust all the stages in `stages` to WARNING logging level."""
+ return [
+ self._set_log_level_and_params(logging.WARNING, stage)
+ for stage in stages
+ ]
def get_early_import_stages(self):
"""Return a list of functions that should be called as importer
@@ -134,12 +133,11 @@ class BeetsPlugin:
def wrapper(*args, **kwargs):
assert self._log.level == logging.NOTSET
- verbosity = beets.config['verbose'].get(int)
+ verbosity = beets.config["verbose"].get(int)
log_level = max(logging.DEBUG, base_log_level - 10 * verbosity)
self._log.setLevel(log_level)
if argspec.varkw is None:
- kwargs = {k: v for k, v in kwargs.items()
- if k in argspec.args}
+ kwargs = {k: v for k, v in kwargs.items() if k in argspec.args}
try:
return func(*args, **kwargs)
@@ -149,8 +147,7 @@ class BeetsPlugin:
return wrapper
def queries(self):
- """Should return a dict mapping prefixes to Query subclasses.
- """
+ """Return a dict mapping prefixes to Query subclasses."""
return {}
def track_distance(self, item, info):
@@ -201,6 +198,7 @@ class BeetsPlugin:
"""
# Defer import to prevent circular dependency
from beets import library
+
mediafile.MediaFile.add_field(name, descriptor)
library.Item._media_fields.add(name)
@@ -208,8 +206,7 @@ class BeetsPlugin:
listeners = None
def register_listener(self, event, func):
- """Add a function as a listener for the specified event.
- """
+ """Add a function as a listener for the specified event."""
wrapped_func = self._set_log_level_and_params(logging.WARNING, func)
cls = self.__class__
@@ -230,11 +227,13 @@ class BeetsPlugin:
function will be invoked as ``%name{}`` from path format
strings.
"""
+
def helper(func):
if cls.template_funcs is None:
cls.template_funcs = {}
cls.template_funcs[name] = func
return func
+
return helper
@classmethod
@@ -244,11 +243,13 @@ class BeetsPlugin:
strings. The function must accept a single parameter, the Item
being formatted.
"""
+
def helper(func):
if cls.template_fields is None:
cls.template_fields = {}
cls.template_fields[name] = func
return func
+
return helper
@@ -262,25 +263,29 @@ def load_plugins(names=()):
BeetsPlugin subclasses desired.
"""
for name in names:
- modname = f'{PLUGIN_NAMESPACE}.{name}'
+ modname = f"{PLUGIN_NAMESPACE}.{name}"
try:
try:
namespace = __import__(modname, None, None)
except ImportError as exc:
# Again, this is hacky:
- if exc.args[0].endswith(' ' + name):
- log.warning('** plugin {0} not found', name)
+ if exc.args[0].endswith(" " + name):
+ log.warning("** plugin {0} not found", name)
else:
raise
else:
for obj in getattr(namespace, name).__dict__.values():
- if isinstance(obj, type) and issubclass(obj, BeetsPlugin) \
- and obj != BeetsPlugin and obj not in _classes:
+ if (
+ isinstance(obj, type)
+ and issubclass(obj, BeetsPlugin)
+ and obj != BeetsPlugin
+ and obj not in _classes
+ ):
_classes.add(obj)
except Exception:
log.warning(
- '** error loading plugin {}:\n{}',
+ "** error loading plugin {}:\n{}",
name,
traceback.format_exc(),
)
@@ -311,9 +316,9 @@ def find_plugins():
# Communication with plugins.
+
def commands():
- """Returns a list of Subcommand objects from all loaded plugins.
- """
+ """Returns a list of Subcommand objects from all loaded plugins."""
out = []
for plugin in find_plugins():
out += plugin.commands()
@@ -332,16 +337,16 @@ def queries():
def types(model_cls):
# Gives us `item_types` and `album_types`
- attr_name = f'{model_cls.__name__.lower()}_types'
+ attr_name = f"{model_cls.__name__.lower()}_types"
types = {}
for plugin in find_plugins():
plugin_types = getattr(plugin, attr_name, {})
for field in plugin_types:
if field in types and plugin_types[field] != types[field]:
raise PluginConflictException(
- 'Plugin {} defines flexible field {} '
- 'which has already been defined with '
- 'another type.'.format(plugin.name, field)
+ "Plugin {} defines flexible field {} "
+ "which has already been defined with "
+ "another type.".format(plugin.name, field)
)
types.update(plugin_types)
return types
@@ -349,7 +354,7 @@ def types(model_cls):
def named_queries(model_cls):
# Gather `item_queries` and `album_queries` from the plugins.
- attr_name = f'{model_cls.__name__.lower()}_queries'
+ attr_name = f"{model_cls.__name__.lower()}_queries"
queries = {}
for plugin in find_plugins():
plugin_queries = getattr(plugin, attr_name, {})
@@ -362,6 +367,7 @@ def track_distance(item, info):
Returns a Distance object.
"""
from beets.autotag.hooks import Distance
+
dist = Distance()
for plugin in find_plugins():
dist.update(plugin.track_distance(item, info))
@@ -371,6 +377,7 @@ def track_distance(item, info):
def album_distance(items, album_info, mapping):
"""Returns the album distance calculated by plugins."""
from beets.autotag.hooks import Distance
+
dist = Distance()
for plugin in find_plugins():
dist.update(plugin.album_distance(items, album_info, mapping))
@@ -378,23 +385,21 @@ def album_distance(items, album_info, mapping):
def candidates(items, artist, album, va_likely, extra_tags=None):
- """Gets MusicBrainz candidates for an album from each plugin.
- """
+ """Gets MusicBrainz candidates for an album from each plugin."""
for plugin in find_plugins():
- yield from plugin.candidates(items, artist, album, va_likely,
- extra_tags)
+ yield from plugin.candidates(
+ items, artist, album, va_likely, extra_tags
+ )
def item_candidates(item, artist, title):
- """Gets MusicBrainz candidates for an item from the plugins.
- """
+ """Gets MusicBrainz candidates for an item from the plugins."""
for plugin in find_plugins():
yield from plugin.item_candidates(item, artist, title)
def album_for_id(album_id):
- """Get AlbumInfo objects for a given ID string.
- """
+ """Get AlbumInfo objects for a given ID string."""
for plugin in find_plugins():
album = plugin.album_for_id(album_id)
if album:
@@ -402,8 +407,7 @@ def album_for_id(album_id):
def track_for_id(track_id):
- """Get TrackInfo objects for a given ID string.
- """
+ """Get TrackInfo objects for a given ID string."""
for plugin in find_plugins():
track = plugin.track_for_id(track_id)
if track:
@@ -439,29 +443,44 @@ def import_stages():
# New-style (lazy) plugin-provided fields.
+
+def _check_conflicts_and_merge(plugin, plugin_funcs, funcs):
+ """Check the provided template functions for conflicts and merge into funcs.
+
+ Raises a `PluginConflictException` if a plugin defines template functions
+ for fields that another plugin has already defined template functions for.
+ """
+ if plugin_funcs:
+ if not plugin_funcs.keys().isdisjoint(funcs.keys()):
+ conflicted_fields = ", ".join(plugin_funcs.keys() & funcs.keys())
+ raise PluginConflictException(
+ f"Plugin {plugin.name} defines template functions for "
+ f"{conflicted_fields} that conflict with another plugin."
+ )
+ funcs.update(plugin_funcs)
+
+
def item_field_getters():
"""Get a dictionary mapping field names to unary functions that
compute the field's value.
"""
funcs = {}
for plugin in find_plugins():
- if plugin.template_fields:
- funcs.update(plugin.template_fields)
+ _check_conflicts_and_merge(plugin, plugin.template_fields, funcs)
return funcs
def album_field_getters():
- """As above, for album fields.
- """
+ """As above, for album fields."""
funcs = {}
for plugin in find_plugins():
- if plugin.album_template_fields:
- funcs.update(plugin.album_template_fields)
+ _check_conflicts_and_merge(plugin, plugin.album_template_fields, funcs)
return funcs
# Event dispatch.
+
def event_handlers():
"""Find all event handlers from plugins as a dictionary mapping
event names to sequences of callables.
@@ -482,7 +501,7 @@ def send(event, **arguments):
Return a list of non-None values returned from the handlers.
"""
- log.debug('Sending event: {0}', event)
+ log.debug("Sending event: {0}", event)
results = []
for handler in event_handlers()[event]:
result = handler(**arguments)
@@ -497,11 +516,11 @@ def feat_tokens(for_artist=True):
The `for_artist` option determines whether the regex should be
suitable for matching artist fields (the default) or title fields.
"""
- feat_words = ['ft', 'featuring', 'feat', 'feat.', 'ft.']
+ feat_words = ["ft", "featuring", "feat", "feat.", "ft."]
if for_artist:
- feat_words += ['with', 'vs', 'and', 'con', '&']
- return r'(?<=\s)(?:{})(?=\s)'.format(
- '|'.join(re.escape(x) for x in feat_words)
+ feat_words += ["with", "vs", "and", "con", "&"]
+ return r"(?<=\s)(?:{})(?=\s)".format(
+ "|".join(re.escape(x) for x in feat_words)
)
@@ -517,7 +536,7 @@ def sanitize_choices(choices, choices_all):
if s not in seen:
if s in list(choices_all):
res.append(s)
- elif s == '*':
+ elif s == "*":
res.extend(others)
seen.add(s)
return res
@@ -550,11 +569,11 @@ def sanitize_pairs(pairs, pairs_all):
if x not in seen:
seen.add(x)
res.append(x)
- elif k == '*':
+ elif k == "*":
new = [o for o in others if o not in seen]
seen.update(new)
res.extend(new)
- elif v == '*':
+ elif v == "*":
new = [o for o in others if o not in seen and o[0] == k]
seen.update(new)
res.extend(new)
@@ -568,12 +587,15 @@ def notify_info_yielded(event):
Each yielded value is passed to plugins using the 'info' parameter of
'send'.
"""
+
def decorator(generator):
def decorated(*args, **kwargs):
for v in generator(*args, **kwargs):
send(event, info=v)
yield v
+
return decorated
+
return decorator
@@ -583,7 +605,7 @@ def get_distance(config, data_source, info):
"""
dist = beets.autotag.Distance()
if info.data_source == data_source:
- dist.add('source', config['source_weight'].as_number())
+ dist.add("source", config["source_weight"].as_number())
return dist
@@ -620,7 +642,7 @@ def apply_item_changes(lib, item, move, pretend, write):
class MetadataSourcePlugin(metaclass=abc.ABCMeta):
def __init__(self):
super().__init__()
- self.config.add({'source_weight': 0.5})
+ self.config.add({"source_weight": 0.5})
@abc.abstractproperty
def id_regex(self):
@@ -643,7 +665,7 @@ class MetadataSourcePlugin(metaclass=abc.ABCMeta):
raise NotImplementedError
@abc.abstractmethod
- def _search_api(self, query_type, filters, keywords=''):
+ def _search_api(self, query_type, filters, keywords=""):
raise NotImplementedError
@abc.abstractmethod
@@ -655,7 +677,7 @@ class MetadataSourcePlugin(metaclass=abc.ABCMeta):
raise NotImplementedError
@staticmethod
- def get_artist(artists, id_key='id', name_key='name'):
+ def get_artist(artists, id_key="id", name_key="name", join_key=None):
"""Returns an artist string (all artists) and an artist_id (the main
artist) for a list of artist object dicts.
@@ -663,6 +685,8 @@ class MetadataSourcePlugin(metaclass=abc.ABCMeta):
and 'the') to the front and strips trailing disambiguation numbers. It
returns a tuple containing the comma-separated string of all
normalized artists and the ``id`` of the main/first artist.
+ Alternatively a keyword can be used to combine artists together into a
+ single string by passing the join_key argument.
:param artists: Iterable of artist dicts or lists returned by API.
:type artists: list[dict] or list[list]
@@ -673,39 +697,55 @@ class MetadataSourcePlugin(metaclass=abc.ABCMeta):
to concatenate for the artist string (containing all artists).
Defaults to 'name'.
:type name_key: str or int
+ :param join_key: Key or index corresponding to a field containing a
+ keyword to use for combining artists into a single string, for
+ example "Feat.", "Vs.", "And" or similar. The default is None
+ which keeps the default behaviour (comma-separated).
+ :type join_key: str or int
:return: Normalized artist string.
:rtype: str
"""
artist_id = None
- artist_names = []
- for artist in artists:
+ artist_string = ""
+ artists = list(artists) # In case a generator was passed.
+ total = len(artists)
+ for idx, artist in enumerate(artists):
if not artist_id:
artist_id = artist[id_key]
name = artist[name_key]
# Strip disambiguation number.
- name = re.sub(r' \(\d+\)$', '', name)
+ name = re.sub(r" \(\d+\)$", "", name)
# Move articles to the front.
- name = re.sub(r'^(.*?), (a|an|the)$', r'\2 \1', name, flags=re.I)
- artist_names.append(name)
- artist = ', '.join(artist_names).replace(' ,', ',') or None
- return artist, artist_id
+ name = re.sub(r"^(.*?), (a|an|the)$", r"\2 \1", name, flags=re.I)
+ # Use a join keyword if requested and available.
+ if idx < (total - 1): # Skip joining on last.
+ if join_key and artist.get(join_key, None):
+ name += f" {artist[join_key]} "
+ else:
+ name += ", "
+ artist_string += name
- def _get_id(self, url_type, id_):
+ return artist_string, artist_id
+
+ @staticmethod
+ def _get_id(url_type, id_, id_regex):
"""Parse an ID from its URL if necessary.
:param url_type: Type of URL. Either 'album' or 'track'.
:type url_type: str
:param id_: Album/track ID or URL.
:type id_: str
+ :param id_regex: A dictionary containing a regular expression
+ extracting an ID from an URL (if it's not an ID already) in
+ 'pattern' and the number of the match group in 'match_group'.
+ :type id_regex: dict
:return: Album/track ID.
:rtype: str
"""
- self._log.debug(
- "Searching {} for {} '{}'", self.data_source, url_type, id_
- )
- match = re.search(self.id_regex['pattern'].format(url_type), str(id_))
+ log.debug("Extracting {} ID from '{}'", url_type, id_)
+ match = re.search(id_regex["pattern"].format(url_type), str(id_))
if match:
- id_ = match.group(self.id_regex['match_group'])
+ id_ = match.group(id_regex["match_group"])
if id_:
return id_
return None
@@ -726,11 +766,11 @@ class MetadataSourcePlugin(metaclass=abc.ABCMeta):
:return: Candidate AlbumInfo objects.
:rtype: list[beets.autotag.hooks.AlbumInfo]
"""
- query_filters = {'album': album}
+ query_filters = {"album": album}
if not va_likely:
- query_filters['artist'] = artist
- results = self._search_api(query_type='album', filters=query_filters)
- albums = [self.album_for_id(album_id=r['id']) for r in results]
+ query_filters["artist"] = artist
+ results = self._search_api(query_type="album", filters=query_filters)
+ albums = [self.album_for_id(album_id=r["id"]) for r in results]
return [a for a in albums if a is not None]
def item_candidates(self, item, artist, title):
@@ -747,7 +787,7 @@ class MetadataSourcePlugin(metaclass=abc.ABCMeta):
:rtype: list[beets.autotag.hooks.TrackInfo]
"""
tracks = self._search_api(
- query_type='track', keywords=title, filters={'artist': artist}
+ query_type="track", keywords=title, filters={"artist": artist}
)
return [self.track_for_id(track_data=track) for track in tracks]
diff --git a/lib/beets/random.py b/lib/beets/random.py
index eb4f55af..f3318054 100644
--- a/lib/beets/random.py
+++ b/lib/beets/random.py
@@ -12,24 +12,22 @@
# The above copyright notice and this permission notice shall be
# included in all copies or substantial portions of the Software.
-"""Get a random song or album from the library.
-"""
+"""Get a random song or album from the library."""
import random
-from operator import attrgetter
from itertools import groupby
+from operator import attrgetter
def _length(obj, album):
- """Get the duration of an item or album.
- """
+ """Get the duration of an item or album."""
if album:
return sum(i.length for i in obj.items())
else:
return obj.length
-def _equal_chance_permutation(objs, field='albumartist', random_gen=None):
+def _equal_chance_permutation(objs, field="albumartist", random_gen=None):
"""Generate (lazily) a permutation of the objects where every group
with equal values for `field` have an equal chance of appearing in
any given position.
@@ -86,8 +84,9 @@ def _take_time(iter, secs, album):
return out
-def random_objs(objs, album, number=1, time=None, equal_chance=False,
- random_gen=None):
+def random_objs(
+ objs, album, number=1, time=None, equal_chance=False, random_gen=None
+):
"""Get a random subset of the provided `objs`.
If `number` is provided, produce that many matches. Otherwise, if
diff --git a/lib/beets/util/confit.py b/lib/beets/test/__init__.py
similarity index 62%
rename from lib/beets/util/confit.py
rename to lib/beets/test/__init__.py
index dd912c44..2af37583 100644
--- a/lib/beets/util/confit.py
+++ b/lib/beets/test/__init__.py
@@ -1,5 +1,5 @@
# This file is part of beets.
-# Copyright 2016-2019, Adrian Sampson.
+# Copyright 2024, Lars Kruse
#
# Permission is hereby granted, free of charge, to any person obtaining
# a copy of this software and associated documentation files (the
@@ -12,17 +12,8 @@
# The above copyright notice and this permission notice shall be
# included in all copies or substantial portions of the Software.
-
-import confuse
-
-import warnings
-warnings.warn("beets.util.confit is deprecated; use confuse instead")
-
-# Import everything from the confuse module into this module.
-for key, value in confuse.__dict__.items():
- if key not in ['__name__']:
- globals()[key] = value
-
-
-# Cleanup namespace.
-del key, value, warnings, confuse
+"""This module contains components of beets' test environment, which
+may be of use for testing procedures of external libraries or programs.
+For example the 'TestHelper' class may be useful for creating an
+in-memory beets library filled with a few example items.
+"""
diff --git a/lib/beets/test/_common.py b/lib/beets/test/_common.py
new file mode 100644
index 00000000..50dbde43
--- /dev/null
+++ b/lib/beets/test/_common.py
@@ -0,0 +1,328 @@
+# This file is part of beets.
+# Copyright 2016, Adrian Sampson.
+#
+# Permission is hereby granted, free of charge, to any person obtaining
+# a copy of this software and associated documentation files (the
+# "Software"), to deal in the Software without restriction, including
+# without limitation the rights to use, copy, modify, merge, publish,
+# distribute, sublicense, and/or sell copies of the Software, and to
+# permit persons to whom the Software is furnished to do so, subject to
+# the following conditions:
+#
+# The above copyright notice and this permission notice shall be
+# included in all copies or substantial portions of the Software.
+
+"""Some common functionality for beets' test cases."""
+
+import os
+import sys
+import tempfile
+import unittest
+from contextlib import contextmanager
+
+import beets
+import beets.library
+
+# Make sure the development versions of the plugins are used
+import beetsplug
+from beets import importer, logging, util
+from beets.ui import commands
+from beets.util import syspath
+
+beetsplug.__path__ = [
+ os.path.abspath(
+ os.path.join(
+ os.path.dirname(__file__),
+ os.path.pardir,
+ os.path.pardir,
+ "beetsplug",
+ )
+ )
+]
+
+# Test resources path.
+RSRC = util.bytestring_path(
+ os.path.abspath(
+ os.path.join(
+ os.path.dirname(__file__),
+ os.path.pardir,
+ os.path.pardir,
+ "test",
+ "rsrc",
+ )
+ )
+)
+PLUGINPATH = os.path.join(RSRC.decode(), "beetsplug")
+
+# Propagate to root logger so the test runner can capture it
+log = logging.getLogger("beets")
+log.propagate = True
+log.setLevel(logging.DEBUG)
+
+# Dummy item creation.
+_item_ident = 0
+
+# OS feature test.
+HAVE_SYMLINK = sys.platform != "win32"
+HAVE_HARDLINK = sys.platform != "win32"
+
+try:
+ import reflink
+
+ HAVE_REFLINK = reflink.supported_at(tempfile.gettempdir())
+except ImportError:
+ HAVE_REFLINK = False
+
+
+def item(lib=None):
+ global _item_ident
+ _item_ident += 1
+ i = beets.library.Item(
+ title="the title",
+ artist="the artist",
+ albumartist="the album artist",
+ album="the album",
+ genre="the genre",
+ lyricist="the lyricist",
+ composer="the composer",
+ arranger="the arranger",
+ grouping="the grouping",
+ work="the work title",
+ mb_workid="the work musicbrainz id",
+ work_disambig="the work disambiguation",
+ year=1,
+ month=2,
+ day=3,
+ track=4,
+ tracktotal=5,
+ disc=6,
+ disctotal=7,
+ lyrics="the lyrics",
+ comments="the comments",
+ bpm=8,
+ comp=True,
+ path=f"somepath{_item_ident}",
+ length=60.0,
+ bitrate=128000,
+ format="FLAC",
+ mb_trackid="someID-1",
+ mb_albumid="someID-2",
+ mb_artistid="someID-3",
+ mb_albumartistid="someID-4",
+ mb_releasetrackid="someID-5",
+ album_id=None,
+ mtime=12345,
+ )
+ if lib:
+ lib.add(i)
+ return i
+
+
+def album(lib=None):
+ global _item_ident
+ _item_ident += 1
+ i = beets.library.Album(
+ artpath=None,
+ albumartist="some album artist",
+ albumartist_sort="some sort album artist",
+ albumartist_credit="some album artist credit",
+ album="the album",
+ genre="the genre",
+ year=2014,
+ month=2,
+ day=5,
+ tracktotal=0,
+ disctotal=1,
+ comp=False,
+ mb_albumid="someID-1",
+ mb_albumartistid="someID-1",
+ )
+ if lib:
+ lib.add(i)
+ return i
+
+
+# Dummy import session.
+def import_session(lib=None, loghandler=None, paths=[], query=[], cli=False):
+ cls = commands.TerminalImportSession if cli else importer.ImportSession
+ return cls(lib, loghandler, paths, query)
+
+
+class Assertions:
+ """A mixin with additional unit test assertions."""
+
+ def assertExists(self, path): # noqa
+ assert os.path.exists(syspath(path)), f"file does not exist: {path!r}"
+
+ def assertNotExists(self, path): # noqa
+ assert not os.path.exists(syspath(path)), f"file exists: {path!r}"
+
+ def assertIsFile(self, path): # noqa
+ self.assertExists(path)
+ assert os.path.isfile(
+ syspath(path)
+ ), "path exists, but is not a regular file: {!r}".format(path)
+
+ def assertIsDir(self, path): # noqa
+ self.assertExists(path)
+ assert os.path.isdir(
+ syspath(path)
+ ), "path exists, but is not a directory: {!r}".format(path)
+
+ def assert_equal_path(self, a, b):
+ """Check that two paths are equal."""
+ a_bytes, b_bytes = util.normpath(a), util.normpath(b)
+
+ assert a_bytes == b_bytes, f"{a_bytes=} != {b_bytes=}"
+
+
+# Mock I/O.
+
+
+class InputException(Exception):
+ def __init__(self, output=None):
+ self.output = output
+
+ def __str__(self):
+ msg = "Attempt to read with no input provided."
+ if self.output is not None:
+ msg += f" Output: {self.output!r}"
+ return msg
+
+
+class DummyOut:
+ encoding = "utf-8"
+
+ def __init__(self):
+ self.buf = []
+
+ def write(self, s):
+ self.buf.append(s)
+
+ def get(self):
+ return "".join(self.buf)
+
+ def flush(self):
+ self.clear()
+
+ def clear(self):
+ self.buf = []
+
+
+class DummyIn:
+ encoding = "utf-8"
+
+ def __init__(self, out=None):
+ self.buf = []
+ self.reads = 0
+ self.out = out
+
+ def add(self, s):
+ self.buf.append(s + "\n")
+
+ def close(self):
+ pass
+
+ def readline(self):
+ if not self.buf:
+ if self.out:
+ raise InputException(self.out.get())
+ else:
+ raise InputException()
+ self.reads += 1
+ return self.buf.pop(0)
+
+
+class DummyIO:
+ """Mocks input and output streams for testing UI code."""
+
+ def __init__(self):
+ self.stdout = DummyOut()
+ self.stdin = DummyIn(self.stdout)
+
+ def addinput(self, s):
+ self.stdin.add(s)
+
+ def getoutput(self):
+ res = self.stdout.get()
+ self.stdout.clear()
+ return res
+
+ def readcount(self):
+ return self.stdin.reads
+
+ def install(self):
+ sys.stdin = self.stdin
+ sys.stdout = self.stdout
+
+ def restore(self):
+ sys.stdin = sys.__stdin__
+ sys.stdout = sys.__stdout__
+
+
+# Utility.
+
+
+def touch(path):
+ open(syspath(path), "a").close()
+
+
+class Bag:
+ """An object that exposes a set of fields given as keyword
+ arguments. Any field not found in the dictionary appears to be None.
+ Used for mocking Album objects and the like.
+ """
+
+ def __init__(self, **fields):
+ self.fields = fields
+
+ def __getattr__(self, key):
+ return self.fields.get(key)
+
+
+# Platform mocking.
+
+
+@contextmanager
+def platform_windows():
+ import ntpath
+
+ old_path = os.path
+ try:
+ os.path = ntpath
+ yield
+ finally:
+ os.path = old_path
+
+
+@contextmanager
+def platform_posix():
+ import posixpath
+
+ old_path = os.path
+ try:
+ os.path = posixpath
+ yield
+ finally:
+ os.path = old_path
+
+
+@contextmanager
+def system_mock(name):
+ import platform
+
+ old_system = platform.system
+ platform.system = lambda: name
+ try:
+ yield
+ finally:
+ platform.system = old_system
+
+
+def slow_test(unused=None):
+ def _id(obj):
+ return obj
+
+ if "SKIP_SLOW_TESTS" in os.environ:
+ return unittest.skip("test is slow")
+ return _id
diff --git a/lib/beets/test/helper.py b/lib/beets/test/helper.py
new file mode 100644
index 00000000..470498b5
--- /dev/null
+++ b/lib/beets/test/helper.py
@@ -0,0 +1,1003 @@
+# This file is part of beets.
+# Copyright 2016, Thomas Scholtes.
+#
+# Permission is hereby granted, free of charge, to any person obtaining
+# a copy of this software and associated documentation files (the
+# "Software"), to deal in the Software without restriction, including
+# without limitation the rights to use, copy, modify, merge, publish,
+# distribute, sublicense, and/or sell copies of the Software, and to
+# permit persons to whom the Software is furnished to do so, subject to
+# the following conditions:
+#
+# The above copyright notice and this permission notice shall be
+# included in all copies or substantial portions of the Software.
+
+"""This module includes various helpers that provide fixtures, capture
+information or mock the environment.
+
+- The `control_stdin` and `capture_stdout` context managers allow one to
+ interact with the user interface.
+
+- `has_program` checks the presence of a command on the system.
+
+- The `generate_album_info` and `generate_track_info` functions return
+ fixtures to be used when mocking the autotagger.
+
+- The `ImportSessionFixture` allows one to run importer code while
+ controlling the interactions through code.
+
+- The `TestHelper` class encapsulates various fixtures that can be set up.
+"""
+
+from __future__ import annotations
+
+import os
+import os.path
+import shutil
+import subprocess
+import sys
+import unittest
+from contextlib import contextmanager
+from enum import Enum
+from functools import cached_property
+from io import StringIO
+from pathlib import Path
+from tempfile import mkdtemp, mkstemp
+from typing import Any, ClassVar
+from unittest.mock import patch
+
+import responses
+from mediafile import Image, MediaFile
+
+import beets
+import beets.plugins
+from beets import autotag, config, importer, logging, util
+from beets.autotag.hooks import AlbumInfo, TrackInfo
+from beets.importer import ImportSession
+from beets.library import Album, Item, Library
+from beets.test import _common
+from beets.ui.commands import TerminalImportSession
+from beets.util import (
+ MoveOperation,
+ bytestring_path,
+ clean_module_tempdir,
+ syspath,
+)
+
+
+class LogCapture(logging.Handler):
+ def __init__(self):
+ logging.Handler.__init__(self)
+ self.messages = []
+
+ def emit(self, record):
+ self.messages.append(str(record.msg))
+
+
+@contextmanager
+def capture_log(logger="beets"):
+ capture = LogCapture()
+ log = logging.getLogger(logger)
+ log.addHandler(capture)
+ try:
+ yield capture.messages
+ finally:
+ log.removeHandler(capture)
+
+
+@contextmanager
+def control_stdin(input=None):
+ """Sends ``input`` to stdin.
+
+ >>> with control_stdin('yes'):
+ ... input()
+ 'yes'
+ """
+ org = sys.stdin
+ sys.stdin = StringIO(input)
+ try:
+ yield sys.stdin
+ finally:
+ sys.stdin = org
+
+
+@contextmanager
+def capture_stdout():
+ """Save stdout in a StringIO.
+
+ >>> with capture_stdout() as output:
+ ... print('spam')
+ ...
+ >>> output.getvalue()
+ 'spam'
+ """
+ org = sys.stdout
+ sys.stdout = capture = StringIO()
+ try:
+ yield sys.stdout
+ finally:
+ sys.stdout = org
+ print(capture.getvalue())
+
+
+def _convert_args(args):
+ """Convert args to bytestrings for Python 2 and convert them to strings
+ on Python 3.
+ """
+ for i, elem in enumerate(args):
+ if isinstance(elem, bytes):
+ args[i] = elem.decode(util.arg_encoding())
+
+ return args
+
+
+def has_program(cmd, args=["--version"]):
+ """Returns `True` if `cmd` can be executed."""
+ full_cmd = _convert_args([cmd] + args)
+ try:
+ with open(os.devnull, "wb") as devnull:
+ subprocess.check_call(
+ full_cmd, stderr=devnull, stdout=devnull, stdin=devnull
+ )
+ except OSError:
+ return False
+ except subprocess.CalledProcessError:
+ return False
+ else:
+ return True
+
+
+class TestHelper(_common.Assertions):
+ """Helper mixin for high-level cli and plugin tests.
+
+ This mixin provides methods to isolate beets' global state provide
+ fixtures.
+ """
+
+ db_on_disk: ClassVar[bool] = False
+
+ # TODO automate teardown through hook registration
+
+ def setup_beets(self):
+ """Setup pristine global configuration and library for testing.
+
+ Sets ``beets.config`` so we can safely use any functionality
+ that uses the global configuration. All paths used are
+ contained in a temporary directory
+
+ Sets the following properties on itself.
+
+ - ``temp_dir`` Path to a temporary directory containing all
+ files specific to beets
+
+ - ``libdir`` Path to a subfolder of ``temp_dir``, containing the
+ library's media files. Same as ``config['directory']``.
+
+ - ``config`` The global configuration used by beets.
+
+ - ``lib`` Library instance created with the settings from
+ ``config``.
+
+ Make sure you call ``teardown_beets()`` afterwards.
+ """
+ self.create_temp_dir()
+ temp_dir_str = os.fsdecode(self.temp_dir)
+ self.env_patcher = patch.dict(
+ "os.environ",
+ {
+ "BEETSDIR": temp_dir_str,
+ "HOME": temp_dir_str, # used by Confuse to create directories.
+ },
+ )
+ self.env_patcher.start()
+
+ self.config = beets.config
+ self.config.sources = []
+ self.config.read(user=False, defaults=True)
+
+ self.config["plugins"] = []
+ self.config["verbose"] = 1
+ self.config["ui"]["color"] = False
+ self.config["threaded"] = False
+
+ self.libdir = os.path.join(self.temp_dir, b"libdir")
+ os.mkdir(syspath(self.libdir))
+ self.config["directory"] = os.fsdecode(self.libdir)
+
+ if self.db_on_disk:
+ dbpath = util.bytestring_path(self.config["library"].as_filename())
+ else:
+ dbpath = ":memory:"
+ self.lib = Library(dbpath, self.libdir)
+
+ # Initialize, but don't install, a DummyIO.
+ self.io = _common.DummyIO()
+
+ def teardown_beets(self):
+ self.env_patcher.stop()
+ self.io.restore()
+ self.lib._close()
+ self.remove_temp_dir()
+ beets.config.clear()
+ beets.config._materialized = False
+
+ # Library fixtures methods
+
+ def create_item(self, **values):
+ """Return an `Item` instance with sensible default values.
+
+ The item receives its attributes from `**values` paratmeter. The
+ `title`, `artist`, `album`, `track`, `format` and `path`
+ attributes have defaults if they are not given as parameters.
+ The `title` attribute is formatted with a running item count to
+ prevent duplicates. The default for the `path` attribute
+ respects the `format` value.
+
+ The item is attached to the database from `self.lib`.
+ """
+ item_count = self._get_item_count()
+ values_ = {
+ "title": "t\u00eftle {0}",
+ "artist": "the \u00e4rtist",
+ "album": "the \u00e4lbum",
+ "track": item_count,
+ "format": "MP3",
+ }
+ values_.update(values)
+ values_["title"] = values_["title"].format(item_count)
+ values_["db"] = self.lib
+ item = Item(**values_)
+ if "path" not in values:
+ item["path"] = "audio." + item["format"].lower()
+ # mtime needs to be set last since other assignments reset it.
+ item.mtime = 12345
+ return item
+
+ def add_item(self, **values):
+ """Add an item to the library and return it.
+
+ Creates the item by passing the parameters to `create_item()`.
+
+ If `path` is not set in `values` it is set to `item.destination()`.
+ """
+ # When specifying a path, store it normalized (as beets does
+ # ordinarily).
+ if "path" in values:
+ values["path"] = util.normpath(values["path"])
+
+ item = self.create_item(**values)
+ item.add(self.lib)
+
+ # Ensure every item has a path.
+ if "path" not in values:
+ item["path"] = item.destination()
+ item.store()
+
+ return item
+
+ def add_item_fixture(self, **values):
+ """Add an item with an actual audio file to the library."""
+ item = self.create_item(**values)
+ extension = item["format"].lower()
+ item["path"] = os.path.join(
+ _common.RSRC, util.bytestring_path("min." + extension)
+ )
+ item.add(self.lib)
+ item.move(operation=MoveOperation.COPY)
+ item.store()
+ return item
+
+ def add_album(self, **values):
+ item = self.add_item(**values)
+ return self.lib.add_album([item])
+
+ def add_item_fixtures(self, ext="mp3", count=1):
+ """Add a number of items with files to the database."""
+ # TODO base this on `add_item()`
+ items = []
+ path = os.path.join(_common.RSRC, util.bytestring_path("full." + ext))
+ for i in range(count):
+ item = Item.from_path(path)
+ item.album = f"\u00e4lbum {i}" # Check unicode paths
+ item.title = f"t\u00eftle {i}"
+ # mtime needs to be set last since other assignments reset it.
+ item.mtime = 12345
+ item.add(self.lib)
+ item.move(operation=MoveOperation.COPY)
+ item.store()
+ items.append(item)
+ return items
+
+ def add_album_fixture(
+ self,
+ track_count=1,
+ fname="full",
+ ext="mp3",
+ disc_count=1,
+ ):
+ """Add an album with files to the database."""
+ items = []
+ path = os.path.join(
+ _common.RSRC,
+ util.bytestring_path(f"{fname}.{ext}"),
+ )
+ for discnumber in range(1, disc_count + 1):
+ for i in range(track_count):
+ item = Item.from_path(path)
+ item.album = "\u00e4lbum" # Check unicode paths
+ item.title = f"t\u00eftle {i}"
+ item.disc = discnumber
+ # mtime needs to be set last since other assignments reset it.
+ item.mtime = 12345
+ item.add(self.lib)
+ item.move(operation=MoveOperation.COPY)
+ item.store()
+ items.append(item)
+ return self.lib.add_album(items)
+
+ def create_mediafile_fixture(self, ext="mp3", images=[]):
+ """Copy a fixture mediafile with the extension to `temp_dir`.
+
+ `images` is a subset of 'png', 'jpg', and 'tiff'. For each
+ specified extension a cover art image is added to the media
+ file.
+ """
+ src = os.path.join(_common.RSRC, util.bytestring_path("full." + ext))
+ handle, path = mkstemp(dir=self.temp_dir)
+ path = bytestring_path(path)
+ os.close(handle)
+ shutil.copyfile(syspath(src), syspath(path))
+
+ if images:
+ mediafile = MediaFile(path)
+ imgs = []
+ for img_ext in images:
+ file = util.bytestring_path(f"image-2x3.{img_ext}")
+ img_path = os.path.join(_common.RSRC, file)
+ with open(img_path, "rb") as f:
+ imgs.append(Image(f.read()))
+ mediafile.images = imgs
+ mediafile.save()
+
+ return path
+
+ def _get_item_count(self):
+ if not hasattr(self, "__item_count"):
+ count = 0
+ self.__item_count = count + 1
+ return count
+
+ # Running beets commands
+
+ def run_command(self, *args, **kwargs):
+ """Run a beets command with an arbitrary amount of arguments. The
+ Library` defaults to `self.lib`, but can be overridden with
+ the keyword argument `lib`.
+ """
+ sys.argv = ["beet"] # avoid leakage from test suite args
+ lib = None
+ if hasattr(self, "lib"):
+ lib = self.lib
+ lib = kwargs.get("lib", lib)
+ beets.ui._raw_main(_convert_args(list(args)), lib)
+
+ def run_with_output(self, *args):
+ with capture_stdout() as out:
+ self.run_command(*args)
+ return out.getvalue()
+
+ # Safe file operations
+
+ def create_temp_dir(self, **kwargs):
+ """Create a temporary directory and assign it into
+ `self.temp_dir`. Call `remove_temp_dir` later to delete it.
+ """
+ temp_dir = mkdtemp(**kwargs)
+ self.temp_dir = util.bytestring_path(temp_dir)
+
+ def remove_temp_dir(self):
+ """Delete the temporary directory created by `create_temp_dir`."""
+ shutil.rmtree(syspath(self.temp_dir))
+
+ def touch(self, path, dir=None, content=""):
+ """Create a file at `path` with given content.
+
+ If `dir` is given, it is prepended to `path`. After that, if the
+ path is relative, it is resolved with respect to
+ `self.temp_dir`.
+ """
+ if dir:
+ path = os.path.join(dir, path)
+
+ if not os.path.isabs(path):
+ path = os.path.join(self.temp_dir, path)
+
+ parent = os.path.dirname(path)
+ if not os.path.isdir(syspath(parent)):
+ os.makedirs(syspath(parent))
+
+ with open(syspath(path), "a+") as f:
+ f.write(content)
+ return path
+
+
+# A test harness for all beets tests.
+# Provides temporary, isolated configuration.
+class BeetsTestCase(unittest.TestCase, TestHelper):
+ """A unittest.TestCase subclass that saves and restores beets'
+ global configuration. This allows tests to make temporary
+ modifications that will then be automatically removed when the test
+ completes. Also provides some additional assertion methods, a
+ temporary directory, and a DummyIO.
+ """
+
+ def setUp(self):
+ self.setup_beets()
+
+ def tearDown(self):
+ self.teardown_beets()
+
+
+class ItemInDBTestCase(BeetsTestCase):
+ """A test case that includes an in-memory library object (`lib`) and
+ an item added to the library (`i`).
+ """
+
+ def setUp(self):
+ super().setUp()
+ self.i = _common.item(self.lib)
+
+
+class PluginMixin:
+ plugin: ClassVar[str]
+ preload_plugin: ClassVar[bool] = True
+
+ def setUp(self):
+ super().setUp()
+ if self.preload_plugin:
+ self.load_plugins()
+
+ def tearDown(self):
+ super().tearDown()
+ self.unload_plugins()
+
+ def load_plugins(self, *plugins: str) -> None:
+ """Load and initialize plugins by names.
+
+ Similar setting a list of plugins in the configuration. Make
+ sure you call ``unload_plugins()`` afterwards.
+ """
+ # FIXME this should eventually be handled by a plugin manager
+ plugins = (self.plugin,) if hasattr(self, "plugin") else plugins
+ beets.config["plugins"] = plugins
+ beets.plugins.load_plugins(plugins)
+ beets.plugins.find_plugins()
+
+ # Take a backup of the original _types and _queries to restore
+ # when unloading.
+ Item._original_types = dict(Item._types)
+ Album._original_types = dict(Album._types)
+ Item._types.update(beets.plugins.types(Item))
+ Album._types.update(beets.plugins.types(Album))
+
+ Item._original_queries = dict(Item._queries)
+ Album._original_queries = dict(Album._queries)
+ Item._queries.update(beets.plugins.named_queries(Item))
+ Album._queries.update(beets.plugins.named_queries(Album))
+
+ def unload_plugins(self) -> None:
+ """Unload all plugins and remove them from the configuration."""
+ # FIXME this should eventually be handled by a plugin manager
+ for plugin_class in beets.plugins._instances:
+ plugin_class.listeners = None
+ beets.config["plugins"] = []
+ beets.plugins._classes = set()
+ beets.plugins._instances = {}
+ Item._types = getattr(Item, "_original_types", {})
+ Album._types = getattr(Album, "_original_types", {})
+ Item._queries = getattr(Item, "_original_queries", {})
+ Album._queries = getattr(Album, "_original_queries", {})
+
+ @contextmanager
+ def configure_plugin(self, config: list[Any] | dict[str, Any]):
+ if isinstance(config, list):
+ beets.config[self.plugin] = config
+ else:
+ for key, value in config.items():
+ beets.config[self.plugin][key] = value
+ self.load_plugins(self.plugin)
+
+ yield
+
+ self.unload_plugins()
+
+
+class PluginTestCase(PluginMixin, BeetsTestCase):
+ pass
+
+
+class ImportHelper(TestHelper):
+ """Provides tools to setup a library, a directory containing files that are
+ to be imported and an import session. The class also provides stubs for the
+ autotagging library and several assertions for the library.
+ """
+
+ resource_path = syspath(os.path.join(_common.RSRC, b"full.mp3"))
+ default_import_config = {
+ "autotag": True,
+ "copy": True,
+ "hardlink": False,
+ "link": False,
+ "move": False,
+ "resume": False,
+ "singletons": False,
+ "timid": True,
+ }
+
+ lib: Library
+ importer: ImportSession
+
+ @cached_property
+ def import_path(self) -> Path:
+ import_path = Path(os.fsdecode(self.temp_dir)) / "import"
+ import_path.mkdir(exist_ok=True)
+ return import_path
+
+ @cached_property
+ def import_dir(self) -> bytes:
+ return bytestring_path(self.import_path)
+
+ def setUp(self):
+ super().setUp()
+ self.import_media = []
+ self.lib.path_formats = [
+ ("default", os.path.join("$artist", "$album", "$title")),
+ ("singleton:true", os.path.join("singletons", "$title")),
+ ("comp:true", os.path.join("compilations", "$album", "$title")),
+ ]
+
+ def prepare_track_for_import(
+ self,
+ track_id: int,
+ album_path: Path,
+ album_id: int | None = None,
+ ) -> Path:
+ track_path = album_path / f"track_{track_id}.mp3"
+ shutil.copy(self.resource_path, track_path)
+ medium = MediaFile(track_path)
+ medium.update(
+ {
+ "album": "Tag Album" + (f" {album_id}" if album_id else ""),
+ "albumartist": None,
+ "mb_albumid": None,
+ "comp": None,
+ "artist": "Tag Artist",
+ "title": f"Tag Track {track_id}",
+ "track": track_id,
+ "mb_trackid": None,
+ }
+ )
+ medium.save()
+ self.import_media.append(medium)
+ return track_path
+
+ def prepare_album_for_import(
+ self,
+ item_count: int,
+ album_id: int | None = None,
+ album_path: Path | None = None,
+ ) -> list[Path]:
+ """Create an album directory with media files to import.
+
+ The directory has following layout
+ album/
+ track_1.mp3
+ track_2.mp3
+ track_3.mp3
+ """
+ if not album_path:
+ album_dir = f"album_{album_id}" if album_id else "album"
+ album_path = self.import_path / album_dir
+
+ album_path.mkdir(exist_ok=True)
+
+ return [
+ self.prepare_track_for_import(tid, album_path, album_id=album_id)
+ for tid in range(1, item_count + 1)
+ ]
+
+ def prepare_albums_for_import(self, count: int = 1) -> None:
+ album_dirs = Path(os.fsdecode(self.import_dir)).glob("album_*")
+ base_idx = int(str(max(album_dirs, default="0")).split("_")[-1]) + 1
+
+ for album_id in range(base_idx, count + base_idx):
+ self.prepare_album_for_import(1, album_id=album_id)
+
+ def _get_import_session(self, import_dir: bytes) -> ImportSession:
+ return ImportSessionFixture(
+ self.lib,
+ loghandler=None,
+ query=None,
+ paths=[import_dir],
+ )
+
+ def setup_importer(
+ self, import_dir: bytes | None = None, **kwargs
+ ) -> ImportSession:
+ config["import"].set_args({**self.default_import_config, **kwargs})
+ self.importer = self._get_import_session(import_dir or self.import_dir)
+ return self.importer
+
+ def setup_singleton_importer(self, **kwargs) -> ImportSession:
+ return self.setup_importer(singletons=True, **kwargs)
+
+ def assert_file_in_lib(self, *segments):
+ """Join the ``segments`` and assert that this path exists in the
+ library directory.
+ """
+ self.assertExists(os.path.join(self.libdir, *segments))
+
+ def assert_file_not_in_lib(self, *segments):
+ """Join the ``segments`` and assert that this path does not
+ exist in the library directory.
+ """
+ self.assertNotExists(os.path.join(self.libdir, *segments))
+
+ def assert_lib_dir_empty(self):
+ assert not os.listdir(syspath(self.libdir))
+
+
+class AsIsImporterMixin:
+ def setUp(self):
+ super().setUp()
+ self.prepare_album_for_import(1)
+
+ def run_asis_importer(self, **kwargs):
+ importer = self.setup_importer(autotag=False, **kwargs)
+ importer.run()
+ return importer
+
+
+class ImportTestCase(ImportHelper, BeetsTestCase):
+ pass
+
+
+class ImportSessionFixture(ImportSession):
+ """ImportSession that can be controlled programaticaly.
+
+ >>> lib = Library(':memory:')
+ >>> importer = ImportSessionFixture(lib, paths=['/path/to/import'])
+ >>> importer.add_choice(importer.action.SKIP)
+ >>> importer.add_choice(importer.action.ASIS)
+ >>> importer.default_choice = importer.action.APPLY
+ >>> importer.run()
+
+ This imports ``/path/to/import`` into `lib`. It skips the first
+ album and imports thesecond one with metadata from the tags. For the
+ remaining albums, the metadata from the autotagger will be applied.
+ """
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self._choices = []
+ self._resolutions = []
+
+ default_choice = importer.action.APPLY
+
+ def add_choice(self, choice):
+ self._choices.append(choice)
+
+ def clear_choices(self):
+ self._choices = []
+
+ def choose_match(self, task):
+ try:
+ choice = self._choices.pop(0)
+ except IndexError:
+ choice = self.default_choice
+
+ if choice == importer.action.APPLY:
+ return task.candidates[0]
+ elif isinstance(choice, int):
+ return task.candidates[choice - 1]
+ else:
+ return choice
+
+ choose_item = choose_match
+
+ Resolution = Enum("Resolution", "REMOVE SKIP KEEPBOTH MERGE")
+
+ default_resolution = "REMOVE"
+
+ def add_resolution(self, resolution):
+ assert isinstance(resolution, self.Resolution)
+ self._resolutions.append(resolution)
+
+ def resolve_duplicate(self, task, found_duplicates):
+ try:
+ res = self._resolutions.pop(0)
+ except IndexError:
+ res = self.default_resolution
+
+ if res == self.Resolution.SKIP:
+ task.set_choice(importer.action.SKIP)
+ elif res == self.Resolution.REMOVE:
+ task.should_remove_duplicates = True
+ elif res == self.Resolution.MERGE:
+ task.should_merge_duplicates = True
+
+
+class TerminalImportSessionFixture(TerminalImportSession):
+ def __init__(self, *args, **kwargs):
+ self.io = kwargs.pop("io")
+ super().__init__(*args, **kwargs)
+ self._choices = []
+
+ default_choice = importer.action.APPLY
+
+ def add_choice(self, choice):
+ self._choices.append(choice)
+
+ def clear_choices(self):
+ self._choices = []
+
+ def choose_match(self, task):
+ self._add_choice_input()
+ return super().choose_match(task)
+
+ def choose_item(self, task):
+ self._add_choice_input()
+ return super().choose_item(task)
+
+ def _add_choice_input(self):
+ try:
+ choice = self._choices.pop(0)
+ except IndexError:
+ choice = self.default_choice
+
+ if choice == importer.action.APPLY:
+ self.io.addinput("A")
+ elif choice == importer.action.ASIS:
+ self.io.addinput("U")
+ elif choice == importer.action.ALBUMS:
+ self.io.addinput("G")
+ elif choice == importer.action.TRACKS:
+ self.io.addinput("T")
+ elif choice == importer.action.SKIP:
+ self.io.addinput("S")
+ elif isinstance(choice, int):
+ self.io.addinput("M")
+ self.io.addinput(str(choice))
+ self._add_choice_input()
+ else:
+ raise Exception("Unknown choice %s" % choice)
+
+
+class TerminalImportMixin(ImportHelper):
+ """Provides_a terminal importer for the import session."""
+
+ io: _common.DummyIO
+
+ def _get_import_session(self, import_dir: bytes) -> importer.ImportSession:
+ self.io.install()
+ return TerminalImportSessionFixture(
+ self.lib,
+ loghandler=None,
+ query=None,
+ io=self.io,
+ paths=[import_dir],
+ )
+
+
+def generate_album_info(album_id, track_values):
+ """Return `AlbumInfo` populated with mock data.
+
+ Sets the album info's `album_id` field is set to the corresponding
+ argument. For each pair (`id`, `values`) in `track_values` the `TrackInfo`
+ from `generate_track_info` is added to the album info's `tracks` field.
+ Most other fields of the album and track info are set to "album
+ info" and "track info", respectively.
+ """
+ tracks = [generate_track_info(id, values) for id, values in track_values]
+ album = AlbumInfo(
+ album_id="album info",
+ album="album info",
+ artist="album info",
+ artist_id="album info",
+ tracks=tracks,
+ )
+ for field in ALBUM_INFO_FIELDS:
+ setattr(album, field, "album info")
+
+ return album
+
+
+ALBUM_INFO_FIELDS = [
+ "album",
+ "album_id",
+ "artist",
+ "artist_id",
+ "asin",
+ "albumtype",
+ "va",
+ "label",
+ "barcode",
+ "artist_sort",
+ "releasegroup_id",
+ "catalognum",
+ "language",
+ "country",
+ "albumstatus",
+ "media",
+ "albumdisambig",
+ "releasegroupdisambig",
+ "artist_credit",
+ "data_source",
+ "data_url",
+]
+
+
+def generate_track_info(track_id="track info", values={}):
+ """Return `TrackInfo` populated with mock data.
+
+ The `track_id` field is set to the corresponding argument. All other
+ string fields are set to "track info".
+ """
+ track = TrackInfo(
+ title="track info",
+ track_id=track_id,
+ )
+ for field in TRACK_INFO_FIELDS:
+ setattr(track, field, "track info")
+ for field, value in values.items():
+ setattr(track, field, value)
+ return track
+
+
+TRACK_INFO_FIELDS = [
+ "artist",
+ "artist_id",
+ "artist_sort",
+ "disctitle",
+ "artist_credit",
+ "data_source",
+ "data_url",
+]
+
+
+class AutotagStub:
+ """Stub out MusicBrainz album and track matcher and control what the
+ autotagger returns.
+ """
+
+ NONE = "NONE"
+ IDENT = "IDENT"
+ GOOD = "GOOD"
+ BAD = "BAD"
+ MISSING = "MISSING"
+ """Generate an album match for all but one track
+ """
+
+ length = 2
+ matching = IDENT
+
+ def install(self):
+ self.mb_match_album = autotag.mb.match_album
+ self.mb_match_track = autotag.mb.match_track
+ self.mb_album_for_id = autotag.mb.album_for_id
+ self.mb_track_for_id = autotag.mb.track_for_id
+
+ autotag.mb.match_album = self.match_album
+ autotag.mb.match_track = self.match_track
+ autotag.mb.album_for_id = self.album_for_id
+ autotag.mb.track_for_id = self.track_for_id
+
+ return self
+
+ def restore(self):
+ autotag.mb.match_album = self.mb_match_album
+ autotag.mb.match_track = self.mb_match_track
+ autotag.mb.album_for_id = self.mb_album_for_id
+ autotag.mb.track_for_id = self.mb_track_for_id
+
+ def match_album(self, albumartist, album, tracks, extra_tags):
+ if self.matching == self.IDENT:
+ yield self._make_album_match(albumartist, album, tracks)
+
+ elif self.matching == self.GOOD:
+ for i in range(self.length):
+ yield self._make_album_match(albumartist, album, tracks, i)
+
+ elif self.matching == self.BAD:
+ for i in range(self.length):
+ yield self._make_album_match(albumartist, album, tracks, i + 1)
+
+ elif self.matching == self.MISSING:
+ yield self._make_album_match(albumartist, album, tracks, missing=1)
+
+ def match_track(self, artist, title):
+ yield TrackInfo(
+ title=title.replace("Tag", "Applied"),
+ track_id="trackid",
+ artist=artist.replace("Tag", "Applied"),
+ artist_id="artistid",
+ length=1,
+ index=0,
+ )
+
+ def album_for_id(self, mbid):
+ return None
+
+ def track_for_id(self, mbid):
+ return None
+
+ def _make_track_match(self, artist, album, number):
+ return TrackInfo(
+ title="Applied Track %d" % number,
+ track_id="match %d" % number,
+ artist=artist,
+ length=1,
+ index=0,
+ )
+
+ def _make_album_match(self, artist, album, tracks, distance=0, missing=0):
+ if distance:
+ id = " " + "M" * distance
+ else:
+ id = ""
+ if artist is None:
+ artist = "Various Artists"
+ else:
+ artist = artist.replace("Tag", "Applied") + id
+ album = album.replace("Tag", "Applied") + id
+
+ track_infos = []
+ for i in range(tracks - missing):
+ track_infos.append(self._make_track_match(artist, album, i + 1))
+
+ return AlbumInfo(
+ artist=artist,
+ album=album,
+ tracks=track_infos,
+ va=False,
+ album_id="albumid" + id,
+ artist_id="artistid" + id,
+ albumtype="soundtrack",
+ data_source="match_source",
+ )
+
+
+class FetchImageHelper:
+ """Helper mixin for mocking requests when fetching images
+ with remote art sources.
+ """
+
+ @responses.activate
+ def run(self, *args, **kwargs):
+ super().run(*args, **kwargs)
+
+ IMAGEHEADER = {
+ "image/jpeg": b"\x00" * 6 + b"JFIF",
+ "image/png": b"\211PNG\r\n\032\n",
+ }
+
+ def mock_response(self, url, content_type="image/jpeg", file_type=None):
+ if file_type is None:
+ file_type = content_type
+ responses.add(
+ responses.GET,
+ url,
+ content_type=content_type,
+ # imghdr reads 32 bytes
+ body=self.IMAGEHEADER.get(file_type, b"").ljust(32, b"\x00"),
+ )
+
+
+class CleanupModulesMixin:
+ modules: ClassVar[tuple[str, ...]]
+
+ @classmethod
+ def tearDownClass(cls) -> None:
+ """Remove files created by the plugin."""
+ for module in cls.modules:
+ clean_module_tempdir(module)
diff --git a/lib/beets/ui/__init__.py b/lib/beets/ui/__init__.py
index 121cb5dc..8580bd1e 100644
--- a/lib/beets/ui/__init__.py
+++ b/lib/beets/ui/__init__.py
@@ -18,31 +18,29 @@ CLI commands are implemented in the ui.commands module.
"""
-import optparse
-import textwrap
-import sys
-from difflib import SequenceMatcher
-import sqlite3
import errno
-import re
-import struct
-import traceback
+import optparse
import os.path
+import re
+import sqlite3
+import struct
+import sys
+import textwrap
+import traceback
+from difflib import SequenceMatcher
+from typing import Any, Callable, List
-from beets import logging
-from beets import library
-from beets import plugins
-from beets import util
-from beets.util.functemplate import template
-from beets import config
-from beets.util import as_string
-from beets.autotag import mb
-from beets.dbcore import query as db_query
-from beets.dbcore import db
import confuse
+from beets import config, library, logging, plugins, util
+from beets.autotag import mb
+from beets.dbcore import db
+from beets.dbcore import query as db_query
+from beets.util import as_string
+from beets.util.functemplate import template
+
# On Windows platforms, use colorama to support "ANSI" terminal colors.
-if sys.platform == 'win32':
+if sys.platform == "win32":
try:
import colorama
except ImportError:
@@ -51,15 +49,15 @@ if sys.platform == 'win32':
colorama.init()
-log = logging.getLogger('beets')
+log = logging.getLogger("beets")
if not log.handlers:
log.addHandler(logging.StreamHandler())
log.propagate = False # Don't propagate to root handler.
PF_KEY_QUERIES = {
- 'comp': 'comp:true',
- 'singleton': 'singleton:true',
+ "comp": "comp:true",
+ "singleton": "singleton:true",
}
@@ -73,31 +71,29 @@ class UserError(Exception):
def _in_encoding():
- """Get the encoding to use for *inputting* strings from the console.
- """
+ """Get the encoding to use for *inputting* strings from the console."""
return _stream_encoding(sys.stdin)
def _out_encoding():
- """Get the encoding to use for *outputting* strings to the console.
- """
+ """Get the encoding to use for *outputting* strings to the console."""
return _stream_encoding(sys.stdout)
-def _stream_encoding(stream, default='utf-8'):
+def _stream_encoding(stream, default="utf-8"):
"""A helper for `_in_encoding` and `_out_encoding`: get the stream's
preferred encoding, using a configured override or a default
fallback if neither is not specified.
"""
# Configured override?
- encoding = config['terminal_encoding'].get()
+ encoding = config["terminal_encoding"].get()
if encoding:
return encoding
# For testing: When sys.stdout or sys.stdin is a StringIO under the
# test harness, it doesn't have an `encoding` attribute. Just use
# UTF-8.
- if not hasattr(stream, 'encoding'):
+ if not hasattr(stream, "encoding"):
return default
# Python's guessed output stream encoding, or UTF-8 as a fallback
@@ -124,19 +120,19 @@ def print_(*strings, **kwargs):
(it defaults to a newline).
"""
if not strings:
- strings = ['']
+ strings = [""]
assert isinstance(strings[0], str)
- txt = ' '.join(strings)
- txt += kwargs.get('end', '\n')
+ txt = " ".join(strings)
+ txt += kwargs.get("end", "\n")
# Encode the string and write it to stdout.
# On Python 3, sys.stdout expects text strings and uses the
# exception-throwing encoding error policy. To avoid throwing
# errors and use our configurable encoding override, we use the
# underlying bytes buffer instead.
- if hasattr(sys.stdout, 'buffer'):
- out = txt.encode(_out_encoding(), 'replace')
+ if hasattr(sys.stdout, "buffer"):
+ out = txt.encode(_out_encoding(), "replace")
sys.stdout.buffer.write(out)
sys.stdout.buffer.flush()
else:
@@ -147,9 +143,9 @@ def print_(*strings, **kwargs):
# Configuration wrappers.
+
def _bool_fallback(a, b):
- """Given a boolean or None, return the original value or a fallback.
- """
+ """Given a boolean or None, return the original value or a fallback."""
if a is None:
assert isinstance(b, bool)
return b
@@ -162,7 +158,7 @@ def should_write(write_opt=None):
"""Decide whether a command that updates metadata should also write
tags, using the importer configuration as the default.
"""
- return _bool_fallback(write_opt, config['import']['write'].get(bool))
+ return _bool_fallback(write_opt, config["import"]["write"].get(bool))
def should_move(move_opt=None):
@@ -177,13 +173,19 @@ def should_move(move_opt=None):
"""
return _bool_fallback(
move_opt,
- config['import']['move'].get(bool) or
- config['import']['copy'].get(bool)
+ config["import"]["move"].get(bool)
+ or config["import"]["copy"].get(bool),
)
# Input prompts.
+
+def indent(count):
+ """Returns a string with `count` many spaces."""
+ return " " * count
+
+
def input_(prompt=None):
"""Like `input`, but decodes the result to a Unicode string.
Raises a UserError if stdin is not available. The prompt is sent to
@@ -194,18 +196,25 @@ def input_(prompt=None):
# use print_() explicitly to display prompts.
# https://bugs.python.org/issue1927
if prompt:
- print_(prompt, end=' ')
+ print_(prompt, end=" ")
try:
resp = input()
except EOFError:
- raise UserError('stdin stream ended while input required')
+ raise UserError("stdin stream ended while input required")
return resp
-def input_options(options, require=False, prompt=None, fallback_prompt=None,
- numrange=None, default=None, max_width=72):
+def input_options(
+ options,
+ require=False,
+ prompt=None,
+ fallback_prompt=None,
+ numrange=None,
+ default=None,
+ max_width=72,
+):
"""Prompts a user for input. The sequence of `options` defines the
choices the user has. A single-letter shortcut is inferred for each
option; the user's choice is returned as that single, lower-case
@@ -245,30 +254,37 @@ def input_options(options, require=False, prompt=None, fallback_prompt=None,
found_letter = letter
break
else:
- raise ValueError('no unambiguous lettering found')
+ raise ValueError("no unambiguous lettering found")
letters[found_letter.lower()] = option
index = option.index(found_letter)
# Mark the option's shortcut letter for display.
if not require and (
- (default is None and not numrange and first) or
- (isinstance(default, str) and
- found_letter.lower() == default.lower())):
+ (default is None and not numrange and first)
+ or (
+ isinstance(default, str)
+ and found_letter.lower() == default.lower()
+ )
+ ):
# The first option is the default; mark it.
- show_letter = '[%s]' % found_letter.upper()
+ show_letter = "[%s]" % found_letter.upper()
is_default = True
else:
show_letter = found_letter.upper()
is_default = False
# Colorize the letter shortcut.
- show_letter = colorize('action_default' if is_default else 'action',
- show_letter)
+ show_letter = colorize(
+ "action_default" if is_default else "action", show_letter
+ )
# Insert the highlighted letter back into the word.
+ descr_color = "action_default" if is_default else "action_description"
capitalized.append(
- option[:index] + show_letter + option[index + 1:]
+ colorize(descr_color, option[:index])
+ + show_letter
+ + colorize(descr_color, option[index + 1 :])
)
display_letters.append(found_letter.upper())
@@ -290,36 +306,38 @@ def input_options(options, require=False, prompt=None, fallback_prompt=None,
if numrange:
if isinstance(default, int):
default_name = str(default)
- default_name = colorize('action_default', default_name)
- tmpl = '# selection (default %s)'
+ default_name = colorize("action_default", default_name)
+ tmpl = "# selection (default %s)"
prompt_parts.append(tmpl % default_name)
prompt_part_lengths.append(len(tmpl % str(default)))
else:
- prompt_parts.append('# selection')
+ prompt_parts.append("# selection")
prompt_part_lengths.append(len(prompt_parts[-1]))
prompt_parts += capitalized
prompt_part_lengths += [len(s) for s in options]
# Wrap the query text.
- prompt = ''
+ # Start prompt with U+279C: Heavy Round-Tipped Rightwards Arrow
+ prompt = colorize("action", "\u279C ")
line_length = 0
- for i, (part, length) in enumerate(zip(prompt_parts,
- prompt_part_lengths)):
+ for i, (part, length) in enumerate(
+ zip(prompt_parts, prompt_part_lengths)
+ ):
# Add punctuation.
if i == len(prompt_parts) - 1:
- part += '?'
+ part += colorize("action_description", "?")
else:
- part += ','
+ part += colorize("action_description", ",")
length += 1
# Choose either the current line or the beginning of the next.
if line_length + length + 1 > max_width:
- prompt += '\n'
+ prompt += "\n"
line_length = 0
if line_length != 0:
# Not the beginning of the line; need a space.
- part = ' ' + part
+ part = " " + part
length += 1
prompt += part
@@ -328,10 +346,10 @@ def input_options(options, require=False, prompt=None, fallback_prompt=None,
# Make a fallback prompt too. This is displayed if the user enters
# something that is not recognized.
if not fallback_prompt:
- fallback_prompt = 'Enter one of '
+ fallback_prompt = "Enter one of "
if numrange:
- fallback_prompt += '%i-%i, ' % numrange
- fallback_prompt += ', '.join(display_letters) + ':'
+ fallback_prompt += "%i-%i, " % numrange
+ fallback_prompt += ", ".join(display_letters) + ":"
resp = input_(prompt)
while True:
@@ -368,10 +386,12 @@ def input_yn(prompt, require=False):
"""Prompts the user for a "yes" or "no" response. The default is
"yes" unless `require` is `True`, in which case there is no default.
"""
- sel = input_options(
- ('y', 'n'), require, prompt, 'Enter Y or N:'
+ # Start prompt with U+279C: Heavy Round-Tipped Rightwards Arrow
+ yesno = colorize("action", "\u279C ") + colorize(
+ "action_description", "Enter Y or N:"
)
- return sel == 'y'
+ sel = input_options(("y", "n"), require, prompt, yesno)
+ return sel == "y"
def input_select_objects(prompt, objs, rep, prompt_all=None):
@@ -385,24 +405,26 @@ def input_select_objects(prompt, objs, rep, prompt_all=None):
objects individually.
"""
choice = input_options(
- ('y', 'n', 's'), False,
- '%s? (Yes/no/select)' % (prompt_all or prompt))
+ ("y", "n", "s"), False, "%s? (Yes/no/select)" % (prompt_all or prompt)
+ )
print() # Blank line.
- if choice == 'y': # Yes.
+ if choice == "y": # Yes.
return objs
- elif choice == 's': # Select.
+ elif choice == "s": # Select.
out = []
for obj in objs:
rep(obj)
answer = input_options(
- ('y', 'n', 'q'), True, '%s? (yes/no/quit)' % prompt,
- 'Enter Y or N:'
+ ("y", "n", "q"),
+ True,
+ "%s? (yes/no/quit)" % prompt,
+ "Enter Y or N:",
)
- if answer == 'y':
+ if answer == "y":
out.append(obj)
- elif answer == 'q':
+ elif answer == "q":
return out
return out
@@ -412,15 +434,16 @@ def input_select_objects(prompt, objs, rep, prompt_all=None):
# Human output formatting.
+
def human_bytes(size):
"""Formats size, a number of bytes, in a human-readable way."""
- powers = ['', 'K', 'M', 'G', 'T', 'P', 'E', 'Z', 'Y', 'H']
- unit = 'B'
+ powers = ["", "K", "M", "G", "T", "P", "E", "Z", "Y", "H"]
+ unit = "B"
for power in powers:
if size < 1024:
return f"{size:3.1f} {power}{unit}"
size /= 1024.0
- unit = 'iB'
+ unit = "iB"
return "big"
@@ -429,13 +452,13 @@ def human_seconds(interval):
interval using English words.
"""
units = [
- (1, 'second'),
- (60, 'minute'),
- (60, 'hour'),
- (24, 'day'),
- (7, 'week'),
- (52, 'year'),
- (10, 'decade'),
+ (1, "second"),
+ (60, "minute"),
+ (60, "hour"),
+ (24, "day"),
+ (7, "week"),
+ (52, "year"),
+ (10, "decade"),
]
for i in range(len(units) - 1):
increment, suffix = units[i]
@@ -456,7 +479,7 @@ def human_seconds_short(interval):
string.
"""
interval = int(interval)
- return '%i:%02i' % (interval // 60, interval % 60)
+ return "%i:%02i" % (interval // 60, interval % 60)
# Colorization.
@@ -465,51 +488,102 @@ def human_seconds_short(interval):
# https://bitbucket.org/birkenfeld/pygments-main/src/default/pygments/console.py
# (pygments is by Tim Hatch, Armin Ronacher, et al.)
COLOR_ESCAPE = "\x1b["
-DARK_COLORS = {
- "black": 0,
- "darkred": 1,
- "darkgreen": 2,
- "brown": 3,
- "darkyellow": 3,
- "darkblue": 4,
- "purple": 5,
- "darkmagenta": 5,
- "teal": 6,
- "darkcyan": 6,
- "lightgray": 7
+LEGACY_COLORS = {
+ "black": ["black"],
+ "darkred": ["red"],
+ "darkgreen": ["green"],
+ "brown": ["yellow"],
+ "darkyellow": ["yellow"],
+ "darkblue": ["blue"],
+ "purple": ["magenta"],
+ "darkmagenta": ["magenta"],
+ "teal": ["cyan"],
+ "darkcyan": ["cyan"],
+ "lightgray": ["white"],
+ "darkgray": ["bold", "black"],
+ "red": ["bold", "red"],
+ "green": ["bold", "green"],
+ "yellow": ["bold", "yellow"],
+ "blue": ["bold", "blue"],
+ "fuchsia": ["bold", "magenta"],
+ "magenta": ["bold", "magenta"],
+ "turquoise": ["bold", "cyan"],
+ "cyan": ["bold", "cyan"],
+ "white": ["bold", "white"],
}
-LIGHT_COLORS = {
- "darkgray": 0,
- "red": 1,
- "green": 2,
- "yellow": 3,
- "blue": 4,
- "fuchsia": 5,
- "magenta": 5,
- "turquoise": 6,
- "cyan": 6,
- "white": 7
+# All ANSI Colors.
+ANSI_CODES = {
+ # Styles.
+ "normal": 0,
+ "bold": 1,
+ "faint": 2,
+ # "italic": 3,
+ "underline": 4,
+ # "blink_slow": 5,
+ # "blink_rapid": 6,
+ "inverse": 7,
+ # "conceal": 8,
+ # "crossed_out": 9
+ # Text colors.
+ "black": 30,
+ "red": 31,
+ "green": 32,
+ "yellow": 33,
+ "blue": 34,
+ "magenta": 35,
+ "cyan": 36,
+ "white": 37,
+ # Background colors.
+ "bg_black": 40,
+ "bg_red": 41,
+ "bg_green": 42,
+ "bg_yellow": 43,
+ "bg_blue": 44,
+ "bg_magenta": 45,
+ "bg_cyan": 46,
+ "bg_white": 47,
}
RESET_COLOR = COLOR_ESCAPE + "39;49;00m"
# These abstract COLOR_NAMES are lazily mapped on to the actual color in COLORS
# as they are defined in the configuration files, see function: colorize
-COLOR_NAMES = ['text_success', 'text_warning', 'text_error', 'text_highlight',
- 'text_highlight_minor', 'action_default', 'action']
+COLOR_NAMES = [
+ "text_success",
+ "text_warning",
+ "text_error",
+ "text_highlight",
+ "text_highlight_minor",
+ "action_default",
+ "action",
+ # New Colors
+ "text",
+ "text_faint",
+ "import_path",
+ "import_path_items",
+ "action_description",
+ "added",
+ "removed",
+ "changed",
+ "added_highlight",
+ "removed_highlight",
+ "changed_highlight",
+ "text_diff_added",
+ "text_diff_removed",
+ "text_diff_changed",
+]
COLORS = None
def _colorize(color, text):
"""Returns a string that prints the given text in the given color
- in a terminal that is ANSI color-aware. The color must be something
- in DARK_COLORS or LIGHT_COLORS.
+ in a terminal that is ANSI color-aware. The color must be a list of strings
+ from ANSI_CODES.
"""
- if color in DARK_COLORS:
- escape = COLOR_ESCAPE + "%im" % (DARK_COLORS[color] + 30)
- elif color in LIGHT_COLORS:
- escape = COLOR_ESCAPE + "%i;01m" % (LIGHT_COLORS[color] + 30)
- else:
- raise ValueError('no such color %s', color)
+ # Construct escape sequence to be put before the text by iterating
+ # over all "ANSI codes" in `color`.
+ escape = ""
+ for code in color:
+ escape = escape + COLOR_ESCAPE + "%im" % ANSI_CODES[code]
return escape + text + RESET_COLOR
@@ -517,81 +591,165 @@ def colorize(color_name, text):
"""Colorize text if colored output is enabled. (Like _colorize but
conditional.)
"""
- if not config['ui']['color'] or 'NO_COLOR' in os.environ.keys():
+ if config["ui"]["color"] and "NO_COLOR" not in os.environ:
+ global COLORS
+ if not COLORS:
+ # Read all color configurations and set global variable COLORS.
+ COLORS = dict()
+ for name in COLOR_NAMES:
+ # Convert legacy color definitions (strings) into the new
+ # list-based color definitions. Do this by trying to read the
+ # color definition from the configuration as unicode - if this
+ # is successful, the color definition is a legacy definition
+ # and has to be converted.
+ try:
+ color_def = config["ui"]["colors"][name].get(str)
+ except (confuse.ConfigTypeError, NameError):
+ # Normal color definition (type: list of unicode).
+ color_def = config["ui"]["colors"][name].get(list)
+ else:
+ # Legacy color definition (type: unicode). Convert.
+ if color_def in LEGACY_COLORS:
+ color_def = LEGACY_COLORS[color_def]
+ else:
+ raise UserError("no such color %s", color_def)
+ for code in color_def:
+ if code not in ANSI_CODES.keys():
+ raise ValueError("no such ANSI code %s", code)
+ COLORS[name] = color_def
+ # In case a 3rd party plugin is still passing the actual color ('red')
+ # instead of the abstract color name ('text_error')
+ color = COLORS.get(color_name)
+ if not color:
+ log.debug("Invalid color_name: {0}", color_name)
+ color = color_name
+ return _colorize(color, text)
+ else:
return text
- global COLORS
- if not COLORS:
- COLORS = {name:
- config['ui']['colors'][name].as_str()
- for name in COLOR_NAMES}
- # In case a 3rd party plugin is still passing the actual color ('red')
- # instead of the abstract color name ('text_error')
- color = COLORS.get(color_name)
- if not color:
- log.debug('Invalid color_name: {0}', color_name)
- color = color_name
- return _colorize(color, text)
+
+def uncolorize(colored_text):
+ """Remove colors from a string."""
+ # Define a regular expression to match ANSI codes.
+ # See: http://stackoverflow.com/a/2187024/1382707
+ # Explanation of regular expression:
+ # \x1b - matches ESC character
+ # \[ - matches opening square bracket
+ # [;\d]* - matches a sequence consisting of one or more digits or
+ # semicola
+ # [A-Za-z] - matches a letter
+ ansi_code_regex = re.compile(r"\x1b\[[;\d]*[A-Za-z]", re.VERBOSE)
+ # Strip ANSI codes from `colored_text` using the regular expression.
+ text = ansi_code_regex.sub("", colored_text)
+ return text
-def _colordiff(a, b, highlight='text_highlight',
- minor_highlight='text_highlight_minor'):
+def color_split(colored_text, index):
+ ansi_code_regex = re.compile(r"(\x1b\[[;\d]*[A-Za-z])", re.VERBOSE)
+ length = 0
+ pre_split = ""
+ post_split = ""
+ found_color_code = None
+ found_split = False
+ for part in ansi_code_regex.split(colored_text):
+ # Count how many real letters we have passed
+ length += color_len(part)
+ if found_split:
+ post_split += part
+ else:
+ if ansi_code_regex.match(part):
+ # This is a color code
+ if part == RESET_COLOR:
+ found_color_code = None
+ else:
+ found_color_code = part
+ pre_split += part
+ else:
+ if index < length:
+ # Found part with our split in.
+ split_index = index - (length - color_len(part))
+ found_split = True
+ if found_color_code:
+ pre_split += part[:split_index] + RESET_COLOR
+ post_split += found_color_code + part[split_index:]
+ else:
+ pre_split += part[:split_index]
+ post_split += part[split_index:]
+ else:
+ # Not found, add this part to the pre split
+ pre_split += part
+ return pre_split, post_split
+
+
+def color_len(colored_text):
+ """Measure the length of a string while excluding ANSI codes from the
+ measurement. The standard `len(my_string)` method also counts ANSI codes
+ to the string length, which is counterproductive when layouting a
+ Terminal interface.
+ """
+ # Return the length of the uncolored string.
+ return len(uncolorize(colored_text))
+
+
+def _colordiff(a, b):
"""Given two values, return the same pair of strings except with
their differences highlighted in the specified color. Strings are
highlighted intelligently to show differences; other values are
stringified and highlighted in their entirety.
"""
- if not isinstance(a, str) \
- or not isinstance(b, str):
- # Non-strings: use ordinary equality.
- a = str(a)
- b = str(b)
- if a == b:
- return a, b
- else:
- return colorize(highlight, a), colorize(highlight, b)
-
+ # First, convert paths to readable format
if isinstance(a, bytes) or isinstance(b, bytes):
# A path field.
a = util.displayable_path(a)
b = util.displayable_path(b)
+ if not isinstance(a, str) or not isinstance(b, str):
+ # Non-strings: use ordinary equality.
+ if a == b:
+ return str(a), str(b)
+ else:
+ return (
+ colorize("text_diff_removed", str(a)),
+ colorize("text_diff_added", str(b)),
+ )
+
a_out = []
b_out = []
matcher = SequenceMatcher(lambda x: False, a, b)
for op, a_start, a_end, b_start, b_end in matcher.get_opcodes():
- if op == 'equal':
+ if op == "equal":
# In both strings.
a_out.append(a[a_start:a_end])
b_out.append(b[b_start:b_end])
- elif op == 'insert':
+ elif op == "insert":
# Right only.
- b_out.append(colorize(highlight, b[b_start:b_end]))
- elif op == 'delete':
+ b_out.append(colorize("text_diff_added", b[b_start:b_end]))
+ elif op == "delete":
# Left only.
- a_out.append(colorize(highlight, a[a_start:a_end]))
- elif op == 'replace':
+ a_out.append(colorize("text_diff_removed", a[a_start:a_end]))
+ elif op == "replace":
# Right and left differ. Colorise with second highlight if
# it's just a case change.
if a[a_start:a_end].lower() != b[b_start:b_end].lower():
- color = highlight
+ a_color = "text_diff_removed"
+ b_color = "text_diff_added"
else:
- color = minor_highlight
- a_out.append(colorize(color, a[a_start:a_end]))
- b_out.append(colorize(color, b[b_start:b_end]))
+ a_color = b_color = "text_highlight_minor"
+ a_out.append(colorize(a_color, a[a_start:a_end]))
+ b_out.append(colorize(b_color, b[b_start:b_end]))
else:
- assert(False)
+ assert False
- return ''.join(a_out), ''.join(b_out)
+ return "".join(a_out), "".join(b_out)
-def colordiff(a, b, highlight='text_highlight'):
+def colordiff(a, b):
"""Colorize differences between two values if color is enabled.
(Like _colordiff but conditional.)
"""
- if config['ui']['color']:
- return _colordiff(a, b, highlight)
+ if config["ui"]["color"]:
+ return _colordiff(a, b)
else:
return str(a), str(b)
@@ -601,7 +759,7 @@ def get_path_formats(subview=None):
pairs.
"""
path_formats = []
- subview = subview or config['paths']
+ subview = subview or config["paths"]
for query, view in subview.items():
query = PF_KEY_QUERIES.get(query, query) # Expand common queries.
path_formats.append((query, template(view.as_str())))
@@ -609,25 +767,22 @@ def get_path_formats(subview=None):
def get_replacements():
- """Confuse validation function that reads regex/string pairs.
- """
+ """Confuse validation function that reads regex/string pairs."""
replacements = []
- for pattern, repl in config['replace'].get(dict).items():
- repl = repl or ''
+ for pattern, repl in config["replace"].get(dict).items():
+ repl = repl or ""
try:
replacements.append((re.compile(pattern), repl))
except re.error:
raise UserError(
- 'malformed regular expression in replace: {}'.format(
- pattern
- )
+ "malformed regular expression in replace: {}".format(pattern)
)
return replacements
def term_width():
"""Get the width (columns) of the terminal."""
- fallback = config['ui']['terminal_width'].get(int)
+ fallback = config["ui"]["terminal_width"].get(int)
# The fcntl and termios modules are not available on non-Unix
# platforms, so we fall back to a constant.
@@ -638,16 +793,354 @@ def term_width():
return fallback
try:
- buf = fcntl.ioctl(0, termios.TIOCGWINSZ, ' ' * 4)
+ buf = fcntl.ioctl(0, termios.TIOCGWINSZ, " " * 4)
except OSError:
return fallback
try:
- height, width = struct.unpack('hh', buf)
+ height, width = struct.unpack("hh", buf)
except struct.error:
return fallback
return width
+def split_into_lines(string, width_tuple):
+ """Splits string into a list of substrings at whitespace.
+
+ `width_tuple` is a 3-tuple of `(first_width, last_width, middle_width)`.
+ The first substring has a length not longer than `first_width`, the last
+ substring has a length not longer than `last_width`, and all other
+ substrings have a length not longer than `middle_width`.
+ `string` may contain ANSI codes at word borders.
+ """
+ first_width, middle_width, last_width = width_tuple
+ words = []
+ esc_text = re.compile(
+ r"""(?P[^\x1b]*)
+ (?P(?:\x1b\[[;\d]*[A-Za-z])+)
+ (?P[^\x1b]+)(?P\x1b\[39;49;00m)
+ (?P[^\x1b]*)""",
+ re.VERBOSE,
+ )
+ if uncolorize(string) == string:
+ # No colors in string
+ words = string.split()
+ else:
+ # Use a regex to find escapes and the text within them.
+ for m in esc_text.finditer(string):
+ # m contains four groups:
+ # pretext - any text before escape sequence
+ # esc - intitial escape sequence
+ # text - text, no escape sequence, may contain spaces
+ # reset - ASCII colour reset
+ space_before_text = False
+ if m.group("pretext") != "":
+ # Some pretext found, let's handle it
+ # Add any words in the pretext
+ words += m.group("pretext").split()
+ if m.group("pretext")[-1] == " ":
+ # Pretext ended on a space
+ space_before_text = True
+ else:
+ # Pretext ended mid-word, ensure next word
+ pass
+ else:
+ # pretext empty, treat as if there is a space before
+ space_before_text = True
+ if m.group("text")[0] == " ":
+ # First character of the text is a space
+ space_before_text = True
+ # Now, handle the words in the main text:
+ raw_words = m.group("text").split()
+ if space_before_text:
+ # Colorize each word with pre/post escapes
+ # Reconstruct colored words
+ words += [
+ m.group("esc") + raw_word + RESET_COLOR
+ for raw_word in raw_words
+ ]
+ elif raw_words:
+ # Pretext stops mid-word
+ if m.group("esc") != RESET_COLOR:
+ # Add the rest of the current word, with a reset after it
+ words[-1] += m.group("esc") + raw_words[0] + RESET_COLOR
+ # Add the subsequent colored words:
+ words += [
+ m.group("esc") + raw_word + RESET_COLOR
+ for raw_word in raw_words[1:]
+ ]
+ else:
+ # Caught a mid-word escape sequence
+ words[-1] += raw_words[0]
+ words += raw_words[1:]
+ if (
+ m.group("text")[-1] != " "
+ and m.group("posttext") != ""
+ and m.group("posttext")[0] != " "
+ ):
+ # reset falls mid-word
+ post_text = m.group("posttext").split()
+ words[-1] += post_text[0]
+ words += post_text[1:]
+ else:
+ # Add any words after escape sequence
+ words += m.group("posttext").split()
+ result = []
+ next_substr = ""
+ # Iterate over all words.
+ previous_fit = False
+ for i in range(len(words)):
+ if i == 0:
+ pot_substr = words[i]
+ else:
+ # (optimistically) add the next word to check the fit
+ pot_substr = " ".join([next_substr, words[i]])
+ # Find out if the pot(ential)_substr fits into the next substring.
+ fits_first = len(result) == 0 and color_len(pot_substr) <= first_width
+ fits_middle = len(result) != 0 and color_len(pot_substr) <= middle_width
+ if fits_first or fits_middle:
+ # Fitted(!) let's try and add another word before appending
+ next_substr = pot_substr
+ previous_fit = True
+ elif not fits_first and not fits_middle and previous_fit:
+ # Extra word didn't fit, append what we have
+ result.append(next_substr)
+ next_substr = words[i]
+ previous_fit = color_len(next_substr) <= middle_width
+ else:
+ # Didn't fit anywhere
+ if uncolorize(pot_substr) == pot_substr:
+ # Simple uncolored string, append a cropped word
+ if len(result) == 0:
+ # Crop word by the first_width for the first line
+ result.append(pot_substr[:first_width])
+ # add rest of word to next line
+ next_substr = pot_substr[first_width:]
+ else:
+ result.append(pot_substr[:middle_width])
+ next_substr = pot_substr[middle_width:]
+ else:
+ # Colored strings
+ if len(result) == 0:
+ this_line, next_line = color_split(pot_substr, first_width)
+ result.append(this_line)
+ next_substr = next_line
+ else:
+ this_line, next_line = color_split(pot_substr, middle_width)
+ result.append(this_line)
+ next_substr = next_line
+ previous_fit = color_len(next_substr) <= middle_width
+
+ # We finished constructing the substrings, but the last substring
+ # has not yet been added to the result.
+ result.append(next_substr)
+ # Also, the length of the last substring was only checked against
+ # `middle_width`. Append an empty substring as the new last substring if
+ # the last substring is too long.
+ if not color_len(next_substr) <= last_width:
+ result.append("")
+ return result
+
+
+def print_column_layout(
+ indent_str, left, right, separator=" -> ", max_width=term_width()
+):
+ """Print left & right data, with separator inbetween
+ 'left' and 'right' have a structure of:
+ {'prefix':u'','contents':u'','suffix':u'','width':0}
+ In a column layout the printing will be:
+ {indent_str}{lhs0}{separator}{rhs0}
+ {lhs1 / padding }{rhs1}
+ ...
+ The first line of each column (i.e. {lhs0} or {rhs0}) is:
+ {prefix}{part of contents}{suffix}
+ With subsequent lines (i.e. {lhs1}, {rhs1} onwards) being the
+ rest of contents, wrapped if the width would be otherwise exceeded.
+ """
+ if right["prefix"] + right["contents"] + right["suffix"] == "":
+ # No right hand information, so we don't need a separator.
+ separator = ""
+ first_line_no_wrap = (
+ indent_str
+ + left["prefix"]
+ + left["contents"]
+ + left["suffix"]
+ + separator
+ + right["prefix"]
+ + right["contents"]
+ + right["suffix"]
+ )
+ if color_len(first_line_no_wrap) < max_width:
+ # Everything fits, print out line.
+ print_(first_line_no_wrap)
+ else:
+ # Wrap into columns
+ if "width" not in left or "width" not in right:
+ # If widths have not been defined, set to share space.
+ left["width"] = (
+ max_width - len(indent_str) - color_len(separator)
+ ) // 2
+ right["width"] = (
+ max_width - len(indent_str) - color_len(separator)
+ ) // 2
+ # On the first line, account for suffix as well as prefix
+ left_width_tuple = (
+ left["width"]
+ - color_len(left["prefix"])
+ - color_len(left["suffix"]),
+ left["width"] - color_len(left["prefix"]),
+ left["width"] - color_len(left["prefix"]),
+ )
+
+ left_split = split_into_lines(left["contents"], left_width_tuple)
+ right_width_tuple = (
+ right["width"]
+ - color_len(right["prefix"])
+ - color_len(right["suffix"]),
+ right["width"] - color_len(right["prefix"]),
+ right["width"] - color_len(right["prefix"]),
+ )
+
+ right_split = split_into_lines(right["contents"], right_width_tuple)
+ max_line_count = max(len(left_split), len(right_split))
+
+ out = ""
+ for i in range(max_line_count):
+ # indentation
+ out += indent_str
+
+ # Prefix or indent_str for line
+ if i == 0:
+ out += left["prefix"]
+ else:
+ out += indent(color_len(left["prefix"]))
+
+ # Line i of left hand side contents.
+ if i < len(left_split):
+ out += left_split[i]
+ left_part_len = color_len(left_split[i])
+ else:
+ left_part_len = 0
+
+ # Padding until end of column.
+ # Note: differs from original
+ # column calcs in not -1 afterwards for space
+ # in track number as that is included in 'prefix'
+ padding = left["width"] - color_len(left["prefix"]) - left_part_len
+
+ # Remove some padding on the first line to display
+ # length
+ if i == 0:
+ padding -= color_len(left["suffix"])
+
+ out += indent(padding)
+
+ if i == 0:
+ out += left["suffix"]
+
+ # Separator between columns.
+ if i == 0:
+ out += separator
+ else:
+ out += indent(color_len(separator))
+
+ # Right prefix, contents, padding, suffix
+ if i == 0:
+ out += right["prefix"]
+ else:
+ out += indent(color_len(right["prefix"]))
+
+ # Line i of right hand side.
+ if i < len(right_split):
+ out += right_split[i]
+ right_part_len = color_len(right_split[i])
+ else:
+ right_part_len = 0
+
+ # Padding until end of column
+ padding = (
+ right["width"] - color_len(right["prefix"]) - right_part_len
+ )
+ # Remove some padding on the first line to display
+ # length
+ if i == 0:
+ padding -= color_len(right["suffix"])
+ out += indent(padding)
+ # Length in first line
+ if i == 0:
+ out += right["suffix"]
+
+ # Linebreak, except in the last line.
+ if i < max_line_count - 1:
+ out += "\n"
+
+ # Constructed all of the columns, now print
+ print_(out)
+
+
+def print_newline_layout(
+ indent_str, left, right, separator=" -> ", max_width=term_width()
+):
+ """Prints using a newline separator between left & right if
+ they go over their allocated widths. The datastructures are
+ shared with the column layout. In contrast to the column layout,
+ the prefix and suffix are printed at the beginning and end of
+ the contents. If no wrapping is required (i.e. everything fits) the
+ first line will look exactly the same as the column layout:
+ {indent}{lhs0}{separator}{rhs0}
+ However if this would go over the width given, the layout now becomes:
+ {indent}{lhs0}
+ {indent}{separator}{rhs0}
+ If {lhs0} would go over the maximum width, the subsequent lines are
+ indented a second time for ease of reading.
+ """
+ if right["prefix"] + right["contents"] + right["suffix"] == "":
+ # No right hand information, so we don't need a separator.
+ separator = ""
+ first_line_no_wrap = (
+ indent_str
+ + left["prefix"]
+ + left["contents"]
+ + left["suffix"]
+ + separator
+ + right["prefix"]
+ + right["contents"]
+ + right["suffix"]
+ )
+ if color_len(first_line_no_wrap) < max_width:
+ # Everything fits, print out line.
+ print_(first_line_no_wrap)
+ else:
+ # Newline separation, with wrapping
+ empty_space = max_width - len(indent_str)
+ # On lower lines we will double the indent for clarity
+ left_width_tuple = (
+ empty_space,
+ empty_space - len(indent_str),
+ empty_space - len(indent_str),
+ )
+ left_str = left["prefix"] + left["contents"] + left["suffix"]
+ left_split = split_into_lines(left_str, left_width_tuple)
+ # Repeat calculations for rhs, including separator on first line
+ right_width_tuple = (
+ empty_space - color_len(separator),
+ empty_space - len(indent_str),
+ empty_space - len(indent_str),
+ )
+ right_str = right["prefix"] + right["contents"] + right["suffix"]
+ right_split = split_into_lines(right_str, right_width_tuple)
+ for i, line in enumerate(left_split):
+ if i == 0:
+ print_(indent_str + line)
+ elif line != "":
+ # Ignore empty lines
+ print_(indent_str * 2 + line)
+ for i, line in enumerate(right_split):
+ if i == 0:
+ print_(indent_str + separator + line)
+ elif line != "":
+ print_(indent_str * 2 + line)
+
+
FLOAT_EPSILON = 0.01
@@ -660,25 +1153,28 @@ def _field_diff(field, old, old_fmt, new, new_fmt):
newval = new.get(field)
# If no change, abort.
- if isinstance(oldval, float) and isinstance(newval, float) and \
- abs(oldval - newval) < FLOAT_EPSILON:
+ if (
+ isinstance(oldval, float)
+ and isinstance(newval, float)
+ and abs(oldval - newval) < FLOAT_EPSILON
+ ):
return None
elif oldval == newval:
return None
# Get formatted values for output.
- oldstr = old_fmt.get(field, '')
- newstr = new_fmt.get(field, '')
+ oldstr = old_fmt.get(field, "")
+ newstr = new_fmt.get(field, "")
# For strings, highlight changes. For others, colorize the whole
# thing.
if isinstance(oldval, str):
oldstr, newstr = colordiff(oldval, newstr)
else:
- oldstr = colorize('text_error', oldstr)
- newstr = colorize('text_error', newstr)
+ oldstr = colorize("text_error", oldstr)
+ newstr = colorize("text_error", newstr)
- return f'{oldstr} -> {newstr}'
+ return f"{oldstr} -> {newstr}"
def show_model_changes(new, old=None, fields=None, always=False):
@@ -702,29 +1198,28 @@ def show_model_changes(new, old=None, fields=None, always=False):
changes = []
for field in old:
# Subset of the fields. Never show mtime.
- if field == 'mtime' or (fields and field not in fields):
+ if field == "mtime" or (fields and field not in fields):
continue
# Detect and show difference for this field.
line = _field_diff(field, old, old_fmt, new, new_fmt)
if line:
- changes.append(f' {field}: {line}')
+ changes.append(f" {field}: {line}")
# New fields.
for field in set(new) - set(old):
if fields and field not in fields:
continue
- changes.append(' {}: {}'.format(
- field,
- colorize('text_highlight', new_fmt[field])
- ))
+ changes.append(
+ " {}: {}".format(field, colorize("text_highlight", new_fmt[field]))
+ )
# Print changes.
if changes or always:
print_(format(old))
if changes:
- print_('\n'.join(changes))
+ print_("\n".join(changes))
return bool(changes)
@@ -751,31 +1246,34 @@ def show_path_changes(path_changes):
destinations = list(map(util.displayable_path, destinations))
# Calculate widths for terminal split
- col_width = (term_width() - len(' -> ')) // 2
+ col_width = (term_width() - len(" -> ")) // 2
max_width = len(max(sources + destinations, key=len))
if max_width > col_width:
# Print every change over two lines
for source, dest in zip(sources, destinations):
color_source, color_dest = colordiff(source, dest)
- print_('{0} \n -> {1}'.format(color_source, color_dest))
+ print_("{0} \n -> {1}".format(color_source, color_dest))
else:
# Print every change on a single line, and add a header
- title_pad = max_width - len('Source ') + len(' -> ')
+ title_pad = max_width - len("Source ") + len(" -> ")
- print_('Source {0} Destination'.format(' ' * title_pad))
+ print_("Source {0} Destination".format(" " * title_pad))
for source, dest in zip(sources, destinations):
pad = max_width - len(source)
color_source, color_dest = colordiff(source, dest)
- print_('{0} {1} -> {2}'.format(
- color_source,
- ' ' * pad,
- color_dest,
- ))
+ print_(
+ "{0} {1} -> {2}".format(
+ color_source,
+ " " * pad,
+ color_dest,
+ )
+ )
# Helper functions for option parsing.
+
def _store_dict(option, opt_str, value, parser):
"""Custom action callback to parse options which have ``key=value``
pairs as values. All such pairs passed for this option are
@@ -790,17 +1288,16 @@ def _store_dict(option, opt_str, value, parser):
setattr(parser.values, dest, {})
option_values = getattr(parser.values, dest)
- # Decode the argument using the platform's argument encoding.
- value = util.text_string(value, util.arg_encoding())
-
try:
- key, value = value.split('=', 1)
+ key, value = value.split("=", 1)
if not (key and value):
raise ValueError
except ValueError:
raise UserError(
- "supplied argument `{}' is not of the form `key=value'"
- .format(value))
+ "supplied argument `{}' is not of the form `key=value'".format(
+ value
+ )
+ )
option_values[key] = value
@@ -828,20 +1325,29 @@ class CommonOptionsParser(optparse.OptionParser):
# us to check whether it has been specified on the CLI - bypassing the
# fact that arguments may be in any order
- def add_album_option(self, flags=('-a', '--album')):
+ def add_album_option(self, flags=("-a", "--album")):
"""Add a -a/--album option to match albums instead of tracks.
If used then the format option can auto-detect whether we're setting
the format for items or albums.
Sets the album property on the options extracted from the CLI.
"""
- album = optparse.Option(*flags, action='store_true',
- help='match albums instead of tracks')
+ album = optparse.Option(
+ *flags, action="store_true", help="match albums instead of tracks"
+ )
self.add_option(album)
self._album_flags = set(flags)
- def _set_format(self, option, opt_str, value, parser, target=None,
- fmt=None, store_true=False):
+ def _set_format(
+ self,
+ option,
+ opt_str,
+ value,
+ parser,
+ target=None,
+ fmt=None,
+ store_true=False,
+ ):
"""Internal callback that sets the correct format while parsing CLI
arguments.
"""
@@ -852,9 +1358,9 @@ class CommonOptionsParser(optparse.OptionParser):
if fmt:
value = fmt
elif value:
- value, = decargs([value])
+ (value,) = decargs([value])
else:
- value = ''
+ value = ""
parser.values.format = value
if target:
@@ -874,7 +1380,7 @@ class CommonOptionsParser(optparse.OptionParser):
config[library.Item._format_config_key].set(value)
config[library.Album._format_config_key].set(value)
- def add_path_option(self, flags=('-p', '--path')):
+ def add_path_option(self, flags=("-p", "--path")):
"""Add a -p/--path option to display the path instead of the default
format.
@@ -884,14 +1390,17 @@ class CommonOptionsParser(optparse.OptionParser):
Sets the format property to '$path' on the options extracted from the
CLI.
"""
- path = optparse.Option(*flags, nargs=0, action='callback',
- callback=self._set_format,
- callback_kwargs={'fmt': '$path',
- 'store_true': True},
- help='print paths for matched items or albums')
+ path = optparse.Option(
+ *flags,
+ nargs=0,
+ action="callback",
+ callback=self._set_format,
+ callback_kwargs={"fmt": "$path", "store_true": True},
+ help="print paths for matched items or albums",
+ )
self.add_option(path)
- def add_format_option(self, flags=('-f', '--format'), target=None):
+ def add_format_option(self, flags=("-f", "--format"), target=None):
"""Add -f/--format option to print some LibModel instances with a
custom format.
@@ -909,19 +1418,20 @@ class CommonOptionsParser(optparse.OptionParser):
kwargs = {}
if target:
if isinstance(target, str):
- target = {'item': library.Item,
- 'album': library.Album}[target]
- kwargs['target'] = target
+ target = {"item": library.Item, "album": library.Album}[target]
+ kwargs["target"] = target
- opt = optparse.Option(*flags, action='callback',
- callback=self._set_format,
- callback_kwargs=kwargs,
- help='print with custom format')
+ opt = optparse.Option(
+ *flags,
+ action="callback",
+ callback=self._set_format,
+ callback_kwargs=kwargs,
+ help="print with custom format",
+ )
self.add_option(opt)
def add_all_common_options(self):
- """Add album, path and format options.
- """
+ """Add album, path and format options."""
self.add_album_option()
self.add_path_option()
self.add_format_option()
@@ -935,12 +1445,15 @@ class CommonOptionsParser(optparse.OptionParser):
# There you will also find a better description of the code and a more
# succinct example program.
+
class Subcommand:
"""A subcommand of a root command-line application that may be
invoked by a SubcommandOptionParser.
"""
- def __init__(self, name, parser=None, help='', aliases=(), hide=False):
+ func: Callable[[library.Library, optparse.Values, List[str]], Any]
+
+ def __init__(self, name, parser=None, help="", aliases=(), hide=False):
"""Creates a new subcommand. name is the primary way to invoke
the subcommand; aliases are alternate names. parser is an
OptionParser responsible for parsing the subcommand's options.
@@ -967,8 +1480,9 @@ class Subcommand:
@root_parser.setter
def root_parser(self, root_parser):
self._root_parser = root_parser
- self.parser.prog = '{} {}'.format(
- as_string(root_parser.get_prog_name()), self.name)
+ self.parser.prog = "{} {}".format(
+ as_string(root_parser.get_prog_name()), self.name
+ )
class SubcommandsOptionParser(CommonOptionsParser):
@@ -982,11 +1496,13 @@ class SubcommandsOptionParser(CommonOptionsParser):
to subcommands, a sequence of Subcommand objects.
"""
# A more helpful default usage.
- if 'usage' not in kwargs:
- kwargs['usage'] = """
+ if "usage" not in kwargs:
+ kwargs[
+ "usage"
+ ] = """
%prog COMMAND [ARGS...]
%prog help COMMAND"""
- kwargs['add_help_option'] = False
+ kwargs["add_help_option"] = False
# Super constructor.
super().__init__(*args, **kwargs)
@@ -997,8 +1513,7 @@ class SubcommandsOptionParser(CommonOptionsParser):
self.subcommands = []
def add_subcommand(self, *cmds):
- """Adds a Subcommand object to the parser's list of commands.
- """
+ """Adds a Subcommand object to the parser's list of commands."""
for cmd in cmds:
cmd.root_parser = self
self.subcommands.append(cmd)
@@ -1012,7 +1527,7 @@ class SubcommandsOptionParser(CommonOptionsParser):
# Subcommands header.
result = ["\n"]
- result.append(formatter.format_heading('Commands'))
+ result.append(formatter.format_heading("Commands"))
formatter.indent()
# Generate the display names (including aliases).
@@ -1024,7 +1539,7 @@ class SubcommandsOptionParser(CommonOptionsParser):
for subcommand in subcommands:
name = subcommand.name
if subcommand.aliases:
- name += ' (%s)' % ', '.join(subcommand.aliases)
+ name += " (%s)" % ", ".join(subcommand.aliases)
disp_names.append(name)
# Set the help position based on the max width.
@@ -1040,16 +1555,24 @@ class SubcommandsOptionParser(CommonOptionsParser):
name = "%*s%s\n" % (formatter.current_indent, "", name)
indent_first = help_position
else:
- name = "%*s%-*s " % (formatter.current_indent, "",
- name_width, name)
+ name = "%*s%-*s " % (
+ formatter.current_indent,
+ "",
+ name_width,
+ name,
+ )
indent_first = 0
result.append(name)
help_width = formatter.width - help_position
help_lines = textwrap.wrap(subcommand.help, help_width)
- help_line = help_lines[0] if help_lines else ''
+ help_line = help_lines[0] if help_lines else ""
result.append("%*s%s\n" % (indent_first, "", help_line))
- result.extend(["%*s%s\n" % (help_position, "", line)
- for line in help_lines[1:]])
+ result.extend(
+ [
+ "%*s%s\n" % (help_position, "", line)
+ for line in help_lines[1:]
+ ]
+ )
formatter.dedent()
# Concatenate the original help message with the subcommand
@@ -1062,8 +1585,7 @@ class SubcommandsOptionParser(CommonOptionsParser):
an alias. If no subcommand matches, returns None.
"""
for subcommand in self.subcommands:
- if name == subcommand.name or \
- name in subcommand.aliases:
+ if name == subcommand.name or name in subcommand.aliases:
return subcommand
return None
@@ -1075,9 +1597,9 @@ class SubcommandsOptionParser(CommonOptionsParser):
# Force the help command
if options.help:
- subargs = ['help']
+ subargs = ["help"]
elif options.version:
- subargs = ['version']
+ subargs = ["version"]
return options, subargs
def parse_subcommand(self, args):
@@ -1087,7 +1609,7 @@ class SubcommandsOptionParser(CommonOptionsParser):
"""
# Help is default command
if not args:
- args = ['help']
+ args = ["help"]
cmdname = args.pop(0)
subcommand = self._subcommand_for_name(cmdname)
@@ -1098,23 +1620,24 @@ class SubcommandsOptionParser(CommonOptionsParser):
return subcommand, suboptions, subargs
-optparse.Option.ALWAYS_TYPED_ACTIONS += ('callback',)
+optparse.Option.ALWAYS_TYPED_ACTIONS += ("callback",)
# The main entry point and bootstrapping.
+
def _load_plugins(options, config):
- """Load the plugins specified on the command line or in the configuration.
- """
- paths = config['pluginpath'].as_str_seq(split=False)
+ """Load the plugins specified on the command line or in the configuration."""
+ paths = config["pluginpath"].as_str_seq(split=False)
paths = [util.normpath(p) for p in paths]
- log.debug('plugin paths: {0}', util.displayable_path(paths))
+ log.debug("plugin paths: {0}", util.displayable_path(paths))
# On Python 3, the search paths need to be unicode.
- paths = [util.py3_path(p) for p in paths]
+ paths = [os.fsdecode(p) for p in paths]
# Extend the `beetsplug` package to include the plugin paths.
import beetsplug
+
beetsplug.__path__ = paths + list(beetsplug.__path__)
# For backwards compatibility, also support plugin paths that
@@ -1123,10 +1646,17 @@ def _load_plugins(options, config):
# If we were given any plugins on the command line, use those.
if options.plugins is not None:
- plugin_list = (options.plugins.split(',')
- if len(options.plugins) > 0 else [])
+ plugin_list = (
+ options.plugins.split(",") if len(options.plugins) > 0 else []
+ )
else:
- plugin_list = config['plugins'].as_str_seq()
+ plugin_list = config["plugins"].as_str_seq()
+
+ # Exclude any plugins that were specified on the command line
+ if options.exclude is not None:
+ plugin_list = [
+ p for p in plugin_list if p not in options.exclude.split(",")
+ ]
plugins.load_plugins(plugin_list)
return plugins
@@ -1171,12 +1701,11 @@ def _setup(options, lib=None):
def _configure(options):
- """Amend the global configuration object with command line options.
- """
+ """Amend the global configuration object with command line options."""
# Add any additional config files specified with --config. This
# special handling lets specified plugins get loaded before we
# finish parsing the command line.
- if getattr(options, 'config', None) is not None:
+ if getattr(options, "config", None) is not None:
overlay_path = options.config
del options.config
config.set_file(overlay_path)
@@ -1185,50 +1714,67 @@ def _configure(options):
config.set_args(options)
# Configure the logger.
- if config['verbose'].get(int):
+ if config["verbose"].get(int):
log.set_global_level(logging.DEBUG)
else:
log.set_global_level(logging.INFO)
if overlay_path:
- log.debug('overlaying configuration: {0}',
- util.displayable_path(overlay_path))
+ log.debug(
+ "overlaying configuration: {0}", util.displayable_path(overlay_path)
+ )
config_path = config.user_config_path()
if os.path.isfile(config_path):
- log.debug('user configuration: {0}',
- util.displayable_path(config_path))
+ log.debug("user configuration: {0}", util.displayable_path(config_path))
else:
- log.debug('no user configuration found at {0}',
- util.displayable_path(config_path))
+ log.debug(
+ "no user configuration found at {0}",
+ util.displayable_path(config_path),
+ )
- log.debug('data directory: {0}',
- util.displayable_path(config.config_dir()))
+ log.debug("data directory: {0}", util.displayable_path(config.config_dir()))
return config
+def _ensure_db_directory_exists(path):
+ if path == b":memory:": # in memory db
+ return
+ newpath = os.path.dirname(path)
+ if not os.path.isdir(newpath):
+ if input_yn(
+ "The database directory {} does not \
+ exist. Create it (Y/n)?".format(
+ util.displayable_path(newpath)
+ )
+ ):
+ os.makedirs(newpath)
+
+
def _open_library(config):
- """Create a new library instance from the configuration.
- """
- dbpath = util.bytestring_path(config['library'].as_filename())
+ """Create a new library instance from the configuration."""
+ dbpath = util.bytestring_path(config["library"].as_filename())
+ _ensure_db_directory_exists(dbpath)
try:
lib = library.Library(
dbpath,
- config['directory'].as_filename(),
+ config["directory"].as_filename(),
get_path_formats(),
get_replacements(),
)
lib.get_item(0) # Test database connection.
except (sqlite3.OperationalError, sqlite3.DatabaseError) as db_error:
- log.debug('{}', traceback.format_exc())
- raise UserError("database file {} cannot not be opened: {}".format(
- util.displayable_path(dbpath),
- db_error
- ))
- log.debug('library database: {0}\n'
- 'library directory: {1}',
- util.displayable_path(lib.path),
- util.displayable_path(lib.directory))
+ log.debug("{}", traceback.format_exc())
+ raise UserError(
+ "database file {} cannot not be opened: {}".format(
+ util.displayable_path(dbpath), db_error
+ )
+ )
+ log.debug(
+ "library database: {0}\n" "library directory: {1}",
+ util.displayable_path(lib.path),
+ util.displayable_path(lib.directory),
+ )
return lib
@@ -1237,31 +1783,65 @@ def _raw_main(args, lib=None):
handling.
"""
parser = SubcommandsOptionParser()
- parser.add_format_option(flags=('--format-item',), target=library.Item)
- parser.add_format_option(flags=('--format-album',), target=library.Album)
- parser.add_option('-l', '--library', dest='library',
- help='library database file to use')
- parser.add_option('-d', '--directory', dest='directory',
- help="destination music directory")
- parser.add_option('-v', '--verbose', dest='verbose', action='count',
- help='log more details (use twice for even more)')
- parser.add_option('-c', '--config', dest='config',
- help='path to configuration file')
- parser.add_option('-p', '--plugins', dest='plugins',
- help='a comma-separated list of plugins to load')
- parser.add_option('-h', '--help', dest='help', action='store_true',
- help='show this help message and exit')
- parser.add_option('--version', dest='version', action='store_true',
- help=optparse.SUPPRESS_HELP)
+ parser.add_format_option(flags=("--format-item",), target=library.Item)
+ parser.add_format_option(flags=("--format-album",), target=library.Album)
+ parser.add_option(
+ "-l", "--library", dest="library", help="library database file to use"
+ )
+ parser.add_option(
+ "-d",
+ "--directory",
+ dest="directory",
+ help="destination music directory",
+ )
+ parser.add_option(
+ "-v",
+ "--verbose",
+ dest="verbose",
+ action="count",
+ help="log more details (use twice for even more)",
+ )
+ parser.add_option(
+ "-c", "--config", dest="config", help="path to configuration file"
+ )
+ parser.add_option(
+ "-p",
+ "--plugins",
+ dest="plugins",
+ help="a comma-separated list of plugins to load",
+ )
+ parser.add_option(
+ "-P",
+ "--disable-plugins",
+ dest="exclude",
+ help="a comma-separated list of plugins to disable",
+ )
+ parser.add_option(
+ "-h",
+ "--help",
+ dest="help",
+ action="store_true",
+ help="show this help message and exit",
+ )
+ parser.add_option(
+ "--version",
+ dest="version",
+ action="store_true",
+ help=optparse.SUPPRESS_HELP,
+ )
options, subargs = parser.parse_global_options(args)
# Special case for the `config --edit` command: bypass _setup so
# that an invalid configuration does not prevent the editor from
# starting.
- if subargs and subargs[0] == 'config' \
- and ('-e' in subargs or '--edit' in subargs):
+ if (
+ subargs
+ and subargs[0] == "config"
+ and ("-e" in subargs or "--edit" in subargs)
+ ):
from beets.ui.commands import config_edit
+
return config_edit()
test_lib = bool(lib)
@@ -1271,7 +1851,7 @@ def _raw_main(args, lib=None):
subcommand, suboptions, subargs = parser.parse_subcommand(subargs)
subcommand.func(lib, suboptions, subargs)
- plugins.send('cli_exit', lib=lib)
+ plugins.send("cli_exit", lib=lib)
if not test_lib:
# Clean up the library unless it came from the test harness.
lib._close()
@@ -1285,7 +1865,7 @@ def main(args=None):
_raw_main(args)
except UserError as exc:
message = exc.args[0] if exc.args else None
- log.error('error: {0}', message)
+ log.error("error: {0}", message)
sys.exit(1)
except util.HumanReadableException as exc:
exc.log(log)
@@ -1293,14 +1873,14 @@ def main(args=None):
except library.FileOperationError as exc:
# These errors have reasonable human-readable descriptions, but
# we still want to log their tracebacks for debugging.
- log.debug('{}', traceback.format_exc())
- log.error('{}', exc)
+ log.debug("{}", traceback.format_exc())
+ log.error("{}", exc)
sys.exit(1)
except confuse.ConfigError as exc:
- log.error('configuration error: {0}', exc)
+ log.error("configuration error: {0}", exc)
sys.exit(1)
except db_query.InvalidQueryError as exc:
- log.error('invalid query: {0}', exc)
+ log.error("invalid query: {0}", exc)
sys.exit(1)
except OSError as exc:
if exc.errno == errno.EPIPE:
@@ -1310,11 +1890,11 @@ def main(args=None):
raise
except KeyboardInterrupt:
# Silently ignore ^C except in verbose mode.
- log.debug('{}', traceback.format_exc())
+ log.debug("{}", traceback.format_exc())
except db.DBAccessError as exc:
log.error(
- 'database access error: {0}\n'
- 'the library file might have a permissions problem',
- exc
+ "database access error: {0}\n"
+ "the library file might have a permissions problem",
+ exc,
)
sys.exit(1)
diff --git a/lib/beets/ui/commands.py b/lib/beets/ui/commands.py
index 3a337401..826dc07a 100755
--- a/lib/beets/ui/commands.py
+++ b/lib/beets/ui/commands.py
@@ -19,32 +19,38 @@ interface.
import os
import re
-from platform import python_version
-from collections import namedtuple, Counter
+from collections import Counter, namedtuple
from itertools import chain
+from platform import python_version
+from typing import Sequence
import beets
-from beets import ui
-from beets.ui import print_, input_, decargs, show_path_changes
-from beets import autotag
-from beets.autotag import Recommendation
-from beets.autotag import hooks
-from beets import plugins
-from beets import importer
-from beets import util
-from beets.util import syspath, normpath, ancestry, displayable_path, \
- MoveOperation
-from beets import library
-from beets import config
-from beets import logging
+from beets import autotag, config, importer, library, logging, plugins, ui, util
+from beets.autotag import Recommendation, hooks
+from beets.ui import (
+ decargs,
+ input_,
+ print_,
+ print_column_layout,
+ print_newline_layout,
+ show_path_changes,
+)
+from beets.util import (
+ MoveOperation,
+ ancestry,
+ displayable_path,
+ functemplate,
+ normpath,
+ syspath,
+)
from . import _store_dict
-VARIOUS_ARTISTS = 'Various Artists'
-PromptChoice = namedtuple('PromptChoice', ['short', 'long', 'callback'])
+VARIOUS_ARTISTS = "Various Artists"
+PromptChoice = namedtuple("PromptChoice", ["short", "long", "callback"])
# Global logger.
-log = logging.getLogger('beets')
+log = logging.getLogger("beets")
# The list of default subcommands. This is populated with Subcommand
# objects that can be fed to a SubcommandsOptionParser.
@@ -53,6 +59,7 @@ default_commands = []
# Utilities.
+
def _do_query(lib, query, album, also_items=True):
"""For commands that operate on matched items, performs a query
and returns a list of matching items and a list of matching
@@ -72,27 +79,67 @@ def _do_query(lib, query, album, also_items=True):
items = list(lib.items(query))
if album and not albums:
- raise ui.UserError('No matching albums found.')
+ raise ui.UserError("No matching albums found.")
elif not album and not items:
- raise ui.UserError('No matching items found.')
+ raise ui.UserError("No matching items found.")
return items, albums
+def _paths_from_logfile(path):
+ """Parse the logfile and yield skipped paths to pass to the `import`
+ command.
+ """
+ with open(path, encoding="utf-8") as fp:
+ for i, line in enumerate(fp, start=1):
+ verb, sep, paths = line.rstrip("\n").partition(" ")
+ if not sep:
+ raise ValueError(f"line {i} is invalid")
+
+ # Ignore informational lines that don't need to be re-imported.
+ if verb in {"import", "duplicate-keep", "duplicate-replace"}:
+ continue
+
+ if verb not in {"asis", "skip", "duplicate-skip"}:
+ raise ValueError(f"line {i} contains unknown verb {verb}")
+
+ yield os.path.commonpath(paths.split("; "))
+
+
+def _parse_logfiles(logfiles):
+ """Parse all `logfiles` and yield paths from it."""
+ for logfile in logfiles:
+ try:
+ yield from _paths_from_logfile(syspath(normpath(logfile)))
+ except ValueError as err:
+ raise ui.UserError(
+ "malformed logfile {}: {}".format(
+ util.displayable_path(logfile), str(err)
+ )
+ ) from err
+ except OSError as err:
+ raise ui.UserError(
+ "unreadable logfile {}: {}".format(
+ util.displayable_path(logfile), str(err)
+ )
+ ) from err
+
+
# fields: Shows a list of available fields for queries and format strings.
+
def _print_keys(query):
"""Given a SQLite query result, print the `key` field of each
returned row, with indentation of 2 spaces.
"""
for row in query:
- print_(' ' * 2 + row['key'])
+ print_(" " * 2 + row["key"])
def fields_func(lib, opts, args):
def _print_rows(names):
names.sort()
- print_(' ' + '\n '.join(names))
+ print_(" " + "\n ".join(names))
print_("Item fields:")
_print_rows(library.Item.all_keys())
@@ -102,7 +149,7 @@ def fields_func(lib, opts, args):
with lib.transaction() as tx:
# The SQL uses the DISTINCT to get unique values from the query
- unique_fields = 'SELECT DISTINCT key FROM (%s)'
+ unique_fields = "SELECT DISTINCT key FROM (%s)"
print_("Item flexible attributes:")
_print_keys(tx.query(unique_fields % library.Item._flex_table))
@@ -112,8 +159,7 @@ def fields_func(lib, opts, args):
fields_cmd = ui.Subcommand(
- 'fields',
- help='show fields available for queries and format strings'
+ "fields", help="show fields available for queries and format strings"
)
fields_cmd.func = fields_func
default_commands.append(fields_cmd)
@@ -121,12 +167,13 @@ default_commands.append(fields_cmd)
# help: Print help text for commands
-class HelpCommand(ui.Subcommand):
+class HelpCommand(ui.Subcommand):
def __init__(self):
super().__init__(
- 'help', aliases=('?',),
- help='give detailed help on a specific sub-command',
+ "help",
+ aliases=("?",),
+ help="give detailed help on a specific sub-command",
)
def func(self, lib, opts, args):
@@ -147,50 +194,92 @@ default_commands.append(HelpCommand())
# Importer utilities and support.
+
def disambig_string(info):
"""Generate a string for an AlbumInfo or TrackInfo object that
provides context that helps disambiguate similar-looking albums and
tracks.
"""
- disambig = []
- if info.data_source and info.data_source != 'MusicBrainz':
- disambig.append(info.data_source)
-
if isinstance(info, hooks.AlbumInfo):
- if info.media:
- if info.mediums and info.mediums > 1:
- disambig.append('{}x{}'.format(
- info.mediums, info.media
- ))
- else:
- disambig.append(info.media)
- if info.year:
- disambig.append(str(info.year))
- if info.country:
- disambig.append(info.country)
- if info.label:
- disambig.append(info.label)
- if info.catalognum:
- disambig.append(info.catalognum)
- if info.albumdisambig:
- disambig.append(info.albumdisambig)
+ disambig = get_album_disambig_fields(info)
+ elif isinstance(info, hooks.TrackInfo):
+ disambig = get_singleton_disambig_fields(info)
+ else:
+ return ""
- if disambig:
- return ', '.join(disambig)
+ return ", ".join(disambig)
+
+
+def get_singleton_disambig_fields(info: hooks.TrackInfo) -> Sequence[str]:
+ out = []
+ chosen_fields = config["match"]["singleton_disambig_fields"].as_str_seq()
+ calculated_values = {
+ "index": "Index {}".format(str(info.index)),
+ "track_alt": "Track {}".format(info.track_alt),
+ "album": (
+ "[{}]".format(info.album)
+ if (
+ config["import"]["singleton_album_disambig"].get()
+ and info.get("album")
+ )
+ else ""
+ ),
+ }
+
+ for field in chosen_fields:
+ if field in calculated_values:
+ out.append(str(calculated_values[field]))
+ else:
+ try:
+ out.append(str(info[field]))
+ except (AttributeError, KeyError):
+ print(f"Disambiguation string key {field} does not exist.")
+
+ return out
+
+
+def get_album_disambig_fields(info: hooks.AlbumInfo) -> Sequence[str]:
+ out = []
+ chosen_fields = config["match"]["album_disambig_fields"].as_str_seq()
+ calculated_values = {
+ "media": (
+ "{}x{}".format(info.mediums, info.media)
+ if (info.mediums and info.mediums > 1)
+ else info.media
+ ),
+ }
+
+ for field in chosen_fields:
+ if field in calculated_values:
+ out.append(str(calculated_values[field]))
+ else:
+ try:
+ out.append(str(info[field]))
+ except (AttributeError, KeyError):
+ print(f"Disambiguation string key {field} does not exist.")
+
+ return out
+
+
+def dist_colorize(string, dist):
+ """Formats a string as a colorized similarity string according to
+ a distance.
+ """
+ if dist <= config["match"]["strong_rec_thresh"].as_number():
+ string = ui.colorize("text_success", string)
+ elif dist <= config["match"]["medium_rec_thresh"].as_number():
+ string = ui.colorize("text_warning", string)
+ else:
+ string = ui.colorize("text_error", string)
+ return string
def dist_string(dist):
"""Formats a distance (a float) as a colorized similarity percentage
string.
"""
- out = '%.1f%%' % ((1 - dist) * 100)
- if dist <= config['match']['strong_rec_thresh'].as_number():
- out = ui.colorize('text_success', out)
- elif dist <= config['match']['medium_rec_thresh'].as_number():
- out = ui.colorize('text_warning', out)
- else:
- out = ui.colorize('text_error', out)
- return out
+ string = "{:.1f}%".format(((1 - dist) * 100))
+ return dist_colorize(string, dist)
def penalty_string(distance, limit=None):
@@ -199,31 +288,185 @@ def penalty_string(distance, limit=None):
"""
penalties = []
for key in distance.keys():
- key = key.replace('album_', '')
- key = key.replace('track_', '')
- key = key.replace('_', ' ')
+ key = key.replace("album_", "")
+ key = key.replace("track_", "")
+ key = key.replace("_", " ")
penalties.append(key)
if penalties:
if limit and len(penalties) > limit:
- penalties = penalties[:limit] + ['...']
- return ui.colorize('text_warning', '(%s)' % ', '.join(penalties))
+ penalties = penalties[:limit] + ["..."]
+ # Prefix penalty string with U+2260: Not Equal To
+ penalty_string = "\u2260 {}".format(", ".join(penalties))
+ return ui.colorize("changed", penalty_string)
-def show_change(cur_artist, cur_album, match):
- """Print out a representation of the changes that will be made if an
- album's tags are changed according to `match`, which must be an AlbumMatch
- object.
+class ChangeRepresentation:
+ """Keeps track of all information needed to generate a (colored) text
+ representation of the changes that will be made if an album or singleton's
+ tags are changed according to `match`, which must be an AlbumMatch or
+ TrackMatch object, accordingly.
"""
- def show_album(artist, album):
- if artist:
- album_description = f' {artist} - {album}'
- elif album:
- album_description = ' %s' % album
- else:
- album_description = ' (unknown album)'
- print_(album_description)
- def format_index(track_info):
+ cur_artist = None
+ # cur_album set if album, cur_title set if singleton
+ cur_album = None
+ cur_title = None
+ match = None
+ indent_header = ""
+ indent_detail = ""
+
+ def __init__(self):
+ # Read match header indentation width from config.
+ match_header_indent_width = config["ui"]["import"]["indentation"][
+ "match_header"
+ ].as_number()
+ self.indent_header = ui.indent(match_header_indent_width)
+
+ # Read match detail indentation width from config.
+ match_detail_indent_width = config["ui"]["import"]["indentation"][
+ "match_details"
+ ].as_number()
+ self.indent_detail = ui.indent(match_detail_indent_width)
+
+ # Read match tracklist indentation width from config
+ match_tracklist_indent_width = config["ui"]["import"]["indentation"][
+ "match_tracklist"
+ ].as_number()
+ self.indent_tracklist = ui.indent(match_tracklist_indent_width)
+ self.layout = config["ui"]["import"]["layout"].as_choice(
+ {
+ "column": 0,
+ "newline": 1,
+ }
+ )
+
+ def print_layout(
+ self, indent, left, right, separator=" -> ", max_width=None
+ ):
+ if not max_width:
+ # If no max_width provided, use terminal width
+ max_width = ui.term_width()
+ if self.layout == 0:
+ print_column_layout(indent, left, right, separator, max_width)
+ else:
+ print_newline_layout(indent, left, right, separator, max_width)
+
+ def show_match_header(self):
+ """Print out a 'header' identifying the suggested match (album name,
+ artist name,...) and summarizing the changes that would be made should
+ the user accept the match.
+ """
+ # Print newline at beginning of change block.
+ print_("")
+
+ # 'Match' line and similarity.
+ print_(
+ self.indent_header + f"Match ({dist_string(self.match.distance)}):"
+ )
+
+ if self.match.info.get("album"):
+ # Matching an album - print that
+ artist_album_str = (
+ f"{self.match.info.artist}" + f" - {self.match.info.album}"
+ )
+ else:
+ # Matching a single track
+ artist_album_str = (
+ f"{self.match.info.artist}" + f" - {self.match.info.title}"
+ )
+ print_(
+ self.indent_header
+ + dist_colorize(artist_album_str, self.match.distance)
+ )
+
+ # Penalties.
+ penalties = penalty_string(self.match.distance)
+ if penalties:
+ print_(self.indent_header + penalties)
+
+ # Disambiguation.
+ disambig = disambig_string(self.match.info)
+ if disambig:
+ print_(self.indent_header + disambig)
+
+ # Data URL.
+ if self.match.info.data_url:
+ url = ui.colorize("text_faint", f"{self.match.info.data_url}")
+ print_(self.indent_header + url)
+
+ def show_match_details(self):
+ """Print out the details of the match, including changes in album name
+ and artist name.
+ """
+ # Artist.
+ artist_l, artist_r = self.cur_artist or "", self.match.info.artist
+ if artist_r == VARIOUS_ARTISTS:
+ # Hide artists for VA releases.
+ artist_l, artist_r = "", ""
+ if artist_l != artist_r:
+ artist_l, artist_r = ui.colordiff(artist_l, artist_r)
+ # Prefix with U+2260: Not Equal To
+ left = {
+ "prefix": ui.colorize("changed", "\u2260") + " Artist: ",
+ "contents": artist_l,
+ "suffix": "",
+ }
+ right = {"prefix": "", "contents": artist_r, "suffix": ""}
+ self.print_layout(self.indent_detail, left, right)
+
+ else:
+ print_(self.indent_detail + "*", "Artist:", artist_r)
+
+ if self.cur_album:
+ # Album
+ album_l, album_r = self.cur_album or "", self.match.info.album
+ if (
+ self.cur_album != self.match.info.album
+ and self.match.info.album != VARIOUS_ARTISTS
+ ):
+ album_l, album_r = ui.colordiff(album_l, album_r)
+ # Prefix with U+2260: Not Equal To
+ left = {
+ "prefix": ui.colorize("changed", "\u2260") + " Album: ",
+ "contents": album_l,
+ "suffix": "",
+ }
+ right = {"prefix": "", "contents": album_r, "suffix": ""}
+ self.print_layout(self.indent_detail, left, right)
+ else:
+ print_(self.indent_detail + "*", "Album:", album_r)
+ elif self.cur_title:
+ # Title - for singletons
+ title_l, title_r = self.cur_title or "", self.match.info.title
+ if self.cur_title != self.match.info.title:
+ title_l, title_r = ui.colordiff(title_l, title_r)
+ # Prefix with U+2260: Not Equal To
+ left = {
+ "prefix": ui.colorize("changed", "\u2260") + " Title: ",
+ "contents": title_l,
+ "suffix": "",
+ }
+ right = {"prefix": "", "contents": title_r, "suffix": ""}
+ self.print_layout(self.indent_detail, left, right)
+ else:
+ print_(self.indent_detail + "*", "Title:", title_r)
+
+ def make_medium_info_line(self, track_info):
+ """Construct a line with the current medium's info."""
+ track_media = track_info.get("media", "Media")
+ # Build output string.
+ if self.match.info.mediums > 1 and track_info.disctitle:
+ return (
+ f"* {track_media} {track_info.medium}: {track_info.disctitle}"
+ )
+ elif self.match.info.mediums > 1:
+ return f"* {track_media} {track_info.medium}"
+ elif track_info.disctitle:
+ return f"* {track_media}: {track_info.disctitle}"
+ else:
+ return ""
+
+ def format_index(self, track_info):
"""Return a string representing the track index of the given
TrackInfo or Item object.
"""
@@ -231,209 +474,290 @@ def show_change(cur_artist, cur_album, match):
index = track_info.index
medium_index = track_info.medium_index
medium = track_info.medium
- mediums = match.info.mediums
+ mediums = self.match.info.mediums
else:
index = medium_index = track_info.track
medium = track_info.disc
mediums = track_info.disctotal
- if config['per_disc_numbering']:
+ if config["per_disc_numbering"]:
if mediums and mediums > 1:
- return f'{medium}-{medium_index}'
+ return f"{medium}-{medium_index}"
else:
- return str(medium_index if medium_index is not None
- else index)
+ return str(medium_index if medium_index is not None else index)
else:
return str(index)
- # Identify the album in question.
- if cur_artist != match.info.artist or \
- (cur_album != match.info.album and
- match.info.album != VARIOUS_ARTISTS):
- artist_l, artist_r = cur_artist or '', match.info.artist
- album_l, album_r = cur_album or '', match.info.album
- if artist_r == VARIOUS_ARTISTS:
- # Hide artists for VA releases.
- artist_l, artist_r = '', ''
-
- if config['artist_credit']:
- artist_r = match.info.artist_credit
-
- artist_l, artist_r = ui.colordiff(artist_l, artist_r)
- album_l, album_r = ui.colordiff(album_l, album_r)
-
- print_("Correcting tags from:")
- show_album(artist_l, album_l)
- print_("To:")
- show_album(artist_r, album_r)
- else:
- print_("Tagging:\n {0.artist} - {0.album}".format(match.info))
-
- # Data URL.
- if match.info.data_url:
- print_('URL:\n %s' % match.info.data_url)
-
- # Info line.
- info = []
- # Similarity.
- info.append('(Similarity: %s)' % dist_string(match.distance))
- # Penalties.
- penalties = penalty_string(match.distance)
- if penalties:
- info.append(penalties)
- # Disambiguation.
- disambig = disambig_string(match.info)
- if disambig:
- info.append(ui.colorize('text_highlight_minor', '(%s)' % disambig))
- print_(' '.join(info))
-
- # Tracks.
- pairs = list(match.mapping.items())
- pairs.sort(key=lambda item_and_track_info: item_and_track_info[1].index)
-
- # Build up LHS and RHS for track difference display. The `lines` list
- # contains ``(lhs, rhs, width)`` tuples where `width` is the length (in
- # characters) of the uncolorized LHS.
- lines = []
- medium = disctitle = None
- for item, track_info in pairs:
-
- # Medium number and title.
- if medium != track_info.medium or disctitle != track_info.disctitle:
- media = match.info.media or 'Media'
- if match.info.mediums > 1 and track_info.disctitle:
- lhs = '{} {}: {}'.format(media, track_info.medium,
- track_info.disctitle)
- elif match.info.mediums > 1:
- lhs = f'{media} {track_info.medium}'
- elif track_info.disctitle:
- lhs = f'{media}: {track_info.disctitle}'
+ def make_track_numbers(self, item, track_info):
+ """Format colored track indices."""
+ cur_track = self.format_index(item)
+ new_track = self.format_index(track_info)
+ templ = "(#{})"
+ changed = False
+ # Choose color based on change.
+ if cur_track != new_track:
+ changed = True
+ if item.track in (track_info.index, track_info.medium_index):
+ highlight_color = "text_highlight_minor"
else:
- lhs = None
- if lhs:
- lines.append((lhs, '', 0))
- medium, disctitle = track_info.medium, track_info.disctitle
+ highlight_color = "text_highlight"
+ else:
+ highlight_color = "text_faint"
- # Titles.
+ cur_track = templ.format(cur_track)
+ new_track = templ.format(new_track)
+ lhs_track = ui.colorize(highlight_color, cur_track)
+ rhs_track = ui.colorize(highlight_color, new_track)
+ return lhs_track, rhs_track, changed
+
+ @staticmethod
+ def make_track_titles(item, track_info):
+ """Format colored track titles."""
new_title = track_info.title
if not item.title.strip():
- # If there's no title, we use the filename.
+ # If there's no title, we use the filename. Don't colordiff.
cur_title = displayable_path(os.path.basename(item.path))
- lhs, rhs = cur_title, new_title
+ return cur_title, new_title, True
else:
+ # If there is a title, highlight differences.
cur_title = item.title.strip()
- lhs, rhs = ui.colordiff(cur_title, new_title)
- lhs_width = len(cur_title)
+ cur_col, new_col = ui.colordiff(cur_title, new_title)
+ return cur_col, new_col, cur_title != new_title
+ @staticmethod
+ def make_track_lengths(item, track_info):
+ """Format colored track lengths."""
+ changed = False
+ if (
+ item.length
+ and track_info.length
+ and abs(item.length - track_info.length)
+ >= config["ui"]["length_diff_thresh"].as_number()
+ ):
+ highlight_color = "text_highlight"
+ changed = True
+ else:
+ highlight_color = "text_highlight_minor"
+
+ # Handle nonetype lengths by setting to 0
+ cur_length0 = item.length if item.length else 0
+ new_length0 = track_info.length if track_info.length else 0
+ # format into string
+ cur_length = f"({ui.human_seconds_short(cur_length0)})"
+ new_length = f"({ui.human_seconds_short(new_length0)})"
+ # colorize
+ lhs_length = ui.colorize(highlight_color, cur_length)
+ rhs_length = ui.colorize(highlight_color, new_length)
+
+ return lhs_length, rhs_length, changed
+
+ def make_line(self, item, track_info):
+ """Extract changes from item -> new TrackInfo object, and colorize
+ appropriately. Returns (lhs, rhs) for column printing.
+ """
+ # Track titles.
+ lhs_title, rhs_title, diff_title = self.make_track_titles(
+ item, track_info
+ )
# Track number change.
- cur_track, new_track = format_index(item), format_index(track_info)
- if cur_track != new_track:
- if item.track in (track_info.index, track_info.medium_index):
- color = 'text_highlight_minor'
- else:
- color = 'text_highlight'
- templ = ui.colorize(color, ' (#{0})')
- lhs += templ.format(cur_track)
- rhs += templ.format(new_track)
- lhs_width += len(cur_track) + 4
-
+ lhs_track, rhs_track, diff_track = self.make_track_numbers(
+ item, track_info
+ )
# Length change.
- if item.length and track_info.length and \
- abs(item.length - track_info.length) > \
- config['ui']['length_diff_thresh'].as_number():
- cur_length = ui.human_seconds_short(item.length)
- new_length = ui.human_seconds_short(track_info.length)
- templ = ui.colorize('text_highlight', ' ({0})')
- lhs += templ.format(cur_length)
- rhs += templ.format(new_length)
- lhs_width += len(cur_length) + 3
+ lhs_length, rhs_length, diff_length = self.make_track_lengths(
+ item, track_info
+ )
- # Penalties.
- penalties = penalty_string(match.distance.tracks[track_info])
- if penalties:
- rhs += ' %s' % penalties
+ changed = diff_title or diff_track or diff_length
- if lhs != rhs:
- lines.append((' * %s' % lhs, rhs, lhs_width))
- elif config['import']['detail']:
- lines.append((' * %s' % lhs, '', lhs_width))
+ # Construct lhs and rhs dicts.
+ # Previously, we printed the penalties, however this is no longer
+ # the case, thus the 'info' dictionary is unneeded.
+ # penalties = penalty_string(self.match.distance.tracks[track_info])
- # Print each track in two columns, or across two lines.
- col_width = (ui.term_width() - len(''.join([' * ', ' -> ']))) // 2
- if lines:
- max_width = max(w for _, _, w in lines)
- for lhs, rhs, lhs_width in lines:
- if not rhs:
- print_(lhs)
- elif max_width > col_width:
- print_(f'{lhs} ->\n {rhs}')
+ prefix = ui.colorize("changed", "\u2260 ") if changed else "* "
+ lhs = {
+ "prefix": prefix + lhs_track + " ",
+ "contents": lhs_title,
+ "suffix": " " + lhs_length,
+ }
+ rhs = {"prefix": "", "contents": "", "suffix": ""}
+ if not changed:
+ # Only return the left side, as nothing changed.
+ return (lhs, rhs)
+ else:
+ # Construct a dictionary for the "changed to" side
+ rhs = {
+ "prefix": rhs_track + " ",
+ "contents": rhs_title,
+ "suffix": " " + rhs_length,
+ }
+ return (lhs, rhs)
+
+ def print_tracklist(self, lines):
+ """Calculates column widths for tracks stored as line tuples:
+ (left, right). Then prints each line of tracklist.
+ """
+ if len(lines) == 0:
+ # If no lines provided, e.g. details not required, do nothing.
+ return
+
+ def get_width(side):
+ """Return the width of left or right in uncolorized characters."""
+ try:
+ return len(
+ ui.uncolorize(
+ " ".join(
+ [side["prefix"], side["contents"], side["suffix"]]
+ )
+ )
+ )
+ except KeyError:
+ # An empty dictionary -> Nothing to report
+ return 0
+
+ # Check how to fit content into terminal window
+ indent_width = len(self.indent_tracklist)
+ terminal_width = ui.term_width()
+ joiner_width = len("".join(["* ", " -> "]))
+ col_width = (terminal_width - indent_width - joiner_width) // 2
+ max_width_l = max(get_width(line_tuple[0]) for line_tuple in lines)
+ max_width_r = max(get_width(line_tuple[1]) for line_tuple in lines)
+
+ if (
+ (max_width_l <= col_width)
+ and (max_width_r <= col_width)
+ or (
+ ((max_width_l > col_width) or (max_width_r > col_width))
+ and ((max_width_l + max_width_r) <= col_width * 2)
+ )
+ ):
+ # All content fits. Either both maximum widths are below column
+ # widths, or one of the columns is larger than allowed but the
+ # other is smaller than allowed.
+ # In this case we can afford to shrink the columns to fit their
+ # largest string
+ col_width_l = max_width_l
+ col_width_r = max_width_r
+ else:
+ # Not all content fits - stick with original half/half split
+ col_width_l = col_width
+ col_width_r = col_width
+
+ # Print out each line, using the calculated width from above.
+ for left, right in lines:
+ left["width"] = col_width_l
+ right["width"] = col_width_r
+ self.print_layout(self.indent_tracklist, left, right)
+
+
+class AlbumChange(ChangeRepresentation):
+ """Album change representation, setting cur_album"""
+
+ def __init__(self, cur_artist, cur_album, match):
+ super().__init__()
+ self.cur_artist = cur_artist
+ self.cur_album = cur_album
+ self.match = match
+
+ def show_match_tracks(self):
+ """Print out the tracks of the match, summarizing changes the match
+ suggests for them.
+ """
+ # Tracks.
+ # match is an AlbumMatch named tuple, mapping is a dict
+ # Sort the pairs by the track_info index (at index 1 of the namedtuple)
+ pairs = list(self.match.mapping.items())
+ pairs.sort(key=lambda item_and_track_info: item_and_track_info[1].index)
+ # Build up LHS and RHS for track difference display. The `lines` list
+ # contains `(left, right)` tuples.
+ lines = []
+ medium = disctitle = None
+ for item, track_info in pairs:
+ # If the track is the first on a new medium, show medium
+ # number and title.
+ if medium != track_info.medium or disctitle != track_info.disctitle:
+ # Create header for new medium
+ header = self.make_medium_info_line(track_info)
+ if header != "":
+ # Print tracks from previous medium
+ self.print_tracklist(lines)
+ lines = []
+ print_(self.indent_detail + header)
+ # Save new medium details for future comparison.
+ medium, disctitle = track_info.medium, track_info.disctitle
+
+ # Construct the line tuple for the track.
+ left, right = self.make_line(item, track_info)
+ if right["contents"] != "":
+ lines.append((left, right))
else:
- pad = max_width - lhs_width
- print_('{}{} -> {}'.format(lhs, ' ' * pad, rhs))
+ if config["import"]["detail"]:
+ lines.append((left, right))
+ self.print_tracklist(lines)
- # Missing and unmatched tracks.
- if match.extra_tracks:
- print_('Missing tracks ({}/{} - {:.1%}):'.format(
- len(match.extra_tracks),
- len(match.info.tracks),
- len(match.extra_tracks) / len(match.info.tracks)
- ))
- pad_width = max(len(track_info.title) for track_info in
- match.extra_tracks)
- for track_info in match.extra_tracks:
- line = ' ! {0: <{width}} (#{1: >2})'.format(track_info.title,
- format_index(track_info),
- width=pad_width)
- if track_info.length:
- line += ' (%s)' % ui.human_seconds_short(track_info.length)
- print_(ui.colorize('text_warning', line))
- if match.extra_items:
- print_('Unmatched tracks ({}):'.format(len(match.extra_items)))
- pad_width = max(len(item.title) for item in match.extra_items)
- for item in match.extra_items:
- line = ' ! {0: <{width}} (#{1: >2})'.format(item.title,
- format_index(item),
- width=pad_width)
- if item.length:
- line += ' (%s)' % ui.human_seconds_short(item.length)
- print_(ui.colorize('text_warning', line))
+ # Missing and unmatched tracks.
+ if self.match.extra_tracks:
+ print_(
+ "Missing tracks ({0}/{1} - {2:.1%}):".format(
+ len(self.match.extra_tracks),
+ len(self.match.info.tracks),
+ len(self.match.extra_tracks) / len(self.match.info.tracks),
+ )
+ )
+ for track_info in self.match.extra_tracks:
+ line = f" ! {track_info.title} (#{self.format_index(track_info)})"
+ if track_info.length:
+ line += f" ({ui.human_seconds_short(track_info.length)})"
+ print_(ui.colorize("text_warning", line))
+ if self.match.extra_items:
+ print_(f"Unmatched tracks ({len(self.match.extra_items)}):")
+ for item in self.match.extra_items:
+ line = " ! {} (#{})".format(item.title, self.format_index(item))
+ if item.length:
+ line += " ({})".format(ui.human_seconds_short(item.length))
+ print_(ui.colorize("text_warning", line))
+
+
+class TrackChange(ChangeRepresentation):
+ """Track change representation, comparing item with match."""
+
+ def __init__(self, cur_artist, cur_title, match):
+ super().__init__()
+ self.cur_artist = cur_artist
+ self.cur_title = cur_title
+ self.match = match
+
+
+def show_change(cur_artist, cur_album, match):
+ """Print out a representation of the changes that will be made if an
+ album's tags are changed according to `match`, which must be an AlbumMatch
+ object.
+ """
+ change = AlbumChange(
+ cur_artist=cur_artist, cur_album=cur_album, match=match
+ )
+
+ # Print the match header.
+ change.show_match_header()
+
+ # Print the match details.
+ change.show_match_details()
+
+ # Print the match tracks.
+ change.show_match_tracks()
def show_item_change(item, match):
"""Print out the change that would occur by tagging `item` with the
metadata from `match`, a TrackMatch object.
"""
- cur_artist, new_artist = item.artist, match.info.artist
- cur_title, new_title = item.title, match.info.title
-
- if cur_artist != new_artist or cur_title != new_title:
- cur_artist, new_artist = ui.colordiff(cur_artist, new_artist)
- cur_title, new_title = ui.colordiff(cur_title, new_title)
-
- print_("Correcting track tags from:")
- print_(f" {cur_artist} - {cur_title}")
- print_("To:")
- print_(f" {new_artist} - {new_title}")
-
- else:
- print_(f"Tagging track: {cur_artist} - {cur_title}")
-
- # Data URL.
- if match.info.data_url:
- print_('URL:\n %s' % match.info.data_url)
-
- # Info line.
- info = []
- # Similarity.
- info.append('(Similarity: %s)' % dist_string(match.distance))
- # Penalties.
- penalties = penalty_string(match.distance)
- if penalties:
- info.append(penalties)
- # Disambiguation.
- disambig = disambig_string(match.info)
- if disambig:
- info.append(ui.colorize('text_highlight_minor', '(%s)' % disambig))
- print_(' '.join(info))
+ change = TrackChange(
+ cur_artist=item.artist, cur_title=item.title, match=match
+ )
+ # Print the match header.
+ change.show_match_header()
+ # Print the match details.
+ change.show_match_details()
def summarize_items(items, singleton):
@@ -458,23 +782,24 @@ def summarize_items(items, singleton):
# Enumerate all the formats by decreasing frequencies:
for fmt, count in sorted(
format_counts.items(),
- key=lambda fmt_and_count: (-fmt_and_count[1], fmt_and_count[0])
+ key=lambda fmt_and_count: (-fmt_and_count[1], fmt_and_count[0]),
):
- summary_parts.append(f'{fmt} {count}')
+ summary_parts.append(f"{fmt} {count}")
if items:
average_bitrate = sum([item.bitrate for item in items]) / len(items)
total_duration = sum([item.length for item in items])
total_filesize = sum([item.filesize for item in items])
- summary_parts.append('{}kbps'.format(int(average_bitrate / 1000)))
+ summary_parts.append("{}kbps".format(int(average_bitrate / 1000)))
if items[0].format == "FLAC":
- sample_bits = '{}kHz/{} bit'.format(
- round(int(items[0].samplerate) / 1000, 1), items[0].bitdepth)
+ sample_bits = "{}kHz/{} bit".format(
+ round(int(items[0].samplerate) / 1000, 1), items[0].bitdepth
+ )
summary_parts.append(sample_bits)
summary_parts.append(ui.human_seconds_short(total_duration))
summary_parts.append(ui.human_bytes(total_filesize))
- return ', '.join(summary_parts)
+ return ", ".join(summary_parts)
def _summary_judgment(rec):
@@ -485,35 +810,46 @@ def _summary_judgment(rec):
summary judgment is made.
"""
- if config['import']['quiet']:
+ if config["import"]["quiet"]:
if rec == Recommendation.strong:
return importer.action.APPLY
else:
- action = config['import']['quiet_fallback'].as_choice({
- 'skip': importer.action.SKIP,
- 'asis': importer.action.ASIS,
- })
- elif config['import']['timid']:
+ action = config["import"]["quiet_fallback"].as_choice(
+ {
+ "skip": importer.action.SKIP,
+ "asis": importer.action.ASIS,
+ }
+ )
+ elif config["import"]["timid"]:
return None
elif rec == Recommendation.none:
- action = config['import']['none_rec_action'].as_choice({
- 'skip': importer.action.SKIP,
- 'asis': importer.action.ASIS,
- 'ask': None,
- })
+ action = config["import"]["none_rec_action"].as_choice(
+ {
+ "skip": importer.action.SKIP,
+ "asis": importer.action.ASIS,
+ "ask": None,
+ }
+ )
else:
return None
if action == importer.action.SKIP:
- print_('Skipping.')
+ print_("Skipping.")
elif action == importer.action.ASIS:
- print_('Importing as-is.')
+ print_("Importing as-is.")
return action
-def choose_candidate(candidates, singleton, rec, cur_artist=None,
- cur_album=None, item=None, itemcount=None,
- choices=[]):
+def choose_candidate(
+ candidates,
+ singleton,
+ rec,
+ cur_artist=None,
+ cur_album=None,
+ item=None,
+ itemcount=None,
+ choices=[],
+):
"""Given a sorted list of candidates, ask the user for a selection
of which candidate to use. Applies to both full albums and
singletons (tracks). Candidates are either AlbumMatch or TrackMatch
@@ -544,10 +880,11 @@ def choose_candidate(candidates, singleton, rec, cur_artist=None,
if singleton:
print_("No matching recordings found.")
else:
- print_("No matching release found for {} tracks."
- .format(itemcount))
- print_('For help, see: '
- 'https://beets.readthedocs.org/en/latest/faq.html#nomatch')
+ print_("No matching release found for {} tracks.".format(itemcount))
+ print_(
+ "For help, see: "
+ "https://beets.readthedocs.org/en/latest/faq.html#nomatch"
+ )
sel = ui.input_options(choice_opts)
if sel in choice_actions:
return choice_actions[sel]
@@ -566,41 +903,46 @@ def choose_candidate(candidates, singleton, rec, cur_artist=None,
if not bypass_candidates:
# Display list of candidates.
- print_('Finding tags for {} "{} - {}".'.format(
- 'track' if singleton else 'album',
- item.artist if singleton else cur_artist,
- item.title if singleton else cur_album,
- ))
+ print_("")
+ print_(
+ 'Finding tags for {} "{} - {}".'.format(
+ "track" if singleton else "album",
+ item.artist if singleton else cur_artist,
+ item.title if singleton else cur_album,
+ )
+ )
- print_('Candidates:')
+ print_(ui.indent(2) + "Candidates:")
for i, match in enumerate(candidates):
# Index, metadata, and distance.
- line = [
- '{}.'.format(i + 1),
- '{} - {}'.format(
- match.info.artist,
- match.info.title if singleton else match.info.album,
- ),
- '({})'.format(dist_string(match.distance)),
- ]
+ index0 = "{0}.".format(i + 1)
+ index = dist_colorize(index0, match.distance)
+ dist = "({:.1f}%)".format((1 - match.distance) * 100)
+ distance = dist_colorize(dist, match.distance)
+ metadata = "{0} - {1}".format(
+ match.info.artist,
+ match.info.title if singleton else match.info.album,
+ )
+ if i == 0:
+ metadata = dist_colorize(metadata, match.distance)
+ else:
+ metadata = ui.colorize("text_highlight_minor", metadata)
+ line1 = [index, distance, metadata]
+ print_(ui.indent(2) + " ".join(line1))
# Penalties.
penalties = penalty_string(match.distance, 3)
if penalties:
- line.append(penalties)
+ print_(ui.indent(13) + penalties)
# Disambiguation
disambig = disambig_string(match.info)
if disambig:
- line.append(ui.colorize('text_highlight_minor',
- '(%s)' % disambig))
-
- print_(' '.join(line))
+ print_(ui.indent(13) + disambig)
# Ask the user for a choice.
- sel = ui.input_options(choice_opts,
- numrange=(1, len(candidates)))
- if sel == 'm':
+ sel = ui.input_options(choice_opts, numrange=(1, len(candidates)))
+ if sel == "m":
pass
elif sel in choice_actions:
return choice_actions[sel]
@@ -619,24 +961,29 @@ def choose_candidate(candidates, singleton, rec, cur_artist=None,
show_change(cur_artist, cur_album, match)
# Exact match => tag automatically if we're not in timid mode.
- if rec == Recommendation.strong and not config['import']['timid']:
+ if rec == Recommendation.strong and not config["import"]["timid"]:
return match
# Ask for confirmation.
- default = config['import']['default_action'].as_choice({
- 'apply': 'a',
- 'skip': 's',
- 'asis': 'u',
- 'none': None,
- })
+ default = config["import"]["default_action"].as_choice(
+ {
+ "apply": "a",
+ "skip": "s",
+ "asis": "u",
+ "none": None,
+ }
+ )
if default is None:
require = True
# Bell ring when user interaction is needed.
- if config['import']['bell']:
- ui.print_('\a', end='')
- sel = ui.input_options(('Apply', 'More candidates') + choice_opts,
- require=require, default=default)
- if sel == 'a':
+ if config["import"]["bell"]:
+ ui.print_("\a", end="")
+ sel = ui.input_options(
+ ("Apply", "More candidates") + choice_opts,
+ require=require,
+ default=default,
+ )
+ if sel == "a":
return match
elif sel in choice_actions:
return choice_actions[sel]
@@ -648,13 +995,11 @@ def manual_search(session, task):
Input either an artist and album (for full albums) or artist and
track name (for singletons) for manual search.
"""
- artist = input_('Artist:').strip()
- name = input_('Album:' if task.is_album else 'Track:').strip()
+ artist = input_("Artist:").strip()
+ name = input_("Album:" if task.is_album else "Track:").strip()
if task.is_album:
- _, _, prop = autotag.tag_album(
- task.items, artist, name
- )
+ _, _, prop = autotag.tag_album(task.items, artist, name)
return prop
else:
return autotag.tag_item(task.item, artist, name)
@@ -665,28 +1010,23 @@ def manual_id(session, task):
Input an ID, either for an album ("release") or a track ("recording").
"""
- prompt = 'Enter {} ID:'.format('release' if task.is_album
- else 'recording')
+ prompt = "Enter {} ID:".format("release" if task.is_album else "recording")
search_id = input_(prompt).strip()
if task.is_album:
- _, _, prop = autotag.tag_album(
- task.items, search_ids=search_id.split()
- )
+ _, _, prop = autotag.tag_album(task.items, search_ids=search_id.split())
return prop
else:
return autotag.tag_item(task.item, search_ids=search_id.split())
def abort_action(session, task):
- """A prompt choice callback that aborts the importer.
- """
+ """A prompt choice callback that aborts the importer."""
raise importer.ImportAbort()
class TerminalImportSession(importer.ImportSession):
- """An import session that runs in a terminal.
- """
+ """An import session that runs in a terminal."""
def choose_match(self, task):
"""Given an initial autotagging of items, go through an interactive
@@ -695,21 +1035,27 @@ class TerminalImportSession(importer.ImportSession):
"""
# Show what we're tagging.
print_()
- print_(displayable_path(task.paths, '\n') +
- ' ({} items)'.format(len(task.items)))
+
+ path_str0 = displayable_path(task.paths, "\n")
+ path_str = ui.colorize("import_path", path_str0)
+ items_str0 = "({} items)".format(len(task.items))
+ items_str = ui.colorize("import_path_items", items_str0)
+ print_(" ".join([path_str, items_str]))
# Let plugins display info or prompt the user before we go through the
# process of selecting candidate.
- results = plugins.send('import_task_before_choice',
- session=self, task=task)
+ results = plugins.send(
+ "import_task_before_choice", session=self, task=task
+ )
actions = [action for action in results if action]
if len(actions) == 1:
return actions[0]
elif len(actions) > 1:
raise plugins.PluginConflictException(
- 'Only one handler for `import_task_before_choice` may return '
- 'an action.')
+ "Only one handler for `import_task_before_choice` may return "
+ "an action."
+ )
# Take immediate action if appropriate.
action = _summary_judgment(task.rec)
@@ -728,8 +1074,13 @@ class TerminalImportSession(importer.ImportSession):
# `PromptChoice`.
choices = self._get_choices(task)
choice = choose_candidate(
- task.candidates, False, task.rec, task.cur_artist,
- task.cur_album, itemcount=len(task.items), choices=choices
+ task.candidates,
+ False,
+ task.rec,
+ task.cur_artist,
+ task.cur_album,
+ itemcount=len(task.items),
+ choices=choices,
)
# Basic choices that require no more action here.
@@ -775,8 +1126,9 @@ class TerminalImportSession(importer.ImportSession):
while True:
# Ask for a choice.
choices = self._get_choices(task)
- choice = choose_candidate(candidates, True, rec, item=task.item,
- choices=choices)
+ choice = choose_candidate(
+ candidates, True, rec, item=task.item, choices=choices
+ )
if choice in (importer.action.SKIP, importer.action.ASIS):
return choice
@@ -798,49 +1150,71 @@ class TerminalImportSession(importer.ImportSession):
"""Decide what to do when a new album or item seems similar to one
that's already in the library.
"""
- log.warning("This {0} is already in the library!",
- ("album" if task.is_album else "item"))
+ log.warning(
+ "This {0} is already in the library!",
+ ("album" if task.is_album else "item"),
+ )
- if config['import']['quiet']:
+ if config["import"]["quiet"]:
# In quiet mode, don't prompt -- just skip.
- log.info('Skipping.')
- sel = 's'
+ log.info("Skipping.")
+ sel = "s"
else:
# Print some detail about the existing and new items so the
# user can make an informed decision.
for duplicate in found_duplicates:
- print_("Old: " + summarize_items(
- list(duplicate.items()) if task.is_album else [duplicate],
- not task.is_album,
- ))
+ print_(
+ "Old: "
+ + summarize_items(
+ (
+ list(duplicate.items())
+ if task.is_album
+ else [duplicate]
+ ),
+ not task.is_album,
+ )
+ )
+ if config["import"]["duplicate_verbose_prompt"]:
+ if task.is_album:
+ for dup in duplicate.items():
+ print(f" {dup}")
+ else:
+ print(f" {duplicate}")
- print_("New: " + summarize_items(
- task.imported_items(),
- not task.is_album,
- ))
+ print_(
+ "New: "
+ + summarize_items(
+ task.imported_items(),
+ not task.is_album,
+ )
+ )
+ if config["import"]["duplicate_verbose_prompt"]:
+ for item in task.imported_items():
+ print(f" {item}")
sel = ui.input_options(
- ('Skip new', 'Keep all', 'Remove old', 'Merge all')
+ ("Skip new", "Keep all", "Remove old", "Merge all")
)
- if sel == 's':
+ if sel == "s":
# Skip new.
task.set_choice(importer.action.SKIP)
- elif sel == 'k':
+ elif sel == "k":
# Keep both. Do nothing; leave the choice intact.
pass
- elif sel == 'r':
+ elif sel == "r":
# Remove old.
task.should_remove_duplicates = True
- elif sel == 'm':
+ elif sel == "m":
task.should_merge_duplicates = True
else:
assert False
def should_resume(self, path):
- return ui.input_yn("Import of the directory:\n{}\n"
- "was interrupted. Resume (Y/n)?"
- .format(displayable_path(path)))
+ return ui.input_yn(
+ "Import of the directory:\n{}\n"
+ "was interrupted. Resume (Y/n)?".format(displayable_path(path))
+ )
def _get_choices(self, task):
"""Get the list of prompt choices that should be presented to the
@@ -860,47 +1234,61 @@ class TerminalImportSession(importer.ImportSession):
"""
# Standard, built-in choices.
choices = [
- PromptChoice('s', 'Skip',
- lambda s, t: importer.action.SKIP),
- PromptChoice('u', 'Use as-is',
- lambda s, t: importer.action.ASIS)
+ PromptChoice("s", "Skip", lambda s, t: importer.action.SKIP),
+ PromptChoice("u", "Use as-is", lambda s, t: importer.action.ASIS),
]
if task.is_album:
choices += [
- PromptChoice('t', 'as Tracks',
- lambda s, t: importer.action.TRACKS),
- PromptChoice('g', 'Group albums',
- lambda s, t: importer.action.ALBUMS),
+ PromptChoice(
+ "t", "as Tracks", lambda s, t: importer.action.TRACKS
+ ),
+ PromptChoice(
+ "g", "Group albums", lambda s, t: importer.action.ALBUMS
+ ),
]
choices += [
- PromptChoice('e', 'Enter search', manual_search),
- PromptChoice('i', 'enter Id', manual_id),
- PromptChoice('b', 'aBort', abort_action),
+ PromptChoice("e", "Enter search", manual_search),
+ PromptChoice("i", "enter Id", manual_id),
+ PromptChoice("b", "aBort", abort_action),
]
# Send the before_choose_candidate event and flatten list.
- extra_choices = list(chain(*plugins.send('before_choose_candidate',
- session=self, task=task)))
+ extra_choices = list(
+ chain(
+ *plugins.send(
+ "before_choose_candidate", session=self, task=task
+ )
+ )
+ )
# Add a "dummy" choice for the other baked-in option, for
# duplicate checking.
- all_choices = [
- PromptChoice('a', 'Apply', None),
- ] + choices + extra_choices
+ all_choices = (
+ [
+ PromptChoice("a", "Apply", None),
+ ]
+ + choices
+ + extra_choices
+ )
# Check for conflicts.
short_letters = [c.short for c in all_choices]
if len(short_letters) != len(set(short_letters)):
# Duplicate short letter has been found.
- duplicates = [i for i, count in Counter(short_letters).items()
- if count > 1]
+ duplicates = [
+ i for i, count in Counter(short_letters).items() if count > 1
+ ]
for short in duplicates:
# Keep the first of the choices, removing the rest.
dup_choices = [c for c in all_choices if c.short == short]
for c in dup_choices[1:]:
- log.warning("Prompt choice '{0}' removed due to conflict "
- "with '{1}' (short letter: '{2}')",
- c.long, dup_choices[0].long, c.short)
+ log.warning(
+ "Prompt choice '{0}' removed due to conflict "
+ "with '{1}' (short letter: '{2}')",
+ c.long,
+ dup_choices[0].long,
+ c.short,
+ )
extra_choices.remove(c)
return choices + extra_choices
@@ -913,46 +1301,41 @@ def import_files(lib, paths, query):
"""Import the files in the given list of paths or matching the
query.
"""
- # Check the user-specified directories.
- for path in paths:
- if not os.path.exists(syspath(normpath(path))):
- raise ui.UserError('no such file or directory: {}'.format(
- displayable_path(path)))
-
# Check parameter consistency.
- if config['import']['quiet'] and config['import']['timid']:
+ if config["import"]["quiet"] and config["import"]["timid"]:
raise ui.UserError("can't be both quiet and timid")
# Open the log.
- if config['import']['log'].get() is not None:
- logpath = syspath(config['import']['log'].as_filename())
+ if config["import"]["log"].get() is not None:
+ logpath = syspath(config["import"]["log"].as_filename())
try:
- loghandler = logging.FileHandler(logpath)
+ loghandler = logging.FileHandler(logpath, encoding="utf-8")
except OSError:
- raise ui.UserError("could not open log file for writing: "
- "{}".format(displayable_path(logpath)))
+ raise ui.UserError(
+ "could not open log file for writing: "
+ "{}".format(displayable_path(logpath))
+ )
else:
loghandler = None
# Never ask for input in quiet mode.
- if config['import']['resume'].get() == 'ask' and \
- config['import']['quiet']:
- config['import']['resume'] = False
+ if config["import"]["resume"].get() == "ask" and config["import"]["quiet"]:
+ config["import"]["resume"] = False
session = TerminalImportSession(lib, loghandler, paths, query)
session.run()
# Emit event.
- plugins.send('import', lib=lib, paths=paths)
+ plugins.send("import", lib=lib, paths=paths)
def import_func(lib, opts, args):
- config['import'].set_args(opts)
+ config["import"].set_args(opts)
# Special case: --copy flag suppresses import_move (which would
# otherwise take precedence).
if opts.copy:
- config['import']['move'] = False
+ config["import"]["move"] = False
if opts.library:
query = decargs(args)
@@ -960,112 +1343,238 @@ def import_func(lib, opts, args):
else:
query = None
paths = args
- if not paths:
- raise ui.UserError('no path specified')
+
+ # The paths from the logfiles go into a separate list to allow handling
+ # errors differently from user-specified paths.
+ paths_from_logfiles = list(_parse_logfiles(opts.from_logfiles or []))
+
+ if not paths and not paths_from_logfiles:
+ raise ui.UserError("no path specified")
# On Python 2, we used to get filenames as raw bytes, which is
# what we need. On Python 3, we need to undo the "helpful"
# conversion to Unicode strings to get the real bytestring
# filename.
- paths = [p.encode(util.arg_encoding(), 'surrogateescape')
- for p in paths]
+ paths = [
+ p.encode(util.arg_encoding(), "surrogateescape") for p in paths
+ ]
+ paths_from_logfiles = [
+ p.encode(util.arg_encoding(), "surrogateescape")
+ for p in paths_from_logfiles
+ ]
+
+ # Check the user-specified directories.
+ for path in paths:
+ if not os.path.exists(syspath(normpath(path))):
+ raise ui.UserError(
+ "no such file or directory: {}".format(
+ displayable_path(path)
+ )
+ )
+
+ # Check the directories from the logfiles, but don't throw an error in
+ # case those paths don't exist. Maybe some of those paths have already
+ # been imported and moved separately, so logging a warning should
+ # suffice.
+ for path in paths_from_logfiles:
+ if not os.path.exists(syspath(normpath(path))):
+ log.warning(
+ "No such file or directory: {}".format(
+ displayable_path(path)
+ )
+ )
+ continue
+
+ paths.append(path)
+
+ # If all paths were read from a logfile, and none of them exist, throw
+ # an error
+ if not paths:
+ raise ui.UserError("none of the paths are importable")
import_files(lib, paths, query)
import_cmd = ui.Subcommand(
- 'import', help='import new music', aliases=('imp', 'im')
+ "import", help="import new music", aliases=("imp", "im")
)
import_cmd.parser.add_option(
- '-c', '--copy', action='store_true', default=None,
- help="copy tracks into library directory (default)"
+ "-c",
+ "--copy",
+ action="store_true",
+ default=None,
+ help="copy tracks into library directory (default)",
)
import_cmd.parser.add_option(
- '-C', '--nocopy', action='store_false', dest='copy',
- help="don't copy tracks (opposite of -c)"
+ "-C",
+ "--nocopy",
+ action="store_false",
+ dest="copy",
+ help="don't copy tracks (opposite of -c)",
)
import_cmd.parser.add_option(
- '-m', '--move', action='store_true', dest='move',
- help="move tracks into the library (overrides -c)"
+ "-m",
+ "--move",
+ action="store_true",
+ dest="move",
+ help="move tracks into the library (overrides -c)",
)
import_cmd.parser.add_option(
- '-w', '--write', action='store_true', default=None,
- help="write new metadata to files' tags (default)"
+ "-w",
+ "--write",
+ action="store_true",
+ default=None,
+ help="write new metadata to files' tags (default)",
)
import_cmd.parser.add_option(
- '-W', '--nowrite', action='store_false', dest='write',
- help="don't write metadata (opposite of -w)"
+ "-W",
+ "--nowrite",
+ action="store_false",
+ dest="write",
+ help="don't write metadata (opposite of -w)",
)
import_cmd.parser.add_option(
- '-a', '--autotag', action='store_true', dest='autotag',
- help="infer tags for imported files (default)"
+ "-a",
+ "--autotag",
+ action="store_true",
+ dest="autotag",
+ help="infer tags for imported files (default)",
)
import_cmd.parser.add_option(
- '-A', '--noautotag', action='store_false', dest='autotag',
- help="don't infer tags for imported files (opposite of -a)"
+ "-A",
+ "--noautotag",
+ action="store_false",
+ dest="autotag",
+ help="don't infer tags for imported files (opposite of -a)",
)
import_cmd.parser.add_option(
- '-p', '--resume', action='store_true', default=None,
- help="resume importing if interrupted"
+ "-p",
+ "--resume",
+ action="store_true",
+ default=None,
+ help="resume importing if interrupted",
)
import_cmd.parser.add_option(
- '-P', '--noresume', action='store_false', dest='resume',
- help="do not try to resume importing"
+ "-P",
+ "--noresume",
+ action="store_false",
+ dest="resume",
+ help="do not try to resume importing",
)
import_cmd.parser.add_option(
- '-q', '--quiet', action='store_true', dest='quiet',
- help="never prompt for input: skip albums instead"
+ "-q",
+ "--quiet",
+ action="store_true",
+ dest="quiet",
+ help="never prompt for input: skip albums instead",
)
import_cmd.parser.add_option(
- '-l', '--log', dest='log',
- help='file to log untaggable albums for later review'
+ "--quiet-fallback",
+ type="string",
+ dest="quiet_fallback",
+ help="decision in quiet mode when no strong match: skip or asis",
)
import_cmd.parser.add_option(
- '-s', '--singletons', action='store_true',
- help='import individual tracks instead of full albums'
+ "-l",
+ "--log",
+ dest="log",
+ help="file to log untaggable albums for later review",
)
import_cmd.parser.add_option(
- '-t', '--timid', dest='timid', action='store_true',
- help='always confirm all actions'
+ "-s",
+ "--singletons",
+ action="store_true",
+ help="import individual tracks instead of full albums",
)
import_cmd.parser.add_option(
- '-L', '--library', dest='library', action='store_true',
- help='retag items matching a query'
+ "-t",
+ "--timid",
+ dest="timid",
+ action="store_true",
+ help="always confirm all actions",
)
import_cmd.parser.add_option(
- '-i', '--incremental', dest='incremental', action='store_true',
- help='skip already-imported directories'
+ "-L",
+ "--library",
+ dest="library",
+ action="store_true",
+ help="retag items matching a query",
)
import_cmd.parser.add_option(
- '-I', '--noincremental', dest='incremental', action='store_false',
- help='do not skip already-imported directories'
+ "-i",
+ "--incremental",
+ dest="incremental",
+ action="store_true",
+ help="skip already-imported directories",
)
import_cmd.parser.add_option(
- '--from-scratch', dest='from_scratch', action='store_true',
- help='erase existing metadata before applying new metadata'
+ "-I",
+ "--noincremental",
+ dest="incremental",
+ action="store_false",
+ help="do not skip already-imported directories",
)
import_cmd.parser.add_option(
- '--flat', dest='flat', action='store_true',
- help='import an entire tree as a single album'
+ "-R",
+ "--incremental-skip-later",
+ action="store_true",
+ dest="incremental_skip_later",
+ help="do not record skipped files during incremental import",
)
import_cmd.parser.add_option(
- '-g', '--group-albums', dest='group_albums', action='store_true',
- help='group tracks in a folder into separate albums'
+ "-r",
+ "--noincremental-skip-later",
+ action="store_false",
+ dest="incremental_skip_later",
+ help="record skipped files during incremental import",
)
import_cmd.parser.add_option(
- '--pretend', dest='pretend', action='store_true',
- help='just print the files to import'
+ "--from-scratch",
+ dest="from_scratch",
+ action="store_true",
+ help="erase existing metadata before applying new metadata",
)
import_cmd.parser.add_option(
- '-S', '--search-id', dest='search_ids', action='append',
- metavar='ID',
- help='restrict matching to a specific metadata backend ID'
+ "--flat",
+ dest="flat",
+ action="store_true",
+ help="import an entire tree as a single album",
)
import_cmd.parser.add_option(
- '--set', dest='set_fields', action='callback',
+ "-g",
+ "--group-albums",
+ dest="group_albums",
+ action="store_true",
+ help="group tracks in a folder into separate albums",
+)
+import_cmd.parser.add_option(
+ "--pretend",
+ dest="pretend",
+ action="store_true",
+ help="just print the files to import",
+)
+import_cmd.parser.add_option(
+ "-S",
+ "--search-id",
+ dest="search_ids",
+ action="append",
+ metavar="ID",
+ help="restrict matching to a specific metadata backend ID",
+)
+import_cmd.parser.add_option(
+ "--from-logfile",
+ dest="from_logfiles",
+ action="append",
+ metavar="PATH",
+ help="read skipped paths from an existing logfile",
+)
+import_cmd.parser.add_option(
+ "--set",
+ dest="set_fields",
+ action="callback",
callback=_store_dict,
- metavar='FIELD=VALUE',
- help='set the given fields to the supplied values'
+ metavar="FIELD=VALUE",
+ help="set the given fields to the supplied values",
)
import_cmd.func = import_func
default_commands.append(import_cmd)
@@ -1073,7 +1582,8 @@ default_commands.append(import_cmd)
# list: Query and show library contents.
-def list_items(lib, query, album, fmt=''):
+
+def list_items(lib, query, album, fmt=""):
"""Print out items in lib matching query. If album, then search for
albums instead of single items.
"""
@@ -1089,9 +1599,10 @@ def list_func(lib, opts, args):
list_items(lib, decargs(args), opts.album)
-list_cmd = ui.Subcommand('list', help='query the library', aliases=('ls',))
-list_cmd.parser.usage += "\n" \
- 'Example: %prog -f \'$album: $title\' artist:beatles'
+list_cmd = ui.Subcommand("list", help="query the library", aliases=("ls",))
+list_cmd.parser.usage += (
+ "\n" "Example: %prog -f '$album: $title' artist:beatles"
+)
list_cmd.parser.add_all_common_options()
list_cmd.func = list_func
default_commands.append(list_cmd)
@@ -1099,27 +1610,45 @@ default_commands.append(list_cmd)
# update: Update library contents according to on-disk tags.
-def update_items(lib, query, album, move, pretend, fields):
+
+def update_items(lib, query, album, move, pretend, fields, exclude_fields=None):
"""For all the items matched by the query, update the library to
reflect the item's embedded tags.
:param fields: The fields to be stored. If not specified, all fields will
be.
+ :param exclude_fields: The fields to not be stored. If not specified, all
+ fields will be.
"""
with lib.transaction():
- if move and fields is not None and 'path' not in fields:
+ items, _ = _do_query(lib, query, album)
+ if move and fields is not None and "path" not in fields:
# Special case: if an item needs to be moved, the path field has to
# updated; otherwise the new path will not be reflected in the
# database.
- fields.append('path')
- items, _ = _do_query(lib, query, album)
+ fields.append("path")
+ if fields is None:
+ # no fields were provided, update all media fields
+ item_fields = fields or library.Item._media_fields
+ if move and "path" not in item_fields:
+ # move is enabled, add 'path' to the list of fields to update
+ item_fields.add("path")
+ else:
+ # fields was provided, just update those
+ item_fields = fields
+ # get all the album fields to update
+ album_fields = fields or library.Album._fields.keys()
+ if exclude_fields:
+ # remove any excluded fields from the item and album sets
+ item_fields = [f for f in item_fields if f not in exclude_fields]
+ album_fields = [f for f in album_fields if f not in exclude_fields]
# Walk through the items and pick up their changes.
affected_albums = set()
for item in items:
# Item deleted?
- if not os.path.exists(syspath(item.path)):
+ if not item.path or not os.path.exists(syspath(item.path)):
ui.print_(format(item))
- ui.print_(ui.colorize('text_error', ' deleted'))
+ ui.print_(ui.colorize("text_error", " deleted"))
if not pretend:
item.remove(True)
affected_albums.add(item.album_id)
@@ -1127,16 +1656,20 @@ def update_items(lib, query, album, move, pretend, fields):
# Did the item change since last checked?
if item.current_mtime() <= item.mtime:
- log.debug('skipping {0} because mtime is up to date ({1})',
- displayable_path(item.path), item.mtime)
+ log.debug(
+ "skipping {0} because mtime is up to date ({1})",
+ displayable_path(item.path),
+ item.mtime,
+ )
continue
# Read new data.
try:
item.read()
except library.ReadError as exc:
- log.error('error reading {0}: {1}',
- displayable_path(item.path), exc)
+ log.error(
+ "error reading {0}: {1}", displayable_path(item.path), exc
+ )
continue
# Special-case album artist when it matches track artist. (Hacky
@@ -1146,12 +1679,10 @@ def update_items(lib, query, album, move, pretend, fields):
old_item = lib.get_item(item.id)
if old_item.albumartist == old_item.artist == item.artist:
item.albumartist = old_item.albumartist
- item._dirty.discard('albumartist')
+ item._dirty.discard("albumartist")
# Check for and display changes.
- changed = ui.show_model_changes(
- item,
- fields=fields or library.Item._media_fields)
+ changed = ui.show_model_changes(item, fields=item_fields)
# Save changes.
if not pretend:
@@ -1160,14 +1691,14 @@ def update_items(lib, query, album, move, pretend, fields):
if move and lib.directory in ancestry(item.path):
item.move(store=False)
- item.store(fields=fields)
+ item.store(fields=item_fields)
affected_albums.add(item.album_id)
else:
# The file's mtime was different, but there were no
# changes to the metadata. Store the new mtime,
# which is set in the call to read(), so we don't
# check this again in the future.
- item.store(fields=fields)
+ item.store(fields=item_fields)
# Skip album changes while pretending.
if pretend:
@@ -1179,59 +1710,91 @@ def update_items(lib, query, album, move, pretend, fields):
continue
album = lib.get_album(album_id)
if not album: # Empty albums have already been removed.
- log.debug('emptied album {0}', album_id)
+ log.debug("emptied album {0}", album_id)
continue
first_item = album.items().get()
# Update album structure to reflect an item in it.
for key in library.Album.item_keys:
album[key] = first_item[key]
- album.store(fields=fields)
+ album.store(fields=album_fields)
# Move album art (and any inconsistent items).
if move and lib.directory in ancestry(first_item.path):
- log.debug('moving album {0}', album_id)
+ log.debug("moving album {0}", album_id)
# Manually moving and storing the album.
items = list(album.items())
for item in items:
item.move(store=False, with_album=False)
- item.store(fields=fields)
+ item.store(fields=item_fields)
album.move(store=False)
- album.store(fields=fields)
+ album.store(fields=album_fields)
def update_func(lib, opts, args):
# Verify that the library folder exists to prevent accidental wipes.
- if not os.path.isdir(lib.directory):
+ if not os.path.isdir(syspath(lib.directory)):
ui.print_("Library path is unavailable or does not exist.")
ui.print_(lib.directory)
if not ui.input_yn("Are you sure you want to continue (y/n)?", True):
return
- update_items(lib, decargs(args), opts.album, ui.should_move(opts.move),
- opts.pretend, opts.fields)
+ update_items(
+ lib,
+ decargs(args),
+ opts.album,
+ ui.should_move(opts.move),
+ opts.pretend,
+ opts.fields,
+ opts.exclude_fields,
+ )
update_cmd = ui.Subcommand(
- 'update', help='update the library', aliases=('upd', 'up',)
+ "update",
+ help="update the library",
+ aliases=(
+ "upd",
+ "up",
+ ),
)
update_cmd.parser.add_album_option()
update_cmd.parser.add_format_option()
update_cmd.parser.add_option(
- '-m', '--move', action='store_true', dest='move',
- help="move files in the library directory"
+ "-m",
+ "--move",
+ action="store_true",
+ dest="move",
+ help="move files in the library directory",
)
update_cmd.parser.add_option(
- '-M', '--nomove', action='store_false', dest='move',
- help="don't move files in library"
+ "-M",
+ "--nomove",
+ action="store_false",
+ dest="move",
+ help="don't move files in library",
)
update_cmd.parser.add_option(
- '-p', '--pretend', action='store_true',
- help="show all changes but do nothing"
+ "-p",
+ "--pretend",
+ action="store_true",
+ help="show all changes but do nothing",
)
update_cmd.parser.add_option(
- '-F', '--field', default=None, action='append', dest='fields',
- help='list of fields to update'
+ "-F",
+ "--field",
+ default=None,
+ action="append",
+ dest="fields",
+ help="list of fields to update",
+)
+update_cmd.parser.add_option(
+ "-e",
+ "--exclude-field",
+ default=None,
+ action="append",
+ dest="exclude_fields",
+ help="list of fields to exclude from updates",
)
update_cmd.func = update_func
default_commands.append(update_cmd)
@@ -1239,6 +1802,7 @@ default_commands.append(update_cmd)
# remove: Remove items from library, delete files.
+
def remove_items(lib, query, album, delete, force):
"""Remove items matching query from lib. If album, then match and
remove whole albums. If delete, also remove files from disk.
@@ -1250,21 +1814,23 @@ def remove_items(lib, query, album, delete, force):
# Confirm file removal if not forcing removal.
if not force:
# Prepare confirmation with user.
- album_str = " in {} album{}".format(
- len(albums), 's' if len(albums) > 1 else ''
- ) if album else ""
+ album_str = (
+ " in {} album{}".format(len(albums), "s" if len(albums) > 1 else "")
+ if album
+ else ""
+ )
if delete:
- fmt = '$path - $title'
- prompt = 'Really DELETE'
- prompt_all = 'Really DELETE {} file{}{}'.format(
- len(items), 's' if len(items) > 1 else '', album_str
+ fmt = "$path - $title"
+ prompt = "Really DELETE"
+ prompt_all = "Really DELETE {} file{}{}".format(
+ len(items), "s" if len(items) > 1 else "", album_str
)
else:
- fmt = ''
- prompt = 'Really remove from the library?'
- prompt_all = 'Really remove {} item{}{} from the library?'.format(
- len(items), 's' if len(items) > 1 else '', album_str
+ fmt = ""
+ prompt = "Really remove from the library?"
+ prompt_all = "Really remove {} item{}{} from the library?".format(
+ len(items), "s" if len(items) > 1 else "", album_str
)
# Helpers for printing affected items
@@ -1283,8 +1849,9 @@ def remove_items(lib, query, album, delete, force):
fmt_obj(o)
# Confirm with user.
- objs = ui.input_select_objects(prompt, objs, fmt_obj,
- prompt_all=prompt_all)
+ objs = ui.input_select_objects(
+ prompt, objs, fmt_obj, prompt_all=prompt_all
+ )
if not objs:
return
@@ -1300,15 +1867,13 @@ def remove_func(lib, opts, args):
remove_cmd = ui.Subcommand(
- 'remove', help='remove matching items from the library', aliases=('rm',)
+ "remove", help="remove matching items from the library", aliases=("rm",)
)
remove_cmd.parser.add_option(
- "-d", "--delete", action="store_true",
- help="also remove files from disk"
+ "-d", "--delete", action="store_true", help="also remove files from disk"
)
remove_cmd.parser.add_option(
- "-f", "--force", action="store_true",
- help="do not ask when removing items"
+ "-f", "--force", action="store_true", help="do not ask when removing items"
)
remove_cmd.parser.add_album_option()
remove_cmd.func = remove_func
@@ -1317,6 +1882,7 @@ default_commands.append(remove_cmd)
# stats: Show library/query statistics.
+
def show_stats(lib, query, exact):
"""Shows some statistics about the matched items."""
items = lib.items(query)
@@ -1333,7 +1899,7 @@ def show_stats(lib, query, exact):
try:
total_size += os.path.getsize(syspath(item.path))
except OSError as exc:
- log.info('could not get size of {}: {}', item.path, exc)
+ log.info("could not get size of {}: {}", item.path, exc)
else:
total_size += int(item.length * item.bitrate / 8)
total_time += item.length
@@ -1343,24 +1909,26 @@ def show_stats(lib, query, exact):
if item.album_id:
albums.add(item.album_id)
- size_str = '' + ui.human_bytes(total_size)
+ size_str = "" + ui.human_bytes(total_size)
if exact:
- size_str += f' ({total_size} bytes)'
+ size_str += f" ({total_size} bytes)"
- print_("""Tracks: {}
+ print_(
+ """Tracks: {}
Total time: {}{}
{}: {}
Artists: {}
Albums: {}
Album artists: {}""".format(
- total_items,
- ui.human_seconds(total_time),
- f' ({total_time:.2f} seconds)' if exact else '',
- 'Total size' if exact else 'Approximate total size',
- size_str,
- len(artists),
- len(albums),
- len(album_artists)),
+ total_items,
+ ui.human_seconds(total_time),
+ f" ({total_time:.2f} seconds)" if exact else "",
+ "Total size" if exact else "Approximate total size",
+ size_str,
+ len(artists),
+ len(albums),
+ len(album_artists),
+ ),
)
@@ -1369,11 +1937,10 @@ def stats_func(lib, opts, args):
stats_cmd = ui.Subcommand(
- 'stats', help='show statistics about the library or a query'
+ "stats", help="show statistics about the library or a query"
)
stats_cmd.parser.add_option(
- '-e', '--exact', action='store_true',
- help='exact size and time'
+ "-e", "--exact", action="store_true", help="exact size and time"
)
stats_cmd.func = stats_func
default_commands.append(stats_cmd)
@@ -1381,27 +1948,27 @@ default_commands.append(stats_cmd)
# version: Show current beets version.
+
def show_version(lib, opts, args):
- print_('beets version %s' % beets.__version__)
- print_(f'Python version {python_version()}')
+ print_("beets version %s" % beets.__version__)
+ print_(f"Python version {python_version()}")
# Show plugins.
names = sorted(p.name for p in plugins.find_plugins())
if names:
- print_('plugins:', ', '.join(names))
+ print_("plugins:", ", ".join(names))
else:
- print_('no plugins loaded')
+ print_("no plugins loaded")
-version_cmd = ui.Subcommand(
- 'version', help='output version information'
-)
+version_cmd = ui.Subcommand("version", help="output version information")
version_cmd.func = show_version
default_commands.append(version_cmd)
# modify: Declaratively change metadata.
-def modify_items(lib, mods, dels, query, write, move, album, confirm):
+
+def modify_items(lib, mods, dels, query, write, move, album, confirm, inherit):
"""Modifies matching items according to user-specified assignments and
deletions.
@@ -1411,47 +1978,51 @@ def modify_items(lib, mods, dels, query, write, move, album, confirm):
# Parse key=value specifications into a dictionary.
model_cls = library.Album if album else library.Item
- for key, value in mods.items():
- mods[key] = model_cls._parse(key, value)
-
# Get the items to modify.
items, albums = _do_query(lib, query, album, False)
objs = albums if album else items
# Apply changes *temporarily*, preview them, and collect modified
# objects.
- print_('Modifying {} {}s.'
- .format(len(objs), 'album' if album else 'item'))
+ print_("Modifying {} {}s.".format(len(objs), "album" if album else "item"))
changed = []
+ templates = {
+ key: functemplate.template(value) for key, value in mods.items()
+ }
for obj in objs:
- if print_and_modify(obj, mods, dels) and obj not in changed:
+ obj_mods = {
+ key: model_cls._parse(key, obj.evaluate_template(templates[key]))
+ for key in mods.keys()
+ }
+ if print_and_modify(obj, obj_mods, dels) and obj not in changed:
changed.append(obj)
# Still something to do?
if not changed:
- print_('No changes to make.')
+ print_("No changes to make.")
return
# Confirm action.
if confirm:
if write and move:
- extra = ', move and write tags'
+ extra = ", move and write tags"
elif write:
- extra = ' and write tags'
+ extra = " and write tags"
elif move:
- extra = ' and move'
+ extra = " and move"
else:
- extra = ''
+ extra = ""
changed = ui.input_select_objects(
- 'Really modify%s' % extra, changed,
- lambda o: print_and_modify(o, mods, dels)
+ "Really modify%s" % extra,
+ changed,
+ lambda o: print_and_modify(o, mods, dels),
)
# Apply changes to database and files
with lib.transaction():
for obj in changed:
- obj.try_sync(write, move)
+ obj.try_sync(write, move, inherit)
def print_and_modify(obj, mods, dels):
@@ -1479,10 +2050,10 @@ def modify_parse_args(args):
dels = []
query = []
for arg in args:
- if arg.endswith('!') and '=' not in arg and ':' not in arg:
+ if arg.endswith("!") and "=" not in arg and ":" not in arg:
dels.append(arg[:-1]) # Strip trailing !.
- elif '=' in arg and ':' not in arg.split('=', 1)[0]:
- key, val = arg.split('=', 1)
+ elif "=" in arg and ":" not in arg.split("=", 1)[0]:
+ key, val = arg.split("=", 1)
mods[key] = val
else:
query.append(arg)
@@ -1492,35 +2063,63 @@ def modify_parse_args(args):
def modify_func(lib, opts, args):
query, mods, dels = modify_parse_args(decargs(args))
if not mods and not dels:
- raise ui.UserError('no modifications specified')
- modify_items(lib, mods, dels, query, ui.should_write(opts.write),
- ui.should_move(opts.move), opts.album, not opts.yes)
+ raise ui.UserError("no modifications specified")
+ modify_items(
+ lib,
+ mods,
+ dels,
+ query,
+ ui.should_write(opts.write),
+ ui.should_move(opts.move),
+ opts.album,
+ not opts.yes,
+ opts.inherit,
+ )
modify_cmd = ui.Subcommand(
- 'modify', help='change metadata fields', aliases=('mod',)
+ "modify", help="change metadata fields", aliases=("mod",)
)
modify_cmd.parser.add_option(
- '-m', '--move', action='store_true', dest='move',
- help="move files in the library directory"
+ "-m",
+ "--move",
+ action="store_true",
+ dest="move",
+ help="move files in the library directory",
)
modify_cmd.parser.add_option(
- '-M', '--nomove', action='store_false', dest='move',
- help="don't move files in library"
+ "-M",
+ "--nomove",
+ action="store_false",
+ dest="move",
+ help="don't move files in library",
)
modify_cmd.parser.add_option(
- '-w', '--write', action='store_true', default=None,
- help="write new metadata to files' tags (default)"
+ "-w",
+ "--write",
+ action="store_true",
+ default=None,
+ help="write new metadata to files' tags (default)",
)
modify_cmd.parser.add_option(
- '-W', '--nowrite', action='store_false', dest='write',
- help="don't write metadata (opposite of -w)"
+ "-W",
+ "--nowrite",
+ action="store_false",
+ dest="write",
+ help="don't write metadata (opposite of -w)",
)
modify_cmd.parser.add_album_option()
-modify_cmd.parser.add_format_option(target='item')
+modify_cmd.parser.add_format_option(target="item")
modify_cmd.parser.add_option(
- '-y', '--yes', action='store_true',
- help='skip confirmation'
+ "-y", "--yes", action="store_true", help="skip confirmation"
+)
+modify_cmd.parser.add_option(
+ "-I",
+ "--noinherit",
+ action="store_false",
+ dest="inherit",
+ default=True,
+ help="when modifying albums, don't also change item data",
)
modify_cmd.func = modify_func
default_commands.append(modify_cmd)
@@ -1528,8 +2127,10 @@ default_commands.append(modify_cmd)
# move: Move/copy files to the library or a new base directory.
-def move_items(lib, dest, query, copy, album, pretend, confirm=False,
- export=False):
+
+def move_items(
+ lib, dest, query, copy, album, pretend, confirm=False, export=False
+):
"""Moves or copies items to a new base directory, given by dest. If
dest is None, then the library's base directory is used, making the
command "consolidate" files.
@@ -1548,40 +2149,56 @@ def move_items(lib, dest, query, copy, album, pretend, confirm=False,
objs = [o for o in objs if (isalbummoved if album else isitemmoved)(o)]
num_unmoved = num_objs - len(objs)
# Report unmoved files that match the query.
- unmoved_msg = ''
+ unmoved_msg = ""
if num_unmoved > 0:
- unmoved_msg = f' ({num_unmoved} already in place)'
+ unmoved_msg = f" ({num_unmoved} already in place)"
copy = copy or export # Exporting always copies.
- action = 'Copying' if copy else 'Moving'
- act = 'copy' if copy else 'move'
- entity = 'album' if album else 'item'
- log.info('{0} {1} {2}{3}{4}.', action, len(objs), entity,
- 's' if len(objs) != 1 else '', unmoved_msg)
+ action = "Copying" if copy else "Moving"
+ act = "copy" if copy else "move"
+ entity = "album" if album else "item"
+ log.info(
+ "{0} {1} {2}{3}{4}.",
+ action,
+ len(objs),
+ entity,
+ "s" if len(objs) != 1 else "",
+ unmoved_msg,
+ )
if not objs:
return
if pretend:
if album:
- show_path_changes([(item.path, item.destination(basedir=dest))
- for obj in objs for item in obj.items()])
+ show_path_changes(
+ [
+ (item.path, item.destination(basedir=dest))
+ for obj in objs
+ for item in obj.items()
+ ]
+ )
else:
- show_path_changes([(obj.path, obj.destination(basedir=dest))
- for obj in objs])
+ show_path_changes(
+ [(obj.path, obj.destination(basedir=dest)) for obj in objs]
+ )
else:
if confirm:
objs = ui.input_select_objects(
- 'Really %s' % act, objs,
+ "Really %s" % act,
+ objs,
lambda o: show_path_changes(
- [(o.path, o.destination(basedir=dest))]))
+ [(o.path, o.destination(basedir=dest))]
+ ),
+ )
for obj in objs:
- log.debug('moving: {0}', util.displayable_path(obj.path))
+ log.debug("moving: {0}", util.displayable_path(obj.path))
if export:
# Copy without affecting the database.
- obj.move(operation=MoveOperation.COPY, basedir=dest,
- store=False)
+ obj.move(
+ operation=MoveOperation.COPY, basedir=dest, store=False
+ )
else:
# Ordinary move/copy: store the new path.
if copy:
@@ -1594,35 +2211,54 @@ def move_func(lib, opts, args):
dest = opts.dest
if dest is not None:
dest = normpath(dest)
- if not os.path.isdir(dest):
- raise ui.UserError('no such directory: %s' % dest)
+ if not os.path.isdir(syspath(dest)):
+ raise ui.UserError(
+ "no such directory: {}".format(displayable_path(dest))
+ )
- move_items(lib, dest, decargs(args), opts.copy, opts.album, opts.pretend,
- opts.timid, opts.export)
+ move_items(
+ lib,
+ dest,
+ decargs(args),
+ opts.copy,
+ opts.album,
+ opts.pretend,
+ opts.timid,
+ opts.export,
+ )
-move_cmd = ui.Subcommand(
- 'move', help='move or copy items', aliases=('mv',)
+move_cmd = ui.Subcommand("move", help="move or copy items", aliases=("mv",))
+move_cmd.parser.add_option(
+ "-d", "--dest", metavar="DIR", dest="dest", help="destination directory"
)
move_cmd.parser.add_option(
- '-d', '--dest', metavar='DIR', dest='dest',
- help='destination directory'
+ "-c",
+ "--copy",
+ default=False,
+ action="store_true",
+ help="copy instead of moving",
)
move_cmd.parser.add_option(
- '-c', '--copy', default=False, action='store_true',
- help='copy instead of moving'
+ "-p",
+ "--pretend",
+ default=False,
+ action="store_true",
+ help="show how files would be moved, but don't touch anything",
)
move_cmd.parser.add_option(
- '-p', '--pretend', default=False, action='store_true',
- help='show how files would be moved, but don\'t touch anything'
+ "-t",
+ "--timid",
+ dest="timid",
+ action="store_true",
+ help="always confirm all actions",
)
move_cmd.parser.add_option(
- '-t', '--timid', dest='timid', action='store_true',
- help='always confirm all actions'
-)
-move_cmd.parser.add_option(
- '-e', '--export', default=False, action='store_true',
- help='copy without changing the database path'
+ "-e",
+ "--export",
+ default=False,
+ action="store_true",
+ help="copy without changing the database path",
)
move_cmd.parser.add_album_option()
move_cmd.func = move_func
@@ -1631,6 +2267,7 @@ default_commands.append(move_cmd)
# write: Write tags into files.
+
def write_items(lib, query, pretend, force):
"""Write tag information from the database to the respective files
in the filesystem.
@@ -1640,20 +2277,22 @@ def write_items(lib, query, pretend, force):
for item in items:
# Item deleted?
if not os.path.exists(syspath(item.path)):
- log.info('missing file: {0}', util.displayable_path(item.path))
+ log.info("missing file: {0}", util.displayable_path(item.path))
continue
# Get an Item object reflecting the "clean" (on-disk) state.
try:
clean_item = library.Item.from_path(item.path)
except library.ReadError as exc:
- log.error('error reading {0}: {1}',
- displayable_path(item.path), exc)
+ log.error(
+ "error reading {0}: {1}", displayable_path(item.path), exc
+ )
continue
# Check for and display changes.
- changed = ui.show_model_changes(item, clean_item,
- library.Item._media_tag_fields, force)
+ changed = ui.show_model_changes(
+ item, clean_item, library.Item._media_tag_fields, force
+ )
if (changed or force) and not pretend:
# We use `try_sync` here to keep the mtime up to date in the
# database.
@@ -1664,14 +2303,18 @@ def write_func(lib, opts, args):
write_items(lib, decargs(args), opts.pretend, opts.force)
-write_cmd = ui.Subcommand('write', help='write tag information to files')
+write_cmd = ui.Subcommand("write", help="write tag information to files")
write_cmd.parser.add_option(
- '-p', '--pretend', action='store_true',
- help="show all changes but do nothing"
+ "-p",
+ "--pretend",
+ action="store_true",
+ help="show all changes but do nothing",
)
write_cmd.parser.add_option(
- '-f', '--force', action='store_true',
- help="write tags even if the existing tags match the database"
+ "-f",
+ "--force",
+ action="store_true",
+ help="write tags even if the existing tags match the database",
)
write_cmd.func = write_func
default_commands.append(write_cmd)
@@ -1679,6 +2322,7 @@ default_commands.append(write_cmd)
# config: Show and edit user configuration.
+
def config_func(lib, opts, args):
# Make sure lazy configuration is loaded
config.resolve()
@@ -1708,8 +2352,8 @@ def config_func(lib, opts, args):
# Dump configuration.
else:
config_out = config.dump(full=opts.defaults, redact=opts.redact)
- if config_out.strip() != '{}':
- print_(util.text_string(config_out))
+ if config_out.strip() != "{}":
+ print_(config_out)
else:
print("Empty configuration")
@@ -1722,33 +2366,43 @@ def config_edit():
editor = util.editor_command()
try:
if not os.path.isfile(path):
- open(path, 'w+').close()
+ open(path, "w+").close()
util.interactive_open([path], editor)
except OSError as exc:
message = f"Could not edit configuration: {exc}"
if not editor:
- message += ". Please set the EDITOR environment variable"
+ message += (
+ ". Please set the VISUAL (or EDITOR) environment variable"
+ )
raise ui.UserError(message)
-config_cmd = ui.Subcommand('config',
- help='show or edit the user configuration')
+config_cmd = ui.Subcommand("config", help="show or edit the user configuration")
config_cmd.parser.add_option(
- '-p', '--paths', action='store_true',
- help='show files that configuration was loaded from'
+ "-p",
+ "--paths",
+ action="store_true",
+ help="show files that configuration was loaded from",
)
config_cmd.parser.add_option(
- '-e', '--edit', action='store_true',
- help='edit user configuration with $EDITOR'
+ "-e",
+ "--edit",
+ action="store_true",
+ help="edit user configuration with $VISUAL (or $EDITOR)",
)
config_cmd.parser.add_option(
- '-d', '--defaults', action='store_true',
- help='include the default configuration'
+ "-d",
+ "--defaults",
+ action="store_true",
+ help="include the default configuration",
)
config_cmd.parser.add_option(
- '-c', '--clear', action='store_false',
- dest='redact', default=True,
- help='do not redact sensitive fields'
+ "-c",
+ "--clear",
+ action="store_false",
+ dest="redact",
+ default=True,
+ help="do not redact sensitive fields",
)
config_cmd.func = config_func
default_commands.append(config_cmd)
@@ -1756,23 +2410,26 @@ default_commands.append(config_cmd)
# completion: print completion script
+
def print_completion(*args):
for line in completion_script(default_commands + plugins.commands()):
- print_(line, end='')
- if not any(map(os.path.isfile, BASH_COMPLETION_PATHS)):
- log.warning('Warning: Unable to find the bash-completion package. '
- 'Command line completion might not work.')
+ print_(line, end="")
+ if not any(os.path.isfile(syspath(p)) for p in BASH_COMPLETION_PATHS):
+ log.warning(
+ "Warning: Unable to find the bash-completion package. "
+ "Command line completion might not work."
+ )
-BASH_COMPLETION_PATHS = map(syspath, [
- '/etc/bash_completion',
- '/usr/share/bash-completion/bash_completion',
- '/usr/local/share/bash-completion/bash_completion',
+BASH_COMPLETION_PATHS = [
+ b"/etc/bash_completion",
+ b"/usr/share/bash-completion/bash_completion",
+ b"/usr/local/share/bash-completion/bash_completion",
# SmartOS
- '/opt/local/share/bash-completion/bash_completion',
+ b"/opt/local/share/bash-completion/bash_completion",
# Homebrew (before bash-completion2)
- '/usr/local/etc/bash_completion',
-])
+ b"/usr/local/etc/bash_completion",
+]
def completion_script(commands):
@@ -1781,9 +2438,9 @@ def completion_script(commands):
``commands`` is alist of ``ui.Subcommand`` instances to generate
completion data for.
"""
- base_script = os.path.join(os.path.dirname(__file__), 'completion_base.sh')
+ base_script = os.path.join(os.path.dirname(__file__), "completion_base.sh")
with open(base_script) as base_script:
- yield util.text_string(base_script.read())
+ yield base_script.read()
options = {}
aliases = {}
@@ -1795,50 +2452,47 @@ def completion_script(commands):
command_names.append(name)
for alias in cmd.aliases:
- if re.match(r'^\w+$', alias):
+ if re.match(r"^\w+$", alias):
aliases[alias] = name
- options[name] = {'flags': [], 'opts': []}
+ options[name] = {"flags": [], "opts": []}
for opts in cmd.parser._get_all_options()[1:]:
- if opts.action in ('store_true', 'store_false'):
- option_type = 'flags'
+ if opts.action in ("store_true", "store_false"):
+ option_type = "flags"
else:
- option_type = 'opts'
+ option_type = "opts"
options[name][option_type].extend(
opts._short_opts + opts._long_opts
)
# Add global options
- options['_global'] = {
- 'flags': ['-v', '--verbose'],
- 'opts':
- '-l --library -c --config -d --directory -h --help'.split(' ')
+ options["_global"] = {
+ "flags": ["-v", "--verbose"],
+ "opts": "-l --library -c --config -d --directory -h --help".split(" "),
}
# Add flags common to all commands
- options['_common'] = {
- 'flags': ['-h', '--help']
- }
+ options["_common"] = {"flags": ["-h", "--help"]}
# Start generating the script
yield "_beet() {\n"
# Command names
- yield " local commands='%s'\n" % ' '.join(command_names)
+ yield " local commands='%s'\n" % " ".join(command_names)
yield "\n"
# Command aliases
- yield " local aliases='%s'\n" % ' '.join(aliases.keys())
+ yield " local aliases='%s'\n" % " ".join(aliases.keys())
for alias, cmd in aliases.items():
- yield " local alias__{}={}\n".format(alias.replace('-', '_'), cmd)
- yield '\n'
+ yield " local alias__{}={}\n".format(alias.replace("-", "_"), cmd)
+ yield "\n"
# Fields
- yield " fields='%s'\n" % ' '.join(
+ yield " fields='%s'\n" % " ".join(
set(
- list(library.Item._fields.keys()) +
- list(library.Album._fields.keys())
+ list(library.Item._fields.keys())
+ + list(library.Album._fields.keys())
)
)
@@ -1846,17 +2500,18 @@ def completion_script(commands):
for cmd, opts in options.items():
for option_type, option_list in opts.items():
if option_list:
- option_list = ' '.join(option_list)
+ option_list = " ".join(option_list)
yield " local {}__{}='{}'\n".format(
- option_type, cmd.replace('-', '_'), option_list)
+ option_type, cmd.replace("-", "_"), option_list
+ )
- yield ' _beet_dispatch\n'
- yield '}\n'
+ yield " _beet_dispatch\n"
+ yield "}\n"
completion_cmd = ui.Subcommand(
- 'completion',
- help='print shell script that provides command line completion'
+ "completion",
+ help="print shell script that provides command line completion",
)
completion_cmd.func = print_completion
completion_cmd.hide = True
diff --git a/lib/beets/ui/completion_base.sh b/lib/beets/ui/completion_base.sh
index 1eaa4db3..e83f9d2c 100644
--- a/lib/beets/ui/completion_base.sh
+++ b/lib/beets/ui/completion_base.sh
@@ -31,7 +31,7 @@
# plugins dynamically
#
# Currently, only Bash 3.2 and newer is supported and the
-# `bash-completion` package is requied.
+# `bash-completion` package (v2.8 or newer) is required.
#
# TODO
# ----
@@ -46,7 +46,30 @@
# * Support long options with `=`, e.g. `--config=file`. Debian's bash
# completion package can handle this.
#
+# Note that 'bash-completion' v2.8 is a part of Debian 10, which is part of
+# LTS until 2024-06-30. After this date, the minimum version requirement can
+# be changed, and newer features can be used unconditionally. See PR#5301.
+#
+if [[ ${BASH_COMPLETION_VERSINFO[0]} -ne 2 \
+ || ${BASH_COMPLETION_VERSINFO[1]} -lt 8 ]]; then
+ echo "Incompatible version of 'bash-completion'!"
+ return 1
+fi
+
+# The later code relies on 'bash-completion' version 2.12, but older versions
+# are still supported. Here, we provide implementations of the newer functions
+# in terms of older ones, if 'bash-completion' is too old to have them.
+
+if [[ ${BASH_COMPLETION_VERSINFO[1]} -lt 12 ]]; then
+ _comp_get_words() {
+ _get_comp_words_by_ref "$@"
+ }
+
+ _comp_compgen_filedir() {
+ _filedir "$@"
+ }
+fi
# Determines the beets subcommand and dispatches the completion
# accordingly.
@@ -54,7 +77,7 @@ _beet_dispatch() {
local cur prev cmd=
COMPREPLY=()
- _get_comp_words_by_ref -n : cur prev
+ _comp_get_words -n : cur prev
# Look for the beets subcommand
local arg
@@ -99,7 +122,7 @@ _beet_complete() {
completions="${flags___common} ${opts} ${flags}"
COMPREPLY+=( $(compgen -W "$completions" -- $cur) )
else
- _filedir
+ _comp_compgen_filedir
fi
}
@@ -114,12 +137,12 @@ _beet_complete_global() {
;;
-l|--library|-c|--config)
# Filename completion
- _filedir
+ _comp_compgen_filedir
return
;;
-d|--directory)
# Directory completion
- _filedir -d
+ _comp_compgen_filedir -d
return
;;
esac
diff --git a/lib/beets/util/__init__.py b/lib/beets/util/__init__.py
index d58bb28e..bfb23c05 100644
--- a/lib/beets/util/__init__.py
+++ b/lib/beets/util/__init__.py
@@ -13,29 +13,55 @@
# included in all copies or substantial portions of the Software.
"""Miscellaneous utility functions."""
+from __future__ import annotations
-import os
-import sys
import errno
-import locale
-import re
-import tempfile
-import shutil
import fnmatch
-import functools
-from collections import Counter, namedtuple
-from multiprocessing.pool import ThreadPool
-import traceback
-import subprocess
+import os
import platform
+import re
import shlex
-from beets.util import hidden
-from unidecode import unidecode
+import shutil
+import subprocess
+import sys
+import tempfile
+import traceback
+from collections import Counter, namedtuple
+from contextlib import suppress
from enum import Enum
+from logging import Logger
+from multiprocessing.pool import ThreadPool
+from pathlib import Path
+from typing import (
+ Any,
+ AnyStr,
+ Callable,
+ Generator,
+ Iterable,
+ List,
+ MutableSequence,
+ Optional,
+ Pattern,
+ Sequence,
+ Tuple,
+ TypeVar,
+ Union,
+)
+if sys.version_info >= (3, 10):
+ from typing import TypeAlias
+else:
+ from typing_extensions import TypeAlias
+
+from unidecode import unidecode
+
+from beets.util import hidden
MAX_FILENAME_LENGTH = 200
-WINDOWS_MAGIC_PREFIX = '\\\\?\\'
+WINDOWS_MAGIC_PREFIX = "\\\\?\\"
+T = TypeVar("T")
+Bytes_or_String: TypeAlias = Union[str, bytes]
+PathLike = Union[str, bytes, Path]
class HumanReadableException(Exception):
@@ -51,7 +77,8 @@ class HumanReadableException(Exception):
associated exception. (Note that this is not necessary in Python 3.x
and should be removed when we make the transition.)
"""
- error_kind = 'Error' # Human-readable description of error type.
+
+ error_kind = "Error" # Human-readable description of error type.
def __init__(self, reason, verb, tb=None):
self.reason = reason
@@ -60,12 +87,11 @@ class HumanReadableException(Exception):
super().__init__(self.get_message())
def _gerund(self):
- """Generate a (likely) gerund form of the English verb.
- """
- if ' ' in self.verb:
+ """Generate a (likely) gerund form of the English verb."""
+ if " " in self.verb:
return self.verb
- gerund = self.verb[:-1] if self.verb.endswith('e') else self.verb
- gerund += 'ing'
+ gerund = self.verb[:-1] if self.verb.endswith("e") else self.verb
+ gerund += "ing"
return gerund
def _reasonstr(self):
@@ -73,8 +99,8 @@ class HumanReadableException(Exception):
if isinstance(self.reason, str):
return self.reason
elif isinstance(self.reason, bytes):
- return self.reason.decode('utf-8', 'ignore')
- elif hasattr(self.reason, 'strerror'): # i.e., EnvironmentError
+ return self.reason.decode("utf-8", "ignore")
+ elif hasattr(self.reason, "strerror"): # i.e., EnvironmentError
return self.reason.strerror
else:
return '"{}"'.format(str(self.reason))
@@ -91,7 +117,7 @@ class HumanReadableException(Exception):
"""
if self.tb:
logger.debug(self.tb)
- logger.error('{0}: {1}', self.error_kind, self.args[0])
+ logger.error("{0}: {1}", self.error_kind, self.args[0])
class FilesystemError(HumanReadableException):
@@ -106,28 +132,27 @@ class FilesystemError(HumanReadableException):
def get_message(self):
# Use a nicer English phrasing for some specific verbs.
- if self.verb in ('move', 'copy', 'rename'):
- clause = 'while {} {} to {}'.format(
+ if self.verb in ("move", "copy", "rename"):
+ clause = "while {} {} to {}".format(
self._gerund(),
displayable_path(self.paths[0]),
- displayable_path(self.paths[1])
+ displayable_path(self.paths[1]),
)
- elif self.verb in ('delete', 'write', 'create', 'read'):
- clause = 'while {} {}'.format(
- self._gerund(),
- displayable_path(self.paths[0])
+ elif self.verb in ("delete", "write", "create", "read"):
+ clause = "while {} {}".format(
+ self._gerund(), displayable_path(self.paths[0])
)
else:
- clause = 'during {} of paths {}'.format(
- self.verb, ', '.join(displayable_path(p) for p in self.paths)
+ clause = "during {} of paths {}".format(
+ self.verb, ", ".join(displayable_path(p) for p in self.paths)
)
- return f'{self._reasonstr()} {clause}'
+ return f"{self._reasonstr()} {clause}"
class MoveOperation(Enum):
- """The file operations that e.g. various move functions can carry out.
- """
+ """The file operations that e.g. various move functions can carry out."""
+
MOVE = 0
COPY = 1
LINK = 2
@@ -136,7 +161,7 @@ class MoveOperation(Enum):
REFLINK_AUTO = 5
-def normpath(path):
+def normpath(path: bytes) -> bytes:
"""Provide the canonical form of the path suitable for storing in
the database.
"""
@@ -145,11 +170,11 @@ def normpath(path):
return bytestring_path(path)
-def ancestry(path):
+def ancestry(path: bytes) -> List[str]:
"""Return a list consisting of path's parent directory, its
grandparent, and so on. For instance:
- >>> ancestry('/a/b/c')
+ >>> ancestry(b'/a/b/c')
['/', '/a', '/a/b']
The argument should *not* be the result of a call to `syspath`.
@@ -169,13 +194,18 @@ def ancestry(path):
return out
-def sorted_walk(path, ignore=(), ignore_hidden=False, logger=None):
+def sorted_walk(
+ path: AnyStr,
+ ignore: Sequence = (),
+ ignore_hidden: bool = False,
+ logger: Optional[Logger] = None,
+) -> Generator[Tuple, None, None]:
"""Like `os.walk`, but yields things in case-insensitive sorted,
breadth-first order. Directory and file names matching any glob
pattern in `ignore` are skipped. If `logger` is provided, then
warning messages are logged there when a directory cannot be listed.
"""
- # Make sure the pathes aren't Unicode strings.
+ # Make sure the paths aren't Unicode strings.
path = bytestring_path(path)
ignore = [bytestring_path(i) for i in ignore]
@@ -184,9 +214,11 @@ def sorted_walk(path, ignore=(), ignore_hidden=False, logger=None):
contents = os.listdir(syspath(path))
except OSError as exc:
if logger:
- logger.warning('could not list directory {}: {}'.format(
- displayable_path(path), exc.strerror
- ))
+ logger.warning(
+ "could not list directory {}: {}".format(
+ displayable_path(path), exc.strerror
+ )
+ )
return
dirs = []
files = []
@@ -198,9 +230,9 @@ def sorted_walk(path, ignore=(), ignore_hidden=False, logger=None):
for pat in ignore:
if fnmatch.fnmatch(base, pat):
if logger:
- logger.debug('ignoring {} due to ignore rule {}'.format(
- base, pat
- ))
+ logger.debug(
+ "ignoring {} due to ignore rule {}".format(base, pat)
+ )
skip = True
break
if skip:
@@ -226,14 +258,14 @@ def sorted_walk(path, ignore=(), ignore_hidden=False, logger=None):
yield from sorted_walk(cur, ignore, ignore_hidden, logger)
-def path_as_posix(path):
+def path_as_posix(path: bytes) -> bytes:
"""Return the string representation of the path with forward (/)
slashes.
"""
- return path.replace(b'\\', b'/')
+ return path.replace(b"\\", b"/")
-def mkdirall(path):
+def mkdirall(path: bytes):
"""Make all the enclosing directories of path (like mkdir -p on the
parent).
"""
@@ -242,11 +274,12 @@ def mkdirall(path):
try:
os.mkdir(syspath(ancestor))
except OSError as exc:
- raise FilesystemError(exc, 'create', (ancestor,),
- traceback.format_exc())
+ raise FilesystemError(
+ exc, "create", (ancestor,), traceback.format_exc()
+ )
-def fnmatch_all(names, patterns):
+def fnmatch_all(names: Sequence[bytes], patterns: Sequence[bytes]) -> bool:
"""Determine whether all strings in `names` match at least one of
the `patterns`, which should be shell glob expressions.
"""
@@ -261,7 +294,11 @@ def fnmatch_all(names, patterns):
return True
-def prune_dirs(path, root=None, clutter=('.DS_Store', 'Thumbs.db')):
+def prune_dirs(
+ path: str,
+ root: Optional[Bytes_or_String] = None,
+ clutter: Sequence[str] = (".DS_Store", "Thumbs.db"),
+):
"""If path is an empty directory, then remove it. Recursively remove
path's ancestry up to root (which is never removed) where there are
empty directories. If path is not contained in root, then nothing is
@@ -279,7 +316,7 @@ def prune_dirs(path, root=None, clutter=('.DS_Store', 'Thumbs.db')):
ancestors = []
elif root in ancestors:
# Only remove directories below the root.
- ancestors = ancestors[ancestors.index(root) + 1:]
+ ancestors = ancestors[ancestors.index(root) + 1 :]
else:
# Remove nothing.
return
@@ -292,7 +329,7 @@ def prune_dirs(path, root=None, clutter=('.DS_Store', 'Thumbs.db')):
if not os.path.exists(directory):
# Directory gone already.
continue
- clutter = [bytestring_path(c) for c in clutter]
+ clutter: List[bytes] = [bytestring_path(c) for c in clutter]
match_paths = [bytestring_path(d) for d in os.listdir(directory)]
try:
if fnmatch_all(match_paths, clutter):
@@ -304,10 +341,10 @@ def prune_dirs(path, root=None, clutter=('.DS_Store', 'Thumbs.db')):
break
-def components(path):
+def components(path: AnyStr) -> MutableSequence[AnyStr]:
"""Return a list of the path components in path. For instance:
- >>> components('/a/b/c')
+ >>> components(b'/a/b/c')
['a', 'b', 'c']
The argument should *not* be the result of a call to `syspath`.
@@ -328,58 +365,61 @@ def components(path):
return comps
-def arg_encoding():
+def arg_encoding() -> str:
"""Get the encoding for command-line arguments (and other OS
locale-sensitive strings).
"""
- try:
- return locale.getdefaultlocale()[1] or 'utf-8'
- except ValueError:
- # Invalid locale environment variable setting. To avoid
- # failing entirely for no good reason, assume UTF-8.
- return 'utf-8'
+ return sys.getfilesystemencoding()
-def _fsencoding():
+def _fsencoding() -> str:
"""Get the system's filesystem encoding. On Windows, this is always
UTF-8 (not MBCS).
"""
encoding = sys.getfilesystemencoding() or sys.getdefaultencoding()
- if encoding == 'mbcs':
+ if encoding == "mbcs":
# On Windows, a broken encoding known to Python as "MBCS" is
# used for the filesystem. However, we only use the Unicode API
# for Windows paths, so the encoding is actually immaterial so
# we can avoid dealing with this nastiness. We arbitrarily
# choose UTF-8.
- encoding = 'utf-8'
+ encoding = "utf-8"
return encoding
-def bytestring_path(path):
+def bytestring_path(path: PathLike) -> bytes:
"""Given a path, which is either a bytes or a unicode, returns a str
- path (ensuring that we never deal with Unicode pathnames).
+ path (ensuring that we never deal with Unicode pathnames). Path should be
+ bytes but has safeguards for strings to be converted.
"""
# Pass through bytestrings.
if isinstance(path, bytes):
return path
+ str_path = str(path)
+
# On Windows, remove the magic prefix added by `syspath`. This makes
# ``bytestring_path(syspath(X)) == X``, i.e., we can safely
# round-trip through `syspath`.
- if os.path.__name__ == 'ntpath' and path.startswith(WINDOWS_MAGIC_PREFIX):
- path = path[len(WINDOWS_MAGIC_PREFIX):]
+ if os.path.__name__ == "ntpath" and str_path.startswith(
+ WINDOWS_MAGIC_PREFIX
+ ):
+ str_path = str_path[len(WINDOWS_MAGIC_PREFIX) :]
# Try to encode with default encodings, but fall back to utf-8.
try:
- return path.encode(_fsencoding())
+ return str_path.encode(_fsencoding())
except (UnicodeError, LookupError):
- return path.encode('utf-8')
+ return str_path.encode("utf-8")
-PATH_SEP = bytestring_path(os.sep)
+PATH_SEP: bytes = bytestring_path(os.sep)
-def displayable_path(path, separator='; '):
+def displayable_path(
+ path: Union[bytes, str, Tuple[Union[bytes, str], ...]],
+ separator: str = "; ",
+) -> str:
"""Attempts to decode a bytestring path to a unicode object for the
purpose of displaying it to the user. If the `path` argument is a
list or a tuple, the elements are joined with `separator`.
@@ -393,66 +433,55 @@ def displayable_path(path, separator='; '):
return str(path)
try:
- return path.decode(_fsencoding(), 'ignore')
+ return path.decode(_fsencoding(), "ignore")
except (UnicodeError, LookupError):
- return path.decode('utf-8', 'ignore')
+ return path.decode("utf-8", "ignore")
-def syspath(path, prefix=True):
+def syspath(path: PathLike, prefix: bool = True) -> str:
"""Convert a path for use by the operating system. In particular,
paths on Windows must receive a magic prefix and must be converted
to Unicode before they are sent to the OS. To disable the magic
prefix on Windows, set `prefix` to False---but only do this if you
*really* know what you're doing.
"""
+ str_path = os.fsdecode(path)
# Don't do anything if we're not on windows
- if os.path.__name__ != 'ntpath':
- return path
-
- if not isinstance(path, str):
- # Beets currently represents Windows paths internally with UTF-8
- # arbitrarily. But earlier versions used MBCS because it is
- # reported as the FS encoding by Windows. Try both.
- try:
- path = path.decode('utf-8')
- except UnicodeError:
- # The encoding should always be MBCS, Windows' broken
- # Unicode representation.
- encoding = sys.getfilesystemencoding() or sys.getdefaultencoding()
- path = path.decode(encoding, 'replace')
+ if os.path.__name__ != "ntpath":
+ return str_path
# Add the magic prefix if it isn't already there.
# https://msdn.microsoft.com/en-us/library/windows/desktop/aa365247.aspx
- if prefix and not path.startswith(WINDOWS_MAGIC_PREFIX):
- if path.startswith('\\\\'):
+ if prefix and not str_path.startswith(WINDOWS_MAGIC_PREFIX):
+ if str_path.startswith("\\\\"):
# UNC path. Final path should look like \\?\UNC\...
- path = 'UNC' + path[1:]
- path = WINDOWS_MAGIC_PREFIX + path
+ str_path = "UNC" + str_path[1:]
+ str_path = WINDOWS_MAGIC_PREFIX + str_path
- return path
+ return str_path
-def samefile(p1, p2):
+def samefile(p1: bytes, p2: bytes) -> bool:
"""Safer equality for paths."""
if p1 == p2:
return True
return shutil._samefile(syspath(p1), syspath(p2))
-def remove(path, soft=True):
+def remove(path: Optional[bytes], soft: bool = True):
"""Remove the file. If `soft`, then no error will be raised if the
file does not exist.
"""
path = syspath(path)
- if soft and not os.path.exists(path):
+ if not path or (soft and not os.path.exists(path)):
return
try:
os.remove(path)
except OSError as exc:
- raise FilesystemError(exc, 'delete', (path,), traceback.format_exc())
+ raise FilesystemError(exc, "delete", (path,), traceback.format_exc())
-def copy(path, dest, replace=False):
+def copy(path: bytes, dest: bytes, replace: bool = False):
"""Copy a plain file. Permissions are not copied. If `dest` already
exists, raises a FilesystemError unless `replace` is True. Has no
effect if `path` is the same as `dest`. Paths are translated to
@@ -463,15 +492,14 @@ def copy(path, dest, replace=False):
path = syspath(path)
dest = syspath(dest)
if not replace and os.path.exists(dest):
- raise FilesystemError('file exists', 'copy', (path, dest))
+ raise FilesystemError("file exists", "copy", (path, dest))
try:
shutil.copyfile(path, dest)
except OSError as exc:
- raise FilesystemError(exc, 'copy', (path, dest),
- traceback.format_exc())
+ raise FilesystemError(exc, "copy", (path, dest), traceback.format_exc())
-def move(path, dest, replace=False):
+def move(path: bytes, dest: bytes, replace: bool = False):
"""Rename a file. `dest` may not be a directory. If `dest` already
exists, raises an OSError unless `replace` is True. Has no effect if
`path` is the same as `dest`. If the paths are on different
@@ -479,40 +507,49 @@ def move(path, dest, replace=False):
instead, in which case metadata will *not* be preserved. Paths are
translated to system paths.
"""
- if os.path.isdir(path):
- raise FilesystemError(u'source is directory', 'move', (path, dest))
- if os.path.isdir(dest):
- raise FilesystemError(u'destination is directory', 'move',
- (path, dest))
+ if os.path.isdir(syspath(path)):
+ raise FilesystemError("source is directory", "move", (path, dest))
+ if os.path.isdir(syspath(dest)):
+ raise FilesystemError("destination is directory", "move", (path, dest))
if samefile(path, dest):
return
- path = syspath(path)
- dest = syspath(dest)
- if os.path.exists(dest) and not replace:
- raise FilesystemError('file exists', 'rename', (path, dest))
+ if os.path.exists(syspath(dest)) and not replace:
+ raise FilesystemError("file exists", "rename", (path, dest))
# First, try renaming the file.
try:
- os.replace(path, dest)
+ os.replace(syspath(path), syspath(dest))
except OSError:
- tmp = tempfile.mktemp(suffix='.beets',
- prefix=py3_path(b'.' + os.path.basename(dest)),
- dir=py3_path(os.path.dirname(dest)))
- tmp = syspath(tmp)
+ # Copy the file to a temporary destination.
+ basename = os.path.basename(bytestring_path(dest))
+ dirname = os.path.dirname(bytestring_path(dest))
+ tmp = tempfile.NamedTemporaryFile(
+ suffix=syspath(b".beets", prefix=False),
+ prefix=syspath(b"." + basename + b".", prefix=False),
+ dir=syspath(dirname),
+ delete=False,
+ )
try:
- shutil.copyfile(path, tmp)
- os.replace(tmp, dest)
+ with open(syspath(path), "rb") as f:
+ shutil.copyfileobj(f, tmp)
+ finally:
+ tmp.close()
+
+ # Move the copied file into place.
+ try:
+ os.replace(tmp.name, syspath(dest))
tmp = None
- os.remove(path)
+ os.remove(syspath(path))
except OSError as exc:
- raise FilesystemError(exc, 'move', (path, dest),
- traceback.format_exc())
+ raise FilesystemError(
+ exc, "move", (path, dest), traceback.format_exc()
+ )
finally:
if tmp is not None:
os.remove(tmp)
-def link(path, dest, replace=False):
+def link(path: bytes, dest: bytes, replace: bool = False):
"""Create a symbolic link from path to `dest`. Raises an OSError if
`dest` already exists, unless `replace` is True. Does nothing if
`path` == `dest`.
@@ -521,23 +558,21 @@ def link(path, dest, replace=False):
return
if os.path.exists(syspath(dest)) and not replace:
- raise FilesystemError('file exists', 'rename', (path, dest))
+ raise FilesystemError("file exists", "rename", (path, dest))
try:
os.symlink(syspath(path), syspath(dest))
except NotImplementedError:
# raised on python >= 3.2 and Windows versions before Vista
- raise FilesystemError('OS does not support symbolic links.'
- 'link', (path, dest), traceback.format_exc())
+ raise FilesystemError(
+ "OS does not support symbolic links." "link",
+ (path, dest),
+ traceback.format_exc(),
+ )
except OSError as exc:
- # TODO: Windows version checks can be removed for python 3
- if hasattr('sys', 'getwindowsversion'):
- if sys.getwindowsversion()[0] < 6: # is before Vista
- exc = 'OS does not support symbolic links.'
- raise FilesystemError(exc, 'link', (path, dest),
- traceback.format_exc())
+ raise FilesystemError(exc, "link", (path, dest), traceback.format_exc())
-def hardlink(path, dest, replace=False):
+def hardlink(path: bytes, dest: bytes, replace: bool = False):
"""Create a hard link from path to `dest`. Raises an OSError if
`dest` already exists, unless `replace` is True. Does nothing if
`path` == `dest`.
@@ -546,22 +581,34 @@ def hardlink(path, dest, replace=False):
return
if os.path.exists(syspath(dest)) and not replace:
- raise FilesystemError('file exists', 'rename', (path, dest))
+ raise FilesystemError("file exists", "rename", (path, dest))
try:
os.link(syspath(path), syspath(dest))
except NotImplementedError:
- raise FilesystemError('OS does not support hard links.'
- 'link', (path, dest), traceback.format_exc())
+ raise FilesystemError(
+ "OS does not support hard links." "link",
+ (path, dest),
+ traceback.format_exc(),
+ )
except OSError as exc:
if exc.errno == errno.EXDEV:
- raise FilesystemError('Cannot hard link across devices.'
- 'link', (path, dest), traceback.format_exc())
+ raise FilesystemError(
+ "Cannot hard link across devices." "link",
+ (path, dest),
+ traceback.format_exc(),
+ )
else:
- raise FilesystemError(exc, 'link', (path, dest),
- traceback.format_exc())
+ raise FilesystemError(
+ exc, "link", (path, dest), traceback.format_exc()
+ )
-def reflink(path, dest, replace=False, fallback=False):
+def reflink(
+ path: bytes,
+ dest: bytes,
+ replace: bool = False,
+ fallback: bool = False,
+):
"""Create a reflink from `dest` to `path`.
Raise an `OSError` if `dest` already exists, unless `replace` is
@@ -578,7 +625,7 @@ def reflink(path, dest, replace=False, fallback=False):
return
if os.path.exists(syspath(dest)) and not replace:
- raise FilesystemError('file exists', 'rename', (path, dest))
+ raise FilesystemError("file exists", "rename", (path, dest))
try:
pyreflink.reflink(path, dest)
@@ -586,11 +633,15 @@ def reflink(path, dest, replace=False, fallback=False):
if fallback:
copy(path, dest, replace)
else:
- raise FilesystemError('OS/filesystem does not support reflinks.',
- 'link', (path, dest), traceback.format_exc())
+ raise FilesystemError(
+ "OS/filesystem does not support reflinks.",
+ "link",
+ (path, dest),
+ traceback.format_exc(),
+ )
-def unique_path(path):
+def unique_path(path: bytes) -> bytes:
"""Returns a version of ``path`` that does not exist on the
filesystem. Specifically, if ``path` itself already exists, then
something unique is appended to the path.
@@ -599,15 +650,15 @@ def unique_path(path):
return path
base, ext = os.path.splitext(path)
- match = re.search(br'\.(\d)+$', base)
+ match = re.search(rb"\.(\d)+$", base)
if match:
num = int(match.group(1))
- base = base[:match.start()]
+ base = base[: match.start()]
else:
num = 0
while True:
num += 1
- suffix = f'.{num}'.encode() + ext
+ suffix = f".{num}".encode() + ext
new_path = base + suffix
if not os.path.exists(new_path):
return new_path
@@ -617,17 +668,20 @@ def unique_path(path):
# Unix. They are forbidden here because they cause problems on Samba
# shares, which are sufficiently common as to cause frequent problems.
# https://msdn.microsoft.com/en-us/library/windows/desktop/aa365247.aspx
-CHAR_REPLACE = [
- (re.compile(r'[\\/]'), '_'), # / and \ -- forbidden everywhere.
- (re.compile(r'^\.'), '_'), # Leading dot (hidden files on Unix).
- (re.compile(r'[\x00-\x1f]'), ''), # Control characters.
- (re.compile(r'[<>:"\?\*\|]'), '_'), # Windows "reserved characters".
- (re.compile(r'\.$'), '_'), # Trailing dots.
- (re.compile(r'\s+$'), ''), # Trailing whitespace.
+CHAR_REPLACE: List[Tuple[Pattern, str]] = [
+ (re.compile(r"[\\/]"), "_"), # / and \ -- forbidden everywhere.
+ (re.compile(r"^\."), "_"), # Leading dot (hidden files on Unix).
+ (re.compile(r"[\x00-\x1f]"), ""), # Control characters.
+ (re.compile(r'[<>:"\?\*\|]'), "_"), # Windows "reserved characters".
+ (re.compile(r"\.$"), "_"), # Trailing dots.
+ (re.compile(r"\s+$"), ""), # Trailing whitespace.
]
-def sanitize_path(path, replacements=None):
+def sanitize_path(
+ path: str,
+ replacements: Optional[Sequence[Sequence[Union[Pattern, str]]]] = None,
+) -> str:
"""Takes a path (as a Unicode string) and makes sure that it is
legal. Returns a new path. Only works with fragments; won't work
reliably on Windows when a path begins with a drive letter. Path
@@ -640,7 +694,7 @@ def sanitize_path(path, replacements=None):
comps = components(path)
if not comps:
- return ''
+ return ""
for i, comp in enumerate(comps):
for regex, repl in replacements:
comp = regex.sub(repl, comp)
@@ -648,7 +702,7 @@ def sanitize_path(path, replacements=None):
return os.path.join(*comps)
-def truncate_path(path, length=MAX_FILENAME_LENGTH):
+def truncate_path(path: AnyStr, length: int = MAX_FILENAME_LENGTH) -> AnyStr:
"""Given a bytestring path or a Unicode path fragment, truncate the
components to a legal length. In the last component, the extension
is preserved.
@@ -659,13 +713,19 @@ def truncate_path(path, length=MAX_FILENAME_LENGTH):
base, ext = os.path.splitext(comps[-1])
if ext:
# Last component has an extension.
- base = base[:length - len(ext)]
+ base = base[: length - len(ext)]
out[-1] = base + ext
return os.path.join(*out)
-def _legalize_stage(path, replacements, length, extension, fragment):
+def _legalize_stage(
+ path: str,
+ replacements: Optional[Sequence[Sequence[Union[Pattern, str]]]],
+ length: int,
+ extension: str,
+ fragment: bool,
+) -> Tuple[Bytes_or_String, bool]:
"""Perform a single round of path legalization steps
(sanitation/replacement, encoding from Unicode to bytes,
extension-appending, and truncation). Return the path (Unicode if
@@ -677,7 +737,7 @@ def _legalize_stage(path, replacements, length, extension, fragment):
# Encode for the filesystem.
if not fragment:
- path = bytestring_path(path)
+ path = bytestring_path(path) # type: ignore
# Preserve extension.
path += extension.lower()
@@ -689,7 +749,13 @@ def _legalize_stage(path, replacements, length, extension, fragment):
return path, path != pre_truncate_path
-def legalize_path(path, replacements, length, extension, fragment):
+def legalize_path(
+ path: str,
+ replacements: Optional[Sequence[Sequence[Union[Pattern, str]]]],
+ length: int,
+ extension: bytes,
+ fragment: bool,
+) -> Tuple[Union[Bytes_or_String, bool]]:
"""Given a path-like Unicode string, produce a legal path. Return
the path and a flag indicating whether some replacements had to be
ignored (see below).
@@ -713,7 +779,7 @@ def legalize_path(path, replacements, length, extension, fragment):
if fragment:
# Outputting Unicode.
- extension = extension.decode('utf-8', 'ignore')
+ extension = extension.decode("utf-8", "ignore")
first_stage_path, _ = _legalize_stage(
path, replacements, length, extension, fragment
@@ -737,104 +803,44 @@ def legalize_path(path, replacements, length, extension, fragment):
return second_stage_path, retruncated
-def py3_path(path):
- """Convert a bytestring path to Unicode on Python 3 only. On Python
- 2, return the bytestring path unchanged.
-
- This helps deal with APIs on Python 3 that *only* accept Unicode
- (i.e., `str` objects). I philosophically disagree with this
- decision, because paths are sadly bytes on Unix, but that's the way
- it is. So this function helps us "smuggle" the true bytes data
- through APIs that took Python 3's Unicode mandate too seriously.
- """
- if isinstance(path, str):
- return path
- assert isinstance(path, bytes)
- return os.fsdecode(path)
-
-
-def str2bool(value):
+def str2bool(value: str) -> bool:
"""Returns a boolean reflecting a human-entered string."""
- return value.lower() in ('yes', '1', 'true', 't', 'y')
+ return value.lower() in ("yes", "1", "true", "t", "y")
-def as_string(value):
+def as_string(value: Any) -> str:
"""Convert a value to a Unicode object for matching with a query.
None becomes the empty string. Bytestrings are silently decoded.
"""
if value is None:
- return ''
+ return ""
elif isinstance(value, memoryview):
- return bytes(value).decode('utf-8', 'ignore')
+ return bytes(value).decode("utf-8", "ignore")
elif isinstance(value, bytes):
- return value.decode('utf-8', 'ignore')
+ return value.decode("utf-8", "ignore")
else:
return str(value)
-def text_string(value, encoding='utf-8'):
- """Convert a string, which can either be bytes or unicode, to
- unicode.
-
- Text (unicode) is left untouched; bytes are decoded. This is useful
- to convert from a "native string" (bytes on Python 2, str on Python
- 3) to a consistently unicode value.
- """
- if isinstance(value, bytes):
- return value.decode(encoding)
- return value
-
-
-def plurality(objs):
+def plurality(objs: Sequence[T]) -> T:
"""Given a sequence of hashble objects, returns the object that
is most common in the set and the its number of appearance. The
sequence must contain at least one object.
"""
c = Counter(objs)
if not c:
- raise ValueError('sequence must be non-empty')
+ raise ValueError("sequence must be non-empty")
return c.most_common(1)[0]
-def cpu_count():
- """Return the number of hardware thread contexts (cores or SMT
- threads) in the system.
- """
- # Adapted from the soundconverter project:
- # https://github.com/kassoulet/soundconverter
- if sys.platform == 'win32':
- try:
- num = int(os.environ['NUMBER_OF_PROCESSORS'])
- except (ValueError, KeyError):
- num = 0
- elif sys.platform == 'darwin':
- try:
- num = int(command_output([
- '/usr/sbin/sysctl',
- '-n',
- 'hw.ncpu',
- ]).stdout)
- except (ValueError, OSError, subprocess.CalledProcessError):
- num = 0
- else:
- try:
- num = os.sysconf('SC_NPROCESSORS_ONLN')
- except (ValueError, OSError, AttributeError):
- num = 0
- if num >= 1:
- return num
- else:
- return 1
-
-
-def convert_command_args(args):
- """Convert command arguments to bytestrings on Python 2 and
- surrogate-escaped strings on Python 3."""
+def convert_command_args(args: List[bytes]) -> List[str]:
+ """Convert command arguments, which may either be `bytes` or `str`
+ objects, to uniformly surrogate-escaped strings."""
assert isinstance(args, list)
- def convert(arg):
+ def convert(arg) -> str:
if isinstance(arg, bytes):
- arg = arg.decode(arg_encoding(), 'surrogateescape')
+ return os.fsdecode(arg)
return arg
return [convert(a) for a in args]
@@ -844,7 +850,10 @@ def convert_command_args(args):
CommandOutput = namedtuple("CommandOutput", ("stdout", "stderr"))
-def command_output(cmd, shell=False):
+def command_output(
+ cmd: List[Bytes_or_String],
+ shell: bool = False,
+) -> CommandOutput:
"""Runs the command and returns its output after it has exited.
Returns a CommandOutput. The attributes ``stdout`` and ``stderr`` contain
@@ -864,37 +873,34 @@ def command_output(cmd, shell=False):
"""
cmd = convert_command_args(cmd)
- try: # python >= 3.3
- devnull = subprocess.DEVNULL
- except AttributeError:
- devnull = open(os.devnull, 'r+b')
+ devnull = subprocess.DEVNULL
proc = subprocess.Popen(
cmd,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
stdin=devnull,
- close_fds=platform.system() != 'Windows',
- shell=shell
+ close_fds=platform.system() != "Windows",
+ shell=shell,
)
stdout, stderr = proc.communicate()
if proc.returncode:
raise subprocess.CalledProcessError(
returncode=proc.returncode,
- cmd=' '.join(cmd),
+ cmd=" ".join(map(str, cmd)),
output=stdout + stderr,
)
return CommandOutput(stdout, stderr)
-def max_filename_length(path, limit=MAX_FILENAME_LENGTH):
+def max_filename_length(path: AnyStr, limit=MAX_FILENAME_LENGTH) -> int:
"""Attempt to determine the maximum filename length for the
filesystem containing `path`. If the value is greater than `limit`,
then `limit` is used instead (to prevent errors when a filesystem
misreports its capacity). If it cannot be determined (e.g., on
Windows), return `limit`.
"""
- if hasattr(os, 'statvfs'):
+ if hasattr(os, "statvfs"):
try:
res = os.statvfs(path)
except OSError:
@@ -904,34 +910,34 @@ def max_filename_length(path, limit=MAX_FILENAME_LENGTH):
return limit
-def open_anything():
+def open_anything() -> str:
"""Return the system command that dispatches execution to the correct
program.
"""
sys_name = platform.system()
- if sys_name == 'Darwin':
- base_cmd = 'open'
- elif sys_name == 'Windows':
- base_cmd = 'start'
+ if sys_name == "Darwin":
+ base_cmd = "open"
+ elif sys_name == "Windows":
+ base_cmd = "start"
else: # Assume Unix
- base_cmd = 'xdg-open'
+ base_cmd = "xdg-open"
return base_cmd
-def editor_command():
+def editor_command() -> str:
"""Get a command for opening a text file.
- Use the `EDITOR` environment variable by default. If it is not
- present, fall back to `open_anything()`, the platform-specific tool
- for opening files in general.
+ First try environment variable `VISUAL` followed by `EDITOR`. As last resort
+ fall back to `open_anything()`, the platform-specific tool for opening files
+ in general.
+
"""
- editor = os.environ.get('EDITOR')
- if editor:
- return editor
- return open_anything()
+ return (
+ os.environ.get("VISUAL") or os.environ.get("EDITOR") or open_anything()
+ )
-def interactive_open(targets, command):
+def interactive_open(targets: Sequence[str], command: str):
"""Open the files in `targets` by `exec`ing a new `command`, given
as a Unicode string. (The new program takes over, and Python
execution ends: this does not fork a subprocess.)
@@ -953,77 +959,68 @@ def interactive_open(targets, command):
return os.execlp(*args)
-def _windows_long_path_name(short_path):
- """Use Windows' `GetLongPathNameW` via ctypes to get the canonical,
- long path given a short filename.
- """
- if not isinstance(short_path, str):
- short_path = short_path.decode(_fsencoding())
-
- import ctypes
- buf = ctypes.create_unicode_buffer(260)
- get_long_path_name_w = ctypes.windll.kernel32.GetLongPathNameW
- return_value = get_long_path_name_w(short_path, buf, 260)
-
- if return_value == 0 or return_value > 260:
- # An error occurred
- return short_path
- else:
- long_path = buf.value
- # GetLongPathNameW does not change the case of the drive
- # letter.
- if len(long_path) > 1 and long_path[1] == ':':
- long_path = long_path[0].upper() + long_path[1:]
- return long_path
-
-
-def case_sensitive(path):
+def case_sensitive(path: bytes) -> bool:
"""Check whether the filesystem at the given path is case sensitive.
To work best, the path should point to a file or a directory. If the path
does not exist, assume a case sensitive file system on every platform
except Windows.
+
+ Currently only used for absolute paths by beets; may have a trailing
+ path separator.
"""
- # A fallback in case the path does not exist.
- if not os.path.exists(syspath(path)):
- # By default, the case sensitivity depends on the platform.
- return platform.system() != 'Windows'
+ # Look at parent paths until we find a path that actually exists, or
+ # reach the root.
+ while True:
+ head, tail = os.path.split(path)
+ if head == path:
+ # We have reached the root of the file system.
+ # By default, the case sensitivity depends on the platform.
+ return platform.system() != "Windows"
- # If an upper-case version of the path exists but a lower-case
- # version does not, then the filesystem must be case-sensitive.
- # (Otherwise, we have more work to do.)
- if not (os.path.exists(syspath(path.lower())) and
- os.path.exists(syspath(path.upper()))):
- return True
+ # Trailing path separator, or path does not exist.
+ if not tail or not os.path.exists(path):
+ path = head
+ continue
- # Both versions of the path exist on the file system. Check whether
- # they refer to different files by their inodes. Alas,
- # `os.path.samefile` is only available on Unix systems on Python 2.
- if platform.system() != 'Windows':
- return not os.path.samefile(syspath(path.lower()),
- syspath(path.upper()))
+ upper_tail = tail.upper()
+ lower_tail = tail.lower()
- # On Windows, we check whether the canonical, long filenames for the
- # files are the same.
- lower = _windows_long_path_name(path.lower())
- upper = _windows_long_path_name(path.upper())
- return lower != upper
+ # In case we can't tell from the given path name, look at the
+ # parent directory.
+ if upper_tail == lower_tail:
+ path = head
+ continue
+
+ upper_sys = syspath(os.path.join(head, upper_tail))
+ lower_sys = syspath(os.path.join(head, lower_tail))
+
+ # If either the upper-cased or lower-cased path does not exist, the
+ # filesystem must be case-sensitive.
+ # (Otherwise, we have more work to do.)
+ if not os.path.exists(upper_sys) or not os.path.exists(lower_sys):
+ return True
+
+ # Original and both upper- and lower-cased versions of the path
+ # exist on the file system. Check whether they refer to different
+ # files by their inodes (or an alternative method on Windows).
+ return not os.path.samefile(lower_sys, upper_sys)
-def raw_seconds_short(string):
+def raw_seconds_short(string: str) -> float:
"""Formats a human-readable M:SS string as a float (number of seconds).
Raises ValueError if the conversion cannot take place due to `string` not
being in the right format.
"""
- match = re.match(r'^(\d+):([0-5]\d)$', string)
+ match = re.match(r"^(\d+):([0-5]\d)$", string)
if not match:
- raise ValueError('String not in M:SS format')
+ raise ValueError("String not in M:SS format")
minutes, seconds = map(int, match.groups())
return float(minutes * 60 + seconds)
-def asciify_path(path, sep_replace):
+def asciify_path(path: str, sep_replace: str) -> str:
"""Decodes all unicode characters in a path into ASCII equivalents.
Substitutions are provided by the unidecode module. Path separators in the
@@ -1036,22 +1033,20 @@ def asciify_path(path, sep_replace):
# if this platform has an os.altsep, change it to os.sep.
if os.altsep:
path = path.replace(os.altsep, os.sep)
- path_components = path.split(os.sep)
+ path_components: List[Bytes_or_String] = path.split(os.sep)
for index, item in enumerate(path_components):
path_components[index] = unidecode(item).replace(os.sep, sep_replace)
if os.altsep:
path_components[index] = unidecode(item).replace(
- os.altsep,
- sep_replace
+ os.altsep, sep_replace
)
return os.sep.join(path_components)
-def par_map(transform, items):
+def par_map(transform: Callable, items: Iterable):
"""Apply the function `transform` to all the elements in the
iterable `items`, like `map(transform, items)` but with no return
- value. The map *might* happen in parallel: it's parallel on Python 3
- and sequential on Python 2.
+ value.
The parallelism uses threads (not processes), so this is only useful
for IO-bound `transform`s.
@@ -1062,42 +1057,61 @@ def par_map(transform, items):
pool.join()
-def lazy_property(func):
- """A decorator that creates a lazily evaluated property. On first access,
- the property is assigned the return value of `func`. This first value is
- stored, so that future accesses do not have to evaluate `func` again.
-
- This behaviour is useful when `func` is expensive to evaluate, and it is
- not certain that the result will be needed.
+class cached_classproperty: # noqa: N801
+ """A decorator implementing a read-only property that is *lazy* in
+ the sense that the getter is only invoked once. Subsequent accesses
+ through *any* instance use the cached result.
"""
- field_name = '_' + func.__name__
- @property
- @functools.wraps(func)
- def wrapper(self):
- if hasattr(self, field_name):
- return getattr(self, field_name)
+ def __init__(self, getter):
+ self.getter = getter
+ self.cache = {}
- value = func(self)
- setattr(self, field_name, value)
- return value
+ def __get__(self, instance, owner):
+ if owner not in self.cache:
+ self.cache[owner] = self.getter(owner)
- return wrapper
+ return self.cache[owner]
-def decode_commandline_path(path):
- """Prepare a path for substitution into commandline template.
+def get_module_tempdir(module: str) -> Path:
+ """Return the temporary directory for the given module.
- On Python 3, we need to construct the subprocess commands to invoke as a
- Unicode string. On Unix, this is a little unfortunate---the OS is
- expecting bytes---so we use surrogate escaping and decode with the
- argument encoding, which is the same encoding that will then be
- *reversed* to recover the same bytes before invoking the OS. On
- Windows, we want to preserve the Unicode filename "as is."
+ The directory is created within the `/tmp/beets/` directory on
+ Linux (or the equivalent temporary directory on other systems).
+
+ Dots in the module name are replaced by underscores.
"""
- # On Python 3, the template is a Unicode string, which only supports
- # substitution of Unicode variables.
- if platform.system() == 'Windows':
- return path.decode(_fsencoding())
- else:
- return path.decode(arg_encoding(), 'surrogateescape')
+ module = module.replace("beets.", "").replace(".", "_")
+ return Path(tempfile.gettempdir()) / "beets" / module
+
+
+def clean_module_tempdir(module: str) -> None:
+ """Clean the temporary directory for the given module."""
+ tempdir = get_module_tempdir(module)
+ shutil.rmtree(tempdir, ignore_errors=True)
+ with suppress(OSError):
+ # remove parent (/tmp/beets) directory if it is empty
+ tempdir.parent.rmdir()
+
+
+def get_temp_filename(
+ module: str,
+ prefix: str = "",
+ path: PathLike | None = None,
+ suffix: str = "",
+) -> bytes:
+ """Return temporary filename for the given module and prefix.
+
+ The filename starts with the given `prefix`.
+ If 'suffix' is given, it is used a the file extension.
+ If 'path' is given, we use the same suffix.
+ """
+ if not suffix and path:
+ suffix = Path(os.fsdecode(path)).suffix
+
+ tempdir = get_module_tempdir(module)
+ tempdir.mkdir(parents=True, exist_ok=True)
+
+ _, filename = tempfile.mkstemp(dir=tempdir, prefix=prefix, suffix=suffix)
+ return bytestring_path(filename)
diff --git a/lib/beets/util/artresizer.py b/lib/beets/util/artresizer.py
index 8683e228..09cc29e0 100644
--- a/lib/beets/util/artresizer.py
+++ b/lib/beets/util/artresizer.py
@@ -16,23 +16,20 @@
public resizing proxy if neither is available.
"""
-import subprocess
import os
import os.path
+import platform
import re
-from tempfile import NamedTemporaryFile
+import subprocess
+from itertools import chain
from urllib.parse import urlencode
-from beets import logging
-from beets import util
-# Resizing methods
-PIL = 1
-IMAGEMAGICK = 2
-WEBPROXY = 3
+from beets import logging, util
+from beets.util import displayable_path, get_temp_filename, syspath
-PROXY_URL = 'https://images.weserv.nl/'
+PROXY_URL = "https://images.weserv.nl/"
-log = logging.getLogger('beets')
+log = logging.getLogger("beets")
def resize_url(url, maxwidth, quality=0):
@@ -40,265 +37,473 @@ def resize_url(url, maxwidth, quality=0):
maxwidth (preserving aspect ratio).
"""
params = {
- 'url': url.replace('http://', ''),
- 'w': maxwidth,
+ "url": url.replace("http://", ""),
+ "w": maxwidth,
}
if quality > 0:
- params['q'] = quality
+ params["q"] = quality
- return '{}?{}'.format(PROXY_URL, urlencode(params))
+ return "{}?{}".format(PROXY_URL, urlencode(params))
-def temp_file_for(path):
- """Return an unused filename with the same extension as the
- specified path.
- """
- ext = os.path.splitext(path)[1]
- with NamedTemporaryFile(suffix=util.py3_path(ext), delete=False) as f:
- return util.bytestring_path(f.name)
+class LocalBackendNotAvailableError(Exception):
+ pass
-def pil_resize(maxwidth, path_in, path_out=None, quality=0, max_filesize=0):
- """Resize using Python Imaging Library (PIL). Return the output path
- of resized image.
- """
- path_out = path_out or temp_file_for(path_in)
- from PIL import Image
+_NOT_AVAILABLE = object()
- log.debug('artresizer: PIL resizing {0} to {1}',
- util.displayable_path(path_in), util.displayable_path(path_out))
- try:
- im = Image.open(util.syspath(path_in))
- size = maxwidth, maxwidth
- im.thumbnail(size, Image.ANTIALIAS)
+class LocalBackend:
+ @classmethod
+ def available(cls):
+ try:
+ cls.version()
+ return True
+ except LocalBackendNotAvailableError:
+ return False
- if quality == 0:
- # Use PIL's default quality.
- quality = -1
- # progressive=False only affects JPEGs and is the default,
- # but we include it here for explicitness.
- im.save(util.py3_path(path_out), quality=quality, progressive=False)
+class IMBackend(LocalBackend):
+ NAME = "ImageMagick"
- if max_filesize > 0:
- # If maximum filesize is set, we attempt to lower the quality of
- # jpeg conversion by a proportional amount, up to 3 attempts
- # First, set the maximum quality to either provided, or 95
- if quality > 0:
- lower_qual = quality
- else:
- lower_qual = 95
- for i in range(5):
- # 5 attempts is an abitrary choice
- filesize = os.stat(util.syspath(path_out)).st_size
- log.debug("PIL Pass {0} : Output size: {1}B", i, filesize)
- if filesize <= max_filesize:
- return path_out
- # The relationship between filesize & quality will be
- # image dependent.
- lower_qual -= 10
- # Restrict quality dropping below 10
- if lower_qual < 10:
- lower_qual = 10
- # Use optimize flag to improve filesize decrease
- im.save(util.py3_path(path_out), quality=lower_qual,
- optimize=True, progressive=False)
- log.warning("PIL Failed to resize file to below {0}B",
- max_filesize)
- return path_out
+ # These fields are used as a cache for `version()`. `_legacy` indicates
+ # whether the modern `magick` binary is available or whether to fall back
+ # to the old-style `convert`, `identify`, etc. commands.
+ _version = None
+ _legacy = None
+ @classmethod
+ def version(cls):
+ """Obtain and cache ImageMagick version.
+
+ Raises `LocalBackendNotAvailableError` if not available.
+ """
+ if cls._version is None:
+ for cmd_name, legacy in (("magick", False), ("convert", True)):
+ try:
+ out = util.command_output([cmd_name, "--version"]).stdout
+ except (subprocess.CalledProcessError, OSError) as exc:
+ log.debug("ImageMagick version check failed: {}", exc)
+ cls._version = _NOT_AVAILABLE
+ else:
+ if b"imagemagick" in out.lower():
+ pattern = rb".+ (\d+)\.(\d+)\.(\d+).*"
+ match = re.search(pattern, out)
+ if match:
+ cls._version = (
+ int(match.group(1)),
+ int(match.group(2)),
+ int(match.group(3)),
+ )
+ cls._legacy = legacy
+
+ if cls._version is _NOT_AVAILABLE:
+ raise LocalBackendNotAvailableError()
else:
- return path_out
- except OSError:
- log.error("PIL cannot create thumbnail for '{0}'",
- util.displayable_path(path_in))
- return path_in
+ return cls._version
+ def __init__(self):
+ """Initialize a wrapper around ImageMagick for local image operations.
-def im_resize(maxwidth, path_in, path_out=None, quality=0, max_filesize=0):
- """Resize using ImageMagick.
+ Stores the ImageMagick version and legacy flag. If ImageMagick is not
+ available, raise an Exception.
+ """
+ self.version()
- Use the ``magick`` program or ``convert`` on older versions. Return
- the output path of resized image.
- """
- path_out = path_out or temp_file_for(path_in)
- log.debug('artresizer: ImageMagick resizing {0} to {1}',
- util.displayable_path(path_in), util.displayable_path(path_out))
+ # Use ImageMagick's magick binary when it's available.
+ # If it's not, fall back to the older, separate convert
+ # and identify commands.
+ if self._legacy:
+ self.convert_cmd = ["convert"]
+ self.identify_cmd = ["identify"]
+ self.compare_cmd = ["compare"]
+ else:
+ self.convert_cmd = ["magick"]
+ self.identify_cmd = ["magick", "identify"]
+ self.compare_cmd = ["magick", "compare"]
- # "-resize WIDTHx>" shrinks images with the width larger
- # than the given width while maintaining the aspect ratio
- # with regards to the height.
- # ImageMagick already seems to default to no interlace, but we include it
- # here for the sake of explicitness.
- cmd = ArtResizer.shared.im_convert_cmd + [
- util.syspath(path_in, prefix=False),
- '-resize', f'{maxwidth}x>',
- '-interlace', 'none',
- ]
+ def resize(
+ self, maxwidth, path_in, path_out=None, quality=0, max_filesize=0
+ ):
+ """Resize using ImageMagick.
- if quality > 0:
- cmd += ['-quality', f'{quality}']
+ Use the ``magick`` program or ``convert`` on older versions. Return
+ the output path of resized image.
+ """
+ if not path_out:
+ path_out = get_temp_filename(__name__, "resize_IM_", path_in)
- # "-define jpeg:extent=SIZEb" sets the target filesize for imagemagick to
- # SIZE in bytes.
- if max_filesize > 0:
- cmd += ['-define', f'jpeg:extent={max_filesize}b']
-
- cmd.append(util.syspath(path_out, prefix=False))
-
- try:
- util.command_output(cmd)
- except subprocess.CalledProcessError:
- log.warning('artresizer: IM convert failed for {0}',
- util.displayable_path(path_in))
- return path_in
-
- return path_out
-
-
-BACKEND_FUNCS = {
- PIL: pil_resize,
- IMAGEMAGICK: im_resize,
-}
-
-
-def pil_getsize(path_in):
- from PIL import Image
-
- try:
- im = Image.open(util.syspath(path_in))
- return im.size
- except OSError as exc:
- log.error("PIL could not read file {}: {}",
- util.displayable_path(path_in), exc)
-
-
-def im_getsize(path_in):
- cmd = ArtResizer.shared.im_identify_cmd + \
- ['-format', '%w %h', util.syspath(path_in, prefix=False)]
-
- try:
- out = util.command_output(cmd).stdout
- except subprocess.CalledProcessError as exc:
- log.warning('ImageMagick size query failed')
log.debug(
- '`convert` exited with (status {}) when '
- 'getting size with command {}:\n{}',
- exc.returncode, cmd, exc.output.strip()
+ "artresizer: ImageMagick resizing {0} to {1}",
+ displayable_path(path_in),
+ displayable_path(path_out),
)
- return
- try:
- return tuple(map(int, out.split(b' ')))
- except IndexError:
- log.warning('Could not understand IM output: {0!r}', out)
+ # "-resize WIDTHx>" shrinks images with the width larger
+ # than the given width while maintaining the aspect ratio
+ # with regards to the height.
+ # ImageMagick already seems to default to no interlace, but we include
+ # it here for the sake of explicitness.
+ cmd = self.convert_cmd + [
+ syspath(path_in, prefix=False),
+ "-resize",
+ f"{maxwidth}x>",
+ "-interlace",
+ "none",
+ ]
-BACKEND_GET_SIZE = {
- PIL: pil_getsize,
- IMAGEMAGICK: im_getsize,
-}
+ if quality > 0:
+ cmd += ["-quality", f"{quality}"]
+ # "-define jpeg:extent=SIZEb" sets the target filesize for imagemagick
+ # to SIZE in bytes.
+ if max_filesize > 0:
+ cmd += ["-define", f"jpeg:extent={max_filesize}b"]
-def pil_deinterlace(path_in, path_out=None):
- path_out = path_out or temp_file_for(path_in)
- from PIL import Image
+ cmd.append(syspath(path_out, prefix=False))
+
+ try:
+ util.command_output(cmd)
+ except subprocess.CalledProcessError:
+ log.warning(
+ "artresizer: IM convert failed for {0}",
+ displayable_path(path_in),
+ )
+ return path_in
- try:
- im = Image.open(util.syspath(path_in))
- im.save(util.py3_path(path_out), progressive=False)
return path_out
- except IOError:
- return path_in
+ def get_size(self, path_in):
+ cmd = self.identify_cmd + [
+ "-format",
+ "%w %h",
+ syspath(path_in, prefix=False),
+ ]
-def im_deinterlace(path_in, path_out=None):
- path_out = path_out or temp_file_for(path_in)
+ try:
+ out = util.command_output(cmd).stdout
+ except subprocess.CalledProcessError as exc:
+ log.warning("ImageMagick size query failed")
+ log.debug(
+ "`convert` exited with (status {}) when "
+ "getting size with command {}:\n{}",
+ exc.returncode,
+ cmd,
+ exc.output.strip(),
+ )
+ return None
+ try:
+ return tuple(map(int, out.split(b" ")))
+ except IndexError:
+ log.warning("Could not understand IM output: {0!r}", out)
+ return None
- cmd = ArtResizer.shared.im_convert_cmd + [
- util.syspath(path_in, prefix=False),
- '-interlace', 'none',
- util.syspath(path_out, prefix=False),
- ]
+ def deinterlace(self, path_in, path_out=None):
+ if not path_out:
+ path_out = get_temp_filename(__name__, "deinterlace_IM_", path_in)
- try:
- util.command_output(cmd)
- return path_out
- except subprocess.CalledProcessError:
- return path_in
+ cmd = self.convert_cmd + [
+ syspath(path_in, prefix=False),
+ "-interlace",
+ "none",
+ syspath(path_out, prefix=False),
+ ]
+ try:
+ util.command_output(cmd)
+ return path_out
+ except subprocess.CalledProcessError:
+ # FIXME: Should probably issue a warning?
+ return path_in
-DEINTERLACE_FUNCS = {
- PIL: pil_deinterlace,
- IMAGEMAGICK: im_deinterlace,
-}
+ def get_format(self, filepath):
+ cmd = self.identify_cmd + ["-format", "%[magick]", syspath(filepath)]
+ try:
+ return util.command_output(cmd).stdout
+ except subprocess.CalledProcessError:
+ # FIXME: Should probably issue a warning?
+ return None
-def im_get_format(filepath):
- cmd = ArtResizer.shared.im_identify_cmd + [
- '-format', '%[magick]',
- util.syspath(filepath)
- ]
+ def convert_format(self, source, target, deinterlaced):
+ cmd = self.convert_cmd + [
+ syspath(source),
+ *(["-interlace", "none"] if deinterlaced else []),
+ syspath(target),
+ ]
- try:
- return util.command_output(cmd).stdout
- except subprocess.CalledProcessError:
- return None
-
-
-def pil_get_format(filepath):
- from PIL import Image, UnidentifiedImageError
-
- try:
- with Image.open(util.syspath(filepath)) as im:
- return im.format
- except (ValueError, TypeError, UnidentifiedImageError, FileNotFoundError):
- log.exception("failed to detect image format for {}", filepath)
- return None
-
-
-BACKEND_GET_FORMAT = {
- PIL: pil_get_format,
- IMAGEMAGICK: im_get_format,
-}
-
-
-def im_convert_format(source, target, deinterlaced):
- cmd = ArtResizer.shared.im_convert_cmd + [
- util.syspath(source),
- *(["-interlace", "none"] if deinterlaced else []),
- util.syspath(target),
- ]
-
- try:
- subprocess.check_call(
- cmd,
- stderr=subprocess.DEVNULL,
- stdout=subprocess.DEVNULL
- )
- return target
- except subprocess.CalledProcessError:
- return source
-
-
-def pil_convert_format(source, target, deinterlaced):
- from PIL import Image, UnidentifiedImageError
-
- try:
- with Image.open(util.syspath(source)) as im:
- im.save(util.py3_path(target), progressive=not deinterlaced)
+ try:
+ subprocess.check_call(
+ cmd, stderr=subprocess.DEVNULL, stdout=subprocess.DEVNULL
+ )
return target
- except (ValueError, TypeError, UnidentifiedImageError, FileNotFoundError,
- OSError):
- log.exception("failed to convert image {} -> {}", source, target)
- return source
+ except subprocess.CalledProcessError:
+ # FIXME: Should probably issue a warning?
+ return source
+
+ @property
+ def can_compare(self):
+ return self.version() > (6, 8, 7)
+
+ def compare(self, im1, im2, compare_threshold):
+ is_windows = platform.system() == "Windows"
+
+ # Converting images to grayscale tends to minimize the weight
+ # of colors in the diff score. So we first convert both images
+ # to grayscale and then pipe them into the `compare` command.
+ # On Windows, ImageMagick doesn't support the magic \\?\ prefix
+ # on paths, so we pass `prefix=False` to `syspath`.
+ convert_cmd = self.convert_cmd + [
+ syspath(im2, prefix=False),
+ syspath(im1, prefix=False),
+ "-colorspace",
+ "gray",
+ "MIFF:-",
+ ]
+ compare_cmd = self.compare_cmd + [
+ "-define",
+ "phash:colorspaces=sRGB,HCLp",
+ "-metric",
+ "PHASH",
+ "-",
+ "null:",
+ ]
+ log.debug(
+ "comparing images with pipeline {} | {}", convert_cmd, compare_cmd
+ )
+ convert_proc = subprocess.Popen(
+ convert_cmd,
+ stdout=subprocess.PIPE,
+ stderr=subprocess.PIPE,
+ close_fds=not is_windows,
+ )
+ compare_proc = subprocess.Popen(
+ compare_cmd,
+ stdin=convert_proc.stdout,
+ stdout=subprocess.PIPE,
+ stderr=subprocess.PIPE,
+ close_fds=not is_windows,
+ )
+
+ # Check the convert output. We're not interested in the
+ # standard output; that gets piped to the next stage.
+ convert_proc.stdout.close()
+ convert_stderr = convert_proc.stderr.read()
+ convert_proc.stderr.close()
+ convert_proc.wait()
+ if convert_proc.returncode:
+ log.debug(
+ "ImageMagick convert failed with status {}: {!r}",
+ convert_proc.returncode,
+ convert_stderr,
+ )
+ return None
+
+ # Check the compare output.
+ stdout, stderr = compare_proc.communicate()
+ if compare_proc.returncode:
+ if compare_proc.returncode != 1:
+ log.debug(
+ "ImageMagick compare failed: {0}, {1}",
+ displayable_path(im2),
+ displayable_path(im1),
+ )
+ return None
+ out_str = stderr
+ else:
+ out_str = stdout
+
+ try:
+ phash_diff = float(out_str)
+ except ValueError:
+ log.debug("IM output is not a number: {0!r}", out_str)
+ return None
+
+ log.debug("ImageMagick compare score: {0}", phash_diff)
+ return phash_diff <= compare_threshold
+
+ @property
+ def can_write_metadata(self):
+ return True
+
+ def write_metadata(self, file, metadata):
+ assignments = list(
+ chain.from_iterable(("-set", k, v) for k, v in metadata.items())
+ )
+ command = self.convert_cmd + [file, *assignments, file]
+
+ util.command_output(command)
-BACKEND_CONVERT_IMAGE_FORMAT = {
- PIL: pil_convert_format,
- IMAGEMAGICK: im_convert_format,
-}
+class PILBackend(LocalBackend):
+ NAME = "PIL"
+
+ @classmethod
+ def version(cls):
+ try:
+ __import__("PIL", fromlist=["Image"])
+ except ImportError:
+ raise LocalBackendNotAvailableError()
+
+ def __init__(self):
+ """Initialize a wrapper around PIL for local image operations.
+
+ If PIL is not available, raise an Exception.
+ """
+ self.version()
+
+ def resize(
+ self, maxwidth, path_in, path_out=None, quality=0, max_filesize=0
+ ):
+ """Resize using Python Imaging Library (PIL). Return the output path
+ of resized image.
+ """
+ if not path_out:
+ path_out = get_temp_filename(__name__, "resize_PIL_", path_in)
+
+ from PIL import Image
+
+ log.debug(
+ "artresizer: PIL resizing {0} to {1}",
+ displayable_path(path_in),
+ displayable_path(path_out),
+ )
+
+ try:
+ im = Image.open(syspath(path_in))
+ size = maxwidth, maxwidth
+ im.thumbnail(size, Image.Resampling.LANCZOS)
+
+ if quality == 0:
+ # Use PIL's default quality.
+ quality = -1
+
+ # progressive=False only affects JPEGs and is the default,
+ # but we include it here for explicitness.
+ im.save(os.fsdecode(path_out), quality=quality, progressive=False)
+
+ if max_filesize > 0:
+ # If maximum filesize is set, we attempt to lower the quality
+ # of jpeg conversion by a proportional amount, up to 3 attempts
+ # First, set the maximum quality to either provided, or 95
+ if quality > 0:
+ lower_qual = quality
+ else:
+ lower_qual = 95
+ for i in range(5):
+ # 5 attempts is an arbitrary choice
+ filesize = os.stat(syspath(path_out)).st_size
+ log.debug("PIL Pass {0} : Output size: {1}B", i, filesize)
+ if filesize <= max_filesize:
+ return path_out
+ # The relationship between filesize & quality will be
+ # image dependent.
+ lower_qual -= 10
+ # Restrict quality dropping below 10
+ if lower_qual < 10:
+ lower_qual = 10
+ # Use optimize flag to improve filesize decrease
+ im.save(
+ os.fsdecode(path_out),
+ quality=lower_qual,
+ optimize=True,
+ progressive=False,
+ )
+ log.warning(
+ "PIL Failed to resize file to below {0}B", max_filesize
+ )
+ return path_out
+
+ else:
+ return path_out
+ except OSError:
+ log.error(
+ "PIL cannot create thumbnail for '{0}'",
+ displayable_path(path_in),
+ )
+ return path_in
+
+ def get_size(self, path_in):
+ from PIL import Image
+
+ try:
+ im = Image.open(syspath(path_in))
+ return im.size
+ except OSError as exc:
+ log.error(
+ "PIL could not read file {}: {}", displayable_path(path_in), exc
+ )
+ return None
+
+ def deinterlace(self, path_in, path_out=None):
+ if not path_out:
+ path_out = get_temp_filename(__name__, "deinterlace_PIL_", path_in)
+
+ from PIL import Image
+
+ try:
+ im = Image.open(syspath(path_in))
+ im.save(os.fsdecode(path_out), progressive=False)
+ return path_out
+ except OSError:
+ # FIXME: Should probably issue a warning?
+ return path_in
+
+ def get_format(self, filepath):
+ from PIL import Image, UnidentifiedImageError
+
+ try:
+ with Image.open(syspath(filepath)) as im:
+ return im.format
+ except (
+ ValueError,
+ TypeError,
+ UnidentifiedImageError,
+ FileNotFoundError,
+ ):
+ log.exception("failed to detect image format for {}", filepath)
+ return None
+
+ def convert_format(self, source, target, deinterlaced):
+ from PIL import Image, UnidentifiedImageError
+
+ try:
+ with Image.open(syspath(source)) as im:
+ im.save(os.fsdecode(target), progressive=not deinterlaced)
+ return target
+ except (
+ ValueError,
+ TypeError,
+ UnidentifiedImageError,
+ FileNotFoundError,
+ OSError,
+ ):
+ log.exception("failed to convert image {} -> {}", source, target)
+ return source
+
+ @property
+ def can_compare(self):
+ return False
+
+ def compare(self, im1, im2, compare_threshold):
+ # It is an error to call this when ArtResizer.can_compare is not True.
+ raise NotImplementedError()
+
+ @property
+ def can_write_metadata(self):
+ return True
+
+ def write_metadata(self, file, metadata):
+ from PIL import Image, PngImagePlugin
+
+ # FIXME: Detect and handle other file types (currently, the only user
+ # is the thumbnails plugin, which generates PNG images).
+ im = Image.open(syspath(file))
+ meta = PngImagePlugin.PngInfo()
+ for k, v in metadata.items():
+ meta.add_text(k, v, 0)
+ im.save(os.fsdecode(file), "PNG", pnginfo=meta)
class Shareable(type):
@@ -319,28 +524,36 @@ class Shareable(type):
return cls._instance
+BACKEND_CLASSES = [
+ IMBackend,
+ PILBackend,
+]
+
+
class ArtResizer(metaclass=Shareable):
- """A singleton class that performs image resizes.
- """
+ """A singleton class that performs image resizes."""
def __init__(self):
- """Create a resizer object with an inferred method.
- """
- self.method = self._check_method()
- log.debug("artresizer: method is {0}", self.method)
- self.can_compare = self._can_compare()
+ """Create a resizer object with an inferred method."""
+ # Check if a local backend is available, and store an instance of the
+ # backend class. Otherwise, fallback to the web proxy.
+ for backend_cls in BACKEND_CLASSES:
+ try:
+ self.local_method = backend_cls()
+ log.debug(f"artresizer: method is {self.local_method.NAME}")
+ break
+ except LocalBackendNotAvailableError:
+ continue
+ else:
+ log.debug("artresizer: method is WEBPROXY")
+ self.local_method = None
- # Use ImageMagick's magick binary when it's available. If it's
- # not, fall back to the older, separate convert and identify
- # commands.
- if self.method[0] == IMAGEMAGICK:
- self.im_legacy = self.method[2]
- if self.im_legacy:
- self.im_convert_cmd = ['convert']
- self.im_identify_cmd = ['identify']
- else:
- self.im_convert_cmd = ['magick']
- self.im_identify_cmd = ['magick', 'identify']
+ @property
+ def method(self):
+ if self.local:
+ return self.local_method.NAME
+ else:
+ return "WEBPROXY"
def resize(
self, maxwidth, path_in, path_out=None, quality=0, max_filesize=0
@@ -351,17 +564,26 @@ class ArtResizer(metaclass=Shareable):
For WEBPROXY, returns `path_in` unmodified.
"""
if self.local:
- func = BACKEND_FUNCS[self.method[0]]
- return func(maxwidth, path_in, path_out,
- quality=quality, max_filesize=max_filesize)
+ return self.local_method.resize(
+ maxwidth,
+ path_in,
+ path_out,
+ quality=quality,
+ max_filesize=max_filesize,
+ )
else:
+ # Handled by `proxy_url` already.
return path_in
def deinterlace(self, path_in, path_out=None):
+ """Deinterlace an image.
+
+ Only available locally.
+ """
if self.local:
- func = DEINTERLACE_FUNCS[self.method[0]]
- return func(path_in, path_out)
+ return self.local_method.deinterlace(path_in, path_out)
else:
+ # FIXME: Should probably issue a warning?
return path_in
def proxy_url(self, maxwidth, url, quality=0):
@@ -370,6 +592,7 @@ class ArtResizer(metaclass=Shareable):
Otherwise, the URL is returned unmodified.
"""
if self.local:
+ # Going to be handled by `resize()`.
return url
else:
return resize_url(url, maxwidth, quality)
@@ -379,7 +602,7 @@ class ArtResizer(metaclass=Shareable):
"""A boolean indicating whether the resizing method is performed
locally (i.e., PIL or ImageMagick).
"""
- return self.method[0] in BACKEND_FUNCS
+ return self.local_method is not None
def get_size(self, path_in):
"""Return the size of an image file as an int couple (width, height)
@@ -388,8 +611,10 @@ class ArtResizer(metaclass=Shareable):
Only available locally.
"""
if self.local:
- func = BACKEND_GET_SIZE[self.method[0]]
- return func(path_in)
+ return self.local_method.get_size(path_in)
+ else:
+ # FIXME: Should probably issue a warning?
+ return path_in
def get_format(self, path_in):
"""Returns the format of the image as a string.
@@ -397,8 +622,10 @@ class ArtResizer(metaclass=Shareable):
Only available locally.
"""
if self.local:
- func = BACKEND_GET_FORMAT[self.method[0]]
- return func(path_in)
+ return self.local_method.get_format(path_in)
+ else:
+ # FIXME: Should probably issue a warning?
+ return None
def reformat(self, path_in, new_format, deinterlaced=True):
"""Converts image to desired format, updating its extension, but
@@ -407,86 +634,66 @@ class ArtResizer(metaclass=Shareable):
Only available locally.
"""
if not self.local:
+ # FIXME: Should probably issue a warning?
return path_in
new_format = new_format.lower()
# A nonexhaustive map of image "types" to extensions overrides
new_format = {
- 'jpeg': 'jpg',
+ "jpeg": "jpg",
}.get(new_format, new_format)
fname, ext = os.path.splitext(path_in)
- path_new = fname + b'.' + new_format.encode('utf8')
- func = BACKEND_CONVERT_IMAGE_FORMAT[self.method[0]]
+ path_new = fname + b"." + new_format.encode("utf8")
# allows the exception to propagate, while still making sure a changed
# file path was removed
result_path = path_in
try:
- result_path = func(path_in, path_new, deinterlaced)
+ result_path = self.local_method.convert_format(
+ path_in, path_new, deinterlaced
+ )
finally:
if result_path != path_in:
os.unlink(path_in)
return result_path
- def _can_compare(self):
+ @property
+ def can_compare(self):
"""A boolean indicating whether image comparison is available"""
- return self.method[0] == IMAGEMAGICK and self.method[1] > (6, 8, 7)
-
- @staticmethod
- def _check_method():
- """Return a tuple indicating an available method and its version.
-
- The result has at least two elements:
- - The method, eitehr WEBPROXY, PIL, or IMAGEMAGICK.
- - The version.
-
- If the method is IMAGEMAGICK, there is also a third element: a
- bool flag indicating whether to use the `magick` binary or
- legacy single-purpose executables (`convert`, `identify`, etc.)
- """
- version = get_im_version()
- if version:
- version, legacy = version
- return IMAGEMAGICK, version, legacy
-
- version = get_pil_version()
- if version:
- return PIL, version
-
- return WEBPROXY, (0)
-
-
-def get_im_version():
- """Get the ImageMagick version and legacy flag as a pair. Or return
- None if ImageMagick is not available.
- """
- for cmd_name, legacy in ((['magick'], False), (['convert'], True)):
- cmd = cmd_name + ['--version']
-
- try:
- out = util.command_output(cmd).stdout
- except (subprocess.CalledProcessError, OSError) as exc:
- log.debug('ImageMagick version check failed: {}', exc)
+ if self.local:
+ return self.local_method.can_compare
else:
- if b'imagemagick' in out.lower():
- pattern = br".+ (\d+)\.(\d+)\.(\d+).*"
- match = re.search(pattern, out)
- if match:
- version = (int(match.group(1)),
- int(match.group(2)),
- int(match.group(3)))
- return version, legacy
+ return False
- return None
+ def compare(self, im1, im2, compare_threshold):
+ """Return a boolean indicating whether two images are similar.
+ Only available locally.
+ """
+ if self.local:
+ return self.local_method.compare(im1, im2, compare_threshold)
+ else:
+ # FIXME: Should probably issue a warning?
+ return None
-def get_pil_version():
- """Get the PIL/Pillow version, or None if it is unavailable.
- """
- try:
- __import__('PIL', fromlist=['Image'])
- return (0,)
- except ImportError:
- return None
+ @property
+ def can_write_metadata(self):
+ """A boolean indicating whether writing image metadata is supported."""
+
+ if self.local:
+ return self.local_method.can_write_metadata
+ else:
+ return False
+
+ def write_metadata(self, file, metadata):
+ """Write key-value metadata to the image file.
+
+ Only available locally. Currently, expects the image to be a PNG file.
+ """
+ if self.local:
+ self.local_method.write_metadata(file, metadata)
+ else:
+ # FIXME: Should probably issue a warning?
+ pass
diff --git a/lib/beets/util/bluelet.py b/lib/beets/util/bluelet.py
index a40f3b2f..db34486b 100644
--- a/lib/beets/util/bluelet.py
+++ b/lib/beets/util/bluelet.py
@@ -6,23 +6,24 @@ asyncore.
Bluelet: easy concurrency without all the messy parallelism.
"""
-import socket
-import select
-import sys
-import types
-import errno
-import traceback
-import time
import collections
-
+import errno
+import select
+import socket
+import sys
+import time
+import traceback
+import types
# Basic events used for thread scheduling.
+
class Event:
"""Just a base class identifying Bluelet events. An event is an
object yielded from a Bluelet thread coroutine to suspend operation
and communicate with the scheduler.
"""
+
pass
@@ -31,6 +32,7 @@ class WaitableEvent(Event):
waited for using a select() call. That is, it's an event with an
associated file descriptor.
"""
+
def waitables(self):
"""Return "waitable" objects to pass to select(). Should return
three iterables for input readiness, output readiness, and
@@ -48,18 +50,21 @@ class WaitableEvent(Event):
class ValueEvent(Event):
"""An event that does nothing but return a fixed value."""
+
def __init__(self, value):
self.value = value
class ExceptionEvent(Event):
"""Raise an exception at the yield point. Used internally."""
+
def __init__(self, exc_info):
self.exc_info = exc_info
class SpawnEvent(Event):
"""Add a new coroutine thread to the scheduler."""
+
def __init__(self, coro):
self.spawned = coro
@@ -68,12 +73,14 @@ class JoinEvent(Event):
"""Suspend the thread until the specified child thread has
completed.
"""
+
def __init__(self, child):
self.child = child
class KillEvent(Event):
"""Unschedule a child thread."""
+
def __init__(self, child):
self.child = child
@@ -83,6 +90,7 @@ class DelegationEvent(Event):
once the child thread finished, return control to the parent
thread.
"""
+
def __init__(self, coro):
self.spawned = coro
@@ -91,13 +99,14 @@ class ReturnEvent(Event):
"""Return a value the current thread's delegator at the point of
delegation. Ends the current (delegate) thread.
"""
+
def __init__(self, value):
self.value = value
class SleepEvent(WaitableEvent):
- """Suspend the thread for a given duration.
- """
+ """Suspend the thread for a given duration."""
+
def __init__(self, duration):
self.wakeup_time = time.time() + duration
@@ -107,6 +116,7 @@ class SleepEvent(WaitableEvent):
class ReadEvent(WaitableEvent):
"""Reads from a file-like object."""
+
def __init__(self, fd, bufsize):
self.fd = fd
self.bufsize = bufsize
@@ -120,6 +130,7 @@ class ReadEvent(WaitableEvent):
class WriteEvent(WaitableEvent):
"""Writes to a file-like object."""
+
def __init__(self, fd, data):
self.fd = fd
self.data = data
@@ -133,6 +144,7 @@ class WriteEvent(WaitableEvent):
# Core logic for executing and scheduling threads.
+
def _event_select(events):
"""Perform a select() over all the Events provided, returning the
ones ready to be fired. Only WaitableEvents (including SleepEvents)
@@ -154,11 +166,11 @@ def _event_select(events):
wlist += w
xlist += x
for waitable in r:
- waitable_to_event[('r', waitable)] = event
+ waitable_to_event[("r", waitable)] = event
for waitable in w:
- waitable_to_event[('w', waitable)] = event
+ waitable_to_event[("w", waitable)] = event
for waitable in x:
- waitable_to_event[('x', waitable)] = event
+ waitable_to_event[("x", waitable)] = event
# If we have a any sleeping threads, determine how long to sleep.
if earliest_wakeup:
@@ -177,11 +189,11 @@ def _event_select(events):
# Gather ready events corresponding to the ready waitables.
ready_events = set()
for ready in rready:
- ready_events.add(waitable_to_event[('r', ready)])
+ ready_events.add(waitable_to_event[("r", ready)])
for ready in wready:
- ready_events.add(waitable_to_event[('w', ready)])
+ ready_events.add(waitable_to_event[("w", ready)])
for ready in xready:
- ready_events.add(waitable_to_event[('x', ready)])
+ ready_events.add(waitable_to_event[("x", ready)])
# Gather any finished sleeps.
for event in events:
@@ -207,6 +219,7 @@ class Delegated(Event):
"""Placeholder indicating that a thread has delegated execution to a
different thread.
"""
+
def __init__(self, child):
self.child = child
@@ -277,8 +290,7 @@ def run(root_coro):
threads[coro] = next_event
def kill_thread(coro):
- """Unschedule this thread and its (recursive) delegates.
- """
+ """Unschedule this thread and its (recursive) delegates."""
# Collect all coroutines in the delegation stack.
coros = [coro]
while isinstance(threads[coro], Delegated):
@@ -338,12 +350,16 @@ def run(root_coro):
try:
value = event.fire()
except OSError as exc:
- if isinstance(exc.args, tuple) and \
- exc.args[0] == errno.EPIPE:
+ if (
+ isinstance(exc.args, tuple)
+ and exc.args[0] == errno.EPIPE
+ ):
# Broken pipe. Remote host disconnected.
pass
- elif isinstance(exc.args, tuple) and \
- exc.args[0] == errno.ECONNRESET:
+ elif (
+ isinstance(exc.args, tuple)
+ and exc.args[0] == errno.ECONNRESET
+ ):
# Connection was reset by peer.
pass
else:
@@ -382,16 +398,16 @@ def run(root_coro):
# Sockets and their associated events.
+
class SocketClosedError(Exception):
pass
class Listener:
- """A socket wrapper object for listening sockets.
- """
+ """A socket wrapper object for listening sockets."""
+
def __init__(self, host, port):
- """Create a listening socket on the given hostname and port.
- """
+ """Create a listening socket on the given hostname and port."""
self._closed = False
self.host = host
self.port = port
@@ -410,19 +426,18 @@ class Listener:
return AcceptEvent(self)
def close(self):
- """Immediately close the listening socket. (Not an event.)
- """
+ """Immediately close the listening socket. (Not an event.)"""
self._closed = True
self.sock.close()
class Connection:
- """A socket wrapper object for connected sockets.
- """
+ """A socket wrapper object for connected sockets."""
+
def __init__(self, sock, addr):
self.sock = sock
self.addr = addr
- self._buf = b''
+ self._buf = b""
self._closed = False
def close(self):
@@ -473,7 +488,7 @@ class Connection:
self._buf += data
else:
line = self._buf
- self._buf = b''
+ self._buf = b""
yield ReturnEvent(line)
break
@@ -482,6 +497,7 @@ class AcceptEvent(WaitableEvent):
"""An event for Listener objects (listening sockets) that suspends
execution until the socket gets a connection.
"""
+
def __init__(self, listener):
self.listener = listener
@@ -497,6 +513,7 @@ class ReceiveEvent(WaitableEvent):
"""An event for Connection objects (connected sockets) for
asynchronously reading data.
"""
+
def __init__(self, conn, bufsize):
self.conn = conn
self.bufsize = bufsize
@@ -512,6 +529,7 @@ class SendEvent(WaitableEvent):
"""An event for Connection objects (connected sockets) for
asynchronously writing data.
"""
+
def __init__(self, conn, data, sendall=False):
self.conn = conn
self.data = data
@@ -530,9 +548,9 @@ class SendEvent(WaitableEvent):
# Public interface for threads; each returns an event object that
# can immediately be "yield"ed.
+
def null():
- """Event: yield to the scheduler without doing anything special.
- """
+ """Event: yield to the scheduler without doing anything special."""
return ValueEvent(None)
@@ -541,7 +559,7 @@ def spawn(coro):
and child coroutines run concurrently.
"""
if not isinstance(coro, types.GeneratorType):
- raise ValueError('%s is not a coroutine' % coro)
+ raise ValueError("%s is not a coroutine" % coro)
return SpawnEvent(coro)
@@ -551,7 +569,7 @@ def call(coro):
returns a value using end(), then this event returns that value.
"""
if not isinstance(coro, types.GeneratorType):
- raise ValueError('%s is not a coroutine' % coro)
+ raise ValueError("%s is not a coroutine" % coro)
return DelegationEvent(coro)
@@ -573,7 +591,8 @@ def read(fd, bufsize=None):
if not data:
break
buf.append(data)
- yield ReturnEvent(''.join(buf))
+ yield ReturnEvent("".join(buf))
+
return DelegationEvent(reader())
else:
@@ -595,8 +614,7 @@ def connect(host, port):
def sleep(duration):
- """Event: suspend the thread for ``duration`` seconds.
- """
+ """Event: suspend the thread for ``duration`` seconds."""
return SleepEvent(duration)
@@ -608,19 +626,20 @@ def join(coro):
def kill(coro):
- """Halt the execution of a different `spawn`ed thread.
- """
+ """Halt the execution of a different `spawn`ed thread."""
return KillEvent(coro)
# Convenience function for running socket servers.
+
def server(host, port, func):
"""A coroutine that runs a network server. Host and port specify the
listening address. func should be a coroutine that takes a single
parameter, a Connection object. The coroutine is invoked for every
incoming connection on the listening socket.
"""
+
def handler(conn):
try:
yield func(conn)
diff --git a/lib/beets/util/enumeration.py b/lib/beets/util/enumeration.py
index e49f6fdd..33a6be58 100644
--- a/lib/beets/util/enumeration.py
+++ b/lib/beets/util/enumeration.py
@@ -20,6 +20,7 @@ class OrderedEnum(Enum):
"""
An Enum subclass that allows comparison of members.
"""
+
def __ge__(self, other):
if self.__class__ is other.__class__:
return self.value >= other.value
diff --git a/lib/beets/util/functemplate.py b/lib/beets/util/functemplate.py
index 289a436d..7d7e8f01 100644
--- a/lib/beets/util/functemplate.py
+++ b/lib/beets/util/functemplate.py
@@ -27,22 +27,21 @@ engine like Jinja2 or Mustache.
"""
-import re
import ast
import dis
-import types
-import sys
import functools
+import re
+import types
-SYMBOL_DELIM = '$'
-FUNC_DELIM = '%'
-GROUP_OPEN = '{'
-GROUP_CLOSE = '}'
-ARG_SEP = ','
-ESCAPE_CHAR = '$'
+SYMBOL_DELIM = "$"
+FUNC_DELIM = "%"
+GROUP_OPEN = "{"
+GROUP_CLOSE = "}"
+ARG_SEP = ","
+ESCAPE_CHAR = "$"
-VARIABLE_PREFIX = '__var_'
-FUNCTION_PREFIX = '__func_'
+VARIABLE_PREFIX = "__var_"
+FUNCTION_PREFIX = "__func_"
class Environment:
@@ -57,10 +56,6 @@ class Environment:
# Code generation helpers.
-def ex_lvalue(name):
- """A variable load expression."""
- return ast.Name(name, ast.Store())
-
def ex_rvalue(name):
"""A variable store expression."""
@@ -74,15 +69,6 @@ def ex_literal(val):
return ast.Constant(val)
-def ex_varassign(name, expr):
- """Assign an expression into a single variable. The expression may
- either be an `ast.expr` object or a value to be used as a literal.
- """
- if not isinstance(expr, ast.expr):
- expr = ex_literal(expr)
- return ast.Assign([ex_lvalue(name)], expr)
-
-
def ex_call(func, args):
"""A function-call expression with only positional parameters. The
function may be an expression or the name of a function. Each
@@ -99,19 +85,18 @@ def ex_call(func, args):
return ast.Call(func, args, [])
-def compile_func(arg_names, statements, name='_the_func', debug=False):
+def compile_func(arg_names, statements, name="_the_func", debug=False):
"""Compile a list of statements as the body of a function and return
the resulting Python function. If `debug`, then print out the
bytecode of the compiled function.
"""
args_fields = {
- 'args': [ast.arg(arg=n, annotation=None) for n in arg_names],
- 'kwonlyargs': [],
- 'kw_defaults': [],
- 'defaults': [ex_literal(None) for _ in arg_names],
+ "args": [ast.arg(arg=n, annotation=None) for n in arg_names],
+ "kwonlyargs": [],
+ "kw_defaults": [],
+ "defaults": [ex_literal(None) for _ in arg_names],
}
- if 'posonlyargs' in ast.arguments._fields: # Added in Python 3.8.
- args_fields['posonlyargs'] = []
+ args_fields["posonlyargs"] = []
args = ast.arguments(**args_fields)
func_def = ast.FunctionDef(
@@ -123,14 +108,11 @@ def compile_func(arg_names, statements, name='_the_func', debug=False):
# The ast.Module signature changed in 3.8 to accept a list of types to
# ignore.
- if sys.version_info >= (3, 8):
- mod = ast.Module([func_def], [])
- else:
- mod = ast.Module([func_def])
+ mod = ast.Module([func_def], [])
ast.fix_missing_locations(mod)
- prog = compile(mod, '', 'exec')
+ prog = compile(mod, "", "exec")
# Debug: show bytecode.
if debug:
@@ -146,6 +128,7 @@ def compile_func(arg_names, statements, name='_the_func', debug=False):
# AST nodes for the template language.
+
class Symbol:
"""A variable-substitution symbol in a template."""
@@ -154,7 +137,7 @@ class Symbol:
self.original = original
def __repr__(self):
- return 'Symbol(%s)' % repr(self.ident)
+ return "Symbol(%s)" % repr(self.ident)
def evaluate(self, env):
"""Evaluate the symbol in the environment, returning a Unicode
@@ -183,8 +166,9 @@ class Call:
self.original = original
def __repr__(self):
- return 'Call({}, {}, {})'.format(repr(self.ident), repr(self.args),
- repr(self.original))
+ return "Call({}, {}, {})".format(
+ repr(self.ident), repr(self.args), repr(self.original)
+ )
def evaluate(self, env):
"""Evaluate the function call in the environment, returning a
@@ -197,7 +181,7 @@ class Call:
except Exception as exc:
# Function raised exception! Maybe inlining the name of
# the exception will help debug.
- return '<%s>' % str(exc)
+ return "<%s>" % str(exc)
return str(out)
else:
return self.original
@@ -215,21 +199,22 @@ class Call:
# Create a subexpression that joins the result components of
# the arguments.
- arg_exprs.append(ex_call(
- ast.Attribute(ex_literal(''), 'join', ast.Load()),
- [ex_call(
- 'map',
+ arg_exprs.append(
+ ex_call(
+ ast.Attribute(ex_literal(""), "join", ast.Load()),
[
- ex_rvalue(str.__name__),
- ast.List(subexprs, ast.Load()),
- ]
- )],
- ))
+ ex_call(
+ "map",
+ [
+ ex_rvalue(str.__name__),
+ ast.List(subexprs, ast.Load()),
+ ],
+ )
+ ],
+ )
+ )
- subexpr_call = ex_call(
- FUNCTION_PREFIX + self.ident,
- arg_exprs
- )
+ subexpr_call = ex_call(FUNCTION_PREFIX + self.ident, arg_exprs)
return [subexpr_call], varnames, funcnames
@@ -242,7 +227,7 @@ class Expression:
self.parts = parts
def __repr__(self):
- return 'Expression(%s)' % (repr(self.parts))
+ return "Expression(%s)" % (repr(self.parts))
def evaluate(self, env):
"""Evaluate the entire expression in the environment, returning
@@ -254,7 +239,7 @@ class Expression:
out.append(part)
else:
out.append(part.evaluate(env))
- return ''.join(map(str, out))
+ return "".join(map(str, out))
def translate(self):
"""Compile the expression to a list of Python AST expressions, a
@@ -276,6 +261,7 @@ class Expression:
# Parser.
+
class ParseError(Exception):
pass
@@ -295,7 +281,7 @@ class Parser:
"""
def __init__(self, string, in_argument=False):
- """ Create a new parser.
+ """Create a new parser.
:param in_arguments: boolean that indicates the parser is to be
used for parsing function arguments, ie. considering commas
(`ARG_SEP`) a special character
@@ -306,10 +292,16 @@ class Parser:
self.parts = []
# Common parsing resources.
- special_chars = (SYMBOL_DELIM, FUNC_DELIM, GROUP_OPEN, GROUP_CLOSE,
- ESCAPE_CHAR)
- special_char_re = re.compile(r'[%s]|\Z' %
- ''.join(re.escape(c) for c in special_chars))
+ special_chars = (
+ SYMBOL_DELIM,
+ FUNC_DELIM,
+ GROUP_OPEN,
+ GROUP_CLOSE,
+ ESCAPE_CHAR,
+ )
+ special_char_re = re.compile(
+ r"[%s]|\Z" % "".join(re.escape(c) for c in special_chars)
+ )
escapable_chars = (SYMBOL_DELIM, FUNC_DELIM, GROUP_CLOSE, ARG_SEP)
terminator_chars = (GROUP_CLOSE,)
@@ -326,9 +318,10 @@ class Parser:
if self.in_argument:
extra_special_chars = (ARG_SEP,)
special_char_re = re.compile(
- r'[%s]|\Z' % ''.join(
- re.escape(c) for c in
- self.special_chars + extra_special_chars
+ r"[%s]|\Z"
+ % "".join(
+ re.escape(c)
+ for c in self.special_chars + extra_special_chars
)
)
@@ -341,10 +334,10 @@ class Parser:
# A non-special character. Skip to the next special
# character, treating the interstice as literal text.
next_pos = (
- special_char_re.search(
- self.string[self.pos:]).start() + self.pos
+ special_char_re.search(self.string[self.pos :]).start()
+ + self.pos
)
- text_parts.append(self.string[self.pos:next_pos])
+ text_parts.append(self.string[self.pos : next_pos])
self.pos = next_pos
continue
@@ -358,8 +351,9 @@ class Parser:
break
next_char = self.string[self.pos + 1]
- if char == ESCAPE_CHAR and next_char in (self.escapable_chars +
- extra_special_chars):
+ if char == ESCAPE_CHAR and next_char in (
+ self.escapable_chars + extra_special_chars
+ ):
# An escaped special character ($$, $}, etc.). Note that
# ${ is not an escape sequence: this is ambiguous with
# the start of a symbol and it's not necessary (just
@@ -370,7 +364,7 @@ class Parser:
# Shift all characters collected so far into a single string.
if text_parts:
- self.parts.append(''.join(text_parts))
+ self.parts.append("".join(text_parts))
text_parts = []
if char == SYMBOL_DELIM:
@@ -392,7 +386,7 @@ class Parser:
# If any parsed characters remain, shift them into a string.
if text_parts:
- self.parts.append(''.join(text_parts))
+ self.parts.append("".join(text_parts))
def parse_symbol(self):
"""Parse a variable reference (like ``$foo`` or ``${foo}``)
@@ -419,21 +413,23 @@ class Parser:
closer = self.string.find(GROUP_CLOSE, self.pos)
if closer == -1 or closer == self.pos:
# No closing brace found or identifier is empty.
- self.parts.append(self.string[start_pos:self.pos])
+ self.parts.append(self.string[start_pos : self.pos])
else:
# Closer found.
- ident = self.string[self.pos:closer]
+ ident = self.string[self.pos : closer]
self.pos = closer + 1
- self.parts.append(Symbol(ident,
- self.string[start_pos:self.pos]))
+ self.parts.append(
+ Symbol(ident, self.string[start_pos : self.pos])
+ )
else:
# A bare-word symbol.
ident = self._parse_ident()
if ident:
# Found a real symbol.
- self.parts.append(Symbol(ident,
- self.string[start_pos:self.pos]))
+ self.parts.append(
+ Symbol(ident, self.string[start_pos : self.pos])
+ )
else:
# A standalone $.
self.parts.append(SYMBOL_DELIM)
@@ -457,25 +453,24 @@ class Parser:
if self.pos >= len(self.string):
# Identifier terminates string.
- self.parts.append(self.string[start_pos:self.pos])
+ self.parts.append(self.string[start_pos : self.pos])
return
if self.string[self.pos] != GROUP_OPEN:
# Argument list not opened.
- self.parts.append(self.string[start_pos:self.pos])
+ self.parts.append(self.string[start_pos : self.pos])
return
# Skip past opening brace and try to parse an argument list.
self.pos += 1
args = self.parse_argument_list()
- if self.pos >= len(self.string) or \
- self.string[self.pos] != GROUP_CLOSE:
+ if self.pos >= len(self.string) or self.string[self.pos] != GROUP_CLOSE:
# Arguments unclosed.
- self.parts.append(self.string[start_pos:self.pos])
+ self.parts.append(self.string[start_pos : self.pos])
return
self.pos += 1 # Move past closing brace.
- self.parts.append(Call(ident, args, self.string[start_pos:self.pos]))
+ self.parts.append(Call(ident, args, self.string[start_pos : self.pos]))
def parse_argument_list(self):
"""Parse a list of arguments starting at ``pos``, returning a
@@ -487,15 +482,17 @@ class Parser:
expressions = []
while self.pos < len(self.string):
- subparser = Parser(self.string[self.pos:], in_argument=True)
+ subparser = Parser(self.string[self.pos :], in_argument=True)
subparser.parse_expression()
# Extract and advance past the parsed expression.
expressions.append(Expression(subparser.parts))
self.pos += subparser.pos
- if self.pos >= len(self.string) or \
- self.string[self.pos] == GROUP_CLOSE:
+ if (
+ self.pos >= len(self.string)
+ or self.string[self.pos] == GROUP_CLOSE
+ ):
# Argument list terminated by EOF or closing brace.
break
@@ -510,8 +507,8 @@ class Parser:
"""Parse an identifier and return it (possibly an empty string).
Updates ``pos``.
"""
- remainder = self.string[self.pos:]
- ident = re.match(r'\w*', remainder).group(0)
+ remainder = self.string[self.pos :]
+ ident = re.match(r"\w*", remainder).group(0)
self.pos += len(ident)
return ident
@@ -524,32 +521,20 @@ def _parse(template):
parser.parse_expression()
parts = parser.parts
- remainder = parser.string[parser.pos:]
+ remainder = parser.string[parser.pos :]
if remainder:
parts.append(remainder)
return Expression(parts)
-def cached(func):
- """Like the `functools.lru_cache` decorator, but works (as a no-op)
- on Python < 3.2.
- """
- if hasattr(functools, 'lru_cache'):
- return functools.lru_cache(maxsize=128)(func)
- else:
- # Do nothing when lru_cache is not available.
- return func
-
-
-@cached
+@functools.lru_cache(maxsize=128)
def template(fmt):
return Template(fmt)
# External interface.
class Template:
- """A string template, including text, Symbols, and Calls.
- """
+ """A string template, including text, Symbols, and Calls."""
def __init__(self, template):
self.expr = _parse(template)
@@ -568,8 +553,7 @@ class Template:
return self.expr.evaluate(Environment(values, functions))
def substitute(self, values={}, functions={}):
- """Evaluate the template given the values and functions.
- """
+ """Evaluate the template given the values and functions."""
try:
res = self.compiled(values, functions)
except Exception: # Handle any exceptions thrown by compiled version.
@@ -599,24 +583,29 @@ class Template:
for funcname in funcnames:
args[FUNCTION_PREFIX + funcname] = functions[funcname]
parts = func(**args)
- return ''.join(parts)
+ return "".join(parts)
return wrapper_func
# Performance tests.
-if __name__ == '__main__':
+if __name__ == "__main__":
import timeit
- _tmpl = Template('foo $bar %baz{foozle $bar barzle} $bar')
- _vars = {'bar': 'qux'}
- _funcs = {'baz': str.upper}
- interp_time = timeit.timeit('_tmpl.interpret(_vars, _funcs)',
- 'from __main__ import _tmpl, _vars, _funcs',
- number=10000)
+
+ _tmpl = Template("foo $bar %baz{foozle $bar barzle} $bar")
+ _vars = {"bar": "qux"}
+ _funcs = {"baz": str.upper}
+ interp_time = timeit.timeit(
+ "_tmpl.interpret(_vars, _funcs)",
+ "from __main__ import _tmpl, _vars, _funcs",
+ number=10000,
+ )
print(interp_time)
- comp_time = timeit.timeit('_tmpl.substitute(_vars, _funcs)',
- 'from __main__ import _tmpl, _vars, _funcs',
- number=10000)
+ comp_time = timeit.timeit(
+ "_tmpl.substitute(_vars, _funcs)",
+ "from __main__ import _tmpl, _vars, _funcs",
+ number=10000,
+ )
print(comp_time)
- print('Speedup:', interp_time / comp_time)
+ print("Speedup:", interp_time / comp_time)
diff --git a/lib/beets/util/hidden.py b/lib/beets/util/hidden.py
index 881de1ac..d2c66fac 100644
--- a/lib/beets/util/hidden.py
+++ b/lib/beets/util/hidden.py
@@ -1,5 +1,6 @@
# This file is part of beets.
# Copyright 2016, Adrian Sampson.
+# Copyright 2024, Arav K.
#
# Permission is hereby granted, free of charge, to any person obtaining
# a copy of this software and associated documentation files (the
@@ -14,71 +15,49 @@
"""Simple library to work out if a file is hidden on different platforms."""
+import ctypes
import os
import stat
-import ctypes
import sys
-import beets.util
+from pathlib import Path
+from typing import Union
-def _is_hidden_osx(path):
- """Return whether or not a file is hidden on OS X.
-
- This uses os.lstat to work out if a file has the "hidden" flag.
+def is_hidden(path: Union[bytes, Path]) -> bool:
"""
- file_stat = os.lstat(beets.util.syspath(path))
-
- if hasattr(file_stat, 'st_flags') and hasattr(stat, 'UF_HIDDEN'):
- return bool(file_stat.st_flags & stat.UF_HIDDEN)
- else:
- return False
-
-
-def _is_hidden_win(path):
- """Return whether or not a file is hidden on Windows.
-
- This uses GetFileAttributes to work out if a file has the "hidden" flag
- (FILE_ATTRIBUTE_HIDDEN).
+ Determine whether the given path is treated as a 'hidden file' by the OS.
"""
- # FILE_ATTRIBUTE_HIDDEN = 2 (0x2) from GetFileAttributes documentation.
- hidden_mask = 2
- # Retrieve the attributes for the file.
- attrs = ctypes.windll.kernel32.GetFileAttributesW(beets.util.syspath(path))
+ if isinstance(path, bytes):
+ path = Path(os.fsdecode(path))
- # Ensure we have valid attribues and compare them against the mask.
- return attrs >= 0 and attrs & hidden_mask
+ # TODO: Avoid doing a platform check on every invocation of the function.
+ # TODO: Stop supporting 'bytes' inputs once 'pathlib' is fully integrated.
+ if sys.platform == "win32":
+ # On Windows, we check for an FS-provided attribute.
-def _is_hidden_dot(path):
- """Return whether or not a file starts with a dot.
+ # FILE_ATTRIBUTE_HIDDEN = 2 (0x2) from GetFileAttributes documentation.
+ hidden_mask = 2
- Files starting with a dot are seen as "hidden" files on Unix-based OSes.
- """
- return os.path.basename(path).startswith(b'.')
+ # Retrieve the attributes for the file.
+ attrs = ctypes.windll.kernel32.GetFileAttributesW(str(path))
+ # Ensure the attribute mask is valid.
+ if attrs < 0:
+ return False
-def is_hidden(path):
- """Return whether or not a file is hidden. `path` should be a
- bytestring filename.
+ # Check for the hidden attribute.
+ return attrs & hidden_mask
- This method works differently depending on the platform it is called on.
+ # On OS X, we check for an FS-provided attribute.
+ if sys.platform == "darwin":
+ if hasattr(os.stat_result, "st_flags") and hasattr(stat, "UF_HIDDEN"):
+ if path.lstat().st_flags & stat.UF_HIDDEN:
+ return True
- On OS X, it uses both the result of `is_hidden_osx` and `is_hidden_dot` to
- work out if a file is hidden.
+ # On all non-Windows platforms, we check for a '.'-prefixed file name.
+ if path.name.startswith("."):
+ return True
- On Windows, it uses the result of `is_hidden_win` to work out if a file is
- hidden.
-
- On any other operating systems (i.e. Linux), it uses `is_hidden_dot` to
- work out if a file is hidden.
- """
- # Run platform specific functions depending on the platform
- if sys.platform == 'darwin':
- return _is_hidden_osx(path) or _is_hidden_dot(path)
- elif sys.platform == 'win32':
- return _is_hidden_win(path)
- else:
- return _is_hidden_dot(path)
-
-__all__ = ['is_hidden']
+ return False
diff --git a/lib/beets/util/id_extractors.py b/lib/beets/util/id_extractors.py
new file mode 100644
index 00000000..04e9e94a
--- /dev/null
+++ b/lib/beets/util/id_extractors.py
@@ -0,0 +1,65 @@
+# This file is part of beets.
+# Copyright 2016, Adrian Sampson.
+#
+# Permission is hereby granted, free of charge, to any person obtaining
+# a copy of this software and associated documentation files (the
+# "Software"), to deal in the Software without restriction, including
+# without limitation the rights to use, copy, modify, merge, publish,
+# distribute, sublicense, and/or sell copies of the Software, and to
+# permit persons to whom the Software is furnished to do so, subject to
+# the following conditions:
+#
+# The above copyright notice and this permission notice shall be
+# included in all copies or substantial portions of the Software.
+
+"""Helpers around the extraction of album/track ID's from metadata sources."""
+
+import re
+
+# Spotify IDs consist of 22 alphanumeric characters
+# (zero-left-padded base62 representation of randomly generated UUID4)
+spotify_id_regex = {
+ "pattern": r"(^|open\.spotify\.com/{}/)([0-9A-Za-z]{{22}})",
+ "match_group": 2,
+}
+
+deezer_id_regex = {
+ "pattern": r"(^|deezer\.com/)([a-z]*/)?({}/)?(\d+)",
+ "match_group": 4,
+}
+
+beatport_id_regex = {
+ "pattern": r"(^|beatport\.com/release/.+/)(\d+)$",
+ "match_group": 2,
+}
+
+# A note on Bandcamp: There is no such thing as a Bandcamp album or artist ID,
+# the URL can be used as the identifier. The Bandcamp metadata source plugin
+# works that way - https://github.com/snejus/beetcamp. Bandcamp album
+# URLs usually look like: https://nameofartist.bandcamp.com/album/nameofalbum
+
+
+def extract_discogs_id_regex(album_id):
+ """Returns the Discogs_id or None."""
+ # Discogs-IDs are simple integers. In order to avoid confusion with
+ # other metadata plugins, we only look for very specific formats of the
+ # input string:
+ # - plain integer, optionally wrapped in brackets and prefixed by an
+ # 'r', as this is how discogs displays the release ID on its webpage.
+ # - legacy url format: discogs.com//release/
+ # - legacy url short format: discogs.com/release/
+ # - current url format: discogs.com/release/-
+ # See #291, #4080 and #4085 for the discussions leading up to these
+ # patterns.
+ # Regex has been tested here https://regex101.com/r/TOu7kw/1
+
+ for pattern in [
+ r"^\[?r?(?P\d+)\]?$",
+ r"discogs\.com/release/(?P\d+)-?",
+ r"discogs\.com/[^/]+/release/(?P\d+)",
+ ]:
+ match = re.search(pattern, album_id)
+ if match:
+ return int(match.group("id"))
+
+ return None
diff --git a/lib/beets/util/m3u.py b/lib/beets/util/m3u.py
new file mode 100644
index 00000000..b6e355e0
--- /dev/null
+++ b/lib/beets/util/m3u.py
@@ -0,0 +1,97 @@
+# This file is part of beets.
+# Copyright 2022, J0J0 Todos.
+#
+# Permission is hereby granted, free of charge, to any person obtaining
+# a copy of this software and associated documentation files (the
+# "Software"), to deal in the Software without restriction, including
+# without limitation the rights to use, copy, modify, merge, publish,
+# distribute, sublicense, and/or sell copies of the Software, and to
+# permit persons to whom the Software is furnished to do so, subject to
+# the following conditions:
+#
+# The above copyright notice and this permission notice shall be
+# included in all copies or substantial portions of the Software.
+
+"""Provides utilities to read, write and manipulate m3u playlist files."""
+
+import traceback
+
+from beets.util import FilesystemError, mkdirall, normpath, syspath
+
+
+class EmptyPlaylistError(Exception):
+ """Raised when a playlist file without media files is saved or loaded."""
+
+ pass
+
+
+class M3UFile:
+ """Reads and writes m3u or m3u8 playlist files."""
+
+ def __init__(self, path):
+ """``path`` is the absolute path to the playlist file.
+
+ The playlist file type, m3u or m3u8 is determined by 1) the ending
+ being m3u8 and 2) the file paths contained in the list being utf-8
+ encoded. Since the list is passed from the outside, this is currently
+ out of control of this class.
+ """
+ self.path = path
+ self.extm3u = False
+ self.media_list = []
+
+ def load(self):
+ """Reads the m3u file from disk and sets the object's attributes."""
+ pl_normpath = normpath(self.path)
+ try:
+ with open(syspath(pl_normpath), "rb") as pl_file:
+ raw_contents = pl_file.readlines()
+ except OSError as exc:
+ raise FilesystemError(
+ exc, "read", (pl_normpath,), traceback.format_exc()
+ )
+
+ self.extm3u = True if raw_contents[0].rstrip() == b"#EXTM3U" else False
+ for line in raw_contents[1:]:
+ if line.startswith(b"#"):
+ # Support for specific EXTM3U comments could be added here.
+ continue
+ self.media_list.append(normpath(line.rstrip()))
+ if not self.media_list:
+ raise EmptyPlaylistError
+
+ def set_contents(self, media_list, extm3u=True):
+ """Sets self.media_list to a list of media file paths.
+
+ Also sets additional flags, changing the final m3u-file's format.
+
+ ``media_list`` is a list of paths to media files that should be added
+ to the playlist (relative or absolute paths, that's the responsibility
+ of the caller). By default the ``extm3u`` flag is set, to ensure a
+ save-operation writes an m3u-extended playlist (comment "#EXTM3U" at
+ the top of the file).
+ """
+ self.media_list = media_list
+ self.extm3u = extm3u
+
+ def write(self):
+ """Writes the m3u file to disk.
+
+ Handles the creation of potential parent directories.
+ """
+ header = [b"#EXTM3U"] if self.extm3u else []
+ if not self.media_list:
+ raise EmptyPlaylistError
+ contents = header + self.media_list
+ pl_normpath = normpath(self.path)
+ mkdirall(pl_normpath)
+
+ try:
+ with open(syspath(pl_normpath), "wb") as pl_file:
+ for line in contents:
+ pl_file.write(line + b"\n")
+ pl_file.write(b"\n") # Final linefeed to prevent noeol file.
+ except OSError as exc:
+ raise FilesystemError(
+ exc, "create", (pl_normpath,), traceback.format_exc()
+ )
diff --git a/lib/beets/util/pipeline.py b/lib/beets/util/pipeline.py
index d338cb51..c4933ff0 100644
--- a/lib/beets/util/pipeline.py
+++ b/lib/beets/util/pipeline.py
@@ -33,11 +33,11 @@ in place of any single coroutine.
import queue
-from threading import Thread, Lock
import sys
+from threading import Lock, Thread
-BUBBLE = '__PIPELINE_BUBBLE__'
-POISON = '__PIPELINE_POISON__'
+BUBBLE = "__PIPELINE_BUBBLE__"
+POISON = "__PIPELINE_POISON__"
DEFAULT_QUEUE_SIZE = 16
@@ -48,6 +48,7 @@ def _invalidate_queue(q, val=None, sync=True):
which defaults to None. `sync` controls whether a lock is
required (because it's not reentrant!).
"""
+
def _qsize(len=len):
return 1
@@ -75,8 +76,8 @@ def _invalidate_queue(q, val=None, sync=True):
q._qsize = _qsize
q._put = _put
q._get = _get
- q.not_empty.notifyAll()
- q.not_full.notifyAll()
+ q.not_empty.notify_all()
+ q.not_full.notify_all()
finally:
if sync:
@@ -168,6 +169,7 @@ def stage(func):
while True:
task = yield task
task = func(*(args + (task,)))
+
return coro
@@ -191,6 +193,7 @@ def mutator_stage(func):
while True:
task = yield task
func(*(args + (task,)))
+
return coro
@@ -218,20 +221,18 @@ class PipelineThread(Thread):
self.exc_info = None
def abort(self):
- """Shut down the thread at the next chance possible.
- """
+ """Shut down the thread at the next chance possible."""
with self.abort_lock:
self.abort_flag = True
# Ensure that we are not blocking on a queue read or write.
- if hasattr(self, 'in_queue'):
+ if hasattr(self, "in_queue"):
_invalidate_queue(self.in_queue, POISON)
- if hasattr(self, 'out_queue'):
+ if hasattr(self, "out_queue"):
_invalidate_queue(self.out_queue, POISON)
def abort_all(self, exc_info):
- """Abort all other threads in the system for an exception.
- """
+ """Abort all other threads in the system for an exception."""
self.exc_info = exc_info
for thread in self.all_threads:
thread.abort()
@@ -373,7 +374,7 @@ class Pipeline:
be at least two stages.
"""
if len(stages) < 2:
- raise ValueError('pipeline must have at least two stages')
+ raise ValueError("pipeline must have at least two stages")
self.stages = []
for stage in stages:
if isinstance(stage, (list, tuple)):
@@ -405,15 +406,15 @@ class Pipeline:
# Middle stages.
for i in range(1, queue_count):
for coro in self.stages[i]:
- threads.append(MiddlePipelineThread(
- coro, queues[i - 1], queues[i], threads
- ))
+ threads.append(
+ MiddlePipelineThread(
+ coro, queues[i - 1], queues[i], threads
+ )
+ )
# Last stage.
for coro in self.stages[-1]:
- threads.append(
- LastPipelineThread(coro, queues[-1], threads)
- )
+ threads.append(LastPipelineThread(coro, queues[-1], threads))
# Start threads.
for thread in threads:
@@ -472,21 +473,21 @@ class Pipeline:
# Smoke test.
-if __name__ == '__main__':
+if __name__ == "__main__":
import time
# Test a normally-terminating pipeline both in sequence and
# in parallel.
def produce():
for i in range(5):
- print('generating %i' % i)
+ print("generating %i" % i)
time.sleep(1)
yield i
def work():
num = yield
while True:
- print('processing %i' % num)
+ print("processing %i" % num)
time.sleep(2)
num = yield num * 2
@@ -494,7 +495,7 @@ if __name__ == '__main__':
while True:
num = yield
time.sleep(1)
- print('received %i' % num)
+ print("received %i" % num)
ts_start = time.time()
Pipeline([produce(), work(), consume()]).run_sequential()
@@ -503,22 +504,22 @@ if __name__ == '__main__':
ts_par = time.time()
Pipeline([produce(), (work(), work()), consume()]).run_parallel()
ts_end = time.time()
- print('Sequential time:', ts_seq - ts_start)
- print('Parallel time:', ts_par - ts_seq)
- print('Multiply-parallel time:', ts_end - ts_par)
+ print("Sequential time:", ts_seq - ts_start)
+ print("Parallel time:", ts_par - ts_seq)
+ print("Multiply-parallel time:", ts_end - ts_par)
print()
# Test a pipeline that raises an exception.
def exc_produce():
for i in range(10):
- print('generating %i' % i)
+ print("generating %i" % i)
time.sleep(1)
yield i
def exc_work():
num = yield
while True:
- print('processing %i' % num)
+ print("processing %i" % num)
time.sleep(3)
if num == 3:
raise Exception()
@@ -527,6 +528,6 @@ if __name__ == '__main__':
def exc_consume():
while True:
num = yield
- print('received %i' % num)
+ print("received %i" % num)
Pipeline([exc_produce(), exc_work(), exc_consume()]).run_parallel(1)
diff --git a/lib/beets/vfs.py b/lib/beets/vfs.py
index aef69650..4a9681a9 100644
--- a/lib/beets/vfs.py
+++ b/lib/beets/vfs.py
@@ -17,9 +17,10 @@ libraries.
"""
from collections import namedtuple
+
from beets import util
-Node = namedtuple('Node', ['files', 'dirs'])
+Node = namedtuple("Node", ["files", "dirs"])
def _insert(node, path, itemid):
diff --git a/lib/beetsplug/__init__.py b/lib/beetsplug/__init__.py
index da248491..763ff3a0 100644
--- a/lib/beetsplug/__init__.py
+++ b/lib/beetsplug/__init__.py
@@ -17,4 +17,5 @@
# Make this a namespace package.
from pkgutil import extend_path
+
__path__ = extend_path(__path__, __name__)
diff --git a/lib/beetsplug/absubmit.py b/lib/beetsplug/absubmit.py
index d1ea692f..fc40b85e 100644
--- a/lib/beetsplug/absubmit.py
+++ b/lib/beetsplug/absubmit.py
@@ -22,16 +22,14 @@ import json
import os
import subprocess
import tempfile
-
from distutils.spawn import find_executable
+
import requests
-from beets import plugins
-from beets import util
-from beets import ui
+from beets import plugins, ui, util
# We use this field to check whether AcousticBrainz info is present.
-PROBE_FIELD = 'mood_acoustic'
+PROBE_FIELD = "mood_acoustic"
class ABSubmitError(Exception):
@@ -47,39 +45,39 @@ def call(args):
return util.command_output(args).stdout
except subprocess.CalledProcessError as e:
raise ABSubmitError(
- '{} exited with status {}'.format(args[0], e.returncode)
+ "{} exited with status {}".format(args[0], e.returncode)
)
class AcousticBrainzSubmitPlugin(plugins.BeetsPlugin):
-
def __init__(self):
super().__init__()
- self.config.add({
- 'extractor': '',
- 'force': False,
- 'pretend': False
- })
+ self._log.warning("This plugin is deprecated.")
- self.extractor = self.config['extractor'].as_str()
+ self.config.add(
+ {"extractor": "", "force": False, "pretend": False, "base_url": ""}
+ )
+
+ self.extractor = self.config["extractor"].as_str()
if self.extractor:
self.extractor = util.normpath(self.extractor)
- # Expicit path to extractor
+ # Explicit path to extractor
if not os.path.isfile(self.extractor):
raise ui.UserError(
- 'Extractor command does not exist: {0}.'.
- format(self.extractor)
+ "Extractor command does not exist: {0}.".format(
+ self.extractor
+ )
)
else:
# Implicit path to extractor, search for it in path
- self.extractor = 'streaming_extractor_music'
+ self.extractor = "streaming_extractor_music"
try:
call([self.extractor])
except OSError:
raise ui.UserError(
- 'No extractor command found: please install the extractor'
- ' binary from https://acousticbrainz.org/download'
+ "No extractor command found: please install the extractor"
+ " binary from https://essentia.upf.edu/"
)
except ABSubmitError:
# Extractor found, will exit with an error if not called with
@@ -92,36 +90,58 @@ class AcousticBrainzSubmitPlugin(plugins.BeetsPlugin):
# Calculate extractor hash.
self.extractor_sha = hashlib.sha1()
- with open(self.extractor, 'rb') as extractor:
+ with open(self.extractor, "rb") as extractor:
self.extractor_sha.update(extractor.read())
self.extractor_sha = self.extractor_sha.hexdigest()
- base_url = 'https://acousticbrainz.org/api/v1/{mbid}/low-level'
+ self.url = ""
+ base_url = self.config["base_url"].as_str()
+ if base_url:
+ if not base_url.startswith("http"):
+ raise ui.UserError(
+ "AcousticBrainz server base URL must start "
+ "with an HTTP scheme"
+ )
+ elif base_url[-1] != "/":
+ base_url = base_url + "/"
+ self.url = base_url + "{mbid}/low-level"
def commands(self):
cmd = ui.Subcommand(
- 'absubmit',
- help='calculate and submit AcousticBrainz analysis'
+ "absubmit", help="calculate and submit AcousticBrainz analysis"
)
cmd.parser.add_option(
- '-f', '--force', dest='force_refetch',
- action='store_true', default=False,
- help='re-download data when already present'
+ "-f",
+ "--force",
+ dest="force_refetch",
+ action="store_true",
+ default=False,
+ help="re-download data when already present",
)
cmd.parser.add_option(
- '-p', '--pretend', dest='pretend_fetch',
- action='store_true', default=False,
- help='pretend to perform action, but show \
-only files which would be processed'
+ "-p",
+ "--pretend",
+ dest="pretend_fetch",
+ action="store_true",
+ default=False,
+ help="pretend to perform action, but show \
+only files which would be processed",
)
cmd.func = self.command
return [cmd]
def command(self, lib, opts, args):
- # Get items from arguments
- items = lib.items(ui.decargs(args))
- self.opts = opts
- util.par_map(self.analyze_submit, items)
+ if not self.url:
+ raise ui.UserError(
+ "This plugin is deprecated since AcousticBrainz no longer "
+ "accepts new submissions. See the base_url configuration "
+ "option."
+ )
+ else:
+ # Get items from arguments
+ items = lib.items(ui.decargs(args))
+ self.opts = opts
+ util.par_map(self.analyze_submit, items)
def analyze_submit(self, item):
analysis = self._get_analysis(item)
@@ -129,28 +149,29 @@ only files which would be processed'
self._submit_data(item, analysis)
def _get_analysis(self, item):
- mbid = item['mb_trackid']
+ mbid = item["mb_trackid"]
# Avoid re-analyzing files that already have AB data.
- if not self.opts.force_refetch and not self.config['force']:
+ if not self.opts.force_refetch and not self.config["force"]:
if item.get(PROBE_FIELD):
return None
# If file has no MBID, skip it.
if not mbid:
- self._log.info('Not analysing {}, missing '
- 'musicbrainz track id.', item)
+ self._log.info(
+ "Not analysing {}, missing " "musicbrainz track id.", item
+ )
return None
- if self.opts.pretend_fetch or self.config['pretend']:
- self._log.info('pretend action - extract item: {}', item)
+ if self.opts.pretend_fetch or self.config["pretend"]:
+ self._log.info("pretend action - extract item: {}", item)
return None
# Temporary file to save extractor output to, extractor only works
# if an output file is given. Here we use a temporary file to copy
# the data into a python object and then remove the file from the
# system.
- tmp_file, filename = tempfile.mkstemp(suffix='.json')
+ tmp_file, filename = tempfile.mkstemp(suffix=".json")
try:
# Close the file, so the extractor can overwrite it.
os.close(tmp_file)
@@ -158,15 +179,17 @@ only files which would be processed'
call([self.extractor, util.syspath(item.path), filename])
except ABSubmitError as e:
self._log.warning(
- 'Failed to analyse {item} for AcousticBrainz: {error}',
- item=item, error=e
+ "Failed to analyse {item} for AcousticBrainz: {error}",
+ item=item,
+ error=e,
)
return None
with open(filename) as tmp_file:
analysis = json.load(tmp_file)
# Add the hash to the output.
- analysis['metadata']['version']['essentia_build_sha'] = \
- self.extractor_sha
+ analysis["metadata"]["version"][
+ "essentia_build_sha"
+ ] = self.extractor_sha
return analysis
finally:
try:
@@ -177,20 +200,28 @@ only files which would be processed'
raise
def _submit_data(self, item, data):
- mbid = item['mb_trackid']
- headers = {'Content-Type': 'application/json'}
- response = requests.post(self.base_url.format(mbid=mbid),
- json=data, headers=headers)
+ mbid = item["mb_trackid"]
+ headers = {"Content-Type": "application/json"}
+ response = requests.post(
+ self.url.format(mbid=mbid),
+ json=data,
+ headers=headers,
+ timeout=10,
+ )
# Test that request was successful and raise an error on failure.
if response.status_code != 200:
try:
- message = response.json()['message']
+ message = response.json()["message"]
except (ValueError, KeyError) as e:
- message = f'unable to get error message: {e}'
+ message = f"unable to get error message: {e}"
self._log.error(
- 'Failed to submit AcousticBrainz analysis of {item}: '
- '{message}).', item=item, message=message
+ "Failed to submit AcousticBrainz analysis of {item}: "
+ "{message}).",
+ item=item,
+ message=message,
)
else:
- self._log.debug('Successfully submitted AcousticBrainz analysis '
- 'for {}.', item)
+ self._log.debug(
+ "Successfully submitted AcousticBrainz analysis " "for {}.",
+ item,
+ )
diff --git a/lib/beetsplug/acousticbrainz.py b/lib/beetsplug/acousticbrainz.py
index eabc5849..a4b153fc 100644
--- a/lib/beetsplug/acousticbrainz.py
+++ b/lib/beetsplug/acousticbrainz.py
@@ -22,220 +22,187 @@ import requests
from beets import plugins, ui
from beets.dbcore import types
-ACOUSTIC_BASE = "https://acousticbrainz.org/"
LEVELS = ["/low-level", "/high-level"]
ABSCHEME = {
- 'highlevel': {
- 'danceability': {
- 'all': {
- 'danceable': 'danceable'
- }
- },
- 'gender': {
- 'value': 'gender'
- },
- 'genre_rosamerica': {
- 'value': 'genre_rosamerica'
- },
- 'mood_acoustic': {
- 'all': {
- 'acoustic': 'mood_acoustic'
- }
- },
- 'mood_aggressive': {
- 'all': {
- 'aggressive': 'mood_aggressive'
- }
- },
- 'mood_electronic': {
- 'all': {
- 'electronic': 'mood_electronic'
- }
- },
- 'mood_happy': {
- 'all': {
- 'happy': 'mood_happy'
- }
- },
- 'mood_party': {
- 'all': {
- 'party': 'mood_party'
- }
- },
- 'mood_relaxed': {
- 'all': {
- 'relaxed': 'mood_relaxed'
- }
- },
- 'mood_sad': {
- 'all': {
- 'sad': 'mood_sad'
- }
- },
- 'moods_mirex': {
- 'value': 'moods_mirex'
- },
- 'ismir04_rhythm': {
- 'value': 'rhythm'
- },
- 'tonal_atonal': {
- 'all': {
- 'tonal': 'tonal'
- }
- },
- 'timbre': {
- 'value': 'timbre'
- },
- 'voice_instrumental': {
- 'value': 'voice_instrumental'
- },
+ "highlevel": {
+ "danceability": {"all": {"danceable": "danceable"}},
+ "gender": {"value": "gender"},
+ "genre_rosamerica": {"value": "genre_rosamerica"},
+ "mood_acoustic": {"all": {"acoustic": "mood_acoustic"}},
+ "mood_aggressive": {"all": {"aggressive": "mood_aggressive"}},
+ "mood_electronic": {"all": {"electronic": "mood_electronic"}},
+ "mood_happy": {"all": {"happy": "mood_happy"}},
+ "mood_party": {"all": {"party": "mood_party"}},
+ "mood_relaxed": {"all": {"relaxed": "mood_relaxed"}},
+ "mood_sad": {"all": {"sad": "mood_sad"}},
+ "moods_mirex": {"value": "moods_mirex"},
+ "ismir04_rhythm": {"value": "rhythm"},
+ "tonal_atonal": {"all": {"tonal": "tonal"}},
+ "timbre": {"value": "timbre"},
+ "voice_instrumental": {"value": "voice_instrumental"},
},
- 'lowlevel': {
- 'average_loudness': 'average_loudness'
+ "lowlevel": {"average_loudness": "average_loudness"},
+ "rhythm": {"bpm": "bpm"},
+ "tonal": {
+ "chords_changes_rate": "chords_changes_rate",
+ "chords_key": "chords_key",
+ "chords_number_rate": "chords_number_rate",
+ "chords_scale": "chords_scale",
+ "key_key": ("initial_key", 0),
+ "key_scale": ("initial_key", 1),
+ "key_strength": "key_strength",
},
- 'rhythm': {
- 'bpm': 'bpm'
- },
- 'tonal': {
- 'chords_changes_rate': 'chords_changes_rate',
- 'chords_key': 'chords_key',
- 'chords_number_rate': 'chords_number_rate',
- 'chords_scale': 'chords_scale',
- 'key_key': ('initial_key', 0),
- 'key_scale': ('initial_key', 1),
- 'key_strength': 'key_strength'
-
- }
}
class AcousticPlugin(plugins.BeetsPlugin):
item_types = {
- 'average_loudness': types.Float(6),
- 'chords_changes_rate': types.Float(6),
- 'chords_key': types.STRING,
- 'chords_number_rate': types.Float(6),
- 'chords_scale': types.STRING,
- 'danceable': types.Float(6),
- 'gender': types.STRING,
- 'genre_rosamerica': types.STRING,
- 'initial_key': types.STRING,
- 'key_strength': types.Float(6),
- 'mood_acoustic': types.Float(6),
- 'mood_aggressive': types.Float(6),
- 'mood_electronic': types.Float(6),
- 'mood_happy': types.Float(6),
- 'mood_party': types.Float(6),
- 'mood_relaxed': types.Float(6),
- 'mood_sad': types.Float(6),
- 'moods_mirex': types.STRING,
- 'rhythm': types.Float(6),
- 'timbre': types.STRING,
- 'tonal': types.Float(6),
- 'voice_instrumental': types.STRING,
+ "average_loudness": types.Float(6),
+ "chords_changes_rate": types.Float(6),
+ "chords_key": types.STRING,
+ "chords_number_rate": types.Float(6),
+ "chords_scale": types.STRING,
+ "danceable": types.Float(6),
+ "gender": types.STRING,
+ "genre_rosamerica": types.STRING,
+ "initial_key": types.STRING,
+ "key_strength": types.Float(6),
+ "mood_acoustic": types.Float(6),
+ "mood_aggressive": types.Float(6),
+ "mood_electronic": types.Float(6),
+ "mood_happy": types.Float(6),
+ "mood_party": types.Float(6),
+ "mood_relaxed": types.Float(6),
+ "mood_sad": types.Float(6),
+ "moods_mirex": types.STRING,
+ "rhythm": types.Float(6),
+ "timbre": types.STRING,
+ "tonal": types.Float(6),
+ "voice_instrumental": types.STRING,
}
def __init__(self):
super().__init__()
- self.config.add({
- 'auto': True,
- 'force': False,
- 'tags': []
- })
+ self._log.warning("This plugin is deprecated.")
- if self.config['auto']:
- self.register_listener('import_task_files',
- self.import_task_files)
+ self.config.add(
+ {"auto": True, "force": False, "tags": [], "base_url": ""}
+ )
+
+ self.base_url = self.config["base_url"].as_str()
+ if self.base_url:
+ if not self.base_url.startswith("http"):
+ raise ui.UserError(
+ "AcousticBrainz server base URL must start "
+ "with an HTTP scheme"
+ )
+ elif self.base_url[-1] != "/":
+ self.base_url = self.base_url + "/"
+
+ if self.config["auto"]:
+ self.register_listener("import_task_files", self.import_task_files)
def commands(self):
- cmd = ui.Subcommand('acousticbrainz',
- help="fetch metadata from AcousticBrainz")
+ cmd = ui.Subcommand(
+ "acousticbrainz", help="fetch metadata from AcousticBrainz"
+ )
cmd.parser.add_option(
- '-f', '--force', dest='force_refetch',
- action='store_true', default=False,
- help='re-download data when already present'
+ "-f",
+ "--force",
+ dest="force_refetch",
+ action="store_true",
+ default=False,
+ help="re-download data when already present",
)
def func(lib, opts, args):
items = lib.items(ui.decargs(args))
- self._fetch_info(items, ui.should_write(),
- opts.force_refetch or self.config['force'])
+ self._fetch_info(
+ items,
+ ui.should_write(),
+ opts.force_refetch or self.config["force"],
+ )
cmd.func = func
return [cmd]
def import_task_files(self, session, task):
- """Function is called upon beet import.
- """
+ """Function is called upon beet import."""
self._fetch_info(task.imported_items(), False, True)
def _get_data(self, mbid):
+ if not self.base_url:
+ raise ui.UserError(
+ "This plugin is deprecated since AcousticBrainz has shut "
+ "down. See the base_url configuration option."
+ )
data = {}
- for url in _generate_urls(mbid):
- self._log.debug('fetching URL: {}', url)
+ for url in _generate_urls(self.base_url, mbid):
+ self._log.debug("fetching URL: {}", url)
try:
- res = requests.get(url)
+ res = requests.get(url, timeout=10)
except requests.RequestException as exc:
- self._log.info('request error: {}', exc)
+ self._log.info("request error: {}", exc)
return {}
if res.status_code == 404:
- self._log.info('recording ID {} not found', mbid)
+ self._log.info("recording ID {} not found", mbid)
return {}
try:
data.update(res.json())
except ValueError:
- self._log.debug('Invalid Response: {}', res.text)
+ self._log.debug("Invalid Response: {}", res.text)
return {}
return data
def _fetch_info(self, items, write, force):
- """Fetch additional information from AcousticBrainz for the `item`s.
- """
- tags = self.config['tags'].as_str_seq()
+ """Fetch additional information from AcousticBrainz for the `item`s."""
+ tags = self.config["tags"].as_str_seq()
for item in items:
# If we're not forcing re-downloading for all tracks, check
# whether the data is already present. We use one
# representative field name to check for previously fetched
# data.
if not force:
- mood_str = item.get('mood_acoustic', '')
+ mood_str = item.get("mood_acoustic", "")
if mood_str:
- self._log.info('data already present for: {}', item)
+ self._log.info("data already present for: {}", item)
continue
# We can only fetch data for tracks with MBIDs.
if not item.mb_trackid:
continue
- self._log.info('getting data for: {}', item)
+ self._log.info("getting data for: {}", item)
data = self._get_data(item.mb_trackid)
if data:
for attr, val in self._map_data_to_scheme(data, ABSCHEME):
if not tags or attr in tags:
- self._log.debug('attribute {} of {} set to {}',
- attr,
- item,
- val)
+ self._log.debug(
+ "attribute {} of {} set to {}", attr, item, val
+ )
setattr(item, attr, val)
else:
- self._log.debug('skipping attribute {} of {}'
- ' (value {}) due to config',
- attr,
- item,
- val)
+ self._log.debug(
+ "skipping attribute {} of {}"
+ " (value {}) due to config",
+ attr,
+ item,
+ val,
+ )
item.store()
if write:
item.try_write()
def _map_data_to_scheme(self, data, scheme):
- """Given `data` as a structure of nested dictionaries, and `scheme` as a
- structure of nested dictionaries , `yield` tuples `(attr, val)` where
- `attr` and `val` are corresponding leaf nodes in `scheme` and `data`.
+ """Given `data` as a structure of nested dictionaries, and
+ `scheme` as a structure of nested dictionaries , `yield` tuples
+ `(attr, val)` where `attr` and `val` are corresponding leaf
+ nodes in `scheme` and `data`.
As its name indicates, `scheme` defines how the data is structured,
so this function tries to find leaf nodes in `data` that correspond
@@ -286,14 +253,12 @@ class AcousticPlugin(plugins.BeetsPlugin):
# The recursive traversal.
composites = defaultdict(list)
- yield from self._data_to_scheme_child(data,
- scheme,
- composites)
+ yield from self._data_to_scheme_child(data, scheme, composites)
# When composites has been populated, yield the composite attributes
# by joining their parts.
for composite_attr, value_parts in composites.items():
- yield composite_attr, ' '.join(value_parts)
+ yield composite_attr, " ".join(value_parts)
def _data_to_scheme_child(self, subdata, subscheme, composites):
"""The recursive business logic of :meth:`_map_data_to_scheme`:
@@ -307,28 +272,33 @@ class AcousticPlugin(plugins.BeetsPlugin):
"""
for k, v in subscheme.items():
if k in subdata:
- if type(v) == dict:
- yield from self._data_to_scheme_child(subdata[k],
- v,
- composites)
- elif type(v) == tuple:
+ if isinstance(v, dict):
+ yield from self._data_to_scheme_child(
+ subdata[k], v, composites
+ )
+ elif isinstance(v, tuple):
composite_attribute, part_number = v
attribute_parts = composites[composite_attribute]
# Parts are not guaranteed to be inserted in order
while len(attribute_parts) <= part_number:
- attribute_parts.append('')
+ attribute_parts.append("")
attribute_parts[part_number] = subdata[k]
else:
yield v, subdata[k]
else:
- self._log.warning('Acousticbrainz did not provide info'
- 'about {}', k)
- self._log.debug('Data {} could not be mapped to scheme {} '
- 'because key {} was not found', subdata, v, k)
+ self._log.warning(
+ "Acousticbrainz did not provide info " "about {}", k
+ )
+ self._log.debug(
+ "Data {} could not be mapped to scheme {} "
+ "because key {} was not found",
+ subdata,
+ v,
+ k,
+ )
-def _generate_urls(mbid):
- """Generates AcousticBrainz end point urls for given `mbid`.
- """
+def _generate_urls(base_url, mbid):
+ """Generates AcousticBrainz end point urls for given `mbid`."""
for level in LEVELS:
- yield ACOUSTIC_BASE + mbid + level
+ yield base_url + mbid + level
diff --git a/lib/beetsplug/advancedrewrite.py b/lib/beetsplug/advancedrewrite.py
new file mode 100644
index 00000000..9a5feaaf
--- /dev/null
+++ b/lib/beetsplug/advancedrewrite.py
@@ -0,0 +1,174 @@
+# This file is part of beets.
+# Copyright 2023, Max Rumpf.
+#
+# Permission is hereby granted, free of charge, to any person obtaining
+# a copy of this software and associated documentation files (the
+# "Software"), to deal in the Software without restriction, including
+# without limitation the rights to use, copy, modify, merge, publish,
+# distribute, sublicense, and/or sell copies of the Software, and to
+# permit persons to whom the Software is furnished to do so, subject to
+# the following conditions:
+#
+# The above copyright notice and this permission notice shall be
+# included in all copies or substantial portions of the Software.
+
+"""Plugin to rewrite fields based on a given query."""
+
+import re
+import shlex
+from collections import defaultdict
+
+import confuse
+
+from beets.dbcore import AndQuery, query_from_strings
+from beets.dbcore.types import MULTI_VALUE_DSV
+from beets.library import Album, Item
+from beets.plugins import BeetsPlugin
+from beets.ui import UserError
+
+
+def rewriter(field, simple_rules, advanced_rules):
+ """Template field function factory.
+
+ Create a template field function that rewrites the given field
+ with the given rewriting rules.
+ ``simple_rules`` must be a list of (pattern, replacement) pairs.
+ ``advanced_rules`` must be a list of (query, replacement) pairs.
+ """
+
+ def fieldfunc(item):
+ value = item._values_fixed[field]
+ for pattern, replacement in simple_rules:
+ if pattern.match(value.lower()):
+ # Rewrite activated.
+ return replacement
+ for query, replacement in advanced_rules:
+ if query.match(item):
+ # Rewrite activated.
+ return replacement
+ # Not activated; return original value.
+ return value
+
+ return fieldfunc
+
+
+class AdvancedRewritePlugin(BeetsPlugin):
+ """Plugin to rewrite fields based on a given query."""
+
+ def __init__(self):
+ """Parse configuration and register template fields for rewriting."""
+ super().__init__()
+
+ template = confuse.Sequence(
+ confuse.OneOf(
+ [
+ confuse.MappingValues(str),
+ {
+ "match": str,
+ "replacements": confuse.MappingValues(
+ confuse.OneOf([str, confuse.Sequence(str)]),
+ ),
+ },
+ ]
+ )
+ )
+
+ # Used to apply the same rewrite to the corresponding album field.
+ corresponding_album_fields = {
+ "artist": "albumartist",
+ "artists": "albumartists",
+ "artist_sort": "albumartist_sort",
+ "artists_sort": "albumartists_sort",
+ }
+
+ # Gather all the rewrite rules for each field.
+ class RulesContainer:
+ def __init__(self):
+ self.simple = []
+ self.advanced = []
+
+ rules = defaultdict(RulesContainer)
+ for rule in self.config.get(template):
+ if "match" not in rule:
+ # Simple syntax
+ if len(rule) != 1:
+ raise UserError(
+ "Simple rewrites must have only one rule, "
+ "but found multiple entries. "
+ "Did you forget to prepend a dash (-)?"
+ )
+ key, value = next(iter(rule.items()))
+ try:
+ fieldname, pattern = key.split(None, 1)
+ except ValueError:
+ raise UserError(
+ f"Invalid simple rewrite specification {key}"
+ )
+ if fieldname not in Item._fields:
+ raise UserError(
+ f"invalid field name {fieldname} in rewriter"
+ )
+ self._log.debug(
+ f"adding simple rewrite '{pattern}' → '{value}' "
+ f"for field {fieldname}"
+ )
+ pattern = re.compile(pattern.lower())
+ rules[fieldname].simple.append((pattern, value))
+
+ # Apply the same rewrite to the corresponding album field.
+ if fieldname in corresponding_album_fields:
+ album_fieldname = corresponding_album_fields[fieldname]
+ rules[album_fieldname].simple.append((pattern, value))
+ else:
+ # Advanced syntax
+ match = rule["match"]
+ replacements = rule["replacements"]
+ if len(replacements) == 0:
+ raise UserError(
+ "Advanced rewrites must have at least one replacement"
+ )
+ query = query_from_strings(
+ AndQuery,
+ Item,
+ prefixes={},
+ query_parts=shlex.split(match),
+ )
+ for fieldname, replacement in replacements.items():
+ if fieldname not in Item._fields:
+ raise UserError(
+ f"Invalid field name {fieldname} in rewriter"
+ )
+ self._log.debug(
+ f"adding advanced rewrite to '{replacement}' "
+ f"for field {fieldname}"
+ )
+ if isinstance(replacement, list):
+ if Item._fields[fieldname] is not MULTI_VALUE_DSV:
+ raise UserError(
+ f"Field {fieldname} is not a multi-valued field "
+ f"but a list was given: {', '.join(replacement)}"
+ )
+ elif isinstance(replacement, str):
+ if Item._fields[fieldname] is MULTI_VALUE_DSV:
+ replacement = [replacement]
+ else:
+ raise UserError(
+ f"Invalid type of replacement {replacement} "
+ f"for field {fieldname}"
+ )
+
+ rules[fieldname].advanced.append((query, replacement))
+
+ # Apply the same rewrite to the corresponding album field.
+ if fieldname in corresponding_album_fields:
+ album_fieldname = corresponding_album_fields[fieldname]
+ rules[album_fieldname].advanced.append(
+ (query, replacement)
+ )
+
+ # Replace each template field with the new rewriter function.
+ for fieldname, fieldrules in rules.items():
+ getter = rewriter(fieldname, fieldrules.simple, fieldrules.advanced)
+ self.template_fields[fieldname] = getter
+ if fieldname in Album._fields:
+ self.album_template_fields[fieldname] = getter
diff --git a/lib/beetsplug/albumtypes.py b/lib/beetsplug/albumtypes.py
index 47f8dc64..5200b5c6 100644
--- a/lib/beetsplug/albumtypes.py
+++ b/lib/beetsplug/albumtypes.py
@@ -26,40 +26,42 @@ class AlbumTypesPlugin(BeetsPlugin):
def __init__(self):
"""Init AlbumTypesPlugin."""
super().__init__()
- self.album_template_fields['atypes'] = self._atypes
- self.config.add({
- 'types': [
- ('ep', 'EP'),
- ('single', 'Single'),
- ('soundtrack', 'OST'),
- ('live', 'Live'),
- ('compilation', 'Anthology'),
- ('remix', 'Remix')
- ],
- 'ignore_va': ['compilation'],
- 'bracket': '[]'
- })
+ self.album_template_fields["atypes"] = self._atypes
+ self.config.add(
+ {
+ "types": [
+ ("ep", "EP"),
+ ("single", "Single"),
+ ("soundtrack", "OST"),
+ ("live", "Live"),
+ ("compilation", "Anthology"),
+ ("remix", "Remix"),
+ ],
+ "ignore_va": ["compilation"],
+ "bracket": "[]",
+ }
+ )
def _atypes(self, item: Album):
"""Returns a formatted string based on album's types."""
- types = self.config['types'].as_pairs()
- ignore_va = self.config['ignore_va'].as_str_seq()
- bracket = self.config['bracket'].as_str()
+ types = self.config["types"].as_pairs()
+ ignore_va = self.config["ignore_va"].as_str_seq()
+ bracket = self.config["bracket"].as_str()
# Assign a left and right bracket or leave blank if argument is empty.
if len(bracket) == 2:
bracket_l = bracket[0]
bracket_r = bracket[1]
else:
- bracket_l = ''
- bracket_r = ''
+ bracket_l = ""
+ bracket_r = ""
- res = ''
- albumtypes = item.albumtypes.split('; ')
+ res = ""
+ albumtypes = item.albumtypes
is_va = item.mb_albumartistid == VARIOUS_ARTISTS_ID
for type in types:
if type[0] in albumtypes and type[1]:
if not is_va or (type[0] not in ignore_va and is_va):
- res += f'{bracket_l}{type[1]}{bracket_r}'
+ res += f"{bracket_l}{type[1]}{bracket_r}"
return res
diff --git a/lib/beetsplug/aura.py b/lib/beetsplug/aura.py
index f4ae5527..09d85920 100644
--- a/lib/beetsplug/aura.py
+++ b/lib/beetsplug/aura.py
@@ -15,35 +15,41 @@
"""An AURA server using Flask."""
-from mimetypes import guess_type
+import os
import re
-import os.path
-from os.path import isfile, getsize
-
-from beets.plugins import BeetsPlugin
-from beets.ui import Subcommand, _open_library
-from beets import config
-from beets.util import py3_path
-from beets.library import Item, Album
-from beets.dbcore.query import (
- MatchQuery,
- NotQuery,
- RegexpQuery,
- AndQuery,
- FixedFieldSort,
- SlowFieldSort,
- MultipleSort,
-)
+import sys
+from dataclasses import dataclass
+from mimetypes import guess_type
+from typing import ClassVar, Mapping, Type
from flask import (
Blueprint,
Flask,
current_app,
- send_file,
make_response,
request,
+ send_file,
)
+if sys.version_info >= (3, 11):
+ from typing import Self
+else:
+ from typing_extensions import Self
+
+from beets import config
+from beets.dbcore.query import (
+ AndQuery,
+ FixedFieldSort,
+ MatchQuery,
+ MultipleSort,
+ NotQuery,
+ RegexpQuery,
+ SlowFieldSort,
+ SQLiteType,
+)
+from beets.library import Album, Item, LibModel, Library
+from beets.plugins import BeetsPlugin
+from beets.ui import Subcommand, _open_library
# Constants
@@ -118,9 +124,20 @@ ARTIST_ATTR_MAP = {
}
+@dataclass
class AURADocument:
"""Base class for building AURA documents."""
+ model_cls: ClassVar[Type[LibModel]]
+
+ lib: Library
+ args: Mapping[str, str]
+
+ @classmethod
+ def from_app(cls) -> Self:
+ """Initialise the document using the global app and request."""
+ return cls(current_app.config["lib"], request.args)
+
@staticmethod
def error(status, title, detail):
"""Make a response for an error following the JSON:API spec.
@@ -136,13 +153,29 @@ class AURADocument:
}
return make_response(document, status)
+ @classmethod
+ def get_attribute_converter(cls, beets_attr: str) -> Type[SQLiteType]:
+ """Work out what data type an attribute should be for beets.
+
+ Args:
+ beets_attr: The name of the beets attribute, e.g. "title".
+ """
+ try:
+ # Look for field in list of Album fields
+ # and get python type of database type.
+ # See beets.library.Album and beets.dbcore.types
+ return cls.model_cls._fields[beets_attr].model_type
+ except KeyError:
+ # Fall back to string (NOTE: probably not good)
+ return str
+
def translate_filters(self):
"""Translate filters from request arguments to a beets Query."""
# The format of each filter key in the request parameter is:
# filter[]. This regex extracts .
pattern = re.compile(r"filter\[(?P[a-zA-Z0-9_-]+)\]")
queries = []
- for key, value in request.args.items():
+ for key, value in self.args.items():
match = pattern.match(key)
if match:
# Extract attribute name from key
@@ -191,10 +224,10 @@ class AURADocument:
albums) or a list of strings (artists).
"""
# Pages start from zero
- page = request.args.get("page", 0, int)
+ page = self.args.get("page", 0, int)
# Use page limit defined in config by default.
default_limit = config["aura"]["page_limit"].get(int)
- limit = request.args.get("limit", default_limit, int)
+ limit = self.args.get("limit", default_limit, int)
# start = offset of first item to return
start = page * limit
# end = offset of last item + 1
@@ -204,10 +237,10 @@ class AURADocument:
next_url = None
else:
# Not the last page so work out links.next url
- if not request.args:
+ if not self.args:
# No existing arguments, so current page is 0
next_url = request.url + "?page=1"
- elif not request.args.get("page", None):
+ elif not self.args.get("page", None):
# No existing page argument, so add one to the end
next_url = request.url + "&page=1"
else:
@@ -216,7 +249,10 @@ class AURADocument:
f"page={page}", "page={}".format(page + 1)
)
# Get only the items in the page range
- data = [self.resource_object(collection[i]) for i in range(start, end)]
+ data = [
+ self.get_resource_object(self.lib, collection[i])
+ for i in range(start, end)
+ ]
return data, next_url
def get_included(self, data, include_str):
@@ -250,18 +286,26 @@ class AURADocument:
res_type = identifier["type"]
if res_type == "track":
track_id = int(identifier["id"])
- track = current_app.config["lib"].get_item(track_id)
- included.append(TrackDocument.resource_object(track))
+ track = self.lib.get_item(track_id)
+ included.append(
+ TrackDocument.get_resource_object(self.lib, track)
+ )
elif res_type == "album":
album_id = int(identifier["id"])
- album = current_app.config["lib"].get_album(album_id)
- included.append(AlbumDocument.resource_object(album))
+ album = self.lib.get_album(album_id)
+ included.append(
+ AlbumDocument.get_resource_object(self.lib, album)
+ )
elif res_type == "artist":
artist_id = identifier["id"]
- included.append(ArtistDocument.resource_object(artist_id))
+ included.append(
+ ArtistDocument.get_resource_object(self.lib, artist_id)
+ )
elif res_type == "image":
image_id = identifier["id"]
- included.append(ImageDocument.resource_object(image_id))
+ included.append(
+ ImageDocument.get_resource_object(self.lib, image_id)
+ )
else:
raise ValueError(f"Invalid resource type: {res_type}")
return included
@@ -269,7 +313,7 @@ class AURADocument:
def all_resources(self):
"""Build document for /tracks, /albums or /artists."""
query = self.translate_filters()
- sort_arg = request.args.get("sort", None)
+ sort_arg = self.args.get("sort", None)
if sort_arg:
sort = self.translate_sorts(sort_arg)
# For each sort field add a query which ensures all results
@@ -292,7 +336,7 @@ class AURADocument:
if next_url:
document["links"] = {"next": next_url}
# Include related resources for each element in "data"
- include_str = request.args.get("include", None)
+ include_str = self.args.get("include", None)
if include_str:
document["included"] = self.get_included(data, include_str)
return document
@@ -305,7 +349,7 @@ class AURADocument:
resource object.
"""
document = {"data": resource_object}
- include_str = request.args.get("include", None)
+ include_str = self.args.get("include", None)
if include_str:
# [document["data"]] is because arg needs to be list
document["included"] = self.get_included(
@@ -317,6 +361,8 @@ class AURADocument:
class TrackDocument(AURADocument):
"""Class for building documents for /tracks endpoints."""
+ model_cls = Item
+
attribute_map = TRACK_ATTR_MAP
def get_collection(self, query=None, sort=None):
@@ -326,9 +372,10 @@ class TrackDocument(AURADocument):
query: A beets Query object or a beets query string.
sort: A beets Sort object.
"""
- return current_app.config["lib"].items(query, sort)
+ return self.lib.items(query, sort)
- def get_attribute_converter(self, beets_attr):
+ @classmethod
+ def get_attribute_converter(cls, beets_attr: str) -> Type[SQLiteType]:
"""Work out what data type an attribute should be for beets.
Args:
@@ -336,20 +383,12 @@ class TrackDocument(AURADocument):
"""
# filesize is a special field (read from disk not db?)
if beets_attr == "filesize":
- converter = int
- else:
- try:
- # Look for field in list of Item fields
- # and get python type of database type.
- # See beets.library.Item and beets.dbcore.types
- converter = Item._fields[beets_attr].model_type
- except KeyError:
- # Fall back to string (NOTE: probably not good)
- converter = str
- return converter
+ return int
+
+ return super().get_attribute_converter(beets_attr)
@staticmethod
- def resource_object(track):
+ def get_resource_object(lib: Library, track):
"""Construct a JSON:API resource object from a beets Item.
Args:
@@ -387,7 +426,7 @@ class TrackDocument(AURADocument):
Args:
track_id: The beets id of the track (integer).
"""
- track = current_app.config["lib"].get_item(track_id)
+ track = self.lib.get_item(track_id)
if not track:
return self.error(
"404 Not Found",
@@ -396,12 +435,16 @@ class TrackDocument(AURADocument):
track_id
),
)
- return self.single_resource_document(self.resource_object(track))
+ return self.single_resource_document(
+ self.get_resource_object(self.lib, track)
+ )
class AlbumDocument(AURADocument):
"""Class for building documents for /albums endpoints."""
+ model_cls = Album
+
attribute_map = ALBUM_ATTR_MAP
def get_collection(self, query=None, sort=None):
@@ -411,26 +454,10 @@ class AlbumDocument(AURADocument):
query: A beets Query object or a beets query string.
sort: A beets Sort object.
"""
- return current_app.config["lib"].albums(query, sort)
-
- def get_attribute_converter(self, beets_attr):
- """Work out what data type an attribute should be for beets.
-
- Args:
- beets_attr: The name of the beets attribute, e.g. "title".
- """
- try:
- # Look for field in list of Album fields
- # and get python type of database type.
- # See beets.library.Album and beets.dbcore.types
- converter = Album._fields[beets_attr].model_type
- except KeyError:
- # Fall back to string (NOTE: probably not good)
- converter = str
- return converter
+ return self.lib.albums(query, sort)
@staticmethod
- def resource_object(album):
+ def get_resource_object(lib: Library, album):
"""Construct a JSON:API resource object from a beets Album.
Args:
@@ -449,7 +476,7 @@ class AlbumDocument(AURADocument):
# track number. Sorting is not required but it's nice.
query = MatchQuery("album_id", album.id)
sort = FixedFieldSort("track", ascending=True)
- tracks = current_app.config["lib"].items(query, sort)
+ tracks = lib.items(query, sort)
# JSON:API one-to-many relationship to tracks on the album
relationships = {
"tracks": {
@@ -458,7 +485,7 @@ class AlbumDocument(AURADocument):
}
# Add images relationship if album has associated images
if album.artpath:
- path = py3_path(album.artpath)
+ path = os.fsdecode(album.artpath)
filename = path.split("/")[-1]
image_id = f"album-{album.id}-{filename}"
relationships["images"] = {
@@ -485,7 +512,7 @@ class AlbumDocument(AURADocument):
Args:
album_id: The beets id of the album (integer).
"""
- album = current_app.config["lib"].get_album(album_id)
+ album = self.lib.get_album(album_id)
if not album:
return self.error(
"404 Not Found",
@@ -494,12 +521,16 @@ class AlbumDocument(AURADocument):
album_id
),
)
- return self.single_resource_document(self.resource_object(album))
+ return self.single_resource_document(
+ self.get_resource_object(self.lib, album)
+ )
class ArtistDocument(AURADocument):
"""Class for building documents for /artists endpoints."""
+ model_cls = Item
+
attribute_map = ARTIST_ATTR_MAP
def get_collection(self, query=None, sort=None):
@@ -510,7 +541,7 @@ class ArtistDocument(AURADocument):
sort: A beets Sort object.
"""
# Gets only tracks with matching artist information
- tracks = current_app.config["lib"].items(query, sort)
+ tracks = self.lib.items(query, sort)
collection = []
for track in tracks:
# Do not add duplicates
@@ -518,24 +549,8 @@ class ArtistDocument(AURADocument):
collection.append(track.artist)
return collection
- def get_attribute_converter(self, beets_attr):
- """Work out what data type an attribute should be for beets.
-
- Args:
- beets_attr: The name of the beets attribute, e.g. "artist".
- """
- try:
- # Look for field in list of Item fields
- # and get python type of database type.
- # See beets.library.Item and beets.dbcore.types
- converter = Item._fields[beets_attr].model_type
- except KeyError:
- # Fall back to string (NOTE: probably not good)
- converter = str
- return converter
-
@staticmethod
- def resource_object(artist_id):
+ def get_resource_object(lib: Library, artist_id):
"""Construct a JSON:API resource object for the given artist.
Args:
@@ -543,7 +558,7 @@ class ArtistDocument(AURADocument):
"""
# Get tracks where artist field exactly matches artist_id
query = MatchQuery("artist", artist_id)
- tracks = current_app.config["lib"].items(query)
+ tracks = lib.items(query)
if not tracks:
return None
@@ -565,7 +580,7 @@ class ArtistDocument(AURADocument):
}
}
album_query = MatchQuery("albumartist", artist_id)
- albums = current_app.config["lib"].albums(query=album_query)
+ albums = lib.albums(query=album_query)
if len(albums) != 0:
relationships["albums"] = {
"data": [{"type": "album", "id": str(a.id)} for a in albums]
@@ -584,7 +599,7 @@ class ArtistDocument(AURADocument):
Args:
artist_id: A string which is the artist's name.
"""
- artist_resource = self.resource_object(artist_id)
+ artist_resource = self.get_resource_object(self.lib, artist_id)
if not artist_resource:
return self.error(
"404 Not Found",
@@ -608,7 +623,7 @@ def safe_filename(fn):
return False
# In single names, rule out Unix directory traversal names.
- if fn in ('.', '..'):
+ if fn in (".", ".."):
return False
return True
@@ -617,8 +632,10 @@ def safe_filename(fn):
class ImageDocument(AURADocument):
"""Class for building documents for /images/(id) endpoints."""
+ model_cls = Album
+
@staticmethod
- def get_image_path(image_id):
+ def get_image_path(lib: Library, image_id):
"""Works out the full path to the image with the given id.
Returns None if there is no such image.
@@ -640,13 +657,13 @@ class ImageDocument(AURADocument):
# Get the path to the directory parent's images are in
if parent_type == "album":
- album = current_app.config["lib"].get_album(int(parent_id))
+ album = lib.get_album(int(parent_id))
if not album or not album.artpath:
return None
# Cut the filename off of artpath
# This is in preparation for supporting images in the same
# directory that are not tracked by beets.
- artpath = py3_path(album.artpath)
+ artpath = os.fsdecode(album.artpath)
dir_path = "/".join(artpath.split("/")[:-1])
else:
# Images for other resource types are not supported
@@ -654,13 +671,13 @@ class ImageDocument(AURADocument):
img_path = os.path.join(dir_path, img_filename)
# Check the image actually exists
- if isfile(img_path):
+ if os.path.isfile(img_path):
return img_path
else:
return None
@staticmethod
- def resource_object(image_id):
+ def get_resource_object(lib: Library, image_id):
"""Construct a JSON:API resource object for the given image.
Args:
@@ -669,14 +686,14 @@ class ImageDocument(AURADocument):
"""
# Could be called as a static method, so can't use
# self.get_image_path()
- image_path = ImageDocument.get_image_path(image_id)
+ image_path = ImageDocument.get_image_path(lib, image_id)
if not image_path:
return None
attributes = {
"role": "cover",
"mimetype": guess_type(image_path)[0],
- "size": getsize(image_path),
+ "size": os.path.getsize(image_path),
}
try:
from PIL import Image
@@ -709,7 +726,7 @@ class ImageDocument(AURADocument):
image_id: A string in the form
"--".
"""
- image_resource = self.resource_object(image_id)
+ image_resource = self.get_resource_object(self.lib, image_id)
if not image_resource:
return self.error(
"404 Not Found",
@@ -737,8 +754,7 @@ def server_info():
@aura_bp.route("/tracks")
def all_tracks():
"""Respond with a list of all tracks and related information."""
- doc = TrackDocument()
- return doc.all_resources()
+ return TrackDocument.from_app().all_resources()
@aura_bp.route("/tracks/")
@@ -748,8 +764,7 @@ def single_track(track_id):
Args:
track_id: The id of the track provided in the URL (integer).
"""
- doc = TrackDocument()
- return doc.single_resource(track_id)
+ return TrackDocument.from_app().single_resource(track_id)
@aura_bp.route("/tracks//audio")
@@ -769,8 +784,8 @@ def audio_file(track_id):
),
)
- path = py3_path(track.path)
- if not isfile(path):
+ path = os.fsdecode(track.path)
+ if not os.path.isfile(path):
return AURADocument.error(
"404 Not Found",
"No audio file for the requested track.",
@@ -821,8 +836,7 @@ def audio_file(track_id):
@aura_bp.route("/albums")
def all_albums():
"""Respond with a list of all albums and related information."""
- doc = AlbumDocument()
- return doc.all_resources()
+ return AlbumDocument.from_app().all_resources()
@aura_bp.route("/albums/")
@@ -832,8 +846,7 @@ def single_album(album_id):
Args:
album_id: The id of the album provided in the URL (integer).
"""
- doc = AlbumDocument()
- return doc.single_resource(album_id)
+ return AlbumDocument.from_app().single_resource(album_id)
# Artist endpoints
@@ -843,8 +856,7 @@ def single_album(album_id):
@aura_bp.route("/artists")
def all_artists():
"""Respond with a list of all artists and related information."""
- doc = ArtistDocument()
- return doc.all_resources()
+ return ArtistDocument.from_app().all_resources()
# Using the path converter allows slashes in artist_id
@@ -856,8 +868,7 @@ def single_artist(artist_id):
artist_id: The id of the artist provided in the URL. A string
which is the artist's name.
"""
- doc = ArtistDocument()
- return doc.single_resource(artist_id)
+ return ArtistDocument.from_app().single_resource(artist_id)
# Image endpoints
@@ -873,8 +884,7 @@ def single_image(image_id):
image_id: The id of the image provided in the URL. A string in
the form "--".
"""
- doc = ImageDocument()
- return doc.single_resource(image_id)
+ return ImageDocument.from_app().single_resource(image_id)
@aura_bp.route("/images//file")
@@ -885,7 +895,7 @@ def image_file(image_id):
image_id: The id of the image provided in the URL. A string in
the form "--".
"""
- img_path = ImageDocument.get_image_path(image_id)
+ img_path = ImageDocument.get_image_path(current_app.config["lib"], image_id)
if not img_path:
return AURADocument.error(
"404 Not Found",
diff --git a/lib/beetsplug/autobpm.py b/lib/beetsplug/autobpm.py
new file mode 100644
index 00000000..aace0c59
--- /dev/null
+++ b/lib/beetsplug/autobpm.py
@@ -0,0 +1,92 @@
+# This file is part of beets.
+#
+# Permission is hereby granted, free of charge, to any person obtaining
+# a copy of this software and associated documentation files (the
+# "Software"), to deal in the Software without restriction, including
+# without limitation the rights to use, copy, modify, merge, publish,
+# distribute, sublicense, and/or sell copies of the Software, and to
+# permit persons to whom the Software is furnished to do so, subject to
+# the following conditions:
+#
+# The above copyright notice and this permission notice shall be
+# included in all copies or substantial portions of the Software.
+
+"""Uses Librosa to calculate the `bpm` field.
+"""
+
+
+from librosa import beat, load
+from soundfile import LibsndfileError
+
+from beets import ui, util
+from beets.plugins import BeetsPlugin
+
+
+class AutoBPMPlugin(BeetsPlugin):
+ def __init__(self):
+ super().__init__()
+ self.config.add(
+ {
+ "auto": True,
+ "overwrite": False,
+ }
+ )
+
+ if self.config["auto"].get(bool):
+ self.import_stages = [self.imported]
+
+ def commands(self):
+ cmd = ui.Subcommand(
+ "autobpm", help="detect and add bpm from audio using Librosa"
+ )
+ cmd.func = self.command
+ return [cmd]
+
+ def command(self, lib, opts, args):
+ self.calculate_bpm(lib.items(ui.decargs(args)), write=ui.should_write())
+
+ def imported(self, session, task):
+ self.calculate_bpm(task.imported_items())
+
+ def calculate_bpm(self, items, write=False):
+ overwrite = self.config["overwrite"].get(bool)
+
+ for item in items:
+ if item["bpm"]:
+ self._log.info(
+ "found bpm {0} for {1}",
+ item["bpm"],
+ util.displayable_path(item.path),
+ )
+ if not overwrite:
+ continue
+
+ try:
+ y, sr = load(util.syspath(item.path), res_type="kaiser_fast")
+ except LibsndfileError as exc:
+ self._log.error(
+ "LibsndfileError: failed to load {0} {1}",
+ util.displayable_path(item.path),
+ exc,
+ )
+ continue
+ except ValueError as exc:
+ self._log.error(
+ "ValueError: failed to load {0} {1}",
+ util.displayable_path(item.path),
+ exc,
+ )
+ continue
+
+ tempo, _ = beat.beat_track(y=y, sr=sr)
+ bpm = round(tempo)
+ item["bpm"] = bpm
+ self._log.info(
+ "added computed bpm {0} for {1}",
+ bpm,
+ util.displayable_path(item.path),
+ )
+
+ if write:
+ item.try_write()
+ item.store()
diff --git a/lib/beetsplug/badfiles.py b/lib/beetsplug/badfiles.py
index ec465895..056b6534 100644
--- a/lib/beetsplug/badfiles.py
+++ b/lib/beetsplug/badfiles.py
@@ -16,18 +16,18 @@
"""
-from subprocess import check_output, CalledProcessError, list2cmdline, STDOUT
-
-import shlex
-import os
import errno
+import os
+import shlex
import sys
+from subprocess import STDOUT, CalledProcessError, check_output, list2cmdline
+
import confuse
+
+from beets import importer, ui
from beets.plugins import BeetsPlugin
from beets.ui import Subcommand
from beets.util import displayable_path, par_map
-from beets import ui
-from beets import importer
class CheckerCommandException(Exception):
@@ -52,14 +52,15 @@ class BadFiles(BeetsPlugin):
super().__init__()
self.verbose = False
- self.register_listener('import_task_start',
- self.on_import_task_start)
- self.register_listener('import_task_before_choice',
- self.on_import_task_before_choice)
+ self.register_listener("import_task_start", self.on_import_task_start)
+ self.register_listener(
+ "import_task_before_choice", self.on_import_task_before_choice
+ )
def run_command(self, cmd):
- self._log.debug("running command: {}",
- displayable_path(list2cmdline(cmd)))
+ self._log.debug(
+ "running command: {}", displayable_path(list2cmdline(cmd))
+ )
try:
output = check_output(cmd, stderr=STDOUT)
errors = 0
@@ -70,7 +71,7 @@ class BadFiles(BeetsPlugin):
status = e.returncode
except OSError as e:
raise CheckerCommandException(cmd, e)
- output = output.decode(sys.getdefaultencoding(), 'replace')
+ output = output.decode(sys.getdefaultencoding(), "replace")
return status, errors, [line for line in output.split("\n") if line]
def check_mp3val(self, path):
@@ -88,12 +89,13 @@ class BadFiles(BeetsPlugin):
cmd = shlex.split(command)
cmd.append(path)
return self.run_command(cmd)
+
return checker
def get_checker(self, ext):
ext = ext.lower()
try:
- command = self.config['commands'].get(dict).get(ext)
+ command = self.config["commands"].get(dict).get(ext)
except confuse.NotFoundError:
command = None
if command:
@@ -109,15 +111,17 @@ class BadFiles(BeetsPlugin):
dpath = displayable_path(item.path)
self._log.debug("checking path: {}", dpath)
if not os.path.exists(item.path):
- ui.print_("{}: file does not exist".format(
- ui.colorize('text_error', dpath)))
+ ui.print_(
+ "{}: file does not exist".format(
+ ui.colorize("text_error", dpath)
+ )
+ )
# Run the checker against the file if one is found
- ext = os.path.splitext(item.path)[1][1:].decode('utf8', 'ignore')
+ ext = os.path.splitext(item.path)[1][1:].decode("utf8", "ignore")
checker = self.get_checker(ext)
if not checker:
- self._log.error("no checker specified in the config for {}",
- ext)
+ self._log.error("no checker specified in the config for {}", ext)
return []
path = item.path
if not isinstance(path, str):
@@ -129,7 +133,7 @@ class BadFiles(BeetsPlugin):
self._log.error(
"command not found: {} when validating file: {}",
e.checker,
- e.path
+ e.path,
)
else:
self._log.error("error invoking {}: {}", e.checker, e.msg)
@@ -139,25 +143,30 @@ class BadFiles(BeetsPlugin):
if status > 0:
error_lines.append(
- "{}: checker exited with status {}"
- .format(ui.colorize('text_error', dpath), status))
+ "{}: checker exited with status {}".format(
+ ui.colorize("text_error", dpath), status
+ )
+ )
for line in output:
error_lines.append(f" {line}")
elif errors > 0:
error_lines.append(
- "{}: checker found {} errors or warnings"
- .format(ui.colorize('text_warning', dpath), errors))
+ "{}: checker found {} errors or warnings".format(
+ ui.colorize("text_warning", dpath), errors
+ )
+ )
for line in output:
error_lines.append(f" {line}")
elif self.verbose:
error_lines.append(
- "{}: ok".format(ui.colorize('text_success', dpath)))
+ "{}: ok".format(ui.colorize("text_success", dpath))
+ )
return error_lines
def on_import_task_start(self, task, session):
- if not self.config['check_on_import'].get(False):
+ if not self.config["check_on_import"].get(False):
return
checks_failed = []
@@ -171,26 +180,29 @@ class BadFiles(BeetsPlugin):
task._badfiles_checks_failed = checks_failed
def on_import_task_before_choice(self, task, session):
- if hasattr(task, '_badfiles_checks_failed'):
- ui.print_('{} one or more files failed checks:'
- .format(ui.colorize('text_warning', 'BAD')))
+ if hasattr(task, "_badfiles_checks_failed"):
+ ui.print_(
+ "{} one or more files failed checks:".format(
+ ui.colorize("text_warning", "BAD")
+ )
+ )
for error in task._badfiles_checks_failed:
for error_line in error:
ui.print_(error_line)
ui.print_()
- ui.print_('What would you like to do?')
+ ui.print_("What would you like to do?")
- sel = ui.input_options(['aBort', 'skip', 'continue'])
+ sel = ui.input_options(["aBort", "skip", "continue"])
- if sel == 's':
+ if sel == "s":
return importer.action.SKIP
- elif sel == 'c':
+ elif sel == "c":
return None
- elif sel == 'b':
+ elif sel == "b":
raise importer.ImportAbort()
else:
- raise Exception(f'Unexpected selection: {sel}')
+ raise Exception(f"Unexpected selection: {sel}")
def command(self, lib, opts, args):
# Get items from arguments
@@ -204,12 +216,16 @@ class BadFiles(BeetsPlugin):
par_map(check_and_print, items)
def commands(self):
- bad_command = Subcommand('bad',
- help='check for corrupt or missing files')
+ bad_command = Subcommand(
+ "bad", help="check for corrupt or missing files"
+ )
bad_command.parser.add_option(
- '-v', '--verbose',
- action='store_true', default=False, dest='verbose',
- help='view results for both the bad and uncorrupted files'
+ "-v",
+ "--verbose",
+ action="store_true",
+ default=False,
+ dest="verbose",
+ help="view results for both the bad and uncorrupted files",
)
bad_command.func = self.command
return [bad_command]
diff --git a/lib/beetsplug/bareasc.py b/lib/beetsplug/bareasc.py
index 21836936..8cdcbb11 100644
--- a/lib/beetsplug/bareasc.py
+++ b/lib/beetsplug/bareasc.py
@@ -19,15 +19,17 @@
"""Provides a bare-ASCII matching query."""
-from beets import ui
-from beets.ui import print_, decargs
-from beets.plugins import BeetsPlugin
-from beets.dbcore.query import StringFieldQuery
from unidecode import unidecode
+from beets import ui
+from beets.dbcore.query import StringFieldQuery
+from beets.plugins import BeetsPlugin
+from beets.ui import decargs, print_
-class BareascQuery(StringFieldQuery):
+
+class BareascQuery(StringFieldQuery[str]):
"""Compare items using bare ASCII, without accents etc."""
+
@classmethod
def string_match(cls, pattern, val):
"""Convert both pattern and string to plain ASCII before matching.
@@ -42,27 +44,40 @@ class BareascQuery(StringFieldQuery):
val = unidecode(val)
return pattern in val
+ def col_clause(self):
+ """Compare ascii version of the pattern."""
+ clause = f"unidecode({self.field})"
+ if self.pattern.islower():
+ clause = f"lower({clause})"
+
+ return rf"{clause} LIKE ? ESCAPE '\'", [f"%{unidecode(self.pattern)}%"]
+
class BareascPlugin(BeetsPlugin):
"""Plugin to provide bare-ASCII option for beets matching."""
+
def __init__(self):
"""Default prefix for selecting bare-ASCII matching is #."""
super().__init__()
- self.config.add({
- 'prefix': '#',
- })
+ self.config.add(
+ {
+ "prefix": "#",
+ }
+ )
def queries(self):
"""Register bare-ASCII matching."""
- prefix = self.config['prefix'].as_str()
+ prefix = self.config["prefix"].as_str()
return {prefix: BareascQuery}
def commands(self):
"""Add bareasc command as unidecode version of 'list'."""
- cmd = ui.Subcommand('bareasc',
- help='unidecode version of beet list command')
- cmd.parser.usage += "\n" \
- 'Example: %prog -f \'$album: $title\' artist:beatles'
+ cmd = ui.Subcommand(
+ "bareasc", help="unidecode version of beet list command"
+ )
+ cmd.parser.usage += (
+ "\n" "Example: %prog -f '$album: $title' artist:beatles"
+ )
cmd.parser.add_all_common_options()
cmd.func = self.unidecode_list
return [cmd]
diff --git a/lib/beetsplug/beatport.py b/lib/beetsplug/beatport.py
index 133441d7..6108b039 100644
--- a/lib/beetsplug/beatport.py
+++ b/lib/beetsplug/beatport.py
@@ -19,19 +19,22 @@ import json
import re
from datetime import datetime, timedelta
+import confuse
from requests_oauthlib import OAuth1Session
-from requests_oauthlib.oauth1_session import (TokenRequestDenied, TokenMissing,
- VerifierMissing)
+from requests_oauthlib.oauth1_session import (
+ TokenMissing,
+ TokenRequestDenied,
+ VerifierMissing,
+)
import beets
import beets.ui
from beets.autotag.hooks import AlbumInfo, TrackInfo
from beets.plugins import BeetsPlugin, MetadataSourcePlugin, get_distance
-import confuse
-
+from beets.util.id_extractors import beatport_id_regex
AUTH_ERRORS = (TokenRequestDenied, TokenMissing, VerifierMissing)
-USER_AGENT = f'beets/{beets.__version__} +https://beets.io/'
+USER_AGENT = f"beets/{beets.__version__} +https://beets.io/"
class BeatportAPIError(Exception):
@@ -40,24 +43,23 @@ class BeatportAPIError(Exception):
class BeatportObject:
def __init__(self, data):
- self.beatport_id = data['id']
- self.name = str(data['name'])
- if 'releaseDate' in data:
- self.release_date = datetime.strptime(data['releaseDate'],
- '%Y-%m-%d')
- if 'artists' in data:
- self.artists = [(x['id'], str(x['name']))
- for x in data['artists']]
- if 'genres' in data:
- self.genres = [str(x['name'])
- for x in data['genres']]
+ self.beatport_id = data["id"]
+ self.name = str(data["name"])
+ if "releaseDate" in data:
+ self.release_date = datetime.strptime(
+ data["releaseDate"], "%Y-%m-%d"
+ )
+ if "artists" in data:
+ self.artists = [(x["id"], str(x["name"])) for x in data["artists"]]
+ if "genres" in data:
+ self.genres = [str(x["name"]) for x in data["genres"]]
class BeatportClient:
- _api_base = 'https://oauth-api.beatport.com'
+ _api_base = "https://oauth-api.beatport.com"
def __init__(self, c_key, c_secret, auth_key=None, auth_secret=None):
- """ Initiate the client with OAuth information.
+ """Initiate the client with OAuth information.
For the initial authentication with the backend `auth_key` and
`auth_secret` can be `None`. Use `get_authorize_url` and
@@ -69,14 +71,16 @@ class BeatportClient:
:param auth_secret: OAuth1 resource owner secret
"""
self.api = OAuth1Session(
- client_key=c_key, client_secret=c_secret,
+ client_key=c_key,
+ client_secret=c_secret,
resource_owner_key=auth_key,
resource_owner_secret=auth_secret,
- callback_uri='oob')
- self.api.headers = {'User-Agent': USER_AGENT}
+ callback_uri="oob",
+ )
+ self.api.headers = {"User-Agent": USER_AGENT}
def get_authorize_url(self):
- """ Generate the URL for the user to authorize the application.
+ """Generate the URL for the user to authorize the application.
Retrieves a request token from the Beatport API and returns the
corresponding authorization URL on their end that the user has
@@ -91,12 +95,14 @@ class BeatportClient:
:rtype: unicode
"""
self.api.fetch_request_token(
- self._make_url('/identity/1/oauth/request-token'))
+ self._make_url("/identity/1/oauth/request-token")
+ )
return self.api.authorization_url(
- self._make_url('/identity/1/oauth/authorize'))
+ self._make_url("/identity/1/oauth/authorize")
+ )
def get_access_token(self, auth_data):
- """ Obtain the final access token and secret for the API.
+ """Obtain the final access token and secret for the API.
:param auth_data: URL-encoded authorization data as displayed at
the authorization url (obtained via
@@ -106,13 +112,15 @@ class BeatportClient:
:rtype: (unicode, unicode) tuple
"""
self.api.parse_authorization_response(
- "https://beets.io/auth?" + auth_data)
+ "https://beets.io/auth?" + auth_data
+ )
access_data = self.api.fetch_access_token(
- self._make_url('/identity/1/oauth/access-token'))
- return access_data['oauth_token'], access_data['oauth_token_secret']
+ self._make_url("/identity/1/oauth/access-token")
+ )
+ return access_data["oauth_token"], access_data["oauth_token_secret"]
- def search(self, query, release_type='release', details=True):
- """ Perform a search of the Beatport catalogue.
+ def search(self, query, release_type="release", details=True):
+ """Perform a search of the Beatport catalogue.
:param query: Query string
:param release_type: Type of releases to search for, can be
@@ -126,27 +134,30 @@ class BeatportClient:
py:class:`BeatportRelease` or
:py:class:`BeatportTrack`
"""
- response = self._get('catalog/3/search',
- query=query, perPage=5,
- facets=[f'fieldType:{release_type}'])
+ response = self._get(
+ "catalog/3/search",
+ query=query,
+ perPage=5,
+ facets=[f"fieldType:{release_type}"],
+ )
for item in response:
- if release_type == 'release':
+ if release_type == "release":
if details:
- release = self.get_release(item['id'])
+ release = self.get_release(item["id"])
else:
release = BeatportRelease(item)
yield release
- elif release_type == 'track':
+ elif release_type == "track":
yield BeatportTrack(item)
def get_release(self, beatport_id):
- """ Get information about a single release.
+ """Get information about a single release.
:param beatport_id: Beatport ID of the release
:returns: The matching release
:rtype: :py:class:`BeatportRelease`
"""
- response = self._get('/catalog/3/releases', id=beatport_id)
+ response = self._get("/catalog/3/releases", id=beatport_id)
if response:
release = BeatportRelease(response[0])
release.tracks = self.get_release_tracks(beatport_id)
@@ -154,34 +165,35 @@ class BeatportClient:
return None
def get_release_tracks(self, beatport_id):
- """ Get all tracks for a given release.
+ """Get all tracks for a given release.
:param beatport_id: Beatport ID of the release
:returns: Tracks in the matching release
:rtype: list of :py:class:`BeatportTrack`
"""
- response = self._get('/catalog/3/tracks', releaseId=beatport_id,
- perPage=100)
+ response = self._get(
+ "/catalog/3/tracks", releaseId=beatport_id, perPage=100
+ )
return [BeatportTrack(t) for t in response]
def get_track(self, beatport_id):
- """ Get information about a single track.
+ """Get information about a single track.
:param beatport_id: Beatport ID of the track
:returns: The matching track
:rtype: :py:class:`BeatportTrack`
"""
- response = self._get('/catalog/3/tracks', id=beatport_id)
+ response = self._get("/catalog/3/tracks", id=beatport_id)
return BeatportTrack(response[0])
def _make_url(self, endpoint):
- """ Get complete URL for a given API endpoint. """
- if not endpoint.startswith('/'):
- endpoint = '/' + endpoint
+ """Get complete URL for a given API endpoint."""
+ if not endpoint.startswith("/"):
+ endpoint = "/" + endpoint
return self._api_base + endpoint
def _get(self, endpoint, **kwargs):
- """ Perform a GET request on a given API endpoint.
+ """Perform a GET request on a given API endpoint.
Automatically extracts result data from the response and converts HTTP
exceptions into :py:class:`BeatportAPIError` objects.
@@ -189,13 +201,16 @@ class BeatportClient:
try:
response = self.api.get(self._make_url(endpoint), params=kwargs)
except Exception as e:
- raise BeatportAPIError("Error connecting to Beatport API: {}"
- .format(e))
+ raise BeatportAPIError(
+ "Error connecting to Beatport API: {}".format(e)
+ )
if not response:
raise BeatportAPIError(
- "Error {0.status_code} for '{0.request.path_url}"
- .format(response))
- return response.json()['results']
+ "Error {0.status_code} for '{0.request.path_url}".format(
+ response
+ )
+ )
+ return response.json()["results"]
class BeatportRelease(BeatportObject):
@@ -211,79 +226,83 @@ class BeatportRelease(BeatportObject):
)
def __repr__(self):
- return str(self).encode('utf-8')
+ return str(self).encode("utf-8")
def __init__(self, data):
BeatportObject.__init__(self, data)
- if 'catalogNumber' in data:
- self.catalog_number = data['catalogNumber']
- if 'label' in data:
- self.label_name = data['label']['name']
- if 'category' in data:
- self.category = data['category']
- if 'slug' in data:
+ if "catalogNumber" in data:
+ self.catalog_number = data["catalogNumber"]
+ if "label" in data:
+ self.label_name = data["label"]["name"]
+ if "category" in data:
+ self.category = data["category"]
+ if "slug" in data:
self.url = "https://beatport.com/release/{}/{}".format(
- data['slug'], data['id'])
- self.genre = data.get('genre')
+ data["slug"], data["id"]
+ )
+ self.genre = data.get("genre")
class BeatportTrack(BeatportObject):
def __str__(self):
artist_str = ", ".join(x[1] for x in self.artists)
- return (""
- .format(artist_str, self.name, self.mix_name))
+ return "".format(
+ artist_str, self.name, self.mix_name
+ )
def __repr__(self):
- return str(self).encode('utf-8')
+ return str(self).encode("utf-8")
def __init__(self, data):
BeatportObject.__init__(self, data)
- if 'title' in data:
- self.title = str(data['title'])
- if 'mixName' in data:
- self.mix_name = str(data['mixName'])
- self.length = timedelta(milliseconds=data.get('lengthMs', 0) or 0)
+ if "title" in data:
+ self.title = str(data["title"])
+ if "mixName" in data:
+ self.mix_name = str(data["mixName"])
+ self.length = timedelta(milliseconds=data.get("lengthMs", 0) or 0)
if not self.length:
try:
- min, sec = data.get('length', '0:0').split(':')
+ min, sec = data.get("length", "0:0").split(":")
self.length = timedelta(minutes=int(min), seconds=int(sec))
except ValueError:
pass
- if 'slug' in data:
- self.url = "https://beatport.com/track/{}/{}" \
- .format(data['slug'], data['id'])
- self.track_number = data.get('trackNumber')
- self.bpm = data.get('bpm')
- self.initial_key = str(
- (data.get('key') or {}).get('shortName')
- )
+ if "slug" in data:
+ self.url = "https://beatport.com/track/{}/{}".format(
+ data["slug"], data["id"]
+ )
+ self.track_number = data.get("trackNumber")
+ self.bpm = data.get("bpm")
+ self.initial_key = str((data.get("key") or {}).get("shortName"))
# Use 'subgenre' and if not present, 'genre' as a fallback.
- if data.get('subGenres'):
- self.genre = str(data['subGenres'][0].get('name'))
- elif data.get('genres'):
- self.genre = str(data['genres'][0].get('name'))
+ if data.get("subGenres"):
+ self.genre = str(data["subGenres"][0].get("name"))
+ elif data.get("genres"):
+ self.genre = str(data["genres"][0].get("name"))
class BeatportPlugin(BeetsPlugin):
- data_source = 'Beatport'
+ data_source = "Beatport"
+ id_regex = beatport_id_regex
def __init__(self):
super().__init__()
- self.config.add({
- 'apikey': '57713c3906af6f5def151b33601389176b37b429',
- 'apisecret': 'b3fe08c93c80aefd749fe871a16cd2bb32e2b954',
- 'tokenfile': 'beatport_token.json',
- 'source_weight': 0.5,
- })
- self.config['apikey'].redact = True
- self.config['apisecret'].redact = True
+ self.config.add(
+ {
+ "apikey": "57713c3906af6f5def151b33601389176b37b429",
+ "apisecret": "b3fe08c93c80aefd749fe871a16cd2bb32e2b954",
+ "tokenfile": "beatport_token.json",
+ "source_weight": 0.5,
+ }
+ )
+ self.config["apikey"].redact = True
+ self.config["apisecret"].redact = True
self.client = None
- self.register_listener('import_begin', self.setup)
+ self.register_listener("import_begin", self.setup)
def setup(self, session=None):
- c_key = self.config['apikey'].as_str()
- c_secret = self.config['apisecret'].as_str()
+ c_key = self.config["apikey"].as_str()
+ c_secret = self.config["apisecret"].as_str()
# Get the OAuth token from a file or log in.
try:
@@ -293,8 +312,8 @@ class BeatportPlugin(BeetsPlugin):
# No token yet. Generate one.
token, secret = self.authenticate(c_key, c_secret)
else:
- token = tokendata['token']
- secret = tokendata['secret']
+ token = tokendata["token"]
+ secret = tokendata["secret"]
self.client = BeatportClient(c_key, c_secret, token, secret)
@@ -304,8 +323,8 @@ class BeatportPlugin(BeetsPlugin):
try:
url = auth_client.get_authorize_url()
except AUTH_ERRORS as e:
- self._log.debug('authentication error: {0}', e)
- raise beets.ui.UserError('communication with Beatport failed')
+ self._log.debug("authentication error: {0}", e)
+ raise beets.ui.UserError("communication with Beatport failed")
beets.ui.print_("To authenticate with Beatport, visit:")
beets.ui.print_(url)
@@ -315,29 +334,26 @@ class BeatportPlugin(BeetsPlugin):
try:
token, secret = auth_client.get_access_token(data)
except AUTH_ERRORS as e:
- self._log.debug('authentication error: {0}', e)
- raise beets.ui.UserError('Beatport token request failed')
+ self._log.debug("authentication error: {0}", e)
+ raise beets.ui.UserError("Beatport token request failed")
# Save the token for later use.
- self._log.debug('Beatport token {0}, secret {1}', token, secret)
- with open(self._tokenfile(), 'w') as f:
- json.dump({'token': token, 'secret': secret}, f)
+ self._log.debug("Beatport token {0}, secret {1}", token, secret)
+ with open(self._tokenfile(), "w") as f:
+ json.dump({"token": token, "secret": secret}, f)
return token, secret
def _tokenfile(self):
- """Get the path to the JSON file for storing the OAuth token.
- """
- return self.config['tokenfile'].get(confuse.Filename(in_app_dir=True))
+ """Get the path to the JSON file for storing the OAuth token."""
+ return self.config["tokenfile"].get(confuse.Filename(in_app_dir=True))
def album_distance(self, items, album_info, mapping):
"""Returns the Beatport source weight and the maximum source weight
for albums.
"""
return get_distance(
- data_source=self.data_source,
- info=album_info,
- config=self.config
+ data_source=self.data_source, info=album_info, config=self.config
)
def track_distance(self, item, track_info):
@@ -345,9 +361,7 @@ class BeatportPlugin(BeetsPlugin):
for individual tracks.
"""
return get_distance(
- data_source=self.data_source,
- info=track_info,
- config=self.config
+ data_source=self.data_source, info=track_info, config=self.config
)
def candidates(self, items, artist, release, va_likely, extra_tags=None):
@@ -357,34 +371,36 @@ class BeatportPlugin(BeetsPlugin):
if va_likely:
query = release
else:
- query = f'{artist} {release}'
+ query = f"{artist} {release}"
try:
return self._get_releases(query)
except BeatportAPIError as e:
- self._log.debug('API Error: {0} (query: {1})', e, query)
+ self._log.debug("API Error: {0} (query: {1})", e, query)
return []
def item_candidates(self, item, artist, title):
"""Returns a list of TrackInfo objects for beatport search results
matching title and artist.
"""
- query = f'{artist} {title}'
+ query = f"{artist} {title}"
try:
return self._get_tracks(query)
except BeatportAPIError as e:
- self._log.debug('API Error: {0} (query: {1})', e, query)
+ self._log.debug("API Error: {0} (query: {1})", e, query)
return []
def album_for_id(self, release_id):
"""Fetches a release by its Beatport ID and returns an AlbumInfo object
or None if the query is not a valid ID or release is not found.
"""
- self._log.debug('Searching for release {0}', release_id)
- match = re.search(r'(^|beatport\.com/release/.+/)(\d+)$', release_id)
- if not match:
- self._log.debug('Not a valid Beatport release ID.')
+ self._log.debug("Searching for release {0}", release_id)
+
+ release_id = self._get_id("album", release_id, self.id_regex)
+ if release_id is None:
+ self._log.debug("Not a valid Beatport release ID.")
return None
- release = self.client.get_release(match.group(2))
+
+ release = self.client.get_release(release_id)
if release:
return self._get_album_info(release)
return None
@@ -393,10 +409,10 @@ class BeatportPlugin(BeetsPlugin):
"""Fetches a track by its Beatport ID and returns a TrackInfo object
or None if the track is not a valid Beatport ID or track is not found.
"""
- self._log.debug('Searching for track {0}', track_id)
- match = re.search(r'(^|beatport\.com/track/.+/)(\d+)$', track_id)
+ self._log.debug("Searching for track {0}", track_id)
+ match = re.search(r"(^|beatport\.com/track/.+/)(\d+)$", track_id)
if not match:
- self._log.debug('Not a valid Beatport track ID.')
+ self._log.debug("Not a valid Beatport track ID.")
return None
bp_track = self.client.get_track(match.group(2))
if bp_track is not None:
@@ -404,55 +420,67 @@ class BeatportPlugin(BeetsPlugin):
return None
def _get_releases(self, query):
- """Returns a list of AlbumInfo objects for a beatport search query.
- """
+ """Returns a list of AlbumInfo objects for a beatport search query."""
# Strip non-word characters from query. Things like "!" and "-" can
# cause a query to return no results, even if they match the artist or
# album title. Use `re.UNICODE` flag to avoid stripping non-english
# word characters.
- query = re.sub(r'\W+', ' ', query, flags=re.UNICODE)
+ query = re.sub(r"\W+", " ", query, flags=re.UNICODE)
# Strip medium information from query, Things like "CD1" and "disk 1"
# can also negate an otherwise positive result.
- query = re.sub(r'\b(CD|disc)\s*\d+', '', query, flags=re.I)
- albums = [self._get_album_info(x)
- for x in self.client.search(query)]
+ query = re.sub(r"\b(CD|disc)\s*\d+", "", query, flags=re.I)
+ albums = [self._get_album_info(x) for x in self.client.search(query)]
return albums
def _get_album_info(self, release):
- """Returns an AlbumInfo object for a Beatport Release object.
- """
+ """Returns an AlbumInfo object for a Beatport Release object."""
va = len(release.artists) > 3
artist, artist_id = self._get_artist(release.artists)
if va:
artist = "Various Artists"
tracks = [self._get_track_info(x) for x in release.tracks]
- return AlbumInfo(album=release.name, album_id=release.beatport_id,
- artist=artist, artist_id=artist_id, tracks=tracks,
- albumtype=release.category, va=va,
- year=release.release_date.year,
- month=release.release_date.month,
- day=release.release_date.day,
- label=release.label_name,
- catalognum=release.catalog_number, media='Digital',
- data_source=self.data_source, data_url=release.url,
- genre=release.genre)
+ return AlbumInfo(
+ album=release.name,
+ album_id=release.beatport_id,
+ beatport_album_id=release.beatport_id,
+ artist=artist,
+ artist_id=artist_id,
+ tracks=tracks,
+ albumtype=release.category,
+ va=va,
+ year=release.release_date.year,
+ month=release.release_date.month,
+ day=release.release_date.day,
+ label=release.label_name,
+ catalognum=release.catalog_number,
+ media="Digital",
+ data_source=self.data_source,
+ data_url=release.url,
+ genre=release.genre,
+ )
def _get_track_info(self, track):
- """Returns a TrackInfo object for a Beatport Track object.
- """
+ """Returns a TrackInfo object for a Beatport Track object."""
title = track.name
if track.mix_name != "Original Mix":
title += f" ({track.mix_name})"
artist, artist_id = self._get_artist(track.artists)
length = track.length.total_seconds()
- return TrackInfo(title=title, track_id=track.beatport_id,
- artist=artist, artist_id=artist_id,
- length=length, index=track.track_number,
- medium_index=track.track_number,
- data_source=self.data_source, data_url=track.url,
- bpm=track.bpm, initial_key=track.initial_key,
- genre=track.genre)
+ return TrackInfo(
+ title=title,
+ track_id=track.beatport_id,
+ artist=artist,
+ artist_id=artist_id,
+ length=length,
+ index=track.track_number,
+ medium_index=track.track_number,
+ data_source=self.data_source,
+ data_url=track.url,
+ bpm=track.bpm,
+ initial_key=track.initial_key,
+ genre=track.genre,
+ )
def _get_artist(self, artists):
"""Returns an artist string (all artists) and an artist_id (the main
@@ -463,8 +491,7 @@ class BeatportPlugin(BeetsPlugin):
)
def _get_tracks(self, query):
- """Returns a list of TrackInfo objects for a Beatport query.
- """
- bp_tracks = self.client.search(query, release_type='track')
+ """Returns a list of TrackInfo objects for a Beatport query."""
+ bp_tracks = self.client.search(query, release_type="track")
tracks = [self._get_track_info(x) for x in bp_tracks]
return tracks
diff --git a/lib/beetsplug/bench.py b/lib/beetsplug/bench.py
index 6dffbdda..673b9b7c 100644
--- a/lib/beetsplug/bench.py
+++ b/lib/beetsplug/bench.py
@@ -16,17 +16,14 @@
"""
-from beets.plugins import BeetsPlugin
-from beets import ui
-from beets import vfs
-from beets import library
-from beets.util.functemplate import Template
-from beets.autotag import match
-from beets import plugins
-from beets import importer
import cProfile
import timeit
+from beets import importer, library, plugins, ui, vfs
+from beets.autotag import match
+from beets.plugins import BeetsPlugin
+from beets.util.functemplate import Template
+
def aunique_benchmark(lib, prof):
def _build_tree():
@@ -34,74 +31,103 @@ def aunique_benchmark(lib, prof):
# Measure path generation performance with %aunique{} included.
lib.path_formats = [
- (library.PF_KEY_DEFAULT,
- Template('$albumartist/$album%aunique{}/$track $title')),
+ (
+ library.PF_KEY_DEFAULT,
+ Template("$albumartist/$album%aunique{}/$track $title"),
+ ),
]
if prof:
- cProfile.runctx('_build_tree()', {}, {'_build_tree': _build_tree},
- 'paths.withaunique.prof')
+ cProfile.runctx(
+ "_build_tree()",
+ {},
+ {"_build_tree": _build_tree},
+ "paths.withaunique.prof",
+ )
else:
interval = timeit.timeit(_build_tree, number=1)
- print('With %aunique:', interval)
+ print("With %aunique:", interval)
# And with %aunique replaceed with a "cheap" no-op function.
lib.path_formats = [
- (library.PF_KEY_DEFAULT,
- Template('$albumartist/$album%lower{}/$track $title')),
+ (
+ library.PF_KEY_DEFAULT,
+ Template("$albumartist/$album%lower{}/$track $title"),
+ ),
]
if prof:
- cProfile.runctx('_build_tree()', {}, {'_build_tree': _build_tree},
- 'paths.withoutaunique.prof')
+ cProfile.runctx(
+ "_build_tree()",
+ {},
+ {"_build_tree": _build_tree},
+ "paths.withoutaunique.prof",
+ )
else:
interval = timeit.timeit(_build_tree, number=1)
- print('Without %aunique:', interval)
+ print("Without %aunique:", interval)
def match_benchmark(lib, prof, query=None, album_id=None):
# If no album ID is provided, we'll match against a suitably huge
# album.
if not album_id:
- album_id = '9c5c043e-bc69-4edb-81a4-1aaf9c81e6dc'
+ album_id = "9c5c043e-bc69-4edb-81a4-1aaf9c81e6dc"
# Get an album from the library to use as the source for the match.
items = lib.albums(query).get().items()
# Ensure fingerprinting is invoked (if enabled).
- plugins.send('import_task_start',
- task=importer.ImportTask(None, None, items),
- session=importer.ImportSession(lib, None, None, None))
+ plugins.send(
+ "import_task_start",
+ task=importer.ImportTask(None, None, items),
+ session=importer.ImportSession(lib, None, None, None),
+ )
# Run the match.
def _run_match():
match.tag_album(items, search_ids=[album_id])
+
if prof:
- cProfile.runctx('_run_match()', {}, {'_run_match': _run_match},
- 'match.prof')
+ cProfile.runctx(
+ "_run_match()", {}, {"_run_match": _run_match}, "match.prof"
+ )
else:
interval = timeit.timeit(_run_match, number=1)
- print('match duration:', interval)
+ print("match duration:", interval)
class BenchmarkPlugin(BeetsPlugin):
- """A plugin for performing some simple performance benchmarks.
- """
- def commands(self):
- aunique_bench_cmd = ui.Subcommand('bench_aunique',
- help='benchmark for %aunique{}')
- aunique_bench_cmd.parser.add_option('-p', '--profile',
- action='store_true', default=False,
- help='performance profiling')
- aunique_bench_cmd.func = lambda lib, opts, args: \
- aunique_benchmark(lib, opts.profile)
+ """A plugin for performing some simple performance benchmarks."""
- match_bench_cmd = ui.Subcommand('bench_match',
- help='benchmark for track matching')
- match_bench_cmd.parser.add_option('-p', '--profile',
- action='store_true', default=False,
- help='performance profiling')
- match_bench_cmd.parser.add_option('-i', '--id', default=None,
- help='album ID to match against')
- match_bench_cmd.func = lambda lib, opts, args: \
- match_benchmark(lib, opts.profile, ui.decargs(args), opts.id)
+ def commands(self):
+ aunique_bench_cmd = ui.Subcommand(
+ "bench_aunique", help="benchmark for %aunique{}"
+ )
+ aunique_bench_cmd.parser.add_option(
+ "-p",
+ "--profile",
+ action="store_true",
+ default=False,
+ help="performance profiling",
+ )
+ aunique_bench_cmd.func = lambda lib, opts, args: aunique_benchmark(
+ lib, opts.profile
+ )
+
+ match_bench_cmd = ui.Subcommand(
+ "bench_match", help="benchmark for track matching"
+ )
+ match_bench_cmd.parser.add_option(
+ "-p",
+ "--profile",
+ action="store_true",
+ default=False,
+ help="performance profiling",
+ )
+ match_bench_cmd.parser.add_option(
+ "-i", "--id", default=None, help="album ID to match against"
+ )
+ match_bench_cmd.func = lambda lib, opts, args: match_benchmark(
+ lib, opts.profile, ui.decargs(args), opts.id
+ )
return [aunique_bench_cmd, match_bench_cmd]
diff --git a/lib/beetsplug/bpd/__init__.py b/lib/beetsplug/bpd/__init__.py
index 07198b1b..a4cb4d29 100644
--- a/lib/beetsplug/bpd/__init__.py
+++ b/lib/beetsplug/bpd/__init__.py
@@ -18,35 +18,36 @@ use of the wide range of MPD clients.
"""
-import re
-import sys
-from string import Template
-import traceback
-import random
-import time
-import math
import inspect
+import math
+import random
+import re
import socket
+import sys
+import time
+import traceback
+from string import Template
+from typing import List
-import beets
-from beets.plugins import BeetsPlugin
-import beets.ui
-from beets import vfs
-from beets.util import bluelet
-from beets.library import Item
-from beets import dbcore
from mediafile import MediaFile
-PROTOCOL_VERSION = '0.16.0'
+import beets
+import beets.ui
+from beets import dbcore, vfs
+from beets.library import Item
+from beets.plugins import BeetsPlugin
+from beets.util import bluelet
+
+PROTOCOL_VERSION = "0.16.0"
BUFSIZE = 1024
-HELLO = 'OK MPD %s' % PROTOCOL_VERSION
-CLIST_BEGIN = 'command_list_begin'
-CLIST_VERBOSE_BEGIN = 'command_list_ok_begin'
-CLIST_END = 'command_list_end'
-RESP_OK = 'OK'
-RESP_CLIST_VERBOSE = 'list_OK'
-RESP_ERR = 'ACK'
+HELLO = "OK MPD %s" % PROTOCOL_VERSION
+CLIST_BEGIN = "command_list_begin"
+CLIST_VERBOSE_BEGIN = "command_list_ok_begin"
+CLIST_END = "command_list_end"
+RESP_OK = "OK"
+RESP_CLIST_VERBOSE = "list_OK"
+RESP_ERR = "ACK"
NEWLINE = "\n"
@@ -68,15 +69,28 @@ VOLUME_MAX = 100
SAFE_COMMANDS = (
# Commands that are available when unauthenticated.
- 'close', 'commands', 'notcommands', 'password', 'ping',
+ "close",
+ "commands",
+ "notcommands",
+ "password",
+ "ping",
)
# List of subsystems/events used by the `idle` command.
SUBSYSTEMS = [
- 'update', 'player', 'mixer', 'options', 'playlist', 'database',
+ "update",
+ "player",
+ "mixer",
+ "options",
+ "playlist",
+ "database",
# Related to unsupported commands:
- 'stored_playlist', 'output', 'subscription', 'sticker', 'message',
- 'partition',
+ "stored_playlist",
+ "output",
+ "subscription",
+ "sticker",
+ "message",
+ "partition",
]
ITEM_KEYS_WRITABLE = set(MediaFile.fields()).intersection(Item._fields.keys())
@@ -89,48 +103,53 @@ class NoGstreamerError(Exception):
# Error-handling, exceptions, parameter parsing.
+
class BPDError(Exception):
"""An error that should be exposed to the client to the BPD
server.
"""
- def __init__(self, code, message, cmd_name='', index=0):
+
+ def __init__(self, code, message, cmd_name="", index=0):
self.code = code
self.message = message
self.cmd_name = cmd_name
self.index = index
- template = Template('$resp [$code@$index] {$cmd_name} $message')
+ template = Template("$resp [$code@$index] {$cmd_name} $message")
def response(self):
"""Returns a string to be used as the response code for the
erring command.
"""
- return self.template.substitute({
- 'resp': RESP_ERR,
- 'code': self.code,
- 'index': self.index,
- 'cmd_name': self.cmd_name,
- 'message': self.message,
- })
+ return self.template.substitute(
+ {
+ "resp": RESP_ERR,
+ "code": self.code,
+ "index": self.index,
+ "cmd_name": self.cmd_name,
+ "message": self.message,
+ }
+ )
def make_bpd_error(s_code, s_message):
- """Create a BPDError subclass for a static code and message.
- """
+ """Create a BPDError subclass for a static code and message."""
class NewBPDError(BPDError):
code = s_code
message = s_message
- cmd_name = ''
+ cmd_name = ""
index = 0
def __init__(self):
pass
+
return NewBPDError
-ArgumentTypeError = make_bpd_error(ERROR_ARG, 'invalid type for argument')
-ArgumentIndexError = make_bpd_error(ERROR_ARG, 'argument out of range')
-ArgumentNotFoundError = make_bpd_error(ERROR_NO_EXIST, 'argument not found')
+
+ArgumentTypeError = make_bpd_error(ERROR_ARG, "invalid type for argument")
+ArgumentIndexError = make_bpd_error(ERROR_ARG, "argument out of range")
+ArgumentNotFoundError = make_bpd_error(ERROR_NO_EXIST, "argument not found")
def cast_arg(t, val):
@@ -140,7 +159,7 @@ def cast_arg(t, val):
If 't' is the special string 'intbool', attempts to cast first
to an int and then to a bool (i.e., 1=True, 0=False).
"""
- if t == 'intbool':
+ if t == "intbool":
return cast_arg(bool, cast_arg(int, val))
else:
try:
@@ -159,6 +178,7 @@ class BPDIdle(Exception):
"""Raised by a command to indicate the client wants to enter the idle state
and should be notified when a relevant event happens.
"""
+
def __init__(self, subsystems):
super().__init__()
self.subsystems = set(subsystems)
@@ -201,8 +221,8 @@ class BaseServer:
self.volume = VOLUME_MAX
self.crossfade = 0
self.mixrampdb = 0.0
- self.mixrampdelay = float('nan')
- self.replay_gain_mode = 'off'
+ self.mixrampdelay = float("nan")
+ self.replay_gain_mode = "off"
self.playlist = []
self.playlist_version = 0
self.current_index = -1
@@ -216,13 +236,11 @@ class BaseServer:
self.random_obj = random.Random()
def connect(self, conn):
- """A new client has connected.
- """
+ """A new client has connected."""
self.connections.add(conn)
def disconnect(self, conn):
- """Client has disconnected; clean up residual state.
- """
+ """Client has disconnected; clean up residual state."""
self.connections.remove(conn)
def run(self):
@@ -233,15 +251,20 @@ class BaseServer:
def start():
yield bluelet.spawn(
- bluelet.server(self.ctrl_host, self.ctrl_port,
- ControlConnection.handler(self)))
- yield bluelet.server(self.host, self.port,
- MPDConnection.handler(self))
+ bluelet.server(
+ self.ctrl_host,
+ self.ctrl_port,
+ ControlConnection.handler(self),
+ )
+ )
+ yield bluelet.server(
+ self.host, self.port, MPDConnection.handler(self)
+ )
+
bluelet.run(start())
def dispatch_events(self):
- """If any clients have idle events ready, send them.
- """
+ """If any clients have idle events ready, send them."""
# We need a copy of `self.connections` here since clients might
# disconnect once we try and send to them, changing `self.connections`.
for conn in list(self.connections):
@@ -255,7 +278,7 @@ class BaseServer:
if not self.ctrl_sock:
self.ctrl_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.ctrl_sock.connect((self.ctrl_host, self.ctrl_port))
- self.ctrl_sock.sendall((message + '\n').encode('utf-8'))
+ self.ctrl_sock.sendall((message + "\n").encode("utf-8"))
def _send_event(self, event):
"""Notify subscribed connections of an event."""
@@ -269,8 +292,7 @@ class BaseServer:
raise NotImplementedError
def _item_id(self, item):
- """An abstract method returning the integer id for an item.
- """
+ """An abstract method returning the integer id for an item."""
raise NotImplementedError
def _id_to_index(self, track_id):
@@ -326,8 +348,7 @@ class BaseServer:
subsystems = subsystems or SUBSYSTEMS
for system in subsystems:
if system not in SUBSYSTEMS:
- raise BPDError(ERROR_ARG,
- f'Unrecognised idle event: {system}')
+ raise BPDError(ERROR_ARG, f"Unrecognised idle event: {system}")
raise BPDIdle(subsystems) # put the connection into idle mode
def cmd_kill(self, conn):
@@ -344,30 +365,30 @@ class BaseServer:
conn.authenticated = True
else:
conn.authenticated = False
- raise BPDError(ERROR_PASSWORD, 'incorrect password')
+ raise BPDError(ERROR_PASSWORD, "incorrect password")
def cmd_commands(self, conn):
"""Lists the commands available to the user."""
if self.password and not conn.authenticated:
# Not authenticated. Show limited list of commands.
for cmd in SAFE_COMMANDS:
- yield 'command: ' + cmd
+ yield "command: " + cmd
else:
# Authenticated. Show all commands.
for func in dir(self):
- if func.startswith('cmd_'):
- yield 'command: ' + func[4:]
+ if func.startswith("cmd_"):
+ yield "command: " + func[4:]
def cmd_notcommands(self, conn):
"""Lists all unavailable commands."""
if self.password and not conn.authenticated:
# Not authenticated. Show privileged commands.
for func in dir(self):
- if func.startswith('cmd_'):
+ if func.startswith("cmd_"):
cmd = func[4:]
if cmd not in SAFE_COMMANDS:
- yield 'command: ' + cmd
+ yield "command: " + cmd
else:
# Authenticated. No commands are unavailable.
@@ -381,43 +402,43 @@ class BaseServer:
playlist, playlistlength, and xfade.
"""
yield (
- 'repeat: ' + str(int(self.repeat)),
- 'random: ' + str(int(self.random)),
- 'consume: ' + str(int(self.consume)),
- 'single: ' + str(int(self.single)),
- 'playlist: ' + str(self.playlist_version),
- 'playlistlength: ' + str(len(self.playlist)),
- 'mixrampdb: ' + str(self.mixrampdb),
+ "repeat: " + str(int(self.repeat)),
+ "random: " + str(int(self.random)),
+ "consume: " + str(int(self.consume)),
+ "single: " + str(int(self.single)),
+ "playlist: " + str(self.playlist_version),
+ "playlistlength: " + str(len(self.playlist)),
+ "mixrampdb: " + str(self.mixrampdb),
)
if self.volume > 0:
- yield 'volume: ' + str(self.volume)
+ yield "volume: " + str(self.volume)
if not math.isnan(self.mixrampdelay):
- yield 'mixrampdelay: ' + str(self.mixrampdelay)
+ yield "mixrampdelay: " + str(self.mixrampdelay)
if self.crossfade > 0:
- yield 'xfade: ' + str(self.crossfade)
+ yield "xfade: " + str(self.crossfade)
if self.current_index == -1:
- state = 'stop'
+ state = "stop"
elif self.paused:
- state = 'pause'
+ state = "pause"
else:
- state = 'play'
- yield 'state: ' + state
+ state = "play"
+ yield "state: " + state
if self.current_index != -1: # i.e., paused or playing
current_id = self._item_id(self.playlist[self.current_index])
- yield 'song: ' + str(self.current_index)
- yield 'songid: ' + str(current_id)
+ yield "song: " + str(self.current_index)
+ yield "songid: " + str(current_id)
if len(self.playlist) > self.current_index + 1:
# If there's a next song, report its index too.
next_id = self._item_id(self.playlist[self.current_index + 1])
- yield 'nextsong: ' + str(self.current_index + 1)
- yield 'nextsongid: ' + str(next_id)
+ yield "nextsong: " + str(self.current_index + 1)
+ yield "nextsongid: " + str(next_id)
if self.error:
- yield 'error: ' + self.error
+ yield "error: " + self.error
def cmd_clearerror(self, conn):
"""Removes the persistent error state of the server. This
@@ -428,32 +449,32 @@ class BaseServer:
def cmd_random(self, conn, state):
"""Set or unset random (shuffle) mode."""
- self.random = cast_arg('intbool', state)
- self._send_event('options')
+ self.random = cast_arg("intbool", state)
+ self._send_event("options")
def cmd_repeat(self, conn, state):
"""Set or unset repeat mode."""
- self.repeat = cast_arg('intbool', state)
- self._send_event('options')
+ self.repeat = cast_arg("intbool", state)
+ self._send_event("options")
def cmd_consume(self, conn, state):
"""Set or unset consume mode."""
- self.consume = cast_arg('intbool', state)
- self._send_event('options')
+ self.consume = cast_arg("intbool", state)
+ self._send_event("options")
def cmd_single(self, conn, state):
"""Set or unset single mode."""
# TODO support oneshot in addition to 0 and 1 [MPD 0.20]
- self.single = cast_arg('intbool', state)
- self._send_event('options')
+ self.single = cast_arg("intbool", state)
+ self._send_event("options")
def cmd_setvol(self, conn, vol):
"""Set the player's volume level (0-100)."""
vol = cast_arg(int, vol)
if vol < VOLUME_MIN or vol > VOLUME_MAX:
- raise BPDError(ERROR_ARG, 'volume out of range')
+ raise BPDError(ERROR_ARG, "volume out of range")
self.volume = vol
- self._send_event('mixer')
+ self._send_event("mixer")
def cmd_volume(self, conn, vol_delta):
"""Deprecated command to change the volume by a relative amount."""
@@ -464,53 +485,53 @@ class BaseServer:
"""Set the number of seconds of crossfading."""
crossfade = cast_arg(int, crossfade)
if crossfade < 0:
- raise BPDError(ERROR_ARG, 'crossfade time must be nonnegative')
- self._log.warning('crossfade is not implemented in bpd')
+ raise BPDError(ERROR_ARG, "crossfade time must be nonnegative")
+ self._log.warning("crossfade is not implemented in bpd")
self.crossfade = crossfade
- self._send_event('options')
+ self._send_event("options")
def cmd_mixrampdb(self, conn, db):
"""Set the mixramp normalised max volume in dB."""
db = cast_arg(float, db)
if db > 0:
- raise BPDError(ERROR_ARG, 'mixrampdb time must be negative')
- self._log.warning('mixramp is not implemented in bpd')
+ raise BPDError(ERROR_ARG, "mixrampdb time must be negative")
+ self._log.warning("mixramp is not implemented in bpd")
self.mixrampdb = db
- self._send_event('options')
+ self._send_event("options")
def cmd_mixrampdelay(self, conn, delay):
"""Set the mixramp delay in seconds."""
delay = cast_arg(float, delay)
if delay < 0:
- raise BPDError(ERROR_ARG, 'mixrampdelay time must be nonnegative')
- self._log.warning('mixramp is not implemented in bpd')
+ raise BPDError(ERROR_ARG, "mixrampdelay time must be nonnegative")
+ self._log.warning("mixramp is not implemented in bpd")
self.mixrampdelay = delay
- self._send_event('options')
+ self._send_event("options")
def cmd_replay_gain_mode(self, conn, mode):
"""Set the replay gain mode."""
- if mode not in ['off', 'track', 'album', 'auto']:
- raise BPDError(ERROR_ARG, 'Unrecognised replay gain mode')
- self._log.warning('replay gain is not implemented in bpd')
+ if mode not in ["off", "track", "album", "auto"]:
+ raise BPDError(ERROR_ARG, "Unrecognised replay gain mode")
+ self._log.warning("replay gain is not implemented in bpd")
self.replay_gain_mode = mode
- self._send_event('options')
+ self._send_event("options")
def cmd_replay_gain_status(self, conn):
"""Get the replaygain mode."""
- yield 'replay_gain_mode: ' + str(self.replay_gain_mode)
+ yield "replay_gain_mode: " + str(self.replay_gain_mode)
def cmd_clear(self, conn):
"""Clear the playlist."""
self.playlist = []
self.playlist_version += 1
self.cmd_stop(conn)
- self._send_event('playlist')
+ self._send_event("playlist")
def cmd_delete(self, conn, index):
"""Remove the song at index from the playlist."""
index = cast_arg(int, index)
try:
- del(self.playlist[index])
+ del self.playlist[index]
except IndexError:
raise ArgumentIndexError()
self.playlist_version += 1
@@ -520,7 +541,7 @@ class BaseServer:
elif index < self.current_index: # Deleted before playing.
# Shift playing index down.
self.current_index -= 1
- self._send_event('playlist')
+ self._send_event("playlist")
def cmd_deleteid(self, conn, track_id):
self.cmd_delete(conn, self._id_to_index(track_id))
@@ -544,7 +565,7 @@ class BaseServer:
self.current_index += 1
self.playlist_version += 1
- self._send_event('playlist')
+ self._send_event("playlist")
def cmd_moveid(self, conn, idx_from, idx_to):
idx_from = self._id_to_index(idx_from)
@@ -570,7 +591,7 @@ class BaseServer:
self.current_index = i
self.playlist_version += 1
- self._send_event('playlist')
+ self._send_event("playlist")
def cmd_swapid(self, conn, i_id, j_id):
i = self._id_to_index(i_id)
@@ -618,12 +639,11 @@ class BaseServer:
Also a dummy implementation.
"""
for idx, track in enumerate(self.playlist):
- yield 'cpos: ' + str(idx)
- yield 'Id: ' + str(track.id)
+ yield "cpos: " + str(idx)
+ yield "Id: " + str(track.id)
def cmd_currentsong(self, conn):
- """Sends information about the currently-playing song.
- """
+ """Sends information about the currently-playing song."""
if self.current_index != -1: # -1 means stopped.
track = self.playlist[self.current_index]
yield self._item_info(track)
@@ -668,8 +688,8 @@ class BaseServer:
if state is None:
self.paused = not self.paused # Toggle.
else:
- self.paused = cast_arg('intbool', state)
- self._send_event('player')
+ self.paused = cast_arg("intbool", state)
+ self._send_event("player")
def cmd_play(self, conn, index=-1):
"""Begin playback, possibly at a specified playlist index."""
@@ -689,7 +709,7 @@ class BaseServer:
self.current_index = index
self.paused = False
- self._send_event('player')
+ self._send_event("player")
def cmd_playid(self, conn, track_id=0):
track_id = cast_arg(int, track_id)
@@ -703,7 +723,7 @@ class BaseServer:
"""Stop playback."""
self.current_index = -1
self.paused = False
- self._send_event('player')
+ self._send_event("player")
def cmd_seek(self, conn, index, pos):
"""Seek to a specified point in a specified song."""
@@ -711,7 +731,7 @@ class BaseServer:
if index < 0 or index >= len(self.playlist):
raise ArgumentIndexError()
self.current_index = index
- self._send_event('player')
+ self._send_event("player")
def cmd_seekid(self, conn, track_id, pos):
index = self._id_to_index(track_id)
@@ -725,23 +745,21 @@ class BaseServer:
without crashing, and that this is not treated as ERROR_ARG (since it
is caused by a programming error, not a protocol error).
"""
- 'a' + 2
+ "a" + 2
class Connection:
- """A connection between a client and the server.
- """
+ """A connection between a client and the server."""
+
def __init__(self, server, sock):
- """Create a new connection for the accepted socket `client`.
- """
+ """Create a new connection for the accepted socket `client`."""
self.server = server
self.sock = sock
- self.address = '{}:{}'.format(*sock.sock.getpeername())
+ self.address = "{}:{}".format(*sock.sock.getpeername())
- def debug(self, message, kind=' '):
- """Log a debug message about this connection.
- """
- self.server._log.debug('{}[{}]: {}', kind, self.address, message)
+ def debug(self, message, kind=" "):
+ """Log a debug message about this connection."""
+ self.server._log.debug("{}[{}]: {}", kind, self.address, message)
def run(self):
pass
@@ -756,26 +774,25 @@ class Connection:
lines = [lines]
out = NEWLINE.join(lines) + NEWLINE
for l in out.split(NEWLINE)[:-1]:
- self.debug(l, kind='>')
+ self.debug(l, kind=">")
if isinstance(out, str):
- out = out.encode('utf-8')
+ out = out.encode("utf-8")
return self.sock.sendall(out)
@classmethod
def handler(cls, server):
def _handle(sock):
- """Creates a new `Connection` and runs it.
- """
+ """Creates a new `Connection` and runs it."""
return cls(server, sock).run()
+
return _handle
class MPDConnection(Connection):
- """A connection that receives commands from an MPD-compatible client.
- """
+ """A connection that receives commands from an MPD-compatible client."""
+
def __init__(self, server, sock):
- """Create a new connection for the accepted socket `client`.
- """
+ """Create a new connection for the accepted socket `client`."""
super().__init__(server, sock)
self.authenticated = False
self.notifications = set()
@@ -794,23 +811,20 @@ class MPDConnection(Connection):
yield self.send(RESP_OK)
def disconnect(self):
- """The connection has closed for any reason.
- """
+ """The connection has closed for any reason."""
self.server.disconnect(self)
- self.debug('disconnected', kind='*')
+ self.debug("disconnected", kind="*")
def notify(self, event):
- """Queue up an event for sending to this client.
- """
+ """Queue up an event for sending to this client."""
self.notifications.add(event)
def send_notifications(self, force_close_idle=False):
- """Send the client any queued events now.
- """
+ """Send the client any queued events now."""
pending = self.notifications.intersection(self.idle_subscriptions)
try:
for event in pending:
- yield self.send(f'changed: {event}')
+ yield self.send(f"changed: {event}")
if pending or force_close_idle:
self.idle_subscriptions = set()
self.notifications = self.notifications.difference(pending)
@@ -822,7 +836,7 @@ class MPDConnection(Connection):
"""Send a greeting to the client and begin processing commands
as they arrive.
"""
- self.debug('connected', kind='*')
+ self.debug("connected", kind="*")
self.server.connect(self)
yield self.send(HELLO)
@@ -834,25 +848,26 @@ class MPDConnection(Connection):
break
line = line.strip()
if not line:
- err = BPDError(ERROR_UNKNOWN, 'No command given')
+ err = BPDError(ERROR_UNKNOWN, "No command given")
yield self.send(err.response())
self.disconnect() # Client sent a blank line.
break
- line = line.decode('utf8') # MPD protocol uses UTF-8.
+ line = line.decode("utf8") # MPD protocol uses UTF-8.
for l in line.split(NEWLINE):
- self.debug(l, kind='<')
+ self.debug(l, kind="<")
if self.idle_subscriptions:
# The connection is in idle mode.
- if line == 'noidle':
+ if line == "noidle":
yield bluelet.call(self.send_notifications(True))
else:
- err = BPDError(ERROR_UNKNOWN,
- f'Got command while idle: {line}')
+ err = BPDError(
+ ERROR_UNKNOWN, f"Got command while idle: {line}"
+ )
yield self.send(err.response())
break
continue
- if line == 'noidle':
+ if line == "noidle":
# When not in idle, this command sends no response.
continue
@@ -880,26 +895,25 @@ class MPDConnection(Connection):
return
except BPDIdle as e:
self.idle_subscriptions = e.subsystems
- self.debug('awaiting: {}'.format(' '.join(e.subsystems)),
- kind='z')
+ self.debug(
+ "awaiting: {}".format(" ".join(e.subsystems)), kind="z"
+ )
yield bluelet.call(self.server.dispatch_events())
class ControlConnection(Connection):
- """A connection used to control BPD for debugging and internal events.
- """
+ """A connection used to control BPD for debugging and internal events."""
+
def __init__(self, server, sock):
- """Create a new connection for the accepted socket `client`.
- """
+ """Create a new connection for the accepted socket `client`."""
super().__init__(server, sock)
- def debug(self, message, kind=' '):
- self.server._log.debug('CTRL {}[{}]: {}', kind, self.address, message)
+ def debug(self, message, kind=" "):
+ self.server._log.debug("CTRL {}[{}]: {}", kind, self.address, message)
def run(self):
- """Listen for control commands and delegate to `ctrl_*` methods.
- """
- self.debug('connected', kind='*')
+ """Listen for control commands and delegate to `ctrl_*` methods."""
+ self.debug("connected", kind="*")
while True:
line = yield self.sock.readline()
if not line:
@@ -907,47 +921,45 @@ class ControlConnection(Connection):
line = line.strip()
if not line:
break # Client sent a blank line.
- line = line.decode('utf8') # Protocol uses UTF-8.
+ line = line.decode("utf8") # Protocol uses UTF-8.
for l in line.split(NEWLINE):
- self.debug(l, kind='<')
+ self.debug(l, kind="<")
command = Command(line)
try:
- func = command.delegate('ctrl_', self)
+ func = command.delegate("ctrl_", self)
yield bluelet.call(func(*command.args))
except (AttributeError, TypeError) as e:
- yield self.send('ERROR: {}'.format(e.args[0]))
+ yield self.send("ERROR: {}".format(e.args[0]))
except Exception:
- yield self.send(['ERROR: server error',
- traceback.format_exc().rstrip()])
+ yield self.send(
+ ["ERROR: server error", traceback.format_exc().rstrip()]
+ )
def ctrl_play_finished(self):
- """Callback from the player signalling a song finished playing.
- """
+ """Callback from the player signalling a song finished playing."""
yield bluelet.call(self.server.dispatch_events())
def ctrl_profile(self):
- """Memory profiling for debugging.
- """
+ """Memory profiling for debugging."""
from guppy import hpy
+
heap = hpy().heap()
yield self.send(heap)
def ctrl_nickname(self, oldlabel, newlabel):
- """Rename a client in the log messages.
- """
+ """Rename a client in the log messages."""
for c in self.server.connections:
if c.address == oldlabel:
c.address = newlabel
break
else:
- yield self.send(f'ERROR: no such client: {oldlabel}')
+ yield self.send(f"ERROR: no such client: {oldlabel}")
class Command:
- """A command issued by the client for processing by the server.
- """
+ """A command issued by the client for processing by the server."""
- command_re = re.compile(r'^([^ \t]+)[ \t]*')
+ command_re = re.compile(r"^([^ \t]+)[ \t]*")
arg_re = re.compile(r'"((?:\\"|[^"])+)"|([^ \t"]+)')
def __init__(self, s):
@@ -958,12 +970,12 @@ class Command:
self.name = command_match.group(1)
self.args = []
- arg_matches = self.arg_re.findall(s[command_match.end():])
+ arg_matches = self.arg_re.findall(s[command_match.end() :])
for match in arg_matches:
if match[0]:
# Quoted argument.
arg = match[0]
- arg = arg.replace('\\"', '"').replace('\\\\', '\\')
+ arg = arg.replace('\\"', '"').replace("\\\\", "\\")
else:
# Unquoted argument.
arg = match[1]
@@ -994,8 +1006,10 @@ class Command:
wrong_num = (len(self.args) > max_args) or (len(self.args) < min_args)
# If the command accepts a variable number of arguments skip the check.
if wrong_num and not argspec.varargs:
- raise TypeError('wrong number of arguments for "{}"'
- .format(self.name), self.name)
+ raise TypeError(
+ 'wrong number of arguments for "{}"'.format(self.name),
+ self.name,
+ )
return func
@@ -1005,17 +1019,19 @@ class Command:
"""
try:
# `conn` is an extra argument to all cmd handlers.
- func = self.delegate('cmd_', conn.server, extra_args=1)
+ func = self.delegate("cmd_", conn.server, extra_args=1)
except AttributeError as e:
raise BPDError(ERROR_UNKNOWN, e.args[0])
except TypeError as e:
raise BPDError(ERROR_ARG, e.args[0], self.name)
# Ensure we have permission for this command.
- if conn.server.password and \
- not conn.authenticated and \
- self.name not in SAFE_COMMANDS:
- raise BPDError(ERROR_PERMISSION, 'insufficient privileges')
+ if (
+ conn.server.password
+ and not conn.authenticated
+ and self.name not in SAFE_COMMANDS
+ ):
+ raise BPDError(ERROR_PERMISSION, "insufficient privileges")
try:
args = [conn] + self.args
@@ -1040,11 +1056,11 @@ class Command:
except Exception:
# An "unintentional" error. Hide it from the client.
- conn.server._log.error('{}', traceback.format_exc())
- raise BPDError(ERROR_SYSTEM, 'server error', self.name)
+ conn.server._log.error("{}", traceback.format_exc())
+ raise BPDError(ERROR_SYSTEM, "server error", self.name)
-class CommandList(list):
+class CommandList(List[Command]):
"""A list of commands issued by the client for processing by the
server. May be verbose, in which case the response is delimited, or
not. Should be a list of `Command` objects.
@@ -1060,8 +1076,7 @@ class CommandList(list):
self.verbose = verbose
def run(self, conn):
- """Coroutine executing all the commands in this list.
- """
+ """Coroutine executing all the commands in this list."""
for i, command in enumerate(self):
try:
yield bluelet.call(command.run(conn))
@@ -1079,6 +1094,7 @@ class CommandList(list):
# A subclass of the basic, protocol-handling server that actually plays
# music.
+
class Server(BaseServer):
"""An MPD-compatible server using GStreamer to play audio and beets
to store its library.
@@ -1089,50 +1105,50 @@ class Server(BaseServer):
from beetsplug.bpd import gstplayer
except ImportError as e:
# This is a little hacky, but it's the best I know for now.
- if e.args[0].endswith(' gst'):
+ if e.args[0].endswith(" gst"):
raise NoGstreamerError()
else:
raise
- log.info('Starting server...')
+ log.info("Starting server...")
super().__init__(host, port, password, ctrl_port, log)
self.lib = library
self.player = gstplayer.GstPlayer(self.play_finished)
self.cmd_update(None)
- log.info('Server ready and listening on {}:{}'.format(
- host, port))
- log.debug('Listening for control signals on {}:{}'.format(
- host, ctrl_port))
+ log.info("Server ready and listening on {}:{}".format(host, port))
+ log.debug(
+ "Listening for control signals on {}:{}".format(host, ctrl_port)
+ )
def run(self):
self.player.run()
super().run()
def play_finished(self):
- """A callback invoked every time our player finishes a track.
- """
+ """A callback invoked every time our player finishes a track."""
self.cmd_next(None)
- self._ctrl_send('play_finished')
+ self._ctrl_send("play_finished")
# Metadata helper functions.
def _item_info(self, item):
info_lines = [
- 'file: ' + item.destination(fragment=True),
- 'Time: ' + str(int(item.length)),
- 'duration: ' + f'{item.length:.3f}',
- 'Id: ' + str(item.id),
+ "file: " + item.destination(fragment=True),
+ "Time: " + str(int(item.length)),
+ "duration: " + f"{item.length:.3f}",
+ "Id: " + str(item.id),
]
try:
pos = self._id_to_index(item.id)
- info_lines.append('Pos: ' + str(pos))
+ info_lines.append("Pos: " + str(pos))
except ArgumentNotFoundError:
# Don't include position if not in playlist.
pass
for tagtype, field in self.tagtype_map.items():
- info_lines.append('{}: {}'.format(
- tagtype, str(getattr(item, field))))
+ info_lines.append(
+ "{}: {}".format(tagtype, str(getattr(item, field)))
+ )
return info_lines
@@ -1142,11 +1158,11 @@ class Server(BaseServer):
commands. Sometimes a single number can be provided instead.
"""
try:
- start, stop = str(items).split(':', 1)
+ start, stop = str(items).split(":", 1)
except ValueError:
if accept_single_number:
return [cast_arg(int, items)]
- raise BPDError(ERROR_ARG, 'bad range syntax')
+ raise BPDError(ERROR_ARG, "bad range syntax")
start = cast_arg(int, start)
stop = cast_arg(int, stop)
return range(start, stop)
@@ -1156,17 +1172,16 @@ class Server(BaseServer):
# Database updating.
- def cmd_update(self, conn, path='/'):
- """Updates the catalog to reflect the current database state.
- """
+ def cmd_update(self, conn, path="/"):
+ """Updates the catalog to reflect the current database state."""
# Path is ignored. Also, the real MPD does this asynchronously;
# this is done inline.
- self._log.debug('Building directory tree...')
+ self._log.debug("Building directory tree...")
self.tree = vfs.libtree(self.lib)
- self._log.debug('Finished building directory tree.')
+ self._log.debug("Finished building directory tree.")
self.updated_time = time.time()
- self._send_event('update')
- self._send_event('database')
+ self._send_event("update")
+ self._send_event("database")
# Path (directory tree) browsing.
@@ -1174,7 +1189,7 @@ class Server(BaseServer):
"""Returns a VFS node or an item ID located at the path given.
If the path does not exist, raises a
"""
- components = path.split('/')
+ components = path.split("/")
node = self.tree
for component in components:
@@ -1196,15 +1211,15 @@ class Server(BaseServer):
def _path_join(self, p1, p2):
"""Smashes together two BPD paths."""
- out = p1 + '/' + p2
- return out.replace('//', '/').replace('//', '/')
+ out = p1 + "/" + p2
+ return out.replace("//", "/").replace("//", "/")
def cmd_lsinfo(self, conn, path="/"):
"""Sends info on all the items in the path."""
node = self._resolve_path(path)
if isinstance(node, int):
# Trying to list a track.
- raise BPDError(ERROR_ARG, 'this is not a directory')
+ raise BPDError(ERROR_ARG, "this is not a directory")
else:
for name, itemid in iter(sorted(node.files.items())):
item = self.lib.get_item(itemid)
@@ -1214,7 +1229,7 @@ class Server(BaseServer):
if dirpath.startswith("/"):
# Strip leading slash (libmpc rejects this).
dirpath = dirpath[1:]
- yield 'directory: %s' % dirpath
+ yield "directory: %s" % dirpath
def _listall(self, basepath, node, info=False):
"""Helper function for recursive listing. If info, show
@@ -1226,7 +1241,7 @@ class Server(BaseServer):
item = self.lib.get_item(node)
yield self._item_info(item)
else:
- yield 'file: ' + basepath
+ yield "file: " + basepath
else:
# List a directory. Recurse into both directories and files.
for name, itemid in sorted(node.files.items()):
@@ -1235,7 +1250,7 @@ class Server(BaseServer):
yield from self._listall(newpath, itemid, info)
for name, subdir in sorted(node.dirs.items()):
newpath = self._path_join(basepath, name)
- yield 'directory: ' + newpath
+ yield "directory: " + newpath
yield from self._listall(newpath, subdir, info)
def cmd_listall(self, conn, path="/"):
@@ -1249,8 +1264,7 @@ class Server(BaseServer):
# Playlist manipulation.
def _all_items(self, node):
- """Generator yielding all items under a VFS node.
- """
+ """Generator yielding all items under a VFS node."""
if isinstance(node, int):
# Could be more efficient if we built up all the IDs and
# then issued a single SELECT.
@@ -1270,9 +1284,9 @@ class Server(BaseServer):
for item in self._all_items(self._resolve_path(path)):
self.playlist.append(item)
if send_id:
- yield 'Id: ' + str(item.id)
+ yield "Id: " + str(item.id)
self.playlist_version += 1
- self._send_event('playlist')
+ self._send_event("playlist")
def cmd_add(self, conn, path):
"""Adds a track or directory to the playlist, specified by a
@@ -1292,8 +1306,8 @@ class Server(BaseServer):
item = self.playlist[self.current_index]
yield (
- 'bitrate: ' + str(item.bitrate / 1000),
- 'audio: {}:{}:{}'.format(
+ "bitrate: " + str(item.bitrate / 1000),
+ "audio: {}:{}:{}".format(
str(item.samplerate),
str(item.bitdepth),
str(item.channels),
@@ -1302,12 +1316,12 @@ class Server(BaseServer):
(pos, total) = self.player.time()
yield (
- 'time: {}:{}'.format(
+ "time: {}:{}".format(
str(int(pos)),
str(int(total)),
),
- 'elapsed: ' + f'{pos:.3f}',
- 'duration: ' + f'{total:.3f}',
+ "elapsed: " + f"{pos:.3f}",
+ "duration: " + f"{total:.3f}",
)
# Also missing 'updating_db'.
@@ -1315,55 +1329,57 @@ class Server(BaseServer):
def cmd_stats(self, conn):
"""Sends some statistics about the library."""
with self.lib.transaction() as tx:
- statement = 'SELECT COUNT(DISTINCT artist), ' \
- 'COUNT(DISTINCT album), ' \
- 'COUNT(id), ' \
- 'SUM(length) ' \
- 'FROM items'
+ statement = (
+ "SELECT COUNT(DISTINCT artist), "
+ "COUNT(DISTINCT album), "
+ "COUNT(id), "
+ "SUM(length) "
+ "FROM items"
+ )
artists, albums, songs, totaltime = tx.query(statement)[0]
yield (
- 'artists: ' + str(artists),
- 'albums: ' + str(albums),
- 'songs: ' + str(songs),
- 'uptime: ' + str(int(time.time() - self.startup_time)),
- 'playtime: ' + '0', # Missing.
- 'db_playtime: ' + str(int(totaltime)),
- 'db_update: ' + str(int(self.updated_time)),
+ "artists: " + str(artists),
+ "albums: " + str(albums),
+ "songs: " + str(songs),
+ "uptime: " + str(int(time.time() - self.startup_time)),
+ "playtime: " + "0", # Missing.
+ "db_playtime: " + str(int(totaltime)),
+ "db_update: " + str(int(self.updated_time)),
)
def cmd_decoders(self, conn):
"""Send list of supported decoders and formats."""
decoders = self.player.get_decoders()
for name, (mimes, exts) in decoders.items():
- yield f'plugin: {name}'
+ yield f"plugin: {name}"
for ext in exts:
- yield f'suffix: {ext}'
+ yield f"suffix: {ext}"
for mime in mimes:
- yield f'mime_type: {mime}'
+ yield f"mime_type: {mime}"
# Searching.
tagtype_map = {
- 'Artist': 'artist',
- 'ArtistSort': 'artist_sort',
- 'Album': 'album',
- 'Title': 'title',
- 'Track': 'track',
- 'AlbumArtist': 'albumartist',
- 'AlbumArtistSort': 'albumartist_sort',
- 'Label': 'label',
- 'Genre': 'genre',
- 'Date': 'year',
- 'OriginalDate': 'original_year',
- 'Composer': 'composer',
- 'Disc': 'disc',
- 'Comment': 'comments',
- 'MUSICBRAINZ_TRACKID': 'mb_trackid',
- 'MUSICBRAINZ_ALBUMID': 'mb_albumid',
- 'MUSICBRAINZ_ARTISTID': 'mb_artistid',
- 'MUSICBRAINZ_ALBUMARTISTID': 'mb_albumartistid',
- 'MUSICBRAINZ_RELEASETRACKID': 'mb_releasetrackid',
+ "Artist": "artist",
+ "ArtistSort": "artist_sort",
+ "Album": "album",
+ "Title": "title",
+ "Track": "track",
+ "AlbumArtist": "albumartist",
+ "AlbumArtistSort": "albumartist_sort",
+ "Label": "label",
+ "Genre": "genre",
+ "Date": "year",
+ "OriginalDate": "original_year",
+ "Composer": "composer",
+ "Disc": "disc",
+ "Comment": "comments",
+ "MUSICBRAINZ_TRACKID": "mb_trackid",
+ "MUSICBRAINZ_ALBUMID": "mb_albumid",
+ "MUSICBRAINZ_ARTISTID": "mb_artistid",
+ "MUSICBRAINZ_ALBUMARTISTID": "mb_albumartistid",
+ "MUSICBRAINZ_RELEASETRACKID": "mb_releasetrackid",
}
def cmd_tagtypes(self, conn):
@@ -1371,7 +1387,7 @@ class Server(BaseServer):
searching.
"""
for tag in self.tagtype_map:
- yield 'tagtype: ' + tag
+ yield "tagtype: " + tag
def _tagtype_lookup(self, tag):
"""Uses `tagtype_map` to look up the beets column name for an
@@ -1383,7 +1399,7 @@ class Server(BaseServer):
# Match case-insensitively.
if test_tag.lower() == tag.lower():
return test_tag, key
- raise BPDError(ERROR_UNKNOWN, 'no such tagtype')
+ raise BPDError(ERROR_UNKNOWN, "no such tagtype")
def _metadata_query(self, query_type, any_query_type, kv):
"""Helper function returns a query object that will find items
@@ -1396,13 +1412,15 @@ class Server(BaseServer):
# Iterate pairwise over the arguments.
it = iter(kv)
for tag, value in zip(it, it):
- if tag.lower() == 'any':
+ if tag.lower() == "any":
if any_query_type:
- queries.append(any_query_type(value,
- ITEM_KEYS_WRITABLE,
- query_type))
+ queries.append(
+ any_query_type(
+ value, ITEM_KEYS_WRITABLE, query_type
+ )
+ )
else:
- raise BPDError(ERROR_UNKNOWN, 'no such tagtype')
+ raise BPDError(ERROR_UNKNOWN, "no such tagtype")
else:
_, key = self._tagtype_lookup(tag)
queries.append(query_type(key, value))
@@ -1412,17 +1430,15 @@ class Server(BaseServer):
def cmd_search(self, conn, *kv):
"""Perform a substring match for items."""
- query = self._metadata_query(dbcore.query.SubstringQuery,
- dbcore.query.AnyFieldQuery,
- kv)
+ query = self._metadata_query(
+ dbcore.query.SubstringQuery, dbcore.query.AnyFieldQuery, kv
+ )
for item in self.lib.items(query):
yield self._item_info(item)
def cmd_find(self, conn, *kv):
"""Perform an exact match for items."""
- query = self._metadata_query(dbcore.query.MatchQuery,
- None,
- kv)
+ query = self._metadata_query(dbcore.query.MatchQuery, None, kv)
for item in self.lib.items(query):
yield self._item_info(item)
@@ -1432,22 +1448,27 @@ class Server(BaseServer):
"""
show_tag_canon, show_key = self._tagtype_lookup(show_tag)
if len(kv) == 1:
- if show_tag_canon == 'Album':
+ if show_tag_canon == "Album":
# If no tag was given, assume artist. This is because MPD
# supports a short version of this command for fetching the
# albums belonging to a particular artist, and some clients
# rely on this behaviour (e.g. MPDroid, M.A.L.P.).
- kv = ('Artist', kv[0])
+ kv = ("Artist", kv[0])
else:
raise BPDError(ERROR_ARG, 'should be "Album" for 3 arguments')
elif len(kv) % 2 != 0:
- raise BPDError(ERROR_ARG, 'Incorrect number of filter arguments')
+ raise BPDError(ERROR_ARG, "Incorrect number of filter arguments")
query = self._metadata_query(dbcore.query.MatchQuery, None, kv)
clause, subvals = query.clause()
- statement = 'SELECT DISTINCT ' + show_key + \
- ' FROM items WHERE ' + clause + \
- ' ORDER BY ' + show_key
+ statement = (
+ "SELECT DISTINCT "
+ + show_key
+ + " FROM items WHERE "
+ + clause
+ + " ORDER BY "
+ + show_key
+ )
self._log.debug(statement)
with self.lib.transaction() as tx:
rows = tx.query(statement, subvals)
@@ -1456,7 +1477,7 @@ class Server(BaseServer):
if not row[0]:
# Skip any empty values of the field.
continue
- yield show_tag_canon + ': ' + str(row[0])
+ yield show_tag_canon + ": " + str(row[0])
def cmd_count(self, conn, tag, value):
"""Returns the number and total time of songs matching the
@@ -1468,44 +1489,44 @@ class Server(BaseServer):
for item in self.lib.items(dbcore.query.MatchQuery(key, value)):
songs += 1
playtime += item.length
- yield 'songs: ' + str(songs)
- yield 'playtime: ' + str(int(playtime))
+ yield "songs: " + str(songs)
+ yield "playtime: " + str(int(playtime))
# Persistent playlist manipulation. In MPD this is an optional feature so
# these dummy implementations match MPD's behaviour with the feature off.
def cmd_listplaylist(self, conn, playlist):
- raise BPDError(ERROR_NO_EXIST, 'No such playlist')
+ raise BPDError(ERROR_NO_EXIST, "No such playlist")
def cmd_listplaylistinfo(self, conn, playlist):
- raise BPDError(ERROR_NO_EXIST, 'No such playlist')
+ raise BPDError(ERROR_NO_EXIST, "No such playlist")
def cmd_listplaylists(self, conn):
- raise BPDError(ERROR_UNKNOWN, 'Stored playlists are disabled')
+ raise BPDError(ERROR_UNKNOWN, "Stored playlists are disabled")
def cmd_load(self, conn, playlist):
- raise BPDError(ERROR_NO_EXIST, 'Stored playlists are disabled')
+ raise BPDError(ERROR_NO_EXIST, "Stored playlists are disabled")
def cmd_playlistadd(self, conn, playlist, uri):
- raise BPDError(ERROR_UNKNOWN, 'Stored playlists are disabled')
+ raise BPDError(ERROR_UNKNOWN, "Stored playlists are disabled")
def cmd_playlistclear(self, conn, playlist):
- raise BPDError(ERROR_UNKNOWN, 'Stored playlists are disabled')
+ raise BPDError(ERROR_UNKNOWN, "Stored playlists are disabled")
def cmd_playlistdelete(self, conn, playlist, index):
- raise BPDError(ERROR_UNKNOWN, 'Stored playlists are disabled')
+ raise BPDError(ERROR_UNKNOWN, "Stored playlists are disabled")
def cmd_playlistmove(self, conn, playlist, from_index, to_index):
- raise BPDError(ERROR_UNKNOWN, 'Stored playlists are disabled')
+ raise BPDError(ERROR_UNKNOWN, "Stored playlists are disabled")
def cmd_rename(self, conn, playlist, new_name):
- raise BPDError(ERROR_UNKNOWN, 'Stored playlists are disabled')
+ raise BPDError(ERROR_UNKNOWN, "Stored playlists are disabled")
def cmd_rm(self, conn, playlist):
- raise BPDError(ERROR_UNKNOWN, 'Stored playlists are disabled')
+ raise BPDError(ERROR_UNKNOWN, "Stored playlists are disabled")
def cmd_save(self, conn, playlist):
- raise BPDError(ERROR_UNKNOWN, 'Stored playlists are disabled')
+ raise BPDError(ERROR_UNKNOWN, "Stored playlists are disabled")
# "Outputs." Just a dummy implementation because we don't control
# any outputs.
@@ -1513,9 +1534,9 @@ class Server(BaseServer):
def cmd_outputs(self, conn):
"""List the available outputs."""
yield (
- 'outputid: 0',
- 'outputname: gstreamer',
- 'outputenabled: 1',
+ "outputid: 0",
+ "outputname: gstreamer",
+ "outputenabled: 1",
)
def cmd_enableoutput(self, conn, output_id):
@@ -1526,7 +1547,7 @@ class Server(BaseServer):
def cmd_disableoutput(self, conn, output_id):
output_id = cast_arg(int, output_id)
if output_id == 0:
- raise BPDError(ERROR_ARG, 'cannot disable this output')
+ raise BPDError(ERROR_ARG, "cannot disable this output")
else:
raise ArgumentIndexError()
@@ -1574,20 +1595,24 @@ class Server(BaseServer):
# Beets plugin hooks.
+
class BPDPlugin(BeetsPlugin):
"""Provides the "beet bpd" command for running a music player
server.
"""
+
def __init__(self):
super().__init__()
- self.config.add({
- 'host': '',
- 'port': 6600,
- 'control_port': 6601,
- 'password': '',
- 'volume': VOLUME_MAX,
- })
- self.config['password'].redact = True
+ self.config.add(
+ {
+ "host": "",
+ "port": 6600,
+ "control_port": 6601,
+ "password": "",
+ "volume": VOLUME_MAX,
+ }
+ )
+ self.config["password"].redact = True
def start_bpd(self, lib, host, port, password, volume, ctrl_port):
"""Starts a BPD server."""
@@ -1596,29 +1621,32 @@ class BPDPlugin(BeetsPlugin):
server.cmd_setvol(None, volume)
server.run()
except NoGstreamerError:
- self._log.error('Gstreamer Python bindings not found.')
- self._log.error('Install "gstreamer1.0" and "python-gi"'
- 'or similar package to use BPD.')
+ self._log.error("Gstreamer Python bindings not found.")
+ self._log.error(
+ 'Install "gstreamer1.0" and "python-gi"'
+ "or similar package to use BPD."
+ )
def commands(self):
cmd = beets.ui.Subcommand(
- 'bpd', help='run an MPD-compatible music player server'
+ "bpd", help="run an MPD-compatible music player server"
)
def func(lib, opts, args):
- host = self.config['host'].as_str()
+ host = self.config["host"].as_str()
host = args.pop(0) if args else host
- port = args.pop(0) if args else self.config['port'].get(int)
+ port = args.pop(0) if args else self.config["port"].get(int)
if args:
ctrl_port = args.pop(0)
else:
- ctrl_port = self.config['control_port'].get(int)
+ ctrl_port = self.config["control_port"].get(int)
if args:
- raise beets.ui.UserError('too many arguments')
- password = self.config['password'].as_str()
- volume = self.config['volume'].get(int)
- self.start_bpd(lib, host, int(port), password, volume,
- int(ctrl_port))
+ raise beets.ui.UserError("too many arguments")
+ password = self.config["password"].as_str()
+ volume = self.config["volume"].get(int)
+ self.start_bpd(
+ lib, host, int(port), password, volume, int(ctrl_port)
+ )
cmd.func = func
return [cmd]
diff --git a/lib/beetsplug/bpd/gstplayer.py b/lib/beetsplug/bpd/gstplayer.py
index 64954b1c..77ddc198 100644
--- a/lib/beetsplug/bpd/gstplayer.py
+++ b/lib/beetsplug/bpd/gstplayer.py
@@ -17,18 +17,19 @@ music player.
"""
+import _thread
+import copy
+import os
import sys
import time
-import _thread
-import os
-import copy
import urllib
-from beets import ui
import gi
-gi.require_version('Gst', '1.0')
-from gi.repository import GLib, Gst # noqa: E402
+from beets import ui
+
+gi.require_version("Gst", "1.0")
+from gi.repository import GLib, Gst # noqa: E402
Gst.init(None)
@@ -128,8 +129,8 @@ class GstPlayer:
"""
self.player.set_state(Gst.State.NULL)
if isinstance(path, str):
- path = path.encode('utf-8')
- uri = 'file://' + urllib.parse.quote(path)
+ path = path.encode("utf-8")
+ uri = "file://" + urllib.parse.quote(path)
self.player.set_property("uri", uri)
self.player.set_state(Gst.State.PLAYING)
self.playing = True
@@ -175,12 +176,12 @@ class GstPlayer:
posq = self.player.query_position(fmt)
if not posq[0]:
raise QueryError("query_position failed")
- pos = posq[1] / (10 ** 9)
+ pos = posq[1] / (10**9)
lengthq = self.player.query_duration(fmt)
if not lengthq[0]:
raise QueryError("query_duration failed")
- length = lengthq[1] / (10 ** 9)
+ length = lengthq[1] / (10**9)
self.cached_time = (pos, length)
return (pos, length)
@@ -202,7 +203,7 @@ class GstPlayer:
return
fmt = Gst.Format(Gst.Format.TIME)
- ns = position * 10 ** 9 # convert to nanoseconds
+ ns = position * 10**9 # convert to nanoseconds
self.player.seek_simple(fmt, Gst.SeekFlags.FLUSH, ns)
# save new cached time
@@ -223,11 +224,13 @@ def get_decoders():
and file extensions.
"""
# We only care about audio decoder elements.
- filt = (Gst.ELEMENT_FACTORY_TYPE_DEPAYLOADER |
- Gst.ELEMENT_FACTORY_TYPE_DEMUXER |
- Gst.ELEMENT_FACTORY_TYPE_PARSER |
- Gst.ELEMENT_FACTORY_TYPE_DECODER |
- Gst.ELEMENT_FACTORY_TYPE_MEDIA_AUDIO)
+ filt = (
+ Gst.ELEMENT_FACTORY_TYPE_DEPAYLOADER
+ | Gst.ELEMENT_FACTORY_TYPE_DEMUXER
+ | Gst.ELEMENT_FACTORY_TYPE_PARSER
+ | Gst.ELEMENT_FACTORY_TYPE_DECODER
+ | Gst.ELEMENT_FACTORY_TYPE_MEDIA_AUDIO
+ )
decoders = {}
mime_types = set()
@@ -239,7 +242,7 @@ def get_decoders():
for i in range(caps.get_size()):
struct = caps.get_structure(i)
mime = struct.get_name()
- if mime == 'unknown/unknown':
+ if mime == "unknown/unknown":
continue
mimes.add(mime)
mime_types.add(mime)
@@ -295,10 +298,9 @@ def play_complicated(paths):
time.sleep(1)
-if __name__ == '__main__':
+if __name__ == "__main__":
# A very simple command-line player. Just give it names of audio
# files on the command line; these are all played in sequence.
- paths = [os.path.abspath(os.path.expanduser(p))
- for p in sys.argv[1:]]
+ paths = [os.path.abspath(os.path.expanduser(p)) for p in sys.argv[1:]]
# play_simple(paths)
play_complicated(paths)
diff --git a/lib/beetsplug/bpm.py b/lib/beetsplug/bpm.py
index 5aa2d95a..3edcbef8 100644
--- a/lib/beetsplug/bpm.py
+++ b/lib/beetsplug/bpm.py
@@ -30,7 +30,7 @@ def bpm(max_strokes):
for i in range(max_strokes):
# Press enter to the rhythm...
s = input()
- if s == '':
+ if s == "":
t1 = time.time()
# Only start measuring at the second stroke
if t0:
@@ -46,18 +46,20 @@ def bpm(max_strokes):
class BPMPlugin(BeetsPlugin):
-
def __init__(self):
super().__init__()
- self.config.add({
- 'max_strokes': 3,
- 'overwrite': True,
- })
+ self.config.add(
+ {
+ "max_strokes": 3,
+ "overwrite": True,
+ }
+ )
def commands(self):
- cmd = ui.Subcommand('bpm',
- help='determine bpm of a song by pressing '
- 'a key to the rhythm')
+ cmd = ui.Subcommand(
+ "bpm",
+ help="determine bpm of a song by pressing " "a key to the rhythm",
+ )
cmd.func = self.command
return [cmd]
@@ -67,21 +69,23 @@ class BPMPlugin(BeetsPlugin):
self.get_bpm(items, write)
def get_bpm(self, items, write=False):
- overwrite = self.config['overwrite'].get(bool)
+ overwrite = self.config["overwrite"].get(bool)
if len(items) > 1:
- raise ValueError('Can only get bpm of one song at time')
+ raise ValueError("Can only get bpm of one song at time")
item = items[0]
- if item['bpm']:
- self._log.info('Found bpm {0}', item['bpm'])
+ if item["bpm"]:
+ self._log.info("Found bpm {0}", item["bpm"])
if not overwrite:
return
- self._log.info('Press Enter {0} times to the rhythm or Ctrl-D '
- 'to exit', self.config['max_strokes'].get(int))
- new_bpm = bpm(self.config['max_strokes'].get(int))
- item['bpm'] = int(new_bpm)
+ self._log.info(
+ "Press Enter {0} times to the rhythm or Ctrl-D " "to exit",
+ self.config["max_strokes"].get(int),
+ )
+ new_bpm = bpm(self.config["max_strokes"].get(int))
+ item["bpm"] = int(new_bpm)
if write:
item.try_write()
item.store()
- self._log.info('Added new bpm {0}', item['bpm'])
+ self._log.info("Added new bpm {0}", item["bpm"])
diff --git a/lib/beetsplug/bpsync.py b/lib/beetsplug/bpsync.py
index 5b28d6d2..4f3e0e90 100644
--- a/lib/beetsplug/bpsync.py
+++ b/lib/beetsplug/bpsync.py
@@ -15,8 +15,8 @@
"""Update library's tags using Beatport.
"""
-from beets.plugins import BeetsPlugin, apply_item_changes
from beets import autotag, library, ui, util
+from beets.plugins import BeetsPlugin, apply_item_changes
from .beatport import BeatportPlugin
@@ -28,33 +28,33 @@ class BPSyncPlugin(BeetsPlugin):
self.beatport_plugin.setup()
def commands(self):
- cmd = ui.Subcommand('bpsync', help='update metadata from Beatport')
+ cmd = ui.Subcommand("bpsync", help="update metadata from Beatport")
cmd.parser.add_option(
- '-p',
- '--pretend',
- action='store_true',
- help='show all changes but do nothing',
+ "-p",
+ "--pretend",
+ action="store_true",
+ help="show all changes but do nothing",
)
cmd.parser.add_option(
- '-m',
- '--move',
- action='store_true',
- dest='move',
+ "-m",
+ "--move",
+ action="store_true",
+ dest="move",
help="move files in the library directory",
)
cmd.parser.add_option(
- '-M',
- '--nomove',
- action='store_false',
- dest='move',
+ "-M",
+ "--nomove",
+ action="store_false",
+ dest="move",
help="don't move files in library",
)
cmd.parser.add_option(
- '-W',
- '--nowrite',
- action='store_false',
+ "-W",
+ "--nowrite",
+ action="store_false",
default=None,
- dest='write',
+ dest="write",
help="don't write updated metadata to files",
)
cmd.parser.add_format_option()
@@ -62,8 +62,7 @@ class BPSyncPlugin(BeetsPlugin):
return [cmd]
def func(self, lib, opts, args):
- """Command handler for the bpsync function.
- """
+ """Command handler for the bpsync function."""
move = ui.should_move(opts.move)
pretend = opts.pretend
write = ui.should_write(opts.write)
@@ -76,16 +75,16 @@ class BPSyncPlugin(BeetsPlugin):
"""Retrieve and apply info from the autotagger for items matched by
query.
"""
- for item in lib.items(query + ['singleton:true']):
+ for item in lib.items(query + ["singleton:true"]):
if not item.mb_trackid:
self._log.info(
- 'Skipping singleton with no mb_trackid: {}', item
+ "Skipping singleton with no mb_trackid: {}", item
)
continue
if not self.is_beatport_track(item):
self._log.info(
- 'Skipping non-{} singleton: {}',
+ "Skipping non-{} singleton: {}",
self.beatport_plugin.data_source,
item,
)
@@ -100,27 +99,27 @@ class BPSyncPlugin(BeetsPlugin):
@staticmethod
def is_beatport_track(item):
return (
- item.get('data_source') == BeatportPlugin.data_source
+ item.get("data_source") == BeatportPlugin.data_source
and item.mb_trackid.isnumeric()
)
def get_album_tracks(self, album):
if not album.mb_albumid:
- self._log.info('Skipping album with no mb_albumid: {}', album)
+ self._log.info("Skipping album with no mb_albumid: {}", album)
return False
if not album.mb_albumid.isnumeric():
self._log.info(
- 'Skipping album with invalid {} ID: {}',
+ "Skipping album with invalid {} ID: {}",
self.beatport_plugin.data_source,
album,
)
return False
items = list(album.items())
- if album.get('data_source') == self.beatport_plugin.data_source:
+ if album.get("data_source") == self.beatport_plugin.data_source:
return items
if not all(self.is_beatport_track(item) for item in items):
self._log.info(
- 'Skipping non-{} release: {}',
+ "Skipping non-{} release: {}",
self.beatport_plugin.data_source,
album,
)
@@ -142,7 +141,7 @@ class BPSyncPlugin(BeetsPlugin):
albuminfo = self.beatport_plugin.album_for_id(album.mb_albumid)
if not albuminfo:
self._log.info(
- 'Release ID {} not found for album {}',
+ "Release ID {} not found for album {}",
album.mb_albumid,
album,
)
@@ -159,7 +158,7 @@ class BPSyncPlugin(BeetsPlugin):
for track_id, item in library_trackid_to_item.items()
}
- self._log.info('applying changes to {}', album)
+ self._log.info("applying changes to {}", album)
with lib.transaction():
autotag.apply_metadata(albuminfo, item_to_trackinfo)
changed = False
@@ -182,5 +181,5 @@ class BPSyncPlugin(BeetsPlugin):
# Move album art (and any inconsistent items).
if move and lib.directory in util.ancestry(items[0].path):
- self._log.debug('moving album {}', album)
+ self._log.debug("moving album {}", album)
album.move()
diff --git a/lib/beetsplug/bucket.py b/lib/beetsplug/bucket.py
index 9ed50b45..59ee080b 100644
--- a/lib/beetsplug/bucket.py
+++ b/lib/beetsplug/bucket.py
@@ -16,14 +16,13 @@
"""
-from datetime import datetime
import re
import string
+from datetime import datetime
from itertools import tee
from beets import plugins, ui
-
ASCII_DIGITS = string.digits + string.ascii_lowercase
@@ -39,12 +38,10 @@ def pairwise(iterable):
def span_from_str(span_str):
- """Build a span dict from the span string representation.
- """
+ """Build a span dict from the span string representation."""
def normalize_year(d, yearfrom):
- """Convert string to a 4 digits year
- """
+ """Convert string to a 4 digits year"""
if yearfrom < 100:
raise BucketError("%d must be expressed on 4 digits" % yearfrom)
@@ -57,31 +54,33 @@ def span_from_str(span_str):
d = (yearfrom - yearfrom % 100) + d
return d
- years = [int(x) for x in re.findall(r'\d+', span_str)]
+ years = [int(x) for x in re.findall(r"\d+", span_str)]
if not years:
- raise ui.UserError("invalid range defined for year bucket '%s': no "
- "year found" % span_str)
+ raise ui.UserError(
+ "invalid range defined for year bucket '%s': no "
+ "year found" % span_str
+ )
try:
years = [normalize_year(x, years[0]) for x in years]
except BucketError as exc:
- raise ui.UserError("invalid range defined for year bucket '%s': %s" %
- (span_str, exc))
+ raise ui.UserError(
+ "invalid range defined for year bucket '%s': %s" % (span_str, exc)
+ )
- res = {'from': years[0], 'str': span_str}
+ res = {"from": years[0], "str": span_str}
if len(years) > 1:
- res['to'] = years[-1]
+ res["to"] = years[-1]
return res
def complete_year_spans(spans):
- """Set the `to` value of spans if empty and sort them chronologically.
- """
- spans.sort(key=lambda x: x['from'])
- for (x, y) in pairwise(spans):
- if 'to' not in x:
- x['to'] = y['from'] - 1
- if spans and 'to' not in spans[-1]:
- spans[-1]['to'] = datetime.now().year
+ """Set the `to` value of spans if empty and sort them chronologically."""
+ spans.sort(key=lambda x: x["from"])
+ for x, y in pairwise(spans):
+ if "to" not in x:
+ x["to"] = y["from"] - 1
+ if spans and "to" not in spans[-1]:
+ spans[-1]["to"] = datetime.now().year
def extend_year_spans(spans, spanlen, start=1900, end=2014):
@@ -89,17 +88,17 @@ def extend_year_spans(spans, spanlen, start=1900, end=2014):
belongs to a span.
"""
extended_spans = spans[:]
- for (x, y) in pairwise(spans):
+ for x, y in pairwise(spans):
# if a gap between two spans, fill the gap with as much spans of
# spanlen length as necessary
- for span_from in range(x['to'] + 1, y['from'], spanlen):
- extended_spans.append({'from': span_from})
+ for span_from in range(x["to"] + 1, y["from"], spanlen):
+ extended_spans.append({"from": span_from})
# Create spans prior to declared ones
- for span_from in range(spans[0]['from'] - spanlen, start, -spanlen):
- extended_spans.append({'from': span_from})
+ for span_from in range(spans[0]["from"] - spanlen, start, -spanlen):
+ extended_spans.append({"from": span_from})
# Create spans after the declared ones
- for span_from in range(spans[-1]['to'] + 1, end, spanlen):
- extended_spans.append({'from': span_from})
+ for span_from in range(spans[-1]["to"] + 1, end, spanlen):
+ extended_spans.append({"from": span_from})
complete_year_spans(extended_spans)
return extended_spans
@@ -117,25 +116,29 @@ def build_year_spans(year_spans_str):
def str2fmt(s):
- """Deduces formatting syntax from a span string.
- """
- regex = re.compile(r"(?P\D*)(?P\d+)(?P\D*)"
- r"(?P\d*)(?P\D*)")
+ """Deduces formatting syntax from a span string."""
+ regex = re.compile(
+ r"(?P\D*)(?P\d+)(?P\D*)"
+ r"(?P\d*)(?P\D*)"
+ )
m = re.match(regex, s)
- res = {'fromnchars': len(m.group('fromyear')),
- 'tonchars': len(m.group('toyear'))}
- res['fmt'] = "{}%s{}{}{}".format(m.group('bef'),
- m.group('sep'),
- '%s' if res['tonchars'] else '',
- m.group('after'))
+ res = {
+ "fromnchars": len(m.group("fromyear")),
+ "tonchars": len(m.group("toyear")),
+ }
+ res["fmt"] = "{}%s{}{}{}".format(
+ m.group("bef"),
+ m.group("sep"),
+ "%s" if res["tonchars"] else "",
+ m.group("after"),
+ )
return res
def format_span(fmt, yearfrom, yearto, fromnchars, tonchars):
- """Return a span string representation.
- """
- args = (str(yearfrom)[-fromnchars:])
+ """Return a span string representation."""
+ args = str(yearfrom)[-fromnchars:]
if tonchars:
args = (str(yearfrom)[-fromnchars:], str(yearto)[-tonchars:])
@@ -143,11 +146,10 @@ def format_span(fmt, yearfrom, yearto, fromnchars, tonchars):
def extract_modes(spans):
- """Extract the most common spans lengths and representation formats
- """
- rangelen = sorted([x['to'] - x['from'] + 1 for x in spans])
+ """Extract the most common spans lengths and representation formats"""
+ rangelen = sorted([x["to"] - x["from"] + 1 for x in spans])
deflen = sorted(rangelen, key=rangelen.count)[-1]
- reprs = [str2fmt(x['str']) for x in spans]
+ reprs = [str2fmt(x["str"]) for x in spans]
deffmt = sorted(reprs, key=reprs.count)[-1]
return deflen, deffmt
@@ -167,13 +169,16 @@ def build_alpha_spans(alpha_spans_str, alpha_regexs):
begin_index = ASCII_DIGITS.index(bucket[0])
end_index = ASCII_DIGITS.index(bucket[-1])
else:
- raise ui.UserError("invalid range defined for alpha bucket "
- "'%s': no alphanumeric character found" %
- elem)
+ raise ui.UserError(
+ "invalid range defined for alpha bucket "
+ "'%s': no alphanumeric character found" % elem
+ )
spans.append(
re.compile(
- "^[" + ASCII_DIGITS[begin_index:end_index + 1] +
- ASCII_DIGITS[begin_index:end_index + 1].upper() + "]"
+ "^["
+ + ASCII_DIGITS[begin_index : end_index + 1]
+ + ASCII_DIGITS[begin_index : end_index + 1].upper()
+ + "]"
)
)
return spans
@@ -182,29 +187,32 @@ def build_alpha_spans(alpha_spans_str, alpha_regexs):
class BucketPlugin(plugins.BeetsPlugin):
def __init__(self):
super().__init__()
- self.template_funcs['bucket'] = self._tmpl_bucket
+ self.template_funcs["bucket"] = self._tmpl_bucket
- self.config.add({
- 'bucket_year': [],
- 'bucket_alpha': [],
- 'bucket_alpha_regex': {},
- 'extrapolate': False
- })
+ self.config.add(
+ {
+ "bucket_year": [],
+ "bucket_alpha": [],
+ "bucket_alpha_regex": {},
+ "extrapolate": False,
+ }
+ )
self.setup()
def setup(self):
- """Setup plugin from config options
- """
- self.year_spans = build_year_spans(self.config['bucket_year'].get())
- if self.year_spans and self.config['extrapolate']:
- [self.ys_len_mode,
- self.ys_repr_mode] = extract_modes(self.year_spans)
- self.year_spans = extend_year_spans(self.year_spans,
- self.ys_len_mode)
+ """Setup plugin from config options"""
+ self.year_spans = build_year_spans(self.config["bucket_year"].get())
+ if self.year_spans and self.config["extrapolate"]:
+ [self.ys_len_mode, self.ys_repr_mode] = extract_modes(
+ self.year_spans
+ )
+ self.year_spans = extend_year_spans(
+ self.year_spans, self.ys_len_mode
+ )
self.alpha_spans = build_alpha_spans(
- self.config['bucket_alpha'].get(),
- self.config['bucket_alpha_regex'].get()
+ self.config["bucket_alpha"].get(),
+ self.config["bucket_alpha_regex"].get(),
)
def find_bucket_year(self, year):
@@ -212,30 +220,33 @@ class BucketPlugin(plugins.BeetsPlugin):
if no matching bucket.
"""
for ys in self.year_spans:
- if ys['from'] <= int(year) <= ys['to']:
- if 'str' in ys:
- return ys['str']
+ if ys["from"] <= int(year) <= ys["to"]:
+ if "str" in ys:
+ return ys["str"]
else:
- return format_span(self.ys_repr_mode['fmt'],
- ys['from'], ys['to'],
- self.ys_repr_mode['fromnchars'],
- self.ys_repr_mode['tonchars'])
+ return format_span(
+ self.ys_repr_mode["fmt"],
+ ys["from"],
+ ys["to"],
+ self.ys_repr_mode["fromnchars"],
+ self.ys_repr_mode["tonchars"],
+ )
return year
def find_bucket_alpha(self, s):
"""Return alpha-range bucket that matches given string or return the
string initial if no matching bucket.
"""
- for (i, span) in enumerate(self.alpha_spans):
+ for i, span in enumerate(self.alpha_spans):
if span.match(s):
- return self.config['bucket_alpha'].get()[i]
+ return self.config["bucket_alpha"].get()[i]
return s[0].upper()
def _tmpl_bucket(self, text, field=None):
if not field and len(text) == 4 and text.isdigit():
- field = 'year'
+ field = "year"
- if field == 'year':
+ if field == "year":
func = self.find_bucket_year
else:
func = self.find_bucket_alpha
diff --git a/lib/beetsplug/chroma.py b/lib/beetsplug/chroma.py
index 353923aa..369a3cc7 100644
--- a/lib/beetsplug/chroma.py
+++ b/lib/beetsplug/chroma.py
@@ -16,18 +16,17 @@
autotagger. Requires the pyacoustid library.
"""
-from beets import plugins
-from beets import ui
-from beets import util
-from beets import config
-from beets.autotag import hooks
-import confuse
-import acoustid
+import re
from collections import defaultdict
from functools import partial
-import re
-API_KEY = '1vOwZtEn'
+import acoustid
+import confuse
+
+from beets import config, plugins, ui, util
+from beets.autotag import hooks
+
+API_KEY = "1vOwZtEn"
SCORE_THRESH = 0.5
TRACK_ID_WEIGHT = 10.0
COMMON_REL_THRESH = 0.6 # How many tracks must have an album in common?
@@ -49,8 +48,7 @@ _acoustids = {}
def prefix(it, count):
- """Truncate an iterable to at most `count` items.
- """
+ """Truncate an iterable to at most `count` items."""
for i, v in enumerate(it):
if i >= count:
break
@@ -58,13 +56,12 @@ def prefix(it, count):
def releases_key(release, countries, original_year):
- """Used as a key to sort releases by date then preferred country
- """
- date = release.get('date')
+ """Used as a key to sort releases by date then preferred country"""
+ date = release.get("date")
if date and original_year:
- year = date.get('year', 9999)
- month = date.get('month', 99)
- day = date.get('day', 99)
+ year = date.get("year", 9999)
+ month = date.get("month", 99)
+ day = date.get("day", 99)
else:
year = 9999
month = 99
@@ -72,9 +69,9 @@ def releases_key(release, countries, original_year):
# Uses index of preferred countries to sort
country_key = 99
- if release.get('country'):
+ if release.get("country"):
for i, country in enumerate(countries):
- if country.match(release['country']):
+ if country.match(release["country"]):
country_key = i
break
@@ -88,56 +85,63 @@ def acoustid_match(log, path):
try:
duration, fp = acoustid.fingerprint_file(util.syspath(path))
except acoustid.FingerprintGenerationError as exc:
- log.error('fingerprinting of {0} failed: {1}',
- util.displayable_path(repr(path)), exc)
+ log.error(
+ "fingerprinting of {0} failed: {1}",
+ util.displayable_path(repr(path)),
+ exc,
+ )
return None
fp = fp.decode()
_fingerprints[path] = fp
try:
- res = acoustid.lookup(API_KEY, fp, duration,
- meta='recordings releases')
+ res = acoustid.lookup(API_KEY, fp, duration, meta="recordings releases")
except acoustid.AcoustidError as exc:
- log.debug('fingerprint matching {0} failed: {1}',
- util.displayable_path(repr(path)), exc)
+ log.debug(
+ "fingerprint matching {0} failed: {1}",
+ util.displayable_path(repr(path)),
+ exc,
+ )
return None
- log.debug('chroma: fingerprinted {0}',
- util.displayable_path(repr(path)))
+ log.debug("chroma: fingerprinted {0}", util.displayable_path(repr(path)))
# Ensure the response is usable and parse it.
- if res['status'] != 'ok' or not res.get('results'):
- log.debug('no match found')
+ if res["status"] != "ok" or not res.get("results"):
+ log.debug("no match found")
return None
- result = res['results'][0] # Best match.
- if result['score'] < SCORE_THRESH:
- log.debug('no results above threshold')
+ result = res["results"][0] # Best match.
+ if result["score"] < SCORE_THRESH:
+ log.debug("no results above threshold")
return None
- _acoustids[path] = result['id']
+ _acoustids[path] = result["id"]
# Get recording and releases from the result
- if not result.get('recordings'):
- log.debug('no recordings found')
+ if not result.get("recordings"):
+ log.debug("no recordings found")
return None
recording_ids = []
releases = []
- for recording in result['recordings']:
- recording_ids.append(recording['id'])
- if 'releases' in recording:
- releases.extend(recording['releases'])
+ for recording in result["recordings"]:
+ recording_ids.append(recording["id"])
+ if "releases" in recording:
+ releases.extend(recording["releases"])
# The releases list is essentially in random order from the Acoustid lookup
# so we optionally sort it using the match.preferred configuration options.
# 'original_year' to sort the earliest first and
# 'countries' to then sort preferred countries first.
- country_patterns = config['match']['preferred']['countries'].as_str_seq()
+ country_patterns = config["match"]["preferred"]["countries"].as_str_seq()
countries = [re.compile(pat, re.I) for pat in country_patterns]
- original_year = config['match']['preferred']['original_year']
- releases.sort(key=partial(releases_key,
- countries=countries,
- original_year=original_year))
- release_ids = [rel['id'] for rel in releases]
+ original_year = config["match"]["preferred"]["original_year"]
+ releases.sort(
+ key=partial(
+ releases_key, countries=countries, original_year=original_year
+ )
+ )
+ release_ids = [rel["id"] for rel in releases]
- log.debug('matched recordings {0} on releases {1}',
- recording_ids, release_ids)
+ log.debug(
+ "matched recordings {0} on releases {1}", recording_ids, release_ids
+ )
_matches[path] = recording_ids, release_ids
@@ -167,14 +171,16 @@ class AcoustidPlugin(plugins.BeetsPlugin):
def __init__(self):
super().__init__()
- self.config.add({
- 'auto': True,
- })
- config['acoustid']['apikey'].redact = True
+ self.config.add(
+ {
+ "auto": True,
+ }
+ )
+ config["acoustid"]["apikey"].redact = True
- if self.config['auto']:
- self.register_listener('import_task_start', self.fingerprint_task)
- self.register_listener('import_task_apply', apply_acoustid_metadata)
+ if self.config["auto"]:
+ self.register_listener("import_task_start", self.fingerprint_task)
+ self.register_listener("import_task_apply", apply_acoustid_metadata)
def fingerprint_task(self, task, session):
return fingerprint_task(self._log, task, session)
@@ -186,7 +192,7 @@ class AcoustidPlugin(plugins.BeetsPlugin):
return dist
recording_ids, _ = _matches[item.path]
- dist.add_expr('track_id', info.track_id not in recording_ids)
+ dist.add_expr("track_id", info.track_id not in recording_ids)
return dist
def candidates(self, items, artist, album, va_likely, extra_tags=None):
@@ -196,7 +202,7 @@ class AcoustidPlugin(plugins.BeetsPlugin):
if album:
albums.append(album)
- self._log.debug('acoustid album candidates: {0}', len(albums))
+ self._log.debug("acoustid album candidates: {0}", len(albums))
return albums
def item_candidates(self, item, artist, title):
@@ -209,29 +215,31 @@ class AcoustidPlugin(plugins.BeetsPlugin):
track = hooks.track_for_mbid(recording_id)
if track:
tracks.append(track)
- self._log.debug('acoustid item candidates: {0}', len(tracks))
+ self._log.debug("acoustid item candidates: {0}", len(tracks))
return tracks
def commands(self):
- submit_cmd = ui.Subcommand('submit',
- help='submit Acoustid fingerprints')
+ submit_cmd = ui.Subcommand(
+ "submit", help="submit Acoustid fingerprints"
+ )
def submit_cmd_func(lib, opts, args):
try:
- apikey = config['acoustid']['apikey'].as_str()
+ apikey = config["acoustid"]["apikey"].as_str()
except confuse.NotFoundError:
- raise ui.UserError('no Acoustid user API key provided')
+ raise ui.UserError("no Acoustid user API key provided")
submit_items(self._log, apikey, lib.items(ui.decargs(args)))
+
submit_cmd.func = submit_cmd_func
fingerprint_cmd = ui.Subcommand(
- 'fingerprint',
- help='generate fingerprints for items without them'
+ "fingerprint", help="generate fingerprints for items without them"
)
def fingerprint_cmd_func(lib, opts, args):
for item in lib.items(ui.decargs(args)):
fingerprint_item(self._log, item, write=ui.should_write())
+
fingerprint_cmd.func = fingerprint_cmd_func
return [submit_cmd, fingerprint_cmd]
@@ -250,8 +258,7 @@ def fingerprint_task(log, task, session):
def apply_acoustid_metadata(task, session):
- """Apply Acoustid metadata (fingerprint and ID) to the task's items.
- """
+ """Apply Acoustid metadata (fingerprint and ID) to the task's items."""
for item in task.imported_items():
if item.path in _fingerprints:
item.acoustid_fingerprint = _fingerprints[item.path]
@@ -263,17 +270,16 @@ def apply_acoustid_metadata(task, session):
def submit_items(log, userkey, items, chunksize=64):
- """Submit fingerprints for the items to the Acoustid server.
- """
+ """Submit fingerprints for the items to the Acoustid server."""
data = [] # The running list of dictionaries to submit.
def submit_chunk():
"""Submit the current accumulated fingerprint data."""
- log.info('submitting {0} fingerprints', len(data))
+ log.info("submitting {0} fingerprints", len(data))
try:
acoustid.submit(API_KEY, userkey, data)
except acoustid.AcoustidError as exc:
- log.warning('acoustid submission error: {0}', exc)
+ log.warning("acoustid submission error: {0}", exc)
del data[:]
for item in items:
@@ -281,23 +287,25 @@ def submit_items(log, userkey, items, chunksize=64):
# Construct a submission dictionary for this item.
item_data = {
- 'duration': int(item.length),
- 'fingerprint': fp,
+ "duration": int(item.length),
+ "fingerprint": fp,
}
if item.mb_trackid:
- item_data['mbid'] = item.mb_trackid
- log.debug('submitting MBID')
+ item_data["mbid"] = item.mb_trackid
+ log.debug("submitting MBID")
else:
- item_data.update({
- 'track': item.title,
- 'artist': item.artist,
- 'album': item.album,
- 'albumartist': item.albumartist,
- 'year': item.year,
- 'trackno': item.track,
- 'discno': item.disc,
- })
- log.debug('submitting textual metadata')
+ item_data.update(
+ {
+ "track": item.title,
+ "artist": item.artist,
+ "album": item.album,
+ "albumartist": item.albumartist,
+ "year": item.year,
+ "trackno": item.track,
+ "discno": item.disc,
+ }
+ )
+ log.debug("submitting textual metadata")
data.append(item_data)
# If we have enough data, submit a chunk.
@@ -318,28 +326,31 @@ def fingerprint_item(log, item, write=False):
"""
# Get a fingerprint and length for this track.
if not item.length:
- log.info('{0}: no duration available',
- util.displayable_path(item.path))
+ log.info("{0}: no duration available", util.displayable_path(item.path))
elif item.acoustid_fingerprint:
if write:
- log.info('{0}: fingerprint exists, skipping',
- util.displayable_path(item.path))
+ log.info(
+ "{0}: fingerprint exists, skipping",
+ util.displayable_path(item.path),
+ )
else:
- log.info('{0}: using existing fingerprint',
- util.displayable_path(item.path))
+ log.info(
+ "{0}: using existing fingerprint",
+ util.displayable_path(item.path),
+ )
return item.acoustid_fingerprint
else:
- log.info('{0}: fingerprinting',
- util.displayable_path(item.path))
+ log.info("{0}: fingerprinting", util.displayable_path(item.path))
try:
_, fp = acoustid.fingerprint_file(util.syspath(item.path))
item.acoustid_fingerprint = fp.decode()
if write:
- log.info('{0}: writing fingerprint',
- util.displayable_path(item.path))
+ log.info(
+ "{0}: writing fingerprint", util.displayable_path(item.path)
+ )
item.try_write()
if item._db:
item.store()
return item.acoustid_fingerprint
except acoustid.FingerprintGenerationError as exc:
- log.info('fingerprint generation failed: {0}', exc)
+ log.info("fingerprint generation failed: {0}", exc)
diff --git a/lib/beetsplug/convert.py b/lib/beetsplug/convert.py
index 6bc07c28..f150b7c3 100644
--- a/lib/beetsplug/convert.py
+++ b/lib/beetsplug/convert.py
@@ -14,33 +14,33 @@
"""Converts tracks or albums to external directory
"""
-from beets.util import par_map, decode_commandline_path, arg_encoding
-
+import logging
import os
-import threading
+import shlex
import subprocess
import tempfile
-import shlex
+import threading
from string import Template
-from beets import ui, util, plugins, config
+from confuse import ConfigTypeError, Optional
+
+from beets import art, config, plugins, ui, util
+from beets.library import Item, parse_query_string
from beets.plugins import BeetsPlugin
-from confuse import ConfigTypeError
-from beets import art
+from beets.util import arg_encoding, par_map
from beets.util.artresizer import ArtResizer
-from beets.library import parse_query_string
-from beets.library import Item
+from beets.util.m3u import M3UFile
_fs_lock = threading.Lock()
_temp_files = [] # Keep track of temporary transcoded files for deletion.
# Some convenient alternate names for formats.
ALIASES = {
- 'wma': 'windows media',
- 'vorbis': 'ogg',
+ "windows media": "wma",
+ "vorbis": "ogg",
}
-LOSSLESS_FORMATS = ['ape', 'flac', 'alac', 'wav', 'aiff']
+LOSSLESS_FORMATS = ["ape", "flac", "alac", "wave", "aiff"]
def replace_ext(path, ext):
@@ -48,140 +48,217 @@ def replace_ext(path, ext):
The new extension must not contain a leading dot.
"""
- ext_dot = b'.' + ext
+ ext_dot = b"." + ext
return os.path.splitext(path)[0] + ext_dot
def get_format(fmt=None):
- """Return the command template and the extension from the config.
- """
+ """Return the command template and the extension from the config."""
if not fmt:
- fmt = config['convert']['format'].as_str().lower()
+ fmt = config["convert"]["format"].as_str().lower()
fmt = ALIASES.get(fmt, fmt)
try:
- format_info = config['convert']['formats'][fmt].get(dict)
- command = format_info['command']
- extension = format_info.get('extension', fmt)
+ format_info = config["convert"]["formats"][fmt].get(dict)
+ command = format_info["command"]
+ extension = format_info.get("extension", fmt)
except KeyError:
raise ui.UserError(
- 'convert: format {} needs the "command" field'
- .format(fmt)
+ 'convert: format {} needs the "command" field'.format(fmt)
)
except ConfigTypeError:
- command = config['convert']['formats'][fmt].get(str)
+ command = config["convert"]["formats"][fmt].get(str)
extension = fmt
# Convenience and backwards-compatibility shortcuts.
- keys = config['convert'].keys()
- if 'command' in keys:
- command = config['convert']['command'].as_str()
- elif 'opts' in keys:
+ keys = config["convert"].keys()
+ if "command" in keys:
+ command = config["convert"]["command"].as_str()
+ elif "opts" in keys:
# Undocumented option for backwards compatibility with < 1.3.1.
- command = 'ffmpeg -i $source -y {} $dest'.format(
- config['convert']['opts'].as_str()
+ command = "ffmpeg -i $source -y {} $dest".format(
+ config["convert"]["opts"].as_str()
)
- if 'extension' in keys:
- extension = config['convert']['extension'].as_str()
+ if "extension" in keys:
+ extension = config["convert"]["extension"].as_str()
- return (command.encode('utf-8'), extension.encode('utf-8'))
+ return (command.encode("utf-8"), extension.encode("utf-8"))
def should_transcode(item, fmt):
"""Determine whether the item should be transcoded as part of
conversion (i.e., its bitrate is high or it has the wrong format).
"""
- no_convert_queries = config['convert']['no_convert'].as_str_seq()
+ no_convert_queries = config["convert"]["no_convert"].as_str_seq()
if no_convert_queries:
for query_string in no_convert_queries:
query, _ = parse_query_string(query_string, Item)
if query.match(item):
return False
- if config['convert']['never_convert_lossy_files'] and \
- not (item.format.lower() in LOSSLESS_FORMATS):
+ if config["convert"]["never_convert_lossy_files"] and not (
+ item.format.lower() in LOSSLESS_FORMATS
+ ):
return False
- maxbr = config['convert']['max_bitrate'].get(int)
- return fmt.lower() != item.format.lower() or \
- item.bitrate >= 1000 * maxbr
+ maxbr = config["convert"]["max_bitrate"].get(Optional(int))
+ if maxbr is not None and item.bitrate >= 1000 * maxbr:
+ return True
+ return fmt.lower() != item.format.lower()
class ConvertPlugin(BeetsPlugin):
def __init__(self):
super().__init__()
- self.config.add({
- 'dest': None,
- 'pretend': False,
- 'link': False,
- 'hardlink': False,
- 'threads': util.cpu_count(),
- 'format': 'mp3',
- 'id3v23': 'inherit',
- 'formats': {
- 'aac': {
- 'command': 'ffmpeg -i $source -y -vn -acodec aac '
- '-aq 1 $dest',
- 'extension': 'm4a',
+ self.config.add(
+ {
+ "dest": None,
+ "pretend": False,
+ "link": False,
+ "hardlink": False,
+ "threads": os.cpu_count(),
+ "format": "mp3",
+ "id3v23": "inherit",
+ "formats": {
+ "aac": {
+ "command": "ffmpeg -i $source -y -vn -acodec aac "
+ "-aq 1 $dest",
+ "extension": "m4a",
+ },
+ "alac": {
+ "command": "ffmpeg -i $source -y -vn -acodec alac $dest",
+ "extension": "m4a",
+ },
+ "flac": "ffmpeg -i $source -y -vn -acodec flac $dest",
+ "mp3": "ffmpeg -i $source -y -vn -aq 2 $dest",
+ "opus": "ffmpeg -i $source -y -vn -acodec libopus -ab 96k $dest",
+ "ogg": "ffmpeg -i $source -y -vn -acodec libvorbis -aq 3 $dest",
+ "wma": "ffmpeg -i $source -y -vn -acodec wmav2 -vn $dest",
},
- 'alac': {
- 'command': 'ffmpeg -i $source -y -vn -acodec alac $dest',
- 'extension': 'm4a',
- },
- 'flac': 'ffmpeg -i $source -y -vn -acodec flac $dest',
- 'mp3': 'ffmpeg -i $source -y -vn -aq 2 $dest',
- 'opus':
- 'ffmpeg -i $source -y -vn -acodec libopus -ab 96k $dest',
- 'ogg':
- 'ffmpeg -i $source -y -vn -acodec libvorbis -aq 3 $dest',
- 'wma':
- 'ffmpeg -i $source -y -vn -acodec wmav2 -vn $dest',
- },
- 'max_bitrate': 500,
- 'auto': False,
- 'tmpdir': None,
- 'quiet': False,
- 'embed': True,
- 'paths': {},
- 'no_convert': '',
- 'never_convert_lossy_files': False,
- 'copy_album_art': False,
- 'album_art_maxwidth': 0,
- 'delete_originals': False,
- })
- self.early_import_stages = [self.auto_convert]
+ "max_bitrate": None,
+ "auto": False,
+ "auto_keep": False,
+ "tmpdir": None,
+ "quiet": False,
+ "embed": True,
+ "paths": {},
+ "no_convert": "",
+ "never_convert_lossy_files": False,
+ "copy_album_art": False,
+ "album_art_maxwidth": 0,
+ "delete_originals": False,
+ "playlist": None,
+ }
+ )
+ self.early_import_stages = [self.auto_convert, self.auto_convert_keep]
- self.register_listener('import_task_files', self._cleanup)
+ self.register_listener("import_task_files", self._cleanup)
def commands(self):
- cmd = ui.Subcommand('convert', help='convert to external location')
- cmd.parser.add_option('-p', '--pretend', action='store_true',
- help='show actions but do nothing')
- cmd.parser.add_option('-t', '--threads', action='store', type='int',
- help='change the number of threads, \
- defaults to maximum available processors')
- cmd.parser.add_option('-k', '--keep-new', action='store_true',
- dest='keep_new', help='keep only the converted \
- and move the old files')
- cmd.parser.add_option('-d', '--dest', action='store',
- help='set the destination directory')
- cmd.parser.add_option('-f', '--format', action='store', dest='format',
- help='set the target format of the tracks')
- cmd.parser.add_option('-y', '--yes', action='store_true', dest='yes',
- help='do not ask for confirmation')
- cmd.parser.add_option('-l', '--link', action='store_true', dest='link',
- help='symlink files that do not \
- need transcoding.')
- cmd.parser.add_option('-H', '--hardlink', action='store_true',
- dest='hardlink',
- help='hardlink files that do not \
- need transcoding. Overrides --link.')
+ cmd = ui.Subcommand("convert", help="convert to external location")
+ cmd.parser.add_option(
+ "-p",
+ "--pretend",
+ action="store_true",
+ help="show actions but do nothing",
+ )
+ cmd.parser.add_option(
+ "-t",
+ "--threads",
+ action="store",
+ type="int",
+ help="change the number of threads, \
+ defaults to maximum available processors",
+ )
+ cmd.parser.add_option(
+ "-k",
+ "--keep-new",
+ action="store_true",
+ dest="keep_new",
+ help="keep only the converted \
+ and move the old files",
+ )
+ cmd.parser.add_option(
+ "-d", "--dest", action="store", help="set the destination directory"
+ )
+ cmd.parser.add_option(
+ "-f",
+ "--format",
+ action="store",
+ dest="format",
+ help="set the target format of the tracks",
+ )
+ cmd.parser.add_option(
+ "-y",
+ "--yes",
+ action="store_true",
+ dest="yes",
+ help="do not ask for confirmation",
+ )
+ cmd.parser.add_option(
+ "-l",
+ "--link",
+ action="store_true",
+ dest="link",
+ help="symlink files that do not \
+ need transcoding.",
+ )
+ cmd.parser.add_option(
+ "-H",
+ "--hardlink",
+ action="store_true",
+ dest="hardlink",
+ help="hardlink files that do not \
+ need transcoding. Overrides --link.",
+ )
+ cmd.parser.add_option(
+ "-m",
+ "--playlist",
+ action="store",
+ help="""create an m3u8 playlist file containing
+ the converted files. The playlist file will be
+ saved below the destination directory, thus
+ PLAYLIST could be a file name or a relative path.
+ To ensure a working playlist when transferred to
+ a different computer, or opened from an external
+ drive, relative paths pointing to media files
+ will be used.""",
+ )
cmd.parser.add_album_option()
cmd.func = self.convert_func
return [cmd]
def auto_convert(self, config, task):
- if self.config['auto']:
- par_map(lambda item: self.convert_on_import(config.lib, item),
- task.imported_items())
+ if self.config["auto"]:
+ par_map(
+ lambda item: self.convert_on_import(config.lib, item),
+ task.imported_items(),
+ )
+
+ def auto_convert_keep(self, config, task):
+ if self.config["auto_keep"]:
+ empty_opts = self.commands()[0].parser.get_default_values()
+ (
+ dest,
+ threads,
+ path_formats,
+ fmt,
+ pretend,
+ hardlink,
+ link,
+ playlist,
+ ) = self._get_opts_and_config(empty_opts)
+
+ items = task.imported_items()
+ self._parallel_convert(
+ dest,
+ False,
+ path_formats,
+ fmt,
+ pretend,
+ link,
+ hardlink,
+ threads,
+ items,
+ )
# Utilities converted from functions to methods on logging overhaul
@@ -196,55 +273,70 @@ class ConvertPlugin(BeetsPlugin):
assert isinstance(source, bytes)
assert isinstance(dest, bytes)
- quiet = self.config['quiet'].get(bool)
+ quiet = self.config["quiet"].get(bool)
if not quiet and not pretend:
- self._log.info('Encoding {0}', util.displayable_path(source))
+ self._log.info("Encoding {0}", util.displayable_path(source))
- command = command.decode(arg_encoding(), 'surrogateescape')
- source = decode_commandline_path(source)
- dest = decode_commandline_path(dest)
+ command = command.decode(arg_encoding(), "surrogateescape")
+ source = os.fsdecode(source)
+ dest = os.fsdecode(dest)
# Substitute $source and $dest in the argument list.
args = shlex.split(command)
encode_cmd = []
for i, arg in enumerate(args):
- args[i] = Template(arg).safe_substitute({
- 'source': source,
- 'dest': dest,
- })
+ args[i] = Template(arg).safe_substitute(
+ {
+ "source": source,
+ "dest": dest,
+ }
+ )
encode_cmd.append(args[i].encode(util.arg_encoding()))
if pretend:
- self._log.info('{0}', ' '.join(ui.decargs(args)))
+ self._log.info("{0}", " ".join(ui.decargs(args)))
return
try:
util.command_output(encode_cmd)
except subprocess.CalledProcessError as exc:
# Something went wrong (probably Ctrl+C), remove temporary files
- self._log.info('Encoding {0} failed. Cleaning up...',
- util.displayable_path(source))
- self._log.debug('Command {0} exited with status {1}: {2}',
- args,
- exc.returncode,
- exc.output)
+ self._log.info(
+ "Encoding {0} failed. Cleaning up...",
+ util.displayable_path(source),
+ )
+ self._log.debug(
+ "Command {0} exited with status {1}: {2}",
+ args,
+ exc.returncode,
+ exc.output,
+ )
util.remove(dest)
util.prune_dirs(os.path.dirname(dest))
raise
except OSError as exc:
raise ui.UserError(
"convert: couldn't invoke '{}': {}".format(
- ' '.join(ui.decargs(args)), exc
+ " ".join(ui.decargs(args)), exc
)
)
if not quiet and not pretend:
- self._log.info('Finished encoding {0}',
- util.displayable_path(source))
+ self._log.info(
+ "Finished encoding {0}", util.displayable_path(source)
+ )
- def convert_item(self, dest_dir, keep_new, path_formats, fmt,
- pretend=False, link=False, hardlink=False):
+ def convert_item(
+ self,
+ dest_dir,
+ keep_new,
+ path_formats,
+ fmt,
+ pretend=False,
+ link=False,
+ hardlink=False,
+ ):
"""A pipeline thread that converts `Item` objects from a
library.
"""
@@ -252,8 +344,7 @@ class ConvertPlugin(BeetsPlugin):
item, original, converted = None, None, None
while True:
item = yield (item, original, converted)
- dest = item.destination(basedir=dest_dir,
- path_formats=path_formats)
+ dest = item.destination(basedir=dest_dir, path_formats=path_formats)
# When keeping the new file in the library, we first move the
# current (pristine) file to the destination. We'll then copy it
@@ -277,18 +368,23 @@ class ConvertPlugin(BeetsPlugin):
util.mkdirall(dest)
if os.path.exists(util.syspath(dest)):
- self._log.info('Skipping {0} (target file exists)',
- util.displayable_path(item.path))
+ self._log.info(
+ "Skipping {0} (target file exists)",
+ util.displayable_path(item.path),
+ )
continue
if keep_new:
if pretend:
- self._log.info('mv {0} {1}',
- util.displayable_path(item.path),
- util.displayable_path(original))
+ self._log.info(
+ "mv {0} {1}",
+ util.displayable_path(item.path),
+ util.displayable_path(original),
+ )
else:
- self._log.info('Moving to {0}',
- util.displayable_path(original))
+ self._log.info(
+ "Moving to {0}", util.displayable_path(original)
+ )
util.move(item.path, original)
if should_transcode(item, fmt):
@@ -300,20 +396,25 @@ class ConvertPlugin(BeetsPlugin):
else:
linked = link or hardlink
if pretend:
- msg = 'ln' if hardlink else ('ln -s' if link else 'cp')
+ msg = "ln" if hardlink else ("ln -s" if link else "cp")
- self._log.info('{2} {0} {1}',
- util.displayable_path(original),
- util.displayable_path(converted),
- msg)
+ self._log.info(
+ "{2} {0} {1}",
+ util.displayable_path(original),
+ util.displayable_path(converted),
+ msg,
+ )
else:
# No transcoding necessary.
- msg = 'Hardlinking' if hardlink \
- else ('Linking' if link else 'Copying')
+ msg = (
+ "Hardlinking"
+ if hardlink
+ else ("Linking" if link else "Copying")
+ )
- self._log.info('{1} {0}',
- util.displayable_path(item.path),
- msg)
+ self._log.info(
+ "{1} {0}", util.displayable_path(item.path), msg
+ )
if hardlink:
util.hardlink(original, converted)
@@ -325,8 +426,8 @@ class ConvertPlugin(BeetsPlugin):
if pretend:
continue
- id3v23 = self.config['id3v23'].as_choice([True, False, 'inherit'])
- if id3v23 == 'inherit':
+ id3v23 = self.config["id3v23"].as_choice([True, False, "inherit"])
+ if id3v23 == "inherit":
id3v23 = None
# Write tags from the database to the converted file.
@@ -339,23 +440,41 @@ class ConvertPlugin(BeetsPlugin):
item.read()
item.store() # Store new path and audio data.
- if self.config['embed'] and not linked:
+ if self.config["embed"] and not linked:
album = item._cached_album
if album and album.artpath:
- self._log.debug('embedding album art from {}',
- util.displayable_path(album.artpath))
- art.embed_item(self._log, item, album.artpath,
- itempath=converted, id3v23=id3v23)
+ maxwidth = self._get_art_resize(album.artpath)
+ self._log.debug(
+ "embedding album art from {}",
+ util.displayable_path(album.artpath),
+ )
+ art.embed_item(
+ self._log,
+ item,
+ album.artpath,
+ maxwidth,
+ itempath=converted,
+ id3v23=id3v23,
+ )
if keep_new:
- plugins.send('after_convert', item=item,
- dest=dest, keepnew=True)
+ plugins.send(
+ "after_convert", item=item, dest=dest, keepnew=True
+ )
else:
- plugins.send('after_convert', item=item,
- dest=converted, keepnew=False)
+ plugins.send(
+ "after_convert", item=item, dest=converted, keepnew=False
+ )
- def copy_album_art(self, album, dest_dir, path_formats, pretend=False,
- link=False, hardlink=False):
+ def copy_album_art(
+ self,
+ album,
+ dest_dir,
+ path_formats,
+ pretend=False,
+ link=False,
+ hardlink=False,
+ ):
"""Copies or converts the associated cover art of the album. Album must
have at least one track.
"""
@@ -369,8 +488,9 @@ class ConvertPlugin(BeetsPlugin):
# Get the destination of the first item (track) of the album, we use
# this function to format the path accordingly to path_formats.
- dest = album_item.destination(basedir=dest_dir,
- path_formats=path_formats)
+ dest = album_item.destination(
+ basedir=dest_dir, path_formats=path_formats
+ )
# Remove item from the path.
dest = os.path.join(*util.components(dest)[:-1])
@@ -383,46 +503,47 @@ class ConvertPlugin(BeetsPlugin):
util.mkdirall(dest)
if os.path.exists(util.syspath(dest)):
- self._log.info('Skipping {0} (target file exists)',
- util.displayable_path(album.artpath))
+ self._log.info(
+ "Skipping {0} (target file exists)",
+ util.displayable_path(album.artpath),
+ )
return
# Decide whether we need to resize the cover-art image.
- resize = False
- maxwidth = None
- if self.config['album_art_maxwidth']:
- maxwidth = self.config['album_art_maxwidth'].get(int)
- size = ArtResizer.shared.get_size(album.artpath)
- self._log.debug('image size: {}', size)
- if size:
- resize = size[0] > maxwidth
- else:
- self._log.warning('Could not get size of image (please see '
- 'documentation for dependencies).')
+ maxwidth = self._get_art_resize(album.artpath)
# Either copy or resize (while copying) the image.
- if resize:
- self._log.info('Resizing cover art from {0} to {1}',
- util.displayable_path(album.artpath),
- util.displayable_path(dest))
+ if maxwidth is not None:
+ self._log.info(
+ "Resizing cover art from {0} to {1}",
+ util.displayable_path(album.artpath),
+ util.displayable_path(dest),
+ )
if not pretend:
ArtResizer.shared.resize(maxwidth, album.artpath, dest)
else:
if pretend:
- msg = 'ln' if hardlink else ('ln -s' if link else 'cp')
+ msg = "ln" if hardlink else ("ln -s" if link else "cp")
- self._log.info('{2} {0} {1}',
- util.displayable_path(album.artpath),
- util.displayable_path(dest),
- msg)
+ self._log.info(
+ "{2} {0} {1}",
+ util.displayable_path(album.artpath),
+ util.displayable_path(dest),
+ msg,
+ )
else:
- msg = 'Hardlinking' if hardlink \
- else ('Linking' if link else 'Copying')
+ msg = (
+ "Hardlinking"
+ if hardlink
+ else ("Linking" if link else "Copying")
+ )
- self._log.info('{2} cover art from {0} to {1}',
- util.displayable_path(album.artpath),
- util.displayable_path(dest),
- msg)
+ self._log.info(
+ "{2} cover art from {0} to {1}",
+ util.displayable_path(album.artpath),
+ util.displayable_path(dest),
+ msg,
+ )
if hardlink:
util.hardlink(album.artpath, dest)
elif link:
@@ -431,79 +552,92 @@ class ConvertPlugin(BeetsPlugin):
util.copy(album.artpath, dest)
def convert_func(self, lib, opts, args):
- dest = opts.dest or self.config['dest'].get()
- if not dest:
- raise ui.UserError('no convert destination set')
- dest = util.bytestring_path(dest)
-
- threads = opts.threads or self.config['threads'].get(int)
-
- path_formats = ui.get_path_formats(self.config['paths'] or None)
-
- fmt = opts.format or self.config['format'].as_str().lower()
-
- if opts.pretend is not None:
- pretend = opts.pretend
- else:
- pretend = self.config['pretend'].get(bool)
-
- if opts.hardlink is not None:
- hardlink = opts.hardlink
- link = False
- elif opts.link is not None:
- hardlink = False
- link = opts.link
- else:
- hardlink = self.config['hardlink'].get(bool)
- link = self.config['link'].get(bool)
+ (
+ dest,
+ threads,
+ path_formats,
+ fmt,
+ pretend,
+ hardlink,
+ link,
+ playlist,
+ ) = self._get_opts_and_config(opts)
if opts.album:
albums = lib.albums(ui.decargs(args))
items = [i for a in albums for i in a.items()]
if not pretend:
for a in albums:
- ui.print_(format(a, ''))
+ ui.print_(format(a, ""))
else:
items = list(lib.items(ui.decargs(args)))
if not pretend:
for i in items:
- ui.print_(format(i, ''))
+ ui.print_(format(i, ""))
if not items:
- self._log.error('Empty query result.')
+ self._log.error("Empty query result.")
return
if not (pretend or opts.yes or ui.input_yn("Convert? (Y/n)")):
return
- if opts.album and self.config['copy_album_art']:
+ if opts.album and self.config["copy_album_art"]:
for album in albums:
- self.copy_album_art(album, dest, path_formats, pretend,
- link, hardlink)
+ self.copy_album_art(
+ album, dest, path_formats, pretend, link, hardlink
+ )
- convert = [self.convert_item(dest,
- opts.keep_new,
- path_formats,
- fmt,
- pretend,
- link,
- hardlink)
- for _ in range(threads)]
- pipe = util.pipeline.Pipeline([iter(items), convert])
- pipe.run_parallel()
+ self._parallel_convert(
+ dest,
+ opts.keep_new,
+ path_formats,
+ fmt,
+ pretend,
+ link,
+ hardlink,
+ threads,
+ items,
+ )
+
+ if playlist:
+ # Playlist paths are understood as relative to the dest directory.
+ pl_normpath = util.normpath(playlist)
+ pl_dir = os.path.dirname(pl_normpath)
+ self._log.info("Creating playlist file {0}", pl_normpath)
+ # Generates a list of paths to media files, ensures the paths are
+ # relative to the playlist's location and translates the unicode
+ # strings we get from item.destination to bytes.
+ items_paths = [
+ os.path.relpath(
+ util.bytestring_path(
+ item.destination(
+ basedir=dest,
+ path_formats=path_formats,
+ fragment=False,
+ )
+ ),
+ pl_dir,
+ )
+ for item in items
+ ]
+ if not pretend:
+ m3ufile = M3UFile(playlist)
+ m3ufile.set_contents(items_paths)
+ m3ufile.write()
def convert_on_import(self, lib, item):
"""Transcode a file automatically after it is imported into the
library.
"""
- fmt = self.config['format'].as_str().lower()
+ fmt = self.config["format"].as_str().lower()
if should_transcode(item, fmt):
command, ext = get_format()
# Create a temporary file for the conversion.
- tmpdir = self.config['tmpdir'].get()
+ tmpdir = self.config["tmpdir"].get()
if tmpdir:
- tmpdir = util.py3_path(util.bytestring_path(tmpdir))
- fd, dest = tempfile.mkstemp(util.py3_path(b'.' + ext), dir=tmpdir)
+ tmpdir = os.fsdecode(util.bytestring_path(tmpdir))
+ fd, dest = tempfile.mkstemp(os.fsdecode(b"." + ext), dir=tmpdir)
os.close(fd)
dest = util.bytestring_path(dest)
_temp_files.append(dest) # Delete the transcode later.
@@ -522,13 +656,107 @@ class ConvertPlugin(BeetsPlugin):
item.read() # Load new audio information data.
item.store()
- if self.config['delete_originals']:
- self._log.info('Removing original file {0}', source_path)
+ if self.config["delete_originals"]:
+ self._log.log(
+ logging.DEBUG if self.config["quiet"] else logging.INFO,
+ "Removing original file {0}",
+ source_path,
+ )
util.remove(source_path, False)
+ def _get_art_resize(self, artpath):
+ """For a given piece of album art, determine whether or not it needs
+ to be resized according to the user's settings. If so, returns the
+ new size. If not, returns None.
+ """
+ newwidth = None
+ if self.config["album_art_maxwidth"]:
+ maxwidth = self.config["album_art_maxwidth"].get(int)
+ size = ArtResizer.shared.get_size(artpath)
+ self._log.debug("image size: {}", size)
+ if size:
+ if size[0] > maxwidth:
+ newwidth = maxwidth
+ else:
+ self._log.warning(
+ "Could not get size of image (please see "
+ "documentation for dependencies)."
+ )
+ return newwidth
+
def _cleanup(self, task, session):
for path in task.old_paths:
if path in _temp_files:
- if os.path.isfile(path):
+ if os.path.isfile(util.syspath(path)):
util.remove(path)
_temp_files.remove(path)
+
+ def _get_opts_and_config(self, opts):
+ """Returns parameters needed for convert function.
+ Get parameters from command line if available,
+ default to config if not available.
+ """
+ dest = opts.dest or self.config["dest"].get()
+ if not dest:
+ raise ui.UserError("no convert destination set")
+ dest = util.bytestring_path(dest)
+
+ threads = opts.threads or self.config["threads"].get(int)
+
+ path_formats = ui.get_path_formats(self.config["paths"] or None)
+
+ fmt = opts.format or self.config["format"].as_str().lower()
+
+ playlist = opts.playlist or self.config["playlist"].get()
+ if playlist is not None:
+ playlist = os.path.join(dest, util.bytestring_path(playlist))
+
+ if opts.pretend is not None:
+ pretend = opts.pretend
+ else:
+ pretend = self.config["pretend"].get(bool)
+
+ if opts.hardlink is not None:
+ hardlink = opts.hardlink
+ link = False
+ elif opts.link is not None:
+ hardlink = False
+ link = opts.link
+ else:
+ hardlink = self.config["hardlink"].get(bool)
+ link = self.config["link"].get(bool)
+
+ return (
+ dest,
+ threads,
+ path_formats,
+ fmt,
+ pretend,
+ hardlink,
+ link,
+ playlist,
+ )
+
+ def _parallel_convert(
+ self,
+ dest,
+ keep_new,
+ path_formats,
+ fmt,
+ pretend,
+ link,
+ hardlink,
+ threads,
+ items,
+ ):
+ """Run the convert_item function for every items on as many thread as
+ defined in threads
+ """
+ convert = [
+ self.convert_item(
+ dest, keep_new, path_formats, fmt, pretend, link, hardlink
+ )
+ for _ in range(threads)
+ ]
+ pipe = util.pipeline.Pipeline([iter(items), convert])
+ pipe.run_parallel()
diff --git a/lib/beetsplug/deezer.py b/lib/beetsplug/deezer.py
index 5f158f93..a861ea0e 100644
--- a/lib/beetsplug/deezer.py
+++ b/lib/beetsplug/deezer.py
@@ -16,32 +16,66 @@
"""
import collections
+import time
-import unidecode
import requests
+import unidecode
from beets import ui
from beets.autotag import AlbumInfo, TrackInfo
-from beets.plugins import MetadataSourcePlugin, BeetsPlugin
+from beets.dbcore import types
+from beets.library import DateType
+from beets.plugins import BeetsPlugin, MetadataSourcePlugin
+from beets.util.id_extractors import deezer_id_regex
class DeezerPlugin(MetadataSourcePlugin, BeetsPlugin):
- data_source = 'Deezer'
+ data_source = "Deezer"
+
+ item_types = {
+ "deezer_track_rank": types.INTEGER,
+ "deezer_track_id": types.INTEGER,
+ "deezer_updated": DateType(),
+ }
# Base URLs for the Deezer API
# Documentation: https://developers.deezer.com/api/
- search_url = 'https://api.deezer.com/search/'
- album_url = 'https://api.deezer.com/album/'
- track_url = 'https://api.deezer.com/track/'
+ search_url = "https://api.deezer.com/search/"
+ album_url = "https://api.deezer.com/album/"
+ track_url = "https://api.deezer.com/track/"
- id_regex = {
- 'pattern': r'(^|deezer\.com/)([a-z]*/)?({}/)?(\d+)',
- 'match_group': 4,
- }
+ id_regex = deezer_id_regex
def __init__(self):
super().__init__()
+ def commands(self):
+ """Add beet UI commands to interact with Deezer."""
+ deezer_update_cmd = ui.Subcommand(
+ "deezerupdate", help=f"Update {self.data_source} rank"
+ )
+
+ def func(lib, opts, args):
+ items = lib.items(ui.decargs(args))
+ self.deezerupdate(items, ui.should_write())
+
+ deezer_update_cmd.func = func
+
+ return [deezer_update_cmd]
+
+ def fetch_data(self, url):
+ try:
+ response = requests.get(url, timeout=10)
+ response.raise_for_status()
+ data = response.json()
+ except requests.exceptions.RequestException as e:
+ self._log.error("Error fetching data from {}\n Error: {}", url, e)
+ return None
+ if "error" in data:
+ self._log.error("Deezer API error: {}", data["error"]["message"])
+ return None
+ return data
+
def album_for_id(self, album_id):
"""Fetch an album by its Deezer ID or URL and return an
AlbumInfo object or None if the album is not found.
@@ -51,15 +85,20 @@ class DeezerPlugin(MetadataSourcePlugin, BeetsPlugin):
:return: AlbumInfo object for album.
:rtype: beets.autotag.hooks.AlbumInfo or None
"""
- deezer_id = self._get_id('album', album_id)
+ deezer_id = self._get_id("album", album_id, self.id_regex)
if deezer_id is None:
return None
+ album_data = self.fetch_data(self.album_url + deezer_id)
+ if album_data is None:
+ return None
+ contributors = album_data.get("contributors")
+ if contributors is not None:
+ artist, artist_id = self.get_artist(contributors)
+ else:
+ artist, artist_id = None, None
- album_data = requests.get(self.album_url + deezer_id).json()
- artist, artist_id = self.get_artist(album_data['contributors'])
-
- release_date = album_data['release_date']
- date_parts = [int(part) for part in release_date.split('-')]
+ release_date = album_data["release_date"]
+ date_parts = [int(part) for part in release_date.split("-")]
num_date_parts = len(date_parts)
if num_date_parts == 3:
@@ -76,12 +115,23 @@ class DeezerPlugin(MetadataSourcePlugin, BeetsPlugin):
"Invalid `release_date` returned "
"by {} API: '{}'".format(self.data_source, release_date)
)
-
- tracks_data = requests.get(
- self.album_url + deezer_id + '/tracks'
- ).json()['data']
+ tracks_obj = self.fetch_data(self.album_url + deezer_id + "/tracks")
+ if tracks_obj is None:
+ return None
+ try:
+ tracks_data = tracks_obj["data"]
+ except KeyError:
+ self._log.debug("Error fetching album tracks for {}", deezer_id)
+ tracks_data = None
if not tracks_data:
return None
+ while "next" in tracks_obj:
+ tracks_obj = requests.get(
+ tracks_obj["next"],
+ timeout=10,
+ ).json()
+ tracks_data.extend(tracks_obj["data"])
+
tracks = []
medium_totals = collections.defaultdict(int)
for i, track_data in enumerate(tracks_data, start=1):
@@ -93,22 +143,24 @@ class DeezerPlugin(MetadataSourcePlugin, BeetsPlugin):
track.medium_total = medium_totals[track.medium]
return AlbumInfo(
- album=album_data['title'],
+ album=album_data["title"],
album_id=deezer_id,
+ deezer_album_id=deezer_id,
artist=artist,
- artist_credit=self.get_artist([album_data['artist']])[0],
+ artist_credit=self.get_artist([album_data["artist"]])[0],
artist_id=artist_id,
tracks=tracks,
- albumtype=album_data['record_type'],
- va=len(album_data['contributors']) == 1
- and artist.lower() == 'various artists',
+ albumtype=album_data["record_type"],
+ va=len(album_data["contributors"]) == 1
+ and artist.lower() == "various artists",
year=year,
month=month,
day=day,
- label=album_data['label'],
+ label=album_data["label"],
mediums=max(medium_totals.keys()),
data_source=self.data_source,
- data_url=album_data['link'],
+ data_url=album_data["link"],
+ cover_art_url=album_data.get("cover_xl"),
)
def _get_track(self, track_data):
@@ -120,19 +172,23 @@ class DeezerPlugin(MetadataSourcePlugin, BeetsPlugin):
:rtype: beets.autotag.hooks.TrackInfo
"""
artist, artist_id = self.get_artist(
- track_data.get('contributors', [track_data['artist']])
+ track_data.get("contributors", [track_data["artist"]])
)
return TrackInfo(
- title=track_data['title'],
- track_id=track_data['id'],
+ title=track_data["title"],
+ track_id=track_data["id"],
+ deezer_track_id=track_data["id"],
+ isrc=track_data.get("isrc"),
artist=artist,
artist_id=artist_id,
- length=track_data['duration'],
- index=track_data['track_position'],
- medium=track_data['disk_number'],
- medium_index=track_data['track_position'],
+ length=track_data["duration"],
+ index=track_data.get("track_position"),
+ medium=track_data.get("disk_number"),
+ deezer_track_rank=track_data.get("rank"),
+ medium_index=track_data.get("track_position"),
data_source=self.data_source,
- data_url=track_data['link'],
+ data_url=track_data["link"],
+ deezer_updated=time.time(),
)
def track_for_id(self, track_id=None, track_data=None):
@@ -149,29 +205,40 @@ class DeezerPlugin(MetadataSourcePlugin, BeetsPlugin):
:rtype: beets.autotag.hooks.TrackInfo or None
"""
if track_data is None:
- deezer_id = self._get_id('track', track_id)
+ deezer_id = self._get_id("track", track_id, self.id_regex)
if deezer_id is None:
return None
- track_data = requests.get(self.track_url + deezer_id).json()
+ track_data = self.fetch_data(self.track_url + deezer_id)
+ if track_data is None:
+ return None
track = self._get_track(track_data)
# Get album's tracks to set `track.index` (position on the entire
# release) and `track.medium_total` (total number of tracks on
# the track's disc).
- album_tracks_data = requests.get(
- self.album_url + str(track_data['album']['id']) + '/tracks'
- ).json()['data']
+ album_tracks_obj = self.fetch_data(
+ self.album_url + str(track_data["album"]["id"]) + "/tracks"
+ )
+ if album_tracks_obj is None:
+ return None
+ try:
+ album_tracks_data = album_tracks_obj["data"]
+ except KeyError:
+ self._log.debug(
+ "Error fetching album tracks for {}", track_data["album"]["id"]
+ )
+ return None
medium_total = 0
for i, track_data in enumerate(album_tracks_data, start=1):
- if track_data['disk_number'] == track.medium:
+ if track_data["disk_number"] == track.medium:
medium_total += 1
- if track_data['id'] == track.track_id:
+ if track_data["id"] == track.track_id:
track.index = i
track.medium_total = medium_total
return track
@staticmethod
- def _construct_search_query(filters=None, keywords=''):
+ def _construct_search_query(filters=None, keywords=""):
"""Construct a query string with the specified filters and keywords to
be provided to the Deezer Search API
(https://developers.deezer.com/api/search).
@@ -185,14 +252,14 @@ class DeezerPlugin(MetadataSourcePlugin, BeetsPlugin):
"""
query_components = [
keywords,
- ' '.join(f'{k}:"{v}"' for k, v in filters.items()),
+ " ".join(f'{k}:"{v}"' for k, v in filters.items()),
]
- query = ' '.join([q for q in query_components if q])
+ query = " ".join([q for q in query_components if q])
if not isinstance(query, str):
- query = query.decode('utf8')
+ query = query.decode("utf8")
return unidecode.unidecode(query)
- def _search_api(self, query_type, filters=None, keywords=''):
+ def _search_api(self, query_type, filters=None, keywords=""):
"""Query the Deezer Search API for the specified ``keywords``, applying
the provided ``filters``.
@@ -208,19 +275,17 @@ class DeezerPlugin(MetadataSourcePlugin, BeetsPlugin):
if no search results are returned.
:rtype: dict or None
"""
- query = self._construct_search_query(
- keywords=keywords, filters=filters
- )
+ query = self._construct_search_query(keywords=keywords, filters=filters)
if not query:
return None
- self._log.debug(
- f"Searching {self.data_source} for '{query}'"
- )
+ self._log.debug(f"Searching {self.data_source} for '{query}'")
response = requests.get(
- self.search_url + query_type, params={'q': query}
+ self.search_url + query_type,
+ params={"q": query},
+ timeout=10,
)
response.raise_for_status()
- response_data = response.json().get('data', [])
+ response_data = response.json().get("data", [])
self._log.debug(
"Found {} result(s) from {} for '{}'",
len(response_data),
@@ -228,3 +293,30 @@ class DeezerPlugin(MetadataSourcePlugin, BeetsPlugin):
query,
)
return response_data
+
+ def deezerupdate(self, items, write):
+ """Obtain rank information from Deezer."""
+ for index, item in enumerate(items, start=1):
+ self._log.info(
+ "Processing {}/{} tracks - {} ", index, len(items), item
+ )
+ try:
+ deezer_track_id = item.deezer_track_id
+ except AttributeError:
+ self._log.debug("No deezer_track_id present for: {}", item)
+ continue
+ try:
+ rank = self.fetch_data(
+ f"{self.track_url}{deezer_track_id}"
+ ).get("rank")
+ self._log.debug(
+ "Deezer track: {} has {} rank", deezer_track_id, rank
+ )
+ except Exception as e:
+ self._log.debug("Invalid Deezer track_id: {}", e)
+ continue
+ item.deezer_track_rank = int(rank)
+ item.store()
+ item.deezer_updated = time.time()
+ if write:
+ item.try_write()
diff --git a/lib/beetsplug/discogs.py b/lib/beetsplug/discogs.py
index d015e420..344d67a2 100644
--- a/lib/beetsplug/discogs.py
+++ b/lib/beetsplug/discogs.py
@@ -16,62 +16,87 @@
python3-discogs-client library.
"""
-import beets.ui
-from beets import config
-from beets.autotag.hooks import AlbumInfo, TrackInfo
-from beets.plugins import MetadataSourcePlugin, BeetsPlugin, get_distance
-import confuse
-from discogs_client import Release, Master, Client
-from discogs_client.exceptions import DiscogsAPIError
-from requests.exceptions import ConnectionError
import http.client
-import beets
-import re
-import time
import json
-import socket
import os
+import re
+import socket
+import time
import traceback
from string import ascii_lowercase
+import confuse
+from discogs_client import Client, Master, Release
+from discogs_client import __version__ as dc_string
+from discogs_client.exceptions import DiscogsAPIError
+from requests.exceptions import ConnectionError
-USER_AGENT = f'beets/{beets.__version__} +https://beets.io/'
-API_KEY = 'rAzVUQYRaoFjeBjyWuWZ'
-API_SECRET = 'plxtUTqoCzwxZpqdPysCwGuBSmZNdZVy'
+import beets
+import beets.ui
+from beets import config
+from beets.autotag.hooks import AlbumInfo, TrackInfo, string_dist
+from beets.plugins import BeetsPlugin, MetadataSourcePlugin, get_distance
+from beets.util.id_extractors import extract_discogs_id_regex
+
+USER_AGENT = f"beets/{beets.__version__} +https://beets.io/"
+API_KEY = "rAzVUQYRaoFjeBjyWuWZ"
+API_SECRET = "plxtUTqoCzwxZpqdPysCwGuBSmZNdZVy"
# Exceptions that discogs_client should really handle but does not.
-CONNECTION_ERRORS = (ConnectionError, socket.error, http.client.HTTPException,
- ValueError, # JSON decoding raises a ValueError.
- DiscogsAPIError)
+CONNECTION_ERRORS = (
+ ConnectionError,
+ socket.error,
+ http.client.HTTPException,
+ ValueError, # JSON decoding raises a ValueError.
+ DiscogsAPIError,
+)
class DiscogsPlugin(BeetsPlugin):
-
def __init__(self):
super().__init__()
- self.config.add({
- 'apikey': API_KEY,
- 'apisecret': API_SECRET,
- 'tokenfile': 'discogs_token.json',
- 'source_weight': 0.5,
- 'user_token': '',
- 'separator': ', ',
- 'index_tracks': False,
- })
- self.config['apikey'].redact = True
- self.config['apisecret'].redact = True
- self.config['user_token'].redact = True
+ self.check_discogs_client()
+ self.config.add(
+ {
+ "apikey": API_KEY,
+ "apisecret": API_SECRET,
+ "tokenfile": "discogs_token.json",
+ "source_weight": 0.5,
+ "user_token": "",
+ "separator": ", ",
+ "index_tracks": False,
+ "append_style_genre": False,
+ }
+ )
+ self.config["apikey"].redact = True
+ self.config["apisecret"].redact = True
+ self.config["user_token"].redact = True
self.discogs_client = None
- self.register_listener('import_begin', self.setup)
+ self.register_listener("import_begin", self.setup)
+
+ def check_discogs_client(self):
+ """Ensure python3-discogs-client version >= 2.3.15"""
+ dc_min_version = [2, 3, 15]
+ dc_version = [int(elem) for elem in dc_string.split(".")]
+ min_len = min(len(dc_version), len(dc_min_version))
+ gt_min = [
+ (elem > elem_min)
+ for elem, elem_min in zip(
+ dc_version[:min_len], dc_min_version[:min_len]
+ )
+ ]
+ if True not in gt_min:
+ self._log.warning(
+ "python3-discogs-client version should be >= 2.3.15"
+ )
def setup(self, session=None):
- """Create the `discogs_client` field. Authenticate if necessary.
- """
- c_key = self.config['apikey'].as_str()
- c_secret = self.config['apisecret'].as_str()
+ """Create the `discogs_client` field. Authenticate if necessary."""
+ c_key = self.config["apikey"].as_str()
+ c_secret = self.config["apisecret"].as_str()
# Try using a configured user token (bypassing OAuth login).
- user_token = self.config['user_token'].as_str()
+ user_token = self.config["user_token"].as_str()
if user_token:
# The rate limit for authenticated users goes up to 60
# requests per minute.
@@ -86,22 +111,19 @@ class DiscogsPlugin(BeetsPlugin):
# No token yet. Generate one.
token, secret = self.authenticate(c_key, c_secret)
else:
- token = tokendata['token']
- secret = tokendata['secret']
+ token = tokendata["token"]
+ secret = tokendata["secret"]
- self.discogs_client = Client(USER_AGENT, c_key, c_secret,
- token, secret)
+ self.discogs_client = Client(USER_AGENT, c_key, c_secret, token, secret)
def reset_auth(self):
- """Delete token file & redo the auth steps.
- """
+ """Delete token file & redo the auth steps."""
os.remove(self._tokenfile())
self.setup()
def _tokenfile(self):
- """Get the path to the JSON file for storing the OAuth token.
- """
- return self.config['tokenfile'].get(confuse.Filename(in_app_dir=True))
+ """Get the path to the JSON file for storing the OAuth token."""
+ return self.config["tokenfile"].get(confuse.Filename(in_app_dir=True))
def authenticate(self, c_key, c_secret):
# Get the link for the OAuth page.
@@ -109,8 +131,8 @@ class DiscogsPlugin(BeetsPlugin):
try:
_, _, url = auth_client.get_authorize_url()
except CONNECTION_ERRORS as e:
- self._log.debug('connection error: {0}', e)
- raise beets.ui.UserError('communication with Discogs failed')
+ self._log.debug("connection error: {0}", e)
+ raise beets.ui.UserError("communication with Discogs failed")
beets.ui.print_("To authenticate with Discogs, visit:")
beets.ui.print_(url)
@@ -120,34 +142,28 @@ class DiscogsPlugin(BeetsPlugin):
try:
token, secret = auth_client.get_access_token(code)
except DiscogsAPIError:
- raise beets.ui.UserError('Discogs authorization failed')
+ raise beets.ui.UserError("Discogs authorization failed")
except CONNECTION_ERRORS as e:
- self._log.debug('connection error: {0}', e)
- raise beets.ui.UserError('Discogs token request failed')
+ self._log.debug("connection error: {0}", e)
+ raise beets.ui.UserError("Discogs token request failed")
# Save the token for later use.
- self._log.debug('Discogs token {0}, secret {1}', token, secret)
- with open(self._tokenfile(), 'w') as f:
- json.dump({'token': token, 'secret': secret}, f)
+ self._log.debug("Discogs token {0}, secret {1}", token, secret)
+ with open(self._tokenfile(), "w") as f:
+ json.dump({"token": token, "secret": secret}, f)
return token, secret
def album_distance(self, items, album_info, mapping):
- """Returns the album distance.
- """
+ """Returns the album distance."""
return get_distance(
- data_source='Discogs',
- info=album_info,
- config=self.config
+ data_source="Discogs", info=album_info, config=self.config
)
def track_distance(self, item, track_info):
- """Returns the track distance.
- """
+ """Returns the track distance."""
return get_distance(
- data_source='Discogs',
- info=track_info,
- config=self.config
+ data_source="Discogs", info=track_info, config=self.config
)
def candidates(self, items, artist, album, va_likely, extra_tags=None):
@@ -157,48 +173,110 @@ class DiscogsPlugin(BeetsPlugin):
if not self.discogs_client:
return
+ if not album and not artist:
+ self._log.debug(
+ "Skipping Discogs query. Files missing album and "
+ "artist tags."
+ )
+ return []
+
if va_likely:
query = album
else:
- query = f'{artist} {album}'
+ query = f"{artist} {album}"
try:
return self.get_albums(query)
except DiscogsAPIError as e:
- self._log.debug('API Error: {0} (query: {1})', e, query)
+ self._log.debug("API Error: {0} (query: {1})", e, query)
if e.status_code == 401:
self.reset_auth()
return self.candidates(items, artist, album, va_likely)
else:
return []
except CONNECTION_ERRORS:
- self._log.debug('Connection error in album search', exc_info=True)
+ self._log.debug("Connection error in album search", exc_info=True)
return []
- @staticmethod
- def extract_release_id_regex(album_id):
- """Returns the Discogs_id or None."""
- # Discogs-IDs are simple integers. In order to avoid confusion with
- # other metadata plugins, we only look for very specific formats of the
- # input string:
- # - plain integer, optionally wrapped in brackets and prefixed by an
- # 'r', as this is how discogs displays the release ID on its webpage.
- # - legacy url format: discogs.com//release/
- # - current url format: discogs.com/release/-
- # See #291, #4080 and #4085 for the discussions leading up to these
- # patterns.
- # Regex has been tested here https://regex101.com/r/wyLdB4/2
+ def get_track_from_album_by_title(
+ self, album_info, title, dist_threshold=0.3
+ ):
+ def compare_func(track_info):
+ track_title = getattr(track_info, "title", None)
+ dist = string_dist(track_title, title)
+ return track_title and dist < dist_threshold
- for pattern in [
- r'^\[?r?(?P\d+)\]?$',
- r'discogs\.com/release/(?P\d+)-',
- r'discogs\.com/[^/]+/release/(?P\d+)',
- ]:
- match = re.search(pattern, album_id)
- if match:
- return int(match.group('id'))
+ return self.get_track_from_album(album_info, compare_func)
+
+ def get_track_from_album(self, album_info, compare_func):
+ """Return the first track of the release where `compare_func` returns
+ true.
+
+ :return: TrackInfo object.
+ :rtype: beets.autotag.hooks.TrackInfo
+ """
+ if not album_info:
+ return None
+
+ for track_info in album_info.tracks:
+ # check for matching position
+ if not compare_func(track_info):
+ continue
+
+ # attach artist info if not provided
+ if not track_info["artist"]:
+ track_info["artist"] = album_info.artist
+ track_info["artist_id"] = album_info.artist_id
+ # attach album info
+ track_info["album"] = album_info.album
+
+ return track_info
return None
+ def item_candidates(self, item, artist, title):
+ """Returns a list of TrackInfo objects for Search API results
+ matching ``title`` and ``artist``.
+ :param item: Singleton item to be matched.
+ :type item: beets.library.Item
+ :param artist: The artist of the track to be matched.
+ :type artist: str
+ :param title: The title of the track to be matched.
+ :type title: str
+ :return: Candidate TrackInfo objects.
+ :rtype: list[beets.autotag.hooks.TrackInfo]
+ """
+ if not self.discogs_client:
+ return []
+
+ if not artist and not title:
+ self._log.debug(
+ "Skipping Discogs query. File missing artist and " "title tags."
+ )
+ return []
+
+ query = f"{artist} {title}"
+ try:
+ albums = self.get_albums(query)
+ except DiscogsAPIError as e:
+ self._log.debug("API Error: {0} (query: {1})", e, query)
+ if e.status_code == 401:
+ self.reset_auth()
+ return self.item_candidates(item, artist, title)
+ else:
+ return []
+ except CONNECTION_ERRORS:
+ self._log.debug("Connection error in track search", exc_info=True)
+ candidates = []
+ for album_cur in albums:
+ self._log.debug("searching within album {0}", album_cur.album)
+ track_result = self.get_track_from_album_by_title(
+ album_cur, item["title"]
+ )
+ if track_result:
+ candidates.append(track_result)
+ # first 10 results, don't overwhelm with options
+ return candidates[:10]
+
def album_for_id(self, album_id):
"""Fetches an album by its Discogs ID and returns an AlbumInfo object
or None if the album is not found.
@@ -206,83 +284,90 @@ class DiscogsPlugin(BeetsPlugin):
if not self.discogs_client:
return
- self._log.debug('Searching for release {0}', album_id)
+ self._log.debug("Searching for release {0}", album_id)
- discogs_id = self.extract_release_id_regex(album_id)
+ discogs_id = extract_discogs_id_regex(album_id)
if not discogs_id:
return None
- result = Release(self.discogs_client, {'id': discogs_id})
+ result = Release(self.discogs_client, {"id": discogs_id})
# Try to obtain title to verify that we indeed have a valid Release
try:
- getattr(result, 'title')
+ getattr(result, "title")
except DiscogsAPIError as e:
if e.status_code != 404:
- self._log.debug('API Error: {0} (query: {1})', e,
- result.data['resource_url'])
+ self._log.debug(
+ "API Error: {0} (query: {1})",
+ e,
+ result.data["resource_url"],
+ )
if e.status_code == 401:
self.reset_auth()
return self.album_for_id(album_id)
return None
except CONNECTION_ERRORS:
- self._log.debug('Connection error in album lookup',
- exc_info=True)
+ self._log.debug("Connection error in album lookup", exc_info=True)
return None
return self.get_album_info(result)
def get_albums(self, query):
- """Returns a list of AlbumInfo objects for a discogs search query.
- """
+ """Returns a list of AlbumInfo objects for a discogs search query."""
# Strip non-word characters from query. Things like "!" and "-" can
# cause a query to return no results, even if they match the artist or
# album title. Use `re.UNICODE` flag to avoid stripping non-english
# word characters.
- query = re.sub(r'(?u)\W+', ' ', query)
+ query = re.sub(r"(?u)\W+", " ", query)
# Strip medium information from query, Things like "CD1" and "disk 1"
# can also negate an otherwise positive result.
- query = re.sub(r'(?i)\b(CD|disc)\s*\d+', '', query)
+ query = re.sub(r"(?i)\b(CD|disc|vinyl)\s*\d+", "", query)
try:
- releases = self.discogs_client.search(query,
- type='release').page(1)
+ releases = self.discogs_client.search(query, type="release").page(1)
except CONNECTION_ERRORS:
- self._log.debug("Communication error while searching for {0!r}",
- query, exc_info=True)
+ self._log.debug(
+ "Communication error while searching for {0!r}",
+ query,
+ exc_info=True,
+ )
return []
- return [album for album in map(self.get_album_info, releases[:5])
- if album]
+ return [
+ album for album in map(self.get_album_info, releases[:5]) if album
+ ]
def get_master_year(self, master_id):
"""Fetches a master release given its Discogs ID and returns its year
or None if the master release is not found.
"""
- self._log.debug('Searching for master release {0}', master_id)
- result = Master(self.discogs_client, {'id': master_id})
+ self._log.debug("Searching for master release {0}", master_id)
+ result = Master(self.discogs_client, {"id": master_id})
try:
- year = result.fetch('year')
+ year = result.fetch("year")
return year
except DiscogsAPIError as e:
if e.status_code != 404:
- self._log.debug('API Error: {0} (query: {1})', e,
- result.data['resource_url'])
+ self._log.debug(
+ "API Error: {0} (query: {1})",
+ e,
+ result.data["resource_url"],
+ )
if e.status_code == 401:
self.reset_auth()
return self.get_master_year(master_id)
return None
except CONNECTION_ERRORS:
- self._log.debug('Connection error in master release lookup',
- exc_info=True)
+ self._log.debug(
+ "Connection error in master release lookup", exc_info=True
+ )
return None
def get_album_info(self, result):
- """Returns an AlbumInfo object for a discogs Release object.
- """
+ """Returns an AlbumInfo object for a discogs Release object."""
# Explicitly reload the `Release` fields, as they might not be yet
# present if the result is from a `discogs_client.search()`.
- if not result.data.get('artists'):
+ if not result.data.get("artists"):
result.refresh()
# Sanity check for required fields. The list of required fields is
@@ -290,99 +375,138 @@ class DiscogsPlugin(BeetsPlugin):
# lacking some of these fields. This function expects at least:
# `artists` (>0), `title`, `id`, `tracklist` (>0)
# https://www.discogs.com/help/doc/submission-guidelines-general-rules
- if not all([result.data.get(k) for k in ['artists', 'title', 'id',
- 'tracklist']]):
+ if not all(
+ [
+ result.data.get(k)
+ for k in ["artists", "title", "id", "tracklist"]
+ ]
+ ):
self._log.warning("Release does not contain the required fields")
return None
artist, artist_id = MetadataSourcePlugin.get_artist(
- [a.data for a in result.artists]
+ [a.data for a in result.artists], join_key="join"
)
- album = re.sub(r' +', ' ', result.title)
- album_id = result.data['id']
+ album = re.sub(r" +", " ", result.title)
+ album_id = result.data["id"]
# Use `.data` to access the tracklist directly instead of the
# convenient `.tracklist` property, which will strip out useful artist
# information and leave us with skeleton `Artist` objects that will
# each make an API call just to get the same data back.
- tracks = self.get_tracks(result.data['tracklist'])
+ tracks = self.get_tracks(result.data["tracklist"])
# Extract information for the optional AlbumInfo fields, if possible.
- va = result.data['artists'][0].get('name', '').lower() == 'various'
- year = result.data.get('year')
+ va = result.data["artists"][0].get("name", "").lower() == "various"
+ year = result.data.get("year")
mediums = [t.medium for t in tracks]
- country = result.data.get('country')
- data_url = result.data.get('uri')
- style = self.format(result.data.get('styles'))
- genre = self.format(result.data.get('genres'))
- discogs_albumid = self.extract_release_id(result.data.get('uri'))
+ country = result.data.get("country")
+ data_url = result.data.get("uri")
+ style = self.format(result.data.get("styles"))
+ base_genre = self.format(result.data.get("genres"))
+
+ if self.config["append_style_genre"] and style:
+ genre = self.config["separator"].as_str().join([base_genre, style])
+ else:
+ genre = base_genre
+
+ discogs_albumid = extract_discogs_id_regex(result.data.get("uri"))
# Extract information for the optional AlbumInfo fields that are
# contained on nested discogs fields.
albumtype = media = label = catalogno = labelid = None
- if result.data.get('formats'):
- albumtype = ', '.join(
- result.data['formats'][0].get('descriptions', [])) or None
- media = result.data['formats'][0]['name']
- if result.data.get('labels'):
- label = result.data['labels'][0].get('name')
- catalogno = result.data['labels'][0].get('catno')
- labelid = result.data['labels'][0].get('id')
+ if result.data.get("formats"):
+ albumtype = (
+ ", ".join(result.data["formats"][0].get("descriptions", []))
+ or None
+ )
+ media = result.data["formats"][0]["name"]
+ if result.data.get("labels"):
+ label = result.data["labels"][0].get("name")
+ catalogno = result.data["labels"][0].get("catno")
+ labelid = result.data["labels"][0].get("id")
+
+ cover_art_url = self.select_cover_art(result)
# Additional cleanups (various artists name, catalog number, media).
if va:
- artist = config['va_name'].as_str()
- if catalogno == 'none':
+ artist = config["va_name"].as_str()
+ if catalogno == "none":
catalogno = None
# Explicitly set the `media` for the tracks, since it is expected by
# `autotag.apply_metadata`, and set `medium_total`.
for track in tracks:
track.media = media
track.medium_total = mediums.count(track.medium)
+ if not track.artist: # get_track_info often fails to find artist
+ track.artist = artist
+ if not track.artist_id:
+ track.artist_id = artist_id
# Discogs does not have track IDs. Invent our own IDs as proposed
# in #2336.
track.track_id = str(album_id) + "-" + track.track_alt
+ track.data_url = data_url
+ track.data_source = "Discogs"
# Retrieve master release id (returns None if there isn't one).
- master_id = result.data.get('master_id')
+ master_id = result.data.get("master_id")
# Assume `original_year` is equal to `year` for releases without
# a master release, otherwise fetch the master release.
original_year = self.get_master_year(master_id) if master_id else year
- return AlbumInfo(album=album, album_id=album_id, artist=artist,
- artist_id=artist_id, tracks=tracks,
- albumtype=albumtype, va=va, year=year,
- label=label, mediums=len(set(mediums)),
- releasegroup_id=master_id, catalognum=catalogno,
- country=country, style=style, genre=genre,
- media=media, original_year=original_year,
- data_source='Discogs', data_url=data_url,
- discogs_albumid=discogs_albumid,
- discogs_labelid=labelid, discogs_artistid=artist_id)
+ return AlbumInfo(
+ album=album,
+ album_id=album_id,
+ artist=artist,
+ artist_id=artist_id,
+ tracks=tracks,
+ albumtype=albumtype,
+ va=va,
+ year=year,
+ label=label,
+ mediums=len(set(mediums)),
+ releasegroup_id=master_id,
+ catalognum=catalogno,
+ country=country,
+ style=style,
+ genre=genre,
+ media=media,
+ original_year=original_year,
+ data_source="Discogs",
+ data_url=data_url,
+ discogs_albumid=discogs_albumid,
+ discogs_labelid=labelid,
+ discogs_artistid=artist_id,
+ cover_art_url=cover_art_url,
+ )
+
+ def select_cover_art(self, result):
+ """Returns the best candidate image, if any, from a Discogs `Release` object."""
+ if result.data.get("images") and len(result.data.get("images")) > 0:
+ # The first image in this list appears to be the one displayed first
+ # on the release page - even if it is not flagged as `type: "primary"` - and
+ # so it is the best candidate for the cover art.
+ return result.data.get("images")[0].get("uri")
+
+ return None
def format(self, classification):
if classification:
- return self.config['separator'].as_str() \
- .join(sorted(classification))
- else:
- return None
-
- def extract_release_id(self, uri):
- if uri:
- return uri.split("/")[-1]
+ return (
+ self.config["separator"].as_str().join(sorted(classification))
+ )
else:
return None
def get_tracks(self, tracklist):
- """Returns a list of TrackInfo objects for a discogs tracklist.
- """
+ """Returns a list of TrackInfo objects for a discogs tracklist."""
try:
clean_tracklist = self.coalesce_tracks(tracklist)
except Exception as exc:
# FIXME: this is an extra precaution for making sure there are no
# side effects after #2222. It should be removed after further
# testing.
- self._log.debug('{}', traceback.format_exc())
- self._log.error('uncaught exception in coalesce_tracks: {}', exc)
+ self._log.debug("{}", traceback.format_exc())
+ self._log.error("uncaught exception in coalesce_tracks: {}", exc)
clean_tracklist = tracklist
tracks = []
index_tracks = {}
@@ -391,7 +515,7 @@ class DiscogsPlugin(BeetsPlugin):
divisions, next_divisions = [], []
for track in clean_tracklist:
# Only real tracks have `position`. Otherwise, it's an index track.
- if track['position']:
+ if track["position"]:
index += 1
if next_divisions:
# End of a block of index tracks: update the current
@@ -399,17 +523,17 @@ class DiscogsPlugin(BeetsPlugin):
divisions += next_divisions
del next_divisions[:]
track_info = self.get_track_info(track, index, divisions)
- track_info.track_alt = track['position']
+ track_info.track_alt = track["position"]
tracks.append(track_info)
else:
- next_divisions.append(track['title'])
+ next_divisions.append(track["title"])
# We expect new levels of division at the beginning of the
# tracklist (and possibly elsewhere).
try:
divisions.pop()
except IndexError:
pass
- index_tracks[index + 1] = track['title']
+ index_tracks[index + 1] = track["title"]
# Fix up medium and medium_index for each track. Discogs position is
# unreliable, but tracks are in order.
@@ -423,7 +547,7 @@ class DiscogsPlugin(BeetsPlugin):
m = sorted({track.medium.lower() for track in tracks})
# If all track.medium are single consecutive letters, assume it is
# a 2-sided medium.
- if ''.join(m) in ascii_lowercase:
+ if "".join(m) in ascii_lowercase:
sides_per_medium = 2
for track in tracks:
@@ -433,10 +557,15 @@ class DiscogsPlugin(BeetsPlugin):
# are the track index, not the medium.
# side_count is the number of mediums or medium sides (in the case
# of two-sided mediums) that were seen before.
- medium_is_index = track.medium and not track.medium_index and (
- len(track.medium) != 1 or
- # Not within standard incremental medium values (A, B, C, ...).
- ord(track.medium) - 64 != side_count + 1
+ medium_is_index = (
+ track.medium
+ and not track.medium_index
+ and (
+ len(track.medium) != 1
+ or
+ # Not within standard incremental medium values (A, B, C, ...).
+ ord(track.medium) - 64 != side_count + 1
+ )
)
if not medium_is_index and medium != track.medium:
@@ -473,51 +602,54 @@ class DiscogsPlugin(BeetsPlugin):
title for the merged track is the one from the previous index track,
if present; otherwise it is a combination of the subtracks titles.
"""
+
def add_merged_subtracks(tracklist, subtracks):
"""Modify `tracklist` in place, merging a list of `subtracks` into
a single track into `tracklist`."""
# Calculate position based on first subtrack, without subindex.
- idx, medium_idx, sub_idx = \
- self.get_track_index(subtracks[0]['position'])
- position = '{}{}'.format(idx or '', medium_idx or '')
+ idx, medium_idx, sub_idx = self.get_track_index(
+ subtracks[0]["position"]
+ )
+ position = "{}{}".format(idx or "", medium_idx or "")
- if tracklist and not tracklist[-1]['position']:
+ if tracklist and not tracklist[-1]["position"]:
# Assume the previous index track contains the track title.
if sub_idx:
# "Convert" the track title to a real track, discarding the
# subtracks assuming they are logical divisions of a
# physical track (12.2.9 Subtracks).
- tracklist[-1]['position'] = position
+ tracklist[-1]["position"] = position
else:
# Promote the subtracks to real tracks, discarding the
# index track, assuming the subtracks are physical tracks.
index_track = tracklist.pop()
# Fix artists when they are specified on the index track.
- if index_track.get('artists'):
+ if index_track.get("artists"):
for subtrack in subtracks:
- if not subtrack.get('artists'):
- subtrack['artists'] = index_track['artists']
+ if not subtrack.get("artists"):
+ subtrack["artists"] = index_track["artists"]
# Concatenate index with track title when index_tracks
# option is set
- if self.config['index_tracks']:
+ if self.config["index_tracks"]:
for subtrack in subtracks:
- subtrack['title'] = '{}: {}'.format(
- index_track['title'], subtrack['title'])
+ subtrack["title"] = "{}: {}".format(
+ index_track["title"], subtrack["title"]
+ )
tracklist.extend(subtracks)
else:
# Merge the subtracks, pick a title, and append the new track.
track = subtracks[0].copy()
- track['title'] = ' / '.join([t['title'] for t in subtracks])
+ track["title"] = " / ".join([t["title"] for t in subtracks])
tracklist.append(track)
# Pre-process the tracklist, trying to identify subtracks.
subtracks = []
tracklist = []
- prev_subindex = ''
+ prev_subindex = ""
for track in raw_tracklist:
# Regular subtrack (track with subindex).
- if track['position']:
- _, _, subindex = self.get_track_index(track['position'])
+ if track["position"]:
+ _, _, subindex = self.get_track_index(track["position"])
if subindex:
if subindex.rjust(len(raw_tracklist)) > prev_subindex:
# Subtrack still part of the current main track.
@@ -530,17 +662,17 @@ class DiscogsPlugin(BeetsPlugin):
continue
# Index track with nested sub_tracks.
- if not track['position'] and 'sub_tracks' in track:
+ if not track["position"] and "sub_tracks" in track:
# Append the index track, assuming it contains the track title.
tracklist.append(track)
- add_merged_subtracks(tracklist, track['sub_tracks'])
+ add_merged_subtracks(tracklist, track["sub_tracks"])
continue
# Regular track or index track without nested sub_tracks.
if subtracks:
add_merged_subtracks(tracklist, subtracks)
subtracks = []
- prev_subindex = ''
+ prev_subindex = ""
tracklist.append(track)
# Merge and add the remaining subtracks, if any.
@@ -550,22 +682,28 @@ class DiscogsPlugin(BeetsPlugin):
return tracklist
def get_track_info(self, track, index, divisions):
- """Returns a TrackInfo object for a discogs track.
- """
- title = track['title']
- if self.config['index_tracks']:
- prefix = ', '.join(divisions)
+ """Returns a TrackInfo object for a discogs track."""
+ title = track["title"]
+ if self.config["index_tracks"]:
+ prefix = ", ".join(divisions)
if prefix:
- title = f'{prefix}: {title}'
+ title = f"{prefix}: {title}"
track_id = None
- medium, medium_index, _ = self.get_track_index(track['position'])
+ medium, medium_index, _ = self.get_track_index(track["position"])
artist, artist_id = MetadataSourcePlugin.get_artist(
- track.get('artists', [])
+ track.get("artists", []), join_key="join"
+ )
+ length = self.get_track_length(track["duration"])
+ return TrackInfo(
+ title=title,
+ track_id=track_id,
+ artist=artist,
+ artist_id=artist_id,
+ length=length,
+ index=index,
+ medium=medium,
+ medium_index=medium_index,
)
- length = self.get_track_length(track['duration'])
- return TrackInfo(title=title, track_id=track_id, artist=artist,
- artist_id=artist_id, length=length, index=index,
- medium=medium, medium_index=medium_index)
def get_track_index(self, position):
"""Returns the medium, medium index and subtrack index for a discogs
@@ -573,34 +711,33 @@ class DiscogsPlugin(BeetsPlugin):
# Match the standard Discogs positions (12.2.9), which can have several
# forms (1, 1-1, A1, A1.1, A1a, ...).
match = re.match(
- r'^(.*?)' # medium: everything before medium_index.
- r'(\d*?)' # medium_index: a number at the end of
- # `position`, except if followed by a subtrack
- # index.
- # subtrack_index: can only be matched if medium
- # or medium_index have been matched, and can be
- r'((?<=\w)\.[\w]+' # - a dot followed by a string (A.1, 2.A)
- r'|(?<=\d)[A-Z]+' # - a string that follows a number (1A, B2a)
- r')?'
- r'$',
- position.upper()
+ r"^(.*?)" # medium: everything before medium_index.
+ r"(\d*?)" # medium_index: a number at the end of
+ # `position`, except if followed by a subtrack
+ # index.
+ # subtrack_index: can only be matched if medium
+ # or medium_index have been matched, and can be
+ r"((?<=\w)\.[\w]+" # - a dot followed by a string (A.1, 2.A)
+ r"|(?<=\d)[A-Z]+" # - a string that follows a number (1A, B2a)
+ r")?"
+ r"$",
+ position.upper(),
)
if match:
medium, index, subindex = match.groups()
- if subindex and subindex.startswith('.'):
+ if subindex and subindex.startswith("."):
subindex = subindex[1:]
else:
- self._log.debug('Invalid position: {0}', position)
+ self._log.debug("Invalid position: {0}", position)
medium = index = subindex = None
return medium or None, index or None, subindex or None
def get_track_length(self, duration):
- """Returns the track length in seconds for a discogs duration.
- """
+ """Returns the track length in seconds for a discogs duration."""
try:
- length = time.strptime(duration, '%M:%S')
+ length = time.strptime(duration, "%M:%S")
except ValueError:
return None
return length.tm_min * 60 + length.tm_sec
diff --git a/lib/beetsplug/duplicates.py b/lib/beetsplug/duplicates.py
index fdd5c175..ced96e40 100644
--- a/lib/beetsplug/duplicates.py
+++ b/lib/beetsplug/duplicates.py
@@ -15,122 +15,150 @@
"""List duplicate tracks or albums.
"""
+import os
import shlex
+from beets.library import Album, Item
from beets.plugins import BeetsPlugin
-from beets.ui import decargs, print_, Subcommand, UserError
-from beets.util import command_output, displayable_path, subprocess, \
- bytestring_path, MoveOperation, decode_commandline_path
-from beets.library import Item, Album
+from beets.ui import Subcommand, UserError, decargs, print_
+from beets.util import (
+ MoveOperation,
+ bytestring_path,
+ command_output,
+ displayable_path,
+ subprocess,
+)
-
-PLUGIN = 'duplicates'
+PLUGIN = "duplicates"
class DuplicatesPlugin(BeetsPlugin):
- """List duplicate tracks or albums
- """
+ """List duplicate tracks or albums"""
+
def __init__(self):
super().__init__()
- self.config.add({
- 'album': False,
- 'checksum': '',
- 'copy': '',
- 'count': False,
- 'delete': False,
- 'format': '',
- 'full': False,
- 'keys': [],
- 'merge': False,
- 'move': '',
- 'path': False,
- 'tiebreak': {},
- 'strict': False,
- 'tag': '',
- })
+ self.config.add(
+ {
+ "album": False,
+ "checksum": "",
+ "copy": "",
+ "count": False,
+ "delete": False,
+ "format": "",
+ "full": False,
+ "keys": [],
+ "merge": False,
+ "move": "",
+ "path": False,
+ "tiebreak": {},
+ "strict": False,
+ "tag": "",
+ }
+ )
- self._command = Subcommand('duplicates',
- help=__doc__,
- aliases=['dup'])
+ self._command = Subcommand("duplicates", help=__doc__, aliases=["dup"])
self._command.parser.add_option(
- '-c', '--count', dest='count',
- action='store_true',
- help='show duplicate counts',
+ "-c",
+ "--count",
+ dest="count",
+ action="store_true",
+ help="show duplicate counts",
)
self._command.parser.add_option(
- '-C', '--checksum', dest='checksum',
- action='store', metavar='PROG',
- help='report duplicates based on arbitrary command',
+ "-C",
+ "--checksum",
+ dest="checksum",
+ action="store",
+ metavar="PROG",
+ help="report duplicates based on arbitrary command",
)
self._command.parser.add_option(
- '-d', '--delete', dest='delete',
- action='store_true',
- help='delete items from library and disk',
+ "-d",
+ "--delete",
+ dest="delete",
+ action="store_true",
+ help="delete items from library and disk",
)
self._command.parser.add_option(
- '-F', '--full', dest='full',
- action='store_true',
- help='show all versions of duplicate tracks or albums',
+ "-F",
+ "--full",
+ dest="full",
+ action="store_true",
+ help="show all versions of duplicate tracks or albums",
)
self._command.parser.add_option(
- '-s', '--strict', dest='strict',
- action='store_true',
- help='report duplicates only if all attributes are set',
+ "-s",
+ "--strict",
+ dest="strict",
+ action="store_true",
+ help="report duplicates only if all attributes are set",
)
self._command.parser.add_option(
- '-k', '--key', dest='keys',
- action='append', metavar='KEY',
- help='report duplicates based on keys (use multiple times)',
+ "-k",
+ "--key",
+ dest="keys",
+ action="append",
+ metavar="KEY",
+ help="report duplicates based on keys (use multiple times)",
)
self._command.parser.add_option(
- '-M', '--merge', dest='merge',
- action='store_true',
- help='merge duplicate items',
+ "-M",
+ "--merge",
+ dest="merge",
+ action="store_true",
+ help="merge duplicate items",
)
self._command.parser.add_option(
- '-m', '--move', dest='move',
- action='store', metavar='DEST',
- help='move items to dest',
+ "-m",
+ "--move",
+ dest="move",
+ action="store",
+ metavar="DEST",
+ help="move items to dest",
)
self._command.parser.add_option(
- '-o', '--copy', dest='copy',
- action='store', metavar='DEST',
- help='copy items to dest',
+ "-o",
+ "--copy",
+ dest="copy",
+ action="store",
+ metavar="DEST",
+ help="copy items to dest",
)
self._command.parser.add_option(
- '-t', '--tag', dest='tag',
- action='store',
- help='tag matched items with \'k=v\' attribute',
+ "-t",
+ "--tag",
+ dest="tag",
+ action="store",
+ help="tag matched items with 'k=v' attribute",
)
self._command.parser.add_all_common_options()
def commands(self):
-
def _dup(lib, opts, args):
self.config.set_args(opts)
- album = self.config['album'].get(bool)
- checksum = self.config['checksum'].get(str)
- copy = bytestring_path(self.config['copy'].as_str())
- count = self.config['count'].get(bool)
- delete = self.config['delete'].get(bool)
- fmt = self.config['format'].get(str)
- full = self.config['full'].get(bool)
- keys = self.config['keys'].as_str_seq()
- merge = self.config['merge'].get(bool)
- move = bytestring_path(self.config['move'].as_str())
- path = self.config['path'].get(bool)
- tiebreak = self.config['tiebreak'].get(dict)
- strict = self.config['strict'].get(bool)
- tag = self.config['tag'].get(str)
+ album = self.config["album"].get(bool)
+ checksum = self.config["checksum"].get(str)
+ copy = bytestring_path(self.config["copy"].as_str())
+ count = self.config["count"].get(bool)
+ delete = self.config["delete"].get(bool)
+ fmt = self.config["format"].get(str)
+ full = self.config["full"].get(bool)
+ keys = self.config["keys"].as_str_seq()
+ merge = self.config["merge"].get(bool)
+ move = bytestring_path(self.config["move"].as_str())
+ path = self.config["path"].get(bool)
+ tiebreak = self.config["tiebreak"].get(dict)
+ strict = self.config["strict"].get(bool)
+ tag = self.config["tag"].get(str)
if album:
if not keys:
- keys = ['mb_albumid']
+ keys = ["mb_albumid"]
items = lib.albums(decargs(args))
else:
if not keys:
- keys = ['mb_trackid', 'mb_albumid']
+ keys = ["mb_trackid", "mb_albumid"]
items = lib.items(decargs(args))
# If there's nothing to do, return early. The code below assumes
@@ -139,43 +167,47 @@ class DuplicatesPlugin(BeetsPlugin):
return
if path:
- fmt = '$path'
+ fmt = "$path"
# Default format string for count mode.
if count and not fmt:
if album:
- fmt = '$albumartist - $album'
+ fmt = "$albumartist - $album"
else:
- fmt = '$albumartist - $album - $title'
- fmt += ': {0}'
+ fmt = "$albumartist - $album - $title"
+ fmt += ": {0}"
if checksum:
for i in items:
k, _ = self._checksum(i, checksum)
keys = [k]
- for obj_id, obj_count, objs in self._duplicates(items,
- keys=keys,
- full=full,
- strict=strict,
- tiebreak=tiebreak,
- merge=merge):
+ for obj_id, obj_count, objs in self._duplicates(
+ items,
+ keys=keys,
+ full=full,
+ strict=strict,
+ tiebreak=tiebreak,
+ merge=merge,
+ ):
if obj_id: # Skip empty IDs.
for o in objs:
- self._process_item(o,
- copy=copy,
- move=move,
- delete=delete,
- tag=tag,
- fmt=fmt.format(obj_count))
+ self._process_item(
+ o,
+ copy=copy,
+ move=move,
+ delete=delete,
+ tag=tag,
+ fmt=fmt.format(obj_count),
+ )
self._command.func = _dup
return [self._command]
- def _process_item(self, item, copy=False, move=False, delete=False,
- tag=False, fmt=''):
- """Process Item `item`.
- """
+ def _process_item(
+ self, item, copy=False, move=False, delete=False, tag=False, fmt=""
+ ):
+ """Process Item `item`."""
print_(format(item, fmt))
if copy:
item.move(basedir=copy, operation=MoveOperation.COPY)
@@ -187,11 +219,9 @@ class DuplicatesPlugin(BeetsPlugin):
item.remove(delete=True)
if tag:
try:
- k, v = tag.split('=')
+ k, v = tag.split("=")
except Exception:
- raise UserError(
- f"{PLUGIN}: can't parse k=v tag: {tag}"
- )
+ raise UserError(f"{PLUGIN}: can't parse k=v tag: {tag}")
setattr(item, k, v)
item.store()
@@ -200,27 +230,36 @@ class DuplicatesPlugin(BeetsPlugin):
output as flexattr on a key that is the name of the program, and
return the key, checksum tuple.
"""
- args = [p.format(file=decode_commandline_path(item.path))
- for p in shlex.split(prog)]
+ args = [
+ p.format(file=os.fsdecode(item.path)) for p in shlex.split(prog)
+ ]
key = args[0]
checksum = getattr(item, key, False)
if not checksum:
- self._log.debug('key {0} on item {1} not cached:'
- 'computing checksum',
- key, displayable_path(item.path))
+ self._log.debug(
+ "key {0} on item {1} not cached:" "computing checksum",
+ key,
+ displayable_path(item.path),
+ )
try:
checksum = command_output(args).stdout
setattr(item, key, checksum)
item.store()
- self._log.debug('computed checksum for {0} using {1}',
- item.title, key)
+ self._log.debug(
+ "computed checksum for {0} using {1}", item.title, key
+ )
except subprocess.CalledProcessError as e:
- self._log.debug('failed to checksum {0}: {1}',
- displayable_path(item.path), e)
+ self._log.debug(
+ "failed to checksum {0}: {1}",
+ displayable_path(item.path),
+ e,
+ )
else:
- self._log.debug('key {0} on item {1} cached:'
- 'not computing checksum',
- key, displayable_path(item.path))
+ self._log.debug(
+ "key {0} on item {1} cached:" "not computing checksum",
+ key,
+ displayable_path(item.path),
+ )
return key, checksum
def _group_by(self, objs, keys, strict):
@@ -230,18 +269,23 @@ class DuplicatesPlugin(BeetsPlugin):
If strict, all attributes must be defined for a duplicate match.
"""
import collections
+
counts = collections.defaultdict(list)
for obj in objs:
values = [getattr(obj, k, None) for k in keys]
- values = [v for v in values if v not in (None, '')]
+ values = [v for v in values if v not in (None, "")]
if strict and len(values) < len(keys):
- self._log.debug('some keys {0} on item {1} are null or empty:'
- ' skipping',
- keys, displayable_path(obj.path))
- elif (not strict and not len(values)):
- self._log.debug('all keys {0} on item {1} are null or empty:'
- ' skipping',
- keys, displayable_path(obj.path))
+ self._log.debug(
+ "some keys {0} on item {1} are null or empty:" " skipping",
+ keys,
+ displayable_path(obj.path),
+ )
+ elif not strict and not len(values):
+ self._log.debug(
+ "all keys {0} on item {1} are null or empty:" " skipping",
+ keys,
+ displayable_path(obj.path),
+ )
else:
key = tuple(values)
counts[key].append(obj)
@@ -257,18 +301,21 @@ class DuplicatesPlugin(BeetsPlugin):
"completeness" (objects with more non-null fields come first)
and Albums are ordered by their track count.
"""
- kind = 'items' if all(isinstance(o, Item) for o in objs) else 'albums'
+ kind = "items" if all(isinstance(o, Item) for o in objs) else "albums"
if tiebreak and kind in tiebreak.keys():
key = lambda x: tuple(getattr(x, k) for k in tiebreak[kind])
else:
- if kind == 'items':
+ if kind == "items":
+
def truthy(v):
# Avoid a Unicode warning by avoiding comparison
# between a bytes object and the empty Unicode
# string ''.
- return v is not None and \
- (v != '' if isinstance(v, str) else True)
+ return v is not None and (
+ v != "" if isinstance(v, str) else True
+ )
+
fields = Item.all_keys()
key = lambda x: sum(1 for f in fields if truthy(getattr(x, f)))
else:
@@ -285,13 +332,16 @@ class DuplicatesPlugin(BeetsPlugin):
fields = Item.all_keys()
for f in fields:
for o in objs[1:]:
- if getattr(objs[0], f, None) in (None, ''):
+ if getattr(objs[0], f, None) in (None, ""):
value = getattr(o, f, None)
if value:
- self._log.debug('key {0} on item {1} is null '
- 'or empty: setting from item {2}',
- f, displayable_path(objs[0].path),
- displayable_path(o.path))
+ self._log.debug(
+ "key {0} on item {1} is null "
+ "or empty: setting from item {2}",
+ f,
+ displayable_path(objs[0].path),
+ displayable_path(o.path),
+ )
setattr(objs[0], f, value)
objs[0].store()
break
@@ -309,12 +359,14 @@ class DuplicatesPlugin(BeetsPlugin):
missing = Item.from_path(i.path)
missing.album_id = objs[0].id
missing.add(i._db)
- self._log.debug('item {0} missing from album {1}:'
- ' merging from {2} into {3}',
- missing,
- objs[0],
- displayable_path(o.path),
- displayable_path(missing.destination()))
+ self._log.debug(
+ "item {0} missing from album {1}:"
+ " merging from {2} into {3}",
+ missing,
+ objs[0],
+ displayable_path(o.path),
+ displayable_path(missing.destination()),
+ )
missing.move(operation=MoveOperation.COPY)
return objs
@@ -330,8 +382,7 @@ class DuplicatesPlugin(BeetsPlugin):
return objs
def _duplicates(self, objs, keys, full, strict, tiebreak, merge):
- """Generate triples of keys, duplicate counts, and constituent objects.
- """
+ """Generate triples of keys, duplicate counts, and constituent objects."""
offset = 0 if full else 1
for k, objs in self._group_by(objs, keys, strict).items():
if len(objs) > 1:
diff --git a/lib/beetsplug/edit.py b/lib/beetsplug/edit.py
index 6f03fa4d..323dd9e4 100644
--- a/lib/beetsplug/edit.py
+++ b/lib/beetsplug/edit.py
@@ -15,23 +15,22 @@
"""Open metadata information in a text editor to let the user edit it.
"""
-from beets import plugins
-from beets import util
-from beets import ui
-from beets.dbcore import types
-from beets.importer import action
-from beets.ui.commands import _do_query, PromptChoice
import codecs
-import subprocess
-import yaml
-from tempfile import NamedTemporaryFile
import os
import shlex
+import subprocess
+from tempfile import NamedTemporaryFile
+import yaml
+
+from beets import plugins, ui, util
+from beets.dbcore import types
+from beets.importer import action
+from beets.ui.commands import PromptChoice, _do_query
# These "safe" types can avoid the format/parse cycle that most fields go
# through: they are safe to edit with native YAML types.
-SAFE_TYPES = (types.Float, types.Integer, types.Boolean)
+SAFE_TYPES = (types.BaseFloat, types.BaseInteger, types.Boolean)
class ParseError(Exception):
@@ -41,22 +40,20 @@ class ParseError(Exception):
def edit(filename, log):
- """Open `filename` in a text editor.
- """
+ """Open `filename` in a text editor."""
cmd = shlex.split(util.editor_command())
cmd.append(filename)
- log.debug('invoking editor command: {!r}', cmd)
+ log.debug("invoking editor command: {!r}", cmd)
try:
subprocess.call(cmd)
except OSError as exc:
- raise ui.UserError('could not run editor command {!r}: {}'.format(
- cmd[0], exc
- ))
+ raise ui.UserError(
+ "could not run editor command {!r}: {}".format(cmd[0], exc)
+ )
def dump(arg):
- """Dump a sequence of dictionaries as YAML for editing.
- """
+ """Dump a sequence of dictionaries as YAML for editing."""
return yaml.safe_dump_all(
arg,
allow_unicode=True,
@@ -75,7 +72,7 @@ def load(s):
for d in yaml.safe_load_all(s):
if not isinstance(d, dict):
raise ParseError(
- 'each entry must be a dictionary; found {}'.format(
+ "each entry must be a dictionary; found {}".format(
type(d).__name__
)
)
@@ -85,7 +82,7 @@ def load(s):
out.append({str(k): v for k, v in d.items()})
except yaml.YAMLError as e:
- raise ParseError(f'invalid YAML: {e}')
+ raise ParseError(f"invalid YAML: {e}")
return out
@@ -145,51 +142,50 @@ def apply_(obj, data):
class EditPlugin(plugins.BeetsPlugin):
-
def __init__(self):
super().__init__()
- self.config.add({
- # The default fields to edit.
- 'albumfields': 'album albumartist',
- 'itemfields': 'track title artist album',
+ self.config.add(
+ {
+ # The default fields to edit.
+ "albumfields": "album albumartist",
+ "itemfields": "track title artist album",
+ # Silently ignore any changes to these fields.
+ "ignore_fields": "id path",
+ }
+ )
- # Silently ignore any changes to these fields.
- 'ignore_fields': 'id path',
- })
-
- self.register_listener('before_choose_candidate',
- self.before_choose_candidate_listener)
+ self.register_listener(
+ "before_choose_candidate", self.before_choose_candidate_listener
+ )
def commands(self):
- edit_command = ui.Subcommand(
- 'edit',
- help='interactively edit metadata'
+ edit_command = ui.Subcommand("edit", help="interactively edit metadata")
+ edit_command.parser.add_option(
+ "-f",
+ "--field",
+ metavar="FIELD",
+ action="append",
+ help="edit this field also",
)
edit_command.parser.add_option(
- '-f', '--field',
- metavar='FIELD',
- action='append',
- help='edit this field also',
- )
- edit_command.parser.add_option(
- '--all',
- action='store_true', dest='all',
- help='edit all fields',
+ "--all",
+ action="store_true",
+ dest="all",
+ help="edit all fields",
)
edit_command.parser.add_album_option()
edit_command.func = self._edit_command
return [edit_command]
def _edit_command(self, lib, opts, args):
- """The CLI command function for the `beet edit` command.
- """
+ """The CLI command function for the `beet edit` command."""
# Get the objects to edit.
query = ui.decargs(args)
items, albums = _do_query(lib, query, opts.album, False)
objs = albums if opts.album else items
if not objs:
- ui.print_('Nothing to edit.')
+ ui.print_("Nothing to edit.")
return
# Get the fields to edit.
@@ -200,20 +196,19 @@ class EditPlugin(plugins.BeetsPlugin):
self.edit(opts.album, objs, fields)
def _get_fields(self, album, extra):
- """Get the set of fields to edit.
- """
+ """Get the set of fields to edit."""
# Start with the configured base fields.
if album:
- fields = self.config['albumfields'].as_str_seq()
+ fields = self.config["albumfields"].as_str_seq()
else:
- fields = self.config['itemfields'].as_str_seq()
+ fields = self.config["itemfields"].as_str_seq()
# Add the requested extra fields.
if extra:
fields += extra
# Ensure we always have the `id` field for identification.
- fields.append('id')
+ fields.append("id")
return set(fields)
@@ -225,7 +220,7 @@ class EditPlugin(plugins.BeetsPlugin):
- `fields`: The set of field names to edit (or None to edit
everything).
"""
- # Present the YAML to the user and let her change it.
+ # Present the YAML to the user and let them change it.
success = self.edit_objects(objs, fields)
# Save the new data.
@@ -242,8 +237,9 @@ class EditPlugin(plugins.BeetsPlugin):
old_data = [flatten(o, fields) for o in objs]
# Set up a temporary file with the initial data for editing.
- new = NamedTemporaryFile(mode='w', suffix='.yaml', delete=False,
- encoding='utf-8')
+ new = NamedTemporaryFile(
+ mode="w", suffix=".yaml", delete=False, encoding="utf-8"
+ )
old_str = dump(old_data)
new.write(old_str)
new.close()
@@ -256,7 +252,7 @@ class EditPlugin(plugins.BeetsPlugin):
# Read the data back after editing and check whether anything
# changed.
- with codecs.open(new.name, encoding='utf-8') as f:
+ with codecs.open(new.name, encoding="utf-8") as f:
new_str = f.read()
if new_str == old_str:
ui.print_("No changes; aborting.")
@@ -275,29 +271,29 @@ class EditPlugin(plugins.BeetsPlugin):
# Show the changes.
# If the objects are not on the DB yet, we need a copy of their
# original state for show_model_changes.
- objs_old = [obj.copy() if obj.id < 0 else None
- for obj in objs]
+ objs_old = [obj.copy() if obj.id < 0 else None for obj in objs]
self.apply_data(objs, old_data, new_data)
changed = False
for obj, obj_old in zip(objs, objs_old):
changed |= ui.show_model_changes(obj, obj_old)
if not changed:
- ui.print_('No changes to apply.')
+ ui.print_("No changes to apply.")
return False
# Confirm the changes.
choice = ui.input_options(
- ('continue Editing', 'apply', 'cancel')
+ ("continue Editing", "apply", "cancel")
)
- if choice == 'a': # Apply.
+ if choice == "a": # Apply.
return True
- elif choice == 'c': # Cancel.
+ elif choice == "c": # Cancel.
return False
- elif choice == 'e': # Keep editing.
+ elif choice == "e": # Keep editing.
# Reset the temporary changes to the objects. I we have a
# copy from above, use that, else reload from the database.
- objs = [(old_obj or obj)
- for old_obj, obj in zip(objs_old, objs)]
+ objs = [
+ (old_obj or obj) for old_obj, obj in zip(objs_old, objs)
+ ]
for obj in objs:
if not obj.id < 0:
obj.load()
@@ -315,33 +311,35 @@ class EditPlugin(plugins.BeetsPlugin):
are temporary.
"""
if len(old_data) != len(new_data):
- self._log.warning('number of objects changed from {} to {}',
- len(old_data), len(new_data))
+ self._log.warning(
+ "number of objects changed from {} to {}",
+ len(old_data),
+ len(new_data),
+ )
obj_by_id = {o.id: o for o in objs}
- ignore_fields = self.config['ignore_fields'].as_str_seq()
+ ignore_fields = self.config["ignore_fields"].as_str_seq()
for old_dict, new_dict in zip(old_data, new_data):
# Prohibit any changes to forbidden fields to avoid
# clobbering `id` and such by mistake.
forbidden = False
for key in ignore_fields:
if old_dict.get(key) != new_dict.get(key):
- self._log.warning('ignoring object whose {} changed', key)
+ self._log.warning("ignoring object whose {} changed", key)
forbidden = True
break
if forbidden:
continue
- id_ = int(old_dict['id'])
+ id_ = int(old_dict["id"])
apply_(obj_by_id[id_], new_dict)
def save_changes(self, objs):
- """Save a list of updated Model objects to the database.
- """
+ """Save a list of updated Model objects to the database."""
# Save to the database and possibly write tags.
for ob in objs:
if ob._dirty:
- self._log.debug('saving changes to {}', ob)
+ self._log.debug("saving changes to {}", ob)
ob.try_sync(ui.should_write(), ui.should_move())
# Methods for interactive importer execution.
@@ -350,10 +348,13 @@ class EditPlugin(plugins.BeetsPlugin):
"""Append an "Edit" choice and an "edit Candidates" choice (if
there are candidates) to the interactive importer prompt.
"""
- choices = [PromptChoice('d', 'eDit', self.importer_edit)]
+ choices = [PromptChoice("d", "eDit", self.importer_edit)]
if task.candidates:
- choices.append(PromptChoice('c', 'edit Candidates',
- self.importer_edit_candidate))
+ choices.append(
+ PromptChoice(
+ "c", "edit Candidates", self.importer_edit_candidate
+ )
+ )
return choices
@@ -369,7 +370,7 @@ class EditPlugin(plugins.BeetsPlugin):
if not obj._db or obj.id is None:
obj.id = -i
- # Present the YAML to the user and let her change it.
+ # Present the YAML to the user and let them change it.
fields = self._get_fields(album=False, extra=[])
success = self.edit_objects(task.items, fields)
diff --git a/lib/beetsplug/embedart.py b/lib/beetsplug/embedart.py
index 6db46f8c..740863bf 100644
--- a/lib/beetsplug/embedart.py
+++ b/lib/beetsplug/embedart.py
@@ -15,14 +15,16 @@
"""Allows beets to embed album art into file metadata."""
import os.path
+import tempfile
+from mimetypes import guess_extension
+import requests
+
+from beets import art, config, ui
from beets.plugins import BeetsPlugin
-from beets import ui
-from beets.ui import print_, decargs
-from beets.util import syspath, normpath, displayable_path, bytestring_path
+from beets.ui import decargs, print_
+from beets.util import bytestring_path, displayable_path, normpath, syspath
from beets.util.artresizer import ArtResizer
-from beets import config
-from beets import art
def _confirm(objs, album):
@@ -32,11 +34,9 @@ def _confirm(objs, album):
`album` is a Boolean indicating whether these are albums (as opposed
to items).
"""
- noun = 'album' if album else 'file'
- prompt = 'Modify artwork for {} {}{} (Y/n)?'.format(
- len(objs),
- noun,
- 's' if len(objs) > 1 else ''
+ noun = "album" if album else "file"
+ prompt = "Modify artwork for {} {}{} (Y/n)?".format(
+ len(objs), noun, "s" if len(objs) > 1 else ""
)
# Show all the items or albums.
@@ -48,54 +48,72 @@ def _confirm(objs, album):
class EmbedCoverArtPlugin(BeetsPlugin):
- """Allows albumart to be embedded into the actual files.
- """
+ """Allows albumart to be embedded into the actual files."""
+
def __init__(self):
super().__init__()
- self.config.add({
- 'maxwidth': 0,
- 'auto': True,
- 'compare_threshold': 0,
- 'ifempty': False,
- 'remove_art_file': False,
- 'quality': 0,
- })
+ self.config.add(
+ {
+ "maxwidth": 0,
+ "auto": True,
+ "compare_threshold": 0,
+ "ifempty": False,
+ "remove_art_file": False,
+ "quality": 0,
+ }
+ )
- if self.config['maxwidth'].get(int) and not ArtResizer.shared.local:
- self.config['maxwidth'] = 0
- self._log.warning("ImageMagick or PIL not found; "
- "'maxwidth' option ignored")
- if self.config['compare_threshold'].get(int) and not \
- ArtResizer.shared.can_compare:
- self.config['compare_threshold'] = 0
- self._log.warning("ImageMagick 6.8.7 or higher not installed; "
- "'compare_threshold' option ignored")
+ if self.config["maxwidth"].get(int) and not ArtResizer.shared.local:
+ self.config["maxwidth"] = 0
+ self._log.warning(
+ "ImageMagick or PIL not found; " "'maxwidth' option ignored"
+ )
+ if (
+ self.config["compare_threshold"].get(int)
+ and not ArtResizer.shared.can_compare
+ ):
+ self.config["compare_threshold"] = 0
+ self._log.warning(
+ "ImageMagick 6.8.7 or higher not installed; "
+ "'compare_threshold' option ignored"
+ )
- self.register_listener('art_set', self.process_album)
+ self.register_listener("art_set", self.process_album)
def commands(self):
# Embed command.
embed_cmd = ui.Subcommand(
- 'embedart', help='embed image files into file metadata'
+ "embedart", help="embed image files into file metadata"
)
embed_cmd.parser.add_option(
- '-f', '--file', metavar='PATH', help='the image file to embed'
+ "-f", "--file", metavar="PATH", help="the image file to embed"
)
+
embed_cmd.parser.add_option(
"-y", "--yes", action="store_true", help="skip confirmation"
)
- maxwidth = self.config['maxwidth'].get(int)
- quality = self.config['quality'].get(int)
- compare_threshold = self.config['compare_threshold'].get(int)
- ifempty = self.config['ifempty'].get(bool)
+
+ embed_cmd.parser.add_option(
+ "-u",
+ "--url",
+ metavar="URL",
+ help="the URL of the image file to embed",
+ )
+
+ maxwidth = self.config["maxwidth"].get(int)
+ quality = self.config["quality"].get(int)
+ compare_threshold = self.config["compare_threshold"].get(int)
+ ifempty = self.config["ifempty"].get(bool)
def embed_func(lib, opts, args):
if opts.file:
imagepath = normpath(opts.file)
if not os.path.isfile(syspath(imagepath)):
- raise ui.UserError('image file {} not found'.format(
- displayable_path(imagepath)
- ))
+ raise ui.UserError(
+ "image file {} not found".format(
+ displayable_path(imagepath)
+ )
+ )
items = lib.items(decargs(args))
@@ -104,66 +122,122 @@ class EmbedCoverArtPlugin(BeetsPlugin):
return
for item in items:
- art.embed_item(self._log, item, imagepath, maxwidth,
- None, compare_threshold, ifempty,
- quality=quality)
+ art.embed_item(
+ self._log,
+ item,
+ imagepath,
+ maxwidth,
+ None,
+ compare_threshold,
+ ifempty,
+ quality=quality,
+ )
+ elif opts.url:
+ try:
+ response = requests.get(opts.url, timeout=5)
+ response.raise_for_status()
+ except requests.exceptions.RequestException as e:
+ self._log.error("{}".format(e))
+ return
+ extension = guess_extension(response.headers["Content-Type"])
+ if extension is None:
+ self._log.error("Invalid image file")
+ return
+ file = f"image{extension}"
+ tempimg = os.path.join(tempfile.gettempdir(), file)
+ try:
+ with open(tempimg, "wb") as f:
+ f.write(response.content)
+ except Exception as e:
+ self._log.error("Unable to save image: {}".format(e))
+ return
+ items = lib.items(decargs(args))
+ # Confirm with user.
+ if not opts.yes and not _confirm(items, not opts.url):
+ os.remove(tempimg)
+ return
+ for item in items:
+ art.embed_item(
+ self._log,
+ item,
+ tempimg,
+ maxwidth,
+ None,
+ compare_threshold,
+ ifempty,
+ quality=quality,
+ )
+ os.remove(tempimg)
else:
albums = lib.albums(decargs(args))
-
# Confirm with user.
if not opts.yes and not _confirm(albums, not opts.file):
return
-
for album in albums:
- art.embed_album(self._log, album, maxwidth,
- False, compare_threshold, ifempty,
- quality=quality)
+ art.embed_album(
+ self._log,
+ album,
+ maxwidth,
+ False,
+ compare_threshold,
+ ifempty,
+ quality=quality,
+ )
self.remove_artfile(album)
embed_cmd.func = embed_func
# Extract command.
extract_cmd = ui.Subcommand(
- 'extractart',
- help='extract an image from file metadata',
+ "extractart",
+ help="extract an image from file metadata",
)
extract_cmd.parser.add_option(
- '-o', dest='outpath',
- help='image output file',
+ "-o",
+ dest="outpath",
+ help="image output file",
)
extract_cmd.parser.add_option(
- '-n', dest='filename',
- help='image filename to create for all matched albums',
+ "-n",
+ dest="filename",
+ help="image filename to create for all matched albums",
)
extract_cmd.parser.add_option(
- '-a', dest='associate', action='store_true',
- help='associate the extracted images with the album',
+ "-a",
+ dest="associate",
+ action="store_true",
+ help="associate the extracted images with the album",
)
def extract_func(lib, opts, args):
if opts.outpath:
- art.extract_first(self._log, normpath(opts.outpath),
- lib.items(decargs(args)))
+ art.extract_first(
+ self._log, normpath(opts.outpath), lib.items(decargs(args))
+ )
else:
- filename = bytestring_path(opts.filename or
- config['art_filename'].get())
- if os.path.dirname(filename) != b'':
+ filename = bytestring_path(
+ opts.filename or config["art_filename"].get()
+ )
+ if os.path.dirname(filename) != b"":
self._log.error(
- "Only specify a name rather than a path for -n")
+ "Only specify a name rather than a path for -n"
+ )
return
for album in lib.albums(decargs(args)):
artpath = normpath(os.path.join(album.path, filename))
- artpath = art.extract_first(self._log, artpath,
- album.items())
+ artpath = art.extract_first(
+ self._log, artpath, album.items()
+ )
if artpath and opts.associate:
album.set_art(artpath)
album.store()
+
extract_cmd.func = extract_func
# Clear command.
clear_cmd = ui.Subcommand(
- 'clearart',
- help='remove images from file metadata',
+ "clearart",
+ help="remove images from file metadata",
)
clear_cmd.parser.add_option(
"-y", "--yes", action="store_true", help="skip confirmation"
@@ -175,27 +249,32 @@ class EmbedCoverArtPlugin(BeetsPlugin):
if not opts.yes and not _confirm(items, False):
return
art.clear(self._log, lib, decargs(args))
+
clear_cmd.func = clear_func
return [embed_cmd, extract_cmd, clear_cmd]
def process_album(self, album):
- """Automatically embed art after art has been set
- """
- if self.config['auto'] and ui.should_write():
- max_width = self.config['maxwidth'].get(int)
- art.embed_album(self._log, album, max_width, True,
- self.config['compare_threshold'].get(int),
- self.config['ifempty'].get(bool))
+ """Automatically embed art after art has been set"""
+ if self.config["auto"] and ui.should_write():
+ max_width = self.config["maxwidth"].get(int)
+ art.embed_album(
+ self._log,
+ album,
+ max_width,
+ True,
+ self.config["compare_threshold"].get(int),
+ self.config["ifempty"].get(bool),
+ )
self.remove_artfile(album)
def remove_artfile(self, album):
"""Possibly delete the album art file for an album (if the
appropriate configuration option is enabled).
"""
- if self.config['remove_art_file'] and album.artpath:
- if os.path.isfile(album.artpath):
- self._log.debug('Removing album art file for {0}', album)
- os.remove(album.artpath)
+ if self.config["remove_art_file"] and album.artpath:
+ if os.path.isfile(syspath(album.artpath)):
+ self._log.debug("Removing album art file for {0}", album)
+ os.remove(syspath(album.artpath))
album.artpath = None
album.store()
diff --git a/lib/beetsplug/embyupdate.py b/lib/beetsplug/embyupdate.py
index c17fabad..22c88947 100644
--- a/lib/beetsplug/embyupdate.py
+++ b/lib/beetsplug/embyupdate.py
@@ -9,9 +9,10 @@
"""
import hashlib
+from urllib.parse import parse_qs, urlencode, urljoin, urlsplit, urlunsplit
+
import requests
-from urllib.parse import urlencode, urljoin, parse_qs, urlsplit, urlunsplit
from beets import config
from beets.plugins import BeetsPlugin
@@ -32,24 +33,20 @@ def api_url(host, port, endpoint):
"""
# check if http or https is defined as host and create hostname
hostname_list = [host]
- if host.startswith('http://') or host.startswith('https://'):
- hostname = ''.join(hostname_list)
+ if host.startswith("http://") or host.startswith("https://"):
+ hostname = "".join(hostname_list)
else:
- hostname_list.insert(0, 'http://')
- hostname = ''.join(hostname_list)
+ hostname_list.insert(0, "http://")
+ hostname = "".join(hostname_list)
joined = urljoin(
- '{hostname}:{port}'.format(
- hostname=hostname,
- port=port
- ),
- endpoint
+ "{hostname}:{port}".format(hostname=hostname, port=port), endpoint
)
scheme, netloc, path, query_string, fragment = urlsplit(joined)
query_params = parse_qs(query_string)
- query_params['format'] = ['json']
+ query_params["format"] = ["json"]
new_query_string = urlencode(query_params, doseq=True)
return urlunsplit((scheme, netloc, path, new_query_string, fragment))
@@ -66,9 +63,9 @@ def password_data(username, password):
:rtype: dict
"""
return {
- 'username': username,
- 'password': hashlib.sha1(password.encode('utf-8')).hexdigest(),
- 'passwordMd5': hashlib.md5(password.encode('utf-8')).hexdigest()
+ "username": username,
+ "password": hashlib.sha1(password.encode("utf-8")).hexdigest(),
+ "passwordMd5": hashlib.md5(password.encode("utf-8")).hexdigest(),
}
@@ -92,10 +89,10 @@ def create_headers(user_id, token=None):
'Version="0.0.0"'
).format(user_id=user_id)
- headers['x-emby-authorization'] = authorization
+ headers["x-emby-authorization"] = authorization
if token:
- headers['x-mediabrowser-token'] = token
+ headers["x-mediabrowser-token"] = token
return headers
@@ -114,10 +111,15 @@ def get_token(host, port, headers, auth_data):
:returns: Access Token
:rtype: str
"""
- url = api_url(host, port, '/Users/AuthenticateByName')
- r = requests.post(url, headers=headers, data=auth_data)
+ url = api_url(host, port, "/Users/AuthenticateByName")
+ r = requests.post(
+ url,
+ headers=headers,
+ data=auth_data,
+ timeout=10,
+ )
- return r.json().get('AccessToken')
+ return r.json().get("AccessToken")
def get_user(host, port, username):
@@ -132,9 +134,9 @@ def get_user(host, port, username):
:returns: Matched Users
:rtype: list
"""
- url = api_url(host, port, '/Users/Public')
- r = requests.get(url)
- user = [i for i in r.json() if i['Name'] == username]
+ url = api_url(host, port, "/Users/Public")
+ r = requests.get(url, timeout=10)
+ user = [i for i in r.json() if i["Name"] == username]
return user
@@ -144,62 +146,67 @@ class EmbyUpdate(BeetsPlugin):
super().__init__()
# Adding defaults.
- config['emby'].add({
- 'host': 'http://localhost',
- 'port': 8096,
- 'apikey': None,
- 'password': None,
- })
+ config["emby"].add(
+ {
+ "host": "http://localhost",
+ "port": 8096,
+ "apikey": None,
+ "password": None,
+ }
+ )
- self.register_listener('database_change', self.listen_for_db_change)
+ self.register_listener("database_change", self.listen_for_db_change)
def listen_for_db_change(self, lib, model):
- """Listens for beets db change and register the update for the end.
- """
- self.register_listener('cli_exit', self.update)
+ """Listens for beets db change and register the update for the end."""
+ self.register_listener("cli_exit", self.update)
def update(self, lib):
- """When the client exists try to send refresh request to Emby.
- """
- self._log.info('Updating Emby library...')
+ """When the client exists try to send refresh request to Emby."""
+ self._log.info("Updating Emby library...")
- host = config['emby']['host'].get()
- port = config['emby']['port'].get()
- username = config['emby']['username'].get()
- password = config['emby']['password'].get()
- token = config['emby']['apikey'].get()
+ host = config["emby"]["host"].get()
+ port = config["emby"]["port"].get()
+ username = config["emby"]["username"].get()
+ password = config["emby"]["password"].get()
+ userid = config["emby"]["userid"].get()
+ token = config["emby"]["apikey"].get()
# Check if at least a apikey or password is given.
if not any([password, token]):
- self._log.warning('Provide at least Emby password or apikey.')
+ self._log.warning("Provide at least Emby password or apikey.")
return
- # Get user information from the Emby API.
- user = get_user(host, port, username)
- if not user:
- self._log.warning(f'User {username} could not be found.')
- return
+ if not userid:
+ # Get user information from the Emby API.
+ user = get_user(host, port, username)
+ if not user:
+ self._log.warning(f"User {username} could not be found.")
+ return
+ userid = user[0]["Id"]
if not token:
# Create Authentication data and headers.
auth_data = password_data(username, password)
- headers = create_headers(user[0]['Id'])
+ headers = create_headers(userid)
# Get authentication token.
token = get_token(host, port, headers, auth_data)
if not token:
- self._log.warning(
- 'Could not get token for user {0}', username
- )
+ self._log.warning("Could not get token for user {0}", username)
return
# Recreate headers with a token.
- headers = create_headers(user[0]['Id'], token=token)
+ headers = create_headers(userid, token=token)
# Trigger the Update.
- url = api_url(host, port, '/Library/Refresh')
- r = requests.post(url, headers=headers)
+ url = api_url(host, port, "/Library/Refresh")
+ r = requests.post(
+ url,
+ headers=headers,
+ timeout=10,
+ )
if r.status_code != 204:
- self._log.warning('Update could not be triggered')
+ self._log.warning("Update could not be triggered")
else:
- self._log.info('Update triggered.')
+ self._log.info("Update triggered.")
diff --git a/lib/beetsplug/export.py b/lib/beetsplug/export.py
index 99f6d706..ef3ba94a 100644
--- a/lib/beetsplug/export.py
+++ b/lib/beetsplug/export.py
@@ -15,22 +15,23 @@
"""
-import sys
import codecs
-import json
import csv
+import json
+import sys
+from datetime import date, datetime
from xml.etree import ElementTree
-from datetime import datetime, date
-from beets.plugins import BeetsPlugin
-from beets import ui
-from beets import util
import mediafile
+
+from beets import ui, util
+from beets.plugins import BeetsPlugin
from beetsplug.info import library_data, tag_data
class ExportEncoder(json.JSONEncoder):
"""Deals with dates because JSON doesn't have a standard"""
+
def default(self, o):
if isinstance(o, (datetime, date)):
return o.isoformat()
@@ -38,89 +39,99 @@ class ExportEncoder(json.JSONEncoder):
class ExportPlugin(BeetsPlugin):
-
def __init__(self):
super().__init__()
- self.config.add({
- 'default_format': 'json',
- 'json': {
- # JSON module formatting options.
- 'formatting': {
- 'ensure_ascii': False,
- 'indent': 4,
- 'separators': (',', ': '),
- 'sort_keys': True
- }
- },
- 'jsonlines': {
- # JSON Lines formatting options.
- 'formatting': {
- 'ensure_ascii': False,
- 'separators': (',', ': '),
- 'sort_keys': True
- }
- },
- 'csv': {
- # CSV module formatting options.
- 'formatting': {
- # The delimiter used to seperate columns.
- 'delimiter': ',',
- # The dialect to use when formating the file output.
- 'dialect': 'excel'
- }
- },
- 'xml': {
- # XML module formatting options.
- 'formatting': {}
+ self.config.add(
+ {
+ "default_format": "json",
+ "json": {
+ # JSON module formatting options.
+ "formatting": {
+ "ensure_ascii": False,
+ "indent": 4,
+ "separators": (",", ": "),
+ "sort_keys": True,
+ }
+ },
+ "jsonlines": {
+ # JSON Lines formatting options.
+ "formatting": {
+ "ensure_ascii": False,
+ "separators": (",", ": "),
+ "sort_keys": True,
+ }
+ },
+ "csv": {
+ # CSV module formatting options.
+ "formatting": {
+ # The delimiter used to separate columns.
+ "delimiter": ",",
+ # The dialect to use when formatting the file output.
+ "dialect": "excel",
+ }
+ },
+ "xml": {
+ # XML module formatting options.
+ "formatting": {}
+ },
+ # TODO: Use something like the edit plugin
+ # 'item_fields': []
}
- # TODO: Use something like the edit plugin
- # 'item_fields': []
- })
+ )
def commands(self):
- cmd = ui.Subcommand('export', help='export data from beets')
+ cmd = ui.Subcommand("export", help="export data from beets")
cmd.func = self.run
cmd.parser.add_option(
- '-l', '--library', action='store_true',
- help='show library fields instead of tags',
+ "-l",
+ "--library",
+ action="store_true",
+ help="show library fields instead of tags",
)
cmd.parser.add_option(
- '-a', '--album', action='store_true',
+ "-a",
+ "--album",
+ action="store_true",
help='show album fields instead of tracks (implies "--library")',
)
cmd.parser.add_option(
- '--append', action='store_true', default=False,
- help='if should append data to the file',
+ "--append",
+ action="store_true",
+ default=False,
+ help="if should append data to the file",
)
cmd.parser.add_option(
- '-i', '--include-keys', default=[],
- action='append', dest='included_keys',
- help='comma separated list of keys to show',
+ "-i",
+ "--include-keys",
+ default=[],
+ action="append",
+ dest="included_keys",
+ help="comma separated list of keys to show",
)
cmd.parser.add_option(
- '-o', '--output',
- help='path for the output file. If not given, will print the data'
+ "-o",
+ "--output",
+ help="path for the output file. If not given, will print the data",
)
cmd.parser.add_option(
- '-f', '--format', default='json',
- help="the output format: json (default), jsonlines, csv, or xml"
+ "-f",
+ "--format",
+ default="json",
+ help="the output format: json (default), jsonlines, csv, or xml",
)
return [cmd]
def run(self, lib, opts, args):
file_path = opts.output
- file_mode = 'a' if opts.append else 'w'
- file_format = opts.format or self.config['default_format'].get(str)
- file_format_is_line_based = (file_format == 'jsonlines')
- format_options = self.config[file_format]['formatting'].get(dict)
+ file_mode = "a" if opts.append else "w"
+ file_format = opts.format or self.config["default_format"].get(str)
+ file_format_is_line_based = file_format == "jsonlines"
+ format_options = self.config[file_format]["formatting"].get(dict)
export_format = ExportFormat.factory(
file_type=file_format,
- **{
- 'file_path': file_path,
- 'file_mode': file_mode
- }
+ **{"file_path": file_path, "file_mode": file_mode},
)
if opts.library or opts.album:
@@ -130,17 +141,18 @@ class ExportPlugin(BeetsPlugin):
included_keys = []
for keys in opts.included_keys:
- included_keys.extend(keys.split(','))
+ included_keys.extend(keys.split(","))
items = []
for data_emitter in data_collector(
- lib, ui.decargs(args),
- album=opts.album,
+ lib,
+ ui.decargs(args),
+ album=opts.album,
):
try:
- data, item = data_emitter(included_keys or '*')
+ data, item = data_emitter(included_keys or "*")
except (mediafile.UnreadableFileError, OSError) as ex:
- self._log.error('cannot read file: {0}', ex)
+ self._log.error("cannot read file: {0}", ex)
continue
for key, value in data.items():
@@ -158,13 +170,17 @@ class ExportPlugin(BeetsPlugin):
class ExportFormat:
"""The output format type"""
- def __init__(self, file_path, file_mode='w', encoding='utf-8'):
+
+ def __init__(self, file_path, file_mode="w", encoding="utf-8"):
self.path = file_path
self.mode = file_mode
self.encoding = encoding
# creates a file object to write/append or sets to stdout
- self.out_stream = codecs.open(self.path, self.mode, self.encoding) \
- if self.path else sys.stdout
+ self.out_stream = (
+ codecs.open(self.path, self.mode, self.encoding)
+ if self.path
+ else sys.stdout
+ )
@classmethod
def factory(cls, file_type, **kwargs):
@@ -183,17 +199,19 @@ class ExportFormat:
class JsonFormat(ExportFormat):
"""Saves in a json file"""
- def __init__(self, file_path, file_mode='w', encoding='utf-8'):
+
+ def __init__(self, file_path, file_mode="w", encoding="utf-8"):
super().__init__(file_path, file_mode, encoding)
def export(self, data, **kwargs):
json.dump(data, self.out_stream, cls=ExportEncoder, **kwargs)
- self.out_stream.write('\n')
+ self.out_stream.write("\n")
class CSVFormat(ExportFormat):
"""Saves in a csv file"""
- def __init__(self, file_path, file_mode='w', encoding='utf-8'):
+
+ def __init__(self, file_path, file_mode="w", encoding="utf-8"):
super().__init__(file_path, file_mode, encoding)
def export(self, data, **kwargs):
@@ -205,23 +223,24 @@ class CSVFormat(ExportFormat):
class XMLFormat(ExportFormat):
"""Saves in a xml file"""
- def __init__(self, file_path, file_mode='w', encoding='utf-8'):
+
+ def __init__(self, file_path, file_mode="w", encoding="utf-8"):
super().__init__(file_path, file_mode, encoding)
def export(self, data, **kwargs):
# Creates the XML file structure.
- library = ElementTree.Element('library')
- tracks = ElementTree.SubElement(library, 'tracks')
+ library = ElementTree.Element("library")
+ tracks = ElementTree.SubElement(library, "tracks")
if data and isinstance(data[0], dict):
for index, item in enumerate(data):
- track = ElementTree.SubElement(tracks, 'track')
+ track = ElementTree.SubElement(tracks, "track")
for key, value in item.items():
track_details = ElementTree.SubElement(track, key)
track_details.text = value
# Depending on the version of python the encoding needs to change
try:
- data = ElementTree.tostring(library, encoding='unicode', **kwargs)
+ data = ElementTree.tostring(library, encoding="unicode", **kwargs)
except LookupError:
- data = ElementTree.tostring(library, encoding='utf-8', **kwargs)
+ data = ElementTree.tostring(library, encoding="utf-8", **kwargs)
self.out_stream.write(data)
diff --git a/lib/beetsplug/fetchart.py b/lib/beetsplug/fetchart.py
index f2c1e5a7..72aa3aa2 100644
--- a/lib/beetsplug/fetchart.py
+++ b/lib/beetsplug/fetchart.py
@@ -15,29 +15,28 @@
"""Fetches album art.
"""
-from contextlib import closing
import os
import re
-from tempfile import NamedTemporaryFile
from collections import OrderedDict
+from contextlib import closing
-import requests
-
-from beets import plugins
-from beets import importer
-from beets import ui
-from beets import util
-from beets import config
-from mediafile import image_mime_type
-from beets.util.artresizer import ArtResizer
-from beets.util import sorted_walk
-from beets.util import syspath, bytestring_path, py3_path
import confuse
+import requests
+from mediafile import image_mime_type
-CONTENT_TYPES = {
- 'image/jpeg': [b'jpg', b'jpeg'],
- 'image/png': [b'png']
-}
+from beets import config, importer, plugins, ui, util
+from beets.util import bytestring_path, get_temp_filename, sorted_walk, syspath
+from beets.util.artresizer import ArtResizer
+
+try:
+ from bs4 import BeautifulSoup
+
+ HAS_BEAUTIFUL_SOUP = True
+except ImportError:
+ HAS_BEAUTIFUL_SOUP = False
+
+
+CONTENT_TYPES = {"image/jpeg": [b"jpg", b"jpeg"], "image/png": [b"png"]}
IMAGE_EXTENSIONS = [ext for exts in CONTENT_TYPES.values() for ext in exts]
@@ -45,6 +44,7 @@ class Candidate:
"""Holds information about a matching artwork, deals with validation of
dimension restrictions and resizing.
"""
+
CANDIDATE_BAD = 0
CANDIDATE_EXACT = 1
CANDIDATE_DOWNSCALE = 2
@@ -55,8 +55,9 @@ class Candidate:
MATCH_EXACT = 0
MATCH_FALLBACK = 1
- def __init__(self, log, path=None, url=None, source='',
- match=None, size=None):
+ def __init__(
+ self, log, path=None, url=None, source="", match=None, size=None
+ ):
self._log = log
self.path = path
self.url = url
@@ -65,10 +66,15 @@ class Candidate:
self.match = match
self.size = size
- def _validate(self, plugin):
+ def _validate(self, plugin, skip_check_for=None):
"""Determine whether the candidate artwork is valid based on
its dimensions (width and ratio).
+ `skip_check_for` is a check or list of checks to skip. This is used to
+ avoid redundant checks when the candidate has already been
+ validated for a particular operation without changing
+ plugin configuration.
+
Return `CANDIDATE_BAD` if the file is unusable.
Return `CANDIDATE_EXACT` if the file is usable as-is.
Return `CANDIDATE_DOWNSCALE` if the file must be rescaled.
@@ -80,22 +86,34 @@ class Candidate:
if not self.path:
return self.CANDIDATE_BAD
- if (not (plugin.enforce_ratio or plugin.minwidth or plugin.maxwidth
- or plugin.max_filesize or plugin.deinterlace
- or plugin.cover_format)):
+ if skip_check_for is None:
+ skip_check_for = []
+ if isinstance(skip_check_for, int):
+ skip_check_for = [skip_check_for]
+
+ if not (
+ plugin.enforce_ratio
+ or plugin.minwidth
+ or plugin.maxwidth
+ or plugin.max_filesize
+ or plugin.deinterlace
+ or plugin.cover_format
+ ):
return self.CANDIDATE_EXACT
# get_size returns None if no local imaging backend is available
if not self.size:
self.size = ArtResizer.shared.get_size(self.path)
- self._log.debug('image size: {}', self.size)
+ self._log.debug("image size: {}", self.size)
if not self.size:
- self._log.warning('Could not get size of image (please see '
- 'documentation for dependencies). '
- 'The configuration options `minwidth`, '
- '`enforce_ratio` and `max_filesize` '
- 'may be violated.')
+ self._log.warning(
+ "Could not get size of image (please see "
+ "documentation for dependencies). "
+ "The configuration options `minwidth`, "
+ "`enforce_ratio` and `max_filesize` "
+ "may be violated."
+ )
return self.CANDIDATE_EXACT
short_edge = min(self.size)
@@ -103,8 +121,9 @@ class Candidate:
# Check minimum dimension.
if plugin.minwidth and self.size[0] < plugin.minwidth:
- self._log.debug('image too small ({} < {})',
- self.size[0], plugin.minwidth)
+ self._log.debug(
+ "image too small ({} < {})", self.size[0], plugin.minwidth
+ )
return self.CANDIDATE_BAD
# Check aspect ratio.
@@ -112,28 +131,38 @@ class Candidate:
if plugin.enforce_ratio:
if plugin.margin_px:
if edge_diff > plugin.margin_px:
- self._log.debug('image is not close enough to being '
- 'square, ({} - {} > {})',
- long_edge, short_edge, plugin.margin_px)
+ self._log.debug(
+ "image is not close enough to being "
+ "square, ({} - {} > {})",
+ long_edge,
+ short_edge,
+ plugin.margin_px,
+ )
return self.CANDIDATE_BAD
elif plugin.margin_percent:
margin_px = plugin.margin_percent * long_edge
if edge_diff > margin_px:
- self._log.debug('image is not close enough to being '
- 'square, ({} - {} > {})',
- long_edge, short_edge, margin_px)
+ self._log.debug(
+ "image is not close enough to being "
+ "square, ({} - {} > {})",
+ long_edge,
+ short_edge,
+ margin_px,
+ )
return self.CANDIDATE_BAD
elif edge_diff:
# also reached for margin_px == 0 and margin_percent == 0.0
- self._log.debug('image is not square ({} != {})',
- self.size[0], self.size[1])
+ self._log.debug(
+ "image is not square ({} != {})", self.size[0], self.size[1]
+ )
return self.CANDIDATE_BAD
# Check maximum dimension.
downscale = False
if plugin.maxwidth and self.size[0] > plugin.maxwidth:
- self._log.debug('image needs rescaling ({} > {})',
- self.size[0], plugin.maxwidth)
+ self._log.debug(
+ "image needs rescaling ({} > {})", self.size[0], plugin.maxwidth
+ )
downscale = True
# Check filesize.
@@ -141,8 +170,11 @@ class Candidate:
if plugin.max_filesize:
filesize = os.stat(syspath(self.path)).st_size
if filesize > plugin.max_filesize:
- self._log.debug('image needs resizing ({}B > {}B)',
- filesize, plugin.max_filesize)
+ self._log.debug(
+ "image needs resizing ({}B > {}B)",
+ filesize,
+ plugin.max_filesize,
+ )
downsize = True
# Check image format
@@ -151,43 +183,71 @@ class Candidate:
fmt = ArtResizer.shared.get_format(self.path)
reformat = fmt != plugin.cover_format
if reformat:
- self._log.debug('image needs reformatting: {} -> {}',
- fmt, plugin.cover_format)
+ self._log.debug(
+ "image needs reformatting: {} -> {}",
+ fmt,
+ plugin.cover_format,
+ )
- if downscale:
+ if downscale and (self.CANDIDATE_DOWNSCALE not in skip_check_for):
return self.CANDIDATE_DOWNSCALE
- elif downsize:
- return self.CANDIDATE_DOWNSIZE
- elif plugin.deinterlace:
- return self.CANDIDATE_DEINTERLACE
- elif reformat:
+ if reformat and (self.CANDIDATE_REFORMAT not in skip_check_for):
return self.CANDIDATE_REFORMAT
- else:
- return self.CANDIDATE_EXACT
+ if plugin.deinterlace and (
+ self.CANDIDATE_DEINTERLACE not in skip_check_for
+ ):
+ return self.CANDIDATE_DEINTERLACE
+ if downsize and (self.CANDIDATE_DOWNSIZE not in skip_check_for):
+ return self.CANDIDATE_DOWNSIZE
+ return self.CANDIDATE_EXACT
- def validate(self, plugin):
- self.check = self._validate(plugin)
+ def validate(self, plugin, skip_check_for=None):
+ self.check = self._validate(plugin, skip_check_for)
return self.check
def resize(self, plugin):
- if self.check == self.CANDIDATE_DOWNSCALE:
- self.path = \
- ArtResizer.shared.resize(plugin.maxwidth, self.path,
- quality=plugin.quality,
- max_filesize=plugin.max_filesize)
- elif self.check == self.CANDIDATE_DOWNSIZE:
+ """Resize the candidate artwork according to the plugin's
+ configuration until it is valid or no further resizing is
+ possible.
+ """
+ # validate the candidate in case it hasn't been done yet
+ current_check = self.validate(plugin)
+ checks_performed = []
+
+ # we don't want to resize the image if it's valid or bad
+ while current_check not in [self.CANDIDATE_BAD, self.CANDIDATE_EXACT]:
+ self._resize(plugin, current_check)
+ checks_performed.append(current_check)
+ current_check = self.validate(
+ plugin, skip_check_for=checks_performed
+ )
+
+ def _resize(self, plugin, check=None):
+ """Resize the candidate artwork according to the plugin's
+ configuration and the specified check.
+ """
+ if check == self.CANDIDATE_DOWNSCALE:
+ self.path = ArtResizer.shared.resize(
+ plugin.maxwidth,
+ self.path,
+ quality=plugin.quality,
+ max_filesize=plugin.max_filesize,
+ )
+ elif check == self.CANDIDATE_DOWNSIZE:
# dimensions are correct, so maxwidth is set to maximum dimension
- self.path = \
- ArtResizer.shared.resize(max(self.size), self.path,
- quality=plugin.quality,
- max_filesize=plugin.max_filesize)
- elif self.check == self.CANDIDATE_DEINTERLACE:
+ self.path = ArtResizer.shared.resize(
+ max(self.size),
+ self.path,
+ quality=plugin.quality,
+ max_filesize=plugin.max_filesize,
+ )
+ elif check == self.CANDIDATE_DEINTERLACE:
self.path = ArtResizer.shared.deinterlace(self.path)
- elif self.check == self.CANDIDATE_REFORMAT:
+ elif check == self.CANDIDATE_REFORMAT:
self.path = ArtResizer.shared.reformat(
- self.path,
- plugin.cover_format,
- deinterlaced=plugin.deinterlace,
+ self.path,
+ plugin.cover_format,
+ deinterlaced=plugin.deinterlace,
)
@@ -206,26 +266,28 @@ def _logged_get(log, *args, **kwargs):
# `requests.Session.request`.
req_kwargs = kwargs
send_kwargs = {}
- for arg in ('stream', 'verify', 'proxies', 'cert', 'timeout'):
+ for arg in ("stream", "verify", "proxies", "cert", "timeout"):
if arg in kwargs:
send_kwargs[arg] = req_kwargs.pop(arg)
+ if "timeout" not in send_kwargs:
+ send_kwargs["timeout"] = 10
# Our special logging message parameter.
- if 'message' in kwargs:
- message = kwargs.pop('message')
+ if "message" in kwargs:
+ message = kwargs.pop("message")
else:
- message = 'getting URL'
+ message = "getting URL"
- req = requests.Request('GET', *args, **req_kwargs)
+ req = requests.Request("GET", *args, **req_kwargs)
with requests.Session() as s:
- s.headers = {'User-Agent': 'beets'}
+ s.headers = {"User-Agent": "beets"}
prepped = s.prepare_request(req)
settings = s.merge_environment_settings(
prepped.url, {}, None, None, None
)
send_kwargs.update(settings)
- log.debug('{}: {}', message, prepped.url)
+ log.debug("{}: {}", message, prepped.url)
return s.send(prepped, **send_kwargs)
@@ -244,14 +306,26 @@ class RequestMixin:
# ART SOURCES ################################################################
+
class ArtSource(RequestMixin):
- VALID_MATCHING_CRITERIA = ['default']
+ VALID_MATCHING_CRITERIA = ["default"]
def __init__(self, log, config, match_by=None):
self._log = log
self._config = config
self.match_by = match_by or self.VALID_MATCHING_CRITERIA
+ @staticmethod
+ def add_default_config(config):
+ pass
+
+ @classmethod
+ def available(cls, log, config):
+ """Return whether or not all dependencies are met and the art source is
+ in fact usable.
+ """
+ return True
+
def get(self, album, plugin, paths):
raise NotImplementedError()
@@ -267,7 +341,7 @@ class ArtSource(RequestMixin):
class LocalArtSource(ArtSource):
IS_LOCAL = True
- LOC_STR = 'local'
+ LOC_STR = "local"
def fetch_image(self, candidate, plugin):
pass
@@ -275,7 +349,7 @@ class LocalArtSource(ArtSource):
class RemoteArtSource(ArtSource):
IS_LOCAL = False
- LOC_STR = 'remote'
+ LOC_STR = "remote"
def fetch_image(self, candidate, plugin):
"""Downloads an image from a URL and checks whether it seems to
@@ -283,12 +357,16 @@ class RemoteArtSource(ArtSource):
Otherwise, returns None.
"""
if plugin.maxwidth:
- candidate.url = ArtResizer.shared.proxy_url(plugin.maxwidth,
- candidate.url)
+ candidate.url = ArtResizer.shared.proxy_url(
+ plugin.maxwidth, candidate.url
+ )
try:
- with closing(self.request(candidate.url, stream=True,
- message='downloading image')) as resp:
- ct = resp.headers.get('Content-Type', None)
+ with closing(
+ self.request(
+ candidate.url, stream=True, message="downloading image"
+ )
+ ) as resp:
+ ct = resp.headers.get("Content-Type", None)
# Download the image to a temporary file. As some servers
# (notably fanart.tv) have proven to return wrong Content-Types
@@ -296,7 +374,7 @@ class RemoteArtSource(ArtSource):
# rely on it. Instead validate the type using the file magic
# and only then determine the extension.
data = resp.iter_content(chunk_size=1024)
- header = b''
+ header = b""
for chunk in data:
header += chunk
if len(header) >= 32:
@@ -315,34 +393,41 @@ class RemoteArtSource(ArtSource):
real_ct = ct
if real_ct not in CONTENT_TYPES:
- self._log.debug('not a supported image: {}',
- real_ct or 'unknown content type')
+ self._log.debug(
+ "not a supported image: {}",
+ real_ct or "unknown content type",
+ )
return
- ext = b'.' + CONTENT_TYPES[real_ct][0]
+ ext = b"." + CONTENT_TYPES[real_ct][0]
if real_ct != ct:
- self._log.warning('Server specified {}, but returned a '
- '{} image. Correcting the extension '
- 'to {}',
- ct, real_ct, ext)
+ self._log.warning(
+ "Server specified {}, but returned a "
+ "{} image. Correcting the extension "
+ "to {}",
+ ct,
+ real_ct,
+ ext,
+ )
- suffix = py3_path(ext)
- with NamedTemporaryFile(suffix=suffix, delete=False) as fh:
+ filename = get_temp_filename(__name__, suffix=ext.decode())
+ with open(filename, "wb") as fh:
# write the first already loaded part of the image
fh.write(header)
# download the remaining part of the image
for chunk in data:
fh.write(chunk)
- self._log.debug('downloaded art to: {0}',
- util.displayable_path(fh.name))
- candidate.path = util.bytestring_path(fh.name)
+ self._log.debug(
+ "downloaded art to: {0}", util.displayable_path(filename)
+ )
+ candidate.path = util.bytestring_path(filename)
return
except (OSError, requests.RequestException, TypeError) as exc:
# Handling TypeError works around a urllib3 bug:
# https://github.com/shazow/urllib3/issues/556
- self._log.debug('error fetching art: {}', exc)
+ self._log.debug("error fetching art: {}", exc)
return
def cleanup(self, candidate):
@@ -350,46 +435,56 @@ class RemoteArtSource(ArtSource):
try:
util.remove(path=candidate.path)
except util.FilesystemError as exc:
- self._log.debug('error cleaning up tmp art: {}', exc)
+ self._log.debug("error cleaning up tmp art: {}", exc)
class CoverArtArchive(RemoteArtSource):
NAME = "Cover Art Archive"
- VALID_MATCHING_CRITERIA = ['release', 'releasegroup']
+ VALID_MATCHING_CRITERIA = ["release", "releasegroup"]
VALID_THUMBNAIL_SIZES = [250, 500, 1200]
- URL = 'https://coverartarchive.org/release/{mbid}'
- GROUP_URL = 'https://coverartarchive.org/release-group/{mbid}'
+ URL = "https://coverartarchive.org/release/{mbid}"
+ GROUP_URL = "https://coverartarchive.org/release-group/{mbid}"
def get(self, album, plugin, paths):
- """Return the Cover Art Archive and Cover Art Archive release group URLs
- using album MusicBrainz release ID and release group ID.
+ """Return the Cover Art Archive and Cover Art Archive release
+ group URLs using album MusicBrainz release ID and release group
+ ID.
"""
- def get_image_urls(url, size_suffix=None):
+ def get_image_urls(url, preferred_width=None):
try:
response = self.request(url)
except requests.RequestException:
- self._log.debug('{}: error receiving response'
- .format(self.NAME))
+ self._log.debug(
+ "{}: error receiving response".format(self.NAME)
+ )
return
try:
data = response.json()
except ValueError:
- self._log.debug('{}: error loading response: {}'
- .format(self.NAME, response.text))
+ self._log.debug(
+ "{}: error loading response: {}".format(
+ self.NAME, response.text
+ )
+ )
return
- for item in data.get('images', []):
+ for item in data.get("images", []):
try:
- if 'Front' not in item['types']:
+ if "Front" not in item["types"]:
continue
- if size_suffix:
- yield item['thumbnails'][size_suffix]
- else:
- yield item['image']
+ # If there is a pre-sized thumbnail of the desired size
+ # we select it. Otherwise, we return the raw image.
+ image_url: str = item["image"]
+ if preferred_width is not None:
+ if isinstance(item.get("thumbnails"), dict):
+ image_url = item["thumbnails"].get(
+ preferred_width, image_url
+ )
+ yield image_url
except KeyError:
pass
@@ -398,51 +493,51 @@ class CoverArtArchive(RemoteArtSource):
# Cover Art Archive API offers pre-resized thumbnails at several sizes.
# If the maxwidth config matches one of the already available sizes
- # fetch it directly intead of fetching the full sized image and
+ # fetch it directly instead of fetching the full sized image and
# resizing it.
- size_suffix = None
+ preferred_width = None
if plugin.maxwidth in self.VALID_THUMBNAIL_SIZES:
- size_suffix = "-" + str(plugin.maxwidth)
+ preferred_width = str(plugin.maxwidth)
- if 'release' in self.match_by and album.mb_albumid:
- for url in get_image_urls(release_url, size_suffix):
+ if "release" in self.match_by and album.mb_albumid:
+ for url in get_image_urls(release_url, preferred_width):
yield self._candidate(url=url, match=Candidate.MATCH_EXACT)
- if 'releasegroup' in self.match_by and album.mb_releasegroupid:
- for url in get_image_urls(release_group_url):
+ if "releasegroup" in self.match_by and album.mb_releasegroupid:
+ for url in get_image_urls(release_group_url, preferred_width):
yield self._candidate(url=url, match=Candidate.MATCH_FALLBACK)
class Amazon(RemoteArtSource):
NAME = "Amazon"
- URL = 'https://images.amazon.com/images/P/%s.%02i.LZZZZZZZ.jpg'
+ URL = "https://images.amazon.com/images/P/%s.%02i.LZZZZZZZ.jpg"
INDICES = (1, 2)
def get(self, album, plugin, paths):
- """Generate URLs using Amazon ID (ASIN) string.
- """
+ """Generate URLs using Amazon ID (ASIN) string."""
if album.asin:
for index in self.INDICES:
- yield self._candidate(url=self.URL % (album.asin, index),
- match=Candidate.MATCH_EXACT)
+ yield self._candidate(
+ url=self.URL % (album.asin, index),
+ match=Candidate.MATCH_EXACT,
+ )
class AlbumArtOrg(RemoteArtSource):
NAME = "AlbumArt.org scraper"
- URL = 'https://www.albumart.org/index_detail.php'
+ URL = "https://www.albumart.org/index_detail.php"
PAT = r'href\s*=\s*"([^>"]*)"[^>]*title\s*=\s*"View larger image"'
def get(self, album, plugin, paths):
- """Return art URL from AlbumArt.org using album ASIN.
- """
+ """Return art URL from AlbumArt.org using album ASIN."""
if not album.asin:
return
# Get the page from albumart.org.
try:
- resp = self.request(self.URL, params={'asin': album.asin})
- self._log.debug('scraped art URL: {0}', resp.url)
+ resp = self.request(self.URL, params={"asin": album.asin})
+ self._log.debug("scraped art URL: {0}", resp.url)
except requests.RequestException:
- self._log.debug('error scraping art page')
+ self._log.debug("error scraping art page")
return
# Search the page for the image URL.
@@ -451,17 +546,34 @@ class AlbumArtOrg(RemoteArtSource):
image_url = m.group(1)
yield self._candidate(url=image_url, match=Candidate.MATCH_EXACT)
else:
- self._log.debug('no image found on page')
+ self._log.debug("no image found on page")
class GoogleImages(RemoteArtSource):
NAME = "Google Images"
- URL = 'https://www.googleapis.com/customsearch/v1'
+ URL = "https://www.googleapis.com/customsearch/v1"
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
- self.key = self._config['google_key'].get(),
- self.cx = self._config['google_engine'].get(),
+ self.key = (self._config["google_key"].get(),)
+ self.cx = (self._config["google_engine"].get(),)
+
+ @staticmethod
+ def add_default_config(config):
+ config.add(
+ {
+ "google_key": None,
+ "google_engine": "001442825323518660753:hrh5ch1gjzm",
+ }
+ )
+ config["google_key"].redact = True
+
+ @classmethod
+ def available(cls, log, config):
+ has_key = bool(config["google_key"].get())
+ if not has_key:
+ log.debug("google: Disabling art source due to missing key")
+ return has_key
def get(self, album, plugin, paths):
"""Return art URL from google custom search engine
@@ -469,48 +581,63 @@ class GoogleImages(RemoteArtSource):
"""
if not (album.albumartist and album.album):
return
- search_string = (album.albumartist + ',' + album.album).encode('utf-8')
+ search_string = (album.albumartist + "," + album.album).encode("utf-8")
try:
- response = self.request(self.URL, params={
- 'key': self.key,
- 'cx': self.cx,
- 'q': search_string,
- 'searchType': 'image'
- })
+ response = self.request(
+ self.URL,
+ params={
+ "key": self.key,
+ "cx": self.cx,
+ "q": search_string,
+ "searchType": "image",
+ },
+ )
except requests.RequestException:
- self._log.debug('google: error receiving response')
+ self._log.debug("google: error receiving response")
return
# Get results using JSON.
try:
data = response.json()
except ValueError:
- self._log.debug('google: error loading response: {}'
- .format(response.text))
+ self._log.debug(
+ "google: error loading response: {}".format(response.text)
+ )
return
- if 'error' in data:
- reason = data['error']['errors'][0]['reason']
- self._log.debug('google fetchart error: {0}', reason)
+ if "error" in data:
+ reason = data["error"]["errors"][0]["reason"]
+ self._log.debug("google fetchart error: {0}", reason)
return
- if 'items' in data.keys():
- for item in data['items']:
- yield self._candidate(url=item['link'],
- match=Candidate.MATCH_EXACT)
+ if "items" in data.keys():
+ for item in data["items"]:
+ yield self._candidate(
+ url=item["link"], match=Candidate.MATCH_EXACT
+ )
class FanartTV(RemoteArtSource):
"""Art from fanart.tv requested using their API"""
+
NAME = "fanart.tv"
- API_URL = 'https://webservice.fanart.tv/v3/'
- API_ALBUMS = API_URL + 'music/albums/'
- PROJECT_KEY = '61a7d0ab4e67162b7a0c7c35915cd48e'
+ API_URL = "https://webservice.fanart.tv/v3/"
+ API_ALBUMS = API_URL + "music/albums/"
+ PROJECT_KEY = "61a7d0ab4e67162b7a0c7c35915cd48e"
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
- self.client_key = self._config['fanarttv_key'].get()
+ self.client_key = self._config["fanarttv_key"].get()
+
+ @staticmethod
+ def add_default_config(config):
+ config.add(
+ {
+ "fanarttv_key": None,
+ }
+ )
+ config["fanarttv_key"].redact = True
def get(self, album, plugin, paths):
if not album.mb_releasegroupid:
@@ -519,126 +646,142 @@ class FanartTV(RemoteArtSource):
try:
response = self.request(
self.API_ALBUMS + album.mb_releasegroupid,
- headers={'api-key': self.PROJECT_KEY,
- 'client-key': self.client_key})
+ headers={
+ "api-key": self.PROJECT_KEY,
+ "client-key": self.client_key,
+ },
+ )
except requests.RequestException:
- self._log.debug('fanart.tv: error receiving response')
+ self._log.debug("fanart.tv: error receiving response")
return
try:
data = response.json()
except ValueError:
- self._log.debug('fanart.tv: error loading response: {}',
- response.text)
+ self._log.debug(
+ "fanart.tv: error loading response: {}", response.text
+ )
return
- if 'status' in data and data['status'] == 'error':
- if 'not found' in data['error message'].lower():
- self._log.debug('fanart.tv: no image found')
- elif 'api key' in data['error message'].lower():
- self._log.warning('fanart.tv: Invalid API key given, please '
- 'enter a valid one in your config file.')
+ if "status" in data and data["status"] == "error":
+ if "not found" in data["error message"].lower():
+ self._log.debug("fanart.tv: no image found")
+ elif "api key" in data["error message"].lower():
+ self._log.warning(
+ "fanart.tv: Invalid API key given, please "
+ "enter a valid one in your config file."
+ )
else:
- self._log.debug('fanart.tv: error on request: {}',
- data['error message'])
+ self._log.debug(
+ "fanart.tv: error on request: {}", data["error message"]
+ )
return
matches = []
# can there be more than one releasegroupid per response?
- for mbid, art in data.get('albums', {}).items():
+ for mbid, art in data.get("albums", {}).items():
# there might be more art referenced, e.g. cdart, and an albumcover
# might not be present, even if the request was successful
- if album.mb_releasegroupid == mbid and 'albumcover' in art:
- matches.extend(art['albumcover'])
+ if album.mb_releasegroupid == mbid and "albumcover" in art:
+ matches.extend(art["albumcover"])
# can this actually occur?
else:
- self._log.debug('fanart.tv: unexpected mb_releasegroupid in '
- 'response!')
+ self._log.debug(
+ "fanart.tv: unexpected mb_releasegroupid in " "response!"
+ )
- matches.sort(key=lambda x: x['likes'], reverse=True)
+ matches.sort(key=lambda x: int(x["likes"]), reverse=True)
for item in matches:
# fanart.tv has a strict size requirement for album art to be
# uploaded
- yield self._candidate(url=item['url'],
- match=Candidate.MATCH_EXACT,
- size=(1000, 1000))
+ yield self._candidate(
+ url=item["url"], match=Candidate.MATCH_EXACT, size=(1000, 1000)
+ )
class ITunesStore(RemoteArtSource):
NAME = "iTunes Store"
- API_URL = 'https://itunes.apple.com/search'
+ API_URL = "https://itunes.apple.com/search"
def get(self, album, plugin, paths):
- """Return art URL from iTunes Store given an album title.
- """
+ """Return art URL from iTunes Store given an album title."""
if not (album.albumartist and album.album):
return
payload = {
- 'term': album.albumartist + ' ' + album.album,
- 'entity': 'album',
- 'media': 'music',
- 'limit': 200
+ "term": album.albumartist + " " + album.album,
+ "entity": "album",
+ "media": "music",
+ "limit": 200,
}
try:
r = self.request(self.API_URL, params=payload)
r.raise_for_status()
except requests.RequestException as e:
- self._log.debug('iTunes search failed: {0}', e)
+ self._log.debug("iTunes search failed: {0}", e)
return
try:
- candidates = r.json()['results']
+ candidates = r.json()["results"]
except ValueError as e:
- self._log.debug('Could not decode json response: {0}', e)
+ self._log.debug("Could not decode json response: {0}", e)
return
except KeyError as e:
- self._log.debug('{} not found in json. Fields are {} ',
- e,
- list(r.json().keys()))
+ self._log.debug(
+ "{} not found in json. Fields are {} ", e, list(r.json().keys())
+ )
return
if not candidates:
- self._log.debug('iTunes search for {!r} got no results',
- payload['term'])
+ self._log.debug(
+ "iTunes search for {!r} got no results", payload["term"]
+ )
return
- if self._config['high_resolution']:
- image_suffix = '100000x100000-999'
+ if self._config["high_resolution"]:
+ image_suffix = "100000x100000-999"
else:
- image_suffix = '1200x1200bb'
+ image_suffix = "1200x1200bb"
for c in candidates:
try:
- if (c['artistName'] == album.albumartist
- and c['collectionName'] == album.album):
- art_url = c['artworkUrl100']
- art_url = art_url.replace('100x100bb',
- image_suffix)
- yield self._candidate(url=art_url,
- match=Candidate.MATCH_EXACT)
+ if (
+ c["artistName"] == album.albumartist
+ and c["collectionName"] == album.album
+ ):
+ art_url = c["artworkUrl100"]
+ art_url = art_url.replace("100x100bb", image_suffix)
+ yield self._candidate(
+ url=art_url, match=Candidate.MATCH_EXACT
+ )
except KeyError as e:
- self._log.debug('Malformed itunes candidate: {} not found in {}', # NOQA E501
- e,
- list(c.keys()))
+ self._log.debug(
+ "Malformed itunes candidate: {} not found in {}", # NOQA E501
+ e,
+ list(c.keys()),
+ )
try:
- fallback_art_url = candidates[0]['artworkUrl100']
- fallback_art_url = fallback_art_url.replace('100x100bb',
- image_suffix)
- yield self._candidate(url=fallback_art_url,
- match=Candidate.MATCH_FALLBACK)
+ fallback_art_url = candidates[0]["artworkUrl100"]
+ fallback_art_url = fallback_art_url.replace(
+ "100x100bb", image_suffix
+ )
+ yield self._candidate(
+ url=fallback_art_url, match=Candidate.MATCH_FALLBACK
+ )
except KeyError as e:
- self._log.debug('Malformed itunes candidate: {} not found in {}',
- e,
- list(c.keys()))
+ self._log.debug(
+ "Malformed itunes candidate: {} not found in {}",
+ e,
+ list(c.keys()),
+ )
class Wikipedia(RemoteArtSource):
NAME = "Wikipedia (queried through DBpedia)"
- DBPEDIA_URL = 'https://dbpedia.org/sparql'
- WIKIPEDIA_URL = 'https://en.wikipedia.org/w/api.php'
- SPARQL_QUERY = '''PREFIX rdf:
+ DBPEDIA_URL = "https://dbpedia.org/sparql"
+ WIKIPEDIA_URL = "https://en.wikipedia.org/w/api.php"
+ SPARQL_QUERY = """PREFIX rdf:
PREFIX dbpprop:
PREFIX owl:
PREFIX rdfs:
@@ -658,7 +801,7 @@ class Wikipedia(RemoteArtSource):
?subject dbpprop:cover ?coverFilename .
FILTER ( regex(?name, "{album}", "i") )
}}
- Limit 1'''
+ Limit 1"""
def get(self, album, plugin, paths):
if not (album.albumartist and album.album):
@@ -671,28 +814,31 @@ class Wikipedia(RemoteArtSource):
dbpedia_response = self.request(
self.DBPEDIA_URL,
params={
- 'format': 'application/sparql-results+json',
- 'timeout': 2500,
- 'query': self.SPARQL_QUERY.format(
- artist=album.albumartist.title(), album=album.album)
+ "format": "application/sparql-results+json",
+ "timeout": 2500,
+ "query": self.SPARQL_QUERY.format(
+ artist=album.albumartist.title(), album=album.album
+ ),
},
- headers={'content-type': 'application/json'},
+ headers={"content-type": "application/json"},
)
except requests.RequestException:
- self._log.debug('dbpedia: error receiving response')
+ self._log.debug("dbpedia: error receiving response")
return
try:
data = dbpedia_response.json()
- results = data['results']['bindings']
+ results = data["results"]["bindings"]
if results:
- cover_filename = 'File:' + results[0]['coverFilename']['value']
- page_id = results[0]['pageId']['value']
+ cover_filename = "File:" + results[0]["coverFilename"]["value"]
+ page_id = results[0]["pageId"]["value"]
else:
- self._log.debug('wikipedia: album not found on dbpedia')
+ self._log.debug("wikipedia: album not found on dbpedia")
except (ValueError, KeyError, IndexError):
- self._log.debug('wikipedia: error scraping dbpedia response: {}',
- dbpedia_response.text)
+ self._log.debug(
+ "wikipedia: error scraping dbpedia response: {}",
+ dbpedia_response.text,
+ )
# Ensure we have a filename before attempting to query wikipedia
if not (cover_filename and page_id):
@@ -703,43 +849,44 @@ class Wikipedia(RemoteArtSource):
# An additional Wikipedia call can help to find the real filename.
# This may be removed once the DBPedia issue is resolved, see:
# https://github.com/dbpedia/extraction-framework/issues/396
- if ' .' in cover_filename and \
- '.' not in cover_filename.split(' .')[-1]:
+ if " ." in cover_filename and "." not in cover_filename.split(" .")[-1]:
self._log.debug(
- 'wikipedia: dbpedia provided incomplete cover_filename'
+ "wikipedia: dbpedia provided incomplete cover_filename"
)
- lpart, rpart = cover_filename.rsplit(' .', 1)
+ lpart, rpart = cover_filename.rsplit(" .", 1)
# Query all the images in the page
try:
wikipedia_response = self.request(
self.WIKIPEDIA_URL,
params={
- 'format': 'json',
- 'action': 'query',
- 'continue': '',
- 'prop': 'images',
- 'pageids': page_id,
+ "format": "json",
+ "action": "query",
+ "continue": "",
+ "prop": "images",
+ "pageids": page_id,
},
- headers={'content-type': 'application/json'},
+ headers={"content-type": "application/json"},
)
except requests.RequestException:
- self._log.debug('wikipedia: error receiving response')
+ self._log.debug("wikipedia: error receiving response")
return
# Try to see if one of the images on the pages matches our
# incomplete cover_filename
try:
data = wikipedia_response.json()
- results = data['query']['pages'][page_id]['images']
+ results = data["query"]["pages"][page_id]["images"]
for result in results:
- if re.match(re.escape(lpart) + r'.*?\.' + re.escape(rpart),
- result['title']):
- cover_filename = result['title']
+ if re.match(
+ re.escape(lpart) + r".*?\." + re.escape(rpart),
+ result["title"],
+ ):
+ cover_filename = result["title"]
break
except (ValueError, KeyError):
self._log.debug(
- 'wikipedia: failed to retrieve a cover_filename'
+ "wikipedia: failed to retrieve a cover_filename"
)
return
@@ -748,28 +895,29 @@ class Wikipedia(RemoteArtSource):
wikipedia_response = self.request(
self.WIKIPEDIA_URL,
params={
- 'format': 'json',
- 'action': 'query',
- 'continue': '',
- 'prop': 'imageinfo',
- 'iiprop': 'url',
- 'titles': cover_filename.encode('utf-8'),
+ "format": "json",
+ "action": "query",
+ "continue": "",
+ "prop": "imageinfo",
+ "iiprop": "url",
+ "titles": cover_filename.encode("utf-8"),
},
- headers={'content-type': 'application/json'},
+ headers={"content-type": "application/json"},
)
except requests.RequestException:
- self._log.debug('wikipedia: error receiving response')
+ self._log.debug("wikipedia: error receiving response")
return
try:
data = wikipedia_response.json()
- results = data['query']['pages']
+ results = data["query"]["pages"]
for _, result in results.items():
- image_url = result['imageinfo'][0]['url']
- yield self._candidate(url=image_url,
- match=Candidate.MATCH_EXACT)
+ image_url = result["imageinfo"][0]["url"]
+ yield self._candidate(
+ url=image_url, match=Candidate.MATCH_EXACT
+ )
except (ValueError, KeyError, IndexError):
- self._log.debug('wikipedia: error scraping imageinfo')
+ self._log.debug("wikipedia: error scraping imageinfo")
return
@@ -787,13 +935,12 @@ class FileSystem(LocalArtSource):
return [idx for (idx, x) in enumerate(cover_names) if x in filename]
def get(self, album, plugin, paths):
- """Look for album art files in the specified directories.
- """
+ """Look for album art files in the specified directories."""
if not paths:
return
cover_names = list(map(util.bytestring_path, plugin.cover_names))
- cover_names_str = b'|'.join(cover_names)
- cover_pat = br''.join([br"(\b|_)(", cover_names_str, br")(\b|_)"])
+ cover_names_str = b"|".join(cover_names)
+ cover_pat = rb"".join([rb"(\b|_)(", cover_names_str, rb")(\b|_)"])
for path in paths:
if not os.path.isdir(syspath(path)):
@@ -801,113 +948,233 @@ class FileSystem(LocalArtSource):
# Find all files that look like images in the directory.
images = []
- ignore = config['ignore'].as_str_seq()
- ignore_hidden = config['ignore_hidden'].get(bool)
- for _, _, files in sorted_walk(path, ignore=ignore,
- ignore_hidden=ignore_hidden):
+ ignore = config["ignore"].as_str_seq()
+ ignore_hidden = config["ignore_hidden"].get(bool)
+ for _, _, files in sorted_walk(
+ path, ignore=ignore, ignore_hidden=ignore_hidden
+ ):
for fn in files:
fn = bytestring_path(fn)
for ext in IMAGE_EXTENSIONS:
- if fn.lower().endswith(b'.' + ext) and \
- os.path.isfile(syspath(os.path.join(path, fn))):
+ if fn.lower().endswith(b"." + ext) and os.path.isfile(
+ syspath(os.path.join(path, fn))
+ ):
images.append(fn)
# Look for "preferred" filenames.
- images = sorted(images,
- key=lambda x:
- self.filename_priority(x, cover_names))
+ images = sorted(
+ images, key=lambda x: self.filename_priority(x, cover_names)
+ )
remaining = []
for fn in images:
if re.search(cover_pat, os.path.splitext(fn)[0], re.I):
- self._log.debug('using well-named art file {0}',
- util.displayable_path(fn))
- yield self._candidate(path=os.path.join(path, fn),
- match=Candidate.MATCH_EXACT)
+ self._log.debug(
+ "using well-named art file {0}",
+ util.displayable_path(fn),
+ )
+ yield self._candidate(
+ path=os.path.join(path, fn), match=Candidate.MATCH_EXACT
+ )
else:
remaining.append(fn)
# Fall back to any image in the folder.
if remaining and not plugin.cautious:
- self._log.debug('using fallback art file {0}',
- util.displayable_path(remaining[0]))
- yield self._candidate(path=os.path.join(path, remaining[0]),
- match=Candidate.MATCH_FALLBACK)
+ self._log.debug(
+ "using fallback art file {0}",
+ util.displayable_path(remaining[0]),
+ )
+ yield self._candidate(
+ path=os.path.join(path, remaining[0]),
+ match=Candidate.MATCH_FALLBACK,
+ )
class LastFM(RemoteArtSource):
NAME = "Last.fm"
# Sizes in priority order.
- SIZES = OrderedDict([
- ('mega', (300, 300)),
- ('extralarge', (300, 300)),
- ('large', (174, 174)),
- ('medium', (64, 64)),
- ('small', (34, 34)),
- ])
+ SIZES = OrderedDict(
+ [
+ ("mega", (300, 300)),
+ ("extralarge", (300, 300)),
+ ("large", (174, 174)),
+ ("medium", (64, 64)),
+ ("small", (34, 34)),
+ ]
+ )
- API_URL = 'https://ws.audioscrobbler.com/2.0'
+ API_URL = "https://ws.audioscrobbler.com/2.0"
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
- self.key = self._config['lastfm_key'].get(),
+ self.key = (self._config["lastfm_key"].get(),)
+
+ @staticmethod
+ def add_default_config(config):
+ config.add(
+ {
+ "lastfm_key": None,
+ }
+ )
+ config["lastfm_key"].redact = True
+
+ @classmethod
+ def available(cls, log, config):
+ has_key = bool(config["lastfm_key"].get())
+ if not has_key:
+ log.debug("lastfm: Disabling art source due to missing key")
+ return has_key
def get(self, album, plugin, paths):
if not album.mb_albumid:
return
try:
- response = self.request(self.API_URL, params={
- 'method': 'album.getinfo',
- 'api_key': self.key,
- 'mbid': album.mb_albumid,
- 'format': 'json',
- })
+ response = self.request(
+ self.API_URL,
+ params={
+ "method": "album.getinfo",
+ "api_key": self.key,
+ "mbid": album.mb_albumid,
+ "format": "json",
+ },
+ )
except requests.RequestException:
- self._log.debug('lastfm: error receiving response')
+ self._log.debug("lastfm: error receiving response")
return
try:
data = response.json()
- if 'error' in data:
- if data['error'] == 6:
- self._log.debug('lastfm: no results for {}',
- album.mb_albumid)
+ if "error" in data:
+ if data["error"] == 6:
+ self._log.debug(
+ "lastfm: no results for {}", album.mb_albumid
+ )
else:
self._log.error(
- 'lastfm: failed to get album info: {} ({})',
- data['message'], data['error'])
+ "lastfm: failed to get album info: {} ({})",
+ data["message"],
+ data["error"],
+ )
else:
- images = {image['size']: image['#text']
- for image in data['album']['image']}
+ images = {
+ image["size"]: image["#text"]
+ for image in data["album"]["image"]
+ }
# Provide candidates in order of size.
for size in self.SIZES.keys():
if size in images:
- yield self._candidate(url=images[size],
- size=self.SIZES[size])
+ yield self._candidate(
+ url=images[size], size=self.SIZES[size]
+ )
except ValueError:
- self._log.debug('lastfm: error loading response: {}'
- .format(response.text))
+ self._log.debug(
+ "lastfm: error loading response: {}".format(response.text)
+ )
return
+
+class Spotify(RemoteArtSource):
+ NAME = "Spotify"
+
+ SPOTIFY_ALBUM_URL = "https://open.spotify.com/album/"
+
+ @classmethod
+ def available(cls, log, config):
+ if not HAS_BEAUTIFUL_SOUP:
+ log.debug(
+ "To use Spotify as an album art source, "
+ "you must install the beautifulsoup4 module. See "
+ "the documentation for further details."
+ )
+ return HAS_BEAUTIFUL_SOUP
+
+ def get(self, album, plugin, paths):
+ try:
+ url = self.SPOTIFY_ALBUM_URL + album.items().get().spotify_album_id
+ except AttributeError:
+ self._log.debug("Fetchart: no Spotify album ID found")
+ return
+ try:
+ response = requests.get(url, timeout=10)
+ response.raise_for_status()
+ except requests.RequestException as e:
+ self._log.debug("Error: " + str(e))
+ return
+ try:
+ html = response.text
+ soup = BeautifulSoup(html, "html.parser")
+ image_url = soup.find("meta", attrs={"property": "og:image"})[
+ "content"
+ ]
+ yield self._candidate(url=image_url, match=Candidate.MATCH_EXACT)
+ except ValueError:
+ self._log.debug(
+ "Spotify: error loading response: {}".format(response.text)
+ )
+ return
+
+
+class CoverArtUrl(RemoteArtSource):
+ # This source is intended to be used with a plugin that sets the
+ # cover_art_url field on albums or tracks. Users can also manually update
+ # the cover_art_url field using the "set" command. This source will then
+ # use that URL to fetch the image.
+
+ NAME = "Cover Art URL"
+
+ def get(self, album, plugin, paths):
+ image_url = None
+ try:
+ # look for cover_art_url on album or first track
+ if album.get("cover_art_url"):
+ image_url = album.cover_art_url
+ else:
+ image_url = album.items().get().cover_art_url
+ self._log.debug(f"Cover art URL {image_url} found for {album}")
+ except (AttributeError, TypeError):
+ self._log.debug(f"Cover art URL not found for {album}")
+ return
+ if image_url:
+ yield self._candidate(url=image_url, match=Candidate.MATCH_EXACT)
+ else:
+ self._log.debug(f"Cover art URL not found for {album}")
+ return
+
+
# Try each source in turn.
-SOURCES_ALL = ['filesystem',
- 'coverart', 'itunes', 'amazon', 'albumart',
- 'wikipedia', 'google', 'fanarttv', 'lastfm']
+# Note that SOURCES_ALL is redundant (and presently unused). However, we keep
+# it around nn order not break plugins that "register" (a.k.a. monkey-patch)
+# their own fetchart sources.
+SOURCES_ALL = [
+ "filesystem",
+ "coverart",
+ "itunes",
+ "amazon",
+ "albumart",
+ "wikipedia",
+ "google",
+ "fanarttv",
+ "lastfm",
+ "spotify",
+]
ART_SOURCES = {
- 'filesystem': FileSystem,
- 'coverart': CoverArtArchive,
- 'itunes': ITunesStore,
- 'albumart': AlbumArtOrg,
- 'amazon': Amazon,
- 'wikipedia': Wikipedia,
- 'google': GoogleImages,
- 'fanarttv': FanartTV,
- 'lastfm': LastFM,
+ "filesystem": FileSystem,
+ "coverart": CoverArtArchive,
+ "itunes": ITunesStore,
+ "albumart": AlbumArtOrg,
+ "amazon": Amazon,
+ "wikipedia": Wikipedia,
+ "google": GoogleImages,
+ "fanarttv": FanartTV,
+ "lastfm": LastFM,
+ "spotify": Spotify,
+ "cover_art_url": CoverArtUrl,
}
SOURCE_NAMES = {v: k for k, v in ART_SOURCES.items()}
@@ -925,114 +1192,127 @@ class FetchArtPlugin(plugins.BeetsPlugin, RequestMixin):
# fetching them and placing them in the filesystem.
self.art_candidates = {}
- self.config.add({
- 'auto': True,
- 'minwidth': 0,
- 'maxwidth': 0,
- 'quality': 0,
- 'max_filesize': 0,
- 'enforce_ratio': False,
- 'cautious': False,
- 'cover_names': ['cover', 'front', 'art', 'album', 'folder'],
- 'sources': ['filesystem',
- 'coverart', 'itunes', 'amazon', 'albumart'],
- 'google_key': None,
- 'google_engine': '001442825323518660753:hrh5ch1gjzm',
- 'fanarttv_key': None,
- 'lastfm_key': None,
- 'store_source': False,
- 'high_resolution': False,
- 'deinterlace': False,
- 'cover_format': None,
- })
- self.config['google_key'].redact = True
- self.config['fanarttv_key'].redact = True
- self.config['lastfm_key'].redact = True
+ self.config.add(
+ {
+ "auto": True,
+ "minwidth": 0,
+ "maxwidth": 0,
+ "quality": 0,
+ "max_filesize": 0,
+ "enforce_ratio": False,
+ "cautious": False,
+ "cover_names": ["cover", "front", "art", "album", "folder"],
+ "sources": [
+ "filesystem",
+ "coverart",
+ "itunes",
+ "amazon",
+ "albumart",
+ "cover_art_url",
+ ],
+ "store_source": False,
+ "high_resolution": False,
+ "deinterlace": False,
+ "cover_format": None,
+ }
+ )
+ for source in ART_SOURCES.values():
+ source.add_default_config(self.config)
- self.minwidth = self.config['minwidth'].get(int)
- self.maxwidth = self.config['maxwidth'].get(int)
- self.max_filesize = self.config['max_filesize'].get(int)
- self.quality = self.config['quality'].get(int)
+ self.minwidth = self.config["minwidth"].get(int)
+ self.maxwidth = self.config["maxwidth"].get(int)
+ self.max_filesize = self.config["max_filesize"].get(int)
+ self.quality = self.config["quality"].get(int)
# allow both pixel and percentage-based margin specifications
- self.enforce_ratio = self.config['enforce_ratio'].get(
- confuse.OneOf([bool,
- confuse.String(pattern=self.PAT_PX),
- confuse.String(pattern=self.PAT_PERCENT)]))
+ self.enforce_ratio = self.config["enforce_ratio"].get(
+ confuse.OneOf(
+ [
+ bool,
+ confuse.String(pattern=self.PAT_PX),
+ confuse.String(pattern=self.PAT_PERCENT),
+ ]
+ )
+ )
self.margin_px = None
self.margin_percent = None
- self.deinterlace = self.config['deinterlace'].get(bool)
+ self.deinterlace = self.config["deinterlace"].get(bool)
if type(self.enforce_ratio) is str:
- if self.enforce_ratio[-1] == '%':
+ if self.enforce_ratio[-1] == "%":
self.margin_percent = float(self.enforce_ratio[:-1]) / 100
- elif self.enforce_ratio[-2:] == 'px':
+ elif self.enforce_ratio[-2:] == "px":
self.margin_px = int(self.enforce_ratio[:-2])
else:
# shouldn't happen
raise confuse.ConfigValueError()
self.enforce_ratio = True
- cover_names = self.config['cover_names'].as_str_seq()
+ cover_names = self.config["cover_names"].as_str_seq()
self.cover_names = list(map(util.bytestring_path, cover_names))
- self.cautious = self.config['cautious'].get(bool)
- self.store_source = self.config['store_source'].get(bool)
+ self.cautious = self.config["cautious"].get(bool)
+ self.store_source = self.config["store_source"].get(bool)
- self.src_removed = (config['import']['delete'].get(bool) or
- config['import']['move'].get(bool))
+ self.src_removed = config["import"]["delete"].get(bool) or config[
+ "import"
+ ]["move"].get(bool)
- self.cover_format = self.config['cover_format'].get(
+ self.cover_format = self.config["cover_format"].get(
confuse.Optional(str)
)
- if self.config['auto']:
+ if self.config["auto"]:
# Enable two import hooks when fetching is enabled.
self.import_stages = [self.fetch_art]
- self.register_listener('import_task_files', self.assign_art)
+ self.register_listener("import_task_files", self.assign_art)
- available_sources = list(SOURCES_ALL)
- if not self.config['google_key'].get() and \
- 'google' in available_sources:
- available_sources.remove('google')
- if not self.config['lastfm_key'].get() and \
- 'lastfm' in available_sources:
- available_sources.remove('lastfm')
- available_sources = [(s, c)
- for s in available_sources
- for c in ART_SOURCES[s].VALID_MATCHING_CRITERIA]
+ available_sources = [
+ (s_name, c)
+ for (s_name, s_cls) in ART_SOURCES.items()
+ if s_cls.available(self._log, self.config)
+ for c in s_cls.VALID_MATCHING_CRITERIA
+ ]
sources = plugins.sanitize_pairs(
- self.config['sources'].as_pairs(default_value='*'),
- available_sources)
+ self.config["sources"].as_pairs(default_value="*"),
+ available_sources,
+ )
- if 'remote_priority' in self.config:
+ if "remote_priority" in self.config:
self._log.warning(
- 'The `fetch_art.remote_priority` configuration option has '
- 'been deprecated. Instead, place `filesystem` at the end of '
- 'your `sources` list.')
- if self.config['remote_priority'].get(bool):
+ "The `fetch_art.remote_priority` configuration option has "
+ "been deprecated. Instead, place `filesystem` at the end of "
+ "your `sources` list."
+ )
+ if self.config["remote_priority"].get(bool):
fs = []
others = []
for s, c in sources:
- if s == 'filesystem':
+ if s == "filesystem":
fs.append((s, c))
else:
others.append((s, c))
sources = others + fs
- self.sources = [ART_SOURCES[s](self._log, self.config, match_by=[c])
- for s, c in sources]
+ self.sources = [
+ ART_SOURCES[s](self._log, self.config, match_by=[c])
+ for s, c in sources
+ ]
# Asynchronous; after music is added to the library.
def fetch_art(self, session, task):
"""Find art for the album being imported."""
if task.is_album: # Only fetch art for full albums.
- if task.album.artpath and os.path.isfile(task.album.artpath):
+ if task.album.artpath and os.path.isfile(
+ syspath(task.album.artpath)
+ ):
# Album already has art (probably a re-import); skip it.
return
if task.choice_flag == importer.action.ASIS:
# For as-is imports, don't search Web sources for art.
local = True
- elif task.choice_flag in (importer.action.APPLY,
- importer.action.RETAG):
+ elif task.choice_flag in (
+ importer.action.APPLY,
+ importer.action.RETAG,
+ ):
# Search everywhere for art.
local = False
else:
@@ -1049,8 +1329,8 @@ class FetchArtPlugin(plugins.BeetsPlugin, RequestMixin):
if self.store_source:
# store the source of the chosen artwork in a flexible field
self._log.debug(
- "Storing art_source for {0.albumartist} - {0.album}",
- album)
+ "Storing art_source for {0.albumartist} - {0.album}", album
+ )
album.art_source = SOURCE_NAMES[type(candidate.source)]
album.store()
@@ -1067,21 +1347,29 @@ class FetchArtPlugin(plugins.BeetsPlugin, RequestMixin):
# Manual album art fetching.
def commands(self):
- cmd = ui.Subcommand('fetchart', help='download album art')
+ cmd = ui.Subcommand("fetchart", help="download album art")
cmd.parser.add_option(
- '-f', '--force', dest='force',
- action='store_true', default=False,
- help='re-download art when already present'
+ "-f",
+ "--force",
+ dest="force",
+ action="store_true",
+ default=False,
+ help="re-download art when already present",
)
cmd.parser.add_option(
- '-q', '--quiet', dest='quiet',
- action='store_true', default=False,
- help='quiet mode: do not output albums that already have artwork'
+ "-q",
+ "--quiet",
+ dest="quiet",
+ action="store_true",
+ default=False,
+ help="quiet mode: do not output albums that already have artwork",
)
def func(lib, opts, args):
- self.batch_fetch_art(lib, lib.albums(ui.decargs(args)), opts.force,
- opts.quiet)
+ self.batch_fetch_art(
+ lib, lib.albums(ui.decargs(args)), opts.force, opts.quiet
+ )
+
cmd.func = func
return [cmd]
@@ -1100,7 +1388,7 @@ class FetchArtPlugin(plugins.BeetsPlugin, RequestMixin):
for source in self.sources:
if source.IS_LOCAL or not local_only:
self._log.debug(
- 'trying source {0} for album {1.albumartist} - {1.album}',
+ "trying source {0} for album {1.albumartist} - {1.album}",
SOURCE_NAMES[type(source)],
album,
)
@@ -1111,8 +1399,10 @@ class FetchArtPlugin(plugins.BeetsPlugin, RequestMixin):
if candidate.validate(self):
out = candidate
self._log.debug(
- 'using {0.LOC_STR} image {1}'.format(
- source, util.displayable_path(out.path)))
+ "using {0.LOC_STR} image {1}".format(
+ source, util.displayable_path(out.path)
+ )
+ )
break
# Remove temporary files for invalid candidates.
source.cleanup(candidate)
@@ -1129,11 +1419,16 @@ class FetchArtPlugin(plugins.BeetsPlugin, RequestMixin):
fetchart CLI command.
"""
for album in albums:
- if album.artpath and not force and os.path.isfile(album.artpath):
+ if (
+ album.artpath
+ and not force
+ and os.path.isfile(syspath(album.artpath))
+ ):
if not quiet:
- message = ui.colorize('text_highlight_minor',
- 'has album art')
- self._log.info('{0}: {1}', album, message)
+ message = ui.colorize(
+ "text_highlight_minor", "has album art"
+ )
+ self._log.info("{0}: {1}", album, message)
else:
# In ordinary invocations, look for images on the
# filesystem. When forcing, however, always go to the Web
@@ -1143,7 +1438,7 @@ class FetchArtPlugin(plugins.BeetsPlugin, RequestMixin):
candidate = self.art_for_album(album, local_paths)
if candidate:
self._set_art(album, candidate)
- message = ui.colorize('text_success', 'found album art')
+ message = ui.colorize("text_success", "found album art")
else:
- message = ui.colorize('text_error', 'no art found')
- self._log.info('{0}: {1}', album, message)
+ message = ui.colorize("text_error", "no art found")
+ self._log.info("{0}: {1}", album, message)
diff --git a/lib/beetsplug/filefilter.py b/lib/beetsplug/filefilter.py
index ec8fddb4..5618c1bd 100644
--- a/lib/beetsplug/filefilter.py
+++ b/lib/beetsplug/filefilter.py
@@ -17,38 +17,40 @@
import re
+
from beets import config
-from beets.util import bytestring_path
-from beets.plugins import BeetsPlugin
from beets.importer import SingletonImportTask
+from beets.plugins import BeetsPlugin
+from beets.util import bytestring_path
class FileFilterPlugin(BeetsPlugin):
def __init__(self):
super().__init__()
- self.register_listener('import_task_created',
- self.import_task_created_event)
- self.config.add({
- 'path': '.*'
- })
+ self.register_listener(
+ "import_task_created", self.import_task_created_event
+ )
+ self.config.add({"path": ".*"})
- self.path_album_regex = \
- self.path_singleton_regex = \
- re.compile(bytestring_path(self.config['path'].get()))
+ self.path_album_regex = self.path_singleton_regex = re.compile(
+ bytestring_path(self.config["path"].get())
+ )
- if 'album_path' in self.config:
+ if "album_path" in self.config:
self.path_album_regex = re.compile(
- bytestring_path(self.config['album_path'].get()))
+ bytestring_path(self.config["album_path"].get())
+ )
- if 'singleton_path' in self.config:
+ if "singleton_path" in self.config:
self.path_singleton_regex = re.compile(
- bytestring_path(self.config['singleton_path'].get()))
+ bytestring_path(self.config["singleton_path"].get())
+ )
def import_task_created_event(self, session, task):
if task.items and len(task.items) > 0:
items_to_import = []
for item in task.items:
- if self.file_filter(item['path']):
+ if self.file_filter(item["path"]):
items_to_import.append(item)
if len(items_to_import) > 0:
task.items = items_to_import
@@ -58,7 +60,7 @@ class FileFilterPlugin(BeetsPlugin):
return []
elif isinstance(task, SingletonImportTask):
- if not self.file_filter(task.item['path']):
+ if not self.file_filter(task.item["path"]):
return []
# If not filtered, return the original task unchanged.
@@ -68,10 +70,9 @@ class FileFilterPlugin(BeetsPlugin):
"""Checks if the configured regular expressions allow the import
of the file given in full_path.
"""
- import_config = dict(config['import'])
+ import_config = dict(config["import"])
full_path = bytestring_path(full_path)
- if 'singletons' not in import_config or not import_config[
- 'singletons']:
+ if "singletons" not in import_config or not import_config["singletons"]:
# Album
return self.path_album_regex.match(full_path) is not None
else:
diff --git a/lib/beetsplug/fish.py b/lib/beetsplug/fish.py
index 21fd67f6..71ac8574 100644
--- a/lib/beetsplug/fish.py
+++ b/lib/beetsplug/fish.py
@@ -23,17 +23,19 @@ by default but can be added via the `-e` / `--extravalues` flag. For example:
"""
-from beets.plugins import BeetsPlugin
-from beets import library, ui
-from beets.ui import commands
-from operator import attrgetter
import os
+from operator import attrgetter
+
+from beets import library, ui
+from beets.plugins import BeetsPlugin
+from beets.ui import commands
+
BL_NEED2 = """complete -c beet -n '__fish_beet_needs_command' {} {}\n"""
BL_USE3 = """complete -c beet -n '__fish_beet_using_command {}' {} {}\n"""
BL_SUBS = """complete -c beet -n '__fish_at_level {} ""' {} {}\n"""
BL_EXTRA3 = """complete -c beet -n '__fish_beet_use_extra {}' {} {}\n"""
-HEAD = '''
+HEAD = """
function __fish_beet_needs_command
set cmd (commandline -opc)
if test (count $cmd) -eq 1
@@ -62,25 +64,35 @@ function __fish_beet_use_extra
end
return 1
end
-'''
+"""
class FishPlugin(BeetsPlugin):
-
def commands(self):
- cmd = ui.Subcommand('fish', help='generate Fish shell tab completions')
+ cmd = ui.Subcommand("fish", help="generate Fish shell tab completions")
cmd.func = self.run
- cmd.parser.add_option('-f', '--noFields', action='store_true',
- default=False,
- help='omit album/track field completions')
cmd.parser.add_option(
- '-e',
- '--extravalues',
- action='append',
- type='choice',
- choices=library.Item.all_keys() +
- library.Album.all_keys(),
- help='include specified field *values* in completions')
+ "-f",
+ "--noFields",
+ action="store_true",
+ default=False,
+ help="omit album/track field completions",
+ )
+ cmd.parser.add_option(
+ "-e",
+ "--extravalues",
+ action="append",
+ type="choice",
+ choices=library.Item.all_keys() + library.Album.all_keys(),
+ help="include specified field *values* in completions",
+ )
+ cmd.parser.add_option(
+ "-o",
+ "--output",
+ default="~/.config/fish/completions/beet.fish",
+ help="where to save the script. default: "
+ "~/.config/fish/completions",
+ )
return [cmd]
def run(self, lib, opts, args):
@@ -89,22 +101,20 @@ class FishPlugin(BeetsPlugin):
# If specified, also collect the values for these fields.
# Make a giant string of all the above, formatted in a way that
# allows Fish to do tab completion for the `beet` command.
- home_dir = os.path.expanduser("~")
- completion_dir = os.path.join(home_dir, '.config/fish/completions')
- try:
- os.makedirs(completion_dir)
- except OSError:
- if not os.path.isdir(completion_dir):
- raise
- completion_file_path = os.path.join(completion_dir, 'beet.fish')
+
+ completion_file_path = os.path.expanduser(opts.output)
+ completion_dir = os.path.dirname(completion_file_path)
+
+ if completion_dir != "":
+ os.makedirs(completion_dir, exist_ok=True)
+
nobasicfields = opts.noFields # Do not complete for album/track fields
extravalues = opts.extravalues # e.g., Also complete artists names
beetcmds = sorted(
- (commands.default_commands +
- commands.plugins.commands()),
- key=attrgetter('name'))
- fields = sorted(set(
- library.Album.all_keys() + library.Item.all_keys()))
+ (commands.default_commands + commands.plugins.commands()),
+ key=attrgetter("name"),
+ )
+ fields = sorted(set(library.Album.all_keys() + library.Item.all_keys()))
# Collect commands, their aliases, and their help text
cmd_names_help = []
for cmd in beetcmds:
@@ -115,19 +125,26 @@ class FishPlugin(BeetsPlugin):
# Concatenate the string
totstring = HEAD + "\n"
totstring += get_cmds_list([name[0] for name in cmd_names_help])
- totstring += '' if nobasicfields else get_standard_fields(fields)
- totstring += get_extravalues(lib, extravalues) if extravalues else ''
- totstring += "\n" + "# ====== {} =====".format(
- "setup basic beet completion") + "\n" * 2
+ totstring += "" if nobasicfields else get_standard_fields(fields)
+ totstring += get_extravalues(lib, extravalues) if extravalues else ""
+ totstring += (
+ "\n"
+ + "# ====== {} =====".format("setup basic beet completion")
+ + "\n" * 2
+ )
totstring += get_basic_beet_options()
- totstring += "\n" + "# ====== {} =====".format(
- "setup field completion for subcommands") + "\n"
- totstring += get_subcommands(
- cmd_names_help, nobasicfields, extravalues)
+ totstring += (
+ "\n"
+ + "# ====== {} =====".format(
+ "setup field completion for subcommands"
+ )
+ + "\n"
+ )
+ totstring += get_subcommands(cmd_names_help, nobasicfields, extravalues)
# Set up completion for all the command options
totstring += get_all_commands(beetcmds)
- with open(completion_file_path, 'w') as fish_file:
+ with open(completion_file_path, "w") as fish_file:
fish_file.write(totstring)
@@ -140,32 +157,31 @@ def _escape(name):
def get_cmds_list(cmds_names):
# Make a list of all Beets core & plugin commands
- substr = ''
- substr += (
- "set CMDS " + " ".join(cmds_names) + ("\n" * 2)
- )
+ substr = ""
+ substr += "set CMDS " + " ".join(cmds_names) + ("\n" * 2)
return substr
def get_standard_fields(fields):
# Make a list of album/track fields and append with ':'
fields = (field + ":" for field in fields)
- substr = ''
- substr += (
- "set FIELDS " + " ".join(fields) + ("\n" * 2)
- )
+ substr = ""
+ substr += "set FIELDS " + " ".join(fields) + ("\n" * 2)
return substr
def get_extravalues(lib, extravalues):
# Make a list of all values from an album/track field.
# 'beet ls albumartist: ' yields completions for ABBA, Beatles, etc.
- word = ''
+ word = ""
values_set = get_set_of_values_for_field(lib, extravalues)
for fld in extravalues:
- extraname = fld.upper() + 'S'
+ extraname = fld.upper() + "S"
word += (
- "set " + extraname + " " + " ".join(sorted(values_set[fld]))
+ "set "
+ + extraname
+ + " "
+ + " ".join(sorted(values_set[fld]))
+ ("\n" * 2)
)
return word
@@ -184,21 +200,24 @@ def get_set_of_values_for_field(lib, fields):
def get_basic_beet_options():
word = (
- BL_NEED2.format("-l format-item",
- "-f -d 'print with custom format'") +
- BL_NEED2.format("-l format-album",
- "-f -d 'print with custom format'") +
- BL_NEED2.format("-s l -l library",
- "-f -r -d 'library database file to use'") +
- BL_NEED2.format("-s d -l directory",
- "-f -r -d 'destination music directory'") +
- BL_NEED2.format("-s v -l verbose",
- "-f -d 'print debugging information'") +
-
- BL_NEED2.format("-s c -l config",
- "-f -r -d 'path to configuration file'") +
- BL_NEED2.format("-s h -l help",
- "-f -d 'print this help message and exit'"))
+ BL_NEED2.format("-l format-item", "-f -d 'print with custom format'")
+ + BL_NEED2.format("-l format-album", "-f -d 'print with custom format'")
+ + BL_NEED2.format(
+ "-s l -l library", "-f -r -d 'library database file to use'"
+ )
+ + BL_NEED2.format(
+ "-s d -l directory", "-f -r -d 'destination music directory'"
+ )
+ + BL_NEED2.format(
+ "-s v -l verbose", "-f -d 'print debugging information'"
+ )
+ + BL_NEED2.format(
+ "-s c -l config", "-f -r -d 'path to configuration file'"
+ )
+ + BL_NEED2.format(
+ "-s h -l help", "-f -d 'print this help message and exit'"
+ )
+ )
return word
@@ -208,27 +227,35 @@ def get_subcommands(cmd_name_and_help, nobasicfields, extravalues):
for cmdname, cmdhelp in cmd_name_and_help:
cmdname = _escape(cmdname)
- word += "\n" + "# ------ {} -------".format(
- "fieldsetups for " + cmdname) + "\n"
word += (
- BL_NEED2.format(
- ("-a " + cmdname),
- ("-f " + "-d " + wrap(clean_whitespace(cmdhelp)))))
+ "\n"
+ + "# ------ {} -------".format("fieldsetups for " + cmdname)
+ + "\n"
+ )
+ word += BL_NEED2.format(
+ ("-a " + cmdname), ("-f " + "-d " + wrap(clean_whitespace(cmdhelp)))
+ )
if nobasicfields is False:
- word += (
- BL_USE3.format(
- cmdname,
- ("-a " + wrap("$FIELDS")),
- ("-f " + "-d " + wrap("fieldname"))))
+ word += BL_USE3.format(
+ cmdname,
+ ("-a " + wrap("$FIELDS")),
+ ("-f " + "-d " + wrap("fieldname")),
+ )
if extravalues:
for f in extravalues:
setvar = wrap("$" + f.upper() + "S")
- word += " ".join(BL_EXTRA3.format(
- (cmdname + " " + f + ":"),
- ('-f ' + '-A ' + '-a ' + setvar),
- ('-d ' + wrap(f))).split()) + "\n"
+ word += (
+ " ".join(
+ BL_EXTRA3.format(
+ (cmdname + " " + f + ":"),
+ ("-f " + "-A " + "-a " + setvar),
+ ("-d " + wrap(f)),
+ ).split()
+ )
+ + "\n"
+ )
return word
@@ -242,30 +269,59 @@ def get_all_commands(beetcmds):
name = _escape(name)
word += "\n"
- word += ("\n" * 2) + "# ====== {} =====".format(
- "completions for " + name) + "\n"
+ word += (
+ ("\n" * 2)
+ + "# ====== {} =====".format("completions for " + name)
+ + "\n"
+ )
for option in cmd.parser._get_all_options()[1:]:
- cmd_l = (" -l " + option._long_opts[0].replace('--', '')
- )if option._long_opts else ''
- cmd_s = (" -s " + option._short_opts[0].replace('-', '')
- ) if option._short_opts else ''
- cmd_need_arg = ' -r ' if option.nargs in [1] else ''
- cmd_helpstr = (" -d " + wrap(' '.join(option.help.split()))
- ) if option.help else ''
- cmd_arglist = (' -a ' + wrap(" ".join(option.choices))
- ) if option.choices else ''
+ cmd_l = (
+ (" -l " + option._long_opts[0].replace("--", ""))
+ if option._long_opts
+ else ""
+ )
+ cmd_s = (
+ (" -s " + option._short_opts[0].replace("-", ""))
+ if option._short_opts
+ else ""
+ )
+ cmd_need_arg = " -r " if option.nargs in [1] else ""
+ cmd_helpstr = (
+ (" -d " + wrap(" ".join(option.help.split())))
+ if option.help
+ else ""
+ )
+ cmd_arglist = (
+ (" -a " + wrap(" ".join(option.choices)))
+ if option.choices
+ else ""
+ )
- word += " ".join(BL_USE3.format(
+ word += (
+ " ".join(
+ BL_USE3.format(
+ name,
+ (
+ cmd_need_arg
+ + cmd_s
+ + cmd_l
+ + " -f "
+ + cmd_arglist
+ ),
+ cmd_helpstr,
+ ).split()
+ )
+ + "\n"
+ )
+
+ word = word + " ".join(
+ BL_USE3.format(
name,
- (cmd_need_arg + cmd_s + cmd_l + " -f " + cmd_arglist),
- cmd_helpstr).split()) + "\n"
-
- word = (word + " ".join(BL_USE3.format(
- name,
- ("-s " + "h " + "-l " + "help" + " -f "),
- ('-d ' + wrap("print help") + "\n")
- ).split()))
+ ("-s " + "h " + "-l " + "help" + " -f "),
+ ("-d " + wrap("print help") + "\n"),
+ ).split()
+ )
return word
@@ -276,7 +332,7 @@ def clean_whitespace(word):
def wrap(word):
# Need " or ' around strings but watch out if they're in the string
- sptoken = '\"'
+ sptoken = '"'
if ('"') in word and ("'") in word:
word.replace('"', sptoken)
return '"' + word + '"'
diff --git a/lib/beetsplug/freedesktop.py b/lib/beetsplug/freedesktop.py
index ba4d5879..a9a25279 100644
--- a/lib/beetsplug/freedesktop.py
+++ b/lib/beetsplug/freedesktop.py
@@ -16,20 +16,25 @@
"""
-from beets.plugins import BeetsPlugin
from beets import ui
+from beets.plugins import BeetsPlugin
class FreedesktopPlugin(BeetsPlugin):
def commands(self):
deprecated = ui.Subcommand(
"freedesktop",
- help="Print a message to redirect to thumbnails --dolphin")
+ help="Print a message to redirect to thumbnails --dolphin",
+ )
deprecated.func = self.deprecation_message
return [deprecated]
def deprecation_message(self, lib, opts, args):
- ui.print_("This plugin is deprecated. Its functionality is "
- "superseded by the 'thumbnails' plugin")
- ui.print_("'thumbnails --dolphin' replaces freedesktop. See doc & "
- "changelog for more information")
+ ui.print_(
+ "This plugin is deprecated. Its functionality is "
+ "superseded by the 'thumbnails' plugin"
+ )
+ ui.print_(
+ "'thumbnails --dolphin' replaces freedesktop. See doc & "
+ "changelog for more information"
+ )
diff --git a/lib/beetsplug/fromfilename.py b/lib/beetsplug/fromfilename.py
index 55684a27..103e8290 100644
--- a/lib/beetsplug/fromfilename.py
+++ b/lib/beetsplug/fromfilename.py
@@ -16,35 +16,34 @@
filename.
"""
-from beets import plugins
-from beets.util import displayable_path
import os
import re
+from beets import plugins
+from beets.util import displayable_path
# Filename field extraction patterns.
PATTERNS = [
- # Useful patterns.
- r'^(?P.+)[\-_](?P.+)[\-_](?P.*)$',
- r'^(?P
\s*]*)>', '\n', html)
- return re.sub(r'
\s*
', '\n', html)
+ html = re.sub(r"\s*]*)>", "\n", html)
+ return re.sub(r"
\s*
", "\n", html)
def scrape_lyrics_from_html(html):
"""Scrape lyrics from a URL. If no lyrics can be found, return None
instead.
"""
+
def is_text_notcode(text):
+ if not text:
+ return False
length = len(text)
- return (length > 20 and
- text.count(' ') > length / 25 and
- (text.find('{') == -1 or text.find(';') == -1))
+ return (
+ length > 20
+ and text.count(" ") > length / 25
+ and (text.find("{") == -1 or text.find(";") == -1)
+ )
+
html = _scrape_strip_cruft(html)
html = _scrape_merge_paragraphs(html)
# extract all long text blocks that are not code
- soup = try_parse_html(html, parse_only=SoupStrainer(text=is_text_notcode))
+ soup = try_parse_html(html, parse_only=SoupStrainer(string=is_text_notcode))
if not soup:
return None
@@ -566,55 +658,53 @@ class Google(Backend):
def __init__(self, config, log):
super().__init__(config, log)
- self.api_key = config['google_API_key'].as_str()
- self.engine_id = config['google_engine_ID'].as_str()
+ self.api_key = config["google_API_key"].as_str()
+ self.engine_id = config["google_engine_ID"].as_str()
def is_lyrics(self, text, artist=None):
- """Determine whether the text seems to be valid lyrics.
- """
+ """Determine whether the text seems to be valid lyrics."""
if not text:
return False
bad_triggers_occ = []
- nb_lines = text.count('\n')
+ nb_lines = text.count("\n")
if nb_lines <= 1:
self._log.debug("Ignoring too short lyrics '{0}'", text)
return False
elif nb_lines < 5:
- bad_triggers_occ.append('too_short')
+ bad_triggers_occ.append("too_short")
else:
# Lyrics look legit, remove credits to avoid being penalized
# further down
text = remove_credits(text)
- bad_triggers = ['lyrics', 'copyright', 'property', 'links']
+ bad_triggers = ["lyrics", "copyright", "property", "links"]
if artist:
bad_triggers += [artist]
for item in bad_triggers:
- bad_triggers_occ += [item] * len(re.findall(r'\W%s\W' % item,
- text, re.I))
+ bad_triggers_occ += [item] * len(
+ re.findall(r"\W%s\W" % item, text, re.I)
+ )
if bad_triggers_occ:
- self._log.debug('Bad triggers detected: {0}', bad_triggers_occ)
+ self._log.debug("Bad triggers detected: {0}", bad_triggers_occ)
return len(bad_triggers_occ) < 2
def slugify(self, text):
- """Normalize a string and remove non-alphanumeric characters.
- """
- text = re.sub(r"[-'_\s]", '_', text)
- text = re.sub(r"_+", '_', text).strip('_')
+ """Normalize a string and remove non-alphanumeric characters."""
+ text = re.sub(r"[-'_\s]", "_", text)
+ text = re.sub(r"_+", "_", text).strip("_")
pat = r"([^,\(]*)\((.*?)\)" # Remove content within parentheses
- text = re.sub(pat, r'\g<1>', text).strip()
+ text = re.sub(pat, r"\g<1>", text).strip()
try:
- text = unicodedata.normalize('NFKD', text).encode('ascii',
- 'ignore')
- text = str(re.sub(r'[-\s]+', ' ', text.decode('utf-8')))
+ text = unicodedata.normalize("NFKD", text).encode("ascii", "ignore")
+ text = str(re.sub(r"[-\s]+", " ", text.decode("utf-8")))
except UnicodeDecodeError:
self._log.exception("Failing to normalize '{0}'", text)
return text
- BY_TRANS = ['by', 'par', 'de', 'von']
- LYRICS_TRANS = ['lyrics', 'paroles', 'letras', 'liedtexte']
+ BY_TRANS = ["by", "par", "de", "von"]
+ LYRICS_TRANS = ["lyrics", "paroles", "letras", "liedtexte"]
def is_page_candidate(self, url_link, url_title, title, artist):
"""Return True if the URL title makes it a good candidate to be a
@@ -622,8 +712,9 @@ class Google(Backend):
"""
title = self.slugify(title.lower())
artist = self.slugify(artist.lower())
- sitename = re.search("//([^/]+)/.*",
- self.slugify(url_link.lower())).group(1)
+ sitename = re.search(
+ "//([^/]+)/.*", self.slugify(url_link.lower())
+ ).group(1)
url_title = self.slugify(url_title.lower())
# Check if URL title contains song title (exact match)
@@ -632,42 +723,47 @@ class Google(Backend):
# or try extracting song title from URL title and check if
# they are close enough
- tokens = [by + '_' + artist for by in self.BY_TRANS] + \
- [artist, sitename, sitename.replace('www.', '')] + \
- self.LYRICS_TRANS
+ tokens = (
+ [by + "_" + artist for by in self.BY_TRANS]
+ + [artist, sitename, sitename.replace("www.", "")]
+ + self.LYRICS_TRANS
+ )
tokens = [re.escape(t) for t in tokens]
- song_title = re.sub('(%s)' % '|'.join(tokens), '', url_title)
+ song_title = re.sub("(%s)" % "|".join(tokens), "", url_title)
- song_title = song_title.strip('_|')
- typo_ratio = .9
+ song_title = song_title.strip("_|")
+ typo_ratio = 0.9
ratio = difflib.SequenceMatcher(None, song_title, title).ratio()
return ratio >= typo_ratio
- def fetch(self, artist, title):
+ def fetch(self, artist, title, album=None, length=None):
query = f"{artist} {title}"
- url = 'https://www.googleapis.com/customsearch/v1?key=%s&cx=%s&q=%s' \
- % (self.api_key, self.engine_id,
- urllib.parse.quote(query.encode('utf-8')))
+ url = "https://www.googleapis.com/customsearch/v1?key=%s&cx=%s&q=%s" % (
+ self.api_key,
+ self.engine_id,
+ urllib.parse.quote(query.encode("utf-8")),
+ )
data = self.fetch_url(url)
if not data:
- self._log.debug('google backend returned no data')
+ self._log.debug("google backend returned no data")
return None
try:
data = json.loads(data)
except ValueError as exc:
- self._log.debug('google backend returned malformed JSON: {}', exc)
- if 'error' in data:
- reason = data['error']['errors'][0]['reason']
- self._log.debug('google backend error: {0}', reason)
+ self._log.debug("google backend returned malformed JSON: {}", exc)
+ if "error" in data:
+ reason = data["error"]["errors"][0]["reason"]
+ self._log.debug("google backend error: {0}", reason)
return None
- if 'items' in data.keys():
- for item in data['items']:
- url_link = item['link']
- url_title = item.get('title', '')
- if not self.is_page_candidate(url_link, url_title,
- title, artist):
+ if "items" in data.keys():
+ for item in data["items"]:
+ url_link = item["link"]
+ url_title = item.get("title", "")
+ if not self.is_page_candidate(
+ url_link, url_title, title, artist
+ ):
continue
html = self.fetch_url(url_link)
if not html:
@@ -677,48 +773,53 @@ class Google(Backend):
continue
if self.is_lyrics(lyrics, artist):
- self._log.debug('got lyrics from {0}',
- item['displayLink'])
+ self._log.debug("got lyrics from {0}", item["displayLink"])
return lyrics
return None
class LyricsPlugin(plugins.BeetsPlugin):
- SOURCES = ['google', 'musixmatch', 'genius', 'tekstowo']
+ SOURCES = ["google", "musixmatch", "genius", "tekstowo", "lrclib"]
SOURCE_BACKENDS = {
- 'google': Google,
- 'musixmatch': MusiXmatch,
- 'genius': Genius,
- 'tekstowo': Tekstowo,
+ "google": Google,
+ "musixmatch": MusiXmatch,
+ "genius": Genius,
+ "tekstowo": Tekstowo,
+ "lrclib": LRCLib,
}
def __init__(self):
super().__init__()
self.import_stages = [self.imported]
- self.config.add({
- 'auto': True,
- 'bing_client_secret': None,
- 'bing_lang_from': [],
- 'bing_lang_to': None,
- 'google_API_key': None,
- 'google_engine_ID': '009217259823014548361:lndtuqkycfu',
- 'genius_api_key':
- "Ryq93pUGm8bM6eUWwD_M3NOFFDAtp2yEE7W"
+ self.config.add(
+ {
+ "auto": True,
+ "bing_client_secret": None,
+ "bing_lang_from": [],
+ "bing_lang_to": None,
+ "google_API_key": None,
+ "google_engine_ID": "009217259823014548361:lndtuqkycfu",
+ "genius_api_key": "Ryq93pUGm8bM6eUWwD_M3NOFFDAtp2yEE7W"
"76V-uFL5jks5dNvcGCdarqFjDhP9c",
- 'fallback': None,
- 'force': False,
- 'local': False,
- 'sources': self.SOURCES,
- })
- self.config['bing_client_secret'].redact = True
- self.config['google_API_key'].redact = True
- self.config['google_engine_ID'].redact = True
- self.config['genius_api_key'].redact = True
+ "fallback": None,
+ "force": False,
+ "local": False,
+ "synced": False,
+ # Musixmatch is disabled by default as they are currently blocking
+ # requests with the beets user agent.
+ "sources": [s for s in self.SOURCES if s != "musixmatch"],
+ "dist_thresh": 0.1,
+ }
+ )
+ self.config["bing_client_secret"].redact = True
+ self.config["google_API_key"].redact = True
+ self.config["google_engine_ID"].redact = True
+ self.config["genius_api_key"].redact = True
# State information for the ReST writer.
# First, the current artist we're writing.
- self.artist = 'Unknown artist'
+ self.artist = "Unknown artist"
# The current album: False means no album yet.
self.album = False
# The current rest file content. None means the file is not
@@ -727,41 +828,49 @@ class LyricsPlugin(plugins.BeetsPlugin):
available_sources = list(self.SOURCES)
sources = plugins.sanitize_choices(
- self.config['sources'].as_str_seq(), available_sources)
+ self.config["sources"].as_str_seq(), available_sources
+ )
if not HAS_BEAUTIFUL_SOUP:
sources = self.sanitize_bs_sources(sources)
- if 'google' in sources:
- if not self.config['google_API_key'].get():
+ if "google" in sources:
+ if not self.config["google_API_key"].get():
# We log a *debug* message here because the default
# configuration includes `google`. This way, the source
# is silent by default but can be enabled just by
# setting an API key.
- self._log.debug('Disabling google source: '
- 'no API key configured.')
- sources.remove('google')
+ self._log.debug(
+ "Disabling google source: " "no API key configured."
+ )
+ sources.remove("google")
- self.config['bing_lang_from'] = [
- x.lower() for x in self.config['bing_lang_from'].as_str_seq()]
+ self.config["bing_lang_from"] = [
+ x.lower() for x in self.config["bing_lang_from"].as_str_seq()
+ ]
self.bing_auth_token = None
- if not HAS_LANGDETECT and self.config['bing_client_secret'].get():
- self._log.warning('To use bing translations, you need to '
- 'install the langdetect module. See the '
- 'documentation for further details.')
+ if not HAS_LANGDETECT and self.config["bing_client_secret"].get():
+ self._log.warning(
+ "To use bing translations, you need to "
+ "install the langdetect module. See the "
+ "documentation for further details."
+ )
- self.backends = [self.SOURCE_BACKENDS[source](self.config, self._log)
- for source in sources]
+ self.backends = [
+ self.SOURCE_BACKENDS[source](self.config, self._log)
+ for source in sources
+ ]
def sanitize_bs_sources(self, sources):
enabled_sources = []
for source in sources:
if self.SOURCE_BACKENDS[source].REQUIRES_BS:
- self._log.debug('To use the %s lyrics source, you must '
- 'install the beautifulsoup4 module. See '
- 'the documentation for further details.'
- % source)
+ self._log.debug(
+ "To use the %s lyrics source, you must "
+ "install the beautifulsoup4 module. See "
+ "the documentation for further details." % source
+ )
else:
enabled_sources.append(source)
@@ -769,43 +878,62 @@ class LyricsPlugin(plugins.BeetsPlugin):
def get_bing_access_token(self):
params = {
- 'client_id': 'beets',
- 'client_secret': self.config['bing_client_secret'],
- 'scope': "https://api.microsofttranslator.com",
- 'grant_type': 'client_credentials',
+ "client_id": "beets",
+ "client_secret": self.config["bing_client_secret"],
+ "scope": "https://api.microsofttranslator.com",
+ "grant_type": "client_credentials",
}
- oauth_url = 'https://datamarket.accesscontrol.windows.net/v2/OAuth2-13'
- oauth_token = json.loads(requests.post(
- oauth_url,
- data=urllib.parse.urlencode(params)).content)
- if 'access_token' in oauth_token:
- return "Bearer " + oauth_token['access_token']
+ oauth_url = "https://datamarket.accesscontrol.windows.net/v2/OAuth2-13"
+ oauth_token = json.loads(
+ requests.post(
+ oauth_url,
+ data=urllib.parse.urlencode(params),
+ timeout=10,
+ ).content
+ )
+ if "access_token" in oauth_token:
+ return "Bearer " + oauth_token["access_token"]
else:
- self._log.warning('Could not get Bing Translate API access token.'
- ' Check your "bing_client_secret" password')
+ self._log.warning(
+ "Could not get Bing Translate API access token."
+ ' Check your "bing_client_secret" password'
+ )
def commands(self):
- cmd = ui.Subcommand('lyrics', help='fetch song lyrics')
+ cmd = ui.Subcommand("lyrics", help="fetch song lyrics")
cmd.parser.add_option(
- '-p', '--print', dest='printlyr',
- action='store_true', default=False,
- help='print lyrics to console',
+ "-p",
+ "--print",
+ dest="printlyr",
+ action="store_true",
+ default=False,
+ help="print lyrics to console",
)
cmd.parser.add_option(
- '-r', '--write-rest', dest='writerest',
- action='store', default=None, metavar='dir',
- help='write lyrics to given directory as ReST files',
+ "-r",
+ "--write-rest",
+ dest="writerest",
+ action="store",
+ default=None,
+ metavar="dir",
+ help="write lyrics to given directory as ReST files",
)
cmd.parser.add_option(
- '-f', '--force', dest='force_refetch',
- action='store_true', default=False,
- help='always re-download lyrics',
+ "-f",
+ "--force",
+ dest="force_refetch",
+ action="store_true",
+ default=False,
+ help="always re-download lyrics",
)
cmd.parser.add_option(
- '-l', '--local', dest='local_only',
- action='store_true', default=False,
- help='do not fetch missing lyrics',
+ "-l",
+ "--local",
+ dest="local_only",
+ action="store_true",
+ default=False,
+ help="do not fetch missing lyrics",
)
def func(lib, opts, args):
@@ -816,10 +944,12 @@ class LyricsPlugin(plugins.BeetsPlugin):
self.writerest_indexes(opts.writerest)
items = lib.items(ui.decargs(args))
for item in items:
- if not opts.local_only and not self.config['local']:
+ if not opts.local_only and not self.config["local"]:
self.fetch_item_lyrics(
- lib, item, write,
- opts.force_refetch or self.config['force'],
+ lib,
+ item,
+ write,
+ opts.force_refetch or self.config["force"],
)
if item.lyrics:
if opts.printlyr:
@@ -829,14 +959,21 @@ class LyricsPlugin(plugins.BeetsPlugin):
if opts.writerest and items:
# flush last artist & write to ReST
self.writerest(opts.writerest)
- ui.print_('ReST files generated. to build, use one of:')
- ui.print_(' sphinx-build -b html %s _build/html'
- % opts.writerest)
- ui.print_(' sphinx-build -b epub %s _build/epub'
- % opts.writerest)
- ui.print_((' sphinx-build -b latex %s _build/latex '
- '&& make -C _build/latex all-pdf')
- % opts.writerest)
+ ui.print_("ReST files generated. to build, use one of:")
+ ui.print_(
+ " sphinx-build -b html %s _build/html" % opts.writerest
+ )
+ ui.print_(
+ " sphinx-build -b epub %s _build/epub" % opts.writerest
+ )
+ ui.print_(
+ (
+ " sphinx-build -b latex %s _build/latex "
+ "&& make -C _build/latex all-pdf"
+ )
+ % opts.writerest
+ )
+
cmd.func = func
return [cmd]
@@ -851,29 +988,30 @@ class LyricsPlugin(plugins.BeetsPlugin):
# Write current file and start a new one ~ item.albumartist
self.writerest(directory)
self.artist = item.albumartist.strip()
- self.rest = "%s\n%s\n\n.. contents::\n :local:\n\n" \
- % (self.artist,
- '=' * len(self.artist))
+ self.rest = "%s\n%s\n\n.. contents::\n :local:\n\n" % (
+ self.artist,
+ "=" * len(self.artist),
+ )
if self.album != item.album:
tmpalbum = self.album = item.album.strip()
- if self.album == '':
- tmpalbum = 'Unknown album'
- self.rest += "{}\n{}\n\n".format(tmpalbum, '-' * len(tmpalbum))
+ if self.album == "":
+ tmpalbum = "Unknown album"
+ self.rest += "{}\n{}\n\n".format(tmpalbum, "-" * len(tmpalbum))
title_str = ":index:`%s`" % item.title.strip()
- block = '| ' + item.lyrics.replace('\n', '\n| ')
- self.rest += "{}\n{}\n\n{}\n\n".format(title_str,
- '~' * len(title_str),
- block)
+ block = "| " + item.lyrics.replace("\n", "\n| ")
+ self.rest += "{}\n{}\n\n{}\n\n".format(
+ title_str, "~" * len(title_str), block
+ )
def writerest(self, directory):
- """Write self.rest to a ReST file
- """
+ """Write self.rest to a ReST file"""
if self.rest is not None and self.artist is not None:
- path = os.path.join(directory, 'artists',
- slug(self.artist) + '.rst')
- with open(path, 'wb') as output:
- output.write(self.rest.encode('utf-8'))
+ path = os.path.join(
+ directory, "artists", slug(self.artist) + ".rst"
+ )
+ with open(path, "wb") as output:
+ output.write(self.rest.encode("utf-8"))
def writerest_indexes(self, directory):
"""Write conf.py and index.rst files necessary for Sphinx
@@ -882,59 +1020,65 @@ class LyricsPlugin(plugins.BeetsPlugin):
to operate. We do not overwrite existing files so that
customizations are respected."""
try:
- os.makedirs(os.path.join(directory, 'artists'))
+ os.makedirs(os.path.join(directory, "artists"))
except OSError as e:
if e.errno == errno.EEXIST:
pass
else:
raise
- indexfile = os.path.join(directory, 'index.rst')
+ indexfile = os.path.join(directory, "index.rst")
if not os.path.exists(indexfile):
- with open(indexfile, 'w') as output:
+ with open(indexfile, "w") as output:
output.write(REST_INDEX_TEMPLATE)
- conffile = os.path.join(directory, 'conf.py')
+ conffile = os.path.join(directory, "conf.py")
if not os.path.exists(conffile):
- with open(conffile, 'w') as output:
+ with open(conffile, "w") as output:
output.write(REST_CONF_TEMPLATE)
def imported(self, session, task):
- """Import hook for fetching lyrics automatically.
- """
- if self.config['auto']:
+ """Import hook for fetching lyrics automatically."""
+ if self.config["auto"]:
for item in task.imported_items():
- self.fetch_item_lyrics(session.lib, item,
- False, self.config['force'])
+ self.fetch_item_lyrics(
+ session.lib, item, False, self.config["force"]
+ )
def fetch_item_lyrics(self, lib, item, write, force):
"""Fetch and store lyrics for a single item. If ``write``, then the
- lyrics will also be written to the file itself.
+ lyrics will also be written to the file itself.
"""
# Skip if the item already has lyrics.
if not force and item.lyrics:
- self._log.info('lyrics already present: {0}', item)
+ self._log.info("lyrics already present: {0}", item)
return
lyrics = None
+ album = item.album
+ length = round(item.length)
for artist, titles in search_pairs(item):
- lyrics = [self.get_lyrics(artist, title) for title in titles]
+ lyrics = [
+ self.get_lyrics(artist, title, album=album, length=length)
+ for title in titles
+ ]
if any(lyrics):
break
lyrics = "\n\n---\n\n".join([l for l in lyrics if l])
if lyrics:
- self._log.info('fetched lyrics: {0}', item)
- if HAS_LANGDETECT and self.config['bing_client_secret'].get():
+ self._log.info("fetched lyrics: {0}", item)
+ if HAS_LANGDETECT and self.config["bing_client_secret"].get():
lang_from = langdetect.detect(lyrics)
- if self.config['bing_lang_to'].get() != lang_from and (
- not self.config['bing_lang_from'] or (
- lang_from in self.config[
- 'bing_lang_from'].as_str_seq())):
+ if self.config["bing_lang_to"].get() != lang_from and (
+ not self.config["bing_lang_from"]
+ or (lang_from in self.config["bing_lang_from"].as_str_seq())
+ ):
lyrics = self.append_translation(
- lyrics, self.config['bing_lang_to'])
+ lyrics, self.config["bing_lang_to"]
+ )
else:
- self._log.info('lyrics not found: {0}', item)
- fallback = self.config['fallback'].get()
+ self._log.info("lyrics not found: {0}", item)
+ fallback = self.config["fallback"].get()
if fallback:
lyrics = fallback
else:
@@ -944,15 +1088,16 @@ class LyricsPlugin(plugins.BeetsPlugin):
item.try_write()
item.store()
- def get_lyrics(self, artist, title):
+ def get_lyrics(self, artist, title, album=None, length=None):
"""Fetch lyrics, trying each source in turn. Return a string or
None if no lyrics were found.
"""
for backend in self.backends:
- lyrics = backend.fetch(artist, title)
+ lyrics = backend.fetch(artist, title, album=album, length=length)
if lyrics:
- self._log.debug('got lyrics from backend: {0}',
- backend.__class__.__name__)
+ self._log.debug(
+ "got lyrics from backend: {0}", backend.__class__.__name__
+ )
return _scrape_strip_cruft(lyrics, True)
def append_translation(self, text, to_lang):
@@ -962,23 +1107,30 @@ class LyricsPlugin(plugins.BeetsPlugin):
self.bing_auth_token = self.get_bing_access_token()
if self.bing_auth_token:
# Extract unique lines to limit API request size per song
- text_lines = set(text.split('\n'))
- url = ('https://api.microsofttranslator.com/v2/Http.svc/'
- 'Translate?text=%s&to=%s' % ('|'.join(text_lines), to_lang))
- r = requests.get(url,
- headers={"Authorization ": self.bing_auth_token})
+ text_lines = set(text.split("\n"))
+ url = (
+ "https://api.microsofttranslator.com/v2/Http.svc/"
+ "Translate?text=%s&to=%s" % ("|".join(text_lines), to_lang)
+ )
+ r = requests.get(
+ url,
+ headers={"Authorization ": self.bing_auth_token},
+ timeout=10,
+ )
if r.status_code != 200:
- self._log.debug('translation API error {}: {}', r.status_code,
- r.text)
- if 'token has expired' in r.text:
+ self._log.debug(
+ "translation API error {}: {}", r.status_code, r.text
+ )
+ if "token has expired" in r.text:
self.bing_auth_token = None
return self.append_translation(text, to_lang)
return text
lines_translated = ElementTree.fromstring(
- r.text.encode('utf-8')).text
+ r.text.encode("utf-8")
+ ).text
# Use a translation mapping dict to build resulting lyrics
- translations = dict(zip(text_lines, lines_translated.split('|')))
- result = ''
- for line in text.split('\n'):
- result += '{} / {}\n'.format(line, translations[line])
+ translations = dict(zip(text_lines, lines_translated.split("|")))
+ result = ""
+ for line in text.split("\n"):
+ result += "{} / {}\n".format(line, translations[line])
return result
diff --git a/lib/beetsplug/mbcollection.py b/lib/beetsplug/mbcollection.py
index f4a0d161..1c010bf5 100644
--- a/lib/beetsplug/mbcollection.py
+++ b/lib/beetsplug/mbcollection.py
@@ -13,30 +13,29 @@
# included in all copies or substantial portions of the Software.
-from beets.plugins import BeetsPlugin
-from beets.ui import Subcommand
-from beets import ui
-from beets import config
+import re
+
import musicbrainzngs
-import re
+from beets import config, ui
+from beets.plugins import BeetsPlugin
+from beets.ui import Subcommand
SUBMISSION_CHUNK_SIZE = 200
FETCH_CHUNK_SIZE = 100
-UUID_REGEX = r'^[a-f0-9]{8}(-[a-f0-9]{4}){3}-[a-f0-9]{12}$'
+UUID_REGEX = r"^[a-f0-9]{8}(-[a-f0-9]{4}){3}-[a-f0-9]{12}$"
def mb_call(func, *args, **kwargs):
- """Call a MusicBrainz API function and catch exceptions.
- """
+ """Call a MusicBrainz API function and catch exceptions."""
try:
return func(*args, **kwargs)
except musicbrainzngs.AuthenticationError:
- raise ui.UserError('authentication with MusicBrainz failed')
+ raise ui.UserError("authentication with MusicBrainz failed")
except (musicbrainzngs.ResponseError, musicbrainzngs.NetworkError) as exc:
- raise ui.UserError(f'MusicBrainz API error: {exc}')
+ raise ui.UserError(f"MusicBrainz API error: {exc}")
except musicbrainzngs.UsageError:
- raise ui.UserError('MusicBrainz credentials missing')
+ raise ui.UserError("MusicBrainz credentials missing")
def submit_albums(collection_id, release_ids):
@@ -44,45 +43,45 @@ def submit_albums(collection_id, release_ids):
requests are made if there are many release IDs to submit.
"""
for i in range(0, len(release_ids), SUBMISSION_CHUNK_SIZE):
- chunk = release_ids[i:i + SUBMISSION_CHUNK_SIZE]
- mb_call(
- musicbrainzngs.add_releases_to_collection,
- collection_id, chunk
- )
+ chunk = release_ids[i : i + SUBMISSION_CHUNK_SIZE]
+ mb_call(musicbrainzngs.add_releases_to_collection, collection_id, chunk)
class MusicBrainzCollectionPlugin(BeetsPlugin):
def __init__(self):
super().__init__()
- config['musicbrainz']['pass'].redact = True
+ config["musicbrainz"]["pass"].redact = True
musicbrainzngs.auth(
- config['musicbrainz']['user'].as_str(),
- config['musicbrainz']['pass'].as_str(),
+ config["musicbrainz"]["user"].as_str(),
+ config["musicbrainz"]["pass"].as_str(),
)
- self.config.add({
- 'auto': False,
- 'collection': '',
- 'remove': False,
- })
- if self.config['auto']:
+ self.config.add(
+ {
+ "auto": False,
+ "collection": "",
+ "remove": False,
+ }
+ )
+ if self.config["auto"]:
self.import_stages = [self.imported]
def _get_collection(self):
collections = mb_call(musicbrainzngs.get_collections)
- if not collections['collection-list']:
- raise ui.UserError('no collections exist for user')
+ if not collections["collection-list"]:
+ raise ui.UserError("no collections exist for user")
# Get all collection IDs, avoiding event collections
- collection_ids = [x['id'] for x in collections['collection-list']]
+ collection_ids = [x["id"] for x in collections["collection-list"]]
if not collection_ids:
- raise ui.UserError('No collection found.')
+ raise ui.UserError("No collection found.")
# Check that the collection exists so we can present a nice error
- collection = self.config['collection'].as_str()
+ collection = self.config["collection"].as_str()
if collection:
if collection not in collection_ids:
- raise ui.UserError('invalid collection ID: {}'
- .format(collection))
+ raise ui.UserError(
+ "invalid collection ID: {}".format(collection)
+ )
return collection
# No specified collection. Just return the first collection ID
@@ -94,9 +93,9 @@ class MusicBrainzCollectionPlugin(BeetsPlugin):
musicbrainzngs.get_releases_in_collection,
id,
limit=FETCH_CHUNK_SIZE,
- offset=offset
- )['collection']
- return [x['id'] for x in res['release-list']], res['release-count']
+ offset=offset,
+ )["collection"]
+ return [x["id"] for x in res["release-list"]], res["release-count"]
offset = 0
albums_in_collection, release_count = _fetch(offset)
@@ -107,13 +106,15 @@ class MusicBrainzCollectionPlugin(BeetsPlugin):
return albums_in_collection
def commands(self):
- mbupdate = Subcommand('mbupdate',
- help='Update MusicBrainz collection')
- mbupdate.parser.add_option('-r', '--remove',
- action='store_true',
- default=None,
- dest='remove',
- help='Remove albums not in beets library')
+ mbupdate = Subcommand("mbupdate", help="Update MusicBrainz collection")
+ mbupdate.parser.add_option(
+ "-r",
+ "--remove",
+ action="store_true",
+ default=None,
+ dest="remove",
+ help="Remove albums not in beets library",
+ )
mbupdate.func = self.update_collection
return [mbupdate]
@@ -122,26 +123,25 @@ class MusicBrainzCollectionPlugin(BeetsPlugin):
albums_in_collection = self._get_albums_in_collection(collection_id)
remove_me = list(set(albums_in_collection) - lib_ids)
for i in range(0, len(remove_me), FETCH_CHUNK_SIZE):
- chunk = remove_me[i:i + FETCH_CHUNK_SIZE]
+ chunk = remove_me[i : i + FETCH_CHUNK_SIZE]
mb_call(
musicbrainzngs.remove_releases_from_collection,
- collection_id, chunk
+ collection_id,
+ chunk,
)
def update_collection(self, lib, opts, args):
self.config.set_args(opts)
- remove_missing = self.config['remove'].get(bool)
+ remove_missing = self.config["remove"].get(bool)
self.update_album_list(lib, lib.albums(), remove_missing)
def imported(self, session, task):
- """Add each imported album to the collection.
- """
+ """Add each imported album to the collection."""
if task.is_album:
self.update_album_list(session.lib, [task.album])
def update_album_list(self, lib, album_list, remove_missing=False):
- """Update the MusicBrainz collection from a list of Beets albums
- """
+ """Update the MusicBrainz collection from a list of Beets albums"""
collection_id = self._get_collection()
# Get a list of all the album IDs.
@@ -152,13 +152,11 @@ class MusicBrainzCollectionPlugin(BeetsPlugin):
if re.match(UUID_REGEX, aid):
album_ids.append(aid)
else:
- self._log.info('skipping invalid MBID: {0}', aid)
+ self._log.info("skipping invalid MBID: {0}", aid)
# Submit to MusicBrainz.
- self._log.info(
- 'Updating MusicBrainz collection {0}...', collection_id
- )
+ self._log.info("Updating MusicBrainz collection {0}...", collection_id)
submit_albums(collection_id, album_ids)
if remove_missing:
self.remove_missing(collection_id, lib.albums())
- self._log.info('...MusicBrainz collection updated.')
+ self._log.info("...MusicBrainz collection updated.")
diff --git a/lib/beetsplug/mbsubmit.py b/lib/beetsplug/mbsubmit.py
index 3ede0125..d215e616 100644
--- a/lib/beetsplug/mbsubmit.py
+++ b/lib/beetsplug/mbsubmit.py
@@ -21,10 +21,13 @@ implemented by MusicBrainz yet.
[1] https://wiki.musicbrainz.org/History:How_To_Parse_Track_Listings
"""
+import subprocess
+from beets import ui
from beets.autotag import Recommendation
from beets.plugins import BeetsPlugin
from beets.ui.commands import PromptChoice
+from beets.util import displayable_path
from beetsplug.info import print_data
@@ -32,26 +35,65 @@ class MBSubmitPlugin(BeetsPlugin):
def __init__(self):
super().__init__()
- self.config.add({
- 'format': '$track. $title - $artist ($length)',
- 'threshold': 'medium',
- })
+ self.config.add(
+ {
+ "format": "$track. $title - $artist ($length)",
+ "threshold": "medium",
+ "picard_path": "picard",
+ }
+ )
# Validate and store threshold.
- self.threshold = self.config['threshold'].as_choice({
- 'none': Recommendation.none,
- 'low': Recommendation.low,
- 'medium': Recommendation.medium,
- 'strong': Recommendation.strong
- })
+ self.threshold = self.config["threshold"].as_choice(
+ {
+ "none": Recommendation.none,
+ "low": Recommendation.low,
+ "medium": Recommendation.medium,
+ "strong": Recommendation.strong,
+ }
+ )
- self.register_listener('before_choose_candidate',
- self.before_choose_candidate_event)
+ self.register_listener(
+ "before_choose_candidate", self.before_choose_candidate_event
+ )
def before_choose_candidate_event(self, session, task):
if task.rec <= self.threshold:
- return [PromptChoice('p', 'Print tracks', self.print_tracks)]
+ return [
+ PromptChoice("p", "Print tracks", self.print_tracks),
+ PromptChoice("o", "Open files with Picard", self.picard),
+ ]
+
+ def picard(self, session, task):
+ paths = []
+ for p in task.paths:
+ paths.append(displayable_path(p))
+ try:
+ picard_path = self.config["picard_path"].as_str()
+ subprocess.Popen([picard_path] + paths)
+ self._log.info("launched picard from\n{}", picard_path)
+ except OSError as exc:
+ self._log.error(f"Could not open picard, got error:\n{exc}")
def print_tracks(self, session, task):
for i in sorted(task.items, key=lambda i: i.track):
- print_data(None, i, self.config['format'].as_str())
+ print_data(None, i, self.config["format"].as_str())
+
+ def commands(self):
+ """Add beet UI commands for mbsubmit."""
+ mbsubmit_cmd = ui.Subcommand(
+ "mbsubmit", help="Submit Tracks to MusicBrainz"
+ )
+
+ def func(lib, opts, args):
+ items = lib.items(ui.decargs(args))
+ self._mbsubmit(items)
+
+ mbsubmit_cmd.func = func
+
+ return [mbsubmit_cmd]
+
+ def _mbsubmit(self, items):
+ """Print track information to be submitted to MusicBrainz."""
+ for i in sorted(items, key=lambda i: i.track):
+ print_data(None, i, self.config["format"].as_str())
diff --git a/lib/beetsplug/mbsync.py b/lib/beetsplug/mbsync.py
index 26778830..0e63a6f2 100644
--- a/lib/beetsplug/mbsync.py
+++ b/lib/beetsplug/mbsync.py
@@ -15,12 +15,12 @@
"""Update library's tags using MusicBrainz.
"""
-from beets.plugins import BeetsPlugin, apply_item_changes
-from beets import autotag, library, ui, util
-from beets.autotag import hooks
+import re
from collections import defaultdict
-import re
+from beets import autotag, library, ui, util
+from beets.autotag import hooks
+from beets.plugins import BeetsPlugin, apply_item_changes
MBID_REGEX = r"(\d|\w){8}-(\d|\w){4}-(\d|\w){4}-(\d|\w){4}-(\d|\w){12}"
@@ -30,28 +30,41 @@ class MBSyncPlugin(BeetsPlugin):
super().__init__()
def commands(self):
- cmd = ui.Subcommand('mbsync',
- help='update metadata from musicbrainz')
+ cmd = ui.Subcommand("mbsync", help="update metadata from musicbrainz")
cmd.parser.add_option(
- '-p', '--pretend', action='store_true',
- help='show all changes but do nothing')
+ "-p",
+ "--pretend",
+ action="store_true",
+ help="show all changes but do nothing",
+ )
cmd.parser.add_option(
- '-m', '--move', action='store_true', dest='move',
- help="move files in the library directory")
+ "-m",
+ "--move",
+ action="store_true",
+ dest="move",
+ help="move files in the library directory",
+ )
cmd.parser.add_option(
- '-M', '--nomove', action='store_false', dest='move',
- help="don't move files in library")
+ "-M",
+ "--nomove",
+ action="store_false",
+ dest="move",
+ help="don't move files in library",
+ )
cmd.parser.add_option(
- '-W', '--nowrite', action='store_false',
- default=None, dest='write',
- help="don't write updated metadata to files")
+ "-W",
+ "--nowrite",
+ action="store_false",
+ default=None,
+ dest="write",
+ help="don't write updated metadata to files",
+ )
cmd.parser.add_format_option()
cmd.func = self.func
return [cmd]
def func(self, lib, opts, args):
- """Command handler for the mbsync function.
- """
+ """Command handler for the mbsync function."""
move = ui.should_move(opts.move)
pretend = opts.pretend
write = ui.should_write(opts.write)
@@ -64,25 +77,30 @@ class MBSyncPlugin(BeetsPlugin):
"""Retrieve and apply info from the autotagger for items matched by
query.
"""
- for item in lib.items(query + ['singleton:true']):
+ for item in lib.items(query + ["singleton:true"]):
item_formatted = format(item)
if not item.mb_trackid:
- self._log.info('Skipping singleton with no mb_trackid: {0}',
- item_formatted)
+ self._log.info(
+ "Skipping singleton with no mb_trackid: {0}", item_formatted
+ )
continue
# Do we have a valid MusicBrainz track ID?
if not re.match(MBID_REGEX, item.mb_trackid):
- self._log.info('Skipping singleton with invalid mb_trackid:' +
- ' {0}', item_formatted)
+ self._log.info(
+ "Skipping singleton with invalid mb_trackid:" + " {0}",
+ item_formatted,
+ )
continue
# Get the MusicBrainz recording info.
track_info = hooks.track_for_mbid(item.mb_trackid)
if not track_info:
- self._log.info('Recording ID not found: {0} for track {0}',
- item.mb_trackid,
- item_formatted)
+ self._log.info(
+ "Recording ID not found: {0} for track {0}",
+ item.mb_trackid,
+ item_formatted,
+ )
continue
# Apply.
@@ -98,24 +116,29 @@ class MBSyncPlugin(BeetsPlugin):
for a in lib.albums(query):
album_formatted = format(a)
if not a.mb_albumid:
- self._log.info('Skipping album with no mb_albumid: {0}',
- album_formatted)
+ self._log.info(
+ "Skipping album with no mb_albumid: {0}", album_formatted
+ )
continue
items = list(a.items())
# Do we have a valid MusicBrainz album ID?
if not re.match(MBID_REGEX, a.mb_albumid):
- self._log.info('Skipping album with invalid mb_albumid: {0}',
- album_formatted)
+ self._log.info(
+ "Skipping album with invalid mb_albumid: {0}",
+ album_formatted,
+ )
continue
# Get the MusicBrainz album information.
album_info = hooks.album_for_mbid(a.mb_albumid)
if not album_info:
- self._log.info('Release ID {0} not found for album {1}',
- a.mb_albumid,
- album_formatted)
+ self._log.info(
+ "Release ID {0} not found for album {1}",
+ a.mb_albumid,
+ album_formatted,
+ )
continue
# Map release track and recording MBIDs to their information.
@@ -132,8 +155,10 @@ class MBSyncPlugin(BeetsPlugin):
# work for albums that have missing or extra tracks.
mapping = {}
for item in items:
- if item.mb_releasetrackid and \
- item.mb_releasetrackid in releasetrack_index:
+ if (
+ item.mb_releasetrackid
+ and item.mb_releasetrackid in releasetrack_index
+ ):
mapping[item] = releasetrack_index[item.mb_releasetrackid]
else:
candidates = track_index[item.mb_trackid]
@@ -143,13 +168,15 @@ class MBSyncPlugin(BeetsPlugin):
# If there are multiple copies of a recording, they are
# disambiguated using their disc and track number.
for c in candidates:
- if (c.medium_index == item.track and
- c.medium == item.disc):
+ if (
+ c.medium_index == item.track
+ and c.medium == item.disc
+ ):
mapping[item] = c
break
# Apply.
- self._log.debug('applying changes to {}', album_formatted)
+ self._log.debug("applying changes to {}", album_formatted)
with lib.transaction():
autotag.apply_metadata(album_info, mapping)
changed = False
@@ -174,5 +201,5 @@ class MBSyncPlugin(BeetsPlugin):
# Move album art (and any inconsistent items).
if move and lib.directory in util.ancestry(items[0].path):
- self._log.debug('moving album {0}', album_formatted)
+ self._log.debug("moving album {0}", album_formatted)
a.move()
diff --git a/lib/beetsplug/metasync/__init__.py b/lib/beetsplug/metasync/__init__.py
index 361071fb..d17071b5 100644
--- a/lib/beetsplug/metasync/__init__.py
+++ b/lib/beetsplug/metasync/__init__.py
@@ -16,20 +16,20 @@
"""
-from abc import abstractmethod, ABCMeta
+from abc import ABCMeta, abstractmethod
from importlib import import_module
from confuse import ConfigValueError
+
from beets import ui
from beets.plugins import BeetsPlugin
-
-METASYNC_MODULE = 'beetsplug.metasync'
+METASYNC_MODULE = "beetsplug.metasync"
# Dictionary to map the MODULE and the CLASS NAME of meta sources
SOURCES = {
- 'amarok': 'Amarok',
- 'itunes': 'Itunes',
+ "amarok": "Amarok",
+ "itunes": "Itunes",
}
@@ -45,13 +45,13 @@ class MetaSource(metaclass=ABCMeta):
def load_meta_sources():
- """ Returns a dictionary of all the MetaSources
+ """Returns a dictionary of all the MetaSources
E.g., {'itunes': Itunes} with isinstance(Itunes, MetaSource) true
"""
meta_sources = {}
for module_path, class_name in SOURCES.items():
- module = import_module(METASYNC_MODULE + '.' + module_path)
+ module = import_module(METASYNC_MODULE + "." + module_path)
meta_sources[class_name.lower()] = getattr(module, class_name)
return meta_sources
@@ -61,8 +61,7 @@ META_SOURCES = load_meta_sources()
def load_item_types():
- """ Returns a dictionary containing the item_types of all the MetaSources
- """
+ """Returns a dictionary containing the item_types of all the MetaSources"""
item_types = {}
for meta_source in META_SOURCES.values():
item_types.update(meta_source.item_types)
@@ -70,42 +69,50 @@ def load_item_types():
class MetaSyncPlugin(BeetsPlugin):
-
item_types = load_item_types()
def __init__(self):
super().__init__()
def commands(self):
- cmd = ui.Subcommand('metasync',
- help='update metadata from music player libraries')
- cmd.parser.add_option('-p', '--pretend', action='store_true',
- help='show all changes but do nothing')
- cmd.parser.add_option('-s', '--source', default=[],
- action='append', dest='sources',
- help='comma-separated list of sources to sync')
+ cmd = ui.Subcommand(
+ "metasync", help="update metadata from music player libraries"
+ )
+ cmd.parser.add_option(
+ "-p",
+ "--pretend",
+ action="store_true",
+ help="show all changes but do nothing",
+ )
+ cmd.parser.add_option(
+ "-s",
+ "--source",
+ default=[],
+ action="append",
+ dest="sources",
+ help="comma-separated list of sources to sync",
+ )
cmd.parser.add_format_option()
cmd.func = self.func
return [cmd]
def func(self, lib, opts, args):
- """Command handler for the metasync function.
- """
+ """Command handler for the metasync function."""
pretend = opts.pretend
query = ui.decargs(args)
sources = []
for source in opts.sources:
- sources.extend(source.split(','))
+ sources.extend(source.split(","))
- sources = sources or self.config['source'].as_str_seq()
+ sources = sources or self.config["source"].as_str_seq()
meta_source_instances = {}
items = lib.items(query)
# Avoid needlessly instantiating meta sources (can be expensive)
if not items:
- self._log.info('No items found matching query')
+ self._log.info("No items found matching query")
return
# Instantiate the meta sources
@@ -113,18 +120,19 @@ class MetaSyncPlugin(BeetsPlugin):
try:
cls = META_SOURCES[player]
except KeyError:
- self._log.error('Unknown metadata source \'{}\''.format(
- player))
+ self._log.error("Unknown metadata source '{}'".format(player))
try:
meta_source_instances[player] = cls(self.config, self._log)
except (ImportError, ConfigValueError) as e:
- self._log.error('Failed to instantiate metadata source '
- '\'{}\': {}'.format(player, e))
+ self._log.error(
+ "Failed to instantiate metadata source "
+ "'{}': {}".format(player, e)
+ )
# Avoid needlessly iterating over items
if not meta_source_instances:
- self._log.error('No valid metadata sources found')
+ self._log.error("No valid metadata sources found")
return
# Sync the items with all of the meta sources
diff --git a/lib/beetsplug/metasync/amarok.py b/lib/beetsplug/metasync/amarok.py
index a49eecc3..195cd878 100644
--- a/lib/beetsplug/metasync/amarok.py
+++ b/lib/beetsplug/metasync/amarok.py
@@ -16,35 +16,35 @@
"""
-from os.path import basename
from datetime import datetime
+from os.path import basename
from time import mktime
from xml.sax.saxutils import quoteattr
-from beets.util import displayable_path
from beets.dbcore import types
from beets.library import DateType
+from beets.util import displayable_path
from beetsplug.metasync import MetaSource
def import_dbus():
try:
- return __import__('dbus')
+ return __import__("dbus")
except ImportError:
return None
+
dbus = import_dbus()
class Amarok(MetaSource):
-
item_types = {
- 'amarok_rating': types.INTEGER,
- 'amarok_score': types.FLOAT,
- 'amarok_uid': types.STRING,
- 'amarok_playcount': types.INTEGER,
- 'amarok_firstplayed': DateType(),
- 'amarok_lastplayed': DateType(),
+ "amarok_rating": types.INTEGER,
+ "amarok_score": types.FLOAT,
+ "amarok_uid": types.STRING,
+ "amarok_playcount": types.INTEGER,
+ "amarok_firstplayed": DateType(),
+ "amarok_lastplayed": DateType(),
}
query_xml = ' \
@@ -57,10 +57,11 @@ class Amarok(MetaSource):
super().__init__(config, log)
if not dbus:
- raise ImportError('failed to import dbus')
+ raise ImportError("failed to import dbus")
- self.collection = \
- dbus.SessionBus().get_object('org.kde.amarok', '/Collection')
+ self.collection = dbus.SessionBus().get_object(
+ "org.kde.amarok", "/Collection"
+ )
def sync_from_source(self, item):
path = displayable_path(item.path)
@@ -73,35 +74,36 @@ class Amarok(MetaSource):
self.query_xml % quoteattr(basename(path))
)
for result in results:
- if result['xesam:url'] != path:
+ if result["xesam:url"] != path:
continue
- item.amarok_rating = result['xesam:userRating']
- item.amarok_score = result['xesam:autoRating']
- item.amarok_playcount = result['xesam:useCount']
- item.amarok_uid = \
- result['xesam:id'].replace('amarok-sqltrackuid://', '')
+ item.amarok_rating = result["xesam:userRating"]
+ item.amarok_score = result["xesam:autoRating"]
+ item.amarok_playcount = result["xesam:useCount"]
+ item.amarok_uid = result["xesam:id"].replace(
+ "amarok-sqltrackuid://", ""
+ )
- if result['xesam:firstUsed'][0][0] != 0:
+ if result["xesam:firstUsed"][0][0] != 0:
# These dates are stored as timestamps in amarok's db, but
# exposed over dbus as fixed integers in the current timezone.
first_played = datetime(
- result['xesam:firstUsed'][0][0],
- result['xesam:firstUsed'][0][1],
- result['xesam:firstUsed'][0][2],
- result['xesam:firstUsed'][1][0],
- result['xesam:firstUsed'][1][1],
- result['xesam:firstUsed'][1][2]
+ result["xesam:firstUsed"][0][0],
+ result["xesam:firstUsed"][0][1],
+ result["xesam:firstUsed"][0][2],
+ result["xesam:firstUsed"][1][0],
+ result["xesam:firstUsed"][1][1],
+ result["xesam:firstUsed"][1][2],
)
- if result['xesam:lastUsed'][0][0] != 0:
+ if result["xesam:lastUsed"][0][0] != 0:
last_played = datetime(
- result['xesam:lastUsed'][0][0],
- result['xesam:lastUsed'][0][1],
- result['xesam:lastUsed'][0][2],
- result['xesam:lastUsed'][1][0],
- result['xesam:lastUsed'][1][1],
- result['xesam:lastUsed'][1][2]
+ result["xesam:lastUsed"][0][0],
+ result["xesam:lastUsed"][0][1],
+ result["xesam:lastUsed"][0][2],
+ result["xesam:lastUsed"][1][0],
+ result["xesam:lastUsed"][1][1],
+ result["xesam:lastUsed"][1][2],
)
else:
last_played = first_played
diff --git a/lib/beetsplug/metasync/itunes.py b/lib/beetsplug/metasync/itunes.py
index e50a5713..15cbd7bb 100644
--- a/lib/beetsplug/metasync/itunes.py
+++ b/lib/beetsplug/metasync/itunes.py
@@ -16,31 +16,32 @@
"""
-from contextlib import contextmanager
import os
+import plistlib
import shutil
import tempfile
-import plistlib
-
-from urllib.parse import urlparse, unquote
+from contextlib import contextmanager
from time import mktime
+from urllib.parse import unquote, urlparse
+
+from confuse import ConfigValueError
from beets import util
from beets.dbcore import types
from beets.library import DateType
-from confuse import ConfigValueError
+from beets.util import bytestring_path, syspath
from beetsplug.metasync import MetaSource
@contextmanager
def create_temporary_copy(path):
- temp_dir = tempfile.mkdtemp()
- temp_path = os.path.join(temp_dir, 'temp_itunes_lib')
- shutil.copyfile(path, temp_path)
+ temp_dir = bytestring_path(tempfile.mkdtemp())
+ temp_path = os.path.join(temp_dir, b"temp_itunes_lib")
+ shutil.copyfile(syspath(path), syspath(temp_path))
try:
yield temp_path
finally:
- shutil.rmtree(temp_dir)
+ shutil.rmtree(syspath(temp_dir))
def _norm_itunes_path(path):
@@ -54,72 +55,74 @@ def _norm_itunes_path(path):
# which is unwanted in the case of Windows systems.
# E.g., '\\G:\\Music\\bar' needs to be stripped to 'G:\\Music\\bar'
- return util.bytestring_path(os.path.normpath(
- unquote(urlparse(path).path)).lstrip('\\')).lower()
+ return util.bytestring_path(
+ os.path.normpath(unquote(urlparse(path).path)).lstrip("\\")
+ ).lower()
class Itunes(MetaSource):
-
item_types = {
- 'itunes_rating': types.INTEGER, # 0..100 scale
- 'itunes_playcount': types.INTEGER,
- 'itunes_skipcount': types.INTEGER,
- 'itunes_lastplayed': DateType(),
- 'itunes_lastskipped': DateType(),
- 'itunes_dateadded': DateType(),
+ "itunes_rating": types.INTEGER, # 0..100 scale
+ "itunes_playcount": types.INTEGER,
+ "itunes_skipcount": types.INTEGER,
+ "itunes_lastplayed": DateType(),
+ "itunes_lastskipped": DateType(),
+ "itunes_dateadded": DateType(),
}
def __init__(self, config, log):
super().__init__(config, log)
- config.add({'itunes': {
- 'library': '~/Music/iTunes/iTunes Library.xml'
- }})
+ config.add({"itunes": {"library": "~/Music/iTunes/iTunes Library.xml"}})
# Load the iTunes library, which has to be the .xml one (not the .itl)
- library_path = config['itunes']['library'].as_filename()
+ library_path = config["itunes"]["library"].as_filename()
try:
- self._log.debug(
- f'loading iTunes library from {library_path}')
+ self._log.debug(f"loading iTunes library from {library_path}")
with create_temporary_copy(library_path) as library_copy:
- with open(library_copy, 'rb') as library_copy_f:
+ with open(library_copy, "rb") as library_copy_f:
raw_library = plistlib.load(library_copy_f)
except OSError as e:
- raise ConfigValueError('invalid iTunes library: ' + e.strerror)
+ raise ConfigValueError("invalid iTunes library: " + e.strerror)
except Exception:
# It's likely the user configured their '.itl' library (<> xml)
- if os.path.splitext(library_path)[1].lower() != '.xml':
- hint = ': please ensure that the configured path' \
- ' points to the .XML library'
+ if os.path.splitext(library_path)[1].lower() != ".xml":
+ hint = (
+ ": please ensure that the configured path"
+ " points to the .XML library"
+ )
else:
- hint = ''
- raise ConfigValueError('invalid iTunes library' + hint)
+ hint = ""
+ raise ConfigValueError("invalid iTunes library" + hint)
# Make the iTunes library queryable using the path
- self.collection = {_norm_itunes_path(track['Location']): track
- for track in raw_library['Tracks'].values()
- if 'Location' in track}
+ self.collection = {
+ _norm_itunes_path(track["Location"]): track
+ for track in raw_library["Tracks"].values()
+ if "Location" in track
+ }
def sync_from_source(self, item):
result = self.collection.get(util.bytestring_path(item.path).lower())
if not result:
- self._log.warning(f'no iTunes match found for {item}')
+ self._log.warning(f"no iTunes match found for {item}")
return
- item.itunes_rating = result.get('Rating')
- item.itunes_playcount = result.get('Play Count')
- item.itunes_skipcount = result.get('Skip Count')
+ item.itunes_rating = result.get("Rating")
+ item.itunes_playcount = result.get("Play Count")
+ item.itunes_skipcount = result.get("Skip Count")
- if result.get('Play Date UTC'):
+ if result.get("Play Date UTC"):
item.itunes_lastplayed = mktime(
- result.get('Play Date UTC').timetuple())
+ result.get("Play Date UTC").timetuple()
+ )
- if result.get('Skip Date'):
+ if result.get("Skip Date"):
item.itunes_lastskipped = mktime(
- result.get('Skip Date').timetuple())
+ result.get("Skip Date").timetuple()
+ )
- if result.get('Date Added'):
- item.itunes_dateadded = mktime(
- result.get('Date Added').timetuple())
+ if result.get("Date Added"):
+ item.itunes_dateadded = mktime(result.get("Date Added").timetuple())
diff --git a/lib/beetsplug/missing.py b/lib/beetsplug/missing.py
index 771978c1..2e37fde7 100644
--- a/lib/beetsplug/missing.py
+++ b/lib/beetsplug/missing.py
@@ -16,21 +16,21 @@
"""List missing tracks.
"""
-import musicbrainzngs
-
-from musicbrainzngs.musicbrainz import MusicBrainzError
from collections import defaultdict
+
+import musicbrainzngs
+from musicbrainzngs.musicbrainz import MusicBrainzError
+
+from beets import config
from beets.autotag import hooks
+from beets.dbcore import types
from beets.library import Item
from beets.plugins import BeetsPlugin
-from beets.ui import decargs, print_, Subcommand
-from beets import config
-from beets.dbcore import types
+from beets.ui import Subcommand, decargs, print_
def _missing_count(album):
- """Return number of missing items in `album`.
- """
+ """Return number of missing items in `album`."""
return (album.albumtotal or 0) - len(album.items())
@@ -45,80 +45,93 @@ def _item(track_info, album_info, album_id):
t = track_info
a = album_info
- return Item(**{
- 'album_id': album_id,
- 'album': a.album,
- 'albumartist': a.artist,
- 'albumartist_credit': a.artist_credit,
- 'albumartist_sort': a.artist_sort,
- 'albumdisambig': a.albumdisambig,
- 'albumstatus': a.albumstatus,
- 'albumtype': a.albumtype,
- 'artist': t.artist,
- 'artist_credit': t.artist_credit,
- 'artist_sort': t.artist_sort,
- 'asin': a.asin,
- 'catalognum': a.catalognum,
- 'comp': a.va,
- 'country': a.country,
- 'day': a.day,
- 'disc': t.medium,
- 'disctitle': t.disctitle,
- 'disctotal': a.mediums,
- 'label': a.label,
- 'language': a.language,
- 'length': t.length,
- 'mb_albumid': a.album_id,
- 'mb_artistid': t.artist_id,
- 'mb_releasegroupid': a.releasegroup_id,
- 'mb_trackid': t.track_id,
- 'media': t.media,
- 'month': a.month,
- 'script': a.script,
- 'title': t.title,
- 'track': t.index,
- 'tracktotal': len(a.tracks),
- 'year': a.year,
- })
+ return Item(
+ **{
+ "album_id": album_id,
+ "album": a.album,
+ "albumartist": a.artist,
+ "albumartist_credit": a.artist_credit,
+ "albumartist_sort": a.artist_sort,
+ "albumdisambig": a.albumdisambig,
+ "albumstatus": a.albumstatus,
+ "albumtype": a.albumtype,
+ "artist": t.artist,
+ "artist_credit": t.artist_credit,
+ "artist_sort": t.artist_sort,
+ "asin": a.asin,
+ "catalognum": a.catalognum,
+ "comp": a.va,
+ "country": a.country,
+ "day": a.day,
+ "disc": t.medium,
+ "disctitle": t.disctitle,
+ "disctotal": a.mediums,
+ "label": a.label,
+ "language": a.language,
+ "length": t.length,
+ "mb_albumid": a.album_id,
+ "mb_artistid": t.artist_id,
+ "mb_releasegroupid": a.releasegroup_id,
+ "mb_trackid": t.track_id,
+ "media": t.media,
+ "month": a.month,
+ "script": a.script,
+ "title": t.title,
+ "track": t.index,
+ "tracktotal": len(a.tracks),
+ "year": a.year,
+ }
+ )
class MissingPlugin(BeetsPlugin):
- """List missing tracks
- """
+ """List missing tracks"""
album_types = {
- 'missing': types.INTEGER,
+ "missing": types.INTEGER,
}
def __init__(self):
super().__init__()
- self.config.add({
- 'count': False,
- 'total': False,
- 'album': False,
- })
+ self.config.add(
+ {
+ "count": False,
+ "total": False,
+ "album": False,
+ }
+ )
- self.album_template_fields['missing'] = _missing_count
+ self.album_template_fields["missing"] = _missing_count
- self._command = Subcommand('missing',
- help=__doc__,
- aliases=['miss'])
+ self._command = Subcommand("missing", help=__doc__, aliases=["miss"])
self._command.parser.add_option(
- '-c', '--count', dest='count', action='store_true',
- help='count missing tracks per album')
+ "-c",
+ "--count",
+ dest="count",
+ action="store_true",
+ help="count missing tracks per album",
+ )
self._command.parser.add_option(
- '-t', '--total', dest='total', action='store_true',
- help='count total of missing tracks')
+ "-t",
+ "--total",
+ dest="total",
+ action="store_true",
+ help="count total of missing tracks",
+ )
self._command.parser.add_option(
- '-a', '--album', dest='album', action='store_true',
- help='show missing albums for artist instead of tracks')
+ "-a",
+ "--album",
+ dest="album",
+ action="store_true",
+ help="show missing albums for artist instead of tracks",
+ )
self._command.parser.add_format_option()
def commands(self):
def _miss(lib, opts, args):
self.config.set_args(opts)
- albms = self.config['album'].get()
+ albms = self.config["album"].get()
helper = self._missing_albums if albms else self._missing_tracks
helper(lib, decargs(args))
@@ -132,9 +145,9 @@ class MissingPlugin(BeetsPlugin):
"""
albums = lib.albums(query)
- count = self.config['count'].get()
- total = self.config['total'].get()
- fmt = config['format_album' if count else 'format_item'].get()
+ count = self.config["count"].get()
+ total = self.config["total"].get()
+ fmt = config["format_album" if count else "format_item"].get()
if total:
print(sum([_missing_count(a) for a in albums]))
@@ -142,7 +155,7 @@ class MissingPlugin(BeetsPlugin):
# Default format string for count mode.
if count:
- fmt += ': $missing'
+ fmt += ": $missing"
for album in albums:
if count:
@@ -157,13 +170,13 @@ class MissingPlugin(BeetsPlugin):
"""Print a listing of albums missing from each artist in the library
matching query.
"""
- total = self.config['total'].get()
+ total = self.config["total"].get()
albums = lib.albums(query)
# build dict mapping artist to list of their albums in library
albums_by_artist = defaultdict(list)
for alb in albums:
- artist = (alb['albumartist'], alb['mb_albumartistid'])
+ artist = (alb["albumartist"], alb["mb_albumartistid"])
albums_by_artist[artist].append(alb)
total_missing = 0
@@ -171,20 +184,24 @@ class MissingPlugin(BeetsPlugin):
# build dict mapping artist to list of all albums
for artist, albums in albums_by_artist.items():
if artist[1] is None or artist[1] == "":
- albs_no_mbid = ["'" + a['album'] + "'" for a in albums]
+ albs_no_mbid = ["'" + a["album"] + "'" for a in albums]
self._log.info(
"No musicbrainz ID for artist '{}' found in album(s) {}; "
- "skipping", artist[0], ", ".join(albs_no_mbid)
+ "skipping",
+ artist[0],
+ ", ".join(albs_no_mbid),
)
continue
try:
resp = musicbrainzngs.browse_release_groups(artist=artist[1])
- release_groups = resp['release-group-list']
+ release_groups = resp["release-group-list"]
except MusicBrainzError as err:
self._log.info(
"Couldn't fetch info for artist '{}' ({}) - '{}'",
- artist[0], artist[1], err
+ artist[0],
+ artist[1],
+ err,
)
continue
@@ -193,7 +210,7 @@ class MissingPlugin(BeetsPlugin):
for rg in release_groups:
missing.append(rg)
for alb in albums:
- if alb['mb_releasegroupid'] == rg['id']:
+ if alb["mb_releasegroupid"] == rg["id"]:
missing.remove(rg)
present.append(rg)
break
@@ -202,7 +219,7 @@ class MissingPlugin(BeetsPlugin):
if total:
continue
- missing_titles = {rg['title'] for rg in missing}
+ missing_titles = {rg["title"] for rg in missing}
for release_title in missing_titles:
print_("{} - {}".format(artist[0], release_title))
@@ -211,16 +228,18 @@ class MissingPlugin(BeetsPlugin):
print(total_missing)
def _missing(self, album):
- """Query MusicBrainz to determine items missing from `album`.
- """
+ """Query MusicBrainz to determine items missing from `album`."""
item_mbids = [x.mb_trackid for x in album.items()]
if len(list(album.items())) < album.albumtotal:
# fetch missing items
# TODO: Implement caching that without breaking other stuff
album_info = hooks.album_for_mbid(album.mb_albumid)
- for track_info in getattr(album_info, 'tracks', []):
+ for track_info in getattr(album_info, "tracks", []):
if track_info.track_id not in item_mbids:
item = _item(track_info, album_info, album.id)
- self._log.debug('track {0} in album {1}',
- track_info.track_id, album_info.album_id)
+ self._log.debug(
+ "track {0} in album {1}",
+ track_info.track_id,
+ album_info.album_id,
+ )
yield item
diff --git a/lib/beetsplug/mpdstats.py b/lib/beetsplug/mpdstats.py
index 96291cf4..6d4c269d 100644
--- a/lib/beetsplug/mpdstats.py
+++ b/lib/beetsplug/mpdstats.py
@@ -13,16 +13,14 @@
# included in all copies or substantial portions of the Software.
-import mpd
-import time
import os
+import time
-from beets import ui
-from beets import config
-from beets import plugins
-from beets import library
-from beets.util import displayable_path
+import mpd
+
+from beets import config, library, plugins, ui
from beets.dbcore import types
+from beets.util import displayable_path
# If we lose the connection, how many times do we want to retry and how
# much time should we wait between retries?
@@ -30,60 +28,55 @@ RETRIES = 10
RETRY_INTERVAL = 5
-mpd_config = config['mpd']
+mpd_config = config["mpd"]
def is_url(path):
- """Try to determine if the path is an URL.
- """
+ """Try to determine if the path is an URL."""
if isinstance(path, bytes): # if it's bytes, then it's a path
return False
- return path.split('://', 1)[0] in ['http', 'https']
+ return path.split("://", 1)[0] in ["http", "https"]
class MPDClientWrapper:
def __init__(self, log):
self._log = log
- self.music_directory = mpd_config['music_directory'].as_str()
- self.strip_path = mpd_config['strip_path'].as_str()
+ self.music_directory = mpd_config["music_directory"].as_str()
+ self.strip_path = mpd_config["strip_path"].as_str()
# Ensure strip_path end with '/'
- if not self.strip_path.endswith('/'):
- self.strip_path += '/'
+ if not self.strip_path.endswith("/"):
+ self.strip_path += "/"
- self._log.debug('music_directory: {0}', self.music_directory)
- self._log.debug('strip_path: {0}', self.strip_path)
+ self._log.debug("music_directory: {0}", self.music_directory)
+ self._log.debug("strip_path: {0}", self.strip_path)
self.client = mpd.MPDClient()
def connect(self):
- """Connect to the MPD.
- """
- host = mpd_config['host'].as_str()
- port = mpd_config['port'].get(int)
+ """Connect to the MPD."""
+ host = mpd_config["host"].as_str()
+ port = mpd_config["port"].get(int)
- if host[0] in ['/', '~']:
+ if host[0] in ["/", "~"]:
host = os.path.expanduser(host)
- self._log.info('connecting to {0}:{1}', host, port)
+ self._log.info("connecting to {0}:{1}", host, port)
try:
self.client.connect(host, port)
except OSError as e:
- raise ui.UserError(f'could not connect to MPD: {e}')
+ raise ui.UserError(f"could not connect to MPD: {e}")
- password = mpd_config['password'].as_str()
+ password = mpd_config["password"].as_str()
if password:
try:
self.client.password(password)
except mpd.CommandError as e:
- raise ui.UserError(
- f'could not authenticate to MPD: {e}'
- )
+ raise ui.UserError(f"could not authenticate to MPD: {e}")
def disconnect(self):
- """Disconnect from the MPD.
- """
+ """Disconnect from the MPD."""
self.client.close()
self.client.disconnect()
@@ -94,11 +87,11 @@ class MPDClientWrapper:
try:
return getattr(self.client, command)()
except (OSError, mpd.ConnectionError) as err:
- self._log.error('{0}', err)
+ self._log.error("{0}", err)
if retries <= 0:
# if we exited without breaking, we couldn't reconnect in time :(
- raise ui.UserError('communication with MPD server failed')
+ raise ui.UserError("communication with MPD server failed")
time.sleep(RETRY_INTERVAL)
@@ -119,28 +112,27 @@ class MPDClientWrapper:
`strip_path` defaults to ''.
"""
result = None
- entry = self.get('currentsong')
- if 'file' in entry:
- if not is_url(entry['file']):
- file = entry['file']
+ entry = self.get("currentsong")
+ if "file" in entry:
+ if not is_url(entry["file"]):
+ file = entry["file"]
if file.startswith(self.strip_path):
- file = file[len(self.strip_path):]
+ file = file[len(self.strip_path) :]
result = os.path.join(self.music_directory, file)
else:
- result = entry['file']
- self._log.debug('returning: {0}', result)
- return result, entry.get('id')
+ result = entry["file"]
+ self._log.debug("returning: {0}", result)
+ return result, entry.get("id")
def status(self):
- """Return the current status of the MPD.
- """
- return self.get('status')
+ """Return the current status of the MPD."""
+ return self.get("status")
def events(self):
"""Return list of events. This may block a long time while waiting for
an answer from MPD.
"""
- return self.get('idle')
+ return self.get("idle")
class MPDStats:
@@ -148,8 +140,8 @@ class MPDStats:
self.lib = lib
self._log = log
- self.do_rating = mpd_config['rating'].get(bool)
- self.rating_mix = mpd_config['rating_mix'].get(float)
+ self.do_rating = mpd_config["rating"].get(bool)
+ self.rating_mix = mpd_config["rating_mix"].get(float)
self.time_threshold = 10.0 # TODO: maybe add config option?
self.now_playing = None
@@ -160,22 +152,20 @@ class MPDStats:
old rating and the fact if it was skipped or not.
"""
if skipped:
- rolling = (rating - rating / 2.0)
+ rolling = rating - rating / 2.0
else:
- rolling = (rating + (1.0 - rating) / 2.0)
+ rolling = rating + (1.0 - rating) / 2.0
stable = (play_count + 1.0) / (play_count + skip_count + 2.0)
- return (self.rating_mix * stable +
- (1.0 - self.rating_mix) * rolling)
+ return self.rating_mix * stable + (1.0 - self.rating_mix) * rolling
def get_item(self, path):
- """Return the beets item related to path.
- """
- query = library.PathQuery('path', path)
+ """Return the beets item related to path."""
+ query = library.PathQuery("path", path)
item = self.lib.items(query).get()
if item:
return item
else:
- self._log.info('item not found: {0}', displayable_path(path))
+ self._log.info("item not found: {0}", displayable_path(path))
def update_item(self, item, attribute, value=None, increment=None):
"""Update the beets item. Set attribute to value or increment the value
@@ -193,10 +183,12 @@ class MPDStats:
item[attribute] = value
item.store()
- self._log.debug('updated: {0} = {1} [{2}]',
- attribute,
- item[attribute],
- displayable_path(item.path))
+ self._log.debug(
+ "updated: {0} = {1} [{2}]",
+ attribute,
+ item[attribute],
+ displayable_path(item.path),
+ )
def update_rating(self, item, skipped):
"""Update the rating for a beets item. The `item` can either be a
@@ -207,12 +199,13 @@ class MPDStats:
item.load()
rating = self.rating(
- int(item.get('play_count', 0)),
- int(item.get('skip_count', 0)),
- float(item.get('rating', 0.5)),
- skipped)
+ int(item.get("play_count", 0)),
+ int(item.get("skip_count", 0)),
+ float(item.get("rating", 0.5)),
+ skipped,
+ )
- self.update_item(item, 'rating', rating)
+ self.update_item(item, "rating", rating)
def handle_song_change(self, song):
"""Determine if a song was skipped or not and update its attributes.
@@ -222,7 +215,7 @@ class MPDStats:
Returns whether the change was manual (skipped previous song or not)
"""
- diff = abs(song['remaining'] - (time.time() - song['started']))
+ diff = abs(song["remaining"] - (time.time() - song["started"]))
skipped = diff >= self.time_threshold
@@ -232,89 +225,89 @@ class MPDStats:
self.handle_played(song)
if self.do_rating:
- self.update_rating(song['beets_item'], skipped)
+ self.update_rating(song["beets_item"], skipped)
return skipped
def handle_played(self, song):
- """Updates the play count of a song.
- """
- self.update_item(song['beets_item'], 'play_count', increment=1)
- self._log.info('played {0}', displayable_path(song['path']))
+ """Updates the play count of a song."""
+ self.update_item(song["beets_item"], "play_count", increment=1)
+ self._log.info("played {0}", displayable_path(song["path"]))
def handle_skipped(self, song):
- """Updates the skip count of a song.
- """
- self.update_item(song['beets_item'], 'skip_count', increment=1)
- self._log.info('skipped {0}', displayable_path(song['path']))
+ """Updates the skip count of a song."""
+ self.update_item(song["beets_item"], "skip_count", increment=1)
+ self._log.info("skipped {0}", displayable_path(song["path"]))
def on_stop(self, status):
- self._log.info('stop')
+ self._log.info("stop")
# if the current song stays the same it means that we stopped on the
# current track and should not record a skip.
- if self.now_playing and self.now_playing['id'] != status.get('songid'):
+ if self.now_playing and self.now_playing["id"] != status.get("songid"):
self.handle_song_change(self.now_playing)
self.now_playing = None
def on_pause(self, status):
- self._log.info('pause')
+ self._log.info("pause")
self.now_playing = None
def on_play(self, status):
-
path, songid = self.mpd.currentsong()
if not path:
return
- played, duration = map(int, status['time'].split(':', 1))
+ played, duration = map(int, status["time"].split(":", 1))
remaining = duration - played
if self.now_playing:
- if self.now_playing['path'] != path:
+ if self.now_playing["path"] != path:
self.handle_song_change(self.now_playing)
else:
# In case we got mpd play event with same song playing
# multiple times,
# assume low diff means redundant second play event
# after natural song start.
- diff = abs(time.time() - self.now_playing['started'])
+ diff = abs(time.time() - self.now_playing["started"])
if diff <= self.time_threshold:
return
- if self.now_playing['path'] == path and played == 0:
+ if self.now_playing["path"] == path and played == 0:
self.handle_song_change(self.now_playing)
if is_url(path):
- self._log.info('playing stream {0}', displayable_path(path))
+ self._log.info("playing stream {0}", displayable_path(path))
self.now_playing = None
return
- self._log.info('playing {0}', displayable_path(path))
+ self._log.info("playing {0}", displayable_path(path))
self.now_playing = {
- 'started': time.time(),
- 'remaining': remaining,
- 'path': path,
- 'id': songid,
- 'beets_item': self.get_item(path),
+ "started": time.time(),
+ "remaining": remaining,
+ "path": path,
+ "id": songid,
+ "beets_item": self.get_item(path),
}
- self.update_item(self.now_playing['beets_item'],
- 'last_played', value=int(time.time()))
+ self.update_item(
+ self.now_playing["beets_item"],
+ "last_played",
+ value=int(time.time()),
+ )
def run(self):
self.mpd.connect()
- events = ['player']
+ events = ["player"]
while True:
- if 'player' in events:
+ if "player" in events:
status = self.mpd.status()
- handler = getattr(self, 'on_' + status['state'], None)
+ handler = getattr(self, "on_" + status["state"], None)
if handler:
handler(status)
@@ -325,51 +318,61 @@ class MPDStats:
class MPDStatsPlugin(plugins.BeetsPlugin):
-
item_types = {
- 'play_count': types.INTEGER,
- 'skip_count': types.INTEGER,
- 'last_played': library.DateType(),
- 'rating': types.FLOAT,
+ "play_count": types.INTEGER,
+ "skip_count": types.INTEGER,
+ "last_played": library.DateType(),
+ "rating": types.FLOAT,
}
def __init__(self):
super().__init__()
- mpd_config.add({
- 'music_directory': config['directory'].as_filename(),
- 'strip_path': '',
- 'rating': True,
- 'rating_mix': 0.75,
- 'host': os.environ.get('MPD_HOST', 'localhost'),
- 'port': int(os.environ.get('MPD_PORT', 6600)),
- 'password': '',
- })
- mpd_config['password'].redact = True
+ mpd_config.add(
+ {
+ "music_directory": config["directory"].as_filename(),
+ "strip_path": "",
+ "rating": True,
+ "rating_mix": 0.75,
+ "host": os.environ.get("MPD_HOST", "localhost"),
+ "port": int(os.environ.get("MPD_PORT", 6600)),
+ "password": "",
+ }
+ )
+ mpd_config["password"].redact = True
def commands(self):
cmd = ui.Subcommand(
- 'mpdstats',
- help='run a MPD client to gather play statistics')
+ "mpdstats", help="run a MPD client to gather play statistics"
+ )
cmd.parser.add_option(
- '--host', dest='host', type='string',
- help='set the hostname of the server to connect to')
+ "--host",
+ dest="host",
+ type="string",
+ help="set the hostname of the server to connect to",
+ )
cmd.parser.add_option(
- '--port', dest='port', type='int',
- help='set the port of the MPD server to connect to')
+ "--port",
+ dest="port",
+ type="int",
+ help="set the port of the MPD server to connect to",
+ )
cmd.parser.add_option(
- '--password', dest='password', type='string',
- help='set the password of the MPD server to connect to')
+ "--password",
+ dest="password",
+ type="string",
+ help="set the password of the MPD server to connect to",
+ )
def func(lib, opts, args):
mpd_config.set_args(opts)
# Overrides for MPD settings.
if opts.host:
- mpd_config['host'] = opts.host.decode('utf-8')
+ mpd_config["host"] = opts.host.decode("utf-8")
if opts.port:
- mpd_config['host'] = int(opts.port)
+ mpd_config["host"] = int(opts.port)
if opts.password:
- mpd_config['password'] = opts.password.decode('utf-8')
+ mpd_config["password"] = opts.password.decode("utf-8")
try:
MPDStats(lib, self._log).run()
diff --git a/lib/beetsplug/mpdupdate.py b/lib/beetsplug/mpdupdate.py
index e5264e18..cb53afaa 100644
--- a/lib/beetsplug/mpdupdate.py
+++ b/lib/beetsplug/mpdupdate.py
@@ -21,10 +21,11 @@ Put something like the following in your config.yaml to configure:
password: seekrit
"""
-from beets.plugins import BeetsPlugin
import os
import socket
+
from beets import config
+from beets.plugins import BeetsPlugin
# No need to introduce a dependency on an MPD library for such a
@@ -32,14 +33,15 @@ from beets import config
# easier.
class BufferedSocket:
"""Socket abstraction that allows reading by line."""
- def __init__(self, host, port, sep=b'\n'):
- if host[0] in ['/', '~']:
+
+ def __init__(self, host, port, sep=b"\n"):
+ if host[0] in ["/", "~"]:
self.sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
self.sock.connect(os.path.expanduser(host))
else:
self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.sock.connect((host, port))
- self.buf = b''
+ self.buf = b""
self.sep = sep
def readline(self):
@@ -52,7 +54,7 @@ class BufferedSocket:
res, self.buf = self.buf.split(self.sep, 1)
return res + self.sep
else:
- return b''
+ return b""
def send(self, data):
self.sock.send(data)
@@ -64,63 +66,64 @@ class BufferedSocket:
class MPDUpdatePlugin(BeetsPlugin):
def __init__(self):
super().__init__()
- config['mpd'].add({
- 'host': os.environ.get('MPD_HOST', 'localhost'),
- 'port': int(os.environ.get('MPD_PORT', 6600)),
- 'password': '',
- })
- config['mpd']['password'].redact = True
+ config["mpd"].add(
+ {
+ "host": os.environ.get("MPD_HOST", "localhost"),
+ "port": int(os.environ.get("MPD_PORT", 6600)),
+ "password": "",
+ }
+ )
+ config["mpd"]["password"].redact = True
# For backwards compatibility, use any values from the
# plugin-specific "mpdupdate" section.
- for key in config['mpd'].keys():
+ for key in config["mpd"].keys():
if self.config[key].exists():
- config['mpd'][key] = self.config[key].get()
+ config["mpd"][key] = self.config[key].get()
- self.register_listener('database_change', self.db_change)
+ self.register_listener("database_change", self.db_change)
def db_change(self, lib, model):
- self.register_listener('cli_exit', self.update)
+ self.register_listener("cli_exit", self.update)
def update(self, lib):
self.update_mpd(
- config['mpd']['host'].as_str(),
- config['mpd']['port'].get(int),
- config['mpd']['password'].as_str(),
+ config["mpd"]["host"].as_str(),
+ config["mpd"]["port"].get(int),
+ config["mpd"]["password"].as_str(),
)
- def update_mpd(self, host='localhost', port=6600, password=None):
+ def update_mpd(self, host="localhost", port=6600, password=None):
"""Sends the "update" command to the MPD server indicated,
possibly authenticating with a password first.
"""
- self._log.info('Updating MPD database...')
+ self._log.info("Updating MPD database...")
try:
s = BufferedSocket(host, port)
except OSError as e:
- self._log.warning('MPD connection failed: {0}',
- str(e.strerror))
+ self._log.warning("MPD connection failed: {0}", str(e.strerror))
return
resp = s.readline()
- if b'OK MPD' not in resp:
- self._log.warning('MPD connection failed: {0!r}', resp)
+ if b"OK MPD" not in resp:
+ self._log.warning("MPD connection failed: {0!r}", resp)
return
if password:
- s.send(b'password "%s"\n' % password.encode('utf8'))
+ s.send(b'password "%s"\n' % password.encode("utf8"))
resp = s.readline()
- if b'OK' not in resp:
- self._log.warning('Authentication failed: {0!r}', resp)
- s.send(b'close\n')
+ if b"OK" not in resp:
+ self._log.warning("Authentication failed: {0!r}", resp)
+ s.send(b"close\n")
s.close()
return
- s.send(b'update\n')
+ s.send(b"update\n")
resp = s.readline()
- if b'updating_db' not in resp:
- self._log.warning('Update failed: {0!r}', resp)
+ if b"updating_db" not in resp:
+ self._log.warning("Update failed: {0!r}", resp)
- s.send(b'close\n')
+ s.send(b"close\n")
s.close()
- self._log.info('Database updated.')
+ self._log.info("Database updated.")
diff --git a/lib/beetsplug/parentwork.py b/lib/beetsplug/parentwork.py
index 75307b8f..4ddef1c1 100644
--- a/lib/beetsplug/parentwork.py
+++ b/lib/beetsplug/parentwork.py
@@ -17,37 +17,38 @@ and work composition date
"""
+import musicbrainzngs
+
from beets import ui
from beets.plugins import BeetsPlugin
-import musicbrainzngs
-
def direct_parent_id(mb_workid, work_date=None):
"""Given a Musicbrainz work id, find the id one of the works the work is
part of and the first composition date it encounters.
"""
- work_info = musicbrainzngs.get_work_by_id(mb_workid,
- includes=["work-rels",
- "artist-rels"])
- if 'artist-relation-list' in work_info['work'] and work_date is None:
- for artist in work_info['work']['artist-relation-list']:
- if artist['type'] == 'composer':
- if 'end' in artist.keys():
- work_date = artist['end']
+ work_info = musicbrainzngs.get_work_by_id(
+ mb_workid, includes=["work-rels", "artist-rels"]
+ )
+ if "artist-relation-list" in work_info["work"] and work_date is None:
+ for artist in work_info["work"]["artist-relation-list"]:
+ if artist["type"] == "composer":
+ if "end" in artist.keys():
+ work_date = artist["end"]
- if 'work-relation-list' in work_info['work']:
- for direct_parent in work_info['work']['work-relation-list']:
- if direct_parent['type'] == 'parts' \
- and direct_parent.get('direction') == 'backward':
- direct_id = direct_parent['work']['id']
+ if "work-relation-list" in work_info["work"]:
+ for direct_parent in work_info["work"]["work-relation-list"]:
+ if (
+ direct_parent["type"] == "parts"
+ and direct_parent.get("direction") == "backward"
+ ):
+ direct_id = direct_parent["work"]["id"]
return direct_id, work_date
return None, work_date
def work_parent_id(mb_workid):
- """Find the parent work id and composition date of a work given its id.
- """
+ """Find the parent work id and composition date of a work given its id."""
work_date = None
while True:
new_mb_workid, work_date = direct_parent_id(mb_workid, work_date)
@@ -62,8 +63,9 @@ def find_parentwork_info(mb_workid):
the artist relations, and the composition date for a work's parent work.
"""
parent_id, work_date = work_parent_id(mb_workid)
- work_info = musicbrainzngs.get_work_by_id(parent_id,
- includes=["artist-rels"])
+ work_info = musicbrainzngs.get_work_by_id(
+ parent_id, includes=["artist-rels"]
+ )
return work_info, work_date
@@ -71,19 +73,20 @@ class ParentWorkPlugin(BeetsPlugin):
def __init__(self):
super().__init__()
- self.config.add({
- 'auto': False,
- 'force': False,
- })
+ self.config.add(
+ {
+ "auto": False,
+ "force": False,
+ }
+ )
- if self.config['auto']:
+ if self.config["auto"]:
self.import_stages = [self.imported]
def commands(self):
-
def func(lib, opts, args):
self.config.set_args(opts)
- force_parent = self.config['force'].get(bool)
+ force_parent = self.config["force"].get(bool)
write = ui.should_write()
for item in lib.items(ui.decargs(args)):
@@ -92,22 +95,26 @@ class ParentWorkPlugin(BeetsPlugin):
item.store()
if write:
item.try_write()
+
command = ui.Subcommand(
- 'parentwork',
- help='fetch parent works, composers and dates')
+ "parentwork", help="fetch parent works, composers and dates"
+ )
command.parser.add_option(
- '-f', '--force', dest='force',
- action='store_true', default=None,
- help='re-fetch when parent work is already present')
+ "-f",
+ "--force",
+ dest="force",
+ action="store_true",
+ default=None,
+ help="re-fetch when parent work is already present",
+ )
command.func = func
return [command]
def imported(self, session, task):
- """Import hook for fetching parent works automatically.
- """
- force_parent = self.config['force'].get(bool)
+ """Import hook for fetching parent works automatically."""
+ force_parent = self.config["force"].get(bool)
for item in task.imported_items():
self.find_work(item, force_parent)
@@ -124,35 +131,38 @@ class ParentWorkPlugin(BeetsPlugin):
parentwork_info = {}
composer_exists = False
- if 'artist-relation-list' in work_info['work']:
- for artist in work_info['work']['artist-relation-list']:
- if artist['type'] == 'composer':
+ if "artist-relation-list" in work_info["work"]:
+ for artist in work_info["work"]["artist-relation-list"]:
+ if artist["type"] == "composer":
composer_exists = True
- parent_composer.append(artist['artist']['name'])
- parent_composer_sort.append(artist['artist']['sort-name'])
- if 'end' in artist.keys():
- parentwork_info["parentwork_date"] = artist['end']
+ parent_composer.append(artist["artist"]["name"])
+ parent_composer_sort.append(artist["artist"]["sort-name"])
+ if "end" in artist.keys():
+ parentwork_info["parentwork_date"] = artist["end"]
- parentwork_info['parent_composer'] = ', '.join(parent_composer)
- parentwork_info['parent_composer_sort'] = ', '.join(
- parent_composer_sort)
+ parentwork_info["parent_composer"] = ", ".join(parent_composer)
+ parentwork_info["parent_composer_sort"] = ", ".join(
+ parent_composer_sort
+ )
if not composer_exists:
self._log.debug(
- 'no composer for {}; add one at '
- 'https://musicbrainz.org/work/{}',
- item, work_info['work']['id'],
+ "no composer for {}; add one at "
+ "https://musicbrainz.org/work/{}",
+ item,
+ work_info["work"]["id"],
)
- parentwork_info['parentwork'] = work_info['work']['title']
- parentwork_info['mb_parentworkid'] = work_info['work']['id']
+ parentwork_info["parentwork"] = work_info["work"]["title"]
+ parentwork_info["mb_parentworkid"] = work_info["work"]["id"]
- if 'disambiguation' in work_info['work']:
- parentwork_info['parentwork_disambig'] = work_info[
- 'work']['disambiguation']
+ if "disambiguation" in work_info["work"]:
+ parentwork_info["parentwork_disambig"] = work_info["work"][
+ "disambiguation"
+ ]
else:
- parentwork_info['parentwork_disambig'] = None
+ parentwork_info["parentwork_disambig"] = None
return parentwork_info
@@ -169,13 +179,17 @@ class ParentWorkPlugin(BeetsPlugin):
"""
if not item.mb_workid:
- self._log.info('No work for {}, \
-add one at https://musicbrainz.org/recording/{}', item, item.mb_trackid)
+ self._log.info(
+ "No work for {}, \
+add one at https://musicbrainz.org/recording/{}",
+ item,
+ item.mb_trackid,
+ )
return
- hasparent = hasattr(item, 'parentwork')
+ hasparent = hasattr(item, "parentwork")
work_changed = True
- if hasattr(item, 'parentwork_workid_current'):
+ if hasattr(item, "parentwork_workid_current"):
work_changed = item.parentwork_workid_current != item.mb_workid
if force or not hasparent or work_changed:
try:
@@ -184,14 +198,18 @@ add one at https://musicbrainz.org/recording/{}', item, item.mb_trackid)
self._log.debug("error fetching work: {}", e)
return
parent_info = self.get_info(item, work_info)
- parent_info['parentwork_workid_current'] = item.mb_workid
- if 'parent_composer' in parent_info:
- self._log.debug("Work fetched: {} - {}",
- parent_info['parentwork'],
- parent_info['parent_composer'])
+ parent_info["parentwork_workid_current"] = item.mb_workid
+ if "parent_composer" in parent_info:
+ self._log.debug(
+ "Work fetched: {} - {}",
+ parent_info["parentwork"],
+ parent_info["parent_composer"],
+ )
else:
- self._log.debug("Work fetched: {} - no parent composer",
- parent_info['parentwork'])
+ self._log.debug(
+ "Work fetched: {} - no parent composer",
+ parent_info["parentwork"],
+ )
elif hasparent:
self._log.debug("{}: Work present, skipping", item)
@@ -203,9 +221,17 @@ add one at https://musicbrainz.org/recording/{}', item, item.mb_trackid)
item[key] = value
if work_date:
- item['work_date'] = work_date
+ item["work_date"] = work_date
return ui.show_model_changes(
- item, fields=['parentwork', 'parentwork_disambig',
- 'mb_parentworkid', 'parent_composer',
- 'parent_composer_sort', 'work_date',
- 'parentwork_workid_current', 'parentwork_date'])
+ item,
+ fields=[
+ "parentwork",
+ "parentwork_disambig",
+ "mb_parentworkid",
+ "parent_composer",
+ "parent_composer_sort",
+ "work_date",
+ "parentwork_workid_current",
+ "parentwork_date",
+ ],
+ )
diff --git a/lib/beetsplug/permissions.py b/lib/beetsplug/permissions.py
index f5aab056..8f58f24b 100644
--- a/lib/beetsplug/permissions.py
+++ b/lib/beetsplug/permissions.py
@@ -5,10 +5,13 @@ like the following in your config.yaml to configure:
file: 644
dir: 755
"""
+
import os
-from beets import config, util
+import stat
+
+from beets import config
from beets.plugins import BeetsPlugin
-from beets.util import ancestry
+from beets.util import ancestry, displayable_path, syspath
def convert_perm(perm):
@@ -25,7 +28,7 @@ def check_permissions(path, permission):
"""Check whether the file's permissions equal the given vector.
Return a boolean.
"""
- return oct(os.stat(path).st_mode & 0o777) == oct(permission)
+ return oct(stat.S_IMODE(os.stat(syspath(path)).st_mode)) == oct(permission)
def assert_permissions(path, permission, log):
@@ -33,24 +36,20 @@ def assert_permissions(path, permission, log):
log a warning message. Return a boolean indicating the match, like
`check_permissions`.
"""
- if not check_permissions(util.syspath(path), permission):
- log.warning(
- 'could not set permissions on {}',
- util.displayable_path(path),
- )
+ if not check_permissions(path, permission):
+ log.warning("could not set permissions on {}", displayable_path(path))
log.debug(
- 'set permissions to {}, but permissions are now {}',
+ "set permissions to {}, but permissions are now {}",
permission,
- os.stat(util.syspath(path)).st_mode & 0o777,
+ os.stat(syspath(path)).st_mode & 0o777,
)
def dirs_in_library(library, item):
- """Creates a list of ancestor directories in the beets library path.
- """
- return [ancestor
- for ancestor in ancestry(item)
- if ancestor.startswith(library)][1:]
+ """Creates a list of ancestor directories in the beets library path."""
+ return [
+ ancestor for ancestor in ancestry(item) if ancestor.startswith(library)
+ ][1:]
class Permissions(BeetsPlugin):
@@ -58,18 +57,19 @@ class Permissions(BeetsPlugin):
super().__init__()
# Adding defaults.
- self.config.add({
- 'file': '644',
- 'dir': '755',
- })
+ self.config.add(
+ {
+ "file": "644",
+ "dir": "755",
+ }
+ )
- self.register_listener('item_imported', self.fix)
- self.register_listener('album_imported', self.fix)
- self.register_listener('art_set', self.fix_art)
+ self.register_listener("item_imported", self.fix)
+ self.register_listener("album_imported", self.fix)
+ self.register_listener("art_set", self.fix_art)
def fix(self, lib, item=None, album=None):
- """Fix the permissions for an imported Item or Album.
- """
+ """Fix the permissions for an imported Item or Album."""
files = []
dirs = set()
if item:
@@ -82,8 +82,7 @@ class Permissions(BeetsPlugin):
self.set_permissions(files=files, dirs=dirs)
def fix_art(self, album):
- """Fix the permission for Album art file.
- """
+ """Fix the permission for Album art file."""
if album.artpath:
self.set_permissions(files=[album.artpath])
@@ -92,18 +91,19 @@ class Permissions(BeetsPlugin):
# string (in YAML quotes) or, for convenience, as an integer so the
# quotes can be omitted. In the latter case, we need to reinterpret the
# integer as octal, not decimal.
- file_perm = config['permissions']['file'].get()
- dir_perm = config['permissions']['dir'].get()
+ file_perm = config["permissions"]["file"].get()
+ dir_perm = config["permissions"]["dir"].get()
file_perm = convert_perm(file_perm)
dir_perm = convert_perm(dir_perm)
for path in files:
# Changing permissions on the destination file.
self._log.debug(
- 'setting file permissions on {}',
- util.displayable_path(path),
+ "setting file permissions on {}",
+ displayable_path(path),
)
- os.chmod(util.syspath(path), file_perm)
+ if not check_permissions(path, file_perm):
+ os.chmod(syspath(path), file_perm)
# Checks if the destination path has the permissions configured.
assert_permissions(path, file_perm, self._log)
@@ -112,10 +112,11 @@ class Permissions(BeetsPlugin):
for path in dirs:
# Changing permissions on the destination directory.
self._log.debug(
- 'setting directory permissions on {}',
- util.displayable_path(path),
+ "setting directory permissions on {}",
+ displayable_path(path),
)
- os.chmod(util.syspath(path), dir_perm)
+ if not check_permissions(path, dir_perm):
+ os.chmod(syspath(path), dir_perm)
# Checks if the destination path has the permissions configured.
assert_permissions(path, dir_perm, self._log)
diff --git a/lib/beetsplug/play.py b/lib/beetsplug/play.py
index f4233490..3476e582 100644
--- a/lib/beetsplug/play.py
+++ b/lib/beetsplug/play.py
@@ -15,31 +15,37 @@
"""Send the results of a query to the configured music player as a playlist.
"""
+import shlex
+import subprocess
+from os.path import relpath
+
+from beets import config, ui, util
from beets.plugins import BeetsPlugin
from beets.ui import Subcommand
from beets.ui.commands import PromptChoice
-from beets import config
-from beets import ui
-from beets import util
-from os.path import relpath
-from tempfile import NamedTemporaryFile
-import subprocess
-import shlex
+from beets.util import get_temp_filename
# Indicate where arguments should be inserted into the command string.
# If this is missing, they're placed at the end.
-ARGS_MARKER = '$args'
+ARGS_MARKER = "$args"
-def play(command_str, selection, paths, open_args, log, item_type='track',
- keep_open=False):
+def play(
+ command_str,
+ selection,
+ paths,
+ open_args,
+ log,
+ item_type="track",
+ keep_open=False,
+):
"""Play items in paths with command_str and optional arguments. If
keep_open, return to beets, otherwise exit once command runs.
"""
# Print number of tracks or albums to be played, log command to be run.
- item_type += 's' if len(selection) > 1 else ''
- ui.print_('Playing {} {}.'.format(len(selection), item_type))
- log.debug('executing command: {} {!r}', command_str, open_args)
+ item_type += "s" if len(selection) > 1 else ""
+ ui.print_("Playing {} {}.".format(len(selection), item_type))
+ log.debug("executing command: {} {!r}", command_str, open_args)
try:
if keep_open:
@@ -49,42 +55,44 @@ def play(command_str, selection, paths, open_args, log, item_type='track',
else:
util.interactive_open(open_args, command_str)
except OSError as exc:
- raise ui.UserError(
- f"Could not play the query: {exc}")
+ raise ui.UserError(f"Could not play the query: {exc}")
class PlayPlugin(BeetsPlugin):
-
def __init__(self):
super().__init__()
- config['play'].add({
- 'command': None,
- 'use_folders': False,
- 'relative_to': None,
- 'raw': False,
- 'warning_threshold': 100,
- 'bom': False,
- })
+ config["play"].add(
+ {
+ "command": None,
+ "use_folders": False,
+ "relative_to": None,
+ "raw": False,
+ "warning_threshold": 100,
+ "bom": False,
+ }
+ )
- self.register_listener('before_choose_candidate',
- self.before_choose_candidate_listener)
+ self.register_listener(
+ "before_choose_candidate", self.before_choose_candidate_listener
+ )
def commands(self):
play_command = Subcommand(
- 'play',
- help='send music to a player as a playlist'
+ "play", help="send music to a player as a playlist"
)
play_command.parser.add_album_option()
play_command.parser.add_option(
- '-A', '--args',
- action='store',
- help='add additional arguments to the command',
+ "-A",
+ "--args",
+ action="store",
+ help="add additional arguments to the command",
)
play_command.parser.add_option(
- '-y', '--yes',
+ "-y",
+ "--yes",
action="store_true",
- help='skip the warning threshold',
+ help="skip the warning threshold",
)
play_command.func = self._play_command
return [play_command]
@@ -93,8 +101,8 @@ class PlayPlugin(BeetsPlugin):
"""The CLI command function for `beet play`. Create a list of paths
from query, determine if tracks or albums are to be played.
"""
- use_folders = config['play']['use_folders'].get(bool)
- relative_to = config['play']['relative_to'].get()
+ use_folders = config["play"]["use_folders"].get(bool)
+ relative_to = config["play"]["relative_to"].get()
if relative_to:
relative_to = util.normpath(relative_to)
# Perform search by album and add folders rather than tracks to
@@ -108,22 +116,20 @@ class PlayPlugin(BeetsPlugin):
if use_folders:
paths.append(album.item_dir())
else:
- paths.extend(item.path
- for item in sort.sort(album.items()))
- item_type = 'album'
+ paths.extend(item.path for item in sort.sort(album.items()))
+ item_type = "album"
# Perform item query and add tracks to playlist.
else:
selection = lib.items(ui.decargs(args))
paths = [item.path for item in selection]
- item_type = 'track'
+ item_type = "track"
if relative_to:
paths = [relpath(path, relative_to) for path in paths]
if not selection:
- ui.print_(ui.colorize('text_warning',
- f'No {item_type} to play.'))
+ ui.print_(ui.colorize("text_warning", f"No {item_type} to play."))
return
open_args = self._playlist_or_paths(paths)
@@ -132,14 +138,13 @@ class PlayPlugin(BeetsPlugin):
# Check if the selection exceeds configured threshold. If True,
# cancel, otherwise proceed with play command.
if opts.yes or not self._exceeds_threshold(
- selection, command_str, open_args, item_type):
- play(command_str, selection, paths, open_args, self._log,
- item_type)
+ selection, command_str, open_args, item_type
+ ):
+ play(command_str, selection, paths, open_args, self._log, item_type)
def _command_str(self, args=None):
- """Create a command string from the config command and optional args.
- """
- command_str = config['play']['command'].get()
+ """Create a command string from the config command and optional args."""
+ command_str = config["play"]["command"].get()
if not command_str:
return util.open_anything()
# Add optional arguments to the player command.
@@ -153,57 +158,58 @@ class PlayPlugin(BeetsPlugin):
return command_str.replace(" " + ARGS_MARKER, "")
def _playlist_or_paths(self, paths):
- """Return either the raw paths of items or a playlist of the items.
- """
- if config['play']['raw']:
+ """Return either the raw paths of items or a playlist of the items."""
+ if config["play"]["raw"]:
return paths
else:
return [self._create_tmp_playlist(paths)]
- def _exceeds_threshold(self, selection, command_str, open_args,
- item_type='track'):
+ def _exceeds_threshold(
+ self, selection, command_str, open_args, item_type="track"
+ ):
"""Prompt user whether to abort if playlist exceeds threshold. If
True, cancel playback. If False, execute play command.
"""
- warning_threshold = config['play']['warning_threshold'].get(int)
+ warning_threshold = config["play"]["warning_threshold"].get(int)
# Warn user before playing any huge playlists.
if warning_threshold and len(selection) > warning_threshold:
if len(selection) > 1:
- item_type += 's'
+ item_type += "s"
- ui.print_(ui.colorize(
- 'text_warning',
- 'You are about to queue {} {}.'.format(
- len(selection), item_type)))
+ ui.print_(
+ ui.colorize(
+ "text_warning",
+ "You are about to queue {} {}.".format(
+ len(selection), item_type
+ ),
+ )
+ )
- if ui.input_options(('Continue', 'Abort')) == 'a':
+ if ui.input_options(("Continue", "Abort")) == "a":
return True
return False
def _create_tmp_playlist(self, paths_list):
- """Create a temporary .m3u file. Return the filename.
- """
- utf8_bom = config['play']['bom'].get(bool)
- m3u = NamedTemporaryFile('wb', suffix='.m3u', delete=False)
+ """Create a temporary .m3u file. Return the filename."""
+ utf8_bom = config["play"]["bom"].get(bool)
+ filename = get_temp_filename(__name__, suffix=".m3u")
+ with open(filename, "wb") as m3u:
+ if utf8_bom:
+ m3u.write(b"\xEF\xBB\xBF")
- if utf8_bom:
- m3u.write(b'\xEF\xBB\xBF')
+ for item in paths_list:
+ m3u.write(item + b"\n")
- for item in paths_list:
- m3u.write(item + b'\n')
- m3u.close()
- return m3u.name
+ return filename
def before_choose_candidate_listener(self, session, task):
- """Append a "Play" choice to the interactive importer prompt.
- """
- return [PromptChoice('y', 'plaY', self.importer_play)]
+ """Append a "Play" choice to the interactive importer prompt."""
+ return [PromptChoice("y", "plaY", self.importer_play)]
def importer_play(self, session, task):
- """Get items from current import task and send to play function.
- """
+ """Get items from current import task and send to play function."""
selection = task.items
paths = [item.path for item in selection]
@@ -211,5 +217,11 @@ class PlayPlugin(BeetsPlugin):
command_str = self._command_str()
if not self._exceeds_threshold(selection, command_str, open_args):
- play(command_str, selection, paths, open_args, self._log,
- keep_open=True)
+ play(
+ command_str,
+ selection,
+ paths,
+ open_args,
+ self._log,
+ keep_open=True,
+ )
diff --git a/lib/beetsplug/playlist.py b/lib/beetsplug/playlist.py
index 265b8bad..83f95796 100644
--- a/lib/beetsplug/playlist.py
+++ b/lib/beetsplug/playlist.py
@@ -12,98 +12,104 @@
# included in all copies or substantial portions of the Software.
-import os
import fnmatch
+import os
import tempfile
+from typing import Sequence
+
import beets
+from beets.dbcore.query import InQuery
+from beets.library import BLOB_TYPE
from beets.util import path_as_posix
-class PlaylistQuery(beets.dbcore.Query):
- """Matches files listed by a playlist file.
- """
- def __init__(self, pattern):
- self.pattern = pattern
- config = beets.config['playlist']
+class PlaylistQuery(InQuery[bytes]):
+ """Matches files listed by a playlist file."""
+
+ @property
+ def subvals(self) -> Sequence[BLOB_TYPE]:
+ return [BLOB_TYPE(p) for p in self.pattern]
+
+ def __init__(self, _, pattern: str, __):
+ config = beets.config["playlist"]
# Get the full path to the playlist
playlist_paths = (
pattern,
- os.path.abspath(os.path.join(
- config['playlist_dir'].as_filename(),
- f'{pattern}.m3u',
- )),
+ os.path.abspath(
+ os.path.join(
+ config["playlist_dir"].as_filename(),
+ f"{pattern}.m3u",
+ )
+ ),
)
- self.paths = []
+ paths = []
for playlist_path in playlist_paths:
- if not fnmatch.fnmatch(playlist_path, '*.[mM]3[uU]'):
+ if not fnmatch.fnmatch(playlist_path, "*.[mM]3[uU]"):
# This is not am M3U playlist, skip this candidate
continue
try:
- f = open(beets.util.syspath(playlist_path), mode='rb')
+ f = open(beets.util.syspath(playlist_path), mode="rb")
except OSError:
continue
- if config['relative_to'].get() == 'library':
- relative_to = beets.config['directory'].as_filename()
- elif config['relative_to'].get() == 'playlist':
+ if config["relative_to"].get() == "library":
+ relative_to = beets.config["directory"].as_filename()
+ elif config["relative_to"].get() == "playlist":
relative_to = os.path.dirname(playlist_path)
else:
- relative_to = config['relative_to'].as_filename()
+ relative_to = config["relative_to"].as_filename()
relative_to = beets.util.bytestring_path(relative_to)
for line in f:
- if line[0] == '#':
+ if line[0] == "#":
# ignore comments, and extm3u extension
continue
- self.paths.append(beets.util.normpath(
- os.path.join(relative_to, line.rstrip())
- ))
+ paths.append(
+ beets.util.normpath(
+ os.path.join(relative_to, line.rstrip())
+ )
+ )
f.close()
break
-
- def col_clause(self):
- if not self.paths:
- # Playlist is empty
- return '0', ()
- clause = 'path IN ({})'.format(', '.join('?' for path in self.paths))
- return clause, (beets.library.BLOB_TYPE(p) for p in self.paths)
-
- def match(self, item):
- return item.path in self.paths
+ super().__init__("path", paths)
class PlaylistPlugin(beets.plugins.BeetsPlugin):
- item_queries = {'playlist': PlaylistQuery}
+ item_queries = {"playlist": PlaylistQuery}
def __init__(self):
super().__init__()
- self.config.add({
- 'auto': False,
- 'playlist_dir': '.',
- 'relative_to': 'library',
- 'forward_slash': False,
- })
+ self.config.add(
+ {
+ "auto": False,
+ "playlist_dir": ".",
+ "relative_to": "library",
+ "forward_slash": False,
+ }
+ )
- self.playlist_dir = self.config['playlist_dir'].as_filename()
+ self.playlist_dir = self.config["playlist_dir"].as_filename()
self.changes = {}
- if self.config['relative_to'].get() == 'library':
+ if self.config["relative_to"].get() == "library":
self.relative_to = beets.util.bytestring_path(
- beets.config['directory'].as_filename())
- elif self.config['relative_to'].get() != 'playlist':
+ beets.config["directory"].as_filename()
+ )
+ elif self.config["relative_to"].get() != "playlist":
self.relative_to = beets.util.bytestring_path(
- self.config['relative_to'].as_filename())
+ self.config["relative_to"].as_filename()
+ )
else:
self.relative_to = None
- if self.config['auto']:
- self.register_listener('item_moved', self.item_moved)
- self.register_listener('item_removed', self.item_removed)
- self.register_listener('cli_exit', self.cli_exit)
+ if self.config["auto"]:
+ self.register_listener("item_moved", self.item_moved)
+ self.register_listener("item_removed", self.item_removed)
+ self.register_listener("cli_exit", self.cli_exit)
def item_moved(self, item, source, destination):
self.changes[source] = destination
@@ -114,29 +120,36 @@ class PlaylistPlugin(beets.plugins.BeetsPlugin):
def cli_exit(self, lib):
for playlist in self.find_playlists():
- self._log.info(f'Updating playlist: {playlist}')
+ self._log.info(f"Updating playlist: {playlist}")
base_dir = beets.util.bytestring_path(
- self.relative_to if self.relative_to
+ self.relative_to
+ if self.relative_to
else os.path.dirname(playlist)
)
try:
self.update_playlist(playlist, base_dir)
except beets.util.FilesystemError:
- self._log.error('Failed to update playlist: {}'.format(
- beets.util.displayable_path(playlist)))
+ self._log.error(
+ "Failed to update playlist: {}".format(
+ beets.util.displayable_path(playlist)
+ )
+ )
def find_playlists(self):
"""Find M3U playlists in the playlist directory."""
try:
dir_contents = os.listdir(beets.util.syspath(self.playlist_dir))
except OSError:
- self._log.warning('Unable to open playlist directory {}'.format(
- beets.util.displayable_path(self.playlist_dir)))
+ self._log.warning(
+ "Unable to open playlist directory {}".format(
+ beets.util.displayable_path(self.playlist_dir)
+ )
+ )
return
for filename in dir_contents:
- if fnmatch.fnmatch(filename, '*.[mM]3[uU]'):
+ if fnmatch.fnmatch(filename, "*.[mM]3[uU]"):
yield os.path.join(self.playlist_dir, filename)
def update_playlist(self, filename, base_dir):
@@ -144,11 +157,11 @@ class PlaylistPlugin(beets.plugins.BeetsPlugin):
changes = 0
deletions = 0
- with tempfile.NamedTemporaryFile(mode='w+b', delete=False) as tempfp:
+ with tempfile.NamedTemporaryFile(mode="w+b", delete=False) as tempfp:
new_playlist = tempfp.name
- with open(filename, mode='rb') as fp:
+ with open(filename, mode="rb") as fp:
for line in fp:
- original_path = line.rstrip(b'\r\n')
+ original_path = line.rstrip(b"\r\n")
# Ensure that path from playlist is absolute
is_relative = not os.path.isabs(line)
@@ -160,7 +173,7 @@ class PlaylistPlugin(beets.plugins.BeetsPlugin):
try:
new_path = self.changes[beets.util.normpath(lookup)]
except KeyError:
- if self.config['forward_slash']:
+ if self.config["forward_slash"]:
line = path_as_posix(line)
tempfp.write(line)
else:
@@ -173,13 +186,15 @@ class PlaylistPlugin(beets.plugins.BeetsPlugin):
if is_relative:
new_path = os.path.relpath(new_path, base_dir)
line = line.replace(original_path, new_path)
- if self.config['forward_slash']:
+ if self.config["forward_slash"]:
line = path_as_posix(line)
tempfp.write(line)
if changes or deletions:
self._log.info(
- 'Updated playlist {} ({} changes, {} deletions)'.format(
- filename, changes, deletions))
+ "Updated playlist {} ({} changes, {} deletions)".format(
+ filename, changes, deletions
+ )
+ )
beets.util.copy(new_playlist, filename, replace=True)
beets.util.remove(new_playlist)
diff --git a/lib/beetsplug/plexupdate.py b/lib/beetsplug/plexupdate.py
index 2261a55f..9b4419c7 100644
--- a/lib/beetsplug/plexupdate.py
+++ b/lib/beetsplug/plexupdate.py
@@ -8,66 +8,77 @@ Put something like the following in your config.yaml to configure:
token: token
"""
-import requests
+from urllib.parse import urlencode, urljoin
from xml.etree import ElementTree
-from urllib.parse import urljoin, urlencode
+
+import requests
+
from beets import config
from beets.plugins import BeetsPlugin
-def get_music_section(host, port, token, library_name, secure,
- ignore_cert_errors):
- """Getting the section key for the music library in Plex.
- """
- api_endpoint = append_token('library/sections', token)
- url = urljoin('{}://{}:{}'.format(get_protocol(secure), host,
- port), api_endpoint)
+def get_music_section(
+ host, port, token, library_name, secure, ignore_cert_errors
+):
+ """Getting the section key for the music library in Plex."""
+ api_endpoint = append_token("library/sections", token)
+ url = urljoin(
+ "{}://{}:{}".format(get_protocol(secure), host, port), api_endpoint
+ )
# Sends request.
- r = requests.get(url, verify=not ignore_cert_errors)
+ r = requests.get(
+ url,
+ verify=not ignore_cert_errors,
+ timeout=10,
+ )
# Parse xml tree and extract music section key.
tree = ElementTree.fromstring(r.content)
- for child in tree.findall('Directory'):
- if child.get('title') == library_name:
- return child.get('key')
+ for child in tree.findall("Directory"):
+ if child.get("title") == library_name:
+ return child.get("key")
-def update_plex(host, port, token, library_name, secure,
- ignore_cert_errors):
- """Ignore certificate errors if configured to.
- """
+def update_plex(host, port, token, library_name, secure, ignore_cert_errors):
+ """Ignore certificate errors if configured to."""
if ignore_cert_errors:
import urllib3
+
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
"""Sends request to the Plex api to start a library refresh.
"""
# Getting section key and build url.
- section_key = get_music_section(host, port, token, library_name,
- secure, ignore_cert_errors)
- api_endpoint = f'library/sections/{section_key}/refresh'
+ section_key = get_music_section(
+ host, port, token, library_name, secure, ignore_cert_errors
+ )
+ api_endpoint = f"library/sections/{section_key}/refresh"
api_endpoint = append_token(api_endpoint, token)
- url = urljoin('{}://{}:{}'.format(get_protocol(secure), host,
- port), api_endpoint)
+ url = urljoin(
+ "{}://{}:{}".format(get_protocol(secure), host, port), api_endpoint
+ )
# Sends request and returns requests object.
- r = requests.get(url, verify=not ignore_cert_errors)
+ r = requests.get(
+ url,
+ verify=not ignore_cert_errors,
+ timeout=10,
+ )
return r
def append_token(url, token):
- """Appends the Plex Home token to the api call if required.
- """
+ """Appends the Plex Home token to the api call if required."""
if token:
- url += '?' + urlencode({'X-Plex-Token': token})
+ url += "?" + urlencode({"X-Plex-Token": token})
return url
def get_protocol(secure):
if secure:
- return 'https'
+ return "https"
else:
- return 'http'
+ return "http"
class PlexUpdate(BeetsPlugin):
@@ -75,36 +86,39 @@ class PlexUpdate(BeetsPlugin):
super().__init__()
# Adding defaults.
- config['plex'].add({
- 'host': 'localhost',
- 'port': 32400,
- 'token': '',
- 'library_name': 'Music',
- 'secure': False,
- 'ignore_cert_errors': False})
+ config["plex"].add(
+ {
+ "host": "localhost",
+ "port": 32400,
+ "token": "",
+ "library_name": "Music",
+ "secure": False,
+ "ignore_cert_errors": False,
+ }
+ )
- config['plex']['token'].redact = True
- self.register_listener('database_change', self.listen_for_db_change)
+ config["plex"]["token"].redact = True
+ self.register_listener("database_change", self.listen_for_db_change)
def listen_for_db_change(self, lib, model):
"""Listens for beets db change and register the update for the end"""
- self.register_listener('cli_exit', self.update)
+ self.register_listener("cli_exit", self.update)
def update(self, lib):
- """When the client exists try to send refresh request to Plex server.
- """
- self._log.info('Updating Plex library...')
+ """When the client exists try to send refresh request to Plex server."""
+ self._log.info("Updating Plex library...")
# Try to send update request.
try:
update_plex(
- config['plex']['host'].get(),
- config['plex']['port'].get(),
- config['plex']['token'].get(),
- config['plex']['library_name'].get(),
- config['plex']['secure'].get(bool),
- config['plex']['ignore_cert_errors'].get(bool))
- self._log.info('... started.')
+ config["plex"]["host"].get(),
+ config["plex"]["port"].get(),
+ config["plex"]["token"].get(),
+ config["plex"]["library_name"].get(),
+ config["plex"]["secure"].get(bool),
+ config["plex"]["ignore_cert_errors"].get(bool),
+ )
+ self._log.info("... started.")
except requests.exceptions.RequestException:
- self._log.warning('Update failed.')
+ self._log.warning("Update failed.")
diff --git a/lib/beetsplug/random.py b/lib/beetsplug/random.py
index ea9b7b98..dc94a0e3 100644
--- a/lib/beetsplug/random.py
+++ b/lib/beetsplug/random.py
@@ -16,13 +16,12 @@
"""
from beets.plugins import BeetsPlugin
-from beets.ui import Subcommand, decargs, print_
from beets.random import random_objs
+from beets.ui import Subcommand, decargs, print_
def random_func(lib, opts, args):
- """Select some random items or albums and print the results.
- """
+ """Select some random items or albums and print the results."""
# Fetch all the objects matching the query into a list.
query = decargs(args)
if opts.album:
@@ -31,23 +30,35 @@ def random_func(lib, opts, args):
objs = list(lib.items(query))
# Print a random subset.
- objs = random_objs(objs, opts.album, opts.number, opts.time,
- opts.equal_chance)
+ objs = random_objs(
+ objs, opts.album, opts.number, opts.time, opts.equal_chance
+ )
for obj in objs:
print_(format(obj))
-random_cmd = Subcommand('random',
- help='choose a random track or album')
+random_cmd = Subcommand("random", help="choose a random track or album")
random_cmd.parser.add_option(
- '-n', '--number', action='store', type="int",
- help='number of objects to choose', default=1)
+ "-n",
+ "--number",
+ action="store",
+ type="int",
+ help="number of objects to choose",
+ default=1,
+)
random_cmd.parser.add_option(
- '-e', '--equal-chance', action='store_true',
- help='each artist has the same chance')
+ "-e",
+ "--equal-chance",
+ action="store_true",
+ help="each artist has the same chance",
+)
random_cmd.parser.add_option(
- '-t', '--time', action='store', type="float",
- help='total length in minutes of objects to choose')
+ "-t",
+ "--time",
+ action="store",
+ type="float",
+ help="total length in minutes of objects to choose",
+)
random_cmd.parser.add_all_common_options()
random_cmd.func = random_func
diff --git a/lib/beetsplug/replaygain.py b/lib/beetsplug/replaygain.py
index b6297d93..a2753f96 100644
--- a/lib/beetsplug/replaygain.py
+++ b/lib/beetsplug/replaygain.py
@@ -16,23 +16,44 @@
import collections
import enum
import math
+import optparse
import os
+import queue
import signal
import subprocess
import sys
import warnings
-from multiprocessing.pool import ThreadPool, RUN
-from six.moves import queue
-from threading import Thread, Event
+from abc import ABC, abstractmethod
+from dataclasses import dataclass
+from logging import Logger
+from multiprocessing.pool import ThreadPool
+from threading import Event, Thread
+from typing import (
+ Any,
+ Callable,
+ DefaultDict,
+ Dict,
+ List,
+ Optional,
+ Sequence,
+ Tuple,
+ Type,
+ TypeVar,
+ Union,
+ cast,
+)
+
+from confuse import ConfigView
from beets import ui
+from beets.importer import ImportSession, ImportTask
+from beets.library import Album, Item, Library
from beets.plugins import BeetsPlugin
-from beets.util import (syspath, command_output, displayable_path,
- py3_path, cpu_count)
-
+from beets.util import command_output, displayable_path, syspath
# Utilities.
+
class ReplayGainError(Exception):
"""Raised when a local (to a track or an album) error occurs in one
of the backends.
@@ -40,8 +61,7 @@ class ReplayGainError(Exception):
class FatalReplayGainError(Exception):
- """Raised when a fatal error occurs in one of the backends.
- """
+ """Raised when a fatal error occurs in one of the backends."""
class FatalGstreamerPluginReplayGainError(FatalReplayGainError):
@@ -49,13 +69,14 @@ class FatalGstreamerPluginReplayGainError(FatalReplayGainError):
loading the required plugins."""
-def call(args, **kwargs):
+def call(args: List[Any], log: Logger, **kwargs: Any):
"""Execute the command and return its output or raise a
ReplayGainError on failure.
"""
try:
return command_output(args, **kwargs)
except subprocess.CalledProcessError as e:
+ log.debug(e.output.decode("utf8", "ignore"))
raise ReplayGainError(
"{} exited with status {}".format(args[0], e.returncode)
)
@@ -66,12 +87,7 @@ def call(args, **kwargs):
raise ReplayGainError("argument encoding failed")
-def after_version(version_a, version_b):
- return tuple(int(s) for s in version_a.split('.')) \
- >= tuple(int(s) for s in version_b.split('.'))
-
-
-def db_to_lufs(db):
+def db_to_lufs(db: float) -> float:
"""Convert db to LUFS.
According to https://wiki.hydrogenaud.io/index.php?title=
@@ -80,7 +96,7 @@ def db_to_lufs(db):
return db - 107
-def lufs_to_db(db):
+def lufs_to_db(db: float) -> float:
"""Convert LUFS to db.
According to https://wiki.hydrogenaud.io/index.php?title=
@@ -91,59 +107,202 @@ def lufs_to_db(db):
# Backend base and plumbing classes.
-# gain: in LU to reference level
-# peak: part of full scale (FS is 1.0)
-Gain = collections.namedtuple("Gain", "gain peak")
-# album_gain: Gain object
-# track_gains: list of Gain objects
-AlbumGain = collections.namedtuple("AlbumGain", "album_gain track_gains")
+
+@dataclass
+class Gain:
+ # gain: in LU to reference level
+ gain: float
+ # peak: part of full scale (FS is 1.0)
+ peak: float
-class Peak(enum.Enum):
- none = 0
+class PeakMethod(enum.Enum):
true = 1
sample = 2
-class Backend:
- """An abstract class representing engine for calculating RG values.
+class RgTask:
+ """State and methods for a single replaygain calculation (rg version).
+
+ Bundles the state (parameters and results) of a single replaygain
+ calculation (either for one item, one disk, or one full album).
+
+ This class provides methods to store the resulting gains and peaks as plain
+ old rg tags.
"""
+ def __init__(
+ self,
+ items: Sequence[Item],
+ album: Optional[Album],
+ target_level: float,
+ peak_method: Optional[PeakMethod],
+ backend_name: str,
+ log: Logger,
+ ):
+ self.items = items
+ self.album = album
+ self.target_level = target_level
+ self.peak_method = peak_method
+ self.backend_name = backend_name
+ self._log = log
+ self.album_gain: Optional[Gain] = None
+ self.track_gains: Optional[List[Gain]] = None
+
+ def _store_track_gain(self, item: Item, track_gain: Gain):
+ """Store track gain for a single item in the database."""
+ item.rg_track_gain = track_gain.gain
+ item.rg_track_peak = track_gain.peak
+ item.store()
+ self._log.debug(
+ "applied track gain {0} LU, peak {1} of FS",
+ item.rg_track_gain,
+ item.rg_track_peak,
+ )
+
+ def _store_album_gain(self, item: Item, album_gain: Gain):
+ """Store album gain for a single item in the database.
+
+ The caller needs to ensure that `self.album_gain is not None`.
+ """
+ item.rg_album_gain = album_gain.gain
+ item.rg_album_peak = album_gain.peak
+ item.store()
+ self._log.debug(
+ "applied album gain {0} LU, peak {1} of FS",
+ item.rg_album_gain,
+ item.rg_album_peak,
+ )
+
+ def _store_track(self, write: bool):
+ """Store track gain for the first track of the task in the database."""
+ item = self.items[0]
+ if self.track_gains is None or len(self.track_gains) != 1:
+ # In some cases, backends fail to produce a valid
+ # `track_gains` without throwing FatalReplayGainError
+ # => raise non-fatal exception & continue
+ raise ReplayGainError(
+ "ReplayGain backend `{}` failed for track {}".format(
+ self.backend_name, item
+ )
+ )
+
+ self._store_track_gain(item, self.track_gains[0])
+ if write:
+ item.try_write()
+ self._log.debug("done analyzing {0}", item)
+
+ def _store_album(self, write: bool):
+ """Store track/album gains for all tracks of the task in the database."""
+ if (
+ self.album_gain is None
+ or self.track_gains is None
+ or len(self.track_gains) != len(self.items)
+ ):
+ # In some cases, backends fail to produce a valid
+ # `album_gain` without throwing FatalReplayGainError
+ # => raise non-fatal exception & continue
+ raise ReplayGainError(
+ "ReplayGain backend `{}` failed "
+ "for some tracks in album {}".format(
+ self.backend_name, self.album
+ )
+ )
+ for item, track_gain in zip(self.items, self.track_gains):
+ self._store_track_gain(item, track_gain)
+ self._store_album_gain(item, self.album_gain)
+ if write:
+ item.try_write()
+ self._log.debug("done analyzing {0}", item)
+
+ def store(self, write: bool):
+ """Store computed gains for the items of this task in the database."""
+ if self.album is not None:
+ self._store_album(write)
+ else:
+ self._store_track(write)
+
+
+class R128Task(RgTask):
+ """State and methods for a single replaygain calculation (r128 version).
+
+ Bundles the state (parameters and results) of a single replaygain
+ calculation (either for one item, one disk, or one full album).
+
+ This class provides methods to store the resulting gains and peaks as R128
+ tags.
+ """
+
+ def __init__(
+ self,
+ items: Sequence[Item],
+ album: Optional[Album],
+ target_level: float,
+ backend_name: str,
+ log: Logger,
+ ):
+ # R128_* tags do not store the track/album peak
+ super().__init__(items, album, target_level, None, backend_name, log)
+
+ def _store_track_gain(self, item: Item, track_gain: Gain):
+ item.r128_track_gain = track_gain.gain
+ item.store()
+ self._log.debug("applied r128 track gain {0} LU", item.r128_track_gain)
+
+ def _store_album_gain(self, item: Item, album_gain: Gain):
+ """
+
+ The caller needs to ensure that `self.album_gain is not None`.
+ """
+ item.r128_album_gain = album_gain.gain
+ item.store()
+ self._log.debug("applied r128 album gain {0} LU", item.r128_album_gain)
+
+
+AnyRgTask = TypeVar("AnyRgTask", bound=RgTask)
+
+
+class Backend(ABC):
+ """An abstract class representing engine for calculating RG values."""
+
+ NAME = ""
do_parallel = False
- def __init__(self, config, log):
+ def __init__(self, config: ConfigView, log: Logger):
"""Initialize the backend with the configuration view for the
plugin.
"""
self._log = log
- def compute_track_gain(self, items, target_level, peak):
- """Computes the track gain of the given tracks, returns a list
- of Gain objects.
+ @abstractmethod
+ def compute_track_gain(self, task: AnyRgTask) -> AnyRgTask:
+ """Computes the track gain for the tracks belonging to `task`, and sets
+ the `track_gains` attribute on the task. Returns `task`.
"""
raise NotImplementedError()
- def compute_album_gain(self, items, target_level, peak):
- """Computes the album gain of the given album, returns an
- AlbumGain object.
+ @abstractmethod
+ def compute_album_gain(self, task: AnyRgTask) -> AnyRgTask:
+ """Computes the album gain for the album belonging to `task`, and sets
+ the `album_gain` attribute on the task. Returns `task`.
"""
raise NotImplementedError()
# ffmpeg backend
class FfmpegBackend(Backend):
- """A replaygain backend using ffmpeg's ebur128 filter.
- """
+ """A replaygain backend using ffmpeg's ebur128 filter."""
+ NAME = "ffmpeg"
do_parallel = True
- def __init__(self, config, log):
+ def __init__(self, config: ConfigView, log: Logger):
super().__init__(config, log)
self._ffmpeg_path = "ffmpeg"
# check that ffmpeg is installed
try:
- ffmpeg_version_out = call([self._ffmpeg_path, "-version"])
+ ffmpeg_version_out = call([self._ffmpeg_path, "-version"], log)
except OSError:
raise FatalReplayGainError(
f"could not find ffmpeg at {self._ffmpeg_path}"
@@ -165,81 +324,95 @@ class FfmpegBackend(Backend):
"the --enable-libebur128 configuration option is required."
)
- def compute_track_gain(self, items, target_level, peak):
- """Computes the track gain of the given tracks, returns a list
- of Gain objects (the track gains).
+ def compute_track_gain(self, task: AnyRgTask) -> AnyRgTask:
+ """Computes the track gain for the tracks belonging to `task`, and sets
+ the `track_gains` attribute on the task. Returns `task`.
"""
- gains = []
- for item in items:
- gains.append(
- self._analyse_item(
- item,
- target_level,
- peak,
- count_blocks=False,
- )[0] # take only the gain, discarding number of gating blocks
- )
- return gains
+ task.track_gains = [
+ self._analyse_item(
+ item,
+ task.target_level,
+ task.peak_method,
+ count_blocks=False,
+ )[
+ 0
+ ] # take only the gain, discarding number of gating blocks
+ for item in task.items
+ ]
- def compute_album_gain(self, items, target_level, peak):
- """Computes the album gain of the given album, returns an
- AlbumGain object.
+ return task
+
+ def compute_album_gain(self, task: AnyRgTask) -> AnyRgTask:
+ """Computes the album gain for the album belonging to `task`, and sets
+ the `album_gain` attribute on the task. Returns `task`.
"""
- target_level_lufs = db_to_lufs(target_level)
+ target_level_lufs = db_to_lufs(task.target_level)
# analyse tracks
- # list of track Gain objects
- track_gains = []
- # maximum peak
- album_peak = 0
- # sum of BS.1770 gating block powers
- sum_powers = 0
- # total number of BS.1770 gating blocks
- n_blocks = 0
-
- for item in items:
- track_gain, track_n_blocks = self._analyse_item(
- item, target_level, peak
+ # Gives a list of tuples (track_gain, track_n_blocks)
+ track_results: List[Tuple[Gain, int]] = [
+ self._analyse_item(
+ item,
+ task.target_level,
+ task.peak_method,
+ count_blocks=True,
)
- track_gains.append(track_gain)
+ for item in task.items
+ ]
- # album peak is maximum track peak
- album_peak = max(album_peak, track_gain.peak)
+ track_gains: List[Gain] = [tg for tg, _nb in track_results]
- # prepare album_gain calculation
- # total number of blocks is sum of track blocks
- n_blocks += track_n_blocks
+ # Album peak is maximum track peak
+ album_peak = max(tg.peak for tg in track_gains)
+ # Total number of BS.1770 gating blocks
+ n_blocks = sum(nb for _tg, nb in track_results)
+
+ def sum_of_track_powers(track_gain: Gain, track_n_blocks: int):
# convert `LU to target_level` -> LUFS
- track_loudness = target_level_lufs - track_gain.gain
+ loudness = target_level_lufs - track_gain.gain
+
# This reverses ITU-R BS.1770-4 p. 6 equation (5) to convert
# from loudness to power. The result is the average gating
# block power.
- track_power = 10**((track_loudness + 0.691) / 10)
+ power = 10 ** ((loudness + 0.691) / 10)
- # Weight that average power by the number of gating blocks to
- # get the sum of all their powers. Add that to the sum of all
- # block powers in this album.
- sum_powers += track_power * track_n_blocks
+ # Multiply that average power by the number of gating blocks to get
+ # the sum of all block powers in this track.
+ return track_n_blocks * power
# calculate album gain
if n_blocks > 0:
+ # Sum over all tracks to get the sum of BS.1770 gating block powers
+ # for the entire album.
+ sum_powers = sum(
+ sum_of_track_powers(tg, nb) for tg, nb in track_results
+ )
+
# compare ITU-R BS.1770-4 p. 6 equation (5)
# Album gain is the replaygain of the concatenation of all tracks.
album_gain = -0.691 + 10 * math.log10(sum_powers / n_blocks)
else:
album_gain = -70
+
# convert LUFS -> `LU to target_level`
album_gain = target_level_lufs - album_gain
self._log.debug(
- "{}: gain {} LU, peak {}"
- .format(items, album_gain, album_peak)
- )
+ "{}: gain {} LU, peak {}",
+ task.album,
+ album_gain,
+ album_peak,
+ )
- return AlbumGain(Gain(album_gain, album_peak), track_gains)
+ task.album_gain = Gain(album_gain, album_peak)
+ task.track_gains = track_gains
- def _construct_cmd(self, item, peak_method):
+ return task
+
+ def _construct_cmd(
+ self, item: Item, peak_method: Optional[PeakMethod]
+ ) -> List[Union[str, bytes]]:
"""Construct the shell command to analyse items."""
return [
self._ffmpeg_path,
@@ -250,13 +423,21 @@ class FfmpegBackend(Backend):
"-map",
"a:0",
"-filter",
- f"ebur128=peak={peak_method}",
+ "ebur128=peak={}".format(
+ "none" if peak_method is None else peak_method.name
+ ),
"-f",
"null",
"-",
]
- def _analyse_item(self, item, target_level, peak, count_blocks=True):
+ def _analyse_item(
+ self,
+ item: Item,
+ target_level: float,
+ peak_method: Optional[PeakMethod],
+ count_blocks: bool = True,
+ ) -> Tuple[Gain, int]:
"""Analyse item. Return a pair of a Gain object and the number
of gating blocks above the threshold.
@@ -264,44 +445,51 @@ class FfmpegBackend(Backend):
will be 0.
"""
target_level_lufs = db_to_lufs(target_level)
- peak_method = peak.name
# call ffmpeg
self._log.debug(f"analyzing {item}")
cmd = self._construct_cmd(item, peak_method)
- self._log.debug(
- 'executing {0}', ' '.join(map(displayable_path, cmd))
- )
- output = call(cmd).stderr.splitlines()
+ self._log.debug("executing {0}", " ".join(map(displayable_path, cmd)))
+ output = call(cmd, self._log).stderr.splitlines()
# parse output
- if peak == Peak.none:
- peak = 0
+ if peak_method is None:
+ peak = 0.0
else:
line_peak = self._find_line(
output,
- f" {peak_method.capitalize()} peak:".encode(),
- start_line=len(output) - 1, step_size=-1,
+ # `peak_method` is non-`None` in this arm of the conditional
+ f" {peak_method.name.capitalize()} peak:".encode(),
+ start_line=len(output) - 1,
+ step_size=-1,
)
peak = self._parse_float(
- output[self._find_line(
- output, b" Peak:",
- line_peak,
- )]
+ output[
+ self._find_line(
+ output,
+ b" Peak:",
+ line_peak,
+ )
+ ]
)
# convert TPFS -> part of FS
- peak = 10**(peak / 20)
+ peak = 10 ** (peak / 20)
line_integrated_loudness = self._find_line(
- output, b" Integrated loudness:",
- start_line=len(output) - 1, step_size=-1,
+ output,
+ b" Integrated loudness:",
+ start_line=len(output) - 1,
+ step_size=-1,
)
gain = self._parse_float(
- output[self._find_line(
- output, b" I:",
- line_integrated_loudness,
- )]
+ output[
+ self._find_line(
+ output,
+ b" I:",
+ line_integrated_loudness,
+ )
+ ]
)
# convert LUFS -> LU from target level
gain = target_level_lufs - gain
@@ -310,10 +498,13 @@ class FfmpegBackend(Backend):
n_blocks = 0
if count_blocks:
gating_threshold = self._parse_float(
- output[self._find_line(
- output, b" Threshold:",
- start_line=line_integrated_loudness,
- )]
+ output[
+ self._find_line(
+ output,
+ b" Threshold:",
+ start_line=line_integrated_loudness,
+ )
+ ]
)
for line in output:
if not line.startswith(b"[Parsed_ebur128"):
@@ -326,18 +517,22 @@ class FfmpegBackend(Backend):
if self._parse_float(b"M: " + line[1]) >= gating_threshold:
n_blocks += 1
self._log.debug(
- "{}: {} blocks over {} LUFS"
- .format(item, n_blocks, gating_threshold)
+ "{}: {} blocks over {} LUFS".format(
+ item, n_blocks, gating_threshold
+ )
)
- self._log.debug(
- "{}: gain {} LU, peak {}"
- .format(item, gain, peak)
- )
+ self._log.debug("{}: gain {} LU, peak {}".format(item, gain, peak))
return Gain(gain, peak), n_blocks
- def _find_line(self, output, search, start_line=0, step_size=1):
+ def _find_line(
+ self,
+ output: Sequence[bytes],
+ search: bytes,
+ start_line: int = 0,
+ step_size: int = 1,
+ ) -> int:
"""Return index of line beginning with `search`.
Begins searching at index `start_line` in `output`.
@@ -347,24 +542,24 @@ class FfmpegBackend(Backend):
if output[i].startswith(search):
return i
raise ReplayGainError(
- "ffmpeg output: missing {} after line {}"
- .format(repr(search), start_line)
+ "ffmpeg output: missing {} after line {}".format(
+ repr(search), start_line
)
+ )
- def _parse_float(self, line):
+ def _parse_float(self, line: bytes) -> float:
"""Extract a float from a key value pair in `line`.
This format is expected: /[^:]:[[:space:]]*value.*/, where `value` is
the float.
"""
# extract value
- value = line.split(b":", 1)
- if len(value) < 2:
+ parts = line.split(b":", 1)
+ if len(parts) < 2:
raise ReplayGainError(
- "ffmpeg output: expected key value pair, found {}"
- .format(line)
- )
- value = value[1].lstrip()
+ f"ffmpeg output: expected key value pair, found {line!r}"
+ )
+ value = parts[1].lstrip()
# strip unit
value = value.split(b" ", 1)[0]
# cast value to float
@@ -372,87 +567,97 @@ class FfmpegBackend(Backend):
return float(value)
except ValueError:
raise ReplayGainError(
- "ffmpeg output: expected float value, found {}"
- .format(value)
- )
+ f"ffmpeg output: expected float value, found {value!r}"
+ )
# mpgain/aacgain CLI tool backend.
class CommandBackend(Backend):
+ NAME = "command"
do_parallel = True
- def __init__(self, config, log):
+ def __init__(self, config: ConfigView, log: Logger):
super().__init__(config, log)
- config.add({
- 'command': "",
- 'noclip': True,
- })
+ config.add(
+ {
+ "command": "",
+ "noclip": True,
+ }
+ )
- self.command = config["command"].as_str()
+ self.command = cast(str, config["command"].as_str())
if self.command:
# Explicit executable path.
if not os.path.isfile(self.command):
raise FatalReplayGainError(
- 'replaygain command does not exist: {}'.format(
- self.command)
+ "replaygain command does not exist: {}".format(self.command)
)
else:
# Check whether the program is in $PATH.
- for cmd in ('mp3gain', 'aacgain'):
+ for cmd in ("mp3gain", "aacgain"):
try:
- call([cmd, '-v'])
+ call([cmd, "-v"], self._log)
self.command = cmd
except OSError:
pass
if not self.command:
raise FatalReplayGainError(
- 'no replaygain command found: install mp3gain or aacgain'
+ "no replaygain command found: install mp3gain or aacgain"
)
- self.noclip = config['noclip'].get(bool)
+ self.noclip = config["noclip"].get(bool)
- def compute_track_gain(self, items, target_level, peak):
- """Computes the track gain of the given tracks, returns a list
- of TrackGain objects.
+ def compute_track_gain(self, task: AnyRgTask) -> AnyRgTask:
+ """Computes the track gain for the tracks belonging to `task`, and sets
+ the `track_gains` attribute on the task. Returns `task`.
"""
- supported_items = list(filter(self.format_supported, items))
- output = self.compute_gain(supported_items, target_level, False)
- return output
+ supported_items = list(filter(self.format_supported, task.items))
+ output = self.compute_gain(supported_items, task.target_level, False)
+ task.track_gains = output
+ return task
- def compute_album_gain(self, items, target_level, peak):
- """Computes the album gain of the given album, returns an
- AlbumGain object.
+ def compute_album_gain(self, task: AnyRgTask) -> AnyRgTask:
+ """Computes the album gain for the album belonging to `task`, and sets
+ the `album_gain` attribute on the task. Returns `task`.
"""
# TODO: What should be done when not all tracks in the album are
# supported?
- supported_items = list(filter(self.format_supported, items))
- if len(supported_items) != len(items):
- self._log.debug('tracks are of unsupported format')
- return AlbumGain(None, [])
+ supported_items = list(filter(self.format_supported, task.items))
+ if len(supported_items) != len(task.items):
+ self._log.debug("tracks are of unsupported format")
+ task.album_gain = None
+ task.track_gains = None
+ return task
- output = self.compute_gain(supported_items, target_level, True)
- return AlbumGain(output[-1], output[:-1])
+ output = self.compute_gain(supported_items, task.target_level, True)
+ task.album_gain = output[-1]
+ task.track_gains = output[:-1]
+ return task
- def format_supported(self, item):
- """Checks whether the given item is supported by the selected tool.
- """
- if 'mp3gain' in self.command and item.format != 'MP3':
+ def format_supported(self, item: Item) -> bool:
+ """Checks whether the given item is supported by the selected tool."""
+ if "mp3gain" in self.command and item.format != "MP3":
return False
- elif 'aacgain' in self.command and item.format not in ('MP3', 'AAC'):
+ elif "aacgain" in self.command and item.format not in ("MP3", "AAC"):
return False
return True
- def compute_gain(self, items, target_level, is_album):
+ def compute_gain(
+ self,
+ items: Sequence[Item],
+ target_level: float,
+ is_album: bool,
+ ) -> List[Gain]:
"""Computes the track or album gain of a list of items, returns
a list of TrackGain objects.
When computing album gain, the last TrackGain object returned is
the album gain
"""
- if len(items) == 0:
- self._log.debug('no supported tracks to analyze')
+ if not items:
+ self._log.debug("no supported tracks to analyze")
return []
"""Compute ReplayGain values and return a list of results
@@ -464,51 +669,54 @@ class CommandBackend(Backend):
# tag-writing; this turns the mp3gain/aacgain tool into a gain
# calculator rather than a tag manipulator because we take care
# of changing tags ourselves.
- cmd = [self.command, '-o', '-s', 's']
+ cmd: List[Union[bytes, str]] = [self.command, "-o", "-s", "s"]
if self.noclip:
# Adjust to avoid clipping.
- cmd = cmd + ['-k']
+ cmd = cmd + ["-k"]
else:
# Disable clipping warning.
- cmd = cmd + ['-c']
- cmd = cmd + ['-d', str(int(target_level - 89))]
+ cmd = cmd + ["-c"]
+ cmd = cmd + ["-d", str(int(target_level - 89))]
cmd = cmd + [syspath(i.path) for i in items]
- self._log.debug('analyzing {0} files', len(items))
+ self._log.debug("analyzing {0} files", len(items))
self._log.debug("executing {0}", " ".join(map(displayable_path, cmd)))
- output = call(cmd).stdout
- self._log.debug('analysis finished')
- return self.parse_tool_output(output,
- len(items) + (1 if is_album else 0))
+ output = call(cmd, self._log).stdout
+ self._log.debug("analysis finished")
+ return self.parse_tool_output(
+ output, len(items) + (1 if is_album else 0)
+ )
- def parse_tool_output(self, text, num_lines):
+ def parse_tool_output(self, text: bytes, num_lines: int) -> List[Gain]:
"""Given the tab-delimited output from an invocation of mp3gain
or aacgain, parse the text and return a list of dictionaries
containing information about each analyzed file.
"""
out = []
- for line in text.split(b'\n')[1:num_lines + 1]:
- parts = line.split(b'\t')
- if len(parts) != 6 or parts[0] == b'File':
- self._log.debug('bad tool output: {0}', text)
- raise ReplayGainError('mp3gain failed')
- d = {
- 'file': parts[0],
- 'mp3gain': int(parts[1]),
- 'gain': float(parts[2]),
- 'peak': float(parts[3]) / (1 << 15),
- 'maxgain': int(parts[4]),
- 'mingain': int(parts[5]),
+ for line in text.split(b"\n")[1 : num_lines + 1]:
+ parts = line.split(b"\t")
+ if len(parts) != 6 or parts[0] == b"File":
+ self._log.debug("bad tool output: {0}", text)
+ raise ReplayGainError("mp3gain failed")
- }
- out.append(Gain(d['gain'], d['peak']))
+ # _file = parts[0]
+ # _mp3gain = int(parts[1])
+ gain = float(parts[2])
+ peak = float(parts[3]) / (1 << 15)
+ # _maxgain = int(parts[4])
+ # _mingain = int(parts[5])
+
+ out.append(Gain(gain, peak))
return out
# GStreamer-based backend.
+
class GStreamerBackend(Backend):
- def __init__(self, config, log):
+ NAME = "gstreamer"
+
+ def __init__(self, config: ConfigView, log: Logger):
super().__init__(config, log)
self._import_gst()
@@ -523,8 +731,13 @@ class GStreamerBackend(Backend):
self._res = self.Gst.ElementFactory.make("audioresample", "res")
self._rg = self.Gst.ElementFactory.make("rganalysis", "rg")
- if self._src is None or self._decbin is None or self._conv is None \
- or self._res is None or self._rg is None:
+ if (
+ self._src is None
+ or self._decbin is None
+ or self._conv is None
+ or self._res is None
+ or self._rg is None
+ ):
raise FatalGstreamerPluginReplayGainError(
"Failed to load required GStreamer plugins"
)
@@ -560,7 +773,7 @@ class GStreamerBackend(Backend):
self._main_loop = self.GLib.MainLoop()
- self._files = []
+ self._files: List[bytes] = []
def _import_gst(self):
"""Import the necessary GObject-related modules and assign `Gst`
@@ -575,13 +788,12 @@ class GStreamerBackend(Backend):
)
try:
- gi.require_version('Gst', '1.0')
+ gi.require_version("Gst", "1.0")
except ValueError as e:
- raise FatalReplayGainError(
- f"Failed to load GStreamer 1.0: {e}"
- )
+ raise FatalReplayGainError(f"Failed to load GStreamer 1.0: {e}")
+
+ from gi.repository import GLib, GObject, Gst
- from gi.repository import GObject, Gst, GLib
# Calling GObject.threads_init() is not needed for
# PyGObject 3.10.2+
with warnings.catch_warnings():
@@ -593,14 +805,17 @@ class GStreamerBackend(Backend):
self.GLib = GLib
self.Gst = Gst
- def compute(self, files, target_level, album):
- self._error = None
- self._files = list(files)
-
- if len(self._files) == 0:
+ def compute(self, items: Sequence[Item], target_level: float, album: bool):
+ if len(items) == 0:
return
- self._file_tags = collections.defaultdict(dict)
+ self._error = None
+ self._files = [i.path for i in items]
+
+ # FIXME: Turn this into DefaultDict[bytes, Gain]
+ self._file_tags: DefaultDict[bytes, Dict[str, float]] = (
+ collections.defaultdict(dict)
+ )
self._rg.set_property("reference-level", target_level)
@@ -612,21 +827,32 @@ class GStreamerBackend(Backend):
if self._error is not None:
raise self._error
- def compute_track_gain(self, items, target_level, peak):
- self.compute(items, target_level, False)
- if len(self._file_tags) != len(items):
+ def compute_track_gain(self, task: AnyRgTask) -> AnyRgTask:
+ """Computes the track gain for the tracks belonging to `task`, and sets
+ the `track_gains` attribute on the task. Returns `task`.
+ """
+ self.compute(task.items, task.target_level, False)
+ if len(self._file_tags) != len(task.items):
raise ReplayGainError("Some tracks did not receive tags")
ret = []
- for item in items:
- ret.append(Gain(self._file_tags[item]["TRACK_GAIN"],
- self._file_tags[item]["TRACK_PEAK"]))
+ for item in task.items:
+ ret.append(
+ Gain(
+ self._file_tags[item.path]["TRACK_GAIN"],
+ self._file_tags[item.path]["TRACK_PEAK"],
+ )
+ )
- return ret
+ task.track_gains = ret
+ return task
- def compute_album_gain(self, items, target_level, peak):
- items = list(items)
- self.compute(items, target_level, True)
+ def compute_album_gain(self, task: AnyRgTask) -> AnyRgTask:
+ """Computes the album gain for the album belonging to `task`, and sets
+ the `album_gain` attribute on the task. Returns `task`.
+ """
+ items = list(task.items)
+ self.compute(items, task.target_level, True)
if len(self._file_tags) != len(items):
raise ReplayGainError("Some items in album did not receive tags")
@@ -634,21 +860,23 @@ class GStreamerBackend(Backend):
track_gains = []
for item in items:
try:
- gain = self._file_tags[item]["TRACK_GAIN"]
- peak = self._file_tags[item]["TRACK_PEAK"]
+ gain = self._file_tags[item.path]["TRACK_GAIN"]
+ peak = self._file_tags[item.path]["TRACK_PEAK"]
except KeyError:
raise ReplayGainError("results missing for track")
track_gains.append(Gain(gain, peak))
# Get album gain information from the last track.
- last_tags = self._file_tags[items[-1]]
+ last_tags = self._file_tags[items[-1].path]
try:
gain = last_tags["ALBUM_GAIN"]
peak = last_tags["ALBUM_PEAK"]
except KeyError:
raise ReplayGainError("results missing for album")
- return AlbumGain(Gain(gain, peak), track_gains)
+ task.album_gain = Gain(gain, peak)
+ task.track_gains = track_gains
+ return task
def close(self):
self._bus.remove_signal_watch()
@@ -680,36 +908,40 @@ class GStreamerBackend(Backend):
# store the computed tags, we overwrite the RG values of
# received a second time.
if tag == self.Gst.TAG_TRACK_GAIN:
- self._file_tags[self._file]["TRACK_GAIN"] = \
- taglist.get_double(tag)[1]
+ self._file_tags[self._file]["TRACK_GAIN"] = taglist.get_double(
+ tag
+ )[1]
elif tag == self.Gst.TAG_TRACK_PEAK:
- self._file_tags[self._file]["TRACK_PEAK"] = \
- taglist.get_double(tag)[1]
+ self._file_tags[self._file]["TRACK_PEAK"] = taglist.get_double(
+ tag
+ )[1]
elif tag == self.Gst.TAG_ALBUM_GAIN:
- self._file_tags[self._file]["ALBUM_GAIN"] = \
- taglist.get_double(tag)[1]
+ self._file_tags[self._file]["ALBUM_GAIN"] = taglist.get_double(
+ tag
+ )[1]
elif tag == self.Gst.TAG_ALBUM_PEAK:
- self._file_tags[self._file]["ALBUM_PEAK"] = \
- taglist.get_double(tag)[1]
+ self._file_tags[self._file]["ALBUM_PEAK"] = taglist.get_double(
+ tag
+ )[1]
elif tag == self.Gst.TAG_REFERENCE_LEVEL:
- self._file_tags[self._file]["REFERENCE_LEVEL"] = \
+ self._file_tags[self._file]["REFERENCE_LEVEL"] = (
taglist.get_double(tag)[1]
+ )
tags.foreach(handle_tag, None)
- def _set_first_file(self):
+ def _set_first_file(self) -> bool:
if len(self._files) == 0:
return False
self._file = self._files.pop(0)
self._pipe.set_state(self.Gst.State.NULL)
- self._src.set_property("location", py3_path(syspath(self._file.path)))
+ self._src.set_property("location", os.fsdecode(syspath(self._file)))
self._pipe.set_state(self.Gst.State.PLAYING)
return True
- def _set_file(self):
- """Initialize the filesrc element with the next file to be analyzed.
- """
+ def _set_file(self) -> bool:
+ """Initialize the filesrc element with the next file to be analyzed."""
# No more files, we're done
if len(self._files) == 0:
return False
@@ -734,14 +966,14 @@ class GStreamerBackend(Backend):
# Set a new file on the filesrc element, can only be done in the
# READY state
self._src.set_state(self.Gst.State.READY)
- self._src.set_property("location", py3_path(syspath(self._file.path)))
+ self._src.set_property("location", os.fsdecode(syspath(self._file)))
self._decbin.link(self._conv)
self._pipe.set_state(self.Gst.State.READY)
return True
- def _set_next_file(self):
+ def _set_next_file(self) -> bool:
"""Set the next file to be analyzed while keeping the pipeline
in the PAUSED state so that the rganalysis element can correctly
handle album gain.
@@ -755,23 +987,23 @@ class GStreamerBackend(Backend):
if ret:
# Seek to the beginning in order to clear the EOS state of the
# various elements of the pipeline
- self._pipe.seek_simple(self.Gst.Format.TIME,
- self.Gst.SeekFlags.FLUSH,
- 0)
+ self._pipe.seek_simple(
+ self.Gst.Format.TIME, self.Gst.SeekFlags.FLUSH, 0
+ )
self._pipe.set_state(self.Gst.State.PLAYING)
return ret
def _on_pad_added(self, decbin, pad):
sink_pad = self._conv.get_compatible_pad(pad, None)
- assert(sink_pad is not None)
+ assert sink_pad is not None
pad.link(sink_pad)
def _on_pad_removed(self, decbin, pad):
# Called when the decodebin element is disconnected from the
# rest of the pipeline while switching input files
peer = pad.get_peer()
- assert(peer is None)
+ assert peer is None
class AudioToolsBackend(Backend):
@@ -780,7 +1012,9 @@ class AudioToolsBackend(Backend):
file formats and compute ReplayGain values using it replaygain module.
"""
- def __init__(self, config, log):
+ NAME = "audiotools"
+
+ def __init__(self, config: ConfigView, log: Logger):
super().__init__(config, log)
self._import_audiotools()
@@ -800,7 +1034,7 @@ class AudioToolsBackend(Backend):
self._mod_audiotools = audiotools
self._mod_replaygain = audiotools.replaygain
- def open_audio_file(self, item):
+ def open_audio_file(self, item: Item):
"""Open the file to read the PCM stream from the using
``item.path``.
@@ -810,19 +1044,17 @@ class AudioToolsBackend(Backend):
file format is not supported
"""
try:
- audiofile = self._mod_audiotools.open(py3_path(syspath(item.path)))
+ audiofile = self._mod_audiotools.open(
+ os.fsdecode(syspath(item.path))
+ )
except OSError:
- raise ReplayGainError(
- f"File {item.path} was not found"
- )
+ raise ReplayGainError(f"File {item.path} was not found")
except self._mod_audiotools.UnsupportedFile:
- raise ReplayGainError(
- f"Unsupported file type {item.format}"
- )
+ raise ReplayGainError(f"Unsupported file type {item.format}")
return audiofile
- def init_replaygain(self, audiofile, item):
+ def init_replaygain(self, audiofile, item: Item):
"""Return an initialized :class:`audiotools.replaygain.ReplayGain`
instance, which requires the sample rate of the song(s) on which
the ReplayGain values will be computed. The item is passed in case
@@ -835,26 +1067,28 @@ class AudioToolsBackend(Backend):
try:
rg = self._mod_replaygain.ReplayGain(audiofile.sample_rate())
except ValueError:
- raise ReplayGainError(
- f"Unsupported sample rate {item.samplerate}")
+ raise ReplayGainError(f"Unsupported sample rate {item.samplerate}")
return
return rg
- def compute_track_gain(self, items, target_level, peak):
- """Compute ReplayGain values for the requested items.
-
- :return list: list of :class:`Gain` objects
+ def compute_track_gain(self, task: AnyRgTask) -> AnyRgTask:
+ """Computes the track gain for the tracks belonging to `task`, and sets
+ the `track_gains` attribute on the task. Returns `task`.
"""
- return [self._compute_track_gain(item, target_level) for item in items]
+ gains = [
+ self._compute_track_gain(i, task.target_level) for i in task.items
+ ]
+ task.track_gains = gains
+ return task
- def _with_target_level(self, gain, target_level):
+ def _with_target_level(self, gain: float, target_level: float):
"""Return `gain` relative to `target_level`.
Assumes `gain` is relative to 89 db.
"""
return gain + (target_level - 89)
- def _title_gain(self, rg, audiofile, target_level):
+ def _title_gain(self, rg, audiofile, target_level: float):
"""Get the gain result pair from PyAudioTools using the `ReplayGain`
instance `rg` for the given `audiofile`.
@@ -868,11 +1102,11 @@ class AudioToolsBackend(Backend):
except ValueError as exc:
# `audiotools.replaygain` can raise a `ValueError` if the sample
# rate is incorrect.
- self._log.debug('error in rg.title_gain() call: {}', exc)
- raise ReplayGainError('audiotools audio data error')
+ self._log.debug("error in rg.title_gain() call: {}", exc)
+ raise ReplayGainError("audiotools audio data error")
return self._with_target_level(gain, target_level), peak
- def _compute_track_gain(self, item, target_level):
+ def _compute_track_gain(self, item: Item, target_level: float):
"""Compute ReplayGain value for the requested item.
:rtype: :class:`Gain`
@@ -886,53 +1120,64 @@ class AudioToolsBackend(Backend):
rg, audiofile, target_level
)
- self._log.debug('ReplayGain for track {0} - {1}: {2:.2f}, {3:.2f}',
- item.artist, item.title, rg_track_gain, rg_track_peak)
+ self._log.debug(
+ "ReplayGain for track {0} - {1}: {2:.2f}, {3:.2f}",
+ item.artist,
+ item.title,
+ rg_track_gain,
+ rg_track_peak,
+ )
return Gain(gain=rg_track_gain, peak=rg_track_peak)
- def compute_album_gain(self, items, target_level, peak):
- """Compute ReplayGain values for the requested album and its items.
-
- :rtype: :class:`AlbumGain`
+ def compute_album_gain(self, task: AnyRgTask) -> AnyRgTask:
+ """Computes the album gain for the album belonging to `task`, and sets
+ the `album_gain` attribute on the task. Returns `task`.
"""
# The first item is taken and opened to get the sample rate to
# initialize the replaygain object. The object is used for all the
# tracks in the album to get the album values.
- item = list(items)[0]
+ item = list(task.items)[0]
audiofile = self.open_audio_file(item)
rg = self.init_replaygain(audiofile, item)
track_gains = []
- for item in items:
+ for item in task.items:
audiofile = self.open_audio_file(item)
rg_track_gain, rg_track_peak = self._title_gain(
- rg, audiofile, target_level
+ rg, audiofile, task.target_level
)
- track_gains.append(
- Gain(gain=rg_track_gain, peak=rg_track_peak)
+ track_gains.append(Gain(gain=rg_track_gain, peak=rg_track_peak))
+ self._log.debug(
+ "ReplayGain for track {0}: {1:.2f}, {2:.2f}",
+ item,
+ rg_track_gain,
+ rg_track_peak,
)
- self._log.debug('ReplayGain for track {0}: {1:.2f}, {2:.2f}',
- item, rg_track_gain, rg_track_peak)
# After getting the values for all tracks, it's possible to get the
# album values.
rg_album_gain, rg_album_peak = rg.album_gain()
- rg_album_gain = self._with_target_level(rg_album_gain, target_level)
- self._log.debug('ReplayGain for album {0}: {1:.2f}, {2:.2f}',
- items[0].album, rg_album_gain, rg_album_peak)
-
- return AlbumGain(
- Gain(gain=rg_album_gain, peak=rg_album_peak),
- track_gains=track_gains
+ rg_album_gain = self._with_target_level(
+ rg_album_gain, task.target_level
)
+ self._log.debug(
+ "ReplayGain for album {0}: {1:.2f}, {2:.2f}",
+ task.items[0].album,
+ rg_album_gain,
+ rg_album_peak,
+ )
+
+ task.album_gain = Gain(gain=rg_album_gain, peak=rg_album_peak)
+ task.track_gains = track_gains
+ return task
class ExceptionWatcher(Thread):
"""Monitors a queue for exceptions asynchronously.
- Once an exception occurs, raise it and execute a callback.
+ Once an exception occurs, raise it and execute a callback.
"""
- def __init__(self, queue, callback):
+ def __init__(self, queue: queue.Queue, callback: Callable[[], None]):
self._queue = queue
self._callback = callback
self._stopevent = Event()
@@ -943,166 +1188,173 @@ class ExceptionWatcher(Thread):
try:
exc = self._queue.get_nowait()
self._callback()
- raise exc[1].with_traceback(exc[2])
+ raise exc
except queue.Empty:
# No exceptions yet, loop back to check
# whether `_stopevent` is set
pass
- def join(self, timeout=None):
+ def join(self, timeout: Optional[float] = None):
self._stopevent.set()
Thread.join(self, timeout)
# Main plugin logic.
+BACKEND_CLASSES: List[Type[Backend]] = [
+ CommandBackend,
+ GStreamerBackend,
+ AudioToolsBackend,
+ FfmpegBackend,
+]
+BACKENDS: Dict[str, Type[Backend]] = {b.NAME: b for b in BACKEND_CLASSES}
+
+
class ReplayGainPlugin(BeetsPlugin):
- """Provides ReplayGain analysis.
- """
-
- backends = {
- "command": CommandBackend,
- "gstreamer": GStreamerBackend,
- "audiotools": AudioToolsBackend,
- "ffmpeg": FfmpegBackend,
- }
-
- peak_methods = {
- "true": Peak.true,
- "sample": Peak.sample,
- }
+ """Provides ReplayGain analysis."""
def __init__(self):
super().__init__()
# default backend is 'command' for backward-compatibility.
- self.config.add({
- 'overwrite': False,
- 'auto': True,
- 'backend': 'command',
- 'threads': cpu_count(),
- 'parallel_on_import': False,
- 'per_disc': False,
- 'peak': 'true',
- 'targetlevel': 89,
- 'r128': ['Opus'],
- 'r128_targetlevel': lufs_to_db(-23),
- })
+ self.config.add(
+ {
+ "overwrite": False,
+ "auto": True,
+ "backend": "command",
+ "threads": os.cpu_count(),
+ "parallel_on_import": False,
+ "per_disc": False,
+ "peak": "true",
+ "targetlevel": 89,
+ "r128": ["Opus"],
+ "r128_targetlevel": lufs_to_db(-23),
+ }
+ )
- self.overwrite = self.config['overwrite'].get(bool)
- self.per_disc = self.config['per_disc'].get(bool)
+ # FIXME: Consider renaming the configuration option and deprecating the
+ # old name 'overwrite'.
+ self.force_on_import = cast(bool, self.config["overwrite"].get(bool))
# Remember which backend is used for CLI feedback
- self.backend_name = self.config['backend'].as_str()
+ self.backend_name = self.config["backend"].as_str()
- if self.backend_name not in self.backends:
+ if self.backend_name not in BACKENDS:
raise ui.UserError(
"Selected ReplayGain backend {} is not supported. "
"Please select one of: {}".format(
- self.backend_name,
- ', '.join(self.backends.keys())
+ self.backend_name, ", ".join(BACKENDS.keys())
)
)
+
+ # FIXME: Consider renaming the configuration option to 'peak_method'
+ # and deprecating the old name 'peak'.
peak_method = self.config["peak"].as_str()
- if peak_method not in self.peak_methods:
+ if peak_method not in PeakMethod.__members__:
raise ui.UserError(
"Selected ReplayGain peak method {} is not supported. "
"Please select one of: {}".format(
- peak_method,
- ', '.join(self.peak_methods.keys())
+ peak_method, ", ".join(PeakMethod.__members__)
)
)
- self._peak_method = self.peak_methods[peak_method]
+ # This only applies to plain old rg tags, r128 doesn't store peak
+ # values.
+ self.peak_method = PeakMethod[peak_method]
# On-import analysis.
- if self.config['auto']:
- self.register_listener('import_begin', self.import_begin)
- self.register_listener('import', self.import_end)
+ if self.config["auto"]:
+ self.register_listener("import_begin", self.import_begin)
+ self.register_listener("import", self.import_end)
self.import_stages = [self.imported]
# Formats to use R128.
- self.r128_whitelist = self.config['r128'].as_str_seq()
+ self.r128_whitelist = self.config["r128"].as_str_seq()
try:
- self.backend_instance = self.backends[self.backend_name](
+ self.backend_instance = BACKENDS[self.backend_name](
self.config, self._log
)
except (ReplayGainError, FatalReplayGainError) as e:
- raise ui.UserError(
- f'replaygain initialization failed: {e}')
+ raise ui.UserError(f"replaygain initialization failed: {e}")
- def should_use_r128(self, item):
+ # Start threadpool lazily.
+ self.pool = None
+
+ def should_use_r128(self, item: Item) -> bool:
"""Checks the plugin setting to decide whether the calculation
should be done using the EBU R128 standard and use R128_ tags instead.
"""
return item.format in self.r128_whitelist
- def track_requires_gain(self, item):
- return self.overwrite or \
- (self.should_use_r128(item) and not item.r128_track_gain) or \
- (not self.should_use_r128(item) and
- (not item.rg_track_gain or not item.rg_track_peak))
+ @staticmethod
+ def has_r128_track_data(item: Item) -> bool:
+ return item.r128_track_gain is not None
- def album_requires_gain(self, album):
+ @staticmethod
+ def has_rg_track_data(item: Item) -> bool:
+ return item.rg_track_gain is not None and item.rg_track_peak is not None
+
+ def track_requires_gain(self, item: Item) -> bool:
+ if self.should_use_r128(item):
+ if not self.has_r128_track_data(item):
+ return True
+ else:
+ if not self.has_rg_track_data(item):
+ return True
+
+ return False
+
+ @staticmethod
+ def has_r128_album_data(item: Item) -> bool:
+ return (
+ item.r128_track_gain is not None
+ and item.r128_album_gain is not None
+ )
+
+ @staticmethod
+ def has_rg_album_data(item: Item) -> bool:
+ return item.rg_album_gain is not None and item.rg_album_peak is not None
+
+ def album_requires_gain(self, album: Album) -> bool:
# Skip calculating gain only when *all* files don't need
# recalculation. This way, if any file among an album's tracks
# needs recalculation, we still get an accurate album gain
# value.
- return self.overwrite or \
- any([self.should_use_r128(item) and
- (not item.r128_track_gain or not item.r128_album_gain)
- for item in album.items()]) or \
- any([not self.should_use_r128(item) and
- (not item.rg_album_gain or not item.rg_album_peak)
- for item in album.items()])
+ for item in album.items():
+ if self.should_use_r128(item):
+ if not self.has_r128_album_data(item):
+ return True
+ else:
+ if not self.has_rg_album_data(item):
+ return True
- def store_track_gain(self, item, track_gain):
- item.rg_track_gain = track_gain.gain
- item.rg_track_peak = track_gain.peak
- item.store()
- self._log.debug('applied track gain {0} LU, peak {1} of FS',
- item.rg_track_gain, item.rg_track_peak)
+ return False
- def store_album_gain(self, item, album_gain):
- item.rg_album_gain = album_gain.gain
- item.rg_album_peak = album_gain.peak
- item.store()
- self._log.debug('applied album gain {0} LU, peak {1} of FS',
- item.rg_album_gain, item.rg_album_peak)
-
- def store_track_r128_gain(self, item, track_gain):
- item.r128_track_gain = track_gain.gain
- item.store()
-
- self._log.debug('applied r128 track gain {0} LU',
- item.r128_track_gain)
-
- def store_album_r128_gain(self, item, album_gain):
- item.r128_album_gain = album_gain.gain
- item.store()
- self._log.debug('applied r128 album gain {0} LU',
- item.r128_album_gain)
-
- def tag_specific_values(self, items):
- """Return some tag specific values.
-
- Returns a tuple (store_track_gain, store_album_gain, target_level,
- peak_method).
- """
- if any([self.should_use_r128(item) for item in items]):
- store_track_gain = self.store_track_r128_gain
- store_album_gain = self.store_album_r128_gain
- target_level = self.config['r128_targetlevel'].as_number()
- peak = Peak.none # R128_* tags do not store the track/album peak
+ def create_task(
+ self,
+ items: Sequence[Item],
+ use_r128: bool,
+ album: Optional[Album] = None,
+ ) -> RgTask:
+ if use_r128:
+ return R128Task(
+ items,
+ album,
+ self.config["r128_targetlevel"].as_number(),
+ self.backend_instance.NAME,
+ self._log,
+ )
else:
- store_track_gain = self.store_track_gain
- store_album_gain = self.store_album_gain
- target_level = self.config['targetlevel'].as_number()
- peak = self._peak_method
+ return RgTask(
+ items,
+ album,
+ self.config["targetlevel"].as_number(),
+ self.peak_method,
+ self.backend_instance.NAME,
+ self._log,
+ )
- return store_track_gain, store_album_gain, target_level, peak
-
- def handle_album(self, album, write, force=False):
+ def handle_album(self, album: Album, write: bool, force: bool = False):
"""Compute album and track replay gain store it in all of the
album's items.
@@ -1111,23 +1363,22 @@ class ReplayGainPlugin(BeetsPlugin):
items, nothing is done.
"""
if not force and not self.album_requires_gain(album):
- self._log.info('Skipping album {0}', album)
+ self._log.info("Skipping album {0}", album)
return
- if (any([self.should_use_r128(item) for item in album.items()]) and not
- all([self.should_use_r128(item) for item in album.items()])):
+ items_iter = iter(album.items())
+ use_r128 = self.should_use_r128(next(items_iter))
+ if any(use_r128 != self.should_use_r128(i) for i in items_iter):
self._log.error(
"Cannot calculate gain for album {0} (incompatible formats)",
- album)
+ album,
+ )
return
- self._log.info('analyzing {0}', album)
+ self._log.info("analyzing {0}", album)
- tag_vals = self.tag_specific_values(album.items())
- store_track_gain, store_album_gain, target_level, peak = tag_vals
-
- discs = {}
- if self.per_disc:
+ discs: Dict[int, List[Item]] = {}
+ if self.config["per_disc"].get(bool):
for item in album.items():
if discs.get(item.disc) is None:
discs[item.disc] = []
@@ -1135,43 +1386,24 @@ class ReplayGainPlugin(BeetsPlugin):
else:
discs[1] = album.items()
- for discnumber, items in discs.items():
- def _store_album(album_gain):
- if not album_gain or not album_gain.album_gain \
- or len(album_gain.track_gains) != len(items):
- # In some cases, backends fail to produce a valid
- # `album_gain` without throwing FatalReplayGainError
- # => raise non-fatal exception & continue
- raise ReplayGainError(
- "ReplayGain backend `{}` failed "
- "for some tracks in album {}"
- .format(self.backend_name, album)
- )
- for item, track_gain in zip(items,
- album_gain.track_gains):
- store_track_gain(item, track_gain)
- store_album_gain(item, album_gain.album_gain)
- if write:
- item.try_write()
- self._log.debug('done analyzing {0}', item)
+ def store_cb(task: RgTask):
+ task.store(write)
+ for discnumber, items in discs.items():
+ task = self.create_task(items, use_r128, album=album)
try:
self._apply(
- self.backend_instance.compute_album_gain, args=(),
- kwds={
- "items": list(items),
- "target_level": target_level,
- "peak": peak
- },
- callback=_store_album
+ self.backend_instance.compute_album_gain,
+ args=[task],
+ kwds={},
+ callback=store_cb,
)
except ReplayGainError as e:
self._log.info("ReplayGain error: {0}", e)
except FatalReplayGainError as e:
- raise ui.UserError(
- f"Fatal replay gain error: {e}")
+ raise ui.UserError(f"Fatal replay gain error: {e}")
- def handle_track(self, item, write, force=False):
+ def handle_track(self, item: Item, write: bool, force: bool = False):
"""Compute track replay gain and store it in the item.
If ``write`` is truthy then ``item.write()`` is called to write
@@ -1179,101 +1411,79 @@ class ReplayGainPlugin(BeetsPlugin):
in the item, nothing is done.
"""
if not force and not self.track_requires_gain(item):
- self._log.info('Skipping track {0}', item)
+ self._log.info("Skipping track {0}", item)
return
- tag_vals = self.tag_specific_values([item])
- store_track_gain, store_album_gain, target_level, peak = tag_vals
+ use_r128 = self.should_use_r128(item)
- def _store_track(track_gains):
- if not track_gains or len(track_gains) != 1:
- # In some cases, backends fail to produce a valid
- # `track_gains` without throwing FatalReplayGainError
- # => raise non-fatal exception & continue
- raise ReplayGainError(
- "ReplayGain backend `{}` failed for track {}"
- .format(self.backend_name, item)
- )
-
- store_track_gain(item, track_gains[0])
- if write:
- item.try_write()
- self._log.debug('done analyzing {0}', item)
+ def store_cb(task: RgTask):
+ task.store(write)
+ task = self.create_task([item], use_r128)
try:
self._apply(
- self.backend_instance.compute_track_gain, args=(),
- kwds={
- "items": [item],
- "target_level": target_level,
- "peak": peak,
- },
- callback=_store_track
+ self.backend_instance.compute_track_gain,
+ args=[task],
+ kwds={},
+ callback=store_cb,
)
except ReplayGainError as e:
self._log.info("ReplayGain error: {0}", e)
except FatalReplayGainError as e:
raise ui.UserError(f"Fatal replay gain error: {e}")
- def _has_pool(self):
- """Check whether a `ThreadPool` is running instance in `self.pool`
- """
- if hasattr(self, 'pool'):
- if isinstance(self.pool, ThreadPool) and self.pool._state == RUN:
- return True
- return False
-
- def open_pool(self, threads):
- """Open a `ThreadPool` instance in `self.pool`
- """
- if not self._has_pool() and self.backend_instance.do_parallel:
+ def open_pool(self, threads: int):
+ """Open a `ThreadPool` instance in `self.pool`"""
+ if self.pool is None and self.backend_instance.do_parallel:
self.pool = ThreadPool(threads)
- self.exc_queue = queue.Queue()
+ self.exc_queue: queue.Queue = queue.Queue()
signal.signal(signal.SIGINT, self._interrupt)
self.exc_watcher = ExceptionWatcher(
- self.exc_queue, # threads push exceptions here
- self.terminate_pool # abort once an exception occurs
+ self.exc_queue, # threads push exceptions here
+ self.terminate_pool, # abort once an exception occurs
)
self.exc_watcher.start()
- def _apply(self, func, args, kwds, callback):
- if self._has_pool():
- def catch_exc(func, exc_queue, log):
- """Wrapper to catch raised exceptions in threads
- """
- def wfunc(*args, **kwargs):
- try:
- return func(*args, **kwargs)
- except ReplayGainError as e:
- log.info(e.args[0]) # log non-fatal exceptions
- except Exception:
- exc_queue.put(sys.exc_info())
- return wfunc
+ def _apply(
+ self,
+ func: Callable[..., AnyRgTask],
+ args: List[Any],
+ kwds: Dict[str, Any],
+ callback: Callable[[AnyRgTask], Any],
+ ):
+ if self.pool is not None:
- # Wrap function and callback to catch exceptions
- func = catch_exc(func, self.exc_queue, self._log)
- callback = catch_exc(callback, self.exc_queue, self._log)
+ def handle_exc(exc):
+ """Handle exceptions in the async work."""
+ if isinstance(exc, ReplayGainError):
+ self._log.info(exc.args[0]) # Log non-fatal exceptions.
+ else:
+ self.exc_queue.put(exc)
- self.pool.apply_async(func, args, kwds, callback)
+ self.pool.apply_async(
+ func, args, kwds, callback, error_callback=handle_exc
+ )
else:
callback(func(*args, **kwds))
def terminate_pool(self):
- """Terminate the `ThreadPool` instance in `self.pool`
- (e.g. stop execution in case of exception)
+ """Forcibly terminate the `ThreadPool` instance in `self.pool`
+
+ Sends SIGTERM to all processes.
"""
- # Don't call self._as_pool() here,
- # self.pool._state may not be == RUN
- if hasattr(self, 'pool') and isinstance(self.pool, ThreadPool):
+ if self.pool is not None:
self.pool.terminate()
self.pool.join()
+ # Terminating the processes leaves the ExceptionWatcher's queues
+ # in an unknown state, so don't wait for it.
# self.exc_watcher.join()
+ self.pool = None
def _interrupt(self, signal, frame):
try:
- self._log.info('interrupted')
+ self._log.info("interrupted")
self.terminate_pool()
sys.exit(0)
except SystemExit:
@@ -1281,60 +1491,70 @@ class ReplayGainPlugin(BeetsPlugin):
pass
def close_pool(self):
- """Close the `ThreadPool` instance in `self.pool` (if there is one)
- """
- if self._has_pool():
+ """Regularly close the `ThreadPool` instance in `self.pool`."""
+ if self.pool is not None:
self.pool.close()
self.pool.join()
self.exc_watcher.join()
+ self.pool = None
- def import_begin(self, session):
- """Handle `import_begin` event -> open pool
- """
- threads = self.config['threads'].get(int)
+ def import_begin(self, session: ImportSession):
+ """Handle `import_begin` event -> open pool"""
+ threads = cast(int, self.config["threads"].get(int))
- if self.config['parallel_on_import'] \
- and self.config['auto'] \
- and threads:
+ if (
+ self.config["parallel_on_import"]
+ and self.config["auto"]
+ and threads
+ ):
self.open_pool(threads)
def import_end(self, paths):
- """Handle `import` event -> close pool
- """
+ """Handle `import` event -> close pool"""
self.close_pool()
- def imported(self, session, task):
- """Add replay gain info to items or albums of ``task``.
- """
- if self.config['auto']:
+ def imported(self, session: ImportSession, task: ImportTask):
+ """Add replay gain info to items or albums of ``task``."""
+ if self.config["auto"]:
if task.is_album:
- self.handle_album(task.album, False)
+ self.handle_album(task.album, False, self.force_on_import)
else:
- self.handle_track(task.item, False)
+ # Should be a SingletonImportTask
+ assert hasattr(task, "item")
+ self.handle_track(task.item, False, self.force_on_import)
- def command_func(self, lib, opts, args):
+ def command_func(
+ self,
+ lib: Library,
+ opts: optparse.Values,
+ args: List[str],
+ ):
try:
write = ui.should_write(opts.write)
force = opts.force
# Bypass self.open_pool() if called with `--threads 0`
if opts.threads != 0:
- threads = opts.threads or self.config['threads'].get(int)
+ threads = opts.threads or cast(
+ int, self.config["threads"].get(int)
+ )
self.open_pool(threads)
if opts.album:
albums = lib.albums(ui.decargs(args))
self._log.info(
- "Analyzing {} albums ~ {} backend..."
- .format(len(albums), self.backend_name)
+ "Analyzing {} albums ~ {} backend...".format(
+ len(albums), self.backend_name
+ )
)
for album in albums:
self.handle_album(album, write, force)
else:
items = lib.items(ui.decargs(args))
self._log.info(
- "Analyzing {} tracks ~ {} backend..."
- .format(len(items), self.backend_name)
+ "Analyzing {} tracks ~ {} backend...".format(
+ len(items), self.backend_name
+ )
)
for item in items:
self.handle_track(item, write, force)
@@ -1344,25 +1564,40 @@ class ReplayGainPlugin(BeetsPlugin):
# Silence interrupt exceptions
pass
- def commands(self):
- """Return the "replaygain" ui subcommand.
- """
- cmd = ui.Subcommand('replaygain', help='analyze for ReplayGain')
+ def commands(self) -> List[ui.Subcommand]:
+ """Return the "replaygain" ui subcommand."""
+ cmd = ui.Subcommand("replaygain", help="analyze for ReplayGain")
cmd.parser.add_album_option()
cmd.parser.add_option(
- "-t", "--threads", dest="threads", type=int,
- help='change the number of threads, \
- defaults to maximum available processors'
+ "-t",
+ "--threads",
+ dest="threads",
+ type=int,
+ help="change the number of threads, \
+ defaults to maximum available processors",
)
cmd.parser.add_option(
- "-f", "--force", dest="force", action="store_true", default=False,
+ "-f",
+ "--force",
+ dest="force",
+ action="store_true",
+ default=False,
help="analyze all files, including those that "
- "already have ReplayGain metadata")
+ "already have ReplayGain metadata",
+ )
cmd.parser.add_option(
- "-w", "--write", default=None, action="store_true",
- help="write new metadata to files' tags")
+ "-w",
+ "--write",
+ default=None,
+ action="store_true",
+ help="write new metadata to files' tags",
+ )
cmd.parser.add_option(
- "-W", "--nowrite", dest="write", action="store_false",
- help="don't write metadata (opposite of -w)")
+ "-W",
+ "--nowrite",
+ dest="write",
+ action="store_false",
+ help="don't write metadata (opposite of -w)",
+ )
cmd.func = self.command_func
return [cmd]
diff --git a/lib/beetsplug/rewrite.py b/lib/beetsplug/rewrite.py
index e02e4080..83829d65 100644
--- a/lib/beetsplug/rewrite.py
+++ b/lib/beetsplug/rewrite.py
@@ -19,9 +19,8 @@ formats.
import re
from collections import defaultdict
+from beets import library, ui
from beets.plugins import BeetsPlugin
-from beets import ui
-from beets import library
def rewriter(field, rules):
@@ -29,6 +28,7 @@ def rewriter(field, rules):
with the given rewriting rules. ``rules`` must be a list of
(pattern, replacement) pairs.
"""
+
def fieldfunc(item):
value = item._values_fixed[field]
for pattern, replacement in rules:
@@ -37,6 +37,7 @@ def rewriter(field, rules):
return replacement
# Not activated; return original value.
return value
+
return fieldfunc
@@ -55,15 +56,16 @@ class RewritePlugin(BeetsPlugin):
except ValueError:
raise ui.UserError("invalid rewrite specification")
if fieldname not in library.Item._fields:
- raise ui.UserError("invalid field name (%s) in rewriter" %
- fieldname)
- self._log.debug('adding template field {0}', key)
+ raise ui.UserError(
+ "invalid field name (%s) in rewriter" % fieldname
+ )
+ self._log.debug("adding template field {0}", key)
pattern = re.compile(pattern.lower())
rules[fieldname].append((pattern, value))
- if fieldname == 'artist':
+ if fieldname == "artist":
# Special case for the artist field: apply the same
# rewrite for "albumartist" as well.
- rules['albumartist'].append((pattern, value))
+ rules["albumartist"].append((pattern, value))
# Replace each template field with the new rewriter function.
for fieldname, fieldrules in rules.items():
diff --git a/lib/beetsplug/scrub.py b/lib/beetsplug/scrub.py
index d8044668..d1e63ee3 100644
--- a/lib/beetsplug/scrub.py
+++ b/lib/beetsplug/scrub.py
@@ -17,74 +17,78 @@ automatically whenever tags are written.
"""
-from beets.plugins import BeetsPlugin
-from beets import ui
-from beets import util
-from beets import config
import mediafile
import mutagen
+from beets import config, ui, util
+from beets.plugins import BeetsPlugin
+
_MUTAGEN_FORMATS = {
- 'asf': 'ASF',
- 'apev2': 'APEv2File',
- 'flac': 'FLAC',
- 'id3': 'ID3FileType',
- 'mp3': 'MP3',
- 'mp4': 'MP4',
- 'oggflac': 'OggFLAC',
- 'oggspeex': 'OggSpeex',
- 'oggtheora': 'OggTheora',
- 'oggvorbis': 'OggVorbis',
- 'oggopus': 'OggOpus',
- 'trueaudio': 'TrueAudio',
- 'wavpack': 'WavPack',
- 'monkeysaudio': 'MonkeysAudio',
- 'optimfrog': 'OptimFROG',
+ "asf": "ASF",
+ "apev2": "APEv2File",
+ "flac": "FLAC",
+ "id3": "ID3FileType",
+ "mp3": "MP3",
+ "mp4": "MP4",
+ "oggflac": "OggFLAC",
+ "oggspeex": "OggSpeex",
+ "oggtheora": "OggTheora",
+ "oggvorbis": "OggVorbis",
+ "oggopus": "OggOpus",
+ "trueaudio": "TrueAudio",
+ "wavpack": "WavPack",
+ "monkeysaudio": "MonkeysAudio",
+ "optimfrog": "OptimFROG",
}
class ScrubPlugin(BeetsPlugin):
"""Removes extraneous metadata from files' tags."""
+
def __init__(self):
super().__init__()
- self.config.add({
- 'auto': True,
- })
+ self.config.add(
+ {
+ "auto": True,
+ }
+ )
- if self.config['auto']:
+ if self.config["auto"]:
self.register_listener("import_task_files", self.import_task_files)
def commands(self):
def scrub_func(lib, opts, args):
# Walk through matching files and remove tags.
for item in lib.items(ui.decargs(args)):
- self._log.info('scrubbing: {0}',
- util.displayable_path(item.path))
+ self._log.info(
+ "scrubbing: {0}", util.displayable_path(item.path)
+ )
self._scrub_item(item, opts.write)
- scrub_cmd = ui.Subcommand('scrub', help='clean audio tags')
+ scrub_cmd = ui.Subcommand("scrub", help="clean audio tags")
scrub_cmd.parser.add_option(
- '-W', '--nowrite', dest='write',
- action='store_false', default=True,
- help='leave tags empty')
+ "-W",
+ "--nowrite",
+ dest="write",
+ action="store_false",
+ default=True,
+ help="leave tags empty",
+ )
scrub_cmd.func = scrub_func
return [scrub_cmd]
@staticmethod
def _mutagen_classes():
- """Get a list of file type classes from the Mutagen module.
- """
+ """Get a list of file type classes from the Mutagen module."""
classes = []
for modname, clsname in _MUTAGEN_FORMATS.items():
- mod = __import__(f'mutagen.{modname}',
- fromlist=[clsname])
+ mod = __import__(f"mutagen.{modname}", fromlist=[clsname])
classes.append(getattr(mod, clsname))
return classes
def _scrub(self, path):
- """Remove all tags from a file.
- """
+ """Remove all tags from a file."""
for cls in self._mutagen_classes():
# Try opening the file with this type, but just skip in the
# event of any error.
@@ -106,21 +110,22 @@ class ScrubPlugin(BeetsPlugin):
del f[tag]
f.save()
except (OSError, mutagen.MutagenError) as exc:
- self._log.error('could not scrub {0}: {1}',
- util.displayable_path(path), exc)
+ self._log.error(
+ "could not scrub {0}: {1}", util.displayable_path(path), exc
+ )
- def _scrub_item(self, item, restore=True):
+ def _scrub_item(self, item, restore):
"""Remove tags from an Item's associated file and, if `restore`
is enabled, write the database's tags back to the file.
"""
# Get album art if we need to restore it.
if restore:
try:
- mf = mediafile.MediaFile(util.syspath(item.path),
- config['id3v23'].get(bool))
+ mf = mediafile.MediaFile(
+ util.syspath(item.path), config["id3v23"].get(bool)
+ )
except mediafile.UnreadableFileError as exc:
- self._log.error('could not open file to scrub: {0}',
- exc)
+ self._log.error("could not open file to scrub: {0}", exc)
return
images = mf.images
@@ -129,21 +134,23 @@ class ScrubPlugin(BeetsPlugin):
# Restore tags, if enabled.
if restore:
- self._log.debug('writing new tags after scrub')
+ self._log.debug("writing new tags after scrub")
item.try_write()
if images:
- self._log.debug('restoring art')
+ self._log.debug("restoring art")
try:
- mf = mediafile.MediaFile(util.syspath(item.path),
- config['id3v23'].get(bool))
+ mf = mediafile.MediaFile(
+ util.syspath(item.path), config["id3v23"].get(bool)
+ )
mf.images = images
mf.save()
except mediafile.UnreadableFileError as exc:
- self._log.error('could not write tags: {0}', exc)
+ self._log.error("could not write tags: {0}", exc)
def import_task_files(self, session, task):
"""Automatically scrub imported files."""
for item in task.imported_items():
- self._log.debug('auto-scrubbing {0}',
- util.displayable_path(item.path))
- self._scrub_item(item)
+ self._log.debug(
+ "auto-scrubbing {0}", util.displayable_path(item.path)
+ )
+ self._scrub_item(item, ui.should_write())
diff --git a/lib/beetsplug/smartplaylist.py b/lib/beetsplug/smartplaylist.py
index 4c921ecc..9df2cca6 100644
--- a/lib/beetsplug/smartplaylist.py
+++ b/lib/beetsplug/smartplaylist.py
@@ -16,48 +16,112 @@
"""
-from beets.plugins import BeetsPlugin
+import json
+import os
+from urllib.request import pathname2url
+
from beets import ui
-from beets.util import (mkdirall, normpath, sanitize_path, syspath,
- bytestring_path, path_as_posix)
-from beets.library import Item, Album, parse_query_string
from beets.dbcore import OrQuery
from beets.dbcore.query import MultipleSort, ParsingError
-import os
-
-try:
- from urllib.request import pathname2url
-except ImportError:
- # python2 is a bit different
- from urllib import pathname2url
+from beets.library import Album, Item, parse_query_string
+from beets.plugins import BeetsPlugin
+from beets.plugins import send as send_event
+from beets.util import (
+ bytestring_path,
+ displayable_path,
+ mkdirall,
+ normpath,
+ path_as_posix,
+ sanitize_path,
+ syspath,
+)
class SmartPlaylistPlugin(BeetsPlugin):
-
def __init__(self):
super().__init__()
- self.config.add({
- 'relative_to': None,
- 'playlist_dir': '.',
- 'auto': True,
- 'playlists': [],
- 'forward_slash': False,
- 'prefix': '',
- 'urlencode': False,
- })
+ self.config.add(
+ {
+ "relative_to": None,
+ "playlist_dir": ".",
+ "auto": True,
+ "playlists": [],
+ "uri_format": None,
+ "fields": [],
+ "forward_slash": False,
+ "prefix": "",
+ "urlencode": False,
+ "pretend_paths": False,
+ "output": "m3u",
+ }
+ )
- self.config['prefix'].redact = True # May contain username/password.
+ self.config["prefix"].redact = True # May contain username/password.
self._matched_playlists = None
self._unmatched_playlists = None
- if self.config['auto']:
- self.register_listener('database_change', self.db_change)
+ if self.config["auto"]:
+ self.register_listener("database_change", self.db_change)
def commands(self):
spl_update = ui.Subcommand(
- 'splupdate',
- help='update the smart playlists. Playlist names may be '
- 'passed as arguments.'
+ "splupdate",
+ help="update the smart playlists. Playlist names may be "
+ "passed as arguments.",
+ )
+ spl_update.parser.add_option(
+ "-p",
+ "--pretend",
+ action="store_true",
+ help="display query results but don't write playlist files.",
+ )
+ spl_update.parser.add_option(
+ "--pretend-paths",
+ action="store_true",
+ dest="pretend_paths",
+ help="in pretend mode, log the playlist item URIs/paths.",
+ )
+ spl_update.parser.add_option(
+ "-d",
+ "--playlist-dir",
+ dest="playlist_dir",
+ metavar="PATH",
+ type="string",
+ help="directory to write the generated playlist files to.",
+ )
+ spl_update.parser.add_option(
+ "--relative-to",
+ dest="relative_to",
+ metavar="PATH",
+ type="string",
+ help="generate playlist item paths relative to this path.",
+ )
+ spl_update.parser.add_option(
+ "--prefix",
+ type="string",
+ help="prepend string to every path in the playlist file.",
+ )
+ spl_update.parser.add_option(
+ "--forward-slash",
+ action="store_true",
+ dest="forward_slash",
+ help="force forward slash in paths within playlists.",
+ )
+ spl_update.parser.add_option(
+ "--urlencode",
+ action="store_true",
+ help="URL-encode all paths.",
+ )
+ spl_update.parser.add_option(
+ "--uri-format",
+ dest="uri_format",
+ type="string",
+ help="playlist item URI template, e.g. http://beets:8337/item/$id/file.",
+ )
+ spl_update.parser.add_option(
+ "--output",
+ type="string",
+ help="specify the playlist format: m3u|extm3u.",
)
spl_update.func = self.update_cmd
return [spl_update]
@@ -70,13 +134,16 @@ class SmartPlaylistPlugin(BeetsPlugin):
if not a.endswith(".m3u"):
args.add(f"{a}.m3u")
- playlists = {(name, q, a_q)
- for name, q, a_q in self._unmatched_playlists
- if name in args}
+ playlists = {
+ (name, q, a_q)
+ for name, q, a_q in self._unmatched_playlists
+ if name in args
+ }
if not playlists:
raise ui.UserError(
- 'No playlist matching any of {} found'.format(
- [name for name, _, _ in self._unmatched_playlists])
+ "No playlist matching any of {} found".format(
+ [name for name, _, _ in self._unmatched_playlists]
+ )
)
self._matched_playlists = playlists
@@ -84,7 +151,13 @@ class SmartPlaylistPlugin(BeetsPlugin):
else:
self._matched_playlists = self._unmatched_playlists
- self.update_playlists(lib)
+ self.__apply_opts_to_config(opts)
+ self.update_playlists(lib, opts.pretend)
+
+ def __apply_opts_to_config(self, opts):
+ for k, v in opts.__dict__.items():
+ if v is not None and k in self.config:
+ self.config[k] = v
def build_queries(self):
"""
@@ -104,15 +177,14 @@ class SmartPlaylistPlugin(BeetsPlugin):
self._unmatched_playlists = set()
self._matched_playlists = set()
- for playlist in self.config['playlists'].get(list):
- if 'name' not in playlist:
+ for playlist in self.config["playlists"].get(list):
+ if "name" not in playlist:
self._log.warning("playlist configuration is missing name")
continue
- playlist_data = (playlist['name'],)
+ playlist_data = (playlist["name"],)
try:
- for key, model_cls in (('query', Item),
- ('album_query', Album)):
+ for key, model_cls in (("query", Item), ("album_query", Album)):
qs = playlist.get(key)
if qs is None:
query_and_sort = None, None
@@ -122,8 +194,9 @@ class SmartPlaylistPlugin(BeetsPlugin):
query_and_sort = parse_query_string(qs[0], model_cls)
else:
# multiple queries and sorts
- queries, sorts = zip(*(parse_query_string(q, model_cls)
- for q in qs))
+ queries, sorts = zip(
+ *(parse_query_string(q, model_cls) for q in qs)
+ )
query = OrQuery(queries)
final_sorts = []
for s in sorts:
@@ -135,7 +208,7 @@ class SmartPlaylistPlugin(BeetsPlugin):
if not final_sorts:
sort = None
elif len(final_sorts) == 1:
- sort, = final_sorts
+ (sort,) = final_sorts
else:
sort = MultipleSort(final_sorts)
query_and_sort = query, sort
@@ -143,8 +216,9 @@ class SmartPlaylistPlugin(BeetsPlugin):
playlist_data += (query_and_sort,)
except ParsingError as exc:
- self._log.warning("invalid query in playlist {}: {}",
- playlist['name'], exc)
+ self._log.warning(
+ "invalid query in playlist {}: {}", playlist["name"], exc
+ )
continue
self._unmatched_playlists.add(playlist_data)
@@ -163,20 +237,28 @@ class SmartPlaylistPlugin(BeetsPlugin):
for playlist in self._unmatched_playlists:
n, (q, _), (a_q, _) = playlist
if self.matches(model, q, a_q):
- self._log.debug(
- "{0} will be updated because of {1}", n, model)
+ self._log.debug("{0} will be updated because of {1}", n, model)
self._matched_playlists.add(playlist)
- self.register_listener('cli_exit', self.update_playlists)
+ self.register_listener("cli_exit", self.update_playlists)
self._unmatched_playlists -= self._matched_playlists
- def update_playlists(self, lib):
- self._log.info("Updating {0} smart playlists...",
- len(self._matched_playlists))
+ def update_playlists(self, lib, pretend=False):
+ if pretend:
+ self._log.info(
+ "Showing query results for {0} smart playlists...",
+ len(self._matched_playlists),
+ )
+ else:
+ self._log.info(
+ "Updating {0} smart playlists...", len(self._matched_playlists)
+ )
- playlist_dir = self.config['playlist_dir'].as_filename()
+ playlist_dir = self.config["playlist_dir"].as_filename()
playlist_dir = bytestring_path(playlist_dir)
- relative_to = self.config['relative_to'].get()
+ tpl = self.config["uri_format"].get()
+ prefix = bytestring_path(self.config["prefix"].as_str())
+ relative_to = self.config["relative_to"].get()
if relative_to:
relative_to = normpath(relative_to)
@@ -185,7 +267,10 @@ class SmartPlaylistPlugin(BeetsPlugin):
for playlist in self._matched_playlists:
name, (query, q_sort), (album_query, a_q_sort) = playlist
- self._log.debug("Creating playlist {0}", name)
+ if pretend:
+ self._log.info("Results for playlist {}:", name)
+ else:
+ self._log.info("Creating playlist {0}", name)
items = []
if query:
@@ -201,24 +286,71 @@ class SmartPlaylistPlugin(BeetsPlugin):
m3u_name = sanitize_path(m3u_name, lib.replacements)
if m3u_name not in m3us:
m3us[m3u_name] = []
- item_path = item.path
- if relative_to:
- item_path = os.path.relpath(item.path, relative_to)
- if item_path not in m3us[m3u_name]:
- m3us[m3u_name].append(item_path)
+ item_uri = item.path
+ if tpl:
+ item_uri = tpl.replace("$id", str(item.id)).encode("utf-8")
+ else:
+ if relative_to:
+ item_uri = os.path.relpath(item_uri, relative_to)
+ if self.config["forward_slash"].get():
+ item_uri = path_as_posix(item_uri)
+ if self.config["urlencode"]:
+ item_uri = bytestring_path(pathname2url(item_uri))
+ item_uri = prefix + item_uri
- prefix = bytestring_path(self.config['prefix'].as_str())
- # Write all of the accumulated track lists to files.
- for m3u in m3us:
- m3u_path = normpath(os.path.join(playlist_dir,
- bytestring_path(m3u)))
- mkdirall(m3u_path)
- with open(syspath(m3u_path), 'wb') as f:
- for path in m3us[m3u]:
- if self.config['forward_slash'].get():
- path = path_as_posix(path)
- if self.config['urlencode']:
- path = bytestring_path(pathname2url(path))
- f.write(prefix + path + b'\n')
+ if item_uri not in m3us[m3u_name]:
+ m3us[m3u_name].append(PlaylistItem(item, item_uri))
+ if pretend and self.config["pretend_paths"]:
+ print(displayable_path(item_uri))
+ elif pretend:
+ print(item)
- self._log.info("{0} playlists updated", len(self._matched_playlists))
+ if not pretend:
+ # Write all of the accumulated track lists to files.
+ for m3u in m3us:
+ m3u_path = normpath(
+ os.path.join(playlist_dir, bytestring_path(m3u))
+ )
+ mkdirall(m3u_path)
+ pl_format = self.config["output"].get()
+ if pl_format != "m3u" and pl_format != "extm3u":
+ msg = "Unsupported output format '{}' provided! "
+ msg += "Supported: m3u, extm3u"
+ raise Exception(msg.format(pl_format))
+ extm3u = pl_format == "extm3u"
+ with open(syspath(m3u_path), "wb") as f:
+ keys = []
+ if extm3u:
+ keys = self.config["fields"].get(list)
+ f.write(b"#EXTM3U\n")
+ for entry in m3us[m3u]:
+ item = entry.item
+ comment = ""
+ if extm3u:
+ attr = [(k, entry.item[k]) for k in keys]
+ al = [
+ f" {a[0]}={json.dumps(str(a[1]))}" for a in attr
+ ]
+ attrs = "".join(al)
+ comment = "#EXTINF:{}{},{} - {}\n".format(
+ int(item.length), attrs, item.artist, item.title
+ )
+ f.write(comment.encode("utf-8") + entry.uri + b"\n")
+ # Send an event when playlists were updated.
+ send_event("smartplaylist_update")
+
+ if pretend:
+ self._log.info(
+ "Displayed results for {0} playlists",
+ len(self._matched_playlists),
+ )
+ else:
+ self._log.info(
+ "{0} playlists updated", len(self._matched_playlists)
+ )
+
+
+class PlaylistItem:
+ def __init__(self, item, uri):
+ self.item = item
+ self.uri = uri
diff --git a/lib/beetsplug/sonosupdate.py b/lib/beetsplug/sonosupdate.py
index aeb211d8..af3410ff 100644
--- a/lib/beetsplug/sonosupdate.py
+++ b/lib/beetsplug/sonosupdate.py
@@ -16,31 +16,32 @@
This is based on the Kodi Update plugin.
"""
-from beets.plugins import BeetsPlugin
import soco
+from beets.plugins import BeetsPlugin
+
class SonosUpdate(BeetsPlugin):
def __init__(self):
super().__init__()
- self.register_listener('database_change', self.listen_for_db_change)
+ self.register_listener("database_change", self.listen_for_db_change)
def listen_for_db_change(self, lib, model):
"""Listens for beets db change and register the update"""
- self.register_listener('cli_exit', self.update)
+ self.register_listener("cli_exit", self.update)
def update(self, lib):
"""When the client exists try to send refresh request to a Sonos
- controler.
+ controller.
"""
- self._log.info('Requesting a Sonos library update...')
+ self._log.info("Requesting a Sonos library update...")
device = soco.discovery.any_soco()
if device:
device.music_library.start_library_update()
else:
- self._log.warning('Could not find a Sonos device.')
+ self._log.warning("Could not find a Sonos device.")
return
- self._log.info('Sonos update triggered')
+ self._log.info("Sonos update triggered")
diff --git a/lib/beetsplug/spotify.py b/lib/beetsplug/spotify.py
index 2529160d..55a77a8a 100644
--- a/lib/beetsplug/spotify.py
+++ b/lib/beetsplug/spotify.py
@@ -1,5 +1,6 @@
# This file is part of beets.
# Copyright 2019, Rahul Ahuja.
+# Copyright 2022, Alok Saboo.
#
# Permission is hereby granted, free of charge, to any person obtaining
# a copy of this software and associated documentation files (the
@@ -16,59 +17,97 @@
Spotify playlist construction.
"""
-import re
-import json
import base64
-import webbrowser
import collections
+import json
+import re
+import time
+import webbrowser
-import unidecode
-import requests
import confuse
+import requests
+import unidecode
from beets import ui
from beets.autotag.hooks import AlbumInfo, TrackInfo
-from beets.plugins import MetadataSourcePlugin, BeetsPlugin
+from beets.dbcore import types
+from beets.library import DateType
+from beets.plugins import BeetsPlugin, MetadataSourcePlugin
+from beets.util.id_extractors import spotify_id_regex
+
+DEFAULT_WAITING_TIME = 5
+
+
+class SpotifyAPIError(Exception):
+ pass
class SpotifyPlugin(MetadataSourcePlugin, BeetsPlugin):
- data_source = 'Spotify'
+ data_source = "Spotify"
+
+ item_types = {
+ "spotify_track_popularity": types.INTEGER,
+ "spotify_acousticness": types.FLOAT,
+ "spotify_danceability": types.FLOAT,
+ "spotify_energy": types.FLOAT,
+ "spotify_instrumentalness": types.FLOAT,
+ "spotify_key": types.FLOAT,
+ "spotify_liveness": types.FLOAT,
+ "spotify_loudness": types.FLOAT,
+ "spotify_mode": types.INTEGER,
+ "spotify_speechiness": types.FLOAT,
+ "spotify_tempo": types.FLOAT,
+ "spotify_time_signature": types.INTEGER,
+ "spotify_valence": types.FLOAT,
+ "spotify_updated": DateType(),
+ }
# Base URLs for the Spotify API
# Documentation: https://developer.spotify.com/web-api
- oauth_token_url = 'https://accounts.spotify.com/api/token'
- open_track_url = 'https://open.spotify.com/track/'
- search_url = 'https://api.spotify.com/v1/search'
- album_url = 'https://api.spotify.com/v1/albums/'
- track_url = 'https://api.spotify.com/v1/tracks/'
+ oauth_token_url = "https://accounts.spotify.com/api/token"
+ open_track_url = "https://open.spotify.com/track/"
+ search_url = "https://api.spotify.com/v1/search"
+ album_url = "https://api.spotify.com/v1/albums/"
+ track_url = "https://api.spotify.com/v1/tracks/"
+ audio_features_url = "https://api.spotify.com/v1/audio-features/"
- # Spotify IDs consist of 22 alphanumeric characters
- # (zero-left-padded base62 representation of randomly generated UUID4)
- id_regex = {
- 'pattern': r'(^|open\.spotify\.com/{}/)([0-9A-Za-z]{{22}})',
- 'match_group': 2,
+ id_regex = spotify_id_regex
+
+ spotify_audio_features = {
+ "acousticness": "spotify_acousticness",
+ "danceability": "spotify_danceability",
+ "energy": "spotify_energy",
+ "instrumentalness": "spotify_instrumentalness",
+ "key": "spotify_key",
+ "liveness": "spotify_liveness",
+ "loudness": "spotify_loudness",
+ "mode": "spotify_mode",
+ "speechiness": "spotify_speechiness",
+ "tempo": "spotify_tempo",
+ "time_signature": "spotify_time_signature",
+ "valence": "spotify_valence",
}
def __init__(self):
super().__init__()
self.config.add(
{
- 'mode': 'list',
- 'tiebreak': 'popularity',
- 'show_failures': False,
- 'artist_field': 'albumartist',
- 'album_field': 'album',
- 'track_field': 'title',
- 'region_filter': None,
- 'regex': [],
- 'client_id': '4e414367a1d14c75a5c5129a627fcab8',
- 'client_secret': 'f82bdc09b2254f1a8286815d02fd46dc',
- 'tokenfile': 'spotify_token.json',
+ "mode": "list",
+ "tiebreak": "popularity",
+ "show_failures": False,
+ "artist_field": "albumartist",
+ "album_field": "album",
+ "track_field": "title",
+ "region_filter": None,
+ "regex": [],
+ "client_id": "4e414367a1d14c75a5c5129a627fcab8",
+ "client_secret": "f82bdc09b2254f1a8286815d02fd46dc",
+ "tokenfile": "spotify_token.json",
}
)
- self.config['client_secret'].redact = True
+ self.config["client_secret"].redact = True
- self.tokenfile = self.config['tokenfile'].get(
+ self.tokenfile = self.config["tokenfile"].get(
confuse.Filename(in_app_dir=True)
) # Path to the JSON file for storing the OAuth access token.
self.setup()
@@ -81,45 +120,46 @@ class SpotifyPlugin(MetadataSourcePlugin, BeetsPlugin):
except OSError:
self._authenticate()
else:
- self.access_token = token_data['access_token']
+ self.access_token = token_data["access_token"]
def _authenticate(self):
"""Request an access token via the Client Credentials Flow:
https://developer.spotify.com/documentation/general/guides/authorization-guide/#client-credentials-flow
"""
headers = {
- 'Authorization': 'Basic {}'.format(
+ "Authorization": "Basic {}".format(
base64.b64encode(
- ':'.join(
+ ":".join(
self.config[k].as_str()
- for k in ('client_id', 'client_secret')
+ for k in ("client_id", "client_secret")
).encode()
).decode()
)
}
response = requests.post(
self.oauth_token_url,
- data={'grant_type': 'client_credentials'},
+ data={"grant_type": "client_credentials"},
headers=headers,
+ timeout=10,
)
try:
response.raise_for_status()
except requests.exceptions.HTTPError as e:
raise ui.UserError(
- 'Spotify authorization failed: {}\n{}'.format(
- e, response.text
- )
+ "Spotify authorization failed: {}\n{}".format(e, response.text)
)
- self.access_token = response.json()['access_token']
+ self.access_token = response.json()["access_token"]
# Save the token for later use.
self._log.debug(
- '{} access token: {}', self.data_source, self.access_token
+ "{} access token: {}", self.data_source, self.access_token
)
- with open(self.tokenfile, 'w') as f:
- json.dump({'access_token': self.access_token}, f)
+ with open(self.tokenfile, "w") as f:
+ json.dump({"access_token": self.access_token}, f)
- def _handle_response(self, request_type, url, params=None):
+ def _handle_response(
+ self, request_type, url, params=None, retry_count=0, max_retries=3
+ ):
"""Send a request, reauthenticating if necessary.
:param request_type: Type of :class:`Request` constructor,
@@ -133,26 +173,65 @@ class SpotifyPlugin(MetadataSourcePlugin, BeetsPlugin):
:return: JSON data for the class:`Response ` object.
:rtype: dict
"""
- response = request_type(
- url,
- headers={'Authorization': f'Bearer {self.access_token}'},
- params=params,
- )
- if response.status_code != 200:
- if 'token expired' in response.text:
+ try:
+ response = request_type(
+ url,
+ headers={"Authorization": f"Bearer {self.access_token}"},
+ params=params,
+ timeout=10,
+ )
+ response.raise_for_status()
+ return response.json()
+ except requests.exceptions.ReadTimeout:
+ self._log.error("ReadTimeout.")
+ raise SpotifyAPIError("Request timed out.")
+ except requests.exceptions.ConnectionError as e:
+ self._log.error(f"Network error: {e}")
+ raise SpotifyAPIError("Network error.")
+ except requests.exceptions.RequestException as e:
+ if e.response.status_code == 401:
self._log.debug(
- '{} access token has expired. Reauthenticating.',
- self.data_source,
+ f"{self.data_source} access token has expired. "
+ f"Reauthenticating."
)
self._authenticate()
return self._handle_response(request_type, url, params=params)
- else:
- raise ui.UserError(
- '{} API error:\n{}\nURL:\n{}\nparams:\n{}'.format(
- self.data_source, response.text, url, params
- )
+ elif e.response.status_code == 404:
+ raise SpotifyAPIError(
+ f"API Error: {e.response.status_code}\n"
+ f"URL: {url}\nparams: {params}"
)
- return response.json()
+ elif e.response.status_code == 429:
+ if retry_count >= max_retries:
+ raise SpotifyAPIError("Maximum retries reached.")
+ seconds = response.headers.get(
+ "Retry-After", DEFAULT_WAITING_TIME
+ )
+ self._log.debug(
+ f"Too many API requests. Retrying after "
+ f"{seconds} seconds."
+ )
+ time.sleep(int(seconds) + 1)
+ return self._handle_response(
+ request_type,
+ url,
+ params=params,
+ retry_count=retry_count + 1,
+ )
+ elif e.response.status_code == 503:
+ self._log.error("Service Unavailable.")
+ raise SpotifyAPIError("Service Unavailable.")
+ elif e.response.status_code == 502:
+ self._log.error("Bad Gateway.")
+ raise SpotifyAPIError("Bad Gateway.")
+ elif e.response is not None:
+ raise SpotifyAPIError(
+ f"{self.data_source} API error:\n{e.response.text}\n"
+ f"URL:\n{url}\nparams:\n{params}"
+ )
+ else:
+ self._log.error(f"Request failed. Error: {e}")
+ raise SpotifyAPIError("Request failed.")
def album_for_id(self, album_id):
"""Fetch an album by its Spotify ID or URL and return an
@@ -163,26 +242,29 @@ class SpotifyPlugin(MetadataSourcePlugin, BeetsPlugin):
:return: AlbumInfo object for album
:rtype: beets.autotag.hooks.AlbumInfo or None
"""
- spotify_id = self._get_id('album', album_id)
+ spotify_id = self._get_id("album", album_id, self.id_regex)
if spotify_id is None:
return None
album_data = self._handle_response(
requests.get, self.album_url + spotify_id
)
- artist, artist_id = self.get_artist(album_data['artists'])
+ if album_data["name"] == "":
+ self._log.debug("Album removed from Spotify: {}", album_id)
+ return None
+ artist, artist_id = self.get_artist(album_data["artists"])
date_parts = [
- int(part) for part in album_data['release_date'].split('-')
+ int(part) for part in album_data["release_date"].split("-")
]
- release_date_precision = album_data['release_date_precision']
- if release_date_precision == 'day':
+ release_date_precision = album_data["release_date_precision"]
+ if release_date_precision == "day":
year, month, day = date_parts
- elif release_date_precision == 'month':
+ elif release_date_precision == "month":
year, month = date_parts
day = None
- elif release_date_precision == 'year':
+ elif release_date_precision == "year":
year = date_parts[0]
month = None
day = None
@@ -194,9 +276,17 @@ class SpotifyPlugin(MetadataSourcePlugin, BeetsPlugin):
)
)
+ tracks_data = album_data["tracks"]
+ tracks_items = tracks_data["items"]
+ while tracks_data["next"]:
+ tracks_data = self._handle_response(
+ requests.get, tracks_data["next"]
+ )
+ tracks_items.extend(tracks_data["items"])
+
tracks = []
medium_totals = collections.defaultdict(int)
- for i, track_data in enumerate(album_data['tracks']['items'], start=1):
+ for i, track_data in enumerate(tracks_items, start=1):
track = self._get_track(track_data)
track.index = i
medium_totals[track.medium] += 1
@@ -205,21 +295,23 @@ class SpotifyPlugin(MetadataSourcePlugin, BeetsPlugin):
track.medium_total = medium_totals[track.medium]
return AlbumInfo(
- album=album_data['name'],
+ album=album_data["name"],
album_id=spotify_id,
+ spotify_album_id=spotify_id,
artist=artist,
artist_id=artist_id,
+ spotify_artist_id=artist_id,
tracks=tracks,
- albumtype=album_data['album_type'],
- va=len(album_data['artists']) == 1
- and artist.lower() == 'various artists',
+ albumtype=album_data["album_type"],
+ va=len(album_data["artists"]) == 1
+ and artist.lower() == "various artists",
year=year,
month=month,
day=day,
- label=album_data['label'],
+ label=album_data["label"],
mediums=max(medium_totals.keys()),
data_source=self.data_source,
- data_url=album_data['external_urls']['spotify'],
+ data_url=album_data["external_urls"]["spotify"],
)
def _get_track(self, track_data):
@@ -231,18 +323,27 @@ class SpotifyPlugin(MetadataSourcePlugin, BeetsPlugin):
:return: TrackInfo object for track
:rtype: beets.autotag.hooks.TrackInfo
"""
- artist, artist_id = self.get_artist(track_data['artists'])
+ artist, artist_id = self.get_artist(track_data["artists"])
+
+ # Get album information for spotify tracks
+ try:
+ album = track_data["album"]["name"]
+ except (KeyError, TypeError):
+ album = None
return TrackInfo(
- title=track_data['name'],
- track_id=track_data['id'],
+ title=track_data["name"],
+ track_id=track_data["id"],
+ spotify_track_id=track_data["id"],
artist=artist,
+ album=album,
artist_id=artist_id,
- length=track_data['duration_ms'] / 1000,
- index=track_data['track_number'],
- medium=track_data['disc_number'],
- medium_index=track_data['track_number'],
+ spotify_artist_id=artist_id,
+ length=track_data["duration_ms"] / 1000,
+ index=track_data["track_number"],
+ medium=track_data["disc_number"],
+ medium_index=track_data["track_number"],
data_source=self.data_source,
- data_url=track_data['external_urls']['spotify'],
+ data_url=track_data["external_urls"]["spotify"],
)
def track_for_id(self, track_id=None, track_data=None):
@@ -259,7 +360,7 @@ class SpotifyPlugin(MetadataSourcePlugin, BeetsPlugin):
:rtype: beets.autotag.hooks.TrackInfo or None
"""
if track_data is None:
- spotify_id = self._get_id('track', track_id)
+ spotify_id = self._get_id("track", track_id, self.id_regex)
if spotify_id is None:
return None
track_data = self._handle_response(
@@ -271,19 +372,19 @@ class SpotifyPlugin(MetadataSourcePlugin, BeetsPlugin):
# release) and `track.medium_total` (total number of tracks on
# the track's disc).
album_data = self._handle_response(
- requests.get, self.album_url + track_data['album']['id']
+ requests.get, self.album_url + track_data["album"]["id"]
)
medium_total = 0
- for i, track_data in enumerate(album_data['tracks']['items'], start=1):
- if track_data['disc_number'] == track.medium:
+ for i, track_data in enumerate(album_data["tracks"]["items"], start=1):
+ if track_data["disc_number"] == track.medium:
medium_total += 1
- if track_data['id'] == track.track_id:
+ if track_data["id"] == track.track_id:
track.index = i
track.medium_total = medium_total
return track
@staticmethod
- def _construct_search_query(filters=None, keywords=''):
+ def _construct_search_query(filters=None, keywords=""):
"""Construct a query string with the specified filters and keywords to
be provided to the Spotify Search API
(https://developer.spotify.com/documentation/web-api/reference/search/search/#writing-a-query---guidelines).
@@ -297,16 +398,16 @@ class SpotifyPlugin(MetadataSourcePlugin, BeetsPlugin):
"""
query_components = [
keywords,
- ' '.join(':'.join((k, v)) for k, v in filters.items()),
+ " ".join(":".join((k, v)) for k, v in filters.items()),
]
- query = ' '.join([q for q in query_components if q])
+ query = " ".join([q for q in query_components if q])
if not isinstance(query, str):
- query = query.decode('utf8')
+ query = query.decode("utf8")
return unidecode.unidecode(query)
- def _search_api(self, query_type, filters=None, keywords=''):
- """Query the Spotify Search API for the specified ``keywords``, applying
- the provided ``filters``.
+ def _search_api(self, query_type, filters=None, keywords=""):
+ """Query the Spotify Search API for the specified ``keywords``,
+ applying the provided ``filters``.
:param query_type: Item type to search across. Valid types are:
'album', 'artist', 'playlist', and 'track'.
@@ -319,23 +420,20 @@ class SpotifyPlugin(MetadataSourcePlugin, BeetsPlugin):
if no search results are returned.
:rtype: dict or None
"""
- query = self._construct_search_query(
- keywords=keywords, filters=filters
- )
+ query = self._construct_search_query(keywords=keywords, filters=filters)
if not query:
return None
- self._log.debug(
- f"Searching {self.data_source} for '{query}'"
- )
- response_data = (
- self._handle_response(
+ self._log.debug(f"Searching {self.data_source} for '{query}'")
+ try:
+ response = self._handle_response(
requests.get,
self.search_url,
- params={'q': query, 'type': query_type},
+ params={"q": query, "type": query_type},
)
- .get(query_type + 's', {})
- .get('items', [])
- )
+ except SpotifyAPIError as e:
+ self._log.debug("Spotify API error: {}", e)
+ return []
+ response_data = response.get(query_type + "s", {}).get("items", [])
self._log.debug(
"Found {} result(s) from {} for '{}'",
len(response_data),
@@ -345,6 +443,7 @@ class SpotifyPlugin(MetadataSourcePlugin, BeetsPlugin):
return response_data
def commands(self):
+ # autotagger import command
def queries(lib, opts, args):
success = self._parse_opts(opts)
if success:
@@ -352,37 +451,56 @@ class SpotifyPlugin(MetadataSourcePlugin, BeetsPlugin):
self._output_match_results(results)
spotify_cmd = ui.Subcommand(
- 'spotify', help=f'build a {self.data_source} playlist'
+ "spotify", help=f"build a {self.data_source} playlist"
)
spotify_cmd.parser.add_option(
- '-m',
- '--mode',
- action='store',
+ "-m",
+ "--mode",
+ action="store",
help='"open" to open {} with playlist, '
'"list" to print (default)'.format(self.data_source),
)
spotify_cmd.parser.add_option(
- '-f',
- '--show-failures',
- action='store_true',
- dest='show_failures',
- help='list tracks that did not match a {} ID'.format(
+ "-f",
+ "--show-failures",
+ action="store_true",
+ dest="show_failures",
+ help="list tracks that did not match a {} ID".format(
self.data_source
),
)
spotify_cmd.func = queries
- return [spotify_cmd]
+
+ # spotifysync command
+ sync_cmd = ui.Subcommand(
+ "spotifysync", help="fetch track attributes from Spotify"
+ )
+ sync_cmd.parser.add_option(
+ "-f",
+ "--force",
+ dest="force_refetch",
+ action="store_true",
+ default=False,
+ help="re-download data when already present",
+ )
+
+ def func(lib, opts, args):
+ items = lib.items(ui.decargs(args))
+ self._fetch_info(items, ui.should_write(), opts.force_refetch)
+
+ sync_cmd.func = func
+ return [spotify_cmd, sync_cmd]
def _parse_opts(self, opts):
if opts.mode:
- self.config['mode'].set(opts.mode)
+ self.config["mode"].set(opts.mode)
if opts.show_failures:
- self.config['show_failures'].set(True)
+ self.config["show_failures"].set(True)
- if self.config['mode'].get() not in ['list', 'open']:
+ if self.config["mode"].get() not in ["list", "open"]:
self._log.warning(
- '{0} is not a valid mode', self.config['mode'].get()
+ "{0} is not a valid mode", self.config["mode"].get()
)
return False
@@ -408,37 +526,37 @@ class SpotifyPlugin(MetadataSourcePlugin, BeetsPlugin):
if not items:
self._log.debug(
- 'Your beets query returned no items, skipping {}.',
+ "Your beets query returned no items, skipping {}.",
self.data_source,
)
return
- self._log.info('Processing {} tracks...', len(items))
+ self._log.info("Processing {} tracks...", len(items))
for item in items:
# Apply regex transformations if provided
- for regex in self.config['regex'].get():
+ for regex in self.config["regex"].get():
if (
- not regex['field']
- or not regex['search']
- or not regex['replace']
+ not regex["field"]
+ or not regex["search"]
+ or not regex["replace"]
):
continue
- value = item[regex['field']]
- item[regex['field']] = re.sub(
- regex['search'], regex['replace'], value
+ value = item[regex["field"]]
+ item[regex["field"]] = re.sub(
+ regex["search"], regex["replace"], value
)
# Custom values can be passed in the config (just in case)
- artist = item[self.config['artist_field'].get()]
- album = item[self.config['album_field'].get()]
- keywords = item[self.config['track_field'].get()]
+ artist = item[self.config["artist_field"].get()]
+ album = item[self.config["album_field"].get()]
+ keywords = item[self.config["track_field"].get()]
# Query the Web API for each track, look for the items' JSON data
- query_filters = {'artist': artist, 'album': album}
+ query_filters = {"artist": artist, "album": album}
response_data_tracks = self._search_api(
- query_type='track', keywords=keywords, filters=query_filters
+ query_type="track", keywords=keywords, filters=query_filters
)
if not response_data_tracks:
query = self._construct_search_query(
@@ -448,20 +566,20 @@ class SpotifyPlugin(MetadataSourcePlugin, BeetsPlugin):
continue
# Apply market filter if requested
- region_filter = self.config['region_filter'].get()
+ region_filter = self.config["region_filter"].get()
if region_filter:
response_data_tracks = [
track_data
for track_data in response_data_tracks
- if region_filter in track_data['available_markets']
+ if region_filter in track_data["available_markets"]
]
if (
len(response_data_tracks) == 1
- or self.config['tiebreak'].get() == 'first'
+ or self.config["tiebreak"].get() == "first"
):
self._log.debug(
- '{} track(s) found, count: {}',
+ "{} track(s) found, count: {}",
self.data_source,
len(response_data_tracks),
)
@@ -469,29 +587,29 @@ class SpotifyPlugin(MetadataSourcePlugin, BeetsPlugin):
else:
# Use the popularity filter
self._log.debug(
- 'Most popular track chosen, count: {}',
+ "Most popular track chosen, count: {}",
len(response_data_tracks),
)
chosen_result = max(
- response_data_tracks, key=lambda x: x['popularity']
+ response_data_tracks, key=lambda x: x["popularity"]
)
results.append(chosen_result)
failure_count = len(failures)
if failure_count > 0:
- if self.config['show_failures'].get():
+ if self.config["show_failures"].get():
self._log.info(
- '{} track(s) did not match a {} ID:',
+ "{} track(s) did not match a {} ID:",
failure_count,
self.data_source,
)
for track in failures:
- self._log.info('track: {}', track)
- self._log.info('')
+ self._log.info("track: {}", track)
+ self._log.info("")
else:
self._log.warning(
- '{} track(s) did not match a {} ID:\n'
- 'use --show-failures to display',
+ "{} track(s) did not match a {} ID:\n"
+ "use --show-failures to display",
failure_count,
self.data_source,
)
@@ -507,14 +625,14 @@ class SpotifyPlugin(MetadataSourcePlugin, BeetsPlugin):
:type results: list[dict]
"""
if results:
- spotify_ids = [track_data['id'] for track_data in results]
- if self.config['mode'].get() == 'open':
+ spotify_ids = [track_data["id"] for track_data in results]
+ if self.config["mode"].get() == "open":
self._log.info(
- 'Attempting to open {} with playlist'.format(
+ "Attempting to open {} with playlist".format(
self.data_source
)
)
- spotify_url = 'spotify:trackset:Playlist:' + ','.join(
+ spotify_url = "spotify:trackset:Playlist:" + ",".join(
spotify_ids
)
webbrowser.open(spotify_url)
@@ -523,5 +641,72 @@ class SpotifyPlugin(MetadataSourcePlugin, BeetsPlugin):
print(self.open_track_url + spotify_id)
else:
self._log.warning(
- f'No {self.data_source} tracks found from beets query'
+ f"No {self.data_source} tracks found from beets query"
)
+
+ def _fetch_info(self, items, write, force):
+ """Obtain track information from Spotify."""
+
+ self._log.debug("Total {} tracks", len(items))
+
+ for index, item in enumerate(items, start=1):
+ self._log.info(
+ "Processing {}/{} tracks - {} ", index, len(items), item
+ )
+ # If we're not forcing re-downloading for all tracks, check
+ # whether the popularity data is already present
+ if not force:
+ if "spotify_track_popularity" in item:
+ self._log.debug("Popularity already present for: {}", item)
+ continue
+ try:
+ spotify_track_id = item.spotify_track_id
+ except AttributeError:
+ self._log.debug("No track_id present for: {}", item)
+ continue
+
+ popularity, isrc, ean, upc = self.track_info(spotify_track_id)
+ item["spotify_track_popularity"] = popularity
+ item["isrc"] = isrc
+ item["ean"] = ean
+ item["upc"] = upc
+ audio_features = self.track_audio_features(spotify_track_id)
+ if audio_features is None:
+ self._log.info("No audio features found for: {}", item)
+ continue
+ for feature in audio_features.keys():
+ if feature in self.spotify_audio_features.keys():
+ item[self.spotify_audio_features[feature]] = audio_features[
+ feature
+ ]
+ item["spotify_updated"] = time.time()
+ item.store()
+ if write:
+ item.try_write()
+
+ def track_info(self, track_id=None):
+ """Fetch a track's popularity and external IDs using its Spotify ID."""
+ track_data = self._handle_response(
+ requests.get, self.track_url + track_id
+ )
+ self._log.debug(
+ "track_popularity: {} and track_isrc: {}",
+ track_data.get("popularity"),
+ track_data.get("external_ids").get("isrc"),
+ )
+ return (
+ track_data.get("popularity"),
+ track_data.get("external_ids").get("isrc"),
+ track_data.get("external_ids").get("ean"),
+ track_data.get("external_ids").get("upc"),
+ )
+
+ def track_audio_features(self, track_id=None):
+ """Fetch track audio features by its Spotify ID."""
+ try:
+ return self._handle_response(
+ requests.get, self.audio_features_url + track_id
+ )
+ except SpotifyAPIError as e:
+ self._log.debug("Spotify API error: {}", e)
+ return None
diff --git a/lib/beetsplug/subsonicplaylist.py b/lib/beetsplug/subsonicplaylist.py
index ead78919..606cdc8b 100644
--- a/lib/beetsplug/subsonicplaylist.py
+++ b/lib/beetsplug/subsonicplaylist.py
@@ -15,9 +15,9 @@
import random
import string
-from xml.etree import ElementTree
from hashlib import md5
from urllib.parse import urlencode
+from xml.etree import ElementTree
import requests
@@ -26,7 +26,7 @@ from beets.dbcore.query import MatchQuery
from beets.plugins import BeetsPlugin
from beets.ui import Subcommand
-__author__ = 'https://github.com/MrNuggelz'
+__author__ = "https://github.com/MrNuggelz"
def filter_to_be_removed(items, keys):
@@ -34,17 +34,22 @@ def filter_to_be_removed(items, keys):
dont_remove = []
for artist, album, title in keys:
for item in items:
- if artist == item['artist'] and \
- album == item['album'] and \
- title == item['title']:
+ if (
+ artist == item["artist"]
+ and album == item["album"]
+ and title == item["title"]
+ ):
dont_remove.append(item)
return [item for item in items if item not in dont_remove]
else:
+
def to_be_removed(item):
for artist, album, title in keys:
- if artist == item['artist'] and\
- album == item['album'] and\
- title == item['title']:
+ if (
+ artist == item["artist"]
+ and album == item["album"]
+ and title == item["title"]
+ ):
return False
return True
@@ -52,111 +57,121 @@ def filter_to_be_removed(items, keys):
class SubsonicPlaylistPlugin(BeetsPlugin):
-
def __init__(self):
super().__init__()
self.config.add(
{
- 'delete': False,
- 'playlist_ids': [],
- 'playlist_names': [],
- 'username': '',
- 'password': ''
+ "delete": False,
+ "playlist_ids": [],
+ "playlist_names": [],
+ "username": "",
+ "password": "",
}
)
- self.config['password'].redact = True
+ self.config["password"].redact = True
def update_tags(self, playlist_dict, lib):
with lib.transaction():
for query, playlist_tag in playlist_dict.items():
- query = AndQuery([MatchQuery("artist", query[0]),
- MatchQuery("album", query[1]),
- MatchQuery("title", query[2])])
+ query = AndQuery(
+ [
+ MatchQuery("artist", query[0]),
+ MatchQuery("album", query[1]),
+ MatchQuery("title", query[2]),
+ ]
+ )
items = lib.items(query)
if not items:
- self._log.warn("{} | track not found ({})", playlist_tag,
- query)
+ self._log.warn(
+ "{} | track not found ({})", playlist_tag, query
+ )
continue
for item in items:
item.subsonic_playlist = playlist_tag
item.try_sync(write=True, move=False)
def get_playlist(self, playlist_id):
- xml = self.send('getPlaylist', {'id': playlist_id}).text
+ xml = self.send("getPlaylist", {"id": playlist_id}).text
playlist = ElementTree.fromstring(xml)[0]
- if playlist.attrib.get('code', '200') != '200':
- alt_error = 'error getting playlist, but no error message found'
- self._log.warn(playlist.attrib.get('message', alt_error))
+ if playlist.attrib.get("code", "200") != "200":
+ alt_error = "error getting playlist, but no error message found"
+ self._log.warn(playlist.attrib.get("message", alt_error))
return
- name = playlist.attrib.get('name', 'undefined')
- tracks = [(t.attrib['artist'], t.attrib['album'], t.attrib['title'])
- for t in playlist]
+ name = playlist.attrib.get("name", "undefined")
+ tracks = [
+ (t.attrib["artist"], t.attrib["album"], t.attrib["title"])
+ for t in playlist
+ ]
return name, tracks
def commands(self):
def build_playlist(lib, opts, args):
self.config.set_args(opts)
- ids = self.config['playlist_ids'].as_str_seq()
- if self.config['playlist_names'].as_str_seq():
+ ids = self.config["playlist_ids"].as_str_seq()
+ if self.config["playlist_names"].as_str_seq():
playlists = ElementTree.fromstring(
- self.send('getPlaylists').text)[0]
- if playlists.attrib.get('code', '200') != '200':
- alt_error = 'error getting playlists,' \
- ' but no error message found'
- self._log.warn(
- playlists.attrib.get('message', alt_error))
+ self.send("getPlaylists").text
+ )[0]
+ if playlists.attrib.get("code", "200") != "200":
+ alt_error = (
+ "error getting playlists," " but no error message found"
+ )
+ self._log.warn(playlists.attrib.get("message", alt_error))
return
- for name in self.config['playlist_names'].as_str_seq():
+ for name in self.config["playlist_names"].as_str_seq():
for playlist in playlists:
- if name == playlist.attrib['name']:
- ids.append(playlist.attrib['id'])
+ if name == playlist.attrib["name"]:
+ ids.append(playlist.attrib["id"])
playlist_dict = self.get_playlists(ids)
# delete old tags
- if self.config['delete']:
+ if self.config["delete"]:
existing = list(lib.items('subsonic_playlist:";"'))
to_be_removed = filter_to_be_removed(
- existing,
- playlist_dict.keys())
+ existing, playlist_dict.keys()
+ )
for item in to_be_removed:
- item['subsonic_playlist'] = ''
+ item["subsonic_playlist"] = ""
with lib.transaction():
item.try_sync(write=True, move=False)
self.update_tags(playlist_dict, lib)
subsonicplaylist_cmds = Subcommand(
- 'subsonicplaylist', help='import a subsonic playlist'
+ "subsonicplaylist", help="import a subsonic playlist"
)
subsonicplaylist_cmds.parser.add_option(
- '-d',
- '--delete',
- action='store_true',
- help='delete tag from items not in any playlist anymore',
+ "-d",
+ "--delete",
+ action="store_true",
+ help="delete tag from items not in any playlist anymore",
)
subsonicplaylist_cmds.func = build_playlist
return [subsonicplaylist_cmds]
def generate_token(self):
- salt = ''.join(random.choices(string.ascii_lowercase + string.digits))
- return md5(
- (self.config['password'].get() + salt).encode()).hexdigest(), salt
+ salt = "".join(random.choices(string.ascii_lowercase + string.digits))
+ return (
+ md5((self.config["password"].get() + salt).encode()).hexdigest(),
+ salt,
+ )
def send(self, endpoint, params=None):
if params is None:
params = {}
a, b = self.generate_token()
- params['u'] = self.config['username']
- params['t'] = a
- params['s'] = b
- params['v'] = '1.12.0'
- params['c'] = 'beets'
- resp = requests.get('{}/rest/{}?{}'.format(
- self.config['base_url'].get(),
- endpoint,
- urlencode(params))
+ params["u"] = self.config["username"]
+ params["t"] = a
+ params["s"] = b
+ params["v"] = "1.12.0"
+ params["c"] = "beets"
+ resp = requests.get(
+ "{}/rest/{}?{}".format(
+ self.config["base_url"].get(), endpoint, urlencode(params)
+ ),
+ timeout=10,
)
return resp
@@ -166,6 +181,6 @@ class SubsonicPlaylistPlugin(BeetsPlugin):
name, tracks = self.get_playlist(playlist_id)
for track in tracks:
if track not in output:
- output[track] = ';'
- output[track] += name + ';'
+ output[track] = ";"
+ output[track] += name + ";"
return output
diff --git a/lib/beetsplug/subsonicupdate.py b/lib/beetsplug/subsonicupdate.py
index 9480bcb4..2a537e35 100644
--- a/lib/beetsplug/subsonicupdate.py
+++ b/lib/beetsplug/subsonicupdate.py
@@ -32,28 +32,37 @@ is not supported, use password instead:
import hashlib
import random
import string
+from binascii import hexlify
import requests
-from binascii import hexlify
from beets import config
from beets.plugins import BeetsPlugin
-__author__ = 'https://github.com/maffo999'
+__author__ = "https://github.com/maffo999"
class SubsonicUpdate(BeetsPlugin):
def __init__(self):
super().__init__()
# Set default configuration values
- config['subsonic'].add({
- 'user': 'admin',
- 'pass': 'admin',
- 'url': 'http://localhost:4040',
- 'auth': 'token',
- })
- config['subsonic']['pass'].redact = True
- self.register_listener('import', self.start_scan)
+ config["subsonic"].add(
+ {
+ "user": "admin",
+ "pass": "admin",
+ "url": "http://localhost:4040",
+ "auth": "token",
+ }
+ )
+ config["subsonic"]["pass"].redact = True
+ self.register_listener("database_change", self.db_change)
+ self.register_listener("smartplaylist_update", self.spl_update)
+
+ def db_change(self, lib, model):
+ self.register_listener("cli_exit", self.start_scan)
+
+ def spl_update(self):
+ self.register_listener("cli_exit", self.start_scan)
@staticmethod
def __create_token():
@@ -61,13 +70,13 @@ class SubsonicUpdate(BeetsPlugin):
:return: The generated salt and hashed token
"""
- password = config['subsonic']['pass'].as_str()
+ password = config["subsonic"]["pass"].as_str()
# Pick the random sequence and salt the password
r = string.ascii_letters + string.digits
salt = "".join([random.choice(r) for _ in range(6)])
salted_password = password + salt
- token = hashlib.md5(salted_password.encode('utf-8')).hexdigest()
+ token = hashlib.md5(salted_password.encode("utf-8")).hexdigest()
# Put together the payload of the request to the server and the URL
return salt, token
@@ -81,64 +90,71 @@ class SubsonicUpdate(BeetsPlugin):
:return: Endpoint for updating Subsonic
"""
- url = config['subsonic']['url'].as_str()
- if url and url.endswith('/'):
+ url = config["subsonic"]["url"].as_str()
+ if url and url.endswith("/"):
url = url[:-1]
# @deprecated("Use url config option instead")
if not url:
- host = config['subsonic']['host'].as_str()
- port = config['subsonic']['port'].get(int)
- context_path = config['subsonic']['contextpath'].as_str()
- if context_path == '/':
- context_path = ''
+ host = config["subsonic"]["host"].as_str()
+ port = config["subsonic"]["port"].get(int)
+ context_path = config["subsonic"]["contextpath"].as_str()
+ if context_path == "/":
+ context_path = ""
url = f"http://{host}:{port}{context_path}"
- return url + f'/rest/{endpoint}'
+ return url + f"/rest/{endpoint}"
def start_scan(self):
- user = config['subsonic']['user'].as_str()
- auth = config['subsonic']['auth'].as_str()
+ user = config["subsonic"]["user"].as_str()
+ auth = config["subsonic"]["auth"].as_str()
url = self.__format_url("startScan")
- self._log.debug('URL is {0}', url)
- self._log.debug('auth type is {0}', config['subsonic']['auth'])
+ self._log.debug("URL is {0}", url)
+ self._log.debug("auth type is {0}", config["subsonic"]["auth"])
if auth == "token":
salt, token = self.__create_token()
payload = {
- 'u': user,
- 't': token,
- 's': salt,
- 'v': '1.13.0', # Subsonic 5.3 and newer
- 'c': 'beets',
- 'f': 'json'
+ "u": user,
+ "t": token,
+ "s": salt,
+ "v": "1.13.0", # Subsonic 5.3 and newer
+ "c": "beets",
+ "f": "json",
}
elif auth == "password":
- password = config['subsonic']['pass'].as_str()
+ password = config["subsonic"]["pass"].as_str()
encpass = hexlify(password.encode()).decode()
payload = {
- 'u': user,
- 'p': f'enc:{encpass}',
- 'v': '1.12.0',
- 'c': 'beets',
- 'f': 'json'
+ "u": user,
+ "p": f"enc:{encpass}",
+ "v": "1.12.0",
+ "c": "beets",
+ "f": "json",
}
else:
return
try:
- response = requests.get(url, params=payload)
+ response = requests.get(
+ url,
+ params=payload,
+ timeout=10,
+ )
json = response.json()
- if response.status_code == 200 and \
- json['subsonic-response']['status'] == "ok":
- count = json['subsonic-response']['scanStatus']['count']
- self._log.info(
- f'Updating Subsonic; scanning {count} tracks')
- elif response.status_code == 200 and \
- json['subsonic-response']['status'] == "failed":
- error_message = json['subsonic-response']['error']['message']
- self._log.error(f'Error: {error_message}')
+ if (
+ response.status_code == 200
+ and json["subsonic-response"]["status"] == "ok"
+ ):
+ count = json["subsonic-response"]["scanStatus"]["count"]
+ self._log.info(f"Updating Subsonic; scanning {count} tracks")
+ elif (
+ response.status_code == 200
+ and json["subsonic-response"]["status"] == "failed"
+ ):
+ error_message = json["subsonic-response"]["error"]["message"]
+ self._log.error(f"Error: {error_message}")
else:
- self._log.error('Error: {0}', json)
+ self._log.error("Error: {0}", json)
except Exception as error:
- self._log.error(f'Error: {error}')
+ self._log.error(f"Error: {error}")
diff --git a/lib/beetsplug/substitute.py b/lib/beetsplug/substitute.py
new file mode 100644
index 00000000..94b79007
--- /dev/null
+++ b/lib/beetsplug/substitute.py
@@ -0,0 +1,56 @@
+# This file is part of beets.
+# Copyright 2023, Daniele Ferone.
+#
+# Permission is hereby granted, free of charge, to any person obtaining
+# a copy of this software and associated documentation files (the
+# "Software"), to deal in the Software without restriction, including
+# without limitation the rights to use, copy, modify, merge, publish,
+# distribute, sublicense, and/or sell copies of the Software, and to
+# permit persons to whom the Software is furnished to do so, subject to
+# the following conditions:
+#
+# The above copyright notice and this permission notice shall be
+# included in all copies or substantial portions of the Software.
+
+"""The substitute plugin module.
+
+Uses user-specified substitution rules to canonicalize names for path formats.
+"""
+
+import re
+
+from beets.plugins import BeetsPlugin
+
+
+class Substitute(BeetsPlugin):
+ """The substitute plugin class.
+
+ Create a template field function that substitute the given field with the
+ given substitution rules. ``rules`` must be a list of (pattern,
+ replacement) pairs.
+ """
+
+ def tmpl_substitute(self, text):
+ """Do the actual replacing."""
+ if text:
+ for pattern, replacement in self.substitute_rules:
+ if pattern.match(text.lower()):
+ return replacement
+ return text
+ else:
+ return ""
+
+ def __init__(self):
+ """Initialize the substitute plugin.
+
+ Get the configuration, register template function and create list of
+ substitute rules.
+ """
+ super().__init__()
+ self.substitute_rules = []
+ self.template_funcs["substitute"] = self.tmpl_substitute
+
+ for key, view in self.config.items():
+ value = view.as_str()
+ pattern = re.compile(key.lower())
+ self.substitute_rules.append((pattern, value))
diff --git a/lib/beetsplug/the.py b/lib/beetsplug/the.py
index e6626d2b..c6fb46dd 100644
--- a/lib/beetsplug/the.py
+++ b/lib/beetsplug/the.py
@@ -16,50 +16,55 @@
import re
+from typing import List
+
from beets.plugins import BeetsPlugin
-__author__ = 'baobab@heresiarch.info'
-__version__ = '1.1'
+__author__ = "baobab@heresiarch.info"
+__version__ = "1.1"
-PATTERN_THE = '^the\\s'
-PATTERN_A = '^[a][n]?\\s'
-FORMAT = '{0}, {1}'
+PATTERN_THE = "^the\\s"
+PATTERN_A = "^[a][n]?\\s"
+FORMAT = "{0}, {1}"
class ThePlugin(BeetsPlugin):
-
- patterns = []
+ patterns: List[str] = []
def __init__(self):
super().__init__()
- self.template_funcs['the'] = self.the_template_func
+ self.template_funcs["the"] = self.the_template_func
- self.config.add({
- 'the': True,
- 'a': True,
- 'format': '{0}, {1}',
- 'strip': False,
- 'patterns': [],
- })
+ self.config.add(
+ {
+ "the": True,
+ "a": True,
+ "format": "{0}, {1}",
+ "strip": False,
+ "patterns": [],
+ }
+ )
- self.patterns = self.config['patterns'].as_str_seq()
+ self.patterns = self.config["patterns"].as_str_seq()
for p in self.patterns:
if p:
try:
re.compile(p)
except re.error:
- self._log.error('invalid pattern: {0}', p)
+ self._log.error("invalid pattern: {0}", p)
else:
- if not (p.startswith('^') or p.endswith('$')):
- self._log.warning('warning: \"{0}\" will not '
- 'match string start/end', p)
- if self.config['a']:
+ if not (p.startswith("^") or p.endswith("$")):
+ self._log.warning(
+ 'warning: "{0}" will not ' "match string start/end",
+ p,
+ )
+ if self.config["a"]:
self.patterns = [PATTERN_A] + self.patterns
- if self.config['the']:
+ if self.config["the"]:
self.patterns = [PATTERN_THE] + self.patterns
if not self.patterns:
- self._log.warning('no patterns defined!')
+ self._log.warning("no patterns defined!")
def unthe(self, text, pattern):
"""Moves pattern in the path format string or strips it
@@ -75,14 +80,14 @@ class ThePlugin(BeetsPlugin):
except IndexError:
return text
else:
- r = re.sub(r, '', text).strip()
- if self.config['strip']:
+ r = re.sub(r, "", text).strip()
+ if self.config["strip"]:
return r
else:
- fmt = self.config['format'].as_str()
+ fmt = self.config["format"].as_str()
return fmt.format(r, t.strip()).strip()
else:
- return ''
+ return ""
def the_template_func(self, text):
if not self.patterns:
@@ -91,8 +96,8 @@ class ThePlugin(BeetsPlugin):
for p in self.patterns:
r = self.unthe(text, p)
if r != text:
- self._log.debug('\"{0}\" -> \"{1}\"', text, r)
+ self._log.debug('"{0}" -> "{1}"', text, r)
break
return r
else:
- return ''
+ return ""
diff --git a/lib/beetsplug/thumbnails.py b/lib/beetsplug/thumbnails.py
index 6bd9cbac..19c19f06 100644
--- a/lib/beetsplug/thumbnails.py
+++ b/lib/beetsplug/thumbnails.py
@@ -19,51 +19,60 @@ Spec: standards.freedesktop.org/thumbnail-spec/latest/index.html
"""
-from hashlib import md5
-import os
-import shutil
-from itertools import chain
-from pathlib import PurePosixPath
import ctypes
import ctypes.util
+import os
+import shutil
+from hashlib import md5
+from pathlib import PurePosixPath
from xdg import BaseDirectory
+from beets import util
from beets.plugins import BeetsPlugin
from beets.ui import Subcommand, decargs
-from beets import util
-from beets.util.artresizer import ArtResizer, get_im_version, get_pil_version
-
+from beets.util import bytestring_path, displayable_path, syspath
+from beets.util.artresizer import ArtResizer
BASE_DIR = os.path.join(BaseDirectory.xdg_cache_home, "thumbnails")
-NORMAL_DIR = util.bytestring_path(os.path.join(BASE_DIR, "normal"))
-LARGE_DIR = util.bytestring_path(os.path.join(BASE_DIR, "large"))
+NORMAL_DIR = bytestring_path(os.path.join(BASE_DIR, "normal"))
+LARGE_DIR = bytestring_path(os.path.join(BASE_DIR, "large"))
class ThumbnailsPlugin(BeetsPlugin):
def __init__(self):
super().__init__()
- self.config.add({
- 'auto': True,
- 'force': False,
- 'dolphin': False,
- })
+ self.config.add(
+ {
+ "auto": True,
+ "force": False,
+ "dolphin": False,
+ }
+ )
- self.write_metadata = None
- if self.config['auto'] and self._check_local_ok():
- self.register_listener('art_set', self.process_album)
+ if self.config["auto"] and self._check_local_ok():
+ self.register_listener("art_set", self.process_album)
def commands(self):
- thumbnails_command = Subcommand("thumbnails",
- help="Create album thumbnails")
+ thumbnails_command = Subcommand(
+ "thumbnails", help="Create album thumbnails"
+ )
thumbnails_command.parser.add_option(
- '-f', '--force',
- dest='force', action='store_true', default=False,
- help='force regeneration of thumbnails deemed fine (existing & '
- 'recent enough)')
+ "-f",
+ "--force",
+ dest="force",
+ action="store_true",
+ default=False,
+ help="force regeneration of thumbnails deemed fine (existing & "
+ "recent enough)",
+ )
thumbnails_command.parser.add_option(
- '--dolphin', dest='dolphin', action='store_true', default=False,
- help="create Dolphin-compatible thumbnail information (for KDE)")
+ "--dolphin",
+ dest="dolphin",
+ action="store_true",
+ default=False,
+ help="create Dolphin-compatible thumbnail information (for KDE)",
+ )
thumbnails_command.func = self.process_query
return [thumbnails_command]
@@ -75,29 +84,29 @@ class ThumbnailsPlugin(BeetsPlugin):
self.process_album(album)
def _check_local_ok(self):
- """Check that's everythings ready:
- - local capability to resize images
- - thumbnail dirs exist (create them if needed)
- - detect whether we'll use PIL or IM
- - detect whether we'll use GIO or Python to get URIs
+ """Check that everything is ready:
+ - local capability to resize images
+ - thumbnail dirs exist (create them if needed)
+ - detect whether we'll use PIL or IM
+ - detect whether we'll use GIO or Python to get URIs
"""
if not ArtResizer.shared.local:
- self._log.warning("No local image resizing capabilities, "
- "cannot generate thumbnails")
+ self._log.warning(
+ "No local image resizing capabilities, "
+ "cannot generate thumbnails"
+ )
return False
for dir in (NORMAL_DIR, LARGE_DIR):
- if not os.path.exists(dir):
- os.makedirs(dir)
+ if not os.path.exists(syspath(dir)):
+ os.makedirs(syspath(dir))
- if get_im_version():
- self.write_metadata = write_metadata_im
- tool = "IM"
- else:
- assert get_pil_version() # since we're local
- self.write_metadata = write_metadata_pil
- tool = "PIL"
- self._log.debug("using {0} to write metadata", tool)
+ if not ArtResizer.shared.can_write_metadata:
+ raise RuntimeError(
+ f"Thumbnails: ArtResizer backend {ArtResizer.shared.method}"
+ f" unexpectedly cannot write image metadata."
+ )
+ self._log.debug(f"using {ArtResizer.shared.method} to write metadata")
uri_getter = GioURI()
if not uri_getter.available:
@@ -108,20 +117,20 @@ class ThumbnailsPlugin(BeetsPlugin):
return True
def process_album(self, album):
- """Produce thumbnails for the album folder.
- """
- self._log.debug('generating thumbnail for {0}', album)
+ """Produce thumbnails for the album folder."""
+ self._log.debug("generating thumbnail for {0}", album)
if not album.artpath:
- self._log.info('album {0} has no art', album)
+ self._log.info("album {0} has no art", album)
return
- if self.config['dolphin']:
+ if self.config["dolphin"]:
self.make_dolphin_cover_thumbnail(album)
size = ArtResizer.shared.get_size(album.artpath)
if not size:
- self._log.warning('problem getting the picture size for {0}',
- album.artpath)
+ self._log.warning(
+ "problem getting the picture size for {0}", album.artpath
+ )
return
wrote = True
@@ -130,9 +139,9 @@ class ThumbnailsPlugin(BeetsPlugin):
wrote &= self.make_cover_thumbnail(album, 128, NORMAL_DIR)
if wrote:
- self._log.info('wrote thumbnail for {0}', album)
+ self._log.info("wrote thumbnail for {0}", album)
else:
- self._log.info('nothing to do for {0}', album)
+ self._log.info("nothing to do for {0}", album)
def make_cover_thumbnail(self, album, size, target_dir):
"""Make a thumbnail of given size for `album` and put it in
@@ -140,19 +149,28 @@ class ThumbnailsPlugin(BeetsPlugin):
"""
target = os.path.join(target_dir, self.thumbnail_file_name(album.path))
- if os.path.exists(target) and \
- os.stat(target).st_mtime > os.stat(album.artpath).st_mtime:
- if self.config['force']:
- self._log.debug("found a suitable {1}x{1} thumbnail for {0}, "
- "forcing regeneration", album, size)
+ if (
+ os.path.exists(syspath(target))
+ and os.stat(syspath(target)).st_mtime
+ > os.stat(syspath(album.artpath)).st_mtime
+ ):
+ if self.config["force"]:
+ self._log.debug(
+ "found a suitable {1}x{1} thumbnail for {0}, "
+ "forcing regeneration",
+ album,
+ size,
+ )
else:
- self._log.debug("{1}x{1} thumbnail for {0} exists and is "
- "recent enough", album, size)
+ self._log.debug(
+ "{1}x{1} thumbnail for {0} exists and is " "recent enough",
+ album,
+ size,
+ )
return False
- resized = ArtResizer.shared.resize(size, album.artpath,
- util.syspath(target))
- self.add_tags(album, util.syspath(resized))
- shutil.move(resized, target)
+ resized = ArtResizer.shared.resize(size, album.artpath, target)
+ self.add_tags(album, resized)
+ shutil.move(syspath(resized), syspath(target))
return True
def thumbnail_file_name(self, path):
@@ -160,52 +178,35 @@ class ThumbnailsPlugin(BeetsPlugin):
See https://standards.freedesktop.org/thumbnail-spec/latest/x227.html
"""
uri = self.get_uri(path)
- hash = md5(uri.encode('utf-8')).hexdigest()
- return util.bytestring_path(f"{hash}.png")
+ hash = md5(uri.encode("utf-8")).hexdigest()
+ return bytestring_path(f"{hash}.png")
def add_tags(self, album, image_path):
"""Write required metadata to the thumbnail
See https://standards.freedesktop.org/thumbnail-spec/latest/x142.html
"""
- mtime = os.stat(album.artpath).st_mtime
- metadata = {"Thumb::URI": self.get_uri(album.artpath),
- "Thumb::MTime": str(mtime)}
+ mtime = os.stat(syspath(album.artpath)).st_mtime
+ metadata = {
+ "Thumb::URI": self.get_uri(album.artpath),
+ "Thumb::MTime": str(mtime),
+ }
try:
- self.write_metadata(image_path, metadata)
+ ArtResizer.shared.write_metadata(image_path, metadata)
except Exception:
- self._log.exception("could not write metadata to {0}",
- util.displayable_path(image_path))
+ self._log.exception(
+ "could not write metadata to {0}", displayable_path(image_path)
+ )
def make_dolphin_cover_thumbnail(self, album):
outfilename = os.path.join(album.path, b".directory")
- if os.path.exists(outfilename):
+ if os.path.exists(syspath(outfilename)):
return
artfile = os.path.split(album.artpath)[1]
- with open(outfilename, 'w') as f:
- f.write('[Desktop Entry]\n')
- f.write('Icon=./{}'.format(artfile.decode('utf-8')))
+ with open(syspath(outfilename), "w") as f:
+ f.write("[Desktop Entry]\n")
+ f.write("Icon=./{}".format(artfile.decode("utf-8")))
f.close()
- self._log.debug("Wrote file {0}", util.displayable_path(outfilename))
-
-
-def write_metadata_im(file, metadata):
- """Enrich the file metadata with `metadata` dict thanks to IM."""
- command = ['convert', file] + \
- list(chain.from_iterable(('-set', k, v)
- for k, v in metadata.items())) + [file]
- util.command_output(command)
- return True
-
-
-def write_metadata_pil(file, metadata):
- """Enrich the file metadata with `metadata` dict thanks to PIL."""
- from PIL import Image, PngImagePlugin
- im = Image.open(file)
- meta = PngImagePlugin.PngInfo()
- for k, v in metadata.items():
- meta.add_text(k, v, 0)
- im.save(file, "PNG", pnginfo=meta)
- return True
+ self._log.debug("Wrote file {0}", displayable_path(outfilename))
class URIGetter:
@@ -221,7 +222,7 @@ class PathlibURI(URIGetter):
name = "Python Pathlib"
def uri(self, path):
- return PurePosixPath(util.py3_path(path)).as_uri()
+ return PurePosixPath(os.fsdecode(path)).as_uri()
def copy_c_string(c_string):
@@ -232,12 +233,12 @@ def copy_c_string(c_string):
# work. A more surefire way would be to allocate a ctypes buffer and copy
# the data with `memcpy` or somesuch.
s = ctypes.cast(c_string, ctypes.c_char_p).value
- return b'' + s
+ return b"" + s
class GioURI(URIGetter):
- """Use gio URI function g_file_get_uri. Paths must be utf-8 encoded.
- """
+ """Use gio URI function g_file_get_uri. Paths must be utf-8 encoded."""
+
name = "GIO"
def __init__(self):
@@ -266,8 +267,11 @@ class GioURI(URIGetter):
def uri(self, path):
g_file_ptr = self.libgio.g_file_new_for_path(path)
if not g_file_ptr:
- raise RuntimeError("No gfile pointer received for {}".format(
- util.displayable_path(path)))
+ raise RuntimeError(
+ "No gfile pointer received for {}".format(
+ displayable_path(path)
+ )
+ )
try:
uri_ptr = self.libgio.g_file_get_uri(g_file_ptr)
@@ -275,8 +279,10 @@ class GioURI(URIGetter):
self.libgio.g_object_unref(g_file_ptr)
if not uri_ptr:
self.libgio.g_free(uri_ptr)
- raise RuntimeError("No URI received from the gfile pointer for "
- "{}".format(util.displayable_path(path)))
+ raise RuntimeError(
+ "No URI received from the gfile pointer for "
+ "{}".format(displayable_path(path))
+ )
try:
uri = copy_c_string(uri_ptr)
@@ -286,6 +292,4 @@ class GioURI(URIGetter):
try:
return uri.decode(util._fsencoding())
except UnicodeDecodeError:
- raise RuntimeError(
- f"Could not decode filename from GIO: {uri!r}"
- )
+ raise RuntimeError(f"Could not decode filename from GIO: {uri!r}")
diff --git a/lib/beetsplug/types.py b/lib/beetsplug/types.py
index 930d5e86..9ba3aac6 100644
--- a/lib/beetsplug/types.py
+++ b/lib/beetsplug/types.py
@@ -13,14 +13,14 @@
# included in all copies or substantial portions of the Software.
-from beets.plugins import BeetsPlugin
-from beets.dbcore import types
from confuse import ConfigValueError
+
from beets import library
+from beets.dbcore import types
+from beets.plugins import BeetsPlugin
class TypesPlugin(BeetsPlugin):
-
@property
def item_types(self):
return self._types()
@@ -35,16 +35,16 @@ class TypesPlugin(BeetsPlugin):
mytypes = {}
for key, value in self.config.items():
- if value.get() == 'int':
+ if value.get() == "int":
mytypes[key] = types.INTEGER
- elif value.get() == 'float':
+ elif value.get() == "float":
mytypes[key] = types.FLOAT
- elif value.get() == 'bool':
+ elif value.get() == "bool":
mytypes[key] = types.BOOLEAN
- elif value.get() == 'date':
+ elif value.get() == "date":
mytypes[key] = library.DateType()
else:
raise ConfigValueError(
- "unknown type '{}' for the '{}' field"
- .format(value, key))
+ "unknown type '{}' for the '{}' field".format(value, key)
+ )
return mytypes
diff --git a/lib/beetsplug/unimported.py b/lib/beetsplug/unimported.py
index 7714ec83..b473a346 100644
--- a/lib/beetsplug/unimported.py
+++ b/lib/beetsplug/unimported.py
@@ -23,46 +23,44 @@ from beets import util
from beets.plugins import BeetsPlugin
from beets.ui import Subcommand, print_
-__author__ = 'https://github.com/MrNuggelz'
+__author__ = "https://github.com/MrNuggelz"
class Unimported(BeetsPlugin):
-
def __init__(self):
super().__init__()
- self.config.add(
- {
- 'ignore_extensions': []
- }
- )
+ self.config.add({"ignore_extensions": [], "ignore_subdirectories": []})
def commands(self):
def print_unimported(lib, opts, args):
ignore_exts = [
- ('.' + x).encode()
+ ("." + x).encode()
for x in self.config["ignore_extensions"].as_str_seq()
]
ignore_dirs = [
os.path.join(lib.directory, x.encode())
for x in self.config["ignore_subdirectories"].as_str_seq()
]
- in_folder = {
- os.path.join(r, file)
- for r, d, f in os.walk(lib.directory)
- for file in f
- if not any(
- [file.endswith(ext) for ext in ignore_exts]
- + [r in ignore_dirs]
- )
- }
+ in_folder = set()
+ for root, _, files in os.walk(lib.directory):
+ # do not traverse if root is a child of an ignored directory
+ if any(root.startswith(ignored) for ignored in ignore_dirs):
+ continue
+ for file in files:
+ # ignore files with ignored extensions
+ if any(file.endswith(ext) for ext in ignore_exts):
+ continue
+ in_folder.add(os.path.join(root, file))
+
in_library = {x.path for x in lib.items()}
art_files = {x.artpath for x in lib.albums()}
for f in in_folder - in_library - art_files:
print_(util.displayable_path(f))
unimported = Subcommand(
- 'unimported',
- help='list all files in the library folder which are not listed'
- ' in the beets library database')
+ "unimported",
+ help="list all files in the library folder which are not listed"
+ " in the beets library database",
+ )
unimported.func = print_unimported
return [unimported]
diff --git a/lib/beetsplug/web/__init__.py b/lib/beetsplug/web/__init__.py
index 240126e9..dcd0ba38 100644
--- a/lib/beetsplug/web/__init__.py
+++ b/lib/beetsplug/web/__init__.py
@@ -14,21 +14,22 @@
"""A Web interface to beets."""
-from beets.plugins import BeetsPlugin
-from beets import ui
-from beets import util
-import beets.library
+import base64
+import json
+import os
+
import flask
from flask import g, jsonify
-from werkzeug.routing import BaseConverter, PathConverter
-import os
from unidecode import unidecode
-import json
-import base64
+from werkzeug.routing import BaseConverter, PathConverter
+import beets.library
+from beets import ui, util
+from beets.plugins import BeetsPlugin
# Utilities.
+
def _rep(obj, expand=False):
"""Get a flat -- i.e., JSON-ish -- representation of a beets Item or
Album object. For Albums, `expand` dictates whether tracks are
@@ -37,32 +38,32 @@ def _rep(obj, expand=False):
out = dict(obj)
if isinstance(obj, beets.library.Item):
- if app.config.get('INCLUDE_PATHS', False):
- out['path'] = util.displayable_path(out['path'])
+ if app.config.get("INCLUDE_PATHS", False):
+ out["path"] = util.displayable_path(out["path"])
else:
- del out['path']
+ del out["path"]
# Filter all bytes attributes and convert them to strings.
for key, value in out.items():
if isinstance(out[key], bytes):
- out[key] = base64.b64encode(value).decode('ascii')
+ out[key] = base64.b64encode(value).decode("ascii")
# Get the size (in bytes) of the backing file. This is useful
# for the Tomahawk resolver API.
try:
- out['size'] = os.path.getsize(util.syspath(obj.path))
+ out["size"] = os.path.getsize(util.syspath(obj.path))
except OSError:
- out['size'] = 0
+ out["size"] = 0
return out
elif isinstance(obj, beets.library.Album):
- if app.config.get('INCLUDE_PATHS', False):
- out['artpath'] = util.displayable_path(out['artpath'])
+ if app.config.get("INCLUDE_PATHS", False):
+ out["artpath"] = util.displayable_path(out["artpath"])
else:
- del out['artpath']
+ del out["artpath"]
if expand:
- out['items'] = [_rep(item) for item in obj.items()]
+ out["items"] = [_rep(item) for item in obj.items()]
return out
@@ -81,15 +82,15 @@ def json_generator(items, root, expand=False):
if first:
first = False
else:
- yield ','
+ yield ","
yield json.dumps(_rep(item, expand=expand))
- yield ']}'
+ yield "]}"
def is_expand():
"""Returns whether the current request is for an expanded response."""
- return flask.request.args.get('expand') is not None
+ return flask.request.args.get("expand") is not None
def is_delete():
@@ -97,7 +98,7 @@ def is_delete():
files.
"""
- return flask.request.args.get('delete') is not None
+ return flask.request.args.get("delete") is not None
def get_method():
@@ -106,25 +107,24 @@ def get_method():
def resource(name, patchable=False):
- """Decorates a function to handle RESTful HTTP requests for a resource.
- """
+ """Decorates a function to handle RESTful HTTP requests for a resource."""
+
def make_responder(retriever):
def responder(ids):
entities = [retriever(id) for id in ids]
entities = [entity for entity in entities if entity]
if get_method() == "DELETE":
-
- if app.config.get('READONLY', True):
+ if app.config.get("READONLY", True):
return flask.abort(405)
for entity in entities:
entity.remove(delete=is_delete())
- return flask.make_response(jsonify({'deleted': True}), 200)
+ return flask.make_response(jsonify({"deleted": True}), 200)
elif get_method() == "PATCH" and patchable:
- if app.config.get('READONLY', True):
+ if app.config.get("READONLY", True):
return flask.abort(405)
for entity in entities:
@@ -136,7 +136,7 @@ def resource(name, patchable=False):
elif entities:
return app.response_class(
json_generator(entities, root=name),
- mimetype='application/json'
+ mimetype="application/json",
)
elif get_method() == "GET":
@@ -145,7 +145,7 @@ def resource(name, patchable=False):
elif entities:
return app.response_class(
json_generator(entities, root=name),
- mimetype='application/json'
+ mimetype="application/json",
)
else:
return flask.abort(404)
@@ -153,31 +153,31 @@ def resource(name, patchable=False):
else:
return flask.abort(405)
- responder.__name__ = f'get_{name}'
+ responder.__name__ = f"get_{name}"
return responder
+
return make_responder
def resource_query(name, patchable=False):
- """Decorates a function to handle RESTful HTTP queries for resources.
- """
+ """Decorates a function to handle RESTful HTTP queries for resources."""
+
def make_responder(query_func):
def responder(queries):
entities = query_func(queries)
if get_method() == "DELETE":
-
- if app.config.get('READONLY', True):
+ if app.config.get("READONLY", True):
return flask.abort(405)
for entity in entities:
entity.remove(delete=is_delete())
- return flask.make_response(jsonify({'deleted': True}), 200)
+ return flask.make_response(jsonify({"deleted": True}), 200)
elif get_method() == "PATCH" and patchable:
- if app.config.get('READONLY', True):
+ if app.config.get("READONLY", True):
return flask.abort(405)
for entity in entities:
@@ -186,22 +186,21 @@ def resource_query(name, patchable=False):
return app.response_class(
json_generator(entities, root=name),
- mimetype='application/json'
+ mimetype="application/json",
)
elif get_method() == "GET":
return app.response_class(
json_generator(
- entities,
- root='results', expand=is_expand()
+ entities, root="results", expand=is_expand()
),
- mimetype='application/json'
+ mimetype="application/json",
)
else:
return flask.abort(405)
- responder.__name__ = f'query_{name}'
+ responder.__name__ = f"query_{name}"
return responder
@@ -212,34 +211,39 @@ def resource_list(name):
"""Decorates a function to handle RESTful HTTP request for a list of
resources.
"""
+
def make_responder(list_all):
def responder():
return app.response_class(
json_generator(list_all(), root=name, expand=is_expand()),
- mimetype='application/json'
+ mimetype="application/json",
)
- responder.__name__ = f'all_{name}'
+
+ responder.__name__ = f"all_{name}"
return responder
+
return make_responder
def _get_unique_table_field_values(model, field, sort_field):
- """ retrieve all unique values belonging to a key from a model """
+ """retrieve all unique values belonging to a key from a model"""
if field not in model.all_keys() or sort_field not in model.all_keys():
raise KeyError
with g.lib.transaction() as tx:
- rows = tx.query('SELECT DISTINCT "{}" FROM "{}" ORDER BY "{}"'
- .format(field, model._table, sort_field))
+ rows = tx.query(
+ 'SELECT DISTINCT "{}" FROM "{}" ORDER BY "{}"'.format(
+ field, model._table, sort_field
+ )
+ )
return [row[0] for row in rows]
class IdListConverter(BaseConverter):
- """Converts comma separated lists of ids in urls to integer lists.
- """
+ """Converts comma separated lists of ids in urls to integer lists."""
def to_python(self, value):
ids = []
- for id in value.split(','):
+ for id in value.split(","):
try:
ids.append(int(id))
except ValueError:
@@ -247,98 +251,103 @@ class IdListConverter(BaseConverter):
return ids
def to_url(self, value):
- return ','.join(str(v) for v in value)
+ return ",".join(str(v) for v in value)
class QueryConverter(PathConverter):
- """Converts slash separated lists of queries in the url to string list.
- """
+ """Converts slash separated lists of queries in the url to string list."""
def to_python(self, value):
- queries = value.split('/')
+ queries = value.split("/")
"""Do not do path substitution on regex value tests"""
- return [query if '::' in query else query.replace('\\', os.sep)
- for query in queries]
+ return [
+ query if "::" in query else query.replace("\\", os.sep)
+ for query in queries
+ ]
def to_url(self, value):
- return ','.join([v.replace(os.sep, '\\') for v in value])
+ return "/".join([v.replace(os.sep, "\\") for v in value])
class EverythingConverter(PathConverter):
- regex = '.*?'
+ part_isolating = False
+ regex = ".*?"
# Flask setup.
app = flask.Flask(__name__)
-app.url_map.converters['idlist'] = IdListConverter
-app.url_map.converters['query'] = QueryConverter
-app.url_map.converters['everything'] = EverythingConverter
+app.url_map.converters["idlist"] = IdListConverter
+app.url_map.converters["query"] = QueryConverter
+app.url_map.converters["everything"] = EverythingConverter
@app.before_request
def before_request():
- g.lib = app.config['lib']
+ g.lib = app.config["lib"]
# Items.
-@app.route('/item/', methods=["GET", "DELETE", "PATCH"])
-@resource('items', patchable=True)
+
+@app.route("/item/", methods=["GET", "DELETE", "PATCH"])
+@resource("items", patchable=True)
def get_item(id):
return g.lib.get_item(id)
-@app.route('/item/')
-@app.route('/item/query/')
-@resource_list('items')
+@app.route("/item/")
+@app.route("/item/query/")
+@resource_list("items")
def all_items():
return g.lib.items()
-@app.route('/item//file')
+@app.route("/item//file")
def item_file(item_id):
item = g.lib.get_item(item_id)
# On Windows under Python 2, Flask wants a Unicode path. On Python 3, it
# *always* wants a Unicode path.
- if os.name == 'nt':
+ if os.name == "nt":
item_path = util.syspath(item.path)
else:
- item_path = util.py3_path(item.path)
+ item_path = os.fsdecode(item.path)
- try:
- unicode_item_path = util.text_string(item.path)
- except (UnicodeDecodeError, UnicodeEncodeError):
- unicode_item_path = util.displayable_path(item.path)
+ base_filename = os.path.basename(item_path)
+ # FIXME: Arguably, this should just use `displayable_path`: The latter
+ # tries `_fsencoding()` first, but then falls back to `utf-8`, too.
+ if isinstance(base_filename, bytes):
+ try:
+ unicode_base_filename = base_filename.decode("utf-8")
+ except UnicodeError:
+ unicode_base_filename = util.displayable_path(base_filename)
+ else:
+ unicode_base_filename = base_filename
- base_filename = os.path.basename(unicode_item_path)
try:
# Imitate http.server behaviour
base_filename.encode("latin-1", "strict")
- except UnicodeEncodeError:
+ except UnicodeError:
safe_filename = unidecode(base_filename)
else:
- safe_filename = base_filename
+ safe_filename = unicode_base_filename
response = flask.send_file(
- item_path,
- as_attachment=True,
- attachment_filename=safe_filename
+ item_path, as_attachment=True, download_name=safe_filename
)
- response.headers['Content-Length'] = os.path.getsize(item_path)
return response
-@app.route('/item/query/', methods=["GET", "DELETE", "PATCH"])
-@resource_query('items', patchable=True)
+@app.route("/item/query/", methods=["GET", "DELETE", "PATCH"])
+@resource_query("items", patchable=True)
def item_query(queries):
return g.lib.items(queries)
-@app.route('/item/path/')
+@app.route("/item/path/")
def item_at_path(path):
- query = beets.library.PathQuery('path', path.encode('utf-8'))
+ query = beets.library.PathQuery("path", path.encode("utf-8"))
item = g.lib.items(query).get()
if item:
return flask.jsonify(_rep(item))
@@ -346,12 +355,13 @@ def item_at_path(path):
return flask.abort(404)
-@app.route('/item/values/')
+@app.route("/item/values/")
def item_unique_field_values(key):
- sort_key = flask.request.args.get('sort_key', key)
+ sort_key = flask.request.args.get("sort_key", key)
try:
- values = _get_unique_table_field_values(beets.library.Item, key,
- sort_key)
+ values = _get_unique_table_field_values(
+ beets.library.Item, key, sort_key
+ )
except KeyError:
return flask.abort(404)
return flask.jsonify(values=values)
@@ -359,26 +369,27 @@ def item_unique_field_values(key):
# Albums.
-@app.route('/album/', methods=["GET", "DELETE"])
-@resource('albums')
+
+@app.route("/album/", methods=["GET", "DELETE"])
+@resource("albums")
def get_album(id):
return g.lib.get_album(id)
-@app.route('/album/')
-@app.route('/album/query/')
-@resource_list('albums')
+@app.route("/album/")
+@app.route("/album/query/")
+@resource_list("albums")
def all_albums():
return g.lib.albums()
-@app.route('/album/query/', methods=["GET", "DELETE"])
-@resource_query('albums')
+@app.route("/album/query/", methods=["GET", "DELETE"])
+@resource_query("albums")
def album_query(queries):
return g.lib.albums(queries)
-@app.route('/album//art')
+@app.route("/album//art")
def album_art(album_id):
album = g.lib.get_album(album_id)
if album and album.artpath:
@@ -387,12 +398,13 @@ def album_art(album_id):
return flask.abort(404)
-@app.route('/album/values/')
+@app.route("/album/values/")
def album_unique_field_values(key):
- sort_key = flask.request.args.get('sort_key', key)
+ sort_key = flask.request.args.get("sort_key", key)
try:
- values = _get_unique_table_field_values(beets.library.Album, key,
- sort_key)
+ values = _get_unique_table_field_values(
+ beets.library.Album, key, sort_key
+ )
except KeyError:
return flask.abort(404)
return flask.jsonify(values=values)
@@ -400,7 +412,8 @@ def album_unique_field_values(key):
# Artists.
-@app.route('/artist/')
+
+@app.route("/artist/")
def all_artists():
with g.lib.transaction() as tx:
rows = tx.query("SELECT DISTINCT albumartist FROM albums")
@@ -410,88 +423,106 @@ def all_artists():
# Library information.
-@app.route('/stats')
+
+@app.route("/stats")
def stats():
with g.lib.transaction() as tx:
item_rows = tx.query("SELECT COUNT(*) FROM items")
album_rows = tx.query("SELECT COUNT(*) FROM albums")
- return flask.jsonify({
- 'items': item_rows[0][0],
- 'albums': album_rows[0][0],
- })
+ return flask.jsonify(
+ {
+ "items": item_rows[0][0],
+ "albums": album_rows[0][0],
+ }
+ )
# UI.
-@app.route('/')
+
+@app.route("/")
def home():
- return flask.render_template('index.html')
+ return flask.render_template("index.html")
# Plugin hook.
+
class WebPlugin(BeetsPlugin):
def __init__(self):
super().__init__()
- self.config.add({
- 'host': '127.0.0.1',
- 'port': 8337,
- 'cors': '',
- 'cors_supports_credentials': False,
- 'reverse_proxy': False,
- 'include_paths': False,
- 'readonly': True,
- })
+ self.config.add(
+ {
+ "host": "127.0.0.1",
+ "port": 8337,
+ "cors": "",
+ "cors_supports_credentials": False,
+ "reverse_proxy": False,
+ "include_paths": False,
+ "readonly": True,
+ }
+ )
def commands(self):
- cmd = ui.Subcommand('web', help='start a Web interface')
- cmd.parser.add_option('-d', '--debug', action='store_true',
- default=False, help='debug mode')
+ cmd = ui.Subcommand("web", help="start a Web interface")
+ cmd.parser.add_option(
+ "-d",
+ "--debug",
+ action="store_true",
+ default=False,
+ help="debug mode",
+ )
def func(lib, opts, args):
args = ui.decargs(args)
if args:
- self.config['host'] = args.pop(0)
+ self.config["host"] = args.pop(0)
if args:
- self.config['port'] = int(args.pop(0))
+ self.config["port"] = int(args.pop(0))
- app.config['lib'] = lib
+ app.config["lib"] = lib
# Normalizes json output
- app.config['JSONIFY_PRETTYPRINT_REGULAR'] = False
+ app.config["JSONIFY_PRETTYPRINT_REGULAR"] = False
- app.config['INCLUDE_PATHS'] = self.config['include_paths']
- app.config['READONLY'] = self.config['readonly']
+ app.config["INCLUDE_PATHS"] = self.config["include_paths"]
+ app.config["READONLY"] = self.config["readonly"]
# Enable CORS if required.
- if self.config['cors']:
- self._log.info('Enabling CORS with origin: {0}',
- self.config['cors'])
+ if self.config["cors"]:
+ self._log.info(
+ "Enabling CORS with origin: {0}", self.config["cors"]
+ )
from flask_cors import CORS
- app.config['CORS_ALLOW_HEADERS'] = "Content-Type"
- app.config['CORS_RESOURCES'] = {
- r"/*": {"origins": self.config['cors'].get(str)}
+
+ app.config["CORS_ALLOW_HEADERS"] = "Content-Type"
+ app.config["CORS_RESOURCES"] = {
+ r"/*": {"origins": self.config["cors"].get(str)}
}
CORS(
app,
supports_credentials=self.config[
- 'cors_supports_credentials'
- ].get(bool)
+ "cors_supports_credentials"
+ ].get(bool),
)
# Allow serving behind a reverse proxy
- if self.config['reverse_proxy']:
+ if self.config["reverse_proxy"]:
app.wsgi_app = ReverseProxied(app.wsgi_app)
# Start the web application.
- app.run(host=self.config['host'].as_str(),
- port=self.config['port'].get(int),
- debug=opts.debug, threaded=True)
+ app.run(
+ host=self.config["host"].as_str(),
+ port=self.config["port"].get(int),
+ debug=opts.debug,
+ threaded=True,
+ )
+
cmd.func = func
return [cmd]
class ReverseProxied:
- '''Wrap the application in this middleware and configure the
+ """Wrap the application in this middleware and configure the
front-end server to add these headers, to let you quietly bind
this to a URL other than / and to an HTTP scheme that is
different than what is used locally.
@@ -508,19 +539,20 @@ class ReverseProxied:
From: http://flask.pocoo.org/snippets/35/
:param app: the WSGI application
- '''
+ """
+
def __init__(self, app):
self.app = app
def __call__(self, environ, start_response):
- script_name = environ.get('HTTP_X_SCRIPT_NAME', '')
+ script_name = environ.get("HTTP_X_SCRIPT_NAME", "")
if script_name:
- environ['SCRIPT_NAME'] = script_name
- path_info = environ['PATH_INFO']
+ environ["SCRIPT_NAME"] = script_name
+ path_info = environ["PATH_INFO"]
if path_info.startswith(script_name):
- environ['PATH_INFO'] = path_info[len(script_name):]
+ environ["PATH_INFO"] = path_info[len(script_name) :]
- scheme = environ.get('HTTP_X_SCHEME', '')
+ scheme = environ.get("HTTP_X_SCHEME", "")
if scheme:
- environ['wsgi.url_scheme'] = scheme
+ environ["wsgi.url_scheme"] = scheme
return self.app(environ, start_response)
diff --git a/lib/beetsplug/web/static/backbone.js b/lib/beetsplug/web/static/backbone.js
index b2e49322..f8a098a0 100644
--- a/lib/beetsplug/web/static/backbone.js
+++ b/lib/beetsplug/web/static/backbone.js
@@ -274,7 +274,7 @@
},
// Fetch the model from the server. If the server's representation of the
- // model differs from its current attributes, they will be overriden,
+ // model differs from its current attributes, they will be overridden,
// triggering a `"change"` event.
fetch : function(options) {
options || (options = {});
@@ -885,7 +885,7 @@
};
// Element lookup, scoped to DOM elements within the current view.
- // This should be prefered to global lookups, if you're dealing with
+ // This should be preferred to global lookups, if you're dealing with
// a specific view.
var selectorDelegate = function(selector) {
return $(selector, this.el);
@@ -984,7 +984,7 @@
// Ensure that the View has a DOM element to render into.
// If `this.el` is a string, pass it through `$()`, take the first
// matching element, and re-assign it to `el`. Otherwise, create
- // an element from the `id`, `className` and `tagName` proeprties.
+ // an element from the `id`, `className` and `tagName` properties.
_ensureElement : function() {
if (!this.el) {
var attrs = this.attributes || {};
diff --git a/lib/beetsplug/web/static/jquery.js b/lib/beetsplug/web/static/jquery.js
index e1414212..ed506237 100644
--- a/lib/beetsplug/web/static/jquery.js
+++ b/lib/beetsplug/web/static/jquery.js
@@ -2278,7 +2278,7 @@ jQuery.fn.extend({
classNames = value.split( rspace );
while ( (className = classNames[ i++ ]) ) {
- // check each className given, space seperated list
+ // check each className given, space separated list
state = isBool ? state : !self.hasClass( className );
self[ state ? "addClass" : "removeClass" ]( className );
}
@@ -3868,7 +3868,7 @@ var chunker = /((?:\((?:\([^()]+\)|[^()]+)+\)|\[(?:\[[^\[\]]*\]|['"][^'"]*['"]|[
rNonWord = /\W/;
// Here we check if the JavaScript engine is using some sort of
-// optimization where it does not always call our comparision
+// optimization where it does not always call our comparison
// function. If that is the case, discard the hasDuplicate value.
// Thus far that includes Google Chrome.
[0, 0].sort(function() {
@@ -4180,7 +4180,7 @@ Sizzle.error = function( msg ) {
};
/**
- * Utility function for retreiving the text value of an array of DOM nodes
+ * Utility function for retrieving the text value of an array of DOM nodes
* @param {Array|Element} elem
*/
var getText = Sizzle.getText = function( elem ) {
@@ -8111,7 +8111,7 @@ if ( jQuery.support.ajax ) {
xml;
// Firefox throws exceptions when accessing properties
- // of an xhr when a network error occured
+ // of an xhr when a network error occurred
// http://helpful.knobs-dials.com/index.php/Component_returned_failure_code:_0x80040111_(NS_ERROR_NOT_AVAILABLE)
try {
diff --git a/lib/beetsplug/zero.py b/lib/beetsplug/zero.py
index f05b1b5a..14c157ce 100644
--- a/lib/beetsplug/zero.py
+++ b/lib/beetsplug/zero.py
@@ -17,29 +17,33 @@
import re
-from beets.plugins import BeetsPlugin
-from mediafile import MediaFile
-from beets.importer import action
-from beets.ui import Subcommand, decargs, input_yn
import confuse
+from mediafile import MediaFile
-__author__ = 'baobab@heresiarch.info'
+from beets.importer import action
+from beets.plugins import BeetsPlugin
+from beets.ui import Subcommand, decargs, input_yn
+
+__author__ = "baobab@heresiarch.info"
class ZeroPlugin(BeetsPlugin):
def __init__(self):
super().__init__()
- self.register_listener('write', self.write_event)
- self.register_listener('import_task_choice',
- self.import_task_choice_event)
+ self.register_listener("write", self.write_event)
+ self.register_listener(
+ "import_task_choice", self.import_task_choice_event
+ )
- self.config.add({
- 'auto': True,
- 'fields': [],
- 'keep_fields': [],
- 'update_database': False,
- })
+ self.config.add(
+ {
+ "auto": True,
+ "fields": [],
+ "keep_fields": [],
+ "update_database": False,
+ }
+ )
self.fields_to_progs = {}
self.warned = False
@@ -51,29 +55,30 @@ class ZeroPlugin(BeetsPlugin):
A field is zeroed if its value matches one of the associated progs. If
progs is empty, then the associated field is always zeroed.
"""
- if self.config['fields'] and self.config['keep_fields']:
- self._log.warning(
- 'cannot blacklist and whitelist at the same time'
- )
+ if self.config["fields"] and self.config["keep_fields"]:
+ self._log.warning("cannot blacklist and whitelist at the same time")
# Blacklist mode.
- elif self.config['fields']:
- for field in self.config['fields'].as_str_seq():
+ elif self.config["fields"]:
+ for field in self.config["fields"].as_str_seq():
self._set_pattern(field)
# Whitelist mode.
- elif self.config['keep_fields']:
+ elif self.config["keep_fields"]:
for field in MediaFile.fields():
- if (field not in self.config['keep_fields'].as_str_seq() and
- # These fields should always be preserved.
- field not in ('id', 'path', 'album_id')):
+ if (
+ field not in self.config["keep_fields"].as_str_seq()
+ and
+ # These fields should always be preserved.
+ field not in ("id", "path", "album_id")
+ ):
self._set_pattern(field)
def commands(self):
- zero_command = Subcommand('zero', help='set fields to null')
+ zero_command = Subcommand("zero", help="set fields to null")
def zero_fields(lib, opts, args):
if not decargs(args) and not input_yn(
- "Remove fields for all items? (Y/n)",
- True):
+ "Remove fields for all items? (Y/n)", True
+ ):
return
for item in lib.items(decargs(args)):
self.process_item(item)
@@ -86,10 +91,11 @@ class ZeroPlugin(BeetsPlugin):
Do some sanity checks then compile the regexes.
"""
if field not in MediaFile.fields():
- self._log.error('invalid field: {0}', field)
- elif field in ('id', 'path', 'album_id'):
- self._log.warning('field \'{0}\' ignored, zeroing '
- 'it would be dangerous', field)
+ self._log.error("invalid field: {0}", field)
+ elif field in ("id", "path", "album_id"):
+ self._log.warning(
+ "field '{0}' ignored, zeroing " "it would be dangerous", field
+ )
else:
try:
for pattern in self.config[field].as_str_seq():
@@ -101,12 +107,12 @@ class ZeroPlugin(BeetsPlugin):
def import_task_choice_event(self, session, task):
if task.choice_flag == action.ASIS and not self.warned:
- self._log.warning('cannot zero in \"as-is\" mode')
+ self._log.warning('cannot zero in "as-is" mode')
self.warned = True
# TODO request write in as-is mode
def write_event(self, item, path, tags):
- if self.config['auto']:
+ if self.config["auto"]:
self.set_fields(item, tags)
def set_fields(self, item, tags):
@@ -119,7 +125,7 @@ class ZeroPlugin(BeetsPlugin):
fields_set = False
if not self.fields_to_progs:
- self._log.warning('no fields, nothing to do')
+ self._log.warning("no fields, nothing to do")
return False
for field, progs in self.fields_to_progs.items():
@@ -127,14 +133,14 @@ class ZeroPlugin(BeetsPlugin):
value = tags[field]
match = _match_progs(tags[field], progs)
else:
- value = ''
+ value = ""
match = not progs
if match:
fields_set = True
- self._log.debug('{0}: {1} -> None', field, value)
+ self._log.debug("{0}: {1} -> None", field, value)
tags[field] = None
- if self.config['update_database']:
+ if self.config["update_database"]:
item[field] = None
return fields_set
@@ -144,7 +150,7 @@ class ZeroPlugin(BeetsPlugin):
if self.set_fields(item, tags):
item.write(tags=tags)
- if self.config['update_database']:
+ if self.config["update_database"]:
item.store(fields=tags)