convGAN.py 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670
  1. import numpy as np
  2. from numpy.random import seed
  3. import pandas as pd
  4. import matplotlib.pyplot as plt
  5. from library.interfaces import GanBaseClass
  6. from library.dataset import DataSet
  7. from sklearn.decomposition import PCA
  8. from sklearn.metrics import confusion_matrix
  9. from sklearn.metrics import f1_score
  10. from sklearn.metrics import cohen_kappa_score
  11. from sklearn.metrics import precision_score
  12. from sklearn.metrics import recall_score
  13. from sklearn.neighbors import NearestNeighbors
  14. from sklearn.utils import shuffle
  15. from imblearn.datasets import fetch_datasets
  16. from keras.layers import Dense, Input, Multiply, Flatten, Conv1D, Reshape
  17. from keras.models import Model
  18. from keras import backend as K
  19. from tqdm import tqdm
  20. import tensorflow as tf
  21. from tensorflow.keras.optimizers import Adam
  22. from tensorflow.keras.layers import Lambda
  23. import warnings
  24. warnings.filterwarnings("ignore")
  25. class ConvGAN(GanBaseClass):
  26. """
  27. This is a toy example of a GAN.
  28. It repeats the first point of the training-data-set.
  29. """
  30. def __init__(self, n_feat, neb, gen, debug=True):
  31. self.isTrained = False
  32. self.n_feat = n_feat
  33. self.neb = neb
  34. self.gen = gen
  35. self.loss_history = None
  36. self.debug = debug
  37. self.dataSet = None
  38. self.conv_sample_generator = None
  39. self.maj_min_discriminator = None
  40. self.cg = None
  41. def reset(self):
  42. """
  43. Resets the trained GAN to an random state.
  44. """
  45. self.isTrained = False
  46. ## instanciate generator network and visualize architecture
  47. self.conv_sample_generator = self._conv_sample_gen()
  48. ## instanciate discriminator network and visualize architecture
  49. self.maj_min_discriminator = self._maj_min_disc()
  50. ## instanciate network and visualize architecture
  51. self.cg = self._convGAN(self.conv_sample_generator, self.maj_min_discriminator)
  52. def train(self, dataSet, neb_epochs=5):
  53. """
  54. Trains the GAN.
  55. It stores the data points in the training data set and mark as trained.
  56. *dataSet* is a instance of /library.dataset.DataSet/. It contains the training dataset.
  57. We are only interested in the first *maxListSize* points in class 1.
  58. """
  59. if dataSet.data1.shape[0] <= 0:
  60. raise AttributeError("Train: Expected data class 1 to contain at least one point.")
  61. self.dataSet = dataSet
  62. self._rough_learning(neb_epochs, dataSet.data1, dataSet.data0)
  63. self.isTrained = True
  64. def generateDataPoint(self):
  65. """
  66. Returns one synthetic data point by repeating the stored list.
  67. """
  68. return (self.generateData(1))[0]
  69. def generateData(self, numOfSamples=1):
  70. """
  71. Generates a list of synthetic data-points.
  72. *numOfSamples* is a integer > 0. It gives the number of new generated samples.
  73. """
  74. if not self.isTrained:
  75. raise ValueError("Try to generate data with untrained Re.")
  76. data_min = self.dataSet.data1
  77. ## roughly claculate the upper bound of the synthetic samples to be generated from each neighbourhood
  78. synth_num = (numOfSamples // len(data_min)) + 1
  79. ## generate synth_num synthetic samples from each minority neighbourhood
  80. synth_set=[]
  81. for i in range(len(data_min)):
  82. synth_set.extend(self.generate_data_for_min_point(data_min, i, synth_num))
  83. synth_set = synth_set[:numOfSamples] ## extract the exact number of synthetic samples needed to exactly balance the two classes
  84. return np.array(synth_set)
  85. # ###############################################################
  86. # Hidden internal functions
  87. # ###############################################################
  88. # Creating the GAN
  89. def _conv_sample_gen(self):
  90. """
  91. the generator network to generate synthetic samples from the convex space
  92. of arbitrary minority neighbourhoods
  93. """
  94. ## takes minority batch as input
  95. min_neb_batch = Input(shape=(self.n_feat,))
  96. ## reshaping the 2D tensor to 3D for using 1-D convolution,
  97. ## otherwise 1-D convolution won't work.
  98. x = tf.reshape(min_neb_batch, (1, self.neb, self.n_feat), name=None)
  99. ## using 1-D convolution, feature dimension remains the same
  100. x = Conv1D(self.n_feat, 3, activation='relu')(x)
  101. ## flatten after convolution
  102. x = Flatten()(x)
  103. ## add dense layer to transform the vector to a convenient dimension
  104. x = Dense(self.neb * self.gen, activation='relu')(x)
  105. ## again, witching to 2-D tensor once we have the convenient shape
  106. x = Reshape((self.neb, self.gen))(x)
  107. ## row wise sum
  108. s = K.sum(x, axis=1)
  109. ## adding a small constant to always ensure the row sums are non zero.
  110. ## if this is not done then during initialization the sum can be zero.
  111. s_non_zero = Lambda(lambda x: x + .000001)(s)
  112. ## reprocals of the approximated row sum
  113. sinv = tf.math.reciprocal(s_non_zero)
  114. ## At this step we ensure that row sum is 1 for every row in x.
  115. ## That means, each row is set of convex co-efficient
  116. x = Multiply()([sinv, x])
  117. ## Now we transpose the matrix. So each column is now a set of convex coefficients
  118. aff=tf.transpose(x[0])
  119. ## We now do matrix multiplication of the affine combinations with the original
  120. ## minority batch taken as input. This generates a convex transformation
  121. ## of the input minority batch
  122. synth=tf.matmul(aff, min_neb_batch)
  123. ## finally we compile the generator with an arbitrary minortiy neighbourhood batch
  124. ## as input and a covex space transformation of the same number of samples as output
  125. model = Model(inputs=min_neb_batch, outputs=synth)
  126. opt = Adam(learning_rate=0.001)
  127. model.compile(loss='mean_squared_logarithmic_error', optimizer=opt)
  128. return model
  129. def _maj_min_disc(self):
  130. """
  131. the discriminator is trained intwo phase:
  132. first phase: while training GAN the discriminator learns to differentiate synthetic
  133. minority samples generated from convex minority data space against
  134. the borderline majority samples
  135. second phase: after the GAN generator learns to create synthetic samples,
  136. it can be used to generate synthetic samples to balance the dataset
  137. and then rettrain the discriminator with the balanced dataset
  138. """
  139. ## takes as input synthetic sample generated as input stacked upon a batch of
  140. ## borderline majority samples
  141. samples = Input(shape=(self.n_feat,))
  142. ## passed through two dense layers
  143. y = Dense(250, activation='relu')(samples)
  144. y = Dense(125, activation='relu')(y)
  145. ## two output nodes. outputs have to be one-hot coded (see labels variable before)
  146. output = Dense(2, activation='sigmoid')(y)
  147. ## compile model
  148. model = Model(inputs=samples, outputs=output)
  149. opt = Adam(learning_rate=0.0001)
  150. model.compile(loss='binary_crossentropy', optimizer=opt)
  151. return model
  152. def _convGAN(self, generator, discriminator):
  153. """
  154. for joining the generator and the discriminator
  155. conv_coeff_generator-> generator network instance
  156. maj_min_discriminator -> discriminator network instance
  157. """
  158. ## by default the discriminator trainability is switched off.
  159. ## Thus training the GAN means training the generator network as per previously
  160. ## trained discriminator network.
  161. discriminator.trainable = False
  162. ## input receives a neighbourhood minority batch
  163. ## and a proximal majority batch concatenated
  164. batch_data = Input(shape=(self.n_feat,))
  165. ## extract minority batch
  166. min_batch = Lambda(lambda x: x[:self.neb])(batch_data)
  167. ## extract majority batch
  168. maj_batch = Lambda(lambda x: x[self.neb:])(batch_data)
  169. ## pass minority batch into generator to obtain convex space transformation
  170. ## (synthetic samples) of the minority neighbourhood input batch
  171. conv_samples = generator(min_batch)
  172. ## concatenate the synthetic samples with the majority samples
  173. new_samples = tf.concat([conv_samples, maj_batch],axis=0)
  174. ## pass the concatenated vector into the discriminator to know its decisions
  175. output = discriminator(new_samples)
  176. ## note that, the discriminator will not be traied but will make decisions based
  177. ## on its previous training while using this function
  178. model = Model(inputs=batch_data, outputs=output)
  179. opt = Adam(learning_rate=0.0001)
  180. model.compile(loss='mse', optimizer=opt)
  181. return model
  182. # Create synthetic points
  183. def _generate_data_for_min_point(self, data_min, index, synth_num):
  184. """
  185. generate synth_num synthetic points for a particular minoity sample
  186. synth_num -> required number of data points that can be generated from a neighbourhood
  187. data_min -> minority class data
  188. neb -> oversampling neighbourhood
  189. index -> index of the minority sample in a training data whose neighbourhood we want to obtain
  190. """
  191. runs = int(synth_num / self.neb) + 1
  192. synth_set = []
  193. for _run in range(runs):
  194. batch = self._NMB_guided(data_min, index)
  195. synth_batch = self.conv_sample_generator.predict(batch)
  196. for x in synth_batch:
  197. synth_set.append(x)
  198. return synth_set[:synth_num]
  199. # Training
  200. def _rough_learning(self, neb_epochs, data_min, data_maj):
  201. generator = self.conv_sample_generator
  202. discriminator = self.maj_min_discriminator
  203. GAN = self.cg
  204. loss_history=[] ## this is for stroring the loss for every run
  205. min_idx = 0
  206. neb_epoch_count = 1
  207. labels = []
  208. for i in range(2 * self.gen):
  209. if i < self.gen:
  210. labels.append(np.array([1,0]))
  211. else:
  212. labels.append(np.array([0,1]))
  213. labels = np.array(labels)
  214. labels = tf.convert_to_tensor(labels)
  215. for step in range(neb_epochs * len(data_min)):
  216. ## generate minority neighbourhood batch for every minority class sampls by index
  217. min_batch = self._NMB_guided(data_min, min_idx)
  218. min_idx = min_idx + 1
  219. ## generate random proximal majority batch
  220. maj_batch = self._BMB(data_min, data_maj)
  221. ## generate synthetic samples from convex space
  222. ## of minority neighbourhood batch using generator
  223. conv_samples = generator.predict(min_batch)
  224. ## concatenate them with the majority batch
  225. concat_sample = tf.concat([conv_samples, maj_batch], axis=0)
  226. ## switch on discriminator training
  227. discriminator.trainable = True
  228. ## train the discriminator with the concatenated samples and the one-hot encoded labels
  229. discriminator.fit(x=concat_sample, y=labels, verbose=0)
  230. ## switch off the discriminator training again
  231. discriminator.trainable = False
  232. ## use the GAN to make the generator learn on the decisions
  233. ## made by the previous discriminator training
  234. gan_loss_history = GAN.fit(concat_sample, y=labels, verbose=0)
  235. ## store the loss for the step
  236. loss_history.append(gan_loss_history.history['loss'])
  237. if self.debug and ((step + 1) % 10 == 0):
  238. print(f"{step + 1} neighbourhood batches trained; running neighbourhood epoch {neb_epoch_count}")
  239. if min_idx == len(data_min) - 1:
  240. if self.debug:
  241. print(f"Neighbourhood epoch {neb_epoch_count} complete")
  242. neb_epoch_count = neb_epoch_count + 1
  243. min_idx = 0
  244. if self.debug:
  245. run_range = range(1, len(loss_history) + 1)
  246. plt.rcParams["figure.figsize"] = (16,10)
  247. plt.xticks(fontsize=20)
  248. plt.yticks(fontsize=20)
  249. plt.xlabel('runs', fontsize=25)
  250. plt.ylabel('loss', fontsize=25)
  251. plt.title('Rough learning loss for discriminator', fontsize=25)
  252. plt.plot(run_range, loss_history)
  253. plt.show()
  254. self.conv_sample_generator = generator
  255. self.maj_min_discriminator = discriminator
  256. self.cg = GAN
  257. self.loss_history = loss_history
  258. ## convGAN
  259. def _BMB(self, data_min, data_maj):
  260. ## Generate a borderline majority batch
  261. ## data_min -> minority class data
  262. ## data_maj -> majority class data
  263. ## neb -> oversampling neighbourhood
  264. ## gen -> convex combinations generated from each neighbourhood
  265. neigh = NearestNeighbors(self.neb)
  266. neigh.fit(data_maj)
  267. bmbi = [
  268. neigh.kneighbors([data_min[i]], self.neb, return_distance=False)
  269. for i in range(len(data_min))
  270. ]
  271. bmbi = np.unique(np.array(bmbi).flatten())
  272. bmbi = shuffle(bmbi)
  273. return tf.convert_to_tensor(
  274. data_maj[np.random.randint(len(data_maj), size=self.gen)]
  275. )
  276. def _NMB_guided(self, data_min, index):
  277. ## generate a minority neighbourhood batch for a particular minority sample
  278. ## we need this for minority data generation
  279. ## we will generate synthetic samples for each training data neighbourhood
  280. ## index -> index of the minority sample in a training data whose neighbourhood we want to obtain
  281. ## data_min -> minority class data
  282. ## neb -> oversampling neighbourhood
  283. neigh = NearestNeighbors(self.neb)
  284. neigh.fit(data_min)
  285. nmbi = neigh.kneighbors([data_min[index]], self.neb, return_distance=False)
  286. nmbi = shuffle(nmbi)
  287. nmb = data_min[nmbi]
  288. nmb = tf.convert_to_tensor(nmb[0])
  289. return nmb
  290. ## this is the main training process where the GAn learns to generate appropriate samples from the convex space
  291. ## this is the first training phase for the discriminator and the only training phase for the generator.
  292. def rough_learning_predictions(discriminator,test_data_numpy,test_labels_numpy):
  293. ## after the first phase of training the discriminator can be used for classification
  294. ## it already learns to differentiate the convex minority points with majority points during the first training phase
  295. y_pred_2d=discriminator.predict(tf.convert_to_tensor(test_data_numpy))
  296. ## discretisation of the labels
  297. y_pred=np.digitize(y_pred_2d[:,0], [.5])
  298. ## prediction shows a model with good recall and less precision
  299. c=confusion_matrix(test_labels_numpy, y_pred)
  300. f=f1_score(test_labels_numpy, y_pred)
  301. pr=precision_score(test_labels_numpy, y_pred)
  302. rc=recall_score(test_labels_numpy, y_pred)
  303. k=cohen_kappa_score(test_labels_numpy, y_pred)
  304. print('Rough learning confusion matrix:', c)
  305. print('Rough learning f1 score', f)
  306. print('Rough learning precision score', pr)
  307. print('Rough learning recall score', rc)
  308. print('Rough learning kappa score', k)
  309. return c,f,pr,rc,k
  310. def generate_synthetic_data(gan, data_min, data_maj):
  311. ## roughly claculate the upper bound of the synthetic samples to be generated from each neighbourhood
  312. synth_num=((len(data_maj)-len(data_min))//len(data_min))+1
  313. ## generate synth_num synthetic samples from each minority neighbourhood
  314. synth_set = gan.generateData(synth_num)
  315. ovs_min_class=np.concatenate((data_min,synth_set),axis=0)
  316. ovs_training_dataset=np.concatenate((ovs_min_class,data_maj),axis=0)
  317. ovs_pca_labels=np.concatenate((np.zeros(len(data_min)),np.zeros(len(synth_set))+1,np.zeros(len(data_maj))+2))
  318. # TODO ovs_training_labels=np.concatenate((np.zeros(len(ovs_min_class))+1,np.zeros(len(data_maj))+0))
  319. ovs_training_labels_oh=[]
  320. for i in range(len(ovs_training_dataset)):
  321. if i<len(ovs_min_class):
  322. ovs_training_labels_oh.append(np.array([1,0]))
  323. else:
  324. ovs_training_labels_oh.append(np.array([0,1]))
  325. ovs_training_labels_oh=np.array(ovs_training_labels_oh)
  326. ovs_training_labels_oh=tf.convert_to_tensor(ovs_training_labels_oh)
  327. ## PCA visualization of the synthetic sata
  328. ## observe how the minority samples from convex space have optimal variance and avoids overlap with the majority
  329. pca = PCA(n_components=2)
  330. pca.fit(ovs_training_dataset)
  331. data_pca= pca.transform(ovs_training_dataset)
  332. ## plot PCA
  333. plt.rcParams["figure.figsize"] = (12,12)
  334. # TODO colors=['r', 'b', 'g']
  335. plt.xticks(fontsize=20)
  336. plt.yticks(fontsize=20)
  337. plt.xlabel('PCA1',fontsize=25)
  338. plt.ylabel('PCA2', fontsize=25)
  339. plt.title('PCA plot of oversampled data',fontsize=25)
  340. classes = ['minority', 'synthetic minority', 'majority']
  341. scatter=plt.scatter(data_pca[:,0], data_pca[:,1], c=ovs_pca_labels, cmap='Set1')
  342. plt.legend(handles=scatter.legend_elements()[0], labels=classes, fontsize=20)
  343. plt.show()
  344. return ovs_training_dataset, ovs_pca_labels, ovs_training_labels_oh
  345. def final_learning(discriminator, ovs_training_dataset, ovs_training_labels_oh, test_data_numpy, test_labels_numpy, num_epochs):
  346. print('\n')
  347. print('Final round training of the discrminator as a majority-minority classifier')
  348. print('\n')
  349. ## second phase training of the discriminator with balanced data
  350. history_second_learning=discriminator.fit(x=ovs_training_dataset,y=ovs_training_labels_oh, batch_size=20, epochs=num_epochs)
  351. ## loss of the second phase learning smoothly decreses
  352. ## this is because now the data is fixed and diverse convex combinations are no longer fed into the discriminator at every training step
  353. run_range=range(1,num_epochs+1)
  354. plt.rcParams["figure.figsize"] = (16,10)
  355. plt.xticks(fontsize=20)
  356. plt.yticks(fontsize=20)
  357. plt.xlabel('runs',fontsize=25)
  358. plt.ylabel('loss', fontsize=25)
  359. plt.title('Final learning loss for discriminator', fontsize=25)
  360. plt.plot(run_range, history_second_learning.history['loss'])
  361. plt.show()
  362. ## finally after second phase training the discriminator classifier has a more balanced performance
  363. ## meaning better F1-Score
  364. ## the recall decreases but the precision improves
  365. print('\n')
  366. y_pred_2d=discriminator.predict(tf.convert_to_tensor(test_data_numpy))
  367. y_pred=np.digitize(y_pred_2d[:,0], [.5])
  368. c=confusion_matrix(test_labels_numpy, y_pred)
  369. f=f1_score(test_labels_numpy, y_pred)
  370. pr=precision_score(test_labels_numpy, y_pred)
  371. rc=recall_score(test_labels_numpy, y_pred)
  372. k=cohen_kappa_score(test_labels_numpy, y_pred)
  373. print('Final learning confusion matrix:', c)
  374. print('Final learning f1 score', f)
  375. print('Final learning precision score', pr)
  376. print('Final learning recall score', rc)
  377. print('Final learning kappa score', k)
  378. return c,f,pr,rc,k
  379. def convGAN_train_end_to_end(training_data,training_labels,test_data,test_labels, neb, gen, neb_epochs,epochs_retrain_disc):
  380. ##minority class
  381. data_min=training_data[np.where(training_labels == 1)[0]]
  382. ##majority class
  383. data_maj=training_data[np.where(training_labels == 0)[0]]
  384. dataSet = DataSet(data0=data_maj, data1=data_min)
  385. gan = ConvGAN(data_min.shape[1], neb, gen)
  386. gan.reset()
  387. ## instanciate generator network and visualize architecture
  388. conv_sample_generator = gan.conv_sample_generator
  389. print(conv_sample_generator.summary())
  390. print('\n')
  391. ## instanciate discriminator network and visualize architecture
  392. maj_min_discriminator = gan.maj_min_discriminator
  393. print(maj_min_discriminator.summary())
  394. print('\n')
  395. ## instanciate network and visualize architecture
  396. cg = gan.cg
  397. print(cg.summary())
  398. print('\n')
  399. print('Training the GAN, first round training of the discrminator as a majority-minority classifier')
  400. print('\n')
  401. ## train gan generator ## rough_train_discriminator
  402. gan.train(dataSet, neb_epochs)
  403. print('\n')
  404. ## rough learning results
  405. c_r,f_r,pr_r,rc_r,k_r = rough_learning_predictions(gan.maj_min_discriminator_r, test_data, test_labels)
  406. print('\n')
  407. ## generate synthetic data
  408. ovs_training_dataset, ovs_pca_labels, ovs_training_labels_oh = generate_synthetic_data(gan, data_min, data_maj)
  409. print('\n')
  410. ## final training results
  411. c,f,pr,rc,k=final_learning(gan.maj_min_discriminator, ovs_training_dataset, ovs_training_labels_oh, test_data, test_labels, epochs_retrain_disc)
  412. return ((c_r,f_r,pr_r,rc_r,k_r),(c,f,pr,rc,k))
  413. def unison_shuffled_copies(a, b,seed_perm):
  414. 'Shuffling the feature matrix along with the labels with same order'
  415. np.random.seed(seed_perm)##change seed 1,2,3,4,5
  416. assert len(a) == len(b)
  417. p = np.random.permutation(len(a))
  418. return a[p], b[p]
  419. def runTest():
  420. seed_num=1
  421. seed(seed_num)
  422. tf.random.set_seed(seed_num)
  423. ## Import dataset
  424. data = fetch_datasets()['yeast_me2']
  425. ## Creating label and feature matrices
  426. labels_x = data.target ## labels of the data
  427. features_x = data.data ## features of the data
  428. # Until now we have obtained the data. We divided it into training and test sets. we separated obtained seperate variables for the majority and miority classes and their labels for both sets.
  429. ## specify parameters
  430. neb=gen=5 ##neb=gen required
  431. neb_epochs=10
  432. epochs_retrain_disc=50
  433. # TODO n_feat=len(features_x[1]) ## number of features
  434. ## Training
  435. np.random.seed(42)
  436. strata=5
  437. results=[]
  438. for seed_perm in range(strata):
  439. features_x,labels_x=unison_shuffled_copies(features_x,labels_x,seed_perm)
  440. ### Extracting all features and labels
  441. print('Extracting all features and labels for seed:'+ str(seed_perm)+'\n')
  442. ## Dividing data into training and testing datasets for 10-fold CV
  443. print('Dividing data into training and testing datasets for 10-fold CV for seed:'+ str(seed_perm)+'\n')
  444. label_1=list(np.where(labels_x == 1)[0])
  445. features_1=features_x[label_1]
  446. label_0=list(np.where(labels_x != 1)[0])
  447. features_0=features_x[label_0]
  448. a=len(features_1)//5
  449. b=len(features_0)//5
  450. fold_1_min=features_1[0:a]
  451. fold_1_maj=features_0[0:b]
  452. fold_1_tst=np.concatenate((fold_1_min,fold_1_maj))
  453. lab_1_tst=np.concatenate((np.zeros(len(fold_1_min))+1, np.zeros(len(fold_1_maj))))
  454. fold_2_min=features_1[a:2*a]
  455. fold_2_maj=features_0[b:2*b]
  456. fold_2_tst=np.concatenate((fold_2_min,fold_2_maj))
  457. lab_2_tst=np.concatenate((np.zeros(len(fold_1_min))+1, np.zeros(len(fold_1_maj))))
  458. fold_3_min=features_1[2*a:3*a]
  459. fold_3_maj=features_0[2*b:3*b]
  460. fold_3_tst=np.concatenate((fold_3_min,fold_3_maj))
  461. lab_3_tst=np.concatenate((np.zeros(len(fold_1_min))+1, np.zeros(len(fold_1_maj))))
  462. fold_4_min=features_1[3*a:4*a]
  463. fold_4_maj=features_0[3*b:4*b]
  464. fold_4_tst=np.concatenate((fold_4_min,fold_4_maj))
  465. lab_4_tst=np.concatenate((np.zeros(len(fold_1_min))+1, np.zeros(len(fold_1_maj))))
  466. fold_5_min=features_1[4*a:]
  467. fold_5_maj=features_0[4*b:]
  468. fold_5_tst=np.concatenate((fold_5_min,fold_5_maj))
  469. lab_5_tst=np.concatenate((np.zeros(len(fold_5_min))+1, np.zeros(len(fold_5_maj))))
  470. fold_1_trn=np.concatenate((fold_2_min,fold_3_min,fold_4_min,fold_5_min, fold_2_maj,fold_3_maj,fold_4_maj,fold_5_maj))
  471. lab_1_trn=np.concatenate((np.zeros(3*a+len(fold_5_min))+1,np.zeros(3*b+len(fold_5_maj))))
  472. fold_2_trn=np.concatenate((fold_1_min,fold_3_min,fold_4_min,fold_5_min,fold_1_maj,fold_3_maj,fold_4_maj,fold_5_maj))
  473. lab_2_trn=np.concatenate((np.zeros(3*a+len(fold_5_min))+1,np.zeros(3*b+len(fold_5_maj))))
  474. fold_3_trn=np.concatenate((fold_2_min,fold_1_min,fold_4_min,fold_5_min,fold_2_maj,fold_1_maj,fold_4_maj,fold_5_maj))
  475. lab_3_trn=np.concatenate((np.zeros(3*a+len(fold_5_min))+1,np.zeros(3*b+len(fold_5_maj))))
  476. fold_4_trn=np.concatenate((fold_2_min,fold_3_min,fold_1_min,fold_5_min,fold_2_maj,fold_3_maj,fold_1_maj,fold_5_maj))
  477. lab_4_trn=np.concatenate((np.zeros(3*a+len(fold_5_min))+1,np.zeros(3*b+len(fold_5_maj))))
  478. fold_5_trn=np.concatenate((fold_2_min,fold_3_min,fold_4_min,fold_1_min,fold_2_maj,fold_3_maj,fold_4_maj,fold_1_maj))
  479. lab_5_trn=np.concatenate((np.zeros(4*a)+1,np.zeros(4*b)))
  480. training_folds_feats=[fold_1_trn,fold_2_trn,fold_3_trn,fold_4_trn,fold_5_trn]
  481. testing_folds_feats=[fold_1_tst,fold_2_tst,fold_3_tst,fold_4_tst,fold_5_tst]
  482. training_folds_labels=[lab_1_trn,lab_2_trn,lab_3_trn,lab_4_trn,lab_5_trn]
  483. testing_folds_labels=[lab_1_tst,lab_2_tst,lab_3_tst,lab_4_tst,lab_5_tst]
  484. for i in range(5):
  485. print('\n')
  486. print('Executing fold: '+str(i+1))
  487. print('\n')
  488. r1,r2=convGAN_train_end_to_end(training_folds_feats[i],training_folds_labels[i],testing_folds_feats[i],testing_folds_labels[i], neb, gen, neb_epochs, epochs_retrain_disc)
  489. results.append(np.array([list(r1[1:]),list(r2[1:])]))
  490. results=np.array(results)
  491. ## Benchmark
  492. mean_rough=np.mean(results[:,0], axis=0)
  493. data_r={'F1-Score_r':[mean_rough[0]], 'Precision_r' : [mean_rough[1]], 'Recall_r' : [mean_rough[2]], 'Kappa_r': [mean_rough[3]]}
  494. df_r=pd.DataFrame(data=data_r)
  495. print('Rough training results:')
  496. print('\n')
  497. print(df_r)
  498. mean_final=np.mean(results[:,1], axis=0)
  499. data_f={'F1-Score_f':[mean_final[0]], 'Precision_f' : [mean_final[1]], 'Recall_f' : [mean_final[2]], 'Kappa_f': [mean_final[3]]}
  500. df_f=pd.DataFrame(data=data_f)
  501. print('Final training results:')
  502. print('\n')
  503. print(df_f)
  504. if __name__ == "__main__":
  505. runTest()