mirror of
https://github.com/gfx-rs/wgpu.git
synced 2025-12-08 21:26:17 +00:00
[naga] Write only the current entrypoint (#7626)
Changes the MSL and HLSL backends to support writing only a single entry point, and uses them that way in wgpu-hal. This is working towards a fix for #5885. * Increase the limit in test_stack_size
This commit is contained in:
parent
9fccdf5cf3
commit
850c3d4310
@ -79,6 +79,7 @@ Naga now infers the correct binding layout when a resource appears only in an as
|
||||
#### Naga
|
||||
|
||||
- Mark `readonly_and_readwrite_storage_textures` & `packed_4x8_integer_dot_product` language extensions as implemented. By @teoxoy in [#7543](https://github.com/gfx-rs/wgpu/pull/7543)
|
||||
- `naga::back::hlsl::Writer::new` has a new `pipeline_options` argument. `hlsl::PipelineOptions::default()` can be passed as a default. The `shader_stage` and `entry_point` members of `pipeline_options` can be used to write only a single entry point when using the HLSL and MSL backends (GLSL and SPIR-V already had this functionality). The Metal and DX12 HALs now write only a single entry point when loading shaders. By @andyleiserson in [#7626](https://github.com/gfx-rs/wgpu/pull/7626).
|
||||
|
||||
#### D3D12
|
||||
|
||||
|
||||
@ -349,7 +349,9 @@ fn backends(c: &mut Criterion) {
|
||||
let options = naga::back::hlsl::Options::default();
|
||||
let mut string = String::new();
|
||||
for input in &inputs.inner {
|
||||
let mut writer = naga::back::hlsl::Writer::new(&mut string, &options);
|
||||
let pipeline_options = Default::default();
|
||||
let mut writer =
|
||||
naga::back::hlsl::Writer::new(&mut string, &options, &pipeline_options);
|
||||
let _ = writer.write(
|
||||
input.module.as_ref().unwrap(),
|
||||
input.module_info.as_ref().unwrap(),
|
||||
|
||||
@ -824,7 +824,8 @@ fn write_output(
|
||||
.unwrap_pretty();
|
||||
|
||||
let mut buffer = String::new();
|
||||
let mut writer = hlsl::Writer::new(&mut buffer, ¶ms.hlsl);
|
||||
let pipeline_options = Default::default();
|
||||
let mut writer = hlsl::Writer::new(&mut buffer, ¶ms.hlsl, &pipeline_options);
|
||||
writer.write(&module, &info, None).unwrap_pretty();
|
||||
fs::write(output_path, buffer)?;
|
||||
}
|
||||
|
||||
@ -349,7 +349,8 @@ pub struct PipelineOptions {
|
||||
pub shader_stage: ShaderStage,
|
||||
/// The name of the entry point.
|
||||
///
|
||||
/// If no entry point that matches is found while creating a [`Writer`], a error will be thrown.
|
||||
/// If no entry point that matches is found while creating a [`Writer`], an
|
||||
/// error will be thrown.
|
||||
pub entry_point: String,
|
||||
/// How many views to render to, if doing multiview rendering.
|
||||
pub multiview: Option<core::num::NonZeroU32>,
|
||||
|
||||
@ -119,7 +119,7 @@ use core::fmt::Error as FmtError;
|
||||
|
||||
use thiserror::Error;
|
||||
|
||||
use crate::{back, proc};
|
||||
use crate::{back, ir, proc};
|
||||
|
||||
#[derive(Copy, Clone, Debug, Default, PartialEq, Eq, Hash)]
|
||||
#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
|
||||
@ -434,6 +434,22 @@ pub struct ReflectionInfo {
|
||||
pub entry_point_names: Vec<Result<String, EntryPointError>>,
|
||||
}
|
||||
|
||||
/// A subset of options that are meant to be changed per pipeline.
|
||||
#[derive(Debug, Default, Clone)]
|
||||
#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
|
||||
#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
|
||||
#[cfg_attr(feature = "deserialize", serde(default))]
|
||||
pub struct PipelineOptions {
|
||||
/// The entry point to write.
|
||||
///
|
||||
/// Entry points are identified by a shader stage specification,
|
||||
/// and a name.
|
||||
///
|
||||
/// If `None`, all entry points will be written. If `Some` and the entry
|
||||
/// point is not found, an error will be thrown while writing.
|
||||
pub entry_point: Option<(ir::ShaderStage, String)>,
|
||||
}
|
||||
|
||||
#[derive(Error, Debug)]
|
||||
pub enum Error {
|
||||
#[error(transparent)]
|
||||
@ -448,6 +464,8 @@ pub enum Error {
|
||||
Override,
|
||||
#[error(transparent)]
|
||||
ResolveArraySizeError(#[from] proc::ResolveArraySizeError),
|
||||
#[error("entry point with stage {0:?} and name '{1}' not found")]
|
||||
EntryPointNotFound(ir::ShaderStage, String),
|
||||
}
|
||||
|
||||
#[derive(PartialEq, Eq, Hash)]
|
||||
@ -519,8 +537,10 @@ pub struct Writer<'a, W> {
|
||||
namer: proc::Namer,
|
||||
/// HLSL backend options
|
||||
options: &'a Options,
|
||||
/// Per-stage backend options
|
||||
pipeline_options: &'a PipelineOptions,
|
||||
/// Information about entry point arguments and result types.
|
||||
entry_point_io: Vec<writer::EntryPointInterface>,
|
||||
entry_point_io: crate::FastHashMap<usize, writer::EntryPointInterface>,
|
||||
/// Set of expressions that have associated temporary variables
|
||||
named_expressions: crate::NamedExpressions,
|
||||
wrapped: Wrapped,
|
||||
|
||||
@ -12,10 +12,10 @@ use super::{
|
||||
WrappedZeroValue,
|
||||
},
|
||||
storage::StoreValue,
|
||||
BackendResult, Error, FragmentEntryPoint, Options, ShaderModel,
|
||||
BackendResult, Error, FragmentEntryPoint, Options, PipelineOptions, ShaderModel,
|
||||
};
|
||||
use crate::{
|
||||
back::{self, Baked},
|
||||
back::{self, get_entry_points, Baked},
|
||||
common,
|
||||
proc::{self, index, NameKey},
|
||||
valid, Handle, Module, RayQueryFunction, Scalar, ScalarKind, ShaderStage, TypeInner,
|
||||
@ -123,13 +123,14 @@ struct BindingArraySamplerInfo {
|
||||
}
|
||||
|
||||
impl<'a, W: fmt::Write> super::Writer<'a, W> {
|
||||
pub fn new(out: W, options: &'a Options) -> Self {
|
||||
pub fn new(out: W, options: &'a Options, pipeline_options: &'a PipelineOptions) -> Self {
|
||||
Self {
|
||||
out,
|
||||
names: crate::FastHashMap::default(),
|
||||
namer: proc::Namer::default(),
|
||||
options,
|
||||
entry_point_io: Vec::new(),
|
||||
pipeline_options,
|
||||
entry_point_io: crate::FastHashMap::default(),
|
||||
named_expressions: crate::NamedExpressions::default(),
|
||||
wrapped: super::Wrapped::default(),
|
||||
written_committed_intersection: false,
|
||||
@ -387,8 +388,12 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
|
||||
writeln!(self.out)?;
|
||||
}
|
||||
|
||||
let ep_range = get_entry_points(module, self.pipeline_options.entry_point.as_ref())
|
||||
.map_err(|(stage, name)| Error::EntryPointNotFound(stage, name))?;
|
||||
|
||||
// Write all entry points wrapped structs
|
||||
for (index, ep) in module.entry_points.iter().enumerate() {
|
||||
for index in ep_range.clone() {
|
||||
let ep = &module.entry_points[index];
|
||||
let ep_name = self.names[&NameKey::EntryPoint(index as u16)].clone();
|
||||
let ep_io = self.write_ep_interface(
|
||||
module,
|
||||
@ -397,7 +402,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
|
||||
&ep_name,
|
||||
fragment_entry_point,
|
||||
)?;
|
||||
self.entry_point_io.push(ep_io);
|
||||
self.entry_point_io.insert(index, ep_io);
|
||||
}
|
||||
|
||||
// Write all regular functions
|
||||
@ -442,10 +447,11 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
|
||||
writeln!(self.out)?;
|
||||
}
|
||||
|
||||
let mut entry_point_names = Vec::with_capacity(module.entry_points.len());
|
||||
let mut translated_ep_names = Vec::with_capacity(ep_range.len());
|
||||
|
||||
// Write all entry points
|
||||
for (index, ep) in module.entry_points.iter().enumerate() {
|
||||
for index in ep_range {
|
||||
let ep = &module.entry_points[index];
|
||||
let info = module_info.get_entry_point(index);
|
||||
|
||||
if !self.options.fake_missing_bindings {
|
||||
@ -462,7 +468,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
|
||||
}
|
||||
}
|
||||
if let Some(err) = ep_error {
|
||||
entry_point_names.push(Err(err));
|
||||
translated_ep_names.push(Err(err));
|
||||
continue;
|
||||
}
|
||||
}
|
||||
@ -493,10 +499,12 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
|
||||
writeln!(self.out)?;
|
||||
}
|
||||
|
||||
entry_point_names.push(Ok(name));
|
||||
translated_ep_names.push(Ok(name));
|
||||
}
|
||||
|
||||
Ok(super::ReflectionInfo { entry_point_names })
|
||||
Ok(super::ReflectionInfo {
|
||||
entry_point_names: translated_ep_names,
|
||||
})
|
||||
}
|
||||
|
||||
fn write_modifier(&mut self, binding: &crate::Binding) -> BackendResult {
|
||||
@ -816,7 +824,13 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
|
||||
ep_index: u16,
|
||||
) -> BackendResult {
|
||||
let ep = &module.entry_points[ep_index as usize];
|
||||
let ep_input = match self.entry_point_io[ep_index as usize].input.take() {
|
||||
let ep_input = match self
|
||||
.entry_point_io
|
||||
.get_mut(&(ep_index as usize))
|
||||
.unwrap()
|
||||
.input
|
||||
.take()
|
||||
{
|
||||
Some(ep_input) => ep_input,
|
||||
None => return Ok(()),
|
||||
};
|
||||
@ -1432,7 +1446,9 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
|
||||
}
|
||||
}
|
||||
back::FunctionType::EntryPoint(index) => {
|
||||
if let Some(ref ep_output) = self.entry_point_io[index as usize].output {
|
||||
if let Some(ref ep_output) =
|
||||
self.entry_point_io.get(&(index as usize)).unwrap().output
|
||||
{
|
||||
write!(self.out, "{}", ep_output.ty_name)?;
|
||||
} else {
|
||||
self.write_type(module, result.ty)?;
|
||||
@ -1479,7 +1495,9 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
|
||||
}
|
||||
}
|
||||
back::FunctionType::EntryPoint(ep_index) => {
|
||||
if let Some(ref ep_input) = self.entry_point_io[ep_index as usize].input {
|
||||
if let Some(ref ep_input) =
|
||||
self.entry_point_io.get(&(ep_index as usize)).unwrap().input
|
||||
{
|
||||
write!(self.out, "{} {}", ep_input.ty_name, ep_input.arg_name)?;
|
||||
} else {
|
||||
let stage = module.entry_points[ep_index as usize].stage;
|
||||
@ -1501,7 +1519,12 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
|
||||
}
|
||||
}
|
||||
if need_workgroup_variables_initialization {
|
||||
if self.entry_point_io[ep_index as usize].input.is_some()
|
||||
if self
|
||||
.entry_point_io
|
||||
.get(&(ep_index as usize))
|
||||
.unwrap()
|
||||
.input
|
||||
.is_some()
|
||||
|| !func.arguments.is_empty()
|
||||
{
|
||||
write!(self.out, ", ")?;
|
||||
@ -1870,9 +1893,12 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
|
||||
// for entry point returns, we may need to reshuffle the outputs into a different struct
|
||||
let ep_output = match func_ctx.ty {
|
||||
back::FunctionType::Function(_) => None,
|
||||
back::FunctionType::EntryPoint(index) => {
|
||||
self.entry_point_io[index as usize].output.as_ref()
|
||||
}
|
||||
back::FunctionType::EntryPoint(index) => self
|
||||
.entry_point_io
|
||||
.get(&(index as usize))
|
||||
.unwrap()
|
||||
.output
|
||||
.as_ref(),
|
||||
};
|
||||
let final_name = match ep_output {
|
||||
Some(ep_output) => {
|
||||
|
||||
@ -79,6 +79,33 @@ impl core::fmt::Display for Level {
|
||||
}
|
||||
}
|
||||
|
||||
/// Locate the entry point(s) to write.
|
||||
///
|
||||
/// If `entry_point` is given, and the specified entry point exists, returns a
|
||||
/// length-1 `Range` containing the index of that entry point. If no
|
||||
/// `entry_point` is given, returns the complete range of entry point indices.
|
||||
/// If `entry_point` is given but does not exist, returns an error.
|
||||
#[cfg(any(hlsl_out, msl_out))]
|
||||
fn get_entry_points(
|
||||
module: &crate::ir::Module,
|
||||
entry_point: Option<&(crate::ir::ShaderStage, String)>,
|
||||
) -> Result<core::ops::Range<usize>, (crate::ir::ShaderStage, String)> {
|
||||
use alloc::borrow::ToOwned;
|
||||
|
||||
if let Some(&(stage, ref name)) = entry_point {
|
||||
let Some(ep_index) = module
|
||||
.entry_points
|
||||
.iter()
|
||||
.position(|ep| ep.stage == stage && ep.name == *name)
|
||||
else {
|
||||
return Err((stage, name.to_owned()));
|
||||
};
|
||||
Ok(ep_index..ep_index + 1)
|
||||
} else {
|
||||
Ok(0..module.entry_points.len())
|
||||
}
|
||||
}
|
||||
|
||||
/// Whether we're generating an entry point or a regular function.
|
||||
///
|
||||
/// Backend languages often require different code for a [`Function`]
|
||||
|
||||
@ -52,7 +52,7 @@ use alloc::{
|
||||
};
|
||||
use core::fmt::{Error as FmtError, Write};
|
||||
|
||||
use crate::{arena::Handle, proc::index, valid::ModuleInfo};
|
||||
use crate::{arena::Handle, ir, proc::index, valid::ModuleInfo};
|
||||
|
||||
mod keywords;
|
||||
pub mod sampler;
|
||||
@ -184,7 +184,7 @@ pub enum Error {
|
||||
#[error("can not use writeable storage buffers in fragment stage prior to MSL 1.2")]
|
||||
UnsupportedWriteableStorageBuffer,
|
||||
#[error("can not use writeable storage textures in {0:?} stage prior to MSL 1.2")]
|
||||
UnsupportedWriteableStorageTexture(crate::ShaderStage),
|
||||
UnsupportedWriteableStorageTexture(ir::ShaderStage),
|
||||
#[error("can not use read-write storage textures prior to MSL 1.2")]
|
||||
UnsupportedRWStorageTexture,
|
||||
#[error("array of '{0}' is not supported for target MSL version")]
|
||||
@ -199,6 +199,8 @@ pub enum Error {
|
||||
UnsupportedBitCast(crate::TypeInner),
|
||||
#[error(transparent)]
|
||||
ResolveArraySizeError(#[from] crate::proc::ResolveArraySizeError),
|
||||
#[error("entry point with stage {0:?} and name '{1}' not found")]
|
||||
EntryPointNotFound(ir::ShaderStage, String),
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, thiserror::Error)]
|
||||
@ -420,6 +422,15 @@ pub struct VertexBufferMapping {
|
||||
#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
|
||||
#[cfg_attr(feature = "deserialize", serde(default))]
|
||||
pub struct PipelineOptions {
|
||||
/// The entry point to write.
|
||||
///
|
||||
/// Entry points are identified by a shader stage specification,
|
||||
/// and a name.
|
||||
///
|
||||
/// If `None`, all entry points will be written. If `Some` and the entry
|
||||
/// point is not found, an error will be thrown while writing.
|
||||
pub entry_point: Option<(ir::ShaderStage, String)>,
|
||||
|
||||
/// Allow `BuiltIn::PointSize` and inject it if doesn't exist.
|
||||
///
|
||||
/// Metal doesn't like this for non-point primitive topologies and requires it for
|
||||
@ -737,5 +748,5 @@ pub fn write_string(
|
||||
|
||||
#[test]
|
||||
fn test_error_size() {
|
||||
assert_eq!(size_of::<Error>(), 32);
|
||||
assert_eq!(size_of::<Error>(), 40);
|
||||
}
|
||||
|
||||
@ -16,7 +16,7 @@ use half::f16;
|
||||
use super::{sampler as sm, Error, LocationMode, Options, PipelineOptions, TranslationInfo};
|
||||
use crate::{
|
||||
arena::{Handle, HandleSet},
|
||||
back::{self, Baked},
|
||||
back::{self, get_entry_points, Baked},
|
||||
common,
|
||||
proc::{
|
||||
self,
|
||||
@ -5872,10 +5872,15 @@ template <typename A>
|
||||
self.named_expressions.clear();
|
||||
}
|
||||
|
||||
let ep_range = get_entry_points(module, pipeline_options.entry_point.as_ref())
|
||||
.map_err(|(stage, name)| Error::EntryPointNotFound(stage, name))?;
|
||||
|
||||
let mut info = TranslationInfo {
|
||||
entry_point_names: Vec::with_capacity(module.entry_points.len()),
|
||||
entry_point_names: Vec::with_capacity(ep_range.len()),
|
||||
};
|
||||
for (ep_index, ep) in module.entry_points.iter().enumerate() {
|
||||
|
||||
for ep_index in ep_range {
|
||||
let ep = &module.entry_points[ep_index];
|
||||
let fun = &ep.function;
|
||||
let fun_info = mod_info.get_entry_point(ep_index);
|
||||
let mut ep_error = None;
|
||||
@ -7076,8 +7081,8 @@ fn test_stack_size() {
|
||||
}
|
||||
let stack_size = addresses_end - addresses_start;
|
||||
// check the size (in debug only)
|
||||
// last observed macOS value: 20528 (CI)
|
||||
if !(11000..=25000).contains(&stack_size) {
|
||||
// last observed macOS value: 25904 (CI), 2025-04-29
|
||||
if !(11000..=27000).contains(&stack_size) {
|
||||
panic!("`put_expression` stack size {stack_size} has changed!");
|
||||
}
|
||||
}
|
||||
|
||||
@ -741,7 +741,8 @@ fn write_output_hlsl(
|
||||
.expect("override evaluation failed");
|
||||
|
||||
let mut buffer = String::new();
|
||||
let mut writer = hlsl::Writer::new(&mut buffer, options);
|
||||
let pipeline_options = Default::default();
|
||||
let mut writer = hlsl::Writer::new(&mut buffer, options, &pipeline_options);
|
||||
let reflection_info = writer
|
||||
.write(&module, &info, frag_ep.as_ref())
|
||||
.expect("HLSL write failed");
|
||||
|
||||
@ -299,9 +299,13 @@ impl super::Device {
|
||||
&layout.naga_options
|
||||
};
|
||||
|
||||
let pipeline_options = hlsl::PipelineOptions {
|
||||
entry_point: Some((naga_stage, stage.entry_point.to_string())),
|
||||
};
|
||||
|
||||
//TODO: reuse the writer
|
||||
let mut source = String::new();
|
||||
let mut writer = hlsl::Writer::new(&mut source, naga_options);
|
||||
let mut writer = hlsl::Writer::new(&mut source, naga_options, &pipeline_options);
|
||||
let reflection_info = {
|
||||
profiling::scope!("naga::back::hlsl::write");
|
||||
writer
|
||||
@ -315,13 +319,7 @@ impl super::Device {
|
||||
naga_options.shader_model.to_str()
|
||||
);
|
||||
|
||||
let ep_index = module
|
||||
.entry_points
|
||||
.iter()
|
||||
.position(|ep| ep.stage == naga_stage && ep.name == stage.entry_point)
|
||||
.ok_or(crate::PipelineError::EntryPoint(naga_stage))?;
|
||||
|
||||
let raw_ep = reflection_info.entry_point_names[ep_index]
|
||||
let raw_ep = reflection_info.entry_point_names[0]
|
||||
.as_ref()
|
||||
.map_err(|e| crate::PipelineError::Linkage(stage_bit, format!("{e}")))?;
|
||||
|
||||
|
||||
@ -181,6 +181,7 @@ impl super::Device {
|
||||
};
|
||||
|
||||
let pipeline_options = naga::back::msl::PipelineOptions {
|
||||
entry_point: Some((naga_stage, stage.entry_point.to_owned())),
|
||||
allow_and_force_point_size: match primitive_class {
|
||||
MTLPrimitiveTopologyClass::Point => true,
|
||||
_ => false,
|
||||
@ -223,7 +224,7 @@ impl super::Device {
|
||||
.position(|ep| ep.stage == naga_stage && ep.name == stage.entry_point)
|
||||
.ok_or(crate::PipelineError::EntryPoint(naga_stage))?;
|
||||
let ep = &module.entry_points[ep_index];
|
||||
let ep_name = info.entry_point_names[ep_index]
|
||||
let translated_ep_name = info.entry_point_names[0]
|
||||
.as_ref()
|
||||
.map_err(|e| crate::PipelineError::Linkage(stage_bit, format!("{}", e)))?;
|
||||
|
||||
@ -233,10 +234,12 @@ impl super::Device {
|
||||
depth: ep.workgroup_size[2] as _,
|
||||
};
|
||||
|
||||
let function = library.get_function(ep_name, None).map_err(|e| {
|
||||
log::error!("get_function: {:?}", e);
|
||||
crate::PipelineError::EntryPoint(naga_stage)
|
||||
})?;
|
||||
let function = library
|
||||
.get_function(translated_ep_name, None)
|
||||
.map_err(|e| {
|
||||
log::error!("get_function: {:?}", e);
|
||||
crate::PipelineError::EntryPoint(naga_stage)
|
||||
})?;
|
||||
|
||||
// collect sizes indices, immutable buffers, and work group memory sizes
|
||||
let ep_info = &module_info.get_entry_point(ep_index);
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user