diff --git a/app/blueprints/collections/__init__.py b/app/blueprints/collections/__init__.py index f3beab40..f7b92c7a 100644 --- a/app/blueprints/collections/__init__.py +++ b/app/blueprints/collections/__init__.py @@ -108,9 +108,12 @@ def create_edit(author=None, name=None): form = CollectionForm(formdata=request.form, obj=collection) - initial_package = None + initial_packages = [] if "package" in request.args: - initial_package = Package.get_by_key(request.args["package"]) + for package_id in request.args.getlist("package"): + package = Package.get_by_key(package_id) + if package: + initial_packages.append(package) if request.method == "GET": # HACK: fix bug in wtforms @@ -123,7 +126,7 @@ def create_edit(author=None, name=None): form.name = None if form.validate_on_submit(): - ret = handle_create_edit(collection, form, initial_package, author) + ret = handle_create_edit(collection, form, initial_packages, author) if ret: return ret @@ -132,7 +135,7 @@ def create_edit(author=None, name=None): def handle_create_edit(collection: Collection, form: CollectionForm, - initial_package: typing.Optional[Package], author: User): + initial_packages: typing.List[Package], author: User): severity = AuditSeverity.NORMAL if author == current_user else AuditSeverity.EDITOR name = form.name.data if collection else regex_invalid_chars.sub("", form.title.data.lower().replace(" ", "_")) @@ -157,9 +160,9 @@ def handle_create_edit(collection: Collection, form: CollectionForm, collection.name = name db.session.add(collection) - if initial_package: + for package in initial_packages: link = CollectionPackage() - link.package = initial_package + link.package = package link.collection = collection link.order = len(collection.items) db.session.add(link) diff --git a/app/models/packages.py b/app/models/packages.py index a5ea10d7..42303654 100644 --- a/app/models/packages.py +++ b/app/models/packages.py @@ -21,6 +21,7 @@ import enum from flask import url_for from flask_babel import lazy_gettext from flask_sqlalchemy import BaseQuery +from sqlalchemy import or_ from sqlalchemy_searchable import SearchQueryMixin from sqlalchemy_utils.types import TSVectorType from sqlalchemy.dialects.postgresql import insert @@ -473,7 +474,13 @@ class Package(db.Model): if len(parts) != 2: return None - return Package.query.filter(Package.name == parts[1], Package.author.has(username=parts[0])).first() + name = parts[1] + if name.endswith("_game"): + name = name[:-5] + + return Package.query.filter( + or_(Package.name == name, Package.name == name + "_game"), + Package.author.has(username=parts[0])).first() def get_id(self): return "{}/{}".format(self.author.username, self.name)