Skip to content
Snippets Groups Projects
Verified Commit f8b15a3f authored by Eliot Berriot's avatar Eliot Berriot
Browse files

Added API endpoint to insert multiple tracks into playlist

parent 1729c4f8
No related branches found
No related tags found
No related merge requests found
from django import forms from django.conf import settings
from django.db import models from django.db import models
from django.db import transaction from django.db import transaction
from django.utils import timezone from django.utils import timezone
from rest_framework import exceptions
from funkwhale_api.common import fields from funkwhale_api.common import fields
...@@ -40,17 +42,15 @@ class Playlist(models.Model): ...@@ -40,17 +42,15 @@ class Playlist(models.Model):
index = total index = total
if index > total: if index > total:
raise forms.ValidationError('Index is not continuous') raise exceptions.ValidationError('Index is not continuous')
if index < 0: if index < 0:
raise forms.ValidationError('Index must be zero or positive') raise exceptions.ValidationError('Index must be zero or positive')
if move: if move:
# we remove the index temporarily, to avoid integrity errors # we remove the index temporarily, to avoid integrity errors
plt.index = None plt.index = None
plt.save(update_fields=['index']) plt.save(update_fields=['index'])
if move:
if index > old_index: if index > old_index:
# new index is higher than current, we decrement previous tracks # new index is higher than current, we decrement previous tracks
to_update = existing.filter( to_update = existing.filter(
...@@ -58,8 +58,7 @@ class Playlist(models.Model): ...@@ -58,8 +58,7 @@ class Playlist(models.Model):
to_update.update(index=models.F('index') - 1) to_update.update(index=models.F('index') - 1)
if index < old_index: if index < old_index:
# new index is lower than current, we increment next tracks # new index is lower than current, we increment next tracks
to_update = existing.filter( to_update = existing.filter(index__lt=old_index, index__gte=index)
index__lt=old_index, index__gte=index)
to_update.update(index=models.F('index') + 1) to_update.update(index=models.F('index') + 1)
else: else:
to_update = existing.filter(index__gte=index) to_update = existing.filter(index__gte=index)
...@@ -77,6 +76,24 @@ class Playlist(models.Model): ...@@ -77,6 +76,24 @@ class Playlist(models.Model):
to_update = existing.filter(index__gt=index) to_update = existing.filter(index__gt=index)
return to_update.update(index=models.F('index') - 1) return to_update.update(index=models.F('index') - 1)
@transaction.atomic
def insert_many(self, tracks):
existing = self.playlist_tracks.select_for_update()
now = timezone.now()
total = existing.filter(index__isnull=False).count()
if existing.count() + len(tracks) > settings.PLAYLISTS_MAX_TRACKS:
raise exceptions.ValidationError(
'Playlist would reach the maximum of {} tracks'.format(
settings.PLAYLISTS_MAX_TRACKS))
self.save(update_fields=['modification_date'])
start = total
plts = [
PlaylistTrack(
creation_date=now, playlist=self, track=track, index=start+i)
for i, track in enumerate(tracks)
]
return PlaylistTrack.objects.bulk_create(plts)
class PlaylistTrack(models.Model): class PlaylistTrack(models.Model):
track = models.ForeignKey( track = models.ForeignKey(
......
...@@ -3,8 +3,9 @@ from django.db import transaction ...@@ -3,8 +3,9 @@ from django.db import transaction
from rest_framework import serializers from rest_framework import serializers
from taggit.models import Tag from taggit.models import Tag
from funkwhale_api.music.models import Track
from funkwhale_api.music.serializers import TrackSerializerNested from funkwhale_api.music.serializers import TrackSerializerNested
from funkwhale_api.users.serializers import UserBasicSerializer
from . import models from . import models
...@@ -61,20 +62,34 @@ class PlaylistTrackWriteSerializer(serializers.ModelSerializer): ...@@ -61,20 +62,34 @@ class PlaylistTrackWriteSerializer(serializers.ModelSerializer):
return [] return []
class PlaylistWriteSerializer(serializers.ModelSerializer):
class Meta:
model = models.Playlist
fields = [
'id',
'name',
'privacy_level',
]
class PlaylistSerializer(serializers.ModelSerializer): class PlaylistSerializer(serializers.ModelSerializer):
tracks_count = serializers.SerializerMethodField() tracks_count = serializers.SerializerMethodField()
user = UserBasicSerializer()
class Meta: class Meta:
model = models.Playlist model = models.Playlist
fields = ( fields = (
'id', 'id',
'name', 'name',
'user',
'tracks_count', 'tracks_count',
'privacy_level', 'privacy_level',
'creation_date', 'creation_date',
'modification_date') 'modification_date')
read_only_fields = [ read_only_fields = [
'id', 'id',
'user',
'modification_date', 'modification_date',
'creation_date',] 'creation_date',]
...@@ -84,3 +99,8 @@ class PlaylistSerializer(serializers.ModelSerializer): ...@@ -84,3 +99,8 @@ class PlaylistSerializer(serializers.ModelSerializer):
except AttributeError: except AttributeError:
# no annotation? # no annotation?
return obj.playlist_tracks.count() return obj.playlist_tracks.count()
class PlaylistAddManySerializer(serializers.Serializer):
tracks = serializers.PrimaryKeyRelatedField(
many=True, queryset=Track.objects.for_nested_serialization())
from django.db.models import Count from django.db.models import Count
from django.db import transaction
from rest_framework import exceptions
from rest_framework import generics, mixins, viewsets from rest_framework import generics, mixins, viewsets
from rest_framework import status from rest_framework import status
from rest_framework.decorators import detail_route from rest_framework.decorators import detail_route
...@@ -25,7 +27,7 @@ class PlaylistViewSet( ...@@ -25,7 +27,7 @@ class PlaylistViewSet(
serializer_class = serializers.PlaylistSerializer serializer_class = serializers.PlaylistSerializer
queryset = ( queryset = (
models.Playlist.objects.all() models.Playlist.objects.all().select_related('user')
.annotate(tracks_count=Count('playlist_tracks')) .annotate(tracks_count=Count('playlist_tracks'))
) )
permission_classes = [ permission_classes = [
...@@ -36,6 +38,11 @@ class PlaylistViewSet( ...@@ -36,6 +38,11 @@ class PlaylistViewSet(
owner_checks = ['write'] owner_checks = ['write']
filter_class = filters.PlaylistFilter filter_class = filters.PlaylistFilter
def get_serializer_class(self):
if self.request.method in ['PUT', 'PATCH', 'DELETE', 'POST']:
return serializers.PlaylistWriteSerializer
return self.serializer_class
@detail_route(methods=['get']) @detail_route(methods=['get'])
def tracks(self, request, *args, **kwargs): def tracks(self, request, *args, **kwargs):
playlist = self.get_object() playlist = self.get_object()
...@@ -47,6 +54,24 @@ class PlaylistViewSet( ...@@ -47,6 +54,24 @@ class PlaylistViewSet(
} }
return Response(data, status=200) return Response(data, status=200)
@detail_route(methods=['post'])
@transaction.atomic
def add(self, request, *args, **kwargs):
playlist = self.get_object()
serializer = serializers.PlaylistAddManySerializer(data=request.data)
serializer.is_valid(raise_exception=True)
try:
plts = playlist.insert_many(serializer.validated_data['tracks'])
except exceptions.ValidationError as e:
payload = {'playlist': e.detail}
return Response(payload, status=400)
serializer = serializers.PlaylistTrackSerializer(plts, many=True)
data = {
'count': len(plts),
'results': serializer.data
}
return Response(data, status=201)
def get_queryset(self): def get_queryset(self):
return self.queryset.filter( return self.queryset.filter(
fields.privacy_level_query(self.request.user)) fields.privacy_level_query(self.request.user))
......
import pytest import pytest
from django import forms from rest_framework import exceptions
def test_can_insert_plt(factories): def test_can_insert_plt(factories):
...@@ -79,14 +79,14 @@ def test_can_insert_and_move_last_to_0(factories): ...@@ -79,14 +79,14 @@ def test_can_insert_and_move_last_to_0(factories):
def test_cannot_insert_at_wrong_index(factories): def test_cannot_insert_at_wrong_index(factories):
plt = factories['playlists.PlaylistTrack']() plt = factories['playlists.PlaylistTrack']()
new = factories['playlists.PlaylistTrack'](playlist=plt.playlist) new = factories['playlists.PlaylistTrack'](playlist=plt.playlist)
with pytest.raises(forms.ValidationError): with pytest.raises(exceptions.ValidationError):
plt.playlist.insert(new, 2) plt.playlist.insert(new, 2)
def test_cannot_insert_at_negative_index(factories): def test_cannot_insert_at_negative_index(factories):
plt = factories['playlists.PlaylistTrack']() plt = factories['playlists.PlaylistTrack']()
new = factories['playlists.PlaylistTrack'](playlist=plt.playlist) new = factories['playlists.PlaylistTrack'](playlist=plt.playlist)
with pytest.raises(forms.ValidationError): with pytest.raises(exceptions.ValidationError):
plt.playlist.insert(new, -1) plt.playlist.insert(new, -1)
...@@ -103,3 +103,24 @@ def test_remove_update_indexes(factories): ...@@ -103,3 +103,24 @@ def test_remove_update_indexes(factories):
assert first.index == 0 assert first.index == 0
assert third.index == 1 assert third.index == 1
def test_can_insert_many(factories):
playlist = factories['playlists.Playlist']()
existing = factories['playlists.PlaylistTrack'](playlist=playlist, index=0)
tracks = factories['music.Track'].create_batch(size=3)
plts = playlist.insert_many(tracks)
for i, plt in enumerate(plts):
assert plt.index == i + 1
assert plt.track == tracks[i]
assert plt.playlist == playlist
def test_insert_many_honor_max_tracks(factories, settings):
settings.PLAYLISTS_MAX_TRACKS = 4
playlist = factories['playlists.Playlist']()
plts = factories['playlists.PlaylistTrack'].create_batch(
size=2, playlist=playlist)
track = factories['music.Track']()
with pytest.raises(exceptions.ValidationError):
playlist.insert_many([track, track, track])
...@@ -153,3 +153,20 @@ def test_can_list_tracks_from_playlist( ...@@ -153,3 +153,20 @@ def test_can_list_tracks_from_playlist(
assert response.data['count'] == 1 assert response.data['count'] == 1
assert response.data['results'][0] == serialized_plt assert response.data['results'][0] == serialized_plt
def test_can_add_multiple_tracks_at_once_via_api(
factories, mocker, logged_in_api_client):
playlist = factories['playlists.Playlist'](user=logged_in_api_client.user)
tracks = factories['music.Track'].create_batch(size=5)
track_ids = [t.id for t in tracks]
mocker.spy(playlist, 'insert_many')
url = reverse('api:v1:playlists-add', kwargs={'pk': playlist.pk})
response = logged_in_api_client.post(url, {'tracks': track_ids})
assert response.status_code == 201
assert playlist.playlist_tracks.count() == len(track_ids)
for plt in playlist.playlist_tracks.order_by('index'):
assert response.data['results'][plt.index]['id'] == plt.id
assert plt.track == tracks[plt.index]
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment