Hi All, Using Sqlalchemy 1.3.23
I am getting a NotImplementedError: Operator 'getitem' is not supported on this expression when sorting on some hybrid_properties. I have attached a sample code to replicate it. falls over with the following traceback: Traceback (most recent call last): File "testholdings.py", line 526, in <module> trans = db_session.query(Transaction).order_by(desc(Transaction.total_cost)).all() File "/home/gvv/Projects/uby/venv/lib/python3.8/site-packages/sqlalchemy/ext/hybrid.py", line 898, in __get__ return self._expr_comparator(owner) File "/home/gvv/Projects/uby/venv/lib/python3.8/site-packages/sqlalchemy/ext/hybrid.py", line 1105, in expr_comparator comparator(owner), File "testholdings.py", line 135, in total_cost return TotalCostComparator(cls) File "testholdings.py", line 89, in __init__ expr = case( File "<string>", line 2, in case File "/home/gvv/Projects/uby/venv/lib/python3.8/site-packages/sqlalchemy/sql/elements.py", line 2437, in __init__ whenlist = [ File "/home/gvv/Projects/uby/venv/lib/python3.8/site-packages/sqlalchemy/sql/elements.py", line 2439, in <listcomp> for (c, r) in whens File "/home/gvv/Projects/uby/venv/lib/python3.8/site-packages/sqlalchemy/sql/operators.py", line 432, in __getitem__ return self.operate(getitem, index) File "/home/gvv/Projects/uby/venv/lib/python3.8/site-packages/sqlalchemy/sql/elements.py", line 762, in operate return op(self.comparator, *other, **kwargs) File "/home/gvv/Projects/uby/venv/lib/python3.8/site-packages/sqlalchemy/sql/operators.py", line 432, in __getitem__ return self.operate(getitem, index) File "<string>", line 1, in <lambda> File "/home/gvv/Projects/uby/venv/lib/python3.8/site-packages/sqlalchemy/sql/type_api.py", line 67, in operate return o[0](self.expr, op, *(other + o[1:]), **kwargs) File "/home/gvv/Projects/uby/venv/lib/python3.8/site-packages/sqlalchemy/sql/default_comparator.py", line 237, in _getitem_impl _unsupported_impl(expr, op, other, **kw) File "/home/gvv/Projects/uby/venv/lib/python3.8/site-packages/sqlalchemy/sql/default_comparator.py", line 241, in _unsupported_impl raise NotImplementedError( NotImplementedError: Operator 'getitem' is not supported on this expression Thanks in advance, George -- SQLAlchemy - The Python SQL Toolkit and Object Relational Mapper http://www.sqlalchemy.org/ To post example code, please provide an MCVE: Minimal, Complete, and Verifiable Example. See http://stackoverflow.com/help/mcve for a full description. --- You received this message because you are subscribed to the Google Groups "sqlalchemy" group. To unsubscribe from this group and stop receiving emails from it, send an email to sqlalchemy+unsubscr...@googlegroups.com. To view this discussion on the web visit https://groups.google.com/d/msgid/sqlalchemy/6c323f70-51c1-4724-ac12-58100fb1e3fen%40googlegroups.com.
import decimal from datetime import date from sqlalchemy import select, asc, desc, cast, case, event, create_engine from sqlalchemy import Column, ForeignKey, Integer, Numeric, Enum, Date, String from sqlalchemy.orm.session import object_session from sqlalchemy.orm import relationship, configure_mappers, sessionmaker from sqlalchemy.ext.hybrid import hybrid_property, Comparator from sqlalchemy.ext.declarative import declarative_base engine = create_engine('sqlite:///:memory:', echo=False) Base = declarative_base() class StockCompanyShareInfo(Base): __tablename__ = "StockCompanyShareInfos" Id = Column(Integer, autoincrement=True, primary_key=True, nullable=False) LastTradeDate = Column(Date, default=None) LastPrice = Column(Numeric(9,4), default=0) # OneToOne side of StockCompany ItemStockCompany_Id = Column(Integer, ForeignKey("StockCompanies.Id")) ItemStockCompany = relationship("StockCompany", back_populates="ShareInfo", primaryjoin="StockCompany.Id==StockCompanyShareInfo.ItemStockCompany_Id") class StockCompany(Base): __tablename__ = "StockCompanies" Id = Column(Integer, autoincrement=True, primary_key=True, nullable=False) Ticker = Column(String(8), index=True, nullable=False) # One2One side of CompanyShare ShareInfo = relationship("StockCompanyShareInfo", uselist=False, back_populates="ItemStockCompany", primaryjoin="StockCompanyShareInfo.ItemStockCompany_Id==StockCompany.Id", cascade="all, delete-orphan") class HybridComparator(Comparator): def __init__(self, expr): super().__init__(expr) def __eq__(self, val): expr = self.__clause_element__() return expr == val def __ne__(self, val): expr = self.__clause_element__() return expr != val def __ge__(self, val): expr = self.__clause_element__() return expr >= val def __gt__(self, val): expr = self.__clause_element__() return expr > val def __le__(self, val): expr = self.__clause_element__() return expr <= val def __lt__(self, val): expr = self.__clause_element__() return expr < val def asc(self): expr = self.__clause_element__() return asc(expr) def desc(self): expr = self.__clause_element__() return desc(expr) class TotalValueComparator(HybridComparator): def __init__(self, cls): # need to cast to 2 decimals - display is 2 decimals # cls.UnitPrice is 4 decimals expr = cast(cls.Units * cls.UnitPrice, Numeric(9, 2)) super().__init__(expr) class TotalCostComparator(HybridComparator): def __init__(self, cls): expr = case( (cls.Type == "SELL", cast(cls.Units * cls.UnitPrice, Numeric(9, 2)) - cls.Brokerage), else_=cast(cls.Units * cls.UnitPrice, Numeric(9, 2)) + cls.Brokerage ) #expr = case( # (cls.Type == "SELL", cls.total_value - cls.Brokerage), # else_=cls.total_value + cls.Brokerage #) super().__init__(expr) class Transaction(Base): __tablename__ = "Transactions" Id = Column(Integer, autoincrement=True, primary_key=True, nullable=False) Type = Column(Enum("BUY", "SELL", name="HoldingTransactionType"), nullable=False, default="BUY" ) Date = Column(Date, nullable=False, default=None) Units = Column(Integer, nullable=False) UnitPrice = Column(Numeric(9, 4), nullable=False) Brokerage = Column(Numeric(9, 2)) # Many2One side of Holding ItemHolding_Id = Column(Integer, ForeignKey("Holdings.Id")) # calculated columns @hybrid_property def total_value(self): return self.Units * self.UnitPrice @total_value.comparator def total_value(cls): return TotalValueComparator(cls) @hybrid_property def total_cost(self): if self.Type == "SELL": return self.total_value - self.Brokerage return self.total_value + self.Brokerage @total_cost.comparator def total_cost(cls): return TotalCostComparator(cls) def on_transaction_delete(mapper, connection, target): db_session = object_session(target) holding = getattr(target, "ItemHolding") total_units = 0 total_value = decimal.Decimal(0) running_unit_price = decimal.Decimal(0) for trans in sorted(holding.Transactions, key=lambda obj: obj.Date): if trans in db_session.deleted: continue units = int(trans.Units) trans_unit_price = decimal.Decimal(trans.UnitPrice) if trans.Type == "SELL": total_units = total_units - units total_value = total_value - (units * running_unit_price) else: total_units = total_units + units total_value = total_value + (units * trans_unit_price) if total_units == 0: running_unit_price = 0 else: running_unit_price = total_value / total_units if total_units == 0: ave_unit_price = 0 else: ave_unit_price = total_value / total_units connection.execute( Holding.__table__.update(). values(Units=total_units, UnitPrice=ave_unit_price). where(Holding.Id == target.ItemHolding_Id) ) event.listen(Transaction, "after_delete", on_transaction_delete) # Mapper Event class MarketValueComparator(HybridComparator): def __init__(self, cls): # need to cast to 2 decimals - display is 2 decimals # cls.last_price is 4 decimals expr = cast(cls.Units * cls.last_price, Numeric(9, 2)) super().__init__(expr) class VarianceComparator(HybridComparator): def __init__(self, cls): # need to cast to 2 decimals - display is 2 decimals total_cost = cast(cls.Units * cls.UnitPrice, Numeric(9, 2)) market_value = cast(cls.Units * cls.last_price, Numeric(9, 2)) expr = market_value - total_cost super().__init__(expr) class VariancePercentComparator(HybridComparator): def __init__(self, cls): total_cost = cast(cls.Units * cls.UnitPrice, Numeric(9, 2)) market_value = cast(cls.Units * cls.last_price, Numeric(9, 2)) if total_cost == 0: expr = 0 else: expr = ((market_value - total_cost) / total_cost) * 100 super().__init__(expr) class ThresholdPriceComparator(HybridComparator): def __init__(self, cls): if cls.Threshold == 0: expr = 0 else: threshold = (1 - (cls.Threshold / 100)) expr = case( (cls.UnitPrice > cls.last_price, cast(cls.UnitPrice * threshold, Numeric(9,2))), else_=cast(cls.last_price * threshold, Numeric(9,2)) ) super().__init__(expr) class ThresholdValueComparator(HybridComparator): def __init__(self, cls): expr = cls.threshold_price * cls.Units super().__init__(expr) class ThresholdVarianceComparator(HybridComparator): def __init__(self, cls): expr = cls.threshold_value - cls.total_cost super().__init__(expr) class Holding(Base): __tablename__ = "Holdings" Id = Column(Integer, autoincrement=True, primary_key=True, nullable=False) # Many2One StockCompany_Id = Column(Integer, ForeignKey("StockCompanies.Id"), nullable=False) StockCompany = relationship("StockCompany", primaryjoin="StockCompany.Id==Holding.StockCompany_Id") Units = Column(Integer, default=0) UnitPrice = Column(Numeric(9, 4), default=0) Threshold = Column(Integer, default=0) # One2Many Transactions = relationship("Transaction", uselist=True, backref="ItemHolding", order_by="desc(Transaction.Date)", cascade="all, delete-orphan") # calculated columns @hybrid_property def total_cost(self): return self.Units * self.UnitPrice @total_cost.comparator def total_cost(cls): return TotalValueComparator(cls) @hybrid_property def last_price(self): return self.StockCompany.ShareInfo.LastPrice @last_price.expression def last_price(cls): return select([StockCompanyShareInfo.LastPrice]).\ where(StockCompanyShareInfo.ItemStockCompany_Id == cls.StockCompany_Id).\ as_scalar() @hybrid_property def market_value(self): return self.Units * self.last_price @market_value.comparator def market_value(cls): return MarketValueComparator(cls) @hybrid_property def variance(self): return self.market_value - self.total_cost @variance.comparator def variance(cls): return VarianceComparator(cls) @hybrid_property def variance_percent(self): if self.total_cost == 0: return 0 return (self.variance / self.total_cost) * 100 @variance_percent.comparator def variance_percent(cls): return VariancePercentComparator(cls) @hybrid_property def threshold_price(self): if self.Threshold == 0: return 0 threshold = decimal.Decimal(1 - (self.Threshold / 100)) # float to decimal if self.UnitPrice > self.last_price: return self.UnitPrice * threshold else: return self.last_price * threshold @threshold_price.comparator def threshold_price(cls): return ThresholdPriceComparator(cls) @hybrid_property def threshold_value(self): return self.threshold_price * self.Units @threshold_value.comparator def threshold_value(cls): return ThresholdValueComparator(cls) @hybrid_property def threshold_variance(self): return self.threshold_value - self.total_cost @threshold_variance.comparator def threshold_variance(cls): return ThresholdVarianceComparator(cls) def on_holding_update(mapper, connection, target): db_session = object_session(target) transactions = getattr(target, "Transactions") total_units = 0 total_value = decimal.Decimal(0) running_unit_price = decimal.Decimal(0) counter = [] for trans in sorted(transactions, key=lambda obj: obj.Date): # dont know why but there are duplicate transactions in update mode if trans in db_session.dirty: if trans.Id in counter: continue else: counter.append(trans.Id) units = int(trans.Units) trans_unit_price = decimal.Decimal(trans.UnitPrice) if trans.Type == "SELL": total_units = total_units - units total_value = total_value - (units * running_unit_price) else: total_units = total_units + units total_value = total_value + (units * trans_unit_price) if total_units == 0: running_unit_price = 0 else: running_unit_price = total_value / total_units if total_units == 0: ave_unit_price = 0 else: ave_unit_price = total_value / total_units setattr(target, "Units", total_units) setattr(target, "UnitPrice", ave_unit_price) event.listen(Holding, "before_insert", on_holding_update) # Mapper Event event.listen(Holding, "before_update", on_holding_update) # Mapper Event if __name__ == '__main__': configure_mappers() Base.metadata.create_all(engine) db_session = sessionmaker(bind=engine)() # populate tables coy = StockCompany() coy.Ticker = "GVV" info = StockCompanyShareInfo() info.LastTradeDate = date(2021, 11, 18) info.LastPrice = 0.0300 db_session.add(info) coy.ShareInfo = info db_session.add(coy) db_session.commit() coy = StockCompany() coy.Ticker = "PRV" info = StockCompanyShareInfo() info.LastTradeDate = date(2021, 11, 18) info.LastPrice = 0.1000 db_session.add(info) coy.ShareInfo = info db_session.add(coy) db_session.commit() coy = db_session.query(StockCompany).filter(StockCompany.Ticker == "GVV").first() if coy is not None: hold = Holding() hold.StockCompany = coy hold.Threshold = 10 db_session.add(hold) trans = Transaction() trans.Type = "BUY" trans.Date = date(2018, 3, 27) trans.Units = 250000 trans.UnitPrice = 0.0200 trans.Brokerage = 19.95 db_session.add(trans) hold.Transactions.append(trans) trans = Transaction() trans.Type = "SELL" trans.Date = date(2018, 4, 20) trans.Units = 250000 trans.UnitPrice = 0.0210 trans.Brokerage = 19.95 db_session.add(trans) hold.Transactions.append(trans) trans = Transaction() trans.Type = "BUY" trans.Date = date(2018, 5, 2) trans.Units = 312500 trans.UnitPrice = 0.0160 trans.Brokerage = 19.95 db_session.add(trans) hold.Transactions.append(trans) db_session.commit() coy = db_session.query(StockCompany).filter(StockCompany.Ticker == "PRV").first() if coy is not None: hold = Holding() hold.StockCompany = coy hold.Threshold = 0 db_session.add(hold) trans = Transaction() trans.Type = "BUY" trans.Date = date(2021, 7, 12) trans.Units = 57472 trans.UnitPrice = 0.0870 trans.Brokerage = 19.95 db_session.add(trans) hold.Transactions.append(trans) trans = Transaction() trans.Type = "BUY" trans.Date = date(2021, 10, 7) trans.Units = 143800 trans.UnitPrice = 0.1450 trans.Brokerage = 19.95 db_session.add(trans) hold.Transactions.append(trans) trans = Transaction() trans.Type = "BUY" trans.Date = date(2021, 11, 1) trans.Units = 1643 trans.UnitPrice = 0.1450 trans.Brokerage = 0 db_session.add(trans) hold.Transactions.append(trans) db_session.commit() print("StockCompanies") coys = db_session.query(StockCompany).all() format_string = "{:<8} {:<13} {:>9}" print(format_string.format("Ticker","LastTradeDate","LastPrice")) for co in coys: print(format_string.format(co.Ticker, str(co.ShareInfo.LastTradeDate), str(co.ShareInfo.LastPrice))) print("\nHoldings") holds = db_session.query(Holding).all() format_string = "{:<8} {:>7} {:>9} {:>9} {:>10} {:>11} {:>10} {:>9} {:>13} {:>14} {:>17} {:>10}" print(format_string.format("Ticker","Units","UnitPrice","LastPrice","TotalCost", "MarketValue","Variance","Variance%","ThresholdPrice", "ThresholdValue","ThresholdVariance","Threshold%")) for hold in holds: print(format_string.format(hold.StockCompany.Ticker, str(hold.Units), str(hold.UnitPrice), str(hold.last_price), str(hold.total_cost), str(hold.market_value), str(hold.variance), "{:.2f}".format(hold.variance_percent), "{:.4f}".format(hold.threshold_price), "{:.2f}".format(hold.threshold_value), "{:.2f}".format(hold.threshold_variance), str(hold.Threshold))) print("\nHoldings - Sort by Variance% asc - OK") holds = db_session.query(Holding).order_by(asc(Holding.variance_percent)).all() print(format_string.format("Ticker","Units","UnitPrice","LastPrice","TotalCost", "MarketValue","Variance","Variance%","ThresholdPrice", "ThresholdValue","ThresholdVariance","Threshold%")) for hold in holds: print(format_string.format(hold.StockCompany.Ticker, str(hold.Units), str(hold.UnitPrice), str(hold.last_price), str(hold.total_cost), str(hold.market_value), str(hold.variance), "{:.2f}".format(hold.variance_percent), "{:.4f}".format(hold.threshold_price), "{:.2f}".format(hold.threshold_value), "{:.2f}".format(hold.threshold_variance), str(hold.Threshold))) print("\nTransactions") trans = db_session.query(Transaction).all() format_string = "{:<8} {:<4} {:<10} {:>7} {:>9} {:>10} {:>9} {:>10}" print(format_string.format("Ticker","Type","Date","Units","UnitPrice","TotalValue", "Brokerage","TotalCost")) for tran in trans: print(format_string.format(tran.ItemHolding.StockCompany.Ticker, tran.Type, str(tran.Date), str(tran.Units), str(tran.UnitPrice), "{:.2f}".format(tran.total_value), str(tran.Brokerage), "{:.2f}".format(tran.total_cost))) print("\nTransactions - Sort by TotalValue asc - OK") trans = db_session.query(Transaction).order_by(asc(Transaction.total_value)).all() print(format_string.format("Ticker","Type","Date","Units","UnitPrice","TotalValue", "Brokerage","TotalCost")) for tran in trans: print(format_string.format(tran.ItemHolding.StockCompany.Ticker, tran.Type, str(tran.Date), str(tran.Units), str(tran.UnitPrice), "{:.2f}".format(tran.total_value), str(tran.Brokerage), "{:.2f}".format(tran.total_cost))) ################################################################################### print("\nTransactions - Sort by TotalCost desc - NOT OK") trans = db_session.query(Transaction).order_by(desc(Transaction.total_cost)).all() # trans = db_session.query(Transaction).order_by(Transaction.total_cost.desc()).all() print(format_string.format("Ticker","Type","Date","Units","UnitPrice","TotalValue", "Brokerage","TotalCost")) for tran in trans: print(format_string.format(tran.ItemHolding.StockCompany.Ticker, tran.Type, str(tran.Date), str(tran.Units), str(tran.UnitPrice), "{:.2f}".format(tran.total_value), str(tran.Brokerage), "{:.2f}".format(tran.total_cost))) print("\nHoldings - Sort by ThresholdPrice asc - NOT OK") holds = db_session.query(Holding).order_by(asc(Holding.threshold_price)).all() format_string = "{:<8} {:>7} {:>9} {:>9} {:>10} {:>11} {:>10} {:>9} {:>13} {:>14} {:>17} {:>10}" print(format_string.format("Ticker","Units","UnitPrice","LastPrice","TotalCost", "MarketValue","Variance","Variance%","ThresholdPrice", "ThresholdValue","ThresholdVariance","Threshold%")) for hold in holds: print(format_string.format(hold.StockCompany.Ticker, str(hold.Units), str(hold.UnitPrice), str(hold.last_price), str(hold.total_cost), str(hold.market_value), str(hold.variance), "{:.2f}".format(hold.variance_percent), "{:.4f}".format(hold.threshold_price), "{:.2f}".format(hold.threshold_value), "{:.2f}".format(hold.threshold_variance), str(hold.Threshold))) print("\nHoldings - Sort by ThresholdValue asc - NOT OK") holds = db_session.query(Holding).order_by(asc(Holding.threshold_value)).all() print(format_string.format("Ticker","Units","UnitPrice","LastPrice","TotalCost", "MarketValue","Variance","Variance%","ThresholdPrice", "ThresholdValue","ThresholdVariance","Threshold%")) for hold in holds: print(format_string.format(hold.StockCompany.Ticker, str(hold.Units), str(hold.UnitPrice), str(hold.last_price), str(hold.total_cost), str(hold.market_value), str(hold.variance), "{:.2f}".format(hold.variance_percent), "{:.4f}".format(hold.threshold_price), "{:.2f}".format(hold.threshold_value), "{:.2f}".format(hold.threshold_variance), str(hold.Threshold))) print("\nHoldings - Sort by ThresholdVariance asc - NOT OK") holds = db_session.query(Holding).order_by(asc(Holding.threshold_variance)).all() print(format_string.format("Ticker","Units","UnitPrice","LastPrice","TotalCost", "MarketValue","Variance","Variance%","ThresholdPrice", "ThresholdValue","ThresholdVariance","Threshold%")) for hold in holds: print(format_string.format(hold.StockCompany.Ticker, str(hold.Units), str(hold.UnitPrice), str(hold.last_price), str(hold.total_cost), str(hold.market_value), str(hold.variance), "{:.2f}".format(hold.variance_percent), "{:.4f}".format(hold.threshold_price), "{:.2f}".format(hold.threshold_value), "{:.2f}".format(hold.threshold_variance), str(hold.Threshold)))