Hello everyone,
I am working on a plugin that provides RESTful endpoints to query the Airflow
DB (a modified/extended version of the experimental API).
An excerpt of the plugin code is:
# dagstatus_plugin.py
DagStatusBlueprint = Blueprint('dag_status', __name__, url_prefix='/dagstatus')
@DagStatusBlueprint.route('/dags', methods=['GET'])
@csrf.exempt
def get_dags():
…
@DagStatusBlueprint.route('/dags/<dag_id>/dagruns', methods=['GET'])
@csrf.exempt
def get_dag_dagruns(dag_id):
…
The approach I am using to unit-test the plugin is the following:
I keep a template of airflow.cfg where I parametrized values that I want to
instrument during the tests. For example:
[core]
# The home folder for airflow, default is ~/airflow
airflow_home = {AIRFLOW_HOME}
# The folder where your airflow pipelines live, most likely a
# subfolder in a code repository
# This path must be absolute
dags_folder = {AIRFLOW_HOME}/dags
The main reason to have a have a custom airflow.cfg file is that I would like
to control the creation of the metadata database and I want to force the
insertion of dags and dag runs into the respective tables so that I can test
the output coming from the endpoints.
Then I use the Flask test_client() method, to create a client that allows me to
call the endpoints exposed by my plugin. In this way I can check that the
endpoint (and the query string) returns the expected output.
The test code looks like the following:
import unittest
…
path_to_this = os.path.dirname(os.path.realpath(__file__))
sqlitedb_path = os.path.normpath(os.path.join(path_to_this, 'test.db'))
client = None
session = None
def setUpModule():
try:
# load the airflow.cfg template file and set custom parameter values
(just AIRFLOW_HOME in this example)
fr = open(os.path.join(path_to_this,
'airflow_test_dag_plugin.template.cfg'), 'r')
template_config_file = fr.read()
fr.close()
config_file = template_config_file.format(
AIRFLOW_HOME=path_to_this
)
# create an airflow.cfg in the current directory
fw = open(os.path.join(path_to_this, 'airflow.cfg'), 'w')
fw.write(config_file)
# set env variable AIRFLOW_HOME to current directory so that when
airflow starts
# will use the instrumented config file
os.environ['AIRFLOW_HOME'] = path_to_this
# start airflow
import airflow
# init test db
from airflow import settings
from airflow.utils import db
from airflow.www import app
db.initdb()
global client, session
# flask test client to interact with the endpoints
app = app.create_app(config=None, testing=True)
client = app.test_client()
# session object to load/clear db tables
session = settings.Session()
except Exception:
# in case of errors, remove all the files and exit
clean_test_env()
def tearDownModule():
# remove all the files and exit
clean_test_env()
class TestDagPlugin(unittest.TestCase):
# dag test set
test_dags = [
{
u'dag_id': u'test_dag_1', u'is_paused': False, u'is_subdag': False,
u'is_active': True,
u'last_scheduler_run': None, u'last_pickled': None,
u'last_expired': None, u'scheduler_lock': False,
u'pickle_id': None, u'fileloc': None, u'owners': None
},
…
]
# dag run test set
test_dag_run = [
{
'id': 1,
'dag_id': 'test_dag_1',
…
},
…
]
def setUp(self):
# load test sets into the test db
for dag in TestDagPlugin.test_dags:
session.execute(
text("INSERT INTO dag VALUES(:dag_id, :is_paused, :is_subdag,
:is_active, :last_scheduler_run, NULL, :last_expired, :scheduler_lock, NULL,
:fileloc, :owners)"),
dag)
for dag_run in TestDagPlugin.test_dag_run:
session.execute(
text("INSERT INTO dag_run VALUES(:id, :dag_id, :execution_date,
:state, :run_id, :external_trigger, :conf, :end_date, :start_date)"),
dag_run)
session.commit()
def test_get_dags(self):
“””
Invoke the endpoint, get the JSON and check if the output is the
expected one.
“””
r = client.get('/dagstatus/dags')
rjson = json.loads(r.data)
data = rjson['data']
sorted_returned_dags = sorted(data, key=lambda d: d['dag_id'])
for i, dag in enumerate(sorted_returned_dags):
print i, dag, TestDagPlugin.test_dags[i]
self.assertTrue(dag == TestDagPlugin.test_dags[i])
def test_get_dags_filter_is_paused(self):
“””
Invoke the endpoint, get the JSON and check if the output is the
expected one.
“””
r = client.get('/dagstatus/dags?is_paused=1')
rjson = json.loads(r.data)
data = rjson['data']
paused_dags = filter(TestDagPlugin.test_dags, lambda dag:
dag['is_paused'] is True)
def tearDown(self):
session.execute("DELETE FROM dag;")
session.execute("DELETE FROM dag_run;")
I know that testing is an hot topic in the Airflow community and I would like
to hear your opinion about this, if you have any ideas to improve it or suggest
better practices.
Thanks,
Emanuele