http://git-wip-us.apache.org/repos/asf/spark/blob/40ed2af5/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLTab.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLTab.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLTab.scala new file mode 100644 index 0000000..0b0867f --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLTab.scala @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.ui + +import java.util.concurrent.atomic.AtomicInteger + +import org.apache.spark.Logging +import org.apache.spark.sql.SQLContext +import org.apache.spark.ui.{SparkUI, SparkUITab} + +private[sql] class SQLTab(sqlContext: SQLContext, sparkUI: SparkUI) + extends SparkUITab(sparkUI, SQLTab.nextTabName) with Logging { + + val parent = sparkUI + val listener = sqlContext.listener + + attachPage(new AllExecutionsPage(this)) + attachPage(new ExecutionPage(this)) + parent.attachTab(this) + + parent.addStaticHandler(SQLTab.STATIC_RESOURCE_DIR, "/static/sql") +} + +private[sql] object SQLTab { + + private val STATIC_RESOURCE_DIR = "org/apache/spark/sql/execution/ui/static" + + private val nextTabId = new AtomicInteger(0) + + private def nextTabName: String = { + val nextId = nextTabId.getAndIncrement() + if (nextId == 0) "SQL" else s"SQL$nextId" + } +}
http://git-wip-us.apache.org/repos/asf/spark/blob/40ed2af5/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala new file mode 100644 index 0000000..ae3d752 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala @@ -0,0 +1,118 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.ui + +import java.util.concurrent.atomic.AtomicLong + +import scala.collection.mutable + +import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.metric.{SQLMetricParam, SQLMetricValue} + +/** + * A graph used for storing information of an executionPlan of DataFrame. + * + * Each graph is defined with a set of nodes and a set of edges. Each node represents a node in the + * SparkPlan tree, and each edge represents a parent-child relationship between two nodes. + */ +private[ui] case class SparkPlanGraph( + nodes: Seq[SparkPlanGraphNode], edges: Seq[SparkPlanGraphEdge]) { + + def makeDotFile(metrics: Map[Long, Any]): String = { + val dotFile = new StringBuilder + dotFile.append("digraph G {\n") + nodes.foreach(node => dotFile.append(node.makeDotNode(metrics) + "\n")) + edges.foreach(edge => dotFile.append(edge.makeDotEdge + "\n")) + dotFile.append("}") + dotFile.toString() + } +} + +private[sql] object SparkPlanGraph { + + /** + * Build a SparkPlanGraph from the root of a SparkPlan tree. + */ + def apply(plan: SparkPlan): SparkPlanGraph = { + val nodeIdGenerator = new AtomicLong(0) + val nodes = mutable.ArrayBuffer[SparkPlanGraphNode]() + val edges = mutable.ArrayBuffer[SparkPlanGraphEdge]() + buildSparkPlanGraphNode(plan, nodeIdGenerator, nodes, edges) + new SparkPlanGraph(nodes, edges) + } + + private def buildSparkPlanGraphNode( + plan: SparkPlan, + nodeIdGenerator: AtomicLong, + nodes: mutable.ArrayBuffer[SparkPlanGraphNode], + edges: mutable.ArrayBuffer[SparkPlanGraphEdge]): SparkPlanGraphNode = { + val metrics = plan.metrics.toSeq.map { case (key, metric) => + SQLPlanMetric(metric.name.getOrElse(key), metric.id, + metric.param.asInstanceOf[SQLMetricParam[SQLMetricValue[Any], Any]]) + } + val node = SparkPlanGraphNode( + nodeIdGenerator.getAndIncrement(), plan.nodeName, plan.simpleString, metrics) + nodes += node + val childrenNodes = plan.children.map( + child => buildSparkPlanGraphNode(child, nodeIdGenerator, nodes, edges)) + for (child <- childrenNodes) { + edges += SparkPlanGraphEdge(child.id, node.id) + } + node + } +} + +/** + * Represent a node in the SparkPlan tree, along with its metrics. + * + * @param id generated by "SparkPlanGraph". There is no duplicate id in a graph + * @param name the name of this SparkPlan node + * @param metrics metrics that this SparkPlan node will track + */ +private[ui] case class SparkPlanGraphNode( + id: Long, name: String, desc: String, metrics: Seq[SQLPlanMetric]) { + + def makeDotNode(metricsValue: Map[Long, Any]): String = { + val values = { + for (metric <- metrics; + value <- metricsValue.get(metric.accumulatorId)) yield { + metric.name + ": " + value + } + } + val label = if (values.isEmpty) { + name + } else { + // If there are metrics, display all metrics in a separate line. We should use an escaped + // "\n" here to follow the dot syntax. + // + // Note: whitespace between two "\n"s is to create an empty line between the name of + // SparkPlan and metrics. If removing it, it won't display the empty line in UI. + name + "\\n \\n" + values.mkString("\\n") + } + s""" $id [label="$label"];""" + } +} + +/** + * Represent an edge in the SparkPlan tree. `fromId` is the parent node id, and `toId` is the child + * node id. + */ +private[ui] case class SparkPlanGraphEdge(fromId: Long, toId: Long) { + + def makeDotEdge: String = s""" $fromId->$toId;\n""" +} http://git-wip-us.apache.org/repos/asf/spark/blob/40ed2af5/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala deleted file mode 100644 index 3cf70db..0000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala +++ /dev/null @@ -1,490 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.jdbc - -import java.sql.{Connection, DriverManager, ResultSet, ResultSetMetaData, SQLException} -import java.util.Properties - -import org.apache.commons.lang3.StringUtils - -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.SpecificMutableRow -import org.apache.spark.sql.catalyst.util.DateTimeUtils -import org.apache.spark.sql.sources._ -import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.UTF8String -import org.apache.spark.{Logging, Partition, SparkContext, TaskContext} - -/** - * Data corresponding to one partition of a JDBCRDD. - */ -private[sql] case class JDBCPartition(whereClause: String, idx: Int) extends Partition { - override def index: Int = idx -} - - -private[sql] object JDBCRDD extends Logging { - - /** - * Maps a JDBC type to a Catalyst type. This function is called only when - * the JdbcDialect class corresponding to your database driver returns null. - * - * @param sqlType - A field of java.sql.Types - * @return The Catalyst type corresponding to sqlType. - */ - private def getCatalystType( - sqlType: Int, - precision: Int, - scale: Int, - signed: Boolean): DataType = { - val answer = sqlType match { - // scalastyle:off - case java.sql.Types.ARRAY => null - case java.sql.Types.BIGINT => if (signed) { LongType } else { DecimalType(20,0) } - case java.sql.Types.BINARY => BinaryType - case java.sql.Types.BIT => BooleanType // @see JdbcDialect for quirks - case java.sql.Types.BLOB => BinaryType - case java.sql.Types.BOOLEAN => BooleanType - case java.sql.Types.CHAR => StringType - case java.sql.Types.CLOB => StringType - case java.sql.Types.DATALINK => null - case java.sql.Types.DATE => DateType - case java.sql.Types.DECIMAL - if precision != 0 || scale != 0 => DecimalType.bounded(precision, scale) - case java.sql.Types.DECIMAL => DecimalType.SYSTEM_DEFAULT - case java.sql.Types.DISTINCT => null - case java.sql.Types.DOUBLE => DoubleType - case java.sql.Types.FLOAT => FloatType - case java.sql.Types.INTEGER => if (signed) { IntegerType } else { LongType } - case java.sql.Types.JAVA_OBJECT => null - case java.sql.Types.LONGNVARCHAR => StringType - case java.sql.Types.LONGVARBINARY => BinaryType - case java.sql.Types.LONGVARCHAR => StringType - case java.sql.Types.NCHAR => StringType - case java.sql.Types.NCLOB => StringType - case java.sql.Types.NULL => null - case java.sql.Types.NUMERIC - if precision != 0 || scale != 0 => DecimalType.bounded(precision, scale) - case java.sql.Types.NUMERIC => DecimalType.SYSTEM_DEFAULT - case java.sql.Types.NVARCHAR => StringType - case java.sql.Types.OTHER => null - case java.sql.Types.REAL => DoubleType - case java.sql.Types.REF => StringType - case java.sql.Types.ROWID => LongType - case java.sql.Types.SMALLINT => IntegerType - case java.sql.Types.SQLXML => StringType - case java.sql.Types.STRUCT => StringType - case java.sql.Types.TIME => TimestampType - case java.sql.Types.TIMESTAMP => TimestampType - case java.sql.Types.TINYINT => IntegerType - case java.sql.Types.VARBINARY => BinaryType - case java.sql.Types.VARCHAR => StringType - case _ => null - // scalastyle:on - } - - if (answer == null) throw new SQLException("Unsupported type " + sqlType) - answer - } - - /** - * Takes a (schema, table) specification and returns the table's Catalyst - * schema. - * - * @param url - The JDBC url to fetch information from. - * @param table - The table name of the desired table. This may also be a - * SQL query wrapped in parentheses. - * - * @return A StructType giving the table's Catalyst schema. - * @throws SQLException if the table specification is garbage. - * @throws SQLException if the table contains an unsupported type. - */ - def resolveTable(url: String, table: String, properties: Properties): StructType = { - val dialect = JdbcDialects.get(url) - val conn: Connection = DriverManager.getConnection(url, properties) - try { - val rs = conn.prepareStatement(s"SELECT * FROM $table WHERE 1=0").executeQuery() - try { - val rsmd = rs.getMetaData - val ncols = rsmd.getColumnCount - val fields = new Array[StructField](ncols) - var i = 0 - while (i < ncols) { - val columnName = rsmd.getColumnLabel(i + 1) - val dataType = rsmd.getColumnType(i + 1) - val typeName = rsmd.getColumnTypeName(i + 1) - val fieldSize = rsmd.getPrecision(i + 1) - val fieldScale = rsmd.getScale(i + 1) - val isSigned = rsmd.isSigned(i + 1) - val nullable = rsmd.isNullable(i + 1) != ResultSetMetaData.columnNoNulls - val metadata = new MetadataBuilder().putString("name", columnName) - val columnType = - dialect.getCatalystType(dataType, typeName, fieldSize, metadata).getOrElse( - getCatalystType(dataType, fieldSize, fieldScale, isSigned)) - fields(i) = StructField(columnName, columnType, nullable, metadata.build()) - i = i + 1 - } - return new StructType(fields) - } finally { - rs.close() - } - } finally { - conn.close() - } - - throw new RuntimeException("This line is unreachable.") - } - - /** - * Prune all but the specified columns from the specified Catalyst schema. - * - * @param schema - The Catalyst schema of the master table - * @param columns - The list of desired columns - * - * @return A Catalyst schema corresponding to columns in the given order. - */ - private def pruneSchema(schema: StructType, columns: Array[String]): StructType = { - val fieldMap = Map(schema.fields map { x => x.metadata.getString("name") -> x }: _*) - new StructType(columns map { name => fieldMap(name) }) - } - - /** - * Given a driver string and an url, return a function that loads the - * specified driver string then returns a connection to the JDBC url. - * getConnector is run on the driver code, while the function it returns - * is run on the executor. - * - * @param driver - The class name of the JDBC driver for the given url. - * @param url - The JDBC url to connect to. - * - * @return A function that loads the driver and connects to the url. - */ - def getConnector(driver: String, url: String, properties: Properties): () => Connection = { - () => { - try { - if (driver != null) DriverRegistry.register(driver) - } catch { - case e: ClassNotFoundException => { - logWarning(s"Couldn't find class $driver", e); - } - } - DriverManager.getConnection(url, properties) - } - } - - /** - * Build and return JDBCRDD from the given information. - * - * @param sc - Your SparkContext. - * @param schema - The Catalyst schema of the underlying database table. - * @param driver - The class name of the JDBC driver for the given url. - * @param url - The JDBC url to connect to. - * @param fqTable - The fully-qualified table name (or paren'd SQL query) to use. - * @param requiredColumns - The names of the columns to SELECT. - * @param filters - The filters to include in all WHERE clauses. - * @param parts - An array of JDBCPartitions specifying partition ids and - * per-partition WHERE clauses. - * - * @return An RDD representing "SELECT requiredColumns FROM fqTable". - */ - def scanTable( - sc: SparkContext, - schema: StructType, - driver: String, - url: String, - properties: Properties, - fqTable: String, - requiredColumns: Array[String], - filters: Array[Filter], - parts: Array[Partition]): RDD[InternalRow] = { - val dialect = JdbcDialects.get(url) - val quotedColumns = requiredColumns.map(colName => dialect.quoteIdentifier(colName)) - new JDBCRDD( - sc, - getConnector(driver, url, properties), - pruneSchema(schema, requiredColumns), - fqTable, - quotedColumns, - filters, - parts, - properties) - } -} - -/** - * An RDD representing a table in a database accessed via JDBC. Both the - * driver code and the workers must be able to access the database; the driver - * needs to fetch the schema while the workers need to fetch the data. - */ -private[sql] class JDBCRDD( - sc: SparkContext, - getConnection: () => Connection, - schema: StructType, - fqTable: String, - columns: Array[String], - filters: Array[Filter], - partitions: Array[Partition], - properties: Properties) - extends RDD[InternalRow](sc, Nil) { - - /** - * Retrieve the list of partitions corresponding to this RDD. - */ - override def getPartitions: Array[Partition] = partitions - - /** - * `columns`, but as a String suitable for injection into a SQL query. - */ - private val columnList: String = { - val sb = new StringBuilder() - columns.foreach(x => sb.append(",").append(x)) - if (sb.length == 0) "1" else sb.substring(1) - } - - /** - * Converts value to SQL expression. - */ - private def compileValue(value: Any): Any = value match { - case stringValue: UTF8String => s"'${escapeSql(stringValue.toString)}'" - case _ => value - } - - private def escapeSql(value: String): String = - if (value == null) null else StringUtils.replace(value, "'", "''") - - /** - * Turns a single Filter into a String representing a SQL expression. - * Returns null for an unhandled filter. - */ - private def compileFilter(f: Filter): String = f match { - case EqualTo(attr, value) => s"$attr = ${compileValue(value)}" - case LessThan(attr, value) => s"$attr < ${compileValue(value)}" - case GreaterThan(attr, value) => s"$attr > ${compileValue(value)}" - case LessThanOrEqual(attr, value) => s"$attr <= ${compileValue(value)}" - case GreaterThanOrEqual(attr, value) => s"$attr >= ${compileValue(value)}" - case _ => null - } - - /** - * `filters`, but as a WHERE clause suitable for injection into a SQL query. - */ - private val filterWhereClause: String = { - val filterStrings = filters map compileFilter filter (_ != null) - if (filterStrings.size > 0) { - val sb = new StringBuilder("WHERE ") - filterStrings.foreach(x => sb.append(x).append(" AND ")) - sb.substring(0, sb.length - 5) - } else "" - } - - /** - * A WHERE clause representing both `filters`, if any, and the current partition. - */ - private def getWhereClause(part: JDBCPartition): String = { - if (part.whereClause != null && filterWhereClause.length > 0) { - filterWhereClause + " AND " + part.whereClause - } else if (part.whereClause != null) { - "WHERE " + part.whereClause - } else { - filterWhereClause - } - } - - // Each JDBC-to-Catalyst conversion corresponds to a tag defined here so that - // we don't have to potentially poke around in the Metadata once for every - // row. - // Is there a better way to do this? I'd rather be using a type that - // contains only the tags I define. - abstract class JDBCConversion - case object BooleanConversion extends JDBCConversion - case object DateConversion extends JDBCConversion - case class DecimalConversion(precision: Int, scale: Int) extends JDBCConversion - case object DoubleConversion extends JDBCConversion - case object FloatConversion extends JDBCConversion - case object IntegerConversion extends JDBCConversion - case object LongConversion extends JDBCConversion - case object BinaryLongConversion extends JDBCConversion - case object StringConversion extends JDBCConversion - case object TimestampConversion extends JDBCConversion - case object BinaryConversion extends JDBCConversion - - /** - * Maps a StructType to a type tag list. - */ - def getConversions(schema: StructType): Array[JDBCConversion] = { - schema.fields.map(sf => sf.dataType match { - case BooleanType => BooleanConversion - case DateType => DateConversion - case DecimalType.Fixed(p, s) => DecimalConversion(p, s) - case DoubleType => DoubleConversion - case FloatType => FloatConversion - case IntegerType => IntegerConversion - case LongType => - if (sf.metadata.contains("binarylong")) BinaryLongConversion else LongConversion - case StringType => StringConversion - case TimestampType => TimestampConversion - case BinaryType => BinaryConversion - case _ => throw new IllegalArgumentException(s"Unsupported field $sf") - }).toArray - } - - - /** - * Runs the SQL query against the JDBC driver. - */ - override def compute(thePart: Partition, context: TaskContext): Iterator[InternalRow] = - new Iterator[InternalRow] { - var closed = false - var finished = false - var gotNext = false - var nextValue: InternalRow = null - - context.addTaskCompletionListener{ context => close() } - val part = thePart.asInstanceOf[JDBCPartition] - val conn = getConnection() - - // H2's JDBC driver does not support the setSchema() method. We pass a - // fully-qualified table name in the SELECT statement. I don't know how to - // talk about a table in a completely portable way. - - val myWhereClause = getWhereClause(part) - - val sqlText = s"SELECT $columnList FROM $fqTable $myWhereClause" - val stmt = conn.prepareStatement(sqlText, - ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY) - val fetchSize = properties.getProperty("fetchSize", "0").toInt - stmt.setFetchSize(fetchSize) - val rs = stmt.executeQuery() - - val conversions = getConversions(schema) - val mutableRow = new SpecificMutableRow(schema.fields.map(x => x.dataType)) - - def getNext(): InternalRow = { - if (rs.next()) { - var i = 0 - while (i < conversions.length) { - val pos = i + 1 - conversions(i) match { - case BooleanConversion => mutableRow.setBoolean(i, rs.getBoolean(pos)) - case DateConversion => - // DateTimeUtils.fromJavaDate does not handle null value, so we need to check it. - val dateVal = rs.getDate(pos) - if (dateVal != null) { - mutableRow.setInt(i, DateTimeUtils.fromJavaDate(dateVal)) - } else { - mutableRow.update(i, null) - } - // When connecting with Oracle DB through JDBC, the precision and scale of BigDecimal - // object returned by ResultSet.getBigDecimal is not correctly matched to the table - // schema reported by ResultSetMetaData.getPrecision and ResultSetMetaData.getScale. - // If inserting values like 19999 into a column with NUMBER(12, 2) type, you get through - // a BigDecimal object with scale as 0. But the dataframe schema has correct type as - // DecimalType(12, 2). Thus, after saving the dataframe into parquet file and then - // retrieve it, you will get wrong result 199.99. - // So it is needed to set precision and scale for Decimal based on JDBC metadata. - case DecimalConversion(p, s) => - val decimalVal = rs.getBigDecimal(pos) - if (decimalVal == null) { - mutableRow.update(i, null) - } else { - mutableRow.update(i, Decimal(decimalVal, p, s)) - } - case DoubleConversion => mutableRow.setDouble(i, rs.getDouble(pos)) - case FloatConversion => mutableRow.setFloat(i, rs.getFloat(pos)) - case IntegerConversion => mutableRow.setInt(i, rs.getInt(pos)) - case LongConversion => mutableRow.setLong(i, rs.getLong(pos)) - // TODO(davies): use getBytes for better performance, if the encoding is UTF-8 - case StringConversion => mutableRow.update(i, UTF8String.fromString(rs.getString(pos))) - case TimestampConversion => - val t = rs.getTimestamp(pos) - if (t != null) { - mutableRow.setLong(i, DateTimeUtils.fromJavaTimestamp(t)) - } else { - mutableRow.update(i, null) - } - case BinaryConversion => mutableRow.update(i, rs.getBytes(pos)) - case BinaryLongConversion => { - val bytes = rs.getBytes(pos) - var ans = 0L - var j = 0 - while (j < bytes.size) { - ans = 256 * ans + (255 & bytes(j)) - j = j + 1; - } - mutableRow.setLong(i, ans) - } - } - if (rs.wasNull) mutableRow.setNullAt(i) - i = i + 1 - } - mutableRow - } else { - finished = true - null.asInstanceOf[InternalRow] - } - } - - def close() { - if (closed) return - try { - if (null != rs) { - rs.close() - } - } catch { - case e: Exception => logWarning("Exception closing resultset", e) - } - try { - if (null != stmt) { - stmt.close() - } - } catch { - case e: Exception => logWarning("Exception closing statement", e) - } - try { - if (null != conn) { - conn.close() - } - logInfo("closed connection") - } catch { - case e: Exception => logWarning("Exception closing connection", e) - } - } - - override def hasNext: Boolean = { - if (!finished) { - if (!gotNext) { - nextValue = getNext() - if (finished) { - close() - } - gotNext = true - } - } - !finished - } - - override def next(): InternalRow = { - if (!hasNext) { - throw new NoSuchElementException("End of stream") - } - gotNext = false - nextValue - } - } -} http://git-wip-us.apache.org/repos/asf/spark/blob/40ed2af5/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala deleted file mode 100644 index 48d97ce..0000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala +++ /dev/null @@ -1,152 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.jdbc - -import java.util.Properties - -import scala.collection.mutable.ArrayBuffer - -import org.apache.spark.Partition -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.sources._ -import org.apache.spark.sql.types.StructType -import org.apache.spark.sql.{DataFrame, Row, SQLContext, SaveMode} - -/** - * Instructions on how to partition the table among workers. - */ -private[sql] case class JDBCPartitioningInfo( - column: String, - lowerBound: Long, - upperBound: Long, - numPartitions: Int) - -private[sql] object JDBCRelation { - /** - * Given a partitioning schematic (a column of integral type, a number of - * partitions, and upper and lower bounds on the column's value), generate - * WHERE clauses for each partition so that each row in the table appears - * exactly once. The parameters minValue and maxValue are advisory in that - * incorrect values may cause the partitioning to be poor, but no data - * will fail to be represented. - */ - def columnPartition(partitioning: JDBCPartitioningInfo): Array[Partition] = { - if (partitioning == null) return Array[Partition](JDBCPartition(null, 0)) - - val numPartitions = partitioning.numPartitions - val column = partitioning.column - if (numPartitions == 1) return Array[Partition](JDBCPartition(null, 0)) - // Overflow and silliness can happen if you subtract then divide. - // Here we get a little roundoff, but that's (hopefully) OK. - val stride: Long = (partitioning.upperBound / numPartitions - - partitioning.lowerBound / numPartitions) - var i: Int = 0 - var currentValue: Long = partitioning.lowerBound - var ans = new ArrayBuffer[Partition]() - while (i < numPartitions) { - val lowerBound = if (i != 0) s"$column >= $currentValue" else null - currentValue += stride - val upperBound = if (i != numPartitions - 1) s"$column < $currentValue" else null - val whereClause = - if (upperBound == null) { - lowerBound - } else if (lowerBound == null) { - upperBound - } else { - s"$lowerBound AND $upperBound" - } - ans += JDBCPartition(whereClause, i) - i = i + 1 - } - ans.toArray - } -} - -private[sql] class DefaultSource extends RelationProvider with DataSourceRegister { - - def format(): String = "jdbc" - - /** Returns a new base relation with the given parameters. */ - override def createRelation( - sqlContext: SQLContext, - parameters: Map[String, String]): BaseRelation = { - val url = parameters.getOrElse("url", sys.error("Option 'url' not specified")) - val driver = parameters.getOrElse("driver", null) - val table = parameters.getOrElse("dbtable", sys.error("Option 'dbtable' not specified")) - val partitionColumn = parameters.getOrElse("partitionColumn", null) - val lowerBound = parameters.getOrElse("lowerBound", null) - val upperBound = parameters.getOrElse("upperBound", null) - val numPartitions = parameters.getOrElse("numPartitions", null) - - if (driver != null) DriverRegistry.register(driver) - - if (partitionColumn != null - && (lowerBound == null || upperBound == null || numPartitions == null)) { - sys.error("Partitioning incompletely specified") - } - - val partitionInfo = if (partitionColumn == null) { - null - } else { - JDBCPartitioningInfo( - partitionColumn, - lowerBound.toLong, - upperBound.toLong, - numPartitions.toInt) - } - val parts = JDBCRelation.columnPartition(partitionInfo) - val properties = new Properties() // Additional properties that we will pass to getConnection - parameters.foreach(kv => properties.setProperty(kv._1, kv._2)) - JDBCRelation(url, table, parts, properties)(sqlContext) - } -} - -private[sql] case class JDBCRelation( - url: String, - table: String, - parts: Array[Partition], - properties: Properties = new Properties())(@transient val sqlContext: SQLContext) - extends BaseRelation - with PrunedFilteredScan - with InsertableRelation { - - override val needConversion: Boolean = false - - override val schema: StructType = JDBCRDD.resolveTable(url, table, properties) - - override def buildScan(requiredColumns: Array[String], filters: Array[Filter]): RDD[Row] = { - val driver: String = DriverRegistry.getDriverClassName(url) - // Rely on a type erasure hack to pass RDD[InternalRow] back as RDD[Row] - JDBCRDD.scanTable( - sqlContext.sparkContext, - schema, - driver, - url, - properties, - table, - requiredColumns, - filters, - parts).asInstanceOf[RDD[Row]] - } - - override def insert(data: DataFrame, overwrite: Boolean): Unit = { - data.write - .mode(if (overwrite) SaveMode.Overwrite else SaveMode.Append) - .jdbc(url, table, properties) - } -} http://git-wip-us.apache.org/repos/asf/spark/blob/40ed2af5/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcUtils.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcUtils.scala deleted file mode 100644 index cc918c2..0000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcUtils.scala +++ /dev/null @@ -1,52 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.jdbc - -import java.sql.{Connection, DriverManager} -import java.util.Properties - -import scala.util.Try - -/** - * Util functions for JDBC tables. - */ -private[sql] object JdbcUtils { - - /** - * Establishes a JDBC connection. - */ - def createConnection(url: String, connectionProperties: Properties): Connection = { - DriverManager.getConnection(url, connectionProperties) - } - - /** - * Returns true if the table already exists in the JDBC database. - */ - def tableExists(conn: Connection, table: String): Boolean = { - // Somewhat hacky, but there isn't a good way to identify whether a table exists for all - // SQL database systems, considering "table" could also include the database name. - Try(conn.prepareStatement(s"SELECT 1 FROM $table LIMIT 1").executeQuery().next()).isSuccess - } - - /** - * Drops a table from the JDBC database. - */ - def dropTable(conn: Connection, table: String): Unit = { - conn.prepareStatement(s"DROP TABLE $table").executeUpdate() - } -} http://git-wip-us.apache.org/repos/asf/spark/blob/40ed2af5/sql/core/src/main/scala/org/apache/spark/sql/jdbc/jdbc.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/jdbc.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/jdbc.scala deleted file mode 100644 index 035e051..0000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/jdbc.scala +++ /dev/null @@ -1,250 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql - -import java.sql.{Connection, Driver, DriverManager, DriverPropertyInfo, PreparedStatement, SQLFeatureNotSupportedException} -import java.util.Properties - -import scala.collection.mutable - -import org.apache.spark.Logging -import org.apache.spark.sql.types._ -import org.apache.spark.util.Utils - -package object jdbc { - private[sql] object JDBCWriteDetails extends Logging { - /** - * Returns a PreparedStatement that inserts a row into table via conn. - */ - def insertStatement(conn: Connection, table: String, rddSchema: StructType): - PreparedStatement = { - val sql = new StringBuilder(s"INSERT INTO $table VALUES (") - var fieldsLeft = rddSchema.fields.length - while (fieldsLeft > 0) { - sql.append("?") - if (fieldsLeft > 1) sql.append(", ") else sql.append(")") - fieldsLeft = fieldsLeft - 1 - } - conn.prepareStatement(sql.toString) - } - - /** - * Saves a partition of a DataFrame to the JDBC database. This is done in - * a single database transaction in order to avoid repeatedly inserting - * data as much as possible. - * - * It is still theoretically possible for rows in a DataFrame to be - * inserted into the database more than once if a stage somehow fails after - * the commit occurs but before the stage can return successfully. - * - * This is not a closure inside saveTable() because apparently cosmetic - * implementation changes elsewhere might easily render such a closure - * non-Serializable. Instead, we explicitly close over all variables that - * are used. - */ - def savePartition( - getConnection: () => Connection, - table: String, - iterator: Iterator[Row], - rddSchema: StructType, - nullTypes: Array[Int]): Iterator[Byte] = { - val conn = getConnection() - var committed = false - try { - conn.setAutoCommit(false) // Everything in the same db transaction. - val stmt = insertStatement(conn, table, rddSchema) - try { - while (iterator.hasNext) { - val row = iterator.next() - val numFields = rddSchema.fields.length - var i = 0 - while (i < numFields) { - if (row.isNullAt(i)) { - stmt.setNull(i + 1, nullTypes(i)) - } else { - rddSchema.fields(i).dataType match { - case IntegerType => stmt.setInt(i + 1, row.getInt(i)) - case LongType => stmt.setLong(i + 1, row.getLong(i)) - case DoubleType => stmt.setDouble(i + 1, row.getDouble(i)) - case FloatType => stmt.setFloat(i + 1, row.getFloat(i)) - case ShortType => stmt.setInt(i + 1, row.getShort(i)) - case ByteType => stmt.setInt(i + 1, row.getByte(i)) - case BooleanType => stmt.setBoolean(i + 1, row.getBoolean(i)) - case StringType => stmt.setString(i + 1, row.getString(i)) - case BinaryType => stmt.setBytes(i + 1, row.getAs[Array[Byte]](i)) - case TimestampType => stmt.setTimestamp(i + 1, row.getAs[java.sql.Timestamp](i)) - case DateType => stmt.setDate(i + 1, row.getAs[java.sql.Date](i)) - case t: DecimalType => stmt.setBigDecimal(i + 1, row.getDecimal(i)) - case _ => throw new IllegalArgumentException( - s"Can't translate non-null value for field $i") - } - } - i = i + 1 - } - stmt.executeUpdate() - } - } finally { - stmt.close() - } - conn.commit() - committed = true - } finally { - if (!committed) { - // The stage must fail. We got here through an exception path, so - // let the exception through unless rollback() or close() want to - // tell the user about another problem. - conn.rollback() - conn.close() - } else { - // The stage must succeed. We cannot propagate any exception close() might throw. - try { - conn.close() - } catch { - case e: Exception => logWarning("Transaction succeeded, but closing failed", e) - } - } - } - Array[Byte]().iterator - } - - /** - * Compute the schema string for this RDD. - */ - def schemaString(df: DataFrame, url: String): String = { - val sb = new StringBuilder() - val dialect = JdbcDialects.get(url) - df.schema.fields foreach { field => { - val name = field.name - val typ: String = - dialect.getJDBCType(field.dataType).map(_.databaseTypeDefinition).getOrElse( - field.dataType match { - case IntegerType => "INTEGER" - case LongType => "BIGINT" - case DoubleType => "DOUBLE PRECISION" - case FloatType => "REAL" - case ShortType => "INTEGER" - case ByteType => "BYTE" - case BooleanType => "BIT(1)" - case StringType => "TEXT" - case BinaryType => "BLOB" - case TimestampType => "TIMESTAMP" - case DateType => "DATE" - case t: DecimalType => s"DECIMAL(${t.precision}},${t.scale}})" - case _ => throw new IllegalArgumentException(s"Don't know how to save $field to JDBC") - }) - val nullable = if (field.nullable) "" else "NOT NULL" - sb.append(s", $name $typ $nullable") - }} - if (sb.length < 2) "" else sb.substring(2) - } - - /** - * Saves the RDD to the database in a single transaction. - */ - def saveTable( - df: DataFrame, - url: String, - table: String, - properties: Properties = new Properties()) { - val dialect = JdbcDialects.get(url) - val nullTypes: Array[Int] = df.schema.fields.map { field => - dialect.getJDBCType(field.dataType).map(_.jdbcNullType).getOrElse( - field.dataType match { - case IntegerType => java.sql.Types.INTEGER - case LongType => java.sql.Types.BIGINT - case DoubleType => java.sql.Types.DOUBLE - case FloatType => java.sql.Types.REAL - case ShortType => java.sql.Types.INTEGER - case ByteType => java.sql.Types.INTEGER - case BooleanType => java.sql.Types.BIT - case StringType => java.sql.Types.CLOB - case BinaryType => java.sql.Types.BLOB - case TimestampType => java.sql.Types.TIMESTAMP - case DateType => java.sql.Types.DATE - case t: DecimalType => java.sql.Types.DECIMAL - case _ => throw new IllegalArgumentException( - s"Can't translate null value for field $field") - }) - } - - val rddSchema = df.schema - val driver: String = DriverRegistry.getDriverClassName(url) - val getConnection: () => Connection = JDBCRDD.getConnector(driver, url, properties) - df.foreachPartition { iterator => - JDBCWriteDetails.savePartition(getConnection, table, iterator, rddSchema, nullTypes) - } - } - - } - - private [sql] class DriverWrapper(val wrapped: Driver) extends Driver { - override def acceptsURL(url: String): Boolean = wrapped.acceptsURL(url) - - override def jdbcCompliant(): Boolean = wrapped.jdbcCompliant() - - override def getPropertyInfo(url: String, info: Properties): Array[DriverPropertyInfo] = { - wrapped.getPropertyInfo(url, info) - } - - override def getMinorVersion: Int = wrapped.getMinorVersion - - def getParentLogger: java.util.logging.Logger = - throw new SQLFeatureNotSupportedException( - s"${this.getClass().getName}.getParentLogger is not yet implemented.") - - override def connect(url: String, info: Properties): Connection = wrapped.connect(url, info) - - override def getMajorVersion: Int = wrapped.getMajorVersion - } - - /** - * java.sql.DriverManager is always loaded by bootstrap classloader, - * so it can't load JDBC drivers accessible by Spark ClassLoader. - * - * To solve the problem, drivers from user-supplied jars are wrapped - * into thin wrapper. - */ - private [sql] object DriverRegistry extends Logging { - - private val wrapperMap: mutable.Map[String, DriverWrapper] = mutable.Map.empty - - def register(className: String): Unit = { - val cls = Utils.getContextOrSparkClassLoader.loadClass(className) - if (cls.getClassLoader == null) { - logTrace(s"$className has been loaded with bootstrap ClassLoader, wrapper is not required") - } else if (wrapperMap.get(className).isDefined) { - logTrace(s"Wrapper for $className already exists") - } else { - synchronized { - if (wrapperMap.get(className).isEmpty) { - val wrapper = new DriverWrapper(cls.newInstance().asInstanceOf[Driver]) - DriverManager.registerDriver(wrapper) - wrapperMap(className) = wrapper - logTrace(s"Wrapper for $className registered") - } - } - } - } - - def getDriverClassName(url: String): String = DriverManager.getDriver(url) match { - case wrapper: DriverWrapper => wrapper.wrapped.getClass.getCanonicalName - case driver => driver.getClass.getCanonicalName - } - } - -} // package object jdbc http://git-wip-us.apache.org/repos/asf/spark/blob/40ed2af5/sql/core/src/main/scala/org/apache/spark/sql/json/InferSchema.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/InferSchema.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/InferSchema.scala deleted file mode 100644 index ec5668c..0000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/InferSchema.scala +++ /dev/null @@ -1,207 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.json - -import com.fasterxml.jackson.core._ - -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.analysis.HiveTypeCoercion -import org.apache.spark.sql.json.JacksonUtils.nextUntil -import org.apache.spark.sql.types._ - -private[sql] object InferSchema { - /** - * Infer the type of a collection of json records in three stages: - * 1. Infer the type of each record - * 2. Merge types by choosing the lowest type necessary to cover equal keys - * 3. Replace any remaining null fields with string, the top type - */ - def apply( - json: RDD[String], - samplingRatio: Double = 1.0, - columnNameOfCorruptRecords: String): StructType = { - require(samplingRatio > 0, s"samplingRatio ($samplingRatio) should be greater than 0") - val schemaData = if (samplingRatio > 0.99) { - json - } else { - json.sample(withReplacement = false, samplingRatio, 1) - } - - // perform schema inference on each row and merge afterwards - val rootType = schemaData.mapPartitions { iter => - val factory = new JsonFactory() - iter.map { row => - try { - val parser = factory.createParser(row) - parser.nextToken() - inferField(parser) - } catch { - case _: JsonParseException => - StructType(Seq(StructField(columnNameOfCorruptRecords, StringType))) - } - } - }.treeAggregate[DataType](StructType(Seq()))(compatibleRootType, compatibleRootType) - - canonicalizeType(rootType) match { - case Some(st: StructType) => st - case _ => - // canonicalizeType erases all empty structs, including the only one we want to keep - StructType(Seq()) - } - } - - /** - * Infer the type of a json document from the parser's token stream - */ - private def inferField(parser: JsonParser): DataType = { - import com.fasterxml.jackson.core.JsonToken._ - parser.getCurrentToken match { - case null | VALUE_NULL => NullType - - case FIELD_NAME => - parser.nextToken() - inferField(parser) - - case VALUE_STRING if parser.getTextLength < 1 => - // Zero length strings and nulls have special handling to deal - // with JSON generators that do not distinguish between the two. - // To accurately infer types for empty strings that are really - // meant to represent nulls we assume that the two are isomorphic - // but will defer treating null fields as strings until all the - // record fields' types have been combined. - NullType - - case VALUE_STRING => StringType - case START_OBJECT => - val builder = Seq.newBuilder[StructField] - while (nextUntil(parser, END_OBJECT)) { - builder += StructField(parser.getCurrentName, inferField(parser), nullable = true) - } - - StructType(builder.result().sortBy(_.name)) - - case START_ARRAY => - // If this JSON array is empty, we use NullType as a placeholder. - // If this array is not empty in other JSON objects, we can resolve - // the type as we pass through all JSON objects. - var elementType: DataType = NullType - while (nextUntil(parser, END_ARRAY)) { - elementType = compatibleType(elementType, inferField(parser)) - } - - ArrayType(elementType) - - case VALUE_NUMBER_INT | VALUE_NUMBER_FLOAT => - import JsonParser.NumberType._ - parser.getNumberType match { - // For Integer values, use LongType by default. - case INT | LONG => LongType - // Since we do not have a data type backed by BigInteger, - // when we see a Java BigInteger, we use DecimalType. - case BIG_INTEGER | BIG_DECIMAL => - val v = parser.getDecimalValue - DecimalType(v.precision(), v.scale()) - case FLOAT | DOUBLE => - // TODO(davies): Should we use decimal if possible? - DoubleType - } - - case VALUE_TRUE | VALUE_FALSE => BooleanType - } - } - - /** - * Convert NullType to StringType and remove StructTypes with no fields - */ - private def canonicalizeType: DataType => Option[DataType] = { - case at @ ArrayType(elementType, _) => - for { - canonicalType <- canonicalizeType(elementType) - } yield { - at.copy(canonicalType) - } - - case StructType(fields) => - val canonicalFields = for { - field <- fields - if field.name.nonEmpty - canonicalType <- canonicalizeType(field.dataType) - } yield { - field.copy(dataType = canonicalType) - } - - if (canonicalFields.nonEmpty) { - Some(StructType(canonicalFields)) - } else { - // per SPARK-8093: empty structs should be deleted - None - } - - case NullType => Some(StringType) - case other => Some(other) - } - - /** - * Remove top-level ArrayType wrappers and merge the remaining schemas - */ - private def compatibleRootType: (DataType, DataType) => DataType = { - case (ArrayType(ty1, _), ty2) => compatibleRootType(ty1, ty2) - case (ty1, ArrayType(ty2, _)) => compatibleRootType(ty1, ty2) - case (ty1, ty2) => compatibleType(ty1, ty2) - } - - /** - * Returns the most general data type for two given data types. - */ - private[json] def compatibleType(t1: DataType, t2: DataType): DataType = { - HiveTypeCoercion.findTightestCommonTypeOfTwo(t1, t2).getOrElse { - // t1 or t2 is a StructType, ArrayType, or an unexpected type. - (t1, t2) match { - // Double support larger range than fixed decimal, DecimalType.Maximum should be enough - // in most case, also have better precision. - case (DoubleType, t: DecimalType) => - DoubleType - case (t: DecimalType, DoubleType) => - DoubleType - case (t1: DecimalType, t2: DecimalType) => - val scale = math.max(t1.scale, t2.scale) - val range = math.max(t1.precision - t1.scale, t2.precision - t2.scale) - if (range + scale > 38) { - // DecimalType can't support precision > 38 - DoubleType - } else { - DecimalType(range + scale, scale) - } - - case (StructType(fields1), StructType(fields2)) => - val newFields = (fields1 ++ fields2).groupBy(field => field.name).map { - case (name, fieldTypes) => - val dataType = fieldTypes.view.map(_.dataType).reduce(compatibleType) - StructField(name, dataType, nullable = true) - } - StructType(newFields.toSeq.sortBy(_.name)) - - case (ArrayType(elementType1, containsNull1), ArrayType(elementType2, containsNull2)) => - ArrayType(compatibleType(elementType1, elementType2), containsNull1 || containsNull2) - - // strings and every string is a Json object. - case (_, _) => StringType - } - } - } -} http://git-wip-us.apache.org/repos/asf/spark/blob/40ed2af5/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala deleted file mode 100644 index 5bb9e62..0000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala +++ /dev/null @@ -1,203 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.json - -import java.io.CharArrayWriter - -import com.fasterxml.jackson.core.JsonFactory -import com.google.common.base.Objects -import org.apache.hadoop.fs.{FileStatus, Path} -import org.apache.hadoop.io.{LongWritable, NullWritable, Text} -import org.apache.hadoop.mapred.{JobConf, TextInputFormat} -import org.apache.hadoop.mapreduce.lib.input.FileInputFormat -import org.apache.hadoop.mapreduce.lib.output.TextOutputFormat -import org.apache.hadoop.mapreduce.{Job, RecordWriter, TaskAttemptContext} - -import org.apache.spark.Logging -import org.apache.spark.broadcast.Broadcast -import org.apache.spark.mapred.SparkHadoopMapRedUtil -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.execution.datasources.PartitionSpec -import org.apache.spark.sql.sources._ -import org.apache.spark.sql.types.StructType -import org.apache.spark.sql.{AnalysisException, Row, SQLContext} -import org.apache.spark.util.SerializableConfiguration - -private[sql] class DefaultSource extends HadoopFsRelationProvider with DataSourceRegister { - - def format(): String = "json" - - override def createRelation( - sqlContext: SQLContext, - paths: Array[String], - dataSchema: Option[StructType], - partitionColumns: Option[StructType], - parameters: Map[String, String]): HadoopFsRelation = { - val samplingRatio = parameters.get("samplingRatio").map(_.toDouble).getOrElse(1.0) - - new JSONRelation(None, samplingRatio, dataSchema, None, partitionColumns, paths)(sqlContext) - } -} - -private[sql] class JSONRelation( - val inputRDD: Option[RDD[String]], - val samplingRatio: Double, - val maybeDataSchema: Option[StructType], - val maybePartitionSpec: Option[PartitionSpec], - override val userDefinedPartitionColumns: Option[StructType], - override val paths: Array[String] = Array.empty[String])(@transient val sqlContext: SQLContext) - extends HadoopFsRelation(maybePartitionSpec) { - - /** Constraints to be imposed on schema to be stored. */ - private def checkConstraints(schema: StructType): Unit = { - if (schema.fieldNames.length != schema.fieldNames.distinct.length) { - val duplicateColumns = schema.fieldNames.groupBy(identity).collect { - case (x, ys) if ys.length > 1 => "\"" + x + "\"" - }.mkString(", ") - throw new AnalysisException(s"Duplicate column(s) : $duplicateColumns found, " + - s"cannot save to JSON format") - } - } - - override val needConversion: Boolean = false - - private def createBaseRdd(inputPaths: Array[FileStatus]): RDD[String] = { - val job = new Job(sqlContext.sparkContext.hadoopConfiguration) - val conf = job.getConfiguration - - val paths = inputPaths.map(_.getPath) - - if (paths.nonEmpty) { - FileInputFormat.setInputPaths(job, paths: _*) - } - - sqlContext.sparkContext.hadoopRDD( - conf.asInstanceOf[JobConf], - classOf[TextInputFormat], - classOf[LongWritable], - classOf[Text]).map(_._2.toString) // get the text line - } - - override lazy val dataSchema = { - val jsonSchema = maybeDataSchema.getOrElse { - val files = cachedLeafStatuses().filterNot { status => - val name = status.getPath.getName - name.startsWith("_") || name.startsWith(".") - }.toArray - InferSchema( - inputRDD.getOrElse(createBaseRdd(files)), - samplingRatio, - sqlContext.conf.columnNameOfCorruptRecord) - } - checkConstraints(jsonSchema) - - jsonSchema - } - - override private[sql] def buildScan( - requiredColumns: Array[String], - filters: Array[Filter], - inputPaths: Array[String], - broadcastedConf: Broadcast[SerializableConfiguration]): RDD[Row] = { - refresh() - super.buildScan(requiredColumns, filters, inputPaths, broadcastedConf) - } - - override def buildScan( - requiredColumns: Array[String], - filters: Array[Filter], - inputPaths: Array[FileStatus]): RDD[Row] = { - JacksonParser( - inputRDD.getOrElse(createBaseRdd(inputPaths)), - StructType(requiredColumns.map(dataSchema(_))), - sqlContext.conf.columnNameOfCorruptRecord).asInstanceOf[RDD[Row]] - } - - override def equals(other: Any): Boolean = other match { - case that: JSONRelation => - ((inputRDD, that.inputRDD) match { - case (Some(thizRdd), Some(thatRdd)) => thizRdd eq thatRdd - case (None, None) => true - case _ => false - }) && paths.toSet == that.paths.toSet && - dataSchema == that.dataSchema && - schema == that.schema - case _ => false - } - - override def hashCode(): Int = { - Objects.hashCode( - inputRDD, - paths.toSet, - dataSchema, - schema, - partitionColumns) - } - - override def prepareJobForWrite(job: Job): OutputWriterFactory = { - new OutputWriterFactory { - override def newInstance( - path: String, - dataSchema: StructType, - context: TaskAttemptContext): OutputWriter = { - new JsonOutputWriter(path, dataSchema, context) - } - } - } -} - -private[json] class JsonOutputWriter( - path: String, - dataSchema: StructType, - context: TaskAttemptContext) - extends OutputWriter with SparkHadoopMapRedUtil with Logging { - - val writer = new CharArrayWriter() - // create the Generator without separator inserted between 2 records - val gen = new JsonFactory().createGenerator(writer).setRootValueSeparator(null) - - val result = new Text() - - private val recordWriter: RecordWriter[NullWritable, Text] = { - new TextOutputFormat[NullWritable, Text]() { - override def getDefaultWorkFile(context: TaskAttemptContext, extension: String): Path = { - val uniqueWriteJobId = context.getConfiguration.get("spark.sql.sources.writeJobUUID") - val split = context.getTaskAttemptID.getTaskID.getId - new Path(path, f"part-r-$split%05d-$uniqueWriteJobId$extension") - } - }.getRecordWriter(context) - } - - override def write(row: Row): Unit = throw new UnsupportedOperationException("call writeInternal") - - override protected[sql] def writeInternal(row: InternalRow): Unit = { - JacksonGenerator(dataSchema, gen, row) - gen.flush() - - result.set(writer.toString) - writer.reset() - - recordWriter.write(NullWritable.get(), result) - } - - override def close(): Unit = { - gen.close() - recordWriter.close(context) - } -} http://git-wip-us.apache.org/repos/asf/spark/blob/40ed2af5/sql/core/src/main/scala/org/apache/spark/sql/json/JacksonGenerator.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JacksonGenerator.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JacksonGenerator.scala deleted file mode 100644 index d734e7e..0000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/JacksonGenerator.scala +++ /dev/null @@ -1,135 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.json - -import org.apache.spark.sql.catalyst.InternalRow - -import scala.collection.Map - -import com.fasterxml.jackson.core._ - -import org.apache.spark.sql.Row -import org.apache.spark.sql.types._ - -private[sql] object JacksonGenerator { - /** Transforms a single Row to JSON using Jackson - * - * @param rowSchema the schema object used for conversion - * @param gen a JsonGenerator object - * @param row The row to convert - */ - def apply(rowSchema: StructType, gen: JsonGenerator)(row: Row): Unit = { - def valWriter: (DataType, Any) => Unit = { - case (_, null) | (NullType, _) => gen.writeNull() - case (StringType, v: String) => gen.writeString(v) - case (TimestampType, v: java.sql.Timestamp) => gen.writeString(v.toString) - case (IntegerType, v: Int) => gen.writeNumber(v) - case (ShortType, v: Short) => gen.writeNumber(v) - case (FloatType, v: Float) => gen.writeNumber(v) - case (DoubleType, v: Double) => gen.writeNumber(v) - case (LongType, v: Long) => gen.writeNumber(v) - case (DecimalType(), v: java.math.BigDecimal) => gen.writeNumber(v) - case (ByteType, v: Byte) => gen.writeNumber(v.toInt) - case (BinaryType, v: Array[Byte]) => gen.writeBinary(v) - case (BooleanType, v: Boolean) => gen.writeBoolean(v) - case (DateType, v) => gen.writeString(v.toString) - case (udt: UserDefinedType[_], v) => valWriter(udt.sqlType, udt.serialize(v)) - - case (ArrayType(ty, _), v: Seq[_]) => - gen.writeStartArray() - v.foreach(valWriter(ty, _)) - gen.writeEndArray() - - case (MapType(kv, vv, _), v: Map[_, _]) => - gen.writeStartObject() - v.foreach { p => - gen.writeFieldName(p._1.toString) - valWriter(vv, p._2) - } - gen.writeEndObject() - - case (StructType(ty), v: Row) => - gen.writeStartObject() - ty.zip(v.toSeq).foreach { - case (_, null) => - case (field, v) => - gen.writeFieldName(field.name) - valWriter(field.dataType, v) - } - gen.writeEndObject() - } - - valWriter(rowSchema, row) - } - - /** Transforms a single InternalRow to JSON using Jackson - * - * TODO: make the code shared with the other apply method. - * - * @param rowSchema the schema object used for conversion - * @param gen a JsonGenerator object - * @param row The row to convert - */ - def apply(rowSchema: StructType, gen: JsonGenerator, row: InternalRow): Unit = { - def valWriter: (DataType, Any) => Unit = { - case (_, null) | (NullType, _) => gen.writeNull() - case (StringType, v) => gen.writeString(v.toString) - case (TimestampType, v: java.sql.Timestamp) => gen.writeString(v.toString) - case (IntegerType, v: Int) => gen.writeNumber(v) - case (ShortType, v: Short) => gen.writeNumber(v) - case (FloatType, v: Float) => gen.writeNumber(v) - case (DoubleType, v: Double) => gen.writeNumber(v) - case (LongType, v: Long) => gen.writeNumber(v) - case (DecimalType(), v: java.math.BigDecimal) => gen.writeNumber(v) - case (ByteType, v: Byte) => gen.writeNumber(v.toInt) - case (BinaryType, v: Array[Byte]) => gen.writeBinary(v) - case (BooleanType, v: Boolean) => gen.writeBoolean(v) - case (DateType, v) => gen.writeString(v.toString) - case (udt: UserDefinedType[_], v) => valWriter(udt.sqlType, udt.serialize(v)) - - case (ArrayType(ty, _), v: ArrayData) => - gen.writeStartArray() - v.foreach(ty, (_, value) => valWriter(ty, value)) - gen.writeEndArray() - - case (MapType(kv, vv, _), v: Map[_, _]) => - gen.writeStartObject() - v.foreach { p => - gen.writeFieldName(p._1.toString) - valWriter(vv, p._2) - } - gen.writeEndObject() - - case (StructType(ty), v: InternalRow) => - gen.writeStartObject() - var i = 0 - while (i < ty.length) { - val field = ty(i) - val value = v.get(i, field.dataType) - if (value != null) { - gen.writeFieldName(field.name) - valWriter(field.dataType, value) - } - i += 1 - } - gen.writeEndObject() - } - - valWriter(rowSchema, row) - } -} http://git-wip-us.apache.org/repos/asf/spark/blob/40ed2af5/sql/core/src/main/scala/org/apache/spark/sql/json/JacksonParser.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JacksonParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JacksonParser.scala deleted file mode 100644 index b8fd3b9..0000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/JacksonParser.scala +++ /dev/null @@ -1,228 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.json - -import java.io.ByteArrayOutputStream - -import com.fasterxml.jackson.core._ - -import scala.collection.mutable.ArrayBuffer - -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.util.DateTimeUtils -import org.apache.spark.sql.json.JacksonUtils.nextUntil -import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.UTF8String - -private[sql] object JacksonParser { - def apply( - json: RDD[String], - schema: StructType, - columnNameOfCorruptRecords: String): RDD[InternalRow] = { - parseJson(json, schema, columnNameOfCorruptRecords) - } - - /** - * Parse the current token (and related children) according to a desired schema - */ - private[sql] def convertField( - factory: JsonFactory, - parser: JsonParser, - schema: DataType): Any = { - import com.fasterxml.jackson.core.JsonToken._ - (parser.getCurrentToken, schema) match { - case (null | VALUE_NULL, _) => - null - - case (FIELD_NAME, _) => - parser.nextToken() - convertField(factory, parser, schema) - - case (VALUE_STRING, StringType) => - UTF8String.fromString(parser.getText) - - case (VALUE_STRING, _) if parser.getTextLength < 1 => - // guard the non string type - null - - case (VALUE_STRING, DateType) => - DateTimeUtils.millisToDays(DateTimeUtils.stringToTime(parser.getText).getTime) - - case (VALUE_STRING, TimestampType) => - DateTimeUtils.stringToTime(parser.getText).getTime * 1000L - - case (VALUE_NUMBER_INT, TimestampType) => - parser.getLongValue * 1000L - - case (_, StringType) => - val writer = new ByteArrayOutputStream() - val generator = factory.createGenerator(writer, JsonEncoding.UTF8) - generator.copyCurrentStructure(parser) - generator.close() - UTF8String.fromBytes(writer.toByteArray) - - case (VALUE_NUMBER_INT | VALUE_NUMBER_FLOAT, FloatType) => - parser.getFloatValue - - case (VALUE_NUMBER_INT | VALUE_NUMBER_FLOAT, DoubleType) => - parser.getDoubleValue - - case (VALUE_NUMBER_INT | VALUE_NUMBER_FLOAT, dt: DecimalType) => - Decimal(parser.getDecimalValue, dt.precision, dt.scale) - - case (VALUE_NUMBER_INT, ByteType) => - parser.getByteValue - - case (VALUE_NUMBER_INT, ShortType) => - parser.getShortValue - - case (VALUE_NUMBER_INT, IntegerType) => - parser.getIntValue - - case (VALUE_NUMBER_INT, LongType) => - parser.getLongValue - - case (VALUE_TRUE, BooleanType) => - true - - case (VALUE_FALSE, BooleanType) => - false - - case (START_OBJECT, st: StructType) => - convertObject(factory, parser, st) - - case (START_ARRAY, st: StructType) => - // SPARK-3308: support reading top level JSON arrays and take every element - // in such an array as a row - convertArray(factory, parser, st) - - case (START_ARRAY, ArrayType(st, _)) => - convertArray(factory, parser, st) - - case (START_OBJECT, ArrayType(st, _)) => - // the business end of SPARK-3308: - // when an object is found but an array is requested just wrap it in a list - convertField(factory, parser, st) :: Nil - - case (START_OBJECT, MapType(StringType, kt, _)) => - convertMap(factory, parser, kt) - - case (_, udt: UserDefinedType[_]) => - convertField(factory, parser, udt.sqlType) - } - } - - /** - * Parse an object from the token stream into a new Row representing the schema. - * - * Fields in the json that are not defined in the requested schema will be dropped. - */ - private def convertObject( - factory: JsonFactory, - parser: JsonParser, - schema: StructType): InternalRow = { - val row = new GenericMutableRow(schema.length) - while (nextUntil(parser, JsonToken.END_OBJECT)) { - schema.getFieldIndex(parser.getCurrentName) match { - case Some(index) => - row.update(index, convertField(factory, parser, schema(index).dataType)) - - case None => - parser.skipChildren() - } - } - - row - } - - /** - * Parse an object as a Map, preserving all fields - */ - private def convertMap( - factory: JsonFactory, - parser: JsonParser, - valueType: DataType): MapData = { - val keys = ArrayBuffer.empty[UTF8String] - val values = ArrayBuffer.empty[Any] - while (nextUntil(parser, JsonToken.END_OBJECT)) { - keys += UTF8String.fromString(parser.getCurrentName) - values += convertField(factory, parser, valueType) - } - ArrayBasedMapData(keys.toArray, values.toArray) - } - - private def convertArray( - factory: JsonFactory, - parser: JsonParser, - elementType: DataType): ArrayData = { - val values = ArrayBuffer.empty[Any] - while (nextUntil(parser, JsonToken.END_ARRAY)) { - values += convertField(factory, parser, elementType) - } - - new GenericArrayData(values.toArray) - } - - private def parseJson( - json: RDD[String], - schema: StructType, - columnNameOfCorruptRecords: String): RDD[InternalRow] = { - - def failedRecord(record: String): Seq[InternalRow] = { - // create a row even if no corrupt record column is present - val row = new GenericMutableRow(schema.length) - for (corruptIndex <- schema.getFieldIndex(columnNameOfCorruptRecords)) { - require(schema(corruptIndex).dataType == StringType) - row.update(corruptIndex, UTF8String.fromString(record)) - } - - Seq(row) - } - - json.mapPartitions { iter => - val factory = new JsonFactory() - - iter.flatMap { record => - try { - val parser = factory.createParser(record) - parser.nextToken() - - convertField(factory, parser, schema) match { - case null => failedRecord(record) - case row: InternalRow => row :: Nil - case array: ArrayData => - if (array.numElements() == 0) { - Nil - } else { - array.toArray[InternalRow](schema) - } - case _ => - sys.error( - s"Failed to parse record $record. Please make sure that each line of the file " + - "(or each string in the RDD) is a valid JSON object or an array of JSON objects.") - } - } catch { - case _: JsonProcessingException => - failedRecord(record) - } - } - } - } -} http://git-wip-us.apache.org/repos/asf/spark/blob/40ed2af5/sql/core/src/main/scala/org/apache/spark/sql/json/JacksonUtils.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JacksonUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JacksonUtils.scala deleted file mode 100644 index fde9685..0000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/JacksonUtils.scala +++ /dev/null @@ -1,32 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.json - -import com.fasterxml.jackson.core.{JsonParser, JsonToken} - -private object JacksonUtils { - /** - * Advance the parser until a null or a specific token is found - */ - def nextUntil(parser: JsonParser, stopOn: JsonToken): Boolean = { - parser.nextToken() match { - case null => false - case x => x != stopOn - } - } -} http://git-wip-us.apache.org/repos/asf/spark/blob/40ed2af5/sql/core/src/main/scala/org/apache/spark/sql/metric/SQLMetrics.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/metric/SQLMetrics.scala b/sql/core/src/main/scala/org/apache/spark/sql/metric/SQLMetrics.scala deleted file mode 100644 index 3b907e5..0000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/metric/SQLMetrics.scala +++ /dev/null @@ -1,149 +0,0 @@ -/* -* Licensed to the Apache Software Foundation (ASF) under one or more -* contributor license agreements. See the NOTICE file distributed with -* this work for additional information regarding copyright ownership. -* The ASF licenses this file to You under the Apache License, Version 2.0 -* (the "License"); you may not use this file except in compliance with -* the License. You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ - -package org.apache.spark.sql.metric - -import org.apache.spark.{Accumulable, AccumulableParam, SparkContext} - -/** - * Create a layer for specialized metric. We cannot add `@specialized` to - * `Accumulable/AccumulableParam` because it will break Java source compatibility. - * - * An implementation of SQLMetric should override `+=` and `add` to avoid boxing. - */ -private[sql] abstract class SQLMetric[R <: SQLMetricValue[T], T]( - name: String, val param: SQLMetricParam[R, T]) - extends Accumulable[R, T](param.zero, param, Some(name), true) - -/** - * Create a layer for specialized metric. We cannot add `@specialized` to - * `Accumulable/AccumulableParam` because it will break Java source compatibility. - */ -private[sql] trait SQLMetricParam[R <: SQLMetricValue[T], T] extends AccumulableParam[R, T] { - - def zero: R -} - -/** - * Create a layer for specialized metric. We cannot add `@specialized` to - * `Accumulable/AccumulableParam` because it will break Java source compatibility. - */ -private[sql] trait SQLMetricValue[T] extends Serializable { - - def value: T - - override def toString: String = value.toString -} - -/** - * A wrapper of Long to avoid boxing and unboxing when using Accumulator - */ -private[sql] class LongSQLMetricValue(private var _value : Long) extends SQLMetricValue[Long] { - - def add(incr: Long): LongSQLMetricValue = { - _value += incr - this - } - - // Although there is a boxing here, it's fine because it's only called in SQLListener - override def value: Long = _value -} - -/** - * A wrapper of Int to avoid boxing and unboxing when using Accumulator - */ -private[sql] class IntSQLMetricValue(private var _value: Int) extends SQLMetricValue[Int] { - - def add(term: Int): IntSQLMetricValue = { - _value += term - this - } - - // Although there is a boxing here, it's fine because it's only called in SQLListener - override def value: Int = _value -} - -/** - * A specialized long Accumulable to avoid boxing and unboxing when using Accumulator's - * `+=` and `add`. - */ -private[sql] class LongSQLMetric private[metric](name: String) - extends SQLMetric[LongSQLMetricValue, Long](name, LongSQLMetricParam) { - - override def +=(term: Long): Unit = { - localValue.add(term) - } - - override def add(term: Long): Unit = { - localValue.add(term) - } -} - -/** - * A specialized int Accumulable to avoid boxing and unboxing when using Accumulator's - * `+=` and `add`. - */ -private[sql] class IntSQLMetric private[metric](name: String) - extends SQLMetric[IntSQLMetricValue, Int](name, IntSQLMetricParam) { - - override def +=(term: Int): Unit = { - localValue.add(term) - } - - override def add(term: Int): Unit = { - localValue.add(term) - } -} - -private object LongSQLMetricParam extends SQLMetricParam[LongSQLMetricValue, Long] { - - override def addAccumulator(r: LongSQLMetricValue, t: Long): LongSQLMetricValue = r.add(t) - - override def addInPlace(r1: LongSQLMetricValue, r2: LongSQLMetricValue): LongSQLMetricValue = - r1.add(r2.value) - - override def zero(initialValue: LongSQLMetricValue): LongSQLMetricValue = zero - - override def zero: LongSQLMetricValue = new LongSQLMetricValue(0L) -} - -private object IntSQLMetricParam extends SQLMetricParam[IntSQLMetricValue, Int] { - - override def addAccumulator(r: IntSQLMetricValue, t: Int): IntSQLMetricValue = r.add(t) - - override def addInPlace(r1: IntSQLMetricValue, r2: IntSQLMetricValue): IntSQLMetricValue = - r1.add(r2.value) - - override def zero(initialValue: IntSQLMetricValue): IntSQLMetricValue = zero - - override def zero: IntSQLMetricValue = new IntSQLMetricValue(0) -} - -private[sql] object SQLMetrics { - - def createIntMetric(sc: SparkContext, name: String): IntSQLMetric = { - val acc = new IntSQLMetric(name) - sc.cleaner.foreach(_.registerAccumulatorForCleanup(acc)) - acc - } - - def createLongMetric(sc: SparkContext, name: String): LongSQLMetric = { - val acc = new LongSQLMetric(name) - sc.cleaner.foreach(_.registerAccumulatorForCleanup(acc)) - acc - } -} --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org