github-actions[bot] opened a new issue, #629: URL: https://github.com/apache/incubator-wayang/issues/629
add functionality required for this test load_collection & collect https://github.com/apache/incubator-wayang/blob/1e0f9e8166225176fe3022de5fbcce3dbcba96b9/python/src/pywy/tests/test_decision_tree_regression.py#L23 ```python # limitations under the License. # from pywy.dataquanta import WayangContext from pywy.platforms.java import JavaPlugin from pywy.platforms.spark import SparkPlugin import pytest # TODO: add functionality required for this test load_collection & collect @pytest.mark.skip(reason="no way of currently testing this, since we are missing implementations for load_collection & collect") def test_train_and_predict(): # Initialize context with platforms ctx = WayangContext().register({JavaPlugin, SparkPlugin}) # Input features and labels features = ctx.load_collection([ [1.0, 2.0], [2.0, 3.0], [3.0, 4.0], [4.0, 5.0] ]) labels = ctx.load_collection([3.0, 4.0, 5.0, 6.0]) # Train the model model = features.train_decision_tree_regression(labels, max_depth=3, min_instances=1) # Run predictions on same features predictions = model.predict(features) # Collect and validate result = predictions.collect() print("Predictions:", result) assert len(result) is 4, f"Expected len(result) to be 4, but got: {len(result)}" for pred in result: assert pred is float assert pred > 1.0 assert pred <= 7.0 ``` fac400e04a14d77f47335d07ed4a704d15f02e42 -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
