This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new a4deb7188a [SYSTEMDS-3500] Fix lineage support / tests for
contains-value function
a4deb7188a is described below
commit a4deb7188a4ca8b45df9e58efef289e21e06a93d
Author: Matthias Boehm <[email protected]>
AuthorDate: Thu Feb 23 23:05:25 2023 +0100
[SYSTEMDS-3500] Fix lineage support / tests for contains-value function
This patch fixes missing lineage reconstruction support and one python
test for the new contains-value function.
---
.../instructions/cp/ParameterizedBuiltinCPInstruction.java | 8 +++++++-
src/main/python/tests/federated/test_federated_mnist.py | 2 +-
2 files changed, 8 insertions(+), 2 deletions(-)
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java
index d3c88fd5ff..a67f8cd20d 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java
@@ -474,7 +474,13 @@ public class ParameterizedBuiltinCPInstruction extends
ComputationCPInstruction
@Override
public Pair<String, LineageItem> getLineageItem(ExecutionContext ec) {
String opcode = getOpcode();
- if(opcode.equalsIgnoreCase("groupedagg")) {
+ if(opcode.equalsIgnoreCase("contains")) {
+ CPOperand target = getTargetOperand();
+ CPOperand pattern = getFP64Literal("pattern");
+ return Pair.of(output.getName(),
+ new LineageItem(getOpcode(),
LineageItemUtils.getLineage(ec, target, pattern)));
+ }
+ else if(opcode.equalsIgnoreCase("groupedagg")) {
CPOperand target = getTargetOperand();
CPOperand groups = new
CPOperand(params.get(Statement.GAGG_GROUPS), ValueType.FP64, DataType.MATRIX);
String wt = params.containsKey(Statement.GAGG_WEIGHTS)
? params.get(Statement.GAGG_WEIGHTS) : String
diff --git a/src/main/python/tests/federated/test_federated_mnist.py
b/src/main/python/tests/federated/test_federated_mnist.py
index 3b11bd3194..c49f64897c 100644
--- a/src/main/python/tests/federated/test_federated_mnist.py
+++ b/src/main/python/tests/federated/test_federated_mnist.py
@@ -114,7 +114,7 @@ class TestFederatedMnist(unittest.TestCase):
with self.sds.capture_stats_context():
[_, _, acc] = multiLogRegPredict(Xt, bias, Yt).compute()
stats = self.sds.take_stats()
- for fed_instr in ["fed_isnan", "fed_*", "fed_-", "fed_uark+",
"fed_r'", "fed_rightIndex"]:
+ for fed_instr in ["fed_contains", "fed_*", "fed_-", "fed_uark+",
"fed_r'", "fed_rightIndex"]:
self.assertIn(fed_instr, stats)
self.assertGreater(acc, 80)