[Mlir-commits] [mlir] [mlir][arith] Fix `arith.select` lowering after #166513 (PR #166692)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Nov 5 19:15:09 PST 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-llvm

Author: Matthias Springer (matthias-springer)

<details>
<summary>Changes</summary>

#<!-- -->166513 broke the lowering of `arith.select` with unsupported FP4 types. For this op, it is fine to convert to `i4`.


---
Full diff: https://github.com/llvm/llvm-project/pull/166692.diff


3 Files Affected:

- (modified) mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h (+8-5) 
- (modified) mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp (+4-1) 
- (modified) mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir (+4-2) 


``````````diff
diff --git a/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h b/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h
index cad6cec761ab8..b8e3023b25569 100644
--- a/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h
+++ b/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h
@@ -86,7 +86,8 @@ class AttrConvertPassThrough {
 /// ArrayRef<NamedAttribute>.
 template <typename SourceOp, typename TargetOp,
           template <typename, typename> typename AttrConvert =
-              AttrConvertPassThrough>
+              AttrConvertPassThrough,
+          bool FailOnUnsupportedFP = true>
 class VectorConvertToLLVMPattern : public ConvertOpToLLVMPattern<SourceOp> {
 public:
   using ConvertOpToLLVMPattern<SourceOp>::ConvertOpToLLVMPattern;
@@ -123,11 +124,13 @@ class VectorConvertToLLVMPattern : public ConvertOpToLLVMPattern<SourceOp> {
                                            "unsupported floating point type");
       return success();
     };
-    for (Value operand : op->getOperands())
-      if (failed(checkType(operand)))
+    if (FailOnUnsupportedFP) {
+      for (Value operand : op->getOperands())
+        if (failed(checkType(operand)))
+          return failure();
+      if (failed(checkType(op->getResult(0))))
         return failure();
-    if (failed(checkType(op->getResult(0))))
-      return failure();
+    }
 
     // Determine attributes for the target op
     AttrConvert<SourceOp, TargetOp> attrConvert(op);
diff --git a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
index 03ed4d51cc744..55cffa1e22d77 100644
--- a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
+++ b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
@@ -14,6 +14,7 @@
 #include "mlir/Conversion/LLVMCommon/VectorPattern.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Arith/Transforms/Passes.h"
+#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
 #include "mlir/Dialect/LLVMIR/LLVMAttrs.h"
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
 #include "mlir/IR/TypeUtilities.h"
@@ -139,7 +140,9 @@ using RemSIOpLowering =
 using RemUIOpLowering =
     VectorConvertToLLVMPattern<arith::RemUIOp, LLVM::URemOp>;
 using SelectOpLowering =
-    VectorConvertToLLVMPattern<arith::SelectOp, LLVM::SelectOp>;
+    VectorConvertToLLVMPattern<arith::SelectOp, LLVM::SelectOp,
+                               AttrConvertPassThrough,
+                               /*FailOnUnsupportedFP=*/false>;
 using ShLIOpLowering =
     VectorConvertToLLVMPattern<arith::ShLIOp, LLVM::ShlOp,
                                arith::AttrConvertOverflowToLLVM>;
diff --git a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
index b5dcb01d3dc6b..5f1ec66234df2 100644
--- a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
+++ b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
@@ -754,11 +754,13 @@ func.func @memref_bitcast(%1: memref<?xi16>) -> memref<?xbf16> {
 //       CHECK:   arith.addf {{.*}} : f4E2M1FN
 //       CHECK:   arith.addf {{.*}} : vector<4xf4E2M1FN>
 //       CHECK:   arith.addf {{.*}} : vector<8x4xf4E2M1FN>
-func.func @unsupported_fp_type(%arg0: f4E2M1FN, %arg1: vector<4xf4E2M1FN>, %arg2: vector<8x4xf4E2M1FN>) -> (f4E2M1FN, vector<4xf4E2M1FN>, vector<8x4xf4E2M1FN>) {
+//       CHECK:   llvm.select {{.*}} : i1, i4
+func.func @unsupported_fp_type(%arg0: f4E2M1FN, %arg1: vector<4xf4E2M1FN>, %arg2: vector<8x4xf4E2M1FN>, %arg3: f4E2M1FN, %arg4: i1) {
   %0 = arith.addf %arg0, %arg0 : f4E2M1FN
   %1 = arith.addf %arg1, %arg1 : vector<4xf4E2M1FN>
   %2 = arith.addf %arg2, %arg2 : vector<8x4xf4E2M1FN>
-  return %0, %1, %2 : f4E2M1FN, vector<4xf4E2M1FN>, vector<8x4xf4E2M1FN>
+  %3 = arith.select %arg4, %arg0, %arg3 : f4E2M1FN
+  return
 }
 
 // -----

``````````

</details>


https://github.com/llvm/llvm-project/pull/166692


More information about the Mlir-commits mailing list