LoGAN.py 9.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285
  1. import math
  2. import numpy as np
  3. from tqdm import tqdm
  4. from keras.layers import Dense, Dropout, Input
  5. from keras.models import Model,Sequential
  6. from keras.layers.advanced_activations import LeakyReLU
  7. from keras.optimizers import Adam
  8. from sklearn.neighbors import NearestNeighbors
  9. def adam_optimizer():
  10. return Adam(lr=0.0002, beta_1=0.5)
  11. def Neb_grps(data, near_neb):
  12. nbrs = NearestNeighbors(n_neighbors=near_neb, algorithm='ball_tree').fit(data)
  13. _distances, indices = nbrs.kneighbors(data)
  14. neb_class = list(indices)
  15. return np.asarray(neb_class)
  16. class GanTrainParameters:
  17. """
  18. Parameters for Training the GAN Network.
  19. """
  20. def __init__(self, n_feat, batch_size, min_t, features_0_trn, features_1_trn):
  21. self.batch_size = batch_size
  22. self.n_feat = n_feat
  23. self.min_t = min_t
  24. self.features_0_trn = features_0_trn
  25. self.features_1_trn = features_1_trn
  26. def im_batch_creator_min(self):
  27. nbd = Neb_grps(self.min_t, self.batch_size)
  28. rand = np.random.randint(low=0, high=self.features_1_trn.shape[0], size=1)
  29. idx = tuple(list(nbd[rand]))
  30. image_batch = self.features_1_trn[idx]
  31. return image_batch
  32. def im_batch_creator_maj(self):
  33. rand = np.random.randint(low=0, high=self.features_0_trn.shape[0], size=self.batch_size)
  34. image_batch = np.reshape(self.features_0_trn[rand[:,None]], (self.batch_size, self.n_feat))
  35. return image_batch
  36. class TLoRasNoise:
  37. """
  38. Noise function
  39. """
  40. def __init__(self, shadow=50, sigma=.005, num_afcomb=7):
  41. self.shadow = shadow
  42. self.sigma = sigma
  43. self.num_afcomb = num_afcomb
  44. def tLoRAS(self, data, num_samples, num_RACOS):
  45. np.random.seed(42)
  46. data_shadow = np.asarray([
  47. d + np.random.normal(0, self.sigma)
  48. for d in data[:num_samples]
  49. for _c in range(self.shadow)
  50. ])
  51. return np.asarray([
  52. self.shadowLcDataPoint(num_samples, data_shadow)
  53. for _i in range(num_RACOS)
  54. ])
  55. def shadowLcDataPoint(self, num_samples, data_shadow):
  56. idx = np.random.randint(self.shadow * num_samples, size=self.num_afcomb)
  57. w = np.random.randint(100, size=len(idx))
  58. aff_w = np.asarray(w/sum(w))
  59. data_tsl = np.array(data_shadow)[idx,:]
  60. return np.dot(aff_w, data_tsl)
  61. def noise(self, data, batch_size):
  62. return self.tLoRAS(data=data, num_samples=batch_size, num_RACOS=batch_size)
  63. class GAN:
  64. """
  65. Class for GAN.
  66. """
  67. def __init__(self, n_feat=1, noise=None,
  68. discriminatorMin=None, discriminatorMax=None, generator=None):
  69. self.n_feat = n_feat
  70. if noise is None:
  71. self.noise = TLoRasNoise()
  72. else:
  73. self.noise = noise
  74. self.create_gan(
  75. discriminatorMin or self.create_discriminator_min(),
  76. discriminatorMax or self.create_discriminator_maj(),
  77. generator or self.create_generator())
  78. def create_generator(self):
  79. generator=Sequential()
  80. generator.add(Dense(units=25, input_dim=self.n_feat))
  81. generator.add(LeakyReLU(0.2))
  82. generator.add(Dense(units=256))
  83. generator.add(LeakyReLU(0.2))
  84. generator.add(Dense(units=512))
  85. generator.add(LeakyReLU(0.2))
  86. generator.add(Dense(units=256))
  87. generator.add(LeakyReLU(0.2))
  88. generator.add(Dense(units=25))
  89. generator.add(LeakyReLU(0.2))
  90. generator.add(Dense(units=self.n_feat))
  91. generator.compile(loss='binary_crossentropy', optimizer=adam_optimizer())
  92. return generator
  93. def create_discriminator_min(self):
  94. discriminator=Sequential()
  95. discriminator.add(Dense(units=1024,input_dim=self.n_feat))
  96. discriminator.add(LeakyReLU(0.2))
  97. discriminator.add(Dropout(0.3))
  98. discriminator.add(Dense(units=512))
  99. discriminator.add(LeakyReLU(0.2))
  100. discriminator.add(Dropout(0.3))
  101. discriminator.add(Dense(units=256))
  102. discriminator.add(LeakyReLU(0.2))
  103. discriminator.add(Dense(units=1, activation='sigmoid'))
  104. discriminator.compile(loss='binary_crossentropy', optimizer=adam_optimizer())
  105. return discriminator
  106. def create_discriminator_maj(self):
  107. discriminator=Sequential()
  108. discriminator.add(Dense(units=1024,input_dim=self.n_feat))
  109. discriminator.add(LeakyReLU(0.2))
  110. discriminator.add(Dropout(0.3))
  111. discriminator.add(Dense(units=512))
  112. discriminator.add(LeakyReLU(0.2))
  113. discriminator.add(Dropout(0.3))
  114. discriminator.add(Dense(units=256))
  115. discriminator.add(LeakyReLU(0.2))
  116. discriminator.add(Dense(units=1, activation='sigmoid'))
  117. discriminator.compile(loss='binary_crossentropy', optimizer=adam_optimizer())
  118. return discriminator
  119. def create_gan(self, discriminator_min, discriminator_maj, generator):
  120. discriminator_min.trainable = False
  121. discriminator_maj.trainable = False
  122. gan_input = Input(shape=(self.n_feat,))
  123. x = generator(gan_input)
  124. gan_output_min= discriminator_min(x)
  125. gan_output_maj= discriminator_maj(x)
  126. gan = Model(inputs=gan_input, outputs=[gan_output_min,gan_output_maj])
  127. gan.compile(loss=['binary_crossentropy','binary_crossentropy'], optimizer='adam')
  128. # store the parts for later usage.
  129. self.generator = generator
  130. self.discriminator_min = discriminator_min
  131. self.discriminator_maj = discriminator_maj
  132. self.gan = gan
  133. return gan
  134. def train(self, parameters):
  135. for e in range(1,30+1 ):
  136. print(e)
  137. for _i in tqdm(range(parameters.batch_size)):
  138. # Get a random set of real images
  139. image_batch = parameters.im_batch_creator_min()
  140. #generate random noise as an input to initialize the generator
  141. noise_min = self.noise.noise(image_batch, parameters.batch_size)
  142. # Generate fake samples from noised input
  143. generated_images = self.generator.predict(noise_min)
  144. #Construct different batches of real and fake data
  145. X = np.concatenate((image_batch, generated_images))
  146. # Labels for generated and real data
  147. y_dis = np.zeros(2* parameters.batch_size)
  148. y_dis[: parameters.batch_size]=0.9
  149. #Pre train discriminator_min on fake and real data before starting the gan.
  150. self.discriminator_min.trainable = True
  151. _d_loss_min = self.discriminator_min.train_on_batch(X, y_dis)
  152. if e==0 or e>15:
  153. image_batch_maj = parameters.im_batch_creator_maj()
  154. X_maj = np.concatenate((image_batch_maj, generated_images))
  155. y_dis_maj=np.ones(2* parameters.batch_size)+1
  156. y_dis_maj[: parameters.batch_size]=0
  157. #Pre train discriminator_maj on fake and real data before starting the gan.
  158. self.discriminator_maj.trainable = True
  159. _d_loss_maj = self.discriminator_maj.train_on_batch(X_maj, y_dis_maj)
  160. #Tricking the noised input of the Generator as real data
  161. noise = self.noise.noise(image_batch, parameters.batch_size)
  162. y_gen_min = np.ones(parameters.batch_size)
  163. # During the training of gan,
  164. # the weights of discriminator should be fixed.
  165. #We can enforce that by setting the trainable flag
  166. self.discriminator_min.trainable = False
  167. self.discriminator_maj.trainable = False
  168. #training the GAN by alternating the training of the Discriminator
  169. #and training the chained GAN model with Discriminator’s weights freezed.
  170. _g_loss_min = self.gan.train_on_batch(noise, [y_gen_min, y_gen_min])
  171. def genFeat(self, parameters):
  172. im_batch = parameters.im_batch_creator_min()
  173. noise = self.noise.noise(im_batch, parameters.batch_size)
  174. return self.generator.predict(noise)
  175. def predict(self, data):
  176. y_pred = self.discriminator_maj.predict(data)
  177. return np.reshape(y_pred, len(data))
  178. class DataSet:
  179. """
  180. Stores data and Labels.
  181. """
  182. def __init__(self, data=None, labels=None, data0=None, data1=None):
  183. if data is None:
  184. self.fromData01(data0, data1)
  185. else:
  186. if labels is None:
  187. raise "expected labels to be a numpy.array"
  188. else:
  189. self.data = data
  190. self.labels = labels
  191. def fromData01(self, data0=None, data1=None):
  192. if data0 is None and data1 is None:
  193. raise "Expected data, data0 or data1 to be a numpy.array"
  194. elif data0 is None:
  195. self.data = data1
  196. self.labels = np.zeros(len(data1)) + 1
  197. elif data1 is None:
  198. self.data = data0
  199. self.labels = np.zeros(len(data0))
  200. else:
  201. self.data = np.concatenate((data1, data0))
  202. self.labels = np.concatenate(( np.zeros(len(data1)) + 1, np.zeros(len(data0)) ))
  203. class TrainTestData:
  204. """
  205. Stores features, data and labels for class 0 and class 1.
  206. """
  207. def __init__(self, features0, features1, trainFactor=0.9):
  208. self.nFeatures0 = len(features0)
  209. self.nFeatures1 = len(features1)
  210. self.features_0_trn, self.features_0_tst = self.splitUpData(features0, trainFactor)
  211. self.features_1_trn, self.features_1_tst = self.splitUpData(features1, trainFactor)
  212. self.test = DataSet(data1=self.features_1_tst, data0=self.features_0_tst)
  213. self.train = DataSet(data1=self.features_1_trn, data0=self.features_0_trn)
  214. def splitUpData(self, data, trainFactor=0.9):
  215. size = len(data)
  216. trainSize = math.ceil(size * trainFactor)
  217. trn = data[list(range(0, trainSize))]
  218. tst = data[list(range(trainSize, size))]
  219. return trn, tst