[Mlir-commits] [mlir] [mlir][Tosa]: Add folder to ReciprocalOp of splat constant inputs (PR #78137)
Aviad Cohen
llvmlistbot at llvm.org
Mon Jan 15 01:54:05 PST 2024
https://github.com/AviadCo created https://github.com/llvm/llvm-project/pull/78137
None
>From 488ef7211c45be86746c4b38594ed0328f9e0549 Mon Sep 17 00:00:00 2001
From: Aviad Cohen <aviad.cohen2 at mobileye.com>
Date: Mon, 15 Jan 2024 08:39:01 +0200
Subject: [PATCH] [mlir][Tosa]: Add folder to ReciprocalOp of splat constant
inputs
---
mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td | 7 +++++++
.../Dialect/Tosa/IR/TosaCanonicalizations.cpp | 19 +++++++++++++++++++
mlir/lib/Dialect/Tosa/IR/TosaOps.cpp | 9 +++++++++
.../Dialect/Tosa/Transforms/TosaFolders.cpp | 15 ++-------------
.../Dialect/Tosa/Utils/ConversionUtils.cpp | 1 +
mlir/test/Dialect/Tosa/canonicalize.mlir | 13 +++++++++++++
6 files changed, 51 insertions(+), 13 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 3257ecd9d91f11..d8fc960563bf29 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -1114,6 +1114,13 @@ def Tosa_ReciprocalOp : Tosa_ElementwiseOp<"reciprocal",
let results = (outs
Tosa_Tensor:$output
);
+
+ let extraClassDeclaration = [{
+ /// Computes reciprocal on a float element (input must be from float type).
+ static llvm::APFloat computeFloatElemOne(const llvm::APFloat &floatVal, FloatType floatTy);
+ }];
+
+ let hasFolder = 1;
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index 26c39ff3523434..fb3cd378f2c84b 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -17,6 +17,7 @@
#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
#include "mlir/Dialect/Tosa/Utils/QuantUtils.h"
#include "mlir/Dialect/Tosa/Utils/ShapeUtils.h"
+#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/Matchers.h"
@@ -25,6 +26,7 @@
#include "mlir/Transforms/InliningUtils.h"
#include "mlir/Transforms/RegionUtils.h"
#include "llvm/ADT/APFloat.h"
+#include "llvm/ADT/APInt.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/TypeSwitch.h"
@@ -1036,3 +1038,20 @@ OpFoldResult ConcatOp::fold(FoldAdaptor adaptor) {
getOperation()->setOperands(concatOperands);
return getResult();
}
+
+OpFoldResult tosa::ReciprocalOp::fold(FoldAdaptor adaptor) {
+ auto input = adaptor.getInput1();
+
+ auto inputAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(input);
+ // Fold splat inputs only.
+ if (!inputAttr || !inputAttr.isSplat())
+ return {};
+
+ auto shapeType = llvm::cast<ShapedType>(getType());
+ if (auto floatType = llvm::dyn_cast<FloatType>(inputAttr.getElementType())) {
+ auto floatVal = inputAttr.getSplatValue<APFloat>();
+ return DenseElementsAttr::get(shapeType, computeFloatElemOne(floatVal, floatType));
+ }
+
+ return {};
+}
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 661126f4df9976..a2af9ef0c069f2 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -25,6 +25,7 @@
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Transforms/InliningUtils.h"
+#include "llvm/ADT/APFloat.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/TypeSwitch.h"
@@ -1778,6 +1779,14 @@ std::optional<SmallVector<int64_t, 4>> ApplyScaleOp::getShapeForUnroll() {
return std::nullopt;
}
+APFloat tosa::ReciprocalOp::computeFloatElemOne(const APFloat &floatVal, FloatType floatTy) {
+ auto recipAttr = FloatAttr::get(floatTy, 1.0);
+ APFloat recip = recipAttr.getValue();
+ recip.divide(floatVal, llvm::APFloat::rmNearestTiesToEven);
+
+ return recip;
+}
+
// parse and print of IfOp refer to the implementation of SCF dialect.
ParseResult IfOp::parse(OpAsmParser &parser, OperationState &result) {
// Create the regions for 'then'.
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp
index d35e911ebe63c4..6208b38900ebad 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp
@@ -15,6 +15,7 @@
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/Dialect/Tosa/Transforms/Passes.h"
+#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
@@ -30,10 +31,6 @@ using namespace mlir::tosa;
namespace {
-/// Rounding mode to be used on floating point operations that require rounding.
-static constexpr llvm::RoundingMode tosaRoundingMode =
- llvm::APFloat::rmNearestTiesToEven;
-
/// Apply the given transformation \p toApply to every element of the tensor to
/// be transformed \p toTransform.
///
@@ -249,14 +246,6 @@ struct TosaFoldConstantReciprocal : public OpRewritePattern<ReciprocalOp> {
using OpRewritePattern::OpRewritePattern;
- static APFloat computeReciprocal(const APFloat &floatVal, FloatType floatTy) {
- auto recipAttr = FloatAttr::get(floatTy, 1.0);
- APFloat recip = recipAttr.getValue();
- recip.divide(floatVal, tosaRoundingMode);
-
- return recip;
- }
-
LogicalResult matchAndRewrite(ReciprocalOp recip,
PatternRewriter &rewriter) const override {
auto inputTensor = recip.getInput1();
@@ -281,7 +270,7 @@ struct TosaFoldConstantReciprocal : public OpRewritePattern<ReciprocalOp> {
// Create a new tensor with the updated values
auto newTensor = applyElementWise<APFloat, APFloat, FloatType>(
- inputValues, &computeReciprocal,
+ inputValues, &ReciprocalOp::computeFloatElemOne,
cast<FloatType>(inputValues.getElementType()));
// Replace the use of the reciprocal with the transformed tensor
diff --git a/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp b/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp
index ee428b201d0073..9fc864463d95bf 100644
--- a/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp
+++ b/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp
@@ -12,6 +12,7 @@
#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
+#include "mlir/IR/BuiltinAttributes.h"
using namespace mlir;
using namespace mlir::tosa;
diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index fd51d287bca058..de9d13b1453232 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -613,3 +613,16 @@ func.func nested @fold_tile_rank_zero() -> tensor<i32> {
%1 = tosa.tile %0 {multiples = array<i64>} : (tensor<i32>) -> tensor<i32>
return %1 : tensor<i32>
}
+
+// -----
+
+// CHECK-LABEL: @fold_reciprocal
+func.func nested @fold_reciprocal() -> tensor<3x600x1200xf32> {
+ // CHECK: %[[VAL_0:.*]] = "tosa.const"() <{value = dense<8.620690e-03> : tensor<3x600x1200xf32>}> : () -> tensor<3x600x1200xf32>
+ // CHECK: return %[[VAL_0]] : tensor<3x600x1200xf32>
+ // CHECK: }
+ %0 = "tosa.const"(){ value = dense<116.0>: tensor<f32> }: () -> tensor<f32>
+ %1 = "tosa.cast"(%0) : (tensor<f32>) -> tensor<3x600x1200xf32>
+ %2 = "tosa.reciprocal"(%1): (tensor<3x600x1200xf32>) -> tensor<3x600x1200xf32>
+ return %2 : tensor<3x600x1200xf32>
+}
More information about the Mlir-commits
mailing list