Skip to content

Moved python functions #1374

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 27 additions & 6 deletions pgml-extension/src/bindings/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,31 @@ use std::fmt::Debug;
use anyhow::{anyhow, Result};
#[allow(unused_imports)] // used for test macros
use pgrx::*;
use pyo3::{PyResult, Python};
use pyo3::{pyfunction, PyResult, Python};

use crate::orm::*;

#[pyfunction]
fn r_insert_logs(project_id: i64, model_id: i64, logs: String) -> PyResult<String> {
let id_value = Spi::get_one_with_args::<i64>(
"INSERT INTO pgml.logs (project_id, model_id, logs) VALUES ($1, $2, $3::JSONB) RETURNING id;",
vec![
(PgBuiltInOids::INT8OID.oid(), project_id.into_datum()),
(PgBuiltInOids::INT8OID.oid(), model_id.into_datum()),
(PgBuiltInOids::TEXTOID.oid(), logs.into_datum()),
],
)
.unwrap()
.unwrap();
Ok(format!("Inserted logs with id: {}", id_value))
}

#[pyfunction]
fn r_print_info(info: String) -> PyResult<String> {
info!("{}", info);
Ok(info)
}

#[cfg(feature = "python")]
#[macro_export]
macro_rules! create_pymodule {
Expand All @@ -16,11 +37,11 @@ macro_rules! create_pymodule {
pyo3::Python::with_gil(|py| -> anyhow::Result<pyo3::Py<pyo3::types::PyModule>> {
use $crate::bindings::TracebackError;
let src = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), $pyfile));
Ok(
pyo3::types::PyModule::from_code(py, src, "transformers.py", "__main__")
.format_traceback(py)?
.into(),
)
let module = pyo3::types::PyModule::from_code(py, src, "transformers.py", "__main__")
.format_traceback(py)?;
module.add_function(wrap_pyfunction!($crate::bindings::r_insert_logs, module)?)?;
module.add_function(wrap_pyfunction!($crate::bindings::r_print_info, module)?)?;
Ok(module.into())
})
});
};
Expand Down
14 changes: 6 additions & 8 deletions pgml-extension/src/bindings/transformers/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@
from trl import SFTTrainer, DataCollatorForCompletionOnlyLM
from trl.trainer import ConstantLengthDataset
from peft import LoraConfig, get_peft_model
from pypgrx import print_info, insert_logs
from abc import abstractmethod

transformers.logging.set_verbosity_info()
Expand Down Expand Up @@ -1017,8 +1016,7 @@ def on_log(self, args, state, control, logs=None, **kwargs):
logs["step"] = state.global_step
logs["max_steps"] = state.max_steps
logs["timestamp"] = str(datetime.now())
print_info(json.dumps(logs, indent=4))
insert_logs(self.project_id, self.model_id, json.dumps(logs))
r_print_info(json.dumps(logs, indent=4))


class FineTuningBase:
Expand Down Expand Up @@ -1100,9 +1098,9 @@ def print_number_of_trainable_model_parameters(self, model):
trainable_model_params += param.numel()

# Calculate and print the number and percentage of trainable parameters
print_info(f"Trainable model parameters: {trainable_model_params}")
print_info(f"All model parameters: {all_model_params}")
print_info(
r_print_info(f"Trainable model parameters: {trainable_model_params}")
r_print_info(f"All model parameters: {all_model_params}")
r_print_info(
f"Percentage of trainable model parameters: {100 * trainable_model_params / all_model_params:.2f}%"
)

Expand Down Expand Up @@ -1398,7 +1396,7 @@ def __init__(
"bias": "none",
"task_type": "CAUSAL_LM",
}
print_info(
r_print_info(
"LoRA configuration are not set. Using default parameters"
+ json.dumps(self.lora_config_params)
)
Expand Down Expand Up @@ -1465,7 +1463,7 @@ def formatting_prompts_func(example):
peft_config=LoraConfig(**self.lora_config_params),
callbacks=[PGMLCallback(self.project_id, self.model_id)],
)
print_info("Creating Supervised Fine Tuning trainer done. Training ... ")
r_print_info("Creating Supervised Fine Tuning trainer done. Training ... ")

# Train
self.trainer.train()
Expand Down