| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158 |
- """
- This module contains some example Generative Adversarial Networks for testing.
- The classes StupidToyPointGan and StupidToyListGan are not really Networks. This classes are used
- for testing the interface. Hope your actually GAN will perform better than this two.
- The class SimpleGan is a simple standard Generative Adversarial Network.
- """
- import numpy as np
- from library.interfaces import GanBaseClass
- class StupidToyPointGan(GanBaseClass):
- """
- This is a toy example of a GAN.
- It repeats the first point of the training-data-set.
- """
- def __init__(self):
- """
- Initializes the class and mark it as untrained.
- """
- self.isTrained = False
- self.exampleItem = None
- def train(self, dataSet):
- """
- Trains the GAN.
- It stores the first data-point in the training data-set and mark the GAN as trained.
- *dataSet* is a instance of /library.dataset.DataSet/. It contains the training dataset.
- We are only interested in the class 1.
- """
- if dataSet.data1.shape[0] <= 0:
- raise AttributeError("Train GAN: Expected data class 1 to contain at least one point.")
- self.isTrained = True
- self.exampleItem = dataSet.data1[0].copy()
- def generateDataPoint(self):
- """
- Generates one synthetic data-point by copying the stored data point.
- """
- if not self.isTrained:
- raise ValueError("Try to generate data with untrained GAN.")
- return self.exampleItem
- def generateData(self, numOfSamples=1):
- """
- Generates a list of synthetic data-points.
- *numOfSamples* is a integer > 0. It gives the number of new generated samples.
- """
- numOfSamples = int(numOfSamples)
- if numOfSamples < 1:
- raise AttributeError("Expected numOfSamples to be > 0")
- return np.array([self.generateDataPoint() for _ in range(numOfSamples)])
- class StupidToyListGan(GanBaseClass):
- """
- This is a toy example of a GAN.
- It repeats the first point of the training-data-set.
- """
- def __init__(self, maxListSize=100):
- self.isTrained = False
- self.exampleItems = None
- self.nextIndex = 0
- self.maxListSize = int(maxListSize)
- if self.maxListSize < 1:
- raise AttributeError("Expected maxListSize to be > 0 but got " + str(self.maxListSize))
- def train(self, dataSet):
- """
- Trains the GAN.
- It stores the first data-point in the training data-set and mark the GAN as trained.
- *dataSet* is a instance of /library.dataset.DataSet/. It contains the training dataset.
- We are only interested in the first *maxListSize* points in class 1.
- """
- if dataSet.data1.shape[0] <= 0:
- raise AttributeError("Train GAN: Expected data class 1 to contain at least one point.")
- self.isTrained = True
- self.exampleItems = dataSet.data1[: self.maxListSize].copy()
- def generateDataPoint(self):
- """
- Returns one synthetic data point by repeating the stored list.
- """
- if not self.isTrained:
- raise ValueError("Try to generate data with untrained GAN.")
- i = self.nextIndex
- self.nextIndex += 1
- if self.nextIndex >= self.exampleItems.shape[0]:
- self.nextIndex = 0
- return self.exampleItems[i]
- def generateData(self, numOfSamples=1):
- """
- Generates a list of synthetic data-points.
- *numOfSamples* is a integer > 0. It gives the number of new generated samples.
- """
- numOfSamples = int(numOfSamples)
- if numOfSamples < 1:
- raise AttributeError("Expected numOfSamples to be > 0")
- return np.array([self.generateDataPoint() for _ in range(numOfSamples)])
- # class SimpleGan(GanBaseClass):
- # def __init__(self, maxListSize=100):
- # self.isTrained = False
- # self.exampleItems = None
- # self.nextIndex = 0
- # self.maxListSize = int(maxListSize)
- # if self.maxListSize < 1:
- # raise AttributeError(f"Expected maxListSize to be > 0 but got {self.maxListSize}")
- #
- #
- # def train(self, dataSet):
- # if dataSet.data1.shape[0] <= 0:
- # raise AttributeError("Train GAN: Expected data class 1 to contain at least one point.")
- #
- # self.isTrained = True
- # self.exampleItems = dataSet.data1[: self.maxListSize].copy()
- #
- # def generateDataPoint(self, numOfSamples=1):
- # if not self.isTrained:
- # raise ValueError("Try to generate data with untrained GAN.")
- #
- # i = self.nextIndex
- # self.nextIndex += 1
- # if self.nextIndex >= self.exampleItems.shape[0]:
- # self.nextIndex = 0
- #
- # return self.exampleItems[i]
- #
- #
- # def generateData(self, numOfSamples=1):
- # numOfSamples = int(numOfSamples)
- # if numOfSamples < 1:
- # raise AttributeError("Expected numOfSamples to be > 0")
- #
- # return np.array([self.generateDataPoint() for _ in range(numOfSamples)])
- #
|