diff --git a/api/funkwhale_api/common/management/commands/load_test_data.py b/api/funkwhale_api/common/management/commands/load_test_data.py index 26f787f48df7f6704a40549a9c425c1d976992af..9eab5ff083c7843f19e5846baf830f54d23bcc7c 100644 --- a/api/funkwhale_api/common/management/commands/load_test_data.py +++ b/api/funkwhale_api/common/management/commands/load_test_data.py @@ -46,16 +46,28 @@ def create_local_accounts(factories, count, dependencies): return actors -def create_tagged_tracks(factories, count, dependencies): +def create_taggable_items(dependency): + def inner(factories, count, dependencies): - objs = [] - for track in dependencies["tracks"]: - tag = random.choice(dependencies["tags"]) - objs.append(factories["tags.TaggedItem"].build(content_object=track, tag=tag)) + objs = [] + tagged_objects = dependencies.get( + dependency, list(CONFIG_BY_ID[dependency]["model"].objects.all().only("pk")) + ) + tags = dependencies.get("tags", list(tags_models.Tag.objects.all().only("pk"))) + for i in range(count): + tag = random.choice(tags) + tagged_object = random.choice(tagged_objects) + objs.append( + factories["tags.TaggedItem"].build( + content_object=tagged_object, tag=tag + ) + ) + + return tags_models.TaggedItem.objects.bulk_create( + objs, batch_size=BATCH_SIZE, ignore_conflicts=True + ) - return tags_models.TaggedItem.objects.bulk_create( - objs, batch_size=BATCH_SIZE, ignore_conflicts=True - ) + return inner CONFIG = [ @@ -110,7 +122,10 @@ CONFIG = [ { "id": "track_tags", "model": tags_models.TaggedItem, - "handler": create_tagged_tracks, + "queryset": tags_models.TaggedItem.objects.filter( + content_type__app_label="music", content_type__model="track" + ), + "handler": create_taggable_items("tracks"), "depends_on": [ { "field": "tag", @@ -127,6 +142,52 @@ CONFIG = [ }, ], }, + { + "id": "album_tags", + "model": tags_models.TaggedItem, + "queryset": tags_models.TaggedItem.objects.filter( + content_type__app_label="music", content_type__model="album" + ), + "handler": create_taggable_items("albums"), + "depends_on": [ + { + "field": "tag", + "id": "tags", + "default_factor": 0.1, + "queryset": tags_models.Tag.objects.all(), + "set": False, + }, + { + "field": "content_object", + "id": "albums", + "default_factor": 1, + "set": False, + }, + ], + }, + { + "id": "artist_tags", + "model": tags_models.TaggedItem, + "queryset": tags_models.TaggedItem.objects.filter( + content_type__app_label="music", content_type__model="artist" + ), + "handler": create_taggable_items("artists"), + "depends_on": [ + { + "field": "tag", + "id": "tags", + "default_factor": 0.1, + "queryset": tags_models.Tag.objects.all(), + "set": False, + }, + { + "field": "content_object", + "id": "artists", + "default_factor": 1, + "set": False, + }, + ], + }, ] CONFIG_BY_ID = {c["id"]: c for c in CONFIG} @@ -194,8 +255,8 @@ class Command(BaseCommand): self.stdout.write("\nFinal state of database:\n\n") for row in CONFIG: - model = row["model"] - total = model.objects.all().count() + qs = row.get("queryset", row["model"].objects.all()) + total = qs.count() self.stdout.write("- {} {} objects".format(total, row["id"])) self.stdout.write("")