ctgan.py 1.7 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758
  1. import numpy as np
  2. import ctgan
  3. from library.interfaces import GanBaseClass
  4. from library.dataset import DataSet
  5. class CtGAN(GanBaseClass):
  6. """
  7. This is a toy example of a GAN.
  8. It repeats the first point of the training-data-set.
  9. """
  10. def __init__(self, epochs=10, debug=False):
  11. self.isTrained = False
  12. self.epochs = epochs
  13. self.debug = debug
  14. self.ctgan = None
  15. def reset(self):
  16. """
  17. Resets the trained GAN to an random state.
  18. """
  19. self.isTrained = False
  20. ## instanciate generator network and visualize architecture
  21. self.ctgan = ctgan.CTGANSynthesizer(epochs=self.epochs)
  22. def train(self, dataSet):
  23. """
  24. Trains the GAN.
  25. It stores the data points in the training data set and mark as trained.
  26. *dataSet* is a instance of /library.dataset.DataSet/. It contains the training dataset.
  27. We are only interested in the first *maxListSize* points in class 1.
  28. """
  29. if dataSet.data1.shape[0] <= 0:
  30. raise AttributeError("Train: Expected data class 1 to contain at least one point.")
  31. self.ctgan.fit(dataSet.data1)
  32. self.isTrained = True
  33. def generateDataPoint(self):
  34. """
  35. Returns one synthetic data point by repeating the stored list.
  36. """
  37. return (self.generateData(1))[0]
  38. def generateData(self, numOfSamples=1):
  39. """
  40. Generates a list of synthetic data-points.
  41. *numOfSamples* is a integer > 0. It gives the number of new generated samples.
  42. """
  43. if not self.isTrained:
  44. raise ValueError("Try to generate data with untrained Re.")
  45. return self.ctgan.sample(numOfSamples)