ctab.py 1.7 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758
  1. from library.interfaces import GanBaseClass
  2. from library.dataset import DataSet
  3. from model.ctabgan import CTABGAN
  4. import warnings
  5. warnings.filterwarnings("ignore")
  6. class CtabGan(GanBaseClass):
  7. """
  8. This is a toy example of a GAN.
  9. It repeats the first point of the training-data-set.
  10. """
  11. def __init__(self, n_feat, epochs=10, debug=True):
  12. self.isTrained = False
  13. self.epochs = 10
  14. def reset(self):
  15. """
  16. Resets the trained GAN to an random state.
  17. """
  18. self.isTrained = False
  19. self.synthesizer = CTABGAN(epochs = self.epochs)
  20. def train(self, dataSet):
  21. """
  22. Trains the GAN.
  23. It stores the data points in the training data set and mark as trained.
  24. *dataSet* is a instance of /library.dataset.DataSet/. It contains the training dataset.
  25. We are only interested in the first *maxListSize* points in class 1.
  26. """
  27. if dataSet.data1.shape[0] <= 0:
  28. raise AttributeError("Train: Expected data class 1 to contain at least one point.")
  29. self.synthesizer.fit(train_data=dataSet.data1)
  30. self.isTrained = True
  31. def generateDataPoint(self):
  32. """
  33. Returns one synthetic data point by repeating the stored list.
  34. """
  35. return (self.generateData(1))[0]
  36. def generateData(self, numOfSamples=1):
  37. """
  38. Generates a list of synthetic data-points.
  39. *numOfSamples* is a integer > 0. It gives the number of new generated samples.
  40. """
  41. if not self.isTrained:
  42. raise ValueError("Try to generate data with untrained Re.")
  43. return self.synthesizer.generate_samples(numOfSamples)