DCGAN para reconstrucción de imágenes

Enrique Blanco    29 septiembre, 2020

Hacer algo más de un año en este blog hicimos una muy breve introducción a los modelos generativos profundos. Hoy vamos a mostrar como construir y entrenar una DCGAN (Deep convolutional generative adversarial networks) dedicada a la reconstrucción de imágenes haciendo uso de Python y Keras.

Las arquitecturas generativas más usadas, recordemos, son los Variational Autoencoders (VAEs) y las GANs o Generative Adversarial Networks. 

En este artículo, vamos a poner un ejemplo práctico con Keras de implementación de una GAN para la reconstrucción de imágenes. Este tipo de redes son el perfecto ejemplo de una red neuronal que constituye un modelo generativo haciendo uso del paradigma de aprendizaje no supervisado para entrenar dos modelos en paralelo en un juego de suma cero.

Figura 1. Ejemplo de una red neuronal GAN. Fuente
Figura 1. Ejemplo de una red neuronal GAN. Fuente

La misión del generador es crear nuevas imágenes similares al conjunto de datos que deberían ser indistinguibles por las redes discriminadoras. La red discriminadora tiene en cuenta dos salidas, que son las imágenes del conjunto de datos real y las imágenes que salen de la red del generador.

El discriminador funciona como un clasificador binario y determina si una imagen dada por el generador es sintética o real.

Comencemos definiendo la ubicación del checkpoint del generador que vamos a entrenar además de importar las librerías y el dataset con el que vamos a trabajar.

codegenerator_save_name = 'generator.h5'
generator_checkpoint_path = './checkpoints/'

import numpy as np
import random
import matplotlib.pyplot as plt
import datetime
import os
from tqdm import tqdm

from keras.layers import Input, Conv2D
from keras.layers import AveragePooling2D, BatchNormalization
from keras.layers import UpSampling2D, Flatten, Activation
from keras.models import Model, Sequential
from keras.layers.core import Dense, Dropout
from keras.layers.advanced_activations import LeakyReLU
from keras.callbacks import EarlyStopping, TensorBoard
from keras.optimizers import Adam
from keras import backend as K
from keras.datasets import mnist

Como vemos, contamos con 60000 imágenes de entrenamiento y 10000 imágenes de testeo con una dimensión de 28 x 28 píxeles para un total de 10 clases distintas.

(X_train, y_train), (X_test, y_test) =  mnist.load_data()

print('Size of the training_set: ', X_train.shape)
print('Size of the test_set: ', X_test.shape)
print('Shape of each image: ', X_train[0].shape)
print('Total number of classes: ', len(np.unique(y_train)))
print('Unique class labels: ', np.unique(y_train))
----------------------------------------------------------------
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz
11493376/11490434 [==============================] - 0s 0us/step
Size of the training_set:  (60000, 28, 28)
Size of the test_set:  (10000, 28, 28)
Shape of each image:  (28, 28)
Total number of classes:  10
Unique class labels:  [0 1 2 3 4 5 6 7 8 9]


Aquí podemos ver 9 ejemplos de los dígitos que hemos importado. Cada una de estas imágenes viene acompañada de su correspondiente etiqueta.

# Plot of 9 random images
for i in range(0, 9):
    plt.subplot(331+i) # plot of 3 rows and 3 columns
    plt.axis('off') # turn off axis
    plt.imshow(X_train[i], cmap='gray') # gray scale
Figura 2. Muestra de 9 imágenes contenidas en MNIST dataset.
Figura 2. Muestra de 9 imágenes contenidas en MNIST dataset.

Data preprocessing

Las imágenes vienen con un solo canal entre 0 y 255. Nuestra intención aquí para procesar la imagen es hacer una normalización entre -1 y 1.

# Converting integer values to float types 
X_train = X_train.astype(np.float32)
X_test = X_test.astype(np.float32)

# Scaling and centering
X_train = (X_train - 127.5) / 127.5
X_test = (X_test - 127.5)/ 127.5
print('Maximum pixel value in the training_set after Centering and Scaling: ', np.max(X_train))
print('Minimum pixel value in the training_set after Centering and Scaling: ', np.min(X_train))
----------------------------------------------------------------
Maximum pixel value in the training_set after Centering and Scaling:  1.0
Minimum pixel value in the training_set after Centering and Scaling:  -1.0

Enriqueciendo el dataset

Ahora vamos a crear una función que modifique un poco nuestras imágenes. Nuestra misión es incluir algunas máscaras de ceros dentro de la imagen para un pequeño número de píxeles. Un ejemplo del resultado de aplicar esa función lo podemos ver a continuación. 

def noising(image):
    """Masking."""
    import random
    array = np.array(image)
    i = random.choice(range(8, 12))  # x coord for top left corner of the mask
    j = random.choice(range(8, 12))  # y coord for top left corner of the mask
    array[i:i+8, j:j+8] = -1  # setting the pixels in the masked region to 0
    return array

noised_train_data = np.array([*map(noising, X_train)])
noised_test_data = np.array([*map(noising, X_test)])
print('Noised train data Shape/Dimension : ', noised_train_data.shape)
print('Noised test data Shape/Dimension : ', noised_train_data.shape)
----------------------------------------------------------------
Noised train data Shape/Dimension :  (60000, 28, 28)
Noised test data Shape/Dimension :  (60000, 28, 28)
Figura 3. Muestra de las mismas 9 imágenes de la Figura 2, pero enmascaradas.
Figura 3. Muestra de las mismas 9 imágenes de la Figura 2, pero enmascaradas.

Haciendo un reshaping de los datos de entrenamiento y testeo

Para poder alimentar correctamente tanto el discriminador como el generador necesitamos hacer un reshape de nuestro DataSet simplemente añadimos una dimensión más a las ya conocidas.

c# Reshaping the training data
X_train = X_train.reshape(X_train.shape[0], X_train.shape[1], X_train.shape[2], 1)
print('Size/Shape of the original training set: ', X_train.shape)

# Reshaping the nosied training data
noised_train_data = noised_train_data.reshape(noised_train_data.shape[0],
                                              noised_train_data.shape[1],
                                              noised_train_data.shape[2], 1)
print('Size/Shape of the noised training set: ', noised_train_data.shape)

# Reshaping the testing data
X_test = X_test.reshape(X_test.shape[0], X_test.shape[1], X_test.shape[2], 1)
print('Size/Shape of the original test set: ', X_test.shape)

# Reshaping the noised testing data
noised_test_data = noised_test_data.reshape(noised_test_data.shape[0],
                                            noised_test_data.shape[1],
                                            noised_test_data.shape[2], 1)
print('Size/Shape of the noised test set: ', noised_test_data.shape)
----------------------------------------------------------------
Size/Shape of the original training set:  (60000, 28, 28, 1)
Size/Shape of the noised training set:  (60000, 28, 28, 1)
Size/Shape of the original test set:  (10000, 28, 28, 1)
Size/Shape of the noised test set:  (10000, 28, 28, 1)

A continuación, incluimos algunas funciones auxiliares para revertir el escalado de las imágenes y para representar el resultado del aprendizaje de nuestro generador durante el entrenamiento.

ef upscale(image):
    """Scale the image to 0-255 scale."""
    return (image*127.5 + 127.5).astype(np.uint8)


def generated_images_plot(original, noised_data, generator):
    """Plot subplot of images during training."""
    print('NOISED')
    for i in range(9):
        plt.subplot(331 + i)
        plt.axis('off')
        plt.imshow(upscale(np.squeeze(noised_data[i])), cmap='gray')
    plt.show()
    print('GENERATED')
    for i in range(9):
        pred = generator.predict(noised_data[i:i+1], verbose=0)
        plt.subplot(331 + i)
        plt.axis('off')
        plt.imshow(upscale(np.squeeze(pred[0])), cmap='gray')
    plt.show()
    print('ORIGINAL')
    for i in range(9):
        plt.subplot(331 + i)
        plt.axis('off')
        plt.imshow(upscale(np.squeeze(original[i])), cmap='gray')
    plt.show()


def plot_generated_images_combined(original, noised_data, generator):
    """Another function to plot images during training."""
    rows, cols = 4, 12
    num = rows * cols
    image_size = 28
    generated_images = generator.predict(noised_data[0:num])
    imgs = np.concatenate([original[0:num], noised_data[0:num],
                          generated_images])
    imgs = imgs.reshape((rows * 3, cols, image_size, image_size))
    imgs = np.vstack(np.split(imgs, rows, axis=1))
    imgs = imgs.reshape((rows * 3, -1, image_size, image_size))
    imgs = np.vstack([np.hstack(i) for i in imgs])
    imgs = upscale(imgs)
    plt.figure(figsize=(8, 16))
    plt.axis('off')
    plt.title('Original Images: top rows, '
              'Corrupted Input: middle rows, '
              'Generated Images: bottom rows')
    plt.imshow(imgs, cmap='gray')
    plt.show()


def plot_training_loss(discriminator_losses, generator_losses):
    """Plot the losses."""
    plt.figure()
    plt.plot(range(len(discriminator_losses)), discriminator_losses,
             color='red', label='Discriminator loss')
    plt.plot(range(len(generator_losses)), generator_losses,
             color='blue', label='Adversarial loss')
    plt.title('Discriminator and Adversarial loss')
    plt.xlabel('Iterations')
    plt.ylabel('Loss (Adversarial/Discriminator)')
    plt.legend()
    plt.show()

Definición de hiperparametros de nuestra red neuronal

K.clear_session()

# Smoothing value
smooth_real = 0.9

# Number of epochs
epochs = 10

# Batchsize
batch_size = 64

# Optimizer for the generator
optimizer_g = Adam(lr=0.0002, beta_1=0.5)

# Optimizer for the discriminator
optimizer_d = Adam(lr=0.0004, beta_1=0.5)

# Shape of the input image
input_shape = (28,28,1)

Creando el generador

Vamos a crear la arquitectura del generador usando Keras. Probemos en primer lugar con seis bloques convolucionales.

def img_generator(input_shape):
    generator = Sequential()
    generator.add(Conv2D(32, (3, 3), padding='same', input_shape=input_shape)) # 32 filters
    generator.add(BatchNormalization())
    generator.add(Activation('relu'))
    generator.add(AveragePooling2D(pool_size=(2, 2)))
    
    generator.add(Conv2D(64, (3, 3), padding='same')) # 64 filters
    generator.add(BatchNormalization())
    generator.add(Activation('relu'))
    generator.add(AveragePooling2D(pool_size=(2, 2)))
    
    generator.add(Conv2D(128, (3, 3), padding='same')) # 128 filters
    generator.add(BatchNormalization())
    generator.add(Activation('relu')) 
    
    generator.add(Conv2D(128, (3, 3), padding='same')) # 128 filters
    generator.add(Activation('relu'))
    generator.add(UpSampling2D((2,2)))
    
    generator.add(Conv2D(64, (3, 3), padding='same')) # 64 filters
    generator.add(Activation('relu'))
    generator.add(UpSampling2D((2,2)))
    
    generator.add(Conv2D(1, (3, 3), activation='tanh', padding='same')) # 1 filter
    return generator
# print generator summary
img_generator(input_shape).summary()
----------------------------------------------------------------
Model: "sequential_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
conv2d_1 (Conv2D)            (None, 28, 28, 32)        320       
_________________________________________________________________
batch_normalization_1 (Batch (None, 28, 28, 32)        128       
_________________________________________________________________
activation_1 (Activation)    (None, 28, 28, 32)        0         
_________________________________________________________________
average_pooling2d_1 (Average (None, 14, 14, 32)        0         
_________________________________________________________________
conv2d_2 (Conv2D)            (None, 14, 14, 64)        18496     
_________________________________________________________________
batch_normalization_2 (Batch (None, 14, 14, 64)        256       
_________________________________________________________________
activation_2 (Activation)    (None, 14, 14, 64)        0         
_________________________________________________________________
average_pooling2d_2 (Average (None, 7, 7, 64)          0         
_________________________________________________________________
conv2d_3 (Conv2D)            (None, 7, 7, 128)         73856     
_________________________________________________________________
batch_normalization_3 (Batch (None, 7, 7, 128)         512       
_________________________________________________________________
activation_3 (Activation)    (None, 7, 7, 128)         0         
_________________________________________________________________
conv2d_4 (Conv2D)            (None, 7, 7, 128)         147584    
_________________________________________________________________
activation_4 (Activation)    (None, 7, 7, 128)         0         
_________________________________________________________________
up_sampling2d_1 (UpSampling2 (None, 14, 14, 128)       0         
_________________________________________________________________
conv2d_5 (Conv2D)            (None, 14, 14, 64)        73792     
_________________________________________________________________
activation_5 (Activation)    (None, 14, 14, 64)        0         
_________________________________________________________________
up_sampling2d_2 (UpSampling2 (None, 28, 28, 64)        0         
_________________________________________________________________
conv2d_6 (Conv2D)            (None, 28, 28, 1)         577       
=================================================================
Total params: 315,521
Trainable params: 315,073
Non-trainable params: 448
_________________________________________________________________

Cómo vemos hemos creado un modelo con unos 315 parámetros a optimizar.

Creando el discriminador

Ahora es el turno de crear el discriminador. Vamos a poner tres bloques convolucionales seguidos de una capa densa que termine en una única unidad, con un 0 o un 1.

def img_discriminator(input_shape):
    discriminator = Sequential()
    discriminator.add(Conv2D(64, (3, 3), strides=2, padding='same', input_shape=input_shape, activation = 'linear'))
    discriminator.add(LeakyReLU(0.2))
    discriminator.add(Dropout(0.2))
    
    discriminator.add(Conv2D(128, (3, 3), strides=2, padding='same', activation = 'linear'))
    discriminator.add(LeakyReLU(0.2))
    discriminator.add(Dropout(0.2))
    
    discriminator.add(Conv2D(256, (3, 3), padding='same', activation = 'linear'))
    discriminator.add(LeakyReLU(0.2))
    discriminator.add(Dropout(0.2))
    
    discriminator.add(Flatten())
    discriminator.add(Dense(1, activation='sigmoid'))

    return discriminator
# print summary of the discriminator
img_discriminator(input_shape).summary()
----------------------------------------------------------------
Model: "sequential_2"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
conv2d_7 (Conv2D)            (None, 14, 14, 64)        640       
_________________________________________________________________
leaky_re_lu_1 (LeakyReLU)    (None, 14, 14, 64)        0         
_________________________________________________________________
dropout_1 (Dropout)          (None, 14, 14, 64)        0         
_________________________________________________________________
conv2d_8 (Conv2D)            (None, 7, 7, 128)         73856     
_________________________________________________________________
leaky_re_lu_2 (LeakyReLU)    (None, 7, 7, 128)         0         
_________________________________________________________________
dropout_2 (Dropout)          (None, 7, 7, 128)         0         
_________________________________________________________________
conv2d_9 (Conv2D)            (None, 7, 7, 256)         295168    
_________________________________________________________________
leaky_re_lu_3 (LeakyReLU)    (None, 7, 7, 256)         0         
_________________________________________________________________
dropout_3 (Dropout)          (None, 7, 7, 256)         0         
_________________________________________________________________
flatten_1 (Flatten)          (None, 12544)             0         
_________________________________________________________________
dense_1 (Dense)              (None, 1)                 12545     
=================================================================
Total params: 382,209
Trainable params: 382,209
Non-trainable params: 0
_________________________________________________________________

El número de parámetros a entrenar esta rama de la GAN está en el mismo orden de magnitud que para el generador.

Creando la GAN

Tras definir el discriminador y el generador, con la siguiente función vamos a ser capaces de construir la red neuronal generativa.

def dcgan(discriminator, generator, input_shape):
    # Don't train the discriminator when compiling GAN
    discriminator.trainable = False

    # Accepts the noised input
    gan_input = Input(shape=input_shape)
    
    # Generates image by passing the above received input to the generator
    gen_img = generator(gan_input)
    
    # Feeds the generated image to the discriminator
    gan_output = discriminator(gen_img)
    
    # Compile everything as a model with binary crossentropy loss
    gan = Model(inputs=gan_input, outputs=gan_output)
    return gan

Entendiendo el entrenamiento

Sigamos los siguientes pasos para entrenar nuestro modelo generativo

  1. Cargamos el generador y el discriminador haciendo uso de las funciones que acabamos de crear;
  2. Compilamos el discriminador con una función de pérdida binary cross entropy;
  3. Creamos la GAN y compilamos con la misma función de pérdida;
  4. Generamos baches tanto de imágenes originales como de imágenes modificadas y alimentamos el generador con imágenes modificadas;
  5. A continuación, alimentamos el discriminador con imágenes originales e imágenes obtenidas desde el generador;
  6. Ponemos discriminator.trainable a True para permitir que el discriminador se entren con este nuevo set de imágenes;
  7. Alimentamos el generador con aquellas imágenes que el discriminador ha etiquetado como 1. No se nos debe olvidar cambiar discriminator.trainable a False. No queremos que el discriminador modifique sus pesos mientras se está entrenando el generador;
  8. Desde el punto cuatro hasta el punto 7 repetimos los pasos durante tantas épocas como hayamos definido.
def train(X_train,      noised_train_data,
          input_shape,  smooth_real,
          epochs,       batch_size,
          optimizer_g, optimizer_d):
    """Training GAN."""
    discriminator_losses = []
    generator_losses = []

    # Number of iteration possible with batches of size 128
    iterations = X_train.shape[0] // batch_size
    # Load the generator and the discriminator
    generator = img_generator(input_shape)
    discriminator = img_discriminator(input_shape)
    # Compile the discriminator with binary_crossentropy loss
    discriminator.compile(loss='binary_crossentropy', optimizer=optimizer_d)
    # Feed the generator and the discriminator to the function dcgan
    # to form the DCGAN architecture
    gan = dcgan(discriminator, generator, input_shape)
    # Compile the DCGAN with binary_crossentropy loss
    gan.compile(loss='binary_crossentropy', optimizer=optimizer_g)

    for i in range(epochs):
        print('Epoch %d' % (i+1))
        # Use tqdm to get an estimate of time remaining
        for j in tqdm(range(1, iterations+1)):
            # batch of original images (batch = batchsize)
            original = X_train[np.random.randint(0, X_train.shape[0],
                                                 size=batch_size)]
            # batch of noised images (batch = batchsize)
            noise = noised_train_data[np.random.randint(0,
                                                        noised_train_data.shape[0],
                                                        size=batch_size)]
            # Generate fake images
            generated_images = generator.predict(noise)
            # Labels for generated data
            dis_lab = np.zeros(2*batch_size)
            dis_train = np.concatenate([original, generated_images])
            # label smoothing for original images
            dis_lab[:batch_size] = smooth_real
            # Train discriminator on original images
            discriminator.trainable = True
            discriminator_loss = discriminator.train_on_batch(dis_train,
                                                              dis_lab)
            # save the losses
            discriminator_losses.append(discriminator_loss)
            # Train generator
            gen_lab = np.ones(batch_size)
            discriminator.trainable = False
            sample_indices = np.random.randint(0, X_train.shape[0],
                                               size=batch_size)
            original = X_train[sample_indices]
            noise = noised_train_data[sample_indices]

            generator_loss = gan.train_on_batch(noise, gen_lab)
            # save the losses
            generator_losses.append(generator_loss)
            if i == 0 and j == 1:
                print('Iteration - %d', j)
                generated_images_plot(original, noise, generator)
                plot_generated_images_combined(original, noise, generator)
        
        print("Discriminator Loss: ", discriminator_loss,
              ", Adversarial Loss: ", generator_loss)
        generated_images_plot(original, noise, generator)
        plot_generated_images_combined(original, noise, generator)

        # Save generator model
        generator.save(generator_checkpoint_path) 
        
    # plot the losses
    plot_training_loss(discriminator_losses, generator_losses)

    return generator

Procedamos a lanzar el entrenamiento:

generator = train(X_train, noised_train_data,
                  input_shape, smooth_real,
                  epochs, batch_size,
                  optimizer_g, optimizer_d)
Figura 3. Primera evaluación antes de empezar la primera época de entrenamiento (I)
Figura 4. Primera evaluación antes de empezar la primera época de entrenamiento (II)
Figura 4. Primera evaluación antes de empezar la primera época de entrenamiento (II)
Figura 5. Inferencia del generador tras 10 épocas de entrenamiento (I)
Figura 5. Inferencia del generador tras 10 épocas de entrenamiento (I)
Figura 6. Inferencia del generador tras 10 épocas de entrenamiento (II)
Figura 6. Inferencia del generador tras 10 épocas de entrenamiento (II)

Es fácil observar cómo el generador aprende a reproducir las imágenes que se le facilitan, aunque estas estén enmascaradas. En la figura anterior, por ejemplo, tenemos 4 conjuntos de imágenes, de los cuales la primera fila son las imágenes originales, la segunda son los mismos números enmascarados y la tercera es la reconstrucción del generador. El desempeño es bastante bueno dado el poco tiempo que hemos dejado entrenando la red (1 minuto por época). Fíjense en cómo son las primeras reconstrucciones (Figuras 3 y 4).

Infiriendo con el generador ya entrenado

Al entrenar en un modelo convolucional de clasificación normal nuestro dataset de MNIST, obtendremos una precisión cercana al 98-99% .

# input image shape
input_shape = (28,28,1)

def train_mnist(input_shape, X_train, y_train):
    model = Sequential()
    model.add(Conv2D(32, (3, 3), strides=2, padding='same',
                     input_shape=input_shape))
    model.add(Activation('relu'))
    model.add(Dropout(0.2))

    model.add(Conv2D(64, (3, 3), strides=2, padding='same'))
    model.add(Activation('relu'))
    model.add(Dropout(0.2)) 

    model.add(Conv2D(128, (3, 3), padding='same'))
    model.add(Activation('relu'))
    model.add(Dropout(0.2))
    model.add(Flatten())

    model.add(Dense(1024, activation = 'relu'))
    model.add(Dense(10, activation='softmax'))
    
    # Compilamos el modelo
    model.compile(loss='sparse_categorical_crossentropy',
                  optimizer='adam', metrics=['accuracy'])
    
    
    if not os.path.isdir('./logs'):
        os.mkdir('./logs')
    
    # Algunos callbacks que nos pueden ser de utilidad
    early_stopping_callback = EarlyStopping(monitor='val_loss', 
                                            patience=3)
    
    tensorboard_callback = TensorBoard(log_dir='./logs/mnist' + datetime.datetime.now().strftime("%Y%m%d-%H%M%S"),
                                        histogram_freq=0, # It is general issue with keras/tboard that you cannot get histograms with a validation_generator
                                        write_graph=True,
                                        write_images=True,
                                        )
    
    # Ajustamos a los datos de entrenamiento
    model.fit(X_train, y_train,
              batch_size = 128,  
              epochs = 3, 
              callbacks= [early_stopping_callback, tensorboard_callback],
              validation_data=(X_test, y_test), 
              verbose = 2 )
    return model
mnist_model = train_mnist(input_shape, X_train, y_train)
----------------------------------------------------------------
Train on 60000 samples, validate on 10000 samples
Epoch 1/3
 - 13s - loss: 0.1935 - accuracy: 0.9390 - val_loss: 0.0564 - val_accuracy: 0.9817
Epoch 2/3
 - 13s - loss: 0.0603 - accuracy: 0.9808 - val_loss: 0.0364 - val_accuracy: 0.9886
Epoch 3/3
 - 13s - loss: 0.0436 - accuracy: 0.9862 - val_loss: 0.0305 - val_accuracy: 0.9907

Si con ese modelo evaluamos la precisión sobre el conjunto de entrenamiento enmascarado esta misma disminuye hasta el 76,85%, lo cual es entendible.

# prediction on the masked images
pred_labels = mnist_model.predict_classes(noised_test_data)
print('The model median accuracy on the masked images is:',np.mean(pred_labels==y_test)*100)
----------------------------------------------------------------
The model median accuracy on the masked images is: 76.85

Si en vez de alimentar un modelo convolucional normal entrenado con datos sin enmascarar alimentamos el generador de la GAN que acabamos de entrenar, obtenemos una precisión media del 91,17 %. Esto se traduce en un incremento de casi el 15% en precisión.

# predict on the restored/generated digits
gen_pred_lab = mnist_model.predict_classes(gen_imgs_test)
print('The model model accuracy on the generated images is:',np.mean(gen_pred_lab==y_test)*100)
----------------------------------------------------------------
The model model accuracy on the generated images is: 91.17

Por lo tanto, podemos asegurar que nuestro generador ha aprendido correctamente a distinguir números manuscritos a pesar de haber sido alimentado con imágenes enmascaradas. Por ello, podemos asegurar que este modelo será capaz de generar imágenes fidedignas a partir de imágenes ruidosas. Además, nuestro discriminador también las identificará correctamente.

Veamos un ejemplo abajo, donde el discriminador se equivoca (y es entendible su equivocación, cualquiera erraría en la identificación) ante la «rareza» del ‘8’ reconstruido.

Figura 7. Imágenes generadas y etiquetadas por nuestra GAN.
Figura 7. Imágenes generadas y etiquetadas por nuestra GAN.

Como se muestra a continuación, alimentemos el generador cargado desde el último checkpoint con una imagen de un seis enmascarada.

from keras.models import load_model

loaded_generator = load_model(generator_checkpoint_path, compile=False)
random_idx = np.random.randint(0, noised_test_data.shape[0])

plt.imshow(upscale(np.squeeze(noised_test_data[random_idx])), cmap='gray')
Figura 8. Dígito correspondiente a un '6' enmascarado.
Figura 8. Dígito correspondiente a un ‘6’ enmascarado.

La recreación del generador de la imagen es satisfactoria.

reconstructed = loaded_generator.predict(noised_test_data[random_idx].reshape(1, noised_test_data[random_idx].shape[0],
                                            noised_test_data[random_idx].shape[1],
                                            noised_test_data[random_idx].shape[2]))
plt.imshow(upscale(np.squeeze(reconstructed)), cmap='gray')
Figura 9. Dígito de la Figura 8 reconstruida.
Figura 9. Dígito de la Figura 8 reconstruida.

Esperamos que este artículo sirva para dar más luz en el mundo de los modelos generativos sobre cómo se pueden entrenar este tipo de arquitecturas.

Para mantenerte al día con LUCA visita nuestra página web,  suscríbete a LUCA Data Speaks o síguenos en TwitterLinkedIn YouTube.

Deja una respuesta

Tu dirección de correo electrónico no será publicada. Los campos obligatorios están marcados con *