From adeadb650f93e7cbb1be65558a9c599846bc7eda Mon Sep 17 00:00:00 2001 From: Jamie Nicol Date: Fri, 14 Feb 2025 11:40:47 +0000 Subject: [PATCH] [naga wgsl-in] Ensure constant evaluation correctly handles Composes of vector ZeroValues Constant evaluation often relies on the expressions being evaluated being either a Literal, or a vector Compose for which it can iterate over each component. Currently this is achieved by calling two functions: * eval_zero_value_and_splat() which transforms a scalar ZeroValue into a Literal, or a Splat or vector ZeroValue into a Compose of Literals. * proc::flatten_compose() takes potentially nested nested Compose and Splat expressions and produces an flat iterator that yields each component. eg `vec3(vec2(0), 0)` would yield `Literal(0), Literal(0), Literal(0)`. For component-wise vector operations, we can then iterate through each component of the flattened compose and apply the operation. When there are multiple operands it is crucial they have both been flattened correctly so that we use the corresponding component from each operand together. Where this falls short is if a *vector* ZeroValue is nested within a Compose. eg `vec3(vec2(), 0)`. flatten_compose() is unable to flatten this, and the resulting iterator will yield `ZeroValue, Literal(0)` This causes various issues. Take binary_op(), for example. If we attempt to add `vec3(1, 2, 3)` to our unflattenable `vec3(vec2(), 0)` this should be evaluated component-wise as 0 + 1, 0 + 2, and 0 + 3. As this has not been correctly flattened, however, we will evaluate vec2() + 1, and 0 + 2, which is simply incorrect. To solve this, we make eval_zero_value_and_splat() recursively call itself for each component if the expression is a Compose. This ensures no ZeroValues will be present during flatten_compose(), meaning it will successfully fully flatten the expression. --- naga/src/proc/constant_evaluator.rs | 11 + naga/tests/in/const-exprs.wgsl | 7 + .../out/glsl/const-exprs.main.Compute.glsl | 7 + naga/tests/out/hlsl/const-exprs.hlsl | 9 + naga/tests/out/msl/const-exprs.msl | 8 + naga/tests/out/spv/const-exprs.spvasm | 234 ++++++++++-------- naga/tests/out/wgsl/const-exprs.wgsl | 8 + 7 files changed, 174 insertions(+), 110 deletions(-) diff --git a/naga/src/proc/constant_evaluator.rs b/naga/src/proc/constant_evaluator.rs index fcb325f539..3988eaa4f2 100644 --- a/naga/src/proc/constant_evaluator.rs +++ b/naga/src/proc/constant_evaluator.rs @@ -1411,6 +1411,17 @@ impl<'a> ConstantEvaluator<'a> { mut expr: Handle, span: Span, ) -> Result, ConstantEvaluatorError> { + // If expr is a Compose expression, elimate ZeroValue and Splat expressions for + // each of its components. + if let Expression::Compose { ty, ref components } = self.expressions[expr] { + let components = components + .clone() + .iter() + .map(|component| self.eval_zero_value_and_splat(*component, span)) + .collect::>()?; + expr = self.register_evaluated_expr(Expression::Compose { ty, components }, span)?; + } + // The result of the splat() for a Splat of a scalar ZeroValue is a // vector ZeroValue, so we must call eval_zero_value_impl() after // splat() in order to ensure we have no ZeroValues remaining. diff --git a/naga/tests/in/const-exprs.wgsl b/naga/tests/in/const-exprs.wgsl index ee9304ce45..5165c49ceb 100644 --- a/naga/tests/in/const-exprs.wgsl +++ b/naga/tests/in/const-exprs.wgsl @@ -87,3 +87,10 @@ fn compose_of_splat() { const add_vec = vec2(1.0f) + vec2(3.0f, 4.0f); const compare_vec = vec2(3.0f) == vec2(3.0f, 4.0f); + +// Ensure binary ops correctly flatten compositions of vector zero values +fn compose_vector_zero_val_binop() { + var a = vec3(vec2i(), 0) + vec3(1); + var b = vec3(vec2i(), 0) + vec3(0, 1, 2); + var c = vec3(vec2i(), 2) + vec3(1, vec2i()); +} diff --git a/naga/tests/out/glsl/const-exprs.main.Compute.glsl b/naga/tests/out/glsl/const-exprs.main.Compute.glsl index 0b318a65e3..4b473bed7c 100644 --- a/naga/tests/out/glsl/const-exprs.main.Compute.glsl +++ b/naga/tests/out/glsl/const-exprs.main.Compute.glsl @@ -86,6 +86,13 @@ uint map_texture_kind(int texture_kind) { } } +void compose_vector_zero_val_binop() { + ivec3 a = ivec3(1, 1, 1); + ivec3 b = ivec3(0, 1, 2); + ivec3 c = ivec3(1, 0, 2); + return; +} + void main() { swizzle_of_compose(); index_of_compose(); diff --git a/naga/tests/out/hlsl/const-exprs.hlsl b/naga/tests/out/hlsl/const-exprs.hlsl index aa2ba75ed6..9d62504f78 100644 --- a/naga/tests/out/hlsl/const-exprs.hlsl +++ b/naga/tests/out/hlsl/const-exprs.hlsl @@ -93,6 +93,15 @@ uint map_texture_kind(int texture_kind) } } +void compose_vector_zero_val_binop() +{ + int3 a = int3(1, 1, 1); + int3 b = int3(0, 1, 2); + int3 c = int3(1, 0, 2); + + return; +} + [numthreads(2, 3, 1)] void main() { diff --git a/naga/tests/out/msl/const-exprs.msl b/naga/tests/out/msl/const-exprs.msl index cb0959f72e..dc0c394868 100644 --- a/naga/tests/out/msl/const-exprs.msl +++ b/naga/tests/out/msl/const-exprs.msl @@ -93,6 +93,14 @@ uint map_texture_kind( } } +void compose_vector_zero_val_binop( +) { + metal::int3 a = metal::int3(1, 1, 1); + metal::int3 b = metal::int3(0, 1, 2); + metal::int3 c = metal::int3(1, 0, 2); + return; +} + kernel void main_( ) { swizzle_of_compose(); diff --git a/naga/tests/out/spv/const-exprs.spvasm b/naga/tests/out/spv/const-exprs.spvasm index afd9fe8499..63b256d5de 100644 --- a/naga/tests/out/spv/const-exprs.spvasm +++ b/naga/tests/out/spv/const-exprs.spvasm @@ -1,12 +1,12 @@ ; SPIR-V ; Version: 1.1 ; Generator: rspirv -; Bound: 109 +; Bound: 120 OpCapability Shader %1 = OpExtInstImport "GLSL.std.450" OpMemoryModel Logical GLSL450 -OpEntryPoint GLCompute %100 "main" -OpExecutionMode %100 LocalSize 2 3 1 +OpEntryPoint GLCompute %111 "main" +OpExecutionMode %111 LocalSize 2 3 1 %2 = OpTypeVoid %3 = OpTypeInt 32 0 %4 = OpTypeInt 32 1 @@ -16,137 +16,151 @@ OpExecutionMode %100 LocalSize 2 3 1 %8 = OpTypeVector %6 2 %10 = OpTypeBool %9 = OpTypeVector %10 2 -%11 = OpConstant %3 2 -%12 = OpConstant %4 3 -%13 = OpConstant %4 4 -%14 = OpConstant %4 8 -%15 = OpConstant %6 3.141 -%16 = OpConstant %6 6.282 -%17 = OpConstant %6 0.44444445 -%18 = OpConstant %6 0.0 -%19 = OpConstantComposite %7 %17 %18 %18 %18 -%20 = OpConstant %4 0 -%21 = OpConstant %4 1 -%22 = OpConstant %4 2 -%23 = OpConstant %6 4.0 -%24 = OpConstant %6 5.0 -%25 = OpConstantComposite %8 %23 %24 -%26 = OpConstantTrue %10 -%27 = OpConstantFalse %10 -%28 = OpConstantComposite %9 %26 %27 -%31 = OpTypeFunction %2 -%32 = OpConstantComposite %5 %13 %12 %22 %21 -%34 = OpTypePointer Function %5 -%39 = OpTypePointer Function %4 -%43 = OpConstant %4 6 -%48 = OpConstant %4 30 -%49 = OpConstant %4 70 -%52 = OpConstantNull %4 -%54 = OpConstantNull %4 -%57 = OpConstantNull %5 -%68 = OpConstant %4 -4 -%69 = OpConstantComposite %5 %68 %68 %68 %68 -%78 = OpConstant %6 1.0 -%79 = OpConstant %6 2.0 -%80 = OpConstantComposite %7 %79 %78 %78 %78 -%82 = OpTypePointer Function %7 -%87 = OpTypeFunction %3 %4 -%88 = OpConstant %3 10 -%89 = OpConstant %3 20 -%90 = OpConstant %3 30 -%91 = OpConstant %3 0 -%98 = OpConstantNull %3 -%30 = OpFunction %2 None %31 -%29 = OpLabel -%33 = OpVariable %34 Function %32 -OpBranch %35 -%35 = OpLabel -OpReturn -OpFunctionEnd -%37 = OpFunction %2 None %31 +%11 = OpTypeVector %4 3 +%12 = OpConstant %3 2 +%13 = OpConstant %4 3 +%14 = OpConstant %4 4 +%15 = OpConstant %4 8 +%16 = OpConstant %6 3.141 +%17 = OpConstant %6 6.282 +%18 = OpConstant %6 0.44444445 +%19 = OpConstant %6 0.0 +%20 = OpConstantComposite %7 %18 %19 %19 %19 +%21 = OpConstant %4 0 +%22 = OpConstant %4 1 +%23 = OpConstant %4 2 +%24 = OpConstant %6 4.0 +%25 = OpConstant %6 5.0 +%26 = OpConstantComposite %8 %24 %25 +%27 = OpConstantTrue %10 +%28 = OpConstantFalse %10 +%29 = OpConstantComposite %9 %27 %28 +%32 = OpTypeFunction %2 +%33 = OpConstantComposite %5 %14 %13 %23 %22 +%35 = OpTypePointer Function %5 +%40 = OpTypePointer Function %4 +%44 = OpConstant %4 6 +%49 = OpConstant %4 30 +%50 = OpConstant %4 70 +%53 = OpConstantNull %4 +%55 = OpConstantNull %4 +%58 = OpConstantNull %5 +%69 = OpConstant %4 -4 +%70 = OpConstantComposite %5 %69 %69 %69 %69 +%79 = OpConstant %6 1.0 +%80 = OpConstant %6 2.0 +%81 = OpConstantComposite %7 %80 %79 %79 %79 +%83 = OpTypePointer Function %7 +%88 = OpTypeFunction %3 %4 +%89 = OpConstant %3 10 +%90 = OpConstant %3 20 +%91 = OpConstant %3 30 +%92 = OpConstant %3 0 +%99 = OpConstantNull %3 +%102 = OpConstantComposite %11 %22 %22 %22 +%103 = OpConstantComposite %11 %21 %22 %23 +%104 = OpConstantComposite %11 %22 %21 %23 +%106 = OpTypePointer Function %11 +%31 = OpFunction %2 None %32 +%30 = OpLabel +%34 = OpVariable %35 Function %33 +OpBranch %36 %36 = OpLabel -%38 = OpVariable %39 Function %22 -OpBranch %40 -%40 = OpLabel OpReturn OpFunctionEnd -%42 = OpFunction %2 None %31 +%38 = OpFunction %2 None %32 +%37 = OpLabel +%39 = OpVariable %40 Function %23 +OpBranch %41 %41 = OpLabel -%44 = OpVariable %39 Function %43 -OpBranch %45 -%45 = OpLabel OpReturn OpFunctionEnd -%47 = OpFunction %2 None %31 +%43 = OpFunction %2 None %32 +%42 = OpLabel +%45 = OpVariable %40 Function %44 +OpBranch %46 %46 = OpLabel -%56 = OpVariable %34 Function %57 -%51 = OpVariable %39 Function %52 -%55 = OpVariable %39 Function %49 -%50 = OpVariable %39 Function %48 -%53 = OpVariable %39 Function %54 -OpBranch %58 -%58 = OpLabel -%59 = OpLoad %4 %50 -OpStore %51 %59 -%60 = OpLoad %4 %51 -OpStore %53 %60 -%61 = OpLoad %4 %50 -%62 = OpLoad %4 %51 -%63 = OpLoad %4 %53 -%64 = OpLoad %4 %55 -%65 = OpCompositeConstruct %5 %61 %62 %63 %64 -OpStore %56 %65 OpReturn OpFunctionEnd -%67 = OpFunction %2 None %31 -%66 = OpLabel -%70 = OpVariable %34 Function %69 -OpBranch %71 -%71 = OpLabel +%48 = OpFunction %2 None %32 +%47 = OpLabel +%57 = OpVariable %35 Function %58 +%52 = OpVariable %40 Function %53 +%56 = OpVariable %40 Function %50 +%51 = OpVariable %40 Function %49 +%54 = OpVariable %40 Function %55 +OpBranch %59 +%59 = OpLabel +%60 = OpLoad %4 %51 +OpStore %52 %60 +%61 = OpLoad %4 %52 +OpStore %54 %61 +%62 = OpLoad %4 %51 +%63 = OpLoad %4 %52 +%64 = OpLoad %4 %54 +%65 = OpLoad %4 %56 +%66 = OpCompositeConstruct %5 %62 %63 %64 %65 +OpStore %57 %66 OpReturn OpFunctionEnd -%73 = OpFunction %2 None %31 +%68 = OpFunction %2 None %32 +%67 = OpLabel +%71 = OpVariable %35 Function %70 +OpBranch %72 %72 = OpLabel -%74 = OpVariable %34 Function %69 -OpBranch %75 -%75 = OpLabel OpReturn OpFunctionEnd -%77 = OpFunction %2 None %31 +%74 = OpFunction %2 None %32 +%73 = OpLabel +%75 = OpVariable %35 Function %70 +OpBranch %76 %76 = OpLabel -%81 = OpVariable %82 Function %80 -OpBranch %83 -%83 = OpLabel OpReturn OpFunctionEnd -%86 = OpFunction %3 None %87 -%85 = OpFunctionParameter %4 +%78 = OpFunction %2 None %32 +%77 = OpLabel +%82 = OpVariable %83 Function %81 +OpBranch %84 %84 = OpLabel -OpBranch %92 -%92 = OpLabel -OpSelectionMerge %93 None -OpSwitch %85 %97 0 %94 1 %95 2 %96 -%94 = OpLabel -OpReturnValue %88 +OpReturn +OpFunctionEnd +%87 = OpFunction %3 None %88 +%86 = OpFunctionParameter %4 +%85 = OpLabel +OpBranch %93 +%93 = OpLabel +OpSelectionMerge %94 None +OpSwitch %86 %98 0 %95 1 %96 2 %97 %95 = OpLabel OpReturnValue %89 %96 = OpLabel OpReturnValue %90 %97 = OpLabel OpReturnValue %91 -%93 = OpLabel -OpReturnValue %98 +%98 = OpLabel +OpReturnValue %92 +%94 = OpLabel +OpReturnValue %99 +OpFunctionEnd +%101 = OpFunction %2 None %32 +%100 = OpLabel +%105 = OpVariable %106 Function %102 +%107 = OpVariable %106 Function %103 +%108 = OpVariable %106 Function %104 +OpBranch %109 +%109 = OpLabel +OpReturn OpFunctionEnd -%100 = OpFunction %2 None %31 -%99 = OpLabel -OpBranch %101 -%101 = OpLabel -%102 = OpFunctionCall %2 %30 -%103 = OpFunctionCall %2 %37 -%104 = OpFunctionCall %2 %42 -%105 = OpFunctionCall %2 %47 -%106 = OpFunctionCall %2 %67 -%107 = OpFunctionCall %2 %73 -%108 = OpFunctionCall %2 %77 +%111 = OpFunction %2 None %32 +%110 = OpLabel +OpBranch %112 +%112 = OpLabel +%113 = OpFunctionCall %2 %31 +%114 = OpFunctionCall %2 %38 +%115 = OpFunctionCall %2 %43 +%116 = OpFunctionCall %2 %48 +%117 = OpFunctionCall %2 %68 +%118 = OpFunctionCall %2 %74 +%119 = OpFunctionCall %2 %78 OpReturn OpFunctionEnd \ No newline at end of file diff --git a/naga/tests/out/wgsl/const-exprs.wgsl b/naga/tests/out/wgsl/const-exprs.wgsl index 411e835b5d..4649807eb8 100644 --- a/naga/tests/out/wgsl/const-exprs.wgsl +++ b/naga/tests/out/wgsl/const-exprs.wgsl @@ -85,6 +85,14 @@ fn map_texture_kind(texture_kind: i32) -> u32 { } } +fn compose_vector_zero_val_binop() { + var a: vec3 = vec3(1i, 1i, 1i); + var b: vec3 = vec3(0i, 1i, 2i); + var c: vec3 = vec3(1i, 0i, 2i); + + return; +} + @compute @workgroup_size(2, 3, 1) fn main() { swizzle_of_compose();