[Mlir-commits] [mlir] d84d418 - [mlir][tosa] Constant folding for reciprocal
Matthias Gehre
llvmlistbot at llvm.org
Wed Jul 5 02:39:08 PDT 2023
Author: Tina Jung
Date: 2023-07-05T11:38:46+02:00
New Revision: d84d418e2adc421c98e484ab3b09e2f4f3e5c1ef
URL: https://github.com/llvm/llvm-project/commit/d84d418e2adc421c98e484ab3b09e2f4f3e5c1ef
DIFF: https://github.com/llvm/llvm-project/commit/d84d418e2adc421c98e484ab3b09e2f4f3e5c1ef.diff
LOG: [mlir][tosa] Constant folding for reciprocal
Add constant fold for tosa.reciprocal, which can be applied if the input is a dense constant tensor. The reciprocal is computed for every element and the result is a tensor with the same dimensions as the input tensor.
As the input tensor might require a lot of memory and the folding might double the required memory, a heuristic decides when to actually apply the folding. Currently, the operation will be replaced only if the input constant is a splat (i.e. requires little memory) or has in single user (similar to the already existing fold for constant transposes). This keeps the additionally required space low.
Differential Revision: https://reviews.llvm.org/D150578
Added:
mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp
mlir/test/Dialect/Tosa/constant-reciprocal-fold.mlir
Modified:
mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h
mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt
mlir/lib/Dialect/Tosa/Transforms/TosaLayerwiseConstantFoldPass.cpp
Removed:
mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantTranspose.cpp
################################################################################
diff --git a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h
index d6ae78196f4cbd..c81f59b3d5d36a 100644
--- a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h
@@ -30,6 +30,8 @@ void populateTosaDecomposeTransposeConv(MLIRContext *ctx,
RewritePatternSet &patterns);
void populateTosaDecomposeDepthwise(MLIRContext *ctx,
RewritePatternSet &patterns);
+void populateTosaFoldConstantReciprocalPatterns(MLIRContext *ctx,
+ RewritePatternSet &patterns);
void populateTosaFoldConstantTransposePatterns(MLIRContext *ctx,
RewritePatternSet &patterns);
diff --git a/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt
index 4f5a54de0c7346..0e6510ba1e9255 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt
@@ -2,7 +2,7 @@ add_mlir_dialect_library(MLIRTosaTransforms
TosaDecomposeTransposeConv.cpp
TosaDecomposeConv2D.cpp
TosaDecomposeDepthwise.cpp
- TosaFoldConstantTranspose.cpp
+ TosaFolders.cpp
TosaInferShapes.cpp
TosaLayerwiseConstantFoldPass.cpp
TosaMakeBroadcastable.cpp
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantTranspose.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantTranspose.cpp
deleted file mode 100644
index 302e2793f0a32e..00000000000000
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantTranspose.cpp
+++ /dev/null
@@ -1,138 +0,0 @@
-//===- TosaFoldConstantTranspose.cpp --------------------------------------===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-//
-// Fold TOSA Transpose operation on constant data
-//
-//===----------------------------------------------------------------------===//
-
-#include "mlir/Dialect/Tosa/IR/TosaOps.h"
-#include "mlir/Dialect/Tosa/Transforms/Passes.h"
-#include "mlir/Dialect/Utils/IndexingUtils.h"
-#include "mlir/IR/Matchers.h"
-#include "mlir/Pass/Pass.h"
-
-using namespace mlir;
-using namespace mlir::tosa;
-
-namespace {
-
-template <typename BaseType>
-DenseElementsAttr transposeType(ElementsAttr attr, ShapedType inputType,
- ShapedType outputType,
- llvm::ArrayRef<int64_t> permValues) {
- if (inputType.getNumElements() == 0)
- return DenseElementsAttr::get(outputType, llvm::ArrayRef<BaseType>{});
-
- auto attrValues = attr.getValues<BaseType>();
- auto inputShape = inputType.getShape();
-
- // The inverted permutation map and strides of the output are used to compute
- // the contribution of a given dimension to the destination linear index in
- // an order-independent way.
- auto outputStrides = computeStrides(outputType.getShape());
- auto invertedPermValues = invertPermutationVector(permValues);
-
- auto initialValue = *std::begin(attrValues);
- SmallVector<BaseType> outputValues(inputType.getNumElements(), initialValue);
-
- for (const auto &it : llvm::enumerate(attrValues)) {
- auto srcLinearIndex = it.index();
-
- uint64_t dstLinearIndex = 0;
- for (int64_t dim = inputShape.size() - 1; dim >= 0; --dim) {
- // Compute the index into the current dimension of the source vector.
- auto sourceIndexForDim = srcLinearIndex % inputShape[dim];
- srcLinearIndex /= inputShape[dim];
-
- // Add the contribution of the current dimension to the output using the
- // permutation map.
- dstLinearIndex +=
- outputStrides[invertedPermValues[dim]] * sourceIndexForDim;
- }
-
- outputValues[dstLinearIndex] = it.value();
- }
-
- return DenseElementsAttr::get(outputType,
- llvm::ArrayRef<BaseType>(outputValues));
-}
-
-// A type specialized transposition of an ElementsAttr.
-// This implementation tries to operate on the underlying data in its raw
-// representation when possible to avoid allocating a large number of Attribute
-// objects.
-DenseElementsAttr transpose(ElementsAttr attr, ShapedType inputType,
- ShapedType outputType,
- llvm::ArrayRef<int64_t> permValues) {
- auto baseType = inputType.getElementType();
-
- // Handle possible integer types
- if (auto intType = dyn_cast<IntegerType>(baseType)) {
- switch (intType.getWidth()) {
- case 1:
- return transposeType<bool>(attr, inputType, outputType, permValues);
- case 8:
- return transposeType<int8_t>(attr, inputType, outputType, permValues);
- case 16:
- return transposeType<int16_t>(attr, inputType, outputType, permValues);
- case 32:
- return transposeType<int32_t>(attr, inputType, outputType, permValues);
- case 64:
- return transposeType<int64_t>(attr, inputType, outputType, permValues);
- default:
- return transposeType<APInt>(attr, inputType, outputType, permValues);
- }
- }
-
- // Handle possible float types
- if (baseType.isF32()) {
- return transposeType<float>(attr, inputType, outputType, permValues);
- }
-
- return transposeType<APFloat>(attr, inputType, outputType, permValues);
-}
-
-struct TosaFoldConstantTranspose : public OpRewritePattern<tosa::TransposeOp> {
- using OpRewritePattern::OpRewritePattern;
-
- LogicalResult matchAndRewrite(tosa::TransposeOp op,
- PatternRewriter &rewriter) const override {
- auto outputType = cast<ShapedType>(op.getType());
- // TOSA supports quantized types.
- if (!outputType.getElementType().isIntOrIndexOrFloat())
- return failure();
-
- ElementsAttr inputValues;
- if (!matchPattern(op.getInput1(), m_Constant(&inputValues)))
- return failure();
- // Make sure the input is a constant that has a single user.
- if (!llvm::hasSingleElement(op.getInput1().getDefiningOp()->getUsers()))
- return failure();
-
- DenseIntElementsAttr permAttr;
- if (!matchPattern(op.getPerms(), m_Constant(&permAttr)))
- return failure();
- auto permValues = llvm::to_vector<6>(llvm::map_range(
- // TOSA allows both 32- and 64-bit integer tensors here.
- permAttr.getValues<APInt>(),
- [](const APInt &val) { return val.getSExtValue(); }));
-
- auto inputType = cast<ShapedType>(op.getInput1().getType());
-
- auto resultAttr = transpose(inputValues, inputType, outputType, permValues);
- rewriter.replaceOpWithNewOp<tosa::ConstOp>(op, outputType, resultAttr);
- return success();
- }
-};
-
-} // namespace
-
-void mlir::tosa::populateTosaFoldConstantTransposePatterns(
- MLIRContext *ctx, RewritePatternSet &patterns) {
- patterns.add<TosaFoldConstantTranspose>(ctx);
-}
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp
new file mode 100644
index 00000000000000..58693991952a3b
--- /dev/null
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp
@@ -0,0 +1,302 @@
+//===- TosaFolders.cpp ----------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Fold TOSA operations
+//
+//===----------------------------------------------------------------------===//
+
+#include <functional>
+
+#include "mlir/Dialect/Tosa/IR/TosaOps.h"
+#include "mlir/Dialect/Tosa/Transforms/Passes.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Matchers.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Support/LogicalResult.h"
+#include "llvm/ADT/APFloat.h"
+#include "llvm/ADT/FloatingPointMode.h"
+#include "llvm/ADT/SmallVector.h"
+
+using namespace mlir;
+using namespace mlir::tosa;
+
+namespace {
+
+/// Rounding mode to be used on floating point operations that require rounding.
+static constexpr llvm::RoundingMode tosaRoundingMode =
+ llvm::APFloat::rmNearestTiesToEven;
+
+/// Apply the given transformation \p toApply to every element of the tensor to
+/// be transformed \p toTransform.
+///
+/// Elements of \p toTransform are extracted as \p SrcValueType.
+///
+/// \returns A tensor with the same size as \p toTransform, containing
+/// \p TargetValueType values of type \p TargetType.
+template <class SrcValType, class TargetValType, class TargetType>
+DenseElementsAttr applyElementWise(
+ const DenseElementsAttr &toTransform,
+ const std::function<TargetValType(const SrcValType &, TargetType)> &toApply,
+ TargetType targetType) {
+ SmallVector<TargetValType> transformedValues;
+ // We already know the amount of values we will insert, reserve space for
+ // all of them to avoid dynamic resizing
+ transformedValues.reserve(toTransform.getNumElements());
+ for (auto val : toTransform.getValues<SrcValType>()) {
+ auto transformedVal = toApply(val, targetType);
+ transformedValues.push_back(transformedVal);
+ }
+
+ // Make sure that the output tensor has the expected output type
+ auto inShape = toTransform.getType();
+ auto outTy = inShape.cloneWith({}, targetType);
+
+ return DenseElementsAttr::get(outTy, transformedValues);
+}
+
+template DenseElementsAttr applyElementWise<APFloat, APFloat, FloatType>(
+ const DenseElementsAttr &toTransform,
+ const std::function<APFloat(const APFloat &, FloatType)> &toApply,
+ FloatType targetType);
+
+/// Function that checks if the type contained in \p toCheck is float.
+LogicalResult notifyIfNotFloat(TypedValue<TensorType> toCheck, TosaOp location,
+ PatternRewriter &rewriter) {
+ if (isa<FloatType>(toCheck.getType().getElementType())) {
+ return success();
+ }
+ return rewriter.notifyMatchFailure(location,
+ "Unexpected input tensor type: the "
+ "TOSA spec only allows floats");
+}
+
+/// Function that checks if \p toCheck is a dense TOSA constant tensor.
+LogicalResult notifyIfNoTosaDenseConstantTensor(TypedValue<TensorType> toCheck,
+ TosaOp location,
+ PatternRewriter &rewriter) {
+ // Check whether the tensor is constant and dense
+ // TODO We currently ensure the tensor is dense by using the correct type for
+ // the bind_value, however we do not actually need this value. It would be
+ // nicer to only have a check here.
+ DenseElementsAttr tmp;
+ if (!matchPattern(toCheck, m_Constant(&tmp))) {
+ return rewriter.notifyMatchFailure(location,
+ "Non-const or non-dense input tensor");
+ }
+
+ // Make sure it actually is a TOSA constant (the match allows for other
+ // constants as well)
+ if (isa<ConstOp>(toCheck.getDefiningOp())) {
+ return success();
+ }
+
+ return rewriter.notifyMatchFailure(location,
+ "The reciprocal can only be folded if "
+ "it operates on a TOSA constant");
+}
+
+/// Function that checks if \p toCheck is a dense TOSA constant float tensor.
+LogicalResult notifyIfNotConstantFloatTosaTensor(TypedValue<TensorType> toCheck,
+ TosaOp location,
+ PatternRewriter &rewriter) {
+ auto floatCheck = notifyIfNotFloat(toCheck, location, rewriter);
+ if (failed(floatCheck)) {
+ return floatCheck;
+ }
+ return notifyIfNoTosaDenseConstantTensor(toCheck, location, rewriter);
+}
+
+/// Heuristic to decide when to replace a unary operation on a constant with the
+/// folded value.
+/// Folding operations on constants can lead to an increased memory usage
+/// whenever the input cannot be replaced but a new constant is inserted. Hence,
+/// this will currently only suggest folding when the memory impact is
+/// negligible.
+/// Takes the \p unaryOp and the constant input \p values.
+/// \returns Whether folding should be applied.
+bool constantUnaryOpShouldBeFolded(TosaOp unaryOp, DenseElementsAttr values) {
+ assert(unaryOp->getNumOperands() == 1);
+ auto inputOp = unaryOp->getOperand(0);
+
+ // If the input is a splat, we don't care for the number of users
+ if (isa<SplatElementsAttr>(values)) {
+ return true;
+ }
+
+ // If this is the only use of the tensor it should be replaced as no
+ // additional memory is required
+ return inputOp.hasOneUse();
+}
+
+template <typename BaseType>
+DenseElementsAttr transposeType(ElementsAttr attr, ShapedType inputType,
+ ShapedType outputType,
+ llvm::ArrayRef<int64_t> permValues) {
+ if (inputType.getNumElements() == 0)
+ return DenseElementsAttr::get(outputType, llvm::ArrayRef<BaseType>{});
+
+ auto attrValues = attr.getValues<BaseType>();
+ auto inputShape = inputType.getShape();
+
+ // The inverted permutation map and strides of the output are used to compute
+ // the contribution of a given dimension to the destination linear index in
+ // an order-independent way.
+ auto outputStrides = computeStrides(outputType.getShape());
+ auto invertedPermValues = invertPermutationVector(permValues);
+
+ auto initialValue = *std::begin(attrValues);
+ SmallVector<BaseType> outputValues(inputType.getNumElements(), initialValue);
+
+ for (const auto &it : llvm::enumerate(attrValues)) {
+ auto srcLinearIndex = it.index();
+
+ uint64_t dstLinearIndex = 0;
+ for (int64_t dim = inputShape.size() - 1; dim >= 0; --dim) {
+ // Compute the index into the current dimension of the source vector.
+ auto sourceIndexForDim = srcLinearIndex % inputShape[dim];
+ srcLinearIndex /= inputShape[dim];
+
+ // Add the contribution of the current dimension to the output using the
+ // permutation map.
+ dstLinearIndex +=
+ outputStrides[invertedPermValues[dim]] * sourceIndexForDim;
+ }
+
+ outputValues[dstLinearIndex] = it.value();
+ }
+
+ return DenseElementsAttr::get(outputType,
+ llvm::ArrayRef<BaseType>(outputValues));
+}
+
+// A type specialized transposition of an ElementsAttr.
+// This implementation tries to operate on the underlying data in its raw
+// representation when possible to avoid allocating a large number of Attribute
+// objects.
+DenseElementsAttr transpose(ElementsAttr attr, ShapedType inputType,
+ ShapedType outputType,
+ llvm::ArrayRef<int64_t> permValues) {
+ auto baseType = inputType.getElementType();
+
+ // Handle possible integer types
+ if (auto intType = dyn_cast<IntegerType>(baseType)) {
+ switch (intType.getWidth()) {
+ case 1:
+ return transposeType<bool>(attr, inputType, outputType, permValues);
+ case 8:
+ return transposeType<int8_t>(attr, inputType, outputType, permValues);
+ case 16:
+ return transposeType<int16_t>(attr, inputType, outputType, permValues);
+ case 32:
+ return transposeType<int32_t>(attr, inputType, outputType, permValues);
+ case 64:
+ return transposeType<int64_t>(attr, inputType, outputType, permValues);
+ default:
+ return transposeType<APInt>(attr, inputType, outputType, permValues);
+ }
+ }
+
+ // Handle possible float types
+ if (baseType.isF32()) {
+ return transposeType<float>(attr, inputType, outputType, permValues);
+ }
+
+ return transposeType<APFloat>(attr, inputType, outputType, permValues);
+}
+
+struct TosaFoldConstantTranspose : public OpRewritePattern<tosa::TransposeOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(tosa::TransposeOp op,
+ PatternRewriter &rewriter) const override {
+ auto outputType = cast<ShapedType>(op.getType());
+ // TOSA supports quantized types.
+ if (!outputType.getElementType().isIntOrIndexOrFloat())
+ return failure();
+
+ ElementsAttr inputValues;
+ if (!matchPattern(op.getInput1(), m_Constant(&inputValues)))
+ return failure();
+ // Make sure the input is a constant that has a single user.
+ if (!llvm::hasSingleElement(op.getInput1().getDefiningOp()->getUsers()))
+ return failure();
+
+ DenseIntElementsAttr permAttr;
+ if (!matchPattern(op.getPerms(), m_Constant(&permAttr)))
+ return failure();
+ auto permValues = llvm::to_vector<6>(llvm::map_range(
+ // TOSA allows both 32- and 64-bit integer tensors here.
+ permAttr.getValues<APInt>(),
+ [](const APInt &val) { return val.getSExtValue(); }));
+
+ auto inputType = cast<ShapedType>(op.getInput1().getType());
+
+ auto resultAttr = transpose(inputValues, inputType, outputType, permValues);
+ rewriter.replaceOpWithNewOp<tosa::ConstOp>(op, outputType, resultAttr);
+ return success();
+ }
+};
+
+struct TosaFoldConstantReciprocal : public OpRewritePattern<ReciprocalOp> {
+
+ using OpRewritePattern::OpRewritePattern;
+
+ static APFloat computeReciprocal(const APFloat &floatVal, FloatType floatTy) {
+ auto recipAttr = FloatAttr::get(floatTy, 1.0);
+ APFloat recip = recipAttr.getValue();
+ recip.divide(floatVal, tosaRoundingMode);
+
+ return recip;
+ }
+
+ LogicalResult matchAndRewrite(ReciprocalOp recip,
+ PatternRewriter &rewriter) const override {
+ auto inputTensor = recip.getInput1();
+
+ // Check that we can apply folding
+ auto preCondCheck =
+ notifyIfNotConstantFloatTosaTensor(inputTensor, recip, rewriter);
+ if (failed(preCondCheck)) {
+ return preCondCheck;
+ }
+
+ // Extract the tensor values
+ DenseElementsAttr inputValues;
+ matchPattern(inputTensor, m_Constant(&inputValues));
+
+ // Check whether this should be folded.
+ if (!constantUnaryOpShouldBeFolded(recip, inputValues)) {
+ return rewriter.notifyMatchFailure(
+ recip, "Currently, reciprocals will only be folded if the input "
+ "tensor has a single user");
+ }
+
+ // Create a new tensor with the updated values
+ auto newTensor = applyElementWise<APFloat, APFloat, FloatType>(
+ inputValues, &computeReciprocal,
+ cast<FloatType>(inputValues.getElementType()));
+
+ // Replace the use of the reciprocal with the transformed tensor
+ rewriter.replaceOpWithNewOp<ConstOp>(recip, newTensor.getType(), newTensor);
+ return success();
+ }
+};
+
+} // namespace
+
+void mlir::tosa::populateTosaFoldConstantTransposePatterns(
+ MLIRContext *ctx, RewritePatternSet &patterns) {
+ patterns.add<TosaFoldConstantTranspose>(ctx);
+}
+
+void mlir::tosa::populateTosaFoldConstantReciprocalPatterns(
+ MLIRContext *ctx, RewritePatternSet &patterns) {
+ patterns.add<TosaFoldConstantReciprocal>(ctx);
+}
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaLayerwiseConstantFoldPass.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaLayerwiseConstantFoldPass.cpp
index a217f66cd84c63..2e2d338abbe4bf 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaLayerwiseConstantFoldPass.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaLayerwiseConstantFoldPass.cpp
@@ -50,6 +50,7 @@ struct TosaLayerwiseConstantFoldPass
RewritePatternSet patterns(ctx);
auto func = getOperation();
+ mlir::tosa::populateTosaFoldConstantReciprocalPatterns(ctx, patterns);
mlir::tosa::populateTosaFoldConstantTransposePatterns(ctx, patterns);
populateTosaOpsCanonicalizationPatterns(ctx, patterns);
diff --git a/mlir/test/Dialect/Tosa/constant-reciprocal-fold.mlir b/mlir/test/Dialect/Tosa/constant-reciprocal-fold.mlir
new file mode 100644
index 00000000000000..cc71c43d53ce29
--- /dev/null
+++ b/mlir/test/Dialect/Tosa/constant-reciprocal-fold.mlir
@@ -0,0 +1,137 @@
+// RUN: mlir-opt --split-input-file --tosa-layerwise-constant-fold %s | FileCheck %s
+
+// CHECK-LABEL: @reciprocal_fold_single_valued
+func.func @reciprocal_fold_single_valued() -> tensor<f32> {
+ // CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}2.5{{0*}}e-01{{.*}}tensor<f32>
+ // CHECK-NOT: tosa.reciprocal
+ // CHECK: return [[RES]]
+ %0 = "tosa.const"() {value = dense<4.0> : tensor<f32>} : () -> tensor<f32>
+ %1 = "tosa.reciprocal"(%0) : (tensor<f32>) -> tensor<f32>
+ return %1 : tensor<f32>
+}
+
+// CHECK-LABEL: @reciprocal_fold_splat
+func.func @reciprocal_fold_splat() -> tensor<12x7xf32> {
+ // CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}2.5{{0*}}e-01{{.*}}tensor<12x7xf32>
+ // CHECK-NOT: tosa.reciprocal
+ // CHECK: return [[RES]]
+ %0 = "tosa.const"() {value = dense<4.0> : tensor<12x7xf32>} : () -> tensor<12x7xf32>
+ %1 = "tosa.reciprocal"(%0) : (tensor<12x7xf32>) -> tensor<12x7xf32>
+ return %1 : tensor<12x7xf32>
+}
+
+// CHECK-LABEL: @reciprocal_div_zero
+func.func @reciprocal_div_zero() -> tensor<f32> {
+ // 0x7F800000 is the value for +infinity
+ // CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}0x7F800000
+ // CHECK-NOT: tosa.reciprocal
+ // CHECK: return [[RES]]
+ %0 = "tosa.const"() {value = dense<0.0> : tensor<f32>} : () -> tensor<f32>
+ %1 = "tosa.reciprocal"(%0) : (tensor<f32>) -> tensor<f32>
+ return %1 : tensor<f32>
+}
+
+// CHECK-LABEL: @reciprocal_div_neg_zero
+func.func @reciprocal_div_neg_zero() -> tensor<f32> {
+ // 0xFF800000 is the value for -infinity
+ // CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}0xFF800000
+ // CHECK-NOT: tosa.reciprocal
+ // CHECK: return [[RES]]
+ %0 = "tosa.const"() {value = dense<-0.0> : tensor<f32>} : () -> tensor<f32>
+ %1 = "tosa.reciprocal"(%0) : (tensor<f32>) -> tensor<f32>
+ return %1 : tensor<f32>
+}
+
+// CHECK-LABEL: @reciprocal_div_nan
+func.func @reciprocal_div_nan() -> tensor<f32> {
+ // 0x7FC00000 is the value for NAN
+ // CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}0x7FC00000
+ // CHECK-NOT: tosa.reciprocal
+ // CHECK: return [[RES]]
+ %0 = "tosa.const"() {value = dense<0x7FC00000> : tensor<f32>} : () -> tensor<f32>
+ %1 = "tosa.reciprocal"(%0) : (tensor<f32>) -> tensor<f32>
+ return %1 : tensor<f32>
+}
+
+// CHECK-LABEL: @reciprocal_div_infinity
+func.func @reciprocal_div_infinity() -> tensor<f32> {
+ // CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}<0.{{0*}}e+00>
+ // CHECK-NOT: tosa.reciprocal
+ // CHECK: return [[RES]]
+ %0 = "tosa.const"() {value = dense<0x7F800000> : tensor<f32>} : () -> tensor<f32>
+ %1 = "tosa.reciprocal"(%0) : (tensor<f32>) -> tensor<f32>
+ return %1 : tensor<f32>
+}
+
+// CHECK-LABEL: @reciprocal_div_neg_infinity
+func.func @reciprocal_div_neg_infinity() -> tensor<f32> {
+ // CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}<-0.{{0*}}e+00>
+ // CHECK-NOT: tosa.reciprocal
+ // CHECK: return [[RES]]
+ %0 = "tosa.const"() {value = dense<0xFF800000> : tensor<f32>} : () -> tensor<f32>
+ %1 = "tosa.reciprocal"(%0) : (tensor<f32>) -> tensor<f32>
+ return %1 : tensor<f32>
+}
+
+// CHECK-LABEL: @reciprocal_div_underflow
+func.func @reciprocal_div_underflow() -> tensor<2xf16> {
+ // CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}-0.{{0*}}e+00, 0.{{0*}}e+00
+ // CHECK-NOT: tosa.reciprocal
+ // CHECK: return [[RES]]
+ %0 = "tosa.const"() {value = dense<[-6.0e+15, 6.0e+15]> : tensor<2xf16>} : () -> tensor<2xf16>
+ %1 = "tosa.reciprocal"(%0) : (tensor<2xf16>) -> tensor<2xf16>
+ return %1 : tensor<2xf16>
+}
+
+// CHECK-LABEL: @reciprocal_div_overflow
+func.func @reciprocal_div_overflow() -> tensor<2xf16> {
+ // CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}0x7C00, 0xFC00
+ // CHECK-NOT: tosa.reciprocal
+ // CHECK: return [[RES]]
+ %0 = "tosa.const"() {value = dense<[0.0000001, -0.0000001]> : tensor<2xf16>} : () -> tensor<2xf16>
+ %1 = "tosa.reciprocal"(%0) : (tensor<2xf16>) -> tensor<2xf16>
+ return %1 : tensor<2xf16>
+}
+
+// CHECK-LABEL: @reciprocal_no_fold
+// The folding optimization works only intra-procedurally, so we won't be able
+// to fold anything here
+func.func @reciprocal_no_fold(%arg0: tensor<?x?xf32>) -> tensor<?x?xf32> {
+ // CHECK: tosa.reciprocal
+ // CHECK-NEXT: return
+ %0 = "tosa.reciprocal"(%arg0) : (tensor<?x?xf32>) -> tensor<?x?xf32>
+ return %0 : tensor<?x?xf32>
+}
+
+// CHECK-LABEL: @reciprocal_fold
+func.func @reciprocal_fold() -> tensor<4x6xf32> {
+ // CHECK: [[RES:]] ={{.*}}tosa.const
+ // CHECK-SAME{LITERAL}: [[5.68828249, 11.4416485, 1.6880486, 0.680272102, -0.875350117, 0.342313349],
+ // CHECK-SAME{LITERAL}: [-4.81231928, 0.698080301, 0.65432179, -82.6446304, -4.33651352, -0.747551739],
+ // CHECK-SAME{LITERAL}: [-12.4378109, 13.140605, 1.89501607, 0.885582745, 4.08830738, 1.4396776],
+ // CHECK-SAME{LITERAL}: [2.02880907, -1.53280187, 0.552730501, 7.15819644, 0.64495325, -0.973709881]]
+ // CHECK-NOT: tosa.reciprocal
+ // CHECK: return [[RES]]
+ %0 = "tosa.const"() { value = dense<[
+ [ 0.1758, 0.0874, 0.5924, 1.4700, -1.1424, 2.9213],
+ [-0.2078, 1.4325, 1.5283, -0.0121, -0.2306, -1.3377],
+ [-0.0804, 0.0761, 0.5277, 1.1292, 0.2446, 0.6946],
+ [ 0.4929, -0.6524, 1.8092, 0.1397, 1.5505, -1.0270]]>
+ : tensor<4x6xf32>
+ } : () -> tensor<4x6xf32>
+ %1 = "tosa.reciprocal"(%0) : (tensor<4x6xf32>) -> tensor<4x6xf32>
+ return %1 : tensor<4x6xf32>
+}
+
+// CHECK-LABEL: @reciprocal_of_const_sparse
+// Sparse tensors are currently not supported
+func.func @reciprocal_of_const_sparse() -> tensor<32xbf16> {
+ // CHECK: tosa.const
+ // CHECK: tosa.reciprocal
+ %0 = "tosa.const"() { value = sparse<
+ [[0], [3], [11], [17], [20], [23], [25], [30], [31]],
+ [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0]>
+ : tensor<32xbf16> } : () -> tensor<32xbf16>
+ %1 = "tosa.reciprocal"(%0) : (tensor<32xbf16>) -> tensor<32xbf16>
+ return %1 : tensor<32xbf16>
+}
More information about the Mlir-commits
mailing list