This is an automated email from the ASF dual-hosted git repository.

jiayu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/sedona.git


The following commit(s) were added to refs/heads/master by this push:
     new debac336 [SEDONA-311] Refactor `Inferred*Expression` base class for 
Sedona SQL (#871)
debac336 is described below

commit debac336e923111b63ece886901e43f8961f5fbe
Author: Kristin Cowalcijk <[email protected]>
AuthorDate: Sun Jun 25 09:25:53 2023 +0800

    [SEDONA-311] Refactor `Inferred*Expression` base class for Sedona SQL (#871)
---
 .../sql/sedona_sql/expressions/Constructors.scala  |  30 +-
 .../expressions/FoldableExpression.scala           |  29 ++
 .../sql/sedona_sql/expressions/Functions.scala     | 199 ++++-----
 .../expressions/InferrableFunctionConverter.scala  | 494 +++++++++++++++++++++
 .../expressions/InferredExpression.scala           | 223 ++++++++++
 .../expressions/NullSafeExpressions.scala          | 366 ---------------
 6 files changed, 858 insertions(+), 483 deletions(-)

diff --git 
a/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Constructors.scala
 
b/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Constructors.scala
index fff08228..ccc7ffe9 100644
--- 
a/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Constructors.scala
+++ 
b/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Constructors.scala
@@ -26,6 +26,7 @@ import 
org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
 import org.apache.spark.sql.catalyst.expressions.{Expression, 
ImplicitCastInputTypes}
 import org.apache.spark.sql.sedona_sql.UDT.GeometryUDT
 import org.apache.spark.sql.sedona_sql.expressions.implicits.GeometryEnhancer
+import 
org.apache.spark.sql.sedona_sql.expressions.InferrableFunctionConverter._
 import org.apache.spark.sql.types._
 import org.apache.spark.unsafe.types.UTF8String
 
@@ -36,7 +37,7 @@ import org.apache.spark.unsafe.types.UTF8String
   *                         string, the second parameter is the delimiter. 
String format should be similar to CSV/TSV
   */
 case class ST_PointFromText(inputExpressions: Seq[Expression])
-  extends InferredBinaryExpression(Constructors.pointFromText) with 
FoldableExpression {
+  extends InferredExpression(Constructors.pointFromText _) {
   protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = 
{
     copy(inputExpressions = newChildren)
   }
@@ -48,7 +49,7 @@ case class ST_PointFromText(inputExpressions: Seq[Expression])
   * @param inputExpressions
   */
 case class ST_PolygonFromText(inputExpressions: Seq[Expression])
-  extends InferredBinaryExpression(Constructors.polygonFromText) with 
FoldableExpression {
+  extends InferredExpression(Constructors.polygonFromText _) {
   protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = 
{
     copy(inputExpressions = newChildren)
   }
@@ -60,7 +61,7 @@ case class ST_PolygonFromText(inputExpressions: 
Seq[Expression])
   * @param inputExpressions
   */
 case class ST_LineFromText(inputExpressions: Seq[Expression])
-  extends InferredUnaryExpression(Constructors.lineFromText) with 
FoldableExpression {
+  extends InferredExpression(Constructors.lineFromText _) {
   protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = 
{
     copy(inputExpressions = newChildren)
   }
@@ -71,7 +72,7 @@ case class ST_LineFromText(inputExpressions: Seq[Expression])
   * @param inputExpressions
   */
 case class ST_LineStringFromText(inputExpressions: Seq[Expression])
-  extends InferredBinaryExpression(Constructors.lineStringFromText) with 
FoldableExpression {
+  extends InferredExpression(Constructors.lineStringFromText _) {
   protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = 
{
     copy(inputExpressions = newChildren)
   }
@@ -83,7 +84,7 @@ case class ST_LineStringFromText(inputExpressions: 
Seq[Expression])
   * @param inputExpressions This function takes a geometry string and a srid. 
The string format must be WKT.
   */
 case class ST_GeomFromWKT(inputExpressions: Seq[Expression])
-  extends InferredBinaryExpression(Constructors.geomFromWKT) with 
FoldableExpression {
+  extends InferredExpression(Constructors.geomFromWKT _) {
 
   protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = 
{
     copy(inputExpressions = newChildren)
@@ -97,7 +98,7 @@ case class ST_GeomFromWKT(inputExpressions: Seq[Expression])
   * @param inputExpressions This function takes a geometry string and a srid. 
The string format must be WKT.
   */
 case class ST_GeomFromText(inputExpressions: Seq[Expression])
-  extends InferredBinaryExpression(Constructors.geomFromWKT) with 
FoldableExpression {
+  extends InferredExpression(Constructors.geomFromWKT _) {
 
   protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = 
{
     copy(inputExpressions = newChildren)
@@ -180,7 +181,7 @@ case class ST_GeomFromGeoJSON(inputExpressions: 
Seq[Expression])
   * @param inputExpressions This function takes 2 parameter which are point x, 
y.
   */
 case class ST_Point(inputExpressions: Seq[Expression])
-  extends InferredBinaryExpression(Constructors.point) with FoldableExpression 
{
+  extends InferredExpression(Constructors.point _) {
 
   protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = 
{
     copy(inputExpressions = newChildren)
@@ -193,7 +194,7 @@ case class ST_Point(inputExpressions: Seq[Expression])
  * @param inputExpressions This function takes 4 parameter which are point x, 
y, z and srid (default 0).
  */
 case class ST_PointZ(inputExpressions: Seq[Expression])
-  extends InferredQuarternaryExpression(Constructors.pointZ) with 
FoldableExpression {
+  extends InferredExpression(Constructors.pointZ _) {
 
   protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = 
{
     copy(inputExpressions = newChildren)
@@ -207,7 +208,7 @@ case class ST_PointZ(inputExpressions: Seq[Expression])
   * @param inputExpressions
   */
 case class ST_PolygonFromEnvelope(inputExpressions: Seq[Expression])
-  extends InferredQuarternaryExpression(Constructors.polygonFromEnvelope) with 
FoldableExpression {
+  extends InferredExpression(Constructors.polygonFromEnvelope _) {
 
   protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = 
{
     copy(inputExpressions = newChildren)
@@ -226,22 +227,21 @@ trait UserDataGeneratator {
 }
 
 case class ST_GeomFromGeoHash(inputExpressions: Seq[Expression])
-  extends InferredBinaryExpression(Constructors.geomFromGeoHash) with 
FoldableExpression {
+  extends 
InferredExpression(InferrableFunction.allowRightNull(Constructors.geomFromGeoHash))
 {
   protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = 
{
     copy(inputExpressions = newChildren)
   }
-  override def allowRightNull: Boolean = true
 }
 
 case class ST_GeomFromGML(inputExpressions: Seq[Expression])
-  extends InferredUnaryExpression(Constructors.geomFromGML) with 
FoldableExpression {
+  extends InferredExpression(Constructors.geomFromGML _) {
   protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = 
{
     copy(inputExpressions = newChildren)
   }
 }
 
 case class ST_GeomFromKML(inputExpressions: Seq[Expression])
-  extends InferredUnaryExpression(Constructors.geomFromKML) with 
FoldableExpression {
+  extends InferredExpression(Constructors.geomFromKML _) {
   protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = 
{
     copy(inputExpressions = newChildren)
   }
@@ -253,7 +253,7 @@ case class ST_GeomFromKML(inputExpressions: Seq[Expression])
  * @param inputExpressions This function takes a geometry string and a srid. 
The string format must be WKT.
  */
 case class ST_MPolyFromText(inputExpressions: Seq[Expression])
-  extends InferredBinaryExpression(Constructors.mPolyFromText) with 
FoldableExpression {
+  extends InferredExpression(Constructors.mPolyFromText _) {
   protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = 
{
     copy(inputExpressions = newChildren)
   }
@@ -265,7 +265,7 @@ case class ST_MPolyFromText(inputExpressions: 
Seq[Expression])
  * @param inputExpressions This function takes a geometry string and a srid. 
The string format must be WKT.
  */
 case class ST_MLineFromText(inputExpressions: Seq[Expression])
-  extends InferredBinaryExpression(Constructors.mLineFromText) with 
FoldableExpression {
+  extends InferredExpression(Constructors.mLineFromText _) {
 
   protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = 
{
     copy(inputExpressions = newChildren)
diff --git 
a/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/FoldableExpression.scala
 
b/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/FoldableExpression.scala
new file mode 100644
index 00000000..08c0acb2
--- /dev/null
+++ 
b/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/FoldableExpression.scala
@@ -0,0 +1,29 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.spark.sql.sedona_sql.expressions
+
+import org.apache.spark.sql.catalyst.expressions.Expression
+
+/**
+  * Make expression foldable by constant folding optimizer. If all children
+  * expressions are foldable, then the expression itself is foldable.
+  */
+trait FoldableExpression extends Expression {
+  override def foldable: Boolean = children.forall(_.foldable)
+}
diff --git 
a/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Functions.scala
 
b/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Functions.scala
index a324da69..c084eec1 100644
--- 
a/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Functions.scala
+++ 
b/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Functions.scala
@@ -29,6 +29,7 @@ import org.apache.spark.sql.sedona_sql.expressions.implicits._
 import org.apache.spark.sql.types._
 import org.locationtech.jts.algorithm.MinimumBoundingCircle
 import org.locationtech.jts.geom._
+import 
org.apache.spark.sql.sedona_sql.expressions.InferrableFunctionConverter._
 
 /**
   * Return the distance between two geometries.
@@ -36,7 +37,7 @@ import org.locationtech.jts.geom._
   * @param inputExpressions This function takes two geometries and calculates 
the distance between two objects.
   */
 case class ST_Distance(inputExpressions: Seq[Expression])
-  extends InferredBinaryExpression(Functions.distance) with FoldableExpression 
{
+  extends InferredExpression(Functions.distance _) {
 
   protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = 
{
     copy(inputExpressions = newChildren)
@@ -45,7 +46,7 @@ case class ST_Distance(inputExpressions: Seq[Expression])
 
 
 case class ST_YMax(inputExpressions: Seq[Expression])
-  extends InferredUnaryExpression(Functions.yMax) with FoldableExpression {
+  extends InferredExpression(Functions.yMax _) {
 
   protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = 
{
     copy(inputExpressions = newChildren)
@@ -53,7 +54,7 @@ case class ST_YMax(inputExpressions: Seq[Expression])
 }
 
 case class ST_YMin(inputExpressions: Seq[Expression])
-  extends InferredUnaryExpression(Functions.yMin) with FoldableExpression {
+  extends InferredExpression(Functions.yMin _) {
 
   protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = 
{
     copy(inputExpressions = newChildren)
@@ -66,7 +67,7 @@ case class ST_YMin(inputExpressions: Seq[Expression])
   * @param inputExpressions This function takes a geometry and returns the 
maximum of all Z-coordinate values.
 */
 case class ST_ZMax(inputExpressions: Seq[Expression])
-  extends InferredUnaryExpression(Functions.zMax) with FoldableExpression {
+  extends InferredExpression(Functions.zMax _) {
 
   protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = 
{
     copy(inputExpressions = newChildren)
@@ -79,7 +80,7 @@ case class ST_ZMax(inputExpressions: Seq[Expression])
  * @param inputExpressions This function takes a geometry and returns the 
minimum of all Z-coordinate values.
 */
 case class ST_ZMin(inputExpressions: Seq[Expression])
-  extends InferredUnaryExpression(Functions.zMin) with FoldableExpression {
+  extends InferredExpression(Functions.zMin _) {
 
   protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = 
{
     copy(inputExpressions = newChildren)
@@ -87,7 +88,7 @@ case class ST_ZMin(inputExpressions: Seq[Expression])
 }
 
 case class ST_3DDistance(inputExpressions: Seq[Expression])
-  extends InferredBinaryExpression(Functions.distance3d) with 
FoldableExpression {
+  extends InferredExpression(Functions.distance3d _) {
 
   protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = 
{
     copy(inputExpressions = newChildren)
@@ -100,7 +101,7 @@ case class ST_3DDistance(inputExpressions: Seq[Expression])
  * @param inputExpressions
  */
 case class ST_ConcaveHull(inputExpressions: Seq[Expression])
-  extends InferredTernaryExpression(Functions.concaveHull) with 
FoldableExpression {
+  extends InferredExpression(Functions.concaveHull _) {
 
   protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): 
Expression = {
     copy(inputExpressions = newChildren)
@@ -113,7 +114,7 @@ case class ST_ConcaveHull(inputExpressions: Seq[Expression])
   * @param inputExpressions
   */
 case class ST_ConvexHull(inputExpressions: Seq[Expression])
-  extends InferredUnaryExpression(Functions.convexHull) with 
FoldableExpression {
+  extends InferredExpression(Functions.convexHull _) {
 
   protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = 
{
     copy(inputExpressions = newChildren)
@@ -126,7 +127,7 @@ case class ST_ConvexHull(inputExpressions: Seq[Expression])
   * @param inputExpressions
   */
 case class ST_NPoints(inputExpressions: Seq[Expression])
-  extends InferredUnaryExpression(Functions.nPoints) with FoldableExpression {
+  extends InferredExpression(Functions.nPoints _) {
 
   protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = 
{
     copy(inputExpressions = newChildren)
@@ -139,7 +140,7 @@ case class ST_NPoints(inputExpressions: Seq[Expression])
   * @param inputExpressions
   */
 case class ST_NDims(inputExpressions: Seq[Expression])
-  extends InferredUnaryExpression(Functions.nDims) with FoldableExpression {
+  extends InferredExpression(Functions.nDims _) {
 
   protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = 
{
     copy(inputExpressions = newChildren)
@@ -152,7 +153,7 @@ case class ST_NDims(inputExpressions: Seq[Expression])
   * @param inputExpressions
   */
 case class ST_Buffer(inputExpressions: Seq[Expression])
-  extends InferredBinaryExpression(Functions.buffer) with FoldableExpression {
+  extends InferredExpression(Functions.buffer _) {
 
   protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = 
{
     copy(inputExpressions = newChildren)
@@ -166,7 +167,7 @@ case class ST_Buffer(inputExpressions: Seq[Expression])
   * @param inputExpressions
   */
 case class ST_Envelope(inputExpressions: Seq[Expression])
-  extends InferredUnaryExpression(Functions.envelope) with FoldableExpression {
+  extends InferredExpression(Functions.envelope _) {
 
   protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = 
{
     copy(inputExpressions = newChildren)
@@ -179,7 +180,7 @@ case class ST_Envelope(inputExpressions: Seq[Expression])
   * @param inputExpressions
   */
 case class ST_Length(inputExpressions: Seq[Expression])
-  extends InferredUnaryExpression(Functions.length) with FoldableExpression {
+  extends InferredExpression(Functions.length _) {
 
   protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = 
{
     copy(inputExpressions = newChildren)
@@ -192,7 +193,7 @@ case class ST_Length(inputExpressions: Seq[Expression])
   * @param inputExpressions
   */
 case class ST_Area(inputExpressions: Seq[Expression])
-  extends InferredUnaryExpression(Functions.area) with FoldableExpression {
+  extends InferredExpression(Functions.area _) {
 
   protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = 
{
     copy(inputExpressions = newChildren)
@@ -205,7 +206,7 @@ case class ST_Area(inputExpressions: Seq[Expression])
   * @param inputExpressions
   */
 case class ST_Centroid(inputExpressions: Seq[Expression])
-  extends InferredUnaryExpression(Functions.getCentroid) with 
FoldableExpression {
+  extends InferredExpression(Functions.getCentroid _) {
 
   protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = 
{
     copy(inputExpressions = newChildren)
@@ -218,7 +219,7 @@ case class ST_Centroid(inputExpressions: Seq[Expression])
   * @param inputExpressions
   */
 case class ST_Transform(inputExpressions: Seq[Expression])
-  extends InferredQuarternaryExpression(Functions.transform) with 
FoldableExpression {
+  extends InferredExpression(inferrableFunction4(Functions.transform)) {
 
   protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = 
{
     copy(inputExpressions = newChildren)
@@ -231,7 +232,7 @@ case class ST_Transform(inputExpressions: Seq[Expression])
   * @param inputExpressions
   */
 case class ST_Intersection(inputExpressions: Seq[Expression])
-  extends InferredBinaryExpression(Functions.intersection) with 
FoldableExpression {
+  extends InferredExpression(Functions.intersection _) {
 
   protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = 
{
     copy(inputExpressions = newChildren)
@@ -245,7 +246,7 @@ case class ST_Intersection(inputExpressions: 
Seq[Expression])
   * @param inputExpressions
   */
 case class ST_MakeValid(inputExpressions: Seq[Expression])
-  extends InferredBinaryExpression(Functions.makeValid) with 
FoldableExpression {
+  extends InferredExpression(Functions.makeValid _) {
 
   protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = 
{
     copy(inputExpressions = newChildren)
@@ -258,7 +259,7 @@ case class ST_MakeValid(inputExpressions: Seq[Expression])
   * @param inputExpressions
   */
 case class ST_IsValid(inputExpressions: Seq[Expression])
-  extends InferredUnaryExpression(Functions.isValid) with FoldableExpression {
+  extends InferredExpression(Functions.isValid _) {
 
   protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = 
{
     copy(inputExpressions = newChildren)
@@ -271,7 +272,7 @@ case class ST_IsValid(inputExpressions: Seq[Expression])
   * @param inputExpressions
   */
 case class ST_IsSimple(inputExpressions: Seq[Expression])
-  extends InferredUnaryExpression(Functions.isSimple) with FoldableExpression {
+  extends InferredExpression(Functions.isSimple _) {
 
   protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = 
{
     copy(inputExpressions = newChildren)
@@ -287,7 +288,7 @@ case class ST_IsSimple(inputExpressions: Seq[Expression])
   *                         second arg is distance tolerance for the 
simplification(all vertices in the simplified geometry will be within this 
distance of the original geometry)
   */
 case class ST_SimplifyPreserveTopology(inputExpressions: Seq[Expression])
-  extends InferredBinaryExpression(Functions.simplifyPreserveTopology) with 
FoldableExpression {
+  extends InferredExpression(Functions.simplifyPreserveTopology _) {
 
   protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = 
{
     copy(inputExpressions = newChildren)
@@ -301,7 +302,7 @@ case class ST_SimplifyPreserveTopology(inputExpressions: 
Seq[Expression])
   *                         be rounded to the nearest number.
   */
 case class ST_PrecisionReduce(inputExpressions: Seq[Expression])
-  extends InferredBinaryExpression(Functions.reducePrecision) with 
FoldableExpression {
+  extends InferredExpression(Functions.reducePrecision _) {
 
   protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = 
{
     copy(inputExpressions = newChildren)
@@ -309,7 +310,7 @@ case class ST_PrecisionReduce(inputExpressions: 
Seq[Expression])
 }
 
 case class ST_AsText(inputExpressions: Seq[Expression])
-  extends InferredUnaryExpression(Functions.asWKT) with FoldableExpression {
+  extends InferredExpression(Functions.asWKT _) {
 
   protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = 
{
     copy(inputExpressions = newChildren)
@@ -317,7 +318,7 @@ case class ST_AsText(inputExpressions: Seq[Expression])
 }
 
 case class ST_AsGeoJSON(inputExpressions: Seq[Expression])
-  extends InferredUnaryExpression(Functions.asGeoJson) with FoldableExpression 
{
+  extends InferredExpression(Functions.asGeoJson _) {
 
   protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = 
{
     copy(inputExpressions = newChildren)
@@ -325,7 +326,7 @@ case class ST_AsGeoJSON(inputExpressions: Seq[Expression])
 }
 
 case class ST_AsBinary(inputExpressions: Seq[Expression])
-  extends InferredUnaryExpression(Functions.asWKB) with FoldableExpression {
+  extends InferredExpression(Functions.asWKB _) {
 
   protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = 
{
     copy(inputExpressions = newChildren)
@@ -333,7 +334,7 @@ case class ST_AsBinary(inputExpressions: Seq[Expression])
 }
 
 case class ST_AsEWKB(inputExpressions: Seq[Expression])
-  extends InferredUnaryExpression(Functions.asEWKB) with FoldableExpression {
+  extends InferredExpression(Functions.asEWKB _) {
 
   protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = 
{
     copy(inputExpressions = newChildren)
@@ -341,7 +342,7 @@ case class ST_AsEWKB(inputExpressions: Seq[Expression])
 }
 
 case class ST_SRID(inputExpressions: Seq[Expression])
-  extends InferredUnaryExpression(Functions.getSRID) with FoldableExpression {
+  extends InferredExpression(Functions.getSRID _) {
 
   protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = 
{
     copy(inputExpressions = newChildren)
@@ -349,7 +350,7 @@ case class ST_SRID(inputExpressions: Seq[Expression])
 }
 
 case class ST_SetSRID(inputExpressions: Seq[Expression])
-  extends InferredBinaryExpression(Functions.setSRID) with FoldableExpression {
+  extends InferredExpression(Functions.setSRID _) {
 
   protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = 
{
     copy(inputExpressions = newChildren)
@@ -357,7 +358,7 @@ case class ST_SetSRID(inputExpressions: Seq[Expression])
 }
 
 case class ST_GeometryType(inputExpressions: Seq[Expression])
-  extends InferredUnaryExpression(Functions.geometryType) with 
FoldableExpression {
+  extends InferredExpression(Functions.geometryType _) {
 
   protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = 
{
     copy(inputExpressions = newChildren)
@@ -372,7 +373,7 @@ case class ST_GeometryType(inputExpressions: 
Seq[Expression])
   * @param inputExpressions Geometry
   */
 case class ST_LineMerge(inputExpressions: Seq[Expression])
-  extends InferredUnaryExpression(Functions.lineMerge) with FoldableExpression 
{
+  extends InferredExpression(Functions.lineMerge _) {
 
   protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = 
{
     copy(inputExpressions = newChildren)
@@ -380,7 +381,7 @@ case class ST_LineMerge(inputExpressions: Seq[Expression])
 }
 
 case class ST_Azimuth(inputExpressions: Seq[Expression])
-  extends InferredBinaryExpression(Functions.azimuth) with FoldableExpression {
+  extends InferredExpression(Functions.azimuth _) {
 
   protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = 
{
     copy(inputExpressions = newChildren)
@@ -388,7 +389,7 @@ case class ST_Azimuth(inputExpressions: Seq[Expression])
 }
 
 case class ST_X(inputExpressions: Seq[Expression])
-  extends InferredUnaryExpression(Functions.x) with FoldableExpression {
+  extends InferredExpression(Functions.x _) {
 
   protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = 
{
     copy(inputExpressions = newChildren)
@@ -397,7 +398,7 @@ case class ST_X(inputExpressions: Seq[Expression])
 
 
 case class ST_Y(inputExpressions: Seq[Expression])
-  extends InferredUnaryExpression(Functions.y) with FoldableExpression {
+  extends InferredExpression(Functions.y _) {
 
   protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = 
{
     copy(inputExpressions = newChildren)
@@ -405,7 +406,7 @@ case class ST_Y(inputExpressions: Seq[Expression])
 }
 
 case class ST_Z(inputExpressions: Seq[Expression])
-  extends InferredUnaryExpression(Functions.z) with FoldableExpression {
+  extends InferredExpression(Functions.z _) {
 
   protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = 
{
     copy(inputExpressions = newChildren)
@@ -413,7 +414,7 @@ case class ST_Z(inputExpressions: Seq[Expression])
 }
 
 case class ST_StartPoint(inputExpressions: Seq[Expression])
-  extends InferredUnaryExpression(Functions.startPoint) with 
FoldableExpression {
+  extends InferredExpression(Functions.startPoint _) {
 
   protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = 
{
     copy(inputExpressions = newChildren)
@@ -421,7 +422,7 @@ case class ST_StartPoint(inputExpressions: Seq[Expression])
 }
 
 case class ST_Boundary(inputExpressions: Seq[Expression])
-  extends InferredUnaryExpression(Functions.boundary) with FoldableExpression {
+  extends InferredExpression(Functions.boundary _) {
 
   protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = 
{
     copy(inputExpressions = newChildren)
@@ -471,7 +472,7 @@ case class ST_MinimumBoundingRadius(inputExpressions: 
Seq[Expression])
 
 
 case class ST_MinimumBoundingCircle(inputExpressions: Seq[Expression])
-  extends InferredBinaryExpression(Functions.minimumBoundingCircle) with 
FoldableExpression {
+  extends InferredExpression(Functions.minimumBoundingCircle _) {
 
   protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = 
{
     copy(inputExpressions = newChildren)
@@ -486,7 +487,7 @@ case class ST_MinimumBoundingCircle(inputExpressions: 
Seq[Expression])
  * @param inputExpressions
  */
 case class ST_LineSubstring(inputExpressions: Seq[Expression])
-  extends InferredTernaryExpression(Functions.lineSubString) with 
FoldableExpression {
+  extends InferredExpression(Functions.lineSubString _) {
 
   protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = 
{
     copy(inputExpressions = newChildren)
@@ -501,7 +502,7 @@ case class ST_LineSubstring(inputExpressions: 
Seq[Expression])
  * @param inputExpressions
  */
 case class ST_LineInterpolatePoint(inputExpressions: Seq[Expression])
-  extends InferredBinaryExpression(Functions.lineInterpolatePoint) with 
FoldableExpression {
+  extends InferredExpression(Functions.lineInterpolatePoint _) {
 
   protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = 
{
     copy(inputExpressions = newChildren)
@@ -509,7 +510,7 @@ case class ST_LineInterpolatePoint(inputExpressions: 
Seq[Expression])
 }
 
 case class ST_EndPoint(inputExpressions: Seq[Expression])
-  extends InferredUnaryExpression(Functions.endPoint) with FoldableExpression {
+  extends InferredExpression(Functions.endPoint _) {
 
   protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = 
{
     copy(inputExpressions = newChildren)
@@ -517,7 +518,7 @@ case class ST_EndPoint(inputExpressions: Seq[Expression])
 }
 
 case class ST_ExteriorRing(inputExpressions: Seq[Expression])
-  extends InferredUnaryExpression(Functions.exteriorRing) with 
FoldableExpression {
+  extends InferredExpression(Functions.exteriorRing _) {
 
   protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = 
{
     copy(inputExpressions = newChildren)
@@ -526,7 +527,7 @@ case class ST_ExteriorRing(inputExpressions: 
Seq[Expression])
 
 
 case class ST_GeometryN(inputExpressions: Seq[Expression])
-  extends InferredBinaryExpression(Functions.geometryN) with 
FoldableExpression {
+  extends InferredExpression(Functions.geometryN _) {
 
   protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = 
{
     copy(inputExpressions = newChildren)
@@ -534,7 +535,7 @@ case class ST_GeometryN(inputExpressions: Seq[Expression])
 }
 
 case class ST_InteriorRingN(inputExpressions: Seq[Expression])
-  extends InferredBinaryExpression(Functions.interiorRingN) with 
FoldableExpression {
+  extends InferredExpression(Functions.interiorRingN _) {
 
   protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = 
{
     copy(inputExpressions = newChildren)
@@ -542,7 +543,7 @@ case class ST_InteriorRingN(inputExpressions: 
Seq[Expression])
 }
 
 case class ST_Dump(inputExpressions: Seq[Expression])
-  extends InferredUnaryExpression(Functions.dump) with FoldableExpression {
+  extends InferredExpression(Functions.dump _) {
 
   protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = 
{
     copy(inputExpressions = newChildren)
@@ -550,7 +551,7 @@ case class ST_Dump(inputExpressions: Seq[Expression])
 }
 
 case class ST_DumpPoints(inputExpressions: Seq[Expression])
-  extends InferredUnaryExpression(Functions.dumpPoints) with 
FoldableExpression {
+  extends InferredExpression(Functions.dumpPoints _) {
 
   protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = 
{
     copy(inputExpressions = newChildren)
@@ -559,7 +560,7 @@ case class ST_DumpPoints(inputExpressions: Seq[Expression])
 
 
 case class ST_IsClosed(inputExpressions: Seq[Expression])
-  extends InferredUnaryExpression(Functions.isClosed) with FoldableExpression {
+  extends InferredExpression(Functions.isClosed _) {
 
   protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = 
{
     copy(inputExpressions = newChildren)
@@ -567,7 +568,7 @@ case class ST_IsClosed(inputExpressions: Seq[Expression])
 }
 
 case class ST_NumInteriorRings(inputExpressions: Seq[Expression])
-  extends InferredUnaryExpression(Functions.numInteriorRings) with 
FoldableExpression {
+  extends InferredExpression(Functions.numInteriorRings _) {
 
   protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = 
{
     copy(inputExpressions = newChildren)
@@ -575,7 +576,7 @@ case class ST_NumInteriorRings(inputExpressions: 
Seq[Expression])
 }
 
 case class ST_AddPoint(inputExpressions: Seq[Expression])
-  extends InferredTernaryExpression(Functions.addPoint) with 
FoldableExpression {
+  extends InferredExpression(inferrableFunction3(Functions.addPoint)) {
 
   protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = 
{
     copy(inputExpressions = newChildren)
@@ -583,7 +584,7 @@ case class ST_AddPoint(inputExpressions: Seq[Expression])
 }
 
 case class ST_RemovePoint(inputExpressions: Seq[Expression])
-  extends InferredBinaryExpression(Functions.removePoint) with 
FoldableExpression {
+  extends InferredExpression(inferrableFunction2(Functions.removePoint)) {
 
   protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = 
{
     copy(inputExpressions = newChildren)
@@ -591,7 +592,7 @@ case class ST_RemovePoint(inputExpressions: Seq[Expression])
 }
 
 case class ST_SetPoint(inputExpressions: Seq[Expression])
-  extends InferredTernaryExpression(Functions.setPoint) with 
FoldableExpression {
+  extends InferredExpression(Functions.setPoint _) {
 
   protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = 
{
     copy(inputExpressions = newChildren)
@@ -599,24 +600,22 @@ case class ST_SetPoint(inputExpressions: Seq[Expression])
 }
 
 case class ST_IsRing(inputExpressions: Seq[Expression])
-  extends UnaryGeometryExpression with FoldableExpression with CodegenFallback 
{
-
-  override protected def nullSafeEval(geometry: Geometry): Any = {
-    geometry match {
-      case string: LineString => Functions.isRing(string)
-      case _ => null
-    }
-  }
-
-  override def dataType: DataType = BooleanType
-
-  override def children: Seq[Expression] = inputExpressions
+  extends InferredExpression(ST_IsRing.isRing _) {
 
   protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = 
{
     copy(inputExpressions = newChildren)
   }
 }
 
+object ST_IsRing {
+  def isRing(geom: Geometry): Option[Boolean] = {
+    geom match {
+      case _: LineString => Some(Functions.isRing(geom))
+      case _ => None
+    }
+  }
+}
+
 /**
   * Returns the number of Geometries. If geometry is a GEOMETRYCOLLECTION (or 
MULTI*) return the number of geometries,
   * for single geometries will return 1
@@ -626,7 +625,7 @@ case class ST_IsRing(inputExpressions: Seq[Expression])
   * @param inputExpressions Geometry
   */
 case class ST_NumGeometries(inputExpressions: Seq[Expression])
-  extends InferredUnaryExpression(Functions.numGeometries) with 
FoldableExpression {
+  extends InferredExpression(Functions.numGeometries _) {
 
   protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = 
{
     copy(inputExpressions = newChildren)
@@ -639,7 +638,7 @@ case class ST_NumGeometries(inputExpressions: 
Seq[Expression])
   * @param inputExpressions Geometry
   */
 case class ST_FlipCoordinates(inputExpressions: Seq[Expression])
-  extends InferredUnaryExpression(Functions.flipCoordinates) with 
FoldableExpression {
+  extends InferredExpression(Functions.flipCoordinates _) {
 
   protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = 
{
     copy(inputExpressions = newChildren)
@@ -647,7 +646,7 @@ case class ST_FlipCoordinates(inputExpressions: 
Seq[Expression])
 }
 
 case class ST_SubDivide(inputExpressions: Seq[Expression])
-  extends InferredBinaryExpression(Functions.subDivide) with 
FoldableExpression {
+  extends InferredExpression(Functions.subDivide _) {
 
   protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = 
{
     copy(inputExpressions = newChildren)
@@ -680,17 +679,15 @@ case class ST_SubDivideExplode(children: Seq[Expression])
 }
 
 case class ST_MakePolygon(inputExpressions: Seq[Expression])
-  extends InferredBinaryExpression(Functions.makePolygon) with 
FoldableExpression {
+  extends 
InferredExpression(InferrableFunction.allowRightNull(Functions.makePolygon)) {
 
   protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = 
{
     copy(inputExpressions = newChildren)
   }
-
-  override def allowRightNull: Boolean = true
 }
 
 case class ST_GeoHash(inputExpressions: Seq[Expression])
-  extends InferredBinaryExpression(Functions.geohash) with FoldableExpression {
+  extends InferredExpression(Functions.geohash _) {
 
   protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = 
{
     copy(inputExpressions = newChildren)
@@ -703,7 +700,7 @@ case class ST_GeoHash(inputExpressions: Seq[Expression])
  * @param inputExpressions
  */
 case class ST_Difference(inputExpressions: Seq[Expression])
-  extends InferredBinaryExpression(Functions.difference) with 
FoldableExpression {
+  extends InferredExpression(Functions.difference _) {
 
   protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = 
{
     copy(inputExpressions = newChildren)
@@ -716,7 +713,7 @@ case class ST_Difference(inputExpressions: Seq[Expression])
  * @param inputExpressions
  */
 case class ST_SymDifference(inputExpressions: Seq[Expression])
-  extends InferredBinaryExpression(Functions.symDifference) with 
FoldableExpression {
+  extends InferredExpression(Functions.symDifference _) {
 
   protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = 
{
     copy(inputExpressions = newChildren)
@@ -729,7 +726,7 @@ case class ST_SymDifference(inputExpressions: 
Seq[Expression])
  * @param inputExpressions
  */
 case class ST_Union(inputExpressions: Seq[Expression])
-  extends InferredBinaryExpression(Functions.union) with FoldableExpression {
+  extends InferredExpression(Functions.union _) {
 
   protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = 
{
     copy(inputExpressions = newChildren)
@@ -737,7 +734,7 @@ case class ST_Union(inputExpressions: Seq[Expression])
 }
 
 case class ST_Multi(inputExpressions: Seq[Expression])
-  extends InferredUnaryExpression(Functions.createMultiGeometryFromOneElement) 
with FoldableExpression {
+  extends InferredExpression(Functions.createMultiGeometryFromOneElement _) {
 
   protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = 
{
     copy(inputExpressions = newChildren)
@@ -750,7 +747,7 @@ case class ST_Multi(inputExpressions: Seq[Expression])
  * @param inputExpressions Geometry
  */
 case class ST_PointOnSurface(inputExpressions: Seq[Expression])
-  extends InferredUnaryExpression(Functions.pointOnSurface) with 
FoldableExpression {
+  extends InferredExpression(Functions.pointOnSurface _) {
 
   protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = 
{
     copy(inputExpressions = newChildren)
@@ -763,7 +760,7 @@ case class ST_PointOnSurface(inputExpressions: 
Seq[Expression])
  * @param inputExpressions
  */
 case class ST_Reverse(inputExpressions: Seq[Expression])
-  extends InferredUnaryExpression(Functions.reverse) with FoldableExpression {
+  extends InferredExpression(Functions.reverse _) {
 
   protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = 
{
     copy(inputExpressions = newChildren)
@@ -776,7 +773,7 @@ case class ST_Reverse(inputExpressions: Seq[Expression])
  * @param inputExpressions sequence of 2 input arguments, a geometry and a 
value 'n'
  */
 case class ST_PointN(inputExpressions: Seq[Expression])
-  extends InferredBinaryExpression(Functions.pointN) with FoldableExpression {
+  extends InferredExpression(Functions.pointN _) {
 
   protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = 
{
       copy(inputExpressions = newChildren)
@@ -789,7 +786,7 @@ case class ST_PointN(inputExpressions: Seq[Expression])
  * @param inputExpressions
  */
 case class ST_Force_2D(inputExpressions: Seq[Expression])
-  extends InferredUnaryExpression(Functions.force2D) with FoldableExpression {
+  extends InferredExpression(Functions.force2D _) {
 
   protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = 
{
     copy(inputExpressions = newChildren)
@@ -802,7 +799,7 @@ case class ST_Force_2D(inputExpressions: Seq[Expression])
  * @param inputExpressions
  */
 case class ST_AsEWKT(inputExpressions: Seq[Expression])
-  extends InferredUnaryExpression(Functions.asEWKT) with FoldableExpression {
+  extends InferredExpression(Functions.asEWKT _) {
 
   protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = 
{
     copy(inputExpressions = newChildren)
@@ -810,7 +807,7 @@ case class ST_AsEWKT(inputExpressions: Seq[Expression])
 }
 
 case class ST_AsGML(inputExpressions: Seq[Expression])
-  extends InferredUnaryExpression(Functions.asGML) with FoldableExpression {
+  extends InferredExpression(Functions.asGML _) {
 
   protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = 
{
     copy(inputExpressions = newChildren)
@@ -818,7 +815,7 @@ case class ST_AsGML(inputExpressions: Seq[Expression])
 }
 
 case class ST_AsKML(inputExpressions: Seq[Expression])
-  extends InferredUnaryExpression(Functions.asKML) with FoldableExpression {
+  extends InferredExpression(Functions.asKML _) {
 
   protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = 
{
     copy(inputExpressions = newChildren)
@@ -831,7 +828,7 @@ case class ST_AsKML(inputExpressions: Seq[Expression])
  * @param inputExpressions
  */
 case class ST_IsEmpty(inputExpressions: Seq[Expression])
-  extends InferredUnaryExpression(Functions.isEmpty) with FoldableExpression {
+  extends InferredExpression(Functions.isEmpty _) {
 
   protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = 
{
     copy(inputExpressions = newChildren)
@@ -844,7 +841,7 @@ case class ST_IsEmpty(inputExpressions: Seq[Expression])
  * @param inputExpressions
  */
 case class ST_XMax(inputExpressions: Seq[Expression])
-  extends InferredUnaryExpression(Functions.xMax) with FoldableExpression {
+  extends InferredExpression(Functions.xMax _) {
 
   protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = 
{
     copy(inputExpressions = newChildren)
@@ -857,7 +854,7 @@ case class ST_XMax(inputExpressions: Seq[Expression])
  * @param inputExpressions
  */
 case class ST_XMin(inputExpressions: Seq[Expression])
-  extends InferredUnaryExpression(Functions.xMin) with FoldableExpression {
+  extends InferredExpression(Functions.xMin _) {
 
   protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = 
{
     copy(inputExpressions = newChildren)
@@ -871,7 +868,7 @@ case class ST_XMin(inputExpressions: Seq[Expression])
  * @param inputExpressions
  */
 case class ST_BuildArea(inputExpressions: Seq[Expression])
-  extends InferredUnaryExpression(Functions.buildArea) with FoldableExpression 
{
+  extends InferredExpression(Functions.buildArea _) {
 
   protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): 
Expression = {
     copy(inputExpressions = newChildren)
@@ -884,7 +881,7 @@ case class ST_BuildArea(inputExpressions: Seq[Expression])
  * @param inputExpressions
  */
 case class ST_Normalize(inputExpressions: Seq[Expression])
-  extends InferredUnaryExpression(Functions.normalize) with FoldableExpression 
{
+  extends InferredExpression(Functions.normalize _) {
 
   protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): 
Expression = {
     copy(inputExpressions = newChildren)
@@ -897,7 +894,7 @@ case class ST_Normalize(inputExpressions: Seq[Expression])
  * @param inputExpressions
  */
 case class ST_LineFromMultiPoint(inputExpressions: Seq[Expression])
-  extends InferredUnaryExpression(Functions.lineFromMultiPoint) with 
FoldableExpression {
+  extends InferredExpression(Functions.lineFromMultiPoint _) {
 
   protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = 
{
     copy(inputExpressions = newChildren)
@@ -910,7 +907,7 @@ case class ST_LineFromMultiPoint(inputExpressions: 
Seq[Expression])
  * @param inputExpressions
  */
 case class ST_Split(inputExpressions: Seq[Expression])
-  extends InferredBinaryExpression(Functions.split) with FoldableExpression {
+  extends InferredExpression(Functions.split _) {
 
   protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = 
{
     copy(inputExpressions = newChildren)
@@ -918,7 +915,7 @@ case class ST_Split(inputExpressions: Seq[Expression])
 }
 
 case class ST_S2CellIDs(inputExpressions: Seq[Expression])
-  extends InferredBinaryExpression(Functions.s2CellIDs) with 
FoldableExpression {
+  extends InferredExpression(Functions.s2CellIDs _) {
 
   protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = 
{
     copy(inputExpressions = newChildren)
@@ -926,13 +923,11 @@ case class ST_S2CellIDs(inputExpressions: Seq[Expression])
 }
 
 case class ST_CollectionExtract(inputExpressions: Seq[Expression])
-  extends InferredBinaryExpression(Functions.collectionExtract) with 
FoldableExpression {
+  extends 
InferredExpression(InferrableFunction.allowRightNull(Functions.collectionExtract))
 {
 
   protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = 
{
     copy(inputExpressions = newChildren)
   }
-
-  override def allowRightNull: Boolean = true
 }
 
 /**
@@ -942,7 +937,7 @@ case class ST_CollectionExtract(inputExpressions: 
Seq[Expression])
  * @param inputExpressions Geometry
  */
 case class ST_GeometricMedian(inputExpressions: Seq[Expression])
-  extends InferredQuarternaryExpression(Functions.geometricMedian) with 
FoldableExpression {
+  extends InferredExpression(inferrableFunction4(Functions.geometricMedian)) 
with FoldableExpression {
 
   protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = 
{
     copy(inputExpressions = newChildren)
@@ -950,7 +945,7 @@ case class ST_GeometricMedian(inputExpressions: 
Seq[Expression])
 }
 
 case class ST_DistanceSphere(inputExpressions: Seq[Expression])
-  extends InferredTernaryExpression(Haversine.distance) with 
FoldableExpression {
+  extends InferredExpression(inferrableFunction3(Haversine.distance)) {
 
   protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = 
{
     copy(inputExpressions = newChildren)
@@ -958,7 +953,7 @@ case class ST_DistanceSphere(inputExpressions: 
Seq[Expression])
 }
 
 case class ST_DistanceSpheroid(inputExpressions: Seq[Expression])
-  extends InferredBinaryExpression(Spheroid.distance) with FoldableExpression {
+  extends InferredExpression(Spheroid.distance _) {
 
   protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = 
{
     copy(inputExpressions = newChildren)
@@ -966,7 +961,7 @@ case class ST_DistanceSpheroid(inputExpressions: 
Seq[Expression])
 }
 
 case class ST_AreaSpheroid(inputExpressions: Seq[Expression])
-  extends InferredUnaryExpression(Spheroid.area) with FoldableExpression {
+  extends InferredExpression(Spheroid.area _) {
 
   protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = 
{
     copy(inputExpressions = newChildren)
@@ -974,7 +969,7 @@ case class ST_AreaSpheroid(inputExpressions: 
Seq[Expression])
 }
 
 case class ST_LengthSpheroid(inputExpressions: Seq[Expression])
-  extends InferredUnaryExpression(Spheroid.length) with FoldableExpression {
+  extends InferredExpression(Spheroid.length _) {
 
   protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = 
{
     copy(inputExpressions = newChildren)
@@ -982,14 +977,14 @@ case class ST_LengthSpheroid(inputExpressions: 
Seq[Expression])
 }
 
 case class ST_NumPoints(inputExpressions: Seq[Expression])
-  extends InferredUnaryExpression(Functions.numPoints) with FoldableExpression 
{
+  extends InferredExpression(Functions.numPoints _) {
   protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = 
{
     copy(inputExpressions = newChildren)
   }
 }
 
 case class ST_Force3D(inputExpressions: Seq[Expression])
-  extends InferredBinaryExpression(Functions.force3D) with FoldableExpression {
+  extends InferredExpression(inferrableFunction2(Functions.force3D)) {
 
   protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = 
{
     copy(inputExpressions = newChildren)
@@ -997,14 +992,14 @@ case class ST_Force3D(inputExpressions: Seq[Expression])
 }
 
 case class ST_NRings(inputExpressions: Seq[Expression])
-  extends InferredUnaryExpression(Functions.nRings) with FoldableExpression {
+  extends InferredExpression(Functions.nRings _) {
   protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = 
{
     copy(inputExpressions = newChildren)
   }
 }
 
 case class ST_Translate(inputExpressions: Seq[Expression])
-  extends InferredQuarternaryExpression(Functions.translate) with 
FoldableExpression {
+  extends InferredExpression(inferrableFunction4(Functions.translate)) with 
FoldableExpression {
   protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = 
{
     copy(inputExpressions = newChildren)
   }
@@ -1018,14 +1013,14 @@ case class ST_Dimension(inputExpressions: 
Seq[Expression])
 }
 
 case class ST_BoundingDiagonal(inputExpressions: Seq[Expression])
-  extends InferredUnaryExpression(Functions.boundingDiagonal) with 
FoldableExpression {
+  extends InferredExpression(Functions.boundingDiagonal _) {
   protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = 
{
     copy(inputExpressions = newChildren)
   }
 }
 
 case class ST_HausdorffDistance(inputExpressions: Seq[Expression])
-  extends InferredTernaryExpression(Functions.hausdorffDistance) with 
FoldableExpression {
+  extends InferredExpression(inferrableFunction3(Functions.hausdorffDistance)) 
{
   protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = 
{
     copy(inputExpressions = newChildren)
   }
diff --git 
a/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/InferrableFunctionConverter.scala
 
b/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/InferrableFunctionConverter.scala
new file mode 100644
index 00000000..83156db9
--- /dev/null
+++ 
b/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/InferrableFunctionConverter.scala
@@ -0,0 +1,494 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.spark.sql.sedona_sql.expressions
+
+import scala.reflect.runtime.universe.TypeTag
+
+/**
+ * Implicit conversions from Java/Scala functions to [[InferrableFunction]]. 
This should be used in conjunction with
+ * [[InferredExpression]] to make wrapping Java/Scala functions as catalyst 
expressions much easier.
+ */
+object InferrableFunctionConverter {
+  // scalastyle:off line.size.limit
+  implicit def inferrableFunction1[R: InferrableType, A1: InferrableType](f: 
(A1) => R)(implicit typeTag: TypeTag[(A1) => R])
+  : InferrableFunction = InferrableFunction(typeTag, argExtractors => {
+    val func = f.asInstanceOf[(Any) => Any]
+    val extractor1 = argExtractors(0)
+    input => {
+      val arg1 = extractor1(input)
+      if (arg1 != null) {
+        func(arg1)
+      } else {
+        null
+      }
+    }
+  })
+
+  implicit def inferrableFunction2[R: InferrableType, A1: InferrableType, A2: 
InferrableType](f: (A1, A2) => R)(implicit typeTag: TypeTag[(A1, A2) => R])
+  : InferrableFunction = InferrableFunction(typeTag, argExtractors => {
+    val func = f.asInstanceOf[(Any, Any) => Any]
+    val extractor1 = argExtractors(0)
+    val extractor2 = argExtractors(1)
+    input => {
+      val arg1 = extractor1(input)
+      val arg2 = extractor2(input)
+      if (arg1 != null && arg2 != null) {
+        func(arg1, arg2)
+      } else {
+        null
+      }
+    }
+  })
+
+  implicit def inferrableFunction3[R: InferrableType, A1: InferrableType, A2: 
InferrableType, A3: InferrableType](f: (A1, A2, A3) => R)(implicit typeTag: 
TypeTag[(A1, A2, A3) => R])
+  : InferrableFunction = InferrableFunction(typeTag, argExtractors => {
+    val func = f.asInstanceOf[(Any, Any, Any) => Any]
+    val extractor1 = argExtractors(0)
+    val extractor2 = argExtractors(1)
+    val extractor3 = argExtractors(2)
+    input => {
+      val arg1 = extractor1(input)
+      val arg2 = extractor2(input)
+      val arg3 = extractor3(input)
+      if (arg1 != null && arg2 != null && arg3 != null) {
+        func(arg1, arg2, arg3)
+      } else {
+        null
+      }
+    }
+  })
+
+  implicit def inferrableFunction4[R: InferrableType, A1: InferrableType, A2: 
InferrableType, A3: InferrableType, A4: InferrableType](f: (A1, A2, A3, A4) => 
R)(implicit typeTag: TypeTag[(A1, A2, A3, A4) => R])
+  : InferrableFunction = InferrableFunction(typeTag, argExtractors => {
+    val func = f.asInstanceOf[(Any, Any, Any, Any) => Any]
+    val extractor1 = argExtractors(0)
+    val extractor2 = argExtractors(1)
+    val extractor3 = argExtractors(2)
+    val extractor4 = argExtractors(3)
+    input => {
+      val arg1 = extractor1(input)
+      val arg2 = extractor2(input)
+      val arg3 = extractor3(input)
+      val arg4 = extractor4(input)
+      if (arg1 != null && arg2 != null && arg3 != null && arg4 != null) {
+        func(arg1, arg2, arg3, arg4)
+      } else {
+        null
+      }
+    }
+  })
+
+  implicit def inferrableFunction5[R: InferrableType, A1: InferrableType, A2: 
InferrableType, A3: InferrableType, A4: InferrableType, A5: InferrableType](f: 
(A1, A2, A3, A4, A5) => R)(implicit typeTag: TypeTag[(A1, A2, A3, A4, A5) => R])
+  : InferrableFunction = InferrableFunction(typeTag, argExtractors => {
+    val func = f.asInstanceOf[(Any, Any, Any, Any, Any) => Any]
+    val extractor1 = argExtractors(0)
+    val extractor2 = argExtractors(1)
+    val extractor3 = argExtractors(2)
+    val extractor4 = argExtractors(3)
+    val extractor5 = argExtractors(4)
+    input => {
+      val arg1 = extractor1(input)
+      val arg2 = extractor2(input)
+      val arg3 = extractor3(input)
+      val arg4 = extractor4(input)
+      val arg5 = extractor5(input)
+      if (arg1 != null && arg2 != null && arg3 != null && arg4 != null && arg5 
!= null) {
+        func(arg1, arg2, arg3, arg4, arg5)
+      } else {
+        null
+      }
+    }
+  })
+
+  implicit def inferrableFunction6[R: InferrableType, A1: InferrableType, A2: 
InferrableType, A3: InferrableType, A4: InferrableType, A5: InferrableType, A6: 
InferrableType](f: (A1, A2, A3, A4, A5, A6) => R)(implicit typeTag: 
TypeTag[(A1, A2, A3, A4, A5, A6) => R])
+  : InferrableFunction = InferrableFunction(typeTag, argExtractors => {
+    val func = f.asInstanceOf[(Any, Any, Any, Any, Any, Any) => Any]
+    val extractor1 = argExtractors(0)
+    val extractor2 = argExtractors(1)
+    val extractor3 = argExtractors(2)
+    val extractor4 = argExtractors(3)
+    val extractor5 = argExtractors(4)
+    val extractor6 = argExtractors(5)
+    input => {
+      val arg1 = extractor1(input)
+      val arg2 = extractor2(input)
+      val arg3 = extractor3(input)
+      val arg4 = extractor4(input)
+      val arg5 = extractor5(input)
+      val arg6 = extractor6(input)
+      if (arg1 != null && arg2 != null && arg3 != null && arg4 != null && arg5 
!= null && arg6 != null) {
+        func(arg1, arg2, arg3, arg4, arg5, arg6)
+      } else {
+        null
+      }
+    }
+  })
+
+  implicit def inferrableFunction7[R: InferrableType, A1: InferrableType, A2: 
InferrableType, A3: InferrableType, A4: InferrableType, A5: InferrableType, A6: 
InferrableType, A7: InferrableType](f: (A1, A2, A3, A4, A5, A6, A7) => 
R)(implicit typeTag: TypeTag[(A1, A2, A3, A4, A5, A6, A7) => R])
+  : InferrableFunction = InferrableFunction(typeTag, argExtractors => {
+    val func = f.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any) => Any]
+    val extractor1 = argExtractors(0)
+    val extractor2 = argExtractors(1)
+    val extractor3 = argExtractors(2)
+    val extractor4 = argExtractors(3)
+    val extractor5 = argExtractors(4)
+    val extractor6 = argExtractors(5)
+    val extractor7 = argExtractors(6)
+    input => {
+      val arg1 = extractor1(input)
+      val arg2 = extractor2(input)
+      val arg3 = extractor3(input)
+      val arg4 = extractor4(input)
+      val arg5 = extractor5(input)
+      val arg6 = extractor6(input)
+      val arg7 = extractor7(input)
+      if (arg1 != null && arg2 != null && arg3 != null && arg4 != null && arg5 
!= null && arg6 != null && arg7 != null) {
+        func(arg1, arg2, arg3, arg4, arg5, arg6, arg7)
+      } else {
+        null
+      }
+    }
+  })
+
+  implicit def inferrableFunction8[R: InferrableType, A1: InferrableType, A2: 
InferrableType, A3: InferrableType, A4: InferrableType, A5: InferrableType, A6: 
InferrableType, A7: InferrableType, A8: InferrableType](f: (A1, A2, A3, A4, A5, 
A6, A7, A8) => R)(implicit typeTag: TypeTag[(A1, A2, A3, A4, A5, A6, A7, A8) => 
R])
+  : InferrableFunction = InferrableFunction(typeTag, argExtractors => {
+    val func = f.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any) => Any]
+    val extractor1 = argExtractors(0)
+    val extractor2 = argExtractors(1)
+    val extractor3 = argExtractors(2)
+    val extractor4 = argExtractors(3)
+    val extractor5 = argExtractors(4)
+    val extractor6 = argExtractors(5)
+    val extractor7 = argExtractors(6)
+    val extractor8 = argExtractors(7)
+    input => {
+      val arg1 = extractor1(input)
+      val arg2 = extractor2(input)
+      val arg3 = extractor3(input)
+      val arg4 = extractor4(input)
+      val arg5 = extractor5(input)
+      val arg6 = extractor6(input)
+      val arg7 = extractor7(input)
+      val arg8 = extractor8(input)
+      if (arg1 != null && arg2 != null && arg3 != null && arg4 != null && arg5 
!= null && arg6 != null && arg7 != null && arg8 != null) {
+        func(arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8)
+      } else {
+        null
+      }
+    }
+  })
+
+  implicit def inferrableFunction9[R: InferrableType, A1: InferrableType, A2: 
InferrableType, A3: InferrableType, A4: InferrableType, A5: InferrableType, A6: 
InferrableType, A7: InferrableType, A8: InferrableType, A9: InferrableType](f: 
(A1, A2, A3, A4, A5, A6, A7, A8, A9) => R)(implicit typeTag: TypeTag[(A1, A2, 
A3, A4, A5, A6, A7, A8, A9) => R])
+  : InferrableFunction = InferrableFunction(typeTag, argExtractors => {
+    val func = f.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any) => 
Any]
+    val extractor1 = argExtractors(0)
+    val extractor2 = argExtractors(1)
+    val extractor3 = argExtractors(2)
+    val extractor4 = argExtractors(3)
+    val extractor5 = argExtractors(4)
+    val extractor6 = argExtractors(5)
+    val extractor7 = argExtractors(6)
+    val extractor8 = argExtractors(7)
+    val extractor9 = argExtractors(8)
+    input => {
+      val arg1 = extractor1(input)
+      val arg2 = extractor2(input)
+      val arg3 = extractor3(input)
+      val arg4 = extractor4(input)
+      val arg5 = extractor5(input)
+      val arg6 = extractor6(input)
+      val arg7 = extractor7(input)
+      val arg8 = extractor8(input)
+      val arg9 = extractor9(input)
+      if (arg1 != null && arg2 != null && arg3 != null && arg4 != null && arg5 
!= null && arg6 != null && arg7 != null && arg8 != null && arg9 != null) {
+        func(arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8, arg9)
+      } else {
+        null
+      }
+    }
+  })
+
+  implicit def inferrableFunction10[R: InferrableType, A1: InferrableType, A2: 
InferrableType, A3: InferrableType, A4: InferrableType, A5: InferrableType, A6: 
InferrableType, A7: InferrableType, A8: InferrableType, A9: InferrableType, 
A10: InferrableType](f: (A1, A2, A3, A4, A5, A6, A7, A8, A9, A10) => 
R)(implicit typeTag: TypeTag[(A1, A2, A3, A4, A5, A6, A7, A8, A9, A10) => R])
+  : InferrableFunction = InferrableFunction(typeTag, argExtractors => {
+    val func = f.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, 
Any) => Any]
+    val extractor1 = argExtractors(0)
+    val extractor2 = argExtractors(1)
+    val extractor3 = argExtractors(2)
+    val extractor4 = argExtractors(3)
+    val extractor5 = argExtractors(4)
+    val extractor6 = argExtractors(5)
+    val extractor7 = argExtractors(6)
+    val extractor8 = argExtractors(7)
+    val extractor9 = argExtractors(8)
+    val extractor10 = argExtractors(9)
+    input => {
+      val arg1 = extractor1(input)
+      val arg2 = extractor2(input)
+      val arg3 = extractor3(input)
+      val arg4 = extractor4(input)
+      val arg5 = extractor5(input)
+      val arg6 = extractor6(input)
+      val arg7 = extractor7(input)
+      val arg8 = extractor8(input)
+      val arg9 = extractor9(input)
+      val arg10 = extractor10(input)
+      if (arg1 != null && arg2 != null && arg3 != null && arg4 != null && arg5 
!= null && arg6 != null && arg7 != null && arg8 != null && arg9 != null && 
arg10 != null) {
+        func(arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8, arg9, arg10)
+      } else {
+        null
+      }
+    }
+  })
+
+  implicit def inferrableFunction11[R: InferrableType, A1: InferrableType, A2: 
InferrableType, A3: InferrableType, A4: InferrableType, A5: InferrableType, A6: 
InferrableType, A7: InferrableType, A8: InferrableType, A9: InferrableType, 
A10: InferrableType, A11: InferrableType](f: (A1, A2, A3, A4, A5, A6, A7, A8, 
A9, A10, A11) => R)(implicit typeTag: TypeTag[(A1, A2, A3, A4, A5, A6, A7, A8, 
A9, A10, A11) => R])
+  : InferrableFunction = InferrableFunction(typeTag, argExtractors => {
+    val func = f.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, 
Any, Any) => Any]
+    val extractor1 = argExtractors(0)
+    val extractor2 = argExtractors(1)
+    val extractor3 = argExtractors(2)
+    val extractor4 = argExtractors(3)
+    val extractor5 = argExtractors(4)
+    val extractor6 = argExtractors(5)
+    val extractor7 = argExtractors(6)
+    val extractor8 = argExtractors(7)
+    val extractor9 = argExtractors(8)
+    val extractor10 = argExtractors(9)
+    val extractor11 = argExtractors(10)
+    input => {
+      val arg1 = extractor1(input)
+      val arg2 = extractor2(input)
+      val arg3 = extractor3(input)
+      val arg4 = extractor4(input)
+      val arg5 = extractor5(input)
+      val arg6 = extractor6(input)
+      val arg7 = extractor7(input)
+      val arg8 = extractor8(input)
+      val arg9 = extractor9(input)
+      val arg10 = extractor10(input)
+      val arg11 = extractor11(input)
+      if (arg1 != null && arg2 != null && arg3 != null && arg4 != null && arg5 
!= null && arg6 != null && arg7 != null && arg8 != null && arg9 != null && 
arg10 != null && arg11 != null) {
+        func(arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8, arg9, arg10, 
arg11)
+      } else {
+        null
+      }
+    }
+  })
+
+  implicit def inferrableFunction12[R: InferrableType, A1: InferrableType, A2: 
InferrableType, A3: InferrableType, A4: InferrableType, A5: InferrableType, A6: 
InferrableType, A7: InferrableType, A8: InferrableType, A9: InferrableType, 
A10: InferrableType, A11: InferrableType, A12: InferrableType](f: (A1, A2, A3, 
A4, A5, A6, A7, A8, A9, A10, A11, A12) => R)(implicit typeTag: TypeTag[(A1, A2, 
A3, A4, A5, A6, A7, A8, A9, A10, A11, A12) => R])
+  : InferrableFunction = InferrableFunction(typeTag, argExtractors => {
+    val func = f.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, 
Any, Any, Any) => Any]
+    val extractor1 = argExtractors(0)
+    val extractor2 = argExtractors(1)
+    val extractor3 = argExtractors(2)
+    val extractor4 = argExtractors(3)
+    val extractor5 = argExtractors(4)
+    val extractor6 = argExtractors(5)
+    val extractor7 = argExtractors(6)
+    val extractor8 = argExtractors(7)
+    val extractor9 = argExtractors(8)
+    val extractor10 = argExtractors(9)
+    val extractor11 = argExtractors(10)
+    val extractor12 = argExtractors(11)
+    input => {
+      val arg1 = extractor1(input)
+      val arg2 = extractor2(input)
+      val arg3 = extractor3(input)
+      val arg4 = extractor4(input)
+      val arg5 = extractor5(input)
+      val arg6 = extractor6(input)
+      val arg7 = extractor7(input)
+      val arg8 = extractor8(input)
+      val arg9 = extractor9(input)
+      val arg10 = extractor10(input)
+      val arg11 = extractor11(input)
+      val arg12 = extractor12(input)
+      if (arg1 != null && arg2 != null && arg3 != null && arg4 != null && arg5 
!= null && arg6 != null && arg7 != null && arg8 != null && arg9 != null && 
arg10 != null && arg11 != null && arg12 != null) {
+        func(arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8, arg9, arg10, 
arg11, arg12)
+      } else {
+        null
+      }
+    }
+  })
+
+  implicit def inferrableFunction13[R: InferrableType, A1: InferrableType, A2: 
InferrableType, A3: InferrableType, A4: InferrableType, A5: InferrableType, A6: 
InferrableType, A7: InferrableType, A8: InferrableType, A9: InferrableType, 
A10: InferrableType, A11: InferrableType, A12: InferrableType, A13: 
InferrableType](f: (A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13) => 
R)(implicit typeTag: TypeTag[(A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, 
A12, A13) => R])
+  : InferrableFunction = InferrableFunction(typeTag, argExtractors => {
+    val func = f.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, 
Any, Any, Any, Any) => Any]
+    val extractor1 = argExtractors(0)
+    val extractor2 = argExtractors(1)
+    val extractor3 = argExtractors(2)
+    val extractor4 = argExtractors(3)
+    val extractor5 = argExtractors(4)
+    val extractor6 = argExtractors(5)
+    val extractor7 = argExtractors(6)
+    val extractor8 = argExtractors(7)
+    val extractor9 = argExtractors(8)
+    val extractor10 = argExtractors(9)
+    val extractor11 = argExtractors(10)
+    val extractor12 = argExtractors(11)
+    val extractor13 = argExtractors(12)
+    input => {
+      val arg1 = extractor1(input)
+      val arg2 = extractor2(input)
+      val arg3 = extractor3(input)
+      val arg4 = extractor4(input)
+      val arg5 = extractor5(input)
+      val arg6 = extractor6(input)
+      val arg7 = extractor7(input)
+      val arg8 = extractor8(input)
+      val arg9 = extractor9(input)
+      val arg10 = extractor10(input)
+      val arg11 = extractor11(input)
+      val arg12 = extractor12(input)
+      val arg13 = extractor13(input)
+      if (arg1 != null && arg2 != null && arg3 != null && arg4 != null && arg5 
!= null && arg6 != null && arg7 != null && arg8 != null && arg9 != null && 
arg10 != null && arg11 != null && arg12 != null && arg13 != null) {
+        func(arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8, arg9, arg10, 
arg11, arg12, arg13)
+      } else {
+        null
+      }
+    }
+  })
+
+  implicit def inferrableFunction14[R: InferrableType, A1: InferrableType, A2: 
InferrableType, A3: InferrableType, A4: InferrableType, A5: InferrableType, A6: 
InferrableType, A7: InferrableType, A8: InferrableType, A9: InferrableType, 
A10: InferrableType, A11: InferrableType, A12: InferrableType, A13: 
InferrableType, A14: InferrableType](f: (A1, A2, A3, A4, A5, A6, A7, A8, A9, 
A10, A11, A12, A13, A14) => R)(implicit typeTag: TypeTag[(A1, A2, A3, A4, A5, 
A6, A7, A8, A9, A10, A11, A12, A13 [...]
+  : InferrableFunction = InferrableFunction(typeTag, argExtractors => {
+    val func = f.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, 
Any, Any, Any, Any, Any) => Any]
+    val extractor1 = argExtractors(0)
+    val extractor2 = argExtractors(1)
+    val extractor3 = argExtractors(2)
+    val extractor4 = argExtractors(3)
+    val extractor5 = argExtractors(4)
+    val extractor6 = argExtractors(5)
+    val extractor7 = argExtractors(6)
+    val extractor8 = argExtractors(7)
+    val extractor9 = argExtractors(8)
+    val extractor10 = argExtractors(9)
+    val extractor11 = argExtractors(10)
+    val extractor12 = argExtractors(11)
+    val extractor13 = argExtractors(12)
+    val extractor14 = argExtractors(13)
+    input => {
+      val arg1 = extractor1(input)
+      val arg2 = extractor2(input)
+      val arg3 = extractor3(input)
+      val arg4 = extractor4(input)
+      val arg5 = extractor5(input)
+      val arg6 = extractor6(input)
+      val arg7 = extractor7(input)
+      val arg8 = extractor8(input)
+      val arg9 = extractor9(input)
+      val arg10 = extractor10(input)
+      val arg11 = extractor11(input)
+      val arg12 = extractor12(input)
+      val arg13 = extractor13(input)
+      val arg14 = extractor14(input)
+      if (arg1 != null && arg2 != null && arg3 != null && arg4 != null && arg5 
!= null && arg6 != null && arg7 != null && arg8 != null && arg9 != null && 
arg10 != null && arg11 != null && arg12 != null && arg13 != null && arg14 != 
null) {
+        func(arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8, arg9, arg10, 
arg11, arg12, arg13, arg14)
+      } else {
+        null
+      }
+    }
+  })
+
+  implicit def inferrableFunction15[R: InferrableType, A1: InferrableType, A2: 
InferrableType, A3: InferrableType, A4: InferrableType, A5: InferrableType, A6: 
InferrableType, A7: InferrableType, A8: InferrableType, A9: InferrableType, 
A10: InferrableType, A11: InferrableType, A12: InferrableType, A13: 
InferrableType, A14: InferrableType, A15: InferrableType](f: (A1, A2, A3, A4, 
A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15) => R)(implicit typeTag: 
TypeTag[(A1, A2, A3, A4, A5, A6, A7,  [...]
+  : InferrableFunction = InferrableFunction(typeTag, argExtractors => {
+    val func = f.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, 
Any, Any, Any, Any, Any, Any) => Any]
+    val extractor1 = argExtractors(0)
+    val extractor2 = argExtractors(1)
+    val extractor3 = argExtractors(2)
+    val extractor4 = argExtractors(3)
+    val extractor5 = argExtractors(4)
+    val extractor6 = argExtractors(5)
+    val extractor7 = argExtractors(6)
+    val extractor8 = argExtractors(7)
+    val extractor9 = argExtractors(8)
+    val extractor10 = argExtractors(9)
+    val extractor11 = argExtractors(10)
+    val extractor12 = argExtractors(11)
+    val extractor13 = argExtractors(12)
+    val extractor14 = argExtractors(13)
+    val extractor15 = argExtractors(14)
+    input => {
+      val arg1 = extractor1(input)
+      val arg2 = extractor2(input)
+      val arg3 = extractor3(input)
+      val arg4 = extractor4(input)
+      val arg5 = extractor5(input)
+      val arg6 = extractor6(input)
+      val arg7 = extractor7(input)
+      val arg8 = extractor8(input)
+      val arg9 = extractor9(input)
+      val arg10 = extractor10(input)
+      val arg11 = extractor11(input)
+      val arg12 = extractor12(input)
+      val arg13 = extractor13(input)
+      val arg14 = extractor14(input)
+      val arg15 = extractor15(input)
+      if (arg1 != null && arg2 != null && arg3 != null && arg4 != null && arg5 
!= null && arg6 != null && arg7 != null && arg8 != null && arg9 != null && 
arg10 != null && arg11 != null && arg12 != null && arg13 != null && arg14 != 
null && arg15 != null) {
+        func(arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8, arg9, arg10, 
arg11, arg12, arg13, arg14, arg15)
+      } else {
+        null
+      }
+    }
+  })
+
+  implicit def inferrableFunction16[R: InferrableType, A1: InferrableType, A2: 
InferrableType, A3: InferrableType, A4: InferrableType, A5: InferrableType, A6: 
InferrableType, A7: InferrableType, A8: InferrableType, A9: InferrableType, 
A10: InferrableType, A11: InferrableType, A12: InferrableType, A13: 
InferrableType, A14: InferrableType, A15: InferrableType, A16: 
InferrableType](f: (A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, 
A14, A15, A16) => R)(implicit typeTag: TypeTag[(A1 [...]
+  : InferrableFunction = InferrableFunction(typeTag, argExtractors => {
+    val func = f.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, 
Any, Any, Any, Any, Any, Any, Any) => Any]
+    val extractor1 = argExtractors(0)
+    val extractor2 = argExtractors(1)
+    val extractor3 = argExtractors(2)
+    val extractor4 = argExtractors(3)
+    val extractor5 = argExtractors(4)
+    val extractor6 = argExtractors(5)
+    val extractor7 = argExtractors(6)
+    val extractor8 = argExtractors(7)
+    val extractor9 = argExtractors(8)
+    val extractor10 = argExtractors(9)
+    val extractor11 = argExtractors(10)
+    val extractor12 = argExtractors(11)
+    val extractor13 = argExtractors(12)
+    val extractor14 = argExtractors(13)
+    val extractor15 = argExtractors(14)
+    val extractor16 = argExtractors(15)
+    input => {
+      val arg1 = extractor1(input)
+      val arg2 = extractor2(input)
+      val arg3 = extractor3(input)
+      val arg4 = extractor4(input)
+      val arg5 = extractor5(input)
+      val arg6 = extractor6(input)
+      val arg7 = extractor7(input)
+      val arg8 = extractor8(input)
+      val arg9 = extractor9(input)
+      val arg10 = extractor10(input)
+      val arg11 = extractor11(input)
+      val arg12 = extractor12(input)
+      val arg13 = extractor13(input)
+      val arg14 = extractor14(input)
+      val arg15 = extractor15(input)
+      val arg16 = extractor16(input)
+      if (arg1 != null && arg2 != null && arg3 != null && arg4 != null && arg5 
!= null && arg6 != null && arg7 != null && arg8 != null && arg9 != null && 
arg10 != null && arg11 != null && arg12 != null && arg13 != null && arg14 != 
null && arg15 != null && arg16 != null) {
+        func(arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8, arg9, arg10, 
arg11, arg12, arg13, arg14, arg15, arg16)
+      } else {
+        null
+      }
+    }
+  })
+  // scalastyle:on
+}
diff --git 
a/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/InferredExpression.scala
 
b/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/InferredExpression.scala
new file mode 100644
index 00000000..a1ce24b9
--- /dev/null
+++ 
b/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/InferredExpression.scala
@@ -0,0 +1,223 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.spark.sql.sedona_sql.expressions
+
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.{Expression, 
ImplicitCastInputTypes}
+import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
+import org.apache.spark.sql.catalyst.util.ArrayData
+import org.apache.spark.sql.sedona_sql.UDT.GeometryUDT
+import org.apache.spark.sql.types.{AbstractDataType, BinaryType, BooleanType, 
DataType, DataTypes, DoubleType, IntegerType, LongType, StringType}
+import org.apache.spark.unsafe.types.UTF8String
+import org.locationtech.jts.geom.Geometry
+import org.apache.spark.sql.sedona_sql.expressions.implicits._
+
+import scala.reflect.runtime.universe.TypeTag
+import scala.reflect.runtime.universe.Type
+import scala.reflect.runtime.universe.typeOf
+
+/**
+ * This is the base class for wrapping Java/Scala functions as a catalyst 
expression in Spark SQL.
+ * @param f The function to be wrapped. Subclasses can simply pass a function 
to this constructor,
+ *          and the function will be converted to [[InferrableFunction]] by 
[[InferrableFunctionConverter]]
+ *          automatically.
+ */
+abstract class InferredExpression(f: InferrableFunction)
+  extends Expression with ImplicitCastInputTypes with SerdeAware with 
CodegenFallback with FoldableExpression
+    with Serializable {
+  def inputExpressions: Seq[Expression]
+  override def children: Seq[Expression] = inputExpressions
+  override def toString: String = s" **${getClass.getName}**  "
+  override def nullable: Boolean = true
+  override def inputTypes: Seq[AbstractDataType] = f.sparkInputTypes
+  override def dataType: DataType = f.sparkReturnType
+
+  private val argExtractors: Array[InternalRow => Any] = 
f.buildExtractors(inputExpressions)
+  private val evaluator: InternalRow => Any = f.evaluatorBuilder(argExtractors)
+
+  override def eval(input: InternalRow): Any = f.serializer(evaluator(input))
+  override def evalWithoutSerialization(input: InternalRow): Any = 
evaluator(input)
+}
+
+// This is a compile time type shield for the types we are able to infer. 
Anything
+// other than these types will cause a compilation error. This is the Scala
+// 2 way of making a union type.
+sealed class InferrableType[T: TypeTag]
+object InferrableType {
+  implicit val geometryInstance: InferrableType[Geometry] =
+    new InferrableType[Geometry] {}
+  implicit val geometryArrayInstance: InferrableType[Array[Geometry]] =
+    new InferrableType[Array[Geometry]] {}
+  implicit val javaDoubleInstance: InferrableType[java.lang.Double] =
+    new InferrableType[java.lang.Double] {}
+  implicit val javaIntegerInstance: InferrableType[java.lang.Integer] =
+    new InferrableType[java.lang.Integer] {}
+  implicit val doubleInstance: InferrableType[Double] =
+    new InferrableType[Double] {}
+  implicit val booleanInstance: InferrableType[Boolean] =
+    new InferrableType[Boolean] {}
+  implicit val booleanOptInstance: InferrableType[Option[Boolean]] =
+    new InferrableType[Option[Boolean]] {}
+  implicit val intInstance: InferrableType[Int] =
+    new InferrableType[Int] {}
+  implicit val stringInstance: InferrableType[String] =
+    new InferrableType[String] {}
+  implicit val binaryInstance: InferrableType[Array[Byte]] =
+    new InferrableType[Array[Byte]] {}
+  implicit val longArrayInstance: InferrableType[Array[java.lang.Long]] =
+    new InferrableType[Array[java.lang.Long]] {}
+}
+
+object InferredTypes {
+  def buildArgumentExtractor(t: Type): Expression => InternalRow => Any = {
+    if (t =:= typeOf[Geometry]) {
+      expr => input => expr.toGeometry(input)
+    } else if (t =:= typeOf[Array[Geometry]]) {
+      expr => input => expr.toGeometryArray(input)
+    } else if (t =:= typeOf[String]) {
+      expr => input => expr.asString(input)
+    } else {
+      expr => input => expr.eval(input)
+    }
+  }
+
+  def buildSerializer(t: Type): Any => Any = {
+    if (t =:= typeOf[Geometry]) {
+      output =>
+        if (output != null) {
+          output.asInstanceOf[Geometry].toGenericArrayData
+        } else {
+          null
+        }
+    } else if (t =:= typeOf[String]) {
+      output =>
+        if (output != null) {
+          UTF8String.fromString(output.asInstanceOf[String])
+        } else {
+          null
+        }
+    } else if (t =:= typeOf[Array[java.lang.Long]]) {
+      output =>
+        if (output != null) {
+          ArrayData.toArrayData(output)
+        } else {
+          null
+        }
+    } else if (t =:= typeOf[Array[Geometry]]) {
+      output =>
+        if (output != null) {
+          
ArrayData.toArrayData(output.asInstanceOf[Array[Geometry]].map(_.toGenericArrayData))
+        } else {
+          null
+        }
+    } else if (t =:= typeOf[Option[Boolean]]) {
+      output =>
+        if (output != null) {
+          output.asInstanceOf[Option[Boolean]].orNull
+        } else {
+          null
+        }
+    } else {
+      output => output
+    }
+  }
+
+  def inferSparkType(t: Type): DataType = {
+    if (t =:= typeOf[Geometry]) {
+      GeometryUDT
+    } else if (t =:= typeOf[Array[Geometry]]) {
+      DataTypes.createArrayType(GeometryUDT)
+    } else if (t =:= typeOf[java.lang.Double]) {
+      DoubleType
+    } else if (t =:= typeOf[java.lang.Integer]) {
+      IntegerType
+    } else if (t =:= typeOf[Double]) {
+      DoubleType
+    } else if (t =:= typeOf[Int]) {
+      IntegerType
+    } else if (t =:= typeOf[String]) {
+      StringType
+    } else if (t =:= typeOf[Array[Byte]]) {
+      BinaryType
+    } else if (t =:= typeOf[Array[java.lang.Long]]) {
+      DataTypes.createArrayType(LongType)
+    } else if (t =:= typeOf[Option[Boolean]]) {
+      BooleanType
+    } else {
+      BooleanType
+    }
+  }
+}
+
+case class InferrableFunction(sparkInputTypes: Seq[AbstractDataType],
+                              sparkReturnType: DataType,
+                              serializer: Any => Any,
+                              argExtractorBuilders: Seq[Expression => 
InternalRow => Any],
+                              evaluatorBuilder: Array[InternalRow => Any] => 
InternalRow => Any) {
+  def buildExtractors(expressions: Seq[Expression]): Array[InternalRow => Any] 
= {
+    argExtractorBuilders.zipAll(expressions, null, null).flatMap {
+      case (null, _) => None
+      case (builder, expr) => Some(builder(expr))
+    }.toArray
+  }
+}
+
+object InferrableFunction {
+  /**
+   * Infer input types and return type from a type tag, and construct builder 
for argument extractors.
+   * @param typeTag Type tag of the function.
+   * @param evaluatorBuilder Builder for the evaluator.
+   * @return InferrableFunction.
+   */
+  def apply(typeTag: TypeTag[_], evaluatorBuilder: Array[InternalRow => Any] 
=> InternalRow => Any): InferrableFunction = {
+    val argTypes = typeTag.tpe.typeArgs.init
+    val returnType = typeTag.tpe.typeArgs.last
+    val sparkInputTypes: Seq[AbstractDataType] = 
argTypes.map(InferredTypes.inferSparkType)
+    val sparkReturnType: DataType = InferredTypes.inferSparkType(returnType)
+    val serializer = InferredTypes.buildSerializer(returnType)
+    val argExtractorBuilders = 
argTypes.map(InferredTypes.buildArgumentExtractor)
+    InferrableFunction(sparkInputTypes, sparkReturnType, serializer, 
argExtractorBuilders, evaluatorBuilder)
+  }
+
+  /**
+   * A variant of binary inferred expression which allows the second argument 
to be null.
+   * @param f Function to be wrapped as a catalyst expression.
+   * @param typeTag Type tag of the function.
+   * @tparam R Return type of the function.
+   * @tparam A1 Type of the first argument.
+   * @tparam A2 Type of the second argument.
+   * @return InferrableFunction.
+   */
+  def allowRightNull[R, A1, A2](f: (A1, A2) => R)(implicit typeTag: 
TypeTag[(A1, A2) => R]): InferrableFunction = {
+    apply(typeTag, extractors => {
+      val func = f.asInstanceOf[(Any, Any) => Any]
+      val extractor1 = extractors(0)
+      val extractor2 = extractors(1)
+      input => {
+        val arg1 = extractor1(input)
+        val arg2 = extractor2(input)
+        if (arg1 != null) {
+          func(arg1, arg2)
+        } else {
+          null
+        }
+      }
+    })
+  }
+}
diff --git 
a/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/NullSafeExpressions.scala
 
b/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/NullSafeExpressions.scala
deleted file mode 100644
index f526baf0..00000000
--- 
a/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/NullSafeExpressions.scala
+++ /dev/null
@@ -1,366 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package org.apache.spark.sql.sedona_sql.expressions
-
-import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
-import org.apache.spark.sql.catalyst.util.ArrayData
-import org.apache.spark.sql.sedona_sql.UDT.GeometryUDT
-import org.apache.spark.sql.sedona_sql.expressions.implicits._
-import org.apache.spark.sql.types._
-import org.apache.spark.unsafe.types.UTF8String
-import org.locationtech.jts.geom.Geometry
-
-import scala.reflect.runtime.universe._
-
-/**
-  * Make expression foldable by constant folding optimizer. If all children
-  * expressions are foldable, then the expression itself is foldable.
-  */
-trait FoldableExpression extends Expression {
-  override def foldable: Boolean = children.forall(_.foldable)
-}
-
-abstract class UnaryGeometryExpression extends Expression with SerdeAware with 
ExpectsInputTypes {
-  def inputExpressions: Seq[Expression]
-
-  override def nullable: Boolean = true
-
-  override def inputTypes: Seq[AbstractDataType] = Seq(GeometryUDT)
-
-  override def eval(input: InternalRow): Any = {
-    val result = evalWithoutSerialization(input)
-    serializeResult(result)
-  }
-
-  override def evalWithoutSerialization(input: InternalRow): Any ={
-    val inputExpression = inputExpressions.head
-    val geometry = inputExpression match {
-      case expr: SerdeAware => expr.evalWithoutSerialization(input)
-      case expr: Any => expr.toGeometry(input)
-    }
-
-    (geometry) match {
-      case (geometry: Geometry) => nullSafeEval(geometry)
-      case _ => null
-    }
-  }
-
-  protected def serializeResult(result: Any): Any = {
-    result match {
-      case geometry: Geometry => geometry.toGenericArrayData
-      case _ => result
-    }
-  }
-
-  protected def nullSafeEval(geometry: Geometry): Any
-
-
-}
-
-abstract class BinaryGeometryExpression extends Expression with SerdeAware 
with ExpectsInputTypes {
-  def inputExpressions: Seq[Expression]
-
-  override def nullable: Boolean = true
-
-  override def inputTypes: Seq[AbstractDataType] = Seq(GeometryUDT, 
GeometryUDT)
-
-  override def eval(input: InternalRow): Any = {
-    val result = evalWithoutSerialization(input)
-    serializeResult(result)
-  }
-
-  override def evalWithoutSerialization(input: InternalRow): Any = {
-    val leftExpression = inputExpressions(0)
-    val leftGeometry = leftExpression match {
-      case expr: SerdeAware => expr.evalWithoutSerialization(input)
-      case _ => leftExpression.toGeometry(input)
-    }
-
-    val rightExpression = inputExpressions(1)
-    val rightGeometry = rightExpression match {
-      case expr: SerdeAware => expr.evalWithoutSerialization(input)
-      case _ => rightExpression.toGeometry(input)
-    }
-
-    (leftGeometry, rightGeometry) match {
-      case (leftGeometry: Geometry, rightGeometry: Geometry) => 
nullSafeEval(leftGeometry, rightGeometry)
-      case _ => null
-    }
-  }
-
-  protected def serializeResult(result: Any): Any = {
-    result match {
-      case geometry: Geometry => geometry.toGenericArrayData
-      case _ => result
-    }
-  }
-
-  protected def nullSafeEval(leftGeometry: Geometry, rightGeometry: Geometry): 
Any
-}
-
-// This is a compile time type shield for the types we are able to infer. 
Anything
-// other than these types will cause a compilation error. This is the Scala
-// 2 way of making a union type.
-sealed class InferrableType[T: TypeTag]
-object InferrableType {
-  implicit val geometryInstance: InferrableType[Geometry] =
-    new InferrableType[Geometry] {}
-  implicit val geometryArrayInstance: InferrableType[Array[Geometry]] =
-    new InferrableType[Array[Geometry]] {}
-  implicit val javaDoubleInstance: InferrableType[java.lang.Double] =
-    new InferrableType[java.lang.Double] {}
-  implicit val javaIntegerInstance: InferrableType[java.lang.Integer] =
-    new InferrableType[java.lang.Integer] {}
-  implicit val doubleInstance: InferrableType[Double] =
-    new InferrableType[Double] {}
-  implicit val booleanInstance: InferrableType[Boolean] =
-    new InferrableType[Boolean] {}
-  implicit val intInstance: InferrableType[Int] =
-    new InferrableType[Int] {}
-  implicit val stringInstance: InferrableType[String] =
-    new InferrableType[String] {}
-  implicit val binaryInstance: InferrableType[Array[Byte]] =
-    new InferrableType[Array[Byte]] {}
-  implicit val longArrayInstance: InferrableType[Array[java.lang.Long]] =
-    new InferrableType[Array[java.lang.Long]] {}
-}
-
-object InferredTypes {
-  def buildExtractor[T: TypeTag](expr: Expression): InternalRow => T = {
-    if (typeOf[T] =:= typeOf[Geometry]) {
-      input: InternalRow => expr.toGeometry(input).asInstanceOf[T]
-    } else if (typeOf[T] =:= typeOf[Array[Geometry]]) {
-      input: InternalRow => expr.toGeometryArray(input).asInstanceOf[T]
-    } else if (typeOf[T] =:= typeOf[String]) {
-      input: InternalRow => expr.asString(input).asInstanceOf[T]
-    } else {
-      input: InternalRow => expr.eval(input).asInstanceOf[T]
-    }
-  }
-
-  def buildSerializer[T: TypeTag]: T => Any = {
-    if (typeOf[T] =:= typeOf[Geometry]) {
-      output: T => if (output != null) {
-        output.asInstanceOf[Geometry].toGenericArrayData
-      } else {
-        null
-      }
-    } else if (typeOf[T] =:= typeOf[String]) {
-      output: T => if (output != null) {
-        UTF8String.fromString(output.asInstanceOf[String])
-      } else {
-        null
-      }
-    } else if (typeOf[T] =:= typeOf[Array[java.lang.Long]]) {
-      output: T =>
-        if (output != null) {
-          ArrayData.toArrayData(output)
-        } else {
-          null
-        }
-    } else if (typeOf[T] =:= typeOf[Array[Geometry]]) {
-      output: T =>
-        if (output != null) {
-          
ArrayData.toArrayData(output.asInstanceOf[Array[Geometry]].map(_.toGenericArrayData))
-        } else {
-          null
-        }
-    } else {
-      output: T => output
-    }
-  }
-
-  def inferSparkType[T: TypeTag]: DataType = {
-    if (typeOf[T] =:= typeOf[Geometry]) {
-      GeometryUDT
-    } else if (typeOf[T] =:= typeOf[Array[Geometry]]) {
-      DataTypes.createArrayType(GeometryUDT)
-    } else if (typeOf[T] =:= typeOf[java.lang.Double]) {
-      DoubleType
-    } else if (typeOf[T] =:= typeOf[java.lang.Integer]) {
-      IntegerType
-    } else if (typeOf[T] =:= typeOf[Double]) {
-      DoubleType
-    } else if (typeOf[T] =:= typeOf[Int]) {
-      IntegerType
-    } else if (typeOf[T] =:= typeOf[String]) {
-      StringType
-    } else if (typeOf[T] =:= typeOf[Array[Byte]]) {
-      BinaryType
-    } else if (typeOf[T] =:= typeOf[Array[java.lang.Long]]) {
-      DataTypes.createArrayType(LongType)
-    } else {
-      BooleanType
-    }
-  }
-}
-
-/**
-  * The implicit TypeTag's tell Scala to maintain generic type info at 
runtime. Normally type
-  * erasure would remove any knowledge of what the passed in generic type is.
-  */
-abstract class InferredUnaryExpression[A1: InferrableType, R: InferrableType]
-    (f: (A1) => R)
-    (implicit val a1Tag: TypeTag[A1], implicit val rTag: TypeTag[R])
-    extends Expression with ImplicitCastInputTypes with SerdeAware with 
CodegenFallback with Serializable {
-  import InferredTypes._
-
-  def inputExpressions: Seq[Expression]
-
-  override def children: Seq[Expression] = inputExpressions
-
-  override def toString: String = s" **${getClass.getName}**  "
-
-  override def inputTypes: Seq[AbstractDataType] = Seq(inferSparkType[A1])
-
-  override def nullable: Boolean = true
-
-  override def dataType = inferSparkType[R]
-
-  lazy val extract = buildExtractor[A1](inputExpressions(0))
-
-  lazy val serialize = buildSerializer[R]
-
-  override def eval(input: InternalRow): Any = 
serialize(evalWithoutSerialization(input).asInstanceOf[R])
-
-  override def evalWithoutSerialization(input: InternalRow): Any = {
-    val value = extract(input)
-    if (value != null) {
-      f(value)
-    } else {
-      null
-    }
-  }
-}
-
-abstract class InferredBinaryExpression[A1: InferrableType, A2: 
InferrableType, R: InferrableType]
-    (f: (A1, A2) => R)
-    (implicit val a1Tag: TypeTag[A1], implicit val a2Tag: TypeTag[A2], 
implicit val rTag: TypeTag[R])
-    extends Expression with ImplicitCastInputTypes with SerdeAware with 
CodegenFallback with Serializable {
-  import InferredTypes._
-
-  def inputExpressions: Seq[Expression]
-
-  override def children: Seq[Expression] = inputExpressions
-
-  override def toString: String = s" **${getClass.getName}**  "
-
-  override def inputTypes: Seq[AbstractDataType] = Seq(inferSparkType[A1], 
inferSparkType[A2])
-
-  override def nullable: Boolean = true
-
-  def allowRightNull: Boolean = false
-
-  override def dataType = inferSparkType[R]
-
-  lazy val extractLeft = buildExtractor[A1](inputExpressions(0))
-  lazy val extractRight = buildExtractor[A2](inputExpressions(1))
-
-  lazy val serialize = buildSerializer[R]
-
-  override def eval(input: InternalRow): Any = 
serialize(evalWithoutSerialization(input).asInstanceOf[R])
-
-  override def evalWithoutSerialization(input: InternalRow): Any = {
-    val left = extractLeft(input)
-    val right = extractRight(input)
-    if (left != null && (right != null || allowRightNull)) {
-        f(left, right)
-    } else {
-      null
-    }
-  }
-}
-
-abstract class InferredTernaryExpression[A1: InferrableType, A2: 
InferrableType, A3: InferrableType, R: InferrableType]
-(f: (A1, A2, A3) => R)
-(implicit val a1Tag: TypeTag[A1], implicit val a2Tag: TypeTag[A2], implicit 
val a3Tag: TypeTag[A3], implicit val rTag: TypeTag[R])
-  extends Expression with ImplicitCastInputTypes with SerdeAware with 
CodegenFallback with Serializable {
-  import InferredTypes._
-
-  def inputExpressions: Seq[Expression]
-
-  override def children: Seq[Expression] = inputExpressions
-
-  override def toString: String = s" **${getClass.getName}**  "
-
-  override def inputTypes: Seq[AbstractDataType] = Seq(inferSparkType[A1], 
inferSparkType[A2], inferSparkType[A3])
-
-  override def nullable: Boolean = true
-
-  override def dataType = inferSparkType[R]
-
-  lazy val extractFirst = buildExtractor[A1](inputExpressions(0))
-  lazy val extractSecond = buildExtractor[A2](inputExpressions(1))
-  lazy val extractThird = buildExtractor[A3](inputExpressions(2))
-
-  lazy val serialize = buildSerializer[R]
-
-  override def eval(input: InternalRow): Any = 
serialize(evalWithoutSerialization(input).asInstanceOf[R])
-
-  override def evalWithoutSerialization(input: InternalRow): Any = {
-    val first = extractFirst(input)
-    val second = extractSecond(input)
-    val third = extractThird(input)
-    if (first != null && second != null && third != null) {
-      f(first, second, third)
-    } else {
-      null
-    }
-  }
-}
-
-abstract class InferredQuarternaryExpression[A1: InferrableType, A2: 
InferrableType, A3: InferrableType, A4: InferrableType, R: InferrableType]
-(f: (A1, A2, A3, A4) => R)
-(implicit val a1Tag: TypeTag[A1], implicit val a2Tag: TypeTag[A2], implicit 
val a3Tag: TypeTag[A3], implicit val a4Tag: TypeTag[A4], implicit val rTag: 
TypeTag[R])
-  extends Expression with ImplicitCastInputTypes with CodegenFallback with 
Serializable {
-  import InferredTypes._
-
-  def inputExpressions: Seq[Expression]
-
-  override def children: Seq[Expression] = inputExpressions
-
-  override def toString: String = s" **${getClass.getName}**  "
-
-  override def inputTypes: Seq[AbstractDataType] = Seq(inferSparkType[A1], 
inferSparkType[A2], inferSparkType[A3], inferSparkType[A4])
-
-  override def nullable: Boolean = true
-
-  override def dataType = inferSparkType[R]
-
-  lazy val extractFirst = buildExtractor[A1](inputExpressions(0))
-  lazy val extractSecond = buildExtractor[A2](inputExpressions(1))
-  lazy val extractThird = buildExtractor[A3](inputExpressions(2))
-  lazy val extractForth = buildExtractor[A4](inputExpressions(3))
-
-  lazy val serialize = buildSerializer[R]
-
-  override def eval(input: InternalRow): Any = {
-    val first = extractFirst(input)
-    val second = extractSecond(input)
-    val third = extractThird(input)
-    val forth = extractForth(input)
-    if (first != null && second != null && third != null && forth != null) {
-      serialize(f(first, second, third, forth))
-    } else {
-      null
-    }
-  }
-}


Reply via email to