Skip to content
Snippets Groups Projects
Commit 2e86ad40 authored by Eliot Berriot's avatar Eliot Berriot
Browse files

Merge branch '994-front-fts' into 'develop'

Resolve "Switch to proper full-text-search system"

See merge request funkwhale/funkwhale!976
parents 8556f137 20311344
No related branches found
No related tags found
No related merge requests found
......@@ -84,11 +84,21 @@ def get_fts_query(query_string, fts_fields=["body_text"], model=None):
fk_field = model._meta.get_field(fk_field_name)
related_model = fk_field.related_model
subquery = related_model.objects.filter(
**{lookup: SearchQuery(query_string, search_type="raw")}
**{
lookup: SearchQuery(
query_string, search_type="raw", config="english_nostop"
)
}
).values_list("pk", flat=True)
new_query = Q(**{"{}__in".format(fk_field_name): list(subquery)})
else:
new_query = Q(**{field: SearchQuery(query_string, search_type="raw")})
new_query = Q(
**{
field: SearchQuery(
query_string, search_type="raw", config="english_nostop"
)
}
)
query = utils.join_queries_or(query, new_query)
return query
......
# Generated by Django 2.2.7 on 2019-12-16 15:06
import django.contrib.postgres.search
import django.contrib.postgres.indexes
from django.db import migrations, models
import django.db.models.deletion
from django.db import connection
FIELDS = {
"music.Artist": {
"fields": [
'name',
],
"trigger_name": "music_artist_update_body_text"
},
"music.Track": {
"fields": ['title', 'copyright'],
"trigger_name": "music_track_update_body_text"
},
"music.Album": {
"fields": ['title'],
"trigger_name": "music_album_update_body_text"
},
}
def populate_body_text(apps, schema_editor):
for label, search_config in FIELDS.items():
model = apps.get_model(*label.split('.'))
print('Updating search index for {}…'.format(model.__name__))
vector = django.contrib.postgres.search.SearchVector(*search_config['fields'], config="public.english_nostop")
model.objects.update(body_text=vector)
def rewind(apps, schema_editor):
pass
def setup_dictionary(apps, schema_editor):
cursor = connection.cursor()
statements = [
"""
CREATE TEXT SEARCH DICTIONARY english_stem_nostop (
Template = snowball
, Language = english
);
""",
"CREATE TEXT SEARCH CONFIGURATION public.english_nostop ( COPY = pg_catalog.english );",
"ALTER TEXT SEARCH CONFIGURATION public.english_nostop ALTER MAPPING FOR asciiword, asciihword, hword_asciipart, hword, hword_part, word WITH english_stem_nostop;",
]
print('Create non stopword dictionary and search configuration…')
for statement in statements:
cursor.execute(statement)
for label, search_config in FIELDS.items():
model = apps.get_model(*label.split('.'))
table = model._meta.db_table
print('Dropping database trigger {} on {}…'.format(search_config['trigger_name'], table))
sql = """
DROP TRIGGER IF EXISTS {trigger_name} ON {table}
""".format(
trigger_name=search_config['trigger_name'],
table=table,
)
cursor.execute(sql)
print('Creating database trigger {} on {}…'.format(search_config['trigger_name'], table))
sql = """
CREATE TRIGGER {trigger_name}
BEFORE INSERT OR UPDATE
ON {table}
FOR EACH ROW
EXECUTE PROCEDURE
tsvector_update_trigger(body_text, 'public.english_nostop', {fields})
""".format(
trigger_name=search_config['trigger_name'],
table=table,
fields=', '.join(search_config['fields']),
)
cursor.execute(sql)
def rewind_dictionary(apps, schema_editor):
cursor = connection.cursor()
for label, search_config in FIELDS.items():
model = apps.get_model(*label.split('.'))
table = model._meta.db_table
class Migration(migrations.Migration):
dependencies = [
('music', '0044_full_text_search'),
]
operations = [
migrations.RunPython(setup_dictionary, rewind_dictionary),
migrations.RunPython(populate_body_text, rewind),
]
......@@ -95,7 +95,22 @@ def refetch_obj(obj, queryset):
return obj
class ArtistViewSet(common_views.SkipFilterForGetObject, viewsets.ReadOnlyModelViewSet):
class HandleInvalidSearch(object):
def list(self, *args, **kwargs):
try:
return super().list(*args, **kwargs)
except django.db.utils.ProgrammingError as e:
if "in tsquery:" in str(e):
return Response({"detail": "Invalid query"}, status=400)
else:
raise
class ArtistViewSet(
HandleInvalidSearch,
common_views.SkipFilterForGetObject,
viewsets.ReadOnlyModelViewSet,
):
queryset = (
models.Artist.objects.all()
.prefetch_related("attributed_to")
......@@ -149,7 +164,11 @@ class ArtistViewSet(common_views.SkipFilterForGetObject, viewsets.ReadOnlyModelV
)
class AlbumViewSet(common_views.SkipFilterForGetObject, viewsets.ReadOnlyModelViewSet):
class AlbumViewSet(
HandleInvalidSearch,
common_views.SkipFilterForGetObject,
viewsets.ReadOnlyModelViewSet,
):
queryset = (
models.Album.objects.all()
.order_by("-creation_date")
......@@ -254,7 +273,11 @@ class LibraryViewSet(
return Response(serializer.data)
class TrackViewSet(common_views.SkipFilterForGetObject, viewsets.ReadOnlyModelViewSet):
class TrackViewSet(
HandleInvalidSearch,
common_views.SkipFilterForGetObject,
viewsets.ReadOnlyModelViewSet,
):
"""
A simple ViewSet for viewing and editing accounts.
"""
......
......@@ -16,7 +16,7 @@ def test_body_text_trigger_creation(factory_name, fields, factories):
obj.refresh_from_db()
cursor = connection.cursor()
sql = """
SELECT to_tsvector('{indexed_text}')
SELECT to_tsvector('english_nostop', '{indexed_text}')
""".format(
indexed_text=" ".join([getattr(obj, f) for f in fields if getattr(obj, f)]),
)
......@@ -41,7 +41,7 @@ def test_body_text_trigger_updaten(factory_name, fields, factories, faker):
obj.refresh_from_db()
cursor = connection.cursor()
sql = """
SELECT to_tsvector('{indexed_text}')
SELECT to_tsvector('english_nostop', '{indexed_text}')
""".format(
indexed_text=" ".join([getattr(obj, f) for f in fields if getattr(obj, f)]),
)
......
......@@ -1238,3 +1238,21 @@ def test_search_get_fts_advanced(settings, logged_in_api_client, factories):
assert response.status_code == 200
assert response.data == expected
def test_search_get_fts_stop_words(settings, logged_in_api_client, factories):
settings.USE_FULL_TEXT_SEARCH = True
artist = factories["music.Artist"](name="she")
factories["music.Artist"]()
url = reverse("api:v1:search")
expected = {
"artists": [serializers.ArtistWithAlbumsSerializer(artist).data],
"albums": [],
"tracks": [],
"tags": [],
}
response = logged_in_api_client.get(url, {"q": "sh"})
assert response.status_code == 200
assert response.data == expected
......@@ -183,6 +183,9 @@ export default {
).then(response => {
self.result = response.data
self.isLoading = false
}, error => {
self.result = null
self.isLoading = false
})
}, 500),
selectPage: function(page) {
......
......@@ -154,7 +154,7 @@ export default {
let params = {
page: this.page,
page_size: this.paginateBy,
name__icontains: this.query,
q: this.query,
ordering: this.getOrderingAsString(),
playable: "true",
tag: this.tags,
......@@ -171,6 +171,9 @@ export default {
).then(response => {
self.result = response.data
self.isLoading = false
}, error => {
self.result = null
self.isLoading = false
})
}, 500),
selectPage: function(page) {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment