| 1234567891011121314151617181920212223242526272829303132333435 |
- import numpy as np
- class GanBaseClass:
- def __init__(self):
- self.isTrained = False
- self.exampleItem = None
- pass
- def train(self, dataSet):
- if dataSet.data0.shape[0] <= 0:
- raise AttributeError("Train GAN: Expected data class 0 to contain at least one point.")
- print(
- "Train GAN with |class 0|=%d, |class 1|=%d"
- % (dataSet.data0.shape[0], dataSet.data1.shape[0])
- )
- self.isTrained = True
- self.exampleItem = dataSet.data0[0].copy()
- def generateData(self):
- if not self.isTrained:
- raise ValueError("Try to generate data with untrained GAN.")
- return self.exampleItem
- class TesterNetworkBaseClass:
- def __init__(self):
- pass
- def train(self, data, labels):
- pass
- def predict(self, data):
- return np.zeros(data.shape[0])
|