exercise.py 2.6 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970
  1. import numpy as np
  2. from library.dataset import DataSet, TrainTestData
  3. class Exercise:
  4. """
  5. Exercising a test for a minority class extension class.
  6. """
  7. def __init__(self, createNetworkFunction, shuffleFunction=None, numOfSlices=5, numOfShuffles=5):
  8. self.numOfSlices = numOfSlices
  9. self.numOfShuffles = numOfShuffles
  10. self.createNetworkFunction = createNetworkFunction
  11. self.shuffleFunction = shuffleFunction
  12. self.debug = print
  13. def run(self, gan, dataset):
  14. if len(dataset.data0) > len(dataset.data1):
  15. raise AttributeError("Expected class 0 to be the minority class but class 0 is bigger than class 1.")
  16. self.debug("### Start exercise for synthetic point generator")
  17. for shuffleStep in range(self.numOfShuffles):
  18. stepTitle = "Step {shuffleStep + 1}/{self.numOfShuffles}"
  19. self.debug(f"\n====== {stepTitle} =======")
  20. if self.shuffleFunction is not None:
  21. self.debug("-> Shuffling data")
  22. dataset.shuffleWith(self.shuffleFunction)
  23. self.debug("-> Spliting data to slices")
  24. dataSlices = TrainTestData.splitDataToSlices(dataset, self.numOfSlices)
  25. for (sliceNr, sliceData) in enumerate(dataSlices):
  26. sliceTitle = "Slice {sliceNr + 1}/{self.numOfSlices}"
  27. self.debug(f"\n------ {stepTitle}: {sliceTitle} -------")
  28. self._exerciseWithDataSlice(gan, sliceData)
  29. self.debug("### Exercise is done.")
  30. def _exerciseWithDataSlice(self, gan, dataSlice):
  31. self.debug("-> Train generator for synthetic samples")
  32. gan.train(dataSlice.train)
  33. numOfNeededSamples = dataSlice.train.size1 - dataSlice.train.size0
  34. if numOfNeededSamples > 0:
  35. self.debug(f"-> create {numOfNeededSamples} synthetic samples")
  36. newSamples = np.asarray([gan.generateData() for _ in range(numOfNeededSamples)])
  37. train = DataSet(
  38. data0=np.concatenate((dataSlice.train.data0, newSamples)),
  39. data1=dataSlice.train.data1
  40. )
  41. else:
  42. train = dataSlice.train
  43. self.debug("-> create network")
  44. testNetwork = self.createNetworkFunction()
  45. self.debug("-> train network")
  46. testNetwork.train(train.data, train.labels)
  47. self.debug("-> test network")
  48. results = testNetwork.predict(dataSlice.test.data)
  49. self.debug("-> check results")
  50. self._checkResults(results, dataSlice.test.labels)
  51. def _checkResults(self, results, expectedLabels):
  52. pass