[Mlir-commits] [mlir] 3fbc6fd - [TOSA] Loosen folding restrictions for tosa.add, tosa.sub, tosa.mul
Robert Suderman
llvmlistbot at llvm.org
Thu Mar 30 11:22:39 PDT 2023
Author: SJW
Date: 2023-03-30T18:22:20Z
New Revision: 3fbc6fd4931f91003a5441866b674d3d635d8a60
URL: https://github.com/llvm/llvm-project/commit/3fbc6fd4931f91003a5441866b674d3d635d8a60
DIFF: https://github.com/llvm/llvm-project/commit/3fbc6fd4931f91003a5441866b674d3d635d8a60.diff
LOG: [TOSA] Loosen folding restrictions for tosa.add,tosa.sub, tosa.mul
Allow folding of different tensor types when the constant tensor is broadcast.
Removed redundant and incorrect AddZero and MulOne canonical optimizations.
Reviewed By: rsuderman, eric-k256
Differential Revision: https://reviews.llvm.org/D145738
Added:
Modified:
mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
mlir/test/Dialect/Tosa/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 33741c00d8495..043098f65a9ee 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -477,7 +477,6 @@ def Tosa_AddOp : Tosa_Op<"add", [
Tosa_Tensor:$output
);
- let hasCanonicalizer = 1;
let hasFolder = 1;
}
@@ -796,7 +795,6 @@ def Tosa_MulOp : Tosa_Op<"mul", [
Tosa_Tensor:$output
);
- let hasCanonicalizer = 1;
let hasFolder = 1;
}
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index ef93e1955b60b..19a80c783c475 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -246,92 +246,6 @@ void TransposeOp::getCanonicalizationPatterns(RewritePatternSet &results,
results.add<ConsolidateTransposeOptimization, TransposeIsReshape>(context);
}
-struct AddZeroOptimization : public OpRewritePattern<tosa::AddOp> {
- using OpRewritePattern::OpRewritePattern;
-
- LogicalResult matchAndRewrite(tosa::AddOp op,
- PatternRewriter &rewriter) const override {
- auto input1 = op.getInput1();
- auto input2 = op.getInput2();
-
- DenseElementsAttr input1Attr;
- if (matchPattern(input1, m_Constant(&input1Attr)) && input1Attr.isSplat() &&
- input2.getType() == op.getType()) {
- if (input1Attr.getType().getElementType().isa<IntegerType>() &&
- input1Attr.getSplatValue<APInt>().isZero()) {
- rewriter.replaceOp(op, op.getInput2());
- return success();
- }
- }
-
- DenseElementsAttr input2Attr;
- if (matchPattern(input2, m_Constant(&input2Attr)) && input2Attr.isSplat() &&
- input1.getType() == op.getType()) {
- if (input2Attr.getType().getElementType().isa<IntegerType>() &&
- input2Attr.getSplatValue<APInt>().isZero()) {
- rewriter.replaceOp(op, op.getInput1());
- return success();
- }
- }
-
- return failure();
- }
-};
-
-void AddOp::getCanonicalizationPatterns(RewritePatternSet &results,
- MLIRContext *context) {
- results.add<AddZeroOptimization>(context);
-}
-
-struct MulOneOptimization : public OpRewritePattern<tosa::MulOp> {
- using OpRewritePattern::OpRewritePattern;
-
- LogicalResult matchAndRewrite(tosa::MulOp op,
- PatternRewriter &rewriter) const override {
- auto input1 = op.getInput1();
- auto input2 = op.getInput2();
-
- DenseElementsAttr input1Attr;
- if (matchPattern(input1, m_Constant(&input1Attr)) && input1Attr.isSplat() &&
- input2.getType() == op.getType()) {
- if (input1Attr.getType().getElementType().isa<FloatType>() &&
- input1Attr.getSplatValue<APFloat>().isExactlyValue(1)) {
- rewriter.replaceOp(op, op.getInput2());
- return success();
- }
-
- if (input1Attr.getType().getElementType().isa<IntegerType>() &&
- matchPattern(input1, m_One())) {
- rewriter.replaceOp(op, op.getInput2());
- return success();
- }
- }
-
- DenseElementsAttr input2Attr;
- if (matchPattern(input2, m_Constant(&input2Attr)) && input2Attr.isSplat() &&
- input1.getType() == op.getType()) {
- if (input2Attr.getType().getElementType().isa<FloatType>() &&
- input2Attr.getSplatValue<APFloat>().isExactlyValue(1)) {
- rewriter.replaceOp(op, op.getInput1());
- return success();
- }
-
- if (input2Attr.getType().getElementType().isa<IntegerType>() &&
- matchPattern(input2, m_One())) {
- rewriter.replaceOp(op, op.getInput1());
- return success();
- }
- }
-
- return failure();
- }
-};
-
-void MulOp::getCanonicalizationPatterns(RewritePatternSet &results,
- MLIRContext *context) {
- results.add<MulOneOptimization>(context);
-}
-
struct MaterializePadValue : public OpRewritePattern<tosa::PadOp> {
using OpRewritePattern::OpRewritePattern;
@@ -609,44 +523,47 @@ DenseElementsAttr binaryFolder(DenseElementsAttr lhs, DenseElementsAttr rhs,
return {};
}
+static bool isSplatZero(Type elemType, DenseElementsAttr val) {
+ if (elemType.isa<FloatType>())
+ return val && val.isSplat() && val.getSplatValue<APFloat>().isZero();
+ if (elemType.isa<IntegerType>())
+ return val && val.isSplat() && val.getSplatValue<APInt>().isZero();
+ return false;
+}
+
+static bool isSplatOne(Type elemType, DenseElementsAttr val, int64_t shift) {
+ if (elemType.isa<FloatType>())
+ return val && val.isSplat() &&
+ val.getSplatValue<APFloat>().isExactlyValue(1.0);
+ if (elemType.isa<IntegerType>()) {
+ const int64_t shifted = 1LL << shift;
+ return val && val.isSplat() &&
+ val.getSplatValue<APInt>().getSExtValue() == shifted;
+ }
+ return false;
+}
+
OpFoldResult AddOp::fold(FoldAdaptor adaptor) {
auto lhsTy = getInput1().getType().dyn_cast<RankedTensorType>();
auto rhsTy = getInput2().getType().dyn_cast<RankedTensorType>();
auto resultTy = getType().dyn_cast<RankedTensorType>();
if (!lhsTy || !rhsTy || !resultTy)
return {};
- if (lhsTy != rhsTy)
- return {};
auto resultETy = resultTy.getElementType();
auto lhsAttr = adaptor.getInput1().dyn_cast_or_null<DenseElementsAttr>();
auto rhsAttr = adaptor.getInput2().dyn_cast_or_null<DenseElementsAttr>();
- if (lhsAttr && lhsAttr.isSplat() && resultETy.isa<FloatType>()) {
- if (lhsAttr.getSplatValue<APFloat>().isZero())
- return getInput2();
- }
-
- if (rhsAttr && rhsAttr.isSplat() && resultETy.isa<FloatType>()) {
- if (rhsAttr.getSplatValue<APFloat>().isZero())
- return getInput1();
- }
-
- if (lhsAttr && lhsAttr.isSplat() && resultETy.isa<IntegerType>()) {
- if (lhsAttr.getSplatValue<APInt>().isZero())
- return getInput2();
- }
-
- if (rhsAttr && rhsAttr.isSplat() && resultETy.isa<IntegerType>()) {
- if (rhsAttr.getSplatValue<APInt>().isZero())
- return getInput1();
- }
+ if (lhsTy == resultTy && isSplatZero(resultETy, rhsAttr))
+ return getInput1();
+ if (rhsTy == resultTy && isSplatZero(resultETy, lhsAttr))
+ return getInput2();
if (!lhsAttr || !rhsAttr)
return {};
return binaryFolder<std::plus<APInt>, std::plus<APFloat>>(lhsAttr, rhsAttr,
- lhsTy);
+ resultTy);
}
OpFoldResult DivOp::fold(FoldAdaptor adaptor) {
@@ -724,50 +641,26 @@ OpFoldResult MulOp::fold(FoldAdaptor adaptor) {
auto resultTy = getType().dyn_cast<RankedTensorType>();
if (!lhsTy || !rhsTy || !resultTy)
return {};
- if (lhsTy != rhsTy)
- return {};
auto resultETy = resultTy.getElementType();
auto lhsAttr = adaptor.getInput1().dyn_cast_or_null<DenseElementsAttr>();
auto rhsAttr = adaptor.getInput2().dyn_cast_or_null<DenseElementsAttr>();
- if (lhsAttr && lhsAttr.isSplat() && resultETy.isa<FloatType>()) {
- auto val = lhsAttr.getSplatValue<APFloat>();
- if (val.isZero())
+ const int64_t shift = resultETy.isa<IntegerType>() ? getShift() : 0;
+ if (rhsTy == resultTy) {
+ if (isSplatZero(resultETy, lhsAttr))
return lhsAttr;
- if (val.isExactlyValue(1.0))
+ if (isSplatOne(resultETy, lhsAttr, shift))
return rhs;
}
-
- if (rhsAttr && rhsAttr.isSplat() && resultETy.isa<FloatType>()) {
- auto val = rhsAttr.getSplatValue<APFloat>();
- if (val.isZero())
- return rhsAttr;
- if (val.isExactlyValue(1.0))
- return lhs;
- }
-
- if (lhsAttr && lhsAttr.isSplat() && resultETy.isa<IntegerType>()) {
- auto val = lhsAttr.getSplatValue<APInt>();
- if (val.isZero())
- return lhsAttr;
- const int64_t shift = getShift();
- const int64_t shifted = 1LL << shift;
- if (val.getSExtValue() == shifted)
- return rhs;
- }
-
- if (rhsAttr && rhsAttr.isSplat() && resultETy.isa<IntegerType>()) {
- auto val = rhsAttr.getSplatValue<APInt>();
- const int64_t shift = getShift();
- const int64_t shifted = 1LL << shift;
- if (val.isZero())
+ if (lhsTy == resultTy) {
+ if (isSplatZero(resultETy, rhsAttr))
return rhsAttr;
- if (val.getSExtValue() == shifted)
+ if (isSplatOne(resultETy, rhsAttr, shift))
return lhs;
}
- return mulBinaryFolder(lhsAttr, rhsAttr, lhsTy, getShift());
+ return mulBinaryFolder(lhsAttr, rhsAttr, resultTy, getShift());
}
OpFoldResult SubOp::fold(FoldAdaptor adaptor) {
@@ -776,28 +669,19 @@ OpFoldResult SubOp::fold(FoldAdaptor adaptor) {
auto resultTy = getType().dyn_cast<RankedTensorType>();
if (!lhsTy || !rhsTy || !resultTy)
return {};
- if (lhsTy != rhsTy)
- return {};
auto resultETy = resultTy.getElementType();
auto lhsAttr = adaptor.getInput1().dyn_cast_or_null<DenseElementsAttr>();
auto rhsAttr = adaptor.getInput2().dyn_cast_or_null<DenseElementsAttr>();
- if (rhsAttr && rhsAttr.isSplat() && resultETy.isa<FloatType>()) {
- if (rhsAttr.getSplatValue<APFloat>().isZero())
- return getInput1();
- }
-
- if (rhsAttr && rhsAttr.isSplat() && resultETy.isa<IntegerType>()) {
- if (rhsAttr.getSplatValue<APInt>().isZero())
- return getInput1();
- }
+ if (lhsTy == resultTy && isSplatZero(resultETy, rhsAttr))
+ return getInput1();
if (!lhsAttr || !rhsAttr)
return {};
return binaryFolder<std::minus<APInt>, std::minus<APFloat>>(lhsAttr, rhsAttr,
- lhsTy);
+ resultTy);
}
namespace {
diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index 77627d8c8ba62..bdd4021cb39a1 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -7,15 +7,15 @@ func.func @argmax_nofold(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
return %0 : tensor<?x1xf32>
}
-// CHECK-LABEL: @add_zero_
diff erent_shape
-func.func @add_zero_
diff erent_shape(%arg0: tensor<2x3xi32>) -> tensor<4x2x3xi32> {
- // CHECK: tosa.add
- %zeros = "tosa.const"() {value = dense<0> : tensor<4x2x3xi32>} : () -> tensor<4x2x3xi32>
- %1 = "tosa.add"(%arg0, %zeros) : (tensor<2x3xi32>, tensor<4x2x3xi32>) -> tensor<4x2x3xi32>
+// CHECK-LABEL: @add_bcast_zero_int
+func.func @add_bcast_zero_int(%arg0: tensor<4x2x3xi32>) -> tensor<4x2x3xi32> {
+ // CHECK-NOT: tosa.add
+ // CHECK: return %arg0
+ %zeros = "tosa.const"() {value = dense<0> : tensor<1x1x1xi32>} : () -> tensor<1x1x1xi32>
+ %1 = "tosa.add"(%arg0, %zeros) : (tensor<4x2x3xi32>, tensor<1x1x1xi32>) -> tensor<4x2x3xi32>
return %1 : tensor<4x2x3xi32>
}
-
// CHECK-LABEL: @add_zero_int
func.func @add_zero_int(%arg0: tensor<2x3xi32>) -> tensor<2x3xi32> {
// CHECK: return %arg0
@@ -176,14 +176,6 @@ func.func @pad_determine_val_quant(%arg0: tensor<?x?xi32>, %arg1 : tensor<2x2xi3
return %1 : tensor<?x?xi32>
}
-// CHECK-LABEL: @mul_one_
diff erent_shape
-func.func @mul_one_
diff erent_shape(%arg0: tensor<2x3xf32>) -> tensor<4x2x3xf32> {
- // CHECK: tosa.mul
- %ones = "tosa.const"() {value = dense<1.0> : tensor<4x2x3xf32>} : () -> tensor<4x2x3xf32>
- %1 = "tosa.mul"(%arg0, %ones) {shift = 0 : i32} : (tensor<2x3xf32>, tensor<4x2x3xf32>) -> tensor<4x2x3xf32>
- return %1 : tensor<4x2x3xf32>
-}
-
// CHECK-LABEL: @mul_one_float
func.func @mul_one_float(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> {
// CHECK: return %arg0
@@ -193,6 +185,15 @@ func.func @mul_one_float(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> {
return %1 : tensor<2x3xf32>
}
+// CHECK-LABEL: @mul_bcast_one_float
+func.func @mul_bcast_one_float(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> {
+ // CHECK: return %arg0
+ // CHECK-NOT: tosa.mul
+ %ones = "tosa.const"() {value = dense<1.0> : tensor<1x1xf32>} : () -> tensor<1x1xf32>
+ %1 = "tosa.mul"(%ones, %arg0) {shift = 0 : i32} : (tensor<1x1xf32>, tensor<2x3xf32>) -> tensor<2x3xf32>
+ return %1 : tensor<2x3xf32>
+}
+
// CHECK-LABEL: @mul_one_int
func.func @mul_one_int(%arg0: tensor<2x3xi32>) -> tensor<2x3xi32> {
// CHECK: return %arg0
More information about the Mlir-commits
mailing list