Verified Commit 64850e78 authored by Eliot Berriot's avatar Eliot Berriot
Browse files

Added caching for webfinger/AP responses to reduce latency

parent d715fa0e
......@@ -218,7 +218,7 @@ ALLOWED_HOSTS = env.list("DJANGO_ALLOWED_HOSTS", default=["retribute.me"])
# CACHES
# ------------------------------------------------------------------------------
CACHES = {"default": env.cache()}
ASYNC_REDIS_PARAMS = {"address": CACHES["default"]["LOCATION"]}
CHANNEL_LAYERS = {
"default": {
"BACKEND": "channels_redis.core.RedisChannelLayer",
......
import aioredis
import asyncio
import json
from django.conf import settings
class Backend:
class NotFound(Exception):
pass
async def get(self, key):
raise NotImplementedError
async def set(self, key):
raise NotImplementedError
class Dummy(Backend):
def __init__(self):
self._cache = {}
async def get(self, key):
try:
return self._cache[key]
except KeyError:
raise self.NotFound(key)
async def set(self, key, value):
self._cache[key] = value
class Redis(Backend):
def __init__(self, params):
self.params = params
self._redis = None
async def redis(self):
if self._redis:
return self._redis
self._redis = await aioredis.create_redis(
loop=asyncio.get_event_loop(), **self.params
)
return self._redis
async def get(self, key):
r = await self.redis()
try:
v = await r.get(key)
except KeyError:
raise self.NotFound(key)
return json.loads(v)
async def set(self, key, value):
r = await self.redis()
await r.set(key, json.dumps(value))
def get_default():
return Redis(settings.ASYNC_REDIS_PARAMS)
......@@ -5,6 +5,7 @@ import ssl
from channels.generic.http import AsyncHttpConsumer
from .. import cache
from . import exceptions
from . import sources
from . import serializers
......@@ -30,6 +31,7 @@ def wrapper_500(callback):
await callback(self, body)
except Exception as e:
await json_response(self, 400, {"detail": str(e)})
raise
return callback
......@@ -92,7 +94,7 @@ class SearchSingleConsumer(AsyncHttpConsumer):
await json_response(self, 400, {"detail": "Invalid lookup"})
try:
async with aiohttp.client.ClientSession(timeout=aiohttp_timeout) as session:
data = await source.get(lookup, session)
data = await source.get(lookup, session, cache=cache.get_default())
profile = sources.result_to_retribute_profile(lookup_type, lookup, data)
except (exceptions.SearchError, aiohttp.ClientError) as e:
await json_response(self, 400, {"detail": e.message})
......@@ -105,7 +107,7 @@ class SearchSingleConsumer(AsyncHttpConsumer):
async def do_lookup(lookup, lookup_type, session, source, results):
try:
data = await source.get(lookup, session)
data = await source.get(lookup, session, cache=cache.get_default())
profile = sources.result_to_retribute_profile(lookup_type, lookup, data)
except (
exceptions.SearchError,
......
......@@ -37,12 +37,16 @@ class Activitypub(Source):
# '#nobot',
]
async def get(self, lookup, session):
async with session.get(
lookup, headers={"Accept": "application/activity+json"}
) as response:
response.raise_for_status()
actor_data = await response.json()
async def get(self, lookup, session, cache):
try:
actor_data = await cache.get("activitypub:profile:{}".format(lookup))
except cache.NotFound:
async with session.get(
lookup, headers={"Accept": "application/activity+json"}
) as response:
response.raise_for_status()
actor_data = await response.json()
await cache.set("activitypub:profile:{}".format(lookup), actor_data)
serializer = activitypub.ActorSerializer(data=actor_data)
serializer.is_valid(raise_exception=True)
for tag in serializer.validated_data["tag"]:
......@@ -62,12 +66,12 @@ class Activitypub(Source):
class Webfinger(Source):
id = "webfinger"
async def get(self, lookup, session):
webfinger_data = await webfinger.lookup(lookup, session)
async def get(self, lookup, session, cache):
webfinger_data = await webfinger.lookup(lookup, session, cache=cache)
links = webfinger.get_links(webfinger_data)
found = None
if "activitypub" in links:
found = await Activitypub().get(links["activitypub"], session)
found = await Activitypub().get(links["activitypub"], session, cache=cache)
return found
......
......@@ -3,7 +3,12 @@ from rest_framework import serializers
from . import exceptions
async def lookup(name, session):
async def lookup(name, session, cache):
try:
return await cache.get("webfinger:links:{}".format(name))
except cache.NotFound:
pass
try:
username, domain = name.split("@")
except ValueError:
......@@ -14,7 +19,9 @@ async def lookup(name, session):
params={"resource": "acct:{}".format(name)},
) as response:
response.raise_for_status()
return await response.json()
payload = await response.json()
await cache.set("webfinger:links:{}".format(name), payload)
return payload
class AccountLinkSerializer(serializers.Serializer):
......
......@@ -20,6 +20,7 @@ include_package_data = True
packages = find:
install_requires =
pytz
aioredis
argon2-cffi
whitenoise
redis
......
......@@ -6,6 +6,7 @@ import asynctest
from django.utils import timezone
from config import routing
from retribute_api import cache
pytest_plugins = "aiohttp.pytest_plugin"
......@@ -37,3 +38,10 @@ def now(mocker):
@pytest.fixture
def application():
return routing.application
@pytest.fixture(autouse=True)
def dummycache(mocker):
c = cache.Dummy()
mocker.patch.object(cache, "get_default", return_value=c)
return c
......@@ -7,7 +7,9 @@ from retribute_api.search import exceptions
from retribute_api.search import sources
async def test_search_consumer_success(loop, application, mocker, coroutine_mock):
async def test_search_consumer_success(
loop, application, mocker, coroutine_mock, dummycache
):
get = mocker.patch.object(sources.Webfinger, "get", coroutine_mock())
expected = {"dummy": "json"}
get_profile = mocker.patch.object(
......@@ -18,6 +20,7 @@ async def test_search_consumer_success(loop, application, mocker, coroutine_mock
)
response = await communicator.get_response()
assert get.call_args[0][0] == "test@user.domain"
assert get.call_args[1]["cache"] == dummycache
get_profile.assert_called_once_with(
"webfinger", "test@user.domain", get.return_value
)
......@@ -29,7 +32,7 @@ async def test_search_consumer_success(loop, application, mocker, coroutine_mock
assert response["body"] == json.dumps(expected, indent=2, sort_keys=True).encode()
async def test_search_multiple(loop, application, mocker, coroutine_mock):
async def test_search_multiple(loop, application, mocker, coroutine_mock, dummycache):
get = mocker.patch.object(
sources.Webfinger,
"get",
......@@ -61,6 +64,8 @@ async def test_search_multiple(loop, application, mocker, coroutine_mock):
response = await communicator.get_response()
assert response["status"] == 200
assert get.call_count == 4
for call in get.call_args_list:
assert call[1]["cache"] == dummycache
assert get_profile.call_count == 2
get_profile.assert_any_call("webfinger", "1", "1-data")
get_profile.assert_any_call("webfinger", "4", "4-data")
......
from retribute_api.search import sources
async def test_webfinger_source(mocker, session, responses):
async def test_webfinger_source(mocker, session, responses, dummycache):
name = "username@domain.test"
webfinger_response = {
......@@ -46,8 +46,14 @@ async def test_webfinger_source(mocker, session, responses):
}
source = sources.Webfinger()
result = await source.get(name, session)
cache_get = mocker.spy(dummycache, "get")
cache_set = mocker.spy(dummycache, "set")
result = await source.get(name, session, cache=dummycache)
cache_get.assert_any_call("activitypub:profile:https://domain.test/users/user")
cache_set.assert_any_call(
"activitypub:profile:https://domain.test/users/user", actor_response
)
assert result == expected
......
from retribute_api.search import webfinger
async def test_wellknown_success(responses, session):
async def test_wellknown_success(responses, session, dummycache, mocker):
name = "user@domain.test"
webfinger_response = {
"subject": "acct:user@domain.test",
......@@ -18,8 +18,14 @@ async def test_wellknown_success(responses, session):
"https://domain.test/.well-known/webfinger?resource=acct:{}".format(name),
payload=webfinger_response,
)
cache_get = mocker.spy(dummycache, "get")
cache_set = mocker.spy(dummycache, "set")
response = await webfinger.lookup(name, session)
response = await webfinger.lookup(name, session, cache=dummycache)
cache_get.assert_called_once_with("webfinger:links:{}".format(name))
cache_set.assert_called_once_with(
"webfinger:links:{}".format(name), webfinger_response
)
assert response == webfinger_response
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment