{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "Tce3stUlHN0L" }, "source": [ "##### Copyright 2020 The TensorFlow Authors." ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "cellView": "form", "execution": { "iopub.execute_input": "2024-04-20T11:17:18.633969Z", "iopub.status.busy": "2024-04-20T11:17:18.633731Z", "iopub.status.idle": "2024-04-20T11:17:18.637470Z", "shell.execute_reply": "2024-04-20T11:17:18.636923Z" }, "id": "tuOe1ymfHZPu" }, "outputs": [], "source": [ "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", "# you may not use this file except in compliance with the License.\n", "# You may obtain a copy of the License at\n", "#\n", "# https://www.apache.org/licenses/LICENSE-2.0\n", "#\n", "# Unless required by applicable law or agreed to in writing, software\n", "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", "# limitations under the License." ] }, { "cell_type": "markdown", "metadata": { "id": "8yo62ffS5TF5" }, "source": [ "# Using text and neural network features\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", "
\n", " View on TensorFlow.org\n", " \n", " Run in Google Colab\n", " \n", " View on GitHub\n", " \n", " Download notebook\n", " \n", " See TF Hub model\n", "
\n" ] }, { "cell_type": "markdown", "metadata": { "id": "zrCwCCxhiAL7" }, "source": [ "Welcome to the **Intermediate Colab** for **TensorFlow Decision Forests** (**TF-DF**).\n", "In this colab, you will learn about some more advanced capabilities of **TF-DF**, including how to deal with natural language features.\n", "\n", "This colab assumes you are familiar with the concepts presented the [Beginner colab](beginner_colab.ipynb), notably about the installation about TF-DF.\n", "\n", "In this colab, you will:\n", "\n", "1. Train a Random Forest that consumes text features natively as categorical sets.\n", "\n", "1. Train a Random Forest that consumes text features using a [TensorFlow Hub](https://www.tensorflow.org/hub) module. In this setting (transfer learning), the module is already pre-trained on a large text corpus.\n", "\n", "1. Train a Gradient Boosted Decision Trees (GBDT) and a Neural Network together. The GBDT will consume the output of the Neural Network." ] }, { "cell_type": "markdown", "metadata": { "id": "Rzskapxq7gdo" }, "source": [ "## Setup" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "execution": { "iopub.execute_input": "2024-04-20T11:17:18.640607Z", "iopub.status.busy": "2024-04-20T11:17:18.640390Z", "iopub.status.idle": "2024-04-20T11:17:21.458610Z", "shell.execute_reply": "2024-04-20T11:17:21.457811Z" }, "id": "mZiInVYfffAb" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Collecting tensorflow_decision_forests\r\n", " Using cached tensorflow_decision_forests-1.9.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (6.0 kB)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: numpy in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow_decision_forests) (1.26.4)\r\n", "Requirement already satisfied: pandas in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow_decision_forests) (2.2.2)\r\n", "Requirement already satisfied: tensorflow~=2.16.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow_decision_forests) (2.16.1)\r\n", "Requirement already satisfied: six in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow_decision_forests) (1.16.0)\r\n", "Requirement already satisfied: absl-py in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow_decision_forests) (1.4.0)\r\n", "Requirement already satisfied: wheel in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow_decision_forests) (0.41.2)\r\n", "Collecting wurlitzer (from tensorflow_decision_forests)\r\n", " Using cached wurlitzer-3.0.3-py3-none-any.whl.metadata (1.9 kB)\r\n", "Requirement already satisfied: tf-keras~=2.16 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow_decision_forests) (2.16.0)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: astunparse>=1.6.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.16.1->tensorflow_decision_forests) (1.6.3)\r\n", "Requirement already satisfied: flatbuffers>=23.5.26 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.16.1->tensorflow_decision_forests) (24.3.25)\r\n", "Requirement already satisfied: gast!=0.5.0,!=0.5.1,!=0.5.2,>=0.2.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.16.1->tensorflow_decision_forests) (0.5.4)\r\n", "Requirement already satisfied: google-pasta>=0.1.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.16.1->tensorflow_decision_forests) (0.2.0)\r\n", "Requirement already satisfied: h5py>=3.10.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.16.1->tensorflow_decision_forests) (3.11.0)\r\n", "Requirement already satisfied: libclang>=13.0.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.16.1->tensorflow_decision_forests) (18.1.1)\r\n", "Requirement already satisfied: ml-dtypes~=0.3.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.16.1->tensorflow_decision_forests) (0.3.2)\r\n", "Requirement already satisfied: opt-einsum>=2.3.2 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.16.1->tensorflow_decision_forests) (3.3.0)\r\n", "Requirement already satisfied: packaging in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.16.1->tensorflow_decision_forests) (24.0)\r\n", "Requirement already satisfied: protobuf!=4.21.0,!=4.21.1,!=4.21.2,!=4.21.3,!=4.21.4,!=4.21.5,<5.0.0dev,>=3.20.3 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.16.1->tensorflow_decision_forests) (3.20.3)\r\n", "Requirement already satisfied: requests<3,>=2.21.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.16.1->tensorflow_decision_forests) (2.31.0)\r\n", "Requirement already satisfied: setuptools in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.16.1->tensorflow_decision_forests) (69.5.1)\r\n", "Requirement already satisfied: termcolor>=1.1.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.16.1->tensorflow_decision_forests) (2.4.0)\r\n", "Requirement already satisfied: typing-extensions>=3.6.6 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.16.1->tensorflow_decision_forests) (4.11.0)\r\n", "Requirement already satisfied: wrapt>=1.11.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.16.1->tensorflow_decision_forests) (1.16.0)\r\n", "Requirement already satisfied: grpcio<2.0,>=1.24.3 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.16.1->tensorflow_decision_forests) (1.63.0rc2)\r\n", "Requirement already satisfied: tensorboard<2.17,>=2.16 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.16.1->tensorflow_decision_forests) (2.16.2)\r\n", "Requirement already satisfied: keras>=3.0.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.16.1->tensorflow_decision_forests) (3.2.1)\r\n", "Requirement already satisfied: tensorflow-io-gcs-filesystem>=0.23.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.16.1->tensorflow_decision_forests) (0.36.0)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: python-dateutil>=2.8.2 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from pandas->tensorflow_decision_forests) (2.9.0.post0)\r\n", "Requirement already satisfied: pytz>=2020.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from pandas->tensorflow_decision_forests) (2024.1)\r\n", "Requirement already satisfied: tzdata>=2022.7 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from pandas->tensorflow_decision_forests) (2024.1)\r\n", "Requirement already satisfied: rich in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from keras>=3.0.0->tensorflow~=2.16.1->tensorflow_decision_forests) (13.7.1)\r\n", "Requirement already satisfied: namex in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from keras>=3.0.0->tensorflow~=2.16.1->tensorflow_decision_forests) (0.0.8)\r\n", "Requirement already satisfied: optree in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from keras>=3.0.0->tensorflow~=2.16.1->tensorflow_decision_forests) (0.11.0)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: charset-normalizer<4,>=2 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from requests<3,>=2.21.0->tensorflow~=2.16.1->tensorflow_decision_forests) (3.3.2)\r\n", "Requirement already satisfied: idna<4,>=2.5 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from requests<3,>=2.21.0->tensorflow~=2.16.1->tensorflow_decision_forests) (3.7)\r\n", "Requirement already satisfied: urllib3<3,>=1.21.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from requests<3,>=2.21.0->tensorflow~=2.16.1->tensorflow_decision_forests) (2.2.1)\r\n", "Requirement already satisfied: certifi>=2017.4.17 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from requests<3,>=2.21.0->tensorflow~=2.16.1->tensorflow_decision_forests) (2024.2.2)\r\n", "Requirement already satisfied: markdown>=2.6.8 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorboard<2.17,>=2.16->tensorflow~=2.16.1->tensorflow_decision_forests) (3.6)\r\n", "Requirement already satisfied: tensorboard-data-server<0.8.0,>=0.7.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorboard<2.17,>=2.16->tensorflow~=2.16.1->tensorflow_decision_forests) (0.7.2)\r\n", "Requirement already satisfied: werkzeug>=1.0.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorboard<2.17,>=2.16->tensorflow~=2.16.1->tensorflow_decision_forests) (3.0.2)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: importlib-metadata>=4.4 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from markdown>=2.6.8->tensorboard<2.17,>=2.16->tensorflow~=2.16.1->tensorflow_decision_forests) (7.1.0)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: MarkupSafe>=2.1.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from werkzeug>=1.0.1->tensorboard<2.17,>=2.16->tensorflow~=2.16.1->tensorflow_decision_forests) (2.1.5)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: markdown-it-py>=2.2.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from rich->keras>=3.0.0->tensorflow~=2.16.1->tensorflow_decision_forests) (3.0.0)\r\n", "Requirement already satisfied: pygments<3.0.0,>=2.13.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from rich->keras>=3.0.0->tensorflow~=2.16.1->tensorflow_decision_forests) (2.17.2)\r\n", "Requirement already satisfied: zipp>=0.5 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from importlib-metadata>=4.4->markdown>=2.6.8->tensorboard<2.17,>=2.16->tensorflow~=2.16.1->tensorflow_decision_forests) (3.18.1)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: mdurl~=0.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from markdown-it-py>=2.2.0->rich->keras>=3.0.0->tensorflow~=2.16.1->tensorflow_decision_forests) (0.1.2)\r\n", "Using cached tensorflow_decision_forests-1.9.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (15.5 MB)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Using cached wurlitzer-3.0.3-py3-none-any.whl (7.3 kB)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Installing collected packages: wurlitzer, tensorflow_decision_forests\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Successfully installed tensorflow_decision_forests-1.9.0 wurlitzer-3.0.3\r\n" ] } ], "source": [ "# Install TensorFlow Dececision Forests\n", "!pip install tensorflow_decision_forests\n" ] }, { "cell_type": "markdown", "metadata": { "id": "2EFndCFdoJM5" }, "source": [ "[Wurlitzer](https://pypi.org/project/wurlitzer/) is needed to display the detailed training logs in Colabs (when using `verbose=2` in the model constructor)." ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "execution": { "iopub.execute_input": "2024-04-20T11:17:21.463115Z", "iopub.status.busy": "2024-04-20T11:17:21.462778Z", "iopub.status.idle": "2024-04-20T11:17:23.415566Z", "shell.execute_reply": "2024-04-20T11:17:23.414736Z" }, "id": "L06XWRdSoLj5" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: wurlitzer in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (3.0.3)\r\n" ] } ], "source": [ "!pip install wurlitzer" ] }, { "cell_type": "markdown", "metadata": { "id": "i7PlfbnxYcPf" }, "source": [ "Import the necessary libraries." ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "execution": { "iopub.execute_input": "2024-04-20T11:17:23.419981Z", "iopub.status.busy": "2024-04-20T11:17:23.419703Z", "iopub.status.idle": "2024-04-20T11:17:25.835024Z", "shell.execute_reply": "2024-04-20T11:17:25.834311Z" }, "id": "RsCV2oAS7gC_" }, "outputs": [], "source": [ "import os\n", "# Keep using Keras 2\n", "os.environ['TF_USE_LEGACY_KERAS'] = '1'\n", "\n", "import tensorflow_decision_forests as tfdf\n", "\n", "import numpy as np\n", "import pandas as pd\n", "import tensorflow as tf\n", "import tf_keras\n", "import math" ] }, { "cell_type": "markdown", "metadata": { "id": "w2fsI0y5x5i5" }, "source": [ "The hidden code cell limits the output height in colab." ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "cellView": "form", "execution": { "iopub.execute_input": "2024-04-20T11:17:25.839123Z", "iopub.status.busy": "2024-04-20T11:17:25.838781Z", "iopub.status.idle": "2024-04-20T11:17:25.843204Z", "shell.execute_reply": "2024-04-20T11:17:25.842432Z" }, "id": "jZXB4o6Tlu0i" }, "outputs": [], "source": [ "#@title\n", "\n", "from IPython.core.magic import register_line_magic\n", "from IPython.display import Javascript\n", "from IPython.display import display as ipy_display\n", "\n", "# Some of the model training logs can cover the full\n", "# screen if not compressed to a smaller viewport.\n", "# This magic allows setting a max height for a cell.\n", "@register_line_magic\n", "def set_cell_height(size):\n", " ipy_display(\n", " Javascript(\"google.colab.output.setIframeHeight(0, true, {maxHeight: \" +\n", " str(size) + \"})\"))" ] }, { "cell_type": "markdown", "metadata": { "id": "M_D4Ft4o65XT" }, "source": [ "## Use raw text as features\n", "\n", "TF-DF can consume [categorical-set](https://arxiv.org/pdf/2009.09991.pdf) features natively. Categorical-sets represent text features as bags of words (or n-grams).\n", "\n", "For example: `\"The little blue dog\" ` → `{\"the\", \"little\", \"blue\", \"dog\"}`\n", "\n", "In this example, you'll will train a Random Forest on the [Stanford Sentiment Treebank](https://nlp.stanford.edu/sentiment/index.html) (SST) dataset. The objective of this dataset is to classify sentences as carrying a *positive* or *negative* sentiment. You'll will use the binary classification version of the dataset curated in [TensorFlow Datasets](https://www.tensorflow.org/datasets/catalog/glue#gluesst2).\n", "\n", "**Note:** Categorical-set features can be expensive to train. In this colab, we will train a small Random Forest with 20 trees." ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "execution": { "iopub.execute_input": "2024-04-20T11:17:25.846511Z", "iopub.status.busy": "2024-04-20T11:17:25.846118Z", "iopub.status.idle": "2024-04-20T11:17:27.951356Z", "shell.execute_reply": "2024-04-20T11:17:27.950173Z" }, "id": "SgEiFy23j14S" }, "outputs": [], "source": [ "# Install the TensorFlow Datasets package\n", "!pip install tensorflow-datasets -U --quiet" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "execution": { "iopub.execute_input": "2024-04-20T11:17:27.955701Z", "iopub.status.busy": "2024-04-20T11:17:27.955431Z", "iopub.status.idle": "2024-04-20T11:17:31.953387Z", "shell.execute_reply": "2024-04-20T11:17:31.952581Z" }, "id": "uVN-j0E4Q1T3" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "{'idx': 163, 'label': -1, 'sentence': b'not even the hanson brothers can save it'}\n", "{'idx': 131, 'label': -1, 'sentence': b'strong setup and ambitious goals fade as the film descends into unsophisticated scare tactics and b-film thuggery .'}\n", "{'idx': 1579, 'label': -1, 'sentence': b'too timid to bring a sense of closure to an ugly chapter of the twentieth century .'}\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "2024-04-20 11:17:31.940774: W tensorflow/core/kernels/data/cache_dataset_ops.cc:858] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.\n", "2024-04-20 11:17:31.946131: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n" ] } ], "source": [ "# Load the dataset\n", "import tensorflow_datasets as tfds\n", "all_ds = tfds.load(\"glue/sst2\")\n", "\n", "# Display the first 3 examples of the test fold.\n", "for example in all_ds[\"test\"].take(3):\n", " print({attr_name: attr_tensor.numpy() for attr_name, attr_tensor in example.items()})" ] }, { "cell_type": "markdown", "metadata": { "id": "UHiQUWE2XDYN" }, "source": [ "The dataset is modified as follows:\n", "\n", "1. The raw labels are integers in `{-1, 1}`, but the learning algorithm expects positive integer labels e.g. `{0, 1}`. Therefore, the labels are transformed as follows: `new_labels = (original_labels + 1) / 2`.\n", "1. A batch-size of 64 is applied to make reading the dataset more efficient.\n", "1. The `sentence` attribute needs to be tokenized, i.e. `\"hello world\" -> [\"hello\", \"world\"]`.\n", "\n", "\n", "**Note:** This example doesn't use the `test` split of the dataset as it does not have labels. If `test` split had labels, you could concatenate the `validation` fold into the `train` one (e.g. `all_ds[\"train\"].concatenate(all_ds[\"validation\"])`).\n", "\n", "**Details:** Some decision forest learning algorithms do not need a validation dataset (e.g. Random Forests) while others do (e.g. Gradient Boosted Trees in some cases). Since each learning algorithm under TF-DF can use validation data differently, TF-DF handles train/validation splits internally. As a result, when you have a training and validation sets, they can always be concatenated as input to the learning algorithm." ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "execution": { "iopub.execute_input": "2024-04-20T11:17:31.957721Z", "iopub.status.busy": "2024-04-20T11:17:31.957055Z", "iopub.status.idle": "2024-04-20T11:17:32.052722Z", "shell.execute_reply": "2024-04-20T11:17:32.052058Z" }, "id": "yqYDKTKdSPYw" }, "outputs": [], "source": [ "def prepare_dataset(example):\n", " label = (example[\"label\"] + 1) // 2\n", " return {\"sentence\" : tf.strings.split(example[\"sentence\"])}, label\n", "\n", "train_ds = all_ds[\"train\"].batch(100).map(prepare_dataset)\n", "test_ds = all_ds[\"validation\"].batch(100).map(prepare_dataset)" ] }, { "cell_type": "markdown", "metadata": { "id": "YYkIjROI9w43" }, "source": [ "Finally, train and evaluate the model as usual. TF-DF automatically detects multi-valued categorical features as categorical-set.\n" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "execution": { "iopub.execute_input": "2024-04-20T11:17:32.056273Z", "iopub.status.busy": "2024-04-20T11:17:32.056042Z", "iopub.status.idle": "2024-04-20T11:18:28.384322Z", "shell.execute_reply": "2024-04-20T11:18:28.383620Z" }, "id": "mpxTtYo39wYZ" }, "outputs": [ { "data": { "application/javascript": [ "google.colab.output.setIframeHeight(0, true, {maxHeight: 300})" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Warning: The `num_threads` constructor argument is not set and the number of CPU is os.cpu_count()=32 > 32. Setting num_threads to 32. Set num_threads manually to use more than 32 cpus.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "WARNING:absl:The `num_threads` constructor argument is not set and the number of CPU is os.cpu_count()=32 > 32. Setting num_threads to 32. Set num_threads manually to use more than 32 cpus.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Use /tmpfs/tmp/tmpx28adgyq as temporary training directory\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Reading training dataset...\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Training tensor examples:\n", "Features: {'sentence': tf.RaggedTensor(values=Tensor(\"data:0\", shape=(None,), dtype=string), row_splits=Tensor(\"data_1:0\", shape=(None,), dtype=int64))}\n", "Label: Tensor(\"data_2:0\", shape=(None,), dtype=int64)\n", "Weights: None\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Normalized tensor features:\n", " {'sentence': SemanticTensor(semantic=, tensor=tf.RaggedTensor(values=Tensor(\"data:0\", shape=(None,), dtype=string), row_splits=Tensor(\"data_1:0\", shape=(None,), dtype=int64)))}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Training dataset read in 0:00:04.709443. Found 67349 examples.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Training model...\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Standard output detected as not visible to the user e.g. running in a notebook. Creating a training log redirection. If training gets stuck, try calling tfdf.keras.set_training_logs_redirection(False).\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "[INFO 24-04-20 11:17:36.8175 UTC kernel.cc:771] Start Yggdrasil model training\n", "[INFO 24-04-20 11:17:36.8176 UTC kernel.cc:772] Collect training examples\n", "[INFO 24-04-20 11:17:36.8176 UTC kernel.cc:785] Dataspec guide:\n", "column_guides {\n", " column_name_pattern: \"^__LABEL$\"\n", " type: CATEGORICAL\n", " categorial {\n", " min_vocab_frequency: 0\n", " max_vocab_count: -1\n", " }\n", "}\n", "default_column_guide {\n", " categorial {\n", " max_vocab_count: 2000\n", " }\n", " discretized_numerical {\n", " maximum_num_bins: 255\n", " }\n", "}\n", "ignore_columns_without_guides: false\n", "detect_numerical_as_discretized_numerical: false\n", "\n", "[INFO 24-04-20 11:17:36.8179 UTC kernel.cc:391] Number of batches: 674\n", "[INFO 24-04-20 11:17:36.8180 UTC kernel.cc:392] Number of examples: 67349\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "[INFO 24-04-20 11:17:36.8602 UTC data_spec_inference.cc:305] 12816 item(s) have been pruned (i.e. they are considered out of dictionary) for the column sentence (2000 item(s) left) because min_value_count=5 and max_number_of_unique_values=2000\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "[INFO 24-04-20 11:17:36.9136 UTC kernel.cc:792] Training dataset:\n", "Number of records: 67349\n", "Number of columns: 2\n", "\n", "Number of columns by type:\n", "\tCATEGORICAL_SET: 1 (50%)\n", "\tCATEGORICAL: 1 (50%)\n", "\n", "Columns:\n", "\n", "CATEGORICAL_SET: 1 (50%)\n", "\t1: \"sentence\" CATEGORICAL_SET has-dict vocab-size:2001 num-oods:10187 (15.1257%) most-frequent:\"the\" 27205 (40.3941%)\n", "\n", "CATEGORICAL: 1 (50%)\n", "\t0: \"__LABEL\" CATEGORICAL integerized vocab-size:3 no-ood-item\n", "\n", "Terminology:\n", "\tnas: Number of non-available (i.e. missing) values.\n", "\tood: Out of dictionary.\n", "\tmanually-defined: Attribute whose type is manually defined by the user, i.e., the type was not automatically inferred.\n", "\ttokenized: The attribute value is obtained through tokenization.\n", "\thas-dict: The attribute is attached to a string dictionary e.g. a categorical attribute stored as a string.\n", "\tvocab-size: Number of unique values.\n", "\n", "[INFO 24-04-20 11:17:36.9137 UTC kernel.cc:808] Configure learner\n", "[INFO 24-04-20 11:17:36.9139 UTC kernel.cc:822] Training config:\n", "learner: \"RANDOM_FOREST\"\n", "features: \"^sentence$\"\n", "label: \"^__LABEL$\"\n", "task: CLASSIFICATION\n", "random_seed: 123456\n", "metadata {\n", " framework: \"TF Keras\"\n", "}\n", "pure_serving_model: false\n", "[yggdrasil_decision_forests.model.random_forest.proto.random_forest_config] {\n", " num_trees: 30\n", " decision_tree {\n", " max_depth: 16\n", " min_examples: 5\n", " in_split_min_examples_check: true\n", " keep_non_leaf_label_distribution: true\n", " num_candidate_attributes: 0\n", " missing_value_policy: GLOBAL_IMPUTATION\n", " allow_na_conditions: false\n", " categorical_set_greedy_forward {\n", " sampling: 0.1\n", " max_num_items: -1\n", " min_item_frequency: 1\n", " }\n", " growing_strategy_local {\n", " }\n", " categorical {\n", " cart {\n", " }\n", " }\n", " axis_aligned_split {\n", " }\n", " internal {\n", " sorting_strategy: PRESORTED\n", " }\n", " uplift {\n", " min_examples_in_treatment: 5\n", " split_score: KULLBACK_LEIBLER\n", " }\n", " }\n", " winner_take_all_inference: true\n", " compute_oob_performances: true\n", " compute_oob_variable_importances: false\n", " num_oob_variable_importances_permutations: 1\n", " bootstrap_training_dataset: true\n", " bootstrap_size_ratio: 1\n", " adapt_bootstrap_size_ratio_for_maximum_training_duration: false\n", " sampling_with_replacement: true\n", "}\n", "\n", "[INFO 24-04-20 11:17:36.9143 UTC kernel.cc:825] Deployment config:\n", "cache_path: \"/tmpfs/tmp/tmpx28adgyq/working_cache\"\n", "num_threads: 32\n", "try_resume_training: true\n", "\n", "[INFO 24-04-20 11:17:36.9145 UTC kernel.cc:887] Train model\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "[INFO 24-04-20 11:17:36.9152 UTC random_forest.cc:416] Training random forest on 67349 example(s) and 1 feature(s).\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "[INFO 24-04-20 11:18:08.8767 UTC random_forest.cc:802] Training of tree 1/30 (tree index:1) done accuracy:0.7412 logloss:9.32811\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "[INFO 24-04-20 11:18:19.2026 UTC random_forest.cc:802] Training of tree 6/30 (tree index:27) done accuracy:0.775555 logloss:4.88012\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "[INFO 24-04-20 11:18:21.1285 UTC random_forest.cc:802] Training of tree 16/30 (tree index:25) done accuracy:0.808699 logloss:1.679\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "[INFO 24-04-20 11:18:22.6848 UTC random_forest.cc:802] Training of tree 26/30 (tree index:8) done accuracy:0.818557 logloss:0.904858\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "[INFO 24-04-20 11:18:24.0005 UTC random_forest.cc:802] Training of tree 30/30 (tree index:6) done accuracy:0.821274 logloss:0.854486\n", "[INFO 24-04-20 11:18:24.0013 UTC random_forest.cc:882] Final OOB metrics: accuracy:0.821274 logloss:0.854486\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "[INFO 24-04-20 11:18:24.0104 UTC kernel.cc:919] Export model in log directory: /tmpfs/tmp/tmpx28adgyq with prefix da59c2f23fdb4012\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "[INFO 24-04-20 11:18:24.0388 UTC kernel.cc:937] Save model in resources\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "[INFO 24-04-20 11:18:24.0420 UTC abstract_model.cc:881] Model self evaluation:\n", "Number of predictions (without weights): 67349\n", "Number of predictions (with weights): 67349\n", "Task: CLASSIFICATION\n", "Label: __LABEL\n", "\n", "Accuracy: 0.821274 CI95[W][0.818828 0.8237]\n", "LogLoss: : 0.854486\n", "ErrorRate: : 0.178726\n", "\n", "Default Accuracy: : 0.557826\n", "Default LogLoss: : 0.686445\n", "Default ErrorRate: : 0.442174\n", "\n", "Confusion Table:\n", "truth\\prediction\n", " 1 2\n", "1 19593 10187\n", "2 1850 35719\n", "Total: 67349\n", "\n", "\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "[INFO 24-04-20 11:18:24.0658 UTC kernel.cc:1233] Loading model from path /tmpfs/tmp/tmpx28adgyq/model/ with prefix da59c2f23fdb4012\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "[INFO 24-04-20 11:18:24.3112 UTC decision_forest.cc:734] Model loaded with 30 root(s), 43180 node(s), and 1 input feature(s).\n", "[INFO 24-04-20 11:18:24.3113 UTC abstract_model.cc:1344] Engine \"RandomForestGeneric\" built\n", "[INFO 24-04-20 11:18:24.3113 UTC kernel.cc:1061] Use fast generic engine\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Model trained in 0:00:47.515581\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Compiling model...\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Model compiled.\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "%set_cell_height 300\n", "\n", "# Specify the model.\n", "model_1 = tfdf.keras.RandomForestModel(num_trees=30, verbose=2)\n", "\n", "# Train the model.\n", "model_1.fit(x=train_ds)" ] }, { "cell_type": "markdown", "metadata": { "id": "D9FMFGzwiHCt" }, "source": [ "In the previous logs, note that `sentence` is a `CATEGORICAL_SET` feature.\n", "\n", "The model is evaluated as usual:" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "execution": { "iopub.execute_input": "2024-04-20T11:18:28.388062Z", "iopub.status.busy": "2024-04-20T11:18:28.387416Z", "iopub.status.idle": "2024-04-20T11:18:32.437207Z", "shell.execute_reply": "2024-04-20T11:18:32.436535Z" }, "id": "cpf-wHl094S1" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\r", "1/9 [==>...........................] - ETA: 31s - loss: 0.0000e+00 - accuracy: 0.8100" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "9/9 [==============================] - 4s 4ms/step - loss: 0.0000e+00 - accuracy: 0.7638\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "BinaryCrossentropyloss: 0.0\n", "Accuracy: 0.7637614607810974\n" ] } ], "source": [ "model_1.compile(metrics=[\"accuracy\"])\n", "evaluation = model_1.evaluate(test_ds)\n", "\n", "print(f\"BinaryCrossentropyloss: {evaluation[0]}\")\n", "print(f\"Accuracy: {evaluation[1]}\")" ] }, { "cell_type": "markdown", "metadata": { "id": "YliBX4GtjncQ" }, "source": [ "The training logs looks are follow:" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "execution": { "iopub.execute_input": "2024-04-20T11:18:32.440877Z", "iopub.status.busy": "2024-04-20T11:18:32.440603Z", "iopub.status.idle": "2024-04-20T11:18:33.080965Z", "shell.execute_reply": "2024-04-20T11:18:33.080264Z" }, "id": "OnTTtBNmjpo7" }, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "import matplotlib.pyplot as plt\n", "\n", "logs = model_1.make_inspector().training_logs()\n", "plt.plot([log.num_trees for log in logs], [log.evaluation.accuracy for log in logs])\n", "plt.xlabel(\"Number of trees\")\n", "plt.ylabel(\"Out-of-bag accuracy\")\n", "pass" ] }, { "cell_type": "markdown", "metadata": { "id": "d4qJ0ig3kgic" }, "source": [ "More trees would probably be beneficial (I am sure of it because I tried :p)." ] }, { "cell_type": "markdown", "metadata": { "id": "Iil_oyOhCNx6" }, "source": [ "## Use a pretrained text embedding\n", "\n", "The previous example trained a Random Forest using raw text features. This example will use a pre-trained TF-Hub embedding to convert text features into a dense embedding, and then train a Random Forest on top of it. In this situation, the Random Forest will only \"see\" the numerical output of the embedding (i.e. it will not see the raw text). \n", "\n", "In this experiment, will use the [Universal-Sentence-Encoder](https://tfhub.dev/google/universal-sentence-encoder/4). Different pre-trained embeddings might be suited for different types of text (e.g. different language, different task) but also for other type of structured features (e.g. images).\n", "\n", "**Note:** This embedding is large (1GB) and therefore the final model will be slow to run (compared to classical decision tree inference).\n", "\n", "The embedding module can be applied in one of two places:\n", "\n", "1. During the dataset preparation.\n", "2. In the pre-processing stage of the model.\n", "\n", "The second option is often preferable: Packaging the embedding in the model makes the model easier to use (and harder to misuse).\n", "\n", "First install TF-Hub:" ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "execution": { "iopub.execute_input": "2024-04-20T11:18:33.084882Z", "iopub.status.busy": "2024-04-20T11:18:33.084423Z", "iopub.status.idle": "2024-04-20T11:18:35.276538Z", "shell.execute_reply": "2024-04-20T11:18:35.275411Z" }, "id": "QfYGXim_DskC" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: tensorflow-hub in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (0.16.1)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: numpy>=1.12.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow-hub) (1.26.4)\r\n", "Requirement already satisfied: protobuf>=3.19.6 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow-hub) (3.20.3)\r\n", "Requirement already satisfied: tf-keras>=2.14.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow-hub) (2.16.0)\r\n", "Requirement already satisfied: tensorflow<2.17,>=2.16 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tf-keras>=2.14.1->tensorflow-hub) (2.16.1)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: absl-py>=1.0.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow<2.17,>=2.16->tf-keras>=2.14.1->tensorflow-hub) (1.4.0)\r\n", "Requirement already satisfied: astunparse>=1.6.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow<2.17,>=2.16->tf-keras>=2.14.1->tensorflow-hub) (1.6.3)\r\n", "Requirement already satisfied: flatbuffers>=23.5.26 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow<2.17,>=2.16->tf-keras>=2.14.1->tensorflow-hub) (24.3.25)\r\n", "Requirement already satisfied: gast!=0.5.0,!=0.5.1,!=0.5.2,>=0.2.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow<2.17,>=2.16->tf-keras>=2.14.1->tensorflow-hub) (0.5.4)\r\n", "Requirement already satisfied: google-pasta>=0.1.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow<2.17,>=2.16->tf-keras>=2.14.1->tensorflow-hub) (0.2.0)\r\n", "Requirement already satisfied: h5py>=3.10.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow<2.17,>=2.16->tf-keras>=2.14.1->tensorflow-hub) (3.11.0)\r\n", "Requirement already satisfied: libclang>=13.0.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow<2.17,>=2.16->tf-keras>=2.14.1->tensorflow-hub) (18.1.1)\r\n", "Requirement already satisfied: ml-dtypes~=0.3.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow<2.17,>=2.16->tf-keras>=2.14.1->tensorflow-hub) (0.3.2)\r\n", "Requirement already satisfied: opt-einsum>=2.3.2 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow<2.17,>=2.16->tf-keras>=2.14.1->tensorflow-hub) (3.3.0)\r\n", "Requirement already satisfied: packaging in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow<2.17,>=2.16->tf-keras>=2.14.1->tensorflow-hub) (24.0)\r\n", "Requirement already satisfied: requests<3,>=2.21.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow<2.17,>=2.16->tf-keras>=2.14.1->tensorflow-hub) (2.31.0)\r\n", "Requirement already satisfied: setuptools in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow<2.17,>=2.16->tf-keras>=2.14.1->tensorflow-hub) (69.5.1)\r\n", "Requirement already satisfied: six>=1.12.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow<2.17,>=2.16->tf-keras>=2.14.1->tensorflow-hub) (1.16.0)\r\n", "Requirement already satisfied: termcolor>=1.1.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow<2.17,>=2.16->tf-keras>=2.14.1->tensorflow-hub) (2.4.0)\r\n", "Requirement already satisfied: typing-extensions>=3.6.6 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow<2.17,>=2.16->tf-keras>=2.14.1->tensorflow-hub) (4.11.0)\r\n", "Requirement already satisfied: wrapt>=1.11.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow<2.17,>=2.16->tf-keras>=2.14.1->tensorflow-hub) (1.16.0)\r\n", "Requirement already satisfied: grpcio<2.0,>=1.24.3 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow<2.17,>=2.16->tf-keras>=2.14.1->tensorflow-hub) (1.63.0rc2)\r\n", "Requirement already satisfied: tensorboard<2.17,>=2.16 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow<2.17,>=2.16->tf-keras>=2.14.1->tensorflow-hub) (2.16.2)\r\n", "Requirement already satisfied: keras>=3.0.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow<2.17,>=2.16->tf-keras>=2.14.1->tensorflow-hub) (3.2.1)\r\n", "Requirement already satisfied: tensorflow-io-gcs-filesystem>=0.23.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow<2.17,>=2.16->tf-keras>=2.14.1->tensorflow-hub) (0.36.0)\r\n", "Requirement already satisfied: wheel<1.0,>=0.23.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from astunparse>=1.6.0->tensorflow<2.17,>=2.16->tf-keras>=2.14.1->tensorflow-hub) (0.41.2)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: rich in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from keras>=3.0.0->tensorflow<2.17,>=2.16->tf-keras>=2.14.1->tensorflow-hub) (13.7.1)\r\n", "Requirement already satisfied: namex in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from keras>=3.0.0->tensorflow<2.17,>=2.16->tf-keras>=2.14.1->tensorflow-hub) (0.0.8)\r\n", "Requirement already satisfied: optree in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from keras>=3.0.0->tensorflow<2.17,>=2.16->tf-keras>=2.14.1->tensorflow-hub) (0.11.0)\r\n", "Requirement already satisfied: charset-normalizer<4,>=2 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from requests<3,>=2.21.0->tensorflow<2.17,>=2.16->tf-keras>=2.14.1->tensorflow-hub) (3.3.2)\r\n", "Requirement already satisfied: idna<4,>=2.5 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from requests<3,>=2.21.0->tensorflow<2.17,>=2.16->tf-keras>=2.14.1->tensorflow-hub) (3.7)\r\n", "Requirement already satisfied: urllib3<3,>=1.21.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from requests<3,>=2.21.0->tensorflow<2.17,>=2.16->tf-keras>=2.14.1->tensorflow-hub) (2.2.1)\r\n", "Requirement already satisfied: certifi>=2017.4.17 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from requests<3,>=2.21.0->tensorflow<2.17,>=2.16->tf-keras>=2.14.1->tensorflow-hub) (2024.2.2)\r\n", "Requirement already satisfied: markdown>=2.6.8 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorboard<2.17,>=2.16->tensorflow<2.17,>=2.16->tf-keras>=2.14.1->tensorflow-hub) (3.6)\r\n", "Requirement already satisfied: tensorboard-data-server<0.8.0,>=0.7.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorboard<2.17,>=2.16->tensorflow<2.17,>=2.16->tf-keras>=2.14.1->tensorflow-hub) (0.7.2)\r\n", "Requirement already satisfied: werkzeug>=1.0.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorboard<2.17,>=2.16->tensorflow<2.17,>=2.16->tf-keras>=2.14.1->tensorflow-hub) (3.0.2)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: importlib-metadata>=4.4 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from markdown>=2.6.8->tensorboard<2.17,>=2.16->tensorflow<2.17,>=2.16->tf-keras>=2.14.1->tensorflow-hub) (7.1.0)\r\n", "Requirement already satisfied: MarkupSafe>=2.1.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from werkzeug>=1.0.1->tensorboard<2.17,>=2.16->tensorflow<2.17,>=2.16->tf-keras>=2.14.1->tensorflow-hub) (2.1.5)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: markdown-it-py>=2.2.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from rich->keras>=3.0.0->tensorflow<2.17,>=2.16->tf-keras>=2.14.1->tensorflow-hub) (3.0.0)\r\n", "Requirement already satisfied: pygments<3.0.0,>=2.13.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from rich->keras>=3.0.0->tensorflow<2.17,>=2.16->tf-keras>=2.14.1->tensorflow-hub) (2.17.2)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: zipp>=0.5 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from importlib-metadata>=4.4->markdown>=2.6.8->tensorboard<2.17,>=2.16->tensorflow<2.17,>=2.16->tf-keras>=2.14.1->tensorflow-hub) (3.18.1)\r\n", "Requirement already satisfied: mdurl~=0.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from markdown-it-py>=2.2.0->rich->keras>=3.0.0->tensorflow<2.17,>=2.16->tf-keras>=2.14.1->tensorflow-hub) (0.1.2)\r\n" ] } ], "source": [ "!pip install --upgrade tensorflow-hub" ] }, { "cell_type": "markdown", "metadata": { "id": "kNSEhJgjEXww" }, "source": [ "Unlike before, you don't need to tokenize the text." ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "execution": { "iopub.execute_input": "2024-04-20T11:18:35.281111Z", "iopub.status.busy": "2024-04-20T11:18:35.280804Z", "iopub.status.idle": "2024-04-20T11:18:35.342789Z", "shell.execute_reply": "2024-04-20T11:18:35.342218Z" }, "id": "pS5SYqoScbOc" }, "outputs": [], "source": [ "def prepare_dataset(example):\n", " label = (example[\"label\"] + 1) // 2\n", " return {\"sentence\" : example[\"sentence\"]}, label\n", "\n", "train_ds = all_ds[\"train\"].batch(100).map(prepare_dataset)\n", "test_ds = all_ds[\"validation\"].batch(100).map(prepare_dataset)\n" ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "execution": { "iopub.execute_input": "2024-04-20T11:18:35.346035Z", "iopub.status.busy": "2024-04-20T11:18:35.345759Z", "iopub.status.idle": "2024-04-20T11:19:30.805044Z", "shell.execute_reply": "2024-04-20T11:19:30.804385Z" }, "id": "zHEsd8q_ESpC" }, "outputs": [ { "data": { "application/javascript": [ "google.colab.output.setIframeHeight(0, true, {maxHeight: 300})" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Warning: The `num_threads` constructor argument is not set and the number of CPU is os.cpu_count()=32 > 32. Setting num_threads to 32. Set num_threads manually to use more than 32 cpus.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "WARNING:absl:The `num_threads` constructor argument is not set and the number of CPU is os.cpu_count()=32 > 32. Setting num_threads to 32. Set num_threads manually to use more than 32 cpus.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Use /tmpfs/tmp/tmpv_kuhy0b as temporary training directory\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Reading training dataset...\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Training dataset read in 0:00:24.071412. Found 67349 examples.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Training model...\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "[INFO 24-04-20 11:19:25.9064 UTC kernel.cc:1233] Loading model from path /tmpfs/tmp/tmpv_kuhy0b/model/ with prefix 36a4a9d3f10743e4\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Model trained in 0:00:13.926431\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Compiling model...\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "[INFO 24-04-20 11:19:27.6042 UTC decision_forest.cc:734] Model loaded with 100 root(s), 565608 node(s), and 512 input feature(s).\n", "[INFO 24-04-20 11:19:27.6044 UTC abstract_model.cc:1344] Engine \"RandomForestOptPred\" built\n", "[INFO 24-04-20 11:19:27.6045 UTC kernel.cc:1061] Use fast generic engine\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Model compiled.\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "%set_cell_height 300\n", "\n", "import tensorflow_hub as hub\n", "# NNLM (https://tfhub.dev/google/nnlm-en-dim128/2) is also a good choice.\n", "hub_url = \"https://tfhub.dev/google/universal-sentence-encoder/4\"\n", "embedding = hub.KerasLayer(hub_url)\n", "\n", "sentence = tf_keras.layers.Input(shape=(), name=\"sentence\", dtype=tf.string)\n", "embedded_sentence = embedding(sentence)\n", "\n", "raw_inputs = {\"sentence\": sentence}\n", "processed_inputs = {\"embedded_sentence\": embedded_sentence}\n", "preprocessor = tf_keras.Model(inputs=raw_inputs, outputs=processed_inputs)\n", "\n", "model_2 = tfdf.keras.RandomForestModel(\n", " preprocessing=preprocessor,\n", " num_trees=100)\n", "\n", "model_2.fit(x=train_ds)" ] }, { "cell_type": "code", "execution_count": 15, "metadata": { "execution": { "iopub.execute_input": "2024-04-20T11:19:30.808565Z", "iopub.status.busy": "2024-04-20T11:19:30.808272Z", "iopub.status.idle": "2024-04-20T11:19:32.752540Z", "shell.execute_reply": "2024-04-20T11:19:32.751811Z" }, "id": "xPLoDqiFKY18" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\r", "1/9 [==>...........................] - ETA: 14s - loss: 0.0000e+00 - accuracy: 0.8000" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "4/9 [============>.................] - ETA: 0s - loss: 0.0000e+00 - accuracy: 0.8100 " ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "7/9 [======================>.......] - ETA: 0s - loss: 0.0000e+00 - accuracy: 0.7914" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "9/9 [==============================] - 2s 18ms/step - loss: 0.0000e+00 - accuracy: 0.7878\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "BinaryCrossentropyloss: 0.0\n", "Accuracy: 0.7878440618515015\n" ] } ], "source": [ "model_2.compile(metrics=[\"accuracy\"])\n", "evaluation = model_2.evaluate(test_ds)\n", "\n", "print(f\"BinaryCrossentropyloss: {evaluation[0]}\")\n", "print(f\"Accuracy: {evaluation[1]}\")" ] }, { "cell_type": "markdown", "metadata": { "id": "WPsD3LyaMLHm" }, "source": [ "Note that categorical sets represent text differently from a dense embedding, so it may be useful to use both strategies jointly." ] }, { "cell_type": "markdown", "metadata": { "id": "37AGJamzboZQ" }, "source": [ "## Train a decision tree and neural network together\n", "\n", "The previous example used a pre-trained Neural Network (NN) to \n", "process the text features before passing them to the Random Forest. This example will train both the Neural Network and the Random Forest from scratch.\n" ] }, { "cell_type": "markdown", "metadata": { "id": "YJIxGwwzMkFl" }, "source": [ "TF-DF's Decision Forests do not back-propagate gradients ([although this is the subject of ongoing research](https://arxiv.org/abs/2007.14761)). Therefore, the training happens in two stages:\n", "\n", "1. Train the neural-network as a standard classification task:\n", "\n", "```\n", "example → [Normalize] → [Neural Network*] → [classification head] → prediction\n", "*: Training.\n", "```\n", "\n", "2. Replace the Neural Network's head (the last layer and the soft-max) with a Random Forest. Train the Random Forest as usual:\n", "\n", "```\n", "example → [Normalize] → [Neural Network] → [Random Forest*] → prediction\n", "*: Training.\n", "```\n", "\n" ] }, { "cell_type": "markdown", "metadata": { "id": "YSIvuAhzbjWO" }, "source": [ "### Prepare the dataset\n", "\n", "This example uses the [Palmer's Penguins](https://allisonhorst.github.io/palmerpenguins/articles/intro.html) dataset. See the [Beginner colab](beginner_colab.ipynb) for details." ] }, { "cell_type": "markdown", "metadata": { "id": "InUot_K2b3Mz" }, "source": [ "First, download the raw data:" ] }, { "cell_type": "code", "execution_count": 16, "metadata": { "execution": { "iopub.execute_input": "2024-04-20T11:19:32.756956Z", "iopub.status.busy": "2024-04-20T11:19:32.756326Z", "iopub.status.idle": "2024-04-20T11:19:33.024241Z", "shell.execute_reply": "2024-04-20T11:19:33.023124Z" }, "id": "rNyaeCx0b1be" }, "outputs": [], "source": [ "!wget -q https://storage.googleapis.com/download.tensorflow.org/data/palmer_penguins/penguins.csv -O /tmp/penguins.csv" ] }, { "cell_type": "markdown", "metadata": { "id": "pNPZzQekb9z_" }, "source": [ "Load a dataset into a Pandas Dataframe." ] }, { "cell_type": "code", "execution_count": 17, "metadata": { "execution": { "iopub.execute_input": "2024-04-20T11:19:33.028635Z", "iopub.status.busy": "2024-04-20T11:19:33.028092Z", "iopub.status.idle": "2024-04-20T11:19:33.047052Z", "shell.execute_reply": "2024-04-20T11:19:33.046432Z" }, "id": "9lA3peQ4sa9a" }, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
speciesislandbill_length_mmbill_depth_mmflipper_length_mmbody_mass_gsexyear
0AdelieTorgersen39.118.7181.03750.0male2007
1AdelieTorgersen39.517.4186.03800.0female2007
2AdelieTorgersen40.318.0195.03250.0female2007
\n", "
" ], "text/plain": [ " species island bill_length_mm bill_depth_mm flipper_length_mm \\\n", "0 Adelie Torgersen 39.1 18.7 181.0 \n", "1 Adelie Torgersen 39.5 17.4 186.0 \n", "2 Adelie Torgersen 40.3 18.0 195.0 \n", "\n", " body_mass_g sex year \n", "0 3750.0 male 2007 \n", "1 3800.0 female 2007 \n", "2 3250.0 female 2007 " ] }, "execution_count": 17, "metadata": {}, "output_type": "execute_result" } ], "source": [ "dataset_df = pd.read_csv(\"/tmp/penguins.csv\")\n", "\n", "# Display the first 3 examples.\n", "dataset_df.head(3)" ] }, { "cell_type": "markdown", "metadata": { "id": "v-_SZpRWcAoX" }, "source": [ "\n", "Prepare the dataset for training." ] }, { "cell_type": "code", "execution_count": 18, "metadata": { "execution": { "iopub.execute_input": "2024-04-20T11:19:33.049863Z", "iopub.status.busy": "2024-04-20T11:19:33.049632Z", "iopub.status.idle": "2024-04-20T11:19:33.055291Z", "shell.execute_reply": "2024-04-20T11:19:33.054688Z" }, "id": "rtyi8UoqtzhM" }, "outputs": [], "source": [ "label = \"species\"\n", "\n", "# Replaces numerical NaN (representing missing values in Pandas Dataframe) with 0s.\n", "# ...Neural Nets don't work well with numerical NaNs.\n", "for col in dataset_df.columns:\n", " if dataset_df[col].dtype not in [str, object]:\n", " dataset_df[col] = dataset_df[col].fillna(0)" ] }, { "cell_type": "code", "execution_count": 19, "metadata": { "execution": { "iopub.execute_input": "2024-04-20T11:19:33.057935Z", "iopub.status.busy": "2024-04-20T11:19:33.057696Z", "iopub.status.idle": "2024-04-20T11:19:33.101492Z", "shell.execute_reply": "2024-04-20T11:19:33.100892Z" }, "id": "GKrW5Yfjso0k" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "248 examples in training, 96 examples for testing.\n" ] } ], "source": [ "# Split the dataset into a training and testing dataset.\n", "\n", "def split_dataset(dataset, test_ratio=0.30):\n", " \"\"\"Splits a panda dataframe in two.\"\"\"\n", " test_indices = np.random.rand(len(dataset)) < test_ratio\n", " return dataset[~test_indices], dataset[test_indices]\n", "\n", "train_ds_pd, test_ds_pd = split_dataset(dataset_df)\n", "print(\"{} examples in training, {} examples for testing.\".format(\n", " len(train_ds_pd), len(test_ds_pd)))\n", "\n", "# Convert the datasets into tensorflow datasets\n", "train_ds = tfdf.keras.pd_dataframe_to_tf_dataset(train_ds_pd, label=label)\n", "test_ds = tfdf.keras.pd_dataframe_to_tf_dataset(test_ds_pd, label=label)" ] }, { "cell_type": "markdown", "metadata": { "id": "ore7f6tgcOMh" }, "source": [ "### Build the models\n", "\n", "Next create the neural network model using [Keras' functional style](https://www.tensorflow.org/guide/keras/functional). \n", "\n", "To keep the example simple this model only uses two inputs." ] }, { "cell_type": "code", "execution_count": 20, "metadata": { "execution": { "iopub.execute_input": "2024-04-20T11:19:33.104602Z", "iopub.status.busy": "2024-04-20T11:19:33.104316Z", "iopub.status.idle": "2024-04-20T11:19:33.110653Z", "shell.execute_reply": "2024-04-20T11:19:33.110062Z" }, "id": "S1Jfe4YteBqY" }, "outputs": [], "source": [ "input_1 = tf_keras.Input(shape=(1,), name=\"bill_length_mm\", dtype=\"float\")\n", "input_2 = tf_keras.Input(shape=(1,), name=\"island\", dtype=\"string\")\n", "\n", "nn_raw_inputs = [input_1, input_2]" ] }, { "cell_type": "markdown", "metadata": { "id": "ZjlvAUNGeDM8" }, "source": [ "Use [preprocessing layers](https://www.tensorflow.org/guide/keras/preprocessing_layers) to convert the raw inputs to inputs appropriate for the neural network. " ] }, { "cell_type": "code", "execution_count": 21, "metadata": { "execution": { "iopub.execute_input": "2024-04-20T11:19:33.113628Z", "iopub.status.busy": "2024-04-20T11:19:33.113377Z", "iopub.status.idle": "2024-04-20T11:19:37.057164Z", "shell.execute_reply": "2024-04-20T11:19:37.056543Z" }, "id": "9Q09Nkp6ei21" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:max_tokens is deprecated, please use num_tokens instead.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "WARNING:tensorflow:max_tokens is deprecated, please use num_tokens instead.\n" ] } ], "source": [ "# Normalization.\n", "Normalization = tf_keras.layers.Normalization\n", "CategoryEncoding = tf_keras.layers.CategoryEncoding\n", "StringLookup = tf_keras.layers.StringLookup\n", "\n", "values = train_ds_pd[\"bill_length_mm\"].values[:, tf.newaxis]\n", "input_1_normalizer = Normalization()\n", "input_1_normalizer.adapt(values)\n", "\n", "values = train_ds_pd[\"island\"].values\n", "input_2_indexer = StringLookup(max_tokens=32)\n", "input_2_indexer.adapt(values)\n", "\n", "input_2_onehot = CategoryEncoding(output_mode=\"binary\", max_tokens=32)\n", "\n", "normalized_input_1 = input_1_normalizer(input_1)\n", "normalized_input_2 = input_2_onehot(input_2_indexer(input_2))\n", "\n", "nn_processed_inputs = [normalized_input_1, normalized_input_2]" ] }, { "cell_type": "markdown", "metadata": { "id": "ZCoQljyhelau" }, "source": [ "Build the body of the neural network:" ] }, { "cell_type": "code", "execution_count": 22, "metadata": { "execution": { "iopub.execute_input": "2024-04-20T11:19:37.060674Z", "iopub.status.busy": "2024-04-20T11:19:37.060420Z", "iopub.status.idle": "2024-04-20T11:19:37.110455Z", "shell.execute_reply": "2024-04-20T11:19:37.109862Z" }, "id": "KzocgbYNsH6y" }, "outputs": [], "source": [ "y = tf_keras.layers.Concatenate()(nn_processed_inputs)\n", "y = tf_keras.layers.Dense(16, activation=tf.nn.relu6)(y)\n", "last_layer = tf_keras.layers.Dense(8, activation=tf.nn.relu, name=\"last\")(y)\n", "\n", "# \"3\" for the three label classes. If it were a binary classification, the\n", "# output dim would be 1.\n", "classification_output = tf_keras.layers.Dense(3)(y)\n", "\n", "nn_model = tf_keras.models.Model(nn_raw_inputs, classification_output)" ] }, { "cell_type": "markdown", "metadata": { "id": "zPbRKf1CfIrj" }, "source": [ "This `nn_model` directly produces classification logits. \n", "\n", "Next create a decision forest model. This will operate on the high level features that the neural network extracts in the last layer before that classification head." ] }, { "cell_type": "code", "execution_count": 23, "metadata": { "execution": { "iopub.execute_input": "2024-04-20T11:19:37.113773Z", "iopub.status.busy": "2024-04-20T11:19:37.113508Z", "iopub.status.idle": "2024-04-20T11:19:37.126627Z", "shell.execute_reply": "2024-04-20T11:19:37.126037Z" }, "id": "7fnpGNyTuXvH" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Warning: The `num_threads` constructor argument is not set and the number of CPU is os.cpu_count()=32 > 32. Setting num_threads to 32. Set num_threads manually to use more than 32 cpus.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "WARNING:absl:The `num_threads` constructor argument is not set and the number of CPU is os.cpu_count()=32 > 32. Setting num_threads to 32. Set num_threads manually to use more than 32 cpus.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Use /tmpfs/tmp/tmpgsdeyzk5 as temporary training directory\n" ] } ], "source": [ "# To reduce the risk of mistakes, group both the decision forest and the\n", "# neural network in a single keras model.\n", "nn_without_head = tf_keras.models.Model(inputs=nn_model.inputs, outputs=last_layer)\n", "df_and_nn_model = tfdf.keras.RandomForestModel(preprocessing=nn_without_head)" ] }, { "cell_type": "markdown", "metadata": { "id": "trq07lvMudlz" }, "source": [ "### Train and evaluate the models\n", "\n", "The model will be trained in two stages. First train the neural network with its own classification head:" ] }, { "cell_type": "code", "execution_count": 24, "metadata": { "execution": { "iopub.execute_input": "2024-04-20T11:19:37.129864Z", "iopub.status.busy": "2024-04-20T11:19:37.129613Z", "iopub.status.idle": "2024-04-20T11:19:44.586783Z", "shell.execute_reply": "2024-04-20T11:19:44.585988Z" }, "id": "h4OyUWKiupuF" }, "outputs": [ { "data": { "application/javascript": [ "google.colab.output.setIframeHeight(0, true, {maxHeight: 300})" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 1/10\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/tmpfs/tmp/__autograph_generated_filetvdtpwer.py:63: UserWarning: Input dict contained keys ['bill_depth_mm', 'flipper_length_mm', 'body_mass_g', 'sex', 'year'] which did not match any model input. They will be ignored by the model.\n", " ag__.converted_call(ag__.ld(warnings).warn, (ag__.converted_call('Input dict contained keys {} which did not match any model input. They will be ignored by the model.'.format, ([ag__.ld(n) for n in ag__.converted_call(ag__.ld(tensors).keys, (), None, fscope) if ag__.ld(n) not in ag__.ld(ref_input_names)],), None, fscope),), dict(stacklevel=2), fscope)\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "WARNING: All log messages before absl::InitializeLog() is called are written to STDERR\n", "I0000 00:00:1713611983.477059 22965 service.cc:145] XLA service 0x7ff3c4155d00 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:\n", "I0000 00:00:1713611983.477103 22965 service.cc:153] StreamExecutor device (0): Tesla T4, Compute Capability 7.5\n", "I0000 00:00:1713611983.477109 22965 service.cc:153] StreamExecutor device (1): Tesla T4, Compute Capability 7.5\n", "I0000 00:00:1713611983.477112 22965 service.cc:153] StreamExecutor device (2): Tesla T4, Compute Capability 7.5\n", "I0000 00:00:1713611983.477115 22965 service.cc:153] StreamExecutor device (3): Tesla T4, Compute Capability 7.5\n", "I0000 00:00:1713611983.604064 22965 device_compiler.h:188] Compiled cluster using XLA! This line is logged at most once for the lifetime of the process.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "1/1 [==============================] - ETA: 0s - loss: 1.0737 - accuracy: 0.7137" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "1/1 [==============================] - 7s 7s/step - loss: 1.0737 - accuracy: 0.7137 - val_loss: 1.0569 - val_accuracy: 0.7500\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 2/10\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "1/1 [==============================] - ETA: 0s - loss: 1.0696 - accuracy: 0.7177" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "1/1 [==============================] - 0s 24ms/step - loss: 1.0696 - accuracy: 0.7177 - val_loss: 1.0529 - val_accuracy: 0.7500\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 3/10\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "1/1 [==============================] - ETA: 0s - loss: 1.0655 - accuracy: 0.7177" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "1/1 [==============================] - 0s 22ms/step - loss: 1.0655 - accuracy: 0.7177 - val_loss: 1.0489 - val_accuracy: 0.7500\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 4/10\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "1/1 [==============================] - ETA: 0s - loss: 1.0614 - accuracy: 0.7218" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "1/1 [==============================] - 0s 22ms/step - loss: 1.0614 - accuracy: 0.7218 - val_loss: 1.0450 - val_accuracy: 0.7500\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 5/10\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "1/1 [==============================] - ETA: 0s - loss: 1.0574 - accuracy: 0.7258" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "1/1 [==============================] - 0s 23ms/step - loss: 1.0574 - accuracy: 0.7258 - val_loss: 1.0410 - val_accuracy: 0.7500\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 6/10\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "1/1 [==============================] - ETA: 0s - loss: 1.0533 - accuracy: 0.7298" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "1/1 [==============================] - 0s 22ms/step - loss: 1.0533 - accuracy: 0.7298 - val_loss: 1.0371 - val_accuracy: 0.7708\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 7/10\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "1/1 [==============================] - ETA: 0s - loss: 1.0494 - accuracy: 0.7339" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "1/1 [==============================] - 0s 22ms/step - loss: 1.0494 - accuracy: 0.7339 - val_loss: 1.0332 - val_accuracy: 0.7708\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 8/10\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "1/1 [==============================] - ETA: 0s - loss: 1.0454 - accuracy: 0.7379" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "1/1 [==============================] - 0s 25ms/step - loss: 1.0454 - accuracy: 0.7379 - val_loss: 1.0293 - val_accuracy: 0.7708\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 9/10\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "1/1 [==============================] - ETA: 0s - loss: 1.0415 - accuracy: 0.7419" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "1/1 [==============================] - 0s 22ms/step - loss: 1.0415 - accuracy: 0.7419 - val_loss: 1.0254 - val_accuracy: 0.7812\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 10/10\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "1/1 [==============================] - ETA: 0s - loss: 1.0376 - accuracy: 0.7460" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "1/1 [==============================] - 0s 22ms/step - loss: 1.0376 - accuracy: 0.7460 - val_loss: 1.0217 - val_accuracy: 0.7812\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Model: \"model_1\"\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "__________________________________________________________________________________________________\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " Layer (type) Output Shape Param # Connected to \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "==================================================================================================\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " island (InputLayer) [(None, 1)] 0 [] \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " bill_length_mm (InputLayer [(None, 1)] 0 [] \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " ) \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " string_lookup (StringLooku (None, 1) 0 ['island[0][0]'] \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " p) \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " normalization (Normalizati (None, 1) 3 ['bill_length_mm[0][0]'] \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " on) \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " category_encoding (Categor (None, 32) 0 ['string_lookup[0][0]'] \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " yEncoding) \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " concatenate (Concatenate) (None, 33) 0 ['normalization[0][0]', \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " 'category_encoding[0][0]'] \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " dense (Dense) (None, 16) 544 ['concatenate[0][0]'] \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " dense_1 (Dense) (None, 3) 51 ['dense[0][0]'] \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "==================================================================================================\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Total params: 598 (2.34 KB)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Trainable params: 595 (2.32 KB)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Non-trainable params: 3 (16.00 Byte)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "__________________________________________________________________________________________________\n" ] } ], "source": [ "%set_cell_height 300\n", "\n", "nn_model.compile(\n", " optimizer=tf_keras.optimizers.Adam(),\n", " loss=tf_keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n", " metrics=[\"accuracy\"])\n", "\n", "nn_model.fit(x=train_ds, validation_data=test_ds, epochs=10)\n", "nn_model.summary()" ] }, { "cell_type": "markdown", "metadata": { "id": "N2mgMZOpgMQp" }, "source": [ "The neural network layers are shared between the two models. So now that the neural network is trained the decision forest model will be fit to the trained output of the neural network layers:" ] }, { "cell_type": "code", "execution_count": 25, "metadata": { "execution": { "iopub.execute_input": "2024-04-20T11:19:44.597280Z", "iopub.status.busy": "2024-04-20T11:19:44.597036Z", "iopub.status.idle": "2024-04-20T11:19:45.485510Z", "shell.execute_reply": "2024-04-20T11:19:45.484799Z" }, "id": "JAc9niXqud7V" }, "outputs": [ { "data": { "application/javascript": [ "google.colab.output.setIframeHeight(0, true, {maxHeight: 300})" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Reading training dataset...\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Training dataset read in 0:00:00.463558. Found 248 examples.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Training model...\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Model trained in 0:00:00.043113\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Compiling model...\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "[INFO 24-04-20 11:19:45.0975 UTC kernel.cc:1233] Loading model from path /tmpfs/tmp/tmpgsdeyzk5/model/ with prefix ac3d07af419249e2\n", "[INFO 24-04-20 11:19:45.1138 UTC decision_forest.cc:734] Model loaded with 300 root(s), 5640 node(s), and 8 input feature(s).\n", "[INFO 24-04-20 11:19:45.1138 UTC abstract_model.cc:1344] Engine \"RandomForestGeneric\" built\n", "[INFO 24-04-20 11:19:45.1138 UTC kernel.cc:1061] Use fast generic engine\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Model compiled.\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 25, "metadata": {}, "output_type": "execute_result" } ], "source": [ "%set_cell_height 300\n", "\n", "df_and_nn_model.fit(x=train_ds)" ] }, { "cell_type": "markdown", "metadata": { "id": "HF8Ru2HSv1a5" }, "source": [ "Now evaluate the composed model:" ] }, { "cell_type": "code", "execution_count": 26, "metadata": { "execution": { "iopub.execute_input": "2024-04-20T11:19:45.488708Z", "iopub.status.busy": "2024-04-20T11:19:45.488459Z", "iopub.status.idle": "2024-04-20T11:19:45.749038Z", "shell.execute_reply": "2024-04-20T11:19:45.748386Z" }, "id": "EPMlcObzuw89" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\r", "1/1 [==============================] - ETA: 0s - loss: 0.0000e+00 - accuracy: 0.9479" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "1/1 [==============================] - 0s 240ms/step - loss: 0.0000e+00 - accuracy: 0.9479\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Evaluation: [0.0, 0.9479166865348816]\n" ] } ], "source": [ "df_and_nn_model.compile(metrics=[\"accuracy\"])\n", "print(\"Evaluation:\", df_and_nn_model.evaluate(test_ds))" ] }, { "cell_type": "markdown", "metadata": { "id": "awiHEznlv5sI" }, "source": [ "Compare it to the Neural Network alone:" ] }, { "cell_type": "code", "execution_count": 27, "metadata": { "execution": { "iopub.execute_input": "2024-04-20T11:19:45.752389Z", "iopub.status.busy": "2024-04-20T11:19:45.752141Z", "iopub.status.idle": "2024-04-20T11:19:45.774439Z", "shell.execute_reply": "2024-04-20T11:19:45.773731Z" }, "id": "--ompWYTvxM-" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\r", "1/1 [==============================] - ETA: 0s - loss: 1.0217 - accuracy: 0.7812" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "1/1 [==============================] - 0s 14ms/step - loss: 1.0217 - accuracy: 0.7812\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Evaluation : [1.021651029586792, 0.78125]\n" ] } ], "source": [ "print(\"Evaluation :\", nn_model.evaluate(test_ds))" ] } ], "metadata": { "colab": { "collapsed_sections": [], "name": "intermediate_colab.ipynb", "provenance": [], "toc_visible": true }, "kernelspec": { "display_name": "Python 3", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.19" } }, "nbformat": 4, "nbformat_minor": 0 }