Github user cloud-fan commented on a diff in the pull request:

    https://github.com/apache/spark/pull/22468#discussion_r225521397
  
    --- Diff: 
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedSafeProjection.scala
 ---
    @@ -0,0 +1,173 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one or more
    + * contributor license agreements.  See the NOTICE file distributed with
    + * this work for additional information regarding copyright ownership.
    + * The ASF licenses this file to You under the Apache License, Version 2.0
    + * (the "License"); you may not use this file except in compliance with
    + * the License.  You may obtain a copy of the License at
    + *
    + *    http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +package org.apache.spark.sql.catalyst.expressions
    +
    +import org.apache.spark.SparkException
    +import org.apache.spark.sql.catalyst.InternalRow
    +import org.apache.spark.sql.catalyst.expressions.aggregate.NoOp
    +import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, 
GenericArrayData, MapData}
    +import org.apache.spark.sql.types._
    +
    +
    +/**
    + * An interpreted version of a safe projection.
    + *
    + * @param expressions that produces the resulting fields. These 
expressions must be bound
    + *                    to a schema.
    + */
    +class InterpretedSafeProjection(expressions: Seq[Expression]) extends 
Projection {
    +
    +  private[this] val mutableRow = new 
SpecificInternalRow(expressions.map(_.dataType))
    +
    +  private[this] val exprsWithWriters = expressions.zipWithIndex.filter {
    +    case (NoOp, _) => false
    +    case _ => true
    +  }.map { case (e, i) =>
    +    val converter = generateSafeValueConverter(e.dataType)
    +    val writer = generateRowWriter(i, e.dataType)
    +    val f = if (!e.nullable) {
    +      (v: Any) => writer(converter(v))
    +    } else {
    +      (v: Any) => {
    +        if (v == null) {
    +          mutableRow.setNullAt(i)
    +        } else {
    +          writer(converter(v))
    +        }
    +      }
    +    }
    +    (e, f)
    +  }
    +
    +  private def isPrimitive(dataType: DataType): Boolean = dataType match {
    +    case BooleanType => true
    +    case ByteType => true
    +    case ShortType => true
    +    case IntegerType => true
    +    case LongType => true
    +    case FloatType => true
    +    case DoubleType => true
    +    case _ => false
    +  }
    +
    +  private def generateSafeValueConverter(dt: DataType): Any => Any = dt 
match {
    +    case ArrayType(elemType, _) =>
    +      if (isPrimitive(elemType)) {
    +        v => {
    +          val arrayValue = v.asInstanceOf[ArrayData]
    +          new GenericArrayData(arrayValue.toArray[Any](elemType))
    +        }
    +      } else {
    +        val elementConverter = generateSafeValueConverter(elemType)
    +        v => {
    +          val arrayValue = v.asInstanceOf[ArrayData]
    +          val result = new Array[Any](arrayValue.numElements())
    +          arrayValue.foreach(elemType, (i, e) => {
    +            result(i) = elementConverter(e)
    +          })
    +          new GenericArrayData(result)
    +        }
    +      }
    +
    +    case st: StructType =>
    +      val fieldTypes = st.fields.map(_.dataType)
    +      val fieldConverters = fieldTypes.map(generateSafeValueConverter)
    +      v => {
    +        val row = v.asInstanceOf[InternalRow]
    +        val ar = new Array[Any](row.numFields)
    +        var idx = 0
    +        while (idx < row.numFields) {
    +          ar(idx) = fieldConverters(idx)(row.get(idx, fieldTypes(idx)))
    +          idx += 1
    +        }
    +        new GenericInternalRow(ar)
    +      }
    +
    +    case MapType(keyType, valueType, _) =>
    +      lazy val keyConverter = generateSafeValueConverter(keyType)
    +      lazy val valueConverter = generateSafeValueConverter(valueType)
    +      v => {
    +        val mapValue = v.asInstanceOf[MapData]
    +        val keys = mapValue.keyArray().toArray[Any](keyType)
    +        val values = mapValue.valueArray().toArray[Any](valueType)
    +        val convertedKeys =
    +          if (isPrimitive(keyType)) keys else keys.map(keyConverter)
    --- End diff --
    
    ditto


---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to