import collections
import pyld

from rest_framework import serializers
from django.core import validators

from . import ns

_cache = {}


def document_loader(url):
    loader = pyld.jsonld.requests_document_loader()
    doc = ns.get_by_url(url)
    if doc:
        return doc
    if url in _cache:
        return _cache[url]
    resp = loader(url)
    _cache[url] = resp
    return resp


def expand(document):
    return pyld.jsonld.expand(document)


pyld.jsonld.set_document_loader(document_loader)


class TypeField(serializers.ListField):
    def __init__(self, *args, **kwargs):
        kwargs.setdefault("child", serializers.URLField())
        super().__init__(*args, **kwargs)

    def to_internal_value(self, value):
        try:
            return value[0]
        except IndexError:
            self.fail("invalid", value)

    def to_representation(self, value):
        return value


class JsonLdField(serializers.ListField):
    def __init__(self, *args, **kwargs):
        # by default, pyld will expand all fields to list such as
        # [{"@value": "something"}]
        # we may not always want this behaviour
        self.single = kwargs.pop("single", False)
        self.value_only = kwargs.pop("value_only", False)
        self.id_only = kwargs.pop("id_only", False)
        self.base_field = kwargs.pop("base_field")
        super().__init__(*args, **kwargs)
        if self.single:
            self.validators = [
                v
                for v in self.validators
                if not isinstance(
                    v, (validators.MinLengthValidator, validators.MaxLengthValidator)
                )
            ]

    def to_internal_value(self, value):
        self.source_attrs = [self.field_name]
        value = super().to_internal_value(value)
        if self.id_only:
            value = [v["@id"] for v in value]
        elif self.value_only:
            value = [v["@value"] for v in value]
        if value and self.single:
            return value[0]
        return value

    def to_representation(self, value):
        return self.base_field.to_representation(value)


class Serializer(serializers.Serializer):
    jsonld_fields = []
    type_source = None
    id_source = None
    include_id = True
    include_type = True
    valid_types = []

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        if self.include_id:
            self.fields["@id"] = serializers.URLField(
                required=True, **{"source": self.id_source} if self.id_source else {}
            )
        if self.include_type:
            self.fields["@type"] = TypeField(
                required=True,
                **{"source": self.type_source} if self.type_source else {}
            )
        for field_name, field, kw in self.jsonld_fields:
            self.fields[field_name] = JsonLdField(
                required=field.required,
                base_field=field,
                min_length=1 if field.required else 0,
                **kw
            )

    def run_validation(self, initial_data):
        try:
            document = expand(initial_data)[0]
        except IndexError:
            raise serializers.ValidationError(
                "Cannot parse json-ld, maybe you are missing context."
            )
        return super().run_validation(document)

    def validate(self, validated_data):
        validated_data = super().validate(validated_data)
        if self.valid_types and "@type" in validated_data:
            if validated_data["@type"] not in self.valid_types:
                raise serializers.ValidationError(
                    {
                        "@type": "Invalid type. Allowed types are: {}".format(
                            ", ".join(self.valid_types)
                        )
                    }
                )
        return validated_data

    def to_representation(self, data):
        data = super().to_representation(data)
        compacted = pyld.jsonld.compact(data, {"@context": ns.CONTEXTS})
        # While compacted repr is json-ld valid, it's probably not was is
        # expected by most implementations around in the fediverse,
        # so we remove the semicolons completely and the prefix
        return remove_semicolons(compacted)


def remove_semicolons(struct: dict, skip=["@context"]):
    """
    given a JSON-LD  strut such as {
        'as:inbox': 'https://:test'
    },
    returns  {
        'inbox': 'https://:test'
    },
    """
    new_data = collections.OrderedDict()
    for k, v in struct.items():
        if k in skip:
            new_data[k] = v
            continue
        parts = k.split(":")
        if len(parts) > 1:
            k = ":".join(parts[1:])
        if isinstance(v, dict):
            v = remove_semicolons(v)
        elif isinstance(v, (list, tuple)):
            new_v = []
            for e in v:
                if isinstance(e, dict):
                    new_v.append(remove_semicolons(e))
                else:
                    new_v.append(e)
            v = new_v

        new_data[k] = v
    return new_data