SimpleGan.py 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157
  1. """
  2. This module contains some example Generative Adversarial Networks for testing.
  3. The classes StupidToyPointGan and StupidToyListGan are not really Networks. This classes are used
  4. for testing the interface. Hope your actually GAN will perform better than this two.
  5. The class SimpleGan is a simple standard Generative Adversarial Network.
  6. """
  7. import numpy as np
  8. from library.interfaces import GanBaseClass
  9. from keras.layers import Dense, Dropout, Input
  10. from keras.models import Model, Sequential
  11. from keras.layers.advanced_activations import LeakyReLU
  12. from tensorflow.keras.optimizers import Adam
  13. import tensorflow as tf
  14. class SimpleGan(GanBaseClass):
  15. """
  16. A class for a simple GAN.
  17. """
  18. def __init__(self, numOfFeatures=786, noiseSize=None, epochs=3, batchSize=128, withTanh=False, gLayers=None, dLayers=None):
  19. self.isTrained = False
  20. self.noiseSize = noiseSize if noiseSize is not None else numOfFeatures
  21. self.numOfFeatures = numOfFeatures
  22. self.epochs = epochs
  23. self.batchSize = batchSize
  24. self.scaler = 1.0
  25. self.withTanh = withTanh
  26. self.dLayers = dLayers if dLayers is not None else [1024, 512, 256]
  27. self.gLayers = gLayers if gLayers is not None else [256, 512, 1024]
  28. def reset(self, _dataSet):
  29. """
  30. Resets the trained GAN to an random state.
  31. """
  32. self.scaler = 1.0
  33. self.generator = self._createGenerator(self.numOfFeatures, self.noiseSize)
  34. self.discriminator = self._createDiscriminator(self.numOfFeatures)
  35. self.gan = self._createGan(self.noiseSize)
  36. @staticmethod
  37. def _adamOptimizer():
  38. return Adam(learning_rate=0.0002, beta_1=0.5)
  39. def _createGan(self, noiseSize=100):
  40. self.discriminator.trainable=False
  41. gan_input = Input(shape=(noiseSize,))
  42. x = self.generator(gan_input)
  43. gan_output = self.discriminator(x)
  44. gan= Model(inputs=gan_input, outputs=gan_output)
  45. gan.compile(loss='binary_crossentropy', optimizer='adam')
  46. return gan
  47. def _createGenerator(self, numOfFeatures, noiseSize):
  48. generator=Sequential()
  49. for (n, size) in enumerate(self.dLayers):
  50. if n == 0:
  51. generator.add(Dense(units=size, input_dim=noiseSize))
  52. generator.add(LeakyReLU(0.2))
  53. else:
  54. generator.add(Dense(units=size))
  55. generator.add(LeakyReLU(0.2))
  56. if self.withTanh:
  57. generator.add(Dense(units=numOfFeatures, activation='tanh'))
  58. else:
  59. generator.add(Dense(units=numOfFeatures, activation='softsign'))
  60. generator.compile(loss='binary_crossentropy', optimizer=self._adamOptimizer())
  61. return generator
  62. def _createDiscriminator(self, numOfFeatures):
  63. discriminator=Sequential()
  64. for (n, size) in enumerate(self.dLayers):
  65. if n == 0:
  66. discriminator.add(Dense(units=size, input_dim=numOfFeatures))
  67. discriminator.add(LeakyReLU(0.2))
  68. else:
  69. discriminator.add(Dropout(0.3))
  70. discriminator.add(Dense(units=size))
  71. discriminator.add(LeakyReLU(0.2))
  72. discriminator.add(Dense(units=1, activation='sigmoid'))
  73. discriminator.compile(loss='binary_crossentropy', optimizer=self._adamOptimizer())
  74. return discriminator
  75. def train(self, dataset):
  76. trainData = dataset.data1
  77. trainDataSize = trainData.shape[0]
  78. if trainDataSize <= 0:
  79. raise AttributeError("Train GAN: Expected data class 1 to contain at least one point.")
  80. if self.withTanh:
  81. self.scaler = 1.0
  82. scaleDown = 1.0
  83. else:
  84. self.scaler = max(1.0, 1.1 * tf.reduce_max(tf.abs(trainData)).numpy())
  85. scaleDown = 1.0 / self.scaler
  86. for e in range(self.epochs):
  87. print(f"Epoch {e + 1}/{self.epochs}")
  88. for _ in range(self.batchSize):
  89. #generate random noise as an input to initialize the generator
  90. noise= np.random.normal(0, 1, [self.batchSize, self.noiseSize])
  91. # Generate fake MNIST images from noised input
  92. syntheticBatch = self.generator.predict(noise)
  93. # Get a random set of real images
  94. realBatch = dataset.data1[
  95. np.random.randint(low=0, high=trainDataSize, size=self.batchSize)
  96. ]
  97. #Construct different batches of real and fake data
  98. X = np.concatenate([scaleDown * realBatch, syntheticBatch])
  99. # Labels for generated and real data
  100. y_dis=np.zeros(2 * self.batchSize)
  101. y_dis[:self.batchSize] = 0.9
  102. #Pre train discriminator on fake and real data before starting the gan.
  103. self.discriminator.trainable = True
  104. self.discriminator.train_on_batch(X, y_dis)
  105. #Tricking the noised input of the Generator as real data
  106. noise = np.random.normal(0, 1, [self.batchSize, self.noiseSize])
  107. y_gen = np.ones(self.batchSize)
  108. # During the training of gan,
  109. # the weights of discriminator should be fixed.
  110. #We can enforce that by setting the trainable flag.
  111. self.discriminator.trainable=False
  112. #training the GAN by alternating the training of the Discriminator
  113. #and training the chained GAN model with Discriminator’s weights freezed.
  114. self.gan.train_on_batch(noise, y_gen)
  115. def generateDataPoint(self):
  116. return self.generateData(1)[0]
  117. def generateData(self, numOfSamples=1):
  118. #generate random noise as an input to initialize the generator
  119. noise = np.random.normal(0, 1, [numOfSamples, self.noiseSize])
  120. # Generate fake MNIST images from noised input
  121. return self.scaler * self.generator.predict(noise)