Skip to content
Snippets Groups Projects
middleware.py 10.3 KiB
Newer Older
  • Learn to ignore specific revisions
  • 
    from django import http
    from django.conf import settings
    from django.core.cache import caches
    from django import urls
    
    from funkwhale_api.federation import utils as federation_utils
    
    
    Eliot Berriot's avatar
    Eliot Berriot committed
    from . import session
    
    from . import utils
    
    EXCLUDED_PATHS = ["/api", "/federation", "/.well-known"]
    
    
    def should_fallback_to_spa(path):
        if path == "/":
            return True
        return not any([path.startswith(m) for m in EXCLUDED_PATHS])
    
    
    def serve_spa(request):
        html = get_spa_html(settings.FUNKWHALE_SPA_HTML_ROOT)
        head, tail = html.split("</head>", 1)
    
        if settings.FUNKWHALE_SPA_REWRITE_MANIFEST:
            new_url = (
                settings.FUNKWHALE_SPA_REWRITE_MANIFEST_URL
                or federation_utils.full_url(urls.reverse("api:v1:instance:spa-manifest"))
            )
            head = replace_manifest_url(head, new_url)
    
    
        if not preferences.get("common__api_authentication_required"):
            try:
                request_tags = get_request_head_tags(request) or []
            except urls.exceptions.Resolver404:
                # we don't have any custom tags for this route
                request_tags = []
        else:
            # API is not open, we don't expose any custom data
            request_tags = []
        default_tags = get_default_head_tags(request.path)
        unique_attributes = ["name", "property"]
    
        final_tags = request_tags
        skip = []
    
        for t in final_tags:
            for attr in unique_attributes:
                if attr in t:
                    skip.append(t[attr])
        for t in default_tags:
            existing = False
            for attr in unique_attributes:
                if t.get(attr) in skip:
                    existing = True
                    break
            if not existing:
                final_tags.append(t)
    
        # let's inject our meta tags in the HTML
        head += "\n" + "\n".join(render_tags(final_tags)) + "\n</head>"
    
        css = get_custom_css() or ""
        if css:
            # We add the style add the end of the body to ensure it has the highest
            # priority (since it will come after other stylesheets)
            body, tail = tail.split("</body>", 1)
            css = "<style>{}</style>".format(css)
            tail = body + "\n" + css + "\n</body>" + tail
    
    MANIFEST_LINK_REGEX = re.compile(r"<link .*rel=(?:'|\")?manifest(?:'|\")?.*>")
    
    
    def replace_manifest_url(head, new_url):
        replacement = '<link rel=manifest href="{}">'.format(new_url)
        head = MANIFEST_LINK_REGEX.sub(replacement, head)
        return head
    
    
    
        return get_spa_file(spa_url, "index.html")
    
    
    def get_spa_file(spa_url, name):
    
            # XXX: spa_url is an absolute path to index.html, on the local disk.
            # However, we may want to access manifest.json or other files as well, so we
            # strip the filename
            path = os.path.join(os.path.dirname(spa_url), name)
    
            with open(path) as f:
    
        cache_key = "spa-file:{}:{}".format(spa_url, name)
    
        cached = caches["local"].get(cache_key)
        if cached:
            return cached
    
    
        response = session.get_session().get(utils.join_url(spa_url, name),)
    
        response.raise_for_status()
        content = response.text
        caches["local"].set(cache_key, content, settings.FUNKWHALE_SPA_HTML_CACHE_DURATION)
        return content
    
    
    def get_default_head_tags(path):
        instance_name = preferences.get("instance__name")
        short_description = preferences.get("instance__short_description")
        app_name = settings.APP_NAME
    
        parts = [instance_name, app_name]
    
        return [
            {"tag": "meta", "property": "og:type", "content": "website"},
            {
                "tag": "meta",
                "property": "og:site_name",
                "content": " - ".join([p for p in parts if p]),
            },
            {"tag": "meta", "property": "og:description", "content": short_description},
            {
                "tag": "meta",
                "property": "og:image",
                "content": utils.join_url(settings.FUNKWHALE_URL, "/front/favicon.png"),
            },
            {
                "tag": "meta",
                "property": "og:url",
                "content": utils.join_url(settings.FUNKWHALE_URL, path),
            },
        ]
    
    
    def render_tags(tags):
        """
        Given a dict like {'tag': 'meta', 'hello': 'world'}
        return a html ready tag like
        <meta hello="world" />
        """
        for tag in tags:
    
            yield "<{tag} {attrs} />".format(
                tag=tag.pop("tag"),
                attrs=" ".join(
                    [
                        '{}="{}"'.format(a, html.escape(str(v)))
                        for a, v in sorted(tag.items())
                        if v
                    ]
                ),
            )
    
    
    def get_request_head_tags(request):
        match = urls.resolve(request.path, urlconf=settings.SPA_URLCONF)
        return match.func(request, *match.args, **match.kwargs)
    
    
    
    def get_custom_css():
        css = preferences.get("ui__custom_css").strip()
        if not css:
            return
    
        return xml.sax.saxutils.escape(css)
    
    
    
    class SPAFallbackMiddleware:
        def __init__(self, get_response):
            self.get_response = get_response
    
        def __call__(self, request):
            response = self.get_response(request)
    
            if response.status_code == 404 and should_fallback_to_spa(request.path):
                return serve_spa(request)
    
            return response
    
    
    
    class DevHttpsMiddleware:
        """
        In development, it's sometimes difficult to have django use HTTPS
        when we have django behind nginx behind traefix.
    
        We thus use a simple setting (in dev ONLY) to control that.
        """
    
        def __init__(self, get_response):
            self.get_response = get_response
    
        def __call__(self, request):
            if settings.FORCE_HTTPS_URLS:
                setattr(request.__class__, "scheme", "https")
                setattr(
                    request,
                    "get_host",
                    lambda: request.__class__.get_host(request).replace(":80", ":443"),
                )
            return self.get_response(request)
    
    
    
    def monkey_patch_rest_initialize_request():
        """
        Rest framework use it's own APIRequest, meaning we can't easily
        access our throttling info in the middleware. So me monkey patch the
        `initialize_request` method from rest_framework to keep a link between both requests
        """
        original = views.APIView.initialize_request
    
        def replacement(self, request, *args, **kwargs):
            r = original(self, request, *args, **kwargs)
            setattr(request, "_api_request", r)
            return r
    
        setattr(views.APIView, "initialize_request", replacement)
    
    
    monkey_patch_rest_initialize_request()
    
    
    class ThrottleStatusMiddleware:
        """
        Include useful information regarding throttling in API responses to
        ensure clients can adapt.
        """
    
        def __init__(self, get_response):
            self.get_response = get_response
    
        def __call__(self, request):
            try:
                response = self.get_response(request)
            except throttling.TooManyRequests:
                # manual throttling in non rest_framework view, we have to return
                # the proper response ourselves
                response = http.HttpResponse(status=429)
            request_to_check = request
            try:
                request_to_check = request._api_request
            except AttributeError:
                pass
            throttle_status = getattr(request_to_check, "_throttle_status", None)
            if throttle_status:
                response["X-RateLimit-Limit"] = str(throttle_status["num_requests"])
                response["X-RateLimit-Scope"] = str(throttle_status["scope"])
                response["X-RateLimit-Remaining"] = throttle_status["num_requests"] - len(
                    throttle_status["history"]
                )
                response["X-RateLimit-Duration"] = str(throttle_status["duration"])
                if throttle_status["history"]:
                    now = int(time.time())
                    # At this point, the client can send additional requests
                    oldtest_request = throttle_status["history"][-1]
                    remaining = throttle_status["duration"] - (now - int(oldtest_request))
                    response["Retry-After"] = str(remaining)
                    # At this point, all Rate Limit is reset to 0
                    latest_request = throttle_status["history"][0]
                    remaining = throttle_status["duration"] - (now - int(latest_request))
                    response["X-RateLimit-Reset"] = str(now + remaining)
                    response["X-RateLimit-ResetSeconds"] = str(remaining)
    
            return response
    
    
    
    class ProfilerMiddleware:
        """
        from https://github.com/omarish/django-cprofile-middleware/blob/master/django_cprofile_middleware/middleware.py
        Simple profile middleware to profile django views. To run it, add ?prof to
        the URL like this:
            http://localhost:8000/view/?prof
        Optionally pass the following to modify the output:
        ?sort => Sort the output by a given metric. Default is time.
            See
            http://docs.python.org/2/library/profile.html#pstats.Stats.sort_stats
            for all sort options.
        ?count => The number of rows to display. Default is 100.
        ?download => Download profile file suitable for visualization. For example
            in snakeviz or RunSnakeRun
        This is adapted from an example found here:
        http://www.slideshare.net/zeeg/django-con-high-performance-django-presentation.
        """
    
        def __init__(self, get_response):
            self.get_response = get_response
    
        def __call__(self, request):
            if "prof" not in request.GET:
                return self.get_response(request)
            import profile
            import pstats
    
            profiler = profile.Profile()
            response = profiler.runcall(self.get_response, request)
            profiler.create_stats()
            if "prof-download" in request.GET:
                import marshal
    
                output = marshal.dumps(profiler.stats)
                response = http.HttpResponse(
                    output, content_type="application/octet-stream"
                )
                response["Content-Disposition"] = "attachment; filename=view.prof"
                response["Content-Length"] = len(output)
            stream = io.StringIO()
            stats = pstats.Stats(profiler, stream=stream)
    
            stats.sort_stats(request.GET.get("prof-sort", "cumtime"))
            stats.print_stats(int(request.GET.get("count", 100)))
    
            response = http.HttpResponse("<pre>%s</pre>" % stream.getvalue())
    
            return response