[Mlir-commits] [mlir] [mlir][LLVM] `ControlFlowToLLVM`: Add 1:N type conversion support (PR #153937)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sat Aug 16 01:28:17 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Matthias Springer (matthias-springer)
<details>
<summary>Changes</summary>
Add support for 1:N type conversions to the `ControlFlowToLLVM` lowering patterns. Not applicable to `cf.switch` and `cf.assert`.
---
Full diff: https://github.com/llvm/llvm-project/pull/153937.diff
3 Files Affected:
- (modified) mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp (+35-11)
- (modified) mlir/test/Conversion/MemRefToLLVM/type-conversion.mlir (+17)
- (modified) mlir/test/lib/Dialect/LLVM/TestPatterns.cpp (+2)
``````````diff
diff --git a/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp b/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp
index ff6d369176393..fa0023d6a0621 100644
--- a/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp
+++ b/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp
@@ -125,22 +125,33 @@ static FailureOr<Block *> getConvertedBlock(ConversionPatternRewriter &rewriter,
return rewriter.applySignatureConversion(block, *conversion, converter);
}
+/// Flatten the given value ranges into a single vector of values.
+static SmallVector<Value> flattenValues(ArrayRef<ValueRange> values) {
+ SmallVector<Value> result;
+ for (const auto &vals : values)
+ llvm::append_range(result, vals);
+ return result;
+}
+
/// Convert the destination block signature (if necessary) and lower the branch
/// op to llvm.br.
struct BranchOpLowering : public ConvertOpToLLVMPattern<cf::BranchOp> {
using ConvertOpToLLVMPattern<cf::BranchOp>::ConvertOpToLLVMPattern;
+ using Adaptor =
+ typename ConvertOpToLLVMPattern<cf::BranchOp>::OneToNOpAdaptor;
LogicalResult
- matchAndRewrite(cf::BranchOp op, typename cf::BranchOp::Adaptor adaptor,
+ matchAndRewrite(cf::BranchOp op, Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
+ SmallVector<Value> flattenedAdaptor = flattenValues(adaptor.getOperands());
FailureOr<Block *> convertedBlock =
getConvertedBlock(rewriter, getTypeConverter(), op, op.getSuccessor(),
- TypeRange(adaptor.getOperands()));
+ TypeRange(flattenedAdaptor));
if (failed(convertedBlock))
return failure();
DictionaryAttr attrs = op->getAttrDictionary();
Operation *newOp = rewriter.replaceOpWithNewOp<LLVM::BrOp>(
- op, adaptor.getOperands(), *convertedBlock);
+ op, flattenedAdaptor, *convertedBlock);
// TODO: We should not just forward all attributes like that. But there are
// existing Flang tests that depend on this behavior.
newOp->setAttrs(attrs);
@@ -152,29 +163,42 @@ struct BranchOpLowering : public ConvertOpToLLVMPattern<cf::BranchOp> {
/// branch op to llvm.cond_br.
struct CondBranchOpLowering : public ConvertOpToLLVMPattern<cf::CondBranchOp> {
using ConvertOpToLLVMPattern<cf::CondBranchOp>::ConvertOpToLLVMPattern;
+ using Adaptor =
+ typename ConvertOpToLLVMPattern<cf::CondBranchOp>::OneToNOpAdaptor;
LogicalResult
- matchAndRewrite(cf::CondBranchOp op,
- typename cf::CondBranchOp::Adaptor adaptor,
+ matchAndRewrite(cf::CondBranchOp op, Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
+ SmallVector<Value> flattenedAdaptorTrue =
+ flattenValues(adaptor.getTrueDestOperands());
+ SmallVector<Value> flattenedAdaptorFalse =
+ flattenValues(adaptor.getFalseDestOperands());
+ if (!llvm::hasSingleElement(adaptor.getCondition()))
+ return rewriter.notifyMatchFailure(op,
+ "expected single element condition");
FailureOr<Block *> convertedTrueBlock =
getConvertedBlock(rewriter, getTypeConverter(), op, op.getTrueDest(),
- TypeRange(adaptor.getTrueDestOperands()));
+ TypeRange(flattenedAdaptorTrue));
if (failed(convertedTrueBlock))
return failure();
FailureOr<Block *> convertedFalseBlock =
getConvertedBlock(rewriter, getTypeConverter(), op, op.getFalseDest(),
- TypeRange(adaptor.getFalseDestOperands()));
+ TypeRange(flattenedAdaptorFalse));
if (failed(convertedFalseBlock))
return failure();
DictionaryAttr attrs = op->getAttrDictionary();
auto newOp = rewriter.replaceOpWithNewOp<LLVM::CondBrOp>(
- op, adaptor.getCondition(), adaptor.getTrueDestOperands(),
- adaptor.getFalseDestOperands(), op.getBranchWeightsAttr(),
+ op, llvm::getSingleElement(adaptor.getCondition()),
+ flattenedAdaptorTrue, flattenedAdaptorFalse, op.getBranchWeightsAttr(),
*convertedTrueBlock, *convertedFalseBlock);
// TODO: We should not just forward all attributes like that. But there are
- // existing Flang tests that depend on this behavior.
- newOp->setAttrs(attrs);
+ // existing Flang tests that depend on this behavior. E.g., it is incorrect
+ // to forward the `operandSegmentSizes` attribute. We cannot hard-code all
+ // attributes that must be excluded from forwarding.
+ for (NamedAttribute attr : attrs) {
+ if (attr.getName() != cf::CondBranchOp::getOperandSegmentSizeAttr())
+ newOp->setAttr(attr.getName(), attr.getValue());
+ }
return success();
}
};
diff --git a/mlir/test/Conversion/MemRefToLLVM/type-conversion.mlir b/mlir/test/Conversion/MemRefToLLVM/type-conversion.mlir
index c1751f282b002..6c6756f5097b4 100644
--- a/mlir/test/Conversion/MemRefToLLVM/type-conversion.mlir
+++ b/mlir/test/Conversion/MemRefToLLVM/type-conversion.mlir
@@ -138,3 +138,20 @@ func.func @caller(%arg0: i1, %arg1: i17) -> (i17, i1, i17) {
%res:2 = func.call @multi_return(%arg1, %arg0) : (i17, i1) -> (i17, i1)
return %res#0, %res#1, %res#0 : i17, i1, i17
}
+
+// -----
+
+// CHECK-LABEL: llvm.func @branch(
+// CHECK-SAME: %[[arg0:.*]]: i1, %[[arg1:.*]]: i18, %[[arg2:.*]]: i18)
+// CHECK: llvm.br ^[[bb1:.*]](%[[arg1]], %[[arg2]], %[[arg0]] : i18, i18, i1)
+// CHECK: ^[[bb1]](%[[arg3:.*]]: i18, %[[arg4:.*]]: i18, %[[arg5:.*]]: i1):
+// CHECK: llvm.cond_br %[[arg5]], ^[[bb1]](%[[arg1]], %[[arg2]], %[[arg5]] : i18, i18, i1), ^[[bb2:.*]](%[[arg3]], %[[arg4]] : i18, i18)
+// CHECK: ^bb2(%{{.*}}: i18, %{{.*}}: i18):
+// CHECK: llvm.return
+func.func @branch(%arg0: i1, %arg1: i17) {
+ cf.br ^bb1(%arg1, %arg0: i17, i1)
+^bb1(%arg2: i17, %arg3: i1):
+ cf.cond_br %arg3, ^bb1(%arg1, %arg3 : i17, i1), ^bb2(%arg2 : i17)
+^bb2(%arg4: i17):
+ return
+}
diff --git a/mlir/test/lib/Dialect/LLVM/TestPatterns.cpp b/mlir/test/lib/Dialect/LLVM/TestPatterns.cpp
index fe9aa0f2a9902..9d30ae43cccc1 100644
--- a/mlir/test/lib/Dialect/LLVM/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/LLVM/TestPatterns.cpp
@@ -6,6 +6,7 @@
//
//===----------------------------------------------------------------------===//
+#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
@@ -70,6 +71,7 @@ struct TestLLVMLegalizePatternsPass
mlir::RewritePatternSet patterns(ctx);
patterns.add<TestDirectReplacementOp>(ctx, converter);
populateFuncToLLVMConversionPatterns(converter, patterns);
+ cf::populateControlFlowToLLVMConversionPatterns(converter, patterns);
// Define the conversion target used for the test.
ConversionTarget target(*ctx);
``````````
</details>
https://github.com/llvm/llvm-project/pull/153937
More information about the Mlir-commits
mailing list