SimpleGan.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140
  1. """
  2. This module contains some example Generative Adversarial Networks for testing.
  3. The classes StupidToyPointGan and StupidToyListGan are not really Networks. This classes are used
  4. for testing the interface. Hope your actually GAN will perform better than this two.
  5. The class SimpleGan is a simple standard Generative Adversarial Network.
  6. """
  7. import numpy as np
  8. from library.interfaces import GanBaseClass
  9. from keras.layers import Dense, Dropout, Input
  10. from keras.models import Model, Sequential
  11. from keras.layers.advanced_activations import LeakyReLU
  12. from keras.optimizers import Adam
  13. class SimpleGan(GanBaseClass):
  14. """
  15. A class for a simple GAN.
  16. """
  17. def __init__(self, numOfFeatures=786, noiseSize=100):
  18. self.isTrained = False
  19. self.noiseSize = noiseSize
  20. self.numOfFeatures = numOfFeatures
  21. def reset(self):
  22. """
  23. Resets the trained GAN to an random state.
  24. """
  25. self.generator = self._createGenerator(self.numOfFeatures, self.noiseSize)
  26. self.discriminator = self._createDiscriminator(self.numOfFeatures)
  27. self.gan = self._createGan(self.noiseSize)
  28. @staticmethod
  29. def _adamOptimizer():
  30. return Adam(lr=0.0002, beta_1=0.5)
  31. def _createGan(self, noiseSize=100):
  32. self.discriminator.trainable=False
  33. gan_input = Input(shape=(noiseSize,))
  34. x = self.generator(gan_input)
  35. gan_output = self.discriminator(x)
  36. gan= Model(inputs=gan_input, outputs=gan_output)
  37. gan.compile(loss='binary_crossentropy', optimizer='adam')
  38. return gan
  39. def _createGenerator(self, numOfFeatures, noiseSize):
  40. generator=Sequential()
  41. generator.add(Dense(units=256, input_dim=noiseSize))
  42. generator.add(LeakyReLU(0.2))
  43. generator.add(Dense(units=512))
  44. generator.add(LeakyReLU(0.2))
  45. generator.add(Dense(units=1024))
  46. generator.add(LeakyReLU(0.2))
  47. generator.add(Dense(units=numOfFeatures, activation='tanh'))
  48. generator.compile(loss='binary_crossentropy', optimizer=self._adamOptimizer())
  49. return generator
  50. def _createDiscriminator(self, numOfFeatures):
  51. discriminator=Sequential()
  52. discriminator.add(Dense(units=1024, input_dim=numOfFeatures))
  53. discriminator.add(LeakyReLU(0.2))
  54. discriminator.add(Dropout(0.3))
  55. discriminator.add(Dense(units=512))
  56. discriminator.add(LeakyReLU(0.2))
  57. discriminator.add(Dropout(0.3))
  58. discriminator.add(Dense(units=256))
  59. discriminator.add(LeakyReLU(0.2))
  60. discriminator.add(Dense(units=1, activation='sigmoid'))
  61. discriminator.compile(loss='binary_crossentropy', optimizer=self._adamOptimizer())
  62. return discriminator
  63. def train(self, dataset, epochs=1, batchSize=128):
  64. trainData = dataset.data1
  65. trainDataSize = trainData.shape[0]
  66. if trainDataSize <= 0:
  67. raise AttributeError("Train GAN: Expected data class 1 to contain at least one point.")
  68. for e in range(epochs):
  69. print(f"Epoch {e + 1}")
  70. for _ in range(batchSize):
  71. #generate random noise as an input to initialize the generator
  72. noise= np.random.normal(0, 1, [batchSize, self.noiseSize])
  73. # Generate fake MNIST images from noised input
  74. generatedImages = self.generator.predict(noise)
  75. # Get a random set of real images
  76. image_batch = dataset.data1[
  77. np.random.randint(low=0, high=trainDataSize, size=batchSize)
  78. ]
  79. #Construct different batches of real and fake data
  80. X = np.concatenate([image_batch, generatedImages])
  81. # Labels for generated and real data
  82. y_dis=np.zeros(2 * batchSize)
  83. y_dis[:batchSize] = 0.9
  84. #Pre train discriminator on fake and real data before starting the gan.
  85. self.discriminator.trainable = True
  86. self.discriminator.train_on_batch(X, y_dis)
  87. #Tricking the noised input of the Generator as real data
  88. noise = np.random.normal(0, 1, [batchSize, 100])
  89. y_gen = np.ones(batchSize)
  90. # During the training of gan,
  91. # the weights of discriminator should be fixed.
  92. #We can enforce that by setting the trainable flag
  93. self.discriminator.trainable=False
  94. #training the GAN by alternating the training of the Discriminator
  95. #and training the chained GAN model with Discriminator’s weights freezed.
  96. self.gan.train_on_batch(noise, y_gen)
  97. def generateDataPoint(self):
  98. return self.generateData(1)[0]
  99. def generateData(self, numOfSamples=1):
  100. #generate random noise as an input to initialize the generator
  101. noise = np.random.normal(0, 1, [numOfSamples, self.noiseSize])
  102. # Generate fake MNIST images from noised input
  103. return self.generator.predict(noise)