interfaces.py 894 B

1234567891011121314151617181920212223242526272829303132333435
  1. import numpy as np
  2. class GanBaseClass:
  3. def __init__(self):
  4. self.isTrained = False
  5. self.exampleItem = None
  6. pass
  7. def train(self, dataSet):
  8. if dataSet.data0.shape[0] <= 0:
  9. raise AttributeError("Train GAN: Expected data class 0 to contain at least one point.")
  10. print(
  11. "Train GAN with |class 0|=%d, |class 1|=%d"
  12. % (dataSet.data0.shape[0], dataSet.data1.shape[0])
  13. )
  14. self.isTrained = True
  15. self.exampleItem = dataSet.data0[0].copy()
  16. def generateData(self):
  17. if not self.isTrained:
  18. raise ValueError("Try to generate data with untrained GAN.")
  19. return self.exampleItem
  20. class TesterNetworkBaseClass:
  21. def __init__(self):
  22. pass
  23. def train(self, data, labels):
  24. pass
  25. def predict(self, data):
  26. return np.zeros(data.shape[0])