[Mlir-commits] [mlir] f7b09ad - [mlir][LLVM] `ArithToLLVM`: Add 1:N support for `arith.select` lowering (#153944)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Aug 18 00:42:40 PDT 2025


Author: Matthias Springer
Date: 2025-08-18T09:42:37+02:00
New Revision: f7b09ad700f2d8ae9ad230f6fc85de81e3a6565b

URL: https://github.com/llvm/llvm-project/commit/f7b09ad700f2d8ae9ad230f6fc85de81e3a6565b
DIFF: https://github.com/llvm/llvm-project/commit/f7b09ad700f2d8ae9ad230f6fc85de81e3a6565b.diff

LOG: [mlir][LLVM] `ArithToLLVM`: Add 1:N support for `arith.select` lowering (#153944)

Add 1:N support for the `arith.select` lowering. Only cases where the
entire true/false value is selected are supported.

Added: 
    mlir/test/Conversion/ArithToLLVM/type-conversion.mlir

Modified: 
    mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
    mlir/test/lib/Dialect/LLVM/TestPatterns.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
index 18e857c81af8d..cb0c829719565 100644
--- a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
+++ b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
@@ -238,6 +238,16 @@ struct CmpFOpLowering : public ConvertOpToLLVMPattern<arith::CmpFOp> {
                   ConversionPatternRewriter &rewriter) const override;
 };
 
+struct SelectOpOneToNLowering : public ConvertOpToLLVMPattern<arith::SelectOp> {
+  using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
+  using Adaptor =
+      typename ConvertOpToLLVMPattern<arith::SelectOp>::OneToNOpAdaptor;
+
+  LogicalResult
+  matchAndRewrite(arith::SelectOp op, Adaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override;
+};
+
 } // namespace
 
 //===----------------------------------------------------------------------===//
@@ -479,6 +489,32 @@ CmpFOpLowering::matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
       rewriter);
 }
 
+//===----------------------------------------------------------------------===//
+// SelectOpOneToNLowering
+//===----------------------------------------------------------------------===//
+
+/// Pattern for arith.select where the true/false values lower to multiple
+/// SSA values (1:N conversion). This pattern generates multiple arith.select
+/// than can be lowered by the 1:1 arith.select pattern.
+LogicalResult SelectOpOneToNLowering::matchAndRewrite(
+    arith::SelectOp op, Adaptor adaptor,
+    ConversionPatternRewriter &rewriter) const {
+  // In case of a 1:1 conversion, the 1:1 pattern will match.
+  if (llvm::hasSingleElement(adaptor.getTrueValue()))
+    return rewriter.notifyMatchFailure(
+        op, "not a 1:N conversion, 1:1 pattern will match");
+  if (!op.getCondition().getType().isInteger(1))
+    return rewriter.notifyMatchFailure(op,
+                                       "non-i1 conditions are not supported");
+  SmallVector<Value> results;
+  for (auto [trueValue, falseValue] :
+       llvm::zip_equal(adaptor.getTrueValue(), adaptor.getFalseValue()))
+    results.push_back(arith::SelectOp::create(
+        rewriter, op.getLoc(), op.getCondition(), trueValue, falseValue));
+  rewriter.replaceOpWithMultiple(op, {results});
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // Pass Definition
 //===----------------------------------------------------------------------===//
@@ -587,6 +623,7 @@ void mlir::arith::populateArithToLLVMConversionPatterns(
     RemSIOpLowering,
     RemUIOpLowering,
     SelectOpLowering,
+    SelectOpOneToNLowering,
     ShLIOpLowering,
     ShRSIOpLowering,
     ShRUIOpLowering,

diff  --git a/mlir/test/Conversion/ArithToLLVM/type-conversion.mlir b/mlir/test/Conversion/ArithToLLVM/type-conversion.mlir
new file mode 100644
index 0000000000000..e3a0c82a628ba
--- /dev/null
+++ b/mlir/test/Conversion/ArithToLLVM/type-conversion.mlir
@@ -0,0 +1,15 @@
+// RUN: mlir-opt %s -test-llvm-legalize-patterns -split-input-file | FileCheck %s
+// RUN: mlir-opt %s -test-llvm-legalize-patterns="allow-pattern-rollback=0" -split-input-file | FileCheck %s
+
+// CHECK-LABEL: llvm.func @arith_select(
+//  CHECK-SAME:     %[[arg0:.*]]: i1, %[[arg1:.*]]: i18, %[[arg2:.*]]: i18, %[[arg3:.*]]: i18, %[[arg4:.*]]: i18) -> !llvm.struct<(i18, i18)>
+//       CHECK:   %[[select0:.*]] = llvm.select %[[arg0]], %[[arg1]], %[[arg3]] : i1, i18
+//       CHECK:   %[[select1:.*]] = llvm.select %[[arg0]], %[[arg2]], %[[arg4]] : i1, i18
+//       CHECK:   %[[i0:.*]] = llvm.mlir.poison : !llvm.struct<(i18, i18)>
+//       CHECK:   %[[i1:.*]] = llvm.insertvalue %[[select0]], %[[i0]][0] : !llvm.struct<(i18, i18)>
+//       CHECK:   %[[i2:.*]] = llvm.insertvalue %[[select1]], %[[i1]][1] : !llvm.struct<(i18, i18)>
+//       CHECK:   llvm.return %[[i2]]
+func.func @arith_select(%arg0: i1, %arg1: i17, %arg2: i17) -> (i17) {
+  %0 = arith.select %arg0, %arg1, %arg2 : i17
+  return %0 : i17
+}

diff  --git a/mlir/test/lib/Dialect/LLVM/TestPatterns.cpp b/mlir/test/lib/Dialect/LLVM/TestPatterns.cpp
index 9d30ae43cccc1..69a3d98bc09e4 100644
--- a/mlir/test/lib/Dialect/LLVM/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/LLVM/TestPatterns.cpp
@@ -6,6 +6,7 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
 #include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
 #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"
 #include "mlir/Conversion/LLVMCommon/TypeConverter.h"
@@ -70,6 +71,7 @@ struct TestLLVMLegalizePatternsPass
     // Populate patterns.
     mlir::RewritePatternSet patterns(ctx);
     patterns.add<TestDirectReplacementOp>(ctx, converter);
+    arith::populateArithToLLVMConversionPatterns(converter, patterns);
     populateFuncToLLVMConversionPatterns(converter, patterns);
     cf::populateControlFlowToLLVMConversionPatterns(converter, patterns);
 


        


More information about the Mlir-commits mailing list