[Mlir-commits] [mlir] 3740368 - [mlir][arith] Fix `arith.select` lowering after #166513 (#166692)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Nov 7 09:59:59 PST 2025
Author: Matthias Springer
Date: 2025-11-07T09:59:55-08:00
New Revision: 37403685298bd3a7ba2c482b4d6fc26fa0491e19
URL: https://github.com/llvm/llvm-project/commit/37403685298bd3a7ba2c482b4d6fc26fa0491e19
DIFF: https://github.com/llvm/llvm-project/commit/37403685298bd3a7ba2c482b4d6fc26fa0491e19.diff
LOG: [mlir][arith] Fix `arith.select` lowering after #166513 (#166692)
#166513 broke the lowering of `arith.select` with unsupported FP4 types.
For this op, it is fine to convert to `i4`.
Added:
Modified:
mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h
mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h b/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h
index cad6cec761ab8..e7ab63abfeaa1 100644
--- a/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h
+++ b/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h
@@ -86,7 +86,8 @@ class AttrConvertPassThrough {
/// ArrayRef<NamedAttribute>.
template <typename SourceOp, typename TargetOp,
template <typename, typename> typename AttrConvert =
- AttrConvertPassThrough>
+ AttrConvertPassThrough,
+ bool FailOnUnsupportedFP = false>
class VectorConvertToLLVMPattern : public ConvertOpToLLVMPattern<SourceOp> {
public:
using ConvertOpToLLVMPattern<SourceOp>::ConvertOpToLLVMPattern;
@@ -123,11 +124,13 @@ class VectorConvertToLLVMPattern : public ConvertOpToLLVMPattern<SourceOp> {
"unsupported floating point type");
return success();
};
- for (Value operand : op->getOperands())
- if (failed(checkType(operand)))
+ if (FailOnUnsupportedFP) {
+ for (Value operand : op->getOperands())
+ if (failed(checkType(operand)))
+ return failure();
+ if (failed(checkType(op->getResult(0))))
return failure();
- if (failed(checkType(op->getResult(0))))
- return failure();
+ }
// Determine attributes for the target op
AttrConvert<SourceOp, TargetOp> attrConvert(op);
diff --git a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
index 03ed4d51cc744..b6099902cc337 100644
--- a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
+++ b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
@@ -36,20 +36,23 @@ namespace {
/// attribute.
template <typename SourceOp, typename TargetOp, bool Constrained,
template <typename, typename> typename AttrConvert =
- AttrConvertPassThrough>
+ AttrConvertPassThrough,
+ bool FailOnUnsupportedFP = false>
struct ConstrainedVectorConvertToLLVMPattern
- : public VectorConvertToLLVMPattern<SourceOp, TargetOp, AttrConvert> {
- using VectorConvertToLLVMPattern<SourceOp, TargetOp,
- AttrConvert>::VectorConvertToLLVMPattern;
+ : public VectorConvertToLLVMPattern<SourceOp, TargetOp, AttrConvert,
+ FailOnUnsupportedFP> {
+ using VectorConvertToLLVMPattern<
+ SourceOp, TargetOp, AttrConvert,
+ FailOnUnsupportedFP>::VectorConvertToLLVMPattern;
LogicalResult
matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (Constrained != static_cast<bool>(op.getRoundingModeAttr()))
return failure();
- return VectorConvertToLLVMPattern<SourceOp, TargetOp,
- AttrConvert>::matchAndRewrite(op, adaptor,
- rewriter);
+ return VectorConvertToLLVMPattern<
+ SourceOp, TargetOp, AttrConvert,
+ FailOnUnsupportedFP>::matchAndRewrite(op, adaptor, rewriter);
}
};
@@ -78,7 +81,8 @@ struct IdentityBitcastLowering final
using AddFOpLowering =
VectorConvertToLLVMPattern<arith::AddFOp, LLVM::FAddOp,
- arith::AttrConvertFastMathToLLVM>;
+ arith::AttrConvertFastMathToLLVM,
+ /*FailOnUnsupportedFP=*/true>;
using AddIOpLowering =
VectorConvertToLLVMPattern<arith::AddIOp, LLVM::AddOp,
arith::AttrConvertOverflowToLLVM>;
@@ -87,53 +91,67 @@ using BitcastOpLowering =
VectorConvertToLLVMPattern<arith::BitcastOp, LLVM::BitcastOp>;
using DivFOpLowering =
VectorConvertToLLVMPattern<arith::DivFOp, LLVM::FDivOp,
- arith::AttrConvertFastMathToLLVM>;
+ arith::AttrConvertFastMathToLLVM,
+ /*FailOnUnsupportedFP=*/true>;
using DivSIOpLowering =
VectorConvertToLLVMPattern<arith::DivSIOp, LLVM::SDivOp>;
using DivUIOpLowering =
VectorConvertToLLVMPattern<arith::DivUIOp, LLVM::UDivOp>;
-using ExtFOpLowering = VectorConvertToLLVMPattern<arith::ExtFOp, LLVM::FPExtOp>;
+using ExtFOpLowering = VectorConvertToLLVMPattern<arith::ExtFOp, LLVM::FPExtOp,
+ AttrConvertPassThrough,
+ /*FailOnUnsupportedFP=*/true>;
using ExtSIOpLowering =
VectorConvertToLLVMPattern<arith::ExtSIOp, LLVM::SExtOp>;
using ExtUIOpLowering =
VectorConvertToLLVMPattern<arith::ExtUIOp, LLVM::ZExtOp>;
using FPToSIOpLowering =
- VectorConvertToLLVMPattern<arith::FPToSIOp, LLVM::FPToSIOp>;
+ VectorConvertToLLVMPattern<arith::FPToSIOp, LLVM::FPToSIOp,
+ AttrConvertPassThrough,
+ /*FailOnUnsupportedFP=*/true>;
using FPToUIOpLowering =
- VectorConvertToLLVMPattern<arith::FPToUIOp, LLVM::FPToUIOp>;
+ VectorConvertToLLVMPattern<arith::FPToUIOp, LLVM::FPToUIOp,
+ AttrConvertPassThrough,
+ /*FailOnUnsupportedFP=*/true>;
using MaximumFOpLowering =
VectorConvertToLLVMPattern<arith::MaximumFOp, LLVM::MaximumOp,
- arith::AttrConvertFastMathToLLVM>;
+ arith::AttrConvertFastMathToLLVM,
+ /*FailOnUnsupportedFP=*/true>;
using MaxNumFOpLowering =
VectorConvertToLLVMPattern<arith::MaxNumFOp, LLVM::MaxNumOp,
- arith::AttrConvertFastMathToLLVM>;
+ arith::AttrConvertFastMathToLLVM,
+ /*FailOnUnsupportedFP=*/true>;
using MaxSIOpLowering =
VectorConvertToLLVMPattern<arith::MaxSIOp, LLVM::SMaxOp>;
using MaxUIOpLowering =
VectorConvertToLLVMPattern<arith::MaxUIOp, LLVM::UMaxOp>;
using MinimumFOpLowering =
VectorConvertToLLVMPattern<arith::MinimumFOp, LLVM::MinimumOp,
- arith::AttrConvertFastMathToLLVM>;
+ arith::AttrConvertFastMathToLLVM,
+ /*FailOnUnsupportedFP=*/true>;
using MinNumFOpLowering =
VectorConvertToLLVMPattern<arith::MinNumFOp, LLVM::MinNumOp,
- arith::AttrConvertFastMathToLLVM>;
+ arith::AttrConvertFastMathToLLVM,
+ /*FailOnUnsupportedFP=*/true>;
using MinSIOpLowering =
VectorConvertToLLVMPattern<arith::MinSIOp, LLVM::SMinOp>;
using MinUIOpLowering =
VectorConvertToLLVMPattern<arith::MinUIOp, LLVM::UMinOp>;
using MulFOpLowering =
VectorConvertToLLVMPattern<arith::MulFOp, LLVM::FMulOp,
- arith::AttrConvertFastMathToLLVM>;
+ arith::AttrConvertFastMathToLLVM,
+ /*FailOnUnsupportedFP=*/true>;
using MulIOpLowering =
VectorConvertToLLVMPattern<arith::MulIOp, LLVM::MulOp,
arith::AttrConvertOverflowToLLVM>;
using NegFOpLowering =
VectorConvertToLLVMPattern<arith::NegFOp, LLVM::FNegOp,
- arith::AttrConvertFastMathToLLVM>;
+ arith::AttrConvertFastMathToLLVM,
+ /*FailOnUnsupportedFP=*/true>;
using OrIOpLowering = VectorConvertToLLVMPattern<arith::OrIOp, LLVM::OrOp>;
using RemFOpLowering =
VectorConvertToLLVMPattern<arith::RemFOp, LLVM::FRemOp,
- arith::AttrConvertFastMathToLLVM>;
+ arith::AttrConvertFastMathToLLVM,
+ /*FailOnUnsupportedFP=*/true>;
using RemSIOpLowering =
VectorConvertToLLVMPattern<arith::RemSIOp, LLVM::SRemOp>;
using RemUIOpLowering =
@@ -151,21 +169,25 @@ using SIToFPOpLowering =
VectorConvertToLLVMPattern<arith::SIToFPOp, LLVM::SIToFPOp>;
using SubFOpLowering =
VectorConvertToLLVMPattern<arith::SubFOp, LLVM::FSubOp,
- arith::AttrConvertFastMathToLLVM>;
+ arith::AttrConvertFastMathToLLVM,
+ /*FailOnUnsupportedFP=*/true>;
using SubIOpLowering =
VectorConvertToLLVMPattern<arith::SubIOp, LLVM::SubOp,
arith::AttrConvertOverflowToLLVM>;
using TruncFOpLowering =
ConstrainedVectorConvertToLLVMPattern<arith::TruncFOp, LLVM::FPTruncOp,
- false>;
+ false, AttrConvertPassThrough,
+ /*FailOnUnsupportedFP=*/true>;
using ConstrainedTruncFOpLowering = ConstrainedVectorConvertToLLVMPattern<
arith::TruncFOp, LLVM::ConstrainedFPTruncIntr, true,
- arith::AttrConverterConstrainedFPToLLVM>;
+ arith::AttrConverterConstrainedFPToLLVM, /*FailOnUnsupportedFP=*/true>;
using TruncIOpLowering =
VectorConvertToLLVMPattern<arith::TruncIOp, LLVM::TruncOp,
arith::AttrConvertOverflowToLLVM>;
using UIToFPOpLowering =
- VectorConvertToLLVMPattern<arith::UIToFPOp, LLVM::UIToFPOp>;
+ VectorConvertToLLVMPattern<arith::UIToFPOp, LLVM::UIToFPOp,
+ AttrConvertPassThrough,
+ /*FailOnUnsupportedFP=*/true>;
using XOrIOpLowering = VectorConvertToLLVMPattern<arith::XOrIOp, LLVM::XOrOp>;
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
index b5dcb01d3dc6b..5f1ec66234df2 100644
--- a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
+++ b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
@@ -754,11 +754,13 @@ func.func @memref_bitcast(%1: memref<?xi16>) -> memref<?xbf16> {
// CHECK: arith.addf {{.*}} : f4E2M1FN
// CHECK: arith.addf {{.*}} : vector<4xf4E2M1FN>
// CHECK: arith.addf {{.*}} : vector<8x4xf4E2M1FN>
-func.func @unsupported_fp_type(%arg0: f4E2M1FN, %arg1: vector<4xf4E2M1FN>, %arg2: vector<8x4xf4E2M1FN>) -> (f4E2M1FN, vector<4xf4E2M1FN>, vector<8x4xf4E2M1FN>) {
+// CHECK: llvm.select {{.*}} : i1, i4
+func.func @unsupported_fp_type(%arg0: f4E2M1FN, %arg1: vector<4xf4E2M1FN>, %arg2: vector<8x4xf4E2M1FN>, %arg3: f4E2M1FN, %arg4: i1) {
%0 = arith.addf %arg0, %arg0 : f4E2M1FN
%1 = arith.addf %arg1, %arg1 : vector<4xf4E2M1FN>
%2 = arith.addf %arg2, %arg2 : vector<8x4xf4E2M1FN>
- return %0, %1, %2 : f4E2M1FN, vector<4xf4E2M1FN>, vector<8x4xf4E2M1FN>
+ %3 = arith.select %arg4, %arg0, %arg3 : f4E2M1FN
+ return
}
// -----
More information about the Mlir-commits
mailing list