|
@@ -56,6 +56,9 @@ class ConvGAN(GanBaseClass):
|
|
|
self.maj_min_discriminator = None
|
|
self.maj_min_discriminator = None
|
|
|
self.cg = None
|
|
self.cg = None
|
|
|
|
|
|
|
|
|
|
+ if neb > gen:
|
|
|
|
|
+ raise ValueError(f"Expected neb <= gen but got neb={neb} and gen={gen}.")
|
|
|
|
|
+
|
|
|
def reset(self):
|
|
def reset(self):
|
|
|
"""
|
|
"""
|
|
|
Resets the trained GAN to an random state.
|
|
Resets the trained GAN to an random state.
|
|
@@ -70,6 +73,16 @@ class ConvGAN(GanBaseClass):
|
|
|
## instanciate network and visualize architecture
|
|
## instanciate network and visualize architecture
|
|
|
self.cg = self._convGAN(self.conv_sample_generator, self.maj_min_discriminator)
|
|
self.cg = self._convGAN(self.conv_sample_generator, self.maj_min_discriminator)
|
|
|
|
|
|
|
|
|
|
+ if self.debug:
|
|
|
|
|
+ print(self.conv_sample_generator.summary())
|
|
|
|
|
+ print('\n')
|
|
|
|
|
+
|
|
|
|
|
+ print(self.maj_min_discriminator.summary())
|
|
|
|
|
+ print('\n')
|
|
|
|
|
+
|
|
|
|
|
+ print(self.cg.summary())
|
|
|
|
|
+ print('\n')
|
|
|
|
|
+
|
|
|
def train(self, dataSet):
|
|
def train(self, dataSet):
|
|
|
"""
|
|
"""
|
|
|
Trains the GAN.
|
|
Trains the GAN.
|
|
@@ -212,7 +225,7 @@ class ConvGAN(GanBaseClass):
|
|
|
min_batch = Lambda(lambda x: x[:self.neb])(batch_data)
|
|
min_batch = Lambda(lambda x: x[:self.neb])(batch_data)
|
|
|
|
|
|
|
|
## extract majority batch
|
|
## extract majority batch
|
|
|
- maj_batch = Lambda(lambda x: x[self.neb:])(batch_data)
|
|
|
|
|
|
|
+ maj_batch = Lambda(lambda x: x[self.gen:])(batch_data)
|
|
|
|
|
|
|
|
## pass minority batch into generator to obtain convex space transformation
|
|
## pass minority batch into generator to obtain convex space transformation
|
|
|
## (synthetic samples) of the minority neighbourhood input batch
|
|
## (synthetic samples) of the minority neighbourhood input batch
|