Verified Commit 1e5eab9a authored by Eliot Berriot's avatar Eliot Berriot
Browse files

Added multiple query search engine

parent b64e9210
......@@ -18,6 +18,7 @@ application = ProtocolTypeRouter(
{
"http": URLRouter(
[
url(r"^api/v1/search/$", consumers.SearchMultipleConsumer),
url(
r"^api/v1/search/(?P<lookup_type>.+):(?P<lookup>.+)$",
consumers.SearchSingleConsumer,
......
import asyncio
import aiohttp.client
import json
......@@ -5,6 +6,7 @@ from channels.generic.http import AsyncHttpConsumer
from . import exceptions
from . import sources
from . import serializers
async def json_response(self, status, content):
......@@ -38,10 +40,59 @@ class SearchSingleConsumer(AsyncHttpConsumer):
except KeyError:
await json_response(self, 400, {"detail": "Invalid lookup"})
try:
async with aiohttp.client.ClientSession() as session:
async with aiohttp.client.ClientSession(timeout=5) as session:
data = await source.get(lookup, session)
profile = sources.result_to_retribute_profile(lookup_type, lookup, data)
except exceptions.SearchError as e:
except (exceptions.SearchError, aiohttp.ClientError) as e:
await json_response(self, 400, {"detail": e.message})
await json_response(self, 200, profile)
async def do_lookup(lookup, lookup_type, session, source, results):
try:
data = await source.get(lookup, session)
profile = sources.result_to_retribute_profile(lookup_type, lookup, data)
except (exceptions.SearchError, aiohttp.ClientError) as e:
results[":".join([lookup_type, lookup])] = None
return
results[":".join([lookup_type, lookup])] = profile
class SearchMultipleConsumer(AsyncHttpConsumer):
@wrapper_500
async def handle(self, body):
if self.scope["method"] not in ["POST"]:
return await self.send_response(405, b"")
try:
parsed_body = json.loads(body)
except ValueError:
return await json_response(self, 400, {"detail": "Invalid JSON"})
serializer = serializers.SearchMultipleSerializer(data=parsed_body)
if not serializer.is_valid():
return await json_response(self, 400, {"detail": "Invalid data"})
lookups = serializer.validated_data["lookups"]
results = {}
tasks = []
async with aiohttp.client.ClientSession(timeout=15) as session:
for lookup_type, lookup in lookups:
try:
source = sources.registry._data[lookup_type]
except KeyError:
results[":".join([lookup_type, lookup])] = None
continue
tasks.append(
do_lookup(
lookup=lookup,
lookup_type=lookup_type,
session=session,
source=source,
results=results,
)
)
await asyncio.gather(*tasks)
await json_response(self, 200, results)
from rest_framework import serializers
class Lookup(serializers.CharField):
def to_internal_value(self, value):
value = super().to_internal_value(value)
try:
lookup_type, lookup = value.split(":")
except (ValueError, TypeError, AttributeError):
raise serializers.ValidationError("Invalid lookup {}".format(value))
return lookup_type, lookup
class SearchMultipleSerializer(serializers.Serializer):
lookups = serializers.ListField(child=Lookup(), min_length=1, max_length=20)
import json
import aiohttp
from channels.testing import HttpCommunicator
from retribute_api.search import consumers
from retribute_api.search import exceptions
from retribute_api.search import sources
......@@ -25,3 +27,46 @@ async def test_search_consumer_success(loop, application, mocker, coroutine_mock
(b"Access-Control-Allow-Origin", b"*"),
]
assert response["body"] == json.dumps(expected, indent=2, sort_keys=True).encode()
async def test_search_multiple(loop, application, mocker, coroutine_mock):
get = mocker.patch.object(
sources.Webfinger,
"get",
coroutine_mock(
side_effect=[
"1-data",
exceptions.SearchError(),
aiohttp.ClientError(),
"4-data",
]
),
)
expected = {
"webfinger:1": "1-profile",
"webfinger:2": None,
"webfinger:3": None,
"webfinger:4": "4-profile",
}
profile_results = {"1-data": "1-profile", "4-data": "4-profile"}
get_profile = mocker.patch.object(
sources,
"result_to_retribute_profile",
side_effect=lambda a, b, c: profile_results[c],
)
query = {"lookups": ["webfinger:1", "webfinger:2", "webfinger:3", "webfinger:4"]}
communicator = HttpCommunicator(
application, "POST", "/api/v1/search/", body=json.dumps(query).encode()
)
response = await communicator.get_response()
assert response["status"] == 200
assert get.call_count == 4
assert get_profile.call_count == 2
get_profile.assert_any_call("webfinger", "1", "1-data")
get_profile.assert_any_call("webfinger", "4", "4-data")
assert response["headers"] == [
(b"Content-Type", b"application/json"),
(b"Access-Control-Allow-Origin", b"*"),
]
assert response["body"] == json.dumps(expected, indent=2, sort_keys=True).encode()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment