import asyncio
import functools
import sys

import aiohttp
import arq
import click

from funkwhale_network.db import DB
from funkwhale_network.worker import WorkerSettings

from . import output


def async_command(f):
    def wrapper(*args, **kwargs):
        loop = asyncio.get_event_loop()
        return loop.run_until_complete(f(*args, **kwargs))

    return functools.update_wrapper(wrapper, f)


def conn_command(f):
    async def wrapper(*args, **kwargs):
        from . import db, settings

        pool = await db.get_pool(settings.DB_DSN)
        try:
            async with pool.acquire() as conn:
                kwargs["conn"] = conn
                return await f(*args, **kwargs)
        finally:
            pool.close()
            await pool.wait_closed()

    return functools.update_wrapper(wrapper, f)


@click.group()
def cli():
    pass


@cli.group()
def db():
    """
    Database related commands (migrate, clear…)
    """
    pass


@cli.group()
def worker():
    pass


@db.command()
@async_command
async def migrate():
    """
    Create database tables.
    """
    sys.stdout.write("Migrating …")

    async with DB() as db:
        await db.create()

    sys.stdout.write(" … Done")


@db.command()
@async_command
async def clear():
    """
    Drop database tables.
    """

    async with DB() as db:
        await db.clear()


@cli.command()
def server():
    """
    Start web server.
    """
    from funkwhale_network import server, settings

    server.start(port=settings.PORT)


async def launch_domain_poll(session, domain):
    from . import crawler

    return await crawler.check(session=session, domain=domain)


@cli.command()
@click.argument("domains", type=str, nargs=-1)
@async_command
async def poll(domains):
    """
    Retrieve and store data for the specified domains.
    """
    from . import crawler

    await DB.create_pool()

    if not domains:
        async with DB() as db:
            domains_db = await db.get_all_domains()
            domains = [d["name"] for d in domains_db]

    kwargs = crawler.get_session_kwargs()
    async with aiohttp.ClientSession(**kwargs) as session:
        tasks = [launch_domain_poll(session, d) for d in domains]
        await asyncio.wait(tasks)
        await DB.close_pool()

    return


NOOP = object()


@cli.command()
@click.argument("domain", type=str, nargs=-1)
@click.option("--use-public", is_flag=True)
@click.option("--detail", default=NOOP)
@click.option("--passes", type=click.INT, default=999)
@click.option("--sort", default="Active users (30d)")
@async_command
async def crawl(domain, use_public, detail, passes, sort):
    """
    Crawl the network starting from the given domain(s).
    """
    from . import crawler, settings

    kwargs = crawler.get_session_kwargs()
    async with aiohttp.ClientSession(**kwargs) as session:
        if use_public:
            url = "https://network.funkwhale.audio/api/domains?up=true"
            click.echo(f"Retrieving list of public pods from {url}…")
            response = await session.get(url)
            json = await response.json()
            domain = {d["name"] for d in json["results"]}
        click.echo(f"Launching crawl with {len(domain)} seed domains…")
        results = await crawler.crawl_all(
            session, *domain, stdout=click.echo, max_passes=passes
        )

    click.echo("Complete after {} passes:".format(results["pass_number"]))
    aggregate = aggregate_crawl_results(results["results"])

    if detail != NOOP:

        click.echo("")
        click.echo("Info per domain")
        click.echo("===============")
        click.echo("")

        if not detail:
            fields = [
                "Domain",
                "Active users (30d)",
                "Users",
                "Listenings",
                "Downloads",
                "Open registrations",
                "Anonymous access",
                "Private",
            ]
        else:
            fields = detail.split(",")

        click.echo(
            output.table(
                results["results"].values(), type="Domain", fields=fields, sort=sort
            )
        )

    click.echo("")
    click.echo("Aggregated data")
    click.echo("===============")
    click.echo("")
    click.echo(
        output.obj_table(
            aggregate,
            type="Summary",
            fields=[
                "Domains",
                "Active users (30d)",
                "Active users (180d)",
                "Users",
                "Listenings",
                "Downloads",
                "Tracks",
                "Albums",
                "Artists",
                "Hours of music",
                "Open registrations",
                "Federation enabled",
                "Anonymous access",
                "Private",
            ],
        )
    )


def aggregate_crawl_results(domains_info):
    def count_true(values):
        return sum([1 for v in values if v])

    def permissive_sum(values):
        return sum([v for v in values if v])

    fields = {
        "domain": len,
        "usage_users_total": permissive_sum,
        "usage_users_active_half_year": permissive_sum,
        "usage_users_active_month": permissive_sum,
        "usage_listenings_total": permissive_sum,
        "usage_downloads_total": permissive_sum,
        "library_tracks_total": permissive_sum,
        "library_albums_total": permissive_sum,
        "library_artists_total": permissive_sum,
        "library_music_hours": permissive_sum,
        "open_registrations": count_true,
        "federation_enabled": count_true,
        "anonymous_can_listen": count_true,
        "private": count_true,
    }
    aggregate = {}
    for field, handler in fields.items():
        values = []
        for info in domains_info.values():
            values.append(info[field])
        aggregate[field] = handler(values)

    return aggregate


@worker.command()
@click.option("-v", "--verbose", is_flag=True)
@click.option("--check", is_flag=True)
def start(*, check, verbose):
    arq.worker.run_worker(WorkerSettings)


def main():
    cli()


if __name__ == "__main__":
    main()