From c8fcf1b0d9e5693e6ca774f2bb63996bdc598fd3 Mon Sep 17 00:00:00 2001
From: Eliot Berriot <contact@eliotberriot.com>
Date: Fri, 3 May 2019 12:23:45 +0200
Subject: [PATCH] Support oauth token in URL

---
 api/config/settings/common.py           |  1 +
 api/funkwhale_api/users/oauth/server.py | 25 +++++++++++++++++
 api/tests/test_auth.py                  | 36 +++++++++++++++++++++++++
 api/tests/test_jwt_querystring.py       | 19 -------------
 4 files changed, 62 insertions(+), 19 deletions(-)
 create mode 100644 api/funkwhale_api/users/oauth/server.py
 create mode 100644 api/tests/test_auth.py
 delete mode 100644 api/tests/test_jwt_querystring.py

diff --git a/api/config/settings/common.py b/api/config/settings/common.py
index 3000fedd..bf357e17 100644
--- a/api/config/settings/common.py
+++ b/api/config/settings/common.py
@@ -374,6 +374,7 @@ OAUTH2_PROVIDER = {
     "REFRESH_TOKEN_EXPIRE_SECONDS": 3600 * 24 * 15,
     "AUTHORIZATION_CODE_EXPIRE_SECONDS": 5 * 60,
     "ACCESS_TOKEN_EXPIRE_SECONDS": 60 * 60 * 10,
+    "OAUTH2_SERVER_CLASS": "funkwhale_api.users.oauth.server.OAuth2Server",
 }
 OAUTH2_PROVIDER_APPLICATION_MODEL = "users.Application"
 OAUTH2_PROVIDER_ACCESS_TOKEN_MODEL = "users.AccessToken"
diff --git a/api/funkwhale_api/users/oauth/server.py b/api/funkwhale_api/users/oauth/server.py
new file mode 100644
index 00000000..f62ebf48
--- /dev/null
+++ b/api/funkwhale_api/users/oauth/server.py
@@ -0,0 +1,25 @@
+import urllib.parse
+import oauthlib.oauth2
+
+
+class OAuth2Server(oauthlib.oauth2.Server):
+    def verify_request(self, uri, *args, **kwargs):
+        valid, request = super().verify_request(uri, *args, **kwargs)
+        if valid:
+            return valid, request
+
+        # maybe the token was given in the querystring?
+        query = urllib.parse.urlparse(request.uri).query
+        token = None
+        if query:
+            parsed_qs = urllib.parse.parse_qs(query)
+            token = parsed_qs.get("token", [])
+            if len(token) > 0:
+                token = token[0]
+
+        if token:
+            valid = self.request_validator.validate_bearer_token(
+                token, request.scopes, request
+            )
+
+        return valid, request
diff --git a/api/tests/test_auth.py b/api/tests/test_auth.py
new file mode 100644
index 00000000..653110f0
--- /dev/null
+++ b/api/tests/test_auth.py
@@ -0,0 +1,36 @@
+from django.urls import reverse
+from rest_framework_jwt.settings import api_settings
+
+jwt_payload_handler = api_settings.JWT_PAYLOAD_HANDLER
+jwt_encode_handler = api_settings.JWT_ENCODE_HANDLER
+
+
+def test_can_authenticate_using_jwt_token_param_in_url(factories, preferences, client):
+    user = factories["users.User"]()
+    preferences["common__api_authentication_required"] = True
+    url = reverse("api:v1:tracks-list")
+    response = client.get(url)
+
+    assert response.status_code == 401
+
+    payload = jwt_payload_handler(user)
+    token = jwt_encode_handler(payload)
+    response = client.get(url, data={"jwt": token})
+    assert response.status_code == 200
+
+
+def test_can_authenticate_using_oauth_token_param_in_url(
+    factories, preferences, client, mocker
+):
+    mocker.patch(
+        "funkwhale_api.users.oauth.permissions.should_allow", return_value=True
+    )
+    token = factories["users.AccessToken"]()
+    preferences["common__api_authentication_required"] = True
+    url = reverse("api:v1:tracks-list")
+    response = client.get(url)
+
+    assert response.status_code == 401
+
+    response = client.get(url, data={"token": token.token})
+    assert response.status_code == 200
diff --git a/api/tests/test_jwt_querystring.py b/api/tests/test_jwt_querystring.py
deleted file mode 100644
index 18a673fb..00000000
--- a/api/tests/test_jwt_querystring.py
+++ /dev/null
@@ -1,19 +0,0 @@
-from django.urls import reverse
-from rest_framework_jwt.settings import api_settings
-
-jwt_payload_handler = api_settings.JWT_PAYLOAD_HANDLER
-jwt_encode_handler = api_settings.JWT_ENCODE_HANDLER
-
-
-def test_can_authenticate_using_token_param_in_url(factories, preferences, client):
-    user = factories["users.User"]()
-    preferences["common__api_authentication_required"] = True
-    url = reverse("api:v1:tracks-list")
-    response = client.get(url)
-
-    assert response.status_code == 401
-
-    payload = jwt_payload_handler(user)
-    token = jwt_encode_handler(payload)
-    response = client.get(url, data={"jwt": token})
-    assert response.status_code == 200
-- 
GitLab