http://git-wip-us.apache.org/repos/asf/cassandra/blob/f2883879/pylib/cqlshlib/copyutil.py
----------------------------------------------------------------------
diff --git a/pylib/cqlshlib/copyutil.py b/pylib/cqlshlib/copyutil.py
index 6d9a455..a154363 100644
--- a/pylib/cqlshlib/copyutil.py
+++ b/pylib/cqlshlib/copyutil.py
@@ -14,12 +14,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import ConfigParser
import csv
+import datetime
import json
+import glob
import multiprocessing as mp
import os
import Queue
-import random
import re
import struct
import sys
@@ -46,81 +48,202 @@ from displaying import NO_COLOR_MAP
from formatting import format_value_default, DateTimeFormat, EMPTY,
get_formatter
from sslhandling import ssl_settings
+CopyOptions = namedtuple('CopyOptions', 'copy dialect unrecognized')
-def parse_options(shell, opts):
- """
- Parse options for import (COPY FROM) and export (COPY TO) operations.
- Extract from opts csv and dialect options.
- :return: 3 dictionaries: the csv options, the dialect options, any
unrecognized options.
- """
- dialect_options = shell.csv_dialect_defaults.copy()
- if 'quote' in opts:
- dialect_options['quotechar'] = opts.pop('quote')
- if 'escape' in opts:
- dialect_options['escapechar'] = opts.pop('escape')
- if 'delimiter' in opts:
- dialect_options['delimiter'] = opts.pop('delimiter')
- if dialect_options['quotechar'] == dialect_options['escapechar']:
- dialect_options['doublequote'] = True
- del dialect_options['escapechar']
-
- csv_options = dict()
- csv_options['nullval'] = opts.pop('null', '')
- csv_options['header'] = bool(opts.pop('header', '').lower() == 'true')
- csv_options['encoding'] = opts.pop('encoding', 'utf8')
- csv_options['maxrequests'] = int(opts.pop('maxrequests', 6))
- csv_options['pagesize'] = int(opts.pop('pagesize', 1000))
- # by default the page timeout is 10 seconds per 1000 entries in the page
size or 10 seconds if pagesize is smaller
- csv_options['pagetimeout'] = int(opts.pop('pagetimeout', max(10, 10 *
(csv_options['pagesize'] / 1000))))
- csv_options['maxattempts'] = int(opts.pop('maxattempts', 5))
- csv_options['dtformats'] = DateTimeFormat(opts.pop('timeformat',
shell.display_timestamp_format),
- shell.display_date_format,
- shell.display_nanotime_format)
- csv_options['float_precision'] = shell.display_float_precision
- csv_options['chunksize'] = int(opts.pop('chunksize', 1000))
- csv_options['ingestrate'] = int(opts.pop('ingestrate', 100000))
- csv_options['maxbatchsize'] = int(opts.pop('maxbatchsize', 20))
- csv_options['minbatchsize'] = int(opts.pop('minbatchsize', 2))
- csv_options['reportfrequency'] = float(opts.pop('reportfrequency', 0.25))
-
- return csv_options, dialect_options, opts
-
-
-def get_num_processes(cap):
+def safe_normpath(fname):
"""
- Pick a reasonable number of child processes. We need to leave at
- least one core for the parent process. This doesn't necessarily
- need to be capped, but 4 is currently enough to keep
- a single local Cassandra node busy so we use this for import, whilst
- for export we use 16 since we can connect to multiple Cassandra nodes.
- Eventually this parameter will become an option.
+ :return the normalized path but only if there is a filename, we don't want
to convert
+ an empty string (which means no file name) to a dot. Also expand any user
variables such as ~ to the full path
"""
- try:
- return max(1, min(cap, mp.cpu_count() - 1))
- except NotImplementedError:
- return 1
+ return os.path.normpath(os.path.expanduser(fname)) if fname else fname
class CopyTask(object):
"""
A base class for ImportTask and ExportTask
"""
- def __init__(self, shell, ks, cf, columns, fname, csv_options,
dialect_options, protocol_version, config_file):
+ def __init__(self, shell, ks, table, columns, fname, opts,
protocol_version, config_file, direction):
self.shell = shell
- self.csv_options = csv_options
- self.dialect_options = dialect_options
self.ks = ks
- self.cf = cf
- self.columns = shell.get_column_names(ks, cf) if columns is None else
columns
- self.fname = fname
+ self.table = table
+ self.local_dc = shell.conn.metadata.get_host(shell.hostname).datacenter
+ self.fname = safe_normpath(fname)
self.protocol_version = protocol_version
self.config_file = config_file
+ # do not display messages when exporting to STDOUT
+ self.printmsg = self._printmsg if self.fname is not None or direction
== 'in' else lambda _, eol='\n': None
+ self.options = self.parse_options(opts, direction)
+
+ self.num_processes = self.options.copy['numprocesses']
+ self.printmsg('Using %d child processes' % (self.num_processes,))
self.processes = []
self.inmsg = mp.Queue()
self.outmsg = mp.Queue()
+ self.columns = CopyTask.get_columns(shell, ks, table, columns)
+ self.time_start = time.time()
+
+ @staticmethod
+ def _printmsg(msg, eol='\n'):
+ sys.stdout.write(msg + eol)
+ sys.stdout.flush()
+
+ def maybe_read_config_file(self, opts, direction):
+ """
+ Read optional sections from a configuration file that was specified
in the command options or from the default
+ cqlshrc configuration file if none was specified.
+ """
+ config_file = opts.pop('configfile', '')
+ if not config_file:
+ config_file = self.config_file
+
+ if not os.path.isfile(config_file):
+ return opts
+
+ configs = ConfigParser.RawConfigParser()
+ configs.readfp(open(config_file))
+
+ ret = dict()
+ config_sections = list(['copy', 'copy-%s' % (direction,),
+ 'copy:%s.%s' % (self.ks, self.table),
+ 'copy-%s:%s.%s' % (direction, self.ks,
self.table)])
+
+ for section in config_sections:
+ if configs.has_section(section):
+ options = dict(configs.items(section))
+ self.printmsg("Reading options from %s:[%s]: %s" %
(config_file, section, options))
+ ret.update(options)
+
+ # Update this last so the command line options take precedence over
the configuration file options
+ if opts:
+ self.printmsg("Reading options from the command line: %s" %
(opts,))
+ ret.update(opts)
+
+ if self.shell.debug: # this is important for testing, do not remove
+ self.printmsg("Using options: '%s'" % (ret,))
+
+ return ret
+
+ @staticmethod
+ def clean_options(opts):
+ """
+ Convert all option values to valid string literals unless they are
path names
+ """
+ return dict([(k, v.decode('string_escape') if k not in ['errfile',
'ratefile'] else v)
+ for k, v, in opts.iteritems()])
+
+ def parse_options(self, opts, direction):
+ """
+ Parse options for import (COPY FROM) and export (COPY TO) operations.
+ Extract from opts csv and dialect options.
+
+ :return: 3 dictionaries: the csv options, the dialect options, any
unrecognized options.
+ """
+ shell = self.shell
+ opts = self.clean_options(self.maybe_read_config_file(opts, direction))
+
+ dialect_options = dict()
+ dialect_options['quotechar'] = opts.pop('quote', '"')
+ dialect_options['escapechar'] = opts.pop('escape', '\\')
+ dialect_options['delimiter'] = opts.pop('delimiter', ',')
+ if dialect_options['quotechar'] == dialect_options['escapechar']:
+ dialect_options['doublequote'] = True
+ del dialect_options['escapechar']
+ else:
+ dialect_options['doublequote'] = False
+
+ copy_options = dict()
+ copy_options['nullval'] = opts.pop('null', '')
+ copy_options['header'] = bool(opts.pop('header', '').lower() == 'true')
+ copy_options['encoding'] = opts.pop('encoding', 'utf8')
+ copy_options['maxrequests'] = int(opts.pop('maxrequests', 6))
+ copy_options['pagesize'] = int(opts.pop('pagesize', 1000))
+ # by default the page timeout is 10 seconds per 1000 entries
+ # in the page size or 10 seconds if pagesize is smaller
+ copy_options['pagetimeout'] = int(opts.pop('pagetimeout', max(10, 10 *
(copy_options['pagesize'] / 1000))))
+ copy_options['maxattempts'] = int(opts.pop('maxattempts', 5))
+ copy_options['dtformats'] = DateTimeFormat(opts.pop('datetimeformat',
shell.display_timestamp_format),
+ shell.display_date_format,
shell.display_nanotime_format)
+ copy_options['float_precision'] = shell.display_float_precision
+ copy_options['chunksize'] = int(opts.pop('chunksize', 1000))
+ copy_options['ingestrate'] = int(opts.pop('ingestrate', 100000))
+ copy_options['maxbatchsize'] = int(opts.pop('maxbatchsize', 20))
+ copy_options['minbatchsize'] = int(opts.pop('minbatchsize', 2))
+ copy_options['reportfrequency'] = float(opts.pop('reportfrequency',
0.25))
+ copy_options['consistencylevel'] = shell.consistency_level
+ copy_options['decimalsep'] = opts.pop('decimalsep', '.')
+ copy_options['thousandssep'] = opts.pop('thousandssep', '')
+ copy_options['boolstyle'] = [s.strip() for s in opts.pop('boolstyle',
'True, False').split(',')]
+ copy_options['numprocesses'] = int(opts.pop('numprocesses',
self.get_num_processes(cap=16)))
+ copy_options['begintoken'] = opts.pop('begintoken', '')
+ copy_options['endtoken'] = opts.pop('endtoken', '')
+ copy_options['maxrows'] = int(opts.pop('maxrows', '-1'))
+ copy_options['skiprows'] = int(opts.pop('skiprows', '0'))
+ copy_options['skipcols'] = opts.pop('skipcols', '')
+ copy_options['maxparseerrors'] = int(opts.pop('maxparseerrors', '-1'))
+ copy_options['maxinserterrors'] = int(opts.pop('maxinserterrors',
'-1'))
+ copy_options['errfile'] = safe_normpath(opts.pop('errfile',
'import_%s_%s.err' % (self.ks, self.table,)))
+ copy_options['ratefile'] = safe_normpath(opts.pop('ratefile', ''))
+ copy_options['maxoutputsize'] = int(opts.pop('maxoutputsize', '-1'))
+
+ self.check_options(copy_options)
+ return CopyOptions(copy=copy_options, dialect=dialect_options,
unrecognized=opts)
+
+ @staticmethod
+ def check_options(copy_options):
+ """
+ Check any options that require a sanity check beyond a simple type
conversion and if required
+ raise a value error:
+
+ - boolean styles must be exactly 2, they must be different and they
cannot be empty
+ """
+ bool_styles = copy_options['boolstyle']
+ if len(bool_styles) != 2 or bool_styles[0] == bool_styles[1] or not
bool_styles[0] or not bool_styles[1]:
+ raise ValueError("Invalid boolean styles %s" %
copy_options['boolstyle'])
+
+ @staticmethod
+ def get_num_processes(cap):
+ """
+ Pick a reasonable number of child processes. We need to leave at
+ least one core for the parent process. This doesn't necessarily
+ need to be capped, but 4 is currently enough to keep
+ a single local Cassandra node busy so we use this for import, whilst
+ for export we use 16 since we can connect to multiple Cassandra nodes.
+ Eventually this parameter will become an option.
+ """
+ try:
+ return max(1, min(cap, mp.cpu_count() - 1))
+ except NotImplementedError:
+ return 1
+
+ @staticmethod
+ def describe_interval(seconds):
+ desc = []
+ for length, unit in ((86400, 'day'), (3600, 'hour'), (60, 'minute')):
+ num = int(seconds) / length
+ if num > 0:
+ desc.append('%d %s' % (num, unit))
+ if num > 1:
+ desc[-1] += 's'
+ seconds %= length
+ words = '%.03f seconds' % seconds
+ if len(desc) > 1:
+ words = ', '.join(desc) + ', and ' + words
+ elif len(desc) == 1:
+ words = desc[0] + ' and ' + words
+ return words
+
+ @staticmethod
+ def get_columns(shell, ks, table, columns):
+ """
+ Return all columns if none were specified or only the columns
specified.
+ Possible enhancement: introduce a regex like syntax (^) to allow users
+ to specify all columns except a few.
+ """
+ return shell.get_column_names(ks, table) if not columns else columns
+
def close(self):
for process in self.processes:
process.terminate()
@@ -144,11 +267,10 @@ class CopyTask(object):
return dict(inmsg=self.outmsg, # see comment above
outmsg=self.inmsg, # see comment above
ks=self.ks,
- cf=self.cf,
+ table=self.table,
+ local_dc=self.local_dc,
columns=self.columns,
- csv_options=self.csv_options,
- dialect_options=self.dialect_options,
- consistency_level=shell.consistency_level,
+ options=self.options,
connect_timeout=shell.conn.connect_timeout,
hostname=shell.hostname,
port=shell.port,
@@ -161,55 +283,157 @@ class CopyTask(object):
)
+class ExportWriter(object):
+ """
+ A class that writes to one or more csv files, or STDOUT
+ """
+
+ def __init__(self, fname, shell, columns, options):
+ self.fname = fname
+ self.shell = shell
+ self.columns = columns
+ self.options = options
+ self.header = options.copy['header']
+ self.max_output_size = long(options.copy['maxoutputsize'])
+ self.current_dest = None
+ self.num_files = 0
+
+ if self.max_output_size > 0:
+ if fname is not None:
+ self.write = self._write_with_split
+ self.num_written = 0
+ else:
+ shell.printerr("WARNING: maxoutputsize {} ignored when writing
to STDOUT".format(self.max_output_size))
+ self.write = self._write_without_split
+ else:
+ self.write = self._write_without_split
+
+ def open(self):
+ self.current_dest = self._get_dest(self.fname)
+ if self.current_dest is None:
+ return False
+
+ if self.header:
+ writer = csv.writer(self.current_dest.output,
**self.options.dialect)
+ writer.writerow(self.columns)
+
+ return True
+
+ def close(self):
+ self._close_current_dest()
+
+ def _next_dest(self):
+ self._close_current_dest()
+ self.current_dest = self._get_dest(self.fname + '.%d' %
(self.num_files,))
+
+ def _get_dest(self, source_name):
+ """
+ Open the output file if any or else use stdout. Return a namedtuple
+ containing the out and a boolean indicating if the output should be
closed.
+ """
+ CsvDest = namedtuple('CsvDest', 'output close')
+
+ if self.fname is None:
+ return CsvDest(output=sys.stdout, close=False)
+ else:
+ try:
+ ret = CsvDest(output=open(source_name, 'wb'), close=True)
+ self.num_files += 1
+ return ret
+ except IOError, e:
+ self.shell.printerr("Can't open %r for writing: %s" %
(source_name, e))
+ return None
+
+ def _close_current_dest(self):
+ if self.current_dest and self.current_dest.close:
+ self.current_dest.output.close()
+ self.current_dest = None
+
+ def _write_without_split(self, data, _):
+ """
+ Write the data to the current destination output.
+ """
+ self.current_dest.output.write(data)
+
+ def _write_with_split(self, data, num):
+ """
+ Write the data to the current destination output if we still
+ haven't reached the maximum number of rows. Otherwise split
+ the rows between the current destination and the next.
+ """
+ if (self.num_written + num) > self.max_output_size:
+ num_remaining = self.max_output_size - self.num_written
+ last_switch = 0
+ for i, row in enumerate(filter(None, data.split(os.linesep))):
+ if i == num_remaining:
+ self._next_dest()
+ last_switch = i
+ num_remaining += self.max_output_size
+ self.current_dest.output.write(row + '\n')
+
+ self.num_written = num - last_switch
+ else:
+ self.num_written += num
+ self.current_dest.output.write(data)
+
+
class ExportTask(CopyTask):
"""
A class that exports data to .csv by instantiating one or more processes
that work in parallel (ExportProcess).
"""
+ def __init__(self, shell, ks, table, columns, fname, opts,
protocol_version, config_file):
+ CopyTask.__init__(self, shell, ks, table, columns, fname, opts,
protocol_version, config_file, 'to')
+
+ options = self.options
+ self.begin_token = long(options.copy['begintoken']) if
options.copy['begintoken'] else None
+ self.end_token = long(options.copy['endtoken']) if
options.copy['endtoken'] else None
+ self.writer = ExportWriter(fname, shell, columns, options)
def run(self):
"""
- Initiates the export by creating the processes.
+ Initiates the export by starting the worker processes.
+ Then hand over control to export_records.
"""
shell = self.shell
- fname = self.fname
- if fname is None:
- do_close = False
- csvdest = sys.stdout
- else:
- do_close = True
- try:
- csvdest = open(fname, 'wb')
- except IOError, e:
- shell.printerr("Can't open %r for writing: %s" % (fname, e))
- return 0
+ if self.options.unrecognized:
+ shell.printerr('Unrecognized COPY TO options: %s' % ',
'.join(self.options.unrecognized.keys()))
+ return
- if self.csv_options['header']:
- writer = csv.writer(csvdest, **self.dialect_options)
- writer.writerow(self.columns)
+ if not self.columns:
+ shell.printerr("No column specified")
+ return 0
ranges = self.get_ranges()
- num_processes = get_num_processes(cap=min(16, len(ranges)))
- params = self.make_params()
+ if not ranges:
+ return 0
+
+ if not self.writer.open():
+ return 0
+
+ self.printmsg("\nStarting copy of %s.%s with columns %s." % (self.ks,
self.table, self.columns))
- for i in xrange(num_processes):
+ params = self.make_params()
+ for i in xrange(self.num_processes):
self.processes.append(ExportProcess(params))
for process in self.processes:
process.start()
try:
- return self.check_processes(csvdest, ranges)
+ self.export_records(ranges)
finally:
self.close()
- if do_close:
- csvdest.close()
+
+ def close(self):
+ CopyTask.close(self)
+ self.writer.close()
def get_ranges(self):
"""
return a queue of tuples, where the first tuple entry is a token range
(from, to]
and the second entry is a list of hosts that own that range. Each host
is responsible
- for all the tokens in the rage (from, to].
+ for all the tokens in the range (from, to].
The ring information comes from the driver metadata token map, which
is built by
querying System.PEERS.
@@ -219,43 +443,83 @@ class ExportTask(CopyTask):
"""
shell = self.shell
hostname = shell.hostname
+ local_dc = self.local_dc
ranges = dict()
+ min_token = self.get_min_token()
+ begin_token = self.begin_token
+ end_token = self.end_token
- def make_range(hosts):
+ def make_range(prev, curr):
+ """
+ Return the intersection of (prev, curr) and (begin_token,
end_token),
+ return None if the intersection is empty
+ """
+ ret = (prev, curr)
+ if begin_token:
+ if ret[1] < begin_token:
+ return None
+ elif ret[0] < begin_token:
+ ret = (begin_token, ret[1])
+
+ if end_token:
+ if ret[0] > end_token:
+ return None
+ elif ret[1] > end_token:
+ ret = (ret[0], end_token)
+
+ return ret
+
+ def make_range_data(replicas=[]):
+ hosts = []
+ for r in replicas:
+ if r.is_up and r.datacenter == local_dc:
+ hosts.append(r.address)
+ if not hosts:
+ hosts.append(hostname) # fallback to default host if no
replicas in current dc
return {'hosts': tuple(hosts), 'attempts': 0, 'rows': 0}
- min_token = self.get_min_token()
+ if begin_token and begin_token < min_token:
+ shell.printerr('Begin token %d must be bigger or equal to min
token %d' % (begin_token, min_token))
+ return ranges
+
+ if begin_token and end_token and begin_token > end_token:
+ shell.printerr('Begin token %d must be smaller than end token %d'
% (begin_token, end_token))
+ return ranges
+
if shell.conn.metadata.token_map is None or min_token is None:
- ranges[(None, None)] = make_range([hostname])
+ ranges[(begin_token, end_token)] = make_range_data()
return ranges
- local_dc = shell.conn.metadata.get_host(hostname).datacenter
ring = shell.get_ring(self.ks).items()
ring.sort()
- previous_previous = None
+ # If the ring is empty we get the entire ring from the host we are
currently connected to
+ if not ring:
+ ranges[(begin_token, end_token)] = make_range_data()
+ return ranges
+
+ first_range_data = None
previous = None
for token, replicas in ring:
if previous is None and token.value == min_token:
continue # avoids looping entire ring
- hosts = []
- for host in replicas:
- if host.is_up and host.datacenter == local_dc:
- hosts.append(host.address)
- if not hosts:
- hosts.append(hostname) # fallback to default host if no
replicas in current dc
- ranges[(previous, token.value)] = make_range(hosts)
- previous_previous = previous
+ if previous is None: # we use it at the end when wrapping around
+ first_range_data = make_range_data(replicas)
+
+ current_range = make_range(previous, token.value)
+ if not current_range:
+ continue
+
+ ranges[current_range] = make_range_data(replicas)
previous = token.value
- # If the ring is empty we get the entire ring from the
- # host we are currently connected to, otherwise for the last ring
interval
- # we query the same replicas that hold the last token in the ring
+ # For the last ring interval we query the same replicas that hold the
first token in the ring
+ if previous is not None and (not end_token or previous < end_token):
+ ranges[(previous, end_token)] = first_range_data
+
if not ranges:
- ranges[(None, None)] = make_range([hostname])
- else:
- ranges[(previous, None)] = ranges[(previous_previous,
previous)].copy()
+ shell.printerr('Found no ranges to query, check begin and end
tokens: %s - %s' % (begin_token, end_token))
return ranges
@@ -280,17 +544,21 @@ class ExportTask(CopyTask):
self.outmsg.put((token_range, ranges[token_range]))
ranges[token_range]['attempts'] += 1
- def check_processes(self, csvdest, ranges):
+ def export_records(self, ranges):
"""
- Here we monitor all child processes by collecting their results
- or any errors. We terminate when we have processed all the ranges or
when there
- are no more processes.
+ Send records to child processes and monitor them by collecting their
results
+ or any errors. We terminate when we have processed all the ranges or
when one child
+ process has died (since in this case we will never get any ACK for the
ranges
+ processed by it and at the moment we don't keep track of which ranges a
+ process is handling).
"""
shell = self.shell
processes = self.processes
- meter = RateMeter(update_interval=self.csv_options['reportfrequency'])
+ meter = RateMeter(log_fcn=self.printmsg,
+ update_interval=self.options.copy['reportfrequency'],
+ log_file=self.options.copy['ratefile'])
total_requests = len(ranges)
- max_attempts = self.csv_options['maxattempts']
+ max_attempts = self.options.copy['maxattempts']
self.send_work(ranges, ranges.keys())
@@ -306,8 +574,10 @@ class ExportTask(CopyTask):
if token_range is None: # the entire process failed
shell.printerr('Error from worker process: %s' %
(result))
else: # only this token_range failed, retry up to
max_attempts if no rows received yet,
- # if rows are receive we risk duplicating data,
there is a back-off policy in place
- # in the worker process as well, see
ExpBackoffRetryPolicy
+ # If rows were already received we'd risk
duplicating data.
+ # Note that there is still a slight risk of
duplicating data, even if we have
+ # an error with no rows received yet, it's just
less likely. To avoid retrying on
+ # all timeouts would however mean we could risk
not exporting some rows.
if ranges[token_range]['attempts'] < max_attempts and
ranges[token_range]['rows'] == 0:
shell.printerr('Error for %s: %s (will try again
later attempt %d of %d)'
% (token_range, result,
ranges[token_range]['attempts'], max_attempts))
@@ -319,7 +589,7 @@ class ExportTask(CopyTask):
failed += 1
else: # partial result received
data, num = result
- csvdest.write(data)
+ self.writer.write(data, num)
meter.increment(n=num)
ranges[token_range]['rows'] += num
except Queue.Empty:
@@ -334,27 +604,207 @@ class ExportTask(CopyTask):
shell.printerr('Exported %d ranges out of %d total ranges, some
records might be missing'
% (succeeded, total_requests))
- return meter.get_total_records()
+ self.printmsg("\n%d rows exported to %d files in %s." %
+ (meter.get_total_records(),
+ self.writer.num_files,
+ self.describe_interval(time.time() - self.time_start)))
class ImportReader(object):
"""
A wrapper around a csv reader to keep track of when we have
- exhausted reading input records.
+ exhausted reading input files. We are passed a comma separated
+ list of paths, where each path is a valid glob expression.
+ We generate a source generator and we read each source one
+ by one.
"""
- def __init__(self, linesource, chunksize, dialect_options):
- self.linesource = linesource
- self.chunksize = chunksize
- self.reader = csv.reader(linesource, **dialect_options)
- self.exhausted = False
-
- def read_rows(self):
- if self.exhausted:
+ def __init__(self, task):
+ self.shell = task.shell
+ self.options = task.options
+ self.printmsg = task.printmsg
+ self.chunk_size = self.options.copy['chunksize']
+ self.header = self.options.copy['header']
+ self.max_rows = self.options.copy['maxrows']
+ self.skip_rows = self.options.copy['skiprows']
+ self.sources = self.get_source(task.fname)
+ self.num_sources = 0
+ self.current_source = None
+ self.current_reader = None
+ self.num_read = 0
+
+ def get_source(self, paths):
+ """
+ Return a source generator. Each source is a named tuple
+ wrapping the source input, file name and a boolean indicating
+ if it requires closing.
+ """
+ shell = self.shell
+ LineSource = namedtuple('LineSource', 'input close fname')
+
+ def make_source(fname):
+ try:
+ ret = LineSource(input=open(fname, 'rb'), close=True,
fname=fname)
+ return ret
+ except IOError, e:
+ shell.printerr("Can't open %r for reading: %s" % (fname, e))
+ return None
+
+ if paths is None:
+ self.printmsg("[Use \. on a line by itself to end input]")
+ yield LineSource(input=shell.use_stdin_reader(prompt='[copy] ',
until=r'\.'), close=False, fname='')
+ else:
+ for path in paths.split(','):
+ path = path.strip()
+ if os.path.isfile(path):
+ yield make_source(path)
+ else:
+ for f in glob.glob(path):
+ yield (make_source(f))
+
+ def start(self):
+ self.next_source()
+
+ @property
+ def exhausted(self):
+ return not self.current_reader
+
+ def next_source(self):
+ """
+ Close the current source, if any, and open the next one. Return true
+ if there is another source, false otherwise.
+ """
+ self.close_current_source()
+ while self.current_source is None:
+ try:
+ self.current_source = self.sources.next()
+ if self.current_source and self.current_source.fname:
+ self.num_sources += 1
+ except StopIteration:
+ return False
+
+ if self.header:
+ self.current_source.input.next()
+
+ self.current_reader = csv.reader(self.current_source.input,
**self.options.dialect)
+ return True
+
+ def close_current_source(self):
+ if not self.current_source:
+ return
+
+ if self.current_source.close:
+ self.current_source.input.close()
+ elif self.shell.tty:
+ print
+
+ self.current_source = None
+ self.current_reader = None
+
+ def close(self):
+ self.close_current_source()
+
+ def read_rows(self, max_rows):
+ if not self.current_reader:
return []
- rows = list(next(self.reader) for _ in xrange(self.chunksize))
- self.exhausted = len(rows) < self.chunksize
- return rows
+ rows = []
+ for i in xrange(min(max_rows, self.chunk_size)):
+ try:
+ row = self.current_reader.next()
+ self.num_read += 1
+
+ if 0 <= self.max_rows < self.num_read:
+ self.next_source()
+ break
+
+ if self.num_read > self.skip_rows:
+ rows.append(row)
+
+ except StopIteration:
+ self.next_source()
+ break
+
+ return filter(None, rows)
+
+
+class ImportErrors(object):
+ """
+ A small class for managing import errors
+ """
+ def __init__(self, task):
+ self.shell = task.shell
+ self.reader = task.reader
+ self.options = task.options
+ self.printmsg = task.printmsg
+ self.max_attempts = self.options.copy['maxattempts']
+ self.max_parse_errors = self.options.copy['maxparseerrors']
+ self.max_insert_errors = self.options.copy['maxinserterrors']
+ self.err_file = self.options.copy['errfile']
+ self.parse_errors = 0
+ self.insert_errors = 0
+ self.num_rows_failed = 0
+
+ if os.path.isfile(self.err_file):
+ now = datetime.datetime.now()
+ old_err_file = self.err_file + now.strftime('.%Y%m%d_%H%M%S')
+ self.printmsg("Renaming existing %s to %s\n" % (self.err_file,
old_err_file))
+ os.rename(self.err_file, old_err_file)
+
+ def max_exceeded(self):
+ if self.insert_errors > self.max_insert_errors >= 0:
+ self.shell.printerr("Exceeded maximum number of insert errors %d"
% self.max_insert_errors)
+ return True
+
+ if self.parse_errors > self.max_parse_errors >= 0:
+ self.shell.printerr("Exceeded maximum number of parse errors %d" %
self.max_parse_errors)
+ return True
+
+ return False
+
+ def add_failed_rows(self, rows):
+ self.num_rows_failed += len(rows)
+
+ with open(self.err_file, "a") as f:
+ writer = csv.writer(f, **self.options.dialect)
+ for row in rows:
+ writer.writerow(row)
+
+ def handle_error(self, err, batch):
+ """
+ Handle an error by printing the appropriate error message and
incrementing the correct counter.
+ Return true if we should retry this batch, false if the error is
non-recoverable
+ """
+ shell = self.shell
+ err = str(err)
+
+ if self.is_parse_error(err):
+ self.parse_errors += len(batch['rows'])
+ self.add_failed_rows(batch['rows'])
+ shell.printerr("Failed to import %d rows: %s - given up without
retries"
+ % (len(batch['rows']), err))
+ return False
+ else:
+ self.insert_errors += len(batch['rows'])
+ if batch['attempts'] < self.max_attempts:
+ shell.printerr("Failed to import %d rows: %s - will retry
later, attempt %d of %d"
+ % (len(batch['rows']), err, batch['attempts'],
+ self.max_attempts))
+ return True
+ else:
+ self.add_failed_rows(batch['rows'])
+ shell.printerr("Failed to import %d rows: %s - given up after
%d attempts"
+ % (len(batch['rows']), err, batch['attempts']))
+ return False
+
+ @staticmethod
+ def is_parse_error(err):
+ """
+ We treat parse errors as unrecoverable and we have different global
counters for giving up when
+ a maximum has been reached. We consider value and type errors as parse
errors as well since they
+ are typically non recoverable.
+ """
+ return err.startswith('ValueError') or err.startswith('TypeError') or \
+ err.startswith('ParseError') or err.startswith('IndexError')
class ImportTask(CopyTask):
@@ -362,44 +812,54 @@ class ImportTask(CopyTask):
A class to import data from .csv by instantiating one or more processes
that work in parallel (ImportProcess).
"""
- def __init__(self, shell, ks, cf, columns, fname, csv_options,
dialect_options, protocol_version, config_file):
- CopyTask.__init__(self, shell, ks, cf, columns, fname,
- csv_options, dialect_options, protocol_version,
config_file)
-
- self.num_processes = get_num_processes(cap=4)
- self.chunk_size = csv_options['chunksize']
- self.ingest_rate = csv_options['ingestrate']
- self.max_attempts = csv_options['maxattempts']
- self.header = self.csv_options['header']
- self.table_meta = self.shell.get_table_meta(self.ks, self.cf)
+ def __init__(self, shell, ks, table, columns, fname, opts,
protocol_version, config_file):
+ CopyTask.__init__(self, shell, ks, table, columns, fname, opts,
protocol_version, config_file, 'from')
+
+ options = self.options
+ self.ingest_rate = options.copy['ingestrate']
+ self.max_attempts = options.copy['maxattempts']
+ self.header = options.copy['header']
+ self.skip_columns = [c.strip() for c in
self.options.copy['skipcols'].split(',')]
+ self.valid_columns = [c for c in self.columns if c not in
self.skip_columns]
+ self.table_meta = self.shell.get_table_meta(self.ks, self.table)
self.batch_id = 0
- self.receive_meter =
RateMeter(update_interval=csv_options['reportfrequency'])
- self.send_meter = RateMeter(update_interval=1, log=False)
+ self.receive_meter = RateMeter(log_fcn=self.printmsg,
+
update_interval=options.copy['reportfrequency'],
+ log_file=options.copy['ratefile'])
+ self.send_meter = RateMeter(log_fcn=None, update_interval=1)
+ self.reader = ImportReader(self)
+ self.import_errors = ImportErrors(self)
self.retries = deque([])
self.failed = 0
self.succeeded = 0
self.sent = 0
+ def make_params(self):
+ ret = CopyTask.make_params(self)
+ ret['skip_columns'] = self.skip_columns
+ ret['valid_columns'] = self.valid_columns
+ return ret
+
def run(self):
shell = self.shell
- if self.fname is None:
- do_close = False
- print "[Use \. on a line by itself to end input]"
- linesource = shell.use_stdin_reader(prompt='[copy] ', until=r'\.')
- else:
- do_close = True
- try:
- linesource = open(self.fname, 'rb')
- except IOError, e:
- shell.printerr("Can't open %r for reading: %s" % (self.fname,
e))
+ if self.options.unrecognized:
+ shell.printerr('Unrecognized COPY FROM options: %s' % ',
'.join(self.options.unrecognized.keys()))
+ return
+
+ if not self.valid_columns:
+ shell.printerr("No column specified")
+ return 0
+
+ for c in self.table_meta.primary_key:
+ if c.name not in self.valid_columns:
+ shell.printerr("Primary key column '%s' missing or skipped" %
(c.name,))
return 0
- try:
- if self.header:
- linesource.next()
+ self.printmsg("\nStarting copy of %s.%s with columns %s." % (self.ks,
self.table, self.valid_columns))
- reader = ImportReader(linesource, self.chunk_size,
self.dialect_options)
+ try:
+ self.reader.start()
params = self.make_params()
for i in range(self.num_processes):
@@ -408,7 +868,7 @@ class ImportTask(CopyTask):
for process in self.processes:
process.start()
- return self.process_records(reader)
+ self.import_records()
except Exception, exc:
shell.printerr(str(exc))
@@ -417,31 +877,43 @@ class ImportTask(CopyTask):
return 0
finally:
self.close()
- if do_close:
- linesource.close()
- elif shell.tty:
- print
- def process_records(self, reader):
+ def close(self):
+ CopyTask.close(self)
+ self.reader.close()
+
+ def import_records(self):
"""
Keep on running until we have stuff to receive or send and until all
processes are running.
Send data (batches or retries) up to the max ingest rate. If we are
waiting for stuff to
receive check the incoming queue.
"""
- while (self.has_more_to_send(reader) or self.has_more_to_receive())
and self.all_processes_running():
+ reader = self.reader
+
+ while self.has_more_to_send(reader) or self.has_more_to_receive():
if self.has_more_to_send(reader):
- if self.send_meter.current_record <= self.ingest_rate:
- self.send_batches(reader)
- else:
- self.send_meter.maybe_update()
+ self.send_batches(reader)
if self.has_more_to_receive():
self.receive()
- if self.succeeded < self.sent:
- self.shell.printerr("Failed to process %d batches" % (self.sent -
self.succeeded))
+ if self.import_errors.max_exceeded() or not
self.all_processes_running():
+ break
+
+ if self.import_errors.num_rows_failed:
+ self.shell.printerr("Failed to process %d rows; failed rows
written to %s" %
+ (self.import_errors.num_rows_failed,
+ self.import_errors.err_file))
- return self.receive_meter.get_total_records()
+ if not self.all_processes_running():
+ self.shell.printerr("{} child process(es) died unexpectedly,
aborting"
+ .format(self.num_processes -
self.num_live_processes()))
+
+ self.printmsg("\n%d rows imported from %d files in %s (%d skipped)." %
+ (self.receive_meter.get_total_records(),
+ self.reader.num_sources,
+ self.describe_interval(time.time() - self.time_start),
+ self.reader.skip_rows))
def has_more_to_receive(self):
return (self.succeeded + self.failed) < self.sent
@@ -453,12 +925,11 @@ class ImportTask(CopyTask):
return self.num_live_processes() == self.num_processes
def receive(self):
- shell = self.shell
start_time = time.time()
- while time.time() - start_time < 0.01: # 10 millis
+ while time.time() - start_time < 0.001:
try:
- batch, err = self.inmsg.get(timeout=0.001) # 1 millisecond
+ batch, err = self.inmsg.get(timeout=0.00001)
if err is None:
self.succeeded += batch['imported']
@@ -466,35 +937,39 @@ class ImportTask(CopyTask):
else:
err = str(err)
- if err.startswith('ValueError') or
err.startswith('TypeError') or err.startswith('IndexError') \
- or batch['attempts'] >= self.max_attempts:
- shell.printerr("Failed to import %d rows: %s - given
up after %d attempts"
- % (len(batch['rows']), err,
batch['attempts']))
- self.failed += len(batch['rows'])
- else:
- shell.printerr("Failed to import %d rows: %s - will
retry later, attempt %d of %d"
- % (len(batch['rows']), err,
batch['attempts'],
- self.max_attempts))
+ if self.import_errors.handle_error(err, batch):
self.retries.append(self.reset_batch(batch))
+ else:
+ self.failed += len(batch['rows'])
+
except Queue.Empty:
- break
+ pass
def send_batches(self, reader):
"""
- Send batches to the queue until we have exceeded the ingest rate. In
the export case we queue
- everything and let the worker processes throttle using max_requests,
here we throttle
- in the parent process because of memory usage concerns.
+ Send one batch per worker process to the queue unless we have exceeded
the ingest rate.
+ In the export case we queue everything and let the worker processes
throttle using max_requests,
+ here we throttle using the ingest rate in the parent process because
of memory usage concerns.
When we have finished reading the csv file, then send any retries.
"""
- while self.send_meter.current_record <= self.ingest_rate:
+ for _ in xrange(self.num_processes):
+ max_rows = self.ingest_rate - self.send_meter.current_record
+ if max_rows <= 0:
+ self.send_meter.maybe_update()
+ break
+
if not reader.exhausted:
- rows = reader.read_rows()
+ rows = reader.read_rows(max_rows)
if rows:
self.sent += self.send_batch(self.new_batch(rows))
elif self.retries:
batch = self.retries.popleft()
- self.send_batch(batch)
+ if len(batch['rows']) <= max_rows:
+ self.send_batch(batch)
+ else:
+ self.send_batch(self.split_batch(batch,
batch['rows'][:max_rows]))
+ self.retries.append(self.split_batch(batch,
batch['rows'][max_rows:]))
else:
break
@@ -515,6 +990,10 @@ class ImportTask(CopyTask):
return batch
@staticmethod
+ def split_batch(batch, rows):
+ return ImportTask.make_batch(batch['id'], rows, batch['attempts'])
+
+ @staticmethod
def make_batch(batch_id, rows, attempts):
return {'id': batch_id, 'rows': rows, 'attempts': attempts,
'imported': 0}
@@ -529,12 +1008,12 @@ class ChildProcess(mp.Process):
self.inmsg = params['inmsg']
self.outmsg = params['outmsg']
self.ks = params['ks']
- self.cf = params['cf']
+ self.table = params['table']
+ self.local_dc = params['local_dc']
self.columns = params['columns']
self.debug = params['debug']
self.port = params['port']
self.hostname = params['hostname']
- self.consistency_level = params['consistency_level']
self.connect_timeout = params['connect_timeout']
self.cql_version = params['cql_version']
self.auth_provider = params['auth_provider']
@@ -542,18 +1021,24 @@ class ChildProcess(mp.Process):
self.protocol_version = params['protocol_version']
self.config_file = params['config_file']
+ options = params['options']
+ self.date_time_format = options.copy['dtformats']
+ self.consistency_level = options.copy['consistencylevel']
+ self.decimal_sep = options.copy['decimalsep']
+ self.thousands_sep = options.copy['thousandssep']
+ self.boolean_styles = options.copy['boolstyle']
# Here we inject some failures for testing purposes, only if this
environment variable is set
if os.environ.get('CQLSH_COPY_TEST_FAILURES', ''):
self.test_failures =
json.loads(os.environ.get('CQLSH_COPY_TEST_FAILURES', ''))
else:
self.test_failures = None
- def printmsg(self, text):
+ def printdebugmsg(self, text):
if self.debug:
- sys.stderr.write(text + os.linesep)
+ sys.stdout.write(text + '\n')
def close(self):
- self.printmsg("Closing queues...")
+ self.printdebugmsg("Closing queues...")
self.inmsg.close()
self.outmsg.close()
@@ -565,7 +1050,7 @@ class ExpBackoffRetryPolicy(RetryPolicy):
def __init__(self, parent_process):
RetryPolicy.__init__(self)
self.max_attempts = parent_process.max_attempts
- self.printmsg = parent_process.printmsg
+ self.printdebugmsg = parent_process.printdebugmsg
def on_read_timeout(self, query, consistency, required_responses,
received_responses, data_retrieved, retry_num):
@@ -578,14 +1063,14 @@ class ExpBackoffRetryPolicy(RetryPolicy):
def _handle_timeout(self, consistency, retry_num):
delay = self.backoff(retry_num)
if delay > 0:
- self.printmsg("Timeout received, retrying after %d seconds" %
(delay))
+ self.printdebugmsg("Timeout received, retrying after %d seconds" %
(delay,))
time.sleep(delay)
return self.RETRY, consistency
elif delay == 0:
- self.printmsg("Timeout received, retrying immediately")
+ self.printdebugmsg("Timeout received, retrying immediately")
return self.RETRY, consistency
else:
- self.printmsg("Timeout received, giving up after %d attempts" %
(retry_num + 1))
+ self.printdebugmsg("Timeout received, giving up after %d attempts"
% (retry_num + 1))
return self.RETHROW, None
def backoff(self, retry_num):
@@ -615,16 +1100,17 @@ class ExportSession(object):
def __init__(self, cluster, export_process):
session = cluster.connect(export_process.ks)
session.row_factory = tuple_factory
- session.default_fetch_size = export_process.csv_options['pagesize']
- session.default_timeout = export_process.csv_options['pagetimeout']
+ session.default_fetch_size = export_process.options.copy['pagesize']
+ session.default_timeout = export_process.options.copy['pagetimeout']
- export_process.printmsg("Created connection to %s with page size %d
and timeout %d seconds per page"
- % (session.hosts, session.default_fetch_size,
session.default_timeout))
+ export_process.printdebugmsg("Created connection to %s with page size
%d and timeout %d seconds per page"
+ % (session.hosts,
session.default_fetch_size, session.default_timeout))
self.cluster = cluster
self.session = session
self.requests = 1
self.lock = Lock()
+ self.consistency_level = export_process.consistency_level
def add_request(self):
with self.lock:
@@ -639,7 +1125,7 @@ class ExportSession(object):
return self.requests
def execute_async(self, query):
- return self.session.execute_async(query)
+ return self.session.execute_async(SimpleStatement(query,
consistency_level=self.consistency_level))
def shutdown(self):
self.cluster.shutdown()
@@ -652,18 +1138,16 @@ class ExportProcess(ChildProcess):
def __init__(self, params):
ChildProcess.__init__(self, params=params, target=self.run)
- self.dialect_options = params['dialect_options']
- self.hosts_to_sessions = dict()
+ options = params['options']
+ self.encoding = options.copy['encoding']
+ self.float_precision = options.copy['float_precision']
+ self.nullval = options.copy['nullval']
+ self.max_attempts = options.copy['maxattempts']
+ self.max_requests = options.copy['maxrequests']
- csv_options = params['csv_options']
- self.encoding = csv_options['encoding']
- self.date_time_format = csv_options['dtformats']
- self.float_precision = csv_options['float_precision']
- self.nullval = csv_options['nullval']
- self.max_attempts = csv_options['maxattempts']
- self.max_requests = csv_options['maxrequests']
- self.csv_options = csv_options
+ self.hosts_to_sessions = dict()
self.formatters = dict()
+ self.options = options
def run(self):
try:
@@ -699,7 +1183,7 @@ class ExportProcess(ChildProcess):
else:
msg = str(err)
- self.printmsg(msg)
+ self.printdebugmsg(msg)
self.outmsg.put((token_range, Exception(msg)))
def start_request(self, token_range, info):
@@ -708,7 +1192,7 @@ class ExportProcess(ChildProcess):
will later on invoke the callbacks attached in attach_callbacks.
"""
session = self.get_session(info['hosts'])
- metadata = session.cluster.metadata.keyspaces[self.ks].tables[self.cf]
+ metadata =
session.cluster.metadata.keyspaces[self.ks].tables[self.table]
query = self.prepare_query(metadata.partition_key, token_range,
info['attempts'])
future = session.execute_async(query)
self.attach_callbacks(token_range, future, session)
@@ -736,13 +1220,15 @@ class ExportProcess(ChildProcess):
ssl_options=ssl_settings(host, self.config_file) if self.ssl
else None,
load_balancing_policy=TokenAwarePolicy(WhiteListRoundRobinPolicy(hosts)),
default_retry_policy=ExpBackoffRetryPolicy(self),
- compression=None)
+ compression=None,
+ control_connection_timeout=self.connect_timeout,
+ connect_timeout=self.connect_timeout)
session = ExportSession(new_cluster, self)
self.hosts_to_sessions[host] = session
return session
else:
- host = min(hosts, key=lambda h: self.hosts_to_sessions[h].requests)
+ host = min(hosts, key=lambda hh:
self.hosts_to_sessions[hh].requests)
session = self.hosts_to_sessions[host]
session.add_request()
return session
@@ -769,7 +1255,7 @@ class ExportProcess(ChildProcess):
try:
output = StringIO()
- writer = csv.writer(output, **self.dialect_options)
+ writer = csv.writer(output, **self.options.dialect)
for row in rows:
writer.writerow(map(self.format_value, row))
@@ -792,7 +1278,9 @@ class ExportProcess(ChildProcess):
self.formatters[ctype] = formatter
return formatter(val, encoding=self.encoding, colormap=NO_COLOR_MAP,
date_time_format=self.date_time_format,
- float_precision=self.float_precision,
nullval=self.nullval, quote=False)
+ float_precision=self.float_precision,
nullval=self.nullval, quote=False,
+ decimal_sep=self.decimal_sep,
thousands_sep=self.thousands_sep,
+ boolean_styles=self.boolean_styles)
def close(self):
ChildProcess.close(self)
@@ -841,7 +1329,7 @@ class ExportProcess(ChildProcess):
pk_cols = ", ".join(protect_names(col.name for col in partition_key))
columnlist = ', '.join(protect_names(self.columns))
start_token, end_token = token_range
- query = 'SELECT %s FROM %s.%s' % (columnlist, protect_name(self.ks),
protect_name(self.cf))
+ query = 'SELECT %s FROM %s.%s' % (columnlist, protect_name(self.ks),
protect_name(self.table))
if start_token is not None or end_token is not None:
query += ' WHERE'
if start_token is not None:
@@ -853,6 +1341,11 @@ class ExportProcess(ChildProcess):
return query
+class ParseError(Exception):
+ """ We failed to parse an import record """
+ pass
+
+
class ImportConversion(object):
"""
A class for converting strings to values when importing from csv, used by
ImportProcess,
@@ -860,10 +1353,15 @@ class ImportConversion(object):
"""
def __init__(self, parent, table_meta, statement):
self.ks = parent.ks
- self.cf = parent.cf
- self.columns = parent.columns
+ self.table = parent.table
+ self.columns = parent.valid_columns
self.nullval = parent.nullval
- self.printmsg = parent.printmsg
+ self.printdebugmsg = parent.printdebugmsg
+ self.decimal_sep = parent.decimal_sep
+ self.thousands_sep = parent.thousands_sep
+ self.boolean_styles = parent.boolean_styles
+ self.date_time_format = parent.date_time_format.timestamp_format
+
self.table_meta = table_meta
self.primary_key_indexes = [self.columns.index(col.name) for col in
self.table_meta.primary_key]
self.partition_key_indexes = [self.columns.index(col.name) for col in
self.table_meta.partition_key]
@@ -885,6 +1383,40 @@ class ImportConversion(object):
def convert(t, v):
return converters.get(t.typename, convert_unknown)(unprotect(v),
ct=t)
+ def convert_blob(v, **_):
+ return bytearray.fromhex(v[2:])
+
+ def convert_text(v, **_):
+ return v
+
+ def convert_uuid(v, **_):
+ return UUID(v)
+
+ def convert_bool(v, **_):
+ return True if v.lower() == self.boolean_styles[0].lower() else
False
+
+ def get_convert_integer_fcn(adapter=int):
+ """
+ Return a slow and a fast integer conversion function depending on
self.thousands_sep
+ """
+ if self.thousands_sep:
+ return lambda v, ct=cql_type:
adapter(v.replace(self.thousands_sep, ''))
+ else:
+ return lambda v, ct=cql_type: adapter(v)
+
+ def get_convert_decimal_fcn(adapter=float):
+ """
+ Return a slow and a fast decimal conversion function depending on
self.thousands_sep and self.decimal_sep
+ """
+ if self.thousands_sep and self.decimal_sep:
+ return lambda v, ct=cql_type:
adapter(v.replace(self.thousands_sep, '').replace(self.decimal_sep, '.'))
+ elif self.thousands_sep:
+ return lambda v, ct=cql_type:
adapter(v.replace(self.thousands_sep, ''))
+ elif self.decimal_sep:
+ return lambda v, ct=cql_type:
adapter(v.replace(self.decimal_sep, '.'))
+ else:
+ return lambda v, ct=cql_type: adapter(v)
+
def split(val, sep=','):
"""
Split into a list of values whenever we encounter a separator but
@@ -917,17 +1449,23 @@ class ImportConversion(object):
"(?:(\d{2}):(\d{2})(?::(\d{2}))?)?" + # [HH:MM[:SS]]
"(?:([+\-])(\d{2}):?(\d{2}))?") # [(+|-)HH[:]MM]]
- def convert_date(val, **_):
+ def convert_datetime(val, **_):
+ try:
+ tval = time.strptime(val, self.date_time_format)
+ return timegm(tval) * 1e3 # scale seconds to millis for the
raw value
+ except ValueError:
+ pass # if it's not in the default format we try CQL formats
+
m = p.match(val)
if not m:
- raise ValueError("can't interpret %r as a date" % (val,))
+ raise ValueError("can't interpret %r as a date with this
format: %s" % (val, self.date_time_format))
# https://docs.python.org/2/library/time.html#time.struct_time
tval = time.struct_time((int(m.group(1)), int(m.group(2)),
int(m.group(3)), # year, month, day
- int(m.group(4)) if m.group(4) else 0, #
hour
- int(m.group(5)) if m.group(5) else 0, #
minute
- int(m.group(6)) if m.group(6) else 0, #
second
- 0, 1, -1)) # day of week, day of year,
dst-flag
+ int(m.group(4)) if m.group(4) else 0, #
hour
+ int(m.group(5)) if m.group(5) else 0, #
minute
+ int(m.group(6)) if m.group(6) else 0, #
second
+ 0, 1, -1)) # day of week, day of year,
dst-flag
if m.group(7):
offset = (int(m.group(8)) * 3600 + int(m.group(9)) * 60) *
int(m.group(7) + '1')
@@ -937,6 +1475,12 @@ class ImportConversion(object):
# scale seconds to millis for the raw value
return (timegm(tval) + offset) * 1e3
+ def convert_date(v, **_):
+ return Date(v)
+
+ def convert_time(v, **_):
+ return Time(v)
+
def convert_tuple(val, ct=cql_type):
return tuple(convert(t, v) for t, v in zip(ct.subtypes,
split(val)))
@@ -979,30 +1523,30 @@ class ImportConversion(object):
elif issubclass(ct, ReversedType):
return convert_single_subtype(val, ct=ct)
- self.printmsg("Unknown type %s (%s) for val %s" % (ct,
ct.typename, val))
+ self.printdebugmsg("Unknown type %s (%s) for val %s" % (ct,
ct.typename, val))
return val
converters = {
- 'blob': (lambda v, ct=cql_type: bytearray.fromhex(v[2:])),
- 'decimal': (lambda v, ct=cql_type: Decimal(v)),
- 'uuid': (lambda v, ct=cql_type: UUID(v)),
- 'boolean': (lambda v, ct=cql_type: bool(v)),
- 'tinyint': (lambda v, ct=cql_type: int(v)),
- 'ascii': (lambda v, ct=cql_type: v),
- 'float': (lambda v, ct=cql_type: float(v)),
- 'double': (lambda v, ct=cql_type: float(v)),
- 'bigint': (lambda v, ct=cql_type: long(v)),
- 'int': (lambda v, ct=cql_type: int(v)),
- 'varint': (lambda v, ct=cql_type: int(v)),
- 'inet': (lambda v, ct=cql_type: v),
- 'counter': (lambda v, ct=cql_type: long(v)),
- 'timestamp': convert_date,
- 'timeuuid': (lambda v, ct=cql_type: UUID(v)),
- 'date': (lambda v, ct=cql_type: Date(v)),
- 'smallint': (lambda v, ct=cql_type: int(v)),
- 'time': (lambda v, ct=cql_type: Time(v)),
- 'text': (lambda v, ct=cql_type: v),
- 'varchar': (lambda v, ct=cql_type: v),
+ 'blob': convert_blob,
+ 'decimal': get_convert_decimal_fcn(adapter=Decimal),
+ 'uuid': convert_uuid,
+ 'boolean': convert_bool,
+ 'tinyint': get_convert_integer_fcn(),
+ 'ascii': convert_text,
+ 'float': get_convert_decimal_fcn(),
+ 'double': get_convert_decimal_fcn(),
+ 'bigint': get_convert_integer_fcn(adapter=long),
+ 'int': get_convert_integer_fcn(),
+ 'varint': get_convert_integer_fcn(),
+ 'inet': convert_text,
+ 'counter': get_convert_integer_fcn(adapter=long),
+ 'timestamp': convert_datetime,
+ 'timeuuid': convert_uuid,
+ 'date': convert_date,
+ 'smallint': get_convert_integer_fcn(),
+ 'time': convert_time,
+ 'text': convert_text,
+ 'varchar': convert_text,
'list': convert_list,
'set': convert_set,
'map': convert_map,
@@ -1016,13 +1560,19 @@ class ImportConversion(object):
"""
Parse the row into a list of row values to be returned
"""
+ def convert(n, val):
+ try:
+ return self.converters[self.columns[n]](val)
+ except Exception, e:
+ raise ParseError(e.message)
+
ret = [None] * len(row)
for i, val in enumerate(row):
if val != self.nullval:
- ret[i] = self.converters[self.columns[i]](val)
+ ret[i] = convert(i, val)
else:
if i in self.primary_key_indexes:
- raise ValueError(self.get_null_primary_key_message(i))
+ raise ParseError(self.get_null_primary_key_message(i))
ret[i] = None
@@ -1041,10 +1591,13 @@ class ImportConversion(object):
as expected by metadata.get_replicas(), see also
BoundStatement.routing_key.
"""
def serialize(n):
- c, v = self.columns[n], row[n]
- if v == self.nullval:
- raise ValueError(self.get_null_primary_key_message(n))
- return self.cqltypes[c].serialize(self.converters[c](v),
self.proto_version)
+ try:
+ c, v = self.columns[n], row[n]
+ if v == self.nullval:
+ raise ParseError(self.get_null_primary_key_message(n))
+ return self.cqltypes[c].serialize(self.converters[c](v),
self.proto_version)
+ except Exception, e:
+ raise ParseError(e.message)
partition_key_indexes = self.partition_key_indexes
if len(partition_key_indexes) == 1:
@@ -1063,11 +1616,15 @@ class ImportProcess(ChildProcess):
def __init__(self, params):
ChildProcess.__init__(self, params=params, target=self.run)
- csv_options = params['csv_options']
- self.nullval = csv_options['nullval']
- self.max_attempts = csv_options['maxattempts']
- self.min_batch_size = csv_options['minbatchsize']
- self.max_batch_size = csv_options['maxbatchsize']
+ self.skip_columns = params['skip_columns']
+ self.valid_columns = params['valid_columns']
+ self.skip_column_indexes = [i for i, c in enumerate(self.columns) if c
in self.skip_columns]
+
+ options = params['options']
+ self.nullval = options.copy['nullval']
+ self.max_attempts = options.copy['maxattempts']
+ self.min_batch_size = options.copy['minbatchsize']
+ self.max_batch_size = options.copy['maxbatchsize']
self._session = None
@property
@@ -1079,10 +1636,11 @@ class ImportProcess(ChildProcess):
cql_version=self.cql_version,
protocol_version=self.protocol_version,
auth_provider=self.auth_provider,
-
load_balancing_policy=TokenAwarePolicy(DCAwareRoundRobinPolicy()),
+
load_balancing_policy=TokenAwarePolicy(DCAwareRoundRobinPolicy(local_dc=self.local_dc)),
ssl_options=ssl_settings(self.hostname, self.config_file) if
self.ssl else None,
default_retry_policy=ExpBackoffRetryPolicy(self),
compression=None,
+ control_connection_timeout=self.connect_timeout,
connect_timeout=self.connect_timeout)
self._session = cluster.connect(self.ks)
@@ -1091,8 +1649,8 @@ class ImportProcess(ChildProcess):
def run(self):
try:
- table_meta =
self.session.cluster.metadata.keyspaces[self.ks].tables[self.cf]
- is_counter = ("counter" in [table_meta.columns[name].cql_type for
name in self.columns])
+ table_meta =
self.session.cluster.metadata.keyspaces[self.ks].tables[self.table]
+ is_counter = ("counter" in [table_meta.columns[name].cql_type for
name in self.valid_columns])
if is_counter:
self.run_counter(table_meta)
@@ -1115,22 +1673,20 @@ class ImportProcess(ChildProcess):
"""
Main run method for tables that contain counter columns.
"""
- query = 'UPDATE %s.%s SET %%s WHERE %%s' % (protect_name(self.ks),
protect_name(self.cf))
+ query = 'UPDATE %s.%s SET %%s WHERE %%s' % (protect_name(self.ks),
protect_name(self.table))
# We prepare a query statement to find out the types of the partition
key columns so we can
# route the update query to the correct replicas. As far as I
understood this is the easiest
# way to find out the types of the partition columns, we will never
use this prepared statement
where_clause = ' AND '.join(['%s = ?' % (protect_name(c.name)) for c
in table_meta.partition_key])
- select_query = 'SELECT * FROM %s.%s WHERE %s' %
(protect_name(self.ks), protect_name(self.cf), where_clause)
+ select_query = 'SELECT * FROM %s.%s WHERE %s' %
(protect_name(self.ks), protect_name(self.table), where_clause)
conv = ImportConversion(self, table_meta,
self.session.prepare(select_query))
while True:
+ batch = self.inmsg.get()
try:
- batch = self.inmsg.get()
-
- for batches in self.split_batches(batch, conv):
- for b in batches:
- self.send_counter_batch(query, conv, b)
+ for b in self.split_batches(batch, conv):
+ self.send_counter_batch(query, conv, b)
except Exception, exc:
self.outmsg.put((batch, '%s - %s' % (exc.__class__.__name__,
exc.message)))
@@ -1142,19 +1698,19 @@ class ImportProcess(ChildProcess):
Main run method for normal tables, i.e. tables that do not contain
counter columns.
"""
query = 'INSERT INTO %s.%s (%s) VALUES (%s)' % (protect_name(self.ks),
- protect_name(self.cf),
- ',
'.join(protect_names(self.columns),),
- ', '.join(['?' for _
in self.columns]))
+
protect_name(self.table),
+ ',
'.join(protect_names(self.valid_columns),),
+ ', '.join(['?' for _
in self.valid_columns]))
+
query_statement = self.session.prepare(query)
+ query_statement.consistency_level = self.consistency_level
conv = ImportConversion(self, table_meta, query_statement)
while True:
+ batch = self.inmsg.get()
try:
- batch = self.inmsg.get()
-
- for batches in self.split_batches(batch, conv):
- for b in batches:
- self.send_normal_batch(conv, query_statement, b)
+ for b in self.split_batches(batch, conv):
+ self.send_normal_batch(conv, query_statement, b)
except Exception, exc:
self.outmsg.put((batch, '%s - %s' % (exc.__class__.__name__,
exc.message)))
@@ -1165,35 +1721,82 @@ class ImportProcess(ChildProcess):
if self.test_failures and self.maybe_inject_failures(batch):
return
- columns = self.columns
+ error_rows = []
batch_statement = BatchStatement(batch_type=BatchType.COUNTER,
consistency_level=self.consistency_level)
- for row in batch['rows']:
+
+ for r in batch['rows']:
+ row = self.filter_row_values(r)
+ if len(row) != len(self.valid_columns):
+ error_rows.append(row)
+ continue
+
where_clause = []
set_clause = []
for i, value in enumerate(row):
if i in conv.primary_key_indexes:
- where_clause.append("%s=%s" % (columns[i], value))
+ where_clause.append("%s=%s" % (self.valid_columns[i],
value))
else:
- set_clause.append("%s=%s+%s" % (columns[i], columns[i],
value))
+ set_clause.append("%s=%s+%s" % (self.valid_columns[i],
self.valid_columns[i], value))
full_query_text = query_text % (','.join(set_clause), ' AND
'.join(where_clause))
batch_statement.add(full_query_text)
self.execute_statement(batch_statement, batch)
+ if error_rows:
+ self.outmsg.put((ImportTask.split_batch(batch, error_rows),
+ '%s - %s' % (ParseError.__name__, "Failed to parse
one or more rows")))
+
def send_normal_batch(self, conv, query_statement, batch):
- try:
- if self.test_failures and self.maybe_inject_failures(batch):
- return
+ if self.test_failures and self.maybe_inject_failures(batch):
+ return
- batch_statement = BatchStatement(batch_type=BatchType.UNLOGGED,
consistency_level=self.consistency_level)
- for row in batch['rows']:
- batch_statement.add(query_statement, conv.get_row_values(row))
+ good_rows, converted_rows, errors = self.convert_rows(conv,
batch['rows'])
- self.execute_statement(batch_statement, batch)
+ if converted_rows:
+ try:
+ statement = BatchStatement(batch_type=BatchType.UNLOGGED,
consistency_level=self.consistency_level)
+ for row in converted_rows:
+ statement.add(query_statement, row)
+ self.execute_statement(statement,
ImportTask.split_batch(batch, good_rows))
+ except Exception, exc:
+ self.err_callback(exc, ImportTask.split_batch(batch,
good_rows))
- except Exception, exc:
- self.err_callback(exc, batch)
+ if errors:
+ for msg, rows in errors.iteritems():
+ self.outmsg.put((ImportTask.split_batch(batch, rows),
+ '%s - %s' % (ParseError.__name__, msg)))
+
+ def convert_rows(self, conv, rows):
+ """
+ Try to convert each row. If conversion is OK then add the converted
result to converted_rows
+ and the original string to good_rows. Else add the original string to
error_rows. Return the three
+ arrays.
+ """
+ good_rows = []
+ errors = defaultdict(list)
+ converted_rows = []
+
+ for r in rows:
+ row = self.filter_row_values(r)
+ if len(row) != len(self.valid_columns):
+ msg = 'Invalid row length %d should be %d' % (len(row),
len(self.valid_columns))
+ errors[msg].append(row)
+ continue
+
+ try:
+ converted_rows.append(conv.get_row_values(row))
+ good_rows.append(row)
+ except ParseError, err:
+ errors[err.message].append(row)
+
+ return good_rows, converted_rows, errors
+
+ def filter_row_values(self, row):
+ if not self.skip_column_indexes:
+ return row
+
+ return [v for i, v in enumerate(row) if i not in
self.skip_column_indexes]
def maybe_inject_failures(self, batch):
"""
@@ -1225,43 +1828,66 @@ class ImportProcess(ChildProcess):
def split_batches(self, batch, conv):
"""
- Split a batch into sub-batches with the same
- partition key, if possible. If there are at least
- batch_size rows with the same partition key value then
- create a sub-batch with that partition key value, else
- aggregate all remaining rows in a single 'left-overs' batch
+ Batch rows by partition key, if there are at least min_batch_size (2)
+ rows with the same partition key. These batches can be as big as they
want
+ since this translates to a single insert operation server side.
+
+ If there are less than min_batch_size rows for a partition, work out
the
+ first replica for this partition and add the rows to replica left-over
rows.
+
+ Then batch the left-overs of each replica up to max_batch_size.
"""
rows_by_pk = defaultdict(list)
+ errors = defaultdict(list)
for row in batch['rows']:
- pk = conv.get_row_partition_key_values(row)
- rows_by_pk[pk].append(row)
-
- ret = dict()
- remaining_rows = []
-
- for pk, rows in rows_by_pk.items():
+ try:
+ pk = conv.get_row_partition_key_values(row)
+ rows_by_pk[pk].append(row)
+ except ParseError, e:
+ errors[e.message].append(row)
+
+ if errors:
+ for msg, rows in errors.iteritems():
+ self.outmsg.put((ImportTask.split_batch(batch, rows),
+ '%s - %s' % (ParseError.__name__, msg)))
+
+ rows_by_replica = defaultdict(list)
+ for pk, rows in rows_by_pk.iteritems():
if len(rows) >= self.min_batch_size:
- ret[pk] = self.batches(rows, batch)
+ yield ImportTask.make_batch(batch['id'], rows,
batch['attempts'])
else:
- remaining_rows.extend(rows)
+ replica = self.get_replica(pk)
+ rows_by_replica[replica].extend(rows)
- if remaining_rows:
- ret[self.hostname] = self.batches(remaining_rows, batch)
+ for replica, rows in rows_by_replica.iteritems():
+ for b in self.batches(rows, batch):
+ yield b
- return ret.itervalues()
+ def get_replica(self, pk):
+ """
+ Return the first replica or the host we are already connected to if
there are no local
+ replicas that are up. We always use the first replica to match the
replica chosen by the driver
+ TAR, see TokenAwarePolicy.make_query_plan().
+ """
+ metadata = self.session.cluster.metadata
+ replicas = filter(lambda r: r.is_up and r.datacenter == self.local_dc,
metadata.get_replicas(self.ks, pk))
+ ret = replicas[0].address if len(replicas) > 0 else self.hostname
+ return ret
def batches(self, rows, batch):
+ """
+ Split rows into batches of max_batch_size
+ """
for i in xrange(0, len(rows), self.max_batch_size):
yield ImportTask.make_batch(batch['id'], rows[i:i +
self.max_batch_size], batch['attempts'])
- def result_callback(self, result, batch):
+ def result_callback(self, _, batch):
batch['imported'] = len(batch['rows'])
- batch['rows'] = [] # no need to resend these
+ batch['rows'] = [] # no need to resend these, just send the count in
'imported'
self.outmsg.put((batch, None))
def err_callback(self, response, batch):
- batch['imported'] = len(batch['rows'])
self.outmsg.put((batch, '%s - %s' % (response.__class__.__name__,
response.message)))
if self.debug:
traceback.print_exc(response)
@@ -1269,15 +1895,19 @@ class ImportProcess(ChildProcess):
class RateMeter(object):
- def __init__(self, update_interval=0.25, log=True):
- self.log = log # true if we should log
+ def __init__(self, log_fcn, update_interval=0.25, log_file=''):
+ self.log_fcn = log_fcn # the function for logging, may be None to
disable logging
self.update_interval = update_interval # how often we update in
seconds
+ self.log_file = log_file # an optional file where to log statistics
in addition to stdout
self.start_time = time.time() # the start time
self.last_checkpoint_time = self.start_time # last time we logged
self.current_rate = 0.0 # rows per second
self.current_record = 0 # number of records since we last updated
self.total_records = 0 # total number of records
+ if os.path.isfile(self.log_file):
+ os.unlink(self.log_file)
+
def increment(self, n=1):
self.current_record += n
self.maybe_update()
@@ -1315,11 +1945,15 @@ class RateMeter(object):
return self.total_records / time_difference if time_difference >=
1e-09 else 0
def log_message(self):
- if self.log:
- output = 'Processed: %d rows; Rate: %7.0f rows/s; Avg. rage: %7.0f
rows/s\r' % \
- (self.total_records, self.current_rate,
self.get_avg_rate())
- sys.stdout.write(output)
- sys.stdout.flush()
+ if not self.log_fcn:
+ return
+
+ output = 'Processed: %d rows; Rate: %7.0f rows/s; Avg. rate: %7.0f
rows/s\r' % \
+ (self.total_records, self.current_rate, self.get_avg_rate())
+ self.log_fcn(output, eol='\r')
+ if self.log_file:
+ with open(self.log_file, "a") as f:
+ f.write(output + '\n')
def get_total_records(self):
self.update(time.time())