[Mlir-commits] [mlir] 3740368 - [mlir][arith] Fix `arith.select` lowering after #166513 (#166692)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Nov 7 09:59:59 PST 2025


Author: Matthias Springer
Date: 2025-11-07T09:59:55-08:00
New Revision: 37403685298bd3a7ba2c482b4d6fc26fa0491e19

URL: https://github.com/llvm/llvm-project/commit/37403685298bd3a7ba2c482b4d6fc26fa0491e19
DIFF: https://github.com/llvm/llvm-project/commit/37403685298bd3a7ba2c482b4d6fc26fa0491e19.diff

LOG: [mlir][arith] Fix `arith.select` lowering after #166513 (#166692)

#166513 broke the lowering of `arith.select` with unsupported FP4 types.
For this op, it is fine to convert to `i4`.

Added: 
    

Modified: 
    mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h
    mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
    mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h b/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h
index cad6cec761ab8..e7ab63abfeaa1 100644
--- a/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h
+++ b/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h
@@ -86,7 +86,8 @@ class AttrConvertPassThrough {
 /// ArrayRef<NamedAttribute>.
 template <typename SourceOp, typename TargetOp,
           template <typename, typename> typename AttrConvert =
-              AttrConvertPassThrough>
+              AttrConvertPassThrough,
+          bool FailOnUnsupportedFP = false>
 class VectorConvertToLLVMPattern : public ConvertOpToLLVMPattern<SourceOp> {
 public:
   using ConvertOpToLLVMPattern<SourceOp>::ConvertOpToLLVMPattern;
@@ -123,11 +124,13 @@ class VectorConvertToLLVMPattern : public ConvertOpToLLVMPattern<SourceOp> {
                                            "unsupported floating point type");
       return success();
     };
-    for (Value operand : op->getOperands())
-      if (failed(checkType(operand)))
+    if (FailOnUnsupportedFP) {
+      for (Value operand : op->getOperands())
+        if (failed(checkType(operand)))
+          return failure();
+      if (failed(checkType(op->getResult(0))))
         return failure();
-    if (failed(checkType(op->getResult(0))))
-      return failure();
+    }
 
     // Determine attributes for the target op
     AttrConvert<SourceOp, TargetOp> attrConvert(op);

diff  --git a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
index 03ed4d51cc744..b6099902cc337 100644
--- a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
+++ b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
@@ -36,20 +36,23 @@ namespace {
 /// attribute.
 template <typename SourceOp, typename TargetOp, bool Constrained,
           template <typename, typename> typename AttrConvert =
-              AttrConvertPassThrough>
+              AttrConvertPassThrough,
+          bool FailOnUnsupportedFP = false>
 struct ConstrainedVectorConvertToLLVMPattern
-    : public VectorConvertToLLVMPattern<SourceOp, TargetOp, AttrConvert> {
-  using VectorConvertToLLVMPattern<SourceOp, TargetOp,
-                                   AttrConvert>::VectorConvertToLLVMPattern;
+    : public VectorConvertToLLVMPattern<SourceOp, TargetOp, AttrConvert,
+                                        FailOnUnsupportedFP> {
+  using VectorConvertToLLVMPattern<
+      SourceOp, TargetOp, AttrConvert,
+      FailOnUnsupportedFP>::VectorConvertToLLVMPattern;
 
   LogicalResult
   matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     if (Constrained != static_cast<bool>(op.getRoundingModeAttr()))
       return failure();
-    return VectorConvertToLLVMPattern<SourceOp, TargetOp,
-                                      AttrConvert>::matchAndRewrite(op, adaptor,
-                                                                    rewriter);
+    return VectorConvertToLLVMPattern<
+        SourceOp, TargetOp, AttrConvert,
+        FailOnUnsupportedFP>::matchAndRewrite(op, adaptor, rewriter);
   }
 };
 
@@ -78,7 +81,8 @@ struct IdentityBitcastLowering final
 
 using AddFOpLowering =
     VectorConvertToLLVMPattern<arith::AddFOp, LLVM::FAddOp,
-                               arith::AttrConvertFastMathToLLVM>;
+                               arith::AttrConvertFastMathToLLVM,
+                               /*FailOnUnsupportedFP=*/true>;
 using AddIOpLowering =
     VectorConvertToLLVMPattern<arith::AddIOp, LLVM::AddOp,
                                arith::AttrConvertOverflowToLLVM>;
@@ -87,53 +91,67 @@ using BitcastOpLowering =
     VectorConvertToLLVMPattern<arith::BitcastOp, LLVM::BitcastOp>;
 using DivFOpLowering =
     VectorConvertToLLVMPattern<arith::DivFOp, LLVM::FDivOp,
-                               arith::AttrConvertFastMathToLLVM>;
+                               arith::AttrConvertFastMathToLLVM,
+                               /*FailOnUnsupportedFP=*/true>;
 using DivSIOpLowering =
     VectorConvertToLLVMPattern<arith::DivSIOp, LLVM::SDivOp>;
 using DivUIOpLowering =
     VectorConvertToLLVMPattern<arith::DivUIOp, LLVM::UDivOp>;
-using ExtFOpLowering = VectorConvertToLLVMPattern<arith::ExtFOp, LLVM::FPExtOp>;
+using ExtFOpLowering = VectorConvertToLLVMPattern<arith::ExtFOp, LLVM::FPExtOp,
+                                                  AttrConvertPassThrough,
+                                                  /*FailOnUnsupportedFP=*/true>;
 using ExtSIOpLowering =
     VectorConvertToLLVMPattern<arith::ExtSIOp, LLVM::SExtOp>;
 using ExtUIOpLowering =
     VectorConvertToLLVMPattern<arith::ExtUIOp, LLVM::ZExtOp>;
 using FPToSIOpLowering =
-    VectorConvertToLLVMPattern<arith::FPToSIOp, LLVM::FPToSIOp>;
+    VectorConvertToLLVMPattern<arith::FPToSIOp, LLVM::FPToSIOp,
+                               AttrConvertPassThrough,
+                               /*FailOnUnsupportedFP=*/true>;
 using FPToUIOpLowering =
-    VectorConvertToLLVMPattern<arith::FPToUIOp, LLVM::FPToUIOp>;
+    VectorConvertToLLVMPattern<arith::FPToUIOp, LLVM::FPToUIOp,
+                               AttrConvertPassThrough,
+                               /*FailOnUnsupportedFP=*/true>;
 using MaximumFOpLowering =
     VectorConvertToLLVMPattern<arith::MaximumFOp, LLVM::MaximumOp,
-                               arith::AttrConvertFastMathToLLVM>;
+                               arith::AttrConvertFastMathToLLVM,
+                               /*FailOnUnsupportedFP=*/true>;
 using MaxNumFOpLowering =
     VectorConvertToLLVMPattern<arith::MaxNumFOp, LLVM::MaxNumOp,
-                               arith::AttrConvertFastMathToLLVM>;
+                               arith::AttrConvertFastMathToLLVM,
+                               /*FailOnUnsupportedFP=*/true>;
 using MaxSIOpLowering =
     VectorConvertToLLVMPattern<arith::MaxSIOp, LLVM::SMaxOp>;
 using MaxUIOpLowering =
     VectorConvertToLLVMPattern<arith::MaxUIOp, LLVM::UMaxOp>;
 using MinimumFOpLowering =
     VectorConvertToLLVMPattern<arith::MinimumFOp, LLVM::MinimumOp,
-                               arith::AttrConvertFastMathToLLVM>;
+                               arith::AttrConvertFastMathToLLVM,
+                               /*FailOnUnsupportedFP=*/true>;
 using MinNumFOpLowering =
     VectorConvertToLLVMPattern<arith::MinNumFOp, LLVM::MinNumOp,
-                               arith::AttrConvertFastMathToLLVM>;
+                               arith::AttrConvertFastMathToLLVM,
+                               /*FailOnUnsupportedFP=*/true>;
 using MinSIOpLowering =
     VectorConvertToLLVMPattern<arith::MinSIOp, LLVM::SMinOp>;
 using MinUIOpLowering =
     VectorConvertToLLVMPattern<arith::MinUIOp, LLVM::UMinOp>;
 using MulFOpLowering =
     VectorConvertToLLVMPattern<arith::MulFOp, LLVM::FMulOp,
-                               arith::AttrConvertFastMathToLLVM>;
+                               arith::AttrConvertFastMathToLLVM,
+                               /*FailOnUnsupportedFP=*/true>;
 using MulIOpLowering =
     VectorConvertToLLVMPattern<arith::MulIOp, LLVM::MulOp,
                                arith::AttrConvertOverflowToLLVM>;
 using NegFOpLowering =
     VectorConvertToLLVMPattern<arith::NegFOp, LLVM::FNegOp,
-                               arith::AttrConvertFastMathToLLVM>;
+                               arith::AttrConvertFastMathToLLVM,
+                               /*FailOnUnsupportedFP=*/true>;
 using OrIOpLowering = VectorConvertToLLVMPattern<arith::OrIOp, LLVM::OrOp>;
 using RemFOpLowering =
     VectorConvertToLLVMPattern<arith::RemFOp, LLVM::FRemOp,
-                               arith::AttrConvertFastMathToLLVM>;
+                               arith::AttrConvertFastMathToLLVM,
+                               /*FailOnUnsupportedFP=*/true>;
 using RemSIOpLowering =
     VectorConvertToLLVMPattern<arith::RemSIOp, LLVM::SRemOp>;
 using RemUIOpLowering =
@@ -151,21 +169,25 @@ using SIToFPOpLowering =
     VectorConvertToLLVMPattern<arith::SIToFPOp, LLVM::SIToFPOp>;
 using SubFOpLowering =
     VectorConvertToLLVMPattern<arith::SubFOp, LLVM::FSubOp,
-                               arith::AttrConvertFastMathToLLVM>;
+                               arith::AttrConvertFastMathToLLVM,
+                               /*FailOnUnsupportedFP=*/true>;
 using SubIOpLowering =
     VectorConvertToLLVMPattern<arith::SubIOp, LLVM::SubOp,
                                arith::AttrConvertOverflowToLLVM>;
 using TruncFOpLowering =
     ConstrainedVectorConvertToLLVMPattern<arith::TruncFOp, LLVM::FPTruncOp,
-                                          false>;
+                                          false, AttrConvertPassThrough,
+                                          /*FailOnUnsupportedFP=*/true>;
 using ConstrainedTruncFOpLowering = ConstrainedVectorConvertToLLVMPattern<
     arith::TruncFOp, LLVM::ConstrainedFPTruncIntr, true,
-    arith::AttrConverterConstrainedFPToLLVM>;
+    arith::AttrConverterConstrainedFPToLLVM, /*FailOnUnsupportedFP=*/true>;
 using TruncIOpLowering =
     VectorConvertToLLVMPattern<arith::TruncIOp, LLVM::TruncOp,
                                arith::AttrConvertOverflowToLLVM>;
 using UIToFPOpLowering =
-    VectorConvertToLLVMPattern<arith::UIToFPOp, LLVM::UIToFPOp>;
+    VectorConvertToLLVMPattern<arith::UIToFPOp, LLVM::UIToFPOp,
+                               AttrConvertPassThrough,
+                               /*FailOnUnsupportedFP=*/true>;
 using XOrIOpLowering = VectorConvertToLLVMPattern<arith::XOrIOp, LLVM::XOrOp>;
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
index b5dcb01d3dc6b..5f1ec66234df2 100644
--- a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
+++ b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
@@ -754,11 +754,13 @@ func.func @memref_bitcast(%1: memref<?xi16>) -> memref<?xbf16> {
 //       CHECK:   arith.addf {{.*}} : f4E2M1FN
 //       CHECK:   arith.addf {{.*}} : vector<4xf4E2M1FN>
 //       CHECK:   arith.addf {{.*}} : vector<8x4xf4E2M1FN>
-func.func @unsupported_fp_type(%arg0: f4E2M1FN, %arg1: vector<4xf4E2M1FN>, %arg2: vector<8x4xf4E2M1FN>) -> (f4E2M1FN, vector<4xf4E2M1FN>, vector<8x4xf4E2M1FN>) {
+//       CHECK:   llvm.select {{.*}} : i1, i4
+func.func @unsupported_fp_type(%arg0: f4E2M1FN, %arg1: vector<4xf4E2M1FN>, %arg2: vector<8x4xf4E2M1FN>, %arg3: f4E2M1FN, %arg4: i1) {
   %0 = arith.addf %arg0, %arg0 : f4E2M1FN
   %1 = arith.addf %arg1, %arg1 : vector<4xf4E2M1FN>
   %2 = arith.addf %arg2, %arg2 : vector<8x4xf4E2M1FN>
-  return %0, %1, %2 : f4E2M1FN, vector<4xf4E2M1FN>, vector<8x4xf4E2M1FN>
+  %3 = arith.select %arg4, %arg0, %arg3 : f4E2M1FN
+  return
 }
 
 // -----


        


More information about the Mlir-commits mailing list