[Mlir-commits] [mlir] Fix SSA Handling in SPIRV -> MLIR Translation (PR #123371)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Jan 17 08:54:55 PST 2025


https://github.com/mishaobu created https://github.com/llvm/llvm-project/pull/123371

None

>From b91b29a4b90ba2965ba7e8ed609c12f23a1e42e8 Mon Sep 17 00:00:00 2001
From: "Misha (M3 MBP)" <obukhov.michael+m3mbp at gmail.com>
Date: Fri, 17 Jan 2025 17:51:40 +0100
Subject: [PATCH 1/2] tests

---
 mlir/test/Target/SPIRV/branch-load.mlir | 35 ++++++++++++++++++++++
 mlir/test/Target/SPIRV/ssa.mlir         | 40 +++++++++++++++++++++++++
 mlir/test/Target/SPIRV/ssa2.mlir        | 36 ++++++++++++++++++++++
 3 files changed, 111 insertions(+)
 create mode 100644 mlir/test/Target/SPIRV/branch-load.mlir
 create mode 100644 mlir/test/Target/SPIRV/ssa.mlir
 create mode 100644 mlir/test/Target/SPIRV/ssa2.mlir

diff --git a/mlir/test/Target/SPIRV/branch-load.mlir b/mlir/test/Target/SPIRV/branch-load.mlir
new file mode 100644
index 00000000000000..2a0c376e9302ce
--- /dev/null
+++ b/mlir/test/Target/SPIRV/branch-load.mlir
@@ -0,0 +1,35 @@
+// RUN: mlir-translate -no-implicit-module -test-spirv-roundtrip %s | FileCheck %s
+
+// CHECK: spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []>
+spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
+  // CHECK: spirv.func @main() "None"
+  spirv.func @main() "None" {
+    // CHECK: %[[VAR:.*]] = spirv.Variable : !spirv.ptr<vector<3xf32>, Function>
+    %0 = spirv.Variable : !spirv.ptr<vector<3xf32>, Function>
+    spirv.Branch ^bb1
+  ^bb1:  // pred: ^bb0
+    // CHECK: spirv.mlir.selection
+    spirv.mlir.selection {
+      // CHECK: %[[COND:.*]] = spirv.Constant true
+      // CHECK: spirv.BranchConditional %[[COND]]
+      %true = spirv.Constant true
+      spirv.BranchConditional %true, ^bb1, ^bb2
+    ^bb1:  // pred: ^bb0
+      // CHECK: %[[CONST:.*]] = spirv.Constant dense<0.000000e+00> : vector<3xf32>
+      // CHECK: spirv.Store "Function" %[[VAR]], %[[CONST]] : vector<3xf32>
+      %cst_vec_3xf32 = spirv.Constant dense<0.000000e+00> : vector<3xf32>
+      spirv.Store "Function" %0, %cst_vec_3xf32 : vector<3xf32>
+      spirv.Branch ^bb2
+    ^bb2:  // 2 preds: ^bb0, ^bb1
+      spirv.mlir.merge
+    }
+    // CHECK: %[[RESULT:.*]] = spirv.Load "Function" %[[VAR]] : vector<3xf32>
+    // CHECK: spirv.Return
+    %1 = spirv.Load "Function" %0 : vector<3xf32>
+    spirv.Return
+  }
+  // CHECK: spirv.EntryPoint "Fragment" @main
+  // CHECK: spirv.ExecutionMode @main "OriginUpperLeft"
+  spirv.EntryPoint "Fragment" @main
+  spirv.ExecutionMode @main "OriginUpperLeft"
+}
diff --git a/mlir/test/Target/SPIRV/ssa.mlir b/mlir/test/Target/SPIRV/ssa.mlir
new file mode 100644
index 00000000000000..d3556991ba4630
--- /dev/null
+++ b/mlir/test/Target/SPIRV/ssa.mlir
@@ -0,0 +1,40 @@
+# RUN: split-file %s %t
+# RUN: spirv-as --target-env spv1.0 %t/spv.spvasm -o %t.spv
+# RUN: mlir-translate --deserialize-spirv %t.spv -o - | FileCheck %s
+
+// CHECK: module
+// CHECK: spirv.func @main
+// CHECK: spirv.Variable
+// CHECK: spirv.Return
+//--- spv.spvasm
+; SPIR-V
+; Version: 1.0
+; Generator: Khronos SPIR-V Tools Assembler; 0
+; Bound: 20
+; Schema: 0
+               OpCapability Shader
+               OpMemoryModel Logical GLSL450
+               OpEntryPoint Fragment %main "main"
+               OpExecutionMode %main OriginUpperLeft
+       %void = OpTypeVoid
+      %float = OpTypeFloat 32
+    %v3float = OpTypeVector %float 3
+    %ptr_v3f = OpTypePointer Function %v3float
+         %fn = OpTypeFunction %void
+    %float_0 = OpConstant %float 0
+    %float_1 = OpConstant %float 1
+       %bool = OpTypeBool
+      %true = OpConstantTrue %bool
+    %v3_zero = OpConstantComposite %v3float %float_0 %float_0 %float_0
+       %main = OpFunction %void None %fn
+      %entry = OpLabel
+         %var = OpVariable %ptr_v3f Function
+               OpSelectionMerge %merge None
+               OpBranchConditional %true %then %merge
+       %then = OpLabel
+               OpStore %var %v3_zero
+               OpBranch %merge
+      %merge = OpLabel
+         %load = OpLoad %v3float %var
+               OpReturn
+               OpFunctionEnd
\ No newline at end of file
diff --git a/mlir/test/Target/SPIRV/ssa2.mlir b/mlir/test/Target/SPIRV/ssa2.mlir
new file mode 100644
index 00000000000000..e50d5cb571bbce
--- /dev/null
+++ b/mlir/test/Target/SPIRV/ssa2.mlir
@@ -0,0 +1,36 @@
+# RUN: split-file %s %t
+# RUN: spirv-as --target-env spv1.0 %t/spv.spvasm -o %t.spv
+# RUN: mlir-translate --deserialize-spirv %t.spv -o - | FileCheck %s
+// CHECK: module
+// CHECK: spirv.func @main
+// CHECK: spirv.Variable
+// CHECK: spirv.Return
+//--- spv.spvasm
+; SPIR-V
+; Version: 1.0
+; Generator: Khronos SPIR-V Tools Assembler; 0
+; Bound: 20
+; Schema: 0
+               OpCapability Shader
+               OpMemoryModel Logical GLSL450
+               OpEntryPoint Fragment %main "main"
+               OpExecutionMode %main OriginUpperLeft
+       %void = OpTypeVoid
+      %float = OpTypeFloat 32
+    %ptr_f = OpTypePointer Function %float
+         %fn = OpTypeFunction %void
+    %float_1 = OpConstant %float 1.0
+       %bool = OpTypeBool
+      %true = OpConstantTrue %bool
+       %main = OpFunction %void None %fn
+      %entry = OpLabel
+         %var = OpVariable %ptr_f Function
+               OpSelectionMerge %merge None
+               OpBranchConditional %true %then %merge
+       %then = OpLabel
+               OpStore %var %float_1
+               OpBranch %merge
+      %merge = OpLabel
+         %load = OpLoad %float %var
+               OpReturn
+               OpFunctionEnd
\ No newline at end of file

>From f61b88e59a45c2a9b74693bc7a76c894fc533e46 Mon Sep 17 00:00:00 2001
From: "Misha (M3 MBP)" <obukhov.michael+m3mbp at gmail.com>
Date: Fri, 17 Jan 2025 17:52:16 +0100
Subject: [PATCH 2/2] fix spirv -> mlir translation ssa handling

---
 .../SPIRV/Deserialization/Deserializer.cpp    | 22 ++++++++++++++++++-
 1 file changed, 21 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
index 04469f1933819b..ee62b7da66fc22 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
@@ -1937,8 +1937,16 @@ LogicalResult ControlFlowStructurizer::structurize() {
                  << "[cf] block " << block << " is a function entry block\n");
     }
 
-    for (auto &op : *block)
+    for (auto &op : *block) {
+      if (auto varOp = dyn_cast<spirv::VariableOp>(op)) {
+        if (varOp.getStorageClass() == spirv::StorageClass::Function) { // This prevents %1 variable duplication in composite4anti
+          // For function-scoped variables, ensure proper mapping but maintain their original location
+          mapper.map(&op, &op);
+          continue;
+        }
+      }
       newBlock->push_back(op.clone(mapper));
+    }
   }
 
   // Go through all ops and remap the operands.
@@ -2006,6 +2014,12 @@ LogicalResult ControlFlowStructurizer::structurize() {
   // the SelectionOp/LoopOp's region, there is no escape for it:
   // SelectionOp/LooOp does not support yield values right now.
   for (auto *block : constructBlocks) {
+    block->walk([&](spirv::VariableOp varOp) {
+      if (varOp.getStorageClass() == spirv::StorageClass::Function) {
+        // Move function variables to the entry block to preserve their lifetime
+        varOp->moveBefore(&body.front().front());
+      }
+    });
     for (Operation &op : *block)
       if (!op.use_empty())
         return op.emitOpError(
@@ -2070,6 +2084,12 @@ LogicalResult ControlFlowStructurizer::structurize() {
     }
   }
 
+  if (auto selectionOp = llvm::dyn_cast<spirv::SelectionOp>(op)) {
+    selectionOp.walk([&](spirv::VariableOp varOp) {
+      varOp->moveBefore(&op->getParentRegion()->front().front());
+    });
+  }
+
   LLVM_DEBUG(logger.startLine()
              << "[cf] after structurizing construct with header block "
              << headerBlock << ":\n"



More information about the Mlir-commits mailing list