Hi All,

Using Sqlalchemy 1.3.23

I am getting a NotImplementedError: Operator 'getitem' is not supported on 
this expression
when sorting on some hybrid_properties.

I have attached a sample code to replicate it.

falls over with the following traceback:

Traceback (most recent call last):
  File "testholdings.py", line 526, in <module>
    trans = 
db_session.query(Transaction).order_by(desc(Transaction.total_cost)).all()
  File 
"/home/gvv/Projects/uby/venv/lib/python3.8/site-packages/sqlalchemy/ext/hybrid.py",
 
line 898, in __get__
    return self._expr_comparator(owner)
  File 
"/home/gvv/Projects/uby/venv/lib/python3.8/site-packages/sqlalchemy/ext/hybrid.py",
 
line 1105, in expr_comparator
    comparator(owner),
  File "testholdings.py", line 135, in total_cost
    return TotalCostComparator(cls)
  File "testholdings.py", line 89, in __init__
    expr = case(
  File "<string>", line 2, in case
  File 
"/home/gvv/Projects/uby/venv/lib/python3.8/site-packages/sqlalchemy/sql/elements.py",
 
line 2437, in __init__
    whenlist = [
  File 
"/home/gvv/Projects/uby/venv/lib/python3.8/site-packages/sqlalchemy/sql/elements.py",
 
line 2439, in <listcomp>
    for (c, r) in whens
  File 
"/home/gvv/Projects/uby/venv/lib/python3.8/site-packages/sqlalchemy/sql/operators.py",
 
line 432, in __getitem__
    return self.operate(getitem, index)
  File 
"/home/gvv/Projects/uby/venv/lib/python3.8/site-packages/sqlalchemy/sql/elements.py",
 
line 762, in operate
    return op(self.comparator, *other, **kwargs)
  File 
"/home/gvv/Projects/uby/venv/lib/python3.8/site-packages/sqlalchemy/sql/operators.py",
 
line 432, in __getitem__
    return self.operate(getitem, index)
  File "<string>", line 1, in <lambda>
  File 
"/home/gvv/Projects/uby/venv/lib/python3.8/site-packages/sqlalchemy/sql/type_api.py",
 
line 67, in operate
    return o[0](self.expr, op, *(other + o[1:]), **kwargs)
  File 
"/home/gvv/Projects/uby/venv/lib/python3.8/site-packages/sqlalchemy/sql/default_comparator.py",
 
line 237, in _getitem_impl
    _unsupported_impl(expr, op, other, **kw)
  File 
"/home/gvv/Projects/uby/venv/lib/python3.8/site-packages/sqlalchemy/sql/default_comparator.py",
 
line 241, in _unsupported_impl
    raise NotImplementedError(
NotImplementedError: Operator 'getitem' is not supported on this expression


Thanks in advance,
George

-- 
SQLAlchemy - 
The Python SQL Toolkit and Object Relational Mapper

http://www.sqlalchemy.org/

To post example code, please provide an MCVE: Minimal, Complete, and Verifiable 
Example.  See  http://stackoverflow.com/help/mcve for a full description.
--- 
You received this message because you are subscribed to the Google Groups 
"sqlalchemy" group.
To unsubscribe from this group and stop receiving emails from it, send an email 
to sqlalchemy+unsubscr...@googlegroups.com.
To view this discussion on the web visit 
https://groups.google.com/d/msgid/sqlalchemy/6c323f70-51c1-4724-ac12-58100fb1e3fen%40googlegroups.com.
import decimal
from datetime import date

from sqlalchemy import select, asc, desc, cast, case, event, create_engine
from sqlalchemy import Column, ForeignKey, Integer, Numeric, Enum, Date, String
from sqlalchemy.orm.session import object_session
from sqlalchemy.orm import relationship, configure_mappers, sessionmaker
from sqlalchemy.ext.hybrid import hybrid_property, Comparator
from sqlalchemy.ext.declarative import declarative_base

engine = create_engine('sqlite:///:memory:', echo=False)
Base = declarative_base()


class StockCompanyShareInfo(Base):
    __tablename__ = "StockCompanyShareInfos"

    Id = Column(Integer, autoincrement=True, primary_key=True, nullable=False)

    LastTradeDate = Column(Date, default=None)
    LastPrice = Column(Numeric(9,4), default=0)

    # OneToOne side of StockCompany
    ItemStockCompany_Id = Column(Integer, ForeignKey("StockCompanies.Id"))
    ItemStockCompany = relationship("StockCompany", back_populates="ShareInfo",
                                    primaryjoin="StockCompany.Id==StockCompanyShareInfo.ItemStockCompany_Id")


class StockCompany(Base):
    __tablename__ = "StockCompanies"

    Id = Column(Integer, autoincrement=True, primary_key=True, nullable=False)
    Ticker = Column(String(8), index=True, nullable=False)

    # One2One side of CompanyShare
    ShareInfo = relationship("StockCompanyShareInfo", uselist=False,
                             back_populates="ItemStockCompany",
                             primaryjoin="StockCompanyShareInfo.ItemStockCompany_Id==StockCompany.Id",
                             cascade="all, delete-orphan")


class HybridComparator(Comparator):
    def __init__(self, expr):
        super().__init__(expr)

    def __eq__(self, val):
        expr = self.__clause_element__()
        return expr == val

    def __ne__(self, val):
        expr = self.__clause_element__()
        return expr != val

    def __ge__(self, val):
        expr = self.__clause_element__()
        return expr >= val

    def __gt__(self, val):
        expr = self.__clause_element__()
        return expr > val

    def __le__(self, val):
        expr = self.__clause_element__()
        return expr <= val

    def __lt__(self, val):
        expr = self.__clause_element__()
        return expr < val

    def asc(self):
        expr = self.__clause_element__()
        return asc(expr)

    def desc(self):
        expr = self.__clause_element__()
        return desc(expr)


class TotalValueComparator(HybridComparator):
    def __init__(self, cls):
        # need to cast to 2 decimals - display is 2 decimals
        # cls.UnitPrice is 4 decimals
        expr = cast(cls.Units * cls.UnitPrice, Numeric(9, 2))
        super().__init__(expr)


class TotalCostComparator(HybridComparator):
    def __init__(self, cls):
        expr = case(
            (cls.Type == "SELL", cast(cls.Units * cls.UnitPrice, Numeric(9, 2)) - cls.Brokerage),
            else_=cast(cls.Units * cls.UnitPrice, Numeric(9, 2)) + cls.Brokerage
        )
        #expr = case(
        #    (cls.Type == "SELL", cls.total_value - cls.Brokerage),
        #    else_=cls.total_value + cls.Brokerage
        #)
        super().__init__(expr)


class Transaction(Base):
    __tablename__ = "Transactions"

    Id = Column(Integer, autoincrement=True, primary_key=True, nullable=False)

    Type = Column(Enum("BUY", "SELL", name="HoldingTransactionType"),
                  nullable=False, default="BUY"
                  )

    Date = Column(Date, nullable=False, default=None)

    Units = Column(Integer, nullable=False)
    UnitPrice = Column(Numeric(9, 4), nullable=False)
    Brokerage = Column(Numeric(9, 2))

    # Many2One side of Holding
    ItemHolding_Id = Column(Integer, ForeignKey("Holdings.Id"))

    # calculated columns
    @hybrid_property
    def total_value(self):
        return self.Units * self.UnitPrice

    @total_value.comparator
    def total_value(cls):
        return TotalValueComparator(cls)

    @hybrid_property
    def total_cost(self):
        if self.Type == "SELL":
            return self.total_value - self.Brokerage
        return self.total_value + self.Brokerage

    @total_cost.comparator
    def total_cost(cls):
        return TotalCostComparator(cls)


def on_transaction_delete(mapper, connection, target):
    db_session = object_session(target)

    holding = getattr(target, "ItemHolding")

    total_units = 0
    total_value = decimal.Decimal(0)
    running_unit_price = decimal.Decimal(0)

    for trans in sorted(holding.Transactions, key=lambda obj: obj.Date):
        if trans in db_session.deleted:
            continue
        units = int(trans.Units)
        trans_unit_price = decimal.Decimal(trans.UnitPrice)
        if trans.Type == "SELL":
            total_units = total_units - units
            total_value = total_value - (units * running_unit_price)
        else:
            total_units = total_units + units
            total_value = total_value + (units * trans_unit_price)
        if total_units == 0:
            running_unit_price = 0
        else:
            running_unit_price = total_value / total_units

    if total_units == 0:
        ave_unit_price = 0
    else:
        ave_unit_price = total_value / total_units

    connection.execute(
        Holding.__table__.update().
            values(Units=total_units, UnitPrice=ave_unit_price).
            where(Holding.Id == target.ItemHolding_Id)
    )


event.listen(Transaction, "after_delete", on_transaction_delete)  # Mapper Event


class MarketValueComparator(HybridComparator):
    def __init__(self, cls):
        # need to cast to 2 decimals - display is 2 decimals
        # cls.last_price is 4 decimals
        expr = cast(cls.Units * cls.last_price, Numeric(9, 2))
        super().__init__(expr)


class VarianceComparator(HybridComparator):
    def __init__(self, cls):
        # need to cast to 2 decimals - display is 2 decimals
        total_cost = cast(cls.Units * cls.UnitPrice, Numeric(9, 2))
        market_value = cast(cls.Units * cls.last_price, Numeric(9, 2))
        expr = market_value - total_cost
        super().__init__(expr)


class VariancePercentComparator(HybridComparator):
    def __init__(self, cls):
        total_cost = cast(cls.Units * cls.UnitPrice, Numeric(9, 2))
        market_value = cast(cls.Units * cls.last_price, Numeric(9, 2))
        if total_cost == 0:
            expr = 0
        else:
            expr = ((market_value - total_cost) / total_cost) * 100
        super().__init__(expr)


class ThresholdPriceComparator(HybridComparator):
    def __init__(self, cls):
        if cls.Threshold == 0:
            expr = 0
        else:
            threshold = (1 - (cls.Threshold / 100))
            expr = case(
                (cls.UnitPrice > cls.last_price, cast(cls.UnitPrice * threshold, Numeric(9,2))),
                else_=cast(cls.last_price * threshold, Numeric(9,2))
            )
        super().__init__(expr)


class ThresholdValueComparator(HybridComparator):
    def __init__(self, cls):
        expr = cls.threshold_price * cls.Units
        super().__init__(expr)


class ThresholdVarianceComparator(HybridComparator):
    def __init__(self, cls):
        expr = cls.threshold_value - cls.total_cost
        super().__init__(expr)


class Holding(Base):
    __tablename__ = "Holdings"

    Id = Column(Integer, autoincrement=True, primary_key=True, nullable=False)

    # Many2One
    StockCompany_Id = Column(Integer, ForeignKey("StockCompanies.Id"), nullable=False)
    StockCompany = relationship("StockCompany", primaryjoin="StockCompany.Id==Holding.StockCompany_Id")

    Units = Column(Integer, default=0)
    UnitPrice = Column(Numeric(9, 4), default=0)
    Threshold = Column(Integer, default=0)

    # One2Many
    Transactions = relationship("Transaction", uselist=True, backref="ItemHolding",
                                order_by="desc(Transaction.Date)",
                                cascade="all, delete-orphan")

    # calculated columns
    @hybrid_property
    def total_cost(self):
        return self.Units * self.UnitPrice

    @total_cost.comparator
    def total_cost(cls):
        return TotalValueComparator(cls)

    @hybrid_property
    def last_price(self):
        return self.StockCompany.ShareInfo.LastPrice

    @last_price.expression
    def last_price(cls):
        return select([StockCompanyShareInfo.LastPrice]).\
            where(StockCompanyShareInfo.ItemStockCompany_Id == cls.StockCompany_Id).\
            as_scalar()

    @hybrid_property
    def market_value(self):
        return self.Units * self.last_price

    @market_value.comparator
    def market_value(cls):
        return MarketValueComparator(cls)

    @hybrid_property
    def variance(self):
        return self.market_value - self.total_cost

    @variance.comparator
    def variance(cls):
        return VarianceComparator(cls)

    @hybrid_property
    def variance_percent(self):
        if self.total_cost == 0:
            return 0
        return (self.variance / self.total_cost) * 100

    @variance_percent.comparator
    def variance_percent(cls):
        return VariancePercentComparator(cls)

    @hybrid_property
    def threshold_price(self):
        if self.Threshold == 0:
            return 0

        threshold = decimal.Decimal(1 - (self.Threshold / 100))  # float to decimal
        if self.UnitPrice > self.last_price:
            return self.UnitPrice * threshold
        else:
            return self.last_price * threshold

    @threshold_price.comparator
    def threshold_price(cls):
        return ThresholdPriceComparator(cls)

    @hybrid_property
    def threshold_value(self):
        return self.threshold_price * self.Units

    @threshold_value.comparator
    def threshold_value(cls):
        return ThresholdValueComparator(cls)

    @hybrid_property
    def threshold_variance(self):
        return self.threshold_value - self.total_cost

    @threshold_variance.comparator
    def threshold_variance(cls):
        return ThresholdVarianceComparator(cls)


def on_holding_update(mapper, connection, target):
    db_session = object_session(target)
    transactions = getattr(target, "Transactions")

    total_units = 0
    total_value = decimal.Decimal(0)
    running_unit_price = decimal.Decimal(0)

    counter = []
    for trans in sorted(transactions, key=lambda obj: obj.Date):
        # dont know why but there are duplicate transactions in update mode
        if trans in db_session.dirty:
            if trans.Id in counter:
                continue
            else:
                counter.append(trans.Id)
        units = int(trans.Units)
        trans_unit_price = decimal.Decimal(trans.UnitPrice)
        if trans.Type == "SELL":
            total_units = total_units - units
            total_value = total_value - (units * running_unit_price)
        else:
            total_units = total_units + units
            total_value = total_value + (units * trans_unit_price)
        if total_units == 0:
            running_unit_price = 0
        else:
            running_unit_price = total_value / total_units

    if total_units == 0:
        ave_unit_price = 0
    else:
        ave_unit_price = total_value / total_units

    setattr(target, "Units", total_units)
    setattr(target, "UnitPrice", ave_unit_price)


event.listen(Holding, "before_insert", on_holding_update)  # Mapper Event
event.listen(Holding, "before_update", on_holding_update)  # Mapper Event


if __name__ == '__main__':
    configure_mappers()
    Base.metadata.create_all(engine)

    db_session = sessionmaker(bind=engine)()

    # populate tables
    coy = StockCompany()
    coy.Ticker = "GVV"
    info = StockCompanyShareInfo()
    info.LastTradeDate = date(2021, 11, 18)
    info.LastPrice = 0.0300
    db_session.add(info)
    coy.ShareInfo = info
    db_session.add(coy)
    db_session.commit()

    coy = StockCompany()
    coy.Ticker = "PRV"
    info = StockCompanyShareInfo()
    info.LastTradeDate = date(2021, 11, 18)
    info.LastPrice = 0.1000
    db_session.add(info)
    coy.ShareInfo = info
    db_session.add(coy)
    db_session.commit()

    coy = db_session.query(StockCompany).filter(StockCompany.Ticker == "GVV").first()
    if coy is not None:
        hold = Holding()
        hold.StockCompany = coy
        hold.Threshold = 10
        db_session.add(hold)
        trans = Transaction()
        trans.Type = "BUY"
        trans.Date = date(2018, 3, 27)
        trans.Units = 250000
        trans.UnitPrice = 0.0200
        trans.Brokerage = 19.95
        db_session.add(trans)
        hold.Transactions.append(trans)
        trans = Transaction()
        trans.Type = "SELL"
        trans.Date = date(2018, 4, 20)
        trans.Units = 250000
        trans.UnitPrice = 0.0210
        trans.Brokerage = 19.95
        db_session.add(trans)
        hold.Transactions.append(trans)
        trans = Transaction()
        trans.Type = "BUY"
        trans.Date = date(2018, 5, 2)
        trans.Units = 312500
        trans.UnitPrice = 0.0160
        trans.Brokerage = 19.95
        db_session.add(trans)
        hold.Transactions.append(trans)
        db_session.commit()

    coy = db_session.query(StockCompany).filter(StockCompany.Ticker == "PRV").first()
    if coy is not None:
        hold = Holding()
        hold.StockCompany = coy
        hold.Threshold = 0
        db_session.add(hold)
        trans = Transaction()
        trans.Type = "BUY"
        trans.Date = date(2021, 7, 12)
        trans.Units = 57472
        trans.UnitPrice = 0.0870
        trans.Brokerage = 19.95
        db_session.add(trans)
        hold.Transactions.append(trans)
        trans = Transaction()
        trans.Type = "BUY"
        trans.Date = date(2021, 10, 7)
        trans.Units = 143800
        trans.UnitPrice = 0.1450
        trans.Brokerage = 19.95
        db_session.add(trans)
        hold.Transactions.append(trans)
        trans = Transaction()
        trans.Type = "BUY"
        trans.Date = date(2021, 11, 1)
        trans.Units = 1643
        trans.UnitPrice = 0.1450
        trans.Brokerage = 0
        db_session.add(trans)
        hold.Transactions.append(trans)
        db_session.commit()

    print("StockCompanies")
    coys = db_session.query(StockCompany).all()
    format_string = "{:<8} {:<13} {:>9}"
    print(format_string.format("Ticker","LastTradeDate","LastPrice"))
    for co in coys:
        print(format_string.format(co.Ticker, str(co.ShareInfo.LastTradeDate),
              str(co.ShareInfo.LastPrice)))

    print("\nHoldings")
    holds = db_session.query(Holding).all()
    format_string = "{:<8} {:>7} {:>9} {:>9} {:>10} {:>11} {:>10} {:>9} {:>13} {:>14} {:>17} {:>10}"
    print(format_string.format("Ticker","Units","UnitPrice","LastPrice","TotalCost",
                               "MarketValue","Variance","Variance%","ThresholdPrice",
                               "ThresholdValue","ThresholdVariance","Threshold%"))
    for hold in holds:
        print(format_string.format(hold.StockCompany.Ticker, str(hold.Units),
                                   str(hold.UnitPrice), str(hold.last_price),
                                   str(hold.total_cost), str(hold.market_value),
                                   str(hold.variance),
                                   "{:.2f}".format(hold.variance_percent),
                                   "{:.4f}".format(hold.threshold_price),
                                   "{:.2f}".format(hold.threshold_value),
                                   "{:.2f}".format(hold.threshold_variance),
                                   str(hold.Threshold)))

    print("\nHoldings - Sort by Variance% asc - OK")
    holds = db_session.query(Holding).order_by(asc(Holding.variance_percent)).all()
    print(format_string.format("Ticker","Units","UnitPrice","LastPrice","TotalCost",
                               "MarketValue","Variance","Variance%","ThresholdPrice",
                               "ThresholdValue","ThresholdVariance","Threshold%"))
    for hold in holds:
        print(format_string.format(hold.StockCompany.Ticker, str(hold.Units),
                                   str(hold.UnitPrice), str(hold.last_price),
                                   str(hold.total_cost), str(hold.market_value),
                                   str(hold.variance),
                                   "{:.2f}".format(hold.variance_percent),
                                   "{:.4f}".format(hold.threshold_price),
                                   "{:.2f}".format(hold.threshold_value),
                                   "{:.2f}".format(hold.threshold_variance),
                                   str(hold.Threshold)))

    print("\nTransactions")
    trans = db_session.query(Transaction).all()
    format_string = "{:<8} {:<4} {:<10} {:>7} {:>9} {:>10} {:>9} {:>10}"
    print(format_string.format("Ticker","Type","Date","Units","UnitPrice","TotalValue",
                               "Brokerage","TotalCost"))
    for tran in trans:
        print(format_string.format(tran.ItemHolding.StockCompany.Ticker,
                                   tran.Type, str(tran.Date), str(tran.Units),
                                   str(tran.UnitPrice),
                                   "{:.2f}".format(tran.total_value),
                                   str(tran.Brokerage),
                                   "{:.2f}".format(tran.total_cost)))

    print("\nTransactions - Sort by TotalValue asc - OK")
    trans = db_session.query(Transaction).order_by(asc(Transaction.total_value)).all()
    print(format_string.format("Ticker","Type","Date","Units","UnitPrice","TotalValue",
                               "Brokerage","TotalCost"))
    for tran in trans:
        print(format_string.format(tran.ItemHolding.StockCompany.Ticker,
                                   tran.Type, str(tran.Date), str(tran.Units),
                                   str(tran.UnitPrice),
                                   "{:.2f}".format(tran.total_value),
                                   str(tran.Brokerage),
                                   "{:.2f}".format(tran.total_cost)))
    ###################################################################################
    print("\nTransactions - Sort by TotalCost desc - NOT OK")
    trans = db_session.query(Transaction).order_by(desc(Transaction.total_cost)).all()
    # trans = db_session.query(Transaction).order_by(Transaction.total_cost.desc()).all()
    print(format_string.format("Ticker","Type","Date","Units","UnitPrice","TotalValue",
                               "Brokerage","TotalCost"))
    for tran in trans:
        print(format_string.format(tran.ItemHolding.StockCompany.Ticker,
                                   tran.Type, str(tran.Date), str(tran.Units),
                                   str(tran.UnitPrice),
                                   "{:.2f}".format(tran.total_value),
                                   str(tran.Brokerage),
                                   "{:.2f}".format(tran.total_cost)))

    print("\nHoldings - Sort by ThresholdPrice asc - NOT OK")
    holds = db_session.query(Holding).order_by(asc(Holding.threshold_price)).all()
    format_string = "{:<8} {:>7} {:>9} {:>9} {:>10} {:>11} {:>10} {:>9} {:>13} {:>14} {:>17} {:>10}"
    print(format_string.format("Ticker","Units","UnitPrice","LastPrice","TotalCost",
                               "MarketValue","Variance","Variance%","ThresholdPrice",
                               "ThresholdValue","ThresholdVariance","Threshold%"))
    for hold in holds:
        print(format_string.format(hold.StockCompany.Ticker, str(hold.Units),
                                   str(hold.UnitPrice), str(hold.last_price),
                                   str(hold.total_cost), str(hold.market_value),
                                   str(hold.variance),
                                   "{:.2f}".format(hold.variance_percent),
                                   "{:.4f}".format(hold.threshold_price),
                                   "{:.2f}".format(hold.threshold_value),
                                   "{:.2f}".format(hold.threshold_variance),
                                   str(hold.Threshold)))

    print("\nHoldings - Sort by ThresholdValue asc - NOT OK")
    holds = db_session.query(Holding).order_by(asc(Holding.threshold_value)).all()
    print(format_string.format("Ticker","Units","UnitPrice","LastPrice","TotalCost",
                               "MarketValue","Variance","Variance%","ThresholdPrice",
                               "ThresholdValue","ThresholdVariance","Threshold%"))
    for hold in holds:
        print(format_string.format(hold.StockCompany.Ticker, str(hold.Units),
                                   str(hold.UnitPrice), str(hold.last_price),
                                   str(hold.total_cost), str(hold.market_value),
                                   str(hold.variance),
                                   "{:.2f}".format(hold.variance_percent),
                                   "{:.4f}".format(hold.threshold_price),
                                   "{:.2f}".format(hold.threshold_value),
                                   "{:.2f}".format(hold.threshold_variance),
                                   str(hold.Threshold)))

    print("\nHoldings - Sort by ThresholdVariance asc - NOT OK")
    holds = db_session.query(Holding).order_by(asc(Holding.threshold_variance)).all()
    print(format_string.format("Ticker","Units","UnitPrice","LastPrice","TotalCost",
                               "MarketValue","Variance","Variance%","ThresholdPrice",
                               "ThresholdValue","ThresholdVariance","Threshold%"))
    for hold in holds:
        print(format_string.format(hold.StockCompany.Ticker, str(hold.Units),
                                   str(hold.UnitPrice), str(hold.last_price),
                                   str(hold.total_cost), str(hold.market_value),
                                   str(hold.variance),
                                   "{:.2f}".format(hold.variance_percent),
                                   "{:.4f}".format(hold.threshold_price),
                                   "{:.2f}".format(hold.threshold_value),
                                   "{:.2f}".format(hold.threshold_variance),
                                   str(hold.Threshold)))

Reply via email to