ctab.py 1.7 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859
  1. from library.interfaces import GanBaseClass
  2. from library.dataset import DataSet
  3. from model.synthesizer.ctabgan_synthesizer import CTABGANSynthesizer
  4. import pandas as pd
  5. import warnings
  6. warnings.filterwarnings("ignore")
  7. class CtabGan(GanBaseClass):
  8. """
  9. This is a toy example of a GAN.
  10. It repeats the first point of the training-data-set.
  11. """
  12. def __init__(self, epochs=10, debug=True):
  13. self.isTrained = False
  14. self.epochs = epochs
  15. def reset(self, _dataSet):
  16. """
  17. Resets the trained GAN to an random state.
  18. """
  19. self.isTrained = False
  20. self.synthesizer = CTABGANSynthesizer(epochs = self.epochs)
  21. def train(self, dataSet):
  22. """
  23. Trains the GAN.
  24. It stores the data points in the training data set and mark as trained.
  25. *dataSet* is a instance of /library.dataset.DataSet/. It contains the training dataset.
  26. We are only interested in the first *maxListSize* points in class 1.
  27. """
  28. if dataSet.data1.shape[0] <= 0:
  29. raise AttributeError("Train: Expected data class 1 to contain at least one point.")
  30. self.synthesizer.fit(train_data=pd.DataFrame(dataSet.data1))
  31. self.isTrained = True
  32. def generateDataPoint(self):
  33. """
  34. Returns one synthetic data point by repeating the stored list.
  35. """
  36. return (self.generateData(1))[0]
  37. def generateData(self, numOfSamples=1):
  38. """
  39. Generates a list of synthetic data-points.
  40. *numOfSamples* is a integer > 0. It gives the number of new generated samples.
  41. """
  42. if not self.isTrained:
  43. raise ValueError("Try to generate data with untrained Re.")
  44. return self.synthesizer.sample(numOfSamples)