[Mlir-commits] [mlir] [mlir][spirv][cf] Check destination block argument types (PR #70889)

Jakub Kuderski llvmlistbot at llvm.org
Tue Oct 31 20:13:49 PDT 2023


https://github.com/kuhar updated https://github.com/llvm/llvm-project/pull/70889

>From f54d81de121f9623ef3b088d619fd209e16b7022 Mon Sep 17 00:00:00 2001
From: Jakub Kuderski <jakub at nod-labs.com>
Date: Tue, 31 Oct 2023 23:06:53 -0400
Subject: [PATCH 1/2] [mlir][spirv][cf] Check destination block argument types

Emit errors for illegal destination blocks.

We should add region/block argument type conversions to properly address
this.

Issue: https://github.com/llvm/llvm-project/issues/70813
---
 .../ControlFlowToSPIRV/ControlFlowToSPIRV.cpp | 39 +++++++++++++++----
 .../ControlFlowToSPIRVPass.cpp                |  4 +-
 .../ControlFlowToSPIRV/cf-ops-to-spirv.mlir   | 19 ++++++++-
 3 files changed, 52 insertions(+), 10 deletions(-)

diff --git a/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRV.cpp b/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRV.cpp
index 6787a5ccd3a4a87..6ba4388cb3b5825 100644
--- a/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRV.cpp
+++ b/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRV.cpp
@@ -19,25 +19,41 @@
 #include "mlir/Dialect/SPIRV/Utils/LayoutUtils.h"
 #include "mlir/IR/AffineMap.h"
 #include "mlir/Support/LogicalResult.h"
+#include "mlir/Transforms/DialectConversion.h"
 #include "llvm/Support/Debug.h"
 
 #define DEBUG_TYPE "cf-to-spirv-pattern"
 
 using namespace mlir;
 
+/// Ensures that the target block arguments are legal, emits an error if not.
+static LogicalResult checkBlockArguments(Block &block, Operation *branchOp,
+                                         const TypeConverter &converter) {
+  for (BlockArgument arg : block.getArguments()) {
+    if (!converter.isLegal(arg.getType())) {
+      return branchOp->emitOpError(
+                 "failed to match, destination argument not legalized (found ")
+             << arg << ")";
+    }
+  }
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // Operation conversion
 //===----------------------------------------------------------------------===//
 
 namespace {
-
 /// Converts cf.br to spirv.Branch.
-struct BranchOpPattern final : public OpConversionPattern<cf::BranchOp> {
-  using OpConversionPattern<cf::BranchOp>::OpConversionPattern;
+struct BranchOpPattern final : OpConversionPattern<cf::BranchOp> {
+  using OpConversionPattern::OpConversionPattern;
 
   LogicalResult
   matchAndRewrite(cf::BranchOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
+    if (failed(checkBlockArguments(*op.getDest(), op, *getTypeConverter())))
+      return failure();
+
     rewriter.replaceOpWithNewOp<spirv::BranchOp>(op, op.getDest(),
                                                  adaptor.getDestOperands());
     return success();
@@ -45,16 +61,23 @@ struct BranchOpPattern final : public OpConversionPattern<cf::BranchOp> {
 };
 
 /// Converts cf.cond_br to spirv.BranchConditional.
-struct CondBranchOpPattern final
-    : public OpConversionPattern<cf::CondBranchOp> {
-  using OpConversionPattern<cf::CondBranchOp>::OpConversionPattern;
+struct CondBranchOpPattern final : OpConversionPattern<cf::CondBranchOp> {
+  using OpConversionPattern::OpConversionPattern;
 
   LogicalResult
   matchAndRewrite(cf::CondBranchOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
+    if (failed(checkBlockArguments(*op.getTrueDest(), op, *getTypeConverter())))
+      return failure();
+
+    if (failed(
+            checkBlockArguments(*op.getFalseDest(), op, *getTypeConverter())))
+      return failure();
+
     rewriter.replaceOpWithNewOp<spirv::BranchConditionalOp>(
-        op, op.getCondition(), op.getTrueDest(), adaptor.getTrueDestOperands(),
-        op.getFalseDest(), adaptor.getFalseDestOperands());
+        op, adaptor.getCondition(), op.getTrueDest(),
+        adaptor.getTrueDestOperands(), op.getFalseDest(),
+        adaptor.getFalseDestOperands());
     return success();
   }
 };
diff --git a/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRVPass.cpp b/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRVPass.cpp
index d8aecae257b461e..a752b82eac7c343 100644
--- a/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRVPass.cpp
+++ b/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRVPass.cpp
@@ -25,7 +25,7 @@ using namespace mlir;
 
 namespace {
 /// A pass converting MLIR ControlFlow operations into the SPIR-V dialect.
-class ConvertControlFlowToSPIRVPass
+class ConvertControlFlowToSPIRVPass final
     : public impl::ConvertControlFlowToSPIRVBase<
           ConvertControlFlowToSPIRVPass> {
   void runOnOperation() override;
@@ -44,6 +44,8 @@ void ConvertControlFlowToSPIRVPass::runOnOperation() {
   options.emulateLT32BitScalarTypes = this->emulateLT32BitScalarTypes;
   SPIRVTypeConverter typeConverter(targetAttr, options);
 
+  // TODO: We should also take care of block argument type conversion.
+
   RewritePatternSet patterns(context);
   cf::populateControlFlowToSPIRVPatterns(typeConverter, patterns);
 
diff --git a/mlir/test/Conversion/ControlFlowToSPIRV/cf-ops-to-spirv.mlir b/mlir/test/Conversion/ControlFlowToSPIRV/cf-ops-to-spirv.mlir
index 98ee6891453496d..0e078dc43208c7b 100644
--- a/mlir/test/Conversion/ControlFlowToSPIRV/cf-ops-to-spirv.mlir
+++ b/mlir/test/Conversion/ControlFlowToSPIRV/cf-ops-to-spirv.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -split-input-file -convert-cf-to-spirv -verify-diagnostics %s | FileCheck %s
+// RUN: mlir-opt --split-input-file --convert-cf-to-spirv --verify-diagnostics %s | FileCheck %s
 
 //===----------------------------------------------------------------------===//
 // cf.br, cf.cond_br
@@ -36,3 +36,20 @@ func.func @simple_loop(%begin: i32, %end: i32, %step: i32) {
 }
 
 }
+
+// -----
+
+// TODO: We should handle blocks whose arguments require type conversion.
+
+func.func @main_graph(%arg0: index) {
+  %c3 = arith.constant 1 : index
+  cf.br ^bb1(%arg0 : index)
+// expected-error at +1 {{'cf.br' op failed to match, destination argument not legalized}}
+^bb1(%0: index):  // 2 preds: ^bb0, ^bb2
+  %1 = arith.cmpi slt, %0, %c3 : index
+  cf.cond_br %1, ^bb2, ^bb3
+^bb2:  // pred: ^bb1
+  cf.br ^bb1(%c3 : index)
+^bb3:  // pred: ^bb1
+  return
+}

>From 54cab9910deafe85f2468928b54b77f923e4c0dc Mon Sep 17 00:00:00 2001
From: Jakub Kuderski <jakub at nod-labs.com>
Date: Tue, 31 Oct 2023 23:13:37 -0400
Subject: [PATCH 2/2] fix test

---
 mlir/test/Conversion/ControlFlowToSPIRV/cf-ops-to-spirv.mlir | 3 ++-
 1 file changed, 2 insertions(+), 1 deletion(-)

diff --git a/mlir/test/Conversion/ControlFlowToSPIRV/cf-ops-to-spirv.mlir b/mlir/test/Conversion/ControlFlowToSPIRV/cf-ops-to-spirv.mlir
index 0e078dc43208c7b..7303ea6e9ca6242 100644
--- a/mlir/test/Conversion/ControlFlowToSPIRV/cf-ops-to-spirv.mlir
+++ b/mlir/test/Conversion/ControlFlowToSPIRV/cf-ops-to-spirv.mlir
@@ -43,12 +43,13 @@ func.func @simple_loop(%begin: i32, %end: i32, %step: i32) {
 
 func.func @main_graph(%arg0: index) {
   %c3 = arith.constant 1 : index
-  cf.br ^bb1(%arg0 : index)
 // expected-error at +1 {{'cf.br' op failed to match, destination argument not legalized}}
+  cf.br ^bb1(%arg0 : index)
 ^bb1(%0: index):  // 2 preds: ^bb0, ^bb2
   %1 = arith.cmpi slt, %0, %c3 : index
   cf.cond_br %1, ^bb2, ^bb3
 ^bb2:  // pred: ^bb1
+// expected-error at +1 {{'cf.br' op failed to match, destination argument not legalized}}
   cf.br ^bb1(%c3 : index)
 ^bb3:  // pred: ^bb1
   return



More information about the Mlir-commits mailing list