Repository: cassandra
Updated Branches:
  refs/heads/cassandra-3.0 942e5e599 -> 1deb04432

(cqlsh) further optimise COPY FROM

patch by Stefania Alborghetti; reviewed by Adam Holmberg for


Branch: refs/heads/cassandra-3.0
Commit: 124f1bd2613e400f69f8369ada0ad15c28738530
Parents: 994250c
Author: Stefania Alborghetti <>
Authored: Thu Oct 22 17:16:50 2015 +0800
Committer: Aleksey Yeschenko <>
Committed: Tue Dec 15 21:03:31 2015 +0000

 CHANGES.txt                |   4 +-
 bin/cqlsh                  | 285 ++-----------
 pylib/cqlshlib/ | 910 ++++++++++++++++++++++++++++++++++------
 pylib/cqlshlib/     |  19 +
 4 files changed, 838 insertions(+), 380 deletions(-)
diff --git a/CHANGES.txt b/CHANGES.txt
index 8e58703..90f1bca 100644
--- a/CHANGES.txt
+++ b/CHANGES.txt
@@ -1,12 +1,10 @@
-<<<<<<< HEAD
+ * (cqlsh) further optimise COPY FROM (CASSANDRA-9302)
  * Make Stress compiles within eclipse (CASSANDRA-10807)
  * Cassandra Daemon should print JVM arguments (CASSANDRA-10764)
  * Allow cancellation of index summary redistribution (CASSANDRA-8805)
  * sstableloader will fail if there are collections in the schema tables 
->>>>>>> 5377183... stableloader will fail if there are collections in the 
schema tables
  * Disable reloading of GossipingPropertyFileSnitch (CASSANDRA-9474)
  * Fix Stress profile parsing on Windows (CASSANDRA-10808)
diff --git a/bin/cqlsh b/bin/cqlsh
index e72624a..651420d 100755
--- a/bin/cqlsh
+++ b/bin/cqlsh
@@ -37,7 +37,6 @@ import ConfigParser
 import csv
 import getpass
 import locale
-import multiprocessing as mp
 import optparse
 import os
 import platform
@@ -48,7 +47,6 @@ import warnings
 from StringIO import StringIO
 from contextlib import contextmanager
-from functools import partial
 from glob import glob
 from uuid import UUID
@@ -110,10 +108,10 @@ except ImportError, e:
 from cassandra.auth import PlainTextAuthProvider
 from cassandra.cluster import Cluster, PagedResult
-from cassandra.metadata import protect_name, protect_names, protect_value
+from cassandra.metadata import protect_name, protect_names
 from cassandra.policies import WhiteListRoundRobinPolicy
-from cassandra.protocol import QueryMessage, ResultMessage
-from cassandra.query import SimpleStatement, ordered_dict_factory
+from cassandra.protocol import ResultMessage
+from cassandra.query import SimpleStatement, ordered_dict_factory, 
 # cqlsh should run correctly when run out of a Cassandra source tree,
 # out of an unpacked Cassandra tarball, and after a proper package install.
@@ -334,7 +332,7 @@ cqlsh_extra_syntax_rules = r'''
 <copyOptionVal> ::= <identifier>
                   | <reserved_identifier>
-                  | <stringLiteral>
+                  | <term>
 # avoiding just "DEBUG" so that this rule doesn't get treated as a terminal
@@ -412,17 +410,20 @@ def complete_copy_column_names(ctxt, cqlsh):
     return set(colnames[1:]) - set(existcols)
+                       'MAXATTEMPTS', 'REPORTFREQUENCY']
 @cqlsh_syntax_completer('copyOption', 'optnames')
 def complete_copy_options(ctxt, cqlsh):
     optnames = map(str.upper, ctxt.get_binding('optnames', ()))
     direction = ctxt.get_binding('dir').upper()
-    opts = set(COPY_OPTIONS) - set(optnames)
     if direction == 'FROM':
-        opts -= set(['ENCODING', 'TIMEFORMAT', 'JOBS', 'PAGESIZE', 
+        opts = set(COPY_COMMON_OPTIONS + COPY_FROM_OPTIONS) - set(optnames)
+    elif direction == 'TO':
+        opts = set(COPY_COMMON_OPTIONS + COPY_TO_OPTIONS) - set(optnames)
     return opts
@@ -1520,12 +1521,18 @@ class Shell(cmd.Cmd):
           ESCAPE='\'              - character to appear before the QUOTE char 
when quoted
           HEADER=false            - whether to ignore the first line
           NULL=''                 - string that represents a null value
-          ENCODING='utf8'         - encoding for CSV output (COPY TO only)
-          TIMEFORMAT=             - timestamp strftime format (COPY TO only)
+          ENCODING='utf8'         - encoding for CSV output (COPY TO)
+          TIMEFORMAT=             - timestamp strftime format (COPY TO)
             '%Y-%m-%d %H:%M:%S%z'   defaults to time_format value in cqlshrc
-          PAGESIZE='1000'         - the page size for fetching results (COPY 
TO only)
-          PAGETIMEOUT=10          - the page timeout for fetching results 
(COPY TO only)
-          MAXATTEMPTS='5'         - the maximum number of attempts for errors 
(COPY TO only)
+          MAXREQUESTS=6           - the maximum number of requests each worker 
process can work on in parallel (COPY TO)
+          PAGESIZE=1000           - the page size for fetching results (COPY 
+          PAGETIMEOUT=10          - the page timeout for fetching results 
+          MAXATTEMPTS=5           - the maximum number of attempts for errors
+          CHUNKSIZE=1000          - the size of chunks passed to worker 
processes (COPY FROM)
+          INGESTRATE=100000       - an approximate ingest rate in rows per 
second (COPY FROM)
+          MAXBATCHSIZE=20         - the maximum size of an import batch (COPY 
+          MINBATCHSIZE=2          - the minimum size of an import batch (COPY 
+          REPORTFREQUENCY=0.25    - the frequency with which we display status 
updates in seconds
         When entering CSV data on STDIN, you can use the sequence "\."
         on a line by itself to end the data input.
@@ -1571,253 +1578,11 @@ class Shell(cmd.Cmd):
     def perform_csv_import(self, ks, cf, columns, fname, opts):
         csv_options, dialect_options, unrecognized_options = 
copyutil.parse_options(self, opts)
         if unrecognized_options:
-            self.printerr('Unrecognized COPY FROM options: %s'
-                          % ', '.join(unrecognized_options.keys()))
+            self.printerr('Unrecognized COPY FROM options: %s' % ', 
             return 0
-        nullval, header = csv_options['nullval'], csv_options['header']
-        if fname is None:
-            do_close = False
-            print "[Use \. on a line by itself to end input]"
-            linesource = self.use_stdin_reader(prompt='[copy] ', until=r'\.')
-        else:
-            do_close = True
-            try:
-                linesource = open(fname, 'rb')
-            except IOError, e:
-                self.printerr("Can't open %r for reading: %s" % (fname, e))
-                return 0
-        current_record = None
-        processes, pipes = [], [],
-        try:
-            if header:
-            reader = csv.reader(linesource, **dialect_options)
-            num_processes = copyutil.get_num_processes(cap=4)
-            for i in range(num_processes):
-                parent_conn, child_conn = mp.Pipe()
-                pipes.append(parent_conn)
-                proc_args = (child_conn, ks, cf, columns, nullval)
-                processes.append(mp.Process(target=self.multiproc_import, 
-            for process in processes:
-                process.start()
-            meter = copyutil.RateMeter(10000)
-            for current_record, row in enumerate(reader, start=1):
-                # write to the child process
-                pipes[current_record % num_processes].send((current_record, 
-                # update the progress and current rate periodically
-                meter.increment()
-                # check for any errors reported by the children
-                if (current_record % 100) == 0:
-                    if self._check_import_processes(current_record, pipes):
-                        # no errors seen, continue with outer loop
-                        continue
-                    else:
-                        # errors seen, break out of outer loop
-                        break
-        except Exception, exc:
-            if current_record is None:
-                # we failed before we started
-                self.printerr("\nError starting import process:\n")
-                self.printerr(str(exc))
-                if self.debug:
-                    traceback.print_exc()
-            else:
-                self.printerr("\n" + str(exc))
-                self.printerr("\nAborting import at record #%d. "
-                              "Previously inserted records and some records 
after "
-                              "this number may be present."
-                              % (current_record,))
-                if self.debug:
-                    traceback.print_exc()
-        finally:
-            # send a message that indicates we're done
-            for pipe in pipes:
-                pipe.send((None, None))
-            for process in processes:
-                process.join()
-            self._check_import_processes(current_record, pipes)
-            for pipe in pipes:
-                pipe.close()
-            if do_close:
-                linesource.close()
-            elif self.tty:
-                print
-        return current_record
-    def _check_import_processes(self, current_record, pipes):
-        for pipe in pipes:
-            if pipe.poll():
-                try:
-                    (record_num, error) = pipe.recv()
-                    self.printerr("\n" + str(error))
-                    self.printerr(
-                        "Aborting import at record #%d. "
-                        "Previously inserted records are still present, "
-                        "and some records after that may be present as well."
-                        % (record_num,))
-                    return False
-                except EOFError:
-                    # pipe is closed, nothing to read
-                    self.printerr("\nChild process died without notification, "
-                                  "aborting import at record #%d. Previously "
-                                  "inserted records are probably still 
present, "
-                                  "and some records after that may be present "
-                                  "as well." % (current_record,))
-                    return False
-        return True
-    def multiproc_import(self, pipe, ks, cf, columns, nullval):
-        """
-        This method is where child processes start when doing a COPY FROM
-        operation.  The child process will open one connection to the node and
-        interact directly with the connection, bypassing most of the driver
-        code.  Because we don't need retries, connection pooling, thread 
-        and other fancy features, this is okay.
-        """
-        # open a new connection for this subprocess
-        new_cluster = Cluster(
-            contact_points=(self.hostname,),
-            port=self.port,
-            cql_version=self.conn.cql_version,
-            protocol_version=DEFAULT_PROTOCOL_VERSION,
-            auth_provider=self.auth_provider,
-            ssl_options=sslhandling.ssl_settings(self.hostname, CONFIG_FILE) 
if self.ssl else None,
-            load_balancing_policy=WhiteListRoundRobinPolicy([self.hostname]),
-            compression=None,
-            connect_timeout=self.conn.connect_timeout)
-        session = new_cluster.connect(self.keyspace)
-        conn = session._pools.values()[0]._connection
-        # pre-build as much of the query as we can
-        table_meta = self.get_table_meta(ks, cf)
-        pk_cols = [ for col in table_meta.primary_key]
-        cqltypes = [table_meta.columns[name].typestring for name in columns]
-        pk_indexes = [columns.index( for col in 
-        is_counter_table = ("counter" in cqltypes)
-        if is_counter_table:
-            query = 'Update %s.%s SET %%s WHERE %%s' % (
-                protect_name(ks),
-                protect_name(cf))
-        else:
-            query = 'INSERT INTO %s.%s (%s) VALUES (%%s)' % (
-                protect_name(ks),
-                protect_name(cf),
-                ', '.join(protect_names(columns)))
-        # we need to handle some types specially
-        should_escape = [t in ('ascii', 'text', 'timestamp', 'date', 'time', 
'inet') for t in cqltypes]
-        insert_timestamp = int(time.time() * 1e6)
-        def callback(record_num, response):
-            # This is the callback we register for all inserts.  Because this
-            # is run on the event-loop thread, we need to hold a lock when
-            # adjusting in_flight.
-            with conn.lock:
-                conn.in_flight -= 1
-            if not isinstance(response, ResultMessage):
-                # It's an error. Notify the parent process and let it send
-                # a stop signal to all child processes (including this one).
-                pipe.send((record_num, str(response)))
-                if isinstance(response, Exception) and self.debug:
-                    traceback.print_exc(response)
-        current_record = 0
-        insert_num = 0
-        try:
-            while True:
-                # To avoid totally maxing out the connection,
-                # defer to the reactor thread when we're close
-                # to capacity
-                if conn.in_flight > (conn.max_request_id * 0.9):
-                    conn._readable = True
-                    time.sleep(0.05)
-                    continue
-                try:
-                    (current_record, row) = pipe.recv()
-                except EOFError:
-                    # the pipe was closed and there's nothing to receive
-                    sys.stdout.write('Failed to read from pipe:\n\n')
-                    sys.stdout.flush()
-                    conn._writable = True
-                    conn._readable = True
-                    break
-                # see if the parent process has signaled that we are done
-                if (current_record, row) == (None, None):
-                    conn._writable = True
-                    conn._readable = True
-                    pipe.close()
-                    break
-                # format the values in the row
-                for i, value in enumerate(row):
-                    if value != nullval:
-                        if should_escape[i]:
-                            row[i] = protect_value(value)
-                    elif i in pk_indexes:
-                        # By default, nullval is an empty string. See 
CASSANDRA-7792 for details.
-                        message = "Cannot insert null value for primary key 
column '%s'." % (pk_cols[i],)
-                        if nullval == '':
-                            message += " If you want to insert empty strings, 
consider using " \
-                                       "the WITH NULL=<marker> option for 
-                        pipe.send((current_record, message))
-                        return
-                    else:
-                        row[i] = 'null'
-                if is_counter_table:
-                    where_clause = []
-                    set_clause = []
-                    for i, value in enumerate(row):
-                        if i in pk_indexes:
-                            where_clause.append("%s=%s" % (columns[i], value))
-                        else:
-                            set_clause.append("%s=%s+%s" % (columns[i], 
columns[i], value))
-                    full_query = query % (','.join(set_clause), ' AND 
-                else:
-                    full_query = query % (','.join(row),)
-                query_message = QueryMessage(
-                    full_query, self.consistency_level, 
-                    fetch_size=None, paging_state=None, 
-                request_id = conn.get_request_id()
-                conn.send_msg(query_message, request_id=request_id, 
cb=partial(callback, current_record))
-                with conn.lock:
-                    conn.in_flight += 1
-                # every 50 records, clear the pending writes queue and read
-                # any responses we have
-                if insert_num % 50 == 0:
-                    conn._writable = True
-                    conn._readable = True
-                insert_num += 1
-        except Exception, exc:
-            pipe.send((current_record, exc))
-        finally:
-            # wait for any pending requests to finish
-            while conn.in_flight > 0:
-                conn._readable = True
-                time.sleep(0.01)
-            new_cluster.shutdown()
+        return copyutil.ImportTask(self, ks, cf, columns, fname, csv_options, 
+                                   DEFAULT_PROTOCOL_VERSION, CONFIG_FILE).run()
     def perform_csv_export(self, ks, cf, columns, fname, opts):
         csv_options, dialect_options, unrecognized_options = 
copyutil.parse_options(self, opts)
diff --git a/pylib/cqlshlib/ b/pylib/cqlshlib/
index 8534b98..f699e64 100644
--- a/pylib/cqlshlib/
+++ b/pylib/cqlshlib/
@@ -19,23 +19,32 @@ import json
 import multiprocessing as mp
 import os
 import Queue
+import random
+import re
+import struct
 import sys
 import time
 import traceback
-from StringIO import StringIO
+from calendar import timegm
+from collections import defaultdict, deque, namedtuple
+from decimal import Decimal
 from random import randrange
+from StringIO import StringIO
 from threading import Lock
+from uuid import UUID
 from cassandra.cluster import Cluster
+from cassandra.cqltypes import ReversedType, UserType
 from cassandra.metadata import protect_name, protect_names
-from cassandra.policies import RetryPolicy, WhiteListRoundRobinPolicy, 
-from cassandra.query import tuple_factory
+from cassandra.policies import RetryPolicy, WhiteListRoundRobinPolicy, 
TokenAwarePolicy, DCAwareRoundRobinPolicy
+from cassandra.query import BatchStatement, BatchType, SimpleStatement, 
+from cassandra.util import Date, Time
-import sslhandling
+from cql3handling import CqlRuleSet
 from displaying import NO_COLOR_MAP
 from formatting import format_value_default, EMPTY, get_formatter
+from sslhandling import ssl_settings
 def parse_options(shell, opts):
@@ -60,13 +69,18 @@ def parse_options(shell, opts):
     csv_options['nullval'] = opts.pop('null', '')
     csv_options['header'] = bool(opts.pop('header', '').lower() == 'true')
     csv_options['encoding'] = opts.pop('encoding', 'utf8')
-    csv_options['jobs'] = int(opts.pop('jobs', 12))
+    csv_options['maxrequests'] = int(opts.pop('maxrequests', 6))
     csv_options['pagesize'] = int(opts.pop('pagesize', 1000))
     # by default the page timeout is 10 seconds per 1000 entries in the page 
size or 10 seconds if pagesize is smaller
     csv_options['pagetimeout'] = int(opts.pop('pagetimeout', max(10, 10 * 
(csv_options['pagesize'] / 1000))))
     csv_options['maxattempts'] = int(opts.pop('maxattempts', 5))
     csv_options['dtformats'] = opts.pop('timeformat', 
     csv_options['float_precision'] = shell.display_float_precision
+    csv_options['chunksize'] = int(opts.pop('chunksize', 1000))
+    csv_options['ingestrate'] = int(opts.pop('ingestrate', 100000))
+    csv_options['maxbatchsize'] = int(opts.pop('maxbatchsize', 20))
+    csv_options['minbatchsize'] = int(opts.pop('minbatchsize', 2))
+    csv_options['reportfrequency'] = float(opts.pop('reportfrequency', 0.25))
     return csv_options, dialect_options, opts
@@ -86,9 +100,9 @@ def get_num_processes(cap):
         return 1
-class ExportTask(object):
+class CopyTask(object):
-    A class that exports data to .csv by instantiating one or more processes 
that work in parallel (ExportProcess).
+    A base class for ImportTask and ExportTask
     def __init__(self, shell, ks, cf, columns, fname, csv_options, 
dialect_options, protocol_version, config_file): = shell
@@ -101,6 +115,55 @@ class ExportTask(object):
         self.protocol_version = protocol_version
         self.config_file = config_file
+        self.processes = []
+        self.inmsg = mp.Queue()
+        self.outmsg = mp.Queue()
+    def close(self):
+        for process in self.processes:
+            process.terminate()
+        self.inmsg.close()
+        self.outmsg.close()
+    def num_live_processes(self):
+        return sum(1 for p in self.processes if p.is_alive())
+    def make_params(self):
+        """
+        Return a dictionary of parameters to be used by the worker processes.
+        On Windows this dictionary must be pickle-able.
+        inmsg is the message queue flowing from parent to child process, so 
outmsg from the parent point
+        of view and, vice-versa,  outmsg is the message queue flowing from 
child to parent, so inmsg
+        from the parent point of view, hence the two are swapped below.
+        """
+        shell =
+        return dict(inmsg=self.outmsg,  # see comment above
+                    outmsg=self.inmsg,  # see comment above
+                    ks=self.ks,
+          ,
+                    columns=self.columns,
+                    csv_options=self.csv_options,
+                    dialect_options=self.dialect_options,
+                    consistency_level=shell.consistency_level,
+                    connect_timeout=shell.conn.connect_timeout,
+                    hostname=shell.hostname,
+                    port=shell.port,
+                    ssl=shell.ssl,
+                    auth_provider=shell.auth_provider,
+                    cql_version=shell.conn.cql_version,
+                    config_file=self.config_file,
+                    protocol_version=self.protocol_version,
+                    debug=shell.debug
+                    )
+class ExportTask(CopyTask):
+    """
+    A class that exports data to .csv by instantiating one or more processes 
that work in parallel (ExportProcess).
+    """
     def run(self):
         Initiates the export by creating the processes.
@@ -125,25 +188,18 @@ class ExportTask(object):
         ranges = self.get_ranges()
         num_processes = get_num_processes(cap=min(16, len(ranges)))
+        params = self.make_params()
-        inmsg = mp.Queue()
-        outmsg = mp.Queue()
-        processes = []
         for i in xrange(num_processes):
-            process = ExportProcess(outmsg, inmsg, self.ks,, 
self.columns, self.dialect_options,
-                                    self.csv_options, shell.debug, shell.port, 
-                                    shell.auth_provider, shell.ssl, 
self.protocol_version, self.config_file)
+            self.processes.append(ExportProcess(params))
+        for process in self.processes:
-            processes.append(process)
-            return self.check_processes(csvdest, ranges, inmsg, outmsg, 
+            return self.check_processes(csvdest, ranges)
-            for process in processes:
-                process.terminate()
-            inmsg.close()
-            outmsg.close()
+            self.close()
             if do_close:
@@ -183,9 +239,9 @@ class ExportTask(object):
             hosts = []
             for host in replicas:
-                if host.datacenter == local_dc:
+                if host.is_up and host.datacenter == local_dc:
-            if len(hosts) == 0:
+            if not hosts:
                 hosts.append(hostname)  # fallback to default host if no 
replicas in current dc
             ranges[(previous, token.value)] = make_range(hosts)
             previous_previous = previous
@@ -194,7 +250,7 @@ class ExportTask(object):
         #  If the ring is empty we get the entire ring from the
         #  host we are currently connected to, otherwise for the last ring 
         #  we query the same replicas that hold the last token in the ring
-        if len(ranges) == 0:
+        if not ranges:
             ranges[(None, None)] = make_range([hostname])
             ranges[(previous, None)] = ranges[(previous_previous, 
@@ -217,32 +273,32 @@ class ExportTask(object):
             return None
-    @staticmethod
-    def send_work(ranges, tokens_to_send, queue):
+    def send_work(self, ranges, tokens_to_send):
         for token_range in tokens_to_send:
-            queue.put((token_range, ranges[token_range]))
+            self.outmsg.put((token_range, ranges[token_range]))
             ranges[token_range]['attempts'] += 1
-    def check_processes(self, csvdest, ranges, inmsg, outmsg, processes):
+    def check_processes(self, csvdest, ranges):
         Here we monitor all child processes by collecting their results
         or any errors. We terminate when we have processed all the ranges or 
when there
         are no more processes.
         shell =
-        meter = RateMeter(10000)
-        total_jobs = len(ranges)
+        processes = self.processes
+        meter = RateMeter(update_interval=self.csv_options['reportfrequency'])
+        total_requests = len(ranges)
         max_attempts = self.csv_options['maxattempts']
-        self.send_work(ranges, ranges.keys(), outmsg)
+        self.send_work(ranges, ranges.keys())
         num_processes = len(processes)
         succeeded = 0
         failed = 0
-        while (failed + succeeded) < total_jobs and 
self.num_live_processes(processes) == num_processes:
+        while (failed + succeeded) < total_requests and 
self.num_live_processes() == num_processes:
-                token_range, result = inmsg.get(timeout=1.0)
-                if token_range is None and result is None:  # a job has 
+                token_range, result = self.inmsg.get(timeout=1.0)
+                if token_range is None and result is None:  # a request has 
                     succeeded += 1
                 elif isinstance(result, Exception):  # an error occurred
                     if token_range is None:  # the entire process failed
@@ -253,7 +309,7 @@ class ExportTask(object):
                         if ranges[token_range]['attempts'] < max_attempts and 
ranges[token_range]['rows'] == 0:
                             shell.printerr('Error for %s: %s (will try again 
later attempt %d of %d)'
                                            % (token_range, result, 
ranges[token_range]['attempts'], max_attempts))
-                            self.send_work(ranges, [token_range], outmsg)
+                            self.send_work(ranges, [token_range])
                             shell.printerr('Error for %s: %s (permanently 
given up after %d rows and %d attempts)'
                                            % (token_range, result, 
@@ -267,34 +323,257 @@ class ExportTask(object):
             except Queue.Empty:
-        if self.num_live_processes(processes) < len(processes):
+        if self.num_live_processes() < len(processes):
             for process in processes:
                 if not process.is_alive():
                     shell.printerr('Child process %d died with exit code %d' % 
(, process.exitcode))
-        if succeeded < total_jobs:
+        if succeeded < total_requests:
             shell.printerr('Exported %d ranges out of %d total ranges, some 
records might be missing'
-                           % (succeeded, total_jobs))
+                           % (succeeded, total_requests))
         return meter.get_total_records()
+class ImportReader(object):
+    """
+    A wrapper around a csv reader to keep track of when we have
+    exhausted reading input records.
+    """
+    def __init__(self, linesource, chunksize, dialect_options):
+        self.linesource = linesource
+        self.chunksize = chunksize
+        self.reader = csv.reader(linesource, **dialect_options)
+        self.exhausted = False
+    def read_rows(self):
+        if self.exhausted:
+            return []
+        rows = list(next(self.reader) for _ in xrange(self.chunksize))
+        self.exhausted = len(rows) < self.chunksize
+        return rows
+class ImportTask(CopyTask):
+    """
+    A class to import data from .csv by instantiating one or more processes
+    that work in parallel (ImportProcess).
+    """
+    def __init__(self, shell, ks, cf, columns, fname, csv_options, 
dialect_options, protocol_version, config_file):
+        CopyTask.__init__(self, shell, ks, cf, columns, fname,
+                          csv_options, dialect_options, protocol_version, 
+        self.num_processes = get_num_processes(cap=4)
+        self.chunk_size = csv_options['chunksize']
+        self.ingest_rate = csv_options['ingestrate']
+        self.max_attempts = csv_options['maxattempts']
+        self.header = self.csv_options['header']
+        self.table_meta =,
+        self.batch_id = 0
+        self.receive_meter = 
+        self.send_meter = RateMeter(update_interval=1, log=False)
+        self.retries = deque([])
+        self.failed = 0
+        self.succeeded = 0
+        self.sent = 0
+    def run(self):
+        shell =
+        if self.fname is None:
+            do_close = False
+            print "[Use \. on a line by itself to end input]"
+            linesource = shell.use_stdin_reader(prompt='[copy] ', until=r'\.')
+        else:
+            do_close = True
+            try:
+                linesource = open(self.fname, 'rb')
+            except IOError, e:
+                shell.printerr("Can't open %r for reading: %s" % (self.fname, 
+                return 0
+        try:
+            if self.header:
+            reader = ImportReader(linesource, self.chunk_size, 
+            params = self.make_params()
+            for i in range(self.num_processes):
+                self.processes.append(ImportProcess(params))
+            for process in self.processes:
+                process.start()
+            return self.process_records(reader)
+        except Exception, exc:
+            shell.printerr(str(exc))
+            if shell.debug:
+                traceback.print_exc()
+            return 0
+        finally:
+            self.close()
+            if do_close:
+                linesource.close()
+            elif shell.tty:
+                print
+    def process_records(self, reader):
+        """
+        Keep on running until we have stuff to receive or send and until all 
processes are running.
+        Send data (batches or retries) up to the max ingest rate. If we are 
waiting for stuff to
+        receive check the incoming queue.
+        """
+        while (self.has_more_to_send(reader) or self.has_more_to_receive()) 
and self.all_processes_running():
+            if self.has_more_to_send(reader):
+                if self.send_meter.current_record <= self.ingest_rate:
+                    self.send_batches(reader)
+                else:
+                    self.send_meter.maybe_update()
+            if self.has_more_to_receive():
+                self.receive()
+        if self.succeeded < self.sent:
+  "Failed to process %d batches" % (self.sent - 
+        return self.receive_meter.get_total_records()
+    def has_more_to_receive(self):
+        return (self.succeeded + self.failed) < self.sent
+    def has_more_to_send(self, reader):
+        return (not reader.exhausted) or self.retries
+    def all_processes_running(self):
+        return self.num_live_processes() == self.num_processes
+    def receive(self):
+        shell =
+        start_time = time.time()
+        while time.time() - start_time < 0.01:  # 10 millis
+            try:
+                batch, err = self.inmsg.get(timeout=0.001)  # 1 millisecond
+                if err is None:
+                    self.succeeded += batch['imported']
+                    self.receive_meter.increment(batch['imported'])
+                else:
+                    err = str(err)
+                    if err.startswith('ValueError') or 
err.startswith('TypeError') or err.startswith('IndexError') \
+                            or batch['attempts'] >= self.max_attempts:
+                        shell.printerr("Failed to import %d rows: %s -  given 
up after %d attempts"
+                                       % (len(batch['rows']), err, 
+                        self.failed += len(batch['rows'])
+                    else:
+                        shell.printerr("Failed to import %d rows: %s -  will 
retry later, attempt %d of %d"
+                                       % (len(batch['rows']), err, 
+                                          self.max_attempts))
+                        self.retries.append(self.reset_batch(batch))
+            except Queue.Empty:
+                break
+    def send_batches(self, reader):
+        """
+        Send batches to the queue until we have exceeded the ingest rate. In 
the export case we queue
+        everything and let the worker processes throttle using max_requests, 
here we throttle
+        in the parent process because of memory usage concerns.
+        When we have finished reading the csv file, then send any retries.
+        """
+        while self.send_meter.current_record <= self.ingest_rate:
+            if not reader.exhausted:
+                rows = reader.read_rows()
+                if rows:
+                    self.sent += self.send_batch(self.new_batch(rows))
+            elif self.retries:
+                batch = self.retries.popleft()
+                self.send_batch(batch)
+            else:
+                break
+    def send_batch(self, batch):
+        batch['attempts'] += 1
+        num_rows = len(batch['rows'])
+        self.send_meter.increment(num_rows)
+        self.outmsg.put(batch)
+        return num_rows
+    def new_batch(self, rows):
+        self.batch_id += 1
+        return self.make_batch(self.batch_id, rows, 0)
+    @staticmethod
+    def reset_batch(batch):
+        batch['imported'] = 0
+        return batch
-    def num_live_processes(processes):
-        return sum(1 for p in processes if p.is_alive())
+    def make_batch(batch_id, rows, attempts):
+        return {'id': batch_id, 'rows': rows, 'attempts': attempts, 
'imported': 0}
+class ChildProcess(mp.Process):
+    """
+    An child worker process, this is for common functionality between 
ImportProcess and ExportProcess.
+    """
+    def __init__(self, params, target):
+        mp.Process.__init__(self, target=target)
+        self.inmsg = params['inmsg']
+        self.outmsg = params['outmsg']
+        self.ks = params['ks']
+ = params['cf']
+        self.columns = params['columns']
+        self.debug = params['debug']
+        self.port = params['port']
+        self.hostname = params['hostname']
+        self.consistency_level = params['consistency_level']
+        self.connect_timeout = params['connect_timeout']
+        self.cql_version = params['cql_version']
+        self.auth_provider = params['auth_provider']
+        self.ssl = params['ssl']
+        self.protocol_version = params['protocol_version']
+        self.config_file = params['config_file']
+        # Here we inject some failures for testing purposes, only if this 
environment variable is set
+        if os.environ.get('CQLSH_COPY_TEST_FAILURES', ''):
+            self.test_failures = 
json.loads(os.environ.get('CQLSH_COPY_TEST_FAILURES', ''))
+        else:
+            self.test_failures = None
+    def printmsg(self, text):
+        if self.debug:
+            sys.stderr.write(text + os.linesep)
+    def close(self):
+        self.printmsg("Closing queues...")
+        self.inmsg.close()
+        self.outmsg.close()
 class ExpBackoffRetryPolicy(RetryPolicy):
-    A retry policy with exponential back-off for read timeouts,
-    see ExportProcess.
+    A retry policy with exponential back-off for read timeouts and write 
-    def __init__(self, export_process):
+    def __init__(self, parent_process):
-        self.max_attempts = export_process.csv_options['maxattempts']
-        self.printmsg = lambda txt: export_process.printmsg(txt)
+        self.max_attempts = parent_process.max_attempts
+        self.printmsg = parent_process.printmsg
     def on_read_timeout(self, query, consistency, required_responses,
                         received_responses, data_retrieved, retry_num):
+        return self._handle_timeout(consistency, retry_num)
+    def on_write_timeout(self, query, consistency, write_type,
+                         required_responses, received_responses, retry_num):
+        return self._handle_timeout(consistency, retry_num)
+    def _handle_timeout(self, consistency, retry_num):
         delay = self.backoff(retry_num)
         if delay > 0:
             self.printmsg("Timeout received, retrying after %d seconds" % 
@@ -327,7 +606,7 @@ class ExpBackoffRetryPolicy(RetryPolicy):
 class ExportSession(object):
     A class for connecting to a cluster and storing the number
-    of jobs that this connection is processing. It wraps the methods
+    of requests that this connection is processing. It wraps the methods
     for executing a query asynchronously and for shutting down the
     connection to the cluster.
@@ -342,20 +621,20 @@ class ExportSession(object):
         self.cluster = cluster
         self.session = session
- = 1
+        self.requests = 1
         self.lock = Lock()
-    def add_job(self):
+    def add_request(self):
         with self.lock:
-   += 1
+            self.requests += 1
-    def complete_job(self):
+    def complete_request(self):
         with self.lock:
-   -= 1
+            self.requests -= 1
-    def num_jobs(self):
+    def num_requests(self):
         with self.lock:
-            return
+            return self.requests
     def execute_async(self, query):
         return self.session.execute_async(query)
@@ -364,48 +643,26 @@ class ExportSession(object):
-class ExportProcess(mp.Process):
+class ExportProcess(ChildProcess):
     An child worker process for the export task, ExportTask.
-    def __init__(self, inmsg, outmsg, ks, cf, columns, dialect_options, 
-                 debug, port, cql_version, auth_provider, ssl, 
protocol_version, config_file):
-        mp.Process.__init__(self,
-        self.inmsg = inmsg
-        self.outmsg = outmsg
-        self.ks = ks
- = cf
-        self.columns = columns
-        self.dialect_options = dialect_options
+    def __init__(self, params):
+        ChildProcess.__init__(self, params=params,
+        self.dialect_options = params['dialect_options']
         self.hosts_to_sessions = dict()
-        self.debug = debug
-        self.port = port
-        self.cql_version = cql_version
-        self.auth_provider = auth_provider
-        self.ssl = ssl
-        self.protocol_version = protocol_version
-        self.config_file = config_file
+        csv_options = params['csv_options']
         self.encoding = csv_options['encoding']
         self.time_format = csv_options['dtformats']
         self.float_precision = csv_options['float_precision']
         self.nullval = csv_options['nullval']
-        self.maxjobs = csv_options['jobs']
+        self.max_attempts = csv_options['maxattempts']
+        self.max_requests = csv_options['maxrequests']
         self.csv_options = csv_options
         self.formatters = dict()
-        # Here we inject some failures for testing purposes, only if this 
environment variable is set
-        if os.environ.get('CQLSH_COPY_TEST_FAILURES', ''):
-            self.test_failures = 
json.loads(os.environ.get('CQLSH_COPY_TEST_FAILURES', ''))
-        else:
-            self.test_failures = None
-    def printmsg(self, text):
-        if self.debug:
-            sys.stderr.write(text + os.linesep)
     def run(self):
@@ -423,12 +680,12 @@ class ExportProcess(mp.Process):
         We terminate when the inbound queue is closed.
         while True:
-            if self.num_jobs() > self.maxjobs:
+            if self.num_requests() > self.max_requests:
                 time.sleep(0.001)  # 1 millisecond
             token_range, info = self.inmsg.get()
-            self.start_job(token_range, info)
+            self.start_request(token_range, info)
     def report_error(self, err, token_range=None):
         if isinstance(err, str):
@@ -443,7 +700,7 @@ class ExportProcess(mp.Process):
         self.outmsg.put((token_range, Exception(msg)))
-    def start_job(self, token_range, info):
+    def start_request(self, token_range, info):
         Begin querying a range by executing an async query that
         will later on invoke the callbacks attached in attach_callbacks.
@@ -454,14 +711,14 @@ class ExportProcess(mp.Process):
         future = session.execute_async(query)
         self.attach_callbacks(token_range, future, session)
-    def num_jobs(self):
-        return sum(session.num_jobs() for session in 
+    def num_requests(self):
+        return sum(session.num_requests() for session in 
     def get_session(self, hosts):
         We select a host to connect to. If we have no connections to one of 
the hosts
         yet then we select this host, else we pick the one with the smallest 
-        of jobs.
+        of requests.
         :return: An ExportSession connected to the chosen host.
@@ -474,19 +731,18 @@ class ExportProcess(mp.Process):
-                ssl_options=sslhandling.ssl_settings(host, self.config_file) 
if self.ssl else None,
+                ssl_options=ssl_settings(host, self.config_file) if self.ssl 
else None,
-                compression=None,
-                executor_threads=max(2, self.csv_options['jobs'] / 2))
+                compression=None)
             session = ExportSession(new_cluster, self)
             self.hosts_to_sessions[host] = session
             return session
-            host = min(hosts, key=lambda h: self.hosts_to_sessions[h].jobs)
+            host = min(hosts, key=lambda h: self.hosts_to_sessions[h].requests)
             session = self.hosts_to_sessions[host]
-            session.add_job()
+            session.add_request()
             return session
     def attach_callbacks(self, token_range, future, session):
@@ -497,16 +753,16 @@ class ExportProcess(mp.Process):
                 self.write_rows_to_csv(token_range, rows)
                 self.outmsg.put((None, None))
-                session.complete_job()
+                session.complete_request()
         def err_callback(err):
             self.report_error(err, token_range)
-            session.complete_job()
+            session.complete_request()
         future.add_callbacks(callback=result_callback, errback=err_callback)
     def write_rows_to_csv(self, token_range, rows):
-        if len(rows) == 0:
+        if not rows:
             return  # no rows in this range
@@ -537,12 +793,9 @@ class ExportProcess(mp.Process):
nullval=self.nullval, quote=False)
     def close(self):
-        self.printmsg("Export process terminating...")
-        self.inmsg.close()
-        self.outmsg.close()
+        ChildProcess.close(self)
         for session in self.hosts_to_sessions.values():
-        self.printmsg("Export process terminated")
     def prepare_query(self, partition_key, token_range, attempts):
@@ -598,26 +851,439 @@ class ExportProcess(mp.Process):
         return query
+class ImportConversion(object):
+    """
+    A class for converting strings to values when importing from csv, used by 
+    the parent.
+    """
+    def __init__(self, parent, table_meta, statement):
+        self.ks = parent.ks
+ =
+        self.columns = parent.columns
+        self.nullval = parent.nullval
+        self.printmsg = parent.printmsg
+        self.table_meta = table_meta
+        self.primary_key_indexes = [self.columns.index( for col in 
+        self.partition_key_indexes = [self.columns.index( for col in 
+        self.proto_version = statement.protocol_version
+        self.cqltypes = dict([(, c.type) for c in 
+        self.converters = dict([(, self._get_converter(c.type)) for c in 
+    def _get_converter(self, cql_type):
+        """
+        Return a function that converts a string into a value the can be passed
+        into BoundStatement.bind() for the given cql type. See 
+        for more details.
+        """
+        def unprotect(v):
+            if v is not None:
+                return CqlRuleSet.dequote_value(v)
+        def convert(t, v):
+            return converters.get(t.typename, convert_unknown)(unprotect(v), 
+        def split(val, sep=','):
+            """
+            Split into a list of values whenever we encounter a separator but
+            ignore separators inside parentheses or single quotes, except for 
the two
+            outermost parentheses, which will be ignored. We expect val to be 
at least
+            2 characters long (the two outer parentheses).
+            """
+            ret = []
+            last = 1
+            level = 0
+            quote = False
+            for i, c in enumerate(val):
+                if c == '{' or c == '[' or c == '(':
+                    level += 1
+                elif c == '}' or c == ']' or c == ')':
+                    level -= 1
+                elif c == '\'':
+                    quote = not quote
+                elif c == sep and level == 1 and not quote:
+                    ret.append(val[last:i])
+                    last = i + 1
+            else:
+                if last < len(val) - 1:
+                    ret.append(val[last:-1])
+            return ret
+        # this should match all possible CQL datetime formats
+        p = re.compile("(\d{4})\-(\d{2})\-(\d{2})\s?(?:'T')?" +  # 
YYYY-MM-DD[( |'T')]
+                       "(?:(\d{2}):(\d{2})(?::(\d{2}))?)?" +  # [HH:MM[:SS]]
+                       "(?:([+\-])(\d{2}):?(\d{2}))?")  # [(+|-)HH[:]MM]]
+        def convert_date(val, **_):
+            m = p.match(val)
+            if not m:
+                raise ValueError("can't interpret %r as a date" % (val,))
+            #
+            tval = time.struct_time((int(, int(, 
int(,  # year, month, day
+                                     int( if else 0,  # 
+                                     int( if else 0,  # 
+                                     int( if else 0,  # 
+                                     0, 1, -1))  # day of week, day of year, 
+            if
+                offset = (int( * 3600 + int( * 60) * 
int( + '1')
+            else:
+                offset = -time.timezone
+            # scale seconds to millis for the raw value
+            return (timegm(tval) + offset) * 1e3
+        def convert_tuple(val, ct=cql_type):
+            return tuple(convert(t, v) for t, v in zip(ct.subtypes, 
+        def convert_list(val, ct=cql_type):
+            return list(convert(ct.subtypes[0], v) for v in split(val))
+        def convert_set(val, ct=cql_type):
+            return frozenset(convert(ct.subtypes[0], v) for v in split(val))
+        def convert_map(val, ct=cql_type):
+            """
+            We need to pass to BoundStatement.bind() a dict() because it calls 
+            except we can't create a dict with another dict as the key, hence 
we use a class
+            that adds iteritems to a frozen set of tuples (which is how dict 
are normally made
+            immutable in python).
+            """
+            class ImmutableDict(frozenset):
+                iteritems = frozenset.__iter__
+            return ImmutableDict(frozenset((convert(ct.subtypes[0], v[0]), 
convert(ct.subtypes[1], v[1]))
+                                 for v in [split('{%s}' % vv, sep=':') for vv 
in split(val)]))
+        def convert_user_type(val, ct=cql_type):
+            """
+            A user type is a dictionary except that we must convert each key 
+            an attribute, so we are using named tuples. It must also be 
+            so we cannot use dictionaries. Maybe there is a way to instantiate 
+            directly but I could not work it out.
+            """
+            vals = [v for v in [split('{%s}' % vv, sep=':') for vv in 
+            ret_type = namedtuple(ct.typename, [unprotect(v[0]) for v in vals])
+            return ret_type(*tuple(convert(t, v[1]) for t, v in 
zip(ct.subtypes, vals)))
+        def convert_single_subtype(val, ct=cql_type):
+            return converters.get(ct.subtypes[0].typename, 
convert_unknown)(val, ct=ct.subtypes[0])
+        def convert_unknown(val, ct=cql_type):
+            if issubclass(ct, UserType):
+                return convert_user_type(val, ct=ct)
+            elif issubclass(ct, ReversedType):
+                return convert_single_subtype(val, ct=ct)
+            self.printmsg("Unknown type %s (%s) for val %s" % (ct, 
ct.typename, val))
+            return val
+        converters = {
+            'blob': (lambda v, ct=cql_type: bytearray.fromhex(v[2:])),
+            'decimal': (lambda v, ct=cql_type: Decimal(v)),
+            'uuid': (lambda v, ct=cql_type: UUID(v)),
+            'boolean': (lambda v, ct=cql_type: bool(v)),
+            'tinyint': (lambda v, ct=cql_type: int(v)),
+            'ascii': (lambda v, ct=cql_type: v),
+            'float': (lambda v, ct=cql_type: float(v)),
+            'double': (lambda v, ct=cql_type: float(v)),
+            'bigint': (lambda v, ct=cql_type: long(v)),
+            'int': (lambda v, ct=cql_type: int(v)),
+            'varint': (lambda v, ct=cql_type: int(v)),
+            'inet': (lambda v, ct=cql_type: v),
+            'counter': (lambda v, ct=cql_type: long(v)),
+            'timestamp': convert_date,
+            'timeuuid': (lambda v, ct=cql_type: UUID(v)),
+            'date': (lambda v, ct=cql_type: Date(v)),
+            'smallint': (lambda v, ct=cql_type: int(v)),
+            'time': (lambda v, ct=cql_type: Time(v)),
+            'text': (lambda v, ct=cql_type: v),
+            'varchar': (lambda v, ct=cql_type: v),
+            'list': convert_list,
+            'set': convert_set,
+            'map': convert_map,
+            'tuple': convert_tuple,
+            'frozen': convert_single_subtype,
+        }
+        return converters.get(cql_type.typename, convert_unknown)
+    def get_row_values(self, row):
+        """
+        Parse the row into a list of row values to be returned
+        """
+        ret = [None] * len(row)
+        for i, val in enumerate(row):
+            if val != self.nullval:
+                ret[i] = self.converters[self.columns[i]](val)
+            else:
+                if i in self.primary_key_indexes:
+                    message = "Cannot insert null value for primary key column 
'%s'." % (self.columns[i],)
+                    if self.nullval == '':
+                        message += " If you want to insert empty strings, 
consider using" \
+                                   " the WITH NULL=<marker> option for COPY."
+                    raise Exception(message=message)
+                ret[i] = None
+        return ret
+    def get_row_partition_key_values(self, row):
+        """
+        Return a string composed of the partition key values, serialized and 
binary packed -
+        as expected by metadata.get_replicas(), see also 
+        """
+        def serialize(n):
+            c, v = self.columns[n], row[n]
+            return self.cqltypes[c].serialize(self.converters[c](v), 
+        partition_key_indexes = self.partition_key_indexes
+        if len(partition_key_indexes) == 1:
+            return serialize(partition_key_indexes[0])
+        else:
+            pk_values = []
+            for i in partition_key_indexes:
+                val = serialize(i)
+                l = len(val)
+                pk_values.append(struct.pack(">H%dsB" % l, l, val, 0))
+            return b"".join(pk_values)
+class ImportProcess(ChildProcess):
+    def __init__(self, params):
+        ChildProcess.__init__(self, params=params,
+        csv_options = params['csv_options']
+        self.nullval = csv_options['nullval']
+        self.max_attempts = csv_options['maxattempts']
+        self.min_batch_size = csv_options['minbatchsize']
+        self.max_batch_size = csv_options['maxbatchsize']
+        self._session = None
+    @property
+    def session(self):
+        if not self._session:
+            cluster = Cluster(
+                contact_points=(self.hostname,),
+                port=self.port,
+                cql_version=self.cql_version,
+                protocol_version=self.protocol_version,
+                auth_provider=self.auth_provider,
+                ssl_options=ssl_settings(self.hostname, self.config_file) if 
self.ssl else None,
+                default_retry_policy=ExpBackoffRetryPolicy(self),
+                compression=None,
+                connect_timeout=self.connect_timeout)
+            self._session = cluster.connect(self.ks)
+            self._session.default_timeout = None
+        return self._session
+    def run(self):
+        try:
+            table_meta = 
+            is_counter = ("counter" in [table_meta.columns[name].typestring 
for name in self.columns])
+            if is_counter:
+                self.run_counter(table_meta)
+            else:
+                self.run_normal(table_meta)
+        except Exception, exc:
+            if self.debug:
+                traceback.print_exc(exc)
+        finally:
+            self.close()
+    def close(self):
+        if self._session:
+            self._session.cluster.shutdown()
+        ChildProcess.close(self)
+    def run_counter(self, table_meta):
+        """
+        Main run method for tables that contain counter columns.
+        """
+        query = 'UPDATE %s.%s SET %%s WHERE %%s' % (protect_name(self.ks), 
+        # We prepare a query statement to find out the types of the partition 
key columns so we can
+        # route the update query to the correct replicas. As far as I 
understood this is the easiest
+        # way to find out the types of the partition columns, we will never 
use this prepared statement
+        where_clause = ' AND '.join(['%s = ?' % (protect_name( for c 
in table_meta.partition_key])
+        select_query = 'SELECT * FROM %s.%s WHERE %s' % 
(protect_name(self.ks), protect_name(, where_clause)
+        conv = ImportConversion(self, table_meta, 
+        while True:
+            try:
+                batch = self.inmsg.get()
+                for batches in self.split_batches(batch, conv):
+                    for b in batches:
+                        self.send_counter_batch(query, conv, b)
+            except Exception, exc:
+                self.outmsg.put((batch, '%s - %s' % (exc.__class__.__name__, 
+                if self.debug:
+                    traceback.print_exc(exc)
+    def run_normal(self, table_meta):
+        """
+        Main run method for normal tables, i.e. tables that do not contain 
counter columns.
+        """
+        query = 'INSERT INTO %s.%s (%s) VALUES (%s)' % (protect_name(self.ks),
+                                                        protect_name(,
+                                                        ', 
+                                                        ', '.join(['?' for _ 
in self.columns]))
+        query_statement = self.session.prepare(query)
+        conv = ImportConversion(self, table_meta, query_statement)
+        while True:
+            try:
+                batch = self.inmsg.get()
+                for batches in self.split_batches(batch, conv):
+                    for b in batches:
+                        self.send_normal_batch(conv, query_statement, b)
+            except Exception, exc:
+                self.outmsg.put((batch, '%s - %s' % (exc.__class__.__name__, 
+                if self.debug:
+                    traceback.print_exc(exc)
+    def send_counter_batch(self, query_text, conv, batch):
+        if self.test_failures and self.maybe_inject_failures(batch):
+            return
+        columns = self.columns
+        batch_statement = BatchStatement(batch_type=BatchType.COUNTER, 
+        for row in batch['rows']:
+            where_clause = []
+            set_clause = []
+            for i, value in enumerate(row):
+                if i in conv.primary_key_indexes:
+                    where_clause.append("%s=%s" % (columns[i], value))
+                else:
+                    set_clause.append("%s=%s+%s" % (columns[i], columns[i], 
+            full_query_text = query_text % (','.join(set_clause), ' AND 
+            batch_statement.add(full_query_text)
+        self.execute_statement(batch_statement, batch)
+    def send_normal_batch(self, conv, query_statement, batch):
+        try:
+            if self.test_failures and self.maybe_inject_failures(batch):
+                return
+            batch_statement = BatchStatement(batch_type=BatchType.UNLOGGED, 
+            for row in batch['rows']:
+                batch_statement.add(query_statement, conv.get_row_values(row))
+            self.execute_statement(batch_statement, batch)
+        except Exception, exc:
+            self.err_callback(exc, batch)
+    def maybe_inject_failures(self, batch):
+        """
+        Examine self.test_failures and see if token_range is either a token 
+        supposed to cause a failure (failing_range) or to terminate the worker 
+        (exit_range). If not then call prepare_export_query(), which 
implements the
+        normal behavior.
+        """
+        if 'failing_batch' in self.test_failures:
+            failing_batch = self.test_failures['failing_batch']
+            if failing_batch['id'] == batch['id']:
+                if batch['attempts'] < failing_batch['failures']:
+                    statement = SimpleStatement("INSERT INTO badtable (a, b) 
VALUES (1, 2)",
+                    self.execute_statement(statement, batch)
+                    return True
+        if 'exit_batch' in self.test_failures:
+            exit_batch = self.test_failures['exit_batch']
+            if exit_batch['id'] == batch['id']:
+                sys.exit(1)
+        return False  # carry on as normal
+    def execute_statement(self, statement, batch):
+        future = self.session.execute_async(statement)
+        future.add_callbacks(callback=self.result_callback, 
callback_args=(batch, ),
+                             errback=self.err_callback, errback_args=(batch, ))
+    def split_batches(self, batch, conv):
+        """
+        Split a batch into sub-batches with the same
+        partition key, if possible. If there are at least
+        batch_size rows with the same partition key value then
+        create a sub-batch with that partition key value, else
+        aggregate all remaining rows in a single 'left-overs' batch
+        """
+        rows_by_pk = defaultdict(list)
+        for row in batch['rows']:
+            pk = conv.get_row_partition_key_values(row)
+            rows_by_pk[pk].append(row)
+        ret = dict()
+        remaining_rows = []
+        for pk, rows in rows_by_pk.items():
+            if len(rows) >= self.min_batch_size:
+                ret[pk] = self.batches(rows, batch)
+            else:
+                remaining_rows.extend(rows)
+        if remaining_rows:
+            ret[self.hostname] = self.batches(remaining_rows, batch)
+        return ret.itervalues()
+    def batches(self, rows, batch):
+        for i in xrange(0, len(rows), self.max_batch_size):
+            yield ImportTask.make_batch(batch['id'], rows[i:i + 
self.max_batch_size], batch['attempts'])
+    def result_callback(self, result, batch):
+        batch['imported'] = len(batch['rows'])
+        batch['rows'] = []  # no need to resend these
+        self.outmsg.put((batch, None))
+    def err_callback(self, response, batch):
+        batch['imported'] = len(batch['rows'])
+        self.outmsg.put((batch, '%s - %s' % (response.__class__.__name__, 
+        if self.debug:
+            traceback.print_exc(response)
 class RateMeter(object):
-    def __init__(self, log_threshold):
-        self.log_threshold = log_threshold  # number of records after which we 
-        self.last_checkpoint_time = time.time()  # last time we logged
+    def __init__(self, update_interval=0.25, log=True):
+        self.log = log  # true if we should log
+        self.update_interval = update_interval  # how often we update in 
+        self.start_time = time.time()  # the start time
+        self.last_checkpoint_time = self.start_time  # last time we logged
         self.current_rate = 0.0  # rows per second
-        self.current_record = 0  # number of records since we last logged
+        self.current_record = 0  # number of records since we last updated
         self.total_records = 0   # total number of records
     def increment(self, n=1):
         self.current_record += n
+        self.maybe_update()
-        if self.current_record >= self.log_threshold:
-            self.update()
-            self.log()
-    def update(self):
+    def maybe_update(self):
         new_checkpoint_time = time.time()
+        if new_checkpoint_time - self.last_checkpoint_time >= 
+            self.update(new_checkpoint_time)
+            self.log_message()
+    def update(self, new_checkpoint_time):
         time_difference = new_checkpoint_time - self.last_checkpoint_time
-        if time_difference != 0.0:
+        if time_difference >= 1e-09:
             self.current_rate = self.get_new_rate(self.current_record / 
         self.last_checkpoint_time = new_checkpoint_time
@@ -626,19 +1292,29 @@ class RateMeter(object):
     def get_new_rate(self, new_rate):
-         return the previous rate averaged with the new rate to smooth a bit
+         return the rate of the last period: this is the new rate but
+         averaged with the last rate to smooth a bit
         if self.current_rate == 0.0:
             return new_rate
             return (self.current_rate + new_rate) / 2.0
-    def log(self):
-        output = 'Processed %d rows; Written: %f rows/s\r' % 
(self.total_records, self.current_rate,)
-        sys.stdout.write(output)
-        sys.stdout.flush()
+    def get_avg_rate(self):
+        """
+         return the average rate since we started measuring
+        """
+        time_difference = time.time() - self.start_time
+        return self.total_records / time_difference if time_difference >= 
1e-09 else 0
+    def log_message(self):
+        if self.log:
+            output = 'Processed: %d rows; Rate: %7.0f rows/s; Avg. rage: %7.0f 
rows/s\r' % \
+                     (self.total_records, self.current_rate, 
+            sys.stdout.write(output)
+            sys.stdout.flush()
     def get_total_records(self):
-        self.update()
-        self.log()
+        self.update(time.time())
+        self.log_message()
         return self.total_records
diff --git a/pylib/cqlshlib/ b/pylib/cqlshlib/
index 4d6cf8a..281aad6 100644
--- a/pylib/cqlshlib/
+++ b/pylib/cqlshlib/
@@ -14,9 +14,14 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
+import cProfile
 import codecs
+import pstats
 from itertools import izip
 from datetime import timedelta, tzinfo
+from StringIO import StringIO
 ZERO = timedelta(0)
@@ -122,3 +127,17 @@ def get_file_encoding_bomsize(filename):
         file_encoding, size = "utf-8", 0
     return (file_encoding, size)
+def profile_on():
+    pr = cProfile.Profile()
+    pr.enable()
+    return pr
+def profile_off(pr):
+    pr.disable()
+    s = StringIO()
+    ps = pstats.Stats(pr, stream=s).sort_stats('cumulative')
+    ps.print_stats()
+    print s.getvalue()

Reply via email to