[hal/metal] Mesh Shaders (#8139)

Co-authored-by: Connor Fitzgerald <connorwadefitzgerald@gmail.com>
Co-authored-by: Magnus <85136135+SupaMaggie70Incorporated@users.noreply.github.com>
This commit is contained in:
Inner Daemons 2025-11-14 23:11:43 -05:00 committed by GitHub
parent 92fa99af1b
commit d0cf78c8c8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 906 additions and 389 deletions

View File

@ -125,6 +125,9 @@ By @SupaMaggie70Incorporated in [#8206](https://github.com/gfx-rs/wgpu/pull/8206
- `util::StagingBelt` now takes a `Device` when it is created instead of when it is used. By @kpreid in [#8462](https://github.com/gfx-rs/wgpu/pull/8462). - `util::StagingBelt` now takes a `Device` when it is created instead of when it is used. By @kpreid in [#8462](https://github.com/gfx-rs/wgpu/pull/8462).
- `wgpu_hal::vulkan::Device::texture_from_raw` now takes an `external_memory` argument. By @s-ol in [#8512](https://github.com/gfx-rs/wgpu/pull/8512) - `wgpu_hal::vulkan::Device::texture_from_raw` now takes an `external_memory` argument. By @s-ol in [#8512](https://github.com/gfx-rs/wgpu/pull/8512)
#### Metal
- Add support for mesh shaders. By @SupaMaggie70Incorporated in [#8139](https://github.com/gfx-rs/wgpu/pull/8139)
#### Naga #### Naga
- Prevent UB with invalid ray query calls on spirv. By @Vecvec in [#8390](https://github.com/gfx-rs/wgpu/pull/8390). - Prevent UB with invalid ray query calls on spirv. By @Vecvec in [#8390](https://github.com/gfx-rs/wgpu/pull/8390).

View File

@ -49,6 +49,7 @@ fn all_tests() -> Vec<wgpu_test::GpuTestInitializer> {
cube::TEST, cube::TEST,
cube::TEST_LINES, cube::TEST_LINES,
hello_synchronization::tests::SYNC, hello_synchronization::tests::SYNC,
mesh_shader::TEST,
mipmap::TEST, mipmap::TEST,
mipmap::TEST_QUERY, mipmap::TEST_QUERY,
msaa_line::TEST, msaa_line::TEST,

View File

@ -61,6 +61,18 @@ fn compile_hlsl(device: &wgpu::Device, entry: &str, stage_str: &str) -> wgpu::Sh
} }
} }
fn compile_msl(device: &wgpu::Device, entry: &str) -> wgpu::ShaderModule {
unsafe {
device.create_shader_module_passthrough(wgpu::ShaderModuleDescriptorPassthrough {
entry_point: entry.to_owned(),
label: None,
msl: Some(std::borrow::Cow::Borrowed(include_str!("shader.metal"))),
num_workgroups: (1, 1, 1),
..Default::default()
})
}
}
pub struct Example { pub struct Example {
pipeline: wgpu::RenderPipeline, pipeline: wgpu::RenderPipeline,
} }
@ -71,20 +83,23 @@ impl crate::framework::Example for Example {
device: &wgpu::Device, device: &wgpu::Device,
_queue: &wgpu::Queue, _queue: &wgpu::Queue,
) -> Self { ) -> Self {
let (ts, ms, fs) = if adapter.get_info().backend == wgpu::Backend::Vulkan { let (ts, ms, fs) = match adapter.get_info().backend {
( wgpu::Backend::Vulkan => (
compile_glsl(device, "task"), compile_glsl(device, "task"),
compile_glsl(device, "mesh"), compile_glsl(device, "mesh"),
compile_glsl(device, "frag"), compile_glsl(device, "frag"),
) ),
} else if adapter.get_info().backend == wgpu::Backend::Dx12 { wgpu::Backend::Dx12 => (
(
compile_hlsl(device, "Task", "as"), compile_hlsl(device, "Task", "as"),
compile_hlsl(device, "Mesh", "ms"), compile_hlsl(device, "Mesh", "ms"),
compile_hlsl(device, "Frag", "ps"), compile_hlsl(device, "Frag", "ps"),
) ),
} else { wgpu::Backend::Metal => (
panic!("Example can only run on vulkan or dx12"); compile_msl(device, "taskShader"),
compile_msl(device, "meshShader"),
compile_msl(device, "fragShader"),
),
_ => panic!("Example can currently only run on vulkan, dx12 or metal"),
}; };
let pipeline_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor { let pipeline_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
label: None, label: None,
@ -179,3 +194,21 @@ impl crate::framework::Example for Example {
pub fn main() { pub fn main() {
crate::framework::run::<Example>("mesh_shader"); crate::framework::run::<Example>("mesh_shader");
} }
#[cfg(test)]
#[wgpu_test::gpu_test]
pub static TEST: crate::framework::ExampleTestParams = crate::framework::ExampleTestParams {
name: "mesh_shader",
image_path: "/examples/features/src/mesh_shader/screenshot.png",
width: 1024,
height: 768,
optional_features: wgpu::Features::default(),
base_test_parameters: wgpu_test::TestParameters::default()
.features(
wgpu::Features::EXPERIMENTAL_MESH_SHADER
| wgpu::Features::EXPERIMENTAL_PASSTHROUGH_SHADERS,
)
.limits(wgpu::Limits::defaults().using_recommended_minimum_mesh_shader_values()),
comparisons: &[wgpu_test::ComparisonType::Mean(0.01)],
_phantom: std::marker::PhantomData::<Example>,
};

Binary file not shown.

After

Width:  |  Height:  |  Size: 34 KiB

View File

@ -0,0 +1,77 @@
using namespace metal;
struct OutVertex {
float4 Position [[position]];
float4 Color [[user(locn0)]];
};
struct OutPrimitive {
float4 ColorMask [[flat]] [[user(locn1)]];
bool CullPrimitive [[primitive_culled]];
};
struct InVertex {
};
struct InPrimitive {
float4 ColorMask [[flat]] [[user(locn1)]];
};
struct FragmentIn {
float4 Color [[user(locn0)]];
float4 ColorMask [[flat]] [[user(locn1)]];
};
struct PayloadData {
float4 ColorMask;
bool Visible;
};
using Meshlet = metal::mesh<OutVertex, OutPrimitive, 3, 1, topology::triangle>;
constant float4 positions[3] = {
float4(0.0, 1.0, 0.0, 1.0),
float4(-1.0, -1.0, 0.0, 1.0),
float4(1.0, -1.0, 0.0, 1.0)
};
constant float4 colors[3] = {
float4(0.0, 1.0, 0.0, 1.0),
float4(0.0, 0.0, 1.0, 1.0),
float4(1.0, 0.0, 0.0, 1.0)
};
[[object]]
void taskShader(uint3 tid [[thread_position_in_grid]], object_data PayloadData &outPayload [[payload]], mesh_grid_properties grid) {
outPayload.ColorMask = float4(1.0, 1.0, 0.0, 1.0);
outPayload.Visible = true;
grid.set_threadgroups_per_grid(uint3(3, 1, 1));
}
[[mesh]]
void meshShader(
object_data PayloadData const& payload [[payload]],
Meshlet out
)
{
out.set_primitive_count(1);
for(int i = 0;i < 3;i++) {
OutVertex vert;
vert.Position = positions[i];
vert.Color = colors[i] * payload.ColorMask;
out.set_vertex(i, vert);
out.set_index(i, i);
}
OutPrimitive prim;
prim.ColorMask = float4(1.0, 0.0, 0.0, 1.0);
prim.CullPrimitive = !payload.Visible;
out.set_primitive(0, prim);
}
fragment float4 fragShader(FragmentIn data [[stage_in]]) {
return data.Color * data.ColorMask;
}

View File

@ -3,15 +3,11 @@ use std::{
process::Stdio, process::Stdio,
}; };
use wgpu::{util::DeviceExt, Backends}; use wgpu::util::DeviceExt;
use wgpu_test::{ use wgpu_test::{
fail, gpu_test, FailureCase, GpuTestConfiguration, GpuTestInitializer, TestParameters, fail, gpu_test, GpuTestConfiguration, GpuTestInitializer, TestParameters, TestingContext,
TestingContext,
}; };
/// Backends that support mesh shaders
const MESH_SHADER_BACKENDS: Backends = Backends::DX12.union(Backends::VULKAN);
pub fn all_tests(tests: &mut Vec<GpuTestInitializer>) { pub fn all_tests(tests: &mut Vec<GpuTestInitializer>) {
tests.extend([ tests.extend([
MESH_PIPELINE_BASIC_MESH, MESH_PIPELINE_BASIC_MESH,
@ -98,6 +94,18 @@ fn compile_hlsl(
} }
} }
fn compile_msl(device: &wgpu::Device, entry: &str) -> wgpu::ShaderModule {
unsafe {
device.create_shader_module_passthrough(wgpu::ShaderModuleDescriptorPassthrough {
entry_point: entry.to_owned(),
label: None,
msl: Some(std::borrow::Cow::Borrowed(include_str!("shader.metal"))),
num_workgroups: (1, 1, 1),
..Default::default()
})
}
}
fn get_shaders( fn get_shaders(
device: &wgpu::Device, device: &wgpu::Device,
backend: wgpu::Backend, backend: wgpu::Backend,
@ -114,8 +122,8 @@ fn get_shaders(
// (In the case that the platform does support mesh shaders, the dummy // (In the case that the platform does support mesh shaders, the dummy
// shader is used to avoid requiring EXPERIMENTAL_PASSTHROUGH_SHADERS.) // shader is used to avoid requiring EXPERIMENTAL_PASSTHROUGH_SHADERS.)
let dummy_shader = device.create_shader_module(wgpu::include_wgsl!("non_mesh.wgsl")); let dummy_shader = device.create_shader_module(wgpu::include_wgsl!("non_mesh.wgsl"));
if backend == wgpu::Backend::Vulkan { match backend {
( wgpu::Backend::Vulkan => (
info.use_task.then(|| compile_glsl(device, "task")), info.use_task.then(|| compile_glsl(device, "task")),
if info.use_mesh { if info.use_mesh {
compile_glsl(device, "mesh") compile_glsl(device, "mesh")
@ -123,9 +131,8 @@ fn get_shaders(
dummy_shader dummy_shader
}, },
info.use_frag.then(|| compile_glsl(device, "frag")), info.use_frag.then(|| compile_glsl(device, "frag")),
) ),
} else if backend == wgpu::Backend::Dx12 { wgpu::Backend::Dx12 => (
(
info.use_task info.use_task
.then(|| compile_hlsl(device, "Task", "as", test_name)), .then(|| compile_hlsl(device, "Task", "as", test_name)),
if info.use_mesh { if info.use_mesh {
@ -135,11 +142,20 @@ fn get_shaders(
}, },
info.use_frag info.use_frag
.then(|| compile_hlsl(device, "Frag", "ps", test_name)), .then(|| compile_hlsl(device, "Frag", "ps", test_name)),
) ),
} else { wgpu::Backend::Metal => (
assert!(!MESH_SHADER_BACKENDS.contains(Backends::from(backend))); info.use_task.then(|| compile_msl(device, "taskShader")),
assert!(!info.use_task && !info.use_mesh && !info.use_frag); if info.use_mesh {
(None, dummy_shader, None) compile_msl(device, "meshShader")
} else {
dummy_shader
},
info.use_frag.then(|| compile_msl(device, "fragShader")),
),
_ => {
assert!(!info.use_task && !info.use_mesh && !info.use_frag);
(None, dummy_shader, None)
}
} }
} }
@ -377,7 +393,6 @@ fn mesh_draw(ctx: &TestingContext, draw_type: DrawType) {
fn default_gpu_test_config(draw_type: DrawType) -> GpuTestConfiguration { fn default_gpu_test_config(draw_type: DrawType) -> GpuTestConfiguration {
GpuTestConfiguration::new().parameters( GpuTestConfiguration::new().parameters(
TestParameters::default() TestParameters::default()
.skip(FailureCase::backend(!MESH_SHADER_BACKENDS))
.test_features_limits() .test_features_limits()
.features( .features(
wgpu::Features::EXPERIMENTAL_MESH_SHADER wgpu::Features::EXPERIMENTAL_MESH_SHADER

View File

@ -0,0 +1,77 @@
using namespace metal;
struct OutVertex {
float4 Position [[position]];
float4 Color [[user(locn0)]];
};
struct OutPrimitive {
float4 ColorMask [[flat]] [[user(locn1)]];
bool CullPrimitive [[primitive_culled]];
};
struct InVertex {
};
struct InPrimitive {
float4 ColorMask [[flat]] [[user(locn1)]];
};
struct FragmentIn {
float4 Color [[user(locn0)]];
float4 ColorMask [[flat]] [[user(locn1)]];
};
struct PayloadData {
float4 ColorMask;
bool Visible;
};
using Meshlet = metal::mesh<OutVertex, OutPrimitive, 3, 1, topology::triangle>;
constant float4 positions[3] = {
float4(0.0, 1.0, 0.0, 1.0),
float4(-1.0, -1.0, 0.0, 1.0),
float4(1.0, -1.0, 0.0, 1.0)
};
constant float4 colors[3] = {
float4(0.0, 1.0, 0.0, 1.0),
float4(0.0, 0.0, 1.0, 1.0),
float4(1.0, 0.0, 0.0, 1.0)
};
[[object]]
void taskShader(uint3 tid [[thread_position_in_grid]], object_data PayloadData &outPayload [[payload]], mesh_grid_properties grid) {
outPayload.ColorMask = float4(1.0, 1.0, 0.0, 1.0);
outPayload.Visible = true;
grid.set_threadgroups_per_grid(uint3(3, 1, 1));
}
[[mesh]]
void meshShader(
object_data PayloadData const& payload [[payload]],
Meshlet out
)
{
out.set_primitive_count(1);
for(int i = 0;i < 3;i++) {
OutVertex vert;
vert.Position = positions[i];
vert.Color = colors[i] * payload.ColorMask;
out.set_vertex(i, vert);
out.set_index(i, i);
}
OutPrimitive prim;
prim.ColorMask = float4(1.0, 0.0, 0.0, 1.0);
prim.CullPrimitive = !payload.Visible;
out.set_primitive(0, prim);
}
fragment float4 fragShader(FragmentIn data [[stage_in]]) {
return data.Color * data.ColorMask;
}

View File

@ -607,6 +607,8 @@ impl super::PrivateCapabilities {
let argument_buffers = device.argument_buffers_support(); let argument_buffers = device.argument_buffers_support();
let is_virtual = device.name().to_lowercase().contains("virtual");
Self { Self {
family_check, family_check,
msl_version: if os_is_xr || version.at_least((14, 0), (17, 0), os_is_mac) { msl_version: if os_is_xr || version.at_least((14, 0), (17, 0), os_is_mac) {
@ -902,6 +904,12 @@ impl super::PrivateCapabilities {
&& (device.supports_family(MTLGPUFamily::Apple7) && (device.supports_family(MTLGPUFamily::Apple7)
|| device.supports_family(MTLGPUFamily::Mac2)), || device.supports_family(MTLGPUFamily::Mac2)),
supports_shared_event: version.at_least((10, 14), (12, 0), os_is_mac), supports_shared_event: version.at_least((10, 14), (12, 0), os_is_mac),
mesh_shaders: family_check
&& (device.supports_family(MTLGPUFamily::Metal3)
|| device.supports_family(MTLGPUFamily::Apple7)
|| device.supports_family(MTLGPUFamily::Mac2))
// Mesh shaders don't work on virtual devices even if they should be supported.
&& !is_virtual,
supported_vertex_amplification_factor: { supported_vertex_amplification_factor: {
let mut factor = 1; let mut factor = 1;
// https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf#page=8 // https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf#page=8
@ -1023,6 +1031,8 @@ impl super::PrivateCapabilities {
features.insert(F::SUBGROUP | F::SUBGROUP_BARRIER); features.insert(F::SUBGROUP | F::SUBGROUP_BARRIER);
} }
features.set(F::EXPERIMENTAL_MESH_SHADER, self.mesh_shaders);
if self.supported_vertex_amplification_factor > 1 { if self.supported_vertex_amplification_factor > 1 {
features.insert(F::MULTIVIEW); features.insert(F::MULTIVIEW);
} }
@ -1102,10 +1112,11 @@ impl super::PrivateCapabilities {
max_buffer_size: self.max_buffer_size, max_buffer_size: self.max_buffer_size,
max_non_sampler_bindings: u32::MAX, max_non_sampler_bindings: u32::MAX,
max_task_workgroup_total_count: 0, // See https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf, Maximum threadgroups per mesh shader grid
max_task_workgroups_per_dimension: 0, max_task_workgroup_total_count: 1024,
max_task_workgroups_per_dimension: 1024,
max_mesh_multiview_view_count: 0, max_mesh_multiview_view_count: 0,
max_mesh_output_layers: 0, max_mesh_output_layers: self.max_texture_layers as u32,
max_blas_primitive_count: 0, // When added: 2^28 from https://developer.apple.com/documentation/metal/mtlaccelerationstructureusage/extendedlimits max_blas_primitive_count: 0, // When added: 2^28 from https://developer.apple.com/documentation/metal/mtlaccelerationstructureusage/extendedlimits
max_blas_geometry_count: 0, // When added: 2^24 max_blas_geometry_count: 0, // When added: 2^24

View File

@ -22,11 +22,9 @@ impl Default for super::CommandState {
compute: None, compute: None,
raw_primitive_type: MTLPrimitiveType::Point, raw_primitive_type: MTLPrimitiveType::Point,
index: None, index: None,
raw_wg_size: MTLSize::new(0, 0, 0),
stage_infos: Default::default(), stage_infos: Default::default(),
storage_buffer_length_map: Default::default(), storage_buffer_length_map: Default::default(),
vertex_buffer_size_map: Default::default(), vertex_buffer_size_map: Default::default(),
work_group_memory_sizes: Vec::new(),
push_constants: Vec::new(), push_constants: Vec::new(),
pending_timer_queries: Vec::new(), pending_timer_queries: Vec::new(),
} }
@ -146,6 +144,127 @@ impl super::CommandEncoder {
self.state.reset(); self.state.reset();
self.leave_blit(); self.leave_blit();
} }
/// Updates the bindings for a single shader stage, called in `set_bind_group`.
#[expect(clippy::too_many_arguments)]
fn update_bind_group_state(
&mut self,
stage: naga::ShaderStage,
render_encoder: Option<&metal::RenderCommandEncoder>,
compute_encoder: Option<&metal::ComputeCommandEncoder>,
index_base: super::ResourceData<u32>,
bg_info: &super::BindGroupLayoutInfo,
dynamic_offsets: &[wgt::DynamicOffset],
group_index: u32,
group: &super::BindGroup,
) {
let resource_indices = match stage {
naga::ShaderStage::Vertex => &bg_info.base_resource_indices.vs,
naga::ShaderStage::Fragment => &bg_info.base_resource_indices.fs,
naga::ShaderStage::Task => &bg_info.base_resource_indices.ts,
naga::ShaderStage::Mesh => &bg_info.base_resource_indices.ms,
naga::ShaderStage::Compute => &bg_info.base_resource_indices.cs,
};
let buffers = match stage {
naga::ShaderStage::Vertex => group.counters.vs.buffers,
naga::ShaderStage::Fragment => group.counters.fs.buffers,
naga::ShaderStage::Task => group.counters.ts.buffers,
naga::ShaderStage::Mesh => group.counters.ms.buffers,
naga::ShaderStage::Compute => group.counters.cs.buffers,
};
let mut changes_sizes_buffer = false;
for index in 0..buffers {
let buf = &group.buffers[(index_base.buffers + index) as usize];
let mut offset = buf.offset;
if let Some(dyn_index) = buf.dynamic_index {
offset += dynamic_offsets[dyn_index as usize] as wgt::BufferAddress;
}
let a1 = (resource_indices.buffers + index) as u64;
let a2 = Some(buf.ptr.as_native());
let a3 = offset;
match stage {
naga::ShaderStage::Vertex => render_encoder.unwrap().set_vertex_buffer(a1, a2, a3),
naga::ShaderStage::Fragment => {
render_encoder.unwrap().set_fragment_buffer(a1, a2, a3)
}
naga::ShaderStage::Task => render_encoder.unwrap().set_object_buffer(a1, a2, a3),
naga::ShaderStage::Mesh => render_encoder.unwrap().set_mesh_buffer(a1, a2, a3),
naga::ShaderStage::Compute => compute_encoder.unwrap().set_buffer(a1, a2, a3),
}
if let Some(size) = buf.binding_size {
let br = naga::ResourceBinding {
group: group_index,
binding: buf.binding_location,
};
self.state.storage_buffer_length_map.insert(br, size);
changes_sizes_buffer = true;
}
}
if changes_sizes_buffer {
if let Some((index, sizes)) = self
.state
.make_sizes_buffer_update(stage, &mut self.temp.binding_sizes)
{
let a1 = index as _;
let a2 = (sizes.len() * WORD_SIZE) as u64;
let a3 = sizes.as_ptr().cast();
match stage {
naga::ShaderStage::Vertex => {
render_encoder.unwrap().set_vertex_bytes(a1, a2, a3)
}
naga::ShaderStage::Fragment => {
render_encoder.unwrap().set_fragment_bytes(a1, a2, a3)
}
naga::ShaderStage::Task => render_encoder.unwrap().set_object_bytes(a1, a2, a3),
naga::ShaderStage::Mesh => render_encoder.unwrap().set_mesh_bytes(a1, a2, a3),
naga::ShaderStage::Compute => compute_encoder.unwrap().set_bytes(a1, a2, a3),
}
}
}
let samplers = match stage {
naga::ShaderStage::Vertex => group.counters.vs.samplers,
naga::ShaderStage::Fragment => group.counters.fs.samplers,
naga::ShaderStage::Task => group.counters.ts.samplers,
naga::ShaderStage::Mesh => group.counters.ms.samplers,
naga::ShaderStage::Compute => group.counters.cs.samplers,
};
for index in 0..samplers {
let res = group.samplers[(index_base.samplers + index) as usize];
let a1 = (resource_indices.samplers + index) as u64;
let a2 = Some(res.as_native());
match stage {
naga::ShaderStage::Vertex => {
render_encoder.unwrap().set_vertex_sampler_state(a1, a2)
}
naga::ShaderStage::Fragment => {
render_encoder.unwrap().set_fragment_sampler_state(a1, a2)
}
naga::ShaderStage::Task => render_encoder.unwrap().set_object_sampler_state(a1, a2),
naga::ShaderStage::Mesh => render_encoder.unwrap().set_mesh_sampler_state(a1, a2),
naga::ShaderStage::Compute => compute_encoder.unwrap().set_sampler_state(a1, a2),
}
}
let textures = match stage {
naga::ShaderStage::Vertex => group.counters.vs.textures,
naga::ShaderStage::Fragment => group.counters.fs.textures,
naga::ShaderStage::Task => group.counters.ts.textures,
naga::ShaderStage::Mesh => group.counters.ms.textures,
naga::ShaderStage::Compute => group.counters.cs.textures,
};
for index in 0..textures {
let res = group.textures[(index_base.textures + index) as usize];
let a1 = (resource_indices.textures + index) as u64;
let a2 = Some(res.as_native());
match stage {
naga::ShaderStage::Vertex => render_encoder.unwrap().set_vertex_texture(a1, a2),
naga::ShaderStage::Fragment => render_encoder.unwrap().set_fragment_texture(a1, a2),
naga::ShaderStage::Task => render_encoder.unwrap().set_object_texture(a1, a2),
naga::ShaderStage::Mesh => render_encoder.unwrap().set_mesh_texture(a1, a2),
naga::ShaderStage::Compute => compute_encoder.unwrap().set_texture(a1, a2),
}
}
}
} }
impl super::CommandState { impl super::CommandState {
@ -155,7 +274,8 @@ impl super::CommandState {
self.stage_infos.vs.clear(); self.stage_infos.vs.clear();
self.stage_infos.fs.clear(); self.stage_infos.fs.clear();
self.stage_infos.cs.clear(); self.stage_infos.cs.clear();
self.work_group_memory_sizes.clear(); self.stage_infos.ts.clear();
self.stage_infos.ms.clear();
self.push_constants.clear(); self.push_constants.clear();
} }
@ -702,168 +822,90 @@ impl crate::CommandEncoder for super::CommandEncoder {
dynamic_offsets: &[wgt::DynamicOffset], dynamic_offsets: &[wgt::DynamicOffset],
) { ) {
let bg_info = &layout.bind_group_infos[group_index as usize]; let bg_info = &layout.bind_group_infos[group_index as usize];
let render_encoder = self.state.render.clone();
if let Some(ref encoder) = self.state.render { let compute_encoder = self.state.compute.clone();
let mut changes_sizes_buffer = false; if let Some(encoder) = render_encoder {
for index in 0..group.counters.vs.buffers { self.update_bind_group_state(
let buf = &group.buffers[index as usize]; naga::ShaderStage::Vertex,
let mut offset = buf.offset; Some(&encoder),
if let Some(dyn_index) = buf.dynamic_index { None,
offset += dynamic_offsets[dyn_index as usize] as wgt::BufferAddress; // All zeros, as vs comes first
} super::ResourceData::default(),
encoder.set_vertex_buffer( bg_info,
(bg_info.base_resource_indices.vs.buffers + index) as u64, dynamic_offsets,
Some(buf.ptr.as_native()), group_index,
offset, group,
); );
if let Some(size) = buf.binding_size { self.update_bind_group_state(
let br = naga::ResourceBinding { naga::ShaderStage::Task,
group: group_index, Some(&encoder),
binding: buf.binding_location, None,
}; // All zeros, as ts comes first
self.state.storage_buffer_length_map.insert(br, size); super::ResourceData::default(),
changes_sizes_buffer = true; bg_info,
} dynamic_offsets,
} group_index,
if changes_sizes_buffer { group,
if let Some((index, sizes)) = self.state.make_sizes_buffer_update( );
naga::ShaderStage::Vertex, self.update_bind_group_state(
&mut self.temp.binding_sizes, naga::ShaderStage::Mesh,
) { Some(&encoder),
encoder.set_vertex_bytes( None,
index as _, group.counters.ts.clone(),
(sizes.len() * WORD_SIZE) as u64, bg_info,
sizes.as_ptr().cast(), dynamic_offsets,
); group_index,
} group,
} );
self.update_bind_group_state(
changes_sizes_buffer = false; naga::ShaderStage::Fragment,
for index in 0..group.counters.fs.buffers { Some(&encoder),
let buf = &group.buffers[(group.counters.vs.buffers + index) as usize]; None,
let mut offset = buf.offset; super::ResourceData {
if let Some(dyn_index) = buf.dynamic_index { buffers: group.counters.vs.buffers
offset += dynamic_offsets[dyn_index as usize] as wgt::BufferAddress; + group.counters.ts.buffers
} + group.counters.ms.buffers,
encoder.set_fragment_buffer( textures: group.counters.vs.textures
(bg_info.base_resource_indices.fs.buffers + index) as u64, + group.counters.ts.textures
Some(buf.ptr.as_native()), + group.counters.ms.textures,
offset, samplers: group.counters.vs.samplers
); + group.counters.ts.samplers
if let Some(size) = buf.binding_size { + group.counters.ms.samplers,
let br = naga::ResourceBinding { },
group: group_index, bg_info,
binding: buf.binding_location, dynamic_offsets,
}; group_index,
self.state.storage_buffer_length_map.insert(br, size); group,
changes_sizes_buffer = true; );
}
}
if changes_sizes_buffer {
if let Some((index, sizes)) = self.state.make_sizes_buffer_update(
naga::ShaderStage::Fragment,
&mut self.temp.binding_sizes,
) {
encoder.set_fragment_bytes(
index as _,
(sizes.len() * WORD_SIZE) as u64,
sizes.as_ptr().cast(),
);
}
}
for index in 0..group.counters.vs.samplers {
let res = group.samplers[index as usize];
encoder.set_vertex_sampler_state(
(bg_info.base_resource_indices.vs.samplers + index) as u64,
Some(res.as_native()),
);
}
for index in 0..group.counters.fs.samplers {
let res = group.samplers[(group.counters.vs.samplers + index) as usize];
encoder.set_fragment_sampler_state(
(bg_info.base_resource_indices.fs.samplers + index) as u64,
Some(res.as_native()),
);
}
for index in 0..group.counters.vs.textures {
let res = group.textures[index as usize];
encoder.set_vertex_texture(
(bg_info.base_resource_indices.vs.textures + index) as u64,
Some(res.as_native()),
);
}
for index in 0..group.counters.fs.textures {
let res = group.textures[(group.counters.vs.textures + index) as usize];
encoder.set_fragment_texture(
(bg_info.base_resource_indices.fs.textures + index) as u64,
Some(res.as_native()),
);
}
// Call useResource on all textures and buffers used indirectly so they are alive // Call useResource on all textures and buffers used indirectly so they are alive
for (resource, use_info) in group.resources_to_use.iter() { for (resource, use_info) in group.resources_to_use.iter() {
encoder.use_resource_at(resource.as_native(), use_info.uses, use_info.stages); encoder.use_resource_at(resource.as_native(), use_info.uses, use_info.stages);
} }
} }
if let Some(encoder) = compute_encoder {
if let Some(ref encoder) = self.state.compute { self.update_bind_group_state(
let index_base = super::ResourceData { naga::ShaderStage::Compute,
buffers: group.counters.vs.buffers + group.counters.fs.buffers, None,
samplers: group.counters.vs.samplers + group.counters.fs.samplers, Some(&encoder),
textures: group.counters.vs.textures + group.counters.fs.textures, super::ResourceData {
}; buffers: group.counters.vs.buffers
+ group.counters.ts.buffers
let mut changes_sizes_buffer = false; + group.counters.ms.buffers
for index in 0..group.counters.cs.buffers { + group.counters.fs.buffers,
let buf = &group.buffers[(index_base.buffers + index) as usize]; textures: group.counters.vs.textures
let mut offset = buf.offset; + group.counters.ts.textures
if let Some(dyn_index) = buf.dynamic_index { + group.counters.ms.textures
offset += dynamic_offsets[dyn_index as usize] as wgt::BufferAddress; + group.counters.fs.textures,
} samplers: group.counters.vs.samplers
encoder.set_buffer( + group.counters.ts.samplers
(bg_info.base_resource_indices.cs.buffers + index) as u64, + group.counters.ms.samplers
Some(buf.ptr.as_native()), + group.counters.fs.samplers,
offset, },
); bg_info,
if let Some(size) = buf.binding_size { dynamic_offsets,
let br = naga::ResourceBinding { group_index,
group: group_index, group,
binding: buf.binding_location, );
};
self.state.storage_buffer_length_map.insert(br, size);
changes_sizes_buffer = true;
}
}
if changes_sizes_buffer {
if let Some((index, sizes)) = self.state.make_sizes_buffer_update(
naga::ShaderStage::Compute,
&mut self.temp.binding_sizes,
) {
encoder.set_bytes(
index as _,
(sizes.len() * WORD_SIZE) as u64,
sizes.as_ptr().cast(),
);
}
}
for index in 0..group.counters.cs.samplers {
let res = group.samplers[(index_base.samplers + index) as usize];
encoder.set_sampler_state(
(bg_info.base_resource_indices.cs.samplers + index) as u64,
Some(res.as_native()),
);
}
for index in 0..group.counters.cs.textures {
let res = group.textures[(index_base.textures + index) as usize];
encoder.set_texture(
(bg_info.base_resource_indices.cs.textures + index) as u64,
Some(res.as_native()),
);
}
// Call useResource on all textures and buffers used indirectly so they are alive // Call useResource on all textures and buffers used indirectly so they are alive
for (resource, use_info) in group.resources_to_use.iter() { for (resource, use_info) in group.resources_to_use.iter() {
if !use_info.visible_in_compute { if !use_info.visible_in_compute {
@ -911,6 +953,20 @@ impl crate::CommandEncoder for super::CommandEncoder {
state_pc.as_ptr().cast(), state_pc.as_ptr().cast(),
) )
} }
if stages.contains(wgt::ShaderStages::TASK) {
self.state.render.as_ref().unwrap().set_object_bytes(
layout.push_constants_infos.ts.unwrap().buffer_index as _,
(layout.total_push_constants as usize * WORD_SIZE) as _,
state_pc.as_ptr().cast(),
)
}
if stages.contains(wgt::ShaderStages::MESH) {
self.state.render.as_ref().unwrap().set_object_bytes(
layout.push_constants_infos.ms.unwrap().buffer_index as _,
(layout.total_push_constants as usize * WORD_SIZE) as _,
state_pc.as_ptr().cast(),
)
}
} }
unsafe fn insert_debug_marker(&mut self, label: &str) { unsafe fn insert_debug_marker(&mut self, label: &str) {
@ -935,11 +991,22 @@ impl crate::CommandEncoder for super::CommandEncoder {
unsafe fn set_render_pipeline(&mut self, pipeline: &super::RenderPipeline) { unsafe fn set_render_pipeline(&mut self, pipeline: &super::RenderPipeline) {
self.state.raw_primitive_type = pipeline.raw_primitive_type; self.state.raw_primitive_type = pipeline.raw_primitive_type;
self.state.stage_infos.vs.assign_from(&pipeline.vs_info); match pipeline.vs_info {
Some(ref info) => self.state.stage_infos.vs.assign_from(info),
None => self.state.stage_infos.vs.clear(),
}
match pipeline.fs_info { match pipeline.fs_info {
Some(ref info) => self.state.stage_infos.fs.assign_from(info), Some(ref info) => self.state.stage_infos.fs.assign_from(info),
None => self.state.stage_infos.fs.clear(), None => self.state.stage_infos.fs.clear(),
} }
match pipeline.ts_info {
Some(ref info) => self.state.stage_infos.ts.assign_from(info),
None => self.state.stage_infos.ts.clear(),
}
match pipeline.ms_info {
Some(ref info) => self.state.stage_infos.ms.assign_from(info),
None => self.state.stage_infos.ms.clear(),
}
let encoder = self.state.render.as_ref().unwrap(); let encoder = self.state.render.as_ref().unwrap();
encoder.set_render_pipeline_state(&pipeline.raw); encoder.set_render_pipeline_state(&pipeline.raw);
@ -954,7 +1021,7 @@ impl crate::CommandEncoder for super::CommandEncoder {
encoder.set_depth_bias(bias.constant as f32, bias.slope_scale, bias.clamp); encoder.set_depth_bias(bias.constant as f32, bias.slope_scale, bias.clamp);
} }
{ if pipeline.vs_info.is_some() {
if let Some((index, sizes)) = self if let Some((index, sizes)) = self
.state .state
.make_sizes_buffer_update(naga::ShaderStage::Vertex, &mut self.temp.binding_sizes) .make_sizes_buffer_update(naga::ShaderStage::Vertex, &mut self.temp.binding_sizes)
@ -966,7 +1033,7 @@ impl crate::CommandEncoder for super::CommandEncoder {
); );
} }
} }
if pipeline.fs_lib.is_some() { if pipeline.fs_info.is_some() {
if let Some((index, sizes)) = self if let Some((index, sizes)) = self
.state .state
.make_sizes_buffer_update(naga::ShaderStage::Fragment, &mut self.temp.binding_sizes) .make_sizes_buffer_update(naga::ShaderStage::Fragment, &mut self.temp.binding_sizes)
@ -978,6 +1045,56 @@ impl crate::CommandEncoder for super::CommandEncoder {
); );
} }
} }
if let Some(ts_info) = &pipeline.ts_info {
// update the threadgroup memory sizes
while self.state.stage_infos.ms.work_group_memory_sizes.len()
< ts_info.work_group_memory_sizes.len()
{
self.state.stage_infos.ms.work_group_memory_sizes.push(0);
}
for (index, (cur_size, pipeline_size)) in self
.state
.stage_infos
.ms
.work_group_memory_sizes
.iter_mut()
.zip(ts_info.work_group_memory_sizes.iter())
.enumerate()
{
let size = pipeline_size.next_multiple_of(16);
if *cur_size != size {
*cur_size = size;
encoder.set_object_threadgroup_memory_length(index as _, size as _);
}
}
if let Some((index, sizes)) = self
.state
.make_sizes_buffer_update(naga::ShaderStage::Task, &mut self.temp.binding_sizes)
{
encoder.set_object_bytes(
index as _,
(sizes.len() * WORD_SIZE) as u64,
sizes.as_ptr().cast(),
);
}
}
if let Some(_ms_info) = &pipeline.ms_info {
// So there isn't an equivalent to
// https://developer.apple.com/documentation/metal/mtlrendercommandencoder/setthreadgroupmemorylength(_:offset:index:)
// for mesh shaders. This is probably because the CPU has less control over the dispatch sizes and such. Interestingly
// it also affects mesh shaders without task/object shaders, even though none of compute, task or fragment shaders
// behave this way.
if let Some((index, sizes)) = self
.state
.make_sizes_buffer_update(naga::ShaderStage::Mesh, &mut self.temp.binding_sizes)
{
encoder.set_mesh_bytes(
index as _,
(sizes.len() * WORD_SIZE) as u64,
sizes.as_ptr().cast(),
);
}
}
} }
unsafe fn set_index_buffer<'a>( unsafe fn set_index_buffer<'a>(
@ -1140,11 +1257,21 @@ impl crate::CommandEncoder for super::CommandEncoder {
unsafe fn draw_mesh_tasks( unsafe fn draw_mesh_tasks(
&mut self, &mut self,
_group_count_x: u32, group_count_x: u32,
_group_count_y: u32, group_count_y: u32,
_group_count_z: u32, group_count_z: u32,
) { ) {
unreachable!() let encoder = self.state.render.as_ref().unwrap();
let size = MTLSize {
width: group_count_x as u64,
height: group_count_y as u64,
depth: group_count_z as u64,
};
encoder.draw_mesh_threadgroups(
size,
self.state.stage_infos.ts.raw_wg_size,
self.state.stage_infos.ms.raw_wg_size,
);
} }
unsafe fn draw_indirect( unsafe fn draw_indirect(
@ -1183,11 +1310,20 @@ impl crate::CommandEncoder for super::CommandEncoder {
unsafe fn draw_mesh_tasks_indirect( unsafe fn draw_mesh_tasks_indirect(
&mut self, &mut self,
_buffer: &<Self::A as crate::Api>::Buffer, buffer: &<Self::A as crate::Api>::Buffer,
_offset: wgt::BufferAddress, mut offset: wgt::BufferAddress,
_draw_count: u32, draw_count: u32,
) { ) {
unreachable!() let encoder = self.state.render.as_ref().unwrap();
for _ in 0..draw_count {
encoder.draw_mesh_threadgroups_with_indirect_buffer(
&buffer.raw,
offset,
self.state.stage_infos.ts.raw_wg_size,
self.state.stage_infos.ms.raw_wg_size,
);
offset += size_of::<wgt::DispatchIndirectArgs>() as wgt::BufferAddress;
}
} }
unsafe fn draw_indirect_count( unsafe fn draw_indirect_count(
@ -1295,7 +1431,8 @@ impl crate::CommandEncoder for super::CommandEncoder {
} }
unsafe fn set_compute_pipeline(&mut self, pipeline: &super::ComputePipeline) { unsafe fn set_compute_pipeline(&mut self, pipeline: &super::ComputePipeline) {
self.state.raw_wg_size = pipeline.work_group_size; let previous_sizes =
core::mem::take(&mut self.state.stage_infos.cs.work_group_memory_sizes);
self.state.stage_infos.cs.assign_from(&pipeline.cs_info); self.state.stage_infos.cs.assign_from(&pipeline.cs_info);
let encoder = self.state.compute.as_ref().unwrap(); let encoder = self.state.compute.as_ref().unwrap();
@ -1313,20 +1450,23 @@ impl crate::CommandEncoder for super::CommandEncoder {
} }
// update the threadgroup memory sizes // update the threadgroup memory sizes
while self.state.work_group_memory_sizes.len() < pipeline.work_group_memory_sizes.len() { for (i, current_size) in self
self.state.work_group_memory_sizes.push(0);
}
for (index, (cur_size, pipeline_size)) in self
.state .state
.stage_infos
.cs
.work_group_memory_sizes .work_group_memory_sizes
.iter_mut() .iter_mut()
.zip(pipeline.work_group_memory_sizes.iter())
.enumerate() .enumerate()
{ {
let size = pipeline_size.next_multiple_of(16); let prev_size = if i < previous_sizes.len() {
if *cur_size != size { previous_sizes[i]
*cur_size = size; } else {
encoder.set_threadgroup_memory_length(index as _, size as _); u32::MAX
};
let size: u32 = current_size.next_multiple_of(16);
*current_size = size;
if size != prev_size {
encoder.set_threadgroup_memory_length(i as _, size as _);
} }
} }
} }
@ -1339,13 +1479,17 @@ impl crate::CommandEncoder for super::CommandEncoder {
height: count[1] as u64, height: count[1] as u64,
depth: count[2] as u64, depth: count[2] as u64,
}; };
encoder.dispatch_thread_groups(raw_count, self.state.raw_wg_size); encoder.dispatch_thread_groups(raw_count, self.state.stage_infos.cs.raw_wg_size);
} }
} }
unsafe fn dispatch_indirect(&mut self, buffer: &super::Buffer, offset: wgt::BufferAddress) { unsafe fn dispatch_indirect(&mut self, buffer: &super::Buffer, offset: wgt::BufferAddress) {
let encoder = self.state.compute.as_ref().unwrap(); let encoder = self.state.compute.as_ref().unwrap();
encoder.dispatch_thread_groups_indirect(&buffer.raw, offset, self.state.raw_wg_size); encoder.dispatch_thread_groups_indirect(
&buffer.raw,
offset,
self.state.stage_infos.cs.raw_wg_size,
);
} }
unsafe fn build_acceleration_structures<'a, T>( unsafe fn build_acceleration_structures<'a, T>(

View File

@ -1113,16 +1113,261 @@ impl crate::Device for super::Device {
super::PipelineCache, super::PipelineCache,
>, >,
) -> Result<super::RenderPipeline, crate::PipelineError> { ) -> Result<super::RenderPipeline, crate::PipelineError> {
let (desc_vertex_stage, desc_vertex_buffers) = match &desc.vertex_processor {
crate::VertexProcessor::Standard {
vertex_buffers,
vertex_stage,
} => (vertex_stage, *vertex_buffers),
crate::VertexProcessor::Mesh { .. } => unreachable!(),
};
objc::rc::autoreleasepool(|| { objc::rc::autoreleasepool(|| {
let descriptor = metal::RenderPipelineDescriptor::new(); enum MetalGenericRenderPipelineDescriptor {
Standard(metal::RenderPipelineDescriptor),
Mesh(metal::MeshRenderPipelineDescriptor),
}
macro_rules! descriptor_fn {
($descriptor:ident . $method:ident $( ( $($args:expr),* ) )? ) => {
match $descriptor {
MetalGenericRenderPipelineDescriptor::Standard(ref inner) => inner.$method$(($($args),*))?,
MetalGenericRenderPipelineDescriptor::Mesh(ref inner) => inner.$method$(($($args),*))?,
}
};
}
impl MetalGenericRenderPipelineDescriptor {
fn set_fragment_function(&self, function: Option<&metal::FunctionRef>) {
descriptor_fn!(self.set_fragment_function(function));
}
fn fragment_buffers(&self) -> Option<&metal::PipelineBufferDescriptorArrayRef> {
descriptor_fn!(self.fragment_buffers())
}
fn set_depth_attachment_pixel_format(&self, pixel_format: MTLPixelFormat) {
descriptor_fn!(self.set_depth_attachment_pixel_format(pixel_format));
}
fn color_attachments(
&self,
) -> &metal::RenderPipelineColorAttachmentDescriptorArrayRef {
descriptor_fn!(self.color_attachments())
}
fn set_stencil_attachment_pixel_format(&self, pixel_format: MTLPixelFormat) {
descriptor_fn!(self.set_stencil_attachment_pixel_format(pixel_format));
}
fn set_alpha_to_coverage_enabled(&self, enabled: bool) {
descriptor_fn!(self.set_alpha_to_coverage_enabled(enabled));
}
fn set_label(&self, label: &str) {
descriptor_fn!(self.set_label(label));
}
fn set_max_vertex_amplification_count(&self, count: metal::NSUInteger) {
descriptor_fn!(self.set_max_vertex_amplification_count(count))
}
}
let (primitive_class, raw_primitive_type) =
conv::map_primitive_topology(desc.primitive.topology);
let vs_info;
let ts_info;
let ms_info;
// Create the pipeline descriptor and do vertex/mesh pipeline specific setup
let descriptor = match desc.vertex_processor {
crate::VertexProcessor::Standard {
vertex_buffers,
ref vertex_stage,
} => {
// Vertex pipeline specific setup
let descriptor = metal::RenderPipelineDescriptor::new();
ts_info = None;
ms_info = None;
// Collect vertex buffer mappings
let mut vertex_buffer_mappings =
Vec::<naga::back::msl::VertexBufferMapping>::new();
for (i, vbl) in vertex_buffers.iter().enumerate() {
let mut attributes = Vec::<naga::back::msl::AttributeMapping>::new();
for attribute in vbl.attributes.iter() {
attributes.push(naga::back::msl::AttributeMapping {
shader_location: attribute.shader_location,
offset: attribute.offset as u32,
format: convert_vertex_format_to_naga(attribute.format),
});
}
let mapping = naga::back::msl::VertexBufferMapping {
id: self.shared.private_caps.max_vertex_buffers - 1 - i as u32,
stride: if vbl.array_stride > 0 {
vbl.array_stride.try_into().unwrap()
} else {
vbl.attributes
.iter()
.map(|attribute| attribute.offset + attribute.format.size())
.max()
.unwrap_or(0)
.try_into()
.unwrap()
},
step_mode: match (vbl.array_stride == 0, vbl.step_mode) {
(true, _) => naga::back::msl::VertexBufferStepMode::Constant,
(false, wgt::VertexStepMode::Vertex) => {
naga::back::msl::VertexBufferStepMode::ByVertex
}
(false, wgt::VertexStepMode::Instance) => {
naga::back::msl::VertexBufferStepMode::ByInstance
}
},
attributes,
};
vertex_buffer_mappings.push(mapping);
}
// Setup vertex shader
{
let vs = self.load_shader(
vertex_stage,
&vertex_buffer_mappings,
desc.layout,
primitive_class,
naga::ShaderStage::Vertex,
)?;
descriptor.set_vertex_function(Some(&vs.function));
if self.shared.private_caps.supports_mutability {
Self::set_buffers_mutability(
descriptor.vertex_buffers().unwrap(),
vs.immutable_buffer_mask,
);
}
vs_info = Some(super::PipelineStageInfo {
push_constants: desc.layout.push_constants_infos.vs,
sizes_slot: desc.layout.per_stage_map.vs.sizes_buffer,
sized_bindings: vs.sized_bindings,
vertex_buffer_mappings,
library: Some(vs.library),
raw_wg_size: Default::default(),
work_group_memory_sizes: vec![],
});
}
// Validate vertex buffer count
if desc.layout.total_counters.vs.buffers + (vertex_buffers.len() as u32)
> self.shared.private_caps.max_vertex_buffers
{
let msg = format!(
"pipeline needs too many buffers in the vertex stage: {} vertex and {} layout",
vertex_buffers.len(),
desc.layout.total_counters.vs.buffers
);
return Err(crate::PipelineError::Linkage(
wgt::ShaderStages::VERTEX,
msg,
));
}
// Set the pipeline vertex buffer info
if !vertex_buffers.is_empty() {
let vertex_descriptor = metal::VertexDescriptor::new();
for (i, vb) in vertex_buffers.iter().enumerate() {
let buffer_index =
self.shared.private_caps.max_vertex_buffers as u64 - 1 - i as u64;
let buffer_desc =
vertex_descriptor.layouts().object_at(buffer_index).unwrap();
// Metal expects the stride to be the actual size of the attributes.
// The semantics of array_stride == 0 can be achieved by setting
// the step function to constant and rate to 0.
if vb.array_stride == 0 {
let stride = vb
.attributes
.iter()
.map(|attribute| attribute.offset + attribute.format.size())
.max()
.unwrap_or(0);
buffer_desc.set_stride(wgt::math::align_to(stride, 4));
buffer_desc.set_step_function(MTLVertexStepFunction::Constant);
buffer_desc.set_step_rate(0);
} else {
buffer_desc.set_stride(vb.array_stride);
buffer_desc.set_step_function(conv::map_step_mode(vb.step_mode));
}
for at in vb.attributes {
let attribute_desc = vertex_descriptor
.attributes()
.object_at(at.shader_location as u64)
.unwrap();
attribute_desc.set_format(conv::map_vertex_format(at.format));
attribute_desc.set_buffer_index(buffer_index);
attribute_desc.set_offset(at.offset);
}
}
descriptor.set_vertex_descriptor(Some(vertex_descriptor));
}
MetalGenericRenderPipelineDescriptor::Standard(descriptor)
}
crate::VertexProcessor::Mesh {
ref task_stage,
ref mesh_stage,
} => {
// Mesh pipeline specific setup
vs_info = None;
let descriptor = metal::MeshRenderPipelineDescriptor::new();
// Setup task stage
if let Some(ref task_stage) = task_stage {
let ts = self.load_shader(
task_stage,
&[],
desc.layout,
primitive_class,
naga::ShaderStage::Task,
)?;
descriptor.set_object_function(Some(&ts.function));
if self.shared.private_caps.supports_mutability {
Self::set_buffers_mutability(
descriptor.mesh_buffers().unwrap(),
ts.immutable_buffer_mask,
);
}
ts_info = Some(super::PipelineStageInfo {
push_constants: desc.layout.push_constants_infos.ts,
sizes_slot: desc.layout.per_stage_map.ts.sizes_buffer,
sized_bindings: ts.sized_bindings,
vertex_buffer_mappings: vec![],
library: Some(ts.library),
raw_wg_size: ts.wg_size,
work_group_memory_sizes: ts.wg_memory_sizes,
});
} else {
ts_info = None;
}
// Setup mesh stage
{
let ms = self.load_shader(
mesh_stage,
&[],
desc.layout,
primitive_class,
naga::ShaderStage::Mesh,
)?;
descriptor.set_mesh_function(Some(&ms.function));
if self.shared.private_caps.supports_mutability {
Self::set_buffers_mutability(
descriptor.mesh_buffers().unwrap(),
ms.immutable_buffer_mask,
);
}
ms_info = Some(super::PipelineStageInfo {
push_constants: desc.layout.push_constants_infos.ms,
sizes_slot: desc.layout.per_stage_map.ms.sizes_buffer,
sized_bindings: ms.sized_bindings,
vertex_buffer_mappings: vec![],
library: Some(ms.library),
raw_wg_size: ms.wg_size,
work_group_memory_sizes: ms.wg_memory_sizes,
});
}
MetalGenericRenderPipelineDescriptor::Mesh(descriptor)
}
};
let raw_triangle_fill_mode = match desc.primitive.polygon_mode { let raw_triangle_fill_mode = match desc.primitive.polygon_mode {
wgt::PolygonMode::Fill => MTLTriangleFillMode::Fill, wgt::PolygonMode::Fill => MTLTriangleFillMode::Fill,
@ -1133,76 +1378,8 @@ impl crate::Device for super::Device {
), ),
}; };
let (primitive_class, raw_primitive_type) =
conv::map_primitive_topology(desc.primitive.topology);
// Vertex shader
let (vs_lib, vs_info) = {
let mut vertex_buffer_mappings = Vec::<naga::back::msl::VertexBufferMapping>::new();
for (i, vbl) in desc_vertex_buffers.iter().enumerate() {
let mut attributes = Vec::<naga::back::msl::AttributeMapping>::new();
for attribute in vbl.attributes.iter() {
attributes.push(naga::back::msl::AttributeMapping {
shader_location: attribute.shader_location,
offset: attribute.offset as u32,
format: convert_vertex_format_to_naga(attribute.format),
});
}
vertex_buffer_mappings.push(naga::back::msl::VertexBufferMapping {
id: self.shared.private_caps.max_vertex_buffers - 1 - i as u32,
stride: if vbl.array_stride > 0 {
vbl.array_stride.try_into().unwrap()
} else {
vbl.attributes
.iter()
.map(|attribute| attribute.offset + attribute.format.size())
.max()
.unwrap_or(0)
.try_into()
.unwrap()
},
step_mode: match (vbl.array_stride == 0, vbl.step_mode) {
(true, _) => naga::back::msl::VertexBufferStepMode::Constant,
(false, wgt::VertexStepMode::Vertex) => {
naga::back::msl::VertexBufferStepMode::ByVertex
}
(false, wgt::VertexStepMode::Instance) => {
naga::back::msl::VertexBufferStepMode::ByInstance
}
},
attributes,
});
}
let vs = self.load_shader(
desc_vertex_stage,
&vertex_buffer_mappings,
desc.layout,
primitive_class,
naga::ShaderStage::Vertex,
)?;
descriptor.set_vertex_function(Some(&vs.function));
if self.shared.private_caps.supports_mutability {
Self::set_buffers_mutability(
descriptor.vertex_buffers().unwrap(),
vs.immutable_buffer_mask,
);
}
let info = super::PipelineStageInfo {
push_constants: desc.layout.push_constants_infos.vs,
sizes_slot: desc.layout.per_stage_map.vs.sizes_buffer,
sized_bindings: vs.sized_bindings,
vertex_buffer_mappings,
};
(vs.library, info)
};
// Fragment shader // Fragment shader
let (fs_lib, fs_info) = match desc.fragment_stage { let fs_info = match desc.fragment_stage {
Some(ref stage) => { Some(ref stage) => {
let fs = self.load_shader( let fs = self.load_shader(
stage, stage,
@ -1220,14 +1397,15 @@ impl crate::Device for super::Device {
); );
} }
let info = super::PipelineStageInfo { Some(super::PipelineStageInfo {
push_constants: desc.layout.push_constants_infos.fs, push_constants: desc.layout.push_constants_infos.fs,
sizes_slot: desc.layout.per_stage_map.fs.sizes_buffer, sizes_slot: desc.layout.per_stage_map.fs.sizes_buffer,
sized_bindings: fs.sized_bindings, sized_bindings: fs.sized_bindings,
vertex_buffer_mappings: vec![], vertex_buffer_mappings: vec![],
}; library: Some(fs.library),
raw_wg_size: Default::default(),
(Some(fs.library), Some(info)) work_group_memory_sizes: vec![],
})
} }
None => { None => {
// TODO: This is a workaround for what appears to be a Metal validation bug // TODO: This is a workaround for what appears to be a Metal validation bug
@ -1235,10 +1413,11 @@ impl crate::Device for super::Device {
if desc.color_targets.is_empty() && desc.depth_stencil.is_none() { if desc.color_targets.is_empty() && desc.depth_stencil.is_none() {
descriptor.set_depth_attachment_pixel_format(MTLPixelFormat::Depth32Float); descriptor.set_depth_attachment_pixel_format(MTLPixelFormat::Depth32Float);
} }
(None, None) None
} }
}; };
// Setup pipeline color attachments
for (i, ct) in desc.color_targets.iter().enumerate() { for (i, ct) in desc.color_targets.iter().enumerate() {
let at_descriptor = descriptor.color_attachments().object_at(i as u64).unwrap(); let at_descriptor = descriptor.color_attachments().object_at(i as u64).unwrap();
let ct = if let Some(color_target) = ct.as_ref() { let ct = if let Some(color_target) = ct.as_ref() {
@ -1267,6 +1446,7 @@ impl crate::Device for super::Device {
} }
} }
// Setup depth stencil state
let depth_stencil = match desc.depth_stencil { let depth_stencil = match desc.depth_stencil {
Some(ref ds) => { Some(ref ds) => {
let raw_format = self.shared.private_caps.map_format(ds.format); let raw_format = self.shared.private_caps.map_format(ds.format);
@ -1289,94 +1469,54 @@ impl crate::Device for super::Device {
None => None, None => None,
}; };
if desc.layout.total_counters.vs.buffers + (desc_vertex_buffers.len() as u32) // Setup multisample state
> self.shared.private_caps.max_vertex_buffers
{
let msg = format!(
"pipeline needs too many buffers in the vertex stage: {} vertex and {} layout",
desc_vertex_buffers.len(),
desc.layout.total_counters.vs.buffers
);
return Err(crate::PipelineError::Linkage(
wgt::ShaderStages::VERTEX,
msg,
));
}
if !desc_vertex_buffers.is_empty() {
let vertex_descriptor = metal::VertexDescriptor::new();
for (i, vb) in desc_vertex_buffers.iter().enumerate() {
let buffer_index =
self.shared.private_caps.max_vertex_buffers as u64 - 1 - i as u64;
let buffer_desc = vertex_descriptor.layouts().object_at(buffer_index).unwrap();
// Metal expects the stride to be the actual size of the attributes.
// The semantics of array_stride == 0 can be achieved by setting
// the step function to constant and rate to 0.
if vb.array_stride == 0 {
let stride = vb
.attributes
.iter()
.map(|attribute| attribute.offset + attribute.format.size())
.max()
.unwrap_or(0);
buffer_desc.set_stride(wgt::math::align_to(stride, 4));
buffer_desc.set_step_function(MTLVertexStepFunction::Constant);
buffer_desc.set_step_rate(0);
} else {
buffer_desc.set_stride(vb.array_stride);
buffer_desc.set_step_function(conv::map_step_mode(vb.step_mode));
}
for at in vb.attributes {
let attribute_desc = vertex_descriptor
.attributes()
.object_at(at.shader_location as u64)
.unwrap();
attribute_desc.set_format(conv::map_vertex_format(at.format));
attribute_desc.set_buffer_index(buffer_index);
attribute_desc.set_offset(at.offset);
}
}
descriptor.set_vertex_descriptor(Some(vertex_descriptor));
}
if desc.multisample.count != 1 { if desc.multisample.count != 1 {
//TODO: handle sample mask //TODO: handle sample mask
descriptor.set_sample_count(desc.multisample.count as u64); match descriptor {
MetalGenericRenderPipelineDescriptor::Standard(ref inner) => {
inner.set_sample_count(desc.multisample.count as u64);
}
MetalGenericRenderPipelineDescriptor::Mesh(ref inner) => {
inner.set_raster_sample_count(desc.multisample.count as u64);
}
}
descriptor descriptor
.set_alpha_to_coverage_enabled(desc.multisample.alpha_to_coverage_enabled); .set_alpha_to_coverage_enabled(desc.multisample.alpha_to_coverage_enabled);
//descriptor.set_alpha_to_one_enabled(desc.multisample.alpha_to_one_enabled); //descriptor.set_alpha_to_one_enabled(desc.multisample.alpha_to_one_enabled);
} }
// Set debug label
if let Some(name) = desc.label { if let Some(name) = desc.label {
descriptor.set_label(name); descriptor.set_label(name);
} }
if let Some(mv) = desc.multiview_mask { if let Some(mv) = desc.multiview_mask {
descriptor.set_max_vertex_amplification_count(mv.get().count_ones() as u64); descriptor.set_max_vertex_amplification_count(mv.get().count_ones() as u64);
} }
let raw = self // Create the pipeline from descriptor
.shared let raw = match descriptor {
.device MetalGenericRenderPipelineDescriptor::Standard(d) => {
.lock() self.shared.device.lock().new_render_pipeline_state(&d)
.new_render_pipeline_state(&descriptor) }
.map_err(|e| { MetalGenericRenderPipelineDescriptor::Mesh(d) => {
crate::PipelineError::Linkage( self.shared.device.lock().new_mesh_render_pipeline_state(&d)
wgt::ShaderStages::VERTEX | wgt::ShaderStages::FRAGMENT, }
format!("new_render_pipeline_state: {e:?}"), }
) .map_err(|e| {
})?; crate::PipelineError::Linkage(
wgt::ShaderStages::VERTEX | wgt::ShaderStages::FRAGMENT,
format!("new_render_pipeline_state: {e:?}"),
)
})?;
self.counters.render_pipelines.add(1); self.counters.render_pipelines.add(1);
Ok(super::RenderPipeline { Ok(super::RenderPipeline {
raw, raw,
vs_lib,
fs_lib,
vs_info, vs_info,
fs_info, fs_info,
ts_info,
ms_info,
raw_primitive_type, raw_primitive_type,
raw_triangle_fill_mode, raw_triangle_fill_mode,
raw_front_winding: conv::map_winding(desc.primitive.front_face), raw_front_winding: conv::map_winding(desc.primitive.front_face),
@ -1444,10 +1584,13 @@ impl crate::Device for super::Device {
} }
let cs_info = super::PipelineStageInfo { let cs_info = super::PipelineStageInfo {
library: Some(cs.library),
push_constants: desc.layout.push_constants_infos.cs, push_constants: desc.layout.push_constants_infos.cs,
sizes_slot: desc.layout.per_stage_map.cs.sizes_buffer, sizes_slot: desc.layout.per_stage_map.cs.sizes_buffer,
sized_bindings: cs.sized_bindings, sized_bindings: cs.sized_bindings,
vertex_buffer_mappings: vec![], vertex_buffer_mappings: vec![],
raw_wg_size: cs.wg_size,
work_group_memory_sizes: cs.wg_memory_sizes,
}; };
if let Some(name) = desc.label { if let Some(name) = desc.label {
@ -1468,13 +1611,7 @@ impl crate::Device for super::Device {
self.counters.compute_pipelines.add(1); self.counters.compute_pipelines.add(1);
Ok(super::ComputePipeline { Ok(super::ComputePipeline { raw, cs_info })
raw,
cs_info,
cs_lib: cs.library,
work_group_size: cs.wg_size,
work_group_memory_sizes: cs.wg_memory_sizes,
})
}) })
} }

View File

@ -302,6 +302,7 @@ struct PrivateCapabilities {
int64_atomics: bool, int64_atomics: bool,
float_atomics: bool, float_atomics: bool,
supports_shared_event: bool, supports_shared_event: bool,
mesh_shaders: bool,
supported_vertex_amplification_factor: u32, supported_vertex_amplification_factor: u32,
shader_barycentrics: bool, shader_barycentrics: bool,
supports_memoryless_storage: bool, supports_memoryless_storage: bool,
@ -609,12 +610,16 @@ struct MultiStageData<T> {
vs: T, vs: T,
fs: T, fs: T,
cs: T, cs: T,
ts: T,
ms: T,
} }
const NAGA_STAGES: MultiStageData<naga::ShaderStage> = MultiStageData { const NAGA_STAGES: MultiStageData<naga::ShaderStage> = MultiStageData {
vs: naga::ShaderStage::Vertex, vs: naga::ShaderStage::Vertex,
fs: naga::ShaderStage::Fragment, fs: naga::ShaderStage::Fragment,
cs: naga::ShaderStage::Compute, cs: naga::ShaderStage::Compute,
ts: naga::ShaderStage::Task,
ms: naga::ShaderStage::Mesh,
}; };
impl<T> ops::Index<naga::ShaderStage> for MultiStageData<T> { impl<T> ops::Index<naga::ShaderStage> for MultiStageData<T> {
@ -624,7 +629,8 @@ impl<T> ops::Index<naga::ShaderStage> for MultiStageData<T> {
naga::ShaderStage::Vertex => &self.vs, naga::ShaderStage::Vertex => &self.vs,
naga::ShaderStage::Fragment => &self.fs, naga::ShaderStage::Fragment => &self.fs,
naga::ShaderStage::Compute => &self.cs, naga::ShaderStage::Compute => &self.cs,
naga::ShaderStage::Task | naga::ShaderStage::Mesh => unreachable!(), naga::ShaderStage::Task => &self.ts,
naga::ShaderStage::Mesh => &self.ms,
} }
} }
} }
@ -635,6 +641,8 @@ impl<T> MultiStageData<T> {
vs: fun(&self.vs), vs: fun(&self.vs),
fs: fun(&self.fs), fs: fun(&self.fs),
cs: fun(&self.cs), cs: fun(&self.cs),
ts: fun(&self.ts),
ms: fun(&self.ms),
} }
} }
fn map<Y>(self, fun: impl Fn(T) -> Y) -> MultiStageData<Y> { fn map<Y>(self, fun: impl Fn(T) -> Y) -> MultiStageData<Y> {
@ -642,17 +650,23 @@ impl<T> MultiStageData<T> {
vs: fun(self.vs), vs: fun(self.vs),
fs: fun(self.fs), fs: fun(self.fs),
cs: fun(self.cs), cs: fun(self.cs),
ts: fun(self.ts),
ms: fun(self.ms),
} }
} }
fn iter<'a>(&'a self) -> impl Iterator<Item = &'a T> { fn iter<'a>(&'a self) -> impl Iterator<Item = &'a T> {
iter::once(&self.vs) iter::once(&self.vs)
.chain(iter::once(&self.fs)) .chain(iter::once(&self.fs))
.chain(iter::once(&self.cs)) .chain(iter::once(&self.cs))
.chain(iter::once(&self.ts))
.chain(iter::once(&self.ms))
} }
fn iter_mut<'a>(&'a mut self) -> impl Iterator<Item = &'a mut T> { fn iter_mut<'a>(&'a mut self) -> impl Iterator<Item = &'a mut T> {
iter::once(&mut self.vs) iter::once(&mut self.vs)
.chain(iter::once(&mut self.fs)) .chain(iter::once(&mut self.fs))
.chain(iter::once(&mut self.cs)) .chain(iter::once(&mut self.cs))
.chain(iter::once(&mut self.ts))
.chain(iter::once(&mut self.ms))
} }
} }
@ -816,6 +830,8 @@ impl crate::DynShaderModule for ShaderModule {}
#[derive(Debug, Default)] #[derive(Debug, Default)]
struct PipelineStageInfo { struct PipelineStageInfo {
#[allow(dead_code)]
library: Option<metal::Library>,
push_constants: Option<PushConstantsInfo>, push_constants: Option<PushConstantsInfo>,
/// The buffer argument table index at which we pass runtime-sized arrays' buffer sizes. /// The buffer argument table index at which we pass runtime-sized arrays' buffer sizes.
@ -830,6 +846,12 @@ struct PipelineStageInfo {
/// Info on all bound vertex buffers. /// Info on all bound vertex buffers.
vertex_buffer_mappings: Vec<naga::back::msl::VertexBufferMapping>, vertex_buffer_mappings: Vec<naga::back::msl::VertexBufferMapping>,
/// The workgroup size for compute, task or mesh stages
raw_wg_size: MTLSize,
/// The workgroup memory sizes for compute task or mesh stages
work_group_memory_sizes: Vec<u32>,
} }
impl PipelineStageInfo { impl PipelineStageInfo {
@ -838,6 +860,9 @@ impl PipelineStageInfo {
self.sizes_slot = None; self.sizes_slot = None;
self.sized_bindings.clear(); self.sized_bindings.clear();
self.vertex_buffer_mappings.clear(); self.vertex_buffer_mappings.clear();
self.library = None;
self.work_group_memory_sizes.clear();
self.raw_wg_size = Default::default();
} }
fn assign_from(&mut self, other: &Self) { fn assign_from(&mut self, other: &Self) {
@ -848,18 +873,21 @@ impl PipelineStageInfo {
self.vertex_buffer_mappings.clear(); self.vertex_buffer_mappings.clear();
self.vertex_buffer_mappings self.vertex_buffer_mappings
.extend_from_slice(&other.vertex_buffer_mappings); .extend_from_slice(&other.vertex_buffer_mappings);
self.library = Some(other.library.as_ref().unwrap().clone());
self.raw_wg_size = other.raw_wg_size;
self.work_group_memory_sizes.clear();
self.work_group_memory_sizes
.extend_from_slice(&other.work_group_memory_sizes);
} }
} }
#[derive(Debug)] #[derive(Debug)]
pub struct RenderPipeline { pub struct RenderPipeline {
raw: metal::RenderPipelineState, raw: metal::RenderPipelineState,
#[allow(dead_code)] vs_info: Option<PipelineStageInfo>,
vs_lib: metal::Library,
#[allow(dead_code)]
fs_lib: Option<metal::Library>,
vs_info: PipelineStageInfo,
fs_info: Option<PipelineStageInfo>, fs_info: Option<PipelineStageInfo>,
ts_info: Option<PipelineStageInfo>,
ms_info: Option<PipelineStageInfo>,
raw_primitive_type: MTLPrimitiveType, raw_primitive_type: MTLPrimitiveType,
raw_triangle_fill_mode: MTLTriangleFillMode, raw_triangle_fill_mode: MTLTriangleFillMode,
raw_front_winding: MTLWinding, raw_front_winding: MTLWinding,
@ -876,11 +904,7 @@ impl crate::DynRenderPipeline for RenderPipeline {}
#[derive(Debug)] #[derive(Debug)]
pub struct ComputePipeline { pub struct ComputePipeline {
raw: metal::ComputePipelineState, raw: metal::ComputePipelineState,
#[allow(dead_code)]
cs_lib: metal::Library,
cs_info: PipelineStageInfo, cs_info: PipelineStageInfo,
work_group_size: MTLSize,
work_group_memory_sizes: Vec<u32>,
} }
unsafe impl Send for ComputePipeline {} unsafe impl Send for ComputePipeline {}
@ -954,7 +978,6 @@ struct CommandState {
compute: Option<metal::ComputeCommandEncoder>, compute: Option<metal::ComputeCommandEncoder>,
raw_primitive_type: MTLPrimitiveType, raw_primitive_type: MTLPrimitiveType,
index: Option<IndexState>, index: Option<IndexState>,
raw_wg_size: MTLSize,
stage_infos: MultiStageData<PipelineStageInfo>, stage_infos: MultiStageData<PipelineStageInfo>,
/// Sizes of currently bound [`wgt::BufferBindingType::Storage`] buffers. /// Sizes of currently bound [`wgt::BufferBindingType::Storage`] buffers.
@ -980,7 +1003,6 @@ struct CommandState {
vertex_buffer_size_map: FastHashMap<u64, wgt::BufferSize>, vertex_buffer_size_map: FastHashMap<u64, wgt::BufferSize>,
work_group_memory_sizes: Vec<u32>,
push_constants: Vec<u32>, push_constants: Vec<u32>,
/// Timer query that should be executed when the next pass starts. /// Timer query that should be executed when the next pass starts.

View File

@ -1169,12 +1169,11 @@ bitflags_array! {
/// This is a native only feature. /// This is a native only feature.
const UNIFORM_BUFFER_BINDING_ARRAYS = 1 << 47; const UNIFORM_BUFFER_BINDING_ARRAYS = 1 << 47;
/// Enables mesh shaders and task shaders in mesh shader pipelines. /// Enables mesh shaders and task shaders in mesh shader pipelines. This extension does NOT imply support for
/// compiling mesh shaders at runtime. Rather, the user must use custom passthrough shaders.
/// ///
/// Supported platforms: /// Supported platforms:
/// - Vulkan (with [VK_EXT_mesh_shader](https://registry.khronos.org/vulkan/specs/latest/man/html/VK_EXT_mesh_shader.html)) /// - Vulkan (with [VK_EXT_mesh_shader](https://registry.khronos.org/vulkan/specs/latest/man/html/VK_EXT_mesh_shader.html))
///
/// Potential Platforms:
/// - DX12 /// - DX12
/// - Metal /// - Metal
/// ///

View File

@ -1062,14 +1062,12 @@ impl Limits {
#[must_use] #[must_use]
pub const fn using_recommended_minimum_mesh_shader_values(self) -> Self { pub const fn using_recommended_minimum_mesh_shader_values(self) -> Self {
Self { Self {
// Literally just made this up as 256^2 or 2^16. // This is a common limit for apple devices. It's not immediately clear why.
// My GPU supports 2^22, and compute shaders don't have this kind of limit. max_task_workgroup_total_count: 1024,
// This very likely is never a real limiter max_task_workgroups_per_dimension: 1024,
max_task_workgroup_total_count: 65536,
max_task_workgroups_per_dimension: 256,
// llvmpipe reports 0 multiview count, which just means no multiview is allowed // llvmpipe reports 0 multiview count, which just means no multiview is allowed
max_mesh_multiview_view_count: 0, max_mesh_multiview_view_count: 0,
// llvmpipe once again requires this to be 8. An RTX 3060 supports well over 1024. // llvmpipe once again requires this to be <=8. An RTX 3060 supports well over 1024.
max_mesh_output_layers: 8, max_mesh_output_layers: 8,
..self ..self
} }