Skip to content
Snippets Groups Projects
Verified Commit 2bb27d17 authored by Georg Krause's avatar Georg Krause
Browse files

Fix most tests again

parent e419e32b
No related branches found
No related tags found
No related merge requests found
Pipeline #26428 failed with stages
in 8 minutes and 59 seconds
......@@ -83,8 +83,8 @@ def increment_stat(data, key, value):
data[key] += value
async def get_stats(conn):
checks = await get_latest_check_by_domain(conn)
async def get_stats():
checks = await get_latest_check_by_domain()
data = {
"users": {"total": 0, "activeMonth": 0, "activeHalfyear": 0},
"instances": {"total": 0, "anonymousCanListen": 0, "openRegistrations": 0},
......@@ -125,7 +125,7 @@ def get_domain_query(**kwargs):
async def get_domains(**kwargs):
conn = aiopg.connect(settings.DB_DSN)
conn = await aiopg.connect(settings.DB_DSN)
cursor = await conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor)
filters = kwargs.copy()
filters.setdefault("private", False)
......@@ -161,8 +161,8 @@ def should_keep(domain, filters):
return True
async def get_domain(cursor, name):
conn = aiopg.connect(settings.DB_DSN)
async def get_domain(name):
conn = await aiopg.connect(settings.DB_DSN)
cursor = await conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor)
await cursor.execute("SELECT * FROM domains WHERE name = %s", (name,))
return list(await cursor.fetchall())[0]
......@@ -52,7 +52,7 @@ async def domains(request):
if request.method == "GET":
filters = await parser.parse(domain_filters, request, location="querystring")
limit = int(request.query.get("limit", default=0))
rows = await db.get_domains(request["conn"], **filters)
rows = await db.get_domains(**{"filters": filters})
total = len(rows)
if limit:
rows = rows[:limit]
......@@ -93,7 +93,7 @@ async def domains(request):
{"error": f"Invalid domain name {payload['name']}"}, status=400
)
domain = await serializers.create_domain(request["conn"], payload)
domain = await serializers.create_domain(payload)
if domain:
payload = serializers.serialize_domain(domain)
return web.json_response(payload, status=201)
......
......@@ -72,7 +72,7 @@ class CheckFactory(DBFactory):
@classmethod
async def pre_create(cls, o):
await serializers.create_domain(cls._conn, {"name": o["domain"]})
await serializers.create_domain({"name": o["domain"]})
ALL = [("Check", CheckFactory, "checks"), ("Domain", DomainFactory, "domains")]
......@@ -42,7 +42,7 @@ async def test_check(db_conn, populated_db, session, mocker, coroutine_mock):
clean_check.assert_called_once_with(
{"up": True, "domain": "test.domain"}, {"cleaned": "nodeinfo"}
)
save_check.assert_called_once_with(db_conn, {"cleaned": "check"})
save_check.assert_called_once_with({"cleaned": "check"})
async def test_check_nodeinfo_connection_error(
......@@ -56,7 +56,7 @@ async def test_check_nodeinfo_connection_error(
save_check = mocker.patch.object(crawler, "save_check", coroutine_mock())
await crawler.check(db_conn, session, "test.domain")
fetch_nodeinfo.assert_called_once_with(session, "test.domain")
save_check.assert_called_once_with(db_conn, {"domain": "test.domain", "up": False})
save_check.assert_called_once_with({"domain": "test.domain", "up": False})
def test_clean_nodeinfo(populated_db):
......@@ -181,7 +181,7 @@ async def test_clean_check_result():
async def test_save_check(populated_db, db_conn, factories):
await factories["Check"].c(domain="test.domain", private=False)
await serializers.create_domain(db_conn, {"name": "test.domain"})
await serializers.create_domain({"name": "test.domain"})
data = {
"domain": "test.domain",
"node_name": "Test Domain",
......@@ -208,7 +208,7 @@ async def test_save_check(populated_db, db_conn, factories):
}
sql = "SELECT * from checks ORDER BY time DESC"
result = await crawler.save_check(db_conn, data)
result = await crawler.save_check(data)
async with db_conn.cursor(
cursor_factory=psycopg2.extras.RealDictCursor
......@@ -247,7 +247,7 @@ async def test_private_domain_delete_past_checks(
}
sql = "SELECT * from checks"
assert await crawler.save_check(db_conn, data) is None
assert await crawler.save_check(data) is None
async with db_conn.cursor() as db_cursor:
await db_cursor.execute(sql)
result = await db_cursor.fetchall()
......
import pytest
from funkwhale_network import db
......@@ -17,19 +15,19 @@ async def test_db_create(db_pool):
await db.clear(conn)
async def test_get_latest_checks_by_domain(factories, db_conn):
async def test_get_latest_checks_by_domain(factories):
await factories["Check"].c(domain="test1.domain", private=False)
check2 = await factories["Check"].c(domain="test1.domain", private=False)
check3 = await factories["Check"].c(domain="test2.domain", private=False)
expected = [check2, check3]
for check in expected:
domain = await db.get_domain(db_conn, check["domain"])
domain = await db.get_domain(check["domain"])
check["first_seen"] = domain["first_seen"]
check["node_name"] = domain["node_name"]
check["blocked"] = domain["blocked"]
check["name"] = domain["name"]
result = await db.get_latest_check_by_domain(db_conn)
result = await db.get_latest_check_by_domain()
assert len(result) == 2
for i, row in enumerate(result):
assert dict(row) == dict(expected[i])
......@@ -75,4 +73,4 @@ async def test_get_stats(factories, db_conn):
"listenings": {"total": 42},
"downloads": {"total": 63},
}
assert await db.get_stats(db_conn) == expected
assert await db.get_stats() == expected
......@@ -17,7 +17,7 @@ async def test_domains_get(db_conn, client, factories):
key=lambda o: o["domain"],
)
for check in checks:
domain = await db.get_domain(db_conn, check["domain"])
domain = await db.get_domain(check["domain"])
check["first_seen"] = domain["first_seen"]
check["node_name"] = domain["node_name"]
resp = await client.get("/api/domains")
......@@ -40,7 +40,7 @@ async def test_domains_get_page_size(db_conn, client, factories):
key=lambda o: o["domain"],
)
for check in checks:
domain = await db.get_domain(db_conn, check["domain"])
domain = await db.get_domain(check["domain"])
check["first_seen"] = domain["first_seen"]
check["node_name"] = domain["node_name"]
resp = await client.get("/api/domains", params={"limit": 1})
......@@ -66,7 +66,7 @@ async def test_domains_get_filter_up(db_conn, client, factories, field, value):
check = await factories["Check"].c(
domain=domain["name"], private=False, **{field: value}
)
domain = await db.get_domain(db_conn, check["domain"])
domain = await db.get_domain(check["domain"])
check["first_seen"] = domain["first_seen"]
check["node_name"] = domain["node_name"]
resp = await client.get("/api/domains", params={field: str(value)})
......@@ -84,7 +84,7 @@ async def test_domains_exclude_blocked(db_conn, client, factories):
blocked = await factories["Domain"].c(blocked=True)
await factories["Check"].c(private=False, domain=blocked["name"])
check = await factories["Check"].c(private=False)
domain = await db.get_domain(db_conn, check["domain"])
domain = await db.get_domain(check["domain"])
check["first_seen"] = domain["first_seen"]
check["node_name"] = domain["node_name"]
resp = await client.get("/api/domains")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment