瀏覽代碼

Added missing batch_size information vor convGAN.

Kristian Schultz 3 年之前
父節點
當前提交
9bb1249576
共有 1 個文件被更改,包括 3 次插入4 次删除
  1. 3 4
      library/generators/convGAN.py

+ 3 - 4
library/generators/convGAN.py

@@ -287,7 +287,7 @@ class ConvGAN(GanBaseClass):
         synth_set = []
         for _run in range(runs):
             batch = self.nmbMin.getNbhPointsOfItem(index)
-            synth_batch = self.conv_sample_generator.predict(batch)
+            synth_batch = self.conv_sample_generator.predict(batch, batch_size=self.neb)
             synth_set.extend(synth_batch)
 
         return synth_set[:synth_num]
@@ -318,7 +318,7 @@ class ConvGAN(GanBaseClass):
 
                         ## generate synthetic samples from convex space
                         ## of minority neighbourhood batch using generator
-                        conv_samples = generator.predict(min_batch)
+                        conv_samples = generator.predict(min_batch, batch_size=self.neb)
                         ## concatenate them with the majority batch
                         concat_sample = tf.concat([conv_samples, maj_batch], axis=0)
 
@@ -336,9 +336,8 @@ class ConvGAN(GanBaseClass):
                 ## generate random proximal majority batch
                 maj_batch = self._BMB(data_maj, min_batch_indices)
 
-                ## generate synthetic samples from convex space
                 ## of minority neighbourhood batch using generator
-                conv_samples = generator.predict(min_batch)
+                conv_samples = generator.predict(min_batch, batch_size=self.neb)
                 ## concatenate them with the majority batch
                 concat_sample = tf.concat([conv_samples, maj_batch], axis=0)