[Mlir-commits] [mlir] [math] lower rsqrt to sqrt + fdiv (PR #91344)
Corentin Ferry
llvmlistbot at llvm.org
Tue May 7 07:56:09 PDT 2024
https://github.com/cferry-AMD created https://github.com/llvm/llvm-project/pull/91344
This commit creates an expansion pattern to lower math.rsqrt(x) into fdiv(1, sqrt(x)).
>From d4601688da84e38d383e4a5cd5dc9cfc6b2dc1fe Mon Sep 17 00:00:00 2001
From: Corentin Ferry <corentin.ferry at amd.com>
Date: Mon, 15 Apr 2024 10:48:43 +0200
Subject: [PATCH] [math] lower rsqrt to sqrt + fdiv
---
.../mlir/Dialect/Math/Transforms/Passes.h | 1 +
.../Math/Transforms/ExpandPatterns.cpp | 22 ++++++++++
mlir/test/Dialect/Math/expand-math.mlir | 42 +++++++++++++++++++
mlir/test/lib/Dialect/Math/TestExpandMath.cpp | 1 +
4 files changed, 66 insertions(+)
diff --git a/mlir/include/mlir/Dialect/Math/Transforms/Passes.h b/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
index 24e6d9a8d98e..ba6977251564 100644
--- a/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
@@ -42,6 +42,7 @@ void populateExpandPowFPattern(RewritePatternSet &patterns);
void populateExpandFPowIPattern(RewritePatternSet &patterns);
void populateExpandRoundFPattern(RewritePatternSet &patterns);
void populateExpandRoundEvenPattern(RewritePatternSet &patterns);
+void populateExpandRsqrtPattern(RewritePatternSet &patterns);
void populateMathAlgebraicSimplificationPatterns(RewritePatternSet &patterns);
struct MathPolynomialApproximationOptions {
diff --git a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
index 5ccf3b6d72a2..05d32ad2bc3e 100644
--- a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
@@ -615,6 +615,24 @@ static LogicalResult convertRoundEvenOp(math::RoundEvenOp op,
return success();
}
+// Convert `math.rsqrt` into `arith.divf` + `math.sqrt`
+static LogicalResult convertRsqrtOp(math::RsqrtOp op,
+ PatternRewriter &rewriter) {
+
+ auto operand = op.getOperand();
+ auto operandTy = operand.getType();
+ auto eTy = getElementTypeOrSelf(operandTy);
+ if (!isa<FloatType>(eTy))
+ return failure();
+
+ Location loc = op->getLoc();
+ auto constOneFloat = createFloatConst(loc, operandTy, 1.0, rewriter);
+ auto sqrtOp = rewriter.create<math::SqrtOp>(loc, op->getOperand(0));
+ rewriter.replaceOpWithNewOp<arith::DivFOp>(op, operandTy,
+ ValueRange{constOneFloat, sqrtOp});
+ return success();
+}
+
void mlir::populateExpandCtlzPattern(RewritePatternSet &patterns) {
patterns.add(convertCtlzOp);
}
@@ -678,3 +696,7 @@ void mlir::populateExpandFloorFPattern(RewritePatternSet &patterns) {
void mlir::populateExpandRoundEvenPattern(RewritePatternSet &patterns) {
patterns.add(convertRoundEvenOp);
}
+
+void mlir::populateExpandRsqrtPattern(RewritePatternSet &patterns) {
+ patterns.add(convertRsqrtOp);
+}
diff --git a/mlir/test/Dialect/Math/expand-math.mlir b/mlir/test/Dialect/Math/expand-math.mlir
index 3d94b55126d0..d25f4e571e6a 100644
--- a/mlir/test/Dialect/Math/expand-math.mlir
+++ b/mlir/test/Dialect/Math/expand-math.mlir
@@ -658,3 +658,45 @@ func.func @math_fpowi_to_powf_scalar(%0 : f32, %1: i64) -> f32 {
// CHECK: %[[AND:.*]] = arith.andi %[[CMPF1]], %[[CMPF]] : i1
// CHECK: %[[SEL:.*]] = arith.select %[[AND]], %[[MUL1]], %[[EXP]] : f32
// CHECK: return %[[SEL]] : f32
+
+// -----
+
+// CHECK-LABEL: func.func @rsqrt
+// CHECK-SAME: (%[[ARG:.*]]: f32)
+// CHECK-SAME: -> f32
+// CHECK-DAG: %[[CST:.*]] = arith.constant 1.000000e+00 : f32
+// CHECK-DAG: %[[SQRT:.*]] = math.sqrt %[[ARG]] : f32
+// CHECK-DAG: %[[DIV:.*]] = arith.divf %[[CST]], %[[SQRT]] : f32
+// CHECK: return %[[DIV]] : f32
+func.func @rsqrt(%float: f32) -> (f32) {
+ %float_result = math.rsqrt %float : f32
+ return %float_result : f32
+}
+
+// -----
+
+// CHECK-LABEL: func.func @rsqrt_vec
+// CHECK-SAME: (%[[ARG:.*]]: vector<5xf32>)
+// CHECK-SAME: -> vector<5xf32>
+// CHECK-DAG: %[[CST:.*]] = arith.constant dense<1.000000e+00> : vector<5xf32>
+// CHECK-DAG: %[[SQRT:.*]] = math.sqrt %[[ARG]] : vector<5xf32>
+// CHECK-DAG: %[[DIV:.*]] = arith.divf %[[CST]], %[[SQRT]] : vector<5xf32>
+// CHECK: return %[[DIV]] : vector<5xf32>
+func.func @rsqrt_vec(%float: vector<5xf32>) -> (vector<5xf32>) {
+ %float_result = math.rsqrt %float : vector<5xf32>
+ return %float_result : vector<5xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @rsqrt_tns
+// CHECK-SAME: (%[[ARG:.*]]: tensor<5x8xf32>)
+// CHECK-SAME: -> tensor<5x8xf32>
+// CHECK-DAG: %[[CST:.*]] = arith.constant dense<1.000000e+00> : tensor<5x8xf32>
+// CHECK-DAG: %[[SQRT:.*]] = math.sqrt %[[ARG]] : tensor<5x8xf32>
+// CHECK-DAG: %[[DIV:.*]] = arith.divf %[[CST]], %[[SQRT]] : tensor<5x8xf32>
+// CHECK: return %[[DIV]] : tensor<5x8xf32>
+func.func @rsqrt_tns(%float: tensor<5x8xf32>) -> (tensor<5x8xf32>) {
+ %float_result = math.rsqrt %float : tensor<5x8xf32>
+ return %float_result : tensor<5x8xf32>
+}
diff --git a/mlir/test/lib/Dialect/Math/TestExpandMath.cpp b/mlir/test/lib/Dialect/Math/TestExpandMath.cpp
index da48ccb6e5e0..69af2a08b97b 100644
--- a/mlir/test/lib/Dialect/Math/TestExpandMath.cpp
+++ b/mlir/test/lib/Dialect/Math/TestExpandMath.cpp
@@ -52,6 +52,7 @@ void TestExpandMathPass::runOnOperation() {
populateExpandFPowIPattern(patterns);
populateExpandRoundFPattern(patterns);
populateExpandRoundEvenPattern(patterns);
+ populateExpandRsqrtPattern(patterns);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}
More information about the Mlir-commits
mailing list