Skip to content

Commit c86d18b

Browse files
authored
Add sam3-image model (#164)
1 parent 8d81ad3 commit c86d18b

File tree

12 files changed

+749
-6
lines changed

12 files changed

+749
-6
lines changed

Cargo.toml

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
[package]
22
name = "usls"
33
edition = "2021"
4-
version = "0.1.10"
4+
version = "0.1.11"
55
rust-version = "1.85"
66
description = "A Rust library integrated with ONNXRuntime, providing a collection of ML models."
77
repository = "https://github.com/jamjamjon/usls"
@@ -131,7 +131,8 @@ rfdetr = []
131131
rtdetr = []
132132
rtmo = []
133133
rtmpose = []
134-
sam = []
134+
sam = [] # SAM, SAM2
135+
sam3 = ["tokenizers"] # SAM3
135136
slanet = ["pipeline"]
136137
smolvlm = ["tokenizers"]
137138
sapiens = []
@@ -148,6 +149,7 @@ all-models = [
148149
"yolo",
149150
"yoloe",
150151
"sam",
152+
"sam3",
151153
"clip",
152154
"apisr",
153155
"image-classifier",
@@ -330,6 +332,10 @@ required-features = ["sam"]
330332
name = "sam2"
331333
required-features = ["sam"]
332334

335+
[[example]]
336+
name = "sam3"
337+
required-features = ["sam3"]
338+
333339
[[example]]
334340
name = "sapiens"
335341
required-features = ["sapiens"]

README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,8 @@ usls = { version = "latest-version", features = [ "cuda" ] }
134134
| [RTMW](https://arxiv.org/abs/2407.08634) | Keypoint Detection | `rtmpose` | [demo](examples/rtmw) |
135135
| [RTMO](https://github.com/open-mmlab/mmpose/tree/main/projects/rtmo) | Keypoint Detection | `rtmo` | [demo](examples/rtmo) |
136136
| [SAM](https://github.com/facebookresearch/segment-anything) | Segment Anything | `sam` | [demo](examples/sam) |
137-
| [SAM2](https://github.com/facebookresearch/segment-anything-2) | Segment Anything | `sam` | [demo](examples/sam) |
137+
| [SAM2](https://github.com/facebookresearch/segment-anything-2) | Segment Anything | `sam2` | [demo](examples/sam2) |
138+
| [SAM3](https://github.com/facebookresearch/segment-anything-3) | Segment Anything | `sam3` | [demo](examples/sam3) |
138139
| [MobileSAM](https://github.com/ChaoningZhang/MobileSAM) | Segment Anything | `sam` | [demo](examples/sam) |
139140
| [EdgeSAM](https://github.com/chongzhou96/EdgeSAM) | Segment Anything | `sam` | [demo](examples/sam) |
140141
| [SAM-HQ](https://github.com/SysCV/sam-hq) | Segment Anything | `sam` | [demo](examples/sam) |

assets/000000136466.jpg

98.2 KB
Loading

assets/sam3-demo.jpg

69 KB
Loading

examples/sam3/README.md

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
2+
### Quick Start
3+
4+
```bash
5+
# Text prompt
6+
cargo run -r -F sam3 -F cuda --example sam3 -- --device cuda --dtype q4f16 --source ./assets/sam3-demo.jpg -p shoe
7+
cargo run -r -F sam3 -F cuda --example sam3 -- --device cuda --dtype bnb4 --source ./assets/sam3-demo.jpg -p "person in red vest"
8+
cargo run -r -F sam3 -F cuda --example sam3 -- --device cuda --dtype q8 --source ./assets/sam3-demo.jpg -p "boy in blue vest"
9+
10+
# Visual prompt: a single bbox
11+
cargo run -r -F sam3 -F cuda --example sam3 -- --device cuda --source ./assets/sam3-demo.jpg -p "visual;pos:480,290,110,360"
12+
13+
# Visual prompt: multi-boxes prompting(with positive and negative boxes)
14+
cargo run -r -F sam3 -F cuda --example sam3 -- --device cuda --source ./assets/sam3-demo.jpg -p "visual;pos:480,290,110,360;neg:370,280,115,375"
15+
16+
# Text + negative box
17+
cargo run -r -F sam3 -F cuda --example sam3 -- --device cuda --dtype fp16 --source ./assets/000000136466.jpg -p "handle"
18+
cargo run -r -F sam3 -F cuda --example sam3 -- --device cuda --dtype fp16 --source ./assets/000000136466.jpg -p "handle;neg:40,183,278,21"
19+
20+
# Multiple prompts (Queries)
21+
cargo run -r -F sam3 -F cuda --example sam3 -- --device cuda --dtype fp16 --source ./assets/sam3-demo.jpg --source ./assets/bus.jpg -p shoe -p face -p person
22+
```
23+
24+
25+
26+
### Prompt Format
27+
28+
```
29+
"text;pos:x,y,w,h;neg:x,y,w,h"
30+
```
31+
32+
- `text`: Text description
33+
- `pos:x,y,w,h`: Positive box (find similar)
34+
- `neg:x,y,w,h`: Negative box (exclude region)
35+
36+
37+
### Results
38+
39+
![](https://github.com/jamjamjon/assets/releases/download/sam3/demo.jpg)
40+
![](https://github.com/jamjamjon/assets/releases/download/sam3/demo2.jpg)

examples/sam3/main.rs

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
use anyhow::Result;
2+
use usls::{
3+
models::{Sam3Prompt, SAM3},
4+
Annotator, Config, DataLoader,
5+
};
6+
7+
#[derive(argh::FromArgs)]
8+
/// SAM3 - Segment Anything Model 3
9+
struct Args {
10+
/// device (cpu:0, cuda:0, etc.)
11+
#[argh(option, default = "String::from(\"cpu:0\")")]
12+
device: String,
13+
14+
/// source image paths (can specify multiple)
15+
#[argh(
16+
option,
17+
default = "vec![
18+
String::from(\"./assets/sam3-demo.jpg\"),
19+
// String::from(\"./assets/bus.jpg\")
20+
]"
21+
)]
22+
source: Vec<String>,
23+
24+
/// prompts: "text;pos:x,y,w,h;neg:x,y,w,h" (can specify multiple)
25+
#[argh(option, short = 'p')]
26+
prompt: Vec<String>,
27+
28+
/// confidence threshold (default: 0.5)
29+
#[argh(option, default = "0.5")]
30+
conf: f32,
31+
32+
/// batch size min (default: 1)
33+
#[argh(option, default = "1")]
34+
batch_min: usize,
35+
36+
/// batch size (default: 1)
37+
#[argh(option, default = "1")]
38+
batch: usize,
39+
40+
/// batch size max (default: 4)
41+
#[argh(option, default = "4")]
42+
batch_max: usize,
43+
44+
/// dtype
45+
#[argh(option, default = "String::from(\"q4f16\")")]
46+
dtype: String,
47+
48+
/// show mask
49+
#[argh(switch)]
50+
show_mask: bool,
51+
}
52+
53+
fn main() -> Result<()> {
54+
tracing_subscriber::fmt()
55+
.with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
56+
.with_timer(tracing_subscriber::fmt::time::ChronoLocal::rfc_3339())
57+
.init();
58+
59+
let args: Args = argh::from_env();
60+
61+
// Parse prompts
62+
if args.prompt.is_empty() {
63+
anyhow::bail!("No prompt. Use -p \"text\" or -p \"visual;pos:x,y,w,h\"");
64+
}
65+
let prompts: Vec<Sam3Prompt> = args
66+
.prompt
67+
.iter()
68+
.map(|s| s.parse())
69+
.collect::<std::result::Result<Vec<_>, _>>()
70+
.map_err(|e| anyhow::anyhow!("{}", e))?;
71+
72+
// Build model
73+
let config = Config::sam3_image_predictor()
74+
.with_batch_size_all_min_opt_max(args.batch_min, args.batch, args.batch_max)
75+
.with_device_all(args.device.parse()?)
76+
.with_dtype_all(args.dtype.parse()?)
77+
.with_class_confs(&[args.conf])
78+
.with_num_dry_run_all(1)
79+
.commit()?;
80+
let mut model = SAM3::new(config)?;
81+
82+
// Annotator
83+
let annotator = Annotator::default().with_mask_style(
84+
usls::Style::mask()
85+
.with_draw_mask_polygon_largest(true)
86+
.with_visible(args.show_mask),
87+
);
88+
let output_dir = usls::Dir::Current.base_dir_with_subs(&["runs", model.spec()])?;
89+
90+
// DataLoader with batch iteration
91+
let dataloader = DataLoader::from_paths(&args.source)?
92+
.with_batch(args.batch)
93+
.with_progress_bar(true)
94+
.build()?;
95+
96+
// Process in batches
97+
for batch in dataloader {
98+
let ys = model.forward(&batch, &prompts)?;
99+
println!("ys: {:?}", ys);
100+
101+
for (img, y) in batch.iter().zip(ys.iter()) {
102+
annotator
103+
.annotate(img, y)?
104+
.save(output_dir.join(format!("{}.jpg", usls::timestamp(None))))?;
105+
}
106+
}
107+
108+
usls::perf(false);
109+
Ok(())
110+
}

src/core/dataloader.rs

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,25 @@ impl FromStr for DataLoader {
9696
}
9797

9898
impl DataLoader {
99+
/// Create DataLoader from multiple paths
100+
pub fn from_paths<P: AsRef<Path>>(paths: &[P]) -> Result<Self> {
101+
let paths: VecDeque<PathBuf> = paths.iter().map(|p| p.as_ref().to_path_buf()).collect();
102+
let nf = paths.len() as u64;
103+
104+
if paths.is_empty() {
105+
anyhow::bail!("No paths provided");
106+
}
107+
108+
info!("Found {:?} x{}", MediaType::Image(Location::Local), nf);
109+
110+
Ok(Self {
111+
paths: Some(paths),
112+
media_type: MediaType::Image(Location::Local),
113+
nf,
114+
..Default::default()
115+
})
116+
}
117+
99118
pub fn new(source: &str) -> Result<Self> {
100119
// paths & media_type
101120
let (paths, media_type) = Self::try_load_all(source)?;

src/models/mod.rs

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,9 @@ mod yoloe;
136136
#[cfg(feature = "sam")]
137137
mod sam2;
138138

139+
#[cfg(feature = "sam3")]
140+
mod sam3;
141+
139142
#[cfg(feature = "rtdetr")]
140143
mod d_fine;
141144

@@ -160,6 +163,12 @@ pub use yolo::*;
160163
#[cfg(feature = "sam")]
161164
pub use sam::*;
162165

166+
#[cfg(feature = "sam")]
167+
pub use sam2::*;
168+
169+
#[cfg(feature = "sam3")]
170+
pub use sam3::*;
171+
163172
#[cfg(feature = "clip")]
164173
pub use clip::*;
165174

@@ -244,9 +253,6 @@ pub use yolop::*;
244253
#[cfg(feature = "yoloe")]
245254
pub use yoloe::*;
246255

247-
#[cfg(feature = "sam")]
248-
pub use sam2::*;
249-
250256
#[cfg(feature = "rtdetr")]
251257
pub use d_fine::*;
252258

src/models/sam3/README.md

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
# SAM3: Segment Anything with Concepts
2+
3+
A powerful multimodal segmentation model supporting text, bounding box, and combined prompts.
4+
5+
## References
6+
7+
- Official: [facebookresearch/sam3](https://github.com/facebookresearch/sam3)
8+
9+
## Example
10+
11+
See [examples/sam3](../../../examples/sam3)

src/models/sam3/config.rs

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
use crate::Config;
2+
3+
/// Model configuration for `SAM3`
4+
impl Config {
5+
/// SAM3 base configuration
6+
///
7+
/// - Input size: 1008x1008 (FitExact, no aspect ratio preserved)
8+
/// - Normalization: mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]
9+
/// - Tokenizer: CLIP BPE (max_length=32)
10+
/// - Confidence threshold: 0.5
11+
pub fn sam3() -> Self {
12+
Self::default()
13+
.with_name("sam3")
14+
.with_batch_size_all_min_opt_max(1, 1, 4)
15+
.with_visual_encoder_ixx(0, 1, 3.into())
16+
.with_visual_encoder_ixx(0, 2, 1008.into())
17+
.with_visual_encoder_ixx(0, 3, 1008.into())
18+
.with_textual_encoder_ixx(0, 1, 32.into())
19+
.with_resize_mode(crate::ResizeMode::FitExact)
20+
.with_resize_filter("Bilinear")
21+
.with_image_mean(&[0.5, 0.5, 0.5])
22+
.with_image_std(&[0.5, 0.5, 0.5])
23+
.with_normalize(true)
24+
.with_find_contours(true)
25+
.with_class_confs(&[0.5])
26+
.with_model_max_length(32) // CLIP max length, enables auto padding/truncation
27+
.with_tokenizer_file("sam3/tokenizer.json")
28+
.with_tokenizer_config_file("sam3/tokenizer_config.json")
29+
.with_special_tokens_map_file("sam3/special_tokens_map.json")
30+
.with_config_file("sam3/config.json")
31+
}
32+
33+
pub fn sam3_image_predictor() -> Self {
34+
Self::sam3()
35+
.with_visual_encoder_file("vision-encoder.onnx")
36+
.with_textual_encoder_file("text-encoder.onnx")
37+
.with_encoder_file("geometry-encoder.onnx")
38+
.with_decoder_file("decoder.onnx")
39+
}
40+
}

0 commit comments

Comments
 (0)