|
|
@@ -43,11 +43,12 @@ class ConvGAN(GanBaseClass):
|
|
|
This is a toy example of a GAN.
|
|
|
It repeats the first point of the training-data-set.
|
|
|
"""
|
|
|
- def __init__(self, n_feat, neb, gen, debug=True):
|
|
|
+ def __init__(self, n_feat, neb=5, gen=5, neb_epochs=10, debug=True):
|
|
|
self.isTrained = False
|
|
|
self.n_feat = n_feat
|
|
|
self.neb = neb
|
|
|
self.gen = gen
|
|
|
+ self.neb_epochs = 10
|
|
|
self.loss_history = None
|
|
|
self.debug = debug
|
|
|
self.dataSet = None
|
|
|
@@ -69,7 +70,7 @@ class ConvGAN(GanBaseClass):
|
|
|
## instanciate network and visualize architecture
|
|
|
self.cg = self._convGAN(self.conv_sample_generator, self.maj_min_discriminator)
|
|
|
|
|
|
- def train(self, dataSet, neb_epochs=5):
|
|
|
+ def train(self, dataSet):
|
|
|
"""
|
|
|
Trains the GAN.
|
|
|
|
|
|
@@ -82,7 +83,7 @@ class ConvGAN(GanBaseClass):
|
|
|
raise AttributeError("Train: Expected data class 1 to contain at least one point.")
|
|
|
|
|
|
self.dataSet = dataSet
|
|
|
- self._rough_learning(neb_epochs, dataSet.data1, dataSet.data0)
|
|
|
+ self._rough_learning(dataSet.data1, dataSet.data0)
|
|
|
self.isTrained = True
|
|
|
|
|
|
def generateDataPoint(self):
|
|
|
@@ -109,7 +110,7 @@ class ConvGAN(GanBaseClass):
|
|
|
## generate synth_num synthetic samples from each minority neighbourhood
|
|
|
synth_set=[]
|
|
|
for i in range(len(data_min)):
|
|
|
- synth_set.extend(self.generate_data_for_min_point(data_min, i, synth_num))
|
|
|
+ synth_set.extend(self._generate_data_for_min_point(data_min, i, synth_num))
|
|
|
|
|
|
synth_set = synth_set[:numOfSamples] ## extract the exact number of synthetic samples needed to exactly balance the two classes
|
|
|
|
|
|
@@ -253,7 +254,7 @@ class ConvGAN(GanBaseClass):
|
|
|
|
|
|
|
|
|
# Training
|
|
|
- def _rough_learning(self, neb_epochs, data_min, data_maj):
|
|
|
+ def _rough_learning(self, data_min, data_maj):
|
|
|
generator = self.conv_sample_generator
|
|
|
discriminator = self.maj_min_discriminator
|
|
|
GAN = self.cg
|
|
|
@@ -263,7 +264,7 @@ class ConvGAN(GanBaseClass):
|
|
|
|
|
|
labels = tf.convert_to_tensor(create01Labels(2 * self.gen, self.gen))
|
|
|
|
|
|
- for step in range(neb_epochs * len(data_min)):
|
|
|
+ for step in range(self.neb_epochs * len(data_min)):
|
|
|
## generate minority neighbourhood batch for every minority class sampls by index
|
|
|
min_batch = self._NMB_guided(data_min, min_idx)
|
|
|
min_idx = min_idx + 1
|