ctgan.py 2.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879
  1. import numpy as np
  2. import ctgan
  3. import math
  4. from library.interfaces import GanBaseClass
  5. from library.dataset import DataSet
  6. class CtGAN(GanBaseClass):
  7. """
  8. This is a toy example of a GAN.
  9. It repeats the first point of the training-data-set.
  10. """
  11. def __init__(self, epochs=10, debug=False):
  12. self.isTrained = False
  13. self.epochs = epochs
  14. self.debug = debug
  15. self.ctgan = None
  16. def reset(self, _dataSet):
  17. """
  18. Resets the trained GAN to an random state.
  19. """
  20. self.isTrained = False
  21. ## instanciate generator network and visualize architecture
  22. self.ctgan = ctgan.CTGANSynthesizer(epochs=self.epochs)
  23. def train(self, dataSet):
  24. """
  25. Trains the GAN.
  26. It stores the data points in the training data set and mark as trained.
  27. *dataSet* is a instance of /library.dataset.DataSet/. It contains the training dataset.
  28. We are only interested in the first *maxListSize* points in class 1.
  29. """
  30. if dataSet.data1.shape[0] <= 0:
  31. raise AttributeError("Train: Expected data class 1 to contain at least one point.")
  32. discreteColumns = self.findDiscreteColumns(dataSet.data1)
  33. if discreteColumns != []:
  34. self.ctgan.fit(dataSet.data1, discreteColumns)
  35. else:
  36. self.ctgan.fit(dataSet.data1)
  37. self.isTrained = True
  38. def generateDataPoint(self):
  39. """
  40. Returns one synthetic data point by repeating the stored list.
  41. """
  42. return (self.generateData(1))[0]
  43. def generateData(self, numOfSamples=1):
  44. """
  45. Generates a list of synthetic data-points.
  46. *numOfSamples* is a integer > 0. It gives the number of new generated samples.
  47. """
  48. if not self.isTrained:
  49. raise ValueError("Try to generate data with untrained Re.")
  50. return self.ctgan.sample(numOfSamples)
  51. def findDiscreteColumns(self, data):
  52. columns = set(range(data.shape[1]))
  53. for row in data:
  54. for c in list(columns):
  55. x = row[c]
  56. if float(math.floor(x)) != x:
  57. columns.remove(c)
  58. if len(columns) == 0:
  59. break
  60. return columns