Use intrinsics for dot4{I, U}8Packed on spv

This commit is contained in:
Robert Bamler 2025-04-18 18:42:03 +02:00 committed by Teodor Tanasoaia
parent fe05765602
commit 5b20979e9b
7 changed files with 310 additions and 139 deletions

View File

@ -1143,59 +1143,88 @@ impl BlockContext<'_> {
),
},
fun @ (Mf::Dot4I8Packed | Mf::Dot4U8Packed) => {
// TODO: consider using packed integer dot product if PackedVectorFormat4x8Bit is available
let (extract_op, arg0_id, arg1_id) = match fun {
Mf::Dot4U8Packed => (spirv::Op::BitFieldUExtract, arg0_id, arg1_id),
Mf::Dot4I8Packed => {
// Convert both packed arguments to signed integers so that we can apply the
// `BitFieldSExtract` operation on them in `write_dot_product` below.
let new_arg0_id = self.gen_id();
block.body.push(Instruction::unary(
spirv::Op::Bitcast,
result_type_id,
new_arg0_id,
arg0_id,
));
if self
.writer
.require_all(&[
spirv::Capability::DotProduct,
spirv::Capability::DotProductInput4x8BitPacked,
])
.is_ok()
{
// Write optimized code using `PackedVectorFormat4x8Bit`.
self.writer.use_extension("SPV_KHR_integer_dot_product");
let new_arg1_id = self.gen_id();
block.body.push(Instruction::unary(
spirv::Op::Bitcast,
result_type_id,
new_arg1_id,
arg1_id,
));
let op = match fun {
Mf::Dot4I8Packed => spirv::Op::SDot,
Mf::Dot4U8Packed => spirv::Op::UDot,
_ => unreachable!(),
};
(spirv::Op::BitFieldSExtract, new_arg0_id, new_arg1_id)
}
_ => unreachable!(),
};
block.body.push(Instruction::ternary(
op,
result_type_id,
id,
arg0_id,
arg1_id,
spirv::PackedVectorFormat::PackedVectorFormat4x8Bit as Word,
));
} else {
// Fall back to a polyfill since `PackedVectorFormat4x8Bit` is not available.
let (extract_op, arg0_id, arg1_id) = match fun {
Mf::Dot4U8Packed => (spirv::Op::BitFieldUExtract, arg0_id, arg1_id),
Mf::Dot4I8Packed => {
// Convert both packed arguments to signed integers so that we can apply the
// `BitFieldSExtract` operation on them in `write_dot_product` below.
let new_arg0_id = self.gen_id();
block.body.push(Instruction::unary(
spirv::Op::Bitcast,
result_type_id,
new_arg0_id,
arg0_id,
));
let eight = self.writer.get_constant_scalar(crate::Literal::U32(8));
let new_arg1_id = self.gen_id();
block.body.push(Instruction::unary(
spirv::Op::Bitcast,
result_type_id,
new_arg1_id,
arg1_id,
));
const VEC_LENGTH: u8 = 4;
let bit_shifts: [_; VEC_LENGTH as usize] = core::array::from_fn(|index| {
self.writer
.get_constant_scalar(crate::Literal::U32(index as u32 * 8))
});
(spirv::Op::BitFieldSExtract, new_arg0_id, new_arg1_id)
}
_ => unreachable!(),
};
let eight = self.writer.get_constant_scalar(crate::Literal::U32(8));
const VEC_LENGTH: u8 = 4;
let bit_shifts: [_; VEC_LENGTH as usize] =
core::array::from_fn(|index| {
self.writer
.get_constant_scalar(crate::Literal::U32(index as u32 * 8))
});
self.write_dot_product(
id,
result_type_id,
arg0_id,
arg1_id,
VEC_LENGTH as Word,
block,
|result_id, composite_id, index| {
Instruction::ternary(
extract_op,
result_type_id,
result_id,
composite_id,
bit_shifts[index as usize],
eight,
)
},
);
}
self.write_dot_product(
id,
result_type_id,
arg0_id,
arg1_id,
VEC_LENGTH as Word,
block,
|result_id, composite_id, index| {
Instruction::ternary(
extract_op,
result_type_id,
result_id,
composite_id,
bit_shifts[index as usize],
eight,
)
},
);
self.cached[expr_handle] = id;
return Ok(());
}

View File

@ -202,6 +202,43 @@ impl Writer {
}
}
/// Indicate that the code requires all of the listed capabilities.
///
/// If all entries of `capabilities` appear in the available capabilities
/// specified in the [`Options`] from which this `Writer` was created
/// (including the case where [`Options::capabilities`] is `None`), add
/// them all to this `Writer`'s [`capabilities_used`] table, and return
/// `Ok(())`. If at least one of the listed capabilities is not available,
/// do not add anything to the `capabilities_used` table, and return the
/// first unavailable requested capability, wrapped in `Err()`.
///
/// This method is does not return an [`enum@Error`] in case of failure
/// because it may be used in cases where the caller can recover (e.g.,
/// with a polyfill) if the requested capabilities are not available. In
/// this case, it would be unnecessary work to find *all* the unavailable
/// requested capabilities, and to allocate a `Vec` for them, just so we
/// could return an [`Error::MissingCapabilities`]).
///
/// [`capabilities_used`]: Writer::capabilities_used
pub(super) fn require_all(
&mut self,
capabilities: &[spirv::Capability],
) -> Result<(), spirv::Capability> {
if let Some(ref available) = self.capabilities_available {
for requested in capabilities {
if !available.contains(requested) {
return Err(*requested);
}
}
}
for requested in capabilities {
self.capabilities_used.insert(*requested);
}
Ok(())
}
/// Indicate that the code uses the given extension.
pub(super) fn use_extension(&mut self, extension: &'static str) {
self.extensions_used.insert(extension);

View File

@ -1,6 +1,10 @@
# Explicitly turn on optimizations for `dot4I8Packed` and `dot4U8Packed` on HLSL.
# Explicitly turn on optimizations for `dot4I8Packed` and `dot4U8Packed`
# on SPIRV and HLSL.
targets = "HLSL"
targets = "SPIRV | HLSL"
[spv]
capabilities = ["DotProduct", "DotProductInput4x8BitPacked"]
[hlsl]
shader_model = "V6_4"

View File

@ -1,6 +1,13 @@
# Explicitly turn off optimizations for `dot4I8Packed` and `dot4U8Packed` on HLSL.
# Explicitly turn off optimizations for `dot4I8Packed` and `dot4U8Packed`
# on SPIRV and HLSL.
targets = "HLSL"
targets = "SPIRV | HLSL"
[spv]
# Provide some unrelated capability because an empty list of capabilities would
# get mapped to `None`, which would then be interpreted as "all capabilities
# are available".
capabilities = ["Matrix"]
[hlsl]
shader_model = "V6_3"

View File

@ -0,0 +1,46 @@
; SPIR-V
; Version: 1.1
; Generator: rspirv
; Bound: 30
OpCapability Shader
OpCapability DotProductKHR
OpCapability DotProductInput4x8BitPackedKHR
OpExtension "SPV_KHR_integer_dot_product"
%1 = OpExtInstImport "GLSL.std.450"
OpMemoryModel Logical GLSL450
OpEntryPoint GLCompute %26 "main"
OpExecutionMode %26 LocalSize 1 1 1
%2 = OpTypeVoid
%3 = OpTypeInt 32 0
%6 = OpTypeFunction %3
%7 = OpConstant %3 1
%8 = OpConstant %3 2
%9 = OpConstant %3 3
%10 = OpConstant %3 4
%11 = OpConstant %3 5
%12 = OpConstant %3 6
%13 = OpConstant %3 7
%14 = OpConstant %3 8
%16 = OpTypeInt 32 1
%27 = OpTypeFunction %2
%5 = OpFunction %3 None %6
%4 = OpLabel
OpBranch %15
%15 = OpLabel
%17 = OpSDotKHR %16 %7 %8 PackedVectorFormat4x8BitKHR
%18 = OpUDotKHR %3 %9 %10 PackedVectorFormat4x8BitKHR
%19 = OpIAdd %3 %11 %18
%20 = OpIAdd %3 %12 %18
%21 = OpSDotKHR %16 %19 %20 PackedVectorFormat4x8BitKHR
%22 = OpIAdd %3 %13 %18
%23 = OpIAdd %3 %14 %18
%24 = OpUDotKHR %3 %22 %23 PackedVectorFormat4x8BitKHR
OpReturnValue %24
OpFunctionEnd
%26 = OpFunction %2 None %27
%25 = OpLabel
OpBranch %28
%28 = OpLabel
%29 = OpFunctionCall %3 %5
OpReturn
OpFunctionEnd

View File

@ -0,0 +1,112 @@
; SPIR-V
; Version: 1.1
; Generator: rspirv
; Bound: 99
OpCapability Shader
%1 = OpExtInstImport "GLSL.std.450"
OpMemoryModel Logical GLSL450
OpEntryPoint GLCompute %95 "main"
OpExecutionMode %95 LocalSize 1 1 1
%2 = OpTypeVoid
%3 = OpTypeInt 32 0
%6 = OpTypeFunction %3
%7 = OpConstant %3 1
%8 = OpConstant %3 2
%9 = OpConstant %3 3
%10 = OpConstant %3 4
%11 = OpConstant %3 5
%12 = OpConstant %3 6
%13 = OpConstant %3 7
%14 = OpConstant %3 8
%16 = OpTypeInt 32 1
%20 = OpConstant %3 0
%21 = OpConstant %3 16
%22 = OpConstant %3 24
%23 = OpConstantNull %16
%40 = OpConstantNull %3
%96 = OpTypeFunction %2
%5 = OpFunction %3 None %6
%4 = OpLabel
OpBranch %15
%15 = OpLabel
%18 = OpBitcast %16 %7
%19 = OpBitcast %16 %8
%24 = OpBitFieldSExtract %16 %18 %20 %14
%25 = OpBitFieldSExtract %16 %19 %20 %14
%26 = OpIMul %16 %24 %25
%27 = OpIAdd %16 %23 %26
%28 = OpBitFieldSExtract %16 %18 %14 %14
%29 = OpBitFieldSExtract %16 %19 %14 %14
%30 = OpIMul %16 %28 %29
%31 = OpIAdd %16 %27 %30
%32 = OpBitFieldSExtract %16 %18 %21 %14
%33 = OpBitFieldSExtract %16 %19 %21 %14
%34 = OpIMul %16 %32 %33
%35 = OpIAdd %16 %31 %34
%36 = OpBitFieldSExtract %16 %18 %22 %14
%37 = OpBitFieldSExtract %16 %19 %22 %14
%38 = OpIMul %16 %36 %37
%17 = OpIAdd %16 %35 %38
%41 = OpBitFieldUExtract %3 %9 %20 %14
%42 = OpBitFieldUExtract %3 %10 %20 %14
%43 = OpIMul %3 %41 %42
%44 = OpIAdd %3 %40 %43
%45 = OpBitFieldUExtract %3 %9 %14 %14
%46 = OpBitFieldUExtract %3 %10 %14 %14
%47 = OpIMul %3 %45 %46
%48 = OpIAdd %3 %44 %47
%49 = OpBitFieldUExtract %3 %9 %21 %14
%50 = OpBitFieldUExtract %3 %10 %21 %14
%51 = OpIMul %3 %49 %50
%52 = OpIAdd %3 %48 %51
%53 = OpBitFieldUExtract %3 %9 %22 %14
%54 = OpBitFieldUExtract %3 %10 %22 %14
%55 = OpIMul %3 %53 %54
%39 = OpIAdd %3 %52 %55
%56 = OpIAdd %3 %11 %39
%57 = OpIAdd %3 %12 %39
%59 = OpBitcast %16 %56
%60 = OpBitcast %16 %57
%61 = OpBitFieldSExtract %16 %59 %20 %14
%62 = OpBitFieldSExtract %16 %60 %20 %14
%63 = OpIMul %16 %61 %62
%64 = OpIAdd %16 %23 %63
%65 = OpBitFieldSExtract %16 %59 %14 %14
%66 = OpBitFieldSExtract %16 %60 %14 %14
%67 = OpIMul %16 %65 %66
%68 = OpIAdd %16 %64 %67
%69 = OpBitFieldSExtract %16 %59 %21 %14
%70 = OpBitFieldSExtract %16 %60 %21 %14
%71 = OpIMul %16 %69 %70
%72 = OpIAdd %16 %68 %71
%73 = OpBitFieldSExtract %16 %59 %22 %14
%74 = OpBitFieldSExtract %16 %60 %22 %14
%75 = OpIMul %16 %73 %74
%58 = OpIAdd %16 %72 %75
%76 = OpIAdd %3 %13 %39
%77 = OpIAdd %3 %14 %39
%79 = OpBitFieldUExtract %3 %76 %20 %14
%80 = OpBitFieldUExtract %3 %77 %20 %14
%81 = OpIMul %3 %79 %80
%82 = OpIAdd %3 %40 %81
%83 = OpBitFieldUExtract %3 %76 %14 %14
%84 = OpBitFieldUExtract %3 %77 %14 %14
%85 = OpIMul %3 %83 %84
%86 = OpIAdd %3 %82 %85
%87 = OpBitFieldUExtract %3 %76 %21 %14
%88 = OpBitFieldUExtract %3 %77 %21 %14
%89 = OpIMul %3 %87 %88
%90 = OpIAdd %3 %86 %89
%91 = OpBitFieldUExtract %3 %76 %22 %14
%92 = OpBitFieldUExtract %3 %77 %22 %14
%93 = OpIMul %3 %91 %92
%78 = OpIAdd %3 %90 %93
OpReturnValue %78
OpFunctionEnd
%95 = OpFunction %2 None %96
%94 = OpLabel
OpBranch %97
%97 = OpLabel
%98 = OpFunctionCall %3 %5
OpReturn
OpFunctionEnd

View File

@ -1,12 +1,15 @@
; SPIR-V
; Version: 1.1
; Generator: rspirv
; Bound: 162
; Bound: 95
OpCapability Shader
OpCapability DotProductKHR
OpCapability DotProductInput4x8BitPackedKHR
OpExtension "SPV_KHR_integer_dot_product"
%1 = OpExtInstImport "GLSL.std.450"
OpMemoryModel Logical GLSL450
OpEntryPoint GLCompute %156 "main"
OpExecutionMode %156 LocalSize 1 1 1
OpEntryPoint GLCompute %89 "main"
OpExecutionMode %89 LocalSize 1 1 1
%2 = OpTypeVoid
%4 = OpTypeFloat 32
%3 = OpTypeVector %4 2
@ -39,10 +42,7 @@ OpExecutionMode %156 LocalSize 1 1 1
%76 = OpConstant %6 6
%77 = OpConstant %6 7
%78 = OpConstant %6 8
%83 = OpConstant %6 0
%84 = OpConstant %6 16
%85 = OpConstant %6 24
%157 = OpTypeFunction %2
%90 = OpTypeFunction %2
%8 = OpFunction %3 None %9
%7 = OpLabel
OpBranch %14
@ -96,86 +96,22 @@ OpFunctionEnd
%69 = OpLabel
OpBranch %79
%79 = OpLabel
%81 = OpBitcast %5 %22
%82 = OpBitcast %5 %72
%86 = OpBitFieldSExtract %5 %81 %83 %78
%87 = OpBitFieldSExtract %5 %82 %83 %78
%88 = OpIMul %5 %86 %87
%89 = OpIAdd %5 %32 %88
%90 = OpBitFieldSExtract %5 %81 %78 %78
%91 = OpBitFieldSExtract %5 %82 %78 %78
%92 = OpIMul %5 %90 %91
%93 = OpIAdd %5 %89 %92
%94 = OpBitFieldSExtract %5 %81 %84 %78
%95 = OpBitFieldSExtract %5 %82 %84 %78
%96 = OpIMul %5 %94 %95
%97 = OpIAdd %5 %93 %96
%98 = OpBitFieldSExtract %5 %81 %85 %78
%99 = OpBitFieldSExtract %5 %82 %85 %78
%100 = OpIMul %5 %98 %99
%80 = OpIAdd %5 %97 %100
%102 = OpBitFieldUExtract %6 %73 %83 %78
%103 = OpBitFieldUExtract %6 %74 %83 %78
%104 = OpIMul %6 %102 %103
%105 = OpIAdd %6 %41 %104
%106 = OpBitFieldUExtract %6 %73 %78 %78
%107 = OpBitFieldUExtract %6 %74 %78 %78
%108 = OpIMul %6 %106 %107
%109 = OpIAdd %6 %105 %108
%110 = OpBitFieldUExtract %6 %73 %84 %78
%111 = OpBitFieldUExtract %6 %74 %84 %78
%112 = OpIMul %6 %110 %111
%113 = OpIAdd %6 %109 %112
%114 = OpBitFieldUExtract %6 %73 %85 %78
%115 = OpBitFieldUExtract %6 %74 %85 %78
%116 = OpIMul %6 %114 %115
%101 = OpIAdd %6 %113 %116
%117 = OpIAdd %6 %75 %101
%118 = OpIAdd %6 %76 %101
%120 = OpBitcast %5 %117
%121 = OpBitcast %5 %118
%122 = OpBitFieldSExtract %5 %120 %83 %78
%123 = OpBitFieldSExtract %5 %121 %83 %78
%124 = OpIMul %5 %122 %123
%125 = OpIAdd %5 %32 %124
%126 = OpBitFieldSExtract %5 %120 %78 %78
%127 = OpBitFieldSExtract %5 %121 %78 %78
%128 = OpIMul %5 %126 %127
%129 = OpIAdd %5 %125 %128
%130 = OpBitFieldSExtract %5 %120 %84 %78
%131 = OpBitFieldSExtract %5 %121 %84 %78
%132 = OpIMul %5 %130 %131
%133 = OpIAdd %5 %129 %132
%134 = OpBitFieldSExtract %5 %120 %85 %78
%135 = OpBitFieldSExtract %5 %121 %85 %78
%136 = OpIMul %5 %134 %135
%119 = OpIAdd %5 %133 %136
%137 = OpIAdd %6 %77 %101
%138 = OpIAdd %6 %78 %101
%140 = OpBitFieldUExtract %6 %137 %83 %78
%141 = OpBitFieldUExtract %6 %138 %83 %78
%142 = OpIMul %6 %140 %141
%143 = OpIAdd %6 %41 %142
%144 = OpBitFieldUExtract %6 %137 %78 %78
%145 = OpBitFieldUExtract %6 %138 %78 %78
%146 = OpIMul %6 %144 %145
%147 = OpIAdd %6 %143 %146
%148 = OpBitFieldUExtract %6 %137 %84 %78
%149 = OpBitFieldUExtract %6 %138 %84 %78
%150 = OpIMul %6 %148 %149
%151 = OpIAdd %6 %147 %150
%152 = OpBitFieldUExtract %6 %137 %85 %78
%153 = OpBitFieldUExtract %6 %138 %85 %78
%154 = OpIMul %6 %152 %153
%139 = OpIAdd %6 %151 %154
OpReturnValue %139
%80 = OpSDotKHR %5 %22 %72 PackedVectorFormat4x8BitKHR
%81 = OpUDotKHR %6 %73 %74 PackedVectorFormat4x8BitKHR
%82 = OpIAdd %6 %75 %81
%83 = OpIAdd %6 %76 %81
%84 = OpSDotKHR %5 %82 %83 PackedVectorFormat4x8BitKHR
%85 = OpIAdd %6 %77 %81
%86 = OpIAdd %6 %78 %81
%87 = OpUDotKHR %6 %85 %86 PackedVectorFormat4x8BitKHR
OpReturnValue %87
OpFunctionEnd
%156 = OpFunction %2 None %157
%155 = OpLabel
OpBranch %158
%158 = OpLabel
%159 = OpFunctionCall %3 %8
%160 = OpFunctionCall %5 %17
%161 = OpFunctionCall %6 %70
%89 = OpFunction %2 None %90
%88 = OpLabel
OpBranch %91
%91 = OpLabel
%92 = OpFunctionCall %3 %8
%93 = OpFunctionCall %5 %17
%94 = OpFunctionCall %6 %70
OpReturn
OpFunctionEnd