Parcourir la source

Replaced point generation by faster data stream version.

Kristian Schultz il y a 3 ans
Parent
commit
18f43b0cfc
1 fichiers modifiés avec 30 ajouts et 62 suppressions
  1. 30 62
      library/generators/NextConvGeN.py

+ 30 - 62
library/generators/NextConvGeN.py

@@ -157,16 +157,34 @@ class NextConvGeN(GanBaseClass):
 
         ## roughly claculate the upper bound of the synthetic samples to be generated from each neighbourhood
         synth_num = (numOfSamples // self.minSetSize) + 1
+        runs = (synth_num // self.gen) + 1
 
-        ## generate synth_num synthetic samples from each minority neighbourhood
-        synth_set=[]
-        for i in range(self.minSetSize):
-            synth_set.extend(self._generate_data_for_min_point(i, synth_num))
+        ## Get a random list of all indices
+        indices = randomIndices(self.minSetSize)
+
+        ## generate all neighborhoods
+        def neighborhoodGenerator():
+            for index in indices:
+                yield self.nmbMin.getNbhPointsOfItem(index)
+
+        neighborhoods = (tf.data.Dataset
+            .from_generator(neighborhoodGenerator, output_types=tf.float32)
+            .repeat()
+            )
+        batch = neighborhoods.take(runs * self.minSetSize).batch(32)
+
+        synth_batch = self.conv_sample_generator.predict(batch)
+
+        n = 0
+        synth_set = []
+        for (x,y) in zip(neighborhoods, synth_batch):
+            synth_set.extend(self.correct_feature_types(x.numpy(), y))
+            n += len(y)
+            if n >= numOfSamples:
+                break
 
         ## extract the exact number of synthetic samples needed to exactly balance the two classes
-        synth_set = np.array(synth_set[:numOfSamples])
-        
-        return synth_set
+        return np.array(synth_set[:numOfSamples])
 
     def predictReal(self, data):
         """
@@ -310,28 +328,6 @@ class NextConvGeN(GanBaseClass):
         model.compile(loss='mse', optimizer=opt)
         return model
 
-    # Create synthetic points
-    def _generate_data_for_min_point(self, index, synth_num):
-        """
-        generate synth_num synthetic points for a particular minoity sample
-        synth_num -> required number of data points that can be generated from a neighbourhood
-        data_min -> minority class data
-        neb -> oversampling neighbourhood
-        index -> index of the minority sample in a training data whose neighbourhood we want to obtain
-        """
-
-        runs = int(synth_num / self.neb) + 1
-        synth_set = []
-        for _run in range(runs):
-            batch = self.nmbMin.getNbhPointsOfItem(index)
-            synth_batch = self.conv_sample_generator.predict(tf.reshape(batch, (1, self.neb, self.n_feat)), batch_size=self.neb)
-            synth_batch = self.correct_feature_types(batch, synth_batch[0])
-            synth_set.extend(synth_batch)
-
-        return synth_set[:synth_num]
-
-
-
     # Training
     def _rough_learning(self, data, discTrainCount, batchSize=32):
         generator = self.conv_sample_generator
@@ -369,7 +365,7 @@ class NextConvGeN(GanBaseClass):
 
             self.timing["FixType"].start()
             ## Fix feature types
-            conv_samples = self.correct_feature_types_tf(min_batch, conv_samples)
+            conv_samples = self.correct_feature_types(min_batch.numpy(), conv_samples)
             self.timing["FixType"].stop()
 
             ## concatenate them with the majority batch
@@ -494,36 +490,6 @@ class NextConvGeN(GanBaseClass):
         print(f"[{s[0]}] [{s[1]}] [{s[2]}]", end="\r")
         
     def correct_feature_types(self, batch, synth_batch):
-
-        if self.fdc is None:
-            return synth_batch
-
-        def bestMatchOf(referenceValues, value):
-            best = referenceValues[0]
-            d = abs(best - value)
-            for x in referenceValues:
-                dx = abs(x - value)
-                if dx < d:
-                    best = x
-                    d = dx
-            return best
-
-        synth_batch = list(synth_batch)
-        for i in (self.fdc.nom_list or []):
-            referenceValues = list(set(list(batch[:, i].numpy())))
-            for x in synth_batch:
-                y = x[i]
-                x[i] = bestMatchOf(referenceValues, y)
-
-        for i in (self.fdc.ord_list or []):
-            referenceValues = list(set(list(batch[:, i].numpy())))
-            for x in synth_batch:
-                x[i] = bestMatchOf(referenceValues, x[i])
-
-        return np.array(synth_batch)
-
-    
-    def correct_feature_types_tf(self, batch, synth_batch):
         if self.fdc is None:
             return synth_batch
         
@@ -545,9 +511,11 @@ class NextConvGeN(GanBaseClass):
             
         referenceLists = [None for _ in range(self.n_feat)]
         for i in (self.fdc.nom_list or []):
-            referenceLists[i] = list(set(list(batch[:, i].numpy())))
+            referenceLists[i] = list(set(list(batch[:, i])))
 
         for i in (self.fdc.ord_list or []):
-            referenceLists[i] = list(set(list(batch[:, i].numpy())))
+            referenceLists[i] = list(set(list(batch[:, i])))
+
+        # print(batch.shape, synth_batch.shape)
 
         return Lambda(lambda x: np.array([correctVector(referenceLists, y) for y in x]))(synth_batch)