[naga] Write only the current entrypoint (#7626)

Changes the MSL and HLSL backends to support writing only a single entry
point, and uses them that way in wgpu-hal.

This is working towards a fix for #5885.

* Increase the limit in test_stack_size
This commit is contained in:
Andy Leiserson 2025-04-30 00:59:41 -07:00 committed by GitHub
parent 9fccdf5cf3
commit 850c3d4310
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 141 additions and 45 deletions

View File

@ -79,6 +79,7 @@ Naga now infers the correct binding layout when a resource appears only in an as
#### Naga
- Mark `readonly_and_readwrite_storage_textures` & `packed_4x8_integer_dot_product` language extensions as implemented. By @teoxoy in [#7543](https://github.com/gfx-rs/wgpu/pull/7543)
- `naga::back::hlsl::Writer::new` has a new `pipeline_options` argument. `hlsl::PipelineOptions::default()` can be passed as a default. The `shader_stage` and `entry_point` members of `pipeline_options` can be used to write only a single entry point when using the HLSL and MSL backends (GLSL and SPIR-V already had this functionality). The Metal and DX12 HALs now write only a single entry point when loading shaders. By @andyleiserson in [#7626](https://github.com/gfx-rs/wgpu/pull/7626).
#### D3D12

View File

@ -349,7 +349,9 @@ fn backends(c: &mut Criterion) {
let options = naga::back::hlsl::Options::default();
let mut string = String::new();
for input in &inputs.inner {
let mut writer = naga::back::hlsl::Writer::new(&mut string, &options);
let pipeline_options = Default::default();
let mut writer =
naga::back::hlsl::Writer::new(&mut string, &options, &pipeline_options);
let _ = writer.write(
input.module.as_ref().unwrap(),
input.module_info.as_ref().unwrap(),

View File

@ -824,7 +824,8 @@ fn write_output(
.unwrap_pretty();
let mut buffer = String::new();
let mut writer = hlsl::Writer::new(&mut buffer, &params.hlsl);
let pipeline_options = Default::default();
let mut writer = hlsl::Writer::new(&mut buffer, &params.hlsl, &pipeline_options);
writer.write(&module, &info, None).unwrap_pretty();
fs::write(output_path, buffer)?;
}

View File

@ -349,7 +349,8 @@ pub struct PipelineOptions {
pub shader_stage: ShaderStage,
/// The name of the entry point.
///
/// If no entry point that matches is found while creating a [`Writer`], a error will be thrown.
/// If no entry point that matches is found while creating a [`Writer`], an
/// error will be thrown.
pub entry_point: String,
/// How many views to render to, if doing multiview rendering.
pub multiview: Option<core::num::NonZeroU32>,

View File

@ -119,7 +119,7 @@ use core::fmt::Error as FmtError;
use thiserror::Error;
use crate::{back, proc};
use crate::{back, ir, proc};
#[derive(Copy, Clone, Debug, Default, PartialEq, Eq, Hash)]
#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
@ -434,6 +434,22 @@ pub struct ReflectionInfo {
pub entry_point_names: Vec<Result<String, EntryPointError>>,
}
/// A subset of options that are meant to be changed per pipeline.
#[derive(Debug, Default, Clone)]
#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
#[cfg_attr(feature = "deserialize", serde(default))]
pub struct PipelineOptions {
/// The entry point to write.
///
/// Entry points are identified by a shader stage specification,
/// and a name.
///
/// If `None`, all entry points will be written. If `Some` and the entry
/// point is not found, an error will be thrown while writing.
pub entry_point: Option<(ir::ShaderStage, String)>,
}
#[derive(Error, Debug)]
pub enum Error {
#[error(transparent)]
@ -448,6 +464,8 @@ pub enum Error {
Override,
#[error(transparent)]
ResolveArraySizeError(#[from] proc::ResolveArraySizeError),
#[error("entry point with stage {0:?} and name '{1}' not found")]
EntryPointNotFound(ir::ShaderStage, String),
}
#[derive(PartialEq, Eq, Hash)]
@ -519,8 +537,10 @@ pub struct Writer<'a, W> {
namer: proc::Namer,
/// HLSL backend options
options: &'a Options,
/// Per-stage backend options
pipeline_options: &'a PipelineOptions,
/// Information about entry point arguments and result types.
entry_point_io: Vec<writer::EntryPointInterface>,
entry_point_io: crate::FastHashMap<usize, writer::EntryPointInterface>,
/// Set of expressions that have associated temporary variables
named_expressions: crate::NamedExpressions,
wrapped: Wrapped,

View File

@ -12,10 +12,10 @@ use super::{
WrappedZeroValue,
},
storage::StoreValue,
BackendResult, Error, FragmentEntryPoint, Options, ShaderModel,
BackendResult, Error, FragmentEntryPoint, Options, PipelineOptions, ShaderModel,
};
use crate::{
back::{self, Baked},
back::{self, get_entry_points, Baked},
common,
proc::{self, index, NameKey},
valid, Handle, Module, RayQueryFunction, Scalar, ScalarKind, ShaderStage, TypeInner,
@ -123,13 +123,14 @@ struct BindingArraySamplerInfo {
}
impl<'a, W: fmt::Write> super::Writer<'a, W> {
pub fn new(out: W, options: &'a Options) -> Self {
pub fn new(out: W, options: &'a Options, pipeline_options: &'a PipelineOptions) -> Self {
Self {
out,
names: crate::FastHashMap::default(),
namer: proc::Namer::default(),
options,
entry_point_io: Vec::new(),
pipeline_options,
entry_point_io: crate::FastHashMap::default(),
named_expressions: crate::NamedExpressions::default(),
wrapped: super::Wrapped::default(),
written_committed_intersection: false,
@ -387,8 +388,12 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
writeln!(self.out)?;
}
let ep_range = get_entry_points(module, self.pipeline_options.entry_point.as_ref())
.map_err(|(stage, name)| Error::EntryPointNotFound(stage, name))?;
// Write all entry points wrapped structs
for (index, ep) in module.entry_points.iter().enumerate() {
for index in ep_range.clone() {
let ep = &module.entry_points[index];
let ep_name = self.names[&NameKey::EntryPoint(index as u16)].clone();
let ep_io = self.write_ep_interface(
module,
@ -397,7 +402,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
&ep_name,
fragment_entry_point,
)?;
self.entry_point_io.push(ep_io);
self.entry_point_io.insert(index, ep_io);
}
// Write all regular functions
@ -442,10 +447,11 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
writeln!(self.out)?;
}
let mut entry_point_names = Vec::with_capacity(module.entry_points.len());
let mut translated_ep_names = Vec::with_capacity(ep_range.len());
// Write all entry points
for (index, ep) in module.entry_points.iter().enumerate() {
for index in ep_range {
let ep = &module.entry_points[index];
let info = module_info.get_entry_point(index);
if !self.options.fake_missing_bindings {
@ -462,7 +468,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
}
}
if let Some(err) = ep_error {
entry_point_names.push(Err(err));
translated_ep_names.push(Err(err));
continue;
}
}
@ -493,10 +499,12 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
writeln!(self.out)?;
}
entry_point_names.push(Ok(name));
translated_ep_names.push(Ok(name));
}
Ok(super::ReflectionInfo { entry_point_names })
Ok(super::ReflectionInfo {
entry_point_names: translated_ep_names,
})
}
fn write_modifier(&mut self, binding: &crate::Binding) -> BackendResult {
@ -816,7 +824,13 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
ep_index: u16,
) -> BackendResult {
let ep = &module.entry_points[ep_index as usize];
let ep_input = match self.entry_point_io[ep_index as usize].input.take() {
let ep_input = match self
.entry_point_io
.get_mut(&(ep_index as usize))
.unwrap()
.input
.take()
{
Some(ep_input) => ep_input,
None => return Ok(()),
};
@ -1432,7 +1446,9 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
}
}
back::FunctionType::EntryPoint(index) => {
if let Some(ref ep_output) = self.entry_point_io[index as usize].output {
if let Some(ref ep_output) =
self.entry_point_io.get(&(index as usize)).unwrap().output
{
write!(self.out, "{}", ep_output.ty_name)?;
} else {
self.write_type(module, result.ty)?;
@ -1479,7 +1495,9 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
}
}
back::FunctionType::EntryPoint(ep_index) => {
if let Some(ref ep_input) = self.entry_point_io[ep_index as usize].input {
if let Some(ref ep_input) =
self.entry_point_io.get(&(ep_index as usize)).unwrap().input
{
write!(self.out, "{} {}", ep_input.ty_name, ep_input.arg_name)?;
} else {
let stage = module.entry_points[ep_index as usize].stage;
@ -1501,7 +1519,12 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
}
}
if need_workgroup_variables_initialization {
if self.entry_point_io[ep_index as usize].input.is_some()
if self
.entry_point_io
.get(&(ep_index as usize))
.unwrap()
.input
.is_some()
|| !func.arguments.is_empty()
{
write!(self.out, ", ")?;
@ -1870,9 +1893,12 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
// for entry point returns, we may need to reshuffle the outputs into a different struct
let ep_output = match func_ctx.ty {
back::FunctionType::Function(_) => None,
back::FunctionType::EntryPoint(index) => {
self.entry_point_io[index as usize].output.as_ref()
}
back::FunctionType::EntryPoint(index) => self
.entry_point_io
.get(&(index as usize))
.unwrap()
.output
.as_ref(),
};
let final_name = match ep_output {
Some(ep_output) => {

View File

@ -79,6 +79,33 @@ impl core::fmt::Display for Level {
}
}
/// Locate the entry point(s) to write.
///
/// If `entry_point` is given, and the specified entry point exists, returns a
/// length-1 `Range` containing the index of that entry point. If no
/// `entry_point` is given, returns the complete range of entry point indices.
/// If `entry_point` is given but does not exist, returns an error.
#[cfg(any(hlsl_out, msl_out))]
fn get_entry_points(
module: &crate::ir::Module,
entry_point: Option<&(crate::ir::ShaderStage, String)>,
) -> Result<core::ops::Range<usize>, (crate::ir::ShaderStage, String)> {
use alloc::borrow::ToOwned;
if let Some(&(stage, ref name)) = entry_point {
let Some(ep_index) = module
.entry_points
.iter()
.position(|ep| ep.stage == stage && ep.name == *name)
else {
return Err((stage, name.to_owned()));
};
Ok(ep_index..ep_index + 1)
} else {
Ok(0..module.entry_points.len())
}
}
/// Whether we're generating an entry point or a regular function.
///
/// Backend languages often require different code for a [`Function`]

View File

@ -52,7 +52,7 @@ use alloc::{
};
use core::fmt::{Error as FmtError, Write};
use crate::{arena::Handle, proc::index, valid::ModuleInfo};
use crate::{arena::Handle, ir, proc::index, valid::ModuleInfo};
mod keywords;
pub mod sampler;
@ -184,7 +184,7 @@ pub enum Error {
#[error("can not use writeable storage buffers in fragment stage prior to MSL 1.2")]
UnsupportedWriteableStorageBuffer,
#[error("can not use writeable storage textures in {0:?} stage prior to MSL 1.2")]
UnsupportedWriteableStorageTexture(crate::ShaderStage),
UnsupportedWriteableStorageTexture(ir::ShaderStage),
#[error("can not use read-write storage textures prior to MSL 1.2")]
UnsupportedRWStorageTexture,
#[error("array of '{0}' is not supported for target MSL version")]
@ -199,6 +199,8 @@ pub enum Error {
UnsupportedBitCast(crate::TypeInner),
#[error(transparent)]
ResolveArraySizeError(#[from] crate::proc::ResolveArraySizeError),
#[error("entry point with stage {0:?} and name '{1}' not found")]
EntryPointNotFound(ir::ShaderStage, String),
}
#[derive(Clone, Debug, PartialEq, thiserror::Error)]
@ -420,6 +422,15 @@ pub struct VertexBufferMapping {
#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
#[cfg_attr(feature = "deserialize", serde(default))]
pub struct PipelineOptions {
/// The entry point to write.
///
/// Entry points are identified by a shader stage specification,
/// and a name.
///
/// If `None`, all entry points will be written. If `Some` and the entry
/// point is not found, an error will be thrown while writing.
pub entry_point: Option<(ir::ShaderStage, String)>,
/// Allow `BuiltIn::PointSize` and inject it if doesn't exist.
///
/// Metal doesn't like this for non-point primitive topologies and requires it for
@ -737,5 +748,5 @@ pub fn write_string(
#[test]
fn test_error_size() {
assert_eq!(size_of::<Error>(), 32);
assert_eq!(size_of::<Error>(), 40);
}

View File

@ -16,7 +16,7 @@ use half::f16;
use super::{sampler as sm, Error, LocationMode, Options, PipelineOptions, TranslationInfo};
use crate::{
arena::{Handle, HandleSet},
back::{self, Baked},
back::{self, get_entry_points, Baked},
common,
proc::{
self,
@ -5872,10 +5872,15 @@ template <typename A>
self.named_expressions.clear();
}
let ep_range = get_entry_points(module, pipeline_options.entry_point.as_ref())
.map_err(|(stage, name)| Error::EntryPointNotFound(stage, name))?;
let mut info = TranslationInfo {
entry_point_names: Vec::with_capacity(module.entry_points.len()),
entry_point_names: Vec::with_capacity(ep_range.len()),
};
for (ep_index, ep) in module.entry_points.iter().enumerate() {
for ep_index in ep_range {
let ep = &module.entry_points[ep_index];
let fun = &ep.function;
let fun_info = mod_info.get_entry_point(ep_index);
let mut ep_error = None;
@ -7076,8 +7081,8 @@ fn test_stack_size() {
}
let stack_size = addresses_end - addresses_start;
// check the size (in debug only)
// last observed macOS value: 20528 (CI)
if !(11000..=25000).contains(&stack_size) {
// last observed macOS value: 25904 (CI), 2025-04-29
if !(11000..=27000).contains(&stack_size) {
panic!("`put_expression` stack size {stack_size} has changed!");
}
}

View File

@ -741,7 +741,8 @@ fn write_output_hlsl(
.expect("override evaluation failed");
let mut buffer = String::new();
let mut writer = hlsl::Writer::new(&mut buffer, options);
let pipeline_options = Default::default();
let mut writer = hlsl::Writer::new(&mut buffer, options, &pipeline_options);
let reflection_info = writer
.write(&module, &info, frag_ep.as_ref())
.expect("HLSL write failed");

View File

@ -299,9 +299,13 @@ impl super::Device {
&layout.naga_options
};
let pipeline_options = hlsl::PipelineOptions {
entry_point: Some((naga_stage, stage.entry_point.to_string())),
};
//TODO: reuse the writer
let mut source = String::new();
let mut writer = hlsl::Writer::new(&mut source, naga_options);
let mut writer = hlsl::Writer::new(&mut source, naga_options, &pipeline_options);
let reflection_info = {
profiling::scope!("naga::back::hlsl::write");
writer
@ -315,13 +319,7 @@ impl super::Device {
naga_options.shader_model.to_str()
);
let ep_index = module
.entry_points
.iter()
.position(|ep| ep.stage == naga_stage && ep.name == stage.entry_point)
.ok_or(crate::PipelineError::EntryPoint(naga_stage))?;
let raw_ep = reflection_info.entry_point_names[ep_index]
let raw_ep = reflection_info.entry_point_names[0]
.as_ref()
.map_err(|e| crate::PipelineError::Linkage(stage_bit, format!("{e}")))?;

View File

@ -181,6 +181,7 @@ impl super::Device {
};
let pipeline_options = naga::back::msl::PipelineOptions {
entry_point: Some((naga_stage, stage.entry_point.to_owned())),
allow_and_force_point_size: match primitive_class {
MTLPrimitiveTopologyClass::Point => true,
_ => false,
@ -223,7 +224,7 @@ impl super::Device {
.position(|ep| ep.stage == naga_stage && ep.name == stage.entry_point)
.ok_or(crate::PipelineError::EntryPoint(naga_stage))?;
let ep = &module.entry_points[ep_index];
let ep_name = info.entry_point_names[ep_index]
let translated_ep_name = info.entry_point_names[0]
.as_ref()
.map_err(|e| crate::PipelineError::Linkage(stage_bit, format!("{}", e)))?;
@ -233,10 +234,12 @@ impl super::Device {
depth: ep.workgroup_size[2] as _,
};
let function = library.get_function(ep_name, None).map_err(|e| {
log::error!("get_function: {:?}", e);
crate::PipelineError::EntryPoint(naga_stage)
})?;
let function = library
.get_function(translated_ep_name, None)
.map_err(|e| {
log::error!("get_function: {:?}", e);
crate::PipelineError::EntryPoint(naga_stage)
})?;
// collect sizes indices, immutable buffers, and work group memory sizes
let ep_info = &module_info.get_entry_point(ep_index);