mirror of
https://github.com/gfx-rs/wgpu.git
synced 2025-12-08 21:26:17 +00:00
Subgroup Operations (#5301)
Co-authored-by: Jacob Hughes <j@distanthills.org> Co-authored-by: Connor Fitzgerald <connorwadefitzgerald@gmail.com> Co-authored-by: atlas dostal <rodol@rivalrebels.com>
This commit is contained in:
parent
0dc9dd6bec
commit
ea77d5674d
@ -138,6 +138,7 @@ Bottom level categories:
|
|||||||
### Bug Fixes
|
### Bug Fixes
|
||||||
|
|
||||||
#### General
|
#### General
|
||||||
|
- Add `SUBGROUP, SUBGROUP_VERTEX, SUBGROUP_BARRIER` features. By @exrook and @lichtso in [#5301](https://github.com/gfx-rs/wgpu/pull/5301)
|
||||||
- Fix `serde` feature not compiling for `wgpu-types`. By @KirmesBude in [#5149](https://github.com/gfx-rs/wgpu/pull/5149)
|
- Fix `serde` feature not compiling for `wgpu-types`. By @KirmesBude in [#5149](https://github.com/gfx-rs/wgpu/pull/5149)
|
||||||
- Fix the validation of vertex and index ranges. By @nical in [#5144](https://github.com/gfx-rs/wgpu/pull/5144) and [#5156](https://github.com/gfx-rs/wgpu/pull/5156)
|
- Fix the validation of vertex and index ranges. By @nical in [#5144](https://github.com/gfx-rs/wgpu/pull/5144) and [#5156](https://github.com/gfx-rs/wgpu/pull/5156)
|
||||||
- Fix panic when creating a surface while no backend is available. By @wumpf [#5166](https://github.com/gfx-rs/wgpu/pull/5166)
|
- Fix panic when creating a surface while no backend is available. By @wumpf [#5166](https://github.com/gfx-rs/wgpu/pull/5166)
|
||||||
|
|||||||
@ -424,6 +424,8 @@ fn run() -> Result<(), Box<dyn std::error::Error>> {
|
|||||||
|
|
||||||
// Validate the IR before compaction.
|
// Validate the IR before compaction.
|
||||||
let info = match naga::valid::Validator::new(params.validation_flags, validation_caps)
|
let info = match naga::valid::Validator::new(params.validation_flags, validation_caps)
|
||||||
|
.subgroup_stages(naga::valid::ShaderStages::all())
|
||||||
|
.subgroup_operations(naga::valid::SubgroupOperationSet::all())
|
||||||
.validate(&module)
|
.validate(&module)
|
||||||
{
|
{
|
||||||
Ok(info) => Some(info),
|
Ok(info) => Some(info),
|
||||||
@ -760,6 +762,8 @@ fn bulk_validate(args: Args, params: &Parameters) -> Result<(), Box<dyn std::err
|
|||||||
|
|
||||||
let mut validator =
|
let mut validator =
|
||||||
naga::valid::Validator::new(params.validation_flags, naga::valid::Capabilities::all());
|
naga::valid::Validator::new(params.validation_flags, naga::valid::Capabilities::all());
|
||||||
|
validator.subgroup_stages(naga::valid::ShaderStages::all());
|
||||||
|
validator.subgroup_operations(naga::valid::SubgroupOperationSet::all());
|
||||||
|
|
||||||
if let Err(error) = validator.validate(&module) {
|
if let Err(error) = validator.validate(&module) {
|
||||||
invalid.push(input_path.clone());
|
invalid.push(input_path.clone());
|
||||||
|
|||||||
@ -279,6 +279,94 @@ impl StatementGraph {
|
|||||||
crate::RayQueryFunction::Terminate => "RayQueryTerminate",
|
crate::RayQueryFunction::Terminate => "RayQueryTerminate",
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
S::SubgroupBallot { result, predicate } => {
|
||||||
|
if let Some(predicate) = predicate {
|
||||||
|
self.dependencies.push((id, predicate, "predicate"));
|
||||||
|
}
|
||||||
|
self.emits.push((id, result));
|
||||||
|
"SubgroupBallot"
|
||||||
|
}
|
||||||
|
S::SubgroupCollectiveOperation {
|
||||||
|
op,
|
||||||
|
collective_op,
|
||||||
|
argument,
|
||||||
|
result,
|
||||||
|
} => {
|
||||||
|
self.dependencies.push((id, argument, "arg"));
|
||||||
|
self.emits.push((id, result));
|
||||||
|
match (collective_op, op) {
|
||||||
|
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::All) => {
|
||||||
|
"SubgroupAll"
|
||||||
|
}
|
||||||
|
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Any) => {
|
||||||
|
"SubgroupAny"
|
||||||
|
}
|
||||||
|
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Add) => {
|
||||||
|
"SubgroupAdd"
|
||||||
|
}
|
||||||
|
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Mul) => {
|
||||||
|
"SubgroupMul"
|
||||||
|
}
|
||||||
|
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Max) => {
|
||||||
|
"SubgroupMax"
|
||||||
|
}
|
||||||
|
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Min) => {
|
||||||
|
"SubgroupMin"
|
||||||
|
}
|
||||||
|
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::And) => {
|
||||||
|
"SubgroupAnd"
|
||||||
|
}
|
||||||
|
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Or) => {
|
||||||
|
"SubgroupOr"
|
||||||
|
}
|
||||||
|
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Xor) => {
|
||||||
|
"SubgroupXor"
|
||||||
|
}
|
||||||
|
(
|
||||||
|
crate::CollectiveOperation::ExclusiveScan,
|
||||||
|
crate::SubgroupOperation::Add,
|
||||||
|
) => "SubgroupExclusiveAdd",
|
||||||
|
(
|
||||||
|
crate::CollectiveOperation::ExclusiveScan,
|
||||||
|
crate::SubgroupOperation::Mul,
|
||||||
|
) => "SubgroupExclusiveMul",
|
||||||
|
(
|
||||||
|
crate::CollectiveOperation::InclusiveScan,
|
||||||
|
crate::SubgroupOperation::Add,
|
||||||
|
) => "SubgroupInclusiveAdd",
|
||||||
|
(
|
||||||
|
crate::CollectiveOperation::InclusiveScan,
|
||||||
|
crate::SubgroupOperation::Mul,
|
||||||
|
) => "SubgroupInclusiveMul",
|
||||||
|
_ => unimplemented!(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
S::SubgroupGather {
|
||||||
|
mode,
|
||||||
|
argument,
|
||||||
|
result,
|
||||||
|
} => {
|
||||||
|
match mode {
|
||||||
|
crate::GatherMode::BroadcastFirst => {}
|
||||||
|
crate::GatherMode::Broadcast(index)
|
||||||
|
| crate::GatherMode::Shuffle(index)
|
||||||
|
| crate::GatherMode::ShuffleDown(index)
|
||||||
|
| crate::GatherMode::ShuffleUp(index)
|
||||||
|
| crate::GatherMode::ShuffleXor(index) => {
|
||||||
|
self.dependencies.push((id, index, "index"))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
self.dependencies.push((id, argument, "arg"));
|
||||||
|
self.emits.push((id, result));
|
||||||
|
match mode {
|
||||||
|
crate::GatherMode::BroadcastFirst => "SubgroupBroadcastFirst",
|
||||||
|
crate::GatherMode::Broadcast(_) => "SubgroupBroadcast",
|
||||||
|
crate::GatherMode::Shuffle(_) => "SubgroupShuffle",
|
||||||
|
crate::GatherMode::ShuffleDown(_) => "SubgroupShuffleDown",
|
||||||
|
crate::GatherMode::ShuffleUp(_) => "SubgroupShuffleUp",
|
||||||
|
crate::GatherMode::ShuffleXor(_) => "SubgroupShuffleXor",
|
||||||
|
}
|
||||||
|
}
|
||||||
};
|
};
|
||||||
// Set the last node to the merge node
|
// Set the last node to the merge node
|
||||||
last_node = merge_id;
|
last_node = merge_id;
|
||||||
@ -587,6 +675,8 @@ fn write_function_expressions(
|
|||||||
let ty = if committed { "Committed" } else { "Candidate" };
|
let ty = if committed { "Committed" } else { "Candidate" };
|
||||||
(format!("rayQueryGet{}Intersection", ty).into(), 4)
|
(format!("rayQueryGet{}Intersection", ty).into(), 4)
|
||||||
}
|
}
|
||||||
|
E::SubgroupBallotResult => ("SubgroupBallotResult".into(), 4),
|
||||||
|
E::SubgroupOperationResult { .. } => ("SubgroupOperationResult".into(), 4),
|
||||||
};
|
};
|
||||||
|
|
||||||
// give uniform expressions an outline
|
// give uniform expressions an outline
|
||||||
|
|||||||
@ -50,6 +50,8 @@ bitflags::bitflags! {
|
|||||||
const INSTANCE_INDEX = 1 << 22;
|
const INSTANCE_INDEX = 1 << 22;
|
||||||
/// Sample specific LODs of cube / array shadow textures
|
/// Sample specific LODs of cube / array shadow textures
|
||||||
const TEXTURE_SHADOW_LOD = 1 << 23;
|
const TEXTURE_SHADOW_LOD = 1 << 23;
|
||||||
|
/// Subgroup operations
|
||||||
|
const SUBGROUP_OPERATIONS = 1 << 24;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -117,6 +119,7 @@ impl FeaturesManager {
|
|||||||
check_feature!(SAMPLE_VARIABLES, 400, 300);
|
check_feature!(SAMPLE_VARIABLES, 400, 300);
|
||||||
check_feature!(DYNAMIC_ARRAY_SIZE, 430, 310);
|
check_feature!(DYNAMIC_ARRAY_SIZE, 430, 310);
|
||||||
check_feature!(DUAL_SOURCE_BLENDING, 330, 300 /* with extension */);
|
check_feature!(DUAL_SOURCE_BLENDING, 330, 300 /* with extension */);
|
||||||
|
check_feature!(SUBGROUP_OPERATIONS, 430, 310);
|
||||||
match version {
|
match version {
|
||||||
Version::Embedded { is_webgl: true, .. } => check_feature!(MULTI_VIEW, 140, 300),
|
Version::Embedded { is_webgl: true, .. } => check_feature!(MULTI_VIEW, 140, 300),
|
||||||
_ => check_feature!(MULTI_VIEW, 140, 310),
|
_ => check_feature!(MULTI_VIEW, 140, 310),
|
||||||
@ -259,6 +262,22 @@ impl FeaturesManager {
|
|||||||
writeln!(out, "#extension GL_EXT_texture_shadow_lod : require")?;
|
writeln!(out, "#extension GL_EXT_texture_shadow_lod : require")?;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if self.0.contains(Features::SUBGROUP_OPERATIONS) {
|
||||||
|
// https://registry.khronos.org/OpenGL/extensions/KHR/KHR_shader_subgroup.txt
|
||||||
|
writeln!(out, "#extension GL_KHR_shader_subgroup_basic : require")?;
|
||||||
|
writeln!(out, "#extension GL_KHR_shader_subgroup_vote : require")?;
|
||||||
|
writeln!(
|
||||||
|
out,
|
||||||
|
"#extension GL_KHR_shader_subgroup_arithmetic : require"
|
||||||
|
)?;
|
||||||
|
writeln!(out, "#extension GL_KHR_shader_subgroup_ballot : require")?;
|
||||||
|
writeln!(out, "#extension GL_KHR_shader_subgroup_shuffle : require")?;
|
||||||
|
writeln!(
|
||||||
|
out,
|
||||||
|
"#extension GL_KHR_shader_subgroup_shuffle_relative : require"
|
||||||
|
)?;
|
||||||
|
}
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -518,6 +537,10 @@ impl<'a, W> Writer<'a, W> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Expression::SubgroupBallotResult |
|
||||||
|
Expression::SubgroupOperationResult { .. } => {
|
||||||
|
features.request(Features::SUBGROUP_OPERATIONS)
|
||||||
|
}
|
||||||
_ => {}
|
_ => {}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -2390,6 +2390,125 @@ impl<'a, W: Write> Writer<'a, W> {
|
|||||||
writeln!(self.out, ");")?;
|
writeln!(self.out, ");")?;
|
||||||
}
|
}
|
||||||
Statement::RayQuery { .. } => unreachable!(),
|
Statement::RayQuery { .. } => unreachable!(),
|
||||||
|
Statement::SubgroupBallot { result, predicate } => {
|
||||||
|
write!(self.out, "{level}")?;
|
||||||
|
let res_name = format!("{}{}", back::BAKE_PREFIX, result.index());
|
||||||
|
let res_ty = ctx.info[result].ty.inner_with(&self.module.types);
|
||||||
|
self.write_value_type(res_ty)?;
|
||||||
|
write!(self.out, " {res_name} = ")?;
|
||||||
|
self.named_expressions.insert(result, res_name);
|
||||||
|
|
||||||
|
write!(self.out, "subgroupBallot(")?;
|
||||||
|
match predicate {
|
||||||
|
Some(predicate) => self.write_expr(predicate, ctx)?,
|
||||||
|
None => write!(self.out, "true")?,
|
||||||
|
}
|
||||||
|
writeln!(self.out, ");")?;
|
||||||
|
}
|
||||||
|
Statement::SubgroupCollectiveOperation {
|
||||||
|
op,
|
||||||
|
collective_op,
|
||||||
|
argument,
|
||||||
|
result,
|
||||||
|
} => {
|
||||||
|
write!(self.out, "{level}")?;
|
||||||
|
let res_name = format!("{}{}", back::BAKE_PREFIX, result.index());
|
||||||
|
let res_ty = ctx.info[result].ty.inner_with(&self.module.types);
|
||||||
|
self.write_value_type(res_ty)?;
|
||||||
|
write!(self.out, " {res_name} = ")?;
|
||||||
|
self.named_expressions.insert(result, res_name);
|
||||||
|
|
||||||
|
match (collective_op, op) {
|
||||||
|
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::All) => {
|
||||||
|
write!(self.out, "subgroupAll(")?
|
||||||
|
}
|
||||||
|
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Any) => {
|
||||||
|
write!(self.out, "subgroupAny(")?
|
||||||
|
}
|
||||||
|
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Add) => {
|
||||||
|
write!(self.out, "subgroupAdd(")?
|
||||||
|
}
|
||||||
|
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Mul) => {
|
||||||
|
write!(self.out, "subgroupMul(")?
|
||||||
|
}
|
||||||
|
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Max) => {
|
||||||
|
write!(self.out, "subgroupMax(")?
|
||||||
|
}
|
||||||
|
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Min) => {
|
||||||
|
write!(self.out, "subgroupMin(")?
|
||||||
|
}
|
||||||
|
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::And) => {
|
||||||
|
write!(self.out, "subgroupAnd(")?
|
||||||
|
}
|
||||||
|
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Or) => {
|
||||||
|
write!(self.out, "subgroupOr(")?
|
||||||
|
}
|
||||||
|
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Xor) => {
|
||||||
|
write!(self.out, "subgroupXor(")?
|
||||||
|
}
|
||||||
|
(crate::CollectiveOperation::ExclusiveScan, crate::SubgroupOperation::Add) => {
|
||||||
|
write!(self.out, "subgroupExclusiveAdd(")?
|
||||||
|
}
|
||||||
|
(crate::CollectiveOperation::ExclusiveScan, crate::SubgroupOperation::Mul) => {
|
||||||
|
write!(self.out, "subgroupExclusiveMul(")?
|
||||||
|
}
|
||||||
|
(crate::CollectiveOperation::InclusiveScan, crate::SubgroupOperation::Add) => {
|
||||||
|
write!(self.out, "subgroupInclusiveAdd(")?
|
||||||
|
}
|
||||||
|
(crate::CollectiveOperation::InclusiveScan, crate::SubgroupOperation::Mul) => {
|
||||||
|
write!(self.out, "subgroupInclusiveMul(")?
|
||||||
|
}
|
||||||
|
_ => unimplemented!(),
|
||||||
|
}
|
||||||
|
self.write_expr(argument, ctx)?;
|
||||||
|
writeln!(self.out, ");")?;
|
||||||
|
}
|
||||||
|
Statement::SubgroupGather {
|
||||||
|
mode,
|
||||||
|
argument,
|
||||||
|
result,
|
||||||
|
} => {
|
||||||
|
write!(self.out, "{level}")?;
|
||||||
|
let res_name = format!("{}{}", back::BAKE_PREFIX, result.index());
|
||||||
|
let res_ty = ctx.info[result].ty.inner_with(&self.module.types);
|
||||||
|
self.write_value_type(res_ty)?;
|
||||||
|
write!(self.out, " {res_name} = ")?;
|
||||||
|
self.named_expressions.insert(result, res_name);
|
||||||
|
|
||||||
|
match mode {
|
||||||
|
crate::GatherMode::BroadcastFirst => {
|
||||||
|
write!(self.out, "subgroupBroadcastFirst(")?;
|
||||||
|
}
|
||||||
|
crate::GatherMode::Broadcast(_) => {
|
||||||
|
write!(self.out, "subgroupBroadcast(")?;
|
||||||
|
}
|
||||||
|
crate::GatherMode::Shuffle(_) => {
|
||||||
|
write!(self.out, "subgroupShuffle(")?;
|
||||||
|
}
|
||||||
|
crate::GatherMode::ShuffleDown(_) => {
|
||||||
|
write!(self.out, "subgroupShuffleDown(")?;
|
||||||
|
}
|
||||||
|
crate::GatherMode::ShuffleUp(_) => {
|
||||||
|
write!(self.out, "subgroupShuffleUp(")?;
|
||||||
|
}
|
||||||
|
crate::GatherMode::ShuffleXor(_) => {
|
||||||
|
write!(self.out, "subgroupShuffleXor(")?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
self.write_expr(argument, ctx)?;
|
||||||
|
match mode {
|
||||||
|
crate::GatherMode::BroadcastFirst => {}
|
||||||
|
crate::GatherMode::Broadcast(index)
|
||||||
|
| crate::GatherMode::Shuffle(index)
|
||||||
|
| crate::GatherMode::ShuffleDown(index)
|
||||||
|
| crate::GatherMode::ShuffleUp(index)
|
||||||
|
| crate::GatherMode::ShuffleXor(index) => {
|
||||||
|
write!(self.out, ", ")?;
|
||||||
|
self.write_expr(index, ctx)?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
writeln!(self.out, ");")?;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
@ -3658,7 +3777,9 @@ impl<'a, W: Write> Writer<'a, W> {
|
|||||||
Expression::CallResult(_)
|
Expression::CallResult(_)
|
||||||
| Expression::AtomicResult { .. }
|
| Expression::AtomicResult { .. }
|
||||||
| Expression::RayQueryProceedResult
|
| Expression::RayQueryProceedResult
|
||||||
| Expression::WorkGroupUniformLoadResult { .. } => unreachable!(),
|
| Expression::WorkGroupUniformLoadResult { .. }
|
||||||
|
| Expression::SubgroupOperationResult { .. }
|
||||||
|
| Expression::SubgroupBallotResult => unreachable!(),
|
||||||
// `ArrayLength` is written as `expr.length()` and we convert it to a uint
|
// `ArrayLength` is written as `expr.length()` and we convert it to a uint
|
||||||
Expression::ArrayLength(expr) => {
|
Expression::ArrayLength(expr) => {
|
||||||
write!(self.out, "uint(")?;
|
write!(self.out, "uint(")?;
|
||||||
@ -4227,6 +4348,9 @@ impl<'a, W: Write> Writer<'a, W> {
|
|||||||
if flags.contains(crate::Barrier::WORK_GROUP) {
|
if flags.contains(crate::Barrier::WORK_GROUP) {
|
||||||
writeln!(self.out, "{level}memoryBarrierShared();")?;
|
writeln!(self.out, "{level}memoryBarrierShared();")?;
|
||||||
}
|
}
|
||||||
|
if flags.contains(crate::Barrier::SUB_GROUP) {
|
||||||
|
writeln!(self.out, "{level}subgroupMemoryBarrier();")?;
|
||||||
|
}
|
||||||
writeln!(self.out, "{level}barrier();")?;
|
writeln!(self.out, "{level}barrier();")?;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
@ -4496,6 +4620,11 @@ const fn glsl_built_in(built_in: crate::BuiltIn, options: VaryingOptions) -> &'s
|
|||||||
Bi::WorkGroupId => "gl_WorkGroupID",
|
Bi::WorkGroupId => "gl_WorkGroupID",
|
||||||
Bi::WorkGroupSize => "gl_WorkGroupSize",
|
Bi::WorkGroupSize => "gl_WorkGroupSize",
|
||||||
Bi::NumWorkGroups => "gl_NumWorkGroups",
|
Bi::NumWorkGroups => "gl_NumWorkGroups",
|
||||||
|
// subgroup
|
||||||
|
Bi::NumSubgroups => "gl_NumSubgroups",
|
||||||
|
Bi::SubgroupId => "gl_SubgroupID",
|
||||||
|
Bi::SubgroupSize => "gl_SubgroupSize",
|
||||||
|
Bi::SubgroupInvocationId => "gl_SubgroupInvocationID",
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -179,6 +179,11 @@ impl crate::BuiltIn {
|
|||||||
// to this field will get replaced with references to `SPECIAL_CBUF_VAR`
|
// to this field will get replaced with references to `SPECIAL_CBUF_VAR`
|
||||||
// in `Writer::write_expr`.
|
// in `Writer::write_expr`.
|
||||||
Self::NumWorkGroups => "SV_GroupID",
|
Self::NumWorkGroups => "SV_GroupID",
|
||||||
|
// These builtins map to functions
|
||||||
|
Self::SubgroupSize
|
||||||
|
| Self::SubgroupInvocationId
|
||||||
|
| Self::NumSubgroups
|
||||||
|
| Self::SubgroupId => unreachable!(),
|
||||||
Self::BaseInstance | Self::BaseVertex | Self::WorkGroupSize => {
|
Self::BaseInstance | Self::BaseVertex | Self::WorkGroupSize => {
|
||||||
return Err(Error::Unimplemented(format!("builtin {self:?}")))
|
return Err(Error::Unimplemented(format!("builtin {self:?}")))
|
||||||
}
|
}
|
||||||
|
|||||||
@ -77,6 +77,19 @@ enum Io {
|
|||||||
Output,
|
Output,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const fn is_subgroup_builtin_binding(binding: &Option<crate::Binding>) -> bool {
|
||||||
|
let &Some(crate::Binding::BuiltIn(builtin)) = binding else {
|
||||||
|
return false;
|
||||||
|
};
|
||||||
|
matches!(
|
||||||
|
builtin,
|
||||||
|
crate::BuiltIn::SubgroupSize
|
||||||
|
| crate::BuiltIn::SubgroupInvocationId
|
||||||
|
| crate::BuiltIn::NumSubgroups
|
||||||
|
| crate::BuiltIn::SubgroupId
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
impl<'a, W: fmt::Write> super::Writer<'a, W> {
|
impl<'a, W: fmt::Write> super::Writer<'a, W> {
|
||||||
pub fn new(out: W, options: &'a Options) -> Self {
|
pub fn new(out: W, options: &'a Options) -> Self {
|
||||||
Self {
|
Self {
|
||||||
@ -161,6 +174,19 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
for statement in func.body.iter() {
|
||||||
|
match *statement {
|
||||||
|
crate::Statement::SubgroupCollectiveOperation {
|
||||||
|
op: _,
|
||||||
|
collective_op: crate::CollectiveOperation::InclusiveScan,
|
||||||
|
argument,
|
||||||
|
result: _,
|
||||||
|
} => {
|
||||||
|
self.need_bake_expressions.insert(argument);
|
||||||
|
}
|
||||||
|
_ => {}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn write(
|
pub fn write(
|
||||||
@ -401,31 +427,32 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
|
|||||||
// if they are struct, so that the `stage` argument here could be omitted.
|
// if they are struct, so that the `stage` argument here could be omitted.
|
||||||
fn write_semantic(
|
fn write_semantic(
|
||||||
&mut self,
|
&mut self,
|
||||||
binding: &crate::Binding,
|
binding: &Option<crate::Binding>,
|
||||||
stage: Option<(ShaderStage, Io)>,
|
stage: Option<(ShaderStage, Io)>,
|
||||||
) -> BackendResult {
|
) -> BackendResult {
|
||||||
match *binding {
|
match *binding {
|
||||||
crate::Binding::BuiltIn(builtin) => {
|
Some(crate::Binding::BuiltIn(builtin)) if !is_subgroup_builtin_binding(binding) => {
|
||||||
let builtin_str = builtin.to_hlsl_str()?;
|
let builtin_str = builtin.to_hlsl_str()?;
|
||||||
write!(self.out, " : {builtin_str}")?;
|
write!(self.out, " : {builtin_str}")?;
|
||||||
}
|
}
|
||||||
crate::Binding::Location {
|
Some(crate::Binding::Location {
|
||||||
second_blend_source: true,
|
second_blend_source: true,
|
||||||
..
|
..
|
||||||
} => {
|
}) => {
|
||||||
write!(self.out, " : SV_Target1")?;
|
write!(self.out, " : SV_Target1")?;
|
||||||
}
|
}
|
||||||
crate::Binding::Location {
|
Some(crate::Binding::Location {
|
||||||
location,
|
location,
|
||||||
second_blend_source: false,
|
second_blend_source: false,
|
||||||
..
|
..
|
||||||
} => {
|
}) => {
|
||||||
if stage == Some((crate::ShaderStage::Fragment, Io::Output)) {
|
if stage == Some((crate::ShaderStage::Fragment, Io::Output)) {
|
||||||
write!(self.out, " : SV_Target{location}")?;
|
write!(self.out, " : SV_Target{location}")?;
|
||||||
} else {
|
} else {
|
||||||
write!(self.out, " : {LOCATION_SEMANTIC}{location}")?;
|
write!(self.out, " : {LOCATION_SEMANTIC}{location}")?;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
_ => {}
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
@ -446,17 +473,30 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
|
|||||||
write!(self.out, "struct {struct_name}")?;
|
write!(self.out, "struct {struct_name}")?;
|
||||||
writeln!(self.out, " {{")?;
|
writeln!(self.out, " {{")?;
|
||||||
for m in members.iter() {
|
for m in members.iter() {
|
||||||
|
if is_subgroup_builtin_binding(&m.binding) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
write!(self.out, "{}", back::INDENT)?;
|
write!(self.out, "{}", back::INDENT)?;
|
||||||
if let Some(ref binding) = m.binding {
|
if let Some(ref binding) = m.binding {
|
||||||
self.write_modifier(binding)?;
|
self.write_modifier(binding)?;
|
||||||
}
|
}
|
||||||
self.write_type(module, m.ty)?;
|
self.write_type(module, m.ty)?;
|
||||||
write!(self.out, " {}", &m.name)?;
|
write!(self.out, " {}", &m.name)?;
|
||||||
if let Some(ref binding) = m.binding {
|
self.write_semantic(&m.binding, Some(shader_stage))?;
|
||||||
self.write_semantic(binding, Some(shader_stage))?;
|
|
||||||
}
|
|
||||||
writeln!(self.out, ";")?;
|
writeln!(self.out, ";")?;
|
||||||
}
|
}
|
||||||
|
if members.iter().any(|arg| {
|
||||||
|
matches!(
|
||||||
|
arg.binding,
|
||||||
|
Some(crate::Binding::BuiltIn(crate::BuiltIn::SubgroupId))
|
||||||
|
)
|
||||||
|
}) {
|
||||||
|
writeln!(
|
||||||
|
self.out,
|
||||||
|
"{}uint __local_invocation_index : SV_GroupIndex;",
|
||||||
|
back::INDENT
|
||||||
|
)?;
|
||||||
|
}
|
||||||
writeln!(self.out, "}};")?;
|
writeln!(self.out, "}};")?;
|
||||||
writeln!(self.out)?;
|
writeln!(self.out)?;
|
||||||
|
|
||||||
@ -557,8 +597,8 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Writes special interface structures for an entry point. The special structures have
|
/// Writes special interface structures for an entry point. The special structures have
|
||||||
/// all the fields flattened into them and sorted by binding. They are only needed for
|
/// all the fields flattened into them and sorted by binding. They are needed to emulate
|
||||||
/// VS outputs and FS inputs, so that these interfaces match.
|
/// subgroup built-ins and to make the interfaces between VS outputs and FS inputs match.
|
||||||
fn write_ep_interface(
|
fn write_ep_interface(
|
||||||
&mut self,
|
&mut self,
|
||||||
module: &Module,
|
module: &Module,
|
||||||
@ -567,7 +607,13 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
|
|||||||
ep_name: &str,
|
ep_name: &str,
|
||||||
) -> Result<EntryPointInterface, Error> {
|
) -> Result<EntryPointInterface, Error> {
|
||||||
Ok(EntryPointInterface {
|
Ok(EntryPointInterface {
|
||||||
input: if !func.arguments.is_empty() && stage == ShaderStage::Fragment {
|
input: if !func.arguments.is_empty()
|
||||||
|
&& (stage == ShaderStage::Fragment
|
||||||
|
|| func
|
||||||
|
.arguments
|
||||||
|
.iter()
|
||||||
|
.any(|arg| is_subgroup_builtin_binding(&arg.binding)))
|
||||||
|
{
|
||||||
Some(self.write_ep_input_struct(module, func, stage, ep_name)?)
|
Some(self.write_ep_input_struct(module, func, stage, ep_name)?)
|
||||||
} else {
|
} else {
|
||||||
None
|
None
|
||||||
@ -581,6 +627,38 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn write_ep_argument_initialization(
|
||||||
|
&mut self,
|
||||||
|
ep: &crate::EntryPoint,
|
||||||
|
ep_input: &EntryPointBinding,
|
||||||
|
fake_member: &EpStructMember,
|
||||||
|
) -> BackendResult {
|
||||||
|
match fake_member.binding {
|
||||||
|
Some(crate::Binding::BuiltIn(crate::BuiltIn::SubgroupSize)) => {
|
||||||
|
write!(self.out, "WaveGetLaneCount()")?
|
||||||
|
}
|
||||||
|
Some(crate::Binding::BuiltIn(crate::BuiltIn::SubgroupInvocationId)) => {
|
||||||
|
write!(self.out, "WaveGetLaneIndex()")?
|
||||||
|
}
|
||||||
|
Some(crate::Binding::BuiltIn(crate::BuiltIn::NumSubgroups)) => write!(
|
||||||
|
self.out,
|
||||||
|
"({}u + WaveGetLaneCount() - 1u) / WaveGetLaneCount()",
|
||||||
|
ep.workgroup_size[0] * ep.workgroup_size[1] * ep.workgroup_size[2]
|
||||||
|
)?,
|
||||||
|
Some(crate::Binding::BuiltIn(crate::BuiltIn::SubgroupId)) => {
|
||||||
|
write!(
|
||||||
|
self.out,
|
||||||
|
"{}.__local_invocation_index / WaveGetLaneCount()",
|
||||||
|
ep_input.arg_name
|
||||||
|
)?;
|
||||||
|
}
|
||||||
|
_ => {
|
||||||
|
write!(self.out, "{}.{}", ep_input.arg_name, fake_member.name)?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
/// Write an entry point preface that initializes the arguments as specified in IR.
|
/// Write an entry point preface that initializes the arguments as specified in IR.
|
||||||
fn write_ep_arguments_initialization(
|
fn write_ep_arguments_initialization(
|
||||||
&mut self,
|
&mut self,
|
||||||
@ -588,6 +666,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
|
|||||||
func: &crate::Function,
|
func: &crate::Function,
|
||||||
ep_index: u16,
|
ep_index: u16,
|
||||||
) -> BackendResult {
|
) -> BackendResult {
|
||||||
|
let ep = &module.entry_points[ep_index as usize];
|
||||||
let ep_input = match self.entry_point_io[ep_index as usize].input.take() {
|
let ep_input = match self.entry_point_io[ep_index as usize].input.take() {
|
||||||
Some(ep_input) => ep_input,
|
Some(ep_input) => ep_input,
|
||||||
None => return Ok(()),
|
None => return Ok(()),
|
||||||
@ -601,8 +680,13 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
|
|||||||
match module.types[arg.ty].inner {
|
match module.types[arg.ty].inner {
|
||||||
TypeInner::Array { base, size, .. } => {
|
TypeInner::Array { base, size, .. } => {
|
||||||
self.write_array_size(module, base, size)?;
|
self.write_array_size(module, base, size)?;
|
||||||
let fake_member = fake_iter.next().unwrap();
|
write!(self.out, " = ")?;
|
||||||
writeln!(self.out, " = {}.{};", ep_input.arg_name, fake_member.name)?;
|
self.write_ep_argument_initialization(
|
||||||
|
ep,
|
||||||
|
&ep_input,
|
||||||
|
fake_iter.next().unwrap(),
|
||||||
|
)?;
|
||||||
|
writeln!(self.out, ";")?;
|
||||||
}
|
}
|
||||||
TypeInner::Struct { ref members, .. } => {
|
TypeInner::Struct { ref members, .. } => {
|
||||||
write!(self.out, " = {{ ")?;
|
write!(self.out, " = {{ ")?;
|
||||||
@ -610,14 +694,22 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
|
|||||||
if index != 0 {
|
if index != 0 {
|
||||||
write!(self.out, ", ")?;
|
write!(self.out, ", ")?;
|
||||||
}
|
}
|
||||||
let fake_member = fake_iter.next().unwrap();
|
self.write_ep_argument_initialization(
|
||||||
write!(self.out, "{}.{}", ep_input.arg_name, fake_member.name)?;
|
ep,
|
||||||
|
&ep_input,
|
||||||
|
fake_iter.next().unwrap(),
|
||||||
|
)?;
|
||||||
}
|
}
|
||||||
writeln!(self.out, " }};")?;
|
writeln!(self.out, " }};")?;
|
||||||
}
|
}
|
||||||
_ => {
|
_ => {
|
||||||
let fake_member = fake_iter.next().unwrap();
|
write!(self.out, " = ")?;
|
||||||
writeln!(self.out, " = {}.{};", ep_input.arg_name, fake_member.name)?;
|
self.write_ep_argument_initialization(
|
||||||
|
ep,
|
||||||
|
&ep_input,
|
||||||
|
fake_iter.next().unwrap(),
|
||||||
|
)?;
|
||||||
|
writeln!(self.out, ";")?;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -932,9 +1024,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if let Some(ref binding) = member.binding {
|
self.write_semantic(&member.binding, shader_stage)?;
|
||||||
self.write_semantic(binding, shader_stage)?;
|
|
||||||
};
|
|
||||||
writeln!(self.out, ";")?;
|
writeln!(self.out, ";")?;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1147,7 +1237,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
|
|||||||
}
|
}
|
||||||
back::FunctionType::EntryPoint(ep_index) => {
|
back::FunctionType::EntryPoint(ep_index) => {
|
||||||
if let Some(ref ep_input) = self.entry_point_io[ep_index as usize].input {
|
if let Some(ref ep_input) = self.entry_point_io[ep_index as usize].input {
|
||||||
write!(self.out, "{} {}", ep_input.ty_name, ep_input.arg_name,)?;
|
write!(self.out, "{} {}", ep_input.ty_name, ep_input.arg_name)?;
|
||||||
} else {
|
} else {
|
||||||
let stage = module.entry_points[ep_index as usize].stage;
|
let stage = module.entry_points[ep_index as usize].stage;
|
||||||
for (index, arg) in func.arguments.iter().enumerate() {
|
for (index, arg) in func.arguments.iter().enumerate() {
|
||||||
@ -1164,31 +1254,26 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
|
|||||||
self.write_array_size(module, base, size)?;
|
self.write_array_size(module, base, size)?;
|
||||||
}
|
}
|
||||||
|
|
||||||
if let Some(ref binding) = arg.binding {
|
self.write_semantic(&arg.binding, Some((stage, Io::Input)))?;
|
||||||
self.write_semantic(binding, Some((stage, Io::Input)))?;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if need_workgroup_variables_initialization {
|
if need_workgroup_variables_initialization {
|
||||||
if !func.arguments.is_empty() {
|
if self.entry_point_io[ep_index as usize].input.is_some()
|
||||||
|
|| !func.arguments.is_empty()
|
||||||
|
{
|
||||||
write!(self.out, ", ")?;
|
write!(self.out, ", ")?;
|
||||||
}
|
}
|
||||||
write!(self.out, "uint3 __local_invocation_id : SV_GroupThreadID")?;
|
write!(self.out, "uint3 __local_invocation_id : SV_GroupThreadID")?;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
// Ends of arguments
|
// Ends of arguments
|
||||||
write!(self.out, ")")?;
|
write!(self.out, ")")?;
|
||||||
|
|
||||||
// Write semantic if it present
|
// Write semantic if it present
|
||||||
if let back::FunctionType::EntryPoint(index) = func_ctx.ty {
|
if let back::FunctionType::EntryPoint(index) = func_ctx.ty {
|
||||||
let stage = module.entry_points[index as usize].stage;
|
let stage = module.entry_points[index as usize].stage;
|
||||||
if let Some(crate::FunctionResult {
|
if let Some(crate::FunctionResult { ref binding, .. }) = func.result {
|
||||||
binding: Some(ref binding),
|
|
||||||
..
|
|
||||||
}) = func.result
|
|
||||||
{
|
|
||||||
self.write_semantic(binding, Some((stage, Io::Output)))?;
|
self.write_semantic(binding, Some((stage, Io::Output)))?;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -1988,6 +2073,129 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
|
|||||||
writeln!(self.out, "{level}}}")?
|
writeln!(self.out, "{level}}}")?
|
||||||
}
|
}
|
||||||
Statement::RayQuery { .. } => unreachable!(),
|
Statement::RayQuery { .. } => unreachable!(),
|
||||||
|
Statement::SubgroupBallot { result, predicate } => {
|
||||||
|
write!(self.out, "{level}")?;
|
||||||
|
let name = format!("{}{}", back::BAKE_PREFIX, result.index());
|
||||||
|
write!(self.out, "const uint4 {name} = ")?;
|
||||||
|
self.named_expressions.insert(result, name);
|
||||||
|
|
||||||
|
write!(self.out, "WaveActiveBallot(")?;
|
||||||
|
match predicate {
|
||||||
|
Some(predicate) => self.write_expr(module, predicate, func_ctx)?,
|
||||||
|
None => write!(self.out, "true")?,
|
||||||
|
}
|
||||||
|
writeln!(self.out, ");")?;
|
||||||
|
}
|
||||||
|
Statement::SubgroupCollectiveOperation {
|
||||||
|
op,
|
||||||
|
collective_op,
|
||||||
|
argument,
|
||||||
|
result,
|
||||||
|
} => {
|
||||||
|
write!(self.out, "{level}")?;
|
||||||
|
write!(self.out, "const ")?;
|
||||||
|
let name = format!("{}{}", back::BAKE_PREFIX, result.index());
|
||||||
|
match func_ctx.info[result].ty {
|
||||||
|
proc::TypeResolution::Handle(handle) => self.write_type(module, handle)?,
|
||||||
|
proc::TypeResolution::Value(ref value) => {
|
||||||
|
self.write_value_type(module, value)?
|
||||||
|
}
|
||||||
|
};
|
||||||
|
write!(self.out, " {name} = ")?;
|
||||||
|
self.named_expressions.insert(result, name);
|
||||||
|
|
||||||
|
match (collective_op, op) {
|
||||||
|
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::All) => {
|
||||||
|
write!(self.out, "WaveActiveAllTrue(")?
|
||||||
|
}
|
||||||
|
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Any) => {
|
||||||
|
write!(self.out, "WaveActiveAnyTrue(")?
|
||||||
|
}
|
||||||
|
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Add) => {
|
||||||
|
write!(self.out, "WaveActiveSum(")?
|
||||||
|
}
|
||||||
|
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Mul) => {
|
||||||
|
write!(self.out, "WaveActiveProduct(")?
|
||||||
|
}
|
||||||
|
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Max) => {
|
||||||
|
write!(self.out, "WaveActiveMax(")?
|
||||||
|
}
|
||||||
|
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Min) => {
|
||||||
|
write!(self.out, "WaveActiveMin(")?
|
||||||
|
}
|
||||||
|
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::And) => {
|
||||||
|
write!(self.out, "WaveActiveBitAnd(")?
|
||||||
|
}
|
||||||
|
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Or) => {
|
||||||
|
write!(self.out, "WaveActiveBitOr(")?
|
||||||
|
}
|
||||||
|
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Xor) => {
|
||||||
|
write!(self.out, "WaveActiveBitXor(")?
|
||||||
|
}
|
||||||
|
(crate::CollectiveOperation::ExclusiveScan, crate::SubgroupOperation::Add) => {
|
||||||
|
write!(self.out, "WavePrefixSum(")?
|
||||||
|
}
|
||||||
|
(crate::CollectiveOperation::ExclusiveScan, crate::SubgroupOperation::Mul) => {
|
||||||
|
write!(self.out, "WavePrefixProduct(")?
|
||||||
|
}
|
||||||
|
(crate::CollectiveOperation::InclusiveScan, crate::SubgroupOperation::Add) => {
|
||||||
|
self.write_expr(module, argument, func_ctx)?;
|
||||||
|
write!(self.out, " + WavePrefixSum(")?;
|
||||||
|
}
|
||||||
|
(crate::CollectiveOperation::InclusiveScan, crate::SubgroupOperation::Mul) => {
|
||||||
|
self.write_expr(module, argument, func_ctx)?;
|
||||||
|
write!(self.out, " * WavePrefixProduct(")?;
|
||||||
|
}
|
||||||
|
_ => unimplemented!(),
|
||||||
|
}
|
||||||
|
self.write_expr(module, argument, func_ctx)?;
|
||||||
|
writeln!(self.out, ");")?;
|
||||||
|
}
|
||||||
|
Statement::SubgroupGather {
|
||||||
|
mode,
|
||||||
|
argument,
|
||||||
|
result,
|
||||||
|
} => {
|
||||||
|
write!(self.out, "{level}")?;
|
||||||
|
write!(self.out, "const ")?;
|
||||||
|
let name = format!("{}{}", back::BAKE_PREFIX, result.index());
|
||||||
|
match func_ctx.info[result].ty {
|
||||||
|
proc::TypeResolution::Handle(handle) => self.write_type(module, handle)?,
|
||||||
|
proc::TypeResolution::Value(ref value) => {
|
||||||
|
self.write_value_type(module, value)?
|
||||||
|
}
|
||||||
|
};
|
||||||
|
write!(self.out, " {name} = ")?;
|
||||||
|
self.named_expressions.insert(result, name);
|
||||||
|
|
||||||
|
if matches!(mode, crate::GatherMode::BroadcastFirst) {
|
||||||
|
write!(self.out, "WaveReadLaneFirst(")?;
|
||||||
|
self.write_expr(module, argument, func_ctx)?;
|
||||||
|
} else {
|
||||||
|
write!(self.out, "WaveReadLaneAt(")?;
|
||||||
|
self.write_expr(module, argument, func_ctx)?;
|
||||||
|
write!(self.out, ", ")?;
|
||||||
|
match mode {
|
||||||
|
crate::GatherMode::BroadcastFirst => unreachable!(),
|
||||||
|
crate::GatherMode::Broadcast(index) | crate::GatherMode::Shuffle(index) => {
|
||||||
|
self.write_expr(module, index, func_ctx)?;
|
||||||
|
}
|
||||||
|
crate::GatherMode::ShuffleDown(index) => {
|
||||||
|
write!(self.out, "WaveGetLaneIndex() + ")?;
|
||||||
|
self.write_expr(module, index, func_ctx)?;
|
||||||
|
}
|
||||||
|
crate::GatherMode::ShuffleUp(index) => {
|
||||||
|
write!(self.out, "WaveGetLaneIndex() - ")?;
|
||||||
|
self.write_expr(module, index, func_ctx)?;
|
||||||
|
}
|
||||||
|
crate::GatherMode::ShuffleXor(index) => {
|
||||||
|
write!(self.out, "WaveGetLaneIndex() ^ ")?;
|
||||||
|
self.write_expr(module, index, func_ctx)?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
writeln!(self.out, ");")?;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
@ -3134,7 +3342,9 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
|
|||||||
Expression::CallResult(_)
|
Expression::CallResult(_)
|
||||||
| Expression::AtomicResult { .. }
|
| Expression::AtomicResult { .. }
|
||||||
| Expression::WorkGroupUniformLoadResult { .. }
|
| Expression::WorkGroupUniformLoadResult { .. }
|
||||||
| Expression::RayQueryProceedResult => {}
|
| Expression::RayQueryProceedResult
|
||||||
|
| Expression::SubgroupBallotResult
|
||||||
|
| Expression::SubgroupOperationResult { .. } => {}
|
||||||
}
|
}
|
||||||
|
|
||||||
if !closing_bracket.is_empty() {
|
if !closing_bracket.is_empty() {
|
||||||
@ -3201,6 +3411,9 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
|
|||||||
if barrier.contains(crate::Barrier::WORK_GROUP) {
|
if barrier.contains(crate::Barrier::WORK_GROUP) {
|
||||||
writeln!(self.out, "{level}GroupMemoryBarrierWithGroupSync();")?;
|
writeln!(self.out, "{level}GroupMemoryBarrierWithGroupSync();")?;
|
||||||
}
|
}
|
||||||
|
if barrier.contains(crate::Barrier::SUB_GROUP) {
|
||||||
|
// Does not exist in DirectX
|
||||||
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -436,6 +436,11 @@ impl ResolvedBinding {
|
|||||||
Bi::WorkGroupId => "threadgroup_position_in_grid",
|
Bi::WorkGroupId => "threadgroup_position_in_grid",
|
||||||
Bi::WorkGroupSize => "dispatch_threads_per_threadgroup",
|
Bi::WorkGroupSize => "dispatch_threads_per_threadgroup",
|
||||||
Bi::NumWorkGroups => "threadgroups_per_grid",
|
Bi::NumWorkGroups => "threadgroups_per_grid",
|
||||||
|
// subgroup
|
||||||
|
Bi::NumSubgroups => "simdgroups_per_threadgroup",
|
||||||
|
Bi::SubgroupId => "simdgroup_index_in_threadgroup",
|
||||||
|
Bi::SubgroupSize => "threads_per_simdgroup",
|
||||||
|
Bi::SubgroupInvocationId => "thread_index_in_simdgroup",
|
||||||
Bi::CullDistance | Bi::ViewIndex => {
|
Bi::CullDistance | Bi::ViewIndex => {
|
||||||
return Err(Error::UnsupportedBuiltIn(built_in))
|
return Err(Error::UnsupportedBuiltIn(built_in))
|
||||||
}
|
}
|
||||||
|
|||||||
@ -2042,6 +2042,8 @@ impl<W: Write> Writer<W> {
|
|||||||
crate::Expression::CallResult(_)
|
crate::Expression::CallResult(_)
|
||||||
| crate::Expression::AtomicResult { .. }
|
| crate::Expression::AtomicResult { .. }
|
||||||
| crate::Expression::WorkGroupUniformLoadResult { .. }
|
| crate::Expression::WorkGroupUniformLoadResult { .. }
|
||||||
|
| crate::Expression::SubgroupBallotResult
|
||||||
|
| crate::Expression::SubgroupOperationResult { .. }
|
||||||
| crate::Expression::RayQueryProceedResult => {
|
| crate::Expression::RayQueryProceedResult => {
|
||||||
unreachable!()
|
unreachable!()
|
||||||
}
|
}
|
||||||
@ -3145,6 +3147,121 @@ impl<W: Write> Writer<W> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
crate::Statement::SubgroupBallot { result, predicate } => {
|
||||||
|
write!(self.out, "{level}")?;
|
||||||
|
let name = self.namer.call("");
|
||||||
|
self.start_baking_expression(result, &context.expression, &name)?;
|
||||||
|
self.named_expressions.insert(result, name);
|
||||||
|
write!(self.out, "uint4((uint64_t){NAMESPACE}::simd_ballot(")?;
|
||||||
|
if let Some(predicate) = predicate {
|
||||||
|
self.put_expression(predicate, &context.expression, true)?;
|
||||||
|
} else {
|
||||||
|
write!(self.out, "true")?;
|
||||||
|
}
|
||||||
|
writeln!(self.out, "), 0, 0, 0);")?;
|
||||||
|
}
|
||||||
|
crate::Statement::SubgroupCollectiveOperation {
|
||||||
|
op,
|
||||||
|
collective_op,
|
||||||
|
argument,
|
||||||
|
result,
|
||||||
|
} => {
|
||||||
|
write!(self.out, "{level}")?;
|
||||||
|
let name = self.namer.call("");
|
||||||
|
self.start_baking_expression(result, &context.expression, &name)?;
|
||||||
|
self.named_expressions.insert(result, name);
|
||||||
|
match (collective_op, op) {
|
||||||
|
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::All) => {
|
||||||
|
write!(self.out, "{NAMESPACE}::simd_all(")?
|
||||||
|
}
|
||||||
|
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Any) => {
|
||||||
|
write!(self.out, "{NAMESPACE}::simd_any(")?
|
||||||
|
}
|
||||||
|
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Add) => {
|
||||||
|
write!(self.out, "{NAMESPACE}::simd_sum(")?
|
||||||
|
}
|
||||||
|
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Mul) => {
|
||||||
|
write!(self.out, "{NAMESPACE}::simd_product(")?
|
||||||
|
}
|
||||||
|
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Max) => {
|
||||||
|
write!(self.out, "{NAMESPACE}::simd_max(")?
|
||||||
|
}
|
||||||
|
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Min) => {
|
||||||
|
write!(self.out, "{NAMESPACE}::simd_min(")?
|
||||||
|
}
|
||||||
|
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::And) => {
|
||||||
|
write!(self.out, "{NAMESPACE}::simd_and(")?
|
||||||
|
}
|
||||||
|
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Or) => {
|
||||||
|
write!(self.out, "{NAMESPACE}::simd_or(")?
|
||||||
|
}
|
||||||
|
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Xor) => {
|
||||||
|
write!(self.out, "{NAMESPACE}::simd_xor(")?
|
||||||
|
}
|
||||||
|
(
|
||||||
|
crate::CollectiveOperation::ExclusiveScan,
|
||||||
|
crate::SubgroupOperation::Add,
|
||||||
|
) => write!(self.out, "{NAMESPACE}::simd_prefix_exclusive_sum(")?,
|
||||||
|
(
|
||||||
|
crate::CollectiveOperation::ExclusiveScan,
|
||||||
|
crate::SubgroupOperation::Mul,
|
||||||
|
) => write!(self.out, "{NAMESPACE}::simd_prefix_exclusive_product(")?,
|
||||||
|
(
|
||||||
|
crate::CollectiveOperation::InclusiveScan,
|
||||||
|
crate::SubgroupOperation::Add,
|
||||||
|
) => write!(self.out, "{NAMESPACE}::simd_prefix_inclusive_sum(")?,
|
||||||
|
(
|
||||||
|
crate::CollectiveOperation::InclusiveScan,
|
||||||
|
crate::SubgroupOperation::Mul,
|
||||||
|
) => write!(self.out, "{NAMESPACE}::simd_prefix_inclusive_product(")?,
|
||||||
|
_ => unimplemented!(),
|
||||||
|
}
|
||||||
|
self.put_expression(argument, &context.expression, true)?;
|
||||||
|
writeln!(self.out, ");")?;
|
||||||
|
}
|
||||||
|
crate::Statement::SubgroupGather {
|
||||||
|
mode,
|
||||||
|
argument,
|
||||||
|
result,
|
||||||
|
} => {
|
||||||
|
write!(self.out, "{level}")?;
|
||||||
|
let name = self.namer.call("");
|
||||||
|
self.start_baking_expression(result, &context.expression, &name)?;
|
||||||
|
self.named_expressions.insert(result, name);
|
||||||
|
match mode {
|
||||||
|
crate::GatherMode::BroadcastFirst => {
|
||||||
|
write!(self.out, "{NAMESPACE}::simd_broadcast_first(")?;
|
||||||
|
}
|
||||||
|
crate::GatherMode::Broadcast(_) => {
|
||||||
|
write!(self.out, "{NAMESPACE}::simd_broadcast(")?;
|
||||||
|
}
|
||||||
|
crate::GatherMode::Shuffle(_) => {
|
||||||
|
write!(self.out, "{NAMESPACE}::simd_shuffle(")?;
|
||||||
|
}
|
||||||
|
crate::GatherMode::ShuffleDown(_) => {
|
||||||
|
write!(self.out, "{NAMESPACE}::simd_shuffle_down(")?;
|
||||||
|
}
|
||||||
|
crate::GatherMode::ShuffleUp(_) => {
|
||||||
|
write!(self.out, "{NAMESPACE}::simd_shuffle_up(")?;
|
||||||
|
}
|
||||||
|
crate::GatherMode::ShuffleXor(_) => {
|
||||||
|
write!(self.out, "{NAMESPACE}::simd_shuffle_xor(")?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
self.put_expression(argument, &context.expression, true)?;
|
||||||
|
match mode {
|
||||||
|
crate::GatherMode::BroadcastFirst => {}
|
||||||
|
crate::GatherMode::Broadcast(index)
|
||||||
|
| crate::GatherMode::Shuffle(index)
|
||||||
|
| crate::GatherMode::ShuffleDown(index)
|
||||||
|
| crate::GatherMode::ShuffleUp(index)
|
||||||
|
| crate::GatherMode::ShuffleXor(index) => {
|
||||||
|
write!(self.out, ", ")?;
|
||||||
|
self.put_expression(index, &context.expression, true)?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
writeln!(self.out, ");")?;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -4492,6 +4609,12 @@ impl<W: Write> Writer<W> {
|
|||||||
"{level}{NAMESPACE}::threadgroup_barrier({NAMESPACE}::mem_flags::mem_threadgroup);",
|
"{level}{NAMESPACE}::threadgroup_barrier({NAMESPACE}::mem_flags::mem_threadgroup);",
|
||||||
)?;
|
)?;
|
||||||
}
|
}
|
||||||
|
if flags.contains(crate::Barrier::SUB_GROUP) {
|
||||||
|
writeln!(
|
||||||
|
self.out,
|
||||||
|
"{level}{NAMESPACE}::simdgroup_barrier({NAMESPACE}::mem_flags::mem_threadgroup);",
|
||||||
|
)?;
|
||||||
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -4762,8 +4885,8 @@ fn test_stack_size() {
|
|||||||
}
|
}
|
||||||
let stack_size = addresses_end - addresses_start;
|
let stack_size = addresses_end - addresses_start;
|
||||||
// check the size (in debug only)
|
// check the size (in debug only)
|
||||||
// last observed macOS value: 19152 (CI)
|
// last observed macOS value: 22256 (CI)
|
||||||
if !(9000..=20000).contains(&stack_size) {
|
if !(15000..=25000).contains(&stack_size) {
|
||||||
panic!("`put_block` stack size {stack_size} has changed!");
|
panic!("`put_block` stack size {stack_size} has changed!");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -522,7 +522,9 @@ fn adjust_expr(new_pos: &[Handle<Expression>], expr: &mut Expression) {
|
|||||||
ty: _,
|
ty: _,
|
||||||
comparison: _,
|
comparison: _,
|
||||||
}
|
}
|
||||||
| Expression::WorkGroupUniformLoadResult { ty: _ } => {}
|
| Expression::WorkGroupUniformLoadResult { ty: _ }
|
||||||
|
| Expression::SubgroupBallotResult
|
||||||
|
| Expression::SubgroupOperationResult { .. } => {}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -637,6 +639,41 @@ fn adjust_stmt(new_pos: &[Handle<Expression>], stmt: &mut Statement) {
|
|||||||
adjust(pointer);
|
adjust(pointer);
|
||||||
adjust(result);
|
adjust(result);
|
||||||
}
|
}
|
||||||
|
Statement::SubgroupBallot {
|
||||||
|
ref mut result,
|
||||||
|
ref mut predicate,
|
||||||
|
} => {
|
||||||
|
if let Some(ref mut predicate) = *predicate {
|
||||||
|
adjust(predicate);
|
||||||
|
}
|
||||||
|
adjust(result);
|
||||||
|
}
|
||||||
|
Statement::SubgroupCollectiveOperation {
|
||||||
|
ref mut argument,
|
||||||
|
ref mut result,
|
||||||
|
..
|
||||||
|
} => {
|
||||||
|
adjust(argument);
|
||||||
|
adjust(result);
|
||||||
|
}
|
||||||
|
Statement::SubgroupGather {
|
||||||
|
ref mut mode,
|
||||||
|
ref mut argument,
|
||||||
|
ref mut result,
|
||||||
|
} => {
|
||||||
|
match *mode {
|
||||||
|
crate::GatherMode::BroadcastFirst => {}
|
||||||
|
crate::GatherMode::Broadcast(ref mut index)
|
||||||
|
| crate::GatherMode::Shuffle(ref mut index)
|
||||||
|
| crate::GatherMode::ShuffleDown(ref mut index)
|
||||||
|
| crate::GatherMode::ShuffleUp(ref mut index)
|
||||||
|
| crate::GatherMode::ShuffleXor(ref mut index) => {
|
||||||
|
adjust(index);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
adjust(argument);
|
||||||
|
adjust(result)
|
||||||
|
}
|
||||||
Statement::Call {
|
Statement::Call {
|
||||||
ref mut arguments,
|
ref mut arguments,
|
||||||
ref mut result,
|
ref mut result,
|
||||||
|
|||||||
@ -1279,7 +1279,9 @@ impl<'w> BlockContext<'w> {
|
|||||||
crate::Expression::CallResult(_)
|
crate::Expression::CallResult(_)
|
||||||
| crate::Expression::AtomicResult { .. }
|
| crate::Expression::AtomicResult { .. }
|
||||||
| crate::Expression::WorkGroupUniformLoadResult { .. }
|
| crate::Expression::WorkGroupUniformLoadResult { .. }
|
||||||
| crate::Expression::RayQueryProceedResult => self.cached[expr_handle],
|
| crate::Expression::RayQueryProceedResult
|
||||||
|
| crate::Expression::SubgroupBallotResult
|
||||||
|
| crate::Expression::SubgroupOperationResult { .. } => self.cached[expr_handle],
|
||||||
crate::Expression::As {
|
crate::Expression::As {
|
||||||
expr,
|
expr,
|
||||||
kind,
|
kind,
|
||||||
@ -2490,6 +2492,27 @@ impl<'w> BlockContext<'w> {
|
|||||||
crate::Statement::RayQuery { query, ref fun } => {
|
crate::Statement::RayQuery { query, ref fun } => {
|
||||||
self.write_ray_query_function(query, fun, &mut block);
|
self.write_ray_query_function(query, fun, &mut block);
|
||||||
}
|
}
|
||||||
|
crate::Statement::SubgroupBallot {
|
||||||
|
result,
|
||||||
|
ref predicate,
|
||||||
|
} => {
|
||||||
|
self.write_subgroup_ballot(predicate, result, &mut block)?;
|
||||||
|
}
|
||||||
|
crate::Statement::SubgroupCollectiveOperation {
|
||||||
|
ref op,
|
||||||
|
ref collective_op,
|
||||||
|
argument,
|
||||||
|
result,
|
||||||
|
} => {
|
||||||
|
self.write_subgroup_operation(op, collective_op, argument, result, &mut block)?;
|
||||||
|
}
|
||||||
|
crate::Statement::SubgroupGather {
|
||||||
|
ref mode,
|
||||||
|
argument,
|
||||||
|
result,
|
||||||
|
} => {
|
||||||
|
self.write_subgroup_gather(mode, argument, result, &mut block)?;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -1073,6 +1073,73 @@ impl super::Instruction {
|
|||||||
instruction.add_operand(semantics_id);
|
instruction.add_operand(semantics_id);
|
||||||
instruction
|
instruction
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Group Instructions
|
||||||
|
|
||||||
|
pub(super) fn group_non_uniform_ballot(
|
||||||
|
result_type_id: Word,
|
||||||
|
id: Word,
|
||||||
|
exec_scope_id: Word,
|
||||||
|
predicate: Word,
|
||||||
|
) -> Self {
|
||||||
|
let mut instruction = Self::new(Op::GroupNonUniformBallot);
|
||||||
|
instruction.set_type(result_type_id);
|
||||||
|
instruction.set_result(id);
|
||||||
|
instruction.add_operand(exec_scope_id);
|
||||||
|
instruction.add_operand(predicate);
|
||||||
|
|
||||||
|
instruction
|
||||||
|
}
|
||||||
|
pub(super) fn group_non_uniform_broadcast_first(
|
||||||
|
result_type_id: Word,
|
||||||
|
id: Word,
|
||||||
|
exec_scope_id: Word,
|
||||||
|
value: Word,
|
||||||
|
) -> Self {
|
||||||
|
let mut instruction = Self::new(Op::GroupNonUniformBroadcastFirst);
|
||||||
|
instruction.set_type(result_type_id);
|
||||||
|
instruction.set_result(id);
|
||||||
|
instruction.add_operand(exec_scope_id);
|
||||||
|
instruction.add_operand(value);
|
||||||
|
|
||||||
|
instruction
|
||||||
|
}
|
||||||
|
pub(super) fn group_non_uniform_gather(
|
||||||
|
op: Op,
|
||||||
|
result_type_id: Word,
|
||||||
|
id: Word,
|
||||||
|
exec_scope_id: Word,
|
||||||
|
value: Word,
|
||||||
|
index: Word,
|
||||||
|
) -> Self {
|
||||||
|
let mut instruction = Self::new(op);
|
||||||
|
instruction.set_type(result_type_id);
|
||||||
|
instruction.set_result(id);
|
||||||
|
instruction.add_operand(exec_scope_id);
|
||||||
|
instruction.add_operand(value);
|
||||||
|
instruction.add_operand(index);
|
||||||
|
|
||||||
|
instruction
|
||||||
|
}
|
||||||
|
pub(super) fn group_non_uniform_arithmetic(
|
||||||
|
op: Op,
|
||||||
|
result_type_id: Word,
|
||||||
|
id: Word,
|
||||||
|
exec_scope_id: Word,
|
||||||
|
group_op: Option<spirv::GroupOperation>,
|
||||||
|
value: Word,
|
||||||
|
) -> Self {
|
||||||
|
let mut instruction = Self::new(op);
|
||||||
|
instruction.set_type(result_type_id);
|
||||||
|
instruction.set_result(id);
|
||||||
|
instruction.add_operand(exec_scope_id);
|
||||||
|
if let Some(group_op) = group_op {
|
||||||
|
instruction.add_operand(group_op as u32);
|
||||||
|
}
|
||||||
|
instruction.add_operand(value);
|
||||||
|
|
||||||
|
instruction
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl From<crate::StorageFormat> for spirv::ImageFormat {
|
impl From<crate::StorageFormat> for spirv::ImageFormat {
|
||||||
|
|||||||
@ -13,6 +13,7 @@ mod layout;
|
|||||||
mod ray;
|
mod ray;
|
||||||
mod recyclable;
|
mod recyclable;
|
||||||
mod selection;
|
mod selection;
|
||||||
|
mod subgroup;
|
||||||
mod writer;
|
mod writer;
|
||||||
|
|
||||||
pub use spirv::Capability;
|
pub use spirv::Capability;
|
||||||
|
|||||||
207
naga/src/back/spv/subgroup.rs
Normal file
207
naga/src/back/spv/subgroup.rs
Normal file
@ -0,0 +1,207 @@
|
|||||||
|
use super::{Block, BlockContext, Error, Instruction};
|
||||||
|
use crate::{
|
||||||
|
arena::Handle,
|
||||||
|
back::spv::{LocalType, LookupType},
|
||||||
|
TypeInner,
|
||||||
|
};
|
||||||
|
|
||||||
|
impl<'w> BlockContext<'w> {
|
||||||
|
pub(super) fn write_subgroup_ballot(
|
||||||
|
&mut self,
|
||||||
|
predicate: &Option<Handle<crate::Expression>>,
|
||||||
|
result: Handle<crate::Expression>,
|
||||||
|
block: &mut Block,
|
||||||
|
) -> Result<(), Error> {
|
||||||
|
self.writer.require_any(
|
||||||
|
"GroupNonUniformBallot",
|
||||||
|
&[spirv::Capability::GroupNonUniformBallot],
|
||||||
|
)?;
|
||||||
|
let vec4_u32_type_id = self.get_type_id(LookupType::Local(LocalType::Value {
|
||||||
|
vector_size: Some(crate::VectorSize::Quad),
|
||||||
|
scalar: crate::Scalar::U32,
|
||||||
|
pointer_space: None,
|
||||||
|
}));
|
||||||
|
let exec_scope_id = self.get_index_constant(spirv::Scope::Subgroup as u32);
|
||||||
|
let predicate = if let Some(predicate) = *predicate {
|
||||||
|
self.cached[predicate]
|
||||||
|
} else {
|
||||||
|
self.writer.get_constant_scalar(crate::Literal::Bool(true))
|
||||||
|
};
|
||||||
|
let id = self.gen_id();
|
||||||
|
block.body.push(Instruction::group_non_uniform_ballot(
|
||||||
|
vec4_u32_type_id,
|
||||||
|
id,
|
||||||
|
exec_scope_id,
|
||||||
|
predicate,
|
||||||
|
));
|
||||||
|
self.cached[result] = id;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
pub(super) fn write_subgroup_operation(
|
||||||
|
&mut self,
|
||||||
|
op: &crate::SubgroupOperation,
|
||||||
|
collective_op: &crate::CollectiveOperation,
|
||||||
|
argument: Handle<crate::Expression>,
|
||||||
|
result: Handle<crate::Expression>,
|
||||||
|
block: &mut Block,
|
||||||
|
) -> Result<(), Error> {
|
||||||
|
use crate::SubgroupOperation as sg;
|
||||||
|
match *op {
|
||||||
|
sg::All | sg::Any => {
|
||||||
|
self.writer.require_any(
|
||||||
|
"GroupNonUniformVote",
|
||||||
|
&[spirv::Capability::GroupNonUniformVote],
|
||||||
|
)?;
|
||||||
|
}
|
||||||
|
_ => {
|
||||||
|
self.writer.require_any(
|
||||||
|
"GroupNonUniformArithmetic",
|
||||||
|
&[spirv::Capability::GroupNonUniformArithmetic],
|
||||||
|
)?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let id = self.gen_id();
|
||||||
|
let result_ty = &self.fun_info[result].ty;
|
||||||
|
let result_type_id = self.get_expression_type_id(result_ty);
|
||||||
|
let result_ty_inner = result_ty.inner_with(&self.ir_module.types);
|
||||||
|
|
||||||
|
let (is_scalar, scalar) = match *result_ty_inner {
|
||||||
|
TypeInner::Scalar(kind) => (true, kind),
|
||||||
|
TypeInner::Vector { scalar: kind, .. } => (false, kind),
|
||||||
|
_ => unimplemented!(),
|
||||||
|
};
|
||||||
|
|
||||||
|
use crate::ScalarKind as sk;
|
||||||
|
let spirv_op = match (scalar.kind, *op) {
|
||||||
|
(sk::Bool, sg::All) if is_scalar => spirv::Op::GroupNonUniformAll,
|
||||||
|
(sk::Bool, sg::Any) if is_scalar => spirv::Op::GroupNonUniformAny,
|
||||||
|
(_, sg::All | sg::Any) => unimplemented!(),
|
||||||
|
|
||||||
|
(sk::Sint | sk::Uint, sg::Add) => spirv::Op::GroupNonUniformIAdd,
|
||||||
|
(sk::Float, sg::Add) => spirv::Op::GroupNonUniformFAdd,
|
||||||
|
(sk::Sint | sk::Uint, sg::Mul) => spirv::Op::GroupNonUniformIMul,
|
||||||
|
(sk::Float, sg::Mul) => spirv::Op::GroupNonUniformFMul,
|
||||||
|
(sk::Sint, sg::Max) => spirv::Op::GroupNonUniformSMax,
|
||||||
|
(sk::Uint, sg::Max) => spirv::Op::GroupNonUniformUMax,
|
||||||
|
(sk::Float, sg::Max) => spirv::Op::GroupNonUniformFMax,
|
||||||
|
(sk::Sint, sg::Min) => spirv::Op::GroupNonUniformSMin,
|
||||||
|
(sk::Uint, sg::Min) => spirv::Op::GroupNonUniformUMin,
|
||||||
|
(sk::Float, sg::Min) => spirv::Op::GroupNonUniformFMin,
|
||||||
|
(_, sg::Add | sg::Mul | sg::Min | sg::Max) => unimplemented!(),
|
||||||
|
|
||||||
|
(sk::Sint | sk::Uint, sg::And) => spirv::Op::GroupNonUniformBitwiseAnd,
|
||||||
|
(sk::Sint | sk::Uint, sg::Or) => spirv::Op::GroupNonUniformBitwiseOr,
|
||||||
|
(sk::Sint | sk::Uint, sg::Xor) => spirv::Op::GroupNonUniformBitwiseXor,
|
||||||
|
(sk::Bool, sg::And) => spirv::Op::GroupNonUniformLogicalAnd,
|
||||||
|
(sk::Bool, sg::Or) => spirv::Op::GroupNonUniformLogicalOr,
|
||||||
|
(sk::Bool, sg::Xor) => spirv::Op::GroupNonUniformLogicalXor,
|
||||||
|
(_, sg::And | sg::Or | sg::Xor) => unimplemented!(),
|
||||||
|
};
|
||||||
|
|
||||||
|
let exec_scope_id = self.get_index_constant(spirv::Scope::Subgroup as u32);
|
||||||
|
|
||||||
|
use crate::CollectiveOperation as c;
|
||||||
|
let group_op = match *op {
|
||||||
|
sg::All | sg::Any => None,
|
||||||
|
_ => Some(match *collective_op {
|
||||||
|
c::Reduce => spirv::GroupOperation::Reduce,
|
||||||
|
c::InclusiveScan => spirv::GroupOperation::InclusiveScan,
|
||||||
|
c::ExclusiveScan => spirv::GroupOperation::ExclusiveScan,
|
||||||
|
}),
|
||||||
|
};
|
||||||
|
|
||||||
|
let arg_id = self.cached[argument];
|
||||||
|
block.body.push(Instruction::group_non_uniform_arithmetic(
|
||||||
|
spirv_op,
|
||||||
|
result_type_id,
|
||||||
|
id,
|
||||||
|
exec_scope_id,
|
||||||
|
group_op,
|
||||||
|
arg_id,
|
||||||
|
));
|
||||||
|
self.cached[result] = id;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
pub(super) fn write_subgroup_gather(
|
||||||
|
&mut self,
|
||||||
|
mode: &crate::GatherMode,
|
||||||
|
argument: Handle<crate::Expression>,
|
||||||
|
result: Handle<crate::Expression>,
|
||||||
|
block: &mut Block,
|
||||||
|
) -> Result<(), Error> {
|
||||||
|
self.writer.require_any(
|
||||||
|
"GroupNonUniformBallot",
|
||||||
|
&[spirv::Capability::GroupNonUniformBallot],
|
||||||
|
)?;
|
||||||
|
match *mode {
|
||||||
|
crate::GatherMode::BroadcastFirst | crate::GatherMode::Broadcast(_) => {
|
||||||
|
self.writer.require_any(
|
||||||
|
"GroupNonUniformBallot",
|
||||||
|
&[spirv::Capability::GroupNonUniformBallot],
|
||||||
|
)?;
|
||||||
|
}
|
||||||
|
crate::GatherMode::Shuffle(_) | crate::GatherMode::ShuffleXor(_) => {
|
||||||
|
self.writer.require_any(
|
||||||
|
"GroupNonUniformShuffle",
|
||||||
|
&[spirv::Capability::GroupNonUniformShuffle],
|
||||||
|
)?;
|
||||||
|
}
|
||||||
|
crate::GatherMode::ShuffleDown(_) | crate::GatherMode::ShuffleUp(_) => {
|
||||||
|
self.writer.require_any(
|
||||||
|
"GroupNonUniformShuffleRelative",
|
||||||
|
&[spirv::Capability::GroupNonUniformShuffleRelative],
|
||||||
|
)?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let id = self.gen_id();
|
||||||
|
let result_ty = &self.fun_info[result].ty;
|
||||||
|
let result_type_id = self.get_expression_type_id(result_ty);
|
||||||
|
|
||||||
|
let exec_scope_id = self.get_index_constant(spirv::Scope::Subgroup as u32);
|
||||||
|
|
||||||
|
let arg_id = self.cached[argument];
|
||||||
|
match *mode {
|
||||||
|
crate::GatherMode::BroadcastFirst => {
|
||||||
|
block
|
||||||
|
.body
|
||||||
|
.push(Instruction::group_non_uniform_broadcast_first(
|
||||||
|
result_type_id,
|
||||||
|
id,
|
||||||
|
exec_scope_id,
|
||||||
|
arg_id,
|
||||||
|
));
|
||||||
|
}
|
||||||
|
crate::GatherMode::Broadcast(index)
|
||||||
|
| crate::GatherMode::Shuffle(index)
|
||||||
|
| crate::GatherMode::ShuffleDown(index)
|
||||||
|
| crate::GatherMode::ShuffleUp(index)
|
||||||
|
| crate::GatherMode::ShuffleXor(index) => {
|
||||||
|
let index_id = self.cached[index];
|
||||||
|
let op = match *mode {
|
||||||
|
crate::GatherMode::BroadcastFirst => unreachable!(),
|
||||||
|
// Use shuffle to emit broadcast to allow the index to
|
||||||
|
// be dynamically uniform on Vulkan 1.1. The argument to
|
||||||
|
// OpGroupNonUniformBroadcast must be a constant pre SPIR-V
|
||||||
|
// 1.5 (vulkan 1.2)
|
||||||
|
crate::GatherMode::Broadcast(_) => spirv::Op::GroupNonUniformShuffle,
|
||||||
|
crate::GatherMode::Shuffle(_) => spirv::Op::GroupNonUniformShuffle,
|
||||||
|
crate::GatherMode::ShuffleDown(_) => spirv::Op::GroupNonUniformShuffleDown,
|
||||||
|
crate::GatherMode::ShuffleUp(_) => spirv::Op::GroupNonUniformShuffleUp,
|
||||||
|
crate::GatherMode::ShuffleXor(_) => spirv::Op::GroupNonUniformShuffleXor,
|
||||||
|
};
|
||||||
|
block.body.push(Instruction::group_non_uniform_gather(
|
||||||
|
op,
|
||||||
|
result_type_id,
|
||||||
|
id,
|
||||||
|
exec_scope_id,
|
||||||
|
arg_id,
|
||||||
|
index_id,
|
||||||
|
));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
self.cached[result] = id;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -1310,7 +1310,11 @@ impl Writer {
|
|||||||
spirv::MemorySemantics::WORKGROUP_MEMORY,
|
spirv::MemorySemantics::WORKGROUP_MEMORY,
|
||||||
flags.contains(crate::Barrier::WORK_GROUP),
|
flags.contains(crate::Barrier::WORK_GROUP),
|
||||||
);
|
);
|
||||||
let exec_scope_id = self.get_index_constant(spirv::Scope::Workgroup as u32);
|
let exec_scope_id = if flags.contains(crate::Barrier::SUB_GROUP) {
|
||||||
|
self.get_index_constant(spirv::Scope::Subgroup as u32)
|
||||||
|
} else {
|
||||||
|
self.get_index_constant(spirv::Scope::Workgroup as u32)
|
||||||
|
};
|
||||||
let mem_scope_id = self.get_index_constant(memory_scope as u32);
|
let mem_scope_id = self.get_index_constant(memory_scope as u32);
|
||||||
let semantics_id = self.get_index_constant(semantics.bits());
|
let semantics_id = self.get_index_constant(semantics.bits());
|
||||||
block.body.push(Instruction::control_barrier(
|
block.body.push(Instruction::control_barrier(
|
||||||
@ -1585,6 +1589,41 @@ impl Writer {
|
|||||||
Bi::WorkGroupId => BuiltIn::WorkgroupId,
|
Bi::WorkGroupId => BuiltIn::WorkgroupId,
|
||||||
Bi::WorkGroupSize => BuiltIn::WorkgroupSize,
|
Bi::WorkGroupSize => BuiltIn::WorkgroupSize,
|
||||||
Bi::NumWorkGroups => BuiltIn::NumWorkgroups,
|
Bi::NumWorkGroups => BuiltIn::NumWorkgroups,
|
||||||
|
// Subgroup
|
||||||
|
Bi::NumSubgroups => {
|
||||||
|
self.require_any(
|
||||||
|
"`num_subgroups` built-in",
|
||||||
|
&[spirv::Capability::GroupNonUniform],
|
||||||
|
)?;
|
||||||
|
BuiltIn::NumSubgroups
|
||||||
|
}
|
||||||
|
Bi::SubgroupId => {
|
||||||
|
self.require_any(
|
||||||
|
"`subgroup_id` built-in",
|
||||||
|
&[spirv::Capability::GroupNonUniform],
|
||||||
|
)?;
|
||||||
|
BuiltIn::SubgroupId
|
||||||
|
}
|
||||||
|
Bi::SubgroupSize => {
|
||||||
|
self.require_any(
|
||||||
|
"`subgroup_size` built-in",
|
||||||
|
&[
|
||||||
|
spirv::Capability::GroupNonUniform,
|
||||||
|
spirv::Capability::SubgroupBallotKHR,
|
||||||
|
],
|
||||||
|
)?;
|
||||||
|
BuiltIn::SubgroupSize
|
||||||
|
}
|
||||||
|
Bi::SubgroupInvocationId => {
|
||||||
|
self.require_any(
|
||||||
|
"`subgroup_invocation_id` built-in",
|
||||||
|
&[
|
||||||
|
spirv::Capability::GroupNonUniform,
|
||||||
|
spirv::Capability::SubgroupBallotKHR,
|
||||||
|
],
|
||||||
|
)?;
|
||||||
|
BuiltIn::SubgroupLocalInvocationId
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
self.decorate(id, Decoration::BuiltIn, &[built_in as u32]);
|
self.decorate(id, Decoration::BuiltIn, &[built_in as u32]);
|
||||||
|
|||||||
@ -924,8 +924,124 @@ impl<W: Write> Writer<W> {
|
|||||||
if barrier.contains(crate::Barrier::WORK_GROUP) {
|
if barrier.contains(crate::Barrier::WORK_GROUP) {
|
||||||
writeln!(self.out, "{level}workgroupBarrier();")?;
|
writeln!(self.out, "{level}workgroupBarrier();")?;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if barrier.contains(crate::Barrier::SUB_GROUP) {
|
||||||
|
writeln!(self.out, "{level}subgroupBarrier();")?;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
Statement::RayQuery { .. } => unreachable!(),
|
Statement::RayQuery { .. } => unreachable!(),
|
||||||
|
Statement::SubgroupBallot { result, predicate } => {
|
||||||
|
write!(self.out, "{level}")?;
|
||||||
|
let res_name = format!("{}{}", back::BAKE_PREFIX, result.index());
|
||||||
|
self.start_named_expr(module, result, func_ctx, &res_name)?;
|
||||||
|
self.named_expressions.insert(result, res_name);
|
||||||
|
|
||||||
|
write!(self.out, "subgroupBallot(")?;
|
||||||
|
if let Some(predicate) = predicate {
|
||||||
|
self.write_expr(module, predicate, func_ctx)?;
|
||||||
|
}
|
||||||
|
writeln!(self.out, ");")?;
|
||||||
|
}
|
||||||
|
Statement::SubgroupCollectiveOperation {
|
||||||
|
op,
|
||||||
|
collective_op,
|
||||||
|
argument,
|
||||||
|
result,
|
||||||
|
} => {
|
||||||
|
write!(self.out, "{level}")?;
|
||||||
|
let res_name = format!("{}{}", back::BAKE_PREFIX, result.index());
|
||||||
|
self.start_named_expr(module, result, func_ctx, &res_name)?;
|
||||||
|
self.named_expressions.insert(result, res_name);
|
||||||
|
|
||||||
|
match (collective_op, op) {
|
||||||
|
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::All) => {
|
||||||
|
write!(self.out, "subgroupAll(")?
|
||||||
|
}
|
||||||
|
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Any) => {
|
||||||
|
write!(self.out, "subgroupAny(")?
|
||||||
|
}
|
||||||
|
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Add) => {
|
||||||
|
write!(self.out, "subgroupAdd(")?
|
||||||
|
}
|
||||||
|
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Mul) => {
|
||||||
|
write!(self.out, "subgroupMul(")?
|
||||||
|
}
|
||||||
|
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Max) => {
|
||||||
|
write!(self.out, "subgroupMax(")?
|
||||||
|
}
|
||||||
|
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Min) => {
|
||||||
|
write!(self.out, "subgroupMin(")?
|
||||||
|
}
|
||||||
|
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::And) => {
|
||||||
|
write!(self.out, "subgroupAnd(")?
|
||||||
|
}
|
||||||
|
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Or) => {
|
||||||
|
write!(self.out, "subgroupOr(")?
|
||||||
|
}
|
||||||
|
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Xor) => {
|
||||||
|
write!(self.out, "subgroupXor(")?
|
||||||
|
}
|
||||||
|
(crate::CollectiveOperation::ExclusiveScan, crate::SubgroupOperation::Add) => {
|
||||||
|
write!(self.out, "subgroupExclusiveAdd(")?
|
||||||
|
}
|
||||||
|
(crate::CollectiveOperation::ExclusiveScan, crate::SubgroupOperation::Mul) => {
|
||||||
|
write!(self.out, "subgroupExclusiveMul(")?
|
||||||
|
}
|
||||||
|
(crate::CollectiveOperation::InclusiveScan, crate::SubgroupOperation::Add) => {
|
||||||
|
write!(self.out, "subgroupInclusiveAdd(")?
|
||||||
|
}
|
||||||
|
(crate::CollectiveOperation::InclusiveScan, crate::SubgroupOperation::Mul) => {
|
||||||
|
write!(self.out, "subgroupInclusiveMul(")?
|
||||||
|
}
|
||||||
|
_ => unimplemented!(),
|
||||||
|
}
|
||||||
|
self.write_expr(module, argument, func_ctx)?;
|
||||||
|
writeln!(self.out, ");")?;
|
||||||
|
}
|
||||||
|
Statement::SubgroupGather {
|
||||||
|
mode,
|
||||||
|
argument,
|
||||||
|
result,
|
||||||
|
} => {
|
||||||
|
write!(self.out, "{level}")?;
|
||||||
|
let res_name = format!("{}{}", back::BAKE_PREFIX, result.index());
|
||||||
|
self.start_named_expr(module, result, func_ctx, &res_name)?;
|
||||||
|
self.named_expressions.insert(result, res_name);
|
||||||
|
|
||||||
|
match mode {
|
||||||
|
crate::GatherMode::BroadcastFirst => {
|
||||||
|
write!(self.out, "subgroupBroadcastFirst(")?;
|
||||||
|
}
|
||||||
|
crate::GatherMode::Broadcast(_) => {
|
||||||
|
write!(self.out, "subgroupBroadcast(")?;
|
||||||
|
}
|
||||||
|
crate::GatherMode::Shuffle(_) => {
|
||||||
|
write!(self.out, "subgroupShuffle(")?;
|
||||||
|
}
|
||||||
|
crate::GatherMode::ShuffleDown(_) => {
|
||||||
|
write!(self.out, "subgroupShuffleDown(")?;
|
||||||
|
}
|
||||||
|
crate::GatherMode::ShuffleUp(_) => {
|
||||||
|
write!(self.out, "subgroupShuffleUp(")?;
|
||||||
|
}
|
||||||
|
crate::GatherMode::ShuffleXor(_) => {
|
||||||
|
write!(self.out, "subgroupShuffleXor(")?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
self.write_expr(module, argument, func_ctx)?;
|
||||||
|
match mode {
|
||||||
|
crate::GatherMode::BroadcastFirst => {}
|
||||||
|
crate::GatherMode::Broadcast(index)
|
||||||
|
| crate::GatherMode::Shuffle(index)
|
||||||
|
| crate::GatherMode::ShuffleDown(index)
|
||||||
|
| crate::GatherMode::ShuffleUp(index)
|
||||||
|
| crate::GatherMode::ShuffleXor(index) => {
|
||||||
|
write!(self.out, ", ")?;
|
||||||
|
self.write_expr(module, index, func_ctx)?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
writeln!(self.out, ");")?;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
@ -1698,6 +1814,8 @@ impl<W: Write> Writer<W> {
|
|||||||
Expression::CallResult(_)
|
Expression::CallResult(_)
|
||||||
| Expression::AtomicResult { .. }
|
| Expression::AtomicResult { .. }
|
||||||
| Expression::RayQueryProceedResult
|
| Expression::RayQueryProceedResult
|
||||||
|
| Expression::SubgroupBallotResult
|
||||||
|
| Expression::SubgroupOperationResult { .. }
|
||||||
| Expression::WorkGroupUniformLoadResult { .. } => {}
|
| Expression::WorkGroupUniformLoadResult { .. } => {}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1799,6 +1917,10 @@ fn builtin_str(built_in: crate::BuiltIn) -> Result<&'static str, Error> {
|
|||||||
Bi::SampleMask => "sample_mask",
|
Bi::SampleMask => "sample_mask",
|
||||||
Bi::PrimitiveIndex => "primitive_index",
|
Bi::PrimitiveIndex => "primitive_index",
|
||||||
Bi::ViewIndex => "view_index",
|
Bi::ViewIndex => "view_index",
|
||||||
|
Bi::NumSubgroups => "num_subgroups",
|
||||||
|
Bi::SubgroupId => "subgroup_id",
|
||||||
|
Bi::SubgroupSize => "subgroup_size",
|
||||||
|
Bi::SubgroupInvocationId => "subgroup_invocation_id",
|
||||||
Bi::BaseInstance
|
Bi::BaseInstance
|
||||||
| Bi::BaseVertex
|
| Bi::BaseVertex
|
||||||
| Bi::ClipDistance
|
| Bi::ClipDistance
|
||||||
|
|||||||
@ -72,6 +72,7 @@ impl<'tracer> ExpressionTracer<'tracer> {
|
|||||||
| Ex::GlobalVariable(_)
|
| Ex::GlobalVariable(_)
|
||||||
| Ex::LocalVariable(_)
|
| Ex::LocalVariable(_)
|
||||||
| Ex::CallResult(_)
|
| Ex::CallResult(_)
|
||||||
|
| Ex::SubgroupBallotResult
|
||||||
| Ex::RayQueryProceedResult => {}
|
| Ex::RayQueryProceedResult => {}
|
||||||
|
|
||||||
Ex::Constant(handle) => {
|
Ex::Constant(handle) => {
|
||||||
@ -192,6 +193,7 @@ impl<'tracer> ExpressionTracer<'tracer> {
|
|||||||
Ex::AtomicResult { ty, comparison: _ } => self.types_used.insert(ty),
|
Ex::AtomicResult { ty, comparison: _ } => self.types_used.insert(ty),
|
||||||
Ex::WorkGroupUniformLoadResult { ty } => self.types_used.insert(ty),
|
Ex::WorkGroupUniformLoadResult { ty } => self.types_used.insert(ty),
|
||||||
Ex::ArrayLength(expr) => self.expressions_used.insert(expr),
|
Ex::ArrayLength(expr) => self.expressions_used.insert(expr),
|
||||||
|
Ex::SubgroupOperationResult { ty } => self.types_used.insert(ty),
|
||||||
Ex::RayQueryGetIntersection {
|
Ex::RayQueryGetIntersection {
|
||||||
query,
|
query,
|
||||||
committed: _,
|
committed: _,
|
||||||
@ -223,6 +225,7 @@ impl ModuleMap {
|
|||||||
| Ex::GlobalVariable(_)
|
| Ex::GlobalVariable(_)
|
||||||
| Ex::LocalVariable(_)
|
| Ex::LocalVariable(_)
|
||||||
| Ex::CallResult(_)
|
| Ex::CallResult(_)
|
||||||
|
| Ex::SubgroupBallotResult
|
||||||
| Ex::RayQueryProceedResult => {}
|
| Ex::RayQueryProceedResult => {}
|
||||||
|
|
||||||
// All overrides are retained, so their handles never change.
|
// All overrides are retained, so their handles never change.
|
||||||
@ -353,6 +356,7 @@ impl ModuleMap {
|
|||||||
comparison: _,
|
comparison: _,
|
||||||
} => self.types.adjust(ty),
|
} => self.types.adjust(ty),
|
||||||
Ex::WorkGroupUniformLoadResult { ref mut ty } => self.types.adjust(ty),
|
Ex::WorkGroupUniformLoadResult { ref mut ty } => self.types.adjust(ty),
|
||||||
|
Ex::SubgroupOperationResult { ref mut ty } => self.types.adjust(ty),
|
||||||
Ex::ArrayLength(ref mut expr) => adjust(expr),
|
Ex::ArrayLength(ref mut expr) => adjust(expr),
|
||||||
Ex::RayQueryGetIntersection {
|
Ex::RayQueryGetIntersection {
|
||||||
ref mut query,
|
ref mut query,
|
||||||
|
|||||||
@ -97,6 +97,39 @@ impl FunctionTracer<'_> {
|
|||||||
self.expressions_used.insert(query);
|
self.expressions_used.insert(query);
|
||||||
self.trace_ray_query_function(fun);
|
self.trace_ray_query_function(fun);
|
||||||
}
|
}
|
||||||
|
St::SubgroupBallot { result, predicate } => {
|
||||||
|
if let Some(predicate) = predicate {
|
||||||
|
self.expressions_used.insert(predicate)
|
||||||
|
}
|
||||||
|
self.expressions_used.insert(result)
|
||||||
|
}
|
||||||
|
St::SubgroupCollectiveOperation {
|
||||||
|
op: _,
|
||||||
|
collective_op: _,
|
||||||
|
argument,
|
||||||
|
result,
|
||||||
|
} => {
|
||||||
|
self.expressions_used.insert(argument);
|
||||||
|
self.expressions_used.insert(result)
|
||||||
|
}
|
||||||
|
St::SubgroupGather {
|
||||||
|
mode,
|
||||||
|
argument,
|
||||||
|
result,
|
||||||
|
} => {
|
||||||
|
match mode {
|
||||||
|
crate::GatherMode::BroadcastFirst => {}
|
||||||
|
crate::GatherMode::Broadcast(index)
|
||||||
|
| crate::GatherMode::Shuffle(index)
|
||||||
|
| crate::GatherMode::ShuffleDown(index)
|
||||||
|
| crate::GatherMode::ShuffleUp(index)
|
||||||
|
| crate::GatherMode::ShuffleXor(index) => {
|
||||||
|
self.expressions_used.insert(index)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
self.expressions_used.insert(argument);
|
||||||
|
self.expressions_used.insert(result)
|
||||||
|
}
|
||||||
|
|
||||||
// Trivial statements.
|
// Trivial statements.
|
||||||
St::Break
|
St::Break
|
||||||
@ -250,6 +283,40 @@ impl FunctionMap {
|
|||||||
adjust(query);
|
adjust(query);
|
||||||
self.adjust_ray_query_function(fun);
|
self.adjust_ray_query_function(fun);
|
||||||
}
|
}
|
||||||
|
St::SubgroupBallot {
|
||||||
|
ref mut result,
|
||||||
|
ref mut predicate,
|
||||||
|
} => {
|
||||||
|
if let Some(ref mut predicate) = *predicate {
|
||||||
|
adjust(predicate);
|
||||||
|
}
|
||||||
|
adjust(result);
|
||||||
|
}
|
||||||
|
St::SubgroupCollectiveOperation {
|
||||||
|
op: _,
|
||||||
|
collective_op: _,
|
||||||
|
ref mut argument,
|
||||||
|
ref mut result,
|
||||||
|
} => {
|
||||||
|
adjust(argument);
|
||||||
|
adjust(result);
|
||||||
|
}
|
||||||
|
St::SubgroupGather {
|
||||||
|
ref mut mode,
|
||||||
|
ref mut argument,
|
||||||
|
ref mut result,
|
||||||
|
} => {
|
||||||
|
match *mode {
|
||||||
|
crate::GatherMode::BroadcastFirst => {}
|
||||||
|
crate::GatherMode::Broadcast(ref mut index)
|
||||||
|
| crate::GatherMode::Shuffle(ref mut index)
|
||||||
|
| crate::GatherMode::ShuffleDown(ref mut index)
|
||||||
|
| crate::GatherMode::ShuffleUp(ref mut index)
|
||||||
|
| crate::GatherMode::ShuffleXor(ref mut index) => adjust(index),
|
||||||
|
}
|
||||||
|
adjust(argument);
|
||||||
|
adjust(result);
|
||||||
|
}
|
||||||
|
|
||||||
// Trivial statements.
|
// Trivial statements.
|
||||||
St::Break
|
St::Break
|
||||||
|
|||||||
@ -153,6 +153,11 @@ pub(super) fn map_builtin(word: spirv::Word, invariant: bool) -> Result<crate::B
|
|||||||
Some(Bi::WorkgroupId) => crate::BuiltIn::WorkGroupId,
|
Some(Bi::WorkgroupId) => crate::BuiltIn::WorkGroupId,
|
||||||
Some(Bi::WorkgroupSize) => crate::BuiltIn::WorkGroupSize,
|
Some(Bi::WorkgroupSize) => crate::BuiltIn::WorkGroupSize,
|
||||||
Some(Bi::NumWorkgroups) => crate::BuiltIn::NumWorkGroups,
|
Some(Bi::NumWorkgroups) => crate::BuiltIn::NumWorkGroups,
|
||||||
|
// subgroup
|
||||||
|
Some(Bi::NumSubgroups) => crate::BuiltIn::NumSubgroups,
|
||||||
|
Some(Bi::SubgroupId) => crate::BuiltIn::SubgroupId,
|
||||||
|
Some(Bi::SubgroupSize) => crate::BuiltIn::SubgroupSize,
|
||||||
|
Some(Bi::SubgroupLocalInvocationId) => crate::BuiltIn::SubgroupInvocationId,
|
||||||
_ => return Err(Error::UnsupportedBuiltIn(word)),
|
_ => return Err(Error::UnsupportedBuiltIn(word)),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@ -58,6 +58,8 @@ pub enum Error {
|
|||||||
UnknownBinaryOperator(spirv::Op),
|
UnknownBinaryOperator(spirv::Op),
|
||||||
#[error("unknown relational function {0:?}")]
|
#[error("unknown relational function {0:?}")]
|
||||||
UnknownRelationalFunction(spirv::Op),
|
UnknownRelationalFunction(spirv::Op),
|
||||||
|
#[error("unsupported group operation %{0}")]
|
||||||
|
UnsupportedGroupOperation(spirv::Word),
|
||||||
#[error("invalid parameter {0:?}")]
|
#[error("invalid parameter {0:?}")]
|
||||||
InvalidParameter(spirv::Op),
|
InvalidParameter(spirv::Op),
|
||||||
#[error("invalid operand count {1} for {0:?}")]
|
#[error("invalid operand count {1} for {0:?}")]
|
||||||
|
|||||||
@ -3700,6 +3700,254 @@ impl<I: Iterator<Item = u32>> Frontend<I> {
|
|||||||
},
|
},
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
Op::GroupNonUniformBallot => {
|
||||||
|
inst.expect(5)?;
|
||||||
|
block.extend(emitter.finish(ctx.expressions));
|
||||||
|
let result_type_id = self.next()?;
|
||||||
|
let result_id = self.next()?;
|
||||||
|
let exec_scope_id = self.next()?;
|
||||||
|
let predicate_id = self.next()?;
|
||||||
|
|
||||||
|
let exec_scope_const = self.lookup_constant.lookup(exec_scope_id)?;
|
||||||
|
let _exec_scope = resolve_constant(ctx.gctx(), &exec_scope_const.inner)
|
||||||
|
.filter(|exec_scope| *exec_scope == spirv::Scope::Subgroup as u32)
|
||||||
|
.ok_or(Error::InvalidBarrierScope(exec_scope_id))?;
|
||||||
|
|
||||||
|
let predicate = if self
|
||||||
|
.lookup_constant
|
||||||
|
.lookup(predicate_id)
|
||||||
|
.ok()
|
||||||
|
.filter(|predicate_const| match predicate_const.inner {
|
||||||
|
Constant::Constant(constant) => matches!(
|
||||||
|
ctx.gctx().global_expressions[ctx.gctx().constants[constant].init],
|
||||||
|
crate::Expression::Literal(crate::Literal::Bool(true)),
|
||||||
|
),
|
||||||
|
Constant::Override(_) => false,
|
||||||
|
})
|
||||||
|
.is_some()
|
||||||
|
{
|
||||||
|
None
|
||||||
|
} else {
|
||||||
|
let predicate_lookup = self.lookup_expression.lookup(predicate_id)?;
|
||||||
|
let predicate_handle = get_expr_handle!(predicate_id, predicate_lookup);
|
||||||
|
Some(predicate_handle)
|
||||||
|
};
|
||||||
|
|
||||||
|
let result_handle = ctx
|
||||||
|
.expressions
|
||||||
|
.append(crate::Expression::SubgroupBallotResult, span);
|
||||||
|
self.lookup_expression.insert(
|
||||||
|
result_id,
|
||||||
|
LookupExpression {
|
||||||
|
handle: result_handle,
|
||||||
|
type_id: result_type_id,
|
||||||
|
block_id,
|
||||||
|
},
|
||||||
|
);
|
||||||
|
|
||||||
|
block.push(
|
||||||
|
crate::Statement::SubgroupBallot {
|
||||||
|
result: result_handle,
|
||||||
|
predicate,
|
||||||
|
},
|
||||||
|
span,
|
||||||
|
);
|
||||||
|
emitter.start(ctx.expressions);
|
||||||
|
}
|
||||||
|
spirv::Op::GroupNonUniformAll
|
||||||
|
| spirv::Op::GroupNonUniformAny
|
||||||
|
| spirv::Op::GroupNonUniformIAdd
|
||||||
|
| spirv::Op::GroupNonUniformFAdd
|
||||||
|
| spirv::Op::GroupNonUniformIMul
|
||||||
|
| spirv::Op::GroupNonUniformFMul
|
||||||
|
| spirv::Op::GroupNonUniformSMax
|
||||||
|
| spirv::Op::GroupNonUniformUMax
|
||||||
|
| spirv::Op::GroupNonUniformFMax
|
||||||
|
| spirv::Op::GroupNonUniformSMin
|
||||||
|
| spirv::Op::GroupNonUniformUMin
|
||||||
|
| spirv::Op::GroupNonUniformFMin
|
||||||
|
| spirv::Op::GroupNonUniformBitwiseAnd
|
||||||
|
| spirv::Op::GroupNonUniformBitwiseOr
|
||||||
|
| spirv::Op::GroupNonUniformBitwiseXor
|
||||||
|
| spirv::Op::GroupNonUniformLogicalAnd
|
||||||
|
| spirv::Op::GroupNonUniformLogicalOr
|
||||||
|
| spirv::Op::GroupNonUniformLogicalXor => {
|
||||||
|
block.extend(emitter.finish(ctx.expressions));
|
||||||
|
inst.expect(
|
||||||
|
if matches!(
|
||||||
|
inst.op,
|
||||||
|
spirv::Op::GroupNonUniformAll | spirv::Op::GroupNonUniformAny
|
||||||
|
) {
|
||||||
|
5
|
||||||
|
} else {
|
||||||
|
6
|
||||||
|
},
|
||||||
|
)?;
|
||||||
|
let result_type_id = self.next()?;
|
||||||
|
let result_id = self.next()?;
|
||||||
|
let exec_scope_id = self.next()?;
|
||||||
|
let collective_op_id = match inst.op {
|
||||||
|
spirv::Op::GroupNonUniformAll | spirv::Op::GroupNonUniformAny => {
|
||||||
|
crate::CollectiveOperation::Reduce
|
||||||
|
}
|
||||||
|
_ => {
|
||||||
|
let group_op_id = self.next()?;
|
||||||
|
match spirv::GroupOperation::from_u32(group_op_id) {
|
||||||
|
Some(spirv::GroupOperation::Reduce) => {
|
||||||
|
crate::CollectiveOperation::Reduce
|
||||||
|
}
|
||||||
|
Some(spirv::GroupOperation::InclusiveScan) => {
|
||||||
|
crate::CollectiveOperation::InclusiveScan
|
||||||
|
}
|
||||||
|
Some(spirv::GroupOperation::ExclusiveScan) => {
|
||||||
|
crate::CollectiveOperation::ExclusiveScan
|
||||||
|
}
|
||||||
|
_ => return Err(Error::UnsupportedGroupOperation(group_op_id)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
let argument_id = self.next()?;
|
||||||
|
|
||||||
|
let argument_lookup = self.lookup_expression.lookup(argument_id)?;
|
||||||
|
let argument_handle = get_expr_handle!(argument_id, argument_lookup);
|
||||||
|
|
||||||
|
let exec_scope_const = self.lookup_constant.lookup(exec_scope_id)?;
|
||||||
|
let _exec_scope = resolve_constant(ctx.gctx(), &exec_scope_const.inner)
|
||||||
|
.filter(|exec_scope| *exec_scope == spirv::Scope::Subgroup as u32)
|
||||||
|
.ok_or(Error::InvalidBarrierScope(exec_scope_id))?;
|
||||||
|
|
||||||
|
let op_id = match inst.op {
|
||||||
|
spirv::Op::GroupNonUniformAll => crate::SubgroupOperation::All,
|
||||||
|
spirv::Op::GroupNonUniformAny => crate::SubgroupOperation::Any,
|
||||||
|
spirv::Op::GroupNonUniformIAdd | spirv::Op::GroupNonUniformFAdd => {
|
||||||
|
crate::SubgroupOperation::Add
|
||||||
|
}
|
||||||
|
spirv::Op::GroupNonUniformIMul | spirv::Op::GroupNonUniformFMul => {
|
||||||
|
crate::SubgroupOperation::Mul
|
||||||
|
}
|
||||||
|
spirv::Op::GroupNonUniformSMax
|
||||||
|
| spirv::Op::GroupNonUniformUMax
|
||||||
|
| spirv::Op::GroupNonUniformFMax => crate::SubgroupOperation::Max,
|
||||||
|
spirv::Op::GroupNonUniformSMin
|
||||||
|
| spirv::Op::GroupNonUniformUMin
|
||||||
|
| spirv::Op::GroupNonUniformFMin => crate::SubgroupOperation::Min,
|
||||||
|
spirv::Op::GroupNonUniformBitwiseAnd
|
||||||
|
| spirv::Op::GroupNonUniformLogicalAnd => crate::SubgroupOperation::And,
|
||||||
|
spirv::Op::GroupNonUniformBitwiseOr
|
||||||
|
| spirv::Op::GroupNonUniformLogicalOr => crate::SubgroupOperation::Or,
|
||||||
|
spirv::Op::GroupNonUniformBitwiseXor
|
||||||
|
| spirv::Op::GroupNonUniformLogicalXor => crate::SubgroupOperation::Xor,
|
||||||
|
_ => unreachable!(),
|
||||||
|
};
|
||||||
|
|
||||||
|
let result_type = self.lookup_type.lookup(result_type_id)?;
|
||||||
|
|
||||||
|
let result_handle = ctx.expressions.append(
|
||||||
|
crate::Expression::SubgroupOperationResult {
|
||||||
|
ty: result_type.handle,
|
||||||
|
},
|
||||||
|
span,
|
||||||
|
);
|
||||||
|
self.lookup_expression.insert(
|
||||||
|
result_id,
|
||||||
|
LookupExpression {
|
||||||
|
handle: result_handle,
|
||||||
|
type_id: result_type_id,
|
||||||
|
block_id,
|
||||||
|
},
|
||||||
|
);
|
||||||
|
|
||||||
|
block.push(
|
||||||
|
crate::Statement::SubgroupCollectiveOperation {
|
||||||
|
result: result_handle,
|
||||||
|
op: op_id,
|
||||||
|
collective_op: collective_op_id,
|
||||||
|
argument: argument_handle,
|
||||||
|
},
|
||||||
|
span,
|
||||||
|
);
|
||||||
|
emitter.start(ctx.expressions);
|
||||||
|
}
|
||||||
|
Op::GroupNonUniformBroadcastFirst
|
||||||
|
| Op::GroupNonUniformBroadcast
|
||||||
|
| Op::GroupNonUniformShuffle
|
||||||
|
| Op::GroupNonUniformShuffleDown
|
||||||
|
| Op::GroupNonUniformShuffleUp
|
||||||
|
| Op::GroupNonUniformShuffleXor => {
|
||||||
|
inst.expect(
|
||||||
|
if matches!(inst.op, spirv::Op::GroupNonUniformBroadcastFirst) {
|
||||||
|
5
|
||||||
|
} else {
|
||||||
|
6
|
||||||
|
},
|
||||||
|
)?;
|
||||||
|
block.extend(emitter.finish(ctx.expressions));
|
||||||
|
let result_type_id = self.next()?;
|
||||||
|
let result_id = self.next()?;
|
||||||
|
let exec_scope_id = self.next()?;
|
||||||
|
let argument_id = self.next()?;
|
||||||
|
|
||||||
|
let argument_lookup = self.lookup_expression.lookup(argument_id)?;
|
||||||
|
let argument_handle = get_expr_handle!(argument_id, argument_lookup);
|
||||||
|
|
||||||
|
let exec_scope_const = self.lookup_constant.lookup(exec_scope_id)?;
|
||||||
|
let _exec_scope = resolve_constant(ctx.gctx(), &exec_scope_const.inner)
|
||||||
|
.filter(|exec_scope| *exec_scope == spirv::Scope::Subgroup as u32)
|
||||||
|
.ok_or(Error::InvalidBarrierScope(exec_scope_id))?;
|
||||||
|
|
||||||
|
let mode = if matches!(inst.op, spirv::Op::GroupNonUniformBroadcastFirst) {
|
||||||
|
crate::GatherMode::BroadcastFirst
|
||||||
|
} else {
|
||||||
|
let index_id = self.next()?;
|
||||||
|
let index_lookup = self.lookup_expression.lookup(index_id)?;
|
||||||
|
let index_handle = get_expr_handle!(index_id, index_lookup);
|
||||||
|
match inst.op {
|
||||||
|
spirv::Op::GroupNonUniformBroadcast => {
|
||||||
|
crate::GatherMode::Broadcast(index_handle)
|
||||||
|
}
|
||||||
|
spirv::Op::GroupNonUniformShuffle => {
|
||||||
|
crate::GatherMode::Shuffle(index_handle)
|
||||||
|
}
|
||||||
|
spirv::Op::GroupNonUniformShuffleDown => {
|
||||||
|
crate::GatherMode::ShuffleDown(index_handle)
|
||||||
|
}
|
||||||
|
spirv::Op::GroupNonUniformShuffleUp => {
|
||||||
|
crate::GatherMode::ShuffleUp(index_handle)
|
||||||
|
}
|
||||||
|
spirv::Op::GroupNonUniformShuffleXor => {
|
||||||
|
crate::GatherMode::ShuffleXor(index_handle)
|
||||||
|
}
|
||||||
|
_ => unreachable!(),
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let result_type = self.lookup_type.lookup(result_type_id)?;
|
||||||
|
|
||||||
|
let result_handle = ctx.expressions.append(
|
||||||
|
crate::Expression::SubgroupOperationResult {
|
||||||
|
ty: result_type.handle,
|
||||||
|
},
|
||||||
|
span,
|
||||||
|
);
|
||||||
|
self.lookup_expression.insert(
|
||||||
|
result_id,
|
||||||
|
LookupExpression {
|
||||||
|
handle: result_handle,
|
||||||
|
type_id: result_type_id,
|
||||||
|
block_id,
|
||||||
|
},
|
||||||
|
);
|
||||||
|
|
||||||
|
block.push(
|
||||||
|
crate::Statement::SubgroupGather {
|
||||||
|
result: result_handle,
|
||||||
|
mode,
|
||||||
|
argument: argument_handle,
|
||||||
|
},
|
||||||
|
span,
|
||||||
|
);
|
||||||
|
emitter.start(ctx.expressions);
|
||||||
|
}
|
||||||
_ => return Err(Error::UnsupportedInstruction(self.state, inst.op)),
|
_ => return Err(Error::UnsupportedInstruction(self.state, inst.op)),
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -3824,7 +4072,10 @@ impl<I: Iterator<Item = u32>> Frontend<I> {
|
|||||||
| S::Store { .. }
|
| S::Store { .. }
|
||||||
| S::ImageStore { .. }
|
| S::ImageStore { .. }
|
||||||
| S::Atomic { .. }
|
| S::Atomic { .. }
|
||||||
| S::RayQuery { .. } => {}
|
| S::RayQuery { .. }
|
||||||
|
| S::SubgroupBallot { .. }
|
||||||
|
| S::SubgroupCollectiveOperation { .. }
|
||||||
|
| S::SubgroupGather { .. } => {}
|
||||||
S::Call {
|
S::Call {
|
||||||
function: ref mut callee,
|
function: ref mut callee,
|
||||||
ref arguments,
|
ref arguments,
|
||||||
|
|||||||
@ -874,6 +874,29 @@ impl Texture {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
enum SubgroupGather {
|
||||||
|
BroadcastFirst,
|
||||||
|
Broadcast,
|
||||||
|
Shuffle,
|
||||||
|
ShuffleDown,
|
||||||
|
ShuffleUp,
|
||||||
|
ShuffleXor,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl SubgroupGather {
|
||||||
|
pub fn map(word: &str) -> Option<Self> {
|
||||||
|
Some(match word {
|
||||||
|
"subgroupBroadcastFirst" => Self::BroadcastFirst,
|
||||||
|
"subgroupBroadcast" => Self::Broadcast,
|
||||||
|
"subgroupShuffle" => Self::Shuffle,
|
||||||
|
"subgroupShuffleDown" => Self::ShuffleDown,
|
||||||
|
"subgroupShuffleUp" => Self::ShuffleUp,
|
||||||
|
"subgroupShuffleXor" => Self::ShuffleXor,
|
||||||
|
_ => return None,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub struct Lowerer<'source, 'temp> {
|
pub struct Lowerer<'source, 'temp> {
|
||||||
index: &'temp Index<'source>,
|
index: &'temp Index<'source>,
|
||||||
layouter: Layouter,
|
layouter: Layouter,
|
||||||
@ -2054,6 +2077,14 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
|
|||||||
}
|
}
|
||||||
} else if let Some(fun) = Texture::map(function.name) {
|
} else if let Some(fun) = Texture::map(function.name) {
|
||||||
self.texture_sample_helper(fun, arguments, span, ctx)?
|
self.texture_sample_helper(fun, arguments, span, ctx)?
|
||||||
|
} else if let Some((op, cop)) = conv::map_subgroup_operation(function.name) {
|
||||||
|
return Ok(Some(
|
||||||
|
self.subgroup_operation_helper(span, op, cop, arguments, ctx)?,
|
||||||
|
));
|
||||||
|
} else if let Some(mode) = SubgroupGather::map(function.name) {
|
||||||
|
return Ok(Some(
|
||||||
|
self.subgroup_gather_helper(span, mode, arguments, ctx)?,
|
||||||
|
));
|
||||||
} else {
|
} else {
|
||||||
match function.name {
|
match function.name {
|
||||||
"select" => {
|
"select" => {
|
||||||
@ -2221,6 +2252,14 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
|
|||||||
.push(crate::Statement::Barrier(crate::Barrier::WORK_GROUP), span);
|
.push(crate::Statement::Barrier(crate::Barrier::WORK_GROUP), span);
|
||||||
return Ok(None);
|
return Ok(None);
|
||||||
}
|
}
|
||||||
|
"subgroupBarrier" => {
|
||||||
|
ctx.prepare_args(arguments, 0, span).finish()?;
|
||||||
|
|
||||||
|
let rctx = ctx.runtime_expression_ctx(span)?;
|
||||||
|
rctx.block
|
||||||
|
.push(crate::Statement::Barrier(crate::Barrier::SUB_GROUP), span);
|
||||||
|
return Ok(None);
|
||||||
|
}
|
||||||
"workgroupUniformLoad" => {
|
"workgroupUniformLoad" => {
|
||||||
let mut args = ctx.prepare_args(arguments, 1, span);
|
let mut args = ctx.prepare_args(arguments, 1, span);
|
||||||
let expr = args.next()?;
|
let expr = args.next()?;
|
||||||
@ -2428,6 +2467,22 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
|
|||||||
)?;
|
)?;
|
||||||
return Ok(Some(handle));
|
return Ok(Some(handle));
|
||||||
}
|
}
|
||||||
|
"subgroupBallot" => {
|
||||||
|
let mut args = ctx.prepare_args(arguments, 0, span);
|
||||||
|
let predicate = if arguments.len() == 1 {
|
||||||
|
Some(self.expression(args.next()?, ctx)?)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
args.finish()?;
|
||||||
|
|
||||||
|
let result = ctx
|
||||||
|
.interrupt_emitter(crate::Expression::SubgroupBallotResult, span)?;
|
||||||
|
let rctx = ctx.runtime_expression_ctx(span)?;
|
||||||
|
rctx.block
|
||||||
|
.push(crate::Statement::SubgroupBallot { result, predicate }, span);
|
||||||
|
return Ok(Some(result));
|
||||||
|
}
|
||||||
_ => return Err(Error::UnknownIdent(function.span, function.name)),
|
_ => return Err(Error::UnknownIdent(function.span, function.name)),
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -2619,6 +2674,80 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn subgroup_operation_helper(
|
||||||
|
&mut self,
|
||||||
|
span: Span,
|
||||||
|
op: crate::SubgroupOperation,
|
||||||
|
collective_op: crate::CollectiveOperation,
|
||||||
|
arguments: &[Handle<ast::Expression<'source>>],
|
||||||
|
ctx: &mut ExpressionContext<'source, '_, '_>,
|
||||||
|
) -> Result<Handle<crate::Expression>, Error<'source>> {
|
||||||
|
let mut args = ctx.prepare_args(arguments, 1, span);
|
||||||
|
|
||||||
|
let argument = self.expression(args.next()?, ctx)?;
|
||||||
|
args.finish()?;
|
||||||
|
|
||||||
|
let ty = ctx.register_type(argument)?;
|
||||||
|
|
||||||
|
let result =
|
||||||
|
ctx.interrupt_emitter(crate::Expression::SubgroupOperationResult { ty }, span)?;
|
||||||
|
let rctx = ctx.runtime_expression_ctx(span)?;
|
||||||
|
rctx.block.push(
|
||||||
|
crate::Statement::SubgroupCollectiveOperation {
|
||||||
|
op,
|
||||||
|
collective_op,
|
||||||
|
argument,
|
||||||
|
result,
|
||||||
|
},
|
||||||
|
span,
|
||||||
|
);
|
||||||
|
Ok(result)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn subgroup_gather_helper(
|
||||||
|
&mut self,
|
||||||
|
span: Span,
|
||||||
|
mode: SubgroupGather,
|
||||||
|
arguments: &[Handle<ast::Expression<'source>>],
|
||||||
|
ctx: &mut ExpressionContext<'source, '_, '_>,
|
||||||
|
) -> Result<Handle<crate::Expression>, Error<'source>> {
|
||||||
|
let mut args = ctx.prepare_args(arguments, 2, span);
|
||||||
|
|
||||||
|
let argument = self.expression(args.next()?, ctx)?;
|
||||||
|
|
||||||
|
use SubgroupGather as Sg;
|
||||||
|
let mode = if let Sg::BroadcastFirst = mode {
|
||||||
|
crate::GatherMode::BroadcastFirst
|
||||||
|
} else {
|
||||||
|
let index = self.expression(args.next()?, ctx)?;
|
||||||
|
match mode {
|
||||||
|
Sg::Broadcast => crate::GatherMode::Broadcast(index),
|
||||||
|
Sg::Shuffle => crate::GatherMode::Shuffle(index),
|
||||||
|
Sg::ShuffleDown => crate::GatherMode::ShuffleDown(index),
|
||||||
|
Sg::ShuffleUp => crate::GatherMode::ShuffleUp(index),
|
||||||
|
Sg::ShuffleXor => crate::GatherMode::ShuffleXor(index),
|
||||||
|
Sg::BroadcastFirst => unreachable!(),
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
args.finish()?;
|
||||||
|
|
||||||
|
let ty = ctx.register_type(argument)?;
|
||||||
|
|
||||||
|
let result =
|
||||||
|
ctx.interrupt_emitter(crate::Expression::SubgroupOperationResult { ty }, span)?;
|
||||||
|
let rctx = ctx.runtime_expression_ctx(span)?;
|
||||||
|
rctx.block.push(
|
||||||
|
crate::Statement::SubgroupGather {
|
||||||
|
mode,
|
||||||
|
argument,
|
||||||
|
result,
|
||||||
|
},
|
||||||
|
span,
|
||||||
|
);
|
||||||
|
Ok(result)
|
||||||
|
}
|
||||||
|
|
||||||
fn r#struct(
|
fn r#struct(
|
||||||
&mut self,
|
&mut self,
|
||||||
s: &ast::Struct<'source>,
|
s: &ast::Struct<'source>,
|
||||||
|
|||||||
@ -35,6 +35,11 @@ pub fn map_built_in(word: &str, span: Span) -> Result<crate::BuiltIn, Error<'_>>
|
|||||||
"local_invocation_index" => crate::BuiltIn::LocalInvocationIndex,
|
"local_invocation_index" => crate::BuiltIn::LocalInvocationIndex,
|
||||||
"workgroup_id" => crate::BuiltIn::WorkGroupId,
|
"workgroup_id" => crate::BuiltIn::WorkGroupId,
|
||||||
"num_workgroups" => crate::BuiltIn::NumWorkGroups,
|
"num_workgroups" => crate::BuiltIn::NumWorkGroups,
|
||||||
|
// subgroup
|
||||||
|
"num_subgroups" => crate::BuiltIn::NumSubgroups,
|
||||||
|
"subgroup_id" => crate::BuiltIn::SubgroupId,
|
||||||
|
"subgroup_size" => crate::BuiltIn::SubgroupSize,
|
||||||
|
"subgroup_invocation_id" => crate::BuiltIn::SubgroupInvocationId,
|
||||||
_ => return Err(Error::UnknownBuiltin(span)),
|
_ => return Err(Error::UnknownBuiltin(span)),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@ -260,3 +265,26 @@ pub fn map_conservative_depth(
|
|||||||
_ => Err(Error::UnknownConservativeDepth(span)),
|
_ => Err(Error::UnknownConservativeDepth(span)),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn map_subgroup_operation(
|
||||||
|
word: &str,
|
||||||
|
) -> Option<(crate::SubgroupOperation, crate::CollectiveOperation)> {
|
||||||
|
use crate::CollectiveOperation as co;
|
||||||
|
use crate::SubgroupOperation as sg;
|
||||||
|
Some(match word {
|
||||||
|
"subgroupAll" => (sg::All, co::Reduce),
|
||||||
|
"subgroupAny" => (sg::Any, co::Reduce),
|
||||||
|
"subgroupAdd" => (sg::Add, co::Reduce),
|
||||||
|
"subgroupMul" => (sg::Mul, co::Reduce),
|
||||||
|
"subgroupMin" => (sg::Min, co::Reduce),
|
||||||
|
"subgroupMax" => (sg::Max, co::Reduce),
|
||||||
|
"subgroupAnd" => (sg::And, co::Reduce),
|
||||||
|
"subgroupOr" => (sg::Or, co::Reduce),
|
||||||
|
"subgroupXor" => (sg::Xor, co::Reduce),
|
||||||
|
"subgroupExclusiveAdd" => (sg::Add, co::ExclusiveScan),
|
||||||
|
"subgroupExclusiveMul" => (sg::Mul, co::ExclusiveScan),
|
||||||
|
"subgroupInclusiveAdd" => (sg::Add, co::InclusiveScan),
|
||||||
|
"subgroupInclusiveMul" => (sg::Mul, co::InclusiveScan),
|
||||||
|
_ => return None,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|||||||
@ -431,6 +431,11 @@ pub enum BuiltIn {
|
|||||||
WorkGroupId,
|
WorkGroupId,
|
||||||
WorkGroupSize,
|
WorkGroupSize,
|
||||||
NumWorkGroups,
|
NumWorkGroups,
|
||||||
|
// subgroup
|
||||||
|
NumSubgroups,
|
||||||
|
SubgroupId,
|
||||||
|
SubgroupSize,
|
||||||
|
SubgroupInvocationId,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Number of bytes per scalar.
|
/// Number of bytes per scalar.
|
||||||
@ -1277,6 +1282,51 @@ pub enum SwizzleComponent {
|
|||||||
W = 3,
|
W = 3,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
|
||||||
|
#[cfg_attr(feature = "serialize", derive(Serialize))]
|
||||||
|
#[cfg_attr(feature = "deserialize", derive(Deserialize))]
|
||||||
|
#[cfg_attr(feature = "arbitrary", derive(Arbitrary))]
|
||||||
|
pub enum GatherMode {
|
||||||
|
/// All gather from the active lane with the smallest index
|
||||||
|
BroadcastFirst,
|
||||||
|
/// All gather from the same lane at the index given by the expression
|
||||||
|
Broadcast(Handle<Expression>),
|
||||||
|
/// Each gathers from a different lane at the index given by the expression
|
||||||
|
Shuffle(Handle<Expression>),
|
||||||
|
/// Each gathers from their lane plus the shift given by the expression
|
||||||
|
ShuffleDown(Handle<Expression>),
|
||||||
|
/// Each gathers from their lane minus the shift given by the expression
|
||||||
|
ShuffleUp(Handle<Expression>),
|
||||||
|
/// Each gathers from their lane xored with the given by the expression
|
||||||
|
ShuffleXor(Handle<Expression>),
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
|
||||||
|
#[cfg_attr(feature = "serialize", derive(Serialize))]
|
||||||
|
#[cfg_attr(feature = "deserialize", derive(Deserialize))]
|
||||||
|
#[cfg_attr(feature = "arbitrary", derive(Arbitrary))]
|
||||||
|
pub enum SubgroupOperation {
|
||||||
|
All = 0,
|
||||||
|
Any = 1,
|
||||||
|
Add = 2,
|
||||||
|
Mul = 3,
|
||||||
|
Min = 4,
|
||||||
|
Max = 5,
|
||||||
|
And = 6,
|
||||||
|
Or = 7,
|
||||||
|
Xor = 8,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
|
||||||
|
#[cfg_attr(feature = "serialize", derive(Serialize))]
|
||||||
|
#[cfg_attr(feature = "deserialize", derive(Deserialize))]
|
||||||
|
#[cfg_attr(feature = "arbitrary", derive(Arbitrary))]
|
||||||
|
pub enum CollectiveOperation {
|
||||||
|
Reduce = 0,
|
||||||
|
InclusiveScan = 1,
|
||||||
|
ExclusiveScan = 2,
|
||||||
|
}
|
||||||
|
|
||||||
bitflags::bitflags! {
|
bitflags::bitflags! {
|
||||||
/// Memory barrier flags.
|
/// Memory barrier flags.
|
||||||
#[cfg_attr(feature = "serialize", derive(Serialize))]
|
#[cfg_attr(feature = "serialize", derive(Serialize))]
|
||||||
@ -1285,9 +1335,11 @@ bitflags::bitflags! {
|
|||||||
#[derive(Clone, Copy, Debug, Default, Eq, PartialEq)]
|
#[derive(Clone, Copy, Debug, Default, Eq, PartialEq)]
|
||||||
pub struct Barrier: u32 {
|
pub struct Barrier: u32 {
|
||||||
/// Barrier affects all `AddressSpace::Storage` accesses.
|
/// Barrier affects all `AddressSpace::Storage` accesses.
|
||||||
const STORAGE = 0x1;
|
const STORAGE = 1 << 0;
|
||||||
/// Barrier affects all `AddressSpace::WorkGroup` accesses.
|
/// Barrier affects all `AddressSpace::WorkGroup` accesses.
|
||||||
const WORK_GROUP = 0x2;
|
const WORK_GROUP = 1 << 1;
|
||||||
|
/// Barrier synchronizes execution across all invocations within a subgroup that exectue this instruction.
|
||||||
|
const SUB_GROUP = 1 << 2;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1588,6 +1640,15 @@ pub enum Expression {
|
|||||||
query: Handle<Expression>,
|
query: Handle<Expression>,
|
||||||
committed: bool,
|
committed: bool,
|
||||||
},
|
},
|
||||||
|
/// Result of a [`SubgroupBallot`] statement.
|
||||||
|
///
|
||||||
|
/// [`SubgroupBallot`]: Statement::SubgroupBallot
|
||||||
|
SubgroupBallotResult,
|
||||||
|
/// Result of a [`SubgroupCollectiveOperation`] or [`SubgroupGather`] statement.
|
||||||
|
///
|
||||||
|
/// [`SubgroupCollectiveOperation`]: Statement::SubgroupCollectiveOperation
|
||||||
|
/// [`SubgroupGather`]: Statement::SubgroupGather
|
||||||
|
SubgroupOperationResult { ty: Handle<Type> },
|
||||||
}
|
}
|
||||||
|
|
||||||
pub use block::Block;
|
pub use block::Block;
|
||||||
@ -1872,6 +1933,39 @@ pub enum Statement {
|
|||||||
/// The specific operation we're performing on `query`.
|
/// The specific operation we're performing on `query`.
|
||||||
fun: RayQueryFunction,
|
fun: RayQueryFunction,
|
||||||
},
|
},
|
||||||
|
/// Calculate a bitmask using a boolean from each active thread in the subgroup
|
||||||
|
SubgroupBallot {
|
||||||
|
/// The [`SubgroupBallotResult`] expression representing this load's result.
|
||||||
|
///
|
||||||
|
/// [`SubgroupBallotResult`]: Expression::SubgroupBallotResult
|
||||||
|
result: Handle<Expression>,
|
||||||
|
/// The value from this thread to store in the ballot
|
||||||
|
predicate: Option<Handle<Expression>>,
|
||||||
|
},
|
||||||
|
/// Gather a value from another active thread in the subgroup
|
||||||
|
SubgroupGather {
|
||||||
|
/// Specifies which thread to gather from
|
||||||
|
mode: GatherMode,
|
||||||
|
/// The value to broadcast over
|
||||||
|
argument: Handle<Expression>,
|
||||||
|
/// The [`SubgroupOperationResult`] expression representing this load's result.
|
||||||
|
///
|
||||||
|
/// [`SubgroupOperationResult`]: Expression::SubgroupOperationResult
|
||||||
|
result: Handle<Expression>,
|
||||||
|
},
|
||||||
|
/// Compute a collective operation across all active threads in the subgroup
|
||||||
|
SubgroupCollectiveOperation {
|
||||||
|
/// What operation to compute
|
||||||
|
op: SubgroupOperation,
|
||||||
|
/// How to combine the results
|
||||||
|
collective_op: CollectiveOperation,
|
||||||
|
/// The value to compute over
|
||||||
|
argument: Handle<Expression>,
|
||||||
|
/// The [`SubgroupOperationResult`] expression representing this load's result.
|
||||||
|
///
|
||||||
|
/// [`SubgroupOperationResult`]: Expression::SubgroupOperationResult
|
||||||
|
result: Handle<Expression>,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
/// A function argument.
|
/// A function argument.
|
||||||
|
|||||||
@ -476,6 +476,8 @@ pub enum ConstantEvaluatorError {
|
|||||||
ImageExpression,
|
ImageExpression,
|
||||||
#[error("Constants don't support ray query expressions")]
|
#[error("Constants don't support ray query expressions")]
|
||||||
RayQueryExpression,
|
RayQueryExpression,
|
||||||
|
#[error("Constants don't support subgroup expressions")]
|
||||||
|
SubgroupExpression,
|
||||||
#[error("Cannot access the type")]
|
#[error("Cannot access the type")]
|
||||||
InvalidAccessBase,
|
InvalidAccessBase,
|
||||||
#[error("Cannot access at the index")]
|
#[error("Cannot access at the index")]
|
||||||
@ -884,6 +886,12 @@ impl<'a> ConstantEvaluator<'a> {
|
|||||||
Expression::RayQueryProceedResult | Expression::RayQueryGetIntersection { .. } => {
|
Expression::RayQueryProceedResult | Expression::RayQueryGetIntersection { .. } => {
|
||||||
Err(ConstantEvaluatorError::RayQueryExpression)
|
Err(ConstantEvaluatorError::RayQueryExpression)
|
||||||
}
|
}
|
||||||
|
Expression::SubgroupBallotResult { .. } => {
|
||||||
|
Err(ConstantEvaluatorError::SubgroupExpression)
|
||||||
|
}
|
||||||
|
Expression::SubgroupOperationResult { .. } => {
|
||||||
|
Err(ConstantEvaluatorError::SubgroupExpression)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -37,6 +37,9 @@ pub fn ensure_block_returns(block: &mut crate::Block) {
|
|||||||
| S::RayQuery { .. }
|
| S::RayQuery { .. }
|
||||||
| S::Atomic { .. }
|
| S::Atomic { .. }
|
||||||
| S::WorkGroupUniformLoad { .. }
|
| S::WorkGroupUniformLoad { .. }
|
||||||
|
| S::SubgroupBallot { .. }
|
||||||
|
| S::SubgroupCollectiveOperation { .. }
|
||||||
|
| S::SubgroupGather { .. }
|
||||||
| S::Barrier(_)),
|
| S::Barrier(_)),
|
||||||
)
|
)
|
||||||
| None => block.push(S::Return { value: None }, Default::default()),
|
| None => block.push(S::Return { value: None }, Default::default()),
|
||||||
|
|||||||
@ -598,6 +598,7 @@ impl<'a> ResolveContext<'a> {
|
|||||||
| crate::BinaryOperator::ShiftRight => past(left)?.clone(),
|
| crate::BinaryOperator::ShiftRight => past(left)?.clone(),
|
||||||
},
|
},
|
||||||
crate::Expression::AtomicResult { ty, .. } => TypeResolution::Handle(ty),
|
crate::Expression::AtomicResult { ty, .. } => TypeResolution::Handle(ty),
|
||||||
|
crate::Expression::SubgroupOperationResult { ty } => TypeResolution::Handle(ty),
|
||||||
crate::Expression::WorkGroupUniformLoadResult { ty } => TypeResolution::Handle(ty),
|
crate::Expression::WorkGroupUniformLoadResult { ty } => TypeResolution::Handle(ty),
|
||||||
crate::Expression::Select { accept, .. } => past(accept)?.clone(),
|
crate::Expression::Select { accept, .. } => past(accept)?.clone(),
|
||||||
crate::Expression::Derivative { expr, .. } => past(expr)?.clone(),
|
crate::Expression::Derivative { expr, .. } => past(expr)?.clone(),
|
||||||
@ -885,6 +886,10 @@ impl<'a> ResolveContext<'a> {
|
|||||||
.ok_or(ResolveError::MissingSpecialType)?;
|
.ok_or(ResolveError::MissingSpecialType)?;
|
||||||
TypeResolution::Handle(result)
|
TypeResolution::Handle(result)
|
||||||
}
|
}
|
||||||
|
crate::Expression::SubgroupBallotResult => TypeResolution::Value(Ti::Vector {
|
||||||
|
scalar: crate::Scalar::U32,
|
||||||
|
size: crate::VectorSize::Quad,
|
||||||
|
}),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -787,6 +787,14 @@ impl FunctionInfo {
|
|||||||
non_uniform_result: self.add_ref(query),
|
non_uniform_result: self.add_ref(query),
|
||||||
requirements: UniformityRequirements::empty(),
|
requirements: UniformityRequirements::empty(),
|
||||||
},
|
},
|
||||||
|
E::SubgroupBallotResult => Uniformity {
|
||||||
|
non_uniform_result: Some(handle),
|
||||||
|
requirements: UniformityRequirements::empty(),
|
||||||
|
},
|
||||||
|
E::SubgroupOperationResult { .. } => Uniformity {
|
||||||
|
non_uniform_result: Some(handle),
|
||||||
|
requirements: UniformityRequirements::empty(),
|
||||||
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
let ty = resolve_context.resolve(expression, |h| Ok(&self[h].ty))?;
|
let ty = resolve_context.resolve(expression, |h| Ok(&self[h].ty))?;
|
||||||
@ -1029,6 +1037,42 @@ impl FunctionInfo {
|
|||||||
}
|
}
|
||||||
FunctionUniformity::new()
|
FunctionUniformity::new()
|
||||||
}
|
}
|
||||||
|
S::SubgroupBallot {
|
||||||
|
result: _,
|
||||||
|
predicate,
|
||||||
|
} => {
|
||||||
|
if let Some(predicate) = predicate {
|
||||||
|
let _ = self.add_ref(predicate);
|
||||||
|
}
|
||||||
|
FunctionUniformity::new()
|
||||||
|
}
|
||||||
|
S::SubgroupCollectiveOperation {
|
||||||
|
op: _,
|
||||||
|
collective_op: _,
|
||||||
|
argument,
|
||||||
|
result: _,
|
||||||
|
} => {
|
||||||
|
let _ = self.add_ref(argument);
|
||||||
|
FunctionUniformity::new()
|
||||||
|
}
|
||||||
|
S::SubgroupGather {
|
||||||
|
mode,
|
||||||
|
argument,
|
||||||
|
result: _,
|
||||||
|
} => {
|
||||||
|
let _ = self.add_ref(argument);
|
||||||
|
match mode {
|
||||||
|
crate::GatherMode::BroadcastFirst => {}
|
||||||
|
crate::GatherMode::Broadcast(index)
|
||||||
|
| crate::GatherMode::Shuffle(index)
|
||||||
|
| crate::GatherMode::ShuffleDown(index)
|
||||||
|
| crate::GatherMode::ShuffleUp(index)
|
||||||
|
| crate::GatherMode::ShuffleXor(index) => {
|
||||||
|
let _ = self.add_ref(index);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
FunctionUniformity::new()
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
disruptor = disruptor.or(uniformity.exit_disruptor());
|
disruptor = disruptor.or(uniformity.exit_disruptor());
|
||||||
|
|||||||
@ -1641,6 +1641,7 @@ impl super::Validator {
|
|||||||
return Err(ExpressionError::InvalidRayQueryType(query));
|
return Err(ExpressionError::InvalidRayQueryType(query));
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
E::SubgroupBallotResult | E::SubgroupOperationResult { .. } => self.subgroup_stages,
|
||||||
};
|
};
|
||||||
Ok(stages)
|
Ok(stages)
|
||||||
}
|
}
|
||||||
|
|||||||
@ -47,6 +47,19 @@ pub enum AtomicError {
|
|||||||
ResultTypeMismatch(Handle<crate::Expression>),
|
ResultTypeMismatch(Handle<crate::Expression>),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, thiserror::Error)]
|
||||||
|
#[cfg_attr(test, derive(PartialEq))]
|
||||||
|
pub enum SubgroupError {
|
||||||
|
#[error("Operand {0:?} has invalid type.")]
|
||||||
|
InvalidOperand(Handle<crate::Expression>),
|
||||||
|
#[error("Result type for {0:?} doesn't match the statement")]
|
||||||
|
ResultTypeMismatch(Handle<crate::Expression>),
|
||||||
|
#[error("Support for subgroup operation {0:?} is required")]
|
||||||
|
UnsupportedOperation(super::SubgroupOperationSet),
|
||||||
|
#[error("Unknown operation")]
|
||||||
|
UnknownOperation,
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Clone, Debug, thiserror::Error)]
|
#[derive(Clone, Debug, thiserror::Error)]
|
||||||
#[cfg_attr(test, derive(PartialEq))]
|
#[cfg_attr(test, derive(PartialEq))]
|
||||||
pub enum LocalVariableError {
|
pub enum LocalVariableError {
|
||||||
@ -135,6 +148,8 @@ pub enum FunctionError {
|
|||||||
InvalidRayDescriptor(Handle<crate::Expression>),
|
InvalidRayDescriptor(Handle<crate::Expression>),
|
||||||
#[error("Ray Query {0:?} does not have a matching type")]
|
#[error("Ray Query {0:?} does not have a matching type")]
|
||||||
InvalidRayQueryType(Handle<crate::Type>),
|
InvalidRayQueryType(Handle<crate::Type>),
|
||||||
|
#[error("Shader requires capability {0:?}")]
|
||||||
|
MissingCapability(super::Capabilities),
|
||||||
#[error(
|
#[error(
|
||||||
"Required uniformity of control flow for {0:?} in {1:?} is not fulfilled because of {2:?}"
|
"Required uniformity of control flow for {0:?} in {1:?} is not fulfilled because of {2:?}"
|
||||||
)]
|
)]
|
||||||
@ -155,6 +170,8 @@ pub enum FunctionError {
|
|||||||
WorkgroupUniformLoadExpressionMismatch(Handle<crate::Expression>),
|
WorkgroupUniformLoadExpressionMismatch(Handle<crate::Expression>),
|
||||||
#[error("The expression {0:?} is not valid as a WorkGroupUniformLoad argument. It should be a Pointer in Workgroup address space")]
|
#[error("The expression {0:?} is not valid as a WorkGroupUniformLoad argument. It should be a Pointer in Workgroup address space")]
|
||||||
WorkgroupUniformLoadInvalidPointer(Handle<crate::Expression>),
|
WorkgroupUniformLoadInvalidPointer(Handle<crate::Expression>),
|
||||||
|
#[error("Subgroup operation is invalid")]
|
||||||
|
InvalidSubgroup(#[from] SubgroupError),
|
||||||
}
|
}
|
||||||
|
|
||||||
bitflags::bitflags! {
|
bitflags::bitflags! {
|
||||||
@ -399,6 +416,127 @@ impl super::Validator {
|
|||||||
}
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
fn validate_subgroup_operation(
|
||||||
|
&mut self,
|
||||||
|
op: &crate::SubgroupOperation,
|
||||||
|
collective_op: &crate::CollectiveOperation,
|
||||||
|
argument: Handle<crate::Expression>,
|
||||||
|
result: Handle<crate::Expression>,
|
||||||
|
context: &BlockContext,
|
||||||
|
) -> Result<(), WithSpan<FunctionError>> {
|
||||||
|
let argument_inner = context.resolve_type(argument, &self.valid_expression_set)?;
|
||||||
|
|
||||||
|
let (is_scalar, scalar) = match *argument_inner {
|
||||||
|
crate::TypeInner::Scalar(scalar) => (true, scalar),
|
||||||
|
crate::TypeInner::Vector { scalar, .. } => (false, scalar),
|
||||||
|
_ => {
|
||||||
|
log::error!("Subgroup operand type {:?}", argument_inner);
|
||||||
|
return Err(SubgroupError::InvalidOperand(argument)
|
||||||
|
.with_span_handle(argument, context.expressions)
|
||||||
|
.into_other());
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
use crate::ScalarKind as sk;
|
||||||
|
use crate::SubgroupOperation as sg;
|
||||||
|
match (scalar.kind, *op) {
|
||||||
|
(sk::Bool, sg::All | sg::Any) if is_scalar => {}
|
||||||
|
(sk::Sint | sk::Uint | sk::Float, sg::Add | sg::Mul | sg::Min | sg::Max) => {}
|
||||||
|
(sk::Sint | sk::Uint, sg::And | sg::Or | sg::Xor) => {}
|
||||||
|
|
||||||
|
(_, _) => {
|
||||||
|
log::error!("Subgroup operand type {:?}", argument_inner);
|
||||||
|
return Err(SubgroupError::InvalidOperand(argument)
|
||||||
|
.with_span_handle(argument, context.expressions)
|
||||||
|
.into_other());
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
use crate::CollectiveOperation as co;
|
||||||
|
match (*collective_op, *op) {
|
||||||
|
(
|
||||||
|
co::Reduce,
|
||||||
|
sg::All
|
||||||
|
| sg::Any
|
||||||
|
| sg::Add
|
||||||
|
| sg::Mul
|
||||||
|
| sg::Min
|
||||||
|
| sg::Max
|
||||||
|
| sg::And
|
||||||
|
| sg::Or
|
||||||
|
| sg::Xor,
|
||||||
|
) => {}
|
||||||
|
(co::InclusiveScan | co::ExclusiveScan, sg::Add | sg::Mul) => {}
|
||||||
|
|
||||||
|
(_, _) => {
|
||||||
|
return Err(SubgroupError::UnknownOperation.with_span().into_other());
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
self.emit_expression(result, context)?;
|
||||||
|
match context.expressions[result] {
|
||||||
|
crate::Expression::SubgroupOperationResult { ty }
|
||||||
|
if { &context.types[ty].inner == argument_inner } => {}
|
||||||
|
_ => {
|
||||||
|
return Err(SubgroupError::ResultTypeMismatch(result)
|
||||||
|
.with_span_handle(result, context.expressions)
|
||||||
|
.into_other())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
fn validate_subgroup_gather(
|
||||||
|
&mut self,
|
||||||
|
mode: &crate::GatherMode,
|
||||||
|
argument: Handle<crate::Expression>,
|
||||||
|
result: Handle<crate::Expression>,
|
||||||
|
context: &BlockContext,
|
||||||
|
) -> Result<(), WithSpan<FunctionError>> {
|
||||||
|
match *mode {
|
||||||
|
crate::GatherMode::BroadcastFirst => {}
|
||||||
|
crate::GatherMode::Broadcast(index)
|
||||||
|
| crate::GatherMode::Shuffle(index)
|
||||||
|
| crate::GatherMode::ShuffleDown(index)
|
||||||
|
| crate::GatherMode::ShuffleUp(index)
|
||||||
|
| crate::GatherMode::ShuffleXor(index) => {
|
||||||
|
let index_ty = context.resolve_type(index, &self.valid_expression_set)?;
|
||||||
|
match *index_ty {
|
||||||
|
crate::TypeInner::Scalar(crate::Scalar::U32) => {}
|
||||||
|
_ => {
|
||||||
|
log::error!(
|
||||||
|
"Subgroup gather index type {:?}, expected unsigned int",
|
||||||
|
index_ty
|
||||||
|
);
|
||||||
|
return Err(SubgroupError::InvalidOperand(argument)
|
||||||
|
.with_span_handle(index, context.expressions)
|
||||||
|
.into_other());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
let argument_inner = context.resolve_type(argument, &self.valid_expression_set)?;
|
||||||
|
if !matches!(*argument_inner,
|
||||||
|
crate::TypeInner::Scalar ( scalar, .. ) | crate::TypeInner::Vector { scalar, .. }
|
||||||
|
if matches!(scalar.kind, crate::ScalarKind::Uint | crate::ScalarKind::Sint | crate::ScalarKind::Float)
|
||||||
|
) {
|
||||||
|
log::error!("Subgroup gather operand type {:?}", argument_inner);
|
||||||
|
return Err(SubgroupError::InvalidOperand(argument)
|
||||||
|
.with_span_handle(argument, context.expressions)
|
||||||
|
.into_other());
|
||||||
|
}
|
||||||
|
|
||||||
|
self.emit_expression(result, context)?;
|
||||||
|
match context.expressions[result] {
|
||||||
|
crate::Expression::SubgroupOperationResult { ty }
|
||||||
|
if { &context.types[ty].inner == argument_inner } => {}
|
||||||
|
_ => {
|
||||||
|
return Err(SubgroupError::ResultTypeMismatch(result)
|
||||||
|
.with_span_handle(result, context.expressions)
|
||||||
|
.into_other())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
fn validate_block_impl(
|
fn validate_block_impl(
|
||||||
&mut self,
|
&mut self,
|
||||||
@ -613,8 +751,30 @@ impl super::Validator {
|
|||||||
stages &= super::ShaderStages::FRAGMENT;
|
stages &= super::ShaderStages::FRAGMENT;
|
||||||
finished = true;
|
finished = true;
|
||||||
}
|
}
|
||||||
S::Barrier(_) => {
|
S::Barrier(barrier) => {
|
||||||
stages &= super::ShaderStages::COMPUTE;
|
stages &= super::ShaderStages::COMPUTE;
|
||||||
|
if barrier.contains(crate::Barrier::SUB_GROUP) {
|
||||||
|
if !self.capabilities.contains(
|
||||||
|
super::Capabilities::SUBGROUP | super::Capabilities::SUBGROUP_BARRIER,
|
||||||
|
) {
|
||||||
|
return Err(FunctionError::MissingCapability(
|
||||||
|
super::Capabilities::SUBGROUP
|
||||||
|
| super::Capabilities::SUBGROUP_BARRIER,
|
||||||
|
)
|
||||||
|
.with_span_static(span, "missing capability for this operation"));
|
||||||
|
}
|
||||||
|
if !self
|
||||||
|
.subgroup_operations
|
||||||
|
.contains(super::SubgroupOperationSet::BASIC)
|
||||||
|
{
|
||||||
|
return Err(FunctionError::InvalidSubgroup(
|
||||||
|
SubgroupError::UnsupportedOperation(
|
||||||
|
super::SubgroupOperationSet::BASIC,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
.with_span_static(span, "support for this operation is not present"));
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
S::Store { pointer, value } => {
|
S::Store { pointer, value } => {
|
||||||
let mut current = pointer;
|
let mut current = pointer;
|
||||||
@ -904,6 +1064,86 @@ impl super::Validator {
|
|||||||
crate::RayQueryFunction::Terminate => {}
|
crate::RayQueryFunction::Terminate => {}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
S::SubgroupBallot { result, predicate } => {
|
||||||
|
stages &= self.subgroup_stages;
|
||||||
|
if !self.capabilities.contains(super::Capabilities::SUBGROUP) {
|
||||||
|
return Err(FunctionError::MissingCapability(
|
||||||
|
super::Capabilities::SUBGROUP,
|
||||||
|
)
|
||||||
|
.with_span_static(span, "missing capability for this operation"));
|
||||||
|
}
|
||||||
|
if !self
|
||||||
|
.subgroup_operations
|
||||||
|
.contains(super::SubgroupOperationSet::BALLOT)
|
||||||
|
{
|
||||||
|
return Err(FunctionError::InvalidSubgroup(
|
||||||
|
SubgroupError::UnsupportedOperation(
|
||||||
|
super::SubgroupOperationSet::BALLOT,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
.with_span_static(span, "support for this operation is not present"));
|
||||||
|
}
|
||||||
|
if let Some(predicate) = predicate {
|
||||||
|
let predicate_inner =
|
||||||
|
context.resolve_type(predicate, &self.valid_expression_set)?;
|
||||||
|
if !matches!(
|
||||||
|
*predicate_inner,
|
||||||
|
crate::TypeInner::Scalar(crate::Scalar::BOOL,)
|
||||||
|
) {
|
||||||
|
log::error!(
|
||||||
|
"Subgroup ballot predicate type {:?} expected bool",
|
||||||
|
predicate_inner
|
||||||
|
);
|
||||||
|
return Err(SubgroupError::InvalidOperand(predicate)
|
||||||
|
.with_span_handle(predicate, context.expressions)
|
||||||
|
.into_other());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
self.emit_expression(result, context)?;
|
||||||
|
}
|
||||||
|
S::SubgroupCollectiveOperation {
|
||||||
|
ref op,
|
||||||
|
ref collective_op,
|
||||||
|
argument,
|
||||||
|
result,
|
||||||
|
} => {
|
||||||
|
stages &= self.subgroup_stages;
|
||||||
|
if !self.capabilities.contains(super::Capabilities::SUBGROUP) {
|
||||||
|
return Err(FunctionError::MissingCapability(
|
||||||
|
super::Capabilities::SUBGROUP,
|
||||||
|
)
|
||||||
|
.with_span_static(span, "missing capability for this operation"));
|
||||||
|
}
|
||||||
|
let operation = op.required_operations();
|
||||||
|
if !self.subgroup_operations.contains(operation) {
|
||||||
|
return Err(FunctionError::InvalidSubgroup(
|
||||||
|
SubgroupError::UnsupportedOperation(operation),
|
||||||
|
)
|
||||||
|
.with_span_static(span, "support for this operation is not present"));
|
||||||
|
}
|
||||||
|
self.validate_subgroup_operation(op, collective_op, argument, result, context)?;
|
||||||
|
}
|
||||||
|
S::SubgroupGather {
|
||||||
|
ref mode,
|
||||||
|
argument,
|
||||||
|
result,
|
||||||
|
} => {
|
||||||
|
stages &= self.subgroup_stages;
|
||||||
|
if !self.capabilities.contains(super::Capabilities::SUBGROUP) {
|
||||||
|
return Err(FunctionError::MissingCapability(
|
||||||
|
super::Capabilities::SUBGROUP,
|
||||||
|
)
|
||||||
|
.with_span_static(span, "missing capability for this operation"));
|
||||||
|
}
|
||||||
|
let operation = mode.required_operations();
|
||||||
|
if !self.subgroup_operations.contains(operation) {
|
||||||
|
return Err(FunctionError::InvalidSubgroup(
|
||||||
|
SubgroupError::UnsupportedOperation(operation),
|
||||||
|
)
|
||||||
|
.with_span_static(span, "support for this operation is not present"));
|
||||||
|
}
|
||||||
|
self.validate_subgroup_gather(mode, argument, result, context)?;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Ok(BlockInfo { stages, finished })
|
Ok(BlockInfo { stages, finished })
|
||||||
|
|||||||
@ -420,6 +420,8 @@ impl super::Validator {
|
|||||||
}
|
}
|
||||||
crate::Expression::AtomicResult { .. }
|
crate::Expression::AtomicResult { .. }
|
||||||
| crate::Expression::RayQueryProceedResult
|
| crate::Expression::RayQueryProceedResult
|
||||||
|
| crate::Expression::SubgroupBallotResult
|
||||||
|
| crate::Expression::SubgroupOperationResult { .. }
|
||||||
| crate::Expression::WorkGroupUniformLoadResult { .. } => (),
|
| crate::Expression::WorkGroupUniformLoadResult { .. } => (),
|
||||||
crate::Expression::ArrayLength(array) => {
|
crate::Expression::ArrayLength(array) => {
|
||||||
handle.check_dep(array)?;
|
handle.check_dep(array)?;
|
||||||
@ -565,6 +567,38 @@ impl super::Validator {
|
|||||||
}
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
crate::Statement::SubgroupBallot { result, predicate } => {
|
||||||
|
validate_expr_opt(predicate)?;
|
||||||
|
validate_expr(result)?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
crate::Statement::SubgroupCollectiveOperation {
|
||||||
|
op: _,
|
||||||
|
collective_op: _,
|
||||||
|
argument,
|
||||||
|
result,
|
||||||
|
} => {
|
||||||
|
validate_expr(argument)?;
|
||||||
|
validate_expr(result)?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
crate::Statement::SubgroupGather {
|
||||||
|
mode,
|
||||||
|
argument,
|
||||||
|
result,
|
||||||
|
} => {
|
||||||
|
validate_expr(argument)?;
|
||||||
|
match mode {
|
||||||
|
crate::GatherMode::BroadcastFirst => {}
|
||||||
|
crate::GatherMode::Broadcast(index)
|
||||||
|
| crate::GatherMode::Shuffle(index)
|
||||||
|
| crate::GatherMode::ShuffleDown(index)
|
||||||
|
| crate::GatherMode::ShuffleUp(index)
|
||||||
|
| crate::GatherMode::ShuffleXor(index) => validate_expr(index)?,
|
||||||
|
}
|
||||||
|
validate_expr(result)?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
crate::Statement::Break
|
crate::Statement::Break
|
||||||
| crate::Statement::Continue
|
| crate::Statement::Continue
|
||||||
| crate::Statement::Kill
|
| crate::Statement::Kill
|
||||||
|
|||||||
@ -77,6 +77,8 @@ pub enum VaryingError {
|
|||||||
location: u32,
|
location: u32,
|
||||||
attribute: &'static str,
|
attribute: &'static str,
|
||||||
},
|
},
|
||||||
|
#[error("Workgroup size is multi dimensional, @builtin(subgroup_id) and @builtin(subgroup_invocation_id) are not supported.")]
|
||||||
|
InvalidMultiDimensionalSubgroupBuiltIn,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Debug, thiserror::Error)]
|
#[derive(Clone, Debug, thiserror::Error)]
|
||||||
@ -140,6 +142,7 @@ struct VaryingContext<'a> {
|
|||||||
impl VaryingContext<'_> {
|
impl VaryingContext<'_> {
|
||||||
fn validate_impl(
|
fn validate_impl(
|
||||||
&mut self,
|
&mut self,
|
||||||
|
ep: &crate::EntryPoint,
|
||||||
ty: Handle<crate::Type>,
|
ty: Handle<crate::Type>,
|
||||||
binding: &crate::Binding,
|
binding: &crate::Binding,
|
||||||
) -> Result<(), VaryingError> {
|
) -> Result<(), VaryingError> {
|
||||||
@ -167,12 +170,24 @@ impl VaryingContext<'_> {
|
|||||||
Bi::PrimitiveIndex => Capabilities::PRIMITIVE_INDEX,
|
Bi::PrimitiveIndex => Capabilities::PRIMITIVE_INDEX,
|
||||||
Bi::ViewIndex => Capabilities::MULTIVIEW,
|
Bi::ViewIndex => Capabilities::MULTIVIEW,
|
||||||
Bi::SampleIndex => Capabilities::MULTISAMPLED_SHADING,
|
Bi::SampleIndex => Capabilities::MULTISAMPLED_SHADING,
|
||||||
|
Bi::NumSubgroups
|
||||||
|
| Bi::SubgroupId
|
||||||
|
| Bi::SubgroupSize
|
||||||
|
| Bi::SubgroupInvocationId => Capabilities::SUBGROUP,
|
||||||
_ => Capabilities::empty(),
|
_ => Capabilities::empty(),
|
||||||
};
|
};
|
||||||
if !self.capabilities.contains(required) {
|
if !self.capabilities.contains(required) {
|
||||||
return Err(VaryingError::UnsupportedCapability(required));
|
return Err(VaryingError::UnsupportedCapability(required));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if matches!(
|
||||||
|
built_in,
|
||||||
|
crate::BuiltIn::SubgroupId | crate::BuiltIn::SubgroupInvocationId
|
||||||
|
) && ep.workgroup_size[1..].iter().any(|&s| s > 1)
|
||||||
|
{
|
||||||
|
return Err(VaryingError::InvalidMultiDimensionalSubgroupBuiltIn);
|
||||||
|
}
|
||||||
|
|
||||||
let (visible, type_good) = match built_in {
|
let (visible, type_good) = match built_in {
|
||||||
Bi::BaseInstance | Bi::BaseVertex | Bi::InstanceIndex | Bi::VertexIndex => (
|
Bi::BaseInstance | Bi::BaseVertex | Bi::InstanceIndex | Bi::VertexIndex => (
|
||||||
self.stage == St::Vertex && !self.output,
|
self.stage == St::Vertex && !self.output,
|
||||||
@ -254,6 +269,17 @@ impl VaryingContext<'_> {
|
|||||||
scalar: crate::Scalar::U32,
|
scalar: crate::Scalar::U32,
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
|
Bi::NumSubgroups | Bi::SubgroupId => (
|
||||||
|
self.stage == St::Compute && !self.output,
|
||||||
|
*ty_inner == Ti::Scalar(crate::Scalar::U32),
|
||||||
|
),
|
||||||
|
Bi::SubgroupSize | Bi::SubgroupInvocationId => (
|
||||||
|
match self.stage {
|
||||||
|
St::Compute | St::Fragment => !self.output,
|
||||||
|
St::Vertex => false,
|
||||||
|
},
|
||||||
|
*ty_inner == Ti::Scalar(crate::Scalar::U32),
|
||||||
|
),
|
||||||
};
|
};
|
||||||
|
|
||||||
if !visible {
|
if !visible {
|
||||||
@ -354,13 +380,14 @@ impl VaryingContext<'_> {
|
|||||||
|
|
||||||
fn validate(
|
fn validate(
|
||||||
&mut self,
|
&mut self,
|
||||||
|
ep: &crate::EntryPoint,
|
||||||
ty: Handle<crate::Type>,
|
ty: Handle<crate::Type>,
|
||||||
binding: Option<&crate::Binding>,
|
binding: Option<&crate::Binding>,
|
||||||
) -> Result<(), WithSpan<VaryingError>> {
|
) -> Result<(), WithSpan<VaryingError>> {
|
||||||
let span_context = self.types.get_span_context(ty);
|
let span_context = self.types.get_span_context(ty);
|
||||||
match binding {
|
match binding {
|
||||||
Some(binding) => self
|
Some(binding) => self
|
||||||
.validate_impl(ty, binding)
|
.validate_impl(ep, ty, binding)
|
||||||
.map_err(|e| e.with_span_context(span_context)),
|
.map_err(|e| e.with_span_context(span_context)),
|
||||||
None => {
|
None => {
|
||||||
match self.types[ty].inner {
|
match self.types[ty].inner {
|
||||||
@ -377,7 +404,7 @@ impl VaryingContext<'_> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
Some(ref binding) => self
|
Some(ref binding) => self
|
||||||
.validate_impl(member.ty, binding)
|
.validate_impl(ep, member.ty, binding)
|
||||||
.map_err(|e| e.with_span_context(span_context))?,
|
.map_err(|e| e.with_span_context(span_context))?,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -609,7 +636,7 @@ impl super::Validator {
|
|||||||
capabilities: self.capabilities,
|
capabilities: self.capabilities,
|
||||||
flags: self.flags,
|
flags: self.flags,
|
||||||
};
|
};
|
||||||
ctx.validate(fa.ty, fa.binding.as_ref())
|
ctx.validate(ep, fa.ty, fa.binding.as_ref())
|
||||||
.map_err_inner(|e| EntryPointError::Argument(index as u32, e).with_span())?;
|
.map_err_inner(|e| EntryPointError::Argument(index as u32, e).with_span())?;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -627,7 +654,7 @@ impl super::Validator {
|
|||||||
capabilities: self.capabilities,
|
capabilities: self.capabilities,
|
||||||
flags: self.flags,
|
flags: self.flags,
|
||||||
};
|
};
|
||||||
ctx.validate(fr.ty, fr.binding.as_ref())
|
ctx.validate(ep, fr.ty, fr.binding.as_ref())
|
||||||
.map_err_inner(|e| EntryPointError::Result(e).with_span())?;
|
.map_err_inner(|e| EntryPointError::Result(e).with_span())?;
|
||||||
if ctx.second_blend_source {
|
if ctx.second_blend_source {
|
||||||
// Only the first location may be used when dual source blending
|
// Only the first location may be used when dual source blending
|
||||||
|
|||||||
@ -77,7 +77,7 @@ bitflags::bitflags! {
|
|||||||
#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
|
#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
|
||||||
#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
|
#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
|
||||||
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
|
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
|
||||||
pub struct Capabilities: u16 {
|
pub struct Capabilities: u32 {
|
||||||
/// Support for [`AddressSpace:PushConstant`].
|
/// Support for [`AddressSpace:PushConstant`].
|
||||||
const PUSH_CONSTANT = 0x1;
|
const PUSH_CONSTANT = 0x1;
|
||||||
/// Float values with width = 8.
|
/// Float values with width = 8.
|
||||||
@ -110,6 +110,10 @@ bitflags::bitflags! {
|
|||||||
const CUBE_ARRAY_TEXTURES = 0x4000;
|
const CUBE_ARRAY_TEXTURES = 0x4000;
|
||||||
/// Support for 64-bit signed and unsigned integers.
|
/// Support for 64-bit signed and unsigned integers.
|
||||||
const SHADER_INT64 = 0x8000;
|
const SHADER_INT64 = 0x8000;
|
||||||
|
/// Support for subgroup operations.
|
||||||
|
const SUBGROUP = 0x10000;
|
||||||
|
/// Support for subgroup barriers.
|
||||||
|
const SUBGROUP_BARRIER = 0x20000;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -119,6 +123,57 @@ impl Default for Capabilities {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bitflags::bitflags! {
|
||||||
|
/// Supported subgroup operations
|
||||||
|
#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
|
||||||
|
#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
|
||||||
|
#[derive(Clone, Copy, Debug, Default, Eq, PartialEq)]
|
||||||
|
pub struct SubgroupOperationSet: u8 {
|
||||||
|
/// Elect, Barrier
|
||||||
|
const BASIC = 1 << 0;
|
||||||
|
/// Any, All
|
||||||
|
const VOTE = 1 << 1;
|
||||||
|
/// reductions, scans
|
||||||
|
const ARITHMETIC = 1 << 2;
|
||||||
|
/// ballot, broadcast
|
||||||
|
const BALLOT = 1 << 3;
|
||||||
|
/// shuffle, shuffle xor
|
||||||
|
const SHUFFLE = 1 << 4;
|
||||||
|
/// shuffle up, down
|
||||||
|
const SHUFFLE_RELATIVE = 1 << 5;
|
||||||
|
// We don't support these operations yet
|
||||||
|
// /// Clustered
|
||||||
|
// const CLUSTERED = 1 << 6;
|
||||||
|
// /// Quad supported
|
||||||
|
// const QUAD_FRAGMENT_COMPUTE = 1 << 7;
|
||||||
|
// /// Quad supported in all stages
|
||||||
|
// const QUAD_ALL_STAGES = 1 << 8;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl super::SubgroupOperation {
|
||||||
|
const fn required_operations(&self) -> SubgroupOperationSet {
|
||||||
|
use SubgroupOperationSet as S;
|
||||||
|
match *self {
|
||||||
|
Self::All | Self::Any => S::VOTE,
|
||||||
|
Self::Add | Self::Mul | Self::Min | Self::Max | Self::And | Self::Or | Self::Xor => {
|
||||||
|
S::ARITHMETIC
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl super::GatherMode {
|
||||||
|
const fn required_operations(&self) -> SubgroupOperationSet {
|
||||||
|
use SubgroupOperationSet as S;
|
||||||
|
match *self {
|
||||||
|
Self::BroadcastFirst | Self::Broadcast(_) => S::BALLOT,
|
||||||
|
Self::Shuffle(_) | Self::ShuffleXor(_) => S::SHUFFLE,
|
||||||
|
Self::ShuffleUp(_) | Self::ShuffleDown(_) => S::SHUFFLE_RELATIVE,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
bitflags::bitflags! {
|
bitflags::bitflags! {
|
||||||
/// Validation flags.
|
/// Validation flags.
|
||||||
#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
|
#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
|
||||||
@ -166,6 +221,8 @@ impl ops::Index<Handle<crate::Expression>> for ModuleInfo {
|
|||||||
pub struct Validator {
|
pub struct Validator {
|
||||||
flags: ValidationFlags,
|
flags: ValidationFlags,
|
||||||
capabilities: Capabilities,
|
capabilities: Capabilities,
|
||||||
|
subgroup_stages: ShaderStages,
|
||||||
|
subgroup_operations: SubgroupOperationSet,
|
||||||
types: Vec<r#type::TypeInfo>,
|
types: Vec<r#type::TypeInfo>,
|
||||||
layouter: Layouter,
|
layouter: Layouter,
|
||||||
location_mask: BitSet,
|
location_mask: BitSet,
|
||||||
@ -317,6 +374,8 @@ impl Validator {
|
|||||||
Validator {
|
Validator {
|
||||||
flags,
|
flags,
|
||||||
capabilities,
|
capabilities,
|
||||||
|
subgroup_stages: ShaderStages::empty(),
|
||||||
|
subgroup_operations: SubgroupOperationSet::empty(),
|
||||||
types: Vec::new(),
|
types: Vec::new(),
|
||||||
layouter: Layouter::default(),
|
layouter: Layouter::default(),
|
||||||
location_mask: BitSet::new(),
|
location_mask: BitSet::new(),
|
||||||
@ -329,6 +388,16 @@ impl Validator {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn subgroup_stages(&mut self, stages: ShaderStages) -> &mut Self {
|
||||||
|
self.subgroup_stages = stages;
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn subgroup_operations(&mut self, operations: SubgroupOperationSet) -> &mut Self {
|
||||||
|
self.subgroup_operations = operations;
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
/// Reset the validator internals
|
/// Reset the validator internals
|
||||||
pub fn reset(&mut self) {
|
pub fn reset(&mut self) {
|
||||||
self.types.clear();
|
self.types.clear();
|
||||||
|
|||||||
27
naga/tests/in/spv/subgroup-operations-s.param.ron
Normal file
27
naga/tests/in/spv/subgroup-operations-s.param.ron
Normal file
@ -0,0 +1,27 @@
|
|||||||
|
(
|
||||||
|
god_mode: true,
|
||||||
|
spv: (
|
||||||
|
version: (1, 3),
|
||||||
|
),
|
||||||
|
msl: (
|
||||||
|
lang_version: (2, 4),
|
||||||
|
per_entry_point_map: {},
|
||||||
|
inline_samplers: [],
|
||||||
|
spirv_cross_compatibility: false,
|
||||||
|
fake_missing_bindings: false,
|
||||||
|
zero_initialize_workgroup_memory: true,
|
||||||
|
),
|
||||||
|
glsl: (
|
||||||
|
version: Desktop(430),
|
||||||
|
writer_flags: (""),
|
||||||
|
binding_map: { },
|
||||||
|
zero_initialize_workgroup_memory: true,
|
||||||
|
),
|
||||||
|
hlsl: (
|
||||||
|
shader_model: V6_0,
|
||||||
|
binding_map: {},
|
||||||
|
fake_missing_bindings: true,
|
||||||
|
special_constants_binding: None,
|
||||||
|
zero_initialize_workgroup_memory: true,
|
||||||
|
),
|
||||||
|
)
|
||||||
BIN
naga/tests/in/spv/subgroup-operations-s.spv
Normal file
BIN
naga/tests/in/spv/subgroup-operations-s.spv
Normal file
Binary file not shown.
75
naga/tests/in/spv/subgroup-operations-s.spvasm
Normal file
75
naga/tests/in/spv/subgroup-operations-s.spvasm
Normal file
@ -0,0 +1,75 @@
|
|||||||
|
; SPIR-V
|
||||||
|
; Version: 1.3
|
||||||
|
; Generator: rspirv
|
||||||
|
; Bound: 54
|
||||||
|
OpCapability Shader
|
||||||
|
OpCapability GroupNonUniform
|
||||||
|
OpCapability GroupNonUniformBallot
|
||||||
|
OpCapability GroupNonUniformVote
|
||||||
|
OpCapability GroupNonUniformArithmetic
|
||||||
|
OpCapability GroupNonUniformShuffle
|
||||||
|
OpCapability GroupNonUniformShuffleRelative
|
||||||
|
%1 = OpExtInstImport "GLSL.std.450"
|
||||||
|
OpMemoryModel Logical GLSL450
|
||||||
|
OpEntryPoint GLCompute %15 "main" %6 %9 %11 %13
|
||||||
|
OpExecutionMode %15 LocalSize 1 1 1
|
||||||
|
OpDecorate %6 BuiltIn NumSubgroups
|
||||||
|
OpDecorate %9 BuiltIn SubgroupId
|
||||||
|
OpDecorate %11 BuiltIn SubgroupSize
|
||||||
|
OpDecorate %13 BuiltIn SubgroupLocalInvocationId
|
||||||
|
%2 = OpTypeVoid
|
||||||
|
%3 = OpTypeInt 32 0
|
||||||
|
%4 = OpTypeBool
|
||||||
|
%7 = OpTypePointer Input %3
|
||||||
|
%6 = OpVariable %7 Input
|
||||||
|
%9 = OpVariable %7 Input
|
||||||
|
%11 = OpVariable %7 Input
|
||||||
|
%13 = OpVariable %7 Input
|
||||||
|
%16 = OpTypeFunction %2
|
||||||
|
%17 = OpConstant %3 1
|
||||||
|
%18 = OpConstant %3 0
|
||||||
|
%19 = OpConstant %3 4
|
||||||
|
%21 = OpConstant %3 3
|
||||||
|
%22 = OpConstant %3 2
|
||||||
|
%23 = OpConstant %3 8
|
||||||
|
%26 = OpTypeVector %3 4
|
||||||
|
%28 = OpConstantTrue %4
|
||||||
|
%15 = OpFunction %2 None %16
|
||||||
|
%5 = OpLabel
|
||||||
|
%8 = OpLoad %3 %6
|
||||||
|
%10 = OpLoad %3 %9
|
||||||
|
%12 = OpLoad %3 %11
|
||||||
|
%14 = OpLoad %3 %13
|
||||||
|
OpBranch %20
|
||||||
|
%20 = OpLabel
|
||||||
|
OpControlBarrier %21 %22 %23
|
||||||
|
%24 = OpBitwiseAnd %3 %14 %17
|
||||||
|
%25 = OpIEqual %4 %24 %17
|
||||||
|
%27 = OpGroupNonUniformBallot %26 %21 %25
|
||||||
|
%29 = OpGroupNonUniformBallot %26 %21 %28
|
||||||
|
%30 = OpINotEqual %4 %14 %18
|
||||||
|
%31 = OpGroupNonUniformAll %4 %21 %30
|
||||||
|
%32 = OpIEqual %4 %14 %18
|
||||||
|
%33 = OpGroupNonUniformAny %4 %21 %32
|
||||||
|
%34 = OpGroupNonUniformIAdd %3 %21 Reduce %14
|
||||||
|
%35 = OpGroupNonUniformIMul %3 %21 Reduce %14
|
||||||
|
%36 = OpGroupNonUniformUMin %3 %21 Reduce %14
|
||||||
|
%37 = OpGroupNonUniformUMax %3 %21 Reduce %14
|
||||||
|
%38 = OpGroupNonUniformBitwiseAnd %3 %21 Reduce %14
|
||||||
|
%39 = OpGroupNonUniformBitwiseOr %3 %21 Reduce %14
|
||||||
|
%40 = OpGroupNonUniformBitwiseXor %3 %21 Reduce %14
|
||||||
|
%41 = OpGroupNonUniformIAdd %3 %21 ExclusiveScan %14
|
||||||
|
%42 = OpGroupNonUniformIMul %3 %21 ExclusiveScan %14
|
||||||
|
%43 = OpGroupNonUniformIAdd %3 %21 InclusiveScan %14
|
||||||
|
%44 = OpGroupNonUniformIMul %3 %21 InclusiveScan %14
|
||||||
|
%45 = OpGroupNonUniformBroadcastFirst %3 %21 %14
|
||||||
|
%46 = OpGroupNonUniformBroadcast %3 %21 %14 %19
|
||||||
|
%47 = OpISub %3 %12 %17
|
||||||
|
%48 = OpISub %3 %47 %14
|
||||||
|
%49 = OpGroupNonUniformShuffle %3 %21 %14 %48
|
||||||
|
%50 = OpGroupNonUniformShuffleDown %3 %21 %14 %17
|
||||||
|
%51 = OpGroupNonUniformShuffleUp %3 %21 %14 %17
|
||||||
|
%52 = OpISub %3 %12 %17
|
||||||
|
%53 = OpGroupNonUniformShuffleXor %3 %21 %14 %52
|
||||||
|
OpReturn
|
||||||
|
OpFunctionEnd
|
||||||
27
naga/tests/in/subgroup-operations.param.ron
Normal file
27
naga/tests/in/subgroup-operations.param.ron
Normal file
@ -0,0 +1,27 @@
|
|||||||
|
(
|
||||||
|
god_mode: true,
|
||||||
|
spv: (
|
||||||
|
version: (1, 3),
|
||||||
|
),
|
||||||
|
msl: (
|
||||||
|
lang_version: (2, 4),
|
||||||
|
per_entry_point_map: {},
|
||||||
|
inline_samplers: [],
|
||||||
|
spirv_cross_compatibility: false,
|
||||||
|
fake_missing_bindings: false,
|
||||||
|
zero_initialize_workgroup_memory: true,
|
||||||
|
),
|
||||||
|
glsl: (
|
||||||
|
version: Desktop(430),
|
||||||
|
writer_flags: (""),
|
||||||
|
binding_map: { },
|
||||||
|
zero_initialize_workgroup_memory: true,
|
||||||
|
),
|
||||||
|
hlsl: (
|
||||||
|
shader_model: V6_0,
|
||||||
|
binding_map: {},
|
||||||
|
fake_missing_bindings: true,
|
||||||
|
special_constants_binding: None,
|
||||||
|
zero_initialize_workgroup_memory: true,
|
||||||
|
),
|
||||||
|
)
|
||||||
37
naga/tests/in/subgroup-operations.wgsl
Normal file
37
naga/tests/in/subgroup-operations.wgsl
Normal file
@ -0,0 +1,37 @@
|
|||||||
|
struct Structure {
|
||||||
|
@builtin(num_subgroups) num_subgroups: u32,
|
||||||
|
@builtin(subgroup_size) subgroup_size: u32,
|
||||||
|
};
|
||||||
|
|
||||||
|
@compute @workgroup_size(1)
|
||||||
|
fn main(
|
||||||
|
sizes: Structure,
|
||||||
|
@builtin(subgroup_id) subgroup_id: u32,
|
||||||
|
@builtin(subgroup_invocation_id) subgroup_invocation_id: u32,
|
||||||
|
) {
|
||||||
|
subgroupBarrier();
|
||||||
|
|
||||||
|
subgroupBallot((subgroup_invocation_id & 1u) == 1u);
|
||||||
|
subgroupBallot();
|
||||||
|
|
||||||
|
subgroupAll(subgroup_invocation_id != 0u);
|
||||||
|
subgroupAny(subgroup_invocation_id == 0u);
|
||||||
|
subgroupAdd(subgroup_invocation_id);
|
||||||
|
subgroupMul(subgroup_invocation_id);
|
||||||
|
subgroupMin(subgroup_invocation_id);
|
||||||
|
subgroupMax(subgroup_invocation_id);
|
||||||
|
subgroupAnd(subgroup_invocation_id);
|
||||||
|
subgroupOr(subgroup_invocation_id);
|
||||||
|
subgroupXor(subgroup_invocation_id);
|
||||||
|
subgroupExclusiveAdd(subgroup_invocation_id);
|
||||||
|
subgroupExclusiveMul(subgroup_invocation_id);
|
||||||
|
subgroupInclusiveAdd(subgroup_invocation_id);
|
||||||
|
subgroupInclusiveMul(subgroup_invocation_id);
|
||||||
|
|
||||||
|
subgroupBroadcastFirst(subgroup_invocation_id);
|
||||||
|
subgroupBroadcast(subgroup_invocation_id, 4u);
|
||||||
|
subgroupShuffle(subgroup_invocation_id, sizes.subgroup_size - 1u - subgroup_invocation_id);
|
||||||
|
subgroupShuffleDown(subgroup_invocation_id, 1u);
|
||||||
|
subgroupShuffleUp(subgroup_invocation_id, 1u);
|
||||||
|
subgroupShuffleXor(subgroup_invocation_id, sizes.subgroup_size - 1u);
|
||||||
|
}
|
||||||
58
naga/tests/out/glsl/subgroup-operations-s.main.Compute.glsl
Normal file
58
naga/tests/out/glsl/subgroup-operations-s.main.Compute.glsl
Normal file
@ -0,0 +1,58 @@
|
|||||||
|
#version 430 core
|
||||||
|
#extension GL_ARB_compute_shader : require
|
||||||
|
#extension GL_KHR_shader_subgroup_basic : require
|
||||||
|
#extension GL_KHR_shader_subgroup_vote : require
|
||||||
|
#extension GL_KHR_shader_subgroup_arithmetic : require
|
||||||
|
#extension GL_KHR_shader_subgroup_ballot : require
|
||||||
|
#extension GL_KHR_shader_subgroup_shuffle : require
|
||||||
|
#extension GL_KHR_shader_subgroup_shuffle_relative : require
|
||||||
|
layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
|
||||||
|
|
||||||
|
uint num_subgroups_1 = 0u;
|
||||||
|
|
||||||
|
uint subgroup_id_1 = 0u;
|
||||||
|
|
||||||
|
uint subgroup_size_1 = 0u;
|
||||||
|
|
||||||
|
uint subgroup_invocation_id_1 = 0u;
|
||||||
|
|
||||||
|
|
||||||
|
void main_1() {
|
||||||
|
uint _e5 = subgroup_size_1;
|
||||||
|
uint _e6 = subgroup_invocation_id_1;
|
||||||
|
uvec4 _e9 = subgroupBallot(((_e6 & 1u) == 1u));
|
||||||
|
uvec4 _e10 = subgroupBallot(true);
|
||||||
|
bool _e12 = subgroupAll((_e6 != 0u));
|
||||||
|
bool _e14 = subgroupAny((_e6 == 0u));
|
||||||
|
uint _e15 = subgroupAdd(_e6);
|
||||||
|
uint _e16 = subgroupMul(_e6);
|
||||||
|
uint _e17 = subgroupMin(_e6);
|
||||||
|
uint _e18 = subgroupMax(_e6);
|
||||||
|
uint _e19 = subgroupAnd(_e6);
|
||||||
|
uint _e20 = subgroupOr(_e6);
|
||||||
|
uint _e21 = subgroupXor(_e6);
|
||||||
|
uint _e22 = subgroupExclusiveAdd(_e6);
|
||||||
|
uint _e23 = subgroupExclusiveMul(_e6);
|
||||||
|
uint _e24 = subgroupInclusiveAdd(_e6);
|
||||||
|
uint _e25 = subgroupInclusiveMul(_e6);
|
||||||
|
uint _e26 = subgroupBroadcastFirst(_e6);
|
||||||
|
uint _e27 = subgroupBroadcast(_e6, 4u);
|
||||||
|
uint _e30 = subgroupShuffle(_e6, ((_e5 - 1u) - _e6));
|
||||||
|
uint _e31 = subgroupShuffleDown(_e6, 1u);
|
||||||
|
uint _e32 = subgroupShuffleUp(_e6, 1u);
|
||||||
|
uint _e34 = subgroupShuffleXor(_e6, (_e5 - 1u));
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
void main() {
|
||||||
|
uint num_subgroups = gl_NumSubgroups;
|
||||||
|
uint subgroup_id = gl_SubgroupID;
|
||||||
|
uint subgroup_size = gl_SubgroupSize;
|
||||||
|
uint subgroup_invocation_id = gl_SubgroupInvocationID;
|
||||||
|
num_subgroups_1 = num_subgroups;
|
||||||
|
subgroup_id_1 = subgroup_id;
|
||||||
|
subgroup_size_1 = subgroup_size;
|
||||||
|
subgroup_invocation_id_1 = subgroup_invocation_id;
|
||||||
|
main_1();
|
||||||
|
}
|
||||||
|
|
||||||
45
naga/tests/out/glsl/subgroup-operations.main.Compute.glsl
Normal file
45
naga/tests/out/glsl/subgroup-operations.main.Compute.glsl
Normal file
@ -0,0 +1,45 @@
|
|||||||
|
#version 430 core
|
||||||
|
#extension GL_ARB_compute_shader : require
|
||||||
|
#extension GL_KHR_shader_subgroup_basic : require
|
||||||
|
#extension GL_KHR_shader_subgroup_vote : require
|
||||||
|
#extension GL_KHR_shader_subgroup_arithmetic : require
|
||||||
|
#extension GL_KHR_shader_subgroup_ballot : require
|
||||||
|
#extension GL_KHR_shader_subgroup_shuffle : require
|
||||||
|
#extension GL_KHR_shader_subgroup_shuffle_relative : require
|
||||||
|
layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
|
||||||
|
|
||||||
|
struct Structure {
|
||||||
|
uint num_subgroups;
|
||||||
|
uint subgroup_size;
|
||||||
|
};
|
||||||
|
|
||||||
|
void main() {
|
||||||
|
Structure sizes = Structure(gl_NumSubgroups, gl_SubgroupSize);
|
||||||
|
uint subgroup_id = gl_SubgroupID;
|
||||||
|
uint subgroup_invocation_id = gl_SubgroupInvocationID;
|
||||||
|
subgroupMemoryBarrier();
|
||||||
|
barrier();
|
||||||
|
uvec4 _e7 = subgroupBallot(((subgroup_invocation_id & 1u) == 1u));
|
||||||
|
uvec4 _e8 = subgroupBallot(true);
|
||||||
|
bool _e11 = subgroupAll((subgroup_invocation_id != 0u));
|
||||||
|
bool _e14 = subgroupAny((subgroup_invocation_id == 0u));
|
||||||
|
uint _e15 = subgroupAdd(subgroup_invocation_id);
|
||||||
|
uint _e16 = subgroupMul(subgroup_invocation_id);
|
||||||
|
uint _e17 = subgroupMin(subgroup_invocation_id);
|
||||||
|
uint _e18 = subgroupMax(subgroup_invocation_id);
|
||||||
|
uint _e19 = subgroupAnd(subgroup_invocation_id);
|
||||||
|
uint _e20 = subgroupOr(subgroup_invocation_id);
|
||||||
|
uint _e21 = subgroupXor(subgroup_invocation_id);
|
||||||
|
uint _e22 = subgroupExclusiveAdd(subgroup_invocation_id);
|
||||||
|
uint _e23 = subgroupExclusiveMul(subgroup_invocation_id);
|
||||||
|
uint _e24 = subgroupInclusiveAdd(subgroup_invocation_id);
|
||||||
|
uint _e25 = subgroupInclusiveMul(subgroup_invocation_id);
|
||||||
|
uint _e26 = subgroupBroadcastFirst(subgroup_invocation_id);
|
||||||
|
uint _e28 = subgroupBroadcast(subgroup_invocation_id, 4u);
|
||||||
|
uint _e33 = subgroupShuffle(subgroup_invocation_id, ((sizes.subgroup_size - 1u) - subgroup_invocation_id));
|
||||||
|
uint _e35 = subgroupShuffleDown(subgroup_invocation_id, 1u);
|
||||||
|
uint _e37 = subgroupShuffleUp(subgroup_invocation_id, 1u);
|
||||||
|
uint _e41 = subgroupShuffleXor(subgroup_invocation_id, (sizes.subgroup_size - 1u));
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
50
naga/tests/out/hlsl/subgroup-operations-s.hlsl
Normal file
50
naga/tests/out/hlsl/subgroup-operations-s.hlsl
Normal file
@ -0,0 +1,50 @@
|
|||||||
|
static uint num_subgroups_1 = (uint)0;
|
||||||
|
static uint subgroup_id_1 = (uint)0;
|
||||||
|
static uint subgroup_size_1 = (uint)0;
|
||||||
|
static uint subgroup_invocation_id_1 = (uint)0;
|
||||||
|
|
||||||
|
struct ComputeInput_main {
|
||||||
|
uint __local_invocation_index : SV_GroupIndex;
|
||||||
|
};
|
||||||
|
|
||||||
|
void main_1()
|
||||||
|
{
|
||||||
|
uint _expr5 = subgroup_size_1;
|
||||||
|
uint _expr6 = subgroup_invocation_id_1;
|
||||||
|
const uint4 _e9 = WaveActiveBallot(((_expr6 & 1u) == 1u));
|
||||||
|
const uint4 _e10 = WaveActiveBallot(true);
|
||||||
|
const bool _e12 = WaveActiveAllTrue((_expr6 != 0u));
|
||||||
|
const bool _e14 = WaveActiveAnyTrue((_expr6 == 0u));
|
||||||
|
const uint _e15 = WaveActiveSum(_expr6);
|
||||||
|
const uint _e16 = WaveActiveProduct(_expr6);
|
||||||
|
const uint _e17 = WaveActiveMin(_expr6);
|
||||||
|
const uint _e18 = WaveActiveMax(_expr6);
|
||||||
|
const uint _e19 = WaveActiveBitAnd(_expr6);
|
||||||
|
const uint _e20 = WaveActiveBitOr(_expr6);
|
||||||
|
const uint _e21 = WaveActiveBitXor(_expr6);
|
||||||
|
const uint _e22 = WavePrefixSum(_expr6);
|
||||||
|
const uint _e23 = WavePrefixProduct(_expr6);
|
||||||
|
const uint _e24 = _expr6 + WavePrefixSum(_expr6);
|
||||||
|
const uint _e25 = _expr6 * WavePrefixProduct(_expr6);
|
||||||
|
const uint _e26 = WaveReadLaneFirst(_expr6);
|
||||||
|
const uint _e27 = WaveReadLaneAt(_expr6, 4u);
|
||||||
|
const uint _e30 = WaveReadLaneAt(_expr6, ((_expr5 - 1u) - _expr6));
|
||||||
|
const uint _e31 = WaveReadLaneAt(_expr6, WaveGetLaneIndex() + 1u);
|
||||||
|
const uint _e32 = WaveReadLaneAt(_expr6, WaveGetLaneIndex() - 1u);
|
||||||
|
const uint _e34 = WaveReadLaneAt(_expr6, WaveGetLaneIndex() ^ (_expr5 - 1u));
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
[numthreads(1, 1, 1)]
|
||||||
|
void main(ComputeInput_main computeinput_main)
|
||||||
|
{
|
||||||
|
uint num_subgroups = (1u + WaveGetLaneCount() - 1u) / WaveGetLaneCount();
|
||||||
|
uint subgroup_id = computeinput_main.__local_invocation_index / WaveGetLaneCount();
|
||||||
|
uint subgroup_size = WaveGetLaneCount();
|
||||||
|
uint subgroup_invocation_id = WaveGetLaneIndex();
|
||||||
|
num_subgroups_1 = num_subgroups;
|
||||||
|
subgroup_id_1 = subgroup_id;
|
||||||
|
subgroup_size_1 = subgroup_size;
|
||||||
|
subgroup_invocation_id_1 = subgroup_invocation_id;
|
||||||
|
main_1();
|
||||||
|
}
|
||||||
12
naga/tests/out/hlsl/subgroup-operations-s.ron
Normal file
12
naga/tests/out/hlsl/subgroup-operations-s.ron
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
(
|
||||||
|
vertex:[
|
||||||
|
],
|
||||||
|
fragment:[
|
||||||
|
],
|
||||||
|
compute:[
|
||||||
|
(
|
||||||
|
entry_point:"main",
|
||||||
|
target_profile:"cs_6_0",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
38
naga/tests/out/hlsl/subgroup-operations.hlsl
Normal file
38
naga/tests/out/hlsl/subgroup-operations.hlsl
Normal file
@ -0,0 +1,38 @@
|
|||||||
|
struct Structure {
|
||||||
|
uint num_subgroups;
|
||||||
|
uint subgroup_size;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct ComputeInput_main {
|
||||||
|
uint __local_invocation_index : SV_GroupIndex;
|
||||||
|
};
|
||||||
|
|
||||||
|
[numthreads(1, 1, 1)]
|
||||||
|
void main(ComputeInput_main computeinput_main)
|
||||||
|
{
|
||||||
|
Structure sizes = { (1u + WaveGetLaneCount() - 1u) / WaveGetLaneCount(), WaveGetLaneCount() };
|
||||||
|
uint subgroup_id = computeinput_main.__local_invocation_index / WaveGetLaneCount();
|
||||||
|
uint subgroup_invocation_id = WaveGetLaneIndex();
|
||||||
|
const uint4 _e7 = WaveActiveBallot(((subgroup_invocation_id & 1u) == 1u));
|
||||||
|
const uint4 _e8 = WaveActiveBallot(true);
|
||||||
|
const bool _e11 = WaveActiveAllTrue((subgroup_invocation_id != 0u));
|
||||||
|
const bool _e14 = WaveActiveAnyTrue((subgroup_invocation_id == 0u));
|
||||||
|
const uint _e15 = WaveActiveSum(subgroup_invocation_id);
|
||||||
|
const uint _e16 = WaveActiveProduct(subgroup_invocation_id);
|
||||||
|
const uint _e17 = WaveActiveMin(subgroup_invocation_id);
|
||||||
|
const uint _e18 = WaveActiveMax(subgroup_invocation_id);
|
||||||
|
const uint _e19 = WaveActiveBitAnd(subgroup_invocation_id);
|
||||||
|
const uint _e20 = WaveActiveBitOr(subgroup_invocation_id);
|
||||||
|
const uint _e21 = WaveActiveBitXor(subgroup_invocation_id);
|
||||||
|
const uint _e22 = WavePrefixSum(subgroup_invocation_id);
|
||||||
|
const uint _e23 = WavePrefixProduct(subgroup_invocation_id);
|
||||||
|
const uint _e24 = subgroup_invocation_id + WavePrefixSum(subgroup_invocation_id);
|
||||||
|
const uint _e25 = subgroup_invocation_id * WavePrefixProduct(subgroup_invocation_id);
|
||||||
|
const uint _e26 = WaveReadLaneFirst(subgroup_invocation_id);
|
||||||
|
const uint _e28 = WaveReadLaneAt(subgroup_invocation_id, 4u);
|
||||||
|
const uint _e33 = WaveReadLaneAt(subgroup_invocation_id, ((sizes.subgroup_size - 1u) - subgroup_invocation_id));
|
||||||
|
const uint _e35 = WaveReadLaneAt(subgroup_invocation_id, WaveGetLaneIndex() + 1u);
|
||||||
|
const uint _e37 = WaveReadLaneAt(subgroup_invocation_id, WaveGetLaneIndex() - 1u);
|
||||||
|
const uint _e41 = WaveReadLaneAt(subgroup_invocation_id, WaveGetLaneIndex() ^ (sizes.subgroup_size - 1u));
|
||||||
|
return;
|
||||||
|
}
|
||||||
12
naga/tests/out/hlsl/subgroup-operations.ron
Normal file
12
naga/tests/out/hlsl/subgroup-operations.ron
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
(
|
||||||
|
vertex:[
|
||||||
|
],
|
||||||
|
fragment:[
|
||||||
|
],
|
||||||
|
compute:[
|
||||||
|
(
|
||||||
|
entry_point:"main",
|
||||||
|
target_profile:"cs_6_0",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
55
naga/tests/out/msl/subgroup-operations-s.msl
Normal file
55
naga/tests/out/msl/subgroup-operations-s.msl
Normal file
@ -0,0 +1,55 @@
|
|||||||
|
// language: metal2.4
|
||||||
|
#include <metal_stdlib>
|
||||||
|
#include <simd/simd.h>
|
||||||
|
|
||||||
|
using metal::uint;
|
||||||
|
|
||||||
|
|
||||||
|
void main_1(
|
||||||
|
thread uint& subgroup_size_1,
|
||||||
|
thread uint& subgroup_invocation_id_1
|
||||||
|
) {
|
||||||
|
uint _e5 = subgroup_size_1;
|
||||||
|
uint _e6 = subgroup_invocation_id_1;
|
||||||
|
metal::uint4 unnamed = uint4((uint64_t)metal::simd_ballot((_e6 & 1u) == 1u), 0, 0, 0);
|
||||||
|
metal::uint4 unnamed_1 = uint4((uint64_t)metal::simd_ballot(true), 0, 0, 0);
|
||||||
|
bool unnamed_2 = metal::simd_all(_e6 != 0u);
|
||||||
|
bool unnamed_3 = metal::simd_any(_e6 == 0u);
|
||||||
|
uint unnamed_4 = metal::simd_sum(_e6);
|
||||||
|
uint unnamed_5 = metal::simd_product(_e6);
|
||||||
|
uint unnamed_6 = metal::simd_min(_e6);
|
||||||
|
uint unnamed_7 = metal::simd_max(_e6);
|
||||||
|
uint unnamed_8 = metal::simd_and(_e6);
|
||||||
|
uint unnamed_9 = metal::simd_or(_e6);
|
||||||
|
uint unnamed_10 = metal::simd_xor(_e6);
|
||||||
|
uint unnamed_11 = metal::simd_prefix_exclusive_sum(_e6);
|
||||||
|
uint unnamed_12 = metal::simd_prefix_exclusive_product(_e6);
|
||||||
|
uint unnamed_13 = metal::simd_prefix_inclusive_sum(_e6);
|
||||||
|
uint unnamed_14 = metal::simd_prefix_inclusive_product(_e6);
|
||||||
|
uint unnamed_15 = metal::simd_broadcast_first(_e6);
|
||||||
|
uint unnamed_16 = metal::simd_broadcast(_e6, 4u);
|
||||||
|
uint unnamed_17 = metal::simd_shuffle(_e6, (_e5 - 1u) - _e6);
|
||||||
|
uint unnamed_18 = metal::simd_shuffle_down(_e6, 1u);
|
||||||
|
uint unnamed_19 = metal::simd_shuffle_up(_e6, 1u);
|
||||||
|
uint unnamed_20 = metal::simd_shuffle_xor(_e6, _e5 - 1u);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
struct main_Input {
|
||||||
|
};
|
||||||
|
kernel void main_(
|
||||||
|
uint num_subgroups [[simdgroups_per_threadgroup]]
|
||||||
|
, uint subgroup_id [[simdgroup_index_in_threadgroup]]
|
||||||
|
, uint subgroup_size [[threads_per_simdgroup]]
|
||||||
|
, uint subgroup_invocation_id [[thread_index_in_simdgroup]]
|
||||||
|
) {
|
||||||
|
uint num_subgroups_1 = {};
|
||||||
|
uint subgroup_id_1 = {};
|
||||||
|
uint subgroup_size_1 = {};
|
||||||
|
uint subgroup_invocation_id_1 = {};
|
||||||
|
num_subgroups_1 = num_subgroups;
|
||||||
|
subgroup_id_1 = subgroup_id;
|
||||||
|
subgroup_size_1 = subgroup_size;
|
||||||
|
subgroup_invocation_id_1 = subgroup_invocation_id;
|
||||||
|
main_1(subgroup_size_1, subgroup_invocation_id_1);
|
||||||
|
}
|
||||||
44
naga/tests/out/msl/subgroup-operations.msl
Normal file
44
naga/tests/out/msl/subgroup-operations.msl
Normal file
@ -0,0 +1,44 @@
|
|||||||
|
// language: metal2.4
|
||||||
|
#include <metal_stdlib>
|
||||||
|
#include <simd/simd.h>
|
||||||
|
|
||||||
|
using metal::uint;
|
||||||
|
|
||||||
|
struct Structure {
|
||||||
|
uint num_subgroups;
|
||||||
|
uint subgroup_size;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct main_Input {
|
||||||
|
};
|
||||||
|
kernel void main_(
|
||||||
|
uint num_subgroups [[simdgroups_per_threadgroup]]
|
||||||
|
, uint subgroup_size [[threads_per_simdgroup]]
|
||||||
|
, uint subgroup_id [[simdgroup_index_in_threadgroup]]
|
||||||
|
, uint subgroup_invocation_id [[thread_index_in_simdgroup]]
|
||||||
|
) {
|
||||||
|
const Structure sizes = { num_subgroups, subgroup_size };
|
||||||
|
metal::simdgroup_barrier(metal::mem_flags::mem_threadgroup);
|
||||||
|
metal::uint4 unnamed = uint4((uint64_t)metal::simd_ballot((subgroup_invocation_id & 1u) == 1u), 0, 0, 0);
|
||||||
|
metal::uint4 unnamed_1 = uint4((uint64_t)metal::simd_ballot(true), 0, 0, 0);
|
||||||
|
bool unnamed_2 = metal::simd_all(subgroup_invocation_id != 0u);
|
||||||
|
bool unnamed_3 = metal::simd_any(subgroup_invocation_id == 0u);
|
||||||
|
uint unnamed_4 = metal::simd_sum(subgroup_invocation_id);
|
||||||
|
uint unnamed_5 = metal::simd_product(subgroup_invocation_id);
|
||||||
|
uint unnamed_6 = metal::simd_min(subgroup_invocation_id);
|
||||||
|
uint unnamed_7 = metal::simd_max(subgroup_invocation_id);
|
||||||
|
uint unnamed_8 = metal::simd_and(subgroup_invocation_id);
|
||||||
|
uint unnamed_9 = metal::simd_or(subgroup_invocation_id);
|
||||||
|
uint unnamed_10 = metal::simd_xor(subgroup_invocation_id);
|
||||||
|
uint unnamed_11 = metal::simd_prefix_exclusive_sum(subgroup_invocation_id);
|
||||||
|
uint unnamed_12 = metal::simd_prefix_exclusive_product(subgroup_invocation_id);
|
||||||
|
uint unnamed_13 = metal::simd_prefix_inclusive_sum(subgroup_invocation_id);
|
||||||
|
uint unnamed_14 = metal::simd_prefix_inclusive_product(subgroup_invocation_id);
|
||||||
|
uint unnamed_15 = metal::simd_broadcast_first(subgroup_invocation_id);
|
||||||
|
uint unnamed_16 = metal::simd_broadcast(subgroup_invocation_id, 4u);
|
||||||
|
uint unnamed_17 = metal::simd_shuffle(subgroup_invocation_id, (sizes.subgroup_size - 1u) - subgroup_invocation_id);
|
||||||
|
uint unnamed_18 = metal::simd_shuffle_down(subgroup_invocation_id, 1u);
|
||||||
|
uint unnamed_19 = metal::simd_shuffle_up(subgroup_invocation_id, 1u);
|
||||||
|
uint unnamed_20 = metal::simd_shuffle_xor(subgroup_invocation_id, sizes.subgroup_size - 1u);
|
||||||
|
return;
|
||||||
|
}
|
||||||
81
naga/tests/out/spv/subgroup-operations.spvasm
Normal file
81
naga/tests/out/spv/subgroup-operations.spvasm
Normal file
@ -0,0 +1,81 @@
|
|||||||
|
; SPIR-V
|
||||||
|
; Version: 1.3
|
||||||
|
; Generator: rspirv
|
||||||
|
; Bound: 58
|
||||||
|
OpCapability Shader
|
||||||
|
OpCapability GroupNonUniform
|
||||||
|
OpCapability GroupNonUniformBallot
|
||||||
|
OpCapability GroupNonUniformVote
|
||||||
|
OpCapability GroupNonUniformArithmetic
|
||||||
|
OpCapability GroupNonUniformShuffle
|
||||||
|
OpCapability GroupNonUniformShuffleRelative
|
||||||
|
%1 = OpExtInstImport "GLSL.std.450"
|
||||||
|
OpMemoryModel Logical GLSL450
|
||||||
|
OpEntryPoint GLCompute %17 "main" %8 %11 %13 %15
|
||||||
|
OpExecutionMode %17 LocalSize 1 1 1
|
||||||
|
OpMemberDecorate %4 0 Offset 0
|
||||||
|
OpMemberDecorate %4 1 Offset 4
|
||||||
|
OpDecorate %8 BuiltIn NumSubgroups
|
||||||
|
OpDecorate %11 BuiltIn SubgroupSize
|
||||||
|
OpDecorate %13 BuiltIn SubgroupId
|
||||||
|
OpDecorate %15 BuiltIn SubgroupLocalInvocationId
|
||||||
|
%2 = OpTypeVoid
|
||||||
|
%3 = OpTypeInt 32 0
|
||||||
|
%4 = OpTypeStruct %3 %3
|
||||||
|
%5 = OpTypeBool
|
||||||
|
%9 = OpTypePointer Input %3
|
||||||
|
%8 = OpVariable %9 Input
|
||||||
|
%11 = OpVariable %9 Input
|
||||||
|
%13 = OpVariable %9 Input
|
||||||
|
%15 = OpVariable %9 Input
|
||||||
|
%18 = OpTypeFunction %2
|
||||||
|
%19 = OpConstant %3 1
|
||||||
|
%20 = OpConstant %3 0
|
||||||
|
%21 = OpConstant %3 4
|
||||||
|
%23 = OpConstant %3 3
|
||||||
|
%24 = OpConstant %3 2
|
||||||
|
%25 = OpConstant %3 8
|
||||||
|
%28 = OpTypeVector %3 4
|
||||||
|
%30 = OpConstantTrue %5
|
||||||
|
%17 = OpFunction %2 None %18
|
||||||
|
%6 = OpLabel
|
||||||
|
%10 = OpLoad %3 %8
|
||||||
|
%12 = OpLoad %3 %11
|
||||||
|
%7 = OpCompositeConstruct %4 %10 %12
|
||||||
|
%14 = OpLoad %3 %13
|
||||||
|
%16 = OpLoad %3 %15
|
||||||
|
OpBranch %22
|
||||||
|
%22 = OpLabel
|
||||||
|
OpControlBarrier %23 %24 %25
|
||||||
|
%26 = OpBitwiseAnd %3 %16 %19
|
||||||
|
%27 = OpIEqual %5 %26 %19
|
||||||
|
%29 = OpGroupNonUniformBallot %28 %23 %27
|
||||||
|
%31 = OpGroupNonUniformBallot %28 %23 %30
|
||||||
|
%32 = OpINotEqual %5 %16 %20
|
||||||
|
%33 = OpGroupNonUniformAll %5 %23 %32
|
||||||
|
%34 = OpIEqual %5 %16 %20
|
||||||
|
%35 = OpGroupNonUniformAny %5 %23 %34
|
||||||
|
%36 = OpGroupNonUniformIAdd %3 %23 Reduce %16
|
||||||
|
%37 = OpGroupNonUniformIMul %3 %23 Reduce %16
|
||||||
|
%38 = OpGroupNonUniformUMin %3 %23 Reduce %16
|
||||||
|
%39 = OpGroupNonUniformUMax %3 %23 Reduce %16
|
||||||
|
%40 = OpGroupNonUniformBitwiseAnd %3 %23 Reduce %16
|
||||||
|
%41 = OpGroupNonUniformBitwiseOr %3 %23 Reduce %16
|
||||||
|
%42 = OpGroupNonUniformBitwiseXor %3 %23 Reduce %16
|
||||||
|
%43 = OpGroupNonUniformIAdd %3 %23 ExclusiveScan %16
|
||||||
|
%44 = OpGroupNonUniformIMul %3 %23 ExclusiveScan %16
|
||||||
|
%45 = OpGroupNonUniformIAdd %3 %23 InclusiveScan %16
|
||||||
|
%46 = OpGroupNonUniformIMul %3 %23 InclusiveScan %16
|
||||||
|
%47 = OpGroupNonUniformBroadcastFirst %3 %23 %16
|
||||||
|
%48 = OpGroupNonUniformShuffle %3 %23 %16 %21
|
||||||
|
%49 = OpCompositeExtract %3 %7 1
|
||||||
|
%50 = OpISub %3 %49 %19
|
||||||
|
%51 = OpISub %3 %50 %16
|
||||||
|
%52 = OpGroupNonUniformShuffle %3 %23 %16 %51
|
||||||
|
%53 = OpGroupNonUniformShuffleDown %3 %23 %16 %19
|
||||||
|
%54 = OpGroupNonUniformShuffleUp %3 %23 %16 %19
|
||||||
|
%55 = OpCompositeExtract %3 %7 1
|
||||||
|
%56 = OpISub %3 %55 %19
|
||||||
|
%57 = OpGroupNonUniformShuffleXor %3 %23 %16 %56
|
||||||
|
OpReturn
|
||||||
|
OpFunctionEnd
|
||||||
40
naga/tests/out/wgsl/subgroup-operations-s.wgsl
Normal file
40
naga/tests/out/wgsl/subgroup-operations-s.wgsl
Normal file
@ -0,0 +1,40 @@
|
|||||||
|
var<private> num_subgroups_1: u32;
|
||||||
|
var<private> subgroup_id_1: u32;
|
||||||
|
var<private> subgroup_size_1: u32;
|
||||||
|
var<private> subgroup_invocation_id_1: u32;
|
||||||
|
|
||||||
|
fn main_1() {
|
||||||
|
let _e5 = subgroup_size_1;
|
||||||
|
let _e6 = subgroup_invocation_id_1;
|
||||||
|
let _e9 = subgroupBallot(((_e6 & 1u) == 1u));
|
||||||
|
let _e10 = subgroupBallot();
|
||||||
|
let _e12 = subgroupAll((_e6 != 0u));
|
||||||
|
let _e14 = subgroupAny((_e6 == 0u));
|
||||||
|
let _e15 = subgroupAdd(_e6);
|
||||||
|
let _e16 = subgroupMul(_e6);
|
||||||
|
let _e17 = subgroupMin(_e6);
|
||||||
|
let _e18 = subgroupMax(_e6);
|
||||||
|
let _e19 = subgroupAnd(_e6);
|
||||||
|
let _e20 = subgroupOr(_e6);
|
||||||
|
let _e21 = subgroupXor(_e6);
|
||||||
|
let _e22 = subgroupExclusiveAdd(_e6);
|
||||||
|
let _e23 = subgroupExclusiveMul(_e6);
|
||||||
|
let _e24 = subgroupInclusiveAdd(_e6);
|
||||||
|
let _e25 = subgroupInclusiveMul(_e6);
|
||||||
|
let _e26 = subgroupBroadcastFirst(_e6);
|
||||||
|
let _e27 = subgroupBroadcast(_e6, 4u);
|
||||||
|
let _e30 = subgroupShuffle(_e6, ((_e5 - 1u) - _e6));
|
||||||
|
let _e31 = subgroupShuffleDown(_e6, 1u);
|
||||||
|
let _e32 = subgroupShuffleUp(_e6, 1u);
|
||||||
|
let _e34 = subgroupShuffleXor(_e6, (_e5 - 1u));
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
@compute @workgroup_size(1, 1, 1)
|
||||||
|
fn main(@builtin(num_subgroups) num_subgroups: u32, @builtin(subgroup_id) subgroup_id: u32, @builtin(subgroup_size) subgroup_size: u32, @builtin(subgroup_invocation_id) subgroup_invocation_id: u32) {
|
||||||
|
num_subgroups_1 = num_subgroups;
|
||||||
|
subgroup_id_1 = subgroup_id;
|
||||||
|
subgroup_size_1 = subgroup_size;
|
||||||
|
subgroup_invocation_id_1 = subgroup_invocation_id;
|
||||||
|
main_1();
|
||||||
|
}
|
||||||
31
naga/tests/out/wgsl/subgroup-operations.wgsl
Normal file
31
naga/tests/out/wgsl/subgroup-operations.wgsl
Normal file
@ -0,0 +1,31 @@
|
|||||||
|
struct Structure {
|
||||||
|
@builtin(num_subgroups) num_subgroups: u32,
|
||||||
|
@builtin(subgroup_size) subgroup_size: u32,
|
||||||
|
}
|
||||||
|
|
||||||
|
@compute @workgroup_size(1, 1, 1)
|
||||||
|
fn main(sizes: Structure, @builtin(subgroup_id) subgroup_id: u32, @builtin(subgroup_invocation_id) subgroup_invocation_id: u32) {
|
||||||
|
subgroupBarrier();
|
||||||
|
let _e7 = subgroupBallot(((subgroup_invocation_id & 1u) == 1u));
|
||||||
|
let _e8 = subgroupBallot();
|
||||||
|
let _e11 = subgroupAll((subgroup_invocation_id != 0u));
|
||||||
|
let _e14 = subgroupAny((subgroup_invocation_id == 0u));
|
||||||
|
let _e15 = subgroupAdd(subgroup_invocation_id);
|
||||||
|
let _e16 = subgroupMul(subgroup_invocation_id);
|
||||||
|
let _e17 = subgroupMin(subgroup_invocation_id);
|
||||||
|
let _e18 = subgroupMax(subgroup_invocation_id);
|
||||||
|
let _e19 = subgroupAnd(subgroup_invocation_id);
|
||||||
|
let _e20 = subgroupOr(subgroup_invocation_id);
|
||||||
|
let _e21 = subgroupXor(subgroup_invocation_id);
|
||||||
|
let _e22 = subgroupExclusiveAdd(subgroup_invocation_id);
|
||||||
|
let _e23 = subgroupExclusiveMul(subgroup_invocation_id);
|
||||||
|
let _e24 = subgroupInclusiveAdd(subgroup_invocation_id);
|
||||||
|
let _e25 = subgroupInclusiveMul(subgroup_invocation_id);
|
||||||
|
let _e26 = subgroupBroadcastFirst(subgroup_invocation_id);
|
||||||
|
let _e28 = subgroupBroadcast(subgroup_invocation_id, 4u);
|
||||||
|
let _e33 = subgroupShuffle(subgroup_invocation_id, ((sizes.subgroup_size - 1u) - subgroup_invocation_id));
|
||||||
|
let _e35 = subgroupShuffleDown(subgroup_invocation_id, 1u);
|
||||||
|
let _e37 = subgroupShuffleUp(subgroup_invocation_id, 1u);
|
||||||
|
let _e41 = subgroupShuffleXor(subgroup_invocation_id, (sizes.subgroup_size - 1u));
|
||||||
|
return;
|
||||||
|
}
|
||||||
@ -269,10 +269,18 @@ fn check_targets(
|
|||||||
let params = input.read_parameters();
|
let params = input.read_parameters();
|
||||||
let name = &input.file_name;
|
let name = &input.file_name;
|
||||||
|
|
||||||
let capabilities = if params.god_mode {
|
let (capabilities, subgroup_stages, subgroup_operations) = if params.god_mode {
|
||||||
naga::valid::Capabilities::all()
|
(
|
||||||
|
naga::valid::Capabilities::all(),
|
||||||
|
naga::valid::ShaderStages::all(),
|
||||||
|
naga::valid::SubgroupOperationSet::all(),
|
||||||
|
)
|
||||||
} else {
|
} else {
|
||||||
naga::valid::Capabilities::default()
|
(
|
||||||
|
naga::valid::Capabilities::default(),
|
||||||
|
naga::valid::ShaderStages::empty(),
|
||||||
|
naga::valid::SubgroupOperationSet::empty(),
|
||||||
|
)
|
||||||
};
|
};
|
||||||
|
|
||||||
#[cfg(feature = "serialize")]
|
#[cfg(feature = "serialize")]
|
||||||
@ -285,6 +293,8 @@ fn check_targets(
|
|||||||
}
|
}
|
||||||
|
|
||||||
let info = naga::valid::Validator::new(naga::valid::ValidationFlags::all(), capabilities)
|
let info = naga::valid::Validator::new(naga::valid::ValidationFlags::all(), capabilities)
|
||||||
|
.subgroup_stages(subgroup_stages)
|
||||||
|
.subgroup_operations(subgroup_operations)
|
||||||
.validate(module)
|
.validate(module)
|
||||||
.unwrap_or_else(|err| {
|
.unwrap_or_else(|err| {
|
||||||
panic!(
|
panic!(
|
||||||
@ -308,6 +318,8 @@ fn check_targets(
|
|||||||
}
|
}
|
||||||
|
|
||||||
naga::valid::Validator::new(naga::valid::ValidationFlags::all(), capabilities)
|
naga::valid::Validator::new(naga::valid::ValidationFlags::all(), capabilities)
|
||||||
|
.subgroup_stages(subgroup_stages)
|
||||||
|
.subgroup_operations(subgroup_operations)
|
||||||
.validate(module)
|
.validate(module)
|
||||||
.unwrap_or_else(|err| {
|
.unwrap_or_else(|err| {
|
||||||
panic!(
|
panic!(
|
||||||
@ -850,6 +862,10 @@ fn convert_wgsl() {
|
|||||||
"int64",
|
"int64",
|
||||||
Targets::SPIRV | Targets::HLSL | Targets::WGSL | Targets::METAL,
|
Targets::SPIRV | Targets::HLSL | Targets::WGSL | Targets::METAL,
|
||||||
),
|
),
|
||||||
|
(
|
||||||
|
"subgroup-operations",
|
||||||
|
Targets::SPIRV | Targets::METAL | Targets::GLSL | Targets::HLSL | Targets::WGSL,
|
||||||
|
),
|
||||||
(
|
(
|
||||||
"overrides",
|
"overrides",
|
||||||
Targets::IR
|
Targets::IR
|
||||||
@ -957,6 +973,11 @@ fn convert_spv_all() {
|
|||||||
);
|
);
|
||||||
convert_spv("builtin-accessed-outside-entrypoint", true, Targets::WGSL);
|
convert_spv("builtin-accessed-outside-entrypoint", true, Targets::WGSL);
|
||||||
convert_spv("spec-constants", true, Targets::IR);
|
convert_spv("spec-constants", true, Targets::IR);
|
||||||
|
convert_spv(
|
||||||
|
"subgroup-operations-s",
|
||||||
|
false,
|
||||||
|
Targets::METAL | Targets::GLSL | Targets::HLSL | Targets::WGSL,
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(feature = "glsl-in")]
|
#[cfg(feature = "glsl-in")]
|
||||||
|
|||||||
@ -33,6 +33,7 @@ mod scissor_tests;
|
|||||||
mod shader;
|
mod shader;
|
||||||
mod shader_primitive_index;
|
mod shader_primitive_index;
|
||||||
mod shader_view_format;
|
mod shader_view_format;
|
||||||
|
mod subgroup_operations;
|
||||||
mod texture_bounds;
|
mod texture_bounds;
|
||||||
mod texture_view_creation;
|
mod texture_view_creation;
|
||||||
mod transfer;
|
mod transfer;
|
||||||
|
|||||||
126
tests/tests/subgroup_operations/mod.rs
Normal file
126
tests/tests/subgroup_operations/mod.rs
Normal file
@ -0,0 +1,126 @@
|
|||||||
|
use std::{borrow::Cow, collections::HashMap, num::NonZeroU64};
|
||||||
|
|
||||||
|
use wgpu_test::{gpu_test, GpuTestConfiguration, TestParameters};
|
||||||
|
|
||||||
|
const THREAD_COUNT: u64 = 128;
|
||||||
|
const TEST_COUNT: u32 = 32;
|
||||||
|
|
||||||
|
#[gpu_test]
|
||||||
|
static SUBGROUP_OPERATIONS: GpuTestConfiguration = GpuTestConfiguration::new()
|
||||||
|
.parameters(
|
||||||
|
TestParameters::default()
|
||||||
|
.features(wgpu::Features::SUBGROUP)
|
||||||
|
.limits(wgpu::Limits::downlevel_defaults())
|
||||||
|
.expect_fail(wgpu_test::FailureCase::molten_vk())
|
||||||
|
.expect_fail(
|
||||||
|
// Expect metal to fail on tests involving operations in divergent control flow
|
||||||
|
wgpu_test::FailureCase::backend(wgpu::Backends::METAL)
|
||||||
|
.panic("thread 0 failed tests: 27, 29,\nthread 1 failed tests: 27, 28, 29,\n"),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
.run_sync(|ctx| {
|
||||||
|
let device = &ctx.device;
|
||||||
|
|
||||||
|
let storage_buffer = device.create_buffer(&wgpu::BufferDescriptor {
|
||||||
|
label: None,
|
||||||
|
size: THREAD_COUNT * std::mem::size_of::<u32>() as u64,
|
||||||
|
usage: wgpu::BufferUsages::STORAGE
|
||||||
|
| wgpu::BufferUsages::COPY_DST
|
||||||
|
| wgpu::BufferUsages::COPY_SRC,
|
||||||
|
mapped_at_creation: false,
|
||||||
|
});
|
||||||
|
|
||||||
|
let bind_group_layout = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
|
||||||
|
label: Some("bind group layout"),
|
||||||
|
entries: &[wgpu::BindGroupLayoutEntry {
|
||||||
|
binding: 0,
|
||||||
|
visibility: wgpu::ShaderStages::COMPUTE,
|
||||||
|
ty: wgpu::BindingType::Buffer {
|
||||||
|
ty: wgpu::BufferBindingType::Storage { read_only: false },
|
||||||
|
has_dynamic_offset: false,
|
||||||
|
min_binding_size: NonZeroU64::new(
|
||||||
|
THREAD_COUNT * std::mem::size_of::<u32>() as u64,
|
||||||
|
),
|
||||||
|
},
|
||||||
|
count: None,
|
||||||
|
}],
|
||||||
|
});
|
||||||
|
|
||||||
|
let cs_module = device.create_shader_module(wgpu::ShaderModuleDescriptor {
|
||||||
|
label: None,
|
||||||
|
source: wgpu::ShaderSource::Wgsl(Cow::Borrowed(include_str!("shader.wgsl"))),
|
||||||
|
});
|
||||||
|
|
||||||
|
let pipeline_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
|
||||||
|
label: Some("main"),
|
||||||
|
bind_group_layouts: &[&bind_group_layout],
|
||||||
|
push_constant_ranges: &[],
|
||||||
|
});
|
||||||
|
|
||||||
|
let compute_pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
|
||||||
|
label: None,
|
||||||
|
layout: Some(&pipeline_layout),
|
||||||
|
module: &cs_module,
|
||||||
|
entry_point: "main",
|
||||||
|
constants: &HashMap::default(),
|
||||||
|
});
|
||||||
|
|
||||||
|
let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
|
||||||
|
entries: &[wgpu::BindGroupEntry {
|
||||||
|
binding: 0,
|
||||||
|
resource: storage_buffer.as_entire_binding(),
|
||||||
|
}],
|
||||||
|
layout: &bind_group_layout,
|
||||||
|
label: Some("bind group"),
|
||||||
|
});
|
||||||
|
|
||||||
|
let mut encoder =
|
||||||
|
device.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None });
|
||||||
|
{
|
||||||
|
let mut cpass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
|
||||||
|
label: None,
|
||||||
|
timestamp_writes: None,
|
||||||
|
});
|
||||||
|
cpass.set_pipeline(&compute_pipeline);
|
||||||
|
cpass.set_bind_group(0, &bind_group, &[]);
|
||||||
|
cpass.dispatch_workgroups(1, 1, 1);
|
||||||
|
}
|
||||||
|
ctx.queue.submit(Some(encoder.finish()));
|
||||||
|
|
||||||
|
wgpu::util::DownloadBuffer::read_buffer(
|
||||||
|
device,
|
||||||
|
&ctx.queue,
|
||||||
|
&storage_buffer.slice(..),
|
||||||
|
|mapping_buffer_view| {
|
||||||
|
let mapping_buffer_view = mapping_buffer_view.unwrap();
|
||||||
|
let result: &[u32; THREAD_COUNT as usize] =
|
||||||
|
bytemuck::from_bytes(&mapping_buffer_view);
|
||||||
|
let expected_mask = (1u64 << (TEST_COUNT)) - 1; // generate full mask
|
||||||
|
let expected_array = [expected_mask as u32; THREAD_COUNT as usize];
|
||||||
|
if result != &expected_array {
|
||||||
|
use std::fmt::Write;
|
||||||
|
let mut msg = String::new();
|
||||||
|
writeln!(
|
||||||
|
&mut msg,
|
||||||
|
"Got from GPU:\n{:x?}\n expected:\n{:x?}",
|
||||||
|
result, &expected_array,
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
for (thread, (result, expected)) in result
|
||||||
|
.iter()
|
||||||
|
.zip(expected_array)
|
||||||
|
.enumerate()
|
||||||
|
.filter(|(_, (r, e))| *r != e)
|
||||||
|
{
|
||||||
|
write!(&mut msg, "thread {} failed tests:", thread).unwrap();
|
||||||
|
let difference = result ^ expected;
|
||||||
|
for i in (0..u32::BITS).filter(|i| (difference & (1 << i)) != 0) {
|
||||||
|
write!(&mut msg, " {},", i).unwrap();
|
||||||
|
}
|
||||||
|
writeln!(&mut msg).unwrap();
|
||||||
|
}
|
||||||
|
panic!("{}", msg);
|
||||||
|
}
|
||||||
|
},
|
||||||
|
);
|
||||||
|
});
|
||||||
158
tests/tests/subgroup_operations/shader.wgsl
Normal file
158
tests/tests/subgroup_operations/shader.wgsl
Normal file
@ -0,0 +1,158 @@
|
|||||||
|
@group(0)
|
||||||
|
@binding(0)
|
||||||
|
var<storage, read_write> storage_buffer: array<u32>;
|
||||||
|
|
||||||
|
var<workgroup> workgroup_buffer: u32;
|
||||||
|
|
||||||
|
fn add_result_to_mask(mask: ptr<function, u32>, index: u32, value: bool) {
|
||||||
|
(*mask) |= u32(value) << index;
|
||||||
|
}
|
||||||
|
|
||||||
|
@compute
|
||||||
|
@workgroup_size(128)
|
||||||
|
fn main(
|
||||||
|
@builtin(global_invocation_id) global_id: vec3<u32>,
|
||||||
|
@builtin(num_subgroups) num_subgroups: u32,
|
||||||
|
@builtin(subgroup_id) subgroup_id: u32,
|
||||||
|
@builtin(subgroup_size) subgroup_size: u32,
|
||||||
|
@builtin(subgroup_invocation_id) subgroup_invocation_id: u32,
|
||||||
|
) {
|
||||||
|
var passed = 0u;
|
||||||
|
var expected: u32;
|
||||||
|
|
||||||
|
add_result_to_mask(&passed, 0u, num_subgroups == 128u / subgroup_size);
|
||||||
|
add_result_to_mask(&passed, 1u, subgroup_id == global_id.x / subgroup_size);
|
||||||
|
add_result_to_mask(&passed, 2u, subgroup_invocation_id == global_id.x % subgroup_size);
|
||||||
|
|
||||||
|
var expected_ballot = vec4<u32>(0u);
|
||||||
|
for(var i = 0u; i < subgroup_size; i += 1u) {
|
||||||
|
expected_ballot[i / 32u] |= ((global_id.x - subgroup_invocation_id + i) & 1u) << (i % 32u);
|
||||||
|
}
|
||||||
|
add_result_to_mask(&passed, 3u, dot(vec4<u32>(1u), vec4<u32>(subgroupBallot((subgroup_invocation_id & 1u) == 1u) == expected_ballot)) == 4u);
|
||||||
|
|
||||||
|
add_result_to_mask(&passed, 4u, subgroupAll(true));
|
||||||
|
add_result_to_mask(&passed, 5u, !subgroupAll(subgroup_invocation_id != 0u));
|
||||||
|
|
||||||
|
add_result_to_mask(&passed, 6u, subgroupAny(subgroup_invocation_id == 0u));
|
||||||
|
add_result_to_mask(&passed, 7u, !subgroupAny(false));
|
||||||
|
|
||||||
|
expected = 0u;
|
||||||
|
for(var i = 0u; i < subgroup_size; i += 1u) {
|
||||||
|
expected += global_id.x - subgroup_invocation_id + i + 1u;
|
||||||
|
}
|
||||||
|
add_result_to_mask(&passed, 8u, subgroupAdd(global_id.x + 1u) == expected);
|
||||||
|
|
||||||
|
expected = 1u;
|
||||||
|
for(var i = 0u; i < subgroup_size; i += 1u) {
|
||||||
|
expected *= global_id.x - subgroup_invocation_id + i + 1u;
|
||||||
|
}
|
||||||
|
add_result_to_mask(&passed, 9u, subgroupMul(global_id.x + 1u) == expected);
|
||||||
|
|
||||||
|
expected = 0u;
|
||||||
|
for(var i = 0u; i < subgroup_size; i += 1u) {
|
||||||
|
expected = max(expected, global_id.x - subgroup_invocation_id + i + 1u);
|
||||||
|
}
|
||||||
|
add_result_to_mask(&passed, 10u, subgroupMax(global_id.x + 1u) == expected);
|
||||||
|
|
||||||
|
expected = 0xFFFFFFFFu;
|
||||||
|
for(var i = 0u; i < subgroup_size; i += 1u) {
|
||||||
|
expected = min(expected, global_id.x - subgroup_invocation_id + i + 1u);
|
||||||
|
}
|
||||||
|
add_result_to_mask(&passed, 11u, subgroupMin(global_id.x + 1u) == expected);
|
||||||
|
|
||||||
|
expected = 0xFFFFFFFFu;
|
||||||
|
for(var i = 0u; i < subgroup_size; i += 1u) {
|
||||||
|
expected &= global_id.x - subgroup_invocation_id + i + 1u;
|
||||||
|
}
|
||||||
|
add_result_to_mask(&passed, 12u, subgroupAnd(global_id.x + 1u) == expected);
|
||||||
|
|
||||||
|
expected = 0u;
|
||||||
|
for(var i = 0u; i < subgroup_size; i += 1u) {
|
||||||
|
expected |= global_id.x - subgroup_invocation_id + i + 1u;
|
||||||
|
}
|
||||||
|
add_result_to_mask(&passed, 13u, subgroupOr(global_id.x + 1u) == expected);
|
||||||
|
|
||||||
|
expected = 0u;
|
||||||
|
for(var i = 0u; i < subgroup_size; i += 1u) {
|
||||||
|
expected ^= global_id.x - subgroup_invocation_id + i + 1u;
|
||||||
|
}
|
||||||
|
add_result_to_mask(&passed, 14u, subgroupXor(global_id.x + 1u) == expected);
|
||||||
|
|
||||||
|
expected = 0u;
|
||||||
|
for(var i = 0u; i < subgroup_invocation_id; i += 1u) {
|
||||||
|
expected += global_id.x - subgroup_invocation_id + i + 1u;
|
||||||
|
}
|
||||||
|
add_result_to_mask(&passed, 15u, subgroupExclusiveAdd(global_id.x + 1u) == expected);
|
||||||
|
|
||||||
|
expected = 1u;
|
||||||
|
for(var i = 0u; i < subgroup_invocation_id; i += 1u) {
|
||||||
|
expected *= global_id.x - subgroup_invocation_id + i + 1u;
|
||||||
|
}
|
||||||
|
add_result_to_mask(&passed, 16u, subgroupExclusiveMul(global_id.x + 1u) == expected);
|
||||||
|
|
||||||
|
expected = 0u;
|
||||||
|
for(var i = 0u; i <= subgroup_invocation_id; i += 1u) {
|
||||||
|
expected += global_id.x - subgroup_invocation_id + i + 1u;
|
||||||
|
}
|
||||||
|
add_result_to_mask(&passed, 17u, subgroupInclusiveAdd(global_id.x + 1u) == expected);
|
||||||
|
|
||||||
|
expected = 1u;
|
||||||
|
for(var i = 0u; i <= subgroup_invocation_id; i += 1u) {
|
||||||
|
expected *= global_id.x - subgroup_invocation_id + i + 1u;
|
||||||
|
}
|
||||||
|
add_result_to_mask(&passed, 18u, subgroupInclusiveMul(global_id.x + 1u) == expected);
|
||||||
|
|
||||||
|
add_result_to_mask(&passed, 19u, subgroupBroadcastFirst(u32(subgroup_invocation_id != 0u)) == 0u);
|
||||||
|
add_result_to_mask(&passed, 20u, subgroupBroadcastFirst(u32(subgroup_invocation_id == 0u)) == 1u);
|
||||||
|
add_result_to_mask(&passed, 21u, subgroupBroadcast(subgroup_invocation_id, 1u) == 1u);
|
||||||
|
add_result_to_mask(&passed, 22u, subgroupShuffle(subgroup_invocation_id, subgroup_invocation_id) == subgroup_invocation_id);
|
||||||
|
add_result_to_mask(&passed, 23u, subgroupShuffle(subgroup_invocation_id, subgroup_size - 1u - subgroup_invocation_id) == subgroup_size - 1u - subgroup_invocation_id);
|
||||||
|
add_result_to_mask(&passed, 24u, subgroup_invocation_id == subgroup_size - 1u || subgroupShuffleDown(subgroup_invocation_id, 1u) == subgroup_invocation_id + 1u);
|
||||||
|
add_result_to_mask(&passed, 25u, subgroup_invocation_id == 0u || subgroupShuffleUp(subgroup_invocation_id, 1u) == subgroup_invocation_id - 1u);
|
||||||
|
add_result_to_mask(&passed, 26u, subgroupShuffleXor(subgroup_invocation_id, subgroup_size - 1u) == (subgroup_invocation_id ^ (subgroup_size - 1u)));
|
||||||
|
|
||||||
|
var passed_27 = false;
|
||||||
|
if subgroup_invocation_id % 2u == 0u {
|
||||||
|
passed_27 |= subgroupAdd(1u) == (subgroup_size / 2u);
|
||||||
|
} else {
|
||||||
|
passed_27 |= subgroupAdd(1u) == (subgroup_size / 2u);
|
||||||
|
}
|
||||||
|
add_result_to_mask(&passed, 27u, passed_27);
|
||||||
|
|
||||||
|
var passed_28 = false;
|
||||||
|
switch subgroup_invocation_id % 3u {
|
||||||
|
case 0u: {
|
||||||
|
passed_28 = subgroupBroadcastFirst(subgroup_invocation_id) == 0u;
|
||||||
|
}
|
||||||
|
case 1u: {
|
||||||
|
passed_28 = subgroupBroadcastFirst(subgroup_invocation_id) == 1u;
|
||||||
|
}
|
||||||
|
case 2u: {
|
||||||
|
passed_28 = subgroupBroadcastFirst(subgroup_invocation_id) == 2u;
|
||||||
|
}
|
||||||
|
default { }
|
||||||
|
}
|
||||||
|
add_result_to_mask(&passed, 28u, passed_28);
|
||||||
|
|
||||||
|
expected = 0u;
|
||||||
|
for (var i = subgroup_size; i >= 0u; i -= 1u) {
|
||||||
|
expected = subgroupAdd(1u);
|
||||||
|
if i == subgroup_invocation_id {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
add_result_to_mask(&passed, 29u, expected == (subgroup_invocation_id + 1u));
|
||||||
|
|
||||||
|
if global_id.x == 0u {
|
||||||
|
workgroup_buffer = subgroup_size;
|
||||||
|
}
|
||||||
|
workgroupBarrier();
|
||||||
|
add_result_to_mask(&passed, 30u, workgroup_buffer == subgroup_size);
|
||||||
|
|
||||||
|
// Keep this test last, verify we are still convergent after running other tests
|
||||||
|
add_result_to_mask(&passed, 31u, subgroupAdd(1u) == subgroup_size);
|
||||||
|
|
||||||
|
// Increment TEST_COUNT in subgroup_operations/mod.rs if adding more tests
|
||||||
|
|
||||||
|
storage_buffer[global_id.x] = passed;
|
||||||
|
}
|
||||||
@ -1537,6 +1537,15 @@ impl<A: HalApi> Device<A> {
|
|||||||
.flags
|
.flags
|
||||||
.contains(wgt::DownlevelFlags::CUBE_ARRAY_TEXTURES),
|
.contains(wgt::DownlevelFlags::CUBE_ARRAY_TEXTURES),
|
||||||
);
|
);
|
||||||
|
caps.set(
|
||||||
|
Caps::SUBGROUP,
|
||||||
|
self.features
|
||||||
|
.intersects(wgt::Features::SUBGROUP | wgt::Features::SUBGROUP_VERTEX),
|
||||||
|
);
|
||||||
|
caps.set(
|
||||||
|
Caps::SUBGROUP_BARRIER,
|
||||||
|
self.features.intersects(wgt::Features::SUBGROUP_BARRIER),
|
||||||
|
);
|
||||||
|
|
||||||
let debug_source =
|
let debug_source =
|
||||||
if self.instance_flags.contains(wgt::InstanceFlags::DEBUG) && !source.is_empty() {
|
if self.instance_flags.contains(wgt::InstanceFlags::DEBUG) && !source.is_empty() {
|
||||||
@ -1552,7 +1561,26 @@ impl<A: HalApi> Device<A> {
|
|||||||
None
|
None
|
||||||
};
|
};
|
||||||
|
|
||||||
|
let mut subgroup_stages = naga::valid::ShaderStages::empty();
|
||||||
|
subgroup_stages.set(
|
||||||
|
naga::valid::ShaderStages::COMPUTE | naga::valid::ShaderStages::FRAGMENT,
|
||||||
|
self.features.contains(wgt::Features::SUBGROUP),
|
||||||
|
);
|
||||||
|
subgroup_stages.set(
|
||||||
|
naga::valid::ShaderStages::VERTEX,
|
||||||
|
self.features.contains(wgt::Features::SUBGROUP_VERTEX),
|
||||||
|
);
|
||||||
|
|
||||||
|
let subgroup_operations = if caps.contains(Caps::SUBGROUP) {
|
||||||
|
use naga::valid::SubgroupOperationSet as S;
|
||||||
|
S::BASIC | S::VOTE | S::ARITHMETIC | S::BALLOT | S::SHUFFLE | S::SHUFFLE_RELATIVE
|
||||||
|
} else {
|
||||||
|
naga::valid::SubgroupOperationSet::empty()
|
||||||
|
};
|
||||||
|
|
||||||
let info = naga::valid::Validator::new(naga::valid::ValidationFlags::all(), caps)
|
let info = naga::valid::Validator::new(naga::valid::ValidationFlags::all(), caps)
|
||||||
|
.subgroup_stages(subgroup_stages)
|
||||||
|
.subgroup_operations(subgroup_operations)
|
||||||
.validate(&module)
|
.validate(&module)
|
||||||
.map_err(|inner| {
|
.map_err(|inner| {
|
||||||
pipeline::CreateShaderModuleError::Validation(pipeline::ShaderError {
|
pipeline::CreateShaderModuleError::Validation(pipeline::ShaderError {
|
||||||
|
|||||||
@ -127,6 +127,11 @@ impl super::Adapter {
|
|||||||
)
|
)
|
||||||
});
|
});
|
||||||
|
|
||||||
|
// If we don't have dxc, we reduce the max to 5.1
|
||||||
|
if dxc_container.is_none() {
|
||||||
|
shader_model_support.HighestShaderModel = d3d12_ty::D3D_SHADER_MODEL_5_1;
|
||||||
|
}
|
||||||
|
|
||||||
let mut workarounds = super::Workarounds::default();
|
let mut workarounds = super::Workarounds::default();
|
||||||
|
|
||||||
let info = wgt::AdapterInfo {
|
let info = wgt::AdapterInfo {
|
||||||
@ -343,11 +348,7 @@ impl super::Adapter {
|
|||||||
bgra8unorm_storage_supported,
|
bgra8unorm_storage_supported,
|
||||||
);
|
);
|
||||||
|
|
||||||
// we must be using DXC because uint64_t was added with Shader Model 6
|
let mut features1: d3d12_ty::D3D12_FEATURE_DATA_D3D12_OPTIONS1 = unsafe { mem::zeroed() };
|
||||||
// and FXC only supports up to 5.1
|
|
||||||
let int64_shader_ops_supported = dxc_container.is_some() && {
|
|
||||||
let mut features1: d3d12_ty::D3D12_FEATURE_DATA_D3D12_OPTIONS1 =
|
|
||||||
unsafe { mem::zeroed() };
|
|
||||||
let hr = unsafe {
|
let hr = unsafe {
|
||||||
device.CheckFeatureSupport(
|
device.CheckFeatureSupport(
|
||||||
d3d12_ty::D3D12_FEATURE_D3D12_OPTIONS1,
|
d3d12_ty::D3D12_FEATURE_D3D12_OPTIONS1,
|
||||||
@ -355,9 +356,20 @@ impl super::Adapter {
|
|||||||
mem::size_of::<d3d12_ty::D3D12_FEATURE_DATA_D3D12_OPTIONS1>() as _,
|
mem::size_of::<d3d12_ty::D3D12_FEATURE_DATA_D3D12_OPTIONS1>() as _,
|
||||||
)
|
)
|
||||||
};
|
};
|
||||||
hr == 0 && features1.Int64ShaderOps != 0
|
|
||||||
};
|
// we must be using DXC because uint64_t was added with Shader Model 6
|
||||||
features.set(wgt::Features::SHADER_INT64, int64_shader_ops_supported);
|
// and FXC only supports up to 5.1
|
||||||
|
features.set(
|
||||||
|
wgt::Features::SHADER_INT64,
|
||||||
|
dxc_container.is_some() && hr == 0 && features1.Int64ShaderOps != 0,
|
||||||
|
);
|
||||||
|
|
||||||
|
features.set(
|
||||||
|
wgt::Features::SUBGROUP,
|
||||||
|
shader_model_support.HighestShaderModel >= d3d12_ty::D3D_SHADER_MODEL_6_0
|
||||||
|
&& hr == 0
|
||||||
|
&& features1.WaveOps != 0,
|
||||||
|
);
|
||||||
|
|
||||||
// float32-filterable should always be available on d3d12
|
// float32-filterable should always be available on d3d12
|
||||||
features.set(wgt::Features::FLOAT32_FILTERABLE, true);
|
features.set(wgt::Features::FLOAT32_FILTERABLE, true);
|
||||||
@ -425,6 +437,8 @@ impl super::Adapter {
|
|||||||
.min(crate::MAX_VERTEX_BUFFERS as u32),
|
.min(crate::MAX_VERTEX_BUFFERS as u32),
|
||||||
max_vertex_attributes: d3d12_ty::D3D12_IA_VERTEX_INPUT_RESOURCE_SLOT_COUNT,
|
max_vertex_attributes: d3d12_ty::D3D12_IA_VERTEX_INPUT_RESOURCE_SLOT_COUNT,
|
||||||
max_vertex_buffer_array_stride: d3d12_ty::D3D12_SO_BUFFER_MAX_STRIDE_IN_BYTES,
|
max_vertex_buffer_array_stride: d3d12_ty::D3D12_SO_BUFFER_MAX_STRIDE_IN_BYTES,
|
||||||
|
min_subgroup_size: 4, // Not using `features1.WaveLaneCountMin` as it is unreliable
|
||||||
|
max_subgroup_size: 128,
|
||||||
// The push constants are part of the root signature which
|
// The push constants are part of the root signature which
|
||||||
// has a limit of 64 DWORDS (256 bytes), but other resources
|
// has a limit of 64 DWORDS (256 bytes), but other resources
|
||||||
// also share the root signature:
|
// also share the root signature:
|
||||||
|
|||||||
@ -748,6 +748,8 @@ impl super::Adapter {
|
|||||||
} else {
|
} else {
|
||||||
!0
|
!0
|
||||||
},
|
},
|
||||||
|
min_subgroup_size: 0,
|
||||||
|
max_subgroup_size: 0,
|
||||||
max_push_constant_size: super::MAX_PUSH_CONSTANTS as u32 * 4,
|
max_push_constant_size: super::MAX_PUSH_CONSTANTS as u32 * 4,
|
||||||
min_uniform_buffer_offset_alignment,
|
min_uniform_buffer_offset_alignment,
|
||||||
min_storage_buffer_offset_alignment,
|
min_storage_buffer_offset_alignment,
|
||||||
|
|||||||
@ -813,6 +813,10 @@ impl super::PrivateCapabilities {
|
|||||||
None
|
None
|
||||||
},
|
},
|
||||||
timestamp_query_support,
|
timestamp_query_support,
|
||||||
|
supports_simd_scoped_operations: family_check
|
||||||
|
&& (device.supports_family(MTLGPUFamily::Metal3)
|
||||||
|
|| device.supports_family(MTLGPUFamily::Mac2)
|
||||||
|
|| device.supports_family(MTLGPUFamily::Apple7)),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -898,6 +902,10 @@ impl super::PrivateCapabilities {
|
|||||||
features.set(F::RG11B10UFLOAT_RENDERABLE, self.format_rg11b10_all);
|
features.set(F::RG11B10UFLOAT_RENDERABLE, self.format_rg11b10_all);
|
||||||
features.set(F::SHADER_UNUSED_VERTEX_OUTPUT, true);
|
features.set(F::SHADER_UNUSED_VERTEX_OUTPUT, true);
|
||||||
|
|
||||||
|
if self.supports_simd_scoped_operations {
|
||||||
|
features.insert(F::SUBGROUP | F::SUBGROUP_BARRIER);
|
||||||
|
}
|
||||||
|
|
||||||
features
|
features
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -952,6 +960,8 @@ impl super::PrivateCapabilities {
|
|||||||
max_vertex_buffers: self.max_vertex_buffers,
|
max_vertex_buffers: self.max_vertex_buffers,
|
||||||
max_vertex_attributes: 31,
|
max_vertex_attributes: 31,
|
||||||
max_vertex_buffer_array_stride: base.max_vertex_buffer_array_stride,
|
max_vertex_buffer_array_stride: base.max_vertex_buffer_array_stride,
|
||||||
|
min_subgroup_size: 4,
|
||||||
|
max_subgroup_size: 64,
|
||||||
max_push_constant_size: 0x1000,
|
max_push_constant_size: 0x1000,
|
||||||
min_uniform_buffer_offset_alignment: self.buffer_alignment as u32,
|
min_uniform_buffer_offset_alignment: self.buffer_alignment as u32,
|
||||||
min_storage_buffer_offset_alignment: self.buffer_alignment as u32,
|
min_storage_buffer_offset_alignment: self.buffer_alignment as u32,
|
||||||
|
|||||||
@ -269,6 +269,7 @@ struct PrivateCapabilities {
|
|||||||
supports_shader_primitive_index: bool,
|
supports_shader_primitive_index: bool,
|
||||||
has_unified_memory: Option<bool>,
|
has_unified_memory: Option<bool>,
|
||||||
timestamp_query_support: TimestampQuerySupport,
|
timestamp_query_support: TimestampQuerySupport,
|
||||||
|
supports_simd_scoped_operations: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Debug)]
|
#[derive(Clone, Debug)]
|
||||||
|
|||||||
@ -101,6 +101,9 @@ pub struct PhysicalDeviceFeatures {
|
|||||||
/// to Vulkan 1.3.
|
/// to Vulkan 1.3.
|
||||||
zero_initialize_workgroup_memory:
|
zero_initialize_workgroup_memory:
|
||||||
Option<vk::PhysicalDeviceZeroInitializeWorkgroupMemoryFeatures>,
|
Option<vk::PhysicalDeviceZeroInitializeWorkgroupMemoryFeatures>,
|
||||||
|
|
||||||
|
/// Features provided by `VK_EXT_subgroup_size_control`, promoted to Vulkan 1.3.
|
||||||
|
subgroup_size_control: Option<vk::PhysicalDeviceSubgroupSizeControlFeatures>,
|
||||||
}
|
}
|
||||||
|
|
||||||
// This is safe because the structs have `p_next: *mut c_void`, which we null out/never read.
|
// This is safe because the structs have `p_next: *mut c_void`, which we null out/never read.
|
||||||
@ -148,6 +151,9 @@ impl PhysicalDeviceFeatures {
|
|||||||
if let Some(ref mut feature) = self.ray_query {
|
if let Some(ref mut feature) = self.ray_query {
|
||||||
info = info.push_next(feature);
|
info = info.push_next(feature);
|
||||||
}
|
}
|
||||||
|
if let Some(ref mut feature) = self.subgroup_size_control {
|
||||||
|
info = info.push_next(feature);
|
||||||
|
}
|
||||||
info
|
info
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -434,6 +440,17 @@ impl PhysicalDeviceFeatures {
|
|||||||
} else {
|
} else {
|
||||||
None
|
None
|
||||||
},
|
},
|
||||||
|
subgroup_size_control: if device_api_version >= vk::API_VERSION_1_3
|
||||||
|
|| enabled_extensions.contains(&vk::ExtSubgroupSizeControlFn::name())
|
||||||
|
{
|
||||||
|
Some(
|
||||||
|
vk::PhysicalDeviceSubgroupSizeControlFeatures::builder()
|
||||||
|
.subgroup_size_control(true)
|
||||||
|
.build(),
|
||||||
|
)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -638,6 +655,34 @@ impl PhysicalDeviceFeatures {
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if let Some(ref subgroup) = caps.subgroup {
|
||||||
|
if (caps.device_api_version >= vk::API_VERSION_1_3
|
||||||
|
|| caps.supports_extension(vk::ExtSubgroupSizeControlFn::name()))
|
||||||
|
&& subgroup.supported_operations.contains(
|
||||||
|
vk::SubgroupFeatureFlags::BASIC
|
||||||
|
| vk::SubgroupFeatureFlags::VOTE
|
||||||
|
| vk::SubgroupFeatureFlags::ARITHMETIC
|
||||||
|
| vk::SubgroupFeatureFlags::BALLOT
|
||||||
|
| vk::SubgroupFeatureFlags::SHUFFLE
|
||||||
|
| vk::SubgroupFeatureFlags::SHUFFLE_RELATIVE,
|
||||||
|
)
|
||||||
|
{
|
||||||
|
features.set(
|
||||||
|
F::SUBGROUP,
|
||||||
|
subgroup
|
||||||
|
.supported_stages
|
||||||
|
.contains(vk::ShaderStageFlags::COMPUTE | vk::ShaderStageFlags::FRAGMENT),
|
||||||
|
);
|
||||||
|
features.set(
|
||||||
|
F::SUBGROUP_VERTEX,
|
||||||
|
subgroup
|
||||||
|
.supported_stages
|
||||||
|
.contains(vk::ShaderStageFlags::VERTEX),
|
||||||
|
);
|
||||||
|
features.insert(F::SUBGROUP_BARRIER);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
let supports_depth_format = |format| {
|
let supports_depth_format = |format| {
|
||||||
supports_format(
|
supports_format(
|
||||||
instance,
|
instance,
|
||||||
@ -773,6 +818,13 @@ pub struct PhysicalDeviceProperties {
|
|||||||
/// `VK_KHR_driver_properties` extension, promoted to Vulkan 1.2.
|
/// `VK_KHR_driver_properties` extension, promoted to Vulkan 1.2.
|
||||||
driver: Option<vk::PhysicalDeviceDriverPropertiesKHR>,
|
driver: Option<vk::PhysicalDeviceDriverPropertiesKHR>,
|
||||||
|
|
||||||
|
/// Additional `vk::PhysicalDevice` properties from Vulkan 1.1.
|
||||||
|
subgroup: Option<vk::PhysicalDeviceSubgroupProperties>,
|
||||||
|
|
||||||
|
/// Additional `vk::PhysicalDevice` properties from the
|
||||||
|
/// `VK_EXT_subgroup_size_control` extension, promoted to Vulkan 1.3.
|
||||||
|
subgroup_size_control: Option<vk::PhysicalDeviceSubgroupSizeControlProperties>,
|
||||||
|
|
||||||
/// The device API version.
|
/// The device API version.
|
||||||
///
|
///
|
||||||
/// Which is the version of Vulkan supported for device-level functionality.
|
/// Which is the version of Vulkan supported for device-level functionality.
|
||||||
@ -888,6 +940,11 @@ impl PhysicalDeviceProperties {
|
|||||||
if self.supports_extension(vk::ExtImageRobustnessFn::name()) {
|
if self.supports_extension(vk::ExtImageRobustnessFn::name()) {
|
||||||
extensions.push(vk::ExtImageRobustnessFn::name());
|
extensions.push(vk::ExtImageRobustnessFn::name());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Require `VK_EXT_subgroup_size_control` if the associated feature was requested
|
||||||
|
if requested_features.contains(wgt::Features::SUBGROUP) {
|
||||||
|
extensions.push(vk::ExtSubgroupSizeControlFn::name());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Optional `VK_KHR_swapchain_mutable_format`
|
// Optional `VK_KHR_swapchain_mutable_format`
|
||||||
@ -987,6 +1044,14 @@ impl PhysicalDeviceProperties {
|
|||||||
.min(crate::MAX_VERTEX_BUFFERS as u32),
|
.min(crate::MAX_VERTEX_BUFFERS as u32),
|
||||||
max_vertex_attributes: limits.max_vertex_input_attributes,
|
max_vertex_attributes: limits.max_vertex_input_attributes,
|
||||||
max_vertex_buffer_array_stride: limits.max_vertex_input_binding_stride,
|
max_vertex_buffer_array_stride: limits.max_vertex_input_binding_stride,
|
||||||
|
min_subgroup_size: self
|
||||||
|
.subgroup_size_control
|
||||||
|
.map(|subgroup_size| subgroup_size.min_subgroup_size)
|
||||||
|
.unwrap_or(0),
|
||||||
|
max_subgroup_size: self
|
||||||
|
.subgroup_size_control
|
||||||
|
.map(|subgroup_size| subgroup_size.max_subgroup_size)
|
||||||
|
.unwrap_or(0),
|
||||||
max_push_constant_size: limits.max_push_constants_size,
|
max_push_constant_size: limits.max_push_constants_size,
|
||||||
min_uniform_buffer_offset_alignment: limits.min_uniform_buffer_offset_alignment as u32,
|
min_uniform_buffer_offset_alignment: limits.min_uniform_buffer_offset_alignment as u32,
|
||||||
min_storage_buffer_offset_alignment: limits.min_storage_buffer_offset_alignment as u32,
|
min_storage_buffer_offset_alignment: limits.min_storage_buffer_offset_alignment as u32,
|
||||||
@ -1042,6 +1107,9 @@ impl super::InstanceShared {
|
|||||||
let supports_driver_properties = capabilities.device_api_version
|
let supports_driver_properties = capabilities.device_api_version
|
||||||
>= vk::API_VERSION_1_2
|
>= vk::API_VERSION_1_2
|
||||||
|| capabilities.supports_extension(vk::KhrDriverPropertiesFn::name());
|
|| capabilities.supports_extension(vk::KhrDriverPropertiesFn::name());
|
||||||
|
let supports_subgroup_size_control = capabilities.device_api_version
|
||||||
|
>= vk::API_VERSION_1_3
|
||||||
|
|| capabilities.supports_extension(vk::ExtSubgroupSizeControlFn::name());
|
||||||
|
|
||||||
let supports_acceleration_structure =
|
let supports_acceleration_structure =
|
||||||
capabilities.supports_extension(vk::KhrAccelerationStructureFn::name());
|
capabilities.supports_extension(vk::KhrAccelerationStructureFn::name());
|
||||||
@ -1075,6 +1143,20 @@ impl super::InstanceShared {
|
|||||||
builder = builder.push_next(next);
|
builder = builder.push_next(next);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if capabilities.device_api_version >= vk::API_VERSION_1_1 {
|
||||||
|
let next = capabilities
|
||||||
|
.subgroup
|
||||||
|
.insert(vk::PhysicalDeviceSubgroupProperties::default());
|
||||||
|
builder = builder.push_next(next);
|
||||||
|
}
|
||||||
|
|
||||||
|
if supports_subgroup_size_control {
|
||||||
|
let next = capabilities
|
||||||
|
.subgroup_size_control
|
||||||
|
.insert(vk::PhysicalDeviceSubgroupSizeControlProperties::default());
|
||||||
|
builder = builder.push_next(next);
|
||||||
|
}
|
||||||
|
|
||||||
let mut properties2 = builder.build();
|
let mut properties2 = builder.build();
|
||||||
unsafe {
|
unsafe {
|
||||||
get_device_properties.get_physical_device_properties2(phd, &mut properties2);
|
get_device_properties.get_physical_device_properties2(phd, &mut properties2);
|
||||||
@ -1190,6 +1272,16 @@ impl super::InstanceShared {
|
|||||||
builder = builder.push_next(next);
|
builder = builder.push_next(next);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// `VK_EXT_subgroup_size_control` is promoted to 1.3
|
||||||
|
if capabilities.device_api_version >= vk::API_VERSION_1_3
|
||||||
|
|| capabilities.supports_extension(vk::ExtSubgroupSizeControlFn::name())
|
||||||
|
{
|
||||||
|
let next = features
|
||||||
|
.subgroup_size_control
|
||||||
|
.insert(vk::PhysicalDeviceSubgroupSizeControlFeatures::default());
|
||||||
|
builder = builder.push_next(next);
|
||||||
|
}
|
||||||
|
|
||||||
let mut features2 = builder.build();
|
let mut features2 = builder.build();
|
||||||
unsafe {
|
unsafe {
|
||||||
get_device_properties.get_physical_device_features2(phd, &mut features2);
|
get_device_properties.get_physical_device_features2(phd, &mut features2);
|
||||||
@ -1382,6 +1474,9 @@ impl super::Instance {
|
|||||||
}),
|
}),
|
||||||
image_format_list: phd_capabilities.device_api_version >= vk::API_VERSION_1_2
|
image_format_list: phd_capabilities.device_api_version >= vk::API_VERSION_1_2
|
||||||
|| phd_capabilities.supports_extension(vk::KhrImageFormatListFn::name()),
|
|| phd_capabilities.supports_extension(vk::KhrImageFormatListFn::name()),
|
||||||
|
subgroup_size_control: phd_features
|
||||||
|
.subgroup_size_control
|
||||||
|
.map_or(false, |ext| ext.subgroup_size_control == vk::TRUE),
|
||||||
};
|
};
|
||||||
let capabilities = crate::Capabilities {
|
let capabilities = crate::Capabilities {
|
||||||
limits: phd_capabilities.to_wgpu_limits(),
|
limits: phd_capabilities.to_wgpu_limits(),
|
||||||
@ -1581,6 +1676,15 @@ impl super::Adapter {
|
|||||||
capabilities.push(spv::Capability::Geometry);
|
capabilities.push(spv::Capability::Geometry);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if features.intersects(wgt::Features::SUBGROUP | wgt::Features::SUBGROUP_VERTEX) {
|
||||||
|
capabilities.push(spv::Capability::GroupNonUniform);
|
||||||
|
capabilities.push(spv::Capability::GroupNonUniformVote);
|
||||||
|
capabilities.push(spv::Capability::GroupNonUniformArithmetic);
|
||||||
|
capabilities.push(spv::Capability::GroupNonUniformBallot);
|
||||||
|
capabilities.push(spv::Capability::GroupNonUniformShuffle);
|
||||||
|
capabilities.push(spv::Capability::GroupNonUniformShuffleRelative);
|
||||||
|
}
|
||||||
|
|
||||||
if features.intersects(
|
if features.intersects(
|
||||||
wgt::Features::SAMPLED_TEXTURE_AND_STORAGE_BUFFER_ARRAY_NON_UNIFORM_INDEXING
|
wgt::Features::SAMPLED_TEXTURE_AND_STORAGE_BUFFER_ARRAY_NON_UNIFORM_INDEXING
|
||||||
| wgt::Features::UNIFORM_BUFFER_AND_STORAGE_TEXTURE_ARRAY_NON_UNIFORM_INDEXING,
|
| wgt::Features::UNIFORM_BUFFER_AND_STORAGE_TEXTURE_ARRAY_NON_UNIFORM_INDEXING,
|
||||||
@ -1616,7 +1720,13 @@ impl super::Adapter {
|
|||||||
true, // could check `super::Workarounds::SEPARATE_ENTRY_POINTS`
|
true, // could check `super::Workarounds::SEPARATE_ENTRY_POINTS`
|
||||||
);
|
);
|
||||||
spv::Options {
|
spv::Options {
|
||||||
lang_version: (1, 0),
|
lang_version: if features
|
||||||
|
.intersects(wgt::Features::SUBGROUP | wgt::Features::SUBGROUP_VERTEX)
|
||||||
|
{
|
||||||
|
(1, 3)
|
||||||
|
} else {
|
||||||
|
(1, 0)
|
||||||
|
},
|
||||||
flags,
|
flags,
|
||||||
capabilities: Some(capabilities.iter().cloned().collect()),
|
capabilities: Some(capabilities.iter().cloned().collect()),
|
||||||
bounds_check_policies: naga::proc::BoundsCheckPolicies {
|
bounds_check_policies: naga::proc::BoundsCheckPolicies {
|
||||||
|
|||||||
@ -782,8 +782,14 @@ impl super::Device {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
let mut flags = vk::PipelineShaderStageCreateFlags::empty();
|
||||||
|
if self.shared.private_caps.subgroup_size_control {
|
||||||
|
flags |= vk::PipelineShaderStageCreateFlags::ALLOW_VARYING_SUBGROUP_SIZE
|
||||||
|
}
|
||||||
|
|
||||||
let entry_point = CString::new(stage.entry_point).unwrap();
|
let entry_point = CString::new(stage.entry_point).unwrap();
|
||||||
let create_info = vk::PipelineShaderStageCreateInfo::builder()
|
let create_info = vk::PipelineShaderStageCreateInfo::builder()
|
||||||
|
.flags(flags)
|
||||||
.stage(conv::map_shader_stage(stage_flags))
|
.stage(conv::map_shader_stage(stage_flags))
|
||||||
.module(vk_module)
|
.module(vk_module)
|
||||||
.name(&entry_point)
|
.name(&entry_point)
|
||||||
|
|||||||
@ -238,6 +238,7 @@ struct PrivateCapabilities {
|
|||||||
robust_image_access2: bool,
|
robust_image_access2: bool,
|
||||||
zero_initialize_workgroup_memory: bool,
|
zero_initialize_workgroup_memory: bool,
|
||||||
image_format_list: bool,
|
image_format_list: bool,
|
||||||
|
subgroup_size_control: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
bitflags::bitflags!(
|
bitflags::bitflags!(
|
||||||
|
|||||||
@ -143,6 +143,8 @@ fn print_adapter(output: &mut impl io::Write, report: &AdapterReport, idx: usize
|
|||||||
max_vertex_buffers,
|
max_vertex_buffers,
|
||||||
max_vertex_attributes,
|
max_vertex_attributes,
|
||||||
max_vertex_buffer_array_stride,
|
max_vertex_buffer_array_stride,
|
||||||
|
min_subgroup_size,
|
||||||
|
max_subgroup_size,
|
||||||
max_push_constant_size,
|
max_push_constant_size,
|
||||||
min_uniform_buffer_offset_alignment,
|
min_uniform_buffer_offset_alignment,
|
||||||
min_storage_buffer_offset_alignment,
|
min_storage_buffer_offset_alignment,
|
||||||
@ -176,6 +178,8 @@ fn print_adapter(output: &mut impl io::Write, report: &AdapterReport, idx: usize
|
|||||||
writeln!(output, "\t\t Max Vertex Buffers: {max_vertex_buffers}")?;
|
writeln!(output, "\t\t Max Vertex Buffers: {max_vertex_buffers}")?;
|
||||||
writeln!(output, "\t\t Max Vertex Attributes: {max_vertex_attributes}")?;
|
writeln!(output, "\t\t Max Vertex Attributes: {max_vertex_attributes}")?;
|
||||||
writeln!(output, "\t\t Max Vertex Buffer Array Stride: {max_vertex_buffer_array_stride}")?;
|
writeln!(output, "\t\t Max Vertex Buffer Array Stride: {max_vertex_buffer_array_stride}")?;
|
||||||
|
writeln!(output, "\t\t Min Subgroup Size: {min_subgroup_size}")?;
|
||||||
|
writeln!(output, "\t\t Max Subgroup Size: {max_subgroup_size}")?;
|
||||||
writeln!(output, "\t\t Max Push Constant Size: {max_push_constant_size}")?;
|
writeln!(output, "\t\t Max Push Constant Size: {max_push_constant_size}")?;
|
||||||
writeln!(output, "\t\t Min Uniform Buffer Offset Alignment: {min_uniform_buffer_offset_alignment}")?;
|
writeln!(output, "\t\t Min Uniform Buffer Offset Alignment: {min_uniform_buffer_offset_alignment}")?;
|
||||||
writeln!(output, "\t\t Min Storage Buffer Offset Alignment: {min_storage_buffer_offset_alignment}")?;
|
writeln!(output, "\t\t Min Storage Buffer Offset Alignment: {min_storage_buffer_offset_alignment}")?;
|
||||||
|
|||||||
@ -890,6 +890,30 @@ bitflags::bitflags! {
|
|||||||
///
|
///
|
||||||
/// This is a native only feature.
|
/// This is a native only feature.
|
||||||
const SHADER_INT64 = 1 << 55;
|
const SHADER_INT64 = 1 << 55;
|
||||||
|
/// Allows compute and fragment shaders to use the subgroup operation built-ins
|
||||||
|
///
|
||||||
|
/// Supported Platforms:
|
||||||
|
/// - Vulkan
|
||||||
|
/// - DX12
|
||||||
|
/// - Metal
|
||||||
|
///
|
||||||
|
/// This is a native only feature.
|
||||||
|
const SUBGROUP = 1 << 56;
|
||||||
|
/// Allows vertex shaders to use the subgroup operation built-ins
|
||||||
|
///
|
||||||
|
/// Supported Platforms:
|
||||||
|
/// - Vulkan
|
||||||
|
///
|
||||||
|
/// This is a native only feature.
|
||||||
|
const SUBGROUP_VERTEX = 1 << 57;
|
||||||
|
/// Allows shaders to use the subgroup barrier
|
||||||
|
///
|
||||||
|
/// Supported Platforms:
|
||||||
|
/// - Vulkan
|
||||||
|
/// - Metal
|
||||||
|
///
|
||||||
|
/// This is a native only feature.
|
||||||
|
const SUBGROUP_BARRIER = 1 << 58;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1136,6 +1160,11 @@ pub struct Limits {
|
|||||||
/// The maximum value for each dimension of a `ComputePass::dispatch(x, y, z)` operation.
|
/// The maximum value for each dimension of a `ComputePass::dispatch(x, y, z)` operation.
|
||||||
/// Defaults to 65535. Higher is "better".
|
/// Defaults to 65535. Higher is "better".
|
||||||
pub max_compute_workgroups_per_dimension: u32,
|
pub max_compute_workgroups_per_dimension: u32,
|
||||||
|
|
||||||
|
/// Minimal number of invocations in a subgroup. Higher is "better".
|
||||||
|
pub min_subgroup_size: u32,
|
||||||
|
/// Maximal number of invocations in a subgroup. Lower is "better".
|
||||||
|
pub max_subgroup_size: u32,
|
||||||
/// Amount of storage available for push constants in bytes. Defaults to 0. Higher is "better".
|
/// Amount of storage available for push constants in bytes. Defaults to 0. Higher is "better".
|
||||||
/// Requesting more than 0 during device creation requires [`Features::PUSH_CONSTANTS`] to be enabled.
|
/// Requesting more than 0 during device creation requires [`Features::PUSH_CONSTANTS`] to be enabled.
|
||||||
///
|
///
|
||||||
@ -1146,7 +1175,6 @@ pub struct Limits {
|
|||||||
/// - OpenGL doesn't natively support push constants, and are emulated with uniforms,
|
/// - OpenGL doesn't natively support push constants, and are emulated with uniforms,
|
||||||
/// so this number is less useful but likely 256.
|
/// so this number is less useful but likely 256.
|
||||||
pub max_push_constant_size: u32,
|
pub max_push_constant_size: u32,
|
||||||
|
|
||||||
/// Maximum number of live non-sampler bindings.
|
/// Maximum number of live non-sampler bindings.
|
||||||
///
|
///
|
||||||
/// This limit only affects the d3d12 backend. Using a large number will allow the device
|
/// This limit only affects the d3d12 backend. Using a large number will allow the device
|
||||||
@ -1187,6 +1215,8 @@ impl Default for Limits {
|
|||||||
max_compute_workgroup_size_y: 256,
|
max_compute_workgroup_size_y: 256,
|
||||||
max_compute_workgroup_size_z: 64,
|
max_compute_workgroup_size_z: 64,
|
||||||
max_compute_workgroups_per_dimension: 65535,
|
max_compute_workgroups_per_dimension: 65535,
|
||||||
|
min_subgroup_size: 0,
|
||||||
|
max_subgroup_size: 0,
|
||||||
max_push_constant_size: 0,
|
max_push_constant_size: 0,
|
||||||
max_non_sampler_bindings: 1_000_000,
|
max_non_sampler_bindings: 1_000_000,
|
||||||
}
|
}
|
||||||
@ -1218,6 +1248,8 @@ impl Limits {
|
|||||||
/// max_vertex_buffers: 8,
|
/// max_vertex_buffers: 8,
|
||||||
/// max_vertex_attributes: 16,
|
/// max_vertex_attributes: 16,
|
||||||
/// max_vertex_buffer_array_stride: 2048,
|
/// max_vertex_buffer_array_stride: 2048,
|
||||||
|
/// min_subgroup_size: 0,
|
||||||
|
/// max_subgroup_size: 0,
|
||||||
/// max_push_constant_size: 0,
|
/// max_push_constant_size: 0,
|
||||||
/// min_uniform_buffer_offset_alignment: 256,
|
/// min_uniform_buffer_offset_alignment: 256,
|
||||||
/// min_storage_buffer_offset_alignment: 256,
|
/// min_storage_buffer_offset_alignment: 256,
|
||||||
@ -1254,6 +1286,8 @@ impl Limits {
|
|||||||
max_vertex_buffers: 8,
|
max_vertex_buffers: 8,
|
||||||
max_vertex_attributes: 16,
|
max_vertex_attributes: 16,
|
||||||
max_vertex_buffer_array_stride: 2048,
|
max_vertex_buffer_array_stride: 2048,
|
||||||
|
min_subgroup_size: 0,
|
||||||
|
max_subgroup_size: 0,
|
||||||
max_push_constant_size: 0,
|
max_push_constant_size: 0,
|
||||||
min_uniform_buffer_offset_alignment: 256,
|
min_uniform_buffer_offset_alignment: 256,
|
||||||
min_storage_buffer_offset_alignment: 256,
|
min_storage_buffer_offset_alignment: 256,
|
||||||
@ -1296,6 +1330,8 @@ impl Limits {
|
|||||||
/// max_vertex_buffers: 8,
|
/// max_vertex_buffers: 8,
|
||||||
/// max_vertex_attributes: 16,
|
/// max_vertex_attributes: 16,
|
||||||
/// max_vertex_buffer_array_stride: 255, // +
|
/// max_vertex_buffer_array_stride: 255, // +
|
||||||
|
/// min_subgroup_size: 0,
|
||||||
|
/// max_subgroup_size: 0,
|
||||||
/// max_push_constant_size: 0,
|
/// max_push_constant_size: 0,
|
||||||
/// min_uniform_buffer_offset_alignment: 256,
|
/// min_uniform_buffer_offset_alignment: 256,
|
||||||
/// min_storage_buffer_offset_alignment: 256,
|
/// min_storage_buffer_offset_alignment: 256,
|
||||||
@ -1326,6 +1362,8 @@ impl Limits {
|
|||||||
max_compute_workgroup_size_y: 0,
|
max_compute_workgroup_size_y: 0,
|
||||||
max_compute_workgroup_size_z: 0,
|
max_compute_workgroup_size_z: 0,
|
||||||
max_compute_workgroups_per_dimension: 0,
|
max_compute_workgroups_per_dimension: 0,
|
||||||
|
min_subgroup_size: 0,
|
||||||
|
max_subgroup_size: 0,
|
||||||
|
|
||||||
// Value supported by Intel Celeron B830 on Windows (OpenGL 3.1)
|
// Value supported by Intel Celeron B830 on Windows (OpenGL 3.1)
|
||||||
max_inter_stage_shader_components: 31,
|
max_inter_stage_shader_components: 31,
|
||||||
@ -1418,6 +1456,10 @@ impl Limits {
|
|||||||
compare!(max_vertex_buffers, Less);
|
compare!(max_vertex_buffers, Less);
|
||||||
compare!(max_vertex_attributes, Less);
|
compare!(max_vertex_attributes, Less);
|
||||||
compare!(max_vertex_buffer_array_stride, Less);
|
compare!(max_vertex_buffer_array_stride, Less);
|
||||||
|
if self.min_subgroup_size > 0 && self.max_subgroup_size > 0 {
|
||||||
|
compare!(min_subgroup_size, Greater);
|
||||||
|
compare!(max_subgroup_size, Less);
|
||||||
|
}
|
||||||
compare!(max_push_constant_size, Less);
|
compare!(max_push_constant_size, Less);
|
||||||
compare!(min_uniform_buffer_offset_alignment, Greater);
|
compare!(min_uniform_buffer_offset_alignment, Greater);
|
||||||
compare!(min_storage_buffer_offset_alignment, Greater);
|
compare!(min_storage_buffer_offset_alignment, Greater);
|
||||||
|
|||||||
@ -737,6 +737,8 @@ fn map_wgt_limits(limits: webgpu_sys::GpuSupportedLimits) -> wgt::Limits {
|
|||||||
max_compute_workgroup_size_z: limits.max_compute_workgroup_size_z(),
|
max_compute_workgroup_size_z: limits.max_compute_workgroup_size_z(),
|
||||||
max_compute_workgroups_per_dimension: limits.max_compute_workgroups_per_dimension(),
|
max_compute_workgroups_per_dimension: limits.max_compute_workgroups_per_dimension(),
|
||||||
// The following are not part of WebGPU
|
// The following are not part of WebGPU
|
||||||
|
min_subgroup_size: wgt::Limits::default().min_subgroup_size,
|
||||||
|
max_subgroup_size: wgt::Limits::default().max_subgroup_size,
|
||||||
max_push_constant_size: wgt::Limits::default().max_push_constant_size,
|
max_push_constant_size: wgt::Limits::default().max_push_constant_size,
|
||||||
max_non_sampler_bindings: wgt::Limits::default().max_non_sampler_bindings,
|
max_non_sampler_bindings: wgt::Limits::default().max_non_sampler_bindings,
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user