From 550dbe46cc7d5fe2b239733bcd00e2a214b5ab11 Mon Sep 17 00:00:00 2001
From: Agate <me@agate.blue>
Date: Mon, 18 May 2020 12:03:30 +0200
Subject: [PATCH] Support session/cookie based auth, see #1108

---
 api/config/api_urls.py                    |  4 ++-
 api/config/routing.py                     | 11 +++++--
 api/config/settings/common.py             |  8 ++++-
 api/funkwhale_api/common/auth.py          |  3 ++
 api/funkwhale_api/common/middleware.py    |  8 ++++-
 api/funkwhale_api/users/api_urls.py       |  7 +++--
 api/funkwhale_api/users/serializers.py    | 22 +++++++++++++
 api/funkwhale_api/users/views.py          | 33 +++++++++++++++++++-
 api/tests/common/test_middleware.py       | 35 +++++++++++++++------
 api/tests/users/test_views.py             | 38 +++++++++++++++++++++++
 front/Dockerfile                          |  3 +-
 front/src/main.js                         |  5 ++-
 front/src/store/auth.js                   | 31 ++++++++----------
 front/tests/unit/specs/store/auth.spec.js | 26 ++--------------
 14 files changed, 172 insertions(+), 62 deletions(-)

diff --git a/api/config/api_urls.py b/api/config/api_urls.py
index b50066f3db..74c5e248da 100644
--- a/api/config/api_urls.py
+++ b/api/config/api_urls.py
@@ -77,9 +77,11 @@ v1_patterns += [
         r"^history/",
         include(("funkwhale_api.history.urls", "history"), namespace="history"),
     ),
+    url(r"^", include(("funkwhale_api.users.api_urls", "users"), namespace="users"),),
+    # XXX: 1.0: remove this
     url(
         r"^users/",
-        include(("funkwhale_api.users.api_urls", "users"), namespace="users"),
+        include(("funkwhale_api.users.api_urls", "users"), namespace="users-nested"),
     ),
     url(
         r"^oauth/",
diff --git a/api/config/routing.py b/api/config/routing.py
index 13a67cd1e4..d0858e2432 100644
--- a/api/config/routing.py
+++ b/api/config/routing.py
@@ -1,14 +1,19 @@
+from channels.auth import AuthMiddlewareStack
 from channels.routing import ProtocolTypeRouter, URLRouter
-from django.conf.urls import url
 
+from django.conf.urls import url
 from funkwhale_api.common.auth import TokenAuthMiddleware
 from funkwhale_api.instance import consumers
 
 application = ProtocolTypeRouter(
     {
         # Empty for now (http->django views is added by default)
-        "websocket": TokenAuthMiddleware(
-            URLRouter([url("^api/v1/activity$", consumers.InstanceActivityConsumer)])
+        "websocket": AuthMiddlewareStack(
+            TokenAuthMiddleware(
+                URLRouter(
+                    [url("^api/v1/activity$", consumers.InstanceActivityConsumer)]
+                )
+            )
         )
     }
 )
diff --git a/api/config/settings/common.py b/api/config/settings/common.py
index 87848881d0..ad24d43db2 100644
--- a/api/config/settings/common.py
+++ b/api/config/settings/common.py
@@ -276,10 +276,12 @@ MIDDLEWARE = tuple(ADDITIONAL_MIDDLEWARES_BEFORE) + (
     "django.middleware.security.SecurityMiddleware",
     "django.middleware.clickjacking.XFrameOptionsMiddleware",
     "corsheaders.middleware.CorsMiddleware",
-    "funkwhale_api.common.middleware.SPAFallbackMiddleware",
+    # needs to be before SPA middleware
     "django.contrib.sessions.middleware.SessionMiddleware",
     "django.middleware.common.CommonMiddleware",
     "django.middleware.csrf.CsrfViewMiddleware",
+    # /end
+    "funkwhale_api.common.middleware.SPAFallbackMiddleware",
     "django.contrib.auth.middleware.AuthenticationMiddleware",
     "django.contrib.messages.middleware.MessageMiddleware",
     "funkwhale_api.users.middleware.RecordActivityMiddleware",
@@ -998,6 +1000,10 @@ THROTTLING_RATES = {
         "rate": THROTTLING_USER_RATES.get("oauth-revoke-token", "100/hour"),
         "description": "OAuth token deletion",
     },
+    "login": {
+        "rate": THROTTLING_USER_RATES.get("login", "30/hour"),
+        "description": "Login",
+    },
     "jwt-login": {
         "rate": THROTTLING_USER_RATES.get("jwt-login", "30/hour"),
         "description": "JWT token creation",
diff --git a/api/funkwhale_api/common/auth.py b/api/funkwhale_api/common/auth.py
index 736364337f..b404bbca27 100644
--- a/api/funkwhale_api/common/auth.py
+++ b/api/funkwhale_api/common/auth.py
@@ -29,6 +29,9 @@ class TokenAuthMiddleware:
         self.inner = inner
 
     def __call__(self, scope):
+        if "user" in scope:
+            # auth already handled
+            return self.inner(scope)
         # XXX: 1.0 remove this, replace with websocket/scopedtoken
         auth = TokenHeaderAuth()
         try:
diff --git a/api/funkwhale_api/common/middleware.py b/api/funkwhale_api/common/middleware.py
index de06fd1d44..64bb6f80bf 100644
--- a/api/funkwhale_api/common/middleware.py
+++ b/api/funkwhale_api/common/middleware.py
@@ -10,6 +10,7 @@ import xml.sax.saxutils
 from django import http
 from django.conf import settings
 from django.core.cache import caches
+from django.middleware import csrf
 from django import urls
 from rest_framework import views
 
@@ -81,7 +82,12 @@ def serve_spa(request):
         body, tail = tail.split("</body>", 1)
         css = "<style>{}</style>".format(css)
         tail = body + "\n" + css + "\n</body>" + tail
-    return http.HttpResponse(head + tail)
+
+    # set a csrf token so that visitor can login / query API if needed
+    token = csrf.get_token(request)
+    response = http.HttpResponse(head + tail)
+    response.set_cookie("csrftoken", token, max_age=None)
+    return response
 
 
 MANIFEST_LINK_REGEX = re.compile(r"<link [^>]*rel=(?:'|\")?manifest(?:'|\")?[^>]*>")
diff --git a/api/funkwhale_api/users/api_urls.py b/api/funkwhale_api/users/api_urls.py
index 89930f57be..1c39797f2e 100644
--- a/api/funkwhale_api/users/api_urls.py
+++ b/api/funkwhale_api/users/api_urls.py
@@ -1,8 +1,11 @@
+from django.conf.urls import url
 from funkwhale_api.common import routers
-
 from . import views
 
 router = routers.OptionalSlashRouter()
 router.register(r"users", views.UserViewSet, "users")
 
-urlpatterns = router.urls
+urlpatterns = [
+    url(r"^users/login/?$", views.login, name="login"),
+    url(r"^users/logout/?$", views.logout, name="logout"),
+] + router.urls
diff --git a/api/funkwhale_api/users/serializers.py b/api/funkwhale_api/users/serializers.py
index 542f6e58a4..8646d3b4ab 100644
--- a/api/funkwhale_api/users/serializers.py
+++ b/api/funkwhale_api/users/serializers.py
@@ -4,6 +4,8 @@ from django.core import validators
 from django.utils.deconstruct import deconstructible
 from django.utils.translation import gettext_lazy as _
 
+from django.contrib import auth
+
 from rest_auth.serializers import PasswordResetSerializer as PRS
 from rest_auth.registration.serializers import RegisterSerializer as RS, get_adapter
 from rest_framework import serializers
@@ -265,3 +267,23 @@ class UserDeleteSerializer(serializers.Serializer):
         if not value:
             raise serializers.ValidationError("Please confirm deletion")
         return value
+
+
+class LoginSerializer(serializers.Serializer):
+    username = serializers.CharField()
+    password = serializers.CharField()
+
+    def validate(self, data):
+        user = auth.authenticate(request=self.context.get("request"), **data)
+        if not user:
+            raise serializers.ValidationError(
+                "Unable to log in with provided credentials"
+            )
+
+        if not user.is_active:
+            raise serializers.ValidationError("This account was disabled")
+
+        return user
+
+    def save(self, request):
+        return auth.login(request, self.validated_data)
diff --git a/api/funkwhale_api/users/views.py b/api/funkwhale_api/users/views.py
index 848bc7e6bc..a143c4fd24 100644
--- a/api/funkwhale_api/users/views.py
+++ b/api/funkwhale_api/users/views.py
@@ -1,12 +1,20 @@
+import json
+
+from django import http
+from django.contrib import auth
+from django.middleware import csrf
+
 from allauth.account.adapter import get_adapter
 from rest_auth import views as rest_auth_views
 from rest_auth.registration import views as registration_views
-from rest_framework import mixins, viewsets
+from rest_framework import mixins
+from rest_framework import viewsets
 from rest_framework.decorators import action
 from rest_framework.response import Response
 
 from funkwhale_api.common import authentication
 from funkwhale_api.common import preferences
+from funkwhale_api.common import throttling
 
 from . import models, serializers, tasks
 
@@ -105,3 +113,26 @@ class UserViewSet(mixins.UpdateModelMixin, viewsets.GenericViewSet):
         if not self.request.user.username == kwargs.get("username"):
             return Response(status=403)
         return super().partial_update(request, *args, **kwargs)
+
+
+def login(request):
+    throttling.check_request(request, "login")
+    if request.method != "POST":
+        return http.HttpResponse(status=405)
+    serializer = serializers.LoginSerializer(
+        data=request.POST, context={"request": request}
+    )
+    if not serializer.is_valid():
+        return http.HttpResponse(
+            json.dumps(serializer.errors), status=400, content_type="application/json"
+        )
+    serializer.save(request)
+    csrf.rotate_token(request)
+    return http.HttpResponse(status=200)
+
+
+def logout(request):
+    if request.method != "POST":
+        return http.HttpResponse(status=405)
+    auth.logout(request)
+    return http.HttpResponse(status=200)
diff --git a/api/tests/common/test_middleware.py b/api/tests/common/test_middleware.py
index 8f04ba3184..b5d4d02f13 100644
--- a/api/tests/common/test_middleware.py
+++ b/api/tests/common/test_middleware.py
@@ -14,7 +14,7 @@ from funkwhale_api.common import utils
 def test_spa_fallback_middleware_no_404(mocker):
     get_response = mocker.Mock()
     get_response.return_value = mocker.Mock(status_code=200)
-    request = mocker.Mock(path="/")
+    request = mocker.Mock(path="/", META={})
     m = middleware.SPAFallbackMiddleware(get_response)
 
     assert m(request) == get_response.return_value
@@ -26,7 +26,7 @@ def test_spa_middleware_calls_should_fallback_false(mocker):
     should_falback = mocker.patch.object(
         middleware, "should_fallback_to_spa", return_value=False
     )
-    request = mocker.Mock(path="/")
+    request = mocker.Mock(path="/", META={})
 
     m = middleware.SPAFallbackMiddleware(get_response)
 
@@ -37,7 +37,7 @@ def test_spa_middleware_calls_should_fallback_false(mocker):
 def test_spa_middleware_should_fallback_true(mocker):
     get_response = mocker.Mock()
     get_response.return_value = mocker.Mock(status_code=404)
-    request = mocker.Mock(path="/")
+    request = mocker.Mock(path="/", META={})
     mocker.patch.object(middleware, "should_fallback_to_spa", return_value=True)
     serve_spa = mocker.patch.object(middleware, "serve_spa")
     m = middleware.SPAFallbackMiddleware(get_response)
@@ -56,7 +56,7 @@ def test_should_fallback(path, expected, mocker):
 
 def test_serve_spa_from_cache(mocker, settings, preferences, no_api_auth):
     preferences["instance__name"] = 'Best Funkwhale "pod"'
-    request = mocker.Mock(path="/")
+    request = mocker.Mock(path="/", META={})
     get_spa_html = mocker.patch.object(
         middleware,
         "get_spa_html",
@@ -155,7 +155,7 @@ def test_get_route_head_tags(mocker, settings):
 
 
 def test_serve_spa_includes_custom_css(mocker, no_api_auth):
-    request = mocker.Mock(path="/")
+    request = mocker.Mock(path="/", META={})
     mocker.patch.object(
         middleware,
         "get_spa_html",
@@ -178,6 +178,23 @@ def test_serve_spa_includes_custom_css(mocker, no_api_auth):
     assert response.content == "\n".join(expected).encode()
 
 
+def test_serve_spa_sets_csrf_token(mocker, no_api_auth):
+    request = mocker.Mock(path="/", META={})
+    get_token = mocker.patch.object(middleware.csrf, "get_token", return_value="test")
+    mocker.patch.object(
+        middleware,
+        "get_spa_html",
+        return_value="<html><head></head><body></body></html>",
+    )
+    mocker.patch.object(middleware, "get_default_head_tags", return_value=[])
+    mocker.patch.object(middleware, "get_request_head_tags", return_value=[])
+    response = middleware.serve_spa(request)
+
+    assert response.status_code == 200
+    get_token.assert_called_once_with(request)
+    assert response.cookies["csrftoken"].value == get_token.return_value
+
+
 @pytest.mark.parametrize(
     "custom_css, expected",
     [
@@ -281,7 +298,7 @@ def test_rewrite_manifest_json_url(link, new_url, expected, mocker, settings):
     spa_html = "<html><head><link rel=before>{}<link rel=after></head></html>".format(
         link
     )
-    request = mocker.Mock(path="/")
+    request = mocker.Mock(path="/", META={})
     mocker.patch.object(middleware, "get_spa_html", return_value=spa_html)
     mocker.patch.object(
         middleware, "get_default_head_tags", return_value=[],
@@ -299,7 +316,7 @@ def test_rewrite_manifest_json_url_rewrite_disabled(mocker, settings):
     settings.FUNKWHALE_SPA_REWRITE_MANIFEST = False
     settings.FUNKWHALE_SPA_REWRITE_MANIFEST_URL = "custom_url"
     spa_html = "<html><head><link href=/manifest.json rel=manifest></head></html>"
-    request = mocker.Mock(path="/")
+    request = mocker.Mock(path="/", META={})
     mocker.patch.object(middleware, "get_spa_html", return_value=spa_html)
     mocker.patch.object(
         middleware, "get_default_head_tags", return_value=[],
@@ -318,7 +335,7 @@ def test_rewrite_manifest_json_url_rewrite_default_url(mocker, settings):
     settings.FUNKWHALE_SPA_REWRITE_MANIFEST_URL = None
     spa_html = "<html><head><link href=/manifest.json rel=manifest></head></html>"
     expected_url = federation_utils.full_url(reverse("api:v1:instance:spa-manifest"))
-    request = mocker.Mock(path="/")
+    request = mocker.Mock(path="/", META={})
     mocker.patch.object(middleware, "get_spa_html", return_value=spa_html)
     mocker.patch.object(
         middleware, "get_default_head_tags", return_value=[],
@@ -342,7 +359,7 @@ def test_spa_middleware_handles_api_redirect(mocker):
     match = mocker.Mock(args=["hello"], kwargs={"foo": "bar"}, func=api_view)
     mocker.patch.object(middleware.urls, "resolve", return_value=match)
 
-    request = mocker.Mock(path="/")
+    request = mocker.Mock(path="/", META={})
 
     m = middleware.SPAFallbackMiddleware(get_response)
 
diff --git a/api/tests/users/test_views.py b/api/tests/users/test_views.py
index 1b75d98168..12e3317982 100644
--- a/api/tests/users/test_views.py
+++ b/api/tests/users/test_views.py
@@ -1,6 +1,8 @@
 import pytest
 from django.urls import reverse
 
+from django.test import Client
+
 from funkwhale_api.common import serializers as common_serializers
 from funkwhale_api.common import utils as common_utils
 from funkwhale_api.moderation import tasks as moderation_tasks
@@ -518,3 +520,39 @@ def test_user_login_jwt_honor_email_verification(
     url = reverse("api:v1:token")
     response = api_client.post(url, data)
     assert response.status_code == expected_status_code
+
+
+def test_login_via_api(api_client, factories):
+    user = factories["users.User"]()
+    url = reverse("api:v1:users:login")
+    payload = {"username": user.username, "password": "test"}
+
+    response = api_client.post(url, payload)
+    assert response.status_code == 200
+    assert api_client.session["_auth_user_id"] == str(user.pk)
+
+
+def test_login_via_api_inactive(api_client, factories):
+    user = factories["users.User"](is_active=False)
+    url = reverse("api:v1:users:login")
+    payload = {"username": user.username, "password": "test"}
+
+    response = api_client.post(url, payload)
+    assert response.status_code == 400
+
+
+def test_login_via_api_no_csrf(factories):
+    user = factories["users.User"]()
+    url = reverse("api:v1:users:login")
+    payload = {"username": user.username, "password": "test"}
+    api_client = Client(enforce_csrf_checks=True)
+    response = api_client.post(url, payload)
+    assert response.status_code == 403
+
+
+def test_logout(api_client, factories, mocker):
+    auth_logout = mocker.patch("django.contrib.auth.logout")
+    url = reverse("api:v1:users:logout")
+    response = api_client.post(url)
+    assert response.status_code == 200
+    assert auth_logout.call_count == 1
diff --git a/front/Dockerfile b/front/Dockerfile
index f5d832ce04..90a075b49c 100644
--- a/front/Dockerfile
+++ b/front/Dockerfile
@@ -5,7 +5,8 @@ RUN curl -L -o /usr/local/bin/jq https://github.com/stedolan/jq/releases/downloa
 
 EXPOSE 8080
 WORKDIR /app/
-ADD package.json yarn.lock ./
+COPY scripts/ ./scripts/
+ADD package.json yarn.lock  ./
 RUN yarn install
 
 COPY . .
diff --git a/front/src/main.js b/front/src/main.js
index 30286cd07d..9047d67440 100644
--- a/front/src/main.js
+++ b/front/src/main.js
@@ -68,7 +68,10 @@ Vue.directive('dropdown', function (el, binding) {
     ...(binding.value || {})
   })
 })
+axios.defaults.xsrfCookieName = 'csrftoken'
+axios.defaults.xsrfHeaderName = 'X-CSRFToken'
 axios.interceptors.request.use(function (config) {
+
   // Do something before request is sent
   if (store.state.auth.token) {
     config.headers['Authorization'] = store.getters['auth/header']
@@ -84,7 +87,7 @@ axios.interceptors.response.use(function (response) {
   return response
 }, function (error) {
   error.backendErrors = []
-  if (error.response.status === 401) {
+  if (store.state.auth.authenticated && error.response.status === 401) {
     store.commit('auth/authenticated', false)
     logger.default.warn('Received 401 response from API, redirecting to login form', router.currentRoute.fullPath)
     router.push({name: 'login', query: {next: router.currentRoute.fullPath}})
diff --git a/front/src/store/auth.js b/front/src/store/auth.js
index 700288d1e2..8919dc1226 100644
--- a/front/src/store/auth.js
+++ b/front/src/store/auth.js
@@ -89,9 +89,13 @@ export default {
   actions: {
     // Send a request to the login URL and save the returned JWT
     login ({commit, dispatch}, {next, credentials, onError}) {
-      return axios.post('token/', credentials).then(response => {
+      var form = new FormData();
+      Object.keys(credentials).forEach((k) => {
+        form.set(k, credentials[k])
+      })
+      return axios.post('users/login', form).then(response => {
         logger.default.info('Successfully logged in as', credentials.username)
-        commit('token', response.data.token)
+        // commit('token', response.data.token)
         dispatch('fetchProfile').then(() => {
           // Redirect to a specified route
           router.push(next)
@@ -101,7 +105,8 @@ export default {
         onError(response)
       })
     },
-    logout ({commit}) {
+    async logout ({commit}) {
+      await axios.post('users/logout')
       let modules = [
         'auth',
         'favorites',
@@ -116,16 +121,14 @@ export default {
       logger.default.info('Log out, goodbye!')
       router.push({name: 'index'})
     },
-    check ({commit, dispatch, state}) {
+    async check ({commit, dispatch, state}) {
       logger.default.info('Checking authentication…')
-      var jwt = state.token
-      if (jwt) {
-        commit('token', jwt)
-        dispatch('fetchProfile')
-        dispatch('refreshToken')
+      commit('authenticated', false)
+      let profile = await dispatch('fetchProfile')
+      if (profile) {
+        commit('authenticated', true)
       } else {
         logger.default.info('Anonymous user')
-        commit('authenticated', false)
       }
     },
     fetchProfile ({commit, dispatch, state}) {
@@ -174,13 +177,5 @@ export default {
         resolve()
       })
     },
-    refreshToken ({commit, dispatch, state}) {
-      return axios.post('token/refresh/', {token: state.token}).then(response => {
-        logger.default.info('Refreshed auth token')
-        commit('token', response.data.token)
-      }, response => {
-        logger.default.error('Error while refreshing token', response.data)
-      })
-    }
   }
 }
diff --git a/front/tests/unit/specs/store/auth.spec.js b/front/tests/unit/specs/store/auth.spec.js
index 625c55edc6..63c6d2da06 100644
--- a/front/tests/unit/specs/store/auth.spec.js
+++ b/front/tests/unit/specs/store/auth.spec.js
@@ -91,20 +91,11 @@ describe('store/auth', () => {
         action: store.actions.check,
         params: {state: {}},
         expectedMutations: [
-          { type: 'authenticated', payload: false }
-        ]
-      })
-    })
-    it('check jwt set', () => {
-      testAction({
-        action: store.actions.check,
-        params: {state: {token: 'test', username: 'user'}},
-        expectedMutations: [
-          { type: 'token', payload: 'test' }
+          { type: 'authenticated', payload: false },
+          { type: 'authenticated', payload: true },
         ],
         expectedActions: [
           { type: 'fetchProfile' },
-          { type: 'refreshToken' }
         ]
       })
     })
@@ -173,18 +164,5 @@ describe('store/auth', () => {
         ]
       })
     })
-    it('refreshToken', () => {
-      moxios.stubRequest('token/refresh/', {
-        status: 200,
-        response: {token: 'newtoken'}
-      })
-      testAction({
-        action: store.actions.refreshToken,
-        params: {state: {token: 'oldtoken'}},
-        expectedMutations: [
-          { type: 'token', payload: 'newtoken' }
-        ]
-      })
-    })
   })
 })
-- 
GitLab