[Mlir-commits] [mlir] [math] lower rsqrt to sqrt + fdiv (PR #91344)

Corentin Ferry llvmlistbot at llvm.org
Wed May 8 01:25:32 PDT 2024


https://github.com/cferry-AMD updated https://github.com/llvm/llvm-project/pull/91344

>From d4601688da84e38d383e4a5cd5dc9cfc6b2dc1fe Mon Sep 17 00:00:00 2001
From: Corentin Ferry <corentin.ferry at amd.com>
Date: Mon, 15 Apr 2024 10:48:43 +0200
Subject: [PATCH 1/2] [math] lower rsqrt to sqrt + fdiv

---
 .../mlir/Dialect/Math/Transforms/Passes.h     |  1 +
 .../Math/Transforms/ExpandPatterns.cpp        | 22 ++++++++++
 mlir/test/Dialect/Math/expand-math.mlir       | 42 +++++++++++++++++++
 mlir/test/lib/Dialect/Math/TestExpandMath.cpp |  1 +
 4 files changed, 66 insertions(+)

diff --git a/mlir/include/mlir/Dialect/Math/Transforms/Passes.h b/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
index 24e6d9a8d98e0..ba69772515647 100644
--- a/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
@@ -42,6 +42,7 @@ void populateExpandPowFPattern(RewritePatternSet &patterns);
 void populateExpandFPowIPattern(RewritePatternSet &patterns);
 void populateExpandRoundFPattern(RewritePatternSet &patterns);
 void populateExpandRoundEvenPattern(RewritePatternSet &patterns);
+void populateExpandRsqrtPattern(RewritePatternSet &patterns);
 void populateMathAlgebraicSimplificationPatterns(RewritePatternSet &patterns);
 
 struct MathPolynomialApproximationOptions {
diff --git a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
index 5ccf3b6d72a2c..05d32ad2bc3eb 100644
--- a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
@@ -615,6 +615,24 @@ static LogicalResult convertRoundEvenOp(math::RoundEvenOp op,
   return success();
 }
 
+// Convert `math.rsqrt` into `arith.divf` + `math.sqrt`
+static LogicalResult convertRsqrtOp(math::RsqrtOp op,
+                                    PatternRewriter &rewriter) {
+
+  auto operand = op.getOperand();
+  auto operandTy = operand.getType();
+  auto eTy = getElementTypeOrSelf(operandTy);
+  if (!isa<FloatType>(eTy))
+    return failure();
+
+  Location loc = op->getLoc();
+  auto constOneFloat = createFloatConst(loc, operandTy, 1.0, rewriter);
+  auto sqrtOp = rewriter.create<math::SqrtOp>(loc, op->getOperand(0));
+  rewriter.replaceOpWithNewOp<arith::DivFOp>(op, operandTy,
+                                             ValueRange{constOneFloat, sqrtOp});
+  return success();
+}
+
 void mlir::populateExpandCtlzPattern(RewritePatternSet &patterns) {
   patterns.add(convertCtlzOp);
 }
@@ -678,3 +696,7 @@ void mlir::populateExpandFloorFPattern(RewritePatternSet &patterns) {
 void mlir::populateExpandRoundEvenPattern(RewritePatternSet &patterns) {
   patterns.add(convertRoundEvenOp);
 }
+
+void mlir::populateExpandRsqrtPattern(RewritePatternSet &patterns) {
+  patterns.add(convertRsqrtOp);
+}
diff --git a/mlir/test/Dialect/Math/expand-math.mlir b/mlir/test/Dialect/Math/expand-math.mlir
index 3d94b55126d09..d25f4e571e6a8 100644
--- a/mlir/test/Dialect/Math/expand-math.mlir
+++ b/mlir/test/Dialect/Math/expand-math.mlir
@@ -658,3 +658,45 @@ func.func @math_fpowi_to_powf_scalar(%0 : f32, %1: i64) -> f32 {
 // CHECK:        %[[AND:.*]] = arith.andi %[[CMPF1]], %[[CMPF]] : i1
 // CHECK:        %[[SEL:.*]] = arith.select %[[AND]], %[[MUL1]], %[[EXP]] : f32
 // CHECK:       return %[[SEL]] : f32
+
+// -----
+
+// CHECK-LABEL:   func.func @rsqrt
+// CHECK-SAME:     (%[[ARG:.*]]: f32)
+// CHECK-SAME:    -> f32
+// CHECK-DAG:     %[[CST:.*]] = arith.constant 1.000000e+00 : f32
+// CHECK-DAG:     %[[SQRT:.*]] = math.sqrt %[[ARG]] : f32
+// CHECK-DAG:     %[[DIV:.*]] = arith.divf %[[CST]], %[[SQRT]] : f32
+// CHECK:         return %[[DIV]] : f32
+func.func @rsqrt(%float: f32) -> (f32)  {
+  %float_result = math.rsqrt %float : f32
+  return %float_result : f32
+}
+
+// -----
+
+// CHECK-LABEL:   func.func @rsqrt_vec
+// CHECK-SAME:     (%[[ARG:.*]]: vector<5xf32>)
+// CHECK-SAME:    -> vector<5xf32>
+// CHECK-DAG:     %[[CST:.*]] = arith.constant dense<1.000000e+00> : vector<5xf32>
+// CHECK-DAG:     %[[SQRT:.*]] = math.sqrt %[[ARG]] : vector<5xf32>
+// CHECK-DAG:     %[[DIV:.*]] = arith.divf %[[CST]], %[[SQRT]] : vector<5xf32>
+// CHECK:         return %[[DIV]] : vector<5xf32>
+func.func @rsqrt_vec(%float: vector<5xf32>) -> (vector<5xf32>)  {
+  %float_result = math.rsqrt %float : vector<5xf32>
+  return %float_result : vector<5xf32>
+}
+
+// -----
+
+// CHECK-LABEL:   func.func @rsqrt_tns
+// CHECK-SAME:     (%[[ARG:.*]]: tensor<5x8xf32>)
+// CHECK-SAME:    -> tensor<5x8xf32>
+// CHECK-DAG:     %[[CST:.*]] = arith.constant dense<1.000000e+00> : tensor<5x8xf32>
+// CHECK-DAG:     %[[SQRT:.*]] = math.sqrt %[[ARG]] : tensor<5x8xf32>
+// CHECK-DAG:     %[[DIV:.*]] = arith.divf %[[CST]], %[[SQRT]] : tensor<5x8xf32>
+// CHECK:         return %[[DIV]] : tensor<5x8xf32>
+func.func @rsqrt_tns(%float: tensor<5x8xf32>) -> (tensor<5x8xf32>)  {
+  %float_result = math.rsqrt %float : tensor<5x8xf32>
+  return %float_result : tensor<5x8xf32>
+}
diff --git a/mlir/test/lib/Dialect/Math/TestExpandMath.cpp b/mlir/test/lib/Dialect/Math/TestExpandMath.cpp
index da48ccb6e5e08..69af2a08b97bd 100644
--- a/mlir/test/lib/Dialect/Math/TestExpandMath.cpp
+++ b/mlir/test/lib/Dialect/Math/TestExpandMath.cpp
@@ -52,6 +52,7 @@ void TestExpandMathPass::runOnOperation() {
   populateExpandFPowIPattern(patterns);
   populateExpandRoundFPattern(patterns);
   populateExpandRoundEvenPattern(patterns);
+  populateExpandRsqrtPattern(patterns);
   (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
 }
 

>From 6484141ac4db72f2a03a6f714c9e4183d3e65fc5 Mon Sep 17 00:00:00 2001
From: Corentin Ferry <corentin.ferry at amd.com>
Date: Wed, 8 May 2024 09:25:15 +0100
Subject: [PATCH 2/2] Review comments

---
 .../Math/Transforms/ExpandPatterns.cpp        |  2 +-
 mlir/test/Dialect/Math/expand-math.mlir       | 30 +++++++++++++-
 .../test-expand-math-approx.mlir              | 41 +++++++++++++++++++
 3 files changed, 71 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
index 05d32ad2bc3eb..4d2b6c5638c1f 100644
--- a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
@@ -627,7 +627,7 @@ static LogicalResult convertRsqrtOp(math::RsqrtOp op,
 
   Location loc = op->getLoc();
   auto constOneFloat = createFloatConst(loc, operandTy, 1.0, rewriter);
-  auto sqrtOp = rewriter.create<math::SqrtOp>(loc, op->getOperand(0));
+  auto sqrtOp = rewriter.create<math::SqrtOp>(loc, op.getOperand());
   rewriter.replaceOpWithNewOp<arith::DivFOp>(op, operandTy,
                                              ValueRange{constOneFloat, sqrtOp});
   return success();
diff --git a/mlir/test/Dialect/Math/expand-math.mlir b/mlir/test/Dialect/Math/expand-math.mlir
index d25f4e571e6a8..016a7bbdeb569 100644
--- a/mlir/test/Dialect/Math/expand-math.mlir
+++ b/mlir/test/Dialect/Math/expand-math.mlir
@@ -661,6 +661,20 @@ func.func @math_fpowi_to_powf_scalar(%0 : f32, %1: i64) -> f32 {
 
 // -----
 
+// CHECK-LABEL:   func.func @rsqrt
+// CHECK-SAME:     (%[[ARG:.*]]: f16)
+// CHECK-SAME:    -> f16
+// CHECK-DAG:     %[[CST:.*]] = arith.constant 1.000000e+00 : f16
+// CHECK-DAG:     %[[SQRT:.*]] = math.sqrt %[[ARG]] : f16
+// CHECK-DAG:     %[[DIV:.*]] = arith.divf %[[CST]], %[[SQRT]] : f16
+// CHECK:         return %[[DIV]] : f16
+func.func @rsqrt16(%float: f16) -> (f16)  {
+  %float_result = math.rsqrt %float : f16
+  return %float_result : f16
+}
+
+// -----
+
 // CHECK-LABEL:   func.func @rsqrt
 // CHECK-SAME:     (%[[ARG:.*]]: f32)
 // CHECK-SAME:    -> f32
@@ -668,13 +682,27 @@ func.func @math_fpowi_to_powf_scalar(%0 : f32, %1: i64) -> f32 {
 // CHECK-DAG:     %[[SQRT:.*]] = math.sqrt %[[ARG]] : f32
 // CHECK-DAG:     %[[DIV:.*]] = arith.divf %[[CST]], %[[SQRT]] : f32
 // CHECK:         return %[[DIV]] : f32
-func.func @rsqrt(%float: f32) -> (f32)  {
+func.func @rsqrt32(%float: f32) -> (f32)  {
   %float_result = math.rsqrt %float : f32
   return %float_result : f32
 }
 
 // -----
 
+// CHECK-LABEL:   func.func @rsqrt
+// CHECK-SAME:     (%[[ARG:.*]]: f64)
+// CHECK-SAME:    -> f64
+// CHECK-DAG:     %[[CST:.*]] = arith.constant 1.000000e+00 : f64
+// CHECK-DAG:     %[[SQRT:.*]] = math.sqrt %[[ARG]] : f64
+// CHECK-DAG:     %[[DIV:.*]] = arith.divf %[[CST]], %[[SQRT]] : f64
+// CHECK:         return %[[DIV]] : f64
+func.func @rsqrt64(%float: f64) -> (f64)  {
+  %float_result = math.rsqrt %float : f64
+  return %float_result : f64
+}
+
+// -----
+
 // CHECK-LABEL:   func.func @rsqrt_vec
 // CHECK-SAME:     (%[[ARG:.*]]: vector<5xf32>)
 // CHECK-SAME:    -> vector<5xf32>
diff --git a/mlir/test/mlir-cpu-runner/test-expand-math-approx.mlir b/mlir/test/mlir-cpu-runner/test-expand-math-approx.mlir
index 2b72acde6a3bb..9b929b3c864dc 100644
--- a/mlir/test/mlir-cpu-runner/test-expand-math-approx.mlir
+++ b/mlir/test/mlir-cpu-runner/test-expand-math-approx.mlir
@@ -833,6 +833,46 @@ func.func @atanh() {
   return
 }
 
+// -------------------------------------------------------------------------- //
+// Rsqrt.
+// -------------------------------------------------------------------------- //
+
+func.func @rsqrt_f32(%a : f32) {
+  %r = math.rsqrt %a : f32
+  vector.print %r : f32
+  return
+}
+
+func.func @rsqrt_3xf32(%a : vector<3xf32>) {
+  %r = math.rsqrt %a : vector<3xf32>
+  vector.print %r : vector<3xf32>
+  return
+}
+
+func.func @rsqrt() {
+  // CHECK: 1
+  %zero = arith.constant 1.0 : f32
+  call @rsqrt_f32(%zero) : (f32) -> ()
+
+  // CHECK: 0.707107
+  %cst1 = arith.constant 2.0 : f32
+  call @rsqrt_f32(%cst1) : (f32) -> ()
+
+  // CHECK: inf
+  %cst2 = arith.constant 0.0 : f32
+  call @rsqrt_f32(%cst2) : (f32) -> ()
+
+  // CHECK: -nan
+  %cst3 = arith.constant -1.0 : f32
+  call @rsqrt_f32(%cst3) : (f32) -> ()
+
+  // CHECK: 0.5, 1.41421, 0.57735
+  %vec_x = arith.constant dense<[4.0, 0.5, 3.0]> : vector<3xf32>
+  call @rsqrt_3xf32(%vec_x) : (vector<3xf32>) -> ()
+
+  return
+}
+
 func.func @main() {
   call @exp2f() : () -> ()
   call @roundf() : () -> ()
@@ -844,5 +884,6 @@ func.func @main() {
   call @asinh() : () -> ()
   call @acosh() : () -> ()
   call @atanh() : () -> ()
+  call @rsqrt() : () -> ()
   return
 }



More information about the Mlir-commits mailing list