This is an automated email from the ASF dual-hosted git repository.

maximebeauchemin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/superset.git


The following commit(s) were added to refs/heads/master by this push:
     new 1283803844 chore: add unique constraint to tagged_objects (#26654)
1283803844 is described below

commit 12838038449d4c4efdad749a2fe8f6d936e4375e
Author: Maxime Beauchemin <maximebeauche...@gmail.com>
AuthorDate: Fri Jan 19 15:12:54 2024 -0800

    chore: add unique constraint to tagged_objects (#26654)
---
 superset/daos/tag.py                               | 21 ++++-
 superset/migrations/__init__.py                    |  7 ++
 superset/migrations/migration_utils.py             | 46 +++++++++++
 ...96164e3017c6_tagged_object_unique_constraint.py | 89 +++++++++++++++++++++
 ...-01-18_12-12_15a2c68a2e6b_merging_two_heads.py} | 22 +++++
 .../2024-01-19_08-42_1cf8e4344e2b_merging.py}      | 22 +++++
 superset/tags/models.py                            | 93 ++++++++++++++++++----
 tests/integration_tests/tags/api_tests.py          | 22 +++--
 tests/integration_tests/tags/commands_tests.py     | 12 +--
 9 files changed, 303 insertions(+), 31 deletions(-)

diff --git a/superset/daos/tag.py b/superset/daos/tag.py
index e4aa891816..46a1d2538f 100644
--- a/superset/daos/tag.py
+++ b/superset/daos/tag.py
@@ -51,14 +51,29 @@ class TagDAO(BaseDAO[Tag]):
         object_type: ObjectType, object_id: int, tag_names: list[str]
     ) -> None:
         tagged_objects = []
-        for name in tag_names:
+
+        # striping and de-dupping
+        clean_tag_names: set[str] = {tag.strip() for tag in tag_names}
+
+        for name in clean_tag_names:
             type_ = TagType.custom
-            tag_name = name.strip()
-            tag = TagDAO.get_by_name(tag_name, type_)
+            tag = TagDAO.get_by_name(name, type_)
             tagged_objects.append(
                 TaggedObject(object_id=object_id, object_type=object_type, 
tag=tag)
             )
 
+            # Check if the association already exists
+            existing_tagged_object = (
+                db.session.query(TaggedObject)
+                .filter_by(object_id=object_id, object_type=object_type, 
tag=tag)
+                .first()
+            )
+
+            if not existing_tagged_object:
+                tagged_objects.append(
+                    TaggedObject(object_id=object_id, object_type=object_type, 
tag=tag)
+                )
+
         db.session.add_all(tagged_objects)
         db.session.commit()
 
diff --git a/superset/migrations/__init__.py b/superset/migrations/__init__.py
index 13a83393a9..b083f44bb4 100644
--- a/superset/migrations/__init__.py
+++ b/superset/migrations/__init__.py
@@ -14,3 +14,10 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
+
+import os
+import sys
+
+# hack to be able to import / reuse migration_utils.py in revisions
+module_dir = os.path.dirname(os.path.realpath(__file__))
+sys.path.append(module_dir)
diff --git a/superset/migrations/migration_utils.py 
b/superset/migrations/migration_utils.py
new file mode 100644
index 0000000000..c754669a1a
--- /dev/null
+++ b/superset/migrations/migration_utils.py
@@ -0,0 +1,46 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from alembic.operations import BatchOperations, Operations
+
+naming_convention = {
+    "fk": "fk_%(table_name)s_%(column_0_name)s_%(referred_table_name)s",
+    "uq": "uq_%(table_name)s_%(column_0_name)s",
+}
+
+
+def create_unique_constraint(
+    op: Operations, index_id: str, table_name: str, uix_columns: list[str]
+) -> None:
+    with op.batch_alter_table(
+        table_name, naming_convention=naming_convention
+    ) as batch_op:
+        batch_op.create_unique_constraint(index_id, uix_columns)
+
+
+def drop_unique_constraint(op: Operations, index_id: str, table_name: str) -> 
None:
+    dialect = op.get_bind().dialect.name
+
+    with op.batch_alter_table(
+        table_name, naming_convention=naming_convention
+    ) as batch_op:
+        if dialect == "mysql":
+            # MySQL requires specifying the type of constraint
+            batch_op.drop_constraint(index_id, type_="unique")
+        else:
+            # For other databases, a standard drop_constraint call is 
sufficient
+            batch_op.drop_constraint(index_id)
diff --git 
a/superset/migrations/versions/2024-01-17_13-09_96164e3017c6_tagged_object_unique_constraint.py
 
b/superset/migrations/versions/2024-01-17_13-09_96164e3017c6_tagged_object_unique_constraint.py
new file mode 100644
index 0000000000..0b67ad5024
--- /dev/null
+++ 
b/superset/migrations/versions/2024-01-17_13-09_96164e3017c6_tagged_object_unique_constraint.py
@@ -0,0 +1,89 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import enum
+
+import migration_utils as utils
+import sqlalchemy as sa
+from alembic import op
+from sqlalchemy import Column, Enum, Integer, MetaData, Table
+from sqlalchemy.sql import and_, func, select
+
+# revision identifiers, used by Alembic.
+revision = "96164e3017c6"
+down_revision = "59a1450b3c10"
+
+
+class ObjectType(enum.Enum):
+    # pylint: disable=invalid-name
+    query = 1
+    chart = 2
+    dashboard = 3
+    dataset = 4
+
+
+# Define the tagged_object table structure
+metadata = MetaData()
+tagged_object_table = Table(
+    "tagged_object",
+    metadata,
+    Column("id", Integer, primary_key=True),
+    Column("tag_id", Integer),
+    Column("object_id", Integer),
+    Column("object_type", Enum(ObjectType)),  # Replace ObjectType with your 
Enum
+)
+
+index_id = "uix_tagged_object"
+table_name = "tagged_object"
+uix_columns = ["tag_id", "object_id", "object_type"]
+
+
+def upgrade():
+    bind = op.get_bind()  # Get the database connection bind
+
+    # Reflect the current database state to get existing tables
+    metadata.reflect(bind=bind)
+
+    # Delete duplicates if any
+    min_id_subquery = (
+        select(
+            [
+                func.min(tagged_object_table.c.id).label("min_id"),
+                tagged_object_table.c.tag_id,
+                tagged_object_table.c.object_id,
+                tagged_object_table.c.object_type,
+            ]
+        )
+        .group_by(
+            tagged_object_table.c.tag_id,
+            tagged_object_table.c.object_id,
+            tagged_object_table.c.object_type,
+        )
+        .alias("min_ids")
+    )
+
+    delete_query = tagged_object_table.delete().where(
+        tagged_object_table.c.id.notin_(select([min_id_subquery.c.min_id]))
+    )
+
+    bind.execute(delete_query)
+
+    # Create unique constraint
+    utils.create_unique_constraint(op, index_id, table_name, uix_columns)
+
+
+def downgrade():
+    utils.drop_unique_constraint(op, index_id, table_name)
diff --git a/superset/migrations/__init__.py 
b/superset/migrations/versions/2024-01-18_12-12_15a2c68a2e6b_merging_two_heads.py
similarity index 68%
copy from superset/migrations/__init__.py
copy to 
superset/migrations/versions/2024-01-18_12-12_15a2c68a2e6b_merging_two_heads.py
index 13a83393a9..7904d9298d 100644
--- a/superset/migrations/__init__.py
+++ 
b/superset/migrations/versions/2024-01-18_12-12_15a2c68a2e6b_merging_two_heads.py
@@ -14,3 +14,25 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
+"""merging two heads
+
+Revision ID: 15a2c68a2e6b
+Revises: ('96164e3017c6', 'a32e0c4d8646')
+Create Date: 2024-01-18 12:12:52.174742
+
+"""
+
+# revision identifiers, used by Alembic.
+revision = "15a2c68a2e6b"
+down_revision = ("96164e3017c6", "a32e0c4d8646")
+
+import sqlalchemy as sa
+from alembic import op
+
+
+def upgrade():
+    pass
+
+
+def downgrade():
+    pass
diff --git a/superset/migrations/__init__.py 
b/superset/migrations/versions/2024-01-19_08-42_1cf8e4344e2b_merging.py
similarity index 69%
copy from superset/migrations/__init__.py
copy to superset/migrations/versions/2024-01-19_08-42_1cf8e4344e2b_merging.py
index 13a83393a9..9ac2a9b24f 100644
--- a/superset/migrations/__init__.py
+++ b/superset/migrations/versions/2024-01-19_08-42_1cf8e4344e2b_merging.py
@@ -14,3 +14,25 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
+"""merging
+
+Revision ID: 1cf8e4344e2b
+Revises: ('e863403c0c50', '15a2c68a2e6b')
+Create Date: 2024-01-19 08:42:37.694192
+
+"""
+
+# revision identifiers, used by Alembic.
+revision = "1cf8e4344e2b"
+down_revision = ("e863403c0c50", "15a2c68a2e6b")
+
+import sqlalchemy as sa
+from alembic import op
+
+
+def upgrade():
+    pass
+
+
+def downgrade():
+    pass
diff --git a/superset/tags/models.py b/superset/tags/models.py
index bae4417507..1e8ca7de1a 100644
--- a/superset/tags/models.py
+++ b/superset/tags/models.py
@@ -21,10 +21,21 @@ from typing import TYPE_CHECKING
 
 from flask import escape
 from flask_appbuilder import Model
-from sqlalchemy import Column, Enum, ForeignKey, Integer, orm, String, Table, 
Text
+from sqlalchemy import (
+    Column,
+    Enum,
+    exists,
+    ForeignKey,
+    Integer,
+    orm,
+    String,
+    Table,
+    Text,
+)
 from sqlalchemy.engine.base import Connection
 from sqlalchemy.orm import relationship, sessionmaker
 from sqlalchemy.orm.mapper import Mapper
+from sqlalchemy.schema import UniqueConstraint
 
 from superset import security_manager
 from superset.models.helpers import AuditMixinNullable
@@ -110,6 +121,14 @@ class TaggedObject(Model, AuditMixinNullable):
     object_type = Column(Enum(ObjectType))
 
     tag = relationship("Tag", back_populates="objects", overlaps="tags")
+    __table_args__ = (
+        UniqueConstraint(
+            "tag_id", "object_id", "object_type", name="uix_tagged_object"
+        ),
+    )
+
+    def __str__(self) -> str:
+        return f"<TaggedObject: {self.object_type}:{self.object_id} 
TAG:{self.tag_id}>"
 
 
 def get_tag(name: str, session: orm.Session, type_: TagType) -> Tag:
@@ -138,7 +157,7 @@ def get_object_type(class_name: str) -> ObjectType:
 
 
 class ObjectUpdater:
-    object_type: str | None = None
+    object_type: str = "default"
 
     @classmethod
     def get_owners_ids(
@@ -146,6 +165,19 @@ class ObjectUpdater:
     ) -> list[int]:
         raise NotImplementedError("Subclass should implement `get_owners_ids`")
 
+    @classmethod
+    def get_owner_tag_ids(
+        cls,
+        session: orm.Session,
+        target: Dashboard | FavStar | Slice | Query | SqlaTable,
+    ) -> set[int]:
+        tag_ids = set()
+        for owner_id in cls.get_owners_ids(target):
+            name = f"owner:{owner_id}"
+            tag = get_tag(name, session, TagType.owner)
+            tag_ids.add(tag.id)
+        return tag_ids
+
     @classmethod
     def _add_owners(
         cls,
@@ -153,10 +185,28 @@ class ObjectUpdater:
         target: Dashboard | FavStar | Slice | Query | SqlaTable,
     ) -> None:
         for owner_id in cls.get_owners_ids(target):
-            name = f"owner:{owner_id}"
+            name: str = f"owner:{owner_id}"
             tag = get_tag(name, session, TagType.owner)
+            cls.add_tag_object_if_not_tagged(
+                session, tag_id=tag.id, object_id=target.id, 
object_type=cls.object_type
+            )
+
+    @classmethod
+    def add_tag_object_if_not_tagged(
+        cls, session: orm.Session, tag_id: int, object_id: int, object_type: 
str
+    ) -> None:
+        # Check if the object is already tagged
+        exists_query = exists().where(
+            TaggedObject.tag_id == tag_id,
+            TaggedObject.object_id == object_id,
+            TaggedObject.object_type == object_type,
+        )
+        already_tagged = session.query(exists_query).scalar()
+
+        # Add TaggedObject to the session if it isn't already tagged
+        if not already_tagged:
             tagged_object = TaggedObject(
-                tag_id=tag.id, object_id=target.id, object_type=cls.object_type
+                tag_id=tag_id, object_id=object_id, object_type=object_type
             )
             session.add(tagged_object)
 
@@ -173,10 +223,9 @@ class ObjectUpdater:
 
             # add `type:` tags
             tag = get_tag(f"type:{cls.object_type}", session, TagType.type)
-            tagged_object = TaggedObject(
-                tag_id=tag.id, object_id=target.id, object_type=cls.object_type
+            cls.add_tag_object_if_not_tagged(
+                session, tag_id=tag.id, object_id=target.id, 
object_type=cls.object_type
             )
-            session.add(tagged_object)
             session.commit()
 
     @classmethod
@@ -187,23 +236,35 @@ class ObjectUpdater:
         target: Dashboard | FavStar | Slice | Query | SqlaTable,
     ) -> None:
         with Session(bind=connection) as session:
-            # delete current `owner:` tags
-            query = (
-                session.query(TaggedObject.id)
+            # Fetch current owner tags
+            existing_tags = (
+                session.query(TaggedObject)
                 .join(Tag)
                 .filter(
                     TaggedObject.object_type == cls.object_type,
                     TaggedObject.object_id == target.id,
                     Tag.type == TagType.owner,
                 )
+                .all()
             )
-            ids = [row[0] for row in query]
-            
session.query(TaggedObject).filter(TaggedObject.id.in_(ids)).delete(
-                synchronize_session=False
-            )
+            existing_owner_tag_ids = {tag.tag_id for tag in existing_tags}
 
-            # add `owner:` tags
-            cls._add_owners(session, target)
+            # Determine new owner IDs
+            new_owner_tag_ids = cls.get_owner_tag_ids(session, target)
+
+            # Add missing tags
+            for owner_tag_id in new_owner_tag_ids - existing_owner_tag_ids:
+                tagged_object = TaggedObject(
+                    tag_id=owner_tag_id,
+                    object_id=target.id,
+                    object_type=cls.object_type,
+                )
+                session.add(tagged_object)
+
+            # Remove unnecessary tags
+            for tag in existing_tags:
+                if tag.tag_id not in new_owner_tag_ids:
+                    session.delete(tag)
             session.commit()
 
     @classmethod
diff --git a/tests/integration_tests/tags/api_tests.py 
b/tests/integration_tests/tags/api_tests.py
index 863288a3e7..d79261c2a3 100644
--- a/tests/integration_tests/tags/api_tests.py
+++ b/tests/integration_tests/tags/api_tests.py
@@ -577,15 +577,25 @@ class TestTagApi(SupersetTestCase):
         result = TagDAO.get_tagged_objects_for_tags(tags, ["chart"])
         assert len(result) == 1
 
-        tagged_objects = db.session.query(TaggedObject).filter(
-            TaggedObject.object_id == dashboard.id,
-            TaggedObject.object_type == ObjectType.dashboard,
+        tagged_objects = (
+            db.session.query(TaggedObject)
+            .join(Tag)
+            .filter(
+                TaggedObject.object_id == dashboard.id,
+                TaggedObject.object_type == ObjectType.dashboard,
+                Tag.type == TagType.custom,
+            )
         )
         assert tagged_objects.count() == 2
 
-        tagged_objects = db.session.query(TaggedObject).filter(
-            TaggedObject.object_id == chart.id,
-            TaggedObject.object_type == ObjectType.chart,
+        tagged_objects = (
+            db.session.query(TaggedObject)
+            .join(Tag)
+            .filter(
+                TaggedObject.object_id == chart.id,
+                TaggedObject.object_type == ObjectType.chart,
+                Tag.type == TagType.custom,
+            )
         )
         assert tagged_objects.count() == 2
 
diff --git a/tests/integration_tests/tags/commands_tests.py 
b/tests/integration_tests/tags/commands_tests.py
index 83762f8f6e..3644c076e6 100644
--- a/tests/integration_tests/tags/commands_tests.py
+++ b/tests/integration_tests/tags/commands_tests.py
@@ -63,7 +63,7 @@ class TestCreateCustomTagCommand(SupersetTestCase):
         example_dashboard = (
             db.session.query(Dashboard).filter_by(slug="world_health").one()
         )
-        example_tags = ["create custom tag example 1", "create custom tag 
example 2"]
+        example_tags = {"create custom tag example 1", "create custom tag 
example 2"}
         command = CreateCustomTagCommand(
             ObjectType.dashboard.value, example_dashboard.id, example_tags
         )
@@ -78,7 +78,7 @@ class TestCreateCustomTagCommand(SupersetTestCase):
             )
             .all()
         )
-        assert example_tags == [tag.name for tag in created_tags]
+        assert example_tags == {tag.name for tag in created_tags}
 
         # cleanup
         tags = db.session.query(Tag).filter(Tag.name.in_(example_tags))
@@ -99,7 +99,7 @@ class TestDeleteTagsCommand(SupersetTestCase):
             .filter_by(dashboard_title="World Bank's Data")
             .one()
         )
-        example_tags = ["create custom tag example 1", "create custom tag 
example 2"]
+        example_tags = {"create custom tag example 1", "create custom tag 
example 2"}
         command = CreateCustomTagCommand(
             ObjectType.dashboard.value, example_dashboard.id, example_tags
         )
@@ -115,7 +115,7 @@ class TestDeleteTagsCommand(SupersetTestCase):
             .order_by(Tag.name)
             .all()
         )
-        assert example_tags == [tag.name for tag in created_tags]
+        assert example_tags == {tag.name for tag in created_tags}
 
         command = DeleteTagsCommand(example_tags)
         command.run()
@@ -132,7 +132,7 @@ class TestDeleteTaggedObjectCommand(SupersetTestCase):
         example_dashboard = (
             db.session.query(Dashboard).filter_by(slug="world_health").one()
         )
-        example_tags = ["create custom tag example 1", "create custom tag 
example 2"]
+        example_tags = {"create custom tag example 1", "create custom tag 
example 2"}
         command = CreateCustomTagCommand(
             ObjectType.dashboard.value, example_dashboard.id, example_tags
         )
@@ -152,7 +152,7 @@ class TestDeleteTaggedObjectCommand(SupersetTestCase):
         command = DeleteTaggedObjectCommand(
             object_type=ObjectType.dashboard.value,
             object_id=example_dashboard.id,
-            tag=example_tags[0],
+            tag=list(example_tags)[0],
         )
         command.run()
         tagged_objects = (

Reply via email to