Skip to content

Commit

Permalink
[naga wgsl-in] Ensure constant evaluation correctly handles Composes …
Browse files Browse the repository at this point in the history
…of vector ZeroValues

Constant evaluation often relies on the expressions being evaluated
being either a Literal, or a vector Compose for which it can iterate
over each component.

Currently this is achieved by calling two functions:

* eval_zero_value_and_splat() which transforms a scalar ZeroValue into a
  Literal, or a Splat or vector ZeroValue into a Compose of Literals.

* proc::flatten_compose() takes potentially nested nested Compose and
  Splat expressions and produces an flat iterator that yields each
  component. eg `vec3(vec2(0), 0)` would yield `Literal(0), Literal(0),
  Literal(0)`.

For component-wise vector operations, we can then iterate through each
component of the flattened compose and apply the operation. When there
are multiple operands it is crucial they have both been flattened
correctly so that we use the corresponding component from each operand
together.

Where this falls short is if a *vector* ZeroValue is nested within a
Compose. eg `vec3(vec2(), 0)`. flatten_compose() is unable to flatten
this, and the resulting iterator will yield `ZeroValue, Literal(0)`

This causes various issues. Take binary_op(), for example. If we attempt
to add `vec3(1, 2, 3)` to our unflattenable `vec3(vec2(), 0)` this
should be evaluated component-wise as 0 + 1, 0 + 2, and 0 + 3. As this
has not been correctly flattened, however, we will evaluate vec2() + 1,
and 0 + 2, which is simply incorrect.

To solve this, we make eval_zero_value_and_splat() recursively call
itself for each component if the expression is a Compose. This ensures
no ZeroValues will be present during flatten_compose(), meaning it will
successfully fully flatten the expression.
  • Loading branch information
jamienicol committed Feb 14, 2025
1 parent bae0e70 commit adeadb6
Show file tree
Hide file tree
Showing 7 changed files with 174 additions and 110 deletions.
11 changes: 11 additions & 0 deletions naga/src/proc/constant_evaluator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1411,6 +1411,17 @@ impl<'a> ConstantEvaluator<'a> {
mut expr: Handle<Expression>,
span: Span,
) -> Result<Handle<Expression>, ConstantEvaluatorError> {
// If expr is a Compose expression, elimate ZeroValue and Splat expressions for

Check warning on line 1414 in naga/src/proc/constant_evaluator.rs

View workflow job for this annotation

GitHub Actions / Format & Typos

"elimate" should be "eliminate".
// each of its components.
if let Expression::Compose { ty, ref components } = self.expressions[expr] {
let components = components
.clone()
.iter()
.map(|component| self.eval_zero_value_and_splat(*component, span))
.collect::<Result<_, _>>()?;
expr = self.register_evaluated_expr(Expression::Compose { ty, components }, span)?;
}

// The result of the splat() for a Splat of a scalar ZeroValue is a
// vector ZeroValue, so we must call eval_zero_value_impl() after
// splat() in order to ensure we have no ZeroValues remaining.
Expand Down
7 changes: 7 additions & 0 deletions naga/tests/in/const-exprs.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -87,3 +87,10 @@ fn compose_of_splat() {

const add_vec = vec2(1.0f) + vec2(3.0f, 4.0f);
const compare_vec = vec2(3.0f) == vec2(3.0f, 4.0f);

// Ensure binary ops correctly flatten compositions of vector zero values
fn compose_vector_zero_val_binop() {
var a = vec3(vec2i(), 0) + vec3(1);
var b = vec3(vec2i(), 0) + vec3(0, 1, 2);
var c = vec3(vec2i(), 2) + vec3(1, vec2i());
}
7 changes: 7 additions & 0 deletions naga/tests/out/glsl/const-exprs.main.Compute.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,13 @@ uint map_texture_kind(int texture_kind) {
}
}

void compose_vector_zero_val_binop() {
ivec3 a = ivec3(1, 1, 1);
ivec3 b = ivec3(0, 1, 2);
ivec3 c = ivec3(1, 0, 2);
return;
}

void main() {
swizzle_of_compose();
index_of_compose();
Expand Down
9 changes: 9 additions & 0 deletions naga/tests/out/hlsl/const-exprs.hlsl
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,15 @@ uint map_texture_kind(int texture_kind)
}
}

void compose_vector_zero_val_binop()
{
int3 a = int3(1, 1, 1);
int3 b = int3(0, 1, 2);
int3 c = int3(1, 0, 2);

return;
}

[numthreads(2, 3, 1)]
void main()
{
Expand Down
8 changes: 8 additions & 0 deletions naga/tests/out/msl/const-exprs.msl
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,14 @@ uint map_texture_kind(
}
}

void compose_vector_zero_val_binop(
) {
metal::int3 a = metal::int3(1, 1, 1);
metal::int3 b = metal::int3(0, 1, 2);
metal::int3 c = metal::int3(1, 0, 2);
return;
}

kernel void main_(
) {
swizzle_of_compose();
Expand Down
234 changes: 124 additions & 110 deletions naga/tests/out/spv/const-exprs.spvasm
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
; SPIR-V
; Version: 1.1
; Generator: rspirv
; Bound: 109
; Bound: 120
OpCapability Shader
%1 = OpExtInstImport "GLSL.std.450"
OpMemoryModel Logical GLSL450
OpEntryPoint GLCompute %100 "main"
OpExecutionMode %100 LocalSize 2 3 1
OpEntryPoint GLCompute %111 "main"
OpExecutionMode %111 LocalSize 2 3 1
%2 = OpTypeVoid
%3 = OpTypeInt 32 0
%4 = OpTypeInt 32 1
Expand All @@ -16,137 +16,151 @@ OpExecutionMode %100 LocalSize 2 3 1
%8 = OpTypeVector %6 2
%10 = OpTypeBool
%9 = OpTypeVector %10 2
%11 = OpConstant %3 2
%12 = OpConstant %4 3
%13 = OpConstant %4 4
%14 = OpConstant %4 8
%15 = OpConstant %6 3.141
%16 = OpConstant %6 6.282
%17 = OpConstant %6 0.44444445
%18 = OpConstant %6 0.0
%19 = OpConstantComposite %7 %17 %18 %18 %18
%20 = OpConstant %4 0
%21 = OpConstant %4 1
%22 = OpConstant %4 2
%23 = OpConstant %6 4.0
%24 = OpConstant %6 5.0
%25 = OpConstantComposite %8 %23 %24
%26 = OpConstantTrue %10
%27 = OpConstantFalse %10
%28 = OpConstantComposite %9 %26 %27
%31 = OpTypeFunction %2
%32 = OpConstantComposite %5 %13 %12 %22 %21
%34 = OpTypePointer Function %5
%39 = OpTypePointer Function %4
%43 = OpConstant %4 6
%48 = OpConstant %4 30
%49 = OpConstant %4 70
%52 = OpConstantNull %4
%54 = OpConstantNull %4
%57 = OpConstantNull %5
%68 = OpConstant %4 -4
%69 = OpConstantComposite %5 %68 %68 %68 %68
%78 = OpConstant %6 1.0
%79 = OpConstant %6 2.0
%80 = OpConstantComposite %7 %79 %78 %78 %78
%82 = OpTypePointer Function %7
%87 = OpTypeFunction %3 %4
%88 = OpConstant %3 10
%89 = OpConstant %3 20
%90 = OpConstant %3 30
%91 = OpConstant %3 0
%98 = OpConstantNull %3
%30 = OpFunction %2 None %31
%29 = OpLabel
%33 = OpVariable %34 Function %32
OpBranch %35
%35 = OpLabel
OpReturn
OpFunctionEnd
%37 = OpFunction %2 None %31
%11 = OpTypeVector %4 3
%12 = OpConstant %3 2
%13 = OpConstant %4 3
%14 = OpConstant %4 4
%15 = OpConstant %4 8
%16 = OpConstant %6 3.141
%17 = OpConstant %6 6.282
%18 = OpConstant %6 0.44444445
%19 = OpConstant %6 0.0
%20 = OpConstantComposite %7 %18 %19 %19 %19
%21 = OpConstant %4 0
%22 = OpConstant %4 1
%23 = OpConstant %4 2
%24 = OpConstant %6 4.0
%25 = OpConstant %6 5.0
%26 = OpConstantComposite %8 %24 %25
%27 = OpConstantTrue %10
%28 = OpConstantFalse %10
%29 = OpConstantComposite %9 %27 %28
%32 = OpTypeFunction %2
%33 = OpConstantComposite %5 %14 %13 %23 %22
%35 = OpTypePointer Function %5
%40 = OpTypePointer Function %4
%44 = OpConstant %4 6
%49 = OpConstant %4 30
%50 = OpConstant %4 70
%53 = OpConstantNull %4
%55 = OpConstantNull %4
%58 = OpConstantNull %5
%69 = OpConstant %4 -4
%70 = OpConstantComposite %5 %69 %69 %69 %69
%79 = OpConstant %6 1.0
%80 = OpConstant %6 2.0
%81 = OpConstantComposite %7 %80 %79 %79 %79
%83 = OpTypePointer Function %7
%88 = OpTypeFunction %3 %4
%89 = OpConstant %3 10
%90 = OpConstant %3 20
%91 = OpConstant %3 30
%92 = OpConstant %3 0
%99 = OpConstantNull %3
%102 = OpConstantComposite %11 %22 %22 %22
%103 = OpConstantComposite %11 %21 %22 %23
%104 = OpConstantComposite %11 %22 %21 %23
%106 = OpTypePointer Function %11
%31 = OpFunction %2 None %32
%30 = OpLabel
%34 = OpVariable %35 Function %33
OpBranch %36
%36 = OpLabel
%38 = OpVariable %39 Function %22
OpBranch %40
%40 = OpLabel
OpReturn
OpFunctionEnd
%42 = OpFunction %2 None %31
%38 = OpFunction %2 None %32
%37 = OpLabel
%39 = OpVariable %40 Function %23
OpBranch %41
%41 = OpLabel
%44 = OpVariable %39 Function %43
OpBranch %45
%45 = OpLabel
OpReturn
OpFunctionEnd
%47 = OpFunction %2 None %31
%43 = OpFunction %2 None %32
%42 = OpLabel
%45 = OpVariable %40 Function %44
OpBranch %46
%46 = OpLabel
%56 = OpVariable %34 Function %57
%51 = OpVariable %39 Function %52
%55 = OpVariable %39 Function %49
%50 = OpVariable %39 Function %48
%53 = OpVariable %39 Function %54
OpBranch %58
%58 = OpLabel
%59 = OpLoad %4 %50
OpStore %51 %59
%60 = OpLoad %4 %51
OpStore %53 %60
%61 = OpLoad %4 %50
%62 = OpLoad %4 %51
%63 = OpLoad %4 %53
%64 = OpLoad %4 %55
%65 = OpCompositeConstruct %5 %61 %62 %63 %64
OpStore %56 %65
OpReturn
OpFunctionEnd
%67 = OpFunction %2 None %31
%66 = OpLabel
%70 = OpVariable %34 Function %69
OpBranch %71
%71 = OpLabel
%48 = OpFunction %2 None %32
%47 = OpLabel
%57 = OpVariable %35 Function %58
%52 = OpVariable %40 Function %53
%56 = OpVariable %40 Function %50
%51 = OpVariable %40 Function %49
%54 = OpVariable %40 Function %55
OpBranch %59
%59 = OpLabel
%60 = OpLoad %4 %51
OpStore %52 %60
%61 = OpLoad %4 %52
OpStore %54 %61
%62 = OpLoad %4 %51
%63 = OpLoad %4 %52
%64 = OpLoad %4 %54
%65 = OpLoad %4 %56
%66 = OpCompositeConstruct %5 %62 %63 %64 %65
OpStore %57 %66
OpReturn
OpFunctionEnd
%73 = OpFunction %2 None %31
%68 = OpFunction %2 None %32
%67 = OpLabel
%71 = OpVariable %35 Function %70
OpBranch %72
%72 = OpLabel
%74 = OpVariable %34 Function %69
OpBranch %75
%75 = OpLabel
OpReturn
OpFunctionEnd
%77 = OpFunction %2 None %31
%74 = OpFunction %2 None %32
%73 = OpLabel
%75 = OpVariable %35 Function %70
OpBranch %76
%76 = OpLabel
%81 = OpVariable %82 Function %80
OpBranch %83
%83 = OpLabel
OpReturn
OpFunctionEnd
%86 = OpFunction %3 None %87
%85 = OpFunctionParameter %4
%78 = OpFunction %2 None %32
%77 = OpLabel
%82 = OpVariable %83 Function %81
OpBranch %84
%84 = OpLabel
OpBranch %92
%92 = OpLabel
OpSelectionMerge %93 None
OpSwitch %85 %97 0 %94 1 %95 2 %96
%94 = OpLabel
OpReturnValue %88
OpReturn
OpFunctionEnd
%87 = OpFunction %3 None %88
%86 = OpFunctionParameter %4
%85 = OpLabel
OpBranch %93
%93 = OpLabel
OpSelectionMerge %94 None
OpSwitch %86 %98 0 %95 1 %96 2 %97
%95 = OpLabel
OpReturnValue %89
%96 = OpLabel
OpReturnValue %90
%97 = OpLabel
OpReturnValue %91
%93 = OpLabel
OpReturnValue %98
%98 = OpLabel
OpReturnValue %92
%94 = OpLabel
OpReturnValue %99
OpFunctionEnd
%101 = OpFunction %2 None %32
%100 = OpLabel
%105 = OpVariable %106 Function %102
%107 = OpVariable %106 Function %103
%108 = OpVariable %106 Function %104
OpBranch %109
%109 = OpLabel
OpReturn
OpFunctionEnd
%100 = OpFunction %2 None %31
%99 = OpLabel
OpBranch %101
%101 = OpLabel
%102 = OpFunctionCall %2 %30
%103 = OpFunctionCall %2 %37
%104 = OpFunctionCall %2 %42
%105 = OpFunctionCall %2 %47
%106 = OpFunctionCall %2 %67
%107 = OpFunctionCall %2 %73
%108 = OpFunctionCall %2 %77
%111 = OpFunction %2 None %32
%110 = OpLabel
OpBranch %112
%112 = OpLabel
%113 = OpFunctionCall %2 %31
%114 = OpFunctionCall %2 %38
%115 = OpFunctionCall %2 %43
%116 = OpFunctionCall %2 %48
%117 = OpFunctionCall %2 %68
%118 = OpFunctionCall %2 %74
%119 = OpFunctionCall %2 %78
OpReturn
OpFunctionEnd
8 changes: 8 additions & 0 deletions naga/tests/out/wgsl/const-exprs.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,14 @@ fn map_texture_kind(texture_kind: i32) -> u32 {
}
}

fn compose_vector_zero_val_binop() {
var a: vec3<i32> = vec3<i32>(1i, 1i, 1i);
var b: vec3<i32> = vec3<i32>(0i, 1i, 2i);
var c: vec3<i32> = vec3<i32>(1i, 0i, 2i);

return;
}

@compute @workgroup_size(2, 3, 1)
fn main() {
swizzle_of_compose();
Expand Down

0 comments on commit adeadb6

Please sign in to comment.