This is an automated email from the ASF dual-hosted git repository.

srowen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new e4ca8424474 [SPARK-39384][SQL] Compile built-in linear regression 
aggregate functions for JDBC dialect
e4ca8424474 is described below

commit e4ca8424474e571d8e137388fe5d54732b68c2f3
Author: Jiaan Geng <belie...@163.com>
AuthorDate: Sat Jul 16 09:05:28 2022 -0500

    [SPARK-39384][SQL] Compile built-in linear regression aggregate functions 
for JDBC dialect
    
    ### What changes were proposed in this pull request?
    Recently, Spark DS V2 pushdown framework translate a lot of standard linear 
regression aggregate functions.
    Currently, only H2Dialect compile these standard linear regression 
aggregate functions. This PR compile these standard linear regression aggregate 
functions for other build-in JDBC dialect.
    
    ### Why are the changes needed?
    Make build-in JDBC dialect support compile linear regression aggregate 
push-down.
    
    ### Does this PR introduce _any_ user-facing change?
    'No'.
    New feature.
    
    ### How was this patch tested?
    New test cases.
    
    Closes #37188 from beliefer/SPARK-39384.
    
    Authored-by: Jiaan Geng <belie...@163.com>
    Signed-off-by: Sean Owen <sro...@gmail.com>
---
 .../spark/sql/jdbc/v2/DB2IntegrationSuite.scala    |   4 +
 .../spark/sql/jdbc/v2/OracleIntegrationSuite.scala |   4 +
 .../sql/jdbc/v2/PostgresIntegrationSuite.scala     |   8 ++
 .../org/apache/spark/sql/jdbc/v2/V2JDBCTest.scala  | 118 ++++++++++++++++-----
 .../org/apache/spark/sql/jdbc/DB2Dialect.scala     |  14 ++-
 .../org/apache/spark/sql/jdbc/MySQLDialect.scala   |  32 +++++-
 .../org/apache/spark/sql/jdbc/OracleDialect.scala  |  33 +++++-
 .../apache/spark/sql/jdbc/PostgresDialect.scala    |   3 +-
 8 files changed, 185 insertions(+), 31 deletions(-)

diff --git 
a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/DB2IntegrationSuite.scala
 
b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/DB2IntegrationSuite.scala
index 4b2bbbdd849..1a25cd2802d 100644
--- 
a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/DB2IntegrationSuite.scala
+++ 
b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/DB2IntegrationSuite.scala
@@ -106,4 +106,8 @@ class DB2IntegrationSuite extends 
DockerJDBCIntegrationV2Suite with V2JDBCTest {
   testStddevSamp(true)
   testCovarPop()
   testCovarSamp()
+  testRegrIntercept()
+  testRegrSlope()
+  testRegrR2()
+  testRegrSXY()
 }
diff --git 
a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/OracleIntegrationSuite.scala
 
b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/OracleIntegrationSuite.scala
index 8bc79a244e7..5de76089188 100644
--- 
a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/OracleIntegrationSuite.scala
+++ 
b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/OracleIntegrationSuite.scala
@@ -111,4 +111,8 @@ class OracleIntegrationSuite extends 
DockerJDBCIntegrationV2Suite with V2JDBCTes
   testCovarPop()
   testCovarSamp()
   testCorr()
+  testRegrIntercept()
+  testRegrSlope()
+  testRegrR2()
+  testRegrSXY()
 }
diff --git 
a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/PostgresIntegrationSuite.scala
 
b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/PostgresIntegrationSuite.scala
index 77ace3f3f4e..1ff7527c97b 100644
--- 
a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/PostgresIntegrationSuite.scala
+++ 
b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/PostgresIntegrationSuite.scala
@@ -104,4 +104,12 @@ class PostgresIntegrationSuite extends 
DockerJDBCIntegrationV2Suite with V2JDBCT
   testCovarSamp(true)
   testCorr()
   testCorr(true)
+  testRegrIntercept()
+  testRegrIntercept(true)
+  testRegrSlope()
+  testRegrSlope(true)
+  testRegrR2()
+  testRegrR2(true)
+  testRegrSXY()
+  testRegrSXY(true)
 }
diff --git 
a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCTest.scala
 
b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCTest.scala
index 0f85bd534c3..543c8465ed2 100644
--- 
a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCTest.scala
+++ 
b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCTest.scala
@@ -406,9 +406,11 @@ private[v2] trait V2JDBCTest extends SharedSparkSession 
with DockerIntegrationFu
 
   protected def caseConvert(tableName: String): String = tableName
 
+  private def withOrWithout(isDistinct: Boolean): String = if (isDistinct) 
"with" else "without"
+
   protected def testVarPop(isDistinct: Boolean = false): Unit = {
     val distinct = if (isDistinct) "DISTINCT " else ""
-    test(s"scan with aggregate push-down: VAR_POP with distinct: $isDistinct") 
{
+    test(s"scan with aggregate push-down: VAR_POP ${withOrWithout(isDistinct)} 
DISTINCT") {
       val df = sql(s"SELECT VAR_POP(${distinct}bonus) FROM 
$catalogAndNamespace." +
         s"${caseConvert("employee")} WHERE dept > 0 GROUP BY dept ORDER BY 
dept")
       checkFilterPushed(df)
@@ -416,15 +418,15 @@ private[v2] trait V2JDBCTest extends SharedSparkSession 
with DockerIntegrationFu
       checkAggregatePushed(df, "VAR_POP")
       val row = df.collect()
       assert(row.length === 3)
-      assert(row(0).getDouble(0) === 10000d)
-      assert(row(1).getDouble(0) === 2500d)
-      assert(row(2).getDouble(0) === 0d)
+      assert(row(0).getDouble(0) === 10000.0)
+      assert(row(1).getDouble(0) === 2500.0)
+      assert(row(2).getDouble(0) === 0.0)
     }
   }
 
   protected def testVarSamp(isDistinct: Boolean = false): Unit = {
     val distinct = if (isDistinct) "DISTINCT " else ""
-    test(s"scan with aggregate push-down: VAR_SAMP with distinct: 
$isDistinct") {
+    test(s"scan with aggregate push-down: VAR_SAMP 
${withOrWithout(isDistinct)} DISTINCT") {
       val df = sql(
         s"SELECT VAR_SAMP(${distinct}bonus) FROM $catalogAndNamespace." +
         s"${caseConvert("employee")} WHERE dept > 0 GROUP BY dept ORDER BY 
dept")
@@ -433,15 +435,15 @@ private[v2] trait V2JDBCTest extends SharedSparkSession 
with DockerIntegrationFu
       checkAggregatePushed(df, "VAR_SAMP")
       val row = df.collect()
       assert(row.length === 3)
-      assert(row(0).getDouble(0) === 20000d)
-      assert(row(1).getDouble(0) === 5000d)
+      assert(row(0).getDouble(0) === 20000.0)
+      assert(row(1).getDouble(0) === 5000.0)
       assert(row(2).isNullAt(0))
     }
   }
 
   protected def testStddevPop(isDistinct: Boolean = false): Unit = {
     val distinct = if (isDistinct) "DISTINCT " else ""
-    test(s"scan with aggregate push-down: STDDEV_POP with distinct: 
$isDistinct") {
+    test(s"scan with aggregate push-down: STDDEV_POP 
${withOrWithout(isDistinct)} DISTINCT") {
       val df = sql(
         s"SELECT STDDEV_POP(${distinct}bonus) FROM $catalogAndNamespace." +
         s"${caseConvert("employee")} WHERE dept > 0 GROUP BY dept ORDER BY 
dept")
@@ -450,15 +452,15 @@ private[v2] trait V2JDBCTest extends SharedSparkSession 
with DockerIntegrationFu
       checkAggregatePushed(df, "STDDEV_POP")
       val row = df.collect()
       assert(row.length === 3)
-      assert(row(0).getDouble(0) === 100d)
-      assert(row(1).getDouble(0) === 50d)
-      assert(row(2).getDouble(0) === 0d)
+      assert(row(0).getDouble(0) === 100.0)
+      assert(row(1).getDouble(0) === 50.0)
+      assert(row(2).getDouble(0) === 0.0)
     }
   }
 
   protected def testStddevSamp(isDistinct: Boolean = false): Unit = {
     val distinct = if (isDistinct) "DISTINCT " else ""
-    test(s"scan with aggregate push-down: STDDEV_SAMP with distinct: 
$isDistinct") {
+    test(s"scan with aggregate push-down: STDDEV_SAMP 
${withOrWithout(isDistinct)} DISTINCT") {
       val df = sql(
         s"SELECT STDDEV_SAMP(${distinct}bonus) FROM $catalogAndNamespace." +
         s"${caseConvert("employee")} WHERE dept > 0 GROUP BY dept ORDER BY 
dept")
@@ -467,15 +469,15 @@ private[v2] trait V2JDBCTest extends SharedSparkSession 
with DockerIntegrationFu
       checkAggregatePushed(df, "STDDEV_SAMP")
       val row = df.collect()
       assert(row.length === 3)
-      assert(row(0).getDouble(0) === 141.4213562373095d)
-      assert(row(1).getDouble(0) === 70.71067811865476d)
+      assert(row(0).getDouble(0) === 141.4213562373095)
+      assert(row(1).getDouble(0) === 70.71067811865476)
       assert(row(2).isNullAt(0))
     }
   }
 
   protected def testCovarPop(isDistinct: Boolean = false): Unit = {
     val distinct = if (isDistinct) "DISTINCT " else ""
-    test(s"scan with aggregate push-down: COVAR_POP with distinct: 
$isDistinct") {
+    test(s"scan with aggregate push-down: COVAR_POP 
${withOrWithout(isDistinct)} DISTINCT") {
       val df = sql(
         s"SELECT COVAR_POP(${distinct}bonus, bonus) FROM 
$catalogAndNamespace." +
         s"${caseConvert("employee")} WHERE dept > 0 GROUP BY dept ORDER BY 
dept")
@@ -484,15 +486,15 @@ private[v2] trait V2JDBCTest extends SharedSparkSession 
with DockerIntegrationFu
       checkAggregatePushed(df, "COVAR_POP")
       val row = df.collect()
       assert(row.length === 3)
-      assert(row(0).getDouble(0) === 10000d)
-      assert(row(1).getDouble(0) === 2500d)
-      assert(row(2).getDouble(0) === 0d)
+      assert(row(0).getDouble(0) === 10000.0)
+      assert(row(1).getDouble(0) === 2500.0)
+      assert(row(2).getDouble(0) === 0.0)
     }
   }
 
   protected def testCovarSamp(isDistinct: Boolean = false): Unit = {
     val distinct = if (isDistinct) "DISTINCT " else ""
-    test(s"scan with aggregate push-down: COVAR_SAMP with distinct: 
$isDistinct") {
+    test(s"scan with aggregate push-down: COVAR_SAMP 
${withOrWithout(isDistinct)} DISTINCT") {
       val df = sql(
         s"SELECT COVAR_SAMP(${distinct}bonus, bonus) FROM 
$catalogAndNamespace." +
         s"${caseConvert("employee")} WHERE dept > 0 GROUP BY dept ORDER BY 
dept")
@@ -501,15 +503,15 @@ private[v2] trait V2JDBCTest extends SharedSparkSession 
with DockerIntegrationFu
       checkAggregatePushed(df, "COVAR_SAMP")
       val row = df.collect()
       assert(row.length === 3)
-      assert(row(0).getDouble(0) === 20000d)
-      assert(row(1).getDouble(0) === 5000d)
+      assert(row(0).getDouble(0) === 20000.0)
+      assert(row(1).getDouble(0) === 5000.0)
       assert(row(2).isNullAt(0))
     }
   }
 
   protected def testCorr(isDistinct: Boolean = false): Unit = {
     val distinct = if (isDistinct) "DISTINCT " else ""
-    test(s"scan with aggregate push-down: CORR with distinct: $isDistinct") {
+    test(s"scan with aggregate push-down: CORR ${withOrWithout(isDistinct)} 
DISTINCT") {
       val df = sql(
         s"SELECT CORR(${distinct}bonus, bonus) FROM $catalogAndNamespace." +
         s"${caseConvert("employee")} WHERE dept > 0 GROUP BY dept ORDER BY 
dept")
@@ -518,9 +520,77 @@ private[v2] trait V2JDBCTest extends SharedSparkSession 
with DockerIntegrationFu
       checkAggregatePushed(df, "CORR")
       val row = df.collect()
       assert(row.length === 3)
-      assert(row(0).getDouble(0) === 1d)
-      assert(row(1).getDouble(0) === 1d)
+      assert(row(0).getDouble(0) === 1.0)
+      assert(row(1).getDouble(0) === 1.0)
+      assert(row(2).isNullAt(0))
+    }
+  }
+
+  protected def testRegrIntercept(isDistinct: Boolean = false): Unit = {
+    val distinct = if (isDistinct) "DISTINCT " else ""
+    test(s"scan with aggregate push-down: REGR_INTERCEPT 
${withOrWithout(isDistinct)} DISTINCT") {
+      val df = sql(
+        s"SELECT REGR_INTERCEPT(${distinct}bonus, bonus) FROM 
$catalogAndNamespace." +
+          s"${caseConvert("employee")} WHERE dept > 0 GROUP BY dept ORDER BY 
dept")
+      checkFilterPushed(df)
+      checkAggregateRemoved(df)
+      checkAggregatePushed(df, "REGR_INTERCEPT")
+      val row = df.collect()
+      assert(row.length === 3)
+      assert(row(0).getDouble(0) === 0.0)
+      assert(row(1).getDouble(0) === 0.0)
+      assert(row(2).isNullAt(0))
+    }
+  }
+
+  protected def testRegrSlope(isDistinct: Boolean = false): Unit = {
+    val distinct = if (isDistinct) "DISTINCT " else ""
+    test(s"scan with aggregate push-down: REGR_SLOPE 
${withOrWithout(isDistinct)} DISTINCT") {
+      val df = sql(
+        s"SELECT REGR_SLOPE(${distinct}bonus, bonus) FROM 
$catalogAndNamespace." +
+          s"${caseConvert("employee")} WHERE dept > 0 GROUP BY dept ORDER BY 
dept")
+      checkFilterPushed(df)
+      checkAggregateRemoved(df)
+      checkAggregatePushed(df, "REGR_SLOPE")
+      val row = df.collect()
+      assert(row.length === 3)
+      assert(row(0).getDouble(0) === 1.0)
+      assert(row(1).getDouble(0) === 1.0)
+      assert(row(2).isNullAt(0))
+    }
+  }
+
+  protected def testRegrR2(isDistinct: Boolean = false): Unit = {
+    val distinct = if (isDistinct) "DISTINCT " else ""
+    test(s"scan with aggregate push-down: REGR_R2 ${withOrWithout(isDistinct)} 
DISTINCT") {
+      val df = sql(
+        s"SELECT REGR_R2(${distinct}bonus, bonus) FROM $catalogAndNamespace." +
+          s"${caseConvert("employee")} WHERE dept > 0 GROUP BY dept ORDER BY 
dept")
+      checkFilterPushed(df)
+      checkAggregateRemoved(df)
+      checkAggregatePushed(df, "REGR_R2")
+      val row = df.collect()
+      assert(row.length === 3)
+      assert(row(0).getDouble(0) === 1.0)
+      assert(row(1).getDouble(0) === 1.0)
       assert(row(2).isNullAt(0))
     }
   }
+
+  protected def testRegrSXY(isDistinct: Boolean = false): Unit = {
+    val distinct = if (isDistinct) "DISTINCT " else ""
+    test(s"scan with aggregate push-down: REGR_SXY 
${withOrWithout(isDistinct)} DISTINCT") {
+      val df = sql(
+        s"SELECT REGR_SXY(${distinct}bonus, bonus) FROM $catalogAndNamespace." 
+
+          s"${caseConvert("employee")} WHERE dept > 0 GROUP BY dept ORDER BY 
dept")
+      checkFilterPushed(df)
+      checkAggregateRemoved(df)
+      checkAggregatePushed(df, "REGR_SXY")
+      val row = df.collect()
+      assert(row.length === 3)
+      assert(row(0).getDouble(0) === 20000.0)
+      assert(row(1).getDouble(0) === 5000.0)
+      assert(row(2).getDouble(0) === 0.0)
+    }
+  }
 }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DB2Dialect.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DB2Dialect.scala
index a3637e57266..6c7c1bfe737 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DB2Dialect.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DB2Dialect.scala
@@ -32,15 +32,27 @@ private object DB2Dialect extends JdbcDialect {
   override def canHandle(url: String): Boolean =
     url.toLowerCase(Locale.ROOT).startsWith("jdbc:db2")
 
+  private val distinctUnsupportedAggregateFunctions =
+    Set("COVAR_POP", "COVAR_SAMP", "REGR_INTERCEPT", "REGR_R2", "REGR_SLOPE", 
"REGR_SXY")
+
   // See https://www.ibm.com/docs/en/db2/11.5?topic=functions-aggregate
   private val supportedAggregateFunctions = Set("MAX", "MIN", "SUM", "COUNT", 
"AVG",
-    "VAR_POP", "VAR_SAMP", "STDDEV_POP", "STDDEV_SAMP", "COVAR_POP", 
"COVAR_SAMP")
+    "VAR_POP", "VAR_SAMP", "STDDEV_POP", "STDDEV_SAMP") ++ 
distinctUnsupportedAggregateFunctions
   private val supportedFunctions = supportedAggregateFunctions
 
   override def isSupportedFunction(funcName: String): Boolean =
     supportedFunctions.contains(funcName)
 
   class DB2SQLBuilder extends JDBCSQLBuilder {
+    override def visitAggregateFunction(
+        funcName: String, isDistinct: Boolean, inputs: Array[String]): String =
+      if (isDistinct && 
distinctUnsupportedAggregateFunctions.contains(funcName)) {
+        throw new 
UnsupportedOperationException(s"${this.getClass.getSimpleName} does not " +
+          s"support aggregate function: $funcName with DISTINCT");
+      } else {
+        super.visitAggregateFunction(funcName, isDistinct, inputs)
+      }
+
     override def dialectFunctionName(funcName: String): String = funcName 
match {
       case "VAR_POP" => "VARIANCE"
       case "VAR_SAMP" => "VARIANCE_SAMP"
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala
index cc04b5c7c92..7dc76eed49f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala
@@ -22,13 +22,14 @@ import java.util
 import java.util.Locale
 
 import scala.collection.mutable.ArrayBuilder
+import scala.util.control.NonFatal
 
 import org.apache.spark.sql.AnalysisException
 import org.apache.spark.sql.catalyst.SQLConfHelper
 import org.apache.spark.sql.catalyst.analysis.{IndexAlreadyExistsException, 
NoSuchIndexException}
 import org.apache.spark.sql.connector.catalog.Identifier
 import org.apache.spark.sql.connector.catalog.index.TableIndex
-import org.apache.spark.sql.connector.expressions.{FieldReference, 
NamedReference}
+import org.apache.spark.sql.connector.expressions.{Expression, FieldReference, 
NamedReference}
 import org.apache.spark.sql.errors.QueryExecutionErrors
 import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JdbcUtils}
 import org.apache.spark.sql.types.{BooleanType, DataType, FloatType, LongType, 
MetadataBuilder}
@@ -38,14 +39,39 @@ private case object MySQLDialect extends JdbcDialect with 
SQLConfHelper {
   override def canHandle(url : String): Boolean =
     url.toLowerCase(Locale.ROOT).startsWith("jdbc:mysql")
 
+  private val distinctUnsupportedAggregateFunctions =
+    Set("VAR_POP", "VAR_SAMP", "STDDEV_POP", "STDDEV_SAMP")
+
   // See https://dev.mysql.com/doc/refman/8.0/en/aggregate-functions.html
-  private val supportedAggregateFunctions = Set("MAX", "MIN", "SUM", "COUNT", 
"AVG",
-    "VAR_POP", "VAR_SAMP", "STDDEV_POP", "STDDEV_SAMP")
+  private val supportedAggregateFunctions =
+    Set("MAX", "MIN", "SUM", "COUNT", "AVG") ++ 
distinctUnsupportedAggregateFunctions
   private val supportedFunctions = supportedAggregateFunctions
 
   override def isSupportedFunction(funcName: String): Boolean =
     supportedFunctions.contains(funcName)
 
+  class MySQLSQLBuilder extends JDBCSQLBuilder {
+    override def visitAggregateFunction(
+        funcName: String, isDistinct: Boolean, inputs: Array[String]): String =
+      if (isDistinct && 
distinctUnsupportedAggregateFunctions.contains(funcName)) {
+        throw new 
UnsupportedOperationException(s"${this.getClass.getSimpleName} does not " +
+          s"support aggregate function: $funcName with DISTINCT");
+      } else {
+        super.visitAggregateFunction(funcName, isDistinct, inputs)
+      }
+  }
+
+  override def compileExpression(expr: Expression): Option[String] = {
+    val mysqlSQLBuilder = new MySQLSQLBuilder()
+    try {
+      Some(mysqlSQLBuilder.build(expr))
+    } catch {
+      case NonFatal(e) =>
+        logWarning("Error occurs while compiling V2 expression", e)
+        None
+    }
+  }
+
   override def getCatalystType(
       sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): 
Option[DataType] = {
     if (sqlType == Types.VARBINARY && typeName.equals("BIT") && size != 1) {
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala
index 820bff354ca..79ac248d723 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala
@@ -20,7 +20,10 @@ package org.apache.spark.sql.jdbc
 import java.sql.{Date, Timestamp, Types}
 import java.util.{Locale, TimeZone}
 
+import scala.util.control.NonFatal
+
 import org.apache.spark.sql.catalyst.util.DateTimeUtils
+import org.apache.spark.sql.connector.expressions.Expression
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.types._
 
@@ -33,16 +36,42 @@ private case object OracleDialect extends JdbcDialect {
   override def canHandle(url: String): Boolean =
     url.toLowerCase(Locale.ROOT).startsWith("jdbc:oracle")
 
+  private val distinctUnsupportedAggregateFunctions =
+    Set("VAR_POP", "VAR_SAMP", "STDDEV_POP", "STDDEV_SAMP", "COVAR_POP", 
"COVAR_SAMP", "CORR",
+      "REGR_INTERCEPT", "REGR_R2", "REGR_SLOPE", "REGR_SXY")
+
   // scalastyle:off line.size.limit
   // 
https://docs.oracle.com/en/database/oracle/oracle-database/19/sqlrf/Aggregate-Functions.html#GUID-62BE676B-AF18-4E63-BD14-25206FEA0848
   // scalastyle:on line.size.limit
-  private val supportedAggregateFunctions = Set("MAX", "MIN", "SUM", "COUNT", 
"AVG",
-    "VAR_POP", "VAR_SAMP", "STDDEV_POP", "STDDEV_SAMP", "COVAR_POP", 
"COVAR_SAMP", "CORR")
+  private val supportedAggregateFunctions =
+    Set("MAX", "MIN", "SUM", "COUNT", "AVG") ++ 
distinctUnsupportedAggregateFunctions
   private val supportedFunctions = supportedAggregateFunctions
 
   override def isSupportedFunction(funcName: String): Boolean =
     supportedFunctions.contains(funcName)
 
+  class OracleSQLBuilder extends JDBCSQLBuilder {
+    override def visitAggregateFunction(
+        funcName: String, isDistinct: Boolean, inputs: Array[String]): String =
+      if (isDistinct && 
distinctUnsupportedAggregateFunctions.contains(funcName)) {
+        throw new 
UnsupportedOperationException(s"${this.getClass.getSimpleName} does not " +
+          s"support aggregate function: $funcName with DISTINCT");
+      } else {
+        super.visitAggregateFunction(funcName, isDistinct, inputs)
+      }
+  }
+
+  override def compileExpression(expr: Expression): Option[String] = {
+    val oracleSQLBuilder = new OracleSQLBuilder()
+    try {
+      Some(oracleSQLBuilder.build(expr))
+    } catch {
+      case NonFatal(e) =>
+        logWarning("Error occurs while compiling V2 expression", e)
+        None
+    }
+  }
+
   private def supportTimeZoneTypes: Boolean = {
     val timeZone = DateTimeUtils.getTimeZone(SQLConf.get.sessionLocalTimeZone)
     // TODO: support timezone types when users are not using the JVM timezone, 
which
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala
index cb78bc806e2..878d7a7cfe6 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala
@@ -38,7 +38,8 @@ private object PostgresDialect extends JdbcDialect with 
SQLConfHelper {
 
   // See https://www.postgresql.org/docs/8.4/functions-aggregate.html
   private val supportedAggregateFunctions = Set("MAX", "MIN", "SUM", "COUNT", 
"AVG",
-    "VAR_POP", "VAR_SAMP", "STDDEV_POP", "STDDEV_SAMP", "COVAR_POP", 
"COVAR_SAMP", "CORR")
+    "VAR_POP", "VAR_SAMP", "STDDEV_POP", "STDDEV_SAMP", "COVAR_POP", 
"COVAR_SAMP", "CORR",
+    "REGR_INTERCEPT", "REGR_R2", "REGR_SLOPE", "REGR_SXY")
   private val supportedFunctions = supportedAggregateFunctions
 
   override def isSupportedFunction(funcName: String): Boolean =


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to