|
|
@@ -23,16 +23,16 @@ class SimpleGan(GanBaseClass):
|
|
|
"""
|
|
|
A class for a simple GAN.
|
|
|
"""
|
|
|
- def __init__(self, numOfFeatures=786, noiseSize=None, epochs=3, batchSize=128, withTanh=False, gLayers=None, dLayers=None):
|
|
|
+ def __init__(self, numOfFeatures=786, noiseSize=None, epochs=10, batchSize=128, withTanh=False, gLayers=None, dLayers=None):
|
|
|
self.isTrained = False
|
|
|
- self.noiseSize = noiseSize if noiseSize is not None else numOfFeatures
|
|
|
+ self.noiseSize = noiseSize if noiseSize is not None else (numOfFeatures * 16)
|
|
|
self.numOfFeatures = numOfFeatures
|
|
|
self.epochs = epochs
|
|
|
self.batchSize = batchSize
|
|
|
self.scaler = 1.0
|
|
|
self.withTanh = withTanh
|
|
|
- self.dLayers = dLayers if dLayers is not None else [1024, 512, 256]
|
|
|
- self.gLayers = gLayers if gLayers is not None else [256, 512, 1024]
|
|
|
+ self.dLayers = dLayers if dLayers is not None else [numOfFeatures * 40, numOfFeatures * 20, numOfFeatures * 10]
|
|
|
+ self.gLayers = gLayers if gLayers is not None else [noiseSize * 2, numOfFeatures * 4, numOfFeatures * 2]
|
|
|
|
|
|
def reset(self, _dataSet):
|
|
|
"""
|
|
|
@@ -106,6 +106,8 @@ class SimpleGan(GanBaseClass):
|
|
|
self.scaler = max(1.0, 1.1 * tf.reduce_max(tf.abs(trainData)).numpy())
|
|
|
scaleDown = 1.0 / self.scaler
|
|
|
|
|
|
+ trainData = scaleDown * trainData
|
|
|
+
|
|
|
for e in range(self.epochs):
|
|
|
print(f"Epoch {e + 1}/{self.epochs}")
|
|
|
for _ in range(self.batchSize):
|
|
|
@@ -116,12 +118,12 @@ class SimpleGan(GanBaseClass):
|
|
|
syntheticBatch = self.generator.predict(noise)
|
|
|
|
|
|
# Get a random set of real images
|
|
|
- realBatch = dataset.data1[
|
|
|
+ realBatch = trainData[
|
|
|
np.random.randint(low=0, high=trainDataSize, size=self.batchSize)
|
|
|
]
|
|
|
|
|
|
#Construct different batches of real and fake data
|
|
|
- X = np.concatenate([scaleDown * realBatch, syntheticBatch])
|
|
|
+ X = np.concatenate([realBatch, syntheticBatch])
|
|
|
|
|
|
# Labels for generated and real data
|
|
|
y_dis=np.zeros(2 * self.batchSize)
|