[Mlir-commits] [mlir] d89a0a6 - [mlir][Tosa]: Add folder to ReciprocalOp of splat constant inputs (#78137)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Jan 16 23:05:11 PST 2024


Author: Aviad Cohen
Date: 2024-01-17T09:05:07+02:00
New Revision: d89a0a65947eb0c7bce733ee76991f900209d139

URL: https://github.com/llvm/llvm-project/commit/d89a0a65947eb0c7bce733ee76991f900209d139
DIFF: https://github.com/llvm/llvm-project/commit/d89a0a65947eb0c7bce733ee76991f900209d139.diff

LOG: [mlir][Tosa]: Add folder to ReciprocalOp of splat constant inputs (#78137)

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
    mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
    mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
    mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp
    mlir/test/Dialect/Tosa/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 3257ecd9d91f11..0ee9e713724ea2 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -1114,6 +1114,17 @@ def Tosa_ReciprocalOp : Tosa_ElementwiseOp<"reciprocal",
   let results = (outs
     Tosa_Tensor:$output
   );
+
+  let extraClassDeclaration = [{
+    /// Return the reciprocal result on the operand.
+    static inline APFloat calcOneElement(const APFloat &operand) {
+      APFloat recip = APFloat(operand.getSemantics(), 1);
+      recip.divide(operand, APFloat::rmNearestTiesToEven);
+      return recip;
+    }
+  }];
+
+  let hasFolder = 1;
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index 26c39ff3523434..3f683f701e0fcd 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -17,6 +17,7 @@
 #include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
 #include "mlir/Dialect/Tosa/Utils/QuantUtils.h"
 #include "mlir/Dialect/Tosa/Utils/ShapeUtils.h"
+#include "mlir/IR/BuiltinTypeInterfaces.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/DialectImplementation.h"
 #include "mlir/IR/Matchers.h"
@@ -25,6 +26,7 @@
 #include "mlir/Transforms/InliningUtils.h"
 #include "mlir/Transforms/RegionUtils.h"
 #include "llvm/ADT/APFloat.h"
+#include "llvm/ADT/APInt.h"
 #include "llvm/ADT/DenseMap.h"
 #include "llvm/ADT/TypeSwitch.h"
 
@@ -1036,3 +1038,21 @@ OpFoldResult ConcatOp::fold(FoldAdaptor adaptor) {
   getOperation()->setOperands(concatOperands);
   return getResult();
 }
+
+OpFoldResult tosa::ReciprocalOp::fold(FoldAdaptor adaptor) {
+  auto input = adaptor.getInput1();
+
+  auto inputAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(input);
+  // Fold splat inputs only.
+  if (!inputAttr || !inputAttr.isSplat())
+    return {};
+
+  auto shapeType = llvm::cast<ShapedType>(getType());
+  if (auto floatType = llvm::dyn_cast<FloatType>(inputAttr.getElementType())) {
+    auto floatVal = inputAttr.getSplatValue<APFloat>();
+    return DenseElementsAttr::get(shapeType,
+                                  ReciprocalOp::calcOneElement(floatVal));
+  }
+
+  return {};
+}

diff  --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 661126f4df9976..729116da45e47d 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -25,6 +25,7 @@
 #include "mlir/IR/TypeUtilities.h"
 #include "mlir/Interfaces/InferTypeOpInterface.h"
 #include "mlir/Transforms/InliningUtils.h"
+#include "llvm/ADT/APFloat.h"
 #include "llvm/ADT/DenseMap.h"
 #include "llvm/ADT/TypeSwitch.h"
 

diff  --git a/mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp
index d35e911ebe63c4..050f8ca3f32aed 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp
@@ -30,10 +30,6 @@ using namespace mlir::tosa;
 
 namespace {
 
-/// Rounding mode to be used on floating point operations that require rounding.
-static constexpr llvm::RoundingMode tosaRoundingMode =
-    llvm::APFloat::rmNearestTiesToEven;
-
 /// Apply the given transformation \p toApply to every element of the tensor to
 /// be transformed \p toTransform.
 ///
@@ -44,14 +40,14 @@ static constexpr llvm::RoundingMode tosaRoundingMode =
 template <class SrcValType, class TargetValType, class TargetType>
 DenseElementsAttr applyElementWise(
     const DenseElementsAttr &toTransform,
-    const std::function<TargetValType(const SrcValType &, TargetType)> &toApply,
+    const std::function<TargetValType(const SrcValType &)> &toApply,
     TargetType targetType) {
   SmallVector<TargetValType> transformedValues;
   // We already know the amount of values we will insert, reserve space for
   // all of them to avoid dynamic resizing
   transformedValues.reserve(toTransform.getNumElements());
   for (auto val : toTransform.getValues<SrcValType>()) {
-    auto transformedVal = toApply(val, targetType);
+    auto transformedVal = toApply(val);
     transformedValues.push_back(transformedVal);
   }
 
@@ -64,7 +60,7 @@ DenseElementsAttr applyElementWise(
 
 template DenseElementsAttr applyElementWise<APFloat, APFloat, FloatType>(
     const DenseElementsAttr &toTransform,
-    const std::function<APFloat(const APFloat &, FloatType)> &toApply,
+    const std::function<APFloat(const APFloat &)> &toApply,
     FloatType targetType);
 
 /// Function that checks if the type contained in \p toCheck is float.
@@ -249,14 +245,6 @@ struct TosaFoldConstantReciprocal : public OpRewritePattern<ReciprocalOp> {
 
   using OpRewritePattern::OpRewritePattern;
 
-  static APFloat computeReciprocal(const APFloat &floatVal, FloatType floatTy) {
-    auto recipAttr = FloatAttr::get(floatTy, 1.0);
-    APFloat recip = recipAttr.getValue();
-    recip.divide(floatVal, tosaRoundingMode);
-
-    return recip;
-  }
-
   LogicalResult matchAndRewrite(ReciprocalOp recip,
                                 PatternRewriter &rewriter) const override {
     auto inputTensor = recip.getInput1();
@@ -281,7 +269,7 @@ struct TosaFoldConstantReciprocal : public OpRewritePattern<ReciprocalOp> {
 
     // Create a new tensor with the updated values
     auto newTensor = applyElementWise<APFloat, APFloat, FloatType>(
-        inputValues, &computeReciprocal,
+        inputValues, &ReciprocalOp::calcOneElement,
         cast<FloatType>(inputValues.getElementType()));
 
     // Replace the use of the reciprocal with the transformed tensor

diff  --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index fd51d287bca058..e7ede2e0ccef9a 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -613,3 +613,27 @@ func.func nested @fold_tile_rank_zero() -> tensor<i32> {
   %1 = tosa.tile %0 {multiples = array<i64>} : (tensor<i32>) -> tensor<i32>
   return %1 : tensor<i32>
 }
+
+// -----
+
+// CHECK-LABEL: @fold_reciprocal
+func.func nested @fold_reciprocal() -> tensor<3x600x1200xf32> {
+  // CHECK:           %[[VAL_0:.*]] = "tosa.const"() <{value = dense<8.620690e-03> : tensor<3x600x1200xf32>}> : () -> tensor<3x600x1200xf32>
+  // CHECK:           return %[[VAL_0]] : tensor<3x600x1200xf32>
+  // CHECK:         }
+  %0 = "tosa.const"(){ value = dense<116.0>: tensor<f32> }: () -> tensor<f32>
+  %1 = "tosa.cast"(%0) : (tensor<f32>) -> tensor<3x600x1200xf32>
+  %2 = "tosa.reciprocal"(%1): (tensor<3x600x1200xf32>) -> tensor<3x600x1200xf32>
+  return %2 : tensor<3x600x1200xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @do_not_fold_reciprocal_int
+func.func nested @do_not_fold_reciprocal_int() -> tensor<3x600x1200xi32> {
+  // CHECK:           tosa.reciprocal
+  %0 = "tosa.const"(){ value = dense<11>: tensor<i32> }: () -> tensor<i32>
+  %1 = "tosa.cast"(%0) : (tensor<i32>) -> tensor<3x600x1200xi32>
+  %2 = "tosa.reciprocal"(%1): (tensor<3x600x1200xi32>) -> tensor<3x600x1200xi32>
+  return %2 : tensor<3x600x1200xi32>
+}


        


More information about the Mlir-commits mailing list