Procházet zdrojové kódy

Introduced data stream for training.

Kristian Schultz před 3 roky
rodič
revize
9c8c3b6fee
1 změnil soubory, kde provedl 36 přidání a 11 odebrání
  1. 36 11
      library/generators/NextConvGeN.py

+ 36 - 11
library/generators/NextConvGeN.py

@@ -104,7 +104,7 @@ class NextConvGeN(GanBaseClass):
             print(self.cg.summary())
             print('\n')
 
-    def train(self, data, discTrainCount=5):
+    def train(self, data, discTrainCount=5, batchSize=8):
         """
         Trains the Network.
 
@@ -129,11 +129,11 @@ class NextConvGeN(GanBaseClass):
         self.timing["NbhSearch"].start()
         # Precalculate neighborhoods
         self.nmbMin = NNSearch(self.neb).fit(haystack=normalizedData)
-        self.nmbMin.basePoints = data
+        self.nmbMin.basePoints = np.array([ [x.astype(np.float32) for x in p] for p in data])
         self.timing["NbhSearch"].stop()
 
         # Do the training.
-        self._rough_learning(data, discTrainCount)
+        self._rough_learning(data, discTrainCount, batchSize=batchSize)
         
         # Neighborhood in majority class is no longer needed. So save memory.
         self.isTrained = True
@@ -318,7 +318,7 @@ class NextConvGeN(GanBaseClass):
 
 
     # Training
-    def _rough_learning(self, data, discTrainCount):
+    def _rough_learning(self, data, discTrainCount, batchSize=8):
         generator = self.conv_sample_generator
         discriminator = self.maj_min_discriminator
         convGeN = self.cg
@@ -367,14 +367,39 @@ class NextConvGeN(GanBaseClass):
             discriminator.trainable = False
             self.timing["Fit"].stop()
 
-        
+        def genSamples():
+            for min_idx in range(minSetSize):
+                samples = createSamples(min_idx)
+                for x in samples[0]:
+                    yield x
+
+                for x in samples[1]:
+                    yield x
+
+        def genLabels():
+            for min_idx in range(minSetSize):
+                for x in labels:
+                    yield x
+            
+                
 
         for neb_epoch_count in range(self.neb_epochs):
-            if discTrainCount > 0:
-                for n in range(discTrainCount):
-                    for min_idx in range(minSetSize):
-                        self.progressBar([(neb_epoch_count + 1) / self.neb_epochs, n / discTrainCount, (min_idx + 1) / minSetSize])
-                        trainDiscriminator(createSamples(min_idx))
+            for n in range(max(0,discTrainCount)):
+                self.progressBar([(neb_epoch_count + 1) / self.neb_epochs, n / discTrainCount, 0.5])
+                samples = genSamples()
+                
+                a = tf.data.Dataset.from_generator(genSamples, output_types=tf.float32)
+                b = tf.data.Dataset.from_generator(genLabels, output_types=tf.float32)
+                samples = tf.data.Dataset.zip((a, b)).batch(batchSize * 2 * self.gen)
+
+                self.timing["Fit"].start()
+                ## switch on discriminator training
+                discriminator.trainable = True
+                ## train the discriminator with the concatenated samples and the one-hot encoded labels
+                discriminator.fit(x=samples, verbose=0)
+                ## switch off the discriminator training again
+                discriminator.trainable = False
+                self.timing["Fit"].stop()
 
             for min_idx in range(minSetSize):
                 self.progressBar([(neb_epoch_count + 1) / self.neb_epochs, 1.0, (min_idx + 1) / minSetSize])
@@ -415,7 +440,7 @@ class NextConvGeN(GanBaseClass):
         ## gen -> convex combinations generated from each neighbourhood
         self.timing["BMB"].start()
         indices = [i for i in range(self.minSetSize) if i not in min_idxs]
-        r = np.array([ [x.astype(np.float32) for x in self.nmbMin.basePoints[i]] for i in shuffle(indices)[0:self.gen]])
+        r = self.nmbMin.basePoints[shuffle(indices)[0:self.gen]]
         self.timing["BMB"].stop()
         return r