Skip to content
Snippets Groups Projects
views.py 4.94 KiB
Newer Older
from django.db.models import Q
from django.http import Http404

from rest_framework import generics, mixins, viewsets
from rest_framework import status
from rest_framework.response import Response
from rest_framework.decorators import detail_route, list_route

from funkwhale_api.music.serializers import TrackSerializerNested
from funkwhale_api.common.permissions import ConditionalAuthentication

from . import models
from . import filters
from . import filtersets

class RadioViewSet(
        mixins.CreateModelMixin,
        mixins.RetrieveModelMixin,
        mixins.UpdateModelMixin,
        mixins.ListModelMixin,
Bat's avatar
Bat committed
        mixins.DestroyModelMixin,
        viewsets.GenericViewSet):

    serializer_class = serializers.RadioSerializer
    permission_classes = [ConditionalAuthentication]
    filter_class = filtersets.RadioFilter

    def get_queryset(self):
        query = Q(is_public=True)
        if self.request.user.is_authenticated:
            query |= Q(user=self.request.user)
        return models.Radio.objects.filter(query)

    def perform_create(self, serializer):
        return serializer.save(user=self.request.user)

    def perform_update(self, serializer):
        if serializer.instance.user != self.request.user:
            raise Http404
        return serializer.save(user=self.request.user)

Bat's avatar
Bat committed
    @detail_route(methods=['get'])
    def tracks(self, request, *args, **kwargs):
        radio = self.get_object()
        tracks = radio.get_candidates().for_nested_serialization()
        serializer = TrackSerializerNested(tracks, many=True)
        data = {
            'count': len(tracks),
            'results': serializer.data
        }
        return Response(data, status=200)

    @list_route(methods=['get'])
    def filters(self, request, *args, **kwargs):
        serializer = serializers.FilterSerializer(
            filters.registry.exposed_filters, many=True)
        return Response(serializer.data)

    @list_route(methods=['post'])
    def validate(self, request, *args, **kwargs):
        try:
            f_list = request.data['filters']
        except KeyError:
            return Response(
                {'error': 'You must provide a filters list'}, status=400)
        data = {
            'filters': []
        }
        for f in f_list:
            results = filters.test(f)
            if results['candidates']['sample']:
                qs = results['candidates']['sample'].for_nested_serialization()
                results['candidates']['sample'] = TrackSerializerNested(
                    qs, many=True).data
            data['filters'].append(results)

        return Response(data)


class RadioSessionViewSet(mixins.CreateModelMixin,
                          mixins.RetrieveModelMixin,
                          viewsets.GenericViewSet):

    serializer_class = serializers.RadioSessionSerializer
    queryset = models.RadioSession.objects.all()
    permission_classes = [ConditionalAuthentication]

    def get_queryset(self):
        queryset = super().get_queryset()
        if self.request.user.is_authenticated:
            return queryset.filter(user=self.request.user)
        else:
            return queryset.filter(session_key=self.request.session.session_key)

    def get_serializer_context(self):
        context = super().get_serializer_context()
        if self.request.user.is_authenticated:
            context['user'] = self.request.user
        else:
            context['session_key'] = self.request.session.session_key
        return context


class RadioSessionTrackViewSet(mixins.CreateModelMixin,
                               viewsets.GenericViewSet):
    serializer_class = serializers.RadioSessionTrackSerializer
    queryset = models.RadioSessionTrack.objects.all()
    permission_classes = [ConditionalAuthentication]

    def create(self, request, *args, **kwargs):
        serializer = self.get_serializer(data=request.data)
        serializer.is_valid(raise_exception=True)
        session = serializer.validated_data['session']
        try:
            if request.user.is_authenticated:
                assert request.user == session.user
            else:
                assert request.session.session_key == session.session_key
        except AssertionError:
            return Response(status=status.HTTP_403_FORBIDDEN)
        track = session.radio.pick()
        session_track = session.session_tracks.all().latest('id')
        # self.perform_create(serializer)
        # dirty override here, since we use a different serializer for creation and detail
        serializer = self.serializer_class(instance=session_track, context=self.get_serializer_context())
        headers = self.get_success_headers(serializer.data)
        return Response(serializer.data, status=status.HTTP_201_CREATED, headers=headers)

    def get_serializer_class(self, *args, **kwargs):
        if self.action == 'create':
            return serializers.RadioSessionTrackSerializerCreate
        return super().get_serializer_class(*args, **kwargs)