|
|
@@ -60,11 +60,12 @@ class ConvGAN(GanBaseClass):
|
|
|
This is a toy example of a GAN.
|
|
|
It repeats the first point of the training-data-set.
|
|
|
"""
|
|
|
- def __init__(self, neb, gen):
|
|
|
+ def __init__(self, neb, gen, debug=False):
|
|
|
self.isTrained = False
|
|
|
self.neb = neb
|
|
|
self.gen = gen
|
|
|
self.loss_history = None
|
|
|
+ self.debug = debug
|
|
|
|
|
|
def reset(self):
|
|
|
"""
|
|
|
@@ -93,15 +94,7 @@ class ConvGAN(GanBaseClass):
|
|
|
raise AttributeError("Train: Expected data class 1 to contain at least one point.")
|
|
|
|
|
|
# TODO: do actually training
|
|
|
- self.conv_sample_generator, self.maj_min_discriminator_r , self.cg , self.loss_history = rough_learning(
|
|
|
- neb_epochs,
|
|
|
- dataSet.data1,
|
|
|
- dataSet.data0,
|
|
|
- self.neb,
|
|
|
- self.gen,
|
|
|
- self.conv_sample_generator,
|
|
|
- self.maj_min_discriminator,
|
|
|
- self.cg)
|
|
|
+ self.rough_learning(neb_epochs, dataSet.data1, dataSet.data0)
|
|
|
|
|
|
self.isTrained = True
|
|
|
|
|
|
@@ -127,6 +120,69 @@ class ConvGAN(GanBaseClass):
|
|
|
return np.array(syntheticPoints)
|
|
|
|
|
|
|
|
|
+ # Hidden internal functions
|
|
|
+
|
|
|
+ # Training
|
|
|
+ def _rough_learning(self, neb_epochs, data_min):
|
|
|
+ generator = self.conv_sample_generator
|
|
|
+ discriminator = self.maj_min_discriminator
|
|
|
+ GAN = self.cg
|
|
|
+ loss_history=[] ## this is for stroring the loss for every run
|
|
|
+ min_idx = 0
|
|
|
+ neb_epoch_count = 1
|
|
|
+
|
|
|
+ labels = []
|
|
|
+ for i in range(2 * self.gen):
|
|
|
+ if i < gen:
|
|
|
+ labels.append(np.array([1,0]))
|
|
|
+ else:
|
|
|
+ labels.append(np.array([0,1]))
|
|
|
+ labels = np.array(labels)
|
|
|
+ labels = tf.convert_to_tensor(labels)
|
|
|
+
|
|
|
+
|
|
|
+ for step in range(neb_epochs * len(data_min)):
|
|
|
+ min_batch = NMB_guided(data_min, self.neb, min_idx) ## generate minority neighbourhood batch for every minority class sampls by index
|
|
|
+ min_idx = min_idx + 1
|
|
|
+ maj_batch = BMB(data_min,data_maj, self.neb, self.gen) ## generate random proximal majority batch
|
|
|
+
|
|
|
+ conv_samples = generator.predict(min_batch) ## generate synthetic samples from convex space of minority neighbourhood batch using generator
|
|
|
+ concat_sample = tf.concat([conv_samples, maj_batch], axis=0) ## concatenate them with the majority batch
|
|
|
+
|
|
|
+ discriminator.trainable = True ## switch on discriminator training
|
|
|
+ discriminator.fit(x=concat_sample, y=labels, verbose=0) ## train the discriminator with the concatenated samples and the one-hot encoded labels
|
|
|
+ discriminator.trainable = False ## switch off the discriminator training again
|
|
|
+
|
|
|
+ gan_loss_history = GAN.fit(concat_sample, y=labels, verbose=0) ## use the GAN to make the generator learn on the decisions made by the previous discriminator training
|
|
|
+
|
|
|
+ loss_history.append(gan_loss_history.history['loss']) ## store the loss for the step
|
|
|
+
|
|
|
+ if self.debug and ((step + 1) % 10 == 0):
|
|
|
+ print(f"{step + 1} neighbourhood batches trained; running neighbourhood epoch {neb_epoch_count}")
|
|
|
+
|
|
|
+ if min_idx == len(data_min) - 1:
|
|
|
+ if self.debug:
|
|
|
+ print(f"Neighbourhood epoch {neb_epoch_count} complete")
|
|
|
+ neb_epoch_count = neb_epoch_count + 1
|
|
|
+ min_idx = 0
|
|
|
+
|
|
|
+ if self.debug:
|
|
|
+ run_range = range(1, len(loss_history) + 1)
|
|
|
+ plt.rcParams["figure.figsize"] = (16,10)
|
|
|
+ plt.xticks(fontsize=20)
|
|
|
+ plt.yticks(fontsize=20)
|
|
|
+ plt.xlabel('runs', fontsize=25)
|
|
|
+ plt.ylabel('loss', fontsize=25)
|
|
|
+ plt.title('Rough learning loss for discriminator', fontsize=25)
|
|
|
+ plt.plot(run_range, loss_history)
|
|
|
+ plt.show()
|
|
|
+
|
|
|
+ self.conv_sample_generator = generator
|
|
|
+ self.maj_min_discriminator = discriminator
|
|
|
+ self.cg = GAN
|
|
|
+ self.loss_history = loss_history
|
|
|
+
|
|
|
+
|
|
|
|
|
|
## convGAN
|
|
|
def unison_shuffled_copies(a, b,seed_perm):
|
|
|
@@ -244,62 +300,6 @@ def convGAN(generator,discriminator):
|
|
|
## this is the first training phase for the discriminator and the only training phase for the generator.
|
|
|
|
|
|
|
|
|
-def rough_learning(neb_epochs,data_min,data_maj,neb,gen,generator, discriminator,GAN):
|
|
|
-
|
|
|
-
|
|
|
- step=1
|
|
|
- loss_history=[] ## this is for stroring the loss for every run
|
|
|
- min_idx=0
|
|
|
- neb_epoch_count=1
|
|
|
-
|
|
|
- labels=[]
|
|
|
- for i in range(2*gen):
|
|
|
- if i<gen:
|
|
|
- labels.append(np.array([1,0]))
|
|
|
- else:
|
|
|
- labels.append(np.array([0,1]))
|
|
|
- labels=np.array(labels)
|
|
|
- labels=tf.convert_to_tensor(labels)
|
|
|
-
|
|
|
-
|
|
|
- while step<(neb_epochs*len(data_min)):
|
|
|
-
|
|
|
-
|
|
|
- min_batch=NMB_guided(data_min, neb, min_idx) ## generate minority neighbourhood batch for every minority class sampls by index
|
|
|
- min_idx=min_idx+1
|
|
|
- maj_batch=BMB(data_min,data_maj,neb,gen) ## generate random proximal majority batch
|
|
|
-
|
|
|
- conv_samples=generator.predict(min_batch) ## generate synthetic samples from convex space of minority neighbourhood batch using generator
|
|
|
- concat_sample=tf.concat([conv_samples,maj_batch],axis=0) ## concatenate them with the majority batch
|
|
|
-
|
|
|
- discriminator.trainable=True ## switch on discriminator training
|
|
|
- discriminator.fit(x=concat_sample,y=labels,verbose=0) ## train the discriminator with the concatenated samples and the one-hot encoded labels
|
|
|
- discriminator.trainable=False ## switch off the discriminator training again
|
|
|
-
|
|
|
- gan_loss_history=GAN.fit(concat_sample,y=labels,verbose=0) ## use the GAN to make the generator learn on the decisions made by the previous discriminator training
|
|
|
-
|
|
|
- loss_history.append(gan_loss_history.history['loss']) ## store the loss for the step
|
|
|
-
|
|
|
- if step%10 == 0:
|
|
|
- print(str(step)+' neighbourhood batches trained; running neighbourhood epoch ' + str(neb_epoch_count))
|
|
|
-
|
|
|
- if min_idx==len(data_min)-1:
|
|
|
- print(str('Neighbourhood epoch '+ str(neb_epoch_count) +' complete'))
|
|
|
- neb_epoch_count=neb_epoch_count+1
|
|
|
- min_idx=0
|
|
|
-
|
|
|
-
|
|
|
- step=step+1
|
|
|
- run_range=range(1,len(loss_history)+1)
|
|
|
- plt.rcParams["figure.figsize"] = (16,10)
|
|
|
- plt.xticks(fontsize=20)
|
|
|
- plt.yticks(fontsize=20)
|
|
|
- plt.xlabel('runs',fontsize=25)
|
|
|
- plt.ylabel('loss', fontsize=25)
|
|
|
- plt.title('Rough learning loss for discriminator', fontsize=25)
|
|
|
- plt.plot(run_range, loss_history)
|
|
|
- plt.show()
|
|
|
- return generator, discriminator, GAN, loss_history
|
|
|
|
|
|
|
|
|
|