This is an automated email from the ASF dual-hosted git repository.
linxinyuan pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/texera.git
The following commit(s) were added to refs/heads/main by this push:
new 6a58780f39 fix: prediction input shape in sklearn testing operator
(#4236)
6a58780f39 is described below
commit 6a58780f3976b7dcfe62bd3baf68cbb63a28ad0d
Author: Xinyuan Lin <[email protected]>
AuthorDate: Wed Feb 25 11:48:00 2026 -0800
fix: prediction input shape in sklearn testing operator (#4236)
---
.../texera/amber/operator/sklearn/testing/SklearnTestingOpDesc.scala | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git
a/common/workflow-operator/src/main/scala/org/apache/texera/amber/operator/sklearn/testing/SklearnTestingOpDesc.scala
b/common/workflow-operator/src/main/scala/org/apache/texera/amber/operator/sklearn/testing/SklearnTestingOpDesc.scala
index 4c7af2db98..df7d933665 100644
---
a/common/workflow-operator/src/main/scala/org/apache/texera/amber/operator/sklearn/testing/SklearnTestingOpDesc.scala
+++
b/common/workflow-operator/src/main/scala/org/apache/texera/amber/operator/sklearn/testing/SklearnTestingOpDesc.scala
@@ -69,7 +69,7 @@ class SklearnTestingOpDesc extends PythonOperatorDescriptor {
| table = Table(self.data)
| Y = table[$target]
| X = table.drop($target, axis=1)
- | predictions = model.predict(X)
+ | predictions = model.predict(X.squeeze())
| if $isRegressionStr:
| tuple_["R2"] = r2_score(Y, predictions)
| tuple_["RMSE"] = root_mean_squared_error(Y,
predictions)