Implement dot4I8Packed and dot4U8Packed

Closes #7481.

This implementation roughly follows approach 2 outlined in #7481, i.e.,
it adds a polyfill for the signed and unsigned dot product of packed
vectors for each platform. It doesn't use the specialized instructions
that are available for this operation on SPIR-V (with capability
DotProductInput4x8BitPacked).
This commit is contained in:
Robert Bamler 2025-04-06 22:15:52 +02:00 committed by Teodor Tanasoaia
parent 65c56fdee4
commit c7d0af156d
18 changed files with 421 additions and 74 deletions

View File

@ -40,6 +40,12 @@ Bottom level categories:
## Unreleased
### New Features
#### Naga
- Add polyfills for `dot4U8Packed` and `dot4I8Packed` for all backends. By @robamler in [#7494](https://github.com/gfx-rs/wgpu/pull/7494).
### Changes
#### General

View File

@ -1386,6 +1386,10 @@ impl<'a, W: Write> Writer<'a, W> {
self.need_bake_expressions.insert(arg1.unwrap());
}
}
crate::MathFunction::Dot4U8Packed | crate::MathFunction::Dot4I8Packed => {
self.need_bake_expressions.insert(arg);
self.need_bake_expressions.insert(arg1.unwrap());
}
crate::MathFunction::Pack4xI8
| crate::MathFunction::Pack4xU8
| crate::MathFunction::Unpack4xI8
@ -3558,6 +3562,40 @@ impl<'a, W: Write> Writer<'a, W> {
"Correct TypeInner for dot product should be already validated"
),
},
fun @ (Mf::Dot4I8Packed | Mf::Dot4U8Packed) => {
let conversion = match fun {
Mf::Dot4I8Packed => "int",
Mf::Dot4U8Packed => "",
_ => unreachable!(),
};
let arg1 = arg1.unwrap();
// Write parentheses around the dot product expression to prevent operators
// with different precedences from applying earlier.
write!(self.out, "(")?;
for i in 0..4 {
// Since `bitfieldExtract` only sign extends if the value is signed, we
// need to convert the inputs to `int` in case of `Dot4I8Packed`. For
// `Dot4U8Packed`, the code below only introduces parenthesis around
// each factor, which aren't strictly needed because both operands are
// baked, but which don't hurt either.
write!(self.out, "bitfieldExtract({}(", conversion)?;
self.write_expr(arg, ctx)?;
write!(self.out, "), {}, 8)", i * 8)?;
write!(self.out, " * bitfieldExtract({}(", conversion)?;
self.write_expr(arg1, ctx)?;
write!(self.out, "), {}, 8)", i * 8)?;
if i != 3 {
write!(self.out, " + ")?;
}
}
write!(self.out, ")")?;
return Ok(());
}
Mf::Outer => "outerProduct",
Mf::Cross => "cross",
Mf::Distance => "distance",

View File

@ -206,7 +206,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
self.need_bake_expressions.insert(exp_handle);
}
if let Expression::Math { fun, arg, .. } = *expr {
if let Expression::Math { fun, arg, arg1, .. } = *expr {
match fun {
crate::MathFunction::Asinh
| crate::MathFunction::Acosh
@ -233,6 +233,10 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
self.need_bake_expressions.insert(arg);
}
}
crate::MathFunction::Dot4U8Packed | crate::MathFunction::Dot4I8Packed => {
self.need_bake_expressions.insert(arg);
self.need_bake_expressions.insert(arg1.unwrap());
}
_ => {}
}
}
@ -3433,6 +3437,8 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
Unpack4x8unorm,
Unpack4xI8,
Unpack4xU8,
Dot4I8Packed,
Dot4U8Packed,
QuantizeToF16,
Regular(&'static str),
MissingIntOverload(&'static str),
@ -3484,6 +3490,8 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
Mf::Pow => Function::Regular("pow"),
// geometry
Mf::Dot => Function::Regular("dot"),
Mf::Dot4I8Packed => Function::Dot4I8Packed,
Mf::Dot4U8Packed => Function::Dot4U8Packed,
//Mf::Outer => ,
Mf::Cross => Function::Regular("cross"),
Mf::Distance => Function::Regular("distance"),
@ -3706,6 +3714,37 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
self.write_expr(module, arg, func_ctx)?;
write!(self.out, " >> 24) << 24 >> 24)")?;
}
fun @ (Function::Dot4I8Packed | Function::Dot4U8Packed) => {
let arg1 = arg1.unwrap();
write!(self.out, "dot(")?;
if matches!(fun, Function::Dot4U8Packed) {
write!(self.out, "u")?;
}
write!(self.out, "int4(")?;
self.write_expr(module, arg, func_ctx)?;
write!(self.out, ", ")?;
self.write_expr(module, arg, func_ctx)?;
write!(self.out, " >> 8, ")?;
self.write_expr(module, arg, func_ctx)?;
write!(self.out, " >> 16, ")?;
self.write_expr(module, arg, func_ctx)?;
write!(self.out, " >> 24) << 24 >> 24, ")?;
if matches!(fun, Function::Dot4U8Packed) {
write!(self.out, "u")?;
}
write!(self.out, "int4(")?;
self.write_expr(module, arg1, func_ctx)?;
write!(self.out, ", ")?;
self.write_expr(module, arg1, func_ctx)?;
write!(self.out, " >> 8, ")?;
self.write_expr(module, arg1, func_ctx)?;
write!(self.out, " >> 16, ")?;
self.write_expr(module, arg1, func_ctx)?;
write!(self.out, " >> 24) << 24 >> 24)")?;
}
Function::QuantizeToF16 => {
write!(self.out, "f16tof32(f32tof16(")?;
self.write_expr(module, arg, func_ctx)?;

View File

@ -1470,12 +1470,14 @@ impl<W: Write> Writer<W> {
/// Emit code for the arithmetic expression of the dot product.
///
/// The argument `extractor` is a function that accepts a `Writer`, a handle to a vector,
/// and an index. writes out the expression for the component at that index.
fn put_dot_product(
&mut self,
arg: Handle<crate::Expression>,
arg1: Handle<crate::Expression>,
size: usize,
context: &ExpressionContext,
extractor: impl Fn(&mut Self, Handle<crate::Expression>, usize) -> BackendResult,
) -> BackendResult {
// Write parentheses around the dot product expression to prevent operators
// with different precedences from applying earlier.
@ -1483,22 +1485,12 @@ impl<W: Write> Writer<W> {
// Cycle through all the components of the vector
for index in 0..size {
let component = back::COMPONENTS[index];
// Write the addition to the previous product
// This will print an extra '+' at the beginning but that is fine in msl
write!(self.out, " + ")?;
// Write the first vector expression, this expression is marked to be
// cached so unless it can't be cached (for example, it's a Constant)
// it shouldn't produce large expressions.
self.put_expression(arg, context, true)?;
// Access the current component on the first vector
write!(self.out, ".{component} * ")?;
// Write the second vector expression, this expression is marked to be
// cached so unless it can't be cached (for example, it's a Constant)
// it shouldn't produce large expressions.
self.put_expression(arg1, context, true)?;
// Access the current component on the second vector
write!(self.out, ".{component}")?;
extractor(self, arg, index)?;
write!(self.out, " * ")?;
extractor(self, arg1, index)?;
}
write!(self.out, ")")?;
@ -2194,12 +2186,48 @@ impl<W: Write> Writer<W> {
..
} => "dot",
crate::TypeInner::Vector { size, .. } => {
return self.put_dot_product(arg, arg1.unwrap(), size as usize, context)
return self.put_dot_product(
arg,
arg1.unwrap(),
size as usize,
|writer, arg, index| {
// Write the vector expression; this expression is marked to be
// cached so unless it can't be cached (for example, it's a Constant)
// it shouldn't produce large expressions.
writer.put_expression(arg, context, true)?;
// Access the current component on the vector.
write!(writer.out, ".{}", back::COMPONENTS[index])?;
Ok(())
},
);
}
_ => unreachable!(
"Correct TypeInner for dot product should be already validated"
),
},
fun @ (Mf::Dot4I8Packed | Mf::Dot4U8Packed) => {
let conversion = match fun {
Mf::Dot4I8Packed => "int",
Mf::Dot4U8Packed => "",
_ => unreachable!(),
};
return self.put_dot_product(
arg,
arg1.unwrap(),
4,
|writer, arg, index| {
write!(writer.out, "({}(", conversion)?;
writer.put_expression(arg, context, true)?;
if index == 3 {
write!(writer.out, ") >> 24)")?;
} else {
write!(writer.out, ") << {} >> 24)", (3 - index) * 8)?;
}
Ok(())
},
);
}
Mf::Outer => return Err(Error::UnsupportedCall(format!("{fun:?}"))),
Mf::Cross => "cross",
Mf::Distance => "distance",
@ -3177,6 +3205,10 @@ impl<W: Write> Writer<W> {
}
}
}
crate::MathFunction::Dot4U8Packed | crate::MathFunction::Dot4I8Packed => {
self.need_bake_expressions.insert(arg);
self.need_bake_expressions.insert(arg1.unwrap());
}
crate::MathFunction::FirstLeadingBit
| crate::MathFunction::Pack4xI8
| crate::MathFunction::Pack4xU8

View File

@ -1126,6 +1126,14 @@ impl BlockContext<'_> {
arg1_id,
size as u32,
block,
|result_id, composite_id, index| {
Instruction::composite_extract(
result_type_id,
result_id,
composite_id,
&[index],
)
},
);
self.cached[expr_handle] = id;
return Ok(());
@ -1134,6 +1142,63 @@ impl BlockContext<'_> {
"Correct TypeInner for dot product should be already validated"
),
},
fun @ (Mf::Dot4I8Packed | Mf::Dot4U8Packed) => {
// TODO: consider using packed integer dot product if PackedVectorFormat4x8Bit is available
let (extract_op, arg0_id, arg1_id) = match fun {
Mf::Dot4U8Packed => (spirv::Op::BitFieldUExtract, arg0_id, arg1_id),
Mf::Dot4I8Packed => {
// Convert both packed arguments to signed integers so that we can apply the
// `BitFieldSExtract` operation on them in `write_dot_product` below.
let new_arg0_id = self.gen_id();
block.body.push(Instruction::unary(
spirv::Op::Bitcast,
result_type_id,
new_arg0_id,
arg0_id,
));
let new_arg1_id = self.gen_id();
block.body.push(Instruction::unary(
spirv::Op::Bitcast,
result_type_id,
new_arg1_id,
arg1_id,
));
(spirv::Op::BitFieldSExtract, new_arg0_id, new_arg1_id)
}
_ => unreachable!(),
};
let eight = self.writer.get_constant_scalar(crate::Literal::U32(8));
const VEC_LENGTH: u8 = 4;
let bit_shifts: [_; VEC_LENGTH as usize] = core::array::from_fn(|index| {
self.writer
.get_constant_scalar(crate::Literal::U32(index as u32 * 8))
});
self.write_dot_product(
id,
result_type_id,
arg0_id,
arg1_id,
VEC_LENGTH as Word,
block,
|result_id, composite_id, index| {
Instruction::ternary(
extract_op,
result_type_id,
result_id,
composite_id,
bit_shifts[index as usize],
eight,
)
},
);
self.cached[expr_handle] = id;
return Ok(());
}
Mf::Outer => MathOp::Custom(Instruction::binary(
spirv::Op::OuterProduct,
result_type_id,
@ -2540,6 +2605,12 @@ impl BlockContext<'_> {
}
/// Build the instructions for the arithmetic expression of a dot product
///
/// The argument `extractor` is a function that maps `(result_id,
/// composite_id, index)` to an instruction that extracts the `index`th
/// entry of the value with ID `composite_id` and assigns it to the slot
/// with id `result_id` (which must have type `result_type_id`).
#[expect(clippy::too_many_arguments)]
fn write_dot_product(
&mut self,
result_id: Word,
@ -2548,25 +2619,16 @@ impl BlockContext<'_> {
arg1_id: Word,
size: u32,
block: &mut Block,
extractor: impl Fn(Word, Word, Word) -> Instruction,
) {
let mut partial_sum = self.writer.get_constant_null(result_type_id);
let last_component = size - 1;
for index in 0..=last_component {
// compute the product of the current components
let a_id = self.gen_id();
block.body.push(Instruction::composite_extract(
result_type_id,
a_id,
arg0_id,
&[index],
));
block.body.push(extractor(a_id, arg0_id, index));
let b_id = self.gen_id();
block.body.push(Instruction::composite_extract(
result_type_id,
b_id,
arg1_id,
&[index],
));
block.body.push(extractor(b_id, arg1_id, index));
let prod_id = self.gen_id();
block.body.push(Instruction::binary(
spirv::Op::IMul,

View File

@ -105,6 +105,8 @@ impl TryToWgsl for crate::MathFunction {
Mf::Log2 => "log2",
Mf::Pow => "pow",
Mf::Dot => "dot",
Mf::Dot4I8Packed => "dot4I8Packed",
Mf::Dot4U8Packed => "dot4U8Packed",
Mf::Cross => "cross",
Mf::Distance => "distance",
Mf::Length => "length",

View File

@ -236,6 +236,8 @@ pub fn map_standard_fun(word: &str) -> Option<crate::MathFunction> {
"pow" => Mf::Pow,
// geometry
"dot" => Mf::Dot,
"dot4I8Packed" => Mf::Dot4I8Packed,
"dot4U8Packed" => Mf::Dot4U8Packed,
"cross" => Mf::Cross,
"distance" => Mf::Distance,
"length" => Mf::Length,

View File

@ -1148,6 +1148,8 @@ pub enum MathFunction {
Pow,
// geometry
Dot,
Dot4I8Packed,
Dot4U8Packed,
Outer,
Cross,
Distance,

View File

@ -1314,6 +1314,8 @@ impl<'a> ConstantEvaluator<'a> {
| crate::MathFunction::Frexp
| crate::MathFunction::Ldexp
| crate::MathFunction::Dot
| crate::MathFunction::Dot4I8Packed
| crate::MathFunction::Dot4U8Packed
| crate::MathFunction::Outer
| crate::MathFunction::Distance
| crate::MathFunction::Length

View File

@ -223,6 +223,8 @@ impl super::MathFunction {
Self::Pow => 2,
// geometry
Self::Dot => 2,
Self::Dot4I8Packed => 2,
Self::Dot4U8Packed => 2,
Self::Outer => 2,
Self::Cross => 2,
Self::Distance => 2,

View File

@ -82,6 +82,8 @@ impl ir::MathFunction {
}
Mf::Unpack4xI8 => regular!(1, SCALAR of U32 -> Vec4I).into(),
Mf::Unpack4xU8 => regular!(1, SCALAR of U32 -> Vec4U).into(),
Mf::Dot4I8Packed => regular!(2, SCALAR of U32 -> I32).into(),
Mf::Dot4U8Packed => regular!(2, SCALAR of U32 -> U32).into(),
// One-off operations
Mf::Dot => regular!(2, VECN of NUMERIC -> Scalar).into(),

View File

@ -209,6 +209,7 @@ pub(in crate::proc::overloads) enum ConclusionRule {
Frexp,
Modf,
U32,
I32,
Vec2F,
Vec4F,
Vec4I,
@ -223,6 +224,7 @@ impl ConclusionRule {
Self::Frexp => Conclusion::for_frexp_modf(ir::MathFunction::Frexp, size, scalar),
Self::Modf => Conclusion::for_frexp_modf(ir::MathFunction::Modf, size, scalar),
Self::U32 => Conclusion::Value(ir::TypeInner::Scalar(ir::Scalar::U32)),
Self::I32 => Conclusion::Value(ir::TypeInner::Scalar(ir::Scalar::I32)),
Self::Vec2F => Conclusion::Value(ir::TypeInner::Vector {
size: ir::VectorSize::Bi,
scalar: ir::Scalar::F32,

View File

@ -22,8 +22,24 @@ fn test_integer_dot_product() -> i32 {
return c_4;
}
fn test_packed_integer_dot_product() -> u32 {
let a_5 = 1u;
let b_5 = 2u;
let c_5: i32 = dot4I8Packed(a_5, b_5);
let a_6 = 3u;
let b_6 = 4u;
let c_6: u32 = dot4U8Packed(a_6, b_6);
// test baking of arguments
let c_7: i32 = dot4I8Packed(5u + c_6, 6u + c_6);
let c_8: u32 = dot4U8Packed(7u + c_6, 8u + c_6);
return c_8;
}
@compute @workgroup_size(1)
fn main() {
let a = test_fma();
let b = test_integer_dot_product();
let c = test_packed_integer_dot_product();
}

View File

@ -26,9 +26,22 @@ int test_integer_dot_product() {
return c_4_;
}
uint test_packed_integer_dot_product() {
int c_5_ = (bitfieldExtract(int(1u), 0, 8) * bitfieldExtract(int(2u), 0, 8) + bitfieldExtract(int(1u), 8, 8) * bitfieldExtract(int(2u), 8, 8) + bitfieldExtract(int(1u), 16, 8) * bitfieldExtract(int(2u), 16, 8) + bitfieldExtract(int(1u), 24, 8) * bitfieldExtract(int(2u), 24, 8));
uint c_6_ = (bitfieldExtract((3u), 0, 8) * bitfieldExtract((4u), 0, 8) + bitfieldExtract((3u), 8, 8) * bitfieldExtract((4u), 8, 8) + bitfieldExtract((3u), 16, 8) * bitfieldExtract((4u), 16, 8) + bitfieldExtract((3u), 24, 8) * bitfieldExtract((4u), 24, 8));
uint _e7 = (5u + c_6_);
uint _e9 = (6u + c_6_);
int c_7_ = (bitfieldExtract(int(_e7), 0, 8) * bitfieldExtract(int(_e9), 0, 8) + bitfieldExtract(int(_e7), 8, 8) * bitfieldExtract(int(_e9), 8, 8) + bitfieldExtract(int(_e7), 16, 8) * bitfieldExtract(int(_e9), 16, 8) + bitfieldExtract(int(_e7), 24, 8) * bitfieldExtract(int(_e9), 24, 8));
uint _e12 = (7u + c_6_);
uint _e14 = (8u + c_6_);
uint c_8_ = (bitfieldExtract((_e12), 0, 8) * bitfieldExtract((_e14), 0, 8) + bitfieldExtract((_e12), 8, 8) * bitfieldExtract((_e14), 8, 8) + bitfieldExtract((_e12), 16, 8) * bitfieldExtract((_e14), 16, 8) + bitfieldExtract((_e12), 24, 8) * bitfieldExtract((_e14), 24, 8));
return c_8_;
}
void main() {
vec2 _e0 = test_fma();
int _e1 = test_integer_dot_product();
uint _e2 = test_packed_integer_dot_product();
return;
}

View File

@ -18,10 +18,24 @@ int test_integer_dot_product()
return c_4_;
}
uint test_packed_integer_dot_product()
{
int c_5_ = dot(int4(1u, 1u >> 8, 1u >> 16, 1u >> 24) << 24 >> 24, int4(2u, 2u >> 8, 2u >> 16, 2u >> 24) << 24 >> 24);
uint c_6_ = dot(uint4(3u, 3u >> 8, 3u >> 16, 3u >> 24) << 24 >> 24, uint4(4u, 4u >> 8, 4u >> 16, 4u >> 24) << 24 >> 24);
uint _e7 = (5u + c_6_);
uint _e9 = (6u + c_6_);
int c_7_ = dot(int4(_e7, _e7 >> 8, _e7 >> 16, _e7 >> 24) << 24 >> 24, int4(_e9, _e9 >> 8, _e9 >> 16, _e9 >> 24) << 24 >> 24);
uint _e12 = (7u + c_6_);
uint _e14 = (8u + c_6_);
uint c_8_ = dot(uint4(_e12, _e12 >> 8, _e12 >> 16, _e12 >> 24) << 24 >> 24, uint4(_e14, _e14 >> 8, _e14 >> 16, _e14 >> 24) << 24 >> 24);
return c_8_;
}
[numthreads(1, 1, 1)]
void main()
{
const float2 _e0 = test_fma();
const int _e1 = test_integer_dot_product();
const uint _e2 = test_packed_integer_dot_product();
return;
}

View File

@ -27,9 +27,23 @@ int test_integer_dot_product(
return c_4_;
}
uint test_packed_integer_dot_product(
) {
int c_5_ = ( + (int(1u) << 24 >> 24) * (int(2u) << 24 >> 24) + (int(1u) << 16 >> 24) * (int(2u) << 16 >> 24) + (int(1u) << 8 >> 24) * (int(2u) << 8 >> 24) + (int(1u) >> 24) * (int(2u) >> 24));
uint c_6_ = ( + ((3u) << 24 >> 24) * ((4u) << 24 >> 24) + ((3u) << 16 >> 24) * ((4u) << 16 >> 24) + ((3u) << 8 >> 24) * ((4u) << 8 >> 24) + ((3u) >> 24) * ((4u) >> 24));
uint _e7 = 5u + c_6_;
uint _e9 = 6u + c_6_;
int c_7_ = ( + (int(_e7) << 24 >> 24) * (int(_e9) << 24 >> 24) + (int(_e7) << 16 >> 24) * (int(_e9) << 16 >> 24) + (int(_e7) << 8 >> 24) * (int(_e9) << 8 >> 24) + (int(_e7) >> 24) * (int(_e9) >> 24));
uint _e12 = 7u + c_6_;
uint _e14 = 8u + c_6_;
uint c_8_ = ( + ((_e12) << 24 >> 24) * ((_e14) << 24 >> 24) + ((_e12) << 16 >> 24) * ((_e14) << 16 >> 24) + ((_e12) << 8 >> 24) * ((_e14) << 8 >> 24) + ((_e12) >> 24) * ((_e14) >> 24));
return c_8_;
}
kernel void main_(
) {
metal::float2 _e0 = test_fma();
int _e1 = test_integer_dot_product();
uint _e2 = test_packed_integer_dot_product();
return;
}

View File

@ -1,28 +1,28 @@
; SPIR-V
; Version: 1.1
; Generator: rspirv
; Bound: 75
; Bound: 162
OpCapability Shader
%1 = OpExtInstImport "GLSL.std.450"
OpMemoryModel Logical GLSL450
OpEntryPoint GLCompute %70 "main"
OpExecutionMode %70 LocalSize 1 1 1
OpEntryPoint GLCompute %156 "main"
OpExecutionMode %156 LocalSize 1 1 1
%2 = OpTypeVoid
%4 = OpTypeFloat 32
%3 = OpTypeVector %4 2
%5 = OpTypeInt 32 1
%8 = OpTypeFunction %3
%9 = OpConstant %4 2.0
%10 = OpConstantComposite %3 %9 %9
%11 = OpConstant %4 0.5
%12 = OpConstantComposite %3 %11 %11
%17 = OpTypeFunction %5
%18 = OpConstant %5 1
%19 = OpTypeVector %5 2
%20 = OpConstantComposite %19 %18 %18
%21 = OpTypeInt 32 0
%22 = OpConstant %21 1
%23 = OpTypeVector %21 3
%6 = OpTypeInt 32 0
%9 = OpTypeFunction %3
%10 = OpConstant %4 2.0
%11 = OpConstantComposite %3 %10 %10
%12 = OpConstant %4 0.5
%13 = OpConstantComposite %3 %12 %12
%18 = OpTypeFunction %5
%19 = OpConstant %5 1
%20 = OpTypeVector %5 2
%21 = OpConstantComposite %20 %19 %19
%22 = OpConstant %6 1
%23 = OpTypeVector %6 3
%24 = OpConstantComposite %23 %22 %22 %22
%25 = OpConstant %5 4
%26 = OpTypeVector %5 4
@ -30,39 +30,50 @@ OpExecutionMode %70 LocalSize 1 1 1
%28 = OpConstant %5 2
%29 = OpConstantComposite %26 %28 %28 %28 %28
%32 = OpConstantNull %5
%41 = OpConstantNull %21
%71 = OpTypeFunction %2
%7 = OpFunction %3 None %8
%6 = OpLabel
OpBranch %13
%13 = OpLabel
%14 = OpExtInst %3 %1 Fma %10 %12 %12
OpReturnValue %14
%41 = OpConstantNull %6
%71 = OpTypeFunction %6
%72 = OpConstant %6 2
%73 = OpConstant %6 3
%74 = OpConstant %6 4
%75 = OpConstant %6 5
%76 = OpConstant %6 6
%77 = OpConstant %6 7
%78 = OpConstant %6 8
%83 = OpConstant %6 0
%84 = OpConstant %6 16
%85 = OpConstant %6 24
%157 = OpTypeFunction %2
%8 = OpFunction %3 None %9
%7 = OpLabel
OpBranch %14
%14 = OpLabel
%15 = OpExtInst %3 %1 Fma %11 %13 %13
OpReturnValue %15
OpFunctionEnd
%16 = OpFunction %5 None %17
%15 = OpLabel
%17 = OpFunction %5 None %18
%16 = OpLabel
OpBranch %30
%30 = OpLabel
%33 = OpCompositeExtract %5 %20 0
%34 = OpCompositeExtract %5 %20 0
%33 = OpCompositeExtract %5 %21 0
%34 = OpCompositeExtract %5 %21 0
%35 = OpIMul %5 %33 %34
%36 = OpIAdd %5 %32 %35
%37 = OpCompositeExtract %5 %20 1
%38 = OpCompositeExtract %5 %20 1
%37 = OpCompositeExtract %5 %21 1
%38 = OpCompositeExtract %5 %21 1
%39 = OpIMul %5 %37 %38
%31 = OpIAdd %5 %36 %39
%42 = OpCompositeExtract %21 %24 0
%43 = OpCompositeExtract %21 %24 0
%44 = OpIMul %21 %42 %43
%45 = OpIAdd %21 %41 %44
%46 = OpCompositeExtract %21 %24 1
%47 = OpCompositeExtract %21 %24 1
%48 = OpIMul %21 %46 %47
%49 = OpIAdd %21 %45 %48
%50 = OpCompositeExtract %21 %24 2
%51 = OpCompositeExtract %21 %24 2
%52 = OpIMul %21 %50 %51
%40 = OpIAdd %21 %49 %52
%42 = OpCompositeExtract %6 %24 0
%43 = OpCompositeExtract %6 %24 0
%44 = OpIMul %6 %42 %43
%45 = OpIAdd %6 %41 %44
%46 = OpCompositeExtract %6 %24 1
%47 = OpCompositeExtract %6 %24 1
%48 = OpIMul %6 %46 %47
%49 = OpIAdd %6 %45 %48
%50 = OpCompositeExtract %6 %24 2
%51 = OpCompositeExtract %6 %24 2
%52 = OpIMul %6 %50 %51
%40 = OpIAdd %6 %49 %52
%54 = OpCompositeExtract %5 %27 0
%55 = OpCompositeExtract %5 %29 0
%56 = OpIMul %5 %54 %55
@ -81,11 +92,90 @@ OpBranch %30
%53 = OpIAdd %5 %65 %68
OpReturnValue %53
OpFunctionEnd
%70 = OpFunction %2 None %71
%70 = OpFunction %6 None %71
%69 = OpLabel
OpBranch %72
%72 = OpLabel
%73 = OpFunctionCall %3 %7
%74 = OpFunctionCall %5 %16
OpBranch %79
%79 = OpLabel
%81 = OpBitcast %5 %22
%82 = OpBitcast %5 %72
%86 = OpBitFieldSExtract %5 %81 %83 %78
%87 = OpBitFieldSExtract %5 %82 %83 %78
%88 = OpIMul %5 %86 %87
%89 = OpIAdd %5 %32 %88
%90 = OpBitFieldSExtract %5 %81 %78 %78
%91 = OpBitFieldSExtract %5 %82 %78 %78
%92 = OpIMul %5 %90 %91
%93 = OpIAdd %5 %89 %92
%94 = OpBitFieldSExtract %5 %81 %84 %78
%95 = OpBitFieldSExtract %5 %82 %84 %78
%96 = OpIMul %5 %94 %95
%97 = OpIAdd %5 %93 %96
%98 = OpBitFieldSExtract %5 %81 %85 %78
%99 = OpBitFieldSExtract %5 %82 %85 %78
%100 = OpIMul %5 %98 %99
%80 = OpIAdd %5 %97 %100
%102 = OpBitFieldUExtract %6 %73 %83 %78
%103 = OpBitFieldUExtract %6 %74 %83 %78
%104 = OpIMul %6 %102 %103
%105 = OpIAdd %6 %41 %104
%106 = OpBitFieldUExtract %6 %73 %78 %78
%107 = OpBitFieldUExtract %6 %74 %78 %78
%108 = OpIMul %6 %106 %107
%109 = OpIAdd %6 %105 %108
%110 = OpBitFieldUExtract %6 %73 %84 %78
%111 = OpBitFieldUExtract %6 %74 %84 %78
%112 = OpIMul %6 %110 %111
%113 = OpIAdd %6 %109 %112
%114 = OpBitFieldUExtract %6 %73 %85 %78
%115 = OpBitFieldUExtract %6 %74 %85 %78
%116 = OpIMul %6 %114 %115
%101 = OpIAdd %6 %113 %116
%117 = OpIAdd %6 %75 %101
%118 = OpIAdd %6 %76 %101
%120 = OpBitcast %5 %117
%121 = OpBitcast %5 %118
%122 = OpBitFieldSExtract %5 %120 %83 %78
%123 = OpBitFieldSExtract %5 %121 %83 %78
%124 = OpIMul %5 %122 %123
%125 = OpIAdd %5 %32 %124
%126 = OpBitFieldSExtract %5 %120 %78 %78
%127 = OpBitFieldSExtract %5 %121 %78 %78
%128 = OpIMul %5 %126 %127
%129 = OpIAdd %5 %125 %128
%130 = OpBitFieldSExtract %5 %120 %84 %78
%131 = OpBitFieldSExtract %5 %121 %84 %78
%132 = OpIMul %5 %130 %131
%133 = OpIAdd %5 %129 %132
%134 = OpBitFieldSExtract %5 %120 %85 %78
%135 = OpBitFieldSExtract %5 %121 %85 %78
%136 = OpIMul %5 %134 %135
%119 = OpIAdd %5 %133 %136
%137 = OpIAdd %6 %77 %101
%138 = OpIAdd %6 %78 %101
%140 = OpBitFieldUExtract %6 %137 %83 %78
%141 = OpBitFieldUExtract %6 %138 %83 %78
%142 = OpIMul %6 %140 %141
%143 = OpIAdd %6 %41 %142
%144 = OpBitFieldUExtract %6 %137 %78 %78
%145 = OpBitFieldUExtract %6 %138 %78 %78
%146 = OpIMul %6 %144 %145
%147 = OpIAdd %6 %143 %146
%148 = OpBitFieldUExtract %6 %137 %84 %78
%149 = OpBitFieldUExtract %6 %138 %84 %78
%150 = OpIMul %6 %148 %149
%151 = OpIAdd %6 %147 %150
%152 = OpBitFieldUExtract %6 %137 %85 %78
%153 = OpBitFieldUExtract %6 %138 %85 %78
%154 = OpIMul %6 %152 %153
%139 = OpIAdd %6 %151 %154
OpReturnValue %139
OpFunctionEnd
%156 = OpFunction %2 None %157
%155 = OpLabel
OpBranch %158
%158 = OpLabel
%159 = OpFunctionCall %3 %8
%160 = OpFunctionCall %5 %17
%161 = OpFunctionCall %6 %70
OpReturn
OpFunctionEnd

View File

@ -16,9 +16,18 @@ fn test_integer_dot_product() -> i32 {
return c_4_;
}
fn test_packed_integer_dot_product() -> u32 {
let c_5_ = dot4I8Packed(1u, 2u);
let c_6_ = dot4U8Packed(3u, 4u);
let c_7_ = dot4I8Packed((5u + c_6_), (6u + c_6_));
let c_8_ = dot4U8Packed((7u + c_6_), (8u + c_6_));
return c_8_;
}
@compute @workgroup_size(1, 1, 1)
fn main() {
let _e0 = test_fma();
let _e1 = test_integer_dot_product();
let _e2 = test_packed_integer_dot_product();
return;
}