NextConvGeN.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409
  1. import numpy as np
  2. import matplotlib.pyplot as plt
  3. from library.interfaces import GanBaseClass
  4. from library.dataset import DataSet
  5. from keras.layers import Dense, Input, Multiply, Flatten, Conv1D, Reshape
  6. from keras.models import Model
  7. from keras import backend as K
  8. from tqdm import tqdm
  9. import tensorflow as tf
  10. from tensorflow.keras.optimizers import Adam
  11. from tensorflow.keras.layers import Lambda
  12. from sklearn.utils import shuffle
  13. from library.NNSearch import NNSearch
  14. import warnings
  15. warnings.filterwarnings("ignore")
  16. def repeat(x, times):
  17. return [x for _i in range(times)]
  18. def create01Labels(totalSize, sizeFirstHalf):
  19. labels = repeat(np.array([1,0]), sizeFirstHalf)
  20. labels.extend(repeat(np.array([0,1]), totalSize - sizeFirstHalf))
  21. return np.array(labels)
  22. class ConvGeN(GanBaseClass):
  23. """
  24. This is the ConvGeN class. ConvGeN is a synthetic point generator for imbalanced datasets.
  25. """
  26. def __init__(self, n_feat, neb=5, gen=None, neb_epochs=10, maj_proximal=False, debug=False):
  27. self.isTrained = False
  28. self.n_feat = n_feat
  29. self.neb = neb
  30. self.nebInitial = neb
  31. self.genInitial = gen
  32. self.gen = gen if gen is not None else self.neb
  33. self.neb_epochs = neb_epochs
  34. self.loss_history = None
  35. self.debug = debug
  36. self.minSetSize = 0
  37. self.conv_sample_generator = None
  38. self.maj_min_discriminator = None
  39. self.maj_proximal = maj_proximal
  40. self.cg = None
  41. self.canPredict = True
  42. if self.neb is not None and self.gen is not None and self.neb > self.gen:
  43. raise ValueError(f"Expected neb <= gen but got neb={neb} and gen={gen}.")
  44. def reset(self, dataSet):
  45. """
  46. Creates the network.
  47. *dataSet* is a instance of /library.dataset.DataSet/ or None.
  48. It contains the training dataset.
  49. It is used to determine the neighbourhood size if /neb/ in /__init__/ was None.
  50. """
  51. self.isTrained = False
  52. if dataSet is not None:
  53. nMinoryPoints = dataSet.data1.shape[0]
  54. if self.nebInitial is None:
  55. self.neb = nMinoryPoints
  56. else:
  57. self.neb = min(self.nebInitial, nMinoryPoints)
  58. else:
  59. self.neb = self.nebInitial
  60. self.gen = self.genInitial if self.genInitial is not None else self.neb
  61. ## instanciate generator network and visualize architecture
  62. self.conv_sample_generator = self._conv_sample_gen()
  63. ## instanciate discriminator network and visualize architecture
  64. self.maj_min_discriminator = self._maj_min_disc()
  65. ## instanciate network and visualize architecture
  66. self.cg = self._convGeN(self.conv_sample_generator, self.maj_min_discriminator)
  67. if self.debug:
  68. print(f"neb={self.neb}, gen={self.gen}")
  69. print(self.conv_sample_generator.summary())
  70. print('\n')
  71. print(self.maj_min_discriminator.summary())
  72. print('\n')
  73. print(self.cg.summary())
  74. print('\n')
  75. def train(self, dataSet, discTrainCount=5):
  76. """
  77. Trains the Network.
  78. *dataSet* is a instance of /library.dataset.DataSet/. It contains the training dataset.
  79. *discTrainCount* gives the number of extra training for the discriminator for each epoch. (>= 0)
  80. """
  81. if dataSet.data1.shape[0] <= 0:
  82. raise AttributeError("Train: Expected data class 1 to contain at least one point.")
  83. # Store size of minority class. This is needed during point generation.
  84. self.minSetSize = dataSet.data1.shape[0]
  85. # Precalculate neighborhoods
  86. self.nmbMin = NNSearch(self.neb).fit(haystack=dataSet.data1)
  87. if self.maj_proximal:
  88. self.nmbMaj = NNSearch(self.neb).fit(haystack=dataSet.data0, needles=dataSet.data1)
  89. else:
  90. self.nmbMaj = None
  91. # Do the training.
  92. self._rough_learning(dataSet.data1, dataSet.data0, discTrainCount)
  93. # Neighborhood in majority class is no longer needed. So save memory.
  94. self.nmbMaj = None
  95. self.isTrained = True
  96. def generateDataPoint(self):
  97. """
  98. Returns one synthetic data point by repeating the stored list.
  99. """
  100. return (self.generateData(1))[0]
  101. def generateData(self, numOfSamples=1):
  102. """
  103. Generates a list of synthetic data-points.
  104. *numOfSamples* is a integer > 0. It gives the number of new generated samples.
  105. """
  106. if not self.isTrained:
  107. raise ValueError("Try to generate data with untrained network.")
  108. ## roughly claculate the upper bound of the synthetic samples to be generated from each neighbourhood
  109. synth_num = (numOfSamples // self.minSetSize) + 1
  110. ## generate synth_num synthetic samples from each minority neighbourhood
  111. synth_set=[]
  112. for i in range(self.minSetSize):
  113. synth_set.extend(self._generate_data_for_min_point(i, synth_num))
  114. ## extract the exact number of synthetic samples needed to exactly balance the two classes
  115. synth_set = np.array(synth_set[:numOfSamples])
  116. return synth_set
  117. def predictReal(self, data):
  118. """
  119. Uses the discriminator on data.
  120. *data* is a numpy array of shape (n, n_feat) where n is the number of datapoints and n_feat the number of features.
  121. """
  122. prediction = self.maj_min_discriminator.predict(data)
  123. return np.array([x[0] for x in prediction])
  124. # ###############################################################
  125. # Hidden internal functions
  126. # ###############################################################
  127. # Creating the Network: Generator
  128. def _conv_sample_gen(self):
  129. """
  130. The generator network to generate synthetic samples from the convex space
  131. of arbitrary minority neighbourhoods
  132. """
  133. ## takes minority batch as input
  134. min_neb_batch = Input(shape=(self.n_feat,))
  135. ## reshaping the 2D tensor to 3D for using 1-D convolution,
  136. ## otherwise 1-D convolution won't work.
  137. x = tf.reshape(min_neb_batch, (1, self.neb, self.n_feat), name=None)
  138. ## using 1-D convolution, feature dimension remains the same
  139. x = Conv1D(self.n_feat, 3, activation='relu')(x)
  140. ## flatten after convolution
  141. x = Flatten()(x)
  142. ## add dense layer to transform the vector to a convenient dimension
  143. x = Dense(self.neb * self.gen, activation='relu')(x)
  144. ## again, witching to 2-D tensor once we have the convenient shape
  145. x = Reshape((self.neb, self.gen))(x)
  146. ## column wise sum
  147. s = K.sum(x, axis=1)
  148. ## adding a small constant to always ensure the column sums are non zero.
  149. ## if this is not done then during initialization the sum can be zero.
  150. s_non_zero = Lambda(lambda x: x + .000001)(s)
  151. ## reprocals of the approximated column sum
  152. sinv = tf.math.reciprocal(s_non_zero)
  153. ## At this step we ensure that column sum is 1 for every row in x.
  154. ## That means, each column is set of convex co-efficient
  155. x = Multiply()([sinv, x])
  156. ## Now we transpose the matrix. So each row is now a set of convex coefficients
  157. aff=tf.transpose(x[0])
  158. ## We now do matrix multiplication of the affine combinations with the original
  159. ## minority batch taken as input. This generates a convex transformation
  160. ## of the input minority batch
  161. synth=tf.matmul(aff, min_neb_batch)
  162. ## finally we compile the generator with an arbitrary minortiy neighbourhood batch
  163. ## as input and a covex space transformation of the same number of samples as output
  164. model = Model(inputs=min_neb_batch, outputs=synth)
  165. opt = Adam(learning_rate=0.001)
  166. model.compile(loss='mean_squared_logarithmic_error', optimizer=opt)
  167. return model
  168. # Creating the Network: discriminator
  169. def _maj_min_disc(self):
  170. """
  171. the discriminator is trained in two phase:
  172. first phase: while training ConvGeN the discriminator learns to differentiate synthetic
  173. minority samples generated from convex minority data space against
  174. the borderline majority samples
  175. second phase: after the ConvGeN generator learns to create synthetic samples,
  176. it can be used to generate synthetic samples to balance the dataset
  177. and then rettrain the discriminator with the balanced dataset
  178. """
  179. ## takes as input synthetic sample generated as input stacked upon a batch of
  180. ## borderline majority samples
  181. samples = Input(shape=(self.n_feat,))
  182. ## passed through two dense layers
  183. y = Dense(250, activation='relu')(samples)
  184. y = Dense(125, activation='relu')(y)
  185. y = Dense(75, activation='relu')(y)
  186. ## two output nodes. outputs have to be one-hot coded (see labels variable before)
  187. output = Dense(2, activation='sigmoid')(y)
  188. ## compile model
  189. model = Model(inputs=samples, outputs=output)
  190. opt = Adam(learning_rate=0.0001)
  191. model.compile(loss='binary_crossentropy', optimizer=opt)
  192. return model
  193. # Creating the Network: ConvGeN
  194. def _convGeN(self, generator, discriminator):
  195. """
  196. for joining the generator and the discriminator
  197. conv_coeff_generator-> generator network instance
  198. maj_min_discriminator -> discriminator network instance
  199. """
  200. ## by default the discriminator trainability is switched off.
  201. ## Thus training ConvGeN means training the generator network as per previously
  202. ## trained discriminator network.
  203. discriminator.trainable = False
  204. ## input receives a neighbourhood minority batch
  205. ## and a proximal majority batch concatenated
  206. batch_data = Input(shape=(self.n_feat,))
  207. ## extract minority batch
  208. min_batch = Lambda(lambda x: x[:self.neb])(batch_data)
  209. ## extract majority batch
  210. maj_batch = Lambda(lambda x: x[self.gen:])(batch_data)
  211. ## pass minority batch into generator to obtain convex space transformation
  212. ## (synthetic samples) of the minority neighbourhood input batch
  213. conv_samples = generator(min_batch)
  214. ## concatenate the synthetic samples with the majority samples
  215. new_samples = tf.concat([conv_samples, maj_batch],axis=0)
  216. ## pass the concatenated vector into the discriminator to know its decisions
  217. output = discriminator(new_samples)
  218. ## note that, the discriminator will not be traied but will make decisions based
  219. ## on its previous training while using this function
  220. model = Model(inputs=batch_data, outputs=output)
  221. opt = Adam(learning_rate=0.0001)
  222. model.compile(loss='mse', optimizer=opt)
  223. return model
  224. # Create synthetic points
  225. def _generate_data_for_min_point(self, index, synth_num):
  226. """
  227. generate synth_num synthetic points for a particular minoity sample
  228. synth_num -> required number of data points that can be generated from a neighbourhood
  229. data_min -> minority class data
  230. neb -> oversampling neighbourhood
  231. index -> index of the minority sample in a training data whose neighbourhood we want to obtain
  232. """
  233. runs = int(synth_num / self.neb) + 1
  234. synth_set = []
  235. for _run in range(runs):
  236. batch = self.nmbMin.getNbhPointsOfItem(index)
  237. synth_batch = self.conv_sample_generator.predict(batch, batch_size=self.neb)
  238. synth_set.extend(synth_batch)
  239. return synth_set[:synth_num]
  240. # Training
  241. def _rough_learning(self, data_min, data_maj, discTrainCount):
  242. generator = self.conv_sample_generator
  243. discriminator = self.maj_min_discriminator
  244. convGeN = self.cg
  245. loss_history = [] ## this is for stroring the loss for every run
  246. step = 0
  247. minSetSize = len(data_min)
  248. labels = tf.convert_to_tensor(create01Labels(2 * self.gen, self.gen))
  249. nLabels = 2 * self.gen
  250. for neb_epoch_count in range(self.neb_epochs):
  251. if discTrainCount > 0:
  252. for n in range(discTrainCount):
  253. for min_idx in range(minSetSize):
  254. ## generate minority neighbourhood batch for every minority class sampls by index
  255. min_batch_indices = shuffle(self.nmbMin.neighbourhoodOfItem(min_idx))
  256. min_batch = self.nmbMin.getPointsFromIndices(min_batch_indices)
  257. ## generate random proximal majority batch
  258. maj_batch = self._BMB(data_maj, min_batch_indices)
  259. ## generate synthetic samples from convex space
  260. ## of minority neighbourhood batch using generator
  261. conv_samples = generator.predict(min_batch, batch_size=self.neb)
  262. ## concatenate them with the majority batch
  263. concat_sample = tf.concat([conv_samples, maj_batch], axis=0)
  264. ## switch on discriminator training
  265. discriminator.trainable = True
  266. ## train the discriminator with the concatenated samples and the one-hot encoded labels
  267. discriminator.fit(x=concat_sample, y=labels, verbose=0, batch_size=20)
  268. ## switch off the discriminator training again
  269. discriminator.trainable = False
  270. for min_idx in range(minSetSize):
  271. ## generate minority neighbourhood batch for every minority class sampls by index
  272. min_batch_indices = shuffle(self.nmbMin.neighbourhoodOfItem(min_idx))
  273. min_batch = self.nmbMin.getPointsFromIndices(min_batch_indices)
  274. ## generate random proximal majority batch
  275. maj_batch = self._BMB(data_maj, min_batch_indices)
  276. ## generate synthetic samples from convex space
  277. ## of minority neighbourhood batch using generator
  278. conv_samples = generator.predict(min_batch, batch_size=self.neb)
  279. ## concatenate them with the majority batch
  280. concat_sample = tf.concat([conv_samples, maj_batch], axis=0)
  281. ## switch on discriminator training
  282. discriminator.trainable = True
  283. ## train the discriminator with the concatenated samples and the one-hot encoded labels
  284. discriminator.fit(x=concat_sample, y=labels, verbose=0, batch_size=20)
  285. ## switch off the discriminator training again
  286. discriminator.trainable = False
  287. ## use the complete network to make the generator learn on the decisions
  288. ## made by the previous discriminator training
  289. gen_loss_history = convGeN.fit(concat_sample, y=labels, verbose=0, batch_size=nLabels)
  290. ## store the loss for the step
  291. loss_history.append(gen_loss_history.history['loss'])
  292. step += 1
  293. if self.debug and (step % 10 == 0):
  294. print(f"{step} neighbourhood batches trained; running neighbourhood epoch {neb_epoch_count}")
  295. if self.debug:
  296. print(f"Neighbourhood epoch {neb_epoch_count + 1} complete")
  297. if self.debug:
  298. run_range = range(1, len(loss_history) + 1)
  299. plt.rcParams["figure.figsize"] = (16,10)
  300. plt.xticks(fontsize=20)
  301. plt.yticks(fontsize=20)
  302. plt.xlabel('runs', fontsize=25)
  303. plt.ylabel('loss', fontsize=25)
  304. plt.title('Rough learning loss for discriminator', fontsize=25)
  305. plt.plot(run_range, loss_history)
  306. plt.show()
  307. self.conv_sample_generator = generator
  308. self.maj_min_discriminator = discriminator
  309. self.cg = convGeN
  310. self.loss_history = loss_history
  311. def _BMB(self, data_maj, min_idxs):
  312. ## Generate a borderline majority batch
  313. ## data_maj -> majority class data
  314. ## min_idxs -> indices of points in minority class
  315. ## gen -> convex combinations generated from each neighbourhood
  316. if self.nmbMaj is not None:
  317. return self.nmbMaj.neighbourhoodOfItemList(shuffle(min_idxs), maxCount=self.gen)
  318. else:
  319. return tf.convert_to_tensor(data_maj[np.random.randint(len(data_maj), size=self.gen)])
  320. def retrainDiscriminitor(self, data, labels):
  321. self.maj_min_discriminator.trainable = True
  322. labels = np.array([ [x, 1 - x] for x in labels])
  323. self.maj_min_discriminator.fit(x=data, y=labels, batch_size=20, epochs=self.neb_epochs)
  324. self.maj_min_discriminator.trainable = False