Hello. I've created an encryption wrapper that marshals things in and out 
of encryption, except that for the TypeDecorator the input and output type 
are different. I'm having trouble wrapping float and Enum types while still 
keeping them queryable.

Can someone take a look at the attached reproducer and let me know what I 
can do to fix it up/work around the issue?

-- 
SQLAlchemy - 
The Python SQL Toolkit and Object Relational Mapper

http://www.sqlalchemy.org/

To post example code, please provide an MCVE: Minimal, Complete, and Verifiable 
Example.  See  http://stackoverflow.com/help/mcve for a full description.
--- 
You received this message because you are subscribed to the Google Groups 
"sqlalchemy" group.
To unsubscribe from this group and stop receiving emails from it, send an email 
to sqlalchemy+unsubscr...@googlegroups.com.
To view this discussion on the web visit 
https://groups.google.com/d/msgid/sqlalchemy/c44677c2-343e-49de-8bfe-0b5cfab70b41%40googlegroups.com.
import sqlalchemy
from sqlalchemy.ext.declarative import declarative_base
import sqlalchemy.dialects.postgresql as postgresql
import enum
import tempfile
import contextlib
import pathlib
import subprocess
import getpass

PASSPHRASE_USER = "PASSPHRASE_USER"

class Encrypted(sqlalchemy.types.TypeDecorator):
    impl = postgresql.BYTEA

    def __init__(self, inner, passphrase):
        super().__init__()
        self.inner = inner
        self.passphrase = passphrase

    def bind_expression(self, bindvalue):
        bindvalue = sqlalchemy.cast(bindvalue, sqlalchemy.String)
        return sqlalchemy.func.pgp_sym_encrypt(bindvalue, self.passphrase)

    def column_expression(self, col):
        value = sqlalchemy.func.pgp_sym_decrypt(col, self.passphrase)
        value = sqlalchemy.cast(value, self.inner)
        return value

Base = declarative_base()

class Type(enum.Enum):
    unverified = 0
    normal = 1
    superuser = 2

class User(Base):
    __tablename__ = "users"

    id = sqlalchemy.Column(sqlalchemy.types.BigInteger, primary_key=True)
    user_type = sqlalchemy.Column(
        Encrypted(sqlalchemy.types.Enum(Type), PASSPHRASE_USER)
    )
    reputation = sqlalchemy.Column(
        Encrypted(sqlalchemy.types.Float, PASSPHRASE_USER)
    )

@contextlib.contextmanager
def postgresql_instance():
    with tempfile.TemporaryDirectory() as tempdir:
        tmpdir = pathlib.Path(tempdir)
        datadir = tmpdir / "data_dir"

        pwfile = tmpdir / "passwd"
        datadir.mkdir()
        with open(pwfile, 'w') as f:
            f.write("random_pw")
        subprocess.run(
            [
                "initdb",
                "-D",
                str(datadir),
                "-A",
                "password",
                "--pwfile",
                str(pwfile),
            ],
            check=True,
        )
        subprocess.run(
            ["pg_ctl", "start", "-D", str(datadir), "-l", str(tmpdir / "logfile")],
            check=True,
        )
        yield
        subprocess.run(["pg_ctl", "stop", "-D", str(datadir)], check=True)

@contextlib.contextmanager
def postgresql_engine():
    with postgresql_instance():
        from sqlalchemy import create_engine
        from sqlalchemy_utils import create_database

        connect_string = f"postgresql+psycopg2://{getpass.getuser()}:random_pw@localhost:5432/streamserver"
        engine = create_engine(connect_string, echo=True)
        create_database(engine.url)
        Base.metadata.create_all(engine)

        with engine.connect() as con:
            con.execute("CREATE EXTENSION pgcrypto;")

        yield engine

@contextlib.contextmanager
def postgresql_session():
    from sqlalchemy.orm import sessionmaker

    with postgresql_engine() as engine:
        Session = sessionmaker(bind=engine)
        session = Session()
        yield session
        session.close()

if __name__ == "__main__":
    with postgresql_session() as session:
        user1 = User(
            user_type=Type.unverified,
            reputation=0.0
        )
        session.add(user1)
        session.commit()

        assert session.query(User).count() == 1
        users = session.query(User).all()

        for user in users:
            assert user == user1

Reply via email to