[Mlir-commits] [mlir] [mlir][math] Fix intrinsic conversions to LLVM for 0D-vector types (PR #141020)

Artem Gindinson llvmlistbot at llvm.org
Thu May 22 01:32:07 PDT 2025


https://github.com/AGindinson created https://github.com/llvm/llvm-project/pull/141020

`vector<t>` types are not compatible with the LLVM type system, and must be explicitly converted into `vector<1xt>` when lowering. Employ this rule within the conversion pattern for `math.ctlz`, `.cttz` and `.absi` intrinsics.

>From 54c997b027994dad412c01241a86fe993eb92e81 Mon Sep 17 00:00:00 2001
From: Artem Gindinson <gindinson at roofline.ai>
Date: Thu, 22 May 2025 08:40:28 +0200
Subject: [PATCH] [mlir][math] Fix intrinsic conversions to LLVM for 0D-vector
 types

`vector<t>` types are not compatible with the LLVM type system, and must be
explicitly converted into `vector<1xt>` when lowering. Employ this rule within
the conversion pattern for `math.ctlz`, `.cttz` and `.absi` intrinsics.

Signed-off-by: Artem Gindinson <gindinson at roofline.ai>
---
 mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp | 11 ++++++-
 .../Conversion/MathToLLVM/math-to-llvm.mlir   | 33 +++++++++++++++++++
 2 files changed, 43 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp b/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
index 97da96afac4cd..19cd960b15294 100644
--- a/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
+++ b/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
@@ -84,6 +84,15 @@ struct IntOpWithFlagLowering : public ConvertOpToLLVMPattern<MathOp> {
 
     auto loc = op.getLoc();
     auto resultType = op.getResult().getType();
+    const auto &typeConverter = *this->getTypeConverter();
+    if (!LLVM::isCompatibleType(resultType)) {
+      resultType = typeConverter.convertType(resultType);
+      if (!resultType)
+        return failure();
+    }
+    if (operandType != resultType)
+      return rewriter.notifyMatchFailure(
+          op, "compatible result type doesn't match operand type");
 
     if (!isa<LLVM::LLVMArrayType>(operandType)) {
       rewriter.replaceOpWithNewOp<LLVMOp>(op, resultType, adaptor.getOperand(),
@@ -96,7 +105,7 @@ struct IntOpWithFlagLowering : public ConvertOpToLLVMPattern<MathOp> {
       return failure();
 
     return LLVM::detail::handleMultidimensionalVectors(
-        op.getOperation(), adaptor.getOperands(), *this->getTypeConverter(),
+        op.getOperation(), adaptor.getOperands(), typeConverter,
         [&](Type llvm1DVectorTy, ValueRange operands) {
           return rewriter.create<LLVMOp>(loc, llvm1DVectorTy, operands[0],
                                          false);
diff --git a/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir b/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir
index 974743a55932b..73325a3fd913e 100644
--- a/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir
+++ b/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir
@@ -19,6 +19,8 @@ func.func @ops(%arg0: f32, %arg1: f32, %arg2: i32, %arg3: i32, %arg4: f64) {
 
 // -----
 
+// CHECK-LABEL: func @absi(
+// CHECK-SAME: i32
 func.func @absi(%arg0: i32) -> i32 {
   // CHECK: = "llvm.intr.abs"(%{{.*}}) <{is_int_min_poison = false}> : (i32) -> i32
   %0 = math.absi %arg0 : i32
@@ -27,6 +29,17 @@ func.func @absi(%arg0: i32) -> i32 {
 
 // -----
 
+// CHECK-LABEL: func @absi_0d_vec(
+// CHECK-SAME: i32
+func.func @absi_0d_vec(%arg0 : vector<i32>) {
+  // CHECK: %[[CAST:.+]] = builtin.unrealized_conversion_cast %arg0 : vector<i32> to vector<1xi32>
+  // CHECK: "llvm.intr.abs"(%[[CAST]]) <{is_int_min_poison = false}> : (vector<1xi32>) -> vector<1xi32>
+  %0 = math.absi %arg0 : vector<i32>
+  func.return
+}
+
+// -----
+
 // CHECK-LABEL: func @log1p(
 // CHECK-SAME: f32
 func.func @log1p(%arg0 : f32) {
@@ -201,6 +214,15 @@ func.func @ctlz(%arg0 : i32) {
   func.return
 }
 
+// CHECK-LABEL: func @ctlz_0d_vec(
+// CHECK-SAME: i32
+func.func @ctlz_0d_vec(%arg0 : vector<i32>) {
+  // CHECK: %[[CAST:.+]] = builtin.unrealized_conversion_cast %arg0 : vector<i32> to vector<1xi32>
+  // CHECK: "llvm.intr.ctlz"(%[[CAST]]) <{is_zero_poison = false}> : (vector<1xi32>) -> vector<1xi32>
+  %0 = math.ctlz %arg0 : vector<i32>
+  func.return
+}
+
 // -----
 
 // CHECK-LABEL: func @cttz(
@@ -213,6 +235,17 @@ func.func @cttz(%arg0 : i32) {
 
 // -----
 
+// CHECK-LABEL: func @cttz_0d_vec(
+// CHECK-SAME: i32
+func.func @cttz_0d_vec(%arg0 : vector<i32>) {
+  // CHECK: %[[CAST:.+]] = builtin.unrealized_conversion_cast %arg0 : vector<i32> to vector<1xi32>
+  // CHECK: "llvm.intr.cttz"(%[[CAST]]) <{is_zero_poison = false}> : (vector<1xi32>) -> vector<1xi32>
+  %0 = math.cttz %arg0 : vector<i32>
+  func.return
+}
+
+// -----
+
 // CHECK-LABEL: func @cttz_vec(
 // CHECK-SAME: i32
 func.func @cttz_vec(%arg0 : vector<4xi32>) {



More information about the Mlir-commits mailing list