From 20311344d70cb656596fea2f90feef82b477ddc4 Mon Sep 17 00:00:00 2001
From: Eliot Berriot <contact@eliotberriot.com>
Date: Wed, 18 Dec 2019 15:08:58 +0100
Subject: [PATCH] Resolve "Switch to proper full-text-search system"

---
 api/funkwhale_api/common/search.py            | 14 ++-
 .../0045_full_text_search_stop_words.py       | 94 +++++++++++++++++++
 api/funkwhale_api/music/views.py              | 29 +++++-
 api/tests/music/test_triggers.py              |  4 +-
 api/tests/music/test_views.py                 | 18 ++++
 front/src/components/library/Albums.vue       |  3 +
 front/src/components/library/Artists.vue      |  5 +-
 7 files changed, 159 insertions(+), 8 deletions(-)
 create mode 100644 api/funkwhale_api/music/migrations/0045_full_text_search_stop_words.py

diff --git a/api/funkwhale_api/common/search.py b/api/funkwhale_api/common/search.py
index 6a4091d7d..deb2607f9 100644
--- a/api/funkwhale_api/common/search.py
+++ b/api/funkwhale_api/common/search.py
@@ -84,11 +84,21 @@ def get_fts_query(query_string, fts_fields=["body_text"], model=None):
             fk_field = model._meta.get_field(fk_field_name)
             related_model = fk_field.related_model
             subquery = related_model.objects.filter(
-                **{lookup: SearchQuery(query_string, search_type="raw")}
+                **{
+                    lookup: SearchQuery(
+                        query_string, search_type="raw", config="english_nostop"
+                    )
+                }
             ).values_list("pk", flat=True)
             new_query = Q(**{"{}__in".format(fk_field_name): list(subquery)})
         else:
-            new_query = Q(**{field: SearchQuery(query_string, search_type="raw")})
+            new_query = Q(
+                **{
+                    field: SearchQuery(
+                        query_string, search_type="raw", config="english_nostop"
+                    )
+                }
+            )
         query = utils.join_queries_or(query, new_query)
 
     return query
diff --git a/api/funkwhale_api/music/migrations/0045_full_text_search_stop_words.py b/api/funkwhale_api/music/migrations/0045_full_text_search_stop_words.py
new file mode 100644
index 000000000..cae2f2bfd
--- /dev/null
+++ b/api/funkwhale_api/music/migrations/0045_full_text_search_stop_words.py
@@ -0,0 +1,94 @@
+# Generated by Django 2.2.7 on 2019-12-16 15:06
+
+import django.contrib.postgres.search
+import django.contrib.postgres.indexes
+from django.db import migrations, models
+import django.db.models.deletion
+from django.db import connection
+
+FIELDS = {
+    "music.Artist": {
+        "fields": [
+            'name',
+        ],
+        "trigger_name": "music_artist_update_body_text"
+    },
+    "music.Track": {
+        "fields": ['title', 'copyright'],
+        "trigger_name": "music_track_update_body_text"
+    },
+    "music.Album": {
+        "fields": ['title'],
+        "trigger_name": "music_album_update_body_text"
+    },
+}
+
+def populate_body_text(apps, schema_editor):
+    for label, search_config in FIELDS.items():
+        model = apps.get_model(*label.split('.'))
+        print('Updating search index for {}…'.format(model.__name__))
+        vector = django.contrib.postgres.search.SearchVector(*search_config['fields'], config="public.english_nostop")
+        model.objects.update(body_text=vector)
+
+def rewind(apps, schema_editor):
+    pass
+
+def setup_dictionary(apps, schema_editor):
+    cursor = connection.cursor()
+    statements = [
+        """
+        CREATE TEXT SEARCH DICTIONARY english_stem_nostop (
+            Template = snowball
+            , Language = english
+        );
+        """,
+        "CREATE TEXT SEARCH CONFIGURATION public.english_nostop ( COPY = pg_catalog.english );",
+        "ALTER TEXT SEARCH CONFIGURATION public.english_nostop ALTER MAPPING FOR asciiword, asciihword, hword_asciipart, hword, hword_part, word WITH english_stem_nostop;",
+    ]
+    print('Create non stopword dictionary and search configuration…')
+    for statement in statements:
+        cursor.execute(statement)
+
+    for label, search_config in FIELDS.items():
+        model = apps.get_model(*label.split('.'))
+        table = model._meta.db_table
+        print('Dropping database trigger {} on {}…'.format(search_config['trigger_name'], table))
+        sql = """
+            DROP TRIGGER IF EXISTS {trigger_name} ON {table}
+        """.format(
+            trigger_name=search_config['trigger_name'],
+            table=table,
+        )
+
+        cursor.execute(sql)
+        print('Creating database trigger {} on {}…'.format(search_config['trigger_name'], table))
+        sql = """
+            CREATE TRIGGER {trigger_name}
+                BEFORE INSERT OR UPDATE
+                ON {table}
+                FOR EACH ROW
+                EXECUTE PROCEDURE
+                    tsvector_update_trigger(body_text, 'public.english_nostop', {fields})
+        """.format(
+            trigger_name=search_config['trigger_name'],
+            table=table,
+            fields=', '.join(search_config['fields']),
+        )
+        cursor.execute(sql)
+
+def rewind_dictionary(apps, schema_editor):
+    cursor = connection.cursor()
+    for label, search_config in FIELDS.items():
+        model = apps.get_model(*label.split('.'))
+        table = model._meta.db_table
+
+class Migration(migrations.Migration):
+
+    dependencies = [
+        ('music', '0044_full_text_search'),
+    ]
+
+    operations = [
+        migrations.RunPython(setup_dictionary, rewind_dictionary),
+        migrations.RunPython(populate_body_text, rewind),
+    ]
diff --git a/api/funkwhale_api/music/views.py b/api/funkwhale_api/music/views.py
index a85607a69..39c327e37 100644
--- a/api/funkwhale_api/music/views.py
+++ b/api/funkwhale_api/music/views.py
@@ -95,7 +95,22 @@ def refetch_obj(obj, queryset):
     return obj
 
 
-class ArtistViewSet(common_views.SkipFilterForGetObject, viewsets.ReadOnlyModelViewSet):
+class HandleInvalidSearch(object):
+    def list(self, *args, **kwargs):
+        try:
+            return super().list(*args, **kwargs)
+        except django.db.utils.ProgrammingError as e:
+            if "in tsquery:" in str(e):
+                return Response({"detail": "Invalid query"}, status=400)
+            else:
+                raise
+
+
+class ArtistViewSet(
+    HandleInvalidSearch,
+    common_views.SkipFilterForGetObject,
+    viewsets.ReadOnlyModelViewSet,
+):
     queryset = (
         models.Artist.objects.all()
         .prefetch_related("attributed_to")
@@ -149,7 +164,11 @@ class ArtistViewSet(common_views.SkipFilterForGetObject, viewsets.ReadOnlyModelV
     )
 
 
-class AlbumViewSet(common_views.SkipFilterForGetObject, viewsets.ReadOnlyModelViewSet):
+class AlbumViewSet(
+    HandleInvalidSearch,
+    common_views.SkipFilterForGetObject,
+    viewsets.ReadOnlyModelViewSet,
+):
     queryset = (
         models.Album.objects.all()
         .order_by("-creation_date")
@@ -254,7 +273,11 @@ class LibraryViewSet(
         return Response(serializer.data)
 
 
-class TrackViewSet(common_views.SkipFilterForGetObject, viewsets.ReadOnlyModelViewSet):
+class TrackViewSet(
+    HandleInvalidSearch,
+    common_views.SkipFilterForGetObject,
+    viewsets.ReadOnlyModelViewSet,
+):
     """
     A simple ViewSet for viewing and editing accounts.
     """
diff --git a/api/tests/music/test_triggers.py b/api/tests/music/test_triggers.py
index e8f5f53a9..62f9ae81e 100644
--- a/api/tests/music/test_triggers.py
+++ b/api/tests/music/test_triggers.py
@@ -16,7 +16,7 @@ def test_body_text_trigger_creation(factory_name, fields, factories):
     obj.refresh_from_db()
     cursor = connection.cursor()
     sql = """
-        SELECT to_tsvector('{indexed_text}')
+        SELECT to_tsvector('english_nostop', '{indexed_text}')
     """.format(
         indexed_text=" ".join([getattr(obj, f) for f in fields if getattr(obj, f)]),
     )
@@ -41,7 +41,7 @@ def test_body_text_trigger_updaten(factory_name, fields, factories, faker):
     obj.refresh_from_db()
     cursor = connection.cursor()
     sql = """
-        SELECT to_tsvector('{indexed_text}')
+        SELECT to_tsvector('english_nostop', '{indexed_text}')
     """.format(
         indexed_text=" ".join([getattr(obj, f) for f in fields if getattr(obj, f)]),
     )
diff --git a/api/tests/music/test_views.py b/api/tests/music/test_views.py
index 7eb71c6d3..744edd2d0 100644
--- a/api/tests/music/test_views.py
+++ b/api/tests/music/test_views.py
@@ -1238,3 +1238,21 @@ def test_search_get_fts_advanced(settings, logged_in_api_client, factories):
 
     assert response.status_code == 200
     assert response.data == expected
+
+
+def test_search_get_fts_stop_words(settings, logged_in_api_client, factories):
+    settings.USE_FULL_TEXT_SEARCH = True
+    artist = factories["music.Artist"](name="she")
+    factories["music.Artist"]()
+
+    url = reverse("api:v1:search")
+    expected = {
+        "artists": [serializers.ArtistWithAlbumsSerializer(artist).data],
+        "albums": [],
+        "tracks": [],
+        "tags": [],
+    }
+    response = logged_in_api_client.get(url, {"q": "sh"})
+
+    assert response.status_code == 200
+    assert response.data == expected
diff --git a/front/src/components/library/Albums.vue b/front/src/components/library/Albums.vue
index 5ce51be1b..3964cbad3 100644
--- a/front/src/components/library/Albums.vue
+++ b/front/src/components/library/Albums.vue
@@ -183,6 +183,9 @@ export default {
       ).then(response => {
         self.result = response.data
         self.isLoading = false
+      }, error => {
+        self.result = null
+        self.isLoading = false
       })
     }, 500),
     selectPage: function(page) {
diff --git a/front/src/components/library/Artists.vue b/front/src/components/library/Artists.vue
index f0dcd8a11..905c33d37 100644
--- a/front/src/components/library/Artists.vue
+++ b/front/src/components/library/Artists.vue
@@ -154,7 +154,7 @@ export default {
       let params = {
         page: this.page,
         page_size: this.paginateBy,
-        name__icontains: this.query,
+        q: this.query,
         ordering: this.getOrderingAsString(),
         playable: "true",
         tag: this.tags,
@@ -171,6 +171,9 @@ export default {
       ).then(response => {
         self.result = response.data
         self.isLoading = false
+      }, error => {
+        self.result = null
+        self.isLoading = false
       })
     }, 500),
     selectPage: function(page) {
-- 
GitLab