From 64850e787be5d957979f1d0c360e30ccb4c1ac31 Mon Sep 17 00:00:00 2001 From: Eliot Berriot <contact@eliotberriot.com> Date: Thu, 6 Jun 2019 15:25:01 +0200 Subject: [PATCH] Added caching for webfinger/AP responses to reduce latency --- config/settings/base.py | 2 +- retribute_api/cache.py | 60 +++++++++++++++++++++++++++++++ retribute_api/search/consumers.py | 6 ++-- retribute_api/search/sources.py | 22 +++++++----- retribute_api/search/webfinger.py | 11 ++++-- setup.cfg | 1 + tests/conftest.py | 8 +++++ tests/search/test_consumers.py | 9 +++-- tests/search/test_sources.py | 10 ++++-- tests/search/test_webfinger.py | 10 ++++-- 10 files changed, 119 insertions(+), 20 deletions(-) create mode 100644 retribute_api/cache.py diff --git a/config/settings/base.py b/config/settings/base.py index fa5d814..aa4adc4 100644 --- a/config/settings/base.py +++ b/config/settings/base.py @@ -218,7 +218,7 @@ ALLOWED_HOSTS = env.list("DJANGO_ALLOWED_HOSTS", default=["retribute.me"]) # CACHES # ------------------------------------------------------------------------------ CACHES = {"default": env.cache()} - +ASYNC_REDIS_PARAMS = {"address": CACHES["default"]["LOCATION"]} CHANNEL_LAYERS = { "default": { "BACKEND": "channels_redis.core.RedisChannelLayer", diff --git a/retribute_api/cache.py b/retribute_api/cache.py new file mode 100644 index 0000000..779af31 --- /dev/null +++ b/retribute_api/cache.py @@ -0,0 +1,60 @@ +import aioredis +import asyncio +import json + +from django.conf import settings + + +class Backend: + class NotFound(Exception): + pass + + async def get(self, key): + raise NotImplementedError + + async def set(self, key): + raise NotImplementedError + + +class Dummy(Backend): + def __init__(self): + self._cache = {} + + async def get(self, key): + try: + return self._cache[key] + except KeyError: + raise self.NotFound(key) + + async def set(self, key, value): + self._cache[key] = value + + +class Redis(Backend): + def __init__(self, params): + self.params = params + self._redis = None + + async def redis(self): + if self._redis: + return self._redis + self._redis = await aioredis.create_redis( + loop=asyncio.get_event_loop(), **self.params + ) + return self._redis + + async def get(self, key): + r = await self.redis() + try: + v = await r.get(key) + except KeyError: + raise self.NotFound(key) + return json.loads(v) + + async def set(self, key, value): + r = await self.redis() + await r.set(key, json.dumps(value)) + + +def get_default(): + return Redis(settings.ASYNC_REDIS_PARAMS) diff --git a/retribute_api/search/consumers.py b/retribute_api/search/consumers.py index eb3867a..c7d05c3 100644 --- a/retribute_api/search/consumers.py +++ b/retribute_api/search/consumers.py @@ -5,6 +5,7 @@ import ssl from channels.generic.http import AsyncHttpConsumer +from .. import cache from . import exceptions from . import sources from . import serializers @@ -30,6 +31,7 @@ def wrapper_500(callback): await callback(self, body) except Exception as e: await json_response(self, 400, {"detail": str(e)}) + raise return callback @@ -92,7 +94,7 @@ class SearchSingleConsumer(AsyncHttpConsumer): await json_response(self, 400, {"detail": "Invalid lookup"}) try: async with aiohttp.client.ClientSession(timeout=aiohttp_timeout) as session: - data = await source.get(lookup, session) + data = await source.get(lookup, session, cache=cache.get_default()) profile = sources.result_to_retribute_profile(lookup_type, lookup, data) except (exceptions.SearchError, aiohttp.ClientError) as e: await json_response(self, 400, {"detail": e.message}) @@ -105,7 +107,7 @@ class SearchSingleConsumer(AsyncHttpConsumer): async def do_lookup(lookup, lookup_type, session, source, results): try: - data = await source.get(lookup, session) + data = await source.get(lookup, session, cache=cache.get_default()) profile = sources.result_to_retribute_profile(lookup_type, lookup, data) except ( exceptions.SearchError, diff --git a/retribute_api/search/sources.py b/retribute_api/search/sources.py index ea74c24..6285062 100644 --- a/retribute_api/search/sources.py +++ b/retribute_api/search/sources.py @@ -37,12 +37,16 @@ class Activitypub(Source): # '#nobot', ] - async def get(self, lookup, session): - async with session.get( - lookup, headers={"Accept": "application/activity+json"} - ) as response: - response.raise_for_status() - actor_data = await response.json() + async def get(self, lookup, session, cache): + try: + actor_data = await cache.get("activitypub:profile:{}".format(lookup)) + except cache.NotFound: + async with session.get( + lookup, headers={"Accept": "application/activity+json"} + ) as response: + response.raise_for_status() + actor_data = await response.json() + await cache.set("activitypub:profile:{}".format(lookup), actor_data) serializer = activitypub.ActorSerializer(data=actor_data) serializer.is_valid(raise_exception=True) for tag in serializer.validated_data["tag"]: @@ -62,12 +66,12 @@ class Activitypub(Source): class Webfinger(Source): id = "webfinger" - async def get(self, lookup, session): - webfinger_data = await webfinger.lookup(lookup, session) + async def get(self, lookup, session, cache): + webfinger_data = await webfinger.lookup(lookup, session, cache=cache) links = webfinger.get_links(webfinger_data) found = None if "activitypub" in links: - found = await Activitypub().get(links["activitypub"], session) + found = await Activitypub().get(links["activitypub"], session, cache=cache) return found diff --git a/retribute_api/search/webfinger.py b/retribute_api/search/webfinger.py index 9e8cc44..c89a7be 100644 --- a/retribute_api/search/webfinger.py +++ b/retribute_api/search/webfinger.py @@ -3,7 +3,12 @@ from rest_framework import serializers from . import exceptions -async def lookup(name, session): +async def lookup(name, session, cache): + try: + return await cache.get("webfinger:links:{}".format(name)) + except cache.NotFound: + pass + try: username, domain = name.split("@") except ValueError: @@ -14,7 +19,9 @@ async def lookup(name, session): params={"resource": "acct:{}".format(name)}, ) as response: response.raise_for_status() - return await response.json() + payload = await response.json() + await cache.set("webfinger:links:{}".format(name), payload) + return payload class AccountLinkSerializer(serializers.Serializer): diff --git a/setup.cfg b/setup.cfg index 5902555..ffdc904 100644 --- a/setup.cfg +++ b/setup.cfg @@ -20,6 +20,7 @@ include_package_data = True packages = find: install_requires = pytz + aioredis argon2-cffi whitenoise redis diff --git a/tests/conftest.py b/tests/conftest.py index 36e3281..3ced31a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -6,6 +6,7 @@ import asynctest from django.utils import timezone from config import routing +from retribute_api import cache pytest_plugins = "aiohttp.pytest_plugin" @@ -37,3 +38,10 @@ def now(mocker): @pytest.fixture def application(): return routing.application + + +@pytest.fixture(autouse=True) +def dummycache(mocker): + c = cache.Dummy() + mocker.patch.object(cache, "get_default", return_value=c) + return c diff --git a/tests/search/test_consumers.py b/tests/search/test_consumers.py index b70de23..e24f1d4 100644 --- a/tests/search/test_consumers.py +++ b/tests/search/test_consumers.py @@ -7,7 +7,9 @@ from retribute_api.search import exceptions from retribute_api.search import sources -async def test_search_consumer_success(loop, application, mocker, coroutine_mock): +async def test_search_consumer_success( + loop, application, mocker, coroutine_mock, dummycache +): get = mocker.patch.object(sources.Webfinger, "get", coroutine_mock()) expected = {"dummy": "json"} get_profile = mocker.patch.object( @@ -18,6 +20,7 @@ async def test_search_consumer_success(loop, application, mocker, coroutine_mock ) response = await communicator.get_response() assert get.call_args[0][0] == "test@user.domain" + assert get.call_args[1]["cache"] == dummycache get_profile.assert_called_once_with( "webfinger", "test@user.domain", get.return_value ) @@ -29,7 +32,7 @@ async def test_search_consumer_success(loop, application, mocker, coroutine_mock assert response["body"] == json.dumps(expected, indent=2, sort_keys=True).encode() -async def test_search_multiple(loop, application, mocker, coroutine_mock): +async def test_search_multiple(loop, application, mocker, coroutine_mock, dummycache): get = mocker.patch.object( sources.Webfinger, "get", @@ -61,6 +64,8 @@ async def test_search_multiple(loop, application, mocker, coroutine_mock): response = await communicator.get_response() assert response["status"] == 200 assert get.call_count == 4 + for call in get.call_args_list: + assert call[1]["cache"] == dummycache assert get_profile.call_count == 2 get_profile.assert_any_call("webfinger", "1", "1-data") get_profile.assert_any_call("webfinger", "4", "4-data") diff --git a/tests/search/test_sources.py b/tests/search/test_sources.py index 9830332..9c596fb 100644 --- a/tests/search/test_sources.py +++ b/tests/search/test_sources.py @@ -1,7 +1,7 @@ from retribute_api.search import sources -async def test_webfinger_source(mocker, session, responses): +async def test_webfinger_source(mocker, session, responses, dummycache): name = "username@domain.test" webfinger_response = { @@ -46,8 +46,14 @@ async def test_webfinger_source(mocker, session, responses): } source = sources.Webfinger() - result = await source.get(name, session) + cache_get = mocker.spy(dummycache, "get") + cache_set = mocker.spy(dummycache, "set") + result = await source.get(name, session, cache=dummycache) + cache_get.assert_any_call("activitypub:profile:https://domain.test/users/user") + cache_set.assert_any_call( + "activitypub:profile:https://domain.test/users/user", actor_response + ) assert result == expected diff --git a/tests/search/test_webfinger.py b/tests/search/test_webfinger.py index ee3f9f7..37092b8 100644 --- a/tests/search/test_webfinger.py +++ b/tests/search/test_webfinger.py @@ -1,7 +1,7 @@ from retribute_api.search import webfinger -async def test_wellknown_success(responses, session): +async def test_wellknown_success(responses, session, dummycache, mocker): name = "user@domain.test" webfinger_response = { "subject": "acct:user@domain.test", @@ -18,8 +18,14 @@ async def test_wellknown_success(responses, session): "https://domain.test/.well-known/webfinger?resource=acct:{}".format(name), payload=webfinger_response, ) + cache_get = mocker.spy(dummycache, "get") + cache_set = mocker.spy(dummycache, "set") - response = await webfinger.lookup(name, session) + response = await webfinger.lookup(name, session, cache=dummycache) + cache_get.assert_called_once_with("webfinger:links:{}".format(name)) + cache_set.assert_called_once_with( + "webfinger:links:{}".format(name), webfinger_response + ) assert response == webfinger_response -- GitLab