[Mlir-commits] [mlir] [mlir][math] Fix intrinsic conversions to LLVM for 0D-vector types (PR #141020)
Artem Gindinson
llvmlistbot at llvm.org
Thu May 29 02:00:06 PDT 2025
https://github.com/AGindinson updated https://github.com/llvm/llvm-project/pull/141020
>From 54c997b027994dad412c01241a86fe993eb92e81 Mon Sep 17 00:00:00 2001
From: Artem Gindinson <gindinson at roofline.ai>
Date: Thu, 22 May 2025 08:40:28 +0200
Subject: [PATCH 1/5] [mlir][math] Fix intrinsic conversions to LLVM for
0D-vector types
`vector<t>` types are not compatible with the LLVM type system, and must be
explicitly converted into `vector<1xt>` when lowering. Employ this rule within
the conversion pattern for `math.ctlz`, `.cttz` and `.absi` intrinsics.
Signed-off-by: Artem Gindinson <gindinson at roofline.ai>
---
mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp | 11 ++++++-
.../Conversion/MathToLLVM/math-to-llvm.mlir | 33 +++++++++++++++++++
2 files changed, 43 insertions(+), 1 deletion(-)
diff --git a/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp b/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
index 97da96afac4cd..19cd960b15294 100644
--- a/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
+++ b/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
@@ -84,6 +84,15 @@ struct IntOpWithFlagLowering : public ConvertOpToLLVMPattern<MathOp> {
auto loc = op.getLoc();
auto resultType = op.getResult().getType();
+ const auto &typeConverter = *this->getTypeConverter();
+ if (!LLVM::isCompatibleType(resultType)) {
+ resultType = typeConverter.convertType(resultType);
+ if (!resultType)
+ return failure();
+ }
+ if (operandType != resultType)
+ return rewriter.notifyMatchFailure(
+ op, "compatible result type doesn't match operand type");
if (!isa<LLVM::LLVMArrayType>(operandType)) {
rewriter.replaceOpWithNewOp<LLVMOp>(op, resultType, adaptor.getOperand(),
@@ -96,7 +105,7 @@ struct IntOpWithFlagLowering : public ConvertOpToLLVMPattern<MathOp> {
return failure();
return LLVM::detail::handleMultidimensionalVectors(
- op.getOperation(), adaptor.getOperands(), *this->getTypeConverter(),
+ op.getOperation(), adaptor.getOperands(), typeConverter,
[&](Type llvm1DVectorTy, ValueRange operands) {
return rewriter.create<LLVMOp>(loc, llvm1DVectorTy, operands[0],
false);
diff --git a/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir b/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir
index 974743a55932b..73325a3fd913e 100644
--- a/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir
+++ b/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir
@@ -19,6 +19,8 @@ func.func @ops(%arg0: f32, %arg1: f32, %arg2: i32, %arg3: i32, %arg4: f64) {
// -----
+// CHECK-LABEL: func @absi(
+// CHECK-SAME: i32
func.func @absi(%arg0: i32) -> i32 {
// CHECK: = "llvm.intr.abs"(%{{.*}}) <{is_int_min_poison = false}> : (i32) -> i32
%0 = math.absi %arg0 : i32
@@ -27,6 +29,17 @@ func.func @absi(%arg0: i32) -> i32 {
// -----
+// CHECK-LABEL: func @absi_0d_vec(
+// CHECK-SAME: i32
+func.func @absi_0d_vec(%arg0 : vector<i32>) {
+ // CHECK: %[[CAST:.+]] = builtin.unrealized_conversion_cast %arg0 : vector<i32> to vector<1xi32>
+ // CHECK: "llvm.intr.abs"(%[[CAST]]) <{is_int_min_poison = false}> : (vector<1xi32>) -> vector<1xi32>
+ %0 = math.absi %arg0 : vector<i32>
+ func.return
+}
+
+// -----
+
// CHECK-LABEL: func @log1p(
// CHECK-SAME: f32
func.func @log1p(%arg0 : f32) {
@@ -201,6 +214,15 @@ func.func @ctlz(%arg0 : i32) {
func.return
}
+// CHECK-LABEL: func @ctlz_0d_vec(
+// CHECK-SAME: i32
+func.func @ctlz_0d_vec(%arg0 : vector<i32>) {
+ // CHECK: %[[CAST:.+]] = builtin.unrealized_conversion_cast %arg0 : vector<i32> to vector<1xi32>
+ // CHECK: "llvm.intr.ctlz"(%[[CAST]]) <{is_zero_poison = false}> : (vector<1xi32>) -> vector<1xi32>
+ %0 = math.ctlz %arg0 : vector<i32>
+ func.return
+}
+
// -----
// CHECK-LABEL: func @cttz(
@@ -213,6 +235,17 @@ func.func @cttz(%arg0 : i32) {
// -----
+// CHECK-LABEL: func @cttz_0d_vec(
+// CHECK-SAME: i32
+func.func @cttz_0d_vec(%arg0 : vector<i32>) {
+ // CHECK: %[[CAST:.+]] = builtin.unrealized_conversion_cast %arg0 : vector<i32> to vector<1xi32>
+ // CHECK: "llvm.intr.cttz"(%[[CAST]]) <{is_zero_poison = false}> : (vector<1xi32>) -> vector<1xi32>
+ %0 = math.cttz %arg0 : vector<i32>
+ func.return
+}
+
+// -----
+
// CHECK-LABEL: func @cttz_vec(
// CHECK-SAME: i32
func.func @cttz_vec(%arg0 : vector<4xi32>) {
>From 0e56cad016bd0588d06b83a114017c1d2193ee21 Mon Sep 17 00:00:00 2001
From: Artem Gindinson <gindinson at roofline.ai>
Date: Wed, 28 May 2025 14:47:51 +0000
Subject: [PATCH 2/5] [fixup] Drop obsolete compatibility check
Signed-off-by: Artem Gindinson <gindinson at roofline.ai>
---
mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp | 3 ---
1 file changed, 3 deletions(-)
diff --git a/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp b/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
index bbc17739e7a98..b8097d21dc227 100644
--- a/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
+++ b/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
@@ -94,9 +94,6 @@ struct IntOpWithFlagLowering : public ConvertOpToLLVMPattern<MathOp> {
if (!resultType)
return failure();
}
- if (operandType != resultType)
- return rewriter.notifyMatchFailure(
- op, "compatible result type doesn't match operand type");
if (!isa<LLVM::LLVMArrayType>(operandType)) {
rewriter.replaceOpWithNewOp<LLVMOp>(op, resultType, adaptor.getOperand(),
>From 738abbc077d2a47cd2e99f522ad148c8baed97ad Mon Sep 17 00:00:00 2001
From: Artem Gindinson <gindinson at roofline.ai>
Date: Wed, 28 May 2025 14:53:41 +0000
Subject: [PATCH 3/5] [fixup] Improve code consistency with other lowerings
Signed-off-by: Artem Gindinson <gindinson at roofline.ai>
---
mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp | 14 ++++++--------
1 file changed, 6 insertions(+), 8 deletions(-)
diff --git a/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp b/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
index b8097d21dc227..de8c76cfb28c6 100644
--- a/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
+++ b/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
@@ -89,19 +89,17 @@ struct IntOpWithFlagLowering : public ConvertOpToLLVMPattern<MathOp> {
auto loc = op.getLoc();
auto resultType = op.getResult().getType();
const auto &typeConverter = *this->getTypeConverter();
- if (!LLVM::isCompatibleType(resultType)) {
- resultType = typeConverter.convertType(resultType);
- if (!resultType)
- return failure();
- }
+ auto llvmResultType = typeConverter.convertType(resultType);
+ if (!llvmResultType)
+ return failure();
if (!isa<LLVM::LLVMArrayType>(operandType)) {
- rewriter.replaceOpWithNewOp<LLVMOp>(op, resultType, adaptor.getOperand(),
- false);
+ rewriter.replaceOpWithNewOp<LLVMOp>(op, llvmResultType,
+ adaptor.getOperand(), false);
return success();
}
- auto vectorType = dyn_cast<VectorType>(resultType);
+ auto vectorType = dyn_cast<VectorType>(llvmResultType);
if (!vectorType)
return failure();
>From aa4ce6f2ad4fd244efbedd88c1b69c82aa048692 Mon Sep 17 00:00:00 2001
From: Artem Gindinson <gindinson at roofline.ai>
Date: Wed, 28 May 2025 17:12:50 +0000
Subject: [PATCH 4/5] [fixup] Convert the operand type separately
Signed-off-by: Artem Gindinson <gindinson at roofline.ai>
---
mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp | 13 +++++++------
1 file changed, 7 insertions(+), 6 deletions(-)
diff --git a/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp b/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
index de8c76cfb28c6..c324f93a441aa 100644
--- a/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
+++ b/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
@@ -73,6 +73,8 @@ using ATanOpLowering = ConvertFMFMathToLLVMPattern<math::AtanOp, LLVM::ATanOp>;
using ATan2OpLowering =
ConvertFMFMathToLLVMPattern<math::Atan2Op, LLVM::ATan2Op>;
// A `CtLz/CtTz/absi(a)` is converted into `CtLz/CtTz/absi(a, false)`.
+// TODO: Result and operand types match for `absi` as opposed to `ct*z`, so it
+// may be better to separate the patterns.
template <typename MathOp, typename LLVMOp>
struct IntOpWithFlagLowering : public ConvertOpToLLVMPattern<MathOp> {
using ConvertOpToLLVMPattern<MathOp>::ConvertOpToLLVMPattern;
@@ -81,26 +83,25 @@ struct IntOpWithFlagLowering : public ConvertOpToLLVMPattern<MathOp> {
LogicalResult
matchAndRewrite(MathOp op, typename MathOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
+ const auto &typeConverter = *this->getTypeConverter();
auto operandType = adaptor.getOperand().getType();
-
- if (!operandType || !LLVM::isCompatibleType(operandType))
+ auto llvmOperandType = typeConverter.convertType(operandType);
+ if (!llvmOperandType)
return failure();
auto loc = op.getLoc();
auto resultType = op.getResult().getType();
- const auto &typeConverter = *this->getTypeConverter();
auto llvmResultType = typeConverter.convertType(resultType);
if (!llvmResultType)
return failure();
- if (!isa<LLVM::LLVMArrayType>(operandType)) {
+ if (!isa<LLVM::LLVMArrayType>(llvmOperandType)) {
rewriter.replaceOpWithNewOp<LLVMOp>(op, llvmResultType,
adaptor.getOperand(), false);
return success();
}
- auto vectorType = dyn_cast<VectorType>(llvmResultType);
- if (!vectorType)
+ if (!isa<VectorType>(llvmResultType))
return failure();
return LLVM::detail::handleMultidimensionalVectors(
>From 7c668846557bfd6ff08a9567129170923d93f7e5 Mon Sep 17 00:00:00 2001
From: Artem Gindinson <gindinson at roofline.ai>
Date: Thu, 29 May 2025 08:55:46 +0000
Subject: [PATCH 5/5] [fixup] Consistent function names in LIT
Signed-off-by: Artem Gindinson <gindinson at roofline.ai>
---
mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir | 12 ++++++------
1 file changed, 6 insertions(+), 6 deletions(-)
diff --git a/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir b/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir
index e4b79301fbaa5..ee388f1dbe898 100644
--- a/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir
+++ b/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir
@@ -29,9 +29,9 @@ func.func @absi(%arg0: i32) -> i32 {
// -----
-// CHECK-LABEL: func @absi_0d_vec(
+// CHECK-LABEL: func @absi_0dvector(
// CHECK-SAME: i32
-func.func @absi_0d_vec(%arg0 : vector<i32>) {
+func.func @absi_0dvector(%arg0 : vector<i32>) {
// CHECK: %[[CAST:.+]] = builtin.unrealized_conversion_cast %arg0 : vector<i32> to vector<1xi32>
// CHECK: "llvm.intr.abs"(%[[CAST]]) <{is_int_min_poison = false}> : (vector<1xi32>) -> vector<1xi32>
%0 = math.absi %arg0 : vector<i32>
@@ -292,9 +292,9 @@ func.func @ctlz(%arg0 : i32) {
func.return
}
-// CHECK-LABEL: func @ctlz_0d_vec(
+// CHECK-LABEL: func @ctlz_0dvector(
// CHECK-SAME: i32
-func.func @ctlz_0d_vec(%arg0 : vector<i32>) {
+func.func @ctlz_0dvector(%arg0 : vector<i32>) {
// CHECK: %[[CAST:.+]] = builtin.unrealized_conversion_cast %arg0 : vector<i32> to vector<1xi32>
// CHECK: "llvm.intr.ctlz"(%[[CAST]]) <{is_zero_poison = false}> : (vector<1xi32>) -> vector<1xi32>
%0 = math.ctlz %arg0 : vector<i32>
@@ -313,9 +313,9 @@ func.func @cttz(%arg0 : i32) {
// -----
-// CHECK-LABEL: func @cttz_0d_vec(
+// CHECK-LABEL: func @cttz_0dvector(
// CHECK-SAME: i32
-func.func @cttz_0d_vec(%arg0 : vector<i32>) {
+func.func @cttz_0dvector(%arg0 : vector<i32>) {
// CHECK: %[[CAST:.+]] = builtin.unrealized_conversion_cast %arg0 : vector<i32> to vector<1xi32>
// CHECK: "llvm.intr.cttz"(%[[CAST]]) <{is_zero_poison = false}> : (vector<1xi32>) -> vector<1xi32>
%0 = math.cttz %arg0 : vector<i32>
More information about the Mlir-commits
mailing list