mirror of
https://github.com/gfx-rs/wgpu.git
synced 2025-12-08 21:26:17 +00:00
Potentially optimize dot4{I,U}8Packed on Metal (#7653)
* Potentially optimize `dot4{I,U}8Packed` on Metal
This might allow the Metal compiler to emit faster code (but that's not
confirmed). See
<https://github.com/gpuweb/gpuweb/issues/2677#issuecomment-1713292226>
for the optimization. The limitation to Metal 2.1+ is discussed here:
<https://github.com/gfx-rs/wgpu/pull/7574#issuecomment-2835464472>.
* [naga] Factor out new part of `put_block` on msl
CI on test failed because the latest changes to `put_block` made its
stack too big. Factoring out the new code into a separate method fixes
this issue.
This commit is contained in:
parent
bc0a023cd1
commit
d7e6a0e1fa
@ -65,7 +65,7 @@ Bottom level categories:
|
||||
|
||||
Naga now infers the correct binding layout when a resource appears only in an assignment to `_`. By @andyleiserson in [#7540](https://github.com/gfx-rs/wgpu/pull/7540).
|
||||
|
||||
- Implement `dot4U8Packed` and `dot4I8Packed` for all backends, using specialized intrinsics on SPIR-V and HSLS if available, and polyfills everywhere else. By @robamler in [#7494](https://github.com/gfx-rs/wgpu/pull/7494) and [#7574](https://github.com/gfx-rs/wgpu/pull/7574).
|
||||
- Implement `dot4U8Packed` and `dot4I8Packed` for all backends, using specialized intrinsics on SPIR-V, HSLS, and Metal if available, and polyfills everywhere else. By @robamler in [#7494](https://github.com/gfx-rs/wgpu/pull/7494), [#7574](https://github.com/gfx-rs/wgpu/pull/7574), and [#7653](https://github.com/gfx-rs/wgpu/pull/7653).
|
||||
- Add polyfilled `pack4x{I,U}8Clamped` built-ins to all backends and WGSL frontend. By @ErichDonGubler in [#7546](https://github.com/gfx-rs/wgpu/pull/7546).
|
||||
- Allow textureLoad's sample index arg to be unsigned. By @jimblandy in [#7625](https://github.com/gfx-rs/wgpu/pull/7625).
|
||||
- Properly convert arguments to atomic operations. By @jimblandy in [#7573](https://github.com/gfx-rs/wgpu/pull/7573).
|
||||
|
||||
@ -121,6 +121,9 @@ const fn scalar_is_int(scalar: crate::Scalar) -> bool {
|
||||
/// Prefix for cached clamped level-of-detail values for `ImageLoad` expressions.
|
||||
const CLAMPED_LOD_LOAD_PREFIX: &str = "clamped_lod_e";
|
||||
|
||||
/// Prefix for reinterpreted expressions using `as_type<T>(...)`.
|
||||
const REINTERPRET_PREFIX: &str = "reinterpreted_";
|
||||
|
||||
/// Wrapper for identifier names for clamped level-of-detail values
|
||||
///
|
||||
/// Values of this type implement [`core::fmt::Display`], formatting as
|
||||
@ -156,6 +159,30 @@ impl Display for ArraySizeMember {
|
||||
}
|
||||
}
|
||||
|
||||
/// Wrapper for reinterpreted variables using `as_type<target_type>(orig)`.
|
||||
///
|
||||
/// Implements [`core::fmt::Display`], formatting as a name derived from
|
||||
/// `target_type` and the variable name of `orig`.
|
||||
#[derive(Clone, Copy)]
|
||||
struct Reinterpreted<'a> {
|
||||
target_type: &'a str,
|
||||
orig: Handle<crate::Expression>,
|
||||
}
|
||||
|
||||
impl<'a> Reinterpreted<'a> {
|
||||
const fn new(target_type: &'a str, orig: Handle<crate::Expression>) -> Self {
|
||||
Self { target_type, orig }
|
||||
}
|
||||
}
|
||||
|
||||
impl Display for Reinterpreted<'_> {
|
||||
fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
|
||||
f.write_str(REINTERPRET_PREFIX)?;
|
||||
f.write_str(self.target_type)?;
|
||||
self.orig.write_prefixed(f, "_e")
|
||||
}
|
||||
}
|
||||
|
||||
struct TypeContext<'a> {
|
||||
handle: Handle<crate::Type>,
|
||||
gctx: proc::GlobalCtx<'a>,
|
||||
@ -1470,14 +1497,14 @@ impl<W: Write> Writer<W> {
|
||||
|
||||
/// Emit code for the arithmetic expression of the dot product.
|
||||
///
|
||||
/// The argument `extractor` is a function that accepts a `Writer`, a handle to a vector,
|
||||
/// and an index. writes out the expression for the component at that index.
|
||||
fn put_dot_product(
|
||||
/// The argument `extractor` is a function that accepts a `Writer`, a vector, and
|
||||
/// an index. It writes out the expression for the vector component at that index.
|
||||
fn put_dot_product<T: Copy>(
|
||||
&mut self,
|
||||
arg: Handle<crate::Expression>,
|
||||
arg1: Handle<crate::Expression>,
|
||||
arg: T,
|
||||
arg1: T,
|
||||
size: usize,
|
||||
extractor: impl Fn(&mut Self, Handle<crate::Expression>, usize) -> BackendResult,
|
||||
extractor: impl Fn(&mut Self, T, usize) -> BackendResult,
|
||||
) -> BackendResult {
|
||||
// Write parentheses around the dot product expression to prevent operators
|
||||
// with different precedences from applying earlier.
|
||||
@ -2206,27 +2233,53 @@ impl<W: Write> Writer<W> {
|
||||
),
|
||||
},
|
||||
fun @ (Mf::Dot4I8Packed | Mf::Dot4U8Packed) => {
|
||||
let conversion = match fun {
|
||||
Mf::Dot4I8Packed => "int",
|
||||
Mf::Dot4U8Packed => "",
|
||||
_ => unreachable!(),
|
||||
};
|
||||
if context.lang_version >= (2, 1) {
|
||||
// Write potentially optimizable code using `packed_(u?)char4`.
|
||||
// The two function arguments were already reinterpreted as packed (signed
|
||||
// or unsigned) chars in `Self::put_block`.
|
||||
let packed_type = match fun {
|
||||
Mf::Dot4I8Packed => "packed_char4",
|
||||
Mf::Dot4U8Packed => "packed_uchar4",
|
||||
_ => unreachable!(),
|
||||
};
|
||||
|
||||
return self.put_dot_product(
|
||||
arg,
|
||||
arg1.unwrap(),
|
||||
4,
|
||||
|writer, arg, index| {
|
||||
write!(writer.out, "({}(", conversion)?;
|
||||
writer.put_expression(arg, context, true)?;
|
||||
if index == 3 {
|
||||
write!(writer.out, ") >> 24)")?;
|
||||
} else {
|
||||
write!(writer.out, ") << {} >> 24)", (3 - index) * 8)?;
|
||||
}
|
||||
Ok(())
|
||||
},
|
||||
);
|
||||
return self.put_dot_product(
|
||||
Reinterpreted::new(packed_type, arg),
|
||||
Reinterpreted::new(packed_type, arg1.unwrap()),
|
||||
4,
|
||||
|writer, arg, index| {
|
||||
// MSL implicitly promotes these (signed or unsigned) chars to
|
||||
// `int` or `uint` in the multiplication, so no overflow can occur.
|
||||
write!(writer.out, "{arg}[{index}]")?;
|
||||
Ok(())
|
||||
},
|
||||
);
|
||||
} else {
|
||||
// Fall back to a polyfill since MSL < 2.1 doesn't seem to support
|
||||
// bitcasting from uint to `packed_char4` or `packed_uchar4`.
|
||||
// See <https://github.com/gfx-rs/wgpu/pull/7574#issuecomment-2835464472>.
|
||||
let conversion = match fun {
|
||||
Mf::Dot4I8Packed => "int",
|
||||
Mf::Dot4U8Packed => "",
|
||||
_ => unreachable!(),
|
||||
};
|
||||
|
||||
return self.put_dot_product(
|
||||
arg,
|
||||
arg1.unwrap(),
|
||||
4,
|
||||
|writer, arg, index| {
|
||||
write!(writer.out, "({}(", conversion)?;
|
||||
writer.put_expression(arg, context, true)?;
|
||||
if index == 3 {
|
||||
write!(writer.out, ") >> 24)")?;
|
||||
} else {
|
||||
write!(writer.out, ") << {} >> 24)", (3 - index) * 8)?;
|
||||
}
|
||||
Ok(())
|
||||
},
|
||||
);
|
||||
}
|
||||
}
|
||||
Mf::Outer => return Err(Error::UnsupportedCall(format!("{fun:?}"))),
|
||||
Mf::Cross => "cross",
|
||||
@ -3346,6 +3399,38 @@ impl<W: Write> Writer<W> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Convert the arguments of `Dot4{I, U}Packed` to `packed_(u?)char4`.
|
||||
///
|
||||
/// Caches the results in temporary variables (whose names are derived from
|
||||
/// the original variable names). This caching avoids the need to redo the
|
||||
/// casting for each vector component when emitting the dot product.
|
||||
fn put_casting_to_packed_chars(
|
||||
&mut self,
|
||||
fun: crate::MathFunction,
|
||||
arg0: Handle<crate::Expression>,
|
||||
arg1: Handle<crate::Expression>,
|
||||
indent: back::Level,
|
||||
context: &StatementContext<'_>,
|
||||
) -> Result<(), Error> {
|
||||
let packed_type = match fun {
|
||||
crate::MathFunction::Dot4I8Packed => "packed_char4",
|
||||
crate::MathFunction::Dot4U8Packed => "packed_uchar4",
|
||||
_ => unreachable!(),
|
||||
};
|
||||
|
||||
for arg in [arg0, arg1] {
|
||||
write!(
|
||||
self.out,
|
||||
"{indent}{packed_type} {0} = as_type<{packed_type}>(",
|
||||
Reinterpreted::new(packed_type, arg)
|
||||
)?;
|
||||
self.put_expression(arg, &context.expression, true)?;
|
||||
writeln!(self.out, ");")?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn put_block(
|
||||
&mut self,
|
||||
level: back::Level,
|
||||
@ -3362,17 +3447,45 @@ impl<W: Write> Writer<W> {
|
||||
match *statement {
|
||||
crate::Statement::Emit(ref range) => {
|
||||
for handle in range.clone() {
|
||||
// `ImageLoad` expressions covered by the `Restrict` bounds check policy
|
||||
// may need to cache a clamped version of their level-of-detail argument.
|
||||
if let crate::Expression::ImageLoad {
|
||||
image,
|
||||
level: mip_level,
|
||||
..
|
||||
} = context.expression.function.expressions[handle]
|
||||
{
|
||||
self.put_cache_restricted_level(
|
||||
handle, image, mip_level, level, context,
|
||||
)?;
|
||||
use crate::MathFunction as Mf;
|
||||
|
||||
match context.expression.function.expressions[handle] {
|
||||
// `ImageLoad` expressions covered by the `Restrict` bounds check policy
|
||||
// may need to cache a clamped version of their level-of-detail argument.
|
||||
crate::Expression::ImageLoad {
|
||||
image,
|
||||
level: mip_level,
|
||||
..
|
||||
} => {
|
||||
self.put_cache_restricted_level(
|
||||
handle, image, mip_level, level, context,
|
||||
)?;
|
||||
}
|
||||
|
||||
// If we are going to write a `Dot4I8Packed` or `Dot4U8Packed` on Metal
|
||||
// 2.1+ then we introduce two intermediate variables that recast the two
|
||||
// arguments as packed (signed or unsigned) chars. The actual dot product
|
||||
// is implemented in `Self::put_expression`, and it uses both of these
|
||||
// intermediate variables multiple times. There's no danger that the
|
||||
// original arguments get modified between the definition of these
|
||||
// intermediate variables and the implementation of the actual dot
|
||||
// product since we require the inputs of `Dot4{I, U}Packed` to be baked.
|
||||
crate::Expression::Math {
|
||||
fun: fun @ (Mf::Dot4I8Packed | Mf::Dot4U8Packed),
|
||||
arg,
|
||||
arg1,
|
||||
..
|
||||
} if context.expression.lang_version >= (2, 1) => {
|
||||
self.put_casting_to_packed_chars(
|
||||
fun,
|
||||
arg,
|
||||
arg1.unwrap(),
|
||||
level,
|
||||
context,
|
||||
)?;
|
||||
}
|
||||
|
||||
_ => (),
|
||||
}
|
||||
|
||||
let ptr_class = context.expression.resolve_type(handle).pointer_space();
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
# Turn on optimizations for `dot4I8Packed` and `dot4U8Packed` on SPIR-V and HLSL by
|
||||
# using a version of SPIR-V / shader model that supports these without any extensions.
|
||||
# Turn on optimizations for `dot4I8Packed` and `dot4U8Packed` on SPIR-V, HLSL, and Metal
|
||||
# by using a language version / shader model that supports these (without any extensions).
|
||||
|
||||
targets = "SPIRV | HLSL"
|
||||
targets = "SPIRV | HLSL | METAL"
|
||||
|
||||
[spv]
|
||||
# We also need to provide the corresponding capabilities (which are part of SPIR-V >= 1.6).
|
||||
@ -10,3 +10,6 @@ version = [1, 6]
|
||||
|
||||
[hlsl]
|
||||
shader_model = "V6_4"
|
||||
|
||||
[msl]
|
||||
lang_version = [2, 1]
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
# Explicitly turn off optimizations for `dot4I8Packed` and `dot4U8Packed`
|
||||
# on SPIRV and HLSL.
|
||||
# on SPIRV, HLSL, and Metal.
|
||||
|
||||
targets = "SPIRV | HLSL"
|
||||
targets = "SPIRV | HLSL | METAL"
|
||||
|
||||
[spv]
|
||||
# Provide some unrelated capability because an empty list of capabilities would
|
||||
@ -11,3 +11,6 @@ capabilities = ["Matrix"]
|
||||
|
||||
[hlsl]
|
||||
shader_model = "V6_3"
|
||||
|
||||
[msl]
|
||||
lang_version = [2, 0]
|
||||
|
||||
33
naga/tests/out/msl/wgsl-functions-optimized-by-version.msl
Normal file
33
naga/tests/out/msl/wgsl-functions-optimized-by-version.msl
Normal file
@ -0,0 +1,33 @@
|
||||
// language: metal2.1
|
||||
#include <metal_stdlib>
|
||||
#include <simd/simd.h>
|
||||
|
||||
using metal::uint;
|
||||
|
||||
|
||||
uint test_packed_integer_dot_product(
|
||||
) {
|
||||
packed_char4 reinterpreted_packed_char4_e0 = as_type<packed_char4>(1u);
|
||||
packed_char4 reinterpreted_packed_char4_e1 = as_type<packed_char4>(2u);
|
||||
int c_5_ = ( + reinterpreted_packed_char4_e0[0] * reinterpreted_packed_char4_e1[0] + reinterpreted_packed_char4_e0[1] * reinterpreted_packed_char4_e1[1] + reinterpreted_packed_char4_e0[2] * reinterpreted_packed_char4_e1[2] + reinterpreted_packed_char4_e0[3] * reinterpreted_packed_char4_e1[3]);
|
||||
packed_uchar4 reinterpreted_packed_uchar4_e3 = as_type<packed_uchar4>(3u);
|
||||
packed_uchar4 reinterpreted_packed_uchar4_e4 = as_type<packed_uchar4>(4u);
|
||||
uint c_6_ = ( + reinterpreted_packed_uchar4_e3[0] * reinterpreted_packed_uchar4_e4[0] + reinterpreted_packed_uchar4_e3[1] * reinterpreted_packed_uchar4_e4[1] + reinterpreted_packed_uchar4_e3[2] * reinterpreted_packed_uchar4_e4[2] + reinterpreted_packed_uchar4_e3[3] * reinterpreted_packed_uchar4_e4[3]);
|
||||
uint _e7 = 5u + c_6_;
|
||||
uint _e9 = 6u + c_6_;
|
||||
packed_char4 reinterpreted_packed_char4_e7 = as_type<packed_char4>(_e7);
|
||||
packed_char4 reinterpreted_packed_char4_e9 = as_type<packed_char4>(_e9);
|
||||
int c_7_ = ( + reinterpreted_packed_char4_e7[0] * reinterpreted_packed_char4_e9[0] + reinterpreted_packed_char4_e7[1] * reinterpreted_packed_char4_e9[1] + reinterpreted_packed_char4_e7[2] * reinterpreted_packed_char4_e9[2] + reinterpreted_packed_char4_e7[3] * reinterpreted_packed_char4_e9[3]);
|
||||
uint _e12 = 7u + c_6_;
|
||||
uint _e14 = 8u + c_6_;
|
||||
packed_uchar4 reinterpreted_packed_uchar4_e12 = as_type<packed_uchar4>(_e12);
|
||||
packed_uchar4 reinterpreted_packed_uchar4_e14 = as_type<packed_uchar4>(_e14);
|
||||
uint c_8_ = ( + reinterpreted_packed_uchar4_e12[0] * reinterpreted_packed_uchar4_e14[0] + reinterpreted_packed_uchar4_e12[1] * reinterpreted_packed_uchar4_e14[1] + reinterpreted_packed_uchar4_e12[2] * reinterpreted_packed_uchar4_e14[2] + reinterpreted_packed_uchar4_e12[3] * reinterpreted_packed_uchar4_e14[3]);
|
||||
return c_8_;
|
||||
}
|
||||
|
||||
kernel void main_(
|
||||
) {
|
||||
uint _e0 = test_packed_integer_dot_product();
|
||||
return;
|
||||
}
|
||||
25
naga/tests/out/msl/wgsl-functions-unoptimized.msl
Normal file
25
naga/tests/out/msl/wgsl-functions-unoptimized.msl
Normal file
@ -0,0 +1,25 @@
|
||||
// language: metal2.0
|
||||
#include <metal_stdlib>
|
||||
#include <simd/simd.h>
|
||||
|
||||
using metal::uint;
|
||||
|
||||
|
||||
uint test_packed_integer_dot_product(
|
||||
) {
|
||||
int c_5_ = ( + (int(1u) << 24 >> 24) * (int(2u) << 24 >> 24) + (int(1u) << 16 >> 24) * (int(2u) << 16 >> 24) + (int(1u) << 8 >> 24) * (int(2u) << 8 >> 24) + (int(1u) >> 24) * (int(2u) >> 24));
|
||||
uint c_6_ = ( + ((3u) << 24 >> 24) * ((4u) << 24 >> 24) + ((3u) << 16 >> 24) * ((4u) << 16 >> 24) + ((3u) << 8 >> 24) * ((4u) << 8 >> 24) + ((3u) >> 24) * ((4u) >> 24));
|
||||
uint _e7 = 5u + c_6_;
|
||||
uint _e9 = 6u + c_6_;
|
||||
int c_7_ = ( + (int(_e7) << 24 >> 24) * (int(_e9) << 24 >> 24) + (int(_e7) << 16 >> 24) * (int(_e9) << 16 >> 24) + (int(_e7) << 8 >> 24) * (int(_e9) << 8 >> 24) + (int(_e7) >> 24) * (int(_e9) >> 24));
|
||||
uint _e12 = 7u + c_6_;
|
||||
uint _e14 = 8u + c_6_;
|
||||
uint c_8_ = ( + ((_e12) << 24 >> 24) * ((_e14) << 24 >> 24) + ((_e12) << 16 >> 24) * ((_e14) << 16 >> 24) + ((_e12) << 8 >> 24) * ((_e14) << 8 >> 24) + ((_e12) >> 24) * ((_e14) >> 24));
|
||||
return c_8_;
|
||||
}
|
||||
|
||||
kernel void main_(
|
||||
) {
|
||||
uint _e0 = test_packed_integer_dot_product();
|
||||
return;
|
||||
}
|
||||
Loading…
x
Reference in New Issue
Block a user