http://git-wip-us.apache.org/repos/asf/cassandra/blob/57d558fc/pylib/cqlshlib/copyutil.py
----------------------------------------------------------------------
diff --cc pylib/cqlshlib/copyutil.py
index a2fab00,f699e64..a117ec3
--- a/pylib/cqlshlib/copyutil.py
+++ b/pylib/cqlshlib/copyutil.py
@@@ -23,19 -26,25 +26,25 @@@ import sy
import time
import traceback
- from StringIO import StringIO
+ from calendar import timegm
+ from collections import defaultdict, deque, namedtuple
+ from decimal import Decimal
from random import randrange
+ from StringIO import StringIO
from threading import Lock
+ from uuid import UUID
from cassandra.cluster import Cluster
+ from cassandra.cqltypes import ReversedType, UserType
from cassandra.metadata import protect_name, protect_names
- from cassandra.policies import RetryPolicy, WhiteListRoundRobinPolicy,
TokenAwarePolicy
- from cassandra.query import tuple_factory
+ from cassandra.policies import RetryPolicy, WhiteListRoundRobinPolicy,
TokenAwarePolicy, DCAwareRoundRobinPolicy
+ from cassandra.query import BatchStatement, BatchType, SimpleStatement,
tuple_factory
+ from cassandra.util import Date, Time
-
- import sslhandling
+ from cql3handling import CqlRuleSet
from displaying import NO_COLOR_MAP
-from formatting import format_value_default, EMPTY, get_formatter
+from formatting import format_value_default, DateTimeFormat, EMPTY,
get_formatter
+ from sslhandling import ssl_settings
def parse_options(shell, opts):
@@@ -65,10 -74,13 +74,15 @@@
# by default the page timeout is 10 seconds per 1000 entries in the page
size or 10 seconds if pagesize is smaller
csv_options['pagetimeout'] = int(opts.pop('pagetimeout', max(10, 10 *
(csv_options['pagesize'] / 1000))))
csv_options['maxattempts'] = int(opts.pop('maxattempts', 5))
- csv_options['float_precision'] = shell.display_float_precision
- csv_options['dtformats'] = opts.pop('timeformat',
shell.display_time_format)
+ csv_options['dtformats'] = DateTimeFormat(opts.pop('timeformat',
shell.display_timestamp_format),
+ shell.display_date_format,
+ shell.display_nanotime_format)
+ csv_options['float_precision'] = shell.display_float_precision
+ csv_options['chunksize'] = int(opts.pop('chunksize', 1000))
+ csv_options['ingestrate'] = int(opts.pop('ingestrate', 100000))
+ csv_options['maxbatchsize'] = int(opts.pop('maxbatchsize', 20))
+ csv_options['minbatchsize'] = int(opts.pop('minbatchsize', 2))
+ csv_options['reportfrequency'] = float(opts.pop('reportfrequency', 0.25))
return csv_options, dialect_options, opts
@@@ -371,30 -648,18 +650,18 @@@ class ExportProcess(ChildProcess)
An child worker process for the export task, ExportTask.
"""
- def __init__(self, inmsg, outmsg, ks, cf, columns, dialect_options,
csv_options,
- debug, port, cql_version, auth_provider, ssl,
protocol_version, config_file):
- mp.Process.__init__(self, target=self.run)
- self.inmsg = inmsg
- self.outmsg = outmsg
- self.ks = ks
- self.cf = cf
- self.columns = columns
- self.dialect_options = dialect_options
+ def __init__(self, params):
+ ChildProcess.__init__(self, params=params, target=self.run)
+ self.dialect_options = params['dialect_options']
self.hosts_to_sessions = dict()
- self.debug = debug
- self.port = port
- self.cql_version = cql_version
- self.auth_provider = auth_provider
- self.ssl = ssl
- self.protocol_version = protocol_version
- self.config_file = config_file
-
+ csv_options = params['csv_options']
self.encoding = csv_options['encoding']
- self.time_format = csv_options['dtformats']
+ self.date_time_format = csv_options['dtformats']
self.float_precision = csv_options['float_precision']
self.nullval = csv_options['nullval']
- self.maxjobs = csv_options['jobs']
+ self.max_attempts = csv_options['maxattempts']
+ self.max_requests = csv_options['maxrequests']
self.csv_options = csv_options
self.formatters = dict()
@@@ -600,13 -851,424 +853,424 @@@
return query
+ class ImportConversion(object):
+ """
+ A class for converting strings to values when importing from csv, used by
ImportProcess,
+ the parent.
+ """
+ def __init__(self, parent, table_meta, statement):
+ self.ks = parent.ks
+ self.cf = parent.cf
+ self.columns = parent.columns
+ self.nullval = parent.nullval
+ self.printmsg = parent.printmsg
+ self.table_meta = table_meta
+ self.primary_key_indexes = [self.columns.index(col.name) for col in
self.table_meta.primary_key]
+ self.partition_key_indexes = [self.columns.index(col.name) for col in
self.table_meta.partition_key]
+
+ self.proto_version = statement.protocol_version
+ self.cqltypes = dict([(c.name, c.type) for c in
statement.column_metadata])
+ self.converters = dict([(c.name, self._get_converter(c.type)) for c
in statement.column_metadata])
+
+ def _get_converter(self, cql_type):
+ """
+ Return a function that converts a string into a value the can be
passed
+ into BoundStatement.bind() for the given cql type. See
cassandra.cqltypes
+ for more details.
+ """
+ def unprotect(v):
+ if v is not None:
+ return CqlRuleSet.dequote_value(v)
+
+ def convert(t, v):
+ return converters.get(t.typename, convert_unknown)(unprotect(v),
ct=t)
+
+ def split(val, sep=','):
+ """
+ Split into a list of values whenever we encounter a separator but
+ ignore separators inside parentheses or single quotes, except for
the two
+ outermost parentheses, which will be ignored. We expect val to be
at least
+ 2 characters long (the two outer parentheses).
+ """
+ ret = []
+ last = 1
+ level = 0
+ quote = False
+ for i, c in enumerate(val):
+ if c == '{' or c == '[' or c == '(':
+ level += 1
+ elif c == '}' or c == ']' or c == ')':
+ level -= 1
+ elif c == '\'':
+ quote = not quote
+ elif c == sep and level == 1 and not quote:
+ ret.append(val[last:i])
+ last = i + 1
+ else:
+ if last < len(val) - 1:
+ ret.append(val[last:-1])
+
+ return ret
+
+ # this should match all possible CQL datetime formats
+ p = re.compile("(\d{4})\-(\d{2})\-(\d{2})\s?(?:'T')?" + #
YYYY-MM-DD[( |'T')]
+ "(?:(\d{2}):(\d{2})(?::(\d{2}))?)?" + # [HH:MM[:SS]]
+ "(?:([+\-])(\d{2}):?(\d{2}))?") # [(+|-)HH[:]MM]]
+
+ def convert_date(val, **_):
+ m = p.match(val)
+ if not m:
+ raise ValueError("can't interpret %r as a date" % (val,))
+
+ # https://docs.python.org/2/library/time.html#time.struct_time
+ tval = time.struct_time((int(m.group(1)), int(m.group(2)),
int(m.group(3)), # year, month, day
+ int(m.group(4)) if m.group(4) else 0, #
hour
+ int(m.group(5)) if m.group(5) else 0, #
minute
+ int(m.group(6)) if m.group(6) else 0, #
second
+ 0, 1, -1)) # day of week, day of year,
dst-flag
+
+ if m.group(7):
+ offset = (int(m.group(8)) * 3600 + int(m.group(9)) * 60) *
int(m.group(7) + '1')
+ else:
+ offset = -time.timezone
+
+ # scale seconds to millis for the raw value
+ return (timegm(tval) + offset) * 1e3
+
+ def convert_tuple(val, ct=cql_type):
+ return tuple(convert(t, v) for t, v in zip(ct.subtypes,
split(val)))
+
+ def convert_list(val, ct=cql_type):
+ return list(convert(ct.subtypes[0], v) for v in split(val))
+
+ def convert_set(val, ct=cql_type):
+ return frozenset(convert(ct.subtypes[0], v) for v in split(val))
+
+ def convert_map(val, ct=cql_type):
+ """
+ We need to pass to BoundStatement.bind() a dict() because it
calls iteritems(),
+ except we can't create a dict with another dict as the key, hence
we use a class
+ that adds iteritems to a frozen set of tuples (which is how dict
are normally made
+ immutable in python).
+ """
+ class ImmutableDict(frozenset):
+ iteritems = frozenset.__iter__
+
+ return ImmutableDict(frozenset((convert(ct.subtypes[0], v[0]),
convert(ct.subtypes[1], v[1]))
+ for v in [split('{%s}' % vv, sep=':') for vv
in split(val)]))
+
+ def convert_user_type(val, ct=cql_type):
+ """
+ A user type is a dictionary except that we must convert each key
into
+ an attribute, so we are using named tuples. It must also be
hashable,
+ so we cannot use dictionaries. Maybe there is a way to
instantiate ct
+ directly but I could not work it out.
+ """
+ vals = [v for v in [split('{%s}' % vv, sep=':') for vv in
split(val)]]
+ ret_type = namedtuple(ct.typename, [unprotect(v[0]) for v in
vals])
+ return ret_type(*tuple(convert(t, v[1]) for t, v in
zip(ct.subtypes, vals)))
+
+ def convert_single_subtype(val, ct=cql_type):
+ return converters.get(ct.subtypes[0].typename,
convert_unknown)(val, ct=ct.subtypes[0])
+
+ def convert_unknown(val, ct=cql_type):
+ if issubclass(ct, UserType):
+ return convert_user_type(val, ct=ct)
+ elif issubclass(ct, ReversedType):
+ return convert_single_subtype(val, ct=ct)
+
+ self.printmsg("Unknown type %s (%s) for val %s" % (ct,
ct.typename, val))
+ return val
+
+ converters = {
+ 'blob': (lambda v, ct=cql_type: bytearray.fromhex(v[2:])),
+ 'decimal': (lambda v, ct=cql_type: Decimal(v)),
+ 'uuid': (lambda v, ct=cql_type: UUID(v)),
+ 'boolean': (lambda v, ct=cql_type: bool(v)),
+ 'tinyint': (lambda v, ct=cql_type: int(v)),
+ 'ascii': (lambda v, ct=cql_type: v),
+ 'float': (lambda v, ct=cql_type: float(v)),
+ 'double': (lambda v, ct=cql_type: float(v)),
+ 'bigint': (lambda v, ct=cql_type: long(v)),
+ 'int': (lambda v, ct=cql_type: int(v)),
+ 'varint': (lambda v, ct=cql_type: int(v)),
+ 'inet': (lambda v, ct=cql_type: v),
+ 'counter': (lambda v, ct=cql_type: long(v)),
+ 'timestamp': convert_date,
+ 'timeuuid': (lambda v, ct=cql_type: UUID(v)),
+ 'date': (lambda v, ct=cql_type: Date(v)),
+ 'smallint': (lambda v, ct=cql_type: int(v)),
+ 'time': (lambda v, ct=cql_type: Time(v)),
+ 'text': (lambda v, ct=cql_type: v),
+ 'varchar': (lambda v, ct=cql_type: v),
+ 'list': convert_list,
+ 'set': convert_set,
+ 'map': convert_map,
+ 'tuple': convert_tuple,
+ 'frozen': convert_single_subtype,
+ }
+
+ return converters.get(cql_type.typename, convert_unknown)
+
+ def get_row_values(self, row):
+ """
+ Parse the row into a list of row values to be returned
+ """
+ ret = [None] * len(row)
+ for i, val in enumerate(row):
+ if val != self.nullval:
+ ret[i] = self.converters[self.columns[i]](val)
+ else:
+ if i in self.primary_key_indexes:
+ message = "Cannot insert null value for primary key
column '%s'." % (self.columns[i],)
+ if self.nullval == '':
+ message += " If you want to insert empty strings,
consider using" \
+ " the WITH NULL=<marker> option for COPY."
+ raise Exception(message=message)
+
+ ret[i] = None
+
+ return ret
+
+ def get_row_partition_key_values(self, row):
+ """
+ Return a string composed of the partition key values, serialized and
binary packed -
+ as expected by metadata.get_replicas(), see also
BoundStatement.routing_key.
+ """
+ def serialize(n):
+ c, v = self.columns[n], row[n]
+ return self.cqltypes[c].serialize(self.converters[c](v),
self.proto_version)
+
+ partition_key_indexes = self.partition_key_indexes
+ if len(partition_key_indexes) == 1:
+ return serialize(partition_key_indexes[0])
+ else:
+ pk_values = []
+ for i in partition_key_indexes:
+ val = serialize(i)
+ l = len(val)
+ pk_values.append(struct.pack(">H%dsB" % l, l, val, 0))
+ return b"".join(pk_values)
+
+
+ class ImportProcess(ChildProcess):
+
+ def __init__(self, params):
+ ChildProcess.__init__(self, params=params, target=self.run)
+
+ csv_options = params['csv_options']
+ self.nullval = csv_options['nullval']
+ self.max_attempts = csv_options['maxattempts']
+ self.min_batch_size = csv_options['minbatchsize']
+ self.max_batch_size = csv_options['maxbatchsize']
+ self._session = None
+
+ @property
+ def session(self):
+ if not self._session:
+ cluster = Cluster(
+ contact_points=(self.hostname,),
+ port=self.port,
+ cql_version=self.cql_version,
+ protocol_version=self.protocol_version,
+ auth_provider=self.auth_provider,
+
load_balancing_policy=TokenAwarePolicy(DCAwareRoundRobinPolicy()),
+ ssl_options=ssl_settings(self.hostname, self.config_file) if
self.ssl else None,
+ default_retry_policy=ExpBackoffRetryPolicy(self),
+ compression=None,
+ connect_timeout=self.connect_timeout)
+
+ self._session = cluster.connect(self.ks)
+ self._session.default_timeout = None
+ return self._session
+
+ def run(self):
+ try:
+ table_meta =
self.session.cluster.metadata.keyspaces[self.ks].tables[self.cf]
- is_counter = ("counter" in [table_meta.columns[name].typestring
for name in self.columns])
++ is_counter = ("counter" in [table_meta.columns[name].cql_type for
name in self.columns])
+
+ if is_counter:
+ self.run_counter(table_meta)
+ else:
+ self.run_normal(table_meta)
+
+ except Exception, exc:
+ if self.debug:
+ traceback.print_exc(exc)
+
+ finally:
+ self.close()
+
+ def close(self):
+ if self._session:
+ self._session.cluster.shutdown()
+ ChildProcess.close(self)
+
+ def run_counter(self, table_meta):
+ """
+ Main run method for tables that contain counter columns.
+ """
+ query = 'UPDATE %s.%s SET %%s WHERE %%s' % (protect_name(self.ks),
protect_name(self.cf))
+
+ # We prepare a query statement to find out the types of the partition
key columns so we can
+ # route the update query to the correct replicas. As far as I
understood this is the easiest
+ # way to find out the types of the partition columns, we will never
use this prepared statement
+ where_clause = ' AND '.join(['%s = ?' % (protect_name(c.name)) for c
in table_meta.partition_key])
+ select_query = 'SELECT * FROM %s.%s WHERE %s' %
(protect_name(self.ks), protect_name(self.cf), where_clause)
+ conv = ImportConversion(self, table_meta,
self.session.prepare(select_query))
+
+ while True:
+ try:
+ batch = self.inmsg.get()
+
+ for batches in self.split_batches(batch, conv):
+ for b in batches:
+ self.send_counter_batch(query, conv, b)
+
+ except Exception, exc:
+ self.outmsg.put((batch, '%s - %s' % (exc.__class__.__name__,
exc.message)))
+ if self.debug:
+ traceback.print_exc(exc)
+
+ def run_normal(self, table_meta):
+ """
+ Main run method for normal tables, i.e. tables that do not contain
counter columns.
+ """
+ query = 'INSERT INTO %s.%s (%s) VALUES (%s)' % (protect_name(self.ks),
+ protect_name(self.cf),
+ ',
'.join(protect_names(self.columns),),
+ ', '.join(['?' for _
in self.columns]))
+ query_statement = self.session.prepare(query)
+ conv = ImportConversion(self, table_meta, query_statement)
+
+ while True:
+ try:
+ batch = self.inmsg.get()
+
+ for batches in self.split_batches(batch, conv):
+ for b in batches:
+ self.send_normal_batch(conv, query_statement, b)
+
+ except Exception, exc:
+ self.outmsg.put((batch, '%s - %s' % (exc.__class__.__name__,
exc.message)))
+ if self.debug:
+ traceback.print_exc(exc)
+
+ def send_counter_batch(self, query_text, conv, batch):
+ if self.test_failures and self.maybe_inject_failures(batch):
+ return
+
+ columns = self.columns
+ batch_statement = BatchStatement(batch_type=BatchType.COUNTER,
consistency_level=self.consistency_level)
+ for row in batch['rows']:
+ where_clause = []
+ set_clause = []
+ for i, value in enumerate(row):
+ if i in conv.primary_key_indexes:
+ where_clause.append("%s=%s" % (columns[i], value))
+ else:
+ set_clause.append("%s=%s+%s" % (columns[i], columns[i],
value))
+
+ full_query_text = query_text % (','.join(set_clause), ' AND
'.join(where_clause))
+ batch_statement.add(full_query_text)
+
+ self.execute_statement(batch_statement, batch)
+
+ def send_normal_batch(self, conv, query_statement, batch):
+ try:
+ if self.test_failures and self.maybe_inject_failures(batch):
+ return
+
+ batch_statement = BatchStatement(batch_type=BatchType.UNLOGGED,
consistency_level=self.consistency_level)
+ for row in batch['rows']:
+ batch_statement.add(query_statement, conv.get_row_values(row))
+
+ self.execute_statement(batch_statement, batch)
+
+ except Exception, exc:
+ self.err_callback(exc, batch)
+
+ def maybe_inject_failures(self, batch):
+ """
+ Examine self.test_failures and see if token_range is either a token
range
+ supposed to cause a failure (failing_range) or to terminate the
worker process
+ (exit_range). If not then call prepare_export_query(), which
implements the
+ normal behavior.
+ """
+ if 'failing_batch' in self.test_failures:
+ failing_batch = self.test_failures['failing_batch']
+ if failing_batch['id'] == batch['id']:
+ if batch['attempts'] < failing_batch['failures']:
+ statement = SimpleStatement("INSERT INTO badtable (a, b)
VALUES (1, 2)",
+
consistency_level=self.consistency_level)
+ self.execute_statement(statement, batch)
+ return True
+
+ if 'exit_batch' in self.test_failures:
+ exit_batch = self.test_failures['exit_batch']
+ if exit_batch['id'] == batch['id']:
+ sys.exit(1)
+
+ return False # carry on as normal
+
+ def execute_statement(self, statement, batch):
+ future = self.session.execute_async(statement)
+ future.add_callbacks(callback=self.result_callback,
callback_args=(batch, ),
+ errback=self.err_callback, errback_args=(batch,
))
+
+ def split_batches(self, batch, conv):
+ """
+ Split a batch into sub-batches with the same
+ partition key, if possible. If there are at least
+ batch_size rows with the same partition key value then
+ create a sub-batch with that partition key value, else
+ aggregate all remaining rows in a single 'left-overs' batch
+ """
+ rows_by_pk = defaultdict(list)
+
+ for row in batch['rows']:
+ pk = conv.get_row_partition_key_values(row)
+ rows_by_pk[pk].append(row)
+
+ ret = dict()
+ remaining_rows = []
+
+ for pk, rows in rows_by_pk.items():
+ if len(rows) >= self.min_batch_size:
+ ret[pk] = self.batches(rows, batch)
+ else:
+ remaining_rows.extend(rows)
+
+ if remaining_rows:
+ ret[self.hostname] = self.batches(remaining_rows, batch)
+
+ return ret.itervalues()
+
+ def batches(self, rows, batch):
+ for i in xrange(0, len(rows), self.max_batch_size):
+ yield ImportTask.make_batch(batch['id'], rows[i:i +
self.max_batch_size], batch['attempts'])
+
+ def result_callback(self, result, batch):
+ batch['imported'] = len(batch['rows'])
+ batch['rows'] = [] # no need to resend these
+ self.outmsg.put((batch, None))
+
+ def err_callback(self, response, batch):
+ batch['imported'] = len(batch['rows'])
+ self.outmsg.put((batch, '%s - %s' % (response.__class__.__name__,
response.message)))
+ if self.debug:
+ traceback.print_exc(response)
+
+
class RateMeter(object):
- def __init__(self, log_threshold):
- self.log_threshold = log_threshold # number of records after which
we log
- self.last_checkpoint_time = time.time() # last time we logged
+ def __init__(self, update_interval=0.25, log=True):
+ self.log = log # true if we should log
+ self.update_interval = update_interval # how often we update in
seconds
+ self.start_time = time.time() # the start time
+ self.last_checkpoint_time = self.start_time # last time we logged
self.current_rate = 0.0 # rows per second
- self.current_record = 0 # number of records since we last logged
+ self.current_record = 0 # number of records since we last updated
self.total_records = 0 # total number of records
def increment(self, n=1):