[Mlir-commits] [mlir] 088d388 - [mlir][Arithmetic] Add constant folder for negf.

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Apr 8 00:36:52 PDT 2022


Author: jacquesguan
Date: 2022-04-08T07:36:29Z
New Revision: 088d38890ccee92d5ef6ae13ec1c50f9b0083866

URL: https://github.com/llvm/llvm-project/commit/088d38890ccee92d5ef6ae13ec1c50f9b0083866
DIFF: https://github.com/llvm/llvm-project/commit/088d38890ccee92d5ef6ae13ec1c50f9b0083866.diff

LOG: [mlir][Arithmetic] Add constant folder for negf.

Reviewed By: ftynse

Differential Revision: https://reviews.llvm.org/D123293

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td
    mlir/include/mlir/Dialect/CommonFolders.h
    mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp
    mlir/test/Dialect/Arithmetic/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td b/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td
index f7fd294be31ff..6305c6947c753 100644
--- a/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td
+++ b/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td
@@ -585,6 +585,7 @@ def Arith_NegFOp : Arith_FloatUnaryOp<"negf"> {
     %x = arith.negf %y : tensor<4x?xf8>
     ```
   }];
+  let hasFolder = 1;
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/include/mlir/Dialect/CommonFolders.h b/mlir/include/mlir/Dialect/CommonFolders.h
index b3f600ed4ff49..7ba43c92e7563 100644
--- a/mlir/include/mlir/Dialect/CommonFolders.h
+++ b/mlir/include/mlir/Dialect/CommonFolders.h
@@ -69,6 +69,45 @@ Attribute constFoldBinaryOp(ArrayRef<Attribute> operands,
   }
   return {};
 }
+
+/// Performs constant folding `calculate` with element-wise behavior on the one
+/// attributes in `operands` and returns the result if possible.
+template <class AttrElementT,
+          class ElementValueT = typename AttrElementT::ValueType,
+          class CalculationT = function_ref<ElementValueT(ElementValueT)>>
+Attribute constFoldUnaryOp(ArrayRef<Attribute> operands,
+                           const CalculationT &&calculate) {
+  assert(operands.size() == 1 && "unary op takes one operands");
+  if (!operands[0])
+    return {};
+
+  if (operands[0].isa<AttrElementT>()) {
+    auto op = operands[0].cast<AttrElementT>();
+
+    return AttrElementT::get(op.getType(), calculate(op.getValue()));
+  }
+  if (operands[0].isa<SplatElementsAttr>()) {
+    // Both operands are splats so we can avoid expanding the values out and
+    // just fold based on the splat value.
+    auto op = operands[0].cast<SplatElementsAttr>();
+
+    auto elementResult = calculate(op.getSplatValue<ElementValueT>());
+    return DenseElementsAttr::get(op.getType(), elementResult);
+  } else if (operands[0].isa<ElementsAttr>()) {
+    // Operands are ElementsAttr-derived; perform an element-wise fold by
+    // expanding the values.
+    auto op = operands[0].cast<ElementsAttr>();
+
+    auto opIt = op.value_begin<ElementValueT>();
+    SmallVector<ElementValueT> elementResults;
+    elementResults.reserve(op.getNumElements());
+    for (size_t i = 0, e = op.getNumElements(); i < e; ++i, ++opIt)
+      elementResults.push_back(calculate(*opIt));
+    return DenseElementsAttr::get(op.getType(), elementResults);
+  }
+  return {};
+}
+
 } // namespace mlir
 
 #endif // MLIR_DIALECT_COMMONFOLDERS_H

diff  --git a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp
index 74b77dc9cbd63..d38f4f1c9994e 100644
--- a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp
+++ b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp
@@ -573,6 +573,15 @@ void arith::XOrIOp::getCanonicalizationPatterns(
   patterns.add<XOrINotCmpI>(context);
 }
 
+//===----------------------------------------------------------------------===//
+// NegFOp
+//===----------------------------------------------------------------------===//
+
+OpFoldResult arith::NegFOp::fold(ArrayRef<Attribute> operands) {
+  return constFoldUnaryOp<FloatAttr>(operands,
+                                     [](const APFloat &a) { return -a; });
+}
+
 //===----------------------------------------------------------------------===//
 // AddFOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/Arithmetic/canonicalize.mlir b/mlir/test/Dialect/Arithmetic/canonicalize.mlir
index 550e84c97118d..b4c92d6089e9b 100644
--- a/mlir/test/Dialect/Arithmetic/canonicalize.mlir
+++ b/mlir/test/Dialect/Arithmetic/canonicalize.mlir
@@ -1166,3 +1166,14 @@ func @nofoldShrs2() -> i64 {
   %r = arith.shrsi %c1, %cm32 : i64
   return %r : i64
 }
+
+// -----
+
+// CHECK-LABEL: @test_negf(
+// CHECK: %[[res:.+]] = arith.constant -2.0
+// CHECK: return %[[res]]
+func @test_negf() -> (f32) {
+  %c = arith.constant 2.0 : f32
+  %0 = arith.negf %c : f32
+  return %0: f32
+}


        


More information about the Mlir-commits mailing list