[Mlir-commits] [mlir] 9c62bb5 - Implementation of `ReshapeNoopOptimization` canonicalizer.
Rob Suderman
llvmlistbot at llvm.org
Tue Oct 19 16:08:48 PDT 2021
Author: Kojo Acquah
Date: 2021-10-19T16:07:34-07:00
New Revision: 9c62bb55f473a9d0db16b894708ed09f2346ae9d
URL: https://github.com/llvm/llvm-project/commit/9c62bb55f473a9d0db16b894708ed09f2346ae9d
DIFF: https://github.com/llvm/llvm-project/commit/9c62bb55f473a9d0db16b894708ed09f2346ae9d.diff
LOG: Implementation of `ReshapeNoopOptimization` canonicalizer.
This canonicalizer replaces reshapes of constant tensors that contain the updated shape (skipping the reshape operation).
Differential Revision: https://reviews.llvm.org/D112038
Added:
Modified:
mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
mlir/test/Dialect/Tosa/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 2ad14f5bf1809..744eaa0b3bb8a 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -154,9 +154,43 @@ struct ReshapeReshapeOptimization : public OpRewritePattern<tosa::ReshapeOp> {
}
};
+struct ReshapeConstOptimization : public OpRewritePattern<tosa::ReshapeOp> {
+ using OpRewritePattern<tosa::ReshapeOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(tosa::ReshapeOp op,
+ PatternRewriter &rewriter) const override {
+ Value input = op.input1();
+ ArrayAttr newShape = op.new_shape();
+
+ // Check if input is constant
+ DenseElementsAttr inputAttr;
+ if (!matchPattern(input, m_Constant(&inputAttr)))
+ return failure();
+
+ // Check if has >1 consumer and is not splat
+ if (!input.hasOneUse() && !inputAttr.isSplat())
+ return failure();
+
+ // Grab the new shape
+ SmallVector<int64_t> newShapeValues = llvm::to_vector<6>(
+ llvm::map_range(newShape.getValue(), [](const Attribute &val) {
+ return val.cast<IntegerAttr>().getValue().getSExtValue();
+ }));
+
+ // Build new const op with correct output shape
+ ShapedType inputShape = input.getType().cast<ShapedType>();
+ DenseElementsAttr outputAttr =
+ inputAttr.reshape(inputShape.clone(newShapeValues));
+ rewriter.replaceOpWithNewOp<tosa::ConstOp>(op, outputAttr.getType(),
+ outputAttr);
+ return success();
+ }
+};
+
void ReshapeOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) {
results.insert<ReshapeReshapeOptimization>(context);
+ results.insert<ReshapeConstOptimization>(context);
}
struct ConstantTransposeOptimization
diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index 983ce58c0c652..e6cf1a15ac67f 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -174,6 +174,39 @@ func @reshape_canonicalize_double(%arg0: tensor<?x10xf32>) -> tensor<?x5xf32> {
// -----
+// CHECK-LABEL: @reshape_canonicalize_const
+func @reshape_canonicalize_const() -> tensor<1x10xi32> {
+ // CHECK: %[[VAR0:.+]] = "tosa.const"() {value = dense<0> : tensor<1x10xi32>}
+ // CHECK: return %[[VAR0]]
+ %0 = "tosa.const"() {value = dense<0> : tensor<10xi32>} : () -> tensor<10xi32>
+ %1 = "tosa.reshape"(%0) {new_shape = [1, 10]} : (tensor<10xi32>) -> tensor<1x10xi32>
+ return %1 : tensor<1x10xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @reshape_canonicalize_const_spat
+func @reshape_canonicalize_const_spat() -> (tensor<10xi32>, tensor<1x10xi32>) {
+ // CHECK-DAG: %[[VAR0:.+]] = "tosa.const"() {value = dense<0> : tensor<10xi32>}
+ // CHECK-DAG: %[[VAR1:.+]] = "tosa.const"() {value = dense<0> : tensor<1x10xi32>}
+ // CHECK: return %[[VAR0]], %[[VAR1]]
+ %0 = "tosa.const"() {value = dense<0> : tensor<10xi32>} : () -> tensor<10xi32>
+ %1 = "tosa.reshape"(%0) {new_shape = [1, 10]} : (tensor<10xi32>) -> tensor<1x10xi32>
+ return %0 , %1 : tensor<10xi32>, tensor<1x10xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @reshape_canonicalize_const_sparse
+func @reshape_canonicalize_const_sparse() -> (tensor<3xi32>, tensor<1x3xi32>) {
+ //CHECK: "tosa.reshape"
+ %0 = "tosa.const"() {value = dense<[1, 2, 3]> : tensor<3xi32>} : ()-> tensor<3xi32>
+ %1 = "tosa.reshape"(%0) {new_shape = [1, 3]} : (tensor<3xi32>) -> tensor<1x3xi32>
+ return %0 , %1 : tensor<3xi32>, tensor<1x3xi32>
+}
+
+// -----
+
// CHECK-LABEL: @slice_fold
func @slice_fold(%arg0: tensor<3x4xf32>) -> tensor<3x4xf32> {
// CHECK: return %arg0
More information about the Mlir-commits
mailing list