Repository: incubator-ariatosca Updated Branches: refs/heads/ARIA-63-runtime-properties-modification e81f42758 -> 9f29d2912 (forced update)
ARIA-63 Implement attribute tracking for subprocesses Project: http://git-wip-us.apache.org/repos/asf/incubator-ariatosca/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-ariatosca/commit/9f29d291 Tree: http://git-wip-us.apache.org/repos/asf/incubator-ariatosca/tree/9f29d291 Diff: http://git-wip-us.apache.org/repos/asf/incubator-ariatosca/diff/9f29d291 Branch: refs/heads/ARIA-63-runtime-properties-modification Commit: 9f29d2912142039d1dc8e1398eefc323803a8e73 Parents: dac4da7 Author: Dan Kilman <d...@gigaspaces.com> Authored: Sun Jan 15 17:42:23 2017 +0200 Committer: Dan Kilman <d...@gigaspaces.com> Committed: Tue Jan 17 04:36:23 2017 +0200 ---------------------------------------------------------------------- aria/orchestrator/workflows/executor/process.py | 49 ++-- aria/storage/instrumentation.py | 124 +++++++++ aria/storage/type.py | 20 +- tests/.pylintrc | 2 +- .../test_process_executor_tracked_changes.py | 95 +++++++ tests/storage/test_instrumentation.py | 274 +++++++++++++++++++ 6 files changed, 544 insertions(+), 20 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-ariatosca/blob/9f29d291/aria/orchestrator/workflows/executor/process.py ---------------------------------------------------------------------- diff --git a/aria/orchestrator/workflows/executor/process.py b/aria/orchestrator/workflows/executor/process.py index e0a8aeb..3c86c51 100644 --- a/aria/orchestrator/workflows/executor/process.py +++ b/aria/orchestrator/workflows/executor/process.py @@ -42,6 +42,8 @@ import jsonpickle from aria.utils import imports from aria.orchestrator.workflows.executor import base from aria.orchestrator.context import serialization +from aria.storage import instrumentation +from aria.storage import type as storage_type _IS_WIN = os.name == 'nt' @@ -139,10 +141,17 @@ class ProcessExecutor(base.BaseExecutor): if message_type == 'started': self._task_started(self._tasks[task_id]) elif message_type == 'succeeded': - self._task_succeeded(self._remove_task(task_id)) + task = self._remove_task(task_id) + instrumentation.apply_tracked_changes( + tracked_changes=message['tracked_changes'], + model=task.context.model) + self._task_succeeded(task) elif message_type == 'failed': - self._task_failed(self._remove_task(task_id), - exception=message['exception']) + task = self._remove_task(task_id) + instrumentation.apply_tracked_changes( + tracked_changes=message['tracked_changes'], + model=task.context.model) + self._task_failed(task, exception=message['exception']) else: raise RuntimeError('Invalid state') except BaseException as e: @@ -227,26 +236,27 @@ class _Messenger(object): """Task started message""" self._send_message(type='started') - def succeeded(self): + def succeeded(self, tracked_changes): """Task succeeded message""" - self._send_message(type='succeeded') + self._send_message(type='succeeded', tracked_changes=tracked_changes) - def failed(self, exception): + def failed(self, tracked_changes, exception): """Task failed message""" - self._send_message(type='failed', exception=exception) + self._send_message(type='failed', tracked_changes=tracked_changes, exception=exception) def closed(self): """Executor closed message""" self._send_message(type='closed') - def _send_message(self, type, exception=None): + def _send_message(self, type, tracked_changes=None, exception=None): sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock.connect(('localhost', self.port)) try: data = jsonpickle.dumps({ 'type': type, 'task_id': self.task_id, - 'exception': exception + 'exception': exception, + 'tracked_changes': tracked_changes }) sock.send(struct.pack(_INT_FMT, len(data))) sock.sendall(data) @@ -271,13 +281,20 @@ def _main(): operation_mapping = arguments['operation_mapping'] operation_inputs = arguments['operation_inputs'] context_dict = arguments['context'] - try: - ctx = serialization.operation_context_from_dict(context_dict) - task_func = imports.load_attribute(operation_mapping) - task_func(ctx=ctx, **operation_inputs) - messenger.succeeded() - except BaseException as e: - messenger.failed(exception=e) + + # This must happen before any model class is loaded, because that would trigger + # the listener we are trying to remove. Once it is triggered, many other listeners + # will then be registered. At that point, it is too late. + storage_type.remove_mutable_association_listener() + + with instrumentation.track_changes() as instrument: + try: + ctx = serialization.operation_context_from_dict(context_dict) + task_func = imports.load_attribute(operation_mapping) + task_func(ctx=ctx, **operation_inputs) + messenger.succeeded(tracked_changes=instrument.tracked_changes) + except BaseException as e: + messenger.failed(exception=e, tracked_changes=instrument.tracked_changes) if __name__ == '__main__': _main() http://git-wip-us.apache.org/repos/asf/incubator-ariatosca/blob/9f29d291/aria/storage/instrumentation.py ---------------------------------------------------------------------- diff --git a/aria/storage/instrumentation.py b/aria/storage/instrumentation.py new file mode 100644 index 0000000..1023b94 --- /dev/null +++ b/aria/storage/instrumentation.py @@ -0,0 +1,124 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy + +import sqlalchemy.event + +from . import api +from . import model as _model + +_STUB = object() +_INSTRUMENTED = { + _model.NodeInstance.runtime_properties: dict +} + + +def track_changes(instrumented=None): + return _Instrumentation(instrumented or _INSTRUMENTED) + + +class _Instrumentation(object): + + def __init__(self, instrumented): + self.tracked_changes = {} + self.listeners = [] + self._track_changes(instrumented) + + def _track_changes(self, instrumented): + instrumented_classes = {} + for instrumented_attribute, attribute_type in instrumented.items(): + self._register_set_attribute_listener( + instrumented_attribute=instrumented_attribute, + attribute_type=attribute_type) + instrumented_class = instrumented_attribute.parent.entity + instrumented_class_attributes = instrumented_classes.setdefault(instrumented_class, {}) + instrumented_class_attributes[instrumented_attribute.key] = attribute_type + for instrumented_class, instrumented_attributes in instrumented_classes.items(): + self._register_instance_listeners( + instrumented_class=instrumented_class, + instrumented_attributes=instrumented_attributes) + + def _register_set_attribute_listener(self, instrumented_attribute, attribute_type): + def listener(target, value, *_): + mapi_name = api.generate_lower_name(target.__class__) + tracked_instances = self.tracked_changes.setdefault(mapi_name, {}) + tracked_attributes = tracked_instances.setdefault(target.id, {}) + if value is None: + current = None + else: + current = copy.deepcopy(attribute_type(value)) + tracked_attributes[instrumented_attribute.key] = _Value(_STUB, current) + return current + listener_args = (instrumented_attribute, 'set', listener) + sqlalchemy.event.listen(*listener_args, retval=True) + self.listeners.append(listener_args) + + def _register_instance_listeners(self, instrumented_class, instrumented_attributes): + def listener(target, *_): + mapi_name = api.generate_lower_name(instrumented_class) + tracked_instances = self.tracked_changes.setdefault(mapi_name, {}) + tracked_attributes = tracked_instances.setdefault(target.id, {}) + for attribute_name, attribute_type in instrumented_attributes.items(): + if attribute_name not in tracked_attributes: + initial = getattr(target, attribute_name) + if initial is None: + current = None + else: + current = copy.deepcopy(attribute_type(initial)) + tracked_attributes[attribute_name] = _Value(initial, current) + target.__dict__[attribute_name] = tracked_attributes[attribute_name].current + for listener_args in [(instrumented_class, 'load', listener), + (instrumented_class, 'refresh', listener), + (instrumented_class, 'refresh_flush', listener)]: + sqlalchemy.event.listen(*listener_args) + self.listeners.append(listener_args) + + def restore(self): + for listener_args in self.listeners: + if sqlalchemy.event.contains(*listener_args): + sqlalchemy.event.remove(*listener_args) + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.restore() + + +class _Value(object): + def __init__(self, initial, current): + self.initial = initial + self.current = current + + def __eq__(self, other): + if not isinstance(other, _Value): + return False + return self.initial == other.initial and self.current == other.current + + def __hash__(self): + return hash(self.initial) ^ hash(self.current) + + +def apply_tracked_changes(tracked_changes, model): + for mapi_name, tracked_instances in tracked_changes.items(): + mapi = getattr(model, mapi_name) + for instance_id, tracked_attributes in tracked_instances.items(): + instance = None + for attribute_name, value in tracked_attributes.items(): + if value.initial != value.current: + if not instance: + instance = mapi.get(instance_id) + setattr(instance, attribute_name, value.current) http://git-wip-us.apache.org/repos/asf/incubator-ariatosca/blob/9f29d291/aria/storage/type.py ---------------------------------------------------------------------- diff --git a/aria/storage/type.py b/aria/storage/type.py index ab50b0f..ec81b2c 100644 --- a/aria/storage/type.py +++ b/aria/storage/type.py @@ -16,7 +16,8 @@ import json from sqlalchemy import ( TypeDecorator, - VARCHAR + VARCHAR, + event ) from sqlalchemy.ext import mutable @@ -84,5 +85,18 @@ class _MutableList(mutable.MutableList): raise exceptions.StorageError('SQL Storage error: {0}'.format(str(e))) -_MutableList.associate_with(List) -_MutableDict.associate_with(Dict) +def _mutable_association_listener(mapper, cls): + for prop in mapper.column_attrs: + column_type = prop.columns[0].type + if isinstance(column_type, Dict): + _MutableDict.associate_with_attribute(getattr(cls, prop.key)) + if isinstance(column_type, List): + _MutableList.associate_with_attribute(getattr(cls, prop.key)) + + +def remove_mutable_association_listener(): + if event.contains(*_LISTENER_ARGS): + event.remove(*_LISTENER_ARGS) + +_LISTENER_ARGS = (mutable.mapper, 'mapper_configured', _mutable_association_listener) +event.listen(*_LISTENER_ARGS) http://git-wip-us.apache.org/repos/asf/incubator-ariatosca/blob/9f29d291/tests/.pylintrc ---------------------------------------------------------------------- diff --git a/tests/.pylintrc b/tests/.pylintrc index 23251af..5de0691 100644 --- a/tests/.pylintrc +++ b/tests/.pylintrc @@ -77,7 +77,7 @@ confidence= # --enable=similarities". If you want to run only the classes checker, but have # no Warning level messages displayed, use"--disable=all --enable=classes # --disable=W" -disable=import-star-module-level,old-octal-literal,oct-method,print-statement,unpacking-in-except,parameter-unpacking,backtick,old-raise-syntax,old-ne-operator,long-suffix,dict-view-method,dict-iter-method,metaclass-assignment,next-method-called,raising-string,indexing-exception,raw_input-builtin,long-builtin,file-builtin,execfile-builtin,coerce-builtin,cmp-builtin,buffer-builtin,basestring-builtin,apply-builtin,filter-builtin-not-iterating,using-cmp-argument,useless-suppression,range-builtin-not-iterating,suppressed-message,no-absolute-import,old-division,cmp-method,reload-builtin,zip-builtin-not-iterating,intern-builtin,unichr-builtin,reduce-builtin,standarderror-builtin,unicode-builtin,xrange-builtin,coerce-method,delslice-method,getslice-method,setslice-method,input-builtin,round-builtin,hex-method,nonzero-method,map-builtin-not-iterating,redefined-builtin,no-self-use,missing-docstring,attribute-defined-outside-init,redefined-outer-name,import-error,redefined-variable-type,broad -except,protected-access,global-statement,too-many-locals +disable=import-star-module-level,old-octal-literal,oct-method,print-statement,unpacking-in-except,parameter-unpacking,backtick,old-raise-syntax,old-ne-operator,long-suffix,dict-view-method,dict-iter-method,metaclass-assignment,next-method-called,raising-string,indexing-exception,raw_input-builtin,long-builtin,file-builtin,execfile-builtin,coerce-builtin,cmp-builtin,buffer-builtin,basestring-builtin,apply-builtin,filter-builtin-not-iterating,using-cmp-argument,useless-suppression,range-builtin-not-iterating,suppressed-message,no-absolute-import,old-division,cmp-method,reload-builtin,zip-builtin-not-iterating,intern-builtin,unichr-builtin,reduce-builtin,standarderror-builtin,unicode-builtin,xrange-builtin,coerce-method,delslice-method,getslice-method,setslice-method,input-builtin,round-builtin,hex-method,nonzero-method,map-builtin-not-iterating,redefined-builtin,no-self-use,missing-docstring,attribute-defined-outside-init,redefined-outer-name,import-error,redefined-variable-type,broad -except,protected-access,global-statement,too-many-locals,abstract-method [REPORTS] http://git-wip-us.apache.org/repos/asf/incubator-ariatosca/blob/9f29d291/tests/orchestrator/workflows/executor/test_process_executor_tracked_changes.py ---------------------------------------------------------------------- diff --git a/tests/orchestrator/workflows/executor/test_process_executor_tracked_changes.py b/tests/orchestrator/workflows/executor/test_process_executor_tracked_changes.py new file mode 100644 index 0000000..1564292 --- /dev/null +++ b/tests/orchestrator/workflows/executor/test_process_executor_tracked_changes.py @@ -0,0 +1,95 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from aria.orchestrator.workflows import api +from aria.orchestrator.workflows.core import engine +from aria.orchestrator.workflows.executor import process +from aria.orchestrator import workflow, operation +from aria.orchestrator.workflows import exceptions + +import tests +from tests import mock +from tests import storage + + +_TEST_RUNTIME_PROPERTIES = { + 'some': 'values', 'that': 'are', 'most': 'likely', 'only': 'set', 'here': 'yo' +} + + +def test_track_changes_of_successful_operation(context, executor): + _run_workflow(context=context, executor=executor, op_func=_mock_success_operation) + _assert_tracked_changes_are_applied(context) + + +def test_track_changes_of_failed_operation(context, executor): + with pytest.raises(exceptions.ExecutorException): + _run_workflow(context=context, executor=executor, op_func=_mock_fail_operation) + _assert_tracked_changes_are_applied(context) + + +def _assert_tracked_changes_are_applied(context): + instance = context.model.node_instance.get_by_name(mock.models.DEPENDENCY_NODE_INSTANCE_NAME) + assert instance.runtime_properties == _TEST_RUNTIME_PROPERTIES + + +def _update_runtime_properties(context): + context.node_instance.runtime_properties.clear() + context.node_instance.runtime_properties.update(_TEST_RUNTIME_PROPERTIES) + + +def _run_workflow(context, executor, op_func): + @workflow + def mock_workflow(ctx, graph): + node_instance = ctx.model.node_instance.get_by_name( + mock.models.DEPENDENCY_NODE_INSTANCE_NAME) + node_instance.node.operations['test.op'] = {'operation': _operation_mapping(op_func)} + task = api.task.OperationTask.node_instance(instance=node_instance, name='test.op') + graph.add_tasks(task) + return graph + graph = mock_workflow(ctx=context) # pylint: disable=no-value-for-parameter + eng = engine.Engine(executor=executor, workflow_context=context, tasks_graph=graph) + eng.execute() + + +@operation +def _mock_success_operation(ctx): + _update_runtime_properties(ctx) + + +@operation +def _mock_fail_operation(ctx): + _update_runtime_properties(ctx) + raise RuntimeError + + +def _operation_mapping(func): + return '{name}.{func.__name__}'.format(name=__name__, func=func) + + +@pytest.fixture +def executor(): + result = process.ProcessExecutor(python_path=[tests.ROOT_DIR]) + yield result + result.close() + + +@pytest.fixture +def context(tmpdir): + result = mock.context.simple(storage.get_sqlite_api_kwargs(str(tmpdir))) + yield result + storage.release_sqlite_storage(result.model) http://git-wip-us.apache.org/repos/asf/incubator-ariatosca/blob/9f29d291/tests/storage/test_instrumentation.py ---------------------------------------------------------------------- diff --git a/tests/storage/test_instrumentation.py b/tests/storage/test_instrumentation.py new file mode 100644 index 0000000..b00bbd3 --- /dev/null +++ b/tests/storage/test_instrumentation.py @@ -0,0 +1,274 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +from sqlalchemy import Column, Text, Integer, event + +from aria.storage import ( + model, + structure, + type as aria_type, + ModelStorage, + sql_mapi, + instrumentation +) +from ..storage import get_sqlite_api_kwargs, release_sqlite_storage + + +STUB = instrumentation._STUB +Value = instrumentation._Value +instruments_holder = [] + + +class TestInstrumentation(object): + + def test_track_changes(self, storage): + model_kwargs = dict( + name='name', + dict1={'initial': 'value'}, + dict2={'initial': 'value'}, + list1=['initial'], + list2=['initial'], + int1=0, + int2=0, + string2='string') + model1_instance = MockModel1(**model_kwargs) + model2_instance = MockModel2(**model_kwargs) + storage.mock_model_1.put(model1_instance) + storage.mock_model_2.put(model2_instance) + + instrument = self._track_changes({ + MockModel1.dict1: dict, + MockModel1.list1: list, + MockModel1.int1: int, + MockModel1.string2: str, + MockModel2.dict2: dict, + MockModel2.list2: list, + MockModel2.int2: int, + MockModel2.name: str + }) + + assert not instrument.tracked_changes + + storage_model1_instance = storage.mock_model_1.get(model1_instance.id) + storage_model2_instance = storage.mock_model_2.get(model2_instance.id) + + storage_model1_instance.dict1 = {'hello': 'world'} + storage_model1_instance.dict2 = {'should': 'not track'} + storage_model1_instance.list1 = ['hello'] + storage_model1_instance.list2 = ['should not track'] + storage_model1_instance.int1 = 100 + storage_model1_instance.int2 = 20000 + storage_model1_instance.name = 'should not track' + storage_model1_instance.string2 = 'new_string' + + storage_model2_instance.dict1.update({'should': 'not track'}) + storage_model2_instance.dict2.update({'hello': 'world'}) + storage_model2_instance.list1.append('should not track') + storage_model2_instance.list2.append('hello') + storage_model2_instance.int1 = 100 + storage_model2_instance.int2 = 20000 + storage_model2_instance.name = 'new_name' + storage_model2_instance.string2 = 'should not track' + + assert instrument.tracked_changes == { + 'mock_model_1': { + model1_instance.id: { + 'dict1': Value(STUB, {'hello': 'world'}), + 'list1': Value(STUB, ['hello']), + 'int1': Value(STUB, 100), + 'string2': Value(STUB, 'new_string') + } + }, + 'mock_model_2': { + model2_instance.id: { + 'dict2': Value({'initial': 'value'}, {'hello': 'world', 'initial': 'value'}), + 'list2': Value(['initial'], ['initial', 'hello']), + 'int2': Value(STUB, 20000), + 'name': Value(STUB, 'new_name'), + } + } + } + + def test_attribute_initial_none_value(self, storage): + instance1 = MockModel1(name='name1', dict1=None) + instance2 = MockModel1(name='name2', dict1=None) + storage.mock_model_1.put(instance1) + storage.mock_model_1.put(instance2) + instrument = self._track_changes({MockModel1.dict1: dict}) + instance1 = storage.mock_model_1.get(instance1.id) + instance2 = storage.mock_model_1.get(instance2.id) + instance1.dict1 = {'new': 'value'} + assert instrument.tracked_changes == { + 'mock_model_1': { + instance1.id: {'dict1': Value(STUB, {'new': 'value'})}, + instance2.id: {'dict1': Value(None, None)}, + } + } + + def test_attribute_set_none_value(self, storage): + instance = MockModel1(name='name') + storage.mock_model_1.put(instance) + instrument = self._track_changes({ + MockModel1.dict1: dict, + MockModel1.list1: list, + MockModel1.string2: str, + MockModel1.int1: int + }) + instance = storage.mock_model_1.get(instance.id) + instance.dict1 = None + instance.list1 = None + instance.string2 = None + instance.int1 = None + assert instrument.tracked_changes == { + 'mock_model_1': { + instance.id: { + 'dict1': Value(STUB, None), + 'list1': Value(STUB, None), + 'string2': Value(STUB, None), + 'int1': Value(STUB, None) + } + } + } + + def test_restore(self): + instrument = self._track_changes({MockModel1.dict1: dict}) + # set instance attribute, load instance, refresh instance and flush_refresh listeners + assert len(instrument.listeners) == 4 + for listener_args in instrument.listeners: + assert event.contains(*listener_args) + instrument.restore() + assert len(instrument.listeners) == 4 + for listener_args in instrument.listeners: + assert not event.contains(*listener_args) + return instrument + + def test_restore_twice(self): + instrument = self.test_restore() + instrument.restore() + + def test_instrumentation_context_manager(self, storage): + instance = MockModel1(name='name') + storage.mock_model_1.put(instance) + with self._track_changes({MockModel1.dict1: dict}) as instrument: + instance = storage.mock_model_1.get(instance.id) + instance.dict1 = {'new': 'value'} + assert instrument.tracked_changes == { + 'mock_model_1': {instance.id: {'dict1': Value(STUB, {'new': 'value'})}} + } + assert len(instrument.listeners) == 4 + for listener_args in instrument.listeners: + assert event.contains(*listener_args) + for listener_args in instrument.listeners: + assert not event.contains(*listener_args) + + def test_apply_tracked_changes(self, storage): + initial_values = {'dict1': {'initial': 'value'}, 'list1': ['initial']} + instance1_1 = MockModel1(name='instance1_1', **initial_values) + instance1_2 = MockModel1(name='instance1_2', **initial_values) + instance2_1 = MockModel2(name='instance2_1', **initial_values) + instance2_2 = MockModel2(name='instance2_2', **initial_values) + storage.mock_model_1.put(instance1_1) + storage.mock_model_1.put(instance1_2) + storage.mock_model_2.put(instance2_1) + storage.mock_model_2.put(instance2_2) + + instrument = self._track_changes({ + MockModel1.dict1: dict, + MockModel1.list1: list, + MockModel2.dict1: dict, + MockModel2.list1: list + }) + + def get_instances(): + return (storage.mock_model_1.get(instance1_1.id), + storage.mock_model_1.get(instance1_2.id), + storage.mock_model_2.get(instance2_1.id), + storage.mock_model_2.get(instance2_2.id)) + + instance1_1, instance1_2, instance2_1, instance2_2 = get_instances() + instance1_1.dict1 = {'new': 'value'} + instance1_2.list1 = ['new_value'] + instance2_1.dict1.update({'new': 'value'}) + instance2_2.list1.append('new_value') + + instrument.restore() + storage.mock_model_1._session.expire_all() + + instance1_1, instance1_2, instance2_1, instance2_2 = get_instances() + instance1_1.dict1 = {'overriding': 'value'} + instance1_2.list1 = ['overriding_value'] + instance2_1.dict1 = {'overriding': 'value'} + instance2_2.list1 = ['overriding_value'] + storage.mock_model_1.put(instance1_1) + storage.mock_model_1.put(instance1_2) + storage.mock_model_2.put(instance2_1) + storage.mock_model_2.put(instance2_2) + instance1_1, instance1_2, instance2_1, instance2_2 = get_instances() + assert instance1_1.dict1 == {'overriding': 'value'} + assert instance1_2.list1 == ['overriding_value'] + assert instance2_1.dict1 == {'overriding': 'value'} + assert instance2_2.list1 == ['overriding_value'] + + instrumentation.apply_tracked_changes( + tracked_changes=instrument.tracked_changes, + model=storage) + + instance1_1, instance1_2, instance2_1, instance2_2 = get_instances() + assert instance1_1.dict1 == {'new': 'value'} + assert instance1_2.list1 == ['new_value'] + assert instance2_1.dict1 == {'initial': 'value', 'new': 'value'} + assert instance2_2.list1 == ['initial', 'new_value'] + + def _track_changes(self, instrumented): + instrument = instrumentation.track_changes(instrumented) + instruments_holder.append(instrument) + return instrument + + +@pytest.fixture(autouse=True) +def restore_instrumentation(): + for instrument in instruments_holder: + instrument.restore() + del instruments_holder[:] + + +@pytest.fixture +def storage(): + result = ModelStorage( + api_cls=sql_mapi.SQLAlchemyModelAPI, + api_kwargs=get_sqlite_api_kwargs(), + items=(MockModel1, MockModel2)) + yield result + release_sqlite_storage(result) + + +class _MockModel(structure.ModelMixin): + name = Column(Text) + dict1 = Column(aria_type.Dict) + dict2 = Column(aria_type.Dict) + list1 = Column(aria_type.List) + list2 = Column(aria_type.List) + int1 = Column(Integer) + int2 = Column(Integer) + string2 = Column(Text) + + +class MockModel1(model.DeclarativeBase, _MockModel): + __tablename__ = 'mock_model1' + + +class MockModel2(model.DeclarativeBase, _MockModel): + __tablename__ = 'mock_model2'