mirror of
https://git.gfz-potsdam.de/naaice/poet.git
synced 2025-12-16 04:48:23 +01:00
fix: roll back to functioning state
This commit is contained in:
parent
c4da86ef98
commit
110d5c810b
@ -16,7 +16,13 @@ def initiate_model(model_file_path):
|
|||||||
model = tf.keras.models.load_model(model_file_path)
|
model = tf.keras.models.load_model(model_file_path)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
def prediction_step(model, model_reactive, x, cluster_labels, batch_size):
|
def prediction_step(model, model_reactive, x, cluster_labels, batch_size)
|
||||||
|
# Catch input size mismatches
|
||||||
|
model_input_shape = model.input_shape[1:]
|
||||||
|
if model_input_shape != x.shape[1:]:
|
||||||
|
print(f"Input data size {x.shape[1:]} does not match model input size {model_input_shape}",
|
||||||
|
flush=True)
|
||||||
|
|
||||||
# Predict separately if clustering is used
|
# Predict separately if clustering is used
|
||||||
if cluster_labels:
|
if cluster_labels:
|
||||||
cluster_labels = np.asarray(cluster_labels, dtype=bool)
|
cluster_labels = np.asarray(cluster_labels, dtype=bool)
|
||||||
|
|||||||
@ -344,7 +344,8 @@ std::vector<double> Eigen_predict_clustered(const EigenModel& model, const Eigen
|
|||||||
if (num_features != model.weight_matrices[0].cols() ||
|
if (num_features != model.weight_matrices[0].cols() ||
|
||||||
num_features != model_reactive.weight_matrices[0].cols()) {
|
num_features != model_reactive.weight_matrices[0].cols()) {
|
||||||
throw std::runtime_error("Input data size " + std::to_string(num_features) +
|
throw std::runtime_error("Input data size " + std::to_string(num_features) +
|
||||||
" does not match model input layer sizes");
|
" does not match model input layer sizes" + std::to_string(model.weight_matrices[0].cols()) +
|
||||||
|
" / " + std::to_string(model_reactive.weight_matrices[0].cols()));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Convert input data to Eigen matrix
|
// Convert input data to Eigen matrix
|
||||||
|
|||||||
@ -161,6 +161,15 @@ void poet::ChemistryModule::WorkerDoWork(MPI_Status &probe_status,
|
|||||||
mpi_buffer.begin() + this->prop_count * (wp_i + 1));
|
mpi_buffer.begin() + this->prop_count * (wp_i + 1));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (this->ai_surrogate_enabled) {
|
||||||
|
// Map valid predictions from the ai surrogate in the workpackage
|
||||||
|
for (int i = 0; i < s_curr_wp.size; i++) {
|
||||||
|
if (this->ai_surrogate_validity_vector[wp_start_index + i] == 1) {
|
||||||
|
s_curr_wp.mapping[i] = CHEM_AISURR;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// std::cout << this->comm_rank << ":" << counter++ << std::endl;
|
// std::cout << this->comm_rank << ":" << counter++ << std::endl;
|
||||||
if (dht_enabled || interp_enabled) {
|
if (dht_enabled || interp_enabled) {
|
||||||
dht->prepareKeys(s_curr_wp.input, dt);
|
dht->prepareKeys(s_curr_wp.input, dt);
|
||||||
@ -178,15 +187,6 @@ void poet::ChemistryModule::WorkerDoWork(MPI_Status &probe_status,
|
|||||||
interp->tryInterpolation(s_curr_wp);
|
interp->tryInterpolation(s_curr_wp);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (this->ai_surrogate_enabled) {
|
|
||||||
// Map valid predictions from the ai surrogate in the workpackage
|
|
||||||
for (int i = 0; i < s_curr_wp.size; i++) {
|
|
||||||
if (this->ai_surrogate_validity_vector[wp_start_index + i] == 1) {
|
|
||||||
s_curr_wp.mapping[i] = CHEM_AISURR;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
phreeqc_time_start = MPI_Wtime();
|
phreeqc_time_start = MPI_Wtime();
|
||||||
|
|
||||||
|
|||||||
12
src/poet.cpp
12
src/poet.cpp
@ -457,9 +457,14 @@ static Rcpp::List RunMasterLoop(RInsidePOET &R, const RuntimeParameters ¶ms,
|
|||||||
if (params.use_ai_surrogate && !params.disable_training) {
|
if (params.use_ai_surrogate && !params.disable_training) {
|
||||||
// Add values for which the predictions were invalid
|
// Add values for which the predictions were invalid
|
||||||
// to training data buffer
|
// to training data buffer
|
||||||
MSG("AI: Add invalid predictions to training data buffer");
|
MSG("AI: Add to training data buffer");
|
||||||
std::vector<std::vector<double>> invalid_x =
|
std::vector<std::vector<double>> invalid_x =
|
||||||
R.parseEval("get_invalid_values(predictors_scaled, validity_vector)");
|
R.parseEval("get_invalid_values(predictors_scaled, validity_vector)");
|
||||||
|
|
||||||
|
if (!params.train_only_invalid) {
|
||||||
|
// Use all values if not specified otherwise
|
||||||
|
R.parseEval("validity_vector[] <- 0");
|
||||||
|
}
|
||||||
|
|
||||||
R.parseEval("target_scaled <- preprocess(state_C[ai_surrogate_species])");
|
R.parseEval("target_scaled <- preprocess(state_C[ai_surrogate_species])");
|
||||||
std::vector<std::vector<double>> invalid_y =
|
std::vector<std::vector<double>> invalid_y =
|
||||||
@ -675,6 +680,8 @@ int main(int argc, char *argv[]) {
|
|||||||
/* Use dht species for model input and output */
|
/* Use dht species for model input and output */
|
||||||
R["ai_surrogate_species"] =
|
R["ai_surrogate_species"] =
|
||||||
init_list.getChemistryInit().dht_species.getNames();
|
init_list.getChemistryInit().dht_species.getNames();
|
||||||
|
|
||||||
|
// TODO REMOVE!!
|
||||||
R.parseEval("ai_surrogate_species <- ai_surrogate_species[ai_surrogate_species != \"Charge\"]");
|
R.parseEval("ai_surrogate_species <- ai_surrogate_species[ai_surrogate_species != \"Charge\"]");
|
||||||
const std::string ai_surrogate_input_script =
|
const std::string ai_surrogate_input_script =
|
||||||
init_list.getChemistryInit().ai_surrogate_input_script;
|
init_list.getChemistryInit().ai_surrogate_input_script;
|
||||||
@ -709,6 +716,9 @@ R.parseEval("ai_surrogate_species <- ai_surrogate_species[ai_surrogate_species !
|
|||||||
if (Rcpp::as<bool>(R.parseEval("exists(\"disable_training\")"))) {
|
if (Rcpp::as<bool>(R.parseEval("exists(\"disable_training\")"))) {
|
||||||
run_params.disable_training = R["disable_training"];
|
run_params.disable_training = R["disable_training"];
|
||||||
}
|
}
|
||||||
|
if (Rcpp::as<bool>(R.parseEval("exists(\"train_only_invalid\")"))) {
|
||||||
|
run_params.train_only_invalid = R["train_only_invalid"];
|
||||||
|
}
|
||||||
if (Rcpp::as<bool>(R.parseEval("exists(\"save_model_path\")"))) {
|
if (Rcpp::as<bool>(R.parseEval("exists(\"save_model_path\")"))) {
|
||||||
run_params.save_model_path = Rcpp::as<std::string>(R["save_model_path"]);
|
run_params.save_model_path = Rcpp::as<std::string>(R["save_model_path"]);
|
||||||
MSG("AI: Model will be saved as \"" + run_params.save_model_path + "\"");
|
MSG("AI: Model will be saved as \"" + run_params.save_model_path + "\"");
|
||||||
|
|||||||
@ -74,6 +74,7 @@ struct RuntimeParameters {
|
|||||||
bool disable_training; // Can be set in the R input script
|
bool disable_training; // Can be set in the R input script
|
||||||
bool use_k_means_clustering; // Can be set in the R input script
|
bool use_k_means_clustering; // Can be set in the R input script
|
||||||
bool use_Keras_predictions; // Can be set in the R input script
|
bool use_Keras_predictions; // Can be set in the R input script
|
||||||
|
bool train_only_invalid; // Can be set in the R input script
|
||||||
int batch_size; // Can be set in the R input script
|
int batch_size; // Can be set in the R input script
|
||||||
int training_epochs; // Can be set in the R input script
|
int training_epochs; // Can be set in the R input script
|
||||||
int training_data_size; // Can be set in the R input script
|
int training_data_size; // Can be set in the R input script
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user