Repository: cassandra Updated Branches: refs/heads/trunk 61c8ff4db -> 237d0bc62
Improve performance of cqlsh COPY FROM Patch by Tyler Hobbs; reviewed by Aleksey Yeschenko for CASSANDRA-8225 Project: http://git-wip-us.apache.org/repos/asf/cassandra/repo Commit: http://git-wip-us.apache.org/repos/asf/cassandra/commit/7110904e Tree: http://git-wip-us.apache.org/repos/asf/cassandra/tree/7110904e Diff: http://git-wip-us.apache.org/repos/asf/cassandra/diff/7110904e Branch: refs/heads/trunk Commit: 7110904e40dd2b9025b0189496afa7c7a30b646c Parents: 5cf61e7 Author: Tyler Hobbs <tylerho...@apache.org> Authored: Fri Mar 20 10:26:47 2015 -0500 Committer: Tyler Hobbs <tylerho...@apache.org> Committed: Fri Mar 20 10:26:47 2015 -0500 ---------------------------------------------------------------------- CHANGES.txt | 1 + bin/cqlsh | 290 +++++++++++++++++++++++++++++------- pylib/cqlshlib/async_insert.py | 115 -------------- pylib/cqlshlib/meter.py | 59 -------- 4 files changed, 237 insertions(+), 228 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/cassandra/blob/7110904e/CHANGES.txt ---------------------------------------------------------------------- diff --git a/CHANGES.txt b/CHANGES.txt index 14a45a3..a142999 100644 --- a/CHANGES.txt +++ b/CHANGES.txt @@ -1,4 +1,5 @@ 2.1.4 + * (cqlsh) Greatly improve performance of COPY FROM (CASSANDRA-8225) * IndexSummary effectiveIndexInterval is now a guideline, not a rule (CASSANDRA-8993) * Use correct bounds for page cache eviction of compressed files (CASSANDRA-8746) * SSTableScanner enforces its bounds (CASSANDRA-8946) http://git-wip-us.apache.org/repos/asf/cassandra/blob/7110904e/bin/cqlsh ---------------------------------------------------------------------- diff --git a/bin/cqlsh b/bin/cqlsh index 3ec9457..fdf6ce1 100755 --- a/bin/cqlsh +++ b/bin/cqlsh @@ -51,6 +51,8 @@ import platform import warnings import csv import getpass +from functools import partial +import traceback readline = None @@ -108,6 +110,8 @@ except ImportError, e: from cassandra.cluster import Cluster, PagedResult from cassandra.query import SimpleStatement, ordered_dict_factory from cassandra.policies import WhiteListRoundRobinPolicy +from cassandra.protocol import QueryMessage, ResultMessage +from cassandra.marshal import int16_pack, int32_pack, uint64_pack from cassandra.metadata import protect_name, protect_names, protect_value from cassandra.auth import PlainTextAuthProvider @@ -117,7 +121,7 @@ cqlshlibdir = os.path.join(CASSANDRA_PATH, 'pylib') if os.path.isdir(cqlshlibdir): sys.path.insert(0, cqlshlibdir) -from cqlshlib import cqlhandling, cql3handling, pylexotron, sslhandling, async_insert, meter +from cqlshlib import cqlhandling, cql3handling, pylexotron, sslhandling from cqlshlib.displaying import (RED, BLUE, CYAN, ANSI_RESET, COLUMN_NAME_COLORS, FormattedValue, colorme) from cqlshlib.formatting import format_by_type, formatter_for, format_value_utype @@ -550,6 +554,7 @@ class Shell(cmd.Cmd): self.auth_provider = PlainTextAuthProvider(username=username, password=password) self.username = username self.keyspace = keyspace + self.ssl = ssl self.tracing_enabled = tracing_enabled self.expand_enabled = expand_enabled if use_conn: @@ -913,7 +918,6 @@ class Shell(cmd.Cmd): self.handle_statement(st, statementtext) except Exception, e: if self.debug: - import traceback traceback.print_exc() else: self.printerr(e) @@ -1435,73 +1439,251 @@ class Shell(cmd.Cmd): except IOError, e: self.printerr("Can't open %r for reading: %s" % (fname, e)) return 0 + + current_record = None try: if header: linesource.next() table_meta = self.get_table_meta(ks, cf) reader = csv.reader(linesource, **dialect_options) - from functools import partial - rownum, error = \ - async_insert.insert_concurrent(self.session, enumerate(reader, start=1), - partial( - self.create_insert_statement, - columns, nullval, - table_meta)) - if error: - self.printerr(str(error[0])) - self.printerr("Aborting import at record #%d. " - "Previously-inserted values still present." - % error[1]) + + from multiprocessing import Process, Pipe, cpu_count + + # Pick a resonable number of child processes. We need to leave at + # least one core for the parent process. This doesn't necessarily + # need to be capped at 4, but it's currently enough to keep + # a single local Cassandra node busy, and I see lower throughput + # with more processes. + try: + num_processes = max(1, min(4, cpu_count() - 1)) + except NotImplementedError: + num_processes = 1 + + processes, pipes = [], [], + for i in range(num_processes): + parent_conn, child_conn = Pipe() + pipes.append(parent_conn) + processes.append(Process(target=self.multiproc_import, args=(child_conn, ks, cf, columns, nullval))) + + for process in processes: + process.start() + + last_checkpoint_time = time.time() + current_rate = 0.0 + for current_record, row in enumerate(reader, start=1): + # write to the child process + pipes[current_record % num_processes].send((current_record, row)) + + # update the progress and current rate periodically + if (current_record % 10000) == 0: + new_checkpoint_time = time.time() + new_rate = 10000.0 / (new_checkpoint_time - last_checkpoint_time) + last_checkpoint_time = new_checkpoint_time + + # smooth the rate a bit + if current_rate == 0.0: + current_rate = new_rate + else: + current_rate = (current_rate + new_rate) / 2.0 + + output = 'Processed %s rows; Write: %.2f rows/s\r' % \ + (current_record, current_rate) + sys.stdout.write(output) + sys.stdout.flush() + + # check for any errors reported by the children + if (current_record % 100) == 0: + if self._check_child_pipes(current_record, pipes): + # no errors seen, continue with outer loop + continue + else: + # errors seen, break out of outer loop + break + except Exception, exc: + if current_record is None: + # we failed before we started + self.printerr("\nError starting import process:\n") + self.printerr(str(exc)) + if self.debug: + traceback.print_exc() + else: + self.printerr("\n" + str(exc)) + self.printerr("\nAborting import at record #%d. " + "Previously inserted records and some records after " + "this number may be present." + % (current_record,)) + if self.debug: + traceback.print_exc() finally: + # send a message that indicates we're done + for pipe in pipes: + pipe.send((None, None)) + + for process in processes: + process.join() + + self._check_child_pipes(current_record, pipes) + + for pipe in pipes: + pipe.close() + if do_close: linesource.close() elif self.tty: print - return rownum - def create_insert_statement(self, columns, nullval, table_meta, row): + return current_record - if len(row) != len(columns): - raise ValueError( - "Record has the wrong number of fields (%d instead of %d)." - % (len(row), len(columns))) + def _check_child_pipes(self, current_record, pipes): + # check the pipes for errors from child processes + for pipe in pipes: + if pipe.poll(): + try: + (record_num, error) = pipe.recv() + self.printerr("\n" + str(error)) + self.printerr( + "Aborting import at record #%d. " + "Previously inserted records are still present, " + "and some records after that may be present as well." + % (record_num,)) + return False + except EOFError: + # pipe is closed, nothing to read + self.printerr("\nChild process died without notification, " + "aborting import at record #%d. Previously " + "inserted records are probably still present, " + "and some records after that may be present " + "as well." % (current_record,)) + return False + return True - rowmap = {} - primary_key_columns = [col.name for col in table_meta.primary_key] - for name, value in zip(columns, row): - type = table_meta.columns[name].data_type - cqltype = table_meta.columns[name].typestring + def multiproc_import(self, pipe, ks, cf, columns, nullval): + """ + This method is where child processes start when doing a COPY FROM + operation. The child process will open one connection to the node and + interact directly with the connection, bypassing most of the driver + code. Because we don't need retries, connection pooling, thread safety, + and other fancy features, this is okay. + """ - if value != nullval: - if cqltype in ('ascii', 'text', 'timestamp', 'date', 'time', 'inet'): - rowmap[name] = protect_value(value) - else: - rowmap[name] = value - elif name in primary_key_columns: - # By default, nullval is an empty string. See CASSANDRA-7792 for details. - message = "Cannot insert null value for primary key column '%s'." % (name,) - if nullval == '': - message += " If you want to insert empty strings, consider using " \ - "the WITH NULL=<marker> option for COPY." - self.printerr(message) - return False - else: - rowmap[name] = 'null' - # would be nice to be able to use a prepared query here, but in order - # to use that interface, we'd need to have all the input as native - # values already, reading them from text just like the various - # Cassandra cql types do. Better just to submit them all as intact - # CQL string literals and let Cassandra do its thing. - query = 'INSERT INTO %s.%s (%s) VALUES (%s)' % ( - protect_name(table_meta.keyspace.name), - protect_name(table_meta.name), - ', '.join(protect_names(rowmap.keys())), - ', '.join(rowmap.values()) - ) - if self.debug: - print 'Import using CQL: %s' % query - return SimpleStatement(query) + # open a new connection for this subprocess + new_cluster = Cluster( + contact_points=(self.hostname,), + port=self.port, + cql_version=self.conn.cql_version, + protocol_version=DEFAULT_PROTOCOL_VERSION, + auth_provider=self.auth_provider, + ssl_options=sslhandling.ssl_settings(self.hostname, CONFIG_FILE) if self.ssl else None, + load_balancing_policy=WhiteListRoundRobinPolicy([self.hostname]), + compression=None) + session = new_cluster.connect(self.keyspace) + conn = session._pools.values()[0]._connection + + # pre-build as much of the query as we can + table_meta = self.get_table_meta(ks, cf) + pk_cols = [col.name for col in table_meta.primary_key] + cqltypes = [table_meta.columns[name].typestring for name in columns] + pk_indexes = [columns.index(col.name) for col in table_meta.primary_key] + query = 'INSERT INTO %s.%s (%s) VALUES (%%s)' % ( + protect_name(ks), + protect_name(cf), + ', '.join(columns)) + + # we need to handle some types specially + should_escape = [t in ('ascii', 'text', 'timestamp', 'date', 'time', 'inet') for t in cqltypes] + + insert_timestamp = int(time.time() * 1e6) + + def callback(record_num, response): + # This is the callback we register for all inserts. Because this + # is run on the event-loop thread, we need to hold a lock when + # adjusting in_flight. + with conn.lock: + conn.in_flight -= 1 + + if not isinstance(response, ResultMessage): + # It's an error. Notify the parent process and let it send + # a stop signal to all child processes (including this one). + pipe.send((record_num, str(response))) + if isinstance(response, Exception) and self.debug: + traceback.print_exc(response) + + current_record = 0 + insert_num = 0 + try: + while True: + # To avoid totally maxing out the connection, + # defer to the reactor thread when we're close + # to capacity + if conn.in_flight > (conn.max_request_id * 0.9): + conn._readable = True + time.sleep(0.05) + continue + + try: + (current_record, row) = pipe.recv() + except EOFError: + # the pipe was closed and there's nothing to receive + sys.stdout.write('Failed to read from pipe:\n\n') + sys.stdout.flush() + conn._writable = True + conn._readable = True + break + + # see if the parent process has signaled that we are done + if (current_record, row) == (None, None): + conn._writable = True + conn._readable = True + pipe.close() + break + + # format the values in the row + for i, value in enumerate(row): + if value != nullval: + if should_escape[i]: + row[i] = protect_value(value) + elif i in pk_indexes: + # By default, nullval is an empty string. See CASSANDRA-7792 for details. + message = "Cannot insert null value for primary key column '%s'." % (pk_cols[i],) + if nullval == '': + message += " If you want to insert empty strings, consider using " \ + "the WITH NULL=<marker> option for COPY." + pipe.send((current_record, message)) + return + else: + row[i] = 'null' + + full_query = query % (','.join(row),) + query_message = QueryMessage( + full_query, self.consistency_level, serial_consistency_level=None, + fetch_size=None, paging_state=None, timestamp=insert_timestamp) + + request_id = conn.get_request_id() + binary_message = query_message.to_binary( + stream_id=request_id, protocol_version=DEFAULT_PROTOCOL_VERSION, compression=None) + + # add the message directly to the connection's queue + with conn.lock: + conn.in_flight += 1 + conn._callbacks[request_id] = partial(callback, current_record) + conn.deque.append(binary_message) + + # every 50 records, clear the pending writes queue and read + # any responses we have + if insert_num % 50 == 0: + conn._writable = True + conn._readable = True + + insert_num += 1 + except Exception, exc: + pipe.send((current_record, exc)) + finally: + # wait for any pending requests to finish + while conn.in_flight > 0: + conn._readable = True + time.sleep(0.01) + new_cluster.shutdown() def perform_csv_export(self, ks, cf, columns, fname, opts): dialect_options = self.csv_dialect_defaults.copy() http://git-wip-us.apache.org/repos/asf/cassandra/blob/7110904e/pylib/cqlshlib/async_insert.py ---------------------------------------------------------------------- diff --git a/pylib/cqlshlib/async_insert.py b/pylib/cqlshlib/async_insert.py deleted file mode 100644 index d325716..0000000 --- a/pylib/cqlshlib/async_insert.py +++ /dev/null @@ -1,115 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from threading import Event, Condition -from . import meter -import sys - -class _CountDownLatch(object): - def __init__(self, counter=1): - self._count = counter - self._lock = Condition() - - def count_down(self): - with self._lock: - self._count -= 1 - if self._count <= 0: - self._lock.notifyAll() - - def await(self): - with self._lock: - while self._count > 0: - # use a timeout so that the main thread wakes up occasionally - # so it can see keyboard interrupts (CASSANDRA-7815) - self._lock.wait(0.5) - - -class _ChainedWriter(object): - - CONCURRENCY = 100 - - def __init__(self, session, enumerated_reader, statement_func): - self._sentinel = object() - self._session = session - self._cancellation_event = Event() - self._first_error = None - self._task_counter = _CountDownLatch(self.CONCURRENCY) - self._enumerated_reader = enumerated_reader - self._statement_func = statement_func - self._meter = meter.Meter() - - def insert(self): - if not self._enumerated_reader: - return 0, None - - for i in xrange(self.CONCURRENCY): - self._execute_next(self._sentinel, 0) - - try: - self._task_counter.await() - except KeyboardInterrupt: - self._cancellation_event.set() - sys.stdout.write('Aborting due to keyboard interrupt\n') - self._task_counter.await() - self._meter.done() - return self._meter.num_finished(), self._first_error - - - def _abort(self, error, failed_record): - if not self._first_error: - self._first_error = error, failed_record - self._task_counter.count_down() - self._cancellation_event.set() - - def _handle_error(self, error, failed_record): - self._abort(error, failed_record) - - def _execute_next(self, result, last_completed_record): - if self._cancellation_event.is_set(): - self._task_counter.count_down() - return - - if result is not self._sentinel: - self._meter.mark_written() - - try: - (current_record, row) = next(self._enumerated_reader) - except StopIteration: - self._task_counter.count_down() - return - except Exception as exc: - self._abort(exc, last_completed_record) - return - - if self._cancellation_event.is_set(): - self._task_counter.count_down() - return - - try: - statement = self._statement_func(row) - future = self._session.execute_async(statement) - future.add_callbacks(callback=self._execute_next, - callback_args=(current_record,), - errback=self._handle_error, - errback_args=(current_record,)) - except Exception as exc: - self._abort(exc, current_record) - return - - -def insert_concurrent(session, enumerated_reader, statement_func): - return _ChainedWriter(session, enumerated_reader, statement_func).insert() - http://git-wip-us.apache.org/repos/asf/cassandra/blob/7110904e/pylib/cqlshlib/meter.py ---------------------------------------------------------------------- diff --git a/pylib/cqlshlib/meter.py b/pylib/cqlshlib/meter.py deleted file mode 100644 index e1a6bfc..0000000 --- a/pylib/cqlshlib/meter.py +++ /dev/null @@ -1,59 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from time import time -import sys -from threading import RLock - - -class Meter(object): - - def __init__(self): - self._num_finished = 0 - self._last_checkpoint_time = None - self._current_rate = 0.0 - self._lock = RLock() - - def mark_written(self): - with self._lock: - if not self._last_checkpoint_time: - self._last_checkpoint_time = time() - self._num_finished += 1 - - if self._num_finished % 10000 == 0: - previous_checkpoint_time = self._last_checkpoint_time - self._last_checkpoint_time = time() - new_rate = 10000.0 / (self._last_checkpoint_time - previous_checkpoint_time) - if self._current_rate == 0.0: - self._current_rate = new_rate - else: - self._current_rate = (self._current_rate + new_rate) / 2.0 - - if self._num_finished % 1000 != 0: - return - output = 'Processed %s rows; Write: %.2f rows/s\r' % \ - (self._num_finished, self._current_rate) - sys.stdout.write(output) - sys.stdout.flush() - - def num_finished(self): - with self._lock: - return self._num_finished - - def done(self): - print "" - -