[Mlir-commits] [mlir] da1e37a - Fold full-size subview of static shapes.
Ahmed Taei
llvmlistbot at llvm.org
Thu Mar 4 09:52:15 PST 2021
Author: Ahmed Taei
Date: 2021-03-04T09:52:06-08:00
New Revision: da1e37a8b06b921ac8f742245bb4d2a6ecd8b9e1
URL: https://github.com/llvm/llvm-project/commit/da1e37a8b06b921ac8f742245bb4d2a6ecd8b9e1
DIFF: https://github.com/llvm/llvm-project/commit/da1e37a8b06b921ac8f742245bb4d2a6ecd8b9e1.diff
LOG: Fold full-size subview of static shapes.
Differential Revision: https://reviews.llvm.org/D97429
Added:
Modified:
mlir/lib/Dialect/StandardOps/IR/Ops.cpp
mlir/test/Dialect/Standard/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
index 11a597ce4183..eb7ed2f7418c 100644
--- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
+++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
@@ -3496,9 +3496,13 @@ void SubViewOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
}
OpFoldResult SubViewOp::fold(ArrayRef<Attribute> operands) {
- if (getResult().getType().cast<ShapedType>().getRank() == 0 &&
- source().getType().cast<ShapedType>().getRank() == 0)
+ auto resultShapedType = getResult().getType().cast<ShapedType>();
+ auto sourceShapedType = source().getType().cast<ShapedType>();
+
+ if (resultShapedType.hasStaticShape() &&
+ resultShapedType == sourceShapedType) {
return getViewSource();
+ }
return {};
}
diff --git a/mlir/test/Dialect/Standard/canonicalize.mlir b/mlir/test/Dialect/Standard/canonicalize.mlir
index 72b886a238ff..2d722eaaaa07 100644
--- a/mlir/test/Dialect/Standard/canonicalize.mlir
+++ b/mlir/test/Dialect/Standard/canonicalize.mlir
@@ -204,6 +204,17 @@ func @subview_of_memcast(%arg : memref<4x6x16x32xi8>) ->
// -----
+// CHECK-LABEL: func @subview_of_static_full_size
+// CHECK-SAME: %[[ARG0:.+]]: memref<4x6x16x32xi8>
+// CHECK-NOT: subview
+// CHECK: return %[[ARG0]] : memref<4x6x16x32xi8>
+func @subview_of_static_full_size(%arg0 : memref<4x6x16x32xi8>) -> memref<4x6x16x32xi8> {
+ %0 = subview %arg0[0, 0, 0, 0] [4, 6, 16, 32] [1, 1, 1, 1] : memref<4x6x16x32xi8> to memref<4x6x16x32xi8>
+ return %0 : memref<4x6x16x32xi8>
+}
+
+// -----
+
// CHECK-LABEL: func @trivial_subtensor
// CHECK-SAME: %[[ARG0:.[a-z0-9A-Z_]+]]: tensor<4x6x16x32xi8>
// CHECK-NOT: subtensor
More information about the Mlir-commits
mailing list