mirror of
https://github.com/gfx-rs/wgpu.git
synced 2025-12-08 21:26:17 +00:00
Implement dot4I8Packed and dot4U8Packed
Closes #7481. This implementation roughly follows approach 2 outlined in #7481, i.e., it adds a polyfill for the signed and unsigned dot product of packed vectors for each platform. It doesn't use the specialized instructions that are available for this operation on SPIR-V (with capability DotProductInput4x8BitPacked).
This commit is contained in:
parent
65c56fdee4
commit
c7d0af156d
@ -40,6 +40,12 @@ Bottom level categories:
|
||||
|
||||
## Unreleased
|
||||
|
||||
### New Features
|
||||
|
||||
#### Naga
|
||||
|
||||
- Add polyfills for `dot4U8Packed` and `dot4I8Packed` for all backends. By @robamler in [#7494](https://github.com/gfx-rs/wgpu/pull/7494).
|
||||
|
||||
### Changes
|
||||
|
||||
#### General
|
||||
|
||||
@ -1386,6 +1386,10 @@ impl<'a, W: Write> Writer<'a, W> {
|
||||
self.need_bake_expressions.insert(arg1.unwrap());
|
||||
}
|
||||
}
|
||||
crate::MathFunction::Dot4U8Packed | crate::MathFunction::Dot4I8Packed => {
|
||||
self.need_bake_expressions.insert(arg);
|
||||
self.need_bake_expressions.insert(arg1.unwrap());
|
||||
}
|
||||
crate::MathFunction::Pack4xI8
|
||||
| crate::MathFunction::Pack4xU8
|
||||
| crate::MathFunction::Unpack4xI8
|
||||
@ -3558,6 +3562,40 @@ impl<'a, W: Write> Writer<'a, W> {
|
||||
"Correct TypeInner for dot product should be already validated"
|
||||
),
|
||||
},
|
||||
fun @ (Mf::Dot4I8Packed | Mf::Dot4U8Packed) => {
|
||||
let conversion = match fun {
|
||||
Mf::Dot4I8Packed => "int",
|
||||
Mf::Dot4U8Packed => "",
|
||||
_ => unreachable!(),
|
||||
};
|
||||
|
||||
let arg1 = arg1.unwrap();
|
||||
|
||||
// Write parentheses around the dot product expression to prevent operators
|
||||
// with different precedences from applying earlier.
|
||||
write!(self.out, "(")?;
|
||||
for i in 0..4 {
|
||||
// Since `bitfieldExtract` only sign extends if the value is signed, we
|
||||
// need to convert the inputs to `int` in case of `Dot4I8Packed`. For
|
||||
// `Dot4U8Packed`, the code below only introduces parenthesis around
|
||||
// each factor, which aren't strictly needed because both operands are
|
||||
// baked, but which don't hurt either.
|
||||
write!(self.out, "bitfieldExtract({}(", conversion)?;
|
||||
self.write_expr(arg, ctx)?;
|
||||
write!(self.out, "), {}, 8)", i * 8)?;
|
||||
|
||||
write!(self.out, " * bitfieldExtract({}(", conversion)?;
|
||||
self.write_expr(arg1, ctx)?;
|
||||
write!(self.out, "), {}, 8)", i * 8)?;
|
||||
|
||||
if i != 3 {
|
||||
write!(self.out, " + ")?;
|
||||
}
|
||||
}
|
||||
write!(self.out, ")")?;
|
||||
|
||||
return Ok(());
|
||||
}
|
||||
Mf::Outer => "outerProduct",
|
||||
Mf::Cross => "cross",
|
||||
Mf::Distance => "distance",
|
||||
|
||||
@ -206,7 +206,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
|
||||
self.need_bake_expressions.insert(exp_handle);
|
||||
}
|
||||
|
||||
if let Expression::Math { fun, arg, .. } = *expr {
|
||||
if let Expression::Math { fun, arg, arg1, .. } = *expr {
|
||||
match fun {
|
||||
crate::MathFunction::Asinh
|
||||
| crate::MathFunction::Acosh
|
||||
@ -233,6 +233,10 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
|
||||
self.need_bake_expressions.insert(arg);
|
||||
}
|
||||
}
|
||||
crate::MathFunction::Dot4U8Packed | crate::MathFunction::Dot4I8Packed => {
|
||||
self.need_bake_expressions.insert(arg);
|
||||
self.need_bake_expressions.insert(arg1.unwrap());
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
@ -3433,6 +3437,8 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
|
||||
Unpack4x8unorm,
|
||||
Unpack4xI8,
|
||||
Unpack4xU8,
|
||||
Dot4I8Packed,
|
||||
Dot4U8Packed,
|
||||
QuantizeToF16,
|
||||
Regular(&'static str),
|
||||
MissingIntOverload(&'static str),
|
||||
@ -3484,6 +3490,8 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
|
||||
Mf::Pow => Function::Regular("pow"),
|
||||
// geometry
|
||||
Mf::Dot => Function::Regular("dot"),
|
||||
Mf::Dot4I8Packed => Function::Dot4I8Packed,
|
||||
Mf::Dot4U8Packed => Function::Dot4U8Packed,
|
||||
//Mf::Outer => ,
|
||||
Mf::Cross => Function::Regular("cross"),
|
||||
Mf::Distance => Function::Regular("distance"),
|
||||
@ -3706,6 +3714,37 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
|
||||
self.write_expr(module, arg, func_ctx)?;
|
||||
write!(self.out, " >> 24) << 24 >> 24)")?;
|
||||
}
|
||||
fun @ (Function::Dot4I8Packed | Function::Dot4U8Packed) => {
|
||||
let arg1 = arg1.unwrap();
|
||||
|
||||
write!(self.out, "dot(")?;
|
||||
|
||||
if matches!(fun, Function::Dot4U8Packed) {
|
||||
write!(self.out, "u")?;
|
||||
}
|
||||
write!(self.out, "int4(")?;
|
||||
self.write_expr(module, arg, func_ctx)?;
|
||||
write!(self.out, ", ")?;
|
||||
self.write_expr(module, arg, func_ctx)?;
|
||||
write!(self.out, " >> 8, ")?;
|
||||
self.write_expr(module, arg, func_ctx)?;
|
||||
write!(self.out, " >> 16, ")?;
|
||||
self.write_expr(module, arg, func_ctx)?;
|
||||
write!(self.out, " >> 24) << 24 >> 24, ")?;
|
||||
|
||||
if matches!(fun, Function::Dot4U8Packed) {
|
||||
write!(self.out, "u")?;
|
||||
}
|
||||
write!(self.out, "int4(")?;
|
||||
self.write_expr(module, arg1, func_ctx)?;
|
||||
write!(self.out, ", ")?;
|
||||
self.write_expr(module, arg1, func_ctx)?;
|
||||
write!(self.out, " >> 8, ")?;
|
||||
self.write_expr(module, arg1, func_ctx)?;
|
||||
write!(self.out, " >> 16, ")?;
|
||||
self.write_expr(module, arg1, func_ctx)?;
|
||||
write!(self.out, " >> 24) << 24 >> 24)")?;
|
||||
}
|
||||
Function::QuantizeToF16 => {
|
||||
write!(self.out, "f16tof32(f32tof16(")?;
|
||||
self.write_expr(module, arg, func_ctx)?;
|
||||
|
||||
@ -1470,12 +1470,14 @@ impl<W: Write> Writer<W> {
|
||||
|
||||
/// Emit code for the arithmetic expression of the dot product.
|
||||
///
|
||||
/// The argument `extractor` is a function that accepts a `Writer`, a handle to a vector,
|
||||
/// and an index. writes out the expression for the component at that index.
|
||||
fn put_dot_product(
|
||||
&mut self,
|
||||
arg: Handle<crate::Expression>,
|
||||
arg1: Handle<crate::Expression>,
|
||||
size: usize,
|
||||
context: &ExpressionContext,
|
||||
extractor: impl Fn(&mut Self, Handle<crate::Expression>, usize) -> BackendResult,
|
||||
) -> BackendResult {
|
||||
// Write parentheses around the dot product expression to prevent operators
|
||||
// with different precedences from applying earlier.
|
||||
@ -1483,22 +1485,12 @@ impl<W: Write> Writer<W> {
|
||||
|
||||
// Cycle through all the components of the vector
|
||||
for index in 0..size {
|
||||
let component = back::COMPONENTS[index];
|
||||
// Write the addition to the previous product
|
||||
// This will print an extra '+' at the beginning but that is fine in msl
|
||||
write!(self.out, " + ")?;
|
||||
// Write the first vector expression, this expression is marked to be
|
||||
// cached so unless it can't be cached (for example, it's a Constant)
|
||||
// it shouldn't produce large expressions.
|
||||
self.put_expression(arg, context, true)?;
|
||||
// Access the current component on the first vector
|
||||
write!(self.out, ".{component} * ")?;
|
||||
// Write the second vector expression, this expression is marked to be
|
||||
// cached so unless it can't be cached (for example, it's a Constant)
|
||||
// it shouldn't produce large expressions.
|
||||
self.put_expression(arg1, context, true)?;
|
||||
// Access the current component on the second vector
|
||||
write!(self.out, ".{component}")?;
|
||||
extractor(self, arg, index)?;
|
||||
write!(self.out, " * ")?;
|
||||
extractor(self, arg1, index)?;
|
||||
}
|
||||
|
||||
write!(self.out, ")")?;
|
||||
@ -2194,12 +2186,48 @@ impl<W: Write> Writer<W> {
|
||||
..
|
||||
} => "dot",
|
||||
crate::TypeInner::Vector { size, .. } => {
|
||||
return self.put_dot_product(arg, arg1.unwrap(), size as usize, context)
|
||||
return self.put_dot_product(
|
||||
arg,
|
||||
arg1.unwrap(),
|
||||
size as usize,
|
||||
|writer, arg, index| {
|
||||
// Write the vector expression; this expression is marked to be
|
||||
// cached so unless it can't be cached (for example, it's a Constant)
|
||||
// it shouldn't produce large expressions.
|
||||
writer.put_expression(arg, context, true)?;
|
||||
// Access the current component on the vector.
|
||||
write!(writer.out, ".{}", back::COMPONENTS[index])?;
|
||||
Ok(())
|
||||
},
|
||||
);
|
||||
}
|
||||
_ => unreachable!(
|
||||
"Correct TypeInner for dot product should be already validated"
|
||||
),
|
||||
},
|
||||
fun @ (Mf::Dot4I8Packed | Mf::Dot4U8Packed) => {
|
||||
let conversion = match fun {
|
||||
Mf::Dot4I8Packed => "int",
|
||||
Mf::Dot4U8Packed => "",
|
||||
_ => unreachable!(),
|
||||
};
|
||||
|
||||
return self.put_dot_product(
|
||||
arg,
|
||||
arg1.unwrap(),
|
||||
4,
|
||||
|writer, arg, index| {
|
||||
write!(writer.out, "({}(", conversion)?;
|
||||
writer.put_expression(arg, context, true)?;
|
||||
if index == 3 {
|
||||
write!(writer.out, ") >> 24)")?;
|
||||
} else {
|
||||
write!(writer.out, ") << {} >> 24)", (3 - index) * 8)?;
|
||||
}
|
||||
Ok(())
|
||||
},
|
||||
);
|
||||
}
|
||||
Mf::Outer => return Err(Error::UnsupportedCall(format!("{fun:?}"))),
|
||||
Mf::Cross => "cross",
|
||||
Mf::Distance => "distance",
|
||||
@ -3177,6 +3205,10 @@ impl<W: Write> Writer<W> {
|
||||
}
|
||||
}
|
||||
}
|
||||
crate::MathFunction::Dot4U8Packed | crate::MathFunction::Dot4I8Packed => {
|
||||
self.need_bake_expressions.insert(arg);
|
||||
self.need_bake_expressions.insert(arg1.unwrap());
|
||||
}
|
||||
crate::MathFunction::FirstLeadingBit
|
||||
| crate::MathFunction::Pack4xI8
|
||||
| crate::MathFunction::Pack4xU8
|
||||
|
||||
@ -1126,6 +1126,14 @@ impl BlockContext<'_> {
|
||||
arg1_id,
|
||||
size as u32,
|
||||
block,
|
||||
|result_id, composite_id, index| {
|
||||
Instruction::composite_extract(
|
||||
result_type_id,
|
||||
result_id,
|
||||
composite_id,
|
||||
&[index],
|
||||
)
|
||||
},
|
||||
);
|
||||
self.cached[expr_handle] = id;
|
||||
return Ok(());
|
||||
@ -1134,6 +1142,63 @@ impl BlockContext<'_> {
|
||||
"Correct TypeInner for dot product should be already validated"
|
||||
),
|
||||
},
|
||||
fun @ (Mf::Dot4I8Packed | Mf::Dot4U8Packed) => {
|
||||
// TODO: consider using packed integer dot product if PackedVectorFormat4x8Bit is available
|
||||
let (extract_op, arg0_id, arg1_id) = match fun {
|
||||
Mf::Dot4U8Packed => (spirv::Op::BitFieldUExtract, arg0_id, arg1_id),
|
||||
Mf::Dot4I8Packed => {
|
||||
// Convert both packed arguments to signed integers so that we can apply the
|
||||
// `BitFieldSExtract` operation on them in `write_dot_product` below.
|
||||
let new_arg0_id = self.gen_id();
|
||||
block.body.push(Instruction::unary(
|
||||
spirv::Op::Bitcast,
|
||||
result_type_id,
|
||||
new_arg0_id,
|
||||
arg0_id,
|
||||
));
|
||||
|
||||
let new_arg1_id = self.gen_id();
|
||||
block.body.push(Instruction::unary(
|
||||
spirv::Op::Bitcast,
|
||||
result_type_id,
|
||||
new_arg1_id,
|
||||
arg1_id,
|
||||
));
|
||||
|
||||
(spirv::Op::BitFieldSExtract, new_arg0_id, new_arg1_id)
|
||||
}
|
||||
_ => unreachable!(),
|
||||
};
|
||||
|
||||
let eight = self.writer.get_constant_scalar(crate::Literal::U32(8));
|
||||
|
||||
const VEC_LENGTH: u8 = 4;
|
||||
let bit_shifts: [_; VEC_LENGTH as usize] = core::array::from_fn(|index| {
|
||||
self.writer
|
||||
.get_constant_scalar(crate::Literal::U32(index as u32 * 8))
|
||||
});
|
||||
|
||||
self.write_dot_product(
|
||||
id,
|
||||
result_type_id,
|
||||
arg0_id,
|
||||
arg1_id,
|
||||
VEC_LENGTH as Word,
|
||||
block,
|
||||
|result_id, composite_id, index| {
|
||||
Instruction::ternary(
|
||||
extract_op,
|
||||
result_type_id,
|
||||
result_id,
|
||||
composite_id,
|
||||
bit_shifts[index as usize],
|
||||
eight,
|
||||
)
|
||||
},
|
||||
);
|
||||
self.cached[expr_handle] = id;
|
||||
return Ok(());
|
||||
}
|
||||
Mf::Outer => MathOp::Custom(Instruction::binary(
|
||||
spirv::Op::OuterProduct,
|
||||
result_type_id,
|
||||
@ -2540,6 +2605,12 @@ impl BlockContext<'_> {
|
||||
}
|
||||
|
||||
/// Build the instructions for the arithmetic expression of a dot product
|
||||
///
|
||||
/// The argument `extractor` is a function that maps `(result_id,
|
||||
/// composite_id, index)` to an instruction that extracts the `index`th
|
||||
/// entry of the value with ID `composite_id` and assigns it to the slot
|
||||
/// with id `result_id` (which must have type `result_type_id`).
|
||||
#[expect(clippy::too_many_arguments)]
|
||||
fn write_dot_product(
|
||||
&mut self,
|
||||
result_id: Word,
|
||||
@ -2548,25 +2619,16 @@ impl BlockContext<'_> {
|
||||
arg1_id: Word,
|
||||
size: u32,
|
||||
block: &mut Block,
|
||||
extractor: impl Fn(Word, Word, Word) -> Instruction,
|
||||
) {
|
||||
let mut partial_sum = self.writer.get_constant_null(result_type_id);
|
||||
let last_component = size - 1;
|
||||
for index in 0..=last_component {
|
||||
// compute the product of the current components
|
||||
let a_id = self.gen_id();
|
||||
block.body.push(Instruction::composite_extract(
|
||||
result_type_id,
|
||||
a_id,
|
||||
arg0_id,
|
||||
&[index],
|
||||
));
|
||||
block.body.push(extractor(a_id, arg0_id, index));
|
||||
let b_id = self.gen_id();
|
||||
block.body.push(Instruction::composite_extract(
|
||||
result_type_id,
|
||||
b_id,
|
||||
arg1_id,
|
||||
&[index],
|
||||
));
|
||||
block.body.push(extractor(b_id, arg1_id, index));
|
||||
let prod_id = self.gen_id();
|
||||
block.body.push(Instruction::binary(
|
||||
spirv::Op::IMul,
|
||||
|
||||
@ -105,6 +105,8 @@ impl TryToWgsl for crate::MathFunction {
|
||||
Mf::Log2 => "log2",
|
||||
Mf::Pow => "pow",
|
||||
Mf::Dot => "dot",
|
||||
Mf::Dot4I8Packed => "dot4I8Packed",
|
||||
Mf::Dot4U8Packed => "dot4U8Packed",
|
||||
Mf::Cross => "cross",
|
||||
Mf::Distance => "distance",
|
||||
Mf::Length => "length",
|
||||
|
||||
@ -236,6 +236,8 @@ pub fn map_standard_fun(word: &str) -> Option<crate::MathFunction> {
|
||||
"pow" => Mf::Pow,
|
||||
// geometry
|
||||
"dot" => Mf::Dot,
|
||||
"dot4I8Packed" => Mf::Dot4I8Packed,
|
||||
"dot4U8Packed" => Mf::Dot4U8Packed,
|
||||
"cross" => Mf::Cross,
|
||||
"distance" => Mf::Distance,
|
||||
"length" => Mf::Length,
|
||||
|
||||
@ -1148,6 +1148,8 @@ pub enum MathFunction {
|
||||
Pow,
|
||||
// geometry
|
||||
Dot,
|
||||
Dot4I8Packed,
|
||||
Dot4U8Packed,
|
||||
Outer,
|
||||
Cross,
|
||||
Distance,
|
||||
|
||||
@ -1314,6 +1314,8 @@ impl<'a> ConstantEvaluator<'a> {
|
||||
| crate::MathFunction::Frexp
|
||||
| crate::MathFunction::Ldexp
|
||||
| crate::MathFunction::Dot
|
||||
| crate::MathFunction::Dot4I8Packed
|
||||
| crate::MathFunction::Dot4U8Packed
|
||||
| crate::MathFunction::Outer
|
||||
| crate::MathFunction::Distance
|
||||
| crate::MathFunction::Length
|
||||
|
||||
@ -223,6 +223,8 @@ impl super::MathFunction {
|
||||
Self::Pow => 2,
|
||||
// geometry
|
||||
Self::Dot => 2,
|
||||
Self::Dot4I8Packed => 2,
|
||||
Self::Dot4U8Packed => 2,
|
||||
Self::Outer => 2,
|
||||
Self::Cross => 2,
|
||||
Self::Distance => 2,
|
||||
|
||||
@ -82,6 +82,8 @@ impl ir::MathFunction {
|
||||
}
|
||||
Mf::Unpack4xI8 => regular!(1, SCALAR of U32 -> Vec4I).into(),
|
||||
Mf::Unpack4xU8 => regular!(1, SCALAR of U32 -> Vec4U).into(),
|
||||
Mf::Dot4I8Packed => regular!(2, SCALAR of U32 -> I32).into(),
|
||||
Mf::Dot4U8Packed => regular!(2, SCALAR of U32 -> U32).into(),
|
||||
|
||||
// One-off operations
|
||||
Mf::Dot => regular!(2, VECN of NUMERIC -> Scalar).into(),
|
||||
|
||||
@ -209,6 +209,7 @@ pub(in crate::proc::overloads) enum ConclusionRule {
|
||||
Frexp,
|
||||
Modf,
|
||||
U32,
|
||||
I32,
|
||||
Vec2F,
|
||||
Vec4F,
|
||||
Vec4I,
|
||||
@ -223,6 +224,7 @@ impl ConclusionRule {
|
||||
Self::Frexp => Conclusion::for_frexp_modf(ir::MathFunction::Frexp, size, scalar),
|
||||
Self::Modf => Conclusion::for_frexp_modf(ir::MathFunction::Modf, size, scalar),
|
||||
Self::U32 => Conclusion::Value(ir::TypeInner::Scalar(ir::Scalar::U32)),
|
||||
Self::I32 => Conclusion::Value(ir::TypeInner::Scalar(ir::Scalar::I32)),
|
||||
Self::Vec2F => Conclusion::Value(ir::TypeInner::Vector {
|
||||
size: ir::VectorSize::Bi,
|
||||
scalar: ir::Scalar::F32,
|
||||
|
||||
@ -22,8 +22,24 @@ fn test_integer_dot_product() -> i32 {
|
||||
return c_4;
|
||||
}
|
||||
|
||||
fn test_packed_integer_dot_product() -> u32 {
|
||||
let a_5 = 1u;
|
||||
let b_5 = 2u;
|
||||
let c_5: i32 = dot4I8Packed(a_5, b_5);
|
||||
|
||||
let a_6 = 3u;
|
||||
let b_6 = 4u;
|
||||
let c_6: u32 = dot4U8Packed(a_6, b_6);
|
||||
|
||||
// test baking of arguments
|
||||
let c_7: i32 = dot4I8Packed(5u + c_6, 6u + c_6);
|
||||
let c_8: u32 = dot4U8Packed(7u + c_6, 8u + c_6);
|
||||
return c_8;
|
||||
}
|
||||
|
||||
@compute @workgroup_size(1)
|
||||
fn main() {
|
||||
let a = test_fma();
|
||||
let b = test_integer_dot_product();
|
||||
let c = test_packed_integer_dot_product();
|
||||
}
|
||||
|
||||
@ -26,9 +26,22 @@ int test_integer_dot_product() {
|
||||
return c_4_;
|
||||
}
|
||||
|
||||
uint test_packed_integer_dot_product() {
|
||||
int c_5_ = (bitfieldExtract(int(1u), 0, 8) * bitfieldExtract(int(2u), 0, 8) + bitfieldExtract(int(1u), 8, 8) * bitfieldExtract(int(2u), 8, 8) + bitfieldExtract(int(1u), 16, 8) * bitfieldExtract(int(2u), 16, 8) + bitfieldExtract(int(1u), 24, 8) * bitfieldExtract(int(2u), 24, 8));
|
||||
uint c_6_ = (bitfieldExtract((3u), 0, 8) * bitfieldExtract((4u), 0, 8) + bitfieldExtract((3u), 8, 8) * bitfieldExtract((4u), 8, 8) + bitfieldExtract((3u), 16, 8) * bitfieldExtract((4u), 16, 8) + bitfieldExtract((3u), 24, 8) * bitfieldExtract((4u), 24, 8));
|
||||
uint _e7 = (5u + c_6_);
|
||||
uint _e9 = (6u + c_6_);
|
||||
int c_7_ = (bitfieldExtract(int(_e7), 0, 8) * bitfieldExtract(int(_e9), 0, 8) + bitfieldExtract(int(_e7), 8, 8) * bitfieldExtract(int(_e9), 8, 8) + bitfieldExtract(int(_e7), 16, 8) * bitfieldExtract(int(_e9), 16, 8) + bitfieldExtract(int(_e7), 24, 8) * bitfieldExtract(int(_e9), 24, 8));
|
||||
uint _e12 = (7u + c_6_);
|
||||
uint _e14 = (8u + c_6_);
|
||||
uint c_8_ = (bitfieldExtract((_e12), 0, 8) * bitfieldExtract((_e14), 0, 8) + bitfieldExtract((_e12), 8, 8) * bitfieldExtract((_e14), 8, 8) + bitfieldExtract((_e12), 16, 8) * bitfieldExtract((_e14), 16, 8) + bitfieldExtract((_e12), 24, 8) * bitfieldExtract((_e14), 24, 8));
|
||||
return c_8_;
|
||||
}
|
||||
|
||||
void main() {
|
||||
vec2 _e0 = test_fma();
|
||||
int _e1 = test_integer_dot_product();
|
||||
uint _e2 = test_packed_integer_dot_product();
|
||||
return;
|
||||
}
|
||||
|
||||
|
||||
@ -18,10 +18,24 @@ int test_integer_dot_product()
|
||||
return c_4_;
|
||||
}
|
||||
|
||||
uint test_packed_integer_dot_product()
|
||||
{
|
||||
int c_5_ = dot(int4(1u, 1u >> 8, 1u >> 16, 1u >> 24) << 24 >> 24, int4(2u, 2u >> 8, 2u >> 16, 2u >> 24) << 24 >> 24);
|
||||
uint c_6_ = dot(uint4(3u, 3u >> 8, 3u >> 16, 3u >> 24) << 24 >> 24, uint4(4u, 4u >> 8, 4u >> 16, 4u >> 24) << 24 >> 24);
|
||||
uint _e7 = (5u + c_6_);
|
||||
uint _e9 = (6u + c_6_);
|
||||
int c_7_ = dot(int4(_e7, _e7 >> 8, _e7 >> 16, _e7 >> 24) << 24 >> 24, int4(_e9, _e9 >> 8, _e9 >> 16, _e9 >> 24) << 24 >> 24);
|
||||
uint _e12 = (7u + c_6_);
|
||||
uint _e14 = (8u + c_6_);
|
||||
uint c_8_ = dot(uint4(_e12, _e12 >> 8, _e12 >> 16, _e12 >> 24) << 24 >> 24, uint4(_e14, _e14 >> 8, _e14 >> 16, _e14 >> 24) << 24 >> 24);
|
||||
return c_8_;
|
||||
}
|
||||
|
||||
[numthreads(1, 1, 1)]
|
||||
void main()
|
||||
{
|
||||
const float2 _e0 = test_fma();
|
||||
const int _e1 = test_integer_dot_product();
|
||||
const uint _e2 = test_packed_integer_dot_product();
|
||||
return;
|
||||
}
|
||||
|
||||
@ -27,9 +27,23 @@ int test_integer_dot_product(
|
||||
return c_4_;
|
||||
}
|
||||
|
||||
uint test_packed_integer_dot_product(
|
||||
) {
|
||||
int c_5_ = ( + (int(1u) << 24 >> 24) * (int(2u) << 24 >> 24) + (int(1u) << 16 >> 24) * (int(2u) << 16 >> 24) + (int(1u) << 8 >> 24) * (int(2u) << 8 >> 24) + (int(1u) >> 24) * (int(2u) >> 24));
|
||||
uint c_6_ = ( + ((3u) << 24 >> 24) * ((4u) << 24 >> 24) + ((3u) << 16 >> 24) * ((4u) << 16 >> 24) + ((3u) << 8 >> 24) * ((4u) << 8 >> 24) + ((3u) >> 24) * ((4u) >> 24));
|
||||
uint _e7 = 5u + c_6_;
|
||||
uint _e9 = 6u + c_6_;
|
||||
int c_7_ = ( + (int(_e7) << 24 >> 24) * (int(_e9) << 24 >> 24) + (int(_e7) << 16 >> 24) * (int(_e9) << 16 >> 24) + (int(_e7) << 8 >> 24) * (int(_e9) << 8 >> 24) + (int(_e7) >> 24) * (int(_e9) >> 24));
|
||||
uint _e12 = 7u + c_6_;
|
||||
uint _e14 = 8u + c_6_;
|
||||
uint c_8_ = ( + ((_e12) << 24 >> 24) * ((_e14) << 24 >> 24) + ((_e12) << 16 >> 24) * ((_e14) << 16 >> 24) + ((_e12) << 8 >> 24) * ((_e14) << 8 >> 24) + ((_e12) >> 24) * ((_e14) >> 24));
|
||||
return c_8_;
|
||||
}
|
||||
|
||||
kernel void main_(
|
||||
) {
|
||||
metal::float2 _e0 = test_fma();
|
||||
int _e1 = test_integer_dot_product();
|
||||
uint _e2 = test_packed_integer_dot_product();
|
||||
return;
|
||||
}
|
||||
|
||||
@ -1,28 +1,28 @@
|
||||
; SPIR-V
|
||||
; Version: 1.1
|
||||
; Generator: rspirv
|
||||
; Bound: 75
|
||||
; Bound: 162
|
||||
OpCapability Shader
|
||||
%1 = OpExtInstImport "GLSL.std.450"
|
||||
OpMemoryModel Logical GLSL450
|
||||
OpEntryPoint GLCompute %70 "main"
|
||||
OpExecutionMode %70 LocalSize 1 1 1
|
||||
OpEntryPoint GLCompute %156 "main"
|
||||
OpExecutionMode %156 LocalSize 1 1 1
|
||||
%2 = OpTypeVoid
|
||||
%4 = OpTypeFloat 32
|
||||
%3 = OpTypeVector %4 2
|
||||
%5 = OpTypeInt 32 1
|
||||
%8 = OpTypeFunction %3
|
||||
%9 = OpConstant %4 2.0
|
||||
%10 = OpConstantComposite %3 %9 %9
|
||||
%11 = OpConstant %4 0.5
|
||||
%12 = OpConstantComposite %3 %11 %11
|
||||
%17 = OpTypeFunction %5
|
||||
%18 = OpConstant %5 1
|
||||
%19 = OpTypeVector %5 2
|
||||
%20 = OpConstantComposite %19 %18 %18
|
||||
%21 = OpTypeInt 32 0
|
||||
%22 = OpConstant %21 1
|
||||
%23 = OpTypeVector %21 3
|
||||
%6 = OpTypeInt 32 0
|
||||
%9 = OpTypeFunction %3
|
||||
%10 = OpConstant %4 2.0
|
||||
%11 = OpConstantComposite %3 %10 %10
|
||||
%12 = OpConstant %4 0.5
|
||||
%13 = OpConstantComposite %3 %12 %12
|
||||
%18 = OpTypeFunction %5
|
||||
%19 = OpConstant %5 1
|
||||
%20 = OpTypeVector %5 2
|
||||
%21 = OpConstantComposite %20 %19 %19
|
||||
%22 = OpConstant %6 1
|
||||
%23 = OpTypeVector %6 3
|
||||
%24 = OpConstantComposite %23 %22 %22 %22
|
||||
%25 = OpConstant %5 4
|
||||
%26 = OpTypeVector %5 4
|
||||
@ -30,39 +30,50 @@ OpExecutionMode %70 LocalSize 1 1 1
|
||||
%28 = OpConstant %5 2
|
||||
%29 = OpConstantComposite %26 %28 %28 %28 %28
|
||||
%32 = OpConstantNull %5
|
||||
%41 = OpConstantNull %21
|
||||
%71 = OpTypeFunction %2
|
||||
%7 = OpFunction %3 None %8
|
||||
%6 = OpLabel
|
||||
OpBranch %13
|
||||
%13 = OpLabel
|
||||
%14 = OpExtInst %3 %1 Fma %10 %12 %12
|
||||
OpReturnValue %14
|
||||
%41 = OpConstantNull %6
|
||||
%71 = OpTypeFunction %6
|
||||
%72 = OpConstant %6 2
|
||||
%73 = OpConstant %6 3
|
||||
%74 = OpConstant %6 4
|
||||
%75 = OpConstant %6 5
|
||||
%76 = OpConstant %6 6
|
||||
%77 = OpConstant %6 7
|
||||
%78 = OpConstant %6 8
|
||||
%83 = OpConstant %6 0
|
||||
%84 = OpConstant %6 16
|
||||
%85 = OpConstant %6 24
|
||||
%157 = OpTypeFunction %2
|
||||
%8 = OpFunction %3 None %9
|
||||
%7 = OpLabel
|
||||
OpBranch %14
|
||||
%14 = OpLabel
|
||||
%15 = OpExtInst %3 %1 Fma %11 %13 %13
|
||||
OpReturnValue %15
|
||||
OpFunctionEnd
|
||||
%16 = OpFunction %5 None %17
|
||||
%15 = OpLabel
|
||||
%17 = OpFunction %5 None %18
|
||||
%16 = OpLabel
|
||||
OpBranch %30
|
||||
%30 = OpLabel
|
||||
%33 = OpCompositeExtract %5 %20 0
|
||||
%34 = OpCompositeExtract %5 %20 0
|
||||
%33 = OpCompositeExtract %5 %21 0
|
||||
%34 = OpCompositeExtract %5 %21 0
|
||||
%35 = OpIMul %5 %33 %34
|
||||
%36 = OpIAdd %5 %32 %35
|
||||
%37 = OpCompositeExtract %5 %20 1
|
||||
%38 = OpCompositeExtract %5 %20 1
|
||||
%37 = OpCompositeExtract %5 %21 1
|
||||
%38 = OpCompositeExtract %5 %21 1
|
||||
%39 = OpIMul %5 %37 %38
|
||||
%31 = OpIAdd %5 %36 %39
|
||||
%42 = OpCompositeExtract %21 %24 0
|
||||
%43 = OpCompositeExtract %21 %24 0
|
||||
%44 = OpIMul %21 %42 %43
|
||||
%45 = OpIAdd %21 %41 %44
|
||||
%46 = OpCompositeExtract %21 %24 1
|
||||
%47 = OpCompositeExtract %21 %24 1
|
||||
%48 = OpIMul %21 %46 %47
|
||||
%49 = OpIAdd %21 %45 %48
|
||||
%50 = OpCompositeExtract %21 %24 2
|
||||
%51 = OpCompositeExtract %21 %24 2
|
||||
%52 = OpIMul %21 %50 %51
|
||||
%40 = OpIAdd %21 %49 %52
|
||||
%42 = OpCompositeExtract %6 %24 0
|
||||
%43 = OpCompositeExtract %6 %24 0
|
||||
%44 = OpIMul %6 %42 %43
|
||||
%45 = OpIAdd %6 %41 %44
|
||||
%46 = OpCompositeExtract %6 %24 1
|
||||
%47 = OpCompositeExtract %6 %24 1
|
||||
%48 = OpIMul %6 %46 %47
|
||||
%49 = OpIAdd %6 %45 %48
|
||||
%50 = OpCompositeExtract %6 %24 2
|
||||
%51 = OpCompositeExtract %6 %24 2
|
||||
%52 = OpIMul %6 %50 %51
|
||||
%40 = OpIAdd %6 %49 %52
|
||||
%54 = OpCompositeExtract %5 %27 0
|
||||
%55 = OpCompositeExtract %5 %29 0
|
||||
%56 = OpIMul %5 %54 %55
|
||||
@ -81,11 +92,90 @@ OpBranch %30
|
||||
%53 = OpIAdd %5 %65 %68
|
||||
OpReturnValue %53
|
||||
OpFunctionEnd
|
||||
%70 = OpFunction %2 None %71
|
||||
%70 = OpFunction %6 None %71
|
||||
%69 = OpLabel
|
||||
OpBranch %72
|
||||
%72 = OpLabel
|
||||
%73 = OpFunctionCall %3 %7
|
||||
%74 = OpFunctionCall %5 %16
|
||||
OpBranch %79
|
||||
%79 = OpLabel
|
||||
%81 = OpBitcast %5 %22
|
||||
%82 = OpBitcast %5 %72
|
||||
%86 = OpBitFieldSExtract %5 %81 %83 %78
|
||||
%87 = OpBitFieldSExtract %5 %82 %83 %78
|
||||
%88 = OpIMul %5 %86 %87
|
||||
%89 = OpIAdd %5 %32 %88
|
||||
%90 = OpBitFieldSExtract %5 %81 %78 %78
|
||||
%91 = OpBitFieldSExtract %5 %82 %78 %78
|
||||
%92 = OpIMul %5 %90 %91
|
||||
%93 = OpIAdd %5 %89 %92
|
||||
%94 = OpBitFieldSExtract %5 %81 %84 %78
|
||||
%95 = OpBitFieldSExtract %5 %82 %84 %78
|
||||
%96 = OpIMul %5 %94 %95
|
||||
%97 = OpIAdd %5 %93 %96
|
||||
%98 = OpBitFieldSExtract %5 %81 %85 %78
|
||||
%99 = OpBitFieldSExtract %5 %82 %85 %78
|
||||
%100 = OpIMul %5 %98 %99
|
||||
%80 = OpIAdd %5 %97 %100
|
||||
%102 = OpBitFieldUExtract %6 %73 %83 %78
|
||||
%103 = OpBitFieldUExtract %6 %74 %83 %78
|
||||
%104 = OpIMul %6 %102 %103
|
||||
%105 = OpIAdd %6 %41 %104
|
||||
%106 = OpBitFieldUExtract %6 %73 %78 %78
|
||||
%107 = OpBitFieldUExtract %6 %74 %78 %78
|
||||
%108 = OpIMul %6 %106 %107
|
||||
%109 = OpIAdd %6 %105 %108
|
||||
%110 = OpBitFieldUExtract %6 %73 %84 %78
|
||||
%111 = OpBitFieldUExtract %6 %74 %84 %78
|
||||
%112 = OpIMul %6 %110 %111
|
||||
%113 = OpIAdd %6 %109 %112
|
||||
%114 = OpBitFieldUExtract %6 %73 %85 %78
|
||||
%115 = OpBitFieldUExtract %6 %74 %85 %78
|
||||
%116 = OpIMul %6 %114 %115
|
||||
%101 = OpIAdd %6 %113 %116
|
||||
%117 = OpIAdd %6 %75 %101
|
||||
%118 = OpIAdd %6 %76 %101
|
||||
%120 = OpBitcast %5 %117
|
||||
%121 = OpBitcast %5 %118
|
||||
%122 = OpBitFieldSExtract %5 %120 %83 %78
|
||||
%123 = OpBitFieldSExtract %5 %121 %83 %78
|
||||
%124 = OpIMul %5 %122 %123
|
||||
%125 = OpIAdd %5 %32 %124
|
||||
%126 = OpBitFieldSExtract %5 %120 %78 %78
|
||||
%127 = OpBitFieldSExtract %5 %121 %78 %78
|
||||
%128 = OpIMul %5 %126 %127
|
||||
%129 = OpIAdd %5 %125 %128
|
||||
%130 = OpBitFieldSExtract %5 %120 %84 %78
|
||||
%131 = OpBitFieldSExtract %5 %121 %84 %78
|
||||
%132 = OpIMul %5 %130 %131
|
||||
%133 = OpIAdd %5 %129 %132
|
||||
%134 = OpBitFieldSExtract %5 %120 %85 %78
|
||||
%135 = OpBitFieldSExtract %5 %121 %85 %78
|
||||
%136 = OpIMul %5 %134 %135
|
||||
%119 = OpIAdd %5 %133 %136
|
||||
%137 = OpIAdd %6 %77 %101
|
||||
%138 = OpIAdd %6 %78 %101
|
||||
%140 = OpBitFieldUExtract %6 %137 %83 %78
|
||||
%141 = OpBitFieldUExtract %6 %138 %83 %78
|
||||
%142 = OpIMul %6 %140 %141
|
||||
%143 = OpIAdd %6 %41 %142
|
||||
%144 = OpBitFieldUExtract %6 %137 %78 %78
|
||||
%145 = OpBitFieldUExtract %6 %138 %78 %78
|
||||
%146 = OpIMul %6 %144 %145
|
||||
%147 = OpIAdd %6 %143 %146
|
||||
%148 = OpBitFieldUExtract %6 %137 %84 %78
|
||||
%149 = OpBitFieldUExtract %6 %138 %84 %78
|
||||
%150 = OpIMul %6 %148 %149
|
||||
%151 = OpIAdd %6 %147 %150
|
||||
%152 = OpBitFieldUExtract %6 %137 %85 %78
|
||||
%153 = OpBitFieldUExtract %6 %138 %85 %78
|
||||
%154 = OpIMul %6 %152 %153
|
||||
%139 = OpIAdd %6 %151 %154
|
||||
OpReturnValue %139
|
||||
OpFunctionEnd
|
||||
%156 = OpFunction %2 None %157
|
||||
%155 = OpLabel
|
||||
OpBranch %158
|
||||
%158 = OpLabel
|
||||
%159 = OpFunctionCall %3 %8
|
||||
%160 = OpFunctionCall %5 %17
|
||||
%161 = OpFunctionCall %6 %70
|
||||
OpReturn
|
||||
OpFunctionEnd
|
||||
@ -16,9 +16,18 @@ fn test_integer_dot_product() -> i32 {
|
||||
return c_4_;
|
||||
}
|
||||
|
||||
fn test_packed_integer_dot_product() -> u32 {
|
||||
let c_5_ = dot4I8Packed(1u, 2u);
|
||||
let c_6_ = dot4U8Packed(3u, 4u);
|
||||
let c_7_ = dot4I8Packed((5u + c_6_), (6u + c_6_));
|
||||
let c_8_ = dot4U8Packed((7u + c_6_), (8u + c_6_));
|
||||
return c_8_;
|
||||
}
|
||||
|
||||
@compute @workgroup_size(1, 1, 1)
|
||||
fn main() {
|
||||
let _e0 = test_fma();
|
||||
let _e1 = test_integer_dot_product();
|
||||
let _e2 = test_packed_integer_dot_product();
|
||||
return;
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user