From f8b15a3f48882deca80e3341dcfae33ea19cce7f Mon Sep 17 00:00:00 2001 From: Eliot Berriot <contact@eliotberriot.com> Date: Tue, 20 Mar 2018 19:56:42 +0100 Subject: [PATCH] Added API endpoint to insert multiple tracks into playlist --- api/funkwhale_api/playlists/models.py | 31 +++++++++++++++++----- api/funkwhale_api/playlists/serializers.py | 22 ++++++++++++++- api/funkwhale_api/playlists/views.py | 27 ++++++++++++++++++- api/tests/playlists/test_models.py | 27 ++++++++++++++++--- api/tests/playlists/test_views.py | 17 ++++++++++++ 5 files changed, 112 insertions(+), 12 deletions(-) diff --git a/api/funkwhale_api/playlists/models.py b/api/funkwhale_api/playlists/models.py index f26dbeff..ed4307b6 100644 --- a/api/funkwhale_api/playlists/models.py +++ b/api/funkwhale_api/playlists/models.py @@ -1,8 +1,10 @@ -from django import forms +from django.conf import settings from django.db import models from django.db import transaction from django.utils import timezone +from rest_framework import exceptions + from funkwhale_api.common import fields @@ -40,17 +42,15 @@ class Playlist(models.Model): index = total if index > total: - raise forms.ValidationError('Index is not continuous') + raise exceptions.ValidationError('Index is not continuous') if index < 0: - raise forms.ValidationError('Index must be zero or positive') + raise exceptions.ValidationError('Index must be zero or positive') if move: # we remove the index temporarily, to avoid integrity errors plt.index = None plt.save(update_fields=['index']) - - if move: if index > old_index: # new index is higher than current, we decrement previous tracks to_update = existing.filter( @@ -58,8 +58,7 @@ class Playlist(models.Model): to_update.update(index=models.F('index') - 1) if index < old_index: # new index is lower than current, we increment next tracks - to_update = existing.filter( - index__lt=old_index, index__gte=index) + to_update = existing.filter(index__lt=old_index, index__gte=index) to_update.update(index=models.F('index') + 1) else: to_update = existing.filter(index__gte=index) @@ -77,6 +76,24 @@ class Playlist(models.Model): to_update = existing.filter(index__gt=index) return to_update.update(index=models.F('index') - 1) + @transaction.atomic + def insert_many(self, tracks): + existing = self.playlist_tracks.select_for_update() + now = timezone.now() + total = existing.filter(index__isnull=False).count() + if existing.count() + len(tracks) > settings.PLAYLISTS_MAX_TRACKS: + raise exceptions.ValidationError( + 'Playlist would reach the maximum of {} tracks'.format( + settings.PLAYLISTS_MAX_TRACKS)) + self.save(update_fields=['modification_date']) + start = total + plts = [ + PlaylistTrack( + creation_date=now, playlist=self, track=track, index=start+i) + for i, track in enumerate(tracks) + ] + return PlaylistTrack.objects.bulk_create(plts) + class PlaylistTrack(models.Model): track = models.ForeignKey( diff --git a/api/funkwhale_api/playlists/serializers.py b/api/funkwhale_api/playlists/serializers.py index 0680dc8d..2f71c823 100644 --- a/api/funkwhale_api/playlists/serializers.py +++ b/api/funkwhale_api/playlists/serializers.py @@ -3,8 +3,9 @@ from django.db import transaction from rest_framework import serializers from taggit.models import Tag +from funkwhale_api.music.models import Track from funkwhale_api.music.serializers import TrackSerializerNested - +from funkwhale_api.users.serializers import UserBasicSerializer from . import models @@ -61,20 +62,34 @@ class PlaylistTrackWriteSerializer(serializers.ModelSerializer): return [] +class PlaylistWriteSerializer(serializers.ModelSerializer): + + class Meta: + model = models.Playlist + fields = [ + 'id', + 'name', + 'privacy_level', + ] + + class PlaylistSerializer(serializers.ModelSerializer): tracks_count = serializers.SerializerMethodField() + user = UserBasicSerializer() class Meta: model = models.Playlist fields = ( 'id', 'name', + 'user', 'tracks_count', 'privacy_level', 'creation_date', 'modification_date') read_only_fields = [ 'id', + 'user', 'modification_date', 'creation_date',] @@ -84,3 +99,8 @@ class PlaylistSerializer(serializers.ModelSerializer): except AttributeError: # no annotation? return obj.playlist_tracks.count() + + +class PlaylistAddManySerializer(serializers.Serializer): + tracks = serializers.PrimaryKeyRelatedField( + many=True, queryset=Track.objects.for_nested_serialization()) diff --git a/api/funkwhale_api/playlists/views.py b/api/funkwhale_api/playlists/views.py index 797f7567..a077dec2 100644 --- a/api/funkwhale_api/playlists/views.py +++ b/api/funkwhale_api/playlists/views.py @@ -1,5 +1,7 @@ from django.db.models import Count +from django.db import transaction +from rest_framework import exceptions from rest_framework import generics, mixins, viewsets from rest_framework import status from rest_framework.decorators import detail_route @@ -25,7 +27,7 @@ class PlaylistViewSet( serializer_class = serializers.PlaylistSerializer queryset = ( - models.Playlist.objects.all() + models.Playlist.objects.all().select_related('user') .annotate(tracks_count=Count('playlist_tracks')) ) permission_classes = [ @@ -36,6 +38,11 @@ class PlaylistViewSet( owner_checks = ['write'] filter_class = filters.PlaylistFilter + def get_serializer_class(self): + if self.request.method in ['PUT', 'PATCH', 'DELETE', 'POST']: + return serializers.PlaylistWriteSerializer + return self.serializer_class + @detail_route(methods=['get']) def tracks(self, request, *args, **kwargs): playlist = self.get_object() @@ -47,6 +54,24 @@ class PlaylistViewSet( } return Response(data, status=200) + @detail_route(methods=['post']) + @transaction.atomic + def add(self, request, *args, **kwargs): + playlist = self.get_object() + serializer = serializers.PlaylistAddManySerializer(data=request.data) + serializer.is_valid(raise_exception=True) + try: + plts = playlist.insert_many(serializer.validated_data['tracks']) + except exceptions.ValidationError as e: + payload = {'playlist': e.detail} + return Response(payload, status=400) + serializer = serializers.PlaylistTrackSerializer(plts, many=True) + data = { + 'count': len(plts), + 'results': serializer.data + } + return Response(data, status=201) + def get_queryset(self): return self.queryset.filter( fields.privacy_level_query(self.request.user)) diff --git a/api/tests/playlists/test_models.py b/api/tests/playlists/test_models.py index 414ecaff..c9def4da 100644 --- a/api/tests/playlists/test_models.py +++ b/api/tests/playlists/test_models.py @@ -1,6 +1,6 @@ import pytest -from django import forms +from rest_framework import exceptions def test_can_insert_plt(factories): @@ -79,14 +79,14 @@ def test_can_insert_and_move_last_to_0(factories): def test_cannot_insert_at_wrong_index(factories): plt = factories['playlists.PlaylistTrack']() new = factories['playlists.PlaylistTrack'](playlist=plt.playlist) - with pytest.raises(forms.ValidationError): + with pytest.raises(exceptions.ValidationError): plt.playlist.insert(new, 2) def test_cannot_insert_at_negative_index(factories): plt = factories['playlists.PlaylistTrack']() new = factories['playlists.PlaylistTrack'](playlist=plt.playlist) - with pytest.raises(forms.ValidationError): + with pytest.raises(exceptions.ValidationError): plt.playlist.insert(new, -1) @@ -103,3 +103,24 @@ def test_remove_update_indexes(factories): assert first.index == 0 assert third.index == 1 + + +def test_can_insert_many(factories): + playlist = factories['playlists.Playlist']() + existing = factories['playlists.PlaylistTrack'](playlist=playlist, index=0) + tracks = factories['music.Track'].create_batch(size=3) + plts = playlist.insert_many(tracks) + for i, plt in enumerate(plts): + assert plt.index == i + 1 + assert plt.track == tracks[i] + assert plt.playlist == playlist + + +def test_insert_many_honor_max_tracks(factories, settings): + settings.PLAYLISTS_MAX_TRACKS = 4 + playlist = factories['playlists.Playlist']() + plts = factories['playlists.PlaylistTrack'].create_batch( + size=2, playlist=playlist) + track = factories['music.Track']() + with pytest.raises(exceptions.ValidationError): + playlist.insert_many([track, track, track]) diff --git a/api/tests/playlists/test_views.py b/api/tests/playlists/test_views.py index df7d04a2..ae3fd007 100644 --- a/api/tests/playlists/test_views.py +++ b/api/tests/playlists/test_views.py @@ -153,3 +153,20 @@ def test_can_list_tracks_from_playlist( assert response.data['count'] == 1 assert response.data['results'][0] == serialized_plt + + +def test_can_add_multiple_tracks_at_once_via_api( + factories, mocker, logged_in_api_client): + playlist = factories['playlists.Playlist'](user=logged_in_api_client.user) + tracks = factories['music.Track'].create_batch(size=5) + track_ids = [t.id for t in tracks] + mocker.spy(playlist, 'insert_many') + url = reverse('api:v1:playlists-add', kwargs={'pk': playlist.pk}) + response = logged_in_api_client.post(url, {'tracks': track_ids}) + + assert response.status_code == 201 + assert playlist.playlist_tracks.count() == len(track_ids) + + for plt in playlist.playlist_tracks.order_by('index'): + assert response.data['results'][plt.index]['id'] == plt.id + assert plt.track == tracks[plt.index] -- GitLab