Skip to content
Snippets Groups Projects
Verified Commit e5d8d0f7 authored by Eliot Berriot's avatar Eliot Berriot
Browse files

Added support for filters on /domains

parent f85ffdee
No related branches found
No related tags found
No related merge requests found
......@@ -115,3 +115,33 @@ async def get_stats(conn):
data["listenings"], "total", int(check["usage_listenings_total"])
)
return data
def get_domain_query(**kwargs):
supported_fields = [
"up",
"open_registrations",
"federation_enabled",
"anonymous_can_listen",
"private",
]
base_query = "SELECT DISTINCT on (domain) domain, * FROM checks{where_clause} ORDER BY domain, time DESC"
filters = [(key, value) for key, value in kwargs.items() if key in supported_fields]
if not filters:
return base_query.format(where_clause=""), []
params = []
where_clauses = []
for field, value in sorted(filters):
where_clauses.append(f"{field} = %s")
params.append(value)
where_clause = " WHERE {}".format(" AND ".join(where_clauses))
return base_query.format(where_clause=where_clause), params
@dict_cursor
async def get_domains(cursor, **kwargs):
query, params = get_domain_query(**kwargs)
await cursor.execute(query, params)
return list(await cursor.fetchall())
......@@ -4,6 +4,9 @@ import json
import os
import urllib.parse
from webargs import fields
from webargs.aiohttpparser import parser
from . import crawler
from . import db
from . import exceptions
......@@ -13,6 +16,13 @@ from . import settings
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
STATIC_DIR = os.path.join(BASE_DIR, "static")
domain_filters = {
"up": fields.Bool(),
"open_registrations": fields.Bool(),
"anonymous_can_listen": fields.Bool(),
"federation_enabled": fields.Bool(),
}
def validate_domain(raw):
if not raw:
......@@ -38,7 +48,9 @@ async def index(request):
async def domains(request):
if request.method == "GET":
rows = await db.get_latest_check_by_domain(request["conn"])
filters = await parser.parse(domain_filters, request)
rows = await db.get_domains(request["conn"], private=False, **filters)
payload = {
"count": len(rows),
"previous": None,
......
......@@ -26,6 +26,7 @@ install_requires =
semver
asynctest
django-environ
webargs
[options.entry_points]
console_scripts =
funkwhale-network = funkwhale_network.cli:main
......
import aiohttp
import marshmallow
import psycopg2
import pytest
from funkwhale_network import crawler, serializers
......@@ -177,7 +178,7 @@ async def test_clean_check_result():
assert crawler.clean_check(check, data) == expected
async def test_save_check(populated_db, db_cursor, db_conn, factories):
async def test_save_check(populated_db, db_conn, factories):
await factories["Check"].c(domain="test.domain", private=False)
await serializers.create_domain(db_conn, {"name": "test.domain"})
data = {
......@@ -207,14 +208,18 @@ async def test_save_check(populated_db, db_cursor, db_conn, factories):
sql = "SELECT * from checks ORDER BY time DESC"
result = await crawler.save_check(db_conn, data)
await db_cursor.execute(sql)
row = await db_cursor.fetchone()
data["time"] = result["time"]
assert data == result
assert row == data
await db_cursor.execute("SELECT * FROM domains WHERE name = %s", ["test.domain"])
domain = await db_cursor.fetchone()
async with db_conn.cursor(
cursor_factory=psycopg2.extras.RealDictCursor
) as db_cursor:
await db_cursor.execute(sql)
row = await db_cursor.fetchone()
data["time"] = result["time"]
assert data == result
assert row == data
await db_cursor.execute(
"SELECT * FROM domains WHERE name = %s", ["test.domain"]
)
domain = await db_cursor.fetchone()
assert domain["node_name"] == "Test Domain"
......@@ -241,6 +246,7 @@ async def test_private_domain_delete_past_checks(
sql = "SELECT * from checks"
assert await crawler.save_check(db_conn, data) is None
await db_cursor.execute(sql)
result = await db_cursor.fetchall()
async with db_conn.cursor() as db_cursor:
await db_cursor.execute(sql)
result = await db_cursor.fetchall()
assert result == []
import pytest
from funkwhale_network import db
......@@ -62,3 +64,34 @@ async def test_get_stats(factories, db_conn):
"listenings": {"total": 42},
}
assert await db.get_stats(db_conn) == expected
@pytest.mark.parametrize(
"kwargs, expected_query, expected_params",
[
(
{},
"SELECT DISTINCT on (domain) domain, * FROM checks ORDER BY domain, time DESC",
[],
),
(
{"up": True},
"SELECT DISTINCT on (domain) domain, * FROM checks WHERE up = %s ORDER BY domain, time DESC",
[True],
),
(
{"up": True, "open_registrations": False},
"SELECT DISTINCT on (domain) domain, * FROM checks WHERE open_registrations = %s AND up = %s ORDER BY domain, time DESC",
[False, True],
),
(
{"up": True, "private": False},
"SELECT DISTINCT on (domain) domain, * FROM checks WHERE private = %s AND up = %s ORDER BY domain, time DESC",
[False, True],
),
],
)
def test_get_domain_query(kwargs, expected_query, expected_params):
query, params = db.get_domain_query(**kwargs)
assert query == expected_query
assert params == expected_params
import pytest
from funkwhale_network import serializers
async def test_domains_get(client, factories):
await factories["Check"].c(private=True)
checks = sorted(
[
await factories["Check"].c(private=False),
await factories["Check"].c(private=False),
await factories["Check"].c(private=False),
],
key=lambda o: o["domain"],
)
......@@ -21,6 +23,26 @@ async def test_domains_get(client, factories):
}
@pytest.mark.parametrize(
"field", ["up", "open_registrations", "anonymous_can_listen", "federation_enabled"]
)
@pytest.mark.parametrize("value", [True, False])
async def test_domains_get_filter_up(client, factories, field, value):
await factories["Check"].c(private=True)
await factories["Check"].c(private=False, **{field: not value}),
check = await factories["Check"].c(private=False, **{field: value})
resp = await client.get("/api/domains", params={field: str(value)})
assert resp.status == 200
assert await resp.json() == {
"count": 1,
"next": None,
"previous": None,
"results": [serializers.serialize_domain_from_check(check)],
}
async def test_domains_create(client, coroutine_mock, mocker):
payload = {"name": "test.domain"}
mocker.patch("funkwhale_network.crawler.fetch_nodeinfo", coroutine_mock())
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment