ctgan.py 2.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980
  1. import numpy as np
  2. import ctgan
  3. import math
  4. from library.interfaces import GanBaseClass
  5. from library.dataset import DataSet
  6. class CtGAN(GanBaseClass):
  7. """
  8. This is a toy example of a GAN.
  9. It repeats the first point of the training-data-set.
  10. """
  11. def __init__(self, epochs=10, debug=False):
  12. self.isTrained = False
  13. self.epochs = epochs
  14. self.debug = debug
  15. self.ctgan = None
  16. self.canPredict = False
  17. def reset(self, _dataSet):
  18. """
  19. Resets the trained GAN to an random state.
  20. """
  21. self.isTrained = False
  22. ## instanciate generator network and visualize architecture
  23. self.ctgan = ctgan.CTGANSynthesizer(epochs=self.epochs)
  24. def train(self, dataSet):
  25. """
  26. Trains the GAN.
  27. It stores the data points in the training data set and mark as trained.
  28. *dataSet* is a instance of /library.dataset.DataSet/. It contains the training dataset.
  29. We are only interested in the first *maxListSize* points in class 1.
  30. """
  31. if dataSet.data1.shape[0] <= 0:
  32. raise AttributeError("Train: Expected data class 1 to contain at least one point.")
  33. discreteColumns = self.findDiscreteColumns(dataSet.data1)
  34. if discreteColumns != []:
  35. self.ctgan.fit(dataSet.data1, discreteColumns)
  36. else:
  37. self.ctgan.fit(dataSet.data1)
  38. self.isTrained = True
  39. def generateDataPoint(self):
  40. """
  41. Returns one synthetic data point by repeating the stored list.
  42. """
  43. return (self.generateData(1))[0]
  44. def generateData(self, numOfSamples=1):
  45. """
  46. Generates a list of synthetic data-points.
  47. *numOfSamples* is a integer > 0. It gives the number of new generated samples.
  48. """
  49. if not self.isTrained:
  50. raise ValueError("Try to generate data with untrained Re.")
  51. return self.ctgan.sample(numOfSamples)
  52. def findDiscreteColumns(self, data):
  53. columns = set(range(data.shape[1]))
  54. for row in data:
  55. for c in list(columns):
  56. x = row[c]
  57. if float(math.floor(x)) != x:
  58. columns.remove(c)
  59. if len(columns) == 0:
  60. break
  61. return columns