[hal/dx12] Mesh Shaders (#8110)

* Features and draw commands added

* Tried to implement the pipeline creation (completely untested)

* Fixed clippy issues

* Fixed something I think

* A little bit of work on the mesh shader example (currently doesn't work on dx12)

* Reached a new kind of error state

* Fixed an alignment issue

* DirectX 12 mesh shaders working :party:

* Removed stupid change and updated changelog

* Fixed typo

* Added backends option to example framework

* Removed silly no write fragment shader from tests to see if anything breaks

* Tried to make mesh shader tests run elsewhere too

* Removed printlns and checked that dx12 mesh shader tests run

* Documented very strange issue

* I'm so lost

* Fixed stupid typos

* Fixed all issues

* Removed unnecessary example stuff, updated tests

* Updated typos.toml

* Updated limits

* Apply suggestion from @cwfitzgerald

Co-authored-by: Connor Fitzgerald <connorwadefitzgerald@gmail.com>

* Apply suggestion from @cwfitzgerald

Co-authored-by: Connor Fitzgerald <connorwadefitzgerald@gmail.com>

* Removed supported backends, made example & tests always pass the filename to shader compilers

* Removed excessive bools in test params

* Added new tests to the list

* I'm a sinner for this one (unused import)

* Replaced random stuff with test params hashing

* Updated typos.toml

* Updated -Fo typo thing

* Actually fixed typo issue this time

* Update CHANGELOG.md

Co-authored-by: Connor Fitzgerald <connorwadefitzgerald@gmail.com>

* Update tests/tests/wgpu-gpu/mesh_shader/mod.rs

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Update wgpu-hal/src/dx12/mod.rs

Co-authored-by: Connor Fitzgerald <connorwadefitzgerald@gmail.com>

* Addressed comments

* Lmao

---------

Co-authored-by: Connor Fitzgerald <connorwadefitzgerald@gmail.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
Magnus 2025-09-24 22:24:56 -05:00 committed by GitHub
parent 1f10d0ce8a
commit 05cc6dca82
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
16 changed files with 739 additions and 225 deletions

View File

@ -202,6 +202,7 @@ By @cwfitzgerald in [#8162](https://github.com/gfx-rs/wgpu/pull/8162).
#### DX12
- Allow disabling waiting for latency waitable object. By @marcpabst in [#7400](https://github.com/gfx-rs/wgpu/pull/7400)
- Add mesh shader support, including to the example. By @SupaMaggie70Incorporated in [#8110](https://github.com/gfx-rs/wgpu/issues/8110)
### Bug Fixes

View File

@ -268,9 +268,9 @@ impl ExampleContext {
async fn init_async<E: Example>(surface: &mut SurfaceWrapper, window: Arc<Window>) -> Self {
log::info!("Initializing wgpu...");
let instance = wgpu::Instance::new(&wgpu::InstanceDescriptor::from_env_or_default());
let instance_descriptor = wgpu::InstanceDescriptor::from_env_or_default();
let instance = wgpu::Instance::new(&instance_descriptor);
surface.pre_adapter(&instance, window);
let adapter = get_adapter_with_capabilities_or_from_env(
&instance,
&E::required_features(),

View File

@ -1,15 +1,13 @@
use std::{io::Write, process::Stdio};
use std::process::Stdio;
// Same as in mesh shader tests
fn compile_glsl(
device: &wgpu::Device,
data: &[u8],
shader_stage: &'static str,
) -> wgpu::ShaderModule {
fn compile_glsl(device: &wgpu::Device, shader_stage: &'static str) -> wgpu::ShaderModule {
let cmd = std::process::Command::new("glslc")
.args([
&format!("-fshader-stage={shader_stage}"),
"-",
&format!(
"{}/src/mesh_shader/shader.{shader_stage}",
env!("CARGO_MANIFEST_DIR")
),
"-o",
"-",
"--target-env=vulkan1.2",
@ -19,8 +17,6 @@ fn compile_glsl(
.stdout(Stdio::piped())
.spawn()
.expect("Failed to call glslc");
cmd.stdin.as_ref().unwrap().write_all(data).unwrap();
println!("{shader_stage}");
let output = cmd.wait_with_output().expect("Error waiting for glslc");
assert!(output.status.success());
unsafe {
@ -32,6 +28,38 @@ fn compile_glsl(
})
}
}
fn compile_hlsl(device: &wgpu::Device, entry: &str, stage_str: &str) -> wgpu::ShaderModule {
let out_path = format!(
"{}/src/mesh_shader/shader.{stage_str}.cso",
env!("CARGO_MANIFEST_DIR")
);
let cmd = std::process::Command::new("dxc")
.args([
"-T",
&format!("{stage_str}_6_5"),
"-E",
entry,
&format!("{}/src/mesh_shader/shader.hlsl", env!("CARGO_MANIFEST_DIR")),
"-Fo",
&out_path,
])
.output()
.unwrap();
if !cmd.status.success() {
panic!("DXC failed:\n{}", String::from_utf8(cmd.stderr).unwrap());
}
let file = std::fs::read(&out_path).unwrap();
std::fs::remove_file(out_path).unwrap();
unsafe {
device.create_shader_module_passthrough(wgpu::ShaderModuleDescriptorPassthrough {
entry_point: entry.to_owned(),
label: None,
num_workgroups: (1, 1, 1),
dxil: Some(std::borrow::Cow::Owned(file)),
..Default::default()
})
}
}
pub struct Example {
pipeline: wgpu::RenderPipeline,
@ -39,20 +67,30 @@ pub struct Example {
impl crate::framework::Example for Example {
fn init(
config: &wgpu::SurfaceConfiguration,
_adapter: &wgpu::Adapter,
adapter: &wgpu::Adapter,
device: &wgpu::Device,
_queue: &wgpu::Queue,
) -> Self {
let (ts, ms, fs) = if adapter.get_info().backend == wgpu::Backend::Vulkan {
(
compile_glsl(device, "task"),
compile_glsl(device, "mesh"),
compile_glsl(device, "frag"),
)
} else if adapter.get_info().backend == wgpu::Backend::Dx12 {
(
compile_hlsl(device, "Task", "as"),
compile_hlsl(device, "Mesh", "ms"),
compile_hlsl(device, "Frag", "ps"),
)
} else {
panic!("Example can only run on vulkan or dx12");
};
let pipeline_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
label: None,
bind_group_layouts: &[],
push_constant_ranges: &[],
});
let (ts, ms, fs) = (
compile_glsl(device, include_bytes!("shader.task"), "task"),
compile_glsl(device, include_bytes!("shader.mesh"), "mesh"),
compile_glsl(device, include_bytes!("shader.frag"), "frag"),
);
let pipeline = device.create_mesh_pipeline(&wgpu::MeshPipelineDescriptor {
label: None,
layout: Some(&pipeline_layout),

View File

@ -0,0 +1,53 @@
struct OutVertex {
float4 Position : SV_POSITION;
float4 Color: COLOR;
};
struct OutPrimitive {
float4 ColorMask : COLOR_MASK : PRIMITIVE;
bool CullPrimitive: SV_CullPrimitive;
};
struct InVertex {
float4 Color: COLOR;
};
struct InPrimitive {
float4 ColorMask : COLOR_MASK : PRIMITIVE;
};
struct PayloadData {
float4 ColorMask;
bool Visible;
};
static const float4 positions[3] = {float4(0., 1.0, 0., 1.0), float4(-1.0, -1.0, 0., 1.0), float4(1.0, -1.0, 0., 1.0)};
static const float4 colors[3] = {float4(0., 1., 0., 1.), float4(0., 0., 1., 1.), float4(1., 0., 0., 1.)};
groupshared PayloadData outPayload;
[numthreads(1, 1, 1)]
void Task() {
outPayload.ColorMask = float4(1.0, 1.0, 0.0, 1.0);
outPayload.Visible = true;
DispatchMesh(3, 1, 1, outPayload);
}
[outputtopology("triangle")]
[numthreads(1, 1, 1)]
void Mesh(out indices uint3 triangles[1], out vertices OutVertex vertices[3], out primitives OutPrimitive primitives[1], in payload PayloadData payload) {
SetMeshOutputCounts(3, 1);
vertices[0].Position = positions[0];
vertices[1].Position = positions[1];
vertices[2].Position = positions[2];
vertices[0].Color = colors[0] * payload.ColorMask;
vertices[1].Color = colors[1] * payload.ColorMask;
vertices[2].Color = colors[2] * payload.ColorMask;
triangles[0] = uint3(0, 1, 2);
primitives[0].ColorMask = float4(1.0, 0.0, 0.0, 1.0);
primitives[0].CullPrimitive = !payload.Visible;
}
float4 Frag(InVertex vertex, InPrimitive primitive) : SV_Target {
return vertex.Color * primitive.ColorMask;
}

View File

@ -20,18 +20,18 @@ vertexOutput[];
layout(location = 1) perprimitiveEXT out PrimitiveOutput { vec4 colorMask; }
primitiveOutput[];
shared uint sharedData;
layout(triangles, max_vertices = 3, max_primitives = 1) out;
void main() {
sharedData = 5;
SetMeshOutputsEXT(3, 1);
gl_MeshVerticesEXT[0].gl_Position = positions[0];
gl_MeshVerticesEXT[1].gl_Position = positions[1];
gl_MeshVerticesEXT[2].gl_Position = positions[2];
vertexOutput[0].color = colors[0] * payloadData.colorMask;
vertexOutput[1].color = colors[1] * payloadData.colorMask;
vertexOutput[2].color = colors[2] * payloadData.colorMask;
gl_PrimitiveTriangleIndicesEXT[gl_LocalInvocationIndex] = uvec3(0, 1, 2);
primitiveOutput[0].colorMask = vec4(1.0, 0.0, 1.0, 1.0);
gl_MeshPrimitivesEXT[0].gl_CullPrimitiveEXT = !payloadData.visible;

View File

@ -1,7 +1,7 @@
#version 450
#extension GL_EXT_mesh_shader : require
layout(local_size_x = 4, local_size_y = 1, local_size_z = 1) in;
layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
struct TaskPayload {
vec4 colorMask;

View File

@ -283,7 +283,8 @@ impl crate::ShaderStage {
Self::Vertex => "vs",
Self::Fragment => "ps",
Self::Compute => "cs",
Self::Task | Self::Mesh => unreachable!(),
Self::Task => "as",
Self::Mesh => "ms",
}
}
}

View File

@ -12,6 +12,7 @@ use crate::{
GpuTestConfiguration,
};
#[derive(Hash)]
/// Parameters and resources handed to the test function.
pub struct TestingContext {
pub instance: Instance,

View File

@ -0,0 +1,41 @@
struct OutVertex {
float4 Position : SV_POSITION;
float4 Color: COLOR;
};
struct InVertex {
float4 Color: COLOR;
};
static const float4 positions[3] = {float4(0., 1.0, 0., 1.0), float4(-1.0, -1.0, 0., 1.0), float4(1.0, -1.0, 0., 1.0)};
static const float4 colors[3] = {float4(0., 1., 0., 1.), float4(0., 0., 1., 1.), float4(1., 0., 0., 1.)};
struct EmptyPayload {
uint _nullField;
};
groupshared EmptyPayload _emptyPayload;
[numthreads(4, 1, 1)]
void Task() {
DispatchMesh(1, 1, 1, _emptyPayload);
}
[outputtopology("triangle")]
[numthreads(1, 1, 1)]
void Mesh(out indices uint3 triangles[1], out vertices OutVertex vertices[3], in payload EmptyPayload _emptyPayload) {
SetMeshOutputCounts(3, 1);
vertices[0].Position = positions[0];
vertices[1].Position = positions[1];
vertices[2].Position = positions[2];
vertices[0].Color = colors[0];
vertices[1].Color = colors[1];
vertices[2].Color = colors[2];
triangles[0] = uint3(0, 1, 2);
}
float4 Frag(InVertex vertex) : SV_Target {
return vertex.Color;
}

View File

@ -1,4 +1,7 @@
use std::{io::Write, process::Stdio};
use std::{
hash::{DefaultHasher, Hash, Hasher},
process::Stdio,
};
use wgpu::util::DeviceExt;
use wgpu_test::{
@ -14,19 +17,19 @@ pub fn all_tests(tests: &mut Vec<GpuTestInitializer>) {
MESH_DRAW_INDIRECT,
MESH_MULTI_DRAW_INDIRECT,
MESH_MULTI_DRAW_INDIRECT_COUNT,
MESH_PIPELINE_BASIC_MESH_NO_DRAW,
MESH_PIPELINE_BASIC_TASK_MESH_FRAG_NO_DRAW,
]);
}
// Same as in mesh shader example
fn compile_glsl(
device: &wgpu::Device,
data: &[u8],
shader_stage: &'static str,
) -> wgpu::ShaderModule {
fn compile_glsl(device: &wgpu::Device, shader_stage: &'static str) -> wgpu::ShaderModule {
let cmd = std::process::Command::new("glslc")
.args([
&format!("-fshader-stage={shader_stage}"),
"-",
&format!(
"{}/tests/wgpu-gpu/mesh_shader/basic.{shader_stage}",
env!("CARGO_MANIFEST_DIR")
),
"-o",
"-",
"--target-env=vulkan1.2",
@ -36,8 +39,6 @@ fn compile_glsl(
.stdout(Stdio::piped())
.spawn()
.expect("Failed to call glslc");
cmd.stdin.as_ref().unwrap().write_all(data).unwrap();
println!("{shader_stage}");
let output = cmd.wait_with_output().expect("Error waiting for glslc");
assert!(output.status.success());
unsafe {
@ -50,6 +51,70 @@ fn compile_glsl(
}
}
fn compile_hlsl(
device: &wgpu::Device,
entry: &str,
stage_str: &str,
test_name: &str,
) -> wgpu::ShaderModule {
// Each test needs its own files
let out_path = format!(
"{}/tests/wgpu-gpu/mesh_shader/{test_name}.{stage_str}.cso",
env!("CARGO_MANIFEST_DIR")
);
let cmd = std::process::Command::new("dxc")
.args([
"-T",
&format!("{stage_str}_6_5"),
"-E",
entry,
&format!(
"{}/tests/wgpu-gpu/mesh_shader/basic.hlsl",
env!("CARGO_MANIFEST_DIR")
),
"-Fo",
&out_path,
])
.output()
.unwrap();
if !cmd.status.success() {
panic!("DXC failed:\n{}", String::from_utf8(cmd.stderr).unwrap());
}
let file = std::fs::read(&out_path).unwrap();
std::fs::remove_file(out_path).unwrap();
unsafe {
device.create_shader_module_passthrough(wgpu::ShaderModuleDescriptorPassthrough {
entry_point: entry.to_owned(),
label: None,
num_workgroups: (1, 1, 1),
dxil: Some(std::borrow::Cow::Owned(file)),
..Default::default()
})
}
}
fn get_shaders(
device: &wgpu::Device,
backend: wgpu::Backend,
test_name: &str,
) -> (wgpu::ShaderModule, wgpu::ShaderModule, wgpu::ShaderModule) {
if backend == wgpu::Backend::Vulkan {
(
compile_glsl(device, "task"),
compile_glsl(device, "mesh"),
compile_glsl(device, "frag"),
)
} else if backend == wgpu::Backend::Dx12 {
(
compile_hlsl(device, "Task", "as", test_name),
compile_hlsl(device, "Mesh", "ms", test_name),
compile_hlsl(device, "Frag", "ps", test_name),
)
} else {
unreachable!()
}
}
fn create_depth(
device: &wgpu::Device,
) -> (wgpu::Texture, wgpu::TextureView, wgpu::DepthStencilState) {
@ -79,18 +144,30 @@ fn create_depth(
(depth_texture, depth_view, state)
}
fn mesh_pipeline_build(
ctx: &TestingContext,
task: Option<&[u8]>,
mesh: &[u8],
frag: Option<&[u8]>,
struct MeshPipelineTestInfo {
use_task: bool,
use_frag: bool,
draw: bool,
) {
}
fn hash_testing_context(ctx: &TestingContext) -> u64 {
let mut hasher = DefaultHasher::new();
ctx.hash(&mut hasher);
hasher.finish()
}
fn mesh_pipeline_build(ctx: &TestingContext, info: MeshPipelineTestInfo) {
let backend = ctx.adapter.get_info().backend;
if backend != wgpu::Backend::Vulkan && backend != wgpu::Backend::Dx12 {
return;
}
let device = &ctx.device;
let (_depth_image, depth_view, depth_state) = create_depth(device);
let task = task.map(|t| compile_glsl(device, t, "task"));
let mesh = compile_glsl(device, mesh, "mesh");
let frag = frag.map(|f| compile_glsl(device, f, "frag"));
let test_hash = hash_testing_context(ctx).to_string();
let (task, mesh, frag) = get_shaders(device, backend, &test_hash);
let task = if info.use_task { Some(task) } else { None };
let frag = if info.use_frag { Some(frag) } else { None };
let layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
label: None,
bind_group_layouts: &[],
@ -124,7 +201,7 @@ fn mesh_pipeline_build(
multiview: None,
cache: None,
});
if draw {
if info.draw {
let mut encoder =
device.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None });
{
@ -160,11 +237,14 @@ pub enum DrawType {
}
fn mesh_draw(ctx: &TestingContext, draw_type: DrawType) {
let backend = ctx.adapter.get_info().backend;
if backend != wgpu::Backend::Vulkan && backend != wgpu::Backend::Dx12 {
return;
}
let device = &ctx.device;
let (_depth_image, depth_view, depth_state) = create_depth(device);
let task = compile_glsl(device, BASIC_TASK, "task");
let mesh = compile_glsl(device, BASIC_MESH, "mesh");
let frag = compile_glsl(device, NO_WRITE_FRAG, "frag");
let test_hash = hash_testing_context(ctx).to_string();
let (task, mesh, frag) = get_shaders(device, backend, &test_hash);
let layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
label: None,
bind_group_layouts: &[],
@ -256,11 +336,6 @@ fn mesh_draw(ctx: &TestingContext, draw_type: DrawType) {
ctx.device.poll(wgpu::PollType::Wait).unwrap();
}
const BASIC_TASK: &[u8] = include_bytes!("basic.task");
const BASIC_MESH: &[u8] = include_bytes!("basic.mesh");
//const BASIC_FRAG: &[u8] = include_bytes!("basic.frag.spv");
const NO_WRITE_FRAG: &[u8] = include_bytes!("no-write.frag");
fn default_gpu_test_config(draw_type: DrawType) -> GpuTestConfiguration {
GpuTestConfiguration::new().parameters(
TestParameters::default()
@ -279,47 +354,92 @@ fn default_gpu_test_config(draw_type: DrawType) -> GpuTestConfiguration {
)
}
// Mesh pipeline configs
#[gpu_test]
static MESH_PIPELINE_BASIC_MESH: GpuTestConfiguration = default_gpu_test_config(DrawType::Standard)
.run_sync(|ctx| {
mesh_pipeline_build(&ctx, None, BASIC_MESH, None, true);
});
#[gpu_test]
static MESH_PIPELINE_BASIC_TASK_MESH: GpuTestConfiguration =
default_gpu_test_config(DrawType::Standard).run_sync(|ctx| {
mesh_pipeline_build(&ctx, Some(BASIC_TASK), BASIC_MESH, None, true);
});
#[gpu_test]
static MESH_PIPELINE_BASIC_MESH_FRAG: GpuTestConfiguration =
default_gpu_test_config(DrawType::Standard).run_sync(|ctx| {
mesh_pipeline_build(&ctx, None, BASIC_MESH, Some(NO_WRITE_FRAG), true);
});
#[gpu_test]
static MESH_PIPELINE_BASIC_TASK_MESH_FRAG: GpuTestConfiguration =
pub static MESH_PIPELINE_BASIC_MESH: GpuTestConfiguration =
default_gpu_test_config(DrawType::Standard).run_sync(|ctx| {
mesh_pipeline_build(
&ctx,
Some(BASIC_TASK),
BASIC_MESH,
Some(NO_WRITE_FRAG),
true,
MeshPipelineTestInfo {
use_task: false,
use_frag: false,
draw: true,
},
);
});
#[gpu_test]
pub static MESH_PIPELINE_BASIC_TASK_MESH: GpuTestConfiguration =
default_gpu_test_config(DrawType::Standard).run_sync(|ctx| {
mesh_pipeline_build(
&ctx,
MeshPipelineTestInfo {
use_task: true,
use_frag: false,
draw: true,
},
);
});
#[gpu_test]
pub static MESH_PIPELINE_BASIC_MESH_FRAG: GpuTestConfiguration =
default_gpu_test_config(DrawType::Standard).run_sync(|ctx| {
mesh_pipeline_build(
&ctx,
MeshPipelineTestInfo {
use_task: false,
use_frag: true,
draw: true,
},
);
});
#[gpu_test]
pub static MESH_PIPELINE_BASIC_TASK_MESH_FRAG: GpuTestConfiguration =
default_gpu_test_config(DrawType::Standard).run_sync(|ctx| {
mesh_pipeline_build(
&ctx,
MeshPipelineTestInfo {
use_task: true,
use_frag: true,
draw: true,
},
);
});
#[gpu_test]
pub static MESH_PIPELINE_BASIC_MESH_NO_DRAW: GpuTestConfiguration =
default_gpu_test_config(DrawType::Standard).run_sync(|ctx| {
mesh_pipeline_build(
&ctx,
MeshPipelineTestInfo {
use_task: false,
use_frag: false,
draw: false,
},
);
});
#[gpu_test]
pub static MESH_PIPELINE_BASIC_TASK_MESH_FRAG_NO_DRAW: GpuTestConfiguration =
default_gpu_test_config(DrawType::Standard).run_sync(|ctx| {
mesh_pipeline_build(
&ctx,
MeshPipelineTestInfo {
use_task: true,
use_frag: true,
draw: false,
},
);
});
// Mesh draw
#[gpu_test]
static MESH_DRAW_INDIRECT: GpuTestConfiguration = default_gpu_test_config(DrawType::Indirect)
pub static MESH_DRAW_INDIRECT: GpuTestConfiguration = default_gpu_test_config(DrawType::Indirect)
.run_sync(|ctx| {
mesh_draw(&ctx, DrawType::Indirect);
});
#[gpu_test]
static MESH_MULTI_DRAW_INDIRECT: GpuTestConfiguration =
pub static MESH_MULTI_DRAW_INDIRECT: GpuTestConfiguration =
default_gpu_test_config(DrawType::MultiIndirect).run_sync(|ctx| {
mesh_draw(&ctx, DrawType::MultiIndirect);
});
#[gpu_test]
static MESH_MULTI_DRAW_INDIRECT_COUNT: GpuTestConfiguration =
pub static MESH_MULTI_DRAW_INDIRECT_COUNT: GpuTestConfiguration =
default_gpu_test_config(DrawType::MultiIndirectCount).run_sync(|ctx| {
mesh_draw(&ctx, DrawType::MultiIndirectCount);
});

View File

@ -1,7 +0,0 @@
#version 450
#extension GL_EXT_mesh_shader : require
in VertexInput { layout(location = 0) vec4 color; }
vertexInput;
void main() {}

View File

@ -20,6 +20,9 @@ extend-exclude = [
lod = "lod"
metalness = "metalness"
# A DXC command line argument
Fo = "Fo"
# Usernames
Healthire = "Healthire"
REASY = "REASY"

View File

@ -528,6 +528,22 @@ impl super::Adapter {
wgt::Features::SHADER_INT64_ATOMIC_ALL_OPS | wgt::Features::SHADER_INT64_ATOMIC_MIN_MAX,
atomic_int64_on_typed_resource_supported,
);
let mesh_shader_supported = {
let mut features7 = Direct3D12::D3D12_FEATURE_DATA_D3D12_OPTIONS7::default();
unsafe {
device.CheckFeatureSupport(
Direct3D12::D3D12_FEATURE_D3D12_OPTIONS7,
<*mut _>::cast(&mut features7),
size_of_val(&features7) as u32,
)
}
.is_ok()
&& features7.MeshShaderTier != Direct3D12::D3D12_MESH_SHADER_TIER_NOT_SUPPORTED
};
features.set(
wgt::Features::EXPERIMENTAL_MESH_SHADER,
mesh_shader_supported,
);
// TODO: Determine if IPresentationManager is supported
let presentation_timer = auxil::dxgi::time::PresentationTimer::new_dxgi();
@ -648,10 +664,15 @@ impl super::Adapter {
max_buffer_size: i32::MAX as u64,
max_non_sampler_bindings: 1_000_000,
max_task_workgroup_total_count: 0,
max_task_workgroups_per_dimension: 0,
// Source: https://microsoft.github.io/DirectX-Specs/d3d/MeshShader.html#dispatchmesh-api
max_task_workgroup_total_count: 2u32.pow(22),
// Technically it says "64k" but I highly doubt they want 65536 for compute and exactly 64,000 for task workgroups
max_task_workgroups_per_dimension:
Direct3D12::D3D12_CS_DISPATCH_MAX_THREAD_GROUPS_PER_DIMENSION,
// Multiview not supported by WGPU yet
max_mesh_multiview_count: 0,
max_mesh_output_layers: 0,
// This seems to be right, and I can't find anything to suggest it would be less than the 2048 provided here
max_mesh_output_layers: Direct3D12::D3D12_REQ_TEXTURE2D_ARRAY_AXIS_DIMENSION,
max_blas_primitive_count: if supports_ray_tracing {
1 << 29 // 2^29

View File

@ -1228,11 +1228,16 @@ impl crate::CommandEncoder for super::CommandEncoder {
}
unsafe fn draw_mesh_tasks(
&mut self,
_group_count_x: u32,
_group_count_y: u32,
_group_count_z: u32,
group_count_x: u32,
group_count_y: u32,
group_count_z: u32,
) {
unreachable!()
self.prepare_dispatch([group_count_x, group_count_y, group_count_z]);
let cmd_list6: Direct3D12::ID3D12GraphicsCommandList6 =
self.list.as_ref().unwrap().cast().unwrap();
unsafe {
cmd_list6.DispatchMesh(group_count_x, group_count_y, group_count_z);
}
}
unsafe fn draw_indirect(
&mut self,
@ -1314,11 +1319,36 @@ impl crate::CommandEncoder for super::CommandEncoder {
}
unsafe fn draw_mesh_tasks_indirect(
&mut self,
_buffer: &<Self::A as crate::Api>::Buffer,
_offset: wgt::BufferAddress,
_draw_count: u32,
buffer: &<Self::A as crate::Api>::Buffer,
offset: wgt::BufferAddress,
draw_count: u32,
) {
unreachable!()
if self
.pass
.layout
.special_constants
.as_ref()
.and_then(|sc| sc.indirect_cmd_signatures.as_ref())
.is_some()
{
self.update_root_elements();
} else {
self.prepare_dispatch([0; 3]);
}
let cmd_list6: Direct3D12::ID3D12GraphicsCommandList6 =
self.list.as_ref().unwrap().cast().unwrap();
let cmd_signature = &self
.pass
.layout
.special_constants
.as_ref()
.and_then(|sc| sc.indirect_cmd_signatures.as_ref())
.unwrap_or_else(|| &self.shared.cmd_signatures)
.draw_mesh;
unsafe {
cmd_list6.ExecuteIndirect(cmd_signature, draw_count, &buffer.resource, offset, None, 0);
}
}
unsafe fn draw_indirect_count(
&mut self,
@ -1362,13 +1392,25 @@ impl crate::CommandEncoder for super::CommandEncoder {
}
unsafe fn draw_mesh_tasks_indirect_count(
&mut self,
_buffer: &<Self::A as crate::Api>::Buffer,
_offset: wgt::BufferAddress,
_count_buffer: &<Self::A as crate::Api>::Buffer,
_count_offset: wgt::BufferAddress,
_max_count: u32,
buffer: &<Self::A as crate::Api>::Buffer,
offset: wgt::BufferAddress,
count_buffer: &<Self::A as crate::Api>::Buffer,
count_offset: wgt::BufferAddress,
max_count: u32,
) {
unreachable!()
self.prepare_dispatch([0; 3]);
let cmd_list6: Direct3D12::ID3D12GraphicsCommandList6 =
self.list.as_ref().unwrap().cast().unwrap();
unsafe {
cmd_list6.ExecuteIndirect(
&self.shared.cmd_signatures.draw_mesh,
max_count,
&buffer.resource,
offset,
&count_buffer.resource,
count_offset,
);
}
}
// compute

View File

@ -140,6 +140,16 @@ impl super::Device {
}],
0,
)?,
draw_mesh: Self::create_command_signature(
&raw,
None,
size_of::<wgt::DispatchIndirectArgs>(),
&[Direct3D12::D3D12_INDIRECT_ARGUMENT_DESC {
Type: Direct3D12::D3D12_INDIRECT_ARGUMENT_TYPE_DISPATCH_MESH,
..Default::default()
}],
0,
)?,
dispatch: Self::create_command_signature(
&raw,
None,
@ -1394,6 +1404,19 @@ impl crate::Device for super::Device {
],
0,
)?,
draw_mesh: Self::create_command_signature(
&self.raw,
Some(&raw),
special_constant_buffer_args_len + size_of::<wgt::DispatchIndirectArgs>(),
&[
constant_indirect_argument_desc,
Direct3D12::D3D12_INDIRECT_ARGUMENT_DESC {
Type: Direct3D12::D3D12_INDIRECT_ARGUMENT_TYPE_DISPATCH_MESH,
..Default::default()
},
],
0,
)?,
dispatch: Self::create_command_signature(
&self.raw,
Some(&raw),
@ -1825,60 +1848,10 @@ impl crate::Device for super::Device {
super::PipelineCache,
>,
) -> Result<super::RenderPipeline, crate::PipelineError> {
let mut shader_stages = wgt::ShaderStages::empty();
let root_signature =
unsafe { borrow_optional_interface_temporarily(&desc.layout.shared.signature) };
let (topology_class, topology) = conv::map_topology(desc.primitive.topology);
let mut shader_stages = wgt::ShaderStages::VERTEX;
let (vertex_stage_desc, vertex_buffers_desc) = match &desc.vertex_processor {
crate::VertexProcessor::Standard {
vertex_buffers,
vertex_stage,
} => (vertex_stage, *vertex_buffers),
crate::VertexProcessor::Mesh { .. } => unreachable!(),
};
let blob_vs = self.load_shader(
vertex_stage_desc,
desc.layout,
naga::ShaderStage::Vertex,
desc.fragment_stage.as_ref(),
)?;
let blob_fs = match desc.fragment_stage {
Some(ref stage) => {
shader_stages |= wgt::ShaderStages::FRAGMENT;
Some(self.load_shader(stage, desc.layout, naga::ShaderStage::Fragment, None)?)
}
None => None,
};
let mut vertex_strides = [None; crate::MAX_VERTEX_BUFFERS];
let mut input_element_descs = Vec::new();
for (i, (stride, vbuf)) in vertex_strides
.iter_mut()
.zip(vertex_buffers_desc)
.enumerate()
{
*stride = NonZeroU32::new(vbuf.array_stride as u32);
let (slot_class, step_rate) = match vbuf.step_mode {
wgt::VertexStepMode::Vertex => {
(Direct3D12::D3D12_INPUT_CLASSIFICATION_PER_VERTEX_DATA, 0)
}
wgt::VertexStepMode::Instance => {
(Direct3D12::D3D12_INPUT_CLASSIFICATION_PER_INSTANCE_DATA, 1)
}
};
for attribute in vbuf.attributes {
input_element_descs.push(Direct3D12::D3D12_INPUT_ELEMENT_DESC {
SemanticName: windows::core::PCSTR(NAGA_LOCATION_SEMANTIC.as_ptr()),
SemanticIndex: attribute.shader_location,
Format: auxil::dxgi::conv::map_vertex_format(attribute.format),
InputSlot: i as u32,
AlignedByteOffset: attribute.offset as u32,
InputSlotClass: slot_class,
InstanceDataStepRate: step_rate,
});
}
}
let mut rtv_formats = [Dxgi::Common::DXGI_FORMAT_UNKNOWN;
Direct3D12::D3D12_SIMULTANEOUS_RENDER_TARGET_COUNT as usize];
for (rtv_format, ct) in rtv_formats.iter_mut().zip(desc.color_targets) {
@ -1893,7 +1866,7 @@ impl crate::Device for super::Device {
.map(|ds| ds.bias)
.unwrap_or_default();
let raw_rasterizer = Direct3D12::D3D12_RASTERIZER_DESC {
let rasterizer_state = Direct3D12::D3D12_RASTERIZER_DESC {
FillMode: conv::map_polygon_mode(desc.primitive.polygon_mode),
CullMode: match desc.primitive.cull_mode {
None => Direct3D12::D3D12_CULL_MODE_NONE,
@ -1917,80 +1890,193 @@ impl crate::Device for super::Device {
Direct3D12::D3D12_CONSERVATIVE_RASTERIZATION_MODE_OFF
},
};
let raw_desc = Direct3D12::D3D12_GRAPHICS_PIPELINE_STATE_DESC {
pRootSignature: unsafe {
borrow_optional_interface_temporarily(&desc.layout.shared.signature)
},
VS: blob_vs.create_native_shader(),
PS: match &blob_fs {
Some(shader) => shader.create_native_shader(),
None => Direct3D12::D3D12_SHADER_BYTECODE::default(),
},
GS: Direct3D12::D3D12_SHADER_BYTECODE::default(),
DS: Direct3D12::D3D12_SHADER_BYTECODE::default(),
HS: Direct3D12::D3D12_SHADER_BYTECODE::default(),
StreamOutput: Direct3D12::D3D12_STREAM_OUTPUT_DESC {
pSODeclaration: ptr::null(),
NumEntries: 0,
pBufferStrides: ptr::null(),
NumStrides: 0,
RasterizedStream: 0,
},
BlendState: Direct3D12::D3D12_BLEND_DESC {
AlphaToCoverageEnable: Foundation::BOOL::from(
desc.multisample.alpha_to_coverage_enabled,
),
IndependentBlendEnable: true.into(),
RenderTarget: conv::map_render_targets(desc.color_targets),
},
SampleMask: desc.multisample.mask as u32,
RasterizerState: raw_rasterizer,
DepthStencilState: match desc.depth_stencil {
Some(ref ds) => conv::map_depth_stencil(ds),
None => Default::default(),
},
InputLayout: Direct3D12::D3D12_INPUT_LAYOUT_DESC {
pInputElementDescs: if input_element_descs.is_empty() {
ptr::null()
} else {
input_element_descs.as_ptr()
},
NumElements: input_element_descs.len() as u32,
},
IBStripCutValue: match desc.primitive.strip_index_format {
Some(wgt::IndexFormat::Uint16) => {
Direct3D12::D3D12_INDEX_BUFFER_STRIP_CUT_VALUE_0xFFFF
}
Some(wgt::IndexFormat::Uint32) => {
Direct3D12::D3D12_INDEX_BUFFER_STRIP_CUT_VALUE_0xFFFFFFFF
}
None => Direct3D12::D3D12_INDEX_BUFFER_STRIP_CUT_VALUE_DISABLED,
},
PrimitiveTopologyType: topology_class,
NumRenderTargets: desc.color_targets.len() as u32,
RTVFormats: rtv_formats,
DSVFormat: desc
.depth_stencil
.as_ref()
.map_or(Dxgi::Common::DXGI_FORMAT_UNKNOWN, |ds| {
auxil::dxgi::conv::map_texture_format(ds.format)
}),
SampleDesc: Dxgi::Common::DXGI_SAMPLE_DESC {
Count: desc.multisample.count,
Quality: 0,
},
NodeMask: 0,
CachedPSO: Direct3D12::D3D12_CACHED_PIPELINE_STATE {
pCachedBlob: ptr::null(),
CachedBlobSizeInBytes: 0,
},
Flags: Direct3D12::D3D12_PIPELINE_STATE_FLAG_NONE,
let blob_fs = match desc.fragment_stage {
Some(ref stage) => {
shader_stages |= wgt::ShaderStages::FRAGMENT;
Some(self.load_shader(stage, desc.layout, naga::ShaderStage::Fragment, None)?)
}
None => None,
};
let pixel_shader = match &blob_fs {
Some(shader) => shader.create_native_shader(),
None => Direct3D12::D3D12_SHADER_BYTECODE::default(),
};
let mut vertex_strides = [None; crate::MAX_VERTEX_BUFFERS];
let stream_output = Direct3D12::D3D12_STREAM_OUTPUT_DESC {
pSODeclaration: ptr::null(),
NumEntries: 0,
pBufferStrides: ptr::null(),
NumStrides: 0,
RasterizedStream: 0,
};
let blend_state = Direct3D12::D3D12_BLEND_DESC {
AlphaToCoverageEnable: Foundation::BOOL::from(
desc.multisample.alpha_to_coverage_enabled,
),
IndependentBlendEnable: true.into(),
RenderTarget: conv::map_render_targets(desc.color_targets),
};
let depth_stencil_state = match desc.depth_stencil {
Some(ref ds) => conv::map_depth_stencil(ds),
None => Default::default(),
};
let dsv_format = desc
.depth_stencil
.as_ref()
.map_or(Dxgi::Common::DXGI_FORMAT_UNKNOWN, |ds| {
auxil::dxgi::conv::map_texture_format(ds.format)
});
let sample_desc = Dxgi::Common::DXGI_SAMPLE_DESC {
Count: desc.multisample.count,
Quality: 0,
};
let cached_pso = Direct3D12::D3D12_CACHED_PIPELINE_STATE {
pCachedBlob: ptr::null(),
CachedBlobSizeInBytes: 0,
};
let flags = Direct3D12::D3D12_PIPELINE_STATE_FLAG_NONE;
let raw: Direct3D12::ID3D12PipelineState = {
profiling::scope!("ID3D12Device::CreateGraphicsPipelineState");
unsafe { self.raw.CreateGraphicsPipelineState(&raw_desc) }
let raw: Direct3D12::ID3D12PipelineState = match &desc.vertex_processor {
&crate::VertexProcessor::Standard {
vertex_buffers,
ref vertex_stage,
} => {
shader_stages |= wgt::ShaderStages::VERTEX;
let blob_vs = self.load_shader(
vertex_stage,
desc.layout,
naga::ShaderStage::Vertex,
desc.fragment_stage.as_ref(),
)?;
let mut input_element_descs = Vec::new();
for (i, (stride, vbuf)) in vertex_strides.iter_mut().zip(vertex_buffers).enumerate()
{
*stride = NonZeroU32::new(vbuf.array_stride as u32);
let (slot_class, step_rate) = match vbuf.step_mode {
wgt::VertexStepMode::Vertex => {
(Direct3D12::D3D12_INPUT_CLASSIFICATION_PER_VERTEX_DATA, 0)
}
wgt::VertexStepMode::Instance => {
(Direct3D12::D3D12_INPUT_CLASSIFICATION_PER_INSTANCE_DATA, 1)
}
};
for attribute in vbuf.attributes {
input_element_descs.push(Direct3D12::D3D12_INPUT_ELEMENT_DESC {
SemanticName: windows::core::PCSTR(NAGA_LOCATION_SEMANTIC.as_ptr()),
SemanticIndex: attribute.shader_location,
Format: auxil::dxgi::conv::map_vertex_format(attribute.format),
InputSlot: i as u32,
AlignedByteOffset: attribute.offset as u32,
InputSlotClass: slot_class,
InstanceDataStepRate: step_rate,
});
}
}
let raw_desc = Direct3D12::D3D12_GRAPHICS_PIPELINE_STATE_DESC {
pRootSignature: root_signature,
VS: blob_vs.create_native_shader(),
PS: pixel_shader,
GS: Direct3D12::D3D12_SHADER_BYTECODE::default(),
DS: Direct3D12::D3D12_SHADER_BYTECODE::default(),
HS: Direct3D12::D3D12_SHADER_BYTECODE::default(),
StreamOutput: stream_output,
BlendState: blend_state,
SampleMask: desc.multisample.mask as u32,
RasterizerState: rasterizer_state,
DepthStencilState: depth_stencil_state,
InputLayout: Direct3D12::D3D12_INPUT_LAYOUT_DESC {
pInputElementDescs: if input_element_descs.is_empty() {
ptr::null()
} else {
input_element_descs.as_ptr()
},
NumElements: input_element_descs.len() as u32,
},
IBStripCutValue: match desc.primitive.strip_index_format {
Some(wgt::IndexFormat::Uint16) => {
Direct3D12::D3D12_INDEX_BUFFER_STRIP_CUT_VALUE_0xFFFF
}
Some(wgt::IndexFormat::Uint32) => {
Direct3D12::D3D12_INDEX_BUFFER_STRIP_CUT_VALUE_0xFFFFFFFF
}
None => Direct3D12::D3D12_INDEX_BUFFER_STRIP_CUT_VALUE_DISABLED,
},
PrimitiveTopologyType: topology_class,
NumRenderTargets: desc.color_targets.len() as u32,
RTVFormats: rtv_formats,
DSVFormat: dsv_format,
SampleDesc: sample_desc,
NodeMask: 0,
CachedPSO: cached_pso,
Flags: flags,
};
unsafe {
profiling::scope!("ID3D12Device::CreateGraphicsPipelineState");
self.raw.CreateGraphicsPipelineState(&raw_desc)
}
}
crate::VertexProcessor::Mesh {
task_stage,
mesh_stage,
} => {
let blob_ts = if let Some(ts) = task_stage {
shader_stages |= wgt::ShaderStages::TASK;
Some(self.load_shader(
ts,
desc.layout,
naga::ShaderStage::Task,
desc.fragment_stage.as_ref(),
)?)
} else {
None
};
let task_shader = if let Some(ts) = &blob_ts {
ts.create_native_shader()
} else {
Default::default()
};
shader_stages |= wgt::ShaderStages::MESH;
let blob_ms = self.load_shader(
mesh_stage,
desc.layout,
naga::ShaderStage::Mesh,
desc.fragment_stage.as_ref(),
)?;
let desc = super::MeshShaderPipelineStateStream {
root_signature: root_signature
.as_ref()
.map(|a| a.as_raw().cast())
.unwrap_or(ptr::null_mut()),
task_shader,
pixel_shader,
mesh_shader: blob_ms.create_native_shader(),
blend_state,
sample_mask: desc.multisample.mask as u32,
rasterizer_state,
depth_stencil_state,
primitive_topology_type: topology_class,
rtv_formats: Direct3D12::D3D12_RT_FORMAT_ARRAY {
RTFormats: rtv_formats,
NumRenderTargets: desc.color_targets.len() as u32,
},
dsv_format,
sample_desc,
node_mask: 0,
cached_pso,
flags,
};
let mut raw_desc = unsafe { desc.to_bytes() };
let stream_desc = Direct3D12::D3D12_PIPELINE_STATE_STREAM_DESC {
SizeInBytes: raw_desc.len(),
pPipelineStateSubobjectStream: raw_desc.as_mut_ptr().cast(),
};
let device: Direct3D12::ID3D12Device2 = self.raw.cast().unwrap();
unsafe {
profiling::scope!("ID3D12Device2::CreatePipelineState");
device.CreatePipelineState(&stream_desc)
}
}
}
.map_err(|err| crate::PipelineError::Linkage(shader_stages, err.to_string()))?;

View File

@ -660,6 +660,7 @@ struct Idler {
struct CommandSignatures {
draw: Direct3D12::ID3D12CommandSignature,
draw_indexed: Direct3D12::ID3D12CommandSignature,
draw_mesh: Direct3D12::ID3D12CommandSignature,
dispatch: Direct3D12::ID3D12CommandSignature,
}
@ -1600,3 +1601,116 @@ pub enum ShaderModuleSource {
DxilPassthrough(DxilPassthroughShader),
HlslPassthrough(HlslPassthroughShader),
}
#[repr(C)]
#[derive(Debug)]
struct MeshShaderPipelineStateStream {
root_signature: *mut Direct3D12::ID3D12RootSignature,
task_shader: Direct3D12::D3D12_SHADER_BYTECODE,
mesh_shader: Direct3D12::D3D12_SHADER_BYTECODE,
pixel_shader: Direct3D12::D3D12_SHADER_BYTECODE,
blend_state: Direct3D12::D3D12_BLEND_DESC,
sample_mask: u32,
rasterizer_state: Direct3D12::D3D12_RASTERIZER_DESC,
depth_stencil_state: Direct3D12::D3D12_DEPTH_STENCIL_DESC,
primitive_topology_type: Direct3D12::D3D12_PRIMITIVE_TOPOLOGY_TYPE,
rtv_formats: Direct3D12::D3D12_RT_FORMAT_ARRAY,
dsv_format: Dxgi::Common::DXGI_FORMAT,
sample_desc: Dxgi::Common::DXGI_SAMPLE_DESC,
node_mask: u32,
cached_pso: Direct3D12::D3D12_CACHED_PIPELINE_STATE,
flags: Direct3D12::D3D12_PIPELINE_STATE_FLAGS,
}
impl MeshShaderPipelineStateStream {
/// # Safety
///
/// Returned bytes contain pointers into this struct, for them to be valid,
/// this struct may be at the same location. As if `as_bytes<'a>(&'a self) -> Vec<u8> + 'a`
pub unsafe fn to_bytes(&self) -> Vec<u8> {
use Direct3D12::*;
let mut bytes = Vec::new();
macro_rules! push_subobject {
($subobject_type:expr, $data:expr) => {{
// Ensure 8-byte alignment for the subobject start
let alignment = 8;
let aligned_length = bytes.len().next_multiple_of(alignment);
bytes.resize(aligned_length, 0);
// Append the type tag (u32)
let tag: u32 = $subobject_type.0 as u32;
bytes.extend_from_slice(&tag.to_ne_bytes());
// Align the data
let obj_align = align_of_val(&$data);
let data_start = bytes.len().next_multiple_of(obj_align);
bytes.resize(data_start, 0);
// Append the data itself
#[allow(clippy::ptr_as_ptr, trivial_casts)]
let data_ptr = &$data as *const _ as *const u8;
let data_size = size_of_val(&$data);
let slice = unsafe { core::slice::from_raw_parts(data_ptr, data_size) };
bytes.extend_from_slice(slice);
}};
}
push_subobject!(
D3D12_PIPELINE_STATE_SUBOBJECT_TYPE_ROOT_SIGNATURE,
self.root_signature
);
if !self.task_shader.pShaderBytecode.is_null() {
push_subobject!(D3D12_PIPELINE_STATE_SUBOBJECT_TYPE_AS, self.task_shader);
}
push_subobject!(D3D12_PIPELINE_STATE_SUBOBJECT_TYPE_MS, self.mesh_shader);
if !self.pixel_shader.pShaderBytecode.is_null() {
push_subobject!(D3D12_PIPELINE_STATE_SUBOBJECT_TYPE_PS, self.pixel_shader);
}
push_subobject!(D3D12_PIPELINE_STATE_SUBOBJECT_TYPE_BLEND, self.blend_state);
push_subobject!(
D3D12_PIPELINE_STATE_SUBOBJECT_TYPE_SAMPLE_MASK,
self.sample_mask
);
push_subobject!(
D3D12_PIPELINE_STATE_SUBOBJECT_TYPE_RASTERIZER,
self.rasterizer_state
);
push_subobject!(
D3D12_PIPELINE_STATE_SUBOBJECT_TYPE_DEPTH_STENCIL,
self.depth_stencil_state
);
push_subobject!(
D3D12_PIPELINE_STATE_SUBOBJECT_TYPE_PRIMITIVE_TOPOLOGY,
self.primitive_topology_type
);
if self.rtv_formats.NumRenderTargets != 0 {
push_subobject!(
D3D12_PIPELINE_STATE_SUBOBJECT_TYPE_RENDER_TARGET_FORMATS,
self.rtv_formats
);
}
if self.dsv_format != Dxgi::Common::DXGI_FORMAT_UNKNOWN {
push_subobject!(
D3D12_PIPELINE_STATE_SUBOBJECT_TYPE_DEPTH_STENCIL_FORMAT,
self.dsv_format
);
}
push_subobject!(
D3D12_PIPELINE_STATE_SUBOBJECT_TYPE_SAMPLE_DESC,
self.sample_desc
);
if self.node_mask != 0 {
push_subobject!(
D3D12_PIPELINE_STATE_SUBOBJECT_TYPE_NODE_MASK,
self.node_mask
);
}
if !self.cached_pso.pCachedBlob.is_null() {
push_subobject!(
D3D12_PIPELINE_STATE_SUBOBJECT_TYPE_CACHED_PSO,
self.cached_pso
);
}
push_subobject!(D3D12_PIPELINE_STATE_SUBOBJECT_TYPE_FLAGS, self.flags);
bytes
}
}