Optimize dot4{I, U}8Packed for all spv versions

Emit optimized code for `dot4{I, U}8Packed` regardless of SPIR-V version
as long as the required capabilities are available. On SPIR-V < 1.6,
require the extension "SPV_KHR_integer_dot_product" for this. On
SPIR-V >= 1.6, don't require the extension because the corresponding
capabilities are part of SPIR-V >= 1.6 proper.
This commit is contained in:
Robert Bamler 2025-04-27 15:37:47 +02:00 committed by Teodor Tanasoaia
parent 065d6546c4
commit bb83976ddb
11 changed files with 122 additions and 108 deletions

View File

@ -1143,17 +1143,21 @@ impl BlockContext<'_> {
),
},
fun @ (Mf::Dot4I8Packed | Mf::Dot4U8Packed) => {
if self.writer.lang_version() >= (1, 6)
&& self
.writer
.require_all(&[
spirv::Capability::DotProduct,
spirv::Capability::DotProductInput4x8BitPacked,
])
.is_ok()
if self
.writer
.require_all(&[
spirv::Capability::DotProduct,
spirv::Capability::DotProductInput4x8BitPacked,
])
.is_ok()
{
// Write optimized code using `PackedVectorFormat4x8Bit`.
self.writer.use_extension("SPV_KHR_integer_dot_product");
if self.writer.lang_version() < (1, 6) {
// SPIR-V 1.6 supports the required capabilities natively, so the extension
// is only required for earlier versions. See right column of
// <https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#OpSDot>.
self.writer.use_extension("SPV_KHR_integer_dot_product");
}
let op = match fun {
Mf::Dot4I8Packed => spirv::Op::SDot,

View File

@ -0,0 +1,9 @@
# Turn on optimizations for `dot4I8Packed` and `dot4U8Packed` by enabling the
# required capabilities on a SPIR-V version where these capabilities are only
# available via the extension "SPV_KHR_integer_dot_product".
targets = "SPIRV"
[spv]
capabilities = ["DotProduct", "DotProductInput4x8BitPacked"]
version = [1, 0]

View File

@ -0,0 +1,12 @@
# Turn on optimizations for `dot4I8Packed` and `dot4U8Packed` on SPIR-V and HLSL by
# using a version of SPIR-V / shader model that supports these without any extensions.
targets = "SPIRV | HLSL"
[spv]
# We also need to provide the corresponding capabilities (which are part of SPIR-V >= 1.6).
capabilities = ["DotProduct", "DotProductInput4x8BitPacked"]
version = [1, 6]
[hlsl]
shader_model = "V6_4"

View File

@ -0,0 +1,19 @@
fn test_packed_integer_dot_product() -> u32 {
let a_5 = 1u;
let b_5 = 2u;
let c_5: i32 = dot4I8Packed(a_5, b_5);
let a_6 = 3u;
let b_6 = 4u;
let c_6: u32 = dot4U8Packed(a_6, b_6);
// test baking of arguments
let c_7: i32 = dot4I8Packed(5u + c_6, 6u + c_6);
let c_8: u32 = dot4U8Packed(7u + c_6, 8u + c_6);
return c_8;
}
@compute @workgroup_size(1)
fn main() {
let c = test_packed_integer_dot_product();
}

View File

@ -1,11 +0,0 @@
# Explicitly turn on optimizations for `dot4I8Packed` and `dot4U8Packed`
# on SPIRV and HLSL.
targets = "SPIRV | HLSL"
[spv]
capabilities = ["DotProduct", "DotProductInput4x8BitPacked"]
version = [1, 6]
[hlsl]
shader_model = "V6_4"

View File

@ -1,5 +1,5 @@
; SPIR-V
; Version: 1.6
; Version: 1.0
; Generator: rspirv
; Bound: 30
OpCapability Shader

View File

@ -0,0 +1,45 @@
; SPIR-V
; Version: 1.6
; Generator: rspirv
; Bound: 30
OpCapability Shader
OpCapability DotProductKHR
OpCapability DotProductInput4x8BitPackedKHR
%1 = OpExtInstImport "GLSL.std.450"
OpMemoryModel Logical GLSL450
OpEntryPoint GLCompute %26 "main"
OpExecutionMode %26 LocalSize 1 1 1
%2 = OpTypeVoid
%3 = OpTypeInt 32 0
%6 = OpTypeFunction %3
%7 = OpConstant %3 1
%8 = OpConstant %3 2
%9 = OpConstant %3 3
%10 = OpConstant %3 4
%11 = OpConstant %3 5
%12 = OpConstant %3 6
%13 = OpConstant %3 7
%14 = OpConstant %3 8
%16 = OpTypeInt 32 1
%27 = OpTypeFunction %2
%5 = OpFunction %3 None %6
%4 = OpLabel
OpBranch %15
%15 = OpLabel
%17 = OpSDotKHR %16 %7 %8 PackedVectorFormat4x8BitKHR
%18 = OpUDotKHR %3 %9 %10 PackedVectorFormat4x8BitKHR
%19 = OpIAdd %3 %11 %18
%20 = OpIAdd %3 %12 %18
%21 = OpSDotKHR %16 %19 %20 PackedVectorFormat4x8BitKHR
%22 = OpIAdd %3 %13 %18
%23 = OpIAdd %3 %14 %18
%24 = OpUDotKHR %3 %22 %23 PackedVectorFormat4x8BitKHR
OpReturnValue %24
OpFunctionEnd
%26 = OpFunction %2 None %27
%25 = OpLabel
OpBranch %28
%28 = OpLabel
%29 = OpFunctionCall %3 %5
OpReturn
OpFunctionEnd

View File

@ -1,12 +1,15 @@
; SPIR-V
; Version: 1.1
; Generator: rspirv
; Bound: 162
; Bound: 95
OpCapability Shader
OpCapability DotProductKHR
OpCapability DotProductInput4x8BitPackedKHR
OpExtension "SPV_KHR_integer_dot_product"
%1 = OpExtInstImport "GLSL.std.450"
OpMemoryModel Logical GLSL450
OpEntryPoint GLCompute %156 "main"
OpExecutionMode %156 LocalSize 1 1 1
OpEntryPoint GLCompute %89 "main"
OpExecutionMode %89 LocalSize 1 1 1
%2 = OpTypeVoid
%4 = OpTypeFloat 32
%3 = OpTypeVector %4 2
@ -39,10 +42,7 @@ OpExecutionMode %156 LocalSize 1 1 1
%76 = OpConstant %6 6
%77 = OpConstant %6 7
%78 = OpConstant %6 8
%83 = OpConstant %6 0
%84 = OpConstant %6 16
%85 = OpConstant %6 24
%157 = OpTypeFunction %2
%90 = OpTypeFunction %2
%8 = OpFunction %3 None %9
%7 = OpLabel
OpBranch %14
@ -96,86 +96,22 @@ OpFunctionEnd
%69 = OpLabel
OpBranch %79
%79 = OpLabel
%81 = OpBitcast %5 %22
%82 = OpBitcast %5 %72
%86 = OpBitFieldSExtract %5 %81 %83 %78
%87 = OpBitFieldSExtract %5 %82 %83 %78
%88 = OpIMul %5 %86 %87
%89 = OpIAdd %5 %32 %88
%90 = OpBitFieldSExtract %5 %81 %78 %78
%91 = OpBitFieldSExtract %5 %82 %78 %78
%92 = OpIMul %5 %90 %91
%93 = OpIAdd %5 %89 %92
%94 = OpBitFieldSExtract %5 %81 %84 %78
%95 = OpBitFieldSExtract %5 %82 %84 %78
%96 = OpIMul %5 %94 %95
%97 = OpIAdd %5 %93 %96
%98 = OpBitFieldSExtract %5 %81 %85 %78
%99 = OpBitFieldSExtract %5 %82 %85 %78
%100 = OpIMul %5 %98 %99
%80 = OpIAdd %5 %97 %100
%102 = OpBitFieldUExtract %6 %73 %83 %78
%103 = OpBitFieldUExtract %6 %74 %83 %78
%104 = OpIMul %6 %102 %103
%105 = OpIAdd %6 %41 %104
%106 = OpBitFieldUExtract %6 %73 %78 %78
%107 = OpBitFieldUExtract %6 %74 %78 %78
%108 = OpIMul %6 %106 %107
%109 = OpIAdd %6 %105 %108
%110 = OpBitFieldUExtract %6 %73 %84 %78
%111 = OpBitFieldUExtract %6 %74 %84 %78
%112 = OpIMul %6 %110 %111
%113 = OpIAdd %6 %109 %112
%114 = OpBitFieldUExtract %6 %73 %85 %78
%115 = OpBitFieldUExtract %6 %74 %85 %78
%116 = OpIMul %6 %114 %115
%101 = OpIAdd %6 %113 %116
%117 = OpIAdd %6 %75 %101
%118 = OpIAdd %6 %76 %101
%120 = OpBitcast %5 %117
%121 = OpBitcast %5 %118
%122 = OpBitFieldSExtract %5 %120 %83 %78
%123 = OpBitFieldSExtract %5 %121 %83 %78
%124 = OpIMul %5 %122 %123
%125 = OpIAdd %5 %32 %124
%126 = OpBitFieldSExtract %5 %120 %78 %78
%127 = OpBitFieldSExtract %5 %121 %78 %78
%128 = OpIMul %5 %126 %127
%129 = OpIAdd %5 %125 %128
%130 = OpBitFieldSExtract %5 %120 %84 %78
%131 = OpBitFieldSExtract %5 %121 %84 %78
%132 = OpIMul %5 %130 %131
%133 = OpIAdd %5 %129 %132
%134 = OpBitFieldSExtract %5 %120 %85 %78
%135 = OpBitFieldSExtract %5 %121 %85 %78
%136 = OpIMul %5 %134 %135
%119 = OpIAdd %5 %133 %136
%137 = OpIAdd %6 %77 %101
%138 = OpIAdd %6 %78 %101
%140 = OpBitFieldUExtract %6 %137 %83 %78
%141 = OpBitFieldUExtract %6 %138 %83 %78
%142 = OpIMul %6 %140 %141
%143 = OpIAdd %6 %41 %142
%144 = OpBitFieldUExtract %6 %137 %78 %78
%145 = OpBitFieldUExtract %6 %138 %78 %78
%146 = OpIMul %6 %144 %145
%147 = OpIAdd %6 %143 %146
%148 = OpBitFieldUExtract %6 %137 %84 %78
%149 = OpBitFieldUExtract %6 %138 %84 %78
%150 = OpIMul %6 %148 %149
%151 = OpIAdd %6 %147 %150
%152 = OpBitFieldUExtract %6 %137 %85 %78
%153 = OpBitFieldUExtract %6 %138 %85 %78
%154 = OpIMul %6 %152 %153
%139 = OpIAdd %6 %151 %154
OpReturnValue %139
%80 = OpSDotKHR %5 %22 %72 PackedVectorFormat4x8BitKHR
%81 = OpUDotKHR %6 %73 %74 PackedVectorFormat4x8BitKHR
%82 = OpIAdd %6 %75 %81
%83 = OpIAdd %6 %76 %81
%84 = OpSDotKHR %5 %82 %83 PackedVectorFormat4x8BitKHR
%85 = OpIAdd %6 %77 %81
%86 = OpIAdd %6 %78 %81
%87 = OpUDotKHR %6 %85 %86 PackedVectorFormat4x8BitKHR
OpReturnValue %87
OpFunctionEnd
%156 = OpFunction %2 None %157
%155 = OpLabel
OpBranch %158
%158 = OpLabel
%159 = OpFunctionCall %3 %8
%160 = OpFunctionCall %5 %17
%161 = OpFunctionCall %6 %70
%89 = OpFunction %2 None %90
%88 = OpLabel
OpBranch %91
%91 = OpLabel
%92 = OpFunctionCall %3 %8
%93 = OpFunctionCall %5 %17
%94 = OpFunctionCall %6 %70
OpReturn
OpFunctionEnd