[naga hlsl-out] Factor out some repetitive code

This commit is contained in:
Andy Leiserson 2025-03-25 19:09:43 -07:00
parent bfa7ee8de5
commit 1b4eca97cf

View File

@ -45,6 +45,11 @@ pub(crate) const F2U64_FUNCTION: &str = "naga_f2u64";
pub(crate) const IMAGE_SAMPLE_BASE_CLAMP_TO_EDGE_FUNCTION: &str =
"nagaTextureSampleBaseClampToEdge";
enum Index {
Expression(Handle<crate::Expression>),
Static(u32),
}
struct EpStructMember {
name: String,
ty: Handle<crate::Type>,
@ -1797,6 +1802,23 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
Ok(())
}
fn write_index(
&mut self,
module: &Module,
index: Index,
func_ctx: &back::FunctionCtx<'_>,
) -> BackendResult {
match index {
Index::Static(index) => {
write!(self.out, "{index}")?;
}
Index::Expression(index) => {
self.write_expr(module, index, func_ctx)?;
}
}
Ok(())
}
/// Helper method used to write statements
///
/// # Notes
@ -1953,13 +1975,15 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
//
// We handle matrix Stores here directly (including sub accesses for Vectors and Scalars).
// Loads are handled by `Expression::AccessIndex` (since sub accesses work fine for Loads).
struct MatrixAccess {
base: Handle<crate::Expression>,
index: u32,
}
enum Index {
Expression(Handle<crate::Expression>),
Static(u32),
enum MatrixAccess {
Direct {
base: Handle<crate::Expression>,
index: u32,
},
Struct {
columns: crate::VectorSize,
base: Handle<crate::Expression>,
},
}
let get_members = |expr: Handle<crate::Expression>| {
@ -1973,187 +1997,28 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
}
};
let mut matrix = None;
let mut vector = None;
let mut scalar = None;
let mut current_expr = pointer;
for _ in 0..3 {
let resolved = func_ctx.resolve_type(current_expr, &module.types);
match (resolved, &func_ctx.expressions[current_expr]) {
(
&TypeInner::Pointer { base: ty, .. },
&crate::Expression::AccessIndex { base, index },
) if matches!(
module.types[ty].inner,
TypeInner::Matrix {
rows: crate::VectorSize::Bi,
..
}
) && get_members(base)
.map(|members| members[index as usize].binding.is_none())
== Some(true) =>
{
matrix = Some(MatrixAccess { base, index });
break;
}
(
&TypeInner::ValuePointer {
size: Some(crate::VectorSize::Bi),
..
},
&crate::Expression::Access { base, index },
) => {
vector = Some(Index::Expression(index));
current_expr = base;
}
(
&TypeInner::ValuePointer {
size: Some(crate::VectorSize::Bi),
..
},
&crate::Expression::AccessIndex { base, index },
) => {
vector = Some(Index::Static(index));
current_expr = base;
}
(
&TypeInner::ValuePointer { size: None, .. },
&crate::Expression::Access { base, index },
) => {
scalar = Some(Index::Expression(index));
current_expr = base;
}
(
&TypeInner::ValuePointer { size: None, .. },
&crate::Expression::AccessIndex { base, index },
) => {
scalar = Some(Index::Static(index));
current_expr = base;
}
_ => break,
}
}
write!(self.out, "{level}")?;
if let Some(MatrixAccess { index, base }) = matrix {
let base_ty_res = &func_ctx.info[base].ty;
let resolved = base_ty_res.inner_with(&module.types);
let ty = match *resolved {
TypeInner::Pointer { base, .. } => base,
_ => base_ty_res.handle().unwrap(),
};
if let Some(Index::Static(vec_index)) = vector {
self.write_expr(module, base, func_ctx)?;
write!(
self.out,
".{}_{}",
&self.names[&NameKey::StructMember(ty, index)],
vec_index
)?;
if let Some(scalar_index) = scalar {
write!(self.out, "[")?;
match scalar_index {
Index::Static(index) => {
write!(self.out, "{index}")?;
}
Index::Expression(index) => {
self.write_expr(module, index, func_ctx)?;
}
}
write!(self.out, "]")?;
}
write!(self.out, " = ")?;
self.write_expr(module, value, func_ctx)?;
writeln!(self.out, ";")?;
} else {
let access = WrappedStructMatrixAccess { ty, index };
match (&vector, &scalar) {
(&Some(_), &Some(_)) => {
self.write_wrapped_struct_matrix_set_scalar_function_name(
access,
)?;
}
(&Some(_), &None) => {
self.write_wrapped_struct_matrix_set_vec_function_name(access)?;
}
(&None, _) => {
self.write_wrapped_struct_matrix_set_function_name(access)?;
}
}
write!(self.out, "(")?;
self.write_expr(module, base, func_ctx)?;
write!(self.out, ", ")?;
self.write_expr(module, value, func_ctx)?;
if let Some(Index::Expression(vec_index)) = vector {
write!(self.out, ", ")?;
self.write_expr(module, vec_index, func_ctx)?;
if let Some(scalar_index) = scalar {
write!(self.out, ", ")?;
match scalar_index {
Index::Static(index) => {
write!(self.out, "{index}")?;
}
Index::Expression(index) => {
self.write_expr(module, index, func_ctx)?;
}
}
}
}
writeln!(self.out, ");")?;
}
} else {
// We handle `Store`s to __matCx2 column vectors and scalar elements via
// the previously injected functions __set_col_of_matCx2 / __set_el_of_matCx2.
struct MatrixData {
columns: crate::VectorSize,
base: Handle<crate::Expression>,
}
enum Index {
Expression(Handle<crate::Expression>),
Static(u32),
}
let mut matrix = None;
let mut vector = None;
let mut scalar = None;
let mut current_expr = pointer;
for _ in 0..3 {
let resolved = func_ctx.resolve_type(current_expr, &module.types);
match (resolved, &func_ctx.expressions[current_expr]) {
let matrix_access_on_lhs =
find_matrix_in_access_chain(module, pointer, func_ctx).and_then(
|(matrix_expr, vector, scalar)| match (
func_ctx.resolve_type(matrix_expr, &module.types),
&func_ctx.expressions[matrix_expr],
) {
(
&TypeInner::ValuePointer {
size: Some(crate::VectorSize::Bi),
..
},
&crate::Expression::Access { base, index },
) => {
vector = Some(index);
current_expr = base;
}
(
&TypeInner::ValuePointer { size: None, .. },
&crate::Expression::Access { base, index },
) => {
scalar = Some(Index::Expression(index));
current_expr = base;
}
(
&TypeInner::ValuePointer { size: None, .. },
&TypeInner::Pointer { base: ty, .. },
&crate::Expression::AccessIndex { base, index },
) => {
scalar = Some(Index::Static(index));
current_expr = base;
) if matches!(
module.types[ty].inner,
TypeInner::Matrix {
rows: crate::VectorSize::Bi,
..
}
) && get_members(base)
.map(|members| members[index as usize].binding.is_none())
== Some(true) =>
{
Some((MatrixAccess::Direct { base, index }, vector, scalar))
}
_ => {
if let Some(MatrixType {
@ -2162,24 +2027,95 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
width: 4,
}) = get_inner_matrix_of_struct_array_member(
module,
current_expr,
matrix_expr,
func_ctx,
true,
) {
matrix = Some(MatrixData {
columns,
base: current_expr,
});
Some((
MatrixAccess::Struct {
columns,
base: matrix_expr,
},
vector,
scalar,
))
} else {
None
}
break;
}
},
);
match matrix_access_on_lhs {
Some((MatrixAccess::Direct { index, base }, vector, scalar)) => {
let base_ty_res = &func_ctx.info[base].ty;
let resolved = base_ty_res.inner_with(&module.types);
let ty = match *resolved {
TypeInner::Pointer { base, .. } => base,
_ => base_ty_res.handle().unwrap(),
};
if let Some(Index::Static(vec_index)) = vector {
self.write_expr(module, base, func_ctx)?;
write!(
self.out,
".{}_{}",
&self.names[&NameKey::StructMember(ty, index)],
vec_index
)?;
if let Some(scalar_index) = scalar {
write!(self.out, "[")?;
self.write_index(module, scalar_index, func_ctx)?;
write!(self.out, "]")?;
}
write!(self.out, " = ")?;
self.write_expr(module, value, func_ctx)?;
writeln!(self.out, ";")?;
} else {
let access = WrappedStructMatrixAccess { ty, index };
match (&vector, &scalar) {
(&Some(_), &Some(_)) => {
self.write_wrapped_struct_matrix_set_scalar_function_name(
access,
)?;
}
(&Some(_), &None) => {
self.write_wrapped_struct_matrix_set_vec_function_name(
access,
)?;
}
(&None, _) => {
self.write_wrapped_struct_matrix_set_function_name(access)?;
}
}
write!(self.out, "(")?;
self.write_expr(module, base, func_ctx)?;
write!(self.out, ", ")?;
self.write_expr(module, value, func_ctx)?;
if let Some(Index::Expression(vec_index)) = vector {
write!(self.out, ", ")?;
self.write_expr(module, vec_index, func_ctx)?;
if let Some(scalar_index) = scalar {
write!(self.out, ", ")?;
self.write_index(module, scalar_index, func_ctx)?;
}
}
writeln!(self.out, ");")?;
}
}
Some((
MatrixAccess::Struct { columns, base },
Some(Index::Expression(vec_index)),
scalar,
)) => {
// We handle `Store`s to __matCx2 column vectors and scalar elements via
// the previously injected functions __set_col_of_matCx2 / __set_el_of_matCx2.
if let (Some(MatrixData { columns, base }), Some(vec_index)) =
(matrix, vector)
{
if scalar.is_some() {
write!(self.out, "__set_el_of_mat{}x2", columns as u8)?;
} else {
@ -2192,21 +2128,17 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
if let Some(scalar_index) = scalar {
write!(self.out, ", ")?;
match scalar_index {
Index::Static(index) => {
write!(self.out, "{index}")?;
}
Index::Expression(index) => {
self.write_expr(module, index, func_ctx)?;
}
}
self.write_index(module, scalar_index, func_ctx)?;
}
write!(self.out, ", ")?;
self.write_expr(module, value, func_ctx)?;
writeln!(self.out, ");")?;
} else {
}
Some((MatrixAccess::Struct { .. }, Some(Index::Static(_)), _))
| Some((MatrixAccess::Struct { .. }, None, _))
| None => {
self.write_expr(module, pointer, func_ctx)?;
write!(self.out, " = ")?;
@ -4423,12 +4355,17 @@ pub(super) fn get_inner_matrix_data(
}
}
/// If `base` is an access chain of the form `mat`, `mat[col]`, or `mat[col][row]`,
/// returns a tuple of the matrix, the column (vector) index (if present), and
/// the row (scalar) index (if present).
fn find_matrix_in_access_chain(
module: &Module,
base: Handle<crate::Expression>,
func_ctx: &back::FunctionCtx<'_>,
) -> Option<Handle<crate::Expression>> {
) -> Option<(Handle<crate::Expression>, Option<Index>, Option<Index>)> {
let mut current_base = base;
let mut vector = None;
let mut scalar = None;
loop {
let resolved_tr = func_ctx
.resolve_type(current_base, &module.types)
@ -4436,15 +4373,22 @@ fn find_matrix_in_access_chain(
let resolved = resolved_tr.as_ref()?.inner_with(&module.types);
match *resolved {
TypeInner::Matrix { .. } => return Some((current_base, vector, scalar)),
TypeInner::Scalar(_) | TypeInner::Vector { .. } => {}
TypeInner::Matrix { .. } => return Some(current_base),
_ => return None,
}
current_base = match func_ctx.expressions[current_base] {
crate::Expression::Access { base, .. } => base,
crate::Expression::AccessIndex { base, .. } => base,
let index;
(current_base, index) = match func_ctx.expressions[current_base] {
crate::Expression::Access { base, index } => (base, Index::Expression(index)),
crate::Expression::AccessIndex { base, index } => (base, Index::Static(index)),
_ => return None,
};
match *resolved {
TypeInner::Scalar(_) => scalar = Some(index),
TypeInner::Vector { .. } => vector = Some(index),
_ => unreachable!(),
}
}
}