This commit fixes a bug where long-lasting OsDiagnose jobs would
unnecessarily execute after a RAPI client had already closed the
connection. This fix introduces a couple new polling/timeout
constants for Luxi and RAPI to handle the polling of the HTTP
connection while at the same time monitor the state of the submitted
job.

If the connection is closed, client-side, the job is canceled and it
will be removed from the job queue if not already in a running state.

Signed-off-by: Federico Morg Pareschi <[email protected]>
---
 lib/cli.py                      | 69 ++++++++++++++++++++++++++++++++++-------
 lib/errors.py                   |  9 ++++++
 lib/http/server.py              | 10 +++---
 lib/rapi/client_utils.py        |  3 +-
 lib/rapi/rlib2.py               | 28 ++++++++++++++++-
 lib/rapi/testutils.py           |  3 +-
 src/Ganeti/Constants.hs         | 10 +++++-
 test/py/ganeti.cli_unittest.py  | 42 +++++++++++++++++++++++--
 test/py/ganeti.http_unittest.py | 23 +++++++-------
 9 files changed, 165 insertions(+), 32 deletions(-)

diff --git a/lib/cli.py b/lib/cli.py
index 2001ed9..a470ffa 100644
--- a/lib/cli.py
+++ b/lib/cli.py
@@ -705,7 +705,8 @@ def SendJob(ops, cl=None):
   return job_id
 
 
-def GenericPollJob(job_id, cbs, report_cbs):
+def GenericPollJob(job_id, cbs, report_cbs, cancel_fn=None,
+                   update_freq=constants.DEFAULT_WFJC_TIMEOUT):
   """Generic job-polling function.
 
   @type job_id: number
@@ -714,9 +715,13 @@ def GenericPollJob(job_id, cbs, report_cbs):
   @param cbs: Data callbacks
   @type report_cbs: Instance of L{JobPollReportCbBase}
   @param report_cbs: Reporting callbacks
-
+  @type cancel_fn: Function returning a boolean
+  @param cancel_fn: Function to check if we should cancel the running job
+  @type update_freq: int/long
+  @param update_freq: number of seconds between each WFJC reports
   @return: the opresult of the job
   @raise errors.JobLost: If job can't be found
+  @raise errors.JobCanceled: If job is canceled
   @raise errors.OpExecError: If job didn't succeed
 
   """
@@ -724,17 +729,36 @@ def GenericPollJob(job_id, cbs, report_cbs):
   prev_logmsg_serial = None
 
   status = None
+  should_cancel = False
+
+  if update_freq <= 0:
+    raise errors.ParameterError("Update frequency must be a positive number")
 
   while True:
-    result = cbs.WaitForJobChangeOnce(job_id, ["status"], prev_job_info,
-                                      prev_logmsg_serial)
+    if cancel_fn:
+      timer = 0
+      while timer < update_freq:
+        result = cbs.WaitForJobChangeOnce(job_id, ["status"], prev_job_info,
+                                          prev_logmsg_serial,
+                                          timeout=constants.CLI_WFJC_FREQUENCY)
+        should_cancel = cancel_fn()
+        if should_cancel or not result or result != constants.JOB_NOTCHANGED:
+          break
+        timer += constants.CLI_WFJC_FREQUENCY
+    else:
+      result = cbs.WaitForJobChangeOnce(job_id, ["status"], prev_job_info,
+                                      prev_logmsg_serial, timeout=update_freq)
     if not result:
       # job not found, go away!
       raise errors.JobLost("Job with id %s lost" % job_id)
 
+    if should_cancel:
+      logging.info("Job %s canceled because the client timed out.", job_id)
+      cbs.CancelJob(job_id)
+      raise errors.JobCanceled("Job was canceled")
+
     if result == constants.JOB_NOTCHANGED:
       report_cbs.ReportNotChanged(job_id, status)
-
       # Wait again
       continue
 
@@ -768,7 +792,7 @@ def GenericPollJob(job_id, cbs, report_cbs):
     return result
 
   if status in (constants.JOB_STATUS_CANCELING, constants.JOB_STATUS_CANCELED):
-    raise errors.OpExecError("Job was canceled")
+    raise errors.JobCanceled("Job was canceled")
 
   has_ok = False
   for idx, (status, msg) in enumerate(zip(opstatus, result)):
@@ -797,7 +821,8 @@ class JobPollCbBase(object):
     """
 
   def WaitForJobChangeOnce(self, job_id, fields,
-                           prev_job_info, prev_log_serial):
+                           prev_job_info, prev_log_serial,
+                           timeout=constants.DEFAULT_WFJC_TIMEOUT):
     """Waits for changes on a job.
 
     """
@@ -814,6 +839,15 @@ class JobPollCbBase(object):
     """
     raise NotImplementedError()
 
+  def CancelJob(self, job_id):
+    """Cancels a currently running job.
+
+    @type job_id: number
+    @param job_id: The ID of the Job we want to cancel
+
+    """
+    raise NotImplementedError()
+
 
 class JobPollReportCbBase(object):
   """Base class for L{GenericPollJob} reporting callbacks.
@@ -851,12 +885,14 @@ class _LuxiJobPollCb(JobPollCbBase):
     self.cl = cl
 
   def WaitForJobChangeOnce(self, job_id, fields,
-                           prev_job_info, prev_log_serial):
+                           prev_job_info, prev_log_serial,
+                           timeout=constants.DEFAULT_WFJC_TIMEOUT):
     """Waits for changes on a job.
 
     """
     return self.cl.WaitForJobChangeOnce(job_id, fields,
-                                        prev_job_info, prev_log_serial)
+                                        prev_job_info, prev_log_serial,
+                                        timeout=timeout)
 
   def QueryJobs(self, job_ids, fields):
     """Returns the selected fields for the selected job IDs.
@@ -864,6 +900,11 @@ class _LuxiJobPollCb(JobPollCbBase):
     """
     return self.cl.QueryJobs(job_ids, fields)
 
+  def CancelJob(self, job_id):
+    """Cancels a currently running job.
+
+    """
+    return self.cl.CancelJob(job_id)
 
 class FeedbackFnJobPollReportCb(JobPollReportCbBase):
   def __init__(self, feedback_fn):
@@ -932,7 +973,8 @@ def FormatLogMessage(log_type, log_msg):
   return utils.SafeEncode(log_msg)
 
 
-def PollJob(job_id, cl=None, feedback_fn=None, reporter=None):
+def PollJob(job_id, cl=None, feedback_fn=None, reporter=None, cancel_fn=None,
+            update_freq=constants.DEFAULT_WFJC_TIMEOUT):
   """Function to poll for the result of a job.
 
   @type job_id: job identified
@@ -940,6 +982,10 @@ def PollJob(job_id, cl=None, feedback_fn=None, 
reporter=None):
   @type cl: luxi.Client
   @param cl: the luxi client to use for communicating with the master;
              if None, a new client will be created
+  @type cancel_fn: Function returning a boolean
+  @param cancel_fn: Function to check if we should cancel the running job
+  @type update_freq: int/long
+  @param update_freq: number of seconds between each WFJC report
 
   """
   if cl is None:
@@ -953,7 +999,8 @@ def PollJob(job_id, cl=None, feedback_fn=None, 
reporter=None):
   elif feedback_fn:
     raise errors.ProgrammerError("Can't specify reporter and feedback 
function")
 
-  return GenericPollJob(job_id, _LuxiJobPollCb(cl), reporter)
+  return GenericPollJob(job_id, _LuxiJobPollCb(cl), reporter,
+                        cancel_fn=cancel_fn, update_freq=update_freq)
 
 
 def SubmitOpCode(op, cl=None, feedback_fn=None, opts=None, reporter=None):
diff --git a/lib/errors.py b/lib/errors.py
index b53ced7..e8671a2 100644
--- a/lib/errors.py
+++ b/lib/errors.py
@@ -242,6 +242,15 @@ class JobLost(GenericError):
   """
 
 
+class JobCanceled(GenericError):
+  """Submitted job was canceled.
+
+  The job that was submitted has transitioned to a canceling or canceled
+  state.
+
+  """
+
+
 class JobFileCorrupted(GenericError):
   """Job file could not be properly decoded/restored.
 
diff --git a/lib/http/server.py b/lib/http/server.py
index 81be4a6..f191c25 100644
--- a/lib/http/server.py
+++ b/lib/http/server.py
@@ -86,12 +86,13 @@ class _HttpServerRequest(object):
   """Data structure for HTTP request on server side.
 
   """
-  def __init__(self, method, path, headers, body):
+  def __init__(self, method, path, headers, body, sock):
     # Request attributes
     self.request_method = method
     self.request_path = path
     self.request_headers = headers
     self.request_body = body
+    self.request_sock = sock
 
     # Response attributes
     self.resp_headers = {}
@@ -225,14 +226,15 @@ class 
_HttpClientToServerMessageReader(http.HttpMessageReader):
     return http.HttpClientToServerStartLine(method, path, version)
 
 
-def _HandleServerRequestInner(handler, req_msg):
+def _HandleServerRequestInner(handler, req_msg, reader):
   """Calls the handler function for the current request.
 
   """
   handler_context = _HttpServerRequest(req_msg.start_line.method,
                                        req_msg.start_line.path,
                                        req_msg.headers,
-                                       req_msg.body)
+                                       req_msg.body,
+                                       reader.sock)
 
   logging.debug("Handling request %r", handler_context)
 
@@ -308,7 +310,7 @@ class HttpResponder(object):
 
       (response_msg.start_line.code, response_msg.headers,
        response_msg.body) = \
-        _HandleServerRequestInner(self._handler, request_msg)
+        _HandleServerRequestInner(self._handler, request_msg, req_msg_reader)
     except http.HttpException, err:
       self._SetError(self.responses, self._handler, response_msg, err)
     else:
diff --git a/lib/rapi/client_utils.py b/lib/rapi/client_utils.py
index 224e1a2..9b7911d 100644
--- a/lib/rapi/client_utils.py
+++ b/lib/rapi/client_utils.py
@@ -53,7 +53,8 @@ class RapiJobPollCb(cli.JobPollCbBase):
     self.cl = cl
 
   def WaitForJobChangeOnce(self, job_id, fields,
-                           prev_job_info, prev_log_serial):
+                           prev_job_info, prev_log_serial,
+                           timeout=constants.DEFAULT_WFJC_TIMEOUT):
     """Waits for changes on a job.
 
     """
diff --git a/lib/rapi/rlib2.py b/lib/rapi/rlib2.py
index 34b4124..8916004 100644
--- a/lib/rapi/rlib2.py
+++ b/lib/rapi/rlib2.py
@@ -64,6 +64,8 @@ PUT should be prefered over POST.
 
 # C0103: Invalid name, since the R_* names are not conforming
 
+import OpenSSL
+
 from ganeti import opcodes
 from ganeti import objects
 from ganeti import http
@@ -198,6 +200,28 @@ def _UpdateBeparams(inst):
 
   return inst
 
+def _CheckIfConnectionDropped(sock):
+  """Utility function to monitor the state of an open connection.
+
+  @param sock: Connection's open socket
+  @return: True if the connection was remotely closed, otherwise False
+
+  """
+  try:
+    result = sock.recv(0)
+    if result == "":
+      return True
+  # The connection is still open
+  except OpenSSL.SSL.WantReadError:
+    return False
+  # The connection has been terminated gracefully
+  except OpenSSL.SSL.ZeroReturnError:
+    return True
+  # The connection was terminated
+  except OpenSSL.SSL.SysCallError:
+    return True
+  return False
+
 
 class R_root(baserlib.ResourceBase):
   """/ resource.
@@ -278,9 +302,11 @@ class R_2_os(baserlib.OpcodeResource):
     """
     cl = self.GetClient()
     op = opcodes.OpOsDiagnose(output_fields=["name", "variants"], names=[])
+    cancel_fn = (lambda: _CheckIfConnectionDropped(self._req.request_sock))
     job_id = self.SubmitJob([op], cl=cl)
     # we use custom feedback function, instead of print we log the status
-    result = cli.PollJob(job_id, cl, feedback_fn=baserlib.FeedbackFn)
+    result = cli.PollJob(job_id, cl, feedback_fn=baserlib.FeedbackFn,
+                         cancel_fn=cancel_fn)
     diagnose_data = result[0]
 
     if not isinstance(diagnose_data, list):
diff --git a/lib/rapi/testutils.py b/lib/rapi/testutils.py
index 8f7a7ad..4d054cd 100644
--- a/lib/rapi/testutils.py
+++ b/lib/rapi/testutils.py
@@ -252,9 +252,10 @@ class _RapiMock(object):
       http.HttpClientToServerStartLine(method, path, http.HTTP_1_0)
     req_msg.headers = headers
     req_msg.body = request_body
+    req_reader = type('TestReader', (object, ), {'sock': None})()
 
     (_, _, _, resp_msg) = \
-      http.server.HttpResponder(self.handler)(lambda: (req_msg, None))
+      http.server.HttpResponder(self.handler)(lambda: (req_msg, req_reader))
 
     return (resp_msg.start_line.code, resp_msg.headers, resp_msg.body)
 
diff --git a/src/Ganeti/Constants.hs b/src/Ganeti/Constants.hs
index 5ef9dac..09783d4 100644
--- a/src/Ganeti/Constants.hs
+++ b/src/Ganeti/Constants.hs
@@ -5051,7 +5051,7 @@ luxiDefCtmo = 10
 luxiDefRwto :: Int
 luxiDefRwto = 60
 
--- | 'WaitForJobChange' timeout
+-- | Luxi 'WaitForJobChange' timeout
 luxiWfjcTimeout :: Int
 luxiWfjcTimeout = (luxiDefRwto - 1) `div` 2
 
@@ -5368,3 +5368,11 @@ dataCollectorsEnabledName = "enabled_data_collectors"
 
 dataCollectorsIntervalName :: String
 dataCollectorsIntervalName = "data_collector_interval"
+
+-- | The polling frequency to wait for a job status change
+cliWfjcFrequency :: Int
+cliWfjcFrequency = 20
+
+-- | Default 'WaitForJobChange' timeout in seconds
+defaultWfjcTimeout :: Int
+defaultWfjcTimeout = 60
diff --git a/test/py/ganeti.cli_unittest.py b/test/py/ganeti.cli_unittest.py
index 1a8674e..a3e28af 100755
--- a/test/py/ganeti.cli_unittest.py
+++ b/test/py/ganeti.cli_unittest.py
@@ -502,7 +502,8 @@ class _MockJobPollCb(cli.JobPollCbBase, 
cli.JobPollReportCbBase):
     self._jobstatus.append(args)
 
   def WaitForJobChangeOnce(self, job_id, fields,
-                           prev_job_info, prev_log_serial):
+                           prev_job_info, prev_log_serial,
+                           timeout=constants.DEFAULT_WFJC_TIMEOUT):
     self.tc.assertEqual(job_id, self.job_id)
     self.tc.assertEqualValues(fields, ["status"])
     self.tc.assertFalse(self._expect_notchanged)
@@ -531,6 +532,9 @@ class _MockJobPollCb(cli.JobPollCbBase, 
cli.JobPollReportCbBase):
     self.tc.assertEqual(len(fields), len(result))
     return [result]
 
+  def CancelJob(self, job_id):
+    self.tc.assertEqual(job_id, self.job_id)
+
   def ReportLogMessage(self, job_id, serial, timestamp, log_type, log_msg):
     self.tc.assertEqual(job_id, self.job_id)
     self.tc.assertEqualValues((serial, timestamp, log_type, log_msg),
@@ -646,9 +650,43 @@ class TestGenericPollJob(testutils.GanetiTestCase):
                            [constants.OP_STATUS_CANCELING,
                             constants.OP_STATUS_CANCELING],
                            [None, None])
-    self.assertRaises(errors.OpExecError, cli.GenericPollJob, job_id, cbs, cbs)
+    self.assertRaises(errors.JobCanceled, cli.GenericPollJob, job_id, cbs, cbs)
     cbs.CheckEmpty()
 
+  def testNegativeUpdateFreqParameter(self):
+    job_id = 12345
+
+    cbs = _MockJobPollCb(self, job_id)
+    self.assertRaises(errors.ParameterError, cli.GenericPollJob, job_id, cbs,
+                      cbs, update_freq=-30)
+
+  def testZeroUpdateFreqParameter(self):
+    job_id = 12345
+
+    cbs = _MockJobPollCb(self, job_id)
+    self.assertRaises(errors.ParameterError, cli.GenericPollJob, job_id, cbs,
+                      cbs, update_freq=0)
+
+  def testShouldCancel(self):
+    job_id = 12345
+
+    cbs = _MockJobPollCb(self, job_id)
+    cbs.AddWfjcResult(None, None, constants.JOB_NOTCHANGED)
+    self.assertRaises(errors.JobCanceled, cli.GenericPollJob, job_id, cbs, cbs,
+                      cancel_fn=(lambda: True))
+
+  def testIgnoreCancel(self):
+    job_id = 12345
+    cbs = _MockJobPollCb(self, job_id)
+    cbs.AddWfjcResult(None, None, ((constants.JOB_STATUS_SUCCESS, ), None))
+    cbs.AddQueryJobsResult(constants.JOB_STATUS_SUCCESS,
+                           [constants.OP_STATUS_SUCCESS,
+                            constants.OP_STATUS_SUCCESS],
+                           ["Hello World", "Foo man bar"])
+    self.assertEqual(["Hello World", "Foo man bar"],
+                     cli.GenericPollJob(
+                         job_id, cbs, cbs, cancel_fn=(lambda: False)))
+    cbs.CheckEmpty()
 
 class TestFormatLogMessage(unittest.TestCase):
   def test(self):
diff --git a/test/py/ganeti.http_unittest.py b/test/py/ganeti.http_unittest.py
index 518f817..99bceff 100755
--- a/test/py/ganeti.http_unittest.py
+++ b/test/py/ganeti.http_unittest.py
@@ -85,7 +85,8 @@ class TestMisc(unittest.TestCase):
 
   def testHttpServerRequest(self):
     """Test ganeti.http.server._HttpServerRequest"""
-    server_request = http.server._HttpServerRequest("GET", "/", None, None)
+    server_request = \
+        http.server._HttpServerRequest("GET", "/", None, None, None)
 
     # These are expected by users of the HTTP server
     self.assert_(hasattr(server_request, "request_method"))
@@ -210,30 +211,30 @@ class _SimpleAuthenticator:
 
 class TestHttpServerRequestAuthentication(unittest.TestCase):
   def testNoAuth(self):
-    req = http.server._HttpServerRequest("GET", "/", None, None)
+    req = http.server._HttpServerRequest("GET", "/", None, None, None)
     _FakeRequestAuth("area1", False, None).PreHandleRequest(req)
 
   def testNoRealm(self):
     headers = { http.HTTP_AUTHORIZATION: "", }
-    req = http.server._HttpServerRequest("GET", "/", headers, None)
+    req = http.server._HttpServerRequest("GET", "/", headers, None, None)
     ra = _FakeRequestAuth(None, False, None)
     self.assertRaises(AssertionError, ra.PreHandleRequest, req)
 
   def testNoScheme(self):
     headers = { http.HTTP_AUTHORIZATION: "", }
-    req = http.server._HttpServerRequest("GET", "/", headers, None)
+    req = http.server._HttpServerRequest("GET", "/", headers, None, None)
     ra = _FakeRequestAuth("area1", False, None)
     self.assertRaises(http.HttpUnauthorized, ra.PreHandleRequest, req)
 
   def testUnknownScheme(self):
     headers = { http.HTTP_AUTHORIZATION: "NewStyleAuth abc", }
-    req = http.server._HttpServerRequest("GET", "/", headers, None)
+    req = http.server._HttpServerRequest("GET", "/", headers, None, None)
     ra = _FakeRequestAuth("area1", False, None)
     self.assertRaises(http.HttpUnauthorized, ra.PreHandleRequest, req)
 
   def testInvalidBase64(self):
     headers = { http.HTTP_AUTHORIZATION: "Basic x_=_", }
-    req = http.server._HttpServerRequest("GET", "/", headers, None)
+    req = http.server._HttpServerRequest("GET", "/", headers, None, None)
     ra = _FakeRequestAuth("area1", False, None)
     self.assertRaises(http.HttpUnauthorized, ra.PreHandleRequest, req)
 
@@ -241,7 +242,7 @@ class 
TestHttpServerRequestAuthentication(unittest.TestCase):
     headers = {
       http.HTTP_AUTHORIZATION: "Basic %s" % ("foo".encode("base64").strip(), ),
       }
-    req = http.server._HttpServerRequest("GET", "/", headers, None)
+    req = http.server._HttpServerRequest("GET", "/", headers, None, None)
     ra = _FakeRequestAuth("area1", False, None)
     self.assertRaises(http.HttpUnauthorized, ra.PreHandleRequest, req)
 
@@ -250,12 +251,12 @@ class 
TestHttpServerRequestAuthentication(unittest.TestCase):
       http.HTTP_AUTHORIZATION:
         "Basic %s" % ("foo:bar".encode("base64").strip(), ),
       }
-    req = http.server._HttpServerRequest("GET", "/", headers, None)
+    req = http.server._HttpServerRequest("GET", "/", headers, None, None)
     ac = _SimpleAuthenticator("foo", "bar")
     ra = _FakeRequestAuth("area1", False, ac)
     ra.PreHandleRequest(req)
 
-    req = http.server._HttpServerRequest("GET", "/", headers, None)
+    req = http.server._HttpServerRequest("GET", "/", headers, None, None)
     ac = _SimpleAuthenticator("something", "else")
     ra = _FakeRequestAuth("area1", False, ac)
     self.assertRaises(http.HttpUnauthorized, ra.PreHandleRequest, req)
@@ -270,7 +271,7 @@ class 
TestHttpServerRequestAuthentication(unittest.TestCase):
     for exc, headers in checks.items():
       for i in headers:
         headers = { http.HTTP_AUTHORIZATION: i, }
-        req = http.server._HttpServerRequest("GET", "/", headers, None)
+        req = http.server._HttpServerRequest("GET", "/", headers, None, None)
         ra = _FakeRequestAuth("area1", False, None)
         self.assertRaises(exc, ra.PreHandleRequest, req)
 
@@ -286,7 +287,7 @@ class 
TestHttpServerRequestAuthentication(unittest.TestCase):
               http.HTTP_AUTHORIZATION:
                 "Basic %s" % (basic_auth.encode("base64").strip(), ),
             }
-          req = http.server._HttpServerRequest("GET", "/", headers, None)
+          req = http.server._HttpServerRequest("GET", "/", headers, None, None)
 
           ac = _SimpleAuthenticator(user, pw)
           self.assertFalse(ac.called)
-- 
2.8.0.rc3.226.g39d4020

Reply via email to