[Mlir-commits] [mlir] 13644f0 - [mlir][spirv][cf] Check destination block argument types (#70889)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Nov 1 21:34:55 PDT 2023


Author: Jakub Kuderski
Date: 2023-11-02T00:34:51-04:00
New Revision: 13644f0bda0b11968c70aa051d2c2455229c6970

URL: https://github.com/llvm/llvm-project/commit/13644f0bda0b11968c70aa051d2c2455229c6970
DIFF: https://github.com/llvm/llvm-project/commit/13644f0bda0b11968c70aa051d2c2455229c6970.diff

LOG: [mlir][spirv][cf] Check destination block argument types (#70889)

Do not match on illegal destination blocks. Also apply some minor
cleanups.

TODO: We should add region/block argument type conversions to properly
address this.

Issue: https://github.com/llvm/llvm-project/issues/70813

Added: 
    

Modified: 
    mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRV.cpp
    mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRVPass.cpp
    mlir/test/Conversion/ControlFlowToSPIRV/cf-ops-to-spirv.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRV.cpp b/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRV.cpp
index 6787a5ccd3a4a87..5dba79016120b38 100644
--- a/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRV.cpp
+++ b/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRV.cpp
@@ -18,26 +18,48 @@
 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
 #include "mlir/Dialect/SPIRV/Utils/LayoutUtils.h"
 #include "mlir/IR/AffineMap.h"
+#include "mlir/IR/PatternMatch.h"
 #include "mlir/Support/LogicalResult.h"
+#include "mlir/Transforms/DialectConversion.h"
 #include "llvm/Support/Debug.h"
+#include "llvm/Support/FormatVariadic.h"
 
 #define DEBUG_TYPE "cf-to-spirv-pattern"
 
 using namespace mlir;
 
+/// Checks that the target block arguments are legal.
+static LogicalResult checkBlockArguments(Block &block, Operation *op,
+                                         PatternRewriter &rewriter,
+                                         const TypeConverter &converter) {
+  for (BlockArgument arg : block.getArguments()) {
+    if (!converter.isLegal(arg.getType())) {
+      return rewriter.notifyMatchFailure(
+          op,
+          llvm::formatv(
+              "failed to match, destination argument not legalized (found {0})",
+              arg));
+    }
+  }
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // Operation conversion
 //===----------------------------------------------------------------------===//
 
 namespace {
-
 /// Converts cf.br to spirv.Branch.
-struct BranchOpPattern final : public OpConversionPattern<cf::BranchOp> {
-  using OpConversionPattern<cf::BranchOp>::OpConversionPattern;
+struct BranchOpPattern final : OpConversionPattern<cf::BranchOp> {
+  using OpConversionPattern::OpConversionPattern;
 
   LogicalResult
   matchAndRewrite(cf::BranchOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
+    if (failed(checkBlockArguments(*op.getDest(), op, rewriter,
+                                   *getTypeConverter())))
+      return failure();
+
     rewriter.replaceOpWithNewOp<spirv::BranchOp>(op, op.getDest(),
                                                  adaptor.getDestOperands());
     return success();
@@ -45,16 +67,24 @@ struct BranchOpPattern final : public OpConversionPattern<cf::BranchOp> {
 };
 
 /// Converts cf.cond_br to spirv.BranchConditional.
-struct CondBranchOpPattern final
-    : public OpConversionPattern<cf::CondBranchOp> {
-  using OpConversionPattern<cf::CondBranchOp>::OpConversionPattern;
+struct CondBranchOpPattern final : OpConversionPattern<cf::CondBranchOp> {
+  using OpConversionPattern::OpConversionPattern;
 
   LogicalResult
   matchAndRewrite(cf::CondBranchOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
+    if (failed(checkBlockArguments(*op.getTrueDest(), op, rewriter,
+                                   *getTypeConverter())))
+      return failure();
+
+    if (failed(checkBlockArguments(*op.getFalseDest(), op, rewriter,
+                                   *getTypeConverter())))
+      return failure();
+
     rewriter.replaceOpWithNewOp<spirv::BranchConditionalOp>(
-        op, op.getCondition(), op.getTrueDest(), adaptor.getTrueDestOperands(),
-        op.getFalseDest(), adaptor.getFalseDestOperands());
+        op, adaptor.getCondition(), op.getTrueDest(),
+        adaptor.getTrueDestOperands(), op.getFalseDest(),
+        adaptor.getFalseDestOperands());
     return success();
   }
 };

diff  --git a/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRVPass.cpp b/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRVPass.cpp
index d8aecae257b461e..a752b82eac7c343 100644
--- a/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRVPass.cpp
+++ b/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRVPass.cpp
@@ -25,7 +25,7 @@ using namespace mlir;
 
 namespace {
 /// A pass converting MLIR ControlFlow operations into the SPIR-V dialect.
-class ConvertControlFlowToSPIRVPass
+class ConvertControlFlowToSPIRVPass final
     : public impl::ConvertControlFlowToSPIRVBase<
           ConvertControlFlowToSPIRVPass> {
   void runOnOperation() override;
@@ -44,6 +44,8 @@ void ConvertControlFlowToSPIRVPass::runOnOperation() {
   options.emulateLT32BitScalarTypes = this->emulateLT32BitScalarTypes;
   SPIRVTypeConverter typeConverter(targetAttr, options);
 
+  // TODO: We should also take care of block argument type conversion.
+
   RewritePatternSet patterns(context);
   cf::populateControlFlowToSPIRVPatterns(typeConverter, patterns);
 

diff  --git a/mlir/test/Conversion/ControlFlowToSPIRV/cf-ops-to-spirv.mlir b/mlir/test/Conversion/ControlFlowToSPIRV/cf-ops-to-spirv.mlir
index 98ee6891453496d..a10a1468e582499 100644
--- a/mlir/test/Conversion/ControlFlowToSPIRV/cf-ops-to-spirv.mlir
+++ b/mlir/test/Conversion/ControlFlowToSPIRV/cf-ops-to-spirv.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -split-input-file -convert-cf-to-spirv -verify-diagnostics %s | FileCheck %s
+// RUN: mlir-opt --split-input-file --convert-cf-to-spirv --verify-diagnostics %s | FileCheck %s
 
 //===----------------------------------------------------------------------===//
 // cf.br, cf.cond_br
@@ -36,3 +36,20 @@ func.func @simple_loop(%begin: i32, %end: i32, %step: i32) {
 }
 
 }
+
+// -----
+
+// TODO: We should handle blocks whose arguments require type conversion.
+
+// CHECK-LABEL: func.func @main_graph
+func.func @main_graph(%arg0: index) {
+  %c3 = arith.constant 1 : index
+  cf.br ^bb1(%arg0 : index)
+^bb1(%0: index):  // 2 preds: ^bb0, ^bb2
+  %1 = arith.cmpi slt, %0, %c3 : index
+  cf.cond_br %1, ^bb2, ^bb3
+^bb2:  // pred: ^bb1
+  cf.br ^bb1(%c3 : index)
+^bb3:  // pred: ^bb1
+  return
+}


        


More information about the Mlir-commits mailing list