# Training loop def train(dataset, epochs): for epoch in range(epochs): print(f"Epoch {epoch+1}/{epochs}") for grayscale_batch, color_batch in dataset: # Generate fake images fake_images = generator(grayscale_batch, training=True) # Concatenate grayscale images and generated color images fake_input = tf.concat([grayscale_batch, fake_images], axis=-1) # Concatenate grayscale images and real color images real_input = tf.concat([grayscale_batch, color_batch], axis=-1) # Train the discriminator with tf.GradientTape() as disc_tape: real_output = discriminator(real_input, training=True) fake_output = discriminator(fake_input, training=True) disc_loss = discriminator_loss(real_output, fake_output) # Calculate gradients and update weights gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables) discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables)) # Train the generator with tf.GradientTape() as gen_tape: fake_images = generator(grayscale_batch, training=True) fake_input = tf.concat([grayscale_batch, fake_images], axis=-1) fake_output = discriminator(fake_input, training=True) gen_loss = generator_loss(fake_output) # Calculate gradients and update weights gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables) generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables)) # Generate and save sample images generate_and_save_images(generator, epoch + 1, test_x[:9])