Add mesh shading api to wgpu & wgpu-core (#7345)

This commit is contained in:
SupaMaggie70Incorporated 2025-07-24 19:58:56 -05:00 committed by GitHub
parent 68a10a01d9
commit 074c0e7191
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
50 changed files with 2208 additions and 673 deletions

View File

@ -157,6 +157,7 @@ By @Vecvec in [#7829](https://github.com/gfx-rs/wgpu/pull/7829).
- Add acceleration structure limits. By @Vecvec in [#7845](https://github.com/gfx-rs/wgpu/pull/7845).
- Add support for clip-distances feature for Vulkan and GL backends. By @dzamkov in [#7730](https://github.com/gfx-rs/wgpu/pull/7730)
- Added `wgpu_types::error::{ErrorType, WebGpuError}` for classification of errors according to WebGPU's [`GPUError`]'s classification scheme, and implement `WebGpuError` for existing errors. This allows users of `wgpu-core` to offload error classification onto the WGPU ecosystem, rather than having to do it themselves without sufficient information. By @ErichDonGubler in [#6547](https://github.com/gfx-rs/wgpu/pull/6547).
- Added mesh shader support to `wgpu`, with examples. Requires passthrough. By @SupaMaggie70Incorporated in [#7345](https://github.com/gfx-rs/wgpu/pull/7345).
[`GPUError`]: https://www.w3.org/TR/webgpu/#gpuerror

View File

@ -34,7 +34,7 @@ impl Display for JsError {
impl Debug for JsError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self)
write!(f, "{self}")
}
}

View File

@ -43,6 +43,7 @@ These examples use a common framework to handle wgpu init, window creation, and
- `ray_cube_fragment` - Demonstrates using ray queries with a fragment shader.
- `ray_scene` - Demonstrates using ray queries and model loading
- `ray_shadows` - Demonstrates a simple use of ray queries - high quality shadows - uses a light set with push constants to raytrace through an untransformed scene and detect whether there is something obstructing the light.
- `mesh_shader` - Rrenders a triangle to a window with mesh shaders, while showcasing most mesh shader related features(task shaders, payloads, per primitive data).
#### Compute

View File

@ -13,6 +13,7 @@ pub mod hello_synchronization;
pub mod hello_triangle;
pub mod hello_windows;
pub mod hello_workgroups;
pub mod mesh_shader;
pub mod mipmap;
pub mod msaa_line;
pub mod multiple_render_targets;

View File

@ -182,6 +182,12 @@ const EXAMPLES: &[ExampleDesc] = &[
webgl: false, // No Ray-tracing extensions
webgpu: false, // No Ray-tracing extensions (yet)
},
ExampleDesc {
name: "mesh_shader",
function: wgpu_examples::mesh_shader::main,
webgl: false,
webgpu: false,
},
];
fn get_example_name() -> Option<String> {

View File

@ -0,0 +1,9 @@
# mesh_shader
This example renders a triangle to a window with mesh shaders, while showcasing most mesh shader related features(task shaders, payloads, per primitive data).
## To Run
```
cargo run --bin wgpu-examples mesh_shader
```

View File

@ -0,0 +1,142 @@
use std::{io::Write, process::Stdio};
// Same as in mesh shader tests
fn compile_glsl(
device: &wgpu::Device,
data: &[u8],
shader_stage: &'static str,
) -> wgpu::ShaderModule {
let cmd = std::process::Command::new("glslc")
.args([
&format!("-fshader-stage={shader_stage}"),
"-",
"-o",
"-",
"--target-env=vulkan1.2",
"--target-spv=spv1.4",
])
.stdin(Stdio::piped())
.stdout(Stdio::piped())
.spawn()
.expect("Failed to call glslc");
cmd.stdin.as_ref().unwrap().write_all(data).unwrap();
println!("{shader_stage}");
let output = cmd.wait_with_output().expect("Error waiting for glslc");
assert!(output.status.success());
unsafe {
device.create_shader_module_passthrough(wgpu::ShaderModuleDescriptorPassthrough::SpirV(
wgpu::ShaderModuleDescriptorSpirV {
label: None,
source: wgpu::util::make_spirv_raw(&output.stdout),
},
))
}
}
pub struct Example {
pipeline: wgpu::RenderPipeline,
}
impl crate::framework::Example for Example {
fn init(
config: &wgpu::SurfaceConfiguration,
_adapter: &wgpu::Adapter,
device: &wgpu::Device,
_queue: &wgpu::Queue,
) -> Self {
let pipeline_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
label: None,
bind_group_layouts: &[],
push_constant_ranges: &[],
});
let (ts, ms, fs) = (
compile_glsl(device, include_bytes!("shader.task"), "task"),
compile_glsl(device, include_bytes!("shader.mesh"), "mesh"),
compile_glsl(device, include_bytes!("shader.frag"), "frag"),
);
let pipeline = device.create_mesh_pipeline(&wgpu::MeshPipelineDescriptor {
label: None,
layout: Some(&pipeline_layout),
task: Some(wgpu::TaskState {
module: &ts,
entry_point: Some("main"),
compilation_options: Default::default(),
}),
mesh: wgpu::MeshState {
module: &ms,
entry_point: Some("main"),
compilation_options: Default::default(),
},
fragment: Some(wgpu::FragmentState {
module: &fs,
entry_point: Some("main"),
compilation_options: Default::default(),
targets: &[Some(config.view_formats[0].into())],
}),
primitive: wgpu::PrimitiveState {
cull_mode: Some(wgpu::Face::Back),
..Default::default()
},
depth_stencil: None,
multisample: Default::default(),
multiview: None,
cache: None,
});
Self { pipeline }
}
fn render(&mut self, view: &wgpu::TextureView, device: &wgpu::Device, queue: &wgpu::Queue) {
let mut encoder =
device.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None });
{
let mut rpass = encoder.begin_render_pass(&wgpu::RenderPassDescriptor {
label: None,
color_attachments: &[Some(wgpu::RenderPassColorAttachment {
view,
resolve_target: None,
ops: wgpu::Operations {
load: wgpu::LoadOp::Clear(wgpu::Color {
r: 0.1,
g: 0.2,
b: 0.3,
a: 1.0,
}),
store: wgpu::StoreOp::Store,
},
depth_slice: None,
})],
depth_stencil_attachment: None,
timestamp_writes: None,
occlusion_query_set: None,
});
rpass.push_debug_group("Prepare data for draw.");
rpass.set_pipeline(&self.pipeline);
rpass.pop_debug_group();
rpass.insert_debug_marker("Draw!");
rpass.draw_mesh_tasks(1, 1, 1);
}
queue.submit(Some(encoder.finish()));
}
fn required_downlevel_capabilities() -> wgpu::DownlevelCapabilities {
Default::default()
}
fn required_features() -> wgpu::Features {
wgpu::Features::EXPERIMENTAL_MESH_SHADER | wgpu::Features::SPIRV_SHADER_PASSTHROUGH
}
fn required_limits() -> wgpu::Limits {
wgpu::Limits::defaults().using_recommended_minimum_mesh_shader_values()
}
fn resize(
&mut self,
_config: &wgpu::SurfaceConfiguration,
_device: &wgpu::Device,
_queue: &wgpu::Queue,
) {
// empty
}
fn update(&mut self, _event: winit::event::WindowEvent) {
// empty
}
}
pub fn main() {
crate::framework::run::<Example>("mesh_shader");
}

View File

@ -0,0 +1,11 @@
#version 450
#extension GL_EXT_mesh_shader : require
in VertexInput { layout(location = 0) vec4 color; }
vertexInput;
layout(location = 1) perprimitiveEXT in PrimitiveInput { vec4 colorMask; }
primitiveInput;
layout(location = 0) out vec4 fragColor;
void main() { fragColor = vertexInput.color * primitiveInput.colorMask; }

View File

@ -0,0 +1,38 @@
#version 450
#extension GL_EXT_mesh_shader : require
const vec4[3] positions = {vec4(0., 1.0, 0., 1.0), vec4(-1.0, -1.0, 0., 1.0),
vec4(1.0, -1.0, 0., 1.0)};
const vec4[3] colors = {vec4(0., 1., 0., 1.), vec4(0., 0., 1., 1.),
vec4(1., 0., 0., 1.)};
// This is an inefficient workgroup size.Ideally the total thread count would be
// a multiple of 64
layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
struct PayloadData {
vec4 colorMask;
bool visible;
};
taskPayloadSharedEXT PayloadData payloadData;
out VertexOutput { layout(location = 0) vec4 color; }
vertexOutput[];
layout(location = 1) perprimitiveEXT out PrimitiveOutput { vec4 colorMask; }
primitiveOutput[];
shared uint sharedData;
layout(triangles, max_vertices = 3, max_primitives = 1) out;
void main() {
sharedData = 5;
SetMeshOutputsEXT(3, 1);
gl_MeshVerticesEXT[0].gl_Position = positions[0];
gl_MeshVerticesEXT[1].gl_Position = positions[1];
gl_MeshVerticesEXT[2].gl_Position = positions[2];
vertexOutput[0].color = colors[0] * payloadData.colorMask;
vertexOutput[1].color = colors[1] * payloadData.colorMask;
vertexOutput[2].color = colors[2] * payloadData.colorMask;
gl_PrimitiveTriangleIndicesEXT[gl_LocalInvocationIndex] = uvec3(0, 1, 2);
primitiveOutput[0].colorMask = vec4(1.0, 0.0, 1.0, 1.0);
gl_MeshPrimitivesEXT[0].gl_CullPrimitiveEXT = !payloadData.visible;
}

View File

@ -0,0 +1,16 @@
#version 450
#extension GL_EXT_mesh_shader : require
layout(local_size_x = 4, local_size_y = 1, local_size_z = 1) in;
struct TaskPayload {
vec4 colorMask;
bool visible;
};
taskPayloadSharedEXT TaskPayload taskPayload;
void main() {
taskPayload.colorMask = vec4(1.0, 1.0, 0.0, 1.0);
taskPayload.visible = true;
EmitMeshTasksEXT(3, 1, 1);
}

View File

@ -161,6 +161,13 @@ impl DeviceInterface for CustomDevice {
unimplemented!()
}
fn create_mesh_pipeline(
&self,
_desc: &wgpu::MeshPipelineDescriptor<'_>,
) -> wgpu::custom::DispatchRenderPipeline {
unimplemented!()
}
fn create_compute_pipeline(
&self,
desc: &wgpu::ComputePipelineDescriptor<'_>,

View File

@ -333,6 +333,12 @@ impl GlobalPlay for wgc::global::Global {
panic!("{e}");
}
}
Action::CreateMeshPipeline { id, desc } => {
let (_, error) = self.device_create_mesh_pipeline(device, &desc, Some(id));
if let Some(e) = error {
panic!("{e}");
}
}
Action::DestroyRenderPipeline(id) => {
self.render_pipeline_drop(id);
}

View File

@ -34,6 +34,7 @@ mod image_atomics;
mod instance;
mod life_cycle;
mod mem_leaks;
mod mesh_shader;
mod nv12_texture;
mod occlusion_query;
mod oob_indexing;

View File

@ -0,0 +1,9 @@
#version 450
#extension GL_EXT_mesh_shader : require
in VertexInput { layout(location = 0) vec4 color; }
vertexInput;
layout(location = 0) out vec4 fragColor;
void main() { fragColor = vertexInput.color; }

View File

@ -0,0 +1,25 @@
#version 450
#extension GL_EXT_mesh_shader : require
const vec4[3] positions = {vec4(0., 1.0, 0., 1.0), vec4(-1.0, -1.0, 0., 1.0),
vec4(1.0, -1.0, 0., 1.0)};
const vec4[3] colors = {vec4(0., 1., 0., 1.), vec4(0., 0., 1., 1.),
vec4(1., 0., 0., 1.)};
layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
out VertexOutput { layout(location = 0) vec4 color; }
vertexOutput[];
layout(triangles, max_vertices = 3, max_primitives = 1) out;
void main() {
SetMeshOutputsEXT(3, 1);
gl_MeshVerticesEXT[0].gl_Position = positions[0];
gl_MeshVerticesEXT[1].gl_Position = positions[1];
gl_MeshVerticesEXT[2].gl_Position = positions[2];
vertexOutput[0].color = colors[0];
vertexOutput[1].color = colors[1];
vertexOutput[2].color = colors[2];
gl_PrimitiveTriangleIndicesEXT[gl_LocalInvocationIndex] = uvec3(0, 1, 2);
}

View File

@ -0,0 +1,6 @@
#version 450
#extension GL_EXT_mesh_shader : require
layout(local_size_x = 4, local_size_y = 1, local_size_z = 1) in;
void main() { EmitMeshTasksEXT(1, 1, 1); }

View File

@ -0,0 +1,310 @@
use std::{io::Write, process::Stdio};
use wgpu::util::DeviceExt;
use wgpu_test::{gpu_test, GpuTestConfiguration, TestParameters, TestingContext};
// Same as in mesh shader example
fn compile_glsl(
device: &wgpu::Device,
data: &[u8],
shader_stage: &'static str,
) -> wgpu::ShaderModule {
let cmd = std::process::Command::new("glslc")
.args([
&format!("-fshader-stage={shader_stage}"),
"-",
"-o",
"-",
"--target-env=vulkan1.2",
"--target-spv=spv1.4",
])
.stdin(Stdio::piped())
.stdout(Stdio::piped())
.spawn()
.expect("Failed to call glslc");
cmd.stdin.as_ref().unwrap().write_all(data).unwrap();
println!("{shader_stage}");
let output = cmd.wait_with_output().expect("Error waiting for glslc");
assert!(output.status.success());
unsafe {
device.create_shader_module_passthrough(wgpu::ShaderModuleDescriptorPassthrough::SpirV(
wgpu::ShaderModuleDescriptorSpirV {
label: None,
source: wgpu::util::make_spirv_raw(&output.stdout),
},
))
}
}
fn create_depth(
device: &wgpu::Device,
) -> (wgpu::Texture, wgpu::TextureView, wgpu::DepthStencilState) {
let image_size = wgpu::Extent3d {
width: 64,
height: 64,
depth_or_array_layers: 1,
};
let depth_texture = device.create_texture(&wgpu::TextureDescriptor {
label: None,
size: image_size,
mip_level_count: 1,
sample_count: 1,
dimension: wgpu::TextureDimension::D2,
format: wgpu::TextureFormat::Depth32Float,
usage: wgpu::TextureUsages::RENDER_ATTACHMENT | wgpu::TextureUsages::TEXTURE_BINDING,
view_formats: &[],
});
let depth_view = depth_texture.create_view(&Default::default());
let state = wgpu::DepthStencilState {
format: wgpu::TextureFormat::Depth32Float,
depth_write_enabled: true,
depth_compare: wgpu::CompareFunction::Less, // 1.
stencil: wgpu::StencilState::default(), // 2.
bias: wgpu::DepthBiasState::default(),
};
(depth_texture, depth_view, state)
}
fn mesh_pipeline_build(
ctx: &TestingContext,
task: Option<&[u8]>,
mesh: &[u8],
frag: Option<&[u8]>,
draw: bool,
) {
let device = &ctx.device;
let (_depth_image, depth_view, depth_state) = create_depth(device);
let task = task.map(|t| compile_glsl(device, t, "task"));
let mesh = compile_glsl(device, mesh, "mesh");
let frag = frag.map(|f| compile_glsl(device, f, "frag"));
let layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
label: None,
bind_group_layouts: &[],
push_constant_ranges: &[],
});
let pipeline = device.create_mesh_pipeline(&wgpu::MeshPipelineDescriptor {
label: None,
layout: Some(&layout),
task: task.as_ref().map(|task| wgpu::TaskState {
module: task,
entry_point: Some("main"),
compilation_options: Default::default(),
}),
mesh: wgpu::MeshState {
module: &mesh,
entry_point: Some("main"),
compilation_options: Default::default(),
},
fragment: frag.as_ref().map(|frag| wgpu::FragmentState {
module: frag,
entry_point: Some("main"),
targets: &[],
compilation_options: Default::default(),
}),
primitive: wgpu::PrimitiveState {
cull_mode: Some(wgpu::Face::Back),
..Default::default()
},
depth_stencil: Some(depth_state),
multisample: Default::default(),
multiview: None,
cache: None,
});
if draw {
let mut encoder =
device.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None });
{
let mut pass = encoder.begin_render_pass(&wgpu::RenderPassDescriptor {
label: None,
color_attachments: &[],
depth_stencil_attachment: Some(wgpu::RenderPassDepthStencilAttachment {
view: &depth_view,
depth_ops: Some(wgpu::Operations {
load: wgpu::LoadOp::Clear(1.0),
store: wgpu::StoreOp::Store,
}),
stencil_ops: None,
}),
timestamp_writes: None,
occlusion_query_set: None,
});
pass.set_pipeline(&pipeline);
pass.draw_mesh_tasks(1, 1, 1);
}
ctx.queue.submit(Some(encoder.finish()));
ctx.device.poll(wgpu::PollType::Wait).unwrap();
}
}
#[derive(PartialEq, Eq, Clone, Copy)]
pub enum DrawType {
#[allow(dead_code)]
Standard,
Indirect,
MultiIndirect,
MultiIndirectCount,
}
fn mesh_draw(ctx: &TestingContext, draw_type: DrawType) {
let device = &ctx.device;
let (_depth_image, depth_view, depth_state) = create_depth(device);
let task = compile_glsl(device, BASIC_TASK, "task");
let mesh = compile_glsl(device, BASIC_MESH, "mesh");
let frag = compile_glsl(device, NO_WRITE_FRAG, "frag");
let layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
label: None,
bind_group_layouts: &[],
push_constant_ranges: &[],
});
let pipeline = device.create_mesh_pipeline(&wgpu::MeshPipelineDescriptor {
label: None,
layout: Some(&layout),
task: Some(wgpu::TaskState {
module: &task,
entry_point: Some("main"),
compilation_options: Default::default(),
}),
mesh: wgpu::MeshState {
module: &mesh,
entry_point: Some("main"),
compilation_options: Default::default(),
},
fragment: Some(wgpu::FragmentState {
module: &frag,
entry_point: Some("main"),
targets: &[],
compilation_options: Default::default(),
}),
primitive: wgpu::PrimitiveState {
cull_mode: Some(wgpu::Face::Back),
..Default::default()
},
depth_stencil: Some(depth_state),
multisample: Default::default(),
multiview: None,
cache: None,
});
let buffer = match draw_type {
DrawType::Standard => None,
DrawType::Indirect | DrawType::MultiIndirect | DrawType::MultiIndirectCount => Some(
device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
label: None,
usage: wgpu::BufferUsages::INDIRECT,
contents: bytemuck::bytes_of(&[1u32; 4]),
}),
),
};
let count_buffer = match draw_type {
DrawType::MultiIndirectCount => Some(device.create_buffer_init(
&wgpu::util::BufferInitDescriptor {
label: None,
usage: wgpu::BufferUsages::INDIRECT,
contents: bytemuck::bytes_of(&[1u32; 1]),
},
)),
_ => None,
};
let mut encoder =
device.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None });
{
let mut pass = encoder.begin_render_pass(&wgpu::RenderPassDescriptor {
label: None,
color_attachments: &[],
depth_stencil_attachment: Some(wgpu::RenderPassDepthStencilAttachment {
view: &depth_view,
depth_ops: Some(wgpu::Operations {
load: wgpu::LoadOp::Clear(1.0),
store: wgpu::StoreOp::Store,
}),
stencil_ops: None,
}),
timestamp_writes: None,
occlusion_query_set: None,
});
pass.set_pipeline(&pipeline);
match draw_type {
DrawType::Standard => pass.draw_mesh_tasks(1, 1, 1),
DrawType::Indirect => pass.draw_mesh_tasks_indirect(buffer.as_ref().unwrap(), 0),
DrawType::MultiIndirect => {
pass.multi_draw_mesh_tasks_indirect(buffer.as_ref().unwrap(), 0, 1)
}
DrawType::MultiIndirectCount => pass.multi_draw_mesh_tasks_indirect_count(
buffer.as_ref().unwrap(),
0,
count_buffer.as_ref().unwrap(),
0,
1,
),
}
pass.draw_mesh_tasks_indirect(buffer.as_ref().unwrap(), 0);
}
ctx.queue.submit(Some(encoder.finish()));
ctx.device.poll(wgpu::PollType::Wait).unwrap();
}
const BASIC_TASK: &[u8] = include_bytes!("basic.task");
const BASIC_MESH: &[u8] = include_bytes!("basic.mesh");
//const BASIC_FRAG: &[u8] = include_bytes!("basic.frag.spv");
const NO_WRITE_FRAG: &[u8] = include_bytes!("no-write.frag");
fn default_gpu_test_config(draw_type: DrawType) -> GpuTestConfiguration {
GpuTestConfiguration::new().parameters(
TestParameters::default()
.test_features_limits()
.features(
wgpu::Features::EXPERIMENTAL_MESH_SHADER
| wgpu::Features::SPIRV_SHADER_PASSTHROUGH
| match draw_type {
DrawType::Standard | DrawType::Indirect => wgpu::Features::empty(),
DrawType::MultiIndirect => wgpu::Features::MULTI_DRAW_INDIRECT,
DrawType::MultiIndirectCount => wgpu::Features::MULTI_DRAW_INDIRECT_COUNT,
},
)
.limits(wgpu::Limits::default().using_recommended_minimum_mesh_shader_values()),
)
}
// Mesh pipeline configs
#[gpu_test]
static MESH_PIPELINE_BASIC_MESH: GpuTestConfiguration = default_gpu_test_config(DrawType::Standard)
.run_sync(|ctx| {
mesh_pipeline_build(&ctx, None, BASIC_MESH, None, true);
});
#[gpu_test]
static MESH_PIPELINE_BASIC_TASK_MESH: GpuTestConfiguration =
default_gpu_test_config(DrawType::Standard).run_sync(|ctx| {
mesh_pipeline_build(&ctx, Some(BASIC_TASK), BASIC_MESH, None, true);
});
#[gpu_test]
static MESH_PIPELINE_BASIC_MESH_FRAG: GpuTestConfiguration =
default_gpu_test_config(DrawType::Standard).run_sync(|ctx| {
mesh_pipeline_build(&ctx, None, BASIC_MESH, Some(NO_WRITE_FRAG), true);
});
#[gpu_test]
static MESH_PIPELINE_BASIC_TASK_MESH_FRAG: GpuTestConfiguration =
default_gpu_test_config(DrawType::Standard).run_sync(|ctx| {
mesh_pipeline_build(
&ctx,
Some(BASIC_TASK),
BASIC_MESH,
Some(NO_WRITE_FRAG),
true,
);
});
// Mesh draw
#[gpu_test]
static MESH_DRAW_INDIRECT: GpuTestConfiguration = default_gpu_test_config(DrawType::Indirect)
.run_sync(|ctx| {
mesh_draw(&ctx, DrawType::Indirect);
});
#[gpu_test]
static MESH_MULTI_DRAW_INDIRECT: GpuTestConfiguration =
default_gpu_test_config(DrawType::MultiIndirect).run_sync(|ctx| {
mesh_draw(&ctx, DrawType::MultiIndirect);
});
#[gpu_test]
static MESH_MULTI_DRAW_INDIRECT_COUNT: GpuTestConfiguration =
default_gpu_test_config(DrawType::MultiIndirectCount).run_sync(|ctx| {
mesh_draw(&ctx, DrawType::MultiIndirectCount);
});

View File

@ -0,0 +1,7 @@
#version 450
#extension GL_EXT_mesh_shader : require
in VertexInput { layout(location = 0) vec4 color; }
vertexInput;
void main() {}

View File

@ -123,7 +123,7 @@ use crate::{
use super::{
pass,
render_command::{ArcRenderCommand, RenderCommand},
DrawKind,
DrawCommandFamily, DrawKind,
};
/// Describes a [`RenderBundleEncoder`].
@ -380,7 +380,7 @@ impl RenderBundleEncoder {
} => {
let scope = PassErrorScope::Draw {
kind: DrawKind::Draw,
indexed: false,
family: DrawCommandFamily::Draw,
};
draw(
&mut state,
@ -401,7 +401,7 @@ impl RenderBundleEncoder {
} => {
let scope = PassErrorScope::Draw {
kind: DrawKind::Draw,
indexed: true,
family: DrawCommandFamily::DrawIndexed,
};
draw_indexed(
&mut state,
@ -414,15 +414,33 @@ impl RenderBundleEncoder {
)
.map_pass_err(scope)?;
}
RenderCommand::DrawMeshTasks {
group_count_x,
group_count_y,
group_count_z,
} => {
let scope = PassErrorScope::Draw {
kind: DrawKind::Draw,
family: DrawCommandFamily::DrawMeshTasks,
};
draw_mesh_tasks(
&mut state,
&base.dynamic_offsets,
group_count_x,
group_count_y,
group_count_z,
)
.map_pass_err(scope)?;
}
RenderCommand::DrawIndirect {
buffer_id,
offset,
count: 1,
indexed,
family,
} => {
let scope = PassErrorScope::Draw {
kind: DrawKind::DrawIndirect,
indexed,
family,
};
multi_draw_indirect(
&mut state,
@ -430,7 +448,7 @@ impl RenderBundleEncoder {
&buffer_guard,
buffer_id,
offset,
indexed,
family,
)
.map_pass_err(scope)?;
}
@ -787,13 +805,48 @@ fn draw_indexed(
Ok(())
}
fn draw_mesh_tasks(
state: &mut State,
dynamic_offsets: &[u32],
group_count_x: u32,
group_count_y: u32,
group_count_z: u32,
) -> Result<(), RenderBundleErrorInner> {
let pipeline = state.pipeline()?;
let used_bind_groups = pipeline.used_bind_groups;
let groups_size_limit = state.device.limits.max_task_workgroups_per_dimension;
let max_groups = state.device.limits.max_task_workgroup_total_count;
if group_count_x > groups_size_limit
|| group_count_y > groups_size_limit
|| group_count_z > groups_size_limit
|| group_count_x * group_count_y * group_count_z > max_groups
{
return Err(RenderBundleErrorInner::Draw(DrawError::InvalidGroupSize {
current: [group_count_x, group_count_y, group_count_z],
limit: groups_size_limit,
max_total: max_groups,
}));
}
if group_count_x > 0 && group_count_y > 0 && group_count_z > 0 {
state.flush_binds(used_bind_groups, dynamic_offsets);
state.commands.push(ArcRenderCommand::DrawMeshTasks {
group_count_x,
group_count_y,
group_count_z,
});
}
Ok(())
}
fn multi_draw_indirect(
state: &mut State,
dynamic_offsets: &[u32],
buffer_guard: &crate::storage::Storage<Fallible<Buffer>>,
buffer_id: id::Id<id::markers::Buffer>,
offset: u64,
indexed: bool,
family: DrawCommandFamily,
) -> Result<(), RenderBundleErrorInner> {
state
.device
@ -809,7 +862,7 @@ fn multi_draw_indirect(
let vertex_limits = super::VertexLimits::new(state.vertex_buffer_sizes(), &pipeline.steps);
let stride = super::get_stride_of_indirect_args(indexed);
let stride = super::get_stride_of_indirect_args(family);
state
.buffer_memory_init_actions
.extend(buffer.initialization_status.read().create_action(
@ -818,7 +871,7 @@ fn multi_draw_indirect(
MemoryInitKind::NeedsInitializedMemory,
));
let vertex_or_index_limit = if indexed {
let vertex_or_index_limit = if family == DrawCommandFamily::DrawIndexed {
let index = match state.index {
Some(ref mut index) => index,
None => return Err(DrawError::MissingIndexBuffer.into()),
@ -844,7 +897,7 @@ fn multi_draw_indirect(
buffer,
offset,
count: 1,
indexed,
family,
vertex_or_index_limit,
instance_limit,
@ -1066,11 +1119,18 @@ impl RenderBundle {
)
};
}
Cmd::DrawMeshTasks {
group_count_x,
group_count_y,
group_count_z,
} => unsafe {
raw.draw_mesh_tasks(*group_count_x, *group_count_y, *group_count_z);
},
Cmd::DrawIndirect {
buffer,
offset,
count: 1,
indexed,
family,
vertex_or_index_limit,
instance_limit,
@ -1081,7 +1141,7 @@ impl RenderBundle {
&self.device,
buffer,
*offset,
*indexed,
*family,
*vertex_or_index_limit,
*instance_limit,
)?;
@ -1092,10 +1152,14 @@ impl RenderBundle {
} else {
(buffer.try_raw(snatch_guard)?, *offset)
};
if *indexed {
unsafe { raw.draw_indexed_indirect(buffer, offset, 1) };
} else {
unsafe { raw.draw_indirect(buffer, offset, 1) };
match family {
DrawCommandFamily::Draw => unsafe { raw.draw_indirect(buffer, offset, 1) },
DrawCommandFamily::DrawIndexed => unsafe {
raw.draw_indexed_indirect(buffer, offset, 1)
},
DrawCommandFamily::DrawMeshTasks => unsafe {
raw.draw_mesh_tasks_indirect(buffer, offset, 1);
},
}
}
Cmd::DrawIndirect { .. } | Cmd::MultiDrawIndirectCount { .. } => {
@ -1597,7 +1661,7 @@ where
pub mod bundle_ffi {
use super::{RenderBundleEncoder, RenderCommand};
use crate::{id, RawString};
use crate::{command::DrawCommandFamily, id, RawString};
use core::{convert::TryInto, slice};
use wgt::{BufferAddress, BufferSize, DynamicOffset, IndexFormat};
@ -1752,7 +1816,7 @@ pub mod bundle_ffi {
buffer_id,
offset,
count: 1,
indexed: false,
family: DrawCommandFamily::Draw,
});
}
@ -1765,7 +1829,7 @@ pub mod bundle_ffi {
buffer_id,
offset,
count: 1,
indexed: true,
family: DrawCommandFamily::DrawIndexed,
});
}

View File

@ -56,6 +56,21 @@ pub enum DrawError {
},
#[error(transparent)]
BindingSizeTooSmall(#[from] LateMinBufferBindingSizeMismatch),
#[error(
"Wrong pipeline type for this draw command. Attempted to call {} draw command on {} pipeline",
if *wanted_mesh_pipeline {"mesh shader"} else {"standard"},
if *wanted_mesh_pipeline {"standard"} else {"mesh shader"},
)]
WrongPipelineType { wanted_mesh_pipeline: bool },
#[error(
"Each current draw group size dimension ({current:?}) must be less or equal to {limit}, and the product must be less or equal to {max_total}"
)]
InvalidGroupSize {
current: [u32; 3],
limit: u32,
max_total: u32,
},
}
impl WebGpuError for DrawError {

View File

@ -1442,6 +1442,15 @@ pub enum DrawKind {
MultiDrawIndirectCount,
}
/// The type of draw command(indexed or not, or mesh shader)
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum DrawCommandFamily {
Draw,
DrawIndexed,
DrawMeshTasks,
}
/// A command that can be recorded in a pass or bundle.
///
/// This is used to provide context for errors during command recording.
@ -1480,7 +1489,10 @@ pub enum PassErrorScope {
#[error("In a set_scissor_rect command")]
SetScissorRect,
#[error("In a draw command, kind: {kind:?}")]
Draw { kind: DrawKind, indexed: bool },
Draw {
kind: DrawKind,
family: DrawCommandFamily,
},
#[error("In a write_timestamp command")]
WriteTimestamp,
#[error("In a begin_occlusion_query command")]

View File

@ -54,7 +54,7 @@ use super::{
memory_init::TextureSurfaceDiscard, CommandBufferTextureMemoryActions, CommandEncoder,
QueryResetMap,
};
use super::{DrawKind, Rect};
use super::{DrawCommandFamily, DrawKind, Rect};
use crate::binding_model::{BindError, PushConstantUploadError};
pub use wgt::{LoadOp, StoreOp};
@ -513,7 +513,7 @@ struct State<'scope, 'snatch_guard, 'cmd_enc, 'raw_encoder> {
impl<'scope, 'snatch_guard, 'cmd_enc, 'raw_encoder>
State<'scope, 'snatch_guard, 'cmd_enc, 'raw_encoder>
{
fn is_ready(&self, indexed: bool) -> Result<(), DrawError> {
fn is_ready(&self, family: DrawCommandFamily) -> Result<(), DrawError> {
if let Some(pipeline) = self.pipeline.as_ref() {
self.general.binder.check_compatibility(pipeline.as_ref())?;
self.general.binder.check_late_buffer_bindings()?;
@ -537,7 +537,7 @@ impl<'scope, 'snatch_guard, 'cmd_enc, 'raw_encoder>
});
}
if indexed {
if family == DrawCommandFamily::DrawIndexed {
// Pipeline expects an index buffer
if let Some(pipeline_index_format) = pipeline.strip_index_format {
// We have a buffer bound
@ -556,6 +556,11 @@ impl<'scope, 'snatch_guard, 'cmd_enc, 'raw_encoder>
}
}
}
if (family == DrawCommandFamily::DrawMeshTasks) != pipeline.is_mesh {
return Err(DrawError::WrongPipelineType {
wanted_mesh_pipeline: !pipeline.is_mesh,
});
}
Ok(())
} else {
Err(DrawError::MissingPipeline(pass::MissingPipeline))
@ -2013,7 +2018,7 @@ impl Global {
} => {
let scope = PassErrorScope::Draw {
kind: DrawKind::Draw,
indexed: false,
family: DrawCommandFamily::Draw,
};
draw(
&mut state,
@ -2033,7 +2038,7 @@ impl Global {
} => {
let scope = PassErrorScope::Draw {
kind: DrawKind::Draw,
indexed: true,
family: DrawCommandFamily::DrawIndexed,
};
draw_indexed(
&mut state,
@ -2045,11 +2050,28 @@ impl Global {
)
.map_pass_err(scope)?;
}
ArcRenderCommand::DrawMeshTasks {
group_count_x,
group_count_y,
group_count_z,
} => {
let scope = PassErrorScope::Draw {
kind: DrawKind::Draw,
family: DrawCommandFamily::DrawMeshTasks,
};
draw_mesh_tasks(
&mut state,
group_count_x,
group_count_y,
group_count_z,
)
.map_pass_err(scope)?;
}
ArcRenderCommand::DrawIndirect {
buffer,
offset,
count,
indexed,
family,
vertex_or_index_limit: _,
instance_limit: _,
@ -2060,7 +2082,7 @@ impl Global {
} else {
DrawKind::DrawIndirect
},
indexed,
family,
};
multi_draw_indirect(
&mut state,
@ -2070,7 +2092,7 @@ impl Global {
buffer,
offset,
count,
indexed,
family,
)
.map_pass_err(scope)?;
}
@ -2080,11 +2102,11 @@ impl Global {
count_buffer,
count_buffer_offset,
max_count,
indexed,
family,
} => {
let scope = PassErrorScope::Draw {
kind: DrawKind::MultiDrawIndirectCount,
indexed,
family,
};
multi_draw_indirect_count(
&mut state,
@ -2094,7 +2116,7 @@ impl Global {
count_buffer,
count_buffer_offset,
max_count,
indexed,
family,
)
.map_pass_err(scope)?;
}
@ -2545,7 +2567,7 @@ fn draw(
) -> Result<(), DrawError> {
api_log!("RenderPass::draw {vertex_count} {instance_count} {first_vertex} {first_instance}");
state.is_ready(false)?;
state.is_ready(DrawCommandFamily::Draw)?;
state
.vertex
@ -2579,7 +2601,7 @@ fn draw_indexed(
) -> Result<(), DrawError> {
api_log!("RenderPass::draw_indexed {index_count} {instance_count} {first_index} {base_vertex} {first_instance}");
state.is_ready(true)?;
state.is_ready(DrawCommandFamily::DrawIndexed)?;
let last_index = first_index as u64 + index_count as u64;
let index_limit = state.index.limit;
@ -2608,6 +2630,45 @@ fn draw_indexed(
Ok(())
}
fn draw_mesh_tasks(
state: &mut State,
group_count_x: u32,
group_count_y: u32,
group_count_z: u32,
) -> Result<(), DrawError> {
api_log!("RenderPass::draw_mesh_tasks {group_count_x} {group_count_y} {group_count_z}");
state.is_ready(DrawCommandFamily::DrawMeshTasks)?;
let groups_size_limit = state
.general
.device
.limits
.max_task_workgroups_per_dimension;
let max_groups = state.general.device.limits.max_task_workgroup_total_count;
if group_count_x > groups_size_limit
|| group_count_y > groups_size_limit
|| group_count_z > groups_size_limit
|| group_count_x * group_count_y * group_count_z > max_groups
{
return Err(DrawError::InvalidGroupSize {
current: [group_count_x, group_count_y, group_count_z],
limit: groups_size_limit,
max_total: max_groups,
});
}
unsafe {
if group_count_x > 0 && group_count_y > 0 && group_count_z > 0 {
state
.general
.raw_encoder
.draw_mesh_tasks(group_count_x, group_count_y, group_count_z);
}
}
Ok(())
}
fn multi_draw_indirect(
state: &mut State,
indirect_draw_validation_resources: &mut crate::indirect_validation::DrawResources,
@ -2616,14 +2677,14 @@ fn multi_draw_indirect(
indirect_buffer: Arc<crate::resource::Buffer>,
offset: u64,
count: u32,
indexed: bool,
family: DrawCommandFamily,
) -> Result<(), RenderPassErrorInner> {
api_log!(
"RenderPass::draw_indirect (indexed:{indexed}) {} {offset} {count:?}",
"RenderPass::draw_indirect (family:{family:?}) {} {offset} {count:?}",
indirect_buffer.error_ident()
);
state.is_ready(indexed)?;
state.is_ready(family)?;
if count != 1 {
state
@ -2645,7 +2706,7 @@ fn multi_draw_indirect(
return Err(RenderPassErrorInner::UnalignedIndirectBufferOffset(offset));
}
let stride = get_stride_of_indirect_args(indexed);
let stride = get_stride_of_indirect_args(family);
let end_offset = offset + stride * count as u64;
if end_offset > indirect_buffer.size {
@ -2667,18 +2728,21 @@ fn multi_draw_indirect(
fn draw(
raw_encoder: &mut dyn hal::DynCommandEncoder,
indexed: bool,
family: DrawCommandFamily,
indirect_buffer: &dyn hal::DynBuffer,
offset: u64,
count: u32,
) {
match indexed {
false => unsafe {
match family {
DrawCommandFamily::Draw => unsafe {
raw_encoder.draw_indirect(indirect_buffer, offset, count);
},
true => unsafe {
DrawCommandFamily::DrawIndexed => unsafe {
raw_encoder.draw_indexed_indirect(indirect_buffer, offset, count);
},
DrawCommandFamily::DrawMeshTasks => unsafe {
raw_encoder.draw_mesh_tasks_indirect(indirect_buffer, offset, count);
},
}
}
@ -2703,7 +2767,7 @@ fn multi_draw_indirect(
indirect_draw_validation_batcher: &'a mut crate::indirect_validation::DrawBatcher,
indirect_buffer: Arc<crate::resource::Buffer>,
indexed: bool,
family: DrawCommandFamily,
vertex_or_index_limit: u64,
instance_limit: u64,
}
@ -2715,7 +2779,7 @@ fn multi_draw_indirect(
self.device,
&self.indirect_buffer,
offset,
self.indexed,
self.family,
self.vertex_or_index_limit,
self.instance_limit,
)?;
@ -2731,7 +2795,7 @@ fn multi_draw_indirect(
.get_dst_buffer(draw_data.buffer_index);
draw(
self.raw_encoder,
self.indexed,
self.family,
dst_buffer,
draw_data.offset,
draw_data.count,
@ -2745,8 +2809,8 @@ fn multi_draw_indirect(
indirect_draw_validation_resources,
indirect_draw_validation_batcher,
indirect_buffer,
indexed,
vertex_or_index_limit: if indexed {
family,
vertex_or_index_limit: if family == DrawCommandFamily::DrawIndexed {
state.index.limit
} else {
state.vertex.limits.vertex_limit
@ -2781,7 +2845,7 @@ fn multi_draw_indirect(
draw(
state.general.raw_encoder,
indexed,
family,
indirect_buffer.try_raw(state.general.snatch_guard)?,
offset,
count,
@ -2799,17 +2863,17 @@ fn multi_draw_indirect_count(
count_buffer: Arc<crate::resource::Buffer>,
count_buffer_offset: u64,
max_count: u32,
indexed: bool,
family: DrawCommandFamily,
) -> Result<(), RenderPassErrorInner> {
api_log!(
"RenderPass::multi_draw_indirect_count (indexed:{indexed}) {} {offset} {} {count_buffer_offset:?} {max_count:?}",
"RenderPass::multi_draw_indirect_count (family:{family:?}) {} {offset} {} {count_buffer_offset:?} {max_count:?}",
indirect_buffer.error_ident(),
count_buffer.error_ident()
);
state.is_ready(indexed)?;
state.is_ready(family)?;
let stride = get_stride_of_indirect_args(indexed);
let stride = get_stride_of_indirect_args(family);
state
.general
@ -2879,8 +2943,8 @@ fn multi_draw_indirect_count(
),
);
match indexed {
false => unsafe {
match family {
DrawCommandFamily::Draw => unsafe {
state.general.raw_encoder.draw_indirect_count(
indirect_raw,
offset,
@ -2889,7 +2953,7 @@ fn multi_draw_indirect_count(
max_count,
);
},
true => unsafe {
DrawCommandFamily::DrawIndexed => unsafe {
state.general.raw_encoder.draw_indexed_indirect_count(
indirect_raw,
offset,
@ -2898,6 +2962,15 @@ fn multi_draw_indirect_count(
max_count,
);
},
DrawCommandFamily::DrawMeshTasks => unsafe {
state.general.raw_encoder.draw_mesh_tasks_indirect_count(
indirect_raw,
offset,
count_raw,
count_buffer_offset,
max_count,
);
},
}
Ok(())
}
@ -3250,7 +3323,7 @@ impl Global {
) -> Result<(), PassStateError> {
let scope = PassErrorScope::Draw {
kind: DrawKind::Draw,
indexed: false,
family: DrawCommandFamily::Draw,
};
let base = pass_base!(pass, scope);
@ -3275,7 +3348,7 @@ impl Global {
) -> Result<(), PassStateError> {
let scope = PassErrorScope::Draw {
kind: DrawKind::Draw,
indexed: true,
family: DrawCommandFamily::DrawIndexed,
};
let base = pass_base!(pass, scope);
@ -3290,6 +3363,27 @@ impl Global {
Ok(())
}
pub fn render_pass_draw_mesh_tasks(
&self,
pass: &mut RenderPass,
group_count_x: u32,
group_count_y: u32,
group_count_z: u32,
) -> Result<(), RenderPassError> {
let scope = PassErrorScope::Draw {
kind: DrawKind::Draw,
family: DrawCommandFamily::DrawMeshTasks,
};
let base = pass_base!(pass, scope);
base.commands.push(ArcRenderCommand::DrawMeshTasks {
group_count_x,
group_count_y,
group_count_z,
});
Ok(())
}
pub fn render_pass_draw_indirect(
&self,
pass: &mut RenderPass,
@ -3298,7 +3392,7 @@ impl Global {
) -> Result<(), PassStateError> {
let scope = PassErrorScope::Draw {
kind: DrawKind::DrawIndirect,
indexed: false,
family: DrawCommandFamily::Draw,
};
let base = pass_base!(pass, scope);
@ -3306,7 +3400,7 @@ impl Global {
buffer: pass_try!(base, scope, self.resolve_render_pass_buffer_id(buffer_id)),
offset,
count: 1,
indexed: false,
family: DrawCommandFamily::Draw,
vertex_or_index_limit: 0,
instance_limit: 0,
@ -3323,7 +3417,7 @@ impl Global {
) -> Result<(), PassStateError> {
let scope = PassErrorScope::Draw {
kind: DrawKind::DrawIndirect,
indexed: true,
family: DrawCommandFamily::DrawIndexed,
};
let base = pass_base!(pass, scope);
@ -3331,7 +3425,32 @@ impl Global {
buffer: pass_try!(base, scope, self.resolve_render_pass_buffer_id(buffer_id)),
offset,
count: 1,
indexed: true,
family: DrawCommandFamily::DrawIndexed,
vertex_or_index_limit: 0,
instance_limit: 0,
});
Ok(())
}
pub fn render_pass_draw_mesh_tasks_indirect(
&self,
pass: &mut RenderPass,
buffer_id: id::BufferId,
offset: BufferAddress,
) -> Result<(), RenderPassError> {
let scope = PassErrorScope::Draw {
kind: DrawKind::DrawIndirect,
family: DrawCommandFamily::DrawMeshTasks,
};
let base = pass_base!(pass, scope);
base.commands.push(ArcRenderCommand::DrawIndirect {
buffer: pass_try!(base, scope, self.resolve_render_pass_buffer_id(buffer_id)),
offset,
count: 1,
family: DrawCommandFamily::DrawMeshTasks,
vertex_or_index_limit: 0,
instance_limit: 0,
@ -3349,7 +3468,7 @@ impl Global {
) -> Result<(), PassStateError> {
let scope = PassErrorScope::Draw {
kind: DrawKind::MultiDrawIndirect,
indexed: false,
family: DrawCommandFamily::Draw,
};
let base = pass_base!(pass, scope);
@ -3357,7 +3476,7 @@ impl Global {
buffer: pass_try!(base, scope, self.resolve_render_pass_buffer_id(buffer_id)),
offset,
count,
indexed: false,
family: DrawCommandFamily::Draw,
vertex_or_index_limit: 0,
instance_limit: 0,
@ -3375,7 +3494,7 @@ impl Global {
) -> Result<(), PassStateError> {
let scope = PassErrorScope::Draw {
kind: DrawKind::MultiDrawIndirect,
indexed: true,
family: DrawCommandFamily::DrawIndexed,
};
let base = pass_base!(pass, scope);
@ -3383,7 +3502,33 @@ impl Global {
buffer: pass_try!(base, scope, self.resolve_render_pass_buffer_id(buffer_id)),
offset,
count,
indexed: true,
family: DrawCommandFamily::DrawIndexed,
vertex_or_index_limit: 0,
instance_limit: 0,
});
Ok(())
}
pub fn render_pass_multi_draw_mesh_tasks_indirect(
&self,
pass: &mut RenderPass,
buffer_id: id::BufferId,
offset: BufferAddress,
count: u32,
) -> Result<(), RenderPassError> {
let scope = PassErrorScope::Draw {
kind: DrawKind::MultiDrawIndirect,
family: DrawCommandFamily::DrawMeshTasks,
};
let base = pass_base!(pass, scope);
base.commands.push(ArcRenderCommand::DrawIndirect {
buffer: pass_try!(base, scope, self.resolve_render_pass_buffer_id(buffer_id)),
offset,
count,
family: DrawCommandFamily::DrawMeshTasks,
vertex_or_index_limit: 0,
instance_limit: 0,
@ -3403,7 +3548,7 @@ impl Global {
) -> Result<(), PassStateError> {
let scope = PassErrorScope::Draw {
kind: DrawKind::MultiDrawIndirectCount,
indexed: false,
family: DrawCommandFamily::Draw,
};
let base = pass_base!(pass, scope);
@ -3418,7 +3563,7 @@ impl Global {
),
count_buffer_offset,
max_count,
indexed: false,
family: DrawCommandFamily::Draw,
});
Ok(())
@ -3435,7 +3580,7 @@ impl Global {
) -> Result<(), PassStateError> {
let scope = PassErrorScope::Draw {
kind: DrawKind::MultiDrawIndirectCount,
indexed: true,
family: DrawCommandFamily::DrawIndexed,
};
let base = pass_base!(pass, scope);
@ -3450,7 +3595,39 @@ impl Global {
),
count_buffer_offset,
max_count,
indexed: true,
family: DrawCommandFamily::DrawIndexed,
});
Ok(())
}
pub fn render_pass_multi_draw_mesh_tasks_indirect_count(
&self,
pass: &mut RenderPass,
buffer_id: id::BufferId,
offset: BufferAddress,
count_buffer_id: id::BufferId,
count_buffer_offset: BufferAddress,
max_count: u32,
) -> Result<(), RenderPassError> {
let scope = PassErrorScope::Draw {
kind: DrawKind::MultiDrawIndirectCount,
family: DrawCommandFamily::DrawMeshTasks,
};
let base = pass_base!(pass, scope);
base.commands
.push(ArcRenderCommand::MultiDrawIndirectCount {
buffer: pass_try!(base, scope, self.resolve_render_pass_buffer_id(buffer_id)),
offset,
count_buffer: pass_try!(
base,
scope,
self.resolve_render_pass_buffer_id(count_buffer_id)
),
count_buffer_offset,
max_count,
family: DrawCommandFamily::DrawMeshTasks,
});
Ok(())
@ -3607,9 +3784,10 @@ impl Global {
}
}
pub(crate) const fn get_stride_of_indirect_args(indexed: bool) -> u64 {
match indexed {
false => size_of::<wgt::DrawIndirectArgs>() as u64,
true => size_of::<wgt::DrawIndexedIndirectArgs>() as u64,
pub(crate) const fn get_stride_of_indirect_args(family: DrawCommandFamily) -> u64 {
match family {
DrawCommandFamily::Draw => size_of::<wgt::DrawIndirectArgs>() as u64,
DrawCommandFamily::DrawIndexed => size_of::<wgt::DrawIndexedIndirectArgs>() as u64,
DrawCommandFamily::DrawMeshTasks => size_of::<wgt::DispatchIndirectArgs>() as u64,
}
}

View File

@ -2,7 +2,7 @@ use alloc::sync::Arc;
use wgt::{BufferAddress, BufferSize, Color};
use super::{Rect, RenderBundle};
use super::{DrawCommandFamily, Rect, RenderBundle};
use crate::{
binding_model::BindGroup,
id,
@ -82,11 +82,16 @@ pub enum RenderCommand {
base_vertex: i32,
first_instance: u32,
},
DrawMeshTasks {
group_count_x: u32,
group_count_y: u32,
group_count_z: u32,
},
DrawIndirect {
buffer_id: id::BufferId,
offset: BufferAddress,
count: u32,
indexed: bool,
family: DrawCommandFamily,
},
MultiDrawIndirectCount {
buffer_id: id::BufferId,
@ -94,7 +99,7 @@ pub enum RenderCommand {
count_buffer_id: id::BufferId,
count_buffer_offset: BufferAddress,
max_count: u32,
indexed: bool,
family: DrawCommandFamily,
},
PushDebugGroup {
color: u32,
@ -310,12 +315,21 @@ impl RenderCommand {
base_vertex,
first_instance,
},
RenderCommand::DrawMeshTasks {
group_count_x,
group_count_y,
group_count_z,
} => ArcRenderCommand::DrawMeshTasks {
group_count_x,
group_count_y,
group_count_z,
},
RenderCommand::DrawIndirect {
buffer_id,
offset,
count,
indexed,
family,
} => ArcRenderCommand::DrawIndirect {
buffer: buffers_guard.get(buffer_id).get().map_err(|e| {
RenderPassError {
@ -325,14 +339,14 @@ impl RenderCommand {
} else {
DrawKind::DrawIndirect
},
indexed,
family,
},
inner: e.into(),
}
})?,
offset,
count,
indexed,
family,
vertex_or_index_limit: 0,
instance_limit: 0,
@ -344,11 +358,11 @@ impl RenderCommand {
count_buffer_id,
count_buffer_offset,
max_count,
indexed,
family,
} => {
let scope = PassErrorScope::Draw {
kind: DrawKind::MultiDrawIndirectCount,
indexed,
family,
};
ArcRenderCommand::MultiDrawIndirectCount {
buffer: buffers_guard.get(buffer_id).get().map_err(|e| {
@ -366,7 +380,7 @@ impl RenderCommand {
)?,
count_buffer_offset,
max_count,
indexed,
family,
}
}
@ -473,11 +487,16 @@ pub enum ArcRenderCommand {
base_vertex: i32,
first_instance: u32,
},
DrawMeshTasks {
group_count_x: u32,
group_count_y: u32,
group_count_z: u32,
},
DrawIndirect {
buffer: Arc<Buffer>,
offset: BufferAddress,
count: u32,
indexed: bool,
family: DrawCommandFamily,
/// This limit is only populated for commands in a [`RenderBundle`].
vertex_or_index_limit: u64,
@ -490,7 +509,7 @@ pub enum ArcRenderCommand {
count_buffer: Arc<Buffer>,
count_buffer_offset: BufferAddress,
max_count: u32,
indexed: bool,
family: DrawCommandFamily,
},
PushDebugGroup {
#[cfg_attr(not(any(feature = "serde", feature = "replay")), allow(dead_code))]

View File

@ -16,8 +16,9 @@ use crate::{
id::{self, AdapterId, DeviceId, QueueId, SurfaceId},
instance::{self, Adapter, Surface},
pipeline::{
self, ResolvedComputePipelineDescriptor, ResolvedFragmentState,
ResolvedProgrammableStageDescriptor, ResolvedRenderPipelineDescriptor, ResolvedVertexState,
self, RenderPipelineVertexProcessor, ResolvedComputePipelineDescriptor,
ResolvedFragmentState, ResolvedGeneralRenderPipelineDescriptor, ResolvedMeshState,
ResolvedProgrammableStageDescriptor, ResolvedTaskState, ResolvedVertexState,
},
present,
resource::{
@ -1346,17 +1347,55 @@ impl Global {
let fid = hub.render_pipelines.prepare(id_in);
let device = self.hub.devices.get(device_id);
#[cfg(feature = "trace")]
if let Some(ref mut trace) = *device.trace.lock() {
trace.add(trace::Action::CreateRenderPipeline {
id: fid.id(),
desc: desc.clone(),
});
}
self.device_create_general_render_pipeline(desc.clone().into(), device, fid)
}
pub fn device_create_mesh_pipeline(
&self,
device_id: DeviceId,
desc: &pipeline::MeshPipelineDescriptor,
id_in: Option<id::RenderPipelineId>,
) -> (
id::RenderPipelineId,
Option<pipeline::CreateRenderPipelineError>,
) {
let hub = &self.hub;
let fid = hub.render_pipelines.prepare(id_in);
let device = self.hub.devices.get(device_id);
#[cfg(feature = "trace")]
if let Some(ref mut trace) = *device.trace.lock() {
trace.add(trace::Action::CreateMeshPipeline {
id: fid.id(),
desc: desc.clone(),
});
}
self.device_create_general_render_pipeline(desc.clone().into(), device, fid)
}
fn device_create_general_render_pipeline(
&self,
desc: pipeline::GeneralRenderPipelineDescriptor,
device: Arc<crate::device::resource::Device>,
fid: crate::registry::FutureId<Fallible<pipeline::RenderPipeline>>,
) -> (
id::RenderPipelineId,
Option<pipeline::CreateRenderPipelineError>,
) {
profiling::scope!("Device::create_general_render_pipeline");
let hub = &self.hub;
let error = 'error: {
let device = self.hub.devices.get(device_id);
#[cfg(feature = "trace")]
if let Some(ref mut trace) = *device.trace.lock() {
trace.add(trace::Action::CreateRenderPipeline {
id: fid.id(),
desc: desc.clone(),
});
}
if let Err(e) = device.check_is_valid() {
break 'error e.into();
}
@ -1379,31 +1418,83 @@ impl Global {
Err(e) => break 'error e.into(),
};
let vertex = {
let module = hub
.shader_modules
.get(desc.vertex.stage.module)
.get()
.map_err(|e| pipeline::CreateRenderPipelineError::Stage {
stage: wgt::ShaderStages::VERTEX,
error: e.into(),
});
let module = match module {
Ok(module) => module,
Err(e) => break 'error e,
};
let stage = ResolvedProgrammableStageDescriptor {
module,
entry_point: desc.vertex.stage.entry_point.clone(),
constants: desc.vertex.stage.constants.clone(),
zero_initialize_workgroup_memory: desc
.vertex
.stage
.zero_initialize_workgroup_memory,
};
ResolvedVertexState {
stage,
buffers: desc.vertex.buffers.clone(),
let vertex = match desc.vertex {
RenderPipelineVertexProcessor::Vertex(ref vertex) => {
let module = hub
.shader_modules
.get(vertex.stage.module)
.get()
.map_err(|e| pipeline::CreateRenderPipelineError::Stage {
stage: wgt::ShaderStages::VERTEX,
error: e.into(),
});
let module = match module {
Ok(module) => module,
Err(e) => break 'error e,
};
let stage = ResolvedProgrammableStageDescriptor {
module,
entry_point: vertex.stage.entry_point.clone(),
constants: vertex.stage.constants.clone(),
zero_initialize_workgroup_memory: vertex
.stage
.zero_initialize_workgroup_memory,
};
RenderPipelineVertexProcessor::Vertex(ResolvedVertexState {
stage,
buffers: vertex.buffers.clone(),
})
}
RenderPipelineVertexProcessor::Mesh(ref task, ref mesh) => {
let task_module = if let Some(task) = task {
let module = hub
.shader_modules
.get(task.stage.module)
.get()
.map_err(|e| pipeline::CreateRenderPipelineError::Stage {
stage: wgt::ShaderStages::VERTEX,
error: e.into(),
});
let module = match module {
Ok(module) => module,
Err(e) => break 'error e,
};
let state = ResolvedProgrammableStageDescriptor {
module,
entry_point: task.stage.entry_point.clone(),
constants: task.stage.constants.clone(),
zero_initialize_workgroup_memory: task
.stage
.zero_initialize_workgroup_memory,
};
Some(ResolvedTaskState { stage: state })
} else {
None
};
let mesh_module =
hub.shader_modules
.get(mesh.stage.module)
.get()
.map_err(|e| pipeline::CreateRenderPipelineError::Stage {
stage: wgt::ShaderStages::MESH,
error: e.into(),
});
let mesh_module = match mesh_module {
Ok(module) => module,
Err(e) => break 'error e,
};
let mesh_stage = ResolvedProgrammableStageDescriptor {
module: mesh_module,
entry_point: mesh.stage.entry_point.clone(),
constants: mesh.stage.constants.clone(),
zero_initialize_workgroup_memory: mesh
.stage
.zero_initialize_workgroup_memory,
};
RenderPipelineVertexProcessor::Mesh(
task_module,
ResolvedMeshState { stage: mesh_stage },
)
}
};
@ -1424,10 +1515,7 @@ impl Global {
module,
entry_point: state.stage.entry_point.clone(),
constants: state.stage.constants.clone(),
zero_initialize_workgroup_memory: desc
.vertex
.stage
.zero_initialize_workgroup_memory,
zero_initialize_workgroup_memory: state.stage.zero_initialize_workgroup_memory,
};
Some(ResolvedFragmentState {
stage,
@ -1437,7 +1525,7 @@ impl Global {
None
};
let desc = ResolvedRenderPipelineDescriptor {
let desc = ResolvedGeneralRenderPipelineDescriptor {
label: desc.label.clone(),
layout,
vertex,

View File

@ -3472,7 +3472,7 @@ impl Device {
pub(crate) fn create_render_pipeline(
self: &Arc<Self>,
desc: pipeline::ResolvedRenderPipelineDescriptor,
desc: pipeline::ResolvedGeneralRenderPipelineDescriptor,
) -> Result<Arc<pipeline::RenderPipeline>, pipeline::CreateRenderPipelineError> {
use wgt::TextureFormatFeatureFlags as Tfff;
@ -3513,127 +3513,137 @@ impl Device {
let mut io = validation::StageIo::default();
let mut validated_stages = wgt::ShaderStages::empty();
let mut vertex_steps = Vec::with_capacity(desc.vertex.buffers.len());
let mut vertex_buffers = Vec::with_capacity(desc.vertex.buffers.len());
let mut total_attributes = 0;
let mut vertex_steps;
let mut vertex_buffers;
let mut total_attributes;
let mut shader_expects_dual_source_blending = false;
let mut pipeline_expects_dual_source_blending = false;
for (i, vb_state) in desc.vertex.buffers.iter().enumerate() {
// https://gpuweb.github.io/gpuweb/#abstract-opdef-validating-gpuvertexbufferlayout
if let pipeline::RenderPipelineVertexProcessor::Vertex(ref vertex) = desc.vertex {
vertex_steps = Vec::with_capacity(vertex.buffers.len());
vertex_buffers = Vec::with_capacity(vertex.buffers.len());
total_attributes = 0;
shader_expects_dual_source_blending = false;
pipeline_expects_dual_source_blending = false;
for (i, vb_state) in vertex.buffers.iter().enumerate() {
// https://gpuweb.github.io/gpuweb/#abstract-opdef-validating-gpuvertexbufferlayout
if vb_state.array_stride > self.limits.max_vertex_buffer_array_stride as u64 {
return Err(pipeline::CreateRenderPipelineError::VertexStrideTooLarge {
index: i as u32,
given: vb_state.array_stride as u32,
limit: self.limits.max_vertex_buffer_array_stride,
});
}
if vb_state.array_stride % wgt::VERTEX_ALIGNMENT != 0 {
return Err(pipeline::CreateRenderPipelineError::UnalignedVertexStride {
index: i as u32,
if vb_state.array_stride > self.limits.max_vertex_buffer_array_stride as u64 {
return Err(pipeline::CreateRenderPipelineError::VertexStrideTooLarge {
index: i as u32,
given: vb_state.array_stride as u32,
limit: self.limits.max_vertex_buffer_array_stride,
});
}
if vb_state.array_stride % wgt::VERTEX_ALIGNMENT != 0 {
return Err(pipeline::CreateRenderPipelineError::UnalignedVertexStride {
index: i as u32,
stride: vb_state.array_stride,
});
}
let max_stride = if vb_state.array_stride == 0 {
self.limits.max_vertex_buffer_array_stride as u64
} else {
vb_state.array_stride
};
let mut last_stride = 0;
for attribute in vb_state.attributes.iter() {
let attribute_stride = attribute.offset + attribute.format.size();
if attribute_stride > max_stride {
return Err(
pipeline::CreateRenderPipelineError::VertexAttributeStrideTooLarge {
location: attribute.shader_location,
given: attribute_stride as u32,
limit: max_stride as u32,
},
);
}
let required_offset_alignment = attribute.format.size().min(4);
if attribute.offset % required_offset_alignment != 0 {
return Err(
pipeline::CreateRenderPipelineError::InvalidVertexAttributeOffset {
location: attribute.shader_location,
offset: attribute.offset,
},
);
}
if attribute.shader_location >= self.limits.max_vertex_attributes {
return Err(
pipeline::CreateRenderPipelineError::TooManyVertexAttributes {
given: attribute.shader_location,
limit: self.limits.max_vertex_attributes,
},
);
}
last_stride = last_stride.max(attribute_stride);
}
vertex_steps.push(pipeline::VertexStep {
stride: vb_state.array_stride,
last_stride,
mode: vb_state.step_mode,
});
if vb_state.attributes.is_empty() {
continue;
}
vertex_buffers.push(hal::VertexBufferLayout {
array_stride: vb_state.array_stride,
step_mode: vb_state.step_mode,
attributes: vb_state.attributes.as_ref(),
});
for attribute in vb_state.attributes.iter() {
if attribute.offset >= 0x10000000 {
return Err(
pipeline::CreateRenderPipelineError::InvalidVertexAttributeOffset {
location: attribute.shader_location,
offset: attribute.offset,
},
);
}
if let wgt::VertexFormat::Float64
| wgt::VertexFormat::Float64x2
| wgt::VertexFormat::Float64x3
| wgt::VertexFormat::Float64x4 = attribute.format
{
self.require_features(wgt::Features::VERTEX_ATTRIBUTE_64BIT)?;
}
let previous = io.insert(
attribute.shader_location,
validation::InterfaceVar::vertex_attribute(attribute.format),
);
if previous.is_some() {
return Err(pipeline::CreateRenderPipelineError::ShaderLocationClash(
attribute.shader_location,
));
}
}
total_attributes += vb_state.attributes.len();
}
if vertex_buffers.len() > self.limits.max_vertex_buffers as usize {
return Err(pipeline::CreateRenderPipelineError::TooManyVertexBuffers {
given: vertex_buffers.len() as u32,
limit: self.limits.max_vertex_buffers,
});
}
let max_stride = if vb_state.array_stride == 0 {
self.limits.max_vertex_buffer_array_stride as u64
} else {
vb_state.array_stride
};
let mut last_stride = 0;
for attribute in vb_state.attributes.iter() {
let attribute_stride = attribute.offset + attribute.format.size();
if attribute_stride > max_stride {
return Err(
pipeline::CreateRenderPipelineError::VertexAttributeStrideTooLarge {
location: attribute.shader_location,
given: attribute_stride as u32,
limit: max_stride as u32,
},
);
}
let required_offset_alignment = attribute.format.size().min(4);
if attribute.offset % required_offset_alignment != 0 {
return Err(
pipeline::CreateRenderPipelineError::InvalidVertexAttributeOffset {
location: attribute.shader_location,
offset: attribute.offset,
},
);
}
if attribute.shader_location >= self.limits.max_vertex_attributes {
return Err(
pipeline::CreateRenderPipelineError::TooManyVertexAttributes {
given: attribute.shader_location,
limit: self.limits.max_vertex_attributes,
},
);
}
last_stride = last_stride.max(attribute_stride);
}
vertex_steps.push(pipeline::VertexStep {
stride: vb_state.array_stride,
last_stride,
mode: vb_state.step_mode,
});
if vb_state.attributes.is_empty() {
continue;
}
vertex_buffers.push(hal::VertexBufferLayout {
array_stride: vb_state.array_stride,
step_mode: vb_state.step_mode,
attributes: vb_state.attributes.as_ref(),
});
for attribute in vb_state.attributes.iter() {
if attribute.offset >= 0x10000000 {
return Err(
pipeline::CreateRenderPipelineError::InvalidVertexAttributeOffset {
location: attribute.shader_location,
offset: attribute.offset,
},
);
}
if let wgt::VertexFormat::Float64
| wgt::VertexFormat::Float64x2
| wgt::VertexFormat::Float64x3
| wgt::VertexFormat::Float64x4 = attribute.format
{
self.require_features(wgt::Features::VERTEX_ATTRIBUTE_64BIT)?;
}
let previous = io.insert(
attribute.shader_location,
validation::InterfaceVar::vertex_attribute(attribute.format),
if total_attributes > self.limits.max_vertex_attributes as usize {
return Err(
pipeline::CreateRenderPipelineError::TooManyVertexAttributes {
given: total_attributes as u32,
limit: self.limits.max_vertex_attributes,
},
);
if previous.is_some() {
return Err(pipeline::CreateRenderPipelineError::ShaderLocationClash(
attribute.shader_location,
));
}
}
total_attributes += vb_state.attributes.len();
}
if vertex_buffers.len() > self.limits.max_vertex_buffers as usize {
return Err(pipeline::CreateRenderPipelineError::TooManyVertexBuffers {
given: vertex_buffers.len() as u32,
limit: self.limits.max_vertex_buffers,
});
}
if total_attributes > self.limits.max_vertex_attributes as usize {
return Err(
pipeline::CreateRenderPipelineError::TooManyVertexAttributes {
given: total_attributes as u32,
limit: self.limits.max_vertex_attributes,
},
);
}
} else {
vertex_steps = Vec::new();
vertex_buffers = Vec::new();
};
if desc.primitive.strip_index_format.is_some() && !desc.primitive.topology.is_strip() {
return Err(
@ -3843,44 +3853,132 @@ impl Device {
sc
};
let vertex_entry_point_name;
let vertex_stage = {
let stage_desc = &desc.vertex.stage;
let stage = wgt::ShaderStages::VERTEX;
let mut vertex_stage = None;
let mut task_stage = None;
let mut mesh_stage = None;
let mut _vertex_entry_point_name = String::new();
let mut _task_entry_point_name = String::new();
let mut _mesh_entry_point_name = String::new();
match desc.vertex {
pipeline::RenderPipelineVertexProcessor::Vertex(ref vertex) => {
vertex_stage = {
let stage_desc = &vertex.stage;
let stage = wgt::ShaderStages::VERTEX;
let vertex_shader_module = &stage_desc.module;
vertex_shader_module.same_device(self)?;
let vertex_shader_module = &stage_desc.module;
vertex_shader_module.same_device(self)?;
let stage_err = |error| pipeline::CreateRenderPipelineError::Stage { stage, error };
let stage_err =
|error| pipeline::CreateRenderPipelineError::Stage { stage, error };
vertex_entry_point_name = vertex_shader_module
.finalize_entry_point_name(
stage,
stage_desc.entry_point.as_ref().map(|ep| ep.as_ref()),
)
.map_err(stage_err)?;
_vertex_entry_point_name = vertex_shader_module
.finalize_entry_point_name(
stage,
stage_desc.entry_point.as_ref().map(|ep| ep.as_ref()),
)
.map_err(stage_err)?;
if let Some(ref interface) = vertex_shader_module.interface {
io = interface
.check_stage(
&mut binding_layout_source,
&mut shader_binding_sizes,
&vertex_entry_point_name,
stage,
io,
desc.depth_stencil.as_ref().map(|d| d.depth_compare),
)
.map_err(stage_err)?;
validated_stages |= stage;
if let Some(ref interface) = vertex_shader_module.interface {
io = interface
.check_stage(
&mut binding_layout_source,
&mut shader_binding_sizes,
&_vertex_entry_point_name,
stage,
io,
desc.depth_stencil.as_ref().map(|d| d.depth_compare),
)
.map_err(stage_err)?;
validated_stages |= stage;
}
Some(hal::ProgrammableStage {
module: vertex_shader_module.raw(),
entry_point: &_vertex_entry_point_name,
constants: &stage_desc.constants,
zero_initialize_workgroup_memory: stage_desc
.zero_initialize_workgroup_memory,
})
};
}
pipeline::RenderPipelineVertexProcessor::Mesh(ref task, ref mesh) => {
task_stage = if let Some(task) = task {
let stage_desc = &task.stage;
let stage = wgt::ShaderStages::TASK;
let task_shader_module = &stage_desc.module;
task_shader_module.same_device(self)?;
hal::ProgrammableStage {
module: vertex_shader_module.raw(),
entry_point: &vertex_entry_point_name,
constants: &stage_desc.constants,
zero_initialize_workgroup_memory: stage_desc.zero_initialize_workgroup_memory,
let stage_err =
|error| pipeline::CreateRenderPipelineError::Stage { stage, error };
_task_entry_point_name = task_shader_module
.finalize_entry_point_name(
stage,
stage_desc.entry_point.as_ref().map(|ep| ep.as_ref()),
)
.map_err(stage_err)?;
if let Some(ref interface) = task_shader_module.interface {
io = interface
.check_stage(
&mut binding_layout_source,
&mut shader_binding_sizes,
&_task_entry_point_name,
stage,
io,
desc.depth_stencil.as_ref().map(|d| d.depth_compare),
)
.map_err(stage_err)?;
validated_stages |= stage;
}
Some(hal::ProgrammableStage {
module: task_shader_module.raw(),
entry_point: &_task_entry_point_name,
constants: &stage_desc.constants,
zero_initialize_workgroup_memory: stage_desc
.zero_initialize_workgroup_memory,
})
} else {
None
};
mesh_stage = {
let stage_desc = &mesh.stage;
let stage = wgt::ShaderStages::MESH;
let mesh_shader_module = &stage_desc.module;
mesh_shader_module.same_device(self)?;
let stage_err =
|error| pipeline::CreateRenderPipelineError::Stage { stage, error };
_mesh_entry_point_name = mesh_shader_module
.finalize_entry_point_name(
stage,
stage_desc.entry_point.as_ref().map(|ep| ep.as_ref()),
)
.map_err(stage_err)?;
if let Some(ref interface) = mesh_shader_module.interface {
io = interface
.check_stage(
&mut binding_layout_source,
&mut shader_binding_sizes,
&_mesh_entry_point_name,
stage,
io,
desc.depth_stencil.as_ref().map(|d| d.depth_compare),
)
.map_err(stage_err)?;
validated_stages |= stage;
}
Some(hal::ProgrammableStage {
module: mesh_shader_module.raw(),
entry_point: &_mesh_entry_point_name,
constants: &stage_desc.constants,
zero_initialize_workgroup_memory: stage_desc
.zero_initialize_workgroup_memory,
})
};
}
};
}
let fragment_entry_point_name;
let fragment_stage = match desc.fragment {
@ -4029,20 +4127,29 @@ impl Device {
None => None,
};
let pipeline_desc = hal::RenderPipelineDescriptor {
label: desc.label.to_hal(self.instance_flags),
layout: pipeline_layout.raw(),
vertex_buffers: &vertex_buffers,
vertex_stage,
primitive: desc.primitive,
depth_stencil: desc.depth_stencil.clone(),
multisample: desc.multisample,
fragment_stage,
color_targets,
multiview: desc.multiview,
cache: cache.as_ref().map(|it| it.raw()),
};
let raw =
let is_mesh = mesh_stage.is_some();
let raw = {
let pipeline_desc = hal::RenderPipelineDescriptor {
label: desc.label.to_hal(self.instance_flags),
layout: pipeline_layout.raw(),
vertex_processor: match vertex_stage {
Some(vertex_stage) => hal::VertexProcessor::Standard {
vertex_buffers: &vertex_buffers,
vertex_stage,
},
None => hal::VertexProcessor::Mesh {
task_stage,
mesh_stage: mesh_stage.unwrap(),
},
},
primitive: desc.primitive,
depth_stencil: desc.depth_stencil.clone(),
multisample: desc.multisample,
fragment_stage,
color_targets,
multiview: desc.multiview,
cache: cache.as_ref().map(|it| it.raw()),
};
unsafe { self.raw().create_render_pipeline(&pipeline_desc) }.map_err(
|err| match err {
hal::PipelineError::Device(error) => {
@ -4061,7 +4168,8 @@ impl Device {
pipeline::CreateRenderPipelineError::PipelineConstants { stage, error }
}
},
)?;
)?
};
let pass_context = RenderPassContext {
attachments: AttachmentData {
@ -4095,10 +4203,19 @@ impl Device {
flags |= pipeline::PipelineFlags::WRITES_STENCIL;
}
}
let shader_modules = {
let mut shader_modules = ArrayVec::new();
shader_modules.push(desc.vertex.stage.module);
match desc.vertex {
pipeline::RenderPipelineVertexProcessor::Vertex(vertex) => {
shader_modules.push(vertex.stage.module)
}
pipeline::RenderPipelineVertexProcessor::Mesh(task, mesh) => {
if let Some(task) = task {
shader_modules.push(task.stage.module);
}
shader_modules.push(mesh.stage.module);
}
}
shader_modules.extend(desc.fragment.map(|f| f.stage.module));
shader_modules
};
@ -4115,6 +4232,7 @@ impl Device {
late_sized_buffer_groups,
label: desc.label.to_string(),
tracking_data: TrackingData::new(self.tracker_indices.render_pipelines.clone()),
is_mesh,
};
let pipeline = Arc::new(pipeline);

View File

@ -103,6 +103,10 @@ pub enum Action<'a> {
id: id::RenderPipelineId,
desc: crate::pipeline::RenderPipelineDescriptor<'a>,
},
CreateMeshPipeline {
id: id::RenderPipelineId,
desc: crate::pipeline::MeshPipelineDescriptor<'a>,
},
DestroyRenderPipeline(id::RenderPipelineId),
CreatePipelineCache {
id: id::PipelineCacheId,

View File

@ -919,7 +919,7 @@ impl DrawBatcher {
device: &Device,
src_buffer: &Arc<crate::resource::Buffer>,
offset: u64,
indexed: bool,
family: crate::command::DrawCommandFamily,
vertex_or_index_limit: u64,
instance_limit: u64,
) -> Result<(usize, u64), DeviceError> {
@ -929,7 +929,7 @@ impl DrawBatcher {
} else {
0
};
let stride = extra + crate::command::get_stride_of_indirect_args(indexed);
let stride = extra + crate::command::get_stride_of_indirect_args(family);
let (dst_resource_index, dst_offset) = indirect_draw_validation_resources
.get_dst_subrange(stride, &mut self.current_dst_entry)?;
@ -941,7 +941,7 @@ impl DrawBatcher {
let src_buffer_tracker_index = src_buffer.tracker_index();
let entry = MetadataEntry::new(
indexed,
family == crate::command::DrawCommandFamily::DrawIndexed,
src_offset,
dst_offset,
vertex_or_index_limit,

View File

@ -402,6 +402,33 @@ pub struct FragmentState<'a, SM = ShaderModuleId> {
/// cbindgen:ignore
pub type ResolvedFragmentState<'a> = FragmentState<'a, Arc<ShaderModule>>;
/// Describes the task shader in a mesh shader pipeline.
#[derive(Clone, Debug)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct TaskState<'a, SM = ShaderModuleId> {
/// The compiled task stage and its entry point.
pub stage: ProgrammableStageDescriptor<'a, SM>,
}
pub type ResolvedTaskState<'a> = TaskState<'a, Arc<ShaderModule>>;
/// Describes the mesh shader in a mesh shader pipeline.
#[derive(Clone, Debug)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct MeshState<'a, SM = ShaderModuleId> {
/// The compiled mesh stage and its entry point.
pub stage: ProgrammableStageDescriptor<'a, SM>,
}
pub type ResolvedMeshState<'a> = MeshState<'a, Arc<ShaderModule>>;
#[derive(Clone, Debug)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub(crate) enum RenderPipelineVertexProcessor<'a, SM = ShaderModuleId> {
Vertex(VertexState<'a, SM>),
Mesh(Option<TaskState<'a, SM>>, MeshState<'a, SM>),
}
/// Describes a render (graphics) pipeline.
#[derive(Clone, Debug)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
@ -433,10 +460,109 @@ pub struct RenderPipelineDescriptor<
/// The pipeline cache to use when creating this pipeline.
pub cache: Option<PLC>,
}
/// Describes a mesh shader pipeline.
#[derive(Clone, Debug)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct MeshPipelineDescriptor<
'a,
PLL = PipelineLayoutId,
SM = ShaderModuleId,
PLC = PipelineCacheId,
> {
pub label: Label<'a>,
/// The layout of bind groups for this pipeline.
pub layout: Option<PLL>,
/// The task processing state for this pipeline.
pub task: Option<TaskState<'a, SM>>,
/// The mesh processing state for this pipeline
pub mesh: MeshState<'a, SM>,
/// The properties of the pipeline at the primitive assembly and rasterization level.
#[cfg_attr(feature = "serde", serde(default))]
pub primitive: wgt::PrimitiveState,
/// The effect of draw calls on the depth and stencil aspects of the output target, if any.
#[cfg_attr(feature = "serde", serde(default))]
pub depth_stencil: Option<wgt::DepthStencilState>,
/// The multi-sampling properties of the pipeline.
#[cfg_attr(feature = "serde", serde(default))]
pub multisample: wgt::MultisampleState,
/// The fragment processing state for this pipeline.
pub fragment: Option<FragmentState<'a, SM>>,
/// If the pipeline will be used with a multiview render pass, this indicates how many array
/// layers the attachments will have.
pub multiview: Option<NonZeroU32>,
/// The pipeline cache to use when creating this pipeline.
pub cache: Option<PLC>,
}
/// Describes a render (graphics) pipeline.
#[derive(Clone, Debug)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub(crate) struct GeneralRenderPipelineDescriptor<
'a,
PLL = PipelineLayoutId,
SM = ShaderModuleId,
PLC = PipelineCacheId,
> {
pub label: Label<'a>,
/// The layout of bind groups for this pipeline.
pub layout: Option<PLL>,
/// The vertex processing state for this pipeline.
pub vertex: RenderPipelineVertexProcessor<'a, SM>,
/// The properties of the pipeline at the primitive assembly and rasterization level.
#[cfg_attr(feature = "serde", serde(default))]
pub primitive: wgt::PrimitiveState,
/// The effect of draw calls on the depth and stencil aspects of the output target, if any.
#[cfg_attr(feature = "serde", serde(default))]
pub depth_stencil: Option<wgt::DepthStencilState>,
/// The multi-sampling properties of the pipeline.
#[cfg_attr(feature = "serde", serde(default))]
pub multisample: wgt::MultisampleState,
/// The fragment processing state for this pipeline.
pub fragment: Option<FragmentState<'a, SM>>,
/// If the pipeline will be used with a multiview render pass, this indicates how many array
/// layers the attachments will have.
pub multiview: Option<NonZeroU32>,
/// The pipeline cache to use when creating this pipeline.
pub cache: Option<PLC>,
}
impl<'a, PLL, SM, PLC> From<RenderPipelineDescriptor<'a, PLL, SM, PLC>>
for GeneralRenderPipelineDescriptor<'a, PLL, SM, PLC>
{
fn from(value: RenderPipelineDescriptor<'a, PLL, SM, PLC>) -> Self {
Self {
label: value.label,
layout: value.layout,
vertex: RenderPipelineVertexProcessor::Vertex(value.vertex),
primitive: value.primitive,
depth_stencil: value.depth_stencil,
multisample: value.multisample,
fragment: value.fragment,
multiview: value.multiview,
cache: value.cache,
}
}
}
impl<'a, PLL, SM, PLC> From<MeshPipelineDescriptor<'a, PLL, SM, PLC>>
for GeneralRenderPipelineDescriptor<'a, PLL, SM, PLC>
{
fn from(value: MeshPipelineDescriptor<'a, PLL, SM, PLC>) -> Self {
Self {
label: value.label,
layout: value.layout,
vertex: RenderPipelineVertexProcessor::Mesh(value.task, value.mesh),
primitive: value.primitive,
depth_stencil: value.depth_stencil,
multisample: value.multisample,
fragment: value.fragment,
multiview: value.multiview,
cache: value.cache,
}
}
}
/// cbindgen:ignore
pub type ResolvedRenderPipelineDescriptor<'a> =
RenderPipelineDescriptor<'a, Arc<PipelineLayout>, Arc<ShaderModule>, Arc<PipelineCache>>;
pub(crate) type ResolvedGeneralRenderPipelineDescriptor<'a> =
GeneralRenderPipelineDescriptor<'a, Arc<PipelineLayout>, Arc<ShaderModule>, Arc<PipelineCache>>;
#[derive(Clone, Debug)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
@ -649,6 +775,8 @@ pub struct RenderPipeline {
/// The `label` from the descriptor used to create the resource.
pub(crate) label: String,
pub(crate) tracking_data: TrackingData,
/// Whether this is a mesh shader pipeline
pub(crate) is_mesh: bool,
}
impl Drop for RenderPipeline {

View File

@ -1288,6 +1288,7 @@ impl Interface {
)
}
naga::ShaderStage::Compute => (false, 0),
// TODO: add validation for these, see https://github.com/gfx-rs/wgpu/issues/8003
naga::ShaderStage::Task | naga::ShaderStage::Mesh => {
unreachable!()
}

View File

@ -254,13 +254,15 @@ impl<A: hal::Api> Example<A> {
let pipeline_desc = hal::RenderPipelineDescriptor {
label: None,
layout: &pipeline_layout,
vertex_stage: hal::ProgrammableStage {
module: &shader,
entry_point: "vs_main",
constants: &constants,
zero_initialize_workgroup_memory: true,
vertex_processor: hal::VertexProcessor::Standard {
vertex_stage: hal::ProgrammableStage {
module: &shader,
entry_point: "vs_main",
constants: &constants,
zero_initialize_workgroup_memory: true,
},
vertex_buffers: &[],
},
vertex_buffers: &[],
fragment_stage: Some(hal::ProgrammableStage {
module: &shader,
entry_point: "fs_main",

View File

@ -613,6 +613,12 @@ impl super::Adapter {
// store buffer sizes using 32 bit ints (a situation we have already encountered with vulkan).
max_buffer_size: i32::MAX as u64,
max_non_sampler_bindings: 1_000_000,
max_task_workgroup_total_count: 0,
max_task_workgroups_per_dimension: 0,
max_mesh_multiview_count: 0,
max_mesh_output_layers: 0,
max_blas_primitive_count: if supports_ray_tracing {
1 << 29 // 2^29
} else {

View File

@ -1747,8 +1747,16 @@ impl crate::Device for super::Device {
let (topology_class, topology) = conv::map_topology(desc.primitive.topology);
let mut shader_stages = wgt::ShaderStages::VERTEX;
let (vertex_stage_desc, vertex_buffers_desc) = match &desc.vertex_processor {
crate::VertexProcessor::Standard {
vertex_buffers,
vertex_stage,
} => (vertex_stage, *vertex_buffers),
crate::VertexProcessor::Mesh { .. } => unreachable!(),
};
let blob_vs = self.load_shader(
&desc.vertex_stage,
vertex_stage_desc,
desc.layout,
naga::ShaderStage::Vertex,
desc.fragment_stage.as_ref(),
@ -1765,7 +1773,7 @@ impl crate::Device for super::Device {
let mut input_element_descs = Vec::new();
for (i, (stride, vbuf)) in vertex_strides
.iter_mut()
.zip(desc.vertex_buffers)
.zip(vertex_buffers_desc)
.enumerate()
{
*stride = NonZeroU32::new(vbuf.array_stride as u32);
@ -1919,17 +1927,6 @@ impl crate::Device for super::Device {
})
}
unsafe fn create_mesh_pipeline(
&self,
_desc: &crate::MeshPipelineDescriptor<
<Self::A as crate::Api>::PipelineLayout,
<Self::A as crate::Api>::ShaderModule,
<Self::A as crate::Api>::PipelineCache,
>,
) -> Result<<Self::A as crate::Api>::RenderPipeline, crate::PipelineError> {
unreachable!()
}
unsafe fn destroy_render_pipeline(&self, _pipeline: super::RenderPipeline) {
self.counters.render_pipelines.sub(1);
}

View File

@ -4,10 +4,10 @@ use crate::{
AccelerationStructureBuildSizes, AccelerationStructureDescriptor, Api, BindGroupDescriptor,
BindGroupLayoutDescriptor, BufferDescriptor, BufferMapping, CommandEncoderDescriptor,
ComputePipelineDescriptor, Device, DeviceError, FenceValue,
GetAccelerationStructureBuildSizesDescriptor, Label, MemoryRange, MeshPipelineDescriptor,
PipelineCacheDescriptor, PipelineCacheError, PipelineError, PipelineLayoutDescriptor,
RenderPipelineDescriptor, SamplerDescriptor, ShaderError, ShaderInput, ShaderModuleDescriptor,
TextureDescriptor, TextureViewDescriptor, TlasInstance,
GetAccelerationStructureBuildSizesDescriptor, Label, MemoryRange, PipelineCacheDescriptor,
PipelineCacheError, PipelineError, PipelineLayoutDescriptor, RenderPipelineDescriptor,
SamplerDescriptor, ShaderError, ShaderInput, ShaderModuleDescriptor, TextureDescriptor,
TextureViewDescriptor, TlasInstance,
};
use super::{
@ -100,14 +100,6 @@ pub trait DynDevice: DynResource {
dyn DynPipelineCache,
>,
) -> Result<Box<dyn DynRenderPipeline>, PipelineError>;
unsafe fn create_mesh_pipeline(
&self,
desc: &MeshPipelineDescriptor<
dyn DynPipelineLayout,
dyn DynShaderModule,
dyn DynPipelineCache,
>,
) -> Result<Box<dyn DynRenderPipeline>, PipelineError>;
unsafe fn destroy_render_pipeline(&self, pipeline: Box<dyn DynRenderPipeline>);
unsafe fn create_compute_pipeline(
@ -394,8 +386,22 @@ impl<D: Device + DynResource> DynDevice for D {
let desc = RenderPipelineDescriptor {
label: desc.label,
layout: desc.layout.expect_downcast_ref(),
vertex_buffers: desc.vertex_buffers,
vertex_stage: desc.vertex_stage.clone().expect_downcast(),
vertex_processor: match &desc.vertex_processor {
crate::VertexProcessor::Standard {
vertex_buffers,
vertex_stage,
} => crate::VertexProcessor::Standard {
vertex_buffers,
vertex_stage: vertex_stage.clone().expect_downcast(),
},
crate::VertexProcessor::Mesh {
task_stage: task,
mesh_stage: mesh,
} => crate::VertexProcessor::Mesh {
task_stage: task.as_ref().map(|a| a.clone().expect_downcast()),
mesh_stage: mesh.clone().expect_downcast(),
},
},
primitive: desc.primitive,
depth_stencil: desc.depth_stencil.clone(),
multisample: desc.multisample,
@ -409,32 +415,6 @@ impl<D: Device + DynResource> DynDevice for D {
.map(|b| -> Box<dyn DynRenderPipeline> { Box::new(b) })
}
unsafe fn create_mesh_pipeline(
&self,
desc: &MeshPipelineDescriptor<
dyn DynPipelineLayout,
dyn DynShaderModule,
dyn DynPipelineCache,
>,
) -> Result<Box<dyn DynRenderPipeline>, PipelineError> {
let desc = MeshPipelineDescriptor {
label: desc.label,
layout: desc.layout.expect_downcast_ref(),
task_stage: desc.task_stage.clone().map(|f| f.expect_downcast()),
mesh_stage: desc.mesh_stage.clone().expect_downcast(),
primitive: desc.primitive,
depth_stencil: desc.depth_stencil.clone(),
multisample: desc.multisample,
fragment_stage: desc.fragment_stage.clone().map(|f| f.expect_downcast()),
color_targets: desc.color_targets,
multiview: desc.multiview,
cache: desc.cache.map(|c| c.expect_downcast_ref()),
};
unsafe { D::create_mesh_pipeline(self, &desc) }
.map(|b| -> Box<dyn DynRenderPipeline> { Box::new(b) })
}
unsafe fn destroy_render_pipeline(&self, pipeline: Box<dyn DynRenderPipeline>) {
unsafe { D::destroy_render_pipeline(self, pipeline.unbox()) };
}

View File

@ -801,6 +801,12 @@ impl super::Adapter {
max_compute_workgroups_per_dimension,
max_buffer_size: i32::MAX as u64,
max_non_sampler_bindings: u32::MAX,
max_task_workgroup_total_count: 0,
max_task_workgroups_per_dimension: 0,
max_mesh_multiview_count: 0,
max_mesh_output_layers: 0,
max_blas_primitive_count: 0,
max_blas_geometry_count: 0,
max_tlas_instance_count: 0,

View File

@ -1363,9 +1363,16 @@ impl crate::Device for super::Device {
super::PipelineCache,
>,
) -> Result<super::RenderPipeline, crate::PipelineError> {
let (vertex_stage, vertex_buffers) = match &desc.vertex_processor {
crate::VertexProcessor::Standard {
vertex_buffers,
ref vertex_stage,
} => (vertex_stage, vertex_buffers),
crate::VertexProcessor::Mesh { .. } => unreachable!(),
};
let gl = &self.shared.context.lock();
let mut shaders = ArrayVec::new();
shaders.push((naga::ShaderStage::Vertex, &desc.vertex_stage));
shaders.push((naga::ShaderStage::Vertex, vertex_stage));
if let Some(ref fs) = desc.fragment_stage {
shaders.push((naga::ShaderStage::Fragment, fs));
}
@ -1375,7 +1382,7 @@ impl crate::Device for super::Device {
let (vertex_buffers, vertex_attributes) = {
let mut buffers = Vec::new();
let mut attributes = Vec::new();
for (index, vb_layout) in desc.vertex_buffers.iter().enumerate() {
for (index, vb_layout) in vertex_buffers.iter().enumerate() {
buffers.push(super::VertexBufferDesc {
step: vb_layout.step_mode,
stride: vb_layout.array_stride as u32,
@ -1430,16 +1437,6 @@ impl crate::Device for super::Device {
alpha_to_coverage_enabled: desc.multisample.alpha_to_coverage_enabled,
})
}
unsafe fn create_mesh_pipeline(
&self,
_desc: &crate::MeshPipelineDescriptor<
<Self::A as crate::Api>::PipelineLayout,
<Self::A as crate::Api>::ShaderModule,
<Self::A as crate::Api>::PipelineCache,
>,
) -> Result<<Self::A as crate::Api>::RenderPipeline, crate::PipelineError> {
unreachable!()
}
unsafe fn destroy_render_pipeline(&self, pipeline: super::RenderPipeline) {
// If the pipeline only has 2 strong references remaining, they're `pipeline` and `program_cache`

View File

@ -931,15 +931,6 @@ pub trait Device: WasmNotSendSync {
<Self::A as Api>::PipelineCache,
>,
) -> Result<<Self::A as Api>::RenderPipeline, PipelineError>;
#[allow(clippy::type_complexity)]
unsafe fn create_mesh_pipeline(
&self,
desc: &MeshPipelineDescriptor<
<Self::A as Api>::PipelineLayout,
<Self::A as Api>::ShaderModule,
<Self::A as Api>::PipelineCache,
>,
) -> Result<<Self::A as Api>::RenderPipeline, PipelineError>;
unsafe fn destroy_render_pipeline(&self, pipeline: <Self::A as Api>::RenderPipeline);
#[allow(clippy::type_complexity)]
@ -2323,6 +2314,20 @@ pub struct VertexBufferLayout<'a> {
pub attributes: &'a [wgt::VertexAttribute],
}
#[derive(Clone, Debug)]
pub enum VertexProcessor<'a, M: DynShaderModule + ?Sized> {
Standard {
/// The format of any vertex buffers used with this pipeline.
vertex_buffers: &'a [VertexBufferLayout<'a>],
/// The vertex stage for this pipeline.
vertex_stage: ProgrammableStage<'a, M>,
},
Mesh {
task_stage: Option<ProgrammableStage<'a, M>>,
mesh_stage: ProgrammableStage<'a, M>,
},
}
/// Describes a render (graphics) pipeline.
#[derive(Clone, Debug)]
pub struct RenderPipelineDescriptor<
@ -2334,37 +2339,8 @@ pub struct RenderPipelineDescriptor<
pub label: Label<'a>,
/// The layout of bind groups for this pipeline.
pub layout: &'a Pl,
/// The format of any vertex buffers used with this pipeline.
pub vertex_buffers: &'a [VertexBufferLayout<'a>],
/// The vertex stage for this pipeline.
pub vertex_stage: ProgrammableStage<'a, M>,
/// The properties of the pipeline at the primitive assembly and rasterization level.
pub primitive: wgt::PrimitiveState,
/// The effect of draw calls on the depth and stencil aspects of the output target, if any.
pub depth_stencil: Option<wgt::DepthStencilState>,
/// The multi-sampling properties of the pipeline.
pub multisample: wgt::MultisampleState,
/// The fragment stage for this pipeline.
pub fragment_stage: Option<ProgrammableStage<'a, M>>,
/// The effect of draw calls on the color aspect of the output target.
pub color_targets: &'a [Option<wgt::ColorTargetState>],
/// If the pipeline will be used with a multiview render pass, this indicates how many array
/// layers the attachments will have.
pub multiview: Option<NonZeroU32>,
/// The cache which will be used and filled when compiling this pipeline
pub cache: Option<&'a Pc>,
}
pub struct MeshPipelineDescriptor<
'a,
Pl: DynPipelineLayout + ?Sized,
M: DynShaderModule + ?Sized,
Pc: DynPipelineCache + ?Sized,
> {
pub label: Label<'a>,
/// The layout of bind groups for this pipeline.
pub layout: &'a Pl,
pub task_stage: Option<ProgrammableStage<'a, M>>,
pub mesh_stage: ProgrammableStage<'a, M>,
/// The vertex processing state(vertex shader + buffers or task + mesh shaders)
pub vertex_processor: VertexProcessor<'a, M>,
/// The properties of the pipeline at the primitive assembly and rasterization level.
pub primitive: wgt::PrimitiveState,
/// The effect of draw calls on the depth and stencil aspects of the output target, if any.

View File

@ -1077,6 +1077,12 @@ impl super::PrivateCapabilities {
max_compute_workgroups_per_dimension: 0xFFFF,
max_buffer_size: self.max_buffer_size,
max_non_sampler_bindings: u32::MAX,
max_task_workgroup_total_count: 0,
max_task_workgroups_per_dimension: 0,
max_mesh_multiview_count: 0,
max_mesh_output_layers: 0,
max_blas_primitive_count: 0, // When added: 2^28 from https://developer.apple.com/documentation/metal/mtlaccelerationstructureusage/extendedlimits
max_blas_geometry_count: 0, // When added: 2^24
max_tlas_instance_count: 0, // When added: 2^24

View File

@ -1056,6 +1056,14 @@ impl crate::Device for super::Device {
super::PipelineCache,
>,
) -> Result<super::RenderPipeline, crate::PipelineError> {
let (desc_vertex_stage, desc_vertex_buffers) = match &desc.vertex_processor {
crate::VertexProcessor::Standard {
vertex_buffers,
vertex_stage,
} => (vertex_stage, *vertex_buffers),
crate::VertexProcessor::Mesh { .. } => unreachable!(),
};
objc::rc::autoreleasepool(|| {
let descriptor = metal::RenderPipelineDescriptor::new();
@ -1074,7 +1082,7 @@ impl crate::Device for super::Device {
// Vertex shader
let (vs_lib, vs_info) = {
let mut vertex_buffer_mappings = Vec::<naga::back::msl::VertexBufferMapping>::new();
for (i, vbl) in desc.vertex_buffers.iter().enumerate() {
for (i, vbl) in desc_vertex_buffers.iter().enumerate() {
let mut attributes = Vec::<naga::back::msl::AttributeMapping>::new();
for attribute in vbl.attributes.iter() {
attributes.push(naga::back::msl::AttributeMapping {
@ -1103,7 +1111,7 @@ impl crate::Device for super::Device {
}
let vs = self.load_shader(
&desc.vertex_stage,
desc_vertex_stage,
&vertex_buffer_mappings,
desc.layout,
primitive_class,
@ -1216,12 +1224,12 @@ impl crate::Device for super::Device {
None => None,
};
if desc.layout.total_counters.vs.buffers + (desc.vertex_buffers.len() as u32)
if desc.layout.total_counters.vs.buffers + (desc_vertex_buffers.len() as u32)
> self.shared.private_caps.max_vertex_buffers
{
let msg = format!(
"pipeline needs too many buffers in the vertex stage: {} vertex and {} layout",
desc.vertex_buffers.len(),
desc_vertex_buffers.len(),
desc.layout.total_counters.vs.buffers
);
return Err(crate::PipelineError::Linkage(
@ -1230,9 +1238,9 @@ impl crate::Device for super::Device {
));
}
if !desc.vertex_buffers.is_empty() {
if !desc_vertex_buffers.is_empty() {
let vertex_descriptor = metal::VertexDescriptor::new();
for (i, vb) in desc.vertex_buffers.iter().enumerate() {
for (i, vb) in desc_vertex_buffers.iter().enumerate() {
let buffer_index =
self.shared.private_caps.max_vertex_buffers as u64 - 1 - i as u64;
let buffer_desc = vertex_descriptor.layouts().object_at(buffer_index).unwrap();
@ -1318,17 +1326,6 @@ impl crate::Device for super::Device {
})
}
unsafe fn create_mesh_pipeline(
&self,
_desc: &crate::MeshPipelineDescriptor<
<Self::A as crate::Api>::PipelineLayout,
<Self::A as crate::Api>::ShaderModule,
<Self::A as crate::Api>::PipelineCache,
>,
) -> Result<<Self::A as crate::Api>::RenderPipeline, crate::PipelineError> {
unreachable!()
}
unsafe fn destroy_render_pipeline(&self, _pipeline: super::RenderPipeline) {
self.counters.render_pipelines.sub(1);
}

View File

@ -179,6 +179,12 @@ const CAPABILITIES: crate::Capabilities = {
max_subgroup_size: ALLOC_MAX_U32,
max_push_constant_size: ALLOC_MAX_U32,
max_non_sampler_bindings: ALLOC_MAX_U32,
max_task_workgroup_total_count: 0,
max_task_workgroups_per_dimension: 0,
max_mesh_multiview_count: 0,
max_mesh_output_layers: 0,
max_blas_primitive_count: ALLOC_MAX_U32,
max_blas_geometry_count: ALLOC_MAX_U32,
max_tlas_instance_count: ALLOC_MAX_U32,
@ -375,16 +381,6 @@ impl crate::Device for Context {
) -> Result<Resource, crate::PipelineError> {
Ok(Resource)
}
unsafe fn create_mesh_pipeline(
&self,
desc: &crate::MeshPipelineDescriptor<
<Self::A as crate::Api>::PipelineLayout,
<Self::A as crate::Api>::ShaderModule,
<Self::A as crate::Api>::PipelineCache,
>,
) -> Result<<Self::A as crate::Api>::RenderPipeline, crate::PipelineError> {
Ok(Resource)
}
unsafe fn destroy_render_pipeline(&self, pipeline: Resource) {}
unsafe fn create_compute_pipeline(
&self,

View File

@ -932,7 +932,7 @@ pub struct PhysicalDeviceProperties {
/// Additional `vk::PhysicalDevice` properties from the
/// `VK_EXT_mesh_shader` extension.
_mesh_shader: Option<vk::PhysicalDeviceMeshShaderPropertiesEXT<'static>>,
mesh_shader: Option<vk::PhysicalDeviceMeshShaderPropertiesEXT<'static>>,
/// The device API version.
///
@ -1160,6 +1160,20 @@ impl PhysicalDeviceProperties {
let max_compute_workgroups_per_dimension = limits.max_compute_work_group_count[0]
.min(limits.max_compute_work_group_count[1])
.min(limits.max_compute_work_group_count[2]);
let (
max_task_workgroup_total_count,
max_task_workgroups_per_dimension,
max_mesh_multiview_count,
max_mesh_output_layers,
) = match self.mesh_shader {
Some(m) => (
m.max_task_work_group_total_count,
m.max_task_work_group_count.into_iter().min().unwrap(),
m.max_mesh_multiview_view_count,
m.max_mesh_output_layers,
),
None => (0, 0, 0, 0),
};
// Prevent very large buffers on mesa and most android devices.
let is_nvidia = self.properties.vendor_id == crate::auxil::db::nvidia::VENDOR;
@ -1267,6 +1281,12 @@ impl PhysicalDeviceProperties {
max_compute_workgroups_per_dimension,
max_buffer_size,
max_non_sampler_bindings: u32::MAX,
max_task_workgroup_total_count,
max_task_workgroups_per_dimension,
max_mesh_multiview_count,
max_mesh_output_layers,
max_blas_primitive_count,
max_blas_geometry_count,
max_tlas_instance_count,
@ -1401,7 +1421,7 @@ impl super::InstanceShared {
if supports_mesh_shader {
let next = capabilities
._mesh_shader
.mesh_shader
.insert(vk::PhysicalDeviceMeshShaderPropertiesEXT::default());
properties2 = properties2.push_next(next);
}

View File

@ -747,6 +747,12 @@ pub fn map_shader_stage(stage: wgt::ShaderStages) -> vk::ShaderStageFlags {
if stage.contains(wgt::ShaderStages::COMPUTE) {
flags |= vk::ShaderStageFlags::COMPUTE;
}
if stage.contains(wgt::ShaderStages::TASK) {
flags |= vk::ShaderStageFlags::TASK_EXT;
}
if stage.contains(wgt::ShaderStages::MESH) {
flags |= vk::ShaderStageFlags::MESH_EXT;
}
flags
}

View File

@ -1979,25 +1979,32 @@ impl crate::Device for super::Device {
..Default::default()
};
let mut stages = ArrayVec::<_, { crate::MAX_CONCURRENT_SHADER_STAGES }>::new();
let mut vertex_buffers = Vec::with_capacity(desc.vertex_buffers.len());
let mut vertex_buffers = Vec::new();
let mut vertex_attributes = Vec::new();
for (i, vb) in desc.vertex_buffers.iter().enumerate() {
vertex_buffers.push(vk::VertexInputBindingDescription {
binding: i as u32,
stride: vb.array_stride as u32,
input_rate: match vb.step_mode {
wgt::VertexStepMode::Vertex => vk::VertexInputRate::VERTEX,
wgt::VertexStepMode::Instance => vk::VertexInputRate::INSTANCE,
},
});
for at in vb.attributes {
vertex_attributes.push(vk::VertexInputAttributeDescription {
location: at.shader_location,
if let crate::VertexProcessor::Standard {
vertex_buffers: desc_vertex_buffers,
vertex_stage: _,
} = &desc.vertex_processor
{
vertex_buffers = Vec::with_capacity(desc_vertex_buffers.len());
for (i, vb) in desc_vertex_buffers.iter().enumerate() {
vertex_buffers.push(vk::VertexInputBindingDescription {
binding: i as u32,
format: conv::map_vertex_format(at.format),
offset: at.offset as u32,
stride: vb.array_stride as u32,
input_rate: match vb.step_mode {
wgt::VertexStepMode::Vertex => vk::VertexInputRate::VERTEX,
wgt::VertexStepMode::Instance => vk::VertexInputRate::INSTANCE,
},
});
for at in vb.attributes {
vertex_attributes.push(vk::VertexInputAttributeDescription {
location: at.shader_location,
binding: i as u32,
format: conv::map_vertex_format(at.format),
offset: at.offset as u32,
});
}
}
}
@ -2009,12 +2016,41 @@ impl crate::Device for super::Device {
.topology(conv::map_topology(desc.primitive.topology))
.primitive_restart_enable(desc.primitive.strip_index_format.is_some());
let compiled_vs = self.compile_stage(
&desc.vertex_stage,
naga::ShaderStage::Vertex,
&desc.layout.binding_arrays,
)?;
stages.push(compiled_vs.create_info);
let mut compiled_vs = None;
let mut compiled_ms = None;
let mut compiled_ts = None;
match &desc.vertex_processor {
crate::VertexProcessor::Standard {
vertex_buffers: _,
vertex_stage,
} => {
compiled_vs = Some(self.compile_stage(
vertex_stage,
naga::ShaderStage::Vertex,
&desc.layout.binding_arrays,
)?);
stages.push(compiled_vs.as_ref().unwrap().create_info);
}
crate::VertexProcessor::Mesh {
task_stage,
mesh_stage,
} => {
if let Some(t) = task_stage.as_ref() {
compiled_ts = Some(self.compile_stage(
t,
naga::ShaderStage::Task,
&desc.layout.binding_arrays,
)?);
stages.push(compiled_ts.as_ref().unwrap().create_info);
}
compiled_ms = Some(self.compile_stage(
mesh_stage,
naga::ShaderStage::Mesh,
&desc.layout.binding_arrays,
)?);
stages.push(compiled_ms.as_ref().unwrap().create_info);
}
}
let compiled_fs = match desc.fragment_stage {
Some(ref stage) => {
let compiled = self.compile_stage(
@ -2177,228 +2213,13 @@ impl crate::Device for super::Device {
unsafe { self.shared.set_object_name(raw, label) };
}
if let Some(raw_module) = compiled_vs.temp_raw_module {
unsafe { self.shared.raw.destroy_shader_module(raw_module, None) };
}
if let Some(CompiledStage {
temp_raw_module: Some(raw_module),
..
}) = compiled_fs
}) = compiled_vs
{
unsafe { self.shared.raw.destroy_shader_module(raw_module, None) };
}
self.counters.render_pipelines.add(1);
Ok(super::RenderPipeline { raw })
}
unsafe fn create_mesh_pipeline(
&self,
desc: &crate::MeshPipelineDescriptor<
<Self::A as crate::Api>::PipelineLayout,
<Self::A as crate::Api>::ShaderModule,
<Self::A as crate::Api>::PipelineCache,
>,
) -> Result<<Self::A as crate::Api>::RenderPipeline, crate::PipelineError> {
let dynamic_states = [
vk::DynamicState::VIEWPORT,
vk::DynamicState::SCISSOR,
vk::DynamicState::BLEND_CONSTANTS,
vk::DynamicState::STENCIL_REFERENCE,
];
let mut compatible_rp_key = super::RenderPassKey {
sample_count: desc.multisample.count,
multiview: desc.multiview,
..Default::default()
};
let mut stages = ArrayVec::<_, { crate::MAX_CONCURRENT_SHADER_STAGES }>::new();
let vk_input_assembly = vk::PipelineInputAssemblyStateCreateInfo::default()
.topology(conv::map_topology(desc.primitive.topology))
.primitive_restart_enable(desc.primitive.strip_index_format.is_some());
let compiled_ts = match desc.task_stage {
Some(ref stage) => {
let mut compiled = self.compile_stage(
stage,
naga::ShaderStage::Task,
&desc.layout.binding_arrays,
)?;
compiled.create_info.stage = vk::ShaderStageFlags::TASK_EXT;
stages.push(compiled.create_info);
Some(compiled)
}
None => None,
};
let mut compiled_ms = self.compile_stage(
&desc.mesh_stage,
naga::ShaderStage::Mesh,
&desc.layout.binding_arrays,
)?;
compiled_ms.create_info.stage = vk::ShaderStageFlags::MESH_EXT;
stages.push(compiled_ms.create_info);
let compiled_fs = match desc.fragment_stage {
Some(ref stage) => {
let compiled = self.compile_stage(
stage,
naga::ShaderStage::Fragment,
&desc.layout.binding_arrays,
)?;
stages.push(compiled.create_info);
Some(compiled)
}
None => None,
};
let mut vk_rasterization = vk::PipelineRasterizationStateCreateInfo::default()
.polygon_mode(conv::map_polygon_mode(desc.primitive.polygon_mode))
.front_face(conv::map_front_face(desc.primitive.front_face))
.line_width(1.0)
.depth_clamp_enable(desc.primitive.unclipped_depth);
if let Some(face) = desc.primitive.cull_mode {
vk_rasterization = vk_rasterization.cull_mode(conv::map_cull_face(face))
}
let mut vk_rasterization_conservative_state =
vk::PipelineRasterizationConservativeStateCreateInfoEXT::default()
.conservative_rasterization_mode(
vk::ConservativeRasterizationModeEXT::OVERESTIMATE,
);
if desc.primitive.conservative {
vk_rasterization = vk_rasterization.push_next(&mut vk_rasterization_conservative_state);
}
let mut vk_depth_stencil = vk::PipelineDepthStencilStateCreateInfo::default();
if let Some(ref ds) = desc.depth_stencil {
let vk_format = self.shared.private_caps.map_texture_format(ds.format);
let vk_layout = if ds.is_read_only(desc.primitive.cull_mode) {
vk::ImageLayout::DEPTH_STENCIL_READ_ONLY_OPTIMAL
} else {
vk::ImageLayout::DEPTH_STENCIL_ATTACHMENT_OPTIMAL
};
compatible_rp_key.depth_stencil = Some(super::DepthStencilAttachmentKey {
base: super::AttachmentKey::compatible(vk_format, vk_layout),
stencil_ops: crate::AttachmentOps::all(),
});
if ds.is_depth_enabled() {
vk_depth_stencil = vk_depth_stencil
.depth_test_enable(true)
.depth_write_enable(ds.depth_write_enabled)
.depth_compare_op(conv::map_comparison(ds.depth_compare));
}
if ds.stencil.is_enabled() {
let s = &ds.stencil;
let front = conv::map_stencil_face(&s.front, s.read_mask, s.write_mask);
let back = conv::map_stencil_face(&s.back, s.read_mask, s.write_mask);
vk_depth_stencil = vk_depth_stencil
.stencil_test_enable(true)
.front(front)
.back(back);
}
if ds.bias.is_enabled() {
vk_rasterization = vk_rasterization
.depth_bias_enable(true)
.depth_bias_constant_factor(ds.bias.constant as f32)
.depth_bias_clamp(ds.bias.clamp)
.depth_bias_slope_factor(ds.bias.slope_scale);
}
}
let vk_viewport = vk::PipelineViewportStateCreateInfo::default()
.flags(vk::PipelineViewportStateCreateFlags::empty())
.scissor_count(1)
.viewport_count(1);
let vk_sample_mask = [
desc.multisample.mask as u32,
(desc.multisample.mask >> 32) as u32,
];
let vk_multisample = vk::PipelineMultisampleStateCreateInfo::default()
.rasterization_samples(vk::SampleCountFlags::from_raw(desc.multisample.count))
.alpha_to_coverage_enable(desc.multisample.alpha_to_coverage_enabled)
.sample_mask(&vk_sample_mask);
let mut vk_attachments = Vec::with_capacity(desc.color_targets.len());
for cat in desc.color_targets {
let (key, attarchment) = if let Some(cat) = cat.as_ref() {
let mut vk_attachment = vk::PipelineColorBlendAttachmentState::default()
.color_write_mask(vk::ColorComponentFlags::from_raw(cat.write_mask.bits()));
if let Some(ref blend) = cat.blend {
let (color_op, color_src, color_dst) = conv::map_blend_component(&blend.color);
let (alpha_op, alpha_src, alpha_dst) = conv::map_blend_component(&blend.alpha);
vk_attachment = vk_attachment
.blend_enable(true)
.color_blend_op(color_op)
.src_color_blend_factor(color_src)
.dst_color_blend_factor(color_dst)
.alpha_blend_op(alpha_op)
.src_alpha_blend_factor(alpha_src)
.dst_alpha_blend_factor(alpha_dst);
}
let vk_format = self.shared.private_caps.map_texture_format(cat.format);
(
Some(super::ColorAttachmentKey {
base: super::AttachmentKey::compatible(
vk_format,
vk::ImageLayout::COLOR_ATTACHMENT_OPTIMAL,
),
resolve: None,
}),
vk_attachment,
)
} else {
(None, vk::PipelineColorBlendAttachmentState::default())
};
compatible_rp_key.colors.push(key);
vk_attachments.push(attarchment);
}
let vk_color_blend =
vk::PipelineColorBlendStateCreateInfo::default().attachments(&vk_attachments);
let vk_dynamic_state =
vk::PipelineDynamicStateCreateInfo::default().dynamic_states(&dynamic_states);
let raw_pass = self.shared.make_render_pass(compatible_rp_key)?;
let vk_infos = [{
vk::GraphicsPipelineCreateInfo::default()
.layout(desc.layout.raw)
.stages(&stages)
.input_assembly_state(&vk_input_assembly)
.rasterization_state(&vk_rasterization)
.viewport_state(&vk_viewport)
.multisample_state(&vk_multisample)
.depth_stencil_state(&vk_depth_stencil)
.color_blend_state(&vk_color_blend)
.dynamic_state(&vk_dynamic_state)
.render_pass(raw_pass)
}];
let pipeline_cache = desc
.cache
.map(|it| it.raw)
.unwrap_or(vk::PipelineCache::null());
let mut raw_vec = {
profiling::scope!("vkCreateGraphicsPipelines");
unsafe {
self.shared
.raw
.create_graphics_pipelines(pipeline_cache, &vk_infos, None)
.map_err(|(_, e)| super::map_pipeline_err(e))
}?
};
let raw = raw_vec.pop().unwrap();
if let Some(label) = desc.label {
unsafe { self.shared.set_object_name(raw, label) };
}
// NOTE: this could leak shaders in case of an error.
if let Some(CompiledStage {
temp_raw_module: Some(raw_module),
..
@ -2406,7 +2227,11 @@ impl crate::Device for super::Device {
{
unsafe { self.shared.raw.destroy_shader_module(raw_module, None) };
}
if let Some(raw_module) = compiled_ms.temp_raw_module {
if let Some(CompiledStage {
temp_raw_module: Some(raw_module),
..
}) = compiled_ms
{
unsafe { self.shared.raw.destroy_shader_module(raw_module, None) };
}
if let Some(CompiledStage {

View File

@ -160,6 +160,12 @@ fn print_adapter(output: &mut impl io::Write, report: &AdapterReport, idx: usize
max_subgroup_size,
max_push_constant_size,
max_non_sampler_bindings,
max_task_workgroup_total_count,
max_task_workgroups_per_dimension,
max_mesh_multiview_count,
max_mesh_output_layers,
max_blas_primitive_count,
max_blas_geometry_count,
max_tlas_instance_count,
@ -200,6 +206,12 @@ fn print_adapter(output: &mut impl io::Write, report: &AdapterReport, idx: usize
writeln!(output, "\t\t Max Compute Workgroup Size Y: {max_compute_workgroup_size_y}")?;
writeln!(output, "\t\t Max Compute Workgroup Size Z: {max_compute_workgroup_size_z}")?;
writeln!(output, "\t\t Max Compute Workgroups Per Dimension: {max_compute_workgroups_per_dimension}")?;
writeln!(output, "\t\t Max Task Workgroup Total Count: {max_task_workgroup_total_count}")?;
writeln!(output, "\t\t Max Task Workgroups Per Dimension: {max_task_workgroups_per_dimension}")?;
writeln!(output, "\t\t Max Mesh Multiview Count: {max_mesh_multiview_count}")?;
writeln!(output, "\t\t Max Mesh Output Layers: {max_mesh_output_layers}")?;
writeln!(output, "\t\t Max BLAS Primitive count: {max_blas_primitive_count}")?;
writeln!(output, "\t\t Max BLAS Geometry count: {max_blas_geometry_count}")?;
writeln!(output, "\t\t Max TLAS Instance count: {max_tlas_instance_count}")?;

View File

@ -615,6 +615,17 @@ pub struct Limits {
/// This limit only affects the d3d12 backend. Using a large number will allow the device
/// to create many bind groups at the cost of a large up-front allocation at device creation.
pub max_non_sampler_bindings: u32,
/// The maximum total value of x*y*z for a given `draw_mesh_tasks` command
pub max_task_workgroup_total_count: u32,
/// The maximum value for each dimension of a `RenderPass::draw_mesh_tasks(x, y, z)` operation.
/// Defaults to 65535. Higher is "better".
pub max_task_workgroups_per_dimension: u32,
/// The maximum number of layers that can be output from a mesh shader
pub max_mesh_output_layers: u32,
/// The maximum number of views that can be used by a mesh shader
pub max_mesh_multiview_count: u32,
/// The maximum number of primitive (ex: triangles, aabbs) a BLAS is allowed to have. Requesting
/// more than 0 during device creation only makes sense if [`Features::EXPERIMENTAL_RAY_QUERY`]
/// is enabled.
@ -683,6 +694,10 @@ impl Limits {
/// max_subgroup_size: 0,
/// max_push_constant_size: 0,
/// max_non_sampler_bindings: 1_000_000,
/// max_task_workgroup_total_count: 0,
/// max_task_workgroups_per_dimension: 0,
/// max_mesh_multiview_count: 0,
/// max_mesh_output_layers: 0,
/// max_blas_primitive_count: 0,
/// max_blas_geometry_count: 0,
/// max_tlas_instance_count: 0,
@ -731,6 +746,12 @@ impl Limits {
max_subgroup_size: 0,
max_push_constant_size: 0,
max_non_sampler_bindings: 1_000_000,
max_task_workgroup_total_count: 0,
max_task_workgroups_per_dimension: 0,
max_mesh_multiview_count: 0,
max_mesh_output_layers: 0,
max_blas_primitive_count: 0,
max_blas_geometry_count: 0,
max_tlas_instance_count: 0,
@ -780,6 +801,12 @@ impl Limits {
/// max_compute_workgroups_per_dimension: 65535,
/// max_buffer_size: 256 << 20, // (256 MiB)
/// max_non_sampler_bindings: 1_000_000,
///
/// max_task_workgroup_total_count: 0,
/// max_task_workgroups_per_dimension: 0,
/// max_mesh_multiview_count: 0,
/// max_mesh_output_layers: 0,
///
/// max_blas_primitive_count: 0,
/// max_blas_geometry_count: 0,
/// max_tlas_instance_count: 0,
@ -797,6 +824,11 @@ impl Limits {
max_color_attachments: 4,
// see: https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf#page=7
max_compute_workgroup_storage_size: 16352,
max_task_workgroups_per_dimension: 0,
max_task_workgroup_total_count: 0,
max_mesh_multiview_count: 0,
max_mesh_output_layers: 0,
..Self::defaults()
}
}
@ -844,6 +876,12 @@ impl Limits {
/// max_compute_workgroups_per_dimension: 0, // +
/// max_buffer_size: 256 << 20, // (256 MiB),
/// max_non_sampler_bindings: 1_000_000,
///
/// max_task_workgroup_total_count: 0,
/// max_task_workgroups_per_dimension: 0,
/// max_mesh_multiview_count: 0,
/// max_mesh_output_layers: 0,
///
/// max_blas_primitive_count: 0,
/// max_blas_geometry_count: 0,
/// max_tlas_instance_count: 0,
@ -929,6 +967,24 @@ impl Limits {
}
}
/// The recommended minimum limits for mesh shaders if you enable [`Features::EXPERIMENTAL_MESH_SHADER`]
///
/// These are chosen somewhat arbitrarily. They are small enough that they should cover all physical devices,
/// but not necessarily all use cases.
#[must_use]
pub const fn using_recommended_minimum_mesh_shader_values(self) -> Self {
Self {
// Literally just made this up as 256^2 or 2^16.
// My GPU supports 2^22, and compute shaders don't have this kind of limit.
// This very likely is never a real limiter
max_task_workgroup_total_count: 65536,
max_task_workgroups_per_dimension: 256,
max_mesh_multiview_count: 1,
max_mesh_output_layers: 1024,
..self
}
}
/// Compares every limits within self is within the limits given in `allowed`.
///
/// If you need detailed information on failures, look at [`Limits::check_limits_with_fail_fn`].
@ -1008,6 +1064,12 @@ impl Limits {
}
compare!(max_push_constant_size, Less);
compare!(max_non_sampler_bindings, Less);
compare!(max_task_workgroup_total_count, Less);
compare!(max_task_workgroups_per_dimension, Less);
compare!(max_mesh_multiview_count, Less);
compare!(max_mesh_output_layers, Less);
compare!(max_blas_primitive_count, Less);
compare!(max_blas_geometry_count, Less);
compare!(max_tlas_instance_count, Less);
@ -1402,9 +1464,9 @@ bitflags::bitflags! {
const COMPUTE = 1 << 2;
/// Binding is visible from the vertex and fragment shaders of a render pipeline.
const VERTEX_FRAGMENT = Self::VERTEX.bits() | Self::FRAGMENT.bits();
/// Binding is visible from the task shader of a mesh pipeline
/// Binding is visible from the task shader of a mesh pipeline.
const TASK = 1 << 3;
/// Binding is visible from the mesh shader of a mesh pipeline
/// Binding is visible from the mesh shader of a mesh pipeline.
const MESH = 1 << 4;
}
}

View File

@ -245,6 +245,13 @@ impl Device {
RenderPipeline { inner: pipeline }
}
/// Creates a mesh shader based [`RenderPipeline`].
#[must_use]
pub fn create_mesh_pipeline(&self, desc: &MeshPipelineDescriptor<'_>) -> RenderPipeline {
let pipeline = self.inner.create_mesh_pipeline(desc);
RenderPipeline { inner: pipeline }
}
/// Creates a [`ComputePipeline`].
#[must_use]
pub fn create_compute_pipeline(&self, desc: &ComputePipelineDescriptor<'_>) -> ComputePipeline {

View File

@ -226,6 +226,12 @@ impl RenderPass<'_> {
self.inner.draw_indexed(indices, base_vertex, instances);
}
/// Draws using a mesh shader pipeline
pub fn draw_mesh_tasks(&mut self, group_count_x: u32, group_count_y: u32, group_count_z: u32) {
self.inner
.draw_mesh_tasks(group_count_x, group_count_y, group_count_z);
}
/// Draws primitives from the active vertex buffer(s) based on the contents of the `indirect_buffer`.
///
/// This is like calling [`RenderPass::draw`] but the contents of the call are specified in the `indirect_buffer`.
@ -249,6 +255,25 @@ impl RenderPass<'_> {
.draw_indexed_indirect(&indirect_buffer.inner, indirect_offset);
}
/// Draws using a mesh shader pipeline,
/// based on the contents of the `indirect_buffer`
///
/// This is like calling [`RenderPass::draw_mesh_tasks`] but the contents of the call are specified in the `indirect_buffer`.
/// The structure expected in the `indirect_buffer` must conform to [`DispatchIndirectArgs`](crate::util::DispatchIndirectArgs).
///
/// Indirect drawing has some caveats depending on the features available. We are not currently able to validate
/// these and issue an error.
///
/// See details on the individual flags for more information.
pub fn draw_mesh_tasks_indirect(
&mut self,
indirect_buffer: &Buffer,
indirect_offset: BufferAddress,
) {
self.inner
.draw_mesh_tasks_indirect(&indirect_buffer.inner, indirect_offset);
}
/// Execute a [render bundle][RenderBundle], which is a set of pre-recorded commands
/// that can be run together.
///
@ -307,6 +332,23 @@ impl RenderPass<'_> {
.multi_draw_indexed_indirect(&indirect_buffer.inner, indirect_offset, count);
}
/// Dispatches multiple draw calls based on the contents of the `indirect_buffer`.
/// `count` draw calls are issued.
///
/// The structure expected in the `indirect_buffer` must conform to [`DispatchIndirectArgs`](crate::util::DispatchIndirectArgs).
///
/// This drawing command uses the current render state, as set by preceding `set_*()` methods.
/// It is not affected by changes to the state that are performed after it is called.
pub fn multi_draw_mesh_tasks_indirect(
&mut self,
indirect_buffer: &Buffer,
indirect_offset: BufferAddress,
count: u32,
) {
self.inner
.multi_draw_mesh_tasks_indirect(&indirect_buffer.inner, indirect_offset, count);
}
#[cfg(custom)]
/// Returns custom implementation of RenderPass (if custom backend and is internally T)
pub fn as_custom<T: custom::RenderPassInterface>(&self) -> Option<&T> {
@ -395,6 +437,34 @@ impl RenderPass<'_> {
max_count,
);
}
/// Dispatches multiple draw calls based on the contents of the `indirect_buffer`. The count buffer is read to determine how many draws to issue.
///
/// The indirect buffer must be long enough to account for `max_count` draws, however only `count`
/// draws will be read. If `count` is greater than `max_count`, `max_count` will be used.
///
/// The structure expected in the `indirect_buffer` must conform to [`DispatchIndirectArgs`](crate::util::DispatchIndirectArgs).
///
/// These draw structures are expected to be tightly packed.
///
/// This drawing command uses the current render state, as set by preceding `set_*()` methods.
/// It is not affected by changes to the state that are performed after it is called.
pub fn multi_draw_mesh_tasks_indirect_count(
&mut self,
indirect_buffer: &Buffer,
indirect_offset: BufferAddress,
count_buffer: &Buffer,
count_offset: BufferAddress,
max_count: u32,
) {
self.inner.multi_draw_mesh_tasks_indirect_count(
&indirect_buffer.inner,
indirect_offset,
&count_buffer.inner,
count_offset,
max_count,
);
}
}
/// [`Features::PUSH_CONSTANTS`] must be enabled on the device in order to call these functions.

View File

@ -145,6 +145,48 @@ pub struct FragmentState<'a> {
#[cfg(send_sync)]
static_assertions::assert_impl_all!(FragmentState<'_>: Send, Sync);
/// Describes the task shader stage in a mesh shader pipeline.
///
/// For use in [`MeshPipelineDescriptor`]
#[derive(Clone, Debug)]
pub struct TaskState<'a> {
/// The compiled shader module for this stage.
pub module: &'a ShaderModule,
/// The name of the entry point in the compiled shader to use.
///
/// If [`Some`], there must be a vertex-stage shader entry point with this name in `module`.
/// Otherwise, expect exactly one vertex-stage entry point in `module`, which will be
/// selected.
pub entry_point: Option<&'a str>,
/// Advanced options for when this pipeline is compiled
///
/// This implements `Default`, and for most users can be set to `Default::default()`
pub compilation_options: PipelineCompilationOptions<'a>,
}
#[cfg(send_sync)]
static_assertions::assert_impl_all!(TaskState<'_>: Send, Sync);
/// Describes the mesh shader stage in a mesh shader pipeline.
///
/// For use in [`MeshPipelineDescriptor`]
#[derive(Clone, Debug)]
pub struct MeshState<'a> {
/// The compiled shader module for this stage.
pub module: &'a ShaderModule,
/// The name of the entry point in the compiled shader to use.
///
/// If [`Some`], there must be a vertex-stage shader entry point with this name in `module`.
/// Otherwise, expect exactly one vertex-stage entry point in `module`, which will be
/// selected.
pub entry_point: Option<&'a str>,
/// Advanced options for when this pipeline is compiled
///
/// This implements `Default`, and for most users can be set to `Default::default()`
pub compilation_options: PipelineCompilationOptions<'a>,
}
#[cfg(send_sync)]
static_assertions::assert_impl_all!(MeshState<'_>: Send, Sync);
/// Describes a render (graphics) pipeline.
///
/// For use with [`Device::create_render_pipeline`].
@ -193,3 +235,51 @@ pub struct RenderPipelineDescriptor<'a> {
}
#[cfg(send_sync)]
static_assertions::assert_impl_all!(RenderPipelineDescriptor<'_>: Send, Sync);
/// Describes a mesh shader (graphics) pipeline.
///
/// For use with [`Device::create_mesh_pipeline`].
#[derive(Clone, Debug)]
pub struct MeshPipelineDescriptor<'a> {
/// Debug label of the pipeline. This will show up in graphics debuggers for easy identification.
pub label: Label<'a>,
/// The layout of bind groups for this pipeline.
///
/// If this is set, then [`Device::create_render_pipeline`] will raise a validation error if
/// the layout doesn't match what the shader module(s) expect.
///
/// Using the same [`PipelineLayout`] for many [`RenderPipeline`] or [`ComputePipeline`]
/// pipelines guarantees that you don't have to rebind any resources when switching between
/// those pipelines.
///
/// ## Default pipeline layout
///
/// If `layout` is `None`, then the pipeline has a [default layout] created and used instead.
/// The default layout is deduced from the shader modules.
///
/// You can use [`RenderPipeline::get_bind_group_layout`] to create bind groups for use with the
/// default layout. However, these bind groups cannot be used with any other pipelines. This is
/// convenient for simple pipelines, but using an explicit layout is recommended in most cases.
///
/// [default layout]: https://www.w3.org/TR/webgpu/#default-pipeline-layout
pub layout: Option<&'a PipelineLayout>,
/// The compiled task stage, its entry point, and the color targets.
pub task: Option<TaskState<'a>>,
/// The compiled mesh stage and its entry point
pub mesh: MeshState<'a>,
/// The properties of the pipeline at the primitive assembly and rasterization level.
pub primitive: PrimitiveState,
/// The effect of draw calls on the depth and stencil aspects of the output target, if any.
pub depth_stencil: Option<DepthStencilState>,
/// The multi-sampling properties of the pipeline.
pub multisample: MultisampleState,
/// The compiled fragment stage, its entry point, and the color targets.
pub fragment: Option<FragmentState<'a>>,
/// If the pipeline will be used with a multiview render pass, this indicates how many array
/// layers the attachments will have.
pub multiview: Option<NonZeroU32>,
/// The pipeline cache to use when creating this pipeline.
pub cache: Option<&'a PipelineCache>,
}
#[cfg(send_sync)]
static_assertions::assert_impl_all!(MeshPipelineDescriptor<'_>: Send, Sync);

View File

@ -823,6 +823,12 @@ fn map_wgt_limits(limits: webgpu_sys::GpuSupportedLimits) -> wgt::Limits {
max_push_constant_size: wgt::Limits::default().max_push_constant_size,
max_non_sampler_bindings: wgt::Limits::default().max_non_sampler_bindings,
max_inter_stage_shader_components: wgt::Limits::default().max_inter_stage_shader_components,
max_task_workgroup_total_count: wgt::Limits::default().max_task_workgroup_total_count,
max_task_workgroups_per_dimension: wgt::Limits::default().max_task_workgroups_per_dimension,
max_mesh_output_layers: wgt::Limits::default().max_mesh_output_layers,
max_mesh_multiview_count: wgt::Limits::default().max_mesh_multiview_count,
max_blas_primitive_count: wgt::Limits::default().max_blas_primitive_count,
max_blas_geometry_count: wgt::Limits::default().max_blas_geometry_count,
max_tlas_instance_count: wgt::Limits::default().max_tlas_instance_count,
@ -2174,6 +2180,13 @@ impl dispatch::DeviceInterface for WebDevice {
.into()
}
fn create_mesh_pipeline(
&self,
_desc: &crate::MeshPipelineDescriptor<'_>,
) -> dispatch::DispatchRenderPipeline {
panic!("MESH_SHADER feature must be enabled to call create_mesh_pipeline")
}
fn create_compute_pipeline(
&self,
desc: &crate::ComputePipelineDescriptor<'_>,
@ -3415,6 +3428,10 @@ impl dispatch::RenderPassInterface for WebRenderPassEncoder {
)
}
fn draw_mesh_tasks(&mut self, _group_count_x: u32, _group_count_y: u32, _group_count_z: u32) {
panic!("MESH_SHADER feature must be enabled to call draw_mesh_tasks")
}
fn draw_indirect(
&mut self,
indirect_buffer: &dispatch::DispatchBuffer,
@ -3435,6 +3452,14 @@ impl dispatch::RenderPassInterface for WebRenderPassEncoder {
.draw_indexed_indirect_with_f64(&buffer.inner, indirect_offset as f64);
}
fn draw_mesh_tasks_indirect(
&mut self,
_indirect_buffer: &dispatch::DispatchBuffer,
_indirect_offset: crate::BufferAddress,
) {
panic!("MESH_SHADER feature must be enabled to call draw_mesh_tasks_indirect")
}
fn multi_draw_indirect(
&mut self,
indirect_buffer: &dispatch::DispatchBuffer,
@ -3465,6 +3490,15 @@ impl dispatch::RenderPassInterface for WebRenderPassEncoder {
}
}
fn multi_draw_mesh_tasks_indirect(
&mut self,
_indirect_buffer: &dispatch::DispatchBuffer,
_indirect_offset: crate::BufferAddress,
_count: u32,
) {
panic!("MESH_SHADER feature must be enabled to call multi_draw_mesh_tasks_indirect")
}
fn multi_draw_indirect_count(
&mut self,
_indirect_buffer: &dispatch::DispatchBuffer,
@ -3489,6 +3523,17 @@ impl dispatch::RenderPassInterface for WebRenderPassEncoder {
panic!("MULTI_DRAW_INDIRECT_COUNT feature must be enabled to call multi_draw_indexed_indirect_count")
}
fn multi_draw_mesh_tasks_indirect_count(
&mut self,
_indirect_buffer: &dispatch::DispatchBuffer,
_indirect_offset: crate::BufferAddress,
_count_buffer: &dispatch::DispatchBuffer,
_count_buffer_offset: crate::BufferAddress,
_max_count: u32,
) {
panic!("MESH_SHADER feature must be enabled to call multi_draw_mesh_tasks_indirect_count")
}
fn insert_debug_marker(&mut self, _label: &str) {
// Not available in gecko yet
// self.inner.insert_debug_marker(label);

View File

@ -1354,6 +1354,102 @@ impl dispatch::DeviceInterface for CoreDevice {
.into()
}
fn create_mesh_pipeline(
&self,
desc: &crate::MeshPipelineDescriptor<'_>,
) -> dispatch::DispatchRenderPipeline {
use wgc::pipeline as pipe;
let mesh_constants = desc
.mesh
.compilation_options
.constants
.iter()
.map(|&(key, value)| (String::from(key), value))
.collect();
let descriptor = pipe::MeshPipelineDescriptor {
label: desc.label.map(Borrowed),
task: desc.task.as_ref().map(|task| {
let task_constants = task
.compilation_options
.constants
.iter()
.map(|&(key, value)| (String::from(key), value))
.collect();
pipe::TaskState {
stage: pipe::ProgrammableStageDescriptor {
module: task.module.inner.as_core().id,
entry_point: task.entry_point.map(Borrowed),
constants: task_constants,
zero_initialize_workgroup_memory: desc
.mesh
.compilation_options
.zero_initialize_workgroup_memory,
},
}
}),
mesh: pipe::MeshState {
stage: pipe::ProgrammableStageDescriptor {
module: desc.mesh.module.inner.as_core().id,
entry_point: desc.mesh.entry_point.map(Borrowed),
constants: mesh_constants,
zero_initialize_workgroup_memory: desc
.mesh
.compilation_options
.zero_initialize_workgroup_memory,
},
},
layout: desc.layout.map(|layout| layout.inner.as_core().id),
primitive: desc.primitive,
depth_stencil: desc.depth_stencil.clone(),
multisample: desc.multisample,
fragment: desc.fragment.as_ref().map(|frag| {
let frag_constants = frag
.compilation_options
.constants
.iter()
.map(|&(key, value)| (String::from(key), value))
.collect();
pipe::FragmentState {
stage: pipe::ProgrammableStageDescriptor {
module: frag.module.inner.as_core().id,
entry_point: frag.entry_point.map(Borrowed),
constants: frag_constants,
zero_initialize_workgroup_memory: frag
.compilation_options
.zero_initialize_workgroup_memory,
},
targets: Borrowed(frag.targets),
}
}),
multiview: desc.multiview,
cache: desc.cache.map(|cache| cache.inner.as_core().id),
};
let (id, error) = self
.context
.0
.device_create_mesh_pipeline(self.id, &descriptor, None);
if let Some(cause) = error {
if let wgc::pipeline::CreateRenderPipelineError::Internal { stage, ref error } = cause {
log::error!("Shader translation error for stage {stage:?}: {error}");
log::error!("Please report it to https://github.com/gfx-rs/wgpu");
}
self.context.handle_error(
&self.error_sink,
cause,
desc.label,
"Device::create_render_pipeline",
);
}
CoreRenderPipeline {
context: self.context.clone(),
id,
error_sink: Arc::clone(&self.error_sink),
}
.into()
}
fn create_compute_pipeline(
&self,
desc: &crate::ComputePipelineDescriptor<'_>,
@ -3125,6 +3221,22 @@ impl dispatch::RenderPassInterface for CoreRenderPass {
}
}
fn draw_mesh_tasks(&mut self, group_count_x: u32, group_count_y: u32, group_count_z: u32) {
if let Err(cause) = self.context.0.render_pass_draw_mesh_tasks(
&mut self.pass,
group_count_x,
group_count_y,
group_count_z,
) {
self.context.handle_error(
&self.error_sink,
cause,
self.pass.label(),
"RenderPass::draw_mesh_tasks",
);
}
}
fn draw_indirect(
&mut self,
indirect_buffer: &dispatch::DispatchBuffer,
@ -3167,6 +3279,27 @@ impl dispatch::RenderPassInterface for CoreRenderPass {
}
}
fn draw_mesh_tasks_indirect(
&mut self,
indirect_buffer: &dispatch::DispatchBuffer,
indirect_offset: crate::BufferAddress,
) {
let indirect_buffer = indirect_buffer.as_core();
if let Err(cause) = self.context.0.render_pass_draw_mesh_tasks_indirect(
&mut self.pass,
indirect_buffer.id,
indirect_offset,
) {
self.context.handle_error(
&self.error_sink,
cause,
self.pass.label(),
"RenderPass::draw_mesh_tasks_indirect",
);
}
}
fn multi_draw_indirect(
&mut self,
indirect_buffer: &dispatch::DispatchBuffer,
@ -3213,6 +3346,29 @@ impl dispatch::RenderPassInterface for CoreRenderPass {
}
}
fn multi_draw_mesh_tasks_indirect(
&mut self,
indirect_buffer: &dispatch::DispatchBuffer,
indirect_offset: crate::BufferAddress,
count: u32,
) {
let indirect_buffer = indirect_buffer.as_core();
if let Err(cause) = self.context.0.render_pass_multi_draw_mesh_tasks_indirect(
&mut self.pass,
indirect_buffer.id,
indirect_offset,
count,
) {
self.context.handle_error(
&self.error_sink,
cause,
self.pass.label(),
"RenderPass::multi_draw_mesh_tasks_indirect",
);
}
}
fn multi_draw_indirect_count(
&mut self,
indirect_buffer: &dispatch::DispatchBuffer,
@ -3273,6 +3429,38 @@ impl dispatch::RenderPassInterface for CoreRenderPass {
}
}
fn multi_draw_mesh_tasks_indirect_count(
&mut self,
indirect_buffer: &dispatch::DispatchBuffer,
indirect_offset: crate::BufferAddress,
count_buffer: &dispatch::DispatchBuffer,
count_buffer_offset: crate::BufferAddress,
max_count: u32,
) {
let indirect_buffer = indirect_buffer.as_core();
let count_buffer = count_buffer.as_core();
if let Err(cause) = self
.context
.0
.render_pass_multi_draw_mesh_tasks_indirect_count(
&mut self.pass,
indirect_buffer.id,
indirect_offset,
count_buffer.id,
count_buffer_offset,
max_count,
)
{
self.context.handle_error(
&self.error_sink,
cause,
self.pass.label(),
"RenderPass::multi_draw_mesh_tasks_indirect_count",
);
}
}
fn insert_debug_marker(&mut self, label: &str) {
if let Err(cause) = self
.context

View File

@ -147,6 +147,10 @@ pub trait DeviceInterface: CommonTraits {
&self,
desc: &crate::RenderPipelineDescriptor<'_>,
) -> DispatchRenderPipeline;
fn create_mesh_pipeline(
&self,
desc: &crate::MeshPipelineDescriptor<'_>,
) -> DispatchRenderPipeline;
fn create_compute_pipeline(
&self,
desc: &crate::ComputePipelineDescriptor<'_>,
@ -420,6 +424,7 @@ pub trait RenderPassInterface: CommonTraits {
fn draw(&mut self, vertices: Range<u32>, instances: Range<u32>);
fn draw_indexed(&mut self, indices: Range<u32>, base_vertex: i32, instances: Range<u32>);
fn draw_mesh_tasks(&mut self, group_count_x: u32, group_count_y: u32, group_count_z: u32);
fn draw_indirect(
&mut self,
indirect_buffer: &DispatchBuffer,
@ -430,6 +435,11 @@ pub trait RenderPassInterface: CommonTraits {
indirect_buffer: &DispatchBuffer,
indirect_offset: crate::BufferAddress,
);
fn draw_mesh_tasks_indirect(
&mut self,
indirect_buffer: &DispatchBuffer,
indirect_offset: crate::BufferAddress,
);
fn multi_draw_indirect(
&mut self,
@ -451,6 +461,12 @@ pub trait RenderPassInterface: CommonTraits {
count_buffer_offset: crate::BufferAddress,
max_count: u32,
);
fn multi_draw_mesh_tasks_indirect(
&mut self,
indirect_buffer: &DispatchBuffer,
indirect_offset: crate::BufferAddress,
count: u32,
);
fn multi_draw_indexed_indirect_count(
&mut self,
indirect_buffer: &DispatchBuffer,
@ -459,6 +475,14 @@ pub trait RenderPassInterface: CommonTraits {
count_buffer_offset: crate::BufferAddress,
max_count: u32,
);
fn multi_draw_mesh_tasks_indirect_count(
&mut self,
indirect_buffer: &DispatchBuffer,
indirect_offset: crate::BufferAddress,
count_buffer: &DispatchBuffer,
count_buffer_offset: crate::BufferAddress,
max_count: u32,
);
fn insert_debug_marker(&mut self, label: &str);
fn push_debug_group(&mut self, group_label: &str);