Ver Fonte

added interface for classifying data with generator class.

Kristian Schultz há 4 anos atrás
pai
commit
e70381d03e
2 ficheiros alterados com 14 adições e 0 exclusões
  1. 4 0
      library/generators/convGAN.py
  2. 10 0
      library/interfaces.py

+ 4 - 0
library/generators/convGAN.py

@@ -132,6 +132,10 @@ class ConvGAN(GanBaseClass):
 
         return synth_set
 
+    def predict(self, data):
+        prediction = self.generator.predict(data)
+        return np.array(map(lambda x: x[0], prediction))
+
     # ###############################################################
     # Hidden internal functions
     # ###############################################################

+ 10 - 0
library/interfaces.py

@@ -39,3 +39,13 @@ class GanBaseClass:
         *numOfSamples* is an integer > 0. It gives the number of generated samples.
         """
         raise NotImplementedError
+
+    def predict(self, data):
+        """
+        Takes a list (numpy array) of data points.
+        Returns a list with real values in [0,1] for the propapility
+        that a point is in the minority dataset. With:
+          0.0: point is in majority set
+          1.0: point is in minority set
+        """
+        raise NotImplemented