Skip to content
Snippets Groups Projects
cli.py 8.45 KiB
Newer Older
import aiohttp
import asyncio
import click
import logging.config
Eliot Berriot's avatar
Eliot Berriot committed
import ssl
import sys

from . import output
Eliot Berriot's avatar
Eliot Berriot committed

SSL_PROTOCOLS = (asyncio.sslproto.SSLProtocol,)
try:
    import uvloop.loop
except ImportError:
    pass
else:
    SSL_PROTOCOLS = (*SSL_PROTOCOLS, uvloop.loop.SSLProtocol)


def ignore_aiohttp_ssl_eror(loop):
    """Ignore aiohttp #3535 / cpython #13548 issue with SSL data after close

    There is an issue in Python 3.7 up to 3.7.3 that over-reports a
    ssl.SSLError fatal error (ssl.SSLError: [SSL: KRB5_S_INIT] application data
    after close notify (_ssl.c:2609)) after we are already done with the
    connection. See GitHub issues aio-libs/aiohttp#3535 and
    python/cpython#13548.

    Given a loop, this sets up an exception handler that ignores this specific
    exception, but passes everything else on to the previous exception handler
    this one replaces.

    Checks for fixed Python versions, disabling itself when running on 3.7.4+
    or 3.8.

    """
    if sys.version_info >= (3, 7, 4):
        return

    orig_handler = loop.get_exception_handler()

    def ignore_ssl_error(loop, context):
        if context.get("message") in {
            "SSL error in data received",
            "Fatal error on transport",
            "SSL handshake failed",
            "[SSL: TLSV1_ALERT_INTERNAL_ERROR] tlsv1 alert internal error",
        }:
            # validate we have the right exception, transport and protocol
            exception = context.get("exception")
            protocol = context.get("protocol")
            if (
                isinstance(exception, ssl.SSLError)
                and exception.reason in ("KRB5_S_INIT", "TLSV1_ALERT_INTERNAL_ERROR")
                and isinstance(protocol, SSL_PROTOCOLS)
            ):
                if loop.get_debug():
                    asyncio.log.logger.debug("Ignoring asyncio SSL KRB5_S_INIT error")
                return
        if orig_handler is not None:
            orig_handler(loop, context)
        else:
            loop.default_exception_handler(context)

    loop.set_exception_handler(ignore_ssl_error)


def async_command(f):
    def wrapper(*args, **kwargs):
        loop = asyncio.get_event_loop()
Eliot Berriot's avatar
Eliot Berriot committed
        ignore_aiohttp_ssl_eror(loop)
        return loop.run_until_complete(f(*args, **kwargs))

    return functools.update_wrapper(wrapper, f)


def conn_command(f):
    async def wrapper(*args, **kwargs):
        from . import db
        from . import settings

        pool = await db.get_pool(settings.DB_DSN)
        try:
            async with pool.acquire() as conn:
                kwargs["conn"] = conn
                return await f(*args, **kwargs)
        finally:
            pool.close()
            await pool.wait_closed()

    return functools.update_wrapper(wrapper, f)


@click.group()
def cli():
    pass


@cli.group()
def db():
    """
    Database related commands (migrate, clear…)
    """
    pass


@cli.group()
def worker():
    pass


@db.command()
@async_command
@conn_command
async def migrate(conn):
    """
    Create database tables.
    """
    from . import db

    await db.create(conn)


@db.command()
@async_command
@conn_command
async def clear(conn):
    """
    Drop database tables.
    """
    from . import db

    await db.clear(conn)


@cli.command()
def server():
    """
    Start web server.
    """
    from funkwhale_network import server
    from funkwhale_network import settings
Eliot Berriot's avatar
Eliot Berriot committed
    server.start(port=settings.PORT)


async def launch_domain_poll(pool, session, domain):
    from . import crawler

    async with pool.acquire() as conn:
        return await crawler.check(conn=conn, session=session, domain=domain)


@cli.command()
@click.argument("domain", type=str, nargs=-1)
@async_command
async def poll(domain):
    """
    Retrieve and store data for the specified domains.
    """
    from . import crawler
    from . import db
    from . import settings
    from . import worker

    pool = await db.get_pool(settings.DB_DSN)
    if not domain:
        click.echo("Polling all domains…")
Georg Krause's avatar
Georg Krause committed
        crawler = worker.Crawler()
        return await crawler.poll_all()

    try:
        kwargs = crawler.get_session_kwargs()
        async with aiohttp.ClientSession(**kwargs) as session:
            tasks = [launch_domain_poll(pool, session, d) for d in domain]
            return await asyncio.wait(tasks)

    finally:
        pool.close()
        await pool.wait_closed()


Eliot Berriot's avatar
Eliot Berriot committed
NOOP = object()


@cli.command()
@click.argument("domain", type=str, nargs=-1)
@click.option("--use-public", is_flag=True)
@click.option("--detail", default=NOOP)
@click.option("--passes", type=click.INT, default=999)
@click.option("--sort", default="Active users (30d)")
@async_command
async def crawl(domain, use_public, detail, passes, sort):
    """
    Crawl the network starting from the given domain(s).
    """
    from . import crawler
    from . import settings
Eliot Berriot's avatar
Eliot Berriot committed

    kwargs = crawler.get_session_kwargs()
    async with aiohttp.ClientSession(**kwargs) as session:
        if use_public:
            url = "https://network.funkwhale.audio/api/domains?up=true"
            click.echo("Retrieving list of public pods from {}…".format(url))
            response = await session.get(url)
            json = await response.json()
            domain = set([d["name"] for d in json["results"]])
        click.echo("Launching crawl with {} seed domains…".format(len(domain)))
        results = await crawler.crawl_all(
            session, *domain, stdout=click.echo, max_passes=passes
        )

    click.echo("Complete after {} passes:".format(results["pass_number"]))
    aggregate = aggregate_crawl_results(results["results"])

    if detail != NOOP:

        click.echo("")
        click.echo("Info per domain")
        click.echo("===============")
        click.echo("")

        if not detail:
            fields = [
                "Domain",
                "Active users (30d)",
                "Users",
                "Listenings",
Eliot Berriot's avatar
Eliot Berriot committed
                "Open registrations",
                "Anonymous access",
                "Private",
            ]
        else:
            fields = detail.split(",")

        click.echo(
            output.table(
                results["results"].values(), type="Domain", fields=fields, sort=sort
            )
        )

    click.echo("")
    click.echo("Aggregated data")
    click.echo("===============")
    click.echo("")
    click.echo(
        output.obj_table(
            aggregate,
            type="Summary",
            fields=[
                "Domains",
                "Active users (30d)",
                "Active users (180d)",
                "Users",
                "Listenings",
Eliot Berriot's avatar
Eliot Berriot committed
                "Tracks",
                "Albums",
                "Artists",
                "Hours of music",
                "Open registrations",
                "Federation enabled",
                "Anonymous access",
                "Private",
            ],
        )
    )


def aggregate_crawl_results(domains_info):
    def count_true(values):
        return sum([1 for v in values if v])

    def permissive_sum(values):
        return sum([v for v in values if v])

    fields = {
        "domain": len,
        "usage_users_total": permissive_sum,
        "usage_users_active_half_year": permissive_sum,
        "usage_users_active_month": permissive_sum,
        "usage_listenings_total": permissive_sum,
        "usage_downloads_total": permissive_sum,
Eliot Berriot's avatar
Eliot Berriot committed
        "library_tracks_total": permissive_sum,
        "library_albums_total": permissive_sum,
        "library_artists_total": permissive_sum,
        "library_music_hours": permissive_sum,
        "open_registrations": count_true,
        "federation_enabled": count_true,
        "anonymous_can_listen": count_true,
        "private": count_true,
    }
    aggregate = {}
    for field, handler in fields.items():
        values = []
        for info in domains_info.values():
            values.append(info[field])
        aggregate[field] = handler(values)

    return aggregate


@worker.command()
@click.option("-v", "--verbose", is_flag=True)
@click.option("--check", is_flag=True)
def start(*, check, verbose):
    # worker = arq.worker.import_string("funkwhale_network.worker", "Worker")
    # logging.config.dictConfig(worker.logging_config(verbose))
Georg Krause's avatar
Georg Krause committed
        pass
        # exit(worker.check_health())
Georg Krause's avatar
Georg Krause committed
        pass
        # arq.RunWorkerProcess("funkwhale_network.worker", "Worker", burst=False)