Skip to content
Snippets Groups Projects
Select Git revision
  • fix-index.html-encoding
  • develop default protected
  • document-loglevel
  • master
  • 1.0.1
  • 1121-download
  • plugins-v3
  • 876-http-signature
  • plugins-v2
  • plugins
  • 1.0.1
  • 1.0
  • 1.0-rc1
  • 0.21.2
  • 0.21.1
  • 0.21
  • 0.21-rc2
  • 0.21-rc1
  • 0.20.1
  • 0.20.0
  • 0.20.0-rc1
  • 0.19.1
  • 0.19.0
  • 0.19.0-rc2
  • 0.19.0-rc1
  • 0.18.3
  • 0.18.2
  • 0.18.1
  • 0.18
  • 0.17
30 results

middleware.py

Blame
  • Forked from funkwhale / funkwhale
    4795 commits behind the upstream repository.
    middleware.py 10.35 KiB
    import html
    import io
    import os
    import re
    import time
    import xml.sax.saxutils
    
    from django import http
    from django.conf import settings
    from django.core.cache import caches
    from django import urls
    from rest_framework import views
    
    from funkwhale_api.federation import utils as federation_utils
    
    from . import preferences
    from . import session
    from . import throttling
    from . import utils
    
    EXCLUDED_PATHS = ["/api", "/federation", "/.well-known"]
    
    
    def should_fallback_to_spa(path):
        if path == "/":
            return True
        return not any([path.startswith(m) for m in EXCLUDED_PATHS])
    
    
    def serve_spa(request):
        html = get_spa_html(settings.FUNKWHALE_SPA_HTML_ROOT)
        head, tail = html.split("</head>", 1)
        if settings.FUNKWHALE_SPA_REWRITE_MANIFEST:
            new_url = (
                settings.FUNKWHALE_SPA_REWRITE_MANIFEST_URL
                or federation_utils.full_url(urls.reverse("api:v1:instance:spa-manifest"))
            )
            head = replace_manifest_url(head, new_url)
    
        if not preferences.get("common__api_authentication_required"):
            try:
                request_tags = get_request_head_tags(request) or []
            except urls.exceptions.Resolver404:
                # we don't have any custom tags for this route
                request_tags = []
        else:
            # API is not open, we don't expose any custom data
            request_tags = []
        default_tags = get_default_head_tags(request.path)
        unique_attributes = ["name", "property"]
    
        final_tags = request_tags
        skip = []
    
        for t in final_tags:
            for attr in unique_attributes:
                if attr in t:
                    skip.append(t[attr])
        for t in default_tags:
            existing = False
            for attr in unique_attributes:
                if t.get(attr) in skip:
                    existing = True
                    break
            if not existing:
                final_tags.append(t)
    
        # let's inject our meta tags in the HTML
        head += "\n" + "\n".join(render_tags(final_tags)) + "\n</head>"
        css = get_custom_css() or ""
        if css:
            # We add the style add the end of the body to ensure it has the highest
            # priority (since it will come after other stylesheets)
            body, tail = tail.split("</body>", 1)
            css = "<style>{}</style>".format(css)
            tail = body + "\n" + css + "\n</body>" + tail
        return http.HttpResponse(head + tail)
    
    
    MANIFEST_LINK_REGEX = re.compile(r"<link .*rel=(?:'|\")?manifest(?:'|\")?.*>")
    
    
    def replace_manifest_url(head, new_url):
        replacement = '<link rel=manifest href="{}">'.format(new_url)
        head = MANIFEST_LINK_REGEX.sub(replacement, head)
        return head
    
    
    def get_spa_html(spa_url):
        return get_spa_file(spa_url, "index.html")
    
    
    def get_spa_file(spa_url, name):
        if spa_url.startswith("/"):
            # XXX: spa_url is an absolute path to index.html, on the local disk.
            # However, we may want to access manifest.json or other files as well, so we
            # strip the filename
            path = os.path.join(os.path.dirname(spa_url), name)
            # we try to open a local file
            with open(path) as f:
                return f.read()
        cache_key = "spa-file:{}:{}".format(spa_url, name)
        cached = caches["local"].get(cache_key)
        if cached:
            return cached
    
        response = session.get_session().get(utils.join_url(spa_url, name),)
        response.raise_for_status()
        content = response.text
        caches["local"].set(cache_key, content, settings.FUNKWHALE_SPA_HTML_CACHE_DURATION)
        return content
    
    
    def get_default_head_tags(path):
        instance_name = preferences.get("instance__name")
        short_description = preferences.get("instance__short_description")
        app_name = settings.APP_NAME
    
        parts = [instance_name, app_name]
    
        return [
            {"tag": "meta", "property": "og:type", "content": "website"},
            {
                "tag": "meta",
                "property": "og:site_name",
                "content": " - ".join([p for p in parts if p]),
            },
            {"tag": "meta", "property": "og:description", "content": short_description},
            {
                "tag": "meta",
                "property": "og:image",
                "content": utils.join_url(settings.FUNKWHALE_URL, "/front/favicon.png"),
            },
            {
                "tag": "meta",
                "property": "og:url",
                "content": utils.join_url(settings.FUNKWHALE_URL, path),
            },
        ]
    
    
    def render_tags(tags):
        """
        Given a dict like {'tag': 'meta', 'hello': 'world'}
        return a html ready tag like
        <meta hello="world" />
        """
        for tag in tags:
    
            yield "<{tag} {attrs} />".format(
                tag=tag.pop("tag"),
                attrs=" ".join(
                    [
                        '{}="{}"'.format(a, html.escape(str(v)))
                        for a, v in sorted(tag.items())
                        if v
                    ]
                ),
            )
    
    
    def get_request_head_tags(request):
        match = urls.resolve(request.path, urlconf=settings.SPA_URLCONF)
        return match.func(request, *match.args, **match.kwargs)
    
    
    def get_custom_css():
        css = preferences.get("ui__custom_css").strip()
        if not css:
            return
    
        return xml.sax.saxutils.escape(css)
    
    
    class SPAFallbackMiddleware:
        def __init__(self, get_response):
            self.get_response = get_response
    
        def __call__(self, request):
            response = self.get_response(request)
    
            if response.status_code == 404 and should_fallback_to_spa(request.path):
                return serve_spa(request)
    
            return response
    
    
    class DevHttpsMiddleware:
        """
        In development, it's sometimes difficult to have django use HTTPS
        when we have django behind nginx behind traefix.
    
        We thus use a simple setting (in dev ONLY) to control that.
        """
    
        def __init__(self, get_response):
            self.get_response = get_response
    
        def __call__(self, request):
            if settings.FORCE_HTTPS_URLS:
                setattr(request.__class__, "scheme", "https")
                setattr(
                    request,
                    "get_host",
                    lambda: request.__class__.get_host(request).replace(":80", ":443"),
                )
            return self.get_response(request)
    
    
    def monkey_patch_rest_initialize_request():
        """
        Rest framework use it's own APIRequest, meaning we can't easily
        access our throttling info in the middleware. So me monkey patch the
        `initialize_request` method from rest_framework to keep a link between both requests
        """
        original = views.APIView.initialize_request
    
        def replacement(self, request, *args, **kwargs):
            r = original(self, request, *args, **kwargs)
            setattr(request, "_api_request", r)
            return r
    
        setattr(views.APIView, "initialize_request", replacement)
    
    
    monkey_patch_rest_initialize_request()
    
    
    class ThrottleStatusMiddleware:
        """
        Include useful information regarding throttling in API responses to
        ensure clients can adapt.
        """
    
        def __init__(self, get_response):
            self.get_response = get_response
    
        def __call__(self, request):
            try:
                response = self.get_response(request)
            except throttling.TooManyRequests:
                # manual throttling in non rest_framework view, we have to return
                # the proper response ourselves
                response = http.HttpResponse(status=429)
            request_to_check = request
            try:
                request_to_check = request._api_request
            except AttributeError:
                pass
            throttle_status = getattr(request_to_check, "_throttle_status", None)
            if throttle_status:
                response["X-RateLimit-Limit"] = str(throttle_status["num_requests"])
                response["X-RateLimit-Scope"] = str(throttle_status["scope"])
                response["X-RateLimit-Remaining"] = throttle_status["num_requests"] - len(
                    throttle_status["history"]
                )
                response["X-RateLimit-Duration"] = str(throttle_status["duration"])
                if throttle_status["history"]:
                    now = int(time.time())
                    # At this point, the client can send additional requests
                    oldtest_request = throttle_status["history"][-1]
                    remaining = throttle_status["duration"] - (now - int(oldtest_request))
                    response["Retry-After"] = str(remaining)
                    # At this point, all Rate Limit is reset to 0
                    latest_request = throttle_status["history"][0]
                    remaining = throttle_status["duration"] - (now - int(latest_request))
                    response["X-RateLimit-Reset"] = str(now + remaining)
                    response["X-RateLimit-ResetSeconds"] = str(remaining)
    
            return response
    
    
    class ProfilerMiddleware:
        """
        from https://github.com/omarish/django-cprofile-middleware/blob/master/django_cprofile_middleware/middleware.py
        Simple profile middleware to profile django views. To run it, add ?prof to
        the URL like this:
            http://localhost:8000/view/?prof
        Optionally pass the following to modify the output:
        ?sort => Sort the output by a given metric. Default is time.
            See
            http://docs.python.org/2/library/profile.html#pstats.Stats.sort_stats
            for all sort options.
        ?count => The number of rows to display. Default is 100.
        ?download => Download profile file suitable for visualization. For example
            in snakeviz or RunSnakeRun
        This is adapted from an example found here:
        http://www.slideshare.net/zeeg/django-con-high-performance-django-presentation.
        """
    
        def __init__(self, get_response):
            self.get_response = get_response
    
        def __call__(self, request):
            if "prof" not in request.GET:
                return self.get_response(request)
            import profile
            import pstats
    
            profiler = profile.Profile()
            response = profiler.runcall(self.get_response, request)
            profiler.create_stats()
            if "prof-download" in request.GET:
                import marshal
    
                output = marshal.dumps(profiler.stats)
                response = http.HttpResponse(
                    output, content_type="application/octet-stream"
                )
                response["Content-Disposition"] = "attachment; filename=view.prof"
                response["Content-Length"] = len(output)
            stream = io.StringIO()
            stats = pstats.Stats(profiler, stream=stream)
    
            stats.sort_stats(request.GET.get("prof-sort", "cumtime"))
            stats.print_stats(int(request.GET.get("count", 100)))
    
            response = http.HttpResponse("<pre>%s</pre>" % stream.getvalue())
    
            return response