TensorFlow Custom Loop
A pattern I use for pretty-progress bars, custom logging and metric-handling in tensorflow
.
'''
This overviews the TensorFlow custom training loop in its (what I think is) most general sense. Four steps:
1. Define Model
2. Define Metrics, Optimizer and Losses
3. Define train and test (validation) functions
4. Write training loop
'''
import wandb
import yaml
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import numpy as np
import enlighten
import logging
import time
from dataclasses import dataclass
# STEP 0 - Set up datasets
'''
---------------------------------------------------------
STEP 1 - Define Model
Define inputs, outputs and wrap using keras
'''
inputs = ...
outputs = ...
model = keras.Model(inputs=inputs, outputs=outputs)
'''
---------------------------------------------------------
STEP 2 - Define Metrics, Optimizer and Losses
Use keras.metrics.Metric and keras.optimizers
Can subclass if necessary
'''
train_metric = keras.metrics...
val_metric = keras.metrics...
optimizer = keras.optimizers...
loss_fn = keras.losses...
'''
---------------------------------------------------------
STEP 3 - Define training and test functions
both take inputs and labels
both return a loss value
training invoke tape and applies loss gradient to weights
test just finds loss value
'''
@tf.function
def train_step(input, labels):
# invoke GradientTape()
with tf.GradientTape() as tape:
# find predicted
pred = model(input, training=True)
# calculate loss
loss_value = loss_fn(labels, pred)
loss_value += sum(model.losses)
# find gradient loss and weights
grads = tape.gradient(loss_value, model.trainable_weights)
# apply gradients to update weights
optimizer.apply_gradients(zip(grads, model.trainable_weights))
# update metric
train_metric.update_state(labels, pred)
return loss_value
@tf.function
def test_step(input, labels):
# find predicted
pred = model(input, training=False)
# update metric
val_metric.update_state(labels, pred)
'''
---------------------------------------------------------
STEP 4 - Training/Validation Loop
Remember to reset metric states
'''
#-------------------------------CONFIGURATION---------------------------------#
@dataclass
class ModelConf:
epochs: int
batch_size: int
learning_rate: float
dropout_rate: float
wandb_config = {
"entity": "aadi350",
"project": "urban_heat_index",
"model_name": "vit",
...
}
with open('path/to/conf.yaml', 'r') as f:
conf_file = yaml.load(f, Loader=yaml.SafeLoader)
m = ModelConf(
conf_file['epochs'],
conf_file['batch_size'],
conf_file['learning_rate'],
conf_file['dropout_rate'],
...
)
wandb_config.update(
conf_file
)
wandb.init(config=wandb_config)
m.epochs = conf_file['epochs']
m.batch_size = conf_file['batch_size']
m.learning_rate = conf_file['learning_rate']
...
#-----------------------------END CONFIGURATION-------------------------------#
log_batch = 10
# Set up logging
mlog = logging.getLogger("metric_logger")
handler = logging.StreamHandler()
formatter = logging.Formatter(
"%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S"
)
handler.setFormatter(formatter)
mlog.addHandler(handler)
mlog.setLevel(logging.DEBUG)
# set up progress bar for epochs
manager = enlighten.get_manager()
epoch_counter = manager.counter(
total=m.epochs, desc="Epoch", unit="epochs", color="green"
)
status_bar = manager.status_bar(
"Best metrics", color="white_on_blue", justify=enlighten.Justify.CENTER
)
for epoch in range(m.epochs):
# progress bar for train step
step_counter = manager.counter(
total=m.epochs//m.batch_size
desc="Train Step",
unit="steps",
leave=False,
color="bright_green",
)
# ---------------------TRAIN------------------------#
for step, (in_batch_train, label_batch_train) in enumerate(train_dataset):
loss_value = train_step(in_batch_train, label_batch_train)
step_counter.update()
train_acc = train_metric.result()
train_loss = loss_value
mlog.info("Training acc over epoch: %.4f" % (float(train_acc),))
train_metric.reset_states()
step_counter.close()
# ---------------------VAL---------------------------#
# progress bar for validation step
step_counter = manager.counter(
total=m.epochs//m.batch_size,
desc="Validation Step",
unit="steps",
leave=False,
color="bright_yellow",
)
# validation loop
for step, (in_batch_val, label_batch_val) in enumerate(val_dataset):
test_step(in_batch_val, label_batch_val)
val_acc = val_metric.result()
val_metric.reset_states()
step_counter.close()
# -------------------STATUS-----------------------------#
if best_val < val_acc:
best_val = val_acc
status_bar.update(f"Best validation metric: {best_val}, epoch: {epoch}")
tf.saved_model.save('path/to/model')
epoch_counter.update()
# -------------------WANDB------------------------------#
wandb.log(
{
"train_loss": train_loss,
"train_accuracy": train_acc,
"val_accuracy": val_acc,
}
)