[Mlir-commits] [mlir] [mlir][LLVM] `ControlFlowToLLVM`: Add 1:N type conversion support (PR #153937)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sat Aug 16 01:28:17 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Matthias Springer (matthias-springer)

<details>
<summary>Changes</summary>

Add support for 1:N type conversions to the `ControlFlowToLLVM` lowering patterns. Not applicable to `cf.switch` and `cf.assert`.


---
Full diff: https://github.com/llvm/llvm-project/pull/153937.diff


3 Files Affected:

- (modified) mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp (+35-11) 
- (modified) mlir/test/Conversion/MemRefToLLVM/type-conversion.mlir (+17) 
- (modified) mlir/test/lib/Dialect/LLVM/TestPatterns.cpp (+2) 


``````````diff
diff --git a/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp b/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp
index ff6d369176393..fa0023d6a0621 100644
--- a/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp
+++ b/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp
@@ -125,22 +125,33 @@ static FailureOr<Block *> getConvertedBlock(ConversionPatternRewriter &rewriter,
   return rewriter.applySignatureConversion(block, *conversion, converter);
 }
 
+/// Flatten the given value ranges into a single vector of values.
+static SmallVector<Value> flattenValues(ArrayRef<ValueRange> values) {
+  SmallVector<Value> result;
+  for (const auto &vals : values)
+    llvm::append_range(result, vals);
+  return result;
+}
+
 /// Convert the destination block signature (if necessary) and lower the branch
 /// op to llvm.br.
 struct BranchOpLowering : public ConvertOpToLLVMPattern<cf::BranchOp> {
   using ConvertOpToLLVMPattern<cf::BranchOp>::ConvertOpToLLVMPattern;
+  using Adaptor =
+      typename ConvertOpToLLVMPattern<cf::BranchOp>::OneToNOpAdaptor;
 
   LogicalResult
-  matchAndRewrite(cf::BranchOp op, typename cf::BranchOp::Adaptor adaptor,
+  matchAndRewrite(cf::BranchOp op, Adaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
+    SmallVector<Value> flattenedAdaptor = flattenValues(adaptor.getOperands());
     FailureOr<Block *> convertedBlock =
         getConvertedBlock(rewriter, getTypeConverter(), op, op.getSuccessor(),
-                          TypeRange(adaptor.getOperands()));
+                          TypeRange(flattenedAdaptor));
     if (failed(convertedBlock))
       return failure();
     DictionaryAttr attrs = op->getAttrDictionary();
     Operation *newOp = rewriter.replaceOpWithNewOp<LLVM::BrOp>(
-        op, adaptor.getOperands(), *convertedBlock);
+        op, flattenedAdaptor, *convertedBlock);
     // TODO: We should not just forward all attributes like that. But there are
     // existing Flang tests that depend on this behavior.
     newOp->setAttrs(attrs);
@@ -152,29 +163,42 @@ struct BranchOpLowering : public ConvertOpToLLVMPattern<cf::BranchOp> {
 /// branch op to llvm.cond_br.
 struct CondBranchOpLowering : public ConvertOpToLLVMPattern<cf::CondBranchOp> {
   using ConvertOpToLLVMPattern<cf::CondBranchOp>::ConvertOpToLLVMPattern;
+  using Adaptor =
+      typename ConvertOpToLLVMPattern<cf::CondBranchOp>::OneToNOpAdaptor;
 
   LogicalResult
-  matchAndRewrite(cf::CondBranchOp op,
-                  typename cf::CondBranchOp::Adaptor adaptor,
+  matchAndRewrite(cf::CondBranchOp op, Adaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
+    SmallVector<Value> flattenedAdaptorTrue =
+        flattenValues(adaptor.getTrueDestOperands());
+    SmallVector<Value> flattenedAdaptorFalse =
+        flattenValues(adaptor.getFalseDestOperands());
+    if (!llvm::hasSingleElement(adaptor.getCondition()))
+      return rewriter.notifyMatchFailure(op,
+                                         "expected single element condition");
     FailureOr<Block *> convertedTrueBlock =
         getConvertedBlock(rewriter, getTypeConverter(), op, op.getTrueDest(),
-                          TypeRange(adaptor.getTrueDestOperands()));
+                          TypeRange(flattenedAdaptorTrue));
     if (failed(convertedTrueBlock))
       return failure();
     FailureOr<Block *> convertedFalseBlock =
         getConvertedBlock(rewriter, getTypeConverter(), op, op.getFalseDest(),
-                          TypeRange(adaptor.getFalseDestOperands()));
+                          TypeRange(flattenedAdaptorFalse));
     if (failed(convertedFalseBlock))
       return failure();
     DictionaryAttr attrs = op->getAttrDictionary();
     auto newOp = rewriter.replaceOpWithNewOp<LLVM::CondBrOp>(
-        op, adaptor.getCondition(), adaptor.getTrueDestOperands(),
-        adaptor.getFalseDestOperands(), op.getBranchWeightsAttr(),
+        op, llvm::getSingleElement(adaptor.getCondition()),
+        flattenedAdaptorTrue, flattenedAdaptorFalse, op.getBranchWeightsAttr(),
         *convertedTrueBlock, *convertedFalseBlock);
     // TODO: We should not just forward all attributes like that. But there are
-    // existing Flang tests that depend on this behavior.
-    newOp->setAttrs(attrs);
+    // existing Flang tests that depend on this behavior. E.g., it is incorrect
+    // to forward the `operandSegmentSizes` attribute. We cannot hard-code all
+    // attributes that must be excluded from forwarding.
+    for (NamedAttribute attr : attrs) {
+      if (attr.getName() != cf::CondBranchOp::getOperandSegmentSizeAttr())
+        newOp->setAttr(attr.getName(), attr.getValue());
+    }
     return success();
   }
 };
diff --git a/mlir/test/Conversion/MemRefToLLVM/type-conversion.mlir b/mlir/test/Conversion/MemRefToLLVM/type-conversion.mlir
index c1751f282b002..6c6756f5097b4 100644
--- a/mlir/test/Conversion/MemRefToLLVM/type-conversion.mlir
+++ b/mlir/test/Conversion/MemRefToLLVM/type-conversion.mlir
@@ -138,3 +138,20 @@ func.func @caller(%arg0: i1, %arg1: i17) -> (i17, i1, i17) {
   %res:2 = func.call @multi_return(%arg1, %arg0) : (i17, i1) -> (i17, i1)
   return %res#0, %res#1, %res#0 : i17, i1, i17
 }
+
+// -----
+
+// CHECK-LABEL: llvm.func @branch(
+//  CHECK-SAME:     %[[arg0:.*]]: i1, %[[arg1:.*]]: i18, %[[arg2:.*]]: i18)
+//       CHECK:   llvm.br ^[[bb1:.*]](%[[arg1]], %[[arg2]], %[[arg0]] : i18, i18, i1)
+//       CHECK: ^[[bb1]](%[[arg3:.*]]: i18, %[[arg4:.*]]: i18, %[[arg5:.*]]: i1):
+//       CHECK:   llvm.cond_br %[[arg5]], ^[[bb1]](%[[arg1]], %[[arg2]], %[[arg5]] : i18, i18, i1), ^[[bb2:.*]](%[[arg3]], %[[arg4]] : i18, i18)
+//       CHECK: ^bb2(%{{.*}}: i18, %{{.*}}: i18):
+//       CHECK:   llvm.return
+func.func @branch(%arg0: i1, %arg1: i17) {
+  cf.br ^bb1(%arg1, %arg0: i17, i1)
+^bb1(%arg2: i17, %arg3: i1):
+  cf.cond_br %arg3, ^bb1(%arg1, %arg3 : i17, i1), ^bb2(%arg2 : i17)
+^bb2(%arg4: i17):
+  return
+}
diff --git a/mlir/test/lib/Dialect/LLVM/TestPatterns.cpp b/mlir/test/lib/Dialect/LLVM/TestPatterns.cpp
index fe9aa0f2a9902..9d30ae43cccc1 100644
--- a/mlir/test/lib/Dialect/LLVM/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/LLVM/TestPatterns.cpp
@@ -6,6 +6,7 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
 #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"
 #include "mlir/Conversion/LLVMCommon/TypeConverter.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
@@ -70,6 +71,7 @@ struct TestLLVMLegalizePatternsPass
     mlir::RewritePatternSet patterns(ctx);
     patterns.add<TestDirectReplacementOp>(ctx, converter);
     populateFuncToLLVMConversionPatterns(converter, patterns);
+    cf::populateControlFlowToLLVMConversionPatterns(converter, patterns);
 
     // Define the conversion target used for the test.
     ConversionTarget target(*ctx);

``````````

</details>


https://github.com/llvm/llvm-project/pull/153937


More information about the Mlir-commits mailing list