diff --git a/funkwhale_network/db.py b/funkwhale_network/db.py index 2994f3f0b99b03fb60bab87ee2850b0580b6c261..0e0d9fb9f8994e7286fab59d81162644ce1376e4 100644 --- a/funkwhale_network/db.py +++ b/funkwhale_network/db.py @@ -118,6 +118,20 @@ async def get_stats(conn): def get_domain_query(**kwargs): + base_query = "SELECT DISTINCT on (domain) domain, * FROM checks INNER JOIN domains ON checks.domain = domains.name WHERE domains.blocked = false ORDER BY domain, time DESC" + return base_query.format(where_clause=""), [] + + +@dict_cursor +async def get_domains(cursor, **kwargs): + filters = kwargs.copy() + filters.setdefault("private", False) + filters.setdefault("up", True) + query, params = get_domain_query() + await cursor.execute(query, params) + domains = list(await cursor.fetchall()) + # we do the filtering in Python because I didn't figure how to filter on the latest check + # values only supported_fields = dict( [ ("up", "up"), @@ -125,33 +139,23 @@ def get_domain_query(**kwargs): ("federation_enabled", "federation_enabled"), ("anonymous_can_listen", "anonymous_can_listen"), ("private", "private"), - ("blocked", "domains.blocked"), ] ) - base_query = "SELECT DISTINCT on (domain) domain, * FROM checks INNER JOIN domains ON checks.domain = domains.name{where_clause} ORDER BY domain, time DESC" filters = [ (supported_fields[key], value) - for key, value in kwargs.items() + for key, value in filters.items() if key in supported_fields ] - if not filters: - return base_query.format(where_clause=""), [] - params = [] - where_clauses = [] - for field, value in sorted(filters): - where_clauses.append(f"{field} = %s") - params.append(value) + domains = [d for d in domains if should_keep(d, filters)] + return domains - where_clause = " WHERE {}".format(" AND ".join(where_clauses)) - return base_query.format(where_clause=where_clause), params - -@dict_cursor -async def get_domains(cursor, **kwargs): - query, params = get_domain_query(**kwargs) - await cursor.execute(query, params) - return list(await cursor.fetchall()) +def should_keep(domain, filters): + for key, value in filters: + if domain[key] != value: + return False + return True @dict_cursor diff --git a/funkwhale_network/routes.py b/funkwhale_network/routes.py index ba6a5fb4d4f6a9e487837d3eaf2fca3dee292f04..62a3c6841051432ecbf084b05b6f4ded05afc1fb 100644 --- a/funkwhale_network/routes.py +++ b/funkwhale_network/routes.py @@ -51,9 +51,7 @@ async def domains(request): if request.method == "GET": filters = await parser.parse(domain_filters, request) limit = filters.pop("limit", 0) - rows = await db.get_domains( - request["conn"], private=False, blocked=False, **filters - ) + rows = await db.get_domains(request["conn"], **filters) total = len(rows) if limit: rows = rows[:limit] diff --git a/tests/factories.py b/tests/factories.py index 94bc6655a329ef9e588f37206368d409d900bf89..129098b9c2651c5f6723332eb9821717c1583c1e 100644 --- a/tests/factories.py +++ b/tests/factories.py @@ -48,10 +48,10 @@ class DomainFactory(DBFactory): class CheckFactory(DBFactory): time = "NOW()" - up = factory.Faker("boolean") + up = True domain = factory.Faker("domain_name") open_registrations = factory.Faker("boolean") - private = factory.Faker("boolean") + private = False federation_enabled = factory.Faker("boolean") anonymous_can_listen = factory.Faker("boolean") usage_users_total = factory.Faker("random_int") diff --git a/tests/test_db.py b/tests/test_db.py index cbd1d918f12479d1a81c76338bbe9a6c1d4ced60..5386c7d9247c0bf2d14559c0b4b4988063f03e36 100644 --- a/tests/test_db.py +++ b/tests/test_db.py @@ -73,34 +73,3 @@ async def test_get_stats(factories, db_conn): "listenings": {"total": 42}, } assert await db.get_stats(db_conn) == expected - - -@pytest.mark.parametrize( - "kwargs, expected_query, expected_params", - [ - ( - {}, - "SELECT DISTINCT on (domain) domain, * FROM checks INNER JOIN domains ON checks.domain = domains.name ORDER BY domain, time DESC", - [], - ), - ( - {"up": True}, - "SELECT DISTINCT on (domain) domain, * FROM checks INNER JOIN domains ON checks.domain = domains.name WHERE up = %s ORDER BY domain, time DESC", - [True], - ), - ( - {"up": True, "open_registrations": False}, - "SELECT DISTINCT on (domain) domain, * FROM checks INNER JOIN domains ON checks.domain = domains.name WHERE open_registrations = %s AND up = %s ORDER BY domain, time DESC", - [False, True], - ), - ( - {"up": True, "private": False}, - "SELECT DISTINCT on (domain) domain, * FROM checks INNER JOIN domains ON checks.domain = domains.name WHERE private = %s AND up = %s ORDER BY domain, time DESC", - [False, True], - ), - ], -) -def test_get_domain_query(kwargs, expected_query, expected_params): - query, params = db.get_domain_query(**kwargs) - assert query == expected_query - assert params == expected_params diff --git a/tests/test_routes.py b/tests/test_routes.py index 3c5fbc105f230320191ffa511e2f2e5b7eb08770..2ddc8f7cb0d793aba7bce9cd7c3b9bb97fa4f49e 100644 --- a/tests/test_routes.py +++ b/tests/test_routes.py @@ -57,10 +57,10 @@ async def test_domains_get_page_size(db_conn, client, factories): ) @pytest.mark.parametrize("value", [True, False]) async def test_domains_get_filter_up(db_conn, client, factories, field, value): - - await factories["Check"].c(private=True) - await factories["Check"].c(private=False, **{field: not value}), - check = await factories["Check"].c(private=False, **{field: value}) + domain = factories['Domain']() + await factories["Check"].c(domain=domain['name'], private=True) + await factories["Check"].c(domain=domain['name'], private=False, **{field: not value}), + check = await factories["Check"].c(domain=domain['name'], private=False, **{field: value}) domain = await db.get_domain(db_conn, check["domain"]) check["first_seen"] = domain["first_seen"] check["node_name"] = domain["node_name"]