{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "Tce3stUlHN0L" }, "source": [ "##### Copyright 2020 The TensorFlow Authors." ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "cellView": "form", "execution": { "iopub.execute_input": "2024-04-20T11:24:36.683809Z", "iopub.status.busy": "2024-04-20T11:24:36.683242Z", "iopub.status.idle": "2024-04-20T11:24:36.687044Z", "shell.execute_reply": "2024-04-20T11:24:36.686480Z" }, "id": "tuOe1ymfHZPu" }, "outputs": [], "source": [ "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", "# you may not use this file except in compliance with the License.\n", "# You may obtain a copy of the License at\n", "#\n", "# https://www.apache.org/licenses/LICENSE-2.0\n", "#\n", "# Unless required by applicable law or agreed to in writing, software\n", "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", "# limitations under the License." ] }, { "cell_type": "markdown", "metadata": { "id": "8yo62ffS5TF5" }, "source": [ "# Inspect and debug decision forest models\n", "\n", "\n", " \n", " \n", " \n", " \n", "
\n", " View on TensorFlow.org\n", " \n", " Run in Google Colab\n", " \n", " View on GitHub\n", " \n", " Download notebook\n", "
\n" ] }, { "cell_type": "markdown", "metadata": { "id": "84wIz8LPiLDF" }, "source": [ "In this colab, you will learn how to inspect and create the structure of a model directly. We assume you are familiar with the concepts introduced in the\n", "[beginner](beginner_colab.ipynb) and [intermediate](intermediate_colab.ipynb)\n", "colabs.\n", "\n", "In this colab, you will:\n", "\n", "1. Train a Random Forest model and access its structure programmatically.\n", "\n", "1. Create a Random Forest model by hand and use it as a classical model." ] }, { "cell_type": "markdown", "metadata": { "id": "Rzskapxq7gdo" }, "source": [ "## Setup" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "execution": { "iopub.execute_input": "2024-04-20T11:24:36.690640Z", "iopub.status.busy": "2024-04-20T11:24:36.690055Z", "iopub.status.idle": "2024-04-20T11:24:41.452072Z", "shell.execute_reply": "2024-04-20T11:24:41.451100Z" }, "id": "mZiInVYfffAb" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Collecting tensorflow_decision_forests\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " Using cached tensorflow_decision_forests-1.9.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (6.0 kB)\r\n", "Requirement already satisfied: numpy in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow_decision_forests) (1.26.4)\r\n", "Requirement already satisfied: pandas in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow_decision_forests) (2.2.2)\r\n", "Requirement already satisfied: tensorflow~=2.16.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow_decision_forests) (2.16.1)\r\n", "Requirement already satisfied: six in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow_decision_forests) (1.16.0)\r\n", "Requirement already satisfied: absl-py in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow_decision_forests) (1.4.0)\r\n", "Requirement already satisfied: wheel in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow_decision_forests) (0.41.2)\r\n", "Collecting wurlitzer (from tensorflow_decision_forests)\r\n", " Using cached wurlitzer-3.0.3-py3-none-any.whl.metadata (1.9 kB)\r\n", "Requirement already satisfied: tf-keras~=2.16 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow_decision_forests) (2.16.0)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: astunparse>=1.6.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.16.1->tensorflow_decision_forests) (1.6.3)\r\n", "Requirement already satisfied: flatbuffers>=23.5.26 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.16.1->tensorflow_decision_forests) (24.3.25)\r\n", "Requirement already satisfied: gast!=0.5.0,!=0.5.1,!=0.5.2,>=0.2.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.16.1->tensorflow_decision_forests) (0.5.4)\r\n", "Requirement already satisfied: google-pasta>=0.1.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.16.1->tensorflow_decision_forests) (0.2.0)\r\n", "Requirement already satisfied: h5py>=3.10.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.16.1->tensorflow_decision_forests) (3.11.0)\r\n", "Requirement already satisfied: libclang>=13.0.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.16.1->tensorflow_decision_forests) (18.1.1)\r\n", "Requirement already satisfied: ml-dtypes~=0.3.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.16.1->tensorflow_decision_forests) (0.3.2)\r\n", "Requirement already satisfied: opt-einsum>=2.3.2 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.16.1->tensorflow_decision_forests) (3.3.0)\r\n", "Requirement already satisfied: packaging in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.16.1->tensorflow_decision_forests) (24.0)\r\n", "Requirement already satisfied: protobuf!=4.21.0,!=4.21.1,!=4.21.2,!=4.21.3,!=4.21.4,!=4.21.5,<5.0.0dev,>=3.20.3 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.16.1->tensorflow_decision_forests) (3.20.3)\r\n", "Requirement already satisfied: requests<3,>=2.21.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.16.1->tensorflow_decision_forests) (2.31.0)\r\n", "Requirement already satisfied: setuptools in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.16.1->tensorflow_decision_forests) (69.5.1)\r\n", "Requirement already satisfied: termcolor>=1.1.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.16.1->tensorflow_decision_forests) (2.4.0)\r\n", "Requirement already satisfied: typing-extensions>=3.6.6 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.16.1->tensorflow_decision_forests) (4.11.0)\r\n", "Requirement already satisfied: wrapt>=1.11.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.16.1->tensorflow_decision_forests) (1.16.0)\r\n", "Requirement already satisfied: grpcio<2.0,>=1.24.3 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.16.1->tensorflow_decision_forests) (1.63.0rc2)\r\n", "Requirement already satisfied: tensorboard<2.17,>=2.16 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.16.1->tensorflow_decision_forests) (2.16.2)\r\n", "Requirement already satisfied: keras>=3.0.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.16.1->tensorflow_decision_forests) (3.2.1)\r\n", "Requirement already satisfied: tensorflow-io-gcs-filesystem>=0.23.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.16.1->tensorflow_decision_forests) (0.36.0)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: python-dateutil>=2.8.2 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from pandas->tensorflow_decision_forests) (2.9.0.post0)\r\n", "Requirement already satisfied: pytz>=2020.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from pandas->tensorflow_decision_forests) (2024.1)\r\n", "Requirement already satisfied: tzdata>=2022.7 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from pandas->tensorflow_decision_forests) (2024.1)\r\n", "Requirement already satisfied: rich in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from keras>=3.0.0->tensorflow~=2.16.1->tensorflow_decision_forests) (13.7.1)\r\n", "Requirement already satisfied: namex in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from keras>=3.0.0->tensorflow~=2.16.1->tensorflow_decision_forests) (0.0.8)\r\n", "Requirement already satisfied: optree in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from keras>=3.0.0->tensorflow~=2.16.1->tensorflow_decision_forests) (0.11.0)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: charset-normalizer<4,>=2 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from requests<3,>=2.21.0->tensorflow~=2.16.1->tensorflow_decision_forests) (3.3.2)\r\n", "Requirement already satisfied: idna<4,>=2.5 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from requests<3,>=2.21.0->tensorflow~=2.16.1->tensorflow_decision_forests) (3.7)\r\n", "Requirement already satisfied: urllib3<3,>=1.21.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from requests<3,>=2.21.0->tensorflow~=2.16.1->tensorflow_decision_forests) (2.2.1)\r\n", "Requirement already satisfied: certifi>=2017.4.17 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from requests<3,>=2.21.0->tensorflow~=2.16.1->tensorflow_decision_forests) (2024.2.2)\r\n", "Requirement already satisfied: markdown>=2.6.8 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorboard<2.17,>=2.16->tensorflow~=2.16.1->tensorflow_decision_forests) (3.6)\r\n", "Requirement already satisfied: tensorboard-data-server<0.8.0,>=0.7.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorboard<2.17,>=2.16->tensorflow~=2.16.1->tensorflow_decision_forests) (0.7.2)\r\n", "Requirement already satisfied: werkzeug>=1.0.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorboard<2.17,>=2.16->tensorflow~=2.16.1->tensorflow_decision_forests) (3.0.2)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: importlib-metadata>=4.4 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from markdown>=2.6.8->tensorboard<2.17,>=2.16->tensorflow~=2.16.1->tensorflow_decision_forests) (7.1.0)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: MarkupSafe>=2.1.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from werkzeug>=1.0.1->tensorboard<2.17,>=2.16->tensorflow~=2.16.1->tensorflow_decision_forests) (2.1.5)\r\n", "Requirement already satisfied: markdown-it-py>=2.2.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from rich->keras>=3.0.0->tensorflow~=2.16.1->tensorflow_decision_forests) (3.0.0)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: pygments<3.0.0,>=2.13.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from rich->keras>=3.0.0->tensorflow~=2.16.1->tensorflow_decision_forests) (2.17.2)\r\n", "Requirement already satisfied: zipp>=0.5 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from importlib-metadata>=4.4->markdown>=2.6.8->tensorboard<2.17,>=2.16->tensorflow~=2.16.1->tensorflow_decision_forests) (3.18.1)\r\n", "Requirement already satisfied: mdurl~=0.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from markdown-it-py>=2.2.0->rich->keras>=3.0.0->tensorflow~=2.16.1->tensorflow_decision_forests) (0.1.2)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Using cached tensorflow_decision_forests-1.9.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (15.5 MB)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Using cached wurlitzer-3.0.3-py3-none-any.whl (7.3 kB)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Installing collected packages: wurlitzer, tensorflow_decision_forests\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Successfully installed tensorflow_decision_forests-1.9.0 wurlitzer-3.0.3\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: wurlitzer in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (3.0.3)\r\n" ] } ], "source": [ "# Install TensorFlow Decision Forests.\n", "!pip install tensorflow_decision_forests\n", "\n", "# Use wurlitzer to show the training logs.\n", "!pip install wurlitzer" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "execution": { "iopub.execute_input": "2024-04-20T11:24:41.456245Z", "iopub.status.busy": "2024-04-20T11:24:41.455950Z", "iopub.status.idle": "2024-04-20T11:24:44.147752Z", "shell.execute_reply": "2024-04-20T11:24:44.146995Z" }, "id": "RsCV2oAS7gC_" }, "outputs": [], "source": [ "import os\n", "# Keep using Keras 2\n", "os.environ['TF_USE_LEGACY_KERAS'] = '1'\n", "\n", "import tensorflow_decision_forests as tfdf\n", "\n", "import numpy as np\n", "import pandas as pd\n", "import tensorflow as tf\n", "import tf_keras\n", "import matplotlib.pyplot as plt\n", "import math\n", "import collections" ] }, { "cell_type": "markdown", "metadata": { "id": "xV3klWJnyCgH" }, "source": [ "The hidden code cell limits the output height in colab." ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "cellView": "form", "execution": { "iopub.execute_input": "2024-04-20T11:24:44.152047Z", "iopub.status.busy": "2024-04-20T11:24:44.151605Z", "iopub.status.idle": "2024-04-20T11:24:44.156227Z", "shell.execute_reply": "2024-04-20T11:24:44.155617Z" }, "id": "XAWSjWrQmVE0" }, "outputs": [], "source": [ "#@title\n", "\n", "from IPython.core.magic import register_line_magic\n", "from IPython.display import Javascript\n", "from IPython.display import display as ipy_display\n", "\n", "# Some of the model training logs can cover the full\n", "# screen if not compressed to a smaller viewport.\n", "# This magic allows setting a max height for a cell.\n", "@register_line_magic\n", "def set_cell_height(size):\n", " ipy_display(\n", " Javascript(\"google.colab.output.setIframeHeight(0, true, {maxHeight: \" +\n", " str(size) + \"})\"))" ] }, { "cell_type": "markdown", "metadata": { "id": "M_D4Ft4o65XT" }, "source": [ "## Train a simple Random Forest\n", "\n", "We train a Random Forest like in the [beginner colab](beginner_colab.ipynb):" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "execution": { "iopub.execute_input": "2024-04-20T11:24:44.159522Z", "iopub.status.busy": "2024-04-20T11:24:44.158991Z", "iopub.status.idle": "2024-04-20T11:24:54.402903Z", "shell.execute_reply": "2024-04-20T11:24:54.402089Z" }, "id": "tTW2aBiVcU3E" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " species island bill_length_mm bill_depth_mm flipper_length_mm \\\n", "0 Adelie Torgersen 39.1 18.7 181.0 \n", "1 Adelie Torgersen 39.5 17.4 186.0 \n", "2 Adelie Torgersen 40.3 18.0 195.0 \n", "\n", " body_mass_g sex year \n", "0 3750.0 male 2007 \n", "1 3800.0 female 2007 \n", "2 3250.0 female 2007 \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Warning: The `num_threads` constructor argument is not set and the number of CPU is os.cpu_count()=32 > 32. Setting num_threads to 32. Set num_threads manually to use more than 32 cpus.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "WARNING:absl:The `num_threads` constructor argument is not set and the number of CPU is os.cpu_count()=32 > 32. Setting num_threads to 32. Set num_threads manually to use more than 32 cpus.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Use /tmpfs/tmp/tmpadwizz7x as temporary training directory\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Reading training dataset...\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Training dataset read in 0:00:03.574049. Found 344 examples.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Training model...\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Model trained in 0:00:00.092571\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Compiling model...\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "[INFO 24-04-20 11:24:50.3886 UTC kernel.cc:1233] Loading model from path /tmpfs/tmp/tmpadwizz7x/model/ with prefix 59499fe5fa654879\n", "[INFO 24-04-20 11:24:50.4047 UTC decision_forest.cc:734] Model loaded with 300 root(s), 5080 node(s), and 7 input feature(s).\n", "[INFO 24-04-20 11:24:50.4047 UTC abstract_model.cc:1344] Engine \"RandomForestGeneric\" built\n", "[INFO 24-04-20 11:24:50.4048 UTC kernel.cc:1061] Use fast generic engine\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Model compiled.\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Download the dataset\n", "!wget -q https://storage.googleapis.com/download.tensorflow.org/data/palmer_penguins/penguins.csv -O /tmp/penguins.csv\n", "\n", "# Load a dataset into a Pandas Dataframe.\n", "dataset_df = pd.read_csv(\"/tmp/penguins.csv\")\n", "\n", "# Show the first three examples.\n", "print(dataset_df.head(3))\n", "\n", "# Convert the pandas dataframe into a tf dataset.\n", "dataset_tf = tfdf.keras.pd_dataframe_to_tf_dataset(dataset_df, label=\"species\")\n", "\n", "# Train the Random Forest\n", "model = tfdf.keras.RandomForestModel(compute_oob_variable_importances=True)\n", "model.fit(x=dataset_tf)" ] }, { "cell_type": "markdown", "metadata": { "id": "b7Xie0bhcw8_" }, "source": [ "Note the `compute_oob_variable_importances=True`\n", "hyper-parameter in the model constructor. This option computes the Out-of-bag (OOB)\n", "variable importance during training. This is a popular\n", "[permutation variable importance](https://christophm.github.io/interpretable-ml-book/feature-importance.html) for Random Forest models.\n", "\n", "Computing the OOB Variable importance does not impact the final model, it will slow the training on large datasets.\n", "\n", "Check the model summary:" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "execution": { "iopub.execute_input": "2024-04-20T11:24:54.406750Z", "iopub.status.busy": "2024-04-20T11:24:54.406107Z", "iopub.status.idle": "2024-04-20T11:24:54.417345Z", "shell.execute_reply": "2024-04-20T11:24:54.416704Z" }, "id": "fsQYD-jFc2EH" }, "outputs": [ { "data": { "application/javascript": [ "google.colab.output.setIframeHeight(0, true, {maxHeight: 300})" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Model: \"random_forest_model\"\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "_________________________________________________________________\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " Layer (type) Output Shape Param # \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "=================================================================\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "=================================================================\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Total params: 1 (1.00 Byte)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Trainable params: 0 (0.00 Byte)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Non-trainable params: 1 (1.00 Byte)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "_________________________________________________________________\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Type: \"RANDOM_FOREST\"\n", "Task: CLASSIFICATION\n", "Label: \"__LABEL\"\n", "\n", "Input Features (7):\n", "\tbill_depth_mm\n", "\tbill_length_mm\n", "\tbody_mass_g\n", "\tflipper_length_mm\n", "\tisland\n", "\tsex\n", "\tyear\n", "\n", "No weights\n", "\n", "Variable Importance: INV_MEAN_MIN_DEPTH:\n", " 1. \"flipper_length_mm\" 0.440513 ################\n", " 2. \"bill_length_mm\" 0.438028 ###############\n", " 3. \"bill_depth_mm\" 0.299751 #####\n", " 4. \"island\" 0.295079 #####\n", " 5. \"body_mass_g\" 0.256534 ##\n", " 6. \"sex\" 0.225708 \n", " 7. \"year\" 0.224020 \n", "\n", "Variable Importance: MEAN_DECREASE_IN_ACCURACY:\n", " 1. \"bill_length_mm\" 0.151163 ################\n", " 2. \"island\" 0.008721 #\n", " 3. \"bill_depth_mm\" 0.000000 \n", " 4. \"body_mass_g\" 0.000000 \n", " 5. \"sex\" 0.000000 \n", " 6. \"year\" 0.000000 \n", " 7. \"flipper_length_mm\" -0.002907 \n", "\n", "Variable Importance: MEAN_DECREASE_IN_AP_1_VS_OTHERS:\n", " 1. \"bill_length_mm\" 0.083305 ################\n", " 2. \"island\" 0.007664 #\n", " 3. \"flipper_length_mm\" 0.003400 \n", " 4. \"bill_depth_mm\" 0.002741 \n", " 5. \"body_mass_g\" 0.000722 \n", " 6. \"sex\" 0.000644 \n", " 7. \"year\" 0.000000 \n", "\n", "Variable Importance: MEAN_DECREASE_IN_AP_2_VS_OTHERS:\n", " 1. \"bill_length_mm\" 0.508510 ################\n", " 2. \"island\" 0.023487 \n", " 3. \"bill_depth_mm\" 0.007744 \n", " 4. \"flipper_length_mm\" 0.006008 \n", " 5. \"body_mass_g\" 0.003017 \n", " 6. \"sex\" 0.001537 \n", " 7. \"year\" -0.000245 \n", "\n", "Variable Importance: MEAN_DECREASE_IN_AP_3_VS_OTHERS:\n", " 1. \"island\" 0.002192 ################\n", " 2. \"bill_length_mm\" 0.001572 ############\n", " 3. \"bill_depth_mm\" 0.000497 #######\n", " 4. \"sex\" 0.000000 ####\n", " 5. \"year\" 0.000000 ####\n", " 6. \"body_mass_g\" -0.000053 ####\n", " 7. \"flipper_length_mm\" -0.000890 \n", "\n", "Variable Importance: MEAN_DECREASE_IN_AUC_1_VS_OTHERS:\n", " 1. \"bill_length_mm\" 0.071306 ################\n", " 2. \"island\" 0.007299 #\n", " 3. \"flipper_length_mm\" 0.004506 #\n", " 4. \"bill_depth_mm\" 0.002124 \n", " 5. \"body_mass_g\" 0.000548 \n", " 6. \"sex\" 0.000480 \n", " 7. \"year\" 0.000000 \n", "\n", "Variable Importance: MEAN_DECREASE_IN_AUC_2_VS_OTHERS:\n", " 1. \"bill_length_mm\" 0.108642 ################\n", " 2. \"island\" 0.014493 ##\n", " 3. \"bill_depth_mm\" 0.007406 #\n", " 4. \"flipper_length_mm\" 0.005195 \n", " 5. \"body_mass_g\" 0.001012 \n", " 6. \"sex\" 0.000480 \n", " 7. \"year\" -0.000053 \n", "\n", "Variable Importance: MEAN_DECREASE_IN_AUC_3_VS_OTHERS:\n", " 1. \"island\" 0.002126 ################\n", " 2. \"bill_length_mm\" 0.001393 ###########\n", " 3. \"bill_depth_mm\" 0.000293 #####\n", " 4. \"sex\" 0.000000 ###\n", " 5. \"year\" 0.000000 ###\n", " 6. \"body_mass_g\" -0.000037 ###\n", " 7. \"flipper_length_mm\" -0.000550 \n", "\n", "Variable Importance: MEAN_DECREASE_IN_PRAUC_1_VS_OTHERS:\n", " 1. \"bill_length_mm\" 0.083122 ################\n", " 2. \"island\" 0.010887 ##\n", " 3. \"flipper_length_mm\" 0.003425 \n", " 4. \"bill_depth_mm\" 0.002731 \n", " 5. \"body_mass_g\" 0.000719 \n", " 6. \"sex\" 0.000641 \n", " 7. \"year\" 0.000000 \n", "\n", "Variable Importance: MEAN_DECREASE_IN_PRAUC_2_VS_OTHERS:\n", " 1. \"bill_length_mm\" 0.497611 ################\n", " 2. \"island\" 0.024045 \n", " 3. \"bill_depth_mm\" 0.007734 \n", " 4. \"flipper_length_mm\" 0.006017 \n", " 5. \"body_mass_g\" 0.003000 \n", " 6. \"sex\" 0.001528 \n", " 7. \"year\" -0.000243 \n", "\n", "Variable Importance: MEAN_DECREASE_IN_PRAUC_3_VS_OTHERS:\n", " 1. \"island\" 0.002187 ################\n", " 2. \"bill_length_mm\" 0.001568 ############\n", " 3. \"bill_depth_mm\" 0.000495 #######\n", " 4. \"sex\" 0.000000 ####\n", " 5. \"year\" 0.000000 ####\n", " 6. \"body_mass_g\" -0.000053 ####\n", " 7. \"flipper_length_mm\" -0.000886 \n", "\n", "Variable Importance: NUM_AS_ROOT:\n", " 1. \"flipper_length_mm\" 157.000000 ################\n", " 2. \"bill_length_mm\" 76.000000 #######\n", " 3. \"bill_depth_mm\" 52.000000 #####\n", " 4. \"island\" 12.000000 \n", " 5. \"body_mass_g\" 3.000000 \n", "\n", "Variable Importance: NUM_NODES:\n", " 1. \"bill_length_mm\" 778.000000 ################\n", " 2. \"bill_depth_mm\" 463.000000 #########\n", " 3. \"flipper_length_mm\" 414.000000 ########\n", " 4. \"island\" 342.000000 ######\n", " 5. \"body_mass_g\" 338.000000 ######\n", " 6. \"sex\" 36.000000 \n", " 7. \"year\" 19.000000 \n", "\n", "Variable Importance: SUM_SCORE:\n", " 1. \"bill_length_mm\" 36515.793787 ################\n", " 2. \"flipper_length_mm\" 35120.434174 ###############\n", " 3. \"island\" 14669.408395 ######\n", " 4. \"bill_depth_mm\" 14515.446617 ######\n", " 5. \"body_mass_g\" 3485.330881 #\n", " 6. \"sex\" 354.201073 \n", " 7. \"year\" 49.737758 \n", "\n", "\n", "\n", "Winner takes all: true\n", "Out-of-bag evaluation: accuracy:0.976744 logloss:0.068949\n", "Number of trees: 300\n", "Total number of nodes: 5080\n", "\n", "Number of nodes by tree:\n", "Count: 300 Average: 16.9333 StdDev: 3.10197\n", "Min: 11 Max: 31 Ignored: 0\n", "----------------------------------------------\n", "[ 11, 12) 6 2.00% 2.00% #\n", "[ 12, 13) 0 0.00% 2.00%\n", "[ 13, 14) 46 15.33% 17.33% #####\n", "[ 14, 15) 0 0.00% 17.33%\n", "[ 15, 16) 70 23.33% 40.67% ########\n", "[ 16, 17) 0 0.00% 40.67%\n", "[ 17, 18) 84 28.00% 68.67% ##########\n", "[ 18, 19) 0 0.00% 68.67%\n", "[ 19, 20) 46 15.33% 84.00% #####\n", "[ 20, 21) 0 0.00% 84.00%\n", "[ 21, 22) 30 10.00% 94.00% ####\n", "[ 22, 23) 0 0.00% 94.00%\n", "[ 23, 24) 13 4.33% 98.33% ##\n", "[ 24, 25) 0 0.00% 98.33%\n", "[ 25, 26) 2 0.67% 99.00%\n", "[ 26, 27) 0 0.00% 99.00%\n", "[ 27, 28) 2 0.67% 99.67%\n", "[ 28, 29) 0 0.00% 99.67%\n", "[ 29, 30) 0 0.00% 99.67%\n", "[ 30, 31] 1 0.33% 100.00%\n", "\n", "Depth by leafs:\n", "Count: 2690 Average: 3.53271 StdDev: 1.06789\n", "Min: 2 Max: 7 Ignored: 0\n", "----------------------------------------------\n", "[ 2, 3) 545 20.26% 20.26% ######\n", "[ 3, 4) 747 27.77% 48.03% ########\n", "[ 4, 5) 888 33.01% 81.04% ##########\n", "[ 5, 6) 444 16.51% 97.55% #####\n", "[ 6, 7) 62 2.30% 99.85% #\n", "[ 7, 7] 4 0.15% 100.00%\n", "\n", "Number of training obs by leaf:\n", "Count: 2690 Average: 38.3643 StdDev: 44.8651\n", "Min: 5 Max: 155 Ignored: 0\n", "----------------------------------------------\n", "[ 5, 12) 1474 54.80% 54.80% ##########\n", "[ 12, 20) 124 4.61% 59.41% #\n", "[ 20, 27) 48 1.78% 61.19%\n", "[ 27, 35) 74 2.75% 63.94% #\n", "[ 35, 42) 58 2.16% 66.10%\n", "[ 42, 50) 85 3.16% 69.26% #\n", "[ 50, 57) 96 3.57% 72.83% #\n", "[ 57, 65) 87 3.23% 76.06% #\n", "[ 65, 72) 49 1.82% 77.88%\n", "[ 72, 80) 23 0.86% 78.74%\n", "[ 80, 88) 30 1.12% 79.85%\n", "[ 88, 95) 23 0.86% 80.71%\n", "[ 95, 103) 42 1.56% 82.27%\n", "[ 103, 110) 62 2.30% 84.57%\n", "[ 110, 118) 115 4.28% 88.85% #\n", "[ 118, 125) 115 4.28% 93.12% #\n", "[ 125, 133) 98 3.64% 96.77% #\n", "[ 133, 140) 49 1.82% 98.59%\n", "[ 140, 148) 31 1.15% 99.74%\n", "[ 148, 155] 7 0.26% 100.00%\n", "\n", "Attribute in nodes:\n", "\t778 : bill_length_mm [NUMERICAL]\n", "\t463 : bill_depth_mm [NUMERICAL]\n", "\t414 : flipper_length_mm [NUMERICAL]\n", "\t342 : island [CATEGORICAL]\n", "\t338 : body_mass_g [NUMERICAL]\n", "\t36 : sex [CATEGORICAL]\n", "\t19 : year [NUMERICAL]\n", "\n", "Attribute in nodes with depth <= 0:\n", "\t157 : flipper_length_mm [NUMERICAL]\n", "\t76 : bill_length_mm [NUMERICAL]\n", "\t52 : bill_depth_mm [NUMERICAL]\n", "\t12 : island [CATEGORICAL]\n", "\t3 : body_mass_g [NUMERICAL]\n", "\n", "Attribute in nodes with depth <= 1:\n", "\t250 : bill_length_mm [NUMERICAL]\n", "\t244 : flipper_length_mm [NUMERICAL]\n", "\t183 : bill_depth_mm [NUMERICAL]\n", "\t170 : island [CATEGORICAL]\n", "\t53 : body_mass_g [NUMERICAL]\n", "\n", "Attribute in nodes with depth <= 2:\n", "\t462 : bill_length_mm [NUMERICAL]\n", "\t320 : flipper_length_mm [NUMERICAL]\n", "\t310 : bill_depth_mm [NUMERICAL]\n", "\t287 : island [CATEGORICAL]\n", "\t162 : body_mass_g [NUMERICAL]\n", "\t9 : sex [CATEGORICAL]\n", "\t5 : year [NUMERICAL]\n", "\n", "Attribute in nodes with depth <= 3:\n", "\t669 : bill_length_mm [NUMERICAL]\n", "\t410 : bill_depth_mm [NUMERICAL]\n", "\t383 : flipper_length_mm [NUMERICAL]\n", "\t328 : island [CATEGORICAL]\n", "\t286 : body_mass_g [NUMERICAL]\n", "\t32 : sex [CATEGORICAL]\n", "\t10 : year [NUMERICAL]\n", "\n", "Attribute in nodes with depth <= 5:\n", "\t778 : bill_length_mm [NUMERICAL]\n", "\t462 : bill_depth_mm [NUMERICAL]\n", "\t413 : flipper_length_mm [NUMERICAL]\n", "\t342 : island [CATEGORICAL]\n", "\t338 : body_mass_g [NUMERICAL]\n", "\t36 : sex [CATEGORICAL]\n", "\t19 : year [NUMERICAL]\n", "\n", "Condition type in nodes:\n", "\t2012 : HigherCondition\n", "\t378 : ContainsBitmapCondition\n", "Condition type in nodes with depth <= 0:\n", "\t288 : HigherCondition\n", "\t12 : ContainsBitmapCondition\n", "Condition type in nodes with depth <= 1:\n", "\t730 : HigherCondition\n", "\t170 : ContainsBitmapCondition\n", "Condition type in nodes with depth <= 2:\n", "\t1259 : HigherCondition\n", "\t296 : ContainsBitmapCondition\n", "Condition type in nodes with depth <= 3:\n", "\t1758 : HigherCondition\n", "\t360 : ContainsBitmapCondition\n", "Condition type in nodes with depth <= 5:\n", "\t2010 : HigherCondition\n", "\t378 : ContainsBitmapCondition\n", "Node format: NOT_SET\n", "\n", "Training OOB:\n", "\ttrees: 1, Out-of-bag evaluation: accuracy:0.964286 logloss:1.28727\n", "\ttrees: 13, Out-of-bag evaluation: accuracy:0.94863 logloss:1.38235\n", "\ttrees: 29, Out-of-bag evaluation: accuracy:0.963526 logloss:0.698239\n", "\ttrees: 39, Out-of-bag evaluation: accuracy:0.958824 logloss:0.37345\n", "\ttrees: 54, Out-of-bag evaluation: accuracy:0.973837 logloss:0.171543\n", "\ttrees: 72, Out-of-bag evaluation: accuracy:0.97093 logloss:0.171775\n", "\ttrees: 82, Out-of-bag evaluation: accuracy:0.973837 logloss:0.168111\n", "\ttrees: 92, Out-of-bag evaluation: accuracy:0.976744 logloss:0.167506\n", "\ttrees: 113, Out-of-bag evaluation: accuracy:0.976744 logloss:0.170507\n", "\ttrees: 124, Out-of-bag evaluation: accuracy:0.976744 logloss:0.07406\n", "\ttrees: 135, Out-of-bag evaluation: accuracy:0.976744 logloss:0.0739305\n", "\ttrees: 145, Out-of-bag evaluation: accuracy:0.976744 logloss:0.0741686\n", "\ttrees: 155, Out-of-bag evaluation: accuracy:0.976744 logloss:0.0738562\n", "\ttrees: 166, Out-of-bag evaluation: accuracy:0.976744 logloss:0.0727146\n", "\ttrees: 177, Out-of-bag evaluation: accuracy:0.976744 logloss:0.0721128\n", "\ttrees: 195, Out-of-bag evaluation: accuracy:0.976744 logloss:0.070882\n", "\ttrees: 205, Out-of-bag evaluation: accuracy:0.976744 logloss:0.0705714\n", "\ttrees: 216, Out-of-bag evaluation: accuracy:0.976744 logloss:0.0697382\n", "\ttrees: 231, Out-of-bag evaluation: accuracy:0.976744 logloss:0.0695581\n", "\ttrees: 244, Out-of-bag evaluation: accuracy:0.976744 logloss:0.0683962\n", "\ttrees: 255, Out-of-bag evaluation: accuracy:0.976744 logloss:0.0693447\n", "\ttrees: 267, Out-of-bag evaluation: accuracy:0.976744 logloss:0.0689024\n", "\ttrees: 279, Out-of-bag evaluation: accuracy:0.976744 logloss:0.0694214\n", "\ttrees: 296, Out-of-bag evaluation: accuracy:0.976744 logloss:0.0691636\n", "\ttrees: 300, Out-of-bag evaluation: accuracy:0.976744 logloss:0.068949\n", "\n" ] } ], "source": [ "%set_cell_height 300\n", "\n", "model.summary()" ] }, { "cell_type": "markdown", "metadata": { "id": "dtvAH26EfSgY" }, "source": [ "Note the multiple variable importances with name `MEAN_DECREASE_IN_*`." ] }, { "cell_type": "markdown", "metadata": { "id": "xTwmx8A0c4TU" }, "source": [ "## Plotting the model\n", "\n", "Next, plot the model.\n", "\n", "A Random Forest is a large model (this model has 300 trees and ~5k nodes; see the summary above). Therefore, only plot the first tree, and limit the nodes to depth 3." ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "execution": { "iopub.execute_input": "2024-04-20T11:24:54.423361Z", "iopub.status.busy": "2024-04-20T11:24:54.422656Z", "iopub.status.idle": "2024-04-20T11:24:54.429895Z", "shell.execute_reply": "2024-04-20T11:24:54.429284Z" }, "id": "ZRTrXDz_dIAQ" }, "outputs": [ { "data": { "text/html": [ "\n", "\n", "
\n", "\n" ], "text/plain": [ "" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "tfdf.model_plotter.plot_model_in_colab(model, tree_idx=0, max_depth=3)" ] }, { "cell_type": "markdown", "metadata": { "id": "lOlieoz2c-GA" }, "source": [ "## Inspect the model structure\n", "\n", "The model structure and meta-data is\n", "available through the **inspector** created by `make_inspector()`.\n", "\n", "**Note:** Depending on the learning algorithm and hyper-parameters, the\n", "inspector will expose different specialized attributes. For examples, the\n", "`winner_take_all` field is specific to Random Forest models." ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "execution": { "iopub.execute_input": "2024-04-20T11:24:54.433296Z", "iopub.status.busy": "2024-04-20T11:24:54.432683Z", "iopub.status.idle": "2024-04-20T11:24:54.436951Z", "shell.execute_reply": "2024-04-20T11:24:54.436318Z" }, "id": "KHc8IcW1c8ER" }, "outputs": [], "source": [ "inspector = model.make_inspector()" ] }, { "cell_type": "markdown", "metadata": { "id": "RDdUhqaSsNnQ" }, "source": [ "For our model, the available inspector fields are:" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "execution": { "iopub.execute_input": "2024-04-20T11:24:54.440561Z", "iopub.status.busy": "2024-04-20T11:24:54.439979Z", "iopub.status.idle": "2024-04-20T11:24:54.444995Z", "shell.execute_reply": "2024-04-20T11:24:54.444350Z" }, "id": "jx54DFjRsA7k" }, "outputs": [ { "data": { "text/plain": [ "['MODEL_NAME',\n", " 'dataspec',\n", " 'directory',\n", " 'evaluation',\n", " 'export_to_tensorboard',\n", " 'extract_all_trees',\n", " 'extract_tree',\n", " 'features',\n", " 'file_prefix',\n", " 'header',\n", " 'iterate_on_nodes',\n", " 'label',\n", " 'label_classes',\n", " 'metadata',\n", " 'model_type',\n", " 'num_trees',\n", " 'objective',\n", " 'specialized_header',\n", " 'task',\n", " 'training_logs',\n", " 'tuning_logs',\n", " 'variable_importances',\n", " 'winner_take_all_inference']" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "[field for field in dir(inspector) if not field.startswith(\"_\")]" ] }, { "cell_type": "markdown", "metadata": { "id": "_QJFITMQsgtK" }, "source": [ "Remember to see [the API-reference](https://tensorflow.org/decision_forests/api_docs/python/tfdf/inspector/AbstractInspector) or use `?` for the builtin documentation." ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "execution": { "iopub.execute_input": "2024-04-20T11:24:54.448197Z", "iopub.status.busy": "2024-04-20T11:24:54.447657Z", "iopub.status.idle": "2024-04-20T11:24:54.487393Z", "shell.execute_reply": "2024-04-20T11:24:54.486705Z" }, "id": "YCGkpRkssdCb" }, "outputs": [], "source": [ "?inspector.model_type" ] }, { "cell_type": "markdown", "metadata": { "id": "nd-fOgmjd1oK" }, "source": [ "Some of the model meta-data:" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "execution": { "iopub.execute_input": "2024-04-20T11:24:54.490987Z", "iopub.status.busy": "2024-04-20T11:24:54.490551Z", "iopub.status.idle": "2024-04-20T11:24:54.494787Z", "shell.execute_reply": "2024-04-20T11:24:54.494175Z" }, "id": "Iu_To_z9d35G" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Model type: RANDOM_FOREST\n", "Number of trees: 300\n", "Objective: Classification(label=__LABEL, class=None, num_classes=3)\n", "Input features: [\"bill_depth_mm\" (1; #1), \"bill_length_mm\" (1; #2), \"body_mass_g\" (1; #3), \"flipper_length_mm\" (1; #4), \"island\" (4; #5), \"sex\" (4; #6), \"year\" (1; #7)]\n" ] } ], "source": [ "print(\"Model type:\", inspector.model_type())\n", "print(\"Number of trees:\", inspector.num_trees())\n", "print(\"Objective:\", inspector.objective())\n", "print(\"Input features:\", inspector.features())" ] }, { "cell_type": "markdown", "metadata": { "id": "Zs7b8EBud9JM" }, "source": [ "`evaluate()` is the evaluation of the model computed during training. The dataset used for this evaluation depends on the algorithm. For example, it can be the validation dataset or the out-of-bag-dataset .\n", "\n", "**Note:** While computed during training, `evaluate()` is never an evaluation on the\n", "training dataset." ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "execution": { "iopub.execute_input": "2024-04-20T11:24:54.498031Z", "iopub.status.busy": "2024-04-20T11:24:54.497556Z", "iopub.status.idle": "2024-04-20T11:24:54.501558Z", "shell.execute_reply": "2024-04-20T11:24:54.500992Z" }, "id": "uVN-j0E4Q1T3" }, "outputs": [ { "data": { "text/plain": [ "Evaluation(num_examples=344, accuracy=0.9767441860465116, loss=0.06894904488784283, rmse=None, ndcg=None, aucs=None, auuc=None, qini=None)" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "inspector.evaluation()" ] }, { "cell_type": "markdown", "metadata": { "id": "2r6Yrjb7f5KH" }, "source": [ "The variable importances are:" ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "execution": { "iopub.execute_input": "2024-04-20T11:24:54.504764Z", "iopub.status.busy": "2024-04-20T11:24:54.504381Z", "iopub.status.idle": "2024-04-20T11:24:54.508511Z", "shell.execute_reply": "2024-04-20T11:24:54.507907Z" }, "id": "qoqhhmGjf7ED" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Available variable importances:\n", "\t MEAN_DECREASE_IN_PRAUC_3_VS_OTHERS\n", "\t MEAN_DECREASE_IN_PRAUC_1_VS_OTHERS\n", "\t INV_MEAN_MIN_DEPTH\n", "\t MEAN_DECREASE_IN_AUC_1_VS_OTHERS\n", "\t MEAN_DECREASE_IN_AP_2_VS_OTHERS\n", "\t MEAN_DECREASE_IN_AUC_3_VS_OTHERS\n", "\t MEAN_DECREASE_IN_AUC_2_VS_OTHERS\n", "\t MEAN_DECREASE_IN_AP_1_VS_OTHERS\n", "\t NUM_AS_ROOT\n", "\t NUM_NODES\n", "\t MEAN_DECREASE_IN_PRAUC_2_VS_OTHERS\n", "\t MEAN_DECREASE_IN_ACCURACY\n", "\t SUM_SCORE\n", "\t MEAN_DECREASE_IN_AP_3_VS_OTHERS\n" ] } ], "source": [ "print(f\"Available variable importances:\")\n", "for importance in inspector.variable_importances().keys():\n", " print(\"\\t\", importance)" ] }, { "cell_type": "markdown", "metadata": { "id": "8QUW8w-UmCoW" }, "source": [ "Different variable importances have different semantics. For example, a feature\n", "with a **mean decrease in auc** of `0.05` means that removing this feature from\n", "the training dataset would reduce/hurt the AUC by 5%." ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "execution": { "iopub.execute_input": "2024-04-20T11:24:54.511809Z", "iopub.status.busy": "2024-04-20T11:24:54.511264Z", "iopub.status.idle": "2024-04-20T11:24:54.516008Z", "shell.execute_reply": "2024-04-20T11:24:54.515415Z" }, "id": "OoSG5T8ShSdG" }, "outputs": [ { "data": { "text/plain": [ "[(\"bill_length_mm\" (1; #2), 0.0713061951754389),\n", " (\"island\" (4; #5), 0.007298519736842035),\n", " (\"flipper_length_mm\" (1; #4), 0.004505893640351366),\n", " (\"bill_depth_mm\" (1; #1), 0.0021244517543865804),\n", " (\"body_mass_g\" (1; #3), 0.0005482456140351033),\n", " (\"sex\" (4; #6), 0.00047971491228060437),\n", " (\"year\" (1; #7), 0.0)]" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Mean decrease in AUC of the class 1 vs the others.\n", "inspector.variable_importances()[\"MEAN_DECREASE_IN_AUC_1_VS_OTHERS\"]" ] }, { "cell_type": "markdown", "metadata": { "id": "afSPSBg_uJuI" }, "source": [ "Plot the variable importances from the inspector using Matplotlib" ] }, { "cell_type": "code", "execution_count": 15, "metadata": { "execution": { "iopub.execute_input": "2024-04-20T11:24:54.519423Z", "iopub.status.busy": "2024-04-20T11:24:54.518818Z", "iopub.status.idle": "2024-04-20T11:24:54.779885Z", "shell.execute_reply": "2024-04-20T11:24:54.778963Z" }, "id": "53lM0wDmuI8u" }, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "import matplotlib.pyplot as plt\n", "\n", "plt.figure(figsize=(12, 4))\n", "\n", "# Mean decrease in AUC of the class 1 vs the others.\n", "variable_importance_metric = \"MEAN_DECREASE_IN_AUC_1_VS_OTHERS\"\n", "variable_importances = inspector.variable_importances()[variable_importance_metric]\n", "\n", "# Extract the feature name and importance values.\n", "#\n", "# `variable_importances` is a list of tuples.\n", "feature_names = [vi[0].name for vi in variable_importances]\n", "feature_importances = [vi[1] for vi in variable_importances]\n", "# The feature are ordered in decreasing importance value.\n", "feature_ranks = range(len(feature_names))\n", "\n", "bar = plt.barh(feature_ranks, feature_importances, label=[str(x) for x in feature_ranks])\n", "plt.yticks(feature_ranks, feature_names)\n", "plt.gca().invert_yaxis()\n", "\n", "# TODO: Replace with \"plt.bar_label()\" when available.\n", "# Label each bar with values\n", "for importance, patch in zip(feature_importances, bar.patches):\n", " plt.text(patch.get_x() + patch.get_width(), patch.get_y(), f\"{importance:.4f}\", va=\"top\")\n", "\n", "plt.xlabel(variable_importance_metric)\n", "plt.title(\"Mean decrease in AUC of the class 1 vs the others\")\n", "plt.tight_layout()\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": { "id": "q_-kLTNjhaQo" }, "source": [ "Finally, access the actual tree structure:" ] }, { "cell_type": "code", "execution_count": 16, "metadata": { "execution": { "iopub.execute_input": "2024-04-20T11:24:54.783723Z", "iopub.status.busy": "2024-04-20T11:24:54.783235Z", "iopub.status.idle": "2024-04-20T11:24:54.789974Z", "shell.execute_reply": "2024-04-20T11:24:54.789148Z" }, "id": "l4N_heuzhcUS" }, "outputs": [ { "data": { "text/plain": [ "Tree(root=NonLeafNode(condition=(bill_length_mm >= 43.25; miss=True, score=0.5482327342033386), pos_child=NonLeafNode(condition=(island in ['Biscoe']; miss=True, score=0.6515106558799744), pos_child=NonLeafNode(condition=(bill_depth_mm >= 17.225584030151367; miss=False, score=0.027205035090446472), pos_child=LeafNode(value=ProbabilityValue([0.16666666666666666, 0.0, 0.8333333333333334],n=6.0), idx=7), neg_child=LeafNode(value=ProbabilityValue([0.0, 0.0, 1.0],n=104.0), idx=6), value=ProbabilityValue([0.00909090909090909, 0.0, 0.990909090909091],n=110.0)), neg_child=LeafNode(value=ProbabilityValue([0.0, 1.0, 0.0],n=61.0), idx=5), value=ProbabilityValue([0.005847953216374269, 0.3567251461988304, 0.6374269005847953],n=171.0)), neg_child=NonLeafNode(condition=(bill_depth_mm >= 15.100000381469727; miss=True, score=0.150658518075943), pos_child=NonLeafNode(condition=(flipper_length_mm >= 187.5; miss=True, score=0.036139510571956635), pos_child=LeafNode(value=ProbabilityValue([1.0, 0.0, 0.0],n=104.0), idx=4), neg_child=NonLeafNode(condition=(bill_length_mm >= 42.30000305175781; miss=True, score=0.23430533707141876), pos_child=LeafNode(value=ProbabilityValue([0.0, 1.0, 0.0],n=5.0), idx=3), neg_child=NonLeafNode(condition=(bill_length_mm >= 40.55000305175781; miss=True, score=0.043961383402347565), pos_child=LeafNode(value=ProbabilityValue([0.8, 0.2, 0.0],n=5.0), idx=2), neg_child=LeafNode(value=ProbabilityValue([1.0, 0.0, 0.0],n=53.0), idx=1), value=ProbabilityValue([0.9827586206896551, 0.017241379310344827, 0.0],n=58.0)), value=ProbabilityValue([0.9047619047619048, 0.09523809523809523, 0.0],n=63.0)), value=ProbabilityValue([0.9640718562874252, 0.03592814371257485, 0.0],n=167.0)), neg_child=LeafNode(value=ProbabilityValue([0.0, 0.0, 1.0],n=6.0), idx=0), value=ProbabilityValue([0.930635838150289, 0.03468208092485549, 0.03468208092485549],n=173.0)), value=ProbabilityValue([0.47093023255813954, 0.19476744186046513, 0.33430232558139533],n=344.0)), label_classes=None)" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "inspector.extract_tree(tree_idx=0)" ] }, { "cell_type": "markdown", "metadata": { "id": "B8u_0p80hoeP" }, "source": [ "Extracting a tree is not efficient. If speed is important, the model inspection can be done with the `iterate_on_nodes()` method instead. This method is a Depth First Pre-order traversals iterator on all the nodes of the model.\n", "\n", "**Note:** `extract_tree()` is implemented using `iterate_on_nodes()`.\n", "\n", "For following example computes how many times each feature is used (this is a\n", "kind of structural variable importance):" ] }, { "cell_type": "code", "execution_count": 17, "metadata": { "execution": { "iopub.execute_input": "2024-04-20T11:24:54.793602Z", "iopub.status.busy": "2024-04-20T11:24:54.793006Z", "iopub.status.idle": "2024-04-20T11:24:54.908214Z", "shell.execute_reply": "2024-04-20T11:24:54.907390Z" }, "id": "OUEpes34iHg8" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Number of condition nodes per features:\n", "\t bill_length_mm : 778\n", "\t bill_depth_mm : 463\n", "\t flipper_length_mm : 414\n", "\t island : 342\n", "\t body_mass_g : 338\n", "\t year : 19\n", "\t sex : 36\n" ] } ], "source": [ "# number_of_use[F] will be the number of node using feature F in its condition.\n", "number_of_use = collections.defaultdict(lambda: 0)\n", "\n", "# Iterate over all the nodes in a Depth First Pre-order traversals.\n", "for node_iter in inspector.iterate_on_nodes():\n", "\n", " if not isinstance(node_iter.node, tfdf.py_tree.node.NonLeafNode):\n", " # Skip the leaf nodes\n", " continue\n", "\n", " # Iterate over all the features used in the condition.\n", " # By default, models are \"oblique\" i.e. each node tests a single feature.\n", " for feature in node_iter.node.condition.features():\n", " number_of_use[feature] += 1\n", "\n", "print(\"Number of condition nodes per features:\")\n", "for feature, count in number_of_use.items():\n", " print(\"\\t\", feature.name, \":\", count)" ] }, { "cell_type": "markdown", "metadata": { "id": "CD39OmGbnPww" }, "source": [ "## Creating a model by hand\n", "\n", "In this section you will create a small Random Forest model by hand. To make it\n", "extra easy, the model will only contain one simple tree:\n", "\n", "```\n", "3 label classes: Red, blue and green.\n", "2 features: f1 (numerical) and f2 (string categorical)\n", "\n", "f1>=1.5\n", " ├─(pos)─ f2 in [\"cat\",\"dog\"]\n", " │ ├─(pos)─ value: [0.8, 0.1, 0.1]\n", " │ └─(neg)─ value: [0.1, 0.8, 0.1]\n", " └─(neg)─ value: [0.1, 0.1, 0.8]\n", "```" ] }, { "cell_type": "code", "execution_count": 18, "metadata": { "execution": { "iopub.execute_input": "2024-04-20T11:24:54.911704Z", "iopub.status.busy": "2024-04-20T11:24:54.911246Z", "iopub.status.idle": "2024-04-20T11:24:54.915974Z", "shell.execute_reply": "2024-04-20T11:24:54.915135Z" }, "id": "fGGe5IxdnuEa" }, "outputs": [], "source": [ "# Create the model builder\n", "builder = tfdf.builder.RandomForestBuilder(\n", " path=\"/tmp/manual_model\",\n", " objective=tfdf.py_tree.objective.ClassificationObjective(\n", " label=\"color\", classes=[\"red\", \"blue\", \"green\"]))" ] }, { "cell_type": "markdown", "metadata": { "id": "DRnJ2u-Moqbf" }, "source": [ "Each tree is added one by one.\n", "\n", "**Note:** The tree object (`tfdf.py_tree.tree.Tree`) is the same as the one returned by `extract_tree()` in the previous section." ] }, { "cell_type": "code", "execution_count": 19, "metadata": { "execution": { "iopub.execute_input": "2024-04-20T11:24:54.919723Z", "iopub.status.busy": "2024-04-20T11:24:54.919301Z", "iopub.status.idle": "2024-04-20T11:24:54.926196Z", "shell.execute_reply": "2024-04-20T11:24:54.925349Z" }, "id": "cmAddPhAo0tG" }, "outputs": [], "source": [ "# So alias\n", "Tree = tfdf.py_tree.tree.Tree\n", "SimpleColumnSpec = tfdf.py_tree.dataspec.SimpleColumnSpec\n", "ColumnType = tfdf.py_tree.dataspec.ColumnType\n", "# Nodes\n", "NonLeafNode = tfdf.py_tree.node.NonLeafNode\n", "LeafNode = tfdf.py_tree.node.LeafNode\n", "# Conditions\n", "NumericalHigherThanCondition = tfdf.py_tree.condition.NumericalHigherThanCondition\n", "CategoricalIsInCondition = tfdf.py_tree.condition.CategoricalIsInCondition\n", "# Leaf values\n", "ProbabilityValue = tfdf.py_tree.value.ProbabilityValue\n", "\n", "builder.add_tree(\n", " Tree(\n", " NonLeafNode(\n", " condition=NumericalHigherThanCondition(\n", " feature=SimpleColumnSpec(name=\"f1\", type=ColumnType.NUMERICAL),\n", " threshold=1.5,\n", " missing_evaluation=False),\n", " pos_child=NonLeafNode(\n", " condition=CategoricalIsInCondition(\n", " feature=SimpleColumnSpec(name=\"f2\",type=ColumnType.CATEGORICAL),\n", " mask=[\"cat\", \"dog\"],\n", " missing_evaluation=False),\n", " pos_child=LeafNode(value=ProbabilityValue(probability=[0.8, 0.1, 0.1], num_examples=10)),\n", " neg_child=LeafNode(value=ProbabilityValue(probability=[0.1, 0.8, 0.1], num_examples=20))),\n", " neg_child=LeafNode(value=ProbabilityValue(probability=[0.1, 0.1, 0.8], num_examples=30)))))" ] }, { "cell_type": "markdown", "metadata": { "id": "DjWdgRNNqEAD" }, "source": [ "Conclude the tree writing" ] }, { "cell_type": "code", "execution_count": 20, "metadata": { "execution": { "iopub.execute_input": "2024-04-20T11:24:54.929661Z", "iopub.status.busy": "2024-04-20T11:24:54.929193Z", "iopub.status.idle": "2024-04-20T11:24:55.946091Z", "shell.execute_reply": "2024-04-20T11:24:55.945132Z" }, "id": "cJqn4khxqH6t" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "[INFO 24-04-20 11:24:54.9480 UTC kernel.cc:1233] Loading model from path /tmp/manual_model/tmp/ with prefix f938aac6d7ed44f5\n", "[INFO 24-04-20 11:24:54.9483 UTC decision_forest.cc:734] Model loaded with 1 root(s), 5 node(s), and 2 input feature(s).\n", "[INFO 24-04-20 11:24:54.9483 UTC kernel.cc:1061] Use fast generic engine\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Assets written to: /tmp/manual_model/assets\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "INFO:tensorflow:Assets written to: /tmp/manual_model/assets\n" ] } ], "source": [ "builder.close()" ] }, { "cell_type": "markdown", "metadata": { "id": "_oxxXAn7qK-z" }, "source": [ "Now you can open the model as a regular keras model, and make predictions:" ] }, { "cell_type": "code", "execution_count": 21, "metadata": { "execution": { "iopub.execute_input": "2024-04-20T11:24:55.949899Z", "iopub.status.busy": "2024-04-20T11:24:55.949356Z", "iopub.status.idle": "2024-04-20T11:24:56.134350Z", "shell.execute_reply": "2024-04-20T11:24:56.133412Z" }, "id": "ETwjOJ5uqP5i" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "[INFO 24-04-20 11:24:56.1029 UTC kernel.cc:1233] Loading model from path /tmp/manual_model/assets/ with prefix f938aac6d7ed44f5\n", "[INFO 24-04-20 11:24:56.1032 UTC decision_forest.cc:734] Model loaded with 1 root(s), 5 node(s), and 2 input feature(s).\n", "[INFO 24-04-20 11:24:56.1032 UTC kernel.cc:1061] Use fast generic engine\n" ] } ], "source": [ "manual_model = tf_keras.models.load_model(\"/tmp/manual_model\")" ] }, { "cell_type": "code", "execution_count": 22, "metadata": { "execution": { "iopub.execute_input": "2024-04-20T11:24:56.137655Z", "iopub.status.busy": "2024-04-20T11:24:56.137384Z", "iopub.status.idle": "2024-04-20T11:24:56.784427Z", "shell.execute_reply": "2024-04-20T11:24:56.783451Z" }, "id": "qlC4N-LuqWWR" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\r", "1/2 [==============>...............] - ETA: 0s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "2/2 [==============================] - 1s 3ms/step\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "predictions:\n", " [[0.1 0.1 0.8]\n", " [0.8 0.1 0.1]\n", " [0.1 0.8 0.1]]\n" ] } ], "source": [ "examples = tf.data.Dataset.from_tensor_slices({\n", " \"f1\": [1.0, 2.0, 3.0],\n", " \"f2\": [\"cat\", \"cat\", \"bird\"]\n", " }).batch(2)\n", "\n", "predictions = manual_model.predict(examples)\n", "\n", "print(\"predictions:\\n\",predictions)" ] }, { "cell_type": "markdown", "metadata": { "id": "mxJyp1mKFPXb" }, "source": [ "Access the structure:\n", "\n", "**Note:** Because the model is serialized-and-deserialized, you need to use an alternative but equivalent form." ] }, { "cell_type": "code", "execution_count": 23, "metadata": { "execution": { "iopub.execute_input": "2024-04-20T11:24:56.788532Z", "iopub.status.busy": "2024-04-20T11:24:56.787846Z", "iopub.status.idle": "2024-04-20T11:24:56.810667Z", "shell.execute_reply": "2024-04-20T11:24:56.809680Z" }, "id": "IjcyMHJUFO_B" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "yggdrasil_model_path: /tmp/manual_model/assets/\n", "Input features: [\"f1\" (1; #1), \"f2\" (4; #2)]\n" ] } ], "source": [ "yggdrasil_model_path = manual_model.yggdrasil_model_path_tensor().numpy().decode(\"utf-8\")\n", "print(\"yggdrasil_model_path:\",yggdrasil_model_path)\n", "\n", "inspector = tfdf.inspector.make_inspector(yggdrasil_model_path)\n", "print(\"Input features:\", inspector.features())" ] }, { "cell_type": "markdown", "metadata": { "id": "muW1hgmotx8J" }, "source": [ "And of course, you can plot this manually constructed model:" ] }, { "cell_type": "code", "execution_count": 24, "metadata": { "execution": { "iopub.execute_input": "2024-04-20T11:24:56.814141Z", "iopub.status.busy": "2024-04-20T11:24:56.813597Z", "iopub.status.idle": "2024-04-20T11:24:56.820910Z", "shell.execute_reply": "2024-04-20T11:24:56.820084Z" }, "id": "bqahDVg3t1xM" }, "outputs": [ { "data": { "text/html": [ "\n", "\n", "
\n", "\n" ], "text/plain": [ "" ] }, "execution_count": 24, "metadata": {}, "output_type": "execute_result" } ], "source": [ "tfdf.model_plotter.plot_model_in_colab(manual_model)" ] } ], "metadata": { "colab": { "collapsed_sections": [], "name": "advanced_colab.ipynb", "provenance": [], "toc_visible": true }, "kernelspec": { "display_name": "Python 3", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.19" } }, "nbformat": 4, "nbformat_minor": 0 }