[Mlir-commits] [mlir] 6d45284 - [mlir][memref] Add better support for identity layouts in memref.collapse_shape canonicalizer
Stephan Herhut
llvmlistbot at llvm.org
Thu Jan 20 06:32:06 PST 2022
Author: Stephan Herhut
Date: 2022-01-20T15:31:43+01:00
New Revision: 6d45284618f08fa28dc515cab96fa573c4c4479e
URL: https://github.com/llvm/llvm-project/commit/6d45284618f08fa28dc515cab96fa573c4c4479e
DIFF: https://github.com/llvm/llvm-project/commit/6d45284618f08fa28dc515cab96fa573c4c4479e.diff
LOG: [mlir][memref] Add better support for identity layouts in memref.collapse_shape canonicalizer
When computing the new type of a collapse_shape operation, we need to at least
take into account whether the type has an identity layout, in which case we can
easily support dynamic strides. Otherwise, the canonicalizer creates invalid
IR.
Longer term, both the verifier and the canoncializer need to be extended to
support the general case.
Differential Revision: https://reviews.llvm.org/D117772
Added:
Modified:
mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
mlir/test/Dialect/MemRef/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 55499f63295f0..211af3045b9d8 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -1334,6 +1334,7 @@ computeReshapeCollapsedType(MemRefType type,
AffineExpr offset;
SmallVector<AffineExpr, 4> strides;
auto status = getStridesAndOffset(type, strides, offset);
+ auto isIdentityLayout = type.getLayout().isIdentity();
(void)status;
assert(succeeded(status) && "expected strided memref");
@@ -1350,12 +1351,19 @@ computeReshapeCollapsedType(MemRefType type,
unsigned dim = m.getNumResults();
int64_t size = 1;
AffineExpr stride = strides[currentDim + dim - 1];
- if (!isReshapableDimBand(currentDim, dim, sizes, strides)) {
+ if (isIdentityLayout ||
+ isReshapableDimBand(currentDim, dim, sizes, strides)) {
+ for (unsigned d = 0; d < dim; ++d) {
+ int64_t currentSize = sizes[currentDim + d];
+ if (ShapedType::isDynamic(currentSize)) {
+ size = ShapedType::kDynamicSize;
+ break;
+ }
+ size *= currentSize;
+ }
+ } else {
size = ShapedType::kDynamicSize;
stride = AffineExpr();
- } else {
- for (unsigned d = 0; d < dim; ++d)
- size *= sizes[currentDim + d];
}
newSizes.push_back(size);
newStrides.push_back(stride);
diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index bd7a8dd830a80..58083437ca47e 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -406,6 +406,8 @@ func @collapse_after_memref_cast_type_change(%arg0 : memref<?x512x1x1xf32>) -> m
return %collapsed : memref<?x?xf32>
}
+// -----
+
// CHECK-LABEL: func @collapse_after_memref_cast(
// CHECK-SAME: %[[INPUT:.*]]: memref<?x512x1x?xf32>) -> memref<?x?xf32> {
// CHECK: %[[COLLAPSED:.*]] = memref.collapse_shape %[[INPUT]]
@@ -419,6 +421,21 @@ func @collapse_after_memref_cast(%arg0 : memref<?x512x1x?xf32>) -> memref<?x?xf3
// -----
+// CHECK-LABEL: func @collapse_after_memref_cast_type_change_dynamic(
+// CHECK-SAME: %[[INPUT:.*]]: memref<1x1x1x?xi64>) -> memref<?x?xi64> {
+// CHECK: %[[COLLAPSED:.*]] = memref.collapse_shape %[[INPUT]]
+// CHECK_SAME: {{\[\[}}0, 1, 2], [3]] : memref<1x1x1x?xi64> into memref<1x?xi64>
+// CHECK: %[[DYNAMIC:.*]] = memref.cast %[[COLLAPSED]] :
+// CHECK-SAME: memref<1x?xi64> to memref<?x?xi64>
+// CHECK: return %[[DYNAMIC]] : memref<?x?xi64>
+func @collapse_after_memref_cast_type_change_dynamic(%arg0: memref<1x1x1x?xi64>) -> memref<?x?xi64> {
+ %casted = memref.cast %arg0 : memref<1x1x1x?xi64> to memref<1x1x?x?xi64>
+ %collapsed = memref.collapse_shape %casted [[0, 1, 2], [3]] : memref<1x1x?x?xi64> into memref<?x?xi64>
+ return %collapsed : memref<?x?xi64>
+}
+
+// -----
+
func @reduced_memref(%arg0: memref<2x5x7x1xf32>, %arg1 :index)
-> memref<1x4x1xf32, affine_map<(d0, d1, d2)[s0] -> (d0 * 35 + s0 + d1 * 7 + d2)>> {
%c0 = arith.constant 0 : index
More information about the Mlir-commits
mailing list