Skip to content
Snippets Groups Projects
Verified Commit 64850e78 authored by Eliot Berriot's avatar Eliot Berriot
Browse files

Added caching for webfinger/AP responses to reduce latency

parent d715fa0e
No related branches found
No related tags found
No related merge requests found
...@@ -218,7 +218,7 @@ ALLOWED_HOSTS = env.list("DJANGO_ALLOWED_HOSTS", default=["retribute.me"]) ...@@ -218,7 +218,7 @@ ALLOWED_HOSTS = env.list("DJANGO_ALLOWED_HOSTS", default=["retribute.me"])
# CACHES # CACHES
# ------------------------------------------------------------------------------ # ------------------------------------------------------------------------------
CACHES = {"default": env.cache()} CACHES = {"default": env.cache()}
ASYNC_REDIS_PARAMS = {"address": CACHES["default"]["LOCATION"]}
CHANNEL_LAYERS = { CHANNEL_LAYERS = {
"default": { "default": {
"BACKEND": "channels_redis.core.RedisChannelLayer", "BACKEND": "channels_redis.core.RedisChannelLayer",
......
import aioredis
import asyncio
import json
from django.conf import settings
class Backend:
class NotFound(Exception):
pass
async def get(self, key):
raise NotImplementedError
async def set(self, key):
raise NotImplementedError
class Dummy(Backend):
def __init__(self):
self._cache = {}
async def get(self, key):
try:
return self._cache[key]
except KeyError:
raise self.NotFound(key)
async def set(self, key, value):
self._cache[key] = value
class Redis(Backend):
def __init__(self, params):
self.params = params
self._redis = None
async def redis(self):
if self._redis:
return self._redis
self._redis = await aioredis.create_redis(
loop=asyncio.get_event_loop(), **self.params
)
return self._redis
async def get(self, key):
r = await self.redis()
try:
v = await r.get(key)
except KeyError:
raise self.NotFound(key)
return json.loads(v)
async def set(self, key, value):
r = await self.redis()
await r.set(key, json.dumps(value))
def get_default():
return Redis(settings.ASYNC_REDIS_PARAMS)
...@@ -5,6 +5,7 @@ import ssl ...@@ -5,6 +5,7 @@ import ssl
from channels.generic.http import AsyncHttpConsumer from channels.generic.http import AsyncHttpConsumer
from .. import cache
from . import exceptions from . import exceptions
from . import sources from . import sources
from . import serializers from . import serializers
...@@ -30,6 +31,7 @@ def wrapper_500(callback): ...@@ -30,6 +31,7 @@ def wrapper_500(callback):
await callback(self, body) await callback(self, body)
except Exception as e: except Exception as e:
await json_response(self, 400, {"detail": str(e)}) await json_response(self, 400, {"detail": str(e)})
raise
return callback return callback
...@@ -92,7 +94,7 @@ class SearchSingleConsumer(AsyncHttpConsumer): ...@@ -92,7 +94,7 @@ class SearchSingleConsumer(AsyncHttpConsumer):
await json_response(self, 400, {"detail": "Invalid lookup"}) await json_response(self, 400, {"detail": "Invalid lookup"})
try: try:
async with aiohttp.client.ClientSession(timeout=aiohttp_timeout) as session: async with aiohttp.client.ClientSession(timeout=aiohttp_timeout) as session:
data = await source.get(lookup, session) data = await source.get(lookup, session, cache=cache.get_default())
profile = sources.result_to_retribute_profile(lookup_type, lookup, data) profile = sources.result_to_retribute_profile(lookup_type, lookup, data)
except (exceptions.SearchError, aiohttp.ClientError) as e: except (exceptions.SearchError, aiohttp.ClientError) as e:
await json_response(self, 400, {"detail": e.message}) await json_response(self, 400, {"detail": e.message})
...@@ -105,7 +107,7 @@ class SearchSingleConsumer(AsyncHttpConsumer): ...@@ -105,7 +107,7 @@ class SearchSingleConsumer(AsyncHttpConsumer):
async def do_lookup(lookup, lookup_type, session, source, results): async def do_lookup(lookup, lookup_type, session, source, results):
try: try:
data = await source.get(lookup, session) data = await source.get(lookup, session, cache=cache.get_default())
profile = sources.result_to_retribute_profile(lookup_type, lookup, data) profile = sources.result_to_retribute_profile(lookup_type, lookup, data)
except ( except (
exceptions.SearchError, exceptions.SearchError,
......
...@@ -37,12 +37,16 @@ class Activitypub(Source): ...@@ -37,12 +37,16 @@ class Activitypub(Source):
# '#nobot', # '#nobot',
] ]
async def get(self, lookup, session): async def get(self, lookup, session, cache):
try:
actor_data = await cache.get("activitypub:profile:{}".format(lookup))
except cache.NotFound:
async with session.get( async with session.get(
lookup, headers={"Accept": "application/activity+json"} lookup, headers={"Accept": "application/activity+json"}
) as response: ) as response:
response.raise_for_status() response.raise_for_status()
actor_data = await response.json() actor_data = await response.json()
await cache.set("activitypub:profile:{}".format(lookup), actor_data)
serializer = activitypub.ActorSerializer(data=actor_data) serializer = activitypub.ActorSerializer(data=actor_data)
serializer.is_valid(raise_exception=True) serializer.is_valid(raise_exception=True)
for tag in serializer.validated_data["tag"]: for tag in serializer.validated_data["tag"]:
...@@ -62,12 +66,12 @@ class Activitypub(Source): ...@@ -62,12 +66,12 @@ class Activitypub(Source):
class Webfinger(Source): class Webfinger(Source):
id = "webfinger" id = "webfinger"
async def get(self, lookup, session): async def get(self, lookup, session, cache):
webfinger_data = await webfinger.lookup(lookup, session) webfinger_data = await webfinger.lookup(lookup, session, cache=cache)
links = webfinger.get_links(webfinger_data) links = webfinger.get_links(webfinger_data)
found = None found = None
if "activitypub" in links: if "activitypub" in links:
found = await Activitypub().get(links["activitypub"], session) found = await Activitypub().get(links["activitypub"], session, cache=cache)
return found return found
......
...@@ -3,7 +3,12 @@ from rest_framework import serializers ...@@ -3,7 +3,12 @@ from rest_framework import serializers
from . import exceptions from . import exceptions
async def lookup(name, session): async def lookup(name, session, cache):
try:
return await cache.get("webfinger:links:{}".format(name))
except cache.NotFound:
pass
try: try:
username, domain = name.split("@") username, domain = name.split("@")
except ValueError: except ValueError:
...@@ -14,7 +19,9 @@ async def lookup(name, session): ...@@ -14,7 +19,9 @@ async def lookup(name, session):
params={"resource": "acct:{}".format(name)}, params={"resource": "acct:{}".format(name)},
) as response: ) as response:
response.raise_for_status() response.raise_for_status()
return await response.json() payload = await response.json()
await cache.set("webfinger:links:{}".format(name), payload)
return payload
class AccountLinkSerializer(serializers.Serializer): class AccountLinkSerializer(serializers.Serializer):
......
...@@ -20,6 +20,7 @@ include_package_data = True ...@@ -20,6 +20,7 @@ include_package_data = True
packages = find: packages = find:
install_requires = install_requires =
pytz pytz
aioredis
argon2-cffi argon2-cffi
whitenoise whitenoise
redis redis
......
...@@ -6,6 +6,7 @@ import asynctest ...@@ -6,6 +6,7 @@ import asynctest
from django.utils import timezone from django.utils import timezone
from config import routing from config import routing
from retribute_api import cache
pytest_plugins = "aiohttp.pytest_plugin" pytest_plugins = "aiohttp.pytest_plugin"
...@@ -37,3 +38,10 @@ def now(mocker): ...@@ -37,3 +38,10 @@ def now(mocker):
@pytest.fixture @pytest.fixture
def application(): def application():
return routing.application return routing.application
@pytest.fixture(autouse=True)
def dummycache(mocker):
c = cache.Dummy()
mocker.patch.object(cache, "get_default", return_value=c)
return c
...@@ -7,7 +7,9 @@ from retribute_api.search import exceptions ...@@ -7,7 +7,9 @@ from retribute_api.search import exceptions
from retribute_api.search import sources from retribute_api.search import sources
async def test_search_consumer_success(loop, application, mocker, coroutine_mock): async def test_search_consumer_success(
loop, application, mocker, coroutine_mock, dummycache
):
get = mocker.patch.object(sources.Webfinger, "get", coroutine_mock()) get = mocker.patch.object(sources.Webfinger, "get", coroutine_mock())
expected = {"dummy": "json"} expected = {"dummy": "json"}
get_profile = mocker.patch.object( get_profile = mocker.patch.object(
...@@ -18,6 +20,7 @@ async def test_search_consumer_success(loop, application, mocker, coroutine_mock ...@@ -18,6 +20,7 @@ async def test_search_consumer_success(loop, application, mocker, coroutine_mock
) )
response = await communicator.get_response() response = await communicator.get_response()
assert get.call_args[0][0] == "test@user.domain" assert get.call_args[0][0] == "test@user.domain"
assert get.call_args[1]["cache"] == dummycache
get_profile.assert_called_once_with( get_profile.assert_called_once_with(
"webfinger", "test@user.domain", get.return_value "webfinger", "test@user.domain", get.return_value
) )
...@@ -29,7 +32,7 @@ async def test_search_consumer_success(loop, application, mocker, coroutine_mock ...@@ -29,7 +32,7 @@ async def test_search_consumer_success(loop, application, mocker, coroutine_mock
assert response["body"] == json.dumps(expected, indent=2, sort_keys=True).encode() assert response["body"] == json.dumps(expected, indent=2, sort_keys=True).encode()
async def test_search_multiple(loop, application, mocker, coroutine_mock): async def test_search_multiple(loop, application, mocker, coroutine_mock, dummycache):
get = mocker.patch.object( get = mocker.patch.object(
sources.Webfinger, sources.Webfinger,
"get", "get",
...@@ -61,6 +64,8 @@ async def test_search_multiple(loop, application, mocker, coroutine_mock): ...@@ -61,6 +64,8 @@ async def test_search_multiple(loop, application, mocker, coroutine_mock):
response = await communicator.get_response() response = await communicator.get_response()
assert response["status"] == 200 assert response["status"] == 200
assert get.call_count == 4 assert get.call_count == 4
for call in get.call_args_list:
assert call[1]["cache"] == dummycache
assert get_profile.call_count == 2 assert get_profile.call_count == 2
get_profile.assert_any_call("webfinger", "1", "1-data") get_profile.assert_any_call("webfinger", "1", "1-data")
get_profile.assert_any_call("webfinger", "4", "4-data") get_profile.assert_any_call("webfinger", "4", "4-data")
......
from retribute_api.search import sources from retribute_api.search import sources
async def test_webfinger_source(mocker, session, responses): async def test_webfinger_source(mocker, session, responses, dummycache):
name = "username@domain.test" name = "username@domain.test"
webfinger_response = { webfinger_response = {
...@@ -46,8 +46,14 @@ async def test_webfinger_source(mocker, session, responses): ...@@ -46,8 +46,14 @@ async def test_webfinger_source(mocker, session, responses):
} }
source = sources.Webfinger() source = sources.Webfinger()
result = await source.get(name, session) cache_get = mocker.spy(dummycache, "get")
cache_set = mocker.spy(dummycache, "set")
result = await source.get(name, session, cache=dummycache)
cache_get.assert_any_call("activitypub:profile:https://domain.test/users/user")
cache_set.assert_any_call(
"activitypub:profile:https://domain.test/users/user", actor_response
)
assert result == expected assert result == expected
......
from retribute_api.search import webfinger from retribute_api.search import webfinger
async def test_wellknown_success(responses, session): async def test_wellknown_success(responses, session, dummycache, mocker):
name = "user@domain.test" name = "user@domain.test"
webfinger_response = { webfinger_response = {
"subject": "acct:user@domain.test", "subject": "acct:user@domain.test",
...@@ -18,8 +18,14 @@ async def test_wellknown_success(responses, session): ...@@ -18,8 +18,14 @@ async def test_wellknown_success(responses, session):
"https://domain.test/.well-known/webfinger?resource=acct:{}".format(name), "https://domain.test/.well-known/webfinger?resource=acct:{}".format(name),
payload=webfinger_response, payload=webfinger_response,
) )
cache_get = mocker.spy(dummycache, "get")
cache_set = mocker.spy(dummycache, "set")
response = await webfinger.lookup(name, session) response = await webfinger.lookup(name, session, cache=dummycache)
cache_get.assert_called_once_with("webfinger:links:{}".format(name))
cache_set.assert_called_once_with(
"webfinger:links:{}".format(name), webfinger_response
)
assert response == webfinger_response assert response == webfinger_response
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment