This is an automated email from the ASF dual-hosted git repository.
baunsgaard pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new 0386fe1917 [MINOR] Fix in Python API GatewayServerListener
0386fe1917 is described below
commit 0386fe1917fd71df6008f24c4f000efd0ffece0f
Author: e-strauss <[email protected]>
AuthorDate: Mon Mar 31 15:22:01 2025 +0200
[MINOR] Fix in Python API GatewayServerListener
The current inner class DMLGateWayListener implements function from the
GatewayServerListener interface, which are never invoked by the GatewayServer
(since the GatewayServer, which also implements GatewayServerListener, does not
implement these methods. Furthermore, DMLGateWayListener previously called,
Sys.exit(), which I think is not correct, since it breaks the proper shutdown
of the GatewayServer. Finally, this commit added a new unit case, which checks
the functionality of the D [...]
Closes #2243
---
.../java/org/apache/sysds/api/PythonDMLScript.java | 42 ++++-------
.../sysds/test/usertest/pythonapi/StartupTest.java | 88 ++++++++++++++++++++++
2 files changed, 101 insertions(+), 29 deletions(-)
diff --git a/src/main/java/org/apache/sysds/api/PythonDMLScript.java
b/src/main/java/org/apache/sysds/api/PythonDMLScript.java
index c4957d4e9f..80f5ffcd75 100644
--- a/src/main/java/org/apache/sysds/api/PythonDMLScript.java
+++ b/src/main/java/org/apache/sysds/api/PythonDMLScript.java
@@ -21,17 +21,20 @@ package org.apache.sysds.api;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
+import org.apache.log4j.Level;
+import org.apache.log4j.Logger;
import org.apache.sysds.api.jmlc.Connection;
+import py4j.DefaultGatewayServerListener;
import py4j.GatewayServer;
-import py4j.GatewayServerListener;
import py4j.Py4JNetworkException;
-import py4j.Py4JServerConnection;
+
public class PythonDMLScript {
private static final Log LOG =
LogFactory.getLog(PythonDMLScript.class.getName());
final private Connection _connection;
+ public static GatewayServer GwS;
/**
* Entry point for Python API.
@@ -42,7 +45,7 @@ public class PythonDMLScript {
public static void main(String[] args) throws Exception {
final DMLOptions dmlOptions = DMLOptions.parseCLArguments(args);
DMLScript.loadConfiguration(dmlOptions.configFile);
- final GatewayServer GwS = new GatewayServer(new
PythonDMLScript(), dmlOptions.pythonPort);
+ GwS = new GatewayServer(new PythonDMLScript(),
dmlOptions.pythonPort);
GwS.addListener(new DMLGateWayListener());
try {
GwS.start();
@@ -67,38 +70,20 @@ public class PythonDMLScript {
_connection = new Connection();
}
+ public static void setDMLGateWayListenerLoggerLevel(Level l){
+ Logger.getLogger(DMLGateWayListener.class).setLevel(l);
+ }
+
public Connection getConnection() {
return _connection;
}
- protected static class DMLGateWayListener implements
GatewayServerListener {
+ protected static class DMLGateWayListener extends
DefaultGatewayServerListener {
private static final Log LOG =
LogFactory.getLog(DMLGateWayListener.class.getName());
- @Override
- public void connectionError(Exception e) {
- LOG.warn("Connection error: " + e.getMessage());
- System.exit(1);
- }
-
- @Override
- public void connectionStarted(Py4JServerConnection
gatewayConnection) {
- LOG.debug("Connection Started: " +
gatewayConnection.toString());
- }
-
- @Override
- public void connectionStopped(Py4JServerConnection
gatewayConnection) {
- LOG.debug("Connection stopped: " +
gatewayConnection.toString());
- }
-
- @Override
- public void serverError(Exception e) {
- LOG.error("Server Error " + e.getMessage());
- }
-
@Override
public void serverPostShutdown() {
LOG.info("Shutdown done");
- System.exit(0);
}
@Override
@@ -108,13 +93,12 @@ public class PythonDMLScript {
@Override
public void serverStarted() {
- LOG.info("GatewayServer Started");
+ LOG.info("GatewayServer started");
}
@Override
public void serverStopped() {
- LOG.info("GatewayServer Stopped");
- System.exit(0);
+ LOG.info("GatewayServer stopped");
}
}
diff --git
a/src/test/java/org/apache/sysds/test/usertest/pythonapi/StartupTest.java
b/src/test/java/org/apache/sysds/test/usertest/pythonapi/StartupTest.java
index 4b8395107f..9e7cda13ee 100644
--- a/src/test/java/org/apache/sysds/test/usertest/pythonapi/StartupTest.java
+++ b/src/test/java/org/apache/sysds/test/usertest/pythonapi/StartupTest.java
@@ -19,11 +19,55 @@
package org.apache.sysds.test.usertest.pythonapi;
+import org.apache.log4j.Level;
+import org.apache.log4j.Logger;
+import org.apache.log4j.spi.LoggingEvent;
import org.apache.sysds.api.PythonDMLScript;
+import org.apache.sysds.test.LoggingUtils;
+import org.junit.After;
+import org.junit.Assert;
+import org.junit.Before;
import org.junit.Test;
+import py4j.GatewayServer;
+
+import java.security.Permission;
+import java.util.List;
+
/** Simple tests to verify startup of Python Gateway server happens without
crashes */
public class StartupTest {
+ private LoggingUtils.TestAppender appender;
+ private SecurityManager sm;
+
+ @Before
+ public void setUp() {
+ appender = LoggingUtils.overwrite();
+ sm = System.getSecurityManager();
+ System.setSecurityManager(new NoExitSecurityManager());
+ PythonDMLScript.setDMLGateWayListenerLoggerLevel(Level.ALL);
+
Logger.getLogger(PythonDMLScript.class.getName()).setLevel(Level.ALL);
+ }
+
+ @After
+ public void tearDown() {
+ LoggingUtils.reinsert(appender);
+ System.setSecurityManager(sm);
+ }
+
+ private void assertLogMessages(String... expectedMessages) {
+ List<LoggingEvent> log = LoggingUtils.reinsert(appender);
+ log.stream().forEach(l -> System.out.println(l.getMessage()));
+ Assert.assertEquals("Unexpected number of log messages",
expectedMessages.length, log.size());
+
+ for (int i = 0; i < expectedMessages.length; i++) {
+ // order does not matter
+ boolean found = false;
+ for (String message : expectedMessages) {
+ found |=
log.get(i).getMessage().toString().startsWith(message);
+ }
+ Assert.assertTrue("Unexpected log message: " +
log.get(i).getMessage(),found);
+ }
+ }
@Test(expected = Exception.class)
public void testStartupIncorrect_1() throws Exception {
@@ -50,4 +94,48 @@ public class StartupTest {
// Number out of range
PythonDMLScript.main(new String[] {"-python", "918757"});
}
+
+ @Test
+ public void testStartupIncorrect_6() throws Exception {
+ GatewayServer gws1 = null;
+ try {
+ PythonDMLScript.main(new String[]{"-python", "4001"});
+ gws1 = PythonDMLScript.GwS;
+ Thread.sleep(200);
+ PythonDMLScript.main(new String[]{"-python", "4001"});
+ Thread.sleep(200);
+ } catch (SecurityException e) {
+ assertLogMessages(
+ "GatewayServer started",
+ "failed startup"
+ );
+ gws1.shutdown();
+ }
+ }
+
+ @Test
+ public void testStartupCorrect() throws Exception {
+ PythonDMLScript.main(new String[]{"-python", "4002"});
+ Thread.sleep(200);
+ PythonDMLScript script = (PythonDMLScript)
PythonDMLScript.GwS.getGateway().getEntryPoint();
+ script.getConnection();
+ PythonDMLScript.GwS.shutdown();
+ Thread.sleep(200);
+ assertLogMessages(
+ "GatewayServer started",
+ "Starting JVM shutdown",
+ "Shutdown done",
+ "GatewayServer stopped"
+ );
+ }
+
+ class NoExitSecurityManager extends SecurityManager {
+ @Override
+ public void checkPermission(Permission perm) { }
+
+ @Override
+ public void checkExit(int status) {
+ throw new SecurityException("Intercepted exit()");
+ }
+ }
}