From bf8b143700006179f3262bed7a6722c5f6fa62f7 Mon Sep 17 00:00:00 2001
From: Eliot Berriot <contact@eliotberriot.com>
Date: Thu, 21 Jun 2018 19:21:51 +0200
Subject: [PATCH] See #248: better structure for action serializers

---
 api/funkwhale_api/common/serializers.py     | 43 +++++++++++++--------
 api/funkwhale_api/federation/serializers.py |  2 +-
 api/funkwhale_api/manage/serializers.py     |  3 +-
 api/tests/common/test_serializers.py        | 37 ++++++++++++------
 4 files changed, 54 insertions(+), 31 deletions(-)

diff --git a/api/funkwhale_api/common/serializers.py b/api/funkwhale_api/common/serializers.py
index 029338ef..b3e2d310 100644
--- a/api/funkwhale_api/common/serializers.py
+++ b/api/funkwhale_api/common/serializers.py
@@ -1,6 +1,17 @@
+import collections
 from rest_framework import serializers
 
 
+class Action(object):
+    def __init__(self, name, allow_all=False, filters=None):
+        self.name = name
+        self.allow_all = allow_all
+        self.filters = filters or {}
+
+    def __repr__(self):
+        return "<Action {}>".format(self.name)
+
+
 class ActionSerializer(serializers.Serializer):
     """
     A special serializer that can operate on a list of objects
@@ -11,19 +22,16 @@ class ActionSerializer(serializers.Serializer):
     objects = serializers.JSONField(required=True)
     filters = serializers.DictField(required=False)
     actions = None
-    filterset_class = None
-    # those are actions identifier where we don't want to allow the "all"
-    # selector because it's to dangerous. Like object deletion.
-    dangerous_actions = []
 
     def __init__(self, *args, **kwargs):
+        self.actions_by_name = {a.name: a for a in self.actions}
         self.queryset = kwargs.pop("queryset")
         if self.actions is None:
             raise ValueError(
                 "You must declare a list of actions on " "the serializer class"
             )
 
-        for action in self.actions:
+        for action in self.actions_by_name.keys():
             handler_name = "handle_{}".format(action)
             assert hasattr(self, handler_name), "{} miss a {} method".format(
                 self.__class__.__name__, handler_name
@@ -31,13 +39,14 @@ class ActionSerializer(serializers.Serializer):
         super().__init__(self, *args, **kwargs)
 
     def validate_action(self, value):
-        if value not in self.actions:
+        try:
+            return self.actions_by_name[value]
+        except KeyError:
             raise serializers.ValidationError(
                 "{} is not a valid action. Pick one of {}.".format(
-                    value, ", ".join(self.actions)
+                    value, ", ".join(self.actions_by_name.keys())
                 )
             )
-        return value
 
     def validate_objects(self, value):
         if value == "all":
@@ -51,15 +60,15 @@ class ActionSerializer(serializers.Serializer):
         )
 
     def validate(self, data):
-        dangerous = data["action"] in self.dangerous_actions
-        if dangerous and self.initial_data["objects"] == "all":
+        allow_all = data["action"].allow_all
+        if not allow_all and self.initial_data["objects"] == "all":
             raise serializers.ValidationError(
-                "This action is to dangerous to be applied to all objects"
-            )
-        if self.filterset_class and "filters" in data:
-            qs_filterset = self.filterset_class(
-                data["filters"], queryset=data["objects"]
+                "You cannot apply this action on all objects"
             )
+        final_filters = data.get("filters", {}) or {}
+        final_filters.update(data["action"].filters)
+        if self.filterset_class and final_filters:
+            qs_filterset = self.filterset_class(final_filters, queryset=data["objects"])
             try:
                 assert qs_filterset.form.is_valid()
             except (AssertionError, TypeError):
@@ -72,12 +81,12 @@ class ActionSerializer(serializers.Serializer):
         return data
 
     def save(self):
-        handler_name = "handle_{}".format(self.validated_data["action"])
+        handler_name = "handle_{}".format(self.validated_data["action"].name)
         handler = getattr(self, handler_name)
         result = handler(self.validated_data["objects"])
         payload = {
             "updated": self.validated_data["count"],
-            "action": self.validated_data["action"],
+            "action": self.validated_data["action"].name,
             "result": result,
         }
         return payload
diff --git a/api/funkwhale_api/federation/serializers.py b/api/funkwhale_api/federation/serializers.py
index 062f74f4..451c199c 100644
--- a/api/funkwhale_api/federation/serializers.py
+++ b/api/funkwhale_api/federation/serializers.py
@@ -769,7 +769,7 @@ class CollectionSerializer(serializers.Serializer):
 
 
 class LibraryTrackActionSerializer(common_serializers.ActionSerializer):
-    actions = ["import"]
+    actions = [common_serializers.Action('import', allow_all=True)]
     filterset_class = filters.LibraryTrackFilter
 
     @transaction.atomic
diff --git a/api/funkwhale_api/manage/serializers.py b/api/funkwhale_api/manage/serializers.py
index e8f1e328..f5d52bca 100644
--- a/api/funkwhale_api/manage/serializers.py
+++ b/api/funkwhale_api/manage/serializers.py
@@ -61,8 +61,7 @@ class ManageTrackFileSerializer(serializers.ModelSerializer):
 
 
 class ManageTrackFileActionSerializer(common_serializers.ActionSerializer):
-    actions = ["delete"]
-    dangerous_actions = ["delete"]
+    actions = [common_serializers.Action("delete", allow_all=False)]
     filterset_class = filters.ManageTrackFileFilterSet
 
     @transaction.atomic
diff --git a/api/tests/common/test_serializers.py b/api/tests/common/test_serializers.py
index ca5e5ad8..dbbd38a0 100644
--- a/api/tests/common/test_serializers.py
+++ b/api/tests/common/test_serializers.py
@@ -11,7 +11,7 @@ class TestActionFilterSet(django_filters.FilterSet):
 
 
 class TestSerializer(serializers.ActionSerializer):
-    actions = ["test"]
+    actions = [serializers.Action("test", allow_all=True)]
     filterset_class = TestActionFilterSet
 
     def handle_test(self, objects):
@@ -19,8 +19,10 @@ class TestSerializer(serializers.ActionSerializer):
 
 
 class TestDangerousSerializer(serializers.ActionSerializer):
-    actions = ["test", "test_dangerous"]
-    dangerous_actions = ["test_dangerous"]
+    actions = [
+        serializers.Action("test", allow_all=True),
+        serializers.Action("test_dangerous"),
+    ]
 
     def handle_test(self, objects):
         pass
@@ -29,6 +31,14 @@ class TestDangerousSerializer(serializers.ActionSerializer):
         pass
 
 
+class TestDeleteOnlyInactiveSerializer(serializers.ActionSerializer):
+    actions = [serializers.Action("test", allow_all=True, filters={"is_active": False})]
+    filterset_class = TestActionFilterSet
+
+    def handle_test(self, objects):
+        pass
+
+
 def test_action_serializer_validates_action():
     data = {"objects": "all", "action": "nope"}
     serializer = TestSerializer(data, queryset=models.User.objects.none())
@@ -52,7 +62,7 @@ def test_action_serializers_objects_clean_ids(factories):
     data = {"objects": [user1.pk], "action": "test"}
     serializer = TestSerializer(data, queryset=models.User.objects.all())
 
-    assert serializer.is_valid() is True
+    assert serializer.is_valid(raise_exception=True) is True
     assert list(serializer.validated_data["objects"]) == [user1]
 
 
@@ -63,7 +73,7 @@ def test_action_serializers_objects_clean_all(factories):
     data = {"objects": "all", "action": "test"}
     serializer = TestSerializer(data, queryset=models.User.objects.all())
 
-    assert serializer.is_valid() is True
+    assert serializer.is_valid(raise_exception=True) is True
     assert list(serializer.validated_data["objects"]) == [user1, user2]
 
 
@@ -75,7 +85,7 @@ def test_action_serializers_save(factories, mocker):
     data = {"objects": "all", "action": "test"}
     serializer = TestSerializer(data, queryset=models.User.objects.all())
 
-    assert serializer.is_valid() is True
+    assert serializer.is_valid(raise_exception=True) is True
     result = serializer.save()
     assert result == {"updated": 2, "action": "test", "result": {"hello": "world"}}
     handler.assert_called_once()
@@ -88,7 +98,7 @@ def test_action_serializers_filterset(factories):
     data = {"objects": "all", "action": "test", "filters": {"is_active": True}}
     serializer = TestSerializer(data, queryset=models.User.objects.all())
 
-    assert serializer.is_valid() is True
+    assert serializer.is_valid(raise_exception=True) is True
     assert list(serializer.validated_data["objects"]) == [user2]
 
 
@@ -109,9 +119,14 @@ def test_dangerous_actions_refuses_all(factories):
     assert "non_field_errors" in serializer.errors
 
 
-def test_dangerous_actions_refuses_not_listed(factories):
-    factories["users.User"]()
+def test_action_serializers_can_require_filter(factories):
+    user1 = factories["users.User"](is_active=False)
+    factories["users.User"](is_active=True)
+
     data = {"objects": "all", "action": "test"}
-    serializer = TestDangerousSerializer(data, queryset=models.User.objects.all())
+    serializer = TestDeleteOnlyInactiveSerializer(
+        data, queryset=models.User.objects.all()
+    )
 
-    assert serializer.is_valid() is True
+    assert serializer.is_valid(raise_exception=True) is True
+    assert list(serializer.validated_data["objects"]) == [user1]
-- 
GitLab