[Mlir-commits] [mlir] d84d418 - [mlir][tosa] Constant folding for reciprocal

Matthias Gehre llvmlistbot at llvm.org
Wed Jul 5 02:39:08 PDT 2023


Author: Tina Jung
Date: 2023-07-05T11:38:46+02:00
New Revision: d84d418e2adc421c98e484ab3b09e2f4f3e5c1ef

URL: https://github.com/llvm/llvm-project/commit/d84d418e2adc421c98e484ab3b09e2f4f3e5c1ef
DIFF: https://github.com/llvm/llvm-project/commit/d84d418e2adc421c98e484ab3b09e2f4f3e5c1ef.diff

LOG: [mlir][tosa] Constant folding for reciprocal

Add constant fold for tosa.reciprocal, which can be applied if the input is a dense constant tensor. The reciprocal is computed for every element and the result is a tensor with the same dimensions as the input tensor.

As the input tensor might require a lot of memory and the folding might double the required memory, a heuristic decides when to actually apply the folding. Currently, the operation will be replaced only if the input constant is a splat (i.e. requires little memory) or has in single user (similar to the already existing fold for constant transposes). This keeps the additionally required space low.

Differential Revision: https://reviews.llvm.org/D150578

Added: 
    mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp
    mlir/test/Dialect/Tosa/constant-reciprocal-fold.mlir

Modified: 
    mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h
    mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt
    mlir/lib/Dialect/Tosa/Transforms/TosaLayerwiseConstantFoldPass.cpp

Removed: 
    mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantTranspose.cpp


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h
index d6ae78196f4cbd..c81f59b3d5d36a 100644
--- a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h
@@ -30,6 +30,8 @@ void populateTosaDecomposeTransposeConv(MLIRContext *ctx,
                                         RewritePatternSet &patterns);
 void populateTosaDecomposeDepthwise(MLIRContext *ctx,
                                     RewritePatternSet &patterns);
+void populateTosaFoldConstantReciprocalPatterns(MLIRContext *ctx,
+                                                RewritePatternSet &patterns);
 void populateTosaFoldConstantTransposePatterns(MLIRContext *ctx,
                                                RewritePatternSet &patterns);
 

diff  --git a/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt
index 4f5a54de0c7346..0e6510ba1e9255 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt
@@ -2,7 +2,7 @@ add_mlir_dialect_library(MLIRTosaTransforms
   TosaDecomposeTransposeConv.cpp
   TosaDecomposeConv2D.cpp
   TosaDecomposeDepthwise.cpp
-  TosaFoldConstantTranspose.cpp
+  TosaFolders.cpp
   TosaInferShapes.cpp
   TosaLayerwiseConstantFoldPass.cpp
   TosaMakeBroadcastable.cpp

diff  --git a/mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantTranspose.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantTranspose.cpp
deleted file mode 100644
index 302e2793f0a32e..00000000000000
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantTranspose.cpp
+++ /dev/null
@@ -1,138 +0,0 @@
-//===- TosaFoldConstantTranspose.cpp --------------------------------------===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-//
-// Fold TOSA Transpose operation on constant data
-//
-//===----------------------------------------------------------------------===//
-
-#include "mlir/Dialect/Tosa/IR/TosaOps.h"
-#include "mlir/Dialect/Tosa/Transforms/Passes.h"
-#include "mlir/Dialect/Utils/IndexingUtils.h"
-#include "mlir/IR/Matchers.h"
-#include "mlir/Pass/Pass.h"
-
-using namespace mlir;
-using namespace mlir::tosa;
-
-namespace {
-
-template <typename BaseType>
-DenseElementsAttr transposeType(ElementsAttr attr, ShapedType inputType,
-                                ShapedType outputType,
-                                llvm::ArrayRef<int64_t> permValues) {
-  if (inputType.getNumElements() == 0)
-    return DenseElementsAttr::get(outputType, llvm::ArrayRef<BaseType>{});
-
-  auto attrValues = attr.getValues<BaseType>();
-  auto inputShape = inputType.getShape();
-
-  // The inverted permutation map and strides of the output are used to compute
-  // the contribution of a given dimension to the destination linear index in
-  // an order-independent way.
-  auto outputStrides = computeStrides(outputType.getShape());
-  auto invertedPermValues = invertPermutationVector(permValues);
-
-  auto initialValue = *std::begin(attrValues);
-  SmallVector<BaseType> outputValues(inputType.getNumElements(), initialValue);
-
-  for (const auto &it : llvm::enumerate(attrValues)) {
-    auto srcLinearIndex = it.index();
-
-    uint64_t dstLinearIndex = 0;
-    for (int64_t dim = inputShape.size() - 1; dim >= 0; --dim) {
-      // Compute the index into the current dimension of the source vector.
-      auto sourceIndexForDim = srcLinearIndex % inputShape[dim];
-      srcLinearIndex /= inputShape[dim];
-
-      // Add the contribution of the current dimension to the output using the
-      // permutation map.
-      dstLinearIndex +=
-          outputStrides[invertedPermValues[dim]] * sourceIndexForDim;
-    }
-
-    outputValues[dstLinearIndex] = it.value();
-  }
-
-  return DenseElementsAttr::get(outputType,
-                                llvm::ArrayRef<BaseType>(outputValues));
-}
-
-// A type specialized transposition of an ElementsAttr.
-// This implementation tries to operate on the underlying data in its raw
-// representation when possible to avoid allocating a large number of Attribute
-// objects.
-DenseElementsAttr transpose(ElementsAttr attr, ShapedType inputType,
-                            ShapedType outputType,
-                            llvm::ArrayRef<int64_t> permValues) {
-  auto baseType = inputType.getElementType();
-
-  // Handle possible integer types
-  if (auto intType = dyn_cast<IntegerType>(baseType)) {
-    switch (intType.getWidth()) {
-    case 1:
-      return transposeType<bool>(attr, inputType, outputType, permValues);
-    case 8:
-      return transposeType<int8_t>(attr, inputType, outputType, permValues);
-    case 16:
-      return transposeType<int16_t>(attr, inputType, outputType, permValues);
-    case 32:
-      return transposeType<int32_t>(attr, inputType, outputType, permValues);
-    case 64:
-      return transposeType<int64_t>(attr, inputType, outputType, permValues);
-    default:
-      return transposeType<APInt>(attr, inputType, outputType, permValues);
-    }
-  }
-
-  // Handle possible float types
-  if (baseType.isF32()) {
-    return transposeType<float>(attr, inputType, outputType, permValues);
-  }
-
-  return transposeType<APFloat>(attr, inputType, outputType, permValues);
-}
-
-struct TosaFoldConstantTranspose : public OpRewritePattern<tosa::TransposeOp> {
-  using OpRewritePattern::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(tosa::TransposeOp op,
-                                PatternRewriter &rewriter) const override {
-    auto outputType = cast<ShapedType>(op.getType());
-    // TOSA supports quantized types.
-    if (!outputType.getElementType().isIntOrIndexOrFloat())
-      return failure();
-
-    ElementsAttr inputValues;
-    if (!matchPattern(op.getInput1(), m_Constant(&inputValues)))
-      return failure();
-    // Make sure the input is a constant that has a single user.
-    if (!llvm::hasSingleElement(op.getInput1().getDefiningOp()->getUsers()))
-      return failure();
-
-    DenseIntElementsAttr permAttr;
-    if (!matchPattern(op.getPerms(), m_Constant(&permAttr)))
-      return failure();
-    auto permValues = llvm::to_vector<6>(llvm::map_range(
-        // TOSA allows both 32- and 64-bit integer tensors here.
-        permAttr.getValues<APInt>(),
-        [](const APInt &val) { return val.getSExtValue(); }));
-
-    auto inputType = cast<ShapedType>(op.getInput1().getType());
-
-    auto resultAttr = transpose(inputValues, inputType, outputType, permValues);
-    rewriter.replaceOpWithNewOp<tosa::ConstOp>(op, outputType, resultAttr);
-    return success();
-  }
-};
-
-} // namespace
-
-void mlir::tosa::populateTosaFoldConstantTransposePatterns(
-    MLIRContext *ctx, RewritePatternSet &patterns) {
-  patterns.add<TosaFoldConstantTranspose>(ctx);
-}

diff  --git a/mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp
new file mode 100644
index 00000000000000..58693991952a3b
--- /dev/null
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp
@@ -0,0 +1,302 @@
+//===- TosaFolders.cpp ----------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Fold TOSA operations
+//
+//===----------------------------------------------------------------------===//
+
+#include <functional>
+
+#include "mlir/Dialect/Tosa/IR/TosaOps.h"
+#include "mlir/Dialect/Tosa/Transforms/Passes.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Matchers.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Support/LogicalResult.h"
+#include "llvm/ADT/APFloat.h"
+#include "llvm/ADT/FloatingPointMode.h"
+#include "llvm/ADT/SmallVector.h"
+
+using namespace mlir;
+using namespace mlir::tosa;
+
+namespace {
+
+/// Rounding mode to be used on floating point operations that require rounding.
+static constexpr llvm::RoundingMode tosaRoundingMode =
+    llvm::APFloat::rmNearestTiesToEven;
+
+/// Apply the given transformation \p toApply to every element of the tensor to
+/// be transformed \p toTransform.
+///
+/// Elements of \p toTransform are extracted as \p SrcValueType.
+///
+/// \returns A tensor with the same size as \p toTransform, containing
+/// \p TargetValueType values of type \p TargetType.
+template <class SrcValType, class TargetValType, class TargetType>
+DenseElementsAttr applyElementWise(
+    const DenseElementsAttr &toTransform,
+    const std::function<TargetValType(const SrcValType &, TargetType)> &toApply,
+    TargetType targetType) {
+  SmallVector<TargetValType> transformedValues;
+  // We already know the amount of values we will insert, reserve space for
+  // all of them to avoid dynamic resizing
+  transformedValues.reserve(toTransform.getNumElements());
+  for (auto val : toTransform.getValues<SrcValType>()) {
+    auto transformedVal = toApply(val, targetType);
+    transformedValues.push_back(transformedVal);
+  }
+
+  // Make sure that the output tensor has the expected output type
+  auto inShape = toTransform.getType();
+  auto outTy = inShape.cloneWith({}, targetType);
+
+  return DenseElementsAttr::get(outTy, transformedValues);
+}
+
+template DenseElementsAttr applyElementWise<APFloat, APFloat, FloatType>(
+    const DenseElementsAttr &toTransform,
+    const std::function<APFloat(const APFloat &, FloatType)> &toApply,
+    FloatType targetType);
+
+/// Function that checks if the type contained in \p toCheck is float.
+LogicalResult notifyIfNotFloat(TypedValue<TensorType> toCheck, TosaOp location,
+                               PatternRewriter &rewriter) {
+  if (isa<FloatType>(toCheck.getType().getElementType())) {
+    return success();
+  }
+  return rewriter.notifyMatchFailure(location,
+                                     "Unexpected input tensor type: the "
+                                     "TOSA spec only allows floats");
+}
+
+/// Function that checks if \p toCheck is a dense TOSA constant tensor.
+LogicalResult notifyIfNoTosaDenseConstantTensor(TypedValue<TensorType> toCheck,
+                                                TosaOp location,
+                                                PatternRewriter &rewriter) {
+  // Check whether the tensor is constant and dense
+  // TODO We currently ensure the tensor is dense by using the correct type for
+  // the bind_value, however we do not actually need this value. It would be
+  // nicer to only have a check here.
+  DenseElementsAttr tmp;
+  if (!matchPattern(toCheck, m_Constant(&tmp))) {
+    return rewriter.notifyMatchFailure(location,
+                                       "Non-const or non-dense input tensor");
+  }
+
+  // Make sure it actually is a TOSA constant (the match allows for other
+  // constants as well)
+  if (isa<ConstOp>(toCheck.getDefiningOp())) {
+    return success();
+  }
+
+  return rewriter.notifyMatchFailure(location,
+                                     "The reciprocal can only be folded if "
+                                     "it operates on a TOSA constant");
+}
+
+/// Function that checks if \p toCheck is a dense TOSA constant float tensor.
+LogicalResult notifyIfNotConstantFloatTosaTensor(TypedValue<TensorType> toCheck,
+                                                 TosaOp location,
+                                                 PatternRewriter &rewriter) {
+  auto floatCheck = notifyIfNotFloat(toCheck, location, rewriter);
+  if (failed(floatCheck)) {
+    return floatCheck;
+  }
+  return notifyIfNoTosaDenseConstantTensor(toCheck, location, rewriter);
+}
+
+/// Heuristic to decide when to replace a unary operation on a constant with the
+/// folded value.
+/// Folding operations on constants can lead to an increased memory usage
+/// whenever the input cannot be replaced but a new constant is inserted. Hence,
+/// this will currently only suggest folding when the memory impact is
+/// negligible.
+/// Takes the \p unaryOp and the constant input \p values.
+/// \returns Whether folding should be applied.
+bool constantUnaryOpShouldBeFolded(TosaOp unaryOp, DenseElementsAttr values) {
+  assert(unaryOp->getNumOperands() == 1);
+  auto inputOp = unaryOp->getOperand(0);
+
+  // If the input is a splat, we don't care for the number of users
+  if (isa<SplatElementsAttr>(values)) {
+    return true;
+  }
+
+  // If this is the only use of the tensor it should be replaced as no
+  // additional memory is required
+  return inputOp.hasOneUse();
+}
+
+template <typename BaseType>
+DenseElementsAttr transposeType(ElementsAttr attr, ShapedType inputType,
+                                ShapedType outputType,
+                                llvm::ArrayRef<int64_t> permValues) {
+  if (inputType.getNumElements() == 0)
+    return DenseElementsAttr::get(outputType, llvm::ArrayRef<BaseType>{});
+
+  auto attrValues = attr.getValues<BaseType>();
+  auto inputShape = inputType.getShape();
+
+  // The inverted permutation map and strides of the output are used to compute
+  // the contribution of a given dimension to the destination linear index in
+  // an order-independent way.
+  auto outputStrides = computeStrides(outputType.getShape());
+  auto invertedPermValues = invertPermutationVector(permValues);
+
+  auto initialValue = *std::begin(attrValues);
+  SmallVector<BaseType> outputValues(inputType.getNumElements(), initialValue);
+
+  for (const auto &it : llvm::enumerate(attrValues)) {
+    auto srcLinearIndex = it.index();
+
+    uint64_t dstLinearIndex = 0;
+    for (int64_t dim = inputShape.size() - 1; dim >= 0; --dim) {
+      // Compute the index into the current dimension of the source vector.
+      auto sourceIndexForDim = srcLinearIndex % inputShape[dim];
+      srcLinearIndex /= inputShape[dim];
+
+      // Add the contribution of the current dimension to the output using the
+      // permutation map.
+      dstLinearIndex +=
+          outputStrides[invertedPermValues[dim]] * sourceIndexForDim;
+    }
+
+    outputValues[dstLinearIndex] = it.value();
+  }
+
+  return DenseElementsAttr::get(outputType,
+                                llvm::ArrayRef<BaseType>(outputValues));
+}
+
+// A type specialized transposition of an ElementsAttr.
+// This implementation tries to operate on the underlying data in its raw
+// representation when possible to avoid allocating a large number of Attribute
+// objects.
+DenseElementsAttr transpose(ElementsAttr attr, ShapedType inputType,
+                            ShapedType outputType,
+                            llvm::ArrayRef<int64_t> permValues) {
+  auto baseType = inputType.getElementType();
+
+  // Handle possible integer types
+  if (auto intType = dyn_cast<IntegerType>(baseType)) {
+    switch (intType.getWidth()) {
+    case 1:
+      return transposeType<bool>(attr, inputType, outputType, permValues);
+    case 8:
+      return transposeType<int8_t>(attr, inputType, outputType, permValues);
+    case 16:
+      return transposeType<int16_t>(attr, inputType, outputType, permValues);
+    case 32:
+      return transposeType<int32_t>(attr, inputType, outputType, permValues);
+    case 64:
+      return transposeType<int64_t>(attr, inputType, outputType, permValues);
+    default:
+      return transposeType<APInt>(attr, inputType, outputType, permValues);
+    }
+  }
+
+  // Handle possible float types
+  if (baseType.isF32()) {
+    return transposeType<float>(attr, inputType, outputType, permValues);
+  }
+
+  return transposeType<APFloat>(attr, inputType, outputType, permValues);
+}
+
+struct TosaFoldConstantTranspose : public OpRewritePattern<tosa::TransposeOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(tosa::TransposeOp op,
+                                PatternRewriter &rewriter) const override {
+    auto outputType = cast<ShapedType>(op.getType());
+    // TOSA supports quantized types.
+    if (!outputType.getElementType().isIntOrIndexOrFloat())
+      return failure();
+
+    ElementsAttr inputValues;
+    if (!matchPattern(op.getInput1(), m_Constant(&inputValues)))
+      return failure();
+    // Make sure the input is a constant that has a single user.
+    if (!llvm::hasSingleElement(op.getInput1().getDefiningOp()->getUsers()))
+      return failure();
+
+    DenseIntElementsAttr permAttr;
+    if (!matchPattern(op.getPerms(), m_Constant(&permAttr)))
+      return failure();
+    auto permValues = llvm::to_vector<6>(llvm::map_range(
+        // TOSA allows both 32- and 64-bit integer tensors here.
+        permAttr.getValues<APInt>(),
+        [](const APInt &val) { return val.getSExtValue(); }));
+
+    auto inputType = cast<ShapedType>(op.getInput1().getType());
+
+    auto resultAttr = transpose(inputValues, inputType, outputType, permValues);
+    rewriter.replaceOpWithNewOp<tosa::ConstOp>(op, outputType, resultAttr);
+    return success();
+  }
+};
+
+struct TosaFoldConstantReciprocal : public OpRewritePattern<ReciprocalOp> {
+
+  using OpRewritePattern::OpRewritePattern;
+
+  static APFloat computeReciprocal(const APFloat &floatVal, FloatType floatTy) {
+    auto recipAttr = FloatAttr::get(floatTy, 1.0);
+    APFloat recip = recipAttr.getValue();
+    recip.divide(floatVal, tosaRoundingMode);
+
+    return recip;
+  }
+
+  LogicalResult matchAndRewrite(ReciprocalOp recip,
+                                PatternRewriter &rewriter) const override {
+    auto inputTensor = recip.getInput1();
+
+    // Check that we can apply folding
+    auto preCondCheck =
+        notifyIfNotConstantFloatTosaTensor(inputTensor, recip, rewriter);
+    if (failed(preCondCheck)) {
+      return preCondCheck;
+    }
+
+    // Extract the tensor values
+    DenseElementsAttr inputValues;
+    matchPattern(inputTensor, m_Constant(&inputValues));
+
+    // Check whether this should be folded.
+    if (!constantUnaryOpShouldBeFolded(recip, inputValues)) {
+      return rewriter.notifyMatchFailure(
+          recip, "Currently, reciprocals will only be folded if the input "
+                 "tensor has a single user");
+    }
+
+    // Create a new tensor with the updated values
+    auto newTensor = applyElementWise<APFloat, APFloat, FloatType>(
+        inputValues, &computeReciprocal,
+        cast<FloatType>(inputValues.getElementType()));
+
+    // Replace the use of the reciprocal with the transformed tensor
+    rewriter.replaceOpWithNewOp<ConstOp>(recip, newTensor.getType(), newTensor);
+    return success();
+  }
+};
+
+} // namespace
+
+void mlir::tosa::populateTosaFoldConstantTransposePatterns(
+    MLIRContext *ctx, RewritePatternSet &patterns) {
+  patterns.add<TosaFoldConstantTranspose>(ctx);
+}
+
+void mlir::tosa::populateTosaFoldConstantReciprocalPatterns(
+    MLIRContext *ctx, RewritePatternSet &patterns) {
+  patterns.add<TosaFoldConstantReciprocal>(ctx);
+}

diff  --git a/mlir/lib/Dialect/Tosa/Transforms/TosaLayerwiseConstantFoldPass.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaLayerwiseConstantFoldPass.cpp
index a217f66cd84c63..2e2d338abbe4bf 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaLayerwiseConstantFoldPass.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaLayerwiseConstantFoldPass.cpp
@@ -50,6 +50,7 @@ struct TosaLayerwiseConstantFoldPass
     RewritePatternSet patterns(ctx);
     auto func = getOperation();
 
+    mlir::tosa::populateTosaFoldConstantReciprocalPatterns(ctx, patterns);
     mlir::tosa::populateTosaFoldConstantTransposePatterns(ctx, patterns);
     populateTosaOpsCanonicalizationPatterns(ctx, patterns);
 

diff  --git a/mlir/test/Dialect/Tosa/constant-reciprocal-fold.mlir b/mlir/test/Dialect/Tosa/constant-reciprocal-fold.mlir
new file mode 100644
index 00000000000000..cc71c43d53ce29
--- /dev/null
+++ b/mlir/test/Dialect/Tosa/constant-reciprocal-fold.mlir
@@ -0,0 +1,137 @@
+// RUN: mlir-opt --split-input-file --tosa-layerwise-constant-fold %s | FileCheck %s
+
+// CHECK-LABEL: @reciprocal_fold_single_valued
+func.func @reciprocal_fold_single_valued() -> tensor<f32> {
+  // CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}2.5{{0*}}e-01{{.*}}tensor<f32>
+  // CHECK-NOT: tosa.reciprocal
+  // CHECK: return [[RES]]
+  %0 = "tosa.const"() {value = dense<4.0> : tensor<f32>} : () -> tensor<f32>
+  %1 = "tosa.reciprocal"(%0) : (tensor<f32>) -> tensor<f32>
+  return %1 : tensor<f32>
+}
+
+// CHECK-LABEL: @reciprocal_fold_splat
+func.func @reciprocal_fold_splat() -> tensor<12x7xf32> {
+  // CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}2.5{{0*}}e-01{{.*}}tensor<12x7xf32>
+  // CHECK-NOT: tosa.reciprocal
+  // CHECK: return [[RES]]
+  %0 = "tosa.const"() {value = dense<4.0> : tensor<12x7xf32>} : () -> tensor<12x7xf32>
+  %1 = "tosa.reciprocal"(%0) : (tensor<12x7xf32>) -> tensor<12x7xf32>
+  return %1 : tensor<12x7xf32>
+}
+
+// CHECK-LABEL: @reciprocal_div_zero
+func.func @reciprocal_div_zero() -> tensor<f32> {
+  // 0x7F800000 is the value for +infinity
+  // CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}0x7F800000
+  // CHECK-NOT: tosa.reciprocal
+  // CHECK: return [[RES]]
+  %0 = "tosa.const"() {value = dense<0.0> : tensor<f32>} : () -> tensor<f32>
+  %1 = "tosa.reciprocal"(%0) : (tensor<f32>) -> tensor<f32>
+  return %1 : tensor<f32>
+}
+
+// CHECK-LABEL: @reciprocal_div_neg_zero
+func.func @reciprocal_div_neg_zero() -> tensor<f32> {
+  // 0xFF800000 is the value for -infinity
+  // CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}0xFF800000
+  // CHECK-NOT: tosa.reciprocal
+  // CHECK: return [[RES]]
+  %0 = "tosa.const"() {value = dense<-0.0> : tensor<f32>} : () -> tensor<f32>
+  %1 = "tosa.reciprocal"(%0) : (tensor<f32>) -> tensor<f32>
+  return %1 : tensor<f32>
+}
+
+// CHECK-LABEL: @reciprocal_div_nan
+func.func @reciprocal_div_nan() -> tensor<f32> {
+  // 0x7FC00000 is the value for NAN
+  // CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}0x7FC00000
+  // CHECK-NOT: tosa.reciprocal
+  // CHECK: return [[RES]]
+  %0 = "tosa.const"() {value = dense<0x7FC00000> : tensor<f32>} : () -> tensor<f32>
+  %1 = "tosa.reciprocal"(%0) : (tensor<f32>) -> tensor<f32>
+  return %1 : tensor<f32>
+}
+
+// CHECK-LABEL: @reciprocal_div_infinity
+func.func @reciprocal_div_infinity() -> tensor<f32> {
+  // CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}<0.{{0*}}e+00>
+  // CHECK-NOT: tosa.reciprocal
+  // CHECK: return [[RES]]
+  %0 = "tosa.const"() {value = dense<0x7F800000> : tensor<f32>} : () -> tensor<f32>
+  %1 = "tosa.reciprocal"(%0) : (tensor<f32>) -> tensor<f32>
+  return %1 : tensor<f32>
+}
+
+// CHECK-LABEL: @reciprocal_div_neg_infinity
+func.func @reciprocal_div_neg_infinity() -> tensor<f32> {
+  // CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}<-0.{{0*}}e+00>
+  // CHECK-NOT: tosa.reciprocal
+  // CHECK: return [[RES]]
+  %0 = "tosa.const"() {value = dense<0xFF800000> : tensor<f32>} : () -> tensor<f32>
+  %1 = "tosa.reciprocal"(%0) : (tensor<f32>) -> tensor<f32>
+  return %1 : tensor<f32>
+}
+
+// CHECK-LABEL: @reciprocal_div_underflow
+func.func @reciprocal_div_underflow() -> tensor<2xf16> {
+  // CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}-0.{{0*}}e+00, 0.{{0*}}e+00
+  // CHECK-NOT: tosa.reciprocal
+  // CHECK: return [[RES]]
+  %0 = "tosa.const"() {value = dense<[-6.0e+15, 6.0e+15]> : tensor<2xf16>} : () -> tensor<2xf16>
+  %1 = "tosa.reciprocal"(%0) : (tensor<2xf16>) -> tensor<2xf16>
+  return %1 : tensor<2xf16>
+}
+
+// CHECK-LABEL: @reciprocal_div_overflow
+func.func @reciprocal_div_overflow() -> tensor<2xf16> {
+  // CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}0x7C00, 0xFC00
+  // CHECK-NOT: tosa.reciprocal
+  // CHECK: return [[RES]]
+  %0 = "tosa.const"() {value = dense<[0.0000001, -0.0000001]> : tensor<2xf16>} : () -> tensor<2xf16>
+  %1 = "tosa.reciprocal"(%0) : (tensor<2xf16>) -> tensor<2xf16>
+  return %1 : tensor<2xf16>
+}
+
+// CHECK-LABEL: @reciprocal_no_fold
+// The folding optimization works only intra-procedurally, so we won't be able
+// to fold anything here
+func.func @reciprocal_no_fold(%arg0: tensor<?x?xf32>) -> tensor<?x?xf32> {
+  // CHECK: tosa.reciprocal
+  // CHECK-NEXT: return
+  %0 = "tosa.reciprocal"(%arg0) : (tensor<?x?xf32>) -> tensor<?x?xf32>
+  return %0 : tensor<?x?xf32>
+}
+
+// CHECK-LABEL: @reciprocal_fold
+func.func @reciprocal_fold() -> tensor<4x6xf32> {
+  // CHECK: [[RES:]] ={{.*}}tosa.const
+  // CHECK-SAME{LITERAL}: [[5.68828249, 11.4416485, 1.6880486, 0.680272102, -0.875350117, 0.342313349],
+  // CHECK-SAME{LITERAL}:  [-4.81231928, 0.698080301, 0.65432179, -82.6446304, -4.33651352, -0.747551739],
+  // CHECK-SAME{LITERAL}:  [-12.4378109, 13.140605, 1.89501607, 0.885582745, 4.08830738, 1.4396776],
+  // CHECK-SAME{LITERAL}:  [2.02880907, -1.53280187, 0.552730501, 7.15819644, 0.64495325, -0.973709881]]
+  // CHECK-NOT: tosa.reciprocal
+  // CHECK: return [[RES]]
+  %0 = "tosa.const"() { value = dense<[
+                        [ 0.1758,  0.0874,  0.5924,  1.4700, -1.1424,  2.9213],
+                        [-0.2078,  1.4325,  1.5283, -0.0121, -0.2306, -1.3377],
+                        [-0.0804,  0.0761,  0.5277,  1.1292,  0.2446,  0.6946],
+                        [ 0.4929, -0.6524,  1.8092,  0.1397,  1.5505, -1.0270]]>
+                        : tensor<4x6xf32>
+                      } : () -> tensor<4x6xf32>
+  %1 = "tosa.reciprocal"(%0) : (tensor<4x6xf32>) -> tensor<4x6xf32>
+  return %1 : tensor<4x6xf32>
+}
+
+// CHECK-LABEL: @reciprocal_of_const_sparse
+// Sparse tensors are currently not supported
+func.func @reciprocal_of_const_sparse() -> tensor<32xbf16> {
+  // CHECK: tosa.const
+  // CHECK: tosa.reciprocal
+    %0 = "tosa.const"() { value = sparse<
+          [[0], [3], [11], [17], [20], [23], [25], [30], [31]],
+          [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0]>
+          : tensor<32xbf16> } : () -> tensor<32xbf16>
+    %1 = "tosa.reciprocal"(%0) : (tensor<32xbf16>) -> tensor<32xbf16>
+    return %1 : tensor<32xbf16>
+}


        


More information about the Mlir-commits mailing list