mirror of
https://github.com/gfx-rs/wgpu.git
synced 2025-12-08 21:26:17 +00:00
[d3d12] add a shader cache to avoid calling into DXC/FXC (#7729)
This commit is contained in:
parent
23b81da5cc
commit
dcada3d858
@ -26,7 +26,7 @@ use crate::{
|
|||||||
},
|
},
|
||||||
dx12::{
|
dx12::{
|
||||||
borrow_optional_interface_temporarily, shader_compilation, suballocation,
|
borrow_optional_interface_temporarily, shader_compilation, suballocation,
|
||||||
DynamicStorageBufferOffsets, Event,
|
DynamicStorageBufferOffsets, Event, ShaderCacheKey, ShaderCacheValue,
|
||||||
},
|
},
|
||||||
AccelerationStructureEntries, TlasInstance,
|
AccelerationStructureEntries, TlasInstance,
|
||||||
};
|
};
|
||||||
@ -203,6 +203,7 @@ impl super::Device {
|
|||||||
null_rtv_handle,
|
null_rtv_handle,
|
||||||
mem_allocator,
|
mem_allocator,
|
||||||
compiler_container,
|
compiler_container,
|
||||||
|
shader_cache: Default::default(),
|
||||||
counters: Default::default(),
|
counters: Default::default(),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@ -304,50 +305,85 @@ impl super::Device {
|
|||||||
};
|
};
|
||||||
|
|
||||||
//TODO: reuse the writer
|
//TODO: reuse the writer
|
||||||
let mut source = String::new();
|
let (source, entry_point) = {
|
||||||
let mut writer = hlsl::Writer::new(&mut source, naga_options, &pipeline_options);
|
let mut source = String::new();
|
||||||
let reflection_info = {
|
let mut writer = hlsl::Writer::new(&mut source, naga_options, &pipeline_options);
|
||||||
|
|
||||||
profiling::scope!("naga::back::hlsl::write");
|
profiling::scope!("naga::back::hlsl::write");
|
||||||
writer
|
let mut reflection_info = writer
|
||||||
.write(&module, &info, frag_ep.as_ref())
|
.write(&module, &info, frag_ep.as_ref())
|
||||||
.map_err(|e| crate::PipelineError::Linkage(stage_bit, format!("HLSL: {e:?}")))?
|
.map_err(|e| crate::PipelineError::Linkage(stage_bit, format!("HLSL: {e:?}")))?;
|
||||||
|
|
||||||
|
assert_eq!(reflection_info.entry_point_names.len(), 1);
|
||||||
|
|
||||||
|
let entry_point = reflection_info
|
||||||
|
.entry_point_names
|
||||||
|
.pop()
|
||||||
|
.unwrap()
|
||||||
|
.map_err(|e| crate::PipelineError::Linkage(stage_bit, format!("{e}")))?;
|
||||||
|
|
||||||
|
(source, entry_point)
|
||||||
};
|
};
|
||||||
|
|
||||||
|
log::info!(
|
||||||
|
"Naga generated shader for {:?} at {:?}:\n{}",
|
||||||
|
entry_point,
|
||||||
|
naga_stage,
|
||||||
|
source
|
||||||
|
);
|
||||||
|
|
||||||
|
let key = ShaderCacheKey {
|
||||||
|
source,
|
||||||
|
entry_point,
|
||||||
|
stage: naga_stage,
|
||||||
|
shader_model: naga_options.shader_model,
|
||||||
|
};
|
||||||
|
|
||||||
|
{
|
||||||
|
let mut shader_cache = self.shader_cache.lock();
|
||||||
|
let nr_of_shaders_compiled = shader_cache.nr_of_shaders_compiled;
|
||||||
|
if let Some(value) = shader_cache.entries.get_mut(&key) {
|
||||||
|
value.last_used = nr_of_shaders_compiled;
|
||||||
|
return Ok(value.shader.clone());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let source_name = stage.module.raw_name.as_deref();
|
||||||
|
|
||||||
let full_stage = format!(
|
let full_stage = format!(
|
||||||
"{}_{}",
|
"{}_{}",
|
||||||
naga_stage.to_hlsl_str(),
|
naga_stage.to_hlsl_str(),
|
||||||
naga_options.shader_model.to_str()
|
naga_options.shader_model.to_str()
|
||||||
);
|
);
|
||||||
|
|
||||||
let raw_ep = reflection_info.entry_point_names[0]
|
let compiled_shader = self.compiler_container.compile(
|
||||||
.as_ref()
|
|
||||||
.map_err(|e| crate::PipelineError::Linkage(stage_bit, format!("{e}")))?;
|
|
||||||
|
|
||||||
let source_name = stage.module.raw_name.as_deref();
|
|
||||||
|
|
||||||
let result = self.compiler_container.compile(
|
|
||||||
self,
|
self,
|
||||||
&source,
|
&key.source,
|
||||||
source_name,
|
source_name,
|
||||||
raw_ep,
|
&key.entry_point,
|
||||||
stage_bit,
|
stage_bit,
|
||||||
&full_stage,
|
&full_stage,
|
||||||
);
|
)?;
|
||||||
|
|
||||||
let log_level = if result.is_ok() {
|
{
|
||||||
log::Level::Info
|
let mut shader_cache = self.shader_cache.lock();
|
||||||
} else {
|
shader_cache.nr_of_shaders_compiled += 1;
|
||||||
log::Level::Error
|
let nr_of_shaders_compiled = shader_cache.nr_of_shaders_compiled;
|
||||||
};
|
let value = ShaderCacheValue {
|
||||||
|
last_used: nr_of_shaders_compiled,
|
||||||
|
shader: compiled_shader.clone(),
|
||||||
|
};
|
||||||
|
shader_cache.entries.insert(key, value);
|
||||||
|
|
||||||
log::log!(
|
// Retain all entries that have been used since we compiled the last 100 shaders.
|
||||||
log_level,
|
if shader_cache.entries.len() > 200 {
|
||||||
"Naga generated shader for {:?} at {:?}:\n{}",
|
shader_cache
|
||||||
raw_ep,
|
.entries
|
||||||
naga_stage,
|
.retain(|_, v| v.last_used >= nr_of_shaders_compiled - 100);
|
||||||
source
|
}
|
||||||
);
|
}
|
||||||
result
|
|
||||||
|
Ok(compiled_shader)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn raw_device(&self) -> &Direct3D12::ID3D12Device {
|
pub fn raw_device(&self) -> &Direct3D12::ID3D12Device {
|
||||||
@ -1818,11 +1854,6 @@ impl crate::Device for super::Device {
|
|||||||
}
|
}
|
||||||
.map_err(|err| crate::PipelineError::Linkage(shader_stages, err.to_string()))?;
|
.map_err(|err| crate::PipelineError::Linkage(shader_stages, err.to_string()))?;
|
||||||
|
|
||||||
unsafe { blob_vs.destroy() };
|
|
||||||
if let Some(blob_fs) = blob_fs {
|
|
||||||
unsafe { blob_fs.destroy() };
|
|
||||||
};
|
|
||||||
|
|
||||||
if let Some(label) = desc.label {
|
if let Some(label) = desc.label {
|
||||||
raw.set_name(label)?;
|
raw.set_name(label)?;
|
||||||
}
|
}
|
||||||
@ -1880,8 +1911,6 @@ impl crate::Device for super::Device {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
unsafe { blob_cs.destroy() };
|
|
||||||
|
|
||||||
let raw: Direct3D12::ID3D12PipelineState = pair.map_err(|err| {
|
let raw: Direct3D12::ID3D12PipelineState = pair.map_err(|err| {
|
||||||
crate::PipelineError::Linkage(wgt::ShaderStages::COMPUTE, err.to_string())
|
crate::PipelineError::Linkage(wgt::ShaderStages::COMPUTE, err.to_string())
|
||||||
})?;
|
})?;
|
||||||
|
|||||||
@ -84,10 +84,11 @@ mod suballocation;
|
|||||||
mod types;
|
mod types;
|
||||||
mod view;
|
mod view;
|
||||||
|
|
||||||
use alloc::{borrow::ToOwned as _, sync::Arc, vec::Vec};
|
use alloc::{borrow::ToOwned as _, string::String, sync::Arc, vec::Vec};
|
||||||
use core::{ffi, fmt, mem, num::NonZeroU32, ops::Deref};
|
use core::{ffi, fmt, mem, num::NonZeroU32, ops::Deref};
|
||||||
|
|
||||||
use arrayvec::ArrayVec;
|
use arrayvec::ArrayVec;
|
||||||
|
use hashbrown::HashMap;
|
||||||
use parking_lot::{Mutex, RwLock};
|
use parking_lot::{Mutex, RwLock};
|
||||||
use suballocation::Allocator;
|
use suballocation::Allocator;
|
||||||
use windows::{
|
use windows::{
|
||||||
@ -656,6 +657,7 @@ pub struct Device {
|
|||||||
null_rtv_handle: descriptor::Handle,
|
null_rtv_handle: descriptor::Handle,
|
||||||
mem_allocator: Allocator,
|
mem_allocator: Allocator,
|
||||||
compiler_container: Arc<shader_compilation::CompilerContainer>,
|
compiler_container: Arc<shader_compilation::CompilerContainer>,
|
||||||
|
shader_cache: Mutex<ShaderCache>,
|
||||||
counters: Arc<wgt::HalCounters>,
|
counters: Arc<wgt::HalCounters>,
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1077,6 +1079,28 @@ pub struct ShaderModule {
|
|||||||
|
|
||||||
impl crate::DynShaderModule for ShaderModule {}
|
impl crate::DynShaderModule for ShaderModule {}
|
||||||
|
|
||||||
|
#[derive(Default)]
|
||||||
|
pub struct ShaderCache {
|
||||||
|
nr_of_shaders_compiled: u32,
|
||||||
|
entries: HashMap<ShaderCacheKey, ShaderCacheValue>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(PartialEq, Eq, Hash)]
|
||||||
|
pub(super) struct ShaderCacheKey {
|
||||||
|
source: String,
|
||||||
|
entry_point: String,
|
||||||
|
stage: naga::ShaderStage,
|
||||||
|
shader_model: naga::back::hlsl::ShaderModel,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(super) struct ShaderCacheValue {
|
||||||
|
/// This is the value of [`ShaderCache::nr_of_shaders_compiled`]
|
||||||
|
/// at the time the cache entry was last used.
|
||||||
|
last_used: u32,
|
||||||
|
shader: CompiledShader,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
pub(super) enum CompiledShader {
|
pub(super) enum CompiledShader {
|
||||||
Dxc(Direct3D::Dxc::IDxcBlob),
|
Dxc(Direct3D::Dxc::IDxcBlob),
|
||||||
Fxc(Direct3D::ID3DBlob),
|
Fxc(Direct3D::ID3DBlob),
|
||||||
@ -1095,8 +1119,6 @@ impl CompiledShader {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
unsafe fn destroy(self) {}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user