|
@@ -0,0 +1,140 @@
|
|
|
|
|
+"""
|
|
|
|
|
+This module contains some example Generative Adversarial Networks for testing.
|
|
|
|
|
+
|
|
|
|
|
+The classes StupidToyPointGan and StupidToyListGan are not really Networks. This classes are used
|
|
|
|
|
+for testing the interface. Hope your actually GAN will perform better than this two.
|
|
|
|
|
+
|
|
|
|
|
+The class SimpleGan is a simple standard Generative Adversarial Network.
|
|
|
|
|
+"""
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+import numpy as np
|
|
|
|
|
+
|
|
|
|
|
+from library.interfaces import GanBaseClass
|
|
|
|
|
+
|
|
|
|
|
+from keras.layers import Dense, Dropout, Input
|
|
|
|
|
+from keras.models import Model, Sequential
|
|
|
|
|
+from keras.layers.advanced_activations import LeakyReLU
|
|
|
|
|
+from keras.optimizers import Adam
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+class SimpleGan(GanBaseClass):
|
|
|
|
|
+ """
|
|
|
|
|
+ A class for a simple GAN.
|
|
|
|
|
+ """
|
|
|
|
|
+ def __init__(self, numOfFeatures=786, noiseSize=100):
|
|
|
|
|
+ self.isTrained = False
|
|
|
|
|
+ self.noiseSize = noiseSize
|
|
|
|
|
+ self.numOfFeatures = numOfFeatures
|
|
|
|
|
+
|
|
|
|
|
+ def reset(self):
|
|
|
|
|
+ """
|
|
|
|
|
+ Resets the trained GAN to an random state.
|
|
|
|
|
+ """
|
|
|
|
|
+ self.generator = self._createGenerator(self.numOfFeatures, self.noiseSize)
|
|
|
|
|
+ self.discriminator = self._createDiscriminator(self.numOfFeatures)
|
|
|
|
|
+ self.gan = self._createGan(self.noiseSize)
|
|
|
|
|
+
|
|
|
|
|
+ @staticmethod
|
|
|
|
|
+ def _adamOptimizer():
|
|
|
|
|
+ return Adam(lr=0.0002, beta_1=0.5)
|
|
|
|
|
+
|
|
|
|
|
+ def _createGan(self, noiseSize=100):
|
|
|
|
|
+ self.discriminator.trainable=False
|
|
|
|
|
+ gan_input = Input(shape=(noiseSize,))
|
|
|
|
|
+ x = self.generator(gan_input)
|
|
|
|
|
+ gan_output = self.discriminator(x)
|
|
|
|
|
+ gan= Model(inputs=gan_input, outputs=gan_output)
|
|
|
|
|
+ gan.compile(loss='binary_crossentropy', optimizer='adam')
|
|
|
|
|
+ return gan
|
|
|
|
|
+
|
|
|
|
|
+ def _createGenerator(self, numOfFeatures, noiseSize):
|
|
|
|
|
+ generator=Sequential()
|
|
|
|
|
+ generator.add(Dense(units=256, input_dim=noiseSize))
|
|
|
|
|
+ generator.add(LeakyReLU(0.2))
|
|
|
|
|
+
|
|
|
|
|
+ generator.add(Dense(units=512))
|
|
|
|
|
+ generator.add(LeakyReLU(0.2))
|
|
|
|
|
+
|
|
|
|
|
+ generator.add(Dense(units=1024))
|
|
|
|
|
+ generator.add(LeakyReLU(0.2))
|
|
|
|
|
+
|
|
|
|
|
+ generator.add(Dense(units=numOfFeatures, activation='tanh'))
|
|
|
|
|
+
|
|
|
|
|
+ generator.compile(loss='binary_crossentropy', optimizer=self._adamOptimizer())
|
|
|
|
|
+ return generator
|
|
|
|
|
+
|
|
|
|
|
+ def _createDiscriminator(self, numOfFeatures):
|
|
|
|
|
+ discriminator=Sequential()
|
|
|
|
|
+ discriminator.add(Dense(units=1024, input_dim=numOfFeatures))
|
|
|
|
|
+ discriminator.add(LeakyReLU(0.2))
|
|
|
|
|
+ discriminator.add(Dropout(0.3))
|
|
|
|
|
+
|
|
|
|
|
+ discriminator.add(Dense(units=512))
|
|
|
|
|
+ discriminator.add(LeakyReLU(0.2))
|
|
|
|
|
+ discriminator.add(Dropout(0.3))
|
|
|
|
|
+
|
|
|
|
|
+ discriminator.add(Dense(units=256))
|
|
|
|
|
+ discriminator.add(LeakyReLU(0.2))
|
|
|
|
|
+
|
|
|
|
|
+ discriminator.add(Dense(units=1, activation='sigmoid'))
|
|
|
|
|
+
|
|
|
|
|
+ discriminator.compile(loss='binary_crossentropy', optimizer=self._adamOptimizer())
|
|
|
|
|
+ return discriminator
|
|
|
|
|
+
|
|
|
|
|
+ def train(self, dataset, epochs=1, batchSize=128):
|
|
|
|
|
+ trainData = dataset.data1
|
|
|
|
|
+ trainDataSize = trainData.shape[0]
|
|
|
|
|
+
|
|
|
|
|
+ if trainDataSize <= 0:
|
|
|
|
|
+ raise AttributeError("Train GAN: Expected data class 1 to contain at least one point.")
|
|
|
|
|
+
|
|
|
|
|
+ for e in range(epochs):
|
|
|
|
|
+ print(f"Epoch {e + 1}")
|
|
|
|
|
+ for _ in range(batchSize):
|
|
|
|
|
+ #generate random noise as an input to initialize the generator
|
|
|
|
|
+ noise= np.random.normal(0, 1, [batchSize, self.noiseSize])
|
|
|
|
|
+
|
|
|
|
|
+ # Generate fake MNIST images from noised input
|
|
|
|
|
+ generatedImages = self.generator.predict(noise)
|
|
|
|
|
+
|
|
|
|
|
+ # Get a random set of real images
|
|
|
|
|
+ image_batch = dataset.data1[
|
|
|
|
|
+ np.random.randint(low=0, high=trainDataSize, size=batchSize)
|
|
|
|
|
+ ]
|
|
|
|
|
+
|
|
|
|
|
+ #Construct different batches of real and fake data
|
|
|
|
|
+ X = np.concatenate([image_batch, generatedImages])
|
|
|
|
|
+
|
|
|
|
|
+ # Labels for generated and real data
|
|
|
|
|
+ y_dis=np.zeros(2 * batchSize)
|
|
|
|
|
+ y_dis[:batchSize] = 0.9
|
|
|
|
|
+
|
|
|
|
|
+ #Pre train discriminator on fake and real data before starting the gan.
|
|
|
|
|
+ self.discriminator.trainable = True
|
|
|
|
|
+ self.discriminator.train_on_batch(X, y_dis)
|
|
|
|
|
+
|
|
|
|
|
+ #Tricking the noised input of the Generator as real data
|
|
|
|
|
+ noise = np.random.normal(0, 1, [batchSize, 100])
|
|
|
|
|
+ y_gen = np.ones(batchSize)
|
|
|
|
|
+
|
|
|
|
|
+ # During the training of gan,
|
|
|
|
|
+ # the weights of discriminator should be fixed.
|
|
|
|
|
+ #We can enforce that by setting the trainable flag
|
|
|
|
|
+ self.discriminator.trainable=False
|
|
|
|
|
+
|
|
|
|
|
+ #training the GAN by alternating the training of the Discriminator
|
|
|
|
|
+ #and training the chained GAN model with Discriminator’s weights freezed.
|
|
|
|
|
+ self.gan.train_on_batch(noise, y_gen)
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+ def generateDataPoint(self):
|
|
|
|
|
+ return self.generateData(1)[0]
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+ def generateData(self, numOfSamples=1):
|
|
|
|
|
+ #generate random noise as an input to initialize the generator
|
|
|
|
|
+ noise = np.random.normal(0, 1, [numOfSamples, self.noiseSize])
|
|
|
|
|
+
|
|
|
|
|
+ # Generate fake MNIST images from noised input
|
|
|
|
|
+ return self.generator.predict(noise)
|