fabianmenges closed pull request #2993: Adding YAML Import/Export for 
Datasources to CLI
URL: https://github.com/apache/incubator-superset/pull/2993
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/docs/import_export_datasources.rst 
b/docs/import_export_datasources.rst
new file mode 100644
index 0000000000..351707743e
--- /dev/null
+++ b/docs/import_export_datasources.rst
@@ -0,0 +1,97 @@
+Importing and Exporting Datasources
+===================================
+
+The superset cli allows you to import and export datasources from and to YAML.
+Datasources include both databases and druid clusters. The data is expected to 
be organized in the following hierarchy: ::
+
+    .
+    ???databases
+    |  ???database_1
+    |  |  ???table_1
+    |  |  |  ???columns
+    |  |  |  |  ???column_1
+    |  |  |  |  ???column_2
+    |  |  |  |  ???... (more columns)
+    |  |  |  ???metrics
+    |  |  |     ???metric_1
+    |  |  |     ???metric_2
+    |  |  |     ???... (more metrics)
+    |  |  ??? ... (more tables)
+    |  ??? ... (more databases)
+    ???druid_clusters
+       ???cluster_1
+       |  ???datasource_1
+       |  |  ???columns
+       |  |  |  ???column_1
+       |  |  |  ???column_2
+       |  |  |  ???... (more columns)
+       |  |  ???metrics
+       |  |     ???metric_1
+       |  |     ???metric_2
+       |  |     ???... (more metrics)
+       |  ??? ... (more datasources)
+       ??? ... (more clusters)
+
+
+Exporting Datasources to YAML
+-----------------------------
+You can print your current datasources to stdout by running: ::
+
+    superset export_datasources
+
+
+To save your datasources to a file run: ::
+
+    superset export_datasources -f <filename>
+
+
+By default, default values will be omitted. Use the ``-d`` flag to include 
them.
+If you want back references to be included (e.g. a column to include the table 
id
+it belongs to) use the ``-b`` flag.
+
+
+Exporting the complete supported YAML schema
+--------------------------------------------
+In order to obtain an exhaustive list of all fields you can import using the 
YAML import run: ::
+
+    superset export_datasource_schema
+
+Again, you can use the ``-b`` flag to include back references.
+
+
+Importing Datasources from YAML
+-------------------------------
+In order to import datasources from a YAML file(s), run: ::
+
+    superset import_datasources -p <path or filename>
+
+If you supply a path all files ending with ``*.yaml`` or ``*.yml`` will be 
parsed.
+You can apply additional falgs e.g.: ::
+
+    superset import_datasources -p <path> -r
+
+Will search the supplied path recursively.
+
+The sync flag ``-s`` takes parameters in order to sync the supplied elements 
with
+your file. Be careful this can delete the contents of your meta database. 
Example:
+
+   superset import_datasources -p <path / filename> -s columns,metrics
+
+This will sync all ``metrics`` and ``columns`` for all datasources found in the
+``<path / filename>`` in the Superset meta database. This means columns and 
metrics
+not specified in YAML will be deleted. If you would add ``tables`` to 
``columns,metrics``
+those would be synchronised as well.
+
+
+If you don't supply the sync flag (``-s``) importing will only add and update 
(override) fields.
+E.g. you can add a ``verbose_name`` to the the column ``ds`` in the table 
``random_time_series`` from the example datasets
+by saving the following YAML to file and then running the 
``import_datasources`` command. ::
+
+    databases:
+    - database_name: main
+      tables:
+      - table_name: random_time_series
+        columns:
+        - column_name: ds
+          verbose_name: datetime
+
diff --git a/setup.py b/setup.py
index ec12deafb9..1feebf086b 100644
--- a/setup.py
+++ b/setup.py
@@ -64,9 +64,11 @@ def get_git_sha():
         'markdown==2.6.8',
         'pandas==0.20.3',
         'parsedatetime==2.0.0',
+        'pathlib2==2.3.0',
         'pydruid==0.3.1',
         'PyHive>=0.4.0',
         'python-dateutil==2.6.0',
+        'pyyaml>=3.11',
         'requests==2.17.3',
         'simplejson==3.10.0',
         'six==1.10.0',
diff --git a/superset/cli.py b/superset/cli.py
index ddbdf4bc19..72d33f7bd5 100755
--- a/superset/cli.py
+++ b/superset/cli.py
@@ -4,15 +4,18 @@
 from __future__ import print_function
 from __future__ import unicode_literals
 
-from datetime import datetime
 import logging
+from datetime import datetime
+from pathlib2 import Path
 from subprocess import Popen
+from sys import stdout
+import yaml
 
 from colorama import Fore, Style
 from flask_migrate import MigrateCommand
 from flask_script import Manager
 
-from superset import app, db, security, utils
+from superset import app, db, dict_import_export_util, security, utils
 
 config = app.config
 celery_app = utils.get_celery_app(config)
@@ -178,6 +181,81 @@ def refresh_druid(datasource, merge):
     session.commit()
 
 
[email protected](
+    '-p', '--path', dest='path',
+    help='Path to a single YAML file or path containing multiple YAML ' 
+         'files to import (*.yaml or *.yml)')
[email protected](
+    '-s', '--sync', dest='sync', default='',
+    help='comma seperated list of element types to synchronize '
+         'e.g. "metrics,columns" deletes metrics and columns in the DB '
+         'that are not specified in the YAML file')
[email protected](
+    '-r', '--recursive', dest='recursive', action="store_true",
+    help='recursively search the path for yaml files')
+def import_datasources(path, sync, recursive=False):
+    """Import datasources from YAML"""
+    sync_array = sync.split(',')
+    p = Path(path)
+    files = []
+    if p.is_file():
+        files.append(p)
+    elif p.exists() and not recursive:
+        files.extend(p.glob('*.yaml'))
+        files.extend(p.glob('*.yml'))
+    elif p.exists() and recursive:
+        files.extend(p.rglob('*.yaml'))
+        files.extend(p.rglob('*.yml'))
+    for f in files:
+        logging.info("Importing datasources from file %s", f)
+        try:
+          with f.open() as data_stream:
+              dict_import_export_util.import_from_dict(db.session,
+                                                   yaml.load(data_stream),
+                                                   sync=sync_array)
+        except Exception as e:
+            logging.error("Error when importing datasources from file %s", f)
+            logging.error(e)
+
+
[email protected](
+    '-f', '--datasource-file', default=None, dest='datasource_file',
+    help="Specify the the file to export to")
[email protected](
+    '-p', '--print', action='store_true', dest='print_stdout',
+    help="Print YAML to stdout")
[email protected](
+    '-b', '--back-references', action='store_true', dest='back_references',
+    help="Include parent back references")
[email protected](
+    '-d', '--include-defaults', action='store_true', dest='include_defaults',
+    help="Include fields containing defaults")
+def export_datasources(print_stdout, datasource_file,
+                       back_references, include_defaults):
+    """Export datasources to YAML"""
+    data = dict_import_export_util.export_to_dict(
+        session=db.session,
+        recursive=True,
+        back_references=back_references,
+        include_defaults=include_defaults)
+    if print_stdout or not datasource_file:
+        yaml.safe_dump(data, stdout, default_flow_style=False)
+    if datasource_file:
+        logging.info("Exporting datasources to %s", datasource_file)
+        with open(datasource_file, 'w') as data_stream:
+            yaml.safe_dump(data, data_stream, default_flow_style=False)
+
+
[email protected](
+    '-b', '--back-references', action='store_false',
+    help="Include parent back references")
+def export_datasource_schema(back_references):
+    """Export datasource YAML schema to stdout"""
+    data = dict_import_export_util.export_schema_to_dict(
+        back_references=back_references)
+    yaml.safe_dump(data, stdout, default_flow_style=False)
+
+
 @manager.command
 def update_datasources_cache():
     """Refresh sqllab datasources cache"""
diff --git a/superset/connectors/druid/models.py 
b/superset/connectors/druid/models.py
index 84445151b3..4864f6647a 100644
--- a/superset/connectors/druid/models.py
+++ b/superset/connectors/druid/models.py
@@ -6,6 +6,7 @@
 import logging
 from multiprocessing.pool import ThreadPool
 
+from sqlalchemy.schema import UniqueConstraint
 from dateutil.parser import parse as dparse
 from flask import escape, Markup
 from flask_appbuilder import Model
@@ -28,7 +29,9 @@
 
 from superset import conf, db, import_util, sm, utils
 from superset.connectors.base.models import BaseColumn, BaseDatasource, 
BaseMetric
-from superset.models.helpers import AuditMixinNullable, QueryResult, set_perm
+from superset.models.helpers import (
+  AuditMixinNullable, ImportMixin, QueryResult, set_perm
+)
 from superset.utils import (
     DimSelector, DTTM_ALIAS, flasher, MetricPermException,
 )
@@ -60,7 +63,7 @@ def __init__(self, name, post_aggregator):
         self.post_aggregator = post_aggregator
 
 
-class DruidCluster(Model, AuditMixinNullable):
+class DruidCluster(Model, AuditMixinNullable, ImportMixin):
 
     """ORM object referencing the Druid clusters"""
 
@@ -81,6 +84,11 @@ class DruidCluster(Model, AuditMixinNullable):
     metadata_last_refreshed = Column(DateTime)
     cache_timeout = Column(Integer)
 
+    export_fields = ('cluster_name', 'coordinator_host', 'coordinator_port',
+                     'coordinator_endpoint', 'broker_host', 'broker_port',
+                     'broker_endpoint', 'cache_timeout')
+    export_children = ['datasources']
+
     def __repr__(self):
         return self.verbose_name if self.verbose_name else self.cluster_name
 
@@ -219,6 +227,7 @@ class DruidColumn(Model, BaseColumn):
     """ORM model for storing Druid datasource column metadata"""
 
     __tablename__ = 'columns'
+    __table_args__ = (UniqueConstraint('column_name', 'datasource_id'),)
 
     datasource_id = Column(
         Integer,
@@ -233,8 +242,9 @@ class DruidColumn(Model, BaseColumn):
     export_fields = (
         'datasource_id', 'column_name', 'is_active', 'type', 'groupby',
         'count_distinct', 'sum', 'avg', 'max', 'min', 'filterable',
-        'description', 'dimension_spec_json',
+        'description', 'dimension_spec_json', 'verbose_name'
     )
+    export_parent = 'datasource'
 
     def __repr__(self):
         return self.column_name
@@ -360,6 +370,7 @@ class DruidMetric(Model, BaseMetric):
     """ORM object referencing Druid metrics for a datasource"""
 
     __tablename__ = 'metrics'
+    __table_args__ = (UniqueConstraint('metric_name', 'datasource_id'),)
     datasource_id = Column(
         Integer,
         ForeignKey('datasources.id'))
@@ -374,6 +385,7 @@ class DruidMetric(Model, BaseMetric):
         'metric_name', 'verbose_name', 'metric_type', 'datasource_id',
         'json', 'description', 'is_restricted', 'd3format',
     )
+    export_parent = 'datasource'
 
     @property
     def expression(self):
@@ -409,6 +421,7 @@ class DruidDatasource(Model, BaseDatasource):
     """ORM object referencing Druid datasources (tables)"""
 
     __tablename__ = 'datasources'
+    __table_args__ = (UniqueConstraint('datasource_name', 'cluster_name'),)
 
     type = 'druid'
     query_langtage = 'json'
@@ -438,6 +451,9 @@ class DruidDatasource(Model, BaseDatasource):
         'cluster_name', 'offset', 'cache_timeout', 'params',
     )
 
+    export_parent = 'cluster'
+    export_children = ['columns', 'metrics']
+
     @property
     def database(self):
         return self.cluster
@@ -556,9 +572,12 @@ def int_or_0(v):
         v2nums = [int_or_0(n) for n in v2.split('.')]
         v1nums = (v1nums + [0, 0, 0])[:3]
         v2nums = (v2nums + [0, 0, 0])[:3]
-        return v1nums[0] > v2nums[0] or \
-            (v1nums[0] == v2nums[0] and v1nums[1] > v2nums[1]) or \
-            (v1nums[0] == v2nums[0] and v1nums[1] == v2nums[1] and v1nums[2] > 
v2nums[2])
+        return (
+                   v1nums[0] > v2nums[0] or
+                   (v1nums[0] == v2nums[0] and v1nums[1] > v2nums[1]) or
+                   (v1nums[0] == v2nums[0] and v1nums[1] == v2nums[1] and
+                       v1nums[2] > v2nums[2])
+               )
 
     def latest_metadata(self):
         """Returns segment metadata from the latest segment"""
diff --git a/superset/connectors/druid/views.py 
b/superset/connectors/druid/views.py
index 713a43c36b..1c1fb3e906 100644
--- a/superset/connectors/druid/views.py
+++ b/superset/connectors/druid/views.py
@@ -14,7 +14,7 @@
 from superset.views.base import (
     BaseSupersetView, DatasourceFilter, DeleteMixin,
     get_datasource_exist_error_mgs, ListWidgetWithCheckboxes, 
SupersetModelView,
-    validate_json,
+    validate_json, YamlExportMixin
 )
 from . import models
 
@@ -122,7 +122,7 @@ def post_update(self, metric):
 appbuilder.add_view_no_menu(DruidMetricInlineView)
 
 
-class DruidClusterModelView(SupersetModelView, DeleteMixin):  # noqa
+class DruidClusterModelView(SupersetModelView, DeleteMixin, YamlExportMixin):  
# noqa
     datamodel = SQLAInterface(models.DruidCluster)
 
     list_title = _('List Druid Cluster')
@@ -168,7 +168,7 @@ def _delete(self, pk):
     category_icon='fa-database',)
 
 
-class DruidDatasourceModelView(DatasourceModelView, DeleteMixin):  # noqa
+class DruidDatasourceModelView(DatasourceModelView, DeleteMixin, 
YamlExportMixin):  # noqa
     datamodel = SQLAInterface(models.DruidDatasource)
 
     list_title = _('List Druid Datasource')
diff --git a/superset/connectors/sqla/models.py 
b/superset/connectors/sqla/models.py
index 599062265d..ff9e30d458 100644
--- a/superset/connectors/sqla/models.py
+++ b/superset/connectors/sqla/models.py
@@ -12,6 +12,7 @@
     and_, asc, Boolean, Column, DateTime, desc, ForeignKey, Integer, or_,
     select, String, Text,
 )
+from sqlalchemy.schema import UniqueConstraint
 from sqlalchemy.orm import backref, relationship
 from sqlalchemy.sql import column, literal_column, table, text
 from sqlalchemy.sql.expression import TextAsFrom
@@ -31,6 +32,7 @@ class TableColumn(Model, BaseColumn):
     """ORM object for table columns, each table can have multiple columns"""
 
     __tablename__ = 'table_columns'
+    __table_args__ = (UniqueConstraint('table_id', 'column_name'),)
     table_id = Column(Integer, ForeignKey('tables.id'))
     table = relationship(
         'SqlaTable',
@@ -47,6 +49,7 @@ class TableColumn(Model, BaseColumn):
         'filterable', 'expression', 'description', 'python_date_format',
         'database_expression',
     )
+    export_parent = 'table'
 
     @property
     def sqla_col(self):
@@ -120,6 +123,7 @@ class SqlMetric(Model, BaseMetric):
     """ORM object for metrics, each table can have multiple metrics"""
 
     __tablename__ = 'sql_metrics'
+    __table_args__ = (UniqueConstraint('table_id', 'metric_name'),)
     table_id = Column(Integer, ForeignKey('tables.id'))
     table = relationship(
         'SqlaTable',
@@ -130,6 +134,7 @@ class SqlMetric(Model, BaseMetric):
     export_fields = (
         'metric_name', 'verbose_name', 'metric_type', 'table_id', 'expression',
         'description', 'is_restricted', 'd3format')
+    export_parent = 'table'
 
     @property
     def sqla_col(self):
@@ -162,6 +167,8 @@ class SqlaTable(Model, BaseDatasource):
     column_class = TableColumn
 
     __tablename__ = 'tables'
+    __table_args__ = (UniqueConstraint('database_id', 'table_name'),)
+
     table_name = Column(String(250))
     main_dttm_col = Column(String(250))
     database_id = Column(Integer, ForeignKey('dbs.id'), nullable=False)
@@ -179,15 +186,13 @@ class SqlaTable(Model, BaseDatasource):
     sql = Column(Text)
 
     baselink = 'tablemodelview'
+
     export_fields = (
         'table_name', 'main_dttm_col', 'description', 'default_endpoint',
         'database_id', 'offset', 'cache_timeout', 'schema',
         'sql', 'params')
-
-    __table_args__ = (
-        sa.UniqueConstraint(
-            'database_id', 'schema', 'table_name',
-            name='_customer_location_uc'),)
+    export_parent = 'database'
+    export_children = ['metrics', 'columns']
 
     def __repr__(self):
         return self.name
diff --git a/superset/connectors/sqla/views.py 
b/superset/connectors/sqla/views.py
index 57b271499c..8db8a1aded 100644
--- a/superset/connectors/sqla/views.py
+++ b/superset/connectors/sqla/views.py
@@ -12,7 +12,7 @@
 from superset.utils import has_access
 from superset.views.base import (
     DatasourceFilter, DeleteMixin, get_datasource_exist_error_mgs,
-    ListWidgetWithCheckboxes, SupersetModelView,
+    ListWidgetWithCheckboxes, SupersetModelView, YamlExportMixin
 )
 from . import models
 
@@ -148,7 +148,7 @@ def post_update(self, metric):
 appbuilder.add_view_no_menu(SqlMetricInlineView)
 
 
-class TableModelView(DatasourceModelView, DeleteMixin):  # noqa
+class TableModelView(DatasourceModelView, DeleteMixin, YamlExportMixin):  # 
noqa
     datamodel = SQLAInterface(models.SqlaTable)
 
     list_title = _('List Tables')
diff --git a/superset/dict_import_export_util.py 
b/superset/dict_import_export_util.py
new file mode 100644
index 0000000000..0c61f7b022
--- /dev/null
+++ b/superset/dict_import_export_util.py
@@ -0,0 +1,64 @@
+import logging
+
+from superset.models.core import Database
+from superset.connectors.druid.models import DruidCluster
+
+DATABASES_KEY = 'databases'
+DRUID_CLUSTERS_KEY = 'druid_clusters'
+
+
+def export_schema_to_dict(back_references):
+    """Exports the supported import/export schema to a dictionary"""
+    databases = [Database.export_schema(recursive=True,
+                 include_parent_ref=back_references)]
+    clusters = [DruidCluster.export_schema(recursive=True,
+                include_parent_ref=back_references)]
+    data = dict()
+    if databases:
+        data[DATABASES_KEY] = databases
+    if clusters:
+        data[DRUID_CLUSTERS_KEY] = clusters
+    return data
+
+
+def export_to_dict(session,
+                   recursive,
+                   back_references,
+                   include_defaults):
+    """Exports databases and druid clusters to a dictionary"""
+    logging.info("Starting export")
+    dbs = session.query(Database)
+    databases = [database.export_to_dict(recursive=recursive,
+                 include_parent_ref=back_references,
+                 include_defaults=include_defaults) for database in dbs]
+    logging.info("Exported %d %s", len(databases), DATABASES_KEY)
+    cls = session.query(DruidCluster)
+    clusters = [cluster.export_to_dict(recursive=recursive,
+                include_parent_ref=back_references,
+                include_defaults=include_defaults) for cluster in cls]
+    logging.info("Exported %d %s", len(clusters), DRUID_CLUSTERS_KEY)
+    data = dict()
+    if databases:
+        data[DATABASES_KEY] = databases
+    if clusters:
+        data[DRUID_CLUSTERS_KEY] = clusters
+    return data
+
+
+def import_from_dict(session, data, sync=[]):
+    """Imports databases and druid clusters from dictionary"""
+    if isinstance(data, dict):
+      logging.info("Importing %d %s",
+                   len(data.get(DATABASES_KEY, [])),
+                   DATABASES_KEY)
+      for database in data.get(DATABASES_KEY, []):
+          Database.import_from_dict(session, database, sync=sync)
+
+      logging.info("Importing %d %s",
+                   len(data.get(DRUID_CLUSTERS_KEY, [])),
+                   DRUID_CLUSTERS_KEY)
+      for datasource in data.get(DRUID_CLUSTERS_KEY, []):
+          DruidCluster.import_from_dict(session, datasource, sync=sync)
+      session.commit()
+    else:
+      logging.info("Supplied object is not a dictionary.")
diff --git a/superset/models/core.py b/superset/models/core.py
index 68c305f227..d170dd9cc6 100644
--- a/superset/models/core.py
+++ b/superset/models/core.py
@@ -28,6 +28,7 @@
 from sqlalchemy.orm import relationship, subqueryload
 from sqlalchemy.orm.session import make_transient
 from sqlalchemy.pool import NullPool
+from sqlalchemy.schema import UniqueConstraint
 from sqlalchemy.sql import text
 from sqlalchemy.sql.expression import TextAsFrom
 from sqlalchemy_utils import EncryptedType
@@ -537,12 +538,14 @@ def export_dashboards(cls, dashboard_ids):
         })
 
 
-class Database(Model, AuditMixinNullable):
+class Database(Model, AuditMixinNullable, ImportMixin):
 
     """An ORM object that stores Database related information"""
 
     __tablename__ = 'dbs'
     type = 'table'
+    __table_args__ = (UniqueConstraint('database_name'),)
+
 
     id = Column(Integer, primary_key=True)
     verbose_name = Column(String(250), unique=True)
@@ -567,6 +570,10 @@ class Database(Model, AuditMixinNullable):
     perm = Column(String(1000))
     custom_password_store = config.get('SQLALCHEMY_CUSTOM_PASSWORD_STORE')
     impersonate_user = Column(Boolean, default=False)
+    export_fields = ('database_name', 'sqlalchemy_uri', 'cache_timeout',
+                     'expose_in_sqllab', 'allow_run_sync', 'allow_run_async',
+                     'allow_ctas', 'extra')
+    export_children = ['tables']
 
     def __repr__(self):
         return self.verbose_name if self.verbose_name else self.database_name
diff --git a/superset/models/helpers.py b/superset/models/helpers.py
index d4ae9f45e8..2e6453ec11 100644
--- a/superset/models/helpers.py
+++ b/superset/models/helpers.py
@@ -6,7 +6,12 @@
 
 from datetime import datetime
 import json
+import yaml
+import logging
 import re
+from sqlalchemy.orm.exc import MultipleResultsFound
+from sqlalchemy import and_, or_
+from sqlalchemy import UniqueConstraint
 
 from flask import escape, Markup
 from flask_appbuilder.models.decorators import renders
@@ -20,6 +25,174 @@
 
 
 class ImportMixin(object):
+    export_parent = None
+    # The name of the attribute
+    # with the SQL Alchemy back reference
+
+    export_children = []
+    # List of (str) names of attributes
+    # with the SQL Alchemy forward references
+
+    export_fields = []
+    # The names of the attributes
+    # that are available for import and export
+
+    @classmethod
+    def _parent_foreign_key_mappings(cls):
+        """Get a mapping of foreign name to the local name of foreign keys"""
+        parent_rel = cls.__mapper__.relationships.get(cls.export_parent)
+        if parent_rel:
+            return {l.name: r.name for (l, r) in parent_rel.local_remote_pairs}
+        return {}
+
+    @classmethod
+    def _unique_constrains(cls):
+        """Get all (single column and multi column) unique constraints"""
+        unique = [{c.name for c in u.columns} for u in cls.__table_args__
+                  if isinstance(u, UniqueConstraint)]
+        unique.extend({c.name} for c in cls.__table__.columns if c.unique)
+        return unique
+
+    @classmethod
+    def export_schema(cls, recursive=True, include_parent_ref=False):
+        """Export schema as a dictionary"""
+        parent_excludes = {}
+        if not include_parent_ref:
+            parent_ref = cls.__mapper__.relationships.get(cls.export_parent)
+            if parent_ref:
+                parent_excludes = {c.name for c in parent_ref.local_columns}
+
+        def formatter(c): return ("{0} Default ({1})".format(
+            str(c.type), c.default.arg) if c.default else str(c.type))
+
+        schema = {c.name: formatter(c) for c in cls.__table__.columns
+                  if (c.name in cls.export_fields and
+                  c.name not in parent_excludes)}
+        if recursive:
+            for c in cls.export_children:
+                child_class = cls.__mapper__.relationships[c].argument.class_
+                schema[c] = [child_class.export_schema(recursive=recursive,
+                             include_parent_ref=include_parent_ref)]
+        return schema
+
+    @classmethod
+    def import_from_dict(cls, session, dict_rep, parent=None,
+                         recursive=True, sync=[]):
+        """Import obj from a dictionary"""
+        parent_refs = cls._parent_foreign_key_mappings()
+        export_fields = set(cls.export_fields) | set(parent_refs.keys())
+        new_children = {c: dict_rep.get(c) for c in cls.export_children
+                        if c in dict_rep}
+        unique_constrains = cls._unique_constrains()
+
+        filters = []  # Using these filters to check if obj already exists
+
+        # Remove fields that should not get imported
+        for k in list(dict_rep):
+            if k not in export_fields:
+                del dict_rep[k]
+
+        if not parent:
+            if cls.export_parent:
+                for p in parent_refs.keys():
+                    if p not in dict_rep:
+                        raise RuntimeError(
+                          '{0}: Missing field {1}'.format(cls.__name__, p))
+        else:
+            # Set foreign keys to parent obj
+            for k, v in parent_refs.items():
+                dict_rep[k] = getattr(parent, v)
+
+        # Add filter for parent obj
+        filters.extend([getattr(cls, k) == dict_rep.get(k)
+                        for k in parent_refs.keys()])
+
+        # Add filter for unique constraints
+        ucs = [and_(*[getattr(cls, k) == dict_rep.get(k)
+               for k in cs if dict_rep.get(k) is not None])
+               for cs in unique_constrains]
+        filters.append(or_(*ucs))
+
+        # Check if object already exists in DB, break if more than one is found
+        try:
+            obj_query = session.query(cls).filter(and_(*filters))
+            obj = obj_query.one_or_none()
+        except MultipleResultsFound, e:
+            logging.error('Error importing %s \n %s \n %s', cls.__name__,
+                str(obj_query),
+                yaml.safe_dump(dict_rep))
+            raise e
+
+        if not obj:
+            is_new_obj = True
+            # Create new DB object
+            obj = cls(**dict_rep)
+            logging.info("Importing new %s %s", obj.__tablename__, str(obj))
+            if cls.export_parent and parent:
+                setattr(obj, cls.export_parent, parent)
+            session.add(obj)
+        else:
+            is_new_obj = False
+            logging.info("Updating %s %s", obj.__tablename__, str(obj))
+            # Update columns
+            for k, v in dict_rep.items():
+                setattr(obj, k, v)
+
+        # Recursively create children
+        if recursive:
+            for c in cls.export_children:
+                child_class = cls.__mapper__.relationships[c].argument.class_
+                added = []
+                for c_obj in new_children.get(c, []):
+                    added.append(child_class.import_from_dict(session=session,
+                                                   dict_rep=c_obj,
+                                                   parent=obj,
+                                                   sync=sync))
+                # If children should get synced, delete the ones that did not
+                # get updated.
+                if c in sync and not is_new_obj:
+                    back_refs = child_class._parent_foreign_key_mappings()
+                    delete_filters = [getattr(child_class, k) ==
+                                      getattr(obj, back_refs.get(k))
+                                      for k in back_refs.keys()]
+                    to_delete = set(session.query(child_class)
+                        .filter(and_(*delete_filters))).difference(set(added))
+                    for o in to_delete:
+                        logging.info("Deleting %s %s", c, str(obj))
+                        session.delete(o)
+
+        return obj
+
+    def export_to_dict(self, recursive=True, include_parent_ref=False,
+                       include_defaults=False):
+        """Export obj to dictionary"""
+        cls = self.__class__
+        parent_excludes = {}
+        if recursive and not include_parent_ref:
+            parent_ref = cls.__mapper__.relationships.get(cls.export_parent)
+            if parent_ref:
+                parent_excludes = {c.name for c in parent_ref.local_columns}
+        dict_rep = {c.name: getattr(self, c.name)
+                    for c in cls.__table__.columns
+                    if (c.name in self.export_fields and
+                        c.name not in parent_excludes and
+                        (include_defaults or (
+                             getattr(self, c.name) is not None and
+                             (not c.default or
+                              getattr(self, c.name) != c.default.arg))))
+                    }
+        if recursive:
+            for c in self.export_children:
+                # sorting to make lists of children stable
+                dict_rep[c] = sorted([child.export_to_dict(
+                        recursive=recursive,
+                        include_parent_ref=include_parent_ref,
+                        include_defaults=include_defaults)
+                               for child in getattr(self, c)],
+                        key=lambda k: sorted(k.items()))
+
+        return dict_rep
+
     def override(self, obj):
         """Overrides the plain fields of the dashboard."""
         for field in obj.__class__.export_fields:
diff --git a/superset/views/base.py b/superset/views/base.py
index 7bc55d2c27..a7e626f45a 100644
--- a/superset/views/base.py
+++ b/superset/views/base.py
@@ -1,7 +1,9 @@
+from datetime import datetime
 import functools
 import json
 import logging
 import traceback
+import yaml
 
 from flask import abort, flash, g, get_flashed_messages, redirect, Response
 from flask_appbuilder import BaseView, ModelView
@@ -41,6 +43,14 @@ def json_error_response(msg=None, status=500, 
stacktrace=None, payload=None):
         status=status, mimetype='application/json')
 
 
+def generate_download_headers(extension, filename=None):
+  filename = filename if filename else datetime.now().strftime("%Y%m%d_%H%M%S")
+  content_disp = "attachment; filename={}.{}".format(filename, extension)
+  headers = {
+    "Content-Disposition": content_disp,
+  }
+  return headers
+
 def api(f):
     """
     A decorator to label an endpoint as an API. Catches uncaught exceptions and
@@ -219,6 +229,19 @@ def validate_json(form, field):  # noqa
         raise Exception(_("json isn't valid"))
 
 
+class YamlExportMixin(object):
+    @action("yaml_export", __("Export as YAML"), __("Export as YAML?"), 
"fa-download")
+    def yaml_export(self, items):
+        if not isinstance(items, list):
+            items = [items]
+
+        data = [t.export_to_dict() for t in items]
+        return Response(
+            yaml.safe_dump(data),
+            headers=generate_download_headers("yaml"),
+            mimetype="application/text")
+
+
 class DeleteMixin(object):
     def _delete(self, pk):
         """
diff --git a/superset/views/core.py b/superset/views/core.py
index ef0cbf5844..e0a5916b85 100755
--- a/superset/views/core.py
+++ b/superset/views/core.py
@@ -45,7 +45,8 @@
 from superset.utils import has_access, merge_extra_filters, QueryStatus
 from .base import (
     api, BaseSupersetView, CsvResponse, DeleteMixin, get_error_msg,
-    get_user_roles, json_error_response, SupersetFilter, SupersetModelView,
+    generate_download_headers, get_user_roles, json_error_response,
+    SupersetFilter, SupersetModelView, YamlExportMixin
 )
 
 config = app.config
@@ -161,16 +162,9 @@ def apply(self, query, func):  # noqa
         return query
 
 
-def generate_download_headers(extension):
-    filename = datetime.now().strftime('%Y%m%d_%H%M%S')
-    content_disp = 'attachment; filename={}.{}'.format(filename, extension)
-    headers = {
-        'Content-Disposition': content_disp,
-    }
-    return headers
 
 
-class DatabaseView(SupersetModelView, DeleteMixin):  # noqa
+class DatabaseView(SupersetModelView, DeleteMixin, YamlExportMixin):  # noqa
     datamodel = SQLAInterface(models.Database)
 
     list_title = _('List Databases')
diff --git a/tests/dict_import_export_tests.py 
b/tests/dict_import_export_tests.py
new file mode 100644
index 0000000000..68e1dd57ba
--- /dev/null
+++ b/tests/dict_import_export_tests.py
@@ -0,0 +1,348 @@
+"""Unit tests for Superset"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+from __future__ import unicode_literals
+
+import json
+import unittest
+import yaml
+
+from superset import db
+from superset.connectors.druid.models import (
+    DruidDatasource, DruidColumn, DruidMetric)
+from superset.connectors.sqla.models import SqlaTable, TableColumn, SqlMetric
+
+from .base_tests import SupersetTestCase
+
+DBREF = 'dict_import__export_test'
+NAME_PREFIX = 'dict_'
+ID_PREFIX = 20000
+
+
+class DictImportExportTests(SupersetTestCase):
+    """Testing export import functionality for dashboards"""
+
+    def __init__(self, *args, **kwargs):
+        super(DictImportExportTests, self).__init__(*args, **kwargs)
+
+    @classmethod
+    def delete_imports(cls):
+        # Imported data clean up
+        session = db.session
+        for table in session.query(SqlaTable):
+            if DBREF in table.params_dict:
+                session.delete(table)
+        for datasource in session.query(DruidDatasource):
+            if DBREF in datasource.params_dict:
+                session.delete(datasource)
+        session.commit()
+
+    @classmethod
+    def setUpClass(cls):
+        cls.delete_imports()
+
+    @classmethod
+    def tearDownClass(cls):
+        cls.delete_imports()
+
+    def create_table(
+            self, name, schema='', id=0, cols_names=[], metric_names=[]):
+        database_name = 'main'
+        name = '{0}{1}'.format(NAME_PREFIX, name)
+        params = {DBREF: id, 'database_name': database_name}
+
+        dict_rep = {
+            'database_id': self.get_main_database(db.session).id,
+            'table_name': name,
+            'schema': schema,
+            'id': id,
+            'params': json.dumps(params),
+            'columns': [{'column_name': c}
+                        for c in cols_names],
+            'metrics': [{'metric_name': c} for c in metric_names]
+        }
+
+        table = SqlaTable(
+            id=id,
+            schema=schema,
+            table_name=name,
+            params=json.dumps(params)
+        )
+        for col_name in cols_names:
+            table.columns.append(TableColumn(column_name=col_name))
+        for metric_name in metric_names:
+            table.metrics.append(SqlMetric(metric_name=metric_name))
+        return table, dict_rep
+
+    def create_druid_datasource(
+            self, name, id=0, cols_names=[], metric_names=[]):
+        name = '{0}{1}'.format(NAME_PREFIX, name)
+        cluster_name = 'druid_test'
+        params = {DBREF: id, 'database_name': cluster_name}
+        dict_rep = {
+          'cluster_name': cluster_name,
+          'datasource_name': name,
+          'id': id,
+          'params': json.dumps(params),
+          'columns': [{"column_name": c} for c in cols_names],
+          'metrics': [{"metric_name": c} for c in metric_names]
+        }
+
+        datasource = DruidDatasource(
+            id=id,
+            datasource_name=name,
+            cluster_name=cluster_name,
+            params=json.dumps(params)
+        )
+        for col_name in cols_names:
+            datasource.columns.append(DruidColumn(column_name=col_name))
+        for metric_name in metric_names:
+            datasource.metrics.append(DruidMetric(metric_name=metric_name))
+        return datasource, dict_rep
+
+    def get_datasource(self, datasource_id):
+        return db.session.query(DruidDatasource).filter_by(
+            id=datasource_id).first()
+
+    def get_table_by_name(self, name):
+        return db.session.query(SqlaTable).filter_by(
+            table_name=name).first()
+
+    def yaml_compare(self, obj_1, obj_2):
+        obj_1_str = yaml.safe_dump(obj_1, default_flow_style=False)
+        obj_2_str = yaml.safe_dump(obj_2, default_flow_style=False)
+        self.assertEquals(obj_1_str, obj_2_str)
+
+    def assert_table_equals(self, expected_ds, actual_ds):
+        self.assertEquals(expected_ds.table_name, actual_ds.table_name)
+        self.assertEquals(expected_ds.main_dttm_col, actual_ds.main_dttm_col)
+        self.assertEquals(expected_ds.schema, actual_ds.schema)
+        self.assertEquals(len(expected_ds.metrics), len(actual_ds.metrics))
+        self.assertEquals(len(expected_ds.columns), len(actual_ds.columns))
+        self.assertEquals(
+            set([c.column_name for c in expected_ds.columns]),
+            set([c.column_name for c in actual_ds.columns]))
+        self.assertEquals(
+            set([m.metric_name for m in expected_ds.metrics]),
+            set([m.metric_name for m in actual_ds.metrics]))
+
+    def assert_datasource_equals(self, expected_ds, actual_ds):
+        self.assertEquals(
+            expected_ds.datasource_name, actual_ds.datasource_name)
+        self.assertEquals(expected_ds.main_dttm_col, actual_ds.main_dttm_col)
+        self.assertEquals(len(expected_ds.metrics), len(actual_ds.metrics))
+        self.assertEquals(len(expected_ds.columns), len(actual_ds.columns))
+        self.assertEquals(
+            set([c.column_name for c in expected_ds.columns]),
+            set([c.column_name for c in actual_ds.columns]))
+        self.assertEquals(
+            set([m.metric_name for m in expected_ds.metrics]),
+            set([m.metric_name for m in actual_ds.metrics]))
+
+    def test_import_table_no_metadata(self):
+        table, dict_table = self.create_table('pure_table', id=ID_PREFIX + 1)
+        new_table = SqlaTable.import_from_dict(db.session, dict_table)
+        db.session.commit()
+        imported_id = new_table.id
+        imported = self.get_table(imported_id)
+        self.assert_table_equals(table, imported)
+        self.yaml_compare(table.export_to_dict(), imported.export_to_dict())
+
+    def test_import_table_1_col_1_met(self):
+        table, dict_table = self.create_table(
+            'table_1_col_1_met', id=ID_PREFIX + 2,
+            cols_names=["col1"], metric_names=["metric1"])
+        imported_table = SqlaTable.import_from_dict(db.session, dict_table)
+        db.session.commit()
+        imported = self.get_table(imported_table.id)
+        self.assert_table_equals(table, imported)
+        self.assertEquals(
+            {DBREF: ID_PREFIX + 2, 'database_name': 'main'},
+            json.loads(imported.params))
+        self.yaml_compare(table.export_to_dict(), imported.export_to_dict())
+
+    def test_import_table_2_col_2_met(self):
+        table, dict_table = self.create_table(
+            'table_2_col_2_met', id=ID_PREFIX + 3, cols_names=['c1', 'c2'],
+            metric_names=['m1', 'm2'])
+        imported_table = SqlaTable.import_from_dict(db.session, dict_table)
+        db.session.commit()
+        imported = self.get_table(imported_table.id)
+        self.assert_table_equals(table, imported)
+        self.yaml_compare(table.export_to_dict(), imported.export_to_dict())
+
+    def test_import_table_override_append(self):
+        table, dict_table = self.create_table(
+            'table_override', id=ID_PREFIX + 3,
+            cols_names=['col1'],
+            metric_names=['m1'])
+        imported_table = SqlaTable.import_from_dict(db.session, dict_table)
+        db.session.commit()
+        table_over, dict_table_over = self.create_table(
+                'table_override', id=ID_PREFIX + 3,
+                cols_names=['new_col1', 'col2', 'col3'],
+                metric_names=['new_metric1'])
+        imported_over_table = SqlaTable.import_from_dict(
+                db.session,
+                dict_table_over)
+        db.session.commit()
+
+        imported_over = self.get_table(imported_over_table.id)
+        self.assertEquals(imported_table.id, imported_over.id)
+        expected_table, _ = self.create_table(
+            'table_override', id=ID_PREFIX + 3,
+            metric_names=['new_metric1', 'm1'],
+            cols_names=['col1', 'new_col1', 'col2', 'col3'])
+        self.assert_table_equals(expected_table, imported_over)
+        self.yaml_compare(expected_table.export_to_dict(),
+                          imported_over.export_to_dict())
+
+    def test_import_table_override_sync(self):
+        table, dict_table = self.create_table(
+            'table_override', id=ID_PREFIX + 3,
+            cols_names=['col1'],
+            metric_names=['m1'])
+        imported_table = SqlaTable.import_from_dict(db.session, dict_table)
+        db.session.commit()
+        table_over, dict_table_over = self.create_table(
+            'table_override', id=ID_PREFIX + 3,
+            cols_names=['new_col1', 'col2', 'col3'],
+            metric_names=['new_metric1'])
+        imported_over_table = SqlaTable.import_from_dict(
+            session=db.session,
+            dict_rep=dict_table_over,
+            sync=['metrics', 'columns'])
+        db.session.commit()
+
+        imported_over = self.get_table(imported_over_table.id)
+        self.assertEquals(imported_table.id, imported_over.id)
+        expected_table, _ = self.create_table(
+            'table_override', id=ID_PREFIX + 3,
+            metric_names=['new_metric1'],
+            cols_names=['new_col1', 'col2', 'col3'])
+        self.assert_table_equals(expected_table, imported_over)
+        self.yaml_compare(expected_table.export_to_dict(),
+          imported_over.export_to_dict())
+
+    def test_import_table_override_identical(self):
+        table, dict_table = self.create_table(
+            'copy_cat', id=ID_PREFIX + 4,
+            cols_names=['new_col1', 'col2', 'col3'],
+            metric_names=['new_metric1'])
+        imported_table = SqlaTable.import_from_dict(db.session, dict_table)
+        db.session.commit()
+        copy_table, dict_copy_table = self.create_table(
+            'copy_cat', id=ID_PREFIX + 4,
+            cols_names=['new_col1', 'col2', 'col3'],
+            metric_names=['new_metric1'])
+        imported_copy_table = SqlaTable.import_from_dict(db.session,
+                                                         dict_copy_table)
+        db.session.commit()
+        self.assertEquals(imported_table.id, imported_copy_table.id)
+        self.assert_table_equals(copy_table, self.get_table(imported_table.id))
+        self.yaml_compare(imported_copy_table.export_to_dict(),
+                          imported_table.export_to_dict())
+
+    def test_import_druid_no_metadata(self):
+        datasource, dict_datasource = self.create_druid_datasource(
+            'pure_druid', id=ID_PREFIX + 1)
+        imported_cluster = DruidDatasource.import_from_dict(db.session,
+                                                            dict_datasource)
+        db.session.commit()
+        imported = self.get_datasource(imported_cluster.id)
+        self.assert_datasource_equals(datasource, imported)
+
+    def test_import_druid_1_col_1_met(self):
+        datasource, dict_datasource = self.create_druid_datasource(
+            'druid_1_col_1_met', id=ID_PREFIX + 2,
+            cols_names=["col1"], metric_names=["metric1"])
+        imported_cluster = DruidDatasource.import_from_dict(db.session,
+                                                            dict_datasource)
+        db.session.commit()
+        imported = self.get_datasource(imported_cluster.id)
+        self.assert_datasource_equals(datasource, imported)
+        self.assertEquals(
+            {DBREF: ID_PREFIX + 2, 'database_name': 'druid_test'},
+            json.loads(imported.params))
+
+    def test_import_druid_2_col_2_met(self):
+        datasource, dict_datasource = self.create_druid_datasource(
+            'druid_2_col_2_met', id=ID_PREFIX + 3, cols_names=['c1', 'c2'],
+            metric_names=['m1', 'm2'])
+        imported_cluster = DruidDatasource.import_from_dict(db.session,
+                                                            dict_datasource)
+        db.session.commit()
+        imported = self.get_datasource(imported_cluster.id)
+        self.assert_datasource_equals(datasource, imported)
+
+    def test_import_druid_override_append(self):
+        datasource, dict_datasource = self.create_druid_datasource(
+            'druid_override', id=ID_PREFIX + 3, cols_names=['col1'],
+            metric_names=['m1'])
+        imported_cluster = DruidDatasource.import_from_dict(db.session,
+                                                            dict_datasource)
+        db.session.commit()
+        table_over, table_over_dict = self.create_druid_datasource(
+            'druid_override', id=ID_PREFIX + 3,
+            cols_names=['new_col1', 'col2', 'col3'],
+            metric_names=['new_metric1'])
+        imported_over_cluster = DruidDatasource.import_from_dict(
+                db.session,
+                table_over_dict)
+        db.session.commit()
+        imported_over = self.get_datasource(imported_over_cluster.id)
+        self.assertEquals(imported_cluster.id, imported_over.id)
+        expected_datasource, _ = self.create_druid_datasource(
+            'druid_override', id=ID_PREFIX + 3,
+            metric_names=['new_metric1', 'm1'],
+            cols_names=['col1', 'new_col1', 'col2', 'col3'])
+        self.assert_datasource_equals(expected_datasource, imported_over)
+
+    def test_import_druid_override_sync(self):
+        datasource, dict_datasource = self.create_druid_datasource(
+            'druid_override', id=ID_PREFIX + 3, cols_names=['col1'],
+            metric_names=['m1'])
+        imported_cluster = DruidDatasource.import_from_dict(db.session,
+          dict_datasource)
+        db.session.commit()
+        table_over, table_over_dict = self.create_druid_datasource(
+            'druid_override', id=ID_PREFIX + 3,
+            cols_names=['new_col1', 'col2', 'col3'],
+            metric_names=['new_metric1'])
+        imported_over_cluster = DruidDatasource.import_from_dict(
+            session=db.session,
+            dict_rep=table_over_dict,
+            sync=['metrics', 'columns'])  # syncing metrics and columns
+        db.session.commit()
+        imported_over = self.get_datasource(imported_over_cluster.id)
+        self.assertEquals(imported_cluster.id, imported_over.id)
+        expected_datasource, _ = self.create_druid_datasource(
+            'druid_override', id=ID_PREFIX + 3,
+            metric_names=['new_metric1'],
+            cols_names=['new_col1', 'col2', 'col3'])
+        self.assert_datasource_equals(expected_datasource, imported_over)
+
+    def test_import_druid_override_identical(self):
+        datasource, dict_datasource = self.create_druid_datasource(
+            'copy_cat', id=ID_PREFIX + 4,
+            cols_names=['new_col1', 'col2', 'col3'],
+            metric_names=['new_metric1'])
+        imported = DruidDatasource.import_from_dict(session=db.session,
+                                                    dict_rep=dict_datasource)
+        db.session.commit()
+        copy_datasource, dict_cp_datasource = self.create_druid_datasource(
+            'copy_cat', id=ID_PREFIX + 4,
+            cols_names=['new_col1', 'col2', 'col3'],
+            metric_names=['new_metric1'])
+        imported_copy = DruidDatasource.import_from_dict(db.session,
+                                                         dict_cp_datasource)
+        db.session.commit()
+
+        self.assertEquals(imported.id, imported_copy.id)
+        self.assert_datasource_equals(
+            copy_datasource, self.get_datasource(imported.id))
+
+
+if __name__ == '__main__':
+    unittest.main()
diff --git a/tests/import_export_tests.py b/tests/import_export_tests.py
index 2a9f069cb1..d51b959297 100644
--- a/tests/import_export_tests.py
+++ b/tests/import_export_tests.py
@@ -441,7 +441,7 @@ def test_import_table_override(self):
             cols_names=['col1', 'new_col1', 'col2', 'col3'])
         self.assert_table_equals(expected_table, imported_over)
 
-    def test_import_table_override_idential(self):
+    def test_import_table_override_identical(self):
         table = self.create_table(
             'copy_cat', id=10004, cols_names=['new_col1', 'col2', 'col3'],
             metric_names=['new_metric1'])
@@ -505,7 +505,7 @@ def test_import_druid_override(self):
             cols_names=['col1', 'new_col1', 'col2', 'col3'])
         self.assert_datasource_equals(expected_datasource, imported_over)
 
-    def test_import_druid_override_idential(self):
+    def test_import_druid_override_identical(self):
         datasource = self.create_druid_datasource(
             'copy_cat', id=10005, cols_names=['new_col1', 'col2', 'col3'],
             metric_names=['new_metric1'])


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


With regards,
Apache Git Services

Reply via email to