[d3d12] add a shader cache to avoid calling into DXC/FXC (#7729)

This commit is contained in:
Teodor Tanasoaia 2025-06-05 15:20:51 +02:00 committed by GitHub
parent 23b81da5cc
commit dcada3d858
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 90 additions and 39 deletions

View File

@ -26,7 +26,7 @@ use crate::{
}, },
dx12::{ dx12::{
borrow_optional_interface_temporarily, shader_compilation, suballocation, borrow_optional_interface_temporarily, shader_compilation, suballocation,
DynamicStorageBufferOffsets, Event, DynamicStorageBufferOffsets, Event, ShaderCacheKey, ShaderCacheValue,
}, },
AccelerationStructureEntries, TlasInstance, AccelerationStructureEntries, TlasInstance,
}; };
@ -203,6 +203,7 @@ impl super::Device {
null_rtv_handle, null_rtv_handle,
mem_allocator, mem_allocator,
compiler_container, compiler_container,
shader_cache: Default::default(),
counters: Default::default(), counters: Default::default(),
}) })
} }
@ -304,50 +305,85 @@ impl super::Device {
}; };
//TODO: reuse the writer //TODO: reuse the writer
let (source, entry_point) = {
let mut source = String::new(); let mut source = String::new();
let mut writer = hlsl::Writer::new(&mut source, naga_options, &pipeline_options); let mut writer = hlsl::Writer::new(&mut source, naga_options, &pipeline_options);
let reflection_info = {
profiling::scope!("naga::back::hlsl::write"); profiling::scope!("naga::back::hlsl::write");
writer let mut reflection_info = writer
.write(&module, &info, frag_ep.as_ref()) .write(&module, &info, frag_ep.as_ref())
.map_err(|e| crate::PipelineError::Linkage(stage_bit, format!("HLSL: {e:?}")))? .map_err(|e| crate::PipelineError::Linkage(stage_bit, format!("HLSL: {e:?}")))?;
assert_eq!(reflection_info.entry_point_names.len(), 1);
let entry_point = reflection_info
.entry_point_names
.pop()
.unwrap()
.map_err(|e| crate::PipelineError::Linkage(stage_bit, format!("{e}")))?;
(source, entry_point)
}; };
log::info!(
"Naga generated shader for {:?} at {:?}:\n{}",
entry_point,
naga_stage,
source
);
let key = ShaderCacheKey {
source,
entry_point,
stage: naga_stage,
shader_model: naga_options.shader_model,
};
{
let mut shader_cache = self.shader_cache.lock();
let nr_of_shaders_compiled = shader_cache.nr_of_shaders_compiled;
if let Some(value) = shader_cache.entries.get_mut(&key) {
value.last_used = nr_of_shaders_compiled;
return Ok(value.shader.clone());
}
}
let source_name = stage.module.raw_name.as_deref();
let full_stage = format!( let full_stage = format!(
"{}_{}", "{}_{}",
naga_stage.to_hlsl_str(), naga_stage.to_hlsl_str(),
naga_options.shader_model.to_str() naga_options.shader_model.to_str()
); );
let raw_ep = reflection_info.entry_point_names[0] let compiled_shader = self.compiler_container.compile(
.as_ref()
.map_err(|e| crate::PipelineError::Linkage(stage_bit, format!("{e}")))?;
let source_name = stage.module.raw_name.as_deref();
let result = self.compiler_container.compile(
self, self,
&source, &key.source,
source_name, source_name,
raw_ep, &key.entry_point,
stage_bit, stage_bit,
&full_stage, &full_stage,
); )?;
let log_level = if result.is_ok() { {
log::Level::Info let mut shader_cache = self.shader_cache.lock();
} else { shader_cache.nr_of_shaders_compiled += 1;
log::Level::Error let nr_of_shaders_compiled = shader_cache.nr_of_shaders_compiled;
let value = ShaderCacheValue {
last_used: nr_of_shaders_compiled,
shader: compiled_shader.clone(),
}; };
shader_cache.entries.insert(key, value);
log::log!( // Retain all entries that have been used since we compiled the last 100 shaders.
log_level, if shader_cache.entries.len() > 200 {
"Naga generated shader for {:?} at {:?}:\n{}", shader_cache
raw_ep, .entries
naga_stage, .retain(|_, v| v.last_used >= nr_of_shaders_compiled - 100);
source }
); }
result
Ok(compiled_shader)
} }
pub fn raw_device(&self) -> &Direct3D12::ID3D12Device { pub fn raw_device(&self) -> &Direct3D12::ID3D12Device {
@ -1818,11 +1854,6 @@ impl crate::Device for super::Device {
} }
.map_err(|err| crate::PipelineError::Linkage(shader_stages, err.to_string()))?; .map_err(|err| crate::PipelineError::Linkage(shader_stages, err.to_string()))?;
unsafe { blob_vs.destroy() };
if let Some(blob_fs) = blob_fs {
unsafe { blob_fs.destroy() };
};
if let Some(label) = desc.label { if let Some(label) = desc.label {
raw.set_name(label)?; raw.set_name(label)?;
} }
@ -1880,8 +1911,6 @@ impl crate::Device for super::Device {
} }
}; };
unsafe { blob_cs.destroy() };
let raw: Direct3D12::ID3D12PipelineState = pair.map_err(|err| { let raw: Direct3D12::ID3D12PipelineState = pair.map_err(|err| {
crate::PipelineError::Linkage(wgt::ShaderStages::COMPUTE, err.to_string()) crate::PipelineError::Linkage(wgt::ShaderStages::COMPUTE, err.to_string())
})?; })?;

View File

@ -84,10 +84,11 @@ mod suballocation;
mod types; mod types;
mod view; mod view;
use alloc::{borrow::ToOwned as _, sync::Arc, vec::Vec}; use alloc::{borrow::ToOwned as _, string::String, sync::Arc, vec::Vec};
use core::{ffi, fmt, mem, num::NonZeroU32, ops::Deref}; use core::{ffi, fmt, mem, num::NonZeroU32, ops::Deref};
use arrayvec::ArrayVec; use arrayvec::ArrayVec;
use hashbrown::HashMap;
use parking_lot::{Mutex, RwLock}; use parking_lot::{Mutex, RwLock};
use suballocation::Allocator; use suballocation::Allocator;
use windows::{ use windows::{
@ -656,6 +657,7 @@ pub struct Device {
null_rtv_handle: descriptor::Handle, null_rtv_handle: descriptor::Handle,
mem_allocator: Allocator, mem_allocator: Allocator,
compiler_container: Arc<shader_compilation::CompilerContainer>, compiler_container: Arc<shader_compilation::CompilerContainer>,
shader_cache: Mutex<ShaderCache>,
counters: Arc<wgt::HalCounters>, counters: Arc<wgt::HalCounters>,
} }
@ -1077,6 +1079,28 @@ pub struct ShaderModule {
impl crate::DynShaderModule for ShaderModule {} impl crate::DynShaderModule for ShaderModule {}
#[derive(Default)]
pub struct ShaderCache {
nr_of_shaders_compiled: u32,
entries: HashMap<ShaderCacheKey, ShaderCacheValue>,
}
#[derive(PartialEq, Eq, Hash)]
pub(super) struct ShaderCacheKey {
source: String,
entry_point: String,
stage: naga::ShaderStage,
shader_model: naga::back::hlsl::ShaderModel,
}
pub(super) struct ShaderCacheValue {
/// This is the value of [`ShaderCache::nr_of_shaders_compiled`]
/// at the time the cache entry was last used.
last_used: u32,
shader: CompiledShader,
}
#[derive(Clone)]
pub(super) enum CompiledShader { pub(super) enum CompiledShader {
Dxc(Direct3D::Dxc::IDxcBlob), Dxc(Direct3D::Dxc::IDxcBlob),
Fxc(Direct3D::ID3DBlob), Fxc(Direct3D::ID3DBlob),
@ -1095,8 +1119,6 @@ impl CompiledShader {
}, },
} }
} }
unsafe fn destroy(self) {}
} }
#[derive(Debug)] #[derive(Debug)]