[Mlir-commits] [mlir] [mlir][spirv][cf] legalize block arguments when convert cf to spirv (PR #71288)

Xiang Li llvmlistbot at llvm.org
Sat Nov 4 16:28:51 PDT 2023


https://github.com/python3kgae updated https://github.com/llvm/llvm-project/pull/71288

>From 65b3e61ee9fd32062ea92bd380b46618ccda0704 Mon Sep 17 00:00:00 2001
From: Xiang Li <python3kgae at outlook.com>
Date: Sat, 4 Nov 2023 12:13:11 -0400
Subject: [PATCH 1/2] [mlir][spirv][cf] legalize block arguments when convert
 cf to spirv

When convert branches, legalize target block arguments first.

Fixes llvm#70813
---
 .../ControlFlowToSPIRV/ControlFlowToSPIRV.cpp | 42 +++++++++----------
 .../ControlFlowToSPIRV/cf-ops-to-spirv.mlir   |  6 ++-
 2 files changed, 24 insertions(+), 24 deletions(-)

diff --git a/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRV.cpp b/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRV.cpp
index 5dba79016120b38..c24a5b2e5b05d50 100644
--- a/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRV.cpp
+++ b/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRV.cpp
@@ -28,20 +28,23 @@
 
 using namespace mlir;
 
-/// Checks that the target block arguments are legal.
-static LogicalResult checkBlockArguments(Block &block, Operation *op,
-                                         PatternRewriter &rewriter,
-                                         const TypeConverter &converter) {
-  for (BlockArgument arg : block.getArguments()) {
-    if (!converter.isLegal(arg.getType())) {
-      return rewriter.notifyMatchFailure(
-          op,
-          llvm::formatv(
-              "failed to match, destination argument not legalized (found {0})",
-              arg));
-    }
+/// Legailze target block arguments.
+static void legalizeBlockArguments(Block &block,
+                                   const TypeConverter &converter) {
+  auto builder = OpBuilder::atBlockBegin(&block);
+  for (unsigned i = 0; i < block.getNumArguments(); ++i) {
+    const auto arg = block.getArgument(i);
+    if (converter.isLegal(arg.getType()))
+      continue;
+    unsigned argNum = arg.getArgNumber();
+    Location loc = arg.getLoc();
+    Type ty = arg.getType();
+    Type newTy = converter.convertType(ty);
+    Value newArg = block.insertArgument(argNum, newTy, loc);
+    auto cast = builder.create<UnrealizedConversionCastOp>(loc, ty, newArg);
+    arg.replaceAllUsesWith(cast.getResult(0));
+    block.eraseArgument(argNum + 1);
   }
-  return success();
 }
 
 //===----------------------------------------------------------------------===//
@@ -56,9 +59,7 @@ struct BranchOpPattern final : OpConversionPattern<cf::BranchOp> {
   LogicalResult
   matchAndRewrite(cf::BranchOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    if (failed(checkBlockArguments(*op.getDest(), op, rewriter,
-                                   *getTypeConverter())))
-      return failure();
+    legalizeBlockArguments(*op.getDest(), *getTypeConverter());
 
     rewriter.replaceOpWithNewOp<spirv::BranchOp>(op, op.getDest(),
                                                  adaptor.getDestOperands());
@@ -73,13 +74,8 @@ struct CondBranchOpPattern final : OpConversionPattern<cf::CondBranchOp> {
   LogicalResult
   matchAndRewrite(cf::CondBranchOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    if (failed(checkBlockArguments(*op.getTrueDest(), op, rewriter,
-                                   *getTypeConverter())))
-      return failure();
-
-    if (failed(checkBlockArguments(*op.getFalseDest(), op, rewriter,
-                                   *getTypeConverter())))
-      return failure();
+    legalizeBlockArguments(*op.getTrueDest(), *getTypeConverter());
+    legalizeBlockArguments(*op.getFalseDest(), *getTypeConverter());
 
     rewriter.replaceOpWithNewOp<spirv::BranchConditionalOp>(
         op, adaptor.getCondition(), op.getTrueDest(),
diff --git a/mlir/test/Conversion/ControlFlowToSPIRV/cf-ops-to-spirv.mlir b/mlir/test/Conversion/ControlFlowToSPIRV/cf-ops-to-spirv.mlir
index a10a1468e582499..4e2f6a12c8e7ee4 100644
--- a/mlir/test/Conversion/ControlFlowToSPIRV/cf-ops-to-spirv.mlir
+++ b/mlir/test/Conversion/ControlFlowToSPIRV/cf-ops-to-spirv.mlir
@@ -39,16 +39,20 @@ func.func @simple_loop(%begin: i32, %end: i32, %step: i32) {
 
 // -----
 
-// TODO: We should handle blocks whose arguments require type conversion.
+// Handle blocks whose arguments require type conversion.
 
 // CHECK-LABEL: func.func @main_graph
 func.func @main_graph(%arg0: index) {
   %c3 = arith.constant 1 : index
+// CHECK:  spirv.Branch ^bb1({{.*}} : i32)
   cf.br ^bb1(%arg0 : index)
+// CHECK:      ^bb1({{.*}}: i32):       // 2 preds: ^bb0, ^bb2
 ^bb1(%0: index):  // 2 preds: ^bb0, ^bb2
   %1 = arith.cmpi slt, %0, %c3 : index
+// CHECK:        spirv.BranchConditional {{.*}}, ^bb2, ^bb3
   cf.cond_br %1, ^bb2, ^bb3
 ^bb2:  // pred: ^bb1
+// CHECK:  spirv.Branch ^bb1({{.*}} : i32)
   cf.br ^bb1(%c3 : index)
 ^bb3:  // pred: ^bb1
   return

>From e67457266e182f5451fdf7035cf260d03e2fecdf Mon Sep 17 00:00:00 2001
From: Xiang Li <python3kgae at outlook.com>
Date: Sat, 4 Nov 2023 19:28:39 -0400
Subject: [PATCH 2/2] Handle conversion failures and use
 materializeSourceConversion.

---
 .../ControlFlowToSPIRV/ControlFlowToSPIRV.cpp | 37 ++++++++++++++-----
 1 file changed, 28 insertions(+), 9 deletions(-)

diff --git a/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRV.cpp b/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRV.cpp
index c24a5b2e5b05d50..8907756d33845a0 100644
--- a/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRV.cpp
+++ b/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRV.cpp
@@ -29,22 +29,34 @@
 using namespace mlir;
 
 /// Legailze target block arguments.
-static void legalizeBlockArguments(Block &block,
-                                   const TypeConverter &converter) {
+static LogicalResult legalizeBlockArguments(Block &block, Operation *op,
+                                            PatternRewriter &rewriter,
+                                            const TypeConverter &converter) {
   auto builder = OpBuilder::atBlockBegin(&block);
   for (unsigned i = 0; i < block.getNumArguments(); ++i) {
     const auto arg = block.getArgument(i);
     if (converter.isLegal(arg.getType()))
       continue;
-    unsigned argNum = arg.getArgNumber();
-    Location loc = arg.getLoc();
     Type ty = arg.getType();
     Type newTy = converter.convertType(ty);
+    if (!newTy) {
+      return rewriter.notifyMatchFailure(
+          op, llvm::formatv("failed to legalize type for argument {0})", arg));
+    }
+    unsigned argNum = arg.getArgNumber();
+    Location loc = arg.getLoc();
     Value newArg = block.insertArgument(argNum, newTy, loc);
-    auto cast = builder.create<UnrealizedConversionCastOp>(loc, ty, newArg);
-    arg.replaceAllUsesWith(cast.getResult(0));
+    Value convertedValue = converter.materializeSourceConversion(
+        builder, op->getLoc(), ty, newArg);
+    if (!convertedValue) {
+      return rewriter.notifyMatchFailure(
+          op, llvm::formatv("failed to cast new argument {0} to type {1})",
+                            newArg, ty));
+    }
+    arg.replaceAllUsesWith(convertedValue);
     block.eraseArgument(argNum + 1);
   }
+  return success();
 }
 
 //===----------------------------------------------------------------------===//
@@ -59,7 +71,9 @@ struct BranchOpPattern final : OpConversionPattern<cf::BranchOp> {
   LogicalResult
   matchAndRewrite(cf::BranchOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    legalizeBlockArguments(*op.getDest(), *getTypeConverter());
+    if (failed(legalizeBlockArguments(*op.getDest(), op, rewriter,
+                                      *getTypeConverter())))
+      return failure();
 
     rewriter.replaceOpWithNewOp<spirv::BranchOp>(op, op.getDest(),
                                                  adaptor.getDestOperands());
@@ -74,8 +88,13 @@ struct CondBranchOpPattern final : OpConversionPattern<cf::CondBranchOp> {
   LogicalResult
   matchAndRewrite(cf::CondBranchOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    legalizeBlockArguments(*op.getTrueDest(), *getTypeConverter());
-    legalizeBlockArguments(*op.getFalseDest(), *getTypeConverter());
+    if (failed(legalizeBlockArguments(*op.getTrueDest(), op, rewriter,
+                                      *getTypeConverter())))
+      return failure();
+
+    if (failed(legalizeBlockArguments(*op.getFalseDest(), op, rewriter,
+                                      *getTypeConverter())))
+      return failure();
 
     rewriter.replaceOpWithNewOp<spirv::BranchConditionalOp>(
         op, adaptor.getCondition(), op.getTrueDest(),



More information about the Mlir-commits mailing list