import numpy as np from sklearn.cluster import AgglomerativeClustering from sklearn.cluster import KMeans from sklearn.cluster import DBSCAN from fdc.visualize import plotCluster def aglo_clustering(number_of_clusters, affinity, linkage , five_d_embedding, two_d_embedding , visual=False): np.random.seed(42) ag_cluster = AgglomerativeClustering( n_clusters=number_of_clusters , affinity=affinity , linkage=linkage ) clusters = ag_cluster.fit_predict(five_d_embedding) (values, counts) = np.unique(clusters, return_counts=True) two_d_embedding['Cluster'] = clusters if visual: plotCluster(two_d_embedding, clusterName="Cluster", xName="UMAP_0", yName="UMAP_1", stroke=3) return two_d_embedding.Cluster.to_list(), counts class Clustering: def __init__(self,high_dim,low_dim,visual): self.high_dim=high_dim self.low_dim=low_dim self.visual=visual def Agglomerative(self,number_of_clusters, affinity, linkage): self.number_of_clusters=number_of_clusters self.affinity=affinity self.linkage=linkage ag_cluster = AgglomerativeClustering(n_clusters=number_of_clusters, affinity=affinity, linkage=linkage) clusters = ag_cluster.fit_predict(self.high_dim) (values, counts) = np.unique(clusters, return_counts=True) self.low_dim['Cluster'] = clusters if self.visual: plotCluster(self.low_dim, clusterName="Cluster", xName="UMAP_0", yName="UMAP_1", stroke=3) return self.low_dim.Cluster.to_list(), counts def DBSCAN(self,eps,min_samples): self.eps=eps self.min_samples=min_samples dbscan = DBSCAN(eps=eps, min_samples = min_samples) clusters = dbscan.fit_predict(self.high_dim) (values, counts) = np.unique(clusters, return_counts=True) self.low_dim['Cluster'] = clusters if self.visual: plotCluster(self.low_dim, clusterName="Cluster", xName="UMAP_0", yName="UMAP_1", stroke=3) return self.low_dim.Cluster.to_list(), counts def K_means(self,no_of_clusters): self.no_of_clusters=no_of_clusters kmeans = KMeans(n_clusters=no_of_clusters) clusters = kmeans.fit_predict(self.high_dim) (values, counts) = np.unique(clusters, return_counts=True) self.low_dim['Cluster'] = clusters if self.visual: plotCluster(self.low_dim, clusterName="Cluster", xName="UMAP_0", yName="UMAP_1", stroke=3) return self.low_dim.Cluster.to_list(), counts