{ "cells": [ { "cell_type": "code", "execution_count": null, "id": "pretty-performer", "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import matplotlib.pyplot as plt\n", "import matplotlib.colors\n", "from PIL import Image, ImageDraw, ImageFont\n", "from library.analysis import testSets, generators" ] }, { "cell_type": "markdown", "id": "engaging-warehouse", "metadata": {}, "source": [ "# Constants" ] }, { "cell_type": "code", "execution_count": null, "id": "crazy-taxation", "metadata": {}, "outputs": [], "source": [ "kScore = \"cohens kappa score\"\n", "f1Score = \"f1 score\"" ] }, { "cell_type": "markdown", "id": "extensive-future", "metadata": {}, "source": [ "# Settings" ] }, { "cell_type": "code", "execution_count": null, "id": "warming-department", "metadata": {}, "outputs": [], "source": [ "ignoreSet = [\"ozone_level\", \"yeast_me2\"]\n", "\n", "gans = [g.replace(\"SimpleGAN\", \"GAN\") for g in generators.keys()]\n", "algs = [\"LR\", \"RF\", \"GB\", \"KNN\", \"DoG\"]" ] }, { "cell_type": "code", "execution_count": null, "id": "edf5b592", "metadata": {}, "outputs": [], "source": [ "testSets = [t for t in testSets if t[0:7] == \"folding\"]" ] }, { "cell_type": "markdown", "id": "seasonal-greek", "metadata": {}, "source": [ "# ProWRAS Data" ] }, { "cell_type": "code", "execution_count": null, "id": "needed-birmingham", "metadata": {}, "outputs": [], "source": [ "dataset = [\n", " \"abalone9-18\",\n", " \"abalone_17_vs_7_8_9_10\",\n", " \"car-vgood\",\"car_good\",\n", " \"flare-F\",\n", " \"hypothyroid\",\n", " \"kddcup-guess_passwd_vs_satan\",\n", " \"kr-vs-k-three_vs_eleven\",\n", " \"kr-vs-k-zero-one_vs_draw\",\n", " \"shuttle-2_vs_5\",\n", " \"winequality-red-4\",\n", " \"yeast4\",\n", " \"yeast5\",\n", " \"yeast6\",\n", " \"ozone_level\",\n", " \"yeast_me2\",\n", " \"Average\"\n", " ]" ] }, { "cell_type": "markdown", "id": "d7e8179d", "metadata": {}, "source": [ "knn_ProWRAS_f1 = [0.384,0.347,0.818,0.641,0.301,0.553,1.0,0.94,0.9,1.0,0.141,0.308,0.714,0.545,0.556,0.339,0.538]\n", "knn_ProWRAS_k = [0.35,0.328,0.81,0.622,0.263,0.528,1.0,0.938,0.896,1.0,0.093,0.268,0.704,0.531,0.526,0.305,0.515]\n", "\n", "lr_ProWRAS_f1 = [0.488,0.315,0.407,0.103,0.341,0.446,0.99,0.928,0.853,1.0,0.158,0.308,0.591,0.326,0.347,0.295,0.472]\n", "lr_ProWRAS_k = [0.446,0.287,0.371,0.033,0.3,0.407,0.99,0.926,0.847,1.0,0.119,0.268,0.574,0.3,0.319,0.254,0.441]\n", "\n", "gb_ProWRAS_f1 = [0.385,0.335,0.959,0.863,0.320,0.803,0.998,0.995,0.969,1.0,0.156,0.335,0.735,0.514,0.329,0.225,0.600]\n", "gb_ProWRAS_k = [0.341,0.310,0.957,0.857,0.291,0.794,0.998,0.995,0.967,1.0,0.115,0.303,0.726,0.501,0.303,0.328,0.589]" ] }, { "cell_type": "code", "execution_count": null, "id": "6121f46b", "metadata": {}, "outputs": [], "source": [ "statistic = { }" ] }, { "cell_type": "markdown", "id": "e825147c", "metadata": {}, "source": [ "statistic = { \"ProWRAS\": {} }\n", "for (n, f1, k) in zip(dataset, lr_ProWRAS_f1, lr_ProWRAS_k):\n", " if n in ignoreSet:\n", " continue\n", " \n", " if n not in statistic[\"ProWRAS\"]:\n", " statistic[\"ProWRAS\"][n] = {}\n", " \n", " statistic[\"ProWRAS\"][n][\"LR\"] = { kScore: k, f1Score: f1 }\n", "\n", "for (n, f1, k) in zip(dataset, gb_ProWRAS_f1, gb_ProWRAS_k):\n", " if n in ignoreSet:\n", " continue\n", " \n", " if n not in statistic[\"ProWRAS\"]:\n", " statistic[\"ProWRAS\"][n] = {}\n", " \n", " statistic[\"ProWRAS\"][n][\"GB\"] = { kScore: k, f1Score: f1 }\n", "\n", " \n", "for (n, f1, k) in zip(dataset, knn_ProWRAS_f1, knn_ProWRAS_k):\n", " if n in ignoreSet:\n", " continue\n", " \n", " if n not in statistic[\"ProWRAS\"]:\n", " statistic[\"ProWRAS\"][n] = {}\n", " \n", " statistic[\"ProWRAS\"][n][\"KNN\"] = { kScore: k, f1Score: f1 }\n", " \n", "dataset = list(filter(lambda n: n not in ignoreSet, dataset))" ] }, { "cell_type": "markdown", "id": "selective-connecticut", "metadata": {}, "source": [ "# Load data from CSV files" ] }, { "cell_type": "code", "execution_count": null, "id": "a23177bd", "metadata": {}, "outputs": [], "source": [ "def cleanupName(name):\n", " return name.replace(\"folding_\", \"\").replace(\"imblearn_\", \"\").replace(\"kaggle_\", \"\")" ] }, { "cell_type": "code", "execution_count": null, "id": "304d69ce", "metadata": {}, "outputs": [], "source": [ "dataset = [cleanupName(d) for d in testSets]" ] }, { "cell_type": "code", "execution_count": null, "id": "intended-watts", "metadata": {}, "outputs": [], "source": [ "def loadDiagnoseData(ganType, datasetName):\n", " fileName = f\"data_result/{ganType}/{datasetName}.csv\"\n", " r = {}\n", " try:\n", " with open(fileName) as f:\n", " newBlock = True\n", " n = \"\"\n", " for line in f:\n", " line = line.strip()\n", " if newBlock:\n", " n = line\n", " if n == \"GAN\":\n", " n = \"DoG\"\n", " newBlock = False\n", " elif line == \"---\":\n", " newBlock = True\n", " else:\n", " parts = line.split(\";\")\n", " if parts[0] == \"avg\":\n", " r[n] = { f1Score: float(parts[5]), kScore: float(parts[6]) }\n", " except FileNotFoundError as e:\n", " print(f\"Missing file: {fileName}\")\n", " return r" ] }, { "cell_type": "code", "execution_count": null, "id": "classical-rescue", "metadata": {}, "outputs": [], "source": [ "for gan in gans:\n", " if gan not in statistic:\n", " statistic[gan] = {}\n", " \n", " for ds in testSets:\n", " if ds != \"Average\":\n", " statistic[gan][cleanupName(ds)] = loadDiagnoseData(gan, ds)\n", " \n", " d = cleanupName(ds)\n", " if d not in dataset:\n", " dataset.append(d)" ] }, { "cell_type": "code", "execution_count": null, "id": "unable-entrance", "metadata": {}, "outputs": [], "source": [ "for gan in statistic.keys():\n", " f1 = { n: 0.0 for n in algs }\n", " k = { n: 0.0 for n in algs }\n", " c = 0\n", "\n", " for ds in dataset:\n", " if ds != \"Average\":\n", " c += 1\n", " for n in algs:\n", " if n in statistic[gan][ds].keys():\n", " f1[n] += statistic[gan][ds][n][f1Score]\n", " k[n] += statistic[gan][ds][n][kScore]\n", "\n", " avg = {}\n", " for n in algs:\n", " avg[n] = { f1Score: f1[n] / c, kScore: k[n] / c }\n", " statistic[gan][\"Average\"] = avg" ] }, { "cell_type": "markdown", "id": "public-collins", "metadata": {}, "source": [ "# Show Statistics" ] }, { "cell_type": "code", "execution_count": null, "id": "extra-taiwan", "metadata": {}, "outputs": [], "source": [ "def mix(v, a, b):\n", " return max(0, min(255, int((v * a) + ((1.0 - v) * b))))\n", "\n", "def mixPixel(v, a, b):\n", " return (mix(v, a[0], b[0]), mix(v, a[1], b[1]), mix(v, a[2], b[2]))\n", " \n", "\n", "def drawTransparentRect(img, rect, color, opacity=1.0):\n", " ((x0, y0), (x1, y1)) = rect\n", " \n", " for y in range(y0, y1):\n", " for x in range(x0, x1):\n", " p = (x, y)\n", " c = mixPixel(opacity, color, img.getpixel(p))\n", " img.putpixel(p, c)\n", " " ] }, { "cell_type": "code", "execution_count": null, "id": "5fec8e83", "metadata": {}, "outputs": [], "source": [ "def drawDiagram(size, rowNames, data, colNames=[], colors=None, border=20, barIndent=10, fontSize=20, markers=[0.25, 0.5, 0.75, 1.00]):\n", " silver = (204, 204, 204)\n", " black = (0,0,0)\n", " white = (255, 255, 255)\n", " \n", " defaultColors = [ (31,119,180)\n", " , (255,127,14)\n", " , (44,160,44)\n", " , (214,40,40)\n", " , (148,103,189)\n", " , (140,86,75)\n", " , (227,119,194)\n", " , (127,127,127)\n", " , (40,40,214)\n", " ]\n", " \n", " \n", " if colors is None:\n", " colors = defaultColors\n", "\n", " print((len(data[0]), len(colNames), len(data), len(rowNames)))\n", "\n", " font = ImageFont.truetype(\"FreeSans\", fontSize)\n", " \n", " markerSize = 0\n", " for m in markers:\n", " markerSize = max(markerSize, font.getsize(f\"{m:0.2f}\")[0])\n", "\n", " areaTop = 2 * border + markerSize\n", "\n", " barStep = (size[0] - border - areaTop) // len(data)\n", " barSize = max(border, barStep - border)\n", " barIndent = min(barIndent, barSize / (1 + len(data[0])))\n", " barIndent = barSize / (2 + len(data[0]))\n", " \n", " print((size[0], barSize, barSize * len(data)))\n", " \n", " # Create new Image\n", " w = max(size[0], size[1])\n", " img = Image.new(\"RGB\", (w,w))\n", " d = ImageDraw.Draw(img)\n", " \n", " # Set background to white.\n", " d.rectangle(((0,0), (w,w)), fill=white)\n", " \n", " # draw row names\n", " height = size[1]\n", " left = w - height\n", " textSize = 0\n", " for (n, name) in enumerate(rowNames):\n", " s = font.getsize(name)\n", " offset = int(border + barSize - s[1] + 1.5) // 2\n", " textSize = max(textSize, s[0])\n", " pos = (left + border, areaTop + offset + (barStep * n))\n", " d.text(pos, name, fill=black, font=font)\n", " \n", " \n", " # Calculate sizes for bar drawing.\n", " barLength = height - (4 * border) - textSize\n", " areaSize = (barLength, barSize)\n", " areaLeft = left + (2 * border) + textSize\n", " \n", " # Draw Lines for bar height comparing.\n", " markerPos = [areaLeft + int(v * barLength) for v in markers]\n", " for p in markerPos:\n", " d.line(((p, border), (p, size[0] - border)), fill=silver)\n", " \n", " # Draw bars.\n", " for (n, row) in enumerate(data):\n", " area = ((areaLeft, areaTop + (n * barStep) + (border // 2)), areaSize)\n", "\n", " indices = list(range(len(row)))\n", " indices.sort(key= lambda i: 1.0 - row[i])\n", "\n", " for (n, i) in enumerate(indices):\n", " v = row[i]\n", " c = colors[i]\n", " offset = barIndent * n\n", " tl = (area[0][0], area[0][1] + offset)\n", " br = (tl[0] + int(v * area[1][0]), area[1][1] + tl[1] - offset)\n", " rect = (tl, br)\n", " d.rectangle(rect, fill=c, outline=black)\n", "\n", "\n", " # Draw axis.\n", " d.line(((areaLeft, areaTop), (areaLeft, size[0] - border)), fill=black)\n", " d.line(((areaLeft, areaTop), (w - border, areaTop)), fill=black)\n", " \n", " # Draw y-axis text.\n", " img = img.rotate(90)\n", " d = ImageDraw.Draw(img)\n", "\n", " for (m, p) in zip(markers, markerPos):\n", " d.text((border, size[1] - (p - left)), f\"{m:0.2f}\", fill=black, font=font)\n", " \n", " # Draw legend.\n", " if len(colNames) > 0:\n", " colNameWidth = 0\n", " colNameHeight = fontSize * len(colNames)\n", " for c in colNames:\n", " colNameWidth = max(colNameWidth, font.getsize(c)[0])\n", " \n", " rWidth = int(fontSize * 0.75)\n", " rHeight = fontSize // 2\n", " rPadd = (fontSize - rHeight) // 2\n", " \n", " tl = (size[0] - (int(2.5 * border) + colNameWidth + rWidth), 0)\n", " br = (size[0], int(1.2 * border) + colNameHeight)\n", " drawTransparentRect(img, (tl, br), white, 0.75)\n", "\n", " for (n, c) in enumerate(colNames):\n", " t = border + (fontSize * n)\n", " l = size[0] - border - colNameWidth\n", " d.rectangle(((l - border - rWidth, t + rPadd - 1), (l - border, t + rPadd + rHeight)), fill=colors[n], outline=black)\n", " d.text((l, t), c, fill=black, font=font)\n", " \n", " return img.crop((0, 0, size[0], size[1]))" ] }, { "cell_type": "code", "execution_count": null, "id": "038344f7", "metadata": {}, "outputs": [], "source": [ "def showDiagnose(algo, score):\n", " def valueOf(d, g):\n", " d = cleanupName(d)\n", " if d not in statistic[g].keys():\n", " print(f\"Missing '{d}' in '{g}'\")\n", " return 0.0\n", " \n", " if algo in statistic[g][d].keys():\n", " return statistic[g][d][algo][score]\n", " else:\n", " print(f\"Missing '{algo}' in ('{g}', '{d}')\")\n", " return 0.0\n", " \n", " print(f\"{algo}: {score}\")\n", " \n", " data = [[valueOf(d, g) for g in gans] for d in testSets] \n", " img = drawDiagram((1024, 1024), [cleanupName(d) for d in testSets], data, colNames=gans)\n", " img.save(f\"data_result/statistics/byAlgorithm/statistic-{algo}-{score}.png\")\n" ] }, { "cell_type": "code", "execution_count": null, "id": "8904b4b3", "metadata": {}, "outputs": [], "source": [ "def showDiagnoseAverage(score, onlyOneBar=False):\n", " def valueOf(g, algo):\n", " if algo in statistic[g][\"Average\"].keys():\n", " return statistic[g][\"Average\"][algo][score] \n", " else:\n", " return 0.0\n", "\n", " print(f\"Average: {score}\")\n", " \n", " data = [[valueOf(g, algo) for g in gans] for algo in algs]\n", " img = drawDiagram((1024, 1024), algs, data, colNames=gans)\n", " img.save(f\"data_result/statistics/average/statistic-Algo-Average-{score}.png\")" ] }, { "cell_type": "code", "execution_count": null, "id": "d1f4961a", "metadata": {}, "outputs": [], "source": [ "def showDiagnoseDataset(dataset):\n", " print(f\"{dataset}\")\n", " \n", " def valueOf(algo, score, g):\n", " if dataset in statistic[g]:\n", " if algo in statistic[g][dataset]:\n", " if score in statistic[g][dataset][algo]:\n", " return statistic[g][dataset][algo][score]\n", " return 0.0\n", " \n", " scores = [f1Score, kScore]\n", " \n", " for score in scores:\n", " data = [[valueOf(algo, score, g) for algo in algs] for g in gans]\n", " img = drawDiagram((1024, 1024), gans, data, colNames=algs)\n", " img.save(f\"data_result/statistics/byDataset/statistic-Classifier-{dataset}-{score}.png\")\n" ] }, { "cell_type": "code", "execution_count": null, "id": "57fe8925", "metadata": { "scrolled": false }, "outputs": [], "source": [ "gans = [g for g in statistic.keys() if not g.startswith(\"convGeN\") or g == \"convGeN-majority-full\"]\n", "\n", "for a in algs:\n", " showDiagnose(a, f1Score)\n", " showDiagnose(a, kScore)\n", " \n", "showDiagnoseAverage(f1Score)\n", "showDiagnoseAverage(kScore)\n", "\n", "for t in testSets:\n", " showDiagnoseDataset(cleanupName(t))\n", "\n", "showDiagnoseDataset(\"Average\")\n", "\n", "gans = list(statistic.keys())" ] }, { "cell_type": "code", "execution_count": null, "id": "63841c55", "metadata": {}, "outputs": [], "source": [ "def getValueOf(gan, dataset, algo, score):\n", " if dataset not in statistic[gan].keys():\n", " #print(f\"Missing '{dataset}' in '{gan}'\")\n", " return None\n", "\n", " if algo not in statistic[gan][dataset].keys():\n", " #print(f\"Missing '{algo}' in ('{gan}', '{dataset}')\")\n", " return None\n", " \n", " if score not in statistic[gan][dataset][algo].keys():\n", " #print(f\"Missing '{score}' in ('{gan}', '{dataset}', '{algo}')\")\n", " return None\n", " \n", " return statistic[gan][dataset][algo][score]\n", " \n", " \n", " \n", "def calcTable(algo, score, ignore=[]):\n", " table = []\n", " \n", " def calc(gc, g):\n", " n = 0\n", " for d in testSets:\n", " d = cleanupName(d)\n", " if d not in ignore:\n", " vc = getValueOf(gc, d, algo, score)\n", " v = getValueOf(g, d, algo, score)\n", " if vc is not None and v is not None and vc >= v:\n", " n += 1\n", " return n\n", " \n", " for gc in gans:\n", " table.append([calc(gc, g) for g in gans])\n", " return table" ] }, { "cell_type": "code", "execution_count": null, "id": "177774b0", "metadata": { "scrolled": false }, "outputs": [], "source": [ "tables = {}\n", "ignore = [# \"webpage\"\n", " #, \"mammography\"\n", " #, \"protein_homo\"\n", " #, \"ozone_level\"\n", " #, \"creditcard\"\n", " ]\n", "for a in algs:\n", " tables[a + \" - \" + f1Score] = calcTable(a, f1Score, ignore)\n", " tables[a + \" - \" + kScore] = calcTable(a, kScore, ignore)\n", " \n", "tables[algs[0] + \" - \" + f1Score]" ] }, { "cell_type": "code", "execution_count": null, "id": "453f491d", "metadata": { "scrolled": false }, "outputs": [], "source": [ "cmap = matplotlib.colors.ListedColormap([\n", " (1.0, x / 255.0, 0.0)\n", " for x in range(256)\n", " ] + [\n", " ((255 - x) / 255.0, (255 - x) / 255.0, x / 255.0)\n", " for x in range(256)\n", " ])\n", "\n", "#cmap.set_extremes(bad=cmap(0.0), under=cmap(0.0), over=cmap(1.0))\n", "\n", "for k in tables.keys():\n", " print(k)\n", " labels = list(gans)\n", " t = tables[k]\n", " if k[0:3] == \"DoG\":\n", " #continue\n", " labels = labels[-4:]\n", " t = [r[-4:] for r in t[-4:]]\n", " f = plt.figure(figsize=(5, 4))\n", " f.add_axes([0.4, 0.45, 0.6, 0.5])\n", " else:\n", " f = plt.figure(figsize=(7, 6))\n", " f.add_axes([0.27, 0.25, 0.7, 0.74])\n", " p = plt.imshow(t, cmap=cmap)\n", " plt.colorbar(p)\n", " plt.xticks(range(len(labels)), labels, rotation=\"vertical\")\n", " plt.yticks(range(len(labels)), labels)\n", " plt.savefig(f\"data_result/statistics/successCount/statistic-{k}.pdf\")\n", " plt.show()" ] }, { "cell_type": "code", "execution_count": null, "id": "e4397945", "metadata": {}, "outputs": [], "source": [ "class Table:\n", " def __init__(self, heading):\n", " self.heading = [str(h) for h in heading]\n", " self.sizes = [len(h) for h in self.heading]\n", " self.rows = []\n", " \n", " def add(self, row):\n", " row = [str(r) for r in row]\n", " self.rows.append(row)\n", " self.sizes = [max(a,len(b)) for (a, b) in zip(self.sizes, row)]\n", " \n", " def separator(self):\n", " return \"|\".join([\"-\" * n for n in self.sizes])\n", " \n", " def showRow(self, row):\n", " def pad(n, t):\n", " while len(t) < n:\n", " t += \" \"\n", " return t\n", " \n", " return \"|\".join([pad(n, t) for (n,t) in zip(self.sizes, row)])\n", " \n", " def show(self):\n", " print(self.showRow(self.heading))\n", " print(self.separator())\n", " for row in self.rows:\n", " print(self.showRow(row))\n", " \n", " def showLatex(self, caption, key):\n", " \n", " columnConfig = \"|\".join([\"l\"] + [\"@{\\\\hskip3pt}c@{\\\\hskip3pt}\" for h in self.heading[1:]])\n", "\n", " text = \"\\\\begin{table*}[ht]\\\\scriptsize\"\n", " text += \"\\\\caption{\" + caption + \"}\\\\label{\" + key + \"}\"\n", " text += \"\\\\centering\\\\tabularnewline\\n\"\n", "\n", " text += \"\\\\begin{tabular}{\" + columnConfig + \"}\\\\hline\\n\"\n", " text += \" & \".join([\"\\\\textbf{\" + h + \"}\" for h in self.heading])\n", " text += \"\\n\\\\tabularnewline\\n\\\\hline\\n\"\n", " \n", " for row in self.rows:\n", " text += \" & \".join(row)\n", " text += \"\\n\\\\tabularnewline\\n\"\n", " \n", " text += \"\\hline\\end{tabular}\\end{table*}\\n\"\n", " \n", " return text" ] }, { "cell_type": "code", "execution_count": null, "id": "d0e2faa0", "metadata": {}, "outputs": [], "source": [ "def tableRow(algo, dataset, myGans):\n", " row = []\n", " for gan in myGans:\n", " v = getValueOf(gan, dataset, algo, f1Score)\n", " w = getValueOf(gan, dataset, algo, kScore)\n", " row.append((v, w))\n", " return row" ] }, { "cell_type": "code", "execution_count": null, "id": "8c820b02", "metadata": { "scrolled": false }, "outputs": [], "source": [ "def p(f, bold=False):\n", " if f is None:\n", " text = \"?\"\n", " else:\n", " text = f\"{f:0.3f}\"\n", " \n", " if bold:\n", " return \" \\\\textbf{\" + text + \"} \"\n", " else:\n", " return \" \" + text + \" \"\n", "\n", "def latex(text):\n", " r = \"\"\n", " for x in text:\n", " if x == \"_\" or x == \"-\":\n", " r += \" \"\n", " else:\n", " r += x\n", " return r\n", "\n", "def pairMax(row):\n", " a = 0.0\n", " b = 0.0\n", " \n", " for (x,y) in row:\n", " a = max(a, x)\n", " b = max(b, y)\n", " \n", " return (a, b)\n", " \n", "with open(\"data_result/statistics/Tables.tex\", \"w\") as latexFile:\n", " for algo in algs:\n", " latexFile.write(\"\\n\")\n", " latexFile.write(\"% ### \" + algo + \"\\n\")\n", " latexFile.write(\"\\n\")\n", " f = \"\"\n", " heading = [\"dataset ($f_1~$score$~/~\\\\kappa~$score)\"]\n", " \n", " myGans = gans\n", " if algo[0:3] == \"DoG\":\n", " myGans = list(gans)[-4:]\n", "\n", " for g in myGans:\n", " heading.append(latex(g))\n", " table = Table(heading)\n", " \n", " avg = [[0.0, 0.0] for h in heading[1:]]\n", " mx = [[0.0, 0.0] for h in heading[1:]]\n", " cnt = 0\n", "\n", " for d in testSets:\n", " d = cleanupName(d)\n", " if d not in ignore:\n", " cnt += 1\n", " row = tableRow(algo, d, myGans)\n", " line = [latex(d)]\n", "\n", " m = pairMax(row)\n", "\n", " for (n, r) in enumerate(row):\n", " line.append(f\"{p(r[0], r[0] == m[0])} / {p(r[1], r[1] == m[1])}\")\n", " avg[n][0] += r[0] or 0.0\n", " avg[n][1] += r[1] or 0.0\n", " mx[n][0] = max(mx[n][0], r[0] or 0.0)\n", " mx[n][1] = max(mx[n][1], r[1] or 0.0)\n", " table.add(line)\n", "\n", " m = pairMax(avg)\n", " table.add([\"\\\\hline Average\"] + [f\"{p(a / cnt, a == m[0])} / {p(b / cnt, b == m[1])}\" for (a,b) in avg])\n", " #table.add([\"maximum\"] + [f\"{p(a)} / {p(b)}\" for (a,b) in mx])\n", "\n", " tableB = None\n", " if len(table.heading) > 5:\n", " heading = [table.heading[0]] + table.heading[6:]\n", " tableB = Table(heading)\n", " tableB.sizes = [table.sizes[0]] + table.sizes[6:]\n", " tableB.rows = [\n", " [r[0]] + r[6:]\n", " for r in table.rows\n", " ]\n", " table.heading = table.heading[0:6]\n", " table.sizes = table.sizes[0:6]\n", " table.rows = [\n", " r[0:6]\n", " for r in table.rows\n", " ]\n", "\n", " if tableB is not None:\n", " latexFile.write(table.showLatex(algo + \" (1)\", \"tab:results:\" + algo + \":A\") + \"\\n\")\n", " latexFile.write(\"\\n\")\n", " latexFile.write(tableB.showLatex(algo + \" (2)\", \"tab:results:\" + algo + \":B\") + \"\\n\")\n", " else:\n", " latexFile.write(table.showLatex(algo, \"tab:results:\" + algo + \":A\") + \"\\n\")\n", "\n" ] }, { "cell_type": "code", "execution_count": null, "id": "a3792b44", "metadata": {}, "outputs": [], "source": [ "for algo in algs:\n", " print(\"% ### \" + algo)\n", " heading = [\"dataset\"]\n", " for g in gans:\n", " if not g.startswith(\"convGeN\") or g == \"convGeN-majority-full\":\n", " heading.append(g)\n", " table = []\n", " \n", " avg = [[0.0, 0.0] for h in heading[1:]]\n", " cnt = 0\n", " \n", " for d in testSets:\n", " d = cleanupName(d)\n", " if d not in ignore:\n", " cnt += 1\n", " row = tableRow(algo, d, heading[1:])\n", " table.append([(d,d)] + row)\n", "\n", " for (n, r) in enumerate(row):\n", " avg[n][0] += r[0] or 0.0\n", " avg[n][1] += r[1] or 0.0\n", " \n", " table.append([(\"Average\", \"Average\")] + [(a / cnt, b / cnt) for (a,b) in avg])\n", " \n", " with open(f\"data_result/statistics/{algo}-f1.csv\", \"w\") as f:\n", " f.write((\",\".join(heading)) + \"\\n\")\n", " for row in table:\n", " f.write((\",\".join([str(x[0]) for x in row])) + \"\\n\")\n", "\n", " with open(f\"data_result/statistics/{algo}-kappa.csv\", \"w\") as f:\n", " f.write((\",\".join(heading)) + \"\\n\")\n", " for row in table:\n", " f.write((\",\".join([str(x[1]) for x in row])) + \"\\n\")\n" ] }, { "cell_type": "code", "execution_count": null, "id": "f1b75f48", "metadata": {}, "outputs": [], "source": [ "def dsWeight(d, score):\n", " w = 0.0\n", " for g in gans:\n", " for a in algs:\n", " x = getValueOf(g, d, a, score)\n", " if x is not None:\n", " w += x\n", " return w\n", "\n", "\n", "\n", "dataNames = [cleanupName(d) for d in testSets]\n", "dataNames.sort(key=lambda d: dsWeight(d, f1Score) + dsWeight(d, kScore))\n", "for d in dataNames:\n", " w = dsWeight(d, f1Score) + dsWeight(d, kScore)\n", " print(f\"{w:0.3f} - {d}\")" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.7" } }, "nbformat": 4, "nbformat_minor": 5 }