mirror of
https://github.com/gfx-rs/wgpu.git
synced 2025-12-08 21:26:17 +00:00
[naga hlsl-out] Factor out some repetitive code
This commit is contained in:
parent
bfa7ee8de5
commit
1b4eca97cf
@ -45,6 +45,11 @@ pub(crate) const F2U64_FUNCTION: &str = "naga_f2u64";
|
|||||||
pub(crate) const IMAGE_SAMPLE_BASE_CLAMP_TO_EDGE_FUNCTION: &str =
|
pub(crate) const IMAGE_SAMPLE_BASE_CLAMP_TO_EDGE_FUNCTION: &str =
|
||||||
"nagaTextureSampleBaseClampToEdge";
|
"nagaTextureSampleBaseClampToEdge";
|
||||||
|
|
||||||
|
enum Index {
|
||||||
|
Expression(Handle<crate::Expression>),
|
||||||
|
Static(u32),
|
||||||
|
}
|
||||||
|
|
||||||
struct EpStructMember {
|
struct EpStructMember {
|
||||||
name: String,
|
name: String,
|
||||||
ty: Handle<crate::Type>,
|
ty: Handle<crate::Type>,
|
||||||
@ -1797,6 +1802,23 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn write_index(
|
||||||
|
&mut self,
|
||||||
|
module: &Module,
|
||||||
|
index: Index,
|
||||||
|
func_ctx: &back::FunctionCtx<'_>,
|
||||||
|
) -> BackendResult {
|
||||||
|
match index {
|
||||||
|
Index::Static(index) => {
|
||||||
|
write!(self.out, "{index}")?;
|
||||||
|
}
|
||||||
|
Index::Expression(index) => {
|
||||||
|
self.write_expr(module, index, func_ctx)?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
/// Helper method used to write statements
|
/// Helper method used to write statements
|
||||||
///
|
///
|
||||||
/// # Notes
|
/// # Notes
|
||||||
@ -1953,13 +1975,15 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
|
|||||||
//
|
//
|
||||||
// We handle matrix Stores here directly (including sub accesses for Vectors and Scalars).
|
// We handle matrix Stores here directly (including sub accesses for Vectors and Scalars).
|
||||||
// Loads are handled by `Expression::AccessIndex` (since sub accesses work fine for Loads).
|
// Loads are handled by `Expression::AccessIndex` (since sub accesses work fine for Loads).
|
||||||
struct MatrixAccess {
|
enum MatrixAccess {
|
||||||
base: Handle<crate::Expression>,
|
Direct {
|
||||||
index: u32,
|
base: Handle<crate::Expression>,
|
||||||
}
|
index: u32,
|
||||||
enum Index {
|
},
|
||||||
Expression(Handle<crate::Expression>),
|
Struct {
|
||||||
Static(u32),
|
columns: crate::VectorSize,
|
||||||
|
base: Handle<crate::Expression>,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
let get_members = |expr: Handle<crate::Expression>| {
|
let get_members = |expr: Handle<crate::Expression>| {
|
||||||
@ -1973,187 +1997,28 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
let mut matrix = None;
|
|
||||||
let mut vector = None;
|
|
||||||
let mut scalar = None;
|
|
||||||
|
|
||||||
let mut current_expr = pointer;
|
|
||||||
for _ in 0..3 {
|
|
||||||
let resolved = func_ctx.resolve_type(current_expr, &module.types);
|
|
||||||
|
|
||||||
match (resolved, &func_ctx.expressions[current_expr]) {
|
|
||||||
(
|
|
||||||
&TypeInner::Pointer { base: ty, .. },
|
|
||||||
&crate::Expression::AccessIndex { base, index },
|
|
||||||
) if matches!(
|
|
||||||
module.types[ty].inner,
|
|
||||||
TypeInner::Matrix {
|
|
||||||
rows: crate::VectorSize::Bi,
|
|
||||||
..
|
|
||||||
}
|
|
||||||
) && get_members(base)
|
|
||||||
.map(|members| members[index as usize].binding.is_none())
|
|
||||||
== Some(true) =>
|
|
||||||
{
|
|
||||||
matrix = Some(MatrixAccess { base, index });
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
(
|
|
||||||
&TypeInner::ValuePointer {
|
|
||||||
size: Some(crate::VectorSize::Bi),
|
|
||||||
..
|
|
||||||
},
|
|
||||||
&crate::Expression::Access { base, index },
|
|
||||||
) => {
|
|
||||||
vector = Some(Index::Expression(index));
|
|
||||||
current_expr = base;
|
|
||||||
}
|
|
||||||
(
|
|
||||||
&TypeInner::ValuePointer {
|
|
||||||
size: Some(crate::VectorSize::Bi),
|
|
||||||
..
|
|
||||||
},
|
|
||||||
&crate::Expression::AccessIndex { base, index },
|
|
||||||
) => {
|
|
||||||
vector = Some(Index::Static(index));
|
|
||||||
current_expr = base;
|
|
||||||
}
|
|
||||||
(
|
|
||||||
&TypeInner::ValuePointer { size: None, .. },
|
|
||||||
&crate::Expression::Access { base, index },
|
|
||||||
) => {
|
|
||||||
scalar = Some(Index::Expression(index));
|
|
||||||
current_expr = base;
|
|
||||||
}
|
|
||||||
(
|
|
||||||
&TypeInner::ValuePointer { size: None, .. },
|
|
||||||
&crate::Expression::AccessIndex { base, index },
|
|
||||||
) => {
|
|
||||||
scalar = Some(Index::Static(index));
|
|
||||||
current_expr = base;
|
|
||||||
}
|
|
||||||
_ => break,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
write!(self.out, "{level}")?;
|
write!(self.out, "{level}")?;
|
||||||
|
|
||||||
if let Some(MatrixAccess { index, base }) = matrix {
|
let matrix_access_on_lhs =
|
||||||
let base_ty_res = &func_ctx.info[base].ty;
|
find_matrix_in_access_chain(module, pointer, func_ctx).and_then(
|
||||||
let resolved = base_ty_res.inner_with(&module.types);
|
|(matrix_expr, vector, scalar)| match (
|
||||||
let ty = match *resolved {
|
func_ctx.resolve_type(matrix_expr, &module.types),
|
||||||
TypeInner::Pointer { base, .. } => base,
|
&func_ctx.expressions[matrix_expr],
|
||||||
_ => base_ty_res.handle().unwrap(),
|
) {
|
||||||
};
|
|
||||||
|
|
||||||
if let Some(Index::Static(vec_index)) = vector {
|
|
||||||
self.write_expr(module, base, func_ctx)?;
|
|
||||||
write!(
|
|
||||||
self.out,
|
|
||||||
".{}_{}",
|
|
||||||
&self.names[&NameKey::StructMember(ty, index)],
|
|
||||||
vec_index
|
|
||||||
)?;
|
|
||||||
|
|
||||||
if let Some(scalar_index) = scalar {
|
|
||||||
write!(self.out, "[")?;
|
|
||||||
match scalar_index {
|
|
||||||
Index::Static(index) => {
|
|
||||||
write!(self.out, "{index}")?;
|
|
||||||
}
|
|
||||||
Index::Expression(index) => {
|
|
||||||
self.write_expr(module, index, func_ctx)?;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
write!(self.out, "]")?;
|
|
||||||
}
|
|
||||||
|
|
||||||
write!(self.out, " = ")?;
|
|
||||||
self.write_expr(module, value, func_ctx)?;
|
|
||||||
writeln!(self.out, ";")?;
|
|
||||||
} else {
|
|
||||||
let access = WrappedStructMatrixAccess { ty, index };
|
|
||||||
match (&vector, &scalar) {
|
|
||||||
(&Some(_), &Some(_)) => {
|
|
||||||
self.write_wrapped_struct_matrix_set_scalar_function_name(
|
|
||||||
access,
|
|
||||||
)?;
|
|
||||||
}
|
|
||||||
(&Some(_), &None) => {
|
|
||||||
self.write_wrapped_struct_matrix_set_vec_function_name(access)?;
|
|
||||||
}
|
|
||||||
(&None, _) => {
|
|
||||||
self.write_wrapped_struct_matrix_set_function_name(access)?;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
write!(self.out, "(")?;
|
|
||||||
self.write_expr(module, base, func_ctx)?;
|
|
||||||
write!(self.out, ", ")?;
|
|
||||||
self.write_expr(module, value, func_ctx)?;
|
|
||||||
|
|
||||||
if let Some(Index::Expression(vec_index)) = vector {
|
|
||||||
write!(self.out, ", ")?;
|
|
||||||
self.write_expr(module, vec_index, func_ctx)?;
|
|
||||||
|
|
||||||
if let Some(scalar_index) = scalar {
|
|
||||||
write!(self.out, ", ")?;
|
|
||||||
match scalar_index {
|
|
||||||
Index::Static(index) => {
|
|
||||||
write!(self.out, "{index}")?;
|
|
||||||
}
|
|
||||||
Index::Expression(index) => {
|
|
||||||
self.write_expr(module, index, func_ctx)?;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
writeln!(self.out, ");")?;
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// We handle `Store`s to __matCx2 column vectors and scalar elements via
|
|
||||||
// the previously injected functions __set_col_of_matCx2 / __set_el_of_matCx2.
|
|
||||||
struct MatrixData {
|
|
||||||
columns: crate::VectorSize,
|
|
||||||
base: Handle<crate::Expression>,
|
|
||||||
}
|
|
||||||
|
|
||||||
enum Index {
|
|
||||||
Expression(Handle<crate::Expression>),
|
|
||||||
Static(u32),
|
|
||||||
}
|
|
||||||
|
|
||||||
let mut matrix = None;
|
|
||||||
let mut vector = None;
|
|
||||||
let mut scalar = None;
|
|
||||||
|
|
||||||
let mut current_expr = pointer;
|
|
||||||
for _ in 0..3 {
|
|
||||||
let resolved = func_ctx.resolve_type(current_expr, &module.types);
|
|
||||||
match (resolved, &func_ctx.expressions[current_expr]) {
|
|
||||||
(
|
(
|
||||||
&TypeInner::ValuePointer {
|
&TypeInner::Pointer { base: ty, .. },
|
||||||
size: Some(crate::VectorSize::Bi),
|
|
||||||
..
|
|
||||||
},
|
|
||||||
&crate::Expression::Access { base, index },
|
|
||||||
) => {
|
|
||||||
vector = Some(index);
|
|
||||||
current_expr = base;
|
|
||||||
}
|
|
||||||
(
|
|
||||||
&TypeInner::ValuePointer { size: None, .. },
|
|
||||||
&crate::Expression::Access { base, index },
|
|
||||||
) => {
|
|
||||||
scalar = Some(Index::Expression(index));
|
|
||||||
current_expr = base;
|
|
||||||
}
|
|
||||||
(
|
|
||||||
&TypeInner::ValuePointer { size: None, .. },
|
|
||||||
&crate::Expression::AccessIndex { base, index },
|
&crate::Expression::AccessIndex { base, index },
|
||||||
) => {
|
) if matches!(
|
||||||
scalar = Some(Index::Static(index));
|
module.types[ty].inner,
|
||||||
current_expr = base;
|
TypeInner::Matrix {
|
||||||
|
rows: crate::VectorSize::Bi,
|
||||||
|
..
|
||||||
|
}
|
||||||
|
) && get_members(base)
|
||||||
|
.map(|members| members[index as usize].binding.is_none())
|
||||||
|
== Some(true) =>
|
||||||
|
{
|
||||||
|
Some((MatrixAccess::Direct { base, index }, vector, scalar))
|
||||||
}
|
}
|
||||||
_ => {
|
_ => {
|
||||||
if let Some(MatrixType {
|
if let Some(MatrixType {
|
||||||
@ -2162,24 +2027,95 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
|
|||||||
width: 4,
|
width: 4,
|
||||||
}) = get_inner_matrix_of_struct_array_member(
|
}) = get_inner_matrix_of_struct_array_member(
|
||||||
module,
|
module,
|
||||||
current_expr,
|
matrix_expr,
|
||||||
func_ctx,
|
func_ctx,
|
||||||
true,
|
true,
|
||||||
) {
|
) {
|
||||||
matrix = Some(MatrixData {
|
Some((
|
||||||
columns,
|
MatrixAccess::Struct {
|
||||||
base: current_expr,
|
columns,
|
||||||
});
|
base: matrix_expr,
|
||||||
|
},
|
||||||
|
vector,
|
||||||
|
scalar,
|
||||||
|
))
|
||||||
|
} else {
|
||||||
|
None
|
||||||
}
|
}
|
||||||
|
|
||||||
break;
|
|
||||||
}
|
}
|
||||||
|
},
|
||||||
|
);
|
||||||
|
|
||||||
|
match matrix_access_on_lhs {
|
||||||
|
Some((MatrixAccess::Direct { index, base }, vector, scalar)) => {
|
||||||
|
let base_ty_res = &func_ctx.info[base].ty;
|
||||||
|
let resolved = base_ty_res.inner_with(&module.types);
|
||||||
|
let ty = match *resolved {
|
||||||
|
TypeInner::Pointer { base, .. } => base,
|
||||||
|
_ => base_ty_res.handle().unwrap(),
|
||||||
|
};
|
||||||
|
|
||||||
|
if let Some(Index::Static(vec_index)) = vector {
|
||||||
|
self.write_expr(module, base, func_ctx)?;
|
||||||
|
write!(
|
||||||
|
self.out,
|
||||||
|
".{}_{}",
|
||||||
|
&self.names[&NameKey::StructMember(ty, index)],
|
||||||
|
vec_index
|
||||||
|
)?;
|
||||||
|
|
||||||
|
if let Some(scalar_index) = scalar {
|
||||||
|
write!(self.out, "[")?;
|
||||||
|
self.write_index(module, scalar_index, func_ctx)?;
|
||||||
|
write!(self.out, "]")?;
|
||||||
|
}
|
||||||
|
|
||||||
|
write!(self.out, " = ")?;
|
||||||
|
self.write_expr(module, value, func_ctx)?;
|
||||||
|
writeln!(self.out, ";")?;
|
||||||
|
} else {
|
||||||
|
let access = WrappedStructMatrixAccess { ty, index };
|
||||||
|
match (&vector, &scalar) {
|
||||||
|
(&Some(_), &Some(_)) => {
|
||||||
|
self.write_wrapped_struct_matrix_set_scalar_function_name(
|
||||||
|
access,
|
||||||
|
)?;
|
||||||
|
}
|
||||||
|
(&Some(_), &None) => {
|
||||||
|
self.write_wrapped_struct_matrix_set_vec_function_name(
|
||||||
|
access,
|
||||||
|
)?;
|
||||||
|
}
|
||||||
|
(&None, _) => {
|
||||||
|
self.write_wrapped_struct_matrix_set_function_name(access)?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
write!(self.out, "(")?;
|
||||||
|
self.write_expr(module, base, func_ctx)?;
|
||||||
|
write!(self.out, ", ")?;
|
||||||
|
self.write_expr(module, value, func_ctx)?;
|
||||||
|
|
||||||
|
if let Some(Index::Expression(vec_index)) = vector {
|
||||||
|
write!(self.out, ", ")?;
|
||||||
|
self.write_expr(module, vec_index, func_ctx)?;
|
||||||
|
|
||||||
|
if let Some(scalar_index) = scalar {
|
||||||
|
write!(self.out, ", ")?;
|
||||||
|
self.write_index(module, scalar_index, func_ctx)?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
writeln!(self.out, ");")?;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Some((
|
||||||
|
MatrixAccess::Struct { columns, base },
|
||||||
|
Some(Index::Expression(vec_index)),
|
||||||
|
scalar,
|
||||||
|
)) => {
|
||||||
|
// We handle `Store`s to __matCx2 column vectors and scalar elements via
|
||||||
|
// the previously injected functions __set_col_of_matCx2 / __set_el_of_matCx2.
|
||||||
|
|
||||||
if let (Some(MatrixData { columns, base }), Some(vec_index)) =
|
|
||||||
(matrix, vector)
|
|
||||||
{
|
|
||||||
if scalar.is_some() {
|
if scalar.is_some() {
|
||||||
write!(self.out, "__set_el_of_mat{}x2", columns as u8)?;
|
write!(self.out, "__set_el_of_mat{}x2", columns as u8)?;
|
||||||
} else {
|
} else {
|
||||||
@ -2192,21 +2128,17 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
|
|||||||
|
|
||||||
if let Some(scalar_index) = scalar {
|
if let Some(scalar_index) = scalar {
|
||||||
write!(self.out, ", ")?;
|
write!(self.out, ", ")?;
|
||||||
match scalar_index {
|
self.write_index(module, scalar_index, func_ctx)?;
|
||||||
Index::Static(index) => {
|
|
||||||
write!(self.out, "{index}")?;
|
|
||||||
}
|
|
||||||
Index::Expression(index) => {
|
|
||||||
self.write_expr(module, index, func_ctx)?;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
write!(self.out, ", ")?;
|
write!(self.out, ", ")?;
|
||||||
self.write_expr(module, value, func_ctx)?;
|
self.write_expr(module, value, func_ctx)?;
|
||||||
|
|
||||||
writeln!(self.out, ");")?;
|
writeln!(self.out, ");")?;
|
||||||
} else {
|
}
|
||||||
|
Some((MatrixAccess::Struct { .. }, Some(Index::Static(_)), _))
|
||||||
|
| Some((MatrixAccess::Struct { .. }, None, _))
|
||||||
|
| None => {
|
||||||
self.write_expr(module, pointer, func_ctx)?;
|
self.write_expr(module, pointer, func_ctx)?;
|
||||||
write!(self.out, " = ")?;
|
write!(self.out, " = ")?;
|
||||||
|
|
||||||
@ -4423,12 +4355,17 @@ pub(super) fn get_inner_matrix_data(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// If `base` is an access chain of the form `mat`, `mat[col]`, or `mat[col][row]`,
|
||||||
|
/// returns a tuple of the matrix, the column (vector) index (if present), and
|
||||||
|
/// the row (scalar) index (if present).
|
||||||
fn find_matrix_in_access_chain(
|
fn find_matrix_in_access_chain(
|
||||||
module: &Module,
|
module: &Module,
|
||||||
base: Handle<crate::Expression>,
|
base: Handle<crate::Expression>,
|
||||||
func_ctx: &back::FunctionCtx<'_>,
|
func_ctx: &back::FunctionCtx<'_>,
|
||||||
) -> Option<Handle<crate::Expression>> {
|
) -> Option<(Handle<crate::Expression>, Option<Index>, Option<Index>)> {
|
||||||
let mut current_base = base;
|
let mut current_base = base;
|
||||||
|
let mut vector = None;
|
||||||
|
let mut scalar = None;
|
||||||
loop {
|
loop {
|
||||||
let resolved_tr = func_ctx
|
let resolved_tr = func_ctx
|
||||||
.resolve_type(current_base, &module.types)
|
.resolve_type(current_base, &module.types)
|
||||||
@ -4436,15 +4373,22 @@ fn find_matrix_in_access_chain(
|
|||||||
let resolved = resolved_tr.as_ref()?.inner_with(&module.types);
|
let resolved = resolved_tr.as_ref()?.inner_with(&module.types);
|
||||||
|
|
||||||
match *resolved {
|
match *resolved {
|
||||||
|
TypeInner::Matrix { .. } => return Some((current_base, vector, scalar)),
|
||||||
TypeInner::Scalar(_) | TypeInner::Vector { .. } => {}
|
TypeInner::Scalar(_) | TypeInner::Vector { .. } => {}
|
||||||
TypeInner::Matrix { .. } => return Some(current_base),
|
|
||||||
_ => return None,
|
_ => return None,
|
||||||
}
|
}
|
||||||
|
|
||||||
current_base = match func_ctx.expressions[current_base] {
|
let index;
|
||||||
crate::Expression::Access { base, .. } => base,
|
(current_base, index) = match func_ctx.expressions[current_base] {
|
||||||
crate::Expression::AccessIndex { base, .. } => base,
|
crate::Expression::Access { base, index } => (base, Index::Expression(index)),
|
||||||
|
crate::Expression::AccessIndex { base, index } => (base, Index::Static(index)),
|
||||||
_ => return None,
|
_ => return None,
|
||||||
|
};
|
||||||
|
|
||||||
|
match *resolved {
|
||||||
|
TypeInner::Scalar(_) => scalar = Some(index),
|
||||||
|
TypeInner::Vector { .. } => vector = Some(index),
|
||||||
|
_ => unreachable!(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user