mirror of
https://github.com/gfx-rs/wgpu.git
synced 2025-12-08 21:26:17 +00:00
DXIL & HLSL passthrough (#7831)
Co-authored-by: Connor Fitzgerald <connorwadefitzgerald@gmail.com>
This commit is contained in:
parent
4c08c37a46
commit
e40e66d205
@ -984,6 +984,18 @@ impl Global {
|
||||
runtime_checks: wgt::ShaderRuntimeChecks::unchecked(),
|
||||
}
|
||||
}
|
||||
pipeline::ShaderModuleDescriptorPassthrough::Dxil(inner) => {
|
||||
pipeline::ShaderModuleDescriptor {
|
||||
label: inner.label.clone(),
|
||||
runtime_checks: wgt::ShaderRuntimeChecks::unchecked(),
|
||||
}
|
||||
}
|
||||
pipeline::ShaderModuleDescriptorPassthrough::Hlsl(inner) => {
|
||||
pipeline::ShaderModuleDescriptor {
|
||||
label: inner.label.clone(),
|
||||
runtime_checks: wgt::ShaderRuntimeChecks::unchecked(),
|
||||
}
|
||||
}
|
||||
},
|
||||
data,
|
||||
});
|
||||
|
||||
@ -1801,6 +1801,22 @@ impl Device {
|
||||
num_workgroups: inner.num_workgroups,
|
||||
}
|
||||
}
|
||||
pipeline::ShaderModuleDescriptorPassthrough::Dxil(inner) => {
|
||||
self.require_features(wgt::Features::HLSL_DXIL_SHADER_PASSTHROUGH)?;
|
||||
hal::ShaderInput::Dxil {
|
||||
shader: inner.source,
|
||||
entry_point: inner.entry_point.clone(),
|
||||
num_workgroups: inner.num_workgroups,
|
||||
}
|
||||
}
|
||||
pipeline::ShaderModuleDescriptorPassthrough::Hlsl(inner) => {
|
||||
self.require_features(wgt::Features::HLSL_DXIL_SHADER_PASSTHROUGH)?;
|
||||
hal::ShaderInput::Hlsl {
|
||||
shader: inner.source,
|
||||
entry_point: inner.entry_point.clone(),
|
||||
num_workgroups: inner.num_workgroups,
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
let hal_desc = hal::ShaderModuleDescriptor {
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
use alloc::borrow::ToOwned;
|
||||
use alloc::{
|
||||
borrow::Cow,
|
||||
string::{String, ToString as _},
|
||||
@ -264,27 +265,8 @@ impl super::Device {
|
||||
naga_stage: naga::ShaderStage,
|
||||
fragment_stage: Option<&crate::ProgrammableStage<super::ShaderModule>>,
|
||||
) -> Result<super::CompiledShader, crate::PipelineError> {
|
||||
use naga::back::hlsl;
|
||||
|
||||
let frag_ep = fragment_stage
|
||||
.map(|fs_stage| {
|
||||
hlsl::FragmentEntryPoint::new(&fs_stage.module.naga.module, fs_stage.entry_point)
|
||||
.ok_or(crate::PipelineError::EntryPoint(
|
||||
naga::ShaderStage::Fragment,
|
||||
))
|
||||
})
|
||||
.transpose()?;
|
||||
|
||||
let stage_bit = auxil::map_naga_stage(naga_stage);
|
||||
|
||||
let (module, info) = naga::back::pipeline_constants::process_overrides(
|
||||
&stage.module.naga.module,
|
||||
&stage.module.naga.info,
|
||||
Some((naga_stage, stage.entry_point)),
|
||||
stage.constants,
|
||||
)
|
||||
.map_err(|e| crate::PipelineError::PipelineConstants(stage_bit, format!("HLSL: {e:?}")))?;
|
||||
|
||||
let needs_temp_options = stage.zero_initialize_workgroup_memory
|
||||
!= layout.naga_options.zero_initialize_workgroup_memory
|
||||
|| stage.module.runtime_checks.bounds_checks != layout.naga_options.restrict_indexing
|
||||
@ -301,43 +283,90 @@ impl super::Device {
|
||||
&layout.naga_options
|
||||
};
|
||||
|
||||
let pipeline_options = hlsl::PipelineOptions {
|
||||
entry_point: Some((naga_stage, stage.entry_point.to_string())),
|
||||
};
|
||||
let key = match &stage.module.source {
|
||||
super::ShaderModuleSource::Naga(naga_shader) => {
|
||||
use naga::back::hlsl;
|
||||
|
||||
//TODO: reuse the writer
|
||||
let (source, entry_point) = {
|
||||
let mut source = String::new();
|
||||
let mut writer = hlsl::Writer::new(&mut source, naga_options, &pipeline_options);
|
||||
let frag_ep = match fragment_stage {
|
||||
Some(crate::ProgrammableStage {
|
||||
module:
|
||||
super::ShaderModule {
|
||||
source: super::ShaderModuleSource::Naga(naga_shader),
|
||||
..
|
||||
},
|
||||
entry_point,
|
||||
..
|
||||
}) => Some(
|
||||
hlsl::FragmentEntryPoint::new(&naga_shader.module, entry_point).ok_or(
|
||||
crate::PipelineError::EntryPoint(naga::ShaderStage::Fragment),
|
||||
),
|
||||
),
|
||||
_ => None,
|
||||
}
|
||||
.transpose()?;
|
||||
let (module, info) = naga::back::pipeline_constants::process_overrides(
|
||||
&naga_shader.module,
|
||||
&naga_shader.info,
|
||||
Some((naga_stage, stage.entry_point)),
|
||||
stage.constants,
|
||||
)
|
||||
.map_err(|e| {
|
||||
crate::PipelineError::PipelineConstants(stage_bit, format!("HLSL: {e:?}"))
|
||||
})?;
|
||||
|
||||
profiling::scope!("naga::back::hlsl::write");
|
||||
let mut reflection_info = writer
|
||||
.write(&module, &info, frag_ep.as_ref())
|
||||
.map_err(|e| crate::PipelineError::Linkage(stage_bit, format!("HLSL: {e:?}")))?;
|
||||
let pipeline_options = hlsl::PipelineOptions {
|
||||
entry_point: Some((naga_stage, stage.entry_point.to_string())),
|
||||
};
|
||||
|
||||
assert_eq!(reflection_info.entry_point_names.len(), 1);
|
||||
//TODO: reuse the writer
|
||||
let (source, entry_point) = {
|
||||
let mut source = String::new();
|
||||
let mut writer =
|
||||
hlsl::Writer::new(&mut source, naga_options, &pipeline_options);
|
||||
|
||||
let entry_point = reflection_info
|
||||
.entry_point_names
|
||||
.pop()
|
||||
.unwrap()
|
||||
.map_err(|e| crate::PipelineError::Linkage(stage_bit, format!("{e}")))?;
|
||||
profiling::scope!("naga::back::hlsl::write");
|
||||
let mut reflection_info = writer
|
||||
.write(&module, &info, frag_ep.as_ref())
|
||||
.map_err(|e| {
|
||||
crate::PipelineError::Linkage(stage_bit, format!("HLSL: {e:?}"))
|
||||
})?;
|
||||
|
||||
(source, entry_point)
|
||||
};
|
||||
assert_eq!(reflection_info.entry_point_names.len(), 1);
|
||||
|
||||
log::info!(
|
||||
"Naga generated shader for {:?} at {:?}:\n{}",
|
||||
entry_point,
|
||||
naga_stage,
|
||||
source
|
||||
);
|
||||
let entry_point = reflection_info
|
||||
.entry_point_names
|
||||
.pop()
|
||||
.unwrap()
|
||||
.map_err(|e| crate::PipelineError::Linkage(stage_bit, format!("{e}")))?;
|
||||
|
||||
let key = ShaderCacheKey {
|
||||
source,
|
||||
entry_point,
|
||||
stage: naga_stage,
|
||||
shader_model: naga_options.shader_model,
|
||||
(source, entry_point)
|
||||
};
|
||||
log::info!(
|
||||
"Naga generated shader for {:?} at {:?}:\n{}",
|
||||
entry_point,
|
||||
naga_stage,
|
||||
source
|
||||
);
|
||||
|
||||
ShaderCacheKey {
|
||||
source,
|
||||
entry_point,
|
||||
stage: naga_stage,
|
||||
shader_model: naga_options.shader_model,
|
||||
}
|
||||
}
|
||||
super::ShaderModuleSource::HlslPassthrough(passthrough) => ShaderCacheKey {
|
||||
source: passthrough.shader.clone(),
|
||||
entry_point: passthrough.entry_point.clone(),
|
||||
stage: naga_stage,
|
||||
shader_model: naga_options.shader_model,
|
||||
},
|
||||
|
||||
super::ShaderModuleSource::DxilPassthrough(passthrough) => {
|
||||
return Ok(super::CompiledShader::Precompiled(
|
||||
passthrough.shader.clone(),
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
{
|
||||
@ -351,11 +380,7 @@ impl super::Device {
|
||||
|
||||
let source_name = stage.module.raw_name.as_deref();
|
||||
|
||||
let full_stage = format!(
|
||||
"{}_{}",
|
||||
naga_stage.to_hlsl_str(),
|
||||
naga_options.shader_model.to_str()
|
||||
);
|
||||
let full_stage = format!("{}_{}", naga_stage.to_hlsl_str(), key.shader_model.to_str());
|
||||
|
||||
let compiled_shader = self.compiler_container.compile(
|
||||
self,
|
||||
@ -1671,7 +1696,7 @@ impl crate::Device for super::Device {
|
||||
.and_then(|label| alloc::ffi::CString::new(label).ok());
|
||||
match shader {
|
||||
crate::ShaderInput::Naga(naga) => Ok(super::ShaderModule {
|
||||
naga,
|
||||
source: super::ShaderModuleSource::Naga(naga),
|
||||
raw_name,
|
||||
runtime_checks: desc.runtime_checks,
|
||||
}),
|
||||
@ -1681,6 +1706,32 @@ impl crate::Device for super::Device {
|
||||
crate::ShaderInput::Msl { .. } => {
|
||||
panic!("MSL_SHADER_PASSTHROUGH is not enabled for this backend")
|
||||
}
|
||||
crate::ShaderInput::Dxil {
|
||||
shader,
|
||||
entry_point,
|
||||
num_workgroups,
|
||||
} => Ok(super::ShaderModule {
|
||||
source: super::ShaderModuleSource::DxilPassthrough(super::DxilPassthroughShader {
|
||||
shader: shader.to_vec(),
|
||||
entry_point,
|
||||
num_workgroups,
|
||||
}),
|
||||
raw_name,
|
||||
runtime_checks: desc.runtime_checks,
|
||||
}),
|
||||
crate::ShaderInput::Hlsl {
|
||||
shader,
|
||||
entry_point,
|
||||
num_workgroups,
|
||||
} => Ok(super::ShaderModule {
|
||||
source: super::ShaderModuleSource::HlslPassthrough(super::HlslPassthroughShader {
|
||||
shader: shader.to_owned(),
|
||||
entry_point,
|
||||
num_workgroups,
|
||||
}),
|
||||
raw_name,
|
||||
runtime_checks: desc.runtime_checks,
|
||||
}),
|
||||
}
|
||||
}
|
||||
unsafe fn destroy_shader_module(&self, _module: super::ShaderModule) {
|
||||
|
||||
@ -1077,7 +1077,7 @@ impl crate::DynPipelineLayout for PipelineLayout {}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct ShaderModule {
|
||||
naga: crate::NagaShader,
|
||||
source: ShaderModuleSource,
|
||||
raw_name: Option<alloc::ffi::CString>,
|
||||
runtime_checks: wgt::ShaderRuntimeChecks,
|
||||
}
|
||||
@ -1109,6 +1109,7 @@ pub(super) struct ShaderCacheValue {
|
||||
pub(super) enum CompiledShader {
|
||||
Dxc(Direct3D::Dxc::IDxcBlob),
|
||||
Fxc(Direct3D::ID3DBlob),
|
||||
Precompiled(Vec<u8>),
|
||||
}
|
||||
|
||||
impl CompiledShader {
|
||||
@ -1122,6 +1123,10 @@ impl CompiledShader {
|
||||
pShaderBytecode: unsafe { shader.GetBufferPointer() },
|
||||
BytecodeLength: unsafe { shader.GetBufferSize() },
|
||||
},
|
||||
CompiledShader::Precompiled(shader) => Direct3D12::D3D12_SHADER_BYTECODE {
|
||||
pShaderBytecode: shader.as_ptr().cast(),
|
||||
BytecodeLength: shader.len(),
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -1490,3 +1495,23 @@ impl crate::Queue for Queue {
|
||||
(1_000_000_000.0 / frequency as f64) as f32
|
||||
}
|
||||
}
|
||||
#[derive(Debug)]
|
||||
pub struct DxilPassthroughShader {
|
||||
pub shader: Vec<u8>,
|
||||
pub entry_point: String,
|
||||
pub num_workgroups: (u32, u32, u32),
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct HlslPassthroughShader {
|
||||
pub shader: String,
|
||||
pub entry_point: String,
|
||||
pub num_workgroups: (u32, u32, u32),
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum ShaderModuleSource {
|
||||
Naga(crate::NagaShader),
|
||||
DxilPassthrough(DxilPassthroughShader),
|
||||
HlslPassthrough(HlslPassthroughShader),
|
||||
}
|
||||
|
||||
@ -1346,6 +1346,9 @@ impl crate::Device for super::Device {
|
||||
panic!("`Features::MSL_SHADER_PASSTHROUGH` is not enabled")
|
||||
}
|
||||
crate::ShaderInput::Naga(naga) => naga,
|
||||
crate::ShaderInput::Dxil { .. } | crate::ShaderInput::Hlsl { .. } => {
|
||||
panic!("`Features::HLSL_DXIL_SHADER_PASSTHROUGH` is not enabled")
|
||||
}
|
||||
},
|
||||
label: desc.label.map(|str| str.to_string()),
|
||||
id: self.shared.next_shader_id.fetch_add(1, Ordering::Relaxed),
|
||||
|
||||
@ -2104,6 +2104,16 @@ pub enum ShaderInput<'a> {
|
||||
num_workgroups: (u32, u32, u32),
|
||||
},
|
||||
SpirV(&'a [u32]),
|
||||
Dxil {
|
||||
shader: &'a [u8],
|
||||
entry_point: String,
|
||||
num_workgroups: (u32, u32, u32),
|
||||
},
|
||||
Hlsl {
|
||||
shader: &'a str,
|
||||
entry_point: String,
|
||||
num_workgroups: (u32, u32, u32),
|
||||
},
|
||||
}
|
||||
|
||||
pub struct ShaderModuleDescriptor<'a> {
|
||||
|
||||
@ -1039,6 +1039,9 @@ impl crate::Device for super::Device {
|
||||
crate::ShaderInput::SpirV(_) => {
|
||||
panic!("SPIRV_SHADER_PASSTHROUGH is not enabled for this backend")
|
||||
}
|
||||
crate::ShaderInput::Dxil { .. } | crate::ShaderInput::Hlsl { .. } => {
|
||||
panic!("`Features::HLSL_DXIL_SHADER_PASSTHROUGH` is not enabled for this backend")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -1908,6 +1908,9 @@ impl crate::Device for super::Device {
|
||||
crate::ShaderInput::Msl { .. } => {
|
||||
panic!("MSL_SHADER_PASSTHROUGH is not enabled for this backend")
|
||||
}
|
||||
crate::ShaderInput::Dxil { .. } | crate::ShaderInput::Hlsl { .. } => {
|
||||
panic!("`Features::HLSL_DXIL_SHADER_PASSTHROUGH` is not enabled")
|
||||
}
|
||||
crate::ShaderInput::SpirV(spv) => Cow::Borrowed(spv),
|
||||
};
|
||||
|
||||
|
||||
@ -1244,6 +1244,16 @@ bitflags_array! {
|
||||
///
|
||||
/// [BlasTriangleGeometrySizeDescriptor::vertex_format]: super::BlasTriangleGeometrySizeDescriptor
|
||||
const EXTENDED_ACCELERATION_STRUCTURE_VERTEX_FORMATS = 1 << 51;
|
||||
|
||||
/// Enables creating shader modules from DirectX HLSL or DXIL shaders (unsafe)
|
||||
///
|
||||
/// HLSL/DXIL data is not parsed or interpreted in any way
|
||||
///
|
||||
/// Supported platforms:
|
||||
/// - DX12
|
||||
///
|
||||
/// This is a native only feature.
|
||||
const HLSL_DXIL_SHADER_PASSTHROUGH = 1 << 53;
|
||||
}
|
||||
|
||||
/// Features that are not guaranteed to be supported.
|
||||
|
||||
@ -7765,6 +7765,10 @@ pub enum CreateShaderModuleDescriptorPassthrough<'a, L> {
|
||||
SpirV(ShaderModuleDescriptorSpirV<'a, L>),
|
||||
/// Passthrough for MSL source code.
|
||||
Msl(ShaderModuleDescriptorMsl<'a, L>),
|
||||
/// Passthrough for DXIL compiled with DXC
|
||||
Dxil(ShaderModuleDescriptorDxil<'a, L>),
|
||||
/// Passthrough for HLSL
|
||||
Hlsl(ShaderModuleDescriptorHlsl<'a, L>),
|
||||
}
|
||||
|
||||
impl<'a, L> CreateShaderModuleDescriptorPassthrough<'a, L> {
|
||||
@ -7790,6 +7794,22 @@ impl<'a, L> CreateShaderModuleDescriptorPassthrough<'a, L> {
|
||||
source: inner.source.clone(),
|
||||
})
|
||||
}
|
||||
CreateShaderModuleDescriptorPassthrough::Dxil(inner) => {
|
||||
CreateShaderModuleDescriptorPassthrough::<'_, K>::Dxil(ShaderModuleDescriptorDxil {
|
||||
entry_point: inner.entry_point.clone(),
|
||||
label: fun(&inner.label),
|
||||
num_workgroups: inner.num_workgroups,
|
||||
source: inner.source,
|
||||
})
|
||||
}
|
||||
CreateShaderModuleDescriptorPassthrough::Hlsl(inner) => {
|
||||
CreateShaderModuleDescriptorPassthrough::<'_, K>::Hlsl(ShaderModuleDescriptorHlsl {
|
||||
entry_point: inner.entry_point.clone(),
|
||||
label: fun(&inner.label),
|
||||
num_workgroups: inner.num_workgroups,
|
||||
source: inner.source,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -7798,6 +7818,8 @@ impl<'a, L> CreateShaderModuleDescriptorPassthrough<'a, L> {
|
||||
match self {
|
||||
CreateShaderModuleDescriptorPassthrough::SpirV(inner) => &inner.label,
|
||||
CreateShaderModuleDescriptorPassthrough::Msl(inner) => &inner.label,
|
||||
CreateShaderModuleDescriptorPassthrough::Dxil(inner) => &inner.label,
|
||||
CreateShaderModuleDescriptorPassthrough::Hlsl(inner) => &inner.label,
|
||||
}
|
||||
}
|
||||
|
||||
@ -7809,6 +7831,8 @@ impl<'a, L> CreateShaderModuleDescriptorPassthrough<'a, L> {
|
||||
bytemuck::cast_slice(&inner.source)
|
||||
}
|
||||
CreateShaderModuleDescriptorPassthrough::Msl(inner) => inner.source.as_bytes(),
|
||||
CreateShaderModuleDescriptorPassthrough::Dxil(inner) => inner.source,
|
||||
CreateShaderModuleDescriptorPassthrough::Hlsl(inner) => inner.source.as_bytes(),
|
||||
}
|
||||
}
|
||||
|
||||
@ -7818,6 +7842,8 @@ impl<'a, L> CreateShaderModuleDescriptorPassthrough<'a, L> {
|
||||
match self {
|
||||
CreateShaderModuleDescriptorPassthrough::SpirV(..) => "spv",
|
||||
CreateShaderModuleDescriptorPassthrough::Msl(..) => "msl",
|
||||
CreateShaderModuleDescriptorPassthrough::Dxil(..) => "dxil",
|
||||
CreateShaderModuleDescriptorPassthrough::Hlsl(..) => "hlsl",
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -7838,6 +7864,38 @@ pub struct ShaderModuleDescriptorMsl<'a, L> {
|
||||
pub source: Cow<'a, str>,
|
||||
}
|
||||
|
||||
/// Descriptor for a shader module given by DirectX DXIL source.
|
||||
///
|
||||
/// This type is unique to the Rust API of `wgpu`. In the WebGPU specification,
|
||||
/// only WGSL source code strings are accepted.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ShaderModuleDescriptorDxil<'a, L> {
|
||||
/// Entrypoint.
|
||||
pub entry_point: String,
|
||||
/// Debug label of the shader module. This will show up in graphics debuggers for easy identification.
|
||||
pub label: L,
|
||||
/// Number of workgroups in each dimension x, y and z.
|
||||
pub num_workgroups: (u32, u32, u32),
|
||||
/// Shader DXIL source.
|
||||
pub source: &'a [u8],
|
||||
}
|
||||
|
||||
/// Descriptor for a shader module given by DirectX HLSL source.
|
||||
///
|
||||
/// This type is unique to the Rust API of `wgpu`. In the WebGPU specification,
|
||||
/// only WGSL source code strings are accepted.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ShaderModuleDescriptorHlsl<'a, L> {
|
||||
/// Entrypoint.
|
||||
pub entry_point: String,
|
||||
/// Debug label of the shader module. This will show up in graphics debuggers for easy identification.
|
||||
pub label: L,
|
||||
/// Number of workgroups in each dimension x, y and z.
|
||||
pub num_workgroups: (u32, u32, u32),
|
||||
/// Shader HLSL source.
|
||||
pub source: &'a str,
|
||||
}
|
||||
|
||||
/// Descriptor for a shader module given by SPIR-V binary.
|
||||
///
|
||||
/// This type is unique to the Rust API of `wgpu`. In the WebGPU specification,
|
||||
|
||||
@ -247,3 +247,15 @@ pub type ShaderModuleDescriptorMsl<'a> = wgt::ShaderModuleDescriptorMsl<'a, Labe
|
||||
/// This type is unique to the Rust API of `wgpu`. In the WebGPU specification,
|
||||
/// only WGSL source code strings are accepted.
|
||||
pub type ShaderModuleDescriptorSpirV<'a> = wgt::ShaderModuleDescriptorSpirV<'a, Label<'a>>;
|
||||
|
||||
/// Descriptor for a shader module given by DirectX HLSL source.
|
||||
///
|
||||
/// This type is unique to the Rust API of `wgpu`. In the WebGPU specification,
|
||||
/// only WGSL source code strings are accepted.
|
||||
pub type ShaderModuleDescriptorHlsl<'a> = wgt::ShaderModuleDescriptorHlsl<'a, Label<'a>>;
|
||||
|
||||
/// Descriptor for a shader module given by DirectX DXIL source.
|
||||
///
|
||||
/// This type is unique to the Rust API of `wgpu`. In the WebGPU specification,
|
||||
/// only WGSL source code strings are accepted.
|
||||
pub type ShaderModuleDescriptorDxil<'a> = wgt::ShaderModuleDescriptorDxil<'a, Label<'a>>;
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user