Przeglądaj źródła

Fixed cluster_wise_F1score

Kristian Schultz 3 lat temu
rodzic
commit
cefa0af4dd
1 zmienionych plików z 61 dodań i 27 usunięć
  1. 61 27
      agglo_5dim_2NN_v3.ipynb

+ 61 - 27
agglo_5dim_2NN_v3.ipynb

@@ -22024,39 +22024,35 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 27,
+   "execution_count": 59,
    "id": "fe9ea10e",
    "metadata": {},
    "outputs": [],
    "source": [
     "def cluster_wise_F1score(ref_list,pred_list):\n",
-    "    F1_score_list=[]\n",
-    "    Geometric_mean_list=[]\n",
-    "    cluster_score_list=[]\n",
+    "    def safeDiv(a, b):\n",
+    "        if b != 0:\n",
+    "            return a / b\n",
+    "        return 0.0\n",
+    "    \n",
+    "    F1_score_list = []\n",
+    "    Geometric_mean_list = []\n",
+    "    cluster_score_list = []\n",
+    "    true_positive_total = 0\n",
     "    for i in np.unique(ref_list):\n",
-    "        indices=[j for j,val in enumerate(ref_list) if val==i]\n",
-    "        true_positive=0\n",
+    "        indices = [j for j,val in enumerate(ref_list) if val == i]\n",
+    "        true_positive = 0\n",
     "        for index in indices:\n",
-    "            if ref_list[index]==pred_list[index]:\n",
-    "                true_positive+=1\n",
-    "            else:\n",
-    "                pass\n",
-    "        if pred_list.count(i)==0:\n",
-    "            precision=0\n",
-    "        else:\n",
-    "            precision=true_positive/pred_list.count(i)\n",
-    "        if ref_list.count(i)==0:\n",
-    "            recall=0\n",
-    "        else:\n",
-    "            recall=true_positive/ref_list.count(i)\n",
-    "        if precision==0 and recall==0:\n",
-    "            F1_score=0\n",
-    "            GM=0\n",
-    "            cluster_score=0\n",
-    "        else:\n",
-    "            F1_score=2*((precision * recall)/(precision + recall))\n",
-    "            GM=np.sqrt(precision * recall)\n",
-    "            cluster_score=recall*100\n",
+    "            if i == pred_list[index]:\n",
+    "                true_positive += 1\n",
+    "        true_positive_total += true_positive\n",
+    "        \n",
+    "        precision = safeDiv(true_positive, pred_list.count(i))\n",
+    "        recall = safeDiv(true_positive, len(indices))\n",
+    "        F1_score = safeDiv(2.0 * precision * recall, precision + recall)\n",
+    "        GM = np.sqrt(precision * recall)\n",
+    "        cluster_score = recall * 100.0\n",
+    "        \n",
     "        print(\"F1_Score of cluster \"+str(i)+\" is {}\".format(F1_score))\n",
     "        print(\"Geometric mean of cluster \"+str(i)+\" is {}\".format(GM))\n",
     "        print(\"Correctly predicted data points in cluster \"+str(i)+\" is {}%\".format(cluster_score))\n",
@@ -22064,9 +22060,12 @@
     "        F1_score_list.append(F1_score)\n",
     "        Geometric_mean_list.append(GM)\n",
     "        cluster_score_list.append(cluster_score)\n",
+    "\n",
+    "    correctly_predicted = safeDiv(100.0 * true_positive_total, len(ref_list))\n",
+    "\n",
     "    print(\"average F1_Score of clusters is {}\".format(np.mean(F1_score_list)))\n",
     "    print(\"average Geometric mean of clusters is {}\".format(np.mean(Geometric_mean_list)))\n",
-    "    print(\"Correctly predicted data points in clusters is {}%\".format(np.mean(cluster_score_list)))"
+    "    print(\"Correctly predicted data points in clusters is {}%\".format(correctly_predicted))"
    ]
   },
   {
@@ -22083,6 +22082,41 @@
     "    colnames.append('c'+str(i+1))"
    ]
   },
+  {
+   "cell_type": "code",
+   "execution_count": 60,
+   "id": "1ef4ae81",
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "F1_Score of cluster 1 is 0.8571428571428571\n",
+      "Geometric mean of cluster 1 is 0.8660254037844386\n",
+      "Correctly predicted data points in cluster 1 is 75.0%\n",
+      "\n",
+      "\n",
+      "F1_Score of cluster 2 is 1.0\n",
+      "Geometric mean of cluster 2 is 1.0\n",
+      "Correctly predicted data points in cluster 2 is 100.0%\n",
+      "\n",
+      "\n",
+      "F1_Score of cluster 3 is 0.8\n",
+      "Geometric mean of cluster 3 is 0.816496580927726\n",
+      "Correctly predicted data points in cluster 3 is 66.66666666666666%\n",
+      "\n",
+      "\n",
+      "average F1_Score of clusters is 0.8857142857142858\n",
+      "average Geometric mean of clusters is 0.8941739949040549\n",
+      "Correctly predicted data points in clusters is 80.0%\n"
+     ]
+    }
+   ],
+   "source": [
+    "cluster_wise_F1score([1,2,3,1,2,3,1,2,3,1],[1,2,3,1,2,3,1,2,0,0])"
+   ]
+  },
   {
    "cell_type": "code",
    "execution_count": 29,