21. Image Super-Resolution and Enhancement with SRGAN in TensorFlow#

21.1. Overview#

In this lesson, we will explore how to apply Super-Resolution Generative Adversarial Networks (SRGAN) to enhance marine images. Super-resolution is crucial in marine science for improving the quality of images taken in challenging environments, such as underwater scenes with low visibility or images captured from remote sensing devices.

21.1.1. Learning Objectives#

By the end of this section, you will:

  • Understand the importance of super-resolution in marine science applications.

  • Learn how to prepare and load a dataset for super-resolution tasks.

  • Implement and train an SRGAN model using TensorFlow.

  • Understand the theory behind SRGAN and why it uses two convolutional networks.

  • Evaluate and interpret the results of super-resolution on marine images.

  • Reflect on the evaluation metrics and discuss the applicability of SRGANs in marine contexts.


21.2. Marine Science Applications of Super-Resolution#

Marine images often suffer from low resolution due to factors like water absorption, scattering, and limitations of underwater cameras. Enhancing these images is vital for tasks such as:

  • Identifying marine species and habitats.

  • Monitoring coral reef health.

  • Detecting illegal fishing activities.

  • Improving navigation and obstacle avoidance for underwater vehicles.

21.3. Downloading the Dataset#

We will use a dataset called digitalstill.zip, which contains low-resolution and corresponding high-resolution marine images. Ensure you have this dataset available in your working directory.

# Import necessary libraries
import tensorflow as tf
from tensorflow.keras.layers import Conv2D, BatchNormalization, PReLU, Add, UpSampling2D
from tensorflow.keras.models import Model
from tensorflow.keras.applications import VGG19
import numpy as np
import matplotlib.pyplot as plt
import zipfile
import os
from PIL import Image
from ipywidgets import interact, Dropdown

21.4. Preparing the Environment and Data#

21.4.1. Extracting the Dataset#

First, extract the contents of digitalstill.zip.

# Path to the uploaded dataset
zip_file = '/content/digitalstill.zip'

# Extract the zip file
with zipfile.ZipFile(zip_file, 'r') as zip_ref:
    zip_ref.extractall('/content/digitalstill/')

21.4.2. Loading and Preprocessing the Data#

We will load the images and prepare them for training. We’ll create low-resolution (LR) and high-resolution (HR) image pairs.

# Define directories
hr_dir = '/content/digitalstill/high_resolution/'
lr_dir = '/content/digitalstill/low_resolution/'

# Get image file names
hr_images = sorted(os.listdir(hr_dir))
lr_images = sorted(os.listdir(lr_dir))

# Function to load images
def load_image(path):
    img = Image.open(path)
    img = img.resize((128, 128))  # Resize for uniformity
    img = np.array(img)
    img = img / 127.5 - 1  # Normalize to [-1, 1]
    return img

# Load datasets
X_hr = [load_image(os.path.join(hr_dir, img)) for img in hr_images]
X_lr = [load_image(os.path.join(lr_dir, img)) for img in lr_images]

# Convert to numpy arrays
X_hr = np.array(X_hr)
X_lr = np.array(X_lr)

21.5. Understanding SRGAN#

Super-Resolution Generative Adversarial Networks (SRGAN) are designed to produce high-resolution images from low-resolution inputs. They consist of two main components:

  • Generator Network: Attempts to create high-resolution images from low-resolution inputs.

  • Discriminator Network: Tries to distinguish between the generated high-resolution images and real high-resolution images.

21.5.1. Why Use Two Convolutional Networks?#

The use of two convolutional networks in SRGAN is rooted in the concept of Generative Adversarial Networks (GANs), where two models are trained simultaneously through an adversarial process.

  • Generator (G): The generator’s role is to generate images that are as close as possible to the real high-resolution images. It learns to map low-resolution images to high-resolution counterparts.

  • Discriminator (D): The discriminator’s role is to differentiate between the real high-resolution images and the images generated by the generator.

This adversarial setup creates a minimax game:

  • The generator tries to minimize the difference between generated images and real images, effectively “fooling” the discriminator.

  • The discriminator tries to maximize its ability to correctly classify real and generated images.

21.5.2. Theoretical Background#

21.5.2.1. Generative Adversarial Networks (GANs)#

GANs are a class of machine learning frameworks where two networks contest with each other in a game. Given a training set, this technique learns to generate new data with the same statistics as the training set.

  • Objective Function: $\( \min_G \max_D \mathbb{E}_{x \sim p_{data}(x)}[\log D(x)] + \mathbb{E}_{z \sim p_z(z)}[\log(1 - D(G(z)))] \)$ Where:

    • \(G\) is the generator.

    • \(D\) is the discriminator.

    • \(x\) is a real data sample.

    • \(z\) is a random noise vector.

21.5.2.2. Perceptual Loss#

SRGAN introduces a perceptual loss function, which is more effective than traditional loss functions (like Mean Squared Error) for capturing high-frequency details. It combines:

  • Content Loss: Measures the difference in high-level feature representations between generated and real images using a pre-trained network (e.g., VGG19).

  • Adversarial Loss: Encourages the generator to produce images that are indistinguishable from real images to the discriminator.

21.5.2.3. Why Two Networks Improve Performance#

  • Adversarial Training: The generator improves by learning from the discriminator’s feedback, leading to more realistic and high-quality images.

  • Feature Learning: The discriminator learns to identify intricate details, pushing the generator to enhance these details in generated images.

  • Balance: The competition between the two networks helps in balancing the trade-off between blurriness and artifact introduction.

21.6. Implementing the SRGAN Model#

We will now define the generator and discriminator models. Both networks are convolutional neural networks (CNNs) but serve different purposes.

21.6.1. Generator Network#

The generator is responsible for upsampling low-resolution images to high-resolution images. It employs residual blocks and upsampling layers to reconstruct high-frequency details.

21.6.1.1. Key Components:#

  • Residual Blocks: Help in training deeper networks by mitigating the vanishing gradient problem.

  • Upsampling Layers: Increase the spatial dimensions of the feature maps.

  • Activation Functions: Use Parametric ReLU (PReLU) to allow for learning the activation parameters.

# Define the generator model
def build_generator():
    def residual_block(x):
        res = Conv2D(64, kernel_size=3, strides=1, padding='same')(x)
        res = BatchNormalization(momentum=0.8)(res)
        res = PReLU(shared_axes=[1, 2])(res)
        res = Conv2D(64, kernel_size=3, strides=1, padding='same')(res)
        res = BatchNormalization(momentum=0.8)(res)
        res = Add()([res, x])
        return res

    # Input layer
    input_layer = tf.keras.Input(shape=(128, 128, 3))

    # Pre-residual block
    x = Conv2D(64, kernel_size=9, strides=1, padding='same')(input_layer)
    x = PReLU(shared_axes=[1, 2])(x)

    # Store output for skip connection
    skip_connection = x

    # Residual blocks
    for _ in range(16):
        x = residual_block(x)

    # Post-residual block
    x = Conv2D(64, kernel_size=3, strides=1, padding='same')(x)
    x = BatchNormalization(momentum=0.8)(x)
    x = Add()([x, skip_connection])

    # Upsampling blocks
    for _ in range(2):
        x = UpSampling2D(size=2)(x)
        x = Conv2D(256, kernel_size=3, strides=1, padding='same')(x)
        x = PReLU(shared_axes=[1, 2])(x)

    # Output layer
    output_layer = Conv2D(3, kernel_size=9, strides=1, padding='same', activation='tanh')(x)

    # Define model
    model = Model(inputs=input_layer, outputs=output_layer)
    return model

# Build generator
generator = build_generator()

21.6.2. Discriminator Network#

The discriminator’s role is to distinguish between real high-resolution images and the images generated by the generator. It is a binary classifier that outputs the probability of an image being real.

21.6.2.1. Key Components:#

  • Convolutional Layers: Extract features at different levels.

  • LeakyReLU Activation: Helps in learning non-linear relationships without dying neurons.

  • Fully Connected Layers: Aggregate extracted features to make the final classification.

# Define the discriminator model
def build_discriminator():
    def d_block(x, filters, strides=1, bn=True):
        x = Conv2D(filters, kernel_size=3, strides=strides, padding='same')(x)
        if bn:
            x = BatchNormalization(momentum=0.8)(x)
        x = tf.keras.layers.LeakyReLU(alpha=0.2)(x)
        return x

    # Input layer
    input_layer = tf.keras.Input(shape=(128, 128, 3))

    # Convolutional blocks
    x = d_block(input_layer, 64, bn=False)
    x = d_block(x, 64, strides=2)
    x = d_block(x, 128)
    x = d_block(x, 128, strides=2)
    x = d_block(x, 256)
    x = d_block(x, 256, strides=2)
    x = d_block(x, 512)
    x = d_block(x, 512, strides=2)

    # Flatten and dense layers
    x = tf.keras.layers.Flatten()(x)
    x = tf.keras.layers.Dense(1024)(x)
    x = tf.keras.layers.LeakyReLU(alpha=0.2)(x)

    # Output layer
    output_layer = tf.keras.layers.Dense(1, activation='sigmoid')(x)

    # Define model
    model = Model(inputs=input_layer, outputs=output_layer)
    return model

# Build discriminator
discriminator = build_discriminator()

21.6.3. Compiling the Models#

We will compile the discriminator and the combined model. The discriminator is trained to classify images as real or fake, while the generator is trained to produce images that can fool the discriminator.

# Compile the discriminator
discriminator.compile(loss='binary_crossentropy', optimizer=tf.keras.optimizers.Adam(0.0002, 0.5), metrics=['accuracy'])

# Freeze the discriminator when training the generator
discriminator.trainable = False

# VGG19 for perceptual loss
vgg = VGG19(weights='imagenet', include_top=False, input_shape=(128, 128, 3))
vgg.trainable = False

# Define perceptual loss
def perceptual_loss(y_true, y_pred):
    y_true = tf.keras.applications.vgg19.preprocess_input((y_true + 1) * 127.5)
    y_pred = tf.keras.applications.vgg19.preprocess_input((y_pred + 1) * 127.5)
    hr_features = vgg(y_true)
    sr_features = vgg(y_pred)
    return tf.keras.losses.MeanSquaredError()(hr_features, sr_features)

# Input for generator
img_lr = tf.keras.Input(shape=(128, 128, 3))

# Generate high-resolution images
generated_hr = generator(img_lr)

# Discriminator determines validity
validity = discriminator(generated_hr)

# Combined model (generator and discriminator)
combined = Model(inputs=img_lr, outputs=[validity, generated_hr])

# Compile the combined model
combined.compile(loss=['binary_crossentropy', perceptual_loss],
                 optimizer=tf.keras.optimizers.Adam(0.0002, 0.5),
                 loss_weights=[1e-3, 1])

21.7. Training the SRGAN#

We will now train the SRGAN model using the prepared datasets. The training involves alternating between training the discriminator and the generator.

21.7.1. Training Process#

  1. Train Discriminator:

    • Use real high-resolution images and label them as real (1).

    • Use generated high-resolution images from the generator and label them as fake (0).

    • Update the discriminator’s weights based on the loss.

  2. Train Generator:

    • Use low-resolution images as input.

    • The generator tries to produce high-resolution images that the discriminator classifies as real.

    • The combined model’s loss is a weighted sum of adversarial loss and perceptual loss.

    • Update the generator’s weights based on the combined loss.

# Training parameters
epochs = 10000
batch_size = 4
sample_interval = 1000

# Labels for real and fake images
valid = np.ones((batch_size, 1))
fake = np.zeros((batch_size, 1))

# Training loop
for epoch in range(epochs):
    
    # ----------------------
    #  Train Discriminator
    # ----------------------

    # Select a random batch of images
    idx = np.random.randint(0, X_hr.shape[0], batch_size)
    imgs_hr = X_hr[idx]
    imgs_lr = X_lr[idx]
    
    # Generate high-resolution images from low-resolution images
    fake_hr = generator.predict(imgs_lr)
    
    # Train the discriminator (real classified as real and fake as fake)
    d_loss_real = discriminator.train_on_batch(imgs_hr, valid)
    d_loss_fake = discriminator.train_on_batch(fake_hr, fake)
    d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
    
    # ------------------
    #  Train Generator
    # ------------------

    # Train the generator (wants discriminator to label generated images as real)
    g_loss = combined.train_on_batch(imgs_lr, [valid, imgs_hr])
    
    # Print the progress
    if epoch % sample_interval == 0:
        print(f"[Epoch {epoch}/{epochs}] [D loss: {d_loss[0]:.5f}, acc: {100*d_loss[1]:.2f}%] [G loss: {g_loss[0]:.5f}]")

21.8. Evaluating and Interpreting the Results#

We will now evaluate the performance of the trained SRGAN model by visualizing some results and discussing evaluation metrics.

21.8.1. Evaluation Metrics for SRGAN#

Evaluating super-resolution models involves both quantitative metrics and qualitative assessments.

21.8.1.1. Quantitative Metrics:#

  • Peak Signal-to-Noise Ratio (PSNR): Measures the ratio between the maximum possible power of a signal and the power of corrupting noise. Higher PSNR indicates better quality.

  • Structural Similarity Index (SSIM): Measures the similarity between two images. Values range from -1 to 1, with 1 indicating perfect similarity.

21.8.1.2. Qualitative Assessment:#

  • Visual Inspection: Assessing the visual quality of the generated images for artifacts, blurriness, and realistic textures.

# Import additional libraries for evaluation
from skimage.metrics import peak_signal_noise_ratio as psnr
from skimage.metrics import structural_similarity as ssim
# Select a random set of images
idx = np.random.randint(0, X_lr.shape[0], 3)
imgs_lr = X_lr[idx]
imgs_hr = X_hr[idx]

# Generate high-resolution images
generated_hr = generator.predict(imgs_lr)

# Rescale images for visualization
imgs_lr_vis = 0.5 * imgs_lr + 0.5
generated_hr_vis = 0.5 * generated_hr + 0.5
imgs_hr_vis = 0.5 * imgs_hr + 0.5

# Calculate metrics
psnr_values = [psnr(imgs_hr[i], generated_hr[i], data_range=2.0) for i in range(3)]
ssim_values = [ssim(imgs_hr[i], generated_hr[i], multichannel=True, data_range=2.0) for i in range(3)]

# Plot the results
titles = ['Low-resolution', 'Generated High-resolution', 'Original High-resolution']
fig, axs = plt.subplots(3, 3, figsize=(15, 15))
for i in range(3):
    axs[i, 0].imshow(imgs_lr_vis[i])
    axs[i, 0].set_title(titles[0])
    axs[i, 1].imshow(generated_hr_vis[i])
    axs[i, 1].set_title(f"{titles[1]}\nPSNR: {psnr_values[i]:.2f}, SSIM: {ssim_values[i]:.3f}")
    axs[i, 2].imshow(imgs_hr_vis[i])
    axs[i, 2].set_title(titles[2])
    for j in range(3):
        axs[i, j].axis('off')
plt.show()

21.8.2. Reflection Questions#

Consider the following questions to reflect on the results and deepen your understanding:

  1. Detail Enhancement: Observing the generated images, do you notice a significant improvement in details compared to the low-resolution inputs? Provide specific examples.

  2. PSNR and SSIM Metrics: How do the PSNR and SSIM values correlate with the visual quality of the images? Are higher values always indicative of better quality in the context of marine images?

  3. Artifacts: Are there any artifacts introduced in the generated images? What might be causing them, and how could they affect marine image analysis?

  4. Color Accuracy: Does the color reproduction in the generated images match the original high-resolution images? How important is color accuracy in marine applications?

  5. Limitations of SRGAN: Based on your observations, what are some limitations of using SRGANs for marine image enhancement? Consider factors like computational resources, training data requirements, and potential overfitting.

  6. Alternative Approaches: If SRGANs are not suitable for certain marine applications, what alternative methods could be used for image enhancement? Discuss their potential advantages and disadvantages.

21.9. Conclusion#

In this lesson, we explored how super-resolution can significantly enhance marine images, aiding in various scientific and monitoring tasks. We implemented an SRGAN model in TensorFlow, trained it on marine images, and evaluated its performance using both quantitative metrics and qualitative assessments. We also delved into the theoretical underpinnings of SRGAN, understanding the roles of the generator and discriminator networks.

21.9.1. Key Takeaways#

  • Applicability: SRGANs can improve image resolution, but their effectiveness depends on the quality and quantity of training data.

  • Evaluation: Metrics like PSNR and SSIM are helpful but should be complemented with visual inspections.

  • Challenges: High computational costs and the need for extensive training data can limit the use of SRGANs in marine contexts.

  • Alternatives: Other methods may be more suitable depending on the specific marine application and available resources.

Super-resolution techniques like SRGAN hold great promise for improving the quality of marine imagery, leading to better analysis and decision-making. However, it’s essential to consider their limitations and evaluate whether they are the best tool for a given task in marine science.