[Mlir-commits] [mlir] [mlir][arith] Fix `arith.select` lowering after #166513 (PR #166692)
Matthias Springer
llvmlistbot at llvm.org
Wed Nov 5 19:14:37 PST 2025
https://github.com/matthias-springer created https://github.com/llvm/llvm-project/pull/166692
#166513 broke the lowering of `arith.select` with unsupported FP4 types. For this op, it is fine to convert to `i4`.
>From 48876ecfb0f658236ffdb59a402a4c5cec9d8d2e Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Thu, 6 Nov 2025 03:12:51 +0000
Subject: [PATCH] [mlir][arith] Fix `arith.select` lowering after #166513
---
.../mlir/Conversion/LLVMCommon/VectorPattern.h | 13 ++++++++-----
mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp | 5 ++++-
mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir | 6 ++++--
3 files changed, 16 insertions(+), 8 deletions(-)
diff --git a/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h b/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h
index cad6cec761ab8..b8e3023b25569 100644
--- a/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h
+++ b/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h
@@ -86,7 +86,8 @@ class AttrConvertPassThrough {
/// ArrayRef<NamedAttribute>.
template <typename SourceOp, typename TargetOp,
template <typename, typename> typename AttrConvert =
- AttrConvertPassThrough>
+ AttrConvertPassThrough,
+ bool FailOnUnsupportedFP = true>
class VectorConvertToLLVMPattern : public ConvertOpToLLVMPattern<SourceOp> {
public:
using ConvertOpToLLVMPattern<SourceOp>::ConvertOpToLLVMPattern;
@@ -123,11 +124,13 @@ class VectorConvertToLLVMPattern : public ConvertOpToLLVMPattern<SourceOp> {
"unsupported floating point type");
return success();
};
- for (Value operand : op->getOperands())
- if (failed(checkType(operand)))
+ if (FailOnUnsupportedFP) {
+ for (Value operand : op->getOperands())
+ if (failed(checkType(operand)))
+ return failure();
+ if (failed(checkType(op->getResult(0))))
return failure();
- if (failed(checkType(op->getResult(0))))
- return failure();
+ }
// Determine attributes for the target op
AttrConvert<SourceOp, TargetOp> attrConvert(op);
diff --git a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
index 03ed4d51cc744..55cffa1e22d77 100644
--- a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
+++ b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
@@ -14,6 +14,7 @@
#include "mlir/Conversion/LLVMCommon/VectorPattern.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Transforms/Passes.h"
+#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
#include "mlir/Dialect/LLVMIR/LLVMAttrs.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/IR/TypeUtilities.h"
@@ -139,7 +140,9 @@ using RemSIOpLowering =
using RemUIOpLowering =
VectorConvertToLLVMPattern<arith::RemUIOp, LLVM::URemOp>;
using SelectOpLowering =
- VectorConvertToLLVMPattern<arith::SelectOp, LLVM::SelectOp>;
+ VectorConvertToLLVMPattern<arith::SelectOp, LLVM::SelectOp,
+ AttrConvertPassThrough,
+ /*FailOnUnsupportedFP=*/false>;
using ShLIOpLowering =
VectorConvertToLLVMPattern<arith::ShLIOp, LLVM::ShlOp,
arith::AttrConvertOverflowToLLVM>;
diff --git a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
index b5dcb01d3dc6b..5f1ec66234df2 100644
--- a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
+++ b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
@@ -754,11 +754,13 @@ func.func @memref_bitcast(%1: memref<?xi16>) -> memref<?xbf16> {
// CHECK: arith.addf {{.*}} : f4E2M1FN
// CHECK: arith.addf {{.*}} : vector<4xf4E2M1FN>
// CHECK: arith.addf {{.*}} : vector<8x4xf4E2M1FN>
-func.func @unsupported_fp_type(%arg0: f4E2M1FN, %arg1: vector<4xf4E2M1FN>, %arg2: vector<8x4xf4E2M1FN>) -> (f4E2M1FN, vector<4xf4E2M1FN>, vector<8x4xf4E2M1FN>) {
+// CHECK: llvm.select {{.*}} : i1, i4
+func.func @unsupported_fp_type(%arg0: f4E2M1FN, %arg1: vector<4xf4E2M1FN>, %arg2: vector<8x4xf4E2M1FN>, %arg3: f4E2M1FN, %arg4: i1) {
%0 = arith.addf %arg0, %arg0 : f4E2M1FN
%1 = arith.addf %arg1, %arg1 : vector<4xf4E2M1FN>
%2 = arith.addf %arg2, %arg2 : vector<8x4xf4E2M1FN>
- return %0, %1, %2 : f4E2M1FN, vector<4xf4E2M1FN>, vector<8x4xf4E2M1FN>
+ %3 = arith.select %arg4, %arg0, %arg3 : f4E2M1FN
+ return
}
// -----
More information about the Mlir-commits
mailing list