| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182 |
- from library.ext_prowras import ProWRAS_gen
- from library.interfaces import GanBaseClass
- class ProWRAS(GanBaseClass):
- """
- This is a toy example of a GAN.
- It repeats the first point of the training-data-set.
- """
- def __init__(self
- , max_levels = 5
- , convex_nbd = 5
- , n_neighbors = 5
- , max_concov = None
- , theta = 1.0
- , shadow = 100
- , sigma = 0.000001
- , n_jobs = 1
- , debug = False
- ):
- """
- Initializes the class and mark it as untrained.
- """
- self.data = None
- self.max_levels = max_levels
- self.convex_nbd = convex_nbd
- self.n_neighbors = n_neighbors
- self.max_concov = max_concov
- self.theta = theta
- self.shadow = shadow
- self.sigma = sigma
- self.n_jobs = n_jobs
- self.debug = debug
- def reset(self, _dataSet):
- """
- Resets the trained GAN to an random state.
- """
- pass
- def train(self, dataSet):
- """
- Trains the GAN.
- It stores the first data-point in the training data-set and mark the GAN as trained.
- *dataSet* is a instance of /library.dataset.DataSet/. It contains the training dataset.
- We are only interested in the class 1.
- """
- self.data = dataSet
- def generateDataPoint(self):
- """
- Generates one synthetic data-point by copying the stored data point.
- """
- return self.generateData(1)[0]
- def generateData(self, numOfSamples=1):
- """
- Generates a list of synthetic data-points.
- *numOfSamples* is a integer > 0. It gives the number of new generated samples.
- """
- if self.max_concov is not None:
- max_concov = self.max_concov
- else:
- max_concov = self.data.data.shape[0]
- return ProWRAS_gen(
- data = self.data.data,
- labels = self.data.labels,
- max_levels = self.max_levels,
- convex_nbd = self.convex_nbd,
- n_neighbors = self.n_neighbors,
- max_concov = max_concov,
- num_samples_to_generate = numOfSamples,
- theta = self.theta,
- shadow = self.shadow,
- sigma = self.sigma,
- n_jobs = self.n_jobs,
- enableDebug = self.debug)[0][:numOfSamples]
|