Skip to content
Snippets Groups Projects
Verified Commit ec3fcefa authored by Eliot Berriot's avatar Eliot Berriot
Browse files

Ensure radio tracks only return playable tracks

parent 4d425e92
Branches
Tags
No related merge requests found
...@@ -15,7 +15,7 @@ redeliver_deliveries.short_description = "Redeliver" ...@@ -15,7 +15,7 @@ redeliver_deliveries.short_description = "Redeliver"
def redeliver_activities(modeladmin, request, queryset): def redeliver_activities(modeladmin, request, queryset):
for activity in queryset.select_related("actor__user"): for activity in queryset.select_related("actor__user"):
if activity.actor.is_local: if activity.actor.get_user():
tasks.dispatch_outbox.delay(activity_id=activity.pk) tasks.dispatch_outbox.delay(activity_id=activity.pk)
else: else:
tasks.dispatch_inbox.delay(activity_id=activity.pk) tasks.dispatch_inbox.delay(activity_id=activity.pk)
......
...@@ -94,7 +94,7 @@ class LibraryViewSet(mixins.RetrieveModelMixin, viewsets.GenericViewSet): ...@@ -94,7 +94,7 @@ class LibraryViewSet(mixins.RetrieveModelMixin, viewsets.GenericViewSet):
queryset = ( queryset = (
music_models.Library.objects.all() music_models.Library.objects.all()
.order_by("-creation_date") .order_by("-creation_date")
.select_related("actor") .select_related("actor__user")
.annotate(_uploads_count=Count("uploads")) .annotate(_uploads_count=Count("uploads"))
) )
serializer_class = api_serializers.LibrarySerializer serializer_class = api_serializers.LibrarySerializer
...@@ -107,7 +107,7 @@ class LibraryViewSet(mixins.RetrieveModelMixin, viewsets.GenericViewSet): ...@@ -107,7 +107,7 @@ class LibraryViewSet(mixins.RetrieveModelMixin, viewsets.GenericViewSet):
@decorators.detail_route(methods=["post"]) @decorators.detail_route(methods=["post"])
def scan(self, request, *args, **kwargs): def scan(self, request, *args, **kwargs):
library = self.get_object() library = self.get_object()
if library.actor.is_local: if library.actor.get_user():
return response.Response({"status": "skipped"}, 200) return response.Response({"status": "skipped"}, 200)
scan = library.schedule_scan(actor=request.user.actor) scan = library.schedule_scan(actor=request.user.actor)
......
...@@ -576,7 +576,7 @@ TRACK_FILE_IMPORT_STATUS_CHOICES = ( ...@@ -576,7 +576,7 @@ TRACK_FILE_IMPORT_STATUS_CHOICES = (
def get_file_path(instance, filename): def get_file_path(instance, filename):
if instance.library.actor.is_local: if instance.library.actor.get_user():
return common_utils.ChunkedPath("tracks")(instance, filename) return common_utils.ChunkedPath("tracks")(instance, filename)
else: else:
# we cache remote tracks in a different directory # we cache remote tracks in a different directory
...@@ -725,7 +725,7 @@ class Upload(models.Model): ...@@ -725,7 +725,7 @@ class Upload(models.Model):
self.mimetype = mimetypes.guess_type(self.source)[0] self.mimetype = mimetypes.guess_type(self.source)[0]
if not self.size and self.audio_file: if not self.size and self.audio_file:
self.size = self.audio_file.size self.size = self.audio_file.size
if not self.pk and not self.fid and self.library.actor.is_local: if not self.pk and not self.fid and self.library.actor.get_user():
self.fid = self.get_federation_id() self.fid = self.get_federation_id()
return super().save(**kwargs) return super().save(**kwargs)
...@@ -908,7 +908,7 @@ class Library(federation_models.FederationMixin): ...@@ -908,7 +908,7 @@ class Library(federation_models.FederationMixin):
def should_autoapprove_follow(self, actor): def should_autoapprove_follow(self, actor):
if self.privacy_level == "everyone": if self.privacy_level == "everyone":
return True return True
if self.privacy_level == "instance" and actor.is_local: if self.privacy_level == "instance" and actor.get_user():
return True return True
return False return False
......
...@@ -54,6 +54,8 @@ class SessionRadio(SimpleRadio): ...@@ -54,6 +54,8 @@ class SessionRadio(SimpleRadio):
queryset = self.get_queryset(**kwargs) queryset = self.get_queryset(**kwargs)
if self.session: if self.session:
queryset = self.filter_from_session(queryset) queryset = self.filter_from_session(queryset)
if kwargs.pop("filter_playable", True):
queryset = queryset.playable_by(self.session.user.actor)
return queryset return queryset
def filter_from_session(self, queryset): def filter_from_session(self, queryset):
......
...@@ -76,7 +76,7 @@ def test_can_get_choices_for_custom_radio(factories): ...@@ -76,7 +76,7 @@ def test_can_get_choices_for_custom_radio(factories):
session = factories["radios.CustomRadioSession"]( session = factories["radios.CustomRadioSession"](
custom_radio__config=[{"type": "artist", "ids": [artist.pk]}] custom_radio__config=[{"type": "artist", "ids": [artist.pk]}]
) )
choices = session.radio.get_choices() choices = session.radio.get_choices(filter_playable=False)
expected = [t.pk for t in tracks] expected = [t.pk for t in tracks]
assert list(choices.values_list("id", flat=True)) == expected assert list(choices.values_list("id", flat=True)) == expected
...@@ -94,19 +94,19 @@ def test_cannot_start_custom_radio_if_not_owner_or_not_public(factories): ...@@ -94,19 +94,19 @@ def test_cannot_start_custom_radio_if_not_owner_or_not_public(factories):
assert message in serializer.errors["non_field_errors"] assert message in serializer.errors["non_field_errors"]
def test_can_start_custom_radio_from_api(logged_in_client, factories): def test_can_start_custom_radio_from_api(logged_in_api_client, factories):
artist = factories["music.Artist"]() artist = factories["music.Artist"]()
radio = factories["radios.Radio"]( radio = factories["radios.Radio"](
config=[{"type": "artist", "ids": [artist.pk]}], user=logged_in_client.user config=[{"type": "artist", "ids": [artist.pk]}], user=logged_in_api_client.user
) )
url = reverse("api:v1:radios:sessions-list") url = reverse("api:v1:radios:sessions-list")
response = logged_in_client.post( response = logged_in_api_client.post(
url, {"radio_type": "custom", "custom_radio": radio.pk} url, {"radio_type": "custom", "custom_radio": radio.pk}
) )
assert response.status_code == 201 assert response.status_code == 201
session = radio.sessions.latest("id") session = radio.sessions.latest("id")
assert session.radio_type == "custom" assert session.radio_type == "custom"
assert session.user == logged_in_client.user assert session.user == logged_in_api_client.user
def test_can_use_radio_session_to_filter_choices(factories): def test_can_use_radio_session_to_filter_choices(factories):
...@@ -116,7 +116,7 @@ def test_can_use_radio_session_to_filter_choices(factories): ...@@ -116,7 +116,7 @@ def test_can_use_radio_session_to_filter_choices(factories):
session = radio.start_session(user) session = radio.start_session(user)
for i in range(10): for i in range(10):
radio.pick() radio.pick(filter_playable=False)
# ensure 10 differents tracks have been suggested # ensure 10 differents tracks have been suggested
tracks_id = [ tracks_id = [
...@@ -134,30 +134,34 @@ def test_can_restore_radio_from_previous_session(factories): ...@@ -134,30 +134,34 @@ def test_can_restore_radio_from_previous_session(factories):
assert radio.session == restarted_radio.session assert radio.session == restarted_radio.session
def test_can_start_radio_for_logged_in_user(logged_in_client): def test_can_start_radio_for_logged_in_user(logged_in_api_client):
url = reverse("api:v1:radios:sessions-list") url = reverse("api:v1:radios:sessions-list")
logged_in_client.post(url, {"radio_type": "random"}) logged_in_api_client.post(url, {"radio_type": "random"})
session = models.RadioSession.objects.latest("id") session = models.RadioSession.objects.latest("id")
assert session.radio_type == "random" assert session.radio_type == "random"
assert session.user == logged_in_client.user assert session.user == logged_in_api_client.user
def test_can_get_track_for_session_from_api(factories, logged_in_client): def test_can_get_track_for_session_from_api(factories, logged_in_api_client):
files = factories["music.Upload"].create_batch(1) actor = logged_in_api_client.user.create_actor()
tracks = [f.track for f in files] track = factories["music.Upload"](
library__actor=actor, import_status="finished"
).track
url = reverse("api:v1:radios:sessions-list") url = reverse("api:v1:radios:sessions-list")
response = logged_in_client.post(url, {"radio_type": "random"}) response = logged_in_api_client.post(url, {"radio_type": "random"})
session = models.RadioSession.objects.latest("id") session = models.RadioSession.objects.latest("id")
url = reverse("api:v1:radios:tracks-list") url = reverse("api:v1:radios:tracks-list")
response = logged_in_client.post(url, {"session": session.pk}) response = logged_in_api_client.post(url, {"session": session.pk})
data = json.loads(response.content.decode("utf-8")) data = json.loads(response.content.decode("utf-8"))
assert data["track"]["id"] == tracks[0].id assert data["track"]["id"] == track.pk
assert data["position"] == 1 assert data["position"] == 1
next_track = factories["music.Upload"]().track next_track = factories["music.Upload"](
response = logged_in_client.post(url, {"session": session.pk}) library__actor=actor, import_status="finished"
).track
response = logged_in_api_client.post(url, {"session": session.pk})
data = json.loads(response.content.decode("utf-8")) data = json.loads(response.content.decode("utf-8"))
assert data["track"]["id"] == next_track.id assert data["track"]["id"] == next_track.id
...@@ -188,7 +192,7 @@ def test_can_start_artist_radio(factories): ...@@ -188,7 +192,7 @@ def test_can_start_artist_radio(factories):
session = radio.start_session(user, related_object=artist) session = radio.start_session(user, related_object=artist)
assert session.radio_type == "artist" assert session.radio_type == "artist"
for i in range(5): for i in range(5):
assert radio.pick() in good_tracks assert radio.pick(filter_playable=False) in good_tracks
def test_can_start_tag_radio(factories): def test_can_start_tag_radio(factories):
...@@ -202,7 +206,7 @@ def test_can_start_tag_radio(factories): ...@@ -202,7 +206,7 @@ def test_can_start_tag_radio(factories):
session = radio.start_session(user, related_object=tag) session = radio.start_session(user, related_object=tag)
assert session.radio_type == "tag" assert session.radio_type == "tag"
for i in range(5): for i in range(5):
assert radio.pick() in good_tracks assert radio.pick(filter_playable=False) in good_tracks
def test_can_start_artist_radio_from_api(logged_in_api_client, preferences, factories): def test_can_start_artist_radio_from_api(logged_in_api_client, preferences, factories):
...@@ -232,4 +236,4 @@ def test_can_start_less_listened_radio(factories): ...@@ -232,4 +236,4 @@ def test_can_start_less_listened_radio(factories):
radio.start_session(user) radio.start_session(user)
for i in range(5): for i in range(5):
assert radio.pick() in good_tracks assert radio.pick(filter_playable=False) in good_tracks
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment