DXIL & HLSL passthrough (#7831)

Co-authored-by: Connor Fitzgerald <connorwadefitzgerald@gmail.com>
This commit is contained in:
SupaMaggie70Incorporated 2025-06-25 17:12:58 -05:00 committed by GitHub
parent 4c08c37a46
commit e40e66d205
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 259 additions and 56 deletions

View File

@ -984,6 +984,18 @@ impl Global {
runtime_checks: wgt::ShaderRuntimeChecks::unchecked(),
}
}
pipeline::ShaderModuleDescriptorPassthrough::Dxil(inner) => {
pipeline::ShaderModuleDescriptor {
label: inner.label.clone(),
runtime_checks: wgt::ShaderRuntimeChecks::unchecked(),
}
}
pipeline::ShaderModuleDescriptorPassthrough::Hlsl(inner) => {
pipeline::ShaderModuleDescriptor {
label: inner.label.clone(),
runtime_checks: wgt::ShaderRuntimeChecks::unchecked(),
}
}
},
data,
});

View File

@ -1801,6 +1801,22 @@ impl Device {
num_workgroups: inner.num_workgroups,
}
}
pipeline::ShaderModuleDescriptorPassthrough::Dxil(inner) => {
self.require_features(wgt::Features::HLSL_DXIL_SHADER_PASSTHROUGH)?;
hal::ShaderInput::Dxil {
shader: inner.source,
entry_point: inner.entry_point.clone(),
num_workgroups: inner.num_workgroups,
}
}
pipeline::ShaderModuleDescriptorPassthrough::Hlsl(inner) => {
self.require_features(wgt::Features::HLSL_DXIL_SHADER_PASSTHROUGH)?;
hal::ShaderInput::Hlsl {
shader: inner.source,
entry_point: inner.entry_point.clone(),
num_workgroups: inner.num_workgroups,
}
}
};
let hal_desc = hal::ShaderModuleDescriptor {

View File

@ -1,3 +1,4 @@
use alloc::borrow::ToOwned;
use alloc::{
borrow::Cow,
string::{String, ToString as _},
@ -264,27 +265,8 @@ impl super::Device {
naga_stage: naga::ShaderStage,
fragment_stage: Option<&crate::ProgrammableStage<super::ShaderModule>>,
) -> Result<super::CompiledShader, crate::PipelineError> {
use naga::back::hlsl;
let frag_ep = fragment_stage
.map(|fs_stage| {
hlsl::FragmentEntryPoint::new(&fs_stage.module.naga.module, fs_stage.entry_point)
.ok_or(crate::PipelineError::EntryPoint(
naga::ShaderStage::Fragment,
))
})
.transpose()?;
let stage_bit = auxil::map_naga_stage(naga_stage);
let (module, info) = naga::back::pipeline_constants::process_overrides(
&stage.module.naga.module,
&stage.module.naga.info,
Some((naga_stage, stage.entry_point)),
stage.constants,
)
.map_err(|e| crate::PipelineError::PipelineConstants(stage_bit, format!("HLSL: {e:?}")))?;
let needs_temp_options = stage.zero_initialize_workgroup_memory
!= layout.naga_options.zero_initialize_workgroup_memory
|| stage.module.runtime_checks.bounds_checks != layout.naga_options.restrict_indexing
@ -301,43 +283,90 @@ impl super::Device {
&layout.naga_options
};
let pipeline_options = hlsl::PipelineOptions {
entry_point: Some((naga_stage, stage.entry_point.to_string())),
};
let key = match &stage.module.source {
super::ShaderModuleSource::Naga(naga_shader) => {
use naga::back::hlsl;
//TODO: reuse the writer
let (source, entry_point) = {
let mut source = String::new();
let mut writer = hlsl::Writer::new(&mut source, naga_options, &pipeline_options);
let frag_ep = match fragment_stage {
Some(crate::ProgrammableStage {
module:
super::ShaderModule {
source: super::ShaderModuleSource::Naga(naga_shader),
..
},
entry_point,
..
}) => Some(
hlsl::FragmentEntryPoint::new(&naga_shader.module, entry_point).ok_or(
crate::PipelineError::EntryPoint(naga::ShaderStage::Fragment),
),
),
_ => None,
}
.transpose()?;
let (module, info) = naga::back::pipeline_constants::process_overrides(
&naga_shader.module,
&naga_shader.info,
Some((naga_stage, stage.entry_point)),
stage.constants,
)
.map_err(|e| {
crate::PipelineError::PipelineConstants(stage_bit, format!("HLSL: {e:?}"))
})?;
profiling::scope!("naga::back::hlsl::write");
let mut reflection_info = writer
.write(&module, &info, frag_ep.as_ref())
.map_err(|e| crate::PipelineError::Linkage(stage_bit, format!("HLSL: {e:?}")))?;
let pipeline_options = hlsl::PipelineOptions {
entry_point: Some((naga_stage, stage.entry_point.to_string())),
};
assert_eq!(reflection_info.entry_point_names.len(), 1);
//TODO: reuse the writer
let (source, entry_point) = {
let mut source = String::new();
let mut writer =
hlsl::Writer::new(&mut source, naga_options, &pipeline_options);
let entry_point = reflection_info
.entry_point_names
.pop()
.unwrap()
.map_err(|e| crate::PipelineError::Linkage(stage_bit, format!("{e}")))?;
profiling::scope!("naga::back::hlsl::write");
let mut reflection_info = writer
.write(&module, &info, frag_ep.as_ref())
.map_err(|e| {
crate::PipelineError::Linkage(stage_bit, format!("HLSL: {e:?}"))
})?;
(source, entry_point)
};
assert_eq!(reflection_info.entry_point_names.len(), 1);
log::info!(
"Naga generated shader for {:?} at {:?}:\n{}",
entry_point,
naga_stage,
source
);
let entry_point = reflection_info
.entry_point_names
.pop()
.unwrap()
.map_err(|e| crate::PipelineError::Linkage(stage_bit, format!("{e}")))?;
let key = ShaderCacheKey {
source,
entry_point,
stage: naga_stage,
shader_model: naga_options.shader_model,
(source, entry_point)
};
log::info!(
"Naga generated shader for {:?} at {:?}:\n{}",
entry_point,
naga_stage,
source
);
ShaderCacheKey {
source,
entry_point,
stage: naga_stage,
shader_model: naga_options.shader_model,
}
}
super::ShaderModuleSource::HlslPassthrough(passthrough) => ShaderCacheKey {
source: passthrough.shader.clone(),
entry_point: passthrough.entry_point.clone(),
stage: naga_stage,
shader_model: naga_options.shader_model,
},
super::ShaderModuleSource::DxilPassthrough(passthrough) => {
return Ok(super::CompiledShader::Precompiled(
passthrough.shader.clone(),
))
}
};
{
@ -351,11 +380,7 @@ impl super::Device {
let source_name = stage.module.raw_name.as_deref();
let full_stage = format!(
"{}_{}",
naga_stage.to_hlsl_str(),
naga_options.shader_model.to_str()
);
let full_stage = format!("{}_{}", naga_stage.to_hlsl_str(), key.shader_model.to_str());
let compiled_shader = self.compiler_container.compile(
self,
@ -1671,7 +1696,7 @@ impl crate::Device for super::Device {
.and_then(|label| alloc::ffi::CString::new(label).ok());
match shader {
crate::ShaderInput::Naga(naga) => Ok(super::ShaderModule {
naga,
source: super::ShaderModuleSource::Naga(naga),
raw_name,
runtime_checks: desc.runtime_checks,
}),
@ -1681,6 +1706,32 @@ impl crate::Device for super::Device {
crate::ShaderInput::Msl { .. } => {
panic!("MSL_SHADER_PASSTHROUGH is not enabled for this backend")
}
crate::ShaderInput::Dxil {
shader,
entry_point,
num_workgroups,
} => Ok(super::ShaderModule {
source: super::ShaderModuleSource::DxilPassthrough(super::DxilPassthroughShader {
shader: shader.to_vec(),
entry_point,
num_workgroups,
}),
raw_name,
runtime_checks: desc.runtime_checks,
}),
crate::ShaderInput::Hlsl {
shader,
entry_point,
num_workgroups,
} => Ok(super::ShaderModule {
source: super::ShaderModuleSource::HlslPassthrough(super::HlslPassthroughShader {
shader: shader.to_owned(),
entry_point,
num_workgroups,
}),
raw_name,
runtime_checks: desc.runtime_checks,
}),
}
}
unsafe fn destroy_shader_module(&self, _module: super::ShaderModule) {

View File

@ -1077,7 +1077,7 @@ impl crate::DynPipelineLayout for PipelineLayout {}
#[derive(Debug)]
pub struct ShaderModule {
naga: crate::NagaShader,
source: ShaderModuleSource,
raw_name: Option<alloc::ffi::CString>,
runtime_checks: wgt::ShaderRuntimeChecks,
}
@ -1109,6 +1109,7 @@ pub(super) struct ShaderCacheValue {
pub(super) enum CompiledShader {
Dxc(Direct3D::Dxc::IDxcBlob),
Fxc(Direct3D::ID3DBlob),
Precompiled(Vec<u8>),
}
impl CompiledShader {
@ -1122,6 +1123,10 @@ impl CompiledShader {
pShaderBytecode: unsafe { shader.GetBufferPointer() },
BytecodeLength: unsafe { shader.GetBufferSize() },
},
CompiledShader::Precompiled(shader) => Direct3D12::D3D12_SHADER_BYTECODE {
pShaderBytecode: shader.as_ptr().cast(),
BytecodeLength: shader.len(),
},
}
}
}
@ -1490,3 +1495,23 @@ impl crate::Queue for Queue {
(1_000_000_000.0 / frequency as f64) as f32
}
}
#[derive(Debug)]
pub struct DxilPassthroughShader {
pub shader: Vec<u8>,
pub entry_point: String,
pub num_workgroups: (u32, u32, u32),
}
#[derive(Debug)]
pub struct HlslPassthroughShader {
pub shader: String,
pub entry_point: String,
pub num_workgroups: (u32, u32, u32),
}
#[derive(Debug)]
pub enum ShaderModuleSource {
Naga(crate::NagaShader),
DxilPassthrough(DxilPassthroughShader),
HlslPassthrough(HlslPassthroughShader),
}

View File

@ -1346,6 +1346,9 @@ impl crate::Device for super::Device {
panic!("`Features::MSL_SHADER_PASSTHROUGH` is not enabled")
}
crate::ShaderInput::Naga(naga) => naga,
crate::ShaderInput::Dxil { .. } | crate::ShaderInput::Hlsl { .. } => {
panic!("`Features::HLSL_DXIL_SHADER_PASSTHROUGH` is not enabled")
}
},
label: desc.label.map(|str| str.to_string()),
id: self.shared.next_shader_id.fetch_add(1, Ordering::Relaxed),

View File

@ -2104,6 +2104,16 @@ pub enum ShaderInput<'a> {
num_workgroups: (u32, u32, u32),
},
SpirV(&'a [u32]),
Dxil {
shader: &'a [u8],
entry_point: String,
num_workgroups: (u32, u32, u32),
},
Hlsl {
shader: &'a str,
entry_point: String,
num_workgroups: (u32, u32, u32),
},
}
pub struct ShaderModuleDescriptor<'a> {

View File

@ -1039,6 +1039,9 @@ impl crate::Device for super::Device {
crate::ShaderInput::SpirV(_) => {
panic!("SPIRV_SHADER_PASSTHROUGH is not enabled for this backend")
}
crate::ShaderInput::Dxil { .. } | crate::ShaderInput::Hlsl { .. } => {
panic!("`Features::HLSL_DXIL_SHADER_PASSTHROUGH` is not enabled for this backend")
}
}
}

View File

@ -1908,6 +1908,9 @@ impl crate::Device for super::Device {
crate::ShaderInput::Msl { .. } => {
panic!("MSL_SHADER_PASSTHROUGH is not enabled for this backend")
}
crate::ShaderInput::Dxil { .. } | crate::ShaderInput::Hlsl { .. } => {
panic!("`Features::HLSL_DXIL_SHADER_PASSTHROUGH` is not enabled")
}
crate::ShaderInput::SpirV(spv) => Cow::Borrowed(spv),
};

View File

@ -1244,6 +1244,16 @@ bitflags_array! {
///
/// [BlasTriangleGeometrySizeDescriptor::vertex_format]: super::BlasTriangleGeometrySizeDescriptor
const EXTENDED_ACCELERATION_STRUCTURE_VERTEX_FORMATS = 1 << 51;
/// Enables creating shader modules from DirectX HLSL or DXIL shaders (unsafe)
///
/// HLSL/DXIL data is not parsed or interpreted in any way
///
/// Supported platforms:
/// - DX12
///
/// This is a native only feature.
const HLSL_DXIL_SHADER_PASSTHROUGH = 1 << 53;
}
/// Features that are not guaranteed to be supported.

View File

@ -7765,6 +7765,10 @@ pub enum CreateShaderModuleDescriptorPassthrough<'a, L> {
SpirV(ShaderModuleDescriptorSpirV<'a, L>),
/// Passthrough for MSL source code.
Msl(ShaderModuleDescriptorMsl<'a, L>),
/// Passthrough for DXIL compiled with DXC
Dxil(ShaderModuleDescriptorDxil<'a, L>),
/// Passthrough for HLSL
Hlsl(ShaderModuleDescriptorHlsl<'a, L>),
}
impl<'a, L> CreateShaderModuleDescriptorPassthrough<'a, L> {
@ -7790,6 +7794,22 @@ impl<'a, L> CreateShaderModuleDescriptorPassthrough<'a, L> {
source: inner.source.clone(),
})
}
CreateShaderModuleDescriptorPassthrough::Dxil(inner) => {
CreateShaderModuleDescriptorPassthrough::<'_, K>::Dxil(ShaderModuleDescriptorDxil {
entry_point: inner.entry_point.clone(),
label: fun(&inner.label),
num_workgroups: inner.num_workgroups,
source: inner.source,
})
}
CreateShaderModuleDescriptorPassthrough::Hlsl(inner) => {
CreateShaderModuleDescriptorPassthrough::<'_, K>::Hlsl(ShaderModuleDescriptorHlsl {
entry_point: inner.entry_point.clone(),
label: fun(&inner.label),
num_workgroups: inner.num_workgroups,
source: inner.source,
})
}
}
}
@ -7798,6 +7818,8 @@ impl<'a, L> CreateShaderModuleDescriptorPassthrough<'a, L> {
match self {
CreateShaderModuleDescriptorPassthrough::SpirV(inner) => &inner.label,
CreateShaderModuleDescriptorPassthrough::Msl(inner) => &inner.label,
CreateShaderModuleDescriptorPassthrough::Dxil(inner) => &inner.label,
CreateShaderModuleDescriptorPassthrough::Hlsl(inner) => &inner.label,
}
}
@ -7809,6 +7831,8 @@ impl<'a, L> CreateShaderModuleDescriptorPassthrough<'a, L> {
bytemuck::cast_slice(&inner.source)
}
CreateShaderModuleDescriptorPassthrough::Msl(inner) => inner.source.as_bytes(),
CreateShaderModuleDescriptorPassthrough::Dxil(inner) => inner.source,
CreateShaderModuleDescriptorPassthrough::Hlsl(inner) => inner.source.as_bytes(),
}
}
@ -7818,6 +7842,8 @@ impl<'a, L> CreateShaderModuleDescriptorPassthrough<'a, L> {
match self {
CreateShaderModuleDescriptorPassthrough::SpirV(..) => "spv",
CreateShaderModuleDescriptorPassthrough::Msl(..) => "msl",
CreateShaderModuleDescriptorPassthrough::Dxil(..) => "dxil",
CreateShaderModuleDescriptorPassthrough::Hlsl(..) => "hlsl",
}
}
}
@ -7838,6 +7864,38 @@ pub struct ShaderModuleDescriptorMsl<'a, L> {
pub source: Cow<'a, str>,
}
/// Descriptor for a shader module given by DirectX DXIL source.
///
/// This type is unique to the Rust API of `wgpu`. In the WebGPU specification,
/// only WGSL source code strings are accepted.
#[derive(Debug, Clone)]
pub struct ShaderModuleDescriptorDxil<'a, L> {
/// Entrypoint.
pub entry_point: String,
/// Debug label of the shader module. This will show up in graphics debuggers for easy identification.
pub label: L,
/// Number of workgroups in each dimension x, y and z.
pub num_workgroups: (u32, u32, u32),
/// Shader DXIL source.
pub source: &'a [u8],
}
/// Descriptor for a shader module given by DirectX HLSL source.
///
/// This type is unique to the Rust API of `wgpu`. In the WebGPU specification,
/// only WGSL source code strings are accepted.
#[derive(Debug, Clone)]
pub struct ShaderModuleDescriptorHlsl<'a, L> {
/// Entrypoint.
pub entry_point: String,
/// Debug label of the shader module. This will show up in graphics debuggers for easy identification.
pub label: L,
/// Number of workgroups in each dimension x, y and z.
pub num_workgroups: (u32, u32, u32),
/// Shader HLSL source.
pub source: &'a str,
}
/// Descriptor for a shader module given by SPIR-V binary.
///
/// This type is unique to the Rust API of `wgpu`. In the WebGPU specification,

View File

@ -247,3 +247,15 @@ pub type ShaderModuleDescriptorMsl<'a> = wgt::ShaderModuleDescriptorMsl<'a, Labe
/// This type is unique to the Rust API of `wgpu`. In the WebGPU specification,
/// only WGSL source code strings are accepted.
pub type ShaderModuleDescriptorSpirV<'a> = wgt::ShaderModuleDescriptorSpirV<'a, Label<'a>>;
/// Descriptor for a shader module given by DirectX HLSL source.
///
/// This type is unique to the Rust API of `wgpu`. In the WebGPU specification,
/// only WGSL source code strings are accepted.
pub type ShaderModuleDescriptorHlsl<'a> = wgt::ShaderModuleDescriptorHlsl<'a, Label<'a>>;
/// Descriptor for a shader module given by DirectX DXIL source.
///
/// This type is unique to the Rust API of `wgpu`. In the WebGPU specification,
/// only WGSL source code strings are accepted.
pub type ShaderModuleDescriptorDxil<'a> = wgt::ShaderModuleDescriptorDxil<'a, Label<'a>>;