
We know Transfer Learning is powerful, it can reduce training time, have better generalization models, reduce training memory and so on.
So in this article, I will demonstrate how to build a GAN (CycleGAN) with transfer learning from scratch.
With Transfer Learning with CycleGAN:
- You only need one generator.
- No input labels are needed: it will learn and output opposite image by itself.
- Reduced VRAM usage.
- Only need few hours for training.
- Much more stable results in GAN.
You will need a GPU with at least 16GB of VRAM for better results.
However, you can adjust the batch size / filters of CNN to reduce memory usage and make it compatible with GPUs with 12GB or 8GB VRAM.
You can find code in my GitHub:
https://github.com/Seachaos/Tree.Rocks/blob/main/TransferLearningGAN/TransferLearningCycleGAN.ipynb
CycleGAN idea
it use two sets of images, we call it A and B. For example, set A is a lot of gray cats image and set B is a lot of tabby cats.
The Deep Learning Model will learn the style and features of each image set and translate between. so it can convert images from A to B and convert back from B to A.
For more detail you can read on CycleGAN tutorials on the TensorFlow website.
Transfer Learning Idea
Transfer Learning using exists deep learning model ( pre-trained models ) to do another tasks, It’s like extracting what we need features from an existing model then to use another model.
With this idea, we can use image classifier models ( such as VGG16 ) to build GAN models. benefit from the image classifier model so the GAN model can easily understand input and output images.
Combine together: Transfer Learning with CycleGAN
Now we leverage by Transfer Learning, we use it as input layers for GAN, and the model can easy recognized input image and output corresponding image without label.
Let’s take close look about our model.
It use the input image to classifier model, and to get “x_cmd” it includes the source image information, like making the model see the entire image. then combine this information with other convolutional neural network ( CNN ) like U-Net to get output.
We use TensorFlow, image input size is 128×128 and horse to zebra images as example.
You need install TensorFlow Datasets for the training data. ( if you already installed you can ignore this )
pip install tensorflow-datasets
First thing first, let’s import what we needed:
import tensorflow as tf
from tensorflow.keras import layers
import tensorflow_datasets as tfdsimport numpy as np
import matplotlib.pyplot as plt
from tqdm.auto import trange, tqdm
import random
Prepare Datasets
We use horse2zebra as our training datasets.
This section includes a lot of code for data preparation and augmentation.
Let’s get image sets A and B ( house and zebra )
dataset, dataset_info = tfds.load('cycle_gan/horse2zebra', with_info=True, as_supervised=True)train_a, train_b = dataset['trainA'], dataset['trainB']
test_a, test_b = dataset['testA'], dataset['testB']
Now, Let’s setup some variable we will be use.
If your GPU has less than 16GB VRAM, try reduce the batch size to fit your GPU ( you may also need do more training epoch ).
batch_size = 32 # set to 16 or less, if you don't have enough VRAM.img_size = 128
big_img_size = 192
LR = 0.00012
- big_img_size: is for image augment use.
- LR: learning rate
Then we now extract datasets for training and test
def _process_img(image, label):
image = tf.image.resize(image, (big_img_size, big_img_size))
image = (image / 127.5) - 1.0
return image, labeldef prepare_data(data, b=batch_size):
return data \
.cache() \
.map(_process_img, num_parallel_calls=tf.data.AUTOTUNE) \
.shuffle(b) \
.batch(b)
ds_train_a, ds_train_b = prepare_data(train_a), prepare_data(train_b)
ds_test_a, ds_test_b = prepare_data(test_a), prepare_data(test_b)
x_train_sets = [
tf.concat([a[0] for a in ds_train_a], axis=0),
tf.concat([b[0] for b in ds_train_b], axis=0),
]
x_test_sets = [
tf.concat([a[0] for a in ds_test_a], axis=0),
tf.concat([b[0] for b in ds_test_b], axis=0),
]
print('x_train_all: ', sum([s.shape[0] for s in x_train_sets]), x_train_sets[0].numpy().min(), x_train_sets[0].numpy().max())
print('x_test_all: ', sum([s.shape[0] for s in x_test_sets]), x_test_sets[0].numpy().min(), x_test_sets[0].numpy().max())
you should see output like this:
x_train_all: 2401 -1.0 1.0
x_test_all: 260 -1.0 1.0
Then, We need do image augment for training, so there have two function will be use: “get_x_train” and “get_x_test”
These functions will give us image sets A and image sets B for training.
def _rand_pick(data, augment=True):
idx = np.random.choice(range(len(data)), size=batch_size, replace=False)
x = tf.gather(data, idx, axis=0)
if augment:
cx = random.uniform(1.0, 1.5)
cy = random.uniform(1.0, 1.5)
x = tf.image.random_crop(x, size=(batch_size, int(img_size * cx), int(img_size * cy), 3))
x = tf.image.random_flip_left_right(x)
x = tf.image.resize(x, (img_size, img_size))
return xdef get_x_train():
xa = _rand_pick(x_train_sets[0])
xb = _rand_pick(x_train_sets[1])
return xa, xb
def get_x_test():
xa = _rand_pick(x_test_sets[0], augment=False)
xb = _rand_pick(x_test_sets[1], augment=False)
return xa, xb
Let’s verify “get_x_train” is work or not.
# Verify "get_x_train" output
def cvtImg(x):
return (x + 1.0) / 2.0def show(x, S=12):
x = cvtImg(x)
plt.figure(figsize=(15, 3))
for i in range(min(len(x), S)):
plt.subplot(1, S, i + 1)
plt.imshow(x[i])
plt.axis('off')
plt.show()
for _ in range(1):
xa, xb = get_x_train()
xa = xa.numpy()
print(xa.min(), xa.max(), xa.shape)
show(xa)
show(xb.numpy())
You should see image output like:
1. Build Model — Transfer Learning
Now, we start build GAN model, but first, we need get image input layers first. Let’s extract it from VGG16 model.
We will use the layers “block2_conv2”, “block3_conv3”, “block4_conv3”… from the VGG16 model.
This is the output size of 64×64, 32×32, 16×16 …
You can use “base_model.summary()” to see more detail about VGG16 model.
base_model = tf.keras.applications.VGG16(input_shape=(img_size, img_size, 3), include_top=False)x = x_input = base_model.input
outputs = [
'block2_conv2',
'block3_conv3',
'block4_conv3',
'block5_conv1',
'block5_pool',
]
x_output = [base_model.get_layer(n).output for n in outputs]
base_model = tf.keras.models.Model(x_input, x_output)
base_model.trainable = False
# base_model.summary() # if you want see more detail about VGG16
2. Build Model — Generator
We use GELU as our activation function, For convenience, we’ll define the“act” function for normalization and activation function.
act_name = 'gelu'def act(x):
x = layers.LayerNormalization()(x)
x = layers.Activation(act_name)(x)
return x
This is layer function for generator model.
It takes “x_cmd” from input, it observed input image and find out what’s should be output.
def conv_with_cmd(x_img_input, x_cmd, f=64, sp=4):
x = layers.Dense(128)(x_cmd)
x = layers.BatchNormalization()(x)
x = layers.Activation(act_name)(x)x = layers.Dense(f)(x)
x = layers.BatchNormalization()(x)
x = layers.Activation('sigmoid')(x)
x_g = layers.Reshape((1, 1, f))(x)
# ---
x = layers.Conv2D(f, kernel_size=3, padding='same')(x_img_input)
x = layers.BatchNormalization()(x)
x = layers.Activation(act_name)(x)
x = x * x_g
return x
Now we build generator model.
- The x_input is input source ( image ), go to base_model (VGG16), then have output: [x64, x32, x16, x8, x4]
- “x_cmd” is from VGG16 last output model ( which is x4, 4×4 pixel ), use GlobalMaxPool2D and Dense to extract information.
- The entire model is like U-Net, use “UpSampling2D” and “Concatenate” with base_model output.
From x4 up to x8, x8 up to x16… and so on until output size, each with x_cmd information. - if you don’t have enough VRAM, can try reduce filter of CNN, but the output result may not well.
def create_gen_model():
# img input
x_input = layers.Input(shape=(img_size, img_size, 3))# load base model
x_base_out = base_model(x_input)
[x64, x32, x16, x8, x4] = x_base_out
# x_cmd
x = x4
x = layers.Conv2D(256, kernel_size=3, padding='same')(x)
x = act(x)
x = layers.GlobalMaxPool2D()(x)
x = layers.Dense(128)(x)
x = layers.BatchNormalization()(x)
x = layers.Activation(act_name)(x)
x_cmd = x
# GAN up
x = conv_with_cmd(x4, x_cmd, f=512)
# if you don't have enought VRAM, try reduce filters
for i, (x_cat, f) in enumerate([
(x8, 512),
(x16, 384),
(x32, 256),
(x64, 256),
(x_input, 256),
]):
# final output
x = layers.Conv2D(3, kernel_size=3, padding='same')(x)
x = layers.BatchNormalization()(x)
x = layers.Activation('tanh')(x)
return tf.keras.models.Model(x_input, x)
gen = create_gen_model()
# gen.summary() # if you want see more detail about model
3. Build Model — Discriminator
Let’s makeDiscriminator model.
- use same from base_model.
- use x8 and x4 from base_model.
- softmax output for classifier output.
We will explain output later.
def create_dis_model():
x = x_input = layers.Input(shape=(img_size, img_size, 3))[x64, x32, x16, x8, x4] = base_model(x_input)
x = x8
x = layers.Conv2D(512, kernel_size=3, padding='same')(x)
x = act(x)
x = layers.MaxPool2D()(x)
x = layers.Concatenate()([x, x4])
x = layers.Conv2D(512, kernel_size=3, padding='same')(x)
x = act(x)
x = layers.GlobalMaxPool2D()(x)
x = layers.Dense(384)(x)
x = layers.BatchNormalization()(x)
x = layers.Activation(act_name)(x)
x = layers.Dense(128)(x)
x = layers.BatchNormalization()(x)
x = layers.Activation(act_name)(x)
x = layers.Dense(4)(x)
x = layers.BatchNormalization()(x)
x = layers.Activation('softmax')(x)
return tf.keras.models.Model(x_input, x)
dis = create_dis_model()
# dis.summary() # if you want see more detail about model
Here we define the output variable for training.
We have 4 output results for Discriminator, which is:
- false A, all 0 ( fake image A )
- false B, all 1 ( fake image B )
- true A, all 2 ( real image A )
- true B, all 3 ( real image B )
They all have same shape as batch_size for training.
y_false_a = np.zeros(batch_size)
y_false_b = np.full_like(y_false_a, 1)
y_true_a = np.full_like(y_false_a, 2)
y_true_b = np.full_like(y_false_a, 3)
And here are optimizer for dis
and gen
model,
we use AdamW as optimizer.
opt_gen = tf.keras.optimizers.AdamW(learning_rate=LR)
opt_dis = tf.keras.optimizers.AdamW(learning_rate=LR)
4. Training Model — Discriminator
After we have the generator model (gen) and the discriminator model (dis).
We do `train_dis` and `train_gen` function first then we do train function to run all.
Let’s see the discriminator training code first:
@tf.function
def _train_gen_cycle(x_real, y_t, y_f):
with tf.GradientTape(persistent=True) as tape:
x_fake = gen(x_real) # forward# discriminator
y_p = dis(x_fake)
loss_dis = tf.losses.sparse_categorical_crossentropy(y_t, y_p)
# revert
x_revert = gen(x_fake)
loss_revert = tf.losses.mse(x_real, x_revert)
loss = tf.reduce_mean(loss_dis) + tf.reduce_mean(loss_revert)
g = tape.gradient(loss, gen.trainable_variables)
g = zip(g, gen.trainable_variables)
opt_gen.apply_gradients(g)
return float(loss)
def train_gen():
gen.trainable = True
dis.trainable = False
base_model.trainable = False
xa, xb = get_x_train()
loss_a = \
_train_gen_cycle(xa, y_true_b, y_true_a)
loss_b = \
_train_gen_cycle(xb, y_true_a, y_true_b)
return float(loss_a), float(loss_b)
train_gen()
Explain:
- we need disable all model except discriminator training, so we use: dis.trainable = True
gen.trainable = False
base_model.trainable = False - get xa and xb for training images. ( image sets A and image sets B )
- we use `generator` to get `xa_fake` by `gen.predict(xb, verbose=False)`
- feed `_train_dis(xa, y_true_a)` to teach discriminator source image A is real image (y_true_a)
- feed `_train_dis(xa_fake, y_false_a)` to teach discriminator that generator image A is fake image (y_false_a)
- xb and xb_fake same as above step.
5. Training Model — Generator
Now the code for training Generator, see code first:
@tf.function
def _train_gen_cycle(x_real, y_t, y_f):
with tf.GradientTape(persistent=True) as tape:
x_fake = gen(x_real) # forward# discriminator
y_p = dis(x_fake)
loss_dis = tf.losses.sparse_categorical_crossentropy(y_t, y_p)
# revert
x_revert = gen(x_fake)
loss_revert = tf.losses.mse(x_real, x_revert)
loss = tf.reduce_mean(loss_dis) + tf.reduce_mean(loss_revert)
g = tape.gradient(loss, gen.trainable_variables)
g = zip(g, gen.trainable_variables)
opt_gen.apply_gradients(g)
return float(loss)
def train_gen():
gen.trainable = True
dis.trainable = False
base_model.trainable = False
xa, xb = get_x_train()
loss_a = \
_train_gen_cycle(xa, y_true_b, y_true_a)
loss_b = \
_train_gen_cycle(xb, y_true_a, y_true_b)
return float(loss_a), float(loss_b)
train_gen()
Explain:
- Same as discriminator training, disable all model except discriminator training, so we use:
gen.trainable = True
dis.trainable = False
base_model.trainable = False - get xa and xb for training images. ( image sets A and image sets B )
- the function _train_gen_cycle is take (source image, y true, y false )
in `_train_gen_cycle` function:
- We use the gen model to generate fake image then feed to dis (discriminator) to get ouptut (y_p) , back propagation to the gen model as it is true image. ( use loss function crossentropy for y_t, y_p )
- Use gen to revert image it self, example: A->B then do B->A, and it should be the same as input. ( use loss function MSE )
6. Preview — Before total training
Let’s preview output before we start:
def _preview(x_real, title=None):
x_fake = gen.predict(x_real, verbose=0)
x_real = cvtImg(x_real.numpy())
x_fake = cvtImg(x_fake)plt.figure(figsize=(25, 5))
if title:
plt.suptitle(title)
s = min(batch_size, 9)
for i in range(s):
plt.subplot(2, s, i + 1)
plt.axis('off')
plt.imshow(x_real[i])
plt.subplot(2, s, i + 1 + s)
plt.axis('off')
plt.imshow(x_fake[i])
plt.show()
def preview(useTest=True):
if useTest:
xa, xb = get_x_test()
else:
xa, xb = get_x_train()
_preview(xa[:9], 'A -> B')
_preview(xb[:9], 'B -> A')
preview()
You should see output like this: ( results may different )
7. Total Training
Here we run total training:
def train():
bar = trange(200)
for _ in bar:
lda, ldb = train_dis()
lga, lgb = train_gen()
msg = f'gen: {lga:.5f}, {lgb:.5f} | dis: {lda:.5f}, {ldb:.5f}'
bar.set_description(msg)def go():
for i in trange(50):
train()
if i % 5 == 0:
preview()
opt_dis.learning_rate = opt_dis.learning_rate * 0.98
opt_gen.learning_rate = opt_gen.learning_rate * 0.98
lg = opt_gen.learning_rate.numpy()
ld = opt_dis.learning_rate.numpy()
print(f'run: {i}')
print(f'LR gen: {lg:.7f}')
print(f'LR dis: {ld:.7f}')
go()
preview()
We reduce learning rate during training.
It may take a few hours depends on your GPU power.
During training, you may see some output for preview:
You may get some results like:
That’s all 😀
Be the first to comment