Here is an update to the original that now works, implementing Mike's 
solution of a single table inheritance for the association object.

Thank you again!


from sqlalchemy import Column, ForeignKey, Enum, Text, Integer, 
create_engine
from sqlalchemy.orm import relationship, backref, sessionmaker, 
scoped_session
from sqlalchemy.ext.associationproxy import association_proxy
from sqlalchemy.ext.declarative import declarative_base
from enum import Enum as eEnum

Base = declarative_base()

class WorkerType(eEnum):
    CONTRACTOR = 0
    EMPLOYEE = 1
    PART_TIME = 2

class Employer(Base):
    __tablename__ = "employer"

    id = Column(Integer, primary_key=True)
    name = Column(Text)

    contractor = association_proxy(
        "employer_workers_contractor", "worker",
        creator=lambda el:EmployerWorkerContractor(worker=el)
    )

    employee = association_proxy(
        "employer_workers_employee", "worker",
        creator=lambda el:EmployerWorkerEmployee(worker=el)
    )

    part_time = association_proxy(
        "employer_workers_part_time", "worker",
        creator=lambda el:EmployerWorkerPartTime(worker=el)
    )


class EmployerWorkerAssociation(Base):
    __tablename__ = "employer_to_worker"

    employer_id = Column(ForeignKey("employer.id"), primary_key=True)
    worker_id = Column(ForeignKey("worker.id"), primary_key=True)
    worker_type = Column(Enum(WorkerType), primary_key=True)

    __mapper_args__ = {
        "polymorphic_identity": "EmployerWorkerAssociation",
        "polymorphic_on": "worker_type"
    }


class EmployerWorkerContractor(EmployerWorkerAssociation):
    employer = relationship(
        "Employer",
        backref=backref("employer_workers_contractor", cascade="all, 
delete-orphan")
    )

    worker = relationship(
        "Worker",
        backref=backref("worker_employers_contractor")
    )

    __mapper_args__ = {"polymorphic_identity": WorkerType.CONTRACTOR}


class EmployerWorkerEmployee(EmployerWorkerAssociation):
    employer = relationship(
        "Employer",
        backref=backref("employer_workers_employee", cascade="all, 
delete-orphan")
    )

    worker = relationship(
        "Worker",
        backref=backref("worker_employers_employee")
    )

    __mapper_args__ = {"polymorphic_identity": WorkerType.EMPLOYEE}


class EmployerWorkerPartTime(EmployerWorkerAssociation):
    employer = relationship(
        "Employer",
        backref=backref("employer_workers_part_time", cascade="all, 
delete-orphan")
    )

    worker = relationship(
        "Worker",
        backref=backref("worker_employers_part_time")
    )

    __mapper_args__ = {"polymorphic_identity": WorkerType.PART_TIME}




class Worker(Base):
    __tablename__ = "worker"
    id = Column(Integer, primary_key=True)
    name = Column(Text)

    contractor = association_proxy(
        "worker_employers_contractor", "employer",
        creator=lambda el:EmployerWorkerContractor(employer=el)
    )

    employee = association_proxy(
        "worker_employers_employee", "employer",
        creator=lambda el:EmployerWorkerEmployee(employer=el)
    )

    part_time = association_proxy(
        "worker_employers_part_time", "employer",
        creator=lambda el:EmployerWorkerPartTime(employer=el)
    )



engine = create_engine("sqlite://")
Base.metadata.create_all(engine)
Session = sessionmaker(bind=engine)
session = scoped_session(Session)
e1 = Employer(name="The Company")
session.add(e1)
session.commit()

w1 = Worker(name="The Programmer")
session.add(w1)
session.commit()

e1.contractor.append(w1)
session.add(e1)
session.commit()

print(f"Contractors: {len(e1.contractor)}")
print(f"Employees  : {len(e1.employee)}")
print(f"Part Timers: {len(e1.part_time)}")

e2 = Employer(name="The Enterprise")
session.add(e2)

w1.part_time.append(e2)
session.commit()

print(f"worker   {w1.name} c={len(w1.contractor)} e={len(w1.employee)} 
p={len(w1.part_time)}")
print(f"employer {e1.name} c={len(e1.contractor)} e={len(e1.employee)} 
p={len(e1.part_time)}")
print(f"employer {e2.name} c={len(e2.contractor)} e={len(e2.employee)} 
p={len(e2.part_time)}")



-- 
SQLAlchemy - 
The Python SQL Toolkit and Object Relational Mapper

http://www.sqlalchemy.org/

To post example code, please provide an MCVE: Minimal, Complete, and Verifiable 
Example.  See  http://stackoverflow.com/help/mcve for a full description.
--- 
You received this message because you are subscribed to the Google Groups 
"sqlalchemy" group.
To unsubscribe from this group and stop receiving emails from it, send an email 
to sqlalchemy+unsubscr...@googlegroups.com.
To post to this group, send email to sqlalchemy@googlegroups.com.
Visit this group at https://groups.google.com/group/sqlalchemy.
For more options, visit https://groups.google.com/d/optout.

Reply via email to