[naga spv] Split workgroup and subgroup memory semantics in Control Barriers

This commit is contained in:
Phena Ildanach 2025-04-24 18:33:39 -05:00 committed by Teodor Tanasoaia
parent 1e031e7a02
commit dd273fd7e2
11 changed files with 157 additions and 45 deletions

View File

@ -1714,6 +1714,8 @@ impl Writer {
pub(super) fn write_control_barrier(&mut self, flags: crate::Barrier, block: &mut Block) {
let memory_scope = if flags.contains(crate::Barrier::STORAGE) {
spirv::Scope::Device
} else if flags.contains(crate::Barrier::SUB_GROUP) {
spirv::Scope::Subgroup
} else {
spirv::Scope::Workgroup
};
@ -1726,6 +1728,10 @@ impl Writer {
spirv::MemorySemantics::WORKGROUP_MEMORY,
flags.contains(crate::Barrier::WORK_GROUP),
);
semantics.set(
spirv::MemorySemantics::SUBGROUP_MEMORY,
flags.contains(crate::Barrier::SUB_GROUP),
);
semantics.set(
spirv::MemorySemantics::IMAGE_MEMORY,
flags.contains(crate::Barrier::TEXTURE),

View File

@ -3850,7 +3850,9 @@ impl<I: Iterator<Item = u32>> Frontend<I> {
let semantics = resolve_constant(ctx.gctx(), &semantics_const.inner)
.ok_or(Error::InvalidBarrierMemorySemantics(semantics_id))?;
if exec_scope == spirv::Scope::Workgroup as u32 {
if exec_scope == spirv::Scope::Workgroup as u32
|| exec_scope == spirv::Scope::Subgroup as u32
{
let mut flags = crate::Barrier::empty();
flags.set(
crate::Barrier::STORAGE,
@ -3858,11 +3860,11 @@ impl<I: Iterator<Item = u32>> Frontend<I> {
);
flags.set(
crate::Barrier::WORK_GROUP,
semantics
& (spirv::MemorySemantics::SUBGROUP_MEMORY
| spirv::MemorySemantics::WORKGROUP_MEMORY)
.bits()
!= 0,
semantics & (spirv::MemorySemantics::WORKGROUP_MEMORY).bits() != 0,
);
flags.set(
crate::Barrier::SUB_GROUP,
semantics & spirv::MemorySemantics::SUBGROUP_MEMORY.bits() != 0,
);
flags.set(
crate::Barrier::TEXTURE,

View File

@ -0,0 +1,20 @@
; SPIR-V
; Version: 1.5
; Generator: Google rspirv; 0
; Bound: 14
; Schema: 0
OpCapability Shader
OpMemoryModel Logical Simple
OpEntryPoint GLCompute %1 "main"
OpExecutionMode %1 LocalSize 64 1 1
%void = OpTypeVoid
%6 = OpTypeFunction %void
%uint = OpTypeInt 32 0
%uint_3 = OpConstant %uint 3
%uint_136 = OpConstant %uint 136
%1 = OpFunction %void None %6
%13 = OpLabel
OpMemoryBarrier %uint_3 %uint_136
OpControlBarrier %uint_3 %uint_3 %uint_136
OpReturn
OpFunctionEnd

View File

@ -0,0 +1,5 @@
god_mode = true
targets = "WGSL | SPIRV | GLSL | METAL"
[msl]
lang_version = [2, 0]

View File

@ -0,0 +1,20 @@
#version 310 es
precision highp float;
precision highp int;
layout(local_size_x = 64, local_size_y = 1, local_size_z = 1) in;
void function() {
subgroupMemoryBarrier();
barrier();
subgroupMemoryBarrier();
barrier();
return;
}
void main() {
function();
}

View File

@ -21,6 +21,7 @@ uint global_3 = 0u;
void function() {
uint _e5 = global_2;
uint _e6 = global_3;
barrier();
uvec4 _e9 = subgroupBallot(((_e6 & 1u) == 1u));
uvec4 _e10 = subgroupBallot(true);
bool _e12 = subgroupAll((_e6 != 0u));

View File

@ -0,0 +1,18 @@
// language: metal2.0
#include <metal_stdlib>
#include <simd/simd.h>
using metal::uint;
void function(
) {
metal::simdgroup_barrier(metal::mem_flags::mem_threadgroup);
metal::simdgroup_barrier(metal::mem_flags::mem_threadgroup);
return;
}
kernel void main_(
) {
function();
}

View File

@ -11,6 +11,7 @@ void function(
) {
uint _e5 = global_2;
uint _e6 = global_3;
metal::threadgroup_barrier(metal::mem_flags::mem_none);
metal::uint4 unnamed = metal::uint4((uint64_t)metal::simd_ballot((_e6 & 1u) == 1u), 0, 0, 0);
metal::uint4 unnamed_1 = metal::uint4((uint64_t)metal::simd_ballot(true), 0, 0, 0);
bool unnamed_2 = metal::simd_all(_e6 != 0u);

View File

@ -0,0 +1,29 @@
; SPIR-V
; Version: 1.1
; Generator: rspirv
; Bound: 14
OpCapability Shader
%1 = OpExtInstImport "GLSL.std.450"
OpMemoryModel Logical GLSL450
OpEntryPoint GLCompute %11 "main"
OpExecutionMode %11 LocalSize 64 1 1
%2 = OpTypeVoid
%5 = OpTypeFunction %2
%8 = OpTypeInt 32 0
%7 = OpConstant %8 3
%9 = OpConstant %8 136
%4 = OpFunction %2 None %5
%3 = OpLabel
OpBranch %6
%6 = OpLabel
OpMemoryBarrier %7 %9
OpControlBarrier %7 %7 %9
OpReturn
OpFunctionEnd
%11 = OpFunction %2 None %5
%10 = OpLabel
OpBranch %12
%12 = OpLabel
%13 = OpFunctionCall %2 %4
OpReturn
OpFunctionEnd

View File

@ -34,10 +34,10 @@ OpDecorate %15 BuiltIn SubgroupLocalInvocationId
%20 = OpConstant %3 0
%21 = OpConstant %3 4
%23 = OpConstant %3 3
%24 = OpConstant %3 2
%25 = OpConstant %3 8
%28 = OpTypeVector %3 4
%30 = OpConstantTrue %5
%24 = OpConstant %3 136
%27 = OpTypeVector %3 4
%29 = OpConstantTrue %5
%61 = OpConstant %3 2
%17 = OpFunction %2 None %18
%6 = OpLabel
%10 = OpLoad %3 %8
@ -47,40 +47,40 @@ OpDecorate %15 BuiltIn SubgroupLocalInvocationId
%16 = OpLoad %3 %15
OpBranch %22
%22 = OpLabel
OpControlBarrier %23 %24 %25
%26 = OpBitwiseAnd %3 %16 %19
%27 = OpIEqual %5 %26 %19
%29 = OpGroupNonUniformBallot %28 %23 %27
%31 = OpGroupNonUniformBallot %28 %23 %30
%32 = OpINotEqual %5 %16 %20
%33 = OpGroupNonUniformAll %5 %23 %32
%34 = OpIEqual %5 %16 %20
%35 = OpGroupNonUniformAny %5 %23 %34
%36 = OpGroupNonUniformIAdd %3 %23 Reduce %16
%37 = OpGroupNonUniformIMul %3 %23 Reduce %16
%38 = OpGroupNonUniformUMin %3 %23 Reduce %16
%39 = OpGroupNonUniformUMax %3 %23 Reduce %16
%40 = OpGroupNonUniformBitwiseAnd %3 %23 Reduce %16
%41 = OpGroupNonUniformBitwiseOr %3 %23 Reduce %16
%42 = OpGroupNonUniformBitwiseXor %3 %23 Reduce %16
%43 = OpGroupNonUniformIAdd %3 %23 ExclusiveScan %16
%44 = OpGroupNonUniformIMul %3 %23 ExclusiveScan %16
%45 = OpGroupNonUniformIAdd %3 %23 InclusiveScan %16
%46 = OpGroupNonUniformIMul %3 %23 InclusiveScan %16
%47 = OpGroupNonUniformBroadcastFirst %3 %23 %16
%48 = OpGroupNonUniformShuffle %3 %23 %16 %21
%49 = OpCompositeExtract %3 %7 1
%50 = OpISub %3 %49 %19
%51 = OpISub %3 %50 %16
%52 = OpGroupNonUniformShuffle %3 %23 %16 %51
%53 = OpGroupNonUniformShuffleDown %3 %23 %16 %19
%54 = OpGroupNonUniformShuffleUp %3 %23 %16 %19
%55 = OpCompositeExtract %3 %7 1
%56 = OpISub %3 %55 %19
%57 = OpGroupNonUniformShuffleXor %3 %23 %16 %56
%58 = OpGroupNonUniformQuadBroadcast %3 %23 %16 %21
%59 = OpGroupNonUniformQuadSwap %3 %23 %16 %20
%60 = OpGroupNonUniformQuadSwap %3 %23 %16 %19
%61 = OpGroupNonUniformQuadSwap %3 %23 %16 %24
OpControlBarrier %23 %23 %24
%25 = OpBitwiseAnd %3 %16 %19
%26 = OpIEqual %5 %25 %19
%28 = OpGroupNonUniformBallot %27 %23 %26
%30 = OpGroupNonUniformBallot %27 %23 %29
%31 = OpINotEqual %5 %16 %20
%32 = OpGroupNonUniformAll %5 %23 %31
%33 = OpIEqual %5 %16 %20
%34 = OpGroupNonUniformAny %5 %23 %33
%35 = OpGroupNonUniformIAdd %3 %23 Reduce %16
%36 = OpGroupNonUniformIMul %3 %23 Reduce %16
%37 = OpGroupNonUniformUMin %3 %23 Reduce %16
%38 = OpGroupNonUniformUMax %3 %23 Reduce %16
%39 = OpGroupNonUniformBitwiseAnd %3 %23 Reduce %16
%40 = OpGroupNonUniformBitwiseOr %3 %23 Reduce %16
%41 = OpGroupNonUniformBitwiseXor %3 %23 Reduce %16
%42 = OpGroupNonUniformIAdd %3 %23 ExclusiveScan %16
%43 = OpGroupNonUniformIMul %3 %23 ExclusiveScan %16
%44 = OpGroupNonUniformIAdd %3 %23 InclusiveScan %16
%45 = OpGroupNonUniformIMul %3 %23 InclusiveScan %16
%46 = OpGroupNonUniformBroadcastFirst %3 %23 %16
%47 = OpGroupNonUniformShuffle %3 %23 %16 %21
%48 = OpCompositeExtract %3 %7 1
%49 = OpISub %3 %48 %19
%50 = OpISub %3 %49 %16
%51 = OpGroupNonUniformShuffle %3 %23 %16 %50
%52 = OpGroupNonUniformShuffleDown %3 %23 %16 %19
%53 = OpGroupNonUniformShuffleUp %3 %23 %16 %19
%54 = OpCompositeExtract %3 %7 1
%55 = OpISub %3 %54 %19
%56 = OpGroupNonUniformShuffleXor %3 %23 %16 %55
%57 = OpGroupNonUniformQuadBroadcast %3 %23 %16 %21
%58 = OpGroupNonUniformQuadSwap %3 %23 %16 %20
%59 = OpGroupNonUniformQuadSwap %3 %23 %16 %19
%60 = OpGroupNonUniformQuadSwap %3 %23 %16 %61
OpReturn
OpFunctionEnd

View File

@ -0,0 +1,10 @@
fn function() {
subgroupBarrier();
subgroupBarrier();
return;
}
@compute @workgroup_size(64, 1, 1)
fn main() {
function();
}