Skip to content
Snippets Groups Projects
routes.py 3.25 KiB
Newer Older
  • Learn to ignore specific revisions
  • import aiohttp
    
    import json
    import os
    import urllib.parse
    
    from webargs import fields
    from webargs.aiohttpparser import parser
    
    
    from . import crawler
    
    from . import exceptions
    
    Eliot Berriot's avatar
    Eliot Berriot committed
    from . import settings
    
    BASE_DIR = os.path.dirname(os.path.abspath(__file__))
    STATIC_DIR = os.path.join(BASE_DIR, "static")
    
    
    domain_filters = {
        "up": fields.Bool(),
        "open_registrations": fields.Bool(),
        "anonymous_can_listen": fields.Bool(),
        "federation_enabled": fields.Bool(),
    
    
    def validate_domain(raw):
    
        raw = raw.split("://")[-1]
    
        if not raw:
            raise ValueError()
    
        if not isinstance(raw, str):
            raise ValueError()
    
        url = f"http://{raw}/"
        v = urllib.parse.urlparse(url).hostname
    
    Eliot Berriot's avatar
    Eliot Berriot committed
        if not v or "." not in v:
    
            raise ValueError()
        return v
    
    
    async def index(request):
        with open(os.path.join(BASE_DIR, "index.html"), "r") as f:
            index_html = f.read()
    
    Eliot Berriot's avatar
    Eliot Berriot committed
        return web.Response(
            text=index_html.format(settings=settings), content_type="text/html"
        )
    
    
    async def domains(request):
        if request.method == "GET":
    
            filters = await parser.parse(domain_filters, request, location="querystring")
    
            limit = int(request.query.get("limit", default=0))
    
            rows = await db.get_domains(request["conn"], **filters)
    
            total = len(rows)
            if limit:
                rows = rows[:limit]
    
            if request.query.get("format") == "rss":
                response = web.Response(
                    text=serializers.serialize_rss_feed_from_checks(rows)
                )
                response.headers["Content-Type"] = "application/rss+xml"
                return response
            else:
                payload = {
                    "count": total,
                    "previous": None,
                    "next": None,
                    "results": [
                        serializers.serialize_domain_from_check(check) for check in rows
                    ],
                }
                return web.json_response(payload)
    
            try:
                payload = await request.json()
            except json.decoder.JSONDecodeError:
                payload = await request.post()
            try:
                payload = {"name": validate_domain(payload["name"])}
            except (TypeError, KeyError, AttributeError, ValueError):
                return web.json_response(
                    {"error": f"Invalid payload {payload}"}, status=400
                )
            try:
                kwargs = crawler.get_session_kwargs()
                async with aiohttp.ClientSession(**kwargs) as session:
                    await crawler.fetch_nodeinfo(session, payload["name"])
            except (aiohttp.client_exceptions.ClientError, exceptions.CrawlerError) as e:
                return web.json_response(
                    {"error": f"Invalid domain name {payload['name']}"}, status=400
                )
    
    
            domain = await serializers.create_domain(request["conn"], payload)
    
            if domain:
                payload = serializers.serialize_domain(domain)
                return web.json_response(payload, status=201)
            else:
                # already exist
                return web.json_response({}, status=204)
    
    Eliot Berriot's avatar
    Eliot Berriot committed
    
    
    async def stats(request):
        payload = await db.get_stats(request["conn"])
        return web.json_response(payload)