Skip to content

Prepare pgml for publishing to crates.io #1500

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jun 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions pgml-sdks/pgml/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

10 changes: 6 additions & 4 deletions pgml-sdks/pgml/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@ name = "pgml"
crate-type = ["lib", "cdylib"]

[dependencies]
rust_bridge = {path = "../rust-bridge/rust-bridge", version = "0.1.0"}
# rust_bridge = {path = "../rust-bridge/rust-bridge", version = "0.1.0", optional = true }
rust_bridge = {git = "https://github.com/postgresml/postgresml", version = "0.1.0", optional = true }
sqlx = { version = "0.7.3", features = [ "runtime-tokio-rustls", "postgres", "json", "time", "uuid"] }
serde_json = "1.0.9"
anyhow = "1.0.9"
Expand Down Expand Up @@ -50,6 +51,7 @@ serde_with = "3.8.1"

[features]
default = []
python = ["dep:pyo3", "dep:pyo3-asyncio"]
javascript = ["dep:neon"]
c = []
rust_bridge = ["dep:rust_bridge"]
python = ["rust_bridge", "dep:pyo3", "dep:pyo3-asyncio"]
javascript = ["rust_bridge", "dep:neon"]
c = ["rust_bridge"]
22 changes: 14 additions & 8 deletions pgml-sdks/pgml/src/builtins.rs
Original file line number Diff line number Diff line change
@@ -1,23 +1,29 @@
use anyhow::Context;
use rust_bridge::{alias, alias_methods};
use sqlx::Row;
use tracing::instrument;

/// Provides access to builtin database methods
#[derive(alias, Debug, Clone)]
pub struct Builtins {
database_url: Option<String>,
}

use crate::{get_or_initialize_pool, query_runner::QueryRunner, types::Json};

#[cfg(feature = "rust_bridge")]
use rust_bridge::{alias, alias_methods};

#[cfg(feature = "python")]
use crate::{query_runner::QueryRunnerPython, types::JsonPython};

#[cfg(feature = "c")]
use crate::{languages::c::JsonC, query_runner::QueryRunnerC};

#[alias_methods(new, query, transform, embed, embed_batch)]
/// Provides access to builtin database methods
#[cfg_attr(feature = "rust_bridge", derive(alias))]
#[derive(Debug, Clone)]
pub struct Builtins {
database_url: Option<String>,
}

#[cfg_attr(
feature = "rust_bridge",
alias_methods(new, query, transform, embed, embed_batch)
)]
impl Builtins {
pub fn new(database_url: Option<String>) -> Self {
Self { database_url }
Expand Down
123 changes: 96 additions & 27 deletions pgml-sdks/pgml/src/collection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ use anyhow::Context;
use indicatif::MultiProgress;
use itertools::Itertools;
use regex::Regex;
use rust_bridge::{alias, alias_methods};
use sea_query::Alias;
use sea_query::{Expr, NullOrdering, Order, PostgresQueryBuilder, Query};
use sea_query_binder::SqlxBinder;
Expand Down Expand Up @@ -35,6 +34,12 @@ use crate::{
utils,
};

#[cfg(feature = "rust_bridge")]
use rust_bridge::{alias, alias_methods};

#[cfg(feature = "c")]
use crate::languages::c::GeneralJsonAsyncIteratorC;

#[cfg(feature = "python")]
use crate::{
pipeline::PipelinePython,
Expand All @@ -43,7 +48,7 @@ use crate::{
};

/// A RAGStream Struct
#[derive(alias)]
#[cfg_attr(feature = "rust_bridge", derive(alias))]
#[allow(dead_code)]
pub struct RAGStream {
general_json_async_iterator: Option<GeneralJsonAsyncIterator>,
Expand All @@ -57,7 +62,7 @@ impl Clone for RAGStream {
}
}

#[alias_methods(stream, sources)]
#[cfg_attr(feature = "rust_bridge", alias_methods(stream, sources))]
impl RAGStream {
pub fn stream(&mut self) -> anyhow::Result<GeneralJsonAsyncIterator> {
self.general_json_async_iterator
Expand Down Expand Up @@ -140,7 +145,8 @@ pub(crate) struct CollectionDatabaseData {
}

/// A collection of documents
#[derive(alias, Debug, Clone)]
#[cfg_attr(feature = "rust_bridge", derive(alias))]
#[derive(Debug, Clone)]
pub struct Collection {
pub(crate) name: String,
pub(crate) database_url: Option<String>,
Expand All @@ -149,29 +155,32 @@ pub struct Collection {
pub(crate) database_data: Option<CollectionDatabaseData>,
}

#[alias_methods(
new,
upsert_documents,
get_documents,
delete_documents,
get_pipelines,
get_pipeline,
add_pipeline,
remove_pipeline,
enable_pipeline,
disable_pipeline,
search,
add_search_event,
vector_search,
query,
rag,
rag_stream,
exists,
archive,
upsert_directory,
upsert_file,
generate_er_diagram,
get_pipeline_status
#[cfg_attr(
feature = "rust_bridge",
alias_methods(
new,
upsert_documents,
get_documents,
delete_documents,
get_pipelines,
get_pipeline,
add_pipeline,
remove_pipeline,
enable_pipeline,
disable_pipeline,
search,
add_search_event,
vector_search,
query,
rag,
rag_stream,
exists,
archive,
upsert_directory,
upsert_file,
generate_er_diagram,
get_pipeline_status
)
)]
impl Collection {
/// Creates a new [Collection]
Expand Down Expand Up @@ -1128,6 +1137,65 @@ impl Collection {
.collect())
}

/// Performs rag on the [Collection]
///
/// # Arguments
/// * `query` - The query to search for
/// * `pipeline` - The [Pipeline] to use for the search
///
/// # Example
/// ```
/// use pgml::Collection;
/// use pgml::Pipeline;
/// use serde_json::json;
/// use anyhow::Result;
/// async fn run() -> anyhow::Result<()> {
/// let mut collection = Collection::new("my_collection", None)?;
/// let mut pipeline = Pipeline::new("my_pipeline", None)?;
/// let results = collection.rag(json!({
/// "CONTEXT": {
/// "vector_search": {
/// "query": {
/// "fields": {
/// "body": {
/// "query": "Test document: 2",
/// "parameters": {
/// "prompt": "query: "
/// }
/// },
/// },
/// },
/// "document": {
/// "keys": [
/// "id"
/// ]
/// },
/// "limit": 2
/// },
/// "aggregate": {
/// "join": "\n"
/// }
/// },
/// "CUSTOM": {
/// "sql": "SELECT 'test'"
/// },
/// "chat": {
/// "model": "meta-llama/Meta-Llama-3-8B-Instruct",
/// "messages": [
/// {
/// "role": "system",
/// "content": "You are a friendly and helpful chatbot"
/// },
/// {
/// "role": "user",
/// "content": "Some text with {CONTEXT} - {CUSTOM}",
/// }
/// ],
/// "max_tokens": 10
/// }
/// }).into(), &mut pipeline).await?;
/// Ok(())
/// }
#[instrument(skip(self))]
pub async fn rag(&self, query: Json, pipeline: &mut Pipeline) -> anyhow::Result<Json> {
let pool = get_or_initialize_pool(&self.database_url).await?;
Expand All @@ -1138,6 +1206,7 @@ impl Collection {
Ok(std::mem::take(&mut results[0].0))
}

/// Same as rag buit returns a stream of results
#[instrument(skip(self))]
pub async fn rag_stream(
&self,
Expand Down
9 changes: 6 additions & 3 deletions pgml-sdks/pgml/src/model.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
use rust_bridge::{alias, alias_methods};
use sqlx::{Pool, Postgres};
use tracing::instrument;

Expand All @@ -14,6 +13,9 @@ use crate::types::JsonPython;
#[cfg(feature = "c")]
use crate::languages::c::JsonC;

#[cfg(feature = "rust_bridge")]
use rust_bridge::{alias, alias_methods};

/// A few notes on the following enums:
/// - Sqlx does provide type derivation for enums, but it's not very good
/// - Queries using these enums require a number of additional queries to get their oids and
Expand Down Expand Up @@ -55,7 +57,8 @@ pub(crate) struct ModelDatabaseData {
}

/// A model used for embedding, inference, etc...
#[derive(alias, Debug, Clone)]
#[cfg_attr(feature = "rust_bridge", derive(alias))]
#[derive(Debug, Clone)]
pub struct Model {
pub(crate) name: String,
pub(crate) runtime: ModelRuntime,
Expand All @@ -69,7 +72,7 @@ impl Default for Model {
}
}

#[alias_methods(new, transform)]
#[cfg_attr(feature = "rust_bridge", alias_methods(new, transform))]
impl Model {
/// Creates a new [Model]
pub fn new(name: Option<String>, source: Option<String>, parameters: Option<Json>) -> Self {
Expand Down
22 changes: 14 additions & 8 deletions pgml-sdks/pgml/src/open_source_ai.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
use anyhow::Context;
use futures::{Stream, StreamExt};
use rust_bridge::{alias, alias_methods};
use std::time::{SystemTime, UNIX_EPOCH};
use uuid::Uuid;

Expand All @@ -10,6 +9,9 @@ use crate::{
TransformerPipeline,
};

#[cfg(feature = "rust_bridge")]
use rust_bridge::{alias, alias_methods};

#[cfg(feature = "python")]
use crate::types::{GeneralJsonAsyncIteratorPython, GeneralJsonIteratorPython, JsonPython};

Expand All @@ -20,7 +22,8 @@ use crate::{
};

/// A drop in replacement for OpenAI
#[derive(alias, Debug, Clone)]
#[cfg_attr(feature = "rust_bridge", derive(alias))]
#[derive(Debug, Clone)]
pub struct OpenSourceAI {
database_url: Option<String>,
}
Expand Down Expand Up @@ -166,12 +169,15 @@ impl Iterator for AsyncToSyncJsonIterator {
}
}

#[alias_methods(
new,
chat_completions_create,
chat_completions_create_async,
chat_completions_create_stream,
chat_completions_create_stream_async
#[cfg_attr(
feature = "rust_bridge",
alias_methods(
new,
chat_completions_create,
chat_completions_create_async,
chat_completions_create_stream,
chat_completions_create_stream_async
)
)]
impl OpenSourceAI {
/// Creates a new [OpenSourceAI]
Expand Down
9 changes: 6 additions & 3 deletions pgml-sdks/pgml/src/pipeline.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
use anyhow::Context;
use rust_bridge::{alias, alias_methods};
use serde::Deserialize;
use serde_json::json;
use sqlx::{Executor, PgConnection, Pool, Postgres, Transaction};
Expand All @@ -16,6 +15,9 @@ use crate::{
types::{DateTime, Json, TryToNumeric},
};

#[cfg(feature = "rust_bridge")]
use rust_bridge::{alias, alias_methods};

#[cfg(feature = "python")]
use crate::types::JsonPython;

Expand Down Expand Up @@ -179,7 +181,8 @@ pub struct PipelineDatabaseData {
}

/// A pipeline that describes transformations to documents
#[derive(alias, Debug, Clone)]
#[cfg_attr(feature = "rust_bridge", derive(alias))]
#[derive(Debug, Clone)]
pub struct Pipeline {
pub(crate) name: String,
pub(crate) schema: Option<Json>,
Expand All @@ -205,7 +208,7 @@ fn json_to_schema(schema: &Json) -> anyhow::Result<ParsedSchema> {
})
}

#[alias_methods(new)]
#[cfg_attr(feature = "rust_bridge", alias_methods(new))]
impl Pipeline {
/// Creates a [Pipeline]
///
Expand Down
Loading