瀏覽代碼

Fixed bug that forces neb=gen

Kristian Schultz 4 年之前
父節點
當前提交
9a8eef6cc7
共有 1 個文件被更改,包括 14 次插入1 次删除
  1. 14 1
      library/convGAN.py

+ 14 - 1
library/convGAN.py

@@ -56,6 +56,9 @@ class ConvGAN(GanBaseClass):
         self.maj_min_discriminator = None
         self.cg = None
 
+        if neb > gen:
+            raise ValueError(f"Expected neb <= gen but got neb={neb} and gen={gen}.")
+
     def reset(self):
         """
         Resets the trained GAN to an random state.
@@ -70,6 +73,16 @@ class ConvGAN(GanBaseClass):
         ## instanciate network and visualize architecture
         self.cg = self._convGAN(self.conv_sample_generator, self.maj_min_discriminator)
 
+        if self.debug:
+            print(self.conv_sample_generator.summary())
+            print('\n')
+            
+            print(self.maj_min_discriminator.summary())
+            print('\n')
+
+            print(self.cg.summary())
+            print('\n')
+
     def train(self, dataSet):
         """
         Trains the GAN.
@@ -212,7 +225,7 @@ class ConvGAN(GanBaseClass):
         min_batch = Lambda(lambda x: x[:self.neb])(batch_data)
         
         ## extract majority batch
-        maj_batch = Lambda(lambda x: x[self.neb:])(batch_data)
+        maj_batch = Lambda(lambda x: x[self.gen:])(batch_data)
         
         ## pass minority batch into generator to obtain convex space transformation
         ## (synthetic samples) of the minority neighbourhood input batch