http://git-wip-us.apache.org/repos/asf/spark/blob/08488c17/python/pyspark/sql/__init__.py ---------------------------------------------------------------------- diff --git a/python/pyspark/sql/__init__.py b/python/pyspark/sql/__init__.py new file mode 100644 index 0000000..0a5ba00 --- /dev/null +++ b/python/pyspark/sql/__init__.py @@ -0,0 +1,42 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +public classes of Spark SQL: + + - L{SQLContext} + Main entry point for SQL functionality. + - L{DataFrame} + A Resilient Distributed Dataset (RDD) with Schema information for the data contained. In + addition to normal RDD operations, DataFrames also support SQL. + - L{GroupedData} + - L{Column} + Column is a DataFrame with a single column. + - L{Row} + A Row of data returned by a Spark SQL query. + - L{HiveContext} + Main entry point for accessing data stored in Apache Hive.. +""" + +from pyspark.sql.context import SQLContext, HiveContext +from pyspark.sql.types import Row +from pyspark.sql.dataframe import DataFrame, GroupedData, Column, Dsl, SchemaRDD + +__all__ = [ + 'SQLContext', 'HiveContext', 'DataFrame', 'GroupedData', 'Column', 'Row', + 'Dsl', +]
http://git-wip-us.apache.org/repos/asf/spark/blob/08488c17/python/pyspark/sql/context.py ---------------------------------------------------------------------- diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py new file mode 100644 index 0000000..49f016a --- /dev/null +++ b/python/pyspark/sql/context.py @@ -0,0 +1,642 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import warnings +import json +from array import array +from itertools import imap + +from py4j.protocol import Py4JError + +from pyspark.rdd import _prepare_for_python_RDD +from pyspark.serializers import AutoBatchedSerializer, PickleSerializer +from pyspark.sql.types import StringType, StructType, _verify_type, \ + _infer_schema, _has_nulltype, _merge_type, _create_converter, _python_to_sql_converter +from pyspark.sql.dataframe import DataFrame + +__all__ = ["SQLContext", "HiveContext"] + + +class SQLContext(object): + + """Main entry point for Spark SQL functionality. + + A SQLContext can be used create L{DataFrame}, register L{DataFrame} as + tables, execute SQL over tables, cache tables, and read parquet files. + """ + + def __init__(self, sparkContext, sqlContext=None): + """Create a new SQLContext. + + :param sparkContext: The SparkContext to wrap. + :param sqlContext: An optional JVM Scala SQLContext. If set, we do not instatiate a new + SQLContext in the JVM, instead we make all calls to this object. + + >>> df = sqlCtx.inferSchema(rdd) + >>> sqlCtx.inferSchema(df) # doctest: +IGNORE_EXCEPTION_DETAIL + Traceback (most recent call last): + ... + TypeError:... + + >>> bad_rdd = sc.parallelize([1,2,3]) + >>> sqlCtx.inferSchema(bad_rdd) # doctest: +IGNORE_EXCEPTION_DETAIL + Traceback (most recent call last): + ... + ValueError:... + + >>> from datetime import datetime + >>> allTypes = sc.parallelize([Row(i=1, s="string", d=1.0, l=1L, + ... b=True, list=[1, 2, 3], dict={"s": 0}, row=Row(a=1), + ... time=datetime(2014, 8, 1, 14, 1, 5))]) + >>> df = sqlCtx.inferSchema(allTypes) + >>> df.registerTempTable("allTypes") + >>> sqlCtx.sql('select i+1, d+1, not b, list[1], dict["s"], time, row.a ' + ... 'from allTypes where b and i > 0').collect() + [Row(c0=2, c1=2.0, c2=False, c3=2, c4=0...8, 1, 14, 1, 5), a=1)] + >>> df.map(lambda x: (x.i, x.s, x.d, x.l, x.b, x.time, + ... x.row.a, x.list)).collect() + [(1, u'string', 1.0, 1, True, ...(2014, 8, 1, 14, 1, 5), 1, [1, 2, 3])] + """ + self._sc = sparkContext + self._jsc = self._sc._jsc + self._jvm = self._sc._jvm + self._scala_SQLContext = sqlContext + + @property + def _ssql_ctx(self): + """Accessor for the JVM Spark SQL context. + + Subclasses can override this property to provide their own + JVM Contexts. + """ + if self._scala_SQLContext is None: + self._scala_SQLContext = self._jvm.SQLContext(self._jsc.sc()) + return self._scala_SQLContext + + def registerFunction(self, name, f, returnType=StringType()): + """Registers a lambda function as a UDF so it can be used in SQL statements. + + In addition to a name and the function itself, the return type can be optionally specified. + When the return type is not given it default to a string and conversion will automatically + be done. For any other return type, the produced object must match the specified type. + + >>> sqlCtx.registerFunction("stringLengthString", lambda x: len(x)) + >>> sqlCtx.sql("SELECT stringLengthString('test')").collect() + [Row(c0=u'4')] + >>> from pyspark.sql.types import IntegerType + >>> sqlCtx.registerFunction("stringLengthInt", lambda x: len(x), IntegerType()) + >>> sqlCtx.sql("SELECT stringLengthInt('test')").collect() + [Row(c0=4)] + """ + func = lambda _, it: imap(lambda x: f(*x), it) + ser = AutoBatchedSerializer(PickleSerializer()) + command = (func, None, ser, ser) + pickled_cmd, bvars, env, includes = _prepare_for_python_RDD(self._sc, command, self) + self._ssql_ctx.udf().registerPython(name, + bytearray(pickled_cmd), + env, + includes, + self._sc.pythonExec, + bvars, + self._sc._javaAccumulator, + returnType.json()) + + def inferSchema(self, rdd, samplingRatio=None): + """Infer and apply a schema to an RDD of L{Row}. + + When samplingRatio is specified, the schema is inferred by looking + at the types of each row in the sampled dataset. Otherwise, the + first 100 rows of the RDD are inspected. Nested collections are + supported, which can include array, dict, list, Row, tuple, + namedtuple, or object. + + Each row could be L{pyspark.sql.Row} object or namedtuple or objects. + Using top level dicts is deprecated, as dict is used to represent Maps. + + If a single column has multiple distinct inferred types, it may cause + runtime exceptions. + + >>> rdd = sc.parallelize( + ... [Row(field1=1, field2="row1"), + ... Row(field1=2, field2="row2"), + ... Row(field1=3, field2="row3")]) + >>> df = sqlCtx.inferSchema(rdd) + >>> df.collect()[0] + Row(field1=1, field2=u'row1') + + >>> NestedRow = Row("f1", "f2") + >>> nestedRdd1 = sc.parallelize([ + ... NestedRow(array('i', [1, 2]), {"row1": 1.0}), + ... NestedRow(array('i', [2, 3]), {"row2": 2.0})]) + >>> df = sqlCtx.inferSchema(nestedRdd1) + >>> df.collect() + [Row(f1=[1, 2], f2={u'row1': 1.0}), ..., f2={u'row2': 2.0})] + + >>> nestedRdd2 = sc.parallelize([ + ... NestedRow([[1, 2], [2, 3]], [1, 2]), + ... NestedRow([[2, 3], [3, 4]], [2, 3])]) + >>> df = sqlCtx.inferSchema(nestedRdd2) + >>> df.collect() + [Row(f1=[[1, 2], [2, 3]], f2=[1, 2]), ..., f2=[2, 3])] + + >>> from collections import namedtuple + >>> CustomRow = namedtuple('CustomRow', 'field1 field2') + >>> rdd = sc.parallelize( + ... [CustomRow(field1=1, field2="row1"), + ... CustomRow(field1=2, field2="row2"), + ... CustomRow(field1=3, field2="row3")]) + >>> df = sqlCtx.inferSchema(rdd) + >>> df.collect()[0] + Row(field1=1, field2=u'row1') + """ + + if isinstance(rdd, DataFrame): + raise TypeError("Cannot apply schema to DataFrame") + + first = rdd.first() + if not first: + raise ValueError("The first row in RDD is empty, " + "can not infer schema") + if type(first) is dict: + warnings.warn("Using RDD of dict to inferSchema is deprecated," + "please use pyspark.sql.Row instead") + + if samplingRatio is None: + schema = _infer_schema(first) + if _has_nulltype(schema): + for row in rdd.take(100)[1:]: + schema = _merge_type(schema, _infer_schema(row)) + if not _has_nulltype(schema): + break + else: + warnings.warn("Some of types cannot be determined by the " + "first 100 rows, please try again with sampling") + else: + if samplingRatio > 0.99: + rdd = rdd.sample(False, float(samplingRatio)) + schema = rdd.map(_infer_schema).reduce(_merge_type) + + converter = _create_converter(schema) + rdd = rdd.map(converter) + return self.applySchema(rdd, schema) + + def applySchema(self, rdd, schema): + """ + Applies the given schema to the given RDD of L{tuple} or L{list}. + + These tuples or lists can contain complex nested structures like + lists, maps or nested rows. + + The schema should be a StructType. + + It is important that the schema matches the types of the objects + in each row or exceptions could be thrown at runtime. + + >>> from pyspark.sql.types import * + >>> rdd2 = sc.parallelize([(1, "row1"), (2, "row2"), (3, "row3")]) + >>> schema = StructType([StructField("field1", IntegerType(), False), + ... StructField("field2", StringType(), False)]) + >>> df = sqlCtx.applySchema(rdd2, schema) + >>> sqlCtx.registerRDDAsTable(df, "table1") + >>> df2 = sqlCtx.sql("SELECT * from table1") + >>> df2.collect() + [Row(field1=1, field2=u'row1'),..., Row(field1=3, field2=u'row3')] + + >>> from datetime import date, datetime + >>> rdd = sc.parallelize([(127, -128L, -32768, 32767, 2147483647L, 1.0, + ... date(2010, 1, 1), + ... datetime(2010, 1, 1, 1, 1, 1), + ... {"a": 1}, (2,), [1, 2, 3], None)]) + >>> schema = StructType([ + ... StructField("byte1", ByteType(), False), + ... StructField("byte2", ByteType(), False), + ... StructField("short1", ShortType(), False), + ... StructField("short2", ShortType(), False), + ... StructField("int", IntegerType(), False), + ... StructField("float", FloatType(), False), + ... StructField("date", DateType(), False), + ... StructField("time", TimestampType(), False), + ... StructField("map", + ... MapType(StringType(), IntegerType(), False), False), + ... StructField("struct", + ... StructType([StructField("b", ShortType(), False)]), False), + ... StructField("list", ArrayType(ByteType(), False), False), + ... StructField("null", DoubleType(), True)]) + >>> df = sqlCtx.applySchema(rdd, schema) + >>> results = df.map( + ... lambda x: (x.byte1, x.byte2, x.short1, x.short2, x.int, x.float, x.date, + ... x.time, x.map["a"], x.struct.b, x.list, x.null)) + >>> results.collect()[0] # doctest: +NORMALIZE_WHITESPACE + (127, -128, -32768, 32767, 2147483647, 1.0, datetime.date(2010, 1, 1), + datetime.datetime(2010, 1, 1, 1, 1, 1), 1, 2, [1, 2, 3], None) + + >>> df.registerTempTable("table2") + >>> sqlCtx.sql( + ... "SELECT byte1 - 1 AS byte1, byte2 + 1 AS byte2, " + + ... "short1 + 1 AS short1, short2 - 1 AS short2, int - 1 AS int, " + + ... "float + 1.5 as float FROM table2").collect() + [Row(byte1=126, byte2=-127, short1=-32767, short2=32766, int=2147483646, float=2.5)] + + >>> from pyspark.sql.types import _parse_schema_abstract, _infer_schema_type + >>> rdd = sc.parallelize([(127, -32768, 1.0, + ... datetime(2010, 1, 1, 1, 1, 1), + ... {"a": 1}, (2,), [1, 2, 3])]) + >>> abstract = "byte short float time map{} struct(b) list[]" + >>> schema = _parse_schema_abstract(abstract) + >>> typedSchema = _infer_schema_type(rdd.first(), schema) + >>> df = sqlCtx.applySchema(rdd, typedSchema) + >>> df.collect() + [Row(byte=127, short=-32768, float=1.0, time=..., list=[1, 2, 3])] + """ + + if isinstance(rdd, DataFrame): + raise TypeError("Cannot apply schema to DataFrame") + + if not isinstance(schema, StructType): + raise TypeError("schema should be StructType") + + # take the first few rows to verify schema + rows = rdd.take(10) + # Row() cannot been deserialized by Pyrolite + if rows and isinstance(rows[0], tuple) and rows[0].__class__.__name__ == 'Row': + rdd = rdd.map(tuple) + rows = rdd.take(10) + + for row in rows: + _verify_type(row, schema) + + # convert python objects to sql data + converter = _python_to_sql_converter(schema) + rdd = rdd.map(converter) + + jrdd = self._jvm.SerDeUtil.toJavaArray(rdd._to_java_object_rdd()) + df = self._ssql_ctx.applySchemaToPythonRDD(jrdd.rdd(), schema.json()) + return DataFrame(df, self) + + def registerRDDAsTable(self, rdd, tableName): + """Registers the given RDD as a temporary table in the catalog. + + Temporary tables exist only during the lifetime of this instance of + SQLContext. + + >>> df = sqlCtx.inferSchema(rdd) + >>> sqlCtx.registerRDDAsTable(df, "table1") + """ + if (rdd.__class__ is DataFrame): + df = rdd._jdf + self._ssql_ctx.registerRDDAsTable(df, tableName) + else: + raise ValueError("Can only register DataFrame as table") + + def parquetFile(self, *paths): + """Loads a Parquet file, returning the result as a L{DataFrame}. + + >>> import tempfile, shutil + >>> parquetFile = tempfile.mkdtemp() + >>> shutil.rmtree(parquetFile) + >>> df = sqlCtx.inferSchema(rdd) + >>> df.saveAsParquetFile(parquetFile) + >>> df2 = sqlCtx.parquetFile(parquetFile) + >>> sorted(df.collect()) == sorted(df2.collect()) + True + """ + gateway = self._sc._gateway + jpath = paths[0] + jpaths = gateway.new_array(gateway.jvm.java.lang.String, len(paths) - 1) + for i in range(1, len(paths)): + jpaths[i] = paths[i] + jdf = self._ssql_ctx.parquetFile(jpath, jpaths) + return DataFrame(jdf, self) + + def jsonFile(self, path, schema=None, samplingRatio=1.0): + """ + Loads a text file storing one JSON object per line as a + L{DataFrame}. + + If the schema is provided, applies the given schema to this + JSON dataset. + + Otherwise, it samples the dataset with ratio `samplingRatio` to + determine the schema. + + >>> import tempfile, shutil + >>> jsonFile = tempfile.mkdtemp() + >>> shutil.rmtree(jsonFile) + >>> ofn = open(jsonFile, 'w') + >>> for json in jsonStrings: + ... print>>ofn, json + >>> ofn.close() + >>> df1 = sqlCtx.jsonFile(jsonFile) + >>> sqlCtx.registerRDDAsTable(df1, "table1") + >>> df2 = sqlCtx.sql( + ... "SELECT field1 AS f1, field2 as f2, field3 as f3, " + ... "field6 as f4 from table1") + >>> for r in df2.collect(): + ... print r + Row(f1=1, f2=u'row1', f3=Row(field4=11, field5=None), f4=None) + Row(f1=2, f2=None, f3=Row(field4=22,..., f4=[Row(field7=u'row2')]) + Row(f1=None, f2=u'row3', f3=Row(field4=33, field5=[]), f4=None) + + >>> df3 = sqlCtx.jsonFile(jsonFile, df1.schema()) + >>> sqlCtx.registerRDDAsTable(df3, "table2") + >>> df4 = sqlCtx.sql( + ... "SELECT field1 AS f1, field2 as f2, field3 as f3, " + ... "field6 as f4 from table2") + >>> for r in df4.collect(): + ... print r + Row(f1=1, f2=u'row1', f3=Row(field4=11, field5=None), f4=None) + Row(f1=2, f2=None, f3=Row(field4=22,..., f4=[Row(field7=u'row2')]) + Row(f1=None, f2=u'row3', f3=Row(field4=33, field5=[]), f4=None) + + >>> from pyspark.sql.types import * + >>> schema = StructType([ + ... StructField("field2", StringType(), True), + ... StructField("field3", + ... StructType([ + ... StructField("field5", + ... ArrayType(IntegerType(), False), True)]), False)]) + >>> df5 = sqlCtx.jsonFile(jsonFile, schema) + >>> sqlCtx.registerRDDAsTable(df5, "table3") + >>> df6 = sqlCtx.sql( + ... "SELECT field2 AS f1, field3.field5 as f2, " + ... "field3.field5[0] as f3 from table3") + >>> df6.collect() + [Row(f1=u'row1', f2=None, f3=None)...Row(f1=u'row3', f2=[], f3=None)] + """ + if schema is None: + df = self._ssql_ctx.jsonFile(path, samplingRatio) + else: + scala_datatype = self._ssql_ctx.parseDataType(schema.json()) + df = self._ssql_ctx.jsonFile(path, scala_datatype) + return DataFrame(df, self) + + def jsonRDD(self, rdd, schema=None, samplingRatio=1.0): + """Loads an RDD storing one JSON object per string as a L{DataFrame}. + + If the schema is provided, applies the given schema to this + JSON dataset. + + Otherwise, it samples the dataset with ratio `samplingRatio` to + determine the schema. + + >>> df1 = sqlCtx.jsonRDD(json) + >>> sqlCtx.registerRDDAsTable(df1, "table1") + >>> df2 = sqlCtx.sql( + ... "SELECT field1 AS f1, field2 as f2, field3 as f3, " + ... "field6 as f4 from table1") + >>> for r in df2.collect(): + ... print r + Row(f1=1, f2=u'row1', f3=Row(field4=11, field5=None), f4=None) + Row(f1=2, f2=None, f3=Row(field4=22..., f4=[Row(field7=u'row2')]) + Row(f1=None, f2=u'row3', f3=Row(field4=33, field5=[]), f4=None) + + >>> df3 = sqlCtx.jsonRDD(json, df1.schema()) + >>> sqlCtx.registerRDDAsTable(df3, "table2") + >>> df4 = sqlCtx.sql( + ... "SELECT field1 AS f1, field2 as f2, field3 as f3, " + ... "field6 as f4 from table2") + >>> for r in df4.collect(): + ... print r + Row(f1=1, f2=u'row1', f3=Row(field4=11, field5=None), f4=None) + Row(f1=2, f2=None, f3=Row(field4=22..., f4=[Row(field7=u'row2')]) + Row(f1=None, f2=u'row3', f3=Row(field4=33, field5=[]), f4=None) + + >>> from pyspark.sql.types import * + >>> schema = StructType([ + ... StructField("field2", StringType(), True), + ... StructField("field3", + ... StructType([ + ... StructField("field5", + ... ArrayType(IntegerType(), False), True)]), False)]) + >>> df5 = sqlCtx.jsonRDD(json, schema) + >>> sqlCtx.registerRDDAsTable(df5, "table3") + >>> df6 = sqlCtx.sql( + ... "SELECT field2 AS f1, field3.field5 as f2, " + ... "field3.field5[0] as f3 from table3") + >>> df6.collect() + [Row(f1=u'row1', f2=None,...Row(f1=u'row3', f2=[], f3=None)] + + >>> sqlCtx.jsonRDD(sc.parallelize(['{}', + ... '{"key0": {"key1": "value1"}}'])).collect() + [Row(key0=None), Row(key0=Row(key1=u'value1'))] + >>> sqlCtx.jsonRDD(sc.parallelize(['{"key0": null}', + ... '{"key0": {"key1": "value1"}}'])).collect() + [Row(key0=None), Row(key0=Row(key1=u'value1'))] + """ + + def func(iterator): + for x in iterator: + if not isinstance(x, basestring): + x = unicode(x) + if isinstance(x, unicode): + x = x.encode("utf-8") + yield x + keyed = rdd.mapPartitions(func) + keyed._bypass_serializer = True + jrdd = keyed._jrdd.map(self._jvm.BytesToString()) + if schema is None: + df = self._ssql_ctx.jsonRDD(jrdd.rdd(), samplingRatio) + else: + scala_datatype = self._ssql_ctx.parseDataType(schema.json()) + df = self._ssql_ctx.jsonRDD(jrdd.rdd(), scala_datatype) + return DataFrame(df, self) + + def sql(self, sqlQuery): + """Return a L{DataFrame} representing the result of the given query. + + >>> df = sqlCtx.inferSchema(rdd) + >>> sqlCtx.registerRDDAsTable(df, "table1") + >>> df2 = sqlCtx.sql("SELECT field1 AS f1, field2 as f2 from table1") + >>> df2.collect() + [Row(f1=1, f2=u'row1'), Row(f1=2, f2=u'row2'), Row(f1=3, f2=u'row3')] + """ + return DataFrame(self._ssql_ctx.sql(sqlQuery), self) + + def table(self, tableName): + """Returns the specified table as a L{DataFrame}. + + >>> df = sqlCtx.inferSchema(rdd) + >>> sqlCtx.registerRDDAsTable(df, "table1") + >>> df2 = sqlCtx.table("table1") + >>> sorted(df.collect()) == sorted(df2.collect()) + True + """ + return DataFrame(self._ssql_ctx.table(tableName), self) + + def cacheTable(self, tableName): + """Caches the specified table in-memory.""" + self._ssql_ctx.cacheTable(tableName) + + def uncacheTable(self, tableName): + """Removes the specified table from the in-memory cache.""" + self._ssql_ctx.uncacheTable(tableName) + + +class HiveContext(SQLContext): + + """A variant of Spark SQL that integrates with data stored in Hive. + + Configuration for Hive is read from hive-site.xml on the classpath. + It supports running both SQL and HiveQL commands. + """ + + def __init__(self, sparkContext, hiveContext=None): + """Create a new HiveContext. + + :param sparkContext: The SparkContext to wrap. + :param hiveContext: An optional JVM Scala HiveContext. If set, we do not instatiate a new + HiveContext in the JVM, instead we make all calls to this object. + """ + SQLContext.__init__(self, sparkContext) + + if hiveContext: + self._scala_HiveContext = hiveContext + + @property + def _ssql_ctx(self): + try: + if not hasattr(self, '_scala_HiveContext'): + self._scala_HiveContext = self._get_hive_ctx() + return self._scala_HiveContext + except Py4JError as e: + raise Exception("You must build Spark with Hive. " + "Export 'SPARK_HIVE=true' and run " + "build/sbt assembly", e) + + def _get_hive_ctx(self): + return self._jvm.HiveContext(self._jsc.sc()) + + +def _create_row(fields, values): + row = Row(*values) + row.__FIELDS__ = fields + return row + + +class Row(tuple): + + """ + A row in L{DataFrame}. The fields in it can be accessed like attributes. + + Row can be used to create a row object by using named arguments, + the fields will be sorted by names. + + >>> row = Row(name="Alice", age=11) + >>> row + Row(age=11, name='Alice') + >>> row.name, row.age + ('Alice', 11) + + Row also can be used to create another Row like class, then it + could be used to create Row objects, such as + + >>> Person = Row("name", "age") + >>> Person + <Row(name, age)> + >>> Person("Alice", 11) + Row(name='Alice', age=11) + """ + + def __new__(self, *args, **kwargs): + if args and kwargs: + raise ValueError("Can not use both args " + "and kwargs to create Row") + if args: + # create row class or objects + return tuple.__new__(self, args) + + elif kwargs: + # create row objects + names = sorted(kwargs.keys()) + values = tuple(kwargs[n] for n in names) + row = tuple.__new__(self, values) + row.__FIELDS__ = names + return row + + else: + raise ValueError("No args or kwargs") + + def asDict(self): + """ + Return as an dict + """ + if not hasattr(self, "__FIELDS__"): + raise TypeError("Cannot convert a Row class into dict") + return dict(zip(self.__FIELDS__, self)) + + # let obect acs like class + def __call__(self, *args): + """create new Row object""" + return _create_row(self, args) + + def __getattr__(self, item): + if item.startswith("__"): + raise AttributeError(item) + try: + # it will be slow when it has many fields, + # but this will not be used in normal cases + idx = self.__FIELDS__.index(item) + return self[idx] + except IndexError: + raise AttributeError(item) + + def __reduce__(self): + if hasattr(self, "__FIELDS__"): + return (_create_row, (self.__FIELDS__, tuple(self))) + else: + return tuple.__reduce__(self) + + def __repr__(self): + if hasattr(self, "__FIELDS__"): + return "Row(%s)" % ", ".join("%s=%r" % (k, v) + for k, v in zip(self.__FIELDS__, self)) + else: + return "<Row(%s)>" % ", ".join(self) + + +def _test(): + import doctest + from pyspark.context import SparkContext + from pyspark.sql import Row, SQLContext + import pyspark.sql.context + globs = pyspark.sql.context.__dict__.copy() + sc = SparkContext('local[4]', 'PythonTest') + globs['sc'] = sc + globs['sqlCtx'] = sqlCtx = SQLContext(sc) + globs['rdd'] = sc.parallelize( + [Row(field1=1, field2="row1"), + Row(field1=2, field2="row2"), + Row(field1=3, field2="row3")] + ) + jsonStrings = [ + '{"field1": 1, "field2": "row1", "field3":{"field4":11}}', + '{"field1" : 2, "field3":{"field4":22, "field5": [10, 11]},' + '"field6":[{"field7": "row2"}]}', + '{"field1" : null, "field2": "row3", ' + '"field3":{"field4":33, "field5": []}}' + ] + globs['jsonStrings'] = jsonStrings + globs['json'] = sc.parallelize(jsonStrings) + (failure_count, test_count) = doctest.testmod( + pyspark.sql.context, globs=globs, optionflags=doctest.ELLIPSIS) + globs['sc'].stop() + if failure_count: + exit(-1) + + +if __name__ == "__main__": + _test() http://git-wip-us.apache.org/repos/asf/spark/blob/08488c17/python/pyspark/sql/dataframe.py ---------------------------------------------------------------------- diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py new file mode 100644 index 0000000..cda704e --- /dev/null +++ b/python/pyspark/sql/dataframe.py @@ -0,0 +1,974 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import sys +import itertools +import warnings +import random +import os +from tempfile import NamedTemporaryFile +from itertools import imap + +from py4j.java_collections import ListConverter, MapConverter + +from pyspark.context import SparkContext +from pyspark.rdd import RDD, _prepare_for_python_RDD +from pyspark.serializers import BatchedSerializer, AutoBatchedSerializer, PickleSerializer, \ + UTF8Deserializer +from pyspark.storagelevel import StorageLevel +from pyspark.traceback_utils import SCCallSiteSync +from pyspark.sql.types import * +from pyspark.sql.types import _create_cls, _parse_datatype_json_string + + +__all__ = ["DataFrame", "GroupedData", "Column", "Dsl", "SchemaRDD"] + + +class DataFrame(object): + + """A collection of rows that have the same columns. + + A :class:`DataFrame` is equivalent to a relational table in Spark SQL, + and can be created using various functions in :class:`SQLContext`:: + + people = sqlContext.parquetFile("...") + + Once created, it can be manipulated using the various domain-specific-language + (DSL) functions defined in: :class:`DataFrame`, :class:`Column`. + + To select a column from the data frame, use the apply method:: + + ageCol = people.age + + Note that the :class:`Column` type can also be manipulated + through its various functions:: + + # The following creates a new column that increases everybody's age by 10. + people.age + 10 + + + A more concrete example:: + + # To create DataFrame using SQLContext + people = sqlContext.parquetFile("...") + department = sqlContext.parquetFile("...") + + people.filter(people.age > 30).join(department, people.deptId == department.id)) \ + .groupBy(department.name, "gender").agg({"salary": "avg", "age": "max"}) + """ + + def __init__(self, jdf, sql_ctx): + self._jdf = jdf + self.sql_ctx = sql_ctx + self._sc = sql_ctx and sql_ctx._sc + self.is_cached = False + + @property + def rdd(self): + """ + Return the content of the :class:`DataFrame` as an :class:`RDD` + of :class:`Row` s. + """ + if not hasattr(self, '_lazy_rdd'): + jrdd = self._jdf.javaToPython() + rdd = RDD(jrdd, self.sql_ctx._sc, BatchedSerializer(PickleSerializer())) + schema = self.schema() + + def applySchema(it): + cls = _create_cls(schema) + return itertools.imap(cls, it) + + self._lazy_rdd = rdd.mapPartitions(applySchema) + + return self._lazy_rdd + + def toJSON(self, use_unicode=False): + """Convert a DataFrame into a MappedRDD of JSON documents; one document per row. + + >>> df.toJSON().first() + '{"age":2,"name":"Alice"}' + """ + rdd = self._jdf.toJSON() + return RDD(rdd.toJavaRDD(), self._sc, UTF8Deserializer(use_unicode)) + + def saveAsParquetFile(self, path): + """Save the contents as a Parquet file, preserving the schema. + + Files that are written out using this method can be read back in as + a DataFrame using the L{SQLContext.parquetFile} method. + + >>> import tempfile, shutil + >>> parquetFile = tempfile.mkdtemp() + >>> shutil.rmtree(parquetFile) + >>> df.saveAsParquetFile(parquetFile) + >>> df2 = sqlCtx.parquetFile(parquetFile) + >>> sorted(df2.collect()) == sorted(df.collect()) + True + """ + self._jdf.saveAsParquetFile(path) + + def registerTempTable(self, name): + """Registers this RDD as a temporary table using the given name. + + The lifetime of this temporary table is tied to the L{SQLContext} + that was used to create this DataFrame. + + >>> df.registerTempTable("people") + >>> df2 = sqlCtx.sql("select * from people") + >>> sorted(df.collect()) == sorted(df2.collect()) + True + """ + self._jdf.registerTempTable(name) + + def registerAsTable(self, name): + """DEPRECATED: use registerTempTable() instead""" + warnings.warn("Use registerTempTable instead of registerAsTable.", DeprecationWarning) + self.registerTempTable(name) + + def insertInto(self, tableName, overwrite=False): + """Inserts the contents of this DataFrame into the specified table. + + Optionally overwriting any existing data. + """ + self._jdf.insertInto(tableName, overwrite) + + def saveAsTable(self, tableName): + """Creates a new table with the contents of this DataFrame.""" + self._jdf.saveAsTable(tableName) + + def schema(self): + """Returns the schema of this DataFrame (represented by + a L{StructType}). + + >>> df.schema() + StructType(List(StructField(age,IntegerType,true),StructField(name,StringType,true))) + """ + return _parse_datatype_json_string(self._jdf.schema().json()) + + def printSchema(self): + """Prints out the schema in the tree format. + + >>> df.printSchema() + root + |-- age: integer (nullable = true) + |-- name: string (nullable = true) + <BLANKLINE> + """ + print (self._jdf.schema().treeString()) + + def count(self): + """Return the number of elements in this RDD. + + Unlike the base RDD implementation of count, this implementation + leverages the query optimizer to compute the count on the DataFrame, + which supports features such as filter pushdown. + + >>> df.count() + 2L + """ + return self._jdf.count() + + def collect(self): + """Return a list that contains all of the rows. + + Each object in the list is a Row, the fields can be accessed as + attributes. + + >>> df.collect() + [Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')] + """ + with SCCallSiteSync(self._sc) as css: + bytesInJava = self._jdf.javaToPython().collect().iterator() + tempFile = NamedTemporaryFile(delete=False, dir=self._sc._temp_dir) + tempFile.close() + self._sc._writeToFile(bytesInJava, tempFile.name) + # Read the data into Python and deserialize it: + with open(tempFile.name, 'rb') as tempFile: + rs = list(BatchedSerializer(PickleSerializer()).load_stream(tempFile)) + os.unlink(tempFile.name) + cls = _create_cls(self.schema()) + return [cls(r) for r in rs] + + def limit(self, num): + """Limit the result count to the number specified. + + >>> df.limit(1).collect() + [Row(age=2, name=u'Alice')] + >>> df.limit(0).collect() + [] + """ + jdf = self._jdf.limit(num) + return DataFrame(jdf, self.sql_ctx) + + def take(self, num): + """Take the first num rows of the RDD. + + Each object in the list is a Row, the fields can be accessed as + attributes. + + >>> df.take(2) + [Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')] + """ + return self.limit(num).collect() + + def map(self, f): + """ Return a new RDD by applying a function to each Row, it's a + shorthand for df.rdd.map() + + >>> df.map(lambda p: p.name).collect() + [u'Alice', u'Bob'] + """ + return self.rdd.map(f) + + def mapPartitions(self, f, preservesPartitioning=False): + """ + Return a new RDD by applying a function to each partition. + + >>> rdd = sc.parallelize([1, 2, 3, 4], 4) + >>> def f(iterator): yield 1 + >>> rdd.mapPartitions(f).sum() + 4 + """ + return self.rdd.mapPartitions(f, preservesPartitioning) + + def cache(self): + """ Persist with the default storage level (C{MEMORY_ONLY_SER}). + """ + self.is_cached = True + self._jdf.cache() + return self + + def persist(self, storageLevel=StorageLevel.MEMORY_ONLY_SER): + """ Set the storage level to persist its values across operations + after the first time it is computed. This can only be used to assign + a new storage level if the RDD does not have a storage level set yet. + If no storage level is specified defaults to (C{MEMORY_ONLY_SER}). + """ + self.is_cached = True + javaStorageLevel = self._sc._getJavaStorageLevel(storageLevel) + self._jdf.persist(javaStorageLevel) + return self + + def unpersist(self, blocking=True): + """ Mark it as non-persistent, and remove all blocks for it from + memory and disk. + """ + self.is_cached = False + self._jdf.unpersist(blocking) + return self + + # def coalesce(self, numPartitions, shuffle=False): + # rdd = self._jdf.coalesce(numPartitions, shuffle, None) + # return DataFrame(rdd, self.sql_ctx) + + def repartition(self, numPartitions): + """ Return a new :class:`DataFrame` that has exactly `numPartitions` + partitions. + """ + rdd = self._jdf.repartition(numPartitions, None) + return DataFrame(rdd, self.sql_ctx) + + def sample(self, withReplacement, fraction, seed=None): + """ + Return a sampled subset of this DataFrame. + + >>> df.sample(False, 0.5, 97).count() + 1L + """ + assert fraction >= 0.0, "Negative fraction value: %s" % fraction + seed = seed if seed is not None else random.randint(0, sys.maxint) + rdd = self._jdf.sample(withReplacement, fraction, long(seed)) + return DataFrame(rdd, self.sql_ctx) + + # def takeSample(self, withReplacement, num, seed=None): + # """Return a fixed-size sampled subset of this DataFrame. + # + # >>> df = sqlCtx.inferSchema(rdd) + # >>> df.takeSample(False, 2, 97) + # [Row(field1=3, field2=u'row3'), Row(field1=1, field2=u'row1')] + # """ + # seed = seed if seed is not None else random.randint(0, sys.maxint) + # with SCCallSiteSync(self.context) as css: + # bytesInJava = self._jdf \ + # .takeSampleToPython(withReplacement, num, long(seed)) \ + # .iterator() + # cls = _create_cls(self.schema()) + # return map(cls, self._collect_iterator_through_file(bytesInJava)) + + @property + def dtypes(self): + """Return all column names and their data types as a list. + + >>> df.dtypes + [('age', 'integer'), ('name', 'string')] + """ + return [(str(f.name), f.dataType.jsonValue()) for f in self.schema().fields] + + @property + def columns(self): + """ Return all column names as a list. + + >>> df.columns + [u'age', u'name'] + """ + return [f.name for f in self.schema().fields] + + def join(self, other, joinExprs=None, joinType=None): + """ + Join with another DataFrame, using the given join expression. + The following performs a full outer join between `df1` and `df2`:: + + :param other: Right side of the join + :param joinExprs: Join expression + :param joinType: One of `inner`, `outer`, `left_outer`, `right_outer`, `semijoin`. + + >>> df.join(df2, df.name == df2.name, 'outer').select(df.name, df2.height).collect() + [Row(name=None, height=80), Row(name=u'Bob', height=85), Row(name=u'Alice', height=None)] + """ + + if joinExprs is None: + jdf = self._jdf.join(other._jdf) + else: + assert isinstance(joinExprs, Column), "joinExprs should be Column" + if joinType is None: + jdf = self._jdf.join(other._jdf, joinExprs._jc) + else: + assert isinstance(joinType, basestring), "joinType should be basestring" + jdf = self._jdf.join(other._jdf, joinExprs._jc, joinType) + return DataFrame(jdf, self.sql_ctx) + + def sort(self, *cols): + """ Return a new :class:`DataFrame` sorted by the specified column. + + :param cols: The columns or expressions used for sorting + + >>> df.sort(df.age.desc()).collect() + [Row(age=5, name=u'Bob'), Row(age=2, name=u'Alice')] + >>> df.sortBy(df.age.desc()).collect() + [Row(age=5, name=u'Bob'), Row(age=2, name=u'Alice')] + """ + if not cols: + raise ValueError("should sort by at least one column") + jcols = ListConverter().convert([_to_java_column(c) for c in cols], + self._sc._gateway._gateway_client) + jdf = self._jdf.sort(self._sc._jvm.PythonUtils.toSeq(jcols)) + return DataFrame(jdf, self.sql_ctx) + + sortBy = sort + + def head(self, n=None): + """ Return the first `n` rows or the first row if n is None. + + >>> df.head() + Row(age=2, name=u'Alice') + >>> df.head(1) + [Row(age=2, name=u'Alice')] + """ + if n is None: + rs = self.head(1) + return rs[0] if rs else None + return self.take(n) + + def first(self): + """ Return the first row. + + >>> df.first() + Row(age=2, name=u'Alice') + """ + return self.head() + + def __getitem__(self, item): + """ Return the column by given name + + >>> df['age'].collect() + [Row(age=2), Row(age=5)] + >>> df[ ["name", "age"]].collect() + [Row(name=u'Alice', age=2), Row(name=u'Bob', age=5)] + >>> df[ df.age > 3 ].collect() + [Row(age=5, name=u'Bob')] + """ + if isinstance(item, basestring): + jc = self._jdf.apply(item) + return Column(jc, self.sql_ctx) + elif isinstance(item, Column): + return self.filter(item) + elif isinstance(item, list): + return self.select(*item) + else: + raise IndexError("unexpected index: %s" % item) + + def __getattr__(self, name): + """ Return the column by given name + + >>> df.age.collect() + [Row(age=2), Row(age=5)] + """ + if name.startswith("__"): + raise AttributeError(name) + jc = self._jdf.apply(name) + return Column(jc, self.sql_ctx) + + def select(self, *cols): + """ Selecting a set of expressions. + + >>> df.select().collect() + [Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')] + >>> df.select('*').collect() + [Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')] + >>> df.select('name', 'age').collect() + [Row(name=u'Alice', age=2), Row(name=u'Bob', age=5)] + >>> df.select(df.name, (df.age + 10).alias('age')).collect() + [Row(name=u'Alice', age=12), Row(name=u'Bob', age=15)] + """ + if not cols: + cols = ["*"] + jcols = ListConverter().convert([_to_java_column(c) for c in cols], + self._sc._gateway._gateway_client) + jdf = self._jdf.select(self.sql_ctx._sc._jvm.PythonUtils.toSeq(jcols)) + return DataFrame(jdf, self.sql_ctx) + + def selectExpr(self, *expr): + """ + Selects a set of SQL expressions. This is a variant of + `select` that accepts SQL expressions. + + >>> df.selectExpr("age * 2", "abs(age)").collect() + [Row(('age * 2)=4, Abs('age)=2), Row(('age * 2)=10, Abs('age)=5)] + """ + jexpr = ListConverter().convert(expr, self._sc._gateway._gateway_client) + jdf = self._jdf.selectExpr(self._sc._jvm.PythonUtils.toSeq(jexpr)) + return DataFrame(jdf, self.sql_ctx) + + def filter(self, condition): + """ Filtering rows using the given condition, which could be + Column expression or string of SQL expression. + + where() is an alias for filter(). + + >>> df.filter(df.age > 3).collect() + [Row(age=5, name=u'Bob')] + >>> df.where(df.age == 2).collect() + [Row(age=2, name=u'Alice')] + + >>> df.filter("age > 3").collect() + [Row(age=5, name=u'Bob')] + >>> df.where("age = 2").collect() + [Row(age=2, name=u'Alice')] + """ + if isinstance(condition, basestring): + jdf = self._jdf.filter(condition) + elif isinstance(condition, Column): + jdf = self._jdf.filter(condition._jc) + else: + raise TypeError("condition should be string or Column") + return DataFrame(jdf, self.sql_ctx) + + where = filter + + def groupBy(self, *cols): + """ Group the :class:`DataFrame` using the specified columns, + so we can run aggregation on them. See :class:`GroupedData` + for all the available aggregate functions. + + >>> df.groupBy().avg().collect() + [Row(AVG(age#0)=3.5)] + >>> df.groupBy('name').agg({'age': 'mean'}).collect() + [Row(name=u'Bob', AVG(age#0)=5.0), Row(name=u'Alice', AVG(age#0)=2.0)] + >>> df.groupBy(df.name).avg().collect() + [Row(name=u'Bob', AVG(age#0)=5.0), Row(name=u'Alice', AVG(age#0)=2.0)] + """ + jcols = ListConverter().convert([_to_java_column(c) for c in cols], + self._sc._gateway._gateway_client) + jdf = self._jdf.groupBy(self.sql_ctx._sc._jvm.PythonUtils.toSeq(jcols)) + return GroupedData(jdf, self.sql_ctx) + + def agg(self, *exprs): + """ Aggregate on the entire :class:`DataFrame` without groups + (shorthand for df.groupBy.agg()). + + >>> df.agg({"age": "max"}).collect() + [Row(MAX(age#0)=5)] + >>> from pyspark.sql import Dsl + >>> df.agg(Dsl.min(df.age)).collect() + [Row(MIN(age#0)=2)] + """ + return self.groupBy().agg(*exprs) + + def unionAll(self, other): + """ Return a new DataFrame containing union of rows in this + frame and another frame. + + This is equivalent to `UNION ALL` in SQL. + """ + return DataFrame(self._jdf.unionAll(other._jdf), self.sql_ctx) + + def intersect(self, other): + """ Return a new :class:`DataFrame` containing rows only in + both this frame and another frame. + + This is equivalent to `INTERSECT` in SQL. + """ + return DataFrame(self._jdf.intersect(other._jdf), self.sql_ctx) + + def subtract(self, other): + """ Return a new :class:`DataFrame` containing rows in this frame + but not in another frame. + + This is equivalent to `EXCEPT` in SQL. + """ + return DataFrame(getattr(self._jdf, "except")(other._jdf), self.sql_ctx) + + def addColumn(self, colName, col): + """ Return a new :class:`DataFrame` by adding a column. + + >>> df.addColumn('age2', df.age + 2).collect() + [Row(age=2, name=u'Alice', age2=4), Row(age=5, name=u'Bob', age2=7)] + """ + return self.select('*', col.alias(colName)) + + def to_pandas(self): + """ + Collect all the rows and return a `pandas.DataFrame`. + + >>> df.to_pandas() # doctest: +SKIP + age name + 0 2 Alice + 1 5 Bob + """ + import pandas as pd + return pd.DataFrame.from_records(self.collect(), columns=self.columns) + + +# Having SchemaRDD for backward compatibility (for docs) +class SchemaRDD(DataFrame): + """ + SchemaRDD is deprecated, please use DataFrame + """ + + +def dfapi(f): + def _api(self): + name = f.__name__ + jdf = getattr(self._jdf, name)() + return DataFrame(jdf, self.sql_ctx) + _api.__name__ = f.__name__ + _api.__doc__ = f.__doc__ + return _api + + +class GroupedData(object): + + """ + A set of methods for aggregations on a :class:`DataFrame`, + created by DataFrame.groupBy(). + """ + + def __init__(self, jdf, sql_ctx): + self._jdf = jdf + self.sql_ctx = sql_ctx + + def agg(self, *exprs): + """ Compute aggregates by specifying a map from column name + to aggregate methods. + + The available aggregate methods are `avg`, `max`, `min`, + `sum`, `count`. + + :param exprs: list or aggregate columns or a map from column + name to aggregate methods. + + >>> gdf = df.groupBy(df.name) + >>> gdf.agg({"age": "max"}).collect() + [Row(name=u'Bob', MAX(age#0)=5), Row(name=u'Alice', MAX(age#0)=2)] + >>> from pyspark.sql import Dsl + >>> gdf.agg(Dsl.min(df.age)).collect() + [Row(MIN(age#0)=5), Row(MIN(age#0)=2)] + """ + assert exprs, "exprs should not be empty" + if len(exprs) == 1 and isinstance(exprs[0], dict): + jmap = MapConverter().convert(exprs[0], + self.sql_ctx._sc._gateway._gateway_client) + jdf = self._jdf.agg(jmap) + else: + # Columns + assert all(isinstance(c, Column) for c in exprs), "all exprs should be Column" + jcols = ListConverter().convert([c._jc for c in exprs[1:]], + self.sql_ctx._sc._gateway._gateway_client) + jdf = self._jdf.agg(exprs[0]._jc, self.sql_ctx._sc._jvm.PythonUtils.toSeq(jcols)) + return DataFrame(jdf, self.sql_ctx) + + @dfapi + def count(self): + """ Count the number of rows for each group. + + >>> df.groupBy(df.age).count().collect() + [Row(age=2, count=1), Row(age=5, count=1)] + """ + + @dfapi + def mean(self): + """Compute the average value for each numeric columns + for each group. This is an alias for `avg`.""" + + @dfapi + def avg(self): + """Compute the average value for each numeric columns + for each group.""" + + @dfapi + def max(self): + """Compute the max value for each numeric columns for + each group. """ + + @dfapi + def min(self): + """Compute the min value for each numeric column for + each group.""" + + @dfapi + def sum(self): + """Compute the sum for each numeric columns for each + group.""" + + +def _create_column_from_literal(literal): + sc = SparkContext._active_spark_context + return sc._jvm.Dsl.lit(literal) + + +def _create_column_from_name(name): + sc = SparkContext._active_spark_context + return sc._jvm.Dsl.col(name) + + +def _to_java_column(col): + if isinstance(col, Column): + jcol = col._jc + else: + jcol = _create_column_from_name(col) + return jcol + + +def _unary_op(name, doc="unary operator"): + """ Create a method for given unary operator """ + def _(self): + jc = getattr(self._jc, name)() + return Column(jc, self.sql_ctx) + _.__doc__ = doc + return _ + + +def _dsl_op(name, doc=''): + def _(self): + jc = getattr(self._sc._jvm.Dsl, name)(self._jc) + return Column(jc, self.sql_ctx) + _.__doc__ = doc + return _ + + +def _bin_op(name, doc="binary operator"): + """ Create a method for given binary operator + """ + def _(self, other): + jc = other._jc if isinstance(other, Column) else other + njc = getattr(self._jc, name)(jc) + return Column(njc, self.sql_ctx) + _.__doc__ = doc + return _ + + +def _reverse_op(name, doc="binary operator"): + """ Create a method for binary operator (this object is on right side) + """ + def _(self, other): + jother = _create_column_from_literal(other) + jc = getattr(jother, name)(self._jc) + return Column(jc, self.sql_ctx) + _.__doc__ = doc + return _ + + +class Column(DataFrame): + + """ + A column in a DataFrame. + + `Column` instances can be created by:: + + # 1. Select a column out of a DataFrame + df.colName + df["colName"] + + # 2. Create from an expression + df.colName + 1 + 1 / df.colName + """ + + def __init__(self, jc, sql_ctx=None): + self._jc = jc + super(Column, self).__init__(jc, sql_ctx) + + # arithmetic operators + __neg__ = _dsl_op("negate") + __add__ = _bin_op("plus") + __sub__ = _bin_op("minus") + __mul__ = _bin_op("multiply") + __div__ = _bin_op("divide") + __mod__ = _bin_op("mod") + __radd__ = _bin_op("plus") + __rsub__ = _reverse_op("minus") + __rmul__ = _bin_op("multiply") + __rdiv__ = _reverse_op("divide") + __rmod__ = _reverse_op("mod") + + # logistic operators + __eq__ = _bin_op("equalTo") + __ne__ = _bin_op("notEqual") + __lt__ = _bin_op("lt") + __le__ = _bin_op("leq") + __ge__ = _bin_op("geq") + __gt__ = _bin_op("gt") + + # `and`, `or`, `not` cannot be overloaded in Python, + # so use bitwise operators as boolean operators + __and__ = _bin_op('and') + __or__ = _bin_op('or') + __invert__ = _dsl_op('not') + __rand__ = _bin_op("and") + __ror__ = _bin_op("or") + + # container operators + __contains__ = _bin_op("contains") + __getitem__ = _bin_op("getItem") + getField = _bin_op("getField", "An expression that gets a field by name in a StructField.") + + # string methods + rlike = _bin_op("rlike") + like = _bin_op("like") + startswith = _bin_op("startsWith") + endswith = _bin_op("endsWith") + + def substr(self, startPos, length): + """ + Return a Column which is a substring of the column + + :param startPos: start position (int or Column) + :param length: length of the substring (int or Column) + + >>> df.name.substr(1, 3).collect() + [Row(col=u'Ali'), Row(col=u'Bob')] + """ + if type(startPos) != type(length): + raise TypeError("Can not mix the type") + if isinstance(startPos, (int, long)): + jc = self._jc.substr(startPos, length) + elif isinstance(startPos, Column): + jc = self._jc.substr(startPos._jc, length._jc) + else: + raise TypeError("Unexpected type: %s" % type(startPos)) + return Column(jc, self.sql_ctx) + + __getslice__ = substr + + # order + asc = _unary_op("asc") + desc = _unary_op("desc") + + isNull = _unary_op("isNull", "True if the current expression is null.") + isNotNull = _unary_op("isNotNull", "True if the current expression is not null.") + + def alias(self, alias): + """Return a alias for this column + + >>> df.age.alias("age2").collect() + [Row(age2=2), Row(age2=5)] + """ + return Column(getattr(self._jc, "as")(alias), self.sql_ctx) + + def cast(self, dataType): + """ Convert the column into type `dataType` + + >>> df.select(df.age.cast("string").alias('ages')).collect() + [Row(ages=u'2'), Row(ages=u'5')] + >>> df.select(df.age.cast(StringType()).alias('ages')).collect() + [Row(ages=u'2'), Row(ages=u'5')] + """ + if self.sql_ctx is None: + sc = SparkContext._active_spark_context + ssql_ctx = sc._jvm.SQLContext(sc._jsc.sc()) + else: + ssql_ctx = self.sql_ctx._ssql_ctx + if isinstance(dataType, basestring): + jc = self._jc.cast(dataType) + elif isinstance(dataType, DataType): + jdt = ssql_ctx.parseDataType(dataType.json()) + jc = self._jc.cast(jdt) + return Column(jc, self.sql_ctx) + + def to_pandas(self): + """ + Return a pandas.Series from the column + + >>> df.age.to_pandas() # doctest: +SKIP + 0 2 + 1 5 + dtype: int64 + """ + import pandas as pd + data = [c for c, in self.collect()] + return pd.Series(data) + + +def _aggregate_func(name, doc=""): + """ Create a function for aggregator by name""" + def _(col): + sc = SparkContext._active_spark_context + jc = getattr(sc._jvm.Dsl, name)(_to_java_column(col)) + return Column(jc) + _.__name__ = name + _.__doc__ = doc + return staticmethod(_) + + +class UserDefinedFunction(object): + def __init__(self, func, returnType): + self.func = func + self.returnType = returnType + self._broadcast = None + self._judf = self._create_judf() + + def _create_judf(self): + f = self.func # put it in closure `func` + func = lambda _, it: imap(lambda x: f(*x), it) + ser = AutoBatchedSerializer(PickleSerializer()) + command = (func, None, ser, ser) + sc = SparkContext._active_spark_context + pickled_command, broadcast_vars, env, includes = _prepare_for_python_RDD(sc, command, self) + ssql_ctx = sc._jvm.SQLContext(sc._jsc.sc()) + jdt = ssql_ctx.parseDataType(self.returnType.json()) + judf = sc._jvm.UserDefinedPythonFunction(f.__name__, bytearray(pickled_command), env, + includes, sc.pythonExec, broadcast_vars, + sc._javaAccumulator, jdt) + return judf + + def __del__(self): + if self._broadcast is not None: + self._broadcast.unpersist() + self._broadcast = None + + def __call__(self, *cols): + sc = SparkContext._active_spark_context + jcols = ListConverter().convert([_to_java_column(c) for c in cols], + sc._gateway._gateway_client) + jc = self._judf.apply(sc._jvm.PythonUtils.toSeq(jcols)) + return Column(jc) + + +class Dsl(object): + """ + A collections of builtin aggregators + """ + DSLS = { + 'lit': 'Creates a :class:`Column` of literal value.', + 'col': 'Returns a :class:`Column` based on the given column name.', + 'column': 'Returns a :class:`Column` based on the given column name.', + 'upper': 'Converts a string expression to upper case.', + 'lower': 'Converts a string expression to upper case.', + 'sqrt': 'Computes the square root of the specified float value.', + 'abs': 'Computes the absolutle value.', + + 'max': 'Aggregate function: returns the maximum value of the expression in a group.', + 'min': 'Aggregate function: returns the minimum value of the expression in a group.', + 'first': 'Aggregate function: returns the first value in a group.', + 'last': 'Aggregate function: returns the last value in a group.', + 'count': 'Aggregate function: returns the number of items in a group.', + 'sum': 'Aggregate function: returns the sum of all values in the expression.', + 'avg': 'Aggregate function: returns the average of the values in a group.', + 'mean': 'Aggregate function: returns the average of the values in a group.', + 'sumDistinct': 'Aggregate function: returns the sum of distinct values in the expression.', + } + + for _name, _doc in DSLS.items(): + locals()[_name] = _aggregate_func(_name, _doc) + del _name, _doc + + @staticmethod + def countDistinct(col, *cols): + """ Return a new Column for distinct count of (col, *cols) + + >>> from pyspark.sql import Dsl + >>> df.agg(Dsl.countDistinct(df.age, df.name).alias('c')).collect() + [Row(c=2)] + + >>> df.agg(Dsl.countDistinct("age", "name").alias('c')).collect() + [Row(c=2)] + """ + sc = SparkContext._active_spark_context + jcols = ListConverter().convert([_to_java_column(c) for c in cols], + sc._gateway._gateway_client) + jc = sc._jvm.Dsl.countDistinct(_to_java_column(col), + sc._jvm.PythonUtils.toSeq(jcols)) + return Column(jc) + + @staticmethod + def approxCountDistinct(col, rsd=None): + """ Return a new Column for approxiate distinct count of (col, *cols) + + >>> from pyspark.sql import Dsl + >>> df.agg(Dsl.approxCountDistinct(df.age).alias('c')).collect() + [Row(c=2)] + """ + sc = SparkContext._active_spark_context + if rsd is None: + jc = sc._jvm.Dsl.approxCountDistinct(_to_java_column(col)) + else: + jc = sc._jvm.Dsl.approxCountDistinct(_to_java_column(col), rsd) + return Column(jc) + + @staticmethod + def udf(f, returnType=StringType()): + """Create a user defined function (UDF) + + >>> slen = Dsl.udf(lambda s: len(s), IntegerType()) + >>> df.select(slen(df.name).alias('slen')).collect() + [Row(slen=5), Row(slen=3)] + """ + return UserDefinedFunction(f, returnType) + + +def _test(): + import doctest + from pyspark.context import SparkContext + from pyspark.sql import Row, SQLContext + import pyspark.sql.dataframe + globs = pyspark.sql.dataframe.__dict__.copy() + sc = SparkContext('local[4]', 'PythonTest') + globs['sc'] = sc + globs['sqlCtx'] = sqlCtx = SQLContext(sc) + rdd2 = sc.parallelize([Row(name='Alice', age=2), Row(name='Bob', age=5)]) + rdd3 = sc.parallelize([Row(name='Tom', height=80), Row(name='Bob', height=85)]) + globs['df'] = sqlCtx.inferSchema(rdd2) + globs['df2'] = sqlCtx.inferSchema(rdd3) + (failure_count, test_count) = doctest.testmod( + pyspark.sql.dataframe, globs=globs, optionflags=doctest.ELLIPSIS) + globs['sc'].stop() + if failure_count: + exit(-1) + + +if __name__ == "__main__": + _test() http://git-wip-us.apache.org/repos/asf/spark/blob/08488c17/python/pyspark/sql/tests.py ---------------------------------------------------------------------- diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py new file mode 100644 index 0000000..d25c636 --- /dev/null +++ b/python/pyspark/sql/tests.py @@ -0,0 +1,300 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Unit tests for pyspark.sql; additional tests are implemented as doctests in +individual modules. +""" +import os +import sys +import pydoc +import shutil +import tempfile + +if sys.version_info[:2] <= (2, 6): + try: + import unittest2 as unittest + except ImportError: + sys.stderr.write('Please install unittest2 to test with Python 2.6 or earlier') + sys.exit(1) +else: + import unittest + + +from pyspark.sql import SQLContext, Column +from pyspark.sql.types import IntegerType, Row, ArrayType, StructType, StructField, \ + UserDefinedType, DoubleType, LongType +from pyspark.tests import ReusedPySparkTestCase + + +class ExamplePointUDT(UserDefinedType): + """ + User-defined type (UDT) for ExamplePoint. + """ + + @classmethod + def sqlType(self): + return ArrayType(DoubleType(), False) + + @classmethod + def module(cls): + return 'pyspark.tests' + + @classmethod + def scalaUDT(cls): + return 'org.apache.spark.sql.test.ExamplePointUDT' + + def serialize(self, obj): + return [obj.x, obj.y] + + def deserialize(self, datum): + return ExamplePoint(datum[0], datum[1]) + + +class ExamplePoint: + """ + An example class to demonstrate UDT in Scala, Java, and Python. + """ + + __UDT__ = ExamplePointUDT() + + def __init__(self, x, y): + self.x = x + self.y = y + + def __repr__(self): + return "ExamplePoint(%s,%s)" % (self.x, self.y) + + def __str__(self): + return "(%s,%s)" % (self.x, self.y) + + def __eq__(self, other): + return isinstance(other, ExamplePoint) and \ + other.x == self.x and other.y == self.y + + +class SQLTests(ReusedPySparkTestCase): + + @classmethod + def setUpClass(cls): + ReusedPySparkTestCase.setUpClass() + cls.tempdir = tempfile.NamedTemporaryFile(delete=False) + os.unlink(cls.tempdir.name) + cls.sqlCtx = SQLContext(cls.sc) + cls.testData = [Row(key=i, value=str(i)) for i in range(100)] + rdd = cls.sc.parallelize(cls.testData) + cls.df = cls.sqlCtx.inferSchema(rdd) + + @classmethod + def tearDownClass(cls): + ReusedPySparkTestCase.tearDownClass() + shutil.rmtree(cls.tempdir.name, ignore_errors=True) + + def test_udf(self): + self.sqlCtx.registerFunction("twoArgs", lambda x, y: len(x) + y, IntegerType()) + [row] = self.sqlCtx.sql("SELECT twoArgs('test', 1)").collect() + self.assertEqual(row[0], 5) + + def test_udf2(self): + self.sqlCtx.registerFunction("strlen", lambda string: len(string), IntegerType()) + self.sqlCtx.inferSchema(self.sc.parallelize([Row(a="test")])).registerTempTable("test") + [res] = self.sqlCtx.sql("SELECT strlen(a) FROM test WHERE strlen(a) > 1").collect() + self.assertEqual(4, res[0]) + + def test_udf_with_array_type(self): + d = [Row(l=range(3), d={"key": range(5)})] + rdd = self.sc.parallelize(d) + self.sqlCtx.inferSchema(rdd).registerTempTable("test") + self.sqlCtx.registerFunction("copylist", lambda l: list(l), ArrayType(IntegerType())) + self.sqlCtx.registerFunction("maplen", lambda d: len(d), IntegerType()) + [(l1, l2)] = self.sqlCtx.sql("select copylist(l), maplen(d) from test").collect() + self.assertEqual(range(3), l1) + self.assertEqual(1, l2) + + def test_broadcast_in_udf(self): + bar = {"a": "aa", "b": "bb", "c": "abc"} + foo = self.sc.broadcast(bar) + self.sqlCtx.registerFunction("MYUDF", lambda x: foo.value[x] if x else '') + [res] = self.sqlCtx.sql("SELECT MYUDF('c')").collect() + self.assertEqual("abc", res[0]) + [res] = self.sqlCtx.sql("SELECT MYUDF('')").collect() + self.assertEqual("", res[0]) + + def test_basic_functions(self): + rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}']) + df = self.sqlCtx.jsonRDD(rdd) + df.count() + df.collect() + df.schema() + + # cache and checkpoint + self.assertFalse(df.is_cached) + df.persist() + df.unpersist() + df.cache() + self.assertTrue(df.is_cached) + self.assertEqual(2, df.count()) + + df.registerTempTable("temp") + df = self.sqlCtx.sql("select foo from temp") + df.count() + df.collect() + + def test_apply_schema_to_row(self): + df = self.sqlCtx.jsonRDD(self.sc.parallelize(["""{"a":2}"""])) + df2 = self.sqlCtx.applySchema(df.map(lambda x: x), df.schema()) + self.assertEqual(df.collect(), df2.collect()) + + rdd = self.sc.parallelize(range(10)).map(lambda x: Row(a=x)) + df3 = self.sqlCtx.applySchema(rdd, df.schema()) + self.assertEqual(10, df3.count()) + + def test_serialize_nested_array_and_map(self): + d = [Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")})] + rdd = self.sc.parallelize(d) + df = self.sqlCtx.inferSchema(rdd) + row = df.head() + self.assertEqual(1, len(row.l)) + self.assertEqual(1, row.l[0].a) + self.assertEqual("2", row.d["key"].d) + + l = df.map(lambda x: x.l).first() + self.assertEqual(1, len(l)) + self.assertEqual('s', l[0].b) + + d = df.map(lambda x: x.d).first() + self.assertEqual(1, len(d)) + self.assertEqual(1.0, d["key"].c) + + row = df.map(lambda x: x.d["key"]).first() + self.assertEqual(1.0, row.c) + self.assertEqual("2", row.d) + + def test_infer_schema(self): + d = [Row(l=[], d={}), + Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")}, s="")] + rdd = self.sc.parallelize(d) + df = self.sqlCtx.inferSchema(rdd) + self.assertEqual([], df.map(lambda r: r.l).first()) + self.assertEqual([None, ""], df.map(lambda r: r.s).collect()) + df.registerTempTable("test") + result = self.sqlCtx.sql("SELECT l[0].a from test where d['key'].d = '2'") + self.assertEqual(1, result.head()[0]) + + df2 = self.sqlCtx.inferSchema(rdd, 1.0) + self.assertEqual(df.schema(), df2.schema()) + self.assertEqual({}, df2.map(lambda r: r.d).first()) + self.assertEqual([None, ""], df2.map(lambda r: r.s).collect()) + df2.registerTempTable("test2") + result = self.sqlCtx.sql("SELECT l[0].a from test2 where d['key'].d = '2'") + self.assertEqual(1, result.head()[0]) + + def test_struct_in_map(self): + d = [Row(m={Row(i=1): Row(s="")})] + rdd = self.sc.parallelize(d) + df = self.sqlCtx.inferSchema(rdd) + k, v = df.head().m.items()[0] + self.assertEqual(1, k.i) + self.assertEqual("", v.s) + + def test_convert_row_to_dict(self): + row = Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")}) + self.assertEqual(1, row.asDict()['l'][0].a) + rdd = self.sc.parallelize([row]) + df = self.sqlCtx.inferSchema(rdd) + df.registerTempTable("test") + row = self.sqlCtx.sql("select l, d from test").head() + self.assertEqual(1, row.asDict()["l"][0].a) + self.assertEqual(1.0, row.asDict()['d']['key'].c) + + def test_infer_schema_with_udt(self): + from pyspark.sql.tests import ExamplePoint, ExamplePointUDT + row = Row(label=1.0, point=ExamplePoint(1.0, 2.0)) + rdd = self.sc.parallelize([row]) + df = self.sqlCtx.inferSchema(rdd) + schema = df.schema() + field = [f for f in schema.fields if f.name == "point"][0] + self.assertEqual(type(field.dataType), ExamplePointUDT) + df.registerTempTable("labeled_point") + point = self.sqlCtx.sql("SELECT point FROM labeled_point").head().point + self.assertEqual(point, ExamplePoint(1.0, 2.0)) + + def test_apply_schema_with_udt(self): + from pyspark.sql.tests import ExamplePoint, ExamplePointUDT + row = (1.0, ExamplePoint(1.0, 2.0)) + rdd = self.sc.parallelize([row]) + schema = StructType([StructField("label", DoubleType(), False), + StructField("point", ExamplePointUDT(), False)]) + df = self.sqlCtx.applySchema(rdd, schema) + point = df.head().point + self.assertEquals(point, ExamplePoint(1.0, 2.0)) + + def test_parquet_with_udt(self): + from pyspark.sql.tests import ExamplePoint + row = Row(label=1.0, point=ExamplePoint(1.0, 2.0)) + rdd = self.sc.parallelize([row]) + df0 = self.sqlCtx.inferSchema(rdd) + output_dir = os.path.join(self.tempdir.name, "labeled_point") + df0.saveAsParquetFile(output_dir) + df1 = self.sqlCtx.parquetFile(output_dir) + point = df1.head().point + self.assertEquals(point, ExamplePoint(1.0, 2.0)) + + def test_column_operators(self): + ci = self.df.key + cs = self.df.value + c = ci == cs + self.assertTrue(isinstance((- ci - 1 - 2) % 3 * 2.5 / 3.5, Column)) + rcc = (1 + ci), (1 - ci), (1 * ci), (1 / ci), (1 % ci) + self.assertTrue(all(isinstance(c, Column) for c in rcc)) + cb = [ci == 5, ci != 0, ci > 3, ci < 4, ci >= 0, ci <= 7, ci and cs, ci or cs] + self.assertTrue(all(isinstance(c, Column) for c in cb)) + cbool = (ci & ci), (ci | ci), (~ci) + self.assertTrue(all(isinstance(c, Column) for c in cbool)) + css = cs.like('a'), cs.rlike('a'), cs.asc(), cs.desc(), cs.startswith('a'), cs.endswith('a') + self.assertTrue(all(isinstance(c, Column) for c in css)) + self.assertTrue(isinstance(ci.cast(LongType()), Column)) + + def test_column_select(self): + df = self.df + self.assertEqual(self.testData, df.select("*").collect()) + self.assertEqual(self.testData, df.select(df.key, df.value).collect()) + self.assertEqual([Row(value='1')], df.where(df.key == 1).select(df.value).collect()) + + def test_aggregator(self): + df = self.df + g = df.groupBy() + self.assertEqual([99, 100], sorted(g.agg({'key': 'max', 'value': 'count'}).collect()[0])) + self.assertEqual([Row(**{"AVG(key#0)": 49.5})], g.mean().collect()) + + from pyspark.sql import Dsl + self.assertEqual((0, u'99'), tuple(g.agg(Dsl.first(df.key), Dsl.last(df.value)).first())) + self.assertTrue(95 < g.agg(Dsl.approxCountDistinct(df.key)).first()[0]) + self.assertEqual(100, g.agg(Dsl.countDistinct(df.value)).first()[0]) + + def test_help_command(self): + # Regression test for SPARK-5464 + rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}']) + df = self.sqlCtx.jsonRDD(rdd) + # render_doc() reproduces the help() exception without printing output + pydoc.render_doc(df) + pydoc.render_doc(df.foo) + pydoc.render_doc(df.take(1)) + + +if __name__ == "__main__": + unittest.main() --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org