convGAN2.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413
  1. import numpy as np
  2. from numpy.random import seed
  3. import pandas as pd
  4. import matplotlib.pyplot as plt
  5. from library.interfaces import GanBaseClass
  6. from library.dataset import DataSet
  7. from sklearn.decomposition import PCA
  8. from sklearn.metrics import confusion_matrix
  9. from sklearn.metrics import f1_score
  10. from sklearn.metrics import cohen_kappa_score
  11. from sklearn.metrics import precision_score
  12. from sklearn.metrics import recall_score
  13. from sklearn.neighbors import NearestNeighbors
  14. from sklearn.utils import shuffle
  15. from imblearn.datasets import fetch_datasets
  16. from keras.layers import Dense, Input, Multiply, Flatten, Conv1D, Reshape
  17. from keras.models import Model
  18. from keras import backend as K
  19. from tqdm import tqdm
  20. import tensorflow as tf
  21. from tensorflow.keras.optimizers import Adam
  22. from tensorflow.keras.layers import Lambda
  23. import time
  24. from library.NNSearch import NNSearch
  25. from library.timing import timing
  26. import warnings
  27. warnings.filterwarnings("ignore")
  28. def repeat(x, times):
  29. return [x for _i in range(times)]
  30. def create01Labels(totalSize, sizeFirstHalf):
  31. labels = repeat(np.array([1,0]), sizeFirstHalf)
  32. labels.extend(repeat(np.array([0,1]), totalSize - sizeFirstHalf))
  33. return np.array(labels)
  34. class ConvGAN2(GanBaseClass):
  35. """
  36. This is a toy example of a GAN.
  37. It repeats the first point of the training-data-set.
  38. """
  39. def __init__(self, n_feat, neb=5, gen=5, neb_epochs=10, debug=True):
  40. self.isTrained = False
  41. self.n_feat = n_feat
  42. self.neb = neb
  43. self.gen = gen
  44. self.neb_epochs = 10
  45. self.loss_history = None
  46. self.debug = debug
  47. self.dataSet = None
  48. self.conv_sample_generator = None
  49. self.maj_min_discriminator = None
  50. self.cg = None
  51. self.tNbhFit = 0.0
  52. self.tNbhSearch = 0.0
  53. self.nNbhFit = 0
  54. self.nNbhSearch = 0
  55. self.timing = { name: timing(name) for name in ["reset", "train", "create points", "NMB", "BMB", "_generate_data_for_min_point","predict"]}
  56. if neb > gen:
  57. raise ValueError(f"Expected neb <= gen but got neb={neb} and gen={gen}.")
  58. def reset(self):
  59. """
  60. Resets the trained GAN to an random state.
  61. """
  62. self.timing["reset"].start()
  63. self.isTrained = False
  64. ## instanciate generator network and visualize architecture
  65. self.conv_sample_generator = self._conv_sample_gen()
  66. ## instanciate discriminator network and visualize architecture
  67. self.maj_min_discriminator = self._maj_min_disc()
  68. ## instanciate network and visualize architecture
  69. self.cg = self._convGAN(self.conv_sample_generator, self.maj_min_discriminator)
  70. self.timing["reset"].stop()
  71. if self.debug:
  72. print(self.conv_sample_generator.summary())
  73. print('\n')
  74. print(self.maj_min_discriminator.summary())
  75. print('\n')
  76. print(self.cg.summary())
  77. print('\n')
  78. def train(self, dataSet):
  79. """
  80. Trains the GAN.
  81. It stores the data points in the training data set and mark as trained.
  82. *dataSet* is a instance of /library.dataset.DataSet/. It contains the training dataset.
  83. We are only interested in the first *maxListSize* points in class 1.
  84. """
  85. self.timing["train"].start()
  86. if dataSet.data1.shape[0] <= 0:
  87. raise AttributeError("Train: Expected data class 1 to contain at least one point.")
  88. self.dataSet = dataSet
  89. self._rough_learning(dataSet.data1, dataSet.data0)
  90. self.isTrained = True
  91. self.timing["train"].stop()
  92. def generateDataPoint(self):
  93. """
  94. Returns one synthetic data point by repeating the stored list.
  95. """
  96. return (self.generateData(1))[0]
  97. def generateData(self, numOfSamples=1):
  98. """
  99. Generates a list of synthetic data-points.
  100. *numOfSamples* is a integer > 0. It gives the number of new generated samples.
  101. """
  102. self.timing["create points"].start()
  103. if not self.isTrained:
  104. raise ValueError("Try to generate data with untrained Re.")
  105. data_min = self.dataSet.data1
  106. ## roughly claculate the upper bound of the synthetic samples to be generated from each neighbourhood
  107. synth_num = (numOfSamples // len(data_min)) + 1
  108. ## generate synth_num synthetic samples from each minority neighbourhood
  109. synth_set=[]
  110. nmb = self._NMB_prepare(data_min)
  111. for i in range(len(data_min)):
  112. synth_set.extend(self._generate_data_for_min_point(nmb, i, synth_num))
  113. synth_set = np.array(synth_set[:numOfSamples]) ## extract the exact number of synthetic samples needed to exactly balance the two classes
  114. self.timing["create points"].stop()
  115. return synth_set
  116. # ###############################################################
  117. # Hidden internal functions
  118. # ###############################################################
  119. # Creating the GAN
  120. def _conv_sample_gen(self):
  121. """
  122. the generator network to generate synthetic samples from the convex space
  123. of arbitrary minority neighbourhoods
  124. """
  125. ## takes minority batch as input
  126. min_neb_batch = Input(shape=(self.n_feat,))
  127. ## reshaping the 2D tensor to 3D for using 1-D convolution,
  128. ## otherwise 1-D convolution won't work.
  129. x = tf.reshape(min_neb_batch, (1, self.neb, self.n_feat), name=None)
  130. ## using 1-D convolution, feature dimension remains the same
  131. x = Conv1D(self.n_feat, 3, activation='relu')(x)
  132. ## flatten after convolution
  133. x = Flatten()(x)
  134. ## add dense layer to transform the vector to a convenient dimension
  135. x = Dense(self.neb * self.gen, activation='relu')(x)
  136. ## again, witching to 2-D tensor once we have the convenient shape
  137. x = Reshape((self.neb, self.gen))(x)
  138. ## row wise sum
  139. s = K.sum(x, axis=1)
  140. ## adding a small constant to always ensure the row sums are non zero.
  141. ## if this is not done then during initialization the sum can be zero.
  142. s_non_zero = Lambda(lambda x: x + .000001)(s)
  143. ## reprocals of the approximated row sum
  144. sinv = tf.math.reciprocal(s_non_zero)
  145. ## At this step we ensure that row sum is 1 for every row in x.
  146. ## That means, each row is set of convex co-efficient
  147. x = Multiply()([sinv, x])
  148. ## Now we transpose the matrix. So each column is now a set of convex coefficients
  149. aff=tf.transpose(x[0])
  150. ## We now do matrix multiplication of the affine combinations with the original
  151. ## minority batch taken as input. This generates a convex transformation
  152. ## of the input minority batch
  153. synth=tf.matmul(aff, min_neb_batch)
  154. ## finally we compile the generator with an arbitrary minortiy neighbourhood batch
  155. ## as input and a covex space transformation of the same number of samples as output
  156. model = Model(inputs=min_neb_batch, outputs=synth)
  157. opt = Adam(learning_rate=0.001)
  158. model.compile(loss='mean_squared_logarithmic_error', optimizer=opt)
  159. return model
  160. def _maj_min_disc(self):
  161. """
  162. the discriminator is trained intwo phase:
  163. first phase: while training GAN the discriminator learns to differentiate synthetic
  164. minority samples generated from convex minority data space against
  165. the borderline majority samples
  166. second phase: after the GAN generator learns to create synthetic samples,
  167. it can be used to generate synthetic samples to balance the dataset
  168. and then rettrain the discriminator with the balanced dataset
  169. """
  170. ## takes as input synthetic sample generated as input stacked upon a batch of
  171. ## borderline majority samples
  172. samples = Input(shape=(self.n_feat,))
  173. ## passed through two dense layers
  174. y = Dense(250, activation='relu')(samples)
  175. y = Dense(125, activation='relu')(y)
  176. ## two output nodes. outputs have to be one-hot coded (see labels variable before)
  177. output = Dense(2, activation='sigmoid')(y)
  178. ## compile model
  179. model = Model(inputs=samples, outputs=output)
  180. opt = Adam(learning_rate=0.0001)
  181. model.compile(loss='binary_crossentropy', optimizer=opt)
  182. return model
  183. def _convGAN(self, generator, discriminator):
  184. """
  185. for joining the generator and the discriminator
  186. conv_coeff_generator-> generator network instance
  187. maj_min_discriminator -> discriminator network instance
  188. """
  189. ## by default the discriminator trainability is switched off.
  190. ## Thus training the GAN means training the generator network as per previously
  191. ## trained discriminator network.
  192. discriminator.trainable = False
  193. ## input receives a neighbourhood minority batch
  194. ## and a proximal majority batch concatenated
  195. batch_data = Input(shape=(self.n_feat,))
  196. ##- print(f"GAN: 0..{self.neb}/{self.gen}..")
  197. ## extract minority batch
  198. min_batch = Lambda(lambda x: x[:self.neb])(batch_data)
  199. ## extract majority batch
  200. maj_batch = Lambda(lambda x: x[self.gen:])(batch_data)
  201. ## pass minority batch into generator to obtain convex space transformation
  202. ## (synthetic samples) of the minority neighbourhood input batch
  203. conv_samples = generator(min_batch)
  204. ## concatenate the synthetic samples with the majority samples
  205. new_samples = tf.concat([conv_samples, maj_batch],axis=0)
  206. ##- new_samples = tf.concat([conv_samples, conv_samples, conv_samples, conv_samples],axis=0)
  207. ## pass the concatenated vector into the discriminator to know its decisions
  208. output = discriminator(new_samples)
  209. ##- output = Lambda(lambda x: x[:2 * self.gen])(output)
  210. ## note that, the discriminator will not be traied but will make decisions based
  211. ## on its previous training while using this function
  212. model = Model(inputs=batch_data, outputs=output)
  213. opt = Adam(learning_rate=0.0001)
  214. model.compile(loss='mse', optimizer=opt)
  215. return model
  216. # Create synthetic points
  217. def _generate_data_for_min_point(self, nmb, index, synth_num):
  218. """
  219. generate synth_num synthetic points for a particular minoity sample
  220. synth_num -> required number of data points that can be generated from a neighbourhood
  221. data_min -> minority class data
  222. neb -> oversampling neighbourhood
  223. index -> index of the minority sample in a training data whose neighbourhood we want to obtain
  224. """
  225. self.timing["_generate_data_for_min_point"].start()
  226. runs = int(synth_num / self.neb) + 1
  227. synth_set = []
  228. for _run in range(runs):
  229. batch = self._NMB_guided(nmb, index)
  230. self.timing["predict"].start()
  231. synth_batch = self.conv_sample_generator.predict(batch)
  232. self.timing["predict"].stop()
  233. synth_set.extend(synth_batch)
  234. #for x in synth_batch:
  235. # synth_set.append(x)
  236. self.timing["_generate_data_for_min_point"].stop()
  237. return synth_set[:synth_num]
  238. # Training
  239. def _rough_learning(self, data_min, data_maj):
  240. generator = self.conv_sample_generator
  241. discriminator = self.maj_min_discriminator
  242. GAN = self.cg
  243. loss_history = [] ## this is for stroring the loss for every run
  244. min_idx = 0
  245. neb_epoch_count = 1
  246. labels = tf.convert_to_tensor(create01Labels(2 * self.gen, self.gen))
  247. nmb = self._NMB_prepare(data_min)
  248. for step in range(self.neb_epochs * len(data_min)):
  249. ## generate minority neighbourhood batch for every minority class sampls by index
  250. min_batch = self._NMB_guided(nmb, min_idx)
  251. min_idx = min_idx + 1
  252. ## generate random proximal majority batch
  253. maj_batch = self._BMB(data_min, data_maj)
  254. ## generate synthetic samples from convex space
  255. ## of minority neighbourhood batch using generator
  256. conv_samples = generator.predict(min_batch)
  257. ## concatenate them with the majority batch
  258. concat_sample = tf.concat([conv_samples, maj_batch], axis=0)
  259. ## switch on discriminator training
  260. discriminator.trainable = True
  261. ## train the discriminator with the concatenated samples and the one-hot encoded labels
  262. discriminator.fit(x=concat_sample, y=labels, verbose=0)
  263. ## switch off the discriminator training again
  264. discriminator.trainable = False
  265. ## use the GAN to make the generator learn on the decisions
  266. ## made by the previous discriminator training
  267. ##- print(f"concat sample shape: {concat_sample.shape}/{labels.shape}")
  268. gan_loss_history = GAN.fit(concat_sample, y=labels, verbose=0)
  269. ## store the loss for the step
  270. loss_history.append(gan_loss_history.history['loss'])
  271. if self.debug and ((step + 1) % 10 == 0):
  272. print(f"{step + 1} neighbourhood batches trained; running neighbourhood epoch {neb_epoch_count}")
  273. if min_idx == len(data_min) - 1:
  274. if self.debug:
  275. print(f"Neighbourhood epoch {neb_epoch_count} complete")
  276. neb_epoch_count = neb_epoch_count + 1
  277. min_idx = 0
  278. if self.debug:
  279. run_range = range(1, len(loss_history) + 1)
  280. plt.rcParams["figure.figsize"] = (16,10)
  281. plt.xticks(fontsize=20)
  282. plt.yticks(fontsize=20)
  283. plt.xlabel('runs', fontsize=25)
  284. plt.ylabel('loss', fontsize=25)
  285. plt.title('Rough learning loss for discriminator', fontsize=25)
  286. plt.plot(run_range, loss_history)
  287. plt.show()
  288. self.conv_sample_generator = generator
  289. self.maj_min_discriminator = discriminator
  290. self.cg = GAN
  291. self.loss_history = loss_history
  292. ## convGAN
  293. def _BMB(self, data_min, data_maj):
  294. ## Generate a borderline majority batch
  295. ## data_min -> minority class data
  296. ## data_maj -> majority class data
  297. ## neb -> oversampling neighbourhood
  298. ## gen -> convex combinations generated from each neighbourhood
  299. self.timing["BMB"].start()
  300. result = tf.convert_to_tensor(
  301. data_maj[np.random.randint(len(data_maj), size=self.gen)]
  302. )
  303. self.timing["BMB"].stop()
  304. return result
  305. def _NMB_prepare(self, data_min):
  306. self.timing["NMB"].start()
  307. t = time.time()
  308. neigh = NNSearch(self.neb)
  309. #neigh = NearestNeighbors(self.neb)
  310. neigh.fit(data_min)
  311. self.tNbhFit += (time.time() - t)
  312. self.nNbhFit += 1
  313. self.timing["NMB"].stop()
  314. return (data_min, neigh)
  315. def _NMB_guided(self, nmb, index):
  316. ## generate a minority neighbourhood batch for a particular minority sample
  317. ## we need this for minority data generation
  318. ## we will generate synthetic samples for each training data neighbourhood
  319. ## index -> index of the minority sample in a training data whose neighbourhood we want to obtain
  320. ## data_min -> minority class data
  321. ## neb -> oversampling neighbourhood
  322. self.timing["NMB"].start()
  323. (data_min, neigh) = nmb
  324. t = time.time()
  325. #nmbi = neigh.kneighbors([data_min[index]], self.neb, return_distance=False)
  326. nmbi = np.array([neigh.neighbourhoodOfItem(index)])
  327. self.tNbhSearch += (time.time() - t)
  328. self.nNbhSearch += 1
  329. nmbi = shuffle(nmbi)
  330. nmb = data_min[nmbi]
  331. nmb = tf.convert_to_tensor(nmb[0])
  332. self.timing["NMB"].stop()
  333. return nmb