[Mlir-commits] [mlir] 448adfe - [mlir] Only conditionally lower CF branching ops to LLVM

Tres Popp llvmlistbot at llvm.org
Thu Aug 4 07:37:09 PDT 2022


Author: Tres Popp
Date: 2022-08-04T16:36:27+02:00
New Revision: 448adfee05b737a26dda34e7ae2cd4948760fff0

URL: https://github.com/llvm/llvm-project/commit/448adfee05b737a26dda34e7ae2cd4948760fff0
DIFF: https://github.com/llvm/llvm-project/commit/448adfee05b737a26dda34e7ae2cd4948760fff0.diff

LOG: [mlir] Only conditionally lower CF branching ops to LLVM

Previously cf.br cf.cond_br and cf.switch always lowered to their LLVM
equivalents. These ops are all ops that take in some values of given
types and jump to other blocks with argument lists of the same types. If
the types are not the same, a verification failure will later occur. This led
to confusions, as everything works when func->llvm and cf->llvm lowering
both occur because func->llvm updates the blocks and argument lists
while cf->llvm updates the branching ops. Without func->llvm though,
there will potentially be a type mismatch.

This change now only lowers the CF ops if they will later pass
verification. This is possible because the parent op and its blocks will
be updated before the contained branching ops, so they can test their
new operand types against the types of the blocks they jump to.

Another plan was to have func->llvm only update the entry block
signature and to allow cf->llvm to update all other blocks, but this had
2 problems:
1. This would create a FuncOp lowering in cf->llvm lowering which is
   awkward
2. This new pattern would only be applied if the containing FuncOp is
   marked invalid. This is infeasible with the shared LLVM type
   conversion/target infrastructure.

See previous discussions at
https://discourse.llvm.org/t/lowering-cf-to-llvm/63863 and
https://github.com/llvm/llvm-project/issues/55301

Differential Revision: https://reviews.llvm.org/D130971

Added: 
    mlir/test/Conversion/ControlFlowToLLVM/invalid.mlir

Modified: 
    mlir/docs/TargetLLVMIR.md
    mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/docs/TargetLLVMIR.md b/mlir/docs/TargetLLVMIR.md
index bf639b81dd64f..5ba7adc2463ae 100644
--- a/mlir/docs/TargetLLVMIR.md
+++ b/mlir/docs/TargetLLVMIR.md
@@ -16,6 +16,14 @@ are expected to closely match the corresponding LLVM IR instructions and
 intrinsics. This minimizes the dependency on LLVM IR libraries in MLIR as well
 as reduces the churn in case of changes.
 
+Note that many 
diff erent dialects can be lowered to LLVM but are provided as
+
diff erent sets of patterns and have 
diff erent passes available to mlir-opt.
+However, this is primarily useful for testing and prototyping, and using the
+collection of patterns together is highly recommended. One place this is
+important and visible is the ControlFlow dialect's branching operations which
+will fail to apply if their types mismatch with the blocks they jump to in the
+parent op.
+
 SPIR-V to LLVM dialect conversion has a
 [dedicated document](SPIRVToLLVMDialectConversion.md).
 

diff  --git a/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp b/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp
index cc97ef73d7bfa..89012704541d6 100644
--- a/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp
+++ b/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp
@@ -22,6 +22,7 @@
 #include "mlir/IR/BuiltinOps.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Transforms/DialectConversion.h"
+#include "llvm/ADT/StringRef.h"
 #include <functional>
 
 using namespace mlir;
@@ -71,34 +72,108 @@ struct AssertOpLowering : public ConvertOpToLLVMPattern<cf::AssertOp> {
   }
 };
 
-// Base class for LLVM IR lowering terminator operations with successors.
-template <typename SourceOp, typename TargetOp>
-struct OneToOneLLVMTerminatorLowering
-    : public ConvertOpToLLVMPattern<SourceOp> {
-  using ConvertOpToLLVMPattern<SourceOp>::ConvertOpToLLVMPattern;
-  using Base = OneToOneLLVMTerminatorLowering<SourceOp, TargetOp>;
+/// The cf->LLVM lowerings for branching ops require that the blocks they jump
+/// to first have updated types which should be handled by a pattern operating
+/// on the parent op.
+static LogicalResult verifyMatchingValues(ConversionPatternRewriter &rewriter,
+                                          ValueRange operands,
+                                          ValueRange blockArgs, Location loc,
+                                          llvm::StringRef messagePrefix) {
+  for (const auto &idxAndTypes :
+       llvm::enumerate(llvm::zip(blockArgs, operands))) {
+    int64_t i = idxAndTypes.index();
+    Value argValue =
+        rewriter.getRemappedValue(std::get<0>(idxAndTypes.value()));
+    Type operandType = std::get<1>(idxAndTypes.value()).getType();
+    // In the case of an invalid jump, the block argument will have been
+    // remapped to an UnrealizedConversionCast. In the case of a valid jump,
+    // there might still be a no-op conversion cast with both types being equal.
+    // Consider both of these details to see if the jump would be invalid.
+    if (auto op = dyn_cast_or_null<UnrealizedConversionCastOp>(
+            argValue.getDefiningOp())) {
+      if (op.getOperandTypes().front() != operandType) {
+        return rewriter.notifyMatchFailure(loc, [&](Diagnostic &diag) {
+          diag << messagePrefix;
+          diag << "mismatched types from operand # " << i << " ";
+          diag << operandType;
+          diag << " not compatible with destination block argument type ";
+          diag << argValue.getType();
+          diag << " which should be converted with the parent op.";
+        });
+      }
+    }
+  }
+  return success();
+}
+
+/// Ensure that all block types were updated and then create an LLVM::BrOp
+struct BranchOpLowering : public ConvertOpToLLVMPattern<cf::BranchOp> {
+  using ConvertOpToLLVMPattern<cf::BranchOp>::ConvertOpToLLVMPattern;
 
   LogicalResult
-  matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor,
+  matchAndRewrite(cf::BranchOp op, typename cf::BranchOp::Adaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    rewriter.replaceOpWithNewOp<TargetOp>(op, adaptor.getOperands(),
-                                          op->getSuccessors(), op->getAttrs());
+    if (failed(verifyMatchingValues(rewriter, adaptor.getDestOperands(),
+                                    op.getSuccessor()->getArguments(),
+                                    op.getLoc(),
+                                    /*messagePrefix=*/"")))
+      return failure();
+
+    rewriter.replaceOpWithNewOp<LLVM::BrOp>(
+        op, adaptor.getOperands(), op->getSuccessors(), op->getAttrs());
     return success();
   }
 };
 
-// FIXME: this should be tablegen'ed as well.
-struct BranchOpLowering
-    : public OneToOneLLVMTerminatorLowering<cf::BranchOp, LLVM::BrOp> {
-  using Base::Base;
-};
-struct CondBranchOpLowering
-    : public OneToOneLLVMTerminatorLowering<cf::CondBranchOp, LLVM::CondBrOp> {
-  using Base::Base;
+/// Ensure that all block types were updated and then create an LLVM::CondBrOp
+struct CondBranchOpLowering : public ConvertOpToLLVMPattern<cf::CondBranchOp> {
+  using ConvertOpToLLVMPattern<cf::CondBranchOp>::ConvertOpToLLVMPattern;
+
+  LogicalResult
+  matchAndRewrite(cf::CondBranchOp op,
+                  typename cf::CondBranchOp::Adaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    if (failed(verifyMatchingValues(rewriter, adaptor.getFalseDestOperands(),
+                                    op.getFalseDest()->getArguments(),
+                                    op.getLoc(), "in false case branch ")))
+      return failure();
+    if (failed(verifyMatchingValues(rewriter, adaptor.getTrueDestOperands(),
+                                    op.getTrueDest()->getArguments(),
+                                    op.getLoc(), "in true case branch ")))
+      return failure();
+
+    rewriter.replaceOpWithNewOp<LLVM::CondBrOp>(
+        op, adaptor.getOperands(), op->getSuccessors(), op->getAttrs());
+    return success();
+  }
 };
-struct SwitchOpLowering
-    : public OneToOneLLVMTerminatorLowering<cf::SwitchOp, LLVM::SwitchOp> {
-  using Base::Base;
+
+/// Ensure that all block types were updated and then create an LLVM::SwitchOp
+struct SwitchOpLowering : public ConvertOpToLLVMPattern<cf::SwitchOp> {
+  using ConvertOpToLLVMPattern<cf::SwitchOp>::ConvertOpToLLVMPattern;
+
+  LogicalResult
+  matchAndRewrite(cf::SwitchOp op, typename cf::SwitchOp::Adaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    if (failed(verifyMatchingValues(rewriter, adaptor.getDefaultOperands(),
+                                    op.getDefaultDestination()->getArguments(),
+                                    op.getLoc(), "in switch default case ")))
+      return failure();
+
+    for (const auto &i : llvm::enumerate(
+             llvm::zip(adaptor.getCaseOperands(), op.getCaseDestinations()))) {
+      if (failed(verifyMatchingValues(
+              rewriter, std::get<0>(i.value()),
+              std::get<1>(i.value())->getArguments(), op.getLoc(),
+              "in switch case " + std::to_string(i.index()) + " "))) {
+        return failure();
+      }
+    }
+
+    rewriter.replaceOpWithNewOp<LLVM::SwitchOp>(
+        op, adaptor.getOperands(), op->getSuccessors(), op->getAttrs());
+    return success();
+  }
 };
 
 } // namespace

diff  --git a/mlir/test/Conversion/ControlFlowToLLVM/invalid.mlir b/mlir/test/Conversion/ControlFlowToLLVM/invalid.mlir
new file mode 100644
index 0000000000000..a2afa233a26e8
--- /dev/null
+++ b/mlir/test/Conversion/ControlFlowToLLVM/invalid.mlir
@@ -0,0 +1,42 @@
+// RUN: mlir-opt %s -convert-cf-to-llvm | FileCheck %s
+
+func.func @name(%flag: i32, %pred: i1){
+    // Test cf.br lowering failure with type mismatch
+    // CHECK: cf.br
+    %c0 = arith.constant 0 : index
+    cf.br ^bb1(%c0 : index)
+
+  // Test cf.cond_br lowering failure with type mismatch in false_dest
+  // CHECK: cf.cond_br
+  ^bb1(%0: index):  // 2 preds: ^bb0, ^bb2
+    %c1 = arith.constant 1 : i1
+    %c2 = arith.constant 1 : index
+    cf.cond_br %pred, ^bb2(%c1: i1), ^bb3(%c2: index)
+
+  // Test cf.cond_br lowering failure with type mismatch in true_dest
+  // CHECK: cf.cond_br
+  ^bb2(%1: i1):
+    %c3 = arith.constant 1 : i1
+    %c4 = arith.constant 1 : index
+    cf.cond_br %pred, ^bb3(%c4: index), ^bb2(%c3: i1)
+
+  // Test cf.switch lowering failure with type mismatch in default case
+  // CHECK: cf.switch
+  ^bb3(%2: index):  // pred: ^bb1
+    %c5 = arith.constant 1 : i1
+    %c6 = arith.constant 1 : index
+    cf.switch %flag : i32, [
+      default: ^bb1(%c6 : index),
+      42: ^bb4(%c5 : i1)
+    ]
+
+  // Test cf.switch lowering failure with type mismatch in non-default case
+  // CHECK: cf.switch
+  ^bb4(%3: i1):  // pred: ^bb1
+    %c7 = arith.constant 1 : i1
+    %c8 = arith.constant 1 : index
+    cf.switch %flag : i32, [
+      default: ^bb2(%c7 : i1),
+      41: ^bb1(%c8 : index)
+    ]
+  }


        


More information about the Mlir-commits mailing list