[Mlir-commits] [mlir] [mlir][arith] Fix `arith.select` lowering after #166513 (PR #166692)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Nov 5 19:15:09 PST 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
@llvm/pr-subscribers-mlir-llvm
Author: Matthias Springer (matthias-springer)
<details>
<summary>Changes</summary>
#<!-- -->166513 broke the lowering of `arith.select` with unsupported FP4 types. For this op, it is fine to convert to `i4`.
---
Full diff: https://github.com/llvm/llvm-project/pull/166692.diff
3 Files Affected:
- (modified) mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h (+8-5)
- (modified) mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp (+4-1)
- (modified) mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir (+4-2)
``````````diff
diff --git a/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h b/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h
index cad6cec761ab8..b8e3023b25569 100644
--- a/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h
+++ b/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h
@@ -86,7 +86,8 @@ class AttrConvertPassThrough {
/// ArrayRef<NamedAttribute>.
template <typename SourceOp, typename TargetOp,
template <typename, typename> typename AttrConvert =
- AttrConvertPassThrough>
+ AttrConvertPassThrough,
+ bool FailOnUnsupportedFP = true>
class VectorConvertToLLVMPattern : public ConvertOpToLLVMPattern<SourceOp> {
public:
using ConvertOpToLLVMPattern<SourceOp>::ConvertOpToLLVMPattern;
@@ -123,11 +124,13 @@ class VectorConvertToLLVMPattern : public ConvertOpToLLVMPattern<SourceOp> {
"unsupported floating point type");
return success();
};
- for (Value operand : op->getOperands())
- if (failed(checkType(operand)))
+ if (FailOnUnsupportedFP) {
+ for (Value operand : op->getOperands())
+ if (failed(checkType(operand)))
+ return failure();
+ if (failed(checkType(op->getResult(0))))
return failure();
- if (failed(checkType(op->getResult(0))))
- return failure();
+ }
// Determine attributes for the target op
AttrConvert<SourceOp, TargetOp> attrConvert(op);
diff --git a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
index 03ed4d51cc744..55cffa1e22d77 100644
--- a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
+++ b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
@@ -14,6 +14,7 @@
#include "mlir/Conversion/LLVMCommon/VectorPattern.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Transforms/Passes.h"
+#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
#include "mlir/Dialect/LLVMIR/LLVMAttrs.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/IR/TypeUtilities.h"
@@ -139,7 +140,9 @@ using RemSIOpLowering =
using RemUIOpLowering =
VectorConvertToLLVMPattern<arith::RemUIOp, LLVM::URemOp>;
using SelectOpLowering =
- VectorConvertToLLVMPattern<arith::SelectOp, LLVM::SelectOp>;
+ VectorConvertToLLVMPattern<arith::SelectOp, LLVM::SelectOp,
+ AttrConvertPassThrough,
+ /*FailOnUnsupportedFP=*/false>;
using ShLIOpLowering =
VectorConvertToLLVMPattern<arith::ShLIOp, LLVM::ShlOp,
arith::AttrConvertOverflowToLLVM>;
diff --git a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
index b5dcb01d3dc6b..5f1ec66234df2 100644
--- a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
+++ b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
@@ -754,11 +754,13 @@ func.func @memref_bitcast(%1: memref<?xi16>) -> memref<?xbf16> {
// CHECK: arith.addf {{.*}} : f4E2M1FN
// CHECK: arith.addf {{.*}} : vector<4xf4E2M1FN>
// CHECK: arith.addf {{.*}} : vector<8x4xf4E2M1FN>
-func.func @unsupported_fp_type(%arg0: f4E2M1FN, %arg1: vector<4xf4E2M1FN>, %arg2: vector<8x4xf4E2M1FN>) -> (f4E2M1FN, vector<4xf4E2M1FN>, vector<8x4xf4E2M1FN>) {
+// CHECK: llvm.select {{.*}} : i1, i4
+func.func @unsupported_fp_type(%arg0: f4E2M1FN, %arg1: vector<4xf4E2M1FN>, %arg2: vector<8x4xf4E2M1FN>, %arg3: f4E2M1FN, %arg4: i1) {
%0 = arith.addf %arg0, %arg0 : f4E2M1FN
%1 = arith.addf %arg1, %arg1 : vector<4xf4E2M1FN>
%2 = arith.addf %arg2, %arg2 : vector<8x4xf4E2M1FN>
- return %0, %1, %2 : f4E2M1FN, vector<4xf4E2M1FN>, vector<8x4xf4E2M1FN>
+ %3 = arith.select %arg4, %arg0, %arg3 : f4E2M1FN
+ return
}
// -----
``````````
</details>
https://github.com/llvm/llvm-project/pull/166692
More information about the Mlir-commits
mailing list