Use intrinsics for dot4{I, U}8Packed in HLSL

This commit is contained in:
Robert Bamler 2025-04-18 20:14:14 +02:00 committed by Teodor Tanasoaia
parent 892f629025
commit fe05765602
9 changed files with 152 additions and 25 deletions

View File

@ -12,7 +12,7 @@ use super::{
WrappedZeroValue,
},
storage::StoreValue,
BackendResult, Error, FragmentEntryPoint, Options,
BackendResult, Error, FragmentEntryPoint, Options, ShaderModel,
};
use crate::{
back::{self, Baked},
@ -3751,33 +3751,48 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
fun @ (Function::Dot4I8Packed | Function::Dot4U8Packed) => {
let arg1 = arg1.unwrap();
write!(self.out, "dot(")?;
if self.options.shader_model >= ShaderModel::V6_4 {
// Intrinsics `dot4add_{i, u}8packed` are available in SM 6.4 and later.
let function_name = match fun {
Function::Dot4I8Packed => "dot4add_i8packed",
Function::Dot4U8Packed => "dot4add_u8packed",
_ => unreachable!(),
};
write!(self.out, "{function_name}(")?;
self.write_expr(module, arg, func_ctx)?;
write!(self.out, ", ")?;
self.write_expr(module, arg1, func_ctx)?;
write!(self.out, ", 0)")?;
} else {
// Fall back to a polyfill as `dot4add_u8packed` is not available.
write!(self.out, "dot(")?;
if matches!(fun, Function::Dot4U8Packed) {
write!(self.out, "u")?;
}
write!(self.out, "int4(")?;
self.write_expr(module, arg, func_ctx)?;
write!(self.out, ", ")?;
self.write_expr(module, arg, func_ctx)?;
write!(self.out, " >> 8, ")?;
self.write_expr(module, arg, func_ctx)?;
write!(self.out, " >> 16, ")?;
self.write_expr(module, arg, func_ctx)?;
write!(self.out, " >> 24) << 24 >> 24, ")?;
if matches!(fun, Function::Dot4U8Packed) {
write!(self.out, "u")?;
}
write!(self.out, "int4(")?;
self.write_expr(module, arg, func_ctx)?;
write!(self.out, ", ")?;
self.write_expr(module, arg, func_ctx)?;
write!(self.out, " >> 8, ")?;
self.write_expr(module, arg, func_ctx)?;
write!(self.out, " >> 16, ")?;
self.write_expr(module, arg, func_ctx)?;
write!(self.out, " >> 24) << 24 >> 24, ")?;
if matches!(fun, Function::Dot4U8Packed) {
write!(self.out, "u")?;
if matches!(fun, Function::Dot4U8Packed) {
write!(self.out, "u")?;
}
write!(self.out, "int4(")?;
self.write_expr(module, arg1, func_ctx)?;
write!(self.out, ", ")?;
self.write_expr(module, arg1, func_ctx)?;
write!(self.out, " >> 8, ")?;
self.write_expr(module, arg1, func_ctx)?;
write!(self.out, " >> 16, ")?;
self.write_expr(module, arg1, func_ctx)?;
write!(self.out, " >> 24) << 24 >> 24)")?;
}
write!(self.out, "int4(")?;
self.write_expr(module, arg1, func_ctx)?;
write!(self.out, ", ")?;
self.write_expr(module, arg1, func_ctx)?;
write!(self.out, " >> 8, ")?;
self.write_expr(module, arg1, func_ctx)?;
write!(self.out, " >> 16, ")?;
self.write_expr(module, arg1, func_ctx)?;
write!(self.out, " >> 24) << 24 >> 24)")?;
}
Function::QuantizeToF16 => {
write!(self.out, "f16tof32(f32tof16(")?;

View File

@ -0,0 +1,6 @@
# Explicitly turn on optimizations for `dot4I8Packed` and `dot4U8Packed` on HLSL.
targets = "HLSL"
[hlsl]
shader_model = "V6_4"

View File

@ -0,0 +1,19 @@
fn test_packed_integer_dot_product() -> u32 {
let a_5 = 1u;
let b_5 = 2u;
let c_5: i32 = dot4I8Packed(a_5, b_5);
let a_6 = 3u;
let b_6 = 4u;
let c_6: u32 = dot4U8Packed(a_6, b_6);
// test baking of arguments
let c_7: i32 = dot4I8Packed(5u + c_6, 6u + c_6);
let c_8: u32 = dot4U8Packed(7u + c_6, 8u + c_6);
return c_8;
}
@compute @workgroup_size(1)
fn main() {
let c = test_packed_integer_dot_product();
}

View File

@ -0,0 +1,6 @@
# Explicitly turn off optimizations for `dot4I8Packed` and `dot4U8Packed` on HLSL.
targets = "HLSL"
[hlsl]
shader_model = "V6_3"

View File

@ -0,0 +1,19 @@
fn test_packed_integer_dot_product() -> u32 {
let a_5 = 1u;
let b_5 = 2u;
let c_5: i32 = dot4I8Packed(a_5, b_5);
let a_6 = 3u;
let b_6 = 4u;
let c_6: u32 = dot4U8Packed(a_6, b_6);
// test baking of arguments
let c_7: i32 = dot4I8Packed(5u + c_6, 6u + c_6);
let c_8: u32 = dot4U8Packed(7u + c_6, 8u + c_6);
return c_8;
}
@compute @workgroup_size(1)
fn main() {
let c = test_packed_integer_dot_product();
}

View File

@ -0,0 +1,19 @@
uint test_packed_integer_dot_product()
{
int c_5_ = dot4add_i8packed(1u, 2u, 0);
uint c_6_ = dot4add_u8packed(3u, 4u, 0);
uint _e7 = (5u + c_6_);
uint _e9 = (6u + c_6_);
int c_7_ = dot4add_i8packed(_e7, _e9, 0);
uint _e12 = (7u + c_6_);
uint _e14 = (8u + c_6_);
uint c_8_ = dot4add_u8packed(_e12, _e14, 0);
return c_8_;
}
[numthreads(1, 1, 1)]
void main()
{
const uint _e0 = test_packed_integer_dot_product();
return;
}

View File

@ -0,0 +1,12 @@
(
vertex:[
],
fragment:[
],
compute:[
(
entry_point:"main",
target_profile:"cs_6_4",
),
],
)

View File

@ -0,0 +1,19 @@
uint test_packed_integer_dot_product()
{
int c_5_ = dot(int4(1u, 1u >> 8, 1u >> 16, 1u >> 24) << 24 >> 24, int4(2u, 2u >> 8, 2u >> 16, 2u >> 24) << 24 >> 24);
uint c_6_ = dot(uint4(3u, 3u >> 8, 3u >> 16, 3u >> 24) << 24 >> 24, uint4(4u, 4u >> 8, 4u >> 16, 4u >> 24) << 24 >> 24);
uint _e7 = (5u + c_6_);
uint _e9 = (6u + c_6_);
int c_7_ = dot(int4(_e7, _e7 >> 8, _e7 >> 16, _e7 >> 24) << 24 >> 24, int4(_e9, _e9 >> 8, _e9 >> 16, _e9 >> 24) << 24 >> 24);
uint _e12 = (7u + c_6_);
uint _e14 = (8u + c_6_);
uint c_8_ = dot(uint4(_e12, _e12 >> 8, _e12 >> 16, _e12 >> 24) << 24 >> 24, uint4(_e14, _e14 >> 8, _e14 >> 16, _e14 >> 24) << 24 >> 24);
return c_8_;
}
[numthreads(1, 1, 1)]
void main()
{
const uint _e0 = test_packed_integer_dot_product();
return;
}

View File

@ -0,0 +1,12 @@
(
vertex:[
],
fragment:[
],
compute:[
(
entry_point:"main",
target_profile:"cs_6_3",
),
],
)