Skip to content
Snippets Groups Projects
Verified Commit fbc27992 authored by Eliot Berriot's avatar Eliot Berriot
Browse files

Fixed filtering issue in domains api

parent 4f399746
No related branches found
No related tags found
No related merge requests found
...@@ -118,6 +118,20 @@ async def get_stats(conn): ...@@ -118,6 +118,20 @@ async def get_stats(conn):
def get_domain_query(**kwargs): def get_domain_query(**kwargs):
base_query = "SELECT DISTINCT on (domain) domain, * FROM checks INNER JOIN domains ON checks.domain = domains.name WHERE domains.blocked = false ORDER BY domain, time DESC"
return base_query.format(where_clause=""), []
@dict_cursor
async def get_domains(cursor, **kwargs):
filters = kwargs.copy()
filters.setdefault("private", False)
filters.setdefault("up", True)
query, params = get_domain_query()
await cursor.execute(query, params)
domains = list(await cursor.fetchall())
# we do the filtering in Python because I didn't figure how to filter on the latest check
# values only
supported_fields = dict( supported_fields = dict(
[ [
("up", "up"), ("up", "up"),
...@@ -125,33 +139,23 @@ def get_domain_query(**kwargs): ...@@ -125,33 +139,23 @@ def get_domain_query(**kwargs):
("federation_enabled", "federation_enabled"), ("federation_enabled", "federation_enabled"),
("anonymous_can_listen", "anonymous_can_listen"), ("anonymous_can_listen", "anonymous_can_listen"),
("private", "private"), ("private", "private"),
("blocked", "domains.blocked"),
] ]
) )
base_query = "SELECT DISTINCT on (domain) domain, * FROM checks INNER JOIN domains ON checks.domain = domains.name{where_clause} ORDER BY domain, time DESC"
filters = [ filters = [
(supported_fields[key], value) (supported_fields[key], value)
for key, value in kwargs.items() for key, value in filters.items()
if key in supported_fields if key in supported_fields
] ]
if not filters:
return base_query.format(where_clause=""), []
params = [] domains = [d for d in domains if should_keep(d, filters)]
where_clauses = [] return domains
for field, value in sorted(filters):
where_clauses.append(f"{field} = %s")
params.append(value)
where_clause = " WHERE {}".format(" AND ".join(where_clauses))
return base_query.format(where_clause=where_clause), params
def should_keep(domain, filters):
@dict_cursor for key, value in filters:
async def get_domains(cursor, **kwargs): if domain[key] != value:
query, params = get_domain_query(**kwargs) return False
await cursor.execute(query, params) return True
return list(await cursor.fetchall())
@dict_cursor @dict_cursor
......
...@@ -51,9 +51,7 @@ async def domains(request): ...@@ -51,9 +51,7 @@ async def domains(request):
if request.method == "GET": if request.method == "GET":
filters = await parser.parse(domain_filters, request) filters = await parser.parse(domain_filters, request)
limit = filters.pop("limit", 0) limit = filters.pop("limit", 0)
rows = await db.get_domains( rows = await db.get_domains(request["conn"], **filters)
request["conn"], private=False, blocked=False, **filters
)
total = len(rows) total = len(rows)
if limit: if limit:
rows = rows[:limit] rows = rows[:limit]
......
...@@ -48,10 +48,10 @@ class DomainFactory(DBFactory): ...@@ -48,10 +48,10 @@ class DomainFactory(DBFactory):
class CheckFactory(DBFactory): class CheckFactory(DBFactory):
time = "NOW()" time = "NOW()"
up = factory.Faker("boolean") up = True
domain = factory.Faker("domain_name") domain = factory.Faker("domain_name")
open_registrations = factory.Faker("boolean") open_registrations = factory.Faker("boolean")
private = factory.Faker("boolean") private = False
federation_enabled = factory.Faker("boolean") federation_enabled = factory.Faker("boolean")
anonymous_can_listen = factory.Faker("boolean") anonymous_can_listen = factory.Faker("boolean")
usage_users_total = factory.Faker("random_int") usage_users_total = factory.Faker("random_int")
......
...@@ -73,34 +73,3 @@ async def test_get_stats(factories, db_conn): ...@@ -73,34 +73,3 @@ async def test_get_stats(factories, db_conn):
"listenings": {"total": 42}, "listenings": {"total": 42},
} }
assert await db.get_stats(db_conn) == expected assert await db.get_stats(db_conn) == expected
@pytest.mark.parametrize(
"kwargs, expected_query, expected_params",
[
(
{},
"SELECT DISTINCT on (domain) domain, * FROM checks INNER JOIN domains ON checks.domain = domains.name ORDER BY domain, time DESC",
[],
),
(
{"up": True},
"SELECT DISTINCT on (domain) domain, * FROM checks INNER JOIN domains ON checks.domain = domains.name WHERE up = %s ORDER BY domain, time DESC",
[True],
),
(
{"up": True, "open_registrations": False},
"SELECT DISTINCT on (domain) domain, * FROM checks INNER JOIN domains ON checks.domain = domains.name WHERE open_registrations = %s AND up = %s ORDER BY domain, time DESC",
[False, True],
),
(
{"up": True, "private": False},
"SELECT DISTINCT on (domain) domain, * FROM checks INNER JOIN domains ON checks.domain = domains.name WHERE private = %s AND up = %s ORDER BY domain, time DESC",
[False, True],
),
],
)
def test_get_domain_query(kwargs, expected_query, expected_params):
query, params = db.get_domain_query(**kwargs)
assert query == expected_query
assert params == expected_params
...@@ -57,10 +57,10 @@ async def test_domains_get_page_size(db_conn, client, factories): ...@@ -57,10 +57,10 @@ async def test_domains_get_page_size(db_conn, client, factories):
) )
@pytest.mark.parametrize("value", [True, False]) @pytest.mark.parametrize("value", [True, False])
async def test_domains_get_filter_up(db_conn, client, factories, field, value): async def test_domains_get_filter_up(db_conn, client, factories, field, value):
domain = factories['Domain']()
await factories["Check"].c(private=True) await factories["Check"].c(domain=domain['name'], private=True)
await factories["Check"].c(private=False, **{field: not value}), await factories["Check"].c(domain=domain['name'], private=False, **{field: not value}),
check = await factories["Check"].c(private=False, **{field: value}) check = await factories["Check"].c(domain=domain['name'], private=False, **{field: value})
domain = await db.get_domain(db_conn, check["domain"]) domain = await db.get_domain(db_conn, check["domain"])
check["first_seen"] = domain["first_seen"] check["first_seen"] = domain["first_seen"]
check["node_name"] = domain["node_name"] check["node_name"] = domain["node_name"]
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment