[naga hlsl-out] Factor out some repetitive code

This commit is contained in:
Andy Leiserson 2025-03-25 19:09:43 -07:00
parent bfa7ee8de5
commit 1b4eca97cf

View File

@ -45,6 +45,11 @@ pub(crate) const F2U64_FUNCTION: &str = "naga_f2u64";
pub(crate) const IMAGE_SAMPLE_BASE_CLAMP_TO_EDGE_FUNCTION: &str =
"nagaTextureSampleBaseClampToEdge";
enum Index {
Expression(Handle<crate::Expression>),
Static(u32),
}
struct EpStructMember {
name: String,
ty: Handle<crate::Type>,
@ -1797,6 +1802,23 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
Ok(())
}
fn write_index(
&mut self,
module: &Module,
index: Index,
func_ctx: &back::FunctionCtx<'_>,
) -> BackendResult {
match index {
Index::Static(index) => {
write!(self.out, "{index}")?;
}
Index::Expression(index) => {
self.write_expr(module, index, func_ctx)?;
}
}
Ok(())
}
/// Helper method used to write statements
///
/// # Notes
@ -1953,13 +1975,15 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
//
// We handle matrix Stores here directly (including sub accesses for Vectors and Scalars).
// Loads are handled by `Expression::AccessIndex` (since sub accesses work fine for Loads).
struct MatrixAccess {
enum MatrixAccess {
Direct {
base: Handle<crate::Expression>,
index: u32,
}
enum Index {
Expression(Handle<crate::Expression>),
Static(u32),
},
Struct {
columns: crate::VectorSize,
base: Handle<crate::Expression>,
},
}
let get_members = |expr: Handle<crate::Expression>| {
@ -1973,15 +1997,14 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
}
};
let mut matrix = None;
let mut vector = None;
let mut scalar = None;
write!(self.out, "{level}")?;
let mut current_expr = pointer;
for _ in 0..3 {
let resolved = func_ctx.resolve_type(current_expr, &module.types);
match (resolved, &func_ctx.expressions[current_expr]) {
let matrix_access_on_lhs =
find_matrix_in_access_chain(module, pointer, func_ctx).and_then(
|(matrix_expr, vector, scalar)| match (
func_ctx.resolve_type(matrix_expr, &module.types),
&func_ctx.expressions[matrix_expr],
) {
(
&TypeInner::Pointer { base: ty, .. },
&crate::Expression::AccessIndex { base, index },
@ -1995,50 +2018,36 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
.map(|members| members[index as usize].binding.is_none())
== Some(true) =>
{
matrix = Some(MatrixAccess { base, index });
break;
Some((MatrixAccess::Direct { base, index }, vector, scalar))
}
(
&TypeInner::ValuePointer {
size: Some(crate::VectorSize::Bi),
..
_ => {
if let Some(MatrixType {
columns,
rows: crate::VectorSize::Bi,
width: 4,
}) = get_inner_matrix_of_struct_array_member(
module,
matrix_expr,
func_ctx,
true,
) {
Some((
MatrixAccess::Struct {
columns,
base: matrix_expr,
},
&crate::Expression::Access { base, index },
) => {
vector = Some(Index::Expression(index));
current_expr = base;
vector,
scalar,
))
} else {
None
}
}
(
&TypeInner::ValuePointer {
size: Some(crate::VectorSize::Bi),
..
},
&crate::Expression::AccessIndex { base, index },
) => {
vector = Some(Index::Static(index));
current_expr = base;
}
(
&TypeInner::ValuePointer { size: None, .. },
&crate::Expression::Access { base, index },
) => {
scalar = Some(Index::Expression(index));
current_expr = base;
}
(
&TypeInner::ValuePointer { size: None, .. },
&crate::Expression::AccessIndex { base, index },
) => {
scalar = Some(Index::Static(index));
current_expr = base;
}
_ => break,
}
}
);
write!(self.out, "{level}")?;
if let Some(MatrixAccess { index, base }) = matrix {
match matrix_access_on_lhs {
Some((MatrixAccess::Direct { index, base }, vector, scalar)) => {
let base_ty_res = &func_ctx.info[base].ty;
let resolved = base_ty_res.inner_with(&module.types);
let ty = match *resolved {
@ -2057,14 +2066,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
if let Some(scalar_index) = scalar {
write!(self.out, "[")?;
match scalar_index {
Index::Static(index) => {
write!(self.out, "{index}")?;
}
Index::Expression(index) => {
self.write_expr(module, index, func_ctx)?;
}
}
self.write_index(module, scalar_index, func_ctx)?;
write!(self.out, "]")?;
}
@ -2080,7 +2082,9 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
)?;
}
(&Some(_), &None) => {
self.write_wrapped_struct_matrix_set_vec_function_name(access)?;
self.write_wrapped_struct_matrix_set_vec_function_name(
access,
)?;
}
(&None, _) => {
self.write_wrapped_struct_matrix_set_function_name(access)?;
@ -2098,88 +2102,20 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
if let Some(scalar_index) = scalar {
write!(self.out, ", ")?;
match scalar_index {
Index::Static(index) => {
write!(self.out, "{index}")?;
}
Index::Expression(index) => {
self.write_expr(module, index, func_ctx)?;
}
}
self.write_index(module, scalar_index, func_ctx)?;
}
}
writeln!(self.out, ");")?;
}
} else {
}
Some((
MatrixAccess::Struct { columns, base },
Some(Index::Expression(vec_index)),
scalar,
)) => {
// We handle `Store`s to __matCx2 column vectors and scalar elements via
// the previously injected functions __set_col_of_matCx2 / __set_el_of_matCx2.
struct MatrixData {
columns: crate::VectorSize,
base: Handle<crate::Expression>,
}
enum Index {
Expression(Handle<crate::Expression>),
Static(u32),
}
let mut matrix = None;
let mut vector = None;
let mut scalar = None;
let mut current_expr = pointer;
for _ in 0..3 {
let resolved = func_ctx.resolve_type(current_expr, &module.types);
match (resolved, &func_ctx.expressions[current_expr]) {
(
&TypeInner::ValuePointer {
size: Some(crate::VectorSize::Bi),
..
},
&crate::Expression::Access { base, index },
) => {
vector = Some(index);
current_expr = base;
}
(
&TypeInner::ValuePointer { size: None, .. },
&crate::Expression::Access { base, index },
) => {
scalar = Some(Index::Expression(index));
current_expr = base;
}
(
&TypeInner::ValuePointer { size: None, .. },
&crate::Expression::AccessIndex { base, index },
) => {
scalar = Some(Index::Static(index));
current_expr = base;
}
_ => {
if let Some(MatrixType {
columns,
rows: crate::VectorSize::Bi,
width: 4,
}) = get_inner_matrix_of_struct_array_member(
module,
current_expr,
func_ctx,
true,
) {
matrix = Some(MatrixData {
columns,
base: current_expr,
});
}
break;
}
}
}
if let (Some(MatrixData { columns, base }), Some(vec_index)) =
(matrix, vector)
{
if scalar.is_some() {
write!(self.out, "__set_el_of_mat{}x2", columns as u8)?;
} else {
@ -2192,21 +2128,17 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
if let Some(scalar_index) = scalar {
write!(self.out, ", ")?;
match scalar_index {
Index::Static(index) => {
write!(self.out, "{index}")?;
}
Index::Expression(index) => {
self.write_expr(module, index, func_ctx)?;
}
}
self.write_index(module, scalar_index, func_ctx)?;
}
write!(self.out, ", ")?;
self.write_expr(module, value, func_ctx)?;
writeln!(self.out, ");")?;
} else {
}
Some((MatrixAccess::Struct { .. }, Some(Index::Static(_)), _))
| Some((MatrixAccess::Struct { .. }, None, _))
| None => {
self.write_expr(module, pointer, func_ctx)?;
write!(self.out, " = ")?;
@ -4423,12 +4355,17 @@ pub(super) fn get_inner_matrix_data(
}
}
/// If `base` is an access chain of the form `mat`, `mat[col]`, or `mat[col][row]`,
/// returns a tuple of the matrix, the column (vector) index (if present), and
/// the row (scalar) index (if present).
fn find_matrix_in_access_chain(
module: &Module,
base: Handle<crate::Expression>,
func_ctx: &back::FunctionCtx<'_>,
) -> Option<Handle<crate::Expression>> {
) -> Option<(Handle<crate::Expression>, Option<Index>, Option<Index>)> {
let mut current_base = base;
let mut vector = None;
let mut scalar = None;
loop {
let resolved_tr = func_ctx
.resolve_type(current_base, &module.types)
@ -4436,15 +4373,22 @@ fn find_matrix_in_access_chain(
let resolved = resolved_tr.as_ref()?.inner_with(&module.types);
match *resolved {
TypeInner::Matrix { .. } => return Some((current_base, vector, scalar)),
TypeInner::Scalar(_) | TypeInner::Vector { .. } => {}
TypeInner::Matrix { .. } => return Some(current_base),
_ => return None,
}
current_base = match func_ctx.expressions[current_base] {
crate::Expression::Access { base, .. } => base,
crate::Expression::AccessIndex { base, .. } => base,
let index;
(current_base, index) = match func_ctx.expressions[current_base] {
crate::Expression::Access { base, index } => (base, Index::Expression(index)),
crate::Expression::AccessIndex { base, index } => (base, Index::Static(index)),
_ => return None,
};
match *resolved {
TypeInner::Scalar(_) => scalar = Some(index),
TypeInner::Vector { .. } => vector = Some(index),
_ => unreachable!(),
}
}
}