Repository: incubator-airflow
Updated Branches:
  refs/heads/master 66168efa1 -> 48135ad25


[AIRFLOW 1149][AIRFLOW-1149] Allow for custom filters in Jinja2 templates

Closes #2258 from
NielsZeilemaker/jinja_custom_filters


Project: http://git-wip-us.apache.org/repos/asf/incubator-airflow/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-airflow/commit/48135ad2
Tree: http://git-wip-us.apache.org/repos/asf/incubator-airflow/tree/48135ad2
Diff: http://git-wip-us.apache.org/repos/asf/incubator-airflow/diff/48135ad2

Branch: refs/heads/master
Commit: 48135ad255263d4718bbcace39c746aea5929568
Parents: 66168ef
Author: Niels Zeilemaker <nielszeilema...@godatadriven.com>
Authored: Sat Apr 29 17:14:40 2017 +0200
Committer: Bolke de Bruin <bo...@xs4all.nl>
Committed: Sat Apr 29 17:14:40 2017 +0200

----------------------------------------------------------------------
 airflow/models.py | 19 ++++++++++++++---
 docs/tutorial.rst | 10 +++++++++
 tests/models.py   | 55 ++++++++++++++++++++++++++++++++++++++++++++++++++
 3 files changed, 81 insertions(+), 3 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/48135ad2/airflow/models.py
----------------------------------------------------------------------
diff --git a/airflow/models.py b/airflow/models.py
index d2f7894..aab4833 100755
--- a/airflow/models.py
+++ b/airflow/models.py
@@ -2249,11 +2249,13 @@ class BaseOperator(object):
         memo[id(self)] = result
 
         for k, v in list(self.__dict__.items()):
-            if k not in ('user_defined_macros', 'params'):
+            if k not in ('user_defined_macros', 'user_defined_filters', 
'params'):
                 setattr(result, k, copy.deepcopy(v, memo))
         result.params = self.params
         if hasattr(self, 'user_defined_macros'):
             result.user_defined_macros = self.user_defined_macros
+        if hasattr(self, 'user_defined_filters'):
+            result.user_defined_filters = self.user_defined_filters
         return result
 
     def render_template_from_field(self, attr, content, context, jinja_env):
@@ -2644,6 +2646,12 @@ class DAG(BaseDag, LoggingMixin):
         templates related to this DAG. Note that you can pass any
         type of object here.
     :type user_defined_macros: dict
+    :param user_defined_filters: a dictionary of filters that will be exposed
+        in your jinja templates. For example, passing
+        ``dict(hello=lambda name: 'Hello %s' % name)`` to this argument allows
+        you to ``{{ 'world' | hello }}`` in all jinja templates related to
+        this DAG.
+    :type user_defined_filters: dict
     :param default_args: A dictionary of default parameters to be used
         as constructor keyword parameters when initialising operators.
         Note that operators have the same hook, and precede those defined
@@ -2684,6 +2692,7 @@ class DAG(BaseDag, LoggingMixin):
             full_filepath=None,
             template_searchpath=None,
             user_defined_macros=None,
+            user_defined_filters=None,
             default_args=None,
             concurrency=configuration.getint('core', 'dag_concurrency'),
             max_active_runs=configuration.getint(
@@ -2696,6 +2705,7 @@ class DAG(BaseDag, LoggingMixin):
             params=None):
 
         self.user_defined_macros = user_defined_macros
+        self.user_defined_filters = user_defined_filters
         self.default_args = default_args or {}
         self.params = params or {}
 
@@ -3034,7 +3044,7 @@ class DAG(BaseDag, LoggingMixin):
     def get_template_env(self):
         """
         Returns a jinja2 Environment while taking into account the DAGs
-        template_searchpath and user_defined_macros
+        template_searchpath, user_defined_macros and user_defined_filters
         """
         searchpath = [self.folder]
         if self.template_searchpath:
@@ -3046,6 +3056,8 @@ class DAG(BaseDag, LoggingMixin):
             cache_size=0)
         if self.user_defined_macros:
             env.globals.update(self.user_defined_macros)
+        if self.user_defined_filters:
+            env.filters.update(self.user_defined_filters)
 
         return env
 
@@ -3212,10 +3224,11 @@ class DAG(BaseDag, LoggingMixin):
         result = cls.__new__(cls)
         memo[id(self)] = result
         for k, v in list(self.__dict__.items()):
-            if k not in ('user_defined_macros', 'params'):
+            if k not in ('user_defined_macros', 'user_defined_filters', 
'params'):
                 setattr(result, k, copy.deepcopy(v, memo))
 
         result.user_defined_macros = self.user_defined_macros
+        result.user_defined_filters = self.user_defined_filters
         result.params = self.params
         return result
 

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/48135ad2/docs/tutorial.rst
----------------------------------------------------------------------
diff --git a/docs/tutorial.rst b/docs/tutorial.rst
index d047f82..dc09482 100644
--- a/docs/tutorial.rst
+++ b/docs/tutorial.rst
@@ -231,6 +231,16 @@ different languages, and general flexibility in 
structuring pipelines. It is
 also possible to define your ``template_searchpath`` as pointing to any folder
 locations in the DAG constructor call.
 
+Using that same DAG constructor call, it is possible to define
+``user_defined_macros`` which allow you to specify your own variables.
+For example, passing ``dict(foo='bar')`` to this argument allows you
+to use ``{{ foo }}`` in your templates. Moreover, specifying
+``user_defined_filters`` allow you to register you own filters. For example,
+passing ``dict(hello=lambda name: 'Hello %s' % name)`` to this argument allows
+you to use ``{{ 'world' | hello }}`` in your templates. For more information
+regarding custom filters have a look at the
+`Jinja Documentation <http://jinja.pocoo.org/docs/dev/api/#writing-filters>`_
+
 For more information on the variables and macros that can be referenced
 in templates, make sure to read through the :ref:`macros` section
 

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/48135ad2/tests/models.py
----------------------------------------------------------------------
diff --git a/tests/models.py b/tests/models.py
index 49e5c75..4c2a15f 100644
--- a/tests/models.py
+++ b/tests/models.py
@@ -233,6 +233,61 @@ class DagTest(unittest.TestCase):
             states=[None, State.QUEUED, State.RUNNING], session=session))
         session.close()
 
+    def test_render_template_field(self):
+        """Tests if render_template from a field works"""
+
+        dag = DAG('test-dag',
+                  start_date=DEFAULT_DATE)
+
+        with dag:
+            task = DummyOperator(task_id='op1')
+
+        result = task.render_template('', '{{ foo }}', dict(foo='bar'))
+        self.assertEqual(result, 'bar')
+
+    def test_render_template_field_macro(self):
+        """ Tests if render_template from a field works,
+            if a custom filter was defined"""
+
+        dag = DAG('test-dag',
+                  start_date=DEFAULT_DATE,
+                  user_defined_macros = dict(foo='bar'))
+
+        with dag:
+            task = DummyOperator(task_id='op1')
+
+        result = task.render_template('', '{{ foo }}', dict())
+        self.assertEqual(result, 'bar')
+
+    def test_user_defined_filters(self):
+        def jinja_udf(name):
+            return 'Hello %s' %name
+
+        dag = models.DAG('test-dag',
+                         start_date=DEFAULT_DATE,
+                         user_defined_filters=dict(hello=jinja_udf))
+        jinja_env = dag.get_template_env()
+
+        self.assertIn('hello', jinja_env.filters)
+        self.assertEqual(jinja_env.filters['hello'], jinja_udf)
+
+    def test_render_template_field_filter(self):
+        """ Tests if render_template from a field works,
+            if a custom filter was defined"""
+
+        def jinja_udf(name):
+            return 'Hello %s' %name
+
+        dag = DAG('test-dag',
+                  start_date=DEFAULT_DATE,
+                  user_defined_filters = dict(hello=jinja_udf))
+
+        with dag:
+            task = DummyOperator(task_id='op1')
+
+        result = task.render_template('', "{{ 'world' | hello}}", dict())
+        self.assertEqual(result, 'Hello world')
+
 
 class DagStatTest(unittest.TestCase):
     def test_dagstats_crud(self):

Reply via email to