ConvGeN.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403
  1. import numpy as np
  2. import matplotlib.pyplot as plt
  3. from library.interfaces import GanBaseClass
  4. from library.dataset import DataSet
  5. from keras.layers import Dense, Input, Multiply, Flatten, Conv1D, Reshape
  6. from keras.models import Model
  7. from keras import backend as K
  8. from tqdm import tqdm
  9. import tensorflow as tf
  10. from tensorflow.keras.optimizers import Adam
  11. from tensorflow.keras.layers import Lambda
  12. from sklearn.utils import shuffle
  13. from library.NNSearch import NNSearch
  14. import warnings
  15. warnings.filterwarnings("ignore")
  16. def repeat(x, times):
  17. return [x for _i in range(times)]
  18. def create01Labels(totalSize, sizeFirstHalf):
  19. labels = repeat(np.array([1,0]), sizeFirstHalf)
  20. labels.extend(repeat(np.array([0,1]), totalSize - sizeFirstHalf))
  21. return np.array(labels)
  22. class ConvGeN(GanBaseClass):
  23. """
  24. This is a toy example of a GAN.
  25. It repeats the first point of the training-data-set.
  26. """
  27. def __init__(self, n_feat, neb=5, gen=None, neb_epochs=10, withMajorhoodNbSearch=False, debug=False):
  28. self.isTrained = False
  29. self.n_feat = n_feat
  30. self.neb = neb
  31. self.nebInitial = neb
  32. self.genInitial = gen
  33. self.gen = gen if gen is not None else self.neb
  34. self.neb_epochs = 10
  35. self.loss_history = None
  36. self.debug = debug
  37. self.minSetSize = 0
  38. self.conv_sample_generator = None
  39. self.maj_min_discriminator = None
  40. self.withMajorhoodNbSearch = withMajorhoodNbSearch
  41. self.cg = None
  42. self.canPredict = True
  43. if self.neb is not None and self.gen is not None and self.neb > self.gen:
  44. raise ValueError(f"Expected neb <= gen but got neb={neb} and gen={gen}.")
  45. def reset(self, dataSet):
  46. """
  47. Resets the trained GAN to an random state.
  48. """
  49. self.isTrained = False
  50. if dataSet is not None:
  51. nMinoryPoints = dataSet.data1.shape[0]
  52. if self.nebInitial is None:
  53. self.neb = nMinoryPoints
  54. else:
  55. self.neb = min(self.nebInitial, nMinoryPoints)
  56. else:
  57. self.neb = self.nebInitial
  58. self.gen = self.genInitial if self.genInitial is not None else self.neb
  59. ## instanciate generator network and visualize architecture
  60. self.conv_sample_generator = self._conv_sample_gen()
  61. ## instanciate discriminator network and visualize architecture
  62. self.maj_min_discriminator = self._maj_min_disc()
  63. ## instanciate network and visualize architecture
  64. self.cg = self._convGeN(self.conv_sample_generator, self.maj_min_discriminator)
  65. if self.debug:
  66. print(f"neb={self.neb}, gen={self.gen}")
  67. print(self.conv_sample_generator.summary())
  68. print('\n')
  69. print(self.maj_min_discriminator.summary())
  70. print('\n')
  71. print(self.cg.summary())
  72. print('\n')
  73. def train(self, dataSet, discTrainCount=5):
  74. """
  75. Trains the GAN.
  76. It stores the data points in the training data set and mark as trained.
  77. *dataSet* is a instance of /library.dataset.DataSet/. It contains the training dataset.
  78. We are only interested in the first *maxListSize* points in class 1.
  79. """
  80. if dataSet.data1.shape[0] <= 0:
  81. raise AttributeError("Train: Expected data class 1 to contain at least one point.")
  82. # Store size of minority class. This is needed during point generation.
  83. self.minSetSize = dataSet.data1.shape[0]
  84. # Precalculate neighborhoods
  85. self.nmbMin = NNSearch(self.neb).fit(haystack=dataSet.data1)
  86. if self.withMajorhoodNbSearch:
  87. self.nmbMaj = NNSearch(self.neb).fit(haystack=dataSet.data0, needles=dataSet.data1)
  88. else:
  89. self.nmbMaj = None
  90. # Do the training.
  91. self._rough_learning(dataSet.data1, dataSet.data0, discTrainCount)
  92. # Neighborhood in majority class is no longer needed. So save memory.
  93. self.nmbMaj = None
  94. self.isTrained = True
  95. def generateDataPoint(self):
  96. """
  97. Returns one synthetic data point by repeating the stored list.
  98. """
  99. return (self.generateData(1))[0]
  100. def generateData(self, numOfSamples=1):
  101. """
  102. Generates a list of synthetic data-points.
  103. *numOfSamples* is a integer > 0. It gives the number of new generated samples.
  104. """
  105. if not self.isTrained:
  106. raise ValueError("Try to generate data with untrained Re.")
  107. ## roughly claculate the upper bound of the synthetic samples to be generated from each neighbourhood
  108. synth_num = (numOfSamples // self.minSetSize) + 1
  109. ## generate synth_num synthetic samples from each minority neighbourhood
  110. synth_set=[]
  111. for i in range(self.minSetSize):
  112. synth_set.extend(self._generate_data_for_min_point(i, synth_num))
  113. ## extract the exact number of synthetic samples needed to exactly balance the two classes
  114. synth_set = np.array(synth_set[:numOfSamples])
  115. return synth_set
  116. def predictReal(self, data):
  117. prediction = self.maj_min_discriminator.predict(data)
  118. return np.array([x[0] for x in prediction])
  119. # ###############################################################
  120. # Hidden internal functions
  121. # ###############################################################
  122. # Creating the GAN
  123. def _conv_sample_gen(self):
  124. """
  125. the generator network to generate synthetic samples from the convex space
  126. of arbitrary minority neighbourhoods
  127. """
  128. ## takes minority batch as input
  129. min_neb_batch = Input(shape=(self.n_feat,))
  130. ## reshaping the 2D tensor to 3D for using 1-D convolution,
  131. ## otherwise 1-D convolution won't work.
  132. x = tf.reshape(min_neb_batch, (1, self.neb, self.n_feat), name=None)
  133. ## using 1-D convolution, feature dimension remains the same
  134. x = Conv1D(self.n_feat, 3, activation='relu')(x)
  135. ## flatten after convolution
  136. x = Flatten()(x)
  137. ## add dense layer to transform the vector to a convenient dimension
  138. x = Dense(self.neb * self.gen, activation='relu')(x)
  139. ## again, witching to 2-D tensor once we have the convenient shape
  140. x = Reshape((self.neb, self.gen))(x)
  141. ## row wise sum
  142. s = K.sum(x, axis=1)
  143. ## adding a small constant to always ensure the row sums are non zero.
  144. ## if this is not done then during initialization the sum can be zero.
  145. s_non_zero = Lambda(lambda x: x + .000001)(s)
  146. ## reprocals of the approximated row sum
  147. sinv = tf.math.reciprocal(s_non_zero)
  148. ## At this step we ensure that row sum is 1 for every row in x.
  149. ## That means, each row is set of convex co-efficient
  150. x = Multiply()([sinv, x])
  151. ## Now we transpose the matrix. So each column is now a set of convex coefficients
  152. aff=tf.transpose(x[0])
  153. ## We now do matrix multiplication of the affine combinations with the original
  154. ## minority batch taken as input. This generates a convex transformation
  155. ## of the input minority batch
  156. synth=tf.matmul(aff, min_neb_batch)
  157. ## finally we compile the generator with an arbitrary minortiy neighbourhood batch
  158. ## as input and a covex space transformation of the same number of samples as output
  159. model = Model(inputs=min_neb_batch, outputs=synth)
  160. opt = Adam(learning_rate=0.001)
  161. model.compile(loss='mean_squared_logarithmic_error', optimizer=opt)
  162. return model
  163. def _maj_min_disc(self):
  164. """
  165. the discriminator is trained intwo phase:
  166. first phase: while training GAN the discriminator learns to differentiate synthetic
  167. minority samples generated from convex minority data space against
  168. the borderline majority samples
  169. second phase: after the GAN generator learns to create synthetic samples,
  170. it can be used to generate synthetic samples to balance the dataset
  171. and then rettrain the discriminator with the balanced dataset
  172. """
  173. ## takes as input synthetic sample generated as input stacked upon a batch of
  174. ## borderline majority samples
  175. samples = Input(shape=(self.n_feat,))
  176. ## passed through two dense layers
  177. y = Dense(250, activation='relu')(samples)
  178. y = Dense(125, activation='relu')(y)
  179. y = Dense(75, activation='relu')(y)
  180. ## two output nodes. outputs have to be one-hot coded (see labels variable before)
  181. output = Dense(2, activation='sigmoid')(y)
  182. ## compile model
  183. model = Model(inputs=samples, outputs=output)
  184. opt = Adam(learning_rate=0.0001)
  185. model.compile(loss='binary_crossentropy', optimizer=opt)
  186. return model
  187. def _convGeN(self, generator, discriminator):
  188. """
  189. for joining the generator and the discriminator
  190. conv_coeff_generator-> generator network instance
  191. maj_min_discriminator -> discriminator network instance
  192. """
  193. ## by default the discriminator trainability is switched off.
  194. ## Thus training the GAN means training the generator network as per previously
  195. ## trained discriminator network.
  196. discriminator.trainable = False
  197. ## input receives a neighbourhood minority batch
  198. ## and a proximal majority batch concatenated
  199. batch_data = Input(shape=(self.n_feat,))
  200. ##- print(f"GAN: 0..{self.neb}/{self.gen}..")
  201. ## extract minority batch
  202. min_batch = Lambda(lambda x: x[:self.neb])(batch_data)
  203. ## extract majority batch
  204. maj_batch = Lambda(lambda x: x[self.gen:])(batch_data)
  205. ## pass minority batch into generator to obtain convex space transformation
  206. ## (synthetic samples) of the minority neighbourhood input batch
  207. conv_samples = generator(min_batch)
  208. ## concatenate the synthetic samples with the majority samples
  209. new_samples = tf.concat([conv_samples, maj_batch],axis=0)
  210. ##- new_samples = tf.concat([conv_samples, conv_samples, conv_samples, conv_samples],axis=0)
  211. ## pass the concatenated vector into the discriminator to know its decisions
  212. output = discriminator(new_samples)
  213. ##- output = Lambda(lambda x: x[:2 * self.gen])(output)
  214. ## note that, the discriminator will not be traied but will make decisions based
  215. ## on its previous training while using this function
  216. model = Model(inputs=batch_data, outputs=output)
  217. opt = Adam(learning_rate=0.0001)
  218. model.compile(loss='mse', optimizer=opt)
  219. return model
  220. # Create synthetic points
  221. def _generate_data_for_min_point(self, index, synth_num):
  222. """
  223. generate synth_num synthetic points for a particular minoity sample
  224. synth_num -> required number of data points that can be generated from a neighbourhood
  225. data_min -> minority class data
  226. neb -> oversampling neighbourhood
  227. index -> index of the minority sample in a training data whose neighbourhood we want to obtain
  228. """
  229. runs = int(synth_num / self.neb) + 1
  230. synth_set = []
  231. for _run in range(runs):
  232. batch = self.nmbMin.getNbhPointsOfItem(index)
  233. synth_batch = self.conv_sample_generator.predict(batch, batch_size=self.neb)
  234. synth_set.extend(synth_batch)
  235. return synth_set[:synth_num]
  236. # Training
  237. def _rough_learning(self, data_min, data_maj, discTrainCount):
  238. generator = self.conv_sample_generator
  239. discriminator = self.maj_min_discriminator
  240. GAN = self.cg
  241. loss_history = [] ## this is for stroring the loss for every run
  242. step = 0
  243. minSetSize = len(data_min)
  244. labels = tf.convert_to_tensor(create01Labels(2 * self.gen, self.gen))
  245. nLabels = 2 * self.gen
  246. for neb_epoch_count in range(self.neb_epochs):
  247. if discTrainCount > 0:
  248. for n in range(discTrainCount):
  249. for min_idx in range(minSetSize):
  250. ## generate minority neighbourhood batch for every minority class sampls by index
  251. min_batch_indices = shuffle(self.nmbMin.neighbourhoodOfItem(min_idx))
  252. min_batch = self.nmbMin.getPointsFromIndices(min_batch_indices)
  253. ## generate random proximal majority batch
  254. maj_batch = self._BMB(data_maj, min_batch_indices)
  255. ## generate synthetic samples from convex space
  256. ## of minority neighbourhood batch using generator
  257. conv_samples = generator.predict(min_batch, batch_size=self.neb)
  258. ## concatenate them with the majority batch
  259. concat_sample = tf.concat([conv_samples, maj_batch], axis=0)
  260. ## switch on discriminator training
  261. discriminator.trainable = True
  262. ## train the discriminator with the concatenated samples and the one-hot encoded labels
  263. discriminator.fit(x=concat_sample, y=labels, verbose=0, batch_size=20)
  264. ## switch off the discriminator training again
  265. discriminator.trainable = False
  266. for min_idx in range(minSetSize):
  267. ## generate minority neighbourhood batch for every minority class sampls by index
  268. min_batch_indices = shuffle(self.nmbMin.neighbourhoodOfItem(min_idx))
  269. min_batch = self.nmbMin.getPointsFromIndices(min_batch_indices)
  270. ## generate random proximal majority batch
  271. maj_batch = self._BMB(data_maj, min_batch_indices)
  272. ## generate synthetic samples from convex space
  273. ## of minority neighbourhood batch using generator
  274. conv_samples = generator.predict(min_batch, batch_size=self.neb)
  275. ## concatenate them with the majority batch
  276. concat_sample = tf.concat([conv_samples, maj_batch], axis=0)
  277. ## switch on discriminator training
  278. discriminator.trainable = True
  279. ## train the discriminator with the concatenated samples and the one-hot encoded labels
  280. discriminator.fit(x=concat_sample, y=labels, verbose=0, batch_size=20)
  281. ## switch off the discriminator training again
  282. discriminator.trainable = False
  283. ## use the GAN to make the generator learn on the decisions
  284. ## made by the previous discriminator training
  285. ##- print(f"concat sample shape: {concat_sample.shape}/{labels.shape}")
  286. gan_loss_history = GAN.fit(concat_sample, y=labels, verbose=0, batch_size=nLabels)
  287. ## store the loss for the step
  288. loss_history.append(gan_loss_history.history['loss'])
  289. step += 1
  290. if self.debug and (step % 10 == 0):
  291. print(f"{step} neighbourhood batches trained; running neighbourhood epoch {neb_epoch_count}")
  292. if self.debug:
  293. print(f"Neighbourhood epoch {neb_epoch_count + 1} complete")
  294. if self.debug:
  295. run_range = range(1, len(loss_history) + 1)
  296. plt.rcParams["figure.figsize"] = (16,10)
  297. plt.xticks(fontsize=20)
  298. plt.yticks(fontsize=20)
  299. plt.xlabel('runs', fontsize=25)
  300. plt.ylabel('loss', fontsize=25)
  301. plt.title('Rough learning loss for discriminator', fontsize=25)
  302. plt.plot(run_range, loss_history)
  303. plt.show()
  304. self.conv_sample_generator = generator
  305. self.maj_min_discriminator = discriminator
  306. self.cg = GAN
  307. self.loss_history = loss_history
  308. def _BMB(self, data_maj, min_idxs):
  309. ## Generate a borderline majority batch
  310. ## data_maj -> majority class data
  311. ## min_idxs -> indices of points in minority class
  312. ## gen -> convex combinations generated from each neighbourhood
  313. if self.nmbMaj is not None:
  314. return self.nmbMaj.neighbourhoodOfItemList(shuffle(min_idxs), maxCount=self.gen)
  315. else:
  316. return tf.convert_to_tensor(data_maj[np.random.randint(len(data_maj), size=self.gen)])
  317. def retrainDiscriminitor(self, data, labels):
  318. self.maj_min_discriminator.trainable = True
  319. labels = np.array([ [x, 1 - x] for x in labels])
  320. self.maj_min_discriminator.fit(x=data, y=labels, batch_size=20, epochs=self.neb_epochs)
  321. self.maj_min_discriminator.trainable = False