add Python_Keras_set_weights function (dimeinsons in vector_to_numpy wrong?)

This commit is contained in:
Hannes Signer 2025-01-09 20:59:03 +01:00
parent ef77d755ab
commit 4169cf8a20
2 changed files with 76 additions and 2 deletions

View File

@ -8,6 +8,7 @@
#include <Python.h>
#include <Rmath.h>
#include <condition_variable>
#include <cstddef>
#include <cstdio>
#include <cstdlib>
#include <cstring>
@ -908,8 +909,6 @@ void naa_training(EigenModel *Eigen_model, EigenModel *Eigen_model_reactive,
// update model weights with received weights
EigenModel deserializedModel =
deserializeModelWeights(serializedModel, modelSize);
fprintf(stdout, "After deserialization: %f\n",
deserializedModel.weight_matrices[0](0, 0));
Eigen_model_mutex->lock();
@ -917,7 +916,18 @@ void naa_training(EigenModel *Eigen_model, EigenModel *Eigen_model_reactive,
Eigen_model->biases = deserializedModel.biases;
Eigen_model_mutex->unlock();
std::vector<std::vector<std::vector<double>>> cpp_weights =
Python_Keras_get_weights(model_name);
fprintf(stdout, "size of cpp weights: %zu\n", cpp_weights.size());
for(size_t i = 0; i<cpp_weights.size(); i++){
fprintf(stdout, "size of cpp weights: %zu\n", cpp_weights[i].size());
fprintf(stdout, "size of cpp weights: %zu\n", cpp_weights[i][0].size());
}
Python_keras_set_weights(model_name, cpp_weights);
// for (int i = 0; i < Eigen_model->weight_matrices[0].rows(); i++) {
@ -1126,6 +1136,66 @@ Python_Keras_get_weights(std::string model_name) {
return cpp_weights;
}
int Python_keras_set_weights(std::string model_name, std::vector<std::vector<std::vector<double>>> weights){
// Acquire the Python GIL
PyGILState_STATE gstate = PyGILState_Ensure();
PyObject* py_weights = PyList_New(weights.size());
for(size_t i = 0; i < weights.size(); i++){
PyObject* numpy_array = vector_to_numpy_array(weights[i]);
PyList_SetItem(py_weights, i, numpy_array);
}
// Iterate over py_weights and print the shape of each numpy array
for (Py_ssize_t i = 0; i < PyList_Size(py_weights); ++i) {
PyObject* numpy_array = PyList_GetItem(py_weights, i);
// Use numpy's shape attribute to get the shape
PyObject* shape = PyObject_GetAttrString(numpy_array, "shape");
PyObject* shape_str = PyObject_Repr(shape); // Get a string representation of the shape
PyObject* shape_utf8 = PyUnicode_AsEncodedString(shape_str, "utf-8", "strict");
const char* shape_bytes = PyBytes_AS_STRING(shape_utf8);
// Print the shape
std::cout << "Shape of numpy array at index " << i << ": " << shape_bytes << std::endl;
// Clean up
Py_DECREF(shape);
Py_DECREF(shape_str);
Py_DECREF(shape_utf8);
}
fprintf(stdout, "In Python_Keras_set_weights\n");
PyObject *py_main_module = PyImport_AddModule("__main__");
PyObject *py_global_dict = PyModule_GetDict(py_main_module);
PyObject *py_keras_model =
PyDict_GetItemString(py_global_dict, model_name.c_str());
PyObject *py_set_weights_function =
PyDict_GetItemString(py_global_dict, "set_weights");
PyObject *args = Py_BuildValue("(OO)", py_keras_model, py_weights);
PyObject *py_set_weights = PyObject_CallObject(py_set_weights_function, args);
if (!py_set_weights) {
PyErr_Print(); // Gibt den Python-Fehler aus
std::cerr << "Error: Failed to call set_weights function." << std::endl;
}
Py_XDECREF(py_weights);
Py_DECREF(args);
Py_XDECREF(py_set_weights);
PyGILState_Release(gstate);
return 0;
}
/**
* @brief Joins the training thread and winds down the Python environment
* gracefully

View File

@ -89,6 +89,8 @@ std::vector<double> Eigen_predict_clustered(const EigenModel& model, const Eigen
std::vector<double> Eigen_predict(const EigenModel& model, std::vector<std::vector<double>> x, int batch_size,
std::mutex* Eigen_model_mutex);
int Python_keras_set_weights(std::string model_name, std::vector<std::vector<std::vector<double>>> weights);
// Otherwise, define the necessary stubs
#else
inline void Python_Keras_setup(std::string, std::string){}
@ -110,6 +112,8 @@ inline std::vector<double> Eigen_predict_clustered(const EigenModel&, const Eige
std::mutex*, std::vector<int>&){return {};}
inline std::vector<double> Eigen_predict(const EigenModel&, std::vector<std::vector<double>>, int,
std::mutex*){return {};}
inline int Python_keras_set_weights(std::string model_name, std::vector<std::vector<std::vector<double>>> weights);
#endif
} // namespace poet