import numpy as np class GanBaseClass: def __init__(self): self.isTrained = False self.exampleItem = None pass def train(self, dataSet): if dataSet.data0.shape[0] <= 0: raise AttributeError("Train GAN: Expected data class 0 to contain at least one point.") print( "Train GAN with |class 0|=%d, |class 1|=%d" % (dataSet.data0.shape[0], dataSet.data1.shape[0]) ) self.isTrained = True self.exampleItem = dataSet.data0[0].copy() def generateData(self): if not self.isTrained: raise ValueError("Try to generate data with untrained GAN.") return self.exampleItem class TesterNetworkBaseClass: def __init__(self): pass def train(self, data, labels): pass def predict(self, data): return np.zeros(data.shape[0])