Allow obtaining custom implementation from wgpu api types (#7541)

This commit is contained in:
sagudev 2025-04-18 22:58:49 +02:00 committed by GitHub
parent a9a3ea3a41
commit 6666d528b2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
30 changed files with 262 additions and 17 deletions

View File

@ -3,9 +3,9 @@ use std::pin::Pin;
use std::sync::Arc;
use wgpu::custom::{
AdapterInterface, DeviceInterface, DispatchAdapter, DispatchDevice, DispatchQueue,
DispatchShaderModule, DispatchSurface, InstanceInterface, QueueInterface, RequestAdapterFuture,
ShaderModuleInterface,
AdapterInterface, ComputePipelineInterface, DeviceInterface, DispatchAdapter, DispatchDevice,
DispatchQueue, DispatchShaderModule, DispatchSurface, InstanceInterface, QueueInterface,
RequestAdapterFuture, ShaderModuleInterface,
};
#[derive(Debug, Clone)]
@ -163,9 +163,10 @@ impl DeviceInterface for CustomDevice {
fn create_compute_pipeline(
&self,
_desc: &wgpu::ComputePipelineDescriptor<'_>,
desc: &wgpu::ComputePipelineDescriptor<'_>,
) -> wgpu::custom::DispatchComputePipeline {
unimplemented!()
let module = desc.module.as_custom::<CustomShaderModule>().unwrap();
wgpu::custom::DispatchComputePipeline::custom(CustomComputePipeline(module.0.clone()))
}
unsafe fn create_pipeline_cache(
@ -265,7 +266,7 @@ impl DeviceInterface for CustomDevice {
}
#[derive(Debug)]
struct CustomShaderModule(Counter);
pub struct CustomShaderModule(pub Counter);
impl ShaderModuleInterface for CustomShaderModule {
fn get_compilation_info(&self) -> Pin<Box<dyn wgpu::custom::ShaderCompilationInfoFuture>> {
@ -346,3 +347,12 @@ impl QueueInterface for CustomQueue {
unimplemented!()
}
}
#[derive(Debug)]
pub struct CustomComputePipeline(pub Counter);
impl ComputePipelineInterface for CustomComputePipeline {
fn get_bind_group_layout(&self, _index: u32) -> wgpu::custom::DispatchBindGroupLayout {
unimplemented!()
}
}

View File

@ -1,6 +1,6 @@
use std::marker::PhantomData;
use custom::Counter;
use custom::{Counter, CustomShaderModule};
use wgpu::{DeviceDescriptor, RequestAdapterOptions};
mod custom;
@ -31,12 +31,26 @@ async fn main() {
.unwrap();
assert_eq!(counter.count(), 5);
let _module = device.create_shader_module(wgpu::ShaderModuleDescriptor {
let module = device.create_shader_module(wgpu::ShaderModuleDescriptor {
label: Some("shader"),
source: wgpu::ShaderSource::Dummy(PhantomData),
});
let custom_module = module.as_custom::<CustomShaderModule>().unwrap();
assert_eq!(custom_module.0.count(), 6);
let _module_clone = module.clone();
assert_eq!(counter.count(), 6);
let _pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
label: None,
layout: None,
module: &module,
entry_point: None,
compilation_options: Default::default(),
cache: None,
});
assert_eq!(counter.count(), 7);
}
assert_eq!(counter.count(), 1);
}

View File

@ -133,6 +133,12 @@ impl Adapter {
}
}
#[cfg(custom)]
/// Returns custom implementation of adapter (if custom backend and is internally T)
pub fn as_custom<T: custom::AdapterInterface>(&self) -> Option<&T> {
self.inner.as_custom()
}
#[cfg(custom)]
/// Creates Adapter from custom implementation
pub fn from_custom<T: custom::AdapterInterface>(adapter: T) -> Self {

View File

@ -17,6 +17,14 @@ static_assertions::assert_impl_all!(BindGroup: Send, Sync);
crate::cmp::impl_eq_ord_hash_proxy!(BindGroup => .inner);
impl BindGroup {
#[cfg(custom)]
/// Returns custom implementation of BindGroup (if custom backend and is internally T)
pub fn as_custom<T: custom::BindGroupInterface>(&self) -> Option<&T> {
self.inner.as_custom()
}
}
/// Resource to be bound by a [`BindGroup`] for use with a pipeline.
///
/// The pipelines [`BindGroupLayout`] must contain a matching [`BindingType`].

View File

@ -20,6 +20,14 @@ static_assertions::assert_impl_all!(BindGroupLayout: Send, Sync);
crate::cmp::impl_eq_ord_hash_proxy!(BindGroupLayout => .inner);
impl BindGroupLayout {
#[cfg(custom)]
/// Returns custom implementation of BindGroupLayout (if custom backend and is internally T)
pub fn as_custom<T: custom::BindGroupLayoutInterface>(&self) -> Option<&T> {
self.inner.as_custom()
}
}
/// Describes a [`BindGroupLayout`].
///
/// For use with [`Device::create_bind_group_layout`].

View File

@ -174,6 +174,12 @@ impl Blas {
hal_blas_callback(None)
}
}
#[cfg(custom)]
/// Returns custom implementation of Blas (if custom backend and is internally T)
pub fn as_custom<T: crate::custom::BlasInterface>(&self) -> Option<&T> {
self.inner.as_custom()
}
}
/// Context version of [BlasTriangleGeometry].

View File

@ -386,6 +386,12 @@ impl Buffer {
) -> BufferViewMut<'_> {
self.slice(bounds).get_mapped_range_mut()
}
#[cfg(custom)]
/// Returns custom implementation of Buffer (if custom backend and is internally T)
pub fn as_custom<T: custom::BufferInterface>(&self) -> Option<&T> {
self.inner.as_custom()
}
}
/// A slice of a [`Buffer`], to be mapped, used for vertex or index data, or the like.

View File

@ -13,3 +13,11 @@ pub struct CommandBuffer {
}
#[cfg(send_sync)]
static_assertions::assert_impl_all!(CommandBuffer: Send, Sync);
impl CommandBuffer {
#[cfg(custom)]
/// Returns custom implementation of CommandBuffer (if custom backend and is internally T)
pub fn as_custom<T: custom::CommandBufferInterface>(&self) -> Option<&T> {
self.buffer.as_custom()
}
}

View File

@ -262,6 +262,12 @@ impl CommandEncoder {
hal_command_encoder_callback(None)
}
}
#[cfg(custom)]
/// Returns custom implementation of CommandEncoder (if custom backend and is internally T)
pub fn as_custom<T: custom::CommandEncoderInterface>(&self) -> Option<&T> {
self.inner.as_custom()
}
}
/// [`Features::TIMESTAMP_QUERY_INSIDE_ENCODERS`] must be enabled on the device in order to call these functions.

View File

@ -94,6 +94,12 @@ impl ComputePass<'_> {
self.inner
.dispatch_workgroups_indirect(&indirect_buffer.inner, indirect_offset);
}
#[cfg(custom)]
/// Returns custom implementation of ComputePass (if custom backend and is internally T)
pub fn as_custom<T: custom::ComputePassInterface>(&self) -> Option<&T> {
self.inner.as_custom()
}
}
/// [`Features::PUSH_CONSTANTS`] must be enabled on the device in order to call these functions.

View File

@ -27,6 +27,12 @@ impl ComputePipeline {
let bind_group = self.inner.get_bind_group_layout(index);
BindGroupLayout { inner: bind_group }
}
#[cfg(custom)]
/// Returns custom implementation of ComputePipeline (if custom backend and is internally T)
pub fn as_custom<T: custom::ComputePipelineInterface>(&self) -> Option<&T> {
self.inner.as_custom()
}
}
/// Describes a compute pipeline.

View File

@ -34,6 +34,12 @@ pub type DeviceDescriptor<'a> = wgt::DeviceDescriptor<Label<'a>>;
static_assertions::assert_impl_all!(DeviceDescriptor<'_>: Send, Sync);
impl Device {
#[cfg(custom)]
/// Returns custom implementation of Device (if custom backend and is internally T)
pub fn as_custom<T: custom::DeviceInterface>(&self) -> Option<&T> {
self.inner.as_custom()
}
#[cfg(custom)]
/// Creates Device from custom implementation
pub fn from_custom<T: custom::DeviceInterface>(device: T) -> Self {

View File

@ -208,6 +208,12 @@ impl Instance {
}
}
#[cfg(custom)]
/// Returns custom implementation of Instance (if custom backend and is internally T)
pub fn as_custom<T: custom::InstanceInterface>(&self) -> Option<&T> {
self.inner.as_custom()
}
/// Retrieves all available [`Adapter`]s that match the given [`Backends`].
///
/// # Arguments

View File

@ -87,4 +87,10 @@ impl PipelineCache {
pub fn get_data(&self) -> Option<Vec<u8>> {
self.inner.get_data()
}
#[cfg(custom)]
/// Returns custom implementation of PipelineCache (if custom backend and is internally T)
pub fn as_custom<T: custom::PipelineCacheInterface>(&self) -> Option<&T> {
self.inner.as_custom()
}
}

View File

@ -15,6 +15,14 @@ static_assertions::assert_impl_all!(PipelineLayout: Send, Sync);
crate::cmp::impl_eq_ord_hash_proxy!(PipelineLayout => .inner);
impl PipelineLayout {
#[cfg(custom)]
/// Returns custom implementation of PipelineLayout (if custom backend and is internally T)
pub fn as_custom<T: custom::PipelineLayoutInterface>(&self) -> Option<&T> {
self.inner.as_custom()
}
}
/// Describes a [`PipelineLayout`].
///
/// For use with [`Device::create_pipeline_layout`].

View File

@ -15,6 +15,14 @@ static_assertions::assert_impl_all!(QuerySet: Send, Sync);
crate::cmp::impl_eq_ord_hash_proxy!(QuerySet => .inner);
impl QuerySet {
#[cfg(custom)]
/// Returns custom implementation of QuerySet (if custom backend and is internally T)
pub fn as_custom<T: custom::QuerySetInterface>(&self) -> Option<&T> {
self.inner.as_custom()
}
}
/// Describes a [`QuerySet`].
///
/// For use with [`Device::create_query_set`].

View File

@ -19,6 +19,22 @@ static_assertions::assert_impl_all!(Queue: Send, Sync);
crate::cmp::impl_eq_ord_hash_proxy!(Queue => .inner);
impl Queue {
#[cfg(custom)]
/// Returns custom implementation of Queue (if custom backend and is internally T)
pub fn as_custom<T: custom::QueueInterface>(&self) -> Option<&T> {
self.inner.as_custom()
}
#[cfg(custom)]
/// Creates Queue from custom implementation
pub fn from_custom<T: custom::QueueInterface>(queue: T) -> Self {
Self {
inner: dispatch::DispatchQueue::custom(queue),
}
}
}
/// Identifier for a particular call to [`Queue::submit`]. Can be used
/// as part of an argument to [`Device::poll`] to block for a particular
/// submission to finish.
@ -51,6 +67,14 @@ pub struct QueueWriteBufferView<'a> {
#[cfg(send_sync)]
static_assertions::assert_impl_all!(QueueWriteBufferView<'_>: Send, Sync);
impl QueueWriteBufferView<'_> {
#[cfg(custom)]
/// Returns custom implementation of QueueWriteBufferView (if custom backend and is internally T)
pub fn as_custom<T: custom::QueueWriteBufferInterface>(&self) -> Option<&T> {
self.inner.as_custom()
}
}
impl Deref for QueueWriteBufferView<'_> {
type Target = [u8];
@ -81,14 +105,6 @@ impl Drop for QueueWriteBufferView<'_> {
}
impl Queue {
#[cfg(custom)]
/// Creates Queue from custom implementation
pub fn from_custom<T: custom::QueueInterface>(queue: T) -> Self {
Self {
inner: dispatch::DispatchQueue::custom(queue),
}
}
/// Copies the bytes of `data` into `buffer` starting at `offset`.
///
/// The data must be written fully in-bounds, that is, `offset + data.len() <= buffer.len()`.

View File

@ -18,6 +18,14 @@ static_assertions::assert_impl_all!(RenderBundle: Send, Sync);
crate::cmp::impl_eq_ord_hash_proxy!(RenderBundle => .inner);
impl RenderBundle {
#[cfg(custom)]
/// Returns custom implementation of RenderBundle (if custom backend and is internally T)
pub fn as_custom<T: custom::RenderBundleInterface>(&self) -> Option<&T> {
self.inner.as_custom()
}
}
/// Describes a [`RenderBundle`].
///
/// For use with [`RenderBundleEncoder::finish`].

View File

@ -188,6 +188,12 @@ impl<'a> RenderBundleEncoder<'a> {
self.inner
.draw_indexed_indirect(&indirect_buffer.inner, indirect_offset);
}
#[cfg(custom)]
/// Returns custom implementation of RenderBundleEncoder (if custom backend and is internally T)
pub fn as_custom<T: custom::RenderBundleEncoderInterface>(&self) -> Option<&T> {
self.inner.as_custom()
}
}
/// [`Features::PUSH_CONSTANTS`] must be enabled on the device in order to call these functions.

View File

@ -306,6 +306,12 @@ impl RenderPass<'_> {
self.inner
.multi_draw_indexed_indirect(&indirect_buffer.inner, indirect_offset, count);
}
#[cfg(custom)]
/// Returns custom implementation of RenderPass (if custom backend and is internally T)
pub fn as_custom<T: custom::RenderPassInterface>(&self) -> Option<&T> {
self.inner.as_custom()
}
}
/// [`Features::MULTI_DRAW_INDIRECT_COUNT`] must be enabled on the device in order to call these functions.

View File

@ -28,6 +28,12 @@ impl RenderPipeline {
let layout = self.inner.get_bind_group_layout(index);
BindGroupLayout { inner: layout }
}
#[cfg(custom)]
/// Returns custom implementation of RenderPipeline (if custom backend and is internally T)
pub fn as_custom<T: custom::RenderPipelineInterface>(&self) -> Option<&T> {
self.inner.as_custom()
}
}
/// Specifies an interpretation of the bytes of a vertex buffer as vertex attributes.

View File

@ -18,6 +18,14 @@ static_assertions::assert_impl_all!(Sampler: Send, Sync);
crate::cmp::impl_eq_ord_hash_proxy!(Sampler => .inner);
impl Sampler {
#[cfg(custom)]
/// Returns custom implementation of Sampler (if custom backend and is internally T)
pub fn as_custom<T: custom::SamplerInterface>(&self) -> Option<&T> {
self.inner.as_custom()
}
}
/// Describes a [`Sampler`].
///
/// For use with [`Device::create_sampler`].

View File

@ -29,6 +29,12 @@ impl ShaderModule {
pub fn get_compilation_info(&self) -> impl Future<Output = CompilationInfo> + WasmNotSend {
self.inner.get_compilation_info()
}
#[cfg(custom)]
/// Returns custom implementation of ShaderModule (if custom backend and is internally T)
pub fn as_custom<T: custom::ShaderModuleInterface>(&self) -> Option<&T> {
self.inner.as_custom()
}
}
/// Compilation information for a shader module.

View File

@ -171,6 +171,12 @@ impl Surface<'_> {
hal_surface_callback(None)
}
}
#[cfg(custom)]
/// Returns custom implementation of Surface (if custom backend and is internally T)
pub fn as_custom<T: custom::SurfaceInterface>(&self) -> Option<&T> {
self.inner.as_custom()
}
}
// This custom implementation is required because [`Surface::_surface`] doesn't

View File

@ -38,6 +38,12 @@ impl SurfaceTexture {
self.presented = true;
self.detail.present();
}
#[cfg(custom)]
/// Returns custom implementation of SurfaceTexture (if custom backend and is internally T)
pub fn as_custom<T: crate::custom::SurfaceOutputDetailInterface>(&self) -> Option<&T> {
self.detail.as_custom()
}
}
impl Drop for SurfaceTexture {

View File

@ -37,6 +37,12 @@ impl Texture {
}
}
#[cfg(custom)]
/// Returns custom implementation of Texture (if custom backend and is internally T)
pub fn as_custom<T: custom::TextureInterface>(&self) -> Option<&T> {
self.inner.as_custom()
}
/// Creates a view of this texture, specifying an interpretation of its texels and
/// possibly a subset of its layers and mip levels.
///

View File

@ -40,6 +40,12 @@ impl TextureView {
hal_texture_view_callback(None)
}
}
#[cfg(custom)]
/// Returns custom implementation of TextureView (if custom backend and is internally T)
pub fn as_custom<T: custom::TextureViewInterface>(&self) -> Option<&T> {
self.inner.as_custom()
}
}
/// Describes a [`TextureView`].

View File

@ -57,6 +57,12 @@ impl Tlas {
hal_tlas_callback(None)
}
}
#[cfg(custom)]
/// Returns custom implementation of Tlas (if custom backend and is internally T)
pub fn as_custom<T: crate::custom::TlasInterface>(&self) -> Option<&T> {
self.shared.inner.as_custom()
}
}
/// Entry for a top level acceleration structure build.

View File

@ -18,6 +18,11 @@ macro_rules! dyn_type {
pub(crate) fn new<T: $interface>(t: T) -> Self {
Self(Arc::new(t))
}
#[allow(clippy::allow_attributes, dead_code)]
pub(crate) fn downcast<T: $interface>(&self) -> Option<&T> {
self.0.as_ref().as_any().downcast_ref()
}
}
impl core::ops::Deref for $name {
@ -46,6 +51,10 @@ macro_rules! dyn_type {
pub(crate) fn new<T: $interface>(t: T) -> Self {
Self(Arc::new(t))
}
pub(crate) fn downcast<T: $interface>(&self) -> Option<&T> {
self.0.as_ref().as_any().downcast_ref()
}
}
impl core::ops::Deref for $name {

View File

@ -54,8 +54,20 @@ pub type BufferMapCallback = Box<dyn FnOnce(Result<(), crate::BufferAsyncError>)
#[cfg(not(send_sync))]
pub type BufferMapCallback = Box<dyn FnOnce(Result<(), crate::BufferAsyncError>) + 'static>;
// remove when rust 1.86
#[cfg_attr(not(custom), expect(dead_code))]
pub trait AsAny {
fn as_any(&self) -> &dyn Any;
}
impl<T: 'static> AsAny for T {
fn as_any(&self) -> &dyn Any {
self
}
}
// Common traits on all the interface traits
trait_alias!(CommonTraits: Any + Debug + WasmNotSendSync);
trait_alias!(CommonTraits: AsAny + Any + Debug + WasmNotSendSync);
pub trait InstanceInterface: CommonTraits {
fn new(desc: &crate::InstanceDescriptor) -> Self
@ -575,6 +587,16 @@ macro_rules! dispatch_types {
}
}
#[cfg(custom)]
#[inline]
#[allow(clippy::allow_attributes, unused)]
pub fn as_custom<T: $interface>(&self) -> Option<&T> {
match self {
Self::Custom(value) => value.downcast(),
_ => None,
}
}
#[cfg(webgpu)]
#[inline]
#[allow(clippy::allow_attributes, unused)]
@ -693,6 +715,16 @@ macro_rules! dispatch_types {
}
}
#[cfg(custom)]
#[inline]
#[allow(clippy::allow_attributes, unused)]
pub fn as_custom<T: $interface>(&self) -> Option<&T> {
match self {
Self::Custom(value) => value.downcast(),
_ => None,
}
}
#[cfg(webgpu)]
#[inline]
#[allow(clippy::allow_attributes, unused)]