Skip to content

Commit 68a1e91

Browse files
committed
wgsl: add spirv-unknown-naga-wgsl target, transpiling with naga 27
1 parent d637f7c commit 68a1e91

File tree

6 files changed

+210
-4
lines changed

6 files changed

+210
-4
lines changed

Cargo.lock

Lines changed: 46 additions & 4 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

crates/rustc_codegen_spirv/Cargo.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,8 @@ itertools = "0.14.0"
6161
tracing.workspace = true
6262
tracing-subscriber.workspace = true
6363
tracing-tree = "0.4.0"
64+
naga = { version = "27.0.3", features = ["spv-in", "wgsl-out"] }
65+
strum = { version = "0.27.2", features = ["derive"] }
6466

6567
[dev-dependencies]
6668
pretty_assertions = "1.0"

crates/rustc_codegen_spirv/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,7 @@ mod custom_decorations;
133133
mod custom_insts;
134134
mod link;
135135
mod linker;
136+
mod naga_transpile;
136137
mod spirv_type;
137138
mod spirv_type_constraints;
138139
mod symbols;

crates/rustc_codegen_spirv/src/link.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
use crate::maybe_pqp_cg_ssa as rustc_codegen_ssa;
33

44
use crate::codegen_cx::{CodegenArgs, SpirvMetadata};
5+
use crate::naga_transpile::should_transpile;
56
use crate::target::{SpirvTarget, SpirvTargetVariant};
67
use crate::{SpirvCodegenBackend, SpirvModuleBuffer, linker};
78
use ar::{Archive, GnuBuilder, Header};
@@ -323,6 +324,10 @@ fn post_link_single_module(
323324

324325
drop(save_modules_timer);
325326
}
327+
328+
if let Some(transpile) = should_transpile(sess) {
329+
transpile(sess, cg_args, &spv_binary, out_filename).ok();
330+
}
326331
}
327332

328333
fn do_spirv_opt(
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
use crate::codegen_cx::CodegenArgs;
2+
use crate::target::{NagaTarget, SpirvTarget};
3+
use rustc_session::Session;
4+
use rustc_span::ErrorGuaranteed;
5+
use std::path::Path;
6+
7+
pub type NagaTranspile = fn(
8+
sess: &Session,
9+
cg_args: &CodegenArgs,
10+
spv_binary: &[u32],
11+
out_filename: &Path,
12+
) -> Result<(), ErrorGuaranteed>;
13+
14+
pub fn should_transpile(sess: &Session) -> Option<NagaTranspile> {
15+
let target = SpirvTarget::parse_target(sess.opts.target_triple.tuple())
16+
.expect("parsing should fail earlier");
17+
match target {
18+
SpirvTarget::Naga(NagaTarget::NAGA_WGSL) => Some(transpile::wgsl_transpile),
19+
_ => None,
20+
}
21+
}
22+
23+
mod transpile {
24+
use crate::codegen_cx::CodegenArgs;
25+
use naga::error::ShaderError;
26+
use naga::valid::Capabilities;
27+
use rustc_session::Session;
28+
use rustc_span::ErrorGuaranteed;
29+
use std::path::Path;
30+
31+
pub fn wgsl_transpile(
32+
sess: &Session,
33+
_cg_args: &CodegenArgs,
34+
spv_binary: &[u32],
35+
out_filename: &Path,
36+
) -> Result<(), ErrorGuaranteed> {
37+
// these should be params via spirv-builder
38+
let opts = naga::front::spv::Options::default();
39+
let capabilities = Capabilities::default();
40+
let writer_flags = naga::back::wgsl::WriterFlags::empty();
41+
42+
let module = naga::front::spv::parse_u8_slice(bytemuck::cast_slice(spv_binary), &opts)
43+
.map_err(|err| {
44+
sess.dcx().err(format!(
45+
"Naga failed to parse spv: \n{}",
46+
ShaderError {
47+
source: String::new(),
48+
label: None,
49+
inner: Box::new(err),
50+
}
51+
))
52+
})?;
53+
let mut validator =
54+
naga::valid::Validator::new(naga::valid::ValidationFlags::default(), capabilities);
55+
let info = validator.validate(&module).map_err(|err| {
56+
sess.dcx().err(format!(
57+
"Naga validation failed: \n{}",
58+
ShaderError {
59+
source: String::new(),
60+
label: None,
61+
inner: Box::new(err),
62+
}
63+
))
64+
})?;
65+
66+
let wgsl_dst = out_filename.with_extension("wgsl");
67+
let wgsl = naga::back::wgsl::write_string(&module, &info, writer_flags).map_err(|err| {
68+
sess.dcx()
69+
.err(format!("Naga failed to write wgsl : \n{err}"))
70+
})?;
71+
72+
std::fs::write(&wgsl_dst, wgsl).map_err(|err| {
73+
sess.dcx()
74+
.err(format!("failed to write wgsl to file: {err}"))
75+
})?;
76+
77+
Ok(())
78+
}
79+
}

crates/rustc_codegen_spirv/src/target.rs

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,16 @@ use std::cmp::Ordering;
55
use std::fmt::{Debug, Display, Formatter};
66
use std::ops::{Deref, DerefMut};
77
use std::str::FromStr;
8+
use strum::{Display, EnumString, IntoStaticStr};
89

910
#[derive(Clone, Eq, PartialEq)]
1011
pub enum TargetError {
12+
/// If during parsing a target variant returns `UnknownTarget`, further variants will attempt to parse the string.
13+
/// Returning another error means that you have recognized the target but something else is invalid, and we should
14+
/// abort the parsing with your error.
1115
UnknownTarget(String),
1216
InvalidTargetVersion(SpirvTarget),
17+
InvalidNagaVariant(String),
1318
}
1419

1520
impl Display for TargetError {
@@ -21,6 +26,9 @@ impl Display for TargetError {
2126
TargetError::InvalidTargetVersion(target) => {
2227
write!(f, "Invalid version in target `{}`", target.env())
2328
}
29+
TargetError::InvalidNagaVariant(target) => {
30+
write!(f, "Unknown naga out variant `{target}`")
31+
}
2432
}
2533
}
2634
}
@@ -439,13 +447,71 @@ impl Display for OpenGLTarget {
439447
}
440448
}
441449

450+
/// A naga target
451+
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
452+
pub struct NagaTarget {
453+
pub out: NagaOut,
454+
}
455+
456+
#[derive(Copy, Clone, Debug, Eq, PartialEq, IntoStaticStr, Display, EnumString)]
457+
#[allow(clippy::upper_case_acronyms)]
458+
pub enum NagaOut {
459+
#[strum(to_string = "wgsl")]
460+
WGSL,
461+
}
462+
463+
impl NagaTarget {
464+
pub const NAGA_WGSL: Self = NagaTarget::new(NagaOut::WGSL);
465+
pub const ALL_NAGA_TARGETS: &'static [Self] = &[Self::NAGA_WGSL];
466+
/// emit spirv like naga targets were this target
467+
pub const EMIT_SPIRV_LIKE: SpirvTarget = SpirvTarget::VULKAN_1_3;
468+
469+
pub const fn new(out: NagaOut) -> Self {
470+
Self { out }
471+
}
472+
}
473+
474+
impl SpirvTargetVariant for NagaTarget {
475+
fn validate(&self) -> Result<(), TargetError> {
476+
Ok(())
477+
}
478+
479+
fn to_spirv_tools(&self) -> spirv_tools::TargetEnv {
480+
Self::EMIT_SPIRV_LIKE.to_spirv_tools()
481+
}
482+
483+
fn spirv_version(&self) -> SpirvVersion {
484+
Self::EMIT_SPIRV_LIKE.spirv_version()
485+
}
486+
}
487+
488+
impl FromStr for NagaTarget {
489+
type Err = TargetError;
490+
491+
fn from_str(s: &str) -> Result<Self, Self::Err> {
492+
let s = s
493+
.strip_prefix("naga-")
494+
.ok_or_else(|| TargetError::UnknownTarget(s.to_owned()))?;
495+
Ok(Self::new(FromStr::from_str(s).map_err(|_e| {
496+
TargetError::InvalidNagaVariant(s.to_owned())
497+
})?))
498+
}
499+
}
500+
501+
impl Display for NagaTarget {
502+
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
503+
write!(f, "naga-{}", self.out)
504+
}
505+
}
506+
442507
/// A rust-gpu target
443508
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
444509
#[non_exhaustive]
445510
pub enum SpirvTarget {
446511
Universal(UniversalTarget),
447512
Vulkan(VulkanTarget),
448513
OpenGL(OpenGLTarget),
514+
Naga(NagaTarget),
449515
}
450516

451517
impl SpirvTarget {
@@ -467,12 +533,15 @@ impl SpirvTarget {
467533
pub const OPENGL_4_2: Self = Self::OpenGL(OpenGLTarget::OPENGL_4_2);
468534
pub const OPENGL_4_3: Self = Self::OpenGL(OpenGLTarget::OPENGL_4_3);
469535
pub const OPENGL_4_5: Self = Self::OpenGL(OpenGLTarget::OPENGL_4_5);
536+
pub const NAGA_WGSL: Self = Self::Naga(NagaTarget::NAGA_WGSL);
470537

538+
#[allow(clippy::match_same_arms)]
471539
pub const fn memory_model(&self) -> MemoryModel {
472540
match self {
473541
SpirvTarget::Universal(_) => MemoryModel::Simple,
474542
SpirvTarget::Vulkan(_) => MemoryModel::Vulkan,
475543
SpirvTarget::OpenGL(_) => MemoryModel::GLSL450,
544+
SpirvTarget::Naga(_) => MemoryModel::Vulkan,
476545
}
477546
}
478547
}
@@ -483,6 +552,7 @@ impl SpirvTargetVariant for SpirvTarget {
483552
SpirvTarget::Universal(t) => t.validate(),
484553
SpirvTarget::Vulkan(t) => t.validate(),
485554
SpirvTarget::OpenGL(t) => t.validate(),
555+
SpirvTarget::Naga(t) => t.validate(),
486556
}
487557
}
488558

@@ -491,6 +561,7 @@ impl SpirvTargetVariant for SpirvTarget {
491561
SpirvTarget::Universal(t) => t.to_spirv_tools(),
492562
SpirvTarget::Vulkan(t) => t.to_spirv_tools(),
493563
SpirvTarget::OpenGL(t) => t.to_spirv_tools(),
564+
SpirvTarget::Naga(t) => t.to_spirv_tools(),
494565
}
495566
}
496567

@@ -499,6 +570,7 @@ impl SpirvTargetVariant for SpirvTarget {
499570
SpirvTarget::Universal(t) => t.spirv_version(),
500571
SpirvTarget::Vulkan(t) => t.spirv_version(),
501572
SpirvTarget::OpenGL(t) => t.spirv_version(),
573+
SpirvTarget::Naga(t) => t.spirv_version(),
502574
}
503575
}
504576
}
@@ -513,6 +585,9 @@ impl SpirvTarget {
513585
if matches!(result, Err(TargetError::UnknownTarget(..))) {
514586
result = OpenGLTarget::from_str(s).map(Self::OpenGL);
515587
}
588+
if matches!(result, Err(TargetError::UnknownTarget(..))) {
589+
result = NagaTarget::from_str(s).map(Self::Naga);
590+
}
516591
result
517592
}
518593

@@ -533,6 +608,7 @@ impl SpirvTarget {
533608
SpirvTarget::Universal(t) => t.to_string(),
534609
SpirvTarget::Vulkan(t) => t.to_string(),
535610
SpirvTarget::OpenGL(t) => t.to_string(),
611+
SpirvTarget::Naga(t) => t.to_string(),
536612
}
537613
}
538614

@@ -555,6 +631,7 @@ impl SpirvTarget {
555631
.iter()
556632
.map(|t| Self::OpenGL(*t)),
557633
)
634+
.chain(NagaTarget::ALL_NAGA_TARGETS.iter().map(|t| Self::Naga(*t)))
558635
}
559636
}
560637

0 commit comments

Comments
 (0)