mirror of
https://git.gfz-potsdam.de/naaice/poet.git
synced 2025-12-16 12:54:50 +01:00
feat: remove sklearn k means
This commit is contained in:
parent
f648f618de
commit
fc689383d4
@ -1,16 +1,10 @@
|
|||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from sklearn.cluster import KMeans
|
|
||||||
import os
|
import os
|
||||||
|
|
||||||
os.environ["TF_XLA_FLAGS"] = "--tf_xla_cpu_global_jit"
|
os.environ["TF_XLA_FLAGS"] = "--tf_xla_cpu_global_jit"
|
||||||
os.environ["XLA_FLAGS"] = "--xla_gpu_cuda_data_dir=" + cuda_dir
|
os.environ["XLA_FLAGS"] = "--xla_gpu_cuda_data_dir=" + cuda_dir
|
||||||
|
|
||||||
def k_means(data, k=2, tol=1e-6):
|
|
||||||
kmeans = KMeans(n_clusters=k, tol=tol)
|
|
||||||
labels = kmeans.fit_predict(data)
|
|
||||||
return labels
|
|
||||||
|
|
||||||
def initiate_model(model_file_path):
|
def initiate_model(model_file_path):
|
||||||
print("AI: Model loaded from: " + model_file_path, flush=True)
|
print("AI: Model loaded from: " + model_file_path, flush=True)
|
||||||
model = tf.keras.models.load_model(model_file_path)
|
model = tf.keras.models.load_model(model_file_path)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user