| 1234567891011121314151617181920212223242526272829303132333435363738394041 |
- import numpy as np
- class GanBaseClass:
- def __init__(self):
- self.isTrained = False
- self.exampleItems = None
- self.nextIndex = 0
- pass
- def train(self, dataSet):
- if dataSet.data1.shape[0] <= 0:
- raise AttributeError("Train GAN: Expected data class 1 to contain at least one point.")
- print(
- "Train GAN with |class 0|=%d, |class 1|=%d"
- % (dataSet.data0.shape[0], dataSet.data1.shape[0])
- )
- self.isTrained = True
- self.exampleItems = dataSet.data1.copy()
- def generateData(self):
- if not self.isTrained:
- raise ValueError("Try to generate data with untrained GAN.")
- i = self.nextIndex
- self.nextIndex += 1
- if self.nextIndex >= self.exampleItems.shape[0]:
- self.nextIndex = 0
- return self.exampleItems[i]
- class TesterNetworkBaseClass:
- def __init__(self):
- pass
- def train(self, data, labels):
- pass
- def predict(self, data):
- return np.zeros(data.shape[0])
|