Implement subgroup quad ops (#7683)

* Rudimentary impl of quad ops, impl quad ops for spirv

* Impl quad swap for hlsl, msl and wgsl, finish spv front

* Cargo clippy & cargo fmt, impl valid for quad ops

* Enable quad feature

* Add missing feature to glsl

* Simplifying code by making `SubgroupQuadSwap` an instance of `SubgroupGather`

* Add `GroupNonUniformQuad` spv capability to Vulkan

* Adding GPU tests for quad operations

* Validate that broadcast operations use const invocation ids

* Added changelog entry

---------

Co-authored-by: valaphee <32491319+valaphee@users.noreply.github.com>
This commit is contained in:
Dmitry Zamkov 2025-05-26 02:32:01 -05:00 committed by GitHub
parent 4cd8be548c
commit 9c023e5e29
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
29 changed files with 429 additions and 58 deletions

View File

@ -54,6 +54,7 @@ Bottom level categories:
#### Naga
- When emitting GLSL, Uniform and Storage Buffer memory layouts are now emitted even if no explicit binding is given. By @cloone8 in [#7579](https://github.com/gfx-rs/wgpu/pull/7579).
- Add support for [quad operations](https://www.w3.org/TR/WGSL/#quad-builtin-functions) (requires `SUBGROUP` feature to be enabled). By @dzamkov and @valaphee in [#7683](https://github.com/gfx-rs/wgpu/pull/7683).
### Bug Fixes

View File

@ -379,9 +379,11 @@ impl StatementGraph {
| crate::GatherMode::Shuffle(index)
| crate::GatherMode::ShuffleDown(index)
| crate::GatherMode::ShuffleUp(index)
| crate::GatherMode::ShuffleXor(index) => {
| crate::GatherMode::ShuffleXor(index)
| crate::GatherMode::QuadBroadcast(index) => {
self.dependencies.push((id, index, "index"))
}
crate::GatherMode::QuadSwap(_) => {}
}
self.dependencies.push((id, argument, "arg"));
self.emits.push((id, result));
@ -392,6 +394,12 @@ impl StatementGraph {
crate::GatherMode::ShuffleDown(_) => "SubgroupShuffleDown",
crate::GatherMode::ShuffleUp(_) => "SubgroupShuffleUp",
crate::GatherMode::ShuffleXor(_) => "SubgroupShuffleXor",
crate::GatherMode::QuadBroadcast(_) => "SubgroupQuadBroadcast",
crate::GatherMode::QuadSwap(direction) => match direction {
crate::Direction::X => "SubgroupQuadSwapX",
crate::Direction::Y => "SubgroupQuadSwapY",
crate::Direction::Diagonal => "SubgroupQuadSwapDiagonal",
},
}
}
};

View File

@ -280,6 +280,7 @@ impl FeaturesManager {
out,
"#extension GL_KHR_shader_subgroup_shuffle_relative : require"
)?;
writeln!(out, "#extension GL_KHR_shader_subgroup_quad : require")?;
}
if self.0.contains(Features::TEXTURE_ATOMICS) {

View File

@ -2717,6 +2717,20 @@ impl<'a, W: Write> Writer<'a, W> {
crate::GatherMode::ShuffleXor(_) => {
write!(self.out, "subgroupShuffleXor(")?;
}
crate::GatherMode::QuadBroadcast(_) => {
write!(self.out, "subgroupQuadBroadcast(")?;
}
crate::GatherMode::QuadSwap(direction) => match direction {
crate::Direction::X => {
write!(self.out, "subgroupQuadSwapHorizontal(")?;
}
crate::Direction::Y => {
write!(self.out, "subgroupQuadSwapVertical(")?;
}
crate::Direction::Diagonal => {
write!(self.out, "subgroupQuadSwapDiagonal(")?;
}
},
}
self.write_expr(argument, ctx)?;
match mode {
@ -2725,10 +2739,12 @@ impl<'a, W: Write> Writer<'a, W> {
| crate::GatherMode::Shuffle(index)
| crate::GatherMode::ShuffleDown(index)
| crate::GatherMode::ShuffleUp(index)
| crate::GatherMode::ShuffleXor(index) => {
| crate::GatherMode::ShuffleXor(index)
| crate::GatherMode::QuadBroadcast(index) => {
write!(self.out, ", ")?;
self.write_expr(index, ctx)?;
}
crate::GatherMode::QuadSwap(_) => {}
}
writeln!(self.out, ");")?;
}

View File

@ -2610,30 +2610,55 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
};
write!(self.out, " {name} = ")?;
self.named_expressions.insert(result, name);
if matches!(mode, crate::GatherMode::BroadcastFirst) {
write!(self.out, "WaveReadLaneFirst(")?;
self.write_expr(module, argument, func_ctx)?;
} else {
write!(self.out, "WaveReadLaneAt(")?;
self.write_expr(module, argument, func_ctx)?;
write!(self.out, ", ")?;
match mode {
crate::GatherMode::BroadcastFirst => unreachable!(),
crate::GatherMode::Broadcast(index) | crate::GatherMode::Shuffle(index) => {
self.write_expr(module, index, func_ctx)?;
match mode {
crate::GatherMode::BroadcastFirst => {
write!(self.out, "WaveReadLaneFirst(")?;
self.write_expr(module, argument, func_ctx)?;
}
crate::GatherMode::QuadBroadcast(index) => {
write!(self.out, "QuadReadLaneAt(")?;
self.write_expr(module, argument, func_ctx)?;
write!(self.out, ", ")?;
self.write_expr(module, index, func_ctx)?;
}
crate::GatherMode::QuadSwap(direction) => {
match direction {
crate::Direction::X => {
write!(self.out, "QuadReadAcrossX(")?;
}
crate::Direction::Y => {
write!(self.out, "QuadReadAcrossY(")?;
}
crate::Direction::Diagonal => {
write!(self.out, "QuadReadAcrossDiagonal(")?;
}
}
crate::GatherMode::ShuffleDown(index) => {
write!(self.out, "WaveGetLaneIndex() + ")?;
self.write_expr(module, index, func_ctx)?;
}
crate::GatherMode::ShuffleUp(index) => {
write!(self.out, "WaveGetLaneIndex() - ")?;
self.write_expr(module, index, func_ctx)?;
}
crate::GatherMode::ShuffleXor(index) => {
write!(self.out, "WaveGetLaneIndex() ^ ")?;
self.write_expr(module, index, func_ctx)?;
self.write_expr(module, argument, func_ctx)?;
}
_ => {
write!(self.out, "WaveReadLaneAt(")?;
self.write_expr(module, argument, func_ctx)?;
write!(self.out, ", ")?;
match mode {
crate::GatherMode::BroadcastFirst => unreachable!(),
crate::GatherMode::Broadcast(index)
| crate::GatherMode::Shuffle(index) => {
self.write_expr(module, index, func_ctx)?;
}
crate::GatherMode::ShuffleDown(index) => {
write!(self.out, "WaveGetLaneIndex() + ")?;
self.write_expr(module, index, func_ctx)?;
}
crate::GatherMode::ShuffleUp(index) => {
write!(self.out, "WaveGetLaneIndex() - ")?;
self.write_expr(module, index, func_ctx)?;
}
crate::GatherMode::ShuffleXor(index) => {
write!(self.out, "WaveGetLaneIndex() ^ ")?;
self.write_expr(module, index, func_ctx)?;
}
crate::GatherMode::QuadBroadcast(_) => unreachable!(),
crate::GatherMode::QuadSwap(_) => unreachable!(),
}
}
}

View File

@ -4090,6 +4090,12 @@ impl<W: Write> Writer<W> {
crate::GatherMode::ShuffleXor(_) => {
write!(self.out, "{NAMESPACE}::simd_shuffle_xor(")?;
}
crate::GatherMode::QuadBroadcast(_) => {
write!(self.out, "{NAMESPACE}::quad_broadcast(")?;
}
crate::GatherMode::QuadSwap(_) => {
write!(self.out, "{NAMESPACE}::quad_shuffle_xor(")?;
}
}
self.put_expression(argument, &context.expression, true)?;
match mode {
@ -4098,10 +4104,25 @@ impl<W: Write> Writer<W> {
| crate::GatherMode::Shuffle(index)
| crate::GatherMode::ShuffleDown(index)
| crate::GatherMode::ShuffleUp(index)
| crate::GatherMode::ShuffleXor(index) => {
| crate::GatherMode::ShuffleXor(index)
| crate::GatherMode::QuadBroadcast(index) => {
write!(self.out, ", ")?;
self.put_expression(index, &context.expression, true)?;
}
crate::GatherMode::QuadSwap(direction) => {
write!(self.out, ", ")?;
match direction {
crate::Direction::X => {
write!(self.out, "1u")?;
}
crate::Direction::Y => {
write!(self.out, "2u")?;
}
crate::Direction::Diagonal => {
write!(self.out, "3u")?;
}
}
}
}
writeln!(self.out, ");")?;
}

View File

@ -759,9 +759,11 @@ fn adjust_stmt(new_pos: &HandleVec<Expression, Handle<Expression>>, stmt: &mut S
| crate::GatherMode::Shuffle(ref mut index)
| crate::GatherMode::ShuffleDown(ref mut index)
| crate::GatherMode::ShuffleUp(ref mut index)
| crate::GatherMode::ShuffleXor(ref mut index) => {
| crate::GatherMode::ShuffleXor(ref mut index)
| crate::GatherMode::QuadBroadcast(ref mut index) => {
adjust(index);
}
crate::GatherMode::QuadSwap(_) => {}
}
adjust(argument);
adjust(result)

View File

@ -1203,6 +1203,22 @@ impl super::Instruction {
}
instruction.add_operand(value);
instruction
}
pub(super) fn group_non_uniform_quad_swap(
result_type_id: Word,
id: Word,
exec_scope_id: Word,
value: Word,
direction: Word,
) -> Self {
let mut instruction = Self::new(Op::GroupNonUniformQuadSwap);
instruction.set_type(result_type_id);
instruction.set_result(id);
instruction.add_operand(exec_scope_id);
instruction.add_operand(value);
instruction.add_operand(direction);
instruction
}
}

View File

@ -125,10 +125,6 @@ impl BlockContext<'_> {
result: Handle<crate::Expression>,
block: &mut Block,
) -> Result<(), Error> {
self.writer.require_any(
"GroupNonUniformBallot",
&[spirv::Capability::GroupNonUniformBallot],
)?;
match *mode {
crate::GatherMode::BroadcastFirst => {
self.writer.require_any(
@ -150,6 +146,12 @@ impl BlockContext<'_> {
&[spirv::Capability::GroupNonUniformShuffleRelative],
)?;
}
crate::GatherMode::QuadBroadcast(_) | crate::GatherMode::QuadSwap(_) => {
self.writer.require_any(
"GroupNonUniformQuad",
&[spirv::Capability::GroupNonUniformQuad],
)?;
}
}
let id = self.gen_id();
@ -174,7 +176,8 @@ impl BlockContext<'_> {
| crate::GatherMode::Shuffle(index)
| crate::GatherMode::ShuffleDown(index)
| crate::GatherMode::ShuffleUp(index)
| crate::GatherMode::ShuffleXor(index) => {
| crate::GatherMode::ShuffleXor(index)
| crate::GatherMode::QuadBroadcast(index) => {
let index_id = self.cached[index];
let op = match *mode {
crate::GatherMode::BroadcastFirst => unreachable!(),
@ -187,6 +190,8 @@ impl BlockContext<'_> {
crate::GatherMode::ShuffleDown(_) => spirv::Op::GroupNonUniformShuffleDown,
crate::GatherMode::ShuffleUp(_) => spirv::Op::GroupNonUniformShuffleUp,
crate::GatherMode::ShuffleXor(_) => spirv::Op::GroupNonUniformShuffleXor,
crate::GatherMode::QuadBroadcast(_) => spirv::Op::GroupNonUniformQuadBroadcast,
crate::GatherMode::QuadSwap(_) => unreachable!(),
};
block.body.push(Instruction::group_non_uniform_gather(
op,
@ -197,6 +202,20 @@ impl BlockContext<'_> {
index_id,
));
}
crate::GatherMode::QuadSwap(direction) => {
let direction = self.get_index_constant(match direction {
crate::Direction::X => 0,
crate::Direction::Y => 1,
crate::Direction::Diagonal => 2,
});
block.body.push(Instruction::group_non_uniform_quad_swap(
result_type_id,
id,
exec_scope_id,
arg_id,
direction,
));
}
}
self.cached[result] = id;
Ok(())

View File

@ -945,6 +945,20 @@ impl<W: Write> Writer<W> {
crate::GatherMode::ShuffleXor(_) => {
write!(self.out, "subgroupShuffleXor(")?;
}
crate::GatherMode::QuadBroadcast(_) => {
write!(self.out, "quadBroadcast(")?;
}
crate::GatherMode::QuadSwap(direction) => match direction {
crate::Direction::X => {
write!(self.out, "quadSwapX(")?;
}
crate::Direction::Y => {
write!(self.out, "quadSwapY(")?;
}
crate::Direction::Diagonal => {
write!(self.out, "quadSwapDiagonal(")?;
}
},
}
self.write_expr(module, argument, func_ctx)?;
match mode {
@ -953,10 +967,12 @@ impl<W: Write> Writer<W> {
| crate::GatherMode::Shuffle(index)
| crate::GatherMode::ShuffleDown(index)
| crate::GatherMode::ShuffleUp(index)
| crate::GatherMode::ShuffleXor(index) => {
| crate::GatherMode::ShuffleXor(index)
| crate::GatherMode::QuadBroadcast(index) => {
write!(self.out, ", ")?;
self.write_expr(module, index, func_ctx)?;
}
crate::GatherMode::QuadSwap(_) => {}
}
writeln!(self.out, ");")?;
}

View File

@ -141,9 +141,11 @@ impl FunctionTracer<'_> {
| crate::GatherMode::Shuffle(index)
| crate::GatherMode::ShuffleDown(index)
| crate::GatherMode::ShuffleUp(index)
| crate::GatherMode::ShuffleXor(index) => {
| crate::GatherMode::ShuffleXor(index)
| crate::GatherMode::QuadBroadcast(index) => {
self.expressions_used.insert(index);
}
crate::GatherMode::QuadSwap(_) => {}
}
self.expressions_used.insert(argument);
self.expressions_used.insert(result);
@ -350,7 +352,9 @@ impl FunctionMap {
| crate::GatherMode::Shuffle(ref mut index)
| crate::GatherMode::ShuffleDown(ref mut index)
| crate::GatherMode::ShuffleUp(ref mut index)
| crate::GatherMode::ShuffleXor(ref mut index) => adjust(index),
| crate::GatherMode::ShuffleXor(ref mut index)
| crate::GatherMode::QuadBroadcast(ref mut index) => adjust(index),
crate::GatherMode::QuadSwap(_) => {}
}
adjust(argument);
adjust(result);

View File

@ -4064,7 +4064,8 @@ impl<I: Iterator<Item = u32>> Frontend<I> {
| Op::GroupNonUniformShuffle
| Op::GroupNonUniformShuffleDown
| Op::GroupNonUniformShuffleUp
| Op::GroupNonUniformShuffleXor => {
| Op::GroupNonUniformShuffleXor
| Op::GroupNonUniformQuadBroadcast => {
inst.expect(if matches!(inst.op, Op::GroupNonUniformBroadcastFirst) {
5
} else {
@ -4104,6 +4105,9 @@ impl<I: Iterator<Item = u32>> Frontend<I> {
Op::GroupNonUniformShuffleXor => {
crate::GatherMode::ShuffleXor(index_handle)
}
Op::GroupNonUniformQuadBroadcast => {
crate::GatherMode::QuadBroadcast(index_handle)
}
_ => unreachable!(),
}
};
@ -4135,6 +4139,60 @@ impl<I: Iterator<Item = u32>> Frontend<I> {
);
emitter.start(ctx.expressions);
}
Op::GroupNonUniformQuadSwap => {
inst.expect(6)?;
block.extend(emitter.finish(ctx.expressions));
let result_type_id = self.next()?;
let result_id = self.next()?;
let exec_scope_id = self.next()?;
let argument_id = self.next()?;
let direction_id = self.next()?;
let argument_lookup = self.lookup_expression.lookup(argument_id)?;
let argument_handle = get_expr_handle!(argument_id, argument_lookup);
let exec_scope_const = self.lookup_constant.lookup(exec_scope_id)?;
let _exec_scope = resolve_constant(ctx.gctx(), &exec_scope_const.inner)
.filter(|exec_scope| *exec_scope == spirv::Scope::Subgroup as u32)
.ok_or(Error::InvalidBarrierScope(exec_scope_id))?;
let direction_const = self.lookup_constant.lookup(direction_id)?;
let direction_const = resolve_constant(ctx.gctx(), &direction_const.inner)
.ok_or(Error::InvalidOperand)?;
let direction = match direction_const {
0 => crate::Direction::X,
1 => crate::Direction::Y,
2 => crate::Direction::Diagonal,
_ => unreachable!(),
};
let result_type = self.lookup_type.lookup(result_type_id)?;
let result_handle = ctx.expressions.append(
crate::Expression::SubgroupOperationResult {
ty: result_type.handle,
},
span,
);
self.lookup_expression.insert(
result_id,
LookupExpression {
handle: result_handle,
type_id: result_type_id,
block_id,
},
);
block.push(
crate::Statement::SubgroupGather {
mode: crate::GatherMode::QuadSwap(direction),
result: result_handle,
argument: argument_handle,
},
span,
);
emitter.start(ctx.expressions);
}
Op::AtomicLoad => {
inst.expect(6)?;
let start = self.data_offset;

View File

@ -1078,6 +1078,7 @@ enum SubgroupGather {
ShuffleDown,
ShuffleUp,
ShuffleXor,
QuadBroadcast,
}
impl SubgroupGather {
@ -1089,6 +1090,7 @@ impl SubgroupGather {
"subgroupShuffleDown" => Self::ShuffleDown,
"subgroupShuffleUp" => Self::ShuffleUp,
"subgroupShuffleXor" => Self::ShuffleXor,
"quadBroadcast" => Self::QuadBroadcast,
_ => return None,
})
}
@ -2940,6 +2942,77 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
.push(ir::Statement::SubgroupBallot { result, predicate }, span);
return Ok(Some(result));
}
"quadSwapX" => {
let mut args = ctx.prepare_args(arguments, 1, span);
let argument = self.expression(args.next()?, ctx)?;
args.finish()?;
let ty = ctx.register_type(argument)?;
let result = ctx.interrupt_emitter(
crate::Expression::SubgroupOperationResult { ty },
span,
)?;
let rctx = ctx.runtime_expression_ctx(span)?;
rctx.block.push(
crate::Statement::SubgroupGather {
mode: crate::GatherMode::QuadSwap(crate::Direction::X),
argument,
result,
},
span,
);
return Ok(Some(result));
}
"quadSwapY" => {
let mut args = ctx.prepare_args(arguments, 1, span);
let argument = self.expression(args.next()?, ctx)?;
args.finish()?;
let ty = ctx.register_type(argument)?;
let result = ctx.interrupt_emitter(
crate::Expression::SubgroupOperationResult { ty },
span,
)?;
let rctx = ctx.runtime_expression_ctx(span)?;
rctx.block.push(
crate::Statement::SubgroupGather {
mode: crate::GatherMode::QuadSwap(crate::Direction::Y),
argument,
result,
},
span,
);
return Ok(Some(result));
}
"quadSwapDiagonal" => {
let mut args = ctx.prepare_args(arguments, 1, span);
let argument = self.expression(args.next()?, ctx)?;
args.finish()?;
let ty = ctx.register_type(argument)?;
let result = ctx.interrupt_emitter(
crate::Expression::SubgroupOperationResult { ty },
span,
)?;
let rctx = ctx.runtime_expression_ctx(span)?;
rctx.block.push(
crate::Statement::SubgroupGather {
mode: crate::GatherMode::QuadSwap(crate::Direction::Diagonal),
argument,
result,
},
span,
);
return Ok(Some(result));
}
_ => {
return Err(Box::new(Error::UnknownIdent(function.span, function.name)))
}
@ -3471,12 +3544,13 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
} else {
let index = self.expression(args.next()?, ctx)?;
match mode {
Sg::BroadcastFirst => unreachable!(),
Sg::Broadcast => ir::GatherMode::Broadcast(index),
Sg::Shuffle => ir::GatherMode::Shuffle(index),
Sg::ShuffleDown => ir::GatherMode::ShuffleDown(index),
Sg::ShuffleUp => ir::GatherMode::ShuffleUp(index),
Sg::ShuffleXor => ir::GatherMode::ShuffleXor(index),
Sg::BroadcastFirst => unreachable!(),
Sg::QuadBroadcast => ir::GatherMode::QuadBroadcast(index),
}
};

View File

@ -1303,6 +1303,20 @@ pub enum GatherMode {
ShuffleUp(Handle<Expression>),
/// Each gathers from their lane xored with the given by the expression
ShuffleXor(Handle<Expression>),
/// All gather from the same quad lane at the index given by the expression
QuadBroadcast(Handle<Expression>),
/// Each gathers from the opposite quad lane along the given direction
QuadSwap(Direction),
}
#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
#[cfg_attr(feature = "serialize", derive(Serialize))]
#[cfg_attr(feature = "deserialize", derive(Deserialize))]
#[cfg_attr(feature = "arbitrary", derive(Arbitrary))]
pub enum Direction {
X = 0,
Y = 1,
Diagonal = 2,
}
#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]

View File

@ -1142,9 +1142,11 @@ impl FunctionInfo {
| crate::GatherMode::Shuffle(index)
| crate::GatherMode::ShuffleDown(index)
| crate::GatherMode::ShuffleUp(index)
| crate::GatherMode::ShuffleXor(index) => {
| crate::GatherMode::ShuffleXor(index)
| crate::GatherMode::QuadBroadcast(index) => {
let _ = self.add_ref(index);
}
crate::GatherMode::QuadSwap(_) => {}
}
FunctionUniformity::new()
}

View File

@ -73,6 +73,8 @@ pub enum SubgroupError {
UnsupportedOperation(super::SubgroupOperationSet),
#[error("Unknown operation")]
UnknownOperation,
#[error("Invocation ID must be a const-expression")]
InvalidInvocationIdExprType(Handle<crate::Expression>),
}
#[derive(Clone, Debug, thiserror::Error)]
@ -248,6 +250,7 @@ struct BlockContext<'a> {
special_types: &'a crate::SpecialTypes,
prev_infos: &'a [FunctionInfo],
return_type: Option<Handle<crate::Type>>,
local_expr_kind: &'a crate::proc::ExpressionKindTracker,
}
impl<'a> BlockContext<'a> {
@ -256,6 +259,7 @@ impl<'a> BlockContext<'a> {
module: &'a crate::Module,
info: &'a FunctionInfo,
prev_infos: &'a [FunctionInfo],
local_expr_kind: &'a crate::proc::ExpressionKindTracker,
) -> Self {
Self {
abilities: ControlFlowAbility::RETURN,
@ -268,6 +272,7 @@ impl<'a> BlockContext<'a> {
special_types: &module.special_types,
prev_infos,
return_type: fun.result.as_ref().map(|fr| fr.ty),
local_expr_kind,
}
}
@ -705,7 +710,8 @@ impl super::Validator {
| crate::GatherMode::Shuffle(index)
| crate::GatherMode::ShuffleDown(index)
| crate::GatherMode::ShuffleUp(index)
| crate::GatherMode::ShuffleXor(index) => {
| crate::GatherMode::ShuffleXor(index)
| crate::GatherMode::QuadBroadcast(index) => {
let index_ty = context.resolve_type_inner(index, &self.valid_expression_set)?;
match *index_ty {
crate::TypeInner::Scalar(crate::Scalar::U32) => {}
@ -720,6 +726,17 @@ impl super::Validator {
}
}
}
crate::GatherMode::QuadSwap(_) => {}
}
match *mode {
crate::GatherMode::Broadcast(index) | crate::GatherMode::QuadBroadcast(index) => {
if !context.local_expr_kind.is_const(index) {
return Err(SubgroupError::InvalidInvocationIdExprType(index)
.with_span_handle(index, context.expressions)
.into_other());
}
}
_ => {}
}
let argument_inner = context.resolve_type_inner(argument, &self.valid_expression_set)?;
if !matches!(*argument_inner,
@ -1772,7 +1789,7 @@ impl super::Validator {
let stages = self
.validate_block(
&fun.body,
&BlockContext::new(fun, module, &info, &mod_info.functions),
&BlockContext::new(fun, module, &info, &mod_info.functions, &local_expr_kind),
)?
.stages;
info.available_stages &= stages;

View File

@ -740,7 +740,9 @@ impl super::Validator {
| crate::GatherMode::Shuffle(index)
| crate::GatherMode::ShuffleDown(index)
| crate::GatherMode::ShuffleUp(index)
| crate::GatherMode::ShuffleXor(index) => validate_expr(index)?,
| crate::GatherMode::ShuffleXor(index)
| crate::GatherMode::QuadBroadcast(index) => validate_expr(index)?,
crate::GatherMode::QuadSwap(_) => {}
}
validate_expr(result)?;
Ok(())

View File

@ -29,7 +29,7 @@ pub use analyzer::{ExpressionInfo, FunctionInfo, GlobalUse, Uniformity, Uniformi
pub use compose::ComposeError;
pub use expression::{check_literal_value, LiteralError};
pub use expression::{ConstExpressionError, ExpressionError};
pub use function::{CallError, FunctionError, LocalVariableError};
pub use function::{CallError, FunctionError, LocalVariableError, SubgroupError};
pub use interface::{EntryPointError, GlobalVariableError, VaryingError};
pub use r#type::{Disalignment, PushConstantError, TypeError, TypeFlags, WidthError};
@ -195,8 +195,8 @@ bitflags::bitflags! {
// We don't support these operations yet
// /// Clustered
// const CLUSTERED = 1 << 6;
// /// Quad supported
// const QUAD_FRAGMENT_COMPUTE = 1 << 7;
/// Quad supported
const QUAD_FRAGMENT_COMPUTE = 1 << 7;
// /// Quad supported in all stages
// const QUAD_ALL_STAGES = 1 << 8;
}
@ -221,6 +221,7 @@ impl super::GatherMode {
Self::BroadcastFirst | Self::Broadcast(_) => S::BALLOT,
Self::Shuffle(_) | Self::ShuffleXor(_) => S::SHUFFLE,
Self::ShuffleUp(_) | Self::ShuffleDown(_) => S::SHUFFLE_RELATIVE,
Self::QuadBroadcast(_) | Self::QuadSwap(_) => S::QUAD_FRAGMENT_COMPUTE,
}
}
}
@ -457,7 +458,13 @@ impl Validator {
pub fn new(flags: ValidationFlags, capabilities: Capabilities) -> Self {
let subgroup_operations = if capabilities.contains(Capabilities::SUBGROUP) {
use SubgroupOperationSet as S;
S::BASIC | S::VOTE | S::ARITHMETIC | S::BALLOT | S::SHUFFLE | S::SHUFFLE_RELATIVE
S::BASIC
| S::VOTE
| S::ARITHMETIC
| S::BALLOT
| S::SHUFFLE
| S::SHUFFLE_RELATIVE
| S::QUAD_FRAGMENT_COMPUTE
} else {
SubgroupOperationSet::empty()
};

View File

@ -34,4 +34,9 @@ fn main(
subgroupShuffleDown(subgroup_invocation_id, 1u);
subgroupShuffleUp(subgroup_invocation_id, 1u);
subgroupShuffleXor(subgroup_invocation_id, sizes.subgroup_size - 1u);
quadBroadcast(subgroup_invocation_id, 4u);
quadSwapX(subgroup_invocation_id);
quadSwapY(subgroup_invocation_id);
quadSwapDiagonal(subgroup_invocation_id);
}

View File

@ -3601,3 +3601,35 @@ fn const_eval_value_errors() {
assert!(variant("f32(abs(-9223372036854775807))").is_ok());
assert!(variant("f32(abs(-9223372036854775807 - 1))").is_ok());
}
#[test]
fn subgroup_invalid_broadcast() {
check_validation! {
r#"
fn main(id: u32) {
subgroupBroadcast(123, id);
}
"#:
Err(naga::valid::ValidationError::Function {
source: naga::valid::FunctionError::InvalidSubgroup(
naga::valid::SubgroupError::InvalidInvocationIdExprType(_),
),
..
}),
naga::valid::Capabilities::SUBGROUP
}
check_validation! {
r#"
fn main(id: u32) {
quadBroadcast(123, id);
}
"#:
Err(naga::valid::ValidationError::Function {
source: naga::valid::FunctionError::InvalidSubgroup(
naga::valid::SubgroupError::InvalidInvocationIdExprType(_),
),
..
}),
naga::valid::Capabilities::SUBGROUP
}
}

View File

@ -6,6 +6,7 @@
#extension GL_KHR_shader_subgroup_ballot : require
#extension GL_KHR_shader_subgroup_shuffle : require
#extension GL_KHR_shader_subgroup_shuffle_relative : require
#extension GL_KHR_shader_subgroup_quad : require
layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
uint global = 0u;

View File

@ -6,6 +6,7 @@
#extension GL_KHR_shader_subgroup_ballot : require
#extension GL_KHR_shader_subgroup_shuffle : require
#extension GL_KHR_shader_subgroup_shuffle_relative : require
#extension GL_KHR_shader_subgroup_quad : require
layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
struct Structure {
@ -40,6 +41,10 @@ void main() {
uint _e35 = subgroupShuffleDown(subgroup_invocation_id, 1u);
uint _e37 = subgroupShuffleUp(subgroup_invocation_id, 1u);
uint _e41 = subgroupShuffleXor(subgroup_invocation_id, (sizes.subgroup_size - 1u));
uint _e43 = subgroupQuadBroadcast(subgroup_invocation_id, 4u);
uint _e44 = subgroupQuadSwapHorizontal(subgroup_invocation_id);
uint _e45 = subgroupQuadSwapVertical(subgroup_invocation_id);
uint _e46 = subgroupQuadSwapDiagonal(subgroup_invocation_id);
return;
}

View File

@ -34,5 +34,9 @@ void main(ComputeInput_main computeinput_main)
const uint _e35 = WaveReadLaneAt(subgroup_invocation_id, WaveGetLaneIndex() + 1u);
const uint _e37 = WaveReadLaneAt(subgroup_invocation_id, WaveGetLaneIndex() - 1u);
const uint _e41 = WaveReadLaneAt(subgroup_invocation_id, WaveGetLaneIndex() ^ (sizes.subgroup_size - 1u));
const uint _e43 = QuadReadLaneAt(subgroup_invocation_id, 4u);
const uint _e44 = QuadReadAcrossX(subgroup_invocation_id);
const uint _e45 = QuadReadAcrossY(subgroup_invocation_id);
const uint _e46 = QuadReadAcrossDiagonal(subgroup_invocation_id);
return;
}

View File

@ -40,5 +40,9 @@ kernel void main_(
uint unnamed_18 = metal::simd_shuffle_down(subgroup_invocation_id, 1u);
uint unnamed_19 = metal::simd_shuffle_up(subgroup_invocation_id, 1u);
uint unnamed_20 = metal::simd_shuffle_xor(subgroup_invocation_id, sizes.subgroup_size - 1u);
uint unnamed_21 = metal::quad_broadcast(subgroup_invocation_id, 4u);
uint unnamed_22 = metal::quad_shuffle_xor(subgroup_invocation_id, 1u);
uint unnamed_23 = metal::quad_shuffle_xor(subgroup_invocation_id, 2u);
uint unnamed_24 = metal::quad_shuffle_xor(subgroup_invocation_id, 3u);
return;
}

View File

@ -1,7 +1,7 @@
; SPIR-V
; Version: 1.3
; Generator: rspirv
; Bound: 58
; Bound: 62
OpCapability Shader
OpCapability GroupNonUniform
OpCapability GroupNonUniformBallot
@ -9,6 +9,7 @@ OpCapability GroupNonUniformVote
OpCapability GroupNonUniformArithmetic
OpCapability GroupNonUniformShuffle
OpCapability GroupNonUniformShuffleRelative
OpCapability GroupNonUniformQuad
%1 = OpExtInstImport "GLSL.std.450"
OpMemoryModel Logical GLSL450
OpEntryPoint GLCompute %17 "main" %8 %11 %13 %15
@ -77,5 +78,9 @@ OpControlBarrier %23 %24 %25
%55 = OpCompositeExtract %3 %7 1
%56 = OpISub %3 %55 %19
%57 = OpGroupNonUniformShuffleXor %3 %23 %16 %56
%58 = OpGroupNonUniformQuadBroadcast %3 %23 %16 %21
%59 = OpGroupNonUniformQuadSwap %3 %23 %16 %20
%60 = OpGroupNonUniformQuadSwap %3 %23 %16 %19
%61 = OpGroupNonUniformQuadSwap %3 %23 %16 %24
OpReturn
OpFunctionEnd

View File

@ -27,5 +27,9 @@ fn main(sizes: Structure, @builtin(subgroup_id) subgroup_id: u32, @builtin(subgr
let _e35 = subgroupShuffleDown(subgroup_invocation_id, 1u);
let _e37 = subgroupShuffleUp(subgroup_invocation_id, 1u);
let _e41 = subgroupShuffleXor(subgroup_invocation_id, (sizes.subgroup_size - 1u));
let _e43 = quadBroadcast(subgroup_invocation_id, 4u);
let _e44 = quadSwapX(subgroup_invocation_id);
let _e45 = quadSwapY(subgroup_invocation_id);
let _e46 = quadSwapDiagonal(subgroup_invocation_id);
return;
}

View File

@ -3,7 +3,7 @@ use std::num::NonZeroU64;
use wgpu_test::{gpu_test, GpuTestConfiguration, TestParameters};
const THREAD_COUNT: u64 = 128;
const TEST_COUNT: u32 = 32;
const TEST_COUNT: u32 = 37;
#[gpu_test]
static SUBGROUP_OPERATIONS: GpuTestConfiguration = GpuTestConfiguration::new()
@ -35,7 +35,7 @@ static SUBGROUP_OPERATIONS: GpuTestConfiguration = GpuTestConfiguration::new()
let storage_buffer = device.create_buffer(&wgpu::BufferDescriptor {
label: None,
size: THREAD_COUNT * size_of::<u32>() as u64,
size: THREAD_COUNT * size_of::<u64>() as u64,
usage: wgpu::BufferUsages::STORAGE
| wgpu::BufferUsages::COPY_DST
| wgpu::BufferUsages::COPY_SRC,
@ -50,7 +50,7 @@ static SUBGROUP_OPERATIONS: GpuTestConfiguration = GpuTestConfiguration::new()
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Storage { read_only: false },
has_dynamic_offset: false,
min_binding_size: NonZeroU64::new(THREAD_COUNT * size_of::<u32>() as u64),
min_binding_size: NonZeroU64::new(THREAD_COUNT * size_of::<u64>() as u64),
},
count: None,
}],
@ -101,10 +101,10 @@ static SUBGROUP_OPERATIONS: GpuTestConfiguration = GpuTestConfiguration::new()
&storage_buffer.slice(..),
|mapping_buffer_view| {
let mapping_buffer_view = mapping_buffer_view.unwrap();
let result: &[u32; THREAD_COUNT as usize] =
let result: &[u64; THREAD_COUNT as usize] =
bytemuck::from_bytes(&mapping_buffer_view);
let expected_mask = (1u64 << (TEST_COUNT)) - 1; // generate full mask
let expected_array = [expected_mask as u32; THREAD_COUNT as usize];
let expected_array = [expected_mask; THREAD_COUNT as usize];
if result != &expected_array {
use std::fmt::Write;
let mut msg = String::new();
@ -122,7 +122,7 @@ static SUBGROUP_OPERATIONS: GpuTestConfiguration = GpuTestConfiguration::new()
{
write!(&mut msg, "thread {thread} failed tests:").unwrap();
let difference = result ^ expected;
for i in (0..u32::BITS).filter(|i| (difference & (1 << i)) != 0) {
for i in (0..u64::BITS).filter(|i| (difference & (1 << i)) != 0) {
write!(&mut msg, " {i},").unwrap();
}
writeln!(&mut msg).unwrap();

View File

@ -1,11 +1,11 @@
@group(0)
@binding(0)
var<storage, read_write> storage_buffer: array<u32>;
var<storage, read_write> storage_buffer: array<vec2<u32>>;
var<workgroup> workgroup_buffer: u32;
fn add_result_to_mask(mask: ptr<function, u32>, index: u32, value: bool) {
(*mask) |= u32(value) << index;
fn add_result_to_mask(mask: ptr<function, vec2<u32>>, index: u32, value: bool) {
(*mask)[index / 32u] |= u32(value) << (index % 32u);
}
@compute
@ -17,7 +17,7 @@ fn main(
@builtin(subgroup_size) subgroup_size: u32,
@builtin(subgroup_invocation_id) subgroup_invocation_id: u32,
) {
var passed = 0u;
var passed = vec2<u32>(0u);
var expected: u32;
add_result_to_mask(&passed, 0u, num_subgroups == 128u / subgroup_size);
@ -152,8 +152,14 @@ fn main(
workgroupBarrier();
add_result_to_mask(&passed, 30u, workgroup_buffer == subgroup_size);
add_result_to_mask(&passed, 31u, quadBroadcast(quadBroadcast(subgroup_invocation_id, 0u) ^ subgroup_invocation_id, 0u) == 0u);
add_result_to_mask(&passed, 32u, quadBroadcast(quadBroadcast(subgroup_invocation_id, 1u) ^ quadSwapX(subgroup_invocation_id), 0u) == 0u);
add_result_to_mask(&passed, 33u, quadBroadcast(quadBroadcast(subgroup_invocation_id, 2u) ^ quadSwapY(subgroup_invocation_id), 0u) == 0u);
add_result_to_mask(&passed, 34u, quadBroadcast(quadBroadcast(subgroup_invocation_id, 3u) ^ quadSwapDiagonal(subgroup_invocation_id), 0u) == 0u);
add_result_to_mask(&passed, 35u, quadSwapX(quadSwapY(subgroup_invocation_id)) == quadSwapDiagonal(subgroup_invocation_id));
// Keep this test last, verify we are still convergent after running other tests
add_result_to_mask(&passed, 31u, subgroupAdd(1u) == subgroup_size);
add_result_to_mask(&passed, 36u, subgroupAdd(1u) == subgroup_size);
// Increment TEST_COUNT in subgroup_operations/mod.rs if adding more tests

View File

@ -752,7 +752,8 @@ impl PhysicalDeviceFeatures {
| vk::SubgroupFeatureFlags::ARITHMETIC
| vk::SubgroupFeatureFlags::BALLOT
| vk::SubgroupFeatureFlags::SHUFFLE
| vk::SubgroupFeatureFlags::SHUFFLE_RELATIVE,
| vk::SubgroupFeatureFlags::SHUFFLE_RELATIVE
| vk::SubgroupFeatureFlags::QUAD,
)
{
features.set(
@ -1978,6 +1979,7 @@ impl super::Adapter {
capabilities.push(spv::Capability::GroupNonUniformBallot);
capabilities.push(spv::Capability::GroupNonUniformShuffle);
capabilities.push(spv::Capability::GroupNonUniformShuffleRelative);
capabilities.push(spv::Capability::GroupNonUniformQuad);
}
if features.intersects(