[Mlir-commits] [mlir] ad4b7fb - [mlir][Math] Support fold Log2Op with constant dense.

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sun Jul 10 19:48:43 PDT 2022


Author: jacquesguan
Date: 2022-07-11T10:34:28+08:00
New Revision: ad4b7fb3ce018338e6ad8f5dbc26957434ad820c

URL: https://github.com/llvm/llvm-project/commit/ad4b7fb3ce018338e6ad8f5dbc26957434ad820c
DIFF: https://github.com/llvm/llvm-project/commit/ad4b7fb3ce018338e6ad8f5dbc26957434ad820c.diff

LOG: [mlir][Math] Support fold Log2Op with constant dense.

This patch is similar to D129108, it adds a conditional unary constant folder which allow to exit when the constants not meet the fold condition. And use it for Log2Op to make it able to fold the constant dense.

Differential Revision: https://reviews.llvm.org/D129251

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/CommonFolders.h
    mlir/lib/Dialect/Math/IR/MathOps.cpp
    mlir/test/Dialect/Math/canonicalize.mlir
    mlir/test/mlir-cpu-runner/math-polynomial-approx.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/CommonFolders.h b/mlir/include/mlir/Dialect/CommonFolders.h
index 55dc5ec2349ce..868089361e36b 100644
--- a/mlir/include/mlir/Dialect/CommonFolders.h
+++ b/mlir/include/mlir/Dialect/CommonFolders.h
@@ -98,11 +98,11 @@ Attribute constFoldBinaryOp(ArrayRef<Attribute> operands,
 
 /// Performs constant folding `calculate` with element-wise behavior on the one
 /// attributes in `operands` and returns the result if possible.
-template <class AttrElementT,
-          class ElementValueT = typename AttrElementT::ValueType,
-          class CalculationT = function_ref<ElementValueT(ElementValueT)>>
-Attribute constFoldUnaryOp(ArrayRef<Attribute> operands,
-                           const CalculationT &&calculate) {
+template <
+    class AttrElementT, class ElementValueT = typename AttrElementT::ValueType,
+    class CalculationT = function_ref<Optional<ElementValueT>(ElementValueT)>>
+Attribute constFoldUnaryOpConditional(ArrayRef<Attribute> operands,
+                                      const CalculationT &&calculate) {
   assert(operands.size() == 1 && "unary op takes one operands");
   if (!operands[0])
     return {};
@@ -110,7 +110,10 @@ Attribute constFoldUnaryOp(ArrayRef<Attribute> operands,
   if (operands[0].isa<AttrElementT>()) {
     auto op = operands[0].cast<AttrElementT>();
 
-    return AttrElementT::get(op.getType(), calculate(op.getValue()));
+    auto res = calculate(op.getValue());
+    if (!res)
+      return {};
+    return AttrElementT::get(op.getType(), *res);
   }
   if (operands[0].isa<SplatElementsAttr>()) {
     // Both operands are splats so we can avoid expanding the values out and
@@ -118,7 +121,9 @@ Attribute constFoldUnaryOp(ArrayRef<Attribute> operands,
     auto op = operands[0].cast<SplatElementsAttr>();
 
     auto elementResult = calculate(op.getSplatValue<ElementValueT>());
-    return DenseElementsAttr::get(op.getType(), elementResult);
+    if (!elementResult)
+      return {};
+    return DenseElementsAttr::get(op.getType(), *elementResult);
   } else if (operands[0].isa<ElementsAttr>()) {
     // Operands are ElementsAttr-derived; perform an element-wise fold by
     // expanding the values.
@@ -127,13 +132,27 @@ Attribute constFoldUnaryOp(ArrayRef<Attribute> operands,
     auto opIt = op.value_begin<ElementValueT>();
     SmallVector<ElementValueT> elementResults;
     elementResults.reserve(op.getNumElements());
-    for (size_t i = 0, e = op.getNumElements(); i < e; ++i, ++opIt)
-      elementResults.push_back(calculate(*opIt));
+    for (size_t i = 0, e = op.getNumElements(); i < e; ++i, ++opIt) {
+      auto elementResult = calculate(*opIt);
+      if (!elementResult)
+        return {};
+      elementResults.push_back(*elementResult);
+    }
     return DenseElementsAttr::get(op.getType(), elementResults);
   }
   return {};
 }
 
+template <class AttrElementT,
+          class ElementValueT = typename AttrElementT::ValueType,
+          class CalculationT = function_ref<ElementValueT(ElementValueT)>>
+Attribute constFoldUnaryOp(ArrayRef<Attribute> operands,
+                           const CalculationT &&calculate) {
+  return constFoldUnaryOpConditional<AttrElementT>(
+      operands,
+      [&](ElementValueT a) -> Optional<ElementValueT> { return calculate(a); });
+}
+
 template <
     class AttrElementT, class TargetAttrElementT,
     class ElementValueT = typename AttrElementT::ValueType,

diff  --git a/mlir/lib/Dialect/Math/IR/MathOps.cpp b/mlir/lib/Dialect/Math/IR/MathOps.cpp
index 34e20724c78a0..035e9b49ba5e2 100644
--- a/mlir/lib/Dialect/Math/IR/MathOps.cpp
+++ b/mlir/lib/Dialect/Math/IR/MathOps.cpp
@@ -92,28 +92,19 @@ OpFoldResult math::CtPopOp::fold(ArrayRef<Attribute> operands) {
 //===----------------------------------------------------------------------===//
 
 OpFoldResult math::Log2Op::fold(ArrayRef<Attribute> operands) {
-  auto constOperand = operands.front();
-  if (!constOperand)
-    return {};
-
-  auto attr = constOperand.dyn_cast<FloatAttr>();
-  if (!attr)
-    return {};
+  return constFoldUnaryOpConditional<FloatAttr>(
+      operands, [](const APFloat &a) -> Optional<APFloat> {
+        if (a.isNegative())
+          return {};
 
-  auto ft = getType().cast<FloatType>();
+        if (a.getSizeInBits(a.getSemantics()) == 64)
+          return APFloat(log2(a.convertToDouble()));
 
-  APFloat apf = attr.getValue();
+        if (a.getSizeInBits(a.getSemantics()) == 32)
+          return APFloat(log2f(a.convertToFloat()));
 
-  if (apf.isNegative())
-    return {};
-
-  if (ft.getWidth() == 64)
-    return FloatAttr::get(getType(), log2(apf.convertToDouble()));
-
-  if (ft.getWidth() == 32)
-    return FloatAttr::get(getType(), log2f(apf.convertToFloat()));
-
-  return {};
+        return {};
+      });
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/Math/canonicalize.mlir b/mlir/test/Dialect/Math/canonicalize.mlir
index bcfdf1b9e965c..2ddd766d7d302 100644
--- a/mlir/test/Dialect/Math/canonicalize.mlir
+++ b/mlir/test/Dialect/Math/canonicalize.mlir
@@ -74,6 +74,15 @@ func.func @log2_nofold2_64() -> f64 {
   return %r : f64
 }
 
+// CHECK-LABEL: @log2_fold_vec
+// CHECK: %[[cst:.+]] = arith.constant dense<[0.000000e+00, 1.000000e+00, 1.58496249, 2.000000e+00]> : vector<4xf32>
+// CHECK: return %[[cst]]
+func.func @log2_fold_vec() -> (vector<4xf32>) {
+  %v1 = arith.constant dense<[1.0, 2.0, 3.0, 4.0]> : vector<4xf32>
+  %0 = math.log2 %v1 : vector<4xf32>
+  return %0 : vector<4xf32>
+}
+
 // CHECK-LABEL: @powf_fold
 // CHECK: %[[cst:.+]] = arith.constant 4.000000e+00 : f32
 // CHECK: return %[[cst]]

diff  --git a/mlir/test/mlir-cpu-runner/math-polynomial-approx.mlir b/mlir/test/mlir-cpu-runner/math-polynomial-approx.mlir
index 3f34bc980d4cd..e0fe1a8a6f99e 100644
--- a/mlir/test/mlir-cpu-runner/math-polynomial-approx.mlir
+++ b/mlir/test/mlir-cpu-runner/math-polynomial-approx.mlir
@@ -80,7 +80,7 @@ func.func @log2() {
   %1 = math.log2 %0 : f32
   vector.print %1 : f32
 
-  // CHECK: -2, -0.415037, 0, 0.321928
+  // CHECK: -2, -0.415038, 0, 0.321928
   %2 = arith.constant dense<[0.25, 0.75, 1.0, 1.25]> : vector<4xf32>
   %3 = math.log2 %2 : vector<4xf32>
   vector.print %3 : vector<4xf32>


        


More information about the Mlir-commits mailing list