[Mlir-commits] [mlir] f7b09ad - [mlir][LLVM] `ArithToLLVM`: Add 1:N support for `arith.select` lowering (#153944)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Aug 18 00:42:40 PDT 2025
Author: Matthias Springer
Date: 2025-08-18T09:42:37+02:00
New Revision: f7b09ad700f2d8ae9ad230f6fc85de81e3a6565b
URL: https://github.com/llvm/llvm-project/commit/f7b09ad700f2d8ae9ad230f6fc85de81e3a6565b
DIFF: https://github.com/llvm/llvm-project/commit/f7b09ad700f2d8ae9ad230f6fc85de81e3a6565b.diff
LOG: [mlir][LLVM] `ArithToLLVM`: Add 1:N support for `arith.select` lowering (#153944)
Add 1:N support for the `arith.select` lowering. Only cases where the
entire true/false value is selected are supported.
Added:
mlir/test/Conversion/ArithToLLVM/type-conversion.mlir
Modified:
mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
mlir/test/lib/Dialect/LLVM/TestPatterns.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
index 18e857c81af8d..cb0c829719565 100644
--- a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
+++ b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
@@ -238,6 +238,16 @@ struct CmpFOpLowering : public ConvertOpToLLVMPattern<arith::CmpFOp> {
ConversionPatternRewriter &rewriter) const override;
};
+struct SelectOpOneToNLowering : public ConvertOpToLLVMPattern<arith::SelectOp> {
+ using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
+ using Adaptor =
+ typename ConvertOpToLLVMPattern<arith::SelectOp>::OneToNOpAdaptor;
+
+ LogicalResult
+ matchAndRewrite(arith::SelectOp op, Adaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override;
+};
+
} // namespace
//===----------------------------------------------------------------------===//
@@ -479,6 +489,32 @@ CmpFOpLowering::matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
rewriter);
}
+//===----------------------------------------------------------------------===//
+// SelectOpOneToNLowering
+//===----------------------------------------------------------------------===//
+
+/// Pattern for arith.select where the true/false values lower to multiple
+/// SSA values (1:N conversion). This pattern generates multiple arith.select
+/// than can be lowered by the 1:1 arith.select pattern.
+LogicalResult SelectOpOneToNLowering::matchAndRewrite(
+ arith::SelectOp op, Adaptor adaptor,
+ ConversionPatternRewriter &rewriter) const {
+ // In case of a 1:1 conversion, the 1:1 pattern will match.
+ if (llvm::hasSingleElement(adaptor.getTrueValue()))
+ return rewriter.notifyMatchFailure(
+ op, "not a 1:N conversion, 1:1 pattern will match");
+ if (!op.getCondition().getType().isInteger(1))
+ return rewriter.notifyMatchFailure(op,
+ "non-i1 conditions are not supported");
+ SmallVector<Value> results;
+ for (auto [trueValue, falseValue] :
+ llvm::zip_equal(adaptor.getTrueValue(), adaptor.getFalseValue()))
+ results.push_back(arith::SelectOp::create(
+ rewriter, op.getLoc(), op.getCondition(), trueValue, falseValue));
+ rewriter.replaceOpWithMultiple(op, {results});
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// Pass Definition
//===----------------------------------------------------------------------===//
@@ -587,6 +623,7 @@ void mlir::arith::populateArithToLLVMConversionPatterns(
RemSIOpLowering,
RemUIOpLowering,
SelectOpLowering,
+ SelectOpOneToNLowering,
ShLIOpLowering,
ShRSIOpLowering,
ShRUIOpLowering,
diff --git a/mlir/test/Conversion/ArithToLLVM/type-conversion.mlir b/mlir/test/Conversion/ArithToLLVM/type-conversion.mlir
new file mode 100644
index 0000000000000..e3a0c82a628ba
--- /dev/null
+++ b/mlir/test/Conversion/ArithToLLVM/type-conversion.mlir
@@ -0,0 +1,15 @@
+// RUN: mlir-opt %s -test-llvm-legalize-patterns -split-input-file | FileCheck %s
+// RUN: mlir-opt %s -test-llvm-legalize-patterns="allow-pattern-rollback=0" -split-input-file | FileCheck %s
+
+// CHECK-LABEL: llvm.func @arith_select(
+// CHECK-SAME: %[[arg0:.*]]: i1, %[[arg1:.*]]: i18, %[[arg2:.*]]: i18, %[[arg3:.*]]: i18, %[[arg4:.*]]: i18) -> !llvm.struct<(i18, i18)>
+// CHECK: %[[select0:.*]] = llvm.select %[[arg0]], %[[arg1]], %[[arg3]] : i1, i18
+// CHECK: %[[select1:.*]] = llvm.select %[[arg0]], %[[arg2]], %[[arg4]] : i1, i18
+// CHECK: %[[i0:.*]] = llvm.mlir.poison : !llvm.struct<(i18, i18)>
+// CHECK: %[[i1:.*]] = llvm.insertvalue %[[select0]], %[[i0]][0] : !llvm.struct<(i18, i18)>
+// CHECK: %[[i2:.*]] = llvm.insertvalue %[[select1]], %[[i1]][1] : !llvm.struct<(i18, i18)>
+// CHECK: llvm.return %[[i2]]
+func.func @arith_select(%arg0: i1, %arg1: i17, %arg2: i17) -> (i17) {
+ %0 = arith.select %arg0, %arg1, %arg2 : i17
+ return %0 : i17
+}
diff --git a/mlir/test/lib/Dialect/LLVM/TestPatterns.cpp b/mlir/test/lib/Dialect/LLVM/TestPatterns.cpp
index 9d30ae43cccc1..69a3d98bc09e4 100644
--- a/mlir/test/lib/Dialect/LLVM/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/LLVM/TestPatterns.cpp
@@ -6,6 +6,7 @@
//
//===----------------------------------------------------------------------===//
+#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
@@ -70,6 +71,7 @@ struct TestLLVMLegalizePatternsPass
// Populate patterns.
mlir::RewritePatternSet patterns(ctx);
patterns.add<TestDirectReplacementOp>(ctx, converter);
+ arith::populateArithToLLVMConversionPatterns(converter, patterns);
populateFuncToLLVMConversionPatterns(converter, patterns);
cf::populateControlFlowToLLVMConversionPatterns(converter, patterns);
More information about the Mlir-commits
mailing list