SimpleGan.py 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163
  1. """
  2. This module contains some example Generative Adversarial Networks for testing.
  3. The classes StupidToyPointGan and StupidToyListGan are not really Networks. This classes are used
  4. for testing the interface. Hope your actually GAN will perform better than this two.
  5. The class SimpleGan is a simple standard Generative Adversarial Network.
  6. """
  7. import numpy as np
  8. from library.interfaces import GanBaseClass
  9. from keras.layers import Dense, Dropout, Input
  10. from keras.models import Model, Sequential
  11. try:
  12. from keras.layers.advanced_activations import LeakyReLU
  13. except:
  14. from keras.layers import LeakyReLU
  15. from tensorflow.keras.optimizers import Adam
  16. import tensorflow as tf
  17. class SimpleGan(GanBaseClass):
  18. """
  19. A class for a simple GAN.
  20. """
  21. def __init__(self, numOfFeatures=786, noiseSize=None, epochs=10, batchSize=128, withTanh=False, gLayers=None, dLayers=None):
  22. self.canPredict = False
  23. self.isTrained = False
  24. self.noiseSize = noiseSize if noiseSize is not None else (numOfFeatures * 16)
  25. self.numOfFeatures = numOfFeatures
  26. self.epochs = epochs
  27. self.batchSize = batchSize
  28. self.scaler = 1.0
  29. self.withTanh = withTanh
  30. self.dLayers = dLayers if dLayers is not None else [numOfFeatures * 40, numOfFeatures * 20, numOfFeatures * 10]
  31. self.gLayers = gLayers if gLayers is not None else [self.noiseSize * 2, numOfFeatures * 4, numOfFeatures * 2]
  32. def reset(self, _dataSet):
  33. """
  34. Resets the trained GAN to an random state.
  35. """
  36. self.scaler = 1.0
  37. self.generator = self._createGenerator(self.numOfFeatures, self.noiseSize)
  38. self.discriminator = self._createDiscriminator(self.numOfFeatures)
  39. self.gan = self._createGan(self.noiseSize)
  40. @staticmethod
  41. def _adamOptimizer():
  42. return Adam(learning_rate=0.0002, beta_1=0.5)
  43. def _createGan(self, noiseSize=100):
  44. self.discriminator.trainable=False
  45. gan_input = Input(shape=(noiseSize,))
  46. x = self.generator(gan_input)
  47. gan_output = self.discriminator(x)
  48. gan= Model(inputs=gan_input, outputs=gan_output)
  49. gan.compile(loss='binary_crossentropy', optimizer='adam')
  50. return gan
  51. def _createGenerator(self, numOfFeatures, noiseSize):
  52. generator=Sequential()
  53. for (n, size) in enumerate(self.dLayers):
  54. if n == 0:
  55. generator.add(Dense(units=size, input_dim=noiseSize))
  56. generator.add(LeakyReLU(0.2))
  57. else:
  58. generator.add(Dense(units=size))
  59. generator.add(LeakyReLU(0.2))
  60. if self.withTanh:
  61. generator.add(Dense(units=numOfFeatures, activation='tanh'))
  62. else:
  63. generator.add(Dense(units=numOfFeatures, activation='softsign'))
  64. generator.compile(loss='binary_crossentropy', optimizer=self._adamOptimizer())
  65. return generator
  66. def _createDiscriminator(self, numOfFeatures):
  67. discriminator=Sequential()
  68. for (n, size) in enumerate(self.dLayers):
  69. if n == 0:
  70. discriminator.add(Dense(units=size, input_dim=numOfFeatures))
  71. discriminator.add(LeakyReLU(0.2))
  72. else:
  73. discriminator.add(Dropout(0.3))
  74. discriminator.add(Dense(units=size))
  75. discriminator.add(LeakyReLU(0.2))
  76. discriminator.add(Dense(units=1, activation='sigmoid'))
  77. discriminator.compile(loss='binary_crossentropy', optimizer=self._adamOptimizer())
  78. return discriminator
  79. def train(self, dataset):
  80. trainData = dataset.data1
  81. trainDataSize = trainData.shape[0]
  82. if trainDataSize <= 0:
  83. raise AttributeError("Train GAN: Expected data class 1 to contain at least one point.")
  84. if self.withTanh:
  85. self.scaler = 1.0
  86. scaleDown = 1.0
  87. else:
  88. self.scaler = max(1.0, 1.1 * tf.reduce_max(tf.abs(trainData)).numpy())
  89. scaleDown = 1.0 / self.scaler
  90. trainData = scaleDown * trainData
  91. for e in range(self.epochs):
  92. print(f"Epoch {e + 1}/{self.epochs}")
  93. for _ in range(self.batchSize):
  94. #generate random noise as an input to initialize the generator
  95. noise= np.random.normal(0, 1, [self.batchSize, self.noiseSize])
  96. # Generate fake MNIST images from noised input
  97. syntheticBatch = self.generator.predict(noise)
  98. # Get a random set of real images
  99. realBatch = trainData[
  100. np.random.randint(low=0, high=trainDataSize, size=self.batchSize)
  101. ]
  102. #Construct different batches of real and fake data
  103. X = np.concatenate([realBatch, syntheticBatch])
  104. # Labels for generated and real data
  105. y_dis=np.zeros(2 * self.batchSize)
  106. y_dis[:self.batchSize] = 0.9
  107. #Pre train discriminator on fake and real data before starting the gan.
  108. self.discriminator.trainable = True
  109. self.discriminator.train_on_batch(X, y_dis)
  110. #Tricking the noised input of the Generator as real data
  111. noise = np.random.normal(0, 1, [self.batchSize, self.noiseSize])
  112. y_gen = np.ones(self.batchSize)
  113. # During the training of gan,
  114. # the weights of discriminator should be fixed.
  115. #We can enforce that by setting the trainable flag.
  116. self.discriminator.trainable=False
  117. #training the GAN by alternating the training of the Discriminator
  118. #and training the chained GAN model with Discriminator’s weights freezed.
  119. self.gan.train_on_batch(noise, y_gen)
  120. def generateDataPoint(self):
  121. return self.generateData(1)[0]
  122. def generateData(self, numOfSamples=1):
  123. #generate random noise as an input to initialize the generator
  124. noise = np.random.normal(0, 1, [numOfSamples, self.noiseSize])
  125. # Generate fake MNIST images from noised input
  126. return self.scaler * self.generator.predict(noise)