Selaa lähdekoodia

Added frame for the class.

Kristian Schultz 4 vuotta sitten
vanhempi
commit
34c9b1799c
1 muutettua tiedostoa jossa 59 lisäystä ja 15 poistoa
  1. 59 15
      library/convGAN.py

+ 59 - 15
library/convGAN.py

@@ -39,6 +39,9 @@ seed_num=1
 seed(seed_num)
 tf.random.set_seed(seed_num) 
 
+from library.interfaces import GanBaseClass
+
+
 ## Import dataset
 data = fetch_datasets()['yeast_me2']
 
@@ -51,6 +54,60 @@ features_x.shape
 
 # Until now we have obtained the data. We divided it into training and test sets. we separated obtained seperate variables for the majority and miority classes and their labels for both sets.
 
+
+class ConvGAN(GanBaseClass):
+    """
+    This is a toy example of a GAN.
+    It repeats the first point of the training-data-set.
+    """
+    def __init__(self):
+        self.isTrained = False
+
+    def reset(self):
+        """
+        Resets the trained GAN to an random state.
+        """
+        self.isTrained = False
+
+    def train(self, dataSet):
+        """
+        Trains the GAN.
+
+        It stores the data points in the training data set and mark as trained.
+
+        *dataSet* is a instance of /library.dataset.DataSet/. It contains the training dataset.
+        We are only interested in the first *maxListSize* points in class 1.
+        """
+        if dataSet.data1.shape[0] <= 0:
+            raise AttributeError("Train: Expected data class 1 to contain at least one point.")
+
+        # TODO: do actually training
+
+        self.isTrained = True
+
+    def generateDataPoint(self):
+        """
+        Returns one synthetic data point by repeating the stored list.
+        """
+        return (self.generateData(1))[0]
+
+
+    def generateData(self, numOfSamples=1):
+        """
+        Generates a list of synthetic data-points.
+
+        *numOfSamples* is a integer > 0. It gives the number of new generated samples.
+        """
+        if not self.isTrained:
+            raise ValueError("Try to generate data with untrained Re.")
+
+
+        syntheticPoints = [] # TODO
+
+        return np.array(syntheticPoints)
+
+
+
 ## convGAN
 def unison_shuffled_copies(a, b,seed_perm):
     'Shuffling the feature matrix along with the labels with same order'
@@ -419,28 +476,15 @@ for seed_perm in range(strata):
     
     features_x,labels_x=unison_shuffled_copies(features_x,labels_x,seed_perm)
 
-    #scaler = StandardScaler()
-    #scaler.fit(features_x)
-    #features_x=(scaler.transform(features_x))
-    
-    
     ### Extracting all features and labels
     print('Extracting all features and labels for seed:'+ str(seed_perm)+'\n')
     
     ## Dividing data into training and testing datasets for 10-fold CV
     print('Dividing data into training and testing datasets for 10-fold CV for seed:'+ str(seed_perm)+'\n')
-    label_1=np.where(labels_x == 1)[0]
-    label_1=list(label_1)
-    
+    label_1=list(np.where(labels_x == 1)[0])
     features_1=features_x[label_1]
     
-    label_0=np.where(labels_x != 1)[0]
-    label_0=list(label_0)
-    len(label_0)
-
-
-
-
+    label_0=list(np.where(labels_x != 1)[0])
     features_0=features_x[label_0]
     
     a=len(features_1)//5