clustering.py 766 B

12345678910111213141516171819202122
  1. import numpy as np
  2. from sklearn.cluster import AgglomerativeClustering
  3. from fdc.visualize import plotCluster
  4. def aglo_clustering(number_of_clusters, affinity, linkage
  5. , five_d_embedding, two_d_embedding
  6. , visual=False):
  7. np.random.seed(42)
  8. ag_cluster = AgglomerativeClustering(
  9. n_clusters=number_of_clusters
  10. , affinity=affinity
  11. , linkage=linkage
  12. )
  13. clusters = ag_cluster.fit_predict(five_d_embedding)
  14. (values, counts) = np.unique(clusters, return_counts=True)
  15. two_d_embedding['Cluster'] = clusters
  16. if visual:
  17. plotCluster(two_d_embedding, clusterName="Cluster", xName="UMAP_0", yName="UMAP_1", stroke=3)
  18. return two_d_embedding.Cluster.to_list(), counts