From 64850e787be5d957979f1d0c360e30ccb4c1ac31 Mon Sep 17 00:00:00 2001
From: Eliot Berriot <contact@eliotberriot.com>
Date: Thu, 6 Jun 2019 15:25:01 +0200
Subject: [PATCH] Added caching for webfinger/AP responses to reduce latency

---
 config/settings/base.py           |  2 +-
 retribute_api/cache.py            | 60 +++++++++++++++++++++++++++++++
 retribute_api/search/consumers.py |  6 ++--
 retribute_api/search/sources.py   | 22 +++++++-----
 retribute_api/search/webfinger.py | 11 ++++--
 setup.cfg                         |  1 +
 tests/conftest.py                 |  8 +++++
 tests/search/test_consumers.py    |  9 +++--
 tests/search/test_sources.py      | 10 ++++--
 tests/search/test_webfinger.py    | 10 ++++--
 10 files changed, 119 insertions(+), 20 deletions(-)
 create mode 100644 retribute_api/cache.py

diff --git a/config/settings/base.py b/config/settings/base.py
index fa5d814..aa4adc4 100644
--- a/config/settings/base.py
+++ b/config/settings/base.py
@@ -218,7 +218,7 @@ ALLOWED_HOSTS = env.list("DJANGO_ALLOWED_HOSTS", default=["retribute.me"])
 # CACHES
 # ------------------------------------------------------------------------------
 CACHES = {"default": env.cache()}
-
+ASYNC_REDIS_PARAMS = {"address": CACHES["default"]["LOCATION"]}
 CHANNEL_LAYERS = {
     "default": {
         "BACKEND": "channels_redis.core.RedisChannelLayer",
diff --git a/retribute_api/cache.py b/retribute_api/cache.py
new file mode 100644
index 0000000..779af31
--- /dev/null
+++ b/retribute_api/cache.py
@@ -0,0 +1,60 @@
+import aioredis
+import asyncio
+import json
+
+from django.conf import settings
+
+
+class Backend:
+    class NotFound(Exception):
+        pass
+
+    async def get(self, key):
+        raise NotImplementedError
+
+    async def set(self, key):
+        raise NotImplementedError
+
+
+class Dummy(Backend):
+    def __init__(self):
+        self._cache = {}
+
+    async def get(self, key):
+        try:
+            return self._cache[key]
+        except KeyError:
+            raise self.NotFound(key)
+
+    async def set(self, key, value):
+        self._cache[key] = value
+
+
+class Redis(Backend):
+    def __init__(self, params):
+        self.params = params
+        self._redis = None
+
+    async def redis(self):
+        if self._redis:
+            return self._redis
+        self._redis = await aioredis.create_redis(
+            loop=asyncio.get_event_loop(), **self.params
+        )
+        return self._redis
+
+    async def get(self, key):
+        r = await self.redis()
+        try:
+            v = await r.get(key)
+        except KeyError:
+            raise self.NotFound(key)
+        return json.loads(v)
+
+    async def set(self, key, value):
+        r = await self.redis()
+        await r.set(key, json.dumps(value))
+
+
+def get_default():
+    return Redis(settings.ASYNC_REDIS_PARAMS)
diff --git a/retribute_api/search/consumers.py b/retribute_api/search/consumers.py
index eb3867a..c7d05c3 100644
--- a/retribute_api/search/consumers.py
+++ b/retribute_api/search/consumers.py
@@ -5,6 +5,7 @@ import ssl
 
 from channels.generic.http import AsyncHttpConsumer
 
+from .. import cache
 from . import exceptions
 from . import sources
 from . import serializers
@@ -30,6 +31,7 @@ def wrapper_500(callback):
             await callback(self, body)
         except Exception as e:
             await json_response(self, 400, {"detail": str(e)})
+            raise
 
     return callback
 
@@ -92,7 +94,7 @@ class SearchSingleConsumer(AsyncHttpConsumer):
             await json_response(self, 400, {"detail": "Invalid lookup"})
         try:
             async with aiohttp.client.ClientSession(timeout=aiohttp_timeout) as session:
-                data = await source.get(lookup, session)
+                data = await source.get(lookup, session, cache=cache.get_default())
             profile = sources.result_to_retribute_profile(lookup_type, lookup, data)
         except (exceptions.SearchError, aiohttp.ClientError) as e:
             await json_response(self, 400, {"detail": e.message})
@@ -105,7 +107,7 @@ class SearchSingleConsumer(AsyncHttpConsumer):
 
 async def do_lookup(lookup, lookup_type, session, source, results):
     try:
-        data = await source.get(lookup, session)
+        data = await source.get(lookup, session, cache=cache.get_default())
         profile = sources.result_to_retribute_profile(lookup_type, lookup, data)
     except (
         exceptions.SearchError,
diff --git a/retribute_api/search/sources.py b/retribute_api/search/sources.py
index ea74c24..6285062 100644
--- a/retribute_api/search/sources.py
+++ b/retribute_api/search/sources.py
@@ -37,12 +37,16 @@ class Activitypub(Source):
         # '#nobot',
     ]
 
-    async def get(self, lookup, session):
-        async with session.get(
-            lookup, headers={"Accept": "application/activity+json"}
-        ) as response:
-            response.raise_for_status()
-            actor_data = await response.json()
+    async def get(self, lookup, session, cache):
+        try:
+            actor_data = await cache.get("activitypub:profile:{}".format(lookup))
+        except cache.NotFound:
+            async with session.get(
+                lookup, headers={"Accept": "application/activity+json"}
+            ) as response:
+                response.raise_for_status()
+                actor_data = await response.json()
+            await cache.set("activitypub:profile:{}".format(lookup), actor_data)
         serializer = activitypub.ActorSerializer(data=actor_data)
         serializer.is_valid(raise_exception=True)
         for tag in serializer.validated_data["tag"]:
@@ -62,12 +66,12 @@ class Activitypub(Source):
 class Webfinger(Source):
     id = "webfinger"
 
-    async def get(self, lookup, session):
-        webfinger_data = await webfinger.lookup(lookup, session)
+    async def get(self, lookup, session, cache):
+        webfinger_data = await webfinger.lookup(lookup, session, cache=cache)
         links = webfinger.get_links(webfinger_data)
         found = None
         if "activitypub" in links:
-            found = await Activitypub().get(links["activitypub"], session)
+            found = await Activitypub().get(links["activitypub"], session, cache=cache)
         return found
 
 
diff --git a/retribute_api/search/webfinger.py b/retribute_api/search/webfinger.py
index 9e8cc44..c89a7be 100644
--- a/retribute_api/search/webfinger.py
+++ b/retribute_api/search/webfinger.py
@@ -3,7 +3,12 @@ from rest_framework import serializers
 from . import exceptions
 
 
-async def lookup(name, session):
+async def lookup(name, session, cache):
+    try:
+        return await cache.get("webfinger:links:{}".format(name))
+    except cache.NotFound:
+        pass
+
     try:
         username, domain = name.split("@")
     except ValueError:
@@ -14,7 +19,9 @@ async def lookup(name, session):
         params={"resource": "acct:{}".format(name)},
     ) as response:
         response.raise_for_status()
-        return await response.json()
+        payload = await response.json()
+        await cache.set("webfinger:links:{}".format(name), payload)
+        return payload
 
 
 class AccountLinkSerializer(serializers.Serializer):
diff --git a/setup.cfg b/setup.cfg
index 5902555..ffdc904 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -20,6 +20,7 @@ include_package_data = True
 packages = find:
 install_requires =
     pytz
+    aioredis
     argon2-cffi
     whitenoise
     redis
diff --git a/tests/conftest.py b/tests/conftest.py
index 36e3281..3ced31a 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -6,6 +6,7 @@ import asynctest
 from django.utils import timezone
 
 from config import routing
+from retribute_api import cache
 
 pytest_plugins = "aiohttp.pytest_plugin"
 
@@ -37,3 +38,10 @@ def now(mocker):
 @pytest.fixture
 def application():
     return routing.application
+
+
+@pytest.fixture(autouse=True)
+def dummycache(mocker):
+    c = cache.Dummy()
+    mocker.patch.object(cache, "get_default", return_value=c)
+    return c
diff --git a/tests/search/test_consumers.py b/tests/search/test_consumers.py
index b70de23..e24f1d4 100644
--- a/tests/search/test_consumers.py
+++ b/tests/search/test_consumers.py
@@ -7,7 +7,9 @@ from retribute_api.search import exceptions
 from retribute_api.search import sources
 
 
-async def test_search_consumer_success(loop, application, mocker, coroutine_mock):
+async def test_search_consumer_success(
+    loop, application, mocker, coroutine_mock, dummycache
+):
     get = mocker.patch.object(sources.Webfinger, "get", coroutine_mock())
     expected = {"dummy": "json"}
     get_profile = mocker.patch.object(
@@ -18,6 +20,7 @@ async def test_search_consumer_success(loop, application, mocker, coroutine_mock
     )
     response = await communicator.get_response()
     assert get.call_args[0][0] == "test@user.domain"
+    assert get.call_args[1]["cache"] == dummycache
     get_profile.assert_called_once_with(
         "webfinger", "test@user.domain", get.return_value
     )
@@ -29,7 +32,7 @@ async def test_search_consumer_success(loop, application, mocker, coroutine_mock
     assert response["body"] == json.dumps(expected, indent=2, sort_keys=True).encode()
 
 
-async def test_search_multiple(loop, application, mocker, coroutine_mock):
+async def test_search_multiple(loop, application, mocker, coroutine_mock, dummycache):
     get = mocker.patch.object(
         sources.Webfinger,
         "get",
@@ -61,6 +64,8 @@ async def test_search_multiple(loop, application, mocker, coroutine_mock):
     response = await communicator.get_response()
     assert response["status"] == 200
     assert get.call_count == 4
+    for call in get.call_args_list:
+        assert call[1]["cache"] == dummycache
     assert get_profile.call_count == 2
     get_profile.assert_any_call("webfinger", "1", "1-data")
     get_profile.assert_any_call("webfinger", "4", "4-data")
diff --git a/tests/search/test_sources.py b/tests/search/test_sources.py
index 9830332..9c596fb 100644
--- a/tests/search/test_sources.py
+++ b/tests/search/test_sources.py
@@ -1,7 +1,7 @@
 from retribute_api.search import sources
 
 
-async def test_webfinger_source(mocker, session, responses):
+async def test_webfinger_source(mocker, session, responses, dummycache):
     name = "username@domain.test"
 
     webfinger_response = {
@@ -46,8 +46,14 @@ async def test_webfinger_source(mocker, session, responses):
     }
 
     source = sources.Webfinger()
-    result = await source.get(name, session)
+    cache_get = mocker.spy(dummycache, "get")
+    cache_set = mocker.spy(dummycache, "set")
 
+    result = await source.get(name, session, cache=dummycache)
+    cache_get.assert_any_call("activitypub:profile:https://domain.test/users/user")
+    cache_set.assert_any_call(
+        "activitypub:profile:https://domain.test/users/user", actor_response
+    )
     assert result == expected
 
 
diff --git a/tests/search/test_webfinger.py b/tests/search/test_webfinger.py
index ee3f9f7..37092b8 100644
--- a/tests/search/test_webfinger.py
+++ b/tests/search/test_webfinger.py
@@ -1,7 +1,7 @@
 from retribute_api.search import webfinger
 
 
-async def test_wellknown_success(responses, session):
+async def test_wellknown_success(responses, session, dummycache, mocker):
     name = "user@domain.test"
     webfinger_response = {
         "subject": "acct:user@domain.test",
@@ -18,8 +18,14 @@ async def test_wellknown_success(responses, session):
         "https://domain.test/.well-known/webfinger?resource=acct:{}".format(name),
         payload=webfinger_response,
     )
+    cache_get = mocker.spy(dummycache, "get")
+    cache_set = mocker.spy(dummycache, "set")
 
-    response = await webfinger.lookup(name, session)
+    response = await webfinger.lookup(name, session, cache=dummycache)
+    cache_get.assert_called_once_with("webfinger:links:{}".format(name))
+    cache_set.assert_called_once_with(
+        "webfinger:links:{}".format(name), webfinger_response
+    )
     assert response == webfinger_response
 
 
-- 
GitLab