[Mlir-commits] [mlir] [mlir][spirv] SCFToSPIRV: fix WhileOp block args types conversion (PR #68588)

Ivan Butygin llvmlistbot at llvm.org
Mon Oct 9 07:11:31 PDT 2023


https://github.com/Hardcode84 created https://github.com/llvm/llvm-project/pull/68588

WhileOp before/after block args types weren't converted, resulting in invalid IR.

>From 967d7a53eeac2182785aa5e418f17071169b6cf3 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Mon, 9 Oct 2023 16:06:51 +0200
Subject: [PATCH] [mlir][spirv] SCFToSPIRV: fix WhileOp block args types
 conversion

WhileOp before/after block args types weren't converted, resulting in invalid IR.
---
 mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp |  9 ++-
 mlir/test/Conversion/SCFToSPIRV/while.mlir    | 58 +++++++++++++++++++
 2 files changed, 65 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp b/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp
index f6e3053b8ae6a96..e749f2bc101297d 100644
--- a/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp
+++ b/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp
@@ -344,11 +344,16 @@ struct WhileOpConversion final : SCFToSPIRVPattern<scf::WhileOp> {
     auto loopOp = rewriter.create<spirv::LoopOp>(loc, spirv::LoopControl::None);
     loopOp.addEntryAndMergeBlock();
 
-    OpBuilder::InsertionGuard guard(rewriter);
-
     Region &beforeRegion = whileOp.getBefore();
     Region &afterRegion = whileOp.getAfter();
 
+    if (failed(rewriter.convertRegionTypes(&beforeRegion, typeConverter)) ||
+        failed(rewriter.convertRegionTypes(&afterRegion, typeConverter)))
+      return rewriter.notifyMatchFailure(whileOp,
+                                         "Failed to convert region types");
+
+    OpBuilder::InsertionGuard guard(rewriter);
+
     Block &entryBlock = *loopOp.getEntryBlock();
     Block &beforeBlock = beforeRegion.front();
     Block &afterBlock = afterRegion.front();
diff --git a/mlir/test/Conversion/SCFToSPIRV/while.mlir b/mlir/test/Conversion/SCFToSPIRV/while.mlir
index a7e07a7086034e4..ff455383a7f05cc 100644
--- a/mlir/test/Conversion/SCFToSPIRV/while.mlir
+++ b/mlir/test/Conversion/SCFToSPIRV/while.mlir
@@ -69,4 +69,62 @@ func.func @while_loop2(%arg0: f32) -> i64 {
   return %res : i64
 }
 
+// -----
+
+// CHECK-LABEL: @while_loop_before_typeconv
+func.func @while_loop_before_typeconv(%arg0: index) -> i64 {
+  // CHECK-SAME: (%[[ARG:.*]]: i32)
+  // CHECK: %[[VAR:.*]] = spirv.Variable : !spirv.ptr<i64, Function>
+  // CHECK: spirv.mlir.loop {
+  // CHECK:   spirv.Branch ^[[HEADER:.*]](%[[ARG]] : i32)
+  // CHECK: ^[[HEADER]](%[[INDVAR1:.*]]: i32):
+  // CHECK:   spirv.BranchConditional %{{.*}}, ^[[BODY:.*]](%{{.*}} : i64), ^[[MERGE:.*]]
+  // CHECK: ^[[BODY]](%[[INDVAR2:.*]]: i64):
+  // CHECK: spirv.Branch ^[[HEADER]](%{{.*}} : i32)
+  // CHECK: ^[[MERGE]]:
+  // CHECK:   spirv.mlir.merge
+  // CHECK: }
+  %res = scf.while (%arg1 = %arg0) : (index) -> i64 {
+    %shared = "foo.shared_compute"(%arg1) : (index) -> i64
+    %condition = "foo.evaluate_condition"(%arg1, %shared) : (index, i64) -> i1
+    scf.condition(%condition) %shared : i64
+  } do {
+  ^bb0(%arg2: i64):
+    %res = "foo.payload"(%arg2) : (i64) -> index
+    scf.yield %res : index
+  }
+  // CHECK: %[[OUT:.*]] = spirv.Load "Function" %[[VAR]] : i64
+  // CHECK: spirv.ReturnValue %[[OUT]] : i64
+  return %res : i64
+}
+
+// -----
+
+// CHECK-LABEL: @while_loop_after_typeconv
+func.func @while_loop_after_typeconv(%arg0: f32) -> index {
+  // CHECK-SAME: (%[[ARG:.*]]: f32)
+  // CHECK: %[[VAR:.*]] = spirv.Variable : !spirv.ptr<i32, Function>
+  // CHECK: spirv.mlir.loop {
+  // CHECK:   spirv.Branch ^[[HEADER:.*]](%[[ARG]] : f32)
+  // CHECK: ^[[HEADER]](%[[INDVAR1:.*]]: f32):
+  // CHECK:   spirv.BranchConditional %{{.*}}, ^[[BODY:.*]](%{{.*}} : i32), ^[[MERGE:.*]]
+  // CHECK: ^[[BODY]](%[[INDVAR2:.*]]: i32):
+  // CHECK: spirv.Branch ^[[HEADER]](%{{.*}} : f32)
+  // CHECK: ^[[MERGE]]:
+  // CHECK:   spirv.mlir.merge
+  // CHECK: }
+  %res = scf.while (%arg1 = %arg0) : (f32) -> index {
+    %shared = "foo.shared_compute"(%arg1) : (f32) -> index
+    %condition = "foo.evaluate_condition"(%arg1, %shared) : (f32, index) -> i1
+    scf.condition(%condition) %shared : index
+  } do {
+  ^bb0(%arg2: index):
+    %res = "foo.payload"(%arg2) : (index) -> f32
+    scf.yield %res : f32
+  }
+  // CHECK: %[[OUT:.*]] = spirv.Load "Function" %[[VAR]] : i32
+  // CHECK: spirv.ReturnValue %[[OUT]] : i32
+  return %res : index
+}
+
 } // end module



More information about the Mlir-commits mailing list