[Mlir-commits] [mlir] [mlir][LLVM] `ArithToLLVM`: Add 1:N support for `arith.select` lowering (PR #153944)

Matthias Springer llvmlistbot at llvm.org
Sat Aug 16 03:31:22 PDT 2025


https://github.com/matthias-springer created https://github.com/llvm/llvm-project/pull/153944

Add 1:N support for the `arith.select` lowering. Only cases where the entire true/false value is selected are supported.


>From ae9611fc06baff96c37d2cb5078cfd08e2b2fd51 Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Sat, 16 Aug 2025 10:27:59 +0000
Subject: [PATCH] [mlir][LLVM] `ArithToLLVM`: Add 1:N support for
 `arith.select` lowering

---
 .../Conversion/ArithToLLVM/ArithToLLVM.cpp    | 41 +++++++++++++++++++
 .../MemRefToLLVM/type-conversion.mlir         | 15 +++++++
 mlir/test/lib/Dialect/LLVM/TestPatterns.cpp   |  2 +
 3 files changed, 58 insertions(+)

diff --git a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
index 18e857c81af8d..3d759e0fb6361 100644
--- a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
+++ b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
@@ -238,6 +238,16 @@ struct CmpFOpLowering : public ConvertOpToLLVMPattern<arith::CmpFOp> {
                   ConversionPatternRewriter &rewriter) const override;
 };
 
+struct SelectOpOneToNLowering : public ConvertOpToLLVMPattern<arith::SelectOp> {
+  using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
+  using Adaptor =
+      typename ConvertOpToLLVMPattern<arith::SelectOp>::OneToNOpAdaptor;
+
+  LogicalResult
+  matchAndRewrite(arith::SelectOp op, Adaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override;
+};
+
 } // namespace
 
 //===----------------------------------------------------------------------===//
@@ -479,6 +489,36 @@ CmpFOpLowering::matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
       rewriter);
 }
 
+//===----------------------------------------------------------------------===//
+// SelectOpOneToNLowering
+//===----------------------------------------------------------------------===//
+
+/// Pattern for arith.select where the true/false values lower to multiple
+/// SSA values (1:N conversion). This pattern generates multiple arith.select
+/// than can be lowered by the 1:1 arith.select pattern.
+LogicalResult SelectOpOneToNLowering::matchAndRewrite(
+    arith::SelectOp op, Adaptor adaptor,
+    ConversionPatternRewriter &rewriter) const {
+  // In case of a 1:1 conversion, the 1:1 pattern will match.
+  if (llvm::hasSingleElement(adaptor.getTrueValue()))
+    return rewriter.notifyMatchFailure(
+        op, "not a 1:N conversion, 1:1 pattern will match");
+  if (!op.getCondition().getType().isInteger(1))
+    return rewriter.notifyMatchFailure(op,
+                                       "non-i1 conditions are not supported");
+  if (!llvm::hasSingleElement(adaptor.getCondition()))
+    return rewriter.notifyMatchFailure(
+        op, "1:N condition conversion is not supported");
+  SmallVector<Value> results;
+  for (auto [trueValue, falseValue] :
+       llvm::zip_equal(adaptor.getTrueValue(), adaptor.getFalseValue()))
+    results.push_back(arith::SelectOp::create(
+        rewriter, op.getLoc(), llvm::getSingleElement(adaptor.getCondition()),
+        trueValue, falseValue));
+  rewriter.replaceOpWithMultiple(op, {results});
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // Pass Definition
 //===----------------------------------------------------------------------===//
@@ -587,6 +627,7 @@ void mlir::arith::populateArithToLLVMConversionPatterns(
     RemSIOpLowering,
     RemUIOpLowering,
     SelectOpLowering,
+    SelectOpOneToNLowering,
     ShLIOpLowering,
     ShRSIOpLowering,
     ShRUIOpLowering,
diff --git a/mlir/test/Conversion/MemRefToLLVM/type-conversion.mlir b/mlir/test/Conversion/MemRefToLLVM/type-conversion.mlir
index c1751f282b002..2e050887cc1d3 100644
--- a/mlir/test/Conversion/MemRefToLLVM/type-conversion.mlir
+++ b/mlir/test/Conversion/MemRefToLLVM/type-conversion.mlir
@@ -138,3 +138,18 @@ func.func @caller(%arg0: i1, %arg1: i17) -> (i17, i1, i17) {
   %res:2 = func.call @multi_return(%arg1, %arg0) : (i17, i1) -> (i17, i1)
   return %res#0, %res#1, %res#0 : i17, i1, i17
 }
+
+// -----
+
+// CHECK-LABEL: llvm.func @arith_select(
+//  CHECK-SAME:     %[[arg0:.*]]: i1, %[[arg1:.*]]: i18, %[[arg2:.*]]: i18, %[[arg3:.*]]: i18, %[[arg4:.*]]: i18) -> !llvm.struct<(i18, i18)>
+//       CHECK:   %[[select0:.*]] = llvm.select %[[arg0]], %[[arg1]], %[[arg3]] : i1, i18
+//       CHECK:   %[[select1:.*]] = llvm.select %[[arg0]], %[[arg2]], %[[arg4]] : i1, i18
+//       CHECK:   %[[i0:.*]] = llvm.mlir.poison : !llvm.struct<(i18, i18)>
+//       CHECK:   %[[i1:.*]] = llvm.insertvalue %[[select0]], %[[i0]][0] : !llvm.struct<(i18, i18)>
+//       CHECK:   %[[i2:.*]] = llvm.insertvalue %[[select1]], %[[i1]][1] : !llvm.struct<(i18, i18)>
+//       CHECK:   llvm.return %[[i2]]
+func.func @arith_select(%arg0: i1, %arg1: i17, %arg2: i17) -> (i17) {
+  %0 = arith.select %arg0, %arg1, %arg2 : i17
+  return %0 : i17
+}
diff --git a/mlir/test/lib/Dialect/LLVM/TestPatterns.cpp b/mlir/test/lib/Dialect/LLVM/TestPatterns.cpp
index fe9aa0f2a9902..c2a75836b77b9 100644
--- a/mlir/test/lib/Dialect/LLVM/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/LLVM/TestPatterns.cpp
@@ -6,6 +6,7 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
 #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"
 #include "mlir/Conversion/LLVMCommon/TypeConverter.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
@@ -69,6 +70,7 @@ struct TestLLVMLegalizePatternsPass
     // Populate patterns.
     mlir::RewritePatternSet patterns(ctx);
     patterns.add<TestDirectReplacementOp>(ctx, converter);
+    arith::populateArithToLLVMConversionPatterns(converter, patterns);
     populateFuncToLLVMConversionPatterns(converter, patterns);
 
     // Define the conversion target used for the test.



More information about the Mlir-commits mailing list