Skip to content
Snippets Groups Projects
Verified Commit fbc27992 authored by Eliot Berriot's avatar Eliot Berriot
Browse files

Fixed filtering issue in domains api

parent 4f399746
No related branches found
No related tags found
No related merge requests found
......@@ -118,6 +118,20 @@ async def get_stats(conn):
def get_domain_query(**kwargs):
base_query = "SELECT DISTINCT on (domain) domain, * FROM checks INNER JOIN domains ON checks.domain = domains.name WHERE domains.blocked = false ORDER BY domain, time DESC"
return base_query.format(where_clause=""), []
@dict_cursor
async def get_domains(cursor, **kwargs):
filters = kwargs.copy()
filters.setdefault("private", False)
filters.setdefault("up", True)
query, params = get_domain_query()
await cursor.execute(query, params)
domains = list(await cursor.fetchall())
# we do the filtering in Python because I didn't figure how to filter on the latest check
# values only
supported_fields = dict(
[
("up", "up"),
......@@ -125,33 +139,23 @@ def get_domain_query(**kwargs):
("federation_enabled", "federation_enabled"),
("anonymous_can_listen", "anonymous_can_listen"),
("private", "private"),
("blocked", "domains.blocked"),
]
)
base_query = "SELECT DISTINCT on (domain) domain, * FROM checks INNER JOIN domains ON checks.domain = domains.name{where_clause} ORDER BY domain, time DESC"
filters = [
(supported_fields[key], value)
for key, value in kwargs.items()
for key, value in filters.items()
if key in supported_fields
]
if not filters:
return base_query.format(where_clause=""), []
params = []
where_clauses = []
for field, value in sorted(filters):
where_clauses.append(f"{field} = %s")
params.append(value)
domains = [d for d in domains if should_keep(d, filters)]
return domains
where_clause = " WHERE {}".format(" AND ".join(where_clauses))
return base_query.format(where_clause=where_clause), params
@dict_cursor
async def get_domains(cursor, **kwargs):
query, params = get_domain_query(**kwargs)
await cursor.execute(query, params)
return list(await cursor.fetchall())
def should_keep(domain, filters):
for key, value in filters:
if domain[key] != value:
return False
return True
@dict_cursor
......
......@@ -51,9 +51,7 @@ async def domains(request):
if request.method == "GET":
filters = await parser.parse(domain_filters, request)
limit = filters.pop("limit", 0)
rows = await db.get_domains(
request["conn"], private=False, blocked=False, **filters
)
rows = await db.get_domains(request["conn"], **filters)
total = len(rows)
if limit:
rows = rows[:limit]
......
......@@ -48,10 +48,10 @@ class DomainFactory(DBFactory):
class CheckFactory(DBFactory):
time = "NOW()"
up = factory.Faker("boolean")
up = True
domain = factory.Faker("domain_name")
open_registrations = factory.Faker("boolean")
private = factory.Faker("boolean")
private = False
federation_enabled = factory.Faker("boolean")
anonymous_can_listen = factory.Faker("boolean")
usage_users_total = factory.Faker("random_int")
......
......@@ -73,34 +73,3 @@ async def test_get_stats(factories, db_conn):
"listenings": {"total": 42},
}
assert await db.get_stats(db_conn) == expected
@pytest.mark.parametrize(
"kwargs, expected_query, expected_params",
[
(
{},
"SELECT DISTINCT on (domain) domain, * FROM checks INNER JOIN domains ON checks.domain = domains.name ORDER BY domain, time DESC",
[],
),
(
{"up": True},
"SELECT DISTINCT on (domain) domain, * FROM checks INNER JOIN domains ON checks.domain = domains.name WHERE up = %s ORDER BY domain, time DESC",
[True],
),
(
{"up": True, "open_registrations": False},
"SELECT DISTINCT on (domain) domain, * FROM checks INNER JOIN domains ON checks.domain = domains.name WHERE open_registrations = %s AND up = %s ORDER BY domain, time DESC",
[False, True],
),
(
{"up": True, "private": False},
"SELECT DISTINCT on (domain) domain, * FROM checks INNER JOIN domains ON checks.domain = domains.name WHERE private = %s AND up = %s ORDER BY domain, time DESC",
[False, True],
),
],
)
def test_get_domain_query(kwargs, expected_query, expected_params):
query, params = db.get_domain_query(**kwargs)
assert query == expected_query
assert params == expected_params
......@@ -57,10 +57,10 @@ async def test_domains_get_page_size(db_conn, client, factories):
)
@pytest.mark.parametrize("value", [True, False])
async def test_domains_get_filter_up(db_conn, client, factories, field, value):
await factories["Check"].c(private=True)
await factories["Check"].c(private=False, **{field: not value}),
check = await factories["Check"].c(private=False, **{field: value})
domain = factories['Domain']()
await factories["Check"].c(domain=domain['name'], private=True)
await factories["Check"].c(domain=domain['name'], private=False, **{field: not value}),
check = await factories["Check"].c(domain=domain['name'], private=False, **{field: value})
domain = await db.get_domain(db_conn, check["domain"])
check["first_seen"] = domain["first_seen"]
check["node_name"] = domain["node_name"]
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment