Skip to content
Snippets Groups Projects
db.py 9.14 KiB
Newer Older
  • Learn to ignore specific revisions
  • import aiopg
    
    Georg Krause's avatar
    Georg Krause committed
    from funkwhale_network import settings
    
    
    class DB:
    
        pool: aiopg.Pool
    
    
            self.TABLES = [
                ("domains", self.create_domains_table),
                ("checks", self.create_checks_table),
            ]
    
            await DB.create_pool()
    
        async def __aexit__(self, exc_type, exc, tb):
            return
    
        @classmethod
        async def close_pool(cls):
            cls.pool.close()
            await cls.pool.wait_closed()
    
        @classmethod
        async def create_pool(cls):
            try:
                cls.pool
            except AttributeError:
                cls.pool = await aiopg.create_pool(settings.DB_DSN)
    
    
        def get_cursor(self):
            yield self.pool.cursor()
    
        async def create_domains_table(self):
    
            with await self.pool.cursor() as cursor:
    
                await cursor.execute(
                    """
                    CREATE TABLE IF NOT EXISTS domains (
                        name VARCHAR(255) PRIMARY KEY,
                        node_name VARCHAR(255) NULL,
                        blocked    BOOLEAN              DEFAULT false,
                        first_seen    TIMESTAMP WITH TIME ZONE DEFAULT NOW()
                    );
                    """
                )
                cursor.close()
    
        async def create_checks_table(self):
            sql = """
            CREATE TABLE IF NOT EXISTS checks (
                time        TIMESTAMPTZ       NOT NULL,
                domain      VARCHAR(255) REFERENCES domains(name),
                up    BOOLEAN              NOT NULL,
                open_registrations    BOOLEAN              NULL,
                private    BOOLEAN              NULL,
                federation_enabled BOOLEAN              NULL,
                anonymous_can_listen BOOLEAN              NULL,
                usage_users_total INTEGER NULL,
                usage_users_active_half_year INTEGER NULL,
                usage_users_active_month INTEGER NULL,
                usage_listenings_total INTEGER NULL,
                library_tracks_total INTEGER NULL,
                library_albums_total INTEGER NULL,
                library_artists_total INTEGER NULL,
                library_music_hours INTEGER NULL,
                software_name VARCHAR(255) NULL,
                software_version_major SMALLINT NULL,
                software_version_minor SMALLINT NULL,
                software_version_patch SMALLINT NULL,
                software_prerelease VARCHAR(255) NULL,
                software_build VARCHAR(255) NULL
            );
            ALTER TABLE checks ADD COLUMN IF NOT EXISTS usage_downloads_total INTEGER NULL;
            SELECT create_hypertable('checks', 'time', if_not_exists => TRUE);
            """
    
            with await self.pool.cursor() as cursor:
    
                await cursor.execute(sql)
                cursor.close()
    
        async def create(self):
            for _, create_handler in self.TABLES:
                await create_handler()
    
        async def clear(self):
    
            with await self.pool.cursor() as cursor:
    
                    await cursor.execute(f"DROP TABLE IF EXISTS {table} CASCADE")
    
    
        async def get_latest_check_by_domain(self):
            sql = """
    
    jooola's avatar
    jooola committed
                SELECT DISTINCT on (domain) domain, *
                FROM checks
                INNER JOIN domains ON checks.domain = domains.name
                WHERE private = %s AND domains.blocked = false
                ORDER BY domain, time DESC
            """
    
            with await self.pool.cursor(
                cursor_factory=psycopg2.extras.RealDictCursor
    
            ) as cursor:
    
                await cursor.execute(sql, [False])
                return list(await cursor.fetchall())
    
        def increment_stat(self, data, key, value):
            if not value:
                return
            data[key] += value
    
        async def get_stats(self):
            checks = await self.get_latest_check_by_domain()
            data = {
                "users": {"total": 0, "activeMonth": 0, "activeHalfyear": 0},
                "instances": {"total": 0, "anonymousCanListen": 0, "openRegistrations": 0},
                "artists": {"total": 0},
                "albums": {"total": 0},
                "tracks": {"total": 0},
                "listenings": {"total": 0},
                "downloads": {"total": 0},
            }
            for check in checks:
                self.increment_stat(data["users"], "total", check["usage_users_total"])
    
                self.increment_stat(
                    data["users"], "activeMonth", check["usage_users_active_month"]
                )
    
                self.increment_stat(
                    data["users"], "activeHalfyear", check["usage_users_active_half_year"]
                )
                self.increment_stat(data["instances"], "total", 1)
                self.increment_stat(
                    data["instances"], "openRegistrations", int(check["open_registrations"])
                )
                self.increment_stat(
    
                    data["instances"],
                    "anonymousCanListen",
                    int(check["anonymous_can_listen"]),
                )
                self.increment_stat(
                    data["artists"], "total", int(check["library_artists_total"])
                )
                self.increment_stat(
                    data["tracks"], "total", int(check["library_tracks_total"])
                )
                self.increment_stat(
                    data["albums"], "total", int(check["library_albums_total"])
    
                )
                self.increment_stat(
                    data["listenings"], "total", int(check["usage_listenings_total"])
                )
                self.increment_stat(
                    data["downloads"], "total", int(check["usage_downloads_total"] or 0)
                )
            return data
    
        def get_domain_query(self, **kwargs):
    
    jooola's avatar
    jooola committed
            base_query = """
                SELECT DISTINCT on (domain) domain, *
                FROM checks
                INNER JOIN domains ON checks.domain = domains.name
                WHERE domains.blocked = false
                ORDER BY domain, time DESC
            """
    
            return base_query.format(where_clause=""), []
    
    
        async def get_all_domains(self):
    
            with await self.pool.cursor(
                cursor_factory=psycopg2.extras.RealDictCursor
    
            ) as cursor:
                await cursor.execute("SELECT name FROM domains")
                domains = list(await cursor.fetchall())
                return domains
    
    
        async def get_domains(self, **kwargs):
    
            with await self.pool.cursor(
                cursor_factory=psycopg2.extras.RealDictCursor
    
            ) as cursor:
    
                filters = kwargs.copy()
                filters.setdefault("private", False)
                filters.setdefault("up", True)
                query, params = self.get_domain_query()
                await cursor.execute(query, params)
                domains = list(await cursor.fetchall())
                # we do the filtering in Python because I didn't figure how to filter on the latest check
                # values only
                supported_fields = dict(
                    [
                        ("up", "up"),
                        ("open_registrations", "open_registrations"),
                        ("federation_enabled", "federation_enabled"),
                        ("anonymous_can_listen", "anonymous_can_listen"),
                        ("private", "private"),
                    ]
                )
                filters = [
                    (supported_fields[key], value)
                    for key, value in filters.items()
                    if key in supported_fields
                ]
    
                domains = [d for d in domains if self.should_keep(d, filters)]
                return domains
    
        def should_keep(self, domain, filters):
            for key, value in filters:
                if domain[key] != value:
                    return False
            return True
    
        async def get_domain(self, name):
    
            with await self.pool.cursor(
                cursor_factory=psycopg2.extras.RealDictCursor
    
            ) as cursor:
    
                await cursor.execute("SELECT * FROM domains WHERE name = %s", (name,))
                return list(await cursor.fetchall())[0]
    
        async def save_check(self, data):
    
            with await self.pool.cursor(
                cursor_factory=psycopg2.extras.RealDictCursor
    
            ) as cursor:
    
                node_name = data.pop("node_name", None)
                fields, values = [], []
                for field, value in data.items():
                    fields.append(field)
                    values.append(value)
    
                sql = "INSERT INTO checks (time, {}) VALUES (NOW(), {}) RETURNING *".format(
                    ", ".join(fields), ", ".join(["%s" for _ in values])
                )
                await cursor.execute(sql, values)
                check = await cursor.fetchone()
    
                if data.get("private") is True:
                    # let's clean previous checks
                    sql = "DELETE FROM checks WHERE domain = %s"
                    await cursor.execute(sql, [data["domain"]])
                    return
                if node_name:
                    await cursor.execute(
                        "UPDATE domains SET node_name = %s WHERE name = %s",
                        [node_name, data["domain"]],
                    )
    
                return check
    
        async def create_domain(self, data):
    
            with await self.pool.cursor(
                cursor_factory=psycopg2.extras.RealDictCursor
    
            ) as cursor:
    
                sql = "INSERT INTO domains (name) VALUES (%s) ON CONFLICT DO NOTHING RETURNING *"
                await cursor.execute(sql, [data["name"]])
                domain = await cursor.fetchone()
                return domain