SimpleGan.py 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160
  1. """
  2. This module contains some example Generative Adversarial Networks for testing.
  3. The classes StupidToyPointGan and StupidToyListGan are not really Networks. This classes are used
  4. for testing the interface. Hope your actually GAN will perform better than this two.
  5. The class SimpleGan is a simple standard Generative Adversarial Network.
  6. """
  7. import numpy as np
  8. from library.interfaces import GanBaseClass
  9. from keras.layers import Dense, Dropout, Input
  10. from keras.models import Model, Sequential
  11. from keras.layers.advanced_activations import LeakyReLU
  12. from tensorflow.keras.optimizers import Adam
  13. import tensorflow as tf
  14. class SimpleGan(GanBaseClass):
  15. """
  16. A class for a simple GAN.
  17. """
  18. def __init__(self, numOfFeatures=786, noiseSize=None, epochs=10, batchSize=128, withTanh=False, gLayers=None, dLayers=None):
  19. self.canPredict = False
  20. self.isTrained = False
  21. self.noiseSize = noiseSize if noiseSize is not None else (numOfFeatures * 16)
  22. self.numOfFeatures = numOfFeatures
  23. self.epochs = epochs
  24. self.batchSize = batchSize
  25. self.scaler = 1.0
  26. self.withTanh = withTanh
  27. self.dLayers = dLayers if dLayers is not None else [numOfFeatures * 40, numOfFeatures * 20, numOfFeatures * 10]
  28. self.gLayers = gLayers if gLayers is not None else [self.noiseSize * 2, numOfFeatures * 4, numOfFeatures * 2]
  29. def reset(self, _dataSet):
  30. """
  31. Resets the trained GAN to an random state.
  32. """
  33. self.scaler = 1.0
  34. self.generator = self._createGenerator(self.numOfFeatures, self.noiseSize)
  35. self.discriminator = self._createDiscriminator(self.numOfFeatures)
  36. self.gan = self._createGan(self.noiseSize)
  37. @staticmethod
  38. def _adamOptimizer():
  39. return Adam(learning_rate=0.0002, beta_1=0.5)
  40. def _createGan(self, noiseSize=100):
  41. self.discriminator.trainable=False
  42. gan_input = Input(shape=(noiseSize,))
  43. x = self.generator(gan_input)
  44. gan_output = self.discriminator(x)
  45. gan= Model(inputs=gan_input, outputs=gan_output)
  46. gan.compile(loss='binary_crossentropy', optimizer='adam')
  47. return gan
  48. def _createGenerator(self, numOfFeatures, noiseSize):
  49. generator=Sequential()
  50. for (n, size) in enumerate(self.dLayers):
  51. if n == 0:
  52. generator.add(Dense(units=size, input_dim=noiseSize))
  53. generator.add(LeakyReLU(0.2))
  54. else:
  55. generator.add(Dense(units=size))
  56. generator.add(LeakyReLU(0.2))
  57. if self.withTanh:
  58. generator.add(Dense(units=numOfFeatures, activation='tanh'))
  59. else:
  60. generator.add(Dense(units=numOfFeatures, activation='softsign'))
  61. generator.compile(loss='binary_crossentropy', optimizer=self._adamOptimizer())
  62. return generator
  63. def _createDiscriminator(self, numOfFeatures):
  64. discriminator=Sequential()
  65. for (n, size) in enumerate(self.dLayers):
  66. if n == 0:
  67. discriminator.add(Dense(units=size, input_dim=numOfFeatures))
  68. discriminator.add(LeakyReLU(0.2))
  69. else:
  70. discriminator.add(Dropout(0.3))
  71. discriminator.add(Dense(units=size))
  72. discriminator.add(LeakyReLU(0.2))
  73. discriminator.add(Dense(units=1, activation='sigmoid'))
  74. discriminator.compile(loss='binary_crossentropy', optimizer=self._adamOptimizer())
  75. return discriminator
  76. def train(self, dataset):
  77. trainData = dataset.data1
  78. trainDataSize = trainData.shape[0]
  79. if trainDataSize <= 0:
  80. raise AttributeError("Train GAN: Expected data class 1 to contain at least one point.")
  81. if self.withTanh:
  82. self.scaler = 1.0
  83. scaleDown = 1.0
  84. else:
  85. self.scaler = max(1.0, 1.1 * tf.reduce_max(tf.abs(trainData)).numpy())
  86. scaleDown = 1.0 / self.scaler
  87. trainData = scaleDown * trainData
  88. for e in range(self.epochs):
  89. print(f"Epoch {e + 1}/{self.epochs}")
  90. for _ in range(self.batchSize):
  91. #generate random noise as an input to initialize the generator
  92. noise= np.random.normal(0, 1, [self.batchSize, self.noiseSize])
  93. # Generate fake MNIST images from noised input
  94. syntheticBatch = self.generator.predict(noise)
  95. # Get a random set of real images
  96. realBatch = trainData[
  97. np.random.randint(low=0, high=trainDataSize, size=self.batchSize)
  98. ]
  99. #Construct different batches of real and fake data
  100. X = np.concatenate([realBatch, syntheticBatch])
  101. # Labels for generated and real data
  102. y_dis=np.zeros(2 * self.batchSize)
  103. y_dis[:self.batchSize] = 0.9
  104. #Pre train discriminator on fake and real data before starting the gan.
  105. self.discriminator.trainable = True
  106. self.discriminator.train_on_batch(X, y_dis)
  107. #Tricking the noised input of the Generator as real data
  108. noise = np.random.normal(0, 1, [self.batchSize, self.noiseSize])
  109. y_gen = np.ones(self.batchSize)
  110. # During the training of gan,
  111. # the weights of discriminator should be fixed.
  112. #We can enforce that by setting the trainable flag.
  113. self.discriminator.trainable=False
  114. #training the GAN by alternating the training of the Discriminator
  115. #and training the chained GAN model with Discriminator’s weights freezed.
  116. self.gan.train_on_batch(noise, y_gen)
  117. def generateDataPoint(self):
  118. return self.generateData(1)[0]
  119. def generateData(self, numOfSamples=1):
  120. #generate random noise as an input to initialize the generator
  121. noise = np.random.normal(0, 1, [numOfSamples, self.noiseSize])
  122. # Generate fake MNIST images from noised input
  123. return self.scaler * self.generator.predict(noise)