[Mlir-commits] [mlir] ca84e9e - [mlir][CF] Add structural type conversion patterns (#165629)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Oct 29 18:12:45 PDT 2025
Author: Matthias Springer
Date: 2025-10-29T18:12:40-07:00
New Revision: ca84e9e8260983d17f02a76de8b1ee51cd3ed896
URL: https://github.com/llvm/llvm-project/commit/ca84e9e8260983d17f02a76de8b1ee51cd3ed896
DIFF: https://github.com/llvm/llvm-project/commit/ca84e9e8260983d17f02a76de8b1ee51cd3ed896.diff
LOG: [mlir][CF] Add structural type conversion patterns (#165629)
Add structural type conversion patterns for CF dialect ops. These
patterns are similar to the SCF structural type conversion patterns.
This commit adds missing functionality and is in preparation of #165180,
which changes the way blocks are converted. (Only entry blocks are
converted.)
Added:
mlir/include/mlir/Dialect/ControlFlow/Transforms/StructuralTypeConversions.h
mlir/lib/Dialect/ControlFlow/Transforms/StructuralTypeConversions.cpp
Modified:
mlir/lib/Dialect/ControlFlow/Transforms/CMakeLists.txt
mlir/test/Transforms/test-legalize-type-conversion.mlir
mlir/test/lib/Dialect/Test/CMakeLists.txt
mlir/test/lib/Dialect/Test/TestPatterns.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/ControlFlow/Transforms/StructuralTypeConversions.h b/mlir/include/mlir/Dialect/ControlFlow/Transforms/StructuralTypeConversions.h
new file mode 100644
index 0000000000000..a32d9e2025c76
--- /dev/null
+++ b/mlir/include/mlir/Dialect/ControlFlow/Transforms/StructuralTypeConversions.h
@@ -0,0 +1,48 @@
+//===- StructuralTypeConversions.h - CF Type Conversions --------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_CONTROL_FLOW_TRANSFORMS_STRUCTURAL_TYPE_CONVERSIONS_H
+#define MLIR_DIALECT_CONTROL_FLOW_TRANSFORMS_STRUCTURAL_TYPE_CONVERSIONS_H
+
+#include "mlir/IR/PatternMatch.h"
+
+namespace mlir {
+
+class ConversionTarget;
+class TypeConverter;
+
+namespace cf {
+
+/// Populates patterns for CF structural type conversions and sets up the
+/// provided ConversionTarget with the appropriate legality configuration for
+/// the ops to get converted properly.
+///
+/// A "structural" type conversion is one where the underlying ops are
+/// completely agnostic to the actual types involved and simply need to update
+/// their types. An example of this is cf.br -- the cf.br op needs to update
+/// its types accordingly to the TypeConverter, but otherwise does not care
+/// what type conversions are happening.
+void populateCFStructuralTypeConversionsAndLegality(
+ const TypeConverter &typeConverter, RewritePatternSet &patterns,
+ ConversionTarget &target, PatternBenefit benefit = 1);
+
+/// Similar to `populateCFStructuralTypeConversionsAndLegality` but does not
+/// populate the conversion target.
+void populateCFStructuralTypeConversions(const TypeConverter &typeConverter,
+ RewritePatternSet &patterns,
+ PatternBenefit benefit = 1);
+
+/// Updates the ConversionTarget with dynamic legality of CF operations based
+/// on the provided type converter.
+void populateCFStructuralTypeConversionTarget(
+ const TypeConverter &typeConverter, ConversionTarget &target);
+
+} // namespace cf
+} // namespace mlir
+
+#endif // MLIR_DIALECT_CONTROL_FLOW_TRANSFORMS_STRUCTURAL_TYPE_CONVERSIONS_H
diff --git a/mlir/lib/Dialect/ControlFlow/Transforms/CMakeLists.txt b/mlir/lib/Dialect/ControlFlow/Transforms/CMakeLists.txt
index 47740d31844f4..e9da135ed46f9 100644
--- a/mlir/lib/Dialect/ControlFlow/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/ControlFlow/Transforms/CMakeLists.txt
@@ -1,6 +1,7 @@
add_mlir_dialect_library(MLIRControlFlowTransforms
BufferDeallocationOpInterfaceImpl.cpp
BufferizableOpInterfaceImpl.cpp
+ StructuralTypeConversions.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/ControlFlow/Transforms
diff --git a/mlir/lib/Dialect/ControlFlow/Transforms/StructuralTypeConversions.cpp b/mlir/lib/Dialect/ControlFlow/Transforms/StructuralTypeConversions.cpp
new file mode 100644
index 0000000000000..5e2a742c2d64c
--- /dev/null
+++ b/mlir/lib/Dialect/ControlFlow/Transforms/StructuralTypeConversions.cpp
@@ -0,0 +1,169 @@
+//===- TypeConversion.cpp - Type Conversion of Unstructured Control Flow --===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a pass to convert MLIR standard and builtin dialects
+// into the LLVM IR dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/ControlFlow/Transforms/StructuralTypeConversions.h"
+
+#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+using namespace mlir;
+
+namespace {
+
+/// Helper function for converting branch ops. This function converts the
+/// signature of the given block. If the new block signature is
diff erent from
+/// `expectedTypes`, returns "failure".
+static FailureOr<Block *> getConvertedBlock(ConversionPatternRewriter &rewriter,
+ const TypeConverter *converter,
+ Operation *branchOp, Block *block,
+ TypeRange expectedTypes) {
+ assert(converter && "expected non-null type converter");
+ assert(!block->isEntryBlock() && "entry blocks have no predecessors");
+
+ // There is nothing to do if the types already match.
+ if (block->getArgumentTypes() == expectedTypes)
+ return block;
+
+ // Compute the new block argument types and convert the block.
+ std::optional<TypeConverter::SignatureConversion> conversion =
+ converter->convertBlockSignature(block);
+ if (!conversion)
+ return rewriter.notifyMatchFailure(branchOp,
+ "could not compute block signature");
+ if (expectedTypes != conversion->getConvertedTypes())
+ return rewriter.notifyMatchFailure(
+ branchOp,
+ "mismatch between adaptor operand types and computed block signature");
+ return rewriter.applySignatureConversion(block, *conversion, converter);
+}
+
+/// Flatten the given value ranges into a single vector of values.
+static SmallVector<Value> flattenValues(ArrayRef<ValueRange> values) {
+ SmallVector<Value> result;
+ for (const ValueRange &vals : values)
+ llvm::append_range(result, vals);
+ return result;
+}
+
+/// Convert the destination block signature (if necessary) and change the
+/// operands of the branch op.
+struct BranchOpConversion : public OpConversionPattern<cf::BranchOp> {
+ using OpConversionPattern<cf::BranchOp>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(cf::BranchOp op, OneToNOpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ SmallVector<Value> flattenedAdaptor = flattenValues(adaptor.getOperands());
+ FailureOr<Block *> convertedBlock =
+ getConvertedBlock(rewriter, getTypeConverter(), op, op.getSuccessor(),
+ TypeRange(ValueRange(flattenedAdaptor)));
+ if (failed(convertedBlock))
+ return failure();
+ rewriter.replaceOpWithNewOp<cf::BranchOp>(op, flattenedAdaptor,
+ *convertedBlock);
+ return success();
+ }
+};
+
+/// Convert the destination block signatures (if necessary) and change the
+/// operands of the branch op.
+struct CondBranchOpConversion : public OpConversionPattern<cf::CondBranchOp> {
+ using OpConversionPattern<cf::CondBranchOp>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(cf::CondBranchOp op, OneToNOpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ SmallVector<Value> flattenedAdaptorTrue =
+ flattenValues(adaptor.getTrueDestOperands());
+ SmallVector<Value> flattenedAdaptorFalse =
+ flattenValues(adaptor.getFalseDestOperands());
+ if (!llvm::hasSingleElement(adaptor.getCondition()))
+ return rewriter.notifyMatchFailure(op,
+ "expected single element condition");
+ FailureOr<Block *> convertedTrueBlock =
+ getConvertedBlock(rewriter, getTypeConverter(), op, op.getTrueDest(),
+ TypeRange(ValueRange(flattenedAdaptorTrue)));
+ if (failed(convertedTrueBlock))
+ return failure();
+ FailureOr<Block *> convertedFalseBlock =
+ getConvertedBlock(rewriter, getTypeConverter(), op, op.getFalseDest(),
+ TypeRange(ValueRange(flattenedAdaptorFalse)));
+ if (failed(convertedFalseBlock))
+ return failure();
+ rewriter.replaceOpWithNewOp<cf::CondBranchOp>(
+ op, llvm::getSingleElement(adaptor.getCondition()),
+ flattenedAdaptorTrue, flattenedAdaptorFalse, op.getBranchWeightsAttr(),
+ *convertedTrueBlock, *convertedFalseBlock);
+ return success();
+ }
+};
+
+/// Convert the destination block signatures (if necessary) and change the
+/// operands of the switch op.
+struct SwitchOpConversion : public OpConversionPattern<cf::SwitchOp> {
+ using OpConversionPattern<cf::SwitchOp>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(cf::SwitchOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ // Get or convert default block.
+ FailureOr<Block *> convertedDefaultBlock = getConvertedBlock(
+ rewriter, getTypeConverter(), op, op.getDefaultDestination(),
+ TypeRange(adaptor.getDefaultOperands()));
+ if (failed(convertedDefaultBlock))
+ return failure();
+
+ // Get or convert all case blocks.
+ SmallVector<Block *> caseDestinations;
+ SmallVector<ValueRange> caseOperands = adaptor.getCaseOperands();
+ for (auto it : llvm::enumerate(op.getCaseDestinations())) {
+ Block *b = it.value();
+ FailureOr<Block *> convertedBlock =
+ getConvertedBlock(rewriter, getTypeConverter(), op, b,
+ TypeRange(caseOperands[it.index()]));
+ if (failed(convertedBlock))
+ return failure();
+ caseDestinations.push_back(*convertedBlock);
+ }
+
+ rewriter.replaceOpWithNewOp<cf::SwitchOp>(
+ op, adaptor.getFlag(), *convertedDefaultBlock,
+ adaptor.getDefaultOperands(), adaptor.getCaseValuesAttr(),
+ caseDestinations, caseOperands);
+ return success();
+ }
+};
+
+} // namespace
+
+void mlir::cf::populateCFStructuralTypeConversions(
+ const TypeConverter &typeConverter, RewritePatternSet &patterns,
+ PatternBenefit benefit) {
+ patterns.add<BranchOpConversion, CondBranchOpConversion, SwitchOpConversion>(
+ typeConverter, patterns.getContext(), benefit);
+}
+
+void mlir::cf::populateCFStructuralTypeConversionTarget(
+ const TypeConverter &typeConverter, ConversionTarget &target) {
+ target.addDynamicallyLegalOp<cf::BranchOp, cf::CondBranchOp, cf::SwitchOp>(
+ [&](Operation *op) { return typeConverter.isLegal(op->getOperands()); });
+}
+
+void mlir::cf::populateCFStructuralTypeConversionsAndLegality(
+ const TypeConverter &typeConverter, RewritePatternSet &patterns,
+ ConversionTarget &target, PatternBenefit benefit) {
+ populateCFStructuralTypeConversions(typeConverter, patterns, benefit);
+ populateCFStructuralTypeConversionTarget(typeConverter, target);
+}
diff --git a/mlir/test/Transforms/test-legalize-type-conversion.mlir b/mlir/test/Transforms/test-legalize-type-conversion.mlir
index c003f8b2cb1cd..91f83a0afaeef 100644
--- a/mlir/test/Transforms/test-legalize-type-conversion.mlir
+++ b/mlir/test/Transforms/test-legalize-type-conversion.mlir
@@ -143,3 +143,25 @@ func.func @test_signature_conversion_no_converter() {
return
}
+// -----
+
+// CHECK-LABEL: func @test_unstructured_cf_conversion(
+// CHECK-SAME: %[[arg0:.*]]: f64, %[[c:.*]]: i1)
+// CHECK: %[[cast1:.*]] = "test.cast"(%[[arg0]]) : (f64) -> f32
+// CHECK: "test.foo"(%[[cast1]])
+// CHECK: cf.br ^[[bb1:.*]](%[[arg0]] : f64)
+// CHECK: ^[[bb1]](%[[arg1:.*]]: f64):
+// CHECK: cf.cond_br %[[c]], ^[[bb1]](%[[arg1]] : f64), ^[[bb2:.*]](%[[arg1]] : f64)
+// CHECK: ^[[bb2]](%[[arg2:.*]]: f64):
+// CHECK: %[[cast2:.*]] = "test.cast"(%[[arg2]]) : (f64) -> f32
+// CHECK: "test.bar"(%[[cast2]])
+// CHECK: return
+func.func @test_unstructured_cf_conversion(%arg0: f32, %c: i1) {
+ "test.foo"(%arg0) : (f32) -> ()
+ cf.br ^bb1(%arg0: f32)
+^bb1(%arg1: f32):
+ cf.cond_br %c, ^bb1(%arg1 : f32), ^bb2(%arg1 : f32)
+^bb2(%arg2: f32):
+ "test.bar"(%arg2) : (f32) -> ()
+ return
+}
diff --git a/mlir/test/lib/Dialect/Test/CMakeLists.txt b/mlir/test/lib/Dialect/Test/CMakeLists.txt
index f099d01abd31a..9354a85d984c9 100644
--- a/mlir/test/lib/Dialect/Test/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/Test/CMakeLists.txt
@@ -71,6 +71,7 @@ add_mlir_library(MLIRTestDialect
)
mlir_target_link_libraries(MLIRTestDialect PUBLIC
MLIRControlFlowInterfaces
+ MLIRControlFlowTransforms
MLIRDataLayoutInterfaces
MLIRDerivedAttributeOpInterface
MLIRDestinationStyleOpInterface
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index efbdbfb65d65b..fd2b943ff1296 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -11,6 +11,7 @@
#include "TestTypes.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/CommonFolders.h"
+#include "mlir/Dialect/ControlFlow/Transforms/StructuralTypeConversions.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Func/Transforms/FuncConversions.h"
#include "mlir/Dialect/SCF/Transforms/Patterns.h"
@@ -2042,6 +2043,10 @@ struct TestTypeConversionDriver
});
converter.addConversion([](IndexType type) { return type; });
converter.addConversion([](IntegerType type, SmallVectorImpl<Type> &types) {
+ if (type.isInteger(1)) {
+ // i1 is legal.
+ types.push_back(type);
+ }
if (type.isInteger(38)) {
// i38 is legal.
types.push_back(type);
@@ -2175,6 +2180,8 @@ struct TestTypeConversionDriver
converter);
mlir::scf::populateSCFStructuralTypeConversionsAndLegality(
converter, patterns, target);
+ mlir::cf::populateCFStructuralTypeConversionsAndLegality(converter,
+ patterns, target);
ConversionConfig config;
config.allowPatternRollback = allowPatternRollback;
More information about the Mlir-commits
mailing list