interfaces.py 1.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041
  1. import numpy as np
  2. class GanBaseClass:
  3. def __init__(self):
  4. self.isTrained = False
  5. self.exampleItems = None
  6. self.nextIndex = 0
  7. pass
  8. def train(self, dataSet):
  9. if dataSet.data1.shape[0] <= 0:
  10. raise AttributeError("Train GAN: Expected data class 1 to contain at least one point.")
  11. print(
  12. "Train GAN with |class 0|=%d, |class 1|=%d"
  13. % (dataSet.data0.shape[0], dataSet.data1.shape[0])
  14. )
  15. self.isTrained = True
  16. self.exampleItems = dataSet.data1.copy()
  17. def generateData(self):
  18. if not self.isTrained:
  19. raise ValueError("Try to generate data with untrained GAN.")
  20. i = self.nextIndex
  21. self.nextIndex += 1
  22. if self.nextIndex >= self.exampleItems.shape[0]:
  23. self.nextIndex = 0
  24. return self.exampleItems[i]
  25. class TesterNetworkBaseClass:
  26. def __init__(self):
  27. pass
  28. def train(self, data, labels):
  29. pass
  30. def predict(self, data):
  31. return np.zeros(data.shape[0])