[hal/metal] Mesh Shaders (#8139)

Co-authored-by: Connor Fitzgerald <connorwadefitzgerald@gmail.com>
Co-authored-by: Magnus <85136135+SupaMaggie70Incorporated@users.noreply.github.com>
This commit is contained in:
Inner Daemons 2025-11-14 23:11:43 -05:00 committed by GitHub
parent 92fa99af1b
commit d0cf78c8c8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 906 additions and 389 deletions

View File

@ -125,6 +125,9 @@ By @SupaMaggie70Incorporated in [#8206](https://github.com/gfx-rs/wgpu/pull/8206
- `util::StagingBelt` now takes a `Device` when it is created instead of when it is used. By @kpreid in [#8462](https://github.com/gfx-rs/wgpu/pull/8462).
- `wgpu_hal::vulkan::Device::texture_from_raw` now takes an `external_memory` argument. By @s-ol in [#8512](https://github.com/gfx-rs/wgpu/pull/8512)
#### Metal
- Add support for mesh shaders. By @SupaMaggie70Incorporated in [#8139](https://github.com/gfx-rs/wgpu/pull/8139)
#### Naga
- Prevent UB with invalid ray query calls on spirv. By @Vecvec in [#8390](https://github.com/gfx-rs/wgpu/pull/8390).

View File

@ -49,6 +49,7 @@ fn all_tests() -> Vec<wgpu_test::GpuTestInitializer> {
cube::TEST,
cube::TEST_LINES,
hello_synchronization::tests::SYNC,
mesh_shader::TEST,
mipmap::TEST,
mipmap::TEST_QUERY,
msaa_line::TEST,

View File

@ -61,6 +61,18 @@ fn compile_hlsl(device: &wgpu::Device, entry: &str, stage_str: &str) -> wgpu::Sh
}
}
fn compile_msl(device: &wgpu::Device, entry: &str) -> wgpu::ShaderModule {
unsafe {
device.create_shader_module_passthrough(wgpu::ShaderModuleDescriptorPassthrough {
entry_point: entry.to_owned(),
label: None,
msl: Some(std::borrow::Cow::Borrowed(include_str!("shader.metal"))),
num_workgroups: (1, 1, 1),
..Default::default()
})
}
}
pub struct Example {
pipeline: wgpu::RenderPipeline,
}
@ -71,20 +83,23 @@ impl crate::framework::Example for Example {
device: &wgpu::Device,
_queue: &wgpu::Queue,
) -> Self {
let (ts, ms, fs) = if adapter.get_info().backend == wgpu::Backend::Vulkan {
(
let (ts, ms, fs) = match adapter.get_info().backend {
wgpu::Backend::Vulkan => (
compile_glsl(device, "task"),
compile_glsl(device, "mesh"),
compile_glsl(device, "frag"),
)
} else if adapter.get_info().backend == wgpu::Backend::Dx12 {
(
),
wgpu::Backend::Dx12 => (
compile_hlsl(device, "Task", "as"),
compile_hlsl(device, "Mesh", "ms"),
compile_hlsl(device, "Frag", "ps"),
)
} else {
panic!("Example can only run on vulkan or dx12");
),
wgpu::Backend::Metal => (
compile_msl(device, "taskShader"),
compile_msl(device, "meshShader"),
compile_msl(device, "fragShader"),
),
_ => panic!("Example can currently only run on vulkan, dx12 or metal"),
};
let pipeline_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
label: None,
@ -179,3 +194,21 @@ impl crate::framework::Example for Example {
pub fn main() {
crate::framework::run::<Example>("mesh_shader");
}
#[cfg(test)]
#[wgpu_test::gpu_test]
pub static TEST: crate::framework::ExampleTestParams = crate::framework::ExampleTestParams {
name: "mesh_shader",
image_path: "/examples/features/src/mesh_shader/screenshot.png",
width: 1024,
height: 768,
optional_features: wgpu::Features::default(),
base_test_parameters: wgpu_test::TestParameters::default()
.features(
wgpu::Features::EXPERIMENTAL_MESH_SHADER
| wgpu::Features::EXPERIMENTAL_PASSTHROUGH_SHADERS,
)
.limits(wgpu::Limits::defaults().using_recommended_minimum_mesh_shader_values()),
comparisons: &[wgpu_test::ComparisonType::Mean(0.01)],
_phantom: std::marker::PhantomData::<Example>,
};

Binary file not shown.

After

Width:  |  Height:  |  Size: 34 KiB

View File

@ -0,0 +1,77 @@
using namespace metal;
struct OutVertex {
float4 Position [[position]];
float4 Color [[user(locn0)]];
};
struct OutPrimitive {
float4 ColorMask [[flat]] [[user(locn1)]];
bool CullPrimitive [[primitive_culled]];
};
struct InVertex {
};
struct InPrimitive {
float4 ColorMask [[flat]] [[user(locn1)]];
};
struct FragmentIn {
float4 Color [[user(locn0)]];
float4 ColorMask [[flat]] [[user(locn1)]];
};
struct PayloadData {
float4 ColorMask;
bool Visible;
};
using Meshlet = metal::mesh<OutVertex, OutPrimitive, 3, 1, topology::triangle>;
constant float4 positions[3] = {
float4(0.0, 1.0, 0.0, 1.0),
float4(-1.0, -1.0, 0.0, 1.0),
float4(1.0, -1.0, 0.0, 1.0)
};
constant float4 colors[3] = {
float4(0.0, 1.0, 0.0, 1.0),
float4(0.0, 0.0, 1.0, 1.0),
float4(1.0, 0.0, 0.0, 1.0)
};
[[object]]
void taskShader(uint3 tid [[thread_position_in_grid]], object_data PayloadData &outPayload [[payload]], mesh_grid_properties grid) {
outPayload.ColorMask = float4(1.0, 1.0, 0.0, 1.0);
outPayload.Visible = true;
grid.set_threadgroups_per_grid(uint3(3, 1, 1));
}
[[mesh]]
void meshShader(
object_data PayloadData const& payload [[payload]],
Meshlet out
)
{
out.set_primitive_count(1);
for(int i = 0;i < 3;i++) {
OutVertex vert;
vert.Position = positions[i];
vert.Color = colors[i] * payload.ColorMask;
out.set_vertex(i, vert);
out.set_index(i, i);
}
OutPrimitive prim;
prim.ColorMask = float4(1.0, 0.0, 0.0, 1.0);
prim.CullPrimitive = !payload.Visible;
out.set_primitive(0, prim);
}
fragment float4 fragShader(FragmentIn data [[stage_in]]) {
return data.Color * data.ColorMask;
}

View File

@ -3,15 +3,11 @@ use std::{
process::Stdio,
};
use wgpu::{util::DeviceExt, Backends};
use wgpu::util::DeviceExt;
use wgpu_test::{
fail, gpu_test, FailureCase, GpuTestConfiguration, GpuTestInitializer, TestParameters,
TestingContext,
fail, gpu_test, GpuTestConfiguration, GpuTestInitializer, TestParameters, TestingContext,
};
/// Backends that support mesh shaders
const MESH_SHADER_BACKENDS: Backends = Backends::DX12.union(Backends::VULKAN);
pub fn all_tests(tests: &mut Vec<GpuTestInitializer>) {
tests.extend([
MESH_PIPELINE_BASIC_MESH,
@ -98,6 +94,18 @@ fn compile_hlsl(
}
}
fn compile_msl(device: &wgpu::Device, entry: &str) -> wgpu::ShaderModule {
unsafe {
device.create_shader_module_passthrough(wgpu::ShaderModuleDescriptorPassthrough {
entry_point: entry.to_owned(),
label: None,
msl: Some(std::borrow::Cow::Borrowed(include_str!("shader.metal"))),
num_workgroups: (1, 1, 1),
..Default::default()
})
}
}
fn get_shaders(
device: &wgpu::Device,
backend: wgpu::Backend,
@ -114,8 +122,8 @@ fn get_shaders(
// (In the case that the platform does support mesh shaders, the dummy
// shader is used to avoid requiring EXPERIMENTAL_PASSTHROUGH_SHADERS.)
let dummy_shader = device.create_shader_module(wgpu::include_wgsl!("non_mesh.wgsl"));
if backend == wgpu::Backend::Vulkan {
(
match backend {
wgpu::Backend::Vulkan => (
info.use_task.then(|| compile_glsl(device, "task")),
if info.use_mesh {
compile_glsl(device, "mesh")
@ -123,9 +131,8 @@ fn get_shaders(
dummy_shader
},
info.use_frag.then(|| compile_glsl(device, "frag")),
)
} else if backend == wgpu::Backend::Dx12 {
(
),
wgpu::Backend::Dx12 => (
info.use_task
.then(|| compile_hlsl(device, "Task", "as", test_name)),
if info.use_mesh {
@ -135,11 +142,20 @@ fn get_shaders(
},
info.use_frag
.then(|| compile_hlsl(device, "Frag", "ps", test_name)),
)
} else {
assert!(!MESH_SHADER_BACKENDS.contains(Backends::from(backend)));
assert!(!info.use_task && !info.use_mesh && !info.use_frag);
(None, dummy_shader, None)
),
wgpu::Backend::Metal => (
info.use_task.then(|| compile_msl(device, "taskShader")),
if info.use_mesh {
compile_msl(device, "meshShader")
} else {
dummy_shader
},
info.use_frag.then(|| compile_msl(device, "fragShader")),
),
_ => {
assert!(!info.use_task && !info.use_mesh && !info.use_frag);
(None, dummy_shader, None)
}
}
}
@ -377,7 +393,6 @@ fn mesh_draw(ctx: &TestingContext, draw_type: DrawType) {
fn default_gpu_test_config(draw_type: DrawType) -> GpuTestConfiguration {
GpuTestConfiguration::new().parameters(
TestParameters::default()
.skip(FailureCase::backend(!MESH_SHADER_BACKENDS))
.test_features_limits()
.features(
wgpu::Features::EXPERIMENTAL_MESH_SHADER

View File

@ -0,0 +1,77 @@
using namespace metal;
struct OutVertex {
float4 Position [[position]];
float4 Color [[user(locn0)]];
};
struct OutPrimitive {
float4 ColorMask [[flat]] [[user(locn1)]];
bool CullPrimitive [[primitive_culled]];
};
struct InVertex {
};
struct InPrimitive {
float4 ColorMask [[flat]] [[user(locn1)]];
};
struct FragmentIn {
float4 Color [[user(locn0)]];
float4 ColorMask [[flat]] [[user(locn1)]];
};
struct PayloadData {
float4 ColorMask;
bool Visible;
};
using Meshlet = metal::mesh<OutVertex, OutPrimitive, 3, 1, topology::triangle>;
constant float4 positions[3] = {
float4(0.0, 1.0, 0.0, 1.0),
float4(-1.0, -1.0, 0.0, 1.0),
float4(1.0, -1.0, 0.0, 1.0)
};
constant float4 colors[3] = {
float4(0.0, 1.0, 0.0, 1.0),
float4(0.0, 0.0, 1.0, 1.0),
float4(1.0, 0.0, 0.0, 1.0)
};
[[object]]
void taskShader(uint3 tid [[thread_position_in_grid]], object_data PayloadData &outPayload [[payload]], mesh_grid_properties grid) {
outPayload.ColorMask = float4(1.0, 1.0, 0.0, 1.0);
outPayload.Visible = true;
grid.set_threadgroups_per_grid(uint3(3, 1, 1));
}
[[mesh]]
void meshShader(
object_data PayloadData const& payload [[payload]],
Meshlet out
)
{
out.set_primitive_count(1);
for(int i = 0;i < 3;i++) {
OutVertex vert;
vert.Position = positions[i];
vert.Color = colors[i] * payload.ColorMask;
out.set_vertex(i, vert);
out.set_index(i, i);
}
OutPrimitive prim;
prim.ColorMask = float4(1.0, 0.0, 0.0, 1.0);
prim.CullPrimitive = !payload.Visible;
out.set_primitive(0, prim);
}
fragment float4 fragShader(FragmentIn data [[stage_in]]) {
return data.Color * data.ColorMask;
}

View File

@ -607,6 +607,8 @@ impl super::PrivateCapabilities {
let argument_buffers = device.argument_buffers_support();
let is_virtual = device.name().to_lowercase().contains("virtual");
Self {
family_check,
msl_version: if os_is_xr || version.at_least((14, 0), (17, 0), os_is_mac) {
@ -902,6 +904,12 @@ impl super::PrivateCapabilities {
&& (device.supports_family(MTLGPUFamily::Apple7)
|| device.supports_family(MTLGPUFamily::Mac2)),
supports_shared_event: version.at_least((10, 14), (12, 0), os_is_mac),
mesh_shaders: family_check
&& (device.supports_family(MTLGPUFamily::Metal3)
|| device.supports_family(MTLGPUFamily::Apple7)
|| device.supports_family(MTLGPUFamily::Mac2))
// Mesh shaders don't work on virtual devices even if they should be supported.
&& !is_virtual,
supported_vertex_amplification_factor: {
let mut factor = 1;
// https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf#page=8
@ -1023,6 +1031,8 @@ impl super::PrivateCapabilities {
features.insert(F::SUBGROUP | F::SUBGROUP_BARRIER);
}
features.set(F::EXPERIMENTAL_MESH_SHADER, self.mesh_shaders);
if self.supported_vertex_amplification_factor > 1 {
features.insert(F::MULTIVIEW);
}
@ -1102,10 +1112,11 @@ impl super::PrivateCapabilities {
max_buffer_size: self.max_buffer_size,
max_non_sampler_bindings: u32::MAX,
max_task_workgroup_total_count: 0,
max_task_workgroups_per_dimension: 0,
// See https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf, Maximum threadgroups per mesh shader grid
max_task_workgroup_total_count: 1024,
max_task_workgroups_per_dimension: 1024,
max_mesh_multiview_view_count: 0,
max_mesh_output_layers: 0,
max_mesh_output_layers: self.max_texture_layers as u32,
max_blas_primitive_count: 0, // When added: 2^28 from https://developer.apple.com/documentation/metal/mtlaccelerationstructureusage/extendedlimits
max_blas_geometry_count: 0, // When added: 2^24

View File

@ -22,11 +22,9 @@ impl Default for super::CommandState {
compute: None,
raw_primitive_type: MTLPrimitiveType::Point,
index: None,
raw_wg_size: MTLSize::new(0, 0, 0),
stage_infos: Default::default(),
storage_buffer_length_map: Default::default(),
vertex_buffer_size_map: Default::default(),
work_group_memory_sizes: Vec::new(),
push_constants: Vec::new(),
pending_timer_queries: Vec::new(),
}
@ -146,6 +144,127 @@ impl super::CommandEncoder {
self.state.reset();
self.leave_blit();
}
/// Updates the bindings for a single shader stage, called in `set_bind_group`.
#[expect(clippy::too_many_arguments)]
fn update_bind_group_state(
&mut self,
stage: naga::ShaderStage,
render_encoder: Option<&metal::RenderCommandEncoder>,
compute_encoder: Option<&metal::ComputeCommandEncoder>,
index_base: super::ResourceData<u32>,
bg_info: &super::BindGroupLayoutInfo,
dynamic_offsets: &[wgt::DynamicOffset],
group_index: u32,
group: &super::BindGroup,
) {
let resource_indices = match stage {
naga::ShaderStage::Vertex => &bg_info.base_resource_indices.vs,
naga::ShaderStage::Fragment => &bg_info.base_resource_indices.fs,
naga::ShaderStage::Task => &bg_info.base_resource_indices.ts,
naga::ShaderStage::Mesh => &bg_info.base_resource_indices.ms,
naga::ShaderStage::Compute => &bg_info.base_resource_indices.cs,
};
let buffers = match stage {
naga::ShaderStage::Vertex => group.counters.vs.buffers,
naga::ShaderStage::Fragment => group.counters.fs.buffers,
naga::ShaderStage::Task => group.counters.ts.buffers,
naga::ShaderStage::Mesh => group.counters.ms.buffers,
naga::ShaderStage::Compute => group.counters.cs.buffers,
};
let mut changes_sizes_buffer = false;
for index in 0..buffers {
let buf = &group.buffers[(index_base.buffers + index) as usize];
let mut offset = buf.offset;
if let Some(dyn_index) = buf.dynamic_index {
offset += dynamic_offsets[dyn_index as usize] as wgt::BufferAddress;
}
let a1 = (resource_indices.buffers + index) as u64;
let a2 = Some(buf.ptr.as_native());
let a3 = offset;
match stage {
naga::ShaderStage::Vertex => render_encoder.unwrap().set_vertex_buffer(a1, a2, a3),
naga::ShaderStage::Fragment => {
render_encoder.unwrap().set_fragment_buffer(a1, a2, a3)
}
naga::ShaderStage::Task => render_encoder.unwrap().set_object_buffer(a1, a2, a3),
naga::ShaderStage::Mesh => render_encoder.unwrap().set_mesh_buffer(a1, a2, a3),
naga::ShaderStage::Compute => compute_encoder.unwrap().set_buffer(a1, a2, a3),
}
if let Some(size) = buf.binding_size {
let br = naga::ResourceBinding {
group: group_index,
binding: buf.binding_location,
};
self.state.storage_buffer_length_map.insert(br, size);
changes_sizes_buffer = true;
}
}
if changes_sizes_buffer {
if let Some((index, sizes)) = self
.state
.make_sizes_buffer_update(stage, &mut self.temp.binding_sizes)
{
let a1 = index as _;
let a2 = (sizes.len() * WORD_SIZE) as u64;
let a3 = sizes.as_ptr().cast();
match stage {
naga::ShaderStage::Vertex => {
render_encoder.unwrap().set_vertex_bytes(a1, a2, a3)
}
naga::ShaderStage::Fragment => {
render_encoder.unwrap().set_fragment_bytes(a1, a2, a3)
}
naga::ShaderStage::Task => render_encoder.unwrap().set_object_bytes(a1, a2, a3),
naga::ShaderStage::Mesh => render_encoder.unwrap().set_mesh_bytes(a1, a2, a3),
naga::ShaderStage::Compute => compute_encoder.unwrap().set_bytes(a1, a2, a3),
}
}
}
let samplers = match stage {
naga::ShaderStage::Vertex => group.counters.vs.samplers,
naga::ShaderStage::Fragment => group.counters.fs.samplers,
naga::ShaderStage::Task => group.counters.ts.samplers,
naga::ShaderStage::Mesh => group.counters.ms.samplers,
naga::ShaderStage::Compute => group.counters.cs.samplers,
};
for index in 0..samplers {
let res = group.samplers[(index_base.samplers + index) as usize];
let a1 = (resource_indices.samplers + index) as u64;
let a2 = Some(res.as_native());
match stage {
naga::ShaderStage::Vertex => {
render_encoder.unwrap().set_vertex_sampler_state(a1, a2)
}
naga::ShaderStage::Fragment => {
render_encoder.unwrap().set_fragment_sampler_state(a1, a2)
}
naga::ShaderStage::Task => render_encoder.unwrap().set_object_sampler_state(a1, a2),
naga::ShaderStage::Mesh => render_encoder.unwrap().set_mesh_sampler_state(a1, a2),
naga::ShaderStage::Compute => compute_encoder.unwrap().set_sampler_state(a1, a2),
}
}
let textures = match stage {
naga::ShaderStage::Vertex => group.counters.vs.textures,
naga::ShaderStage::Fragment => group.counters.fs.textures,
naga::ShaderStage::Task => group.counters.ts.textures,
naga::ShaderStage::Mesh => group.counters.ms.textures,
naga::ShaderStage::Compute => group.counters.cs.textures,
};
for index in 0..textures {
let res = group.textures[(index_base.textures + index) as usize];
let a1 = (resource_indices.textures + index) as u64;
let a2 = Some(res.as_native());
match stage {
naga::ShaderStage::Vertex => render_encoder.unwrap().set_vertex_texture(a1, a2),
naga::ShaderStage::Fragment => render_encoder.unwrap().set_fragment_texture(a1, a2),
naga::ShaderStage::Task => render_encoder.unwrap().set_object_texture(a1, a2),
naga::ShaderStage::Mesh => render_encoder.unwrap().set_mesh_texture(a1, a2),
naga::ShaderStage::Compute => compute_encoder.unwrap().set_texture(a1, a2),
}
}
}
}
impl super::CommandState {
@ -155,7 +274,8 @@ impl super::CommandState {
self.stage_infos.vs.clear();
self.stage_infos.fs.clear();
self.stage_infos.cs.clear();
self.work_group_memory_sizes.clear();
self.stage_infos.ts.clear();
self.stage_infos.ms.clear();
self.push_constants.clear();
}
@ -702,168 +822,90 @@ impl crate::CommandEncoder for super::CommandEncoder {
dynamic_offsets: &[wgt::DynamicOffset],
) {
let bg_info = &layout.bind_group_infos[group_index as usize];
if let Some(ref encoder) = self.state.render {
let mut changes_sizes_buffer = false;
for index in 0..group.counters.vs.buffers {
let buf = &group.buffers[index as usize];
let mut offset = buf.offset;
if let Some(dyn_index) = buf.dynamic_index {
offset += dynamic_offsets[dyn_index as usize] as wgt::BufferAddress;
}
encoder.set_vertex_buffer(
(bg_info.base_resource_indices.vs.buffers + index) as u64,
Some(buf.ptr.as_native()),
offset,
);
if let Some(size) = buf.binding_size {
let br = naga::ResourceBinding {
group: group_index,
binding: buf.binding_location,
};
self.state.storage_buffer_length_map.insert(br, size);
changes_sizes_buffer = true;
}
}
if changes_sizes_buffer {
if let Some((index, sizes)) = self.state.make_sizes_buffer_update(
naga::ShaderStage::Vertex,
&mut self.temp.binding_sizes,
) {
encoder.set_vertex_bytes(
index as _,
(sizes.len() * WORD_SIZE) as u64,
sizes.as_ptr().cast(),
);
}
}
changes_sizes_buffer = false;
for index in 0..group.counters.fs.buffers {
let buf = &group.buffers[(group.counters.vs.buffers + index) as usize];
let mut offset = buf.offset;
if let Some(dyn_index) = buf.dynamic_index {
offset += dynamic_offsets[dyn_index as usize] as wgt::BufferAddress;
}
encoder.set_fragment_buffer(
(bg_info.base_resource_indices.fs.buffers + index) as u64,
Some(buf.ptr.as_native()),
offset,
);
if let Some(size) = buf.binding_size {
let br = naga::ResourceBinding {
group: group_index,
binding: buf.binding_location,
};
self.state.storage_buffer_length_map.insert(br, size);
changes_sizes_buffer = true;
}
}
if changes_sizes_buffer {
if let Some((index, sizes)) = self.state.make_sizes_buffer_update(
naga::ShaderStage::Fragment,
&mut self.temp.binding_sizes,
) {
encoder.set_fragment_bytes(
index as _,
(sizes.len() * WORD_SIZE) as u64,
sizes.as_ptr().cast(),
);
}
}
for index in 0..group.counters.vs.samplers {
let res = group.samplers[index as usize];
encoder.set_vertex_sampler_state(
(bg_info.base_resource_indices.vs.samplers + index) as u64,
Some(res.as_native()),
);
}
for index in 0..group.counters.fs.samplers {
let res = group.samplers[(group.counters.vs.samplers + index) as usize];
encoder.set_fragment_sampler_state(
(bg_info.base_resource_indices.fs.samplers + index) as u64,
Some(res.as_native()),
);
}
for index in 0..group.counters.vs.textures {
let res = group.textures[index as usize];
encoder.set_vertex_texture(
(bg_info.base_resource_indices.vs.textures + index) as u64,
Some(res.as_native()),
);
}
for index in 0..group.counters.fs.textures {
let res = group.textures[(group.counters.vs.textures + index) as usize];
encoder.set_fragment_texture(
(bg_info.base_resource_indices.fs.textures + index) as u64,
Some(res.as_native()),
);
}
let render_encoder = self.state.render.clone();
let compute_encoder = self.state.compute.clone();
if let Some(encoder) = render_encoder {
self.update_bind_group_state(
naga::ShaderStage::Vertex,
Some(&encoder),
None,
// All zeros, as vs comes first
super::ResourceData::default(),
bg_info,
dynamic_offsets,
group_index,
group,
);
self.update_bind_group_state(
naga::ShaderStage::Task,
Some(&encoder),
None,
// All zeros, as ts comes first
super::ResourceData::default(),
bg_info,
dynamic_offsets,
group_index,
group,
);
self.update_bind_group_state(
naga::ShaderStage::Mesh,
Some(&encoder),
None,
group.counters.ts.clone(),
bg_info,
dynamic_offsets,
group_index,
group,
);
self.update_bind_group_state(
naga::ShaderStage::Fragment,
Some(&encoder),
None,
super::ResourceData {
buffers: group.counters.vs.buffers
+ group.counters.ts.buffers
+ group.counters.ms.buffers,
textures: group.counters.vs.textures
+ group.counters.ts.textures
+ group.counters.ms.textures,
samplers: group.counters.vs.samplers
+ group.counters.ts.samplers
+ group.counters.ms.samplers,
},
bg_info,
dynamic_offsets,
group_index,
group,
);
// Call useResource on all textures and buffers used indirectly so they are alive
for (resource, use_info) in group.resources_to_use.iter() {
encoder.use_resource_at(resource.as_native(), use_info.uses, use_info.stages);
}
}
if let Some(ref encoder) = self.state.compute {
let index_base = super::ResourceData {
buffers: group.counters.vs.buffers + group.counters.fs.buffers,
samplers: group.counters.vs.samplers + group.counters.fs.samplers,
textures: group.counters.vs.textures + group.counters.fs.textures,
};
let mut changes_sizes_buffer = false;
for index in 0..group.counters.cs.buffers {
let buf = &group.buffers[(index_base.buffers + index) as usize];
let mut offset = buf.offset;
if let Some(dyn_index) = buf.dynamic_index {
offset += dynamic_offsets[dyn_index as usize] as wgt::BufferAddress;
}
encoder.set_buffer(
(bg_info.base_resource_indices.cs.buffers + index) as u64,
Some(buf.ptr.as_native()),
offset,
);
if let Some(size) = buf.binding_size {
let br = naga::ResourceBinding {
group: group_index,
binding: buf.binding_location,
};
self.state.storage_buffer_length_map.insert(br, size);
changes_sizes_buffer = true;
}
}
if changes_sizes_buffer {
if let Some((index, sizes)) = self.state.make_sizes_buffer_update(
naga::ShaderStage::Compute,
&mut self.temp.binding_sizes,
) {
encoder.set_bytes(
index as _,
(sizes.len() * WORD_SIZE) as u64,
sizes.as_ptr().cast(),
);
}
}
for index in 0..group.counters.cs.samplers {
let res = group.samplers[(index_base.samplers + index) as usize];
encoder.set_sampler_state(
(bg_info.base_resource_indices.cs.samplers + index) as u64,
Some(res.as_native()),
);
}
for index in 0..group.counters.cs.textures {
let res = group.textures[(index_base.textures + index) as usize];
encoder.set_texture(
(bg_info.base_resource_indices.cs.textures + index) as u64,
Some(res.as_native()),
);
}
if let Some(encoder) = compute_encoder {
self.update_bind_group_state(
naga::ShaderStage::Compute,
None,
Some(&encoder),
super::ResourceData {
buffers: group.counters.vs.buffers
+ group.counters.ts.buffers
+ group.counters.ms.buffers
+ group.counters.fs.buffers,
textures: group.counters.vs.textures
+ group.counters.ts.textures
+ group.counters.ms.textures
+ group.counters.fs.textures,
samplers: group.counters.vs.samplers
+ group.counters.ts.samplers
+ group.counters.ms.samplers
+ group.counters.fs.samplers,
},
bg_info,
dynamic_offsets,
group_index,
group,
);
// Call useResource on all textures and buffers used indirectly so they are alive
for (resource, use_info) in group.resources_to_use.iter() {
if !use_info.visible_in_compute {
@ -911,6 +953,20 @@ impl crate::CommandEncoder for super::CommandEncoder {
state_pc.as_ptr().cast(),
)
}
if stages.contains(wgt::ShaderStages::TASK) {
self.state.render.as_ref().unwrap().set_object_bytes(
layout.push_constants_infos.ts.unwrap().buffer_index as _,
(layout.total_push_constants as usize * WORD_SIZE) as _,
state_pc.as_ptr().cast(),
)
}
if stages.contains(wgt::ShaderStages::MESH) {
self.state.render.as_ref().unwrap().set_object_bytes(
layout.push_constants_infos.ms.unwrap().buffer_index as _,
(layout.total_push_constants as usize * WORD_SIZE) as _,
state_pc.as_ptr().cast(),
)
}
}
unsafe fn insert_debug_marker(&mut self, label: &str) {
@ -935,11 +991,22 @@ impl crate::CommandEncoder for super::CommandEncoder {
unsafe fn set_render_pipeline(&mut self, pipeline: &super::RenderPipeline) {
self.state.raw_primitive_type = pipeline.raw_primitive_type;
self.state.stage_infos.vs.assign_from(&pipeline.vs_info);
match pipeline.vs_info {
Some(ref info) => self.state.stage_infos.vs.assign_from(info),
None => self.state.stage_infos.vs.clear(),
}
match pipeline.fs_info {
Some(ref info) => self.state.stage_infos.fs.assign_from(info),
None => self.state.stage_infos.fs.clear(),
}
match pipeline.ts_info {
Some(ref info) => self.state.stage_infos.ts.assign_from(info),
None => self.state.stage_infos.ts.clear(),
}
match pipeline.ms_info {
Some(ref info) => self.state.stage_infos.ms.assign_from(info),
None => self.state.stage_infos.ms.clear(),
}
let encoder = self.state.render.as_ref().unwrap();
encoder.set_render_pipeline_state(&pipeline.raw);
@ -954,7 +1021,7 @@ impl crate::CommandEncoder for super::CommandEncoder {
encoder.set_depth_bias(bias.constant as f32, bias.slope_scale, bias.clamp);
}
{
if pipeline.vs_info.is_some() {
if let Some((index, sizes)) = self
.state
.make_sizes_buffer_update(naga::ShaderStage::Vertex, &mut self.temp.binding_sizes)
@ -966,7 +1033,7 @@ impl crate::CommandEncoder for super::CommandEncoder {
);
}
}
if pipeline.fs_lib.is_some() {
if pipeline.fs_info.is_some() {
if let Some((index, sizes)) = self
.state
.make_sizes_buffer_update(naga::ShaderStage::Fragment, &mut self.temp.binding_sizes)
@ -978,6 +1045,56 @@ impl crate::CommandEncoder for super::CommandEncoder {
);
}
}
if let Some(ts_info) = &pipeline.ts_info {
// update the threadgroup memory sizes
while self.state.stage_infos.ms.work_group_memory_sizes.len()
< ts_info.work_group_memory_sizes.len()
{
self.state.stage_infos.ms.work_group_memory_sizes.push(0);
}
for (index, (cur_size, pipeline_size)) in self
.state
.stage_infos
.ms
.work_group_memory_sizes
.iter_mut()
.zip(ts_info.work_group_memory_sizes.iter())
.enumerate()
{
let size = pipeline_size.next_multiple_of(16);
if *cur_size != size {
*cur_size = size;
encoder.set_object_threadgroup_memory_length(index as _, size as _);
}
}
if let Some((index, sizes)) = self
.state
.make_sizes_buffer_update(naga::ShaderStage::Task, &mut self.temp.binding_sizes)
{
encoder.set_object_bytes(
index as _,
(sizes.len() * WORD_SIZE) as u64,
sizes.as_ptr().cast(),
);
}
}
if let Some(_ms_info) = &pipeline.ms_info {
// So there isn't an equivalent to
// https://developer.apple.com/documentation/metal/mtlrendercommandencoder/setthreadgroupmemorylength(_:offset:index:)
// for mesh shaders. This is probably because the CPU has less control over the dispatch sizes and such. Interestingly
// it also affects mesh shaders without task/object shaders, even though none of compute, task or fragment shaders
// behave this way.
if let Some((index, sizes)) = self
.state
.make_sizes_buffer_update(naga::ShaderStage::Mesh, &mut self.temp.binding_sizes)
{
encoder.set_mesh_bytes(
index as _,
(sizes.len() * WORD_SIZE) as u64,
sizes.as_ptr().cast(),
);
}
}
}
unsafe fn set_index_buffer<'a>(
@ -1140,11 +1257,21 @@ impl crate::CommandEncoder for super::CommandEncoder {
unsafe fn draw_mesh_tasks(
&mut self,
_group_count_x: u32,
_group_count_y: u32,
_group_count_z: u32,
group_count_x: u32,
group_count_y: u32,
group_count_z: u32,
) {
unreachable!()
let encoder = self.state.render.as_ref().unwrap();
let size = MTLSize {
width: group_count_x as u64,
height: group_count_y as u64,
depth: group_count_z as u64,
};
encoder.draw_mesh_threadgroups(
size,
self.state.stage_infos.ts.raw_wg_size,
self.state.stage_infos.ms.raw_wg_size,
);
}
unsafe fn draw_indirect(
@ -1183,11 +1310,20 @@ impl crate::CommandEncoder for super::CommandEncoder {
unsafe fn draw_mesh_tasks_indirect(
&mut self,
_buffer: &<Self::A as crate::Api>::Buffer,
_offset: wgt::BufferAddress,
_draw_count: u32,
buffer: &<Self::A as crate::Api>::Buffer,
mut offset: wgt::BufferAddress,
draw_count: u32,
) {
unreachable!()
let encoder = self.state.render.as_ref().unwrap();
for _ in 0..draw_count {
encoder.draw_mesh_threadgroups_with_indirect_buffer(
&buffer.raw,
offset,
self.state.stage_infos.ts.raw_wg_size,
self.state.stage_infos.ms.raw_wg_size,
);
offset += size_of::<wgt::DispatchIndirectArgs>() as wgt::BufferAddress;
}
}
unsafe fn draw_indirect_count(
@ -1295,7 +1431,8 @@ impl crate::CommandEncoder for super::CommandEncoder {
}
unsafe fn set_compute_pipeline(&mut self, pipeline: &super::ComputePipeline) {
self.state.raw_wg_size = pipeline.work_group_size;
let previous_sizes =
core::mem::take(&mut self.state.stage_infos.cs.work_group_memory_sizes);
self.state.stage_infos.cs.assign_from(&pipeline.cs_info);
let encoder = self.state.compute.as_ref().unwrap();
@ -1313,20 +1450,23 @@ impl crate::CommandEncoder for super::CommandEncoder {
}
// update the threadgroup memory sizes
while self.state.work_group_memory_sizes.len() < pipeline.work_group_memory_sizes.len() {
self.state.work_group_memory_sizes.push(0);
}
for (index, (cur_size, pipeline_size)) in self
for (i, current_size) in self
.state
.stage_infos
.cs
.work_group_memory_sizes
.iter_mut()
.zip(pipeline.work_group_memory_sizes.iter())
.enumerate()
{
let size = pipeline_size.next_multiple_of(16);
if *cur_size != size {
*cur_size = size;
encoder.set_threadgroup_memory_length(index as _, size as _);
let prev_size = if i < previous_sizes.len() {
previous_sizes[i]
} else {
u32::MAX
};
let size: u32 = current_size.next_multiple_of(16);
*current_size = size;
if size != prev_size {
encoder.set_threadgroup_memory_length(i as _, size as _);
}
}
}
@ -1339,13 +1479,17 @@ impl crate::CommandEncoder for super::CommandEncoder {
height: count[1] as u64,
depth: count[2] as u64,
};
encoder.dispatch_thread_groups(raw_count, self.state.raw_wg_size);
encoder.dispatch_thread_groups(raw_count, self.state.stage_infos.cs.raw_wg_size);
}
}
unsafe fn dispatch_indirect(&mut self, buffer: &super::Buffer, offset: wgt::BufferAddress) {
let encoder = self.state.compute.as_ref().unwrap();
encoder.dispatch_thread_groups_indirect(&buffer.raw, offset, self.state.raw_wg_size);
encoder.dispatch_thread_groups_indirect(
&buffer.raw,
offset,
self.state.stage_infos.cs.raw_wg_size,
);
}
unsafe fn build_acceleration_structures<'a, T>(

View File

@ -1113,16 +1113,261 @@ impl crate::Device for super::Device {
super::PipelineCache,
>,
) -> Result<super::RenderPipeline, crate::PipelineError> {
let (desc_vertex_stage, desc_vertex_buffers) = match &desc.vertex_processor {
crate::VertexProcessor::Standard {
vertex_buffers,
vertex_stage,
} => (vertex_stage, *vertex_buffers),
crate::VertexProcessor::Mesh { .. } => unreachable!(),
};
objc::rc::autoreleasepool(|| {
let descriptor = metal::RenderPipelineDescriptor::new();
enum MetalGenericRenderPipelineDescriptor {
Standard(metal::RenderPipelineDescriptor),
Mesh(metal::MeshRenderPipelineDescriptor),
}
macro_rules! descriptor_fn {
($descriptor:ident . $method:ident $( ( $($args:expr),* ) )? ) => {
match $descriptor {
MetalGenericRenderPipelineDescriptor::Standard(ref inner) => inner.$method$(($($args),*))?,
MetalGenericRenderPipelineDescriptor::Mesh(ref inner) => inner.$method$(($($args),*))?,
}
};
}
impl MetalGenericRenderPipelineDescriptor {
fn set_fragment_function(&self, function: Option<&metal::FunctionRef>) {
descriptor_fn!(self.set_fragment_function(function));
}
fn fragment_buffers(&self) -> Option<&metal::PipelineBufferDescriptorArrayRef> {
descriptor_fn!(self.fragment_buffers())
}
fn set_depth_attachment_pixel_format(&self, pixel_format: MTLPixelFormat) {
descriptor_fn!(self.set_depth_attachment_pixel_format(pixel_format));
}
fn color_attachments(
&self,
) -> &metal::RenderPipelineColorAttachmentDescriptorArrayRef {
descriptor_fn!(self.color_attachments())
}
fn set_stencil_attachment_pixel_format(&self, pixel_format: MTLPixelFormat) {
descriptor_fn!(self.set_stencil_attachment_pixel_format(pixel_format));
}
fn set_alpha_to_coverage_enabled(&self, enabled: bool) {
descriptor_fn!(self.set_alpha_to_coverage_enabled(enabled));
}
fn set_label(&self, label: &str) {
descriptor_fn!(self.set_label(label));
}
fn set_max_vertex_amplification_count(&self, count: metal::NSUInteger) {
descriptor_fn!(self.set_max_vertex_amplification_count(count))
}
}
let (primitive_class, raw_primitive_type) =
conv::map_primitive_topology(desc.primitive.topology);
let vs_info;
let ts_info;
let ms_info;
// Create the pipeline descriptor and do vertex/mesh pipeline specific setup
let descriptor = match desc.vertex_processor {
crate::VertexProcessor::Standard {
vertex_buffers,
ref vertex_stage,
} => {
// Vertex pipeline specific setup
let descriptor = metal::RenderPipelineDescriptor::new();
ts_info = None;
ms_info = None;
// Collect vertex buffer mappings
let mut vertex_buffer_mappings =
Vec::<naga::back::msl::VertexBufferMapping>::new();
for (i, vbl) in vertex_buffers.iter().enumerate() {
let mut attributes = Vec::<naga::back::msl::AttributeMapping>::new();
for attribute in vbl.attributes.iter() {
attributes.push(naga::back::msl::AttributeMapping {
shader_location: attribute.shader_location,
offset: attribute.offset as u32,
format: convert_vertex_format_to_naga(attribute.format),
});
}
let mapping = naga::back::msl::VertexBufferMapping {
id: self.shared.private_caps.max_vertex_buffers - 1 - i as u32,
stride: if vbl.array_stride > 0 {
vbl.array_stride.try_into().unwrap()
} else {
vbl.attributes
.iter()
.map(|attribute| attribute.offset + attribute.format.size())
.max()
.unwrap_or(0)
.try_into()
.unwrap()
},
step_mode: match (vbl.array_stride == 0, vbl.step_mode) {
(true, _) => naga::back::msl::VertexBufferStepMode::Constant,
(false, wgt::VertexStepMode::Vertex) => {
naga::back::msl::VertexBufferStepMode::ByVertex
}
(false, wgt::VertexStepMode::Instance) => {
naga::back::msl::VertexBufferStepMode::ByInstance
}
},
attributes,
};
vertex_buffer_mappings.push(mapping);
}
// Setup vertex shader
{
let vs = self.load_shader(
vertex_stage,
&vertex_buffer_mappings,
desc.layout,
primitive_class,
naga::ShaderStage::Vertex,
)?;
descriptor.set_vertex_function(Some(&vs.function));
if self.shared.private_caps.supports_mutability {
Self::set_buffers_mutability(
descriptor.vertex_buffers().unwrap(),
vs.immutable_buffer_mask,
);
}
vs_info = Some(super::PipelineStageInfo {
push_constants: desc.layout.push_constants_infos.vs,
sizes_slot: desc.layout.per_stage_map.vs.sizes_buffer,
sized_bindings: vs.sized_bindings,
vertex_buffer_mappings,
library: Some(vs.library),
raw_wg_size: Default::default(),
work_group_memory_sizes: vec![],
});
}
// Validate vertex buffer count
if desc.layout.total_counters.vs.buffers + (vertex_buffers.len() as u32)
> self.shared.private_caps.max_vertex_buffers
{
let msg = format!(
"pipeline needs too many buffers in the vertex stage: {} vertex and {} layout",
vertex_buffers.len(),
desc.layout.total_counters.vs.buffers
);
return Err(crate::PipelineError::Linkage(
wgt::ShaderStages::VERTEX,
msg,
));
}
// Set the pipeline vertex buffer info
if !vertex_buffers.is_empty() {
let vertex_descriptor = metal::VertexDescriptor::new();
for (i, vb) in vertex_buffers.iter().enumerate() {
let buffer_index =
self.shared.private_caps.max_vertex_buffers as u64 - 1 - i as u64;
let buffer_desc =
vertex_descriptor.layouts().object_at(buffer_index).unwrap();
// Metal expects the stride to be the actual size of the attributes.
// The semantics of array_stride == 0 can be achieved by setting
// the step function to constant and rate to 0.
if vb.array_stride == 0 {
let stride = vb
.attributes
.iter()
.map(|attribute| attribute.offset + attribute.format.size())
.max()
.unwrap_or(0);
buffer_desc.set_stride(wgt::math::align_to(stride, 4));
buffer_desc.set_step_function(MTLVertexStepFunction::Constant);
buffer_desc.set_step_rate(0);
} else {
buffer_desc.set_stride(vb.array_stride);
buffer_desc.set_step_function(conv::map_step_mode(vb.step_mode));
}
for at in vb.attributes {
let attribute_desc = vertex_descriptor
.attributes()
.object_at(at.shader_location as u64)
.unwrap();
attribute_desc.set_format(conv::map_vertex_format(at.format));
attribute_desc.set_buffer_index(buffer_index);
attribute_desc.set_offset(at.offset);
}
}
descriptor.set_vertex_descriptor(Some(vertex_descriptor));
}
MetalGenericRenderPipelineDescriptor::Standard(descriptor)
}
crate::VertexProcessor::Mesh {
ref task_stage,
ref mesh_stage,
} => {
// Mesh pipeline specific setup
vs_info = None;
let descriptor = metal::MeshRenderPipelineDescriptor::new();
// Setup task stage
if let Some(ref task_stage) = task_stage {
let ts = self.load_shader(
task_stage,
&[],
desc.layout,
primitive_class,
naga::ShaderStage::Task,
)?;
descriptor.set_object_function(Some(&ts.function));
if self.shared.private_caps.supports_mutability {
Self::set_buffers_mutability(
descriptor.mesh_buffers().unwrap(),
ts.immutable_buffer_mask,
);
}
ts_info = Some(super::PipelineStageInfo {
push_constants: desc.layout.push_constants_infos.ts,
sizes_slot: desc.layout.per_stage_map.ts.sizes_buffer,
sized_bindings: ts.sized_bindings,
vertex_buffer_mappings: vec![],
library: Some(ts.library),
raw_wg_size: ts.wg_size,
work_group_memory_sizes: ts.wg_memory_sizes,
});
} else {
ts_info = None;
}
// Setup mesh stage
{
let ms = self.load_shader(
mesh_stage,
&[],
desc.layout,
primitive_class,
naga::ShaderStage::Mesh,
)?;
descriptor.set_mesh_function(Some(&ms.function));
if self.shared.private_caps.supports_mutability {
Self::set_buffers_mutability(
descriptor.mesh_buffers().unwrap(),
ms.immutable_buffer_mask,
);
}
ms_info = Some(super::PipelineStageInfo {
push_constants: desc.layout.push_constants_infos.ms,
sizes_slot: desc.layout.per_stage_map.ms.sizes_buffer,
sized_bindings: ms.sized_bindings,
vertex_buffer_mappings: vec![],
library: Some(ms.library),
raw_wg_size: ms.wg_size,
work_group_memory_sizes: ms.wg_memory_sizes,
});
}
MetalGenericRenderPipelineDescriptor::Mesh(descriptor)
}
};
let raw_triangle_fill_mode = match desc.primitive.polygon_mode {
wgt::PolygonMode::Fill => MTLTriangleFillMode::Fill,
@ -1133,76 +1378,8 @@ impl crate::Device for super::Device {
),
};
let (primitive_class, raw_primitive_type) =
conv::map_primitive_topology(desc.primitive.topology);
// Vertex shader
let (vs_lib, vs_info) = {
let mut vertex_buffer_mappings = Vec::<naga::back::msl::VertexBufferMapping>::new();
for (i, vbl) in desc_vertex_buffers.iter().enumerate() {
let mut attributes = Vec::<naga::back::msl::AttributeMapping>::new();
for attribute in vbl.attributes.iter() {
attributes.push(naga::back::msl::AttributeMapping {
shader_location: attribute.shader_location,
offset: attribute.offset as u32,
format: convert_vertex_format_to_naga(attribute.format),
});
}
vertex_buffer_mappings.push(naga::back::msl::VertexBufferMapping {
id: self.shared.private_caps.max_vertex_buffers - 1 - i as u32,
stride: if vbl.array_stride > 0 {
vbl.array_stride.try_into().unwrap()
} else {
vbl.attributes
.iter()
.map(|attribute| attribute.offset + attribute.format.size())
.max()
.unwrap_or(0)
.try_into()
.unwrap()
},
step_mode: match (vbl.array_stride == 0, vbl.step_mode) {
(true, _) => naga::back::msl::VertexBufferStepMode::Constant,
(false, wgt::VertexStepMode::Vertex) => {
naga::back::msl::VertexBufferStepMode::ByVertex
}
(false, wgt::VertexStepMode::Instance) => {
naga::back::msl::VertexBufferStepMode::ByInstance
}
},
attributes,
});
}
let vs = self.load_shader(
desc_vertex_stage,
&vertex_buffer_mappings,
desc.layout,
primitive_class,
naga::ShaderStage::Vertex,
)?;
descriptor.set_vertex_function(Some(&vs.function));
if self.shared.private_caps.supports_mutability {
Self::set_buffers_mutability(
descriptor.vertex_buffers().unwrap(),
vs.immutable_buffer_mask,
);
}
let info = super::PipelineStageInfo {
push_constants: desc.layout.push_constants_infos.vs,
sizes_slot: desc.layout.per_stage_map.vs.sizes_buffer,
sized_bindings: vs.sized_bindings,
vertex_buffer_mappings,
};
(vs.library, info)
};
// Fragment shader
let (fs_lib, fs_info) = match desc.fragment_stage {
let fs_info = match desc.fragment_stage {
Some(ref stage) => {
let fs = self.load_shader(
stage,
@ -1220,14 +1397,15 @@ impl crate::Device for super::Device {
);
}
let info = super::PipelineStageInfo {
Some(super::PipelineStageInfo {
push_constants: desc.layout.push_constants_infos.fs,
sizes_slot: desc.layout.per_stage_map.fs.sizes_buffer,
sized_bindings: fs.sized_bindings,
vertex_buffer_mappings: vec![],
};
(Some(fs.library), Some(info))
library: Some(fs.library),
raw_wg_size: Default::default(),
work_group_memory_sizes: vec![],
})
}
None => {
// TODO: This is a workaround for what appears to be a Metal validation bug
@ -1235,10 +1413,11 @@ impl crate::Device for super::Device {
if desc.color_targets.is_empty() && desc.depth_stencil.is_none() {
descriptor.set_depth_attachment_pixel_format(MTLPixelFormat::Depth32Float);
}
(None, None)
None
}
};
// Setup pipeline color attachments
for (i, ct) in desc.color_targets.iter().enumerate() {
let at_descriptor = descriptor.color_attachments().object_at(i as u64).unwrap();
let ct = if let Some(color_target) = ct.as_ref() {
@ -1267,6 +1446,7 @@ impl crate::Device for super::Device {
}
}
// Setup depth stencil state
let depth_stencil = match desc.depth_stencil {
Some(ref ds) => {
let raw_format = self.shared.private_caps.map_format(ds.format);
@ -1289,94 +1469,54 @@ impl crate::Device for super::Device {
None => None,
};
if desc.layout.total_counters.vs.buffers + (desc_vertex_buffers.len() as u32)
> self.shared.private_caps.max_vertex_buffers
{
let msg = format!(
"pipeline needs too many buffers in the vertex stage: {} vertex and {} layout",
desc_vertex_buffers.len(),
desc.layout.total_counters.vs.buffers
);
return Err(crate::PipelineError::Linkage(
wgt::ShaderStages::VERTEX,
msg,
));
}
if !desc_vertex_buffers.is_empty() {
let vertex_descriptor = metal::VertexDescriptor::new();
for (i, vb) in desc_vertex_buffers.iter().enumerate() {
let buffer_index =
self.shared.private_caps.max_vertex_buffers as u64 - 1 - i as u64;
let buffer_desc = vertex_descriptor.layouts().object_at(buffer_index).unwrap();
// Metal expects the stride to be the actual size of the attributes.
// The semantics of array_stride == 0 can be achieved by setting
// the step function to constant and rate to 0.
if vb.array_stride == 0 {
let stride = vb
.attributes
.iter()
.map(|attribute| attribute.offset + attribute.format.size())
.max()
.unwrap_or(0);
buffer_desc.set_stride(wgt::math::align_to(stride, 4));
buffer_desc.set_step_function(MTLVertexStepFunction::Constant);
buffer_desc.set_step_rate(0);
} else {
buffer_desc.set_stride(vb.array_stride);
buffer_desc.set_step_function(conv::map_step_mode(vb.step_mode));
}
for at in vb.attributes {
let attribute_desc = vertex_descriptor
.attributes()
.object_at(at.shader_location as u64)
.unwrap();
attribute_desc.set_format(conv::map_vertex_format(at.format));
attribute_desc.set_buffer_index(buffer_index);
attribute_desc.set_offset(at.offset);
}
}
descriptor.set_vertex_descriptor(Some(vertex_descriptor));
}
// Setup multisample state
if desc.multisample.count != 1 {
//TODO: handle sample mask
descriptor.set_sample_count(desc.multisample.count as u64);
match descriptor {
MetalGenericRenderPipelineDescriptor::Standard(ref inner) => {
inner.set_sample_count(desc.multisample.count as u64);
}
MetalGenericRenderPipelineDescriptor::Mesh(ref inner) => {
inner.set_raster_sample_count(desc.multisample.count as u64);
}
}
descriptor
.set_alpha_to_coverage_enabled(desc.multisample.alpha_to_coverage_enabled);
//descriptor.set_alpha_to_one_enabled(desc.multisample.alpha_to_one_enabled);
}
// Set debug label
if let Some(name) = desc.label {
descriptor.set_label(name);
}
if let Some(mv) = desc.multiview_mask {
descriptor.set_max_vertex_amplification_count(mv.get().count_ones() as u64);
}
let raw = self
.shared
.device
.lock()
.new_render_pipeline_state(&descriptor)
.map_err(|e| {
crate::PipelineError::Linkage(
wgt::ShaderStages::VERTEX | wgt::ShaderStages::FRAGMENT,
format!("new_render_pipeline_state: {e:?}"),
)
})?;
// Create the pipeline from descriptor
let raw = match descriptor {
MetalGenericRenderPipelineDescriptor::Standard(d) => {
self.shared.device.lock().new_render_pipeline_state(&d)
}
MetalGenericRenderPipelineDescriptor::Mesh(d) => {
self.shared.device.lock().new_mesh_render_pipeline_state(&d)
}
}
.map_err(|e| {
crate::PipelineError::Linkage(
wgt::ShaderStages::VERTEX | wgt::ShaderStages::FRAGMENT,
format!("new_render_pipeline_state: {e:?}"),
)
})?;
self.counters.render_pipelines.add(1);
Ok(super::RenderPipeline {
raw,
vs_lib,
fs_lib,
vs_info,
fs_info,
ts_info,
ms_info,
raw_primitive_type,
raw_triangle_fill_mode,
raw_front_winding: conv::map_winding(desc.primitive.front_face),
@ -1444,10 +1584,13 @@ impl crate::Device for super::Device {
}
let cs_info = super::PipelineStageInfo {
library: Some(cs.library),
push_constants: desc.layout.push_constants_infos.cs,
sizes_slot: desc.layout.per_stage_map.cs.sizes_buffer,
sized_bindings: cs.sized_bindings,
vertex_buffer_mappings: vec![],
raw_wg_size: cs.wg_size,
work_group_memory_sizes: cs.wg_memory_sizes,
};
if let Some(name) = desc.label {
@ -1468,13 +1611,7 @@ impl crate::Device for super::Device {
self.counters.compute_pipelines.add(1);
Ok(super::ComputePipeline {
raw,
cs_info,
cs_lib: cs.library,
work_group_size: cs.wg_size,
work_group_memory_sizes: cs.wg_memory_sizes,
})
Ok(super::ComputePipeline { raw, cs_info })
})
}

View File

@ -302,6 +302,7 @@ struct PrivateCapabilities {
int64_atomics: bool,
float_atomics: bool,
supports_shared_event: bool,
mesh_shaders: bool,
supported_vertex_amplification_factor: u32,
shader_barycentrics: bool,
supports_memoryless_storage: bool,
@ -609,12 +610,16 @@ struct MultiStageData<T> {
vs: T,
fs: T,
cs: T,
ts: T,
ms: T,
}
const NAGA_STAGES: MultiStageData<naga::ShaderStage> = MultiStageData {
vs: naga::ShaderStage::Vertex,
fs: naga::ShaderStage::Fragment,
cs: naga::ShaderStage::Compute,
ts: naga::ShaderStage::Task,
ms: naga::ShaderStage::Mesh,
};
impl<T> ops::Index<naga::ShaderStage> for MultiStageData<T> {
@ -624,7 +629,8 @@ impl<T> ops::Index<naga::ShaderStage> for MultiStageData<T> {
naga::ShaderStage::Vertex => &self.vs,
naga::ShaderStage::Fragment => &self.fs,
naga::ShaderStage::Compute => &self.cs,
naga::ShaderStage::Task | naga::ShaderStage::Mesh => unreachable!(),
naga::ShaderStage::Task => &self.ts,
naga::ShaderStage::Mesh => &self.ms,
}
}
}
@ -635,6 +641,8 @@ impl<T> MultiStageData<T> {
vs: fun(&self.vs),
fs: fun(&self.fs),
cs: fun(&self.cs),
ts: fun(&self.ts),
ms: fun(&self.ms),
}
}
fn map<Y>(self, fun: impl Fn(T) -> Y) -> MultiStageData<Y> {
@ -642,17 +650,23 @@ impl<T> MultiStageData<T> {
vs: fun(self.vs),
fs: fun(self.fs),
cs: fun(self.cs),
ts: fun(self.ts),
ms: fun(self.ms),
}
}
fn iter<'a>(&'a self) -> impl Iterator<Item = &'a T> {
iter::once(&self.vs)
.chain(iter::once(&self.fs))
.chain(iter::once(&self.cs))
.chain(iter::once(&self.ts))
.chain(iter::once(&self.ms))
}
fn iter_mut<'a>(&'a mut self) -> impl Iterator<Item = &'a mut T> {
iter::once(&mut self.vs)
.chain(iter::once(&mut self.fs))
.chain(iter::once(&mut self.cs))
.chain(iter::once(&mut self.ts))
.chain(iter::once(&mut self.ms))
}
}
@ -816,6 +830,8 @@ impl crate::DynShaderModule for ShaderModule {}
#[derive(Debug, Default)]
struct PipelineStageInfo {
#[allow(dead_code)]
library: Option<metal::Library>,
push_constants: Option<PushConstantsInfo>,
/// The buffer argument table index at which we pass runtime-sized arrays' buffer sizes.
@ -830,6 +846,12 @@ struct PipelineStageInfo {
/// Info on all bound vertex buffers.
vertex_buffer_mappings: Vec<naga::back::msl::VertexBufferMapping>,
/// The workgroup size for compute, task or mesh stages
raw_wg_size: MTLSize,
/// The workgroup memory sizes for compute task or mesh stages
work_group_memory_sizes: Vec<u32>,
}
impl PipelineStageInfo {
@ -838,6 +860,9 @@ impl PipelineStageInfo {
self.sizes_slot = None;
self.sized_bindings.clear();
self.vertex_buffer_mappings.clear();
self.library = None;
self.work_group_memory_sizes.clear();
self.raw_wg_size = Default::default();
}
fn assign_from(&mut self, other: &Self) {
@ -848,18 +873,21 @@ impl PipelineStageInfo {
self.vertex_buffer_mappings.clear();
self.vertex_buffer_mappings
.extend_from_slice(&other.vertex_buffer_mappings);
self.library = Some(other.library.as_ref().unwrap().clone());
self.raw_wg_size = other.raw_wg_size;
self.work_group_memory_sizes.clear();
self.work_group_memory_sizes
.extend_from_slice(&other.work_group_memory_sizes);
}
}
#[derive(Debug)]
pub struct RenderPipeline {
raw: metal::RenderPipelineState,
#[allow(dead_code)]
vs_lib: metal::Library,
#[allow(dead_code)]
fs_lib: Option<metal::Library>,
vs_info: PipelineStageInfo,
vs_info: Option<PipelineStageInfo>,
fs_info: Option<PipelineStageInfo>,
ts_info: Option<PipelineStageInfo>,
ms_info: Option<PipelineStageInfo>,
raw_primitive_type: MTLPrimitiveType,
raw_triangle_fill_mode: MTLTriangleFillMode,
raw_front_winding: MTLWinding,
@ -876,11 +904,7 @@ impl crate::DynRenderPipeline for RenderPipeline {}
#[derive(Debug)]
pub struct ComputePipeline {
raw: metal::ComputePipelineState,
#[allow(dead_code)]
cs_lib: metal::Library,
cs_info: PipelineStageInfo,
work_group_size: MTLSize,
work_group_memory_sizes: Vec<u32>,
}
unsafe impl Send for ComputePipeline {}
@ -954,7 +978,6 @@ struct CommandState {
compute: Option<metal::ComputeCommandEncoder>,
raw_primitive_type: MTLPrimitiveType,
index: Option<IndexState>,
raw_wg_size: MTLSize,
stage_infos: MultiStageData<PipelineStageInfo>,
/// Sizes of currently bound [`wgt::BufferBindingType::Storage`] buffers.
@ -980,7 +1003,6 @@ struct CommandState {
vertex_buffer_size_map: FastHashMap<u64, wgt::BufferSize>,
work_group_memory_sizes: Vec<u32>,
push_constants: Vec<u32>,
/// Timer query that should be executed when the next pass starts.

View File

@ -1169,12 +1169,11 @@ bitflags_array! {
/// This is a native only feature.
const UNIFORM_BUFFER_BINDING_ARRAYS = 1 << 47;
/// Enables mesh shaders and task shaders in mesh shader pipelines.
/// Enables mesh shaders and task shaders in mesh shader pipelines. This extension does NOT imply support for
/// compiling mesh shaders at runtime. Rather, the user must use custom passthrough shaders.
///
/// Supported platforms:
/// - Vulkan (with [VK_EXT_mesh_shader](https://registry.khronos.org/vulkan/specs/latest/man/html/VK_EXT_mesh_shader.html))
///
/// Potential Platforms:
/// - DX12
/// - Metal
///

View File

@ -1062,14 +1062,12 @@ impl Limits {
#[must_use]
pub const fn using_recommended_minimum_mesh_shader_values(self) -> Self {
Self {
// Literally just made this up as 256^2 or 2^16.
// My GPU supports 2^22, and compute shaders don't have this kind of limit.
// This very likely is never a real limiter
max_task_workgroup_total_count: 65536,
max_task_workgroups_per_dimension: 256,
// This is a common limit for apple devices. It's not immediately clear why.
max_task_workgroup_total_count: 1024,
max_task_workgroups_per_dimension: 1024,
// llvmpipe reports 0 multiview count, which just means no multiview is allowed
max_mesh_multiview_view_count: 0,
// llvmpipe once again requires this to be 8. An RTX 3060 supports well over 1024.
// llvmpipe once again requires this to be <=8. An RTX 3060 supports well over 1024.
max_mesh_output_layers: 8,
..self
}