Github user mxmrlv commented on a diff in the pull request:

    https://github.com/apache/incubator-ariatosca/pull/31#discussion_r90777224
  
    --- Diff: aria/storage/sql_mapi.py ---
    @@ -0,0 +1,361 @@
    +# Licensed to the Apache Software Foundation (ASF) under one or more
    +# contributor license agreements.  See the NOTICE file distributed with
    +# this work for additional information regarding copyright ownership.
    +# The ASF licenses this file to You under the Apache License, Version 2.0
    +# (the "License"); you may not use this file except in compliance with
    +# the License.  You may obtain a copy of the License at
    +#
    +#     http://www.apache.org/licenses/LICENSE-2.0
    +#
    +# Unless required by applicable law or agreed to in writing, software
    +# distributed under the License is distributed on an "AS IS" BASIS,
    +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    +# See the License for the specific language governing permissions and
    +# limitations under the License.
    +"""
    +SQLAlchemy based MAPI
    +"""
    +
    +from sqlalchemy.exc import SQLAlchemyError
    +from sqlalchemy.sql.elements import Label
    +
    +from aria.utils.collections import OrderedDict
    +
    +from aria.storage import (
    +    api,
    +    exceptions
    +)
    +
    +
    +DEFAULT_SQL_DIALECT = 'sqlite'
    +
    +
    +class SQLAlchemyModelAPI(api.ModelAPI):
    +    """
    +    SQL based MAPI.
    +    """
    +
    +    def __init__(self,
    +                 engine,
    +                 session,
    +                 **kwargs):
    +        super(SQLAlchemyModelAPI, self).__init__(**kwargs)
    +        self._engine = engine
    +        self._session = session
    +
    +    def get(self, entry_id, include=None, filters=None, locking=False, 
**kwargs):
    +        """Return a single result based on the model class and element ID
    +        """
    +        filters = filters or {'id': entry_id}
    +        query = self._get_query(include, filters)
    +        if locking:
    +            query = query.with_for_update()
    +        result = query.first()
    +
    +        if not result:
    +            raise exceptions.StorageError(
    +                'Requested {0} with ID `{1}` was not found'
    +                .format(self.model_cls.__name__, entry_id)
    +            )
    +        return result
    +
    +    def iter(self,
    +             include=None,
    +             filters=None,
    +             pagination=None,
    +             sort=None,
    +             **kwargs):
    +        """Return a (possibly empty) list of `model_class` results
    +        """
    +        query = self._get_query(include, filters, sort)
    +
    +        results, _, _, _ = self._paginate(query, pagination)
    +
    +        for result in results:
    +            yield result
    +
    +    def put(self, entry, **kwargs):
    +        """Create a `model_class` instance from a serializable `model` 
object
    +
    +        :param entry: A dict with relevant kwargs, or an instance of a 
class
    +        that has a `to_dict` method, and whose attributes match the columns
    +        of `model_class` (might also my just an instance of `model_class`)
    +        :return: An instance of `model_class`
    +        """
    +        self._session.add(entry)
    +        self._safe_commit()
    +        return entry
    +
    +    def delete(self, entry_id, filters=None, **kwargs):
    +        """Delete a single result based on the model class and element ID
    +        """
    +        try:
    +            instance = self.get(
    +                entry_id,
    +                filters=filters
    +            )
    +        except exceptions.StorageError:
    +            raise exceptions.StorageError(
    +                'Could not delete {0} with ID `{1}` - element not found'
    +                .format(
    +                    self.model_cls.__name__,
    +                    entry_id
    +                )
    +            )
    +        self._load_properties(instance)
    +        self._session.delete(instance)
    +        self._safe_commit()
    +        return instance
    +
    +    # TODO: this might need rework
    +    def update(self, entry, **kwargs):
    +        """Add `instance` to the DB session, and attempt to commit
    +
    +        :return: The updated instance
    +        """
    +        return self.put(entry)
    +
    +    def refresh(self, entry):
    +        """Reload the instance with fresh information from the DB
    +
    +        :param entry: Instance to be re-loaded from the DB
    +        :return: The refreshed instance
    +        """
    +        self._session.refresh(entry)
    +        self._load_properties(entry)
    +        return entry
    +
    +    def _destroy_connection(self):
    +        pass
    +
    +    def _establish_connection(self):
    +        pass
    +
    +    def create(self):
    +        self.model_cls.__table__.create(self._engine)
    +
    +    def drop(self):
    +        """
    +        Drop the table from the storage.
    +        :return:
    +        """
    +        self.model_cls.__table__.drop(self._engine)
    +
    +    def _safe_commit(self):
    +        """Try to commit changes in the session. Roll back if exception 
raised
    +        Excepts SQLAlchemy errors and rollbacks if they're caught
    +        """
    +        try:
    +            self._session.commit()
    +        except SQLAlchemyError as e:
    +            self._session.rollback()
    +            raise exceptions.StorageError('SQL Storage error: 
{0}'.format(str(e)))
    +
    +    def _get_base_query(self, include, joins):
    +        """Create the initial query from the model class and included 
columns
    +
    +        :param include: A (possibly empty) list of columns to include in
    +        the query
    +        :param joins: A (possibly empty) list of models on which the query
    +        should join
    +        :return: An SQLAlchemy AppenderQuery object
    +        """
    +
    +        # If only some columns are included, query through the session 
object
    +        if include:
    +            query = self._session.query(*include)
    +        else:
    +            # If all columns should be returned, query directly from the 
model
    +            query = self._session.query(self.model_cls)
    +
    +        # Add any joins that might be necessary
    +        for join_model in joins:
    +            query = query.join(join_model)
    +
    +        return query
    +
    +    @staticmethod
    +    def _sort_query(query, sort=None):
    +        """Add sorting clauses to the query
    +
    +        :param query: Base SQL query
    +        :param sort: An optional dictionary where keys are column names to
    +        sort by, and values are the order (asc/desc)
    +        :return: An SQLAlchemy AppenderQuery object
    +        """
    +        if sort:
    +            for column, order in sort.items():
    +                if order == 'desc':
    +                    column = column.desc()
    +                query = query.order_by(column)
    +        return query
    +
    +    @staticmethod
    +    def _filter_query(query, filters):
    +        """Add filter clauses to the query
    +
    +        :param query: Base SQL query
    +        :param filters: An optional dictionary where keys are column names 
to
    +        filter by, and values are values applicable for those columns (or 
lists
    +        of such values)
    +        :return: An SQLAlchemy AppenderQuery object
    +        """
    +        for column, value in filters.items():
    +            # If there are multiple values, use `in_`, otherwise, use `eq`
    +            if isinstance(value, (list, tuple)):
    +                query = query.filter(column.in_(value))
    +            else:
    +                query = query.filter(column == value)
    +
    +        return query
    +
    +    def _get_query(self,
    +                   include=None,
    +                   filters=None,
    +                   sort=None):
    +        """Get an SQL query object based on the params passed
    +
    +        :param include: An optional list of columns to include in the query
    +        :param filters: An optional dictionary where keys are column names 
to
    +        filter by, and values are values applicable for those columns (or 
lists
    +        of such values)
    +        :param sort: An optional dictionary where keys are column names to
    +        sort by, and values are the order (asc/desc)
    +        :return: A sorted and filtered query with only the relevant
    +        columns
    +        """
    +
    +        include = include or []
    +        filters = filters or dict()
    +        sort = sort or OrderedDict()
    +
    +        joins = self._get_join_models_list(include, filters, sort)
    +        include, filters, sort = self._get_columns_from_field_names(
    +            include, filters, sort
    +        )
    +
    +        query = self._get_base_query(include, joins)
    +        query = self._filter_query(query, filters)
    +        query = self._sort_query(query, sort)
    +        return query
    +
    +    def _get_columns_from_field_names(self,
    +                                      include,
    +                                      filters,
    +                                      sort):
    +        """Go over the optional parameters (include, filters, sort), and
    +        replace column names with actual SQLA column objects
    +        """
    +        all_includes = [self._get_column(c) for c in include]
    +        include = []
    +        # Columns that are inferred from properties (Labels) should be 
included
    +        # last for the following joins to work properly
    +        for col in all_includes:
    +            if isinstance(col, Label):
    +                include.append(col)
    +            else:
    +                include.insert(0, col)
    +
    +        filters = dict((self._get_column(c), filters[c]) for c in filters)
    +        sort = OrderedDict((self._get_column(c), sort[c]) for c in sort)
    +
    +        return include, filters, sort
    +
    +    def _get_join_models_list(self, include, filters, sort):
    +        """Return a list of models on which the query should be joined, as
    +        inferred from the include, filter and sort column names
    +        """
    +        if not self.model_cls.is_resource:
    +            return []
    +
    +        all_column_names = include + filters.keys() + sort.keys()
    +        join_columns = set(column_name for column_name in all_column_names
    +                           if self._is_join_column(column_name))
    +
    +        # If the only columns included are the columns on which we would
    +        # normally join, there isn't actually a need to join, as the FROM
    +        # clause in the query will be generated from the relevant models 
anyway
    +        if include == list(join_columns):
    +            return []
    +
    +        # Initializing a set, because the same model can appear in several
    +        # join lists
    +        join_models = set()
    +        for column_name in join_columns:
    +            join_models.update(
    +                self.model_cls.join_properties[column_name]['models']
    +            )
    +        # Sort the models by their correct join order
    +        join_models = sorted(join_models,
    +                             key=lambda model: model.join_order, 
reverse=True)
    +
    +        return join_models
    +
    +    def _is_join_column(self, column_name):
    --- End diff --
    
    check Cloudify


---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at infrastruct...@apache.org or file a JIRA ticket
with INFRA.
---

Reply via email to