[Mlir-commits] [mlir] [mlir][arith] Fix `arith.select` lowering after #166513 (PR #166692)

Matthias Springer llvmlistbot at llvm.org
Wed Nov 5 19:14:37 PST 2025


https://github.com/matthias-springer created https://github.com/llvm/llvm-project/pull/166692

#166513 broke the lowering of `arith.select` with unsupported FP4 types. For this op, it is fine to convert to `i4`.


>From 48876ecfb0f658236ffdb59a402a4c5cec9d8d2e Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Thu, 6 Nov 2025 03:12:51 +0000
Subject: [PATCH] [mlir][arith] Fix `arith.select` lowering after #166513

---
 .../mlir/Conversion/LLVMCommon/VectorPattern.h      | 13 ++++++++-----
 mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp     |  5 ++++-
 mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir |  6 ++++--
 3 files changed, 16 insertions(+), 8 deletions(-)

diff --git a/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h b/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h
index cad6cec761ab8..b8e3023b25569 100644
--- a/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h
+++ b/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h
@@ -86,7 +86,8 @@ class AttrConvertPassThrough {
 /// ArrayRef<NamedAttribute>.
 template <typename SourceOp, typename TargetOp,
           template <typename, typename> typename AttrConvert =
-              AttrConvertPassThrough>
+              AttrConvertPassThrough,
+          bool FailOnUnsupportedFP = true>
 class VectorConvertToLLVMPattern : public ConvertOpToLLVMPattern<SourceOp> {
 public:
   using ConvertOpToLLVMPattern<SourceOp>::ConvertOpToLLVMPattern;
@@ -123,11 +124,13 @@ class VectorConvertToLLVMPattern : public ConvertOpToLLVMPattern<SourceOp> {
                                            "unsupported floating point type");
       return success();
     };
-    for (Value operand : op->getOperands())
-      if (failed(checkType(operand)))
+    if (FailOnUnsupportedFP) {
+      for (Value operand : op->getOperands())
+        if (failed(checkType(operand)))
+          return failure();
+      if (failed(checkType(op->getResult(0))))
         return failure();
-    if (failed(checkType(op->getResult(0))))
-      return failure();
+    }
 
     // Determine attributes for the target op
     AttrConvert<SourceOp, TargetOp> attrConvert(op);
diff --git a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
index 03ed4d51cc744..55cffa1e22d77 100644
--- a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
+++ b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
@@ -14,6 +14,7 @@
 #include "mlir/Conversion/LLVMCommon/VectorPattern.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Arith/Transforms/Passes.h"
+#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
 #include "mlir/Dialect/LLVMIR/LLVMAttrs.h"
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
 #include "mlir/IR/TypeUtilities.h"
@@ -139,7 +140,9 @@ using RemSIOpLowering =
 using RemUIOpLowering =
     VectorConvertToLLVMPattern<arith::RemUIOp, LLVM::URemOp>;
 using SelectOpLowering =
-    VectorConvertToLLVMPattern<arith::SelectOp, LLVM::SelectOp>;
+    VectorConvertToLLVMPattern<arith::SelectOp, LLVM::SelectOp,
+                               AttrConvertPassThrough,
+                               /*FailOnUnsupportedFP=*/false>;
 using ShLIOpLowering =
     VectorConvertToLLVMPattern<arith::ShLIOp, LLVM::ShlOp,
                                arith::AttrConvertOverflowToLLVM>;
diff --git a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
index b5dcb01d3dc6b..5f1ec66234df2 100644
--- a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
+++ b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
@@ -754,11 +754,13 @@ func.func @memref_bitcast(%1: memref<?xi16>) -> memref<?xbf16> {
 //       CHECK:   arith.addf {{.*}} : f4E2M1FN
 //       CHECK:   arith.addf {{.*}} : vector<4xf4E2M1FN>
 //       CHECK:   arith.addf {{.*}} : vector<8x4xf4E2M1FN>
-func.func @unsupported_fp_type(%arg0: f4E2M1FN, %arg1: vector<4xf4E2M1FN>, %arg2: vector<8x4xf4E2M1FN>) -> (f4E2M1FN, vector<4xf4E2M1FN>, vector<8x4xf4E2M1FN>) {
+//       CHECK:   llvm.select {{.*}} : i1, i4
+func.func @unsupported_fp_type(%arg0: f4E2M1FN, %arg1: vector<4xf4E2M1FN>, %arg2: vector<8x4xf4E2M1FN>, %arg3: f4E2M1FN, %arg4: i1) {
   %0 = arith.addf %arg0, %arg0 : f4E2M1FN
   %1 = arith.addf %arg1, %arg1 : vector<4xf4E2M1FN>
   %2 = arith.addf %arg2, %arg2 : vector<8x4xf4E2M1FN>
-  return %0, %1, %2 : f4E2M1FN, vector<4xf4E2M1FN>, vector<8x4xf4E2M1FN>
+  %3 = arith.select %arg4, %arg0, %arg3 : f4E2M1FN
+  return
 }
 
 // -----



More information about the Mlir-commits mailing list