diff --git a/funkwhale_network/db.py b/funkwhale_network/db.py index c29c0dcddfc1c024361838c1df92198b42e6ac08..20699f12b03711a227a05d9c1d407a43faf5ade3 100644 --- a/funkwhale_network/db.py +++ b/funkwhale_network/db.py @@ -115,3 +115,33 @@ async def get_stats(conn): data["listenings"], "total", int(check["usage_listenings_total"]) ) return data + + +def get_domain_query(**kwargs): + supported_fields = [ + "up", + "open_registrations", + "federation_enabled", + "anonymous_can_listen", + "private", + ] + base_query = "SELECT DISTINCT on (domain) domain, * FROM checks{where_clause} ORDER BY domain, time DESC" + filters = [(key, value) for key, value in kwargs.items() if key in supported_fields] + if not filters: + return base_query.format(where_clause=""), [] + + params = [] + where_clauses = [] + for field, value in sorted(filters): + where_clauses.append(f"{field} = %s") + params.append(value) + + where_clause = " WHERE {}".format(" AND ".join(where_clauses)) + return base_query.format(where_clause=where_clause), params + + +@dict_cursor +async def get_domains(cursor, **kwargs): + query, params = get_domain_query(**kwargs) + await cursor.execute(query, params) + return list(await cursor.fetchall()) diff --git a/funkwhale_network/routes.py b/funkwhale_network/routes.py index 7cd614a0ebb93684eaf790c53592703afa1218a1..5f84b009b97987b3f9c7a701a0d246a824b60b5f 100644 --- a/funkwhale_network/routes.py +++ b/funkwhale_network/routes.py @@ -4,6 +4,9 @@ import json import os import urllib.parse +from webargs import fields +from webargs.aiohttpparser import parser + from . import crawler from . import db from . import exceptions @@ -13,6 +16,13 @@ from . import settings BASE_DIR = os.path.dirname(os.path.abspath(__file__)) STATIC_DIR = os.path.join(BASE_DIR, "static") +domain_filters = { + "up": fields.Bool(), + "open_registrations": fields.Bool(), + "anonymous_can_listen": fields.Bool(), + "federation_enabled": fields.Bool(), +} + def validate_domain(raw): if not raw: @@ -38,7 +48,9 @@ async def index(request): async def domains(request): if request.method == "GET": - rows = await db.get_latest_check_by_domain(request["conn"]) + + filters = await parser.parse(domain_filters, request) + rows = await db.get_domains(request["conn"], private=False, **filters) payload = { "count": len(rows), "previous": None, diff --git a/setup.cfg b/setup.cfg index 4446a1b98aab53e4dadb45d62be4b762dfb10248..ba9e3faab7e1022630d4478ddc0a1532d8b39020 100644 --- a/setup.cfg +++ b/setup.cfg @@ -26,6 +26,7 @@ install_requires = semver asynctest django-environ + webargs [options.entry_points] console_scripts = funkwhale-network = funkwhale_network.cli:main diff --git a/tests/test_crawler.py b/tests/test_crawler.py index b30a81f11e4114bf06d43be58951c31be22fc48f..1a0928f9b84043f9709597b606a179e9c4775996 100644 --- a/tests/test_crawler.py +++ b/tests/test_crawler.py @@ -1,5 +1,6 @@ import aiohttp import marshmallow +import psycopg2 import pytest from funkwhale_network import crawler, serializers @@ -177,7 +178,7 @@ async def test_clean_check_result(): assert crawler.clean_check(check, data) == expected -async def test_save_check(populated_db, db_cursor, db_conn, factories): +async def test_save_check(populated_db, db_conn, factories): await factories["Check"].c(domain="test.domain", private=False) await serializers.create_domain(db_conn, {"name": "test.domain"}) data = { @@ -207,14 +208,18 @@ async def test_save_check(populated_db, db_cursor, db_conn, factories): sql = "SELECT * from checks ORDER BY time DESC" result = await crawler.save_check(db_conn, data) - await db_cursor.execute(sql) - row = await db_cursor.fetchone() - data["time"] = result["time"] - assert data == result - assert row == data - - await db_cursor.execute("SELECT * FROM domains WHERE name = %s", ["test.domain"]) - domain = await db_cursor.fetchone() + async with db_conn.cursor( + cursor_factory=psycopg2.extras.RealDictCursor + ) as db_cursor: + await db_cursor.execute(sql) + row = await db_cursor.fetchone() + data["time"] = result["time"] + assert data == result + assert row == data + await db_cursor.execute( + "SELECT * FROM domains WHERE name = %s", ["test.domain"] + ) + domain = await db_cursor.fetchone() assert domain["node_name"] == "Test Domain" @@ -241,6 +246,7 @@ async def test_private_domain_delete_past_checks( sql = "SELECT * from checks" assert await crawler.save_check(db_conn, data) is None - await db_cursor.execute(sql) - result = await db_cursor.fetchall() + async with db_conn.cursor() as db_cursor: + await db_cursor.execute(sql) + result = await db_cursor.fetchall() assert result == [] diff --git a/tests/test_db.py b/tests/test_db.py index fa66f12f1c46c2ba57b3a4dd940ee44e07fe4ce1..c44e130a92aa9e3c8152971273f5c2d9b3303524 100644 --- a/tests/test_db.py +++ b/tests/test_db.py @@ -1,3 +1,5 @@ +import pytest + from funkwhale_network import db @@ -62,3 +64,34 @@ async def test_get_stats(factories, db_conn): "listenings": {"total": 42}, } assert await db.get_stats(db_conn) == expected + + +@pytest.mark.parametrize( + "kwargs, expected_query, expected_params", + [ + ( + {}, + "SELECT DISTINCT on (domain) domain, * FROM checks ORDER BY domain, time DESC", + [], + ), + ( + {"up": True}, + "SELECT DISTINCT on (domain) domain, * FROM checks WHERE up = %s ORDER BY domain, time DESC", + [True], + ), + ( + {"up": True, "open_registrations": False}, + "SELECT DISTINCT on (domain) domain, * FROM checks WHERE open_registrations = %s AND up = %s ORDER BY domain, time DESC", + [False, True], + ), + ( + {"up": True, "private": False}, + "SELECT DISTINCT on (domain) domain, * FROM checks WHERE private = %s AND up = %s ORDER BY domain, time DESC", + [False, True], + ), + ], +) +def test_get_domain_query(kwargs, expected_query, expected_params): + query, params = db.get_domain_query(**kwargs) + assert query == expected_query + assert params == expected_params diff --git a/tests/test_routes.py b/tests/test_routes.py index 993f74497be4f961ade261e856376a9ed4743965..1596ac701479fdab5123fc101c39ad816cb8a383 100644 --- a/tests/test_routes.py +++ b/tests/test_routes.py @@ -1,13 +1,15 @@ +import pytest + from funkwhale_network import serializers async def test_domains_get(client, factories): + await factories["Check"].c(private=True) checks = sorted( [ await factories["Check"].c(private=False), await factories["Check"].c(private=False), - await factories["Check"].c(private=False), ], key=lambda o: o["domain"], ) @@ -21,6 +23,26 @@ async def test_domains_get(client, factories): } +@pytest.mark.parametrize( + "field", ["up", "open_registrations", "anonymous_can_listen", "federation_enabled"] +) +@pytest.mark.parametrize("value", [True, False]) +async def test_domains_get_filter_up(client, factories, field, value): + + await factories["Check"].c(private=True) + await factories["Check"].c(private=False, **{field: not value}), + check = await factories["Check"].c(private=False, **{field: value}) + + resp = await client.get("/api/domains", params={field: str(value)}) + assert resp.status == 200 + assert await resp.json() == { + "count": 1, + "next": None, + "previous": None, + "results": [serializers.serialize_domain_from_check(check)], + } + + async def test_domains_create(client, coroutine_mock, mocker): payload = {"name": "test.domain"} mocker.patch("funkwhale_network.crawler.fetch_nodeinfo", coroutine_mock())