This is an automated email from the ASF dual-hosted git repository. damccorm pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push: new 19514938176 Add RunInference example for TensorFlow Hub pre-trained model (#24529) 19514938176 is described below commit 1951493817638f75bcf03e57a09cc934257cb31a Author: harrisonlimh <97203667+harrisonl...@users.noreply.github.com> AuthorDate: Tue Dec 6 11:23:22 2022 -0800 Add RunInference example for TensorFlow Hub pre-trained model (#24529) * Create run_inference_tensorflow_hub.ipynb Add a notebook for using RunInference() with a TensorFlow Hub model. * Update run_inference_tensorflow_hub.ipynb * Update run_inference_tensorflow_hub.ipynb * Update run_inference_tensorflow_hub.ipynb * Update run_inference_tensorflow_hub.ipynb * Update run_inference_tensorflow_hub.ipynb * Update run_inference_tensorflow_hub.ipynb * Update run_inference_tensorflow_hub.ipynb * Update run_inference_tensorflow_hub.ipynb * Update run_inference_tensorflow_hub.ipynb --- .../beam-ml/run_inference_tensorflow_hub.ipynb | 543 +++++++++++++++++++++ 1 file changed, 543 insertions(+) diff --git a/examples/notebooks/beam-ml/run_inference_tensorflow_hub.ipynb b/examples/notebooks/beam-ml/run_inference_tensorflow_hub.ipynb new file mode 100644 index 00000000000..77007f72efc --- /dev/null +++ b/examples/notebooks/beam-ml/run_inference_tensorflow_hub.ipynb @@ -0,0 +1,543 @@ +{ + "cells": [ + { + "cell_type": "code", + "source": [ + "# @title ###### Licensed to the Apache Software Foundation (ASF), Version 2.0 (the \"License\")\n", + "\n", + "# Licensed to the Apache Software Foundation (ASF) under one\n", + "# or more contributor license agreements. See the NOTICE file\n", + "# distributed with this work for additional information\n", + "# regarding copyright ownership. The ASF licenses this file\n", + "# to you under the Apache License, Version 2.0 (the\n", + "# \"License\"); you may not use this file except in compliance\n", + "# with the License. You may obtain a copy of the License at\n", + "#\n", + "# http://www.apache.org/licenses/LICENSE-2.0\n", + "#\n", + "# Unless required by applicable law or agreed to in writing,\n", + "# software distributed under the License is distributed on an\n", + "# \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY\n", + "# KIND, either express or implied. See the License for the\n", + "# specific language governing permissions and limitations\n", + "# under the License" + ], + "metadata": { + "id": "Qx4wHX2zIKS1", + "cellView": "form" + }, + "id": "Qx4wHX2zIKS1", + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "3ac8fc4a-a0ef-47b9-bd80-10801eebe13e" + }, + "source": [ + "# RunInference with Sentence-T5 (ST5) model\n", + "\n", + "This example demonstrates the use of the RunInference transform with the pre-trained [ST5 text encoder model](https://tfhub.dev/google/sentence-t5/st5-base/1) from TensorFlow Hub. The transform runs locally using the [Interactive Runner](https://beam.apache.org/releases/pydoc/2.11.0/apache_beam.runners.interactive.interactive_runner.html)." + ], + "id": "3ac8fc4a-a0ef-47b9-bd80-10801eebe13e" + }, + { + "cell_type": "markdown", + "metadata": { + "id": "3402ecc9-28d6-4226-99b1-147a2d23b7a9" + }, + "source": [ + "## Download and install the dependencies\n" + ], + "id": "3402ecc9-28d6-4226-99b1-147a2d23b7a9" + }, + { + "cell_type": "code", + "source": [ + "!pip install apache_beam[gcp,interactive]==2.41.0\n", + "!pip install tensorflow==2.10.0\n", + "!pip install tensorflow_text==2.10.0\n", + "!pip install keras==2.10.0\n", + "!pip install tfx_bsl==1.10.0\n", + "!pip install pillow==8.4.0" + ], + "metadata": { + "id": "H2-orNBqsZ95" + }, + "id": "H2-orNBqsZ95", + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "f313a508-59ea-47ed-86eb-c9c8e67785f2", + "scrolled": true + }, + "outputs": [], + "source": [ + "import tensorflow as tf\n", + "import tensorflow_hub as hub\n", + "import tensorflow_text\n", + "\n", + "from tensorflow import keras\n", + "\n", + "import apache_beam as beam\n", + "import apache_beam.runners.interactive.interactive_beam as ib\n", + "\n", + "from apache_beam.ml.inference.base import RunInference\n", + "from apache_beam.ml.inference.base import ModelHandler\n", + "from apache_beam.runners.interactive.interactive_runner import InteractiveRunner\n", + "\n", + "from tfx_bsl.public.beam.run_inference import CreateModelHandler\n", + "from tfx_bsl.public.proto import model_spec_pb2" + ], + "id": "f313a508-59ea-47ed-86eb-c9c8e67785f2" + }, + { + "cell_type": "markdown", + "source": [ + "## Authenticate with Google Cloud\n", + "This notebook relies on saving the model to Google Cloud. To use your Google Cloud account, authenticate this notebook." + ], + "metadata": { + "id": "6zAnF4EmomUS" + }, + "id": "6zAnF4EmomUS" + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "r1fgn5I_-mzA" + }, + "outputs": [], + "source": [ + "from google.colab import auth\n", + "auth.authenticate_user()" + ], + "id": "r1fgn5I_-mzA" + }, + { + "cell_type": "markdown", + "metadata": { + "id": "74db0203-3d26-4bc4-8271-81fad9756297" + }, + "source": [ + "## Create a Keras Model from TensorFlow Hub image\n", + "\n", + "Replace `GCS_BUCKET` with the name of your bucket. Your model will be saved in `MODEL_EXPORT_DIR`." + ], + "id": "74db0203-3d26-4bc4-8271-81fad9756297" + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "2ff8e394-f577-4dea-bef9-a4f4528c1378" + }, + "outputs": [], + "source": [ + "GCS_BUCKET = '<GCS Bucket>'\n", + "\n", + "MODEL_EXPORT_DIR = f'gs://{GCS_BUCKET}/st5-base/1'" + ], + "id": "2ff8e394-f577-4dea-bef9-a4f4528c1378" + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "ccaede25-1c1a-4ec4-9296-25c9a2ac43d7" + }, + "outputs": [], + "source": [ + "inp = tf.keras.layers.Input(shape=[], dtype=tf.string, name='input')\n", + "hub_url = \"https://tfhub.dev/google/sentence-t5/st5-base/1\"\n", + "imported = hub.KerasLayer(hub_url)\n", + "outp = imported(inp)\n", + "model = tf.keras.Model(inp, outp)" + ], + "id": "ccaede25-1c1a-4ec4-9296-25c9a2ac43d7" + }, + { + "cell_type": "code", + "source": [ + "# The ST5 model returns a 768-dimensional vector for an English text input.\n", + "model.summary()" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "Xvb-E0D1JHnr", + "outputId": "121a5924-94b0-4b01-97da-14a7223ec61c" + }, + "id": "Xvb-E0D1JHnr", + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Model: \"model\"\n", + "_________________________________________________________________\n", + " Layer (type) Output Shape Param # \n", + "=================================================================\n", + " input (InputLayer) [(None,)] 0 \n", + " \n", + " keras_layer (KerasLayer) [(None, 768)] 0 \n", + " \n", + "=================================================================\n", + "Total params: 0\n", + "Trainable params: 0\n", + "Non-trainable params: 0\n", + "_________________________________________________________________\n" + ] + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "29803d5b-93b9-41fc-b414-f7c737c5d7bc" + }, + "source": [ + "## Save the model\n", + "Save the model with a TF function definition for RunInference." + ], + "id": "29803d5b-93b9-41fc-b414-f7c737c5d7bc" + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "704abfca-5956-4fc1-9717-4c6d5bf2db8e" + }, + "outputs": [], + "source": [ + "RAW_DATA_PREDICT_SPEC = {\n", + " 'input': tf.io.FixedLenFeature([], tf.string),\n", + "}\n", + "\n", + "@tf.function(input_signature=[tf.TensorSpec(shape=[None], dtype=tf.string)])\n", + "def call(serialized_examples):\n", + " features = tf.io.parse_example(serialized_examples, RAW_DATA_PREDICT_SPEC)\n", + " return model(features)\n", + "\n", + "tf.saved_model.save(model, MODEL_EXPORT_DIR, signatures={'serving_default': call})" + ], + "id": "704abfca-5956-4fc1-9717-4c6d5bf2db8e" + }, + { + "cell_type": "markdown", + "metadata": { + "id": "7b56569d-e540-44ed-a46a-9cec886522f6" + }, + "source": [ + "## Create and test the RunInference pipeline locally\n", + "Use TFX_BSL's [CreateModelHandler](https://www.tensorflow.org/tfx/tfx_bsl/api_docs/python/tfx_bsl/public/beam/run_inference/CreateModelHandler) function for RunInference with TensorFlow models." + ], + "id": "7b56569d-e540-44ed-a46a-9cec886522f6" + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "fad13b30-b159-425c-8c15-a41301abd3a4" + }, + "outputs": [], + "source": [ + "# Creates a TensorFlow example to feed to the model handler.\n", + "class ExampleProcessor:\n", + " def create_example(self, feature: tf.string):\n", + " return tf.train.Example(\n", + " features=tf.train.Features(\n", + " feature={'input' : self.create_feature(feature)})\n", + " )\n", + "\n", + " def create_feature(self, element: tf.string):\n", + " return tf.train.Feature(bytes_list=tf.train.BytesList(value=[element.encode()], ))\n" + ], + "id": "fad13b30-b159-425c-8c15-a41301abd3a4" + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "60380ebd-5bc8-4bc6-9cf4-3030bf687367", + "scrolled": true, + "colab": { + "base_uri": "https://localhost:8080/", + "height": 17 + }, + "outputId": "1b53641d-806f-4f0b-c8f3-9badf67ab98a" + }, + "outputs": [ + { + "output_type": "display_data", + "data": { + "application/javascript": [ + "\n", + " if (typeof window.interactive_beam_jquery == 'undefined') {\n", + " var jqueryScript = document.createElement('script');\n", + " jqueryScript.src = 'https://code.jquery.com/jquery-3.4.1.slim.min.js';\n", + " jqueryScript.type = 'text/javascript';\n", + " jqueryScript.onload = function() {\n", + " var datatableScript = document.createElement('script');\n", + " datatableScript.src = 'https://cdn.datatables.net/1.10.20/js/jquery.dataTables.min.js';\n", + " datatableScript.type = 'text/javascript';\n", + " datatableScript.onload = function() {\n", + " window.interactive_beam_jquery = jQuery.noConflict(true);\n", + " window.interactive_beam_jquery(document).ready(function($){\n", + " \n", + " });\n", + " }\n", + " document.head.appendChild(datatableScript);\n", + " };\n", + " document.head.appendChild(jqueryScript);\n", + " } else {\n", + " window.interactive_beam_jquery(document).ready(function($){\n", + " \n", + " });\n", + " }" + ] + }, + "metadata": {} + } + ], + "source": [ + "saved_model_spec = model_spec_pb2.SavedModelSpec(model_path=MODEL_EXPORT_DIR)\n", + "inferece_spec_type = model_spec_pb2.InferenceSpecType(saved_model_spec=saved_model_spec)\n", + "model_handler = CreateModelHandler(inferece_spec_type)\n", + "\n", + "questions = [\n", + " 'what is the official slogan for the 2018 winter olympics?',\n", + "]\n", + "\n", + "pipeline = beam.Pipeline(InteractiveRunner())\n", + "\n", + "inference = (pipeline | 'CreateSentences' >> beam.Create(questions)\n", + " | 'Convert input to Tensor' >> beam.Map(lambda x: ExampleProcessor().create_example(x))\n", + " | 'RunInference with T5' >> RunInference(model_handler))" + ], + "id": "60380ebd-5bc8-4bc6-9cf4-3030bf687367" + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "29a84182-baa5-45c4-abcf-d9cab84835c9", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 565 + }, + "outputId": "d6ef64c2-f120-4bc1-c972-78c1cd8a729e" + }, + "outputs": [ + { + "output_type": "display_data", + "data": { + "text/plain": [ + "<IPython.core.display.HTML object>" + ], + "text/html": [ + "\n", + " <link rel=\"stylesheet\" href=\"https://stackpath.bootstrapcdn.com/bootstrap/4.4.1/css/bootstrap.min.css\" integrity=\"sha384-Vkoo8x4CGsO3+Hhxv8T/Q5PaXtkKtu6ug5TOeNV6gBiFeWPGFN9MuhOf23Q9Ifjh\" crossorigin=\"anonymous\">\n", + " <div id=\"progress_indicator_02aa6852261e6f1837821c6131548b47\">\n", + " <div class=\"spinner-border text-info\" role=\"status\"></div>\n", + " <span class=\"text-info\">Processing... show</span>\n", + " </div>\n", + " " + ] + }, + "metadata": {} + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "WARNING:tensorflow:From /usr/local/lib/python3.8/dist-packages/tfx_bsl/beam/run_inference.py:615: load (from tensorflow.python.saved_model.loader_impl) is deprecated and will be removed in a future version.\n", + "Instructions for updating:\n", + "This function will only be available through the v1 compatibility library as tf.compat.v1.saved_model.loader.load or tf.compat.v1.saved_model.load. There will be a new function for importing SavedModels in Tensorflow 2.0.\n", + "2022-12-06 09:30:47.084208: I tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc:354] MLIR V1 optimization pass is not enabled\n", + "2022-12-06 09:30:54.471173: I tensorflow/compiler/xla/service/service.cc:173] XLA service 0x4a3e0a00 initialized for platform Host (this does not guarantee that XLA will be used). Devices:\n", + "2022-12-06 09:30:54.471244: I tensorflow/compiler/xla/service/service.cc:181] StreamExecutor device (0): Host, Default Version\n", + "2022-12-06 09:30:54.537285: I tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc:268] disabling MLIR crash reproducer, set env var `MLIR_CRASH_REPRODUCER_DIRECTORY` to enable.\n", + "2022-12-06 09:31:00.441479: I tensorflow/compiler/jit/xla_compilation_cache.cc:476] Compiled cluster using XLA! This line is logged at most once for the lifetime of the process.\n" + ] + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "<IPython.core.display.HTML object>" + ], + "text/html": [ + "\n", + " <style>\n", + " .p-Widget.jp-OutputPrompt.jp-OutputArea-prompt:empty {\n", + " padding: 0;\n", + " border: 0;\n", + " }\n", + " .p-Widget.jp-RenderedJavaScript.jp-mod-trusted.jp-OutputArea-output:empty {\n", + " padding: 0;\n", + " border: 0;\n", + " }\n", + " </style>\n", + " <link rel=\"stylesheet\" href=\"https://cdn.datatables.net/1.10.20/css/jquery.dataTables.min.css\">\n", + " <table id=\"table_df_e626fa8c14641f539b5385fe8e9e2dff\" class=\"display\" style=\"display:block\"></table>\n", + " <script>\n", + " \n", + " if (typeof window.interactive_beam_jquery == 'undefined') {\n", + " var jqueryScript = document.createElement('script');\n", + " jqueryScript.src = 'https://code.jquery.com/jquery-3.4.1.slim.min.js';\n", + " jqueryScript.type = 'text/javascript';\n", + " jqueryScript.onload = function() {\n", + " var datatableScript = document.createElement('script');\n", + " datatableScript.src = 'https://cdn.datatables.net/1.10.20/js/jquery.dataTables.min.js';\n", + " datatableScript.type = 'text/javascript';\n", + " datatableScript.onload = function() {\n", + " window.interactive_beam_jquery = jQuery.noConflict(true);\n", + " window.interactive_beam_jquery(document).ready(function($){\n", + " \n", + " var dt;\n", + " if ($.fn.dataTable.isDataTable(\"#table_df_e626fa8c14641f539b5385fe8e9e2dff\")) {\n", + " dt = $(\"#table_df_e626fa8c14641f539b5385fe8e9e2dff\").dataTable();\n", + " } else if ($(\"#table_df_e626fa8c14641f539b5385fe8e9e2dff_wrapper\").length == 0) {\n", + " dt = $(\"#table_df_e626fa8c14641f539b5385fe8e9e2dff\").dataTable({\n", + " \n", + " bAutoWidth: false,\n", + " columns: [{'title': ''}, {'title': 'inference.0'}],\n", + " destroy: true,\n", + " responsive: true,\n", + " columnDefs: [\n", + " {\n", + " targets: \"_all\",\n", + " className: \"dt-left\"\n", + " },\n", + " {\n", + " \"targets\": 0,\n", + " \"width\": \"10px\",\n", + " \"title\": \"\"\n", + " }\n", + " ]\n", + " });\n", + " } else {\n", + " return;\n", + " }\n", + " dt.api()\n", + " .clear()\n", + " .rows.add([{1: 'predict_log {\\n request {\\n model_spec {\\n signature_name: \"serving_default\"\\n }\\n inputs {\\n key: \"serialized_examples\"\\n value {\\n dtype: DT_STRING\\n tensor_shape {\\n dim {\\n size: 1\\n }\\n }\\n string_val: \"\\\\nH\\\\nF\\\\n\\\\005input\\\\022=\\\\n;\\\\n9what is the official slogan for the 2018 winter olympics?\"\\n }\\n }\\n }\\n r [...] + " .draw('full-hold');\n", + " });\n", + " }\n", + " document.head.appendChild(datatableScript);\n", + " };\n", + " document.head.appendChild(jqueryScript);\n", + " } else {\n", + " window.interactive_beam_jquery(document).ready(function($){\n", + " \n", + " var dt;\n", + " if ($.fn.dataTable.isDataTable(\"#table_df_e626fa8c14641f539b5385fe8e9e2dff\")) {\n", + " dt = $(\"#table_df_e626fa8c14641f539b5385fe8e9e2dff\").dataTable();\n", + " } else if ($(\"#table_df_e626fa8c14641f539b5385fe8e9e2dff_wrapper\").length == 0) {\n", + " dt = $(\"#table_df_e626fa8c14641f539b5385fe8e9e2dff\").dataTable({\n", + " \n", + " bAutoWidth: false,\n", + " columns: [{'title': ''}, {'title': 'inference.0'}],\n", + " destroy: true,\n", + " responsive: true,\n", + " columnDefs: [\n", + " {\n", + " targets: \"_all\",\n", + " className: \"dt-left\"\n", + " },\n", + " {\n", + " \"targets\": 0,\n", + " \"width\": \"10px\",\n", + " \"title\": \"\"\n", + " }\n", + " ]\n", + " });\n", + " } else {\n", + " return;\n", + " }\n", + " dt.api()\n", + " .clear()\n", + " .rows.add([{1: 'predict_log {\\n request {\\n model_spec {\\n signature_name: \"serving_default\"\\n }\\n inputs {\\n key: \"serialized_examples\"\\n value {\\n dtype: DT_STRING\\n tensor_shape {\\n dim {\\n size: 1\\n }\\n }\\n string_val: \"\\\\nH\\\\nF\\\\n\\\\005input\\\\022=\\\\n;\\\\n9what is the official slogan for the 2018 winter olympics?\"\\n }\\n }\\n }\\n r [...] + " .draw('full-hold');\n", + " });\n", + " }\n", + " </script>" + ] + }, + "metadata": {} + }, + { + "output_type": "display_data", + "data": { + "application/javascript": [ + "\n", + " if (typeof window.interactive_beam_jquery == 'undefined') {\n", + " var jqueryScript = document.createElement('script');\n", + " jqueryScript.src = 'https://code.jquery.com/jquery-3.4.1.slim.min.js';\n", + " jqueryScript.type = 'text/javascript';\n", + " jqueryScript.onload = function() {\n", + " var datatableScript = document.createElement('script');\n", + " datatableScript.src = 'https://cdn.datatables.net/1.10.20/js/jquery.dataTables.min.js';\n", + " datatableScript.type = 'text/javascript';\n", + " datatableScript.onload = function() {\n", + " window.interactive_beam_jquery = jQuery.noConflict(true);\n", + " window.interactive_beam_jquery(document).ready(function($){\n", + " \n", + " $(\"#progress_indicator_02aa6852261e6f1837821c6131548b47\").remove();\n", + " });\n", + " }\n", + " document.head.appendChild(datatableScript);\n", + " };\n", + " document.head.appendChild(jqueryScript);\n", + " } else {\n", + " window.interactive_beam_jquery(document).ready(function($){\n", + " \n", + " $(\"#progress_indicator_02aa6852261e6f1837821c6131548b47\").remove();\n", + " });\n", + " }" + ] + }, + "metadata": {} + } + ], + "source": [ + "ib.show(inference)" + ], + "id": "29a84182-baa5-45c4-abcf-d9cab84835c9" + } + ], + "metadata": { + "colab": { + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.12" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} \ No newline at end of file