Author: eevans Date: Mon Feb 21 16:55:19 2011 New Revision: 1073064 URL: http://svn.apache.org/viewvc?rev=1073064&view=rev Log: python CQL driver result decoding
Patch by eevans for CASSANDRA-1711 Added: cassandra/trunk/drivers/py/cql/decoders.py cassandra/trunk/drivers/py/cql/results.py Modified: cassandra/trunk/drivers/py/cql/connection.py cassandra/trunk/drivers/py/cql/connection_pool.py cassandra/trunk/drivers/py/cql/marshal.py cassandra/trunk/drivers/py/cqlsh Modified: cassandra/trunk/drivers/py/cql/connection.py URL: http://svn.apache.org/viewvc/cassandra/trunk/drivers/py/cql/connection.py?rev=1073064&r1=1073063&r2=1073064&view=diff ============================================================================== --- cassandra/trunk/drivers/py/cql/connection.py (original) +++ cassandra/trunk/drivers/py/cql/connection.py Mon Feb 21 16:55:19 2011 @@ -21,7 +21,9 @@ from thrift.protocol import TBinaryProto from thrift.Thrift import TApplicationException from errors import CQLException, InvalidCompressionScheme from marshal import prepare -import zlib +from decoders import SchemaDecoder +from results import RowsProxy +import zlib, re try: from cassandra import Cassandra @@ -57,20 +59,79 @@ class Connection(object): ... for column in row.columns: ... print "%s is %s years of age" % (r.key, column.age) """ + _keyspace_re = re.compile("USE (\w+);?", re.I | re.M) + _cfamily_re = re.compile("SELECT\s+.+\s+FROM\s+(\w+)", re.I | re.M) + def __init__(self, host, port=9160, keyspace=None, username=None, - password=None): + password=None, decoder=None): + """ + Params: + * host .........: hostname of Cassandra node. + * port .........: port number to connect to (optional). + * keyspace .....: keyspace name (optional). + * username .....: username used in authentication (optional). + * password .....: password used in authentication (optional). + * decoder ......: result decoder instance (optional, defaults to none). + """ socket = TSocket.TSocket(host, port) self.transport = TTransport.TFramedTransport(socket) protocol = TBinaryProtocol.TBinaryProtocolAccelerated(self.transport) self.client = Cassandra.Client(protocol) socket.open() + # XXX: "current" is probably a misnomer. + self._cur_keyspace = None + self._cur_column_family = None + if username and password: credentials = {"username": username, "password": password} self.client.login(AuthenticationRequest(credentials=credentials)) if keyspace: self.execute('USE %s;' % keyspace) + self._cur_keyspace = keyspace + + if not decoder: + self.decoder = SchemaDecoder(self.__get_schema()) + else: + self.decoder = decoder + + def __get_schema(self): + def columns(metadata): + results = {} + for col in metadata: + results[col.name] = col.validation_class + return results + + def column_families(cf_defs): + cfresults = {} + for cf in cf_defs: + cfresults[cf.name] = {"comparator": cf.comparator_type} + cfresults[cf.name]["default_validation_class"] = \ + cf.default_validation_class + cfresults[cf.name]["columns"] = columns(cf.column_metadata) + return cfresults + + schema = {} + for ksdef in self.client.describe_keyspaces(): + schema[ksdef.name] = column_families(ksdef.cf_defs) + return schema + + def prepare(self, query, *args): + prepared_query = prepare(query, *args) + + # Snag the keyspace or column family and stash it for later use in + # decoding columns. These regexes don't match every query, but the + # current column family only needs to be current for SELECTs. + match = Connection._cfamily_re.match(prepared_query) + if match: + self._cur_column_family = match.group(1) + else: + match = Connection._keyspace_re.match(prepared_query) + if match: + self._cur_keyspace = match.group(1) + + return prepared_query def execute(self, query, *args, **kwargs): """ @@ -85,8 +146,8 @@ class Connection(object): compress = kwargs.get("compression").upper() else: compress = DEFAULT_COMPRESSION - - compressed_query = Connection.compress_query(prepare(query, *args), + + compressed_query = Connection.compress_query(self.prepare(query, *args), compress) request_compression = getattr(Compression, compress) @@ -101,7 +162,11 @@ class Connection(object): raise CQLException(exc) if response.type == CqlResultType.ROWS: - return response.rows + return RowsProxy(response.rows, + self._cur_keyspace, + self._cur_column_family, + self.decoder) + if response.type == CqlResultType.INT: return response.num @@ -127,5 +192,5 @@ class Connection(object): if compression == 'GZIP': return zlib.compress(query) - + # vi: ai ts=4 tw=0 sw=4 et Modified: cassandra/trunk/drivers/py/cql/connection_pool.py URL: http://svn.apache.org/viewvc/cassandra/trunk/drivers/py/cql/connection_pool.py?rev=1073064&r1=1073063&r2=1073064&view=diff ============================================================================== --- cassandra/trunk/drivers/py/cql/connection_pool.py (original) +++ cassandra/trunk/drivers/py/cql/connection_pool.py Mon Feb 21 16:55:19 2011 @@ -38,12 +38,14 @@ class ConnectionPool(object): >>> pool.return_connection(conn) """ def __init__(self, hostname, port=9160, keyspace=None, username=None, - password=None, max_conns=25, max_idle=5, eviction_delay=10000): + password=None, decoder=None, max_conns=25, max_idle=5, + eviction_delay=10000): self.hostname = hostname self.port = port self.keyspace = keyspace self.username = username self.password = password + self.decoder = decoder self.max_conns = max_conns self.max_idle = max_idle self.eviction_delay = eviction_delay @@ -59,7 +61,8 @@ class ConnectionPool(object): port=self.port, keyspace=self.keyspace, username=self.username, - password=self.password) + password=self.password, + decoder=self.decoder) def borrow_connection(self): try: Added: cassandra/trunk/drivers/py/cql/decoders.py URL: http://svn.apache.org/viewvc/cassandra/trunk/drivers/py/cql/decoders.py?rev=1073064&view=auto ============================================================================== --- cassandra/trunk/drivers/py/cql/decoders.py (added) +++ cassandra/trunk/drivers/py/cql/decoders.py Mon Feb 21 16:55:19 2011 @@ -0,0 +1,61 @@ + +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from os.path import abspath, dirname, join +from marshal import unmarshal + +class BaseDecoder(object): + def decode_column(self, keyspace, column_family, name, value): + raise NotImplementedError() + +class NoopDecoder(BaseDecoder): + def decode_column(self, keyspace, column_family, name, value): + return (name, value) + +class SchemaDecoder(BaseDecoder): + """ + Decode binary column names/values according to schema. + """ + def __init__(self, schema={}): + self.schema = schema + + def __get_column_family_def(self, keyspace, column_family): + if self.schema.has_key(keyspace): + if self.schema[keyspace].has_key(column_family): + return self.schema[keyspace][column_family] + return None + + def __comparator_for(self, keyspace, column_family): + cfam = self.__get_column_family_def(keyspace, column_family) + if cfam and cfam.has_key("comparator"): + return cfam["comparator"] + return None + + def __validator_for(self, keyspace, column_family, name): + cfam = self.__get_column_family_def(keyspace, column_family) + if cfam: + if cfam["columns"].has_key(name): + return cfam["columns"][name] + else: + return cfam["default_validation_class"] + return None + + def decode_column(self, keyspace, column_family, name, value): + comparator = self.__comparator_for(keyspace, column_family) + validator = self.__validator_for(keyspace, column_family, name) + return (unmarshal(name, comparator), unmarshal(value, validator)) + Modified: cassandra/trunk/drivers/py/cql/marshal.py URL: http://svn.apache.org/viewvc/cassandra/trunk/drivers/py/cql/marshal.py?rev=1073064&r1=1073063&r2=1073064&view=diff ============================================================================== --- cassandra/trunk/drivers/py/cql/marshal.py (original) +++ cassandra/trunk/drivers/py/cql/marshal.py Mon Feb 21 16:55:19 2011 @@ -18,8 +18,9 @@ from uuid import UUID from StringIO import StringIO from errors import InvalidQueryFormat +from struct import unpack -__all__ = ['prepare'] +__all__ = ['prepare', 'marshal', 'unmarshal'] def prepare(query, *args): result = StringIO() @@ -60,3 +61,28 @@ def marshal(term): return "uuid(\"%s\")" % str(term) else: return str(term) + +def unmarshal(bytestr, typestr): + if typestr == "org.apache.cassandra.db.marshal.BytesType": + return bytestr + elif typestr == "org.apache.cassandra.db.marshal.AsciiType": + return bytestr + elif typestr == "org.apache.cassandra.db.marshal.UTF8Type": + return bytestr.decode("utf8") + elif typestr == "org.apache.cassandra.db.marshal.IntegerType": + return decode_bigint(bytestr) + elif typestr == "org.apache.cassandra.db.marshal.LongType": + return unpack(">q", bytestr)[0] + elif typestr == "org.apache.cassandra.db.marshal.LexicalUUIDType": + return UUID(bytes=bytestr) + elif typestr == "org.apache.cassandra.db.marshal.TimeUUIDType": + return UUID(bytes=bytetr) + else: + return bytestr + +def decode_bigint(term): + val = int(term.encode('hex'), 16) + if (ord(term[0]) & 128) != 0: + val = val - (1 << (len(term) * 8)) + return val + Added: cassandra/trunk/drivers/py/cql/results.py URL: http://svn.apache.org/viewvc/cassandra/trunk/drivers/py/cql/results.py?rev=1073064&view=auto ============================================================================== --- cassandra/trunk/drivers/py/cql/results.py (added) +++ cassandra/trunk/drivers/py/cql/results.py Mon Feb 21 16:55:19 2011 @@ -0,0 +1,82 @@ + +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +class RowsProxy(object): + def __init__(self, rows, keyspace, cfam, decoder): + self.rows = rows + self.keyspace = keyspace + self.cfam = cfam + self.decoder = decoder + + def __len__(self): + return len(self.rows) + + def __getitem__(self, idx): + return Row(self.rows[idx].key, + self.rows[idx].columns, + self.keyspace, + self.cfam, + self.decoder) + + def __iter__(self): + for r in self.rows: + yield Row(r.key, r.columns, self.keyspace, self.cfam, self.decoder) + +class Row(object): + def __init__(self, key, columns, keyspace, cfam, decoder): + self.key = key + self.columns = ColumnsProxy(columns, keyspace, cfam, decoder) + +class ColumnsProxy(object): + def __init__(self, columns, keyspace, cfam, decoder): + self.columns = columns + self.keyspace = keyspace + self.cfam = cfam + self.decoder = decoder + + def __len__(self): + return len(self.columns) + + def __getitem__(self, idx): + return Column(self.decoder.decode_column(self.keyspace, + self.cfam, + self.columns[idx].name, + self.columns[idx].value)) + + def __iter__(self): + for c in self.columns: + yield Column(self.decoder.decode_column(self.keyspace, + self.cfam, + c.name, + c.value)) + + def __str__(self): + return "ColumnsProxy(columns=%s)" % self.columns + + def __repr__(self): + return str(self) + +class Column(object): + def __init__(self, (name, value)): + self.name = name + self.value = value + + def __str__(self): + return "Column(%s, %s)" % (self.name, self.value) + + def __repr__(self): + return str(self) \ No newline at end of file Modified: cassandra/trunk/drivers/py/cqlsh URL: http://svn.apache.org/viewvc/cassandra/trunk/drivers/py/cqlsh?rev=1073064&r1=1073063&r2=1073064&view=diff ============================================================================== --- cassandra/trunk/drivers/py/cqlsh (original) +++ cassandra/trunk/drivers/py/cqlsh Mon Feb 21 16:55:19 2011 @@ -12,10 +12,12 @@ import re try: from cql import Connection from cql.errors import CQLException + from cql.results import RowsProxy except ImportError: sys.path.append(os.path.abspath(os.path.dirname(__file__))) from cql import Connection from cql.errors import CQLException + from cql.results import RowsProxy HISTORY = os.path.join(os.path.expanduser('~'), '.cqlsh') CQLTYPES = ("bytes", "ascii", "utf8", "timeuuid", "uuid", "long", "int") @@ -76,7 +78,7 @@ class Shell(cmd.Cmd): result = self.conn.execute(statement) - if isinstance(result, list): + if isinstance(result, RowsProxy): for row in result: self.printout(row.key, BLUE, False) for column in row.columns: