[naga spv-in] Add support for Memory Barriers

This commit is contained in:
Phena Ildanach 2025-04-24 18:30:44 -05:00 committed by Teodor Tanasoaia
parent 65337894f6
commit 1e031e7a02
26 changed files with 339 additions and 32 deletions

View File

@ -112,7 +112,8 @@ impl StatementGraph {
}
"Continue"
}
S::Barrier(_flags) => "Barrier",
S::ControlBarrier(_flags) => "ControlBarrier",
S::MemoryBarrier(_flags) => "MemoryBarrier",
S::Block(ref b) => {
let (other, last) = self.add(b, targets);
self.flow.push((id, other, ""));

View File

@ -2512,7 +2512,7 @@ impl<'a, W: Write> Writer<'a, W> {
// keyword which ceases all further processing in a fragment shader, it's called OpKill
// in spir-v that's why it's called `Statement::Kill`
Statement::Kill => writeln!(self.out, "{level}discard;")?,
Statement::Barrier(flags) => {
Statement::ControlBarrier(flags) | Statement::MemoryBarrier(flags) => {
self.write_barrier(flags, level)?;
}
// Stores in glsl are just variable assignments written as `pointer = value;`

View File

@ -1647,7 +1647,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
}
writeln!(self.out, "{level}}}")?;
self.write_barrier(crate::Barrier::WORK_GROUP, level)
self.write_control_barrier(crate::Barrier::WORK_GROUP, level)
}
/// Helper method used to write switches
@ -2291,8 +2291,11 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
writeln!(self.out, "{level}continue;")?
}
}
Statement::Barrier(barrier) => {
self.write_barrier(barrier, level)?;
Statement::ControlBarrier(barrier) => {
self.write_control_barrier(barrier, level)?;
}
Statement::MemoryBarrier(barrier) => {
self.write_memory_barrier(barrier, level)?;
}
Statement::ImageStore {
image,
@ -2464,12 +2467,12 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
writeln!(self.out, ");")?;
}
Statement::WorkGroupUniformLoad { pointer, result } => {
self.write_barrier(crate::Barrier::WORK_GROUP, level)?;
self.write_control_barrier(crate::Barrier::WORK_GROUP, level)?;
write!(self.out, "{level}")?;
let name = Baked(result).to_string();
self.write_named_expr(module, pointer, name, result, func_ctx)?;
self.write_barrier(crate::Barrier::WORK_GROUP, level)?;
self.write_control_barrier(crate::Barrier::WORK_GROUP, level)?;
}
Statement::Switch {
selector,
@ -4287,7 +4290,11 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
Ok(())
}
fn write_barrier(&mut self, barrier: crate::Barrier, level: back::Level) -> BackendResult {
fn write_control_barrier(
&mut self,
barrier: crate::Barrier,
level: back::Level,
) -> BackendResult {
if barrier.contains(crate::Barrier::STORAGE) {
writeln!(self.out, "{level}DeviceMemoryBarrierWithGroupSync();")?;
}
@ -4303,6 +4310,26 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
Ok(())
}
fn write_memory_barrier(
&mut self,
barrier: crate::Barrier,
level: back::Level,
) -> BackendResult {
if barrier.contains(crate::Barrier::STORAGE) {
writeln!(self.out, "{level}DeviceMemoryBarrier();")?;
}
if barrier.contains(crate::Barrier::WORK_GROUP) {
writeln!(self.out, "{level}GroupMemoryBarrier();")?;
}
if barrier.contains(crate::Barrier::SUB_GROUP) {
// Does not exist in DirectX
}
if barrier.contains(crate::Barrier::TEXTURE) {
writeln!(self.out, "{level}DeviceMemoryBarrier();")?;
}
Ok(())
}
/// Helper to emit the shared tail of an HLSL atomic call (arguments, value, result)
fn emit_hlsl_atomic_tail(
&mut self,

View File

@ -3700,7 +3700,8 @@ impl<W: Write> Writer<W> {
crate::Statement::Kill => {
writeln!(self.out, "{level}{NAMESPACE}::discard_fragment();")?;
}
crate::Statement::Barrier(flags) => {
crate::Statement::ControlBarrier(flags)
| crate::Statement::MemoryBarrier(flags) => {
self.write_barrier(flags, level)?;
}
crate::Statement::Store { pointer, value } => {

View File

@ -816,7 +816,11 @@ fn adjust_stmt(new_pos: &HandleVec<Expression, Handle<Expression>>, stmt: &mut S
crate::RayQueryFunction::Terminate => {}
}
}
Statement::Break | Statement::Continue | Statement::Kill | Statement::Barrier(_) => {}
Statement::Break
| Statement::Continue
| Statement::Kill
| Statement::ControlBarrier(_)
| Statement::MemoryBarrier(_) => {}
}
}

View File

@ -3240,8 +3240,11 @@ impl BlockContext<'_> {
self.function.consume(block, Instruction::kill());
return Ok(BlockExitDisposition::Discarded);
}
Statement::Barrier(flags) => {
self.writer.write_barrier(flags, &mut block);
Statement::ControlBarrier(flags) => {
self.writer.write_control_barrier(flags, &mut block);
}
Statement::MemoryBarrier(flags) => {
self.writer.write_memory_barrier(flags, &mut block);
}
Statement::Store { pointer, value } => {
let value_id = self.cached[value];
@ -3576,7 +3579,7 @@ impl BlockContext<'_> {
}
Statement::WorkGroupUniformLoad { pointer, result } => {
self.writer
.write_barrier(crate::Barrier::WORK_GROUP, &mut block);
.write_control_barrier(crate::Barrier::WORK_GROUP, &mut block);
let result_type_id = self.get_expression_type_id(&self.fun_info[result].ty);
// Embed the body of
match self.write_access_chain(
@ -3616,7 +3619,7 @@ impl BlockContext<'_> {
}
}
self.writer
.write_barrier(crate::Barrier::WORK_GROUP, &mut block);
.write_control_barrier(crate::Barrier::WORK_GROUP, &mut block);
}
Statement::RayQuery { query, ref fun } => {
self.write_ray_query_function(query, fun, &mut block);

View File

@ -1138,6 +1138,12 @@ impl super::Instruction {
instruction.add_operand(semantics_id);
instruction
}
pub(super) fn memory_barrier(mem_scope_id: Word, semantics_id: Word) -> Self {
let mut instruction = Self::new(Op::MemoryBarrier);
instruction.add_operand(mem_scope_id);
instruction.add_operand(semantics_id);
instruction
}
// Group Instructions

View File

@ -1711,7 +1711,7 @@ impl Writer {
Ok(id)
}
pub(super) fn write_barrier(&mut self, flags: crate::Barrier, block: &mut Block) {
pub(super) fn write_control_barrier(&mut self, flags: crate::Barrier, block: &mut Block) {
let memory_scope = if flags.contains(crate::Barrier::STORAGE) {
spirv::Scope::Device
} else {
@ -1744,6 +1744,37 @@ impl Writer {
));
}
pub(super) fn write_memory_barrier(&mut self, flags: crate::Barrier, block: &mut Block) {
let mut semantics = spirv::MemorySemantics::ACQUIRE_RELEASE;
semantics.set(
spirv::MemorySemantics::UNIFORM_MEMORY,
flags.contains(crate::Barrier::STORAGE),
);
semantics.set(
spirv::MemorySemantics::WORKGROUP_MEMORY,
flags.contains(crate::Barrier::WORK_GROUP),
);
semantics.set(
spirv::MemorySemantics::SUBGROUP_MEMORY,
flags.contains(crate::Barrier::SUB_GROUP),
);
semantics.set(
spirv::MemorySemantics::IMAGE_MEMORY,
flags.contains(crate::Barrier::TEXTURE),
);
let mem_scope_id = if flags.contains(crate::Barrier::STORAGE) {
self.get_index_constant(spirv::Scope::Device as u32)
} else if flags.contains(crate::Barrier::SUB_GROUP) {
self.get_index_constant(spirv::Scope::Subgroup as u32)
} else {
self.get_index_constant(spirv::Scope::Workgroup as u32)
};
let semantics_id = self.get_index_constant(semantics.bits());
block
.body
.push(Instruction::memory_barrier(mem_scope_id, semantics_id));
}
fn generate_workgroup_vars_init_block(
&mut self,
entry_id: Word,
@ -1844,7 +1875,7 @@ impl Writer {
let mut post_if_block = Block::new(merge_id);
self.write_barrier(crate::Barrier::WORK_GROUP, &mut post_if_block);
self.write_control_barrier(crate::Barrier::WORK_GROUP, &mut post_if_block);
let next_id = self.id_gen.next();
function.consume(post_if_block, Instruction::branch(next_id));

View File

@ -830,7 +830,7 @@ impl<W: Write> Writer<W> {
Statement::Continue => {
writeln!(self.out, "{level}continue;")?;
}
Statement::Barrier(barrier) => {
Statement::ControlBarrier(barrier) | Statement::MemoryBarrier(barrier) => {
if barrier.contains(crate::Barrier::STORAGE) {
writeln!(self.out, "{level}storageBarrier();")?;
}

View File

@ -157,7 +157,8 @@ impl FunctionTracer<'_> {
St::Break
| St::Continue
| St::Kill
| St::Barrier(_)
| St::ControlBarrier(_)
| St::MemoryBarrier(_)
| St::Return { value: None } => {}
}
}
@ -375,7 +376,8 @@ impl FunctionMap {
St::Break
| St::Continue
| St::Kill
| St::Barrier(_)
| St::ControlBarrier(_)
| St::MemoryBarrier(_)
| St::Return { value: None } => {}
}
}

View File

@ -2035,8 +2035,10 @@ impl MacroCall {
)?,
MacroCall::Barrier => {
ctx.emit_restart();
ctx.body
.push(crate::Statement::Barrier(crate::Barrier::all()), meta);
ctx.body.push(
crate::Statement::ControlBarrier(crate::Barrier::all()),
meta,
);
return Ok(None);
}
MacroCall::SmoothStep { splatted } => {

View File

@ -3868,11 +3868,50 @@ impl<I: Iterator<Item = u32>> Frontend<I> {
crate::Barrier::TEXTURE,
semantics & spirv::MemorySemantics::IMAGE_MEMORY.bits() != 0,
);
block.push(crate::Statement::Barrier(flags), span);
block.push(crate::Statement::ControlBarrier(flags), span);
} else {
log::warn!("Unsupported barrier execution scope: {}", exec_scope);
}
}
Op::MemoryBarrier => {
inst.expect(3)?;
let mem_scope_id = self.next()?;
let semantics_id = self.next()?;
let mem_scope_const = self.lookup_constant.lookup(mem_scope_id)?;
let semantics_const = self.lookup_constant.lookup(semantics_id)?;
let mem_scope = resolve_constant(ctx.gctx(), &mem_scope_const.inner)
.ok_or(Error::InvalidBarrierScope(mem_scope_id))?;
let semantics = resolve_constant(ctx.gctx(), &semantics_const.inner)
.ok_or(Error::InvalidBarrierMemorySemantics(semantics_id))?;
let mut flags = if mem_scope == spirv::Scope::Device as u32 {
crate::Barrier::STORAGE
} else if mem_scope == spirv::Scope::Workgroup as u32 {
crate::Barrier::WORK_GROUP
} else if mem_scope == spirv::Scope::Subgroup as u32 {
crate::Barrier::SUB_GROUP
} else {
crate::Barrier::empty()
};
flags.set(
crate::Barrier::STORAGE,
semantics & spirv::MemorySemantics::UNIFORM_MEMORY.bits() != 0,
);
flags.set(
crate::Barrier::WORK_GROUP,
semantics & (spirv::MemorySemantics::WORKGROUP_MEMORY).bits() != 0,
);
flags.set(
crate::Barrier::SUB_GROUP,
semantics & spirv::MemorySemantics::SUBGROUP_MEMORY.bits() != 0,
);
flags.set(
crate::Barrier::TEXTURE,
semantics & spirv::MemorySemantics::IMAGE_MEMORY.bits() != 0,
);
block.push(crate::Statement::MemoryBarrier(flags), span);
}
Op::CopyObject => {
inst.expect(4)?;
let result_type_id = self.next()?;
@ -4566,7 +4605,8 @@ impl<I: Iterator<Item = u32>> Frontend<I> {
| S::Continue
| S::Return { .. }
| S::Kill
| S::Barrier(_)
| S::ControlBarrier(_)
| S::MemoryBarrier(_)
| S::Store { .. }
| S::ImageStore { .. }
| S::Atomic { .. }

View File

@ -2737,7 +2737,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
let rctx = ctx.runtime_expression_ctx(span)?;
rctx.block
.push(ir::Statement::Barrier(ir::Barrier::STORAGE), span);
.push(ir::Statement::ControlBarrier(ir::Barrier::STORAGE), span);
return Ok(None);
}
"workgroupBarrier" => {
@ -2745,7 +2745,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
let rctx = ctx.runtime_expression_ctx(span)?;
rctx.block
.push(ir::Statement::Barrier(ir::Barrier::WORK_GROUP), span);
.push(ir::Statement::ControlBarrier(ir::Barrier::WORK_GROUP), span);
return Ok(None);
}
"subgroupBarrier" => {
@ -2753,7 +2753,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
let rctx = ctx.runtime_expression_ctx(span)?;
rctx.block
.push(ir::Statement::Barrier(ir::Barrier::SUB_GROUP), span);
.push(ir::Statement::ControlBarrier(ir::Barrier::SUB_GROUP), span);
return Ok(None);
}
"textureBarrier" => {
@ -2761,7 +2761,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
let rctx = ctx.runtime_expression_ctx(span)?;
rctx.block
.push(ir::Statement::Barrier(ir::Barrier::TEXTURE), span);
.push(ir::Statement::ControlBarrier(ir::Barrier::TEXTURE), span);
return Ok(None);
}
"workgroupUniformLoad" => {

View File

@ -1916,7 +1916,12 @@ pub enum Statement {
/// Synchronize invocations within the work group.
/// The `Barrier` flags control which memory accesses should be synchronized.
/// If empty, this becomes purely an execution barrier.
Barrier(Barrier),
ControlBarrier(Barrier),
/// Synchronize invocations within the work group.
/// The `Barrier` flags control which memory accesses should be synchronized.
MemoryBarrier(Barrier),
/// Stores a value at an address.
///
/// For [`TypeInner::Atomic`] type behind the pointer, the value

View File

@ -42,7 +42,8 @@ pub fn ensure_block_returns(block: &mut crate::Block) {
| S::SubgroupBallot { .. }
| S::SubgroupCollectiveOperation { .. }
| S::SubgroupGather { .. }
| S::Barrier(_)),
| S::ControlBarrier(_)
| S::MemoryBarrier(_)),
)
| None => block.push(S::Return { value: None }, Default::default()),
}

View File

@ -902,7 +902,7 @@ impl FunctionInfo {
ExitFlags::empty()
},
},
S::Barrier(_) => FunctionUniformity {
S::ControlBarrier(_) | S::MemoryBarrier(_) => FunctionUniformity {
result: Uniformity {
non_uniform_result: None,
requirements: UniformityRequirements::WORK_GROUP_BARRIER,

View File

@ -1006,7 +1006,7 @@ impl super::Validator {
S::Kill => {
stages &= super::ShaderStages::FRAGMENT;
}
S::Barrier(barrier) => {
S::ControlBarrier(barrier) | S::MemoryBarrier(barrier) => {
stages &= super::ShaderStages::COMPUTE;
if barrier.contains(crate::Barrier::SUB_GROUP) {
if !self.capabilities.contains(

View File

@ -837,7 +837,8 @@ impl super::Validator {
crate::Statement::Break
| crate::Statement::Continue
| crate::Statement::Kill
| crate::Statement::Barrier(_) => Ok(()),
| crate::Statement::ControlBarrier(_)
| crate::Statement::MemoryBarrier(_) => Ok(()),
})
}
}

View File

@ -0,0 +1,27 @@
; SPIR-V
; Version: 1.5
; Generator: Google rspirv; 0
; Bound: 14
; Schema: 0
OpCapability Shader
OpMemoryModel Logical Simple
OpEntryPoint GLCompute %1 "main"
OpExecutionMode %1 LocalSize 64 1 1
%void = OpTypeVoid
%6 = OpTypeFunction %void
%uint = OpTypeInt 32 0
%uint_2 = OpConstant %uint 2
%uint_264 = OpConstant %uint 264
%uint_1 = OpConstant %uint 1
%uint_2120 = OpConstant %uint 2120
%uint_2376 = OpConstant %uint 2376
%1 = OpFunction %void None %6
%13 = OpLabel
OpMemoryBarrier %uint_2 %uint_264
OpControlBarrier %uint_2 %uint_2 %uint_264
OpMemoryBarrier %uint_1 %uint_2120
OpControlBarrier %uint_2 %uint_1 %uint_2120
OpMemoryBarrier %uint_1 %uint_2376
OpControlBarrier %uint_2 %uint_1 %uint_2376
OpReturn
OpFunctionEnd

View File

@ -0,0 +1,4 @@
targets = "WGSL | SPIRV | GLSL | HLSL | METAL"
[msl]
lang_version = [2, 0]

View File

@ -0,0 +1,34 @@
#version 310 es
precision highp float;
precision highp int;
layout(local_size_x = 64, local_size_y = 1, local_size_z = 1) in;
void function() {
memoryBarrierShared();
barrier();
memoryBarrierShared();
barrier();
memoryBarrierBuffer();
memoryBarrierImage();
barrier();
memoryBarrierBuffer();
memoryBarrierImage();
barrier();
memoryBarrierBuffer();
memoryBarrierShared();
memoryBarrierImage();
barrier();
memoryBarrierBuffer();
memoryBarrierShared();
memoryBarrierImage();
barrier();
return;
}
void main() {
function();
}

View File

@ -0,0 +1,22 @@
void function()
{
GroupMemoryBarrier();
GroupMemoryBarrierWithGroupSync();
DeviceMemoryBarrier();
DeviceMemoryBarrier();
DeviceMemoryBarrierWithGroupSync();
DeviceMemoryBarrierWithGroupSync();
DeviceMemoryBarrier();
GroupMemoryBarrier();
DeviceMemoryBarrier();
DeviceMemoryBarrierWithGroupSync();
GroupMemoryBarrierWithGroupSync();
DeviceMemoryBarrierWithGroupSync();
return;
}
[numthreads(64, 1, 1)]
void main()
{
function();
}

View File

@ -0,0 +1,12 @@
(
vertex:[
],
fragment:[
],
compute:[
(
entry_point:"main",
target_profile:"cs_5_1",
),
],
)

View File

@ -0,0 +1,28 @@
// language: metal2.0
#include <metal_stdlib>
#include <simd/simd.h>
using metal::uint;
void function(
) {
metal::threadgroup_barrier(metal::mem_flags::mem_threadgroup);
metal::threadgroup_barrier(metal::mem_flags::mem_threadgroup);
metal::threadgroup_barrier(metal::mem_flags::mem_device);
metal::threadgroup_barrier(metal::mem_flags::mem_texture);
metal::threadgroup_barrier(metal::mem_flags::mem_device);
metal::threadgroup_barrier(metal::mem_flags::mem_texture);
metal::threadgroup_barrier(metal::mem_flags::mem_device);
metal::threadgroup_barrier(metal::mem_flags::mem_threadgroup);
metal::threadgroup_barrier(metal::mem_flags::mem_texture);
metal::threadgroup_barrier(metal::mem_flags::mem_device);
metal::threadgroup_barrier(metal::mem_flags::mem_threadgroup);
metal::threadgroup_barrier(metal::mem_flags::mem_texture);
return;
}
kernel void main_(
) {
function();
}

View File

@ -0,0 +1,36 @@
; SPIR-V
; Version: 1.1
; Generator: rspirv
; Bound: 17
OpCapability Shader
%1 = OpExtInstImport "GLSL.std.450"
OpMemoryModel Logical GLSL450
OpEntryPoint GLCompute %14 "main"
OpExecutionMode %14 LocalSize 64 1 1
%2 = OpTypeVoid
%5 = OpTypeFunction %2
%8 = OpTypeInt 32 0
%7 = OpConstant %8 2
%9 = OpConstant %8 264
%10 = OpConstant %8 1
%11 = OpConstant %8 2120
%12 = OpConstant %8 2376
%4 = OpFunction %2 None %5
%3 = OpLabel
OpBranch %6
%6 = OpLabel
OpMemoryBarrier %7 %9
OpControlBarrier %7 %7 %9
OpMemoryBarrier %10 %11
OpControlBarrier %7 %10 %11
OpMemoryBarrier %10 %12
OpControlBarrier %7 %10 %12
OpReturn
OpFunctionEnd
%14 = OpFunction %2 None %5
%13 = OpLabel
OpBranch %15
%15 = OpLabel
%16 = OpFunctionCall %2 %4
OpReturn
OpFunctionEnd

View File

@ -0,0 +1,20 @@
fn function() {
workgroupBarrier();
workgroupBarrier();
storageBarrier();
textureBarrier();
storageBarrier();
textureBarrier();
storageBarrier();
workgroupBarrier();
textureBarrier();
storageBarrier();
workgroupBarrier();
textureBarrier();
return;
}
@compute @workgroup_size(64, 1, 1)
fn main() {
function();
}