SQLAlchemy Many-to-Many Relationship: UNIQUE constraint failed

56 Views Asked by At

So, I have a many to many SQLAlchemy relationship defined likeso,

from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import relationship, sessionmaker
from sqlalchemy import Column, Integer, String, ForeignKey, UniqueConstraint, Table, create_engine
from sqlalchemy.orm import relationship, registry


mapper_registry = registry()
Base = declarative_base()


bridge_category = Table(
    "bridge_category",
    Base.metadata,
    Column("video_id", ForeignKey("video.id"), primary_key=True),
    Column("category_id", ForeignKey("category.id"), primary_key=True),
    UniqueConstraint("video_id", "category_id"),
)
class BridgeCategory: pass
mapper_registry.map_imperatively(BridgeCategory, bridge_category)


class Video(Base):
    __tablename__ = 'video'

    id = Column(Integer, primary_key=True)
    title = Column(String)
    categories = relationship("Category", secondary=bridge_category, back_populates="videos")


class Category(Base):
    __tablename__ = 'category'

    id = Column(Integer, primary_key=True)
    text = Column(String, unique=True)
    videos = relationship("Video", secondary=bridge_category, back_populates="categories")


engine = create_engine('sqlite:///:memory:', echo=True)
Base.metadata.create_all(engine)

Session = sessionmaker(bind=engine)

with Session() as s:

    v1 = Video(title='A', categories=[Category(text='blue'), Category(text='red')])
    v2 = Video(title='B', categories=[Category(text='green'), Category(text='red')])
    v3 = Video(title='C', categories=[Category(text='grey'), Category(text='red')])
    videos = [v1, v2, v3]

    s.add_all(videos)
    s.commit()

Of course, because of the unique constraint on Category.text, we get the following error.

sqlalchemy.exc.IntegrityError: (sqlite3.IntegrityError) UNIQUE constraint failed: category.text
[SQL: INSERT INTO category (text) VALUES (?) RETURNING id]
[parameters: ('red',)]

I am wondering what the best way of dealing with this is. With my program, I get a lot of video objects, each with a list of unique Category objects. The text collisions happen across all these video objects.

I could loop through all videos, and all categories, forming a Category set, but that's kinda lame. I'd also have to do that with the 12+ other many-to-many relationships my Video object has, and that seems really inefficient.

Is there like a "insert ignore" flag I can set for this? I haven't been able to find anything online concerning this situation.

1

There are 1 best solutions below

0
scrollout On BEST ANSWER

With a lot of help from the maintainer of SQLAlchemy, I came up with a generic implementation of the code that requires hardly any configurations, or repeating steps, for a single SA model object that contains multiple many-to-many relationships.

from sqlalchemy import Column
from sqlalchemy import create_engine
from sqlalchemy import event
from sqlalchemy import ForeignKey
from sqlalchemy import inspect
from sqlalchemy import Integer
from sqlalchemy import select
from sqlalchemy import String
from sqlalchemy import Table
from sqlalchemy import UniqueConstraint
from sqlalchemy.orm import declarative_base
from sqlalchemy.orm import registry
from sqlalchemy.orm import relationship
from sqlalchemy.orm import RelationshipDirection
from sqlalchemy.orm import Session
from sqlalchemy.orm import sessionmaker


mapper_registry = registry()
Base = declarative_base()


bridge_category = Table(
    "bridge_category",
    Base.metadata,
    Column("video_id", ForeignKey("video.id"), primary_key=True),
    Column("category_id", ForeignKey("category.id"), primary_key=True),
    UniqueConstraint("video_id", "category_id"),
)
class BridgeCategory: pass
mapper_registry.map_imperatively(BridgeCategory, bridge_category)


bridge_format = Table(
    "bridge_format",
    Base.metadata,
    Column("video_id", ForeignKey("video.id"), primary_key=True),
    Column("format_id", ForeignKey("format.id"), primary_key=True),
    UniqueConstraint("video_id", "format_id"),
)
class BridgeFormat: pass
mapper_registry.map_imperatively(BridgeFormat, bridge_format)


class Video(Base):
    __tablename__ = "video"

    id = Column(Integer, primary_key=True)
    title = Column(String)
    categories = relationship("Category", secondary=bridge_category, back_populates="videos")
    formats = relationship("Format", secondary=bridge_format, back_populates="videos")

class Category(Base):
    __tablename__ = "category"

    id = Column(Integer, primary_key=True)
    text = Column(String, unique=True)
    videos = relationship("Video", secondary=bridge_category, back_populates="categories")

class Format(Base):
    __tablename__ = "format"
    id = Column(Integer, primary_key=True, index=True)
    text = Column(String, unique=True)
    videos = relationship("Video", back_populates="formats", secondary=bridge_format)


def unique_robs(session_or_factory, main_obj, rob_unique_col):
    """Unique related objects"""

    def _unique_robs(session, robs, rob_name):
        if not robs:
            return robs
        
        rob_type = type(robs[0])

        with session.no_autoflush:
            local_existing_robs = session.info.get(rob_name, None)
            if local_existing_robs is None:
                session.info[rob_name] = local_existing_robs = {}

            unique_vals = []
            for r in robs:
                unique_val = getattr(r, rob_unique_col)
                if unique_val not in local_existing_robs:
                    unique_vals.append(unique_val)

            existing_categories = {}
            unique_col = getattr(rob_type, rob_unique_col)
            for r in session.scalars(select(rob_type).where(unique_col.in_(unique_vals))):
                existing_categories[getattr(r, rob_unique_col)] = r

            local_existing_robs.update(existing_categories)

            result = []
            for r in robs:
                if getattr(r, rob_unique_col) in local_existing_robs:
                    result.append(local_existing_robs[getattr(r, rob_unique_col)])
                    continue

                local_existing_robs[getattr(r, rob_unique_col)] = r
                result.append(r)

            return result

    @event.listens_for(session_or_factory, "before_attach", retval=True)
    def before_attach(session, obj):
        """Uniquifies all `main_obj` many-to-many relationships."""
        if isinstance(obj, main_obj):
            for r in inspect(obj).mapper.relationships:
                if r.direction.value == RelationshipDirection.MANYTOMANY.value:
                    rob_name = r.class_attribute.key
                    robs = getattr(obj, rob_name, None)
                    if isinstance(robs, list):
                        setattr(obj, rob_name, _unique_robs(session, robs, rob_name))


if __name__ == "__main__":
    engine = create_engine("sqlite:///test.db", echo=True)
    Base.metadata.create_all(engine)
    Session = sessionmaker(bind=engine)
    unique_robs(Session, Video, 'text')

    v1 = Video(title="A", categories=[Category(text="blue"), Category(text="red")])
    v2 = Video(title="B", categories=[Category(text="green"), Category(text="red")], formats=[Format(text='h264')])
    v3 = Video(title="C", categories=[Category(text="grey"), Category(text="red")], formats=[Format(text='h264'), Format(text='vp9')])
    videos = [v1, v2, v3]

    with Session() as s:

        s.add_all(videos)
        s.commit()