[Mlir-commits] [mlir] [mlir][math] Propagate scalability in `convert-math-to-llvm` (PR #82635)
Benjamin Maxwell
llvmlistbot at llvm.org
Thu Feb 22 07:43:44 PST 2024
https://github.com/MacDue created https://github.com/llvm/llvm-project/pull/82635
This also generally increases the coverage of scalable vector types in the math-to-llvm tests.
>From 387d656e63f48a8234db3caead2051b824c7d4a0 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Thu, 22 Feb 2024 15:39:11 +0000
Subject: [PATCH] [mlir][math] Propagate scalability in `convert-math-to-llvm`
---
mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp | 18 ++---
.../Conversion/MathToLLVM/math-to-llvm.mlir | 81 +++++++++++++++++++
2 files changed, 90 insertions(+), 9 deletions(-)
diff --git a/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp b/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
index 1b729611a36235..23e957288eb95e 100644
--- a/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
+++ b/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
@@ -148,10 +148,10 @@ struct ExpM1OpLowering : public ConvertOpToLLVMPattern<math::ExpM1Op> {
return LLVM::detail::handleMultidimensionalVectors(
op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
[&](Type llvm1DVectorTy, ValueRange operands) {
+ auto numElements = LLVM::getVectorNumElements(llvm1DVectorTy);
auto splatAttr = SplatElementsAttr::get(
- mlir::VectorType::get(
- {LLVM::getVectorNumElements(llvm1DVectorTy).getFixedValue()},
- floatType),
+ mlir::VectorType::get({numElements.getKnownMinValue()}, floatType,
+ {numElements.isScalable()}),
floatOne);
auto one =
rewriter.create<LLVM::ConstantOp>(loc, llvm1DVectorTy, splatAttr);
@@ -207,10 +207,10 @@ struct Log1pOpLowering : public ConvertOpToLLVMPattern<math::Log1pOp> {
return LLVM::detail::handleMultidimensionalVectors(
op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
[&](Type llvm1DVectorTy, ValueRange operands) {
+ auto numElements = LLVM::getVectorNumElements(llvm1DVectorTy);
auto splatAttr = SplatElementsAttr::get(
- mlir::VectorType::get(
- {LLVM::getVectorNumElements(llvm1DVectorTy).getFixedValue()},
- floatType),
+ mlir::VectorType::get({numElements.getKnownMinValue()}, floatType,
+ {numElements.isScalable()}),
floatOne);
auto one =
rewriter.create<LLVM::ConstantOp>(loc, llvm1DVectorTy, splatAttr);
@@ -266,10 +266,10 @@ struct RsqrtOpLowering : public ConvertOpToLLVMPattern<math::RsqrtOp> {
return LLVM::detail::handleMultidimensionalVectors(
op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
[&](Type llvm1DVectorTy, ValueRange operands) {
+ auto numElements = LLVM::getVectorNumElements(llvm1DVectorTy);
auto splatAttr = SplatElementsAttr::get(
- mlir::VectorType::get(
- {LLVM::getVectorNumElements(llvm1DVectorTy).getFixedValue()},
- floatType),
+ mlir::VectorType::get({numElements.getKnownMinValue()}, floatType,
+ {numElements.isScalable()}),
floatOne);
auto one =
rewriter.create<LLVM::ConstantOp>(loc, llvm1DVectorTy, splatAttr);
diff --git a/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir b/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir
index 3de2f11d1d12c7..ca8bba56ccd57c 100644
--- a/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir
+++ b/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir
@@ -77,6 +77,18 @@ func.func @log1p_2dvector_fmf(%arg0 : vector<4x3xf32>) {
// -----
+// CHECK-LABEL: func @log1p_scalable_vector(
+// CHECK-SAME: %[[VEC:.*]]: vector<[4]xf32>
+func.func @log1p_scalable_vector(%arg0 : vector<[4]xf32>) -> vector<[4]xf32> {
+ // CHECK: %[[ONE:.*]] = llvm.mlir.constant(dense<1.000000e+00> : vector<[4]xf32>) : vector<[4]xf32>
+ // CHECK: %[[ADD:.*]] = llvm.fadd %[[ONE]], %[[VEC]] : vector<[4]xf32>
+ // CHECK: %[[LOG:.*]] = llvm.intr.log(%[[ADD]]) : (vector<[4]xf32>) -> vector<[4]xf32>
+ %0 = math.log1p %arg0 : vector<[4]xf32>
+ func.return %0 : vector<[4]xf32>
+}
+
+// -----
+
// CHECK-LABEL: func @expm1(
// CHECK-SAME: f32
func.func @expm1(%arg0 : f32) {
@@ -113,6 +125,18 @@ func.func @expm1_vector(%arg0 : vector<4xf32>) {
// -----
+// CHECK-LABEL: func @expm1_scalable_vector(
+// CHECK-SAME: %[[VEC:.*]]: vector<[4]xf32>
+func.func @expm1_scalable_vector(%arg0 : vector<[4]xf32>) -> vector<[4]xf32> {
+ // CHECK: %[[ONE:.*]] = llvm.mlir.constant(dense<1.000000e+00> : vector<[4]xf32>) : vector<[4]xf32>
+ // CHECK: %[[EXP:.*]] = llvm.intr.exp(%[[VEC]]) : (vector<[4]xf32>) -> vector<[4]xf32>
+ // CHECK: %[[SUB:.*]] = llvm.fsub %[[EXP]], %[[ONE]] : vector<[4]xf32>
+ %0 = math.expm1 %arg0 : vector<[4]xf32>
+ func.return %0 : vector<[4]xf32>
+}
+
+// -----
+
// CHECK-LABEL: func @expm1_vector_fmf(
// CHECK-SAME: vector<4xf32>
func.func @expm1_vector_fmf(%arg0 : vector<4xf32>) {
@@ -177,6 +201,16 @@ func.func @cttz_vec(%arg0 : vector<4xi32>) {
// -----
+// CHECK-LABEL: func @cttz_scalable_vec(
+// CHECK-SAME: %[[VEC:.*]]: vector<[4]xi32>
+func.func @cttz_scalable_vec(%arg0 : vector<[4]xi32>) -> vector<[4]xi32> {
+ // CHECK: "llvm.intr.cttz"(%[[VEC]]) <{is_zero_poison = false}> : (vector<[4]xi32>) -> vector<[4]xi32>
+ %0 = math.cttz %arg0 : vector<[4]xi32>
+ func.return %0 : vector<[4]xi32>
+}
+
+// -----
+
// CHECK-LABEL: func @ctpop(
// CHECK-SAME: i32
func.func @ctpop(%arg0 : i32) {
@@ -197,6 +231,16 @@ func.func @ctpop_vector(%arg0 : vector<3xi32>) {
// -----
+// CHECK-LABEL: func @ctpop_scalable_vector(
+// CHECK-SAME: %[[VEC:.*]]: vector<[4]xi32>
+func.func @ctpop_scalable_vector(%arg0 : vector<[4]xi32>) -> vector<[4]xi32> {
+ // CHECK: llvm.intr.ctpop(%[[VEC]]) : (vector<[4]xi32>) -> vector<[4]xi32>
+ %0 = math.ctpop %arg0 : vector<[4]xi32>
+ func.return %0 : vector<[4]xi32>
+}
+
+// -----
+
// CHECK-LABEL: func @rsqrt_double(
// CHECK-SAME: f64
func.func @rsqrt_double(%arg0 : f64) {
@@ -233,6 +277,18 @@ func.func @rsqrt_vector(%arg0 : vector<4xf32>) {
// -----
+// CHECK-LABEL: func @rsqrt_scalable_vector(
+// CHECK-SAME: %[[VEC:.*]]: vector<[4]xf32>
+func.func @rsqrt_scalable_vector(%arg0 : vector<[4]xf32>) -> vector<[4]xf32>{
+ // CHECK: %[[ONE:.*]] = llvm.mlir.constant(dense<1.000000e+00> : vector<[4]xf32>) : vector<[4]xf32>
+ // CHECK: %[[SQRT:.*]] = llvm.intr.sqrt(%[[VEC]]) : (vector<[4]xf32>) -> vector<[4]xf32>
+ // CHECK: %[[DIV:.*]] = llvm.fdiv %[[ONE]], %[[SQRT]] : vector<[4]xf32>
+ %0 = math.rsqrt %arg0 : vector<[4]xf32>
+ func.return %0 : vector<[4]xf32>
+}
+
+// -----
+
// CHECK-LABEL: func @rsqrt_vector_fmf(
// CHECK-SAME: vector<4xf32>
func.func @rsqrt_vector_fmf(%arg0 : vector<4xf32>) {
@@ -245,6 +301,18 @@ func.func @rsqrt_vector_fmf(%arg0 : vector<4xf32>) {
// -----
+// CHECK-LABEL: func @rsqrt_scalable_vector_fmf(
+// CHECK-SAME: %[[VEC:.*]]: vector<[4]xf32>
+func.func @rsqrt_scalable_vector_fmf(%arg0 : vector<[4]xf32>) -> vector<[4]xf32> {
+ // CHECK: %[[ONE:.*]] = llvm.mlir.constant(dense<1.000000e+00> : vector<[4]xf32>) : vector<[4]xf32>
+ // CHECK: %[[SQRT:.*]] = llvm.intr.sqrt(%[[VEC]]) {fastmathFlags = #llvm.fastmath<fast>} : (vector<[4]xf32>) -> vector<[4]xf32>
+ // CHECK: %[[DIV:.*]] = llvm.fdiv %[[ONE]], %[[SQRT]] {fastmathFlags = #llvm.fastmath<fast>} : vector<[4]xf32>
+ %0 = math.rsqrt %arg0 fastmath<fast> : vector<[4]xf32>
+ func.return %0 : vector<[4]xf32>
+}
+
+// -----
+
// CHECK-LABEL: func @rsqrt_multidim_vector(
func.func @rsqrt_multidim_vector(%arg0 : vector<4x3xf32>) {
// CHECK: %[[EXTRACT:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm.array<4 x vector<3xf32>>
@@ -258,6 +326,19 @@ func.func @rsqrt_multidim_vector(%arg0 : vector<4x3xf32>) {
// -----
+// CHECK-LABEL: func @rsqrt_multidim_vector(
+func.func @rsqrt_multidim_vector(%arg0 : vector<4x[4]xf32>) -> vector<4x[4]xf32> {
+ // CHECK: %[[EXTRACT:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm.array<4 x vector<[4]xf32>>
+ // CHECK: %[[ONE:.*]] = llvm.mlir.constant(dense<1.000000e+00> : vector<[4]xf32>) : vector<[4]xf32>
+ // CHECK: %[[SQRT:.*]] = llvm.intr.sqrt(%[[EXTRACT]]) : (vector<[4]xf32>) -> vector<[4]xf32>
+ // CHECK: %[[DIV:.*]] = llvm.fdiv %[[ONE]], %[[SQRT]] : vector<[4]xf32>
+ // CHECK: %[[INSERT:.*]] = llvm.insertvalue %[[DIV]], %{{.*}}[0] : !llvm.array<4 x vector<[4]xf32>>
+ %0 = math.rsqrt %arg0 : vector<4x[4]xf32>
+ func.return %0 : vector<4x[4]xf32>
+}
+
+// -----
+
// CHECK-LABEL: func @fpowi(
// CHECK-SAME: f64
func.func @fpowi(%arg0 : f64, %arg1 : i32) {
More information about the Mlir-commits
mailing list