import numpy as np from library.dataset import DataSet, TrainTestData class Exercise: """ Exercising a test for a minority class extension class. """ def __init__(self, createNetworkFunction, shuffleFunction=None, numOfSlices=5, numOfShuffles=5): self.numOfSlices = numOfSlices self.numOfShuffles = numOfShuffles self.createNetworkFunction = createNetworkFunction self.shuffleFunction = shuffleFunction self.debug = print def run(self, gan, dataset): if len(dataset.data0) > len(dataset.data1): raise AttributeError("Expected class 0 to be the minority class but class 0 is bigger than class 1.") self.debug("### Start exercise for synthetic point generator") for shuffleStep in range(self.numOfShuffles): stepTitle = "Step {shuffleStep + 1}/{self.numOfShuffles}" self.debug(f"\n====== {stepTitle} =======") if self.shuffleFunction is not None: self.debug("-> Shuffling data") dataset.shuffleWith(self.shuffleFunction) self.debug("-> Spliting data to slices") dataSlices = TrainTestData.splitDataToSlices(dataset, self.numOfSlices) for (sliceNr, sliceData) in enumerate(dataSlices): sliceTitle = "Slice {sliceNr + 1}/{self.numOfSlices}" self.debug(f"\n------ {stepTitle}: {sliceTitle} -------") self._exerciseWithDataSlice(gan, sliceData) self.debug("### Exercise is done.") def _exerciseWithDataSlice(self, gan, dataSlice): self.debug("-> Train generator for synthetic samples") gan.train(dataSlice.train) numOfNeededSamples = dataSlice.train.size1 - dataSlice.train.size0 if numOfNeededSamples > 0: self.debug(f"-> create {numOfNeededSamples} synthetic samples") newSamples = np.asarray([gan.generateData() for _ in range(numOfNeededSamples)]) train = DataSet( data0=np.concatenate((dataSlice.train.data0, newSamples)), data1=dataSlice.train.data1 ) else: train = dataSlice.train self.debug("-> create network") testNetwork = self.createNetworkFunction() self.debug("-> train network") testNetwork.train(train.data, train.labels) self.debug("-> test network") results = testNetwork.predict(dataSlice.test.data) self.debug("-> check results") self._checkResults(results, dataSlice.test.labels) def _checkResults(self, results, expectedLabels): pass