Chaithra-Umesh пре 3 година
родитељ
комит
1358aecbc5
1 измењених фајлова са 27 додато и 25 уклоњено
  1. 27 25
      library/generators/NextConvGeN.py

+ 27 - 25
library/generators/NextConvGeN.py

@@ -30,11 +30,11 @@ def create01Labels(totalSize, sizeFirstHalf):
     labels.extend(repeat(np.array([0,1]), totalSize - sizeFirstHalf))
     return np.array(labels)
 
-class ConvGeN(GanBaseClass):
+class NextConvGeN(GanBaseClass):
     """
     This is the ConvGeN class. ConvGeN is a synthetic point generator for imbalanced datasets.
     """
-    def __init__(self, n_feat, neb=5, gen=None, neb_epochs=10, maj_proximal=False, debug=False):
+    def __init__(self, n_feat, neb=5, gen=None, neb_epochs=10, fdc=None, maj_proximal=False, debug=False):
         self.isTrained = False
         self.n_feat = n_feat
         self.neb = neb
@@ -50,11 +50,12 @@ class ConvGeN(GanBaseClass):
         self.maj_proximal = maj_proximal
         self.cg = None
         self.canPredict = True
+        self.fdc = fdc
 
         if self.neb is not None and self.gen is not None and self.neb > self.gen:
             raise ValueError(f"Expected neb <= gen but got neb={neb} and gen={gen}.")
 
-    def reset(self, dataSet):
+    def reset(self, data):
         """
         Creates the network.
 
@@ -64,8 +65,8 @@ class ConvGeN(GanBaseClass):
         """
         self.isTrained = False
 
-        if dataSet is not None:
-            nMinoryPoints = dataSet.data1.shape[0]
+        if data is not None:
+            nMinoryPoints = data.shape[0]
             if self.nebInitial is None:
                 self.neb = nMinoryPoints
             else:
@@ -96,7 +97,7 @@ class ConvGeN(GanBaseClass):
             print(self.cg.summary())
             print('\n')
 
-    def train(self, dataSet, discTrainCount=5):
+    def train(self, data, discTrainCount=5):
         """
         Trains the Network.
 
@@ -104,24 +105,24 @@ class ConvGeN(GanBaseClass):
         
         *discTrainCount* gives the number of extra training for the discriminator for each epoch. (>= 0)
         """
-        if dataSet.data1.shape[0] <= 0:
+        if data.shape[0] <= 0:
             raise AttributeError("Train: Expected data class 1 to contain at least one point.")
 
         # Store size of minority class. This is needed during point generation.
-        self.minSetSize = dataSet.data1.shape[0]
+        self.minSetSize = data.shape[0]
 
+        normalizedData = data
+        if self.fdc is not None:
+            normalizedData = self.fdc.normalize(data)
+        
         # Precalculate neighborhoods
-        self.nmbMin = NNSearch(self.neb).fit(haystack=dataSet.data1)
-        if self.maj_proximal:
-            self.nmbMaj = NNSearch(self.neb).fit(haystack=dataSet.data0, needles=dataSet.data1)
-        else:
-            self.nmbMaj = None
+        self.nmbMin = NNSearch(self.neb).fit(haystack=normalizedData)
+        self.nmbMin.basePoints = data
 
         # Do the training.
-        self._rough_learning(dataSet.data1, dataSet.data0, discTrainCount)
+        self._rough_learning(data, discTrainCount)
         
         # Neighborhood in majority class is no longer needed. So save memory.
-        self.nmbMaj = None
         self.isTrained = True
 
     def generateDataPoint(self):
@@ -149,7 +150,10 @@ class ConvGeN(GanBaseClass):
             synth_set.extend(self._generate_data_for_min_point(i, synth_num))
 
         ## extract the exact number of synthetic samples needed to exactly balance the two classes
-        synth_set = np.array(synth_set[:numOfSamples]) 
+        synth_set = np.array(synth_set[:numOfSamples])
+        
+        if fdc is not None:
+            synth_set = fdc.fixPointsToDataset(synth_set)
 
         return synth_set
 
@@ -302,13 +306,13 @@ class ConvGeN(GanBaseClass):
 
 
     # Training
-    def _rough_learning(self, data_min, data_maj, discTrainCount):
+    def _rough_learning(self, data, discTrainCount):
         generator = self.conv_sample_generator
         discriminator = self.maj_min_discriminator
         convGeN = self.cg
         loss_history = [] ## this is for stroring the loss for every run
         step = 0
-        minSetSize = len(data_min)
+        minSetSize = len(data)
 
         labels = tf.convert_to_tensor(create01Labels(2 * self.gen, self.gen))
         nLabels = 2 * self.gen
@@ -321,7 +325,7 @@ class ConvGeN(GanBaseClass):
                         min_batch_indices = shuffle(self.nmbMin.neighbourhoodOfItem(min_idx))
                         min_batch = self.nmbMin.getPointsFromIndices(min_batch_indices)
                         ## generate random proximal majority batch
-                        maj_batch = self._BMB(data_maj, min_batch_indices)
+                        maj_batch = self._BMB(min_batch_indices)
 
                         ## generate synthetic samples from convex space
                         ## of minority neighbourhood batch using generator
@@ -342,7 +346,7 @@ class ConvGeN(GanBaseClass):
                 min_batch = self.nmbMin.getPointsFromIndices(min_batch_indices)
                 
                 ## generate random proximal majority batch
-                maj_batch = self._BMB(data_maj, min_batch_indices)
+                maj_batch = self._BMB(min_batch_indices)
 
                 ## generate synthetic samples from convex space
                 ## of minority neighbourhood batch using generator
@@ -389,17 +393,15 @@ class ConvGeN(GanBaseClass):
         self.loss_history = loss_history
 
 
-    def _BMB(self, data_maj, min_idxs):
+    def _BMB(self, min_idxs):
 
         ## Generate a borderline majority batch
         ## data_maj -> majority class data
         ## min_idxs -> indices of points in minority class
         ## gen -> convex combinations generated from each neighbourhood
 
-        if self.nmbMaj is not None:
-            return self.nmbMaj.neighbourhoodOfItemList(shuffle(min_idxs), maxCount=self.gen)
-        else:
-            return tf.convert_to_tensor(data_maj[np.random.randint(len(data_maj), size=self.gen)])
+        indices = [i for i in range(self.minSetSize) if i not in min_idxs]
+        return self.nmbMin.neighbourhoodOfItemList(shuffle(indices), maxCount=self.gen)
 
 
     def retrainDiscriminitor(self, data, labels):