mirror of
https://git.gfz-potsdam.de/naaice/poet.git
synced 2025-12-15 20:38:23 +01:00
add Python_Keras_set_weights function (dimeinsons in vector_to_numpy wrong?)
This commit is contained in:
parent
ef77d755ab
commit
4169cf8a20
@ -8,6 +8,7 @@
|
||||
#include <Python.h>
|
||||
#include <Rmath.h>
|
||||
#include <condition_variable>
|
||||
#include <cstddef>
|
||||
#include <cstdio>
|
||||
#include <cstdlib>
|
||||
#include <cstring>
|
||||
@ -908,8 +909,6 @@ void naa_training(EigenModel *Eigen_model, EigenModel *Eigen_model_reactive,
|
||||
// update model weights with received weights
|
||||
EigenModel deserializedModel =
|
||||
deserializeModelWeights(serializedModel, modelSize);
|
||||
fprintf(stdout, "After deserialization: %f\n",
|
||||
deserializedModel.weight_matrices[0](0, 0));
|
||||
|
||||
Eigen_model_mutex->lock();
|
||||
|
||||
@ -917,7 +916,18 @@ void naa_training(EigenModel *Eigen_model, EigenModel *Eigen_model_reactive,
|
||||
Eigen_model->biases = deserializedModel.biases;
|
||||
|
||||
Eigen_model_mutex->unlock();
|
||||
|
||||
|
||||
std::vector<std::vector<std::vector<double>>> cpp_weights =
|
||||
Python_Keras_get_weights(model_name);
|
||||
fprintf(stdout, "size of cpp weights: %zu\n", cpp_weights.size());
|
||||
for(size_t i = 0; i<cpp_weights.size(); i++){
|
||||
fprintf(stdout, "size of cpp weights: %zu\n", cpp_weights[i].size());
|
||||
fprintf(stdout, "size of cpp weights: %zu\n", cpp_weights[i][0].size());
|
||||
}
|
||||
|
||||
Python_keras_set_weights(model_name, cpp_weights);
|
||||
|
||||
|
||||
|
||||
// for (int i = 0; i < Eigen_model->weight_matrices[0].rows(); i++) {
|
||||
@ -1126,6 +1136,66 @@ Python_Keras_get_weights(std::string model_name) {
|
||||
return cpp_weights;
|
||||
}
|
||||
|
||||
|
||||
int Python_keras_set_weights(std::string model_name, std::vector<std::vector<std::vector<double>>> weights){
|
||||
// Acquire the Python GIL
|
||||
PyGILState_STATE gstate = PyGILState_Ensure();
|
||||
|
||||
PyObject* py_weights = PyList_New(weights.size());
|
||||
|
||||
for(size_t i = 0; i < weights.size(); i++){
|
||||
PyObject* numpy_array = vector_to_numpy_array(weights[i]);
|
||||
PyList_SetItem(py_weights, i, numpy_array);
|
||||
}
|
||||
|
||||
// Iterate over py_weights and print the shape of each numpy array
|
||||
for (Py_ssize_t i = 0; i < PyList_Size(py_weights); ++i) {
|
||||
PyObject* numpy_array = PyList_GetItem(py_weights, i);
|
||||
|
||||
// Use numpy's shape attribute to get the shape
|
||||
PyObject* shape = PyObject_GetAttrString(numpy_array, "shape");
|
||||
PyObject* shape_str = PyObject_Repr(shape); // Get a string representation of the shape
|
||||
PyObject* shape_utf8 = PyUnicode_AsEncodedString(shape_str, "utf-8", "strict");
|
||||
const char* shape_bytes = PyBytes_AS_STRING(shape_utf8);
|
||||
|
||||
// Print the shape
|
||||
std::cout << "Shape of numpy array at index " << i << ": " << shape_bytes << std::endl;
|
||||
|
||||
// Clean up
|
||||
Py_DECREF(shape);
|
||||
Py_DECREF(shape_str);
|
||||
Py_DECREF(shape_utf8);
|
||||
}
|
||||
|
||||
fprintf(stdout, "In Python_Keras_set_weights\n");
|
||||
|
||||
PyObject *py_main_module = PyImport_AddModule("__main__");
|
||||
PyObject *py_global_dict = PyModule_GetDict(py_main_module);
|
||||
PyObject *py_keras_model =
|
||||
PyDict_GetItemString(py_global_dict, model_name.c_str());
|
||||
PyObject *py_set_weights_function =
|
||||
PyDict_GetItemString(py_global_dict, "set_weights");
|
||||
PyObject *args = Py_BuildValue("(OO)", py_keras_model, py_weights);
|
||||
|
||||
|
||||
PyObject *py_set_weights = PyObject_CallObject(py_set_weights_function, args);
|
||||
if (!py_set_weights) {
|
||||
PyErr_Print(); // Gibt den Python-Fehler aus
|
||||
std::cerr << "Error: Failed to call set_weights function." << std::endl;
|
||||
}
|
||||
|
||||
Py_XDECREF(py_weights);
|
||||
Py_DECREF(args);
|
||||
Py_XDECREF(py_set_weights);
|
||||
|
||||
PyGILState_Release(gstate);
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
/**
|
||||
* @brief Joins the training thread and winds down the Python environment
|
||||
* gracefully
|
||||
|
||||
@ -89,6 +89,8 @@ std::vector<double> Eigen_predict_clustered(const EigenModel& model, const Eigen
|
||||
std::vector<double> Eigen_predict(const EigenModel& model, std::vector<std::vector<double>> x, int batch_size,
|
||||
std::mutex* Eigen_model_mutex);
|
||||
|
||||
int Python_keras_set_weights(std::string model_name, std::vector<std::vector<std::vector<double>>> weights);
|
||||
|
||||
// Otherwise, define the necessary stubs
|
||||
#else
|
||||
inline void Python_Keras_setup(std::string, std::string){}
|
||||
@ -110,6 +112,8 @@ inline std::vector<double> Eigen_predict_clustered(const EigenModel&, const Eige
|
||||
std::mutex*, std::vector<int>&){return {};}
|
||||
inline std::vector<double> Eigen_predict(const EigenModel&, std::vector<std::vector<double>>, int,
|
||||
std::mutex*){return {};}
|
||||
|
||||
inline int Python_keras_set_weights(std::string model_name, std::vector<std::vector<std::vector<double>>> weights);
|
||||
#endif
|
||||
} // namespace poet
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user