asdad
🧩 Syntax:
# Training loop
def train(dataset, epochs):
for epoch in range(epochs):
print(f"Epoch {epoch+1}/{epochs}")
for grayscale_batch, color_batch in dataset:
# Generate fake images
fake_images = generator(grayscale_batch, training=True)
# Concatenate grayscale images and generated color images
fake_input = tf.concat([grayscale_batch, fake_images], axis=-1)
# Concatenate grayscale images and real color images
real_input = tf.concat([grayscale_batch, color_batch], axis=-1)
# Train the discriminator
with tf.GradientTape() as disc_tape:
real_output = discriminator(real_input, training=True)
fake_output = discriminator(fake_input, training=True)
disc_loss = discriminator_loss(real_output, fake_output)
# Calculate gradients and update weights
gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)
discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))
# Train the generator
with tf.GradientTape() as gen_tape:
fake_images = generator(grayscale_batch, training=True)
fake_input = tf.concat([grayscale_batch, fake_images], axis=-1)
fake_output = discriminator(fake_input, training=True)
gen_loss = generator_loss(fake_output)
# Calculate gradients and update weights
gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
# Generate and save sample images
generate_and_save_images(generator, epoch + 1, test_x[:9])