analysis.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125
  1. from library.exercise import Exercise
  2. from library.dataset import DataSet, TrainTestData
  3. from library.GanExamples import StupidToyListGan
  4. from library.SimpleGan import SimpleGan
  5. from library.Repeater import Repeater
  6. from library.SpheredNoise import SpheredNoise
  7. import pickle
  8. import numpy as np
  9. import random
  10. from imblearn.datasets import fetch_datasets
  11. def loadDataset(datasetName):
  12. pickle_in = open(f"{datasetName}.pickle", "rb")
  13. pickle_dict = pickle.load(pickle_in)
  14. myData = pickle_dict["folding"]
  15. k = myData[0]
  16. labels = np.concatenate((k[1], k[3]), axis=0).astype(float)
  17. features = np.concatenate((k[0], k[2]), axis=0).astype(float)
  18. label_1 = list(np.where(labels == 1)[0])
  19. label_0 = list(np.where(labels == 0)[0])
  20. features_1 = features[label_1]
  21. features_0 = features[label_0]
  22. return DataSet(data0=features_0, data1=features_1)
  23. def getRandGen(initValue, incValue=257, multValue=101, modulus=65537):
  24. value = initValue
  25. while True:
  26. value = ((multValue * value) + incValue) % modulus
  27. yield value
  28. def genShuffler():
  29. randGen = getRandGen(2021)
  30. def shuffler(data):
  31. data = list(data)
  32. size = len(data)
  33. shuffled = []
  34. while size > 0:
  35. p = next(randGen) % size
  36. size -= 1
  37. shuffled.append(data[p])
  38. data = data[0:p] + data[(p + 1):]
  39. return np.array(shuffled)
  40. return shuffler
  41. def runExerciseForSimpleGAN(datasetName):
  42. ganName = "SimpleGAN"
  43. print()
  44. print()
  45. print("///////////////////////////////////////////")
  46. print(f"// Running {ganName} on {datasetName}")
  47. print("///////////////////////////////////////////")
  48. print()
  49. data = loadDataset(f"data_input/{datasetName}")
  50. gan = SimpleGan(numOfFeatures=data.data0.shape[1])
  51. random.seed(2021)
  52. shuffler = genShuffler()
  53. exercise = Exercise(shuffleFunction=shuffler, numOfShuffles=5, numOfSlices=5)
  54. exercise.run(gan, data)
  55. exercise.saveResultsTo(f"data_result/{datasetName}-{ganName}.csv")
  56. exercise.saveResultsTo(f"data_result/{ganName}-{datasetName}.csv")
  57. def runExerciseForRepeater(datasetName):
  58. ganName = "Repeater"
  59. print()
  60. print()
  61. print("///////////////////////////////////////////")
  62. print(f"// Running {ganName} on {datasetName}")
  63. print("///////////////////////////////////////////")
  64. print()
  65. data = loadDataset(f"data_input/{datasetName}")
  66. gan = Repeater()
  67. random.seed(2021)
  68. shuffler = genShuffler()
  69. exercise = Exercise(shuffleFunction=shuffler, numOfShuffles=5, numOfSlices=5)
  70. exercise.run(gan, data)
  71. exercise.saveResultsTo(f"data_result/{datasetName}-{ganName}.csv")
  72. exercise.saveResultsTo(f"data_result/{ganName}-{datasetName}.csv")
  73. def runExerciseForSpheredNoise(datasetName):
  74. ganName = "SpheredNoise"
  75. print()
  76. print()
  77. print("///////////////////////////////////////////")
  78. print(f"// Running {ganName} on {datasetName}")
  79. print("///////////////////////////////////////////")
  80. print()
  81. data = loadDataset(f"data_input/{datasetName}")
  82. gan = SpheredNoise()
  83. random.seed(2021)
  84. shuffler = genShuffler()
  85. exercise = Exercise(shuffleFunction=shuffler, numOfShuffles=5, numOfSlices=5)
  86. exercise.run(gan, data)
  87. exercise.saveResultsTo(f"data_result/{datasetName}-{ganName}.csv")
  88. exercise.saveResultsTo(f"data_result/{ganName}-{datasetName}.csv")
  89. testSets = [
  90. "folding_abalone_17_vs_7_8_9_10",
  91. "folding_abalone9-18",
  92. "folding_car_good",
  93. "folding_car-vgood",
  94. "folding_flare-F",
  95. "folding_hypothyroid",
  96. "folding_kddcup-guess_passwd_vs_satan",
  97. "folding_kr-vs-k-three_vs_eleven",
  98. "folding_kr-vs-k-zero-one_vs_draw",
  99. "folding_shuttle-2_vs_5",
  100. "folding_winequality-red-4",
  101. "folding_yeast4",
  102. "folding_yeast5",
  103. "folding_yeast6"
  104. ]
  105. def runAllTestSets(dataSetList):
  106. for dsFileName in dataSetList:
  107. runExerciseForSimpleGAN(dataSetList)
  108. runExerciseForRepeater(dataSetList)