{ "cells": [ { "cell_type": "code", "execution_count": null, "id": "pretty-performer", "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import matplotlib.pyplot as plt" ] }, { "cell_type": "markdown", "id": "engaging-warehouse", "metadata": {}, "source": [ "# Constants" ] }, { "cell_type": "code", "execution_count": null, "id": "crazy-taxation", "metadata": {}, "outputs": [], "source": [ "kScore = \"cohens kappa score\"\n", "f1Score = \"f1 score\"" ] }, { "cell_type": "markdown", "id": "extensive-future", "metadata": {}, "source": [ "# Settings" ] }, { "cell_type": "code", "execution_count": null, "id": "warming-department", "metadata": {}, "outputs": [], "source": [ "ignoreSet = [\"ozone_level\", \"yeast_me2\"]\n", "\n", "gans = [\"SimpleGAN\", \"Repeater\", \"SpheredNoise\", \"convGAN\"]\n", "algs = {\"LR\", \"GB\", \"KNN\"}\n" ] }, { "cell_type": "markdown", "id": "seasonal-greek", "metadata": {}, "source": [ "# ProWRAS Data" ] }, { "cell_type": "code", "execution_count": null, "id": "needed-birmingham", "metadata": {}, "outputs": [], "source": [ "dataset = [\n", " \"abalone9-18\",\n", " \"abalone_17_vs_7_8_9_10\",\n", " \"car-vgood\",\"car_good\",\n", " \"flare-F\",\n", " \"hypothyroid\",\n", " \"kddcup-guess_passwd_vs_satan\",\n", " \"kr-vs-k-three_vs_eleven\",\n", " \"kr-vs-k-zero-one_vs_draw\",\n", " \"shuttle-2_vs_5\",\n", " \"winequality-red-4\",\n", " \"yeast4\",\n", " \"yeast5\",\n", " \"yeast6\",\n", " \"ozone_level\",\n", " \"yeast_me2\",\n", " \"Average\"\n", " ]" ] }, { "cell_type": "code", "execution_count": null, "id": "brilliant-phoenix", "metadata": {}, "outputs": [], "source": [ "knn_ProWRAS_f1 = [0.384,0.347,0.818,0.641,0.301,0.553,1.0,0.94,0.9,1.0,0.141,0.308,0.714,0.545,0.556,0.339,0.538]\n", "knn_ProWRAS_k = [0.35,0.328,0.81,0.622,0.263,0.528,1.0,0.938,0.896,1.0,0.093,0.268,0.704,0.531,0.526,0.305,0.515]\n", "\n", "lr_ProWRAS_f1 = [0.488,0.315,0.407,0.103,0.341,0.446,0.99,0.928,0.853,1.0,0.158,0.308,0.591,0.326,0.347,0.295,0.472]\n", "lr_ProWRAS_k = [0.446,0.287,0.371,0.033,0.3,0.407,0.99,0.926,0.847,1.0,0.119,0.268,0.574,0.3,0.319,0.254,0.441]\n", "\n", "gb_ProWRAS_f1 = [0.385,0.335,0.959,0.863,0.320,0.803,0.998,0.995,0.969,1.0,0.156,0.335,0.735,0.514,0.329,0.225,0.600]\n", "gb_ProWRAS_k = [0.341,0.310,0.957,0.857,0.291,0.794,0.998,0.995,0.967,1.0,0.115,0.303,0.726,0.501,0.303,0.328,0.589]" ] }, { "cell_type": "code", "execution_count": null, "id": "ordered-roman", "metadata": {}, "outputs": [], "source": [ "statistic = { \"ProWRAS\": {} }\n", "for (n, f1, k) in zip(dataset, lr_ProWRAS_f1, lr_ProWRAS_k):\n", " if n in ignoreSet:\n", " continue\n", " \n", " if n not in statistic[\"ProWRAS\"]:\n", " statistic[\"ProWRAS\"][n] = {}\n", " \n", " statistic[\"ProWRAS\"][n][\"LR\"] = { kScore: k, f1Score: f1 }\n", "\n", "for (n, f1, k) in zip(dataset, gb_ProWRAS_f1, gb_ProWRAS_k):\n", " if n in ignoreSet:\n", " continue\n", " \n", " if n not in statistic[\"ProWRAS\"]:\n", " statistic[\"ProWRAS\"][n] = {}\n", " \n", " statistic[\"ProWRAS\"][n][\"GB\"] = { kScore: k, f1Score: f1 }\n", "\n", " \n", "for (n, f1, k) in zip(dataset, knn_ProWRAS_f1, knn_ProWRAS_k):\n", " if n in ignoreSet:\n", " continue\n", " \n", " if n not in statistic[\"ProWRAS\"]:\n", " statistic[\"ProWRAS\"][n] = {}\n", " \n", " statistic[\"ProWRAS\"][n][\"KNN\"] = { kScore: k, f1Score: f1 }\n", " \n", "dataset = list(filter(lambda n: n not in ignoreSet, dataset))" ] }, { "cell_type": "markdown", "id": "selective-connecticut", "metadata": {}, "source": [ "# Load data from CSV files" ] }, { "cell_type": "code", "execution_count": null, "id": "intended-watts", "metadata": {}, "outputs": [], "source": [ "def loadDiagnoseData(ganType, datasetName):\n", " fileName = f\"data_result/{ganType}/folding_{datasetName}.csv\"\n", " r = {}\n", " try:\n", " with open(fileName) as f:\n", " newBlock = True\n", " n = \"\"\n", " for line in f:\n", " line = line.strip()\n", " if newBlock:\n", " n = line\n", " newBlock = False\n", " elif line == \"---\":\n", " newBlock = True\n", " else:\n", " parts = line.split(\";\")\n", " if parts[0] == \"avg\":\n", " r[n] = { f1Score: float(parts[5]), kScore: float(parts[6]) }\n", " except FileNotFoundError as e:\n", " print(f\"Missing file: {fileName}\")\n", " return r" ] }, { "cell_type": "code", "execution_count": null, "id": "classical-rescue", "metadata": {}, "outputs": [], "source": [ "for gan in gans:\n", " if gan not in statistic:\n", " statistic[gan] = {}\n", " \n", " for ds in dataset:\n", " if ds != \"Average\":\n", " statistic[gan][ds] = loadDiagnoseData(gan, ds)" ] }, { "cell_type": "code", "execution_count": null, "id": "unable-entrance", "metadata": {}, "outputs": [], "source": [ "for gan in statistic.keys():\n", " f1 = { n: 0.0 for n in algs }\n", " k = { n: 0.0 for n in algs }\n", " c = 0\n", "\n", " for ds in dataset:\n", " if ds != \"Average\":\n", " c += 1\n", " for n in algs:\n", " if n in statistic[gan][ds].keys():\n", " f1[n] += statistic[gan][ds][n][f1Score]\n", " k[n] += statistic[gan][ds][n][kScore]\n", "\n", " avg = {}\n", " for n in algs:\n", " avg[n] = { f1Score: f1[n] / c, kScore: k[n] / c }\n", " statistic[gan][\"Average\"] = avg" ] }, { "cell_type": "markdown", "id": "public-collins", "metadata": {}, "source": [ "# Show Statistics" ] }, { "cell_type": "code", "execution_count": null, "id": "combined-courage", "metadata": {}, "outputs": [], "source": [ "def showDiagnose(algo, score):\n", " def gr(n):\n", " if n in resultList:\n", " return resultList[n][algo].data[score]\n", " else:\n", " return 0.0\n", " \n", " print(f\"{algo}: {score}\")\n", " \n", " gans = list(statistic.keys())\n", "\n", " w = 0.8 / len(gans)\n", " bar = range(len(dataset))\n", " plt.figure(figsize=(20, 8))\n", " for g in gans:\n", " values = [\n", " (statistic[g][d][algo][score] if algo in statistic[g][d].keys() else 0.0)\n", " for d in dataset\n", " ]\n", " plt.bar(bar, values, w, label=g)\n", " \n", " \n", " bar = [i + w for i in bar]\n", " \n", " plt.xlabel(\"Dataset\")\n", " plt.ylabel(score)\n", " plt.xticks(range(len(dataset)), dataset)\n", " plt.xticks(rotation=\"vertical\")\n", " plt.legend()\n", " plt.show()" ] }, { "cell_type": "code", "execution_count": null, "id": "familiar-private", "metadata": { "scrolled": false }, "outputs": [], "source": [ "for a in algs:\n", " showDiagnose(a, f1Score)\n", " showDiagnose(a, kScore)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.5" } }, "nbformat": 4, "nbformat_minor": 5 }