| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970 |
- import numpy as np
- from library.dataset import DataSet, TrainTestData
- class Exercise:
- """
- Exercising a test for a minority class extension class.
- """
- def __init__(self, createNetworkFunction, shuffleFunction=None, numOfSlices=5, numOfShuffles=5):
- self.numOfSlices = numOfSlices
- self.numOfShuffles = numOfShuffles
- self.createNetworkFunction = createNetworkFunction
- self.shuffleFunction = shuffleFunction
- self.debug = print
- def run(self, gan, dataset):
- if len(dataset.data0) > len(dataset.data1):
- raise AttributeError("Expected class 0 to be the minority class but class 0 is bigger than class 1.")
- self.debug("### Start exercise for synthetic point generator")
- for shuffleStep in range(self.numOfShuffles):
- stepTitle = "Step {shuffleStep + 1}/{self.numOfShuffles}"
- self.debug(f"\n====== {stepTitle} =======")
- if self.shuffleFunction is not None:
- self.debug("-> Shuffling data")
- dataset.shuffleWith(self.shuffleFunction)
- self.debug("-> Spliting data to slices")
- dataSlices = TrainTestData.splitDataToSlices(dataset, self.numOfSlices)
- for (sliceNr, sliceData) in enumerate(dataSlices):
- sliceTitle = "Slice {sliceNr + 1}/{self.numOfSlices}"
- self.debug(f"\n------ {stepTitle}: {sliceTitle} -------")
- self._exerciseWithDataSlice(gan, sliceData)
- self.debug("### Exercise is done.")
- def _exerciseWithDataSlice(self, gan, dataSlice):
- self.debug("-> Train generator for synthetic samples")
- gan.train(dataSlice.train)
- numOfNeededSamples = dataSlice.train.size1 - dataSlice.train.size0
- if numOfNeededSamples > 0:
- self.debug(f"-> create {numOfNeededSamples} synthetic samples")
- newSamples = np.asarray([gan.generateData() for _ in range(numOfNeededSamples)])
- train = DataSet(
- data0=np.concatenate((dataSlice.train.data0, newSamples)),
- data1=dataSlice.train.data1
- )
- else:
- train = dataSlice.train
- self.debug("-> create network")
- testNetwork = self.createNetworkFunction()
- self.debug("-> train network")
- testNetwork.train(train.data, train.labels)
- self.debug("-> test network")
- results = testNetwork.predict(dataSlice.test.data)
- self.debug("-> check results")
- self._checkResults(results, dataSlice.test.labels)
- def _checkResults(self, results, expectedLabels):
- pass
|