|
|
@@ -33,8 +33,8 @@ class Exercise:
|
|
|
self.debug = print
|
|
|
|
|
|
def run(self, gan, dataset):
|
|
|
- if len(dataset.data0) > len(dataset.data1):
|
|
|
- raise AttributeError("Expected class 0 to be the minority class but class 0 is bigger than class 1.")
|
|
|
+ if len(dataset.data1) > len(dataset.data0):
|
|
|
+ raise AttributeError("Expected class 1 to be the minority class but class 1 is bigger than class 0.")
|
|
|
|
|
|
self.debug("### Start exercise for synthetic point generator")
|
|
|
for shuffleStep in range(self.numOfShuffles):
|
|
|
@@ -58,14 +58,14 @@ class Exercise:
|
|
|
self.debug("-> Train generator for synthetic samples")
|
|
|
gan.train(dataSlice.train)
|
|
|
|
|
|
- numOfNeededSamples = dataSlice.train.size1 - dataSlice.train.size0
|
|
|
+ numOfNeededSamples = dataSlice.train.size0 - dataSlice.train.size1
|
|
|
|
|
|
if numOfNeededSamples > 0:
|
|
|
self.debug(f"-> create {numOfNeededSamples} synthetic samples")
|
|
|
newSamples = np.asarray([gan.generateData() for _ in range(numOfNeededSamples)])
|
|
|
train = DataSet(
|
|
|
- data0=np.concatenate((dataSlice.train.data0, newSamples)),
|
|
|
- data1=dataSlice.train.data1
|
|
|
+ data0=dataSlice.train.data0,
|
|
|
+ data1=np.concatenate((dataSlice.train.data1, newSamples))
|
|
|
)
|
|
|
else:
|
|
|
train = dataSlice.train
|