feat!: make ProgrammableStage::entry_point optional in wgpu-core

This commit is contained in:
Erich Gubler 2024-01-29 21:41:55 -05:00
parent 2c66504a59
commit 023b0e063f
11 changed files with 104 additions and 26 deletions

View File

@ -102,6 +102,7 @@ Bottom level categories:
``` ```
- `wgpu::Id` now implements `PartialOrd`/`Ord` allowing it to be put in `BTreeMap`s. By @cwfitzgerald and @9291Sam in [#5176](https://github.com/gfx-rs/wgpu/pull/5176) - `wgpu::Id` now implements `PartialOrd`/`Ord` allowing it to be put in `BTreeMap`s. By @cwfitzgerald and @9291Sam in [#5176](https://github.com/gfx-rs/wgpu/pull/5176)
- `wgpu::CommandEncoder::write_timestamp` requires now the new `wgpu::Features::TIMESTAMP_QUERY_INSIDE_ENCODERS` feature which is available on all native backends but not on WebGPU (due to a spec change `write_timestamp` is no longer supported on WebGPU). By @wumpf in [#5188](https://github.com/gfx-rs/wgpu/pull/5188) - `wgpu::CommandEncoder::write_timestamp` requires now the new `wgpu::Features::TIMESTAMP_QUERY_INSIDE_ENCODERS` feature which is available on all native backends but not on WebGPU (due to a spec change `write_timestamp` is no longer supported on WebGPU). By @wumpf in [#5188](https://github.com/gfx-rs/wgpu/pull/5188)
- Breaking change: [`wgpu_core::pipeline::ProgrammableStageDescriptor`](https://docs.rs/wgpu-core/latest/wgpu_core/pipeline/struct.ProgrammableStageDescriptor.html#structfield.entry_point) is now optional. By @ErichDonGubler in [#5305](https://github.com/gfx-rs/wgpu/pull/5305).
#### GLES #### GLES

View File

@ -110,7 +110,7 @@ pub fn op_webgpu_create_compute_pipeline(
layout: pipeline_layout, layout: pipeline_layout,
stage: wgpu_core::pipeline::ProgrammableStageDescriptor { stage: wgpu_core::pipeline::ProgrammableStageDescriptor {
module: compute_shader_module_resource.1, module: compute_shader_module_resource.1,
entry_point: Cow::from(compute.entry_point), entry_point: Some(Cow::from(compute.entry_point)),
// TODO(lucacasonato): support args.compute.constants // TODO(lucacasonato): support args.compute.constants
}, },
}; };
@ -355,7 +355,7 @@ pub fn op_webgpu_create_render_pipeline(
Some(wgpu_core::pipeline::FragmentState { Some(wgpu_core::pipeline::FragmentState {
stage: wgpu_core::pipeline::ProgrammableStageDescriptor { stage: wgpu_core::pipeline::ProgrammableStageDescriptor {
module: fragment_shader_module_resource.1, module: fragment_shader_module_resource.1,
entry_point: Cow::from(fragment.entry_point), entry_point: Some(Cow::from(fragment.entry_point)),
}, },
targets: Cow::from(fragment.targets), targets: Cow::from(fragment.targets),
}) })
@ -377,7 +377,7 @@ pub fn op_webgpu_create_render_pipeline(
vertex: wgpu_core::pipeline::VertexState { vertex: wgpu_core::pipeline::VertexState {
stage: wgpu_core::pipeline::ProgrammableStageDescriptor { stage: wgpu_core::pipeline::ProgrammableStageDescriptor {
module: vertex_shader_module_resource.1, module: vertex_shader_module_resource.1,
entry_point: Cow::Owned(args.vertex.entry_point), entry_point: Some(Cow::Owned(args.vertex.entry_point)),
}, },
buffers: Cow::Owned(vertex_buffers), buffers: Cow::Owned(vertex_buffers),
}, },

View File

@ -56,7 +56,7 @@
layout: Some(Id(0, 1, Empty)), layout: Some(Id(0, 1, Empty)),
stage: ( stage: (
module: Id(0, 1, Empty), module: Id(0, 1, Empty),
entry_point: "main", entry_point: Some("main"),
), ),
), ),
), ),

View File

@ -29,7 +29,7 @@
layout: Some(Id(0, 1, Empty)), layout: Some(Id(0, 1, Empty)),
stage: ( stage: (
module: Id(0, 1, Empty), module: Id(0, 1, Empty),
entry_point: "main", entry_point: Some("main"),
), ),
), ),
), ),

View File

@ -57,14 +57,14 @@
vertex: ( vertex: (
stage: ( stage: (
module: Id(0, 1, Empty), module: Id(0, 1, Empty),
entry_point: "vs_main", entry_point: Some("vs_main"),
), ),
buffers: [], buffers: [],
), ),
fragment: Some(( fragment: Some((
stage: ( stage: (
module: Id(0, 1, Empty), module: Id(0, 1, Empty),
entry_point: "fs_main", entry_point: Some("fs_main"),
), ),
targets: [ targets: [
Some(( Some((

View File

@ -133,7 +133,7 @@
layout: Some(Id(0, 1, Empty)), layout: Some(Id(0, 1, Empty)),
stage: ( stage: (
module: Id(0, 1, Empty), module: Id(0, 1, Empty),
entry_point: "main", entry_point: Some("main"),
), ),
), ),
), ),

View File

@ -134,7 +134,7 @@
layout: Some(Id(0, 1, Empty)), layout: Some(Id(0, 1, Empty)),
stage: ( stage: (
module: Id(0, 1, Empty), module: Id(0, 1, Empty),
entry_point: "main", entry_point: Some("main"),
), ),
), ),
), ),

View File

@ -2705,14 +2705,21 @@ impl<A: HalApi> Device<A> {
let mut shader_binding_sizes = FastHashMap::default(); let mut shader_binding_sizes = FastHashMap::default();
let io = validation::StageIo::default(); let io = validation::StageIo::default();
let final_entry_point_name;
{ {
let stage = wgt::ShaderStages::COMPUTE; let stage = wgt::ShaderStages::COMPUTE;
final_entry_point_name = shader_module.finalize_entry_point_name(
stage,
desc.stage.entry_point.as_ref().map(|ep| ep.as_ref()),
)?;
if let Some(ref interface) = shader_module.interface { if let Some(ref interface) = shader_module.interface {
let _ = interface.check_stage( let _ = interface.check_stage(
&mut binding_layout_source, &mut binding_layout_source,
&mut shader_binding_sizes, &mut shader_binding_sizes,
&desc.stage.entry_point, &final_entry_point_name,
stage, stage,
io, io,
None, None,
@ -2740,7 +2747,7 @@ impl<A: HalApi> Device<A> {
label: desc.label.to_hal(self.instance_flags), label: desc.label.to_hal(self.instance_flags),
layout: pipeline_layout.raw(), layout: pipeline_layout.raw(),
stage: hal::ProgrammableStage { stage: hal::ProgrammableStage {
entry_point: desc.stage.entry_point.as_ref(), entry_point: final_entry_point_name.as_ref(),
module: shader_module.raw(), module: shader_module.raw(),
}, },
}; };
@ -3115,6 +3122,7 @@ impl<A: HalApi> Device<A> {
}; };
let vertex_shader_module; let vertex_shader_module;
let vertex_entry_point_name;
let vertex_stage = { let vertex_stage = {
let stage_desc = &desc.vertex.stage; let stage_desc = &desc.vertex.stage;
let stage = wgt::ShaderStages::VERTEX; let stage = wgt::ShaderStages::VERTEX;
@ -3131,12 +3139,19 @@ impl<A: HalApi> Device<A> {
let stage_err = |error| pipeline::CreateRenderPipelineError::Stage { stage, error }; let stage_err = |error| pipeline::CreateRenderPipelineError::Stage { stage, error };
vertex_entry_point_name = vertex_shader_module
.finalize_entry_point_name(
stage,
stage_desc.entry_point.as_ref().map(|ep| ep.as_ref()),
)
.map_err(stage_err)?;
if let Some(ref interface) = vertex_shader_module.interface { if let Some(ref interface) = vertex_shader_module.interface {
io = interface io = interface
.check_stage( .check_stage(
&mut binding_layout_source, &mut binding_layout_source,
&mut shader_binding_sizes, &mut shader_binding_sizes,
&stage_desc.entry_point, &vertex_entry_point_name,
stage, stage,
io, io,
desc.depth_stencil.as_ref().map(|d| d.depth_compare), desc.depth_stencil.as_ref().map(|d| d.depth_compare),
@ -3147,11 +3162,12 @@ impl<A: HalApi> Device<A> {
hal::ProgrammableStage { hal::ProgrammableStage {
module: vertex_shader_module.raw(), module: vertex_shader_module.raw(),
entry_point: stage_desc.entry_point.as_ref(), entry_point: &vertex_entry_point_name,
} }
}; };
let mut fragment_shader_module = None; let mut fragment_shader_module = None;
let fragment_entry_point_name;
let fragment_stage = match desc.fragment { let fragment_stage = match desc.fragment {
Some(ref fragment_state) => { Some(ref fragment_state) => {
let stage = wgt::ShaderStages::FRAGMENT; let stage = wgt::ShaderStages::FRAGMENT;
@ -3167,13 +3183,24 @@ impl<A: HalApi> Device<A> {
let stage_err = |error| pipeline::CreateRenderPipelineError::Stage { stage, error }; let stage_err = |error| pipeline::CreateRenderPipelineError::Stage { stage, error };
fragment_entry_point_name = shader_module
.finalize_entry_point_name(
stage,
fragment_state
.stage
.entry_point
.as_ref()
.map(|ep| ep.as_ref()),
)
.map_err(stage_err)?;
if validated_stages == wgt::ShaderStages::VERTEX { if validated_stages == wgt::ShaderStages::VERTEX {
if let Some(ref interface) = shader_module.interface { if let Some(ref interface) = shader_module.interface {
io = interface io = interface
.check_stage( .check_stage(
&mut binding_layout_source, &mut binding_layout_source,
&mut shader_binding_sizes, &mut shader_binding_sizes,
&fragment_state.stage.entry_point, &fragment_entry_point_name,
stage, stage,
io, io,
desc.depth_stencil.as_ref().map(|d| d.depth_compare), desc.depth_stencil.as_ref().map(|d| d.depth_compare),
@ -3185,7 +3212,7 @@ impl<A: HalApi> Device<A> {
if let Some(ref interface) = shader_module.interface { if let Some(ref interface) = shader_module.interface {
shader_expects_dual_source_blending = interface shader_expects_dual_source_blending = interface
.fragment_uses_dual_source_blending(&fragment_state.stage.entry_point) .fragment_uses_dual_source_blending(&fragment_entry_point_name)
.map_err(|error| pipeline::CreateRenderPipelineError::Stage { .map_err(|error| pipeline::CreateRenderPipelineError::Stage {
stage, stage,
error, error,
@ -3194,7 +3221,7 @@ impl<A: HalApi> Device<A> {
Some(hal::ProgrammableStage { Some(hal::ProgrammableStage {
module: shader_module.raw(), module: shader_module.raw(),
entry_point: fragment_state.stage.entry_point.as_ref(), entry_point: &fragment_entry_point_name,
}) })
} }
None => None, None => None,

View File

@ -92,6 +92,19 @@ impl<A: HalApi> ShaderModule<A> {
pub(crate) fn raw(&self) -> &A::ShaderModule { pub(crate) fn raw(&self) -> &A::ShaderModule {
self.raw.as_ref().unwrap() self.raw.as_ref().unwrap()
} }
pub(crate) fn finalize_entry_point_name(
&self,
stage_bit: wgt::ShaderStages,
entry_point: Option<&str>,
) -> Result<String, validation::StageError> {
match &self.interface {
Some(interface) => interface.finalize_entry_point_name(stage_bit, entry_point),
None => entry_point
.map(|ep| ep.to_string())
.ok_or(validation::StageError::NoEntryPointFound),
}
}
} }
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
@ -213,9 +226,13 @@ impl CreateShaderModuleError {
pub struct ProgrammableStageDescriptor<'a> { pub struct ProgrammableStageDescriptor<'a> {
/// The compiled shader module for this stage. /// The compiled shader module for this stage.
pub module: ShaderModuleId, pub module: ShaderModuleId,
/// The name of the entry point in the compiled shader. There must be a function with this name /// The name of the entry point in the compiled shader. The name is selected using the
/// in the shader. /// following logic:
pub entry_point: Cow<'a, str>, ///
/// * If `Some(name)` is specified, there must be a function with this name in the shader.
/// * If a single entry point associated with this stage must be in the shader, then proceed as
/// if `Some(…)` was specified with that entry point's name.
pub entry_point: Option<Cow<'a, str>>,
} }
/// Number of implicit bind groups derived at pipeline creation. /// Number of implicit bind groups derived at pipeline creation.

View File

@ -283,6 +283,16 @@ pub enum StageError {
}, },
#[error("Location[{location}] is provided by the previous stage output but is not consumed as input by this stage.")] #[error("Location[{location}] is provided by the previous stage output but is not consumed as input by this stage.")]
InputNotConsumed { location: wgt::ShaderLocation }, InputNotConsumed { location: wgt::ShaderLocation },
#[error(
"Unable to select an entry point: no entry point was found in the provided shader module"
)]
NoEntryPointFound,
#[error(
"Unable to select an entry point: \
multiple entry points were found in the provided shader module, \
but no entry point was specified"
)]
MultipleEntryPointsFound,
} }
fn map_storage_format_to_naga(format: wgt::TextureFormat) -> Option<naga::StorageFormat> { fn map_storage_format_to_naga(format: wgt::TextureFormat) -> Option<naga::StorageFormat> {
@ -971,6 +981,28 @@ impl Interface {
} }
} }
pub fn finalize_entry_point_name(
&self,
stage_bit: wgt::ShaderStages,
entry_point_name: Option<&str>,
) -> Result<String, StageError> {
let stage = Self::shader_stage_from_stage_bit(stage_bit);
entry_point_name
.map(|ep| ep.to_string())
.map(Ok)
.unwrap_or_else(|| {
let mut entry_points = self
.entry_points
.keys()
.filter_map(|(ep_stage, name)| (ep_stage == &stage).then_some(name));
let first = entry_points.next().ok_or(StageError::NoEntryPointFound)?;
if entry_points.next().is_some() {
return Err(StageError::MultipleEntryPointsFound);
}
Ok(first.clone())
})
}
pub(crate) fn shader_stage_from_stage_bit(stage_bit: wgt::ShaderStages) -> naga::ShaderStage { pub(crate) fn shader_stage_from_stage_bit(stage_bit: wgt::ShaderStages) -> naga::ShaderStage {
match stage_bit { match stage_bit {
wgt::ShaderStages::VERTEX => naga::ShaderStage::Vertex, wgt::ShaderStages::VERTEX => naga::ShaderStage::Vertex,
@ -993,10 +1025,11 @@ impl Interface {
// we need to look for one with the right execution model. // we need to look for one with the right execution model.
let shader_stage = Self::shader_stage_from_stage_bit(stage_bit); let shader_stage = Self::shader_stage_from_stage_bit(stage_bit);
let pair = (shader_stage, entry_point_name.to_string()); let pair = (shader_stage, entry_point_name.to_string());
let entry_point = self let entry_point = match self.entry_points.get(&pair) {
.entry_points Some(some) => some,
.get(&pair) None => return Err(StageError::MissingEntryPoint(pair.1)),
.ok_or(StageError::MissingEntryPoint(pair.1))?; };
let (_stage, entry_point_name) = pair;
// check resources visibility // check resources visibility
for &handle in entry_point.resources.iter() { for &handle in entry_point.resources.iter() {

View File

@ -1102,7 +1102,7 @@ impl crate::Context for ContextWgpuCore {
vertex: pipe::VertexState { vertex: pipe::VertexState {
stage: pipe::ProgrammableStageDescriptor { stage: pipe::ProgrammableStageDescriptor {
module: desc.vertex.module.id.into(), module: desc.vertex.module.id.into(),
entry_point: Borrowed(desc.vertex.entry_point), entry_point: Some(Borrowed(desc.vertex.entry_point)),
}, },
buffers: Borrowed(&vertex_buffers), buffers: Borrowed(&vertex_buffers),
}, },
@ -1112,7 +1112,7 @@ impl crate::Context for ContextWgpuCore {
fragment: desc.fragment.as_ref().map(|frag| pipe::FragmentState { fragment: desc.fragment.as_ref().map(|frag| pipe::FragmentState {
stage: pipe::ProgrammableStageDescriptor { stage: pipe::ProgrammableStageDescriptor {
module: frag.module.id.into(), module: frag.module.id.into(),
entry_point: Borrowed(frag.entry_point), entry_point: Some(Borrowed(frag.entry_point)),
}, },
targets: Borrowed(frag.targets), targets: Borrowed(frag.targets),
}), }),
@ -1160,7 +1160,7 @@ impl crate::Context for ContextWgpuCore {
layout: desc.layout.map(|l| l.id.into()), layout: desc.layout.map(|l| l.id.into()),
stage: pipe::ProgrammableStageDescriptor { stage: pipe::ProgrammableStageDescriptor {
module: desc.module.id.into(), module: desc.module.id.into(),
entry_point: Borrowed(desc.entry_point), entry_point: Some(Borrowed(desc.entry_point)),
}, },
}; };