Skip to content
Snippets Groups Projects
Verified Commit a1d6eeea authored by Eliot Berriot's avatar Eliot Berriot
Browse files

See #261: Support manual throttle check in non rest_framework views

parent ce919e2c
No related branches found
No related tags found
No related merge requests found
......@@ -616,11 +616,13 @@ REST_FRAMEWORK = {
"django_filters.rest_framework.DjangoFilterBackend",
),
"DEFAULT_RENDERER_CLASSES": ("rest_framework.renderers.JSONRenderer",),
"DEFAULT_THROTTLE_CLASSES": env.list(
}
THROTTLING_ENABLED = env.bool("THROTTLING_ENABLED", default=True)
if THROTTLING_ENABLED:
REST_FRAMEWORK["DEFAULT_THROTTLE_CLASSES"] = env.list(
"THROTTLE_CLASSES",
default=["funkwhale_api.common.throttling.FunkwhaleThrottle"],
),
}
)
THROTTLING_SCOPES = {
"*": {"anonymous": "anonymous-wildcard", "authenticated": "authenticated-wildcard"}
......
......@@ -10,6 +10,7 @@ from django import urls
from rest_framework import views
from . import preferences
from . import throttling
from . import utils
EXCLUDED_PATHS = ["/api", "/federation", "/.well-known"]
......@@ -209,12 +210,18 @@ class ThrottleStatusMiddleware:
self.get_response = get_response
def __call__(self, request):
response = self.get_response(request)
try:
api_request = request._api_request
response = self.get_response(request)
except throttling.TooManyRequests:
# manual throttling in non rest_framework view, we have to return
# the proper response ourselves
response = http.HttpResponse(status=429)
request_to_check = request
try:
request_to_check = request._api_request
except AttributeError:
return response
throttle_status = getattr(api_request, "_throttle_status", None)
pass
throttle_status = getattr(request_to_check, "_throttle_status", None)
if throttle_status:
response["X-RateLimit-Limit"] = str(throttle_status["num_requests"])
response["X-RateLimit-Scope"] = str(throttle_status["scope"])
......
......@@ -74,3 +74,30 @@ class FunkwhaleThrottle(rest_throttling.SimpleRateThrottle):
def throttle_failure(self):
self.attach_info()
return super().throttle_failure()
class TooManyRequests(Exception):
pass
DummyView = collections.namedtuple("DummyView", "action throttling_scopes")
def check_request(request, scope):
"""
A simple wrapper around FunkwhaleThrottle for views that aren't API views
or cannot use rest_framework automatic throttling.
Raise TooManyRequests if limit is reached.
"""
if not settings.THROTTLING_ENABLED:
return True
view = DummyView(
action=scope,
throttling_scopes={scope: {"anonymous": scope, "authenticated": scope}},
)
throttle = FunkwhaleThrottle()
if not throttle.allow_request(request, view):
raise TooManyRequests()
return True
......@@ -4,6 +4,7 @@ import pytest
from django.http import HttpResponse
from funkwhale_api.common import middleware
from funkwhale_api.common import throttling
def test_spa_fallback_middleware_no_404(mocker):
......@@ -213,3 +214,12 @@ def test_throttle_status_middleware_includes_info_in_response_headers(mocker):
assert response["X-RateLimit-Duration"] == "3600"
assert response["X-RateLimit-Scope"] == "hello"
assert response["X-RateLimit-Reset"] == str(int(time.time()) + 2000)
def test_throttle_status_middleware_returns_proper_response(mocker):
get_response = mocker.Mock(side_effect=throttling.TooManyRequests())
request = mocker.Mock(path="/", _api_request=None, _throttle_status=None)
m = middleware.ThrottleStatusMiddleware(get_response)
response = m(request)
assert response.status_code == 429
......@@ -266,3 +266,33 @@ def test_throttle_calls_attach_info(method, mocker):
func()
attach_info.assert_called_once_with()
def test_allow_request(api_request, settings, mocker):
settings.THROTTLING_RATES = {"test": "2/s"}
ip = "92.92.92.92"
request = api_request.get("/", HTTP_X_FORWARDED_FOR=ip)
allow_request = mocker.spy(throttling.FunkwhaleThrottle, "allow_request")
action = "test"
throttling_scopes = {"test": {"anonymous": "test", "authenticated": "test"}}
throttling.check_request(request, action)
throttling.check_request(request, action)
with pytest.raises(throttling.TooManyRequests):
throttling.check_request(request, action)
assert allow_request.call_count == 3
assert allow_request.call_args[0][1] == request
assert allow_request.call_args[0][2] == throttling.DummyView(
action=action, throttling_scopes=throttling_scopes
)
def test_allow_request_throttling_disabled(api_request, settings):
settings.THROTTLING_RATES = {"test": "1/s"}
settings.THROTTLING_ENABLED = False
ip = "92.92.92.92"
request = api_request.get("/", HTTP_X_FORWARDED_FOR=ip)
action = "test"
throttling.check_request(request, action)
# even exceeding request doesn't raise any exception
throttling.check_request(request, action)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment