diff --git a/Cargo.lock b/Cargo.lock index 74424e5..393b5a1 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -207,6 +207,46 @@ dependencies = [ "stable_deref_trait", ] +[[package]] +name = "async-openai" +version = "0.32.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "afb7051804e03daf32cd7e45e7a655bb6cea9283309d2253babfb38c09f4ea03" +dependencies = [ + "async-openai-macros", + "backoff", + "base64 0.22.1", + "bytes", + "derive_builder", + "eventsource-stream", + "futures", + "getrandom 0.3.4", + "rand 0.9.2", + "reqwest", + "reqwest-eventsource", + "secrecy", + "serde", + "serde_json", + "serde_urlencoded", + "thiserror 2.0.17", + "tokio", + "tokio-stream", + "tokio-util", + "tracing", + "url", +] + +[[package]] +name = "async-openai-macros" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "81872a8e595e8ceceab71c6ba1f9078e313b452a1e31934e6763ef5d308705e4" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.110", +] + [[package]] name = "async-trait" version = "0.1.89" @@ -265,26 +305,35 @@ name = "auto-engine-core" version = "0.3.0" dependencies = [ "anyhow", + "async-openai", "async-trait", "auto-engine-macro", + "axum", + "bytes", "convert_case 0.10.0", "enigo", "evalexpr", "futures", + "http", "log", "oar-ocr", "once_cell", "opencv", "regex", "reqwest", + "rmcp", "schemars 1.1.0", "screenshots", + "secrecy", "serde", "serde_json", "serde_yaml", "tauri", "tokio", "tokio-util", + "toml 0.9.10+spec-1.1.0", + "tracing", + "tracing-subscriber", "wasmtime", "wasmtime-wasi", ] @@ -320,7 +369,7 @@ dependencies = [ "log", "num-rational", "num-traits", - "pastey", + "pastey 0.1.1", "rayon", "thiserror 2.0.17", "v_frame", @@ -350,6 +399,72 @@ dependencies = [ "arrayvec", ] +[[package]] +name = "axum" +version = "0.8.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8b52af3cb4058c895d37317bb27508dccc8e5f2d39454016b297bf4a400597b8" +dependencies = [ + "axum-core", + "bytes", + "form_urlencoded", + "futures-util", + "http", + "http-body", + "http-body-util", + "hyper", + "hyper-util", + "itoa", + "matchit", + "memchr", + "mime", + "percent-encoding", + "pin-project-lite", + "serde_core", + "serde_json", + "serde_path_to_error", + "serde_urlencoded", + "sync_wrapper", + "tokio", + "tower", + "tower-layer", + "tower-service", + "tracing", +] + +[[package]] +name = "axum-core" +version = "0.5.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "59446ce19cd142f8833f856eb31f3eb097812d1479ab224f54d72428ca21ea22" +dependencies = [ + "bytes", + "futures-core", + "http", + "http-body", + "http-body-util", + "mime", + "pin-project-lite", + "sync_wrapper", + "tower-layer", + "tower-service", + "tracing", +] + +[[package]] +name = "backoff" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b62ddb9cb1ec0a098ad4bbf9344d0713fa193ae1a80af55febcff2627b6a00c1" +dependencies = [ + "futures-core", + "getrandom 0.2.16", + "instant", + "pin-project-lite", + "rand 0.8.5", + "tokio", +] + [[package]] name = "base64" version = "0.13.1" @@ -496,9 +611,9 @@ checksum = "8f1fe948ff07f4bd06c30984e69f5b4899c516a3ef74f34df92a2df2ab535495" [[package]] name = "bytes" -version = "1.10.1" +version = "1.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d71b6127be86fdcfddb610f7182ac57211d4b18a3e9c82eb2d17662f2227ad6a" +checksum = "b35204fbdc0b3f4446b89fc1ac2cf84a8a68971995d0bf2e925ec7cd960f9cb3" dependencies = [ "serde", ] @@ -645,7 +760,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "374b7c592d9c00c1f4972ea58390ac6b18cbb6ab79011f3bdc90a0b82ca06b77" dependencies = [ "serde", - "toml 0.9.8", + "toml 0.9.10+spec-1.1.0", ] [[package]] @@ -715,8 +830,10 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "145052bdd345b87320e369255277e3fb5152762ad123a901ef5c262dd38fe8d2" dependencies = [ "iana-time-zone", + "js-sys", "num-traits", "serde", + "wasm-bindgen", "windows-link 0.2.1", ] @@ -1267,6 +1384,16 @@ dependencies = [ "darling_macro 0.21.3", ] +[[package]] +name = "darling" +version = "0.23.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "25ae13da2f202d56bd7f91c25fba009e7717a1e4a1cc98a76d844b65ae912e9d" +dependencies = [ + "darling_core 0.23.0", + "darling_macro 0.23.0", +] + [[package]] name = "darling_core" version = "0.20.11" @@ -1295,6 +1422,19 @@ dependencies = [ "syn 2.0.110", ] +[[package]] +name = "darling_core" +version = "0.23.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9865a50f7c335f53564bb694ef660825eb8610e0a53d3e11bf1b0d3df31e03b0" +dependencies = [ + "ident_case", + "proc-macro2", + "quote", + "strsim", + "syn 2.0.110", +] + [[package]] name = "darling_macro" version = "0.20.11" @@ -1317,6 +1457,17 @@ dependencies = [ "syn 2.0.110", ] +[[package]] +name = "darling_macro" +version = "0.23.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ac3984ec7bd6cfa798e62b4a642426a5be0e68f9401cfc2a01e3fa9ea2fcdb8d" +dependencies = [ + "darling_core 0.23.0", + "quote", + "syn 2.0.110", +] + [[package]] name = "dary_heap" version = "0.3.8" @@ -1592,7 +1743,7 @@ dependencies = [ "cc", "memchr", "rustc_version", - "toml 0.9.8", + "toml 0.9.10+spec-1.1.0", "vswhom", "winreg", ] @@ -1711,6 +1862,17 @@ version = "12.0.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ae893d2d5e908b78f151ed89de3bfc272cdf6d368c7ed866942f98e24dea208a" +[[package]] +name = "eventsource-stream" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "74fef4569247a5f429d9156b9d0a2599914385dd189c539334c625d8099d90ab" +dependencies = [ + "futures-core", + "nom 7.1.3", + "pin-project-lite", +] + [[package]] name = "exr" version = "1.74.0" @@ -1977,6 +2139,12 @@ version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f90f7dce0722e95104fcb095585910c0977252f286e354b5e3bd38902cd99988" +[[package]] +name = "futures-timer" +version = "3.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f288b0a4f20f9a56b5d1da57e2227c661b7b16168e2f72365f57b63326e29b24" + [[package]] name = "futures-util" version = "0.3.31" @@ -2521,12 +2689,11 @@ dependencies = [ [[package]] name = "http" -version = "1.3.1" +version = "1.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f4a85d31aea989eead29a3aaf9e1115a180df8282431156e533de47660892565" +checksum = "e3ba2a386d7f85a81f119ad7498ebe444d2e22c2af0b86b069416ace48b3311a" dependencies = [ "bytes", - "fnv", "itoa", ] @@ -2559,6 +2726,12 @@ version = "1.10.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6dbf3de79e51f3d586ab4cb9d5c3e2c14aa28ed23d180cf89b4df0454a69cc87" +[[package]] +name = "httpdate" +version = "1.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "df3b46402a9d5adb4c86a0cf463f42e19994e3ee891101b1841f30a545cb49a9" + [[package]] name = "hyper" version = "1.7.0" @@ -2572,6 +2745,7 @@ dependencies = [ "http", "http-body", "httparse", + "httpdate", "itoa", "pin-project-lite", "pin-utils", @@ -2590,6 +2764,7 @@ dependencies = [ "hyper", "hyper-util", "rustls", + "rustls-native-certs", "rustls-pki-types", "tokio", "tokio-rustls", @@ -2904,6 +3079,15 @@ dependencies = [ "cfb", ] +[[package]] +name = "instant" +version = "0.1.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e0242819d153cba4b4b05a5a8f2a7e9bbf97b6055b2a002b395c96b5ff3c0222" +dependencies = [ + "cfg-if", +] + [[package]] name = "interpolate_name" version = "0.2.4" @@ -3384,6 +3568,12 @@ version = "0.1.10" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2532096657941c2fea9c289d370a250971c689d4f143798ff67113ec042024a5" +[[package]] +name = "matchit" +version = "0.8.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "47e1ffaa40ddd1f3ed91f717a33c8c0ee23fff369e3aa8772b9605cc1d22f4c3" + [[package]] name = "matrixmultiply" version = "0.3.10" @@ -3467,6 +3657,16 @@ version = "0.3.17" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6877bb514081ee2a7ff5ef9de3281f14a4dd4bceac4c09388074a6b5df8a139a" +[[package]] +name = "mime_guess" +version = "2.0.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f7c44f8e672c00fe5308fa235f821cb4198414e1c77935c1ab6948d3fd78550e" +dependencies = [ + "mime", + "unicase", +] + [[package]] name = "minimal-lexical" version = "0.2.1" @@ -3617,7 +3817,7 @@ dependencies = [ "openssl-probe", "openssl-sys", "schannel", - "security-framework", + "security-framework 2.11.1", "security-framework-sys", "tempfile", ] @@ -3859,7 +4059,7 @@ dependencies = [ "serde_json", "thiserror 2.0.17", "tokenizers", - "toml 0.9.8", + "toml 0.9.10+spec-1.1.0", "tracing", "tracing-subscriber", ] @@ -4351,6 +4551,12 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "35fb2e5f958ec131621fdd531e9fc186ed768cbe395337403ae56c17a74c68ec" +[[package]] +name = "pastey" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b867cad97c0791bbd3aaa6472142568c6c9e8f71937e98379f584cfb0cf35bec" + [[package]] name = "pem-rfc7468" version = "0.7.0" @@ -5232,10 +5438,12 @@ dependencies = [ "hyper-util", "js-sys", "log", + "mime_guess", "percent-encoding", "pin-project-lite", "quinn", "rustls", + "rustls-native-certs", "rustls-pki-types", "serde", "serde_json", @@ -5255,6 +5463,22 @@ dependencies = [ "webpki-roots", ] +[[package]] +name = "reqwest-eventsource" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "632c55746dbb44275691640e7b40c907c16a2dc1a5842aa98aaec90da6ec6bde" +dependencies = [ + "eventsource-stream", + "futures-core", + "futures-timer", + "mime", + "nom 7.1.3", + "pin-project-lite", + "reqwest", + "thiserror 1.0.69", +] + [[package]] name = "rgb" version = "0.8.52" @@ -5275,6 +5499,50 @@ dependencies = [ "windows-sys 0.52.0", ] +[[package]] +name = "rmcp" +version = "0.12.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "528d42f8176e6e5e71ea69182b17d1d0a19a6b3b894b564678b74cd7cab13cfa" +dependencies = [ + "async-trait", + "base64 0.22.1", + "bytes", + "chrono", + "futures", + "http", + "http-body", + "http-body-util", + "pastey 0.2.1", + "pin-project-lite", + "rand 0.9.2", + "rmcp-macros", + "schemars 1.1.0", + "serde", + "serde_json", + "sse-stream", + "thiserror 2.0.17", + "tokio", + "tokio-stream", + "tokio-util", + "tower-service", + "tracing", + "uuid", +] + +[[package]] +name = "rmcp-macros" +version = "0.12.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e3f81daaa494eb8e985c9462f7d6ce1ab05e5299f48aafd76cdd3d8b060e6f59" +dependencies = [ + "darling 0.23.0", + "proc-macro2", + "quote", + "serde_json", + "syn 2.0.110", +] + [[package]] name = "rustc-demangle" version = "0.1.26" @@ -5346,6 +5614,18 @@ dependencies = [ "zeroize", ] +[[package]] +name = "rustls-native-certs" +version = "0.8.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9980d917ebb0c0536119ba501e90834767bffc3d60641457fd84a1f3fd337923" +dependencies = [ + "openssl-probe", + "rustls-pki-types", + "schannel", + "security-framework 3.5.1", +] + [[package]] name = "rustls-pki-types" version = "1.13.2" @@ -5439,6 +5719,7 @@ version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9558e172d4e8533736ba97870c4b2cd63f84b382a3d6eb063da41b91cce17289" dependencies = [ + "chrono", "dyn-clone", "ref-cast", "schemars_derive 1.1.0", @@ -5501,6 +5782,16 @@ dependencies = [ "xcb", ] +[[package]] +name = "secrecy" +version = "0.10.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e891af845473308773346dc847b2c23ee78fe442e0472ac50e22a18a93d3ae5a" +dependencies = [ + "serde", + "zeroize", +] + [[package]] name = "security-framework" version = "2.11.1" @@ -5514,6 +5805,19 @@ dependencies = [ "security-framework-sys", ] +[[package]] +name = "security-framework" +version = "3.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b3297343eaf830f66ede390ea39da1d462b6b0c1b000f420d0a83f898bbbe6ef" +dependencies = [ + "bitflags 2.10.0", + "core-foundation 0.10.1", + "core-foundation-sys", + "libc", + "security-framework-sys", +] + [[package]] name = "security-framework-sys" version = "2.15.0" @@ -5618,6 +5922,17 @@ dependencies = [ "serde_core", ] +[[package]] +name = "serde_path_to_error" +version = "0.1.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "10a9ff822e371bb5403e391ecd83e182e0e77ba7f6fe0160b795797109d1b457" +dependencies = [ + "itoa", + "serde", + "serde_core", +] + [[package]] name = "serde_repr" version = "0.1.20" @@ -5640,9 +5955,9 @@ dependencies = [ [[package]] name = "serde_spanned" -version = "1.0.3" +version = "1.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e24345aa0fe688594e73770a5f6d1b216508b4f93484c0026d521acd30134392" +checksum = "f8bbf91e5a4d6315eee45e704372590b30e260ee83af6639d64557f51b067776" dependencies = [ "serde_core", ] @@ -5935,6 +6250,19 @@ dependencies = [ "unicode-segmentation", ] +[[package]] +name = "sse-stream" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eb4dc4d33c68ec1f27d386b5610a351922656e1fdf5c05bbaad930cd1519479a" +dependencies = [ + "bytes", + "futures-util", + "http-body", + "http-body-util", + "pin-project-lite", +] + [[package]] name = "stable_deref_trait" version = "1.2.1" @@ -6209,7 +6537,7 @@ dependencies = [ "serde_json", "tauri-utils", "tauri-winres", - "toml 0.9.8", + "toml 0.9.10+spec-1.1.0", "walkdir", ] @@ -6337,7 +6665,7 @@ dependencies = [ "serde_with", "swift-rs", "thiserror 2.0.17", - "toml 0.9.8", + "toml 0.9.10+spec-1.1.0", "url", "urlpattern", "uuid", @@ -6351,7 +6679,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fd21509dd1fa9bd355dc29894a6ff10635880732396aa38c0066c1e6c1ab8074" dependencies = [ "embed-resource", - "toml 0.9.8", + "toml 0.9.10+spec-1.1.0", ] [[package]] @@ -6589,6 +6917,17 @@ dependencies = [ "tokio", ] +[[package]] +name = "tokio-stream" +version = "0.1.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eca58d7bba4a75707817a2c44174253f9236b2d5fbd055602e9d5c07c139a047" +dependencies = [ + "futures-core", + "pin-project-lite", + "tokio", +] + [[package]] name = "tokio-util" version = "0.7.17" @@ -6616,14 +6955,14 @@ dependencies = [ [[package]] name = "toml" -version = "0.9.8" +version = "0.9.10+spec-1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f0dc8b1fb61449e27716ec0e1bdf0f6b8f3e8f6b05391e8497b8b6d7804ea6d8" +checksum = "0825052159284a1a8b4d6c0c86cbc801f2da5afd2b225fa548c72f2e74002f48" dependencies = [ "indexmap 2.12.0", "serde_core", - "serde_spanned 1.0.3", - "toml_datetime 0.7.3", + "serde_spanned 1.0.4", + "toml_datetime 0.7.5+spec-1.1.0", "toml_parser", "toml_writer", "winnow 0.7.13", @@ -6640,9 +6979,9 @@ dependencies = [ [[package]] name = "toml_datetime" -version = "0.7.3" +version = "0.7.5+spec-1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f2cdb639ebbc97961c51720f858597f7f24c4fc295327923af55b74c3c724533" +checksum = "92e1cfed4a3038bc5a127e35a2d360f145e1f4b971b551a2ba5fd7aedf7e1347" dependencies = [ "serde_core", ] @@ -6678,25 +7017,25 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6485ef6d0d9b5d0ec17244ff7eb05310113c3f316f2d14200d4de56b3cb98f8d" dependencies = [ "indexmap 2.12.0", - "toml_datetime 0.7.3", + "toml_datetime 0.7.5+spec-1.1.0", "toml_parser", "winnow 0.7.13", ] [[package]] name = "toml_parser" -version = "1.0.4" +version = "1.0.6+spec-1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c0cbe268d35bdb4bb5a56a2de88d0ad0eb70af5384a99d648cd4b3d04039800e" +checksum = "a3198b4b0a8e11f09dd03e133c0280504d0801269e9afa46362ffde1cbeebf44" dependencies = [ "winnow 0.7.13", ] [[package]] name = "toml_writer" -version = "1.0.4" +version = "1.0.6+spec-1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "df8b2b54733674ad286d16267dcfc7a71ed5c776e4ac7aa3c3e2561f7c637bf2" +checksum = "ab16f14aed21ee8bfd8ec22513f7287cd4a91aa92e44edfe2c17ddd004e92607" [[package]] name = "tower" @@ -6711,6 +7050,7 @@ dependencies = [ "tokio", "tower-layer", "tower-service", + "tracing", ] [[package]] @@ -6749,6 +7089,7 @@ version = "0.1.41" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "784e0ac535deb450455cbfa28a6f0df145ea1bb7ae51b821cf5e7927fdcfbdd0" dependencies = [ + "log", "pin-project-lite", "tracing-attributes", "tracing-core", @@ -6897,6 +7238,12 @@ dependencies = [ "unic-common", ] +[[package]] +name = "unicase" +version = "2.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "75b844d17643ee918803943289730bec8aac480150456169e647ed0b576ba539" + [[package]] name = "unicode-ident" version = "1.0.22" @@ -7379,7 +7726,7 @@ dependencies = [ "serde", "serde_derive", "sha2", - "toml 0.9.8", + "toml 0.9.10+spec-1.1.0", "windows-sys 0.60.2", "zstd", ] diff --git a/auto-engine-core/Cargo.toml b/auto-engine-core/Cargo.toml index 5b7b8cd..a0fc3e8 100644 --- a/auto-engine-core/Cargo.toml +++ b/auto-engine-core/Cargo.toml @@ -27,6 +27,15 @@ anyhow = "1.0.100" convert_case = "0.10.0" oar-ocr = "0.3.1" reqwest = { version = "0.12.9", default-features = false, features = ["json", "rustls-tls"] } +rmcp = { version = "0.12.0", features = ["transport-streamable-http-server-session", "transport-streamable-http-server"] } +http = "1.4.0" +tracing = "0.1.41" +tracing-subscriber = { version = "0.3", features = ["env-filter"] } +axum = "0.8.8" +toml = "0.9.10" +async-openai = { version = "0.32.2", features = ["responses", "completions", "chat-completion"] } +secrecy = "0.10" +bytes = "1.11.0" [features] default = ["types", "context", "event", "pipeline", "runner", "utils"] diff --git a/auto-engine-core/src/context.rs b/auto-engine-core/src/context.rs index 5ff1da9..10d97d9 100644 --- a/auto-engine-core/src/context.rs +++ b/auto-engine-core/src/context.rs @@ -1,14 +1,20 @@ use crate::utils; -use serde::Serialize; +use serde::{Deserialize, Serialize}; use std::collections::HashMap; use std::path::{Path, PathBuf}; use std::sync::Arc; -use tauri::async_runtime::RwLock; use tauri::Manager; +use tauri::async_runtime::RwLock; + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ValueItem { + pub description: String, + pub value: serde_json::Value, +} #[derive(Debug)] pub struct Context { - pub string_value: Arc>>, + pub value: Arc>>, pub(crate) screen_scale: f64, pub(crate) pipeline_path: PathBuf, pub(crate) workflow_path: PathBuf, @@ -20,7 +26,7 @@ impl Context { #[cfg(feature = "tauri")] pub fn new(path: PathBuf, app_handle: Option) -> Self { Self { - string_value: Arc::new(RwLock::new(HashMap::new())), + value: Arc::new(RwLock::new(HashMap::new())), screen_scale: 1.0, pipeline_path: path.clone(), workflow_path: path.clone(), @@ -31,7 +37,7 @@ impl Context { #[cfg(not(feature = "tauri"))] pub fn new(path: PathBuf) -> Self { Self { - string_value: Arc::new(RwLock::new(HashMap::new())), + value: Arc::new(RwLock::new(HashMap::new())), screen_scale: 1.0, pipeline_path: path.clone(), workflow_path: path.clone(), @@ -44,20 +50,32 @@ impl Context { } pub async fn set_string_value(&self, key: &str, value: &str) -> Result<(), String> { - self.set_value::(key, value.to_string()).await + self.set_value::(key, value.to_string(), String::new()) + .await } - pub async fn set_value(&self, key: &str, value: T) -> Result<(), String> { - let mut map = self.string_value.write().await; + pub async fn set_value( + &self, + key: &str, + value: T, + description: String, + ) -> Result<(), String> { + let mut map = self.value.write().await; map.insert( key.to_string(), - serde_json::to_value(value).map_err(|e| format!("{:?}", e))?, + ValueItem { + description, + value: serde_json::to_value(value).map_err(|e| format!("{:?}", e))?, + }, ); Ok(()) } pub async fn get_value(&self, key: &str) -> Option { - let map = self.string_value.read().await; - map.get(key).cloned() + let map = self.value.read().await; + if let Some(item) = map.get(key).cloned() { + return Some(item.value); + } + None } pub async fn get_value_parse(&self, key: &str) -> Option { @@ -80,7 +98,7 @@ impl Context { default_value } - pub fn load_image_path(&self, image: &str) -> Result { + pub fn path_image(&self, image: &str) -> Result { let image_path = self.workflow_path.join("images").join(image); if !image_path.exists() { return Err(format!("Image {} does not exist", image)); @@ -88,13 +106,22 @@ impl Context { Ok(image_path) } - pub fn resource_path(&self) -> PathBuf { - if let Some(handle) = self.app_handle.clone(){ + pub fn path_resource(&self) -> PathBuf { + if let Some(handle) = self.app_handle.clone() { if cfg!(debug_assertions) { - return PathBuf::from("") + return PathBuf::from(""); } - return handle.path().resource_dir().unwrap().to_path_buf() + return handle.path().resource_dir().unwrap().to_path_buf(); } PathBuf::from("") } + + pub async fn values(&self) -> HashMap { + let mut res: HashMap = HashMap::new(); + let map = self.value.read().await; + map.iter().for_each(|(k, v)| { + res.insert(k.clone(), v.clone()); + }); + res + } } diff --git a/auto-engine-core/src/event/node.rs b/auto-engine-core/src/event/node.rs index 3a11063..010da40 100644 --- a/auto-engine-core/src/event/node.rs +++ b/auto-engine-core/src/event/node.rs @@ -26,6 +26,11 @@ impl NodeEventPayload { pub fn running(name: String) -> NodeEventPayload { NodeEventPayload::new::("running".to_string(), name, None) } + + pub fn running_with_payload(name: String, result: Option) -> NodeEventPayload { + NodeEventPayload::new::("running".to_string(), name, result) + } + pub fn waiting(name: String) -> NodeEventPayload { NodeEventPayload::new::("waiting".to_string(), name, None) } diff --git a/auto-engine-core/src/lib.rs b/auto-engine-core/src/lib.rs index 61c5e00..a7c0f46 100644 --- a/auto-engine-core/src/lib.rs +++ b/auto-engine-core/src/lib.rs @@ -31,3 +31,4 @@ pub mod plugin; mod action; pub mod notification; +pub mod mcp; diff --git a/auto-engine-core/src/mcp.rs b/auto-engine-core/src/mcp.rs new file mode 100644 index 0000000..f156286 --- /dev/null +++ b/auto-engine-core/src/mcp.rs @@ -0,0 +1,4 @@ +pub mod server; +mod state; +mod tool; +mod service; \ No newline at end of file diff --git a/auto-engine-core/src/mcp/server.rs b/auto-engine-core/src/mcp/server.rs new file mode 100644 index 0000000..83d14b8 --- /dev/null +++ b/auto-engine-core/src/mcp/server.rs @@ -0,0 +1,162 @@ +use crate::mcp::service::McpServiceBuilder; +use crate::mcp::state::McpState; +use crate::mcp::tool::{ToolCallBuilder, ToolDefine}; +use crate::node::start::node; +use crate::schema::workflow::WorkflowSchema; +use crate::types::workflow::WorkflowMetaData; +use rmcp::handler::server::wrapper::Parameters; +use schemars::{Schema, SchemaGenerator}; +use serde_json::{Map, Value}; +use std::collections::HashMap; +use std::fs; +use std::path::PathBuf; +use std::sync::Arc; +use tokio_util::sync::CancellationToken; + +struct ParamSchemaBuilder { + params: Vec, +} + +impl ParamSchemaBuilder { + fn new() -> ParamSchemaBuilder { + ParamSchemaBuilder { params: vec![] } + } + + fn add_param(&mut self, name: String) { + self.params.push(name); + } + + fn build(self) -> Schema { + let mut type_string = Map::new(); + type_string.insert("type".to_string(), Value::String("string".to_string())); + let mut params = Map::new(); + let mut required = vec![]; + + for param in self.params { + params.insert(param.clone(), Value::Object(type_string.clone())); + required.push(Value::String(param.clone())); + } + + let mut res = Map::new(); + res.insert("properties".to_string(), Value::Object(params)); + res.insert("type".to_string(), Value::String("object".to_string())); + res.insert("required".to_string(), Value::Array(required)); + + Schema::from(res) + } +} + +pub struct McpServer { + workflow_dir: PathBuf, + token: CancellationToken, +} + +impl McpServer { + pub fn new(workflow_dir: PathBuf) -> Self { + McpServer { + workflow_dir, + token: Default::default(), + } + } + + pub fn load_tools(&self) -> Result>, String> { + if !self.workflow_dir.exists() { + return Ok(vec![]); + } + + let mut workflows = vec![]; + + for entry in self.workflow_dir.read_dir().map_err(|e| e.to_string())? { + let entry = entry.map_err(|e| e.to_string())?; + let path = entry.path(); + + if path.is_dir() { + // Try to load metadata from this directory + let meta_path = path.join("_meta.toml"); + if !meta_path.exists() { + continue; + } + // read metadata + let meta_content = fs::read_to_string(&meta_path).map_err(|e| e.to_string())?; + let meta = toml::from_str::(meta_content.as_str()) + .map_err(|e| e.to_string())?; + + // read workflow content + let workflow_content = + fs::read_to_string(path.join("workflow.yaml")).map_err(|e| e.to_string())?; + + let schema = serde_yaml::from_str::(&workflow_content) + .map_err(|e| e.to_string())?; + + let mut builder = ParamSchemaBuilder::new(); + for node in schema.nodes { + if node.action_type == node::START_NODE_TYPE { + if let Some(value) = node.input_data.unwrap_or(HashMap::new()).get("params") + { + value.as_object().map(|obj| { + for (key, value) in obj { + if value.is_string() { + builder.add_param(key.clone()) + } + } + }); + } + } + } + + workflows.push((meta, builder.build())); + } + } + + let mut tools = vec![]; + + for (workflow, param) in workflows { + let tool = ToolCallBuilder::new( + workflow + .name + .ok_or("No name found in metadata".to_string())?, + ) + .with_description( + workflow + .description + .ok_or("No description found in metadata".to_string())?, + ) + .with_input_schema(param) + .with_call_func(Arc::new(|ctx, params: HashMap| { + Box::pin(async move { Ok(serde_json::Value::String("Hello !".to_string())) }) + })) + .build(); + tools.push(tool); + } + + Ok(tools) + } + + pub async fn run(&self) -> Result<(), String> { + let tools = self.load_tools()?; + + let service = McpServiceBuilder::new() + .with_port(8080) + .with_tool_calls(tools) + .build(); + + log::info!("MCP server listening on {}", self.workflow_dir.display()); + tracing::info!("MCP server started"); + service + .run(self.token.clone()) + .await + .map_err(|e| e.to_string())?; + + Ok(()) + } + + pub async fn stop(&mut self) { + self.token.cancel(); + self.token = Default::default(); + } + + pub async fn restart(&mut self) -> Result<(), String> { + self.stop().await; + self.run().await + } +} diff --git a/auto-engine-core/src/mcp/service.rs b/auto-engine-core/src/mcp/service.rs new file mode 100644 index 0000000..b974da7 --- /dev/null +++ b/auto-engine-core/src/mcp/service.rs @@ -0,0 +1,126 @@ +use crate::mcp::state::McpState; +use crate::mcp::tool::ToolDefine; +use axum::Router; +use rmcp::handler::server::tool::{ToolCallContext, ToolRoute}; +use rmcp::transport::StreamableHttpService; +use rmcp::transport::streamable_http_server::session::local::LocalSessionManager; +use rmcp::{RoleServer, Service}; +use serde::Serialize; +use serde::de::DeserializeOwned; +use std::borrow::Cow; +use std::net::SocketAddr; +use std::str::FromStr; +use std::sync::Arc; +use tokio::select; +use tokio_util::sync::CancellationToken; +use tracing_subscriber::layer::SubscriberExt; +use tracing_subscriber::util::SubscriberInitExt; + +pub struct McpServiceBuilder { + // service: Option, + port: Option, + host: Option, + log_level: Option, + tool_calls: Vec>, +} + +impl McpServiceBuilder { + pub fn new() -> Self { + McpServiceBuilder { + port: None, + host: None, + log_level: None, + tool_calls: vec![], + } + } + pub fn with_port(mut self, port: u16) -> Self { + log::info!("port ===> {}", port); + self.port = Some(port); + self + } + + pub fn with_host(mut self, host: &str) -> Self { + self.host = Some(host.to_string()); + self + } + + pub fn with_log_level(mut self, log_level: &str) -> Self { + self.log_level = Some(log_level.to_string()); + self + } + + pub fn with_tool_call(mut self, tool_call: ToolDefine) -> Self { + self.tool_calls.push(tool_call); + self + } + pub fn with_tool_calls(mut self, tool_call: Vec>) -> Self { + log::info!("tool_call ===> {}", tool_call.len()); + self.tool_calls.extend(tool_call); + self + } + + pub fn build(&self) -> McpService { + log::info!("build ===>"); + + let log_level = self.log_level.clone().unwrap_or("debug".to_string()); + + // tracing_subscriber::registry() + // .with( + // tracing_subscriber::EnvFilter::try_from_default_env() + // .unwrap_or_else(|_| log_level.into()), + // ) + // .with(tracing_subscriber::fmt::layer()) + // .init(); + + let tool_calls = self.tool_calls.clone(); + + let service = StreamableHttpService::new( + move || { + let mut state = McpState::new(); + + for tool in tool_calls.clone() { + state + .tool_router + .map + .insert(tool.attr.name.clone(), tool.clone()); + } + + Ok(state) + }, + LocalSessionManager::default().into(), + Default::default(), + ); + + McpService::Streamable(service, self.port.clone()) + } +} + +pub enum McpService> { + Streamable(StreamableHttpService, Option), +} + +impl McpService { + pub async fn run(self, cancel: CancellationToken) -> Result<(), Box> { + match self { + McpService::Streamable(service, port) => { + let mut p = 23456; + if let Some(port) = port { + p = port + } + let addr = SocketAddr::from_str(&format!("0.0.0.0:{}", p))?; + let tcp_listener = tokio::net::TcpListener::bind(addr).await?; + tracing::info!("listening on {}", addr); + let router = Router::new().nest_service("/mcp", service); + let _ = axum::serve(tcp_listener, router) + .with_graceful_shutdown(async move { + select! { + _ = cancel.cancelled() => {} + _ = tokio::signal::ctrl_c() => {} + } + }) + .await; + Ok(()) + } + } + } +} diff --git a/auto-engine-core/src/mcp/state.rs b/auto-engine-core/src/mcp/state.rs new file mode 100644 index 0000000..f7d47a6 --- /dev/null +++ b/auto-engine-core/src/mcp/state.rs @@ -0,0 +1,66 @@ +use rmcp::handler::server::tool::{ToolCallContext, ToolRouter}; +use rmcp::model::{ + CallToolRequestParam, CallToolResult, Implementation, InitializeRequestParam, InitializeResult, + ProtocolVersion, ServerCapabilities, ServerInfo, +}; +use rmcp::service::RequestContext; +use rmcp::{ErrorData, RoleServer, ServerHandler, tool_router}; + +pub struct McpState { + pub tool_router: ToolRouter, +} + +#[tool_router] +impl McpState { + pub fn new() -> Self { + McpState { + tool_router: ToolRouter::new(), + } + } +} + +impl ServerHandler for McpState { + fn ping( + &self, + _context: RequestContext, + ) -> impl Future> + Send + '_ { + async { Ok(()) } + } + + async fn initialize( + &self, + _request: InitializeRequestParam, + _context: RequestContext, + ) -> Result { + Ok(self.get_info()) + } + + async fn call_tool( + &self, + request: CallToolRequestParam, + context: RequestContext, + ) -> Result { + let tcc = ToolCallContext::new(self, request, context); + self.tool_router.call(tcc).await + } + + async fn list_tools( + &self, + _request: Option, + _context: RequestContext, + ) -> Result { + Ok(rmcp::model::ListToolsResult::with_all_items( + self.tool_router.list_all(), + )) + } + fn get_info(&self) -> ServerInfo { + ServerInfo { + protocol_version: ProtocolVersion::V_2025_06_18, + capabilities: ServerCapabilities::builder() + .enable_tools() + .build(), + server_info: Implementation::from_build_env(), + instructions: Some("This server provides counter tools and prompts. Tools: increment, decrement, get_value, say_hello, echo, sum. Prompts: example_prompt (takes a message), counter_analysis (analyzes counter state with a goal).".to_string()), + } + } +} diff --git a/auto-engine-core/src/mcp/tool.rs b/auto-engine-core/src/mcp/tool.rs new file mode 100644 index 0000000..31725a8 --- /dev/null +++ b/auto-engine-core/src/mcp/tool.rs @@ -0,0 +1,152 @@ +use rmcp::handler::server::tool::{DynCallToolHandler, ToolCallContext, ToolRoute}; +use rmcp::model::{CallToolResult, Content, Icon, JsonObject, Tool, ToolAnnotations}; +use rmcp::schemars::{JsonSchema, SchemaGenerator}; +use rmcp::{ErrorData, RoleServer, Service}; +use serde::Serialize; +use serde::de::DeserializeOwned; +use serde_json::Value; +use std::borrow::Cow; +use std::collections::HashMap; +use std::pin::Pin; +use std::sync::Arc; +use schemars::Schema; +use crate::mcp::state::McpState; + +pub type ToolDefine = ToolRoute; + +pub struct ToolCallBuilder { + /// The name of the tool + name: Cow<'static, str>, + title: Option, + description: Option>, + annotations: Option, + icons: Option>, + call_func: Option>>, + input_schema: Option, + output_schema: Option, +} + +type ToolCallFunc = dyn for<'s> Fn( + ToolCallContext<'s, T>, + I, +) -> Pin< + Box> + Send + Sync + 'static>, +> + Send ++ Sync; + +impl ToolCallBuilder { + pub fn new(name: String) -> Self { + Self { + name: Cow::from(name), + title: None, + description: None, + annotations: None, + icons: None, + call_func: None, + input_schema: None, + output_schema: None, + } + } + + pub fn with_title(mut self, title: impl Into) -> Self { + self.title = Some(title.into()); + self + } + + pub fn with_description(mut self, description: impl Into>) -> Self { + self.description = Some(description.into()); + self + } + + pub fn with_annotations(mut self, annotations: ToolAnnotations) -> Self { + self.annotations = Some(annotations); + self + } + + pub fn with_icons(mut self, icons: Vec) -> Self { + self.icons = Some(icons); + self + } + + pub fn with_input(mut self) -> Self { + let mut generator = SchemaGenerator::new(Default::default()); + let schema = I::json_schema(&mut generator); + let input_schema = serde_json::to_value(schema) + .unwrap() + .as_object() + .unwrap() + .clone(); + self.input_schema = Some(input_schema); + self + } + + pub fn with_input_schema(mut self, schema: Schema) -> Self { + let input_schema = serde_json::to_value(schema) + .unwrap() + .as_object() + .unwrap() + .clone(); + self.input_schema = Some(input_schema); + self + } + pub fn with_output_schema(mut self) -> Self { + let mut generator = SchemaGenerator::new(Default::default()); + let schema = O::json_schema(&mut generator); + let output_schema = serde_json::to_value(schema) + .unwrap() + .as_object() + .unwrap() + .clone(); + self.output_schema = Some(output_schema); + self + } + pub fn with_call_func( + self, + call_func: Arc>, + ) -> Self { + let mut builder = self; + + builder.call_func = Some(Arc::new(move |context| { + let call_func = call_func.clone(); + Box::pin(async move { + let obj = context.arguments.clone().unwrap_or_default(); + let arguments: I = serde_json::from_value(serde_json::Value::Object(obj)) + .unwrap_or_else(|e| { + log::warn!("Failed to convert argument {:?}", e); + Default::default() + }); + match call_func(context, arguments).await { + Ok(res) => Ok(CallToolResult::success(vec![ + Content::json(res).unwrap_or(Content::text("Failed to call tool")), + ])), + Err(e) => Ok(CallToolResult::error(vec![Content::text(e.to_string())])), + } + }) + })); + builder + } + + pub fn build(self) -> ToolDefine { + let output_schema = if let Some(output_schema) = self.output_schema { + Some(Arc::new(output_schema)) + } else { + None + }; + + ToolDefine { + call: self + .call_func + .unwrap_or(Arc::new(|_| panic!("call_func is not set"))), + attr: Tool { + name: self.name, + title: self.title, + description: self.description, + input_schema: Arc::new(self.input_schema.unwrap_or(JsonObject::new())), + output_schema, + annotations: self.annotations, + icons: self.icons, + meta: None, + }, + } + } +} diff --git a/auto-engine-core/src/node.rs b/auto-engine-core/src/node.rs index 228d39c..ec02085 100644 --- a/auto-engine-core/src/node.rs +++ b/auto-engine-core/src/node.rs @@ -10,3 +10,4 @@ pub mod start; pub mod time_wait; #[cfg(feature = "wasm")] pub mod wasm; +pub mod ai; diff --git a/auto-engine-core/src/node/ai.rs b/auto-engine-core/src/node/ai.rs new file mode 100644 index 0000000..cf6a2d1 --- /dev/null +++ b/auto-engine-core/src/node/ai.rs @@ -0,0 +1 @@ +pub mod gpt; \ No newline at end of file diff --git a/auto-engine-core/src/node/ai/gpt.rs b/auto-engine-core/src/node/ai/gpt.rs new file mode 100644 index 0000000..43c099a --- /dev/null +++ b/auto-engine-core/src/node/ai/gpt.rs @@ -0,0 +1,2 @@ +pub mod node; +pub mod runner; diff --git a/auto-engine-core/src/node/ai/gpt/node.rs b/auto-engine-core/src/node/ai/gpt/node.rs new file mode 100644 index 0000000..197635b --- /dev/null +++ b/auto-engine-core/src/node/ai/gpt/node.rs @@ -0,0 +1,213 @@ +use crate::types::field::{ + BooleanConstraint, Condition, FieldCondition, FieldType, SchemaField, ValueConstraint, +}; +use crate::types::node::{I18nValue, NodeDefine}; +use std::collections::HashMap; + +pub const NODE_TYPE: &str = "GPT"; + +#[derive(Default)] +pub struct GptNode; + +impl GptNode { + pub fn new() -> Self { + Self {} + } +} + +impl NodeDefine for GptNode { + fn action_type(&self) -> String { + NODE_TYPE.to_string() + } + + fn name(&self) -> I18nValue { + I18nValue { + zh: "ChatGPT".to_string(), + en: "ChatGPT".to_string(), + } + } + + fn icon(&self) -> String { + String::from( + "data:image/svg+xml;base64,PHN2ZyB3aWR0aD0iMjRweCIgaGVpZ2h0PSIyNHB4IiB2aWV3Qm94PSItMSAtMSAyNiAyNiIgcm9sZT0iaW1nIiB4bWxucz0iaHR0cDovL3d3dy53My5vcmcvMjAwMC9zdmciPjx0aXRsZT5PcGVuQUkgaWNvbjwvdGl0bGU+PHBhdGggZD0iTTIyLjI4MTkgOS44MjExYTUuOTg0NyA1Ljk4NDcgMCAwIDAtLjUxNTctNC45MTA4IDYuMDQ2MiA2LjA0NjIgMCAwIDAtNi41MDk4LTIuOUE2LjA2NTEgNi4wNjUxIDAgMCAwIDQuOTgwNyA0LjE4MThhNS45ODQ3IDUuOTg0NyAwIDAgMC0zLjk5NzcgMi45IDYuMDQ2MiA2LjA0NjIgMCAwIDAgLjc0MjcgNy4wOTY2IDUuOTggNS45OCAwIDAgMCAuNTExIDQuOTEwNyA2LjA1MSA2LjA1MSAwIDAgMCA2LjUxNDYgMi45MDAxQTUuOTg0NyA1Ljk4NDcgMCAwIDAgMTMuMjU5OSAyNGE2LjA1NTcgNi4wNTU3IDAgMCAwIDUuNzcxOC00LjIwNTggNS45ODk0IDUuOTg5NCAwIDAgMCAzLjk5NzctMi45MDAxIDYuMDU1NyA2LjA1NTcgMCAwIDAtLjc0NzUtNy4wNzI5em0tOS4wMjIgMTIuNjA4MWE0LjQ3NTUgNC40NzU1IDAgMCAxLTIuODc2NC0xLjA0MDhsLjE0MTktLjA4MDQgNC43NzgzLTIuNzU4MmEuNzk0OC43OTQ4IDAgMCAwIC4zOTI3LS42ODEzdi02LjczNjlsMi4wMiAxLjE2ODZhLjA3MS4wNzEgMCAwIDEgLjAzOC4wNTJ2NS41ODI2YTQuNTA0IDQuNTA0IDAgMCAxLTQuNDk0NSA0LjQ5NDR6bS05LjY2MDctNC4xMjU0YTQuNDcwOCA0LjQ3MDggMCAwIDEtLjUzNDYtMy4wMTM3bC4xNDIuMDg1MiA0Ljc4MyAyLjc1ODJhLjc3MTIuNzcxMiAwIDAgMCAuNzgwNiAwbDUuODQyOC0zLjM2ODV2Mi4zMzI0YS4wODA0LjA4MDQgMCAwIDEtLjAzMzIuMDYxNUw5Ljc0IDE5Ljk1MDJhNC40OTkyIDQuNDk5MiAwIDAgMS02LjE0MDgtMS42NDY0ek0yLjM0MDggNy44OTU2YTQuNDg1IDQuNDg1IDAgMCAxIDIuMzY1NS0xLjk3MjhWMTEuNmEuNzY2NC43NjY0IDAgMCAwIC4zODc5LjY3NjVsNS44MTQ0IDMuMzU0My0yLjAyMDEgMS4xNjg1YS4wNzU3LjA3NTcgMCAwIDEtLjA3MSAwbC00LjgzMDMtMi43ODY1QTQuNTA0IDQuNTA0IDAgMCAxIDIuMzQwOCA3Ljg3MnptMTYuNTk2MyAzLjg1NThMMTMuMTAzOCA4LjM2NCAxNS4xMTkyIDcuMmEuMDc1Ny4wNzU3IDAgMCAxIC4wNzEgMGw0LjgzMDMgMi43OTEzYTQuNDk0NCA0LjQ5NDQgMCAwIDEtLjY3NjUgOC4xMDQydi01LjY3NzJhLjc5Ljc5IDAgMCAwLS40MDctLjY2N3ptMi4wMTA3LTMuMDIzMWwtLjE0Mi0uMDg1Mi00Ljc3MzUtMi43ODE4YS43NzU5Ljc3NTkgMCAwIDAtLjc4NTQgMEw5LjQwOSA5LjIyOTdWNi44OTc0YS4wNjYyLjA2NjIgMCAwIDEgLjAyODQtLjA2MTVsNC44MzAzLTIuNzg2NmE0LjQ5OTIgNC40OTkyIDAgMCAxIDYuNjgwMiA0LjY2ek04LjMwNjUgMTIuODYzbC0yLjAyLTEuMTYzOGEuMDgwNC4wODA0IDAgMCAxLS4wMzgtLjA1NjdWNi4wNzQyYTQuNDk5MiA0LjQ5OTIgMCAwIDEgNy4zNzU3LTMuNDUzN2wtLjE0Mi4wODA1TDguNzA0IDUuNDU5YS43OTQ4Ljc5NDggMCAwIDAtLjM5MjcuNjgxM3ptMS4wOTc2LTIuMzY1NGwyLjYwMi0xLjQ5OTggMi42MDY5IDEuNDk5OHYyLjk5OTRsLTIuNTk3NCAxLjQ5OTctMi42MDY3LTEuNDk5N1oiLz48L3N2Zz4=", + ) + } + + fn category(&self) -> Option { + Some(I18nValue { + zh: "AI".to_string(), + en: "AI".to_string(), + }) + } + + fn description(&self) -> Option { + Some(I18nValue { + zh: "调用 OpenAI 兼容的聊天模型,返回生成文本与用量。".to_string(), + en: "Call an OpenAI-compatible chat model and return the generated text and usage." + .to_string(), + }) + } + + fn output_schema(&self, _input: HashMap) -> Vec { + vec![ + SchemaField { + name: "data".to_string(), + field_type: FieldType::String, + item_type: None, + description: Some(I18nValue { + zh: "ChatGPT 节点生成的返回内容".to_string(), + en: "Return content generated by ChatGPT nodes".to_string(), + }), + enums: vec![], + default: None, + condition: None, + }, + SchemaField { + name: "prompt_tokens".to_string(), + field_type: FieldType::Number, + item_type: None, + description: Some(I18nValue { + zh: "提示词消耗 token 数".to_string(), + en: "Prompt tokens used".to_string(), + }), + enums: vec![], + default: None, + condition: None, + }, + SchemaField { + name: "completion_tokens".to_string(), + field_type: FieldType::Number, + item_type: None, + description: Some(I18nValue { + zh: "生成部分消耗 token 数".to_string(), + en: "Completion tokens used".to_string(), + }), + enums: vec![], + default: None, + condition: None, + }, + SchemaField { + name: "total_tokens".to_string(), + field_type: FieldType::Number, + item_type: None, + description: Some(I18nValue { + zh: "总 token 消耗".to_string(), + en: "Total tokens used".to_string(), + }), + enums: vec![], + default: None, + condition: None, + }, + ] + } + + fn input_schema(&self) -> Vec { + vec![ + SchemaField { + name: "api_key".to_string(), + field_type: FieldType::Password, + item_type: None, + description: Some(I18nValue { + zh: "OpenAI API Key(为空则使用环境变量 OPENAI_API_KEY)".to_string(), + en: "OpenAI API key (falls back to OPENAI_API_KEY env var)".to_string(), + }), + enums: vec![], + default: Some("".to_string()), + condition: None, + }, + SchemaField { + name: "base_url".to_string(), + field_type: FieldType::String, + item_type: None, + description: Some(I18nValue { + zh: "可选的 API Base URL,兼容自建或代理网关".to_string(), + en: "Optional API base URL for self-hosted or proxy gateways".to_string(), + }), + enums: vec![], + default: Some("".to_string()), + condition: None, + }, + SchemaField { + name: "model".to_string(), + field_type: FieldType::String, + item_type: None, + description: Some(I18nValue { + zh: "模型名称,如 gpt-4o-mini".to_string(), + en: "Model name, e.g. gpt-4o-mini".to_string(), + }), + enums: vec![], + default: Some("gpt-4o-mini".to_string()), + condition: None, + }, + SchemaField { + name: "system_prompt_enabled".to_string(), + field_type: FieldType::Boolean, + item_type: None, + description: Some(I18nValue { + zh: "是否设置系统提示词".to_string(), + en: "Should system prompts be enabled?".to_string(), + }), + enums: vec![], + default: None, + condition: None, + }, + SchemaField { + name: "system".to_string(), + field_type: FieldType::String, + item_type: None, + description: Some(I18nValue { + zh: "系统提示词,可为空".to_string(), + en: "Optional system prompt".to_string(), + }), + enums: vec![], + default: Some("".to_string()), + condition: Some(Condition::Field(FieldCondition { + field: "system_prompt_enabled".to_string(), + constraint: ValueConstraint::Boolean(BooleanConstraint::from(true)), + required: false, + })), + }, + SchemaField { + name: "prompt".to_string(), + field_type: FieldType::String, + item_type: None, + description: Some(I18nValue { + zh: "用户提示词".to_string(), + en: "User prompt".to_string(), + }), + enums: vec![], + default: None, + condition: None, + }, + // TODO: Support response json schema format + // SchemaField { + // name: "response_format_enabled".to_string(), + // field_type: FieldType::Boolean, + // item_type: None, + // description: Some(I18nValue { + // zh: "是否自定义Json响应格式".to_string(), + // en: "Should the JSON response format be customized?".to_string(), + // }), + // enums: vec![], + // default: None, + // condition: None, + // }, + // SchemaField { + // name: "response_json_schema".to_string(), + // field_type: FieldType::String, + // item_type: None, + // description: Some(I18nValue { + // zh: "设置JsonSchema,以Json的方式响应".to_string(), + // en: "Response JsonSchema Format".to_string(), + // }), + // enums: vec![], + // default: None, + // condition: Some(Condition::Field(FieldCondition { + // field: "response_format_enabled".to_string(), + // constraint: ValueConstraint::Boolean(BooleanConstraint::from(true)), + // required: false, + // })), + // } + ] + } +} diff --git a/auto-engine-core/src/node/ai/gpt/runner.rs b/auto-engine-core/src/node/ai/gpt/runner.rs new file mode 100644 index 0000000..c437c02 --- /dev/null +++ b/auto-engine-core/src/node/ai/gpt/runner.rs @@ -0,0 +1,230 @@ +use crate::context::Context; +use crate::types::node::{NodeRunner, NodeRunnerControl, NodeRunnerController, NodeRunnerFactory}; +use async_openai::Client; +use async_openai::config::{Config, OpenAIConfig}; +use async_openai::types::chat::{ + ChatCompletionRequestAssistantMessage, ChatCompletionRequestMessage, + ChatCompletionRequestSystemMessage, ChatCompletionRequestSystemMessageArgs, + ChatCompletionRequestUserMessage, ChatCompletionRequestUserMessageArgs, + CreateChatCompletionRequestArgs, ResponseFormat, ResponseFormatJsonSchema, +}; +use secrecy::ExposeSecret; +use serde::{Deserialize, Serialize}; +use serde_json::json; +use std::collections::HashMap; + +#[derive(Serialize, Deserialize, Clone, Debug)] +pub struct GptParams { + #[serde(default)] + pub api_key: Option, + #[serde(default)] + pub base_url: Option, + #[serde(default = "GptParams::default_model")] + pub model: String, + pub prompt: String, + #[serde(default)] + pub system_prompt_enabled: bool, + #[serde(default)] + pub system: Option, + #[serde(default)] + pub response_format_enabled: bool, + pub response_json_schema: Option, +} + +impl GptParams { + fn default_model() -> String { + "gpt-4o-mini".to_string() + } +} + +#[derive(Default, Clone)] +pub struct GptRunner; + +impl GptRunner { + pub fn new() -> Self { + Self {} + } + + fn resolve_api_key(&self, provided: Option) -> Option { + match provided { + Some(key) if !key.trim().is_empty() => Some(key.trim().to_string()), + _ => std::env::var("OPENAI_API_KEY").ok(), + } + } + + async fn build_context_prompt(&self, ctx: &Context) -> String { + format!( + "You are an AI node in the workflow. You can view the output from the preceding node {:?}, formatted as `ctx.{{node_name}}.{{node_output_key}}`. Simply retrieve the value corresponding to the key. Unless otherwise specified, please output using the following response JSON format: `{{\"data\": value}}`.", + ctx.values().await + ) + } + + pub fn default_response_json_schema(&self) -> ResponseFormatJsonSchema { + ResponseFormatJsonSchema { + description: Some(String::from( + "The response data can be of any type, including primitive types, arrays, objects, etc.", + )), + name: "data".to_string(), + schema: Some(json!({ + "type": "object", + "properties": {} + })), + strict: None, + } + } +} + +#[async_trait::async_trait] +impl NodeRunner for GptRunner { + type ParamType = GptParams; + + async fn run( + &mut self, + ctx: &Context, + param: Self::ParamType, + ) -> Result>, String> { + if param.prompt.trim().is_empty() { + return Err("prompt cannot be empty".to_string()); + } + + let api_key = self.resolve_api_key(param.api_key).ok_or_else(|| { + "OpenAI API key is missing; provide api_key or set OPENAI_API_KEY".to_string() + })?; + + let mut config = OpenAIConfig::default().with_api_key(api_key); + if let Some(base_url) = param + .base_url + .as_ref() + .map(|s| s.trim()) + .filter(|s| !s.is_empty()) + { + config = config.with_api_base(base_url); + } + + if config.api_key().expose_secret().is_empty() { + return Err("OpenAI API key is empty; please configure it first".to_string()); + } + + let client = Client::with_config(config); + + let mut messages: Vec = vec![]; + + // context prompt + { + let context_prompt = self.build_context_prompt(ctx).await; + messages.push( + ChatCompletionRequestSystemMessageArgs::default() + .content(context_prompt) + .build() + .map_err(|e| format!("failed to build system message: {}", e))? + .into(), + ); + } + + // user set system prompt + if let Some(system_prompt) = param + .system + .as_ref() + .map(|s| s.trim()) + .filter(|s| !s.is_empty()) + && param.system_prompt_enabled + { + let system_message = ChatCompletionRequestSystemMessageArgs::default() + .content(system_prompt.to_string()) + .build() + .map_err(|e| format!("failed to build system message: {}", e))?; + messages.push(ChatCompletionRequestSystemMessage::from(system_message).into()); + } + + let user_message = ChatCompletionRequestUserMessageArgs::default() + .content(param.prompt) + .build() + .map_err(|e| format!("failed to build user message: {}", e))?; + messages.push(ChatCompletionRequestUserMessage::from(user_message).into()); + + let mut response_schema = self.default_response_json_schema(); + if let Some(json_schema) = param.response_json_schema.clone() + && param.response_format_enabled + { + response_schema.schema = + Some(serde_json::from_str(&json_schema).map_err(|e| { + format!("failed to parse response JSON schema: {}", e.to_string()) + })?); + } + + let request = CreateChatCompletionRequestArgs::default() + .max_tokens(512u32) + .model(param.model) + .response_format(ResponseFormat::JsonSchema { + json_schema: response_schema, + }) + .messages(messages) + .build() + .map_err(|e| format!("failed to build chat request: {}", e))?; + + let response = client + .chat() + .create(request) + .await + .map_err(|e| format!("OpenAI chat request failed: {}", e))?; + + let choice = response + .choices + .first() + .ok_or_else(|| "OpenAI chat request failed: no choices returned".to_string())?; + + let content = choice.message.content.clone().ok_or_else(|| { + log::error!("OpenAI chat request failed: no message content was returned"); + "OpenAI chat request failed: no message content was returned".to_string() + })?; + + let mut res = HashMap::new(); + + if content.trim().is_empty() { + res.insert("content".to_string(), serde_json::json!(content)); + } else { + let data: serde_json::Map = + serde_json::from_str(content.as_str()) + .inspect_err(|e| log::error!("chat message: {}, error: {}", content, e)) + .map_err(|e| format!("failed to parse chat message: {}", e))?; + + res.insert("data".to_string(), serde_json::json!(data.get("data"))); + } + + let usage = response.usage.unwrap_or_default(); + res.insert( + "prompt_tokens".to_string(), + serde_json::json!(usage.prompt_tokens), + ); + res.insert( + "completion_tokens".to_string(), + serde_json::json!(usage.completion_tokens), + ); + res.insert( + "total_tokens".to_string(), + serde_json::json!(usage.total_tokens), + ); + + Ok(Some(res)) + } +} + +pub struct GptRunnerFactory; + +impl GptRunnerFactory { + pub fn new() -> Self { + Self {} + } +} + +impl Default for GptRunnerFactory { + fn default() -> Self { + Self::new() + } +} + +impl NodeRunnerFactory for GptRunnerFactory { + fn create(&self) -> Box { + Box::new(NodeRunnerController::new(GptRunner::new())) + } +} diff --git a/auto-engine-core/src/node/data_aggregator/runner.rs b/auto-engine-core/src/node/data_aggregator/runner.rs index 9343b73..567c1d4 100644 --- a/auto-engine-core/src/node/data_aggregator/runner.rs +++ b/auto-engine-core/src/node/data_aggregator/runner.rs @@ -31,7 +31,7 @@ impl NodeRunner for DataAggregatorRunner { param: Self::ParamType, ) -> Result>, String> { { - let map = ctx.string_value.read().await; + let map = ctx.value.read().await; log::info!("map: {:?}", map.keys()); } let mut values = vec![]; diff --git a/auto-engine-core/src/node/ocr/node.rs b/auto-engine-core/src/node/ocr/node.rs index 31250eb..3814082 100644 --- a/auto-engine-core/src/node/ocr/node.rs +++ b/auto-engine-core/src/node/ocr/node.rs @@ -60,18 +60,6 @@ impl NodeDefine for OcrNode { default: None, condition: None, }, - SchemaField { - name: "confidence".to_string(), - field_type: FieldType::Number, - item_type: None, - description: Some(I18nValue { - zh: "对应文本的置信度得分".to_string(), - en: "Confidence score for the detected text".to_string(), - }), - enums: vec![], - default: None, - condition: None, - }, ] } diff --git a/auto-engine-core/src/node/ocr/runner.rs b/auto-engine-core/src/node/ocr/runner.rs index c14d3ec..38ef75a 100644 --- a/auto-engine-core/src/node/ocr/runner.rs +++ b/auto-engine-core/src/node/ocr/runner.rs @@ -44,7 +44,7 @@ impl NodeRunner for OcrRunner { ctx: &Context, param: Self::ParamType, ) -> Result>, String> { - let mut resource_path_prefix = ctx.resource_path().to_string_lossy().replace(r"\\?\", "").to_string(); + let mut resource_path_prefix = ctx.path_resource().to_string_lossy().replace(r"\\?\", "").to_string(); if resource_path_prefix != "" { resource_path_prefix = format!("{}/", resource_path_prefix); @@ -66,20 +66,24 @@ impl NodeRunner for OcrRunner { let results = ocr.predict(vec![image]).map_err(|e| e.to_string())?; let mut res = HashMap::new(); - for text_region in &results[0].text_regions { - if let Some((text, confidence)) = text_region.text_with_confidence() { - let text = if param.digits_only { - self.extract_digits(text) - } else { - text.to_string() - }; - res.insert("text".to_string(), serde_json::json!(text)); - res.insert("confidence".to_string(), serde_json::json!(confidence)); - return Ok(Some(res)); + let mut text_res = "".to_string(); + for ocr_res in &results { + for text_region in ocr_res.text_regions.clone() { + if let Some((text, _)) = text_region.text_with_confidence() { + let text = if param.digits_only { + self.extract_digits(text) + } else { + text.to_string() + }; + text_res.push_str(&text); + log::info!("======> {}", text); + log::info!("======> {}", text_res); + } } } + res.insert("text".to_string(), serde_json::json!(text_res)); - Ok(None) + Ok(Some(res)) } } diff --git a/auto-engine-core/src/node/start/node.rs b/auto-engine-core/src/node/start/node.rs index 4e3e306..66ebccb 100644 --- a/auto-engine-core/src/node/start/node.rs +++ b/auto-engine-core/src/node/start/node.rs @@ -2,6 +2,8 @@ use crate::types::field::{FieldType, SchemaField}; use crate::types::node::{I18nValue, NodeDefine}; use std::collections::HashMap; +pub const START_NODE_TYPE: &'static str = "Start"; + pub struct StartNode; impl StartNode { @@ -12,7 +14,7 @@ impl StartNode { impl NodeDefine for StartNode { fn action_type(&self) -> String { - String::from("Start") + String::from(START_NODE_TYPE) } fn name(&self) -> I18nValue { diff --git a/auto-engine-core/src/register/bus.rs b/auto-engine-core/src/register/bus.rs index b41d8e8..2977b2e 100644 --- a/auto-engine-core/src/register/bus.rs +++ b/auto-engine-core/src/register/bus.rs @@ -1,3 +1,5 @@ +use crate::node::ai::gpt::node::GptNode; +use crate::node::ai::gpt::runner::GptRunnerFactory; use crate::node::data_aggregator::node::DataAggregatorNode; use crate::node::data_aggregator::runner::DataAggregatorRunnerFactory; use crate::node::http::node::HttpNode; @@ -74,6 +76,7 @@ impl NodeRegisterBus { Box::new(DataAggregatorNode::new()), Box::new(DataAggregatorRunnerFactory::new()), ); + self.register(Box::new(GptNode::new()), Box::new(GptRunnerFactory::new())); self } diff --git a/auto-engine-core/src/types/conditions.rs b/auto-engine-core/src/types/conditions.rs index eadc784..324281a 100644 --- a/auto-engine-core/src/types/conditions.rs +++ b/auto-engine-core/src/types/conditions.rs @@ -24,7 +24,7 @@ impl Conditions { if let Some(key) = &self.exist && key != "" { - let values = ctx.string_value.read().await; + let values = ctx.value.read().await; if !values.contains_key(key) { log::info!("{} does not exist, {:?}", key, values); return Ok(ConditionResult { @@ -37,7 +37,7 @@ impl Conditions { if let Some(key) = &self.not_exist && key != "" { - let values = ctx.string_value.read().await; + let values = ctx.value.read().await; if values.contains_key(key) { log::info!("{} does not exist, {:?}", key, values); return Ok(ConditionResult { @@ -57,14 +57,14 @@ impl Conditions { log::error!("{}", err); return Ok(ConditionResult { pass: false, - reason: Some(err.to_string()), + reason: Some(format!("{} condition err: {}", condition, err)), }); } }; if condition.trim() == "" { return Ok(ConditionResult { pass: false, - reason: Some(condition.clone()), + reason: Some(format!("{} does not pass condition", condition)), }); } let result = evalexpr::eval_boolean(&condition) @@ -73,12 +73,12 @@ impl Conditions { log::info!("{} does not pass condition", condition); return Ok(ConditionResult { pass: false, - reason: Some(condition.clone()), + reason: Some(format!("{} does not pass condition", condition)), }); } } Ok(ConditionResult { - pass: false, + pass: true, reason: None, }) } diff --git a/auto-engine-core/src/types/field.rs b/auto-engine-core/src/types/field.rs index b7715cc..1d76986 100644 --- a/auto-engine-core/src/types/field.rs +++ b/auto-engine-core/src/types/field.rs @@ -26,6 +26,12 @@ pub struct BooleanConstraint { pub equals: bool, } +impl BooleanConstraint { + pub fn from(value: bool) -> Self { + Self { equals: value } + } +} + #[derive(Debug, Clone, Serialize, Deserialize, Default)] pub struct StringConstraint { #[serde(skip_serializing_if = "Option::is_none")] @@ -110,6 +116,7 @@ pub enum FieldType { Object, Image, File, + Password, } #[derive(Clone, Default, Serialize, Debug, Deserialize)] pub struct SchemaField { diff --git a/auto-engine-core/src/types/mod.rs b/auto-engine-core/src/types/mod.rs index 893ea5f..9f6216a 100644 --- a/auto-engine-core/src/types/mod.rs +++ b/auto-engine-core/src/types/mod.rs @@ -9,3 +9,4 @@ pub use keyboard::*; pub mod field; pub mod node; +pub mod workflow; diff --git a/auto-engine-core/src/types/node.rs b/auto-engine-core/src/types/node.rs index 2120a1a..18a1bb3 100644 --- a/auto-engine-core/src/types/node.rs +++ b/auto-engine-core/src/types/node.rs @@ -47,7 +47,8 @@ pub trait NodeRunnerControl: Send + Sync { ctx: &Context, node_name: &str, params: HashMap, - schema_field: Vec, + input_schema: Vec, + output_schema: Vec, ) -> Result>, String>; } @@ -71,11 +72,12 @@ where ctx: &Context, node_name: &str, params: HashMap, - schema_field: Vec, + input_schema: Vec, + output_schema: Vec, ) -> Result>, String> { let mut params = params; log::info!("params: {:?}, size: {}", params, params.len()); - for field in schema_field.iter() { + for field in input_schema.iter() { log::info!("field: {:?}", field); let default = field.default.clone().unwrap_or_default(); let mut val = params @@ -86,9 +88,10 @@ where if let serde_json::Value::String(s) = &val { let res = utils::parse_variables(ctx, s).await; val = match field.field_type { - FieldType::String | FieldType::Image | FieldType::File => { - serde_json::Value::String(res.clone()) - } + FieldType::String + | FieldType::Image + | FieldType::File + | FieldType::Password => serde_json::Value::String(res.clone()), FieldType::Number => match res.trim() { "" => serde_json::Value::Number(0.into()), @@ -111,7 +114,7 @@ where }, FieldType::Boolean => match res.to_lowercase().as_str() { "true" | "1" => serde_json::Value::Bool(true), - "false" | "0" => serde_json::Value::Bool(false), + "false" | "0" | "" => serde_json::Value::Bool(false), _ => { return Err(format!( "Field '{}' cannot be parsed as a boolean: {}", @@ -139,11 +142,18 @@ where if let Some(result) = self.runner.run(ctx, params).await? { for (name, value) in result.iter() { + let mut description = String::new(); + for field in &output_schema { + if field.name.as_str() == name.as_str() { + description = field.description.clone().unwrap_or(Default::default()).en; + break + } + } log::info!( "set value {}", format!("ctx.{}.{}", node_name, name).as_str() ); - ctx.set_value(format!("ctx.{}.{}", node_name, name).as_str(), value) + ctx.set_value(format!("ctx.{}.{}", node_name, name).as_str(), value, description) .await?; } return Ok(Some(result)); diff --git a/auto-engine-core/src/types/workflow.rs b/auto-engine-core/src/types/workflow.rs new file mode 100644 index 0000000..aa8a65f --- /dev/null +++ b/auto-engine-core/src/types/workflow.rs @@ -0,0 +1,12 @@ +use serde::{Deserialize, Serialize}; + +#[derive(Deserialize, Serialize, Default, Clone)] +pub struct WorkflowMetaData { + pub id: Option, + pub name: Option, + pub description: Option, + pub version: Option, + pub author: Option, + pub tags: Vec, + pub group: Option, +} \ No newline at end of file diff --git a/auto-engine-core/src/utils/mod.rs b/auto-engine-core/src/utils/mod.rs index 2ff862e..14bcf98 100644 --- a/auto-engine-core/src/utils/mod.rs +++ b/auto-engine-core/src/utils/mod.rs @@ -8,7 +8,7 @@ pub static REGEX_PARSE_VARIABLES: Lazy = // String: the value name or key // bool: if need get value from Context pub async fn parse_variables(context: &Context, input: &str) -> String { - let ctx = context.string_value.read().await; + let ctx = context.value.read().await; REGEX_PARSE_VARIABLES .replace_all(input, |caps: ®ex::Captures| { @@ -16,7 +16,7 @@ pub async fn parse_variables(context: &Context, input: &str) -> String { let default = caps.get(2).map(|m| m.as_str()).unwrap_or(""); if let Some(value) = ctx.get(var_name) { - return serde_json::to_string(&value) + return serde_json::to_string(&value.value) .unwrap_or_default() .trim_matches('"') .to_string(); @@ -27,14 +27,14 @@ pub async fn parse_variables(context: &Context, input: &str) -> String { } pub async fn try_parse_variables(context: &Context, input: &str) -> Result { - let ctx = context.string_value.read().await; + let ctx = context.value.read().await; let mut err: Option = None; let result = REGEX_PARSE_VARIABLES.replace_all(input, |caps: ®ex::Captures| { let var_name = &caps[1]; let variable = match ctx.get(var_name) { - Some(value) => serde_json::to_string(value) + Some(value) => serde_json::to_string(&value.value) .unwrap_or_default() .trim_matches('"') .to_string(), diff --git a/auto-engine-core/src/workflow/runner.rs b/auto-engine-core/src/workflow/runner.rs index 13ae2cd..66c72d4 100644 --- a/auto-engine-core/src/workflow/runner.rs +++ b/auto-engine-core/src/workflow/runner.rs @@ -282,6 +282,7 @@ fn handle_nod( &node_name, run_input.clone(), node.input_schema().clone(), + node.output_schema(Default::default()).clone(), ) .await { @@ -316,6 +317,7 @@ fn handle_nod( &node_name, run_input.clone(), node.input_schema().clone(), + node.output_schema(Default::default()).clone(), ) .await {