TensorFlow Model Training Using GradientTape | by Rashida Nasrin Sucky | Oct, 2023


Photo by Sivani Bandaru on Unsplash

Use of GradientTape to Update the Weights

TensorFlow is arguably the most popular library for deep learning. I wrote so many tutorials on TensorFlow before and still continuing. TensorFlow is very well organized and easy to use package where you do not need to worry about model development and model training too much. Pretty much most of the stuff is taken care of by the package itself. That is probably the reason why it has gotten so popular in the industry. But at the same time, sometimes it is nice to have control over the behind-the-scenes functionalities. It gives you a lot of power to experiment with the models. If you are job seeker, some extra knowledge may give you an edge.

Previously, I wrote an article on how to develop custom activation functions, layers, and loss functions. In this article, we will see how you can train the model manually and update the weights yourself. But don’t worry. You don’t have to remember the differential calculus all over again. We have GradientTape() method available in TensorFlow itself to take care of that part.

If GradientTape() is totally new to you, please feel free to check this exercises on GradientTape() that shows you, how GradientTape() works: Introduction to GradientTape in TensorFlow — Regenerative (regenerativetoday.com)

Data Preparation

In this article we work on a simple classification algorithm in TensorFlow using GradientTape(). Please download the dataset from this link:

Heart Failure Prediction Dataset (kaggle.com)

This dataset has an open database license.

These are the necessary imports:

import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Dense, Input
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.ticker as mticker
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix
import itertools
from tqdm import tqdm
import tensorflow_datasets as tfds

Creating the DataFrame with the dataset:

import pandas as pd
df = pd.read_csv('heart.csv')
df



Source link

Be the first to comment

Leave a Reply

Your email address will not be published.


*