Skip to content
Snippets Groups Projects
Commit 6f907458 authored by Eliot Berriot's avatar Eliot Berriot
Browse files

Merge branch '88-exclude-empty-radio' into 'develop'

Fixed #88: Now exclude tracks without file from radio candidates

Closes #88

See merge request funkwhale/funkwhale!62
parents eb505a80 db4ae180
No related branches found
No related tags found
No related merge requests found
import random import random
from rest_framework import serializers from rest_framework import serializers
from django.db.models import Count
from django.core.exceptions import ValidationError from django.core.exceptions import ValidationError
from taggit.models import Tag from taggit.models import Tag
from funkwhale_api.users.models import User from funkwhale_api.users.models import User
...@@ -39,8 +40,11 @@ class SessionRadio(SimpleRadio): ...@@ -39,8 +40,11 @@ class SessionRadio(SimpleRadio):
self.session = models.RadioSession.objects.create(user=user, radio_type=self.radio_type, **kwargs) self.session = models.RadioSession.objects.create(user=user, radio_type=self.radio_type, **kwargs)
return self.session return self.session
def get_queryset(self): def get_queryset(self, **kwargs):
raise NotImplementedError qs = Track.objects.annotate(
files_count=Count('files')
)
return qs.filter(files_count__gt=0)
def get_queryset_kwargs(self): def get_queryset_kwargs(self):
return {} return {}
...@@ -75,7 +79,9 @@ class SessionRadio(SimpleRadio): ...@@ -75,7 +79,9 @@ class SessionRadio(SimpleRadio):
@registry.register(name='random') @registry.register(name='random')
class RandomRadio(SessionRadio): class RandomRadio(SessionRadio):
def get_queryset(self, **kwargs): def get_queryset(self, **kwargs):
return Track.objects.all() qs = super().get_queryset(**kwargs)
return qs.order_by('?')
@registry.register(name='favorites') @registry.register(name='favorites')
class FavoritesRadio(SessionRadio): class FavoritesRadio(SessionRadio):
...@@ -87,8 +93,9 @@ class FavoritesRadio(SessionRadio): ...@@ -87,8 +93,9 @@ class FavoritesRadio(SessionRadio):
return kwargs return kwargs
def get_queryset(self, **kwargs): def get_queryset(self, **kwargs):
qs = super().get_queryset(**kwargs)
track_ids = kwargs['user'].track_favorites.all().values_list('track', flat=True) track_ids = kwargs['user'].track_favorites.all().values_list('track', flat=True)
return Track.objects.filter(pk__in=track_ids) return qs.filter(pk__in=track_ids)
@registry.register(name='custom') @registry.register(name='custom')
...@@ -101,7 +108,11 @@ class CustomRadio(SessionRadio): ...@@ -101,7 +108,11 @@ class CustomRadio(SessionRadio):
return kwargs return kwargs
def get_queryset(self, **kwargs): def get_queryset(self, **kwargs):
return filters.run(kwargs['custom_radio'].config) qs = super().get_queryset(**kwargs)
return filters.run(
kwargs['custom_radio'].config,
candidates=qs,
)
def validate_session(self, data, **context): def validate_session(self, data, **context):
data = super().validate_session(data, **context) data = super().validate_session(data, **context)
...@@ -141,6 +152,7 @@ class TagRadio(RelatedObjectRadio): ...@@ -141,6 +152,7 @@ class TagRadio(RelatedObjectRadio):
model = Tag model = Tag
def get_queryset(self, **kwargs): def get_queryset(self, **kwargs):
qs = super().get_queryset(**kwargs)
return Track.objects.filter(tags__in=[self.session.related_object]) return Track.objects.filter(tags__in=[self.session.related_object])
@registry.register(name='artist') @registry.register(name='artist')
...@@ -148,7 +160,8 @@ class ArtistRadio(RelatedObjectRadio): ...@@ -148,7 +160,8 @@ class ArtistRadio(RelatedObjectRadio):
model = Artist model = Artist
def get_queryset(self, **kwargs): def get_queryset(self, **kwargs):
return self.session.related_object.tracks.all() qs = super().get_queryset(**kwargs)
return qs.filter(artist=self.session.related_object)
@registry.register(name='less-listened') @registry.register(name='less-listened')
...@@ -160,5 +173,6 @@ class LessListenedRadio(RelatedObjectRadio): ...@@ -160,5 +173,6 @@ class LessListenedRadio(RelatedObjectRadio):
super().clean(instance) super().clean(instance)
def get_queryset(self, **kwargs): def get_queryset(self, **kwargs):
qs = super().get_queryset(**kwargs)
listened = self.session.user.listenings.all().values_list('track', flat=True) listened = self.session.user.listenings.all().values_list('track', flat=True)
return Track.objects.exclude(pk__in=listened).order_by('?') return qs.exclude(pk__in=listened).order_by('?')
...@@ -51,7 +51,8 @@ def test_can_pick_by_weight(): ...@@ -51,7 +51,8 @@ def test_can_pick_by_weight():
def test_can_get_choices_for_favorites_radio(factories): def test_can_get_choices_for_favorites_radio(factories):
tracks = factories['music.Track'].create_batch(10) files = factories['music.TrackFile'].create_batch(10)
tracks = [f.track for f in files]
user = factories['users.User']() user = factories['users.User']()
for i in range(5): for i in range(5):
TrackFavorite.add(track=random.choice(tracks), user=user) TrackFavorite.add(track=random.choice(tracks), user=user)
...@@ -71,8 +72,12 @@ def test_can_get_choices_for_favorites_radio(factories): ...@@ -71,8 +72,12 @@ def test_can_get_choices_for_favorites_radio(factories):
def test_can_get_choices_for_custom_radio(factories): def test_can_get_choices_for_custom_radio(factories):
artist = factories['music.Artist']() artist = factories['music.Artist']()
tracks = factories['music.Track'].create_batch(5, artist=artist) files = factories['music.TrackFile'].create_batch(
wrong_tracks = factories['music.Track'].create_batch(5) 5, track__artist=artist)
tracks = [f.track for f in files]
wrong_files = factories['music.TrackFile'].create_batch(5)
wrong_tracks = [f.track for f in wrong_files]
session = factories['radios.CustomRadioSession']( session = factories['radios.CustomRadioSession'](
custom_radio__config=[{'type': 'artist', 'ids': [artist.pk]}] custom_radio__config=[{'type': 'artist', 'ids': [artist.pk]}]
) )
...@@ -113,7 +118,8 @@ def test_can_start_custom_radio_from_api(logged_in_client, factories): ...@@ -113,7 +118,8 @@ def test_can_start_custom_radio_from_api(logged_in_client, factories):
def test_can_use_radio_session_to_filter_choices(factories): def test_can_use_radio_session_to_filter_choices(factories):
tracks = factories['music.Track'].create_batch(30) files = factories['music.TrackFile'].create_batch(30)
tracks = [f.track for f in files]
user = factories['users.User']() user = factories['users.User']()
radio = radios.RandomRadio() radio = radios.RandomRadio()
session = radio.start_session(user) session = radio.start_session(user)
...@@ -156,8 +162,8 @@ def test_can_start_radio_for_anonymous_user(client, db): ...@@ -156,8 +162,8 @@ def test_can_start_radio_for_anonymous_user(client, db):
def test_can_get_track_for_session_from_api(factories, logged_in_client): def test_can_get_track_for_session_from_api(factories, logged_in_client):
tracks = factories['music.Track'].create_batch(size=1) files = factories['music.TrackFile'].create_batch(1)
tracks = [f.track for f in files]
url = reverse('api:v1:radios:sessions-list') url = reverse('api:v1:radios:sessions-list')
response = logged_in_client.post(url, {'radio_type': 'random'}) response = logged_in_client.post(url, {'radio_type': 'random'})
session = models.RadioSession.objects.latest('id') session = models.RadioSession.objects.latest('id')
...@@ -169,7 +175,7 @@ def test_can_get_track_for_session_from_api(factories, logged_in_client): ...@@ -169,7 +175,7 @@ def test_can_get_track_for_session_from_api(factories, logged_in_client):
assert data['track']['id'] == tracks[0].id assert data['track']['id'] == tracks[0].id
assert data['position'] == 1 assert data['position'] == 1
next_track = factories['music.Track']() next_track = factories['music.TrackFile']().track
response = logged_in_client.post(url, {'session': session.pk}) response = logged_in_client.post(url, {'session': session.pk})
data = json.loads(response.content.decode('utf-8')) data = json.loads(response.content.decode('utf-8'))
...@@ -193,8 +199,11 @@ def test_related_object_radio_validate_related_object(factories): ...@@ -193,8 +199,11 @@ def test_related_object_radio_validate_related_object(factories):
def test_can_start_artist_radio(factories): def test_can_start_artist_radio(factories):
user = factories['users.User']() user = factories['users.User']()
artist = factories['music.Artist']() artist = factories['music.Artist']()
wrong_tracks = factories['music.Track'].create_batch(5) wrong_files = factories['music.TrackFile'].create_batch(5)
good_tracks = factories['music.Track'].create_batch(5, artist=artist) wrong_tracks = [f.track for f in wrong_files]
good_files = factories['music.TrackFile'].create_batch(
5, track__artist=artist)
good_tracks = [f.track for f in good_files]
radio = radios.ArtistRadio() radio = radios.ArtistRadio()
session = radio.start_session(user, related_object=artist) session = radio.start_session(user, related_object=artist)
...@@ -206,8 +215,11 @@ def test_can_start_artist_radio(factories): ...@@ -206,8 +215,11 @@ def test_can_start_artist_radio(factories):
def test_can_start_tag_radio(factories): def test_can_start_tag_radio(factories):
user = factories['users.User']() user = factories['users.User']()
tag = factories['taggit.Tag']() tag = factories['taggit.Tag']()
wrong_tracks = factories['music.Track'].create_batch(5) wrong_files = factories['music.TrackFile'].create_batch(5)
good_tracks = factories['music.Track'].create_batch(5, tags=[tag]) wrong_tracks = [f.track for f in wrong_files]
good_files = factories['music.TrackFile'].create_batch(
5, track__tags=[tag])
good_tracks = [f.track for f in good_files]
radio = radios.TagRadio() radio = radios.TagRadio()
session = radio.start_session(user, related_object=tag) session = radio.start_session(user, related_object=tag)
...@@ -229,9 +241,11 @@ def test_can_start_artist_radio_from_api(client, factories): ...@@ -229,9 +241,11 @@ def test_can_start_artist_radio_from_api(client, factories):
def test_can_start_less_listened_radio(factories): def test_can_start_less_listened_radio(factories):
user = factories['users.User']() user = factories['users.User']()
history = factories['history.Listening'].create_batch(5, user=user) wrong_files = factories['music.TrackFile'].create_batch(5)
wrong_tracks = [h.track for h in history] for f in wrong_files:
good_tracks = factories['music.Track'].create_batch(size=5) factories['history.Listening'](track=f.track, user=user)
good_files = factories['music.TrackFile'].create_batch(5)
good_tracks = [f.track for f in good_files]
radio = radios.LessListenedRadio() radio = radios.LessListenedRadio()
session = radio.start_session(user) session = radio.start_session(user)
assert session.related_object == user assert session.related_object == user
......
Now exclude tracks without file from radio candidates (#88)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment