[Mlir-commits] [mlir] da1e37a - Fold full-size subview of static shapes.

Ahmed Taei llvmlistbot at llvm.org
Thu Mar 4 09:52:15 PST 2021


Author: Ahmed Taei
Date: 2021-03-04T09:52:06-08:00
New Revision: da1e37a8b06b921ac8f742245bb4d2a6ecd8b9e1

URL: https://github.com/llvm/llvm-project/commit/da1e37a8b06b921ac8f742245bb4d2a6ecd8b9e1
DIFF: https://github.com/llvm/llvm-project/commit/da1e37a8b06b921ac8f742245bb4d2a6ecd8b9e1.diff

LOG: Fold full-size subview of static shapes.

Differential Revision: https://reviews.llvm.org/D97429

Added: 
    

Modified: 
    mlir/lib/Dialect/StandardOps/IR/Ops.cpp
    mlir/test/Dialect/Standard/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
index 11a597ce4183..eb7ed2f7418c 100644
--- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
+++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
@@ -3496,9 +3496,13 @@ void SubViewOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
 }
 
 OpFoldResult SubViewOp::fold(ArrayRef<Attribute> operands) {
-  if (getResult().getType().cast<ShapedType>().getRank() == 0 &&
-      source().getType().cast<ShapedType>().getRank() == 0)
+  auto resultShapedType = getResult().getType().cast<ShapedType>();
+  auto sourceShapedType = source().getType().cast<ShapedType>();
+
+  if (resultShapedType.hasStaticShape() &&
+      resultShapedType == sourceShapedType) {
     return getViewSource();
+  }
 
   return {};
 }

diff  --git a/mlir/test/Dialect/Standard/canonicalize.mlir b/mlir/test/Dialect/Standard/canonicalize.mlir
index 72b886a238ff..2d722eaaaa07 100644
--- a/mlir/test/Dialect/Standard/canonicalize.mlir
+++ b/mlir/test/Dialect/Standard/canonicalize.mlir
@@ -204,6 +204,17 @@ func @subview_of_memcast(%arg : memref<4x6x16x32xi8>) ->
 
 // -----
 
+// CHECK-LABEL: func @subview_of_static_full_size
+// CHECK-SAME: %[[ARG0:.+]]: memref<4x6x16x32xi8>
+// CHECK-NOT: subview
+// CHECK: return %[[ARG0]] : memref<4x6x16x32xi8>
+func @subview_of_static_full_size(%arg0 : memref<4x6x16x32xi8>) -> memref<4x6x16x32xi8> {
+  %0 = subview %arg0[0, 0, 0, 0] [4, 6, 16, 32] [1, 1, 1, 1] : memref<4x6x16x32xi8> to memref<4x6x16x32xi8>
+  return %0 : memref<4x6x16x32xi8>
+}
+
+// -----
+
 // CHECK-LABEL: func @trivial_subtensor
 //  CHECK-SAME:   %[[ARG0:.[a-z0-9A-Z_]+]]: tensor<4x6x16x32xi8>
 //   CHECK-NOT:   subtensor


        


More information about the Mlir-commits mailing list