Skip to content
Snippets Groups Projects
radios.py 5.02 KiB
Newer Older
  • Learn to ignore specific revisions
  • from rest_framework import serializers
    
    from django.core.exceptions import ValidationError
    from taggit.models import Tag
    from funkwhale_api.users.models import User
    from funkwhale_api.music.models import Track, Artist
    
    
    from . import filters
    
    from . import models
    from .registries import registry
    
    
    class SimpleRadio(object):
    
        def clean(self, instance):
            return
    
        def pick(self, choices, previous_choices=[]):
            return random.sample(set(choices).difference(previous_choices), 1)[0]
    
        def pick_many(self, choices, quantity):
            return random.sample(set(choices), quantity)
    
        def weighted_pick(self, choices, previous_choices=[]):
            total = sum(weight for c, weight in choices)
            r = random.uniform(0, total)
            upto = 0
            for choice, weight in choices:
                if upto + weight >= r:
                    return choice
                upto += weight
    
    
    class SessionRadio(SimpleRadio):
        def __init__(self, session=None):
            self.session = session
    
        def start_session(self, user, **kwargs):
            self.session = models.RadioSession.objects.create(user=user, radio_type=self.radio_type, **kwargs)
            return self.session
    
        def get_queryset(self):
            raise NotImplementedError
    
        def get_queryset_kwargs(self):
            return {}
    
        def get_choices(self, **kwargs):
            kwargs.update(self.get_queryset_kwargs())
            queryset = self.get_queryset(**kwargs)
            if self.session:
                queryset = self.filter_from_session(queryset)
            return queryset
    
        def filter_from_session(self, queryset):
            already_played = self.session.session_tracks.all().values_list('track', flat=True)
    
            queryset = queryset.exclude(pk__in=already_played)
    
            return queryset
    
        def pick(self, **kwargs):
            return self.pick_many(quantity=1, **kwargs)[0]
    
        def pick_many(self, quantity, **kwargs):
            choices = self.get_choices(**kwargs)
            picked_choices = super().pick_many(choices=choices, quantity=quantity)
            if self.session:
                for choice in picked_choices:
                    self.session.add(choice)
            return picked_choices
    
    
        def validate_session(self, data, **context):
            return data
    
    
    
    @registry.register(name='random')
    class RandomRadio(SessionRadio):
        def get_queryset(self, **kwargs):
            return Track.objects.all()
    
    @registry.register(name='favorites')
    class FavoritesRadio(SessionRadio):
    
        def get_queryset_kwargs(self):
            kwargs = super().get_queryset_kwargs()
            if self.session:
                kwargs['user'] = self.session.user
            return kwargs
    
        def get_queryset(self, **kwargs):
            track_ids = kwargs['user'].track_favorites.all().values_list('track', flat=True)
            return Track.objects.filter(pk__in=track_ids)
    
    
    
    @registry.register(name='custom')
    class CustomRadio(SessionRadio):
    
        def get_queryset_kwargs(self):
            kwargs = super().get_queryset_kwargs()
            kwargs['user'] = self.session.user
            kwargs['custom_radio'] = self.session.custom_radio
            return kwargs
    
        def get_queryset(self, **kwargs):
            return filters.run(kwargs['custom_radio'].config)
    
        def validate_session(self, data, **context):
            data = super().validate_session(data, **context)
            try:
                user = data['user']
            except KeyError:
                user = context['user']
            try:
                assert (
                    data['custom_radio'].user == user or
                    data['custom_radio'].is_public)
            except KeyError:
                raise serializers.ValidationError(
                    'You must provide a custom radio')
            except AssertionError:
                raise serializers.ValidationError(
                    "You don't have access to this radio")
            return data
    
    
    
    class RelatedObjectRadio(SessionRadio):
        """Abstract radio related to an object (tag, artist, user...)"""
    
        def clean(self, instance):
            super().clean(instance)
            if not instance.related_object:
                raise ValidationError('Cannot start RelatedObjectRadio without related object')
            if not isinstance(instance.related_object, self.model):
                raise ValidationError('Trying to start radio with bad related object')
    
        def get_related_object(self, pk):
            return self.model.objects.get(pk=pk)
    
    
    @registry.register(name='tag')
    class TagRadio(RelatedObjectRadio):
        model = Tag
    
        def get_queryset(self, **kwargs):
            return Track.objects.filter(tags__in=[self.session.related_object])
    
    @registry.register(name='artist')
    class ArtistRadio(RelatedObjectRadio):
        model = Artist
    
        def get_queryset(self, **kwargs):
            return self.session.related_object.tracks.all()
    
    
    @registry.register(name='less-listened')
    class LessListenedRadio(RelatedObjectRadio):
        model = User
    
        def clean(self, instance):
            instance.related_object = instance.user
            super().clean(instance)
    
        def get_queryset(self, **kwargs):
            listened = self.session.user.listenings.all().values_list('track', flat=True)
            return Track.objects.exclude(pk__in=listened).order_by('?')