[Mlir-commits] [mlir] 94438c8 - [mlir] Add a MemRefCastOp canonicalization pattern.
Nicolas Vasilache
llvmlistbot at llvm.org
Wed May 6 06:11:54 PDT 2020
Author: Nicolas Vasilache
Date: 2020-05-06T09:10:05-04:00
New Revision: 94438c86adef0a0f50bc3737253e8c98b4e3dd3e
URL: https://github.com/llvm/llvm-project/commit/94438c86adef0a0f50bc3737253e8c98b4e3dd3e
DIFF: https://github.com/llvm/llvm-project/commit/94438c86adef0a0f50bc3737253e8c98b4e3dd3e.diff
LOG: [mlir] Add a MemRefCastOp canonicalization pattern.
Summary:
This revision adds a conservative canonicalization pattern for MemRefCastOp that are typically inserted during ViewOp and SubViewOp canonicalization.
Ideally such canonicalizations would propagate the type to consumers but this is not a local behavior. As a consequence MemRefCastOp are introduced to keep type compatibility but need to be cleaned up later, in the case where more dynamic behavior than necessary is introduced.
Differential Revision: https://reviews.llvm.org/D79438
Added:
Modified:
mlir/include/mlir/Dialect/StandardOps/IR/Ops.h
mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
mlir/lib/Dialect/StandardOps/IR/Ops.cpp
mlir/test/Transforms/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h
index 573f9b7c988f..218aff6d11c9 100644
--- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h
+++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h
@@ -301,6 +301,44 @@ ParseResult parseDimAndSymbolList(OpAsmParser &parser,
raw_ostream &operator<<(raw_ostream &os, SubViewOp::Range &range);
+/// Determines whether MemRefCastOp casts to a more dynamic version of the
+/// source memref. This is useful to to fold a memref_cast into a consuming op
+/// and implement canonicalization patterns for ops in
diff erent dialects that
+/// may consume the results of memref_cast operations. Such foldable memref_cast
+/// operations are typically inserted as `view` and `subview` ops are
+/// canonicalized, to preserve the type compatibility of their uses.
+///
+/// Returns true when all conditions are met:
+/// 1. source and result are ranked memrefs with strided semantics and same
+/// element type and rank.
+/// 2. each of the source's size, offset or stride has more static information
+/// than the corresponding result's size, offset or stride.
+///
+/// Example 1:
+/// ```mlir
+/// %1 = memref_cast %0 : memref<8x16xf32> to memref<?x?xf32>
+/// %2 = consumer %1 ... : memref<?x?xf32> ...
+/// ```
+///
+/// may fold into:
+///
+/// ```mlir
+/// %2 = consumer %0 ... : memref<8x16xf32> ...
+/// ```
+///
+/// Example 2:
+/// ```
+/// %1 = memref_cast %0 : memref<?x16xf32, affine_map<(i, j)->(16 * i + j)>>
+/// to memref<?x?xf32>
+/// consumer %1 : memref<?x?xf32> ...
+/// ```
+///
+/// may fold into:
+///
+/// ```
+/// consumer %0 ... : memref<?x16xf32, affine_map<(i, j)->(16 * i + j)>>
+/// ```
+bool canFoldIntoConsumerOp(MemRefCastOp castOp);
} // end namespace mlir
#endif // MLIR_DIALECT_IR_STANDARDOPS_IR_OPS_H
diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
index efcbdf63983e..4e04df9b9215 100644
--- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
+++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
@@ -2606,6 +2606,7 @@ def SubViewOp : Std_Op<"subview", [
}];
let hasCanonicalizer = 1;
+ let hasFolder = 1;
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 82ae6de83c83..5803824a3162 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -44,82 +44,16 @@ static ParseResult parseNamedStructuredOp(OpAsmParser &parser,
template <typename NamedStructuredOpType>
static LogicalResult verifyNamedStructuredOp(NamedStructuredOpType op);
-/// Determines whether it is possible to fold it away in the parent Linalg op:
-///
-/// ```mlir
-/// %1 = memref_cast %0 : memref<8x16xf32> to memref<?x?xf32>
-/// %2 = linalg.slice %1 ... : memref<?x?xf32> ...
-/// // or
-/// %1 = memref_cast %0 : memref<8x16xf32, affine_map<(i, j)->(16 * i + j)>>
-/// to memref<?x?xf32>
-/// linalg.generic(%1 ...) : memref<?x?xf32> ...
-/// ```
-///
-/// into
-///
-/// ```mlir
-/// %2 = linalg.slice %0 ... : memref<8x16xf32> ...
-/// // or
-/// linalg.generic(%0 ... : memref<8x16xf32, affine_map<(i, j)->(16 * i + j)>>
-/// ```
-///
-static bool canFold(MemRefCastOp castOp) {
- MemRefType sourceType = castOp.source().getType().dyn_cast<MemRefType>();
- MemRefType resultType = castOp.getType().dyn_cast<MemRefType>();
-
- // If we don't have MemRefType as source and destination, bail out.
- if (!sourceType || !resultType)
- return false;
-
- // If resultType has a map, it needs to be the same as the source type to
- // canonicalize.
- if (!resultType.getAffineMaps().empty() &&
- sourceType.getAffineMaps() != resultType.getAffineMaps())
- return false;
-
- // Ensure that:
- // 1. source is static
- // 2. source and target have the same rank (will be extended when needed)
- // 3. if result is partially static, ensure sizes match.
- if (!sourceType.hasStaticShape() ||
- sourceType.getRank() != resultType.getRank())
- return false;
-
- for (auto it : llvm::zip(sourceType.getShape(), resultType.getShape())) {
- auto sourceSize = std::get<0>(it);
- auto resultSize = std::get<1>(it);
- if (ShapedType::isDynamic(resultSize))
- continue;
- if (sourceSize != resultSize)
- return false;
- }
-
- // If source has a map, it can only canonicalize if it is the canonical
- // strided layout map.
- if (sourceType.getAffineMaps().empty())
- return true;
-
- int64_t offset;
- SmallVector<int64_t, 4> strides;
- auto res = getStridesAndOffset(sourceType, strides, offset);
- (void)res;
- assert(succeeded(res));
- auto stridedMap =
- makeStridedLinearLayoutMap(strides, offset, castOp.getContext());
- AffineMap sourceMap = sourceType.getAffineMaps().front();
- return sourceMap == stridedMap;
-}
-
/// This is a common class used for patterns of the form
/// ```
/// someop(memrefcast) -> someop
/// ```
-/// It folds the source of any memref_cast into the root operation directly.
+/// It folds the source of the memref_cast into the root operation directly.
static LogicalResult foldMemRefCast(Operation *op) {
bool folded = false;
for (OpOperand &operand : op->getOpOperands()) {
auto castOp = dyn_cast_or_null<MemRefCastOp>(operand.get().getDefiningOp());
- if (castOp && canFold(castOp)) {
+ if (castOp && canFoldIntoConsumerOp(castOp)) {
operand.set(castOp.getOperand());
folded = true;
}
diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
index 972a37d20f97..269dd083542c 100644
--- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
+++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
@@ -2519,6 +2519,111 @@ class SubViewOpOffsetFolder final : public OpRewritePattern<SubViewOp> {
} // end anonymous namespace
+/// Determines whether MemRefCastOp casts to a more dynamic version of the
+/// source memref. This is useful to to fold a memref_cast into a consuming op
+/// and implement canonicalization patterns for ops in
diff erent dialects that
+/// may consume the results of memref_cast operations. Such foldable memref_cast
+/// operations are typically inserted as `view` and `subview` ops are
+/// canonicalized, to preserve the type compatibility of their uses.
+///
+/// Returns true when all conditions are met:
+/// 1. source and result are ranked memrefs with strided semantics and same
+/// element type and rank.
+/// 2. each of the source's size, offset or stride has more static information
+/// than the corresponding result's size, offset or stride.
+///
+/// Example 1:
+/// ```mlir
+/// %1 = memref_cast %0 : memref<8x16xf32> to memref<?x?xf32>
+/// %2 = consumer %1 ... : memref<?x?xf32> ...
+/// ```
+///
+/// may fold into:
+///
+/// ```mlir
+/// %2 = consumer %0 ... : memref<8x16xf32> ...
+/// ```
+///
+/// Example 2:
+/// ```
+/// %1 = memref_cast %0 : memref<?x16xf32, affine_map<(i, j)->(16 * i + j)>>
+/// to memref<?x?xf32>
+/// consumer %1 : memref<?x?xf32> ...
+/// ```
+///
+/// may fold into:
+///
+/// ```
+/// consumer %0 ... : memref<?x16xf32, affine_map<(i, j)->(16 * i + j)>>
+/// ```
+bool mlir::canFoldIntoConsumerOp(MemRefCastOp castOp) {
+ MemRefType sourceType = castOp.source().getType().dyn_cast<MemRefType>();
+ MemRefType resultType = castOp.getType().dyn_cast<MemRefType>();
+
+ // Requires ranked MemRefType.
+ if (!sourceType || !resultType)
+ return false;
+
+ // Requires same elemental type.
+ if (sourceType.getElementType() != resultType.getElementType())
+ return false;
+
+ // Requires same rank.
+ if (sourceType.getRank() != resultType.getRank())
+ return false;
+
+ // Only fold casts between strided memref forms.
+ int64_t sourceOffset, resultOffset;
+ SmallVector<int64_t, 4> sourceStrides, resultStrides;
+ if (failed(getStridesAndOffset(sourceType, sourceStrides, sourceOffset)) ||
+ failed(getStridesAndOffset(resultType, resultStrides, resultOffset)))
+ return false;
+
+ // If cast is towards more static sizes along any dimension, don't fold.
+ for (auto it : llvm::zip(sourceType.getShape(), resultType.getShape())) {
+ auto ss = std::get<0>(it), st = std::get<1>(it);
+ if (ss != st)
+ if (MemRefType::isDynamic(ss) && !MemRefType::isDynamic(st))
+ return false;
+ }
+
+ // If cast is towards more static offset along any dimension, don't fold.
+ if (sourceOffset != resultOffset)
+ if (MemRefType::isDynamicStrideOrOffset(sourceOffset) &&
+ !MemRefType::isDynamicStrideOrOffset(resultOffset))
+ return false;
+
+ // If cast is towards more static strides along any dimension, don't fold.
+ for (auto it : llvm::zip(sourceStrides, resultStrides)) {
+ auto ss = std::get<0>(it), st = std::get<1>(it);
+ if (ss != st)
+ if (MemRefType::isDynamicStrideOrOffset(ss) &&
+ !MemRefType::isDynamicStrideOrOffset(st))
+ return false;
+ }
+
+ return true;
+}
+
+OpFoldResult SubViewOp::fold(ArrayRef<Attribute>) {
+ auto folds = [](Operation *op) {
+ bool folded = false;
+ for (OpOperand &operand : op->getOpOperands()) {
+ auto castOp =
+ dyn_cast_or_null<MemRefCastOp>(operand.get().getDefiningOp());
+ if (castOp && canFoldIntoConsumerOp(castOp)) {
+ operand.set(castOp.getOperand());
+ folded = true;
+ }
+ }
+ return folded ? success() : failure();
+ };
+
+ if (succeeded(folds(*this)))
+ return getResult();
+ return {};
+}
+
void SubViewOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) {
results.insert<SubViewOpShapeFolder, SubViewOpStrideFolder,
diff --git a/mlir/test/Transforms/canonicalize.mlir b/mlir/test/Transforms/canonicalize.mlir
index 1cff314d731a..781ce83d95c1 100644
--- a/mlir/test/Transforms/canonicalize.mlir
+++ b/mlir/test/Transforms/canonicalize.mlir
@@ -919,3 +919,15 @@ func @tensor_divi_unsigned_by_one(%arg0: tensor<4x5xi32>) -> tensor<4x5xi32> {
// CHECK: return %[[ARG]]
return %res : tensor<4x5xi32>
}
+
+// -----
+
+// CHECK-LABEL: func @memref_cast_folding_subview
+func @memref_cast_folding_subview(%arg0: memref<4x5xf32>, %i: index) -> (memref<?x?xf32, offset:? , strides: [?, ?]>) {
+ %0 = memref_cast %arg0 : memref<4x5xf32> to memref<?x?xf32>
+ // CHECK-NEXT: subview %{{.*}}: memref<4x5xf32>
+ %1 = subview %0[][%i,%i][]: memref<?x?xf32> to memref<?x?xf32, offset:? , strides: [?, ?]>
+ // CHECK-NEXT: return %{{.*}}
+ return %1: memref<?x?xf32, offset:? , strides: [?, ?]>
+}
+
More information about the Mlir-commits
mailing list