NextConvGeN.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479
  1. import numpy as np
  2. import matplotlib.pyplot as plt
  3. from library.interfaces import GanBaseClass
  4. from library.dataset import DataSet
  5. from library.timing import timing
  6. from keras.layers import Dense, Input, Multiply, Flatten, Conv1D, Reshape
  7. from keras.models import Model
  8. from keras import backend as K
  9. from tqdm import tqdm
  10. import tensorflow as tf
  11. from tensorflow.keras.optimizers import Adam
  12. from tensorflow.keras.layers import Lambda
  13. from sklearn.utils import shuffle
  14. from library.NNSearch import NNSearch
  15. import warnings
  16. warnings.filterwarnings("ignore")
  17. def repeat(x, times):
  18. return [x for _i in range(times)]
  19. def create01Labels(totalSize, sizeFirstHalf):
  20. labels = repeat(np.array([1,0]), sizeFirstHalf)
  21. labels.extend(repeat(np.array([0,1]), totalSize - sizeFirstHalf))
  22. return np.array(labels)
  23. class NextConvGeN(GanBaseClass):
  24. """
  25. This is the ConvGeN class. ConvGeN is a synthetic point generator for imbalanced datasets.
  26. """
  27. def __init__(self, n_feat, neb=5, gen=None, neb_epochs=10, fdc=None, maj_proximal=False, debug=False):
  28. self.isTrained = False
  29. self.n_feat = n_feat
  30. self.neb = neb
  31. self.nebInitial = neb
  32. self.genInitial = gen
  33. self.gen = gen if gen is not None else self.neb
  34. self.neb_epochs = neb_epochs
  35. self.loss_history = None
  36. self.debug = debug
  37. self.minSetSize = 0
  38. self.conv_sample_generator = None
  39. self.maj_min_discriminator = None
  40. self.maj_proximal = maj_proximal
  41. self.cg = None
  42. self.canPredict = True
  43. self.fdc = fdc
  44. self.lastProgress = (-1,-1,-1)
  45. self.timing = { n: timing(n) for n in [
  46. "Train", "BMB", "NbhSearch", "NBH", "GenSamples", "Fit"
  47. ] }
  48. if self.neb is not None and self.gen is not None and self.neb > self.gen:
  49. raise ValueError(f"Expected neb <= gen but got neb={neb} and gen={gen}.")
  50. def reset(self, data):
  51. """
  52. Creates the network.
  53. *dataSet* is a instance of /library.dataset.DataSet/ or None.
  54. It contains the training dataset.
  55. It is used to determine the neighbourhood size if /neb/ in /__init__/ was None.
  56. """
  57. self.isTrained = False
  58. if data is not None:
  59. nMinoryPoints = data.shape[0]
  60. if self.nebInitial is None:
  61. self.neb = nMinoryPoints
  62. else:
  63. self.neb = min(self.nebInitial, nMinoryPoints)
  64. else:
  65. self.neb = self.nebInitial
  66. self.gen = self.genInitial if self.genInitial is not None else self.neb
  67. ## instanciate generator network and visualize architecture
  68. self.conv_sample_generator = self._conv_sample_gen()
  69. ## instanciate discriminator network and visualize architecture
  70. self.maj_min_discriminator = self._maj_min_disc()
  71. ## instanciate network and visualize architecture
  72. self.cg = self._convGeN(self.conv_sample_generator, self.maj_min_discriminator)
  73. self.lastProgress = (-1,-1,-1)
  74. if self.debug:
  75. print(f"neb={self.neb}, gen={self.gen}")
  76. print(self.conv_sample_generator.summary())
  77. print('\n')
  78. print(self.maj_min_discriminator.summary())
  79. print('\n')
  80. print(self.cg.summary())
  81. print('\n')
  82. def train(self, data, discTrainCount=5):
  83. """
  84. Trains the Network.
  85. *dataSet* is a instance of /library.dataset.DataSet/. It contains the training dataset.
  86. *discTrainCount* gives the number of extra training for the discriminator for each epoch. (>= 0)
  87. """
  88. if data.shape[0] <= 0:
  89. raise AttributeError("Train: Expected data class 1 to contain at least one point.")
  90. self.timing["Train"].start()
  91. # Store size of minority class. This is needed during point generation.
  92. self.minSetSize = data.shape[0]
  93. normalizedData = data
  94. if self.fdc is not None:
  95. normalizedData = self.fdc.normalize(data)
  96. print(f"|N| = {normalizedData.shape}")
  97. print(f"|D| = {data.shape}")
  98. self.timing["NbhSearch"].start()
  99. # Precalculate neighborhoods
  100. self.nmbMin = NNSearch(self.neb).fit(haystack=normalizedData)
  101. self.nmbMin.basePoints = data
  102. self.timing["NbhSearch"].stop()
  103. # Do the training.
  104. self._rough_learning(data, discTrainCount)
  105. # Neighborhood in majority class is no longer needed. So save memory.
  106. self.isTrained = True
  107. self.timing["Train"].stop()
  108. def generateDataPoint(self):
  109. """
  110. Returns one synthetic data point by repeating the stored list.
  111. """
  112. return (self.generateData(1))[0]
  113. def generateData(self, numOfSamples=1):
  114. """
  115. Generates a list of synthetic data-points.
  116. *numOfSamples* is a integer > 0. It gives the number of new generated samples.
  117. """
  118. if not self.isTrained:
  119. raise ValueError("Try to generate data with untrained network.")
  120. ## roughly claculate the upper bound of the synthetic samples to be generated from each neighbourhood
  121. synth_num = (numOfSamples // self.minSetSize) + 1
  122. ## generate synth_num synthetic samples from each minority neighbourhood
  123. synth_set=[]
  124. for i in range(self.minSetSize):
  125. synth_set.extend(self._generate_data_for_min_point(i, synth_num))
  126. ## extract the exact number of synthetic samples needed to exactly balance the two classes
  127. synth_set = np.array(synth_set[:numOfSamples])
  128. return synth_set
  129. def predictReal(self, data):
  130. """
  131. Uses the discriminator on data.
  132. *data* is a numpy array of shape (n, n_feat) where n is the number of datapoints and n_feat the number of features.
  133. """
  134. prediction = self.maj_min_discriminator.predict(data)
  135. return np.array([x[0] for x in prediction])
  136. # ###############################################################
  137. # Hidden internal functions
  138. # ###############################################################
  139. # Creating the Network: Generator
  140. def _conv_sample_gen(self):
  141. """
  142. The generator network to generate synthetic samples from the convex space
  143. of arbitrary minority neighbourhoods
  144. """
  145. ## takes minority batch as input
  146. min_neb_batch = Input(shape=(self.n_feat,))
  147. ## reshaping the 2D tensor to 3D for using 1-D convolution,
  148. ## otherwise 1-D convolution won't work.
  149. x = tf.reshape(min_neb_batch, (1, self.neb, self.n_feat), name=None)
  150. ## using 1-D convolution, feature dimension remains the same
  151. x = Conv1D(self.n_feat, 3, activation='relu')(x)
  152. ## flatten after convolution
  153. x = Flatten()(x)
  154. ## add dense layer to transform the vector to a convenient dimension
  155. x = Dense(self.neb * self.gen, activation='relu')(x)
  156. ## again, witching to 2-D tensor once we have the convenient shape
  157. x = Reshape((self.neb, self.gen))(x)
  158. ## column wise sum
  159. s = K.sum(x, axis=1)
  160. ## adding a small constant to always ensure the column sums are non zero.
  161. ## if this is not done then during initialization the sum can be zero.
  162. s_non_zero = Lambda(lambda x: x + .000001)(s)
  163. ## reprocals of the approximated column sum
  164. sinv = tf.math.reciprocal(s_non_zero)
  165. ## At this step we ensure that column sum is 1 for every row in x.
  166. ## That means, each column is set of convex co-efficient
  167. x = Multiply()([sinv, x])
  168. ## Now we transpose the matrix. So each row is now a set of convex coefficients
  169. aff=tf.transpose(x[0])
  170. ## We now do matrix multiplication of the affine combinations with the original
  171. ## minority batch taken as input. This generates a convex transformation
  172. ## of the input minority batch
  173. synth=tf.matmul(aff, min_neb_batch)
  174. ## finally we compile the generator with an arbitrary minortiy neighbourhood batch
  175. ## as input and a covex space transformation of the same number of samples as output
  176. model = Model(inputs=min_neb_batch, outputs=synth)
  177. opt = Adam(learning_rate=0.001)
  178. model.compile(loss='mean_squared_logarithmic_error', optimizer=opt)
  179. return model
  180. # Creating the Network: discriminator
  181. def _maj_min_disc(self):
  182. """
  183. the discriminator is trained in two phase:
  184. first phase: while training ConvGeN the discriminator learns to differentiate synthetic
  185. minority samples generated from convex minority data space against
  186. the borderline majority samples
  187. second phase: after the ConvGeN generator learns to create synthetic samples,
  188. it can be used to generate synthetic samples to balance the dataset
  189. and then rettrain the discriminator with the balanced dataset
  190. """
  191. ## takes as input synthetic sample generated as input stacked upon a batch of
  192. ## borderline majority samples
  193. samples = Input(shape=(self.n_feat,))
  194. ## passed through two dense layers
  195. y = Dense(250, activation='relu')(samples)
  196. y = Dense(125, activation='relu')(y)
  197. y = Dense(75, activation='relu')(y)
  198. ## two output nodes. outputs have to be one-hot coded (see labels variable before)
  199. output = Dense(2, activation='sigmoid')(y)
  200. ## compile model
  201. model = Model(inputs=samples, outputs=output)
  202. opt = Adam(learning_rate=0.0001)
  203. model.compile(loss='binary_crossentropy', optimizer=opt)
  204. return model
  205. # Creating the Network: ConvGeN
  206. def _convGeN(self, generator, discriminator):
  207. """
  208. for joining the generator and the discriminator
  209. conv_coeff_generator-> generator network instance
  210. maj_min_discriminator -> discriminator network instance
  211. """
  212. ## by default the discriminator trainability is switched off.
  213. ## Thus training ConvGeN means training the generator network as per previously
  214. ## trained discriminator network.
  215. discriminator.trainable = False
  216. ## input receives a neighbourhood minority batch
  217. ## and a proximal majority batch concatenated
  218. batch_data = Input(shape=(self.n_feat,))
  219. ## extract minority batch
  220. min_batch = Lambda(lambda x: x[:self.neb])(batch_data)
  221. ## extract majority batch
  222. maj_batch = Lambda(lambda x: x[self.gen:])(batch_data)
  223. ## pass minority batch into generator to obtain convex space transformation
  224. ## (synthetic samples) of the minority neighbourhood input batch
  225. conv_samples = generator(min_batch)
  226. ## concatenate the synthetic samples with the majority samples
  227. new_samples = tf.concat([conv_samples, maj_batch],axis=0)
  228. ## pass the concatenated vector into the discriminator to know its decisions
  229. output = discriminator(new_samples)
  230. ## note that, the discriminator will not be traied but will make decisions based
  231. ## on its previous training while using this function
  232. model = Model(inputs=batch_data, outputs=output)
  233. opt = Adam(learning_rate=0.0001)
  234. model.compile(loss='mse', optimizer=opt)
  235. return model
  236. # Create synthetic points
  237. def _generate_data_for_min_point(self, index, synth_num):
  238. """
  239. generate synth_num synthetic points for a particular minoity sample
  240. synth_num -> required number of data points that can be generated from a neighbourhood
  241. data_min -> minority class data
  242. neb -> oversampling neighbourhood
  243. index -> index of the minority sample in a training data whose neighbourhood we want to obtain
  244. """
  245. runs = int(synth_num / self.neb) + 1
  246. synth_set = []
  247. for _run in range(runs):
  248. batch = self.nmbMin.getNbhPointsOfItem(index)
  249. synth_batch = self.conv_sample_generator.predict(batch, batch_size=self.neb)
  250. synth_batch = self.correct_feature_types(batch, synth_batch)
  251. synth_set.extend(synth_batch)
  252. return synth_set[:synth_num]
  253. # Training
  254. def _rough_learning(self, data, discTrainCount):
  255. generator = self.conv_sample_generator
  256. discriminator = self.maj_min_discriminator
  257. convGeN = self.cg
  258. loss_history = [] ## this is for stroring the loss for every run
  259. step = 0
  260. minSetSize = len(data)
  261. labels = tf.convert_to_tensor(create01Labels(2 * self.gen, self.gen))
  262. nLabels = 2 * self.gen
  263. for neb_epoch_count in range(self.neb_epochs):
  264. if discTrainCount > 0:
  265. for n in range(discTrainCount):
  266. for min_idx in range(minSetSize):
  267. self.progressBar([(neb_epoch_count + 1) / self.neb_epochs, n / discTrainCount, (min_idx + 1) / minSetSize])
  268. self.timing["NBH"].start()
  269. ## generate minority neighbourhood batch for every minority class sampls by index
  270. min_batch_indices = shuffle(self.nmbMin.neighbourhoodOfItem(min_idx))
  271. min_batch = self.nmbMin.getPointsFromIndices(min_batch_indices)
  272. ## generate random proximal majority batch
  273. maj_batch = self._BMB(min_batch_indices)
  274. self.timing["NBH"].stop()
  275. self.timing["GenSamples"].start()
  276. ## generate synthetic samples from convex space
  277. ## of minority neighbourhood batch using generator
  278. conv_samples = generator.predict(min_batch, batch_size=self.neb)
  279. ## concatenate them with the majority batch
  280. concat_sample = tf.concat([conv_samples, maj_batch], axis=0)
  281. self.timing["GenSamples"].stop()
  282. self.timing["Fit"].start()
  283. ## switch on discriminator training
  284. discriminator.trainable = True
  285. ## train the discriminator with the concatenated samples and the one-hot encoded labels
  286. discriminator.fit(x=concat_sample, y=labels, verbose=0, batch_size=20)
  287. ## switch off the discriminator training again
  288. discriminator.trainable = False
  289. self.timing["Fit"].stop()
  290. for min_idx in range(minSetSize):
  291. self.progressBar([(neb_epoch_count + 1) / self.neb_epochs, 1.0, (min_idx + 1) / minSetSize])
  292. ## generate minority neighbourhood batch for every minority class sampls by index
  293. min_batch_indices = shuffle(self.nmbMin.neighbourhoodOfItem(min_idx))
  294. min_batch = self.nmbMin.getPointsFromIndices(min_batch_indices)
  295. ## generate random proximal majority batch
  296. maj_batch = self._BMB(min_batch_indices)
  297. ## generate synthetic samples from convex space
  298. ## of minority neighbourhood batch using generator
  299. conv_samples = generator.predict(min_batch, batch_size=self.neb)
  300. ## concatenate them with the majority batch
  301. concat_sample = tf.concat([conv_samples, maj_batch], axis=0)
  302. ## switch on discriminator training
  303. discriminator.trainable = True
  304. ## train the discriminator with the concatenated samples and the one-hot encoded labels
  305. discriminator.fit(x=concat_sample, y=labels, verbose=0, batch_size=20)
  306. ## switch off the discriminator training again
  307. discriminator.trainable = False
  308. ## use the complete network to make the generator learn on the decisions
  309. ## made by the previous discriminator training
  310. gen_loss_history = convGeN.fit(concat_sample, y=labels, verbose=0, batch_size=nLabels)
  311. ## store the loss for the step
  312. loss_history.append(gen_loss_history.history['loss'])
  313. step += 1
  314. if self.debug and (step % 10 == 0):
  315. print(f"{step} neighbourhood batches trained; running neighbourhood epoch {neb_epoch_count}")
  316. if self.debug:
  317. print(f"Neighbourhood epoch {neb_epoch_count + 1} complete")
  318. if self.debug:
  319. run_range = range(1, len(loss_history) + 1)
  320. plt.rcParams["figure.figsize"] = (16,10)
  321. plt.xticks(fontsize=20)
  322. plt.yticks(fontsize=20)
  323. plt.xlabel('runs', fontsize=25)
  324. plt.ylabel('loss', fontsize=25)
  325. plt.title('Rough learning loss for discriminator', fontsize=25)
  326. plt.plot(run_range, loss_history)
  327. plt.show()
  328. self.conv_sample_generator = generator
  329. self.maj_min_discriminator = discriminator
  330. self.cg = convGeN
  331. self.loss_history = loss_history
  332. def _BMB(self, min_idxs):
  333. ## Generate a borderline majority batch
  334. ## data_maj -> majority class data
  335. ## min_idxs -> indices of points in minority class
  336. ## gen -> convex combinations generated from each neighbourhood
  337. self.timing["BMB"].start()
  338. indices = [i for i in range(self.minSetSize) if i not in min_idxs]
  339. r = np.array([self.nmbMin.basePoints[i] for i in shuffle(indices)[0:self.gen]])
  340. self.timing["BMB"].stop()
  341. return r
  342. def retrainDiscriminitor(self, data, labels):
  343. self.maj_min_discriminator.trainable = True
  344. labels = np.array([ [x, 1 - x] for x in labels])
  345. self.maj_min_discriminator.fit(x=data, y=labels, batch_size=20, epochs=self.neb_epochs)
  346. self.maj_min_discriminator.trainable = False
  347. def progressBar(self, x):
  348. x = [int(v * 10) for v in x]
  349. if True not in [self.lastProgress[i] != x[i] for i in range(len(self.lastProgress))]:
  350. return
  351. def bar(v):
  352. r = ""
  353. for n in range(10):
  354. if n > v:
  355. r += " "
  356. else:
  357. r += "="
  358. return r
  359. s = [bar(v) for v in x]
  360. print(f"[{s[0]}] [{s[1]}] [{s[2]}]", end="\r")
  361. def correct_feature_types(self, batch, synth_batch):
  362. if self.fdc is None:
  363. return synth_batch
  364. def bestMatchOf(referenceValues, value):
  365. best = referenceValues[0]
  366. d = abs(best - value)
  367. for x in referenceValues:
  368. dx = abs(x - value)
  369. if dx < d:
  370. best = x
  371. d = dx
  372. return best
  373. synth_batch = list(synth_batch)
  374. for i in (self.fdc.nom_list or []):
  375. referenceValues = list(set(list(batch[:, i].numpy())))
  376. for x in synth_batch:
  377. y = x[i]
  378. x[i] = bestMatchOf(referenceValues, y)
  379. for i in (self.fdc.ord_list or []):
  380. referenceValues = list(set(list(batch[:, i].numpy())))
  381. for x in synth_batch:
  382. x[i] = bestMatchOf(referenceValues, x[i])
  383. return np.array(synth_batch)