Skip to content
Snippets Groups Projects
Verified Commit 9839d1de authored by Eliot Berriot's avatar Eliot Berriot
Browse files

Added basic validation for nodeinfo data

parent 5e17c3e3
No related branches found
No related tags found
No related merge requests found
from . import schemas
async def fetch(session, domain):
nodeinfo = await get_well_known_data(session, domain)
data = await get_nodeinfo(session, nodeinfo)
return data
async def get_well_known_data(session, domain, protocol="https"):
url = f"https://{domain}/.well-known/nodeinfo"
response = await session.get(url)
return await response.json()
async def get_nodeinfo(session, nodeinfo):
for link in nodeinfo.get("links", []):
if link["rel"] == "http://nodeinfo.diaspora.software/ns/schema/2.0":
response = await session.get(link["href"])
return await response.json()
raise
def clean_nodeinfo(data):
schema = schemas.NodeInfo2Schema()
return schema.load(data)
import marshmallow
import semver
import re
class VersionField(marshmallow.fields.Str):
def deserialize(self, value, *args, **kwargs):
value = super().deserialize(value, *args, **kwargs)
try:
return semver.parse(value)
except ValueError:
# funkwhale does not always include the patch version, so we add the 0 ourself and
# try again
try:
v_regex = r"(\d+\.\d+)"
match = re.match(v_regex, value)
if match and match[0]:
new_version = f"{match[0]}.0"
return semver.parse(value.replace(match[0], new_version, 1))
raise ValueError()
except (ValueError, IndexError):
raise marshmallow.ValidationError(
f"{value} is not a semver version number"
)
return value
class SoftwareSchema(marshmallow.Schema):
name = marshmallow.fields.String(
required=True, validate=[marshmallow.validate.OneOf(["funkwhale", "Funkwhale"])]
)
version = VersionField(required=True)
"""
"openRegistrations": False,
"usage": {"users": {"total": 78}},
"metadata": {
"private": False,
"nodeName": "Funkwhale 101",
"library": {
"federationEnabled": True,
"federationNeedsApproval": True,
"anonymousCanListen": True,
"tracks": {"total": 98552},
"artists": {"total": 9831},
"albums": {"total": 10872},
"music": {"hours": 7650.678055555555},
},
"usage": {
"favorites": {"tracks": {"total": 1683}},
"listenings": {"total": 50294},
},
},
}
"""
class StatisticsSchema(marshmallow.Schema):
total = marshmallow.fields.Integer(required=False)
class UsageSchema(marshmallow.Schema):
users = marshmallow.fields.Nested(StatisticsSchema, required=False)
class LibraryMetadataSchema(marshmallow.Schema):
anonymousCanListen = marshmallow.fields.Boolean(required=True)
federationEnabled = marshmallow.fields.Boolean(required=True)
class MetadataSchema(marshmallow.Schema):
nodeName = marshmallow.fields.String(required=True)
private = marshmallow.fields.Boolean(required=True)
library = usage = marshmallow.fields.Nested(LibraryMetadataSchema, required=False)
class NodeInfo2Schema(marshmallow.Schema):
software = marshmallow.fields.Nested(SoftwareSchema)
openRegistrations = marshmallow.fields.Boolean(required=True)
usage = marshmallow.fields.Nested(UsageSchema, required=False)
metadata = marshmallow.fields.Nested(MetadataSchema, required=False)
......@@ -21,6 +21,8 @@ install_requires =
aiopg
aiohttp
arq
marshmallow
semver
[options.entry_points]
console_scripts =
......@@ -30,6 +32,7 @@ console_scripts =
dev = ipdb
pytest
pytest-mock
aioresponses
[options.packages.find]
exclude =
......
import os
import pytest
from aiohttp import web
import aiohttp
from aioresponses import aioresponses
import funkwhale_network
from funkwhale_network import db
import os
pytest_plugins = "aiohttp.pytest_plugin"
@pytest.fixture
def client(loop, aiohttp_client, populated_db, db_pool):
app = web.Application(middlewares=funkwhale_network.MIDDLEWARES)
app = aiohttp.web.Application(middlewares=funkwhale_network.MIDDLEWARES)
funkwhale_network.prepare_app(app, pool=db_pool)
yield loop.run_until_complete(aiohttp_client(app))
......@@ -38,3 +41,15 @@ async def populated_db(db_pool):
await db.create(conn)
yield conn
await db.clear(conn)
@pytest.fixture
def responses():
with aioresponses() as m:
yield m
@pytest.fixture
async def session(loop):
async with aiohttp.ClientSession() as session:
yield session
from funkwhale_network import crawler
async def test_fetch(session, responses):
domain = "test.domain"
well_known_payload = {
"links": [
{
"rel": "http://nodeinfo.diaspora.software/ns/schema/2.0",
"href": "https://test.domain/nodeinfo/2.0/",
}
]
}
payload = {"hello": "world"}
responses.get(
"https://test.domain/.well-known/nodeinfo", payload=well_known_payload
)
responses.get("https://test.domain/nodeinfo/2.0/", payload=payload)
result = await crawler.fetch(session, domain)
assert result == payload
def test_validate_data(populated_db):
payload = {
"version": "2.0",
"software": {"name": "funkwhale", "version": "0.18-dev+git.b575999e"},
"openRegistrations": False,
"usage": {"users": {"total": 78}},
"metadata": {
"private": False,
"nodeName": "Funkwhale 101",
"library": {
"federationEnabled": True,
"federationNeedsApproval": True,
"anonymousCanListen": True,
"tracks": {"total": 98552},
"artists": {"total": 9831},
"albums": {"total": 10872},
"music": {"hours": 7650.678055555555},
},
"usage": {
"favorites": {"tracks": {"total": 1683}},
"listenings": {"total": 50294},
},
},
}
expected = {
"software": {
"name": "funkwhale",
"version": {
"major": 0,
"minor": 18,
"patch": 0,
"prerelease": "dev",
"build": "git.b575999e",
},
},
"openRegistrations": False,
"usage": {"users": {"total": 78}},
"metadata": {
"private": False,
"nodeName": "Funkwhale 101",
"library": {"federationEnabled": True, "anonymousCanListen": True},
},
}
result = crawler.clean_nodeinfo(payload)
assert result.data == expected
import pytest
from funkwhale_network import schemas
@pytest.mark.parametrize(
"value, expected",
[
(
"1.2.3-dev+build-1",
{
"major": 1,
"minor": 2,
"patch": 3,
"prerelease": "dev",
"build": "build-1",
},
),
(
"1.2-dev+build-1",
{
"major": 1,
"minor": 2,
"patch": 0,
"prerelease": "dev",
"build": "build-1",
},
),
(
"1.2+build-1",
{
"major": 1,
"minor": 2,
"patch": 0,
"prerelease": None,
"build": "build-1",
},
),
(
"1.2",
{"major": 1, "minor": 2, "patch": 0, "prerelease": None, "build": None},
),
],
)
def test_validate_version_number(value, expected):
assert schemas.VersionField().deserialize(value) == expected
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment