From f8b15a3f48882deca80e3341dcfae33ea19cce7f Mon Sep 17 00:00:00 2001
From: Eliot Berriot <contact@eliotberriot.com>
Date: Tue, 20 Mar 2018 19:56:42 +0100
Subject: [PATCH] Added API endpoint to insert multiple tracks into playlist

---
 api/funkwhale_api/playlists/models.py      | 31 +++++++++++++++++-----
 api/funkwhale_api/playlists/serializers.py | 22 ++++++++++++++-
 api/funkwhale_api/playlists/views.py       | 27 ++++++++++++++++++-
 api/tests/playlists/test_models.py         | 27 ++++++++++++++++---
 api/tests/playlists/test_views.py          | 17 ++++++++++++
 5 files changed, 112 insertions(+), 12 deletions(-)

diff --git a/api/funkwhale_api/playlists/models.py b/api/funkwhale_api/playlists/models.py
index f26dbeff..ed4307b6 100644
--- a/api/funkwhale_api/playlists/models.py
+++ b/api/funkwhale_api/playlists/models.py
@@ -1,8 +1,10 @@
-from django import forms
+from django.conf import settings
 from django.db import models
 from django.db import transaction
 from django.utils import timezone
 
+from rest_framework import exceptions
+
 from funkwhale_api.common import fields
 
 
@@ -40,17 +42,15 @@ class Playlist(models.Model):
             index = total
 
         if index > total:
-            raise forms.ValidationError('Index is not continuous')
+            raise exceptions.ValidationError('Index is not continuous')
 
         if index < 0:
-            raise forms.ValidationError('Index must be zero or positive')
+            raise exceptions.ValidationError('Index must be zero or positive')
 
         if move:
             # we remove the index temporarily, to avoid integrity errors
             plt.index = None
             plt.save(update_fields=['index'])
-
-        if move:
             if index > old_index:
                 # new index is higher than current, we decrement previous tracks
                 to_update = existing.filter(
@@ -58,8 +58,7 @@ class Playlist(models.Model):
                 to_update.update(index=models.F('index') - 1)
             if index < old_index:
                 # new index is lower than current, we increment next tracks
-                to_update = existing.filter(
-                    index__lt=old_index, index__gte=index)
+                to_update = existing.filter(index__lt=old_index, index__gte=index)
                 to_update.update(index=models.F('index') + 1)
         else:
             to_update = existing.filter(index__gte=index)
@@ -77,6 +76,24 @@ class Playlist(models.Model):
         to_update = existing.filter(index__gt=index)
         return to_update.update(index=models.F('index') - 1)
 
+    @transaction.atomic
+    def insert_many(self, tracks):
+        existing = self.playlist_tracks.select_for_update()
+        now = timezone.now()
+        total = existing.filter(index__isnull=False).count()
+        if existing.count() + len(tracks) > settings.PLAYLISTS_MAX_TRACKS:
+            raise exceptions.ValidationError(
+                'Playlist would reach the maximum of {} tracks'.format(
+                    settings.PLAYLISTS_MAX_TRACKS))
+        self.save(update_fields=['modification_date'])
+        start = total
+        plts = [
+            PlaylistTrack(
+                creation_date=now, playlist=self, track=track, index=start+i)
+            for i, track in enumerate(tracks)
+        ]
+        return PlaylistTrack.objects.bulk_create(plts)
+
 
 class PlaylistTrack(models.Model):
     track = models.ForeignKey(
diff --git a/api/funkwhale_api/playlists/serializers.py b/api/funkwhale_api/playlists/serializers.py
index 0680dc8d..2f71c823 100644
--- a/api/funkwhale_api/playlists/serializers.py
+++ b/api/funkwhale_api/playlists/serializers.py
@@ -3,8 +3,9 @@ from django.db import transaction
 from rest_framework import serializers
 from taggit.models import Tag
 
+from funkwhale_api.music.models import Track
 from funkwhale_api.music.serializers import TrackSerializerNested
-
+from funkwhale_api.users.serializers import UserBasicSerializer
 from . import models
 
 
@@ -61,20 +62,34 @@ class PlaylistTrackWriteSerializer(serializers.ModelSerializer):
         return []
 
 
+class PlaylistWriteSerializer(serializers.ModelSerializer):
+
+    class Meta:
+        model = models.Playlist
+        fields = [
+            'id',
+            'name',
+            'privacy_level',
+        ]
+
+
 class PlaylistSerializer(serializers.ModelSerializer):
     tracks_count = serializers.SerializerMethodField()
+    user = UserBasicSerializer()
 
     class Meta:
         model = models.Playlist
         fields = (
             'id',
             'name',
+            'user',
             'tracks_count',
             'privacy_level',
             'creation_date',
             'modification_date')
         read_only_fields = [
             'id',
+            'user',
             'modification_date',
             'creation_date',]
 
@@ -84,3 +99,8 @@ class PlaylistSerializer(serializers.ModelSerializer):
         except AttributeError:
             # no annotation?
             return obj.playlist_tracks.count()
+
+
+class PlaylistAddManySerializer(serializers.Serializer):
+    tracks = serializers.PrimaryKeyRelatedField(
+        many=True, queryset=Track.objects.for_nested_serialization())
diff --git a/api/funkwhale_api/playlists/views.py b/api/funkwhale_api/playlists/views.py
index 797f7567..a077dec2 100644
--- a/api/funkwhale_api/playlists/views.py
+++ b/api/funkwhale_api/playlists/views.py
@@ -1,5 +1,7 @@
 from django.db.models import Count
+from django.db import transaction
 
+from rest_framework import exceptions
 from rest_framework import generics, mixins, viewsets
 from rest_framework import status
 from rest_framework.decorators import detail_route
@@ -25,7 +27,7 @@ class PlaylistViewSet(
 
     serializer_class = serializers.PlaylistSerializer
     queryset = (
-        models.Playlist.objects.all()
+        models.Playlist.objects.all().select_related('user')
               .annotate(tracks_count=Count('playlist_tracks'))
     )
     permission_classes = [
@@ -36,6 +38,11 @@ class PlaylistViewSet(
     owner_checks = ['write']
     filter_class = filters.PlaylistFilter
 
+    def get_serializer_class(self):
+        if self.request.method in ['PUT', 'PATCH', 'DELETE', 'POST']:
+            return serializers.PlaylistWriteSerializer
+        return self.serializer_class
+
     @detail_route(methods=['get'])
     def tracks(self, request, *args, **kwargs):
         playlist = self.get_object()
@@ -47,6 +54,24 @@ class PlaylistViewSet(
         }
         return Response(data, status=200)
 
+    @detail_route(methods=['post'])
+    @transaction.atomic
+    def add(self, request, *args, **kwargs):
+        playlist = self.get_object()
+        serializer = serializers.PlaylistAddManySerializer(data=request.data)
+        serializer.is_valid(raise_exception=True)
+        try:
+            plts = playlist.insert_many(serializer.validated_data['tracks'])
+        except exceptions.ValidationError as e:
+            payload = {'playlist': e.detail}
+            return Response(payload, status=400)
+        serializer = serializers.PlaylistTrackSerializer(plts, many=True)
+        data = {
+            'count': len(plts),
+            'results': serializer.data
+        }
+        return Response(data, status=201)
+
     def get_queryset(self):
         return self.queryset.filter(
             fields.privacy_level_query(self.request.user))
diff --git a/api/tests/playlists/test_models.py b/api/tests/playlists/test_models.py
index 414ecaff..c9def4da 100644
--- a/api/tests/playlists/test_models.py
+++ b/api/tests/playlists/test_models.py
@@ -1,6 +1,6 @@
 import pytest
 
-from django import forms
+from rest_framework import exceptions
 
 
 def test_can_insert_plt(factories):
@@ -79,14 +79,14 @@ def test_can_insert_and_move_last_to_0(factories):
 def test_cannot_insert_at_wrong_index(factories):
     plt = factories['playlists.PlaylistTrack']()
     new = factories['playlists.PlaylistTrack'](playlist=plt.playlist)
-    with pytest.raises(forms.ValidationError):
+    with pytest.raises(exceptions.ValidationError):
         plt.playlist.insert(new, 2)
 
 
 def test_cannot_insert_at_negative_index(factories):
     plt = factories['playlists.PlaylistTrack']()
     new = factories['playlists.PlaylistTrack'](playlist=plt.playlist)
-    with pytest.raises(forms.ValidationError):
+    with pytest.raises(exceptions.ValidationError):
         plt.playlist.insert(new, -1)
 
 
@@ -103,3 +103,24 @@ def test_remove_update_indexes(factories):
 
     assert first.index == 0
     assert third.index == 1
+
+
+def test_can_insert_many(factories):
+    playlist = factories['playlists.Playlist']()
+    existing = factories['playlists.PlaylistTrack'](playlist=playlist, index=0)
+    tracks = factories['music.Track'].create_batch(size=3)
+    plts = playlist.insert_many(tracks)
+    for i, plt in enumerate(plts):
+        assert plt.index == i + 1
+        assert plt.track == tracks[i]
+        assert plt.playlist == playlist
+
+
+def test_insert_many_honor_max_tracks(factories, settings):
+    settings.PLAYLISTS_MAX_TRACKS = 4
+    playlist = factories['playlists.Playlist']()
+    plts = factories['playlists.PlaylistTrack'].create_batch(
+        size=2, playlist=playlist)
+    track = factories['music.Track']()
+    with pytest.raises(exceptions.ValidationError):
+        playlist.insert_many([track, track, track])
diff --git a/api/tests/playlists/test_views.py b/api/tests/playlists/test_views.py
index df7d04a2..ae3fd007 100644
--- a/api/tests/playlists/test_views.py
+++ b/api/tests/playlists/test_views.py
@@ -153,3 +153,20 @@ def test_can_list_tracks_from_playlist(
 
     assert response.data['count'] == 1
     assert response.data['results'][0] == serialized_plt
+
+
+def test_can_add_multiple_tracks_at_once_via_api(
+        factories, mocker, logged_in_api_client):
+    playlist = factories['playlists.Playlist'](user=logged_in_api_client.user)
+    tracks = factories['music.Track'].create_batch(size=5)
+    track_ids = [t.id for t in tracks]
+    mocker.spy(playlist, 'insert_many')
+    url = reverse('api:v1:playlists-add', kwargs={'pk': playlist.pk})
+    response = logged_in_api_client.post(url, {'tracks': track_ids})
+
+    assert response.status_code == 201
+    assert playlist.playlist_tracks.count() == len(track_ids)
+
+    for plt in playlist.playlist_tracks.order_by('index'):
+        assert response.data['results'][plt.index]['id'] == plt.id
+        assert plt.track == tracks[plt.index]
-- 
GitLab