Parcourir la source

Cleaned ConvGeN code.

Kristian Schultz il y a 3 ans
Parent
commit
ac1fda6a35
1 fichiers modifiés avec 36 ajouts et 30 suppressions
  1. 36 30
      library/generators/ConvGeN.py

+ 36 - 30
library/generators/ConvGeN.py

@@ -32,8 +32,7 @@ def create01Labels(totalSize, sizeFirstHalf):
 
 
 class ConvGeN(GanBaseClass):
 class ConvGeN(GanBaseClass):
     """
     """
-    This is a toy example of a GAN.
-    It repeats the first point of the training-data-set.
+    This is the ConvGeN class. ConvGeN is a synthetic point generator for imbalanced datasets.
     """
     """
     def __init__(self, n_feat, neb=5, gen=None, neb_epochs=10, maj_proximal=False, debug=False):
     def __init__(self, n_feat, neb=5, gen=None, neb_epochs=10, maj_proximal=False, debug=False):
         self.isTrained = False
         self.isTrained = False
@@ -57,7 +56,11 @@ class ConvGeN(GanBaseClass):
 
 
     def reset(self, dataSet):
     def reset(self, dataSet):
         """
         """
-        Resets the trained GAN to an random state.
+        Creates the network.
+
+        *dataSet* is a instance of /library.dataset.DataSet/ or None.
+        It contains the training dataset.
+        It is used to determine the neighbourhood size if /neb/ in /__init__/ was None.
         """
         """
         self.isTrained = False
         self.isTrained = False
 
 
@@ -95,12 +98,11 @@ class ConvGeN(GanBaseClass):
 
 
     def train(self, dataSet, discTrainCount=5):
     def train(self, dataSet, discTrainCount=5):
         """
         """
-        Trains the GAN.
-
-        It stores the data points in the training data set and mark as trained.
+        Trains the Network.
 
 
         *dataSet* is a instance of /library.dataset.DataSet/. It contains the training dataset.
         *dataSet* is a instance of /library.dataset.DataSet/. It contains the training dataset.
-        We are only interested in the first *maxListSize* points in class 1.
+        
+        *discTrainCount* gives the number of extra training for the discriminator for each epoch. (>= 0)
         """
         """
         if dataSet.data1.shape[0] <= 0:
         if dataSet.data1.shape[0] <= 0:
             raise AttributeError("Train: Expected data class 1 to contain at least one point.")
             raise AttributeError("Train: Expected data class 1 to contain at least one point.")
@@ -136,7 +138,7 @@ class ConvGeN(GanBaseClass):
         *numOfSamples* is a integer > 0. It gives the number of new generated samples.
         *numOfSamples* is a integer > 0. It gives the number of new generated samples.
         """
         """
         if not self.isTrained:
         if not self.isTrained:
-            raise ValueError("Try to generate data with untrained Re.")
+            raise ValueError("Try to generate data with untrained network.")
 
 
         ## roughly claculate the upper bound of the synthetic samples to be generated from each neighbourhood
         ## roughly claculate the upper bound of the synthetic samples to be generated from each neighbourhood
         synth_num = (numOfSamples // self.minSetSize) + 1
         synth_num = (numOfSamples // self.minSetSize) + 1
@@ -152,6 +154,11 @@ class ConvGeN(GanBaseClass):
         return synth_set
         return synth_set
 
 
     def predictReal(self, data):
     def predictReal(self, data):
+        """
+        Uses the discriminator on data.
+        
+        *data* is a numpy array of shape (n, n_feat) where n is the number of datapoints and n_feat the number of features.
+        """
         prediction = self.maj_min_discriminator.predict(data)
         prediction = self.maj_min_discriminator.predict(data)
         return np.array([x[0] for x in prediction])
         return np.array([x[0] for x in prediction])
 
 
@@ -159,10 +166,10 @@ class ConvGeN(GanBaseClass):
     # Hidden internal functions
     # Hidden internal functions
     # ###############################################################
     # ###############################################################
 
 
-    # Creating the GAN
+    # Creating the Network: Generator
     def _conv_sample_gen(self):
     def _conv_sample_gen(self):
         """
         """
-        the generator network to generate synthetic samples from the convex space
+        The generator network to generate synthetic samples from the convex space
         of arbitrary minority neighbourhoods
         of arbitrary minority neighbourhoods
         """
         """
 
 
@@ -181,17 +188,17 @@ class ConvGeN(GanBaseClass):
 
 
         ## again, witching to 2-D tensor once we have the convenient shape
         ## again, witching to 2-D tensor once we have the convenient shape
         x = Reshape((self.neb, self.gen))(x)
         x = Reshape((self.neb, self.gen))(x)
-        ## row wise sum
+        ## column wise sum
         s = K.sum(x, axis=1)
         s = K.sum(x, axis=1)
-        ## adding a small constant to always ensure the row sums are non zero.
+        ## adding a small constant to always ensure the column sums are non zero.
         ## if this is not done then during initialization the sum can be zero.
         ## if this is not done then during initialization the sum can be zero.
         s_non_zero = Lambda(lambda x: x + .000001)(s)
         s_non_zero = Lambda(lambda x: x + .000001)(s)
-        ## reprocals of the approximated row sum
+        ## reprocals of the approximated column sum
         sinv = tf.math.reciprocal(s_non_zero)
         sinv = tf.math.reciprocal(s_non_zero)
-        ## At this step we ensure that row sum is 1 for every row in x.
-        ## That means, each row is set of convex co-efficient
+        ## At this step we ensure that column sum is 1 for every row in x.
+        ## That means, each column is set of convex co-efficient
         x = Multiply()([sinv, x])
         x = Multiply()([sinv, x])
-        ## Now we transpose the matrix. So each column is now a set of convex coefficients
+        ## Now we transpose the matrix. So each row is now a set of convex coefficients
         aff=tf.transpose(x[0])
         aff=tf.transpose(x[0])
         ## We now do matrix multiplication of the affine combinations with the original
         ## We now do matrix multiplication of the affine combinations with the original
         ## minority batch taken as input. This generates a convex transformation
         ## minority batch taken as input. This generates a convex transformation
@@ -204,13 +211,14 @@ class ConvGeN(GanBaseClass):
         model.compile(loss='mean_squared_logarithmic_error', optimizer=opt)
         model.compile(loss='mean_squared_logarithmic_error', optimizer=opt)
         return model
         return model
 
 
+    # Creating the Network: discriminator
     def _maj_min_disc(self):
     def _maj_min_disc(self):
         """
         """
-        the discriminator is trained intwo phase:
-        first phase:  while training GAN the discriminator learns to differentiate synthetic
+        the discriminator is trained in two phase:
+        first phase:  while training ConvGeN the discriminator learns to differentiate synthetic
                       minority samples generated from convex minority data space against
                       minority samples generated from convex minority data space against
                       the borderline majority samples
                       the borderline majority samples
-        second phase: after the GAN generator learns to create synthetic samples,
+        second phase: after the ConvGeN generator learns to create synthetic samples,
                       it can be used to generate synthetic samples to balance the dataset
                       it can be used to generate synthetic samples to balance the dataset
                       and then rettrain the discriminator with the balanced dataset
                       and then rettrain the discriminator with the balanced dataset
         """
         """
@@ -233,6 +241,7 @@ class ConvGeN(GanBaseClass):
         model.compile(loss='binary_crossentropy', optimizer=opt)
         model.compile(loss='binary_crossentropy', optimizer=opt)
         return model
         return model
 
 
+    # Creating the Network: ConvGeN
     def _convGeN(self, generator, discriminator):
     def _convGeN(self, generator, discriminator):
         """
         """
         for joining the generator and the discriminator
         for joining the generator and the discriminator
@@ -240,7 +249,7 @@ class ConvGeN(GanBaseClass):
         maj_min_discriminator -> discriminator network instance
         maj_min_discriminator -> discriminator network instance
         """
         """
         ## by default the discriminator trainability is switched off.
         ## by default the discriminator trainability is switched off.
-        ## Thus training the GAN means training the generator network as per previously
+        ## Thus training ConvGeN means training the generator network as per previously
         ## trained discriminator network.
         ## trained discriminator network.
         discriminator.trainable = False
         discriminator.trainable = False
 
 
@@ -248,8 +257,6 @@ class ConvGeN(GanBaseClass):
         ## and a proximal majority batch concatenated
         ## and a proximal majority batch concatenated
         batch_data = Input(shape=(self.n_feat,))
         batch_data = Input(shape=(self.n_feat,))
         
         
-        ##- print(f"GAN: 0..{self.neb}/{self.gen}..")
-
         ## extract minority batch
         ## extract minority batch
         min_batch = Lambda(lambda x: x[:self.neb])(batch_data)
         min_batch = Lambda(lambda x: x[:self.neb])(batch_data)
         
         
@@ -262,11 +269,9 @@ class ConvGeN(GanBaseClass):
         
         
         ## concatenate the synthetic samples with the majority samples
         ## concatenate the synthetic samples with the majority samples
         new_samples = tf.concat([conv_samples, maj_batch],axis=0)
         new_samples = tf.concat([conv_samples, maj_batch],axis=0)
-        ##- new_samples = tf.concat([conv_samples, conv_samples, conv_samples, conv_samples],axis=0)
         
         
         ## pass the concatenated vector into the discriminator to know its decisions
         ## pass the concatenated vector into the discriminator to know its decisions
         output = discriminator(new_samples)
         output = discriminator(new_samples)
-        ##- output = Lambda(lambda x: x[:2 * self.gen])(output)
         
         
         ## note that, the discriminator will not be traied but will make decisions based
         ## note that, the discriminator will not be traied but will make decisions based
         ## on its previous training while using this function
         ## on its previous training while using this function
@@ -300,7 +305,7 @@ class ConvGeN(GanBaseClass):
     def _rough_learning(self, data_min, data_maj, discTrainCount):
     def _rough_learning(self, data_min, data_maj, discTrainCount):
         generator = self.conv_sample_generator
         generator = self.conv_sample_generator
         discriminator = self.maj_min_discriminator
         discriminator = self.maj_min_discriminator
-        GAN = self.cg
+        convGeN = self.cg
         loss_history = [] ## this is for stroring the loss for every run
         loss_history = [] ## this is for stroring the loss for every run
         step = 0
         step = 0
         minSetSize = len(data_min)
         minSetSize = len(data_min)
@@ -335,12 +340,14 @@ class ConvGeN(GanBaseClass):
                 ## generate minority neighbourhood batch for every minority class sampls by index
                 ## generate minority neighbourhood batch for every minority class sampls by index
                 min_batch_indices = shuffle(self.nmbMin.neighbourhoodOfItem(min_idx))
                 min_batch_indices = shuffle(self.nmbMin.neighbourhoodOfItem(min_idx))
                 min_batch = self.nmbMin.getPointsFromIndices(min_batch_indices)
                 min_batch = self.nmbMin.getPointsFromIndices(min_batch_indices)
+                
                 ## generate random proximal majority batch
                 ## generate random proximal majority batch
                 maj_batch = self._BMB(data_maj, min_batch_indices)
                 maj_batch = self._BMB(data_maj, min_batch_indices)
 
 
                 ## generate synthetic samples from convex space
                 ## generate synthetic samples from convex space
                 ## of minority neighbourhood batch using generator
                 ## of minority neighbourhood batch using generator
                 conv_samples = generator.predict(min_batch, batch_size=self.neb)
                 conv_samples = generator.predict(min_batch, batch_size=self.neb)
+                
                 ## concatenate them with the majority batch
                 ## concatenate them with the majority batch
                 concat_sample = tf.concat([conv_samples, maj_batch], axis=0)
                 concat_sample = tf.concat([conv_samples, maj_batch], axis=0)
 
 
@@ -351,13 +358,12 @@ class ConvGeN(GanBaseClass):
                 ## switch off the discriminator training again
                 ## switch off the discriminator training again
                 discriminator.trainable = False
                 discriminator.trainable = False
 
 
-                ## use the GAN to make the generator learn on the decisions
+                ## use the complete network to make the generator learn on the decisions
                 ## made by the previous discriminator training
                 ## made by the previous discriminator training
-                ##- print(f"concat sample shape: {concat_sample.shape}/{labels.shape}")
-                gan_loss_history = GAN.fit(concat_sample, y=labels, verbose=0, batch_size=nLabels)
+                gen_loss_history = convGeN.fit(concat_sample, y=labels, verbose=0, batch_size=nLabels)
 
 
                 ## store the loss for the step
                 ## store the loss for the step
-                loss_history.append(gan_loss_history.history['loss'])
+                loss_history.append(gen_loss_history.history['loss'])
 
 
                 step += 1
                 step += 1
                 if self.debug and (step % 10 == 0):
                 if self.debug and (step % 10 == 0):
@@ -379,7 +385,7 @@ class ConvGeN(GanBaseClass):
 
 
         self.conv_sample_generator = generator
         self.conv_sample_generator = generator
         self.maj_min_discriminator = discriminator
         self.maj_min_discriminator = discriminator
-        self.cg = GAN
+        self.cg = convGeN
         self.loss_history = loss_history
         self.loss_history = loss_history