NextConvGeN.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421
  1. import numpy as np
  2. import matplotlib.pyplot as plt
  3. from library.interfaces import GanBaseClass
  4. from library.dataset import DataSet
  5. from library.timing import timing
  6. from keras.layers import Dense, Input, Multiply, Flatten, Conv1D, Reshape
  7. from keras.models import Model
  8. from keras import backend as K
  9. from tqdm import tqdm
  10. import tensorflow as tf
  11. from tensorflow.keras.optimizers import Adam
  12. from tensorflow.keras.layers import Lambda
  13. from sklearn.utils import shuffle
  14. from library.NNSearch import NNSearch
  15. import warnings
  16. warnings.filterwarnings("ignore")
  17. def repeat(x, times):
  18. return [x for _i in range(times)]
  19. def create01Labels(totalSize, sizeFirstHalf):
  20. labels = repeat(np.array([1,0]), sizeFirstHalf)
  21. labels.extend(repeat(np.array([0,1]), totalSize - sizeFirstHalf))
  22. return np.array(labels)
  23. class NextConvGeN(GanBaseClass):
  24. """
  25. This is the ConvGeN class. ConvGeN is a synthetic point generator for imbalanced datasets.
  26. """
  27. def __init__(self, n_feat, neb=5, gen=None, neb_epochs=10, fdc=None, maj_proximal=False, debug=False):
  28. self.isTrained = False
  29. self.n_feat = n_feat
  30. self.neb = neb
  31. self.nebInitial = neb
  32. self.genInitial = gen
  33. self.gen = gen if gen is not None else self.neb
  34. self.neb_epochs = neb_epochs
  35. self.loss_history = None
  36. self.debug = debug
  37. self.minSetSize = 0
  38. self.conv_sample_generator = None
  39. self.maj_min_discriminator = None
  40. self.maj_proximal = maj_proximal
  41. self.cg = None
  42. self.canPredict = True
  43. self.fdc = fdc
  44. self.timing = { n: timing(n) for n in [
  45. "Train", "BMB", "NbhSearch"
  46. ] }
  47. if self.neb is not None and self.gen is not None and self.neb > self.gen:
  48. raise ValueError(f"Expected neb <= gen but got neb={neb} and gen={gen}.")
  49. def reset(self, data):
  50. """
  51. Creates the network.
  52. *dataSet* is a instance of /library.dataset.DataSet/ or None.
  53. It contains the training dataset.
  54. It is used to determine the neighbourhood size if /neb/ in /__init__/ was None.
  55. """
  56. self.isTrained = False
  57. if data is not None:
  58. nMinoryPoints = data.shape[0]
  59. if self.nebInitial is None:
  60. self.neb = nMinoryPoints
  61. else:
  62. self.neb = min(self.nebInitial, nMinoryPoints)
  63. else:
  64. self.neb = self.nebInitial
  65. self.gen = self.genInitial if self.genInitial is not None else self.neb
  66. ## instanciate generator network and visualize architecture
  67. self.conv_sample_generator = self._conv_sample_gen()
  68. ## instanciate discriminator network and visualize architecture
  69. self.maj_min_discriminator = self._maj_min_disc()
  70. ## instanciate network and visualize architecture
  71. self.cg = self._convGeN(self.conv_sample_generator, self.maj_min_discriminator)
  72. if self.debug:
  73. print(f"neb={self.neb}, gen={self.gen}")
  74. print(self.conv_sample_generator.summary())
  75. print('\n')
  76. print(self.maj_min_discriminator.summary())
  77. print('\n')
  78. print(self.cg.summary())
  79. print('\n')
  80. def train(self, data, discTrainCount=5):
  81. """
  82. Trains the Network.
  83. *dataSet* is a instance of /library.dataset.DataSet/. It contains the training dataset.
  84. *discTrainCount* gives the number of extra training for the discriminator for each epoch. (>= 0)
  85. """
  86. if data.shape[0] <= 0:
  87. raise AttributeError("Train: Expected data class 1 to contain at least one point.")
  88. self.timing["Train"].start()
  89. # Store size of minority class. This is needed during point generation.
  90. self.minSetSize = data.shape[0]
  91. normalizedData = data
  92. if self.fdc is not None:
  93. normalizedData = self.fdc.normalize(data)
  94. self.timing["NbhSearch"].start()
  95. # Precalculate neighborhoods
  96. self.nmbMin = NNSearch(self.neb).fit(haystack=normalizedData)
  97. self.nmbMin.basePoints = data
  98. self.timing["NbhSearch"].stop()
  99. # Do the training.
  100. self._rough_learning(data, discTrainCount)
  101. # Neighborhood in majority class is no longer needed. So save memory.
  102. self.isTrained = True
  103. self.timing["Train"].stop()
  104. def generateDataPoint(self):
  105. """
  106. Returns one synthetic data point by repeating the stored list.
  107. """
  108. return (self.generateData(1))[0]
  109. def generateData(self, numOfSamples=1):
  110. """
  111. Generates a list of synthetic data-points.
  112. *numOfSamples* is a integer > 0. It gives the number of new generated samples.
  113. """
  114. if not self.isTrained:
  115. raise ValueError("Try to generate data with untrained network.")
  116. ## roughly claculate the upper bound of the synthetic samples to be generated from each neighbourhood
  117. synth_num = (numOfSamples // self.minSetSize) + 1
  118. ## generate synth_num synthetic samples from each minority neighbourhood
  119. synth_set=[]
  120. for i in range(self.minSetSize):
  121. synth_set.extend(self._generate_data_for_min_point(i, synth_num))
  122. ## extract the exact number of synthetic samples needed to exactly balance the two classes
  123. synth_set = np.array(synth_set[:numOfSamples])
  124. return synth_set
  125. def predictReal(self, data):
  126. """
  127. Uses the discriminator on data.
  128. *data* is a numpy array of shape (n, n_feat) where n is the number of datapoints and n_feat the number of features.
  129. """
  130. prediction = self.maj_min_discriminator.predict(data)
  131. return np.array([x[0] for x in prediction])
  132. # ###############################################################
  133. # Hidden internal functions
  134. # ###############################################################
  135. # Creating the Network: Generator
  136. def _conv_sample_gen(self):
  137. """
  138. The generator network to generate synthetic samples from the convex space
  139. of arbitrary minority neighbourhoods
  140. """
  141. ## takes minority batch as input
  142. min_neb_batch = Input(shape=(self.n_feat,))
  143. ## reshaping the 2D tensor to 3D for using 1-D convolution,
  144. ## otherwise 1-D convolution won't work.
  145. x = tf.reshape(min_neb_batch, (1, self.neb, self.n_feat), name=None)
  146. ## using 1-D convolution, feature dimension remains the same
  147. x = Conv1D(self.n_feat, 3, activation='relu')(x)
  148. ## flatten after convolution
  149. x = Flatten()(x)
  150. ## add dense layer to transform the vector to a convenient dimension
  151. x = Dense(self.neb * self.gen, activation='relu')(x)
  152. ## again, witching to 2-D tensor once we have the convenient shape
  153. x = Reshape((self.neb, self.gen))(x)
  154. ## column wise sum
  155. s = K.sum(x, axis=1)
  156. ## adding a small constant to always ensure the column sums are non zero.
  157. ## if this is not done then during initialization the sum can be zero.
  158. s_non_zero = Lambda(lambda x: x + .000001)(s)
  159. ## reprocals of the approximated column sum
  160. sinv = tf.math.reciprocal(s_non_zero)
  161. ## At this step we ensure that column sum is 1 for every row in x.
  162. ## That means, each column is set of convex co-efficient
  163. x = Multiply()([sinv, x])
  164. ## Now we transpose the matrix. So each row is now a set of convex coefficients
  165. aff=tf.transpose(x[0])
  166. ## We now do matrix multiplication of the affine combinations with the original
  167. ## minority batch taken as input. This generates a convex transformation
  168. ## of the input minority batch
  169. synth=tf.matmul(aff, min_neb_batch)
  170. ## finally we compile the generator with an arbitrary minortiy neighbourhood batch
  171. ## as input and a covex space transformation of the same number of samples as output
  172. model = Model(inputs=min_neb_batch, outputs=synth)
  173. opt = Adam(learning_rate=0.001)
  174. model.compile(loss='mean_squared_logarithmic_error', optimizer=opt)
  175. return model
  176. # Creating the Network: discriminator
  177. def _maj_min_disc(self):
  178. """
  179. the discriminator is trained in two phase:
  180. first phase: while training ConvGeN the discriminator learns to differentiate synthetic
  181. minority samples generated from convex minority data space against
  182. the borderline majority samples
  183. second phase: after the ConvGeN generator learns to create synthetic samples,
  184. it can be used to generate synthetic samples to balance the dataset
  185. and then rettrain the discriminator with the balanced dataset
  186. """
  187. ## takes as input synthetic sample generated as input stacked upon a batch of
  188. ## borderline majority samples
  189. samples = Input(shape=(self.n_feat,))
  190. ## passed through two dense layers
  191. y = Dense(250, activation='relu')(samples)
  192. y = Dense(125, activation='relu')(y)
  193. y = Dense(75, activation='relu')(y)
  194. ## two output nodes. outputs have to be one-hot coded (see labels variable before)
  195. output = Dense(2, activation='sigmoid')(y)
  196. ## compile model
  197. model = Model(inputs=samples, outputs=output)
  198. opt = Adam(learning_rate=0.0001)
  199. model.compile(loss='binary_crossentropy', optimizer=opt)
  200. return model
  201. # Creating the Network: ConvGeN
  202. def _convGeN(self, generator, discriminator):
  203. """
  204. for joining the generator and the discriminator
  205. conv_coeff_generator-> generator network instance
  206. maj_min_discriminator -> discriminator network instance
  207. """
  208. ## by default the discriminator trainability is switched off.
  209. ## Thus training ConvGeN means training the generator network as per previously
  210. ## trained discriminator network.
  211. discriminator.trainable = False
  212. ## input receives a neighbourhood minority batch
  213. ## and a proximal majority batch concatenated
  214. batch_data = Input(shape=(self.n_feat,))
  215. ## extract minority batch
  216. min_batch = Lambda(lambda x: x[:self.neb])(batch_data)
  217. ## extract majority batch
  218. maj_batch = Lambda(lambda x: x[self.gen:])(batch_data)
  219. ## pass minority batch into generator to obtain convex space transformation
  220. ## (synthetic samples) of the minority neighbourhood input batch
  221. conv_samples = generator(min_batch)
  222. ## concatenate the synthetic samples with the majority samples
  223. new_samples = tf.concat([conv_samples, maj_batch],axis=0)
  224. ## pass the concatenated vector into the discriminator to know its decisions
  225. output = discriminator(new_samples)
  226. ## note that, the discriminator will not be traied but will make decisions based
  227. ## on its previous training while using this function
  228. model = Model(inputs=batch_data, outputs=output)
  229. opt = Adam(learning_rate=0.0001)
  230. model.compile(loss='mse', optimizer=opt)
  231. return model
  232. # Create synthetic points
  233. def _generate_data_for_min_point(self, index, synth_num):
  234. """
  235. generate synth_num synthetic points for a particular minoity sample
  236. synth_num -> required number of data points that can be generated from a neighbourhood
  237. data_min -> minority class data
  238. neb -> oversampling neighbourhood
  239. index -> index of the minority sample in a training data whose neighbourhood we want to obtain
  240. """
  241. runs = int(synth_num / self.neb) + 1
  242. synth_set = []
  243. for _run in range(runs):
  244. batch = self.nmbMin.getNbhPointsOfItem(index)
  245. synth_batch = self.conv_sample_generator.predict(batch, batch_size=self.neb)
  246. synth_set.extend(synth_batch)
  247. return synth_set[:synth_num]
  248. # Training
  249. def _rough_learning(self, data, discTrainCount):
  250. generator = self.conv_sample_generator
  251. discriminator = self.maj_min_discriminator
  252. convGeN = self.cg
  253. loss_history = [] ## this is for stroring the loss for every run
  254. step = 0
  255. minSetSize = len(data)
  256. labels = tf.convert_to_tensor(create01Labels(2 * self.gen, self.gen))
  257. nLabels = 2 * self.gen
  258. for neb_epoch_count in range(self.neb_epochs):
  259. print(f"NEB EPOCH #{neb_epoch_count + 1} / {self.neb_epochs}")
  260. if discTrainCount > 0:
  261. for n in range(discTrainCount):
  262. print(f"discTrain #{n + 1} / {discTrainCount}")
  263. for min_idx in range(minSetSize):
  264. ## generate minority neighbourhood batch for every minority class sampls by index
  265. min_batch_indices = shuffle(self.nmbMin.neighbourhoodOfItem(min_idx))
  266. min_batch = self.nmbMin.getPointsFromIndices(min_batch_indices)
  267. ## generate random proximal majority batch
  268. maj_batch = self._BMB(min_batch_indices)
  269. ## generate synthetic samples from convex space
  270. ## of minority neighbourhood batch using generator
  271. conv_samples = generator.predict(min_batch, batch_size=self.neb)
  272. ## concatenate them with the majority batch
  273. concat_sample = tf.concat([conv_samples, maj_batch], axis=0)
  274. ## switch on discriminator training
  275. discriminator.trainable = True
  276. ## train the discriminator with the concatenated samples and the one-hot encoded labels
  277. discriminator.fit(x=concat_sample, y=labels, verbose=0, batch_size=20)
  278. ## switch off the discriminator training again
  279. discriminator.trainable = False
  280. for min_idx in range(minSetSize):
  281. ## generate minority neighbourhood batch for every minority class sampls by index
  282. min_batch_indices = shuffle(self.nmbMin.neighbourhoodOfItem(min_idx))
  283. min_batch = self.nmbMin.getPointsFromIndices(min_batch_indices)
  284. ## generate random proximal majority batch
  285. maj_batch = self._BMB(min_batch_indices)
  286. ## generate synthetic samples from convex space
  287. ## of minority neighbourhood batch using generator
  288. conv_samples = generator.predict(min_batch, batch_size=self.neb)
  289. ## concatenate them with the majority batch
  290. concat_sample = tf.concat([conv_samples, maj_batch], axis=0)
  291. ## switch on discriminator training
  292. discriminator.trainable = True
  293. ## train the discriminator with the concatenated samples and the one-hot encoded labels
  294. discriminator.fit(x=concat_sample, y=labels, verbose=0, batch_size=20)
  295. ## switch off the discriminator training again
  296. discriminator.trainable = False
  297. ## use the complete network to make the generator learn on the decisions
  298. ## made by the previous discriminator training
  299. gen_loss_history = convGeN.fit(concat_sample, y=labels, verbose=0, batch_size=nLabels)
  300. ## store the loss for the step
  301. loss_history.append(gen_loss_history.history['loss'])
  302. step += 1
  303. if self.debug and (step % 10 == 0):
  304. print(f"{step} neighbourhood batches trained; running neighbourhood epoch {neb_epoch_count}")
  305. if self.debug:
  306. print(f"Neighbourhood epoch {neb_epoch_count + 1} complete")
  307. if self.debug:
  308. run_range = range(1, len(loss_history) + 1)
  309. plt.rcParams["figure.figsize"] = (16,10)
  310. plt.xticks(fontsize=20)
  311. plt.yticks(fontsize=20)
  312. plt.xlabel('runs', fontsize=25)
  313. plt.ylabel('loss', fontsize=25)
  314. plt.title('Rough learning loss for discriminator', fontsize=25)
  315. plt.plot(run_range, loss_history)
  316. plt.show()
  317. self.conv_sample_generator = generator
  318. self.maj_min_discriminator = discriminator
  319. self.cg = convGeN
  320. self.loss_history = loss_history
  321. def _BMB(self, min_idxs):
  322. ## Generate a borderline majority batch
  323. ## data_maj -> majority class data
  324. ## min_idxs -> indices of points in minority class
  325. ## gen -> convex combinations generated from each neighbourhood
  326. self.timing["BMB"].start()
  327. indices = [i for i in range(self.minSetSize) if i not in min_idxs]
  328. r = np.array([self.nmbMin.basePoints[i] for i in shuffle(indices)[0:self.gen]])
  329. self.timing["BMB"].stop()
  330. return r
  331. def retrainDiscriminitor(self, data, labels):
  332. self.maj_min_discriminator.trainable = True
  333. labels = np.array([ [x, 1 - x] for x in labels])
  334. self.maj_min_discriminator.fit(x=data, y=labels, batch_size=20, epochs=self.neb_epochs)
  335. self.maj_min_discriminator.trainable = False