Bläddra i källkod

Wrapped batch so we can train more than one neighborhood at once in the future.

Kristian Schultz 3 år sedan
förälder
incheckning
260553a9a0
1 ändrade filer med 95 tillägg och 71 borttagningar
  1. 95 71
      library/generators/NextConvGeN.py

+ 95 - 71
library/generators/NextConvGeN.py

@@ -55,7 +55,7 @@ class NextConvGeN(GanBaseClass):
         self.lastProgress = (-1,-1,-1)
         
         self.timing = { n: timing(n) for n in [
-            "Train", "BMB", "NbhSearch", "NBH", "GenSamples", "Fit"
+            "Train", "BMB", "NbhSearch", "NBH", "GenSamples", "Fit", "FixType"
             ] }
 
         if self.neb is not None and self.gen is not None and self.neb > self.gen:
@@ -189,13 +189,10 @@ class NextConvGeN(GanBaseClass):
         """
 
         ## takes minority batch as input
-        min_neb_batch = Input(shape=(self.n_feat,))
+        min_neb_batch = Input(shape=(self.neb, self.n_feat,))
 
-        ## reshaping the 2D tensor to 3D for using 1-D convolution,
-        ## otherwise 1-D convolution won't work.
-        x = tf.reshape(min_neb_batch, (1, self.neb, self.n_feat), name=None)
         ## using 1-D convolution, feature dimension remains the same
-        x = Conv1D(self.n_feat, 3, activation='relu')(x)
+        x = Conv1D(self.n_feat, 3, activation='relu')(min_neb_batch)
         ## flatten after convolution
         x = Flatten()(x)
         ## add dense layer to transform the vector to a convenient dimension
@@ -270,23 +267,26 @@ class NextConvGeN(GanBaseClass):
 
         ## input receives a neighbourhood minority batch
         ## and a proximal majority batch concatenated
-        batch_data = Input(shape=(self.n_feat,))
+        batch_data = Input(shape=(2, self.gen, self.n_feat,))
         
         ## extract minority batch
-        min_batch = Lambda(lambda x: x[:self.neb])(batch_data)
+        min_batch = Lambda(lambda x: x[:, 0, : ,:], name="SplitForGen")(batch_data)
         
         ## extract majority batch
-        maj_batch = Lambda(lambda x: x[self.gen:])(batch_data)
+        maj_batch = Lambda(lambda x: x[:, 1, :, :], name="SplitForDisc")(batch_data)
+        maj_batch = tf.reshape(maj_batch, (-1, self.n_feat), name="ReshapeForDisc")
         
         ## pass minority batch into generator to obtain convex space transformation
         ## (synthetic samples) of the minority neighbourhood input batch
         conv_samples = generator(min_batch)
+        conv_samples = tf.reshape(conv_samples, (-1, self.n_feat), name="ReshapeGenOutput")
+
+        ## pass samples into the discriminator to know its decisions
+        conv_samples = discriminator(conv_samples)
+        maj_batch = discriminator(maj_batch)
         
-        ## concatenate the synthetic samples with the majority samples
-        new_samples = tf.concat([conv_samples, maj_batch],axis=0)
-        
-        ## pass the concatenated vector into the discriminator to know its decisions
-        output = discriminator(new_samples)
+        ## concatenate the decisions
+        output = tf.concat([conv_samples, maj_batch],axis=0)
         
         ## note that, the discriminator will not be traied but will make decisions based
         ## on its previous training while using this function
@@ -309,8 +309,8 @@ class NextConvGeN(GanBaseClass):
         synth_set = []
         for _run in range(runs):
             batch = self.nmbMin.getNbhPointsOfItem(index)
-            synth_batch = self.conv_sample_generator.predict(batch, batch_size=self.neb)
-            synth_batch = self.correct_feature_types(batch, synth_batch)
+            synth_batch = self.conv_sample_generator.predict(tf.reshape(batch, (1, self.neb, self.n_feat)), batch_size=self.neb)
+            synth_batch = self.correct_feature_types(batch, synth_batch[0])
             synth_set.extend(synth_batch)
 
         return synth_set[:synth_num]
@@ -323,79 +323,73 @@ class NextConvGeN(GanBaseClass):
         discriminator = self.maj_min_discriminator
         convGeN = self.cg
         loss_history = [] ## this is for stroring the loss for every run
-        step = 0
         minSetSize = len(data)
 
-        labels = tf.convert_to_tensor(create01Labels(2 * self.gen, self.gen))
         nLabels = 2 * self.gen
+        labels = np.array(create01Labels(nLabels, self.gen))
+        labelsGeN = np.array([labels])
+        
+        
+        def createSamples(min_idx):
+            self.timing["NBH"].start()
+            ## generate minority neighbourhood batch for every minority class sampls by index
+            min_batch_indices = shuffle(self.nmbMin.neighbourhoodOfItem(min_idx))
+            min_batch = self.nmbMin.getPointsFromIndices(min_batch_indices)
+
+            ## generate random proximal majority batch
+            maj_batch = self._BMB(min_batch_indices)
+            self.timing["NBH"].stop()
+
+            self.timing["GenSamples"].start()
+            ## generate synthetic samples from convex space
+            ## of minority neighbourhood batch using generator
+            conv_samples = generator.predict(np.array([min_batch]), batch_size=self.neb)
+            conv_samples = tf.reshape(conv_samples, shape=(self.gen, self.n_feat))
+            self.timing["GenSamples"].stop()
+
+            self.timing["FixType"].start()
+            ## Fix feature types
+            conv_samples = self.correct_feature_types_tf(min_batch, conv_samples)
+            self.timing["FixType"].stop()
+
+            ## concatenate them with the majority batch
+            conv_samples = [conv_samples, maj_batch]
+            return conv_samples
+
+        def trainDiscriminator(samples):
+            concat_samples = tf.concat(samples, axis=0)
+            self.timing["Fit"].start()
+            ## switch on discriminator training
+            discriminator.trainable = True
+            ## train the discriminator with the concatenated samples and the one-hot encoded labels
+            discriminator.fit(x=concat_samples, y=labels, verbose=0, batch_size=20)
+            ## switch off the discriminator training again
+            discriminator.trainable = False
+            self.timing["Fit"].stop()
+
+        
 
         for neb_epoch_count in range(self.neb_epochs):
             if discTrainCount > 0:
                 for n in range(discTrainCount):
                     for min_idx in range(minSetSize):
                         self.progressBar([(neb_epoch_count + 1) / self.neb_epochs, n / discTrainCount, (min_idx + 1) / minSetSize])
-                        self.timing["NBH"].start()
-                        ## generate minority neighbourhood batch for every minority class sampls by index
-                        min_batch_indices = shuffle(self.nmbMin.neighbourhoodOfItem(min_idx))
-                        min_batch = self.nmbMin.getPointsFromIndices(min_batch_indices)
-                        ## generate random proximal majority batch
-                        maj_batch = self._BMB(min_batch_indices)
-                        self.timing["NBH"].stop()
-
-                        self.timing["GenSamples"].start()
-                        ## generate synthetic samples from convex space
-                        ## of minority neighbourhood batch using generator
-                        conv_samples = generator.predict(min_batch, batch_size=self.neb)
-                        ## concatenate them with the majority batch
-                        concat_sample = tf.concat([conv_samples, maj_batch], axis=0)
-                        self.timing["GenSamples"].stop()
-
-                        self.timing["Fit"].start()
-                        ## switch on discriminator training
-                        discriminator.trainable = True
-                        ## train the discriminator with the concatenated samples and the one-hot encoded labels
-                        discriminator.fit(x=concat_sample, y=labels, verbose=0, batch_size=20)
-                        ## switch off the discriminator training again
-                        discriminator.trainable = False
-                        self.timing["Fit"].stop()
+                        trainDiscriminator(createSamples(min_idx))
 
             for min_idx in range(minSetSize):
                 self.progressBar([(neb_epoch_count + 1) / self.neb_epochs, 1.0, (min_idx + 1) / minSetSize])
-                ## generate minority neighbourhood batch for every minority class sampls by index
-                min_batch_indices = shuffle(self.nmbMin.neighbourhoodOfItem(min_idx))
-                min_batch = self.nmbMin.getPointsFromIndices(min_batch_indices)
-                
-                ## generate random proximal majority batch
-                maj_batch = self._BMB(min_batch_indices)
-
-                ## generate synthetic samples from convex space
-                ## of minority neighbourhood batch using generator
-                conv_samples = generator.predict(min_batch, batch_size=self.neb)
-                
-                ## concatenate them with the majority batch
-                concat_sample = tf.concat([conv_samples, maj_batch], axis=0)
-
-                ## switch on discriminator training
-                discriminator.trainable = True
-                ## train the discriminator with the concatenated samples and the one-hot encoded labels
-                discriminator.fit(x=concat_sample, y=labels, verbose=0, batch_size=20)
-                ## switch off the discriminator training again
-                discriminator.trainable = False
+
+                samples = createSamples(min_idx)
+                trainDiscriminator(samples)
 
                 ## use the complete network to make the generator learn on the decisions
                 ## made by the previous discriminator training
-                gen_loss_history = convGeN.fit(concat_sample, y=labels, verbose=0, batch_size=nLabels)
+                samples = np.array([samples])
+                gen_loss_history = convGeN.fit(samples, y=labelsGeN, verbose=0, batch_size=nLabels)
 
                 ## store the loss for the step
                 loss_history.append(gen_loss_history.history['loss'])
 
-                step += 1
-                if self.debug and (step % 10 == 0):
-                    print(f"{step} neighbourhood batches trained; running neighbourhood epoch {neb_epoch_count}")
-
-            if self.debug:
-                print(f"Neighbourhood epoch {neb_epoch_count + 1} complete")
-
         if self.debug:
             run_range = range(1, len(loss_history) + 1)
             plt.rcParams["figure.figsize"] = (16,10)
@@ -421,7 +415,7 @@ class NextConvGeN(GanBaseClass):
         ## gen -> convex combinations generated from each neighbourhood
         self.timing["BMB"].start()
         indices = [i for i in range(self.minSetSize) if i not in min_idxs]
-        r = np.array([self.nmbMin.basePoints[i] for i in shuffle(indices)[0:self.gen]])
+        r = np.array([ [x.astype(np.float32) for x in self.nmbMin.basePoints[i]] for i in shuffle(indices)[0:self.gen]])
         self.timing["BMB"].stop()
         return r
 
@@ -477,3 +471,33 @@ class NextConvGeN(GanBaseClass):
                 x[i] = bestMatchOf(referenceValues, x[i])
 
         return np.array(synth_batch)
+
+    
+    def correct_feature_types_tf(self, batch, synth_batch):
+        if self.fdc is None:
+            return synth_batch
+        
+        def bestMatchOf(referenceValues, value):
+            if referenceValues is not None:
+                best = referenceValues[0]
+                d = abs(best - value)
+                for x in referenceValues:
+                    dx = abs(x - value)
+                    if dx < d:
+                        best = x
+                        d = dx
+                return best
+            else:
+                return value
+        
+        def correctVector(referenceLists, v):
+            return np.array([bestMatchOf(referenceLists[i], v[i]) for i in range(len(v))])
+            
+        referenceLists = [None for _ in range(self.n_feat)]
+        for i in (self.fdc.nom_list or []):
+            referenceLists[i] = list(set(list(batch[:, i].numpy())))
+
+        for i in (self.fdc.ord_list or []):
+            referenceLists[i] = list(set(list(batch[:, i].numpy())))
+
+        return Lambda(lambda x: np.array([correctVector(referenceLists, y) for y in x]))(synth_batch)