SimpleGan.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142
  1. """
  2. This module contains some example Generative Adversarial Networks for testing.
  3. The classes StupidToyPointGan and StupidToyListGan are not really Networks. This classes are used
  4. for testing the interface. Hope your actually GAN will perform better than this two.
  5. The class SimpleGan is a simple standard Generative Adversarial Network.
  6. """
  7. import numpy as np
  8. from library.interfaces import GanBaseClass
  9. from keras.layers import Dense, Dropout, Input
  10. from keras.models import Model, Sequential
  11. from keras.layers.advanced_activations import LeakyReLU
  12. from keras.optimizers import Adam
  13. class SimpleGan(GanBaseClass):
  14. """
  15. A class for a simple GAN.
  16. """
  17. def __init__(self, numOfFeatures=786, noiseSize=100, epochs=3, batchSize=128):
  18. self.isTrained = False
  19. self.noiseSize = noiseSize
  20. self.numOfFeatures = numOfFeatures
  21. self.epochs = epochs
  22. self.batchSize = batchSize
  23. def reset(self):
  24. """
  25. Resets the trained GAN to an random state.
  26. """
  27. self.generator = self._createGenerator(self.numOfFeatures, self.noiseSize)
  28. self.discriminator = self._createDiscriminator(self.numOfFeatures)
  29. self.gan = self._createGan(self.noiseSize)
  30. @staticmethod
  31. def _adamOptimizer():
  32. return Adam(lr=0.0002, beta_1=0.5)
  33. def _createGan(self, noiseSize=100):
  34. self.discriminator.trainable=False
  35. gan_input = Input(shape=(noiseSize,))
  36. x = self.generator(gan_input)
  37. gan_output = self.discriminator(x)
  38. gan= Model(inputs=gan_input, outputs=gan_output)
  39. gan.compile(loss='binary_crossentropy', optimizer='adam')
  40. return gan
  41. def _createGenerator(self, numOfFeatures, noiseSize):
  42. generator=Sequential()
  43. generator.add(Dense(units=256, input_dim=noiseSize))
  44. generator.add(LeakyReLU(0.2))
  45. generator.add(Dense(units=512))
  46. generator.add(LeakyReLU(0.2))
  47. generator.add(Dense(units=1024))
  48. generator.add(LeakyReLU(0.2))
  49. generator.add(Dense(units=numOfFeatures, activation='tanh'))
  50. generator.compile(loss='binary_crossentropy', optimizer=self._adamOptimizer())
  51. return generator
  52. def _createDiscriminator(self, numOfFeatures):
  53. discriminator=Sequential()
  54. discriminator.add(Dense(units=1024, input_dim=numOfFeatures))
  55. discriminator.add(LeakyReLU(0.2))
  56. discriminator.add(Dropout(0.3))
  57. discriminator.add(Dense(units=512))
  58. discriminator.add(LeakyReLU(0.2))
  59. discriminator.add(Dropout(0.3))
  60. discriminator.add(Dense(units=256))
  61. discriminator.add(LeakyReLU(0.2))
  62. discriminator.add(Dense(units=1, activation='sigmoid'))
  63. discriminator.compile(loss='binary_crossentropy', optimizer=self._adamOptimizer())
  64. return discriminator
  65. def train(self, dataset):
  66. trainData = dataset.data1
  67. trainDataSize = trainData.shape[0]
  68. if trainDataSize <= 0:
  69. raise AttributeError("Train GAN: Expected data class 1 to contain at least one point.")
  70. for e in range(self.epochs):
  71. print(f"Epoch {e + 1}/{self.epochs}")
  72. for _ in range(self.batchSize):
  73. #generate random noise as an input to initialize the generator
  74. noise= np.random.normal(0, 1, [self.batchSize, self.noiseSize])
  75. # Generate fake MNIST images from noised input
  76. generatedImages = self.generator.predict(noise)
  77. # Get a random set of real images
  78. image_batch = dataset.data1[
  79. np.random.randint(low=0, high=trainDataSize, size=self.batchSize)
  80. ]
  81. #Construct different batches of real and fake data
  82. X = np.concatenate([image_batch, generatedImages])
  83. # Labels for generated and real data
  84. y_dis=np.zeros(2 * self.batchSize)
  85. y_dis[:self.batchSize] = 0.9
  86. #Pre train discriminator on fake and real data before starting the gan.
  87. self.discriminator.trainable = True
  88. self.discriminator.train_on_batch(X, y_dis)
  89. #Tricking the noised input of the Generator as real data
  90. noise = np.random.normal(0, 1, [self.batchSize, 100])
  91. y_gen = np.ones(self.batchSize)
  92. # During the training of gan,
  93. # the weights of discriminator should be fixed.
  94. #We can enforce that by setting the trainable flag
  95. self.discriminator.trainable=False
  96. #training the GAN by alternating the training of the Discriminator
  97. #and training the chained GAN model with Discriminator’s weights freezed.
  98. self.gan.train_on_batch(noise, y_gen)
  99. def generateDataPoint(self):
  100. return self.generateData(1)[0]
  101. def generateData(self, numOfSamples=1):
  102. #generate random noise as an input to initialize the generator
  103. noise = np.random.normal(0, 1, [numOfSamples, self.noiseSize])
  104. # Generate fake MNIST images from noised input
  105. return self.generator.predict(noise)