This is an automated email from the ASF dual-hosted git repository.
estrauss pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new 264bdcdd80 [SYSTEMDS-3902] Accelerated data transfer Python <--> JVM
264bdcdd80 is described below
commit 264bdcdd80955251f874eed91f3b22e29da367cf
Author: e-strauss <[email protected]>
AuthorDate: Thu Jul 31 19:20:43 2025 +0200
[SYSTEMDS-3902] Accelerated data transfer Python <--> JVM
Introduced a new data transfer mechanism on Unix systems using FIFO (named)
pipes as a faster alternative to py4j-based communication.
- Supports multiple value types (uint8, int32, fp32, fp64) for dense matrix
exchange.
- Adds experimental support for partitioned matrix transfer from Python to
Java via multiple concurrent pipes (disabled by default due to limited
performance improvement).
- Significantly reduces overhead compared to py4j for large matrix
transfers in supported scenarios
Closes #2296.
---
.github/workflows/javaTests.yml | 2 +
.../java/org/apache/sysds/api/PythonDMLScript.java | 120 +++++++++
.../apache/sysds/runtime/util/UnixPipeUtils.java | 268 +++++++++++++++++++++
.../python/systemds/context/systemds_context.py | 151 +++++++++++-
src/main/python/systemds/operator/nodes/matrix.py | 2 +-
.../python/systemds/operator/nodes/multi_return.py | 4 +-
src/main/python/systemds/script_building/script.py | 2 +-
src/main/python/systemds/utils/converters.py | 193 +++++++++++++--
.../python/tests/matrix/test_block_converter.py | 2 +-
.../tests/matrix/test_block_converter_unix_pipe.py | 104 ++++++++
.../test/component/utils/UnixPipeUtilsTest.java | 191 +++++++++++++++
.../sysds/test/usertest/pythonapi/StartupTest.java | 152 +++++++++++-
12 files changed, 1161 insertions(+), 30 deletions(-)
diff --git a/.github/workflows/javaTests.yml b/.github/workflows/javaTests.yml
index c11f00ed4f..ba7da60c8e 100644
--- a/.github/workflows/javaTests.yml
+++ b/.github/workflows/javaTests.yml
@@ -29,6 +29,7 @@ on:
- '*.html'
- 'src/main/python/**'
- 'dev/**'
+ - '.github/workflows/python.yml'
branches:
- main
pull_request:
@@ -38,6 +39,7 @@ on:
- '*.html'
- 'src/main/python/**'
- 'dev/**'
+ - '.github/workflows/python.yml'
branches:
- main
diff --git a/src/main/java/org/apache/sysds/api/PythonDMLScript.java
b/src/main/java/org/apache/sysds/api/PythonDMLScript.java
index 80f5ffcd75..3b1864d71d 100644
--- a/src/main/java/org/apache/sysds/api/PythonDMLScript.java
+++ b/src/main/java/org/apache/sysds/api/PythonDMLScript.java
@@ -25,10 +25,26 @@ import org.apache.log4j.Level;
import org.apache.log4j.Logger;
import org.apache.sysds.api.jmlc.Connection;
+import org.apache.sysds.common.Types;
+import org.apache.sysds.runtime.DMLRuntimeException;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+import org.apache.sysds.runtime.util.CommonThreadPool;
+import org.apache.sysds.runtime.util.UnixPipeUtils;
import py4j.DefaultGatewayServerListener;
import py4j.GatewayServer;
import py4j.Py4JNetworkException;
+import java.io.BufferedInputStream;
+import java.io.BufferedOutputStream;
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.concurrent.Callable;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Future;
+
public class PythonDMLScript {
@@ -36,6 +52,13 @@ public class PythonDMLScript {
final private Connection _connection;
public static GatewayServer GwS;
+ private static String fromPythonBase = "py2java";
+ private static String toPythonBase = "java2py";
+ public HashMap<Integer, BufferedInputStream> fromPython = null;
+ public HashMap<Integer, BufferedOutputStream> toPython = null;
+ public String baseDir;
+ private static int BATCH_SIZE = 32*1024;
+
/**
* Entry point for Python API.
*
@@ -78,6 +101,103 @@ public class PythonDMLScript {
return _connection;
}
+
+ public void openPipes(String path, int num) throws IOException {
+ fromPython = new HashMap<>(num * 2);
+ toPython = new HashMap<>(num * 2);
+ baseDir = path;
+ for (int i = 0; i < num; i++) {
+ BufferedInputStream pipe_in =
UnixPipeUtils.openInput(path + "/" + fromPythonBase + "-" + i, i);
+ LOG.debug("PY2JAVA pipe "+i+" is ready!");
+ fromPython.put(i, pipe_in);
+
+ BufferedOutputStream pipe_out =
UnixPipeUtils.openOutput(path + "/" + toPythonBase + "-" + i, i);
+ toPython.put(i, pipe_out);
+ }
+ }
+
+ public MatrixBlock startReadingMbFromPipe(int id, int rlen, int clen,
Types.ValueType type) throws IOException {
+ long limit = (long) rlen * clen;
+ LOG.debug("trying to read matrix from "+id+" with "+rlen+" rows
and "+clen+" columns. Total size: "+limit);
+ if(limit > Integer.MAX_VALUE)
+ throw new DMLRuntimeException("Dense NumPy array of
size " + limit +
+ " cannot be converted to MatrixBlock");
+ MatrixBlock mb = new MatrixBlock(rlen, clen, false, -1);
+ if(fromPython != null){
+ BufferedInputStream pipe = fromPython.get(id);
+ double[] denseBlock = new double[(int) limit];
+ UnixPipeUtils.readNumpyArrayInBatches(pipe, id,
BATCH_SIZE, (int) limit, type, denseBlock, 0);
+ mb.init(denseBlock, rlen, clen);
+ } else {
+ throw new DMLRuntimeException("FIFO Pipes are not
initialized.");
+ }
+ mb.recomputeNonZeros();
+ mb.examSparsity();
+ LOG.debug("Reading from Python finished");
+ return mb;
+ }
+
+ public MatrixBlock startReadingMbFromPipes(int[] blockSizes, int rlen,
int clen, Types.ValueType type) throws ExecutionException, InterruptedException
{
+ long limit = (long) rlen * clen;
+ if(limit > Integer.MAX_VALUE)
+ throw new DMLRuntimeException("Dense NumPy array of
size " + limit +
+ " cannot be converted to MatrixBlock");
+ MatrixBlock mb = new MatrixBlock(rlen, clen, false, -1);
+ if(fromPython != null){
+ ExecutorService pool = CommonThreadPool.get();
+ double[] denseBlock = new double[(int) limit];
+ int offsetOut = 0;
+ List<Future<Void>> futures = new ArrayList<>();
+ for (int i = 0; i < blockSizes.length; i++) {
+ BufferedInputStream pipe = fromPython.get(i);
+ int id = i, blockSize = blockSizes[i],
_offsetOut = offsetOut;
+ Callable<Void> task = () -> {
+
UnixPipeUtils.readNumpyArrayInBatches(pipe, id, BATCH_SIZE, blockSize, type,
denseBlock, _offsetOut);
+ return null;
+ };
+
+ futures.add(pool.submit(task));
+ offsetOut += blockSize;
+ }
+ // Wait for all tasks and propagate exceptions
+ for (Future<Void> f : futures) {
+ f.get();
+ }
+
+ mb.init(denseBlock, rlen, clen);
+ } else {
+ throw new DMLRuntimeException("FIFO Pipes are not
initialized.");
+ }
+ mb.recomputeNonZeros();
+ mb.examSparsity();
+ return mb;
+ }
+
+ public void startWritingMbToPipe(int id, MatrixBlock mb) throws
IOException {
+ if (toPython != null) {
+ int rlen = mb.getNumRows();
+ int clen = mb.getNumColumns();
+ int numElem = rlen * clen;
+ LOG.debug("Trying to write matrix ["+baseDir + "-"+
id+"] with "+rlen+" rows and "+clen+" columns. Total size: "+numElem*8);
+
+ BufferedOutputStream out = toPython.get(id);
+ long bytes =
UnixPipeUtils.writeNumpyArrayInBatches(out, id, BATCH_SIZE, numElem,
Types.ValueType.FP64, mb);
+
+ LOG.debug("Writing of " + bytes +" Bytes to Python
["+baseDir + "-"+ id+"] finished");
+ } else {
+ throw new DMLRuntimeException("FIFO Pipes are not
initialized.");
+ }
+ }
+
+ public void closePipes() throws IOException {
+ LOG.debug("Closing all pipes in Java");
+ for (BufferedInputStream pipe : fromPython.values())
+ pipe.close();
+ for (BufferedOutputStream pipe : toPython.values())
+ pipe.close();
+ LOG.debug("Closed all pipes in Java");
+ }
+
protected static class DMLGateWayListener extends
DefaultGatewayServerListener {
private static final Log LOG =
LogFactory.getLog(DMLGateWayListener.class.getName());
diff --git a/src/main/java/org/apache/sysds/runtime/util/UnixPipeUtils.java
b/src/main/java/org/apache/sysds/runtime/util/UnixPipeUtils.java
new file mode 100644
index 0000000000..69014acc0f
--- /dev/null
+++ b/src/main/java/org/apache/sysds/runtime/util/UnixPipeUtils.java
@@ -0,0 +1,268 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.util;
+
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.apache.sysds.common.Types;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+
+import java.io.BufferedInputStream;
+import java.io.BufferedOutputStream;
+import java.io.EOFException;
+import java.io.File;
+import java.io.FileInputStream;
+import java.io.FileNotFoundException;
+import java.io.FileOutputStream;
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.nio.ByteOrder;
+import java.nio.DoubleBuffer;
+import java.nio.FloatBuffer;
+import java.nio.IntBuffer;
+
+public class UnixPipeUtils {
+ private static final Log LOG =
LogFactory.getLog(UnixPipeUtils.class.getName());
+
+ /**
+ * Opens a named pipe for input, reads 4 bytes as an int, compares it
to the expected ID.
+ * If matched, returns the InputStream for further use.
+ *
+ * @param pipePath The filesystem path to the FIFO pipe
+ * @param expectedId The expected handshake ID
+ * @return BufferedInputStream if handshake succeeds
+ * @throws IOException if file access fails
+ * @throws IllegalStateException if handshake ID doesn't match
+ */
+
+ public static BufferedInputStream openInput(String pipePath, int
expectedId) throws IOException {
+ File pipeFile = new File(pipePath);
+ if (!pipeFile.exists()) {
+ throw new FileNotFoundException("Pipe not found at
path: " + pipePath);
+ }
+
+ FileInputStream fis = new FileInputStream(pipeFile);
+ BufferedInputStream bis = new BufferedInputStream(fis);
+
+ readHandshake(expectedId, bis);
+
+ return bis;
+ }
+
+ public static void readHandshake(int expectedId, BufferedInputStream
bis) throws IOException {
+ // Read 4 bytes for handshake
+ byte[] buffer = new byte[4];
+ int bytesRead = bis.read(buffer);
+ if (bytesRead != 4) {
+ bis.close();
+ throw new IOException("Failed to read handshake integer
from pipe");
+ }
+
+ // Convert bytes to int (assuming little-endian to match
typical Python struct.pack)
+ int receivedId =
ByteBuffer.wrap(buffer).order(ByteOrder.LITTLE_ENDIAN).getInt();
+ expectedId += 1000;
+
+ if (receivedId != expectedId) {
+ bis.close();
+ throw new IllegalStateException("Handshake ID mismatch:
expected " + expectedId + ", got " + receivedId);
+ }
+ }
+
+ public static BufferedOutputStream openOutput(String pipePath, int
expectedId) throws IOException {
+ File pipeFile = new File(pipePath);
+ if (!pipeFile.exists()) {
+ throw new FileNotFoundException("Pipe not found at
path: " + pipePath);
+ }
+
+ FileOutputStream fos = new FileOutputStream(pipeFile);
+ BufferedOutputStream bos = new BufferedOutputStream(fos);
+
+ writeHandshake(expectedId, bos);
+
+ return bos;
+ }
+
+ public static void writeHandshake(int expectedId, BufferedOutputStream
bos) throws IOException {
+ // Convert int to 4-byte little-endian and send as handshake
+ byte[] handshake =
ByteBuffer.allocate(4).order(ByteOrder.LITTLE_ENDIAN).putInt(expectedId +
1000).array();
+ bos.write(handshake);
+ bos.flush();
+ }
+
+ public static void readNumpyArrayInBatches(BufferedInputStream in, int
id, int batchSize, int numElem,
+
Types.ValueType type, double[] out, int offsetOut)
+ throws IOException {
+ int elemSize;
+ switch (type){
+ case UINT8 -> elemSize = 1;
+ case INT32, FP32 -> elemSize = 4;
+ default -> elemSize = 8;
+ }
+
+ try {
+ // Read start header
+ readHandshake(id, in);
+ long bytesRemaining = ((long) numElem) * elemSize;
+ byte[] buffer = new byte[batchSize];
+
+ while (bytesRemaining > 0) {
+ int currentBatchSize = (int)
Math.min(batchSize, bytesRemaining);
+ int totalRead = 0;
+
+ while (totalRead < currentBatchSize) {
+ int bytesRead = in.read(buffer,
totalRead, currentBatchSize - totalRead);
+ if (bytesRead == -1) {
+ throw new
EOFException("Unexpected end of stream in pipe #" + id +
+ ": expected " +
currentBatchSize + " bytes, got " + totalRead);
+ }
+ totalRead += bytesRead;
+ }
+
+ // Interpret bytes with value type and fill the
dense MB
+ offsetOut = fillDoubleArrayFromByteArray(type,
out, offsetOut, buffer, currentBatchSize);
+ bytesRemaining -= currentBatchSize;
+ }
+
+ // Read end header
+ readHandshake(id, in);
+
+ } catch (Exception e) {
+ LOG.error("Error occurred while reading data from pipe
#" + id, e);
+ throw e;
+ }
+ }
+
+ private static int fillDoubleArrayFromByteArray(Types.ValueType type,
double[] out, int offsetOut, byte[] buffer,
+
int currentBatchSize) {
+ ByteBuffer bb = ByteBuffer.wrap(buffer, 0,
currentBatchSize).order(ByteOrder.LITTLE_ENDIAN);
+ switch (type){
+ default -> {
+ DoubleBuffer doubleBuffer = bb.asDoubleBuffer();
+ int numDoubles = doubleBuffer.remaining();
+ doubleBuffer.get(out, offsetOut, numDoubles);
+ offsetOut += numDoubles;
+ }
+ case FP32 -> {
+ FloatBuffer floatBuffer = bb.asFloatBuffer();
+ int numFloats = floatBuffer.remaining();
+ for (int i = 0; i < numFloats; i++) {
+ out[offsetOut++] = floatBuffer.get();
+ }
+ }
+ case INT32 -> {
+ IntBuffer intBuffer = bb.asIntBuffer();
+ int numInts = intBuffer.remaining();
+ for (int i = 0; i < numInts; i++) {
+ out[offsetOut++] = intBuffer.get();
+ }
+ }
+ case UINT8 -> {
+ for (int i = 0; i < currentBatchSize; i++) {
+ out[offsetOut++] = bb.get(i) & 0xFF;
+ }
+ }
+ }
+ return offsetOut;
+ }
+
+ public static long writeNumpyArrayInBatches(BufferedOutputStream out,
int id, int batchSize, int numElem,
+
Types.ValueType type, MatrixBlock mb) throws IOException {
+ int elemSize;
+ switch (type) {
+ case UINT8 -> elemSize = 1;
+ case INT32, FP32 -> elemSize = 4;
+ default -> elemSize = 8;
+ }
+ long totalBytesWritten = 0;
+
+ // Write start header
+ writeHandshake(id, out);
+
+ int bytesRemaining = numElem * elemSize;
+ int offset = 0;
+
+ byte[] buffer = new byte[batchSize];
+
+ while (bytesRemaining > 0) {
+ int currentBatchSize = Math.min(batchSize,
bytesRemaining);
+
+ // Fill buffer from MatrixBlock into byte[] (typed)
+ int bytesWritten = fillByteArrayFromDoubleArray(type,
mb, offset, buffer, currentBatchSize);
+ totalBytesWritten += bytesWritten;
+
+ out.write(buffer, 0, currentBatchSize);
+ offset += currentBatchSize / elemSize;
+ bytesRemaining -= currentBatchSize;
+ }
+
+ out.flush();
+
+ // Write end header
+ writeHandshake(id, out);
+ return totalBytesWritten;
+ }
+
+ private static int fillByteArrayFromDoubleArray(Types.ValueType type,
MatrixBlock mb, int offsetIn,
+
byte[] buffer, int maxBytes) {
+ ByteBuffer bb = ByteBuffer.wrap(buffer, 0,
maxBytes).order(ByteOrder.LITTLE_ENDIAN);
+ int r,c;
+ switch (type) {
+ default -> { // FP64
+ DoubleBuffer doubleBuffer = bb.asDoubleBuffer();
+ int count = Math.min(doubleBuffer.remaining(),
mb.getNumRows() * mb.getNumColumns() - offsetIn);
+ for (int i = 0; i < count; i++) {
+ r = (offsetIn + i) / mb.getNumColumns();
+ c = (offsetIn + i) % mb.getNumColumns();
+ doubleBuffer.put(mb.getDouble(r,c));
+ }
+ return count * 8;
+ }
+ case FP32 -> {
+ FloatBuffer floatBuffer = bb.asFloatBuffer();
+ int count = Math.min(floatBuffer.remaining(),
mb.getNumRows() * mb.getNumColumns() - offsetIn);
+ for (int i = 0; i < count; i++) {
+ r = (offsetIn + i) / mb.getNumColumns();
+ c = (offsetIn + i) % mb.getNumColumns();
+ floatBuffer.put((float)
mb.getDouble(r,c));
+ }
+ return count * 4;
+ }
+ case INT32 -> {
+ IntBuffer intBuffer = bb.asIntBuffer();
+ int count = Math.min(intBuffer.remaining(),
mb.getNumRows() * mb.getNumColumns() - offsetIn);
+ for (int i = 0; i < count; i++) {
+ r = (offsetIn + i) / mb.getNumColumns();
+ c = (offsetIn + i) % mb.getNumColumns();
+ intBuffer.put((int) mb.getDouble(r,c));
+ }
+ return count * 4;
+ }
+ case UINT8 -> {
+ int count = Math.min(maxBytes, mb.getNumRows()
* mb.getNumColumns() - offsetIn);
+ for (int i = 0; i < count; i++) {
+ r = (offsetIn + i) / mb.getNumColumns();
+ c = (offsetIn + i) % mb.getNumColumns();
+ buffer[i] = (byte) ((int)
mb.getDouble(r,c) & 0xFF);
+ }
+ return count;
+ }
+ }
+ }
+}
\ No newline at end of file
diff --git a/src/main/python/systemds/context/systemds_context.py
b/src/main/python/systemds/context/systemds_context.py
index a559850c33..440c51c3c8 100644
--- a/src/main/python/systemds/context/systemds_context.py
+++ b/src/main/python/systemds/context/systemds_context.py
@@ -24,15 +24,19 @@ __all__ = ["SystemDSContext"]
import json
import logging
import os
+import uuid
import socket
import sys
+import struct
+import traceback
from contextlib import contextmanager
from glob import glob
from queue import Queue
from subprocess import PIPE, Popen
from threading import Thread
-from time import sleep
+from time import sleep, time
from typing import Dict, Iterable, Sequence, Tuple, Union
+from concurrent.futures import ThreadPoolExecutor
import numpy as np
import pandas as pd
@@ -66,7 +70,15 @@ class SystemDSContext(object):
_log: logging.Logger
__stdout: Queue = None
__stderr: Queue = None
+ _FIFO_PATH = "/tmp/systemds/pipes/"
+ _FIFO_PY2JAVA_BASE = "py2java"
+ _FIFO_JAVA2PY_BASE = "java2py"
+ _FIFO_PY2JAVA_PIPES = []
+ _FIFO_JAVA2PY_PIPES = []
+ _data_transfer_mode = 0
+ _multi_pipe_enabled = False
_logging_initialized = False
+ _executor_pool = ThreadPoolExecutor(max_workers=os.cpu_count() * 2 or 4)
def __init__(
self,
@@ -75,6 +87,8 @@ class SystemDSContext(object):
capture_stdout: bool = False,
logging_level: int = 20,
py4j_logging_level: int = 50,
+ data_transfer_mode: int = 1,
+ multi_pipe_enabled: bool = False,
):
"""Starts a new instance of SystemDSContext, in which the connection
to a JVM systemds instance is handled
Any new instance of this SystemDS Context, would start a separate new
JVM.
@@ -89,11 +103,106 @@ class SystemDSContext(object):
The logging levels are as follows: 10 DEBUG, 20 INFO, 30 WARNING,
40 ERROR, 50 CRITICAL.
:param py4j_logging_level: The logging level for Py4j to use, since
all communication to the JVM is done through this,
it can be verbose if not set high.
+ :param data_transfer_mode: default 0,
"""
+
self.__setup_logging(logging_level, py4j_logging_level)
self.__start(port, capture_stdout)
self.capture_stats(capture_statistics)
self._log.debug("Started JVM and SystemDS python context manager")
+ self.__setup_data_transfer(data_transfer_mode, multi_pipe_enabled)
+
+ def __setup_data_transfer(self, data_transfer_mode=0,
multi_pipe_enabled=False):
+ self._data_transfer_mode = data_transfer_mode
+ self._multi_pipe_enabled = multi_pipe_enabled
+ if os.name == "posix" and data_transfer_mode == 1:
+ num_pipes = (os.cpu_count() or 2) if multi_pipe_enabled else 1
+ self.__make_fifo_named_pipes(num_pipes)
+ in_pipes, out_pipes = self.__init_pipes(num_pipes)
+
+ self._log.info(
+ "Data transfer: Handshake done for {} IN / OUT
Pipes".format(num_pipes)
+ )
+ self._FIFO_PY2JAVA_PIPES = out_pipes
+ self._FIFO_JAVA2PY_PIPES = in_pipes
+ else:
+ self._data_transfer_mode = 0
+
+ def __init_pipes(self, num_pipes):
+
+ def open_pipe_write(path, pipe_id):
+ self._log.debug("Opening {} for writing in Py".format(path))
+ pipe = open(path, "wb")
+ handshake = struct.pack("<i", pipe_id + 1000)
+ os.write(pipe.fileno(), handshake)
+ return pipe
+
+ def open_pipe_read(path, expected_id):
+ expected_id += 1000
+ self._log.debug("Opening {} for reading in Py".format(path))
+ pipe = open(path, "rb")
+ handshake = os.read(pipe.fileno(), 4)
+ received_id = struct.unpack("<i", handshake)[0]
+ if received_id != expected_id:
+ self._log.debug(
+ "Mismatch: {}, expected: {}".format(received_id,
expected_id)
+ )
+ else:
+ self._log.debug("JAVA2PY pipe {} is ready!".format(expected_id
- 1000))
+ return pipe
+
+ def open_pipes_java(ep, p, n):
+ ep.openPipes(p, n)
+
+ fut = self._executor_pool.submit(
+ open_pipes_java, self.java_gateway.entry_point, self._FIFO_PATH,
num_pipes
+ )
+ out_f, in_f = ([], [])
+ for i in range(num_pipes):
+ out_f.append(
+ self._executor_pool.submit(
+ open_pipe_write,
+ "{}/{}-{}".format(self._FIFO_PATH,
self._FIFO_PY2JAVA_BASE, i),
+ i,
+ )
+ )
+ in_f.append(
+ self._executor_pool.submit(
+ open_pipe_read,
+ "{}/{}-{}".format(self._FIFO_PATH,
self._FIFO_JAVA2PY_BASE, i),
+ i,
+ )
+ )
+ fut.result()
+ out_pipes = [f.result() for f in out_f]
+ in_pipes = [f.result() for f in in_f]
+ return in_pipes, out_pipes
+
+ def __delete_tmp_files(self):
+ if os.path.isdir(self._FIFO_PATH):
+ for root, dirs, files in os.walk(self._FIFO_PATH):
+ for file in files:
+ file_path = os.path.join(root, file)
+ os.remove(file_path)
+ os.rmdir(self._FIFO_PATH)
+ self._log.debug("Cleaned up tmp files")
+ else:
+ self._log.debug("FIFO Path does not exist")
+
+ def __make_fifo_named_pipes(self, num_pipes):
+ self._FIFO_PATH += str(uuid.uuid4())
+ os.makedirs(self._FIFO_PATH, exist_ok=True)
+ direction, name, i = (None, None, 0)
+ directions = [self._FIFO_PY2JAVA_BASE, self._FIFO_JAVA2PY_BASE]
+ try:
+ for direction in directions:
+ for i in range(num_pipes):
+ name = direction + "-{}".format(i)
+ os.mkfifo(self._FIFO_PATH + "/" + name)
+ except IOError:
+ # clean up already created pipes in self.FIFO_PATH
+ self.__delete_tmp_files()
+ raise Exception("Creating named pipe {} failed".format(name))
def get_stdout(self, lines: int = -1):
"""Getter for the stdout of the java subprocess
@@ -140,6 +249,20 @@ class SystemDSContext(object):
message += "\n\n"
message += str(exception)
sys.tracebacklimit = trace_back_limit
+
+ # print error in case we encounter another exception during closing
+ # also print stacktrace of the original exception
+ tb_str = (
+ exception
+ if isinstance(exception, str)
+ else "".join(
+ traceback.format_exception(
+ type(exception), exception, exception.__traceback__
+ )
+ )
+ )
+ print("Encountered exception & shutting down:
\n\n{}".format(tb_str.strip()))
+
self.close()
raise RuntimeError(message)
@@ -252,7 +375,7 @@ class SystemDSContext(object):
command.append("org.apache.sysds.api.PythonDMLScript")
# Find the configuration file for systemds.
- # TODO: refine the choise of configuration file
+ # TODO: refine the choice of configuration file
files = glob(os.path.join(root, "conf", "SystemDS*.xml"))
if len(files) > 1:
self._log.warning("Multiple config files found selecting: " +
files[0])
@@ -352,15 +475,37 @@ class SystemDSContext(object):
def close(self):
"""Close the connection to the java process and do necessary
cleanup."""
if hasattr(self, "java_gateway"):
- self.__kill_Popen(self.java_gateway.java_process)
+ if self._data_transfer_mode == 1:
+ self._log.debug("Closing all Pipes in Python")
+ for pipe in self._FIFO_JAVA2PY_PIPES:
+ pipe.close()
+ self._FIFO_JAVA2PY_PIPES = []
+ for pipe in self._FIFO_PY2JAVA_PIPES:
+ pipe.close()
+ self._FIFO_PY2JAVA_PIPES = []
+ self._log.debug("All Pipes are closed in Python")
+
+ if self.java_gateway._gateway_client.is_connected:
+ self._log.debug("Closing all Pipes in Java")
+ self.java_gateway.entry_point.closePipes()
+ self._log.debug("All Pipes are closed in Java")
+ else:
+ self._log.debug("Java Gateway is not connected anymore")
+ self.__delete_tmp_files()
+
self.java_gateway.shutdown()
+ self.__kill_Popen(self.java_gateway.java_process)
if hasattr(self, "__process"):
logging.error("Has process variable")
self.__kill_Popen(self.__process)
if hasattr(self, "__stdout_thread") and
self.__stdout_thread.is_alive():
self.__stdout_thread.join(0)
+ for line in self.get_stdout():
+ print(line)
if hasattr(self, "__stderr_thread") and
self.__stderr_thread.is_alive():
self.__stderr_thread.join(0)
+ for line in self.get_stderr():
+ print(line)
def __kill_Popen(self, process: Popen):
"""Stop the process at the Popen.
diff --git a/src/main/python/systemds/operator/nodes/matrix.py
b/src/main/python/systemds/operator/nodes/matrix.py
index 208e248ec6..562897e463 100644
--- a/src/main/python/systemds/operator/nodes/matrix.py
+++ b/src/main/python/systemds/operator/nodes/matrix.py
@@ -109,7 +109,7 @@ class Matrix(OperationNode):
def _parse_output_result_variables(self, result_variables):
return matrix_block_to_numpy(
- self.sds_context.java_gateway.jvm,
+ self.sds_context,
result_variables.getMatrixBlock(self._script.out_var_name[0]),
)
diff --git a/src/main/python/systemds/operator/nodes/multi_return.py
b/src/main/python/systemds/operator/nodes/multi_return.py
index a43c478a08..523537602d 100644
--- a/src/main/python/systemds/operator/nodes/multi_return.py
+++ b/src/main/python/systemds/operator/nodes/multi_return.py
@@ -81,7 +81,9 @@ class MultiReturn(OperationNode):
output = self._outputs[idx]
if str(output) == "MatrixNode":
result_var.append(
- matrix_block_to_numpy(jvmV,
result_variables.getMatrixBlock(v))
+ matrix_block_to_numpy(
+ self.sds_context, result_variables.getMatrixBlock(v)
+ )
)
elif str(output) == "FrameNode":
result_var.append(
diff --git a/src/main/python/systemds/script_building/script.py
b/src/main/python/systemds/script_building/script.py
index 37bd4cdca0..351d3fccfb 100644
--- a/src/main/python/systemds/script_building/script.py
+++ b/src/main/python/systemds/script_building/script.py
@@ -95,7 +95,7 @@ class DMLScript:
exception_str = "Py4JNetworkError: no connection to JVM, most
likely due to previous crash or closed JVM from calls to close()"
trace_back_limit = 0
except Exception as e:
- exception_str = str(e)
+ exception_str = e
trace_back_limit = None
self.sds_context.exception_and_close(exception_str, trace_back_limit)
diff --git a/src/main/python/systemds/utils/converters.py
b/src/main/python/systemds/utils/converters.py
index 551a233257..855342d3c1 100644
--- a/src/main/python/systemds/utils/converters.py
+++ b/src/main/python/systemds/utils/converters.py
@@ -20,11 +20,62 @@
# -------------------------------------------------------------
import struct
+import tempfile
+import mmap
+import time
import numpy as np
import pandas as pd
import concurrent.futures
from py4j.java_gateway import JavaClass, JavaGateway, JavaObject, JVMView
+import os
+
+
+def format_bytes(size):
+ for unit in ["Bytes", "KB", "MB", "GB", "TB", "PB"]:
+ if size < 1024.0:
+ return f"{size:.2f} {unit}"
+ size /= 1024.0
+
+
+def pipe_transfer_header(pipe, pipe_id):
+ handshake = struct.pack("<i", pipe_id + 1000)
+ os.write(pipe.fileno(), handshake)
+
+
+def pipe_transfer_bytes(pipe, offset, end, batch_size_bytes, mem_view):
+ while offset < end:
+ # Slice the memoryview without copying
+ slice_end = min(offset + batch_size_bytes, end)
+ chunk = mem_view[offset:slice_end]
+ written = os.write(pipe.fileno(), chunk)
+ if written == 0:
+ raise Exception("Buffer issue")
+ offset += written
+
+
+def pipe_receive_header(pipe, pipe_id, logger):
+ expected_handshake = pipe_id + 1000
+ header = os.read(pipe.fileno(), 4) # pipe.read(4)
+ if len(header) < 4:
+ raise IOError("Failed to read handshake header")
+ received = struct.unpack("<i", header)[0]
+ if received != expected_handshake:
+ raise ValueError(
+ f"Handshake mismatch: expected {expected_handshake}, got
{received}"
+ )
+ logger.debug("Read handshake successfully")
+
+
+def pipe_receive_bytes(pipe, view, offset, end, batch_size_bytes, logger):
+ while offset < end:
+ slice_end = min(offset + batch_size_bytes, end)
+ chunk = os.read(pipe.fileno(), slice_end - offset)
+ if not chunk:
+ raise IOError("Pipe read returned empty data unexpectedly")
+ actual_size = len(chunk)
+ view[offset : offset + actual_size] = chunk
+ offset += actual_size
def numpy_to_matrix_block(sds, np_arr: np.array):
@@ -37,13 +88,17 @@ def numpy_to_matrix_block(sds, np_arr: np.array):
rows = np_arr.shape[0]
cols = np_arr.shape[1] if np_arr.ndim == 2 else 1
+ if rows > 2147483647:
+ raise Exception("")
+
# If not numpy array then convert to numpy array
if not isinstance(np_arr, np.ndarray):
np_arr = np.asarray(np_arr, dtype=np.float64)
jvm: JVMView = sds.java_gateway.jvm
+ ep = sds.java_gateway.entry_point
- # flatten and prepare byte buffer.
+ # flatten and set value type
if np_arr.dtype is np.dtype(np.uint8):
arr = np_arr.ravel()
value_type = jvm.org.apache.sysds.common.Types.ValueType.UINT8
@@ -56,31 +111,137 @@ def numpy_to_matrix_block(sds, np_arr: np.array):
else:
arr = np_arr.ravel().astype(np.float64)
value_type = jvm.org.apache.sysds.common.Types.ValueType.FP64
- buf = arr.tobytes()
- # Send data to java.
- try:
+ if sds._data_transfer_mode == 1:
+ mv = memoryview(arr).cast("B")
+ total_bytes = mv.nbytes
+ min_bytes_per_pipe = 1024 * 1024 * 1024 * 1
+ batch_size_bytes = 32 * 1024 # pipe's ring buffer is 64KB
+
+ # Using multiple pipes is disabled by default
+ use_single_pipe = (
+ not sds._multi_pipe_enabled or total_bytes < 2 * min_bytes_per_pipe
+ )
+ if use_single_pipe:
+ sds._log.debug(
+ "Using single FIFO pipe for reading
{}".format(format_bytes(total_bytes))
+ )
+ pipe_id = 0
+ pipe = sds._FIFO_PY2JAVA_PIPES[pipe_id]
+ fut = sds._executor_pool.submit(
+ ep.startReadingMbFromPipe, pipe_id, rows, cols, value_type
+ )
+
+ pipe_transfer_header(pipe, pipe_id) # start
+ pipe_transfer_bytes(pipe, 0, total_bytes, batch_size_bytes, mv)
+ pipe_transfer_header(pipe, pipe_id) # end
+
+ return fut.result() # Java returns MatrixBlock
+ else:
+ num_pipes = min(
+ len(sds._FIFO_PY2JAVA_PIPES), total_bytes // min_bytes_per_pipe
+ )
+ # align blocks per element
+ num_elems = len(arr)
+ elem_size = np_arr.dtype.itemsize
+ min_elems_block = num_elems // num_pipes
+ left_over = num_elems % num_pipes
+ block_sizes = sds.java_gateway.new_array(jvm.int, num_pipes)
+ for i in range(num_pipes):
+ block_sizes[i] = min_elems_block + int(i < left_over)
+
+ # run java readers in parallel
+ fut_java = sds._executor_pool.submit(
+ ep.startReadingMbFromPipes, block_sizes, rows, cols, value_type
+ )
+
+ # run writers in parallel
+ def _pipe_write_task(_pipe_id, _pipe, memview, start, end):
+ pipe_transfer_header(_pipe, _pipe_id)
+ pipe_transfer_bytes(_pipe, start, end, batch_size_bytes,
memview)
+ pipe_transfer_header(_pipe, _pipe_id)
+
+ cur = 0
+ futures = []
+ for i, size in enumerate(block_sizes):
+ pipe = sds._FIFO_PY2JAVA_PIPES[i]
+ start_byte = cur * elem_size
+ cur += size
+ end_byte = cur * elem_size
+
+ fut = sds._executor_pool.submit(
+ _pipe_write_task, i, pipe, mv, start_byte, end_byte
+ )
+ futures.append(fut)
+
+ return fut_java.result() # Java returns MatrixBlock
+ else:
+ # prepare byte buffer.
+ buf = arr.tobytes()
+
+ # Send data to java.
j_class: JavaClass =
jvm.org.apache.sysds.runtime.util.Py4jConverterUtils
return j_class.convertPy4JArrayToMB(buf, rows, cols, value_type)
- except Exception as e:
- sds.exception_and_close(e)
-def matrix_block_to_numpy(jvm: JVMView, mb: JavaObject):
+def matrix_block_to_numpy(sds, mb: JavaObject):
"""Converts a MatrixBlock object in the JVM to a numpy array.
- :param jvm: The current JVM instance running systemds.
+ :param sds: The current systemds context.
:param mb: A pointer to the JVM's MatrixBlock object.
"""
+ jvm: JVMView = sds.java_gateway.jvm
+ ep = sds.java_gateway.entry_point
+
+ rows = mb.getNumRows()
+ cols = mb.getNumColumns()
+ try:
+ if sds._data_transfer_mode == 1:
+ dtype = np.float64
+
+ elem_size = np.dtype(dtype).itemsize
+ num_elements = rows * cols
+ total_bytes = num_elements * elem_size
+ batch_size_bytes = 32 * 1024 # 32 KB
+
+ arr = np.empty(num_elements, dtype=dtype)
+ mv = memoryview(arr).cast("B")
+
+ pipe_id = 0
+ pipe = sds._FIFO_JAVA2PY_PIPES[pipe_id]
- num_ros = mb.getNumRows()
- num_cols = mb.getNumColumns()
- buf =
jvm.org.apache.sysds.runtime.util.Py4jConverterUtils.convertMBtoPy4JDenseArr(
- mb
- )
- return np.frombuffer(buf, count=num_ros * num_cols,
dtype=np.float64).reshape(
- (num_ros, num_cols)
- )
+ sds._log.debug(
+ "Using single FIFO pipe for reading {}".format(
+ format_bytes(total_bytes)
+ )
+ )
+
+ # Java starts writing to pipe in background
+ fut = sds._executor_pool.submit(ep.startWritingMbToPipe, pipe_id,
mb)
+
+ pipe_receive_header(pipe, pipe_id, sds._log)
+ sds._log.debug(
+ "Py4j task for writing {} [{}] is: done=[{}],
running=[{}]".format(
+ format_bytes(total_bytes), sds._FIFO_PATH, fut.done(),
fut.running()
+ )
+ )
+ pipe_receive_bytes(pipe, mv, 0, total_bytes, batch_size_bytes,
sds._log)
+ pipe_receive_header(pipe, pipe_id, sds._log)
+
+ fut.result()
+ sds._log.debug("Reading is done for
{}".format(format_bytes(total_bytes)))
+ return arr.reshape((rows, cols))
+
+ else:
+ buf =
jvm.org.apache.sysds.runtime.util.Py4jConverterUtils.convertMBtoPy4JDenseArr(
+ mb
+ )
+ return np.frombuffer(buf, count=rows * cols,
dtype=np.float64).reshape(
+ (rows, cols)
+ )
+ except Exception as e:
+ sds.exception_and_close(e)
+ return None
def convert(jvm, fb, idx, num_elements, value_type, pd_series,
conversion="column"):
diff --git a/src/main/python/tests/matrix/test_block_converter.py
b/src/main/python/tests/matrix/test_block_converter.py
index 5fe4b205b6..3e132b4238 100644
--- a/src/main/python/tests/matrix/test_block_converter.py
+++ b/src/main/python/tests/matrix/test_block_converter.py
@@ -73,7 +73,7 @@ class Test_MatrixBlockConverter(unittest.TestCase):
def convert_back_and_forth(self, array):
matrix_block = numpy_to_matrix_block(self.sds, array)
# use the ability to call functions on matrix_block.
- returned = matrix_block_to_numpy(self.sds.java_gateway.jvm,
matrix_block)
+ returned = matrix_block_to_numpy(self.sds, matrix_block)
self.assertTrue(np.allclose(array, returned))
diff --git a/src/main/python/tests/matrix/test_block_converter_unix_pipe.py
b/src/main/python/tests/matrix/test_block_converter_unix_pipe.py
new file mode 100644
index 0000000000..53a2eb4630
--- /dev/null
+++ b/src/main/python/tests/matrix/test_block_converter_unix_pipe.py
@@ -0,0 +1,104 @@
+# -------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+# -------------------------------------------------------------
+
+
+import os
+import shutil
+import unittest
+import pandas as pd
+import numpy as np
+from systemds.context import SystemDSContext
+
+
+class TestMatrixBlockConverterUnixPipe(unittest.TestCase):
+
+ sds: SystemDSContext = None
+ temp_dir: str = "tests/iotests/temp_write_csv/"
+
+ @classmethod
+ def setUpClass(cls):
+ cls.sds = SystemDSContext(
+ data_transfer_mode=1, logging_level=10, capture_stdout=True
+ )
+ if not os.path.exists(cls.temp_dir):
+ os.makedirs(cls.temp_dir)
+
+ @classmethod
+ def tearDownClass(cls):
+ cls.sds.close()
+ shutil.rmtree(cls.temp_dir, ignore_errors=True)
+
+ def test_python_to_java(self):
+ combinations = [ # (n_rows, n_cols)
+ (5, 0),
+ (5, 1),
+ (10, 10),
+ ]
+
+ for n_rows, n_cols in combinations:
+ matrix = (
+ np.random.random((n_rows, n_cols))
+ if n_cols != 0
+ else np.random.random(n_rows)
+ )
+ # Transfer into SystemDS and write to CSV
+ matrix_sds = self.sds.from_numpy(matrix)
+ matrix_sds.write(
+ self.temp_dir + "into_systemds_matrix.csv", format="csv",
header=False
+ ).compute(verbose=True)
+
+ # Read the CSV file using pandas
+ result_df = pd.read_csv(
+ self.temp_dir + "into_systemds_matrix.csv", header=None
+ )
+ matrix_out = result_df.to_numpy()
+ if n_cols == 0:
+ matrix_out = matrix_out.flatten()
+ # Verify the data
+ self.assertTrue(np.allclose(matrix_out, matrix))
+
+ def test_java_to_python(self):
+ combinations = [ # (n_rows, n_cols)
+ (5, 1),
+ (10, 10),
+ ]
+
+ for n_rows, n_cols in combinations:
+ matrix = np.random.random((n_rows, n_cols))
+
+ # Create a CSV file to read into SystemDS
+ pd.DataFrame(matrix).to_csv(
+ self.temp_dir + "out_of_systemds_matrix.csv", header=False,
index=False
+ )
+
+ matrix_sds = self.sds.read(
+ self.temp_dir + "out_of_systemds_matrix.csv",
+ data_type="matrix",
+ format="csv",
+ )
+ matrix_out = matrix_sds.compute()
+
+ # Verify the data
+ self.assertTrue(np.allclose(matrix_out, matrix))
+
+
+if __name__ == "__main__":
+ unittest.main(exit=False)
diff --git
a/src/test/java/org/apache/sysds/test/component/utils/UnixPipeUtilsTest.java
b/src/test/java/org/apache/sysds/test/component/utils/UnixPipeUtilsTest.java
new file mode 100644
index 0000000000..650d6c1053
--- /dev/null
+++ b/src/test/java/org/apache/sysds/test/component/utils/UnixPipeUtilsTest.java
@@ -0,0 +1,191 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.component.utils;
+
+import org.apache.sysds.common.Types;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+import org.apache.sysds.runtime.util.UnixPipeUtils;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.experimental.runners.Enclosed;
+import org.junit.rules.TemporaryFolder;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+
+import java.io.BufferedInputStream;
+import java.io.BufferedOutputStream;
+import java.io.EOFException;
+import java.io.File;
+import java.io.FileInputStream;
+import java.io.FileNotFoundException;
+import java.io.FileOutputStream;
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.nio.ByteOrder;
+import java.util.Arrays;
+import java.util.Collection;
+
+import static org.junit.Assert.assertArrayEquals;
+
+
+@RunWith(Enclosed.class)
+public class UnixPipeUtilsTest {
+
+ @RunWith(Parameterized.class)
+ public static class ParameterizedTest {
+ @Rule
+ public TemporaryFolder folder = new TemporaryFolder();
+
+ @Parameterized.Parameters(name = "{index}: type={0}")
+ public static Collection<Object[]> data() {
+ return Arrays.asList(new Object[][]{
+ {Types.ValueType.FP64, 6, 48, 99, new
MatrixBlock(2, 3, new double[]{1.0, 2.0, 3.0, 4.0, 5.0, 6.0})},
+ {Types.ValueType.FP32, 6, 24, 88, new
MatrixBlock(3, 2, new double[]{1.0, 2.0, 3.0, 4.0, 5.0, 6.0})},
+ {Types.ValueType.INT32, 4, 16, 77, new
MatrixBlock(2, 2, new double[]{0, -1, 2, -3})},
+ {Types.ValueType.UINT8, 4, 4, 66, new
MatrixBlock(2, 2, new double[]{0, 1, 2, 3})}
+ });
+ }
+
+ private final Types.ValueType type;
+ private final int numElem;
+ private final int batchSize;
+ private final int id;
+ private final MatrixBlock matrixBlock;
+
+
+ public ParameterizedTest(Types.ValueType type, int numElem, int
batchSize, int id, MatrixBlock matrixBlock) {
+ this.type = type;
+ this.numElem = numElem;
+ this.batchSize = batchSize;
+ this.id = id;
+ this.matrixBlock = matrixBlock;
+ }
+
+ @Test
+ public void testReadWriteNumpyArrayBatch() throws IOException {
+ File tempFile = folder.newFile("pipe_test_" +
type.name());
+
+ try (BufferedOutputStream out =
UnixPipeUtils.openOutput(tempFile.getAbsolutePath(), id)) {
+ UnixPipeUtils.writeNumpyArrayInBatches(out, id,
batchSize, numElem, type, matrixBlock);
+ }
+
+ double[] output = new double[numElem];
+ try (BufferedInputStream in =
UnixPipeUtils.openInput(tempFile.getAbsolutePath(), id)) {
+ UnixPipeUtils.readNumpyArrayInBatches(in, id,
batchSize, numElem, type, output, 0);
+ }
+
+ assertArrayEquals(matrixBlock.getDenseBlockValues(),
output, 1e-9);
+ }
+ }
+
+ public static class NonParameterizedTest {
+ @Rule
+ public TemporaryFolder folder = new TemporaryFolder();
+
+ @Test(expected = FileNotFoundException.class)
+ public void testOpenInputFileNotFound() throws IOException {
+ // instantiate class once for coverage
+ new UnixPipeUtils();
+
+ // Create a path that does not exist
+ File nonExistentFile = new File(folder.getRoot(),
"nonexistent.pipe");
+
+ // This should throw FileNotFoundException
+
UnixPipeUtils.openInput(nonExistentFile.getAbsolutePath(), 123);
+ }
+
+ @Test(expected = FileNotFoundException.class)
+ public void testOpenOutputFileNotFound() throws IOException {
+ // Create a path that does not exist
+ File nonExistentFile = new File(folder.getRoot(),
"nonexistent.pipe");
+
+ // This should throw FileNotFoundException
+
UnixPipeUtils.openOutput(nonExistentFile.getAbsolutePath(), 123);
+ }
+
+
+ @Test
+ public void testOpenInputAndOutputHandshakeMatch() throws
IOException {
+ File tempFile = folder.newFile("pipe_test1");
+ int id = 42;
+
+ // Write expected handshake
+ try (BufferedOutputStream bos =
UnixPipeUtils.openOutput(tempFile.getAbsolutePath(), id)) {}
+
+ // Read and validate handshake
+ try (BufferedInputStream bis =
UnixPipeUtils.openInput(tempFile.getAbsolutePath(), id)) {
+ // success: no exception = handshake passed
+ }
+ }
+
+ @Test(expected = IllegalStateException.class)
+ public void testOpenInputHandshakeMismatch() throws IOException
{
+ File tempFile = folder.newFile("pipe_test2");
+ int writeId = 123;
+ int wrongReadId = 456;
+
+ try (BufferedOutputStream bos =
UnixPipeUtils.openOutput(tempFile.getAbsolutePath(), writeId)) {}
+
+ // Will throw due to ID mismatch
+ UnixPipeUtils.openInput(tempFile.getAbsolutePath(),
wrongReadId);
+ }
+
+ @Test(expected = IOException.class)
+ public void testOpenInputIncompleteHandshake() throws
IOException {
+ File tempFile = folder.newFile("short_handshake.pipe");
+
+ // Write only 2 bytes instead of 4
+ try (FileOutputStream fos = new
FileOutputStream(tempFile)) {
+ fos.write(new byte[]{0x01, 0x02});
+ }
+
+ UnixPipeUtils.openInput(tempFile.getAbsolutePath(),
100);
+ }
+
+ @Test(expected = EOFException.class)
+ public void testReadNumpyArrayUnexpectedEOF() throws
IOException {
+ File tempFile = folder.newFile("pipe_test5");
+ int id = 12;
+ int numElem = 5;
+ int batchSize = 40;
+ Types.ValueType type = Types.ValueType.FP64;
+
+ // Write partial data (handshake + 3 doubles instead of
5)
+ try (BufferedOutputStream out = new
BufferedOutputStream(new FileOutputStream(tempFile))) {
+ ByteBuffer bb =
ByteBuffer.allocate(4).order(ByteOrder.LITTLE_ENDIAN).putInt(id + 1000);
+ out.write(bb.array());
+
+ // Write 3 doubles only
+ bb = ByteBuffer.allocate(8 *
3).order(ByteOrder.LITTLE_ENDIAN);
+ for (int i = 0; i < 3; i++)
+ bb.putDouble(i + 1.0);
+ out.write(bb.array());
+
+ // no end handshake
+ out.flush();
+ }
+
+ double[] outArr = new double[numElem];
+ try (BufferedInputStream in = new
BufferedInputStream(new FileInputStream(tempFile))) {
+ UnixPipeUtils.readNumpyArrayInBatches(in, id,
batchSize, numElem, type, outArr, 0);
+ }
+ }
+ }
+}
diff --git
a/src/test/java/org/apache/sysds/test/usertest/pythonapi/StartupTest.java
b/src/test/java/org/apache/sysds/test/usertest/pythonapi/StartupTest.java
index 549d4cde7b..b048e66e95 100644
--- a/src/test/java/org/apache/sysds/test/usertest/pythonapi/StartupTest.java
+++ b/src/test/java/org/apache/sysds/test/usertest/pythonapi/StartupTest.java
@@ -23,16 +23,27 @@ import org.apache.log4j.Level;
import org.apache.log4j.Logger;
import org.apache.log4j.spi.LoggingEvent;
import org.apache.sysds.api.PythonDMLScript;
+import org.apache.sysds.common.Types;
+import org.apache.sysds.runtime.DMLRuntimeException;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+import org.apache.sysds.runtime.util.UnixPipeUtils;
import org.apache.sysds.test.LoggingUtils;
import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
+import org.junit.Rule;
import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
import py4j.GatewayServer;
+import java.io.BufferedInputStream;
+import java.io.BufferedOutputStream;
+import java.io.File;
import java.security.Permission;
import java.util.List;
+import static org.junit.Assert.assertArrayEquals;
+
/** Simple tests to verify startup of Python Gateway server happens without
crashes */
public class StartupTest {
@@ -58,18 +69,36 @@ public class StartupTest {
}
private void assertLogMessages(String... expectedMessages) {
+ assertLogMessages(true, expectedMessages);
+ }
+
+ private void assertLogMessages(boolean strict, String...
expectedMessages) {
List<LoggingEvent> log = LoggingUtils.reinsert(appender);
log.stream().forEach(l -> System.out.println(l.getMessage()));
- Assert.assertEquals("Unexpected number of log messages",
expectedMessages.length, log.size());
+ if (strict){
+ Assert.assertEquals("Unexpected number of log
messages", expectedMessages.length, log.size());
- for (int i = 0; i < expectedMessages.length; i++) {
- // order does not matter
- boolean found = false;
+ for (int i = 0; i < expectedMessages.length; i++) {
+ // order does not matter
+ boolean found = false;
+ for (String message : expectedMessages) {
+ found |=
log.get(i).getMessage().toString().startsWith(message);
+ }
+ Assert.assertTrue("Unexpected log message: " +
log.get(i).getMessage(),found);
+ }
+ } else {
for (String message : expectedMessages) {
- found |=
log.get(i).getMessage().toString().startsWith(message);
+ // order does not matter
+ boolean found = false;
+
+ for (LoggingEvent loggingEvent : log) {
+ found |=
loggingEvent.getMessage().toString().startsWith(message);
+ }
+ Assert.assertTrue("Expected log message not
found: " + message,found);
}
- Assert.assertTrue("Unexpected log message: " +
log.get(i).getMessage(),found);
}
+
+
}
@Test(expected = Exception.class)
@@ -108,7 +137,7 @@ public class StartupTest {
PythonDMLScript.main(new String[]{"-python", "4001"});
Thread.sleep(200);
} catch (SecurityException e) {
- assertLogMessages(
+ assertLogMessages(false,
"GatewayServer started",
"failed startup"
);
@@ -125,6 +154,7 @@ public class StartupTest {
PythonDMLScript.GwS.shutdown();
Thread.sleep(200);
assertLogMessages(
+ false,
"GatewayServer started",
"Starting JVM shutdown",
"Shutdown done",
@@ -132,6 +162,98 @@ public class StartupTest {
);
}
+ @Rule
+ public TemporaryFolder folder = new TemporaryFolder();
+
+ @Test
+ public void testDataTransfer() throws Exception {
+ PythonDMLScript.main(new String[]{"-python", "4003"});
+ Thread.sleep(200);
+ PythonDMLScript script = (PythonDMLScript)
PythonDMLScript.GwS.getGateway().getEntryPoint();
+
+ File in = folder.newFile("py2java-0");
+ File out = folder.newFile("java2py-0");
+
+ // Init Test
+ BufferedOutputStream py2java =
UnixPipeUtils.openOutput(in.getAbsolutePath(), 0);
+ script.openPipes(folder.getRoot().getPath(), 1);
+ BufferedInputStream java2py =
UnixPipeUtils.openInput(out.getAbsolutePath(), 0);
+
+ // Write Test
+ double[] data = new double[]{1.0, 2.0, 3.0, 4.0, 5.0, 6.0};
+ MatrixBlock mb = new MatrixBlock(2, 3, data);
+ script.startWritingMbToPipe(0, mb);
+ double[] rcv_data = new double[data.length];
+ UnixPipeUtils.readNumpyArrayInBatches(java2py, 0, 32,
data.length, Types.ValueType.FP64, rcv_data, 0);
+ assertArrayEquals(data, rcv_data, 1e-9);
+
+ // Read Test
+ UnixPipeUtils.writeNumpyArrayInBatches(py2java, 0, 32,
data.length, Types.ValueType.FP64, mb);
+ MatrixBlock rcv_mb = script.startReadingMbFromPipe(0, 2, 3,
Types.ValueType.FP64);
+ assertArrayEquals(data, rcv_mb.getDenseBlockValues(), 1e-9);
+
+
+ script.closePipes();
+
+ PythonDMLScript.GwS.shutdown();
+ Thread.sleep(200);
+ }
+
+ @Test
+ public void testDataTransferMultiPipes() throws Exception {
+ PythonDMLScript.main(new String[]{"-python", "4004"});
+ Thread.sleep(200);
+ PythonDMLScript script = (PythonDMLScript)
PythonDMLScript.GwS.getGateway().getEntryPoint();
+
+ File in = folder.newFile("py2java-0");
+ folder.newFile("java2py-0");
+ File in2 = folder.newFile("py2java-1");
+ folder.newFile("java2py-1");
+
+ // Init Test
+ BufferedOutputStream py2java =
UnixPipeUtils.openOutput(in.getAbsolutePath(), 0);
+ BufferedOutputStream py2java2 =
UnixPipeUtils.openOutput(in2.getAbsolutePath(), 1);
+ script.openPipes(folder.getRoot().getPath(), 2);
+
+ // Read Test
+ double[] data = new double[]{1.0, 2.0, 3.0};
+ MatrixBlock mb = new MatrixBlock(3, 1, data);
+ UnixPipeUtils.writeNumpyArrayInBatches(py2java, 0, 32, 3,
Types.ValueType.FP64, mb);
+ UnixPipeUtils.writeNumpyArrayInBatches(py2java2, 1, 32, 3,
Types.ValueType.FP64, mb);
+ MatrixBlock rcv_mb = script.startReadingMbFromPipes(new
int[]{3,3}, 6, 1, Types.ValueType.FP64);
+ data = new double[]{1.0, 2.0, 3.0, 1.0, 2.0, 3.0};
+ assertArrayEquals(data, rcv_mb.getDenseBlockValues(), 1e-9);
+
+ script.closePipes();
+
+ PythonDMLScript.GwS.shutdown();
+ Thread.sleep(200);
+ }
+
+ @Test(expected = DMLRuntimeException.class)
+ public void testDataTransferNotInit1() throws Exception {
+ PythonDMLScript.main(new String[]{"-python", "4005"});
+ Thread.sleep(200);
+ PythonDMLScript script = (PythonDMLScript)
PythonDMLScript.GwS.getGateway().getEntryPoint();
+ script.startReadingMbFromPipe(0, 2, 3, Types.ValueType.FP64);
+ }
+
+ @Test(expected = DMLRuntimeException.class)
+ public void testDataTransferNotInit2() throws Exception {
+ PythonDMLScript.main(new String[]{"-python", "4006"});
+ Thread.sleep(200);
+ PythonDMLScript script = (PythonDMLScript)
PythonDMLScript.GwS.getGateway().getEntryPoint();
+ script.startWritingMbToPipe(0, null);
+ }
+
+ @Test(expected = DMLRuntimeException.class)
+ public void testDataTransferNotInit3() throws Exception {
+ PythonDMLScript.main(new String[]{"-python", "4007"});
+ Thread.sleep(200);
+ PythonDMLScript script = (PythonDMLScript)
PythonDMLScript.GwS.getGateway().getEntryPoint();
+ script.startReadingMbFromPipes(new int[]{3,3}, 2, 3,
Types.ValueType.FP64);
+ }
+
@SuppressWarnings("removal")
class NoExitSecurityManager extends SecurityManager {
@Override
@@ -142,4 +264,20 @@ public class StartupTest {
throw new SecurityException("Intercepted exit()");
}
}
+
+ @Test(expected = DMLRuntimeException.class)
+ public void testDataTransferMaxValue1() throws Exception {
+ PythonDMLScript.main(new String[]{"-python", "4008"});
+ Thread.sleep(200);
+ PythonDMLScript script = (PythonDMLScript)
PythonDMLScript.GwS.getGateway().getEntryPoint();
+ script.startReadingMbFromPipe(0, Integer.MAX_VALUE, 3,
Types.ValueType.FP64);
+ }
+
+ @Test(expected = DMLRuntimeException.class)
+ public void testDataTransferMaxValue2() throws Exception {
+ PythonDMLScript.main(new String[]{"-python", "4009"});
+ Thread.sleep(200);
+ PythonDMLScript script = (PythonDMLScript)
PythonDMLScript.GwS.getGateway().getEntryPoint();
+ script.startReadingMbFromPipes(new int[]{3,3},
Integer.MAX_VALUE, 2, Types.ValueType.FP64);
+ }
}