Repository: cassandra
Updated Branches:
  refs/heads/trunk 61c8ff4db -> 237d0bc62


Improve performance of cqlsh COPY FROM

Patch by Tyler Hobbs; reviewed by Aleksey Yeschenko for CASSANDRA-8225


Project: http://git-wip-us.apache.org/repos/asf/cassandra/repo
Commit: http://git-wip-us.apache.org/repos/asf/cassandra/commit/7110904e
Tree: http://git-wip-us.apache.org/repos/asf/cassandra/tree/7110904e
Diff: http://git-wip-us.apache.org/repos/asf/cassandra/diff/7110904e

Branch: refs/heads/trunk
Commit: 7110904e40dd2b9025b0189496afa7c7a30b646c
Parents: 5cf61e7
Author: Tyler Hobbs <tylerho...@apache.org>
Authored: Fri Mar 20 10:26:47 2015 -0500
Committer: Tyler Hobbs <tylerho...@apache.org>
Committed: Fri Mar 20 10:26:47 2015 -0500

----------------------------------------------------------------------
 CHANGES.txt                    |   1 +
 bin/cqlsh                      | 290 +++++++++++++++++++++++++++++-------
 pylib/cqlshlib/async_insert.py | 115 --------------
 pylib/cqlshlib/meter.py        |  59 --------
 4 files changed, 237 insertions(+), 228 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/cassandra/blob/7110904e/CHANGES.txt
----------------------------------------------------------------------
diff --git a/CHANGES.txt b/CHANGES.txt
index 14a45a3..a142999 100644
--- a/CHANGES.txt
+++ b/CHANGES.txt
@@ -1,4 +1,5 @@
 2.1.4
+ * (cqlsh) Greatly improve performance of COPY FROM (CASSANDRA-8225)
  * IndexSummary effectiveIndexInterval is now a guideline, not a rule 
(CASSANDRA-8993)
  * Use correct bounds for page cache eviction of compressed files 
(CASSANDRA-8746)
  * SSTableScanner enforces its bounds (CASSANDRA-8946)

http://git-wip-us.apache.org/repos/asf/cassandra/blob/7110904e/bin/cqlsh
----------------------------------------------------------------------
diff --git a/bin/cqlsh b/bin/cqlsh
index 3ec9457..fdf6ce1 100755
--- a/bin/cqlsh
+++ b/bin/cqlsh
@@ -51,6 +51,8 @@ import platform
 import warnings
 import csv
 import getpass
+from functools import partial
+import traceback
 
 
 readline = None
@@ -108,6 +110,8 @@ except ImportError, e:
 from cassandra.cluster import Cluster, PagedResult
 from cassandra.query import SimpleStatement, ordered_dict_factory
 from cassandra.policies import WhiteListRoundRobinPolicy
+from cassandra.protocol import QueryMessage, ResultMessage
+from cassandra.marshal import int16_pack, int32_pack, uint64_pack
 from cassandra.metadata import protect_name, protect_names, protect_value
 from cassandra.auth import PlainTextAuthProvider
 
@@ -117,7 +121,7 @@ cqlshlibdir = os.path.join(CASSANDRA_PATH, 'pylib')
 if os.path.isdir(cqlshlibdir):
     sys.path.insert(0, cqlshlibdir)
 
-from cqlshlib import cqlhandling, cql3handling, pylexotron, sslhandling, 
async_insert, meter
+from cqlshlib import cqlhandling, cql3handling, pylexotron, sslhandling
 from cqlshlib.displaying import (RED, BLUE, CYAN, ANSI_RESET, 
COLUMN_NAME_COLORS,
                                  FormattedValue, colorme)
 from cqlshlib.formatting import format_by_type, formatter_for, 
format_value_utype
@@ -550,6 +554,7 @@ class Shell(cmd.Cmd):
             self.auth_provider = PlainTextAuthProvider(username=username, 
password=password)
         self.username = username
         self.keyspace = keyspace
+        self.ssl = ssl
         self.tracing_enabled = tracing_enabled
         self.expand_enabled = expand_enabled
         if use_conn:
@@ -913,7 +918,6 @@ class Shell(cmd.Cmd):
                 self.handle_statement(st, statementtext)
             except Exception, e:
                 if self.debug:
-                    import traceback
                     traceback.print_exc()
                 else:
                     self.printerr(e)
@@ -1435,73 +1439,251 @@ class Shell(cmd.Cmd):
             except IOError, e:
                 self.printerr("Can't open %r for reading: %s" % (fname, e))
                 return 0
+
+        current_record = None
         try:
             if header:
                 linesource.next()
             table_meta = self.get_table_meta(ks, cf)
             reader = csv.reader(linesource, **dialect_options)
-            from functools import partial
-            rownum, error = \
-                async_insert.insert_concurrent(self.session, enumerate(reader, 
start=1),
-                                               partial(
-                                                   
self.create_insert_statement,
-                                                   columns, nullval,
-                                                   table_meta))
-            if error:
-                self.printerr(str(error[0]))
-                self.printerr("Aborting import at record #%d. "
-                              "Previously-inserted values still present."
-                               % error[1])
+
+            from multiprocessing import Process, Pipe, cpu_count
+
+            # Pick a resonable number of child processes. We need to leave at
+            # least one core for the parent process.  This doesn't necessarily
+            # need to be capped at 4, but it's currently enough to keep
+            # a single local Cassandra node busy, and I see lower throughput
+            # with more processes.
+            try:
+                num_processes = max(1, min(4, cpu_count() - 1))
+            except NotImplementedError:
+                num_processes = 1
+
+            processes, pipes = [], [],
+            for i in range(num_processes):
+                parent_conn, child_conn = Pipe()
+                pipes.append(parent_conn)
+                processes.append(Process(target=self.multiproc_import, 
args=(child_conn, ks, cf, columns, nullval)))
+
+            for process in processes:
+                process.start()
+
+            last_checkpoint_time = time.time()
+            current_rate = 0.0
+            for current_record, row in enumerate(reader, start=1):
+                # write to the child process
+                pipes[current_record % num_processes].send((current_record, 
row))
+
+                # update the progress and current rate periodically
+                if (current_record % 10000) == 0:
+                    new_checkpoint_time = time.time()
+                    new_rate = 10000.0 / (new_checkpoint_time - 
last_checkpoint_time)
+                    last_checkpoint_time = new_checkpoint_time
+
+                    # smooth the rate a bit
+                    if current_rate == 0.0:
+                        current_rate = new_rate
+                    else:
+                        current_rate = (current_rate + new_rate) / 2.0
+
+                    output = 'Processed %s rows; Write: %.2f rows/s\r' % \
+                             (current_record, current_rate)
+                    sys.stdout.write(output)
+                    sys.stdout.flush()
+
+                # check for any errors reported by the children
+                if (current_record % 100) == 0:
+                    if self._check_child_pipes(current_record, pipes):
+                        # no errors seen, continue with outer loop
+                        continue
+                    else:
+                        # errors seen, break out of outer loop
+                        break
+        except Exception, exc:
+            if current_record is None:
+                # we failed before we started
+                self.printerr("\nError starting import process:\n")
+                self.printerr(str(exc))
+                if self.debug:
+                    traceback.print_exc()
+            else:
+                self.printerr("\n" + str(exc))
+                self.printerr("\nAborting import at record #%d. "
+                              "Previously inserted records and some records 
after "
+                              "this number may be present."
+                              % (current_record,))
+                if self.debug:
+                    traceback.print_exc()
         finally:
+            # send a message that indicates we're done
+            for pipe in pipes:
+                pipe.send((None, None))
+
+            for process in processes:
+                process.join()
+
+            self._check_child_pipes(current_record, pipes)
+
+            for pipe in pipes:
+                pipe.close()
+
             if do_close:
                 linesource.close()
             elif self.tty:
                 print
-        return rownum
 
-    def create_insert_statement(self, columns, nullval, table_meta, row):
+        return current_record
 
-        if len(row) != len(columns):
-            raise ValueError(
-                "Record has the wrong number of fields (%d instead of %d)."
-                % (len(row), len(columns)))
+    def _check_child_pipes(self, current_record, pipes):
+        # check the pipes for errors from child processes
+        for pipe in pipes:
+            if pipe.poll():
+                try:
+                    (record_num, error) = pipe.recv()
+                    self.printerr("\n" + str(error))
+                    self.printerr(
+                        "Aborting import at record #%d. "
+                        "Previously inserted records are still present, "
+                        "and some records after that may be present as well."
+                        % (record_num,))
+                    return False
+                except EOFError:
+                    # pipe is closed, nothing to read
+                    self.printerr("\nChild process died without notification, "
+                                  "aborting import at record #%d. Previously "
+                                  "inserted records are probably still 
present, "
+                                  "and some records after that may be present "
+                                  "as well." % (current_record,))
+                    return False
+        return True
 
-        rowmap = {}
-        primary_key_columns = [col.name for col in table_meta.primary_key]
-        for name, value in zip(columns, row):
-            type = table_meta.columns[name].data_type
-            cqltype = table_meta.columns[name].typestring
+    def multiproc_import(self, pipe, ks, cf, columns, nullval):
+        """
+        This method is where child processes start when doing a COPY FROM
+        operation.  The child process will open one connection to the node and
+        interact directly with the connection, bypassing most of the driver
+        code.  Because we don't need retries, connection pooling, thread 
safety,
+        and other fancy features, this is okay.
+        """
 
-            if value != nullval:
-                if cqltype in ('ascii', 'text', 'timestamp', 'date', 'time', 
'inet'):
-                    rowmap[name] = protect_value(value)
-                else:
-                    rowmap[name] = value
-            elif name in primary_key_columns:
-                # By default, nullval is an empty string. See CASSANDRA-7792 
for details.
-                message = "Cannot insert null value for primary key column 
'%s'." % (name,)
-                if nullval == '':
-                    message += " If you want to insert empty strings, consider 
using " \
-                               "the WITH NULL=<marker> option for COPY."
-                self.printerr(message)
-                return False
-            else:
-                rowmap[name] = 'null'
-        # would be nice to be able to use a prepared query here, but in order
-        # to use that interface, we'd need to have all the input as native
-        # values already, reading them from text just like the various
-        # Cassandra cql types do. Better just to submit them all as intact
-        # CQL string literals and let Cassandra do its thing.
-        query = 'INSERT INTO %s.%s (%s) VALUES (%s)' % (
-            protect_name(table_meta.keyspace.name),
-            protect_name(table_meta.name),
-            ', '.join(protect_names(rowmap.keys())),
-            ', '.join(rowmap.values())
-        )
-        if self.debug:
-            print 'Import using CQL: %s' % query
-        return SimpleStatement(query)
+        # open a new connection for this subprocess
+        new_cluster = Cluster(
+                contact_points=(self.hostname,),
+                port=self.port,
+                cql_version=self.conn.cql_version,
+                protocol_version=DEFAULT_PROTOCOL_VERSION,
+                auth_provider=self.auth_provider,
+                ssl_options=sslhandling.ssl_settings(self.hostname, 
CONFIG_FILE) if self.ssl else None,
+                
load_balancing_policy=WhiteListRoundRobinPolicy([self.hostname]),
+                compression=None)
+        session = new_cluster.connect(self.keyspace)
+        conn = session._pools.values()[0]._connection
+
+        # pre-build as much of the query as we can
+        table_meta = self.get_table_meta(ks, cf)
+        pk_cols = [col.name for col in table_meta.primary_key]
+        cqltypes = [table_meta.columns[name].typestring for name in columns]
+        pk_indexes = [columns.index(col.name) for col in 
table_meta.primary_key]
+        query = 'INSERT INTO %s.%s (%s) VALUES (%%s)' % (
+            protect_name(ks),
+            protect_name(cf),
+            ', '.join(columns))
+
+        # we need to handle some types specially
+        should_escape = [t in ('ascii', 'text', 'timestamp', 'date', 'time', 
'inet') for t in cqltypes]
+
+        insert_timestamp = int(time.time() * 1e6)
+
+        def callback(record_num, response):
+            # This is the callback we register for all inserts.  Because this
+            # is run on the event-loop thread, we need to hold a lock when
+            # adjusting in_flight.
+            with conn.lock:
+                conn.in_flight -= 1
+
+            if not isinstance(response, ResultMessage):
+                # It's an error. Notify the parent process and let it send
+                # a stop signal to all child processes (including this one).
+                pipe.send((record_num, str(response)))
+                if isinstance(response, Exception) and self.debug:
+                    traceback.print_exc(response)
+
+        current_record = 0
+        insert_num = 0
+        try:
+            while True:
+                # To avoid totally maxing out the connection,
+                # defer to the reactor thread when we're close
+                # to capacity
+                if conn.in_flight > (conn.max_request_id * 0.9):
+                    conn._readable = True
+                    time.sleep(0.05)
+                    continue
+
+                try:
+                    (current_record, row) = pipe.recv()
+                except EOFError:
+                    # the pipe was closed and there's nothing to receive
+                    sys.stdout.write('Failed to read from pipe:\n\n')
+                    sys.stdout.flush()
+                    conn._writable = True
+                    conn._readable = True
+                    break
+
+                # see if the parent process has signaled that we are done
+                if (current_record, row) == (None, None):
+                    conn._writable = True
+                    conn._readable = True
+                    pipe.close()
+                    break
+
+                # format the values in the row
+                for i, value in enumerate(row):
+                    if value != nullval:
+                        if should_escape[i]:
+                            row[i] = protect_value(value)
+                    elif i in pk_indexes:
+                        # By default, nullval is an empty string. See 
CASSANDRA-7792 for details.
+                        message = "Cannot insert null value for primary key 
column '%s'." % (pk_cols[i],)
+                        if nullval == '':
+                            message += " If you want to insert empty strings, 
consider using " \
+                                       "the WITH NULL=<marker> option for 
COPY."
+                        pipe.send((current_record, message))
+                        return
+                    else:
+                        row[i] = 'null'
+
+                full_query = query % (','.join(row),)
+                query_message = QueryMessage(
+                    full_query, self.consistency_level, 
serial_consistency_level=None,
+                    fetch_size=None, paging_state=None, 
timestamp=insert_timestamp)
+
+                request_id = conn.get_request_id()
+                binary_message = query_message.to_binary(
+                    stream_id=request_id, 
protocol_version=DEFAULT_PROTOCOL_VERSION, compression=None)
+
+                # add the message directly to the connection's queue
+                with conn.lock:
+                    conn.in_flight += 1
+                conn._callbacks[request_id] = partial(callback, current_record)
+                conn.deque.append(binary_message)
+
+                # every 50 records, clear the pending writes queue and read
+                # any responses we have
+                if insert_num % 50 == 0:
+                    conn._writable = True
+                    conn._readable = True
+
+                insert_num += 1
+        except Exception, exc:
+            pipe.send((current_record, exc))
+        finally:
+            # wait for any pending requests to finish
+            while conn.in_flight > 0:
+                conn._readable = True
+                time.sleep(0.01)
 
+            new_cluster.shutdown()
 
     def perform_csv_export(self, ks, cf, columns, fname, opts):
         dialect_options = self.csv_dialect_defaults.copy()

http://git-wip-us.apache.org/repos/asf/cassandra/blob/7110904e/pylib/cqlshlib/async_insert.py
----------------------------------------------------------------------
diff --git a/pylib/cqlshlib/async_insert.py b/pylib/cqlshlib/async_insert.py
deleted file mode 100644
index d325716..0000000
--- a/pylib/cqlshlib/async_insert.py
+++ /dev/null
@@ -1,115 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from threading import Event, Condition
-from . import meter
-import sys
-
-class _CountDownLatch(object):
-    def __init__(self, counter=1):
-        self._count = counter
-        self._lock = Condition()
-
-    def count_down(self):
-        with self._lock:
-            self._count -= 1
-            if self._count <= 0:
-                self._lock.notifyAll()
-
-    def await(self):
-        with self._lock:
-            while self._count > 0:
-                # use a timeout so that the main thread wakes up occasionally
-                # so it can see keyboard interrupts (CASSANDRA-7815)
-                self._lock.wait(0.5)
-
-
-class _ChainedWriter(object):
-
-    CONCURRENCY = 100
-
-    def __init__(self, session, enumerated_reader, statement_func):
-        self._sentinel = object()
-        self._session = session
-        self._cancellation_event = Event()
-        self._first_error = None
-        self._task_counter = _CountDownLatch(self.CONCURRENCY)
-        self._enumerated_reader = enumerated_reader
-        self._statement_func = statement_func
-        self._meter = meter.Meter()
-
-    def insert(self):
-        if not self._enumerated_reader:
-            return 0, None
-
-        for i in xrange(self.CONCURRENCY):
-            self._execute_next(self._sentinel, 0)
-
-        try:
-            self._task_counter.await()
-        except KeyboardInterrupt:
-            self._cancellation_event.set()
-            sys.stdout.write('Aborting due to keyboard interrupt\n')
-            self._task_counter.await()
-        self._meter.done()
-        return self._meter.num_finished(), self._first_error
-
-
-    def _abort(self, error, failed_record):
-        if not self._first_error:
-            self._first_error = error, failed_record
-        self._task_counter.count_down()
-        self._cancellation_event.set()
-
-    def _handle_error(self, error, failed_record):
-        self._abort(error, failed_record)
-
-    def _execute_next(self, result, last_completed_record):
-        if self._cancellation_event.is_set():
-            self._task_counter.count_down()
-            return
-
-        if result is not self._sentinel:
-            self._meter.mark_written()
-
-        try:
-            (current_record, row) = next(self._enumerated_reader)
-        except StopIteration:
-            self._task_counter.count_down()
-            return
-        except Exception as exc:
-            self._abort(exc, last_completed_record)
-            return
-
-        if self._cancellation_event.is_set():
-            self._task_counter.count_down()
-            return
-
-        try:
-            statement = self._statement_func(row)
-            future = self._session.execute_async(statement)
-            future.add_callbacks(callback=self._execute_next,
-                                 callback_args=(current_record,),
-                                 errback=self._handle_error,
-                                 errback_args=(current_record,))
-        except Exception as exc:
-            self._abort(exc, current_record)
-            return
-
-
-def insert_concurrent(session, enumerated_reader, statement_func):
-    return _ChainedWriter(session, enumerated_reader, statement_func).insert()
-

http://git-wip-us.apache.org/repos/asf/cassandra/blob/7110904e/pylib/cqlshlib/meter.py
----------------------------------------------------------------------
diff --git a/pylib/cqlshlib/meter.py b/pylib/cqlshlib/meter.py
deleted file mode 100644
index e1a6bfc..0000000
--- a/pylib/cqlshlib/meter.py
+++ /dev/null
@@ -1,59 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from time import time
-import sys
-from threading import RLock
-
-
-class Meter(object):
-
-    def __init__(self):
-        self._num_finished = 0
-        self._last_checkpoint_time = None
-        self._current_rate = 0.0
-        self._lock = RLock()
-
-    def mark_written(self):
-        with self._lock:
-            if not self._last_checkpoint_time:
-                self._last_checkpoint_time = time()
-            self._num_finished += 1
-
-            if self._num_finished % 10000 == 0:
-                previous_checkpoint_time = self._last_checkpoint_time
-                self._last_checkpoint_time = time()
-                new_rate = 10000.0 / (self._last_checkpoint_time - 
previous_checkpoint_time)
-                if self._current_rate == 0.0:
-                    self._current_rate = new_rate
-                else:
-                    self._current_rate = (self._current_rate + new_rate) / 2.0
-
-            if self._num_finished % 1000 != 0:
-                return
-            output = 'Processed %s rows; Write: %.2f rows/s\r' % \
-                     (self._num_finished, self._current_rate)
-            sys.stdout.write(output)
-            sys.stdout.flush()
-
-    def num_finished(self):
-        with self._lock:
-            return self._num_finished
-
-    def done(self):
-        print ""
-
-

Reply via email to