Ver código fonte

added Simple GAN

Kristian Schultz 4 anos atrás
pai
commit
b8050d8279

Diferenças do arquivo suprimidas por serem muito extensas
+ 7 - 18
Example Exercise.ipynb


Diferenças do arquivo suprimidas por serem muito extensas
+ 22 - 47
Example Toy Exercise.ipynb


+ 13 - 38
library/GanExamples.py

@@ -26,6 +26,13 @@ class StupidToyPointGan(GanBaseClass):
         self.isTrained = False
         self.exampleItem = None
 
+    def reset(self):
+        """
+        Resets the trained GAN to an random state.
+        """
+        self.isTrained = False
+        self.exampleItem = None
+
     def train(self, dataSet):
         """
         Trains the GAN.
@@ -76,6 +83,12 @@ class StupidToyListGan(GanBaseClass):
         if self.maxListSize < 1:
             raise AttributeError("Expected maxListSize to be > 0 but got " + str(self.maxListSize))
 
+    def reset(self):
+        """
+        Resets the trained GAN to an random state.
+        """
+        self.isTrained = False
+        self.exampleItems = None
 
     def train(self, dataSet):
         """
@@ -118,41 +131,3 @@ class StupidToyListGan(GanBaseClass):
             raise AttributeError("Expected numOfSamples to be > 0")
 
         return np.array([self.generateDataPoint() for _ in range(numOfSamples)])
-
-
-# class SimpleGan(GanBaseClass):
-#     def __init__(self, maxListSize=100):
-#         self.isTrained = False
-#         self.exampleItems = None
-#         self.nextIndex = 0
-#         self.maxListSize = int(maxListSize)
-#         if self.maxListSize < 1:
-#             raise AttributeError(f"Expected maxListSize to be > 0 but got {self.maxListSize}")
-#
-#
-#     def train(self, dataSet):
-#         if dataSet.data1.shape[0] <= 0:
-#             raise AttributeError("Train GAN: Expected data class 1 to contain at least one point.")
-#
-#         self.isTrained = True
-#         self.exampleItems = dataSet.data1[: self.maxListSize].copy()
-#
-#     def generateDataPoint(self, numOfSamples=1):
-#         if not self.isTrained:
-#             raise ValueError("Try to generate data with untrained GAN.")
-#
-#         i = self.nextIndex
-#         self.nextIndex += 1
-#         if self.nextIndex >= self.exampleItems.shape[0]:
-#             self.nextIndex = 0
-#
-#         return self.exampleItems[i]
-#
-#
-#     def generateData(self, numOfSamples=1):
-#         numOfSamples = int(numOfSamples)
-#         if numOfSamples < 1:
-#             raise AttributeError("Expected numOfSamples to be > 0")
-#
-#         return np.array([self.generateDataPoint() for _ in range(numOfSamples)])
-#

+ 140 - 0
library/SimpleGan.py

@@ -0,0 +1,140 @@
+"""
+This module contains some example Generative Adversarial Networks for testing.
+
+The classes StupidToyPointGan and StupidToyListGan are not really Networks. This classes are used
+for testing the interface. Hope your actually GAN will perform better than this two.
+
+The class SimpleGan is a simple standard Generative Adversarial Network.
+"""
+
+
+import numpy as np
+
+from library.interfaces import GanBaseClass
+
+from keras.layers import Dense, Dropout, Input
+from keras.models import Model, Sequential
+from keras.layers.advanced_activations import LeakyReLU
+from keras.optimizers import Adam
+
+
+class SimpleGan(GanBaseClass):
+    """
+    A class for a simple GAN.
+    """
+    def __init__(self, numOfFeatures=786, noiseSize=100):
+        self.isTrained = False
+        self.noiseSize = noiseSize
+        self.numOfFeatures = numOfFeatures
+
+    def reset(self):
+        """
+        Resets the trained GAN to an random state.
+        """
+        self.generator = self._createGenerator(self.numOfFeatures, self.noiseSize)
+        self.discriminator = self._createDiscriminator(self.numOfFeatures)
+        self.gan = self._createGan(self.noiseSize)
+
+    @staticmethod
+    def _adamOptimizer():
+        return Adam(lr=0.0002, beta_1=0.5)
+
+    def _createGan(self, noiseSize=100):
+        self.discriminator.trainable=False
+        gan_input = Input(shape=(noiseSize,))
+        x = self.generator(gan_input)
+        gan_output = self.discriminator(x)
+        gan= Model(inputs=gan_input, outputs=gan_output)
+        gan.compile(loss='binary_crossentropy', optimizer='adam')
+        return gan
+
+    def _createGenerator(self, numOfFeatures, noiseSize):
+        generator=Sequential()
+        generator.add(Dense(units=256, input_dim=noiseSize))
+        generator.add(LeakyReLU(0.2))
+
+        generator.add(Dense(units=512))
+        generator.add(LeakyReLU(0.2))
+
+        generator.add(Dense(units=1024))
+        generator.add(LeakyReLU(0.2))
+
+        generator.add(Dense(units=numOfFeatures, activation='tanh'))
+
+        generator.compile(loss='binary_crossentropy', optimizer=self._adamOptimizer())
+        return generator
+
+    def _createDiscriminator(self, numOfFeatures):
+        discriminator=Sequential()
+        discriminator.add(Dense(units=1024, input_dim=numOfFeatures))
+        discriminator.add(LeakyReLU(0.2))
+        discriminator.add(Dropout(0.3))
+
+        discriminator.add(Dense(units=512))
+        discriminator.add(LeakyReLU(0.2))
+        discriminator.add(Dropout(0.3))
+
+        discriminator.add(Dense(units=256))
+        discriminator.add(LeakyReLU(0.2))
+
+        discriminator.add(Dense(units=1, activation='sigmoid'))
+
+        discriminator.compile(loss='binary_crossentropy', optimizer=self._adamOptimizer())
+        return discriminator
+
+    def train(self, dataset, epochs=1, batchSize=128):
+        trainData = dataset.data1
+        trainDataSize = trainData.shape[0]
+
+        if trainDataSize <= 0:
+            raise AttributeError("Train GAN: Expected data class 1 to contain at least one point.")
+
+        for e in range(epochs):
+            print(f"Epoch {e + 1}")
+            for _ in range(batchSize):
+                #generate  random noise as an input  to  initialize the  generator
+                noise= np.random.normal(0, 1, [batchSize, self.noiseSize])
+
+                # Generate fake MNIST images from noised input
+                generatedImages = self.generator.predict(noise)
+
+                # Get a random set of  real images
+                image_batch = dataset.data1[
+                    np.random.randint(low=0, high=trainDataSize, size=batchSize)
+                    ]
+
+                #Construct different batches of  real and fake data
+                X = np.concatenate([image_batch, generatedImages])
+
+                # Labels for generated and real data
+                y_dis=np.zeros(2 * batchSize)
+                y_dis[:batchSize] = 0.9
+
+                #Pre train discriminator on  fake and real data  before starting the gan.
+                self.discriminator.trainable = True
+                self.discriminator.train_on_batch(X, y_dis)
+
+                #Tricking the noised input of the Generator as real data
+                noise = np.random.normal(0, 1, [batchSize, 100])
+                y_gen = np.ones(batchSize)
+
+                # During the training of gan,
+                # the weights of discriminator should be fixed.
+                #We can enforce that by setting the trainable flag
+                self.discriminator.trainable=False
+
+                #training  the GAN by alternating the training of the Discriminator
+                #and training the chained GAN model with Discriminator’s weights freezed.
+                self.gan.train_on_batch(noise, y_gen)
+
+
+    def generateDataPoint(self):
+        return self.generateData(1)[0]
+
+
+    def generateData(self, numOfSamples=1):
+        #generate  random noise as an input  to  initialize the  generator
+        noise = np.random.normal(0, 1, [numOfSamples, self.noiseSize])
+
+        # Generate fake MNIST images from noised input
+        return self.generator.predict(noise)

+ 4 - 0
library/exercise.py

@@ -121,6 +121,10 @@ class Exercise:
         one data slice with training and testing data.
         """
 
+        # Start over with a new GAN instance.
+        self.debug("-> Reset the GAN")
+        gan.reset()
+
         # Train the gan so it can produce synthetic samples.
         self.debug("-> Train generator for synthetic samples")
         gan.train(dataSlice.train)

+ 6 - 0
library/interfaces.py

@@ -14,6 +14,12 @@ class GanBaseClass:
         Initializes the class.
         """
 
+    def reset(self):
+        """
+        Resets the trained GAN to an random state.
+        """
+        raise NotImplementedError
+
     def train(self, dataSet):
         """
         Trains the GAN.

Alguns arquivos não foram mostrados porque muitos arquivos mudaram nesse diff