diff --git a/funkwhale_network/routes.py b/funkwhale_network/routes.py index 07ef63af627bbb5d20706442b7b1d402b9cc4b77..113a7ea9375e7d8950d8bbfe6c3c8f7edcedaf26 100644 --- a/funkwhale_network/routes.py +++ b/funkwhale_network/routes.py @@ -26,6 +26,7 @@ domain_filters = { def validate_domain(raw): + raw = raw.split('://')[-1] if not raw: raise ValueError() diff --git a/tests/test_routes.py b/tests/test_routes.py index 2fed96d8d32eb84825aeab9d1ce1574424c9fbab..0d8103edaa015c26fbb310265cf1f80da4e0c190 100644 --- a/tests/test_routes.py +++ b/tests/test_routes.py @@ -2,6 +2,7 @@ import datetime import pytest from funkwhale_network import db +from funkwhale_network import routes from funkwhale_network import serializers @@ -129,3 +130,12 @@ async def test_domains_rss(db_conn, factories, client, mocker): assert resp.status == 200 assert await resp.text() == expected assert resp.headers["content-type"] == "application/rss+xml" + + +@pytest.mark.parametrize('input,expected', [ + ('example.com', 'example.com'), + ('http://example.com', 'example.com'), + ('https://example.com', 'example.com'), +]) +def test_validate_domain(input, expected): + assert routes.validate_domain(input) == expected