Almost LGTM, this now returns immediately if the job completes and the
distinction between poll and update frequency is clear. Please override
abstract CancelJob in RapiJobPollCb class as well (client_utils.py).
On Tuesday, April 26, 2016 at 3:43:18 PM UTC+1, Federico Pareschi wrote:
>
> This commit fixes a bug where long-lasting OsDiagnose jobs would
> unnecessarily execute after a RAPI client had already closed the
> connection. This fix introduces a couple new polling/timeout
> constants for Luxi and RAPI to handle the polling of the HTTP
> connection while at the same time monitor the state of the submitted
> job.
>
> If the connection is closed, client-side, the job is canceled and it
> will be removed from the job queue if not already in a running state.
>
> ---
> lib/cli.py | 69
> ++++++++++++++++++++++++++++++++++-------
> lib/errors.py | 9 ++++++
> lib/http/server.py | 10 +++---
> lib/rapi/client_utils.py | 3 +-
> lib/rapi/rlib2.py | 28 ++++++++++++++++-
> lib/rapi/testutils.py | 3 +-
> src/Ganeti/Constants.hs | 10 +++++-
> test/py/ganeti.cli_unittest.py | 42 +++++++++++++++++++++++--
> test/py/ganeti.http_unittest.py | 23 +++++++-------
> 9 files changed, 165 insertions(+), 32 deletions(-)
>
> diff --git a/lib/cli.py b/lib/cli.py
> index 2001ed9..a470ffa 100644
> --- a/lib/cli.py
> +++ b/lib/cli.py
> @@ -705,7 +705,8 @@ def SendJob(ops, cl=None):
> return job_id
>
>
> -def GenericPollJob(job_id, cbs, report_cbs):
> +def GenericPollJob(job_id, cbs, report_cbs, cancel_fn=None,
> + update_freq=constants.DEFAULT_WFJC_TIMEOUT):
> """Generic job-polling function.
>
> @type job_id: number
> @@ -714,9 +715,13 @@ def GenericPollJob(job_id, cbs, report_cbs):
> @param cbs: Data callbacks
> @type report_cbs: Instance of L{JobPollReportCbBase}
> @param report_cbs: Reporting callbacks
> -
> + @type cancel_fn: Function returning a boolean
> + @param cancel_fn: Function to check if we should cancel the running job
> + @type update_freq: int/long
> + @param update_freq: number of seconds between each WFJC reports
> @return: the opresult of the job
> @raise errors.JobLost: If job can't be found
> + @raise errors.JobCanceled: If job is canceled
> @raise errors.OpExecError: If job didn't succeed
>
> """
> @@ -724,17 +729,36 @@ def GenericPollJob(job_id, cbs, report_cbs):
> prev_logmsg_serial = None
>
> status = None
> + should_cancel = False
> +
> + if update_freq <= 0:
> + raise errors.ParameterError("Update frequency must be a positive
> number")
>
> while True:
> - result = cbs.WaitForJobChangeOnce(job_id, ["status"], prev_job_info,
> - prev_logmsg_serial)
> + if cancel_fn:
> + timer = 0
> + while timer < update_freq:
> + result = cbs.WaitForJobChangeOnce(job_id, ["status"],
> prev_job_info,
> + prev_logmsg_serial,
> +
> timeout=constants.CLI_WFJC_FREQUENCY)
> + should_cancel = cancel_fn()
> + if should_cancel or not result or result !=
> constants.JOB_NOTCHANGED:
> + break
> + timer += constants.CLI_WFJC_FREQUENCY
> + else:
> + result = cbs.WaitForJobChangeOnce(job_id, ["status"],
> prev_job_info,
> + prev_logmsg_serial,
> timeout=update_freq)
> if not result:
> # job not found, go away!
> raise errors.JobLost("Job with id %s lost" % job_id)
>
> + if should_cancel:
> + logging.info("Job %s canceled because the client timed out.",
> job_id)
> + cbs.CancelJob(job_id)
> + raise errors.JobCanceled("Job was canceled")
> +
> if result == constants.JOB_NOTCHANGED:
> report_cbs.ReportNotChanged(job_id, status)
> -
> # Wait again
> continue
>
> @@ -768,7 +792,7 @@ def GenericPollJob(job_id, cbs, report_cbs):
> return result
>
> if status in (constants.JOB_STATUS_CANCELING,
> constants.JOB_STATUS_CANCELED):
> - raise errors.OpExecError("Job was canceled")
> + raise errors.JobCanceled("Job was canceled")
>
> has_ok = False
> for idx, (status, msg) in enumerate(zip(opstatus, result)):
> @@ -797,7 +821,8 @@ class JobPollCbBase(object):
> """
>
> def WaitForJobChangeOnce(self, job_id, fields,
> - prev_job_info, prev_log_serial):
> + prev_job_info, prev_log_serial,
> + timeout=constants.DEFAULT_WFJC_TIMEOUT):
> """Waits for changes on a job.
>
> """
> @@ -814,6 +839,15 @@ class JobPollCbBase(object):
> """
> raise NotImplementedError()
>
> + def CancelJob(self, job_id):
> + """Cancels a currently running job.
> +
> + @type job_id: number
> + @param job_id: The ID of the Job we want to cancel
> +
> + """
> + raise NotImplementedError()
> +
>
> class JobPollReportCbBase(object):
> """Base class for L{GenericPollJob} reporting callbacks.
> @@ -851,12 +885,14 @@ class _LuxiJobPollCb(JobPollCbBase):
> self.cl = cl
>
> def WaitForJobChangeOnce(self, job_id, fields,
> - prev_job_info, prev_log_serial):
> + prev_job_info, prev_log_serial,
> + timeout=constants.DEFAULT_WFJC_TIMEOUT):
> """Waits for changes on a job.
>
> """
> return self.cl.WaitForJobChangeOnce(job_id, fields,
> - prev_job_info, prev_log_serial)
> + prev_job_info, prev_log_serial,
> + timeout=timeout)
>
> def QueryJobs(self, job_ids, fields):
> """Returns the selected fields for the selected job IDs.
> @@ -864,6 +900,11 @@ class _LuxiJobPollCb(JobPollCbBase):
> """
> return self.cl.QueryJobs(job_ids, fields)
>
> + def CancelJob(self, job_id):
> + """Cancels a currently running job.
> +
> + """
> + return self.cl.CancelJob(job_id)
>
> class FeedbackFnJobPollReportCb(JobPollReportCbBase):
> def __init__(self, feedback_fn):
> @@ -932,7 +973,8 @@ def FormatLogMessage(log_type, log_msg):
> return utils.SafeEncode(log_msg)
>
>
> -def PollJob(job_id, cl=None, feedback_fn=None, reporter=None):
> +def PollJob(job_id, cl=None, feedback_fn=None, reporter=None,
> cancel_fn=None,
> + update_freq=constants.DEFAULT_WFJC_TIMEOUT):
> """Function to poll for the result of a job.
>
> @type job_id: job identified
> @@ -940,6 +982,10 @@ def PollJob(job_id, cl=None, feedback_fn=None,
> reporter=None):
> @type cl: luxi.Client
> @param cl: the luxi client to use for communicating with the master;
> if None, a new client will be created
> + @type cancel_fn: Function returning a boolean
> + @param cancel_fn: Function to check if we should cancel the running job
> + @type update_freq: int/long
> + @param update_freq: number of seconds between each WFJC report
>
> """
> if cl is None:
> @@ -953,7 +999,8 @@ def PollJob(job_id, cl=None, feedback_fn=None,
> reporter=None):
> elif feedback_fn:
> raise errors.ProgrammerError("Can't specify reporter and feedback
> function")
>
> - return GenericPollJob(job_id, _LuxiJobPollCb(cl), reporter)
> + return GenericPollJob(job_id, _LuxiJobPollCb(cl), reporter,
> + cancel_fn=cancel_fn, update_freq=update_freq)
>
>
> def SubmitOpCode(op, cl=None, feedback_fn=None, opts=None,
> reporter=None):
> diff --git a/lib/errors.py b/lib/errors.py
> index b53ced7..e8671a2 100644
> --- a/lib/errors.py
> +++ b/lib/errors.py
> @@ -242,6 +242,15 @@ class JobLost(GenericError):
> """
>
>
> +class JobCanceled(GenericError):
> + """Submitted job was canceled.
> +
> + The job that was submitted has transitioned to a canceling or canceled
> + state.
> +
> + """
> +
> +
> class JobFileCorrupted(GenericError):
> """Job file could not be properly decoded/restored.
>
> diff --git a/lib/http/server.py b/lib/http/server.py
> index 81be4a6..f191c25 100644
> --- a/lib/http/server.py
> +++ b/lib/http/server.py
> @@ -86,12 +86,13 @@ class _HttpServerRequest(object):
> """Data structure for HTTP request on server side.
>
> """
> - def __init__(self, method, path, headers, body):
> + def __init__(self, method, path, headers, body, sock):
> # Request attributes
> self.request_method = method
> self.request_path = path
> self.request_headers = headers
> self.request_body = body
> + self.request_sock = sock
>
> # Response attributes
> self.resp_headers = {}
> @@ -225,14 +226,15 @@ class
> _HttpClientToServerMessageReader(http.HttpMessageReader):
> return http.HttpClientToServerStartLine(method, path, version)
>
>
> -def _HandleServerRequestInner(handler, req_msg):
> +def _HandleServerRequestInner(handler, req_msg, reader):
> """Calls the handler function for the current request.
>
> """
> handler_context = _HttpServerRequest(req_msg.start_line.method,
> req_msg.start_line.path,
> req_msg.headers,
> - req_msg.body)
> + req_msg.body,
> + reader.sock)
>
> logging.debug("Handling request %r", handler_context)
>
> @@ -308,7 +310,7 @@ class HttpResponder(object):
>
> (response_msg.start_line.code, response_msg.headers,
> response_msg.body) = \
> - _HandleServerRequestInner(self._handler, request_msg)
> + _HandleServerRequestInner(self._handler, request_msg,
> req_msg_reader)
> except http.HttpException, err:
> self._SetError(self.responses, self._handler, response_msg, err)
> else:
> diff --git a/lib/rapi/client_utils.py b/lib/rapi/client_utils.py
> index 224e1a2..9b7911d 100644
> --- a/lib/rapi/client_utils.py
> +++ b/lib/rapi/client_utils.py
> @@ -53,7 +53,8 @@ class RapiJobPollCb(cli.JobPollCbBase):
> self.cl = cl
>
> def WaitForJobChangeOnce(self, job_id, fields,
> - prev_job_info, prev_log_serial):
> + prev_job_info, prev_log_serial,
> + timeout=constants.DEFAULT_WFJC_TIMEOUT):
> """Waits for changes on a job.
>
> """
> diff --git a/lib/rapi/rlib2.py b/lib/rapi/rlib2.py
> index 34b4124..8916004 100644
> --- a/lib/rapi/rlib2.py
> +++ b/lib/rapi/rlib2.py
> @@ -64,6 +64,8 @@ PUT should be prefered over POST.
>
> # C0103: Invalid name, since the R_* names are not conforming
>
> +import OpenSSL
> +
> from ganeti import opcodes
> from ganeti import objects
> from ganeti import http
> @@ -198,6 +200,28 @@ def _UpdateBeparams(inst):
>
> return inst
>
> +def _CheckIfConnectionDropped(sock):
> + """Utility function to monitor the state of an open connection.
> +
> + @param sock: Connection's open socket
> + @return: True if the connection was remotely closed, otherwise False
> +
> + """
> + try:
> + result = sock.recv(0)
> + if result == "":
> + return True
> + # The connection is still open
> + except OpenSSL.SSL.WantReadError:
> + return False
> + # The connection has been terminated gracefully
> + except OpenSSL.SSL.ZeroReturnError:
> + return True
> + # The connection was terminated
> + except OpenSSL.SSL.SysCallError:
> + return True
> + return False
> +
>
> class R_root(baserlib.ResourceBase):
> """/ resource.
> @@ -278,9 +302,11 @@ class R_2_os(baserlib.OpcodeResource):
> """
> cl = self.GetClient()
> op = opcodes.OpOsDiagnose(output_fields=["name", "variants"],
> names=[])
> + cancel_fn = (lambda:
> _CheckIfConnectionDropped(self._req.request_sock))
> job_id = self.SubmitJob([op], cl=cl)
> # we use custom feedback function, instead of print we log the status
> - result = cli.PollJob(job_id, cl, feedback_fn=baserlib.FeedbackFn)
> + result = cli.PollJob(job_id, cl, feedback_fn=baserlib.FeedbackFn,
> + cancel_fn=cancel_fn)
> diagnose_data = result[0]
>
> if not isinstance(diagnose_data, list):
> diff --git a/lib/rapi/testutils.py b/lib/rapi/testutils.py
> index 8f7a7ad..4d054cd 100644
> --- a/lib/rapi/testutils.py
> +++ b/lib/rapi/testutils.py
> @@ -252,9 +252,10 @@ class _RapiMock(object):
> http.HttpClientToServerStartLine(method, path, http.HTTP_1_0)
> req_msg.headers = headers
> req_msg.body = request_body
> + req_reader = type('TestReader', (object, ), {'sock': None})()
>
> (_, _, _, resp_msg) = \
> - http.server.HttpResponder(self.handler)(lambda: (req_msg, None))
> + http.server.HttpResponder(self.handler)(lambda: (req_msg,
> req_reader))
>
> return (resp_msg.start_line.code, resp_msg.headers, resp_msg.body)
>
> diff --git a/src/Ganeti/Constants.hs b/src/Ganeti/Constants.hs
> index 5ef9dac..09783d4 100644
> --- a/src/Ganeti/Constants.hs
> +++ b/src/Ganeti/Constants.hs
> @@ -5051,7 +5051,7 @@ luxiDefCtmo = 10
> luxiDefRwto :: Int
> luxiDefRwto = 60
>
> --- | 'WaitForJobChange' timeout
> +-- | Luxi 'WaitForJobChange' timeout
> luxiWfjcTimeout :: Int
> luxiWfjcTimeout = (luxiDefRwto - 1) `div` 2
>
> @@ -5368,3 +5368,11 @@ dataCollectorsEnabledName =
> "enabled_data_collectors"
>
> dataCollectorsIntervalName :: String
> dataCollectorsIntervalName = "data_collector_interval"
> +
> +-- | The polling frequency to wait for a job status change
> +cliWfjcFrequency :: Int
> +cliWfjcFrequency = 20
> +
> +-- | Default 'WaitForJobChange' timeout in seconds
> +defaultWfjcTimeout :: Int
> +defaultWfjcTimeout = 60
> diff --git a/test/py/ganeti.cli_unittest.py b/test/py/
> ganeti.cli_unittest.py
> index 1a8674e..a3e28af 100755
> --- a/test/py/ganeti.cli_unittest.py
> +++ b/test/py/ganeti.cli_unittest.py
> @@ -502,7 +502,8 @@ class _MockJobPollCb(cli.JobPollCbBase,
> cli.JobPollReportCbBase):
> self._jobstatus.append(args)
>
> def WaitForJobChangeOnce(self, job_id, fields,
> - prev_job_info, prev_log_serial):
> + prev_job_info, prev_log_serial,
> + timeout=constants.DEFAULT_WFJC_TIMEOUT):
> self.tc.assertEqual(job_id, self.job_id)
> self.tc.assertEqualValues(fields, ["status"])
> self.tc.assertFalse(self._expect_notchanged)
> @@ -531,6 +532,9 @@ class _MockJobPollCb(cli.JobPollCbBase,
> cli.JobPollReportCbBase):
> self.tc.assertEqual(len(fields), len(result))
> return [result]
>
> + def CancelJob(self, job_id):
> + self.tc.assertEqual(job_id, self.job_id)
> +
> def ReportLogMessage(self, job_id, serial, timestamp, log_type,
> log_msg):
> self.tc.assertEqual(job_id, self.job_id)
> self.tc.assertEqualValues((serial, timestamp, log_type, log_msg),
> @@ -646,9 +650,43 @@ class TestGenericPollJob(testutils.GanetiTestCase):
> [constants.OP_STATUS_CANCELING,
> constants.OP_STATUS_CANCELING],
> [None, None])
> - self.assertRaises(errors.OpExecError, cli.GenericPollJob, job_id,
> cbs, cbs)
> + self.assertRaises(errors.JobCanceled, cli.GenericPollJob, job_id,
> cbs, cbs)
> cbs.CheckEmpty()
>
> + def testNegativeUpdateFreqParameter(self):
> + job_id = 12345
> +
> + cbs = _MockJobPollCb(self, job_id)
> + self.assertRaises(errors.ParameterError, cli.GenericPollJob, job_id,
> cbs,
> + cbs, update_freq=-30)
> +
> + def testZeroUpdateFreqParameter(self):
> + job_id = 12345
> +
> + cbs = _MockJobPollCb(self, job_id)
> + self.assertRaises(errors.ParameterError, cli.GenericPollJob, job_id,
> cbs,
> + cbs, update_freq=0)
> +
> + def testShouldCancel(self):
> + job_id = 12345
> +
> + cbs = _MockJobPollCb(self, job_id)
> + cbs.AddWfjcResult(None, None, constants.JOB_NOTCHANGED)
> + self.assertRaises(errors.JobCanceled, cli.GenericPollJob, job_id,
> cbs, cbs,
> + cancel_fn=(lambda: True))
> +
> + def testIgnoreCancel(self):
> + job_id = 12345
> + cbs = _MockJobPollCb(self, job_id)
> + cbs.AddWfjcResult(None, None, ((constants.JOB_STATUS_SUCCESS, ),
> None))
> + cbs.AddQueryJobsResult(constants.JOB_STATUS_SUCCESS,
> + [constants.OP_STATUS_SUCCESS,
> + constants.OP_STATUS_SUCCESS],
> + ["Hello World", "Foo man bar"])
> + self.assertEqual(["Hello World", "Foo man bar"],
> + cli.GenericPollJob(
> + job_id, cbs, cbs, cancel_fn=(lambda: False)))
> + cbs.CheckEmpty()
>
> class TestFormatLogMessage(unittest.TestCase):
> def test(self):
> diff --git a/test/py/ganeti.http_unittest.py b/test/py/
> ganeti.http_unittest.py
> index 518f817..99bceff 100755
> --- a/test/py/ganeti.http_unittest.py
> +++ b/test/py/ganeti.http_unittest.py
> @@ -85,7 +85,8 @@ class TestMisc(unittest.TestCase):
>
> def testHttpServerRequest(self):
> """Test ganeti.http.server._HttpServerRequest"""
> - server_request = http.server._HttpServerRequest("GET", "/", None,
> None)
> + server_request = \
> + http.server._HttpServerRequest("GET", "/", None, None, None)
>
> # These are expected by users of the HTTP server
> self.assert_(hasattr(server_request, "request_method"))
> @@ -210,30 +211,30 @@ class _SimpleAuthenticator:
>
> class TestHttpServerRequestAuthentication(unittest.TestCase):
> def testNoAuth(self):
> - req = http.server._HttpServerRequest("GET", "/", None, None)
> + req = http.server._HttpServerRequest("GET", "/", None, None, None)
> _FakeRequestAuth("area1", False, None).PreHandleRequest(req)
>
> def testNoRealm(self):
> headers = { http.HTTP_AUTHORIZATION: "", }
> - req = http.server._HttpServerRequest("GET", "/", headers, None)
> + req = http.server._HttpServerRequest("GET", "/", headers, None, None)
> ra = _FakeRequestAuth(None, False, None)
> self.assertRaises(AssertionError, ra.PreHandleRequest, req)
>
> def testNoScheme(self):
> headers = { http.HTTP_AUTHORIZATION: "", }
> - req = http.server._HttpServerRequest("GET", "/", headers, None)
> + req = http.server._HttpServerRequest("GET", "/", headers, None, None)
> ra = _FakeRequestAuth("area1", False, None)
> self.assertRaises(http.HttpUnauthorized, ra.PreHandleRequest, req)
>
> def testUnknownScheme(self):
> headers = { http.HTTP_AUTHORIZATION: "NewStyleAuth abc", }
> - req = http.server._HttpServerRequest("GET", "/", headers, None)
> + req = http.server._HttpServerRequest("GET", "/", headers, None, None)
> ra = _FakeRequestAuth("area1", False, None)
> self.assertRaises(http.HttpUnauthorized, ra.PreHandleRequest, req)
>
> def testInvalidBase64(self):
> headers = { http.HTTP_AUTHORIZATION: "Basic x_=_", }
> - req = http.server._HttpServerRequest("GET", "/", headers, None)
> + req = http.server._HttpServerRequest("GET", "/", headers, None, None)
> ra = _FakeRequestAuth("area1", False, None)
> self.assertRaises(http.HttpUnauthorized, ra.PreHandleRequest, req)
>
> @@ -241,7 +242,7 @@ class
> TestHttpServerRequestAuthentication(unittest.TestCase):
> headers = {
> http.HTTP_AUTHORIZATION: "Basic %s" %
> ("foo".encode("base64").strip(), ),
> }
> - req = http.server._HttpServerRequest("GET", "/", headers, None)
> + req = http.server._HttpServerRequest("GET", "/", headers, None, None)
> ra = _FakeRequestAuth("area1", False, None)
> self.assertRaises(http.HttpUnauthorized, ra.PreHandleRequest, req)
>
> @@ -250,12 +251,12 @@ class
> TestHttpServerRequestAuthentication(unittest.TestCase):
> http.HTTP_AUTHORIZATION:
> "Basic %s" % ("foo:bar".encode("base64").strip(), ),
> }
> - req = http.server._HttpServerRequest("GET", "/", headers, None)
> + req = http.server._HttpServerRequest("GET", "/", headers, None, None)
> ac = _SimpleAuthenticator("foo", "bar")
> ra = _FakeRequestAuth("area1", False, ac)
> ra.PreHandleRequest(req)
>
> - req = http.server._HttpServerRequest("GET", "/", headers, None)
> + req = http.server._HttpServerRequest("GET", "/", headers, None, None)
> ac = _SimpleAuthenticator("something", "else")
> ra = _FakeRequestAuth("area1", False, ac)
> self.assertRaises(http.HttpUnauthorized, ra.PreHandleRequest, req)
> @@ -270,7 +271,7 @@ class
> TestHttpServerRequestAuthentication(unittest.TestCase):
> for exc, headers in checks.items():
> for i in headers:
> headers = { http.HTTP_AUTHORIZATION: i, }
> - req = http.server._HttpServerRequest("GET", "/", headers, None)
> + req = http.server._HttpServerRequest("GET", "/", headers, None,
> None)
> ra = _FakeRequestAuth("area1", False, None)
> self.assertRaises(exc, ra.PreHandleRequest, req)
>
> @@ -286,7 +287,7 @@ class
> TestHttpServerRequestAuthentication(unittest.TestCase):
> http.HTTP_AUTHORIZATION:
> "Basic %s" % (basic_auth.encode("base64").strip(), ),
> }
> - req = http.server._HttpServerRequest("GET", "/", headers, None)
> + req = http.server._HttpServerRequest("GET", "/", headers, None,
> None)
>
> ac = _SimpleAuthenticator(user, pw)
> self.assertFalse(ac.called)
> --
> 2.8.0.rc3.226.g39d4020
>
>