[Mlir-commits] [mlir] 94438c8 - [mlir] Add a MemRefCastOp canonicalization pattern.

Nicolas Vasilache llvmlistbot at llvm.org
Wed May 6 06:11:54 PDT 2020


Author: Nicolas Vasilache
Date: 2020-05-06T09:10:05-04:00
New Revision: 94438c86adef0a0f50bc3737253e8c98b4e3dd3e

URL: https://github.com/llvm/llvm-project/commit/94438c86adef0a0f50bc3737253e8c98b4e3dd3e
DIFF: https://github.com/llvm/llvm-project/commit/94438c86adef0a0f50bc3737253e8c98b4e3dd3e.diff

LOG: [mlir] Add a MemRefCastOp canonicalization pattern.

Summary:
This revision adds a conservative canonicalization pattern for MemRefCastOp that are typically inserted during ViewOp and SubViewOp canonicalization.
Ideally such canonicalizations would propagate the type to consumers but this is not a local behavior. As a consequence MemRefCastOp are introduced to keep type compatibility but need to be cleaned up later, in the case where more dynamic behavior than necessary is introduced.

Differential Revision: https://reviews.llvm.org/D79438

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/StandardOps/IR/Ops.h
    mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
    mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
    mlir/lib/Dialect/StandardOps/IR/Ops.cpp
    mlir/test/Transforms/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h
index 573f9b7c988f..218aff6d11c9 100644
--- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h
+++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h
@@ -301,6 +301,44 @@ ParseResult parseDimAndSymbolList(OpAsmParser &parser,
 
 raw_ostream &operator<<(raw_ostream &os, SubViewOp::Range &range);
 
+/// Determines whether MemRefCastOp casts to a more dynamic version of the
+/// source memref. This is useful to to fold a memref_cast into a consuming op
+/// and implement canonicalization patterns for ops in 
diff erent dialects that
+/// may consume the results of memref_cast operations. Such foldable memref_cast
+/// operations are typically inserted as `view` and `subview` ops are
+/// canonicalized, to preserve the type compatibility of their uses.
+///
+/// Returns true when all conditions are met:
+/// 1. source and result are ranked memrefs with strided semantics and same
+/// element type and rank.
+/// 2. each of the source's size, offset or stride has more static information
+/// than the corresponding result's size, offset or stride.
+///
+/// Example 1:
+/// ```mlir
+///   %1 = memref_cast %0 : memref<8x16xf32> to memref<?x?xf32>
+///   %2 = consumer %1 ... : memref<?x?xf32> ...
+/// ```
+///
+/// may fold into:
+///
+/// ```mlir
+///   %2 = consumer %0 ... : memref<8x16xf32> ...
+/// ```
+///
+/// Example 2:
+/// ```
+///   %1 = memref_cast %0 : memref<?x16xf32, affine_map<(i, j)->(16 * i + j)>>
+///          to memref<?x?xf32>
+///   consumer %1 : memref<?x?xf32> ...
+/// ```
+///
+/// may fold into:
+///
+/// ```
+///   consumer %0 ... : memref<?x16xf32, affine_map<(i, j)->(16 * i + j)>>
+/// ```
+bool canFoldIntoConsumerOp(MemRefCastOp castOp);
 } // end namespace mlir
 
 #endif // MLIR_DIALECT_IR_STANDARDOPS_IR_OPS_H

diff  --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
index efcbdf63983e..4e04df9b9215 100644
--- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
+++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
@@ -2606,6 +2606,7 @@ def SubViewOp : Std_Op<"subview", [
   }];
 
   let hasCanonicalizer = 1;
+  let hasFolder = 1;
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 82ae6de83c83..5803824a3162 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -44,82 +44,16 @@ static ParseResult parseNamedStructuredOp(OpAsmParser &parser,
 template <typename NamedStructuredOpType>
 static LogicalResult verifyNamedStructuredOp(NamedStructuredOpType op);
 
-/// Determines whether it is possible to fold it away in the parent Linalg op:
-///
-/// ```mlir
-///   %1 = memref_cast %0 : memref<8x16xf32> to memref<?x?xf32>
-///   %2 = linalg.slice %1 ... : memref<?x?xf32> ...
-///   // or
-///   %1 = memref_cast %0 : memref<8x16xf32, affine_map<(i, j)->(16 * i + j)>>
-///          to memref<?x?xf32>
-///   linalg.generic(%1 ...) : memref<?x?xf32> ...
-/// ```
-///
-/// into
-///
-/// ```mlir
-///   %2 = linalg.slice %0 ... : memref<8x16xf32> ...
-///   // or
-///   linalg.generic(%0 ... : memref<8x16xf32, affine_map<(i, j)->(16 * i + j)>>
-/// ```
-///
-static bool canFold(MemRefCastOp castOp) {
-  MemRefType sourceType = castOp.source().getType().dyn_cast<MemRefType>();
-  MemRefType resultType = castOp.getType().dyn_cast<MemRefType>();
-
-  // If we don't have MemRefType as source and destination, bail out.
-  if (!sourceType || !resultType)
-    return false;
-
-  // If resultType has a map, it needs to be the same as the source type to
-  // canonicalize.
-  if (!resultType.getAffineMaps().empty() &&
-      sourceType.getAffineMaps() != resultType.getAffineMaps())
-    return false;
-
-  // Ensure that:
-  //   1. source is static
-  //   2. source and target have the same rank (will be extended when needed)
-  //   3. if result is partially static, ensure sizes match.
-  if (!sourceType.hasStaticShape() ||
-      sourceType.getRank() != resultType.getRank())
-    return false;
-
-  for (auto it : llvm::zip(sourceType.getShape(), resultType.getShape())) {
-    auto sourceSize = std::get<0>(it);
-    auto resultSize = std::get<1>(it);
-    if (ShapedType::isDynamic(resultSize))
-      continue;
-    if (sourceSize != resultSize)
-      return false;
-  }
-
-  // If source has a map, it can only canonicalize if it is the canonical
-  // strided layout map.
-  if (sourceType.getAffineMaps().empty())
-    return true;
-
-  int64_t offset;
-  SmallVector<int64_t, 4> strides;
-  auto res = getStridesAndOffset(sourceType, strides, offset);
-  (void)res;
-  assert(succeeded(res));
-  auto stridedMap =
-      makeStridedLinearLayoutMap(strides, offset, castOp.getContext());
-  AffineMap sourceMap = sourceType.getAffineMaps().front();
-  return sourceMap == stridedMap;
-}
-
 /// This is a common class used for patterns of the form
 /// ```
 ///    someop(memrefcast) -> someop
 /// ```
-/// It folds the source of any memref_cast into the root operation directly.
+/// It folds the source of the memref_cast into the root operation directly.
 static LogicalResult foldMemRefCast(Operation *op) {
   bool folded = false;
   for (OpOperand &operand : op->getOpOperands()) {
     auto castOp = dyn_cast_or_null<MemRefCastOp>(operand.get().getDefiningOp());
-    if (castOp && canFold(castOp)) {
+    if (castOp && canFoldIntoConsumerOp(castOp)) {
       operand.set(castOp.getOperand());
       folded = true;
     }

diff  --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
index 972a37d20f97..269dd083542c 100644
--- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
+++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
@@ -2519,6 +2519,111 @@ class SubViewOpOffsetFolder final : public OpRewritePattern<SubViewOp> {
 
 } // end anonymous namespace
 
+/// Determines whether MemRefCastOp casts to a more dynamic version of the
+/// source memref. This is useful to to fold a memref_cast into a consuming op
+/// and implement canonicalization patterns for ops in 
diff erent dialects that
+/// may consume the results of memref_cast operations. Such foldable memref_cast
+/// operations are typically inserted as `view` and `subview` ops are
+/// canonicalized, to preserve the type compatibility of their uses.
+///
+/// Returns true when all conditions are met:
+/// 1. source and result are ranked memrefs with strided semantics and same
+/// element type and rank.
+/// 2. each of the source's size, offset or stride has more static information
+/// than the corresponding result's size, offset or stride.
+///
+/// Example 1:
+/// ```mlir
+///   %1 = memref_cast %0 : memref<8x16xf32> to memref<?x?xf32>
+///   %2 = consumer %1 ... : memref<?x?xf32> ...
+/// ```
+///
+/// may fold into:
+///
+/// ```mlir
+///   %2 = consumer %0 ... : memref<8x16xf32> ...
+/// ```
+///
+/// Example 2:
+/// ```
+///   %1 = memref_cast %0 : memref<?x16xf32, affine_map<(i, j)->(16 * i + j)>>
+///          to memref<?x?xf32>
+///   consumer %1 : memref<?x?xf32> ...
+/// ```
+///
+/// may fold into:
+///
+/// ```
+///   consumer %0 ... : memref<?x16xf32, affine_map<(i, j)->(16 * i + j)>>
+/// ```
+bool mlir::canFoldIntoConsumerOp(MemRefCastOp castOp) {
+  MemRefType sourceType = castOp.source().getType().dyn_cast<MemRefType>();
+  MemRefType resultType = castOp.getType().dyn_cast<MemRefType>();
+
+  // Requires ranked MemRefType.
+  if (!sourceType || !resultType)
+    return false;
+
+  // Requires same elemental type.
+  if (sourceType.getElementType() != resultType.getElementType())
+    return false;
+
+  // Requires same rank.
+  if (sourceType.getRank() != resultType.getRank())
+    return false;
+
+  // Only fold casts between strided memref forms.
+  int64_t sourceOffset, resultOffset;
+  SmallVector<int64_t, 4> sourceStrides, resultStrides;
+  if (failed(getStridesAndOffset(sourceType, sourceStrides, sourceOffset)) ||
+      failed(getStridesAndOffset(resultType, resultStrides, resultOffset)))
+    return false;
+
+  // If cast is towards more static sizes along any dimension, don't fold.
+  for (auto it : llvm::zip(sourceType.getShape(), resultType.getShape())) {
+    auto ss = std::get<0>(it), st = std::get<1>(it);
+    if (ss != st)
+      if (MemRefType::isDynamic(ss) && !MemRefType::isDynamic(st))
+        return false;
+  }
+
+  // If cast is towards more static offset along any dimension, don't fold.
+  if (sourceOffset != resultOffset)
+    if (MemRefType::isDynamicStrideOrOffset(sourceOffset) &&
+        !MemRefType::isDynamicStrideOrOffset(resultOffset))
+      return false;
+
+  // If cast is towards more static strides along any dimension, don't fold.
+  for (auto it : llvm::zip(sourceStrides, resultStrides)) {
+    auto ss = std::get<0>(it), st = std::get<1>(it);
+    if (ss != st)
+      if (MemRefType::isDynamicStrideOrOffset(ss) &&
+          !MemRefType::isDynamicStrideOrOffset(st))
+        return false;
+  }
+
+  return true;
+}
+
+OpFoldResult SubViewOp::fold(ArrayRef<Attribute>) {
+  auto folds = [](Operation *op) {
+    bool folded = false;
+    for (OpOperand &operand : op->getOpOperands()) {
+      auto castOp =
+          dyn_cast_or_null<MemRefCastOp>(operand.get().getDefiningOp());
+      if (castOp && canFoldIntoConsumerOp(castOp)) {
+        operand.set(castOp.getOperand());
+        folded = true;
+      }
+    }
+    return folded ? success() : failure();
+  };
+
+  if (succeeded(folds(*this)))
+    return getResult();
+  return {};
+}
+
 void SubViewOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
                                             MLIRContext *context) {
   results.insert<SubViewOpShapeFolder, SubViewOpStrideFolder,

diff  --git a/mlir/test/Transforms/canonicalize.mlir b/mlir/test/Transforms/canonicalize.mlir
index 1cff314d731a..781ce83d95c1 100644
--- a/mlir/test/Transforms/canonicalize.mlir
+++ b/mlir/test/Transforms/canonicalize.mlir
@@ -919,3 +919,15 @@ func @tensor_divi_unsigned_by_one(%arg0: tensor<4x5xi32>) -> tensor<4x5xi32> {
   // CHECK: return %[[ARG]]
   return %res : tensor<4x5xi32>
 }
+
+// -----
+
+// CHECK-LABEL: func @memref_cast_folding_subview
+func @memref_cast_folding_subview(%arg0: memref<4x5xf32>, %i: index) -> (memref<?x?xf32, offset:? , strides: [?, ?]>) {
+  %0 = memref_cast %arg0 : memref<4x5xf32> to memref<?x?xf32>
+  // CHECK-NEXT: subview %{{.*}}: memref<4x5xf32>
+  %1 = subview %0[][%i,%i][]: memref<?x?xf32> to memref<?x?xf32, offset:? , strides: [?, ?]>
+  // CHECK-NEXT: return %{{.*}}
+  return %1: memref<?x?xf32, offset:? , strides: [?, ?]>
+}
+


        


More information about the Mlir-commits mailing list