[Mlir-commits] [mlir] 26c95ae - [mlir][Math] Add constant folder for sqrt.

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Mar 18 01:02:03 PDT 2022


Author: jacquesguan
Date: 2022-03-18T16:01:44+08:00
New Revision: 26c95ae38940b5b6ccfc65188ba9931eb51e468e

URL: https://github.com/llvm/llvm-project/commit/26c95ae38940b5b6ccfc65188ba9931eb51e468e
DIFF: https://github.com/llvm/llvm-project/commit/26c95ae38940b5b6ccfc65188ba9931eb51e468e.diff

LOG: [mlir][Math] Add constant folder for sqrt.

Differential Revision: https://reviews.llvm.org/D121980

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Math/IR/MathOps.td
    mlir/lib/Dialect/Math/IR/MathOps.cpp
    mlir/test/Dialect/Math/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Math/IR/MathOps.td b/mlir/include/mlir/Dialect/Math/IR/MathOps.td
index ca91d5353b584..b0ccbc21439b0 100644
--- a/mlir/include/mlir/Dialect/Math/IR/MathOps.td
+++ b/mlir/include/mlir/Dialect/Math/IR/MathOps.td
@@ -724,6 +724,7 @@ def Math_SqrtOp : Math_FloatUnaryOp<"sqrt"> {
     %x = math.sqrt %y : tensor<4x?xf32>
     ```
   }];
+  let hasFolder = 1;
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/Math/IR/MathOps.cpp b/mlir/lib/Dialect/Math/IR/MathOps.cpp
index 4410e93ef6e0f..42f8334403c1e 100644
--- a/mlir/lib/Dialect/Math/IR/MathOps.cpp
+++ b/mlir/lib/Dialect/Math/IR/MathOps.cpp
@@ -101,6 +101,31 @@ OpFoldResult math::PowFOp::fold(ArrayRef<Attribute> operands) {
   return {};
 }
 
+OpFoldResult math::SqrtOp::fold(ArrayRef<Attribute> operands) {
+  auto constOperand = operands.front();
+  if (!constOperand)
+    return {};
+
+  auto attr = constOperand.dyn_cast<FloatAttr>();
+  if (!attr)
+    return {};
+
+  auto ft = getType().cast<FloatType>();
+
+  APFloat apf = attr.getValue();
+
+  if (apf.isNegative())
+    return {};
+
+  if (ft.getWidth() == 64)
+    return FloatAttr::get(getType(), sqrt(apf.convertToDouble()));
+
+  if (ft.getWidth() == 32)
+    return FloatAttr::get(getType(), sqrtf(apf.convertToFloat()));
+
+  return {};
+}
+
 /// Materialize an integer or floating point constant.
 Operation *math::MathDialect::materializeConstant(OpBuilder &builder,
                                                   Attribute value, Type type,

diff  --git a/mlir/test/Dialect/Math/canonicalize.mlir b/mlir/test/Dialect/Math/canonicalize.mlir
index 5ee63b60bec1a..27a92908eaec4 100644
--- a/mlir/test/Dialect/Math/canonicalize.mlir
+++ b/mlir/test/Dialect/Math/canonicalize.mlir
@@ -82,3 +82,12 @@ func @powf_fold() -> f32 {
   %r = math.powf %c, %c : f32
   return %r : f32
 }
+
+// CHECK-LABEL: @sqrt_fold
+// CHECK: %[[cst:.+]] = arith.constant 2.000000e+00 : f32
+// CHECK: return %[[cst]]
+func @sqrt_fold() -> f32 {
+  %c = arith.constant 4.0 : f32
+  %r = math.sqrt %c : f32
+  return %r : f32
+}


        


More information about the Mlir-commits mailing list