Repository: incubator-ariatosca
Updated Branches:
  refs/heads/ARIA-18-migrate-tosca-parser 557f8ec46 -> b2c8379c6 (forced update)


ARIA-11 Add cancel workflow execution support

This commit also adds an API to supply custom field validation
functions in the storage mechanism.


Project: http://git-wip-us.apache.org/repos/asf/incubator-ariatosca/repo
Commit: 
http://git-wip-us.apache.org/repos/asf/incubator-ariatosca/commit/48076260
Tree: http://git-wip-us.apache.org/repos/asf/incubator-ariatosca/tree/48076260
Diff: http://git-wip-us.apache.org/repos/asf/incubator-ariatosca/diff/48076260

Branch: refs/heads/ARIA-18-migrate-tosca-parser
Commit: 48076260d088dd0270c8ed6f6121f3b2e3a2b6e5
Parents: 9a44178
Author: Dan Kilman <d...@gigaspaces.com>
Authored: Wed Nov 9 16:37:43 2016 +0200
Committer: Dan Kilman <d...@gigaspaces.com>
Committed: Sun Nov 13 18:33:45 2016 +0200

----------------------------------------------------------------------
 aria/context/workflow.py                     | 18 ++++++
 aria/events/__init__.py                      |  2 +
 aria/events/builtin_event_handler.py         | 36 ++++++++---
 aria/events/workflow_engine_event_handler.py | 14 ++++-
 aria/storage/models.py                       | 30 +++++++--
 aria/storage/structures.py                   | 11 +++-
 aria/tools/validation.py                     | 11 ----
 aria/workflows/core/engine.py                | 21 ++++++-
 tests/context/__init__.py                    | 14 +++++
 tests/context/test_workflow.py               | 62 +++++++++++++++++++
 tests/mock/context.py                        |  1 +
 tests/storage/test_field.py                  | 24 ++++++--
 tests/storage/test_models.py                 | 68 +++++++++++++++++++-
 tests/workflows/core/test_engine.py          | 75 +++++++++++++++++++++--
 14 files changed, 346 insertions(+), 41 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-ariatosca/blob/48076260/aria/context/workflow.py
----------------------------------------------------------------------
diff --git a/aria/context/workflow.py b/aria/context/workflow.py
index 44db38d..329757a 100644
--- a/aria/context/workflow.py
+++ b/aria/context/workflow.py
@@ -61,6 +61,12 @@ class WorkflowContext(logger.LoggerMixin):
         self.parameters = parameters or {}
         self.task_max_retries = task_max_retries
         self.task_retry_interval = task_retry_interval
+        # TODO: execution creation should happen somewhere else
+        # should be moved there, when such logical place exists
+        try:
+            self.model.execution.get(self.execution_id)
+        except exceptions.StorageError:
+            self._create_execution()
 
     def __repr__(self):
         return (
@@ -69,6 +75,18 @@ class WorkflowContext(logger.LoggerMixin):
             'execution_id={self.execution_id})'.format(
                 name=self.__class__.__name__, self=self))
 
+    def _create_execution(self):
+        execution_cls = self.model.execution.model_cls
+        execution = self.model.execution.model_cls(
+            id=self.execution_id,
+            deployment_id=self.deployment_id,
+            workflow_id=self.workflow_id,
+            blueprint_id=self.blueprint_id,
+            status=execution_cls.PENDING,
+            parameters=self.parameters,
+        )
+        self.model.execution.store(execution)
+
     @property
     def blueprint_id(self):
         """

http://git-wip-us.apache.org/repos/asf/incubator-ariatosca/blob/48076260/aria/events/__init__.py
----------------------------------------------------------------------
diff --git a/aria/events/__init__.py b/aria/events/__init__.py
index 74f3e22..2e88733 100644
--- a/aria/events/__init__.py
+++ b/aria/events/__init__.py
@@ -46,6 +46,8 @@ on_failure_task_signal = signal('failure_task_signal')
 
 # workflow engine workflow signals:
 start_workflow_signal = signal('start_workflow_signal')
+on_cancelling_workflow_signal = signal('on_cancelling_workflow_signal')
+on_cancelled_workflow_signal = signal('on_cancelled_workflow_signal')
 on_success_workflow_signal = signal('on_success_workflow_signal')
 on_failure_workflow_signal = signal('on_failure_workflow_signal')
 

http://git-wip-us.apache.org/repos/asf/incubator-ariatosca/blob/48076260/aria/events/builtin_event_handler.py
----------------------------------------------------------------------
diff --git a/aria/events/builtin_event_handler.py 
b/aria/events/builtin_event_handler.py
index b54024b..ac65978 100644
--- a/aria/events/builtin_event_handler.py
+++ b/aria/events/builtin_event_handler.py
@@ -30,6 +30,8 @@ from . import (
     start_workflow_signal,
     on_success_workflow_signal,
     on_failure_workflow_signal,
+    on_cancelled_workflow_signal,
+    on_cancelling_workflow_signal,
     sent_task_signal,
     start_task_signal,
     on_success_task_signal,
@@ -71,16 +73,9 @@ def _task_succeeded(task, *args, **kwargs):
 
 @start_workflow_signal.connect
 def _workflow_started(workflow_context, *args, **kwargs):
-    execution_cls = workflow_context.model.execution.model_cls
-    execution = execution_cls(
-        id=workflow_context.execution_id,
-        deployment_id=workflow_context.deployment_id,
-        workflow_id=workflow_context.workflow_id,
-        blueprint_id=workflow_context.blueprint_id,
-        status=execution_cls.PENDING,
-        started_at=datetime.utcnow(),
-        parameters=workflow_context.parameters,
-    )
+    execution = workflow_context.execution
+    execution.status = execution.STARTED
+    execution.started_at = datetime.utcnow()
     workflow_context.execution = execution
 
 
@@ -99,3 +94,24 @@ def _workflow_succeeded(workflow_context, *args, **kwargs):
     execution.status = execution.TERMINATED
     execution.ended_at = datetime.utcnow()
     workflow_context.execution = execution
+
+
+@on_cancelled_workflow_signal.connect
+def _workflow_cancelled(workflow_context, *args, **kwargs):
+    execution = workflow_context.execution
+    # _workflow_cancelling function may have called this function
+    # already
+    if execution.status == execution.CANCELLED:
+        return
+    execution.status = execution.CANCELLED
+    execution.ended_at = datetime.utcnow()
+    workflow_context.execution = execution
+
+
+@on_cancelling_workflow_signal.connect
+def _workflow_cancelling(workflow_context, *args, **kwargs):
+    execution = workflow_context.execution
+    if execution.status == execution.PENDING:
+        return _workflow_cancelled(workflow_context=workflow_context)
+    execution.status = execution.CANCELLING
+    workflow_context.execution = execution

http://git-wip-us.apache.org/repos/asf/incubator-ariatosca/blob/48076260/aria/events/workflow_engine_event_handler.py
----------------------------------------------------------------------
diff --git a/aria/events/workflow_engine_event_handler.py 
b/aria/events/workflow_engine_event_handler.py
index 2f74ded..7df11d1 100644
--- a/aria/events/workflow_engine_event_handler.py
+++ b/aria/events/workflow_engine_event_handler.py
@@ -27,7 +27,9 @@ from . import (
     on_failure_task_signal,
     start_workflow_signal,
     on_success_workflow_signal,
-    on_failure_workflow_signal
+    on_failure_workflow_signal,
+    on_cancelled_workflow_signal,
+    on_cancelling_workflow_signal,
 )
 
 
@@ -60,3 +62,13 @@ def _failure_workflow_handler(context, **kwargs):
 @on_success_workflow_signal.connect
 def _success_workflow_handler(context, **kwargs):
     context.logger.debug('Event: Workflow success: 
{context.name}'.format(context=context))
+
+
+@on_cancelled_workflow_signal.connect
+def _cancel_workflow_handler(context, **kwargs):
+    context.logger.debug('Event: Workflow cancelled: 
{context.name}'.format(context=context))
+
+
+@on_cancelling_workflow_signal.connect
+def _cancelling_workflow_handler(context, **kwargs):
+    context.logger.debug('Event: Workflow cancelling: 
{context.name}'.format(context=context))

http://git-wip-us.apache.org/repos/asf/incubator-ariatosca/blob/48076260/aria/storage/models.py
----------------------------------------------------------------------
diff --git a/aria/storage/models.py b/aria/storage/models.py
index a3686de..c0f8c28 100644
--- a/aria/storage/models.py
+++ b/aria/storage/models.py
@@ -185,13 +185,28 @@ class Execution(Model):
     """
     A Model which represents an execution
     """
+
+    class _Validation(object):
+
+        @staticmethod
+        def execution_status_transition_validation(_, value, instance):
+            """Validation function that verifies execution status transitions 
are OK"""
+            try:
+                current_status = instance.status
+            except AttributeError:
+                return
+            valid_transitions = 
Execution.VALID_TRANSITIONS.get(current_status, [])
+            if current_status != value and value not in valid_transitions:
+                raise ValueError('Cannot change execution status from 
{current} to {new}'.format(
+                    current=current_status,
+                    new=value))
+
     TERMINATED = 'terminated'
     FAILED = 'failed'
     CANCELLED = 'cancelled'
     PENDING = 'pending'
     STARTED = 'started'
     CANCELLING = 'cancelling'
-    FORCE_CANCELLING = 'force_cancelling'
     STATES = (
         TERMINATED,
         FAILED,
@@ -199,21 +214,26 @@ class Execution(Model):
         PENDING,
         STARTED,
         CANCELLING,
-        FORCE_CANCELLING,
     )
     END_STATES = [TERMINATED, FAILED, CANCELLED]
     ACTIVE_STATES = [state for state in STATES if state not in END_STATES]
+    VALID_TRANSITIONS = {
+        PENDING: [STARTED, CANCELLED],
+        STARTED: END_STATES + [CANCELLING],
+        CANCELLING: END_STATES
+    }
 
     id = Field(type=basestring, default=uuid_generator)
-    status = Field(type=basestring, choices=STATES)
+    status = Field(type=basestring, choices=STATES,
+                   
validation_func=_Validation.execution_status_transition_validation)
     deployment_id = Field(type=basestring)
     workflow_id = Field(type=basestring)
     blueprint_id = Field(type=basestring)
-    started_at = Field(type=datetime)
+    created_at = Field(type=datetime, default=datetime.utcnow)
+    started_at = Field(type=datetime, default=None)
     ended_at = Field(type=datetime, default=None)
     error = Field(type=basestring, default=None)
     parameters = Field()
-    is_system_workflow = Field(type=bool, default=False)
 
 
 class Relationship(Model):

http://git-wip-us.apache.org/repos/asf/incubator-ariatosca/blob/48076260/aria/storage/structures.py
----------------------------------------------------------------------
diff --git a/aria/storage/structures.py b/aria/storage/structures.py
index c692d36..a26e7eb 100644
--- a/aria/storage/structures.py
+++ b/aria/storage/structures.py
@@ -72,6 +72,7 @@ class Field(ValidatorMixin):
             self,
             type=None,
             choices=(),
+            validation_func=None,
             default=NO_DEFAULT,
             **kwargs):
         """
@@ -85,6 +86,7 @@ class Field(ValidatorMixin):
         self.type = type
         self.choices = choices
         self.default = default
+        self.validation_func = validation_func
         super(Field, self).__init__(**kwargs)
 
     def __get__(self, instance, owner):
@@ -104,15 +106,16 @@ class Field(ValidatorMixin):
 
     def __set__(self, instance, value):
         field_name = self._field_name(instance)
-        self.validate_value(field_name, value)
+        self.validate_value(field_name, value, instance)
         setattr(instance, self._ATTRIBUTE_NAME(field_name), value)
 
-    def validate_value(self, name, value):
+    def validate_value(self, name, value, instance):
         """
         Validates the value of the field.
 
         :param name: the name of the field.
         :param value: the value of the field.
+        :param instance: the instance containing the field.
         """
         if self.default != self.NO_DEFAULT and value == self.default:
             return
@@ -120,6 +123,8 @@ class Field(ValidatorMixin):
             self.validate_instance(name, value, self.type)
         if self.choices:
             self.validate_in_choice(name, value, self.choices)
+        if self.validation_func:
+            self.validation_func(name, value, instance)
 
     def _field_name(self, instance):
         """
@@ -147,7 +152,7 @@ class IterField(Field):
         """
         super(IterField, self).__init__(choices=(), **kwargs)
 
-    def validate_value(self, name, values):
+    def validate_value(self, name, values, *args):
         """
         Validates the value of each iterable value.
 

http://git-wip-us.apache.org/repos/asf/incubator-ariatosca/blob/48076260/aria/tools/validation.py
----------------------------------------------------------------------
diff --git a/aria/tools/validation.py b/aria/tools/validation.py
index ea1dae7..a33f7a2 100644
--- a/aria/tools/validation.py
+++ b/aria/tools/validation.py
@@ -24,20 +24,9 @@ class ValidatorMixin(object):
     """
 
     _ARGUMENT_TYPE_MESSAGE = '{name} argument must be {type} based, got 
{arg!r}'
-    _ACTION_MESSAGE = 'action arg options: {actions}, got {action}'
     _ARGUMENT_CHOICE_MESSAGE = '{name} argument must be in {choices}, got 
{arg!r}'
 
     @classmethod
-    def validate_actions(cls, action):
-        """
-        Validate action is defined in the class ``ACTIONS`` attribute
-        """
-        # todo: remove this and use validate choice
-        if action not in cls.ACTIONS:
-            raise TypeError(cls._ACTION_MESSAGE.format(
-                actions=cls.ACTIONS, action=action))
-
-    @classmethod
     def validate_in_choice(cls, name, argument, choices):
         """
         Validate ``argument`` is in ``choices``

http://git-wip-us.apache.org/repos/asf/incubator-ariatosca/blob/48076260/aria/workflows/core/engine.py
----------------------------------------------------------------------
diff --git a/aria/workflows/core/engine.py b/aria/workflows/core/engine.py
index 3ed137c..b32ef11 100644
--- a/aria/workflows/core/engine.py
+++ b/aria/workflows/core/engine.py
@@ -48,7 +48,11 @@ class Engine(logger.LoggerMixin):
         """
         try:
             events.start_workflow_signal.send(self._workflow_context)
+            cancel = False
             while True:
+                cancel = self._is_cancel()
+                if cancel:
+                    break
                 for task in self._ended_tasks():
                     self._handle_ended_tasks(task)
                 for task in self._executable_tasks():
@@ -57,11 +61,26 @@ class Engine(logger.LoggerMixin):
                     break
                 else:
                     time.sleep(0.1)
-            events.on_success_workflow_signal.send(self._workflow_context)
+            if cancel:
+                
events.on_cancelled_workflow_signal.send(self._workflow_context)
+            else:
+                events.on_success_workflow_signal.send(self._workflow_context)
         except BaseException as e:
             events.on_failure_workflow_signal.send(self._workflow_context, 
exception=e)
             raise
 
+    def cancel_execution(self):
+        """
+        Send a cancel request to the engine. If execution already started, 
execution status
+        will be modified to 'cancelling' status. If execution is in pending 
mode, execution status
+        will be modified to 'cancelled' directly.
+        """
+        events.on_cancelling_workflow_signal.send(self._workflow_context)
+
+    def _is_cancel(self):
+        return self._workflow_context.execution.status in 
[models.Execution.CANCELLING,
+                                                           
models.Execution.CANCELLED]
+
     def _executable_tasks(self):
         now = datetime.utcnow()
         return (task for task in self._tasks_iter()

http://git-wip-us.apache.org/repos/asf/incubator-ariatosca/blob/48076260/tests/context/__init__.py
----------------------------------------------------------------------
diff --git a/tests/context/__init__.py b/tests/context/__init__.py
new file mode 100644
index 0000000..ae1e83e
--- /dev/null
+++ b/tests/context/__init__.py
@@ -0,0 +1,14 @@
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.

http://git-wip-us.apache.org/repos/asf/incubator-ariatosca/blob/48076260/tests/context/test_workflow.py
----------------------------------------------------------------------
diff --git a/tests/context/test_workflow.py b/tests/context/test_workflow.py
new file mode 100644
index 0000000..b40a66d
--- /dev/null
+++ b/tests/context/test_workflow.py
@@ -0,0 +1,62 @@
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from datetime import datetime
+
+import pytest
+
+from aria import context, application_model_storage
+
+from ..mock import models
+from ..storage import InMemoryModelDriver
+
+
+class TestWorkflowContext(object):
+
+    def test_execution_creation_on_workflow_context_creation(self, storage):
+        self._create_ctx(storage)
+        execution = storage.execution.get(models.EXECUTION_ID)
+        assert execution.id == models.EXECUTION_ID
+        assert execution.deployment_id == models.DEPLOYMENT_ID
+        assert execution.workflow_id == models.WORKFLOW_ID
+        assert execution.blueprint_id == models.BLUEPRINT_ID
+        assert execution.status == storage.execution.model_cls.PENDING
+        assert execution.parameters == {}
+        assert execution.created_at <= datetime.utcnow()
+
+    def test_subsequent_workflow_context_creation_do_not_fail(self, storage):
+        self._create_ctx(storage)
+        self._create_ctx(storage)
+
+    @staticmethod
+    def _create_ctx(storage):
+        return context.workflow.WorkflowContext(
+            name='simple_context',
+            model_storage=storage,
+            resource_storage=None,
+            deployment_id=models.DEPLOYMENT_ID,
+            workflow_id=models.WORKFLOW_ID,
+            execution_id=models.EXECUTION_ID,
+            task_max_retries=models.TASK_MAX_RETRIES,
+            task_retry_interval=models.TASK_RETRY_INTERVAL
+        )
+
+
+@pytest.fixture(scope='function')
+def storage():
+    result = application_model_storage(InMemoryModelDriver())
+    result.setup()
+    result.deployment.store(models.get_deployment())
+    return result

http://git-wip-us.apache.org/repos/asf/incubator-ariatosca/blob/48076260/tests/mock/context.py
----------------------------------------------------------------------
diff --git a/tests/mock/context.py b/tests/mock/context.py
index 13020f3..bef2437 100644
--- a/tests/mock/context.py
+++ b/tests/mock/context.py
@@ -22,6 +22,7 @@ from ..storage import InMemoryModelDriver
 def simple():
     storage = application_model_storage(InMemoryModelDriver())
     storage.setup()
+    storage.deployment.store(models.get_deployment())
     return context.workflow.WorkflowContext(
         name='simple_context',
         model_storage=storage,

http://git-wip-us.apache.org/repos/asf/incubator-ariatosca/blob/48076260/tests/storage/test_field.py
----------------------------------------------------------------------
diff --git a/tests/storage/test_field.py b/tests/storage/test_field.py
index fcabe80..cab218f 100644
--- a/tests/storage/test_field.py
+++ b/tests/storage/test_field.py
@@ -90,9 +90,9 @@ def test_pointer_field():
         vars(PointerField(type=Model, choices=(), default=Field.NO_DEFAULT))
     with pytest.raises(AssertionError):
         PointerField(type=list)
-    pointer_field.validate_value('pointer_field', test_model)
+    pointer_field.validate_value('pointer_field', test_model, None)
     with pytest.raises(TypeError):
-        pointer_field.validate_value('pointer_field', int)
+        pointer_field.validate_value('pointer_field', int, None)
 
 
 def test_iterable_pointer_field():
@@ -103,6 +103,22 @@ def test_iterable_pointer_field():
     with pytest.raises(AssertionError):
         IterPointerField(type=list)
 
-    iter_pointer_field.validate_value('iter_pointer_field', [test_model, 
test_model])
+    iter_pointer_field.validate_value('iter_pointer_field', [test_model, 
test_model], None)
     with pytest.raises(TypeError):
-        iter_pointer_field.validate_value('iter_pointer_field', [int, 
test_model])
+        iter_pointer_field.validate_value('iter_pointer_field', [int, 
test_model], None)
+
+
+def test_custom_field_validation():
+    def validation_func(name, value, instance):
+        assert name == 'id'
+        assert value == 'value'
+        assert isinstance(instance, TestModel)
+
+    class TestModel(Model):
+        id = Field(default='_', validation_func=validation_func)
+
+    obj = TestModel()
+    obj.id = 'value'
+
+    with pytest.raises(AssertionError):
+        obj.id = 'not_value'

http://git-wip-us.apache.org/repos/asf/incubator-ariatosca/blob/48076260/tests/storage/test_models.py
----------------------------------------------------------------------
diff --git a/tests/storage/test_models.py b/tests/storage/test_models.py
index 19262bb..5c9abe3 100644
--- a/tests/storage/test_models.py
+++ b/tests/storage/test_models.py
@@ -26,7 +26,8 @@ from aria.storage.models import (
     RelationshipInstance,
     Node,
     NodeInstance,
-    Blueprint)
+    Blueprint,
+    Execution)
 
 # TODO: add tests per model
 
@@ -273,3 +274,68 @@ def test_relationship_instance():
     assert set(relationship_instances) == set(chain(
         node_instance.relationships_by_target('target_0'),
         node_instance.relationships_by_target('target_1')))
+
+
+def test_execution_status_transition():
+    def create_execution(status):
+        return Execution(
+            id='e_id',
+            deployment_id='d_id',
+            workflow_id='w_id',
+            blueprint_id='b_id',
+            status=status,
+            parameters={}
+        )
+
+    valid_transitions = {
+        Execution.PENDING: [Execution.STARTED,
+                            Execution.CANCELLED,
+                            Execution.PENDING],
+        Execution.STARTED: [Execution.FAILED,
+                            Execution.TERMINATED,
+                            Execution.CANCELLED,
+                            Execution.CANCELLING,
+                            Execution.STARTED],
+        Execution.CANCELLING: [Execution.FAILED,
+                               Execution.TERMINATED,
+                               Execution.CANCELLED,
+                               Execution.CANCELLING],
+        Execution.FAILED: [Execution.FAILED],
+        Execution.TERMINATED: [Execution.TERMINATED],
+        Execution.CANCELLED: [Execution.CANCELLED]
+    }
+
+    invalid_transitions = {
+        Execution.PENDING: [Execution.FAILED,
+                            Execution.TERMINATED,
+                            Execution.CANCELLING],
+        Execution.STARTED: [Execution.PENDING],
+        Execution.CANCELLING: [Execution.PENDING,
+                               Execution.STARTED],
+        Execution.FAILED: [Execution.PENDING,
+                           Execution.STARTED,
+                           Execution.TERMINATED,
+                           Execution.CANCELLED,
+                           Execution.CANCELLING],
+        Execution.TERMINATED: [Execution.PENDING,
+                               Execution.STARTED,
+                               Execution.FAILED,
+                               Execution.CANCELLED,
+                               Execution.CANCELLING],
+        Execution.CANCELLED: [Execution.PENDING,
+                              Execution.STARTED,
+                              Execution.FAILED,
+                              Execution.TERMINATED,
+                              Execution.CANCELLING],
+    }
+
+    for current_status, valid_transitioned_statues in 
valid_transitions.items():
+        for transitioned_status in valid_transitioned_statues:
+            execution = create_execution(current_status)
+            execution.status = transitioned_status
+
+    for current_status, invalid_transitioned_statues in 
invalid_transitions.items():
+        for transitioned_status in invalid_transitioned_statues:
+            execution = create_execution(current_status)
+            with pytest.raises(ValueError):
+                execution.status = transitioned_status

http://git-wip-us.apache.org/repos/asf/incubator-ariatosca/blob/48076260/tests/workflows/core/test_engine.py
----------------------------------------------------------------------
diff --git a/tests/workflows/core/test_engine.py 
b/tests/workflows/core/test_engine.py
index d0d41f3..984d5da 100644
--- a/tests/workflows/core/test_engine.py
+++ b/tests/workflows/core/test_engine.py
@@ -14,6 +14,7 @@
 # limitations under the License.
 
 import time
+import threading
 from datetime import datetime
 
 import pytest
@@ -37,11 +38,20 @@ global_test_holder = {}
 
 class BaseTest(object):
 
+    @classmethod
+    def _execute(cls, workflow_func, workflow_context, executor):
+        eng = cls._engine(workflow_func=workflow_func,
+                          workflow_context=workflow_context,
+                          executor=executor)
+        eng.execute()
+        return eng
+
     @staticmethod
-    def _execute(workflow_func, workflow_context, executor):
+    def _engine(workflow_func, workflow_context, executor):
         graph = workflow_func(ctx=workflow_context)
-        eng = engine.Engine(executor=executor, 
workflow_context=workflow_context, tasks_graph=graph)
-        eng.execute()
+        return engine.Engine(executor=executor,
+                             workflow_context=workflow_context,
+                             tasks_graph=graph)
 
     @staticmethod
     def _op(func, ctx, inputs=None, max_retries=None, retry_interval=None):
@@ -78,9 +88,13 @@ class BaseTest(object):
             workflow_context.states.append('failure')
             workflow_context.exception = exception
 
+        def cancel_workflow_handler(workflow_context, *args, **kwargs):
+            workflow_context.states.append('cancel')
+
         events.start_workflow_signal.connect(start_workflow_handler)
         events.on_success_workflow_signal.connect(success_workflow_handler)
         events.on_failure_workflow_signal.connect(failure_workflow_handler)
+        events.on_cancelled_workflow_signal.connect(cancel_workflow_handler)
         events.sent_task_signal.connect(sent_task_handler)
         try:
             yield
@@ -88,6 +102,7 @@ class BaseTest(object):
             events.start_workflow_signal.disconnect(start_workflow_handler)
             
events.on_success_workflow_signal.disconnect(success_workflow_handler)
             
events.on_failure_workflow_signal.disconnect(failure_workflow_handler)
+            
events.on_cancelled_workflow_signal.disconnect(cancel_workflow_handler)
             events.sent_task_signal.disconnect(sent_task_handler)
 
     @pytest.fixture(scope='function')
@@ -137,6 +152,10 @@ class TestEngine(BaseTest):
         assert workflow_context.states == ['start', 'success']
         assert workflow_context.exception is None
         assert 'sent_task_signal_calls' not in global_test_holder
+        execution = workflow_context.execution
+        assert execution.started_at <= execution.ended_at <= datetime.utcnow()
+        assert execution.error is None
+        assert execution.status == models.Execution.TERMINATED
 
     def test_single_task_successful_execution(self, workflow_context, 
executor):
         @workflow
@@ -162,6 +181,10 @@ class TestEngine(BaseTest):
         assert workflow_context.states == ['start', 'failure']
         assert isinstance(workflow_context.exception, 
exceptions.ExecutorException)
         assert global_test_holder.get('sent_task_signal_calls') == 1
+        execution = workflow_context.execution
+        assert execution.started_at <= execution.ended_at <= datetime.utcnow()
+        assert execution.error is not None
+        assert execution.status == models.Execution.FAILED
 
     def test_two_tasks_execution_order(self, workflow_context, executor):
         @workflow
@@ -189,17 +212,53 @@ class TestEngine(BaseTest):
         @workflow
         def mock_workflow(ctx, graph):
             graph.add_tasks(api.task.WorkflowTask(sub_workflow, ctx=ctx))
-
         self._execute(workflow_func=mock_workflow,
                       workflow_context=workflow_context,
                       executor=executor)
-
         assert workflow_context.states == ['start', 'success']
         assert workflow_context.exception is None
         assert global_test_holder.get('invocations') == [1, 2]
         assert global_test_holder.get('sent_task_signal_calls') == 2
 
 
+class TestCancel(BaseTest):
+
+    def test_cancel_started_execution(self, workflow_context, executor):
+        number_of_tasks = 100
+
+        @workflow
+        def mock_workflow(ctx, graph):
+            return graph.sequence(*(self._op(mock_sleep_task, ctx, 
inputs={'seconds': 0.1})
+                                    for _ in range(number_of_tasks)))
+        eng = self._engine(workflow_func=mock_workflow,
+                           workflow_context=workflow_context,
+                           executor=executor)
+        t = threading.Thread(target=eng.execute)
+        t.start()
+        time.sleep(1)
+        eng.cancel_execution()
+        t.join(timeout=30)
+        assert workflow_context.states == ['start', 'cancel']
+        assert workflow_context.exception is None
+        invocations = global_test_holder.get('invocations', [])
+        assert 0 < len(invocations) < number_of_tasks
+        execution = workflow_context.execution
+        assert execution.started_at <= execution.ended_at <= datetime.utcnow()
+        assert execution.error is None
+        assert execution.status == models.Execution.CANCELLED
+
+    def test_cancel_pending_execution(self, workflow_context, executor):
+        @workflow
+        def mock_workflow(graph, **_):
+            return graph
+        eng = self._engine(workflow_func=mock_workflow,
+                           workflow_context=workflow_context,
+                           executor=executor)
+        eng.cancel_execution()
+        execution = workflow_context.execution
+        assert execution.status == models.Execution.CANCELLED
+
+
 class TestRetries(BaseTest):
 
     def test_one_max_retries_and_success_on_retry(self, workflow_context, 
executor):
@@ -334,3 +393,9 @@ def mock_conditional_failure_task(failure_count):
             raise RuntimeError
     finally:
         invocations.append(time.time())
+
+
+def mock_sleep_task(seconds):
+    invocations = global_test_holder.setdefault('invocations', [])
+    invocations.append(time.time())
+    time.sleep(seconds)

Reply via email to