|
|
@@ -24,6 +24,7 @@ class SimpleGan(GanBaseClass):
|
|
|
A class for a simple GAN.
|
|
|
"""
|
|
|
def __init__(self, numOfFeatures=786, noiseSize=None, epochs=10, batchSize=128, withTanh=False, gLayers=None, dLayers=None):
|
|
|
+ self.canPredict = False
|
|
|
self.isTrained = False
|
|
|
self.noiseSize = noiseSize if noiseSize is not None else (numOfFeatures * 16)
|
|
|
self.numOfFeatures = numOfFeatures
|
|
|
@@ -32,7 +33,7 @@ class SimpleGan(GanBaseClass):
|
|
|
self.scaler = 1.0
|
|
|
self.withTanh = withTanh
|
|
|
self.dLayers = dLayers if dLayers is not None else [numOfFeatures * 40, numOfFeatures * 20, numOfFeatures * 10]
|
|
|
- self.gLayers = gLayers if gLayers is not None else [noiseSize * 2, numOfFeatures * 4, numOfFeatures * 2]
|
|
|
+ self.gLayers = gLayers if gLayers is not None else [self.noiseSize * 2, numOfFeatures * 4, numOfFeatures * 2]
|
|
|
|
|
|
def reset(self, _dataSet):
|
|
|
"""
|