From cdb56e2f0e487441d36b4caaea75f1c1f40437e5 Mon Sep 17 00:00:00 2001 From: Huanqi Cao Date: Sat, 7 Feb 2026 20:43:11 +0800 Subject: [PATCH 01/29] Add Rust stubgen and integration tests --- rust/Cargo.toml | 2 +- rust/tvm-ffi-stubgen/Cargo.toml | 34 ++ rust/tvm-ffi-stubgen/src/cli.rs | 36 ++ rust/tvm-ffi-stubgen/src/ffi.rs | 220 +++++++++ rust/tvm-ffi-stubgen/src/generate.rs | 695 +++++++++++++++++++++++++++ rust/tvm-ffi-stubgen/src/lib.rs | 71 +++ rust/tvm-ffi-stubgen/src/main.rs | 24 + rust/tvm-ffi-stubgen/src/model.rs | 140 ++++++ rust/tvm-ffi-stubgen/src/schema.rs | 52 ++ rust/tvm-ffi-stubgen/src/tests.rs | 222 +++++++++ rust/tvm-ffi-stubgen/src/utils.rs | 77 +++ 11 files changed, 1572 insertions(+), 1 deletion(-) create mode 100644 rust/tvm-ffi-stubgen/Cargo.toml create mode 100644 rust/tvm-ffi-stubgen/src/cli.rs create mode 100644 rust/tvm-ffi-stubgen/src/ffi.rs create mode 100644 rust/tvm-ffi-stubgen/src/generate.rs create mode 100644 rust/tvm-ffi-stubgen/src/lib.rs create mode 100644 rust/tvm-ffi-stubgen/src/main.rs create mode 100644 rust/tvm-ffi-stubgen/src/model.rs create mode 100644 rust/tvm-ffi-stubgen/src/schema.rs create mode 100644 rust/tvm-ffi-stubgen/src/tests.rs create mode 100644 rust/tvm-ffi-stubgen/src/utils.rs diff --git a/rust/Cargo.toml b/rust/Cargo.toml index d92437681..78899c45f 100644 --- a/rust/Cargo.toml +++ b/rust/Cargo.toml @@ -16,6 +16,6 @@ # under the License. [workspace] -members = ["tvm-ffi", "tvm-ffi-sys", "tvm-ffi-macros"] +members = ["tvm-ffi", "tvm-ffi-sys", "tvm-ffi-macros", "tvm-ffi-stubgen"] resolver = "2" diff --git a/rust/tvm-ffi-stubgen/Cargo.toml b/rust/tvm-ffi-stubgen/Cargo.toml new file mode 100644 index 000000000..49a29788a --- /dev/null +++ b/rust/tvm-ffi-stubgen/Cargo.toml @@ -0,0 +1,34 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +[package] +name = "tvm-ffi-stubgen" +description = "Rust stub generator for tvm-ffi" +version = "0.1.0" +edition = "2021" +license = "Apache-2.0" + +[[bin]] +name = "tvm-ffi-stubgen" +path = "src/main.rs" + +[dependencies] +clap = { version = "4.5", features = ["derive"] } +libloading = "0.8" +serde = { version = "1.0", features = ["derive"] } +serde_json = "1.0" +tvm-ffi = { version = "0.1.0-alpha.0", path = "../tvm-ffi" } diff --git a/rust/tvm-ffi-stubgen/src/cli.rs b/rust/tvm-ffi-stubgen/src/cli.rs new file mode 100644 index 000000000..905270c90 --- /dev/null +++ b/rust/tvm-ffi-stubgen/src/cli.rs @@ -0,0 +1,36 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use clap::Parser; +use std::path::PathBuf; + +#[derive(Parser, Debug)] +#[command(name = "tvm-ffi-stubgen", about = "Generate Rust stubs from tvm-ffi metadata")] +pub struct Args { + #[arg(value_name = "OUT_DIR")] + pub out_dir: PathBuf, + #[arg(long = "dlls", value_delimiter = ';', num_args = 1..)] + pub dlls: Vec, + #[arg(long = "init-prefix")] + pub init_prefix: String, + #[arg(long = "init-crate")] + pub init_crate: String, + #[arg(long = "tvm-ffi-path")] + pub tvm_ffi_path: Option, + #[arg(long = "overwrite")] + pub overwrite: bool, +} diff --git a/rust/tvm-ffi-stubgen/src/ffi.rs b/rust/tvm-ffi-stubgen/src/ffi.rs new file mode 100644 index 000000000..56d03bfaa --- /dev/null +++ b/rust/tvm-ffi-stubgen/src/ffi.rs @@ -0,0 +1,220 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use libloading::Library; +use std::path::PathBuf; +use std::sync::LazyLock; +use tvm_ffi::function_internal::{ArgIntoRef, IntoArgHolder}; +use tvm_ffi::{Any, Error, Function, Result as FfiResult, String as FfiString, TYPE_ERROR}; +use tvm_ffi::tvm_ffi_sys::{ + TVMFFIAny, TVMFFIByteArray, TVMFFIGetTypeInfo, TVMFFIObjectHandle, TVMFFITypeIndex, + TVMFFITypeInfo, TVMFFITypeKeyToIndex, +}; + +#[repr(C)] +#[derive(Debug)] +pub(crate) struct Array { + handle: TVMFFIObjectHandle, +} + +extern "C" { + fn TVMFFIObjectIncRef(handle: TVMFFIObjectHandle) -> i32; + fn TVMFFIObjectDecRef(handle: TVMFFIObjectHandle) -> i32; +} + +impl Clone for Array { + fn clone(&self) -> Self { + unsafe { + TVMFFIObjectIncRef(self.handle); + } + Self { handle: self.handle } + } +} + +impl Drop for Array { + fn drop(&mut self) { + unsafe { + TVMFFIObjectDecRef(self.handle); + } + } +} + +unsafe impl tvm_ffi::type_traits::AnyCompatible for Array { + fn type_str() -> String { + "ffi.Array".to_string() + } + + unsafe fn copy_to_any_view(src: &Self, data: &mut TVMFFIAny) { + data.type_index = TVMFFITypeIndex::kTVMFFIArray as i32; + data.small_str_len = 0; + data.data_union.v_obj = src.handle as *mut tvm_ffi::tvm_ffi_sys::TVMFFIObject; + } + + unsafe fn move_to_any(src: Self, data: &mut TVMFFIAny) { + data.type_index = TVMFFITypeIndex::kTVMFFIArray as i32; + data.small_str_len = 0; + data.data_union.v_obj = src.handle as *mut tvm_ffi::tvm_ffi_sys::TVMFFIObject; + } + + unsafe fn check_any_strict(data: &TVMFFIAny) -> bool { + data.type_index == TVMFFITypeIndex::kTVMFFIArray as i32 + } + + unsafe fn copy_from_any_view_after_check(data: &TVMFFIAny) -> Self { + let handle = data.data_union.v_obj as TVMFFIObjectHandle; + TVMFFIObjectIncRef(handle); + Self { handle } + } + + unsafe fn move_from_any_after_check(data: &mut TVMFFIAny) -> Self { + let handle = data.data_union.v_obj as TVMFFIObjectHandle; + Self { handle } + } + + unsafe fn try_cast_from_any_view(data: &TVMFFIAny) -> Result { + if data.type_index == TVMFFITypeIndex::kTVMFFIArray as i32 { + Ok(Self::copy_from_any_view_after_check(data)) + } else { + Err(()) + } + } +} + +impl ArgIntoRef for Array { + type Target = Array; + fn to_ref(&self) -> &Self::Target { + self + } +} + +impl<'a> ArgIntoRef for &'a Array { + type Target = Array; + fn to_ref(&self) -> &Self::Target { + self + } +} + +impl IntoArgHolder for Array { + type Target = Array; + fn into_arg_holder(self) -> Self::Target { + self + } +} + +impl<'a> IntoArgHolder for &'a Array { + type Target = &'a Array; + fn into_arg_holder(self) -> Self::Target { + self + } +} + +impl TryFrom for Array { + type Error = Error; + fn try_from(value: Any) -> Result { + if let Some(ret) = value.try_as::() { + Ok(ret) + } else { + Err(Error::new(TYPE_ERROR, "Expected ffi.Array", "")) + } + } +} + +pub(crate) fn load_dlls(paths: &[PathBuf]) -> Result, Box> { + let mut libs = Vec::new(); + for path in paths { + let lib = unsafe { Library::new(path) }?; + libs.push(lib); + } + Ok(libs) +} + +pub(crate) fn list_global_function_names() -> FfiResult> { + let functor_func = Function::get_global("ffi.FunctionListGlobalNamesFunctor")?; + let functor_any = functor_func.call_tuple_with_len::<0, _>(())?; + let functor: Function = functor_any.try_into()?; + let count_any = functor.call_tuple_with_len::<1, _>((-1i64,))?; + let count: i64 = count_any.try_into()?; + let mut out = Vec::new(); + for idx in 0..count { + let name_any = functor.call_tuple_with_len::<1, _>((idx,))?; + let name: FfiString = name_any.try_into()?; + out.push(name.as_str().to_string()); + } + Ok(out) +} + +pub(crate) fn list_registered_type_keys() -> FfiResult> { + let get_keys = Function::get_global("ffi.GetRegisteredTypeKeys")?; + let keys_any = get_keys.call_tuple_with_len::<0, _>(())?; + let keys: Array = keys_any.try_into()?; + let size = array_size(&keys)?; + let mut out = Vec::new(); + for idx in 0..size { + let item = array_get_item(&keys, idx)?; + let key: FfiString = item.try_into()?; + out.push(key.as_str().to_string()); + } + Ok(out) +} + +pub(crate) fn array_size(array: &Array) -> FfiResult { + static FUNC: LazyLock = + LazyLock::new(|| Function::get_global("ffi.ArraySize").expect("ffi.ArraySize missing")); + let func = &*FUNC; + let size_any = func.call_tuple_with_len::<1, _>((array,))?; + size_any.try_into() +} + +pub(crate) fn array_get_item(array: &Array, index: i64) -> FfiResult { + static FUNC: LazyLock = LazyLock::new(|| { + Function::get_global("ffi.ArrayGetItem").expect("ffi.ArrayGetItem missing") + }); + let func = &*FUNC; + func.call_tuple_with_len::<2, _>((array, index)) +} + +pub(crate) fn get_type_info(type_key: &str) -> Option<&'static TVMFFITypeInfo> { + unsafe { + let key = TVMFFIByteArray::from_str(type_key); + let mut tindex = 0; + if TVMFFITypeKeyToIndex(&key, &mut tindex) != 0 { + return None; + } + let info = TVMFFIGetTypeInfo(tindex); + if info.is_null() { + None + } else { + Some(&*info) + } + } +} + +pub(crate) fn get_global_func_metadata(name: &str) -> FfiResult> { + let func = Function::get_global("ffi.GetGlobalFuncMetadata")?; + let name_arg = FfiString::from(name); + let meta_any = func.call_tuple_with_len::<1, _>((name_arg,))?; + let meta: FfiString = meta_any.try_into()?; + Ok(Some(meta.as_str().to_string())) +} + +pub(crate) fn byte_array_to_string_opt(value: &TVMFFIByteArray) -> Option { + if value.data.is_null() || value.size == 0 { + return None; + } + let slice = unsafe { std::slice::from_raw_parts(value.data, value.size) }; + Some(String::from_utf8_lossy(slice).to_string()) +} diff --git a/rust/tvm-ffi-stubgen/src/generate.rs b/rust/tvm-ffi-stubgen/src/generate.rs new file mode 100644 index 000000000..b98102fbb --- /dev/null +++ b/rust/tvm-ffi-stubgen/src/generate.rs @@ -0,0 +1,695 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::cli::Args; +use crate::ffi; +use crate::model::{FunctionGen, FunctionSig, MethodGen, ModuleNode, RustType, TypeGen}; +use crate::schema::{extract_type_schema, parse_type_schema, TypeSchema}; +use crate::utils; +use std::collections::BTreeMap; +use std::fmt::Write as _; + +const METHOD_FLAG_STATIC: i64 = 1 << 2; + +pub(crate) fn build_type_map(type_keys: &[String], prefix: &str) -> BTreeMap { + let mut map = BTreeMap::new(); + for key in type_keys { + let (mods, name) = split_name(key, prefix); + let rust_name = sanitize_ident(&name, IdentStyle::Type); + let module_path = module_path(&mods); + let path = if module_path.is_empty() { + format!("crate::types::{}", rust_name) + } else { + format!("crate::types::{}::{}", module_path, rust_name) + }; + map.insert(key.clone(), path); + } + map +} + +pub(crate) fn build_function_entries( + func_names: &[String], + type_map: &BTreeMap, + prefix: &str, +) -> tvm_ffi::Result, FunctionGen)>> { + let mut out = Vec::new(); + for full_name in func_names { + let metadata = ffi::get_global_func_metadata(full_name)?; + let schema = metadata + .and_then(|meta| extract_type_schema(&meta)) + .and_then(|schema| parse_type_schema(&schema)); + let sig = build_function_sig(schema.as_ref(), type_map, None); + let (mods, name) = split_name(full_name, prefix); + let rust_name = sanitize_ident(&name, IdentStyle::Function); + out.push(( + mods, + FunctionGen { + full_name: full_name.clone(), + rust_name, + sig, + }, + )); + } + Ok(out) +} + +pub(crate) fn build_type_entries( + type_keys: &[String], + type_map: &BTreeMap, + prefix: &str, +) -> tvm_ffi::Result, TypeGen)>> { + let mut out = Vec::new(); + for key in type_keys { + let (mods, name) = split_name(key, prefix); + let rust_name = sanitize_ident(&name, IdentStyle::Type); + let mut methods = Vec::new(); + if let Some(info) = ffi::get_type_info(key) { + if info.num_methods > 0 && !info.methods.is_null() { + let method_slice = unsafe { + std::slice::from_raw_parts(info.methods, info.num_methods as usize) + }; + for method in method_slice { + let method_name = match ffi::byte_array_to_string_opt(&method.name) { + Some(name) => name, + None => continue, + }; + let rust_method_name = map_method_name(&method_name); + let is_static = (method.flags & METHOD_FLAG_STATIC) != 0; + let meta = ffi::byte_array_to_string_opt(&method.metadata); + let schema = meta + .as_deref() + .and_then(extract_type_schema) + .and_then(|s| parse_type_schema(&s)); + let sig = build_method_sig(schema.as_ref(), type_map, Some(key.as_str()), is_static); + let full_name = format!("{}.{}", key, method_name); + methods.push(MethodGen { + full_name, + rust_name: rust_method_name, + sig, + is_static, + }); + } + } + } + out.push(( + mods, + TypeGen { + type_key: key.clone(), + rust_name, + methods, + }, + )); + } + Ok(out) +} + +pub(crate) fn build_function_modules( + funcs: Vec<(Vec, FunctionGen)>, + _prefix: &str, +) -> ModuleNode { + let mut root = ModuleNode::default(); + for (mods, func) in funcs { + insert_function(&mut root, &mods, func); + } + root +} + +pub(crate) fn build_type_modules(types: Vec<(Vec, TypeGen)>, _prefix: &str) -> ModuleNode { + let mut root = ModuleNode::default(); + for (mods, ty) in types { + insert_type(&mut root, &mods, ty); + } + root +} + +fn build_function_sig( + schema: Option<&TypeSchema>, + type_map: &BTreeMap, + self_type_key: Option<&str>, +) -> FunctionSig { + match schema { + None => FunctionSig::packed(), + Some(schema) if schema.origin != "ffi.Function" => FunctionSig::packed(), + Some(schema) if schema.args.is_empty() => FunctionSig::packed(), + Some(schema) => { + let ret = rust_type_for_schema(&schema.args[0], type_map, self_type_key); + let args: Vec = schema.args[1..] + .iter() + .map(|arg| rust_type_for_schema(arg, type_map, self_type_key)) + .collect(); + FunctionSig::from_types(args, ret) + } + } +} + +fn build_method_sig( + schema: Option<&TypeSchema>, + type_map: &BTreeMap, + self_type_key: Option<&str>, + is_static: bool, +) -> FunctionSig { + if !is_static { + return FunctionSig::packed(); + } + build_function_sig(schema, type_map, self_type_key) +} + +fn rust_type_for_schema( + schema: &TypeSchema, + type_map: &BTreeMap, + _self_type_key: Option<&str>, +) -> RustType { + match schema.origin.as_str() { + "None" => RustType::supported("()"), + "bool" => RustType::supported("bool"), + "int" => RustType::supported("i64"), + "float" => RustType::supported("f64"), + "Device" => RustType::unsupported("tvm_ffi::Any"), + "DataType" => RustType::unsupported("tvm_ffi::Any"), + "ffi.String" | "std::string" | "const char*" | "ffi.SmallStr" => { + RustType::supported("tvm_ffi::String") + } + "ffi.Bytes" | "TVMFFIByteArray*" | "ffi.SmallBytes" => { + RustType::supported("tvm_ffi::Bytes") + } + "ffi.Function" => RustType::unsupported("tvm_ffi::Any"), + "ffi.Object" => RustType::unsupported("tvm_ffi::Any"), + "ffi.Tensor" | "DLTensor*" => RustType::unsupported("tvm_ffi::Any"), + "ffi.Shape" => RustType::unsupported("tvm_ffi::Any"), + "ffi.Module" => RustType::unsupported("tvm_ffi::Any"), + "Optional" => RustType::unsupported("tvm_ffi::Any"), + "Union" | "Variant" | "tuple" | "list" | "dict" | "ffi.Array" | "ffi.Map" | "Any" => { + RustType::unsupported("tvm_ffi::Any") + } + other => { + if let Some(_path) = type_map.get(other) { + RustType::unsupported("tvm_ffi::Any") + } else { + RustType::unsupported("tvm_ffi::Any") + } + } + } +} + +fn insert_function(root: &mut ModuleNode, mods: &[String], func: FunctionGen) { + let mut node = root; + for module in mods { + node = node + .children + .entry(module.clone()) + .or_insert_with(|| ModuleNode { + name: module.clone(), + ..ModuleNode::default() + }); + } + node.functions.push(func); +} + +fn insert_type(root: &mut ModuleNode, mods: &[String], ty: TypeGen) { + let mut node = root; + for module in mods { + node = node + .children + .entry(module.clone()) + .or_insert_with(|| ModuleNode { + name: module.clone(), + ..ModuleNode::default() + }); + } + node.types.push(ty); +} + +fn split_name(full_name: &str, prefix: &str) -> (Vec, String) { + let mut remainder = full_name; + if !prefix.is_empty() && remainder.starts_with(prefix) { + remainder = &remainder[prefix.len()..]; + } else if !prefix.is_empty() && remainder.starts_with(prefix.trim_end_matches('.')) { + remainder = &remainder[prefix.trim_end_matches('.').len()..]; + remainder = remainder.trim_start_matches('.'); + } + let parts: Vec<&str> = remainder.split('.').filter(|p| !p.is_empty()).collect(); + if parts.is_empty() { + return (Vec::new(), "ffi".to_string()); + } + if parts.len() == 1 { + return (Vec::new(), parts[0].to_string()); + } + let mut mods = Vec::new(); + for part in &parts[..parts.len() - 1] { + mods.push(sanitize_ident(part, IdentStyle::Module)); + } + (mods, parts[parts.len() - 1].to_string()) +} + +fn module_path(mods: &[String]) -> String { + if mods.is_empty() { + return String::new(); + } + mods.join("::") +} + +pub(crate) fn render_cargo_toml( + args: &Args, + _type_map: &BTreeMap, +) -> Result> { + let tvm_ffi_path = match &args.tvm_ffi_path { + Some(path) => path.clone(), + None => utils::default_tvm_ffi_path()?, + }; + let rel_path = utils::relative_path(&args.out_dir, &tvm_ffi_path); + let mut out = String::new(); + writeln!( + &mut out, + "[package]\nname = \"{}\"\nversion = \"0.1.0\"\nedition = \"2021\"\n", + args.init_crate + )?; + writeln!( + &mut out, + "[dependencies]\ntvm-ffi = {{ path = \"{}\" }}\n", + rel_path.display() + )?; + Ok(out) +} + +pub(crate) fn render_lib_rs() -> String { + let mut out = String::new(); + out.push_str("pub mod functions;\n"); + out.push_str("pub mod types;\n\n"); + out.push_str("pub use functions::*;\n"); + out.push_str("pub use types::*;\n\n"); + out.push_str("pub fn load_library(path: &str) -> tvm_ffi::Result {\n"); + out.push_str(" tvm_ffi::Module::load_from_file(path)\n"); + out.push_str("}\n"); + out +} + +pub(crate) fn render_functions_rs(root: &ModuleNode) -> String { + let mut out = String::new(); + out.push_str("use std::sync::LazyLock;\n"); + out.push_str("use tvm_ffi::{Any, AnyView, Function, Result};\n\n"); + render_function_module(&mut out, root, 0); + out +} + +pub(crate) fn render_types_rs(root: &ModuleNode) -> String { + let mut out = String::new(); + out.push_str("use std::sync::LazyLock;\n"); + out.push_str("use tvm_ffi::object::ObjectRef;\n"); + out.push_str("use tvm_ffi::{Any, AnyView, Function, Result};\n\n"); + render_type_module(&mut out, root, 0); + out +} + +fn render_function_module(out: &mut String, node: &ModuleNode, indent: usize) { + let indent_str = " ".repeat(indent); + for func in &node.functions { + render_function(out, func, indent); + } + for child in node.children.values() { + writeln!(out, "{}pub mod {} {{", indent_str, child.name).ok(); + render_function_module(out, child, indent + 4); + writeln!(out, "{}}}", indent_str).ok(); + } +} + +fn render_type_module(out: &mut String, node: &ModuleNode, indent: usize) { + let indent_str = " ".repeat(indent); + for ty in &node.types { + render_type(out, ty, indent); + } + for child in node.children.values() { + writeln!(out, "{}pub mod {} {{", indent_str, child.name).ok(); + render_type_module(out, child, indent + 4); + writeln!(out, "{}}}", indent_str).ok(); + } +} + +fn render_function(out: &mut String, func: &FunctionGen, indent: usize) { + let indent_str = " ".repeat(indent); + let static_name = static_ident("FUNC", &func.full_name); + writeln!( + out, + "{}static {}: LazyLock = LazyLock::new(|| Function::get_global(\"{}\").expect(\"missing global function\"));", + indent_str, static_name, func.full_name + ) + .ok(); + if func.sig.packed { + writeln!( + out, + "{}pub fn {}(args: &[Any]) -> Result {{", + indent_str, func.rust_name + ) + .ok(); + writeln!( + out, + "{} let func = &*{};", + indent_str, static_name + ) + .ok(); + writeln!( + out, + "{} let views: Vec> = args.iter().map(AnyView::from).collect();", + indent_str + ) + .ok(); + writeln!(out, "{} func.call_packed(&views)", indent_str).ok(); + writeln!(out, "{}}}", indent_str).ok(); + writeln!(out).ok(); + return; + } + let args = render_args(&func.sig.args); + writeln!( + out, + "{}pub fn {}({}) -> Result<{}> {{", + indent_str, + func.rust_name, + args, + func.sig.ret.name + ) + .ok(); + writeln!( + out, + "{} let func = &*{};", + indent_str, static_name + ) + .ok(); + writeln!( + out, + "{} let typed = tvm_ffi::into_typed_fn!(func.clone(), Fn({}) -> Result<{}>);", + indent_str, + render_type_list(&func.sig.args), + func.sig.ret.typed_name() + ) + .ok(); + let call_expr = format!("typed({})", render_call_args_typed(&func.sig.args)); + writeln!( + out, + "{} {}", + indent_str, + func.sig.ret.wrap_return(&call_expr) + ) + .ok(); + writeln!(out, "{}}}", indent_str).ok(); + writeln!(out).ok(); +} + +fn render_type(out: &mut String, ty: &TypeGen, indent: usize) { + let indent_str = " ".repeat(indent); + writeln!(out, "{}#[derive(Clone)]", indent_str).ok(); + writeln!(out, "{}pub struct {} {{", indent_str, ty.rust_name).ok(); + writeln!(out, "{} inner: ObjectRef,", indent_str).ok(); + writeln!(out, "{}}}\n", indent_str).ok(); + + writeln!(out, "{}impl {} {{", indent_str, ty.rust_name).ok(); + writeln!( + out, + "{} pub fn from_object(inner: ObjectRef) -> Self {{", + indent_str + ) + .ok(); + writeln!(out, "{} Self {{ inner }}", indent_str).ok(); + writeln!(out, "{} }}", indent_str).ok(); + writeln!( + out, + "{} pub fn as_object_ref(&self) -> &ObjectRef {{", + indent_str + ) + .ok(); + writeln!(out, "{} &self.inner", indent_str).ok(); + writeln!(out, "{} }}", indent_str).ok(); + writeln!(out, "{}}}\n", indent_str).ok(); + + writeln!(out, "{}impl From for {} {{", indent_str, ty.rust_name).ok(); + writeln!( + out, + "{} fn from(inner: ObjectRef) -> Self {{", + indent_str + ) + .ok(); + writeln!(out, "{} Self {{ inner }}", indent_str).ok(); + writeln!(out, "{} }}", indent_str).ok(); + writeln!(out, "{}}}\n", indent_str).ok(); + + for method in &ty.methods { + render_method_static(out, ty, method, indent); + } + + writeln!(out, "{}impl {} {{", indent_str, ty.rust_name).ok(); + for method in &ty.methods { + render_method(out, ty, method, indent + 4); + } + writeln!(out, "{}}}\n", indent_str).ok(); +} + +fn render_method_static(out: &mut String, ty: &TypeGen, method: &MethodGen, indent: usize) { + let indent_str = " ".repeat(indent); + let static_name = static_ident("METHOD", &format!("{}::{}", ty.type_key, method.rust_name)); + writeln!( + out, + "{}static {}: LazyLock = LazyLock::new(|| Function::get_global(\"{}\").expect(\"missing method\"));", + indent_str, static_name, method.full_name + ) + .ok(); +} + +fn render_method(out: &mut String, ty: &TypeGen, method: &MethodGen, indent: usize) { + let indent_str = " ".repeat(indent); + let static_name = static_ident("METHOD", &format!("{}::{}", ty.type_key, method.rust_name)); + let self_prefix = if method.is_static { "" } else { "&self" }; + if method.sig.packed { + if method.is_static { + writeln!( + out, + "{}pub fn {}(args: &[Any]) -> Result {{", + indent_str, method.rust_name + ) + .ok(); + writeln!( + out, + "{} let func = &*{};", + indent_str, static_name + ) + .ok(); + writeln!( + out, + "{} let views: Vec> = args.iter().map(AnyView::from).collect();", + indent_str + ) + .ok(); + writeln!(out, "{} func.call_packed(&views)", indent_str).ok(); + writeln!(out, "{}}}", indent_str).ok(); + return; + } + writeln!( + out, + "{}pub fn {}(&self, args: &[Any]) -> Result {{", + indent_str, method.rust_name + ) + .ok(); + writeln!( + out, + "{} let func = &*{};", + indent_str, static_name + ) + .ok(); + writeln!( + out, + "{} let mut views: Vec> = Vec::with_capacity(args.len() + 1);", + indent_str + ) + .ok(); + writeln!(out, "{} views.push(AnyView::from(self.as_object_ref()));", indent_str).ok(); + writeln!( + out, + "{} views.extend(args.iter().map(AnyView::from));", + indent_str + ) + .ok(); + writeln!(out, "{} func.call_packed(&views)", indent_str).ok(); + writeln!(out, "{}}}", indent_str).ok(); + return; + } + + let args = render_args(&method.sig.args); + let signature = if method.is_static { + format!("{}({})", method.rust_name, args) + } else if args.is_empty() { + format!("{}({})", method.rust_name, self_prefix) + } else { + format!("{}({}, {})", method.rust_name, self_prefix, args) + }; + writeln!( + out, + "{}pub fn {} -> Result<{}> {{", + indent_str, + signature, + method.sig.ret.name + ) + .ok(); + writeln!( + out, + "{} let func = &*{};", + indent_str, static_name + ) + .ok(); + let type_list = if method.is_static { + render_type_list(&method.sig.args) + } else { + let mut types = vec!["tvm_ffi::object::ObjectRef".to_string()]; + types.extend(method.sig.args.iter().map(|arg| arg.typed_name().to_string())); + types.join(", ") + }; + writeln!( + out, + "{} let typed = tvm_ffi::into_typed_fn!(func.clone(), Fn({}) -> Result<{}>);", + indent_str, + type_list, + method.sig.ret.name + ) + .ok(); + let call_expr = format!("typed({})", render_method_call_args(method)); + writeln!(out, "{} {}", indent_str, method.sig.ret.wrap_return(&call_expr)).ok(); + writeln!(out, "{}}}", indent_str).ok(); +} + +fn render_args(args: &[RustType]) -> String { + let mut out = Vec::new(); + for (i, arg) in args.iter().enumerate() { + out.push(format!("_{}: {}", i, arg.name)); + } + out.join(", ") +} + +fn render_type_list(args: &[RustType]) -> String { + args.iter().map(|arg| arg.typed_name().to_string()).collect::>().join(", ") +} + +fn render_call_args_typed(args: &[RustType]) -> String { + let mut out = Vec::new(); + for (i, arg) in args.iter().enumerate() { + out.push(arg.call_expr(&format!("_{}", i))); + } + out.join(", ") +} + +fn render_method_call_args(method: &MethodGen) -> String { + if method.is_static { + return render_call_args_typed(&method.sig.args); + } + let mut out = Vec::new(); + let self_type = RustType::object_wrapper("Self"); + out.push(self_type.call_expr("self")); + for (i, arg) in method.sig.args.iter().enumerate() { + out.push(arg.call_expr(&format!("_{}", i))); + } + out.join(", ") +} + +fn map_method_name(name: &str) -> String { + if name == "__ffi_init__" { + return "c_ffi_init".to_string(); + } + sanitize_ident(name, IdentStyle::Function) +} + +#[derive(Clone, Copy, PartialEq, Eq)] +enum IdentStyle { + Function, + Module, + Type, +} + +fn sanitize_ident(name: &str, style: IdentStyle) -> String { + let mut out = String::new(); + let mut prev_underscore = false; + for (i, ch) in name.chars().enumerate() { + let mut c = ch; + if style == IdentStyle::Module { + if ch.is_ascii_uppercase() { + if i > 0 && !prev_underscore { + out.push('_'); + } + c = ch.to_ascii_lowercase(); + } + } + if c.is_ascii_alphanumeric() || c == '_' { + out.push(c); + prev_underscore = c == '_'; + } else { + out.push('_'); + prev_underscore = true; + } + } + if out.is_empty() { + out.push('_'); + } + if out.chars().next().unwrap().is_ascii_digit() { + out.insert(0, '_'); + } + const KEYWORDS: &[&str] = &[ + "as", "break", "const", "continue", "crate", "else", "enum", "extern", "false", + "fn", "for", "if", "impl", "in", "let", "loop", "match", "mod", "move", "mut", + "pub", "ref", "return", "self", "Self", "static", "struct", "super", "trait", + "true", "type", "unsafe", "use", "where", "while", "async", "await", "dyn", + ]; + if KEYWORDS.contains(&out.as_str()) { + out.push('_'); + } + match style { + IdentStyle::Type => to_pascal_case(&out), + _ => out, + } +} + +fn to_pascal_case(name: &str) -> String { + let mut out = String::new(); + let mut uppercase = true; + for ch in name.chars() { + if ch == '_' { + uppercase = true; + continue; + } + if uppercase { + out.extend(ch.to_uppercase()); + uppercase = false; + } else { + out.push(ch); + } + } + if out.is_empty() { + "Type".to_string() + } else { + out + } +} + +fn static_ident(prefix: &str, full_name: &str) -> String { + let mut out = String::new(); + out.push_str(prefix); + out.push('_'); + for ch in full_name.chars() { + if ch.is_ascii_alphanumeric() { + out.push(ch.to_ascii_uppercase()); + } else { + out.push('_'); + } + } + if out.chars().next().unwrap().is_ascii_digit() { + out.insert(0, '_'); + } + out +} diff --git a/rust/tvm-ffi-stubgen/src/lib.rs b/rust/tvm-ffi-stubgen/src/lib.rs new file mode 100644 index 000000000..901a0e4dc --- /dev/null +++ b/rust/tvm-ffi-stubgen/src/lib.rs @@ -0,0 +1,71 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +mod cli; +mod ffi; +mod generate; +mod model; +mod schema; +mod utils; + +#[cfg(test)] +mod tests; + +pub use cli::Args; + +pub fn run(args: Args) -> Result<(), Box> { + let prefix = utils::normalize_prefix(&args.init_prefix); + if args.dlls.is_empty() { + return Err("--dlls is required".into()); + } + utils::ensure_out_dir(&args.out_dir, args.overwrite)?; + + let _loaded_libs = ffi::load_dlls(&args.dlls)?; + + let global_funcs = ffi::list_global_function_names()?; + let filtered_funcs: Vec = global_funcs + .into_iter() + .filter(|name| name.starts_with(&prefix)) + .collect(); + + let type_keys = ffi::list_registered_type_keys()?; + let filtered_types: Vec = type_keys + .into_iter() + .filter(|name| name.starts_with(&prefix)) + .collect(); + + let type_map = generate::build_type_map(&filtered_types, &prefix); + let functions = generate::build_function_entries(&filtered_funcs, &type_map, &prefix)?; + let types = generate::build_type_entries(&filtered_types, &type_map, &prefix)?; + + let functions_root = generate::build_function_modules(functions, &prefix); + let types_root = generate::build_type_modules(types, &prefix); + + let cargo_toml = generate::render_cargo_toml(&args, &type_map)?; + let lib_rs = generate::render_lib_rs(); + let functions_rs = generate::render_functions_rs(&functions_root); + let types_rs = generate::render_types_rs(&types_root); + + let src_dir = args.out_dir.join("src"); + std::fs::create_dir_all(&src_dir)?; + std::fs::write(args.out_dir.join("Cargo.toml"), cargo_toml)?; + std::fs::write(src_dir.join("lib.rs"), lib_rs)?; + std::fs::write(src_dir.join("functions.rs"), functions_rs)?; + std::fs::write(src_dir.join("types.rs"), types_rs)?; + + Ok(()) +} diff --git a/rust/tvm-ffi-stubgen/src/main.rs b/rust/tvm-ffi-stubgen/src/main.rs new file mode 100644 index 000000000..b6568c8a2 --- /dev/null +++ b/rust/tvm-ffi-stubgen/src/main.rs @@ -0,0 +1,24 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use clap::Parser; +use tvm_ffi_stubgen::{run, Args}; + +fn main() -> Result<(), Box> { + let args = Args::parse(); + run(args) +} diff --git a/rust/tvm-ffi-stubgen/src/model.rs b/rust/tvm-ffi-stubgen/src/model.rs new file mode 100644 index 000000000..0997a73e6 --- /dev/null +++ b/rust/tvm-ffi-stubgen/src/model.rs @@ -0,0 +1,140 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::collections::BTreeMap; + +#[derive(Debug, Clone)] +pub(crate) struct RustType { + pub(crate) name: String, + pub(crate) supported: bool, + pub(crate) kind: RustTypeKind, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub(crate) enum RustTypeKind { + Plain, + ObjectWrapper, +} + +#[derive(Debug, Clone)] +pub(crate) struct FunctionSig { + pub(crate) args: Vec, + pub(crate) ret: RustType, + pub(crate) packed: bool, +} + +#[derive(Debug, Clone)] +pub(crate) struct FunctionGen { + pub(crate) full_name: String, + pub(crate) rust_name: String, + pub(crate) sig: FunctionSig, +} + +#[derive(Debug, Clone)] +pub(crate) struct MethodGen { + pub(crate) full_name: String, + pub(crate) rust_name: String, + pub(crate) sig: FunctionSig, + pub(crate) is_static: bool, +} + +#[derive(Debug, Clone)] +pub(crate) struct TypeGen { + pub(crate) type_key: String, + pub(crate) rust_name: String, + pub(crate) methods: Vec, +} + +#[derive(Debug, Default)] +pub(crate) struct ModuleNode { + pub(crate) name: String, + pub(crate) functions: Vec, + pub(crate) types: Vec, + pub(crate) children: BTreeMap, +} + +impl FunctionSig { + pub(crate) fn packed() -> Self { + Self { + args: Vec::new(), + ret: RustType::unsupported("tvm_ffi::Any"), + packed: true, + } + } + + pub(crate) fn from_types(args: Vec, ret: RustType) -> Self { + let typed = args.iter().all(|arg| arg.supported) && ret.supported; + Self { + args, + ret, + packed: !typed, + } + } +} + +impl RustType { + pub(crate) fn supported(name: &str) -> Self { + Self { + name: name.to_string(), + supported: true, + kind: RustTypeKind::Plain, + } + } + + pub(crate) fn unsupported(name: &str) -> Self { + Self { + name: name.to_string(), + supported: false, + kind: RustTypeKind::Plain, + } + } + + pub(crate) fn object_wrapper(name: &str) -> Self { + Self { + name: name.to_string(), + supported: true, + kind: RustTypeKind::ObjectWrapper, + } + } + + pub(crate) fn typed_name(&self) -> &str { + match self.kind { + RustTypeKind::Plain => &self.name, + RustTypeKind::ObjectWrapper => "tvm_ffi::object::ObjectRef", + } + } + + pub(crate) fn call_expr(&self, arg_name: &str) -> String { + match self.kind { + RustTypeKind::Plain => arg_name.to_string(), + RustTypeKind::ObjectWrapper => format!("{}.as_object_ref().clone()", arg_name), + } + } + + pub(crate) fn wrap_return(&self, expr: &str) -> String { + match self.kind { + RustTypeKind::Plain => expr.to_string(), + RustTypeKind::ObjectWrapper => { + if self.name == "Self" { + format!("{}.map(Self::from)", expr) + } else { + format!("{}.map({}::from)", expr, self.name) + } + } + } + } +} diff --git a/rust/tvm-ffi-stubgen/src/schema.rs b/rust/tvm-ffi-stubgen/src/schema.rs new file mode 100644 index 000000000..6cec7e771 --- /dev/null +++ b/rust/tvm-ffi-stubgen/src/schema.rs @@ -0,0 +1,52 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use serde::Deserialize; + +#[derive(Debug, Clone)] +pub(crate) struct TypeSchema { + pub(crate) origin: String, + pub(crate) args: Vec, +} + +#[derive(Deserialize)] +struct TypeSchemaJson { + #[serde(rename = "type")] + ty: String, + #[serde(default)] + args: Vec, +} + +pub(crate) fn extract_type_schema(metadata: &str) -> Option { + let value: serde_json::Value = serde_json::from_str(metadata).ok()?; + value + .get("type_schema") + .and_then(|v| v.as_str()) + .map(|s| s.to_string()) +} + +pub(crate) fn parse_type_schema(schema: &str) -> Option { + let json: TypeSchemaJson = serde_json::from_str(schema).ok()?; + Some(parse_type_schema_json(&json)) +} + +fn parse_type_schema_json(json: &TypeSchemaJson) -> TypeSchema { + TypeSchema { + origin: json.ty.clone(), + args: json.args.iter().map(parse_type_schema_json).collect(), + } +} diff --git a/rust/tvm-ffi-stubgen/src/tests.rs b/rust/tvm-ffi-stubgen/src/tests.rs new file mode 100644 index 000000000..dfe227cc8 --- /dev/null +++ b/rust/tvm-ffi-stubgen/src/tests.rs @@ -0,0 +1,222 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use super::{run, Args}; +use crate::utils; +use std::env; +use std::fs; +use std::path::{Path, PathBuf}; +use std::process::Command; + +#[test] +fn stubgen_tvm_ffi_testing() { + let dlls = resolve_testing_dlls().expect("unable to locate tvm_ffi testing libraries"); + let out_dir = unique_temp_dir("tvm_ffi_stubgen_test"); + let args = Args { + out_dir: out_dir.clone(), + dlls: dlls.clone(), + init_prefix: "testing".to_string(), + init_crate: "tvm_ffi_testing_stub".to_string(), + tvm_ffi_path: Some(utils::default_tvm_ffi_path().expect("tvm-ffi path")), + overwrite: true, + }; + + run(args).expect("stubgen run"); + + let cargo_toml = out_dir.join("Cargo.toml"); + let functions_rs = out_dir.join("src").join("functions.rs"); + assert!(cargo_toml.exists(), "Cargo.toml not generated"); + assert!(functions_rs.exists(), "functions.rs not generated"); + + let functions_body = fs::read_to_string(functions_rs).expect("read functions.rs"); + assert!(functions_body.contains("add_one"), "missing add_one stub"); + + write_integration_test(&out_dir).expect("write integration test"); + run_generated_tests(&out_dir, &dlls).expect("run generated tests"); +} + +fn resolve_testing_dlls() -> Result, String> { + if let Ok(value) = env::var("TVM_FFI_TESTING_DLLS") { + let dlls = split_paths(&value); + if !dlls.is_empty() { + return Ok(dlls); + } + } + + if let Ok(dir) = env::var("TVM_FFI_TESTING_LIB_DIR") { + let dir = PathBuf::from(dir); + if let Some(dlls) = dlls_from_dir(&dir) { + return Ok(dlls); + } + } + + let manifest = PathBuf::from(env!("CARGO_MANIFEST_DIR")); + let tvm_ffi_root = manifest.join("../tvm-ffi"); + let mut candidates = vec![tvm_ffi_root.join("build/lib")]; + + if let Ok(venv) = env::var("VIRTUAL_ENV") { + if let Some(path) = find_venv_lib_dir(Path::new(&venv)) { + candidates.push(path); + } + } + + for dir in candidates { + if let Some(dlls) = dlls_from_dir(&dir) { + return Ok(dlls); + } + } + + Err("set TVM_FFI_TESTING_DLLS or TVM_FFI_TESTING_LIB_DIR to run tests".to_string()) +} + +fn dlls_from_dir(dir: &Path) -> Option> { + let tvm_ffi = dir.join(lib_filename("tvm_ffi")); + let tvm_ffi_testing = dir.join(lib_filename("tvm_ffi_testing")); + if tvm_ffi.exists() && tvm_ffi_testing.exists() { + Some(vec![tvm_ffi, tvm_ffi_testing]) + } else { + None + } +} + +fn lib_filename(name: &str) -> String { + if cfg!(target_os = "windows") { + format!("{}.dll", name) + } else if cfg!(target_os = "macos") { + format!("lib{}.dylib", name) + } else { + format!("lib{}.so", name) + } +} + +fn split_paths(value: &str) -> Vec { + let normalized = value.replace(';', ":"); + normalized + .split(':') + .filter(|item| !item.trim().is_empty()) + .map(PathBuf::from) + .collect() +} + +fn unique_temp_dir(prefix: &str) -> PathBuf { + let base = env::temp_dir(); + let pid = std::process::id(); + let nanos = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .map(|d| d.as_nanos()) + .unwrap_or(0); + base.join(format!("{}_{}_{}", prefix, pid, nanos)) +} + +fn find_venv_lib_dir(venv: &Path) -> Option { + for lib_root in ["lib", "lib64"] { + let base = venv.join(lib_root); + let entries = fs::read_dir(&base).ok()?; + for entry in entries.flatten() { + let path = entry.path(); + let name = path.file_name()?.to_string_lossy(); + if !name.starts_with("python") { + continue; + } + let candidate = path.join("site-packages").join("tvm_ffi").join("lib"); + if candidate.exists() { + return Some(candidate); + } + } + } + None +} + +fn write_integration_test(out_dir: &Path) -> Result<(), Box> { + let tests_dir = out_dir.join("tests"); + fs::create_dir_all(&tests_dir)?; + let test_body = r#"use tvm_ffi_testing_stub::add_one; + +#[test] +fn add_one_roundtrip() { + let lib_dir = std::env::var("TVM_FFI_TESTING_LIB_DIR").expect("lib dir"); + let lib_path = format!("{}/{}", lib_dir, lib_filename("tvm_ffi_testing")); + tvm_ffi::Module::load_from_file(&lib_path).expect("load tvm_ffi_testing"); + let value = add_one(1).expect("call add_one"); + assert_eq!(value, 2); +} + +fn lib_filename(name: &str) -> String { + if cfg!(target_os = "windows") { + format!("{}.dll", name) + } else if cfg!(target_os = "macos") { + format!("lib{}.dylib", name) + } else { + format!("lib{}.so", name) + } +} +"#; + fs::write(tests_dir.join("integration.rs"), test_body)?; + Ok(()) +} + +fn run_generated_tests(out_dir: &Path, dlls: &[PathBuf]) -> Result<(), Box> { + let mut cmd = Command::new("cargo"); + cmd.arg("test") + .arg("--manifest-path") + .arg(out_dir.join("Cargo.toml")) + .current_dir(out_dir); + + let lib_dir = dlls + .get(0) + .and_then(|path| path.parent()) + .map(|path| path.to_path_buf()) + .ok_or("missing library directory")?; + + let ld_var = if cfg!(target_os = "windows") { + "PATH" + } else if cfg!(target_os = "macos") { + "DYLD_LIBRARY_PATH" + } else { + "LD_LIBRARY_PATH" + }; + + let current_ld = env::var(ld_var).unwrap_or_default(); + let separator = if ld_var == "PATH" { ";" } else { ":" }; + let new_ld = if current_ld.is_empty() { + lib_dir.to_string_lossy().to_string() + } else { + format!("{}{}{}", lib_dir.to_string_lossy(), separator, current_ld) + }; + cmd.env(ld_var, new_ld); + cmd.env("TVM_FFI_TESTING_LIB_DIR", lib_dir); + + let mut path_value = env::var("PATH").unwrap_or_default(); + if let Ok(venv) = env::var("VIRTUAL_ENV") { + let venv_bin = Path::new(&venv).join("bin"); + let venv_str = venv_bin.to_string_lossy(); + if !path_value.split(':').any(|item| item == venv_str) { + if !path_value.is_empty() { + path_value = format!("{}:{}", venv_str, path_value); + } else { + path_value = venv_str.to_string(); + } + } + } + cmd.env("PATH", path_value); + + let status = cmd.status()?; + if !status.success() { + return Err("generated crate tests failed".into()); + } + Ok(()) +} diff --git a/rust/tvm-ffi-stubgen/src/utils.rs b/rust/tvm-ffi-stubgen/src/utils.rs new file mode 100644 index 000000000..e20003f99 --- /dev/null +++ b/rust/tvm-ffi-stubgen/src/utils.rs @@ -0,0 +1,77 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::fs; +use std::path::{Path, PathBuf}; + +pub(crate) fn normalize_prefix(prefix: &str) -> String { + if prefix.is_empty() { + return String::new(); + } + if prefix.ends_with('.') { + prefix.to_string() + } else { + format!("{}.", prefix) + } +} + +pub(crate) fn ensure_out_dir(out_dir: &Path, overwrite: bool) -> Result<(), Box> { + if out_dir.exists() { + let mut has_entries = false; + for entry in fs::read_dir(out_dir)? { + let entry = entry?; + if entry.file_name() != "." && entry.file_name() != ".." { + has_entries = true; + break; + } + } + if has_entries && !overwrite { + return Err("output directory is not empty (use --overwrite to proceed)".into()); + } + } else { + fs::create_dir_all(out_dir)?; + } + Ok(()) +} + +pub(crate) fn default_tvm_ffi_path() -> Result> { + let current = PathBuf::from(env!("CARGO_MANIFEST_DIR")); + let candidate = current.join("../tvm-ffi"); + if candidate.exists() { + return Ok(candidate); + } + Err("unable to locate tvm-ffi path (use --tvm-ffi-path)".into()) +} + +pub(crate) fn relative_path(from: &Path, to: &Path) -> PathBuf { + let from = from.canonicalize().unwrap_or_else(|_| from.to_path_buf()); + let to = to.canonicalize().unwrap_or_else(|_| to.to_path_buf()); + let from_components: Vec<_> = from.components().collect(); + let to_components: Vec<_> = to.components().collect(); + let mut i = 0; + while i < from_components.len() && i < to_components.len() && from_components[i] == to_components[i] { + i += 1; + } + let mut out = PathBuf::new(); + for _ in i..from_components.len() { + out.push(".."); + } + for comp in &to_components[i..] { + out.push(comp.as_os_str()); + } + out +} From 9c5306484c20ea5d53186e279f2b1d14d3b37f32 Mon Sep 17 00:00:00 2001 From: Huanqi Cao Date: Sat, 7 Feb 2026 22:28:50 +0800 Subject: [PATCH 02/29] Refine stubgen tests and build setup --- rust/tvm-ffi-stubgen/build.rs | 63 ++++++++ rust/tvm-ffi-stubgen/src/cli.rs | 5 +- rust/tvm-ffi-stubgen/src/ffi.rs | 8 +- rust/tvm-ffi-stubgen/src/generate.rs | 103 ++++++------ rust/tvm-ffi-stubgen/src/lib.rs | 3 - rust/tvm-ffi-stubgen/src/tests.rs | 222 -------------------------- rust/tvm-ffi-stubgen/src/utils.rs | 10 +- rust/tvm-ffi-stubgen/tests/stubgen.rs | 157 ++++++++++++++++++ 8 files changed, 284 insertions(+), 287 deletions(-) create mode 100644 rust/tvm-ffi-stubgen/build.rs delete mode 100644 rust/tvm-ffi-stubgen/src/tests.rs create mode 100644 rust/tvm-ffi-stubgen/tests/stubgen.rs diff --git a/rust/tvm-ffi-stubgen/build.rs b/rust/tvm-ffi-stubgen/build.rs new file mode 100644 index 000000000..3026cc71c --- /dev/null +++ b/rust/tvm-ffi-stubgen/build.rs @@ -0,0 +1,63 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::env; +use std::process::Command; + +fn main() { + let lib_dir = tvm_ffi_libdir(); + let target_os = env::var("CARGO_CFG_TARGET_OS").unwrap_or_default(); + + if target_os == "linux" || target_os == "macos" { + println!("cargo:rustc-link-arg=-Wl,-rpath,{}", lib_dir); + } + + let ld_var = match target_os.as_str() { + "windows" => "PATH", + "macos" => "DYLD_LIBRARY_PATH", + "linux" => "LD_LIBRARY_PATH", + _ => "", + }; + if !ld_var.is_empty() { + let current = env::var(ld_var).unwrap_or_default(); + let separator = if ld_var == "PATH" { ";" } else { ":" }; + let value = if current.is_empty() { + lib_dir.clone() + } else { + format!("{}{}{}", lib_dir, separator, current) + }; + println!("cargo:rustc-env={}={}", ld_var, value); + } +} + +fn tvm_ffi_libdir() -> String { + let output = Command::new("tvm-ffi-config") + .arg("--libdir") + .output() + .expect("tvm-ffi-config --libdir"); + if !output.status.success() { + panic!("tvm-ffi-config --libdir failed"); + } + let lib_dir = String::from_utf8(output.stdout) + .expect("tvm-ffi-config output") + .trim() + .to_string(); + if lib_dir.is_empty() { + panic!("tvm-ffi-config returned empty libdir"); + } + lib_dir +} diff --git a/rust/tvm-ffi-stubgen/src/cli.rs b/rust/tvm-ffi-stubgen/src/cli.rs index 905270c90..683b507dc 100644 --- a/rust/tvm-ffi-stubgen/src/cli.rs +++ b/rust/tvm-ffi-stubgen/src/cli.rs @@ -19,7 +19,10 @@ use clap::Parser; use std::path::PathBuf; #[derive(Parser, Debug)] -#[command(name = "tvm-ffi-stubgen", about = "Generate Rust stubs from tvm-ffi metadata")] +#[command( + name = "tvm-ffi-stubgen", + about = "Generate Rust stubs from tvm-ffi metadata" +)] pub struct Args { #[arg(value_name = "OUT_DIR")] pub out_dir: PathBuf, diff --git a/rust/tvm-ffi-stubgen/src/ffi.rs b/rust/tvm-ffi-stubgen/src/ffi.rs index 56d03bfaa..dda322033 100644 --- a/rust/tvm-ffi-stubgen/src/ffi.rs +++ b/rust/tvm-ffi-stubgen/src/ffi.rs @@ -19,11 +19,11 @@ use libloading::Library; use std::path::PathBuf; use std::sync::LazyLock; use tvm_ffi::function_internal::{ArgIntoRef, IntoArgHolder}; -use tvm_ffi::{Any, Error, Function, Result as FfiResult, String as FfiString, TYPE_ERROR}; use tvm_ffi::tvm_ffi_sys::{ TVMFFIAny, TVMFFIByteArray, TVMFFIGetTypeInfo, TVMFFIObjectHandle, TVMFFITypeIndex, TVMFFITypeInfo, TVMFFITypeKeyToIndex, }; +use tvm_ffi::{Any, Error, Function, Result as FfiResult, String as FfiString, TYPE_ERROR}; #[repr(C)] #[derive(Debug)] @@ -41,7 +41,9 @@ impl Clone for Array { unsafe { TVMFFIObjectIncRef(self.handle); } - Self { handle: self.handle } + Self { + handle: self.handle, + } } } @@ -101,7 +103,7 @@ impl ArgIntoRef for Array { } } -impl<'a> ArgIntoRef for &'a Array { +impl ArgIntoRef for &Array { type Target = Array; fn to_ref(&self) -> &Self::Target { self diff --git a/rust/tvm-ffi-stubgen/src/generate.rs b/rust/tvm-ffi-stubgen/src/generate.rs index b98102fbb..9ce8a10b1 100644 --- a/rust/tvm-ffi-stubgen/src/generate.rs +++ b/rust/tvm-ffi-stubgen/src/generate.rs @@ -79,9 +79,8 @@ pub(crate) fn build_type_entries( let mut methods = Vec::new(); if let Some(info) = ffi::get_type_info(key) { if info.num_methods > 0 && !info.methods.is_null() { - let method_slice = unsafe { - std::slice::from_raw_parts(info.methods, info.num_methods as usize) - }; + let method_slice = + unsafe { std::slice::from_raw_parts(info.methods, info.num_methods as usize) }; for method in method_slice { let method_name = match ffi::byte_array_to_string_opt(&method.name) { Some(name) => name, @@ -94,7 +93,8 @@ pub(crate) fn build_type_entries( .as_deref() .and_then(extract_type_schema) .and_then(|s| parse_type_schema(&s)); - let sig = build_method_sig(schema.as_ref(), type_map, Some(key.as_str()), is_static); + let sig = + build_method_sig(schema.as_ref(), type_map, Some(key.as_str()), is_static); let full_name = format!("{}.{}", key, method_name); methods.push(MethodGen { full_name, @@ -354,12 +354,7 @@ fn render_function(out: &mut String, func: &FunctionGen, indent: usize) { indent_str, func.rust_name ) .ok(); - writeln!( - out, - "{} let func = &*{};", - indent_str, static_name - ) - .ok(); + writeln!(out, "{} let func = &*{};", indent_str, static_name).ok(); writeln!( out, "{} let views: Vec> = args.iter().map(AnyView::from).collect();", @@ -375,18 +370,10 @@ fn render_function(out: &mut String, func: &FunctionGen, indent: usize) { writeln!( out, "{}pub fn {}({}) -> Result<{}> {{", - indent_str, - func.rust_name, - args, - func.sig.ret.name - ) - .ok(); - writeln!( - out, - "{} let func = &*{};", - indent_str, static_name + indent_str, func.rust_name, args, func.sig.ret.name ) .ok(); + writeln!(out, "{} let func = &*{};", indent_str, static_name).ok(); writeln!( out, "{} let typed = tvm_ffi::into_typed_fn!(func.clone(), Fn({}) -> Result<{}>);", @@ -433,7 +420,12 @@ fn render_type(out: &mut String, ty: &TypeGen, indent: usize) { writeln!(out, "{} }}", indent_str).ok(); writeln!(out, "{}}}\n", indent_str).ok(); - writeln!(out, "{}impl From for {} {{", indent_str, ty.rust_name).ok(); + writeln!( + out, + "{}impl From for {} {{", + indent_str, ty.rust_name + ) + .ok(); writeln!( out, "{} fn from(inner: ObjectRef) -> Self {{", @@ -478,12 +470,7 @@ fn render_method(out: &mut String, ty: &TypeGen, method: &MethodGen, indent: usi indent_str, method.rust_name ) .ok(); - writeln!( - out, - "{} let func = &*{};", - indent_str, static_name - ) - .ok(); + writeln!(out, "{} let func = &*{};", indent_str, static_name).ok(); writeln!( out, "{} let views: Vec> = args.iter().map(AnyView::from).collect();", @@ -500,19 +487,19 @@ fn render_method(out: &mut String, ty: &TypeGen, method: &MethodGen, indent: usi indent_str, method.rust_name ) .ok(); + writeln!(out, "{} let func = &*{};", indent_str, static_name).ok(); writeln!( out, - "{} let func = &*{};", - indent_str, static_name + "{} let mut views: Vec> = Vec::with_capacity(args.len() + 1);", + indent_str ) .ok(); writeln!( out, - "{} let mut views: Vec> = Vec::with_capacity(args.len() + 1);", + "{} views.push(AnyView::from(self.as_object_ref()));", indent_str ) .ok(); - writeln!(out, "{} views.push(AnyView::from(self.as_object_ref()));", indent_str).ok(); writeln!( out, "{} views.extend(args.iter().map(AnyView::from));", @@ -535,34 +522,37 @@ fn render_method(out: &mut String, ty: &TypeGen, method: &MethodGen, indent: usi writeln!( out, "{}pub fn {} -> Result<{}> {{", - indent_str, - signature, - method.sig.ret.name - ) - .ok(); - writeln!( - out, - "{} let func = &*{};", - indent_str, static_name + indent_str, signature, method.sig.ret.name ) .ok(); + writeln!(out, "{} let func = &*{};", indent_str, static_name).ok(); let type_list = if method.is_static { render_type_list(&method.sig.args) } else { let mut types = vec!["tvm_ffi::object::ObjectRef".to_string()]; - types.extend(method.sig.args.iter().map(|arg| arg.typed_name().to_string())); + types.extend( + method + .sig + .args + .iter() + .map(|arg| arg.typed_name().to_string()), + ); types.join(", ") }; writeln!( out, "{} let typed = tvm_ffi::into_typed_fn!(func.clone(), Fn({}) -> Result<{}>);", - indent_str, - type_list, - method.sig.ret.name + indent_str, type_list, method.sig.ret.name ) .ok(); let call_expr = format!("typed({})", render_method_call_args(method)); - writeln!(out, "{} {}", indent_str, method.sig.ret.wrap_return(&call_expr)).ok(); + writeln!( + out, + "{} {}", + indent_str, + method.sig.ret.wrap_return(&call_expr) + ) + .ok(); writeln!(out, "{}}}", indent_str).ok(); } @@ -575,7 +565,10 @@ fn render_args(args: &[RustType]) -> String { } fn render_type_list(args: &[RustType]) -> String { - args.iter().map(|arg| arg.typed_name().to_string()).collect::>().join(", ") + args.iter() + .map(|arg| arg.typed_name().to_string()) + .collect::>() + .join(", ") } fn render_call_args_typed(args: &[RustType]) -> String { @@ -618,13 +611,11 @@ fn sanitize_ident(name: &str, style: IdentStyle) -> String { let mut prev_underscore = false; for (i, ch) in name.chars().enumerate() { let mut c = ch; - if style == IdentStyle::Module { - if ch.is_ascii_uppercase() { - if i > 0 && !prev_underscore { - out.push('_'); - } - c = ch.to_ascii_lowercase(); + if style == IdentStyle::Module && ch.is_ascii_uppercase() { + if i > 0 && !prev_underscore { + out.push('_'); } + c = ch.to_ascii_lowercase(); } if c.is_ascii_alphanumeric() || c == '_' { out.push(c); @@ -641,10 +632,10 @@ fn sanitize_ident(name: &str, style: IdentStyle) -> String { out.insert(0, '_'); } const KEYWORDS: &[&str] = &[ - "as", "break", "const", "continue", "crate", "else", "enum", "extern", "false", - "fn", "for", "if", "impl", "in", "let", "loop", "match", "mod", "move", "mut", - "pub", "ref", "return", "self", "Self", "static", "struct", "super", "trait", - "true", "type", "unsafe", "use", "where", "while", "async", "await", "dyn", + "as", "break", "const", "continue", "crate", "else", "enum", "extern", "false", "fn", + "for", "if", "impl", "in", "let", "loop", "match", "mod", "move", "mut", "pub", "ref", + "return", "self", "Self", "static", "struct", "super", "trait", "true", "type", "unsafe", + "use", "where", "while", "async", "await", "dyn", ]; if KEYWORDS.contains(&out.as_str()) { out.push('_'); diff --git a/rust/tvm-ffi-stubgen/src/lib.rs b/rust/tvm-ffi-stubgen/src/lib.rs index 901a0e4dc..0e88156b1 100644 --- a/rust/tvm-ffi-stubgen/src/lib.rs +++ b/rust/tvm-ffi-stubgen/src/lib.rs @@ -22,9 +22,6 @@ mod model; mod schema; mod utils; -#[cfg(test)] -mod tests; - pub use cli::Args; pub fn run(args: Args) -> Result<(), Box> { diff --git a/rust/tvm-ffi-stubgen/src/tests.rs b/rust/tvm-ffi-stubgen/src/tests.rs deleted file mode 100644 index dfe227cc8..000000000 --- a/rust/tvm-ffi-stubgen/src/tests.rs +++ /dev/null @@ -1,222 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use super::{run, Args}; -use crate::utils; -use std::env; -use std::fs; -use std::path::{Path, PathBuf}; -use std::process::Command; - -#[test] -fn stubgen_tvm_ffi_testing() { - let dlls = resolve_testing_dlls().expect("unable to locate tvm_ffi testing libraries"); - let out_dir = unique_temp_dir("tvm_ffi_stubgen_test"); - let args = Args { - out_dir: out_dir.clone(), - dlls: dlls.clone(), - init_prefix: "testing".to_string(), - init_crate: "tvm_ffi_testing_stub".to_string(), - tvm_ffi_path: Some(utils::default_tvm_ffi_path().expect("tvm-ffi path")), - overwrite: true, - }; - - run(args).expect("stubgen run"); - - let cargo_toml = out_dir.join("Cargo.toml"); - let functions_rs = out_dir.join("src").join("functions.rs"); - assert!(cargo_toml.exists(), "Cargo.toml not generated"); - assert!(functions_rs.exists(), "functions.rs not generated"); - - let functions_body = fs::read_to_string(functions_rs).expect("read functions.rs"); - assert!(functions_body.contains("add_one"), "missing add_one stub"); - - write_integration_test(&out_dir).expect("write integration test"); - run_generated_tests(&out_dir, &dlls).expect("run generated tests"); -} - -fn resolve_testing_dlls() -> Result, String> { - if let Ok(value) = env::var("TVM_FFI_TESTING_DLLS") { - let dlls = split_paths(&value); - if !dlls.is_empty() { - return Ok(dlls); - } - } - - if let Ok(dir) = env::var("TVM_FFI_TESTING_LIB_DIR") { - let dir = PathBuf::from(dir); - if let Some(dlls) = dlls_from_dir(&dir) { - return Ok(dlls); - } - } - - let manifest = PathBuf::from(env!("CARGO_MANIFEST_DIR")); - let tvm_ffi_root = manifest.join("../tvm-ffi"); - let mut candidates = vec![tvm_ffi_root.join("build/lib")]; - - if let Ok(venv) = env::var("VIRTUAL_ENV") { - if let Some(path) = find_venv_lib_dir(Path::new(&venv)) { - candidates.push(path); - } - } - - for dir in candidates { - if let Some(dlls) = dlls_from_dir(&dir) { - return Ok(dlls); - } - } - - Err("set TVM_FFI_TESTING_DLLS or TVM_FFI_TESTING_LIB_DIR to run tests".to_string()) -} - -fn dlls_from_dir(dir: &Path) -> Option> { - let tvm_ffi = dir.join(lib_filename("tvm_ffi")); - let tvm_ffi_testing = dir.join(lib_filename("tvm_ffi_testing")); - if tvm_ffi.exists() && tvm_ffi_testing.exists() { - Some(vec![tvm_ffi, tvm_ffi_testing]) - } else { - None - } -} - -fn lib_filename(name: &str) -> String { - if cfg!(target_os = "windows") { - format!("{}.dll", name) - } else if cfg!(target_os = "macos") { - format!("lib{}.dylib", name) - } else { - format!("lib{}.so", name) - } -} - -fn split_paths(value: &str) -> Vec { - let normalized = value.replace(';', ":"); - normalized - .split(':') - .filter(|item| !item.trim().is_empty()) - .map(PathBuf::from) - .collect() -} - -fn unique_temp_dir(prefix: &str) -> PathBuf { - let base = env::temp_dir(); - let pid = std::process::id(); - let nanos = std::time::SystemTime::now() - .duration_since(std::time::UNIX_EPOCH) - .map(|d| d.as_nanos()) - .unwrap_or(0); - base.join(format!("{}_{}_{}", prefix, pid, nanos)) -} - -fn find_venv_lib_dir(venv: &Path) -> Option { - for lib_root in ["lib", "lib64"] { - let base = venv.join(lib_root); - let entries = fs::read_dir(&base).ok()?; - for entry in entries.flatten() { - let path = entry.path(); - let name = path.file_name()?.to_string_lossy(); - if !name.starts_with("python") { - continue; - } - let candidate = path.join("site-packages").join("tvm_ffi").join("lib"); - if candidate.exists() { - return Some(candidate); - } - } - } - None -} - -fn write_integration_test(out_dir: &Path) -> Result<(), Box> { - let tests_dir = out_dir.join("tests"); - fs::create_dir_all(&tests_dir)?; - let test_body = r#"use tvm_ffi_testing_stub::add_one; - -#[test] -fn add_one_roundtrip() { - let lib_dir = std::env::var("TVM_FFI_TESTING_LIB_DIR").expect("lib dir"); - let lib_path = format!("{}/{}", lib_dir, lib_filename("tvm_ffi_testing")); - tvm_ffi::Module::load_from_file(&lib_path).expect("load tvm_ffi_testing"); - let value = add_one(1).expect("call add_one"); - assert_eq!(value, 2); -} - -fn lib_filename(name: &str) -> String { - if cfg!(target_os = "windows") { - format!("{}.dll", name) - } else if cfg!(target_os = "macos") { - format!("lib{}.dylib", name) - } else { - format!("lib{}.so", name) - } -} -"#; - fs::write(tests_dir.join("integration.rs"), test_body)?; - Ok(()) -} - -fn run_generated_tests(out_dir: &Path, dlls: &[PathBuf]) -> Result<(), Box> { - let mut cmd = Command::new("cargo"); - cmd.arg("test") - .arg("--manifest-path") - .arg(out_dir.join("Cargo.toml")) - .current_dir(out_dir); - - let lib_dir = dlls - .get(0) - .and_then(|path| path.parent()) - .map(|path| path.to_path_buf()) - .ok_or("missing library directory")?; - - let ld_var = if cfg!(target_os = "windows") { - "PATH" - } else if cfg!(target_os = "macos") { - "DYLD_LIBRARY_PATH" - } else { - "LD_LIBRARY_PATH" - }; - - let current_ld = env::var(ld_var).unwrap_or_default(); - let separator = if ld_var == "PATH" { ";" } else { ":" }; - let new_ld = if current_ld.is_empty() { - lib_dir.to_string_lossy().to_string() - } else { - format!("{}{}{}", lib_dir.to_string_lossy(), separator, current_ld) - }; - cmd.env(ld_var, new_ld); - cmd.env("TVM_FFI_TESTING_LIB_DIR", lib_dir); - - let mut path_value = env::var("PATH").unwrap_or_default(); - if let Ok(venv) = env::var("VIRTUAL_ENV") { - let venv_bin = Path::new(&venv).join("bin"); - let venv_str = venv_bin.to_string_lossy(); - if !path_value.split(':').any(|item| item == venv_str) { - if !path_value.is_empty() { - path_value = format!("{}:{}", venv_str, path_value); - } else { - path_value = venv_str.to_string(); - } - } - } - cmd.env("PATH", path_value); - - let status = cmd.status()?; - if !status.success() { - return Err("generated crate tests failed".into()); - } - Ok(()) -} diff --git a/rust/tvm-ffi-stubgen/src/utils.rs b/rust/tvm-ffi-stubgen/src/utils.rs index e20003f99..a7f559c9e 100644 --- a/rust/tvm-ffi-stubgen/src/utils.rs +++ b/rust/tvm-ffi-stubgen/src/utils.rs @@ -29,7 +29,10 @@ pub(crate) fn normalize_prefix(prefix: &str) -> String { } } -pub(crate) fn ensure_out_dir(out_dir: &Path, overwrite: bool) -> Result<(), Box> { +pub(crate) fn ensure_out_dir( + out_dir: &Path, + overwrite: bool, +) -> Result<(), Box> { if out_dir.exists() { let mut has_entries = false; for entry in fs::read_dir(out_dir)? { @@ -63,7 +66,10 @@ pub(crate) fn relative_path(from: &Path, to: &Path) -> PathBuf { let from_components: Vec<_> = from.components().collect(); let to_components: Vec<_> = to.components().collect(); let mut i = 0; - while i < from_components.len() && i < to_components.len() && from_components[i] == to_components[i] { + while i < from_components.len() + && i < to_components.len() + && from_components[i] == to_components[i] + { i += 1; } let mut out = PathBuf::new(); diff --git a/rust/tvm-ffi-stubgen/tests/stubgen.rs b/rust/tvm-ffi-stubgen/tests/stubgen.rs new file mode 100644 index 000000000..924085d2e --- /dev/null +++ b/rust/tvm-ffi-stubgen/tests/stubgen.rs @@ -0,0 +1,157 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::env; +use std::fs; +use std::path::{Path, PathBuf}; +use std::process::Command; +use tvm_ffi_stubgen::{run, Args}; + +#[test] +fn stubgen_tvm_ffi_testing() { + let lib_dir = tvm_ffi_libdir().expect("tvm-ffi-config --libdir"); + let dlls = resolve_testing_dlls(&lib_dir).expect("unable to locate tvm_ffi testing libraries"); + let testing_lib = dlls + .iter() + .find(|path| { + path.file_name() + .map(|name| name.to_string_lossy().contains("tvm_ffi_testing")) + .unwrap_or(false) + }) + .cloned() + .expect("tvm_ffi_testing library"); + let out_dir = unique_temp_dir("tvm_ffi_stubgen_test"); + let args = Args { + out_dir: out_dir.clone(), + dlls: vec![testing_lib.clone()], + init_prefix: "testing".to_string(), + init_crate: "tvm_ffi_testing_stub".to_string(), + tvm_ffi_path: None, + overwrite: true, + }; + + run(args).expect("stubgen run"); + + let cargo_toml = out_dir.join("Cargo.toml"); + let functions_rs = out_dir.join("src").join("functions.rs"); + assert!(cargo_toml.exists(), "Cargo.toml not generated"); + assert!(functions_rs.exists(), "functions.rs not generated"); + + let functions_body = fs::read_to_string(functions_rs).expect("read functions.rs"); + assert!(functions_body.contains("add_one"), "missing add_one stub"); + + write_integration_test(&out_dir, &testing_lib).expect("write integration test"); + run_generated_tests(&out_dir, &lib_dir).expect("run generated tests"); +} + +fn resolve_testing_dlls(lib_dir: &Path) -> Result, String> { + if let Some(dlls) = dlls_from_dir(lib_dir) { + return Ok(dlls); + } + Err("tvm-ffi-config --libdir did not contain tvm_ffi libraries".to_string()) +} + +fn dlls_from_dir(dir: &Path) -> Option> { + let tvm_ffi = dir.join(lib_filename("tvm_ffi")); + let tvm_ffi_testing = dir.join(lib_filename("tvm_ffi_testing")); + if tvm_ffi.exists() && tvm_ffi_testing.exists() { + Some(vec![tvm_ffi, tvm_ffi_testing]) + } else { + None + } +} + +fn lib_filename(name: &str) -> String { + if cfg!(target_os = "windows") { + format!("{}.dll", name) + } else if cfg!(target_os = "macos") { + format!("lib{}.dylib", name) + } else { + format!("lib{}.so", name) + } +} + +fn unique_temp_dir(prefix: &str) -> PathBuf { + let base = env::temp_dir(); + let pid = std::process::id(); + let nanos = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .map(|d| d.as_nanos()) + .unwrap_or(0); + base.join(format!("{}_{}_{}", prefix, pid, nanos)) +} + +fn tvm_ffi_libdir() -> Result> { + let output = Command::new("tvm-ffi-config").arg("--libdir").output()?; + if !output.status.success() { + return Err("tvm-ffi-config --libdir failed".into()); + } + let lib_dir = String::from_utf8(output.stdout)?.trim().to_string(); + if lib_dir.is_empty() { + return Err("tvm-ffi-config returned empty libdir".into()); + } + Ok(PathBuf::from(lib_dir)) +} + +fn write_integration_test( + out_dir: &Path, + testing_lib: &Path, +) -> Result<(), Box> { + let tests_dir = out_dir.join("tests"); + fs::create_dir_all(&tests_dir)?; + let test_body = format!( + "use tvm_ffi_testing_stub::add_one;\n\n#[test]\nfn add_one_roundtrip() {{\n let lib_path = \"{}\";\n tvm_ffi::Module::load_from_file(lib_path).expect(\"load tvm_ffi_testing\");\n let value = add_one(1).expect(\"call add_one\");\n assert_eq!(value, 2);\n}}\n", + testing_lib.display() + ); + fs::write(tests_dir.join("integration.rs"), test_body)?; + Ok(()) +} + +fn run_generated_tests(out_dir: &Path, lib_dir: &Path) -> Result<(), Box> { + let mut cmd = Command::new("cargo"); + cmd.arg("test") + .arg("--manifest-path") + .arg(out_dir.join("Cargo.toml")) + .current_dir(out_dir); + + let ld_var = if cfg!(target_os = "windows") { + "PATH" + } else if cfg!(target_os = "macos") { + "DYLD_LIBRARY_PATH" + } else { + "LD_LIBRARY_PATH" + }; + + let current_ld = env::var(ld_var).unwrap_or_default(); + let separator = if ld_var == "PATH" { ";" } else { ":" }; + let lib_dir_str = lib_dir.to_string_lossy(); + let new_ld = if current_ld.is_empty() { + lib_dir_str.to_string() + } else { + format!("{}{}{}", lib_dir_str, separator, current_ld) + }; + cmd.env(ld_var, new_ld); + + let path_value = env::var("PATH").unwrap_or_default(); + cmd.env("PATH", path_value); + + let status = cmd.status()?; + if !status.success() { + return Err("generated crate tests failed".into()); + } + Ok(()) +} From 651314533cb812c6cee054992cf0f40ed50f978e Mon Sep 17 00:00:00 2001 From: Huanqi Cao Date: Sat, 7 Feb 2026 22:54:10 +0800 Subject: [PATCH 03/29] refactor stubgen array usage --- rust/tvm-ffi-stubgen/src/ffi.rs | 141 +-------------------------- rust/tvm-ffi-stubgen/src/generate.rs | 12 +-- 2 files changed, 10 insertions(+), 143 deletions(-) diff --git a/rust/tvm-ffi-stubgen/src/ffi.rs b/rust/tvm-ffi-stubgen/src/ffi.rs index dda322033..790ca5f58 100644 --- a/rust/tvm-ffi-stubgen/src/ffi.rs +++ b/rust/tvm-ffi-stubgen/src/ffi.rs @@ -17,123 +17,11 @@ use libloading::Library; use std::path::PathBuf; -use std::sync::LazyLock; -use tvm_ffi::function_internal::{ArgIntoRef, IntoArgHolder}; +use tvm_ffi::Array; use tvm_ffi::tvm_ffi_sys::{ - TVMFFIAny, TVMFFIByteArray, TVMFFIGetTypeInfo, TVMFFIObjectHandle, TVMFFITypeIndex, - TVMFFITypeInfo, TVMFFITypeKeyToIndex, + TVMFFIByteArray, TVMFFIGetTypeInfo, TVMFFITypeInfo, TVMFFITypeKeyToIndex, }; -use tvm_ffi::{Any, Error, Function, Result as FfiResult, String as FfiString, TYPE_ERROR}; - -#[repr(C)] -#[derive(Debug)] -pub(crate) struct Array { - handle: TVMFFIObjectHandle, -} - -extern "C" { - fn TVMFFIObjectIncRef(handle: TVMFFIObjectHandle) -> i32; - fn TVMFFIObjectDecRef(handle: TVMFFIObjectHandle) -> i32; -} - -impl Clone for Array { - fn clone(&self) -> Self { - unsafe { - TVMFFIObjectIncRef(self.handle); - } - Self { - handle: self.handle, - } - } -} - -impl Drop for Array { - fn drop(&mut self) { - unsafe { - TVMFFIObjectDecRef(self.handle); - } - } -} - -unsafe impl tvm_ffi::type_traits::AnyCompatible for Array { - fn type_str() -> String { - "ffi.Array".to_string() - } - - unsafe fn copy_to_any_view(src: &Self, data: &mut TVMFFIAny) { - data.type_index = TVMFFITypeIndex::kTVMFFIArray as i32; - data.small_str_len = 0; - data.data_union.v_obj = src.handle as *mut tvm_ffi::tvm_ffi_sys::TVMFFIObject; - } - - unsafe fn move_to_any(src: Self, data: &mut TVMFFIAny) { - data.type_index = TVMFFITypeIndex::kTVMFFIArray as i32; - data.small_str_len = 0; - data.data_union.v_obj = src.handle as *mut tvm_ffi::tvm_ffi_sys::TVMFFIObject; - } - - unsafe fn check_any_strict(data: &TVMFFIAny) -> bool { - data.type_index == TVMFFITypeIndex::kTVMFFIArray as i32 - } - - unsafe fn copy_from_any_view_after_check(data: &TVMFFIAny) -> Self { - let handle = data.data_union.v_obj as TVMFFIObjectHandle; - TVMFFIObjectIncRef(handle); - Self { handle } - } - - unsafe fn move_from_any_after_check(data: &mut TVMFFIAny) -> Self { - let handle = data.data_union.v_obj as TVMFFIObjectHandle; - Self { handle } - } - - unsafe fn try_cast_from_any_view(data: &TVMFFIAny) -> Result { - if data.type_index == TVMFFITypeIndex::kTVMFFIArray as i32 { - Ok(Self::copy_from_any_view_after_check(data)) - } else { - Err(()) - } - } -} - -impl ArgIntoRef for Array { - type Target = Array; - fn to_ref(&self) -> &Self::Target { - self - } -} - -impl ArgIntoRef for &Array { - type Target = Array; - fn to_ref(&self) -> &Self::Target { - self - } -} - -impl IntoArgHolder for Array { - type Target = Array; - fn into_arg_holder(self) -> Self::Target { - self - } -} - -impl<'a> IntoArgHolder for &'a Array { - type Target = &'a Array; - fn into_arg_holder(self) -> Self::Target { - self - } -} - -impl TryFrom for Array { - type Error = Error; - fn try_from(value: Any) -> Result { - if let Some(ret) = value.try_as::() { - Ok(ret) - } else { - Err(Error::new(TYPE_ERROR, "Expected ffi.Array", "")) - } - } -} +use tvm_ffi::{Function, Result as FfiResult, String as FfiString}; pub(crate) fn load_dlls(paths: &[PathBuf]) -> Result, Box> { let mut libs = Vec::new(); @@ -162,33 +50,14 @@ pub(crate) fn list_global_function_names() -> FfiResult> { pub(crate) fn list_registered_type_keys() -> FfiResult> { let get_keys = Function::get_global("ffi.GetRegisteredTypeKeys")?; let keys_any = get_keys.call_tuple_with_len::<0, _>(())?; - let keys: Array = keys_any.try_into()?; - let size = array_size(&keys)?; let mut out = Vec::new(); - for idx in 0..size { - let item = array_get_item(&keys, idx)?; - let key: FfiString = item.try_into()?; + let keys: Array = keys_any.try_into()?; + for key in &keys { out.push(key.as_str().to_string()); } Ok(out) } -pub(crate) fn array_size(array: &Array) -> FfiResult { - static FUNC: LazyLock = - LazyLock::new(|| Function::get_global("ffi.ArraySize").expect("ffi.ArraySize missing")); - let func = &*FUNC; - let size_any = func.call_tuple_with_len::<1, _>((array,))?; - size_any.try_into() -} - -pub(crate) fn array_get_item(array: &Array, index: i64) -> FfiResult { - static FUNC: LazyLock = LazyLock::new(|| { - Function::get_global("ffi.ArrayGetItem").expect("ffi.ArrayGetItem missing") - }); - let func = &*FUNC; - func.call_tuple_with_len::<2, _>((array, index)) -} - pub(crate) fn get_type_info(type_key: &str) -> Option<&'static TVMFFITypeInfo> { unsafe { let key = TVMFFIByteArray::from_str(type_key); diff --git a/rust/tvm-ffi-stubgen/src/generate.rs b/rust/tvm-ffi-stubgen/src/generate.rs index 9ce8a10b1..263068d49 100644 --- a/rust/tvm-ffi-stubgen/src/generate.rs +++ b/rust/tvm-ffi-stubgen/src/generate.rs @@ -234,13 +234,11 @@ fn insert_type(root: &mut ModuleNode, mods: &[String], ty: TypeGen) { } fn split_name(full_name: &str, prefix: &str) -> (Vec, String) { - let mut remainder = full_name; - if !prefix.is_empty() && remainder.starts_with(prefix) { - remainder = &remainder[prefix.len()..]; - } else if !prefix.is_empty() && remainder.starts_with(prefix.trim_end_matches('.')) { - remainder = &remainder[prefix.trim_end_matches('.').len()..]; - remainder = remainder.trim_start_matches('.'); - } + let remainder = if prefix.is_empty() { + full_name + } else { + full_name.strip_prefix(prefix).unwrap_or(full_name) + }; let parts: Vec<&str> = remainder.split('.').filter(|p| !p.is_empty()).collect(); if parts.is_empty() { return (Vec::new(), "ffi".to_string()); From b31a4ea10d77e045c8288f859fc735903a7a8335 Mon Sep 17 00:00:00 2001 From: Huanqi Cao Date: Sat, 7 Feb 2026 22:54:58 +0800 Subject: [PATCH 04/29] suppress non-snake-case warnings --- rust/tvm-ffi-stubgen/src/generate.rs | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/rust/tvm-ffi-stubgen/src/generate.rs b/rust/tvm-ffi-stubgen/src/generate.rs index 263068d49..21db9de4f 100644 --- a/rust/tvm-ffi-stubgen/src/generate.rs +++ b/rust/tvm-ffi-stubgen/src/generate.rs @@ -346,6 +346,12 @@ fn render_function(out: &mut String, func: &FunctionGen, indent: usize) { ) .ok(); if func.sig.packed { + writeln!( + out, + "{}#[allow(non_snake_case)]", + indent_str + ) + .ok(); writeln!( out, "{}pub fn {}(args: &[Any]) -> Result {{", @@ -365,6 +371,12 @@ fn render_function(out: &mut String, func: &FunctionGen, indent: usize) { return; } let args = render_args(&func.sig.args); + writeln!( + out, + "{}#[allow(non_snake_case)]", + indent_str + ) + .ok(); writeln!( out, "{}pub fn {}({}) -> Result<{}> {{", From 98028dece09be5896e3df85226e3726f662dbf93 Mon Sep 17 00:00:00 2001 From: Huanqi Cao Date: Sun, 8 Feb 2026 15:44:07 +0800 Subject: [PATCH 05/29] Improve stubgen typing and Map support --- rust/tvm-ffi-stubgen/src/generate.rs | 66 +++++-- rust/tvm-ffi-stubgen/src/lib.rs | 25 ++- rust/tvm-ffi-stubgen/src/schema.rs | 14 ++ rust/tvm-ffi/src/collections/map.rs | 261 +++++++++++++++++++++++++++ rust/tvm-ffi/src/collections/mod.rs | 1 + rust/tvm-ffi/src/lib.rs | 1 + 6 files changed, 349 insertions(+), 19 deletions(-) create mode 100644 rust/tvm-ffi/src/collections/map.rs diff --git a/rust/tvm-ffi-stubgen/src/generate.rs b/rust/tvm-ffi-stubgen/src/generate.rs index 21db9de4f..94a81dc40 100644 --- a/rust/tvm-ffi-stubgen/src/generate.rs +++ b/rust/tvm-ffi-stubgen/src/generate.rs @@ -178,30 +178,62 @@ fn rust_type_for_schema( "bool" => RustType::supported("bool"), "int" => RustType::supported("i64"), "float" => RustType::supported("f64"), - "Device" => RustType::unsupported("tvm_ffi::Any"), - "DataType" => RustType::unsupported("tvm_ffi::Any"), + "Device" => RustType::supported("tvm_ffi::DLDevice"), + "DataType" => RustType::supported("tvm_ffi::DLDataType"), "ffi.String" | "std::string" | "const char*" | "ffi.SmallStr" => { RustType::supported("tvm_ffi::String") } "ffi.Bytes" | "TVMFFIByteArray*" | "ffi.SmallBytes" => { RustType::supported("tvm_ffi::Bytes") } - "ffi.Function" => RustType::unsupported("tvm_ffi::Any"), - "ffi.Object" => RustType::unsupported("tvm_ffi::Any"), - "ffi.Tensor" | "DLTensor*" => RustType::unsupported("tvm_ffi::Any"), - "ffi.Shape" => RustType::unsupported("tvm_ffi::Any"), - "ffi.Module" => RustType::unsupported("tvm_ffi::Any"), - "Optional" => RustType::unsupported("tvm_ffi::Any"), - "Union" | "Variant" | "tuple" | "list" | "dict" | "ffi.Array" | "ffi.Map" | "Any" => { - RustType::unsupported("tvm_ffi::Any") - } - other => { - if let Some(_path) = type_map.get(other) { - RustType::unsupported("tvm_ffi::Any") - } else { - RustType::unsupported("tvm_ffi::Any") + "ffi.Function" => RustType::supported("tvm_ffi::Function"), + "ffi.Object" => RustType::supported("tvm_ffi::object::ObjectRef"), + "ffi.Tensor" | "DLTensor*" => RustType::supported("tvm_ffi::Tensor"), + "ffi.Shape" => RustType::supported("tvm_ffi::Shape"), + "ffi.Module" => RustType::supported("tvm_ffi::Module"), + "Optional" => match schema.args.as_slice() { + [inner] => { + let inner_ty = rust_type_for_schema(inner, type_map, _self_type_key); + if inner_ty.supported { + RustType::supported(&format!("Option<{}>", inner_ty.name)) + } else { + RustType::unsupported("tvm_ffi::Any") + } } - } + _ => RustType::unsupported("tvm_ffi::Any"), + }, + "ffi.Array" => match schema.args.as_slice() { + [inner] => { + let inner_ty = rust_type_for_schema(inner, type_map, _self_type_key); + if inner_ty.supported { + RustType::supported(&format!("tvm_ffi::Array<{}>", inner_ty.name)) + } else { + RustType::unsupported("tvm_ffi::Any") + } + } + _ => RustType::unsupported("tvm_ffi::Any"), + }, + "ffi.Map" => match schema.args.as_slice() { + [key, value] => { + let key_ty = rust_type_for_schema(key, type_map, _self_type_key); + let value_ty = rust_type_for_schema(value, type_map, _self_type_key); + if key_ty.supported && value_ty.supported { + RustType::supported(&format!( + "tvm_ffi::Map<{}, {}>", + key_ty.name, value_ty.name + )) + } else { + RustType::unsupported("tvm_ffi::Any") + } + } + _ => RustType::unsupported("tvm_ffi::Any"), + }, + "Any" | "ffi.Any" => RustType::supported("tvm_ffi::Any"), + "Union" | "Variant" | "tuple" | "list" | "dict" => RustType::unsupported("tvm_ffi::Any"), + other => match type_map.get(other) { + Some(path) => RustType::object_wrapper(path), + None => RustType::unsupported("tvm_ffi::Any"), + }, } } diff --git a/rust/tvm-ffi-stubgen/src/lib.rs b/rust/tvm-ffi-stubgen/src/lib.rs index 0e88156b1..07e283466 100644 --- a/rust/tvm-ffi-stubgen/src/lib.rs +++ b/rust/tvm-ffi-stubgen/src/lib.rs @@ -23,6 +23,8 @@ mod schema; mod utils; pub use cli::Args; +use std::collections::{BTreeSet, HashSet}; +use crate::schema::{collect_type_keys, extract_type_schema, parse_type_schema}; pub fn run(args: Args) -> Result<(), Box> { let prefix = utils::normalize_prefix(&args.init_prefix); @@ -40,11 +42,30 @@ pub fn run(args: Args) -> Result<(), Box> { .collect(); let type_keys = ffi::list_registered_type_keys()?; - let filtered_types: Vec = type_keys - .into_iter() + let type_key_set: HashSet = type_keys.iter().cloned().collect(); + let mut filtered_types: Vec = type_keys + .iter() .filter(|name| name.starts_with(&prefix)) + .cloned() .collect(); + let mut referenced_types: BTreeSet = BTreeSet::new(); + for full_name in &filtered_funcs { + let metadata = ffi::get_global_func_metadata(full_name)?; + let schema = metadata + .and_then(|meta| extract_type_schema(&meta)) + .and_then(|schema| parse_type_schema(&schema)); + if let Some(schema) = schema { + collect_type_keys(&schema, &type_key_set, &mut referenced_types); + } + } + + for ty in referenced_types { + if !filtered_types.contains(&ty) { + filtered_types.push(ty); + } + } + let type_map = generate::build_type_map(&filtered_types, &prefix); let functions = generate::build_function_entries(&filtered_funcs, &type_map, &prefix)?; let types = generate::build_type_entries(&filtered_types, &type_map, &prefix)?; diff --git a/rust/tvm-ffi-stubgen/src/schema.rs b/rust/tvm-ffi-stubgen/src/schema.rs index 6cec7e771..3baf5de3b 100644 --- a/rust/tvm-ffi-stubgen/src/schema.rs +++ b/rust/tvm-ffi-stubgen/src/schema.rs @@ -16,6 +16,7 @@ // under the License. use serde::Deserialize; +use std::collections::{BTreeSet, HashSet}; #[derive(Debug, Clone)] pub(crate) struct TypeSchema { @@ -44,6 +45,19 @@ pub(crate) fn parse_type_schema(schema: &str) -> Option { Some(parse_type_schema_json(&json)) } +pub(crate) fn collect_type_keys( + schema: &TypeSchema, + known: &HashSet, + out: &mut BTreeSet, +) { + if known.contains(&schema.origin) { + out.insert(schema.origin.clone()); + } + for arg in &schema.args { + collect_type_keys(arg, known, out); + } +} + fn parse_type_schema_json(json: &TypeSchemaJson) -> TypeSchema { TypeSchema { origin: json.ty.clone(), diff --git a/rust/tvm-ffi/src/collections/map.rs b/rust/tvm-ffi/src/collections/map.rs new file mode 100644 index 000000000..3199e7b5d --- /dev/null +++ b/rust/tvm-ffi/src/collections/map.rs @@ -0,0 +1,261 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +use std::marker::PhantomData; +use std::sync::LazyLock; + +use crate::any::TryFromTemp; +use crate::derive::Object; +use crate::error::Result; +use crate::function::Function; +use crate::object::{Object, ObjectArc, ObjectRefCore}; +use crate::type_traits::AnyCompatible; +use crate::{Any, AnyView}; +use tvm_ffi_sys::TVMFFITypeIndex as TypeIndex; +use tvm_ffi_sys::{TVMFFIAny, TVMFFIObject}; + +#[repr(C)] +#[derive(Object)] +#[type_key = "ffi.Map"] +#[type_index(TypeIndex::kTVMFFIMap)] +pub struct MapObj { + pub object: Object, +} + +impl Map +where + K: AnyCompatible + Clone + 'static, + V: AnyCompatible + Clone + 'static, +{ + /// Create a new Map from key/value pairs. + pub fn new>(items: I) -> Result { + static API_FUNC: LazyLock = + LazyLock::new(|| Function::get_global("ffi.Map").unwrap()); + let items: Vec<(K, V)> = items.into_iter().collect(); + let mut args: Vec = Vec::with_capacity(items.len() * 2); + for (key, value) in items.iter() { + args.push(AnyView::from(key)); + args.push(AnyView::from(value)); + } + (*API_FUNC).call_packed(&args)?.try_into() + } + + /// Return the number of entries in the map. + pub fn len(&self) -> Result { + static API_FUNC: LazyLock = + LazyLock::new(|| Function::get_global("ffi.MapSize").unwrap()); + let args = [AnyView::from(self)]; + let size_any = (*API_FUNC).call_packed(&args)?; + let temp: TryFromTemp = TryFromTemp::try_from(size_any)?; + let size = TryFromTemp::into_value(temp); + Ok(size as usize) + } + + /// Return true if the map is empty. + pub fn is_empty(&self) -> Result { + Ok(self.len()? == 0) + } + + /// Return true if the map contains the key. + pub fn contains_key(&self, key: &K) -> Result { + static API_FUNC: LazyLock = + LazyLock::new(|| Function::get_global("ffi.MapCount").unwrap()); + let args = [AnyView::from(self), AnyView::from(key)]; + let count_any = (*API_FUNC).call_packed(&args)?; + let temp: TryFromTemp = TryFromTemp::try_from(count_any)?; + let count = TryFromTemp::into_value(temp); + Ok(count != 0) + } + + /// Return the value for key or raise a KeyError. + pub fn get(&self, key: &K) -> Result { + static API_FUNC: LazyLock = + LazyLock::new(|| Function::get_global("ffi.MapGetItem").unwrap()); + let args = [AnyView::from(self), AnyView::from(key)]; + let value_any = (*API_FUNC).call_packed(&args)?; + let temp: TryFromTemp = TryFromTemp::try_from(value_any)?; + Ok(TryFromTemp::into_value(temp)) + } + + /// Return the value for key or None if missing. + pub fn get_optional(&self, key: &K) -> Result> { + if !self.contains_key(key)? { + return Ok(None); + } + self.get(key).map(Some) + } + + /// Return the value for key or a default value if missing. + pub fn get_or(&self, key: &K, default: V) -> Result { + match self.get_optional(key)? { + Some(value) => Ok(value), + None => Ok(default), + } + } + + /// Iterate over key/value pairs. + pub fn iter(&self) -> Result> { + static API_FUNC: LazyLock = + LazyLock::new(|| Function::get_global("ffi.MapForwardIterFunctor").unwrap()); + let args = [AnyView::from(self)]; + let functor: Function = (*API_FUNC).call_packed(&args)?.try_into()?; + Ok(MapIterator { + functor, + remaining: self.len()?, + _marker: PhantomData, + }) + } +} + +pub struct MapIterator { + functor: Function, + remaining: usize, + _marker: PhantomData<(K, V)>, +} + +impl Iterator for MapIterator +where + K: AnyCompatible + Clone + 'static, + V: AnyCompatible + Clone + 'static, +{ + type Item = (K, V); + + fn next(&mut self) -> Option { + if self.remaining == 0 { + return None; + } + let key_any = self + .functor + .call_tuple_with_len::<1, _>((0i64,)) + .ok()?; + let key_temp: TryFromTemp = TryFromTemp::try_from(key_any).ok()?; + let key = TryFromTemp::into_value(key_temp); + + let value_any = self + .functor + .call_tuple_with_len::<1, _>((1i64,)) + .ok()?; + let value_temp: TryFromTemp = TryFromTemp::try_from(value_any).ok()?; + let value = TryFromTemp::into_value(value_temp); + let _ = self.functor.call_tuple_with_len::<1, _>((2i64,)); + self.remaining -= 1; + Some((key, value)) + } +} +#[repr(C)] +#[derive(Clone)] +pub struct Map { + data: ObjectArc, + _marker: PhantomData<(K, V)>, +} + +unsafe impl ObjectRefCore for Map { + type ContainerType = MapObj; + + fn data(this: &Self) -> &ObjectArc { + &this.data + } + + fn into_data(this: Self) -> ObjectArc { + this.data + } + + fn from_data(data: ObjectArc) -> Self { + Self { + data, + _marker: PhantomData, + } + } +} + +// --- Any Type System Conversions --- + +unsafe impl AnyCompatible for Map +where + K: AnyCompatible + Clone + 'static, + V: AnyCompatible + Clone + 'static, +{ + fn type_str() -> String { + format!("Map<{}, {}>", K::type_str(), V::type_str()) + } + + unsafe fn check_any_strict(data: &TVMFFIAny) -> bool { + data.type_index == TypeIndex::kTVMFFIMap as i32 + } + + unsafe fn copy_to_any_view(src: &Self, data: &mut TVMFFIAny) { + data.type_index = TypeIndex::kTVMFFIMap as i32; + data.data_union.v_obj = ObjectArc::as_raw(Self::data(src)) as *mut TVMFFIObject; + data.small_str_len = 0; + } + + unsafe fn move_to_any(src: Self, data: &mut TVMFFIAny) { + data.type_index = TypeIndex::kTVMFFIMap as i32; + data.data_union.v_obj = ObjectArc::into_raw(Self::into_data(src)) as *mut TVMFFIObject; + data.small_str_len = 0; + } + + unsafe fn copy_from_any_view_after_check(data: &TVMFFIAny) -> Self { + let ptr = data.data_union.v_obj as *const MapObj; + crate::object::unsafe_::inc_ref(ptr as *mut TVMFFIObject); + Self::from_data(ObjectArc::from_raw(ptr)) + } + + unsafe fn move_from_any_after_check(data: &mut TVMFFIAny) -> Self { + let ptr = data.data_union.v_obj as *const MapObj; + let obj = Self::from_data(ObjectArc::from_raw(ptr)); + + data.type_index = TypeIndex::kTVMFFINone as i32; + data.data_union.v_int64 = 0; + + obj + } + + unsafe fn try_cast_from_any_view(data: &TVMFFIAny) -> Result { + if data.type_index != TypeIndex::kTVMFFIMap as i32 { + return Err(()); + } + Ok(Self::copy_from_any_view_after_check(data)) + } +} + +impl TryFrom for Map +where + K: AnyCompatible + Clone + 'static, + V: AnyCompatible + Clone + 'static, +{ + type Error = crate::error::Error; + + fn try_from(value: Any) -> Result { + let temp: TryFromTemp = TryFromTemp::try_from(value)?; + Ok(TryFromTemp::into_value(temp)) + } +} + +impl<'a, K, V> TryFrom> for Map +where + K: AnyCompatible + Clone + 'static, + V: AnyCompatible + Clone + 'static, +{ + type Error = crate::error::Error; + + fn try_from(value: AnyView<'a>) -> Result { + let temp: TryFromTemp = TryFromTemp::try_from(value)?; + Ok(TryFromTemp::into_value(temp)) + } +} diff --git a/rust/tvm-ffi/src/collections/mod.rs b/rust/tvm-ffi/src/collections/mod.rs index ad17dccae..791ff7557 100644 --- a/rust/tvm-ffi/src/collections/mod.rs +++ b/rust/tvm-ffi/src/collections/mod.rs @@ -18,5 +18,6 @@ */ /// Collection types pub mod array; +pub mod map; pub mod shape; pub mod tensor; diff --git a/rust/tvm-ffi/src/lib.rs b/rust/tvm-ffi/src/lib.rs index fad82601c..891f734e1 100644 --- a/rust/tvm-ffi/src/lib.rs +++ b/rust/tvm-ffi/src/lib.rs @@ -33,6 +33,7 @@ pub use tvm_ffi_sys; pub use crate::any::{Any, AnyView}; pub use crate::collections::array::Array; +pub use crate::collections::map::Map; pub use crate::collections::shape::Shape; pub use crate::collections::tensor::{CPUNDAlloc, NDAllocator, Tensor}; pub use crate::device::{current_stream, with_stream}; From 891de62bd2f1e89a0256b281727a80e530e60c30 Mon Sep 17 00:00:00 2001 From: Huanqi Cao Date: Sun, 8 Feb 2026 15:47:34 +0800 Subject: [PATCH 06/29] Format Rust sources --- rust/tvm-ffi-stubgen/src/ffi.rs | 2 +- rust/tvm-ffi-stubgen/src/generate.rs | 14 ++------------ rust/tvm-ffi-stubgen/src/lib.rs | 2 +- rust/tvm-ffi/src/collections/map.rs | 10 ++-------- 4 files changed, 6 insertions(+), 22 deletions(-) diff --git a/rust/tvm-ffi-stubgen/src/ffi.rs b/rust/tvm-ffi-stubgen/src/ffi.rs index 790ca5f58..100e5bf16 100644 --- a/rust/tvm-ffi-stubgen/src/ffi.rs +++ b/rust/tvm-ffi-stubgen/src/ffi.rs @@ -17,10 +17,10 @@ use libloading::Library; use std::path::PathBuf; -use tvm_ffi::Array; use tvm_ffi::tvm_ffi_sys::{ TVMFFIByteArray, TVMFFIGetTypeInfo, TVMFFITypeInfo, TVMFFITypeKeyToIndex, }; +use tvm_ffi::Array; use tvm_ffi::{Function, Result as FfiResult, String as FfiString}; pub(crate) fn load_dlls(paths: &[PathBuf]) -> Result, Box> { diff --git a/rust/tvm-ffi-stubgen/src/generate.rs b/rust/tvm-ffi-stubgen/src/generate.rs index 94a81dc40..2c6cecb56 100644 --- a/rust/tvm-ffi-stubgen/src/generate.rs +++ b/rust/tvm-ffi-stubgen/src/generate.rs @@ -378,12 +378,7 @@ fn render_function(out: &mut String, func: &FunctionGen, indent: usize) { ) .ok(); if func.sig.packed { - writeln!( - out, - "{}#[allow(non_snake_case)]", - indent_str - ) - .ok(); + writeln!(out, "{}#[allow(non_snake_case)]", indent_str).ok(); writeln!( out, "{}pub fn {}(args: &[Any]) -> Result {{", @@ -403,12 +398,7 @@ fn render_function(out: &mut String, func: &FunctionGen, indent: usize) { return; } let args = render_args(&func.sig.args); - writeln!( - out, - "{}#[allow(non_snake_case)]", - indent_str - ) - .ok(); + writeln!(out, "{}#[allow(non_snake_case)]", indent_str).ok(); writeln!( out, "{}pub fn {}({}) -> Result<{}> {{", diff --git a/rust/tvm-ffi-stubgen/src/lib.rs b/rust/tvm-ffi-stubgen/src/lib.rs index 07e283466..9128e9fc4 100644 --- a/rust/tvm-ffi-stubgen/src/lib.rs +++ b/rust/tvm-ffi-stubgen/src/lib.rs @@ -22,9 +22,9 @@ mod model; mod schema; mod utils; +use crate::schema::{collect_type_keys, extract_type_schema, parse_type_schema}; pub use cli::Args; use std::collections::{BTreeSet, HashSet}; -use crate::schema::{collect_type_keys, extract_type_schema, parse_type_schema}; pub fn run(args: Args) -> Result<(), Box> { let prefix = utils::normalize_prefix(&args.init_prefix); diff --git a/rust/tvm-ffi/src/collections/map.rs b/rust/tvm-ffi/src/collections/map.rs index 3199e7b5d..f1b6be5c9 100644 --- a/rust/tvm-ffi/src/collections/map.rs +++ b/rust/tvm-ffi/src/collections/map.rs @@ -139,17 +139,11 @@ where if self.remaining == 0 { return None; } - let key_any = self - .functor - .call_tuple_with_len::<1, _>((0i64,)) - .ok()?; + let key_any = self.functor.call_tuple_with_len::<1, _>((0i64,)).ok()?; let key_temp: TryFromTemp = TryFromTemp::try_from(key_any).ok()?; let key = TryFromTemp::into_value(key_temp); - let value_any = self - .functor - .call_tuple_with_len::<1, _>((1i64,)) - .ok()?; + let value_any = self.functor.call_tuple_with_len::<1, _>((1i64,)).ok()?; let value_temp: TryFromTemp = TryFromTemp::try_from(value_any).ok()?; let value = TryFromTemp::into_value(value_temp); let _ = self.functor.call_tuple_with_len::<1, _>((2i64,)); From 91af2fdc83f8e6c3a66efe2234c730d6b8509c22 Mon Sep 17 00:00:00 2001 From: Huanqi Cao Date: Sun, 8 Feb 2026 16:16:36 +0800 Subject: [PATCH 07/29] Use toml serializer for Cargo.toml --- rust/tvm-ffi-stubgen/Cargo.toml | 1 + rust/tvm-ffi-stubgen/src/generate.rs | 42 +++++++++++++++++++--------- 2 files changed, 30 insertions(+), 13 deletions(-) diff --git a/rust/tvm-ffi-stubgen/Cargo.toml b/rust/tvm-ffi-stubgen/Cargo.toml index 49a29788a..03375be95 100644 --- a/rust/tvm-ffi-stubgen/Cargo.toml +++ b/rust/tvm-ffi-stubgen/Cargo.toml @@ -31,4 +31,5 @@ clap = { version = "4.5", features = ["derive"] } libloading = "0.8" serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" +toml = "0.8" tvm-ffi = { version = "0.1.0-alpha.0", path = "../tvm-ffi" } diff --git a/rust/tvm-ffi-stubgen/src/generate.rs b/rust/tvm-ffi-stubgen/src/generate.rs index 2c6cecb56..ae97ffcd1 100644 --- a/rust/tvm-ffi-stubgen/src/generate.rs +++ b/rust/tvm-ffi-stubgen/src/generate.rs @@ -22,6 +22,7 @@ use crate::schema::{extract_type_schema, parse_type_schema, TypeSchema}; use crate::utils; use std::collections::BTreeMap; use std::fmt::Write as _; +use toml::value::Table; const METHOD_FLAG_STATIC: i64 = 1 << 2; @@ -300,19 +301,34 @@ pub(crate) fn render_cargo_toml( Some(path) => path.clone(), None => utils::default_tvm_ffi_path()?, }; - let rel_path = utils::relative_path(&args.out_dir, &tvm_ffi_path); - let mut out = String::new(); - writeln!( - &mut out, - "[package]\nname = \"{}\"\nversion = \"0.1.0\"\nedition = \"2021\"\n", - args.init_crate - )?; - writeln!( - &mut out, - "[dependencies]\ntvm-ffi = {{ path = \"{}\" }}\n", - rel_path.display() - )?; - Ok(out) + let tvm_ffi_path = tvm_ffi_path.canonicalize().unwrap_or_else(|_| tvm_ffi_path); + let tvm_ffi_path_str = tvm_ffi_path.to_string_lossy().to_string(); + + let mut package = Table::new(); + package.insert( + "name".to_string(), + toml::Value::String(args.init_crate.clone()), + ); + package.insert( + "version".to_string(), + toml::Value::String("0.1.0".to_string()), + ); + package.insert( + "edition".to_string(), + toml::Value::String("2021".to_string()), + ); + + let mut tvm_ffi = Table::new(); + tvm_ffi.insert("path".to_string(), toml::Value::String(tvm_ffi_path_str)); + + let mut dependencies = Table::new(); + dependencies.insert("tvm-ffi".to_string(), toml::Value::Table(tvm_ffi)); + + let mut doc = Table::new(); + doc.insert("package".to_string(), toml::Value::Table(package)); + doc.insert("dependencies".to_string(), toml::Value::Table(dependencies)); + + Ok(toml::to_string(&toml::Value::Table(doc))?) } pub(crate) fn render_lib_rs() -> String { From 72304d7a3eeeefe6910c40439bd1eee948bad4d0 Mon Sep 17 00:00:00 2001 From: Huanqi Cao Date: Sun, 8 Feb 2026 18:00:17 +0800 Subject: [PATCH 08/29] Add AnyValue for typed Any support --- rust/tvm-ffi-stubgen/src/generate.rs | 26 +++++- rust/tvm-ffi/src/any.rs | 33 +++++++ rust/tvm-ffi/src/function_internal.rs | 118 ++++++++++++++++++++++++++ rust/tvm-ffi/src/lib.rs | 2 +- rust/tvm-ffi/src/type_traits.rs | 38 ++++++++- 5 files changed, 211 insertions(+), 6 deletions(-) diff --git a/rust/tvm-ffi-stubgen/src/generate.rs b/rust/tvm-ffi-stubgen/src/generate.rs index ae97ffcd1..d4e892da7 100644 --- a/rust/tvm-ffi-stubgen/src/generate.rs +++ b/rust/tvm-ffi-stubgen/src/generate.rs @@ -229,7 +229,7 @@ fn rust_type_for_schema( } _ => RustType::unsupported("tvm_ffi::Any"), }, - "Any" | "ffi.Any" => RustType::supported("tvm_ffi::Any"), + "Any" | "ffi.Any" => RustType::supported("tvm_ffi::AnyValue"), "Union" | "Variant" | "tuple" | "list" | "dict" => RustType::unsupported("tvm_ffi::Any"), other => match type_map.get(other) { Some(path) => RustType::object_wrapper(path), @@ -355,13 +355,23 @@ pub(crate) fn render_types_rs(root: &ModuleNode) -> String { let mut out = String::new(); out.push_str("use std::sync::LazyLock;\n"); out.push_str("use tvm_ffi::object::ObjectRef;\n"); - out.push_str("use tvm_ffi::{Any, AnyView, Function, Result};\n\n"); + out.push_str("use tvm_ffi::{Any, AnyView, Result};\n\n"); render_type_module(&mut out, root, 0); out } fn render_function_module(out: &mut String, node: &ModuleNode, indent: usize) { let indent_str = " ".repeat(indent); + if indent > 0 { + writeln!(out, "{}use std::sync::LazyLock;", indent_str).ok(); + writeln!( + out, + "{}use tvm_ffi::{{Any, AnyView, Function, Result}};", + indent_str + ) + .ok(); + writeln!(out).ok(); + } for func in &node.functions { render_function(out, func, indent); } @@ -374,6 +384,12 @@ fn render_function_module(out: &mut String, node: &ModuleNode, indent: usize) { fn render_type_module(out: &mut String, node: &ModuleNode, indent: usize) { let indent_str = " ".repeat(indent); + if indent > 0 { + writeln!(out, "{}use std::sync::LazyLock;", indent_str).ok(); + writeln!(out, "{}use tvm_ffi::object::ObjectRef;", indent_str).ok(); + writeln!(out, "{}use tvm_ffi::{{Any, AnyView, Result}};", indent_str).ok(); + writeln!(out).ok(); + } for ty in &node.types { render_type(out, ty, indent); } @@ -500,7 +516,7 @@ fn render_method_static(out: &mut String, ty: &TypeGen, method: &MethodGen, inde let static_name = static_ident("METHOD", &format!("{}::{}", ty.type_key, method.rust_name)); writeln!( out, - "{}static {}: LazyLock = LazyLock::new(|| Function::get_global(\"{}\").expect(\"missing method\"));", + "{}static {}: LazyLock = LazyLock::new(|| tvm_ffi::Function::get_global(\"{}\").expect(\"missing method\"));", indent_str, static_name, method.full_name ) .ok(); @@ -590,7 +606,9 @@ fn render_method(out: &mut String, ty: &TypeGen, method: &MethodGen, indent: usi writeln!( out, "{} let typed = tvm_ffi::into_typed_fn!(func.clone(), Fn({}) -> Result<{}>);", - indent_str, type_list, method.sig.ret.name + indent_str, + type_list, + method.sig.ret.typed_name() ) .ok(); let call_expr = format!("typed({})", render_method_call_args(method)); diff --git a/rust/tvm-ffi/src/any.rs b/rust/tvm-ffi/src/any.rs index ecf8b9ea3..e37a36df8 100644 --- a/rust/tvm-ffi/src/any.rs +++ b/rust/tvm-ffi/src/any.rs @@ -37,6 +37,13 @@ pub struct Any { data: TVMFFIAny, } +/// Managed Any wrapper that participates in typed signatures. +#[repr(transparent)] +#[derive(Clone)] +pub struct AnyValue { + inner: Any, +} + //--------------------- // AnyView //--------------------- @@ -144,6 +151,11 @@ impl Any { pub fn type_index(&self) -> i32 { self.data.type_index } + + #[inline] + pub(crate) fn as_raw_ffi_any(&self) -> TVMFFIAny { + self.data + } /// Try to query if stored typed in Any exactly matches the type T /// /// This function is fast in the case of failure and can be used to check @@ -200,6 +212,26 @@ impl Any { } } +impl AnyValue { + pub fn new(value: Any) -> Self { + Self { inner: value } + } + + pub fn as_any(&self) -> &Any { + &self.inner + } + + pub fn into_any(self) -> Any { + self.inner + } +} + +impl From for AnyValue { + fn from(value: Any) -> Self { + Self { inner: value } + } +} + impl Default for Any { fn default() -> Self { Self::new() @@ -329,6 +361,7 @@ crate::impl_try_from_any!( crate::string::String, crate::string::Bytes, crate::object::ObjectRef, + crate::any::AnyValue, tvm_ffi_sys::dlpack::DLDataType, tvm_ffi_sys::dlpack::DLDevice, ); diff --git a/rust/tvm-ffi/src/function_internal.rs b/rust/tvm-ffi/src/function_internal.rs index e059051c1..a1ae206b4 100644 --- a/rust/tvm-ffi/src/function_internal.rs +++ b/rust/tvm-ffi/src/function_internal.rs @@ -107,6 +107,65 @@ impl IntoArgHolder for &[u8] { } } +impl IntoArgHolder for crate::object::ObjectRef { + type Target = crate::object::ObjectRef; + fn into_arg_holder(self) -> Self::Target { + self + } +} + +impl IntoArgHolder for crate::any::AnyValue { + type Target = crate::any::AnyValue; + fn into_arg_holder(self) -> Self::Target { + self + } +} + +impl IntoArgHolder for crate::DLDevice { + type Target = crate::DLDevice; + fn into_arg_holder(self) -> Self::Target { + self + } +} + +impl IntoArgHolder for crate::DLDataType { + type Target = crate::DLDataType; + fn into_arg_holder(self) -> Self::Target { + self + } +} + +impl IntoArgHolder for crate::Array +where + T: crate::AnyCompatible + Clone + 'static, +{ + type Target = crate::Array; + fn into_arg_holder(self) -> Self::Target { + self + } +} + +impl IntoArgHolder for crate::Map +where + K: crate::AnyCompatible + Clone + 'static, + V: crate::AnyCompatible + Clone + 'static, +{ + type Target = crate::Map; + fn into_arg_holder(self) -> Self::Target { + self + } +} + +impl IntoArgHolder for Option +where + T: IntoArgHolder, +{ + type Target = Option; + fn into_arg_holder(self) -> Self::Target { + self.map(IntoArgHolder::into_arg_holder) + } +} + // helper trait to implement IntoArgHolderTuple to apply into_arg_holder to each element pub trait IntoArgHolderTuple { type Target; @@ -152,6 +211,65 @@ crate::impl_arg_into_ref!( bool, i8, i16, i32, i64, isize, u8, u16, u32, u64, usize, f32, f64, String, Bytes ); +impl ArgIntoRef for crate::object::ObjectRef { + type Target = crate::object::ObjectRef; + fn to_ref(&self) -> &Self::Target { + self + } +} + +impl ArgIntoRef for crate::any::AnyValue { + type Target = crate::any::AnyValue; + fn to_ref(&self) -> &Self::Target { + self + } +} + +impl ArgIntoRef for crate::DLDevice { + type Target = crate::DLDevice; + fn to_ref(&self) -> &Self::Target { + self + } +} + +impl ArgIntoRef for crate::DLDataType { + type Target = crate::DLDataType; + fn to_ref(&self) -> &Self::Target { + self + } +} + +impl ArgIntoRef for Option +where + T: AnyCompatible, +{ + type Target = Option; + fn to_ref(&self) -> &Self::Target { + self + } +} + +impl ArgIntoRef for crate::Array +where + T: AnyCompatible + Clone, +{ + type Target = crate::Array; + fn to_ref(&self) -> &Self::Target { + self + } +} + +impl ArgIntoRef for crate::Map +where + K: AnyCompatible + Clone, + V: AnyCompatible + Clone, +{ + type Target = crate::Map; + fn to_ref(&self) -> &Self::Target { + self + } +} + //----------------------------------------------------------- // TupleAsPackedArgs // diff --git a/rust/tvm-ffi/src/lib.rs b/rust/tvm-ffi/src/lib.rs index 891f734e1..95338a5f6 100644 --- a/rust/tvm-ffi/src/lib.rs +++ b/rust/tvm-ffi/src/lib.rs @@ -31,7 +31,7 @@ pub mod string; pub mod type_traits; pub use tvm_ffi_sys; -pub use crate::any::{Any, AnyView}; +pub use crate::any::{Any, AnyValue, AnyView}; pub use crate::collections::array::Array; pub use crate::collections::map::Map; pub use crate::collections::shape::Shape; diff --git a/rust/tvm-ffi/src/type_traits.rs b/rust/tvm-ffi/src/type_traits.rs index d39da4b46..4b676915e 100644 --- a/rust/tvm-ffi/src/type_traits.rs +++ b/rust/tvm-ffi/src/type_traits.rs @@ -16,8 +16,9 @@ * specific language governing permissions and limitations * under the License. */ +use crate::any::{Any, AnyValue}; use tvm_ffi_sys::TVMFFITypeIndex as TypeIndex; -use tvm_ffi_sys::{TVMFFIAny, TVMFFIGetTypeInfo}; +use tvm_ffi_sys::{TVMFFIAny, TVMFFIAnyViewToOwnedAny, TVMFFIGetTypeInfo}; //----------------------------------------------------- // AnyCompatible @@ -198,6 +199,41 @@ unsafe impl AnyCompatible for Option { } } +/// AnyCompatible for AnyValue +unsafe impl AnyCompatible for AnyValue { + fn type_str() -> String { + "Any".to_string() + } + + unsafe fn copy_to_any_view(src: &Self, data: &mut TVMFFIAny) { + *data = src.as_any().as_raw_ffi_any(); + } + + unsafe fn move_to_any(src: Self, data: &mut TVMFFIAny) { + *data = Any::into_raw_ffi_any(src.into_any()); + } + + unsafe fn check_any_strict(_data: &TVMFFIAny) -> bool { + true + } + + unsafe fn copy_from_any_view_after_check(data: &TVMFFIAny) -> Self { + let mut owned = TVMFFIAny::new(); + crate::check_safe_call!(TVMFFIAnyViewToOwnedAny(data, &mut owned)).unwrap(); + AnyValue::from(Any::from_raw_ffi_any(owned)) + } + + unsafe fn move_from_any_after_check(data: &mut TVMFFIAny) -> Self { + let raw = *data; + *data = TVMFFIAny::new(); + AnyValue::from(Any::from_raw_ffi_any(raw)) + } + + unsafe fn try_cast_from_any_view(data: &TVMFFIAny) -> Result { + Ok(Self::copy_from_any_view_after_check(data)) + } +} + /// AnyCompatible for void* unsafe impl AnyCompatible for *mut core::ffi::c_void { unsafe fn copy_to_any_view(src: &Self, data: &mut TVMFFIAny) { From 6d17273a91f9e88e7d1757efc23f83e6b1a582fe Mon Sep 17 00:00:00 2001 From: Huanqi Cao Date: Sun, 8 Feb 2026 18:01:22 +0800 Subject: [PATCH 09/29] Avoid AnyValue TryFrom overlap --- rust/tvm-ffi/src/any.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/rust/tvm-ffi/src/any.rs b/rust/tvm-ffi/src/any.rs index e37a36df8..7b5bac401 100644 --- a/rust/tvm-ffi/src/any.rs +++ b/rust/tvm-ffi/src/any.rs @@ -361,7 +361,6 @@ crate::impl_try_from_any!( crate::string::String, crate::string::Bytes, crate::object::ObjectRef, - crate::any::AnyValue, tvm_ffi_sys::dlpack::DLDataType, tvm_ffi_sys::dlpack::DLDevice, ); From c40f7996c8f734ffc2da47f1efdcbf76c444b4a1 Mon Sep 17 00:00:00 2001 From: Huanqi Cao Date: Sun, 8 Feb 2026 18:05:08 +0800 Subject: [PATCH 10/29] Extend typed macro arity --- rust/tvm-ffi-stubgen/src/model.rs | 2 +- rust/tvm-ffi/src/function_internal.rs | 8 +++ rust/tvm-ffi/src/macros.rs | 85 ++++++++++++++++++++++++++- 3 files changed, 93 insertions(+), 2 deletions(-) diff --git a/rust/tvm-ffi-stubgen/src/model.rs b/rust/tvm-ffi-stubgen/src/model.rs index 0997a73e6..d4d2dcfdc 100644 --- a/rust/tvm-ffi-stubgen/src/model.rs +++ b/rust/tvm-ffi-stubgen/src/model.rs @@ -77,7 +77,7 @@ impl FunctionSig { } pub(crate) fn from_types(args: Vec, ret: RustType) -> Self { - let typed = args.iter().all(|arg| arg.supported) && ret.supported; + let typed = args.len() <= 12 && args.iter().all(|arg| arg.supported) && ret.supported; Self { args, ret, diff --git a/rust/tvm-ffi/src/function_internal.rs b/rust/tvm-ffi/src/function_internal.rs index a1ae206b4..8bf4505e4 100644 --- a/rust/tvm-ffi/src/function_internal.rs +++ b/rust/tvm-ffi/src/function_internal.rs @@ -195,6 +195,10 @@ impl_into_arg_holder_tuple!(T0, T1, T2, T3, T4; 0, 1, 2, 3, 4); impl_into_arg_holder_tuple!(T0, T1, T2, T3, T4, T5; 0, 1, 2, 3, 4, 5); impl_into_arg_holder_tuple!(T0, T1, T2, T3, T4, T5, T6; 0, 1, 2, 3, 4, 5, 6); impl_into_arg_holder_tuple!(T0, T1, T2, T3, T4, T5, T6, T7; 0, 1, 2, 3, 4, 5, 6, 7); +impl_into_arg_holder_tuple!(T0, T1, T2, T3, T4, T5, T6, T7, T8; 0, 1, 2, 3, 4, 5, 6, 7, 8); +impl_into_arg_holder_tuple!(T0, T1, T2, T3, T4, T5, T6, T7, T8, T9; 0, 1, 2, 3, 4, 5, 6, 7, 8, 9); +impl_into_arg_holder_tuple!(T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10; 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10); +impl_into_arg_holder_tuple!(T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11; 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11); //------------------------------------------------------------ // ArgIntoRef @@ -309,3 +313,7 @@ impl_tuple_as_packed_args!(5; T0, T1, T2, T3, T4; 0, 1, 2, 3, 4); impl_tuple_as_packed_args!(6; T0, T1, T2, T3, T4, T5; 0, 1, 2, 3, 4, 5); impl_tuple_as_packed_args!(7; T0, T1, T2, T3, T4, T5, T6; 0, 1, 2, 3, 4, 5, 6); impl_tuple_as_packed_args!(8; T0, T1, T2, T3, T4, T5, T6, T7; 0, 1, 2, 3, 4, 5, 6, 7); +impl_tuple_as_packed_args!(9; T0, T1, T2, T3, T4, T5, T6, T7, T8; 0, 1, 2, 3, 4, 5, 6, 7, 8); +impl_tuple_as_packed_args!(10; T0, T1, T2, T3, T4, T5, T6, T7, T8, T9; 0, 1, 2, 3, 4, 5, 6, 7, 8, 9); +impl_tuple_as_packed_args!(11; T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10; 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10); +impl_tuple_as_packed_args!(12; T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11; 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11); diff --git a/rust/tvm-ffi/src/macros.rs b/rust/tvm-ffi/src/macros.rs index 95c2b5ba4..bf3f5ac11 100644 --- a/rust/tvm-ffi/src/macros.rs +++ b/rust/tvm-ffi/src/macros.rs @@ -312,7 +312,7 @@ macro_rules! tvm_ffi_dll_export_typed_func { /// Since the ffi mechanism requires us to pass arguments by reference. /// /// # Supported Argument Counts -/// This macro supports functions with 0 to 8 arguments. +/// This macro supports functions with 0 to 12 arguments. ///----------------------------------------------------------- #[macro_export] macro_rules! into_typed_fn { @@ -395,4 +395,87 @@ macro_rules! into_typed_fn { Ok(_f.call_tuple_with_len::<8, _>(tuple_args)?.try_into()?) } }}; + // Case for 9 arguments + ($f:expr, $trait:ident($t0:ty, $t1:ty, $t2:ty, $t3:ty, $t4:ty, $t5:ty, $t6:ty, $t7:ty, $t8:ty) + -> $ret_ty:ty) => {{ + let _f = $f; + move |a0: $t0, + a1: $t1, + a2: $t2, + a3: $t3, + a4: $t4, + a5: $t5, + a6: $t6, + a7: $t7, + a8: $t8| + -> $ret_ty { + use $crate::function_internal::IntoArgHolderTuple; + let tuple_args = (a0, a1, a2, a3, a4, a5, a6, a7, a8).into_arg_holder_tuple(); + Ok(_f.call_tuple_with_len::<9, _>(tuple_args)?.try_into()?) + } + }}; + // Case for 10 arguments + ($f:expr, $trait:ident($t0:ty, $t1:ty, $t2:ty, $t3:ty, $t4:ty, $t5:ty, $t6:ty, $t7:ty, $t8:ty, $t9:ty) + -> $ret_ty:ty) => {{ + let _f = $f; + move |a0: $t0, + a1: $t1, + a2: $t2, + a3: $t3, + a4: $t4, + a5: $t5, + a6: $t6, + a7: $t7, + a8: $t8, + a9: $t9| + -> $ret_ty { + use $crate::function_internal::IntoArgHolderTuple; + let tuple_args = (a0, a1, a2, a3, a4, a5, a6, a7, a8, a9).into_arg_holder_tuple(); + Ok(_f.call_tuple_with_len::<10, _>(tuple_args)?.try_into()?) + } + }}; + // Case for 11 arguments + ($f:expr, $trait:ident($t0:ty, $t1:ty, $t2:ty, $t3:ty, $t4:ty, $t5:ty, $t6:ty, $t7:ty, $t8:ty, $t9:ty, $t10:ty) + -> $ret_ty:ty) => {{ + let _f = $f; + move |a0: $t0, + a1: $t1, + a2: $t2, + a3: $t3, + a4: $t4, + a5: $t5, + a6: $t6, + a7: $t7, + a8: $t8, + a9: $t9, + a10: $t10| + -> $ret_ty { + use $crate::function_internal::IntoArgHolderTuple; + let tuple_args = (a0, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10).into_arg_holder_tuple(); + Ok(_f.call_tuple_with_len::<11, _>(tuple_args)?.try_into()?) + } + }}; + // Case for 12 arguments + ($f:expr, $trait:ident($t0:ty, $t1:ty, $t2:ty, $t3:ty, $t4:ty, $t5:ty, $t6:ty, $t7:ty, $t8:ty, $t9:ty, $t10:ty, $t11:ty) + -> $ret_ty:ty) => {{ + let _f = $f; + move |a0: $t0, + a1: $t1, + a2: $t2, + a3: $t3, + a4: $t4, + a5: $t5, + a6: $t6, + a7: $t7, + a8: $t8, + a9: $t9, + a10: $t10, + a11: $t11| + -> $ret_ty { + use $crate::function_internal::IntoArgHolderTuple; + let tuple_args = + (a0, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11).into_arg_holder_tuple(); + Ok(_f.call_tuple_with_len::<12, _>(tuple_args)?.try_into()?) + } + }}; } From be5076fcaa1cecc7ae6c8539efbd4bc2fb5ca9a2 Mon Sep 17 00:00:00 2001 From: Huanqi Cao Date: Sun, 8 Feb 2026 19:58:08 +0800 Subject: [PATCH 11/29] Improve stubgen wrappers and output --- rust/tvm-ffi-macros/src/object_macros.rs | 28 +++-- rust/tvm-ffi-stubgen/src/generate.rs | 141 ++++++++++++++--------- rust/tvm-ffi-stubgen/src/lib.rs | 2 + rust/tvm-ffi-stubgen/src/model.rs | 21 +++- rust/tvm-ffi/src/lib.rs | 2 + rust/tvm-ffi/src/macros.rs | 52 +++++++++ rust/tvm-ffi/src/object_wrapper.rs | 124 ++++++++++++++++++++ 7 files changed, 304 insertions(+), 66 deletions(-) create mode 100644 rust/tvm-ffi/src/object_wrapper.rs diff --git a/rust/tvm-ffi-macros/src/object_macros.rs b/rust/tvm-ffi-macros/src/object_macros.rs index 8154709df..b6566eb48 100644 --- a/rust/tvm-ffi-macros/src/object_macros.rs +++ b/rust/tvm-ffi-macros/src/object_macros.rs @@ -169,9 +169,27 @@ pub fn derive_object_ref(input: proc_macro::TokenStream) -> TokenStream { unsafe fn check_any_strict(data: & #tvm_ffi_crate::tvm_ffi_sys::TVMFFIAny) -> bool { type ContainerType = <#struct_name as #tvm_ffi_crate::object::ObjectRefCore> ::ContainerType; - let type_index = + let target_index = ::type_index(); - data.type_index == type_index as i32 + if data.type_index == target_index as i32 { + return true; + } + let info = #tvm_ffi_crate::tvm_ffi_sys::TVMFFIGetTypeInfo(data.type_index); + if info.is_null() { + return false; + } + let info = &*info; + let ancestors = info.type_acenstors; + if ancestors.is_null() { + return false; + } + for depth in 0..info.type_depth { + let ancestor = *ancestors.add(depth as usize); + if !ancestor.is_null() && (*ancestor).type_index == target_index { + return true; + } + } + false } unsafe fn copy_from_any_view_after_check( @@ -223,11 +241,7 @@ pub fn derive_object_ref(input: proc_macro::TokenStream) -> TokenStream { unsafe fn try_cast_from_any_view( data: & #tvm_ffi_crate::tvm_ffi_sys::TVMFFIAny ) -> Result { - type ContainerType = <#struct_name as #tvm_ffi_crate::object::ObjectRefCore> - ::ContainerType; - let type_index = - ::type_index(); - if data.type_index == type_index as i32 { + if Self::check_any_strict(data) { Ok(Self::copy_from_any_view_after_check(data)) } else { Err(()) diff --git a/rust/tvm-ffi-stubgen/src/generate.rs b/rust/tvm-ffi-stubgen/src/generate.rs index d4e892da7..869d7f2b4 100644 --- a/rust/tvm-ffi-stubgen/src/generate.rs +++ b/rust/tvm-ffi-stubgen/src/generate.rs @@ -333,29 +333,91 @@ pub(crate) fn render_cargo_toml( pub(crate) fn render_lib_rs() -> String { let mut out = String::new(); - out.push_str("pub mod functions;\n"); - out.push_str("pub mod types;\n\n"); - out.push_str("pub use functions::*;\n"); - out.push_str("pub use types::*;\n\n"); - out.push_str("pub fn load_library(path: &str) -> tvm_ffi::Result {\n"); - out.push_str(" tvm_ffi::Module::load_from_file(path)\n"); - out.push_str("}\n"); + out.push_str( + r#"pub mod functions; +pub mod types; + +pub use functions::*; +pub use types::*; + +pub fn load_library(path: &str) -> tvm_ffi::Result { + tvm_ffi::Module::load_from_file(path) +} +"#, + ); + out +} + +pub(crate) fn render_build_rs() -> String { + let mut out = String::new(); + out.push_str( + r#"use std::env; +use std::process::Command; + +fn update_ld_library_path(lib_dir: &str) { + let os_env_var = match env::var("CARGO_CFG_TARGET_OS").as_deref() { + Ok("windows") => "PATH", + Ok("macos") => "DYLD_LIBRARY_PATH", + Ok("linux") => "LD_LIBRARY_PATH", + _ => "", + }; + if os_env_var.is_empty() { + return; + } + let current_val = env::var(os_env_var).unwrap_or_else(|_| String::new()); + let separator = if os_env_var == "PATH" { ";" } else { ":" }; + let new_ld_path = if current_val.is_empty() { + lib_dir.to_string() + } else { + format!("{}{}{}", current_val, separator, lib_dir) + }; + println!("cargo:rustc-env={}={}", os_env_var, new_ld_path); +} + +fn main() { + let output = Command::new("tvm-ffi-config") + .arg("--libdir") + .output() + .expect("Failed to run tvm-ffi-config"); + if !output.status.success() { + panic!("tvm-ffi-config --libdir failed"); + } + let lib_dir = String::from_utf8(output.stdout) + .expect("Invalid UTF-8 output from tvm-ffi-config") + .trim() + .to_string(); + if lib_dir.is_empty() { + panic!("tvm-ffi-config returned empty library path"); + } + println!("cargo:rustc-link-search=native={}", lib_dir); + println!("cargo:rustc-link-lib=dylib=tvm_ffi"); + update_ld_library_path(&lib_dir); +} +"#, + ); out } pub(crate) fn render_functions_rs(root: &ModuleNode) -> String { let mut out = String::new(); - out.push_str("use std::sync::LazyLock;\n"); - out.push_str("use tvm_ffi::{Any, AnyView, Function, Result};\n\n"); + out.push_str( + r#"use std::sync::LazyLock; +use tvm_ffi::{Any, AnyView, Function, Result}; + +"#, + ); render_function_module(&mut out, root, 0); out } pub(crate) fn render_types_rs(root: &ModuleNode) -> String { let mut out = String::new(); - out.push_str("use std::sync::LazyLock;\n"); - out.push_str("use tvm_ffi::object::ObjectRef;\n"); - out.push_str("use tvm_ffi::{Any, AnyView, Result};\n\n"); + out.push_str( + r#"use std::sync::LazyLock; +use tvm_ffi::{Any, AnyView, Result}; + +"#, + ); render_type_module(&mut out, root, 0); out } @@ -386,7 +448,6 @@ fn render_type_module(out: &mut String, node: &ModuleNode, indent: usize) { let indent_str = " ".repeat(indent); if indent > 0 { writeln!(out, "{}use std::sync::LazyLock;", indent_str).ok(); - writeln!(out, "{}use tvm_ffi::object::ObjectRef;", indent_str).ok(); writeln!(out, "{}use tvm_ffi::{{Any, AnyView, Result}};", indent_str).ok(); writeln!(out).ok(); } @@ -443,7 +504,7 @@ fn render_function(out: &mut String, func: &FunctionGen, indent: usize) { "{} let typed = tvm_ffi::into_typed_fn!(func.clone(), Fn({}) -> Result<{}>);", indent_str, render_type_list(&func.sig.args), - func.sig.ret.typed_name() + func.sig.ret.typed_ret_name() ) .ok(); let call_expr = format!("typed({})", render_call_args_typed(&func.sig.args)); @@ -451,7 +512,9 @@ fn render_function(out: &mut String, func: &FunctionGen, indent: usize) { out, "{} {}", indent_str, - func.sig.ret.wrap_return(&call_expr) + func.sig + .ret + .wrap_typed_return(&call_expr, func.sig.ret.typed_ret_name()) ) .ok(); writeln!(out, "{}}}", indent_str).ok(); @@ -460,45 +523,12 @@ fn render_function(out: &mut String, func: &FunctionGen, indent: usize) { fn render_type(out: &mut String, ty: &TypeGen, indent: usize) { let indent_str = " ".repeat(indent); - writeln!(out, "{}#[derive(Clone)]", indent_str).ok(); - writeln!(out, "{}pub struct {} {{", indent_str, ty.rust_name).ok(); - writeln!(out, "{} inner: ObjectRef,", indent_str).ok(); - writeln!(out, "{}}}\n", indent_str).ok(); - - writeln!(out, "{}impl {} {{", indent_str, ty.rust_name).ok(); - writeln!( - out, - "{} pub fn from_object(inner: ObjectRef) -> Self {{", - indent_str - ) - .ok(); - writeln!(out, "{} Self {{ inner }}", indent_str).ok(); - writeln!(out, "{} }}", indent_str).ok(); writeln!( out, - "{} pub fn as_object_ref(&self) -> &ObjectRef {{", - indent_str + "{}tvm_ffi::define_object_wrapper!({}, \"{}\");\n", + indent_str, ty.rust_name, ty.type_key ) .ok(); - writeln!(out, "{} &self.inner", indent_str).ok(); - writeln!(out, "{} }}", indent_str).ok(); - writeln!(out, "{}}}\n", indent_str).ok(); - - writeln!( - out, - "{}impl From for {} {{", - indent_str, ty.rust_name - ) - .ok(); - writeln!( - out, - "{} fn from(inner: ObjectRef) -> Self {{", - indent_str - ) - .ok(); - writeln!(out, "{} Self {{ inner }}", indent_str).ok(); - writeln!(out, "{} }}", indent_str).ok(); - writeln!(out, "{}}}\n", indent_str).ok(); for method in &ty.methods { render_method_static(out, ty, method, indent); @@ -599,7 +629,7 @@ fn render_method(out: &mut String, ty: &TypeGen, method: &MethodGen, indent: usi .sig .args .iter() - .map(|arg| arg.typed_name().to_string()), + .map(|arg| arg.typed_arg_name().to_string()), ); types.join(", ") }; @@ -608,7 +638,7 @@ fn render_method(out: &mut String, ty: &TypeGen, method: &MethodGen, indent: usi "{} let typed = tvm_ffi::into_typed_fn!(func.clone(), Fn({}) -> Result<{}>);", indent_str, type_list, - method.sig.ret.typed_name() + method.sig.ret.typed_ret_name() ) .ok(); let call_expr = format!("typed({})", render_method_call_args(method)); @@ -616,7 +646,10 @@ fn render_method(out: &mut String, ty: &TypeGen, method: &MethodGen, indent: usi out, "{} {}", indent_str, - method.sig.ret.wrap_return(&call_expr) + method + .sig + .ret + .wrap_typed_return(&call_expr, method.sig.ret.typed_ret_name()) ) .ok(); writeln!(out, "{}}}", indent_str).ok(); @@ -632,7 +665,7 @@ fn render_args(args: &[RustType]) -> String { fn render_type_list(args: &[RustType]) -> String { args.iter() - .map(|arg| arg.typed_name().to_string()) + .map(|arg| arg.typed_arg_name().to_string()) .collect::>() .join(", ") } diff --git a/rust/tvm-ffi-stubgen/src/lib.rs b/rust/tvm-ffi-stubgen/src/lib.rs index 9128e9fc4..553754379 100644 --- a/rust/tvm-ffi-stubgen/src/lib.rs +++ b/rust/tvm-ffi-stubgen/src/lib.rs @@ -77,10 +77,12 @@ pub fn run(args: Args) -> Result<(), Box> { let lib_rs = generate::render_lib_rs(); let functions_rs = generate::render_functions_rs(&functions_root); let types_rs = generate::render_types_rs(&types_root); + let build_rs = generate::render_build_rs(); let src_dir = args.out_dir.join("src"); std::fs::create_dir_all(&src_dir)?; std::fs::write(args.out_dir.join("Cargo.toml"), cargo_toml)?; + std::fs::write(args.out_dir.join("build.rs"), build_rs)?; std::fs::write(src_dir.join("lib.rs"), lib_rs)?; std::fs::write(src_dir.join("functions.rs"), functions_rs)?; std::fs::write(src_dir.join("types.rs"), types_rs)?; diff --git a/rust/tvm-ffi-stubgen/src/model.rs b/rust/tvm-ffi-stubgen/src/model.rs index d4d2dcfdc..237f11ac1 100644 --- a/rust/tvm-ffi-stubgen/src/model.rs +++ b/rust/tvm-ffi-stubgen/src/model.rs @@ -111,13 +111,20 @@ impl RustType { } } - pub(crate) fn typed_name(&self) -> &str { + pub(crate) fn typed_arg_name(&self) -> &str { match self.kind { RustTypeKind::Plain => &self.name, RustTypeKind::ObjectWrapper => "tvm_ffi::object::ObjectRef", } } + pub(crate) fn typed_ret_name(&self) -> &str { + match self.kind { + RustTypeKind::Plain => &self.name, + RustTypeKind::ObjectWrapper => &self.name, + } + } + pub(crate) fn call_expr(&self, arg_name: &str) -> String { match self.kind { RustTypeKind::Plain => arg_name.to_string(), @@ -125,14 +132,18 @@ impl RustType { } } - pub(crate) fn wrap_return(&self, expr: &str) -> String { + pub(crate) fn wrap_typed_return(&self, expr: &str, typed_ret_name: &str) -> String { match self.kind { RustTypeKind::Plain => expr.to_string(), RustTypeKind::ObjectWrapper => { - if self.name == "Self" { - format!("{}.map(Self::from)", expr) + if typed_ret_name == "tvm_ffi::object::ObjectRef" { + if self.name == "Self" { + format!("{}.map(Self::from)", expr) + } else { + format!("{}.map({}::from)", expr, self.name) + } } else { - format!("{}.map({}::from)", expr, self.name) + expr.to_string() } } } diff --git a/rust/tvm-ffi/src/lib.rs b/rust/tvm-ffi/src/lib.rs index 95338a5f6..13baa91a4 100644 --- a/rust/tvm-ffi/src/lib.rs +++ b/rust/tvm-ffi/src/lib.rs @@ -27,6 +27,7 @@ pub mod function; pub mod function_internal; pub mod macros; pub mod object; +pub mod object_wrapper; pub mod string; pub mod type_traits; pub use tvm_ffi_sys; @@ -45,6 +46,7 @@ pub use crate::error::{ pub use crate::extra::module::Module; pub use crate::function::Function; pub use crate::object::{Object, ObjectArc, ObjectCore, ObjectCoreWithExtraItems, ObjectRefCore}; +pub use crate::object_wrapper::ObjectWrapper; pub use crate::string::{Bytes, String}; pub use crate::type_traits::AnyCompatible; diff --git a/rust/tvm-ffi/src/macros.rs b/rust/tvm-ffi/src/macros.rs index bf3f5ac11..4199ffa79 100644 --- a/rust/tvm-ffi/src/macros.rs +++ b/rust/tvm-ffi/src/macros.rs @@ -238,6 +238,58 @@ macro_rules! impl_arg_into_ref { } } +/// Define a stubgen-oriented object wrapper type. +/// +/// This macro is intended for code emitted by the Rust stub generator. +/// It is not meant as a general-purpose user-facing API. +#[macro_export] +macro_rules! define_object_wrapper { + ($name:ident, $type_key:expr) => { + #[derive(Clone)] + pub struct $name { + inner: $crate::object::ObjectRef, + } + + impl $name { + pub fn from_object(inner: $crate::object::ObjectRef) -> Self { + Self { inner } + } + + pub fn as_object_ref(&self) -> &$crate::object::ObjectRef { + &self.inner + } + + pub fn into_object_ref(self) -> $crate::object::ObjectRef { + self.inner + } + } + + impl From<$crate::object::ObjectRef> for $name { + fn from(inner: $crate::object::ObjectRef) -> Self { + Self::from_object(inner) + } + } + + impl $crate::object_wrapper::ObjectWrapper for $name { + const TYPE_KEY: &'static str = $type_key; + + fn from_object(inner: $crate::object::ObjectRef) -> Self { + Self::from_object(inner) + } + + fn as_object_ref(&self) -> &$crate::object::ObjectRef { + self.as_object_ref() + } + + fn into_object_ref(self) -> $crate::object::ObjectRef { + self.into_object_ref() + } + } + + $crate::impl_try_from_any!($name); + }; +} + // ---------------------------------------------------------------------------- // Macros for function definitions // ---------------------------------------------------------------------------- diff --git a/rust/tvm-ffi/src/object_wrapper.rs b/rust/tvm-ffi/src/object_wrapper.rs new file mode 100644 index 000000000..7b9c81fea --- /dev/null +++ b/rust/tvm-ffi/src/object_wrapper.rs @@ -0,0 +1,124 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +use crate::object::{Object, ObjectArc, ObjectRef, ObjectRefCore}; +use crate::type_traits::AnyCompatible; +use tvm_ffi_sys::{ + TVMFFIAny, TVMFFIByteArray, TVMFFIGetTypeInfo, TVMFFIObject, TVMFFITypeKeyToIndex, +}; + +/// Runtime support for stubgen-generated object wrappers. +/// +/// This module is intended for code emitted by the Rust stub generator and is +/// not meant as a general-purpose user-facing API. +pub trait ObjectWrapper: Clone { + const TYPE_KEY: &'static str; + fn from_object(inner: ObjectRef) -> Self; + fn as_object_ref(&self) -> &ObjectRef; + fn into_object_ref(self) -> ObjectRef; +} + +fn type_index_for_key(type_key: &'static str) -> Option { + let key = unsafe { TVMFFIByteArray::from_str(type_key) }; + let mut index = 0i32; + let code = unsafe { TVMFFITypeKeyToIndex(&key, &mut index) }; + if code == 0 { + Some(index) + } else { + None + } +} + +unsafe fn is_instance_type(type_index: i32, target_index: i32) -> bool { + if type_index == target_index { + return true; + } + let info = TVMFFIGetTypeInfo(type_index); + if info.is_null() { + return false; + } + let info = &*info; + let ancestors = info.type_acenstors; + if ancestors.is_null() { + return false; + } + for depth in 0..info.type_depth { + let ancestor = *ancestors.add(depth as usize); + if !ancestor.is_null() && (*ancestor).type_index == target_index { + return true; + } + } + false +} + +unsafe impl AnyCompatible for T { + fn type_str() -> String { + T::TYPE_KEY.to_string() + } + + unsafe fn copy_to_any_view(src: &Self, data: &mut TVMFFIAny) { + let obj = src.as_object_ref(); + let arc = ::data(obj); + let raw = ObjectArc::as_raw(arc) as *mut TVMFFIObject; + data.type_index = (*raw).type_index; + data.small_str_len = 0; + data.data_union.v_obj = raw; + } + + unsafe fn move_to_any(src: Self, data: &mut TVMFFIAny) { + let obj = src.into_object_ref(); + let arc = ::into_data(obj); + let raw = ObjectArc::into_raw(arc) as *mut TVMFFIObject; + data.type_index = (*raw).type_index; + data.small_str_len = 0; + data.data_union.v_obj = raw; + } + + unsafe fn check_any_strict(data: &TVMFFIAny) -> bool { + let Some(target_index) = type_index_for_key(T::TYPE_KEY) else { + return false; + }; + is_instance_type(data.type_index, target_index) + } + + unsafe fn copy_from_any_view_after_check(data: &TVMFFIAny) -> Self { + let ptr = data.data_union.v_obj as *mut TVMFFIObject; + crate::object::unsafe_::inc_ref(ptr); + let arc = ObjectArc::from_raw(ptr as *mut Object); + let obj = ::from_data(arc); + T::from_object(obj) + } + + unsafe fn move_from_any_after_check(data: &mut TVMFFIAny) -> Self { + let ptr = data.data_union.v_obj as *mut TVMFFIObject; + let arc = ObjectArc::from_raw(ptr as *mut Object); + data.type_index = crate::TypeIndex::kTVMFFINone as i32; + data.data_union.v_int64 = 0; + let obj = ::from_data(arc); + T::from_object(obj) + } + + unsafe fn try_cast_from_any_view(data: &TVMFFIAny) -> Result { + if Self::check_any_strict(data) { + Ok(Self::copy_from_any_view_after_check(data)) + } else { + Err(()) + } + } +} From addd67dff2d05446cc9145c24db39a73ef8e2946 Mon Sep 17 00:00:00 2001 From: Huanqi Cao Date: Mon, 9 Feb 2026 20:40:32 +0800 Subject: [PATCH 12/29] Add stubgen fields; first commit to build TileLang IR --- rust/tvm-ffi-macros/src/object_macros.rs | 302 +++++++++++++++-------- rust/tvm-ffi-stubgen/src/generate.rs | 73 +++++- rust/tvm-ffi-stubgen/src/model.rs | 8 + rust/tvm-ffi/src/any.rs | 9 + rust/tvm-ffi/src/error.rs | 7 + rust/tvm-ffi/src/function_internal.rs | 11 + rust/tvm-ffi/src/object_wrapper.rs | 128 +++++++++- 7 files changed, 430 insertions(+), 108 deletions(-) diff --git a/rust/tvm-ffi-macros/src/object_macros.rs b/rust/tvm-ffi-macros/src/object_macros.rs index b6566eb48..01ef8afed 100644 --- a/rust/tvm-ffi-macros/src/object_macros.rs +++ b/rust/tvm-ffi-macros/src/object_macros.rs @@ -124,131 +124,221 @@ pub fn derive_object_ref(input: proc_macro::TokenStream) -> TokenStream { } .expect("First field must be `data: ObjectArc`"); - let mut expanded = quote! { - unsafe impl #tvm_ffi_crate::object::ObjectRefCore for #struct_name { - type ContainerType = <#data_ty as std::ops::Deref>::Target; - #[inline] - fn data(this: &Self) -> &ObjectArc { - &this.data - } - #[inline] - fn into_data(this: Self) -> ObjectArc { - this.data - } - #[inline] - fn from_data(data: ObjectArc) -> Self { - Self { data} - } - } + let is_object_ref = struct_name == syn::Ident::new("ObjectRef", struct_name.span()); - // implement AnyCompatible for #struct_name - unsafe impl #tvm_ffi_crate::type_traits::AnyCompatible for #struct_name { - fn type_str() -> String { - type ContainerType = <#struct_name as #tvm_ffi_crate::object::ObjectRefCore> - ::ContainerType; - ::TYPE_KEY.into() - } + let any_compatible_tokens = if is_object_ref { + quote! { + // implement AnyCompatible for #struct_name + unsafe impl #tvm_ffi_crate::type_traits::AnyCompatible for #struct_name { + fn type_str() -> String { + type ContainerType = <#struct_name as #tvm_ffi_crate::object::ObjectRefCore> + ::ContainerType; + ::TYPE_KEY.into() + } - unsafe fn copy_to_any_view( - src: &Self, - data: &mut #tvm_ffi_crate::tvm_ffi_sys::TVMFFIAny - ) { - type ContainerType = <#struct_name as #tvm_ffi_crate::object::ObjectRefCore> - ::ContainerType; - let type_index = - ::type_index(); - data.type_index = type_index as i32; - data.small_str_len = 0; - let data_ptr = #tvm_ffi_crate::object::ObjectArc::::as_raw( - &src.data - ); - data.data_union.v_obj = - data_ptr as *mut ContainerType as *mut #tvm_ffi_crate::tvm_ffi_sys::TVMFFIObject; - } + unsafe fn copy_to_any_view( + src: &Self, + data: &mut #tvm_ffi_crate::tvm_ffi_sys::TVMFFIAny + ) { + type ContainerType = <#struct_name as #tvm_ffi_crate::object::ObjectRefCore> + ::ContainerType; + let data_ptr = #tvm_ffi_crate::object::ObjectArc::::as_raw( + &src.data + ) as *mut #tvm_ffi_crate::tvm_ffi_sys::TVMFFIObject; + data.type_index = (*data_ptr).type_index; + data.small_str_len = 0; + data.data_union.v_obj = data_ptr; + } - unsafe fn check_any_strict(data: & #tvm_ffi_crate::tvm_ffi_sys::TVMFFIAny) -> bool { - type ContainerType = <#struct_name as #tvm_ffi_crate::object::ObjectRefCore> - ::ContainerType; - let target_index = - ::type_index(); - if data.type_index == target_index as i32 { - return true; + unsafe fn check_any_strict(data: & #tvm_ffi_crate::tvm_ffi_sys::TVMFFIAny) -> bool { + data.type_index >= #tvm_ffi_crate::TypeIndex::kTVMFFIStaticObjectBegin as i32 } - let info = #tvm_ffi_crate::tvm_ffi_sys::TVMFFIGetTypeInfo(data.type_index); - if info.is_null() { - return false; + + unsafe fn copy_from_any_view_after_check( + data: & #tvm_ffi_crate::tvm_ffi_sys::TVMFFIAny + ) -> Self { + type ContainerType = <#struct_name as #tvm_ffi_crate::object::ObjectRefCore> + ::ContainerType; + let data_ptr = data.data_union.v_obj; + // need to increase ref because original weak ptr + // do not own the code + #tvm_ffi_crate::object::unsafe_::inc_ref( + data_ptr as *mut #tvm_ffi_crate::tvm_ffi_sys::TVMFFIObject + ); + Self { + data : #tvm_ffi_crate::object::ObjectArc::from_raw( + data_ptr as *mut ContainerType + ) + } } - let info = &*info; - let ancestors = info.type_acenstors; - if ancestors.is_null() { - return false; + + unsafe fn move_to_any( + src: Self, + data: &mut #tvm_ffi_crate::tvm_ffi_sys::TVMFFIAny + ) { + type ContainerType = <#struct_name as #tvm_ffi_crate::object::ObjectRefCore> + ::ContainerType; + let data_ptr = #tvm_ffi_crate::object::ObjectArc::into_raw( + src.data + ) as *mut #tvm_ffi_crate::tvm_ffi_sys::TVMFFIObject; + data.type_index = (*data_ptr).type_index; + data.small_str_len = 0; + data.data_union.v_obj = data_ptr; } - for depth in 0..info.type_depth { - let ancestor = *ancestors.add(depth as usize); - if !ancestor.is_null() && (*ancestor).type_index == target_index { - return true; + + unsafe fn move_from_any_after_check( + data: &mut #tvm_ffi_crate::tvm_ffi_sys::TVMFFIAny + ) -> Self { + type ContainerType = <#struct_name as #tvm_ffi_crate::object::ObjectRefCore> + ::ContainerType; + let data_ptr = data.data_union.v_obj as *mut ContainerType; + Self { + data : #tvm_ffi_crate::object::ObjectArc::::from_raw(data_ptr) } } - false - } - unsafe fn copy_from_any_view_after_check( - data: & #tvm_ffi_crate::tvm_ffi_sys::TVMFFIAny - ) -> Self { - type ContainerType = <#struct_name as #tvm_ffi_crate::object::ObjectRefCore> - ::ContainerType; - let data_ptr = data.data_union.v_obj; - // need to increase ref because original weak ptr - // do not own the code - #tvm_ffi_crate::object::unsafe_::inc_ref( - data_ptr as *mut #tvm_ffi_crate::tvm_ffi_sys::TVMFFIObject - ); - Self { - data : #tvm_ffi_crate::object::ObjectArc::from_raw( - data_ptr as *mut ContainerType - ) + unsafe fn try_cast_from_any_view( + data: & #tvm_ffi_crate::tvm_ffi_sys::TVMFFIAny + ) -> Result { + if Self::check_any_strict(data) { + Ok(Self::copy_from_any_view_after_check(data)) + } else { + Err(()) + } } } + } + } else { + quote! { + // implement AnyCompatible for #struct_name + unsafe impl #tvm_ffi_crate::type_traits::AnyCompatible for #struct_name { + fn type_str() -> String { + type ContainerType = <#struct_name as #tvm_ffi_crate::object::ObjectRefCore> + ::ContainerType; + ::TYPE_KEY.into() + } - unsafe fn move_to_any( - src: Self, - data: &mut #tvm_ffi_crate::tvm_ffi_sys::TVMFFIAny - ) { - type ContainerType = <#struct_name as #tvm_ffi_crate::object::ObjectRefCore> - ::ContainerType; - let type_index = - ::type_index(); - data.type_index = type_index as i32; - data.small_str_len = 0; - let data_ptr = #tvm_ffi_crate::object::ObjectArc::into_raw( - src.data - ); - data.data_union.v_obj = - data_ptr as *mut ContainerType as *mut #tvm_ffi_crate::tvm_ffi_sys::TVMFFIObject; - } + unsafe fn copy_to_any_view( + src: &Self, + data: &mut #tvm_ffi_crate::tvm_ffi_sys::TVMFFIAny + ) { + type ContainerType = <#struct_name as #tvm_ffi_crate::object::ObjectRefCore> + ::ContainerType; + let type_index = + ::type_index(); + data.type_index = type_index as i32; + data.small_str_len = 0; + let data_ptr = #tvm_ffi_crate::object::ObjectArc::::as_raw( + &src.data + ); + data.data_union.v_obj = + data_ptr as *mut ContainerType as *mut #tvm_ffi_crate::tvm_ffi_sys::TVMFFIObject; + } + + unsafe fn check_any_strict(data: & #tvm_ffi_crate::tvm_ffi_sys::TVMFFIAny) -> bool { + type ContainerType = <#struct_name as #tvm_ffi_crate::object::ObjectRefCore> + ::ContainerType; + let target_index = + ::type_index(); + if data.type_index == target_index as i32 { + return true; + } + let info = #tvm_ffi_crate::tvm_ffi_sys::TVMFFIGetTypeInfo(data.type_index); + if info.is_null() { + return false; + } + let info = &*info; + let ancestors = info.type_acenstors; + if ancestors.is_null() { + return false; + } + for depth in 0..info.type_depth { + let ancestor = *ancestors.add(depth as usize); + if !ancestor.is_null() && (*ancestor).type_index == target_index { + return true; + } + } + false + } - unsafe fn move_from_any_after_check( - data: &mut #tvm_ffi_crate::tvm_ffi_sys::TVMFFIAny - ) -> Self { - type ContainerType = <#struct_name as #tvm_ffi_crate::object::ObjectRefCore> - ::ContainerType; - let data_ptr = data.data_union.v_obj as *mut ContainerType; - Self { - data : #tvm_ffi_crate::object::ObjectArc::::from_raw(data_ptr) + unsafe fn copy_from_any_view_after_check( + data: & #tvm_ffi_crate::tvm_ffi_sys::TVMFFIAny + ) -> Self { + type ContainerType = <#struct_name as #tvm_ffi_crate::object::ObjectRefCore> + ::ContainerType; + let data_ptr = data.data_union.v_obj; + // need to increase ref because original weak ptr + // do not own the code + #tvm_ffi_crate::object::unsafe_::inc_ref( + data_ptr as *mut #tvm_ffi_crate::tvm_ffi_sys::TVMFFIObject + ); + Self { + data : #tvm_ffi_crate::object::ObjectArc::from_raw( + data_ptr as *mut ContainerType + ) + } } - } - unsafe fn try_cast_from_any_view( - data: & #tvm_ffi_crate::tvm_ffi_sys::TVMFFIAny - ) -> Result { - if Self::check_any_strict(data) { - Ok(Self::copy_from_any_view_after_check(data)) - } else { - Err(()) + unsafe fn move_to_any( + src: Self, + data: &mut #tvm_ffi_crate::tvm_ffi_sys::TVMFFIAny + ) { + type ContainerType = <#struct_name as #tvm_ffi_crate::object::ObjectRefCore> + ::ContainerType; + let type_index = + ::type_index(); + data.type_index = type_index as i32; + data.small_str_len = 0; + let data_ptr = #tvm_ffi_crate::object::ObjectArc::into_raw( + src.data + ); + data.data_union.v_obj = + data_ptr as *mut ContainerType as *mut #tvm_ffi_crate::tvm_ffi_sys::TVMFFIObject; + } + + unsafe fn move_from_any_after_check( + data: &mut #tvm_ffi_crate::tvm_ffi_sys::TVMFFIAny + ) -> Self { + type ContainerType = <#struct_name as #tvm_ffi_crate::object::ObjectRefCore> + ::ContainerType; + let data_ptr = data.data_union.v_obj as *mut ContainerType; + Self { + data : #tvm_ffi_crate::object::ObjectArc::::from_raw(data_ptr) + } + } + + unsafe fn try_cast_from_any_view( + data: & #tvm_ffi_crate::tvm_ffi_sys::TVMFFIAny + ) -> Result { + if Self::check_any_strict(data) { + Ok(Self::copy_from_any_view_after_check(data)) + } else { + Err(()) + } } } } }; + + let mut expanded = quote! { + unsafe impl #tvm_ffi_crate::object::ObjectRefCore for #struct_name { + type ContainerType = <#data_ty as std::ops::Deref>::Target; + #[inline] + fn data(this: &Self) -> &ObjectArc { + &this.data + } + #[inline] + fn into_data(this: Self) -> ObjectArc { + this.data + } + #[inline] + fn from_data(data: ObjectArc) -> Self { + Self { data} + } + } + + #any_compatible_tokens + }; // skip ObjectRef since it can create circular dependency with any.rs if struct_name != "ObjectRef" { expanded.extend(quote! { diff --git a/rust/tvm-ffi-stubgen/src/generate.rs b/rust/tvm-ffi-stubgen/src/generate.rs index 869d7f2b4..27af14080 100644 --- a/rust/tvm-ffi-stubgen/src/generate.rs +++ b/rust/tvm-ffi-stubgen/src/generate.rs @@ -17,7 +17,7 @@ use crate::cli::Args; use crate::ffi; -use crate::model::{FunctionGen, FunctionSig, MethodGen, ModuleNode, RustType, TypeGen}; +use crate::model::{FieldGen, FunctionGen, FunctionSig, MethodGen, ModuleNode, RustType, TypeGen}; use crate::schema::{extract_type_schema, parse_type_schema, TypeSchema}; use crate::utils; use std::collections::BTreeMap; @@ -78,6 +78,7 @@ pub(crate) fn build_type_entries( let (mods, name) = split_name(key, prefix); let rust_name = sanitize_ident(&name, IdentStyle::Type); let mut methods = Vec::new(); + let mut fields = Vec::new(); if let Some(info) = ffi::get_type_info(key) { if info.num_methods > 0 && !info.methods.is_null() { let method_slice = @@ -105,6 +106,31 @@ pub(crate) fn build_type_entries( }); } } + if info.num_fields > 0 && !info.fields.is_null() { + let field_slice = + unsafe { std::slice::from_raw_parts(info.fields, info.num_fields as usize) }; + for field in field_slice { + let field_name = match ffi::byte_array_to_string_opt(&field.name) { + Some(name) => name, + None => continue, + }; + let rust_field_name = sanitize_ident(&field_name, IdentStyle::Function); + let meta = ffi::byte_array_to_string_opt(&field.metadata); + let schema = meta + .as_deref() + .and_then(extract_type_schema) + .and_then(|s| parse_type_schema(&s)); + let ty = match schema.as_ref() { + Some(schema) => rust_type_for_schema(schema, type_map, Some(key.as_str())), + None => RustType::unsupported("tvm_ffi::Any"), + }; + fields.push(FieldGen { + name: field_name, + rust_name: rust_field_name, + ty, + }); + } + } } out.push(( mods, @@ -112,6 +138,7 @@ pub(crate) fn build_type_entries( type_key: key.clone(), rust_name, methods, + fields, }, )); } @@ -530,17 +557,61 @@ fn render_type(out: &mut String, ty: &TypeGen, indent: usize) { ) .ok(); + for field in &ty.fields { + render_field_static(out, ty, field, indent); + } for method in &ty.methods { render_method_static(out, ty, method, indent); } writeln!(out, "{}impl {} {{", indent_str, ty.rust_name).ok(); + for field in &ty.fields { + render_field(out, ty, field, indent + 4); + } for method in &ty.methods { render_method(out, ty, method, indent + 4); } writeln!(out, "{}}}\n", indent_str).ok(); } +fn render_field_static(out: &mut String, ty: &TypeGen, field: &FieldGen, indent: usize) { + let indent_str = " ".repeat(indent); + let static_name = static_ident("FIELD", &format!("{}::{}", ty.type_key, field.name)); + writeln!( + out, + "{}static {}: LazyLock> = LazyLock::new(|| tvm_ffi::object_wrapper::FieldGetter::new(\"{}\", \"{}\").expect(\"missing field\"));", + indent_str, static_name, field.ty.name, ty.type_key, field.name + ) + .ok(); +} + +fn render_field(out: &mut String, ty: &TypeGen, field: &FieldGen, indent: usize) { + let indent_str = " ".repeat(indent); + let static_name = static_ident("FIELD", &format!("{}::{}", ty.type_key, field.name)); + writeln!( + out, + "{}pub fn {}(&self) -> Result<{}> {{", + indent_str, field.rust_name, field.ty.name + ) + .ok(); + if field.ty.name == "tvm_ffi::Any" { + writeln!( + out, + "{} {}.get_any(self.as_object_ref())", + indent_str, static_name + ) + .ok(); + } else { + writeln!( + out, + "{} {}.get(self.as_object_ref())", + indent_str, static_name + ) + .ok(); + } + writeln!(out, "{}}}", indent_str).ok(); +} + fn render_method_static(out: &mut String, ty: &TypeGen, method: &MethodGen, indent: usize) { let indent_str = " ".repeat(indent); let static_name = static_ident("METHOD", &format!("{}::{}", ty.type_key, method.rust_name)); diff --git a/rust/tvm-ffi-stubgen/src/model.rs b/rust/tvm-ffi-stubgen/src/model.rs index 237f11ac1..f2c441e5c 100644 --- a/rust/tvm-ffi-stubgen/src/model.rs +++ b/rust/tvm-ffi-stubgen/src/model.rs @@ -52,11 +52,19 @@ pub(crate) struct MethodGen { pub(crate) is_static: bool, } +#[derive(Debug, Clone)] +pub(crate) struct FieldGen { + pub(crate) name: String, + pub(crate) rust_name: String, + pub(crate) ty: RustType, +} + #[derive(Debug, Clone)] pub(crate) struct TypeGen { pub(crate) type_key: String, pub(crate) rust_name: String, pub(crate) methods: Vec, + pub(crate) fields: Vec, } #[derive(Debug, Default)] diff --git a/rust/tvm-ffi/src/any.rs b/rust/tvm-ffi/src/any.rs index 7b5bac401..931398937 100644 --- a/rust/tvm-ffi/src/any.rs +++ b/rust/tvm-ffi/src/any.rs @@ -232,6 +232,15 @@ impl From for AnyValue { } } +impl<'a> TryFrom> for AnyValue { + type Error = crate::error::Error; + + fn try_from(value: AnyView<'a>) -> Result { + Ok(AnyValue::from(Any::from(value))) + } +} + + impl Default for Any { fn default() -> Self { Self::new() diff --git a/rust/tvm-ffi/src/error.rs b/rust/tvm-ffi/src/error.rs index cb6dabb15..2689f3790 100644 --- a/rust/tvm-ffi/src/error.rs +++ b/rust/tvm-ffi/src/error.rs @@ -18,6 +18,7 @@ */ use crate::derive::{Object, ObjectRef}; use crate::object::{Object, ObjectArc}; +use std::convert::Infallible; use std::ffi::c_void; use tvm_ffi_sys::TVMFFIBacktraceUpdateMode::kTVMFFIBacktraceUpdateModeAppend; use tvm_ffi_sys::{ @@ -64,6 +65,12 @@ pub struct Error { data: ObjectArc, } +impl From for Error { + fn from(value: Infallible) -> Self { + match value {} + } +} + /// Default result that uses Error as the error type pub type Result = std::result::Result; diff --git a/rust/tvm-ffi/src/function_internal.rs b/rust/tvm-ffi/src/function_internal.rs index 8bf4505e4..2db0e266b 100644 --- a/rust/tvm-ffi/src/function_internal.rs +++ b/rust/tvm-ffi/src/function_internal.rs @@ -166,6 +166,17 @@ where } } +impl IntoArgHolder for T +where + T: crate::object_wrapper::ObjectWrapper, +{ + type Target = T; + fn into_arg_holder(self) -> Self::Target { + self + } +} + + // helper trait to implement IntoArgHolderTuple to apply into_arg_holder to each element pub trait IntoArgHolderTuple { type Target; diff --git a/rust/tvm-ffi/src/object_wrapper.rs b/rust/tvm-ffi/src/object_wrapper.rs index 7b9c81fea..bd4c3ad2e 100644 --- a/rust/tvm-ffi/src/object_wrapper.rs +++ b/rust/tvm-ffi/src/object_wrapper.rs @@ -17,10 +17,13 @@ * under the License. */ +use crate::any::Any; use crate::object::{Object, ObjectArc, ObjectRef, ObjectRefCore}; use crate::type_traits::AnyCompatible; +use std::marker::PhantomData; use tvm_ffi_sys::{ - TVMFFIAny, TVMFFIByteArray, TVMFFIGetTypeInfo, TVMFFIObject, TVMFFITypeKeyToIndex, + TVMFFIAny, TVMFFIByteArray, TVMFFIFieldGetter, TVMFFIGetTypeInfo, TVMFFIObject, + TVMFFITypeKeyToIndex, }; /// Runtime support for stubgen-generated object wrappers. @@ -34,6 +37,129 @@ pub trait ObjectWrapper: Clone { fn into_object_ref(self) -> ObjectRef; } +struct FieldGetterInner { + offset: usize, + getter: TVMFFIFieldGetter, +} + +impl FieldGetterInner { + fn get_any(&self, obj: &ObjectRef) -> crate::Result { + unsafe { + let arc = ::data(obj); + let raw = ObjectArc::as_raw(arc) as *mut TVMFFIObject; + if raw.is_null() { + crate::bail!(crate::error::ATTRIBUTE_ERROR, "Null object for field access"); + } + let field_ptr = (raw as *mut u8).add(self.offset) as *mut std::ffi::c_void; + let mut out = TVMFFIAny::new(); + crate::check_safe_call!((self.getter)(field_ptr, &mut out))?; + Ok(Any::from_raw_ffi_any(out)) + } + } +} + +pub struct FieldGetter { + inner: FieldGetterInner, + _marker: PhantomData, +} + +// FieldGetter stores only reflection metadata, not values of T. +// It is safe to share across threads regardless of T's Send/Sync. +unsafe impl Send for FieldGetter {} +unsafe impl Sync for FieldGetter {} + +impl FieldGetter { + pub fn new(type_key: &'static str, field_name: &'static str) -> crate::Result { + let inner = resolve_field_by_type_key(type_key, field_name)?; + Ok(Self { + inner, + _marker: PhantomData, + }) + } + + pub fn get_any(&self, obj: &ObjectRef) -> crate::Result { + self.inner.get_any(obj) + } +} + +impl FieldGetter +where + T: TryFrom, + T::Error: Into, +{ + pub fn get(&self, obj: &ObjectRef) -> crate::Result { + self.inner.get_any(obj)?.try_into().map_err(Into::into) + } +} + +fn resolve_field_by_type_key( + type_key: &'static str, + field_name: &'static str, +) -> crate::Result { + unsafe { + let key = TVMFFIByteArray::from_str(type_key); + let mut type_index = 0i32; + crate::check_safe_call!(TVMFFITypeKeyToIndex(&key, &mut type_index))?; + resolve_field_by_type_index(type_index, field_name) + } +} + +fn resolve_field_by_type_index( + type_index: i32, + field_name: &'static str, +) -> crate::Result { + unsafe { + let info = TVMFFIGetTypeInfo(type_index); + if info.is_null() { + crate::bail!( + crate::error::ATTRIBUTE_ERROR, + "Type info missing for field {}", + field_name + ); + } + let info = &*info; + if info.fields.is_null() || info.num_fields <= 0 { + crate::bail!( + crate::error::ATTRIBUTE_ERROR, + "Type {} has no fields", + info.type_key.as_str() + ); + } + let fields = std::slice::from_raw_parts(info.fields, info.num_fields as usize); + for field in fields { + if field.name.as_str() != field_name { + continue; + } + let getter = match field.getter { + Some(getter) => getter, + None => { + crate::bail!( + crate::error::ATTRIBUTE_ERROR, + "Field {} has no getter", + field_name + ); + } + }; + if field.offset < 0 { + crate::bail!( + crate::error::ATTRIBUTE_ERROR, + "Field {} has invalid offset", + field_name + ); + } + return Ok(FieldGetterInner { + offset: field.offset as usize, + getter, + }); + } + crate::bail!( + crate::error::ATTRIBUTE_ERROR, + "Field {} not found", + field_name + ); + } +} + fn type_index_for_key(type_key: &'static str) -> Option { let key = unsafe { TVMFFIByteArray::from_str(type_key) }; let mut index = 0i32; From 4a7b28eecfc3c9b1430661d907c5fa68eb0f6f3e Mon Sep 17 00:00:00 2001 From: Huanqi Cao Date: Mon, 9 Feb 2026 21:38:20 +0800 Subject: [PATCH 13/29] Refine stubgen layout and warnings --- rust/tvm-ffi-stubgen/src/generate.rs | 95 ++++++++++++++++++++++++---- rust/tvm-ffi-stubgen/src/lib.rs | 9 +-- rust/tvm-ffi-stubgen/src/utils.rs | 22 ------- 3 files changed, 88 insertions(+), 38 deletions(-) diff --git a/rust/tvm-ffi-stubgen/src/generate.rs b/rust/tvm-ffi-stubgen/src/generate.rs index 27af14080..a487ab633 100644 --- a/rust/tvm-ffi-stubgen/src/generate.rs +++ b/rust/tvm-ffi-stubgen/src/generate.rs @@ -33,9 +33,9 @@ pub(crate) fn build_type_map(type_keys: &[String], prefix: &str) -> BTreeMap String { +pub(crate) fn render_lib_rs(functions_root: &ModuleNode, types_root: &ModuleNode) -> String { let mut out = String::new(); out.push_str( - r#"pub mod functions; -pub mod types; - -pub use functions::*; -pub use types::*; + r#"pub mod _tvm_ffi_stubgen_detail { + pub mod functions; + pub mod types; +} +"#, + ); + render_facade_module(&mut out, Some(functions_root), Some(types_root), &[], 0, true); + out.push_str( + r#" pub fn load_library(path: &str) -> tvm_ffi::Result { tvm_ffi::Module::load_from_file(path) } @@ -428,7 +432,10 @@ fn main() { pub(crate) fn render_functions_rs(root: &ModuleNode) -> String { let mut out = String::new(); out.push_str( - r#"use std::sync::LazyLock; + r#"#![allow(unused_imports)] +#![allow(non_snake_case, nonstandard_style)] + +use std::sync::LazyLock; use tvm_ffi::{Any, AnyView, Function, Result}; "#, @@ -440,7 +447,10 @@ use tvm_ffi::{Any, AnyView, Function, Result}; pub(crate) fn render_types_rs(root: &ModuleNode) -> String { let mut out = String::new(); out.push_str( - r#"use std::sync::LazyLock; + r#"#![allow(unused_imports)] +#![allow(non_snake_case, nonstandard_style)] + +use std::sync::LazyLock; use tvm_ffi::{Any, AnyView, Result}; "#, @@ -449,6 +459,69 @@ use tvm_ffi::{Any, AnyView, Result}; out } +fn render_facade_module( + out: &mut String, + functions: Option<&ModuleNode>, + types: Option<&ModuleNode>, + path: &[String], + indent: usize, + is_root: bool, +) { + let indent_str = " ".repeat(indent); + if !is_root { + let name = path.last().expect("module path missing"); + writeln!(out, "{}pub mod {} {{", indent_str, name).ok(); + } + + let current_indent = if is_root { indent_str.clone() } else { " ".repeat(indent + 4) }; + let module_path = if path.is_empty() { + String::new() + } else { + format!("::{}", path.join("::")) + }; + + if let Some(node) = functions { + for func in &node.functions { + writeln!( + out, + "{}pub use crate::_tvm_ffi_stubgen_detail::functions{}::{};", + current_indent, module_path, func.rust_name + ) + .ok(); + } + } + if let Some(node) = types { + for ty in &node.types { + writeln!( + out, + "{}pub use crate::_tvm_ffi_stubgen_detail::types{}::{};", + current_indent, module_path, ty.rust_name + ) + .ok(); + } + } + + let mut child_names = std::collections::BTreeSet::new(); + if let Some(node) = functions { + child_names.extend(node.children.keys().cloned()); + } + if let Some(node) = types { + child_names.extend(node.children.keys().cloned()); + } + + for child in child_names { + let mut child_path = path.to_vec(); + child_path.push(child.clone()); + let func_child = functions.and_then(|node| node.children.get(&child)); + let type_child = types.and_then(|node| node.children.get(&child)); + render_facade_module(out, func_child, type_child, &child_path, indent + 4, false); + } + + if !is_root { + writeln!(out, "{}}}", indent_str).ok(); + } +} + fn render_function_module(out: &mut String, node: &ModuleNode, indent: usize) { let indent_str = " ".repeat(indent); if indent > 0 { @@ -498,7 +571,6 @@ fn render_function(out: &mut String, func: &FunctionGen, indent: usize) { ) .ok(); if func.sig.packed { - writeln!(out, "{}#[allow(non_snake_case)]", indent_str).ok(); writeln!( out, "{}pub fn {}(args: &[Any]) -> Result {{", @@ -518,7 +590,6 @@ fn render_function(out: &mut String, func: &FunctionGen, indent: usize) { return; } let args = render_args(&func.sig.args); - writeln!(out, "{}#[allow(non_snake_case)]", indent_str).ok(); writeln!( out, "{}pub fn {}({}) -> Result<{}> {{", diff --git a/rust/tvm-ffi-stubgen/src/lib.rs b/rust/tvm-ffi-stubgen/src/lib.rs index 553754379..0131c427a 100644 --- a/rust/tvm-ffi-stubgen/src/lib.rs +++ b/rust/tvm-ffi-stubgen/src/lib.rs @@ -74,18 +74,19 @@ pub fn run(args: Args) -> Result<(), Box> { let types_root = generate::build_type_modules(types, &prefix); let cargo_toml = generate::render_cargo_toml(&args, &type_map)?; - let lib_rs = generate::render_lib_rs(); + let lib_rs = generate::render_lib_rs(&functions_root, &types_root); let functions_rs = generate::render_functions_rs(&functions_root); let types_rs = generate::render_types_rs(&types_root); let build_rs = generate::render_build_rs(); let src_dir = args.out_dir.join("src"); - std::fs::create_dir_all(&src_dir)?; + let detail_dir = src_dir.join("_tvm_ffi_stubgen_detail"); + std::fs::create_dir_all(&detail_dir)?; std::fs::write(args.out_dir.join("Cargo.toml"), cargo_toml)?; std::fs::write(args.out_dir.join("build.rs"), build_rs)?; std::fs::write(src_dir.join("lib.rs"), lib_rs)?; - std::fs::write(src_dir.join("functions.rs"), functions_rs)?; - std::fs::write(src_dir.join("types.rs"), types_rs)?; + std::fs::write(detail_dir.join("functions.rs"), functions_rs)?; + std::fs::write(detail_dir.join("types.rs"), types_rs)?; Ok(()) } diff --git a/rust/tvm-ffi-stubgen/src/utils.rs b/rust/tvm-ffi-stubgen/src/utils.rs index a7f559c9e..d1ae369e3 100644 --- a/rust/tvm-ffi-stubgen/src/utils.rs +++ b/rust/tvm-ffi-stubgen/src/utils.rs @@ -59,25 +59,3 @@ pub(crate) fn default_tvm_ffi_path() -> Result PathBuf { - let from = from.canonicalize().unwrap_or_else(|_| from.to_path_buf()); - let to = to.canonicalize().unwrap_or_else(|_| to.to_path_buf()); - let from_components: Vec<_> = from.components().collect(); - let to_components: Vec<_> = to.components().collect(); - let mut i = 0; - while i < from_components.len() - && i < to_components.len() - && from_components[i] == to_components[i] - { - i += 1; - } - let mut out = PathBuf::new(); - for _ in i..from_components.len() { - out.push(".."); - } - for comp in &to_components[i..] { - out.push(comp.as_os_str()); - } - out -} From 0724c46268a9ba3649928f54c2a58ac39aa2acb3 Mon Sep 17 00:00:00 2001 From: Huanqi Cao Date: Fri, 13 Feb 2026 22:16:51 +0800 Subject: [PATCH 14/29] fix(stubgen): update test to match _tvm_ffi_stubgen_detail layout The stubgen test was expecting functions.rs at src/functions.rs, but the generator actually writes it to src/_tvm_ffi_stubgen_detail/functions.rs. This commit updates the test to check the correct path. Also includes formatting fixes from cargo fmt. Co-authored-by: Cursor --- rust/tvm-ffi-stubgen/src/generate.rs | 15 +++++++++++++-- rust/tvm-ffi-stubgen/tests/stubgen.rs | 5 ++++- 2 files changed, 17 insertions(+), 3 deletions(-) diff --git a/rust/tvm-ffi-stubgen/src/generate.rs b/rust/tvm-ffi-stubgen/src/generate.rs index a487ab633..77eff566d 100644 --- a/rust/tvm-ffi-stubgen/src/generate.rs +++ b/rust/tvm-ffi-stubgen/src/generate.rs @@ -368,7 +368,14 @@ pub(crate) fn render_lib_rs(functions_root: &ModuleNode, types_root: &ModuleNode "#, ); - render_facade_module(&mut out, Some(functions_root), Some(types_root), &[], 0, true); + render_facade_module( + &mut out, + Some(functions_root), + Some(types_root), + &[], + 0, + true, + ); out.push_str( r#" pub fn load_library(path: &str) -> tvm_ffi::Result { @@ -473,7 +480,11 @@ fn render_facade_module( writeln!(out, "{}pub mod {} {{", indent_str, name).ok(); } - let current_indent = if is_root { indent_str.clone() } else { " ".repeat(indent + 4) }; + let current_indent = if is_root { + indent_str.clone() + } else { + " ".repeat(indent + 4) + }; let module_path = if path.is_empty() { String::new() } else { diff --git a/rust/tvm-ffi-stubgen/tests/stubgen.rs b/rust/tvm-ffi-stubgen/tests/stubgen.rs index 924085d2e..218ef71fd 100644 --- a/rust/tvm-ffi-stubgen/tests/stubgen.rs +++ b/rust/tvm-ffi-stubgen/tests/stubgen.rs @@ -47,7 +47,10 @@ fn stubgen_tvm_ffi_testing() { run(args).expect("stubgen run"); let cargo_toml = out_dir.join("Cargo.toml"); - let functions_rs = out_dir.join("src").join("functions.rs"); + let functions_rs = out_dir + .join("src") + .join("_tvm_ffi_stubgen_detail") + .join("functions.rs"); assert!(cargo_toml.exists(), "Cargo.toml not generated"); assert!(functions_rs.exists(), "functions.rs not generated"); From 98251f545883f59da4e366fbec3a52a27f616557 Mon Sep 17 00:00:00 2001 From: Huanqi Cao Date: Sat, 14 Feb 2026 21:34:41 +0800 Subject: [PATCH 15/29] feat(rust): implement repr(C) zero-cost stubgen with safe Deref-based subtyping MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 实现 repr(C) 零开销 Rust FFI wrapper 生成,使用 Deref trait 建模继承, 生成 100% safe 的用户代码。 ## 核心变更 ### tvm-ffi 库增强 - **subtyping.rs (新文件)**: 子类型转换辅助函数 - `upcast`: 安全的消费型向上转换 - `try_downcast`: 安全的消费型向下转换(使用 runtime type check) - `is_instance_of`: 继承关系运行时检查 - **macros.rs**: 新增 `impl_object_hierarchy!` 宏 - 自动生成 Deref/From/TryFrom impls 用于继承链 - 零开销引用 upcast(通过 Deref auto-coercion) - 安全的消费型转换(使用标准库 trait) - **object_wrapper.rs**: 移除 `is_instance_type`(移至 subtyping.rs) ### tvm-ffi-macros 修复 - **derive(ObjectRef)**: 修复为委托给 ObjectRef 的 AnyCompatible 实现, 消除对 `pub(crate) unsafe_::inc_ref` 的直接调用 - **derive(Object)**: 将 `proc_macro_error::abort!` 替换为 `panic!`, 避免在外部 crate 编译时失败 ### tvm-ffi-stubgen 重构 - **repr_c.rs (新文件)**: C 兼容性检查逻辑 - `check_repr_c`: 验证类型是否可用 repr(C) 表示 - 根据 field size 映射精确的 Rust POD 类型(i8/i16/i32/i64, f32/f64) - 改进 alignment padding 验证 - **generate.rs**: 全新 repr(C) 代码生成路径 - 生成 `#[repr(C)] *Obj` struct(零开销字段访问) - 生成 `#[repr(C)] *Ref` wrapper(user-facing 类型) - 调用 `impl_object_hierarchy!` 自动生成子类型转换 - 只为直接字段生成 getter(继承字段通过 Deref 自动可用) - 过滤内建类型(ffi.*)避免重复定义 - 对 repr(C) 类型使用 `self as &ObjectRef`(deref coercion) - 对 fallback 类型使用 `self.as_object_ref()` - **model.rs**: 扩展 TypeGen - 新增 `ancestor_chain` 字段存储继承链 - 修改 `call_expr` 使用 `Into` 而非 `as_object_ref().clone()` ## 生成代码特性 ✅ **100% Safe Rust**: 无任何 `unsafe` 关键字出现在生成的 stub 代码中 ✅ **零开销字段访问**: repr(C) 类型直接访问字段无 FFI 调用 ✅ **类型安全子类型转换**: 使用标准库 Deref/From/TryFrom trait ✅ **符合 Rust 惯用法**: 自动 upcast、fallible downcast ## 测试验证 使用 tvm_ffi_testing.so 生成 stub crate,编译通过,无任何 unsafe 代码。 Closes: stubgen repr(C) redesign initiative Co-authored-by: Cursor --- rust/tvm-ffi-macros/src/object_macros.rs | 20 +- rust/tvm-ffi-stubgen/src/generate.rs | 295 +++++++++++++++++++++-- rust/tvm-ffi-stubgen/src/lib.rs | 3 +- rust/tvm-ffi-stubgen/src/model.rs | 25 +- rust/tvm-ffi-stubgen/src/repr_c.rs | 243 +++++++++++++++++++ rust/tvm-ffi/src/any.rs | 1 - rust/tvm-ffi/src/function_internal.rs | 1 - rust/tvm-ffi/src/lib.rs | 1 + rust/tvm-ffi/src/macros.rs | 62 +++++ rust/tvm-ffi/src/object_wrapper.rs | 29 +-- rust/tvm-ffi/src/subtyping.rs | 106 ++++++++ 11 files changed, 731 insertions(+), 55 deletions(-) create mode 100644 rust/tvm-ffi-stubgen/src/repr_c.rs create mode 100644 rust/tvm-ffi/src/subtyping.rs diff --git a/rust/tvm-ffi-macros/src/object_macros.rs b/rust/tvm-ffi-macros/src/object_macros.rs index 01ef8afed..9e6e1ac52 100644 --- a/rust/tvm-ffi-macros/src/object_macros.rs +++ b/rust/tvm-ffi-macros/src/object_macros.rs @@ -58,7 +58,7 @@ pub fn derive_object(input: proc_macro::TokenStream) -> TokenStream { &type_key_arg, &mut tindex ); if ret != 0 { - proc_macro_error::abort!("Failed to get type index for type key: {}", #type_key); + panic!("Failed to get type index for type key: {}", #type_key); } tindex } @@ -266,17 +266,13 @@ pub fn derive_object_ref(input: proc_macro::TokenStream) -> TokenStream { ) -> Self { type ContainerType = <#struct_name as #tvm_ffi_crate::object::ObjectRefCore> ::ContainerType; - let data_ptr = data.data_union.v_obj; - // need to increase ref because original weak ptr - // do not own the code - #tvm_ffi_crate::object::unsafe_::inc_ref( - data_ptr as *mut #tvm_ffi_crate::tvm_ffi_sys::TVMFFIObject - ); - Self { - data : #tvm_ffi_crate::object::ObjectArc::from_raw( - data_ptr as *mut ContainerType - ) - } + // Delegate to ObjectRef to handle reference counting + let obj_ref = <#tvm_ffi_crate::object::ObjectRef as #tvm_ffi_crate::type_traits::AnyCompatible>::copy_from_any_view_after_check(data); + // Use public unsafe API to do pointer cast + let arc = <#tvm_ffi_crate::object::ObjectRef as #tvm_ffi_crate::object::ObjectRefCore>::into_data(obj_ref); + let raw = #tvm_ffi_crate::object::ObjectArc::into_raw(arc); + let typed = #tvm_ffi_crate::object::ObjectArc::from_raw(raw as *const ContainerType); + Self { data: typed } } unsafe fn move_to_any( diff --git a/rust/tvm-ffi-stubgen/src/generate.rs b/rust/tvm-ffi-stubgen/src/generate.rs index 77eff566d..0c6509650 100644 --- a/rust/tvm-ffi-stubgen/src/generate.rs +++ b/rust/tvm-ffi-stubgen/src/generate.rs @@ -17,7 +17,10 @@ use crate::cli::Args; use crate::ffi; -use crate::model::{FieldGen, FunctionGen, FunctionSig, MethodGen, ModuleNode, RustType, TypeGen}; +use crate::model::{ + FieldGen, FunctionGen, FunctionSig, GetterSpec, MethodGen, ModuleNode, RustType, TypeGen, +}; +use crate::repr_c; use crate::schema::{extract_type_schema, parse_type_schema, TypeSchema}; use crate::utils; use std::collections::BTreeMap; @@ -79,7 +82,10 @@ pub(crate) fn build_type_entries( let rust_name = sanitize_ident(&name, IdentStyle::Type); let mut methods = Vec::new(); let mut fields = Vec::new(); + let mut type_depth = 0i32; + let repr_c_info = repr_c::check_repr_c(key, type_map); if let Some(info) = ffi::get_type_info(key) { + type_depth = info.type_depth; if info.num_methods > 0 && !info.methods.is_null() { let method_slice = unsafe { std::slice::from_raw_parts(info.methods, info.num_methods as usize) }; @@ -136,15 +142,102 @@ pub(crate) fn build_type_entries( mods, TypeGen { type_key: key.clone(), - rust_name, + rust_name: rust_name.clone(), methods, fields, + type_depth, + repr_c_info: repr_c_info.clone(), + getter_specs: Vec::new(), + ancestor_chain: Vec::new(), }, )); } + // Second pass: fill getter_specs and ancestor_chain for repr_c types in dependency order (base before derived). + let mut type_key_to_idx: BTreeMap = BTreeMap::new(); + for (idx, (_, ty)) in out.iter().enumerate() { + type_key_to_idx.insert(ty.type_key.clone(), idx); + } + let mut order: Vec = (0..out.len()).collect(); + order.sort_by_key(|&i| out[i].1.type_depth); + for &idx in &order { + let (_, ref ty) = out[idx]; + let repr_c_info = match &ty.repr_c_info { + Some(r) => r, + None => continue, + }; + let parent_specs: Vec = + if let Some(ref parent_key) = repr_c_info.parent_type_key { + let parent_idx = *type_key_to_idx.get(parent_key).unwrap_or(&idx); + out[parent_idx].1.getter_specs.clone() + } else { + Vec::new() + }; + let getter_specs = build_getter_specs(&ty.type_key, &ty.repr_c_info, &parent_specs); + + // Build ancestor chain: [DirectParent, Grandparent, ..., ObjectRef] + let ancestor_chain = if let Some(ref parent_key) = repr_c_info.parent_type_key { + if parent_key == "ffi.Object" { + vec!["tvm_ffi::object::ObjectRef".to_string()] + } else if let Some(parent_rust) = type_map.get(parent_key) { + let parent_idx = *type_key_to_idx.get(parent_key).unwrap_or(&idx); + let mut chain = vec![parent_rust.clone()]; + // Inherit parent's ancestors + chain.extend(out[parent_idx].1.ancestor_chain.clone()); + chain + } else { + vec!["tvm_ffi::object::ObjectRef".to_string()] + } + } else { + vec!["tvm_ffi::object::ObjectRef".to_string()] + }; + + out[idx].1.getter_specs = getter_specs; + out[idx].1.ancestor_chain = ancestor_chain; + } Ok(out) } +fn build_getter_specs( + _type_key: &str, + repr_c_info: &Option, + parent_specs: &[GetterSpec], +) -> Vec { + let info = match repr_c_info { + Some(i) => i, + None => return Vec::new(), + }; + let mut specs = Vec::new(); + for parent in parent_specs { + let access_expr = if parent.access_expr.starts_with("self.data.") { + format!( + "self.data.parent.{}", + &parent.access_expr["self.data.".len()..] + ) + } else { + parent.access_expr.clone() + }; + specs.push(GetterSpec { + method_name: parent.method_name.clone(), + access_expr, + ret_type: parent.ret_type.clone(), + }); + } + for f in &info.direct_fields { + let method_name = format!("get_{}", f.rust_name); + let access_expr = if f.is_pod { + format!("self.data.{}", f.rust_name) + } else { + format!("self.data.{}.clone()", f.rust_name) + }; + specs.push(GetterSpec { + method_name, + access_expr, + ret_type: f.rust_type.clone(), + }); + } + specs +} + pub(crate) fn build_function_modules( funcs: Vec<(Vec, FunctionGen)>, _prefix: &str, @@ -451,18 +544,18 @@ use tvm_ffi::{Any, AnyView, Function, Result}; out } -pub(crate) fn render_types_rs(root: &ModuleNode) -> String { +pub(crate) fn render_types_rs(root: &ModuleNode, type_map: &BTreeMap) -> String { let mut out = String::new(); out.push_str( r#"#![allow(unused_imports)] #![allow(non_snake_case, nonstandard_style)] use std::sync::LazyLock; -use tvm_ffi::{Any, AnyView, Result}; +use tvm_ffi::{Any, AnyView, ObjectArc, Result}; "#, ); - render_type_module(&mut out, root, 0); + render_type_module(&mut out, root, 0, type_map); out } @@ -503,6 +596,10 @@ fn render_facade_module( } if let Some(node) = types { for ty in &node.types { + // Skip built-in types that are not generated + if is_builtin_type(&ty.type_key) { + continue; + } writeln!( out, "{}pub use crate::_tvm_ffi_stubgen_detail::types{}::{};", @@ -555,7 +652,12 @@ fn render_function_module(out: &mut String, node: &ModuleNode, indent: usize) { } } -fn render_type_module(out: &mut String, node: &ModuleNode, indent: usize) { +fn render_type_module( + out: &mut String, + node: &ModuleNode, + indent: usize, + type_map: &BTreeMap, +) { let indent_str = " ".repeat(indent); if indent > 0 { writeln!(out, "{}use std::sync::LazyLock;", indent_str).ok(); @@ -563,11 +665,11 @@ fn render_type_module(out: &mut String, node: &ModuleNode, indent: usize) { writeln!(out).ok(); } for ty in &node.types { - render_type(out, ty, indent); + render_type(out, ty, indent, type_map); } for child in node.children.values() { writeln!(out, "{}pub mod {} {{", indent_str, child.name).ok(); - render_type_module(out, child, indent + 4); + render_type_module(out, child, indent + 4, type_map); writeln!(out, "{}}}", indent_str).ok(); } } @@ -630,7 +732,159 @@ fn render_function(out: &mut String, func: &FunctionGen, indent: usize) { writeln!(out).ok(); } -fn render_type(out: &mut String, ty: &TypeGen, indent: usize) { +fn type_key_to_short_rust_name(type_map: &BTreeMap, type_key: &str) -> String { + type_map + .get(type_key) + .and_then(|path| path.split("::").last().map(String::from)) + .unwrap_or_else(|| type_key.to_string()) +} + +fn render_type(out: &mut String, ty: &TypeGen, indent: usize, type_map: &BTreeMap) { + // Filter out built-in types that are already provided by tvm-ffi + if is_builtin_type(&ty.type_key) { + return; + } + + let _indent_str = " ".repeat(indent); + if let Some(ref info) = ty.repr_c_info { + render_repr_c_type(out, ty, info, indent, type_map); + return; + } + render_fallback_type(out, ty, indent); +} + +fn is_builtin_type(type_key: &str) -> bool { + // Filter ffi.* primitive types and aliases that are provided by tvm-ffi + matches!( + type_key, + "ffi.Object" + | "ffi.String" + | "ffi.Function" + | "ffi.Module" + | "ffi.Tensor" + | "ffi.Shape" + | "ffi.Array" + | "ffi.Map" + | "ffi.Bytes" + | "ffi.SmallStr" + | "ffi.SmallBytes" + | "DLTensor*" + | "DataType" + | "Device" + | "bool" + | "int" + | "float" + ) +} + +fn render_repr_c_type( + out: &mut String, + ty: &TypeGen, + info: &repr_c::ReprCInfo, + indent: usize, + _type_map: &BTreeMap, +) { + let indent_str = " ".repeat(indent); + let obj_name = format!("{}Obj", ty.rust_name); + + // Determine parent type for *Obj struct + let parent_ty = match &info.parent_type_key { + None => "tvm_ffi::object::Object".to_string(), + Some(parent_key) if parent_key == "ffi.Object" => "tvm_ffi::object::Object".to_string(), + Some(parent_key) => { + // Use the type from type_map to get the full Rust path + let parent_rust = _type_map + .get(parent_key) + .map(|s| s.clone()) + .unwrap_or_else(|| format!("{}Obj", sanitize_ident(parent_key, IdentStyle::Type))); + // Extract just the type name and append "Obj" + if let Some(last) = parent_rust.split("::").last() { + format!("{}Obj", last) + } else { + format!("{}Obj", sanitize_ident(parent_key, IdentStyle::Type)) + } + } + }; + + // Generate *Obj struct with #[repr(C)] + writeln!(out, "{}#[repr(C)]", indent_str).ok(); + writeln!(out, "{}#[derive(tvm_ffi::derive::Object)]", indent_str).ok(); + writeln!(out, "{}#[type_key = \"{}\"]", indent_str, ty.type_key).ok(); + writeln!(out, "{}pub struct {} {{", indent_str, obj_name).ok(); + writeln!(out, "{} parent: {},", indent_str, parent_ty).ok(); + for f in &info.direct_fields { + writeln!(out, "{} {}: {},", indent_str, f.rust_name, f.rust_type).ok(); + } + writeln!(out, "{}}}\n", indent_str).ok(); + + // Generate *Ref wrapper with #[repr(C)] + writeln!(out, "{}#[repr(C)]", indent_str).ok(); + writeln!( + out, + "{}#[derive(tvm_ffi::derive::ObjectRef, Clone)]", + indent_str + ) + .ok(); + writeln!(out, "{}pub struct {} {{", indent_str, ty.rust_name).ok(); + writeln!( + out, + "{} data: tvm_ffi::object::ObjectArc<{}>,", + indent_str, obj_name + ) + .ok(); + writeln!(out, "{}}}\n", indent_str).ok(); + + // Generate impl_object_hierarchy! macro call + if !ty.ancestor_chain.is_empty() { + write!( + out, + "{}tvm_ffi::impl_object_hierarchy!({}:", + indent_str, ty.rust_name + ) + .ok(); + for (i, ancestor) in ty.ancestor_chain.iter().enumerate() { + if i == 0 { + write!(out, " {}", ancestor).ok(); + } else { + write!(out, ", {}", ancestor).ok(); + } + } + writeln!(out, ");").ok(); + writeln!(out).ok(); + } + + // Generate getter methods for direct fields only + writeln!(out, "{}impl {} {{", indent_str, ty.rust_name).ok(); + for f in &info.direct_fields { + let method_name = format!("get_{}", f.rust_name); + let access_expr = if f.is_pod { + format!("self.data.{}", f.rust_name) + } else { + format!("self.data.{}.clone()", f.rust_name) + }; + writeln!( + out, + "{} pub fn {}(&self) -> {} {{", + indent_str, method_name, f.rust_type + ) + .ok(); + writeln!(out, "{} {}", indent_str, access_expr).ok(); + writeln!(out, "{} }}", indent_str).ok(); + } + writeln!(out, "{}}}\n", indent_str).ok(); + + // Generate method statics and impls + for method in &ty.methods { + render_method_static(out, ty, method, indent); + } + writeln!(out, "{}impl {} {{", indent_str, ty.rust_name).ok(); + for method in &ty.methods { + render_method(out, ty, method, indent + 4); + } + writeln!(out, "{}}}\n", indent_str).ok(); +} + +fn render_fallback_type(out: &mut String, ty: &TypeGen, indent: usize) { let indent_str = " ".repeat(indent); writeln!( out, @@ -741,12 +995,23 @@ fn render_method(out: &mut String, ty: &TypeGen, method: &MethodGen, indent: usi indent_str ) .ok(); - writeln!( - out, - "{} views.push(AnyView::from(self.as_object_ref()));", - indent_str - ) - .ok(); + // For repr(C) types, use deref coercion to upcast to ObjectRef + // For ObjectWrapper types, use the as_object_ref() method + if ty.repr_c_info.is_some() { + writeln!( + out, + "{} views.push(AnyView::from(self as &tvm_ffi::object::ObjectRef));", + indent_str + ) + .ok(); + } else { + writeln!( + out, + "{} views.push(AnyView::from(self.as_object_ref()));", + indent_str + ) + .ok(); + } writeln!( out, "{} views.extend(args.iter().map(AnyView::from));", diff --git a/rust/tvm-ffi-stubgen/src/lib.rs b/rust/tvm-ffi-stubgen/src/lib.rs index 0131c427a..faa9c18cd 100644 --- a/rust/tvm-ffi-stubgen/src/lib.rs +++ b/rust/tvm-ffi-stubgen/src/lib.rs @@ -19,6 +19,7 @@ mod cli; mod ffi; mod generate; mod model; +mod repr_c; mod schema; mod utils; @@ -76,7 +77,7 @@ pub fn run(args: Args) -> Result<(), Box> { let cargo_toml = generate::render_cargo_toml(&args, &type_map)?; let lib_rs = generate::render_lib_rs(&functions_root, &types_root); let functions_rs = generate::render_functions_rs(&functions_root); - let types_rs = generate::render_types_rs(&types_root); + let types_rs = generate::render_types_rs(&types_root, &type_map); let build_rs = generate::render_build_rs(); let src_dir = args.out_dir.join("src"); diff --git a/rust/tvm-ffi-stubgen/src/model.rs b/rust/tvm-ffi-stubgen/src/model.rs index f2c441e5c..994a116cd 100644 --- a/rust/tvm-ffi-stubgen/src/model.rs +++ b/rust/tvm-ffi-stubgen/src/model.rs @@ -59,12 +59,32 @@ pub(crate) struct FieldGen { pub(crate) ty: RustType, } +/// Spec for a single get_* method on a repr(C) Ref type. +#[derive(Debug, Clone)] +pub(crate) struct GetterSpec { + /// Method name, e.g. "get_first" + pub(crate) method_name: String, + /// Expression to produce the value, e.g. "self.data.first.clone()" + pub(crate) access_expr: String, + /// Return type, e.g. "Shape" or "i64" + pub(crate) ret_type: String, +} + #[derive(Debug, Clone)] pub(crate) struct TypeGen { pub(crate) type_key: String, pub(crate) rust_name: String, pub(crate) methods: Vec, pub(crate) fields: Vec, + /// Depth in inheritance hierarchy (0 = Object, 1 = direct subclass of Object, ...). + pub(crate) type_depth: i32, + /// If Some, type passes check_repr_c and we generate repr(C) *Obj + *Ref. + pub(crate) repr_c_info: Option, + /// Getter specs for repr_c types (get_* methods). Empty for non-repr_c. + pub(crate) getter_specs: Vec, + /// Ancestor chain for repr_c types: [DirectParent, Grandparent, ..., ObjectRef]. + /// Empty for non-repr_c or types without proper hierarchy info. + pub(crate) ancestor_chain: Vec, } #[derive(Debug, Default)] @@ -136,7 +156,10 @@ impl RustType { pub(crate) fn call_expr(&self, arg_name: &str) -> String { match self.kind { RustTypeKind::Plain => arg_name.to_string(), - RustTypeKind::ObjectWrapper => format!("{}.as_object_ref().clone()", arg_name), + RustTypeKind::ObjectWrapper => { + // Use Into trait for upcast + format!("{}.into()", arg_name) + } } } diff --git a/rust/tvm-ffi-stubgen/src/repr_c.rs b/rust/tvm-ffi-stubgen/src/repr_c.rs new file mode 100644 index 000000000..381de02b2 --- /dev/null +++ b/rust/tvm-ffi-stubgen/src/repr_c.rs @@ -0,0 +1,243 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Validation that a type has a compact C-compatible layout (check_repr_c) +//! and extraction of field layout for repr(C) code generation. + +use crate::ffi; +use crate::schema::{extract_type_schema, parse_type_schema, TypeSchema}; +use std::collections::BTreeMap; + +/// Result of check_repr_c: type passes and we have full layout for codegen. +#[derive(Debug, Clone)] +pub(crate) struct ReprCInfo { + /// Type key of the immediate parent (Object or a subclass). None for root types. + pub(crate) parent_type_key: Option, + /// Total size of the struct in bytes. + pub(crate) total_size: i32, + /// Direct fields of this type only (not inherited), sorted by offset. + /// For codegen: first field of *Obj is parent (or Object), then these. + pub(crate) direct_fields: Vec, +} + +#[derive(Debug, Clone)] +pub(crate) struct ReprCField { + pub(crate) name: String, + pub(crate) rust_name: String, + pub(crate) offset: i64, + pub(crate) size: i64, + pub(crate) alignment: i64, + /// Rust type name for the field (e.g. "i64", "Shape"). + pub(crate) rust_type: String, + /// True if Copy type (getter returns value); false if Ref (getter returns Ref and clone). + pub(crate) is_pod: bool, +} + +/// Returns ReprCInfo if the type passes check_repr_c; None otherwise. +pub(crate) fn check_repr_c( + type_key: &str, + type_map: &BTreeMap, +) -> Option { + let info = ffi::get_type_info(type_key)?; + let total_size = total_size_from_info(info)?; + if total_size <= 0 { + return None; + } + + let parent_type_key = if info.type_depth > 0 && !info.type_acenstors.is_null() { + let ancestor_ptr = unsafe { *info.type_acenstors.add(0) }; + if ancestor_ptr.is_null() { + return None; + } + let parent_info = unsafe { &*ancestor_ptr }; + let key = ffi::byte_array_to_string_opt(&parent_info.type_key)?; + if !check_repr_c(&key, type_map).is_some() { + return None; + } + Some(key) + } else { + None + }; + + let parent_total_size: i64 = if let Some(ref parent_key) = parent_type_key { + let parent_info = ffi::get_type_info(parent_key)?; + total_size_from_info(parent_info)? as i64 + } else { + // Root type: first field starts after Object header. Use Object's registered size. + let obj_info = ffi::get_type_info("ffi.Object")?; + total_size_from_info(obj_info)? as i64 + }; + + let mut direct_fields: Vec = Vec::new(); + if info.num_fields > 0 && !info.fields.is_null() { + let field_slice = + unsafe { std::slice::from_raw_parts(info.fields, info.num_fields as usize) }; + for field in field_slice { + let name = ffi::byte_array_to_string_opt(&field.name)?; + if field.offset < 0 || field.size < 0 || field.alignment <= 0 { + return None; + } + let meta = ffi::byte_array_to_string_opt(&field.metadata); + let schema = meta + .as_deref() + .and_then(extract_type_schema) + .and_then(|s| parse_type_schema(&s)); + let (rust_type, is_pod) = + repr_c_field_type(schema.as_ref(), type_map, type_key, field.size)?; + direct_fields.push(ReprCField { + rust_name: sanitize_ident(&name, IdentStyle::Function), + name, + offset: field.offset, + size: field.size, + alignment: field.alignment, + rust_type, + is_pod, + }); + } + } + + direct_fields.sort_by_key(|f| f.offset); + + let first_offset = direct_fields + .first() + .map(|f| f.offset) + .unwrap_or(parent_total_size); + if first_offset != parent_total_size { + return None; + } + + let mut pos = parent_total_size; + for f in &direct_fields { + // Allow alignment padding before fields, but no overlap + if f.offset < pos { + return None; + } + pos = f.offset + f.size; + } + if pos != total_size as i64 { + return None; + } + + Some(ReprCInfo { + parent_type_key, + total_size, + direct_fields, + }) +} + +fn total_size_from_info(info: &tvm_ffi::tvm_ffi_sys::TVMFFITypeInfo) -> Option { + if info.metadata.is_null() { + return None; + } + let meta = unsafe { &*info.metadata }; + if meta.total_size <= 0 { + return None; + } + Some(meta.total_size) +} + +/// Map schema to (rust_type_name, is_pod). None if not repr_c compatible. +fn repr_c_field_type( + schema: Option<&TypeSchema>, + type_map: &BTreeMap, + _self_type_key: &str, + field_size: i64, +) -> Option<(String, bool)> { + let schema = schema?; + match schema.origin.as_str() { + "bool" => Some(("bool".to_string(), true)), + "int" => match field_size { + 1 => Some(("i8".to_string(), true)), + 2 => Some(("i16".to_string(), true)), + 4 => Some(("i32".to_string(), true)), + 8 => Some(("i64".to_string(), true)), + _ => None, // Unsupported int size + }, + "float" => match field_size { + 4 => Some(("f32".to_string(), true)), + 8 => Some(("f64".to_string(), true)), + _ => None, // Unsupported float size + }, + "Device" => Some(("tvm_ffi::DLDevice".to_string(), true)), + "DataType" => Some(("tvm_ffi::DLDataType".to_string(), true)), + "ffi.String" | "std::string" | "const char*" | "ffi.SmallStr" => { + Some(("tvm_ffi::String".to_string(), false)) + } + "ffi.Bytes" | "ffi.SmallBytes" => Some(("tvm_ffi::Bytes".to_string(), false)), + "ffi.Function" => Some(("tvm_ffi::Function".to_string(), false)), + "ffi.Object" => Some(("tvm_ffi::object::ObjectRef".to_string(), false)), + "ffi.Shape" => Some(("tvm_ffi::Shape".to_string(), false)), + "ffi.Module" => Some(("tvm_ffi::Module".to_string(), false)), + "ffi.Tensor" | "DLTensor*" => Some(("tvm_ffi::Tensor".to_string(), false)), + "Optional" => match schema.args.as_slice() { + [inner] => repr_c_field_type(Some(inner), type_map, _self_type_key, field_size) + .map(|(inner_ty, pod)| (format!("Option<{}>", inner_ty), pod)), + _ => None, + }, + "ffi.Array" => match schema.args.as_slice() { + [inner] => { + let (inner_ty, _) = + repr_c_field_type(Some(inner), type_map, _self_type_key, field_size)?; + Some((format!("tvm_ffi::Array<{}>", inner_ty), false)) + } + _ => None, + }, + "ffi.Map" => match schema.args.as_slice() { + [k, v] => { + let (k_ty, _) = repr_c_field_type(Some(k), type_map, _self_type_key, field_size)?; + let (v_ty, _) = repr_c_field_type(Some(v), type_map, _self_type_key, field_size)?; + Some((format!("tvm_ffi::Map<{}, {}>", k_ty, v_ty), false)) + } + _ => None, + }, + other => type_map.get(other).map(|path| (path.clone(), false)), + } +} + +#[derive(Clone, Copy)] +enum IdentStyle { + Function, +} + +fn sanitize_ident(name: &str, style: IdentStyle) -> String { + let mut out = String::new(); + for ch in name.chars() { + if ch.is_ascii_alphanumeric() || ch == '_' { + out.push(ch); + } else { + out.push('_'); + } + } + if out.is_empty() { + out.push('_'); + } + if out.chars().next().unwrap().is_ascii_digit() { + out.insert(0, '_'); + } + const KEYWORDS: &[&str] = &[ + "as", "break", "const", "continue", "crate", "else", "enum", "extern", "false", "fn", + "for", "if", "in", "let", "loop", "match", "move", "mut", "pub", "ref", "return", "self", + "Self", "static", "struct", "super", "trait", "true", "type", "unsafe", "use", "where", + "while", "async", "await", "dyn", + ]; + if KEYWORDS.contains(&out.as_str()) { + out.push('_'); + } + match style { + IdentStyle::Function => out, + } +} diff --git a/rust/tvm-ffi/src/any.rs b/rust/tvm-ffi/src/any.rs index 931398937..c769494a7 100644 --- a/rust/tvm-ffi/src/any.rs +++ b/rust/tvm-ffi/src/any.rs @@ -240,7 +240,6 @@ impl<'a> TryFrom> for AnyValue { } } - impl Default for Any { fn default() -> Self { Self::new() diff --git a/rust/tvm-ffi/src/function_internal.rs b/rust/tvm-ffi/src/function_internal.rs index 2db0e266b..b10b7a89a 100644 --- a/rust/tvm-ffi/src/function_internal.rs +++ b/rust/tvm-ffi/src/function_internal.rs @@ -176,7 +176,6 @@ where } } - // helper trait to implement IntoArgHolderTuple to apply into_arg_holder to each element pub trait IntoArgHolderTuple { type Target; diff --git a/rust/tvm-ffi/src/lib.rs b/rust/tvm-ffi/src/lib.rs index 13baa91a4..2482a3a04 100644 --- a/rust/tvm-ffi/src/lib.rs +++ b/rust/tvm-ffi/src/lib.rs @@ -29,6 +29,7 @@ pub mod macros; pub mod object; pub mod object_wrapper; pub mod string; +pub mod subtyping; pub mod type_traits; pub use tvm_ffi_sys; diff --git a/rust/tvm-ffi/src/macros.rs b/rust/tvm-ffi/src/macros.rs index 4199ffa79..37390ddea 100644 --- a/rust/tvm-ffi/src/macros.rs +++ b/rust/tvm-ffi/src/macros.rs @@ -290,6 +290,68 @@ macro_rules! define_object_wrapper { }; } +/// Implement object hierarchy relationships (Deref, From, TryFrom). +/// +/// This macro is intended for code emitted by the Rust stub generator to +/// establish parent-child relationships in the object hierarchy. +/// +/// # Syntax +/// ```ignore +/// impl_object_hierarchy!(Self: DirectParent, Grandparent, ..., ObjectRef); +/// ``` +/// +/// # Generated implementations +/// - `Deref` for ergonomic field access +/// - `From for DirectParent` (and all ancestors) for upcasts +/// - `TryFrom for Self` for downcasts +/// +/// # Example +/// ```ignore +/// // Given: Node -> BaseExpr -> Expr -> ObjectRef +/// impl_object_hierarchy!(Node: BaseExpr, Expr, ObjectRef); +/// ``` +#[macro_export] +macro_rules! impl_object_hierarchy { + ($self_ty:ty: $direct_parent:ty $(, $ancestor:ty)* $(,)?) => { + // Implement Deref to the direct parent for ergonomic access + impl std::ops::Deref for $self_ty { + type Target = $direct_parent; + + fn deref(&self) -> &Self::Target { + // Safety: All ObjectRef types are repr(C) with a single pointer field (ObjectArc). + // Self and DirectParent have identical memory layout. + // This is a zero-cost, lifetime-preserving reference cast. + unsafe { &*(self as *const $self_ty as *const $direct_parent) } + } + } + + // Implement From for DirectParent (upcast) + impl From<$self_ty> for $direct_parent { + fn from(value: $self_ty) -> Self { + $crate::subtyping::upcast(value) + } + } + + // Implement From for each ancestor (transitive upcast) + $( + impl From<$self_ty> for $ancestor { + fn from(value: $self_ty) -> Self { + $crate::subtyping::upcast(value) + } + } + + // Implement TryFrom for Self (downcast) + impl TryFrom<$ancestor> for $self_ty { + type Error = $ancestor; + + fn try_from(value: $ancestor) -> Result { + $crate::subtyping::try_downcast(value) + } + } + )* + }; +} + // ---------------------------------------------------------------------------- // Macros for function definitions // ---------------------------------------------------------------------------- diff --git a/rust/tvm-ffi/src/object_wrapper.rs b/rust/tvm-ffi/src/object_wrapper.rs index bd4c3ad2e..ad2ea3086 100644 --- a/rust/tvm-ffi/src/object_wrapper.rs +++ b/rust/tvm-ffi/src/object_wrapper.rs @@ -48,7 +48,10 @@ impl FieldGetterInner { let arc = ::data(obj); let raw = ObjectArc::as_raw(arc) as *mut TVMFFIObject; if raw.is_null() { - crate::bail!(crate::error::ATTRIBUTE_ERROR, "Null object for field access"); + crate::bail!( + crate::error::ATTRIBUTE_ERROR, + "Null object for field access" + ); } let field_ptr = (raw as *mut u8).add(self.offset) as *mut std::ffi::c_void; let mut out = TVMFFIAny::new(); @@ -171,28 +174,6 @@ fn type_index_for_key(type_key: &'static str) -> Option { } } -unsafe fn is_instance_type(type_index: i32, target_index: i32) -> bool { - if type_index == target_index { - return true; - } - let info = TVMFFIGetTypeInfo(type_index); - if info.is_null() { - return false; - } - let info = &*info; - let ancestors = info.type_acenstors; - if ancestors.is_null() { - return false; - } - for depth in 0..info.type_depth { - let ancestor = *ancestors.add(depth as usize); - if !ancestor.is_null() && (*ancestor).type_index == target_index { - return true; - } - } - false -} - unsafe impl AnyCompatible for T { fn type_str() -> String { T::TYPE_KEY.to_string() @@ -220,7 +201,7 @@ unsafe impl AnyCompatible for T { let Some(target_index) = type_index_for_key(T::TYPE_KEY) else { return false; }; - is_instance_type(data.type_index, target_index) + crate::subtyping::is_instance_of(data.type_index, target_index) } unsafe fn copy_from_any_view_after_check(data: &TVMFFIAny) -> Self { diff --git a/rust/tvm-ffi/src/subtyping.rs b/rust/tvm-ffi/src/subtyping.rs new file mode 100644 index 000000000..d1c944790 --- /dev/null +++ b/rust/tvm-ffi/src/subtyping.rs @@ -0,0 +1,106 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +//! Subtyping infrastructure for object hierarchy conversions. +//! +//! This module provides type-safe upcast and downcast operations for objects +//! that follow the TVM FFI object hierarchy. + +use crate::object::{Object, ObjectArc, ObjectCore, ObjectRefCore}; +use tvm_ffi_sys::TVMFFIGetTypeInfo; + +/// Check if a type_index is an instance of target_index (including inheritance). +/// +/// # Safety +/// This function accesses the type info table via FFI and follows ancestor pointers. +pub unsafe fn is_instance_of(type_index: i32, target_index: i32) -> bool { + if type_index == target_index { + return true; + } + let info = TVMFFIGetTypeInfo(type_index); + if info.is_null() { + return false; + } + let info = &*info; + let ancestors = info.type_acenstors; + if ancestors.is_null() { + return false; + } + for depth in 0..info.type_depth { + let ancestor = *ancestors.add(depth as usize); + if !ancestor.is_null() && (*ancestor).type_index == target_index { + return true; + } + } + false +} + +/// Upcast an object reference from a subtype to a supertype. +/// +/// This is a consuming operation that transfers ownership. +/// +/// # Type Parameters +/// * `From` - The source type (subtype) +/// * `To` - The target type (supertype) +/// +/// # Safety +/// The caller must ensure that `To` is a valid supertype of `From`. +/// This is typically enforced by the `impl_object_hierarchy!` macro. +pub fn upcast(value: From) -> To { + unsafe { + let arc = ::into_data(value); + let raw = ObjectArc::into_raw(arc); + let casted = ObjectArc::from_raw(raw as *const ::ContainerType); + ::from_data(casted) + } +} + +/// Try to downcast an object reference from a supertype to a subtype. +/// +/// This is a consuming operation that transfers ownership on success. +/// +/// # Type Parameters +/// * `From` - The source type (supertype) +/// * `To` - The target type (subtype) +/// +/// # Returns +/// * `Ok(To)` - If the runtime type check succeeds +/// * `Err(From)` - If the runtime type check fails, returns the original value +/// +/// # Safety +/// This function performs runtime type checking using the TVM FFI type system. +pub fn try_downcast(value: From) -> Result { + unsafe { + let arc = ::data(&value); + let raw = ObjectArc::as_raw(arc) as *const Object as *const tvm_ffi_sys::TVMFFIObject; + let type_index = (*raw).type_index; + let target_index = ::ContainerType::type_index(); + + if is_instance_of(type_index, target_index) { + // Type check passed, perform the downcast + let arc = ::into_data(value); + let raw = ObjectArc::into_raw(arc); + let casted = ObjectArc::from_raw(raw as *const ::ContainerType); + Ok(::from_data(casted)) + } else { + // Type check failed, return the original value + Err(value) + } + } +} From c21e63511bfaebae796fd224f4bba96ab8b1d3cb Mon Sep 17 00:00:00 2001 From: Huanqi Cao Date: Sat, 14 Feb 2026 22:25:21 +0800 Subject: [PATCH 16/29] =?UTF-8?q?stubgen=20=E7=AC=AC=E4=BA=8C=E8=BD=AE?= =?UTF-8?q?=E4=BF=AE=E5=A4=8D=EF=BC=9A=E5=AD=90=E7=B1=BB=E5=9E=8B=E8=BD=AC?= =?UTF-8?q?=E6=8D=A2=E3=80=81check=5Frepr=5Fc=E3=80=81fallback=20=E4=B8=8E?= =?UTF-8?q?=E5=AF=BC=E5=87=BA=E6=B8=85=E7=90=86?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - impl_object_hierarchy!: 补充 TryFrom for Self,覆盖根类型 ObjectRef 下转 - subtyping: upcast/try_downcast 改为 #[doc(hidden)] pub,供宏展开使用,用户统一走 From/TryFrom - check_repr_c: 仅允许对齐 padding (field.offset == align_up(pos, alignment)),并补齐 Any 映射以支持 Map/Array - define_object_wrapper!: 增加 From for ObjectRef,保证 .into() 兼容 - generate: is_builtin_type 增加 None;render_facade_module 跳过空模块,避免空 pub mod ffi Co-authored-by: Cursor --- rust/tvm-ffi-stubgen/src/generate.rs | 28 ++++++++++++++++++++-------- rust/tvm-ffi-stubgen/src/repr_c.rs | 14 ++++++++++++-- rust/tvm-ffi/src/macros.rs | 15 +++++++++++++++ rust/tvm-ffi/src/subtyping.rs | 14 +++++++++----- 4 files changed, 56 insertions(+), 15 deletions(-) diff --git a/rust/tvm-ffi-stubgen/src/generate.rs b/rust/tvm-ffi-stubgen/src/generate.rs index 0c6509650..e8e2edbb1 100644 --- a/rust/tvm-ffi-stubgen/src/generate.rs +++ b/rust/tvm-ffi-stubgen/src/generate.rs @@ -567,6 +567,25 @@ fn render_facade_module( indent: usize, is_root: bool, ) { + // Check if this module has any actual content + let has_functions = functions.map_or(false, |node| !node.functions.is_empty()); + let has_types = types.map_or(false, |node| { + node.types.iter().any(|ty| !is_builtin_type(&ty.type_key)) + }); + + let mut child_names = std::collections::BTreeSet::new(); + if let Some(node) = functions { + child_names.extend(node.children.keys().cloned()); + } + if let Some(node) = types { + child_names.extend(node.children.keys().cloned()); + } + + // Skip rendering if the module is empty and has no children + if !is_root && !has_functions && !has_types && child_names.is_empty() { + return; + } + let indent_str = " ".repeat(indent); if !is_root { let name = path.last().expect("module path missing"); @@ -609,14 +628,6 @@ fn render_facade_module( } } - let mut child_names = std::collections::BTreeSet::new(); - if let Some(node) = functions { - child_names.extend(node.children.keys().cloned()); - } - if let Some(node) = types { - child_names.extend(node.children.keys().cloned()); - } - for child in child_names { let mut child_path = path.to_vec(); child_path.push(child.clone()); @@ -774,6 +785,7 @@ fn is_builtin_type(type_key: &str) -> bool { | "bool" | "int" | "float" + | "None" ) } diff --git a/rust/tvm-ffi-stubgen/src/repr_c.rs b/rust/tvm-ffi-stubgen/src/repr_c.rs index 381de02b2..66577d3cd 100644 --- a/rust/tvm-ffi-stubgen/src/repr_c.rs +++ b/rust/tvm-ffi-stubgen/src/repr_c.rs @@ -122,8 +122,9 @@ pub(crate) fn check_repr_c( let mut pos = parent_total_size; for f in &direct_fields { - // Allow alignment padding before fields, but no overlap - if f.offset < pos { + // Only allow alignment padding (field must start at aligned position) + let aligned_pos = align_up(pos, f.alignment); + if f.offset != aligned_pos { return None; } pos = f.offset + f.size; @@ -150,6 +151,14 @@ fn total_size_from_info(info: &tvm_ffi::tvm_ffi_sys::TVMFFITypeInfo) -> Option i64 { + if alignment <= 0 { + return value; + } + (value + alignment - 1) / alignment * alignment +} + /// Map schema to (rust_type_name, is_pod). None if not repr_c compatible. fn repr_c_field_type( schema: Option<&TypeSchema>, @@ -159,6 +168,7 @@ fn repr_c_field_type( ) -> Option<(String, bool)> { let schema = schema?; match schema.origin.as_str() { + "Any" => Some(("tvm_ffi::AnyValue".to_string(), false)), "bool" => Some(("bool".to_string(), true)), "int" => match field_size { 1 => Some(("i8".to_string(), true)), diff --git a/rust/tvm-ffi/src/macros.rs b/rust/tvm-ffi/src/macros.rs index 37390ddea..4fc5ef70b 100644 --- a/rust/tvm-ffi/src/macros.rs +++ b/rust/tvm-ffi/src/macros.rs @@ -270,6 +270,12 @@ macro_rules! define_object_wrapper { } } + impl From<$name> for $crate::object::ObjectRef { + fn from(wrapper: $name) -> Self { + wrapper.into_object_ref() + } + } + impl $crate::object_wrapper::ObjectWrapper for $name { const TYPE_KEY: &'static str = $type_key; @@ -332,6 +338,15 @@ macro_rules! impl_object_hierarchy { } } + // Implement TryFrom for Self (downcast) + impl TryFrom<$direct_parent> for $self_ty { + type Error = $direct_parent; + + fn try_from(value: $direct_parent) -> Result { + $crate::subtyping::try_downcast(value) + } + } + // Implement From for each ancestor (transitive upcast) $( impl From<$self_ty> for $ancestor { diff --git a/rust/tvm-ffi/src/subtyping.rs b/rust/tvm-ffi/src/subtyping.rs index d1c944790..18c427030 100644 --- a/rust/tvm-ffi/src/subtyping.rs +++ b/rust/tvm-ffi/src/subtyping.rs @@ -29,6 +29,7 @@ use tvm_ffi_sys::TVMFFIGetTypeInfo; /// /// # Safety /// This function accesses the type info table via FFI and follows ancestor pointers. +#[doc(hidden)] pub unsafe fn is_instance_of(type_index: i32, target_index: i32) -> bool { if type_index == target_index { return true; @@ -59,9 +60,10 @@ pub unsafe fn is_instance_of(type_index: i32, target_index: i32) -> bool { /// * `From` - The source type (subtype) /// * `To` - The target type (supertype) /// -/// # Safety -/// The caller must ensure that `To` is a valid supertype of `From`. -/// This is typically enforced by the `impl_object_hierarchy!` macro. +/// # Internal Implementation Detail +/// This function is public for macro expansion but should not be called directly. +/// Use `From::from()` or `.into()` for upcasting instead. +#[doc(hidden)] pub fn upcast(value: From) -> To { unsafe { let arc = ::into_data(value); @@ -83,8 +85,10 @@ pub fn upcast(value: From) -> To { /// * `Ok(To)` - If the runtime type check succeeds /// * `Err(From)` - If the runtime type check fails, returns the original value /// -/// # Safety -/// This function performs runtime type checking using the TVM FFI type system. +/// # Internal Implementation Detail +/// This function is public for macro expansion but should not be called directly. +/// Use `TryFrom::try_from()` or `.try_into()` for downcasting instead. +#[doc(hidden)] pub fn try_downcast(value: From) -> Result { unsafe { let arc = ::data(&value); From 228ec250393302c70767a0b28deaf495c381fd7a Mon Sep 17 00:00:00 2001 From: Huanqi Cao Date: Sat, 14 Feb 2026 22:45:38 +0800 Subject: [PATCH 17/29] fix(stubgen): use type_ancestors[type_depth-1] for direct parent check_repr_c previously used ancestor[0] as parent, which for multi-level inheritance returns the root type (ffi.Object). This caused parent_total_size=24 while direct field offset=56, falsely rejected as layout gap. Use ancestor[type_depth-1] for direct parent so TestObjectDerived passes repr(C). Co-authored-by: Cursor --- rust/tvm-ffi-stubgen/src/repr_c.rs | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/rust/tvm-ffi-stubgen/src/repr_c.rs b/rust/tvm-ffi-stubgen/src/repr_c.rs index 66577d3cd..9ffeea36d 100644 --- a/rust/tvm-ffi-stubgen/src/repr_c.rs +++ b/rust/tvm-ffi-stubgen/src/repr_c.rs @@ -59,7 +59,8 @@ pub(crate) fn check_repr_c( } let parent_type_key = if info.type_depth > 0 && !info.type_acenstors.is_null() { - let ancestor_ptr = unsafe { *info.type_acenstors.add(0) }; + // Direct parent is ancestor[type_depth - 1]; ancestor[0] is the root. + let ancestor_ptr = unsafe { *info.type_acenstors.add((info.type_depth - 1) as usize) }; if ancestor_ptr.is_null() { return None; } @@ -168,7 +169,7 @@ fn repr_c_field_type( ) -> Option<(String, bool)> { let schema = schema?; match schema.origin.as_str() { - "Any" => Some(("tvm_ffi::AnyValue".to_string(), false)), + "Any" | "ffi.Any" => Some(("tvm_ffi::AnyValue".to_string(), false)), "bool" => Some(("bool".to_string(), true)), "int" => match field_size { 1 => Some(("i8".to_string(), true)), From 6c49df75665beae7160354875cf35f222c5af52b Mon Sep 17 00:00:00 2001 From: Huanqi Cao Date: Sat, 14 Feb 2026 23:42:15 +0800 Subject: [PATCH 18/29] docs(stubgen): split user guide and harden runnable examples Move the Rust stubgen usage guide into docs/packaging so it is grouped with packaging workflows, and keep the crate README focused on generated interface/design details. Update the stubgen integration test to validate the documented usage flow with raw-string test source and stable APIs. Co-authored-by: Cursor --- docs/index.rst | 1 + docs/packaging/rust_stubgen.md | 100 +++++++++++++++ rust/tvm-ffi-stubgen/README.md | 174 ++++++++++++++++++++++++++ rust/tvm-ffi-stubgen/tests/stubgen.rs | 23 +++- 4 files changed, 296 insertions(+), 2 deletions(-) create mode 100644 docs/packaging/rust_stubgen.md create mode 100644 rust/tvm-ffi-stubgen/README.md diff --git a/docs/index.rst b/docs/index.rst index 62768c74e..79da2e877 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -71,6 +71,7 @@ Table of Contents packaging/python_packaging.rst packaging/stubgen.rst + packaging/rust_stubgen.md packaging/cpp_tooling.rst .. toctree:: diff --git a/docs/packaging/rust_stubgen.md b/docs/packaging/rust_stubgen.md new file mode 100644 index 000000000..a7b0e3787 --- /dev/null +++ b/docs/packaging/rust_stubgen.md @@ -0,0 +1,100 @@ + + + + + + + + + + + + + + + + + +# Rust Stubgen Guide + +```{note} +The Rust stub generation flow is currently experimental and may evolve. +``` + +This guide covers practical usage of `tvm-ffi-stubgen`: generation command, output crate, and how to call generated APIs. + +## Generate a Stub Crate + +Run from `3rdparty/tvm/3rdparty/tvm-ffi/rust`: + +```bash +cargo run -p tvm-ffi-stubgen -- \ + --init-prefix testing \ + --init-crate tvm-ffi-testing \ + --dlls /abs/path/to/libtvm_ffi_testing.so \ + --overwrite +``` + +### Arguments + +- `OUT_DIR`: positional output directory +- `--dlls`: one or more dynamic libraries for reflection metadata +- `--init-prefix`: registry prefix filter (functions/types to include) +- `--init-crate`: generated crate name +- `--tvm-ffi-path`: optional local path override for `tvm-ffi` +- `--overwrite`: overwrite non-empty output directory + +## Generated Output Layout + +The output is a standalone Rust crate: + +- `Cargo.toml` +- `src/lib.rs` +- `src/_tvm_ffi_stubgen_detail/functions.rs` +- `src/_tvm_ffi_stubgen_detail/types.rs` + +`src/lib.rs` re-exports generated wrappers and provides: + +```rust +pub fn load_library(path: &str) -> tvm_ffi::Result +``` + +## Using Generated Crate + +Using the generated stubs is straightforward—simply load the runtime library, call exported functions, and work with generated object wrappers and subtyping as needed. The full process is shown in the following example, covering typical usage: + +```rust +use tvm_ffi_testing as stub; + +fn main() -> tvm_ffi::Result<()> { + // Load FFI library (required before any calls) + stub::load_library("/abs/path/to/libtvm_ffi_testing.so")?; + + // Call a generated function with typed arguments + let y = stub::add_one(1)?; + assert_eq!(y, 2); + + // Call a function via packed interface for dynamic signature + let _out = stub::echo(&[tvm_ffi::Any::from(1_i64)])?; + + // Use object-returning wrappers and ObjectRef-based APIs + let obj = stub::make_unregistered_object()?; + let count = stub::object_use_count(obj.clone())?; + assert!(count >= 1); + + // Fallback wrapper can be built from ObjectRef directly + let _wrapped: stub::TestUnregisteredObject = obj.into(); + + Ok(()) +} +``` + +- Load the library once before using the APIs. +- Generated functions support typed signatures when possible and fall back to `Any` for dynamic calling. +- Generated object-returning wrappers integrate with `ObjectRef` APIs and wrapper conversions. + + +## Related Docs + +- Rust language guide: `guides/rust_lang_guide.md` +- Rust stubgen design details (implementation-oriented): `rust/tvm-ffi-stubgen/README.md` diff --git a/rust/tvm-ffi-stubgen/README.md b/rust/tvm-ffi-stubgen/README.md new file mode 100644 index 000000000..c73c78e27 --- /dev/null +++ b/rust/tvm-ffi-stubgen/README.md @@ -0,0 +1,174 @@ +# Rust Stubgen Guide + +`tvm-ffi-stubgen` generates Rust stubs from TVM-FFI reflection metadata. +This document is design-oriented and focuses on generated interface forms and implementation choices. + +## Table of Contents + +- [Document Scope](#document-scope) +- [Generated Interface Forms](#generated-interface-forms) +- [Object Model and Inheritance](#object-model-and-inheritance) +- [Field Accessor Style](#field-accessor-style) +- [Subtyping and Cast Rules](#subtyping-and-cast-rules) +- [repr(C) Decision Rules](#reprc-decision-rules) +- [Safety and Fallback Strategy](#safety-and-fallback-strategy) +- [Related User Guide](#related-user-guide) + +## Document Scope + +This README intentionally does not duplicate full command-line tutorial content. +For command usage and end-to-end calling examples, see: + +- `docs/packaging/rust_stubgen.md` + +## Generated Interface Forms + +Stubgen emits a public facade (`src/lib.rs`) plus detail modules: + +- `src/_tvm_ffi_stubgen_detail/functions.rs` +- `src/_tvm_ffi_stubgen_detail/types.rs` + +### Function Wrappers + +#### Typed wrapper path + +When type schema is fully known, function wrappers are generated as typed Rust APIs: + +```rust +pub fn add_one(_0: i64) -> Result { ... } +``` + +#### Packed fallback path + +When schema is not fully resolved, wrappers use packed calling style: + +```rust +pub fn echo(args: &[Any]) -> Result { ... } +``` + +### Type Wrappers + +#### repr(C) path (preferred) + +For layout-compatible object types: + +- `#[repr(C)] Obj` +- `#[derive(ObjectRef, Clone)] ` +- `impl_object_hierarchy!(...)` +- direct-field `get_` accessors + +Example shape: + +```rust +#[repr(C)] +pub struct TestObjectDerivedObj { + parent: TestObjectBaseObj, + v_map: tvm_ffi::Map, + v_array: tvm_ffi::Array, +} +``` + +#### fallback wrapper path + +For non-repr(C)-compatible types: + +- `define_object_wrapper!(Type, "type.key")` +- field access via `FieldGetter` + +## Object Model and Inheritance + +repr(C) object inheritance is modeled by composition and deref chain: + +### Obj-level layout + +Derived object stores parent object as first field: + +```rust +#[repr(C)] +pub struct DerivedObj { + parent: BaseObj, + extra: i64, +} +``` + +### Ref-level inheritance + +Ref wrappers use `impl_object_hierarchy!` to establish: + +- `Deref Base>` +- `From for Base/ObjectRef` (upcast) +- `TryFrom for Derived` (downcast) + +## Field Accessor Style + +Getter generation follows a single style: + +- name prefix is always `get_` +- only direct fields of current type generate getters +- inherited getters are available via deref auto-coercion + +### Return type rules + +- POD field -> return by value +- object/container field -> clone and return user-facing type + +Example: + +```rust +impl TestObjectDerived { + pub fn get_v_map(&self) -> tvm_ffi::Map { + self.data.v_map.clone() + } +} +``` + +## Subtyping and Cast Rules + +Stubgen-generated repr(C) refs use standard Rust traits as the only user-facing cast API: + +- borrow upcast: `Deref` +- consuming upcast: `From` / `.into()` +- consuming downcast: `TryFrom` / `.try_into()` + +This avoids custom cast traits and keeps compile-time type constraints explicit. + +## repr(C) Decision Rules + +`check_repr_c` gates repr(C) generation. + +### Required metadata checks + +- `total_size > 0` +- valid field `offset/size/alignment` +- field order and no overlap +- aligned placement (`field.offset == align_up(pos, alignment)`) +- parent boundary matches first direct field offset +- parent type is also repr(C)-compatible + +### Schema mapping rules + +Representative mappings include: + +- `Any` / `ffi.Any` -> `tvm_ffi::AnyValue` +- `ffi.Array` -> `tvm_ffi::Array` +- `ffi.Map` -> `tvm_ffi::Map` + +## Safety and Fallback Strategy + +Generated user-facing code is intended to remain safe Rust. + +### Safety boundary + +- unsafe operations are encapsulated in `tvm-ffi` internals and macros +- generated wrappers and getters are safe APIs + +### Built-in filtering and fallback + +- built-in `ffi.*` primitives are not re-generated as wrapper types +- unsupported/non-layout-compatible object types fall back to `define_object_wrapper!` + +## Related User Guide + +For generation command-line usage and step-by-step invocation examples, see: + +- `docs/packaging/rust_stubgen.md` diff --git a/rust/tvm-ffi-stubgen/tests/stubgen.rs b/rust/tvm-ffi-stubgen/tests/stubgen.rs index 218ef71fd..a009d9238 100644 --- a/rust/tvm-ffi-stubgen/tests/stubgen.rs +++ b/rust/tvm-ffi-stubgen/tests/stubgen.rs @@ -117,8 +117,27 @@ fn write_integration_test( let tests_dir = out_dir.join("tests"); fs::create_dir_all(&tests_dir)?; let test_body = format!( - "use tvm_ffi_testing_stub::add_one;\n\n#[test]\nfn add_one_roundtrip() {{\n let lib_path = \"{}\";\n tvm_ffi::Module::load_from_file(lib_path).expect(\"load tvm_ffi_testing\");\n let value = add_one(1).expect(\"call add_one\");\n assert_eq!(value, 2);\n}}\n", - testing_lib.display() + r#"use tvm_ffi_testing_stub as stub; + +#[test] +fn generated_usage_roundtrip() {{ + let lib_path = "{lib_path}"; + stub::load_library(lib_path).expect("load tvm_ffi_testing"); + + let value = stub::add_one(1).expect("call add_one"); + assert_eq!(value, 2); + + let _out = stub::echo(&[tvm_ffi::Any::from(1_i64)]).expect("call echo"); + + let obj = stub::make_unregistered_object().expect("create unregistered object"); + let count = stub::object_use_count(obj.clone()).expect("query object use count"); + assert!(count >= 1); + + // Fallback wrapper can be constructed from ObjectRef directly. + let _wrapped: stub::TestUnregisteredObject = obj.into(); +}} +"#, + lib_path = testing_lib.display() ); fs::write(tests_dir.join("integration.rs"), test_body)?; Ok(()) From 7a7b2d1cb6fbd8559314d4263958d7fb20a868f5 Mon Sep 17 00:00:00 2001 From: Huanqi Cao Date: Sun, 15 Feb 2026 00:29:56 +0800 Subject: [PATCH 19/29] feat(stubgen): resolve object methods via type metadata Switch generated object method wrappers from global registry lookup to type reflection method resolution, and map constructor wrappers to Rust-style `new` while removing legacy `c_ffi_init`. Extend integration tests and docs to validate typed constructors and Cxx inheritance roundtrip behavior. Co-authored-by: Cursor --- docs/packaging/rust_stubgen.md | 20 ++++++++ rust/tvm-ffi-stubgen/README.md | 15 ++++++ rust/tvm-ffi-stubgen/src/generate.rs | 26 ++++++---- rust/tvm-ffi-stubgen/src/model.rs | 2 +- rust/tvm-ffi-stubgen/tests/stubgen.rs | 43 +++++++++++++++++ rust/tvm-ffi/src/object_wrapper.rs | 68 ++++++++++++++++++++++++++- 6 files changed, 163 insertions(+), 11 deletions(-) diff --git a/docs/packaging/rust_stubgen.md b/docs/packaging/rust_stubgen.md index a7b0e3787..032d6bad9 100644 --- a/docs/packaging/rust_stubgen.md +++ b/docs/packaging/rust_stubgen.md @@ -77,6 +77,25 @@ fn main() -> tvm_ffi::Result<()> { // Call a function via packed interface for dynamic signature let _out = stub::echo(&[tvm_ffi::Any::from(1_i64)])?; + // Object constructor/method wrappers are resolved from type metadata. + let pair_obj = stub::TestIntPair::new(3, 4)?; + let pair: stub::TestIntPair = pair_obj + .try_into() + .map_err(|_| tvm_ffi::Error::new(tvm_ffi::TYPE_ERROR, "downcast failed", ""))?; + let sum_any = pair.sum(&[])?; + let sum: i64 = sum_any.try_into()?; + assert_eq!(sum, 7); + + // Cxx inheritance sample: construct derived, view as base, then convert back. + let derived_obj = stub::TestCxxClassDerived::new(11, 7, 3.5, 1.25)?; + let base: stub::TestCxxClassBase = derived_obj.clone().into(); + let base_obj: tvm_ffi::object::ObjectRef = base.clone().into(); + let derived_again: stub::TestCxxClassDerived = base_obj.into(); + assert_eq!(base.v_i64()?, 11); + assert_eq!(base.v_i32()?, 7); + assert!((derived_again.v_f64()? - 3.5).abs() < 1e-9); + assert!((derived_again.v_f32()? - 1.25).abs() < 1e-6); + // Use object-returning wrappers and ObjectRef-based APIs let obj = stub::make_unregistered_object()?; let count = stub::object_use_count(obj.clone())?; @@ -91,6 +110,7 @@ fn main() -> tvm_ffi::Result<()> { - Load the library once before using the APIs. - Generated functions support typed signatures when possible and fall back to `Any` for dynamic calling. +- Generated object method wrappers (including constructor `new`) are resolved via type metadata rather than global function lookup. - Generated object-returning wrappers integrate with `ObjectRef` APIs and wrapper conversions. diff --git a/rust/tvm-ffi-stubgen/README.md b/rust/tvm-ffi-stubgen/README.md index c73c78e27..19dee0231 100644 --- a/rust/tvm-ffi-stubgen/README.md +++ b/rust/tvm-ffi-stubgen/README.md @@ -75,6 +75,21 @@ For non-repr(C)-compatible types: - `define_object_wrapper!(Type, "type.key")` - field access via `FieldGetter` +### Object Method Lookup Path + +Object methods (including `__ffi_init__`) are generated from type reflection metadata, +not from global function registry names: + +- generated code calls `tvm_ffi::object_wrapper::resolve_type_method(type_key, method_name)` +- runtime lookup path is `TVMFFITypeKeyToIndex -> TVMFFIGetTypeInfo -> methods[]` +- the `method` entry is converted from `AnyView` to owned `Any`, then to `ffi.Function` + +Global wrappers under `functions.rs` still use `Function::get_global`, but type methods in +`types.rs` no longer assume `.` is globally registered. + +For constructor-like methods (`__ffi_init__`), stubgen emits `new(...)` directly as the public +Rust API (or `ffi_init` only when a user-defined `new` method already exists). + ## Object Model and Inheritance repr(C) object inheritance is modeled by composition and deref chain: diff --git a/rust/tvm-ffi-stubgen/src/generate.rs b/rust/tvm-ffi-stubgen/src/generate.rs index e8e2edbb1..063e2a35e 100644 --- a/rust/tvm-ffi-stubgen/src/generate.rs +++ b/rust/tvm-ffi-stubgen/src/generate.rs @@ -89,12 +89,26 @@ pub(crate) fn build_type_entries( if info.num_methods > 0 && !info.methods.is_null() { let method_slice = unsafe { std::slice::from_raw_parts(info.methods, info.num_methods as usize) }; + let has_user_new = method_slice.iter().any(|method| { + matches!( + ffi::byte_array_to_string_opt(&method.name).as_deref(), + Some("new") + ) + }); for method in method_slice { let method_name = match ffi::byte_array_to_string_opt(&method.name) { Some(name) => name, None => continue, }; - let rust_method_name = map_method_name(&method_name); + let rust_method_name = if method_name == "__ffi_init__" { + if has_user_new { + "ffi_init".to_string() + } else { + "new".to_string() + } + } else { + map_method_name(&method_name) + }; let is_static = (method.flags & METHOD_FLAG_STATIC) != 0; let meta = ffi::byte_array_to_string_opt(&method.metadata); let schema = meta @@ -103,9 +117,8 @@ pub(crate) fn build_type_entries( .and_then(|s| parse_type_schema(&s)); let sig = build_method_sig(schema.as_ref(), type_map, Some(key.as_str()), is_static); - let full_name = format!("{}.{}", key, method_name); methods.push(MethodGen { - full_name, + source_name: method_name, rust_name: rust_method_name, sig, is_static, @@ -965,8 +978,8 @@ fn render_method_static(out: &mut String, ty: &TypeGen, method: &MethodGen, inde let static_name = static_ident("METHOD", &format!("{}::{}", ty.type_key, method.rust_name)); writeln!( out, - "{}static {}: LazyLock = LazyLock::new(|| tvm_ffi::Function::get_global(\"{}\").expect(\"missing method\"));", - indent_str, static_name, method.full_name + "{}static {}: LazyLock = LazyLock::new(|| tvm_ffi::object_wrapper::resolve_type_method(\"{}\", \"{}\").expect(\"missing type method\"));", + indent_str, static_name, ty.type_key, method.source_name ) .ok(); } @@ -1122,9 +1135,6 @@ fn render_method_call_args(method: &MethodGen) -> String { } fn map_method_name(name: &str) -> String { - if name == "__ffi_init__" { - return "c_ffi_init".to_string(); - } sanitize_ident(name, IdentStyle::Function) } diff --git a/rust/tvm-ffi-stubgen/src/model.rs b/rust/tvm-ffi-stubgen/src/model.rs index 994a116cd..5e054bd65 100644 --- a/rust/tvm-ffi-stubgen/src/model.rs +++ b/rust/tvm-ffi-stubgen/src/model.rs @@ -46,7 +46,7 @@ pub(crate) struct FunctionGen { #[derive(Debug, Clone)] pub(crate) struct MethodGen { - pub(crate) full_name: String, + pub(crate) source_name: String, pub(crate) rust_name: String, pub(crate) sig: FunctionSig, pub(crate) is_static: bool, diff --git a/rust/tvm-ffi-stubgen/tests/stubgen.rs b/rust/tvm-ffi-stubgen/tests/stubgen.rs index a009d9238..b93357b61 100644 --- a/rust/tvm-ffi-stubgen/tests/stubgen.rs +++ b/rust/tvm-ffi-stubgen/tests/stubgen.rs @@ -51,11 +51,33 @@ fn stubgen_tvm_ffi_testing() { .join("src") .join("_tvm_ffi_stubgen_detail") .join("functions.rs"); + let types_rs = out_dir + .join("src") + .join("_tvm_ffi_stubgen_detail") + .join("types.rs"); assert!(cargo_toml.exists(), "Cargo.toml not generated"); assert!(functions_rs.exists(), "functions.rs not generated"); + assert!(types_rs.exists(), "types.rs not generated"); let functions_body = fs::read_to_string(functions_rs).expect("read functions.rs"); + let types_body = fs::read_to_string(types_rs).expect("read types.rs"); assert!(functions_body.contains("add_one"), "missing add_one stub"); + assert!( + types_body.contains("resolve_type_method"), + "type method wrappers should resolve from type metadata" + ); + assert!( + types_body.contains("pub fn new("), + "constructor `new` should be generated when available" + ); + assert!( + !types_body.contains("c_ffi_init"), + "legacy constructor name c_ffi_init should not be generated" + ); + assert!( + !types_body.contains("Function::get_global(\"testing.TestIntPair.__ffi_init__\")"), + "type methods should not use global lookup path" + ); write_integration_test(&out_dir, &testing_lib).expect("write integration test"); run_generated_tests(&out_dir, &lib_dir).expect("run generated tests"); @@ -129,6 +151,27 @@ fn generated_usage_roundtrip() {{ let _out = stub::echo(&[tvm_ffi::Any::from(1_i64)]).expect("call echo"); + // Constructor + instance method should resolve from type metadata. + let pair_obj = stub::TestIntPair::new(3, 4).expect("construct TestIntPair"); + let pair: stub::TestIntPair = pair_obj + .try_into() + .unwrap_or_else(|_| panic!("object -> TestIntPair downcast failed")); + let sum_any = pair.sum(&[]).expect("call TestIntPair.sum"); + let sum: i64 = sum_any.try_into().expect("sum any -> i64"); + assert_eq!(sum, 7); + + // Verify upcast/downcast roundtrip on Cxx inheritance chain. + let derived_obj = stub::TestCxxClassDerived::new(11, 7, 3.5, 1.25) + .expect("construct TestCxxClassDerived"); + let _derived: stub::TestCxxClassDerived = derived_obj.clone().into(); + let base: stub::TestCxxClassBase = derived_obj.clone().into(); + let base_obj: tvm_ffi::object::ObjectRef = base.clone().into(); + let roundtrip: stub::TestCxxClassDerived = base_obj.into(); + assert_eq!(base.v_i64().expect("base.v_i64"), 11); + assert_eq!(base.v_i32().expect("base.v_i32"), 7); + assert!((roundtrip.v_f64().expect("derived.v_f64") - 3.5).abs() < 1e-9); + assert!((roundtrip.v_f32().expect("derived.v_f32") - 1.25).abs() < 1e-6); + let obj = stub::make_unregistered_object().expect("create unregistered object"); let count = stub::object_use_count(obj.clone()).expect("query object use count"); assert!(count >= 1); diff --git a/rust/tvm-ffi/src/object_wrapper.rs b/rust/tvm-ffi/src/object_wrapper.rs index ad2ea3086..4e86b0813 100644 --- a/rust/tvm-ffi/src/object_wrapper.rs +++ b/rust/tvm-ffi/src/object_wrapper.rs @@ -22,8 +22,8 @@ use crate::object::{Object, ObjectArc, ObjectRef, ObjectRefCore}; use crate::type_traits::AnyCompatible; use std::marker::PhantomData; use tvm_ffi_sys::{ - TVMFFIAny, TVMFFIByteArray, TVMFFIFieldGetter, TVMFFIGetTypeInfo, TVMFFIObject, - TVMFFITypeKeyToIndex, + TVMFFIAny, TVMFFIAnyViewToOwnedAny, TVMFFIByteArray, TVMFFIFieldGetter, TVMFFIGetTypeInfo, + TVMFFIObject, TVMFFITypeKeyToIndex, }; /// Runtime support for stubgen-generated object wrappers. @@ -37,6 +37,20 @@ pub trait ObjectWrapper: Clone { fn into_object_ref(self) -> ObjectRef; } +/// Resolve an object type method from runtime reflection metadata. +/// +/// Unlike `Function::get_global`, this lookup walks `TVMFFITypeInfo.methods` +/// for the given type key and converts the method entry to a callable +/// `ffi.Function`. +pub fn resolve_type_method(type_key: &str, method_name: &str) -> crate::Result { + unsafe { + let key = TVMFFIByteArray::from_str(type_key); + let mut type_index = 0i32; + crate::check_safe_call!(TVMFFITypeKeyToIndex(&key, &mut type_index))?; + resolve_type_method_by_type_index(type_index, type_key, method_name) + } +} + struct FieldGetterInner { offset: usize, getter: TVMFFIFieldGetter, @@ -107,6 +121,56 @@ fn resolve_field_by_type_key( } } +fn resolve_type_method_by_type_index( + type_index: i32, + type_key: &str, + method_name: &str, +) -> crate::Result { + unsafe { + let info = TVMFFIGetTypeInfo(type_index); + if info.is_null() { + crate::bail!( + crate::error::ATTRIBUTE_ERROR, + "Type info missing for type {}", + type_key + ); + } + let info = &*info; + if info.methods.is_null() || info.num_methods <= 0 { + crate::bail!( + crate::error::ATTRIBUTE_ERROR, + "Type {} has no methods", + type_key + ); + } + let methods = std::slice::from_raw_parts(info.methods, info.num_methods as usize); + for method in methods { + if method.name.as_str() != method_name { + continue; + } + let mut owned = TVMFFIAny::new(); + crate::check_safe_call!(TVMFFIAnyViewToOwnedAny(&method.method, &mut owned))?; + let method_any = Any::from_raw_ffi_any(owned); + return method_any.try_into().map_err(|_err: crate::Error| { + crate::Error::new( + crate::TYPE_ERROR, + &format!( + "Method {}.{} is not callable as ffi.Function", + type_key, method_name + ), + "", + ) + }); + } + crate::bail!( + crate::error::ATTRIBUTE_ERROR, + "Method {}.{} not found in reflection metadata", + type_key, + method_name + ); + } +} + fn resolve_field_by_type_index( type_index: i32, field_name: &'static str, From c445259886fd27e33e41c82a722c16097bf9ed4d Mon Sep 17 00:00:00 2001 From: Huanqi Cao Date: Sun, 15 Feb 2026 00:32:41 +0800 Subject: [PATCH 20/29] chore(stubgen): remove dead code and unused repr-c fields Drop unused helper and stale repr-c struct fields to keep the Rust workspace warning-free after the method lookup and constructor naming refactor. Co-authored-by: Cursor --- rust/tvm-ffi-stubgen/src/generate.rs | 7 ------- rust/tvm-ffi-stubgen/src/repr_c.rs | 5 ----- 2 files changed, 12 deletions(-) diff --git a/rust/tvm-ffi-stubgen/src/generate.rs b/rust/tvm-ffi-stubgen/src/generate.rs index 063e2a35e..5c67716d1 100644 --- a/rust/tvm-ffi-stubgen/src/generate.rs +++ b/rust/tvm-ffi-stubgen/src/generate.rs @@ -756,13 +756,6 @@ fn render_function(out: &mut String, func: &FunctionGen, indent: usize) { writeln!(out).ok(); } -fn type_key_to_short_rust_name(type_map: &BTreeMap, type_key: &str) -> String { - type_map - .get(type_key) - .and_then(|path| path.split("::").last().map(String::from)) - .unwrap_or_else(|| type_key.to_string()) -} - fn render_type(out: &mut String, ty: &TypeGen, indent: usize, type_map: &BTreeMap) { // Filter out built-in types that are already provided by tvm-ffi if is_builtin_type(&ty.type_key) { diff --git a/rust/tvm-ffi-stubgen/src/repr_c.rs b/rust/tvm-ffi-stubgen/src/repr_c.rs index 9ffeea36d..9a771560a 100644 --- a/rust/tvm-ffi-stubgen/src/repr_c.rs +++ b/rust/tvm-ffi-stubgen/src/repr_c.rs @@ -27,8 +27,6 @@ use std::collections::BTreeMap; pub(crate) struct ReprCInfo { /// Type key of the immediate parent (Object or a subclass). None for root types. pub(crate) parent_type_key: Option, - /// Total size of the struct in bytes. - pub(crate) total_size: i32, /// Direct fields of this type only (not inherited), sorted by offset. /// For codegen: first field of *Obj is parent (or Object), then these. pub(crate) direct_fields: Vec, @@ -36,7 +34,6 @@ pub(crate) struct ReprCInfo { #[derive(Debug, Clone)] pub(crate) struct ReprCField { - pub(crate) name: String, pub(crate) rust_name: String, pub(crate) offset: i64, pub(crate) size: i64, @@ -101,7 +98,6 @@ pub(crate) fn check_repr_c( repr_c_field_type(schema.as_ref(), type_map, type_key, field.size)?; direct_fields.push(ReprCField { rust_name: sanitize_ident(&name, IdentStyle::Function), - name, offset: field.offset, size: field.size, alignment: field.alignment, @@ -136,7 +132,6 @@ pub(crate) fn check_repr_c( Some(ReprCInfo { parent_type_key, - total_size, direct_fields, }) } From 8764dcb8b278b6f41f29abb0a81d9a4b10cfd1b8 Mon Sep 17 00:00:00 2001 From: Huanqi Cao Date: Fri, 27 Feb 2026 00:29:06 +0800 Subject: [PATCH 21/29] [stubgen] support multi-prefix generation for single-crate multi-namespace output Allow --init-prefix to be specified multiple times. With a single prefix, behavior is unchanged (prefix stripped, items at crate root). With multiple prefixes, no stripping occurs and each prefix becomes a top-level module, enabling a single generated crate to cover multiple namespaces (e.g. tl, ir, tir, script). Also fixes cross-module repr(C) parent type resolution by using absolute paths into the types module, and adds `impl`/`mod` to the keyword escape list in sanitize_ident. Made-with: Cursor --- docs/packaging/rust_stubgen.md | 21 ++++++++++++++++-- rust/tvm-ffi-stubgen/README.md | 22 +++++++++++++++++++ rust/tvm-ffi-stubgen/src/cli.rs | 4 ++-- rust/tvm-ffi-stubgen/src/generate.rs | 23 +++++++++++--------- rust/tvm-ffi-stubgen/src/lib.rs | 32 +++++++++++++++++++++------- rust/tvm-ffi-stubgen/src/repr_c.rs | 9 +++++--- 6 files changed, 86 insertions(+), 25 deletions(-) diff --git a/docs/packaging/rust_stubgen.md b/docs/packaging/rust_stubgen.md index 032d6bad9..97c6febad 100644 --- a/docs/packaging/rust_stubgen.md +++ b/docs/packaging/rust_stubgen.md @@ -38,12 +38,29 @@ cargo run -p tvm-ffi-stubgen -- \ ### Arguments - `OUT_DIR`: positional output directory -- `--dlls`: one or more dynamic libraries for reflection metadata -- `--init-prefix`: registry prefix filter (functions/types to include) +- `--dlls`: one or more dynamic libraries for reflection metadata (`;`-separated) +- `--init-prefix`: registry prefix filter (repeatable; see multi-prefix below) - `--init-crate`: generated crate name - `--tvm-ffi-path`: optional local path override for `tvm-ffi` - `--overwrite`: overwrite non-empty output directory +### Multi-Prefix Mode + +`--init-prefix` can be specified multiple times to generate a single crate covering +several namespaces: + +```bash +cargo run -p tvm-ffi-stubgen -- \ + --dlls "libtilelang_module.so;libtvm.so" \ + --init-prefix tl --init-prefix ir --init-prefix tir --init-prefix script \ + --init-crate tilelang-ffi \ + --overwrite +``` + +With a single prefix the prefix is stripped and items land at the crate root. +With multiple prefixes no stripping occurs; each prefix becomes a top-level module +(`crate::tl::*`, `crate::ir::*`, etc.). + ## Generated Output Layout The output is a standalone Rust crate: diff --git a/rust/tvm-ffi-stubgen/README.md b/rust/tvm-ffi-stubgen/README.md index 19dee0231..479a8271c 100644 --- a/rust/tvm-ffi-stubgen/README.md +++ b/rust/tvm-ffi-stubgen/README.md @@ -168,6 +168,28 @@ Representative mappings include: - `ffi.Array` -> `tvm_ffi::Array` - `ffi.Map` -> `tvm_ffi::Map` +## Multi-Prefix Generation + +`--init-prefix` accepts multiple values. Behavior depends on the count: + +- **Single prefix** (e.g. `--init-prefix testing`): the prefix is stripped and items land + at the crate root. This is the default backward-compatible mode. +- **Multiple prefixes** (e.g. `--init-prefix tl --init-prefix ir --init-prefix tir`): + no prefix is stripped; each prefix naturally becomes a top-level module. + +Example with multiple prefixes: + +``` +tl.KernelLaunch → crate::tl::KernelLaunch +ir.Span → crate::ir::Span +tir.BufferLoad → crate::tir::BufferLoad +script.ir_builder.* → crate::script::ir_builder::* +``` + +This allows a single generated crate to cover multiple namespaces. Cross-namespace +`repr(C)` inheritance (e.g. `tir.PrimFunc` extending `ir.BaseFunc`) resolves within +the crate without workarounds. + ## Safety and Fallback Strategy Generated user-facing code is intended to remain safe Rust. diff --git a/rust/tvm-ffi-stubgen/src/cli.rs b/rust/tvm-ffi-stubgen/src/cli.rs index 683b507dc..69022af40 100644 --- a/rust/tvm-ffi-stubgen/src/cli.rs +++ b/rust/tvm-ffi-stubgen/src/cli.rs @@ -28,8 +28,8 @@ pub struct Args { pub out_dir: PathBuf, #[arg(long = "dlls", value_delimiter = ';', num_args = 1..)] pub dlls: Vec, - #[arg(long = "init-prefix")] - pub init_prefix: String, + #[arg(long = "init-prefix", num_args = 1..)] + pub init_prefix: Vec, #[arg(long = "init-crate")] pub init_crate: String, #[arg(long = "tvm-ffi-path")] diff --git a/rust/tvm-ffi-stubgen/src/generate.rs b/rust/tvm-ffi-stubgen/src/generate.rs index 5c67716d1..120333f78 100644 --- a/rust/tvm-ffi-stubgen/src/generate.rs +++ b/rust/tvm-ffi-stubgen/src/generate.rs @@ -685,7 +685,12 @@ fn render_type_module( let indent_str = " ".repeat(indent); if indent > 0 { writeln!(out, "{}use std::sync::LazyLock;", indent_str).ok(); - writeln!(out, "{}use tvm_ffi::{{Any, AnyView, Result}};", indent_str).ok(); + writeln!( + out, + "{}use tvm_ffi::{{Any, AnyView, ObjectArc, Result}};", + indent_str + ) + .ok(); writeln!(out).ok(); } for ty in &node.types { @@ -805,22 +810,20 @@ fn render_repr_c_type( let indent_str = " ".repeat(indent); let obj_name = format!("{}Obj", ty.rust_name); - // Determine parent type for *Obj struct + // Determine parent type for *Obj struct using absolute path into + // the types module so cross-module references resolve correctly. let parent_ty = match &info.parent_type_key { None => "tvm_ffi::object::Object".to_string(), Some(parent_key) if parent_key == "ffi.Object" => "tvm_ffi::object::Object".to_string(), Some(parent_key) => { - // Use the type from type_map to get the full Rust path let parent_rust = _type_map .get(parent_key) .map(|s| s.clone()) .unwrap_or_else(|| format!("{}Obj", sanitize_ident(parent_key, IdentStyle::Type))); - // Extract just the type name and append "Obj" - if let Some(last) = parent_rust.split("::").last() { - format!("{}Obj", last) - } else { - format!("{}Obj", sanitize_ident(parent_key, IdentStyle::Type)) - } + // Convert crate path to absolute path inside types module and append Obj + let types_path = + parent_rust.replacen("crate::", "crate::_tvm_ffi_stubgen_detail::types::", 1); + format!("{}Obj", types_path) } }; @@ -871,7 +874,7 @@ fn render_repr_c_type( writeln!(out).ok(); } - // Generate getter methods for direct fields only + // Generate getter methods for direct fields writeln!(out, "{}impl {} {{", indent_str, ty.rust_name).ok(); for f in &info.direct_fields { let method_name = format!("get_{}", f.rust_name); diff --git a/rust/tvm-ffi-stubgen/src/lib.rs b/rust/tvm-ffi-stubgen/src/lib.rs index faa9c18cd..d0e1ee0ef 100644 --- a/rust/tvm-ffi-stubgen/src/lib.rs +++ b/rust/tvm-ffi-stubgen/src/lib.rs @@ -28,25 +28,40 @@ pub use cli::Args; use std::collections::{BTreeSet, HashSet}; pub fn run(args: Args) -> Result<(), Box> { - let prefix = utils::normalize_prefix(&args.init_prefix); + if args.init_prefix.is_empty() { + return Err("--init-prefix is required".into()); + } if args.dlls.is_empty() { return Err("--dlls is required".into()); } utils::ensure_out_dir(&args.out_dir, args.overwrite)?; + let prefixes: Vec = args + .init_prefix + .iter() + .map(|p| utils::normalize_prefix(p)) + .collect(); + // Single prefix: strip it so items land at crate root (backward compat). + // Multiple prefixes: don't strip; each prefix becomes a top-level module. + let effective_prefix = if prefixes.len() == 1 { + prefixes[0].clone() + } else { + String::new() + }; + let _loaded_libs = ffi::load_dlls(&args.dlls)?; let global_funcs = ffi::list_global_function_names()?; let filtered_funcs: Vec = global_funcs .into_iter() - .filter(|name| name.starts_with(&prefix)) + .filter(|name| prefixes.iter().any(|p| name.starts_with(p))) .collect(); let type_keys = ffi::list_registered_type_keys()?; let type_key_set: HashSet = type_keys.iter().cloned().collect(); let mut filtered_types: Vec = type_keys .iter() - .filter(|name| name.starts_with(&prefix)) + .filter(|name| prefixes.iter().any(|p| name.starts_with(p))) .cloned() .collect(); @@ -67,12 +82,13 @@ pub fn run(args: Args) -> Result<(), Box> { } } - let type_map = generate::build_type_map(&filtered_types, &prefix); - let functions = generate::build_function_entries(&filtered_funcs, &type_map, &prefix)?; - let types = generate::build_type_entries(&filtered_types, &type_map, &prefix)?; + let type_map = generate::build_type_map(&filtered_types, &effective_prefix); + let functions = + generate::build_function_entries(&filtered_funcs, &type_map, &effective_prefix)?; + let types = generate::build_type_entries(&filtered_types, &type_map, &effective_prefix)?; - let functions_root = generate::build_function_modules(functions, &prefix); - let types_root = generate::build_type_modules(types, &prefix); + let functions_root = generate::build_function_modules(functions, &effective_prefix); + let types_root = generate::build_type_modules(types, &effective_prefix); let cargo_toml = generate::render_cargo_toml(&args, &type_map)?; let lib_rs = generate::render_lib_rs(&functions_root, &types_root); diff --git a/rust/tvm-ffi-stubgen/src/repr_c.rs b/rust/tvm-ffi-stubgen/src/repr_c.rs index 9a771560a..feb6c05e0 100644 --- a/rust/tvm-ffi-stubgen/src/repr_c.rs +++ b/rust/tvm-ffi-stubgen/src/repr_c.rs @@ -63,6 +63,9 @@ pub(crate) fn check_repr_c( } let parent_info = unsafe { &*ancestor_ptr }; let key = ffi::byte_array_to_string_opt(&parent_info.type_key)?; + if key != "ffi.Object" && !type_map.contains_key(&key) { + return None; + } if !check_repr_c(&key, type_map).is_some() { return None; } @@ -236,9 +239,9 @@ fn sanitize_ident(name: &str, style: IdentStyle) -> String { } const KEYWORDS: &[&str] = &[ "as", "break", "const", "continue", "crate", "else", "enum", "extern", "false", "fn", - "for", "if", "in", "let", "loop", "match", "move", "mut", "pub", "ref", "return", "self", - "Self", "static", "struct", "super", "trait", "true", "type", "unsafe", "use", "where", - "while", "async", "await", "dyn", + "for", "if", "impl", "in", "let", "loop", "match", "mod", "move", "mut", "pub", "ref", + "return", "self", "Self", "static", "struct", "super", "trait", "true", "type", "unsafe", + "use", "where", "while", "async", "await", "dyn", ]; if KEYWORDS.contains(&out.as_str()) { out.push('_'); From 4660f2c6b26f0fe2f6df1aed68deb72c03e1aa45 Mon Sep 17 00:00:00 2001 From: Huanqi Cao Date: Sun, 8 Mar 2026 14:11:50 +0800 Subject: [PATCH 22/29] feat(stubgen): format generated crates by default Run cargo fmt on generated Rust stub crates by default and add a --no-format escape hatch for raw generator debugging. Also relax into_typed_fn! so rustfmt-produced typed signatures remain compilable. Made-with: Cursor --- docs/packaging/rust_stubgen.md | 5 +++++ rust/tvm-ffi-stubgen/README.md | 3 +++ rust/tvm-ffi-stubgen/src/cli.rs | 2 ++ rust/tvm-ffi-stubgen/src/lib.rs | 26 ++++++++++++++++++++++++++ rust/tvm-ffi-stubgen/tests/stubgen.rs | 3 ++- rust/tvm-ffi/src/macros.rs | 24 ++++++++++++------------ 6 files changed, 50 insertions(+), 13 deletions(-) diff --git a/docs/packaging/rust_stubgen.md b/docs/packaging/rust_stubgen.md index 97c6febad..039488e9d 100644 --- a/docs/packaging/rust_stubgen.md +++ b/docs/packaging/rust_stubgen.md @@ -43,6 +43,11 @@ cargo run -p tvm-ffi-stubgen -- \ - `--init-crate`: generated crate name - `--tvm-ffi-path`: optional local path override for `tvm-ffi` - `--overwrite`: overwrite non-empty output directory +- `--no-format`: skip the post-generation `cargo fmt` pass + +By default, `tvm-ffi-stubgen` runs `cargo fmt` on the generated crate after emitting +`Cargo.toml`, `build.rs`, and Rust sources. Use `--no-format` only when you need to inspect +the raw generated text before formatting or when debugging generator output itself. ### Multi-Prefix Mode diff --git a/rust/tvm-ffi-stubgen/README.md b/rust/tvm-ffi-stubgen/README.md index 479a8271c..f3cdf2f1c 100644 --- a/rust/tvm-ffi-stubgen/README.md +++ b/rust/tvm-ffi-stubgen/README.md @@ -28,6 +28,9 @@ Stubgen emits a public facade (`src/lib.rs`) plus detail modules: - `src/_tvm_ffi_stubgen_detail/functions.rs` - `src/_tvm_ffi_stubgen_detail/types.rs` +By default the generator runs `cargo fmt` on the emitted crate after writing these files. +Pass `--no-format` to keep the raw generated text when debugging formatting-sensitive output. + ### Function Wrappers #### Typed wrapper path diff --git a/rust/tvm-ffi-stubgen/src/cli.rs b/rust/tvm-ffi-stubgen/src/cli.rs index 69022af40..0184ee1cf 100644 --- a/rust/tvm-ffi-stubgen/src/cli.rs +++ b/rust/tvm-ffi-stubgen/src/cli.rs @@ -36,4 +36,6 @@ pub struct Args { pub tvm_ffi_path: Option, #[arg(long = "overwrite")] pub overwrite: bool, + #[arg(long = "no-format")] + pub no_format: bool, } diff --git a/rust/tvm-ffi-stubgen/src/lib.rs b/rust/tvm-ffi-stubgen/src/lib.rs index d0e1ee0ef..445c45a38 100644 --- a/rust/tvm-ffi-stubgen/src/lib.rs +++ b/rust/tvm-ffi-stubgen/src/lib.rs @@ -26,6 +26,29 @@ mod utils; use crate::schema::{collect_type_keys, extract_type_schema, parse_type_schema}; pub use cli::Args; use std::collections::{BTreeSet, HashSet}; +use std::process::Command; + +fn format_generated_crate(out_dir: &std::path::Path) -> Result<(), Box> { + let manifest_path = out_dir.join("Cargo.toml"); + let output = Command::new("cargo") + .arg("fmt") + .arg("--manifest-path") + .arg(&manifest_path) + .current_dir(out_dir) + .output()?; + if output.status.success() { + return Ok(()); + } + let stderr = String::from_utf8_lossy(&output.stderr); + let stdout = String::from_utf8_lossy(&output.stdout); + Err(format!( + "cargo fmt failed for generated crate {}.\nstdout:\n{}\nstderr:\n{}", + manifest_path.display(), + stdout.trim(), + stderr.trim() + ) + .into()) +} pub fn run(args: Args) -> Result<(), Box> { if args.init_prefix.is_empty() { @@ -104,6 +127,9 @@ pub fn run(args: Args) -> Result<(), Box> { std::fs::write(src_dir.join("lib.rs"), lib_rs)?; std::fs::write(detail_dir.join("functions.rs"), functions_rs)?; std::fs::write(detail_dir.join("types.rs"), types_rs)?; + if !args.no_format { + format_generated_crate(&args.out_dir)?; + } Ok(()) } diff --git a/rust/tvm-ffi-stubgen/tests/stubgen.rs b/rust/tvm-ffi-stubgen/tests/stubgen.rs index b93357b61..d89159d3b 100644 --- a/rust/tvm-ffi-stubgen/tests/stubgen.rs +++ b/rust/tvm-ffi-stubgen/tests/stubgen.rs @@ -38,10 +38,11 @@ fn stubgen_tvm_ffi_testing() { let args = Args { out_dir: out_dir.clone(), dlls: vec![testing_lib.clone()], - init_prefix: "testing".to_string(), + init_prefix: vec!["testing".to_string()], init_crate: "tvm_ffi_testing_stub".to_string(), tvm_ffi_path: None, overwrite: true, + no_format: false, }; run(args).expect("stubgen run"); diff --git a/rust/tvm-ffi/src/macros.rs b/rust/tvm-ffi/src/macros.rs index 4fc5ef70b..7936f71e7 100644 --- a/rust/tvm-ffi/src/macros.rs +++ b/rust/tvm-ffi/src/macros.rs @@ -451,7 +451,7 @@ macro_rules! into_typed_fn { move || -> $ret_ty { Ok(_f.call_tuple_with_len::<0, _>(())?.try_into()?) } }}; // Case for 1 argument - ($f:expr, $trait:ident($t0:ty) -> $ret_ty:ty) => {{ + ($f:expr, $trait:ident($t0:ty $(,)?) -> $ret_ty:ty) => {{ let _f = $f; move |a0: $t0| -> $ret_ty { use $crate::function_internal::IntoArgHolderTuple; @@ -460,7 +460,7 @@ macro_rules! into_typed_fn { } }}; // Case for 2 arguments - ($f:expr, $trait:ident($t0:ty, $t1:ty) -> $ret_ty:ty) => {{ + ($f:expr, $trait:ident($t0:ty, $t1:ty $(,)?) -> $ret_ty:ty) => {{ let _f = $f; move |a0: $t0, a1: $t1| -> $ret_ty { use $crate::function_internal::IntoArgHolderTuple; @@ -469,7 +469,7 @@ macro_rules! into_typed_fn { } }}; // Case for 3 arguments - ($f:expr, $trait:ident($t0:ty, $t1:ty, $t2:ty) -> $ret_ty:ty) => {{ + ($f:expr, $trait:ident($t0:ty, $t1:ty, $t2:ty $(,)?) -> $ret_ty:ty) => {{ let _f = $f; move |a0: $t0, a1: $t1, a2: $t2| -> $ret_ty { use $crate::function_internal::IntoArgHolderTuple; @@ -478,7 +478,7 @@ macro_rules! into_typed_fn { } }}; // Case for 4 arguments - ($f:expr, $trait:ident($t0:ty, $t1:ty, $t2:ty, $t3:ty) -> $ret_ty:ty) => {{ + ($f:expr, $trait:ident($t0:ty, $t1:ty, $t2:ty, $t3:ty $(,)?) -> $ret_ty:ty) => {{ let _f = $f; move |a0: $t0, a1: $t1, a2: $t2, a3: $t3| -> $ret_ty { use $crate::function_internal::IntoArgHolderTuple; @@ -487,7 +487,7 @@ macro_rules! into_typed_fn { } }}; // Case for 5 arguments - ($f:expr, $trait:ident($t0:ty, $t1:ty, $t2:ty, $t3:ty, $t4:ty) -> $ret_ty:ty) => {{ + ($f:expr, $trait:ident($t0:ty, $t1:ty, $t2:ty, $t3:ty, $t4:ty $(,)?) -> $ret_ty:ty) => {{ let _f = $f; move |a0: $t0, a1: $t1, a2: $t2, a3: $t3, a4: $t4| -> $ret_ty { use $crate::function_internal::IntoArgHolderTuple; @@ -496,7 +496,7 @@ macro_rules! into_typed_fn { } }}; // Case for 6 arguments - ($f:expr, $trait:ident($t0:ty, $t1:ty, $t2:ty, $t3:ty, $t4:ty, $t5:ty) -> $ret_ty:ty) => {{ + ($f:expr, $trait:ident($t0:ty, $t1:ty, $t2:ty, $t3:ty, $t4:ty, $t5:ty $(,)?) -> $ret_ty:ty) => {{ let _f = $f; move |a0: $t0, a1: $t1, a2: $t2, a3: $t3, a4: $t4, a5: $t5| -> $ret_ty { use $crate::function_internal::IntoArgHolderTuple; @@ -505,7 +505,7 @@ macro_rules! into_typed_fn { } }}; // Case for 7 arguments - ($f:expr, $trait:ident($t0:ty, $t1:ty, $t2:ty, $t3:ty, $t4:ty, $t5:ty, $t6:ty) + ($f:expr, $trait:ident($t0:ty, $t1:ty, $t2:ty, $t3:ty, $t4:ty, $t5:ty, $t6:ty $(,)?) -> $ret_ty:ty) => {{ let _f = $f; move |a0: $t0, a1: $t1, a2: $t2, a3: $t3, a4: $t4, a5: $t5, a6: $t6| -> $ret_ty { @@ -515,7 +515,7 @@ macro_rules! into_typed_fn { } }}; // Case for 8 arguments - ($f:expr, $trait:ident($t0:ty, $t1:ty, $t2:ty, $t3:ty, $t4:ty, $t5:ty, $t6:ty, $t7:ty) + ($f:expr, $trait:ident($t0:ty, $t1:ty, $t2:ty, $t3:ty, $t4:ty, $t5:ty, $t6:ty, $t7:ty $(,)?) -> $ret_ty:ty) => {{ let _f = $f; move |a0: $t0, a1: $t1, a2: $t2, a3: $t3, a4: $t4, a5: $t5, a6: $t6, a7: $t7| -> $ret_ty { @@ -525,7 +525,7 @@ macro_rules! into_typed_fn { } }}; // Case for 9 arguments - ($f:expr, $trait:ident($t0:ty, $t1:ty, $t2:ty, $t3:ty, $t4:ty, $t5:ty, $t6:ty, $t7:ty, $t8:ty) + ($f:expr, $trait:ident($t0:ty, $t1:ty, $t2:ty, $t3:ty, $t4:ty, $t5:ty, $t6:ty, $t7:ty, $t8:ty $(,)?) -> $ret_ty:ty) => {{ let _f = $f; move |a0: $t0, @@ -544,7 +544,7 @@ macro_rules! into_typed_fn { } }}; // Case for 10 arguments - ($f:expr, $trait:ident($t0:ty, $t1:ty, $t2:ty, $t3:ty, $t4:ty, $t5:ty, $t6:ty, $t7:ty, $t8:ty, $t9:ty) + ($f:expr, $trait:ident($t0:ty, $t1:ty, $t2:ty, $t3:ty, $t4:ty, $t5:ty, $t6:ty, $t7:ty, $t8:ty, $t9:ty $(,)?) -> $ret_ty:ty) => {{ let _f = $f; move |a0: $t0, @@ -564,7 +564,7 @@ macro_rules! into_typed_fn { } }}; // Case for 11 arguments - ($f:expr, $trait:ident($t0:ty, $t1:ty, $t2:ty, $t3:ty, $t4:ty, $t5:ty, $t6:ty, $t7:ty, $t8:ty, $t9:ty, $t10:ty) + ($f:expr, $trait:ident($t0:ty, $t1:ty, $t2:ty, $t3:ty, $t4:ty, $t5:ty, $t6:ty, $t7:ty, $t8:ty, $t9:ty, $t10:ty $(,)?) -> $ret_ty:ty) => {{ let _f = $f; move |a0: $t0, @@ -585,7 +585,7 @@ macro_rules! into_typed_fn { } }}; // Case for 12 arguments - ($f:expr, $trait:ident($t0:ty, $t1:ty, $t2:ty, $t3:ty, $t4:ty, $t5:ty, $t6:ty, $t7:ty, $t8:ty, $t9:ty, $t10:ty, $t11:ty) + ($f:expr, $trait:ident($t0:ty, $t1:ty, $t2:ty, $t3:ty, $t4:ty, $t5:ty, $t6:ty, $t7:ty, $t8:ty, $t9:ty, $t10:ty, $t11:ty $(,)?) -> $ret_ty:ty) => {{ let _f = $f; move |a0: $t0, From 04283cedd9bcd93224f52b8a909fa7a4455eaac8 Mon Sep 17 00:00:00 2001 From: Huanqi Cao Date: Sun, 8 Mar 2026 23:15:17 +0800 Subject: [PATCH 23/29] feat(stubgen): gap-filling repr_c layout and relaxed parent resolution Rewrite check_repr_c to use a gap-filling strategy instead of strict layout validation. Byte ranges not covered by registered fields are emitted as `[u8; N]` padding members, handling C++ tail padding, vtable pointers, and unregistered fields uniformly. Key changes: - Replace strict contiguous-field checks with gap-filling layout builder - Parent types that are not in type_map or fail check_repr_c no longer cause the child to fall back; the parent region becomes a gap after the Object header - Fields whose type schema cannot be mapped to Rust are skipped and covered by gaps instead of failing the entire type - Handle Optional and ffi.Array with no type args (map to ObjectRef) - Remove alignment inference (infer_alignment_from_size, align_up) - Add log + env_logger for persistent debug/trace diagnostics (RUST_LOG=debug or RUST_LOG=trace) This reduces define_object_wrapper fallbacks from 160 to ~11 (93%) without any changes to the C++ TVM codebase. Made-with: Cursor --- rust/tvm-ffi-stubgen/Cargo.toml | 2 + rust/tvm-ffi-stubgen/src/generate.rs | 17 +- rust/tvm-ffi-stubgen/src/main.rs | 1 + rust/tvm-ffi-stubgen/src/repr_c.rs | 240 ++++++++++++++++++--------- 4 files changed, 181 insertions(+), 79 deletions(-) diff --git a/rust/tvm-ffi-stubgen/Cargo.toml b/rust/tvm-ffi-stubgen/Cargo.toml index 03375be95..fc106d4ef 100644 --- a/rust/tvm-ffi-stubgen/Cargo.toml +++ b/rust/tvm-ffi-stubgen/Cargo.toml @@ -28,7 +28,9 @@ path = "src/main.rs" [dependencies] clap = { version = "4.5", features = ["derive"] } +env_logger = "0.11.9" libloading = "0.8" +log = "0.4.29" serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" toml = "0.8" diff --git a/rust/tvm-ffi-stubgen/src/generate.rs b/rust/tvm-ffi-stubgen/src/generate.rs index 120333f78..f977371d8 100644 --- a/rust/tvm-ffi-stubgen/src/generate.rs +++ b/rust/tvm-ffi-stubgen/src/generate.rs @@ -235,7 +235,7 @@ fn build_getter_specs( ret_type: parent.ret_type.clone(), }); } - for f in &info.direct_fields { + for f in info.fields() { let method_name = format!("get_{}", f.rust_name); let access_expr = if f.is_pod { format!("self.data.{}", f.rust_name) @@ -833,8 +833,15 @@ fn render_repr_c_type( writeln!(out, "{}#[type_key = \"{}\"]", indent_str, ty.type_key).ok(); writeln!(out, "{}pub struct {} {{", indent_str, obj_name).ok(); writeln!(out, "{} parent: {},", indent_str, parent_ty).ok(); - for f in &info.direct_fields { - writeln!(out, "{} {}: {},", indent_str, f.rust_name, f.rust_type).ok(); + for entry in &info.layout { + match entry { + repr_c::LayoutEntry::Field(f) => { + writeln!(out, "{} {}: {},", indent_str, f.rust_name, f.rust_type).ok(); + } + repr_c::LayoutEntry::Gap { name, size } => { + writeln!(out, "{} {}: [u8; {}],", indent_str, name, size).ok(); + } + } } writeln!(out, "{}}}\n", indent_str).ok(); @@ -874,9 +881,9 @@ fn render_repr_c_type( writeln!(out).ok(); } - // Generate getter methods for direct fields + // Generate getter methods for typed fields writeln!(out, "{}impl {} {{", indent_str, ty.rust_name).ok(); - for f in &info.direct_fields { + for f in info.fields() { let method_name = format!("get_{}", f.rust_name); let access_expr = if f.is_pod { format!("self.data.{}", f.rust_name) diff --git a/rust/tvm-ffi-stubgen/src/main.rs b/rust/tvm-ffi-stubgen/src/main.rs index b6568c8a2..cc5b32576 100644 --- a/rust/tvm-ffi-stubgen/src/main.rs +++ b/rust/tvm-ffi-stubgen/src/main.rs @@ -19,6 +19,7 @@ use clap::Parser; use tvm_ffi_stubgen::{run, Args}; fn main() -> Result<(), Box> { + env_logger::init(); let args = Args::parse(); run(args) } diff --git a/rust/tvm-ffi-stubgen/src/repr_c.rs b/rust/tvm-ffi-stubgen/src/repr_c.rs index feb6c05e0..1e6d45cfb 100644 --- a/rust/tvm-ffi-stubgen/src/repr_c.rs +++ b/rust/tvm-ffi-stubgen/src/repr_c.rs @@ -17,9 +17,16 @@ //! Validation that a type has a compact C-compatible layout (check_repr_c) //! and extraction of field layout for repr(C) code generation. +//! +//! The strategy is gap-filling: given the parent struct size and the registered +//! field offsets/sizes, any byte range not covered by a known field is emitted +//! as a `[u8; N]` padding member. This handles C++ tail padding, vtable +//! pointers, and unregistered fields uniformly without requiring alignment +//! inference. use crate::ffi; use crate::schema::{extract_type_schema, parse_type_schema, TypeSchema}; +use log::{debug, trace}; use std::collections::BTreeMap; /// Result of check_repr_c: type passes and we have full layout for codegen. @@ -27,9 +34,17 @@ use std::collections::BTreeMap; pub(crate) struct ReprCInfo { /// Type key of the immediate parent (Object or a subclass). None for root types. pub(crate) parent_type_key: Option, - /// Direct fields of this type only (not inherited), sorted by offset. - /// For codegen: first field of *Obj is parent (or Object), then these. - pub(crate) direct_fields: Vec, + /// Ordered layout entries (fields and gaps) covering [parent_total_size .. total_size). + pub(crate) layout: Vec, +} + +/// A single entry in the repr(C) struct body after the parent. +#[derive(Debug, Clone)] +pub(crate) enum LayoutEntry { + /// A known, typed field. + Field(ReprCField), + /// An opaque gap (padding, vtable pointer, or unregistered field). + Gap { name: String, size: i64 }, } #[derive(Debug, Clone)] @@ -37,59 +52,114 @@ pub(crate) struct ReprCField { pub(crate) rust_name: String, pub(crate) offset: i64, pub(crate) size: i64, - pub(crate) alignment: i64, /// Rust type name for the field (e.g. "i64", "Shape"). pub(crate) rust_type: String, - /// True if Copy type (getter returns value); false if Ref (getter returns Ref and clone). + /// True if Copy type (getter returns value); false if Ref (getter returns clone). pub(crate) is_pod: bool, } -/// Returns ReprCInfo if the type passes check_repr_c; None otherwise. +impl ReprCInfo { + /// Iterate only the typed fields (skipping gaps). + pub(crate) fn fields(&self) -> impl Iterator { + self.layout.iter().filter_map(|e| match e { + LayoutEntry::Field(f) => Some(f), + LayoutEntry::Gap { .. } => None, + }) + } +} + +/// Returns ReprCInfo if the type can be laid out as repr(C); None otherwise. +/// +/// Failure reasons (all logged at DEBUG level): +/// - No type info registered at all +/// - Metadata missing or total_size unknown +/// - Parent type not in type_map or parent itself fails +/// - A field's type schema cannot be mapped to a Rust type pub(crate) fn check_repr_c( type_key: &str, type_map: &BTreeMap, ) -> Option { - let info = ffi::get_type_info(type_key)?; - let total_size = total_size_from_info(info)?; - if total_size <= 0 { - return None; - } - - let parent_type_key = if info.type_depth > 0 && !info.type_acenstors.is_null() { - // Direct parent is ancestor[type_depth - 1]; ancestor[0] is the root. - let ancestor_ptr = unsafe { *info.type_acenstors.add((info.type_depth - 1) as usize) }; - if ancestor_ptr.is_null() { - return None; - } - let parent_info = unsafe { &*ancestor_ptr }; - let key = ffi::byte_array_to_string_opt(&parent_info.type_key)?; - if key != "ffi.Object" && !type_map.contains_key(&key) { + let info = match ffi::get_type_info(type_key) { + Some(i) => i, + None => { + debug!("{}: no type info registered", type_key); return None; } - if !check_repr_c(&key, type_map).is_some() { + }; + let total_size = match total_size_from_info(info) { + Some(s) if s > 0 => s as i64, + _ => { + debug!("{}: metadata missing or total_size <= 0", type_key); return None; } - Some(key) - } else { - None }; + trace!( + "{}: total_size={}, type_depth={}, num_fields={}, num_methods={}", + type_key, total_size, info.type_depth, info.num_fields, info.num_methods + ); - let parent_total_size: i64 = if let Some(ref parent_key) = parent_type_key { - let parent_info = ffi::get_type_info(parent_key)?; - total_size_from_info(parent_info)? as i64 + // Resolve parent. + // If the direct parent is in type_map and passes check_repr_c, we use it as + // the typed parent field. Otherwise we fall back to ffi.Object as the parent + // and let gap-filling cover the bytes between Object and our first field. + let obj_size = { + let oi = ffi::get_type_info("ffi.Object")?; + total_size_from_info(oi)? as i64 + }; + let (parent_type_key, parent_total_size) = if info.type_depth > 0 && !info.type_acenstors.is_null() { + let ancestor_ptr = unsafe { *info.type_acenstors.add((info.type_depth - 1) as usize) }; + let direct_parent_key = if !ancestor_ptr.is_null() { + let pi = unsafe { &*ancestor_ptr }; + ffi::byte_array_to_string_opt(&pi.type_key) + } else { + None + }; + match direct_parent_key { + Some(ref key) if key == "ffi.Object" => { + (None, obj_size) + } + Some(ref key) if type_map.contains_key(key) && check_repr_c(key, type_map).is_some() => { + let pi = ffi::get_type_info(key)?; + let ps = total_size_from_info(pi)? as i64; + trace!("{}: parent='{}' (typed, size={})", type_key, key, ps); + (Some(key.clone()), ps) + } + Some(ref key) => { + // Parent exists but not mappable — use Object as parent, gap covers the rest. + trace!("{}: parent='{}' not mappable, falling back to Object", type_key, key); + (None, obj_size) + } + None => (None, obj_size), + } } else { - // Root type: first field starts after Object header. Use Object's registered size. - let obj_info = ffi::get_type_info("ffi.Object")?; - total_size_from_info(obj_info)? as i64 + (None, obj_size) }; + trace!("{}: parent={:?}, parent_total_size={}", type_key, parent_type_key, parent_total_size); - let mut direct_fields: Vec = Vec::new(); + // Collect and sort fields that belong to this type (offset >= parent_total_size). + let mut typed_fields: Vec = Vec::new(); if info.num_fields > 0 && !info.fields.is_null() { let field_slice = unsafe { std::slice::from_raw_parts(info.fields, info.num_fields as usize) }; for field in field_slice { - let name = ffi::byte_array_to_string_opt(&field.name)?; - if field.offset < 0 || field.size < 0 || field.alignment <= 0 { + let name = match ffi::byte_array_to_string_opt(&field.name) { + Some(n) => n, + None => { + debug!("{}: a field name is unreadable", type_key); + return None; + } + }; + // Skip inherited fields (registered by parent's ObjectDef) + if field.offset < parent_total_size { + trace!("{}: field '{}' at offset={} belongs to parent, skipping", type_key, name, field.offset); + continue; + } + trace!( + "{}: field '{}': offset={}, size={}", + type_key, name, field.offset, field.size + ); + if field.offset < 0 || field.size < 0 { + debug!("{}: field '{}' has invalid offset/size", type_key, name); return None; } let meta = ffi::byte_array_to_string_opt(&field.metadata); @@ -97,45 +167,80 @@ pub(crate) fn check_repr_c( .as_deref() .and_then(extract_type_schema) .and_then(|s| parse_type_schema(&s)); - let (rust_type, is_pod) = - repr_c_field_type(schema.as_ref(), type_map, type_key, field.size)?; - direct_fields.push(ReprCField { - rust_name: sanitize_ident(&name, IdentStyle::Function), + trace!( + "{}: field '{}' schema origin={:?}", + type_key, name, schema.as_ref().map(|s| &s.origin) + ); + let mapped = repr_c_field_type(schema.as_ref(), type_map, type_key, field.size); + let (rust_type, is_pod) = match mapped { + Some(v) => v, + None => { + debug!( + "{}: field '{}' type not mappable, will be covered by gap (schema_origin={:?})", + type_key, name, schema.as_ref().map(|s| &s.origin) + ); + continue; + } + }; + trace!("{}: field '{}' -> rust_type='{}', is_pod={}", type_key, name, rust_type, is_pod); + typed_fields.push(ReprCField { + rust_name: sanitize_ident(&name), offset: field.offset, size: field.size, - alignment: field.alignment, rust_type, is_pod, }); } } + typed_fields.sort_by_key(|f| f.offset); - direct_fields.sort_by_key(|f| f.offset); - - let first_offset = direct_fields - .first() - .map(|f| f.offset) - .unwrap_or(parent_total_size); - if first_offset != parent_total_size { - return None; - } - + // Build layout by walking [parent_total_size .. total_size) and inserting + // gaps wherever there is no registered field. + let mut layout = Vec::new(); let mut pos = parent_total_size; - for f in &direct_fields { - // Only allow alignment padding (field must start at aligned position) - let aligned_pos = align_up(pos, f.alignment); - if f.offset != aligned_pos { + let mut gap_idx = 0usize; + for f in &typed_fields { + if f.offset > pos { + let gap_size = f.offset - pos; + trace!("{}: gap at {}..{} ({} bytes)", type_key, pos, f.offset, gap_size); + layout.push(LayoutEntry::Gap { + name: format!("_gap{}", gap_idx), + size: gap_size, + }); + gap_idx += 1; + pos = f.offset; + } + if f.offset < pos { + // Overlapping fields — shouldn't happen, bail out. + debug!("{}: field '{}' at offset={} overlaps pos={}", type_key, f.rust_name, f.offset, pos); return None; } + layout.push(LayoutEntry::Field(f.clone())); pos = f.offset + f.size; } - if pos != total_size as i64 { + // Trailing gap (tail padding, or fields after last registered one) + if pos < total_size { + let gap_size = total_size - pos; + trace!("{}: trailing gap at {}..{} ({} bytes)", type_key, pos, total_size, gap_size); + layout.push(LayoutEntry::Gap { + name: format!("_gap{}", gap_idx), + size: gap_size, + }); + } else if pos > total_size { + debug!("{}: fields exceed total_size (pos={} > total_size={})", type_key, pos, total_size); return None; } + debug!( + "{}: repr_c OK ({} fields, {} gaps, {} layout entries)", + type_key, + typed_fields.len(), + layout.iter().filter(|e| matches!(e, LayoutEntry::Gap { .. })).count(), + layout.len() + ); Some(ReprCInfo { parent_type_key, - direct_fields, + layout, }) } @@ -150,14 +255,6 @@ fn total_size_from_info(info: &tvm_ffi::tvm_ffi_sys::TVMFFITypeInfo) -> Option i64 { - if alignment <= 0 { - return value; - } - (value + alignment - 1) / alignment * alignment -} - /// Map schema to (rust_type_name, is_pod). None if not repr_c compatible. fn repr_c_field_type( schema: Option<&TypeSchema>, @@ -174,12 +271,12 @@ fn repr_c_field_type( 2 => Some(("i16".to_string(), true)), 4 => Some(("i32".to_string(), true)), 8 => Some(("i64".to_string(), true)), - _ => None, // Unsupported int size + _ => None, }, "float" => match field_size { 4 => Some(("f32".to_string(), true)), 8 => Some(("f64".to_string(), true)), - _ => None, // Unsupported float size + _ => None, }, "Device" => Some(("tvm_ffi::DLDevice".to_string(), true)), "DataType" => Some(("tvm_ffi::DLDataType".to_string(), true)), @@ -195,6 +292,7 @@ fn repr_c_field_type( "Optional" => match schema.args.as_slice() { [inner] => repr_c_field_type(Some(inner), type_map, _self_type_key, field_size) .map(|(inner_ty, pod)| (format!("Option<{}>", inner_ty), pod)), + [] => Some(("Option".to_string(), false)), _ => None, }, "ffi.Array" => match schema.args.as_slice() { @@ -203,6 +301,7 @@ fn repr_c_field_type( repr_c_field_type(Some(inner), type_map, _self_type_key, field_size)?; Some((format!("tvm_ffi::Array<{}>", inner_ty), false)) } + [] => Some(("tvm_ffi::Array".to_string(), false)), _ => None, }, "ffi.Map" => match schema.args.as_slice() { @@ -217,12 +316,7 @@ fn repr_c_field_type( } } -#[derive(Clone, Copy)] -enum IdentStyle { - Function, -} - -fn sanitize_ident(name: &str, style: IdentStyle) -> String { +fn sanitize_ident(name: &str) -> String { let mut out = String::new(); for ch in name.chars() { if ch.is_ascii_alphanumeric() || ch == '_' { @@ -246,7 +340,5 @@ fn sanitize_ident(name: &str, style: IdentStyle) -> String { if KEYWORDS.contains(&out.as_str()) { out.push('_'); } - match style { - IdentStyle::Function => out, - } + out } From 614a4cd3775236306a9d34ab0ce017f620c00733 Mon Sep 17 00:00:00 2001 From: Huanqi Cao Date: Sun, 8 Mar 2026 23:18:41 +0800 Subject: [PATCH 24/29] style(stubgen): cargo fmt repr_c.rs Made-with: Cursor --- rust/tvm-ffi-stubgen/src/repr_c.rs | 130 ++++++++++++++++++++--------- 1 file changed, 92 insertions(+), 38 deletions(-) diff --git a/rust/tvm-ffi-stubgen/src/repr_c.rs b/rust/tvm-ffi-stubgen/src/repr_c.rs index 1e6d45cfb..e4f084015 100644 --- a/rust/tvm-ffi-stubgen/src/repr_c.rs +++ b/rust/tvm-ffi-stubgen/src/repr_c.rs @@ -95,7 +95,11 @@ pub(crate) fn check_repr_c( }; trace!( "{}: total_size={}, type_depth={}, num_fields={}, num_methods={}", - type_key, total_size, info.type_depth, info.num_fields, info.num_methods + type_key, + total_size, + info.type_depth, + info.num_fields, + info.num_methods ); // Resolve parent. @@ -106,35 +110,45 @@ pub(crate) fn check_repr_c( let oi = ffi::get_type_info("ffi.Object")?; total_size_from_info(oi)? as i64 }; - let (parent_type_key, parent_total_size) = if info.type_depth > 0 && !info.type_acenstors.is_null() { - let ancestor_ptr = unsafe { *info.type_acenstors.add((info.type_depth - 1) as usize) }; - let direct_parent_key = if !ancestor_ptr.is_null() { - let pi = unsafe { &*ancestor_ptr }; - ffi::byte_array_to_string_opt(&pi.type_key) + let (parent_type_key, parent_total_size) = + if info.type_depth > 0 && !info.type_acenstors.is_null() { + let ancestor_ptr = unsafe { *info.type_acenstors.add((info.type_depth - 1) as usize) }; + let direct_parent_key = if !ancestor_ptr.is_null() { + let pi = unsafe { &*ancestor_ptr }; + ffi::byte_array_to_string_opt(&pi.type_key) + } else { + None + }; + match direct_parent_key { + Some(ref key) if key == "ffi.Object" => (None, obj_size), + Some(ref key) + if type_map.contains_key(key) && check_repr_c(key, type_map).is_some() => + { + let pi = ffi::get_type_info(key)?; + let ps = total_size_from_info(pi)? as i64; + trace!("{}: parent='{}' (typed, size={})", type_key, key, ps); + (Some(key.clone()), ps) + } + Some(ref key) => { + // Parent exists but not mappable — use Object as parent, gap covers the rest. + trace!( + "{}: parent='{}' not mappable, falling back to Object", + type_key, + key + ); + (None, obj_size) + } + None => (None, obj_size), + } } else { - None + (None, obj_size) }; - match direct_parent_key { - Some(ref key) if key == "ffi.Object" => { - (None, obj_size) - } - Some(ref key) if type_map.contains_key(key) && check_repr_c(key, type_map).is_some() => { - let pi = ffi::get_type_info(key)?; - let ps = total_size_from_info(pi)? as i64; - trace!("{}: parent='{}' (typed, size={})", type_key, key, ps); - (Some(key.clone()), ps) - } - Some(ref key) => { - // Parent exists but not mappable — use Object as parent, gap covers the rest. - trace!("{}: parent='{}' not mappable, falling back to Object", type_key, key); - (None, obj_size) - } - None => (None, obj_size), - } - } else { - (None, obj_size) - }; - trace!("{}: parent={:?}, parent_total_size={}", type_key, parent_type_key, parent_total_size); + trace!( + "{}: parent={:?}, parent_total_size={}", + type_key, + parent_type_key, + parent_total_size + ); // Collect and sort fields that belong to this type (offset >= parent_total_size). let mut typed_fields: Vec = Vec::new(); @@ -151,12 +165,20 @@ pub(crate) fn check_repr_c( }; // Skip inherited fields (registered by parent's ObjectDef) if field.offset < parent_total_size { - trace!("{}: field '{}' at offset={} belongs to parent, skipping", type_key, name, field.offset); + trace!( + "{}: field '{}' at offset={} belongs to parent, skipping", + type_key, + name, + field.offset + ); continue; } trace!( "{}: field '{}': offset={}, size={}", - type_key, name, field.offset, field.size + type_key, + name, + field.offset, + field.size ); if field.offset < 0 || field.size < 0 { debug!("{}: field '{}' has invalid offset/size", type_key, name); @@ -169,7 +191,9 @@ pub(crate) fn check_repr_c( .and_then(|s| parse_type_schema(&s)); trace!( "{}: field '{}' schema origin={:?}", - type_key, name, schema.as_ref().map(|s| &s.origin) + type_key, + name, + schema.as_ref().map(|s| &s.origin) ); let mapped = repr_c_field_type(schema.as_ref(), type_map, type_key, field.size); let (rust_type, is_pod) = match mapped { @@ -182,7 +206,13 @@ pub(crate) fn check_repr_c( continue; } }; - trace!("{}: field '{}' -> rust_type='{}', is_pod={}", type_key, name, rust_type, is_pod); + trace!( + "{}: field '{}' -> rust_type='{}', is_pod={}", + type_key, + name, + rust_type, + is_pod + ); typed_fields.push(ReprCField { rust_name: sanitize_ident(&name), offset: field.offset, @@ -202,7 +232,13 @@ pub(crate) fn check_repr_c( for f in &typed_fields { if f.offset > pos { let gap_size = f.offset - pos; - trace!("{}: gap at {}..{} ({} bytes)", type_key, pos, f.offset, gap_size); + trace!( + "{}: gap at {}..{} ({} bytes)", + type_key, + pos, + f.offset, + gap_size + ); layout.push(LayoutEntry::Gap { name: format!("_gap{}", gap_idx), size: gap_size, @@ -212,7 +248,10 @@ pub(crate) fn check_repr_c( } if f.offset < pos { // Overlapping fields — shouldn't happen, bail out. - debug!("{}: field '{}' at offset={} overlaps pos={}", type_key, f.rust_name, f.offset, pos); + debug!( + "{}: field '{}' at offset={} overlaps pos={}", + type_key, f.rust_name, f.offset, pos + ); return None; } layout.push(LayoutEntry::Field(f.clone())); @@ -221,13 +260,22 @@ pub(crate) fn check_repr_c( // Trailing gap (tail padding, or fields after last registered one) if pos < total_size { let gap_size = total_size - pos; - trace!("{}: trailing gap at {}..{} ({} bytes)", type_key, pos, total_size, gap_size); + trace!( + "{}: trailing gap at {}..{} ({} bytes)", + type_key, + pos, + total_size, + gap_size + ); layout.push(LayoutEntry::Gap { name: format!("_gap{}", gap_idx), size: gap_size, }); } else if pos > total_size { - debug!("{}: fields exceed total_size (pos={} > total_size={})", type_key, pos, total_size); + debug!( + "{}: fields exceed total_size (pos={} > total_size={})", + type_key, pos, total_size + ); return None; } @@ -235,7 +283,10 @@ pub(crate) fn check_repr_c( "{}: repr_c OK ({} fields, {} gaps, {} layout entries)", type_key, typed_fields.len(), - layout.iter().filter(|e| matches!(e, LayoutEntry::Gap { .. })).count(), + layout + .iter() + .filter(|e| matches!(e, LayoutEntry::Gap { .. })) + .count(), layout.len() ); Some(ReprCInfo { @@ -301,7 +352,10 @@ fn repr_c_field_type( repr_c_field_type(Some(inner), type_map, _self_type_key, field_size)?; Some((format!("tvm_ffi::Array<{}>", inner_ty), false)) } - [] => Some(("tvm_ffi::Array".to_string(), false)), + [] => Some(( + "tvm_ffi::Array".to_string(), + false, + )), _ => None, }, "ffi.Map" => match schema.args.as_slice() { From 8e2068cb2abd5cc501ad10e59d119797065d61c2 Mon Sep 17 00:00:00 2001 From: Huanqi Cao Date: Sun, 8 Mar 2026 23:23:01 +0800 Subject: [PATCH 25/29] docs(stubgen): update README for gap-filling repr_c strategy Made-with: Cursor --- rust/tvm-ffi-stubgen/README.md | 55 ++++++++++++++++++++++++---------- 1 file changed, 39 insertions(+), 16 deletions(-) diff --git a/rust/tvm-ffi-stubgen/README.md b/rust/tvm-ffi-stubgen/README.md index f3cdf2f1c..2db988142 100644 --- a/rust/tvm-ffi-stubgen/README.md +++ b/rust/tvm-ffi-stubgen/README.md @@ -53,9 +53,9 @@ pub fn echo(args: &[Any]) -> Result { ... } #### repr(C) path (preferred) -For layout-compatible object types: +For types with known `total_size`: -- `#[repr(C)] Obj` +- `#[repr(C)] Obj` with typed fields and `[u8; N]` gaps - `#[derive(ObjectRef, Clone)] ` - `impl_object_hierarchy!(...)` - direct-field `get_` accessors @@ -64,16 +64,20 @@ Example shape: ```rust #[repr(C)] -pub struct TestObjectDerivedObj { - parent: TestObjectBaseObj, - v_map: tvm_ffi::Map, - v_array: tvm_ffi::Array, +pub struct PrimExprObj { + parent: BaseExprObj, + dtype: tvm_ffi::DLDataType, + _gap0: [u8; 4], // C++ tail padding } ``` +Gaps cover C++ tail padding, vtable pointers, and fields whose type schema +is not mappable to Rust. This allows the vast majority of types to use +repr(C) layout even when metadata is incomplete. + #### fallback wrapper path -For non-repr(C)-compatible types: +For types without `total_size` metadata (no `ObjectDef` registered): - `define_object_wrapper!(Type, "type.key")` - field access via `FieldGetter` @@ -152,16 +156,23 @@ This avoids custom cast traits and keeps compile-time type constraints explicit. ## repr(C) Decision Rules -`check_repr_c` gates repr(C) generation. +`check_repr_c` gates repr(C) generation using a **gap-filling** strategy. + +### Hard requirements (cause fallback to `define_object_wrapper!`) + +- Type must have `total_size > 0` (i.e. `ObjectDef` was called for it) +- No overlapping fields -### Required metadata checks +### Soft handling (does NOT cause fallback) -- `total_size > 0` -- valid field `offset/size/alignment` -- field order and no overlap -- aligned placement (`field.offset == align_up(pos, alignment)`) -- parent boundary matches first direct field offset -- parent type is also repr(C)-compatible +- **Tail padding / vtable / unregistered fields**: byte ranges between registered + fields (or between the last field and `total_size`) are emitted as `[u8; N]` gap + members in the `#[repr(C)]` struct. +- **Parent type not in type_map or not repr(C)-compatible**: the parent region is + treated as a gap after the `Object` header. The struct uses `tvm_ffi::object::Object` + as the parent field and gap-fills the bytes between Object and the first known field. +- **Field type schema not mappable to Rust**: the field is skipped in the struct layout + (covered by a gap) but still accessible via runtime `FieldGetter` if needed. ### Schema mapping rules @@ -169,7 +180,10 @@ Representative mappings include: - `Any` / `ffi.Any` -> `tvm_ffi::AnyValue` - `ffi.Array` -> `tvm_ffi::Array` +- `ffi.Array` (no args) -> `tvm_ffi::Array` - `ffi.Map` -> `tvm_ffi::Map` +- `Optional` -> `Option` +- `Optional` (no args) -> `Option` ## Multi-Prefix Generation @@ -205,7 +219,16 @@ Generated user-facing code is intended to remain safe Rust. ### Built-in filtering and fallback - built-in `ffi.*` primitives are not re-generated as wrapper types -- unsupported/non-layout-compatible object types fall back to `define_object_wrapper!` +- only types without `total_size` metadata fall back to `define_object_wrapper!` +- types with incomplete field schemas or unmappable parents still get repr(C) layout + via gap-filling + +### Logging + +Stubgen uses the `log` crate. Set `RUST_LOG` to control verbosity: + +- `RUST_LOG=debug` — shows repr(C) pass/fail decisions and field mapping failures +- `RUST_LOG=trace` — additionally shows per-field offset/size/schema details ## Related User Guide From d39c4fbdad671bf359b8f511f6ebda4f80494646 Mon Sep 17 00:00:00 2001 From: Huanqi Cao Date: Sun, 8 Mar 2026 23:38:53 +0800 Subject: [PATCH 26/29] fix(stubgen): rename base-class field to __tvm_ffi_object_parent Avoids name collision when a C++ struct has a data field also named "parent". Update README examples and the inherited getter access path in build_getter_specs to match the new field name. docs(stubgen): add TODO section for known open issues - Ancestor chain is truncated when direct parent is not repr(C)-mappable: the chain collapses to [ObjectRef], losing intermediate mappable ancestors and breaking upcast/downcast through those types. - Common interface needed between fallback (define_object_wrapper!) and repr(C) paths to avoid source-compatibility breaks when stubgen version changes cause a type to move between generation strategies. Made-with: Cursor --- rust/tvm-ffi-stubgen/README.md | 54 ++++++++++++++++++++++++++-- rust/tvm-ffi-stubgen/src/generate.rs | 4 +-- 2 files changed, 54 insertions(+), 4 deletions(-) diff --git a/rust/tvm-ffi-stubgen/README.md b/rust/tvm-ffi-stubgen/README.md index 2db988142..a8d8db50d 100644 --- a/rust/tvm-ffi-stubgen/README.md +++ b/rust/tvm-ffi-stubgen/README.md @@ -12,6 +12,7 @@ This document is design-oriented and focuses on generated interface forms and im - [Subtyping and Cast Rules](#subtyping-and-cast-rules) - [repr(C) Decision Rules](#reprc-decision-rules) - [Safety and Fallback Strategy](#safety-and-fallback-strategy) +- [TODO](#todo) - [Related User Guide](#related-user-guide) ## Document Scope @@ -65,7 +66,7 @@ Example shape: ```rust #[repr(C)] pub struct PrimExprObj { - parent: BaseExprObj, + __tvm_ffi_object_parent: BaseExprObj, dtype: tvm_ffi::DLDataType, _gap0: [u8; 4], // C++ tail padding } @@ -108,7 +109,7 @@ Derived object stores parent object as first field: ```rust #[repr(C)] pub struct DerivedObj { - parent: BaseObj, + __tvm_ffi_object_parent: BaseObj, extra: i64, } ``` @@ -230,6 +231,55 @@ Stubgen uses the `log` crate. Set `RUST_LOG` to control verbosity: - `RUST_LOG=debug` — shows repr(C) pass/fail decisions and field mapping failures - `RUST_LOG=trace` — additionally shows per-field offset/size/schema details +## TODO + +Known gaps and design issues that remain open. + +### Ancestor chain is truncated when direct parent is not repr(C)-mappable + +When `check_repr_c` cannot map the direct parent type, `repr_c.rs` falls back to +`tvm_ffi::object::Object` as the layout parent and fills the missing bytes with a gap. +However, the second pass in `generate.rs` that builds `ancestor_chain` only propagates +through types whose `parent_type_key` is set in `ReprCInfo`; when it is `None` the chain +collapses to `[tvm_ffi::object::ObjectRef]`. + +Consequence: if the C++ hierarchy is `Object → A (mappable) → B (not mappable) → C`, +the generated code for `C` emits + +```rust +tvm_ffi::impl_object_hierarchy!(C: tvm_ffi::object::ObjectRef); +``` + +instead of + +```rust +tvm_ffi::impl_object_hierarchy!(C: A, tvm_ffi::object::ObjectRef); +``` + +This means `From for A` and `TryFrom for C` are not generated, and getters +inherited from `A` are inaccessible via deref on `C` even though the layout is correct. + +The ancestor chain logic should be derived from the runtime type ancestry table +(`TVMFFIGetTypeInfo → type_acenstors`) independently of layout mappability, so that +upcast/downcast correctness is preserved regardless of whether every intermediate type +has a usable repr(C) layout. + +### Common interface between fallback and repr(C) paths + +`define_object_wrapper!` types and repr(C) types currently expose different API surfaces: + +- repr(C) types: `Deref` chain, `From`/`TryFrom`, direct `get_*` accessors +- fallback types: `from_object` / `as_object_ref` / `into_object_ref`, runtime `FieldGetter` + +Code that depends on a given type must know which generation path was used, and that +path can change between stubgen versions as reflection metadata improves. A type that +was a thin wrapper in version N may become a repr(C) type in version N+1, silently +breaking downstream call sites that relied on `from_object` or `as_object_ref`. + +A stable, version-independent interface layer is needed so that user code does not need +to distinguish between the two paths, and so that crates built against one stubgen +version remain source-compatible with crates built against a later one. + ## Related User Guide For generation command-line usage and step-by-step invocation examples, see: diff --git a/rust/tvm-ffi-stubgen/src/generate.rs b/rust/tvm-ffi-stubgen/src/generate.rs index f977371d8..cc81a439d 100644 --- a/rust/tvm-ffi-stubgen/src/generate.rs +++ b/rust/tvm-ffi-stubgen/src/generate.rs @@ -223,7 +223,7 @@ fn build_getter_specs( for parent in parent_specs { let access_expr = if parent.access_expr.starts_with("self.data.") { format!( - "self.data.parent.{}", + "self.data.__tvm_ffi_object_parent.{}", &parent.access_expr["self.data.".len()..] ) } else { @@ -832,7 +832,7 @@ fn render_repr_c_type( writeln!(out, "{}#[derive(tvm_ffi::derive::Object)]", indent_str).ok(); writeln!(out, "{}#[type_key = \"{}\"]", indent_str, ty.type_key).ok(); writeln!(out, "{}pub struct {} {{", indent_str, obj_name).ok(); - writeln!(out, "{} parent: {},", indent_str, parent_ty).ok(); + writeln!(out, "{} __tvm_ffi_object_parent: {},", indent_str, parent_ty).ok(); for entry in &info.layout { match entry { repr_c::LayoutEntry::Field(f) => { From 6eb9c9c230a9c83df95575d71cb9d04c8ba8837f Mon Sep 17 00:00:00 2001 From: Huanqi Cao Date: Mon, 9 Mar 2026 00:35:39 +0800 Subject: [PATCH 27/29] refactor(stubgen): generalize parent_range_fields to non_layout_fields Track ALL registered ObjectDef fields that cannot become direct repr(C) struct members in a unified non_layout_fields list: - Parent-range fields (offset < parent_total_size), with or without mappable schema. - Own-range fields whose schema is not mappable to a Rust type. Generate FieldGetter accessors for every entry: - Typed fields (mappable schema): get_xxx() -> T via FieldGetter::get(). - Untyped fields (unmappable schema): get_xxx() -> tvm_ffi::Any via get_any(), with tvm_ffi::Any as the FieldGetter type parameter. Both variants use the same infallible get_* naming convention as direct struct-field getters, preserving interface compatibility across access paths. Made-with: Cursor --- rust/tvm-ffi-stubgen/src/generate.rs | 75 +++++++++++++++++++++++++++- rust/tvm-ffi-stubgen/src/repr_c.rs | 56 ++++++++++++++++++--- 2 files changed, 124 insertions(+), 7 deletions(-) diff --git a/rust/tvm-ffi-stubgen/src/generate.rs b/rust/tvm-ffi-stubgen/src/generate.rs index cc81a439d..55d441430 100644 --- a/rust/tvm-ffi-stubgen/src/generate.rs +++ b/rust/tvm-ffi-stubgen/src/generate.rs @@ -832,7 +832,12 @@ fn render_repr_c_type( writeln!(out, "{}#[derive(tvm_ffi::derive::Object)]", indent_str).ok(); writeln!(out, "{}#[type_key = \"{}\"]", indent_str, ty.type_key).ok(); writeln!(out, "{}pub struct {} {{", indent_str, obj_name).ok(); - writeln!(out, "{} __tvm_ffi_object_parent: {},", indent_str, parent_ty).ok(); + writeln!( + out, + "{} __tvm_ffi_object_parent: {},", + indent_str, parent_ty + ) + .ok(); for entry in &info.layout { match entry { repr_c::LayoutEntry::Field(f) => { @@ -901,6 +906,35 @@ fn render_repr_c_type( } writeln!(out, "{}}}\n", indent_str).ok(); + // Generate FieldGetter statics for non-layout fields. + // Each entry is a registered ObjectDef field that couldn't become a direct struct + // member (parent-range offset or unmappable schema). + for nlf in &info.non_layout_fields { + let static_name = static_ident("FIELD", &format!("{}::{}", ty.type_key, nlf.name)); + // Use the concrete mapped type when available; fall back to tvm_ffi::Any so the + // static can still be constructed (get_any() has no type constraints). + let static_ty = nlf.rust_type.as_deref().unwrap_or("tvm_ffi::Any"); + writeln!( + out, + "{}static {}: std::sync::LazyLock> = std::sync::LazyLock::new(|| {{", + indent_str, static_name, static_ty + ) + .ok(); + writeln!( + out, + "{} tvm_ffi::object_wrapper::FieldGetter::new(\"{}\", \"{}\")", + indent_str, ty.type_key, nlf.name + ) + .ok(); + writeln!( + out, + "{} .expect(\"non-layout field {} must be registered in TVM reflection\")", + indent_str, nlf.name + ) + .ok(); + writeln!(out, "{}}});", indent_str).ok(); + } + // Generate method statics and impls for method in &ty.methods { render_method_static(out, ty, method, indent); @@ -909,6 +943,45 @@ fn render_repr_c_type( for method in &ty.methods { render_method(out, ty, method, indent + 4); } + // Generate FieldGetter accessor methods for non-layout fields. + // Interface matches direct struct-field getters: `get_*` naming, infallible return. + // Typed fields (mappable schema) return the concrete type via get(). + // Untyped fields (unmappable schema) return tvm_ffi::Any via get_any(). + for nlf in &info.non_layout_fields { + let static_name = static_ident("FIELD", &format!("{}::{}", ty.type_key, nlf.name)); + let method_name = format!("get_{}", nlf.rust_name); + let (return_type, call_expr) = if let Some(rt) = &nlf.rust_type { + ( + rt.as_str(), + format!( + "{}.get(&__obj).expect(\"non-layout field {} should be accessible\")", + static_name, nlf.name + ), + ) + } else { + ( + "tvm_ffi::Any", + format!( + "{}.get_any(&__obj).expect(\"non-layout field {} should be accessible\")", + static_name, nlf.name + ), + ) + }; + writeln!( + out, + "{} pub fn {}(&self) -> {} {{", + indent_str, method_name, return_type + ) + .ok(); + writeln!( + out, + "{} let __obj: tvm_ffi::object::ObjectRef = self.clone().into();", + indent_str + ) + .ok(); + writeln!(out, "{} {}", indent_str, call_expr).ok(); + writeln!(out, "{} }}", indent_str).ok(); + } writeln!(out, "{}}}\n", indent_str).ok(); } diff --git a/rust/tvm-ffi-stubgen/src/repr_c.rs b/rust/tvm-ffi-stubgen/src/repr_c.rs index e4f084015..5bb5af0c1 100644 --- a/rust/tvm-ffi-stubgen/src/repr_c.rs +++ b/rust/tvm-ffi-stubgen/src/repr_c.rs @@ -36,6 +36,24 @@ pub(crate) struct ReprCInfo { pub(crate) parent_type_key: Option, /// Ordered layout entries (fields and gaps) covering [parent_total_size .. total_size). pub(crate) layout: Vec, + /// Fields registered in this type's ObjectDef that are NOT part of the repr(C) struct + /// layout. Two causes: (1) offset < parent_total_size — the field occupies a slot + /// within the parent's address range; (2) schema not mappable — the field's type + /// cannot be expressed as a repr(C) Rust type. All such fields can still be read at + /// runtime via FieldGetter. + pub(crate) non_layout_fields: Vec, +} + +/// A registered field that does not appear in the repr(C) struct layout. +#[derive(Debug, Clone)] +pub(crate) struct NonLayoutField { + /// Original C++ field name (used as the FieldGetter key). + pub(crate) name: String, + /// Sanitized Rust identifier (used as the getter method name suffix). + pub(crate) rust_name: String, + /// Mapped Rust type string, or None if the schema could not be mapped. + /// When None, the generated getter returns `tvm_ffi::Any` via get_any(). + pub(crate) rust_type: Option, } /// A single entry in the repr(C) struct body after the parent. @@ -151,7 +169,10 @@ pub(crate) fn check_repr_c( ); // Collect and sort fields that belong to this type (offset >= parent_total_size). + // Any registered field that cannot become a direct struct member is tracked in + // non_layout_fields so it can be exposed via a FieldGetter accessor. let mut typed_fields: Vec = Vec::new(); + let mut non_layout_fields: Vec = Vec::new(); if info.num_fields > 0 && !info.fields.is_null() { let field_slice = unsafe { std::slice::from_raw_parts(info.fields, info.num_fields as usize) }; @@ -163,14 +184,29 @@ pub(crate) fn check_repr_c( return None; } }; - // Skip inherited fields (registered by parent's ObjectDef) + // Fields whose offset falls inside the parent type's address range cannot be + // part of the repr(C) struct layout (they occupy a slot the parent owns). if field.offset < parent_total_size { + let rust_type = if field.offset >= 0 && field.size >= 0 { + let meta = ffi::byte_array_to_string_opt(&field.metadata); + let schema = meta + .as_deref() + .and_then(extract_type_schema) + .and_then(|s| parse_type_schema(&s)); + repr_c_field_type(schema.as_ref(), type_map, type_key, field.size) + .map(|(ty, _)| ty) + } else { + None + }; trace!( - "{}: field '{}' at offset={} belongs to parent, skipping", - type_key, - name, - field.offset + "{}: field '{}' at offset={} is in parent range → non-layout (rust_type={:?})", + type_key, name, field.offset, rust_type ); + non_layout_fields.push(NonLayoutField { + name: name.clone(), + rust_name: sanitize_ident(&name), + rust_type, + }); continue; } trace!( @@ -199,10 +235,17 @@ pub(crate) fn check_repr_c( let (rust_type, is_pod) = match mapped { Some(v) => v, None => { + // Schema not mappable: cannot be a struct field, but still accessible + // at runtime via FieldGetter with an untyped (Any) return. debug!( - "{}: field '{}' type not mappable, will be covered by gap (schema_origin={:?})", + "{}: field '{}' type not mappable, covered by gap + non-layout FieldGetter (schema_origin={:?})", type_key, name, schema.as_ref().map(|s| &s.origin) ); + non_layout_fields.push(NonLayoutField { + name: name.clone(), + rust_name: sanitize_ident(&name), + rust_type: None, + }); continue; } }; @@ -292,6 +335,7 @@ pub(crate) fn check_repr_c( Some(ReprCInfo { parent_type_key, layout, + non_layout_fields, }) } From 0f7886de41802456c48b9be4726f0383644815d7 Mon Sep 17 00:00:00 2001 From: Huanqi Cao Date: Mon, 9 Mar 2026 00:41:22 +0800 Subject: [PATCH 28/29] docs(stubgen): document non_layout_fields and unified getter interface - Field Accessor Style: new section explaining the two accessor paths (direct struct field vs FieldGetter) with the unified get_* convention, return-type rules, and a debugging recipe. - repr(C) Decision Rules: expand soft-handling list to cover parent-range fields and unmappable-schema fields, describing the non_layout_fields mechanism and the emitted get_* accessors. - TODO: add "Parent-range field layout override" item (phase 2 work); update "Common interface" item to reflect current repr(C) state. Made-with: Cursor --- rust/tvm-ffi-stubgen/README.md | 102 ++++++++++++++++++++++++++++----- 1 file changed, 89 insertions(+), 13 deletions(-) diff --git a/rust/tvm-ffi-stubgen/README.md b/rust/tvm-ffi-stubgen/README.md index a8d8db50d..ba668412b 100644 --- a/rust/tvm-ffi-stubgen/README.md +++ b/rust/tvm-ffi-stubgen/README.md @@ -124,27 +124,82 @@ Ref wrappers use `impl_object_hierarchy!` to establish: ## Field Accessor Style -Getter generation follows a single style: +All getter variants share a unified calling convention so callers do not need +to know which code path was used: - name prefix is always `get_` -- only direct fields of current type generate getters -- inherited getters are available via deref auto-coercion +- return is infallible (panics internally rather than returning `Result`) +- only fields of the current type are generated; inherited getters are available + via deref auto-coercion -### Return type rules +### Direct struct field getters (repr(C) layout path) -- POD field -> return by value -- object/container field -> clone and return user-facing type +Fields that map cleanly to the repr(C) struct body are accessed by direct +memory reference: -Example: +- POD field → return by value +- object/container field → clone and return ```rust -impl TestObjectDerived { - pub fn get_v_map(&self) -> tvm_ffi::Map { - self.data.v_map.clone() +impl PrimExpr { + pub fn get_dtype(&self) -> tvm_ffi::DLDataType { + self.data.dtype // POD: copy + } +} +impl ForFrame { + pub fn get_doms(&self) -> tvm_ffi::Array { + self.data.doms.clone() // object: clone + } +} +``` + +### Non-layout field getters (FieldGetter runtime path) + +Some registered ObjectDef fields cannot be placed in the repr(C) struct body: + +1. **Parent-range fields** — offset falls inside the parent type's address + range. Example: `ForFrame.vars` at offset 56 is within `TIRFrame`'s 0..64 + range. These are fields the child type "fills into" a gap slot of the parent. +2. **Schema-unmappable fields** — the field's type schema has no Rust + representation in the current type_map. + +For both cases stubgen generates a `LazyLock>` static and a +matching `get_xxx` method with the **same signature** as a direct getter: + +- typed (schema mappable): `pub fn get_xxx(&self) -> T` via `FieldGetter::get()` +- untyped (schema unmappable): `pub fn get_xxx(&self) -> tvm_ffi::Any` via `get_any()` + +```rust +// auto-generated: parent-range field, typed +static FIELD_FORFRAME__VARS: LazyLock>> = ...; +impl ForFrame { + pub fn get_vars(&self) -> tvm_ffi::Array { + let __obj: tvm_ffi::object::ObjectRef = self.clone().into(); + FIELD_FORFRAME__VARS.get(&__obj).expect("...") } } ``` +Callers use `frame.get_vars()` exactly as they would `frame.get_doms()`, with +no awareness of the access mechanism. + +### Debugging non-layout fields + +To inspect what fields TVM has registered for a type (including those not in +the struct layout), use the reflection API at runtime: + +```rust +// example: inspect ForFrame field offsets and schemas +let info = unsafe { tvm_ffi_sys::TVMFFIGetTypeInfo(type_index) }; +for i in 0..info.num_fields { + let f = &(*info.fields)[i]; + println!("name={:?} offset={} schema={:?}", f.name, f.offset, f.metadata); +} +``` + +Alternatively, set `RUST_LOG=trace` when running stubgen to see all field +offset and schema decisions in the log output. + ## Subtyping and Cast Rules Stubgen-generated repr(C) refs use standard Rust traits as the only user-facing cast API: @@ -172,8 +227,16 @@ This avoids custom cast traits and keeps compile-time type constraints explicit. - **Parent type not in type_map or not repr(C)-compatible**: the parent region is treated as a gap after the `Object` header. The struct uses `tvm_ffi::object::Object` as the parent field and gap-fills the bytes between Object and the first known field. -- **Field type schema not mappable to Rust**: the field is skipped in the struct layout - (covered by a gap) but still accessible via runtime `FieldGetter` if needed. +- **Parent-range fields** (registered by this type but `offset < parent_total_size`): + cannot be placed in the struct body. Tracked in `non_layout_fields`; a FieldGetter + static and `get_*` accessor are emitted instead. These fields physically reside + in a gap slot of the parent type that the child fills with its own data. +- **Field type schema not mappable to Rust**: the field becomes a gap in the struct + body. Also tracked in `non_layout_fields` and exposed via `get_*` returning + `tvm_ffi::Any` (using `FieldGetter::get_any()`). + +In all non-layout cases the emitted `get_*` method has the same naming and +infallible-return signature as a direct struct-field getter. ### Schema mapping rules @@ -268,7 +331,8 @@ has a usable repr(C) layout. `define_object_wrapper!` types and repr(C) types currently expose different API surfaces: -- repr(C) types: `Deref` chain, `From`/`TryFrom`, direct `get_*` accessors +- repr(C) types: `Deref` chain, `From`/`TryFrom`, direct `get_*` accessors, + `get_*` FieldGetter accessors for non-layout fields - fallback types: `from_object` / `as_object_ref` / `into_object_ref`, runtime `FieldGetter` Code that depends on a given type must know which generation path was used, and that @@ -280,6 +344,18 @@ A stable, version-independent interface layer is needed so that user code does n to distinguish between the two paths, and so that crates built against one stubgen version remain source-compatible with crates built against a later one. +### Parent-range field layout override + +When a child type registers a field at an offset within the parent's address range +(e.g. `ForFrame.vars` at offset 56 inside `TIRFrame`'s 0..64 gap), the field is +currently excluded from the repr(C) struct body and exposed only via `FieldGetter`. +This is correct for access but sub-optimal: the child field is physically in a gap +slot that the parent never uses, so it could safely be placed in the struct layout. + +Fix: inspect the parent's layout for the specific offset. If the parent has a +`[u8; N]` gap entry covering that offset, allow the child's field to override it +directly in the struct rather than routing through FieldGetter. + ## Related User Guide For generation command-line usage and step-by-step invocation examples, see: From 7e0887215e5478363c60c627e5cc3d9ad69af880 Mon Sep 17 00:00:00 2001 From: Huanqi Cao Date: Sat, 14 Mar 2026 15:25:12 +0800 Subject: [PATCH 29/29] refactor(rust): upgrade to edition 2024, syn 2.x, fix clippy - Upgrade all Rust crates from edition 2021 to edition 2024 - Set rust-version = "1.85" (MSRV for edition 2024) - Add rust-toolchain.toml (stable + rustfmt/clippy components) - Add .rustfmt.toml with edition 2024 formatting rules - Migrate tvm-ffi-macros from syn 1.x to syn 2.x - Remove unmaintained proc-macro-error dependency - Fix all clippy warnings across tvm-ffi, tvm-ffi-sys, tvm-ffi-macros - Allow unsafe_op_in_unsafe_fn at crate level for incremental migration - Update stubgen template to emit edition 2024 + clippy allows Made-with: Cursor --- rust/.rustfmt.toml | 3 ++ rust/rust-toolchain.toml | 3 ++ rust/tvm-ffi-macros/Cargo.toml | 6 ++-- rust/tvm-ffi-macros/src/lib.rs | 7 ++-- rust/tvm-ffi-macros/src/utils.rs | 46 ++++++-------------------- rust/tvm-ffi-stubgen/Cargo.toml | 3 +- rust/tvm-ffi-stubgen/src/ffi.rs | 8 ++--- rust/tvm-ffi-stubgen/src/generate.rs | 26 ++++++++++----- rust/tvm-ffi-stubgen/src/main.rs | 2 +- rust/tvm-ffi-stubgen/src/repr_c.rs | 39 ++++++---------------- rust/tvm-ffi-stubgen/tests/stubgen.rs | 2 +- rust/tvm-ffi-sys/Cargo.toml | 3 +- rust/tvm-ffi-sys/src/c_api.rs | 6 ++++ rust/tvm-ffi-sys/src/dlpack.rs | 8 ++--- rust/tvm-ffi-sys/src/lib.rs | 2 ++ rust/tvm-ffi/Cargo.toml | 3 +- rust/tvm-ffi/src/any.rs | 4 +-- rust/tvm-ffi/src/collections/array.rs | 2 +- rust/tvm-ffi/src/collections/shape.rs | 2 +- rust/tvm-ffi/src/collections/tensor.rs | 16 ++++----- rust/tvm-ffi/src/device.rs | 2 +- rust/tvm-ffi/src/dtype.rs | 6 ++-- rust/tvm-ffi/src/error.rs | 4 +-- rust/tvm-ffi/src/function.rs | 4 +-- rust/tvm-ffi/src/function_internal.rs | 1 + rust/tvm-ffi/src/lib.rs | 13 +++++++- rust/tvm-ffi/src/macros.rs | 3 +- rust/tvm-ffi/src/object.rs | 11 +++--- rust/tvm-ffi/src/object_wrapper.rs | 10 ++---- rust/tvm-ffi/src/string.rs | 22 ++++++------ rust/tvm-ffi/src/type_traits.rs | 14 +++----- rust/tvm-ffi/tests/test_object.rs | 2 +- 32 files changed, 131 insertions(+), 152 deletions(-) create mode 100644 rust/.rustfmt.toml create mode 100644 rust/rust-toolchain.toml diff --git a/rust/.rustfmt.toml b/rust/.rustfmt.toml new file mode 100644 index 000000000..15df9a7c3 --- /dev/null +++ b/rust/.rustfmt.toml @@ -0,0 +1,3 @@ +edition = "2024" +max_width = 100 +use_small_heuristics = "Default" diff --git a/rust/rust-toolchain.toml b/rust/rust-toolchain.toml new file mode 100644 index 000000000..73cb934de --- /dev/null +++ b/rust/rust-toolchain.toml @@ -0,0 +1,3 @@ +[toolchain] +channel = "stable" +components = ["rustfmt", "clippy"] diff --git a/rust/tvm-ffi-macros/Cargo.toml b/rust/tvm-ffi-macros/Cargo.toml index f8d29d406..4771ce281 100644 --- a/rust/tvm-ffi-macros/Cargo.toml +++ b/rust/tvm-ffi-macros/Cargo.toml @@ -20,7 +20,8 @@ name = "tvm-ffi-macros" description = "Procedural macro crate for tvm-ffi" version = "0.1.0-alpha.0" -edition = "2021" +edition = "2024" +rust-version = "1.85" license = "Apache-2.0" @@ -30,5 +31,4 @@ proc-macro = true [dependencies] proc-macro2 = "^1.0" quote = "^1.0" -syn = { version = "1.0.48", features = ["full", "parsing", "extra-traits"] } -proc-macro-error = "^1.0" +syn = { version = "^2.0", features = ["full"] } diff --git a/rust/tvm-ffi-macros/src/lib.rs b/rust/tvm-ffi-macros/src/lib.rs index 64fe3f18e..76497bd1b 100644 --- a/rust/tvm-ffi-macros/src/lib.rs +++ b/rust/tvm-ffi-macros/src/lib.rs @@ -18,19 +18,16 @@ */ use proc_macro::TokenStream; -use proc_macro_error::proc_macro_error; mod object_macros; mod utils; -#[proc_macro_error] #[proc_macro_derive(Object, attributes(type_key, type_index))] pub fn derive_object(input: TokenStream) -> TokenStream { - TokenStream::from(object_macros::derive_object(input)) + object_macros::derive_object(input) } -#[proc_macro_error] #[proc_macro_derive(ObjectRef, attributes(type_key, type_index))] pub fn derive_object_ref(input: TokenStream) -> TokenStream { - TokenStream::from(object_macros::derive_object_ref(input)) + object_macros::derive_object_ref(input) } diff --git a/rust/tvm-ffi-macros/src/utils.rs b/rust/tvm-ffi-macros/src/utils.rs index da86534f6..0b3bca625 100644 --- a/rust/tvm-ffi-macros/src/utils.rs +++ b/rust/tvm-ffi-macros/src/utils.rs @@ -20,8 +20,6 @@ use proc_macro2::TokenStream; use quote::quote; use std::env; -/// Get the tvm-rt crate name -/// \return The tvm-rt crate name pub(crate) fn get_tvm_ffi_crate() -> TokenStream { if env::var("CARGO_PKG_NAME").unwrap() == "tvm-ffi" { quote!(crate) @@ -30,49 +28,27 @@ pub(crate) fn get_tvm_ffi_crate() -> TokenStream { } } -/// Get an attribute by name from a derive input -/// -/// # Arguments -/// * `derive_input` - The derive input to get the attribute from -/// * `name` - The name of the attribute to get -/// -/// # Returns -/// * `Option<&syn::Attribute>` - The attribute if it exists pub(crate) fn get_attr<'a>( derive_input: &'a syn::DeriveInput, name: &str, ) -> Option<&'a syn::Attribute> { - derive_input.attrs.iter().find(|a| a.path.is_ident(name)) + derive_input.attrs.iter().find(|a| a.path().is_ident(name)) } -/// Convert an attribute to a string -/// -/// # Arguments -/// * `attr` - The attribute to convert -/// -/// # Returns -/// * `syn::LitStr` - The string value of the attribute pub(crate) fn attr_to_str(attr: &syn::Attribute) -> syn::LitStr { - match attr.parse_meta() { - Ok(syn::Meta::NameValue(syn::MetaNameValue { - lit: syn::Lit::Str(s), + match &attr.meta { + syn::Meta::NameValue(syn::MetaNameValue { + value: + syn::Expr::Lit(syn::ExprLit { + lit: syn::Lit::Str(s), + .. + }), .. - })) => s, - Ok(_m) => panic!("Expected a string literal, got"), - Err(e) => panic!("{}", e), + }) => s.clone(), + _ => panic!("Expected #[attr = \"string\"] attribute"), } } -/// Convert an attribute to an integer -/// -/// # Arguments -/// * `attr` - The attribute to convert -/// -/// # Returns -/// * `syn::Result` - The integer value of the attribute pub(crate) fn attr_to_expr(attr: &syn::Attribute) -> syn::Result { - let parser = |input: syn::parse::ParseStream| { - input.parse::() // parse expression after '=' - }; - syn::parse::Parser::parse2(parser, attr.tokens.clone()) + attr.parse_args::() } diff --git a/rust/tvm-ffi-stubgen/Cargo.toml b/rust/tvm-ffi-stubgen/Cargo.toml index fc106d4ef..4157065b2 100644 --- a/rust/tvm-ffi-stubgen/Cargo.toml +++ b/rust/tvm-ffi-stubgen/Cargo.toml @@ -19,7 +19,8 @@ name = "tvm-ffi-stubgen" description = "Rust stub generator for tvm-ffi" version = "0.1.0" -edition = "2021" +edition = "2024" +rust-version = "1.85" license = "Apache-2.0" [[bin]] diff --git a/rust/tvm-ffi-stubgen/src/ffi.rs b/rust/tvm-ffi-stubgen/src/ffi.rs index 100e5bf16..4164f6e15 100644 --- a/rust/tvm-ffi-stubgen/src/ffi.rs +++ b/rust/tvm-ffi-stubgen/src/ffi.rs @@ -17,10 +17,10 @@ use libloading::Library; use std::path::PathBuf; +use tvm_ffi::Array; use tvm_ffi::tvm_ffi_sys::{ TVMFFIByteArray, TVMFFIGetTypeInfo, TVMFFITypeInfo, TVMFFITypeKeyToIndex, }; -use tvm_ffi::Array; use tvm_ffi::{Function, Result as FfiResult, String as FfiString}; pub(crate) fn load_dlls(paths: &[PathBuf]) -> Result, Box> { @@ -66,11 +66,7 @@ pub(crate) fn get_type_info(type_key: &str) -> Option<&'static TVMFFITypeInfo> { return None; } let info = TVMFFIGetTypeInfo(tindex); - if info.is_null() { - None - } else { - Some(&*info) - } + if info.is_null() { None } else { Some(&*info) } } } diff --git a/rust/tvm-ffi-stubgen/src/generate.rs b/rust/tvm-ffi-stubgen/src/generate.rs index 55d441430..e9ecbaaa7 100644 --- a/rust/tvm-ffi-stubgen/src/generate.rs +++ b/rust/tvm-ffi-stubgen/src/generate.rs @@ -21,7 +21,7 @@ use crate::model::{ FieldGen, FunctionGen, FunctionSig, GetterSpec, MethodGen, ModuleNode, RustType, TypeGen, }; use crate::repr_c; -use crate::schema::{extract_type_schema, parse_type_schema, TypeSchema}; +use crate::schema::{TypeSchema, extract_type_schema, parse_type_schema}; use crate::utils; use std::collections::BTreeMap; use std::fmt::Write as _; @@ -434,7 +434,7 @@ pub(crate) fn render_cargo_toml( Some(path) => path.clone(), None => utils::default_tvm_ffi_path()?, }; - let tvm_ffi_path = tvm_ffi_path.canonicalize().unwrap_or_else(|_| tvm_ffi_path); + let tvm_ffi_path = tvm_ffi_path.canonicalize().unwrap_or(tvm_ffi_path); let tvm_ffi_path_str = tvm_ffi_path.to_string_lossy().to_string(); let mut package = Table::new(); @@ -448,7 +448,7 @@ pub(crate) fn render_cargo_toml( ); package.insert( "edition".to_string(), - toml::Value::String("2021".to_string()), + toml::Value::String("2024".to_string()), ); let mut tvm_ffi = Table::new(); @@ -467,7 +467,16 @@ pub(crate) fn render_cargo_toml( pub(crate) fn render_lib_rs(functions_root: &ModuleNode, types_root: &ModuleNode) -> String { let mut out = String::new(); out.push_str( - r#"pub mod _tvm_ffi_stubgen_detail { + r#"#![allow( + clippy::needless_question_mark, + clippy::too_many_arguments, + clippy::enum_variant_names, + clippy::manual_div_ceil, + clippy::just_underscores_and_digits, + non_snake_case +)] + +pub mod _tvm_ffi_stubgen_detail { pub mod functions; pub mod types; } @@ -581,10 +590,9 @@ fn render_facade_module( is_root: bool, ) { // Check if this module has any actual content - let has_functions = functions.map_or(false, |node| !node.functions.is_empty()); - let has_types = types.map_or(false, |node| { - node.types.iter().any(|ty| !is_builtin_type(&ty.type_key)) - }); + let has_functions = functions.is_some_and(|node| !node.functions.is_empty()); + let has_types = + types.is_some_and(|node| node.types.iter().any(|ty| !is_builtin_type(&ty.type_key))); let mut child_names = std::collections::BTreeSet::new(); if let Some(node) = functions { @@ -818,7 +826,7 @@ fn render_repr_c_type( Some(parent_key) => { let parent_rust = _type_map .get(parent_key) - .map(|s| s.clone()) + .cloned() .unwrap_or_else(|| format!("{}Obj", sanitize_ident(parent_key, IdentStyle::Type))); // Convert crate path to absolute path inside types module and append Obj let types_path = diff --git a/rust/tvm-ffi-stubgen/src/main.rs b/rust/tvm-ffi-stubgen/src/main.rs index cc5b32576..6dfa1d097 100644 --- a/rust/tvm-ffi-stubgen/src/main.rs +++ b/rust/tvm-ffi-stubgen/src/main.rs @@ -16,7 +16,7 @@ // under the License. use clap::Parser; -use tvm_ffi_stubgen::{run, Args}; +use tvm_ffi_stubgen::{Args, run}; fn main() -> Result<(), Box> { env_logger::init(); diff --git a/rust/tvm-ffi-stubgen/src/repr_c.rs b/rust/tvm-ffi-stubgen/src/repr_c.rs index 5bb5af0c1..4072bc3ba 100644 --- a/rust/tvm-ffi-stubgen/src/repr_c.rs +++ b/rust/tvm-ffi-stubgen/src/repr_c.rs @@ -25,7 +25,7 @@ //! inference. use crate::ffi; -use crate::schema::{extract_type_schema, parse_type_schema, TypeSchema}; +use crate::schema::{TypeSchema, extract_type_schema, parse_type_schema}; use log::{debug, trace}; use std::collections::BTreeMap; @@ -113,11 +113,7 @@ pub(crate) fn check_repr_c( }; trace!( "{}: total_size={}, type_depth={}, num_fields={}, num_methods={}", - type_key, - total_size, - info.type_depth, - info.num_fields, - info.num_methods + type_key, total_size, info.type_depth, info.num_fields, info.num_methods ); // Resolve parent. @@ -151,8 +147,7 @@ pub(crate) fn check_repr_c( // Parent exists but not mappable — use Object as parent, gap covers the rest. trace!( "{}: parent='{}' not mappable, falling back to Object", - type_key, - key + type_key, key ); (None, obj_size) } @@ -163,9 +158,7 @@ pub(crate) fn check_repr_c( }; trace!( "{}: parent={:?}, parent_total_size={}", - type_key, - parent_type_key, - parent_total_size + type_key, parent_type_key, parent_total_size ); // Collect and sort fields that belong to this type (offset >= parent_total_size). @@ -211,10 +204,7 @@ pub(crate) fn check_repr_c( } trace!( "{}: field '{}': offset={}, size={}", - type_key, - name, - field.offset, - field.size + type_key, name, field.offset, field.size ); if field.offset < 0 || field.size < 0 { debug!("{}: field '{}' has invalid offset/size", type_key, name); @@ -239,7 +229,9 @@ pub(crate) fn check_repr_c( // at runtime via FieldGetter with an untyped (Any) return. debug!( "{}: field '{}' type not mappable, covered by gap + non-layout FieldGetter (schema_origin={:?})", - type_key, name, schema.as_ref().map(|s| &s.origin) + type_key, + name, + schema.as_ref().map(|s| &s.origin) ); non_layout_fields.push(NonLayoutField { name: name.clone(), @@ -251,10 +243,7 @@ pub(crate) fn check_repr_c( }; trace!( "{}: field '{}' -> rust_type='{}', is_pod={}", - type_key, - name, - rust_type, - is_pod + type_key, name, rust_type, is_pod ); typed_fields.push(ReprCField { rust_name: sanitize_ident(&name), @@ -277,10 +266,7 @@ pub(crate) fn check_repr_c( let gap_size = f.offset - pos; trace!( "{}: gap at {}..{} ({} bytes)", - type_key, - pos, - f.offset, - gap_size + type_key, pos, f.offset, gap_size ); layout.push(LayoutEntry::Gap { name: format!("_gap{}", gap_idx), @@ -305,10 +291,7 @@ pub(crate) fn check_repr_c( let gap_size = total_size - pos; trace!( "{}: trailing gap at {}..{} ({} bytes)", - type_key, - pos, - total_size, - gap_size + type_key, pos, total_size, gap_size ); layout.push(LayoutEntry::Gap { name: format!("_gap{}", gap_idx), diff --git a/rust/tvm-ffi-stubgen/tests/stubgen.rs b/rust/tvm-ffi-stubgen/tests/stubgen.rs index d89159d3b..2b69de30f 100644 --- a/rust/tvm-ffi-stubgen/tests/stubgen.rs +++ b/rust/tvm-ffi-stubgen/tests/stubgen.rs @@ -19,7 +19,7 @@ use std::env; use std::fs; use std::path::{Path, PathBuf}; use std::process::Command; -use tvm_ffi_stubgen::{run, Args}; +use tvm_ffi_stubgen::{Args, run}; #[test] fn stubgen_tvm_ffi_testing() { diff --git a/rust/tvm-ffi-sys/Cargo.toml b/rust/tvm-ffi-sys/Cargo.toml index ef87038d4..b0c10751b 100644 --- a/rust/tvm-ffi-sys/Cargo.toml +++ b/rust/tvm-ffi-sys/Cargo.toml @@ -20,7 +20,8 @@ name = "tvm-ffi-sys" description = "Low-level sys crate for tvm-ffi" version = "0.1.0-alpha.0" -edition = "2021" +edition = "2024" +rust-version = "1.85" license = "Apache-2.0" diff --git a/rust/tvm-ffi-sys/src/c_api.rs b/rust/tvm-ffi-sys/src/c_api.rs index e0bf08581..f3c5b8ee5 100644 --- a/rust/tvm-ffi-sys/src/c_api.rs +++ b/rust/tvm-ffi-sys/src/c_api.rs @@ -113,6 +113,12 @@ pub struct TVMFFIObject { __padding: u32, } +impl Default for TVMFFIObject { + fn default() -> Self { + Self::new() + } +} + impl TVMFFIObject { pub fn new() -> Self { Self { diff --git a/rust/tvm-ffi-sys/src/dlpack.rs b/rust/tvm-ffi-sys/src/dlpack.rs index e069ea8e6..a8eeed96a 100644 --- a/rust/tvm-ffi-sys/src/dlpack.rs +++ b/rust/tvm-ffi-sys/src/dlpack.rs @@ -107,8 +107,8 @@ pub struct DLTensor { impl DLDevice { pub fn new(device_type: DLDeviceType, device_id: i32) -> Self { Self { - device_type: device_type, - device_id: device_id, + device_type, + device_id, } } } @@ -117,8 +117,8 @@ impl DLDataType { pub fn new(code: DLDataTypeCode, bits: u8, lanes: u16) -> Self { Self { code: code as u8, - bits: bits, - lanes: lanes, + bits, + lanes, } } } diff --git a/rust/tvm-ffi-sys/src/lib.rs b/rust/tvm-ffi-sys/src/lib.rs index 1530cc713..018421d78 100644 --- a/rust/tvm-ffi-sys/src/lib.rs +++ b/rust/tvm-ffi-sys/src/lib.rs @@ -16,6 +16,8 @@ * specific language governing permissions and limitations * under the License. */ +#![allow(clippy::new_without_default)] +#![allow(clippy::missing_safety_doc)] pub mod c_api; pub mod c_env_api; pub mod dlpack; diff --git a/rust/tvm-ffi/Cargo.toml b/rust/tvm-ffi/Cargo.toml index b27c8c844..ec73bb1b0 100644 --- a/rust/tvm-ffi/Cargo.toml +++ b/rust/tvm-ffi/Cargo.toml @@ -20,7 +20,8 @@ name = "tvm-ffi" description = "tvm-ffi rust support" version = "0.1.0-alpha.0" -edition = "2021" +edition = "2024" +rust-version = "1.85" license = "Apache-2.0" diff --git a/rust/tvm-ffi/src/any.rs b/rust/tvm-ffi/src/any.rs index c769494a7..ade44db68 100644 --- a/rust/tvm-ffi/src/any.rs +++ b/rust/tvm-ffi/src/any.rs @@ -101,9 +101,9 @@ impl<'a, T: AnyCompatible> From<&'a T> for AnyView<'a> { fn from(value: &'a T) -> Self { unsafe { let mut data = TVMFFIAny::new(); - T::copy_to_any_view(&value, &mut data); + T::copy_to_any_view(value, &mut data); Self { - data: data, + data, _phantom: std::marker::PhantomData, } } diff --git a/rust/tvm-ffi/src/collections/array.rs b/rust/tvm-ffi/src/collections/array.rs index 6f259ba19..06950192b 100644 --- a/rust/tvm-ffi/src/collections/array.rs +++ b/rust/tvm-ffi/src/collections/array.rs @@ -165,7 +165,7 @@ impl Array { #[inline] fn as_container(&self) -> &ArrayObj { unsafe { - let ptr = ObjectArc::as_raw(&self.data) as *const ArrayObj; + let ptr = ObjectArc::as_raw(&self.data); &*ptr } } diff --git a/rust/tvm-ffi/src/collections/shape.rs b/rust/tvm-ffi/src/collections/shape.rs index 39d9a1df1..16ec44197 100644 --- a/rust/tvm-ffi/src/collections/shape.rs +++ b/rust/tvm-ffi/src/collections/shape.rs @@ -127,7 +127,7 @@ impl Eq for Shape {} impl PartialOrd for Shape { #[inline] fn partial_cmp(&self, other: &Self) -> Option { - self.as_slice().partial_cmp(other.as_slice()) + Some(self.cmp(other)) } } diff --git a/rust/tvm-ffi/src/collections/tensor.rs b/rust/tvm-ffi/src/collections/tensor.rs index 6b34613e3..acd4bbe24 100644 --- a/rust/tvm-ffi/src/collections/tensor.rs +++ b/rust/tvm-ffi/src/collections/tensor.rs @@ -22,8 +22,8 @@ use crate::dtype::AsDLDataType; use crate::dtype::DLDataTypeExt; use crate::error::Result; use crate::object::{Object, ObjectArc, ObjectCore, ObjectCoreWithExtraItems}; -use tvm_ffi_sys::dlpack::{DLDataType, DLDevice, DLDeviceType, DLTensor}; use tvm_ffi_sys::TVMFFITypeIndex as TypeIndex; +use tvm_ffi_sys::dlpack::{DLDataType, DLDevice, DLDeviceType, DLTensor}; //----------------------------------------------------- // NDAllocator Trait @@ -66,7 +66,7 @@ impl DLTensorExt for DLTensor { } fn item_size(&self) -> usize { - (self.dtype.bits as usize * self.dtype.lanes as usize + 7) / 8 + (self.dtype.bits as usize * self.dtype.lanes as usize).div_ceil(8) } } @@ -267,15 +267,15 @@ impl Tensor { object: Object::new(), dltensor: DLTensor { data: std::ptr::null_mut(), - device: device, + device, ndim: shape.len() as i32, - dtype: dtype, + dtype, shape: std::ptr::null_mut(), strides: std::ptr::null_mut(), byte_offset: 0, }, }, - alloc: alloc, + alloc, }; unsafe { let mut obj_arc = ObjectArc::new_with_extra_items(tensor_obj); @@ -320,16 +320,16 @@ unsafe impl NDAllocator for CPUNDAlloc { const MIN_ALIGN: usize = 64; unsafe fn alloc_data(&mut self, prototype: &DLTensor) -> *mut core::ffi::c_void { - let numel = prototype.numel() as usize; + let numel = prototype.numel(); let item_size = prototype.item_size(); - let size = numel * item_size as usize; + let size = numel * item_size; let layout = std::alloc::Layout::from_size_align(size, Self::MIN_ALIGN).unwrap(); let ptr = std::alloc::alloc(layout); ptr as *mut core::ffi::c_void } unsafe fn free_data(&mut self, tensor: &DLTensor) { - let numel = tensor.numel() as usize; + let numel = tensor.numel(); let item_size = tensor.item_size(); let size = numel * item_size; let layout = std::alloc::Layout::from_size_align(size, Self::MIN_ALIGN).unwrap(); diff --git a/rust/tvm-ffi/src/device.rs b/rust/tvm-ffi/src/device.rs index d6c0418bc..885836d92 100644 --- a/rust/tvm-ffi/src/device.rs +++ b/rust/tvm-ffi/src/device.rs @@ -75,7 +75,7 @@ unsafe impl AnyCompatible for DLDevice { } unsafe fn check_any_strict(data: &TVMFFIAny) -> bool { - return data.type_index == TypeIndex::kTVMFFIDevice as i32; + data.type_index == TypeIndex::kTVMFFIDevice as i32 } unsafe fn copy_from_any_view_after_check(data: &TVMFFIAny) -> Self { diff --git a/rust/tvm-ffi/src/dtype.rs b/rust/tvm-ffi/src/dtype.rs index eb471b2b7..7081242bc 100644 --- a/rust/tvm-ffi/src/dtype.rs +++ b/rust/tvm-ffi/src/dtype.rs @@ -18,9 +18,9 @@ */ use crate::error::Result; use crate::type_traits::AnyCompatible; +use tvm_ffi_sys::TVMFFITypeIndex as TypeIndex; /// Data type handling use tvm_ffi_sys::dlpack::{DLDataType, DLDataTypeCode}; -use tvm_ffi_sys::TVMFFITypeIndex as TypeIndex; use tvm_ffi_sys::{TVMFFIAny, TVMFFIByteArray, TVMFFIDataTypeFromString, TVMFFIDataTypeToString}; /// Extra methods for DLDataType @@ -53,7 +53,7 @@ impl DLDataTypeExt for DLDataType { fn to_string(&self) -> crate::string::String { unsafe { let mut ffi_any = TVMFFIAny::new(); - crate::check_safe_call!(TVMFFIDataTypeToString(&*self, &mut ffi_any)).unwrap(); + crate::check_safe_call!(TVMFFIDataTypeToString(self, &mut ffi_any)).unwrap(); crate::any::Any::from_raw_ffi_any(ffi_any) .try_into() .unwrap() @@ -120,7 +120,7 @@ unsafe impl AnyCompatible for DLDataType { /// # Returns /// `true` if the Any contains a DLDataType, `false` otherwise unsafe fn check_any_strict(data: &TVMFFIAny) -> bool { - return data.type_index == TypeIndex::kTVMFFIDataType as i32; + data.type_index == TypeIndex::kTVMFFIDataType as i32 } /// Copy a DLDataType from an Any view (after type check) diff --git a/rust/tvm-ffi/src/error.rs b/rust/tvm-ffi/src/error.rs index 2689f3790..341d93656 100644 --- a/rust/tvm-ffi/src/error.rs +++ b/rust/tvm-ffi/src/error.rs @@ -125,7 +125,7 @@ impl Error { /// # Returns /// The kind of the error pub fn kind(&self) -> ErrorKind<'_> { - ErrorKind(&self.data.cell.kind.as_str()) + ErrorKind(self.data.cell.kind.as_str()) } /// Get the message of the error @@ -186,7 +186,7 @@ impl Error { let mut new_backtrace = String::new(); new_backtrace.push_str(this.backtrace()); new_backtrace.push_str(backtrace); - return Error::new(this.kind(), this.message(), &new_backtrace); + Error::new(this.kind(), this.message(), &new_backtrace) } } } diff --git a/rust/tvm-ffi/src/function.rs b/rust/tvm-ffi/src/function.rs index e488983f8..e2c740970 100644 --- a/rust/tvm-ffi/src/function.rs +++ b/rust/tvm-ffi/src/function.rs @@ -150,7 +150,7 @@ impl Function { heap_args.resize(args_len, AnyView::new()); &mut heap_args[..args_len] }; - (&tuple_args).fill_any_view(packed_args); + tuple_args.fill_any_view(packed_args); self.call_packed(packed_args) } /// Call function with compile-time known argument count @@ -170,7 +170,7 @@ impl Function { TupleType: TupleAsPackedArgs, { let mut packed_args = [AnyView::new(); LEN]; - (&tuple_args).fill_any_view(&mut packed_args); + tuple_args.fill_any_view(&mut packed_args); self.call_packed(&packed_args) } /// Get global function by name diff --git a/rust/tvm-ffi/src/function_internal.rs b/rust/tvm-ffi/src/function_internal.rs index b10b7a89a..249e62cbb 100644 --- a/rust/tvm-ffi/src/function_internal.rs +++ b/rust/tvm-ffi/src/function_internal.rs @@ -184,6 +184,7 @@ pub trait IntoArgHolderTuple { macro_rules! impl_into_arg_holder_tuple { ( $($T:ident),* ; $($idx:tt),* ) => { + #[allow(clippy::unused_unit)] impl<$($T),*> $crate::function_internal::IntoArgHolderTuple for ($($T,)*) where $($T: IntoArgHolder),* { diff --git a/rust/tvm-ffi/src/lib.rs b/rust/tvm-ffi/src/lib.rs index 2482a3a04..e6ba5a9a8 100644 --- a/rust/tvm-ffi/src/lib.rs +++ b/rust/tvm-ffi/src/lib.rs @@ -16,6 +16,17 @@ * specific language governing permissions and limitations * under the License. */ +// TODO: incrementally migrate unsafe fn bodies to use explicit unsafe blocks +#![allow(unsafe_op_in_unsafe_fn)] +#![allow( + clippy::mut_from_ref, + clippy::not_unsafe_ptr_arg_deref, + clippy::missing_safety_doc, + clippy::new_without_default, + clippy::len_without_is_empty, + clippy::result_unit_err +)] + pub mod any; pub mod collections; pub mod derive; @@ -40,10 +51,10 @@ pub use crate::collections::shape::Shape; pub use crate::collections::tensor::{CPUNDAlloc, NDAllocator, Tensor}; pub use crate::device::{current_stream, with_stream}; pub use crate::dtype::DLDataTypeExt; -pub use crate::error::{Error, ErrorKind, Result}; pub use crate::error::{ ATTRIBUTE_ERROR, INDEX_ERROR, KEY_ERROR, RUNTIME_ERROR, TYPE_ERROR, VALUE_ERROR, }; +pub use crate::error::{Error, ErrorKind, Result}; pub use crate::extra::module::Module; pub use crate::function::Function; pub use crate::object::{Object, ObjectArc, ObjectCore, ObjectCoreWithExtraItems, ObjectRefCore}; diff --git a/rust/tvm-ffi/src/macros.rs b/rust/tvm-ffi/src/macros.rs index 7936f71e7..e2c8eb3d2 100644 --- a/rust/tvm-ffi/src/macros.rs +++ b/rust/tvm-ffi/src/macros.rs @@ -47,6 +47,7 @@ macro_rules! function_name { /// /// # Returns /// * `Result<(), Error>` - The result of the safe call +/// /// Macro to check safe calls and automatically update traceback with file/line info /// /// Usage: check_safe_call!(function(args))?; @@ -102,7 +103,7 @@ macro_rules! bail { macro_rules! ensure { ($cond:expr, $error_kind:expr, $fmt:expr $(, $args:expr)* $(,)?) => {{ if !$cond { - crate::bail!($error_kind, $fmt $(, $args)*); + $crate::bail!($error_kind, $fmt $(, $args)*); } }}; } diff --git a/rust/tvm-ffi/src/object.rs b/rust/tvm-ffi/src/object.rs index dc1970a42..00754aae8 100644 --- a/rust/tvm-ffi/src/object.rs +++ b/rust/tvm-ffi/src/object.rs @@ -22,7 +22,7 @@ use std::sync::atomic::AtomicU64; use crate::derive::ObjectRef; pub use tvm_ffi_sys::TVMFFITypeIndex as TypeIndex; /// Object related ABI handling -use tvm_ffi_sys::{TVMFFIObject, COMBINED_REF_COUNT_BOTH_ONE}; +use tvm_ffi_sys::{COMBINED_REF_COUNT_BOTH_ONE, TVMFFIObject}; /// Object type is by default the TVMFFIObject #[repr(C)] @@ -60,7 +60,6 @@ pub unsafe trait ObjectCore: Sized + 'static { /// /// # Returns /// * `&mut TVMFFIObject` - The object header - /// \return The object header unsafe fn object_header_mut(this: &mut Self) -> &mut TVMFFIObject; } @@ -119,7 +118,7 @@ pub(crate) mod unsafe_ { }; use std::ffi::c_void; - use std::sync::atomic::{fence, Ordering}; + use std::sync::atomic::{Ordering, fence}; use tvm_ffi_sys::TVMFFIObject; use tvm_ffi_sys::TVMFFIObjectDeleterFlagBitMask::{ kTVMFFIObjectDeleterFlagBitMaskBoth, kTVMFFIObjectDeleterFlagBitMaskStrong, @@ -307,7 +306,7 @@ impl ObjectArc { ); // move into the object arc ptr Self { - ptr: std::ptr::NonNull::new_unchecked(ptr as *mut T), + ptr: std::ptr::NonNull::new_unchecked(ptr), _phantom: std::marker::PhantomData, } } @@ -345,7 +344,7 @@ impl ObjectArc { ); // move into the object arc ptr Self { - ptr: std::ptr::NonNull::new_unchecked(ptr as *mut T), + ptr: std::ptr::NonNull::new_unchecked(ptr), _phantom: std::marker::PhantomData, } } @@ -358,7 +357,6 @@ impl ObjectArc { /// /// # Returns /// * `ObjectArc` - The ObjectArc - /// \return The ObjectArc #[inline] pub unsafe fn from_raw(ptr: *const T) -> Self { Self { @@ -389,7 +387,6 @@ impl ObjectArc { /// /// # Returns /// * `*const T` - The raw pointer - /// \return The raw pointer #[inline] pub unsafe fn as_raw(this: &Self) -> *const T { this.ptr.as_ptr() as *const T diff --git a/rust/tvm-ffi/src/object_wrapper.rs b/rust/tvm-ffi/src/object_wrapper.rs index 4e86b0813..4533d5bef 100644 --- a/rust/tvm-ffi/src/object_wrapper.rs +++ b/rust/tvm-ffi/src/object_wrapper.rs @@ -231,11 +231,7 @@ fn type_index_for_key(type_key: &'static str) -> Option { let key = unsafe { TVMFFIByteArray::from_str(type_key) }; let mut index = 0i32; let code = unsafe { TVMFFITypeKeyToIndex(&key, &mut index) }; - if code == 0 { - Some(index) - } else { - None - } + if code == 0 { Some(index) } else { None } } unsafe impl AnyCompatible for T { @@ -269,7 +265,7 @@ unsafe impl AnyCompatible for T { } unsafe fn copy_from_any_view_after_check(data: &TVMFFIAny) -> Self { - let ptr = data.data_union.v_obj as *mut TVMFFIObject; + let ptr = data.data_union.v_obj; crate::object::unsafe_::inc_ref(ptr); let arc = ObjectArc::from_raw(ptr as *mut Object); let obj = ::from_data(arc); @@ -277,7 +273,7 @@ unsafe impl AnyCompatible for T { } unsafe fn move_from_any_after_check(data: &mut TVMFFIAny) -> Self { - let ptr = data.data_union.v_obj as *mut TVMFFIObject; + let ptr = data.data_union.v_obj; let arc = ObjectArc::from_raw(ptr as *mut Object); data.type_index = crate::TypeIndex::kTVMFFINone as i32; data.data_union.v_int64 = 0; diff --git a/rust/tvm-ffi/src/string.rs b/rust/tvm-ffi/src/string.rs index 94739a4c5..2e7a0b128 100644 --- a/rust/tvm-ffi/src/string.rs +++ b/rust/tvm-ffi/src/string.rs @@ -17,7 +17,7 @@ * under the License. */ use crate::derive::Object; -use crate::object::{unsafe_, Object, ObjectArc, ObjectCoreWithExtraItems}; +use crate::object::{Object, ObjectArc, ObjectCoreWithExtraItems, unsafe_}; use crate::type_traits::AnyCompatible; use std::cmp::Ordering; use std::fmt::{Debug, Display}; @@ -91,7 +91,7 @@ unsafe impl ObjectCoreWithExtraItems for BytesObj { #[inline] /// Get the count of extra items (trailing null byte for FFI compatibility) fn extra_items_count(this: &Self) -> usize { - return this.data.size + 1; + this.data.size + 1 } } @@ -114,7 +114,7 @@ where data: TVMFFIAny { type_index: TypeIndex::kTVMFFISmallBytes as i32, small_str_len: value.len() as u32, - data_union: data_union, + data_union, }, } } else { @@ -186,7 +186,7 @@ impl Eq for Bytes {} impl PartialOrd for Bytes { #[inline] fn partial_cmp(&self, other: &Self) -> Option { - self.as_slice().partial_cmp(other.as_slice()) + Some(self.cmp(other)) } } @@ -244,7 +244,7 @@ unsafe impl ObjectCoreWithExtraItems for StringObj { /// Get the count of extra items (trailing null byte for FFI compatibility) fn extra_items_count(this: &Self) -> usize { // extra item is the trailing \0 for ffi compatibility - return this.data.size + 1; + this.data.size + 1 } } @@ -304,7 +304,7 @@ where data: TVMFFIAny { type_index: TypeIndex::kTVMFFISmallStr as i32, small_str_len: bytes.len() as u32, - data_union: data_union, + data_union, }, } } else { @@ -402,7 +402,7 @@ where impl PartialOrd for String { #[inline] fn partial_cmp(&self, other: &Self) -> Option { - self.as_str().partial_cmp(other.as_str()) + Some(self.cmp(other)) } } @@ -452,8 +452,8 @@ unsafe impl AnyCompatible for Bytes { } unsafe fn check_any_strict(data: &TVMFFIAny) -> bool { - return data.type_index == TypeIndex::kTVMFFISmallBytes as i32 - || data.type_index == TypeIndex::kTVMFFIBytes as i32; + data.type_index == TypeIndex::kTVMFFISmallBytes as i32 + || data.type_index == TypeIndex::kTVMFFIBytes as i32 } unsafe fn copy_from_any_view_after_check(data: &TVMFFIAny) -> Self { @@ -500,8 +500,8 @@ unsafe impl AnyCompatible for String { } unsafe fn check_any_strict(data: &TVMFFIAny) -> bool { - return data.type_index == TypeIndex::kTVMFFISmallStr as i32 - || data.type_index == TypeIndex::kTVMFFIStr as i32; + data.type_index == TypeIndex::kTVMFFISmallStr as i32 + || data.type_index == TypeIndex::kTVMFFIStr as i32 } unsafe fn copy_from_any_view_after_check(data: &TVMFFIAny) -> Self { diff --git a/rust/tvm-ffi/src/type_traits.rs b/rust/tvm-ffi/src/type_traits.rs index 4b676915e..eb96bc990 100644 --- a/rust/tvm-ffi/src/type_traits.rs +++ b/rust/tvm-ffi/src/type_traits.rs @@ -151,7 +151,7 @@ unsafe impl AnyCompatible for Option { } unsafe fn copy_to_any_view(src: &Self, data: &mut TVMFFIAny) { - if let Some(ref value) = src { + if let Some(value) = src { T::copy_to_any_view(value, data); } else { data.type_index = TypeIndex::kTVMFFINone as i32; @@ -171,7 +171,7 @@ unsafe impl AnyCompatible for Option { } unsafe fn check_any_strict(data: &TVMFFIAny) -> bool { - return T::check_any_strict(data) || data.type_index == TypeIndex::kTVMFFINone as i32; + T::check_any_strict(data) || data.type_index == TypeIndex::kTVMFFINone as i32 } unsafe fn copy_from_any_view_after_check(data: &TVMFFIAny) -> Self { @@ -348,16 +348,12 @@ unsafe impl AnyCompatible for () { } unsafe fn check_any_strict(data: &TVMFFIAny) -> bool { - return data.type_index == TypeIndex::kTVMFFINone as i32; + data.type_index == TypeIndex::kTVMFFINone as i32 } - unsafe fn copy_from_any_view_after_check(_data: &TVMFFIAny) -> Self { - () - } + unsafe fn copy_from_any_view_after_check(_data: &TVMFFIAny) -> Self {} - unsafe fn move_from_any_after_check(_data: &mut TVMFFIAny) -> Self { - () - } + unsafe fn move_from_any_after_check(_data: &mut TVMFFIAny) -> Self {} unsafe fn try_cast_from_any_view(data: &TVMFFIAny) -> Result { if data.type_index == TypeIndex::kTVMFFINone as i32 { diff --git a/rust/tvm-ffi/tests/test_object.rs b/rust/tvm-ffi/tests/test_object.rs index 60378c2ae..650f0b944 100644 --- a/rust/tvm-ffi/tests/test_object.rs +++ b/rust/tvm-ffi/tests/test_object.rs @@ -16,8 +16,8 @@ * specific language governing permissions and limitations * under the License. */ -use std::sync::atomic::{AtomicU32, Ordering}; use std::sync::Arc; +use std::sync::atomic::{AtomicU32, Ordering}; use tvm_ffi::*; // must have repr(C) for the object header stays in the same position