Skip to content
Snippets Groups Projects
Verified Commit 1e5eab9a authored by Eliot Berriot's avatar Eliot Berriot
Browse files

Added multiple query search engine

parent b64e9210
No related branches found
No related tags found
No related merge requests found
......@@ -18,6 +18,7 @@ application = ProtocolTypeRouter(
{
"http": URLRouter(
[
url(r"^api/v1/search/$", consumers.SearchMultipleConsumer),
url(
r"^api/v1/search/(?P<lookup_type>.+):(?P<lookup>.+)$",
consumers.SearchSingleConsumer,
......
import asyncio
import aiohttp.client
import json
......@@ -5,6 +6,7 @@ from channels.generic.http import AsyncHttpConsumer
from . import exceptions
from . import sources
from . import serializers
async def json_response(self, status, content):
......@@ -38,10 +40,59 @@ class SearchSingleConsumer(AsyncHttpConsumer):
except KeyError:
await json_response(self, 400, {"detail": "Invalid lookup"})
try:
async with aiohttp.client.ClientSession() as session:
async with aiohttp.client.ClientSession(timeout=5) as session:
data = await source.get(lookup, session)
profile = sources.result_to_retribute_profile(lookup_type, lookup, data)
except exceptions.SearchError as e:
except (exceptions.SearchError, aiohttp.ClientError) as e:
await json_response(self, 400, {"detail": e.message})
await json_response(self, 200, profile)
async def do_lookup(lookup, lookup_type, session, source, results):
try:
data = await source.get(lookup, session)
profile = sources.result_to_retribute_profile(lookup_type, lookup, data)
except (exceptions.SearchError, aiohttp.ClientError) as e:
results[":".join([lookup_type, lookup])] = None
return
results[":".join([lookup_type, lookup])] = profile
class SearchMultipleConsumer(AsyncHttpConsumer):
@wrapper_500
async def handle(self, body):
if self.scope["method"] not in ["POST"]:
return await self.send_response(405, b"")
try:
parsed_body = json.loads(body)
except ValueError:
return await json_response(self, 400, {"detail": "Invalid JSON"})
serializer = serializers.SearchMultipleSerializer(data=parsed_body)
if not serializer.is_valid():
return await json_response(self, 400, {"detail": "Invalid data"})
lookups = serializer.validated_data["lookups"]
results = {}
tasks = []
async with aiohttp.client.ClientSession(timeout=15) as session:
for lookup_type, lookup in lookups:
try:
source = sources.registry._data[lookup_type]
except KeyError:
results[":".join([lookup_type, lookup])] = None
continue
tasks.append(
do_lookup(
lookup=lookup,
lookup_type=lookup_type,
session=session,
source=source,
results=results,
)
)
await asyncio.gather(*tasks)
await json_response(self, 200, results)
from rest_framework import serializers
class Lookup(serializers.CharField):
def to_internal_value(self, value):
value = super().to_internal_value(value)
try:
lookup_type, lookup = value.split(":")
except (ValueError, TypeError, AttributeError):
raise serializers.ValidationError("Invalid lookup {}".format(value))
return lookup_type, lookup
class SearchMultipleSerializer(serializers.Serializer):
lookups = serializers.ListField(child=Lookup(), min_length=1, max_length=20)
import json
import aiohttp
from channels.testing import HttpCommunicator
from retribute_api.search import consumers
from retribute_api.search import exceptions
from retribute_api.search import sources
......@@ -25,3 +27,46 @@ async def test_search_consumer_success(loop, application, mocker, coroutine_mock
(b"Access-Control-Allow-Origin", b"*"),
]
assert response["body"] == json.dumps(expected, indent=2, sort_keys=True).encode()
async def test_search_multiple(loop, application, mocker, coroutine_mock):
get = mocker.patch.object(
sources.Webfinger,
"get",
coroutine_mock(
side_effect=[
"1-data",
exceptions.SearchError(),
aiohttp.ClientError(),
"4-data",
]
),
)
expected = {
"webfinger:1": "1-profile",
"webfinger:2": None,
"webfinger:3": None,
"webfinger:4": "4-profile",
}
profile_results = {"1-data": "1-profile", "4-data": "4-profile"}
get_profile = mocker.patch.object(
sources,
"result_to_retribute_profile",
side_effect=lambda a, b, c: profile_results[c],
)
query = {"lookups": ["webfinger:1", "webfinger:2", "webfinger:3", "webfinger:4"]}
communicator = HttpCommunicator(
application, "POST", "/api/v1/search/", body=json.dumps(query).encode()
)
response = await communicator.get_response()
assert response["status"] == 200
assert get.call_count == 4
assert get_profile.call_count == 2
get_profile.assert_any_call("webfinger", "1", "1-data")
get_profile.assert_any_call("webfinger", "4", "4-data")
assert response["headers"] == [
(b"Content-Type", b"application/json"),
(b"Access-Control-Allow-Origin", b"*"),
]
assert response["body"] == json.dumps(expected, indent=2, sort_keys=True).encode()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment