runConvGanTest.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341
  1. import numpy as np
  2. from numpy.random import seed
  3. import pandas as pd
  4. import matplotlib.pyplot as plt
  5. from library.interfaces import GanBaseClass
  6. from library.dataset import DataSet
  7. from library.convGAN import ConvGAN, create01Labels
  8. from sklearn.decomposition import PCA
  9. from sklearn.metrics import confusion_matrix
  10. from sklearn.metrics import f1_score
  11. from sklearn.metrics import cohen_kappa_score
  12. from sklearn.metrics import precision_score
  13. from sklearn.metrics import recall_score
  14. from sklearn.neighbors import NearestNeighbors
  15. from sklearn.utils import shuffle
  16. from imblearn.datasets import fetch_datasets
  17. from keras.layers import Dense, Input, Multiply, Flatten, Conv1D, Reshape
  18. from keras.models import Model
  19. from keras import backend as K
  20. from tqdm import tqdm
  21. import tensorflow as tf
  22. from tensorflow.keras.optimizers import Adam
  23. from tensorflow.keras.layers import Lambda
  24. import warnings
  25. warnings.filterwarnings("ignore")
  26. ## this is the main training process where the GAn learns to generate appropriate samples from the convex space
  27. ## this is the first training phase for the discriminator and the only training phase for the generator.
  28. def rough_learning_predictions(discriminator,test_data_numpy,test_labels_numpy):
  29. """
  30. after the first phase of training the discriminator can be used for classification
  31. it already learns to differentiate the convex minority points with majority points
  32. during the first training phase
  33. """
  34. y_pred_2d = discriminator.predict(tf.convert_to_tensor(test_data_numpy))
  35. ## discretisation of the labels
  36. y_pred = np.digitize(y_pred_2d[:,0], [.5])
  37. ## prediction shows a model with good recall and less precision
  38. c = confusion_matrix(test_labels_numpy, y_pred)
  39. f = f1_score(test_labels_numpy, y_pred)
  40. pr = precision_score(test_labels_numpy, y_pred)
  41. rc = recall_score(test_labels_numpy, y_pred)
  42. k = cohen_kappa_score(test_labels_numpy, y_pred)
  43. print('Rough learning confusion matrix:', c)
  44. print('Rough learning f1 score', f)
  45. print('Rough learning precision score', pr)
  46. print('Rough learning recall score', rc)
  47. print('Rough learning kappa score', k)
  48. return c,f,pr,rc,k
  49. def generate_synthetic_data(gan, data_min, data_maj):
  50. ## roughly claculate the upper bound of the synthetic samples
  51. ## to be generated from each neighbourhood
  52. synth_num = ((len(data_maj) - len(data_min)) // len(data_min)) + 1
  53. ## generate synth_num synthetic samples from each minority neighbourhood
  54. synth_set = gan.generateData(synth_num)
  55. ovs_min_class = np.concatenate((data_min,synth_set), axis=0)
  56. ovs_training_dataset = np.concatenate((ovs_min_class,data_maj), axis=0)
  57. ovs_pca_labels = np.concatenate((
  58. np.zeros(len(data_min)),
  59. np.zeros(len(synth_set)) + 1,
  60. np.zeros(len(data_maj)) + 2
  61. ))
  62. ovs_training_labels_oh = create01Labels(len(ovs_training_dataset), len(ovs_min_class))
  63. ovs_training_labels_oh = tf.convert_to_tensor(ovs_training_labels_oh)
  64. ## PCA visualization of the synthetic sata
  65. ## observe how the minority samples from convex space have optimal variance
  66. ## and avoids overlap with the majority
  67. pca = PCA(n_components=2)
  68. pca.fit(ovs_training_dataset)
  69. data_pca = pca.transform(ovs_training_dataset)
  70. ## plot PCA
  71. plt.rcParams["figure.figsize"] = (12,12)
  72. plt.xticks(fontsize=20)
  73. plt.yticks(fontsize=20)
  74. plt.xlabel('PCA1',fontsize=25)
  75. plt.ylabel('PCA2', fontsize=25)
  76. plt.title('PCA plot of oversampled data',fontsize=25)
  77. classes = ['minority', 'synthetic minority', 'majority']
  78. scatter=plt.scatter(data_pca[:,0], data_pca[:,1], c=ovs_pca_labels, cmap='Set1')
  79. plt.legend(handles=scatter.legend_elements()[0], labels=classes, fontsize=20)
  80. plt.show()
  81. return ovs_training_dataset, ovs_pca_labels, ovs_training_labels_oh
  82. def final_learning(discriminator, ovs_training_dataset, ovs_training_labels_oh, test_data_numpy, test_labels_numpy, num_epochs):
  83. print('\n')
  84. print('Final round training of the discrminator as a majority-minority classifier')
  85. print('\n')
  86. ## second phase training of the discriminator with balanced data
  87. history_second_learning = discriminator.fit(x=ovs_training_dataset, y=ovs_training_labels_oh, batch_size=20, epochs=num_epochs)
  88. ## loss of the second phase learning smoothly decreses
  89. ## this is because now the data is fixed and diverse convex combinations are no longer fed into the discriminator at every training step
  90. run_range = range(1, num_epochs + 1)
  91. plt.rcParams["figure.figsize"] = (16,10)
  92. plt.xticks(fontsize=20)
  93. plt.yticks(fontsize=20)
  94. plt.xlabel('runs',fontsize=25)
  95. plt.ylabel('loss', fontsize=25)
  96. plt.title('Final learning loss for discriminator', fontsize=25)
  97. plt.plot(run_range, history_second_learning.history['loss'])
  98. plt.show()
  99. ## finally after second phase training the discriminator classifier has a more balanced performance
  100. ## meaning better F1-Score
  101. ## the recall decreases but the precision improves
  102. print('\n')
  103. y_pred_2d = discriminator.predict(tf.convert_to_tensor(test_data_numpy))
  104. y_pred = np.digitize(y_pred_2d[:,0], [.5])
  105. c = confusion_matrix(test_labels_numpy, y_pred)
  106. f = f1_score(test_labels_numpy, y_pred)
  107. pr = precision_score(test_labels_numpy, y_pred)
  108. rc = recall_score(test_labels_numpy, y_pred)
  109. k = cohen_kappa_score(test_labels_numpy, y_pred)
  110. print('Final learning confusion matrix:', c)
  111. print('Final learning f1 score', f)
  112. print('Final learning precision score', pr)
  113. print('Final learning recall score', rc)
  114. print('Final learning kappa score', k)
  115. return c, f, pr, rc, k
  116. def convGAN_train_end_to_end(training_data, training_labels, test_data, test_labels, neb, gen, neb_epochs, epochs_retrain_disc):
  117. ##minority class
  118. data_min=training_data[np.where(training_labels == 1)[0]]
  119. ##majority class
  120. data_maj=training_data[np.where(training_labels == 0)[0]]
  121. dataSet = DataSet(data0=data_maj, data1=data_min)
  122. gan = ConvGAN(data_min.shape[1], neb, gen)
  123. gan.reset()
  124. ## instanciate generator network and visualize architecture
  125. conv_sample_generator = gan.conv_sample_generator
  126. print(conv_sample_generator.summary())
  127. print('\n')
  128. ## instanciate discriminator network and visualize architecture
  129. maj_min_discriminator = gan.maj_min_discriminator
  130. print(maj_min_discriminator.summary())
  131. print('\n')
  132. ## instanciate network and visualize architecture
  133. cg = gan.cg
  134. print(cg.summary())
  135. print('\n')
  136. print('Training the GAN, first round training of the discrminator as a majority-minority classifier')
  137. print('\n')
  138. ## train gan generator ## rough_train_discriminator
  139. gan.train(dataSet, neb_epochs)
  140. print('\n')
  141. ## rough learning results
  142. c_r,f_r,pr_r,rc_r,k_r = rough_learning_predictions(gan.maj_min_discriminator_r, test_data, test_labels)
  143. print('\n')
  144. ## generate synthetic data
  145. ovs_training_dataset, ovs_pca_labels, ovs_training_labels_oh = generate_synthetic_data(gan, data_min, data_maj)
  146. print('\n')
  147. ## final training results
  148. c,f,pr,rc,k = final_learning(gan.maj_min_discriminator, ovs_training_dataset, ovs_training_labels_oh, test_data, test_labels, epochs_retrain_disc)
  149. return ((c_r,f_r,pr_r,rc_r,k_r),(c,f,pr,rc,k))
  150. def unison_shuffled_copies(a, b,seed_perm):
  151. 'Shuffling the feature matrix along with the labels with same order'
  152. np.random.seed(seed_perm)##change seed 1,2,3,4,5
  153. assert len(a) == len(b)
  154. p = np.random.permutation(len(a))
  155. return a[p], b[p]
  156. def runTest():
  157. seed_num=1
  158. seed(seed_num)
  159. tf.random.set_seed(seed_num)
  160. ## Import dataset
  161. data = fetch_datasets()['yeast_me2']
  162. ## Creating label and feature matrices
  163. labels_x = data.target ## labels of the data
  164. features_x = data.data ## features of the data
  165. # Until now we have obtained the data. We divided it into training and test sets. we separated obtained seperate variables for the majority and miority classes and their labels for both sets.
  166. ## specify parameters
  167. neb=gen=5 ##neb=gen required
  168. neb_epochs=10
  169. epochs_retrain_disc=50
  170. ## Training
  171. np.random.seed(42)
  172. strata=5
  173. results=[]
  174. for seed_perm in range(strata):
  175. features_x,labels_x=unison_shuffled_copies(features_x,labels_x,seed_perm)
  176. ### Extracting all features and labels
  177. print('Extracting all features and labels for seed:'+ str(seed_perm)+'\n')
  178. ## Dividing data into training and testing datasets for 10-fold CV
  179. print('Dividing data into training and testing datasets for 10-fold CV for seed:'+ str(seed_perm)+'\n')
  180. label_1=list(np.where(labels_x == 1)[0])
  181. features_1=features_x[label_1]
  182. label_0=list(np.where(labels_x != 1)[0])
  183. features_0=features_x[label_0]
  184. a=len(features_1)//5
  185. b=len(features_0)//5
  186. fold_1_min=features_1[0:a]
  187. fold_1_maj=features_0[0:b]
  188. fold_1_tst=np.concatenate((fold_1_min,fold_1_maj))
  189. lab_1_tst=np.concatenate((np.zeros(len(fold_1_min))+1, np.zeros(len(fold_1_maj))))
  190. fold_2_min=features_1[a:2*a]
  191. fold_2_maj=features_0[b:2*b]
  192. fold_2_tst=np.concatenate((fold_2_min,fold_2_maj))
  193. lab_2_tst=np.concatenate((np.zeros(len(fold_1_min))+1, np.zeros(len(fold_1_maj))))
  194. fold_3_min=features_1[2*a:3*a]
  195. fold_3_maj=features_0[2*b:3*b]
  196. fold_3_tst=np.concatenate((fold_3_min,fold_3_maj))
  197. lab_3_tst=np.concatenate((np.zeros(len(fold_1_min))+1, np.zeros(len(fold_1_maj))))
  198. fold_4_min=features_1[3*a:4*a]
  199. fold_4_maj=features_0[3*b:4*b]
  200. fold_4_tst=np.concatenate((fold_4_min,fold_4_maj))
  201. lab_4_tst=np.concatenate((np.zeros(len(fold_1_min))+1, np.zeros(len(fold_1_maj))))
  202. fold_5_min=features_1[4*a:]
  203. fold_5_maj=features_0[4*b:]
  204. fold_5_tst=np.concatenate((fold_5_min,fold_5_maj))
  205. lab_5_tst=np.concatenate((np.zeros(len(fold_5_min))+1, np.zeros(len(fold_5_maj))))
  206. fold_1_trn=np.concatenate((fold_2_min,fold_3_min,fold_4_min,fold_5_min, fold_2_maj,fold_3_maj,fold_4_maj,fold_5_maj))
  207. lab_1_trn=np.concatenate((np.zeros(3*a+len(fold_5_min))+1,np.zeros(3*b+len(fold_5_maj))))
  208. fold_2_trn=np.concatenate((fold_1_min,fold_3_min,fold_4_min,fold_5_min,fold_1_maj,fold_3_maj,fold_4_maj,fold_5_maj))
  209. lab_2_trn=np.concatenate((np.zeros(3*a+len(fold_5_min))+1,np.zeros(3*b+len(fold_5_maj))))
  210. fold_3_trn=np.concatenate((fold_2_min,fold_1_min,fold_4_min,fold_5_min,fold_2_maj,fold_1_maj,fold_4_maj,fold_5_maj))
  211. lab_3_trn=np.concatenate((np.zeros(3*a+len(fold_5_min))+1,np.zeros(3*b+len(fold_5_maj))))
  212. fold_4_trn=np.concatenate((fold_2_min,fold_3_min,fold_1_min,fold_5_min,fold_2_maj,fold_3_maj,fold_1_maj,fold_5_maj))
  213. lab_4_trn=np.concatenate((np.zeros(3*a+len(fold_5_min))+1,np.zeros(3*b+len(fold_5_maj))))
  214. fold_5_trn=np.concatenate((fold_2_min,fold_3_min,fold_4_min,fold_1_min,fold_2_maj,fold_3_maj,fold_4_maj,fold_1_maj))
  215. lab_5_trn=np.concatenate((np.zeros(4*a)+1,np.zeros(4*b)))
  216. training_folds_feats=[fold_1_trn,fold_2_trn,fold_3_trn,fold_4_trn,fold_5_trn]
  217. testing_folds_feats=[fold_1_tst,fold_2_tst,fold_3_tst,fold_4_tst,fold_5_tst]
  218. training_folds_labels=[lab_1_trn,lab_2_trn,lab_3_trn,lab_4_trn,lab_5_trn]
  219. testing_folds_labels=[lab_1_tst,lab_2_tst,lab_3_tst,lab_4_tst,lab_5_tst]
  220. for i in range(5):
  221. print('\n')
  222. print('Executing fold: '+str(i+1))
  223. print('\n')
  224. r1,r2=convGAN_train_end_to_end(training_folds_feats[i],training_folds_labels[i],testing_folds_feats[i],testing_folds_labels[i], neb, gen, neb_epochs, epochs_retrain_disc)
  225. results.append(np.array([list(r1[1:]),list(r2[1:])]))
  226. results=np.array(results)
  227. ## Benchmark
  228. mean_rough=np.mean(results[:,0], axis=0)
  229. data_r={'F1-Score_r':[mean_rough[0]], 'Precision_r' : [mean_rough[1]], 'Recall_r' : [mean_rough[2]], 'Kappa_r': [mean_rough[3]]}
  230. df_r=pd.DataFrame(data=data_r)
  231. print('Rough training results:')
  232. print('\n')
  233. print(df_r)
  234. mean_final=np.mean(results[:,1], axis=0)
  235. data_f={'F1-Score_f':[mean_final[0]], 'Precision_f' : [mean_final[1]], 'Recall_f' : [mean_final[2]], 'Kappa_f': [mean_final[3]]}
  236. df_f=pd.DataFrame(data=data_f)
  237. print('Final training results:')
  238. print('\n')
  239. print(df_f)
  240. if __name__ == "__main__":
  241. runTest()