[Mlir-commits] [mlir] [mlir][CF] Add structural type conversion patterns (PR #165629)

Matthias Springer llvmlistbot at llvm.org
Wed Oct 29 15:05:42 PDT 2025


https://github.com/matthias-springer created https://github.com/llvm/llvm-project/pull/165629

Add structural type conversion patterns for CF dialect ops. These patterns are similar to the SCF structural type conversion patterns.


>From d224c1becdfd18f9d68b1cf35c90e773922b7953 Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Wed, 29 Oct 2025 22:03:27 +0000
Subject: [PATCH] [mlir][CF] Add structural type conversion patterns

---
 .../Transforms/StructuralTypeConversions.h    |  48 +++++
 .../ControlFlow/Transforms/CMakeLists.txt     |   1 +
 .../Transforms/StructuralTypeConversions.cpp  | 169 ++++++++++++++++++
 .../test-legalize-type-conversion.mlir        |  22 +++
 mlir/test/lib/Dialect/Test/CMakeLists.txt     |   1 +
 mlir/test/lib/Dialect/Test/TestPatterns.cpp   |   7 +
 6 files changed, 248 insertions(+)
 create mode 100644 mlir/include/mlir/Dialect/ControlFlow/Transforms/StructuralTypeConversions.h
 create mode 100644 mlir/lib/Dialect/ControlFlow/Transforms/StructuralTypeConversions.cpp

diff --git a/mlir/include/mlir/Dialect/ControlFlow/Transforms/StructuralTypeConversions.h b/mlir/include/mlir/Dialect/ControlFlow/Transforms/StructuralTypeConversions.h
new file mode 100644
index 0000000000000..a32d9e2025c76
--- /dev/null
+++ b/mlir/include/mlir/Dialect/ControlFlow/Transforms/StructuralTypeConversions.h
@@ -0,0 +1,48 @@
+//===- StructuralTypeConversions.h - CF Type Conversions --------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_CONTROL_FLOW_TRANSFORMS_STRUCTURAL_TYPE_CONVERSIONS_H
+#define MLIR_DIALECT_CONTROL_FLOW_TRANSFORMS_STRUCTURAL_TYPE_CONVERSIONS_H
+
+#include "mlir/IR/PatternMatch.h"
+
+namespace mlir {
+
+class ConversionTarget;
+class TypeConverter;
+
+namespace cf {
+
+/// Populates patterns for CF structural type conversions and sets up the
+/// provided ConversionTarget with the appropriate legality configuration for
+/// the ops to get converted properly.
+///
+/// A "structural" type conversion is one where the underlying ops are
+/// completely agnostic to the actual types involved and simply need to update
+/// their types. An example of this is cf.br -- the cf.br op needs to update
+/// its types accordingly to the TypeConverter, but otherwise does not care
+/// what type conversions are happening.
+void populateCFStructuralTypeConversionsAndLegality(
+    const TypeConverter &typeConverter, RewritePatternSet &patterns,
+    ConversionTarget &target, PatternBenefit benefit = 1);
+
+/// Similar to `populateCFStructuralTypeConversionsAndLegality` but does not
+/// populate the conversion target.
+void populateCFStructuralTypeConversions(const TypeConverter &typeConverter,
+                                         RewritePatternSet &patterns,
+                                         PatternBenefit benefit = 1);
+
+/// Updates the ConversionTarget with dynamic legality of CF operations based
+/// on the provided type converter.
+void populateCFStructuralTypeConversionTarget(
+    const TypeConverter &typeConverter, ConversionTarget &target);
+
+} // namespace cf
+} // namespace mlir
+
+#endif // MLIR_DIALECT_CONTROL_FLOW_TRANSFORMS_STRUCTURAL_TYPE_CONVERSIONS_H
diff --git a/mlir/lib/Dialect/ControlFlow/Transforms/CMakeLists.txt b/mlir/lib/Dialect/ControlFlow/Transforms/CMakeLists.txt
index 47740d31844f4..e9da135ed46f9 100644
--- a/mlir/lib/Dialect/ControlFlow/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/ControlFlow/Transforms/CMakeLists.txt
@@ -1,6 +1,7 @@
 add_mlir_dialect_library(MLIRControlFlowTransforms
   BufferDeallocationOpInterfaceImpl.cpp
   BufferizableOpInterfaceImpl.cpp
+  StructuralTypeConversions.cpp
 
   ADDITIONAL_HEADER_DIRS
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/ControlFlow/Transforms
diff --git a/mlir/lib/Dialect/ControlFlow/Transforms/StructuralTypeConversions.cpp b/mlir/lib/Dialect/ControlFlow/Transforms/StructuralTypeConversions.cpp
new file mode 100644
index 0000000000000..5e2a742c2d64c
--- /dev/null
+++ b/mlir/lib/Dialect/ControlFlow/Transforms/StructuralTypeConversions.cpp
@@ -0,0 +1,169 @@
+//===- TypeConversion.cpp - Type Conversion of Unstructured Control Flow --===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a pass to convert MLIR standard and builtin dialects
+// into the LLVM IR dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/ControlFlow/Transforms/StructuralTypeConversions.h"
+
+#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+using namespace mlir;
+
+namespace {
+
+/// Helper function for converting branch ops. This function converts the
+/// signature of the given block. If the new block signature is different from
+/// `expectedTypes`, returns "failure".
+static FailureOr<Block *> getConvertedBlock(ConversionPatternRewriter &rewriter,
+                                            const TypeConverter *converter,
+                                            Operation *branchOp, Block *block,
+                                            TypeRange expectedTypes) {
+  assert(converter && "expected non-null type converter");
+  assert(!block->isEntryBlock() && "entry blocks have no predecessors");
+
+  // There is nothing to do if the types already match.
+  if (block->getArgumentTypes() == expectedTypes)
+    return block;
+
+  // Compute the new block argument types and convert the block.
+  std::optional<TypeConverter::SignatureConversion> conversion =
+      converter->convertBlockSignature(block);
+  if (!conversion)
+    return rewriter.notifyMatchFailure(branchOp,
+                                       "could not compute block signature");
+  if (expectedTypes != conversion->getConvertedTypes())
+    return rewriter.notifyMatchFailure(
+        branchOp,
+        "mismatch between adaptor operand types and computed block signature");
+  return rewriter.applySignatureConversion(block, *conversion, converter);
+}
+
+/// Flatten the given value ranges into a single vector of values.
+static SmallVector<Value> flattenValues(ArrayRef<ValueRange> values) {
+  SmallVector<Value> result;
+  for (const ValueRange &vals : values)
+    llvm::append_range(result, vals);
+  return result;
+}
+
+/// Convert the destination block signature (if necessary) and change the
+/// operands of the branch op.
+struct BranchOpConversion : public OpConversionPattern<cf::BranchOp> {
+  using OpConversionPattern<cf::BranchOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(cf::BranchOp op, OneToNOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    SmallVector<Value> flattenedAdaptor = flattenValues(adaptor.getOperands());
+    FailureOr<Block *> convertedBlock =
+        getConvertedBlock(rewriter, getTypeConverter(), op, op.getSuccessor(),
+                          TypeRange(ValueRange(flattenedAdaptor)));
+    if (failed(convertedBlock))
+      return failure();
+    rewriter.replaceOpWithNewOp<cf::BranchOp>(op, flattenedAdaptor,
+                                              *convertedBlock);
+    return success();
+  }
+};
+
+/// Convert the destination block signatures (if necessary) and change the
+/// operands of the branch op.
+struct CondBranchOpConversion : public OpConversionPattern<cf::CondBranchOp> {
+  using OpConversionPattern<cf::CondBranchOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(cf::CondBranchOp op, OneToNOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    SmallVector<Value> flattenedAdaptorTrue =
+        flattenValues(adaptor.getTrueDestOperands());
+    SmallVector<Value> flattenedAdaptorFalse =
+        flattenValues(adaptor.getFalseDestOperands());
+    if (!llvm::hasSingleElement(adaptor.getCondition()))
+      return rewriter.notifyMatchFailure(op,
+                                         "expected single element condition");
+    FailureOr<Block *> convertedTrueBlock =
+        getConvertedBlock(rewriter, getTypeConverter(), op, op.getTrueDest(),
+                          TypeRange(ValueRange(flattenedAdaptorTrue)));
+    if (failed(convertedTrueBlock))
+      return failure();
+    FailureOr<Block *> convertedFalseBlock =
+        getConvertedBlock(rewriter, getTypeConverter(), op, op.getFalseDest(),
+                          TypeRange(ValueRange(flattenedAdaptorFalse)));
+    if (failed(convertedFalseBlock))
+      return failure();
+    rewriter.replaceOpWithNewOp<cf::CondBranchOp>(
+        op, llvm::getSingleElement(adaptor.getCondition()),
+        flattenedAdaptorTrue, flattenedAdaptorFalse, op.getBranchWeightsAttr(),
+        *convertedTrueBlock, *convertedFalseBlock);
+    return success();
+  }
+};
+
+/// Convert the destination block signatures (if necessary) and change the
+/// operands of the switch op.
+struct SwitchOpConversion : public OpConversionPattern<cf::SwitchOp> {
+  using OpConversionPattern<cf::SwitchOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(cf::SwitchOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    // Get or convert default block.
+    FailureOr<Block *> convertedDefaultBlock = getConvertedBlock(
+        rewriter, getTypeConverter(), op, op.getDefaultDestination(),
+        TypeRange(adaptor.getDefaultOperands()));
+    if (failed(convertedDefaultBlock))
+      return failure();
+
+    // Get or convert all case blocks.
+    SmallVector<Block *> caseDestinations;
+    SmallVector<ValueRange> caseOperands = adaptor.getCaseOperands();
+    for (auto it : llvm::enumerate(op.getCaseDestinations())) {
+      Block *b = it.value();
+      FailureOr<Block *> convertedBlock =
+          getConvertedBlock(rewriter, getTypeConverter(), op, b,
+                            TypeRange(caseOperands[it.index()]));
+      if (failed(convertedBlock))
+        return failure();
+      caseDestinations.push_back(*convertedBlock);
+    }
+
+    rewriter.replaceOpWithNewOp<cf::SwitchOp>(
+        op, adaptor.getFlag(), *convertedDefaultBlock,
+        adaptor.getDefaultOperands(), adaptor.getCaseValuesAttr(),
+        caseDestinations, caseOperands);
+    return success();
+  }
+};
+
+} // namespace
+
+void mlir::cf::populateCFStructuralTypeConversions(
+    const TypeConverter &typeConverter, RewritePatternSet &patterns,
+    PatternBenefit benefit) {
+  patterns.add<BranchOpConversion, CondBranchOpConversion, SwitchOpConversion>(
+      typeConverter, patterns.getContext(), benefit);
+}
+
+void mlir::cf::populateCFStructuralTypeConversionTarget(
+    const TypeConverter &typeConverter, ConversionTarget &target) {
+  target.addDynamicallyLegalOp<cf::BranchOp, cf::CondBranchOp, cf::SwitchOp>(
+      [&](Operation *op) { return typeConverter.isLegal(op->getOperands()); });
+}
+
+void mlir::cf::populateCFStructuralTypeConversionsAndLegality(
+    const TypeConverter &typeConverter, RewritePatternSet &patterns,
+    ConversionTarget &target, PatternBenefit benefit) {
+  populateCFStructuralTypeConversions(typeConverter, patterns, benefit);
+  populateCFStructuralTypeConversionTarget(typeConverter, target);
+}
diff --git a/mlir/test/Transforms/test-legalize-type-conversion.mlir b/mlir/test/Transforms/test-legalize-type-conversion.mlir
index c003f8b2cb1cd..91f83a0afaeef 100644
--- a/mlir/test/Transforms/test-legalize-type-conversion.mlir
+++ b/mlir/test/Transforms/test-legalize-type-conversion.mlir
@@ -143,3 +143,25 @@ func.func @test_signature_conversion_no_converter() {
   return
 }
 
+// -----
+
+// CHECK-LABEL: func @test_unstructured_cf_conversion(
+//  CHECK-SAME:     %[[arg0:.*]]: f64, %[[c:.*]]: i1)
+//       CHECK:   %[[cast1:.*]] = "test.cast"(%[[arg0]]) : (f64) -> f32
+//       CHECK:   "test.foo"(%[[cast1]])
+//       CHECK:   cf.br ^[[bb1:.*]](%[[arg0]] : f64)
+//       CHECK: ^[[bb1]](%[[arg1:.*]]: f64):
+//       CHECK:   cf.cond_br %[[c]], ^[[bb1]](%[[arg1]] : f64), ^[[bb2:.*]](%[[arg1]] : f64)
+//       CHECK: ^[[bb2]](%[[arg2:.*]]: f64):
+//       CHECK:   %[[cast2:.*]] = "test.cast"(%[[arg2]]) : (f64) -> f32
+//       CHECK:   "test.bar"(%[[cast2]])
+//       CHECK: return
+func.func @test_unstructured_cf_conversion(%arg0: f32, %c: i1) {
+  "test.foo"(%arg0) : (f32) -> ()
+  cf.br ^bb1(%arg0: f32)
+^bb1(%arg1: f32):
+  cf.cond_br %c, ^bb1(%arg1 : f32), ^bb2(%arg1 : f32)
+^bb2(%arg2: f32):
+  "test.bar"(%arg2) : (f32) -> ()
+  return
+}
diff --git a/mlir/test/lib/Dialect/Test/CMakeLists.txt b/mlir/test/lib/Dialect/Test/CMakeLists.txt
index f099d01abd31a..9354a85d984c9 100644
--- a/mlir/test/lib/Dialect/Test/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/Test/CMakeLists.txt
@@ -71,6 +71,7 @@ add_mlir_library(MLIRTestDialect
   )
 mlir_target_link_libraries(MLIRTestDialect PUBLIC
   MLIRControlFlowInterfaces
+  MLIRControlFlowTransforms
   MLIRDataLayoutInterfaces
   MLIRDerivedAttributeOpInterface
   MLIRDestinationStyleOpInterface
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index efbdbfb65d65b..fd2b943ff1296 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -11,6 +11,7 @@
 #include "TestTypes.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/CommonFolders.h"
+#include "mlir/Dialect/ControlFlow/Transforms/StructuralTypeConversions.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/Func/Transforms/FuncConversions.h"
 #include "mlir/Dialect/SCF/Transforms/Patterns.h"
@@ -2042,6 +2043,10 @@ struct TestTypeConversionDriver
     });
     converter.addConversion([](IndexType type) { return type; });
     converter.addConversion([](IntegerType type, SmallVectorImpl<Type> &types) {
+      if (type.isInteger(1)) {
+        // i1 is legal.
+        types.push_back(type);
+      }
       if (type.isInteger(38)) {
         // i38 is legal.
         types.push_back(type);
@@ -2175,6 +2180,8 @@ struct TestTypeConversionDriver
                                                               converter);
     mlir::scf::populateSCFStructuralTypeConversionsAndLegality(
         converter, patterns, target);
+    mlir::cf::populateCFStructuralTypeConversionsAndLegality(converter,
+                                                             patterns, target);
 
     ConversionConfig config;
     config.allowPatternRollback = allowPatternRollback;



More information about the Mlir-commits mailing list