diff --git a/api/funkwhale_api/radios/radios.py b/api/funkwhale_api/radios/radios.py index 585bbbe334f0c3ca297d9ada8883b48bb848354c..0d045ea4dc89ba2d545364e2a3123f3f625a001e 100644 --- a/api/funkwhale_api/radios/radios.py +++ b/api/funkwhale_api/radios/radios.py @@ -1,5 +1,6 @@ import random from rest_framework import serializers +from django.db.models import Count from django.core.exceptions import ValidationError from taggit.models import Tag from funkwhale_api.users.models import User @@ -39,8 +40,11 @@ class SessionRadio(SimpleRadio): self.session = models.RadioSession.objects.create(user=user, radio_type=self.radio_type, **kwargs) return self.session - def get_queryset(self): - raise NotImplementedError + def get_queryset(self, **kwargs): + qs = Track.objects.annotate( + files_count=Count('files') + ) + return qs.filter(files_count__gt=0) def get_queryset_kwargs(self): return {} @@ -75,7 +79,9 @@ class SessionRadio(SimpleRadio): @registry.register(name='random') class RandomRadio(SessionRadio): def get_queryset(self, **kwargs): - return Track.objects.all() + qs = super().get_queryset(**kwargs) + return qs.order_by('?') + @registry.register(name='favorites') class FavoritesRadio(SessionRadio): @@ -87,8 +93,9 @@ class FavoritesRadio(SessionRadio): return kwargs def get_queryset(self, **kwargs): + qs = super().get_queryset(**kwargs) track_ids = kwargs['user'].track_favorites.all().values_list('track', flat=True) - return Track.objects.filter(pk__in=track_ids) + return qs.filter(pk__in=track_ids) @registry.register(name='custom') @@ -101,7 +108,11 @@ class CustomRadio(SessionRadio): return kwargs def get_queryset(self, **kwargs): - return filters.run(kwargs['custom_radio'].config) + qs = super().get_queryset(**kwargs) + return filters.run( + kwargs['custom_radio'].config, + candidates=qs, + ) def validate_session(self, data, **context): data = super().validate_session(data, **context) @@ -141,6 +152,7 @@ class TagRadio(RelatedObjectRadio): model = Tag def get_queryset(self, **kwargs): + qs = super().get_queryset(**kwargs) return Track.objects.filter(tags__in=[self.session.related_object]) @registry.register(name='artist') @@ -148,7 +160,8 @@ class ArtistRadio(RelatedObjectRadio): model = Artist def get_queryset(self, **kwargs): - return self.session.related_object.tracks.all() + qs = super().get_queryset(**kwargs) + return qs.filter(artist=self.session.related_object) @registry.register(name='less-listened') @@ -160,5 +173,6 @@ class LessListenedRadio(RelatedObjectRadio): super().clean(instance) def get_queryset(self, **kwargs): + qs = super().get_queryset(**kwargs) listened = self.session.user.listenings.all().values_list('track', flat=True) - return Track.objects.exclude(pk__in=listened).order_by('?') + return qs.exclude(pk__in=listened).order_by('?') diff --git a/api/tests/radios/test_radios.py b/api/tests/radios/test_radios.py index b00bfcd79ce3b917a26698ac4d868dea39521428..b731e3024b039bd5006023bb80276405565edd8e 100644 --- a/api/tests/radios/test_radios.py +++ b/api/tests/radios/test_radios.py @@ -51,7 +51,8 @@ def test_can_pick_by_weight(): def test_can_get_choices_for_favorites_radio(factories): - tracks = factories['music.Track'].create_batch(10) + files = factories['music.TrackFile'].create_batch(10) + tracks = [f.track for f in files] user = factories['users.User']() for i in range(5): TrackFavorite.add(track=random.choice(tracks), user=user) @@ -71,8 +72,12 @@ def test_can_get_choices_for_favorites_radio(factories): def test_can_get_choices_for_custom_radio(factories): artist = factories['music.Artist']() - tracks = factories['music.Track'].create_batch(5, artist=artist) - wrong_tracks = factories['music.Track'].create_batch(5) + files = factories['music.TrackFile'].create_batch( + 5, track__artist=artist) + tracks = [f.track for f in files] + wrong_files = factories['music.TrackFile'].create_batch(5) + wrong_tracks = [f.track for f in wrong_files] + session = factories['radios.CustomRadioSession']( custom_radio__config=[{'type': 'artist', 'ids': [artist.pk]}] ) @@ -113,7 +118,8 @@ def test_can_start_custom_radio_from_api(logged_in_client, factories): def test_can_use_radio_session_to_filter_choices(factories): - tracks = factories['music.Track'].create_batch(30) + files = factories['music.TrackFile'].create_batch(30) + tracks = [f.track for f in files] user = factories['users.User']() radio = radios.RandomRadio() session = radio.start_session(user) @@ -156,8 +162,8 @@ def test_can_start_radio_for_anonymous_user(client, db): def test_can_get_track_for_session_from_api(factories, logged_in_client): - tracks = factories['music.Track'].create_batch(size=1) - + files = factories['music.TrackFile'].create_batch(1) + tracks = [f.track for f in files] url = reverse('api:v1:radios:sessions-list') response = logged_in_client.post(url, {'radio_type': 'random'}) session = models.RadioSession.objects.latest('id') @@ -169,7 +175,7 @@ def test_can_get_track_for_session_from_api(factories, logged_in_client): assert data['track']['id'] == tracks[0].id assert data['position'] == 1 - next_track = factories['music.Track']() + next_track = factories['music.TrackFile']().track response = logged_in_client.post(url, {'session': session.pk}) data = json.loads(response.content.decode('utf-8')) @@ -193,8 +199,11 @@ def test_related_object_radio_validate_related_object(factories): def test_can_start_artist_radio(factories): user = factories['users.User']() artist = factories['music.Artist']() - wrong_tracks = factories['music.Track'].create_batch(5) - good_tracks = factories['music.Track'].create_batch(5, artist=artist) + wrong_files = factories['music.TrackFile'].create_batch(5) + wrong_tracks = [f.track for f in wrong_files] + good_files = factories['music.TrackFile'].create_batch( + 5, track__artist=artist) + good_tracks = [f.track for f in good_files] radio = radios.ArtistRadio() session = radio.start_session(user, related_object=artist) @@ -206,8 +215,11 @@ def test_can_start_artist_radio(factories): def test_can_start_tag_radio(factories): user = factories['users.User']() tag = factories['taggit.Tag']() - wrong_tracks = factories['music.Track'].create_batch(5) - good_tracks = factories['music.Track'].create_batch(5, tags=[tag]) + wrong_files = factories['music.TrackFile'].create_batch(5) + wrong_tracks = [f.track for f in wrong_files] + good_files = factories['music.TrackFile'].create_batch( + 5, track__tags=[tag]) + good_tracks = [f.track for f in good_files] radio = radios.TagRadio() session = radio.start_session(user, related_object=tag) @@ -229,9 +241,11 @@ def test_can_start_artist_radio_from_api(client, factories): def test_can_start_less_listened_radio(factories): user = factories['users.User']() - history = factories['history.Listening'].create_batch(5, user=user) - wrong_tracks = [h.track for h in history] - good_tracks = factories['music.Track'].create_batch(size=5) + wrong_files = factories['music.TrackFile'].create_batch(5) + for f in wrong_files: + factories['history.Listening'](track=f.track, user=user) + good_files = factories['music.TrackFile'].create_batch(5) + good_tracks = [f.track for f in good_files] radio = radios.LessListenedRadio() session = radio.start_session(user) assert session.related_object == user diff --git a/changes/changelog.d/88.bugfix b/changes/changelog.d/88.bugfix new file mode 100644 index 0000000000000000000000000000000000000000..d2f707b4467eecdb86d787841fec538648c393b9 --- /dev/null +++ b/changes/changelog.d/88.bugfix @@ -0,0 +1 @@ +Now exclude tracks without file from radio candidates (#88)