[Mlir-commits] [mlir] [mlir][math] Fix intrinsic conversions to LLVM for 0D-vector types (PR #141020)

Artem Gindinson llvmlistbot at llvm.org
Mon Jun 2 02:57:39 PDT 2025


https://github.com/AGindinson updated https://github.com/llvm/llvm-project/pull/141020

>From 54c997b027994dad412c01241a86fe993eb92e81 Mon Sep 17 00:00:00 2001
From: Artem Gindinson <gindinson at roofline.ai>
Date: Thu, 22 May 2025 08:40:28 +0200
Subject: [PATCH 1/8] [mlir][math] Fix intrinsic conversions to LLVM for
 0D-vector types

`vector<t>` types are not compatible with the LLVM type system, and must be
explicitly converted into `vector<1xt>` when lowering. Employ this rule within
the conversion pattern for `math.ctlz`, `.cttz` and `.absi` intrinsics.

Signed-off-by: Artem Gindinson <gindinson at roofline.ai>
---
 mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp | 11 ++++++-
 .../Conversion/MathToLLVM/math-to-llvm.mlir   | 33 +++++++++++++++++++
 2 files changed, 43 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp b/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
index 97da96afac4cd..19cd960b15294 100644
--- a/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
+++ b/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
@@ -84,6 +84,15 @@ struct IntOpWithFlagLowering : public ConvertOpToLLVMPattern<MathOp> {
 
     auto loc = op.getLoc();
     auto resultType = op.getResult().getType();
+    const auto &typeConverter = *this->getTypeConverter();
+    if (!LLVM::isCompatibleType(resultType)) {
+      resultType = typeConverter.convertType(resultType);
+      if (!resultType)
+        return failure();
+    }
+    if (operandType != resultType)
+      return rewriter.notifyMatchFailure(
+          op, "compatible result type doesn't match operand type");
 
     if (!isa<LLVM::LLVMArrayType>(operandType)) {
       rewriter.replaceOpWithNewOp<LLVMOp>(op, resultType, adaptor.getOperand(),
@@ -96,7 +105,7 @@ struct IntOpWithFlagLowering : public ConvertOpToLLVMPattern<MathOp> {
       return failure();
 
     return LLVM::detail::handleMultidimensionalVectors(
-        op.getOperation(), adaptor.getOperands(), *this->getTypeConverter(),
+        op.getOperation(), adaptor.getOperands(), typeConverter,
         [&](Type llvm1DVectorTy, ValueRange operands) {
           return rewriter.create<LLVMOp>(loc, llvm1DVectorTy, operands[0],
                                          false);
diff --git a/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir b/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir
index 974743a55932b..73325a3fd913e 100644
--- a/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir
+++ b/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir
@@ -19,6 +19,8 @@ func.func @ops(%arg0: f32, %arg1: f32, %arg2: i32, %arg3: i32, %arg4: f64) {
 
 // -----
 
+// CHECK-LABEL: func @absi(
+// CHECK-SAME: i32
 func.func @absi(%arg0: i32) -> i32 {
   // CHECK: = "llvm.intr.abs"(%{{.*}}) <{is_int_min_poison = false}> : (i32) -> i32
   %0 = math.absi %arg0 : i32
@@ -27,6 +29,17 @@ func.func @absi(%arg0: i32) -> i32 {
 
 // -----
 
+// CHECK-LABEL: func @absi_0d_vec(
+// CHECK-SAME: i32
+func.func @absi_0d_vec(%arg0 : vector<i32>) {
+  // CHECK: %[[CAST:.+]] = builtin.unrealized_conversion_cast %arg0 : vector<i32> to vector<1xi32>
+  // CHECK: "llvm.intr.abs"(%[[CAST]]) <{is_int_min_poison = false}> : (vector<1xi32>) -> vector<1xi32>
+  %0 = math.absi %arg0 : vector<i32>
+  func.return
+}
+
+// -----
+
 // CHECK-LABEL: func @log1p(
 // CHECK-SAME: f32
 func.func @log1p(%arg0 : f32) {
@@ -201,6 +214,15 @@ func.func @ctlz(%arg0 : i32) {
   func.return
 }
 
+// CHECK-LABEL: func @ctlz_0d_vec(
+// CHECK-SAME: i32
+func.func @ctlz_0d_vec(%arg0 : vector<i32>) {
+  // CHECK: %[[CAST:.+]] = builtin.unrealized_conversion_cast %arg0 : vector<i32> to vector<1xi32>
+  // CHECK: "llvm.intr.ctlz"(%[[CAST]]) <{is_zero_poison = false}> : (vector<1xi32>) -> vector<1xi32>
+  %0 = math.ctlz %arg0 : vector<i32>
+  func.return
+}
+
 // -----
 
 // CHECK-LABEL: func @cttz(
@@ -213,6 +235,17 @@ func.func @cttz(%arg0 : i32) {
 
 // -----
 
+// CHECK-LABEL: func @cttz_0d_vec(
+// CHECK-SAME: i32
+func.func @cttz_0d_vec(%arg0 : vector<i32>) {
+  // CHECK: %[[CAST:.+]] = builtin.unrealized_conversion_cast %arg0 : vector<i32> to vector<1xi32>
+  // CHECK: "llvm.intr.cttz"(%[[CAST]]) <{is_zero_poison = false}> : (vector<1xi32>) -> vector<1xi32>
+  %0 = math.cttz %arg0 : vector<i32>
+  func.return
+}
+
+// -----
+
 // CHECK-LABEL: func @cttz_vec(
 // CHECK-SAME: i32
 func.func @cttz_vec(%arg0 : vector<4xi32>) {

>From 0e56cad016bd0588d06b83a114017c1d2193ee21 Mon Sep 17 00:00:00 2001
From: Artem Gindinson <gindinson at roofline.ai>
Date: Wed, 28 May 2025 14:47:51 +0000
Subject: [PATCH 2/8] [fixup] Drop obsolete compatibility check

Signed-off-by: Artem Gindinson <gindinson at roofline.ai>
---
 mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp | 3 ---
 1 file changed, 3 deletions(-)

diff --git a/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp b/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
index bbc17739e7a98..b8097d21dc227 100644
--- a/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
+++ b/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
@@ -94,9 +94,6 @@ struct IntOpWithFlagLowering : public ConvertOpToLLVMPattern<MathOp> {
       if (!resultType)
         return failure();
     }
-    if (operandType != resultType)
-      return rewriter.notifyMatchFailure(
-          op, "compatible result type doesn't match operand type");
 
     if (!isa<LLVM::LLVMArrayType>(operandType)) {
       rewriter.replaceOpWithNewOp<LLVMOp>(op, resultType, adaptor.getOperand(),

>From 738abbc077d2a47cd2e99f522ad148c8baed97ad Mon Sep 17 00:00:00 2001
From: Artem Gindinson <gindinson at roofline.ai>
Date: Wed, 28 May 2025 14:53:41 +0000
Subject: [PATCH 3/8] [fixup] Improve code consistency with other lowerings

Signed-off-by: Artem Gindinson <gindinson at roofline.ai>
---
 mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp | 14 ++++++--------
 1 file changed, 6 insertions(+), 8 deletions(-)

diff --git a/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp b/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
index b8097d21dc227..de8c76cfb28c6 100644
--- a/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
+++ b/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
@@ -89,19 +89,17 @@ struct IntOpWithFlagLowering : public ConvertOpToLLVMPattern<MathOp> {
     auto loc = op.getLoc();
     auto resultType = op.getResult().getType();
     const auto &typeConverter = *this->getTypeConverter();
-    if (!LLVM::isCompatibleType(resultType)) {
-      resultType = typeConverter.convertType(resultType);
-      if (!resultType)
-        return failure();
-    }
+    auto llvmResultType = typeConverter.convertType(resultType);
+    if (!llvmResultType)
+      return failure();
 
     if (!isa<LLVM::LLVMArrayType>(operandType)) {
-      rewriter.replaceOpWithNewOp<LLVMOp>(op, resultType, adaptor.getOperand(),
-                                          false);
+      rewriter.replaceOpWithNewOp<LLVMOp>(op, llvmResultType,
+                                          adaptor.getOperand(), false);
       return success();
     }
 
-    auto vectorType = dyn_cast<VectorType>(resultType);
+    auto vectorType = dyn_cast<VectorType>(llvmResultType);
     if (!vectorType)
       return failure();
 

>From aa4ce6f2ad4fd244efbedd88c1b69c82aa048692 Mon Sep 17 00:00:00 2001
From: Artem Gindinson <gindinson at roofline.ai>
Date: Wed, 28 May 2025 17:12:50 +0000
Subject: [PATCH 4/8] [fixup] Convert the operand type separately

Signed-off-by: Artem Gindinson <gindinson at roofline.ai>
---
 mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp | 13 +++++++------
 1 file changed, 7 insertions(+), 6 deletions(-)

diff --git a/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp b/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
index de8c76cfb28c6..c324f93a441aa 100644
--- a/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
+++ b/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
@@ -73,6 +73,8 @@ using ATanOpLowering = ConvertFMFMathToLLVMPattern<math::AtanOp, LLVM::ATanOp>;
 using ATan2OpLowering =
     ConvertFMFMathToLLVMPattern<math::Atan2Op, LLVM::ATan2Op>;
 // A `CtLz/CtTz/absi(a)` is converted into `CtLz/CtTz/absi(a, false)`.
+// TODO: Result and operand types match for `absi` as opposed to `ct*z`, so it
+// may be better to separate the patterns.
 template <typename MathOp, typename LLVMOp>
 struct IntOpWithFlagLowering : public ConvertOpToLLVMPattern<MathOp> {
   using ConvertOpToLLVMPattern<MathOp>::ConvertOpToLLVMPattern;
@@ -81,26 +83,25 @@ struct IntOpWithFlagLowering : public ConvertOpToLLVMPattern<MathOp> {
   LogicalResult
   matchAndRewrite(MathOp op, typename MathOp::Adaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
+    const auto &typeConverter = *this->getTypeConverter();
     auto operandType = adaptor.getOperand().getType();
-
-    if (!operandType || !LLVM::isCompatibleType(operandType))
+    auto llvmOperandType = typeConverter.convertType(operandType);
+    if (!llvmOperandType)
       return failure();
 
     auto loc = op.getLoc();
     auto resultType = op.getResult().getType();
-    const auto &typeConverter = *this->getTypeConverter();
     auto llvmResultType = typeConverter.convertType(resultType);
     if (!llvmResultType)
       return failure();
 
-    if (!isa<LLVM::LLVMArrayType>(operandType)) {
+    if (!isa<LLVM::LLVMArrayType>(llvmOperandType)) {
       rewriter.replaceOpWithNewOp<LLVMOp>(op, llvmResultType,
                                           adaptor.getOperand(), false);
       return success();
     }
 
-    auto vectorType = dyn_cast<VectorType>(llvmResultType);
-    if (!vectorType)
+    if (!isa<VectorType>(llvmResultType))
       return failure();
 
     return LLVM::detail::handleMultidimensionalVectors(

>From 7c668846557bfd6ff08a9567129170923d93f7e5 Mon Sep 17 00:00:00 2001
From: Artem Gindinson <gindinson at roofline.ai>
Date: Thu, 29 May 2025 08:55:46 +0000
Subject: [PATCH 5/8] [fixup] Consistent function names in LIT

Signed-off-by: Artem Gindinson <gindinson at roofline.ai>
---
 mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir | 12 ++++++------
 1 file changed, 6 insertions(+), 6 deletions(-)

diff --git a/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir b/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir
index e4b79301fbaa5..ee388f1dbe898 100644
--- a/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir
+++ b/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir
@@ -29,9 +29,9 @@ func.func @absi(%arg0: i32) -> i32 {
 
 // -----
 
-// CHECK-LABEL: func @absi_0d_vec(
+// CHECK-LABEL: func @absi_0dvector(
 // CHECK-SAME: i32
-func.func @absi_0d_vec(%arg0 : vector<i32>) {
+func.func @absi_0dvector(%arg0 : vector<i32>) {
   // CHECK: %[[CAST:.+]] = builtin.unrealized_conversion_cast %arg0 : vector<i32> to vector<1xi32>
   // CHECK: "llvm.intr.abs"(%[[CAST]]) <{is_int_min_poison = false}> : (vector<1xi32>) -> vector<1xi32>
   %0 = math.absi %arg0 : vector<i32>
@@ -292,9 +292,9 @@ func.func @ctlz(%arg0 : i32) {
   func.return
 }
 
-// CHECK-LABEL: func @ctlz_0d_vec(
+// CHECK-LABEL: func @ctlz_0dvector(
 // CHECK-SAME: i32
-func.func @ctlz_0d_vec(%arg0 : vector<i32>) {
+func.func @ctlz_0dvector(%arg0 : vector<i32>) {
   // CHECK: %[[CAST:.+]] = builtin.unrealized_conversion_cast %arg0 : vector<i32> to vector<1xi32>
   // CHECK: "llvm.intr.ctlz"(%[[CAST]]) <{is_zero_poison = false}> : (vector<1xi32>) -> vector<1xi32>
   %0 = math.ctlz %arg0 : vector<i32>
@@ -313,9 +313,9 @@ func.func @cttz(%arg0 : i32) {
 
 // -----
 
-// CHECK-LABEL: func @cttz_0d_vec(
+// CHECK-LABEL: func @cttz_0dvector(
 // CHECK-SAME: i32
-func.func @cttz_0d_vec(%arg0 : vector<i32>) {
+func.func @cttz_0dvector(%arg0 : vector<i32>) {
   // CHECK: %[[CAST:.+]] = builtin.unrealized_conversion_cast %arg0 : vector<i32> to vector<1xi32>
   // CHECK: "llvm.intr.cttz"(%[[CAST]]) <{is_zero_poison = false}> : (vector<1xi32>) -> vector<1xi32>
   %0 = math.cttz %arg0 : vector<i32>

>From c035cbe1a5ffea4c5bed9b4cfdd7fc16e834ee2e Mon Sep 17 00:00:00 2001
From: Artem Gindinson <gindinson at roofline.ai>
Date: Thu, 29 May 2025 09:38:46 +0000
Subject: [PATCH 6/8] [fixup] LIT func arg checks

Signed-off-by: Artem Gindinson <gindinson at roofline.ai>
---
 mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir | 6 +++---
 1 file changed, 3 insertions(+), 3 deletions(-)

diff --git a/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir b/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir
index ee388f1dbe898..95f5debf35119 100644
--- a/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir
+++ b/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir
@@ -30,7 +30,7 @@ func.func @absi(%arg0: i32) -> i32 {
 // -----
 
 // CHECK-LABEL: func @absi_0dvector(
-// CHECK-SAME: i32
+// CHECK-SAME: vector<i32>
 func.func @absi_0dvector(%arg0 : vector<i32>) {
   // CHECK: %[[CAST:.+]] = builtin.unrealized_conversion_cast %arg0 : vector<i32> to vector<1xi32>
   // CHECK: "llvm.intr.abs"(%[[CAST]]) <{is_int_min_poison = false}> : (vector<1xi32>) -> vector<1xi32>
@@ -293,7 +293,7 @@ func.func @ctlz(%arg0 : i32) {
 }
 
 // CHECK-LABEL: func @ctlz_0dvector(
-// CHECK-SAME: i32
+// CHECK-SAME: vector<i32>
 func.func @ctlz_0dvector(%arg0 : vector<i32>) {
   // CHECK: %[[CAST:.+]] = builtin.unrealized_conversion_cast %arg0 : vector<i32> to vector<1xi32>
   // CHECK: "llvm.intr.ctlz"(%[[CAST]]) <{is_zero_poison = false}> : (vector<1xi32>) -> vector<1xi32>
@@ -314,7 +314,7 @@ func.func @cttz(%arg0 : i32) {
 // -----
 
 // CHECK-LABEL: func @cttz_0dvector(
-// CHECK-SAME: i32
+// CHECK-SAME: vector<i32>
 func.func @cttz_0dvector(%arg0 : vector<i32>) {
   // CHECK: %[[CAST:.+]] = builtin.unrealized_conversion_cast %arg0 : vector<i32> to vector<1xi32>
   // CHECK: "llvm.intr.cttz"(%[[CAST]]) <{is_zero_poison = false}> : (vector<1xi32>) -> vector<1xi32>

>From b8cb54986fe3eab885e484644f006b14f195de7c Mon Sep 17 00:00:00 2001
From: Artem Gindinson <gindinson at roofline.ai>
Date: Thu, 29 May 2025 09:14:02 +0000
Subject: [PATCH 7/8] [fixup] implementation for expm1, log1p, rsqrt

Signed-off-by: Artem Gindinson <gindinson at roofline.ai>
---
 mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp | 82 ++++++++++---------
 .../Conversion/MathToLLVM/math-to-llvm.mlir   | 39 +++++++++
 2 files changed, 83 insertions(+), 38 deletions(-)

diff --git a/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp b/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
index c324f93a441aa..5f2b00c1b95dd 100644
--- a/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
+++ b/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
@@ -128,40 +128,42 @@ struct ExpM1OpLowering : public ConvertOpToLLVMPattern<math::ExpM1Op> {
   LogicalResult
   matchAndRewrite(math::ExpM1Op op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
+    const auto &typeConverter = *this->getTypeConverter();
     auto operandType = adaptor.getOperand().getType();
-
-    if (!operandType || !LLVM::isCompatibleType(operandType))
+    auto llvmOperandType = typeConverter.convertType(operandType);
+    if (!llvmOperandType)
       return failure();
 
     auto loc = op.getLoc();
     auto resultType = op.getResult().getType();
-    auto floatType = cast<FloatType>(getElementTypeOrSelf(resultType));
+    auto floatType = cast<FloatType>(
+        typeConverter.convertType(getElementTypeOrSelf(resultType)));
     auto floatOne = rewriter.getFloatAttr(floatType, 1.0);
     ConvertFastMath<math::ExpM1Op, LLVM::ExpOp> expAttrs(op);
     ConvertFastMath<math::ExpM1Op, LLVM::FSubOp> subAttrs(op);
 
-    if (!isa<LLVM::LLVMArrayType>(operandType)) {
+    if (!isa<LLVM::LLVMArrayType>(llvmOperandType)) {
       LLVM::ConstantOp one;
-      if (LLVM::isCompatibleVectorType(operandType)) {
+      if (LLVM::isCompatibleVectorType(llvmOperandType)) {
         one = rewriter.create<LLVM::ConstantOp>(
-            loc, operandType,
-            SplatElementsAttr::get(cast<ShapedType>(resultType), floatOne));
+            loc, llvmOperandType,
+            SplatElementsAttr::get(cast<ShapedType>(llvmOperandType),
+                                   floatOne));
       } else {
-        one = rewriter.create<LLVM::ConstantOp>(loc, operandType, floatOne);
+        one = rewriter.create<LLVM::ConstantOp>(loc, llvmOperandType, floatOne);
       }
       auto exp = rewriter.create<LLVM::ExpOp>(loc, adaptor.getOperand(),
                                               expAttrs.getAttrs());
       rewriter.replaceOpWithNewOp<LLVM::FSubOp>(
-          op, operandType, ValueRange{exp, one}, subAttrs.getAttrs());
+          op, llvmOperandType, ValueRange{exp, one}, subAttrs.getAttrs());
       return success();
     }
 
-    auto vectorType = dyn_cast<VectorType>(resultType);
-    if (!vectorType)
+    if (!isa<VectorType>(resultType))
       return rewriter.notifyMatchFailure(op, "expected vector result type");
 
     return LLVM::detail::handleMultidimensionalVectors(
-        op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
+        op.getOperation(), adaptor.getOperands(), typeConverter,
         [&](Type llvm1DVectorTy, ValueRange operands) {
           auto numElements = LLVM::getVectorNumElements(llvm1DVectorTy);
           auto splatAttr = SplatElementsAttr::get(
@@ -186,41 +188,43 @@ struct Log1pOpLowering : public ConvertOpToLLVMPattern<math::Log1pOp> {
   LogicalResult
   matchAndRewrite(math::Log1pOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
+    const auto &typeConverter = *this->getTypeConverter();
     auto operandType = adaptor.getOperand().getType();
-
-    if (!operandType || !LLVM::isCompatibleType(operandType))
+    auto llvmOperandType = typeConverter.convertType(operandType);
+    if (!llvmOperandType)
       return rewriter.notifyMatchFailure(op, "unsupported operand type");
 
     auto loc = op.getLoc();
     auto resultType = op.getResult().getType();
-    auto floatType = cast<FloatType>(getElementTypeOrSelf(resultType));
+    auto floatType = cast<FloatType>(
+        typeConverter.convertType(getElementTypeOrSelf(resultType)));
     auto floatOne = rewriter.getFloatAttr(floatType, 1.0);
     ConvertFastMath<math::Log1pOp, LLVM::FAddOp> addAttrs(op);
     ConvertFastMath<math::Log1pOp, LLVM::LogOp> logAttrs(op);
 
-    if (!isa<LLVM::LLVMArrayType>(operandType)) {
+    if (!isa<LLVM::LLVMArrayType>(llvmOperandType)) {
       LLVM::ConstantOp one =
-          LLVM::isCompatibleVectorType(operandType)
+          isa<VectorType>(llvmOperandType)
               ? rewriter.create<LLVM::ConstantOp>(
-                    loc, operandType,
-                    SplatElementsAttr::get(cast<ShapedType>(resultType),
+                    loc, llvmOperandType,
+                    SplatElementsAttr::get(cast<ShapedType>(llvmOperandType),
                                            floatOne))
-              : rewriter.create<LLVM::ConstantOp>(loc, operandType, floatOne);
+              : rewriter.create<LLVM::ConstantOp>(loc, llvmOperandType,
+                                                  floatOne);
 
       auto add = rewriter.create<LLVM::FAddOp>(
-          loc, operandType, ValueRange{one, adaptor.getOperand()},
+          loc, llvmOperandType, ValueRange{one, adaptor.getOperand()},
           addAttrs.getAttrs());
-      rewriter.replaceOpWithNewOp<LLVM::LogOp>(op, operandType, ValueRange{add},
-                                               logAttrs.getAttrs());
+      rewriter.replaceOpWithNewOp<LLVM::LogOp>(
+          op, llvmOperandType, ValueRange{add}, logAttrs.getAttrs());
       return success();
     }
 
-    auto vectorType = dyn_cast<VectorType>(resultType);
-    if (!vectorType)
+    if (!isa<VectorType>(resultType))
       return rewriter.notifyMatchFailure(op, "expected vector result type");
 
     return LLVM::detail::handleMultidimensionalVectors(
-        op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
+        op.getOperation(), adaptor.getOperands(), typeConverter,
         [&](Type llvm1DVectorTy, ValueRange operands) {
           auto numElements = LLVM::getVectorNumElements(llvm1DVectorTy);
           auto splatAttr = SplatElementsAttr::get(
@@ -246,40 +250,42 @@ struct RsqrtOpLowering : public ConvertOpToLLVMPattern<math::RsqrtOp> {
   LogicalResult
   matchAndRewrite(math::RsqrtOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
+    const auto &typeConverter = *this->getTypeConverter();
     auto operandType = adaptor.getOperand().getType();
-
-    if (!operandType || !LLVM::isCompatibleType(operandType))
+    auto llvmOperandType = typeConverter.convertType(operandType);
+    if (!llvmOperandType)
       return failure();
 
     auto loc = op.getLoc();
     auto resultType = op.getResult().getType();
-    auto floatType = cast<FloatType>(getElementTypeOrSelf(resultType));
+    auto floatType = cast<FloatType>(
+        typeConverter.convertType(getElementTypeOrSelf(resultType)));
     auto floatOne = rewriter.getFloatAttr(floatType, 1.0);
     ConvertFastMath<math::RsqrtOp, LLVM::SqrtOp> sqrtAttrs(op);
     ConvertFastMath<math::RsqrtOp, LLVM::FDivOp> divAttrs(op);
 
-    if (!isa<LLVM::LLVMArrayType>(operandType)) {
+    if (!isa<LLVM::LLVMArrayType>(llvmOperandType)) {
       LLVM::ConstantOp one;
-      if (LLVM::isCompatibleVectorType(operandType)) {
+      if (isa<VectorType>(llvmOperandType)) {
         one = rewriter.create<LLVM::ConstantOp>(
-            loc, operandType,
-            SplatElementsAttr::get(cast<ShapedType>(resultType), floatOne));
+            loc, llvmOperandType,
+            SplatElementsAttr::get(cast<ShapedType>(llvmOperandType),
+                                   floatOne));
       } else {
-        one = rewriter.create<LLVM::ConstantOp>(loc, operandType, floatOne);
+        one = rewriter.create<LLVM::ConstantOp>(loc, llvmOperandType, floatOne);
       }
       auto sqrt = rewriter.create<LLVM::SqrtOp>(loc, adaptor.getOperand(),
                                                 sqrtAttrs.getAttrs());
       rewriter.replaceOpWithNewOp<LLVM::FDivOp>(
-          op, operandType, ValueRange{one, sqrt}, divAttrs.getAttrs());
+          op, llvmOperandType, ValueRange{one, sqrt}, divAttrs.getAttrs());
       return success();
     }
 
-    auto vectorType = dyn_cast<VectorType>(resultType);
-    if (!vectorType)
+    if (!isa<VectorType>(resultType))
       return failure();
 
     return LLVM::detail::handleMultidimensionalVectors(
-        op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
+        op.getOperation(), adaptor.getOperands(), typeConverter,
         [&](Type llvm1DVectorTy, ValueRange operands) {
           auto numElements = LLVM::getVectorNumElements(llvm1DVectorTy);
           auto splatAttr = SplatElementsAttr::get(
diff --git a/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir b/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir
index 95f5debf35119..461c74997d872 100644
--- a/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir
+++ b/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir
@@ -102,6 +102,19 @@ func.func @log1p_scalable_vector(%arg0 : vector<[4]xf32>) -> vector<[4]xf32> {
 
 // -----
 
+// CHECK-LABEL: func @log1p_0dvector(
+// CHECK-SAME: vector<f32>
+func.func @log1p_0dvector(%arg0 : vector<f32>) {
+  // CHECK: %[[CAST:.+]] = builtin.unrealized_conversion_cast %arg0 : vector<f32> to vector<1xf32>
+  // CHECK: %[[ONE:.*]] = llvm.mlir.constant(dense<1.000000e+00> : vector<1xf32>) : vector<1xf32>
+  // CHECK: %[[ADD:.*]] = llvm.fadd %[[ONE]], %[[CAST]]  : vector<1xf32>
+  // CHECK: %[[LOG:.*]] = llvm.intr.log(%[[ADD]])  : (vector<1xf32>) -> vector<1xf32>
+  %0 = math.log1p %arg0 : vector<f32>
+  func.return
+}
+
+// -----
+
 // CHECK-LABEL: func @expm1(
 // CHECK-SAME: f32
 func.func @expm1(%arg0 : f32) {
@@ -162,6 +175,19 @@ func.func @expm1_vector_fmf(%arg0 : vector<4xf32>) {
 
 // -----
 
+// CHECK-LABEL: func @expm1_0dvector(
+// CHECK-SAME: vector<f32>
+func.func @expm1_0dvector(%arg0 : vector<f32>) {
+  // CHECK: %[[CAST:.+]] = builtin.unrealized_conversion_cast %arg0 : vector<f32> to vector<1xf32>
+  // CHECK: %[[ONE:.*]] = llvm.mlir.constant(dense<1.000000e+00> : vector<1xf32>) : vector<1xf32>
+  // CHECK: %[[EXP:.*]] = llvm.intr.exp(%[[CAST]]) : (vector<1xf32>) -> vector<1xf32>
+  // CHECK: %[[SUB:.*]] = llvm.fsub %[[EXP]], %[[ONE]] : vector<1xf32>
+  %0 = math.expm1 %arg0 : vector<f32>
+  func.return
+}
+
+// -----
+
 // CHECK-LABEL: func @rsqrt(
 // CHECK-SAME: f32
 func.func @rsqrt(%arg0 : f32) {
@@ -174,6 +200,19 @@ func.func @rsqrt(%arg0 : f32) {
 
 // -----
 
+// CHECK-LABEL: func @rsqrt_0dvector(
+// CHECK-SAME: vector<f32>
+func.func @rsqrt_0dvector(%arg0 : vector<f32>) {
+  // CHECK: %[[CAST:.+]] = builtin.unrealized_conversion_cast %arg0 : vector<f32> to vector<1xf32>
+  // CHECK: %[[ONE:.*]] = llvm.mlir.constant(dense<1.000000e+00> : vector<1xf32>) : vector<1xf32>
+  // CHECK: %[[SQRT:.*]] = llvm.intr.sqrt(%[[CAST]]) : (vector<1xf32>) -> vector<1xf32>
+  // CHECK: %[[SUB:.*]] = llvm.fdiv %[[ONE]], %[[SQRT]] : vector<1xf32>
+  %0 = math.rsqrt %arg0 : vector<f32>
+  func.return
+}
+
+// -----
+
 // CHECK-LABEL: func @trigonometrics
 // CHECK-SAME: [[ARG0:%.+]]: f32
 func.func @trigonometrics(%arg0: f32) {

>From 75923816afe2b75c1ca69c074ca701f1a193fb70 Mon Sep 17 00:00:00 2001
From: Artem Gindinson <gindinson at roofline.ai>
Date: Thu, 29 May 2025 09:40:06 +0000
Subject: [PATCH 8/8] [fixup] isNaN, isFinite implementation

Signed-off-by: Artem Gindinson <gindinson at roofline.ai>
---
 mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp | 20 ++++++++++-------
 .../Conversion/MathToLLVM/math-to-llvm.mlir   | 22 +++++++++++++++++++
 2 files changed, 34 insertions(+), 8 deletions(-)

diff --git a/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp b/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
index 5f2b00c1b95dd..f4d69ce8235bb 100644
--- a/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
+++ b/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
@@ -309,13 +309,15 @@ struct IsNaNOpLowering : public ConvertOpToLLVMPattern<math::IsNaNOp> {
   LogicalResult
   matchAndRewrite(math::IsNaNOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    auto operandType = adaptor.getOperand().getType();
-
-    if (!operandType || !LLVM::isCompatibleType(operandType))
+    const auto &typeConverter = *this->getTypeConverter();
+    auto operandType =
+        typeConverter.convertType(adaptor.getOperand().getType());
+    auto resultType = typeConverter.convertType(op.getResult().getType());
+    if (!operandType || !resultType)
       return failure();
 
     rewriter.replaceOpWithNewOp<LLVM::IsFPClass>(
-        op, op.getType(), adaptor.getOperand(), llvm::fcNan);
+        op, resultType, adaptor.getOperand(), llvm::fcNan);
     return success();
   }
 };
@@ -326,13 +328,15 @@ struct IsFiniteOpLowering : public ConvertOpToLLVMPattern<math::IsFiniteOp> {
   LogicalResult
   matchAndRewrite(math::IsFiniteOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    auto operandType = adaptor.getOperand().getType();
-
-    if (!operandType || !LLVM::isCompatibleType(operandType))
+    const auto &typeConverter = *this->getTypeConverter();
+    auto operandType =
+        typeConverter.convertType(adaptor.getOperand().getType());
+    auto resultType = typeConverter.convertType(op.getResult().getType());
+    if (!operandType || !resultType)
       return failure();
 
     rewriter.replaceOpWithNewOp<LLVM::IsFPClass>(
-        op, op.getType(), adaptor.getOperand(), llvm::fcFinite);
+        op, resultType, adaptor.getOperand(), llvm::fcFinite);
     return success();
   }
 };
diff --git a/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir b/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir
index 461c74997d872..92904082a6f46 100644
--- a/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir
+++ b/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir
@@ -423,6 +423,17 @@ func.func @isnan_double(%arg0 : f64) {
 
 // -----
 
+// CHECK-LABEL: func @isnan_0dvector(
+// CHECK-SAME: vector<f32>
+func.func @isnan_0dvector(%arg0 : vector<f32>) {
+  // CHECK: %[[CAST:.+]] = builtin.unrealized_conversion_cast %arg0 : vector<f32> to vector<1xf32>
+  // CHECK: "llvm.intr.is.fpclass"(%[[CAST]]) <{bit = 3 : i32}> : (vector<1xf32>) -> vector<1xi1>
+  %0 = math.isnan %arg0 : vector<f32>
+  func.return
+}
+
+// -----
+
 // CHECK-LABEL: func @isfinite_double(
 // CHECK-SAME: f64
 func.func @isfinite_double(%arg0 : f64) {
@@ -433,6 +444,17 @@ func.func @isfinite_double(%arg0 : f64) {
 
 // -----
 
+// CHECK-LABEL: func @isfinite_0dvector(
+// CHECK-SAME: vector<f32>
+func.func @isfinite_0dvector(%arg0 : vector<f32>) {
+  // CHECK: %[[CAST:.+]] = builtin.unrealized_conversion_cast %arg0 : vector<f32> to vector<1xf32>
+  // CHECK: "llvm.intr.is.fpclass"(%[[CAST]]) <{bit = 504 : i32}> : (vector<1xf32>) -> vector<1xi1>
+  %0 = math.isfinite %arg0 : vector<f32>
+  func.return
+}
+
+// -----
+
 // CHECK-LABEL: func @rsqrt_double(
 // CHECK-SAME: f64
 func.func @rsqrt_double(%arg0 : f64) {



More information about the Mlir-commits mailing list