import numpy as np class GanBaseClass: def __init__(self): self.isTrained = False self.exampleItems = None self.nextIndex = 0 pass def train(self, dataSet): if dataSet.data1.shape[0] <= 0: raise AttributeError("Train GAN: Expected data class 1 to contain at least one point.") print( "Train GAN with |class 0|=%d, |class 1|=%d" % (dataSet.data0.shape[0], dataSet.data1.shape[0]) ) self.isTrained = True self.exampleItems = dataSet.data1.copy() def generateData(self): if not self.isTrained: raise ValueError("Try to generate data with untrained GAN.") i = self.nextIndex self.nextIndex += 1 if self.nextIndex >= self.exampleItems.shape[0]: self.nextIndex = 0 return self.exampleItems[i] class TesterNetworkBaseClass: def __init__(self): pass def train(self, data, labels): pass def predict(self, data): return np.zeros(data.shape[0])