Kristian Schultz пре 3 година
родитељ
комит
a7baa3766a
1 измењених фајлова са 15 додато и 2 уклоњено
  1. 15 2
      library/generators/NextConvGeN.py

+ 15 - 2
library/generators/NextConvGeN.py

@@ -3,6 +3,7 @@ import matplotlib.pyplot as plt
 
 from library.interfaces import GanBaseClass
 from library.dataset import DataSet
+from library.timing import timing
 
 from keras.layers import Dense, Input, Multiply, Flatten, Conv1D, Reshape
 from keras.models import Model
@@ -51,6 +52,10 @@ class NextConvGeN(GanBaseClass):
         self.cg = None
         self.canPredict = True
         self.fdc = fdc
+        
+        self.timing = { n: timing(n) for n in [
+            "Train", "BMB", "NbhSearch"
+            ] }
 
         if self.neb is not None and self.gen is not None and self.neb > self.gen:
             raise ValueError(f"Expected neb <= gen but got neb={neb} and gen={gen}.")
@@ -108,6 +113,7 @@ class NextConvGeN(GanBaseClass):
         if data.shape[0] <= 0:
             raise AttributeError("Train: Expected data class 1 to contain at least one point.")
 
+        self.timing["Train"].start()
         # Store size of minority class. This is needed during point generation.
         self.minSetSize = data.shape[0]
 
@@ -115,15 +121,18 @@ class NextConvGeN(GanBaseClass):
         if self.fdc is not None:
             normalizedData = self.fdc.normalize(data)
         
+        self.timing["NbhSearch"].start()
         # Precalculate neighborhoods
         self.nmbMin = NNSearch(self.neb).fit(haystack=normalizedData)
         self.nmbMin.basePoints = data
+        self.timing["NbhSearch"].stop()
 
         # Do the training.
         self._rough_learning(data, discTrainCount)
         
         # Neighborhood in majority class is no longer needed. So save memory.
         self.isTrained = True
+        self.timing["Train"].stop()
 
     def generateDataPoint(self):
         """
@@ -315,8 +324,10 @@ class NextConvGeN(GanBaseClass):
         nLabels = 2 * self.gen
 
         for neb_epoch_count in range(self.neb_epochs):
+            print(f"NEB EPOCH #{neb_epoch_count + 1} / {self.neb_epochs}")
             if discTrainCount > 0:
                 for n in range(discTrainCount):
+                    print(f"discTrain #{n + 1} / {discTrainCount}")
                     for min_idx in range(minSetSize):
                         ## generate minority neighbourhood batch for every minority class sampls by index
                         min_batch_indices = shuffle(self.nmbMin.neighbourhoodOfItem(min_idx))
@@ -396,9 +407,11 @@ class NextConvGeN(GanBaseClass):
         ## data_maj -> majority class data
         ## min_idxs -> indices of points in minority class
         ## gen -> convex combinations generated from each neighbourhood
-
+        self.timing["BMB"].start()
         indices = [i for i in range(self.minSetSize) if i not in min_idxs]
-        return self.nmbMin.neighbourhoodOfItemList(shuffle(indices), maxCount=self.gen)
+        r = np.array([self.nmbMin.basePoints[i] for i in shuffle(indices)[0:self.gen]])
+        self.timing["BMB"].stop()
+        return r
 
 
     def retrainDiscriminitor(self, data, labels):