""" This module contains some example Generative Adversarial Networks for testing. The classes StupidToyPointGan and StupidToyListGan are not really Networks. This classes are used for testing the interface. Hope your actually GAN will perform better than this two. The class SimpleGan is a simple standard Generative Adversarial Network. """ import numpy as np from library.interfaces import GanBaseClass from keras.layers import Dense, Dropout, Input from keras.models import Model, Sequential from keras.layers.advanced_activations import LeakyReLU from keras.optimizers import Adam class SimpleGan(GanBaseClass): """ A class for a simple GAN. """ def __init__(self, numOfFeatures=786, noiseSize=100, epochs=3, batchSize=128): self.isTrained = False self.noiseSize = noiseSize self.numOfFeatures = numOfFeatures self.epochs = epochs self.batchSize = batchSize def reset(self): """ Resets the trained GAN to an random state. """ self.generator = self._createGenerator(self.numOfFeatures, self.noiseSize) self.discriminator = self._createDiscriminator(self.numOfFeatures) self.gan = self._createGan(self.noiseSize) @staticmethod def _adamOptimizer(): return Adam(lr=0.0002, beta_1=0.5) def _createGan(self, noiseSize=100): self.discriminator.trainable=False gan_input = Input(shape=(noiseSize,)) x = self.generator(gan_input) gan_output = self.discriminator(x) gan= Model(inputs=gan_input, outputs=gan_output) gan.compile(loss='binary_crossentropy', optimizer='adam') return gan def _createGenerator(self, numOfFeatures, noiseSize): generator=Sequential() generator.add(Dense(units=256, input_dim=noiseSize)) generator.add(LeakyReLU(0.2)) generator.add(Dense(units=512)) generator.add(LeakyReLU(0.2)) generator.add(Dense(units=1024)) generator.add(LeakyReLU(0.2)) generator.add(Dense(units=numOfFeatures, activation='tanh')) generator.compile(loss='binary_crossentropy', optimizer=self._adamOptimizer()) return generator def _createDiscriminator(self, numOfFeatures): discriminator=Sequential() discriminator.add(Dense(units=1024, input_dim=numOfFeatures)) discriminator.add(LeakyReLU(0.2)) discriminator.add(Dropout(0.3)) discriminator.add(Dense(units=512)) discriminator.add(LeakyReLU(0.2)) discriminator.add(Dropout(0.3)) discriminator.add(Dense(units=256)) discriminator.add(LeakyReLU(0.2)) discriminator.add(Dense(units=1, activation='sigmoid')) discriminator.compile(loss='binary_crossentropy', optimizer=self._adamOptimizer()) return discriminator def train(self, dataset): trainData = dataset.data1 trainDataSize = trainData.shape[0] if trainDataSize <= 0: raise AttributeError("Train GAN: Expected data class 1 to contain at least one point.") for e in range(self.epochs): print(f"Epoch {e + 1}/{self.epochs}") for _ in range(self.batchSize): #generate random noise as an input to initialize the generator noise= np.random.normal(0, 1, [self.batchSize, self.noiseSize]) # Generate fake MNIST images from noised input generatedImages = self.generator.predict(noise) # Get a random set of real images image_batch = dataset.data1[ np.random.randint(low=0, high=trainDataSize, size=self.batchSize) ] #Construct different batches of real and fake data X = np.concatenate([image_batch, generatedImages]) # Labels for generated and real data y_dis=np.zeros(2 * self.batchSize) y_dis[:self.batchSize] = 0.9 #Pre train discriminator on fake and real data before starting the gan. self.discriminator.trainable = True self.discriminator.train_on_batch(X, y_dis) #Tricking the noised input of the Generator as real data noise = np.random.normal(0, 1, [self.batchSize, 100]) y_gen = np.ones(self.batchSize) # During the training of gan, # the weights of discriminator should be fixed. #We can enforce that by setting the trainable flag self.discriminator.trainable=False #training the GAN by alternating the training of the Discriminator #and training the chained GAN model with Discriminator’s weights freezed. self.gan.train_on_batch(noise, y_gen) def generateDataPoint(self): return self.generateData(1)[0] def generateData(self, numOfSamples=1): #generate random noise as an input to initialize the generator noise = np.random.normal(0, 1, [numOfSamples, self.noiseSize]) # Generate fake MNIST images from noised input return self.generator.predict(noise)