Verified Commit f8b15a3f authored by Eliot Berriot's avatar Eliot Berriot
Browse files

Added API endpoint to insert multiple tracks into playlist

parent 1729c4f8
from django import forms
from django.conf import settings
from django.db import models
from django.db import transaction
from django.utils import timezone
from rest_framework import exceptions
from funkwhale_api.common import fields
......@@ -40,17 +42,15 @@ class Playlist(models.Model):
index = total
if index > total:
raise forms.ValidationError('Index is not continuous')
raise exceptions.ValidationError('Index is not continuous')
if index < 0:
raise forms.ValidationError('Index must be zero or positive')
raise exceptions.ValidationError('Index must be zero or positive')
if move:
# we remove the index temporarily, to avoid integrity errors
plt.index = None
plt.save(update_fields=['index'])
if move:
if index > old_index:
# new index is higher than current, we decrement previous tracks
to_update = existing.filter(
......@@ -58,8 +58,7 @@ class Playlist(models.Model):
to_update.update(index=models.F('index') - 1)
if index < old_index:
# new index is lower than current, we increment next tracks
to_update = existing.filter(
index__lt=old_index, index__gte=index)
to_update = existing.filter(index__lt=old_index, index__gte=index)
to_update.update(index=models.F('index') + 1)
else:
to_update = existing.filter(index__gte=index)
......@@ -77,6 +76,24 @@ class Playlist(models.Model):
to_update = existing.filter(index__gt=index)
return to_update.update(index=models.F('index') - 1)
@transaction.atomic
def insert_many(self, tracks):
existing = self.playlist_tracks.select_for_update()
now = timezone.now()
total = existing.filter(index__isnull=False).count()
if existing.count() + len(tracks) > settings.PLAYLISTS_MAX_TRACKS:
raise exceptions.ValidationError(
'Playlist would reach the maximum of {} tracks'.format(
settings.PLAYLISTS_MAX_TRACKS))
self.save(update_fields=['modification_date'])
start = total
plts = [
PlaylistTrack(
creation_date=now, playlist=self, track=track, index=start+i)
for i, track in enumerate(tracks)
]
return PlaylistTrack.objects.bulk_create(plts)
class PlaylistTrack(models.Model):
track = models.ForeignKey(
......
......@@ -3,8 +3,9 @@ from django.db import transaction
from rest_framework import serializers
from taggit.models import Tag
from funkwhale_api.music.models import Track
from funkwhale_api.music.serializers import TrackSerializerNested
from funkwhale_api.users.serializers import UserBasicSerializer
from . import models
......@@ -61,20 +62,34 @@ class PlaylistTrackWriteSerializer(serializers.ModelSerializer):
return []
class PlaylistWriteSerializer(serializers.ModelSerializer):
class Meta:
model = models.Playlist
fields = [
'id',
'name',
'privacy_level',
]
class PlaylistSerializer(serializers.ModelSerializer):
tracks_count = serializers.SerializerMethodField()
user = UserBasicSerializer()
class Meta:
model = models.Playlist
fields = (
'id',
'name',
'user',
'tracks_count',
'privacy_level',
'creation_date',
'modification_date')
read_only_fields = [
'id',
'user',
'modification_date',
'creation_date',]
......@@ -84,3 +99,8 @@ class PlaylistSerializer(serializers.ModelSerializer):
except AttributeError:
# no annotation?
return obj.playlist_tracks.count()
class PlaylistAddManySerializer(serializers.Serializer):
tracks = serializers.PrimaryKeyRelatedField(
many=True, queryset=Track.objects.for_nested_serialization())
from django.db.models import Count
from django.db import transaction
from rest_framework import exceptions
from rest_framework import generics, mixins, viewsets
from rest_framework import status
from rest_framework.decorators import detail_route
......@@ -25,7 +27,7 @@ class PlaylistViewSet(
serializer_class = serializers.PlaylistSerializer
queryset = (
models.Playlist.objects.all()
models.Playlist.objects.all().select_related('user')
.annotate(tracks_count=Count('playlist_tracks'))
)
permission_classes = [
......@@ -36,6 +38,11 @@ class PlaylistViewSet(
owner_checks = ['write']
filter_class = filters.PlaylistFilter
def get_serializer_class(self):
if self.request.method in ['PUT', 'PATCH', 'DELETE', 'POST']:
return serializers.PlaylistWriteSerializer
return self.serializer_class
@detail_route(methods=['get'])
def tracks(self, request, *args, **kwargs):
playlist = self.get_object()
......@@ -47,6 +54,24 @@ class PlaylistViewSet(
}
return Response(data, status=200)
@detail_route(methods=['post'])
@transaction.atomic
def add(self, request, *args, **kwargs):
playlist = self.get_object()
serializer = serializers.PlaylistAddManySerializer(data=request.data)
serializer.is_valid(raise_exception=True)
try:
plts = playlist.insert_many(serializer.validated_data['tracks'])
except exceptions.ValidationError as e:
payload = {'playlist': e.detail}
return Response(payload, status=400)
serializer = serializers.PlaylistTrackSerializer(plts, many=True)
data = {
'count': len(plts),
'results': serializer.data
}
return Response(data, status=201)
def get_queryset(self):
return self.queryset.filter(
fields.privacy_level_query(self.request.user))
......
import pytest
from django import forms
from rest_framework import exceptions
def test_can_insert_plt(factories):
......@@ -79,14 +79,14 @@ def test_can_insert_and_move_last_to_0(factories):
def test_cannot_insert_at_wrong_index(factories):
plt = factories['playlists.PlaylistTrack']()
new = factories['playlists.PlaylistTrack'](playlist=plt.playlist)
with pytest.raises(forms.ValidationError):
with pytest.raises(exceptions.ValidationError):
plt.playlist.insert(new, 2)
def test_cannot_insert_at_negative_index(factories):
plt = factories['playlists.PlaylistTrack']()
new = factories['playlists.PlaylistTrack'](playlist=plt.playlist)
with pytest.raises(forms.ValidationError):
with pytest.raises(exceptions.ValidationError):
plt.playlist.insert(new, -1)
......@@ -103,3 +103,24 @@ def test_remove_update_indexes(factories):
assert first.index == 0
assert third.index == 1
def test_can_insert_many(factories):
playlist = factories['playlists.Playlist']()
existing = factories['playlists.PlaylistTrack'](playlist=playlist, index=0)
tracks = factories['music.Track'].create_batch(size=3)
plts = playlist.insert_many(tracks)
for i, plt in enumerate(plts):
assert plt.index == i + 1
assert plt.track == tracks[i]
assert plt.playlist == playlist
def test_insert_many_honor_max_tracks(factories, settings):
settings.PLAYLISTS_MAX_TRACKS = 4
playlist = factories['playlists.Playlist']()
plts = factories['playlists.PlaylistTrack'].create_batch(
size=2, playlist=playlist)
track = factories['music.Track']()
with pytest.raises(exceptions.ValidationError):
playlist.insert_many([track, track, track])
......@@ -153,3 +153,20 @@ def test_can_list_tracks_from_playlist(
assert response.data['count'] == 1
assert response.data['results'][0] == serialized_plt
def test_can_add_multiple_tracks_at_once_via_api(
factories, mocker, logged_in_api_client):
playlist = factories['playlists.Playlist'](user=logged_in_api_client.user)
tracks = factories['music.Track'].create_batch(size=5)
track_ids = [t.id for t in tracks]
mocker.spy(playlist, 'insert_many')
url = reverse('api:v1:playlists-add', kwargs={'pk': playlist.pk})
response = logged_in_api_client.post(url, {'tracks': track_ids})
assert response.status_code == 201
assert playlist.playlist_tracks.count() == len(track_ids)
for plt in playlist.playlist_tracks.order_by('index'):
assert response.data['results'][plt.index]['id'] == plt.id
assert plt.track == tracks[plt.index]
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment