DXIL & HLSL passthrough (#7831)

Co-authored-by: Connor Fitzgerald <connorwadefitzgerald@gmail.com>
This commit is contained in:
SupaMaggie70Incorporated 2025-06-25 17:12:58 -05:00 committed by GitHub
parent 4c08c37a46
commit e40e66d205
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 259 additions and 56 deletions

View File

@ -984,6 +984,18 @@ impl Global {
runtime_checks: wgt::ShaderRuntimeChecks::unchecked(), runtime_checks: wgt::ShaderRuntimeChecks::unchecked(),
} }
} }
pipeline::ShaderModuleDescriptorPassthrough::Dxil(inner) => {
pipeline::ShaderModuleDescriptor {
label: inner.label.clone(),
runtime_checks: wgt::ShaderRuntimeChecks::unchecked(),
}
}
pipeline::ShaderModuleDescriptorPassthrough::Hlsl(inner) => {
pipeline::ShaderModuleDescriptor {
label: inner.label.clone(),
runtime_checks: wgt::ShaderRuntimeChecks::unchecked(),
}
}
}, },
data, data,
}); });

View File

@ -1801,6 +1801,22 @@ impl Device {
num_workgroups: inner.num_workgroups, num_workgroups: inner.num_workgroups,
} }
} }
pipeline::ShaderModuleDescriptorPassthrough::Dxil(inner) => {
self.require_features(wgt::Features::HLSL_DXIL_SHADER_PASSTHROUGH)?;
hal::ShaderInput::Dxil {
shader: inner.source,
entry_point: inner.entry_point.clone(),
num_workgroups: inner.num_workgroups,
}
}
pipeline::ShaderModuleDescriptorPassthrough::Hlsl(inner) => {
self.require_features(wgt::Features::HLSL_DXIL_SHADER_PASSTHROUGH)?;
hal::ShaderInput::Hlsl {
shader: inner.source,
entry_point: inner.entry_point.clone(),
num_workgroups: inner.num_workgroups,
}
}
}; };
let hal_desc = hal::ShaderModuleDescriptor { let hal_desc = hal::ShaderModuleDescriptor {

View File

@ -1,3 +1,4 @@
use alloc::borrow::ToOwned;
use alloc::{ use alloc::{
borrow::Cow, borrow::Cow,
string::{String, ToString as _}, string::{String, ToString as _},
@ -264,27 +265,8 @@ impl super::Device {
naga_stage: naga::ShaderStage, naga_stage: naga::ShaderStage,
fragment_stage: Option<&crate::ProgrammableStage<super::ShaderModule>>, fragment_stage: Option<&crate::ProgrammableStage<super::ShaderModule>>,
) -> Result<super::CompiledShader, crate::PipelineError> { ) -> Result<super::CompiledShader, crate::PipelineError> {
use naga::back::hlsl;
let frag_ep = fragment_stage
.map(|fs_stage| {
hlsl::FragmentEntryPoint::new(&fs_stage.module.naga.module, fs_stage.entry_point)
.ok_or(crate::PipelineError::EntryPoint(
naga::ShaderStage::Fragment,
))
})
.transpose()?;
let stage_bit = auxil::map_naga_stage(naga_stage); let stage_bit = auxil::map_naga_stage(naga_stage);
let (module, info) = naga::back::pipeline_constants::process_overrides(
&stage.module.naga.module,
&stage.module.naga.info,
Some((naga_stage, stage.entry_point)),
stage.constants,
)
.map_err(|e| crate::PipelineError::PipelineConstants(stage_bit, format!("HLSL: {e:?}")))?;
let needs_temp_options = stage.zero_initialize_workgroup_memory let needs_temp_options = stage.zero_initialize_workgroup_memory
!= layout.naga_options.zero_initialize_workgroup_memory != layout.naga_options.zero_initialize_workgroup_memory
|| stage.module.runtime_checks.bounds_checks != layout.naga_options.restrict_indexing || stage.module.runtime_checks.bounds_checks != layout.naga_options.restrict_indexing
@ -301,43 +283,90 @@ impl super::Device {
&layout.naga_options &layout.naga_options
}; };
let pipeline_options = hlsl::PipelineOptions { let key = match &stage.module.source {
entry_point: Some((naga_stage, stage.entry_point.to_string())), super::ShaderModuleSource::Naga(naga_shader) => {
}; use naga::back::hlsl;
//TODO: reuse the writer let frag_ep = match fragment_stage {
let (source, entry_point) = { Some(crate::ProgrammableStage {
let mut source = String::new(); module:
let mut writer = hlsl::Writer::new(&mut source, naga_options, &pipeline_options); super::ShaderModule {
source: super::ShaderModuleSource::Naga(naga_shader),
..
},
entry_point,
..
}) => Some(
hlsl::FragmentEntryPoint::new(&naga_shader.module, entry_point).ok_or(
crate::PipelineError::EntryPoint(naga::ShaderStage::Fragment),
),
),
_ => None,
}
.transpose()?;
let (module, info) = naga::back::pipeline_constants::process_overrides(
&naga_shader.module,
&naga_shader.info,
Some((naga_stage, stage.entry_point)),
stage.constants,
)
.map_err(|e| {
crate::PipelineError::PipelineConstants(stage_bit, format!("HLSL: {e:?}"))
})?;
profiling::scope!("naga::back::hlsl::write"); let pipeline_options = hlsl::PipelineOptions {
let mut reflection_info = writer entry_point: Some((naga_stage, stage.entry_point.to_string())),
.write(&module, &info, frag_ep.as_ref()) };
.map_err(|e| crate::PipelineError::Linkage(stage_bit, format!("HLSL: {e:?}")))?;
assert_eq!(reflection_info.entry_point_names.len(), 1); //TODO: reuse the writer
let (source, entry_point) = {
let mut source = String::new();
let mut writer =
hlsl::Writer::new(&mut source, naga_options, &pipeline_options);
let entry_point = reflection_info profiling::scope!("naga::back::hlsl::write");
.entry_point_names let mut reflection_info = writer
.pop() .write(&module, &info, frag_ep.as_ref())
.unwrap() .map_err(|e| {
.map_err(|e| crate::PipelineError::Linkage(stage_bit, format!("{e}")))?; crate::PipelineError::Linkage(stage_bit, format!("HLSL: {e:?}"))
})?;
(source, entry_point) assert_eq!(reflection_info.entry_point_names.len(), 1);
};
log::info!( let entry_point = reflection_info
"Naga generated shader for {:?} at {:?}:\n{}", .entry_point_names
entry_point, .pop()
naga_stage, .unwrap()
source .map_err(|e| crate::PipelineError::Linkage(stage_bit, format!("{e}")))?;
);
let key = ShaderCacheKey { (source, entry_point)
source, };
entry_point, log::info!(
stage: naga_stage, "Naga generated shader for {:?} at {:?}:\n{}",
shader_model: naga_options.shader_model, entry_point,
naga_stage,
source
);
ShaderCacheKey {
source,
entry_point,
stage: naga_stage,
shader_model: naga_options.shader_model,
}
}
super::ShaderModuleSource::HlslPassthrough(passthrough) => ShaderCacheKey {
source: passthrough.shader.clone(),
entry_point: passthrough.entry_point.clone(),
stage: naga_stage,
shader_model: naga_options.shader_model,
},
super::ShaderModuleSource::DxilPassthrough(passthrough) => {
return Ok(super::CompiledShader::Precompiled(
passthrough.shader.clone(),
))
}
}; };
{ {
@ -351,11 +380,7 @@ impl super::Device {
let source_name = stage.module.raw_name.as_deref(); let source_name = stage.module.raw_name.as_deref();
let full_stage = format!( let full_stage = format!("{}_{}", naga_stage.to_hlsl_str(), key.shader_model.to_str());
"{}_{}",
naga_stage.to_hlsl_str(),
naga_options.shader_model.to_str()
);
let compiled_shader = self.compiler_container.compile( let compiled_shader = self.compiler_container.compile(
self, self,
@ -1671,7 +1696,7 @@ impl crate::Device for super::Device {
.and_then(|label| alloc::ffi::CString::new(label).ok()); .and_then(|label| alloc::ffi::CString::new(label).ok());
match shader { match shader {
crate::ShaderInput::Naga(naga) => Ok(super::ShaderModule { crate::ShaderInput::Naga(naga) => Ok(super::ShaderModule {
naga, source: super::ShaderModuleSource::Naga(naga),
raw_name, raw_name,
runtime_checks: desc.runtime_checks, runtime_checks: desc.runtime_checks,
}), }),
@ -1681,6 +1706,32 @@ impl crate::Device for super::Device {
crate::ShaderInput::Msl { .. } => { crate::ShaderInput::Msl { .. } => {
panic!("MSL_SHADER_PASSTHROUGH is not enabled for this backend") panic!("MSL_SHADER_PASSTHROUGH is not enabled for this backend")
} }
crate::ShaderInput::Dxil {
shader,
entry_point,
num_workgroups,
} => Ok(super::ShaderModule {
source: super::ShaderModuleSource::DxilPassthrough(super::DxilPassthroughShader {
shader: shader.to_vec(),
entry_point,
num_workgroups,
}),
raw_name,
runtime_checks: desc.runtime_checks,
}),
crate::ShaderInput::Hlsl {
shader,
entry_point,
num_workgroups,
} => Ok(super::ShaderModule {
source: super::ShaderModuleSource::HlslPassthrough(super::HlslPassthroughShader {
shader: shader.to_owned(),
entry_point,
num_workgroups,
}),
raw_name,
runtime_checks: desc.runtime_checks,
}),
} }
} }
unsafe fn destroy_shader_module(&self, _module: super::ShaderModule) { unsafe fn destroy_shader_module(&self, _module: super::ShaderModule) {

View File

@ -1077,7 +1077,7 @@ impl crate::DynPipelineLayout for PipelineLayout {}
#[derive(Debug)] #[derive(Debug)]
pub struct ShaderModule { pub struct ShaderModule {
naga: crate::NagaShader, source: ShaderModuleSource,
raw_name: Option<alloc::ffi::CString>, raw_name: Option<alloc::ffi::CString>,
runtime_checks: wgt::ShaderRuntimeChecks, runtime_checks: wgt::ShaderRuntimeChecks,
} }
@ -1109,6 +1109,7 @@ pub(super) struct ShaderCacheValue {
pub(super) enum CompiledShader { pub(super) enum CompiledShader {
Dxc(Direct3D::Dxc::IDxcBlob), Dxc(Direct3D::Dxc::IDxcBlob),
Fxc(Direct3D::ID3DBlob), Fxc(Direct3D::ID3DBlob),
Precompiled(Vec<u8>),
} }
impl CompiledShader { impl CompiledShader {
@ -1122,6 +1123,10 @@ impl CompiledShader {
pShaderBytecode: unsafe { shader.GetBufferPointer() }, pShaderBytecode: unsafe { shader.GetBufferPointer() },
BytecodeLength: unsafe { shader.GetBufferSize() }, BytecodeLength: unsafe { shader.GetBufferSize() },
}, },
CompiledShader::Precompiled(shader) => Direct3D12::D3D12_SHADER_BYTECODE {
pShaderBytecode: shader.as_ptr().cast(),
BytecodeLength: shader.len(),
},
} }
} }
} }
@ -1490,3 +1495,23 @@ impl crate::Queue for Queue {
(1_000_000_000.0 / frequency as f64) as f32 (1_000_000_000.0 / frequency as f64) as f32
} }
} }
#[derive(Debug)]
pub struct DxilPassthroughShader {
pub shader: Vec<u8>,
pub entry_point: String,
pub num_workgroups: (u32, u32, u32),
}
#[derive(Debug)]
pub struct HlslPassthroughShader {
pub shader: String,
pub entry_point: String,
pub num_workgroups: (u32, u32, u32),
}
#[derive(Debug)]
pub enum ShaderModuleSource {
Naga(crate::NagaShader),
DxilPassthrough(DxilPassthroughShader),
HlslPassthrough(HlslPassthroughShader),
}

View File

@ -1346,6 +1346,9 @@ impl crate::Device for super::Device {
panic!("`Features::MSL_SHADER_PASSTHROUGH` is not enabled") panic!("`Features::MSL_SHADER_PASSTHROUGH` is not enabled")
} }
crate::ShaderInput::Naga(naga) => naga, crate::ShaderInput::Naga(naga) => naga,
crate::ShaderInput::Dxil { .. } | crate::ShaderInput::Hlsl { .. } => {
panic!("`Features::HLSL_DXIL_SHADER_PASSTHROUGH` is not enabled")
}
}, },
label: desc.label.map(|str| str.to_string()), label: desc.label.map(|str| str.to_string()),
id: self.shared.next_shader_id.fetch_add(1, Ordering::Relaxed), id: self.shared.next_shader_id.fetch_add(1, Ordering::Relaxed),

View File

@ -2104,6 +2104,16 @@ pub enum ShaderInput<'a> {
num_workgroups: (u32, u32, u32), num_workgroups: (u32, u32, u32),
}, },
SpirV(&'a [u32]), SpirV(&'a [u32]),
Dxil {
shader: &'a [u8],
entry_point: String,
num_workgroups: (u32, u32, u32),
},
Hlsl {
shader: &'a str,
entry_point: String,
num_workgroups: (u32, u32, u32),
},
} }
pub struct ShaderModuleDescriptor<'a> { pub struct ShaderModuleDescriptor<'a> {

View File

@ -1039,6 +1039,9 @@ impl crate::Device for super::Device {
crate::ShaderInput::SpirV(_) => { crate::ShaderInput::SpirV(_) => {
panic!("SPIRV_SHADER_PASSTHROUGH is not enabled for this backend") panic!("SPIRV_SHADER_PASSTHROUGH is not enabled for this backend")
} }
crate::ShaderInput::Dxil { .. } | crate::ShaderInput::Hlsl { .. } => {
panic!("`Features::HLSL_DXIL_SHADER_PASSTHROUGH` is not enabled for this backend")
}
} }
} }

View File

@ -1908,6 +1908,9 @@ impl crate::Device for super::Device {
crate::ShaderInput::Msl { .. } => { crate::ShaderInput::Msl { .. } => {
panic!("MSL_SHADER_PASSTHROUGH is not enabled for this backend") panic!("MSL_SHADER_PASSTHROUGH is not enabled for this backend")
} }
crate::ShaderInput::Dxil { .. } | crate::ShaderInput::Hlsl { .. } => {
panic!("`Features::HLSL_DXIL_SHADER_PASSTHROUGH` is not enabled")
}
crate::ShaderInput::SpirV(spv) => Cow::Borrowed(spv), crate::ShaderInput::SpirV(spv) => Cow::Borrowed(spv),
}; };

View File

@ -1244,6 +1244,16 @@ bitflags_array! {
/// ///
/// [BlasTriangleGeometrySizeDescriptor::vertex_format]: super::BlasTriangleGeometrySizeDescriptor /// [BlasTriangleGeometrySizeDescriptor::vertex_format]: super::BlasTriangleGeometrySizeDescriptor
const EXTENDED_ACCELERATION_STRUCTURE_VERTEX_FORMATS = 1 << 51; const EXTENDED_ACCELERATION_STRUCTURE_VERTEX_FORMATS = 1 << 51;
/// Enables creating shader modules from DirectX HLSL or DXIL shaders (unsafe)
///
/// HLSL/DXIL data is not parsed or interpreted in any way
///
/// Supported platforms:
/// - DX12
///
/// This is a native only feature.
const HLSL_DXIL_SHADER_PASSTHROUGH = 1 << 53;
} }
/// Features that are not guaranteed to be supported. /// Features that are not guaranteed to be supported.

View File

@ -7765,6 +7765,10 @@ pub enum CreateShaderModuleDescriptorPassthrough<'a, L> {
SpirV(ShaderModuleDescriptorSpirV<'a, L>), SpirV(ShaderModuleDescriptorSpirV<'a, L>),
/// Passthrough for MSL source code. /// Passthrough for MSL source code.
Msl(ShaderModuleDescriptorMsl<'a, L>), Msl(ShaderModuleDescriptorMsl<'a, L>),
/// Passthrough for DXIL compiled with DXC
Dxil(ShaderModuleDescriptorDxil<'a, L>),
/// Passthrough for HLSL
Hlsl(ShaderModuleDescriptorHlsl<'a, L>),
} }
impl<'a, L> CreateShaderModuleDescriptorPassthrough<'a, L> { impl<'a, L> CreateShaderModuleDescriptorPassthrough<'a, L> {
@ -7790,6 +7794,22 @@ impl<'a, L> CreateShaderModuleDescriptorPassthrough<'a, L> {
source: inner.source.clone(), source: inner.source.clone(),
}) })
} }
CreateShaderModuleDescriptorPassthrough::Dxil(inner) => {
CreateShaderModuleDescriptorPassthrough::<'_, K>::Dxil(ShaderModuleDescriptorDxil {
entry_point: inner.entry_point.clone(),
label: fun(&inner.label),
num_workgroups: inner.num_workgroups,
source: inner.source,
})
}
CreateShaderModuleDescriptorPassthrough::Hlsl(inner) => {
CreateShaderModuleDescriptorPassthrough::<'_, K>::Hlsl(ShaderModuleDescriptorHlsl {
entry_point: inner.entry_point.clone(),
label: fun(&inner.label),
num_workgroups: inner.num_workgroups,
source: inner.source,
})
}
} }
} }
@ -7798,6 +7818,8 @@ impl<'a, L> CreateShaderModuleDescriptorPassthrough<'a, L> {
match self { match self {
CreateShaderModuleDescriptorPassthrough::SpirV(inner) => &inner.label, CreateShaderModuleDescriptorPassthrough::SpirV(inner) => &inner.label,
CreateShaderModuleDescriptorPassthrough::Msl(inner) => &inner.label, CreateShaderModuleDescriptorPassthrough::Msl(inner) => &inner.label,
CreateShaderModuleDescriptorPassthrough::Dxil(inner) => &inner.label,
CreateShaderModuleDescriptorPassthrough::Hlsl(inner) => &inner.label,
} }
} }
@ -7809,6 +7831,8 @@ impl<'a, L> CreateShaderModuleDescriptorPassthrough<'a, L> {
bytemuck::cast_slice(&inner.source) bytemuck::cast_slice(&inner.source)
} }
CreateShaderModuleDescriptorPassthrough::Msl(inner) => inner.source.as_bytes(), CreateShaderModuleDescriptorPassthrough::Msl(inner) => inner.source.as_bytes(),
CreateShaderModuleDescriptorPassthrough::Dxil(inner) => inner.source,
CreateShaderModuleDescriptorPassthrough::Hlsl(inner) => inner.source.as_bytes(),
} }
} }
@ -7818,6 +7842,8 @@ impl<'a, L> CreateShaderModuleDescriptorPassthrough<'a, L> {
match self { match self {
CreateShaderModuleDescriptorPassthrough::SpirV(..) => "spv", CreateShaderModuleDescriptorPassthrough::SpirV(..) => "spv",
CreateShaderModuleDescriptorPassthrough::Msl(..) => "msl", CreateShaderModuleDescriptorPassthrough::Msl(..) => "msl",
CreateShaderModuleDescriptorPassthrough::Dxil(..) => "dxil",
CreateShaderModuleDescriptorPassthrough::Hlsl(..) => "hlsl",
} }
} }
} }
@ -7838,6 +7864,38 @@ pub struct ShaderModuleDescriptorMsl<'a, L> {
pub source: Cow<'a, str>, pub source: Cow<'a, str>,
} }
/// Descriptor for a shader module given by DirectX DXIL source.
///
/// This type is unique to the Rust API of `wgpu`. In the WebGPU specification,
/// only WGSL source code strings are accepted.
#[derive(Debug, Clone)]
pub struct ShaderModuleDescriptorDxil<'a, L> {
/// Entrypoint.
pub entry_point: String,
/// Debug label of the shader module. This will show up in graphics debuggers for easy identification.
pub label: L,
/// Number of workgroups in each dimension x, y and z.
pub num_workgroups: (u32, u32, u32),
/// Shader DXIL source.
pub source: &'a [u8],
}
/// Descriptor for a shader module given by DirectX HLSL source.
///
/// This type is unique to the Rust API of `wgpu`. In the WebGPU specification,
/// only WGSL source code strings are accepted.
#[derive(Debug, Clone)]
pub struct ShaderModuleDescriptorHlsl<'a, L> {
/// Entrypoint.
pub entry_point: String,
/// Debug label of the shader module. This will show up in graphics debuggers for easy identification.
pub label: L,
/// Number of workgroups in each dimension x, y and z.
pub num_workgroups: (u32, u32, u32),
/// Shader HLSL source.
pub source: &'a str,
}
/// Descriptor for a shader module given by SPIR-V binary. /// Descriptor for a shader module given by SPIR-V binary.
/// ///
/// This type is unique to the Rust API of `wgpu`. In the WebGPU specification, /// This type is unique to the Rust API of `wgpu`. In the WebGPU specification,

View File

@ -247,3 +247,15 @@ pub type ShaderModuleDescriptorMsl<'a> = wgt::ShaderModuleDescriptorMsl<'a, Labe
/// This type is unique to the Rust API of `wgpu`. In the WebGPU specification, /// This type is unique to the Rust API of `wgpu`. In the WebGPU specification,
/// only WGSL source code strings are accepted. /// only WGSL source code strings are accepted.
pub type ShaderModuleDescriptorSpirV<'a> = wgt::ShaderModuleDescriptorSpirV<'a, Label<'a>>; pub type ShaderModuleDescriptorSpirV<'a> = wgt::ShaderModuleDescriptorSpirV<'a, Label<'a>>;
/// Descriptor for a shader module given by DirectX HLSL source.
///
/// This type is unique to the Rust API of `wgpu`. In the WebGPU specification,
/// only WGSL source code strings are accepted.
pub type ShaderModuleDescriptorHlsl<'a> = wgt::ShaderModuleDescriptorHlsl<'a, Label<'a>>;
/// Descriptor for a shader module given by DirectX DXIL source.
///
/// This type is unique to the Rust API of `wgpu`. In the WebGPU specification,
/// only WGSL source code strings are accepted.
pub type ShaderModuleDescriptorDxil<'a> = wgt::ShaderModuleDescriptorDxil<'a, Label<'a>>;